Merge branch 'dependabot/npm_and_yarn/extensions/react-widget/radix-ui/react-icons-1.3.2' of https://github.com/arc53/docsgpt into dependabot/npm_and_yarn/extensions/react-widget/radix-ui/react-icons-1.3.2

This commit is contained in:
ManishMadan2882
2026-03-16 01:57:14 +05:30
344 changed files with 48707 additions and 8636 deletions

View File

@@ -1,9 +1,28 @@
API_KEY=<LLM api key (for example, open ai key)>
LLM_NAME=docsgpt
VITE_API_STREAMING=true
INTERNAL_KEY=<internal key for worker-to-backend authentication>
# Remote Embeddings (Optional - for using a remote embeddings API instead of local SentenceTransformer)
# When set, the app will use the remote API and won't load SentenceTransformer (saves RAM)
EMBEDDINGS_BASE_URL=
EMBEDDINGS_KEY=
#For Azure (you can delete it if you don't use Azure)
OPENAI_API_BASE=
OPENAI_API_VERSION=
AZURE_DEPLOYMENT_NAME=
AZURE_EMBEDDINGS_DEPLOYMENT_NAME=
AZURE_EMBEDDINGS_DEPLOYMENT_NAME=
#Azure AD Application (client) ID
MICROSOFT_CLIENT_ID=your-azure-ad-client-id
#Azure AD Application client secret
MICROSOFT_CLIENT_SECRET=your-azure-ad-client-secret
#Azure AD Tenant ID (or 'common' for multi-tenant)
MICROSOFT_TENANT_ID=your-azure-ad-tenant-id
#If you are using a Microsoft Entra ID tenant,
#configure the AUTHORITY variable as
#"https://login.microsoftonline.com/TENANT_GUID"
#or "https://login.microsoftonline.com/contoso.onmicrosoft.com".
#Alternatively, use "https://login.microsoftonline.com/common" for multi-tenant app.
MICROSOFT_AUTHORITY=https://{tenantId}.ciamlogin.com/{tenantId}

View File

@@ -7,6 +7,9 @@ on:
pull_request:
types: [ opened, synchronize ]
permissions:
contents: read
jobs:
ruff:
runs-on: ubuntu-latest

View File

@@ -1,5 +1,9 @@
name: Run python tests with pytest
on: [push, pull_request]
permissions:
contents: read
jobs:
pytest_and_coverage:
name: Run tests and count coverage

View File

@@ -9,6 +9,10 @@ on:
- '.vale.ini'
- '.github/styles/**'
permissions:
contents: read
pull-requests: write
jobs:
vale:
runs-on: ubuntu-latest

6
.gitignore vendored
View File

@@ -2,6 +2,7 @@
__pycache__/
*.py[cod]
*$py.class
experiments/
experiments
# C extensions
@@ -70,6 +71,7 @@ instance/
# Sphinx documentation
docs/_build/
docs/public/_pagefind/
# PyBuilder
target/
@@ -146,6 +148,10 @@ frontend/yarn-error.log*
frontend/pnpm-debug.log*
frontend/lerna-debug.log*
# Keep frontend utility helpers tracked (overrides global lib/ ignore)
!frontend/src/lib/
!frontend/src/lib/**
frontend/node_modules
frontend/dist
frontend/dist-ssr

View File

@@ -1,2 +1,6 @@
# Allow lines to be as long as 120 characters.
line-length = 120
line-length = 120
[lint.per-file-ignores]
# Integration tests use sys.path.insert() before imports for standalone execution
"tests/integration/*.py" = ["E402"]

View File

@@ -26,13 +26,6 @@
</div>
<div align="center">
<br>
🎃 <a href="https://github.com/arc53/DocsGPT/blob/main/HACKTOBERFEST.md"> Hacktoberfest Prizes, Rules & Q&A </a> 🎃
<br>
<br>
</div>
<div align="center">
<br>
@@ -53,24 +46,11 @@
</ul>
## Roadmap
- [x] Full GoogleAI compatibility (Jan 2025)
- [x] Add tools (Jan 2025)
- [x] Manually updating chunks in the app UI (Feb 2025)
- [x] Devcontainer for easy development (Feb 2025)
- [x] ReACT agent (March 2025)
- [x] Chatbots menu re-design to handle tools, agent types, and more (April 2025)
- [x] New input box in the conversation menu (April 2025)
- [x] Add triggerable actions / tools (webhook) (April 2025)
- [x] Agent optimisations (May 2025)
- [x] Filesystem sources update (July 2025)
- [x] Json Responses (August 2025)
- [x] MCP support (August 2025)
- [x] Google Drive integration (September 2025)
- [x] Add OAuth 2.0 authentication for MCP (September 2025)
- [ ] SharePoint integration (October 2025)
- [ ] Deep Agents (October 2025)
- [ ] Agent scheduling
- [x] Add OAuth 2.0 authentication for MCP ( September 2025 )
- [x] Deep Agents ( October 2025 )
- [x] Prompt Templating ( October 2025 )
- [x] Full api tooling ( Dec 2025 )
- [ ] Agent scheduling ( Jan 2026 )
You can find our full roadmap [here](https://github.com/orgs/arc53/projects/2). Please don't hesitate to contribute or create issues, it helps us improve DocsGPT!
@@ -165,9 +145,17 @@ We as members, contributors, and leaders, pledge to make participation in our co
The source code license is [MIT](https://opensource.org/license/mit/), as described in the [LICENSE](LICENSE) file.
<p>This project is supported by:</p>
## This project is supported by:
<p>
<a href="https://www.digitalocean.com/?utm_medium=opensource&utm_source=DocsGPT">
<img src="https://opensource.nyc3.cdn.digitaloceanspaces.com/attribution/assets/SVG/DO_Logo_horizontal_blue.svg" width="201px">
</a>
</p>
<p>
<a href="https://get.neon.com/docsgpt">
<img width="201" alt="color" src="https://github.com/user-attachments/assets/7d9813b7-0e6d-403f-b5af-68af066b326f" />
</a>
</p>

View File

@@ -1,11 +0,0 @@
API_KEY=your_api_key
EMBEDDINGS_KEY=your_api_key
API_URL=http://localhost:7091
FLASK_APP=application/app.py
FLASK_DEBUG=true
#For OPENAI on Azure
OPENAI_API_BASE=
OPENAI_API_VERSION=
AZURE_DEPLOYMENT_NAME=
AZURE_EMBEDDINGS_DEPLOYMENT_NAME=

View File

@@ -7,7 +7,7 @@ RUN apt-get update && \
apt-get install -y software-properties-common && \
add-apt-repository ppa:deadsnakes/ppa && \
apt-get update && \
apt-get install -y --no-install-recommends gcc wget unzip libc6-dev python3.12 python3.12-venv && \
apt-get install -y --no-install-recommends gcc g++ wget unzip libc6-dev python3.12 python3.12-venv python3.12-dev && \
rm -rf /var/lib/apt/lists/*
# Verify Python installation and setup symlink
@@ -48,7 +48,12 @@ FROM ubuntu:24.04 as final
RUN apt-get update && \
apt-get install -y software-properties-common && \
add-apt-repository ppa:deadsnakes/ppa && \
apt-get update && apt-get install -y --no-install-recommends python3.12 && \
apt-get update && apt-get install -y --no-install-recommends \
python3.12 \
libgl1 \
libglib2.0-0 \
poppler-utils \
&& \
ln -s /usr/bin/python3.12 /usr/bin/python && \
rm -rf /var/lib/apt/lists/*

View File

@@ -1,11 +1,17 @@
import logging
from application.agents.classic_agent import ClassicAgent
from application.agents.react_agent import ReActAgent
from application.agents.workflow_agent import WorkflowAgent
logger = logging.getLogger(__name__)
class AgentCreator:
agents = {
"classic": ClassicAgent,
"react": ReActAgent,
"workflow": WorkflowAgent,
}
@classmethod

View File

@@ -7,11 +7,16 @@ from bson.objectid import ObjectId
from application.agents.tools.tool_action_parser import ToolActionParser
from application.agents.tools.tool_manager import ToolManager
from application.core.json_schema_utils import (
JsonSchemaValidationError,
normalize_json_schema_payload,
)
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.llm.handlers.handler_creator import LLMHandlerCreator
from application.llm.llm_creator import LLMCreator
from application.logging import build_stack_data, log_activity, LogContext
from application.security.encryption import decrypt_credentials
logger = logging.getLogger(__name__)
@@ -21,8 +26,9 @@ class BaseAgent(ABC):
self,
endpoint: str,
llm_name: str,
gpt_model: str,
model_id: str,
api_key: str,
agent_id: Optional[str] = None,
user_api_key: Optional[str] = None,
prompt: str = "",
chat_history: Optional[List[Dict]] = None,
@@ -34,11 +40,13 @@ class BaseAgent(ABC):
token_limit: Optional[int] = settings.DEFAULT_AGENT_LIMITS["token_limit"],
limited_request_mode: Optional[bool] = False,
request_limit: Optional[int] = settings.DEFAULT_AGENT_LIMITS["request_limit"],
compressed_summary: Optional[str] = None,
):
self.endpoint = endpoint
self.llm_name = llm_name
self.gpt_model = gpt_model
self.model_id = model_id
self.api_key = api_key
self.agent_id = agent_id
self.user_api_key = user_api_key
self.prompt = prompt
self.decoded_token = decoded_token or {}
@@ -52,17 +60,27 @@ class BaseAgent(ABC):
api_key=api_key,
user_api_key=user_api_key,
decoded_token=decoded_token,
model_id=model_id,
agent_id=agent_id,
)
self.retrieved_docs = retrieved_docs or []
self.llm_handler = LLMHandlerCreator.create_handler(
llm_name if llm_name else "default"
)
self.attachments = attachments or []
self.json_schema = json_schema
self.json_schema = None
if json_schema is not None:
try:
self.json_schema = normalize_json_schema_payload(json_schema)
except JsonSchemaValidationError as exc:
logger.warning("Ignoring invalid JSON schema payload: %s", exc)
self.limited_token_mode = limited_token_mode
self.token_limit = token_limit
self.limited_request_mode = limited_request_mode
self.request_limit = request_limit
self.compressed_summary = compressed_summary
self.current_token_count = 0
self.context_limit_reached = False
@log_activity()
def gen(
@@ -115,10 +133,10 @@ class BaseAgent(ABC):
params["properties"][k] = {
key: value
for key, value in v.items()
if key != "filled_by_llm" and key != "value"
if key not in ("filled_by_llm", "value", "required")
}
params["required"].append(k)
if v.get("required", False):
params["required"].append(k)
return params
def _prepare_tools(self, tools_dict):
@@ -214,7 +232,11 @@ class BaseAgent(ABC):
for param_type, target_dict in param_types.items():
if param_type in action_data and action_data[param_type].get("properties"):
for param, details in action_data[param_type]["properties"].items():
if param not in call_args and "value" in details:
if (
param not in call_args
and "value" in details
and details["value"]
):
target_dict[param] = details["value"]
for param, value in call_args.items():
for param_type, target_dict in param_types.items():
@@ -227,36 +249,86 @@ class BaseAgent(ABC):
# Prepare tool_config and add tool_id for memory tools
if tool_data["name"] == "api_tool":
action_config = tool_data["config"]["actions"][action_name]
tool_config = {
"url": tool_data["config"]["actions"][action_name]["url"],
"method": tool_data["config"]["actions"][action_name]["method"],
"url": action_config["url"],
"method": action_config["method"],
"headers": headers,
"query_params": query_params,
}
if "body_content_type" in action_config:
tool_config["body_content_type"] = action_config.get(
"body_content_type", "application/json"
)
tool_config["body_encoding_rules"] = action_config.get(
"body_encoding_rules", {}
)
else:
tool_config = tool_data["config"].copy() if tool_data["config"] else {}
# Add tool_id from MongoDB _id for tools that need instance isolation (like memory tool)
# Use MongoDB _id if available, otherwise fall back to enumerated tool_id
if tool_config.get("encrypted_credentials") and self.user:
decrypted = decrypt_credentials(
tool_config["encrypted_credentials"], self.user
)
tool_config.update(decrypted)
tool_config["auth_credentials"] = decrypted
tool_config.pop("encrypted_credentials", None)
tool_config["tool_id"] = str(tool_data.get("_id", tool_id))
if hasattr(self, "conversation_id") and self.conversation_id:
tool_config["conversation_id"] = self.conversation_id
if tool_data["name"] == "mcp_tool":
tool_config["query_mode"] = True
tool = tm.load_tool(
tool_data["name"],
tool_config=tool_config,
user_id=self.user, # Pass user ID for MCP tools credential decryption
user_id=self.user,
)
resolved_arguments = (
{"query_params": query_params, "headers": headers, "body": body}
if tool_data["name"] == "api_tool"
else parameters
)
if tool_data["name"] == "api_tool":
print(
logger.debug(
f"Executing api: {action_name} with query_params: {query_params}, headers: {headers}, body: {body}"
)
result = tool.execute_action(action_name, **body)
else:
print(f"Executing tool: {action_name} with args: {call_args}")
logger.debug(f"Executing tool: {action_name} with args: {call_args}")
result = tool.execute_action(action_name, **parameters)
tool_call_data["result"] = (
f"{str(result)[:50]}..." if len(str(result)) > 50 else result
get_artifact_id = (
getattr(tool, "get_artifact_id", None)
if tool_data["name"] != "api_tool"
else None
)
yield {"type": "tool_call", "data": {**tool_call_data, "status": "completed"}}
artifact_id = None
if callable(get_artifact_id):
try:
artifact_id = get_artifact_id(action_name, **parameters)
except Exception:
logger.exception(
"Failed to extract artifact_id from tool %s for action %s",
tool_data["name"],
action_name,
)
artifact_id = str(artifact_id).strip() if artifact_id is not None else ""
if artifact_id:
tool_call_data["artifact_id"] = artifact_id
result_full = str(result)
tool_call_data["resolved_arguments"] = resolved_arguments
tool_call_data["result_full"] = result_full
tool_call_data["result"] = (
f"{result_full[:50]}..." if len(result_full) > 50 else result_full
)
stream_tool_call_data = {
key: value
for key, value in tool_call_data.items()
if key not in {"result_full", "resolved_arguments"}
}
yield {"type": "tool_call", "data": {**stream_tool_call_data, "status": "completed"}}
self.tool_calls.append(tool_call_data)
return result, call_id
@@ -264,7 +336,11 @@ class BaseAgent(ABC):
def _get_truncated_tool_calls(self):
return [
{
**tool_call,
"tool_name": tool_call.get("tool_name"),
"call_id": tool_call.get("call_id"),
"action_name": tool_call.get("action_name"),
"arguments": tool_call.get("arguments"),
"artifact_id": tool_call.get("artifact_id"),
"result": (
f"{str(tool_call['result'])[:50]}..."
if len(str(tool_call["result"])) > 50
@@ -275,15 +351,176 @@ class BaseAgent(ABC):
for tool_call in self.tool_calls
]
def _calculate_current_context_tokens(self, messages: List[Dict]) -> int:
"""
Calculate total tokens in current context (messages).
Args:
messages: List of message dicts
Returns:
Total token count
"""
from application.api.answer.services.compression.token_counter import (
TokenCounter,
)
return TokenCounter.count_message_tokens(messages)
def _check_context_limit(self, messages: List[Dict]) -> bool:
"""
Check if we're approaching context limit (80%).
Args:
messages: Current message list
Returns:
True if at or above 80% of context limit
"""
from application.core.model_utils import get_token_limit
from application.core.settings import settings
try:
# Calculate current tokens
current_tokens = self._calculate_current_context_tokens(messages)
self.current_token_count = current_tokens
# Get context limit for model
context_limit = get_token_limit(self.model_id)
# Calculate threshold (80%)
threshold = int(context_limit * settings.COMPRESSION_THRESHOLD_PERCENTAGE)
# Check if we've reached the limit
if current_tokens >= threshold:
logger.warning(
f"Context limit approaching: {current_tokens}/{context_limit} tokens "
f"({(current_tokens/context_limit)*100:.1f}%)"
)
return True
return False
except Exception as e:
logger.error(f"Error checking context limit: {str(e)}", exc_info=True)
return False
def _validate_context_size(self, messages: List[Dict]) -> None:
"""
Pre-flight validation before calling LLM. Logs warnings but never raises errors.
Args:
messages: Messages to be sent to LLM
"""
from application.core.model_utils import get_token_limit
current_tokens = self._calculate_current_context_tokens(messages)
self.current_token_count = current_tokens
context_limit = get_token_limit(self.model_id)
percentage = (current_tokens / context_limit) * 100
# Log based on usage level
if current_tokens >= context_limit:
logger.warning(
f"Context at limit: {current_tokens:,}/{context_limit:,} tokens "
f"({percentage:.1f}%). Model: {self.model_id}"
)
elif current_tokens >= int(
context_limit * settings.COMPRESSION_THRESHOLD_PERCENTAGE
):
logger.info(
f"Context approaching limit: {current_tokens:,}/{context_limit:,} tokens "
f"({percentage:.1f}%)"
)
def _truncate_text_middle(self, text: str, max_tokens: int) -> str:
"""
Truncate text by removing content from the middle, preserving start and end.
Args:
text: Text to truncate
max_tokens: Maximum tokens allowed
Returns:
Truncated text with middle removed if needed
"""
from application.utils import num_tokens_from_string
current_tokens = num_tokens_from_string(text)
if current_tokens <= max_tokens:
return text
# Estimate chars per token (roughly 4 chars per token for English)
chars_per_token = len(text) / current_tokens if current_tokens > 0 else 4
target_chars = int(max_tokens * chars_per_token * 0.95) # 5% safety margin
if target_chars <= 0:
return ""
# Split: keep 40% from start, 40% from end, remove middle
start_chars = int(target_chars * 0.4)
end_chars = int(target_chars * 0.4)
truncation_marker = "\n\n[... content truncated to fit context limit ...]\n\n"
truncated = text[:start_chars] + truncation_marker + text[-end_chars:]
logger.info(
f"Truncated text from {current_tokens:,} to ~{max_tokens:,} tokens "
f"(removed middle section)"
)
return truncated
def _build_messages(
self,
system_prompt: str,
query: str,
) -> List[Dict]:
"""Build messages using pre-rendered system prompt"""
from application.core.model_utils import get_token_limit
from application.utils import num_tokens_from_string
# Append compression summary to system prompt if present
if self.compressed_summary:
compression_context = (
"\n\n---\n\n"
"This session is being continued from a previous conversation that "
"has been compressed to fit within context limits. "
"The conversation is summarized below:\n\n"
f"{self.compressed_summary}"
)
system_prompt = system_prompt + compression_context
context_limit = get_token_limit(self.model_id)
system_tokens = num_tokens_from_string(system_prompt)
# Reserve 10% for response/tools
safety_buffer = int(context_limit * 0.1)
available_after_system = context_limit - system_tokens - safety_buffer
# Max tokens for query: 80% of available space (leave room for history)
max_query_tokens = int(available_after_system * 0.8)
query_tokens = num_tokens_from_string(query)
# Truncate query from middle if it exceeds 80% of available context
if query_tokens > max_query_tokens:
query = self._truncate_text_middle(query, max_query_tokens)
query_tokens = num_tokens_from_string(query)
# Calculate remaining budget for chat history
available_for_history = max(available_after_system - query_tokens, 0)
# Truncate chat history to fit within available budget
working_history = self._truncate_history_to_fit(
self.chat_history,
available_for_history,
)
messages = [{"role": "system", "content": system_prompt}]
for i in self.chat_history:
for i in working_history:
if "prompt" in i and "response" in i:
messages.append({"role": "user", "content": i["prompt"]})
messages.append({"role": "assistant", "content": i["response"]})
@@ -315,8 +552,69 @@ class BaseAgent(ABC):
messages.append({"role": "user", "content": query})
return messages
def _truncate_history_to_fit(
self,
history: List[Dict],
max_tokens: int,
) -> List[Dict]:
"""
Truncate chat history to fit within token budget, keeping most recent messages.
Args:
history: Full chat history
max_tokens: Maximum tokens allowed for history
Returns:
Truncated history (most recent messages that fit)
"""
from application.utils import num_tokens_from_string
if not history or max_tokens <= 0:
return []
truncated = []
current_tokens = 0
# Iterate from newest to oldest
for message in reversed(history):
message_tokens = 0
if "prompt" in message and "response" in message:
message_tokens += num_tokens_from_string(message["prompt"])
message_tokens += num_tokens_from_string(message["response"])
if "tool_calls" in message:
for tool_call in message["tool_calls"]:
tool_str = (
f"Tool: {tool_call.get('tool_name')} | "
f"Action: {tool_call.get('action_name')} | "
f"Args: {tool_call.get('arguments')} | "
f"Response: {tool_call.get('result')}"
)
message_tokens += num_tokens_from_string(tool_str)
if current_tokens + message_tokens <= max_tokens:
current_tokens += message_tokens
truncated.insert(0, message) # Maintain chronological order
else:
break
if len(truncated) < len(history):
logger.info(
f"Truncated chat history from {len(history)} to {len(truncated)} messages "
f"to fit within {max_tokens:,} token budget"
)
return truncated
def _llm_gen(self, messages: List[Dict], log_context: Optional[LogContext] = None):
gen_kwargs = {"model": self.gpt_model, "messages": messages}
# Pre-flight context validation - fail fast if over limit
self._validate_context_size(messages)
gen_kwargs = {"model": self.model_id, "messages": messages}
if self.attachments:
# Usage accounting only; stripped before provider invocation.
gen_kwargs["_usage_attachments"] = self.attachments
if (
hasattr(self.llm, "_supports_tools")

View File

@@ -86,7 +86,7 @@ class ReActAgent(BaseAgent):
messages = [{"role": "user", "content": plan_prompt}]
plan_stream = self.llm.gen_stream(
model=self.gpt_model,
model=self.model_id,
messages=messages,
tools=self.tools if self.tools else None,
)
@@ -151,7 +151,7 @@ class ReActAgent(BaseAgent):
messages = [{"role": "user", "content": final_prompt}]
final_stream = self.llm.gen_stream(
model=self.gpt_model, messages=messages, tools=None
model=self.model_id, messages=messages, tools=None
)
if log_context:
@@ -235,4 +235,4 @@ class ReActAgent(BaseAgent):
)
except Exception as e:
logger.error(f"Error extracting content: {e}")
return "".join(collected)
return "".join(collected)

View File

@@ -0,0 +1,323 @@
import base64
import json
import logging
from enum import Enum
from typing import Any, Dict, Optional, Union
from urllib.parse import quote, urlencode
logger = logging.getLogger(__name__)
class ContentType(str, Enum):
"""Supported content types for request bodies."""
JSON = "application/json"
FORM_URLENCODED = "application/x-www-form-urlencoded"
MULTIPART_FORM_DATA = "multipart/form-data"
TEXT_PLAIN = "text/plain"
XML = "application/xml"
OCTET_STREAM = "application/octet-stream"
class RequestBodySerializer:
"""Serializes request bodies according to content-type and OpenAPI 3.1 spec."""
@staticmethod
def serialize(
body_data: Dict[str, Any],
content_type: str = ContentType.JSON,
encoding_rules: Optional[Dict[str, Dict[str, Any]]] = None,
) -> tuple[Union[str, bytes], Dict[str, str]]:
"""
Serialize body data to appropriate format.
Args:
body_data: Dictionary of body parameters
content_type: Content-Type header value
encoding_rules: OpenAPI Encoding Object rules per field
Returns:
Tuple of (serialized_body, updated_headers_dict)
Raises:
ValueError: If serialization fails
"""
if not body_data:
return None, {}
try:
content_type_lower = content_type.lower().split(";")[0].strip()
if content_type_lower == ContentType.JSON:
return RequestBodySerializer._serialize_json(body_data)
elif content_type_lower == ContentType.FORM_URLENCODED:
return RequestBodySerializer._serialize_form_urlencoded(
body_data, encoding_rules
)
elif content_type_lower == ContentType.MULTIPART_FORM_DATA:
return RequestBodySerializer._serialize_multipart_form_data(
body_data, encoding_rules
)
elif content_type_lower == ContentType.TEXT_PLAIN:
return RequestBodySerializer._serialize_text_plain(body_data)
elif content_type_lower == ContentType.XML:
return RequestBodySerializer._serialize_xml(body_data)
elif content_type_lower == ContentType.OCTET_STREAM:
return RequestBodySerializer._serialize_octet_stream(body_data)
else:
logger.warning(
f"Unknown content type: {content_type}, treating as JSON"
)
return RequestBodySerializer._serialize_json(body_data)
except Exception as e:
logger.error(f"Error serializing body: {str(e)}", exc_info=True)
raise ValueError(f"Failed to serialize request body: {str(e)}")
@staticmethod
def _serialize_json(body_data: Dict[str, Any]) -> tuple[str, Dict[str, str]]:
"""Serialize body as JSON per OpenAPI spec."""
try:
serialized = json.dumps(
body_data, separators=(",", ":"), ensure_ascii=False
)
headers = {"Content-Type": ContentType.JSON.value}
return serialized, headers
except (TypeError, ValueError) as e:
raise ValueError(f"Failed to serialize JSON body: {str(e)}")
@staticmethod
def _serialize_form_urlencoded(
body_data: Dict[str, Any],
encoding_rules: Optional[Dict[str, Dict[str, Any]]] = None,
) -> tuple[str, Dict[str, str]]:
"""Serialize body as application/x-www-form-urlencoded per RFC1866/RFC3986."""
encoding_rules = encoding_rules or {}
params = []
for key, value in body_data.items():
if value is None:
continue
rule = encoding_rules.get(key, {})
style = rule.get("style", "form")
explode = rule.get("explode", style == "form")
content_type = rule.get("contentType", "text/plain")
serialized_value = RequestBodySerializer._serialize_form_value(
value, style, explode, content_type, key
)
if isinstance(serialized_value, list):
for sv in serialized_value:
params.append((key, sv))
else:
params.append((key, serialized_value))
# Use standard urlencode (replaces space with +)
serialized = urlencode(params, safe="")
headers = {"Content-Type": ContentType.FORM_URLENCODED.value}
return serialized, headers
@staticmethod
def _serialize_form_value(
value: Any, style: str, explode: bool, content_type: str, key: str
) -> Union[str, list]:
"""Serialize individual form value with encoding rules."""
if isinstance(value, dict):
if content_type == "application/json":
return json.dumps(value, separators=(",", ":"))
elif content_type == "application/xml":
return RequestBodySerializer._dict_to_xml(value)
else:
if style == "deepObject" and explode:
return [
f"{RequestBodySerializer._percent_encode(str(v))}"
for v in value.values()
]
elif explode:
return [
f"{RequestBodySerializer._percent_encode(str(v))}"
for v in value.values()
]
else:
pairs = [f"{k},{v}" for k, v in value.items()]
return RequestBodySerializer._percent_encode(",".join(pairs))
elif isinstance(value, (list, tuple)):
if explode:
return [
RequestBodySerializer._percent_encode(str(item)) for item in value
]
else:
return RequestBodySerializer._percent_encode(
",".join(str(v) for v in value)
)
else:
return RequestBodySerializer._percent_encode(str(value))
@staticmethod
def _serialize_multipart_form_data(
body_data: Dict[str, Any],
encoding_rules: Optional[Dict[str, Dict[str, Any]]] = None,
) -> tuple[bytes, Dict[str, str]]:
"""
Serialize body as multipart/form-data per RFC7578.
Supports file uploads and encoding rules.
"""
import secrets
encoding_rules = encoding_rules or {}
boundary = f"----DocsGPT{secrets.token_hex(16)}"
parts = []
for key, value in body_data.items():
if value is None:
continue
rule = encoding_rules.get(key, {})
content_type = rule.get("contentType", "text/plain")
headers_rule = rule.get("headers", {})
part = RequestBodySerializer._create_multipart_part(
key, value, content_type, headers_rule
)
parts.append(part)
body_bytes = f"--{boundary}\r\n".encode("utf-8")
body_bytes += f"--{boundary}\r\n".join(parts).encode("utf-8")
body_bytes += f"\r\n--{boundary}--\r\n".encode("utf-8")
headers = {
"Content-Type": f"multipart/form-data; boundary={boundary}",
}
return body_bytes, headers
@staticmethod
def _create_multipart_part(
name: str, value: Any, content_type: str, headers_rule: Dict[str, Any]
) -> str:
"""Create a single multipart/form-data part."""
headers = [
f'Content-Disposition: form-data; name="{RequestBodySerializer._percent_encode(name)}"'
]
if isinstance(value, bytes):
if content_type == "application/octet-stream":
value_encoded = base64.b64encode(value).decode("utf-8")
else:
value_encoded = value.decode("utf-8", errors="replace")
headers.append(f"Content-Type: {content_type}")
headers.append("Content-Transfer-Encoding: base64")
elif isinstance(value, dict):
if content_type == "application/json":
value_encoded = json.dumps(value, separators=(",", ":"))
elif content_type == "application/xml":
value_encoded = RequestBodySerializer._dict_to_xml(value)
else:
value_encoded = str(value)
headers.append(f"Content-Type: {content_type}")
elif isinstance(value, str) and content_type != "text/plain":
try:
if content_type == "application/json":
json.loads(value)
value_encoded = value
elif content_type == "application/xml":
value_encoded = value
else:
value_encoded = str(value)
except json.JSONDecodeError:
value_encoded = str(value)
headers.append(f"Content-Type: {content_type}")
else:
value_encoded = str(value)
if content_type != "text/plain":
headers.append(f"Content-Type: {content_type}")
part = "\r\n".join(headers) + "\r\n\r\n" + value_encoded + "\r\n"
return part
@staticmethod
def _serialize_text_plain(body_data: Dict[str, Any]) -> tuple[str, Dict[str, str]]:
"""Serialize body as plain text."""
if len(body_data) == 1:
value = list(body_data.values())[0]
return str(value), {"Content-Type": ContentType.TEXT_PLAIN.value}
else:
text = "\n".join(f"{k}: {v}" for k, v in body_data.items())
return text, {"Content-Type": ContentType.TEXT_PLAIN.value}
@staticmethod
def _serialize_xml(body_data: Dict[str, Any]) -> tuple[str, Dict[str, str]]:
"""Serialize body as XML."""
xml_str = RequestBodySerializer._dict_to_xml(body_data)
return xml_str, {"Content-Type": ContentType.XML.value}
@staticmethod
def _serialize_octet_stream(
body_data: Dict[str, Any],
) -> tuple[bytes, Dict[str, str]]:
"""Serialize body as binary octet stream."""
if isinstance(body_data, bytes):
return body_data, {"Content-Type": ContentType.OCTET_STREAM.value}
elif isinstance(body_data, str):
return body_data.encode("utf-8"), {
"Content-Type": ContentType.OCTET_STREAM.value
}
else:
serialized = json.dumps(body_data)
return serialized.encode("utf-8"), {
"Content-Type": ContentType.OCTET_STREAM.value
}
@staticmethod
def _percent_encode(value: str, safe_chars: str = "") -> str:
"""
Percent-encode per RFC3986.
Args:
value: String to encode
safe_chars: Additional characters to not encode
"""
return quote(value, safe=safe_chars)
@staticmethod
def _dict_to_xml(data: Dict[str, Any], root_name: str = "root") -> str:
"""
Convert dict to simple XML format.
"""
def build_xml(obj: Any, name: str) -> str:
if isinstance(obj, dict):
inner = "".join(build_xml(v, k) for k, v in obj.items())
return f"<{name}>{inner}</{name}>"
elif isinstance(obj, (list, tuple)):
items = "".join(
build_xml(item, f"{name[:-1] if name.endswith('s') else name}")
for item in obj
)
return items
else:
return f"<{name}>{RequestBodySerializer._escape_xml(str(obj))}</{name}>"
root = build_xml(data, root_name)
return f'<?xml version="1.0" encoding="UTF-8"?>{root}'
@staticmethod
def _escape_xml(value: str) -> str:
"""Escape XML special characters."""
return (
value.replace("&", "&amp;")
.replace("<", "&lt;")
.replace(">", "&gt;")
.replace('"', "&quot;")
.replace("'", "&apos;")
)

View File

@@ -1,72 +1,280 @@
import json
import logging
import re
from typing import Any, Dict, Optional
from urllib.parse import urlencode
import requests
from application.agents.tools.api_body_serializer import (
ContentType,
RequestBodySerializer,
)
from application.agents.tools.base import Tool
from application.core.url_validation import validate_url, SSRFError
logger = logging.getLogger(__name__)
DEFAULT_TIMEOUT = 90 # seconds
class APITool(Tool):
"""
API Tool
A flexible tool for performing various API actions (e.g., sending messages, retrieving data) via custom user-specified APIs
A flexible tool for performing various API actions (e.g., sending messages, retrieving data) via custom user-specified APIs.
"""
def __init__(self, config):
self.config = config
self.url = config.get("url", "")
self.method = config.get("method", "GET")
self.headers = config.get("headers", {"Content-Type": "application/json"})
self.headers = config.get("headers", {})
self.query_params = config.get("query_params", {})
self.body_content_type = config.get("body_content_type", ContentType.JSON)
self.body_encoding_rules = config.get("body_encoding_rules", {})
def execute_action(self, action_name, **kwargs):
"""Execute an API action with the given arguments."""
return self._make_api_call(
self.url, self.method, self.headers, self.query_params, kwargs
self.url,
self.method,
self.headers,
self.query_params,
kwargs,
self.body_content_type,
self.body_encoding_rules,
)
def _make_api_call(self, url, method, headers, query_params, body):
if query_params:
url = f"{url}?{requests.compat.urlencode(query_params)}"
# if isinstance(body, dict):
# body = json.dumps(body)
def _make_api_call(
self,
url: str,
method: str,
headers: Dict[str, str],
query_params: Dict[str, Any],
body: Dict[str, Any],
content_type: str = ContentType.JSON,
encoding_rules: Optional[Dict[str, Dict[str, Any]]] = None,
) -> Dict[str, Any]:
"""
Make an API call with proper body serialization and error handling.
Args:
url: API endpoint URL
method: HTTP method (GET, POST, PUT, DELETE, PATCH, HEAD, OPTIONS)
headers: Request headers dict
query_params: URL query parameters
body: Request body as dict
content_type: Content-Type for serialization
encoding_rules: OpenAPI encoding rules
Returns:
Dict with status_code, data, and message
"""
request_url = url
request_headers = headers.copy() if headers else {}
response = None
# Validate URL to prevent SSRF attacks
try:
print(f"Making API call: {method} {url} with body: {body}")
if body == "{}":
body = None
response = requests.request(method, url, headers=headers, data=body)
response.raise_for_status()
content_type = response.headers.get(
"Content-Type", "application/json"
).lower()
if "application/json" in content_type:
validate_url(request_url)
except SSRFError as e:
logger.error(f"URL validation failed: {e}")
return {
"status_code": None,
"message": f"URL validation error: {e}",
"data": None,
}
try:
path_params_used = set()
if query_params:
for match in re.finditer(r"\{([^}]+)\}", request_url):
param_name = match.group(1)
if param_name in query_params:
request_url = request_url.replace(
f"{{{param_name}}}", str(query_params[param_name])
)
path_params_used.add(param_name)
remaining_params = {
k: v for k, v in query_params.items() if k not in path_params_used
}
if remaining_params:
query_string = urlencode(remaining_params)
separator = "&" if "?" in request_url else "?"
request_url = f"{request_url}{separator}{query_string}"
# Re-validate URL after parameter substitution to prevent SSRF via path params
try:
validate_url(request_url)
except SSRFError as e:
logger.error(f"URL validation failed after parameter substitution: {e}")
return {
"status_code": None,
"message": f"URL validation error: {e}",
"data": None,
}
# Serialize body based on content type
if body and body != {}:
try:
data = response.json()
except json.JSONDecodeError as e:
print(f"Error decoding JSON: {e}. Raw response: {response.text}")
serialized_body, body_headers = RequestBodySerializer.serialize(
body, content_type, encoding_rules
)
request_headers.update(body_headers)
except ValueError as e:
logger.error(f"Body serialization failed: {str(e)}")
return {
"status_code": response.status_code,
"message": f"API call returned invalid JSON. Error: {e}",
"data": response.text,
"status_code": None,
"message": f"Body serialization error: {str(e)}",
"data": None,
}
elif "text/" in content_type or "application/xml" in content_type:
data = response.text
elif not response.content:
data = None
else:
print(f"Unsupported content type: {content_type}")
data = response.content
serialized_body = None
if "Content-Type" not in request_headers and method not in [
"GET",
"HEAD",
"DELETE",
]:
request_headers["Content-Type"] = ContentType.JSON
logger.debug(
f"API Call: {method} {request_url} | Content-Type: {request_headers.get('Content-Type', 'N/A')}"
)
if method.upper() == "GET":
response = requests.get(
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
)
elif method.upper() == "POST":
response = requests.post(
request_url,
data=serialized_body,
headers=request_headers,
timeout=DEFAULT_TIMEOUT,
)
elif method.upper() == "PUT":
response = requests.put(
request_url,
data=serialized_body,
headers=request_headers,
timeout=DEFAULT_TIMEOUT,
)
elif method.upper() == "DELETE":
response = requests.delete(
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
)
elif method.upper() == "PATCH":
response = requests.patch(
request_url,
data=serialized_body,
headers=request_headers,
timeout=DEFAULT_TIMEOUT,
)
elif method.upper() == "HEAD":
response = requests.head(
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
)
elif method.upper() == "OPTIONS":
response = requests.options(
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
)
else:
return {
"status_code": None,
"message": f"Unsupported HTTP method: {method}",
"data": None,
}
response.raise_for_status()
data = self._parse_response(response)
return {
"status_code": response.status_code,
"data": data,
"message": "API call successful.",
}
except requests.exceptions.Timeout:
logger.error(f"Request timeout for {request_url}")
return {
"status_code": None,
"message": f"Request timeout ({DEFAULT_TIMEOUT}s exceeded)",
"data": None,
}
except requests.exceptions.ConnectionError as e:
logger.error(f"Connection error: {str(e)}")
return {
"status_code": None,
"message": f"Connection error: {str(e)}",
"data": None,
}
except requests.exceptions.HTTPError as e:
logger.error(f"HTTP error {response.status_code}: {str(e)}")
try:
error_data = response.json()
except (json.JSONDecodeError, ValueError):
error_data = response.text
return {
"status_code": response.status_code,
"message": f"HTTP Error {response.status_code}",
"data": error_data,
}
except requests.exceptions.RequestException as e:
logger.error(f"Request failed: {str(e)}")
return {
"status_code": response.status_code if response else None,
"message": f"API call failed: {str(e)}",
"data": None,
}
except Exception as e:
logger.error(f"Unexpected error in API call: {str(e)}", exc_info=True)
return {
"status_code": None,
"message": f"Unexpected error: {str(e)}",
"data": None,
}
def _parse_response(self, response: requests.Response) -> Any:
"""
Parse response based on Content-Type header.
Supports: JSON, XML, plain text, binary data.
"""
content_type = response.headers.get("Content-Type", "").lower()
if not response.content:
return None
# JSON response
if "application/json" in content_type:
try:
return response.json()
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse JSON response: {str(e)}")
return response.text
# XML response
elif "application/xml" in content_type or "text/xml" in content_type:
return response.text
# Plain text response
elif "text/plain" in content_type or "text/html" in content_type:
return response.text
# Binary/unknown response
else:
# Try to decode as text first, fall back to base64
try:
return response.text
except (UnicodeDecodeError, AttributeError):
import base64
return base64.b64encode(response.content).decode("utf-8")
def get_actions_metadata(self):
"""Return metadata for available actions (none for API Tool - actions are user-defined)."""
return []
def get_config_requirements(self):
"""Return configuration requirements for the tool."""
return {}

View File

@@ -1,6 +1,11 @@
import logging
import requests
from application.agents.tools.base import Tool
logger = logging.getLogger(__name__)
class BraveSearchTool(Tool):
"""
@@ -41,7 +46,7 @@ class BraveSearchTool(Tool):
"""
Performs a web search using the Brave Search API.
"""
print(f"Performing Brave web search for: {query}")
logger.debug("Performing Brave web search for: %s", query)
url = f"{self.base_url}/web/search"
@@ -94,7 +99,7 @@ class BraveSearchTool(Tool):
"""
Performs an image search using the Brave Search API.
"""
print(f"Performing Brave image search for: {query}")
logger.debug("Performing Brave image search for: %s", query)
url = f"{self.base_url}/images/search"
@@ -177,6 +182,10 @@ class BraveSearchTool(Tool):
return {
"token": {
"type": "string",
"label": "API Key",
"description": "Brave Search API key for authentication",
"required": True,
"secret": True,
"order": 1,
},
}

View File

@@ -1,5 +1,14 @@
import logging
import time
from typing import Any, Dict, Optional
from application.agents.tools.base import Tool
from duckduckgo_search import DDGS
logger = logging.getLogger(__name__)
MAX_RETRIES = 3
RETRY_DELAY = 2.0
DEFAULT_TIMEOUT = 15
class DuckDuckGoSearchTool(Tool):
@@ -10,71 +19,123 @@ class DuckDuckGoSearchTool(Tool):
def __init__(self, config):
self.config = config
self.timeout = config.get("timeout", DEFAULT_TIMEOUT)
def _get_ddgs_client(self):
from ddgs import DDGS
return DDGS(timeout=self.timeout)
def _execute_with_retry(self, operation, operation_name: str) -> Dict[str, Any]:
last_error = None
for attempt in range(1, MAX_RETRIES + 1):
try:
results = operation()
return {
"status_code": 200,
"results": list(results) if results else [],
"message": f"{operation_name} completed successfully.",
}
except Exception as e:
last_error = e
error_str = str(e).lower()
if "ratelimit" in error_str or "429" in error_str:
if attempt < MAX_RETRIES:
delay = RETRY_DELAY * attempt
logger.warning(
f"{operation_name} rate limited, retrying in {delay}s (attempt {attempt}/{MAX_RETRIES})"
)
time.sleep(delay)
continue
logger.error(f"{operation_name} failed: {e}")
break
return {
"status_code": 500,
"results": [],
"message": f"{operation_name} failed: {str(last_error)}",
}
def execute_action(self, action_name, **kwargs):
actions = {
"ddg_web_search": self._web_search,
"ddg_image_search": self._image_search,
"ddg_news_search": self._news_search,
}
if action_name in actions:
return actions[action_name](**kwargs)
else:
if action_name not in actions:
raise ValueError(f"Unknown action: {action_name}")
return actions[action_name](**kwargs)
def _web_search(
self,
query,
max_results=5,
):
print(f"Performing DuckDuckGo web search for: {query}")
query: str,
max_results: int = 5,
region: str = "wt-wt",
safesearch: str = "moderate",
timelimit: Optional[str] = None,
) -> Dict[str, Any]:
logger.info(f"DuckDuckGo web search: {query}")
try:
results = DDGS().text(
def operation():
client = self._get_ddgs_client()
return client.text(
query,
max_results=max_results,
region=region,
safesearch=safesearch,
timelimit=timelimit,
max_results=min(max_results, 20),
)
return {
"status_code": 200,
"results": results,
"message": "Web search completed successfully.",
}
except Exception as e:
return {
"status_code": 500,
"message": f"Web search failed: {str(e)}",
}
return self._execute_with_retry(operation, "Web search")
def _image_search(
self,
query,
max_results=5,
):
print(f"Performing DuckDuckGo image search for: {query}")
query: str,
max_results: int = 5,
region: str = "wt-wt",
safesearch: str = "moderate",
timelimit: Optional[str] = None,
) -> Dict[str, Any]:
logger.info(f"DuckDuckGo image search: {query}")
try:
results = DDGS().images(
keywords=query,
max_results=max_results,
def operation():
client = self._get_ddgs_client()
return client.images(
query,
region=region,
safesearch=safesearch,
timelimit=timelimit,
max_results=min(max_results, 50),
)
return {
"status_code": 200,
"results": results,
"message": "Image search completed successfully.",
}
except Exception as e:
return {
"status_code": 500,
"message": f"Image search failed: {str(e)}",
}
return self._execute_with_retry(operation, "Image search")
def _news_search(
self,
query: str,
max_results: int = 5,
region: str = "wt-wt",
safesearch: str = "moderate",
timelimit: Optional[str] = None,
) -> Dict[str, Any]:
logger.info(f"DuckDuckGo news search: {query}")
def operation():
client = self._get_ddgs_client()
return client.news(
query,
region=region,
safesearch=safesearch,
timelimit=timelimit,
max_results=min(max_results, 20),
)
return self._execute_with_retry(operation, "News search")
def get_actions_metadata(self):
return [
{
"name": "ddg_web_search",
"description": "Perform a web search using DuckDuckGo.",
"description": "Search the web using DuckDuckGo. Returns titles, URLs, and snippets.",
"parameters": {
"type": "object",
"properties": {
@@ -84,7 +145,15 @@ class DuckDuckGoSearchTool(Tool):
},
"max_results": {
"type": "integer",
"description": "Number of results to return (default: 5)",
"description": "Number of results (default: 5, max: 20)",
},
"region": {
"type": "string",
"description": "Region code (default: wt-wt for worldwide, us-en for US)",
},
"timelimit": {
"type": "string",
"description": "Time filter: d (day), w (week), m (month), y (year)",
},
},
"required": ["query"],
@@ -92,17 +161,43 @@ class DuckDuckGoSearchTool(Tool):
},
{
"name": "ddg_image_search",
"description": "Perform an image search using DuckDuckGo.",
"description": "Search for images using DuckDuckGo. Returns image URLs and metadata.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search query",
"description": "Image search query",
},
"max_results": {
"type": "integer",
"description": "Number of results to return (default: 5, max: 50)",
"description": "Number of results (default: 5, max: 50)",
},
"region": {
"type": "string",
"description": "Region code (default: wt-wt for worldwide)",
},
},
"required": ["query"],
},
},
{
"name": "ddg_news_search",
"description": "Search for news articles using DuckDuckGo. Returns recent news.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "News search query",
},
"max_results": {
"type": "integer",
"description": "Number of results (default: 5, max: 20)",
},
"timelimit": {
"type": "string",
"description": "Time filter: d (day), w (week), m (month)",
},
},
"required": ["query"],

View File

@@ -1,20 +1,12 @@
import asyncio
import base64
import concurrent.futures
import json
import logging
import time
from typing import Any, Dict, List, Optional
from urllib.parse import parse_qs, urlparse
from application.agents.tools.base import Tool
from application.api.user.tasks import mcp_oauth_status_task, mcp_oauth_task
from application.cache import get_redis_instance
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.security.encryption import decrypt_credentials
from fastmcp import Client
from fastmcp.client.auth import BearerAuth
from fastmcp.client.transports import (
@@ -24,10 +16,18 @@ from fastmcp.client.transports import (
)
from mcp.client.auth import OAuthClientProvider, TokenStorage
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
from pydantic import AnyHttpUrl, ValidationError
from redis import Redis
from application.agents.tools.base import Tool
from application.api.user.tasks import mcp_oauth_status_task, mcp_oauth_task
from application.cache import get_redis_instance
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.security.encryption import decrypt_credentials
logger = logging.getLogger(__name__)
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
@@ -56,6 +56,7 @@ class MCPTool(Tool):
- args: Arguments for STDIO transport
- oauth_scopes: OAuth scopes for oauth auth type
- oauth_client_name: OAuth client name for oauth auth type
- query_mode: If True, use non-interactive OAuth (fail-fast on 401)
user_id: User ID for decrypting credentials (required if encrypted_credentials exist)
"""
self.config = config
@@ -76,23 +77,40 @@ class MCPTool(Tool):
self.oauth_scopes = config.get("oauth_scopes", [])
self.oauth_task_id = config.get("oauth_task_id", None)
self.oauth_client_name = config.get("oauth_client_name", "DocsGPT-MCP")
self.redirect_uri = f"{settings.API_URL}/api/mcp_server/callback"
self.redirect_uri = self._resolve_redirect_uri(config.get("redirect_uri"))
self.available_tools = []
self._cache_key = self._generate_cache_key()
self._client = None
# Only validate and setup if server_url is provided and not OAuth
self.query_mode = config.get("query_mode", False)
if self.server_url and self.auth_type != "oauth":
self._setup_client()
def _resolve_redirect_uri(self, configured_redirect_uri: Optional[str]) -> str:
if configured_redirect_uri:
return configured_redirect_uri.rstrip("/")
explicit = getattr(settings, "MCP_OAUTH_REDIRECT_URI", None)
if explicit:
return explicit.rstrip("/")
connector_base = getattr(settings, "CONNECTOR_REDIRECT_BASE_URI", None)
if connector_base:
parsed = urlparse(connector_base)
if parsed.scheme and parsed.netloc:
return f"{parsed.scheme}://{parsed.netloc}/api/mcp_server/callback"
return f"{settings.API_URL.rstrip('/')}/api/mcp_server/callback"
def _generate_cache_key(self) -> str:
"""Generate a unique cache key for this MCP server configuration."""
auth_key = ""
if self.auth_type == "oauth":
scopes_str = ",".join(self.oauth_scopes) if self.oauth_scopes else "none"
auth_key = f"oauth:{self.oauth_client_name}:{scopes_str}"
auth_key = (
f"oauth:{self.oauth_client_name}:{scopes_str}:{self.redirect_uri}"
)
elif self.auth_type in ["bearer"]:
token = self.auth_credentials.get(
"bearer_token", ""
@@ -109,11 +127,10 @@ class MCPTool(Tool):
return f"{self.server_url}#{self.transport_type}#{auth_key}"
def _setup_client(self):
"""Setup FastMCP client with proper transport and authentication."""
global _mcp_clients_cache
if self._cache_key in _mcp_clients_cache:
cached_data = _mcp_clients_cache[self._cache_key]
if time.time() - cached_data["created_at"] < 1800:
if time.time() - cached_data["created_at"] < 300:
self._client = cached_data["client"]
return
else:
@@ -123,15 +140,25 @@ class MCPTool(Tool):
if self.auth_type == "oauth":
redis_client = get_redis_instance()
auth = DocsGPTOAuth(
mcp_url=self.server_url,
scopes=self.oauth_scopes,
redis_client=redis_client,
redirect_uri=self.redirect_uri,
task_id=self.oauth_task_id,
db=db,
user_id=self.user_id,
)
if self.query_mode:
auth = NonInteractiveOAuth(
mcp_url=self.server_url,
scopes=self.oauth_scopes,
redis_client=redis_client,
redirect_uri=self.redirect_uri,
db=db,
user_id=self.user_id,
)
else:
auth = DocsGPTOAuth(
mcp_url=self.server_url,
scopes=self.oauth_scopes,
redis_client=redis_client,
redirect_uri=self.redirect_uri,
task_id=self.oauth_task_id,
db=db,
user_id=self.user_id,
)
elif self.auth_type == "bearer":
token = self.auth_credentials.get(
"bearer_token", ""
@@ -169,6 +196,8 @@ class MCPTool(Tool):
transport_type = "http"
else:
transport_type = self.transport_type
if transport_type == "stdio":
raise ValueError("STDIO transport is disabled")
if transport_type == "sse":
headers.update({"Accept": "text/event-stream", "Cache-Control": "no-cache"})
return SSETransport(url=self.server_url, headers=headers)
@@ -231,38 +260,53 @@ class MCPTool(Tool):
else:
raise Exception(f"Unknown operation: {operation}")
_ERROR_MAP = [
(concurrent.futures.TimeoutError, lambda op, t, _: f"Timed out after {t}s"),
(ConnectionRefusedError, lambda *_: "Connection refused"),
]
_ERROR_PATTERNS = {
("403", "Forbidden"): "Access denied (403 Forbidden)",
("401", "Unauthorized"): "Authentication failed (401 Unauthorized)",
("ECONNREFUSED",): "Connection refused",
("SSL", "certificate"): "SSL/TLS error",
}
def _run_async_operation(self, operation: str, *args, **kwargs):
"""Run async operation in sync context."""
try:
try:
loop = asyncio.get_running_loop()
import concurrent.futures
def run_in_thread():
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
try:
return new_loop.run_until_complete(
self._execute_with_client(operation, *args, **kwargs)
)
finally:
new_loop.close()
asyncio.get_running_loop()
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(run_in_thread)
future = executor.submit(
self._run_in_new_loop, operation, *args, **kwargs
)
return future.result(timeout=self.timeout)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(
self._execute_with_client(operation, *args, **kwargs)
)
finally:
loop.close()
return self._run_in_new_loop(operation, *args, **kwargs)
except Exception as e:
print(f"Error occurred while running async operation: {e}")
raise
raise self._map_error(operation, e) from e
raise self._map_error(operation, e) from e
def _run_in_new_loop(self, operation, *args, **kwargs):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(
self._execute_with_client(operation, *args, **kwargs)
)
finally:
loop.close()
def _map_error(self, operation: str, exc: Exception) -> Exception:
for exc_type, msg_fn in self._ERROR_MAP:
if isinstance(exc, exc_type):
return Exception(msg_fn(operation, self.timeout, exc))
error_msg = str(exc)
for patterns, friendly in self._ERROR_PATTERNS.items():
if any(p.lower() in error_msg.lower() for p in patterns):
return Exception(friendly)
logger.error("MCP %s failed: %s", operation, exc)
return exc
def discover_tools(self) -> List[Dict]:
"""
@@ -283,16 +327,6 @@ class MCPTool(Tool):
raise Exception(f"Failed to discover tools from MCP server: {str(e)}")
def execute_action(self, action_name: str, **kwargs) -> Any:
"""
Execute an action on the remote MCP server using FastMCP.
Args:
action_name: Name of the action to execute
**kwargs: Parameters for the action
Returns:
Result from the MCP server
"""
if not self.server_url:
raise Exception("No MCP server configured")
if not self._client:
@@ -308,7 +342,37 @@ class MCPTool(Tool):
)
return self._format_result(result)
except Exception as e:
raise Exception(f"Failed to execute action '{action_name}': {str(e)}")
error_msg = str(e)
lower_msg = error_msg.lower()
is_auth_error = (
"401" in error_msg
or "unauthorized" in lower_msg
or "session expired" in lower_msg
or "re-authorize" in lower_msg
)
if is_auth_error:
if self.auth_type == "oauth":
raise Exception(
f"Action '{action_name}' failed: OAuth session expired. "
"Please re-authorize this MCP server in tool settings."
) from e
global _mcp_clients_cache
_mcp_clients_cache.pop(self._cache_key, None)
self._client = None
self._setup_client()
try:
result = self._run_async_operation(
"call_tool", action_name, **cleaned_kwargs
)
return self._format_result(result)
except Exception as retry_e:
raise Exception(
f"Action '{action_name}' failed after re-auth attempt: {retry_e}. "
"Your credentials may have expired — please re-authorize in tool settings."
) from retry_e
raise Exception(
f"Failed to execute action '{action_name}': {error_msg}"
) from e
def _format_result(self, result) -> Dict:
"""Format FastMCP result to match expected format."""
@@ -331,23 +395,35 @@ class MCPTool(Tool):
return result
def test_connection(self) -> Dict:
"""
Test the connection to the MCP server and validate functionality.
Returns:
Dictionary with connection test results including tool count
"""
if not self.server_url:
return {
"success": False,
"message": "No MCP server URL configured",
"message": "No server URL configured",
"tools_count": 0,
}
try:
parsed = urlparse(self.server_url)
if parsed.scheme not in ("http", "https"):
return {
"success": False,
"message": f"Invalid URL scheme '{parsed.scheme}' — use http:// or https://",
"tools_count": 0,
}
except Exception:
return {
"success": False,
"message": "Invalid URL format",
"tools_count": 0,
"transport_type": self.transport_type,
"auth_type": self.auth_type,
"error_type": "ConfigurationError",
}
if not self._client:
self._setup_client()
try:
self._setup_client()
except Exception as e:
return {
"success": False,
"message": f"Client init failed: {str(e)}",
"tools_count": 0,
}
try:
if self.auth_type == "oauth":
return self._test_oauth_connection()
@@ -358,56 +434,94 @@ class MCPTool(Tool):
"success": False,
"message": f"Connection failed: {str(e)}",
"tools_count": 0,
"transport_type": self.transport_type,
"auth_type": self.auth_type,
"error_type": type(e).__name__,
}
def _test_regular_connection(self) -> Dict:
"""Test connection for non-OAuth auth types."""
ping_ok = False
ping_error = None
try:
self._run_async_operation("ping")
ping_success = True
except Exception:
ping_success = False
tools = self.discover_tools()
ping_ok = True
except Exception as e:
ping_error = str(e)
message = f"Successfully connected to MCP server. Found {len(tools)} tools."
if not ping_success:
message += " (Ping not supported, but tool discovery worked)"
return {
"success": True,
"message": message,
"tools_count": len(tools),
"transport_type": self.transport_type,
"auth_type": self.auth_type,
"ping_supported": ping_success,
"tools": [tool.get("name", "unknown") for tool in tools],
}
def _test_oauth_connection(self) -> Dict:
"""Test connection for OAuth auth type with proper async handling."""
try:
task = mcp_oauth_task.delay(config=self.config, user=self.user_id)
if not task:
raise Exception("Failed to start OAuth authentication")
return {
"success": True,
"requires_oauth": True,
"task_id": task.id,
"status": "pending",
"message": "OAuth flow started",
}
tools = self.discover_tools()
except Exception as e:
return {
"success": False,
"message": f"OAuth connection failed: {str(e)}",
"message": f"Connection failed: {ping_error or str(e)}",
"tools_count": 0,
"transport_type": self.transport_type,
"auth_type": self.auth_type,
"error_type": type(e).__name__,
}
if not tools and not ping_ok:
return {
"success": False,
"message": f"Connection failed: {ping_error or 'No tools found'}",
"tools_count": 0,
}
return {
"success": True,
"message": f"Connected — found {len(tools)} tool{'s' if len(tools) != 1 else ''}.",
"tools_count": len(tools),
"tools": [
{
"name": tool.get("name", "unknown"),
"description": tool.get("description", ""),
}
for tool in tools
],
}
def _test_oauth_connection(self) -> Dict:
storage = DBTokenStorage(
server_url=self.server_url, user_id=self.user_id, db_client=db
)
loop = asyncio.new_event_loop()
try:
tokens = loop.run_until_complete(storage.get_tokens())
finally:
loop.close()
if tokens and tokens.access_token:
self.query_mode = True
_mcp_clients_cache.pop(self._cache_key, None)
self._client = None
self._setup_client()
try:
tools = self.discover_tools()
return {
"success": True,
"message": f"Connected — found {len(tools)} tool{'s' if len(tools) != 1 else ''}.",
"tools_count": len(tools),
"tools": [
{
"name": t.get("name", "unknown"),
"description": t.get("description", ""),
}
for t in tools
],
}
except Exception as e:
logger.warning("OAuth token validation failed: %s", e)
_mcp_clients_cache.pop(self._cache_key, None)
self._client = None
return self._start_oauth_task()
def _start_oauth_task(self) -> Dict:
task_config = self.config.copy()
task_config.pop("query_mode", None)
result = mcp_oauth_task.delay(task_config, self.user_id)
return {
"success": False,
"requires_oauth": True,
"task_id": result.id,
"message": "OAuth authorization required.",
"tools_count": 0,
}
def get_actions_metadata(self) -> List[Dict]:
"""
Get metadata for all available actions.
@@ -453,107 +567,88 @@ class MCPTool(Tool):
return actions
def get_config_requirements(self) -> Dict:
"""Get configuration requirements for the MCP tool."""
return {
"server_url": {
"type": "string",
"description": "URL of the remote MCP server (e.g., https://api.example.com/mcp or https://docs.mcp.cloudflare.com/sse)",
"label": "Server URL",
"description": "URL of the remote MCP server",
"required": True,
},
"transport_type": {
"type": "string",
"description": "Transport type for connection",
"enum": ["auto", "sse", "http", "stdio"],
"default": "auto",
"required": False,
"help": {
"auto": "Automatically detect best transport",
"sse": "Server-Sent Events (for real-time streaming)",
"http": "HTTP streaming (recommended for production)",
"stdio": "Standard I/O (for local servers)",
},
"secret": False,
"order": 1,
},
"auth_type": {
"type": "string",
"description": "Authentication type",
"label": "Authentication Type",
"description": "Authentication method for the MCP server",
"enum": ["none", "bearer", "oauth", "api_key", "basic"],
"default": "none",
"required": True,
"help": {
"none": "No authentication",
"bearer": "Bearer token authentication",
"oauth": "OAuth 2.1 authentication (with frontend integration)",
"api_key": "API key authentication",
"basic": "Basic authentication",
},
"secret": False,
"order": 2,
},
"auth_credentials": {
"type": "object",
"description": "Authentication credentials (varies by auth_type)",
"api_key": {
"type": "string",
"label": "API Key",
"description": "API key for authentication",
"required": False,
"properties": {
"bearer_token": {
"type": "string",
"description": "Bearer token for bearer auth",
},
"access_token": {
"type": "string",
"description": "Access token for OAuth (if pre-obtained)",
},
"api_key": {
"type": "string",
"description": "API key for api_key auth",
},
"api_key_header": {
"type": "string",
"description": "Header name for API key (default: X-API-Key)",
},
"username": {
"type": "string",
"description": "Username for basic auth",
},
"password": {
"type": "string",
"description": "Password for basic auth",
},
},
"secret": True,
"order": 3,
"depends_on": {"auth_type": "api_key"},
},
"api_key_header": {
"type": "string",
"label": "API Key Header",
"description": "Header name for API key (default: X-API-Key)",
"default": "X-API-Key",
"required": False,
"secret": False,
"order": 4,
"depends_on": {"auth_type": "api_key"},
},
"bearer_token": {
"type": "string",
"label": "Bearer Token",
"description": "Bearer token for authentication",
"required": False,
"secret": True,
"order": 3,
"depends_on": {"auth_type": "bearer"},
},
"username": {
"type": "string",
"label": "Username",
"description": "Username for basic authentication",
"required": False,
"secret": False,
"order": 3,
"depends_on": {"auth_type": "basic"},
},
"password": {
"type": "string",
"label": "Password",
"description": "Password for basic authentication",
"required": False,
"secret": True,
"order": 4,
"depends_on": {"auth_type": "basic"},
},
"oauth_scopes": {
"type": "array",
"description": "OAuth scopes to request (for oauth auth_type)",
"items": {"type": "string"},
"required": False,
"default": [],
},
"oauth_client_name": {
"type": "string",
"description": "Client name for OAuth registration (for oauth auth_type)",
"default": "DocsGPT-MCP",
"required": False,
},
"headers": {
"type": "object",
"description": "Custom headers to send with requests",
"label": "OAuth Scopes",
"description": "Comma-separated OAuth scopes to request",
"required": False,
"secret": False,
"order": 3,
"depends_on": {"auth_type": "oauth"},
},
"timeout": {
"type": "integer",
"description": "Request timeout in seconds",
"type": "number",
"label": "Timeout (seconds)",
"description": "Request timeout in seconds (1-300)",
"default": 30,
"minimum": 1,
"maximum": 300,
"required": False,
},
"command": {
"type": "string",
"description": "Command to run for STDIO transport (e.g., 'python')",
"required": False,
},
"args": {
"type": "array",
"description": "Arguments for STDIO command",
"items": {"type": "string"},
"required": False,
"secret": False,
"order": 10,
},
}
@@ -575,23 +670,8 @@ class DocsGPTOAuth(OAuthClientProvider):
user_id=None,
db=None,
additional_client_metadata: dict[str, Any] | None = None,
skip_redirect_validation: bool = False,
):
"""
Initialize custom OAuth client provider for DocsGPT.
Args:
mcp_url: Full URL to the MCP endpoint
redirect_uri: Custom redirect URI for DocsGPT frontend
redis_client: Redis client for storing auth state
redis_prefix: Prefix for Redis keys
task_id: Task ID for tracking auth status
scopes: OAuth scopes to request
client_name: Name for this client during registration
user_id: User ID for token storage
db: Database instance for token storage
additional_client_metadata: Extra fields for OAuthClientMetadata
"""
self.redirect_uri = redirect_uri
self.redis_client = redis_client
self.redis_prefix = redis_prefix
@@ -614,7 +694,10 @@ class DocsGPTOAuth(OAuthClientProvider):
)
storage = DBTokenStorage(
server_url=self.server_base_url, user_id=self.user_id, db_client=self.db
server_url=self.server_base_url,
user_id=self.user_id,
db_client=self.db,
expected_redirect_uri=None if skip_redirect_validation else redirect_uri,
)
super().__init__(
@@ -646,22 +729,20 @@ class DocsGPTOAuth(OAuthClientProvider):
async def redirect_handler(self, authorization_url: str) -> None:
"""Store auth URL and state in Redis for frontend to use."""
auth_url, state = self._process_auth_url(authorization_url)
logging.info(
"[DocsGPTOAuth] Processed auth_url: %s, state: %s", auth_url, state
)
logger.info("Processed auth_url: %s, state: %s", auth_url, state)
self.auth_url = auth_url
self.extracted_state = state
if self.redis_client and self.extracted_state:
key = f"{self.redis_prefix}auth_url:{self.extracted_state}"
self.redis_client.setex(key, 600, auth_url)
logging.info("[DocsGPTOAuth] Stored auth_url in Redis: %s", key)
logger.info("Stored auth_url in Redis: %s", key)
if self.task_id:
status_key = f"mcp_oauth_status:{self.task_id}"
status_data = {
"status": "requires_redirect",
"message": "OAuth authorization required",
"message": "Authorization required",
"authorization_url": self.auth_url,
"state": self.extracted_state,
"requires_oauth": True,
@@ -681,7 +762,7 @@ class DocsGPTOAuth(OAuthClientProvider):
status_key = f"mcp_oauth_status:{self.task_id}"
status_data = {
"status": "awaiting_callback",
"message": "Waiting for OAuth callback...",
"message": "Waiting for authorization...",
"authorization_url": self.auth_url,
"state": self.extracted_state,
"requires_oauth": True,
@@ -706,7 +787,7 @@ class DocsGPTOAuth(OAuthClientProvider):
if self.task_id:
status_data = {
"status": "callback_received",
"message": "OAuth callback received, completing authentication...",
"message": "Completing authentication...",
"task_id": self.task_id,
}
self.redis_client.setex(status_key, 600, json.dumps(status_data))
@@ -726,14 +807,44 @@ class DocsGPTOAuth(OAuthClientProvider):
await asyncio.sleep(poll_interval)
self.redis_client.delete(f"{self.redis_prefix}auth_url:{self.extracted_state}")
self.redis_client.delete(f"{self.redis_prefix}state:{self.extracted_state}")
raise Exception("OAuth callback timeout: no code received within 5 minutes")
raise Exception("OAuth timeout: no code received within 5 minutes")
class NonInteractiveOAuth(DocsGPTOAuth):
"""OAuth provider that fails fast on 401 instead of starting interactive auth.
Used during query execution to prevent the streaming response from blocking
while waiting for user authorization that will never come.
"""
def __init__(self, **kwargs):
kwargs.setdefault("task_id", None)
kwargs["skip_redirect_validation"] = True
super().__init__(**kwargs)
async def redirect_handler(self, authorization_url: str) -> None:
raise Exception(
"OAuth session expired — please re-authorize this MCP server in tool settings."
)
async def callback_handler(self) -> tuple[str, str | None]:
raise Exception(
"OAuth session expired — please re-authorize this MCP server in tool settings."
)
class DBTokenStorage(TokenStorage):
def __init__(self, server_url: str, user_id: str, db_client):
def __init__(
self,
server_url: str,
user_id: str,
db_client,
expected_redirect_uri: Optional[str] = None,
):
self.server_url = server_url
self.user_id = user_id
self.db_client = db_client
self.expected_redirect_uri = expected_redirect_uri
self.collection = db_client["connector_sessions"]
@staticmethod
@@ -752,10 +863,9 @@ class DBTokenStorage(TokenStorage):
if not doc or "tokens" not in doc:
return None
try:
tokens = OAuthToken.model_validate(doc["tokens"])
return tokens
return OAuthToken.model_validate(doc["tokens"])
except ValidationError as e:
logging.error(f"Could not load tokens: {e}")
logger.error("Could not load tokens: %s", e)
return None
async def set_tokens(self, tokens: OAuthToken) -> None:
@@ -765,28 +875,38 @@ class DBTokenStorage(TokenStorage):
{"$set": {"tokens": tokens.model_dump()}},
True,
)
logging.info(f"Saved tokens for {self.get_base_url(self.server_url)}")
logger.info("Saved tokens for %s", self.get_base_url(self.server_url))
async def get_client_info(self) -> OAuthClientInformationFull | None:
doc = await asyncio.to_thread(self.collection.find_one, self.get_db_key())
if not doc or "client_info" not in doc:
logger.debug(
"No client_info in DB for %s", self.get_base_url(self.server_url)
)
return None
try:
client_info = OAuthClientInformationFull.model_validate(doc["client_info"])
tokens = await self.get_tokens()
if tokens is None:
logging.debug(
"No tokens found, clearing client info to force fresh registration."
)
await asyncio.to_thread(
self.collection.update_one,
self.get_db_key(),
{"$unset": {"client_info": ""}},
)
return None
if self.expected_redirect_uri:
stored_uris = [
str(uri).rstrip("/") for uri in client_info.redirect_uris
]
expected_uri = self.expected_redirect_uri.rstrip("/")
if expected_uri not in stored_uris:
logger.warning(
"Redirect URI mismatch for %s: expected=%s stored=%s — clearing.",
self.get_base_url(self.server_url),
expected_uri,
stored_uris,
)
await asyncio.to_thread(
self.collection.update_one,
self.get_db_key(),
{"$unset": {"client_info": "", "tokens": ""}},
)
return None
return client_info
except ValidationError as e:
logging.error(f"Could not load client info: {e}")
logger.error("Could not load client info: %s", e)
return None
def _serialize_client_info(self, info: dict) -> dict:
@@ -802,17 +922,17 @@ class DBTokenStorage(TokenStorage):
{"$set": {"client_info": serialized_info}},
True,
)
logging.info(f"Saved client info for {self.get_base_url(self.server_url)}")
logger.info("Saved client info for %s", self.get_base_url(self.server_url))
async def clear(self) -> None:
await asyncio.to_thread(self.collection.delete_one, self.get_db_key())
logging.info(f"Cleared OAuth cache for {self.get_base_url(self.server_url)}")
logger.info("Cleared OAuth cache for %s", self.get_base_url(self.server_url))
@classmethod
async def clear_all(cls, db_client) -> None:
collection = db_client["connector_sessions"]
await asyncio.to_thread(collection.delete_many, {})
logging.info("Cleared all OAuth client cache data.")
logger.info("Cleared all OAuth client cache data.")
class MCPOAuthManager:
@@ -851,7 +971,7 @@ class MCPOAuthManager:
return True
except Exception as e:
logging.error(f"Error handling OAuth callback: {e}")
logger.error("Error handling OAuth callback: %s", e)
return False
def get_oauth_status(self, task_id: str) -> Dict[str, Any]:

View File

@@ -38,6 +38,8 @@ class NotesTool(Tool):
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
self.collection = db["notes"]
self._last_artifact_id: Optional[str] = None
# -----------------------------
# Action implementations
# -----------------------------
@@ -54,6 +56,8 @@ class NotesTool(Tool):
if not self.user_id:
return "Error: NotesTool requires a valid user_id."
self._last_artifact_id = None
if action_name == "view":
return self._get_note()
@@ -125,6 +129,9 @@ class NotesTool(Tool):
"""Return configuration requirements (none for now)."""
return {}
def get_artifact_id(self, action_name: str, **kwargs: Any) -> Optional[str]:
return self._last_artifact_id
# -----------------------------
# Internal helpers (single-note)
# -----------------------------
@@ -132,17 +139,22 @@ class NotesTool(Tool):
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
if not doc or not doc.get("note"):
return "No note found."
if doc.get("_id") is not None:
self._last_artifact_id = str(doc.get("_id"))
return str(doc["note"])
def _overwrite_note(self, content: str) -> str:
content = (content or "").strip()
if not content:
return "Note content required."
self.collection.update_one(
result = self.collection.find_one_and_update(
{"user_id": self.user_id, "tool_id": self.tool_id},
{"$set": {"note": content, "updated_at": datetime.utcnow()}},
upsert=True, # ✅ create if missing
upsert=True,
return_document=True,
)
if result and result.get("_id") is not None:
self._last_artifact_id = str(result.get("_id"))
return "Note saved."
def _str_replace(self, old_str: str, new_str: str) -> str:
@@ -163,10 +175,13 @@ class NotesTool(Tool):
import re
updated_note = re.sub(re.escape(old_str), new_str, current_note, flags=re.IGNORECASE)
self.collection.update_one(
result = self.collection.find_one_and_update(
{"user_id": self.user_id, "tool_id": self.tool_id},
{"$set": {"note": updated_note, "updated_at": datetime.utcnow()}},
return_document=True,
)
if result and result.get("_id") is not None:
self._last_artifact_id = str(result.get("_id"))
return "Note updated."
def _insert(self, line_number: int, text: str) -> str:
@@ -188,12 +203,21 @@ class NotesTool(Tool):
lines.insert(index, text)
updated_note = "\n".join(lines)
self.collection.update_one(
result = self.collection.find_one_and_update(
{"user_id": self.user_id, "tool_id": self.tool_id},
{"$set": {"note": updated_note, "updated_at": datetime.utcnow()}},
return_document=True,
)
if result and result.get("_id") is not None:
self._last_artifact_id = str(result.get("_id"))
return "Text inserted."
def _delete_note(self) -> str:
res = self.collection.delete_one({"user_id": self.user_id, "tool_id": self.tool_id})
return "Note deleted." if res.deleted_count else "No note found to delete."
doc = self.collection.find_one_and_delete(
{"user_id": self.user_id, "tool_id": self.tool_id}
)
if not doc:
return "No note found to delete."
if doc.get("_id") is not None:
self._last_artifact_id = str(doc.get("_id"))
return "Note deleted."

View File

@@ -116,12 +116,13 @@ class NtfyTool(Tool):
]
def get_config_requirements(self):
"""
Specify the configuration requirements.
Returns:
dict: Dictionary describing required config parameters.
"""
return {
"token": {"type": "string", "description": "Access token for authentication"},
"token": {
"type": "string",
"label": "Access Token",
"description": "Ntfy access token for authentication",
"required": True,
"secret": True,
"order": 1,
},
}

View File

@@ -1,6 +1,12 @@
import logging
import psycopg2
from application.agents.tools.base import Tool
logger = logging.getLogger(__name__)
class PostgresTool(Tool):
"""
PostgreSQL Database Tool
@@ -17,17 +23,15 @@ class PostgresTool(Tool):
"postgres_execute_sql": self._execute_sql,
"postgres_get_schema": self._get_schema,
}
if action_name in actions:
return actions[action_name](**kwargs)
else:
if action_name not in actions:
raise ValueError(f"Unknown action: {action_name}")
return actions[action_name](**kwargs)
def _execute_sql(self, sql_query):
"""
Executes an SQL query against the PostgreSQL database using a connection string.
"""
conn = None # Initialize conn to None for error handling
conn = None
try:
conn = psycopg2.connect(self.connection_string)
cur = conn.cursor()
@@ -35,7 +39,9 @@ class PostgresTool(Tool):
conn.commit()
if sql_query.strip().lower().startswith("select"):
column_names = [desc[0] for desc in cur.description] if cur.description else []
column_names = (
[desc[0] for desc in cur.description] if cur.description else []
)
results = []
rows = cur.fetchall()
for row in rows:
@@ -43,7 +49,9 @@ class PostgresTool(Tool):
response_data = {"data": results, "column_names": column_names}
else:
row_count = cur.rowcount
response_data = {"message": f"Query executed successfully, {row_count} rows affected."}
response_data = {
"message": f"Query executed successfully, {row_count} rows affected."
}
cur.close()
return {
@@ -54,26 +62,27 @@ class PostgresTool(Tool):
except psycopg2.Error as e:
error_message = f"Database error: {e}"
print(f"Database error: {e}")
logger.error("PostgreSQL execute_sql error: %s", e)
return {
"status_code": 500,
"message": "Failed to execute SQL query.",
"error": error_message,
}
finally:
if conn: # Ensure connection is closed even if errors occur
if conn:
conn.close()
def _get_schema(self, db_name):
"""
Retrieves the schema of the PostgreSQL database using a connection string.
"""
conn = None # Initialize conn to None for error handling
conn = None
try:
conn = psycopg2.connect(self.connection_string)
cur = conn.cursor()
cur.execute("""
cur.execute(
"""
SELECT
table_name,
column_name,
@@ -87,19 +96,22 @@ class PostgresTool(Tool):
ORDER BY
table_name,
ordinal_position;
""")
"""
)
schema_data = {}
for row in cur.fetchall():
table_name, column_name, data_type, column_default, is_nullable = row
if table_name not in schema_data:
schema_data[table_name] = []
schema_data[table_name].append({
"column_name": column_name,
"data_type": data_type,
"column_default": column_default,
"is_nullable": is_nullable
})
schema_data[table_name].append(
{
"column_name": column_name,
"data_type": data_type,
"column_default": column_default,
"is_nullable": is_nullable,
}
)
cur.close()
return {
@@ -110,14 +122,14 @@ class PostgresTool(Tool):
except psycopg2.Error as e:
error_message = f"Database error: {e}"
print(f"Database error: {e}")
logger.error("PostgreSQL get_schema error: %s", e)
return {
"status_code": 500,
"message": "Failed to retrieve database schema.",
"error": error_message,
}
finally:
if conn: # Ensure connection is closed even if errors occur
if conn:
conn.close()
def get_actions_metadata(self):
@@ -158,6 +170,10 @@ class PostgresTool(Tool):
return {
"token": {
"type": "string",
"description": "PostgreSQL database connection string (e.g., 'postgresql://user:password@host:port/dbname')",
"label": "Connection String",
"description": "PostgreSQL database connection string",
"required": True,
"secret": True,
"order": 1,
},
}
}

View File

@@ -1,7 +1,7 @@
import requests
from markdownify import markdownify
from application.agents.tools.base import Tool
from urllib.parse import urlparse
from application.core.url_validation import validate_url, SSRFError
class ReadWebpageTool(Tool):
"""
@@ -31,11 +31,12 @@ class ReadWebpageTool(Tool):
if not url:
return "Error: URL parameter is missing."
# Ensure the URL has a scheme (if not, default to http)
parsed_url = urlparse(url)
if not parsed_url.scheme:
url = "http://" + url
# Validate URL to prevent SSRF attacks
try:
url = validate_url(url)
except SSRFError as e:
return f"Error: URL validation failed - {e}"
try:
response = requests.get(url, timeout=10, headers={'User-Agent': 'DocsGPT-Agent/1.0'})
response.raise_for_status() # Raise an exception for HTTP errors (4xx or 5xx)

View File

@@ -0,0 +1,342 @@
"""
API Specification Parser
Parses OpenAPI 3.x and Swagger 2.0 specifications and converts them
to API Tool action definitions for use in DocsGPT.
"""
import json
import logging
import re
from typing import Any, Dict, List, Optional, Tuple
import yaml
logger = logging.getLogger(__name__)
SUPPORTED_METHODS = frozenset(
{"get", "post", "put", "delete", "patch", "head", "options"}
)
def parse_spec(spec_content: str) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
"""
Parse an API specification and convert operations to action definitions.
Supports OpenAPI 3.x and Swagger 2.0 formats in JSON or YAML.
Args:
spec_content: Raw specification content as string
Returns:
Tuple of (metadata dict, list of action dicts)
Raises:
ValueError: If the spec is invalid or uses an unsupported format
"""
spec = _load_spec(spec_content)
_validate_spec(spec)
is_swagger = "swagger" in spec
metadata = _extract_metadata(spec, is_swagger)
actions = _extract_actions(spec, is_swagger)
return metadata, actions
def _load_spec(content: str) -> Dict[str, Any]:
"""Parse spec content from JSON or YAML string."""
content = content.strip()
if not content:
raise ValueError("Empty specification content")
try:
if content.startswith("{"):
return json.loads(content)
return yaml.safe_load(content)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON format: {e.msg}")
except yaml.YAMLError as e:
raise ValueError(f"Invalid YAML format: {e}")
def _validate_spec(spec: Dict[str, Any]) -> None:
"""Validate spec version and required fields."""
if not isinstance(spec, dict):
raise ValueError("Specification must be a valid object")
openapi_version = spec.get("openapi", "")
swagger_version = spec.get("swagger", "")
if not (openapi_version.startswith("3.") or swagger_version == "2.0"):
raise ValueError(
"Unsupported specification version. Expected OpenAPI 3.x or Swagger 2.0"
)
if "paths" not in spec or not spec["paths"]:
raise ValueError("No API paths defined in the specification")
def _extract_metadata(spec: Dict[str, Any], is_swagger: bool) -> Dict[str, Any]:
"""Extract API metadata from specification."""
info = spec.get("info", {})
base_url = _get_base_url(spec, is_swagger)
return {
"title": info.get("title", "Untitled API"),
"description": (info.get("description", "") or "")[:500],
"version": info.get("version", ""),
"base_url": base_url,
}
def _get_base_url(spec: Dict[str, Any], is_swagger: bool) -> str:
"""Extract base URL from spec (handles both OpenAPI 3.x and Swagger 2.0)."""
if is_swagger:
schemes = spec.get("schemes", ["https"])
host = spec.get("host", "")
base_path = spec.get("basePath", "")
if host:
scheme = schemes[0] if schemes else "https"
return f"{scheme}://{host}{base_path}".rstrip("/")
return ""
servers = spec.get("servers", [])
if servers and isinstance(servers, list) and servers[0].get("url"):
return servers[0]["url"].rstrip("/")
return ""
def _extract_actions(spec: Dict[str, Any], is_swagger: bool) -> List[Dict[str, Any]]:
"""Extract all API operations as action definitions."""
actions = []
paths = spec.get("paths", {})
base_url = _get_base_url(spec, is_swagger)
components = spec.get("components", {})
definitions = spec.get("definitions", {})
for path, path_item in paths.items():
if not isinstance(path_item, dict):
continue
path_params = path_item.get("parameters", [])
for method in SUPPORTED_METHODS:
operation = path_item.get(method)
if not isinstance(operation, dict):
continue
try:
action = _build_action(
path=path,
method=method,
operation=operation,
path_params=path_params,
base_url=base_url,
components=components,
definitions=definitions,
is_swagger=is_swagger,
)
actions.append(action)
except Exception as e:
logger.warning(
f"Failed to parse operation {method.upper()} {path}: {e}"
)
continue
return actions
def _build_action(
path: str,
method: str,
operation: Dict[str, Any],
path_params: List[Dict],
base_url: str,
components: Dict[str, Any],
definitions: Dict[str, Any],
is_swagger: bool,
) -> Dict[str, Any]:
"""Build a single action from an API operation."""
action_name = _generate_action_name(operation, method, path)
full_url = f"{base_url}{path}" if base_url else path
all_params = path_params + operation.get("parameters", [])
query_params, headers = _categorize_parameters(all_params, components, definitions)
body, body_content_type = _extract_request_body(
operation, components, definitions, is_swagger
)
description = operation.get("summary", "") or operation.get("description", "")
return {
"name": action_name,
"url": full_url,
"method": method.upper(),
"description": (description or "")[:500],
"query_params": {"type": "object", "properties": query_params},
"headers": {"type": "object", "properties": headers},
"body": {"type": "object", "properties": body},
"body_content_type": body_content_type,
"active": True,
}
def _generate_action_name(operation: Dict[str, Any], method: str, path: str) -> str:
"""Generate a valid action name from operationId or method+path."""
if operation.get("operationId"):
name = operation["operationId"]
else:
path_slug = re.sub(r"[{}]", "", path)
path_slug = re.sub(r"[^a-zA-Z0-9]", "_", path_slug)
path_slug = re.sub(r"_+", "_", path_slug).strip("_")
name = f"{method}_{path_slug}"
name = re.sub(r"[^a-zA-Z0-9_-]", "_", name)
return name[:64]
def _categorize_parameters(
parameters: List[Dict],
components: Dict[str, Any],
definitions: Dict[str, Any],
) -> Tuple[Dict, Dict]:
"""Categorize parameters into query params and headers."""
query_params = {}
headers = {}
for param in parameters:
resolved = _resolve_ref(param, components, definitions)
if not resolved or "name" not in resolved:
continue
location = resolved.get("in", "query")
prop = _param_to_property(resolved)
if location in ("query", "path"):
query_params[resolved["name"]] = prop
elif location == "header":
headers[resolved["name"]] = prop
return query_params, headers
def _param_to_property(param: Dict) -> Dict[str, Any]:
"""Convert an API parameter to an action property definition."""
schema = param.get("schema", {})
param_type = schema.get("type", param.get("type", "string"))
mapped_type = "integer" if param_type in ("integer", "number") else "string"
return {
"type": mapped_type,
"description": (param.get("description", "") or "")[:200],
"value": "",
"filled_by_llm": param.get("required", False),
"required": param.get("required", False),
}
def _extract_request_body(
operation: Dict[str, Any],
components: Dict[str, Any],
definitions: Dict[str, Any],
is_swagger: bool,
) -> Tuple[Dict, str]:
"""Extract request body schema and content type."""
content_types = [
"application/json",
"application/x-www-form-urlencoded",
"multipart/form-data",
"text/plain",
"application/xml",
]
if is_swagger:
consumes = operation.get("consumes", [])
body_param = next(
(p for p in operation.get("parameters", []) if p.get("in") == "body"), None
)
if not body_param:
return {}, "application/json"
selected_type = consumes[0] if consumes else "application/json"
schema = body_param.get("schema", {})
else:
request_body = operation.get("requestBody", {})
if not request_body:
return {}, "application/json"
request_body = _resolve_ref(request_body, components, definitions)
content = request_body.get("content", {})
selected_type = "application/json"
schema = {}
for ct in content_types:
if ct in content:
selected_type = ct
schema = content[ct].get("schema", {})
break
if not schema and content:
first_type = next(iter(content))
selected_type = first_type
schema = content[first_type].get("schema", {})
properties = _schema_to_properties(schema, components, definitions)
return properties, selected_type
def _schema_to_properties(
schema: Dict,
components: Dict[str, Any],
definitions: Dict[str, Any],
depth: int = 0,
) -> Dict[str, Any]:
"""Convert schema to action body properties (limited depth to prevent recursion)."""
if depth > 3:
return {}
schema = _resolve_ref(schema, components, definitions)
if not schema or not isinstance(schema, dict):
return {}
properties = {}
schema_type = schema.get("type", "object")
if schema_type == "object":
required_fields = set(schema.get("required", []))
for prop_name, prop_schema in schema.get("properties", {}).items():
resolved = _resolve_ref(prop_schema, components, definitions)
if not isinstance(resolved, dict):
continue
prop_type = resolved.get("type", "string")
mapped_type = "integer" if prop_type in ("integer", "number") else "string"
properties[prop_name] = {
"type": mapped_type,
"description": (resolved.get("description", "") or "")[:200],
"value": "",
"filled_by_llm": prop_name in required_fields,
"required": prop_name in required_fields,
}
return properties
def _resolve_ref(
obj: Any,
components: Dict[str, Any],
definitions: Dict[str, Any],
) -> Optional[Dict]:
"""Resolve $ref references in the specification."""
if not isinstance(obj, dict):
return obj if isinstance(obj, dict) else None
if "$ref" not in obj:
return obj
ref_path = obj["$ref"]
if ref_path.startswith("#/components/"):
parts = ref_path.replace("#/components/", "").split("/")
return _traverse_path(components, parts)
elif ref_path.startswith("#/definitions/"):
parts = ref_path.replace("#/definitions/", "").split("/")
return _traverse_path(definitions, parts)
logger.debug(f"Unsupported ref path: {ref_path}")
return None
def _traverse_path(obj: Dict, parts: List[str]) -> Optional[Dict]:
"""Traverse a nested dictionary using path parts."""
try:
for part in parts:
obj = obj[part]
return obj if isinstance(obj, dict) else None
except (KeyError, TypeError):
return None

View File

@@ -1,6 +1,11 @@
import logging
import requests
from application.agents.tools.base import Tool
logger = logging.getLogger(__name__)
class TelegramTool(Tool):
"""
@@ -18,21 +23,19 @@ class TelegramTool(Tool):
"telegram_send_message": self._send_message,
"telegram_send_image": self._send_image,
}
if action_name in actions:
return actions[action_name](**kwargs)
else:
if action_name not in actions:
raise ValueError(f"Unknown action: {action_name}")
return actions[action_name](**kwargs)
def _send_message(self, text, chat_id):
print(f"Sending message: {text}")
logger.debug("Sending Telegram message to chat_id=%s", chat_id)
url = f"https://api.telegram.org/bot{self.token}/sendMessage"
payload = {"chat_id": chat_id, "text": text}
response = requests.post(url, data=payload)
return {"status_code": response.status_code, "message": "Message sent"}
def _send_image(self, image_url, chat_id):
print(f"Sending image: {image_url}")
logger.debug("Sending Telegram image to chat_id=%s", chat_id)
url = f"https://api.telegram.org/bot{self.token}/sendPhoto"
payload = {"chat_id": chat_id, "photo": image_url}
response = requests.post(url, data=payload)
@@ -82,5 +85,12 @@ class TelegramTool(Tool):
def get_config_requirements(self):
return {
"token": {"type": "string", "description": "Bot token for authentication"},
"token": {
"type": "string",
"label": "Bot Token",
"description": "Telegram bot token for authentication",
"required": True,
"secret": True,
"order": 1,
},
}

View File

@@ -38,6 +38,8 @@ class TodoListTool(Tool):
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
self.collection = db["todos"]
self._last_artifact_id: Optional[str] = None
# -----------------------------
# Action implementations
# -----------------------------
@@ -54,6 +56,8 @@ class TodoListTool(Tool):
if not self.user_id:
return "Error: TodoListTool requires a valid user_id."
self._last_artifact_id = None
if action_name == "list":
return self._list()
@@ -165,6 +169,9 @@ class TodoListTool(Tool):
"""Return configuration requirements."""
return {}
def get_artifact_id(self, action_name: str, **kwargs: Any) -> Optional[str]:
return self._last_artifact_id
# -----------------------------
# Internal helpers
# -----------------------------
@@ -190,11 +197,8 @@ class TodoListTool(Tool):
Returns a simple integer (1, 2, 3, ...) scoped to this user/tool.
With 5-10 todos max, scanning is negligible.
"""
# Find all todos for this user/tool and get their IDs
todos = list(self.collection.find(
{"user_id": self.user_id, "tool_id": self.tool_id},
{"todo_id": 1}
))
query = {"user_id": self.user_id, "tool_id": self.tool_id}
todos = list(self.collection.find(query, {"todo_id": 1}))
# Find the maximum todo_id
max_id = 0
@@ -207,8 +211,8 @@ class TodoListTool(Tool):
def _list(self) -> str:
"""List all todos for the user."""
cursor = self.collection.find({"user_id": self.user_id, "tool_id": self.tool_id})
todos = list(cursor)
query = {"user_id": self.user_id, "tool_id": self.tool_id}
todos = list(self.collection.find(query))
if not todos:
return "No todos found."
@@ -242,7 +246,10 @@ class TodoListTool(Tool):
"created_at": now,
"updated_at": now,
}
self.collection.insert_one(doc)
insert_result = self.collection.insert_one(doc)
inserted_id = getattr(insert_result, "inserted_id", None) or doc.get("_id")
if inserted_id is not None:
self._last_artifact_id = str(inserted_id)
return f"Todo created with ID {todo_id}: {title}"
def _get(self, todo_id: Optional[Any]) -> str:
@@ -251,15 +258,15 @@ class TodoListTool(Tool):
if parsed_todo_id is None:
return "Error: todo_id must be a positive integer."
doc = self.collection.find_one({
"user_id": self.user_id,
"tool_id": self.tool_id,
"todo_id": parsed_todo_id
})
query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}
doc = self.collection.find_one(query)
if not doc:
return f"Error: Todo with ID {parsed_todo_id} not found."
if doc.get("_id") is not None:
self._last_artifact_id = str(doc.get("_id"))
title = doc.get("title", "Untitled")
status = doc.get("status", "open")
@@ -277,14 +284,17 @@ class TodoListTool(Tool):
if not title:
return "Error: Title is required."
result = self.collection.update_one(
{"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id},
{"$set": {"title": title, "updated_at": datetime.now()}}
query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}
doc = self.collection.find_one_and_update(
query,
{"$set": {"title": title, "updated_at": datetime.now()}},
)
if result.matched_count == 0:
if not doc:
return f"Error: Todo with ID {parsed_todo_id} not found."
if doc.get("_id") is not None:
self._last_artifact_id = str(doc.get("_id"))
return f"Todo {parsed_todo_id} updated to: {title}"
def _complete(self, todo_id: Optional[Any]) -> str:
@@ -293,14 +303,17 @@ class TodoListTool(Tool):
if parsed_todo_id is None:
return "Error: todo_id must be a positive integer."
result = self.collection.update_one(
{"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id},
{"$set": {"status": "completed", "updated_at": datetime.now()}}
query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}
doc = self.collection.find_one_and_update(
query,
{"$set": {"status": "completed", "updated_at": datetime.now()}},
)
if result.matched_count == 0:
if not doc:
return f"Error: Todo with ID {parsed_todo_id} not found."
if doc.get("_id") is not None:
self._last_artifact_id = str(doc.get("_id"))
return f"Todo {parsed_todo_id} marked as completed."
def _delete(self, todo_id: Optional[Any]) -> str:
@@ -309,13 +322,12 @@ class TodoListTool(Tool):
if parsed_todo_id is None:
return "Error: todo_id must be a positive integer."
result = self.collection.delete_one({
"user_id": self.user_id,
"tool_id": self.tool_id,
"todo_id": parsed_todo_id
})
if result.deleted_count == 0:
query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}
doc = self.collection.find_one_and_delete(query)
if not doc:
return f"Error: Todo with ID {parsed_todo_id} not found."
if doc.get("_id") is not None:
self._last_artifact_id = str(doc.get("_id"))
return f"Todo {parsed_todo_id} deleted."

View File

@@ -36,7 +36,7 @@ class ToolManager:
def execute_action(self, tool_name, action_name, user_id=None, **kwargs):
if tool_name not in self.tools:
raise ValueError(f"Tool '{tool_name}' not loaded")
if tool_name in {"mcp_tool", "memory", "todo_list"} and user_id:
if tool_name in {"mcp_tool", "memory", "todo_list", "notes"} and user_id:
tool_config = self.config.get(tool_name, {})
tool = self.load_tool(tool_name, tool_config, user_id)
return tool.execute_action(action_name, **kwargs)

View File

@@ -0,0 +1,231 @@
import logging
from datetime import datetime, timezone
from typing import Any, Dict, Generator, Optional
from application.agents.base import BaseAgent
from application.agents.workflows.schemas import (
ExecutionStatus,
Workflow,
WorkflowEdge,
WorkflowGraph,
WorkflowNode,
WorkflowRun,
)
from application.agents.workflows.workflow_engine import WorkflowEngine
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.logging import log_activity, LogContext
logger = logging.getLogger(__name__)
class WorkflowAgent(BaseAgent):
"""A specialized agent that executes predefined workflows."""
def __init__(
self,
*args,
workflow_id: Optional[str] = None,
workflow: Optional[Dict[str, Any]] = None,
workflow_owner: Optional[str] = None,
**kwargs,
):
super().__init__(*args, **kwargs)
self.workflow_id = workflow_id
self.workflow_owner = workflow_owner
self._workflow_data = workflow
self._engine: Optional[WorkflowEngine] = None
@log_activity()
def gen(
self, query: str, log_context: LogContext = None
) -> Generator[Dict[str, str], None, None]:
yield from self._gen_inner(query, log_context)
def _gen_inner(
self, query: str, log_context: LogContext
) -> Generator[Dict[str, str], None, None]:
graph = self._load_workflow_graph()
if not graph:
yield {"type": "error", "error": "Failed to load workflow configuration."}
return
self._engine = WorkflowEngine(graph, self)
yield from self._engine.execute({}, query)
self._save_workflow_run(query)
def _load_workflow_graph(self) -> Optional[WorkflowGraph]:
if self._workflow_data:
return self._parse_embedded_workflow()
if self.workflow_id:
return self._load_from_database()
return None
def _parse_embedded_workflow(self) -> Optional[WorkflowGraph]:
try:
nodes_data = self._workflow_data.get("nodes", [])
edges_data = self._workflow_data.get("edges", [])
workflow = Workflow(
name=self._workflow_data.get("name", "Embedded Workflow"),
description=self._workflow_data.get("description"),
)
nodes = []
for n in nodes_data:
node_config = n.get("data", {})
nodes.append(
WorkflowNode(
id=n["id"],
workflow_id=self.workflow_id or "embedded",
type=n["type"],
title=n.get("title", "Node"),
description=n.get("description"),
position=n.get("position", {"x": 0, "y": 0}),
config=node_config,
)
)
edges = []
for e in edges_data:
edges.append(
WorkflowEdge(
id=e["id"],
workflow_id=self.workflow_id or "embedded",
source=e.get("source") or e.get("source_id"),
target=e.get("target") or e.get("target_id"),
sourceHandle=e.get("sourceHandle") or e.get("source_handle"),
targetHandle=e.get("targetHandle") or e.get("target_handle"),
)
)
return WorkflowGraph(workflow=workflow, nodes=nodes, edges=edges)
except Exception as e:
logger.error(f"Invalid embedded workflow: {e}")
return None
def _load_from_database(self) -> Optional[WorkflowGraph]:
try:
from bson.objectid import ObjectId
if not self.workflow_id or not ObjectId.is_valid(self.workflow_id):
logger.error(f"Invalid workflow ID: {self.workflow_id}")
return None
owner_id = self.workflow_owner
if not owner_id and isinstance(self.decoded_token, dict):
owner_id = self.decoded_token.get("sub")
if not owner_id:
logger.error(
f"Workflow owner not available for workflow load: {self.workflow_id}"
)
return None
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
workflows_coll = db["workflows"]
workflow_nodes_coll = db["workflow_nodes"]
workflow_edges_coll = db["workflow_edges"]
workflow_doc = workflows_coll.find_one(
{"_id": ObjectId(self.workflow_id), "user": owner_id}
)
if not workflow_doc:
logger.error(
f"Workflow {self.workflow_id} not found or inaccessible for user {owner_id}"
)
return None
workflow = Workflow(**workflow_doc)
graph_version = workflow_doc.get("current_graph_version", 1)
try:
graph_version = int(graph_version)
if graph_version <= 0:
graph_version = 1
except (ValueError, TypeError):
graph_version = 1
nodes_docs = list(
workflow_nodes_coll.find(
{"workflow_id": self.workflow_id, "graph_version": graph_version}
)
)
if not nodes_docs and graph_version == 1:
nodes_docs = list(
workflow_nodes_coll.find(
{
"workflow_id": self.workflow_id,
"graph_version": {"$exists": False},
}
)
)
nodes = [WorkflowNode(**doc) for doc in nodes_docs]
edges_docs = list(
workflow_edges_coll.find(
{"workflow_id": self.workflow_id, "graph_version": graph_version}
)
)
if not edges_docs and graph_version == 1:
edges_docs = list(
workflow_edges_coll.find(
{
"workflow_id": self.workflow_id,
"graph_version": {"$exists": False},
}
)
)
edges = [WorkflowEdge(**doc) for doc in edges_docs]
return WorkflowGraph(workflow=workflow, nodes=nodes, edges=edges)
except Exception as e:
logger.error(f"Failed to load workflow from database: {e}")
return None
def _save_workflow_run(self, query: str) -> None:
if not self._engine:
return
try:
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
workflow_runs_coll = db["workflow_runs"]
run = WorkflowRun(
workflow_id=self.workflow_id or "unknown",
status=self._determine_run_status(),
inputs={"query": query},
outputs=self._serialize_state(self._engine.state),
steps=self._engine.get_execution_summary(),
created_at=datetime.now(timezone.utc),
completed_at=datetime.now(timezone.utc),
)
workflow_runs_coll.insert_one(run.to_mongo_doc())
except Exception as e:
logger.error(f"Failed to save workflow run: {e}")
def _determine_run_status(self) -> ExecutionStatus:
if not self._engine or not self._engine.execution_log:
return ExecutionStatus.COMPLETED
for log in self._engine.execution_log:
if log.get("status") == ExecutionStatus.FAILED.value:
return ExecutionStatus.FAILED
return ExecutionStatus.COMPLETED
def _serialize_state(self, state: Dict[str, Any]) -> Dict[str, Any]:
serialized: Dict[str, Any] = {}
for key, value in state.items():
serialized[key] = self._serialize_state_value(value)
return serialized
def _serialize_state_value(self, value: Any) -> Any:
if isinstance(value, dict):
return {
str(dict_key): self._serialize_state_value(dict_value)
for dict_key, dict_value in value.items()
}
if isinstance(value, list):
return [self._serialize_state_value(item) for item in value]
if isinstance(value, tuple):
return [self._serialize_state_value(item) for item in value]
if isinstance(value, datetime):
return value.isoformat()
if isinstance(value, (str, int, float, bool, type(None))):
return value
return str(value)

View File

@@ -0,0 +1,64 @@
from typing import Any, Dict
import celpy
import celpy.celtypes
class CelEvaluationError(Exception):
pass
def _convert_value(value: Any) -> Any:
if isinstance(value, bool):
return celpy.celtypes.BoolType(value)
if isinstance(value, int):
return celpy.celtypes.IntType(value)
if isinstance(value, float):
return celpy.celtypes.DoubleType(value)
if isinstance(value, str):
return celpy.celtypes.StringType(value)
if isinstance(value, list):
return celpy.celtypes.ListType([_convert_value(item) for item in value])
if isinstance(value, dict):
return celpy.celtypes.MapType(
{celpy.celtypes.StringType(k): _convert_value(v) for k, v in value.items()}
)
if value is None:
return celpy.celtypes.BoolType(False)
return celpy.celtypes.StringType(str(value))
def build_activation(state: Dict[str, Any]) -> Dict[str, Any]:
return {k: _convert_value(v) for k, v in state.items()}
def evaluate_cel(expression: str, state: Dict[str, Any]) -> Any:
if not expression or not expression.strip():
raise CelEvaluationError("Empty expression")
try:
env = celpy.Environment()
ast = env.compile(expression)
program = env.program(ast)
activation = build_activation(state)
result = program.evaluate(activation)
except celpy.CELEvalError as exc:
raise CelEvaluationError(f"CEL evaluation error: {exc}") from exc
except Exception as exc:
raise CelEvaluationError(f"CEL error: {exc}") from exc
return cel_to_python(result)
def cel_to_python(value: Any) -> Any:
if isinstance(value, celpy.celtypes.BoolType):
return bool(value)
if isinstance(value, celpy.celtypes.IntType):
return int(value)
if isinstance(value, celpy.celtypes.DoubleType):
return float(value)
if isinstance(value, celpy.celtypes.StringType):
return str(value)
if isinstance(value, celpy.celtypes.ListType):
return [cel_to_python(item) for item in value]
if isinstance(value, celpy.celtypes.MapType):
return {str(k): cel_to_python(v) for k, v in value.items()}
return value

View File

@@ -0,0 +1,109 @@
"""Workflow Node Agents - defines specialized agents for workflow nodes."""
from typing import Any, Dict, List, Optional, Type
from application.agents.base import BaseAgent
from application.agents.classic_agent import ClassicAgent
from application.agents.react_agent import ReActAgent
from application.agents.workflows.schemas import AgentType
class ToolFilterMixin:
"""Mixin that filters fetched tools to only those specified in tool_ids."""
_allowed_tool_ids: List[str]
def _get_user_tools(self, user: str = "local") -> Dict[str, Dict[str, Any]]:
all_tools = super()._get_user_tools(user)
if not self._allowed_tool_ids:
return {}
filtered_tools = {
tool_id: tool
for tool_id, tool in all_tools.items()
if str(tool.get("_id", "")) in self._allowed_tool_ids
}
return filtered_tools
def _get_tools(self, api_key: str = None) -> Dict[str, Dict[str, Any]]:
all_tools = super()._get_tools(api_key)
if not self._allowed_tool_ids:
return {}
filtered_tools = {
tool_id: tool
for tool_id, tool in all_tools.items()
if str(tool.get("_id", "")) in self._allowed_tool_ids
}
return filtered_tools
class WorkflowNodeClassicAgent(ToolFilterMixin, ClassicAgent):
def __init__(
self,
endpoint: str,
llm_name: str,
model_id: str,
api_key: str,
tool_ids: Optional[List[str]] = None,
**kwargs,
):
super().__init__(
endpoint=endpoint,
llm_name=llm_name,
model_id=model_id,
api_key=api_key,
**kwargs,
)
self._allowed_tool_ids = tool_ids or []
class WorkflowNodeReActAgent(ToolFilterMixin, ReActAgent):
def __init__(
self,
endpoint: str,
llm_name: str,
model_id: str,
api_key: str,
tool_ids: Optional[List[str]] = None,
**kwargs,
):
super().__init__(
endpoint=endpoint,
llm_name=llm_name,
model_id=model_id,
api_key=api_key,
**kwargs,
)
self._allowed_tool_ids = tool_ids or []
class WorkflowNodeAgentFactory:
_agents: Dict[AgentType, Type[BaseAgent]] = {
AgentType.CLASSIC: WorkflowNodeClassicAgent,
AgentType.REACT: WorkflowNodeReActAgent,
}
@classmethod
def create(
cls,
agent_type: AgentType,
endpoint: str,
llm_name: str,
model_id: str,
api_key: str,
tool_ids: Optional[List[str]] = None,
**kwargs,
) -> BaseAgent:
agent_class = cls._agents.get(agent_type)
if not agent_class:
raise ValueError(f"Unsupported agent type: {agent_type}")
return agent_class(
endpoint=endpoint,
llm_name=llm_name,
model_id=model_id,
api_key=api_key,
tool_ids=tool_ids,
**kwargs,
)

View File

@@ -0,0 +1,235 @@
from datetime import datetime, timezone
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Union
from bson import ObjectId
from pydantic import BaseModel, ConfigDict, Field, field_validator
class NodeType(str, Enum):
START = "start"
END = "end"
AGENT = "agent"
NOTE = "note"
STATE = "state"
CONDITION = "condition"
class AgentType(str, Enum):
CLASSIC = "classic"
REACT = "react"
class ExecutionStatus(str, Enum):
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
class Position(BaseModel):
model_config = ConfigDict(extra="forbid")
x: float = 0.0
y: float = 0.0
class AgentNodeConfig(BaseModel):
model_config = ConfigDict(extra="allow")
agent_type: AgentType = AgentType.CLASSIC
llm_name: Optional[str] = None
system_prompt: str = "You are a helpful assistant."
prompt_template: str = ""
output_variable: Optional[str] = None
stream_to_user: bool = True
tools: List[str] = Field(default_factory=list)
sources: List[str] = Field(default_factory=list)
chunks: str = "2"
retriever: str = ""
model_id: Optional[str] = None
json_schema: Optional[Dict[str, Any]] = None
class ConditionCase(BaseModel):
model_config = ConfigDict(extra="forbid", populate_by_name=True)
name: Optional[str] = None
expression: str = ""
source_handle: str = Field(..., alias="sourceHandle")
class ConditionNodeConfig(BaseModel):
model_config = ConfigDict(extra="allow")
mode: Literal["simple", "advanced"] = "simple"
cases: List[ConditionCase] = Field(default_factory=list)
class StateOperation(BaseModel):
model_config = ConfigDict(extra="forbid")
expression: str = ""
target_variable: str = ""
class WorkflowEdgeCreate(BaseModel):
model_config = ConfigDict(populate_by_name=True)
id: str
workflow_id: str
source_id: str = Field(..., alias="source")
target_id: str = Field(..., alias="target")
source_handle: Optional[str] = Field(None, alias="sourceHandle")
target_handle: Optional[str] = Field(None, alias="targetHandle")
class WorkflowEdge(WorkflowEdgeCreate):
mongo_id: Optional[str] = Field(None, alias="_id")
@field_validator("mongo_id", mode="before")
@classmethod
def convert_objectid(cls, v: Any) -> Optional[str]:
if isinstance(v, ObjectId):
return str(v)
return v
def to_mongo_doc(self) -> Dict[str, Any]:
return {
"id": self.id,
"workflow_id": self.workflow_id,
"source_id": self.source_id,
"target_id": self.target_id,
"source_handle": self.source_handle,
"target_handle": self.target_handle,
}
class WorkflowNodeCreate(BaseModel):
model_config = ConfigDict(extra="allow")
id: str
workflow_id: str
type: NodeType
title: str = "Node"
description: Optional[str] = None
position: Position = Field(default_factory=Position)
config: Dict[str, Any] = Field(default_factory=dict)
@field_validator("position", mode="before")
@classmethod
def parse_position(cls, v: Union[Dict[str, float], Position]) -> Position:
if isinstance(v, dict):
return Position(**v)
return v
class WorkflowNode(WorkflowNodeCreate):
mongo_id: Optional[str] = Field(None, alias="_id")
@field_validator("mongo_id", mode="before")
@classmethod
def convert_objectid(cls, v: Any) -> Optional[str]:
if isinstance(v, ObjectId):
return str(v)
return v
def to_mongo_doc(self) -> Dict[str, Any]:
return {
"id": self.id,
"workflow_id": self.workflow_id,
"type": self.type.value,
"title": self.title,
"description": self.description,
"position": self.position.model_dump(),
"config": self.config,
}
class WorkflowCreate(BaseModel):
model_config = ConfigDict(extra="allow")
name: str = "New Workflow"
description: Optional[str] = None
user: Optional[str] = None
class Workflow(WorkflowCreate):
id: Optional[str] = Field(None, alias="_id")
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
@field_validator("id", mode="before")
@classmethod
def convert_objectid(cls, v: Any) -> Optional[str]:
if isinstance(v, ObjectId):
return str(v)
return v
def to_mongo_doc(self) -> Dict[str, Any]:
return {
"name": self.name,
"description": self.description,
"user": self.user,
"created_at": self.created_at,
"updated_at": self.updated_at,
}
class WorkflowGraph(BaseModel):
workflow: Workflow
nodes: List[WorkflowNode] = Field(default_factory=list)
edges: List[WorkflowEdge] = Field(default_factory=list)
def get_node_by_id(self, node_id: str) -> Optional[WorkflowNode]:
for node in self.nodes:
if node.id == node_id:
return node
return None
def get_start_node(self) -> Optional[WorkflowNode]:
for node in self.nodes:
if node.type == NodeType.START:
return node
return None
def get_outgoing_edges(self, node_id: str) -> List[WorkflowEdge]:
return [edge for edge in self.edges if edge.source_id == node_id]
class NodeExecutionLog(BaseModel):
model_config = ConfigDict(extra="forbid")
node_id: str
node_type: str
status: ExecutionStatus
started_at: datetime
completed_at: Optional[datetime] = None
error: Optional[str] = None
state_snapshot: Dict[str, Any] = Field(default_factory=dict)
class WorkflowRunCreate(BaseModel):
workflow_id: str
inputs: Dict[str, str] = Field(default_factory=dict)
class WorkflowRun(BaseModel):
model_config = ConfigDict(extra="allow")
id: Optional[str] = Field(None, alias="_id")
workflow_id: str
status: ExecutionStatus = ExecutionStatus.PENDING
inputs: Dict[str, str] = Field(default_factory=dict)
outputs: Dict[str, Any] = Field(default_factory=dict)
steps: List[NodeExecutionLog] = Field(default_factory=list)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
completed_at: Optional[datetime] = None
@field_validator("id", mode="before")
@classmethod
def convert_objectid(cls, v: Any) -> Optional[str]:
if isinstance(v, ObjectId):
return str(v)
return v
def to_mongo_doc(self) -> Dict[str, Any]:
return {
"workflow_id": self.workflow_id,
"status": self.status.value,
"inputs": self.inputs,
"outputs": self.outputs,
"steps": [step.model_dump() for step in self.steps],
"created_at": self.created_at,
"completed_at": self.completed_at,
}

View File

@@ -0,0 +1,455 @@
import json
import logging
from datetime import datetime, timezone
from typing import Any, Dict, Generator, List, Optional, TYPE_CHECKING
from application.agents.workflows.cel_evaluator import CelEvaluationError, evaluate_cel
from application.agents.workflows.node_agent import WorkflowNodeAgentFactory
from application.agents.workflows.schemas import (
AgentNodeConfig,
ConditionNodeConfig,
ExecutionStatus,
NodeExecutionLog,
NodeType,
WorkflowGraph,
WorkflowNode,
)
from application.core.json_schema_utils import (
JsonSchemaValidationError,
normalize_json_schema_payload,
)
from application.error import sanitize_api_error
from application.templates.namespaces import NamespaceManager
from application.templates.template_engine import TemplateEngine, TemplateRenderError
try:
import jsonschema
except ImportError: # pragma: no cover - optional dependency in some deployments.
jsonschema = None
if TYPE_CHECKING:
from application.agents.base import BaseAgent
logger = logging.getLogger(__name__)
StateValue = Any
WorkflowState = Dict[str, StateValue]
TEMPLATE_RESERVED_NAMESPACES = {"agent", "system", "source", "tools", "passthrough"}
class WorkflowEngine:
MAX_EXECUTION_STEPS = 50
def __init__(self, graph: WorkflowGraph, agent: "BaseAgent"):
self.graph = graph
self.agent = agent
self.state: WorkflowState = {}
self.execution_log: List[Dict[str, Any]] = []
self._condition_result: Optional[str] = None
self._template_engine = TemplateEngine()
self._namespace_manager = NamespaceManager()
def execute(
self, initial_inputs: WorkflowState, query: str
) -> Generator[Dict[str, str], None, None]:
self._initialize_state(initial_inputs, query)
start_node = self.graph.get_start_node()
if not start_node:
yield {"type": "error", "error": "No start node found in workflow."}
return
current_node_id: Optional[str] = start_node.id
steps = 0
while current_node_id and steps < self.MAX_EXECUTION_STEPS:
node = self.graph.get_node_by_id(current_node_id)
if not node:
yield {"type": "error", "error": f"Node {current_node_id} not found."}
break
log_entry = self._create_log_entry(node)
yield {
"type": "workflow_step",
"node_id": node.id,
"node_type": node.type.value,
"node_title": node.title,
"status": "running",
}
try:
yield from self._execute_node(node)
log_entry["status"] = ExecutionStatus.COMPLETED.value
log_entry["completed_at"] = datetime.now(timezone.utc)
output_key = f"node_{node.id}_output"
node_output = self.state.get(output_key)
yield {
"type": "workflow_step",
"node_id": node.id,
"node_type": node.type.value,
"node_title": node.title,
"status": "completed",
"state_snapshot": dict(self.state),
"output": node_output,
}
except Exception as e:
logger.error(f"Error executing node {node.id}: {e}", exc_info=True)
log_entry["status"] = ExecutionStatus.FAILED.value
log_entry["error"] = str(e)
log_entry["completed_at"] = datetime.now(timezone.utc)
log_entry["state_snapshot"] = dict(self.state)
self.execution_log.append(log_entry)
user_friendly_error = sanitize_api_error(e)
yield {
"type": "workflow_step",
"node_id": node.id,
"node_type": node.type.value,
"node_title": node.title,
"status": "failed",
"state_snapshot": dict(self.state),
"error": user_friendly_error,
}
yield {"type": "error", "error": user_friendly_error}
break
log_entry["state_snapshot"] = dict(self.state)
self.execution_log.append(log_entry)
if node.type == NodeType.END:
break
current_node_id = self._get_next_node_id(current_node_id)
if current_node_id is None and node.type != NodeType.END:
logger.warning(
f"Branch ended at node '{node.title}' ({node.id}) without reaching an end node"
)
steps += 1
if steps >= self.MAX_EXECUTION_STEPS:
logger.warning(
f"Workflow reached max steps limit ({self.MAX_EXECUTION_STEPS})"
)
def _initialize_state(self, initial_inputs: WorkflowState, query: str) -> None:
self.state.update(initial_inputs)
self.state["query"] = query
self.state["chat_history"] = str(self.agent.chat_history)
def _create_log_entry(self, node: WorkflowNode) -> Dict[str, Any]:
return {
"node_id": node.id,
"node_type": node.type.value,
"started_at": datetime.now(timezone.utc),
"completed_at": None,
"status": ExecutionStatus.RUNNING.value,
"error": None,
"state_snapshot": {},
}
def _get_next_node_id(self, current_node_id: str) -> Optional[str]:
node = self.graph.get_node_by_id(current_node_id)
edges = self.graph.get_outgoing_edges(current_node_id)
if not edges:
return None
if node and node.type == NodeType.CONDITION and self._condition_result:
target_handle = self._condition_result
self._condition_result = None
for edge in edges:
if edge.source_handle == target_handle:
return edge.target_id
return None
return edges[0].target_id
def _execute_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
logger.info(f"Executing node {node.id} ({node.type.value})")
node_handlers = {
NodeType.START: self._execute_start_node,
NodeType.NOTE: self._execute_note_node,
NodeType.AGENT: self._execute_agent_node,
NodeType.STATE: self._execute_state_node,
NodeType.CONDITION: self._execute_condition_node,
NodeType.END: self._execute_end_node,
}
handler = node_handlers.get(node.type)
if handler:
yield from handler(node)
def _execute_start_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
yield from ()
def _execute_note_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
yield from ()
def _execute_agent_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
from application.core.model_utils import (
get_api_key_for_provider,
get_model_capabilities,
get_provider_from_model_id,
)
node_config = AgentNodeConfig(**node.config.get("config", node.config))
if node_config.prompt_template:
formatted_prompt = self._format_template(node_config.prompt_template)
else:
formatted_prompt = self.state.get("query", "")
node_json_schema = self._normalize_node_json_schema(
node_config.json_schema, node.title
)
node_model_id = node_config.model_id or self.agent.model_id
node_llm_name = (
node_config.llm_name
or get_provider_from_model_id(node_model_id or "")
or self.agent.llm_name
)
node_api_key = get_api_key_for_provider(node_llm_name) or self.agent.api_key
if node_json_schema and node_model_id:
model_capabilities = get_model_capabilities(node_model_id)
if model_capabilities and not model_capabilities.get(
"supports_structured_output", False
):
raise ValueError(
f'Model "{node_model_id}" does not support structured output for node "{node.title}"'
)
node_agent = WorkflowNodeAgentFactory.create(
agent_type=node_config.agent_type,
endpoint=self.agent.endpoint,
llm_name=node_llm_name,
model_id=node_model_id,
api_key=node_api_key,
tool_ids=node_config.tools,
prompt=node_config.system_prompt,
chat_history=self.agent.chat_history,
decoded_token=self.agent.decoded_token,
json_schema=node_json_schema,
)
full_response_parts: List[str] = []
structured_response_parts: List[str] = []
has_structured_response = False
first_chunk = True
for event in node_agent.gen(formatted_prompt):
if "answer" in event:
chunk = str(event["answer"])
full_response_parts.append(chunk)
if event.get("structured"):
has_structured_response = True
structured_response_parts.append(chunk)
if node_config.stream_to_user:
if first_chunk and hasattr(self, "_has_streamed"):
yield {"answer": "\n\n"}
first_chunk = False
yield event
if node_config.stream_to_user:
self._has_streamed = True
full_response = "".join(full_response_parts).strip()
output_value: Any = full_response
if has_structured_response:
structured_response = "".join(structured_response_parts).strip()
response_to_parse = structured_response or full_response
parsed_success, parsed_structured = self._parse_structured_output(
response_to_parse
)
output_value = parsed_structured if parsed_success else response_to_parse
if node_json_schema:
self._validate_structured_output(node_json_schema, output_value)
elif node_json_schema:
parsed_success, parsed_structured = self._parse_structured_output(
full_response
)
if not parsed_success:
raise ValueError(
"Structured output was expected but response was not valid JSON"
)
output_value = parsed_structured
self._validate_structured_output(node_json_schema, output_value)
default_output_key = f"node_{node.id}_output"
self.state[default_output_key] = output_value
if node_config.output_variable:
self.state[node_config.output_variable] = output_value
def _execute_state_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
config = node.config.get("config", node.config)
for op in config.get("operations", []):
expression = op.get("expression", "")
target_variable = op.get("target_variable", "")
if expression and target_variable:
self.state[target_variable] = evaluate_cel(expression, self.state)
yield from ()
def _execute_condition_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
config = ConditionNodeConfig(**node.config.get("config", node.config))
matched_handle = None
for case in config.cases:
if not case.expression.strip():
continue
try:
if evaluate_cel(case.expression, self.state):
matched_handle = case.source_handle
break
except CelEvaluationError:
continue
self._condition_result = matched_handle or "else"
yield from ()
def _execute_end_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
config = node.config.get("config", node.config)
output_template = str(config.get("output_template", ""))
if output_template:
formatted_output = self._format_template(output_template)
yield {"answer": formatted_output}
def _parse_structured_output(self, raw_response: str) -> tuple[bool, Optional[Any]]:
normalized_response = raw_response.strip()
if not normalized_response:
return False, None
try:
return True, json.loads(normalized_response)
except json.JSONDecodeError:
logger.warning(
"Workflow agent returned structured output that was not valid JSON"
)
return False, None
def _normalize_node_json_schema(
self, schema: Optional[Dict[str, Any]], node_title: str
) -> Optional[Dict[str, Any]]:
if schema is None:
return None
try:
return normalize_json_schema_payload(schema)
except JsonSchemaValidationError as exc:
raise ValueError(
f'Invalid JSON schema for node "{node_title}": {exc}'
) from exc
def _validate_structured_output(self, schema: Dict[str, Any], output_value: Any) -> None:
if jsonschema is None:
logger.warning(
"jsonschema package is not available, skipping structured output validation"
)
return
try:
normalized_schema = normalize_json_schema_payload(schema)
except JsonSchemaValidationError as exc:
raise ValueError(f"Invalid JSON schema: {exc}") from exc
try:
jsonschema.validate(instance=output_value, schema=normalized_schema)
except jsonschema.exceptions.ValidationError as exc:
raise ValueError(f"Structured output did not match schema: {exc.message}") from exc
except jsonschema.exceptions.SchemaError as exc:
raise ValueError(f"Invalid JSON schema: {exc.message}") from exc
def _format_template(self, template: str) -> str:
context = self._build_template_context()
try:
return self._template_engine.render(template, context)
except TemplateRenderError as e:
logger.warning(
"Workflow template rendering failed, using raw template: %s", str(e)
)
return template
def _build_template_context(self) -> Dict[str, Any]:
docs, docs_together = self._get_source_template_data()
passthrough_data = (
self.state.get("passthrough")
if isinstance(self.state.get("passthrough"), dict)
else None
)
tools_data = (
self.state.get("tools") if isinstance(self.state.get("tools"), dict) else None
)
context = self._namespace_manager.build_context(
user_id=getattr(self.agent, "user", None),
request_id=getattr(self.agent, "request_id", None),
passthrough_data=passthrough_data,
docs=docs,
docs_together=docs_together,
tools_data=tools_data,
)
agent_context: Dict[str, Any] = {}
for key, value in self.state.items():
if not isinstance(key, str):
continue
normalized_key = key.strip()
if not normalized_key:
continue
agent_context[normalized_key] = value
context["agent"] = agent_context
# Keep legacy top-level variables working while namespaced variables are adopted.
for key, value in agent_context.items():
if key in TEMPLATE_RESERVED_NAMESPACES:
context[f"agent_{key}"] = value
continue
if key not in context:
context[key] = value
return context
def _get_source_template_data(self) -> tuple[Optional[List[Dict[str, Any]]], Optional[str]]:
docs = getattr(self.agent, "retrieved_docs", None)
if not isinstance(docs, list) or len(docs) == 0:
return None, None
docs_together_parts: List[str] = []
for doc in docs:
if not isinstance(doc, dict):
continue
text = doc.get("text")
if not isinstance(text, str):
continue
filename = doc.get("filename") or doc.get("title") or doc.get("source")
if isinstance(filename, str) and filename.strip():
docs_together_parts.append(f"{filename}\n{text}")
else:
docs_together_parts.append(text)
docs_together = "\n\n".join(docs_together_parts) if docs_together_parts else None
return docs, docs_together
def get_execution_summary(self) -> List[NodeExecutionLog]:
return [
NodeExecutionLog(
node_id=log["node_id"],
node_type=log["node_type"],
status=ExecutionStatus(log["status"]),
started_at=log["started_at"],
completed_at=log.get("completed_at"),
error=log.get("error"),
state_snapshot=log.get("state_snapshot", {}),
)
for log in self.execution_log
]

View File

@@ -3,6 +3,7 @@ from flask import Blueprint
from application.api import api
from application.api.answer.routes.answer import AnswerResource
from application.api.answer.routes.base import answer_ns
from application.api.answer.routes.search import SearchResource
from application.api.answer.routes.stream import StreamResource
@@ -14,6 +15,7 @@ api.add_namespace(answer_ns)
def init_answer_routes():
api.add_resource(StreamResource, "/stream")
api.add_resource(AnswerResource, "/api/answer")
api.add_resource(SearchResource, "/api/search")
init_answer_routes()

View File

@@ -40,9 +40,9 @@ class AnswerResource(Resource, BaseAnswerResource):
"chunks": fields.Integer(
required=False, default=2, description="Number of chunks"
),
"token_limit": fields.Integer(required=False, description="Token limit"),
"retriever": fields.String(required=False, description="Retriever type"),
"api_key": fields.String(required=False, description="API key"),
"agent_id": fields.String(required=False, description="Agent ID"),
"active_docs": fields.String(
required=False, description="Active documents"
),
@@ -54,6 +54,10 @@ class AnswerResource(Resource, BaseAnswerResource):
default=True,
description="Whether to save the conversation",
),
"model_id": fields.String(
required=False,
description="Model ID to use for this request",
),
"passthrough": fields.Raw(
required=False,
description="Dynamic parameters to inject into prompt template",
@@ -97,6 +101,10 @@ class AnswerResource(Resource, BaseAnswerResource):
isNoneDoc=data.get("isNoneDoc"),
index=None,
should_save_conversation=data.get("save_conversation", True),
agent_id=processor.agent_id,
is_shared_usage=processor.is_shared_usage,
shared_token=processor.shared_token,
model_id=processor.model_id,
)
stream_result = self.process_response_stream(stream)
@@ -133,5 +141,5 @@ class AnswerResource(Resource, BaseAnswerResource):
f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}",
extra={"error": str(e), "traceback": traceback.format_exc()},
)
return make_response({"error": str(e)}, 500)
return make_response({"error": "An error occurred processing your request"}, 500)
return make_response(result, 200)

View File

@@ -7,11 +7,17 @@ from flask import jsonify, make_response, Response
from flask_restx import Namespace
from application.api.answer.services.conversation_service import ConversationService
from application.core.model_utils import (
get_api_key_for_provider,
get_default_model_id,
get_provider_from_model_id,
)
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.error import sanitize_api_error
from application.llm.llm_creator import LLMCreator
from application.utils import check_required_fields, get_gpt_model
from application.utils import check_required_fields
logger = logging.getLogger(__name__)
@@ -27,7 +33,7 @@ class BaseAnswerResource:
db = mongo[settings.MONGO_DB_NAME]
self.db = db
self.user_logs_collection = db["user_logs"]
self.gpt_model = get_gpt_model()
self.default_model_id = get_default_model_id()
self.conversation_service = ConversationService()
def validate_request(
@@ -41,6 +47,27 @@ class BaseAnswerResource:
return missing_fields
return None
@staticmethod
def _prepare_tool_calls_for_logging(
tool_calls: Optional[List[Dict[str, Any]]], max_chars: int = 10000
) -> List[Dict[str, Any]]:
if not tool_calls:
return []
prepared = []
for tool_call in tool_calls:
if not isinstance(tool_call, dict):
prepared.append({"result": str(tool_call)[:max_chars]})
continue
item = dict(tool_call)
for key in ("result", "result_full"):
value = item.get(key)
if isinstance(value, str) and len(value) > max_chars:
item[key] = value[:max_chars]
prepared.append(item)
return prepared
def check_usage(self, agent_config: Dict) -> Optional[Response]:
"""Check if there is a usage limit and if it is exceeded
@@ -54,7 +81,6 @@ class BaseAnswerResource:
api_key = agent_config.get("user_api_key")
if not api_key:
return None
agents_collection = self.db["agents"]
agent = agents_collection.find_one({"key": api_key})
@@ -62,7 +88,6 @@ class BaseAnswerResource:
return make_response(
jsonify({"success": False, "message": "Invalid API key."}), 401
)
limited_token_mode_raw = agent.get("limited_token_mode", False)
limited_request_mode_raw = agent.get("limited_request_mode", False)
@@ -110,15 +135,12 @@ class BaseAnswerResource:
daily_token_usage = token_result[0]["total_tokens"] if token_result else 0
else:
daily_token_usage = 0
if limited_request_mode:
daily_request_usage = token_usage_collection.count_documents(match_query)
else:
daily_request_usage = 0
if not limited_token_mode and not limited_request_mode:
return None
token_exceeded = (
limited_token_mode and token_limit > 0 and daily_token_usage >= token_limit
)
@@ -138,7 +160,6 @@ class BaseAnswerResource:
),
429,
)
return None
def complete_stream(
@@ -155,6 +176,7 @@ class BaseAnswerResource:
agent_id: Optional[str] = None,
is_shared_usage: bool = False,
shared_token: Optional[str] = None,
model_id: Optional[str] = None,
) -> Generator[str, None, None]:
"""
Generator function that streams the complete conversation response.
@@ -173,6 +195,7 @@ class BaseAnswerResource:
agent_id: ID of agent used
is_shared_usage: Flag for shared agent usage
shared_token: Token for shared agent
model_id: Model ID used for the request
retrieved_docs: Pre-fetched documents for sources (optional)
Yields:
@@ -218,9 +241,15 @@ class BaseAnswerResource:
data = json.dumps({"type": "thought", "thought": line["thought"]})
yield f"data: {data}\n\n"
elif "type" in line:
data = json.dumps(line)
if line.get("type") == "error":
sanitized_error = {
"type": "error",
"error": sanitize_api_error(line.get("error", "An error occurred"))
}
data = json.dumps(sanitized_error)
else:
data = json.dumps(line)
yield f"data: {data}\n\n"
if is_structured and structured_chunks:
structured_data = {
"type": "structured_answer",
@@ -230,15 +259,23 @@ class BaseAnswerResource:
}
data = json.dumps(structured_data)
yield f"data: {data}\n\n"
if isNoneDoc:
for doc in source_log_docs:
doc["source"] = "None"
provider = (
get_provider_from_model_id(model_id)
if model_id
else settings.LLM_PROVIDER
)
system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER)
llm = LLMCreator.create_llm(
settings.LLM_PROVIDER,
api_key=settings.API_KEY,
provider or settings.LLM_PROVIDER,
api_key=system_api_key,
user_api_key=user_api_key,
decoded_token=decoded_token,
model_id=model_id,
agent_id=agent_id,
)
if should_save_conversation:
@@ -250,7 +287,7 @@ class BaseAnswerResource:
source_log_docs,
tool_calls,
llm,
self.gpt_model,
model_id or self.default_model_id,
decoded_token,
index=index,
api_key=user_api_key,
@@ -259,20 +296,46 @@ class BaseAnswerResource:
shared_token=shared_token,
attachment_ids=attachment_ids,
)
# Persist compression metadata/summary if it exists and wasn't saved mid-execution
compression_meta = getattr(agent, "compression_metadata", None)
compression_saved = getattr(agent, "compression_saved", False)
if conversation_id and compression_meta and not compression_saved:
try:
self.conversation_service.update_compression_metadata(
conversation_id, compression_meta
)
self.conversation_service.append_compression_message(
conversation_id, compression_meta
)
agent.compression_saved = True
logger.info(
f"Persisted compression metadata for conversation {conversation_id}"
)
except Exception as e:
logger.error(
f"Failed to persist compression metadata: {str(e)}",
exc_info=True,
)
else:
conversation_id = None
id_data = {"type": "id", "id": str(conversation_id)}
data = json.dumps(id_data)
yield f"data: {data}\n\n"
tool_calls_for_logging = self._prepare_tool_calls_for_logging(
getattr(agent, "tool_calls", tool_calls) or tool_calls
)
log_data = {
"action": "stream_answer",
"level": "info",
"user": decoded_token.get("sub"),
"api_key": user_api_key,
"agent_id": agent_id,
"question": question,
"response": response_full,
"sources": source_log_docs,
"tool_calls": tool_calls_for_logging,
"attachments": attachment_ids,
"timestamp": datetime.datetime.now(datetime.timezone.utc),
}
@@ -280,12 +343,11 @@ class BaseAnswerResource:
log_data["structured_output"] = True
if schema_info:
log_data["schema"] = schema_info
# Clean up text fields to be no longer than 10000 characters
for key, value in log_data.items():
if isinstance(value, str) and len(value) > 10000:
log_data[key] = value[:10000]
self.user_logs_collection.insert_one(log_data)
data = json.dumps({"type": "end"})
@@ -293,6 +355,7 @@ class BaseAnswerResource:
except GeneratorExit:
logger.info(f"Stream aborted by client for question: {question[:50]}... ")
# Save partial response
if should_save_conversation and response_full:
try:
if isNoneDoc:
@@ -303,6 +366,7 @@ class BaseAnswerResource:
api_key=settings.API_KEY,
user_api_key=user_api_key,
decoded_token=decoded_token,
agent_id=agent_id,
)
self.conversation_service.save_conversation(
conversation_id,
@@ -312,7 +376,7 @@ class BaseAnswerResource:
source_log_docs,
tool_calls,
llm,
self.gpt_model,
model_id or self.default_model_id,
decoded_token,
index=index,
api_key=user_api_key,
@@ -321,6 +385,25 @@ class BaseAnswerResource:
shared_token=shared_token,
attachment_ids=attachment_ids,
)
compression_meta = getattr(agent, "compression_metadata", None)
compression_saved = getattr(agent, "compression_saved", False)
if conversation_id and compression_meta and not compression_saved:
try:
self.conversation_service.update_compression_metadata(
conversation_id, compression_meta
)
self.conversation_service.append_compression_message(
conversation_id, compression_meta
)
agent.compression_saved = True
logger.info(
f"Persisted compression metadata for conversation {conversation_id} (partial stream)"
)
except Exception as e:
logger.error(
f"Failed to persist compression metadata (partial stream): {str(e)}",
exc_info=True,
)
except Exception as e:
logger.error(
f"Error saving partial response: {str(e)}", exc_info=True
@@ -369,7 +452,7 @@ class BaseAnswerResource:
thought = event["thought"]
elif event["type"] == "error":
logger.error(f"Error from stream: {event['error']}")
return None, None, None, None, event["error"]
return None, None, None, None, event["error"], None
elif event["type"] == "end":
stream_ended = True
except (json.JSONDecodeError, KeyError) as e:
@@ -377,8 +460,7 @@ class BaseAnswerResource:
continue
if not stream_ended:
logger.error("Stream ended unexpectedly without an 'end' event.")
return None, None, None, None, "Stream ended unexpectedly"
return None, None, None, None, "Stream ended unexpectedly", None
result = (
conversation_id,
response_full,
@@ -390,7 +472,6 @@ class BaseAnswerResource:
if is_structured:
result = result + ({"structured": True, "schema": schema_info},)
return result
def error_stream_generate(self, err_response):

View File

@@ -0,0 +1,186 @@
import logging
from typing import Any, Dict, List
from flask import make_response, request
from flask_restx import fields, Resource
from bson.dbref import DBRef
from application.api.answer.routes.base import answer_ns
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.vectorstore.vector_creator import VectorCreator
logger = logging.getLogger(__name__)
@answer_ns.route("/api/search")
class SearchResource(Resource):
"""Fast search endpoint for retrieving relevant documents"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
mongo = MongoDB.get_client()
self.db = mongo[settings.MONGO_DB_NAME]
self.agents_collection = self.db["agents"]
search_model = answer_ns.model(
"SearchModel",
{
"question": fields.String(
required=True, description="Search query"
),
"api_key": fields.String(
required=True, description="API key for authentication"
),
"chunks": fields.Integer(
required=False, default=5, description="Number of results to return"
),
},
)
def _get_sources_from_api_key(self, api_key: str) -> List[str]:
"""Get source IDs connected to the API key/agent.
"""
agent_data = self.agents_collection.find_one({"key": api_key})
if not agent_data:
return []
source_ids = []
# Handle multiple sources (only if non-empty)
sources = agent_data.get("sources", [])
if sources and isinstance(sources, list) and len(sources) > 0:
for source_ref in sources:
# Skip "default" - it's a placeholder, not an actual vectorstore
if source_ref == "default":
continue
elif isinstance(source_ref, DBRef):
source_doc = self.db.dereference(source_ref)
if source_doc:
source_ids.append(str(source_doc["_id"]))
# Handle single source (legacy) - check if sources was empty or didn't yield results
if not source_ids:
source = agent_data.get("source")
if isinstance(source, DBRef):
source_doc = self.db.dereference(source)
if source_doc:
source_ids.append(str(source_doc["_id"]))
# Skip "default" - it's a placeholder, not an actual vectorstore
elif source and source != "default":
source_ids.append(source)
return source_ids
def _search_vectorstores(
self, query: str, source_ids: List[str], chunks: int
) -> List[Dict[str, Any]]:
"""Search across vectorstores and return results"""
if not source_ids:
return []
results = []
chunks_per_source = max(1, chunks // len(source_ids))
seen_texts = set()
for source_id in source_ids:
if not source_id or not source_id.strip():
continue
try:
docsearch = VectorCreator.create_vectorstore(
settings.VECTOR_STORE, source_id, settings.EMBEDDINGS_KEY
)
docs = docsearch.search(query, k=chunks_per_source * 2)
for doc in docs:
if len(results) >= chunks:
break
if hasattr(doc, "page_content") and hasattr(doc, "metadata"):
page_content = doc.page_content
metadata = doc.metadata
else:
page_content = doc.get("text", doc.get("page_content", ""))
metadata = doc.get("metadata", {})
# Skip duplicates
text_hash = hash(page_content[:200])
if text_hash in seen_texts:
continue
seen_texts.add(text_hash)
title = metadata.get(
"title", metadata.get("post_title", "")
)
if not isinstance(title, str):
title = str(title) if title else ""
# Clean up title
if title:
title = title.split("/")[-1]
else:
# Use filename or first part of content as title
title = metadata.get("filename", page_content[:50] + "...")
source = metadata.get("source", source_id)
results.append({
"text": page_content,
"title": title,
"source": source,
})
if len(results) >= chunks:
break
except Exception as e:
logger.error(
f"Error searching vectorstore {source_id}: {e}",
exc_info=True,
)
continue
return results[:chunks]
@answer_ns.expect(search_model)
@answer_ns.doc(description="Search for relevant documents based on query")
def post(self):
data = request.get_json()
question = data.get("question")
api_key = data.get("api_key")
chunks = data.get("chunks", 5)
if not question:
return make_response({"error": "question is required"}, 400)
if not api_key:
return make_response({"error": "api_key is required"}, 400)
# Validate API key
agent = self.agents_collection.find_one({"key": api_key})
if not agent:
return make_response({"error": "Invalid API key"}, 401)
try:
# Get sources connected to this API key
source_ids = self._get_sources_from_api_key(api_key)
if not source_ids:
return make_response([], 200)
# Perform search
results = self._search_vectorstores(question, source_ids, chunks)
return make_response(results, 200)
except Exception as e:
logger.error(
f"/api/search - error: {str(e)}",
extra={"error": str(e)},
exc_info=True,
)
return make_response({"error": "Search failed"}, 500)

View File

@@ -40,9 +40,9 @@ class StreamResource(Resource, BaseAnswerResource):
"chunks": fields.Integer(
required=False, default=2, description="Number of chunks"
),
"token_limit": fields.Integer(required=False, description="Token limit"),
"retriever": fields.String(required=False, description="Retriever type"),
"api_key": fields.String(required=False, description="API key"),
"agent_id": fields.String(required=False, description="Agent ID"),
"active_docs": fields.String(
required=False, description="Active documents"
),
@@ -57,6 +57,10 @@ class StreamResource(Resource, BaseAnswerResource):
default=True,
description="Whether to save the conversation",
),
"model_id": fields.String(
required=False,
description="Model ID to use for this request",
),
"attachments": fields.List(
fields.String, required=False, description="List of attachment IDs"
),
@@ -77,6 +81,12 @@ class StreamResource(Resource, BaseAnswerResource):
processor = StreamProcessor(data, decoded_token)
try:
processor.initialize()
if not processor.decoded_token:
return Response(
self.error_stream_generate("Unauthorized"),
status=401,
mimetype="text/event-stream",
)
docs_together, docs_list = processor.pre_fetch_docs(data["question"])
tools_data = processor.pre_fetch_tools()
@@ -98,9 +108,10 @@ class StreamResource(Resource, BaseAnswerResource):
index=data.get("index"),
should_save_conversation=data.get("save_conversation", True),
attachment_ids=data.get("attachments", []),
agent_id=data.get("agent_id"),
agent_id=processor.agent_id,
is_shared_usage=processor.is_shared_usage,
shared_token=processor.shared_token,
model_id=processor.model_id,
),
mimetype="text/event-stream",
)

View File

@@ -0,0 +1,20 @@
"""
Compression module for managing conversation context compression.
"""
from application.api.answer.services.compression.orchestrator import (
CompressionOrchestrator,
)
from application.api.answer.services.compression.service import CompressionService
from application.api.answer.services.compression.types import (
CompressionResult,
CompressionMetadata,
)
__all__ = [
"CompressionOrchestrator",
"CompressionService",
"CompressionResult",
"CompressionMetadata",
]

View File

@@ -0,0 +1,234 @@
"""Message reconstruction utilities for compression."""
import logging
import uuid
from typing import Dict, List, Optional
logger = logging.getLogger(__name__)
class MessageBuilder:
"""Builds message arrays from compressed context."""
@staticmethod
def build_from_compressed_context(
system_prompt: str,
compressed_summary: Optional[str],
recent_queries: List[Dict],
include_tool_calls: bool = False,
context_type: str = "pre_request",
) -> List[Dict]:
"""
Build messages from compressed context.
Args:
system_prompt: Original system prompt
compressed_summary: Compressed summary (if any)
recent_queries: Recent uncompressed queries
include_tool_calls: Whether to include tool calls from history
context_type: Type of context ('pre_request' or 'mid_execution')
Returns:
List of message dicts ready for LLM
"""
# Append compression summary to system prompt if present
if compressed_summary:
system_prompt = MessageBuilder._append_compression_context(
system_prompt, compressed_summary, context_type
)
messages = [{"role": "system", "content": system_prompt}]
# Add recent history
for query in recent_queries:
if "prompt" in query and "response" in query:
messages.append({"role": "user", "content": query["prompt"]})
messages.append({"role": "assistant", "content": query["response"]})
# Add tool calls from history if present
if include_tool_calls and "tool_calls" in query:
for tool_call in query["tool_calls"]:
call_id = tool_call.get("call_id") or str(uuid.uuid4())
function_call_dict = {
"function_call": {
"name": tool_call.get("action_name"),
"args": tool_call.get("arguments"),
"call_id": call_id,
}
}
function_response_dict = {
"function_response": {
"name": tool_call.get("action_name"),
"response": {"result": tool_call.get("result")},
"call_id": call_id,
}
}
messages.append(
{"role": "assistant", "content": [function_call_dict]}
)
messages.append(
{"role": "tool", "content": [function_response_dict]}
)
# If no recent queries (everything was compressed), add a continuation user message
if len(recent_queries) == 0 and compressed_summary:
messages.append({
"role": "user",
"content": "Please continue with the remaining tasks based on the context above."
})
logger.info("Added continuation user message to maintain proper turn-taking after full compression")
return messages
@staticmethod
def _append_compression_context(
system_prompt: str, compressed_summary: str, context_type: str = "pre_request"
) -> str:
"""
Append compression context to system prompt.
Args:
system_prompt: Original system prompt
compressed_summary: Summary to append
context_type: Type of compression context
Returns:
Updated system prompt
"""
# Remove existing compression context if present
if "This session is being continued" in system_prompt or "Context window limit reached" in system_prompt:
parts = system_prompt.split("\n\n---\n\n")
system_prompt = parts[0]
# Build appropriate context message based on type
if context_type == "mid_execution":
context_message = (
"\n\n---\n\n"
"Context window limit reached during execution. "
"Previous conversation has been compressed to fit within limits. "
"The conversation is summarized below:\n\n"
f"{compressed_summary}"
)
else: # pre_request
context_message = (
"\n\n---\n\n"
"This session is being continued from a previous conversation that "
"has been compressed to fit within context limits. "
"The conversation is summarized below:\n\n"
f"{compressed_summary}"
)
return system_prompt + context_message
@staticmethod
def rebuild_messages_after_compression(
messages: List[Dict],
compressed_summary: Optional[str],
recent_queries: List[Dict],
include_current_execution: bool = False,
include_tool_calls: bool = False,
) -> Optional[List[Dict]]:
"""
Rebuild the message list after compression so tool execution can continue.
Args:
messages: Original message list
compressed_summary: Compressed summary
recent_queries: Recent uncompressed queries
include_current_execution: Whether to preserve current execution messages
include_tool_calls: Whether to include tool calls from history
Returns:
Rebuilt message list or None if failed
"""
# Find the system message
system_message = next(
(msg for msg in messages if msg.get("role") == "system"), None
)
if not system_message:
logger.warning("No system message found in messages list")
return None
# Update system message with compressed summary
if compressed_summary:
content = system_message.get("content", "")
system_message["content"] = MessageBuilder._append_compression_context(
content, compressed_summary, "mid_execution"
)
logger.info(
"Appended compression summary to system prompt (truncated): %s",
(
compressed_summary[:500] + "..."
if len(compressed_summary) > 500
else compressed_summary
),
)
rebuilt_messages = [system_message]
# Add recent history from compressed context
for query in recent_queries:
if "prompt" in query and "response" in query:
rebuilt_messages.append({"role": "user", "content": query["prompt"]})
rebuilt_messages.append(
{"role": "assistant", "content": query["response"]}
)
# Add tool calls from history if present
if include_tool_calls and "tool_calls" in query:
for tool_call in query["tool_calls"]:
call_id = tool_call.get("call_id") or str(uuid.uuid4())
function_call_dict = {
"function_call": {
"name": tool_call.get("action_name"),
"args": tool_call.get("arguments"),
"call_id": call_id,
}
}
function_response_dict = {
"function_response": {
"name": tool_call.get("action_name"),
"response": {"result": tool_call.get("result")},
"call_id": call_id,
}
}
rebuilt_messages.append(
{"role": "assistant", "content": [function_call_dict]}
)
rebuilt_messages.append(
{"role": "tool", "content": [function_response_dict]}
)
# If no recent queries (everything was compressed), add a continuation user message
if len(recent_queries) == 0 and compressed_summary:
rebuilt_messages.append({
"role": "user",
"content": "Please continue with the remaining tasks based on the context above."
})
logger.info("Added continuation user message to maintain proper turn-taking after full compression")
if include_current_execution:
# Preserve any messages that were added during the current execution cycle
recent_msg_count = 1 # system message
for query in recent_queries:
if "prompt" in query and "response" in query:
recent_msg_count += 2
if "tool_calls" in query:
recent_msg_count += len(query["tool_calls"]) * 2
if len(messages) > recent_msg_count:
current_execution_messages = messages[recent_msg_count:]
rebuilt_messages.extend(current_execution_messages)
logger.info(
f"Preserved {len(current_execution_messages)} messages from current execution cycle"
)
logger.info(
f"Messages rebuilt: {len(messages)}{len(rebuilt_messages)} messages. "
f"Ready to continue tool execution."
)
return rebuilt_messages

View File

@@ -0,0 +1,233 @@
"""High-level compression orchestration."""
import logging
from typing import Any, Dict, Optional
from application.api.answer.services.compression.service import CompressionService
from application.api.answer.services.compression.threshold_checker import (
CompressionThresholdChecker,
)
from application.api.answer.services.compression.types import CompressionResult
from application.api.answer.services.conversation_service import ConversationService
from application.core.model_utils import (
get_api_key_for_provider,
get_provider_from_model_id,
)
from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
logger = logging.getLogger(__name__)
class CompressionOrchestrator:
"""
Facade for compression operations.
Coordinates between all compression components and provides
a simple interface for callers.
"""
def __init__(
self,
conversation_service: ConversationService,
threshold_checker: Optional[CompressionThresholdChecker] = None,
):
"""
Initialize orchestrator.
Args:
conversation_service: Service for DB operations
threshold_checker: Custom threshold checker (optional)
"""
self.conversation_service = conversation_service
self.threshold_checker = threshold_checker or CompressionThresholdChecker()
def compress_if_needed(
self,
conversation_id: str,
user_id: str,
model_id: str,
decoded_token: Dict[str, Any],
current_query_tokens: int = 500,
) -> CompressionResult:
"""
Check if compression is needed and perform it if so.
This is the main entry point for compression operations.
Args:
conversation_id: Conversation ID
user_id: User ID
model_id: Model being used for conversation
decoded_token: User's decoded JWT token
current_query_tokens: Estimated tokens for current query
Returns:
CompressionResult with summary and recent queries
"""
try:
# Load conversation
conversation = self.conversation_service.get_conversation(
conversation_id, user_id
)
if not conversation:
logger.warning(
f"Conversation {conversation_id} not found for user {user_id}"
)
return CompressionResult.failure("Conversation not found")
# Check if compression is needed
if not self.threshold_checker.should_compress(
conversation, model_id, current_query_tokens
):
# No compression needed, return full history
queries = conversation.get("queries", [])
return CompressionResult.success_no_compression(queries)
# Perform compression
return self._perform_compression(
conversation_id, conversation, model_id, decoded_token
)
except Exception as e:
logger.error(
f"Error in compress_if_needed: {str(e)}", exc_info=True
)
return CompressionResult.failure(str(e))
def _perform_compression(
self,
conversation_id: str,
conversation: Dict[str, Any],
model_id: str,
decoded_token: Dict[str, Any],
) -> CompressionResult:
"""
Perform the actual compression operation.
Args:
conversation_id: Conversation ID
conversation: Conversation document
model_id: Model ID for conversation
decoded_token: User token
Returns:
CompressionResult
"""
try:
# Determine which model to use for compression
compression_model = (
settings.COMPRESSION_MODEL_OVERRIDE
if settings.COMPRESSION_MODEL_OVERRIDE
else model_id
)
# Get provider and API key for compression model
provider = get_provider_from_model_id(compression_model)
api_key = get_api_key_for_provider(provider)
# Create compression LLM
compression_llm = LLMCreator.create_llm(
provider,
api_key=api_key,
user_api_key=None,
decoded_token=decoded_token,
model_id=compression_model,
agent_id=conversation.get("agent_id"),
)
# Create compression service with DB update capability
compression_service = CompressionService(
llm=compression_llm,
model_id=compression_model,
conversation_service=self.conversation_service,
)
# Compress all queries up to the latest
queries_count = len(conversation.get("queries", []))
compress_up_to = queries_count - 1
if compress_up_to < 0:
logger.warning("No queries to compress")
return CompressionResult.success_no_compression([])
logger.info(
f"Initiating compression for conversation {conversation_id}: "
f"compressing all {queries_count} queries (0-{compress_up_to})"
)
# Perform compression and save to DB
metadata = compression_service.compress_and_save(
conversation_id, conversation, compress_up_to
)
logger.info(
f"Compression successful - ratio: {metadata.compression_ratio:.1f}x, "
f"saved {metadata.original_token_count - metadata.compressed_token_count} tokens"
)
# Reload conversation with updated metadata
conversation = self.conversation_service.get_conversation(
conversation_id, user_id=decoded_token.get("sub")
)
# Get compressed context
compressed_summary, recent_queries = (
compression_service.get_compressed_context(conversation)
)
return CompressionResult.success_with_compression(
compressed_summary, recent_queries, metadata
)
except Exception as e:
logger.error(f"Error performing compression: {str(e)}", exc_info=True)
return CompressionResult.failure(str(e))
def compress_mid_execution(
self,
conversation_id: str,
user_id: str,
model_id: str,
decoded_token: Dict[str, Any],
current_conversation: Optional[Dict[str, Any]] = None,
) -> CompressionResult:
"""
Perform compression during tool execution.
Args:
conversation_id: Conversation ID
user_id: User ID
model_id: Model ID
decoded_token: User token
current_conversation: Pre-loaded conversation (optional)
Returns:
CompressionResult
"""
try:
# Load conversation if not provided
if current_conversation:
conversation = current_conversation
else:
conversation = self.conversation_service.get_conversation(
conversation_id, user_id
)
if not conversation:
logger.warning(
f"Could not load conversation {conversation_id} for mid-execution compression"
)
return CompressionResult.failure("Conversation not found")
# Perform compression
return self._perform_compression(
conversation_id, conversation, model_id, decoded_token
)
except Exception as e:
logger.error(
f"Error in mid-execution compression: {str(e)}", exc_info=True
)
return CompressionResult.failure(str(e))

View File

@@ -0,0 +1,149 @@
"""Compression prompt building logic."""
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
class CompressionPromptBuilder:
"""Builds prompts for LLM compression calls."""
def __init__(self, version: str = "v1.0"):
"""
Initialize prompt builder.
Args:
version: Prompt template version to use
"""
self.version = version
self.system_prompt = self._load_prompt(version)
def _load_prompt(self, version: str) -> str:
"""
Load prompt template from file.
Args:
version: Version string (e.g., 'v1.0')
Returns:
Prompt template content
Raises:
FileNotFoundError: If prompt template file doesn't exist
"""
current_dir = Path(__file__).resolve().parents[4]
prompt_path = current_dir / "prompts" / "compression" / f"{version}.txt"
try:
with open(prompt_path, "r") as f:
return f.read()
except FileNotFoundError:
logger.error(f"Compression prompt template not found: {prompt_path}")
raise FileNotFoundError(
f"Compression prompt template '{version}' not found at {prompt_path}. "
f"Please ensure the template file exists."
)
def build_prompt(
self,
queries: List[Dict[str, Any]],
existing_compressions: Optional[List[Dict[str, Any]]] = None,
) -> List[Dict[str, str]]:
"""
Build messages for compression LLM call.
Args:
queries: List of query objects to compress
existing_compressions: List of previous compression points
Returns:
List of message dicts for LLM
"""
# Build conversation text
conversation_text = self._format_conversation(queries)
# Add existing compression context if present
existing_compression_context = ""
if existing_compressions and len(existing_compressions) > 0:
existing_compression_context = (
"\n\nIMPORTANT: This conversation has been compressed before. "
"Previous compression summaries:\n\n"
)
for i, comp in enumerate(existing_compressions):
existing_compression_context += (
f"--- Compression {i + 1} (up to message {comp.get('query_index', 'unknown')}) ---\n"
f"{comp.get('compressed_summary', '')}\n\n"
)
existing_compression_context += (
"Your task is to create a NEW summary that incorporates the context from "
"previous compressions AND the new messages below. The final summary should "
"be comprehensive and include all important information from both previous "
"compressions and new messages.\n\n"
)
user_prompt = (
f"{existing_compression_context}"
f"Here is the conversation to summarize:\n\n"
f"{conversation_text}"
)
messages = [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": user_prompt},
]
return messages
def _format_conversation(self, queries: List[Dict[str, Any]]) -> str:
"""
Format conversation queries into readable text for compression.
Args:
queries: List of query objects
Returns:
Formatted conversation text
"""
conversation_lines = []
for i, query in enumerate(queries):
conversation_lines.append(f"--- Message {i + 1} ---")
conversation_lines.append(f"User: {query.get('prompt', '')}")
# Add tool calls if present
tool_calls = query.get("tool_calls", [])
if tool_calls:
conversation_lines.append("\nTool Calls:")
for tc in tool_calls:
tool_name = tc.get("tool_name", "unknown")
action_name = tc.get("action_name", "unknown")
arguments = tc.get("arguments", {})
result = tc.get("result", "")
if result is None:
result = ""
status = tc.get("status", "unknown")
# Include full tool result for complete compression context
conversation_lines.append(
f" - {tool_name}.{action_name}({arguments}) "
f"[{status}] → {result}"
)
# Add agent thought if present
thought = query.get("thought", "")
if thought:
conversation_lines.append(f"\nAgent Thought: {thought}")
# Add assistant response
conversation_lines.append(f"\nAssistant: {query.get('response', '')}")
# Add sources if present
sources = query.get("sources", [])
if sources:
conversation_lines.append(f"\nSources Used: {len(sources)} documents")
conversation_lines.append("") # Empty line between messages
return "\n".join(conversation_lines)

View File

@@ -0,0 +1,306 @@
"""Core compression service with simplified responsibilities."""
import logging
import re
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
from application.api.answer.services.compression.prompt_builder import (
CompressionPromptBuilder,
)
from application.api.answer.services.compression.token_counter import TokenCounter
from application.api.answer.services.compression.types import (
CompressionMetadata,
)
from application.core.settings import settings
logger = logging.getLogger(__name__)
class CompressionService:
"""
Service for compressing conversation history.
Handles DB updates.
"""
def __init__(
self,
llm,
model_id: str,
conversation_service=None,
prompt_builder: Optional[CompressionPromptBuilder] = None,
):
"""
Initialize compression service.
Args:
llm: LLM instance to use for compression
model_id: Model ID for compression
conversation_service: Service for DB operations (optional, for DB updates)
prompt_builder: Custom prompt builder (optional)
"""
self.llm = llm
self.model_id = model_id
self.conversation_service = conversation_service
self.prompt_builder = prompt_builder or CompressionPromptBuilder(
version=settings.COMPRESSION_PROMPT_VERSION
)
def compress_conversation(
self,
conversation: Dict[str, Any],
compress_up_to_index: int,
) -> CompressionMetadata:
"""
Compress conversation history up to specified index.
Args:
conversation: Full conversation document
compress_up_to_index: Last query index to include in compression
Returns:
CompressionMetadata with compression details
Raises:
ValueError: If compress_up_to_index is invalid
"""
try:
queries = conversation.get("queries", [])
if compress_up_to_index < 0 or compress_up_to_index >= len(queries):
raise ValueError(
f"Invalid compress_up_to_index: {compress_up_to_index} "
f"(conversation has {len(queries)} queries)"
)
# Get queries to compress
queries_to_compress = queries[: compress_up_to_index + 1]
# Check if there are existing compressions
existing_compressions = conversation.get("compression_metadata", {}).get(
"compression_points", []
)
if existing_compressions:
logger.info(
f"Found {len(existing_compressions)} previous compression(s) - "
f"will incorporate into new summary"
)
# Calculate original token count
original_tokens = TokenCounter.count_query_tokens(queries_to_compress)
# Log tool call stats
self._log_tool_call_stats(queries_to_compress)
# Build compression prompt
messages = self.prompt_builder.build_prompt(
queries_to_compress, existing_compressions
)
# Call LLM to generate compression
logger.info(
f"Starting compression: {len(queries_to_compress)} queries "
f"(messages 0-{compress_up_to_index}, {original_tokens} tokens) "
f"using model {self.model_id}"
)
response = self.llm.gen(
model=self.model_id, messages=messages, max_tokens=4000
)
# Extract summary from response
compressed_summary = self._extract_summary(response)
# Calculate compressed token count
compressed_tokens = TokenCounter.count_message_tokens(
[{"content": compressed_summary}]
)
# Calculate compression ratio
compression_ratio = (
original_tokens / compressed_tokens if compressed_tokens > 0 else 0
)
logger.info(
f"Compression complete: {original_tokens}{compressed_tokens} tokens "
f"({compression_ratio:.1f}x compression)"
)
# Build compression metadata
compression_metadata = CompressionMetadata(
timestamp=datetime.now(timezone.utc),
query_index=compress_up_to_index,
compressed_summary=compressed_summary,
original_token_count=original_tokens,
compressed_token_count=compressed_tokens,
compression_ratio=compression_ratio,
model_used=self.model_id,
compression_prompt_version=self.prompt_builder.version,
)
return compression_metadata
except Exception as e:
logger.error(f"Error compressing conversation: {str(e)}", exc_info=True)
raise
def compress_and_save(
self,
conversation_id: str,
conversation: Dict[str, Any],
compress_up_to_index: int,
) -> CompressionMetadata:
"""
Compress conversation and save to database.
Args:
conversation_id: Conversation ID
conversation: Full conversation document
compress_up_to_index: Last query index to include
Returns:
CompressionMetadata
Raises:
ValueError: If conversation_service not provided or invalid index
"""
if not self.conversation_service:
raise ValueError(
"conversation_service required for compress_and_save operation"
)
# Perform compression
metadata = self.compress_conversation(conversation, compress_up_to_index)
# Save to database
self.conversation_service.update_compression_metadata(
conversation_id, metadata.to_dict()
)
logger.info(f"Compression metadata saved to database for {conversation_id}")
return metadata
def get_compressed_context(
self, conversation: Dict[str, Any]
) -> tuple[Optional[str], List[Dict[str, Any]]]:
"""
Get compressed summary + recent uncompressed messages.
Args:
conversation: Full conversation document
Returns:
(compressed_summary, recent_messages)
"""
try:
compression_metadata = conversation.get("compression_metadata", {})
if not compression_metadata.get("is_compressed"):
logger.debug("No compression metadata found - using full history")
queries = conversation.get("queries", [])
if queries is None:
logger.error("Conversation queries is None - returning empty list")
return None, []
return None, queries
compression_points = compression_metadata.get("compression_points", [])
if not compression_points:
logger.debug("No compression points found - using full history")
queries = conversation.get("queries", [])
if queries is None:
logger.error("Conversation queries is None - returning empty list")
return None, []
return None, queries
# Get the most recent compression point
latest_compression = compression_points[-1]
compressed_summary = latest_compression.get("compressed_summary")
last_compressed_index = latest_compression.get("query_index")
compressed_tokens = latest_compression.get("compressed_token_count", 0)
original_tokens = latest_compression.get("original_token_count", 0)
# Get only messages after compression point
queries = conversation.get("queries", [])
total_queries = len(queries)
recent_queries = queries[last_compressed_index + 1 :]
logger.info(
f"Using compressed context: summary ({compressed_tokens} tokens, "
f"compressed from {original_tokens}) + {len(recent_queries)} recent messages "
f"(messages {last_compressed_index + 1}-{total_queries - 1})"
)
return compressed_summary, recent_queries
except Exception as e:
logger.error(
f"Error getting compressed context: {str(e)}", exc_info=True
)
queries = conversation.get("queries", [])
if queries is None:
return None, []
return None, queries
def _extract_summary(self, llm_response: str) -> str:
"""
Extract clean summary from LLM response.
Args:
llm_response: Raw LLM response
Returns:
Cleaned summary text
"""
try:
# Try to extract content within <summary> tags
summary_match = re.search(
r"<summary>(.*?)</summary>", llm_response, re.DOTALL
)
if summary_match:
summary = summary_match.group(1).strip()
else:
# If no summary tags, remove analysis tags and use the rest
summary = re.sub(
r"<analysis>.*?</analysis>", "", llm_response, flags=re.DOTALL
).strip()
return summary
except Exception as e:
logger.warning(f"Error extracting summary: {str(e)}, using full response")
return llm_response
def _log_tool_call_stats(self, queries: List[Dict[str, Any]]) -> None:
"""Log statistics about tool calls in queries."""
total_tool_calls = 0
total_tool_result_chars = 0
tool_call_breakdown = {}
for q in queries:
for tc in q.get("tool_calls", []):
total_tool_calls += 1
tool_name = tc.get("tool_name", "unknown")
action_name = tc.get("action_name", "unknown")
key = f"{tool_name}.{action_name}"
tool_call_breakdown[key] = tool_call_breakdown.get(key, 0) + 1
# Track total tool result size
result = tc.get("result", "")
if result:
total_tool_result_chars += len(str(result))
if total_tool_calls > 0:
tool_breakdown_str = ", ".join(
f"{tool}({count})"
for tool, count in sorted(tool_call_breakdown.items())
)
tool_result_kb = total_tool_result_chars / 1024
logger.info(
f"Tool call breakdown: {tool_breakdown_str} "
f"(total result size: {tool_result_kb:.1f} KB, {total_tool_result_chars:,} chars)"
)

View File

@@ -0,0 +1,103 @@
"""Compression threshold checking logic."""
import logging
from typing import Any, Dict
from application.core.model_utils import get_token_limit
from application.core.settings import settings
from application.api.answer.services.compression.token_counter import TokenCounter
logger = logging.getLogger(__name__)
class CompressionThresholdChecker:
"""Determines if compression is needed based on token thresholds."""
def __init__(self, threshold_percentage: float = None):
"""
Initialize threshold checker.
Args:
threshold_percentage: Percentage of context to use as threshold
(defaults to settings.COMPRESSION_THRESHOLD_PERCENTAGE)
"""
self.threshold_percentage = (
threshold_percentage or settings.COMPRESSION_THRESHOLD_PERCENTAGE
)
def should_compress(
self,
conversation: Dict[str, Any],
model_id: str,
current_query_tokens: int = 500,
) -> bool:
"""
Determine if compression is needed.
Args:
conversation: Full conversation document
model_id: Target model for this request
current_query_tokens: Estimated tokens for current query
Returns:
True if tokens >= threshold% of context window
"""
try:
# Calculate total tokens in conversation
total_tokens = TokenCounter.count_conversation_tokens(conversation)
total_tokens += current_query_tokens
# Get context window limit for model
context_limit = get_token_limit(model_id)
# Calculate threshold
threshold = int(context_limit * self.threshold_percentage)
compression_needed = total_tokens >= threshold
percentage_used = (total_tokens / context_limit) * 100
if compression_needed:
logger.warning(
f"COMPRESSION TRIGGERED: {total_tokens} tokens / {context_limit} limit "
f"({percentage_used:.1f}% used, threshold: {self.threshold_percentage * 100:.0f}%)"
)
else:
logger.info(
f"Compression check: {total_tokens}/{context_limit} tokens "
f"({percentage_used:.1f}% used, threshold: {self.threshold_percentage * 100:.0f}%) - No compression needed"
)
return compression_needed
except Exception as e:
logger.error(f"Error checking compression need: {str(e)}", exc_info=True)
return False
def check_message_tokens(self, messages: list, model_id: str) -> bool:
"""
Check if message list exceeds threshold.
Args:
messages: List of message dicts
model_id: Target model
Returns:
True if at or above threshold
"""
try:
current_tokens = TokenCounter.count_message_tokens(messages)
context_limit = get_token_limit(model_id)
threshold = int(context_limit * self.threshold_percentage)
if current_tokens >= threshold:
logger.warning(
f"Message context limit approaching: {current_tokens}/{context_limit} tokens "
f"({(current_tokens/context_limit)*100:.1f}%)"
)
return True
return False
except Exception as e:
logger.error(f"Error checking message tokens: {str(e)}", exc_info=True)
return False

View File

@@ -0,0 +1,103 @@
"""Token counting utilities for compression."""
import logging
from typing import Any, Dict, List
from application.utils import num_tokens_from_string
from application.core.settings import settings
logger = logging.getLogger(__name__)
class TokenCounter:
"""Centralized token counting for conversations and messages."""
@staticmethod
def count_message_tokens(messages: List[Dict]) -> int:
"""
Calculate total tokens in a list of messages.
Args:
messages: List of message dicts with 'content' field
Returns:
Total token count
"""
total_tokens = 0
for message in messages:
content = message.get("content", "")
if isinstance(content, str):
total_tokens += num_tokens_from_string(content)
elif isinstance(content, list):
# Handle structured content (tool calls, etc.)
for item in content:
if isinstance(item, dict):
total_tokens += num_tokens_from_string(str(item))
return total_tokens
@staticmethod
def count_query_tokens(
queries: List[Dict[str, Any]], include_tool_calls: bool = True
) -> int:
"""
Count tokens across multiple query objects.
Args:
queries: List of query objects from conversation
include_tool_calls: Whether to count tool call tokens
Returns:
Total token count
"""
total_tokens = 0
for query in queries:
# Count prompt and response tokens
if "prompt" in query:
total_tokens += num_tokens_from_string(query["prompt"])
if "response" in query:
total_tokens += num_tokens_from_string(query["response"])
if "thought" in query:
total_tokens += num_tokens_from_string(query.get("thought", ""))
# Count tool call tokens
if include_tool_calls and "tool_calls" in query:
for tool_call in query["tool_calls"]:
tool_call_string = (
f"Tool: {tool_call.get('tool_name')} | "
f"Action: {tool_call.get('action_name')} | "
f"Args: {tool_call.get('arguments')} | "
f"Response: {tool_call.get('result')}"
)
total_tokens += num_tokens_from_string(tool_call_string)
return total_tokens
@staticmethod
def count_conversation_tokens(
conversation: Dict[str, Any], include_system_prompt: bool = False
) -> int:
"""
Calculate total tokens in a conversation.
Args:
conversation: Conversation document
include_system_prompt: Whether to include system prompt in count
Returns:
Total token count
"""
try:
queries = conversation.get("queries", [])
total_tokens = TokenCounter.count_query_tokens(queries)
# Add system prompt tokens if requested
if include_system_prompt:
# Rough estimate for system prompt
total_tokens += settings.RESERVED_TOKENS.get("system_prompt", 500)
return total_tokens
except Exception as e:
logger.error(f"Error calculating conversation tokens: {str(e)}")
return 0

View File

@@ -0,0 +1,83 @@
"""Type definitions for compression module."""
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Dict, List, Optional
@dataclass
class CompressionMetadata:
"""Metadata about a compression operation."""
timestamp: datetime
query_index: int
compressed_summary: str
original_token_count: int
compressed_token_count: int
compression_ratio: float
model_used: str
compression_prompt_version: str
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for DB storage."""
return {
"timestamp": self.timestamp,
"query_index": self.query_index,
"compressed_summary": self.compressed_summary,
"original_token_count": self.original_token_count,
"compressed_token_count": self.compressed_token_count,
"compression_ratio": self.compression_ratio,
"model_used": self.model_used,
"compression_prompt_version": self.compression_prompt_version,
}
@dataclass
class CompressionResult:
"""Result of a compression operation."""
success: bool
compressed_summary: Optional[str] = None
recent_queries: List[Dict[str, Any]] = field(default_factory=list)
metadata: Optional[CompressionMetadata] = None
error: Optional[str] = None
compression_performed: bool = False
@classmethod
def success_with_compression(
cls, summary: str, queries: List[Dict], metadata: CompressionMetadata
) -> "CompressionResult":
"""Create a successful result with compression."""
return cls(
success=True,
compressed_summary=summary,
recent_queries=queries,
metadata=metadata,
compression_performed=True,
)
@classmethod
def success_no_compression(cls, queries: List[Dict]) -> "CompressionResult":
"""Create a successful result without compression needed."""
return cls(
success=True,
recent_queries=queries,
compression_performed=False,
)
@classmethod
def failure(cls, error: str) -> "CompressionResult":
"""Create a failure result."""
return cls(success=False, error=error, compression_performed=False)
def as_history(self) -> List[Dict[str, str]]:
"""
Convert recent queries to history format.
Returns:
List of prompt/response dicts
"""
return [
{"prompt": q["prompt"], "response": q["response"]}
for q in self.recent_queries
]

View File

@@ -52,7 +52,7 @@ class ConversationService:
sources: List[Dict[str, Any]],
tool_calls: List[Dict[str, Any]],
llm: Any,
gpt_model: str,
model_id: str,
decoded_token: Dict[str, Any],
index: Optional[int] = None,
api_key: Optional[str] = None,
@@ -62,11 +62,13 @@ class ConversationService:
attachment_ids: Optional[List[str]] = None,
) -> str:
"""Save or update a conversation in the database"""
if decoded_token is None:
raise ValueError("Invalid or missing authentication token")
user_id = decoded_token.get("sub")
if not user_id:
raise ValueError("User ID not found in token")
current_time = datetime.now(timezone.utc)
# clean up in sources array such that we save max 1k characters for text part
for source in sources:
if "text" in source and isinstance(source["text"], str):
@@ -90,6 +92,7 @@ class ConversationService:
f"queries.{index}.tool_calls": tool_calls,
f"queries.{index}.timestamp": current_time,
f"queries.{index}.attachments": attachment_ids,
f"queries.{index}.model_id": model_id,
}
},
)
@@ -120,6 +123,7 @@ class ConversationService:
"tool_calls": tool_calls,
"timestamp": current_time,
"attachments": attachment_ids,
"model_id": model_id,
}
}
},
@@ -146,9 +150,12 @@ class ConversationService:
]
completion = llm.gen(
model=gpt_model, messages=messages_summary, max_tokens=30
model=model_id, messages=messages_summary, max_tokens=500
)
if not completion or not completion.strip():
completion = question[:50] if question else "New Conversation"
conversation_data = {
"user": user_id,
"date": current_time,
@@ -162,6 +169,7 @@ class ConversationService:
"tool_calls": tool_calls,
"timestamp": current_time,
"attachments": attachment_ids,
"model_id": model_id,
}
],
}
@@ -177,3 +185,103 @@ class ConversationService:
conversation_data["api_key"] = agent["key"]
result = self.conversations_collection.insert_one(conversation_data)
return str(result.inserted_id)
def update_compression_metadata(
self, conversation_id: str, compression_metadata: Dict[str, Any]
) -> None:
"""
Update conversation with compression metadata.
Uses $push with $slice to keep only the most recent compression points,
preventing unbounded array growth. Since each compression incorporates
previous compressions, older points become redundant.
Args:
conversation_id: Conversation ID
compression_metadata: Compression point data
"""
try:
self.conversations_collection.update_one(
{"_id": ObjectId(conversation_id)},
{
"$set": {
"compression_metadata.is_compressed": True,
"compression_metadata.last_compression_at": compression_metadata.get(
"timestamp"
),
},
"$push": {
"compression_metadata.compression_points": {
"$each": [compression_metadata],
"$slice": -settings.COMPRESSION_MAX_HISTORY_POINTS,
}
},
},
)
logger.info(
f"Updated compression metadata for conversation {conversation_id}"
)
except Exception as e:
logger.error(
f"Error updating compression metadata: {str(e)}", exc_info=True
)
raise
def append_compression_message(
self, conversation_id: str, compression_metadata: Dict[str, Any]
) -> None:
"""
Append a synthetic compression summary entry into the conversation history.
This makes the summary visible in the DB alongside normal queries.
"""
try:
summary = compression_metadata.get("compressed_summary", "")
if not summary:
return
timestamp = compression_metadata.get("timestamp", datetime.now(timezone.utc))
self.conversations_collection.update_one(
{"_id": ObjectId(conversation_id)},
{
"$push": {
"queries": {
"prompt": "[Context Compression Summary]",
"response": summary,
"thought": "",
"sources": [],
"tool_calls": [],
"timestamp": timestamp,
"attachments": [],
"model_id": compression_metadata.get("model_used"),
}
}
},
)
logger.info(f"Appended compression summary to conversation {conversation_id}")
except Exception as e:
logger.error(
f"Error appending compression summary: {str(e)}", exc_info=True
)
def get_compression_metadata(
self, conversation_id: str
) -> Optional[Dict[str, Any]]:
"""
Get compression metadata for a conversation.
Args:
conversation_id: Conversation ID
Returns:
Compression metadata dict or None
"""
try:
conversation = self.conversations_collection.find_one(
{"_id": ObjectId(conversation_id)}, {"compression_metadata": 1}
)
return conversation.get("compression_metadata") if conversation else None
except Exception as e:
logger.error(
f"Error getting compression metadata: {str(e)}", exc_info=True
)
return None

View File

@@ -10,14 +10,21 @@ from bson.dbref import DBRef
from bson.objectid import ObjectId
from application.agents.agent_creator import AgentCreator
from application.api.answer.services.compression import CompressionOrchestrator
from application.api.answer.services.compression.token_counter import TokenCounter
from application.api.answer.services.conversation_service import ConversationService
from application.api.answer.services.prompt_renderer import PromptRenderer
from application.core.model_utils import (
get_api_key_for_provider,
get_default_model_id,
get_provider_from_model_id,
validate_model_id,
)
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.retriever.retriever_creator import RetrieverCreator
from application.utils import (
calculate_doc_token_budget,
get_gpt_model,
limit_chat_history,
)
@@ -83,18 +90,24 @@ class StreamProcessor:
self.retriever_config = {}
self.is_shared_usage = False
self.shared_token = None
self.gpt_model = get_gpt_model()
self.agent_id = self.data.get("agent_id")
self.model_id: Optional[str] = None
self.conversation_service = ConversationService()
self.compression_orchestrator = CompressionOrchestrator(
self.conversation_service
)
self.prompt_renderer = PromptRenderer()
self._prompt_content: Optional[str] = None
self._required_tool_actions: Optional[Dict[str, Set[Optional[str]]]] = None
self.compressed_summary: Optional[str] = None
self.compressed_summary_tokens: int = 0
def initialize(self):
"""Initialize all required components for processing"""
self._configure_agent()
self._validate_and_set_model()
self._configure_source()
self._configure_retriever()
self._configure_agent()
self._load_conversation_history()
self._process_attachments()
@@ -106,14 +119,69 @@ class StreamProcessor:
)
if not conversation:
raise ValueError("Conversation not found or unauthorized")
# Check if compression is enabled and needed
if settings.ENABLE_CONVERSATION_COMPRESSION:
self._handle_compression(conversation)
else:
# Original behavior - load all history
self.history = [
{"prompt": query["prompt"], "response": query["response"]}
for query in conversation.get("queries", [])
]
else:
self.history = limit_chat_history(
json.loads(self.data.get("history", "[]")), model_id=self.model_id
)
def _handle_compression(self, conversation: Dict[str, Any]):
"""
Handle conversation compression logic using orchestrator.
Args:
conversation: Full conversation document
"""
try:
# Use orchestrator to handle all compression logic
result = self.compression_orchestrator.compress_if_needed(
conversation_id=self.conversation_id,
user_id=self.initial_user_id,
model_id=self.model_id,
decoded_token=self.decoded_token,
)
if not result.success:
logger.error(f"Compression failed: {result.error}, using full history")
self.history = [
{"prompt": query["prompt"], "response": query["response"]}
for query in conversation.get("queries", [])
]
return
# Set compressed summary if compression was performed
if result.compression_performed and result.compressed_summary:
self.compressed_summary = result.compressed_summary
self.compressed_summary_tokens = TokenCounter.count_message_tokens(
[{"content": result.compressed_summary}]
)
logger.info(
f"Using compressed summary ({self.compressed_summary_tokens} tokens) "
f"+ {len(result.recent_queries)} recent messages"
)
# Build history from recent queries
self.history = result.as_history()
except Exception as e:
logger.error(
f"Error handling compression, falling back to standard history: {str(e)}",
exc_info=True,
)
# Fallback to original behavior
self.history = [
{"prompt": query["prompt"], "response": query["response"]}
for query in conversation.get("queries", [])
]
else:
self.history = limit_chat_history(
json.loads(self.data.get("history", "[]")), gpt_model=self.gpt_model
)
def _process_attachments(self):
"""Process any attachments in the request"""
@@ -143,6 +211,34 @@ class StreamProcessor:
)
return attachments
def _validate_and_set_model(self):
"""Validate and set model_id from request"""
from application.core.model_settings import ModelRegistry
requested_model = self.data.get("model_id")
if requested_model:
if not validate_model_id(requested_model):
registry = ModelRegistry.get_instance()
available_models = [m.id for m in registry.get_enabled_models()]
raise ValueError(
f"Invalid model_id '{requested_model}'. "
f"Available models: {', '.join(available_models[:5])}"
+ (
f" and {len(available_models) - 5} more"
if len(available_models) > 5
else ""
)
)
self.model_id = requested_model
else:
# Check if agent has a default model configured
agent_default_model = self.agent_config.get("default_model_id", "")
if agent_default_model and validate_model_id(agent_default_model):
self.model_id = agent_default_model
else:
self.model_id = get_default_model_id()
def _get_agent_key(self, agent_id: Optional[str], user_id: Optional[str]) -> tuple:
"""Get API key for agent with access control"""
if not agent_id:
@@ -214,6 +310,10 @@ class StreamProcessor:
data["sources"] = sources_list
else:
data["sources"] = []
# Preserve model configuration from agent
data["default_model_id"] = data.get("default_model_id", "")
return data
def _configure_source(self):
@@ -250,28 +350,61 @@ class StreamProcessor:
self.source = {}
self.all_sources = []
def _resolve_agent_id(self) -> Optional[str]:
"""Resolve agent_id from request, then fall back to conversation context."""
request_agent_id = self.data.get("agent_id")
if request_agent_id:
return str(request_agent_id)
if not self.conversation_id or not self.initial_user_id:
return None
try:
conversation = self.conversation_service.get_conversation(
self.conversation_id, self.initial_user_id
)
except Exception:
return None
if not conversation:
return None
conversation_agent_id = conversation.get("agent_id")
if conversation_agent_id:
return str(conversation_agent_id)
return None
def _configure_agent(self):
"""Configure the agent based on request data"""
agent_id = self.data.get("agent_id")
agent_id = self._resolve_agent_id()
self.agent_key, self.is_shared_usage, self.shared_token = self._get_agent_key(
agent_id, self.initial_user_id
)
self.agent_id = str(agent_id) if agent_id else None
api_key = self.data.get("api_key")
if api_key:
data_key = self._get_data_from_api_key(api_key)
if data_key.get("_id"):
self.agent_id = str(data_key.get("_id"))
self.agent_config.update(
{
"prompt_id": data_key.get("prompt_id", "default"),
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
"user_api_key": api_key,
"json_schema": data_key.get("json_schema"),
"default_model_id": data_key.get("default_model_id", ""),
}
)
self.initial_user_id = data_key.get("user")
self.decoded_token = {"sub": data_key.get("user")}
if data_key.get("source"):
self.source = {"active_docs": data_key["source"]}
if data_key.get("workflow"):
self.agent_config["workflow"] = data_key["workflow"]
self.agent_config["workflow_owner"] = data_key.get("user")
if data_key.get("retriever"):
self.retriever_config["retriever_name"] = data_key["retriever"]
if data_key.get("chunks") is not None:
@@ -284,12 +417,15 @@ class StreamProcessor:
self.retriever_config["chunks"] = 2
elif self.agent_key:
data_key = self._get_data_from_api_key(self.agent_key)
if data_key.get("_id"):
self.agent_id = str(data_key.get("_id"))
self.agent_config.update(
{
"prompt_id": data_key.get("prompt_id", "default"),
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
"user_api_key": self.agent_key,
"json_schema": data_key.get("json_schema"),
"default_model_id": data_key.get("default_model_id", ""),
}
)
self.decoded_token = (
@@ -299,6 +435,9 @@ class StreamProcessor:
)
if data_key.get("source"):
self.source = {"active_docs": data_key["source"]}
if data_key.get("workflow"):
self.agent_config["workflow"] = data_key["workflow"]
self.agent_config["workflow_owner"] = data_key.get("user")
if data_key.get("retriever"):
self.retriever_config["retriever_name"] = data_key["retriever"]
if data_key.get("chunks") is not None:
@@ -310,26 +449,32 @@ class StreamProcessor:
)
self.retriever_config["chunks"] = 2
else:
agent_type = settings.AGENT_NAME
if self.data.get("workflow") and isinstance(
self.data.get("workflow"), dict
):
agent_type = "workflow"
self.agent_config["workflow"] = self.data["workflow"]
if isinstance(self.decoded_token, dict):
self.agent_config["workflow_owner"] = self.decoded_token.get("sub")
self.agent_config.update(
{
"prompt_id": self.data.get("prompt_id", "default"),
"agent_type": settings.AGENT_NAME,
"agent_type": agent_type,
"user_api_key": None,
"json_schema": None,
"default_model_id": "",
}
)
def _configure_retriever(self):
history_token_limit = int(self.data.get("token_limit", 2000))
doc_token_limit = calculate_doc_token_budget(
gpt_model=self.gpt_model, history_token_limit=history_token_limit
)
doc_token_limit = calculate_doc_token_budget(model_id=self.model_id)
self.retriever_config = {
"retriever_name": self.data.get("retriever", "classic"),
"chunks": int(self.data.get("chunks", 2)),
"doc_token_limit": doc_token_limit,
"history_token_limit": history_token_limit,
}
api_key = self.data.get("api_key") or self.agent_key
@@ -344,14 +489,15 @@ class StreamProcessor:
prompt=get_prompt(self.agent_config["prompt_id"], self.prompts_collection),
chunks=self.retriever_config["chunks"],
doc_token_limit=self.retriever_config.get("doc_token_limit", 50000),
gpt_model=self.gpt_model,
model_id=self.model_id,
user_api_key=self.agent_config["user_api_key"],
agent_id=self.agent_id,
decoded_token=self.decoded_token,
)
def pre_fetch_docs(self, question: str) -> tuple[Optional[str], Optional[list]]:
"""Pre-fetch documents for template rendering before agent creation"""
if self.data.get("isNoneDoc", False):
if self.data.get("isNoneDoc", False) and not self.agent_id:
logger.info("Pre-fetch skipped: isNoneDoc=True")
return None, None
try:
@@ -626,17 +772,46 @@ class StreamProcessor:
tools_data=tools_data,
)
return AgentCreator.create_agent(
self.agent_config["agent_type"],
endpoint="stream",
llm_name=settings.LLM_PROVIDER,
gpt_model=self.gpt_model,
api_key=settings.API_KEY,
user_api_key=self.agent_config["user_api_key"],
prompt=rendered_prompt,
chat_history=self.history,
retrieved_docs=self.retrieved_docs,
decoded_token=self.decoded_token,
attachments=self.attachments,
json_schema=self.agent_config.get("json_schema"),
provider = (
get_provider_from_model_id(self.model_id)
if self.model_id
else settings.LLM_PROVIDER
)
system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER)
agent_type = self.agent_config["agent_type"]
# Base agent kwargs
agent_kwargs = {
"endpoint": "stream",
"llm_name": provider or settings.LLM_PROVIDER,
"model_id": self.model_id,
"api_key": system_api_key,
"agent_id": self.agent_id,
"user_api_key": self.agent_config["user_api_key"],
"prompt": rendered_prompt,
"chat_history": self.history,
"retrieved_docs": self.retrieved_docs,
"decoded_token": self.decoded_token,
"attachments": self.attachments,
"json_schema": self.agent_config.get("json_schema"),
"compressed_summary": self.compressed_summary,
}
# Workflow-specific kwargs for workflow agents
if agent_type == "workflow":
workflow_config = self.agent_config.get("workflow")
if isinstance(workflow_config, str):
agent_kwargs["workflow_id"] = workflow_config
elif isinstance(workflow_config, dict):
agent_kwargs["workflow"] = workflow_config
workflow_owner = self.agent_config.get("workflow_owner")
if workflow_owner:
agent_kwargs["workflow_owner"] = workflow_owner
agent = AgentCreator.create_agent(agent_type, **agent_kwargs)
agent.conversation_id = self.conversation_id
agent.initial_user_id = self.initial_user_id
return agent

View File

@@ -1,7 +1,9 @@
import base64
import datetime
import html
import json
import uuid
from urllib.parse import urlencode
from bson.objectid import ObjectId
@@ -35,6 +37,18 @@ connector = Blueprint("connector", __name__)
connectors_ns = Namespace("connectors", description="Connector operations", path="/")
api.add_namespace(connectors_ns)
# Fixed callback status path to prevent open redirect
CALLBACK_STATUS_PATH = "/api/connectors/callback-status"
def build_callback_redirect(params: dict) -> str:
"""Build a safe redirect URL to the callback status page.
Uses a fixed path and properly URL-encodes all parameters
to prevent URL injection and open redirect vulnerabilities.
"""
return f"{CALLBACK_STATUS_PATH}?{urlencode(params)}"
@connectors_ns.route("/api/connectors/auth")
@@ -75,8 +89,8 @@ class ConnectorAuth(Resource):
"state": state
}), 200)
except Exception as e:
current_app.logger.error(f"Error generating connector auth URL: {e}")
return make_response(jsonify({"success": False, "error": str(e)}), 500)
current_app.logger.error(f"Error generating connector auth URL: {e}", exc_info=True)
return make_response(jsonify({"success": False, "error": "Failed to generate authorization URL"}), 500)
@connectors_ns.route("/api/connectors/callback")
@@ -93,18 +107,37 @@ class ConnectorsCallback(Resource):
error = request.args.get('error')
state_dict = json.loads(base64.urlsafe_b64decode(state.encode()).decode())
provider = state_dict["provider"]
state_object_id = state_dict["object_id"]
provider = state_dict.get("provider")
state_object_id = state_dict.get("object_id")
# Validate provider
if not provider or not isinstance(provider, str) or not ConnectorCreator.is_supported(provider):
return redirect(build_callback_redirect({
"status": "error",
"message": "Invalid provider"
}))
if error:
if error == "access_denied":
return redirect(f"/api/connectors/callback-status?status=cancelled&message=Authentication+was+cancelled.+You+can+try+again+if+you'd+like+to+connect+your+account.&provider={provider}")
return redirect(build_callback_redirect({
"status": "cancelled",
"message": "Authentication was cancelled. You can try again if you'd like to connect your account.",
"provider": provider
}))
else:
current_app.logger.warning(f"OAuth error in callback: {error}")
return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}")
return redirect(build_callback_redirect({
"status": "error",
"message": "Authentication failed. Please try again and make sure to grant all requested permissions.",
"provider": provider
}))
if not authorization_code:
return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}")
return redirect(build_callback_redirect({
"status": "error",
"message": "Authentication failed. Please try again and make sure to grant all requested permissions.",
"provider": provider
}))
try:
auth = ConnectorCreator.create_auth(provider)
@@ -113,20 +146,19 @@ class ConnectorsCallback(Resource):
session_token = str(uuid.uuid4())
try:
credentials = auth.create_credentials_from_token_info(token_info)
service = auth.build_drive_service(credentials)
user_info = service.about().get(fields="user").execute()
user_email = user_info.get('user', {}).get('emailAddress', 'Connected User')
if provider == "google_drive":
credentials = auth.create_credentials_from_token_info(token_info)
service = auth.build_drive_service(credentials)
user_info = service.about().get(fields="user").execute()
user_email = user_info.get('user', {}).get('emailAddress', 'Connected User')
else:
user_email = token_info.get('user_info', {}).get('email', 'Connected User')
except Exception as e:
current_app.logger.warning(f"Could not get user info: {e}")
user_email = 'Connected User'
sanitized_token_info = {
"access_token": token_info.get("access_token"),
"refresh_token": token_info.get("refresh_token"),
"token_uri": token_info.get("token_uri"),
"expiry": token_info.get("expiry")
}
sanitized_token_info = auth.sanitize_token_info(token_info)
sessions_collection.find_one_and_update(
{"_id": ObjectId(state_object_id), "provider": provider},
@@ -141,26 +173,39 @@ class ConnectorsCallback(Resource):
)
# Redirect to success page with session token and user email
return redirect(f"/api/connectors/callback-status?status=success&message=Authentication+successful&provider={provider}&session_token={session_token}&user_email={user_email}")
return redirect(build_callback_redirect({
"status": "success",
"message": "Authentication successful",
"provider": provider,
"session_token": session_token,
"user_email": user_email
}))
except Exception as e:
current_app.logger.error(f"Error exchanging code for tokens: {str(e)}", exc_info=True)
return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}")
return redirect(build_callback_redirect({
"status": "error",
"message": "Authentication failed. Please try again and make sure to grant all requested permissions.",
"provider": provider
}))
except Exception as e:
current_app.logger.error(f"Error handling connector callback: {e}")
return redirect("/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.")
return redirect(build_callback_redirect({
"status": "error",
"message": "Authentication failed. Please try again and make sure to grant all requested permissions."
}))
@connectors_ns.route("/api/connectors/files")
class ConnectorFiles(Resource):
@api.expect(api.model("ConnectorFilesModel", {
"provider": fields.String(required=True),
"session_token": fields.String(required=True),
"folder_id": fields.String(required=False),
"limit": fields.Integer(required=False),
"provider": fields.String(required=True),
"session_token": fields.String(required=True),
"folder_id": fields.String(required=False),
"limit": fields.Integer(required=False),
"page_token": fields.String(required=False),
"search_query": fields.String(required=False)
"search_query": fields.String(required=False),
}))
@api.doc(description="List files from a connector provider (supports pagination and search)")
def post(self):
@@ -168,11 +213,8 @@ class ConnectorFiles(Resource):
data = request.get_json()
provider = data.get('provider')
session_token = data.get('session_token')
folder_id = data.get('folder_id')
limit = data.get('limit', 10)
page_token = data.get('page_token')
search_query = data.get('search_query')
if not provider or not session_token:
return make_response(jsonify({"success": False, "error": "provider and session_token are required"}), 400)
@@ -185,15 +227,12 @@ class ConnectorFiles(Resource):
return make_response(jsonify({"success": False, "error": "Invalid or unauthorized session"}), 401)
loader = ConnectorCreator.create_connector(provider, session_token)
generic_keys = {'provider', 'session_token'}
input_config = {
'limit': limit,
'list_only': True,
'session_token': session_token,
'folder_id': folder_id,
'page_token': page_token
k: v for k, v in data.items() if k not in generic_keys
}
if search_query:
input_config['search_query'] = search_query
input_config['list_only'] = True
documents = loader.load_data(input_config)
@@ -228,8 +267,8 @@ class ConnectorFiles(Resource):
"has_more": has_more
}), 200)
except Exception as e:
current_app.logger.error(f"Error loading connector files: {e}")
return make_response(jsonify({"success": False, "error": f"Failed to load files: {str(e)}"}), 500)
current_app.logger.error(f"Error loading connector files: {e}", exc_info=True)
return make_response(jsonify({"success": False, "error": "Failed to load files"}), 500)
@connectors_ns.route("/api/connectors/validate-session")
@@ -260,12 +299,7 @@ class ConnectorValidateSession(Resource):
if is_expired and token_info.get('refresh_token'):
try:
refreshed_token_info = auth.refresh_access_token(token_info.get('refresh_token'))
sanitized_token_info = {
"access_token": refreshed_token_info.get("access_token"),
"refresh_token": refreshed_token_info.get("refresh_token"),
"token_uri": refreshed_token_info.get("token_uri"),
"expiry": refreshed_token_info.get("expiry")
}
sanitized_token_info = auth.sanitize_token_info(refreshed_token_info)
sessions_collection.update_one(
{"session_token": session_token},
{"$set": {"token_info": sanitized_token_info}}
@@ -282,15 +316,21 @@ class ConnectorValidateSession(Resource):
"error": "Session token has expired. Please reconnect."
}), 401)
return make_response(jsonify({
_base_fields = {"access_token", "refresh_token", "token_uri", "expiry"}
provider_extras = {k: v for k, v in token_info.items() if k not in _base_fields}
response_data = {
"success": True,
"expired": False,
"user_email": session.get('user_email', 'Connected User'),
"access_token": token_info.get('access_token')
}), 200)
"access_token": token_info.get('access_token'),
**provider_extras,
}
return make_response(jsonify(response_data), 200)
except Exception as e:
current_app.logger.error(f"Error validating connector session: {e}")
return make_response(jsonify({"success": False, "error": str(e)}), 500)
current_app.logger.error(f"Error validating connector session: {e}", exc_info=True)
return make_response(jsonify({"success": False, "error": "Failed to validate session"}), 500)
@connectors_ns.route("/api/connectors/disconnect")
@@ -311,8 +351,8 @@ class ConnectorDisconnect(Resource):
return make_response(jsonify({"success": True}), 200)
except Exception as e:
current_app.logger.error(f"Error disconnecting connector session: {e}")
return make_response(jsonify({"success": False, "error": str(e)}), 500)
current_app.logger.error(f"Error disconnecting connector session: {e}", exc_info=True)
return make_response(jsonify({"success": False, "error": "Failed to disconnect session"}), 500)
@connectors_ns.route("/api/connectors/sync")
@@ -418,8 +458,8 @@ class ConnectorSync(Resource):
return make_response(
jsonify({
"success": False,
"error": str(err)
}),
"error": "Failed to sync connector source"
}),
400
)
@@ -430,17 +470,32 @@ class ConnectorCallbackStatus(Resource):
def get(self):
"""Return HTML page with connector authentication status"""
try:
status = request.args.get('status', 'error')
message = request.args.get('message', '')
provider = request.args.get('provider', 'connector')
# Validate and sanitize status to a known value
status_raw = request.args.get('status', 'error')
status = status_raw if status_raw in ('success', 'error', 'cancelled') else 'error'
# Escape all user-controlled values for HTML context
message = html.escape(request.args.get('message', ''))
provider_raw = request.args.get('provider', 'connector')
provider = html.escape(provider_raw.replace('_', ' ').title())
session_token = request.args.get('session_token', '')
user_email = request.args.get('user_email', '')
user_email = html.escape(request.args.get('user_email', ''))
def safe_js_string(value: str) -> str:
"""Safely encode a string for embedding in inline JavaScript."""
js_encoded = json.dumps(value)
return js_encoded.replace('</', '<\\/').replace('<!--', '<\\!--')
js_status = safe_js_string(status)
js_session_token = safe_js_string(session_token)
js_user_email = safe_js_string(user_email)
js_provider_type = safe_js_string(provider_raw)
html_content = f"""
<!DOCTYPE html>
<html>
<head>
<title>{provider.replace('_', ' ').title()} Authentication</title>
<title>{provider} Authentication</title>
<style>
body {{ font-family: Arial, sans-serif; text-align: center; padding: 40px; }}
.container {{ max-width: 600px; margin: 0 auto; }}
@@ -450,13 +505,14 @@ class ConnectorCallbackStatus(Resource):
</style>
<script>
window.onload = function() {{
const status = "{status}";
const sessionToken = "{session_token}";
const userEmail = "{user_email}";
const status = {js_status};
const sessionToken = {js_session_token};
const userEmail = {js_user_email};
const providerType = {js_provider_type};
if (status === "success" && window.opener) {{
window.opener.postMessage({{
type: '{provider}_auth_success',
type: providerType + '_auth_success',
session_token: sessionToken,
user_email: userEmail
}}, '*');
@@ -470,17 +526,17 @@ class ConnectorCallbackStatus(Resource):
</head>
<body>
<div class="container">
<h2>{provider.replace('_', ' ').title()} Authentication</h2>
<h2>{provider} Authentication</h2>
<div class="{status}">
<p>{message}</p>
{f'<p>Connected as: {user_email}</p>' if status == 'success' else ''}
</div>
<p><small>You can close this window. {f"Your {provider.replace('_', ' ').title()} is now connected and ready to use." if status == 'success' else "Feel free to close this window."}</small></p>
<p><small>You can close this window. {f"Your {provider} is now connected and ready to use." if status == 'success' else "Feel free to close this window."}</small></p>
</div>
</body>
</html>
"""
return make_response(html_content, 200, {'Content-Type': 'text/html'})
except Exception as e:
current_app.logger.error(f"Error rendering callback status page: {e}")

View File

@@ -1,7 +1,7 @@
import os
import datetime
import json
from flask import Blueprint, request, send_from_directory
from flask import Blueprint, request, send_from_directory, jsonify
from werkzeug.utils import secure_filename
from bson.objectid import ObjectId
import logging
@@ -24,6 +24,16 @@ current_dir = os.path.dirname(
internal = Blueprint("internal", __name__)
@internal.before_request
def verify_internal_key():
"""Verify INTERNAL_KEY for all internal endpoint requests."""
if settings.INTERNAL_KEY:
internal_key = request.headers.get("X-Internal-Key")
if not internal_key or internal_key != settings.INTERNAL_KEY:
logger.warning(f"Unauthorized internal API access attempt from {request.remote_addr}")
return jsonify({"error": "Unauthorized", "message": "Invalid or missing internal key"}), 401
@internal.route("/api/download", methods=["get"])
def download_file():
user = secure_filename(request.args.get("user"))
@@ -51,6 +61,7 @@ def upload_index_files():
file_path = request.form.get("file_path")
directory_structure = request.form.get("directory_structure")
file_name_map = request.form.get("file_name_map")
if directory_structure:
try:
@@ -60,6 +71,14 @@ def upload_index_files():
directory_structure = {}
else:
directory_structure = {}
if file_name_map:
try:
file_name_map = json.loads(file_name_map)
except Exception:
logger.error("Error parsing file_name_map")
file_name_map = None
else:
file_name_map = None
storage = StorageCreator.get_storage()
index_base_path = f"indexes/{id}"
@@ -87,41 +106,43 @@ def upload_index_files():
existing_entry = sources_collection.find_one({"_id": ObjectId(id)})
if existing_entry:
update_fields = {
"user": user,
"name": job_name,
"language": job_name,
"date": datetime.datetime.now(),
"model": settings.EMBEDDINGS_NAME,
"type": type,
"tokens": tokens,
"retriever": retriever,
"remote_data": remote_data,
"sync_frequency": sync_frequency,
"file_path": file_path,
"directory_structure": directory_structure,
}
if file_name_map is not None:
update_fields["file_name_map"] = file_name_map
sources_collection.update_one(
{"_id": ObjectId(id)},
{
"$set": {
"user": user,
"name": job_name,
"language": job_name,
"date": datetime.datetime.now(),
"model": settings.EMBEDDINGS_NAME,
"type": type,
"tokens": tokens,
"retriever": retriever,
"remote_data": remote_data,
"sync_frequency": sync_frequency,
"file_path": file_path,
"directory_structure": directory_structure,
}
},
{"$set": update_fields},
)
else:
sources_collection.insert_one(
{
"_id": ObjectId(id),
"user": user,
"name": job_name,
"language": job_name,
"date": datetime.datetime.now(),
"model": settings.EMBEDDINGS_NAME,
"type": type,
"tokens": tokens,
"retriever": retriever,
"remote_data": remote_data,
"sync_frequency": sync_frequency,
"file_path": file_path,
"directory_structure": directory_structure,
}
)
insert_doc = {
"_id": ObjectId(id),
"user": user,
"name": job_name,
"language": job_name,
"date": datetime.datetime.now(),
"model": settings.EMBEDDINGS_NAME,
"type": type,
"tokens": tokens,
"retriever": retriever,
"remote_data": remote_data,
"sync_frequency": sync_frequency,
"file_path": file_path,
"directory_structure": directory_structure,
}
if file_name_map is not None:
insert_doc["file_name_map"] = file_name_map
sources_collection.insert_one(insert_doc)
return {"status": "ok"}

View File

@@ -3,5 +3,6 @@
from .routes import agents_ns
from .sharing import agents_sharing_ns
from .webhooks import agents_webhooks_ns
from .folders import agents_folders_ns
__all__ = ["agents_ns", "agents_sharing_ns", "agents_webhooks_ns"]
__all__ = ["agents_ns", "agents_sharing_ns", "agents_webhooks_ns", "agents_folders_ns"]

View File

@@ -0,0 +1,261 @@
"""
Agent folders management routes.
Provides virtual folder organization for agents (Google Drive-like structure).
"""
import datetime
from bson.objectid import ObjectId
from flask import jsonify, make_response, request
from flask_restx import Namespace, Resource, fields
from application.api import api
from application.api.user.base import (
agent_folders_collection,
agents_collection,
)
agents_folders_ns = Namespace(
"agents_folders", description="Agent folder management", path="/api/agents/folders"
)
@agents_folders_ns.route("/")
class AgentFolders(Resource):
@api.doc(description="Get all folders for the user")
def get(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
try:
folders = list(agent_folders_collection.find({"user": user}))
result = [
{
"id": str(f["_id"]),
"name": f["name"],
"parent_id": f.get("parent_id"),
"created_at": f.get("created_at", "").isoformat() if f.get("created_at") else None,
"updated_at": f.get("updated_at", "").isoformat() if f.get("updated_at") else None,
}
for f in folders
]
return make_response(jsonify({"folders": result}), 200)
except Exception as e:
return make_response(jsonify({"success": False, "message": str(e)}), 400)
@api.doc(description="Create a new folder")
@api.expect(
api.model(
"CreateFolder",
{
"name": fields.String(required=True, description="Folder name"),
"parent_id": fields.String(required=False, description="Parent folder ID"),
},
)
)
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
if not data or not data.get("name"):
return make_response(jsonify({"success": False, "message": "Folder name is required"}), 400)
parent_id = data.get("parent_id")
if parent_id:
parent = agent_folders_collection.find_one({"_id": ObjectId(parent_id), "user": user})
if not parent:
return make_response(jsonify({"success": False, "message": "Parent folder not found"}), 404)
try:
now = datetime.datetime.now(datetime.timezone.utc)
folder = {
"user": user,
"name": data["name"],
"parent_id": parent_id,
"created_at": now,
"updated_at": now,
}
result = agent_folders_collection.insert_one(folder)
return make_response(
jsonify({"id": str(result.inserted_id), "name": data["name"], "parent_id": parent_id}),
201,
)
except Exception as e:
return make_response(jsonify({"success": False, "message": str(e)}), 400)
@agents_folders_ns.route("/<string:folder_id>")
class AgentFolder(Resource):
@api.doc(description="Get a specific folder with its agents")
def get(self, folder_id):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
try:
folder = agent_folders_collection.find_one({"_id": ObjectId(folder_id), "user": user})
if not folder:
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
agents = list(agents_collection.find({"user": user, "folder_id": folder_id}))
agents_list = [
{"id": str(a["_id"]), "name": a["name"], "description": a.get("description", "")}
for a in agents
]
subfolders = list(agent_folders_collection.find({"user": user, "parent_id": folder_id}))
subfolders_list = [{"id": str(sf["_id"]), "name": sf["name"]} for sf in subfolders]
return make_response(
jsonify({
"id": str(folder["_id"]),
"name": folder["name"],
"parent_id": folder.get("parent_id"),
"agents": agents_list,
"subfolders": subfolders_list,
}),
200,
)
except Exception as e:
return make_response(jsonify({"success": False, "message": str(e)}), 400)
@api.doc(description="Update a folder")
def put(self, folder_id):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
if not data:
return make_response(jsonify({"success": False, "message": "No data provided"}), 400)
try:
update_fields = {"updated_at": datetime.datetime.now(datetime.timezone.utc)}
if "name" in data:
update_fields["name"] = data["name"]
if "parent_id" in data:
if data["parent_id"] == folder_id:
return make_response(jsonify({"success": False, "message": "Cannot set folder as its own parent"}), 400)
update_fields["parent_id"] = data["parent_id"]
result = agent_folders_collection.update_one(
{"_id": ObjectId(folder_id), "user": user}, {"$set": update_fields}
)
if result.matched_count == 0:
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
return make_response(jsonify({"success": True}), 200)
except Exception as e:
return make_response(jsonify({"success": False, "message": str(e)}), 400)
@api.doc(description="Delete a folder")
def delete(self, folder_id):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
try:
agents_collection.update_many(
{"user": user, "folder_id": folder_id}, {"$unset": {"folder_id": ""}}
)
agent_folders_collection.update_many(
{"user": user, "parent_id": folder_id}, {"$unset": {"parent_id": ""}}
)
result = agent_folders_collection.delete_one({"_id": ObjectId(folder_id), "user": user})
if result.deleted_count == 0:
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
return make_response(jsonify({"success": True}), 200)
except Exception as e:
return make_response(jsonify({"success": False, "message": str(e)}), 400)
@agents_folders_ns.route("/move_agent")
class MoveAgentToFolder(Resource):
@api.doc(description="Move an agent to a folder or remove from folder")
@api.expect(
api.model(
"MoveAgent",
{
"agent_id": fields.String(required=True, description="Agent ID to move"),
"folder_id": fields.String(required=False, description="Target folder ID (null to remove from folder)"),
},
)
)
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
if not data or not data.get("agent_id"):
return make_response(jsonify({"success": False, "message": "Agent ID is required"}), 400)
agent_id = data["agent_id"]
folder_id = data.get("folder_id")
try:
agent = agents_collection.find_one({"_id": ObjectId(agent_id), "user": user})
if not agent:
return make_response(jsonify({"success": False, "message": "Agent not found"}), 404)
if folder_id:
folder = agent_folders_collection.find_one({"_id": ObjectId(folder_id), "user": user})
if not folder:
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
agents_collection.update_one(
{"_id": ObjectId(agent_id)}, {"$set": {"folder_id": folder_id}}
)
else:
agents_collection.update_one(
{"_id": ObjectId(agent_id)}, {"$unset": {"folder_id": ""}}
)
return make_response(jsonify({"success": True}), 200)
except Exception as e:
return make_response(jsonify({"success": False, "message": str(e)}), 400)
@agents_folders_ns.route("/bulk_move")
class BulkMoveAgents(Resource):
@api.doc(description="Move multiple agents to a folder")
@api.expect(
api.model(
"BulkMoveAgents",
{
"agent_ids": fields.List(fields.String, required=True, description="List of agent IDs"),
"folder_id": fields.String(required=False, description="Target folder ID"),
},
)
)
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
if not data or not data.get("agent_ids"):
return make_response(jsonify({"success": False, "message": "Agent IDs are required"}), 400)
agent_ids = data["agent_ids"]
folder_id = data.get("folder_id")
try:
if folder_id:
folder = agent_folders_collection.find_one({"_id": ObjectId(folder_id), "user": user})
if not folder:
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
object_ids = [ObjectId(aid) for aid in agent_ids]
if folder_id:
agents_collection.update_many(
{"_id": {"$in": object_ids}, "user": user},
{"$set": {"folder_id": folder_id}},
)
else:
agents_collection.update_many(
{"_id": {"$in": object_ids}, "user": user},
{"$unset": {"folder_id": ""}},
)
return make_response(jsonify({"success": True}), 200)
except Exception as e:
return make_response(jsonify({"success": False, "message": str(e)}), 400)

View File

@@ -11,6 +11,7 @@ from flask_restx import fields, Namespace, Resource
from application.api import api
from application.api.user.base import (
agent_folders_collection,
agents_collection,
db,
ensure_user_doc,
@@ -18,6 +19,13 @@ from application.api.user.base import (
resolve_tool_details,
storage,
users_collection,
workflow_edges_collection,
workflow_nodes_collection,
workflows_collection,
)
from application.core.json_schema_utils import (
JsonSchemaValidationError,
normalize_json_schema_payload,
)
from application.core.settings import settings
from application.utils import (
@@ -30,6 +38,189 @@ from application.utils import (
agents_ns = Namespace("agents", description="Agent management operations", path="/api")
AGENT_TYPE_SCHEMAS = {
"classic": {
"required_published": [
"name",
"description",
"chunks",
"retriever",
"prompt_id",
],
"required_draft": ["name"],
"validate_published": ["name", "description", "prompt_id"],
"validate_draft": [],
"require_source": True,
"fields": [
"user",
"name",
"description",
"agent_type",
"status",
"key",
"image",
"source",
"sources",
"chunks",
"retriever",
"prompt_id",
"tools",
"json_schema",
"models",
"default_model_id",
"folder_id",
"limited_token_mode",
"token_limit",
"limited_request_mode",
"request_limit",
"createdAt",
"updatedAt",
"lastUsedAt",
],
},
"workflow": {
"required_published": ["name", "workflow"],
"required_draft": ["name"],
"validate_published": ["name", "workflow"],
"validate_draft": [],
"fields": [
"user",
"name",
"description",
"agent_type",
"status",
"key",
"workflow",
"folder_id",
"limited_token_mode",
"token_limit",
"limited_request_mode",
"request_limit",
"createdAt",
"updatedAt",
"lastUsedAt",
],
},
}
AGENT_TYPE_SCHEMAS["react"] = AGENT_TYPE_SCHEMAS["classic"]
AGENT_TYPE_SCHEMAS["openai"] = AGENT_TYPE_SCHEMAS["classic"]
def normalize_workflow_reference(workflow_value):
"""Normalize workflow references from form/json payloads."""
if workflow_value is None:
return None
if isinstance(workflow_value, dict):
return (
workflow_value.get("id")
or workflow_value.get("_id")
or workflow_value.get("workflow_id")
)
if isinstance(workflow_value, str):
value = workflow_value.strip()
if not value:
return ""
try:
parsed = json.loads(value)
if isinstance(parsed, str):
return parsed.strip()
if isinstance(parsed, dict):
return (
parsed.get("id") or parsed.get("_id") or parsed.get("workflow_id")
)
except json.JSONDecodeError:
pass
return value
return str(workflow_value)
def validate_workflow_access(workflow_value, user, required=False):
"""Validate workflow reference and ensure ownership."""
workflow_id = normalize_workflow_reference(workflow_value)
if not workflow_id:
if required:
return None, make_response(
jsonify({"success": False, "message": "Workflow is required"}), 400
)
return None, None
if not ObjectId.is_valid(workflow_id):
return None, make_response(
jsonify({"success": False, "message": "Invalid workflow ID format"}), 400
)
workflow = workflows_collection.find_one({"_id": ObjectId(workflow_id), "user": user})
if not workflow:
return None, make_response(
jsonify({"success": False, "message": "Workflow not found"}), 404
)
return workflow_id, None
def build_agent_document(
data, user, key, agent_type, image_url=None, source_field=None, sources_list=None
):
"""Build agent document based on agent type schema."""
if not agent_type or agent_type not in AGENT_TYPE_SCHEMAS:
agent_type = "classic"
schema = AGENT_TYPE_SCHEMAS.get(agent_type, AGENT_TYPE_SCHEMAS["classic"])
allowed_fields = set(schema["fields"])
now = datetime.datetime.now(datetime.timezone.utc)
base_doc = {
"user": user,
"name": data.get("name"),
"description": data.get("description", ""),
"agent_type": agent_type,
"status": data.get("status"),
"key": key,
"createdAt": now,
"updatedAt": now,
"lastUsedAt": None,
}
if agent_type == "workflow":
base_doc["workflow"] = data.get("workflow")
base_doc["folder_id"] = data.get("folder_id")
else:
base_doc.update(
{
"image": image_url or "",
"source": source_field or "",
"sources": sources_list or [],
"chunks": data.get("chunks", ""),
"retriever": data.get("retriever", ""),
"prompt_id": data.get("prompt_id", ""),
"tools": data.get("tools", []),
"json_schema": data.get("json_schema"),
"models": data.get("models", []),
"default_model_id": data.get("default_model_id", ""),
"folder_id": data.get("folder_id"),
}
)
if "limited_token_mode" in allowed_fields:
base_doc["limited_token_mode"] = (
data.get("limited_token_mode") == "True"
if isinstance(data.get("limited_token_mode"), str)
else bool(data.get("limited_token_mode", False))
)
if "token_limit" in allowed_fields:
base_doc["token_limit"] = int(
data.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"])
)
if "limited_request_mode" in allowed_fields:
base_doc["limited_request_mode"] = (
data.get("limited_request_mode") == "True"
if isinstance(data.get("limited_request_mode"), str)
else bool(data.get("limited_request_mode", False))
)
if "request_limit" in allowed_fields:
base_doc["request_limit"] = int(
data.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"])
)
return {k: v for k, v in base_doc.items() if k in allowed_fields}
@agents_ns.route("/get_agent")
class GetAgent(Resource):
@api.doc(params={"id": "Agent ID"}, description="Get agent by ID")
@@ -67,7 +258,7 @@ class GetAgent(Resource):
if (isinstance(source_ref, DBRef) and db.dereference(source_ref))
or source_ref == "default"
],
"chunks": agent["chunks"],
"chunks": agent.get("chunks", "2"),
"retriever": agent.get("retriever", ""),
"prompt_id": agent.get("prompt_id", ""),
"tools": agent.get("tools", []),
@@ -95,6 +286,10 @@ class GetAgent(Resource):
"shared": agent.get("shared_publicly", False),
"shared_metadata": agent.get("shared_metadata", {}),
"shared_token": agent.get("shared_token", ""),
"models": agent.get("models", []),
"default_model_id": agent.get("default_model_id", ""),
"folder_id": agent.get("folder_id"),
"workflow": agent.get("workflow"),
}
return make_response(jsonify(data), 200)
except Exception as e:
@@ -144,7 +339,7 @@ class GetAgents(Resource):
isinstance(source_ref, DBRef) and db.dereference(source_ref)
)
],
"chunks": agent["chunks"],
"chunks": agent.get("chunks", "2"),
"retriever": agent.get("retriever", ""),
"prompt_id": agent.get("prompt_id", ""),
"tools": agent.get("tools", []),
@@ -172,9 +367,15 @@ class GetAgents(Resource):
"shared": agent.get("shared_publicly", False),
"shared_metadata": agent.get("shared_metadata", {}),
"shared_token": agent.get("shared_token", ""),
"models": agent.get("models", []),
"default_model_id": agent.get("default_model_id", ""),
"folder_id": agent.get("folder_id"),
"workflow": agent.get("workflow"),
}
for agent in agents
if "source" in agent or "retriever" in agent
if "source" in agent
or "retriever" in agent
or agent.get("agent_type") == "workflow"
]
except Exception as err:
current_app.logger.error(f"Error retrieving agents: {err}", exc_info=True)
@@ -202,16 +403,22 @@ class CreateAgent(Resource):
required=False,
description="List of source identifiers for multiple sources",
),
"chunks": fields.Integer(required=True, description="Chunks count"),
"retriever": fields.String(required=True, description="Retriever ID"),
"prompt_id": fields.String(required=True, description="Prompt ID"),
"chunks": fields.Integer(required=False, description="Chunks count"),
"retriever": fields.String(required=False, description="Retriever ID"),
"prompt_id": fields.String(required=False, description="Prompt ID"),
"tools": fields.List(
fields.String, required=False, description="List of tool identifiers"
),
"agent_type": fields.String(required=True, description="Type of the agent"),
"agent_type": fields.String(
required=False,
description="Type of the agent (classic, react, workflow). Defaults to 'classic' for backwards compatibility.",
),
"status": fields.String(
required=True, description="Status of the agent (draft or published)"
),
"workflow": fields.String(
required=False, description="Workflow ID for workflow-type agents"
),
"json_schema": fields.Raw(
required=False,
description="JSON schema for enforcing structured output format",
@@ -230,6 +437,17 @@ class CreateAgent(Resource):
required=False,
description="Request limit for the agent in limited mode",
),
"models": fields.List(
fields.String,
required=False,
description="List of available model IDs for this agent",
),
"default_model_id": fields.String(
required=False, description="Default model ID for this agent"
),
"folder_id": fields.String(
required=False, description="Folder ID to organize the agent"
),
},
)
@@ -258,42 +476,22 @@ class CreateAgent(Resource):
data["json_schema"] = json.loads(data["json_schema"])
except json.JSONDecodeError:
data["json_schema"] = None
if "models" in data:
try:
data["models"] = json.loads(data["models"])
except json.JSONDecodeError:
data["models"] = []
print(f"Received data: {data}")
# Validate JSON schema if provided
if data.get("json_schema"):
# Validate and normalize JSON schema if provided
if "json_schema" in data:
try:
# Basic validation - ensure it's a valid JSON structure
json_schema = data.get("json_schema")
if not isinstance(json_schema, dict):
return make_response(
jsonify(
{
"success": False,
"message": "JSON schema must be a valid JSON object",
}
),
400,
)
# Validate that it has either a 'schema' property or is itself a schema
if "schema" not in json_schema and "type" not in json_schema:
return make_response(
jsonify(
{
"success": False,
"message": "JSON schema must contain either a 'schema' property or be a valid JSON schema with 'type' property",
}
),
400,
)
except Exception as e:
data["json_schema"] = normalize_json_schema_payload(
data.get("json_schema")
)
except JsonSchemaValidationError as exc:
return make_response(
jsonify(
{"success": False, "message": f"Invalid JSON schema: {str(e)}"}
),
jsonify({"success": False, "message": f"JSON schema {exc}"}),
400,
)
if data.get("status") not in ["draft", "published"]:
@@ -306,18 +504,34 @@ class CreateAgent(Resource):
),
400,
)
if data.get("status") == "published":
required_fields = [
"name",
"description",
"chunks",
"retriever",
"prompt_id",
"agent_type",
]
# Require either source or sources (but not both)
agent_type = data.get("agent_type", "")
# Default to classic schema for empty or unknown agent types
if not data.get("source") and not data.get("sources"):
if not agent_type or agent_type not in AGENT_TYPE_SCHEMAS:
schema = AGENT_TYPE_SCHEMAS["classic"]
# Set agent_type to classic if it was empty
if not agent_type:
agent_type = "classic"
else:
schema = AGENT_TYPE_SCHEMAS[agent_type]
is_published = data.get("status") == "published"
if agent_type == "workflow":
workflow_id, workflow_error = validate_workflow_access(
data.get("workflow"), user, required=is_published
)
if workflow_error:
return workflow_error
data["workflow"] = workflow_id
if data.get("status") == "published":
required_fields = schema["required_published"]
validate_fields = schema["validate_published"]
if (
schema.get("require_source")
and not data.get("source")
and not data.get("sources")
):
return make_response(
jsonify(
{
@@ -327,10 +541,9 @@ class CreateAgent(Resource):
),
400,
)
validate_fields = ["name", "description", "prompt_id", "agent_type"]
else:
required_fields = ["name"]
validate_fields = []
required_fields = schema["required_draft"]
validate_fields = schema["validate_draft"]
missing_fields = check_required_fields(data, required_fields)
invalid_fields = validate_required_fields(data, validate_fields)
if missing_fields:
@@ -342,72 +555,50 @@ class CreateAgent(Resource):
return make_response(
jsonify({"success": False, "message": "Image upload failed"}), 400
)
folder_id = data.get("folder_id")
if folder_id:
if not ObjectId.is_valid(folder_id):
return make_response(
jsonify({"success": False, "message": "Invalid folder ID format"}),
400,
)
folder = agent_folders_collection.find_one(
{"_id": ObjectId(folder_id), "user": user}
)
if not folder:
return make_response(
jsonify({"success": False, "message": "Folder not found"}), 404
)
try:
key = str(uuid.uuid4()) if data.get("status") == "published" else ""
sources_list = []
source_field = ""
if data.get("sources") and len(data.get("sources", [])) > 0:
for source_id in data.get("sources", []):
if source_id == "default":
sources_list.append("default")
elif ObjectId.is_valid(source_id):
sources_list.append(DBRef("sources", ObjectId(source_id)))
source_field = ""
else:
source_value = data.get("source", "")
if source_value == "default":
source_field = "default"
elif ObjectId.is_valid(source_value):
source_field = DBRef("sources", ObjectId(source_value))
else:
source_field = ""
new_agent = {
"user": user,
"name": data.get("name"),
"description": data.get("description", ""),
"image": image_url,
"source": source_field,
"sources": sources_list,
"chunks": data.get("chunks", ""),
"retriever": data.get("retriever", ""),
"prompt_id": data.get("prompt_id", ""),
"tools": data.get("tools", []),
"agent_type": data.get("agent_type", ""),
"status": data.get("status"),
"json_schema": data.get("json_schema"),
"limited_token_mode": (
data.get("limited_token_mode") == "True"
if isinstance(data.get("limited_token_mode"), str)
else bool(data.get("limited_token_mode", False))
),
"token_limit": int(
data.get(
"token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"]
)
),
"limited_request_mode": (
data.get("limited_request_mode") == "True"
if isinstance(data.get("limited_request_mode"), str)
else bool(data.get("limited_request_mode", False))
),
"request_limit": int(
data.get(
"request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"]
)
),
"createdAt": datetime.datetime.now(datetime.timezone.utc),
"updatedAt": datetime.datetime.now(datetime.timezone.utc),
"lastUsedAt": None,
"key": key,
}
if new_agent["chunks"] == "":
new_agent["chunks"] = "2"
if (
new_agent["source"] == ""
and new_agent["retriever"] == ""
and not new_agent["sources"]
):
new_agent["retriever"] = "classic"
new_agent = build_agent_document(
data, user, key, agent_type, image_url, source_field, sources_list
)
if agent_type != "workflow":
if new_agent.get("chunks") == "":
new_agent["chunks"] = "2"
if (
new_agent.get("source") == ""
and new_agent.get("retriever") == ""
and not new_agent.get("sources")
):
new_agent["retriever"] = "classic"
resp = agents_collection.insert_one(new_agent)
new_id = str(resp.inserted_id)
except Exception as err:
@@ -436,16 +627,22 @@ class UpdateAgent(Resource):
required=False,
description="List of source identifiers for multiple sources",
),
"chunks": fields.Integer(required=True, description="Chunks count"),
"retriever": fields.String(required=True, description="Retriever ID"),
"prompt_id": fields.String(required=True, description="Prompt ID"),
"chunks": fields.Integer(required=False, description="Chunks count"),
"retriever": fields.String(required=False, description="Retriever ID"),
"prompt_id": fields.String(required=False, description="Prompt ID"),
"tools": fields.List(
fields.String, required=False, description="List of tool identifiers"
),
"agent_type": fields.String(required=True, description="Type of the agent"),
"agent_type": fields.String(
required=False,
description="Type of the agent (classic, react, workflow). Defaults to 'classic' for backwards compatibility.",
),
"status": fields.String(
required=True, description="Status of the agent (draft or published)"
),
"workflow": fields.String(
required=False, description="Workflow ID for workflow-type agents"
),
"json_schema": fields.Raw(
required=False,
description="JSON schema for enforcing structured output format",
@@ -464,6 +661,17 @@ class UpdateAgent(Resource):
required=False,
description="Request limit for the agent in limited mode",
),
"models": fields.List(
fields.String,
required=False,
description="List of available model IDs for this agent",
),
"default_model_id": fields.String(
required=False, description="Default model ID for this agent"
),
"folder_id": fields.String(
required=False, description="Folder ID to organize the agent"
),
},
)
@@ -487,7 +695,7 @@ class UpdateAgent(Resource):
data = request.get_json()
else:
data = request.form.to_dict()
json_fields = ["tools", "sources", "json_schema"]
json_fields = ["tools", "sources", "json_schema", "models"]
for field in json_fields:
if field in data and data[field]:
try:
@@ -502,6 +710,8 @@ class UpdateAgent(Resource):
),
400,
)
if data.get("json_schema") == "":
data["json_schema"] = None
except Exception as err:
current_app.logger.error(
f"Error parsing request data: {err}", exc_info=True
@@ -555,6 +765,10 @@ class UpdateAgent(Resource):
"token_limit",
"limited_request_mode",
"request_limit",
"models",
"default_model_id",
"folder_id",
"workflow",
]
for field in allowed_fields:
@@ -658,17 +872,15 @@ class UpdateAgent(Resource):
elif field == "json_schema":
json_schema = data.get("json_schema")
if json_schema is not None:
if not isinstance(json_schema, dict):
try:
update_fields[field] = normalize_json_schema_payload(
json_schema
)
except JsonSchemaValidationError as exc:
return make_response(
jsonify(
{
"success": False,
"message": "JSON schema must be a valid object",
}
),
jsonify({"success": False, "message": f"JSON schema {exc}"}),
400,
)
update_fields[field] = json_schema
else:
update_fields[field] = None
elif field == "limited_token_mode":
@@ -711,10 +923,10 @@ class UpdateAgent(Resource):
)
elif field == "token_limit":
token_limit = data.get("token_limit")
# Convert to int and store
update_fields[field] = int(token_limit) if token_limit else 0
# Validate consistency with mode
if update_fields[field] > 0 and not data.get("limited_token_mode"):
return make_response(
jsonify(
@@ -739,6 +951,42 @@ class UpdateAgent(Resource):
),
400,
)
elif field == "folder_id":
folder_id = data.get("folder_id")
if folder_id:
if not ObjectId.is_valid(folder_id):
return make_response(
jsonify(
{
"success": False,
"message": "Invalid folder ID format",
}
),
400,
)
folder = agent_folders_collection.find_one(
{"_id": ObjectId(folder_id), "user": user}
)
if not folder:
return make_response(
jsonify({"success": False, "message": "Folder not found"}),
404,
)
update_fields[field] = folder_id
else:
update_fields[field] = None
elif field == "workflow":
workflow_required = (
data.get("status", existing_agent.get("status")) == "published"
and data.get("agent_type", existing_agent.get("agent_type"))
== "workflow"
)
workflow_id, workflow_error = validate_workflow_access(
data.get("workflow"), user, required=workflow_required
)
if workflow_error:
return workflow_error
update_fields[field] = workflow_id
else:
value = data[field]
if field in ["name", "description", "prompt_id", "agent_type"]:
@@ -767,46 +1015,82 @@ class UpdateAgent(Resource):
)
newly_generated_key = None
final_status = update_fields.get("status", existing_agent.get("status"))
agent_type = update_fields.get("agent_type", existing_agent.get("agent_type"))
if final_status == "published":
required_published_fields = {
"name": "Agent name",
"description": "Agent description",
"chunks": "Chunks count",
"prompt_id": "Prompt",
"agent_type": "Agent type",
}
if agent_type == "workflow":
required_published_fields = {
"name": "Agent name",
}
missing_published_fields = []
for req_field, field_label in required_published_fields.items():
final_value = update_fields.get(
req_field, existing_agent.get(req_field)
)
if not final_value:
missing_published_fields.append(field_label)
missing_published_fields = []
for req_field, field_label in required_published_fields.items():
final_value = update_fields.get(
req_field, existing_agent.get(req_field)
workflow_id = update_fields.get("workflow", existing_agent.get("workflow"))
if not workflow_id:
missing_published_fields.append("Workflow")
elif not ObjectId.is_valid(workflow_id):
missing_published_fields.append("Valid workflow")
else:
workflow = workflows_collection.find_one(
{"_id": ObjectId(workflow_id), "user": user}
)
if not workflow:
missing_published_fields.append("Workflow access")
if missing_published_fields:
return make_response(
jsonify(
{
"success": False,
"message": f"Cannot publish workflow agent. Missing required fields: {', '.join(missing_published_fields)}",
}
),
400,
)
else:
required_published_fields = {
"name": "Agent name",
"description": "Agent description",
"chunks": "Chunks count",
"prompt_id": "Prompt",
"agent_type": "Agent type",
}
missing_published_fields = []
for req_field, field_label in required_published_fields.items():
final_value = update_fields.get(
req_field, existing_agent.get(req_field)
)
if not final_value:
missing_published_fields.append(field_label)
source_val = update_fields.get("source", existing_agent.get("source"))
sources_val = update_fields.get(
"sources", existing_agent.get("sources", [])
)
if not final_value:
missing_published_fields.append(field_label)
source_val = update_fields.get("source", existing_agent.get("source"))
sources_val = update_fields.get(
"sources", existing_agent.get("sources", [])
)
has_valid_source = (
isinstance(source_val, DBRef)
or source_val == "default"
or (isinstance(sources_val, list) and len(sources_val) > 0)
)
if not has_valid_source:
missing_published_fields.append("Source")
if missing_published_fields:
return make_response(
jsonify(
{
"success": False,
"message": f"Cannot publish agent. Missing or invalid required fields: {', '.join(missing_published_fields)}",
}
),
400,
has_valid_source = (
isinstance(source_val, DBRef)
or source_val == "default"
or (isinstance(sources_val, list) and len(sources_val) > 0)
)
if not has_valid_source:
missing_published_fields.append("Source")
if missing_published_fields:
return make_response(
jsonify(
{
"success": False,
"message": f"Cannot publish agent. Missing or invalid required fields: {', '.join(missing_published_fields)}",
}
),
400,
)
if not existing_agent.get("key"):
newly_generated_key = str(uuid.uuid4())
update_fields["key"] = newly_generated_key
@@ -878,6 +1162,29 @@ class DeleteAgent(Resource):
jsonify({"success": False, "message": "Agent not found"}), 404
)
deleted_id = str(deleted_agent["_id"])
if deleted_agent.get("agent_type") == "workflow" and deleted_agent.get(
"workflow"
):
workflow_id = normalize_workflow_reference(deleted_agent.get("workflow"))
if workflow_id and ObjectId.is_valid(workflow_id):
workflow_oid = ObjectId(workflow_id)
owned_workflow = workflows_collection.find_one(
{"_id": workflow_oid, "user": user}, {"_id": 1}
)
if owned_workflow:
workflow_nodes_collection.delete_many({"workflow_id": workflow_id})
workflow_edges_collection.delete_many({"workflow_id": workflow_id})
workflows_collection.delete_one({"_id": workflow_oid, "user": user})
else:
current_app.logger.warning(
f"Skipping workflow cleanup for non-owned workflow {workflow_id}"
)
elif workflow_id:
current_app.logger.warning(
f"Skipping workflow cleanup for invalid workflow id {workflow_id}"
)
except Exception as err:
current_app.logger.error(f"Error deleting agent: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
@@ -986,19 +1293,16 @@ class AdoptAgent(Resource):
def post(self):
if not (decoded_token := request.decoded_token):
return make_response(jsonify({"success": False}), 401)
if not (agent_id := request.args.get("id")):
return make_response(
jsonify({"success": False, "message": "ID required"}), 400
)
try:
agent = agents_collection.find_one(
{"_id": ObjectId(agent_id), "user": "system"}
)
if not agent:
return make_response(jsonify({"status": "Not found"}), 404)
new_agent = agent.copy()
new_agent.pop("_id", None)
new_agent["user"] = decoded_token["sub"]
@@ -1108,4 +1412,4 @@ class RemoveSharedAgent(Resource):
current_app.logger.error(f"Error removing shared agent: {err}")
return make_response(
jsonify({"success": False, "message": "Server error"}), 500
)
)

View File

@@ -255,8 +255,8 @@ class ShareAgent(Resource):
{"$unset": {"shared_metadata": ""}},
)
except Exception as err:
current_app.logger.error(f"Error sharing/unsharing agent: {err}")
return make_response(jsonify({"success": False, "error": str(err)}), 400)
current_app.logger.error(f"Error sharing/unsharing agent: {err}", exc_info=True)
return make_response(jsonify({"success": False, "error": "Failed to update agent sharing status"}), 400)
shared_token = shared_token if shared else None
return make_response(
jsonify({"success": True, "shared_token": shared_token}), 200

View File

@@ -99,11 +99,8 @@ class StoreAttachment(Resource):
})
if not tasks:
error_msg = "No valid files to upload"
if errors:
error_msg += f". Errors: {errors}"
return make_response(
jsonify({"status": "error", "message": error_msg, "errors": errors}),
jsonify({"status": "error", "message": "No valid files to upload"}),
400,
)
@@ -135,7 +132,7 @@ class StoreAttachment(Resource):
)
except Exception as err:
current_app.logger.error(f"Error storing attachment: {err}", exc_info=True)
return make_response(jsonify({"success": False, "error": str(err)}), 400)
return make_response(jsonify({"success": False, "error": "Failed to store attachment"}), 400)
@attachments_ns.route("/images/<path:image_path>")

View File

@@ -31,12 +31,17 @@ sources_collection = db["sources"]
prompts_collection = db["prompts"]
feedback_collection = db["feedback"]
agents_collection = db["agents"]
agent_folders_collection = db["agent_folders"]
token_usage_collection = db["token_usage"]
shared_conversations_collections = db["shared_conversations"]
users_collection = db["users"]
user_logs_collection = db["user_logs"]
user_tools_collection = db["user_tools"]
attachments_collection = db["attachments"]
workflow_runs_collection = db["workflow_runs"]
workflows_collection = db["workflows"]
workflow_nodes_collection = db["workflow_nodes"]
workflow_edges_collection = db["workflow_edges"]
try:
@@ -46,6 +51,25 @@ try:
background=True,
)
users_collection.create_index("user_id", unique=True)
workflows_collection.create_index(
[("user", 1)], name="workflow_user_index", background=True
)
workflow_nodes_collection.create_index(
[("workflow_id", 1)], name="node_workflow_index", background=True
)
workflow_nodes_collection.create_index(
[("workflow_id", 1), ("graph_version", 1)],
name="node_workflow_graph_version_index",
background=True,
)
workflow_edges_collection.create_index(
[("workflow_id", 1)], name="edge_workflow_index", background=True
)
workflow_edges_collection.create_index(
[("workflow_id", 1), ("graph_version", 1)],
name="edge_workflow_graph_version_index",
background=True,
)
except Exception as e:
print("Error creating indexes:", e)
current_dir = os.path.dirname(

View File

@@ -0,0 +1,3 @@
from .routes import models_ns
__all__ = ["models_ns"]

View File

@@ -0,0 +1,25 @@
from flask import current_app, jsonify, make_response
from flask_restx import Namespace, Resource
from application.core.model_settings import ModelRegistry
models_ns = Namespace("models", description="Available models", path="/api")
@models_ns.route("/models")
class ModelsListResource(Resource):
def get(self):
"""Get list of available models with their capabilities."""
try:
registry = ModelRegistry.get_instance()
models = registry.get_enabled_models()
response = {
"models": [model.to_dict() for model in models],
"default_model_id": registry.default_model_id,
"count": len(models),
}
except Exception as err:
current_app.logger.error(f"Error fetching models: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 500)
return make_response(jsonify(response), 200)

View File

@@ -5,15 +5,16 @@ Main user API routes - registers all namespace modules.
from flask import Blueprint
from application.api import api
from .agents import agents_ns, agents_sharing_ns, agents_webhooks_ns
from .agents import agents_ns, agents_sharing_ns, agents_webhooks_ns, agents_folders_ns
from .analytics import analytics_ns
from .attachments import attachments_ns
from .conversations import conversations_ns
from .models import models_ns
from .prompts import prompts_ns
from .sharing import sharing_ns
from .sources import sources_chunks_ns, sources_ns, sources_upload_ns
from .tools import tools_mcp_ns, tools_ns
from .workflows import workflows_ns
user = Blueprint("user", __name__)
@@ -27,10 +28,14 @@ api.add_namespace(attachments_ns)
# Conversations
api.add_namespace(conversations_ns)
# Agents (main, sharing, webhooks)
# Models
api.add_namespace(models_ns)
# Agents (main, sharing, webhooks, folders)
api.add_namespace(agents_ns)
api.add_namespace(agents_sharing_ns)
api.add_namespace(agents_webhooks_ns)
api.add_namespace(agents_folders_ns)
# Prompts
api.add_namespace(prompts_ns)
@@ -46,3 +51,6 @@ api.add_namespace(sources_upload_ns)
# Tools (main, MCP)
api.add_namespace(tools_ns)
api.add_namespace(tools_mcp_ns)
# Workflows
api.add_namespace(workflows_ns)

View File

@@ -220,8 +220,23 @@ class GetPubliclySharedConversations(Resource):
shared
and "conversation_id" in shared
):
# conversation_id is now stored as an ObjectId, not a DBRef
# Handle DBRef (legacy), ObjectId, dict, and string formats for conversation_id
conversation_id = shared["conversation_id"]
if isinstance(conversation_id, DBRef):
conversation_id = conversation_id.id
elif isinstance(conversation_id, dict):
# Handle dict representation of DBRef (e.g., {"$ref": "...", "$id": "..."})
if "$id" in conversation_id:
conv_id = conversation_id["$id"]
# $id might be a dict like {"$oid": "..."} or a string
if isinstance(conv_id, dict) and "$oid" in conv_id:
conversation_id = ObjectId(conv_id["$oid"])
else:
conversation_id = ObjectId(conv_id)
elif "_id" in conversation_id:
conversation_id = ObjectId(conversation_id["_id"])
elif isinstance(conversation_id, str):
conversation_id = ObjectId(conversation_id)
conversation = conversations_collection.find_one(
{"_id": conversation_id}
)

View File

@@ -55,9 +55,14 @@ class GetChunks(Resource):
if path:
chunk_source = metadata.get("source", "")
# Check if the chunk's source matches the requested path
chunk_file_path = metadata.get("file_path", "")
# Check if the chunk matches the requested path
# For file uploads: source ends with path (e.g., "inputs/.../file.pdf" ends with "file.pdf")
# For crawlers: file_path ends with path (e.g., "guides/setup.md" ends with "setup.md")
source_match = chunk_source and chunk_source.endswith(path)
file_path_match = chunk_file_path and chunk_file_path.endswith(path)
if not chunk_source or not chunk_source.endswith(path):
if not (source_match or file_path_match):
continue
# Filter by search term if provided

View File

@@ -9,6 +9,7 @@ from flask_restx import fields, Namespace, Resource
from application.api import api
from application.api.user.base import sources_collection
from application.api.user.tasks import sync_source
from application.core.settings import settings
from application.storage.storage_creator import StorageCreator
from application.utils import check_required_fields
@@ -20,6 +21,21 @@ sources_ns = Namespace(
)
def _get_provider_from_remote_data(remote_data):
if not remote_data:
return None
if isinstance(remote_data, dict):
return remote_data.get("provider")
if isinstance(remote_data, str):
try:
remote_data_obj = json.loads(remote_data)
except Exception:
return None
if isinstance(remote_data_obj, dict):
return remote_data_obj.get("provider")
return None
@sources_ns.route("/sources")
class CombinedJson(Resource):
@api.doc(description="Provide JSON file with combined available indexes")
@@ -41,6 +57,7 @@ class CombinedJson(Resource):
try:
for index in sources_collection.find({"user": user}).sort("date", -1):
provider = _get_provider_from_remote_data(index.get("remote_data"))
data.append(
{
"id": str(index["_id"]),
@@ -51,6 +68,7 @@ class CombinedJson(Resource):
"tokens": index.get("tokens", ""),
"retriever": index.get("retriever", "classic"),
"syncFrequency": index.get("sync_frequency", ""),
"provider": provider,
"is_nested": bool(index.get("directory_structure")),
"type": index.get(
"type", "file"
@@ -107,6 +125,7 @@ class PaginatedSources(Resource):
paginated_docs = []
for doc in documents:
provider = _get_provider_from_remote_data(doc.get("remote_data"))
doc_data = {
"id": str(doc["_id"]),
"name": doc.get("name", ""),
@@ -116,6 +135,7 @@ class PaginatedSources(Resource):
"tokens": doc.get("tokens", ""),
"retriever": doc.get("retriever", "classic"),
"syncFrequency": doc.get("sync_frequency", ""),
"provider": provider,
"isNested": bool(doc.get("directory_structure")),
"type": doc.get("type", "file"),
}
@@ -240,7 +260,7 @@ class ManageSync(Resource):
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
data = request.get_json() or {}
required_fields = ["source_id", "sync_frequency"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
@@ -269,6 +289,72 @@ class ManageSync(Resource):
return make_response(jsonify({"success": True}), 200)
@sources_ns.route("/sync_source")
class SyncSource(Resource):
sync_source_model = api.model(
"SyncSourceModel",
{"source_id": fields.String(required=True, description="Source ID")},
)
@api.expect(sync_source_model)
@api.doc(description="Trigger an immediate sync for a source")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
required_fields = ["source_id"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
source_id = data["source_id"]
if not ObjectId.is_valid(source_id):
return make_response(
jsonify({"success": False, "message": "Invalid source ID"}), 400
)
doc = sources_collection.find_one(
{"_id": ObjectId(source_id), "user": user}
)
if not doc:
return make_response(
jsonify({"success": False, "message": "Source not found"}), 404
)
source_type = doc.get("type", "")
if source_type.startswith("connector"):
return make_response(
jsonify(
{
"success": False,
"message": "Connector sources must be synced via /api/connectors/sync",
}
),
400,
)
source_data = doc.get("remote_data")
if not source_data:
return make_response(
jsonify({"success": False, "message": "Source is not syncable"}), 400
)
try:
task = sync_source.delay(
source_data=source_data,
job_name=doc.get("name", ""),
user=user,
loader=source_type,
sync_frequency=doc.get("sync_frequency", "never"),
retriever=doc.get("retriever", "classic"),
doc_id=source_id,
)
except Exception as err:
current_app.logger.error(
f"Error starting sync for source {source_id}: {err}",
exc_info=True,
)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
@sources_ns.route("/directory_structure")
class DirectoryStructure(Resource):
@api.doc(
@@ -320,4 +406,4 @@ class DirectoryStructure(Resource):
current_app.logger.error(
f"Error retrieving directory structure: {e}", exc_info=True
)
return make_response(jsonify({"success": False, "error": str(e)}), 500)
return make_response(jsonify({"success": False, "error": "Failed to retrieve directory structure"}), 500)

View File

@@ -64,19 +64,27 @@ class UploadFile(Resource):
safe_user = safe_filename(user)
dir_name = safe_filename(job_name)
base_path = f"{settings.UPLOAD_FOLDER}/{safe_user}/{dir_name}"
file_name_map = {}
try:
storage = StorageCreator.get_storage()
for file in files:
original_filename = file.filename
original_filename = os.path.basename(file.filename)
safe_file = safe_filename(original_filename)
if original_filename:
file_name_map[safe_file] = original_filename
with tempfile.TemporaryDirectory() as temp_dir:
temp_file_path = os.path.join(temp_dir, safe_file)
file.save(temp_file_path)
if zipfile.is_zipfile(temp_file_path):
# Only extract actual .zip files, not Office formats (.docx, .xlsx, .pptx)
# which are technically zip archives but should be processed as-is
is_office_format = safe_file.lower().endswith(
(".docx", ".xlsx", ".pptx", ".odt", ".ods", ".odp", ".epub")
)
if zipfile.is_zipfile(temp_file_path) and not is_office_format:
try:
with zipfile.ZipFile(temp_file_path, "r") as zip_ref:
zip_ref.extractall(path=temp_dir)
@@ -137,6 +145,7 @@ class UploadFile(Resource):
user,
file_path=base_path,
filename=dir_name,
file_name_map=file_name_map,
)
except Exception as err:
current_app.logger.error(f"Error uploading file: {err}", exc_info=True)
@@ -182,6 +191,8 @@ class UploadRemote(Resource):
source_data = config.get("url")
elif data["source"] == "reddit":
source_data = config
elif data["source"] == "s3":
source_data = config
elif data["source"] in ConnectorCreator.get_supported_connectors():
session_token = config.get("session_token")
if not session_token:
@@ -334,6 +345,14 @@ class ManageSourceFiles(Resource):
storage = StorageCreator.get_storage()
source_file_path = source.get("file_path", "")
parent_dir = request.form.get("parent_dir", "")
file_name_map = source.get("file_name_map") or {}
if isinstance(file_name_map, str):
try:
file_name_map = json.loads(file_name_map)
except Exception:
file_name_map = {}
if not isinstance(file_name_map, dict):
file_name_map = {}
if parent_dir and (parent_dir.startswith("/") or ".." in parent_dir):
return make_response(
@@ -355,19 +374,35 @@ class ManageSourceFiles(Resource):
400,
)
added_files = []
map_updated = False
target_dir = source_file_path
if parent_dir:
target_dir = f"{source_file_path}/{parent_dir}"
for file in files:
if file.filename:
safe_filename_str = safe_filename(file.filename)
original_filename = os.path.basename(file.filename)
safe_filename_str = safe_filename(original_filename)
file_path = f"{target_dir}/{safe_filename_str}"
# Save file to storage
storage.save_file(file, file_path)
added_files.append(safe_filename_str)
if original_filename:
relative_key = (
f"{parent_dir}/{safe_filename_str}"
if parent_dir
else safe_filename_str
)
file_name_map[relative_key] = original_filename
map_updated = True
if map_updated:
sources_collection.update_one(
{"_id": ObjectId(source_id)},
{"$set": {"file_name_map": file_name_map}},
)
# Trigger re-ingestion pipeline
from application.api.user.tasks import reingest_source_task
@@ -414,6 +449,7 @@ class ManageSourceFiles(Resource):
# Remove files from storage and directory structure
removed_files = []
map_updated = False
for file_path in file_paths:
full_path = f"{source_file_path}/{file_path}"
@@ -422,6 +458,15 @@ class ManageSourceFiles(Resource):
if storage.file_exists(full_path):
storage.delete_file(full_path)
removed_files.append(file_path)
if file_path in file_name_map:
file_name_map.pop(file_path, None)
map_updated = True
if map_updated and isinstance(file_name_map, dict):
sources_collection.update_one(
{"_id": ObjectId(source_id)},
{"$set": {"file_name_map": file_name_map}},
)
# Trigger re-ingestion pipeline
from application.api.user.tasks import reingest_source_task
@@ -504,6 +549,20 @@ class ManageSourceFiles(Resource):
f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}, "
f"Full path: {full_directory_path}"
)
if directory_path and file_name_map:
prefix = f"{directory_path.rstrip('/')}/"
keys_to_remove = [
key
for key in file_name_map.keys()
if key == directory_path or key.startswith(prefix)
]
if keys_to_remove:
for key in keys_to_remove:
file_name_map.pop(key, None)
sources_collection.update_one(
{"_id": ObjectId(source_id)},
{"$set": {"file_name_map": file_name_map}},
)
# Trigger re-ingestion pipeline
@@ -574,8 +633,9 @@ class TaskStatus(Resource):
):
task_meta = str(task_meta) # Convert to a string representation
except ConnectionError as err:
current_app.logger.error(f"Connection error getting task status: {err}")
return make_response(
jsonify({"success": False, "message": str(err)}), 503
jsonify({"success": False, "message": "Service unavailable"}), 503
)
except Exception as err:
current_app.logger.error(f"Error getting task status: {err}", exc_info=True)

View File

@@ -8,13 +8,25 @@ from application.worker import (
mcp_oauth,
mcp_oauth_status,
remote_worker,
sync,
sync_worker,
)
@celery.task(bind=True)
def ingest(self, directory, formats, job_name, user, file_path, filename):
resp = ingest_worker(self, directory, formats, job_name, file_path, filename, user)
def ingest(
self, directory, formats, job_name, user, file_path, filename, file_name_map=None
):
resp = ingest_worker(
self,
directory,
formats,
job_name,
file_path,
filename,
user,
file_name_map=file_name_map,
)
return resp
@@ -38,6 +50,30 @@ def schedule_syncs(self, frequency):
return resp
@celery.task(bind=True)
def sync_source(
self,
source_data,
job_name,
user,
loader,
sync_frequency,
retriever,
doc_id,
):
resp = sync(
self,
source_data,
job_name,
user,
loader,
sync_frequency,
retriever,
doc_id,
)
return resp
@celery.task(bind=True)
def store_attachment(self, file_info, user):
resp = attachment_worker(self, file_info, user)

View File

@@ -1,21 +1,67 @@
"""Tool management MCP server integration."""
import json
from email.quoprimime import unquote
from urllib.parse import urlencode, urlparse
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, redirect, request
from flask_restx import fields, Namespace, Resource
from flask_restx import Namespace, Resource, fields
from application.agents.tools.mcp_tool import MCPOAuthManager, MCPTool
from application.api import api
from application.api.user.base import user_tools_collection
from application.api.user.tools.routes import transform_actions
from application.cache import get_redis_instance
from application.security.encryption import encrypt_credentials
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.security.encryption import decrypt_credentials, encrypt_credentials
from application.utils import check_required_fields
tools_mcp_ns = Namespace("tools", description="Tool management operations", path="/api")
_mongo = MongoDB.get_client()
_db = _mongo[settings.MONGO_DB_NAME]
_connector_sessions = _db["connector_sessions"]
_ALLOWED_TRANSPORTS = {"auto", "sse", "http"}
def _sanitize_mcp_transport(config):
"""Normalise and validate the transport_type field.
Strips ``command`` / ``args`` keys that are only valid for local STDIO
transports and returns the cleaned transport type string.
"""
transport_type = (config.get("transport_type") or "auto").lower()
if transport_type not in _ALLOWED_TRANSPORTS:
raise ValueError(f"Unsupported transport_type: {transport_type}")
config.pop("command", None)
config.pop("args", None)
config["transport_type"] = transport_type
return transport_type
def _extract_auth_credentials(config):
"""Build an ``auth_credentials`` dict from the raw MCP config."""
auth_credentials = {}
auth_type = config.get("auth_type", "none")
if auth_type == "api_key":
if config.get("api_key"):
auth_credentials["api_key"] = config["api_key"]
if config.get("api_key_header"):
auth_credentials["api_key_header"] = config["api_key_header"]
elif auth_type == "bearer":
if config.get("bearer_token"):
auth_credentials["bearer_token"] = config["bearer_token"]
elif auth_type == "basic":
if config.get("username"):
auth_credentials["username"] = config["username"]
if config.get("password"):
auth_credentials["password"] = config["password"]
return auth_credentials
@tools_mcp_ns.route("/mcp_server/test")
class TestMCPServerConfig(Resource):
@@ -43,34 +89,35 @@ class TestMCPServerConfig(Resource):
return missing_fields
try:
config = data["config"]
try:
_sanitize_mcp_transport(config)
except ValueError:
return make_response(
jsonify({"success": False, "error": "Unsupported transport_type"}),
400,
)
auth_credentials = {}
auth_type = config.get("auth_type", "none")
if auth_type == "api_key" and "api_key" in config:
auth_credentials["api_key"] = config["api_key"]
if "api_key_header" in config:
auth_credentials["api_key_header"] = config["api_key_header"]
elif auth_type == "bearer" and "bearer_token" in config:
auth_credentials["bearer_token"] = config["bearer_token"]
elif auth_type == "basic":
if "username" in config:
auth_credentials["username"] = config["username"]
if "password" in config:
auth_credentials["password"] = config["password"]
auth_credentials = _extract_auth_credentials(config)
test_config = config.copy()
test_config["auth_credentials"] = auth_credentials
mcp_tool = MCPTool(config=test_config, user_id=user)
result = mcp_tool.test_connection()
if result.get("requires_oauth"):
return make_response(jsonify(result), 200)
if not result.get("success") and "message" in result:
current_app.logger.error(
f"MCP connection test failed: {result.get('message')}"
)
result["message"] = "Connection test failed"
return make_response(jsonify(result), 200)
except Exception as e:
current_app.logger.error(f"Error testing MCP server: {e}", exc_info=True)
return make_response(
jsonify(
{"success": False, "error": f"Connection test failed: {str(e)}"}
),
jsonify({"success": False, "error": "Connection test failed"}),
500,
)
@@ -110,22 +157,16 @@ class MCPServerSave(Resource):
return missing_fields
try:
config = data["config"]
try:
_sanitize_mcp_transport(config)
except ValueError:
return make_response(
jsonify({"success": False, "error": "Unsupported transport_type"}),
400,
)
auth_credentials = {}
auth_credentials = _extract_auth_credentials(config)
auth_type = config.get("auth_type", "none")
if auth_type == "api_key":
if "api_key" in config and config["api_key"]:
auth_credentials["api_key"] = config["api_key"]
if "api_key_header" in config:
auth_credentials["api_key_header"] = config["api_key_header"]
elif auth_type == "bearer":
if "bearer_token" in config and config["bearer_token"]:
auth_credentials["bearer_token"] = config["bearer_token"]
elif auth_type == "basic":
if "username" in config and config["username"]:
auth_credentials["username"] = config["username"]
if "password" in config and config["password"]:
auth_credentials["password"] = config["password"]
mcp_config = config.copy()
mcp_config["auth_credentials"] = auth_credentials
@@ -163,30 +204,39 @@ class MCPServerSave(Resource):
"No valid credentials provided for the selected authentication type"
)
storage_config = config.copy()
tool_id = data.get("id")
existing_encrypted = None
if tool_id:
existing_doc = user_tools_collection.find_one(
{"_id": ObjectId(tool_id), "user": user, "name": "mcp_tool"}
)
if existing_doc:
existing_encrypted = existing_doc.get("config", {}).get(
"encrypted_credentials"
)
if auth_credentials:
encrypted_credentials_string = encrypt_credentials(
if existing_encrypted:
existing_secrets = decrypt_credentials(existing_encrypted, user)
existing_secrets.update(auth_credentials)
auth_credentials = existing_secrets
storage_config["encrypted_credentials"] = encrypt_credentials(
auth_credentials, user
)
storage_config["encrypted_credentials"] = encrypted_credentials_string
elif existing_encrypted:
storage_config["encrypted_credentials"] = existing_encrypted
for field in [
"api_key",
"bearer_token",
"username",
"password",
"api_key_header",
"redirect_uri",
]:
storage_config.pop(field, None)
transformed_actions = []
for action in actions_metadata:
action["active"] = True
if "parameters" in action:
if "properties" in action["parameters"]:
for param_name, param_details in action["parameters"][
"properties"
].items():
param_details["filled_by_llm"] = True
param_details["value"] = ""
transformed_actions.append(action)
transformed_actions = transform_actions(actions_metadata)
tool_data = {
"name": "mcp_tool",
"displayName": data["displayName"],
@@ -198,7 +248,6 @@ class MCPServerSave(Resource):
"user": user,
}
tool_id = data.get("id")
if tool_id:
result = user_tools_collection.update_one(
{"_id": ObjectId(tool_id), "user": user, "name": "mcp_tool"},
@@ -233,9 +282,7 @@ class MCPServerSave(Resource):
except Exception as e:
current_app.logger.error(f"Error saving MCP server: {e}", exc_info=True)
return make_response(
jsonify(
{"success": False, "error": f"Failed to save MCP server: {str(e)}"}
),
jsonify({"success": False, "error": "Failed to save MCP server"}),
500,
)
@@ -263,9 +310,12 @@ class MCPOAuthCallback(Resource):
error = request.args.get("error")
if error:
return redirect(
f"/api/connectors/callback-status?status=error&message=OAuth+error:+{error}.+Please+try+again+and+make+sure+to+grant+all+requested+permissions,+including+offline+access.&provider=mcp_tool"
)
params = {
"status": "error",
"message": f"OAuth error: {error}. Please try again and make sure to grant all requested permissions, including offline access.",
"provider": "mcp_tool",
}
return redirect(f"/api/connectors/callback-status?{urlencode(params)}")
if not code or not state:
return redirect(
"/api/connectors/callback-status?status=error&message=Authorization+code+or+state+not+provided.+Please+complete+the+authorization+process+and+make+sure+to+grant+offline+access.&provider=mcp_tool"
@@ -276,7 +326,6 @@ class MCPOAuthCallback(Resource):
return redirect(
"/api/connectors/callback-status?status=error&message=Internal+server+error:+Redis+not+available.&provider=mcp_tool"
)
code = unquote(code)
manager = MCPOAuthManager(redis_client)
success = manager.handle_oauth_callback(state, code, error)
if success:
@@ -292,17 +341,13 @@ class MCPOAuthCallback(Resource):
f"Error handling MCP OAuth callback: {str(e)}", exc_info=True
)
return redirect(
f"/api/connectors/callback-status?status=error&message=Internal+server+error:+{str(e)}.&provider=mcp_tool"
"/api/connectors/callback-status?status=error&message=Internal+server+error.&provider=mcp_tool"
)
@tools_mcp_ns.route("/mcp_server/oauth_status/<string:task_id>")
class MCPOAuthStatus(Resource):
def get(self, task_id):
"""
Get current status of OAuth flow.
Frontend should poll this endpoint periodically.
"""
try:
redis_client = get_redis_instance()
status_key = f"mcp_oauth_status:{task_id}"
@@ -310,6 +355,14 @@ class MCPOAuthStatus(Resource):
if status_data:
status = json.loads(status_data)
if "tools" in status and isinstance(status["tools"], list):
status["tools"] = [
{
"name": t.get("name", "unknown"),
"description": t.get("description", ""),
}
for t in status["tools"]
]
return make_response(
jsonify({"success": True, "task_id": task_id, **status})
)
@@ -317,17 +370,93 @@ class MCPOAuthStatus(Resource):
return make_response(
jsonify(
{
"success": False,
"error": "Task not found or expired",
"success": True,
"task_id": task_id,
"status": "pending",
"message": "Waiting for OAuth to start...",
}
),
404,
200,
)
except Exception as e:
current_app.logger.error(
f"Error getting OAuth status for task {task_id}: {str(e)}"
f"Error getting OAuth status for task {task_id}: {str(e)}",
exc_info=True,
)
return make_response(
jsonify({"success": False, "error": str(e), "task_id": task_id}), 500
jsonify(
{
"success": False,
"error": "Failed to get OAuth status",
"task_id": task_id,
}
),
500,
)
@tools_mcp_ns.route("/mcp_server/auth_status")
class MCPAuthStatus(Resource):
@api.doc(
description="Batch check auth status for all MCP tools. "
"Lightweight DB-only check — no network calls to MCP servers."
)
def get(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
try:
mcp_tools = list(
user_tools_collection.find(
{"user": user, "name": "mcp_tool"},
{"_id": 1, "config": 1},
)
)
if not mcp_tools:
return make_response(jsonify({"success": True, "statuses": {}}), 200)
oauth_server_urls = {}
statuses = {}
for tool in mcp_tools:
tool_id = str(tool["_id"])
config = tool.get("config", {})
auth_type = config.get("auth_type", "none")
if auth_type == "oauth":
server_url = config.get("server_url", "")
if server_url:
parsed = urlparse(server_url)
base_url = f"{parsed.scheme}://{parsed.netloc}"
oauth_server_urls[tool_id] = base_url
else:
statuses[tool_id] = "needs_auth"
else:
statuses[tool_id] = "configured"
if oauth_server_urls:
unique_urls = list(set(oauth_server_urls.values()))
sessions = list(
_connector_sessions.find(
{"user_id": user, "server_url": {"$in": unique_urls}},
{"server_url": 1, "tokens": 1},
)
)
url_has_tokens = {
doc["server_url"]: bool(doc.get("tokens", {}).get("access_token"))
for doc in sessions
}
for tool_id, base_url in oauth_server_urls.items():
if url_has_tokens.get(base_url):
statuses[tool_id] = "connected"
else:
statuses[tool_id] = "needs_auth"
return make_response(jsonify({"success": True, "statuses": statuses}), 200)
except Exception as e:
current_app.logger.error(
"Error checking MCP auth status: %s", e, exc_info=True
)
return make_response(
jsonify({"success": False, "error": "Failed to check auth status"}),
500,
)

View File

@@ -4,6 +4,7 @@ from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, Namespace, Resource
from application.agents.tools.spec_parser import parse_spec
from application.agents.tools.tool_manager import ToolManager
from application.api import api
from application.api.user.base import user_tools_collection
@@ -14,6 +15,114 @@ tool_config = {}
tool_manager = ToolManager(config=tool_config)
def _encrypt_secret_fields(config, config_requirements, user_id):
secret_keys = [
key for key, spec in config_requirements.items()
if spec.get("secret") and key in config and config[key]
]
if not secret_keys:
return config
storage_config = config.copy()
secret_values = {k: config[k] for k in secret_keys}
storage_config["encrypted_credentials"] = encrypt_credentials(secret_values, user_id)
for key in secret_keys:
storage_config.pop(key, None)
return storage_config
def _validate_config(config, config_requirements, has_existing_secrets=False):
errors = {}
for key, spec in config_requirements.items():
depends_on = spec.get("depends_on")
if depends_on:
if not all(config.get(dk) == dv for dk, dv in depends_on.items()):
continue
if spec.get("required") and not config.get(key):
if has_existing_secrets and spec.get("secret"):
continue
errors[key] = f"{spec.get('label', key)} is required"
value = config.get(key)
if value is not None and value != "":
if spec.get("type") == "number":
try:
num = float(value)
if key == "timeout" and (num < 1 or num > 300):
errors[key] = "Timeout must be between 1 and 300"
except (ValueError, TypeError):
errors[key] = f"{spec.get('label', key)} must be a number"
if spec.get("enum") and value not in spec["enum"]:
errors[key] = f"Invalid value for {spec.get('label', key)}"
return errors
def _merge_secrets_on_update(new_config, existing_config, config_requirements, user_id):
"""Merge incoming config with existing encrypted secrets and re-encrypt.
For updates, the client may omit unchanged secret values. This helper
decrypts any previously stored secrets, overlays whatever the client *did*
send, strips plain-text secrets from the stored config, and re-encrypts
the merged result.
Returns the final ``config`` dict ready for persistence.
"""
secret_keys = [
key for key, spec in config_requirements.items()
if spec.get("secret")
]
if not secret_keys:
return new_config
existing_secrets = {}
if "encrypted_credentials" in existing_config:
existing_secrets = decrypt_credentials(
existing_config["encrypted_credentials"], user_id
)
merged_secrets = existing_secrets.copy()
for key in secret_keys:
if key in new_config and new_config[key]:
merged_secrets[key] = new_config[key]
# Start from existing non-secret values, then overlay incoming non-secrets
storage_config = {
k: v for k, v in existing_config.items()
if k not in secret_keys and k != "encrypted_credentials"
}
storage_config.update(
{k: v for k, v in new_config.items() if k not in secret_keys}
)
if merged_secrets:
storage_config["encrypted_credentials"] = encrypt_credentials(
merged_secrets, user_id
)
else:
storage_config.pop("encrypted_credentials", None)
storage_config.pop("has_encrypted_credentials", None)
return storage_config
def transform_actions(actions_metadata):
"""Set default flags on action metadata for storage.
Marks each action as active, sets ``filled_by_llm`` and ``value`` on every
parameter property. Used by both the generic create_tool and MCP save routes.
"""
transformed = []
for action in actions_metadata:
action["active"] = True
if "parameters" in action:
props = action["parameters"].get("properties", {})
for param_details in props.values():
param_details["filled_by_llm"] = True
param_details["value"] = ""
transformed.append(action)
return transformed
tools_ns = Namespace("tools", description="Tool management operations", path="/api")
@@ -28,12 +137,15 @@ class AvailableTools(Resource):
lines = doc.split("\n", 1)
name = lines[0].strip()
description = lines[1].strip() if len(lines) > 1 else ""
config_req = tool_instance.get_config_requirements()
actions = tool_instance.get_actions_metadata()
tools_metadata.append(
{
"name": tool_name,
"displayName": name,
"description": description,
"configRequirements": tool_instance.get_config_requirements(),
"configRequirements": config_req,
"actions": actions,
}
)
except Exception as err:
@@ -59,6 +171,21 @@ class GetTools(Resource):
tool_copy = {**tool}
tool_copy["id"] = str(tool["_id"])
tool_copy.pop("_id", None)
config_req = tool_copy.get("configRequirements", {})
if not config_req:
tool_instance = tool_manager.tools.get(tool_copy.get("name"))
if tool_instance:
config_req = tool_instance.get_config_requirements()
tool_copy["configRequirements"] = config_req
has_secrets = any(
spec.get("secret") for spec in config_req.values()
) if config_req else False
if has_secrets and "encrypted_credentials" in tool_copy.get("config", {}):
tool_copy["config"]["has_encrypted_credentials"] = True
tool_copy["config"].pop("encrypted_credentials", None)
user_tools.append(tool_copy)
except Exception as err:
current_app.logger.error(f"Error getting user tools: {err}", exc_info=True)
@@ -115,23 +242,32 @@ class CreateTool(Resource):
jsonify({"success": False, "message": "Tool not found"}), 404
)
actions_metadata = tool_instance.get_actions_metadata()
transformed_actions = []
for action in actions_metadata:
action["active"] = True
if "parameters" in action:
if "properties" in action["parameters"]:
for param_name, param_details in action["parameters"][
"properties"
].items():
param_details["filled_by_llm"] = True
param_details["value"] = ""
transformed_actions.append(action)
transformed_actions = transform_actions(actions_metadata)
except Exception as err:
current_app.logger.error(
f"Error getting tool actions: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
try:
config_requirements = tool_instance.get_config_requirements()
if config_requirements:
validation_errors = _validate_config(
data["config"], config_requirements
)
if validation_errors:
return make_response(
jsonify(
{
"success": False,
"message": "Validation failed",
"errors": validation_errors,
}
),
400,
)
storage_config = _encrypt_secret_fields(
data["config"], config_requirements, user
)
new_tool = {
"user": user,
"name": data["name"],
@@ -139,7 +275,8 @@ class CreateTool(Resource):
"description": data["description"],
"customName": data.get("customName", ""),
"actions": transformed_actions,
"config": data["config"],
"config": storage_config,
"configRequirements": config_requirements,
"status": data["status"],
}
resp = user_tools_collection.insert_one(new_tool)
@@ -209,57 +346,37 @@ class UpdateTool(Resource):
tool_doc = user_tools_collection.find_one(
{"_id": ObjectId(data["id"]), "user": user}
)
if tool_doc and tool_doc.get("name") == "mcp_tool":
config = data["config"]
existing_config = tool_doc.get("config", {})
storage_config = existing_config.copy()
if not tool_doc:
return make_response(
jsonify({"success": False, "message": "Tool not found"}),
404,
)
tool_name = tool_doc.get("name", data.get("name"))
tool_instance = tool_manager.tools.get(tool_name)
config_requirements = (
tool_instance.get_config_requirements() if tool_instance else {}
)
existing_config = tool_doc.get("config", {})
has_existing_secrets = "encrypted_credentials" in existing_config
storage_config.update(config)
existing_credentials = {}
if "encrypted_credentials" in existing_config:
existing_credentials = decrypt_credentials(
existing_config["encrypted_credentials"], user
if config_requirements:
validation_errors = _validate_config(
data["config"], config_requirements,
has_existing_secrets=has_existing_secrets,
)
if validation_errors:
return make_response(
jsonify({
"success": False,
"message": "Validation failed",
"errors": validation_errors,
}),
400,
)
auth_credentials = existing_credentials.copy()
auth_type = storage_config.get("auth_type", "none")
if auth_type == "api_key":
if "api_key" in config and config["api_key"]:
auth_credentials["api_key"] = config["api_key"]
if "api_key_header" in config:
auth_credentials["api_key_header"] = config[
"api_key_header"
]
elif auth_type == "bearer":
if "bearer_token" in config and config["bearer_token"]:
auth_credentials["bearer_token"] = config["bearer_token"]
elif "encrypted_token" in config and config["encrypted_token"]:
auth_credentials["bearer_token"] = config["encrypted_token"]
elif auth_type == "basic":
if "username" in config and config["username"]:
auth_credentials["username"] = config["username"]
if "password" in config and config["password"]:
auth_credentials["password"] = config["password"]
if auth_type != "none" and auth_credentials:
encrypted_credentials_string = encrypt_credentials(
auth_credentials, user
)
storage_config["encrypted_credentials"] = (
encrypted_credentials_string
)
elif auth_type == "none":
storage_config.pop("encrypted_credentials", None)
for field in [
"api_key",
"bearer_token",
"encrypted_token",
"username",
"password",
"api_key_header",
]:
storage_config.pop(field, None)
update_data["config"] = storage_config
else:
update_data["config"] = data["config"]
update_data["config"] = _merge_secrets_on_update(
data["config"], existing_config, config_requirements, user
)
if "status" in data:
update_data["status"] = data["status"]
user_tools_collection.update_one(
@@ -297,9 +414,42 @@ class UpdateToolConfig(Resource):
if missing_fields:
return missing_fields
try:
tool_doc = user_tools_collection.find_one(
{"_id": ObjectId(data["id"]), "user": user}
)
if not tool_doc:
return make_response(jsonify({"success": False}), 404)
tool_name = tool_doc.get("name")
tool_instance = tool_manager.tools.get(tool_name)
config_requirements = (
tool_instance.get_config_requirements() if tool_instance else {}
)
existing_config = tool_doc.get("config", {})
has_existing_secrets = "encrypted_credentials" in existing_config
if config_requirements:
validation_errors = _validate_config(
data["config"], config_requirements,
has_existing_secrets=has_existing_secrets,
)
if validation_errors:
return make_response(
jsonify({
"success": False,
"message": "Validation failed",
"errors": validation_errors,
}),
400,
)
final_config = _merge_secrets_on_update(
data["config"], existing_config, config_requirements, user
)
user_tools_collection.update_one(
{"_id": ObjectId(data["id"]), "user": user},
{"$set": {"config": data["config"]}},
{"$set": {"config": final_config}},
)
except Exception as err:
current_app.logger.error(
@@ -409,8 +559,142 @@ class DeleteTool(Resource):
{"_id": ObjectId(data["id"]), "user": user}
)
if result.deleted_count == 0:
return {"success": False, "message": "Tool not found"}, 404
return make_response(
jsonify({"success": False, "message": "Tool not found"}), 404
)
except Exception as err:
current_app.logger.error(f"Error deleting tool: {err}", exc_info=True)
return {"success": False}, 400
return {"success": True}, 200
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True}), 200)
@tools_ns.route("/parse_spec")
class ParseSpec(Resource):
@api.doc(
description="Parse an API specification (OpenAPI 3.x or Swagger 2.0) and return actions"
)
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
if "file" in request.files:
file = request.files["file"]
if not file.filename:
return make_response(
jsonify({"success": False, "message": "No file selected"}), 400
)
try:
spec_content = file.read().decode("utf-8")
except UnicodeDecodeError:
return make_response(
jsonify({"success": False, "message": "Invalid file encoding"}), 400
)
elif request.is_json:
data = request.get_json()
spec_content = data.get("spec_content", "")
else:
return make_response(
jsonify({"success": False, "message": "No spec provided"}), 400
)
if not spec_content or not spec_content.strip():
return make_response(
jsonify({"success": False, "message": "Empty spec content"}), 400
)
try:
metadata, actions = parse_spec(spec_content)
return make_response(
jsonify(
{
"success": True,
"metadata": metadata,
"actions": actions,
}
),
200,
)
except ValueError as e:
current_app.logger.error(f"Spec validation error: {e}")
return make_response(jsonify({"success": False, "error": "Invalid specification format"}), 400)
except Exception as err:
current_app.logger.error(f"Error parsing spec: {err}", exc_info=True)
return make_response(jsonify({"success": False, "error": "Failed to parse specification"}), 500)
@tools_ns.route("/artifact/<artifact_id>")
class GetArtifact(Resource):
@api.doc(description="Get artifact data by artifact ID. Returns all todos for the tool when fetching a todo artifact.")
def get(self, artifact_id: str):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user_id = decoded_token.get("sub")
try:
obj_id = ObjectId(artifact_id)
except Exception:
return make_response(
jsonify({"success": False, "message": "Invalid artifact ID"}), 400
)
from application.core.mongo_db import MongoDB
from application.core.settings import settings
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
note_doc = db["notes"].find_one({"_id": obj_id, "user_id": user_id})
if note_doc:
content = note_doc.get("note", "")
line_count = len(content.split("\n")) if content else 0
artifact = {
"artifact_type": "note",
"data": {
"content": content,
"line_count": line_count,
"updated_at": (
note_doc["updated_at"].isoformat()
if note_doc.get("updated_at")
else None
),
},
}
return make_response(jsonify({"success": True, "artifact": artifact}), 200)
todo_doc = db["todos"].find_one({"_id": obj_id, "user_id": user_id})
if todo_doc:
tool_id = todo_doc.get("tool_id")
query = {"user_id": user_id, "tool_id": tool_id}
all_todos = list(db["todos"].find(query))
items = []
open_count = 0
completed_count = 0
for t in all_todos:
status = t.get("status", "open")
if status == "open":
open_count += 1
elif status == "completed":
completed_count += 1
items.append({
"todo_id": t.get("todo_id"),
"title": t.get("title", ""),
"status": status,
"created_at": (
t["created_at"].isoformat() if t.get("created_at") else None
),
"updated_at": (
t["updated_at"].isoformat() if t.get("updated_at") else None
),
})
artifact = {
"artifact_type": "todo_list",
"data": {
"items": items,
"total_count": len(items),
"open_count": open_count,
"completed_count": completed_count,
},
}
return make_response(jsonify({"success": True, "artifact": artifact}), 200)
return make_response(
jsonify({"success": False, "message": "Artifact not found"}), 404
)

View File

@@ -0,0 +1,378 @@
"""Centralized utilities for API routes."""
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Tuple
from bson.errors import InvalidId
from bson.objectid import ObjectId
from flask import jsonify, make_response, request, Response
from pymongo.collection import Collection
def get_user_id() -> Optional[str]:
"""
Extract user ID from decoded JWT token.
Returns:
User ID string or None if not authenticated
"""
decoded_token = getattr(request, "decoded_token", None)
return decoded_token.get("sub") if decoded_token else None
def require_auth(func: Callable) -> Callable:
"""
Decorator to require authentication for route handlers.
Usage:
@require_auth
def get(self):
user_id = get_user_id()
...
"""
@wraps(func)
def wrapper(*args, **kwargs):
user_id = get_user_id()
if not user_id:
return error_response("Unauthorized", 401)
return func(*args, **kwargs)
return wrapper
def success_response(
data: Optional[Dict[str, Any]] = None, status: int = 200
) -> Response:
"""
Create a standardized success response.
Args:
data: Optional data dictionary to include in response
status: HTTP status code (default: 200)
Returns:
Flask Response object
Example:
return success_response({"users": [...], "total": 10})
"""
response = {"success": True}
if data:
response.update(data)
return make_response(jsonify(response), status)
def error_response(message: str, status: int = 400, **kwargs) -> Response:
"""
Create a standardized error response.
Args:
message: Error message string
status: HTTP status code (default: 400)
**kwargs: Additional fields to include in response
Returns:
Flask Response object
Example:
return error_response("Resource not found", 404)
return error_response("Invalid input", 400, errors=["field1", "field2"])
"""
response = {"success": False, "message": message}
response.update(kwargs)
return make_response(jsonify(response), status)
def validate_object_id(
id_string: str, resource_name: str = "Resource"
) -> Tuple[Optional[ObjectId], Optional[Response]]:
"""
Validate and convert string to ObjectId.
Args:
id_string: String to convert
resource_name: Name of resource for error message
Returns:
Tuple of (ObjectId or None, error_response or None)
Example:
obj_id, error = validate_object_id(workflow_id, "Workflow")
if error:
return error
"""
try:
return ObjectId(id_string), None
except (InvalidId, TypeError):
return None, error_response(f"Invalid {resource_name} ID format")
def validate_pagination(
default_limit: int = 20, max_limit: int = 100
) -> Tuple[int, int, Optional[Response]]:
"""
Extract and validate pagination parameters from request.
Args:
default_limit: Default items per page
max_limit: Maximum allowed items per page
Returns:
Tuple of (limit, skip, error_response or None)
Example:
limit, skip, error = validate_pagination()
if error:
return error
"""
try:
limit = min(int(request.args.get("limit", default_limit)), max_limit)
skip = int(request.args.get("skip", 0))
if limit < 1 or skip < 0:
return 0, 0, error_response("Invalid pagination parameters")
return limit, skip, None
except ValueError:
return 0, 0, error_response("Invalid pagination parameters")
def check_resource_ownership(
collection: Collection,
resource_id: ObjectId,
user_id: str,
resource_name: str = "Resource",
) -> Tuple[Optional[Dict], Optional[Response]]:
"""
Check if resource exists and belongs to user.
Args:
collection: MongoDB collection
resource_id: Resource ObjectId
user_id: User ID string
resource_name: Name of resource for error messages
Returns:
Tuple of (resource_dict or None, error_response or None)
Example:
workflow, error = check_resource_ownership(
workflows_collection,
workflow_id,
user_id,
"Workflow"
)
if error:
return error
"""
resource = collection.find_one({"_id": resource_id, "user": user_id})
if not resource:
return None, error_response(f"{resource_name} not found", 404)
return resource, None
def serialize_object_id(
obj: Dict[str, Any], id_field: str = "_id", new_field: str = "id"
) -> Dict[str, Any]:
"""
Convert ObjectId to string in a dictionary.
Args:
obj: Dictionary containing ObjectId
id_field: Field name containing ObjectId
new_field: New field name for string ID
Returns:
Modified dictionary
Example:
user = serialize_object_id(user_doc)
# user["id"] = "507f1f77bcf86cd799439011"
"""
if id_field in obj:
obj[new_field] = str(obj[id_field])
if id_field != new_field:
obj.pop(id_field, None)
return obj
def serialize_list(items: List[Dict], serializer: Callable[[Dict], Dict]) -> List[Dict]:
"""
Apply serializer function to list of items.
Args:
items: List of dictionaries
serializer: Function to apply to each item
Returns:
List of serialized items
Example:
workflows = serialize_list(workflow_docs, serialize_workflow)
"""
return [serializer(item) for item in items]
def paginated_response(
collection: Collection,
query: Dict[str, Any],
serializer: Callable[[Dict], Dict],
limit: int,
skip: int,
sort_field: str = "created_at",
sort_order: int = -1,
response_key: str = "items",
) -> Response:
"""
Create paginated response for collection query.
Args:
collection: MongoDB collection
query: Query dictionary
serializer: Function to serialize each item
limit: Items per page
skip: Number of items to skip
sort_field: Field to sort by
sort_order: Sort order (1=asc, -1=desc)
response_key: Key name for items in response
Returns:
Flask Response with paginated data
Example:
return paginated_response(
workflows_collection,
{"user": user_id},
serialize_workflow,
limit, skip,
response_key="workflows"
)
"""
items = list(
collection.find(query).sort(sort_field, sort_order).skip(skip).limit(limit)
)
total = collection.count_documents(query)
return success_response(
{
response_key: serialize_list(items, serializer),
"total": total,
"limit": limit,
"skip": skip,
}
)
def require_fields(required: List[str]) -> Callable:
"""
Decorator to validate required fields in request JSON.
Args:
required: List of required field names
Returns:
Decorator function
Example:
@require_fields(["name", "description"])
def post(self):
data = request.get_json()
...
"""
def decorator(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args, **kwargs):
data = request.get_json()
if not data:
return error_response("Request body required")
missing = [field for field in required if not data.get(field)]
if missing:
return error_response(f"Missing required fields: {', '.join(missing)}")
return func(*args, **kwargs)
return wrapper
return decorator
def safe_db_operation(
operation: Callable, error_message: str = "Database operation failed"
) -> Tuple[Any, Optional[Response]]:
"""
Safely execute database operation with error handling.
Args:
operation: Function to execute
error_message: Error message if operation fails
Returns:
Tuple of (result or None, error_response or None)
Example:
result, error = safe_db_operation(
lambda: collection.insert_one(doc),
"Failed to create resource"
)
if error:
return error
"""
try:
result = operation()
return result, None
except Exception as e:
return None, error_response(f"{error_message}: {str(e)}")
def validate_enum(
value: Any, allowed: List[Any], field_name: str
) -> Optional[Response]:
"""
Validate that value is in allowed list.
Args:
value: Value to validate
allowed: List of allowed values
field_name: Field name for error message
Returns:
error_response if invalid, None if valid
Example:
error = validate_enum(status, ["draft", "published"], "status")
if error:
return error
"""
if value not in allowed:
allowed_str = ", ".join(f"'{v}'" for v in allowed)
return error_response(f"Invalid {field_name}. Must be one of: {allowed_str}")
return None
def extract_sort_params(
default_field: str = "created_at",
default_order: str = "desc",
allowed_fields: Optional[List[str]] = None,
) -> Tuple[str, int]:
"""
Extract and validate sort parameters from request.
Args:
default_field: Default sort field
default_order: Default sort order ("asc" or "desc")
allowed_fields: List of allowed sort fields (None = no validation)
Returns:
Tuple of (sort_field, sort_order)
Example:
sort_field, sort_order = extract_sort_params(
allowed_fields=["name", "date", "status"]
)
"""
sort_field = request.args.get("sort", default_field)
sort_order_str = request.args.get("order", default_order).lower()
if allowed_fields and sort_field not in allowed_fields:
sort_field = default_field
sort_order = -1 if sort_order_str == "desc" else 1
return sort_field, sort_order

View File

@@ -0,0 +1,3 @@
from .routes import workflows_ns
__all__ = ["workflows_ns"]

View File

@@ -0,0 +1,541 @@
"""Workflow management routes."""
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional, Set
from flask import current_app, request
from flask_restx import Namespace, Resource
from application.api.user.base import (
workflow_edges_collection,
workflow_nodes_collection,
workflows_collection,
)
from application.core.json_schema_utils import (
JsonSchemaValidationError,
normalize_json_schema_payload,
)
from application.core.model_utils import get_model_capabilities
from application.api.user.utils import (
check_resource_ownership,
error_response,
get_user_id,
require_auth,
require_fields,
safe_db_operation,
success_response,
validate_object_id,
)
workflows_ns = Namespace("workflows", path="/api")
def serialize_workflow(w: Dict) -> Dict:
"""Serialize workflow document to API response format."""
return {
"id": str(w["_id"]),
"name": w.get("name"),
"description": w.get("description"),
"created_at": w["created_at"].isoformat() if w.get("created_at") else None,
"updated_at": w["updated_at"].isoformat() if w.get("updated_at") else None,
}
def serialize_node(n: Dict) -> Dict:
"""Serialize workflow node document to API response format."""
return {
"id": n["id"],
"type": n["type"],
"title": n.get("title"),
"description": n.get("description"),
"position": n.get("position"),
"data": n.get("config", {}),
}
def serialize_edge(e: Dict) -> Dict:
"""Serialize workflow edge document to API response format."""
return {
"id": e["id"],
"source": e.get("source_id"),
"target": e.get("target_id"),
"sourceHandle": e.get("source_handle"),
"targetHandle": e.get("target_handle"),
}
def get_workflow_graph_version(workflow: Dict) -> int:
"""Get current graph version with legacy fallback."""
raw_version = workflow.get("current_graph_version", 1)
try:
version = int(raw_version)
return version if version > 0 else 1
except (ValueError, TypeError):
return 1
def fetch_graph_documents(collection, workflow_id: str, graph_version: int) -> List[Dict]:
"""Fetch graph docs for active version, with fallback for legacy unversioned data."""
docs = list(
collection.find({"workflow_id": workflow_id, "graph_version": graph_version})
)
if docs:
return docs
if graph_version == 1:
return list(
collection.find(
{"workflow_id": workflow_id, "graph_version": {"$exists": False}}
)
)
return docs
def validate_json_schema_payload(
json_schema: Any,
) -> tuple[Optional[Dict[str, Any]], Optional[str]]:
"""Validate and normalize optional JSON schema payload for structured output."""
if json_schema is None:
return None, None
try:
return normalize_json_schema_payload(json_schema), None
except JsonSchemaValidationError as exc:
return None, str(exc)
def normalize_agent_node_json_schemas(nodes: List[Dict]) -> List[Dict]:
"""Normalize agent-node JSON schema payloads before persistence."""
normalized_nodes: List[Dict] = []
for node in nodes:
if not isinstance(node, dict):
normalized_nodes.append(node)
continue
normalized_node = dict(node)
if normalized_node.get("type") != "agent":
normalized_nodes.append(normalized_node)
continue
raw_config = normalized_node.get("data")
if not isinstance(raw_config, dict) or "json_schema" not in raw_config:
normalized_nodes.append(normalized_node)
continue
normalized_config = dict(raw_config)
try:
normalized_config["json_schema"] = normalize_json_schema_payload(
raw_config.get("json_schema")
)
except JsonSchemaValidationError:
# Validation runs before normalization; keep original on unexpected shape.
normalized_config["json_schema"] = raw_config.get("json_schema")
normalized_node["data"] = normalized_config
normalized_nodes.append(normalized_node)
return normalized_nodes
def validate_workflow_structure(nodes: List[Dict], edges: List[Dict]) -> List[str]:
"""Validate workflow graph structure."""
errors = []
if not nodes:
errors.append("Workflow must have at least one node")
return errors
start_nodes = [n for n in nodes if n.get("type") == "start"]
if len(start_nodes) != 1:
errors.append("Workflow must have exactly one start node")
end_nodes = [n for n in nodes if n.get("type") == "end"]
if not end_nodes:
errors.append("Workflow must have at least one end node")
node_ids = {n.get("id") for n in nodes}
node_map = {n.get("id"): n for n in nodes}
end_ids = {n.get("id") for n in end_nodes}
for edge in edges:
source_id = edge.get("source")
target_id = edge.get("target")
if source_id not in node_ids:
errors.append(f"Edge references non-existent source: {source_id}")
if target_id not in node_ids:
errors.append(f"Edge references non-existent target: {target_id}")
if start_nodes:
start_id = start_nodes[0].get("id")
if not any(e.get("source") == start_id for e in edges):
errors.append("Start node must have at least one outgoing edge")
condition_nodes = [n for n in nodes if n.get("type") == "condition"]
for cnode in condition_nodes:
cnode_id = cnode.get("id")
cnode_title = cnode.get("title", cnode_id)
outgoing = [e for e in edges if e.get("source") == cnode_id]
if len(outgoing) < 2:
errors.append(
f"Condition node '{cnode_title}' must have at least 2 outgoing edges"
)
node_data = cnode.get("data", {}) or {}
cases = node_data.get("cases", [])
if not isinstance(cases, list):
cases = []
if not cases or not any(
isinstance(c, dict) and str(c.get("expression", "")).strip() for c in cases
):
errors.append(
f"Condition node '{cnode_title}' must have at least one case with an expression"
)
case_handles: Set[str] = set()
duplicate_case_handles: Set[str] = set()
for case in cases:
if not isinstance(case, dict):
continue
raw_handle = case.get("sourceHandle", "")
handle = raw_handle.strip() if isinstance(raw_handle, str) else ""
if not handle:
errors.append(
f"Condition node '{cnode_title}' has a case without a branch handle"
)
continue
if handle in case_handles:
duplicate_case_handles.add(handle)
case_handles.add(handle)
for handle in duplicate_case_handles:
errors.append(
f"Condition node '{cnode_title}' has duplicate case handle '{handle}'"
)
outgoing_by_handle: Dict[str, List[Dict]] = {}
for out_edge in outgoing:
raw_handle = out_edge.get("sourceHandle", "")
handle = raw_handle.strip() if isinstance(raw_handle, str) else ""
outgoing_by_handle.setdefault(handle, []).append(out_edge)
for handle, handle_edges in outgoing_by_handle.items():
if not handle:
errors.append(
f"Condition node '{cnode_title}' has an outgoing edge without sourceHandle"
)
continue
if handle != "else" and handle not in case_handles:
errors.append(
f"Condition node '{cnode_title}' has a connection from unknown branch '{handle}'"
)
if len(handle_edges) > 1:
errors.append(
f"Condition node '{cnode_title}' has multiple outgoing edges from branch '{handle}'"
)
if "else" not in outgoing_by_handle:
errors.append(f"Condition node '{cnode_title}' must have an 'else' branch")
for case in cases:
if not isinstance(case, dict):
continue
raw_handle = case.get("sourceHandle", "")
handle = raw_handle.strip() if isinstance(raw_handle, str) else ""
if not handle:
continue
raw_expression = case.get("expression", "")
has_expression = isinstance(raw_expression, str) and bool(
raw_expression.strip()
)
has_outgoing = bool(outgoing_by_handle.get(handle))
if has_expression and not has_outgoing:
errors.append(
f"Condition node '{cnode_title}' case '{handle}' has an expression but no outgoing edge"
)
if not has_expression and has_outgoing:
errors.append(
f"Condition node '{cnode_title}' case '{handle}' has an outgoing edge but no expression"
)
for handle, handle_edges in outgoing_by_handle.items():
if not handle:
continue
for out_edge in handle_edges:
target = out_edge.get("target")
if target and not _can_reach_end(target, edges, node_map, end_ids):
errors.append(
f"Branch '{handle}' of condition '{cnode_title}' "
f"must eventually reach an end node"
)
agent_nodes = [n for n in nodes if n.get("type") == "agent"]
for agent_node in agent_nodes:
agent_title = agent_node.get("title", agent_node.get("id", "unknown"))
raw_config = agent_node.get("data", {}) or {}
if not isinstance(raw_config, dict):
errors.append(f"Agent node '{agent_title}' has invalid configuration")
continue
normalized_schema, schema_error = validate_json_schema_payload(
raw_config.get("json_schema")
)
has_json_schema = normalized_schema is not None
model_id = raw_config.get("model_id")
if has_json_schema and isinstance(model_id, str) and model_id.strip():
capabilities = get_model_capabilities(model_id.strip())
if capabilities and not capabilities.get("supports_structured_output", False):
errors.append(
f"Agent node '{agent_title}' selected model does not support structured output"
)
if schema_error:
errors.append(f"Agent node '{agent_title}' JSON schema {schema_error}")
for node in nodes:
if not node.get("id"):
errors.append("All nodes must have an id")
if not node.get("type"):
errors.append(f"Node {node.get('id', 'unknown')} must have a type")
return errors
def _can_reach_end(
node_id: str, edges: List[Dict], node_map: Dict, end_ids: set, visited: set = None
) -> bool:
if visited is None:
visited = set()
if node_id in end_ids:
return True
if node_id in visited or node_id not in node_map:
return False
visited.add(node_id)
outgoing = [e.get("target") for e in edges if e.get("source") == node_id]
return any(_can_reach_end(t, edges, node_map, end_ids, visited) for t in outgoing if t)
def create_workflow_nodes(
workflow_id: str, nodes_data: List[Dict], graph_version: int
) -> None:
"""Insert workflow nodes into database."""
if nodes_data:
workflow_nodes_collection.insert_many(
[
{
"id": n["id"],
"workflow_id": workflow_id,
"graph_version": graph_version,
"type": n["type"],
"title": n.get("title", ""),
"description": n.get("description", ""),
"position": n.get("position", {"x": 0, "y": 0}),
"config": n.get("data", {}),
}
for n in nodes_data
]
)
def create_workflow_edges(
workflow_id: str, edges_data: List[Dict], graph_version: int
) -> None:
"""Insert workflow edges into database."""
if edges_data:
workflow_edges_collection.insert_many(
[
{
"id": e["id"],
"workflow_id": workflow_id,
"graph_version": graph_version,
"source_id": e.get("source"),
"target_id": e.get("target"),
"source_handle": e.get("sourceHandle"),
"target_handle": e.get("targetHandle"),
}
for e in edges_data
]
)
@workflows_ns.route("/workflows")
class WorkflowList(Resource):
@require_auth
@require_fields(["name"])
def post(self):
"""Create a new workflow with nodes and edges."""
user_id = get_user_id()
data = request.get_json()
name = data.get("name", "").strip()
nodes_data = data.get("nodes", [])
edges_data = data.get("edges", [])
validation_errors = validate_workflow_structure(nodes_data, edges_data)
if validation_errors:
return error_response(
"Workflow validation failed", errors=validation_errors
)
nodes_data = normalize_agent_node_json_schemas(nodes_data)
now = datetime.now(timezone.utc)
workflow_doc = {
"name": name,
"description": data.get("description", ""),
"user": user_id,
"created_at": now,
"updated_at": now,
"current_graph_version": 1,
}
result, error = safe_db_operation(
lambda: workflows_collection.insert_one(workflow_doc),
"Failed to create workflow",
)
if error:
return error
workflow_id = str(result.inserted_id)
try:
create_workflow_nodes(workflow_id, nodes_data, 1)
create_workflow_edges(workflow_id, edges_data, 1)
except Exception as e:
workflow_nodes_collection.delete_many({"workflow_id": workflow_id})
workflow_edges_collection.delete_many({"workflow_id": workflow_id})
workflows_collection.delete_one({"_id": result.inserted_id})
return error_response(f"Failed to create workflow structure: {str(e)}")
return success_response({"id": workflow_id}, 201)
@workflows_ns.route("/workflows/<string:workflow_id>")
class WorkflowDetail(Resource):
@require_auth
def get(self, workflow_id: str):
"""Get workflow details with nodes and edges."""
user_id = get_user_id()
obj_id, error = validate_object_id(workflow_id, "Workflow")
if error:
return error
workflow, error = check_resource_ownership(
workflows_collection, obj_id, user_id, "Workflow"
)
if error:
return error
graph_version = get_workflow_graph_version(workflow)
nodes = fetch_graph_documents(
workflow_nodes_collection, workflow_id, graph_version
)
edges = fetch_graph_documents(
workflow_edges_collection, workflow_id, graph_version
)
return success_response(
{
"workflow": serialize_workflow(workflow),
"nodes": [serialize_node(n) for n in nodes],
"edges": [serialize_edge(e) for e in edges],
}
)
@require_auth
@require_fields(["name"])
def put(self, workflow_id: str):
"""Update workflow and replace nodes/edges."""
user_id = get_user_id()
obj_id, error = validate_object_id(workflow_id, "Workflow")
if error:
return error
workflow, error = check_resource_ownership(
workflows_collection, obj_id, user_id, "Workflow"
)
if error:
return error
data = request.get_json()
name = data.get("name", "").strip()
nodes_data = data.get("nodes", [])
edges_data = data.get("edges", [])
validation_errors = validate_workflow_structure(nodes_data, edges_data)
if validation_errors:
return error_response(
"Workflow validation failed", errors=validation_errors
)
nodes_data = normalize_agent_node_json_schemas(nodes_data)
current_graph_version = get_workflow_graph_version(workflow)
next_graph_version = current_graph_version + 1
try:
create_workflow_nodes(workflow_id, nodes_data, next_graph_version)
create_workflow_edges(workflow_id, edges_data, next_graph_version)
except Exception as e:
workflow_nodes_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": next_graph_version}
)
workflow_edges_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": next_graph_version}
)
return error_response(f"Failed to update workflow structure: {str(e)}")
now = datetime.now(timezone.utc)
_, error = safe_db_operation(
lambda: workflows_collection.update_one(
{"_id": obj_id},
{
"$set": {
"name": name,
"description": data.get("description", ""),
"updated_at": now,
"current_graph_version": next_graph_version,
}
},
),
"Failed to update workflow",
)
if error:
workflow_nodes_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": next_graph_version}
)
workflow_edges_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": next_graph_version}
)
return error
try:
workflow_nodes_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": {"$ne": next_graph_version}}
)
workflow_edges_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": {"$ne": next_graph_version}}
)
except Exception as cleanup_err:
current_app.logger.warning(
f"Failed to clean old workflow graph versions for {workflow_id}: {cleanup_err}"
)
return success_response()
@require_auth
def delete(self, workflow_id: str):
"""Delete workflow and its graph."""
user_id = get_user_id()
obj_id, error = validate_object_id(workflow_id, "Workflow")
if error:
return error
workflow, error = check_resource_ownership(
workflows_collection, obj_id, user_id, "Workflow"
)
if error:
return error
try:
workflow_nodes_collection.delete_many({"workflow_id": workflow_id})
workflow_edges_collection.delete_many({"workflow_id": workflow_id})
workflows_collection.delete_one({"_id": workflow["_id"], "user": user_id})
except Exception as e:
return error_response(f"Failed to delete workflow: {str(e)}")
return success_response()

View File

@@ -19,9 +19,9 @@ def handle_auth(request, data={}):
options={"verify_exp": False},
)
return decoded_token
except Exception as e:
except Exception:
return {
"message": f"Authentication error: {str(e)}",
"message": "Authentication error: invalid token",
"error": "invalid_token",
}
else:

View File

@@ -21,3 +21,4 @@ def config_loggers(*args, **kwargs):
celery = make_celery()
celery.config_from_object("application.celeryconfig")

View File

@@ -6,3 +6,6 @@ result_backend = os.getenv("CELERY_RESULT_BACKEND")
task_serializer = 'json'
result_serializer = 'json'
accept_content = ['json']
# Autodiscover tasks
imports = ('application.api.user.tasks',)

View File

@@ -0,0 +1,34 @@
from typing import Any, Dict, Optional
class JsonSchemaValidationError(ValueError):
"""Raised when a JSON schema payload is invalid."""
def normalize_json_schema_payload(json_schema: Any) -> Optional[Dict[str, Any]]:
"""
Normalize accepted JSON schema payload shapes to a plain schema object.
Accepted inputs:
- None
- A raw schema object with a top-level "type"
- A wrapped payload with a top-level "schema" object
"""
if json_schema is None:
return None
if not isinstance(json_schema, dict):
raise JsonSchemaValidationError("must be a valid JSON object")
wrapped_schema = json_schema.get("schema")
if wrapped_schema is not None:
if not isinstance(wrapped_schema, dict):
raise JsonSchemaValidationError('field "schema" must be a valid JSON object')
return wrapped_schema
if "type" not in json_schema:
raise JsonSchemaValidationError(
'must include either a "type" or "schema" field'
)
return json_schema

View File

@@ -0,0 +1,224 @@
"""
Model configurations for all supported LLM providers.
"""
from application.core.model_settings import (
AvailableModel,
ModelCapabilities,
ModelProvider,
)
# Base image attachment types supported by most vision-capable LLMs
IMAGE_ATTACHMENTS = [
"image/png",
"image/jpeg",
"image/jpg",
"image/webp",
"image/gif",
]
# PDF excluded: most OpenAI-compatible endpoints don't support native PDF uploads.
# When excluded, PDFs are synthetically processed by converting pages to images.
OPENAI_ATTACHMENTS = IMAGE_ATTACHMENTS
GOOGLE_ATTACHMENTS = ["application/pdf"] + IMAGE_ATTACHMENTS
ANTHROPIC_ATTACHMENTS = IMAGE_ATTACHMENTS
OPENROUTER_ATTACHMENTS = IMAGE_ATTACHMENTS
OPENAI_MODELS = [
AvailableModel(
id="gpt-5.1",
provider=ModelProvider.OPENAI,
display_name="GPT-5.1",
description="Flagship model with enhanced reasoning, coding, and agentic capabilities",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=OPENAI_ATTACHMENTS,
context_window=200000,
),
),
AvailableModel(
id="gpt-5-mini",
provider=ModelProvider.OPENAI,
display_name="GPT-5 Mini",
description="Faster, cost-effective variant of GPT-5.1",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=OPENAI_ATTACHMENTS,
context_window=200000,
),
)
]
ANTHROPIC_MODELS = [
AvailableModel(
id="claude-3-5-sonnet-20241022",
provider=ModelProvider.ANTHROPIC,
display_name="Claude 3.5 Sonnet (Latest)",
description="Latest Claude 3.5 Sonnet with enhanced capabilities",
capabilities=ModelCapabilities(
supports_tools=True,
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
context_window=200000,
),
),
AvailableModel(
id="claude-3-5-sonnet",
provider=ModelProvider.ANTHROPIC,
display_name="Claude 3.5 Sonnet",
description="Balanced performance and capability",
capabilities=ModelCapabilities(
supports_tools=True,
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
context_window=200000,
),
),
AvailableModel(
id="claude-3-opus",
provider=ModelProvider.ANTHROPIC,
display_name="Claude 3 Opus",
description="Most capable Claude model",
capabilities=ModelCapabilities(
supports_tools=True,
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
context_window=200000,
),
),
AvailableModel(
id="claude-3-haiku",
provider=ModelProvider.ANTHROPIC,
display_name="Claude 3 Haiku",
description="Fastest Claude model",
capabilities=ModelCapabilities(
supports_tools=True,
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
context_window=200000,
),
),
]
GOOGLE_MODELS = [
AvailableModel(
id="gemini-flash-latest",
provider=ModelProvider.GOOGLE,
display_name="Gemini Flash (Latest)",
description="Latest experimental Gemini model",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=GOOGLE_ATTACHMENTS,
context_window=int(1e6),
),
),
AvailableModel(
id="gemini-flash-lite-latest",
provider=ModelProvider.GOOGLE,
display_name="Gemini Flash Lite (Latest)",
description="Fast with huge context window",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=GOOGLE_ATTACHMENTS,
context_window=int(1e6),
),
),
AvailableModel(
id="gemini-3-pro-preview",
provider=ModelProvider.GOOGLE,
display_name="Gemini 3 Pro",
description="Most capable Gemini model",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=GOOGLE_ATTACHMENTS,
context_window=2000000,
),
),
]
GROQ_MODELS = [
AvailableModel(
id="llama-3.3-70b-versatile",
provider=ModelProvider.GROQ,
display_name="Llama 3.3 70B",
description="Latest Llama model with high-speed inference",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=128000,
),
),
AvailableModel(
id="openai/gpt-oss-120b",
provider=ModelProvider.GROQ,
display_name="GPT-OSS 120B",
description="Open-source GPT model optimized for speed",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=128000,
),
),
]
OPENROUTER_MODELS = [
AvailableModel(
id="qwen/qwen3-coder:free",
provider=ModelProvider.OPENROUTER,
display_name="Qwen 3 Coder",
description="Latest Qwen model with high-speed inference",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=128000,
supported_attachment_types=OPENROUTER_ATTACHMENTS
),
),
AvailableModel(
id="google/gemma-3-27b-it:free",
provider=ModelProvider.OPENROUTER,
display_name="Gemma 3 27B",
description="Latest Gemma model with high-speed inference",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=128000,
supported_attachment_types=OPENROUTER_ATTACHMENTS
),
),
]
AZURE_OPENAI_MODELS = [
AvailableModel(
id="azure-gpt-4",
provider=ModelProvider.AZURE_OPENAI,
display_name="Azure OpenAI GPT-4",
description="Azure-hosted GPT model",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=OPENAI_ATTACHMENTS,
context_window=8192,
),
),
]
def create_custom_openai_model(model_name: str, base_url: str) -> AvailableModel:
"""Create a custom OpenAI-compatible model (e.g., LM Studio, Ollama)."""
return AvailableModel(
id=model_name,
provider=ModelProvider.OPENAI,
display_name=model_name,
description=f"Custom OpenAI-compatible model at {base_url}",
base_url=base_url,
capabilities=ModelCapabilities(
supports_tools=True,
supported_attachment_types=OPENAI_ATTACHMENTS,
),
)

View File

@@ -0,0 +1,295 @@
import logging
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, List, Optional
logger = logging.getLogger(__name__)
class ModelProvider(str, Enum):
OPENAI = "openai"
OPENROUTER = "openrouter"
AZURE_OPENAI = "azure_openai"
ANTHROPIC = "anthropic"
GROQ = "groq"
GOOGLE = "google"
HUGGINGFACE = "huggingface"
LLAMA_CPP = "llama.cpp"
DOCSGPT = "docsgpt"
PREMAI = "premai"
SAGEMAKER = "sagemaker"
NOVITA = "novita"
@dataclass
class ModelCapabilities:
supports_tools: bool = False
supports_structured_output: bool = False
supports_streaming: bool = True
supported_attachment_types: List[str] = field(default_factory=list)
context_window: int = 128000
input_cost_per_token: Optional[float] = None
output_cost_per_token: Optional[float] = None
@dataclass
class AvailableModel:
id: str
provider: ModelProvider
display_name: str
description: str = ""
capabilities: ModelCapabilities = field(default_factory=ModelCapabilities)
enabled: bool = True
base_url: Optional[str] = None
def to_dict(self) -> Dict:
result = {
"id": self.id,
"provider": self.provider.value,
"display_name": self.display_name,
"description": self.description,
"supported_attachment_types": self.capabilities.supported_attachment_types,
"supports_tools": self.capabilities.supports_tools,
"supports_structured_output": self.capabilities.supports_structured_output,
"supports_streaming": self.capabilities.supports_streaming,
"context_window": self.capabilities.context_window,
"enabled": self.enabled,
}
if self.base_url:
result["base_url"] = self.base_url
return result
class ModelRegistry:
_instance = None
_initialized = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not ModelRegistry._initialized:
self.models: Dict[str, AvailableModel] = {}
self.default_model_id: Optional[str] = None
self._load_models()
ModelRegistry._initialized = True
@classmethod
def get_instance(cls) -> "ModelRegistry":
return cls()
def _load_models(self):
from application.core.settings import settings
self.models.clear()
# Skip DocsGPT model if using custom OpenAI-compatible endpoint
if not settings.OPENAI_BASE_URL:
self._add_docsgpt_models(settings)
if (
settings.OPENAI_API_KEY
or (settings.LLM_PROVIDER == "openai" and settings.API_KEY)
or settings.OPENAI_BASE_URL
):
self._add_openai_models(settings)
if settings.OPENAI_API_BASE or (
settings.LLM_PROVIDER == "azure_openai" and settings.API_KEY
):
self._add_azure_openai_models(settings)
if settings.ANTHROPIC_API_KEY or (
settings.LLM_PROVIDER == "anthropic" and settings.API_KEY
):
self._add_anthropic_models(settings)
if settings.GOOGLE_API_KEY or (
settings.LLM_PROVIDER == "google" and settings.API_KEY
):
self._add_google_models(settings)
if settings.GROQ_API_KEY or (
settings.LLM_PROVIDER == "groq" and settings.API_KEY
):
self._add_groq_models(settings)
if settings.OPEN_ROUTER_API_KEY or (
settings.LLM_PROVIDER == "openrouter" and settings.API_KEY
):
self._add_openrouter_models(settings)
if settings.HUGGINGFACE_API_KEY or (
settings.LLM_PROVIDER == "huggingface" and settings.API_KEY
):
self._add_huggingface_models(settings)
# Default model selection
if settings.LLM_NAME:
# Parse LLM_NAME (may be comma-separated)
model_names = self._parse_model_names(settings.LLM_NAME)
# First model in the list becomes default
for model_name in model_names:
if model_name in self.models:
self.default_model_id = model_name
break
# Backward compat: try exact match if no parsed model found
if not self.default_model_id and settings.LLM_NAME in self.models:
self.default_model_id = settings.LLM_NAME
if not self.default_model_id:
if settings.LLM_PROVIDER and settings.API_KEY:
for model_id, model in self.models.items():
if model.provider.value == settings.LLM_PROVIDER:
self.default_model_id = model_id
break
if not self.default_model_id and self.models:
self.default_model_id = next(iter(self.models.keys()))
logger.info(
f"ModelRegistry loaded {len(self.models)} models, default: {self.default_model_id}"
)
def _add_openai_models(self, settings):
from application.core.model_configs import (
OPENAI_MODELS,
create_custom_openai_model,
)
# Check if using local OpenAI-compatible endpoint (Ollama, LM Studio, etc.)
using_local_endpoint = bool(
settings.OPENAI_BASE_URL and settings.OPENAI_BASE_URL.strip()
)
if using_local_endpoint:
# When OPENAI_BASE_URL is set, ONLY register custom models from LLM_NAME
# Do NOT add standard OpenAI models (gpt-5.1, etc.)
if settings.LLM_NAME:
model_names = self._parse_model_names(settings.LLM_NAME)
for model_name in model_names:
custom_model = create_custom_openai_model(
model_name, settings.OPENAI_BASE_URL
)
self.models[model_name] = custom_model
logger.info(
f"Registered custom OpenAI model: {model_name} at {settings.OPENAI_BASE_URL}"
)
else:
# Standard OpenAI API usage - add standard models if API key is valid
if settings.OPENAI_API_KEY:
for model in OPENAI_MODELS:
self.models[model.id] = model
def _add_azure_openai_models(self, settings):
from application.core.model_configs import AZURE_OPENAI_MODELS
if settings.LLM_PROVIDER == "azure_openai" and settings.LLM_NAME:
for model in AZURE_OPENAI_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in AZURE_OPENAI_MODELS:
self.models[model.id] = model
def _add_anthropic_models(self, settings):
from application.core.model_configs import ANTHROPIC_MODELS
if settings.ANTHROPIC_API_KEY:
for model in ANTHROPIC_MODELS:
self.models[model.id] = model
return
if settings.LLM_PROVIDER == "anthropic" and settings.LLM_NAME:
for model in ANTHROPIC_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in ANTHROPIC_MODELS:
self.models[model.id] = model
def _add_google_models(self, settings):
from application.core.model_configs import GOOGLE_MODELS
if settings.GOOGLE_API_KEY:
for model in GOOGLE_MODELS:
self.models[model.id] = model
return
if settings.LLM_PROVIDER == "google" and settings.LLM_NAME:
for model in GOOGLE_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in GOOGLE_MODELS:
self.models[model.id] = model
def _add_groq_models(self, settings):
from application.core.model_configs import GROQ_MODELS
if settings.GROQ_API_KEY:
for model in GROQ_MODELS:
self.models[model.id] = model
return
if settings.LLM_PROVIDER == "groq" and settings.LLM_NAME:
for model in GROQ_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in GROQ_MODELS:
self.models[model.id] = model
def _add_openrouter_models(self, settings):
from application.core.model_configs import OPENROUTER_MODELS
if settings.OPEN_ROUTER_API_KEY:
for model in OPENROUTER_MODELS:
self.models[model.id] = model
return
if settings.LLM_PROVIDER == "openrouter" and settings.LLM_NAME:
for model in OPENROUTER_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in OPENROUTER_MODELS:
self.models[model.id] = model
def _add_docsgpt_models(self, settings):
model_id = "docsgpt-local"
model = AvailableModel(
id=model_id,
provider=ModelProvider.DOCSGPT,
display_name="DocsGPT Model",
description="Local model",
capabilities=ModelCapabilities(
supports_tools=False,
supported_attachment_types=[],
),
)
self.models[model_id] = model
def _add_huggingface_models(self, settings):
model_id = "huggingface-local"
model = AvailableModel(
id=model_id,
provider=ModelProvider.HUGGINGFACE,
display_name="Hugging Face Model",
description="Local Hugging Face model",
capabilities=ModelCapabilities(
supports_tools=False,
supported_attachment_types=[],
),
)
self.models[model_id] = model
def _parse_model_names(self, llm_name: str) -> List[str]:
"""
Parse LLM_NAME which may contain comma-separated model names.
E.g., 'deepseek-r1:1.5b,gemma:2b' -> ['deepseek-r1:1.5b', 'gemma:2b']
"""
if not llm_name:
return []
return [name.strip() for name in llm_name.split(",") if name.strip()]
def get_model(self, model_id: str) -> Optional[AvailableModel]:
return self.models.get(model_id)
def get_all_models(self) -> List[AvailableModel]:
return list(self.models.values())
def get_enabled_models(self) -> List[AvailableModel]:
return [m for m in self.models.values() if m.enabled]
def model_exists(self, model_id: str) -> bool:
return model_id in self.models

View File

@@ -0,0 +1,92 @@
from typing import Any, Dict, Optional
from application.core.model_settings import ModelRegistry
def get_api_key_for_provider(provider: str) -> Optional[str]:
"""Get the appropriate API key for a provider"""
from application.core.settings import settings
provider_key_map = {
"openai": settings.OPENAI_API_KEY,
"openrouter": settings.OPEN_ROUTER_API_KEY,
"anthropic": settings.ANTHROPIC_API_KEY,
"google": settings.GOOGLE_API_KEY,
"groq": settings.GROQ_API_KEY,
"huggingface": settings.HUGGINGFACE_API_KEY,
"azure_openai": settings.API_KEY,
"docsgpt": None,
"llama.cpp": None,
}
provider_key = provider_key_map.get(provider)
if provider_key:
return provider_key
return settings.API_KEY
def get_all_available_models() -> Dict[str, Dict[str, Any]]:
"""Get all available models with metadata for API response"""
registry = ModelRegistry.get_instance()
return {model.id: model.to_dict() for model in registry.get_enabled_models()}
def validate_model_id(model_id: str) -> bool:
"""Check if a model ID exists in registry"""
registry = ModelRegistry.get_instance()
return registry.model_exists(model_id)
def get_model_capabilities(model_id: str) -> Optional[Dict[str, Any]]:
"""Get capabilities for a specific model"""
registry = ModelRegistry.get_instance()
model = registry.get_model(model_id)
if model:
return {
"supported_attachment_types": model.capabilities.supported_attachment_types,
"supports_tools": model.capabilities.supports_tools,
"supports_structured_output": model.capabilities.supports_structured_output,
"context_window": model.capabilities.context_window,
}
return None
def get_default_model_id() -> str:
"""Get the system default model ID"""
registry = ModelRegistry.get_instance()
return registry.default_model_id
def get_provider_from_model_id(model_id: str) -> Optional[str]:
"""Get the provider name for a given model_id"""
registry = ModelRegistry.get_instance()
model = registry.get_model(model_id)
if model:
return model.provider.value
return None
def get_token_limit(model_id: str) -> int:
"""
Get context window (token limit) for a model.
Returns model's context_window or default 128000 if model not found.
"""
from application.core.settings import settings
registry = ModelRegistry.get_instance()
model = registry.get_model(model_id)
if model:
return model.capabilities.context_window
return settings.DEFAULT_LLM_TOKEN_LIMIT
def get_base_url_for_model(model_id: str) -> Optional[str]:
"""
Get the custom base_url for a specific model if configured.
Returns None if no custom base_url is set.
"""
registry = ModelRegistry.get_instance()
model = registry.get_model(model_id)
if model:
return model.base_url
return None

View File

@@ -2,7 +2,8 @@ import os
from pathlib import Path
from typing import Optional
from pydantic_settings import BaseSettings
from pydantic import field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
@@ -10,27 +11,26 @@ current_dir = os.path.dirname(
class Settings(BaseSettings):
model_config = SettingsConfigDict(extra="ignore")
AUTH_TYPE: Optional[str] = None # simple_jwt, session_jwt, or None
LLM_PROVIDER: str = "docsgpt"
LLM_NAME: Optional[str] = (
None # if LLM_PROVIDER is openai, LLM_NAME can be gpt-4 or gpt-3.5-turbo
)
EMBEDDINGS_NAME: str = "huggingface_sentence-transformers/all-mpnet-base-v2"
EMBEDDINGS_BASE_URL: Optional[str] = None # Remote embeddings API URL (OpenAI-compatible)
EMBEDDINGS_KEY: Optional[str] = (
None # api key for embeddings (if using openai, just copy API_KEY)
)
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
MONGO_URI: str = "mongodb://localhost:27017/docsgpt"
MONGO_DB_NAME: str = "docsgpt"
LLM_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf")
DEFAULT_MAX_HISTORY: int = 150
LLM_TOKEN_LIMITS: dict = {
"gpt-4o": 128000,
"gpt-4o-mini": 128000,
"gpt-4": 8192,
"gpt-3.5-turbo": 4096,
"claude-2": int(1e5),
"gemini-2.5-flash": int(1e6),
}
DEFAULT_LLM_TOKEN_LIMIT: int = 128000
DEFAULT_LLM_TOKEN_LIMIT: int = 128000 # Fallback when model not found in registry
RESERVED_TOKENS: dict = {
"system_prompt": 500,
"current_query": 500,
@@ -43,8 +43,10 @@ class Settings(BaseSettings):
UPLOAD_FOLDER: str = "inputs"
PARSE_PDF_AS_IMAGE: bool = False
PARSE_IMAGE_REMOTE: bool = False
DOCLING_OCR_ENABLED: bool = False # Enable OCR for docling parsers (PDF, images)
DOCLING_OCR_ATTACHMENTS_ENABLED: bool = False # Enable OCR for docling when parsing attachments
VECTOR_STORE: str = (
"faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb"
"faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb" or "pgvector"
)
RETRIEVERS_ENABLED: list = ["classic_rag"]
AGENT_NAME: str = "classic"
@@ -63,6 +65,12 @@ class Settings(BaseSettings):
"http://127.0.0.1:7091/api/connectors/callback" ##add redirect url as it is to your provider's console(gcp)
)
# Microsoft Entra ID (Azure AD) integration
MICROSOFT_CLIENT_ID: Optional[str] = None # Azure AD Application (client) ID
MICROSOFT_CLIENT_SECRET: Optional[str] = None # Azure AD Application client secret
MICROSOFT_TENANT_ID: Optional[str] = "common" # Azure AD Tenant ID (or 'common' for multi-tenant)
MICROSOFT_AUTHORITY: Optional[str] = None # e.g., "https://login.microsoftonline.com/{tenant_id}"
# GitHub source
GITHUB_ACCESS_TOKEN: Optional[str] = None # PAT token with read repo access
@@ -70,11 +78,19 @@ class Settings(BaseSettings):
CACHE_REDIS_URL: str = "redis://localhost:6379/2"
API_URL: str = "http://localhost:7091" # backend url for celery worker
MCP_OAUTH_REDIRECT_URI: Optional[str] = None # public callback URL for MCP OAuth
INTERNAL_KEY: Optional[str] = None # internal api key for worker-to-backend auth
API_KEY: Optional[str] = None # LLM api key (used by LLM_PROVIDER)
# Provider-specific API keys (for multi-model support)
OPENAI_API_KEY: Optional[str] = None
ANTHROPIC_API_KEY: Optional[str] = None
GOOGLE_API_KEY: Optional[str] = None
GROQ_API_KEY: Optional[str] = None
HUGGINGFACE_API_KEY: Optional[str] = None
OPEN_ROUTER_API_KEY: Optional[str] = None
API_KEY: Optional[str] = None # LLM api key
EMBEDDINGS_KEY: Optional[str] = (
None # api key for embeddings (if using openai, just copy API_KEY)
)
OPENAI_API_BASE: Optional[str] = None # azure openai api base url
OPENAI_API_VERSION: Optional[str] = None # azure openai api version
AZURE_DEPLOYMENT_NAME: Optional[str] = None # azure deployment name for answering
@@ -124,7 +140,7 @@ class Settings(BaseSettings):
MILVUS_TOKEN: Optional[str] = ""
# LanceDB vectorstore config
LANCEDB_PATH: str = "/tmp/lancedb" # Path where LanceDB stores its local data
LANCEDB_PATH: str = "./data/lancedb" # Path where LanceDB stores its local data
LANCEDB_TABLE_NAME: Optional[str] = (
"docsgpts" # Name of the table to use for storing vectors
)
@@ -138,11 +154,50 @@ class Settings(BaseSettings):
# Encryption settings
ENCRYPTION_SECRET_KEY: str = "default-docsgpt-encryption-key"
TTS_PROVIDER: str = "google_tts" # google_tts or elevenlabs
TTS_PROVIDER: str = "google_tts" # google_tts or elevenlabs
ELEVENLABS_API_KEY: Optional[str] = None
# Tool pre-fetch settings
ENABLE_TOOL_PREFETCH: bool = True
path = Path(__file__).parent.parent.absolute()
# Conversation Compression Settings
ENABLE_CONVERSATION_COMPRESSION: bool = True
COMPRESSION_THRESHOLD_PERCENTAGE: float = 0.8 # Trigger at 80% of context
COMPRESSION_MODEL_OVERRIDE: Optional[str] = None # Use different model for compression
COMPRESSION_PROMPT_VERSION: str = "v1.0" # Track prompt iterations
COMPRESSION_MAX_HISTORY_POINTS: int = 3 # Keep only last N compression points to prevent DB bloat
@field_validator(
"API_KEY",
"OPENAI_API_KEY",
"ANTHROPIC_API_KEY",
"GOOGLE_API_KEY",
"GROQ_API_KEY",
"HUGGINGFACE_API_KEY",
"EMBEDDINGS_KEY",
"FALLBACK_LLM_API_KEY",
"QDRANT_API_KEY",
"ELEVENLABS_API_KEY",
"INTERNAL_KEY",
mode="before",
)
@classmethod
def normalize_api_key(cls, v: Optional[str]) -> Optional[str]:
"""
Normalize API keys: convert 'None', 'none', empty strings,
and whitespace-only strings to actual None.
Handles Pydantic loading 'None' from .env as string "None".
"""
if v is None:
return None
if not isinstance(v, str):
return v
stripped = v.strip()
if stripped == "" or stripped.lower() == "none":
return None
return stripped
# Project root is one level above application/
path = Path(__file__).parent.parent.parent.absolute()
settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8")

View File

@@ -0,0 +1,181 @@
"""
URL validation utilities to prevent SSRF (Server-Side Request Forgery) attacks.
This module provides functions to validate URLs before making HTTP requests,
blocking access to internal networks, cloud metadata services, and other
potentially dangerous endpoints.
"""
import ipaddress
import socket
from urllib.parse import urlparse
from typing import Optional, Set
class SSRFError(Exception):
"""Raised when a URL fails SSRF validation."""
pass
# Blocked hostnames that should never be accessed
BLOCKED_HOSTNAMES: Set[str] = {
"localhost",
"localhost.localdomain",
"metadata.google.internal",
"metadata",
}
# Cloud metadata IP addresses (AWS, GCP, Azure, etc.)
METADATA_IPS: Set[str] = {
"169.254.169.254", # AWS, GCP, Azure metadata
"169.254.170.2", # AWS ECS task metadata
"fd00:ec2::254", # AWS IPv6 metadata
}
# Allowed schemes for external requests
ALLOWED_SCHEMES: Set[str] = {"http", "https"}
def is_private_ip(ip_str: str) -> bool:
"""
Check if an IP address is private, loopback, or link-local.
Args:
ip_str: IP address as a string
Returns:
True if the IP is private/internal, False otherwise
"""
try:
ip = ipaddress.ip_address(ip_str)
return (
ip.is_private or
ip.is_loopback or
ip.is_link_local or
ip.is_reserved or
ip.is_multicast or
ip.is_unspecified
)
except ValueError:
# If we can't parse it as an IP, return False
return False
def is_metadata_ip(ip_str: str) -> bool:
"""
Check if an IP address is a cloud metadata service IP.
Args:
ip_str: IP address as a string
Returns:
True if the IP is a metadata service, False otherwise
"""
return ip_str in METADATA_IPS
def resolve_hostname(hostname: str) -> Optional[str]:
"""
Resolve a hostname to an IP address.
Args:
hostname: The hostname to resolve
Returns:
The resolved IP address, or None if resolution fails
"""
try:
return socket.gethostbyname(hostname)
except socket.gaierror:
return None
def validate_url(url: str, allow_localhost: bool = False) -> str:
"""
Validate a URL to prevent SSRF attacks.
This function checks that:
1. The URL has an allowed scheme (http or https)
2. The hostname is not a blocked hostname
3. The resolved IP is not a private/internal IP
4. The resolved IP is not a cloud metadata service
Args:
url: The URL to validate
allow_localhost: If True, allow localhost connections (for testing only)
Returns:
The validated URL (with scheme added if missing)
Raises:
SSRFError: If the URL fails validation
"""
# Ensure URL has a scheme
if not urlparse(url).scheme:
url = "http://" + url
parsed = urlparse(url)
# Check scheme
if parsed.scheme not in ALLOWED_SCHEMES:
raise SSRFError(f"URL scheme '{parsed.scheme}' is not allowed. Only HTTP(S) is permitted.")
hostname = parsed.hostname
if not hostname:
raise SSRFError("URL must have a valid hostname.")
hostname_lower = hostname.lower()
# Check blocked hostnames
if hostname_lower in BLOCKED_HOSTNAMES and not allow_localhost:
raise SSRFError(f"Access to '{hostname}' is not allowed.")
# Check if hostname is an IP address directly
try:
ip = ipaddress.ip_address(hostname)
ip_str = str(ip)
if is_metadata_ip(ip_str):
raise SSRFError("Access to cloud metadata services is not allowed.")
if is_private_ip(ip_str) and not allow_localhost:
raise SSRFError("Access to private/internal IP addresses is not allowed.")
return url
except ValueError:
# Not an IP address, it's a hostname - resolve it
pass
# Resolve hostname and check the IP
resolved_ip = resolve_hostname(hostname)
if resolved_ip is None:
raise SSRFError(f"Unable to resolve hostname: {hostname}")
if is_metadata_ip(resolved_ip):
raise SSRFError("Access to cloud metadata services is not allowed.")
if is_private_ip(resolved_ip) and not allow_localhost:
raise SSRFError("Access to private/internal networks is not allowed.")
return url
def validate_url_safe(url: str, allow_localhost: bool = False) -> tuple[bool, str, Optional[str]]:
"""
Validate a URL and return a tuple with validation result.
This is a non-throwing version of validate_url for cases where
you want to handle validation failures gracefully.
Args:
url: The URL to validate
allow_localhost: If True, allow localhost connections (for testing only)
Returns:
Tuple of (is_valid, validated_url_or_original, error_message_or_none)
"""
try:
validated = validate_url(url, allow_localhost)
return (True, validated, None)
except SSRFError as e:
return (False, url, str(e))

View File

@@ -13,3 +13,25 @@ def response_error(code_status, message=None):
def bad_request(status_code=400, message=''):
return response_error(code_status=status_code, message=message)
def sanitize_api_error(error) -> str:
"""
Convert technical API errors to user-friendly messages.
Works with both Exception objects and error message strings.
"""
error_str = str(error).lower()
if "503" in error_str or "unavailable" in error_str or "high demand" in error_str:
return "The AI service is temporarily unavailable due to high demand. Please try again in a moment."
if "429" in error_str or "rate limit" in error_str or "quota" in error_str:
return "Rate limit exceeded. Please wait a moment before trying again."
if "401" in error_str or "unauthorized" in error_str or "invalid api key" in error_str:
return "Authentication error. Please check your API configuration."
if "timeout" in error_str or "timed out" in error_str:
return "The request timed out. Please try again."
if "connection" in error_str or "network" in error_str:
return "Network error. Please check your connection and try again."
original = str(error)
if len(original) > 200 or "{" in original or "traceback" in error_str:
return "An error occurred while processing your request. Please try again later."
return original

View File

@@ -1,30 +1,48 @@
from application.llm.base import BaseLLM
import base64
import logging
from anthropic import AI_PROMPT, Anthropic, HUMAN_PROMPT
from application.core.settings import settings
from application.llm.base import BaseLLM
from application.storage.storage_creator import StorageCreator
logger = logging.getLogger(__name__)
class AnthropicLLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.api_key = (
api_key or settings.ANTHROPIC_API_KEY
) # If not provided, use a default from settings
self.api_key = api_key or settings.ANTHROPIC_API_KEY or settings.API_KEY
self.user_api_key = user_api_key
self.anthropic = Anthropic(api_key=self.api_key)
# Use custom base_url if provided
if base_url:
self.anthropic = Anthropic(api_key=self.api_key, base_url=base_url)
else:
self.anthropic = Anthropic(api_key=self.api_key)
self.HUMAN_PROMPT = HUMAN_PROMPT
self.AI_PROMPT = AI_PROMPT
self.storage = StorageCreator.get_storage()
def _raw_gen(
self, baseself, model, messages, stream=False, tools=None, max_tokens=300, **kwargs
self,
baseself,
model,
messages,
stream=False,
tools=None,
max_tokens=300,
**kwargs,
):
context = messages[0]["content"]
user_question = messages[-1]["content"]
prompt = f"### Context \n {context} \n ### Question \n {user_question}"
if stream:
return self.gen_stream(model, prompt, stream, max_tokens, **kwargs)
completion = self.anthropic.completions.create(
model=model,
max_tokens_to_sample=max_tokens,
@@ -34,7 +52,14 @@ class AnthropicLLM(BaseLLM):
return completion.completion
def _raw_gen_stream(
self, baseself, model, messages, stream=True, tools=None, max_tokens=300, **kwargs
self,
baseself,
model,
messages,
stream=True,
tools=None,
max_tokens=300,
**kwargs,
):
context = messages[0]["content"]
user_question = messages[-1]["content"]
@@ -50,5 +75,117 @@ class AnthropicLLM(BaseLLM):
for completion in stream_response:
yield completion.completion
finally:
if hasattr(stream_response, 'close'):
if hasattr(stream_response, "close"):
stream_response.close()
def get_supported_attachment_types(self):
"""
Return a list of MIME types supported by Anthropic Claude for file uploads.
Claude supports images but not PDFs natively.
PDFs are synthetically supported via PDF-to-image conversion in the handler.
Returns:
list: List of supported MIME types
"""
return [
"image/png",
"image/jpeg",
"image/jpg",
"image/webp",
"image/gif",
]
def prepare_messages_with_attachments(self, messages, attachments=None):
"""
Process attachments for Anthropic Claude API.
Formats images using Claude's vision message format.
Args:
messages (list): List of message dictionaries.
attachments (list): List of attachment dictionaries with content and metadata.
Returns:
list: Messages formatted with image content for Claude API.
"""
if not attachments:
return messages
prepared_messages = messages.copy()
# Find the last user message to attach images to
user_message_index = None
for i in range(len(prepared_messages) - 1, -1, -1):
if prepared_messages[i].get("role") == "user":
user_message_index = i
break
if user_message_index is None:
user_message = {"role": "user", "content": []}
prepared_messages.append(user_message)
user_message_index = len(prepared_messages) - 1
# Convert content to list format if it's a string
if isinstance(prepared_messages[user_message_index].get("content"), str):
text_content = prepared_messages[user_message_index]["content"]
prepared_messages[user_message_index]["content"] = [
{"type": "text", "text": text_content}
]
elif not isinstance(prepared_messages[user_message_index].get("content"), list):
prepared_messages[user_message_index]["content"] = []
for attachment in attachments:
mime_type = attachment.get("mime_type")
if mime_type and mime_type.startswith("image/"):
try:
# Check if this is a pre-converted image (from PDF-to-image conversion)
# These have 'data' key with base64 already
if "data" in attachment:
base64_image = attachment["data"]
else:
base64_image = self._get_base64_image(attachment)
# Claude uses a specific format for images
prepared_messages[user_message_index]["content"].append(
{
"type": "image",
"source": {
"type": "base64",
"media_type": mime_type,
"data": base64_image,
},
}
)
except Exception as e:
logger.error(
f"Error processing image attachment: {e}", exc_info=True
)
if "content" in attachment:
prepared_messages[user_message_index]["content"].append(
{
"type": "text",
"text": f"[Image could not be processed: {attachment.get('path', 'unknown')}]",
}
)
return prepared_messages
def _get_base64_image(self, attachment):
"""
Convert an image file to base64 encoding.
Args:
attachment (dict): Attachment dictionary with path and metadata.
Returns:
str: Base64-encoded image data.
"""
file_path = attachment.get("path")
if not file_path:
raise ValueError("No file path provided in attachment")
try:
with self.storage.get_file(file_path) as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
except FileNotFoundError:
raise FileNotFoundError(f"File not found: {file_path}")

View File

@@ -13,30 +13,35 @@ class BaseLLM(ABC):
def __init__(
self,
decoded_token=None,
agent_id=None,
model_id=None,
base_url=None,
):
self.decoded_token = decoded_token
self.agent_id = str(agent_id) if agent_id else None
self.model_id = model_id
self.base_url = base_url
self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
self.fallback_provider = settings.FALLBACK_LLM_PROVIDER
self.fallback_model_name = settings.FALLBACK_LLM_NAME
self.fallback_llm_api_key = settings.FALLBACK_LLM_API_KEY
self._fallback_llm = None
self._fallback_sequence_index = 0
@property
def fallback_llm(self):
"""Lazy-loaded fallback LLM instance."""
if (
self._fallback_llm is None
and self.fallback_provider
and self.fallback_model_name
):
"""Lazy-loaded fallback LLM from FALLBACK_* settings."""
if self._fallback_llm is None and settings.FALLBACK_LLM_PROVIDER:
try:
from application.llm.llm_creator import LLMCreator
self._fallback_llm = LLMCreator.create_llm(
self.fallback_provider,
self.fallback_llm_api_key,
None,
self.decoded_token,
settings.FALLBACK_LLM_PROVIDER,
api_key=settings.FALLBACK_LLM_API_KEY or settings.API_KEY,
user_api_key=getattr(self, "user_api_key", None),
decoded_token=self.decoded_token,
model_id=settings.FALLBACK_LLM_NAME,
agent_id=self.agent_id,
)
logger.info(
f"Fallback LLM initialized: {settings.FALLBACK_LLM_PROVIDER}/{settings.FALLBACK_LLM_NAME}"
)
except Exception as e:
logger.error(
@@ -54,7 +59,7 @@ class BaseLLM(ABC):
self, method_name: str, decorators: list, *args, **kwargs
):
"""
Unified method execution with fallback support.
Execute method with fallback support.
Args:
method_name: Name of the raw method ('_raw_gen' or '_raw_gen_stream')
@@ -73,10 +78,10 @@ class BaseLLM(ABC):
return decorated_method()
except Exception as e:
if not self.fallback_llm:
logger.error(f"Primary LLM failed and no fallback available: {str(e)}")
logger.error(f"Primary LLM failed and no fallback configured: {str(e)}")
raise
logger.warning(
f"Falling back to {self.fallback_provider}/{self.fallback_model_name}. Error: {str(e)}"
f"Primary LLM failed. Falling back to {settings.FALLBACK_LLM_PROVIDER}/{settings.FALLBACK_LLM_NAME}. Error: {str(e)}"
)
fallback_method = getattr(

View File

@@ -1,76 +1,19 @@
import json
from application.core.settings import settings
from application.llm.base import BaseLLM
from application.llm.openai import OpenAILLM
DOCSGPT_API_KEY = "sk-docsgpt-public"
DOCSGPT_BASE_URL = "https://oai.arc53.com"
DOCSGPT_MODEL = "docsgpt"
class DocsGPTAPILLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
from openai import OpenAI
super().__init__(*args, **kwargs)
self.client = OpenAI(api_key="sk-docsgpt-public", base_url="https://oai.arc53.com")
self.user_api_key = user_api_key
self.api_key = api_key
def _clean_messages_openai(self, messages):
cleaned_messages = []
for message in messages:
role = message.get("role")
content = message.get("content")
if role == "model":
role = "assistant"
if role and content is not None:
if isinstance(content, str):
cleaned_messages.append({"role": role, "content": content})
elif isinstance(content, list):
for item in content:
if "text" in item:
cleaned_messages.append(
{"role": role, "content": item["text"]}
)
elif "function_call" in item:
cleaned_args = self._remove_null_values(
item["function_call"]["args"]
)
tool_call = {
"id": item["function_call"]["call_id"],
"type": "function",
"function": {
"name": item["function_call"]["name"],
"arguments": json.dumps(cleaned_args),
},
}
cleaned_messages.append(
{
"role": "assistant",
"content": None,
"tool_calls": [tool_call],
}
)
elif "function_response" in item:
cleaned_messages.append(
{
"role": "tool",
"tool_call_id": item["function_response"][
"call_id"
],
"content": json.dumps(
item["function_response"]["response"]["result"]
),
}
)
else:
raise ValueError(
f"Unexpected content dictionary format: {item}"
)
else:
raise ValueError(f"Unexpected content type: {type(content)}")
return cleaned_messages
class DocsGPTAPILLM(OpenAILLM):
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
super().__init__(
api_key=DOCSGPT_API_KEY,
user_api_key=user_api_key,
base_url=DOCSGPT_BASE_URL,
*args,
**kwargs,
)
def _raw_gen(
self,
@@ -80,23 +23,19 @@ class DocsGPTAPILLM(BaseLLM):
stream=False,
tools=None,
engine=settings.AZURE_DEPLOYMENT_NAME,
response_format=None,
**kwargs,
):
messages = self._clean_messages_openai(messages)
if tools:
response = self.client.chat.completions.create(
model="docsgpt",
messages=messages,
stream=stream,
tools=tools,
**kwargs,
)
return response.choices[0]
else:
response = self.client.chat.completions.create(
model="docsgpt", messages=messages, stream=stream, **kwargs
)
return response.choices[0].message.content
return super()._raw_gen(
baseself,
DOCSGPT_MODEL,
messages,
stream=stream,
tools=tools,
engine=engine,
response_format=response_format,
**kwargs,
)
def _raw_gen_stream(
self,
@@ -106,35 +45,16 @@ class DocsGPTAPILLM(BaseLLM):
stream=True,
tools=None,
engine=settings.AZURE_DEPLOYMENT_NAME,
response_format=None,
**kwargs,
):
messages = self._clean_messages_openai(messages)
if tools:
response = self.client.chat.completions.create(
model="docsgpt",
messages=messages,
stream=stream,
tools=tools,
**kwargs,
)
else:
response = self.client.chat.completions.create(
model="docsgpt", messages=messages, stream=stream, **kwargs
)
try:
for line in response:
if (
len(line.choices) > 0
and line.choices[0].delta.content is not None
and len(line.choices[0].delta.content) > 0
):
yield line.choices[0].delta.content
elif len(line.choices) > 0:
yield line.choices[0]
finally:
if hasattr(response, 'close'):
response.close()
def _supports_tools(self):
return True
return super()._raw_gen_stream(
baseself,
DOCSGPT_MODEL,
messages,
stream=stream,
tools=tools,
engine=engine,
response_format=response_format,
**kwargs,
)

View File

@@ -1,4 +1,3 @@
import json
import logging
from google import genai
@@ -11,10 +10,13 @@ from application.storage.storage_creator import StorageCreator
class GoogleLLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.api_key = api_key
def __init__(
self, api_key=None, user_api_key=None, decoded_token=None, *args, **kwargs
):
super().__init__(decoded_token=decoded_token, *args, **kwargs)
self.api_key = api_key or settings.GOOGLE_API_KEY or settings.API_KEY
self.user_api_key = user_api_key
self.client = genai.Client(api_key=self.api_key)
self.storage = StorageCreator.get_storage()
@@ -32,6 +34,12 @@ class GoogleLLM(BaseLLM):
"image/jpg",
"image/webp",
"image/gif",
"application/pdf",
"image/png",
"image/jpeg",
"image/jpg",
"image/webp",
"image/gif",
]
def prepare_messages_with_attachments(self, messages, attachments=None):
@@ -47,21 +55,19 @@ class GoogleLLM(BaseLLM):
"""
if not attachments:
return messages
prepared_messages = messages.copy()
# Find the user message to attach files to the last one
user_message_index = None
for i in range(len(prepared_messages) - 1, -1, -1):
if prepared_messages[i].get("role") == "user":
user_message_index = i
break
if user_message_index is None:
user_message = {"role": "user", "content": []}
prepared_messages.append(user_message)
user_message_index = len(prepared_messages) - 1
if isinstance(prepared_messages[user_message_index].get("content"), str):
text_content = prepared_messages[user_message_index]["content"]
prepared_messages[user_message_index]["content"] = [
@@ -69,7 +75,6 @@ class GoogleLLM(BaseLLM):
]
elif not isinstance(prepared_messages[user_message_index].get("content"), list):
prepared_messages[user_message_index]["content"] = []
files = []
for attachment in attachments:
mime_type = attachment.get("mime_type")
@@ -92,11 +97,9 @@ class GoogleLLM(BaseLLM):
"text": f"[File could not be processed: {attachment.get('path', 'unknown')}]",
}
)
if files:
logging.info(f"GoogleLLM: Adding {len(files)} files to message")
prepared_messages[user_message_index]["content"].append({"files": files})
return prepared_messages
def _upload_file_to_google(self, attachment):
@@ -111,14 +114,11 @@ class GoogleLLM(BaseLLM):
"""
if "google_file_uri" in attachment:
return attachment["google_file_uri"]
file_path = attachment.get("path")
if not file_path:
raise ValueError("No file path provided in attachment")
if not self.storage.file_exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
try:
file_uri = self.storage.process_file(
file_path,
@@ -136,24 +136,52 @@ class GoogleLLM(BaseLLM):
attachments_collection.update_one(
{"_id": attachment["_id"]}, {"$set": {"google_file_uri": file_uri}}
)
return file_uri
except Exception as e:
logging.error(f"Error uploading file to Google AI: {e}", exc_info=True)
raise
def _clean_messages_google(self, messages):
"""Convert OpenAI format messages to Google AI format."""
"""
Convert OpenAI format messages to Google AI format and collect system prompts.
Returns:
tuple[list[types.Content], Optional[str]]: cleaned messages and optional
combined system instruction.
"""
cleaned_messages = []
system_instructions = []
def _extract_system_text(content):
if isinstance(content, str):
return content
if isinstance(content, list):
parts = []
for item in content:
if (
isinstance(item, dict)
and "text" in item
and item["text"] is not None
):
parts.append(item["text"])
return "\n".join(parts)
return ""
for message in messages:
role = message.get("role")
content = message.get("content")
# Gemini only accepts user/model in the contents list.
if role == "system":
sys_text = _extract_system_text(content)
if sys_text:
system_instructions.append(sys_text)
continue
if role == "assistant":
role = "model"
elif role == "tool":
role = "model"
parts = []
if role and content is not None:
if isinstance(content, str):
@@ -164,15 +192,31 @@ class GoogleLLM(BaseLLM):
parts.append(types.Part.from_text(text=item["text"]))
elif "function_call" in item:
# Remove null values from args to avoid API errors
cleaned_args = self._remove_null_values(
item["function_call"]["args"]
)
parts.append(
types.Part.from_function_call(
name=item["function_call"]["name"],
args=cleaned_args,
# Create function call part with thought_signature if present
# For Gemini 3 models, we need to include thought_signature
if "thought_signature" in item:
# Use Part constructor with functionCall and thoughtSignature
parts.append(
types.Part(
functionCall=types.FunctionCall(
name=item["function_call"]["name"],
args=cleaned_args,
),
thoughtSignature=item["thought_signature"],
)
)
else:
# Use helper method when no thought_signature
parts.append(
types.Part.from_function_call(
name=item["function_call"]["name"],
args=cleaned_args,
)
)
)
elif "function_response" in item:
parts.append(
types.Part.from_function_response(
@@ -194,11 +238,12 @@ class GoogleLLM(BaseLLM):
)
else:
raise ValueError(f"Unexpected content type: {type(content)}")
if parts:
cleaned_messages.append(types.Content(role=role, parts=parts))
return cleaned_messages
system_instruction = (
"\n\n".join(system_instructions) if system_instructions else None
)
return cleaned_messages, system_instruction
def _clean_schema(self, schema_obj):
"""
@@ -233,8 +278,8 @@ class GoogleLLM(BaseLLM):
cleaned[key] = [self._clean_schema(item) for item in value]
else:
cleaned[key] = value
# Validate that required properties actually exist in properties
if "required" in cleaned and "properties" in cleaned:
valid_required = []
properties_keys = set(cleaned["properties"].keys())
@@ -247,7 +292,6 @@ class GoogleLLM(BaseLLM):
cleaned.pop("required", None)
elif "required" in cleaned and "properties" not in cleaned:
cleaned.pop("required", None)
return cleaned
def _clean_tools_format(self, tools_list):
@@ -263,7 +307,6 @@ class GoogleLLM(BaseLLM):
cleaned_properties = {}
for k, v in properties.items():
cleaned_properties[k] = self._clean_schema(v)
genai_function = dict(
name=function["name"],
description=function["description"],
@@ -282,12 +325,84 @@ class GoogleLLM(BaseLLM):
name=function["name"],
description=function["description"],
)
genai_tool = types.Tool(function_declarations=[genai_function])
genai_tools.append(genai_tool)
return genai_tools
def _extract_preview_from_message(self, message):
"""Get a short, human-readable preview from the last message."""
try:
if hasattr(message, "parts"):
for part in reversed(message.parts):
if getattr(part, "text", None):
return part.text
function_call = getattr(part, "function_call", None)
if function_call:
name = getattr(function_call, "name", "") or "function_call"
return f"function_call:{name}"
function_response = getattr(part, "function_response", None)
if function_response:
name = (
getattr(function_response, "name", "")
or "function_response"
)
return f"function_response:{name}"
if isinstance(message, dict):
content = message.get("content")
if isinstance(content, str):
return content
if isinstance(content, list):
for item in reversed(content):
if isinstance(item, str):
return item
if isinstance(item, dict):
if item.get("text"):
return item["text"]
if item.get("function_call"):
fn = item["function_call"]
if isinstance(fn, dict):
name = fn.get("name") or "function_call"
return f"function_call:{name}"
return "function_call"
if item.get("function_response"):
resp = item["function_response"]
if isinstance(resp, dict):
name = resp.get("name") or "function_response"
return f"function_response:{name}"
return "function_response"
if "text" in message and isinstance(message["text"], str):
return message["text"]
except Exception:
pass
return str(message)
def _summarize_messages_for_log(self, messages, preview_chars=20):
"""Return a compact summary for logging to avoid huge payloads."""
message_count = len(messages) if messages else 0
last_preview = ""
if messages:
last_preview = self._extract_preview_from_message(messages[-1]) or ""
last_preview = str(last_preview).replace("\n", " ")
if len(last_preview) > preview_chars:
last_preview = f"{last_preview[:preview_chars]}..."
return f"count={message_count}, last='{last_preview}'"
@staticmethod
def _get_text_value(part):
"""Get text from both SDK objects and dict-shaped test doubles."""
if isinstance(part, dict):
value = part.get("text")
return value if isinstance(value, str) else ""
value = getattr(part, "text", None)
return value if isinstance(value, str) else ""
@staticmethod
def _is_thought_part(part):
"""Detect Gemini thinking parts when available."""
if isinstance(part, dict):
return bool(part.get("thought"))
return bool(getattr(part, "thought", False))
def _raw_gen(
self,
baseself,
@@ -301,22 +416,20 @@ class GoogleLLM(BaseLLM):
):
"""Generate content using Google AI API without streaming."""
client = genai.Client(api_key=self.api_key)
system_instruction = None
if formatting == "openai":
messages = self._clean_messages_google(messages)
messages, system_instruction = self._clean_messages_google(messages)
config = types.GenerateContentConfig()
if messages[0].role == "system":
config.system_instruction = messages[0].parts[0].text
messages = messages[1:]
if system_instruction:
config.system_instruction = system_instruction
if tools:
cleaned_tools = self._clean_tools_format(tools)
config.tools = cleaned_tools
# Add response schema for structured output if provided
if response_schema:
config.response_schema = response_schema
config.response_mime_type = "application/json"
response = client.models.generate_content(
model=model,
contents=messages,
@@ -341,23 +454,21 @@ class GoogleLLM(BaseLLM):
):
"""Generate content using Google AI API with streaming."""
client = genai.Client(api_key=self.api_key)
system_instruction = None
if formatting == "openai":
messages = self._clean_messages_google(messages)
messages, system_instruction = self._clean_messages_google(messages)
config = types.GenerateContentConfig()
if messages[0].role == "system":
config.system_instruction = messages[0].parts[0].text
messages = messages[1:]
if system_instruction:
config.system_instruction = system_instruction
if tools:
cleaned_tools = self._clean_tools_format(tools)
config.tools = cleaned_tools
# Add response schema for structured output if provided
if response_schema:
config.response_schema = response_schema
config.response_mime_type = "application/json"
# Check if we have both tools and file attachments
has_attachments = False
for message in messages:
for part in message.parts:
@@ -366,9 +477,12 @@ class GoogleLLM(BaseLLM):
break
if has_attachments:
break
messages_summary = self._summarize_messages_for_log(messages)
logging.info(
f"GoogleLLM: Starting stream generation. Model: {model}, Messages: {json.dumps(messages, default=str)}, Has attachments: {has_attachments}"
"GoogleLLM: Starting stream generation. Model: %s, Messages: %s, Has attachments: %s",
model,
messages_summary,
has_attachments,
)
response = client.models.generate_content_stream(
@@ -385,10 +499,26 @@ class GoogleLLM(BaseLLM):
for part in candidate.content.parts:
if part.function_call:
yield part
elif part.text:
yield part.text
continue
part_text = self._get_text_value(part)
if not part_text:
continue
if self._is_thought_part(part):
yield {"type": "thought", "thought": part_text}
else:
yield part_text
elif hasattr(chunk, "text"):
yield chunk.text
chunk_text = self._get_text_value(chunk)
if chunk_text:
if self._is_thought_part(chunk):
yield {"type": "thought", "thought": chunk_text}
else:
yield chunk_text
except Exception as e:
logging.error(f"GoogleLLM: Stream error: {e}", exc_info=True)
raise
finally:
if hasattr(response, "close"):
response.close()
@@ -405,7 +535,6 @@ class GoogleLLM(BaseLLM):
"""Convert JSON schema to Google AI structured output format."""
if not json_schema:
return None
type_map = {
"object": "OBJECT",
"array": "ARRAY",
@@ -418,12 +547,10 @@ class GoogleLLM(BaseLLM):
def convert(schema):
if not isinstance(schema, dict):
return schema
result = {}
schema_type = schema.get("type")
if schema_type:
result["type"] = type_map.get(schema_type.lower(), schema_type.upper())
for key in [
"description",
"nullable",
@@ -435,7 +562,6 @@ class GoogleLLM(BaseLLM):
]:
if key in schema:
result[key] = schema[key]
if "format" in schema:
format_value = schema["format"]
if schema_type == "string":
@@ -445,21 +571,17 @@ class GoogleLLM(BaseLLM):
result["format"] = format_value
else:
result["format"] = format_value
if "properties" in schema:
result["properties"] = {
k: convert(v) for k, v in schema["properties"].items()
}
if "propertyOrdering" not in result and result.get("type") == "OBJECT":
result["propertyOrdering"] = list(result["properties"].keys())
if "items" in schema:
result["items"] = convert(schema["items"])
for field in ["anyOf", "oneOf", "allOf"]:
if field in schema:
result[field] = [convert(s) for s in schema[field]]
return result
try:

View File

@@ -1,32 +1,15 @@
from application.llm.base import BaseLLM
from openai import OpenAI
from application.core.settings import settings
from application.llm.openai import OpenAILLM
GROQ_BASE_URL = "https://api.groq.com/openai/v1"
class GroqLLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.client = OpenAI(api_key=api_key, base_url="https://api.groq.com/openai/v1")
self.api_key = api_key
self.user_api_key = user_api_key
def _raw_gen(self, baseself, model, messages, stream=False, tools=None, **kwargs):
if tools:
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, tools=tools, **kwargs
)
return response.choices[0]
else:
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)
return response.choices[0].message.content
def _raw_gen_stream(
self, baseself, model, messages, stream=True, tools=None, **kwargs
):
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
class GroqLLM(OpenAILLM):
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
super().__init__(
api_key=api_key or settings.GROQ_API_KEY or settings.API_KEY,
user_api_key=user_api_key,
base_url=base_url or GROQ_BASE_URL,
*args,
**kwargs,
)
for line in response:
if line.choices[0].delta.content is not None:
yield line.choices[0].delta.content

View File

@@ -1,4 +1,5 @@
import logging
import uuid
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Generator, List, Optional, Union
@@ -16,6 +17,7 @@ class ToolCall:
name: str
arguments: Union[str, Dict]
index: Optional[int] = None
thought_signature: Optional[str] = None
@classmethod
def from_dict(cls, data: Dict) -> "ToolCall":
@@ -103,6 +105,7 @@ class LLMHandler(ABC):
"""
Prepare messages with attachments and provider-specific formatting.
Args:
agent: The agent instance
messages: Original messages
@@ -116,11 +119,40 @@ class LLMHandler(ABC):
logger.info(f"Preparing messages with {len(attachments)} attachments")
supported_types = agent.llm.get_supported_attachment_types()
# Check if provider supports images but not PDF (synthetic PDF support)
supports_images = any(t.startswith("image/") for t in supported_types)
supports_pdf = "application/pdf" in supported_types
# Process attachments, converting PDFs to images if needed
processed_attachments = []
for attachment in attachments:
mime_type = attachment.get("mime_type")
# Synthetic PDF support: convert PDF to images if LLM supports images but not PDF
if mime_type == "application/pdf" and supports_images and not supports_pdf:
logger.info(
f"Converting PDF to images for synthetic PDF support: {attachment.get('path', 'unknown')}"
)
try:
converted_images = self._convert_pdf_to_images(attachment)
processed_attachments.extend(converted_images)
logger.info(
f"Converted PDF to {len(converted_images)} images"
)
except Exception as e:
logger.error(
f"Failed to convert PDF to images, falling back to text: {e}"
)
# Fall back to treating as unsupported (text extraction)
processed_attachments.append(attachment)
else:
processed_attachments.append(attachment)
supported_attachments = [
a for a in attachments if a.get("mime_type") in supported_types
a for a in processed_attachments if a.get("mime_type") in supported_types
]
unsupported_attachments = [
a for a in attachments if a.get("mime_type") not in supported_types
a for a in processed_attachments if a.get("mime_type") not in supported_types
]
# Process supported attachments with the LLM's custom method
@@ -143,6 +175,37 @@ class LLMHandler(ABC):
)
return messages
def _convert_pdf_to_images(self, attachment: Dict) -> List[Dict]:
"""
Convert a PDF attachment to a list of image attachments.
This enables synthetic PDF support for LLMs that support images but not PDFs.
Args:
attachment: PDF attachment dictionary with 'path' and optional 'content'
Returns:
List of image attachment dictionaries with 'data', 'mime_type', and 'page'
"""
from application.utils import convert_pdf_to_images
from application.storage.storage_creator import StorageCreator
file_path = attachment.get("path")
if not file_path:
raise ValueError("No file path provided in PDF attachment")
storage = StorageCreator.get_storage()
# Convert PDF to images
images_data = convert_pdf_to_images(
file_path=file_path,
storage=storage,
max_pages=20,
dpi=150,
)
return images_data
def _append_unsupported_attachments(
self, messages: List[Dict], attachments: List[Dict]
) -> List[Dict]:
@@ -178,6 +241,407 @@ class LLMHandler(ABC):
system_msg["content"] += f"\n\n{combined_text}"
return prepared_messages
def _prune_messages_minimal(self, messages: List[Dict]) -> Optional[List[Dict]]:
"""
Build a minimal context: system prompt + latest user message only.
Drops all tool/function messages to shrink context aggressively.
"""
system_message = next((m for m in messages if m.get("role") == "system"), None)
if not system_message:
logger.warning("Cannot prune messages minimally: missing system message.")
return None
last_non_system = None
for m in reversed(messages):
if m.get("role") == "user":
last_non_system = m
break
if not last_non_system and m.get("role") not in ("system", None):
last_non_system = m
if not last_non_system:
logger.warning("Cannot prune messages minimally: missing user/assistant messages.")
return None
logger.info("Pruning context to system + latest user/assistant message to proceed.")
return [system_message, last_non_system]
def _extract_text_from_content(self, content: Any) -> str:
"""
Convert message content (str or list of parts) to plain text for compression.
"""
if isinstance(content, str):
return content
if isinstance(content, list):
parts_text = []
for item in content:
if isinstance(item, dict):
if "text" in item and item["text"] is not None:
parts_text.append(str(item["text"]))
elif "function_call" in item or "function_response" in item:
# Keep serialized function calls/responses so the compressor sees actions
parts_text.append(str(item))
elif "files" in item:
parts_text.append(str(item))
return "\n".join(parts_text)
return ""
def _build_conversation_from_messages(self, messages: List[Dict]) -> Optional[Dict]:
"""
Build a conversation-like dict from current messages so we can compress
even when the conversation isn't persisted yet. Includes tool calls/results.
"""
queries = []
current_prompt = None
current_tool_calls = {}
def _commit_query(response_text: str):
nonlocal current_prompt, current_tool_calls
if current_prompt is None and not response_text:
return
tool_calls_list = list(current_tool_calls.values())
queries.append(
{
"prompt": current_prompt or "",
"response": response_text,
"tool_calls": tool_calls_list,
}
)
current_prompt = None
current_tool_calls = {}
for message in messages:
role = message.get("role")
content = message.get("content")
if role == "user":
current_prompt = self._extract_text_from_content(content)
elif role in {"assistant", "model"}:
# If this assistant turn contains tool calls, collect them; otherwise commit a response.
if isinstance(content, list):
for item in content:
if "function_call" in item:
fc = item["function_call"]
call_id = fc.get("call_id") or str(uuid.uuid4())
current_tool_calls[call_id] = {
"tool_name": "unknown_tool",
"action_name": fc.get("name"),
"arguments": fc.get("args"),
"result": None,
"status": "called",
"call_id": call_id,
}
elif "function_response" in item:
fr = item["function_response"]
call_id = fr.get("call_id") or str(uuid.uuid4())
current_tool_calls[call_id] = {
"tool_name": "unknown_tool",
"action_name": fr.get("name"),
"arguments": None,
"result": fr.get("response", {}).get("result"),
"status": "completed",
"call_id": call_id,
}
# No direct assistant text here; continue to next message
continue
response_text = self._extract_text_from_content(content)
_commit_query(response_text)
elif role == "tool":
# Attach tool outputs to the latest pending tool call if possible
tool_text = self._extract_text_from_content(content)
# Attempt to parse function_response style
call_id = None
if isinstance(content, list):
for item in content:
if "function_response" in item and item["function_response"].get("call_id"):
call_id = item["function_response"]["call_id"]
break
if call_id and call_id in current_tool_calls:
current_tool_calls[call_id]["result"] = tool_text
current_tool_calls[call_id]["status"] = "completed"
elif queries:
queries[-1].setdefault("tool_calls", []).append(
{
"tool_name": "unknown_tool",
"action_name": "unknown_action",
"arguments": {},
"result": tool_text,
"status": "completed",
}
)
# If there's an unfinished prompt with tool_calls but no response yet, commit it
if current_prompt is not None or current_tool_calls:
_commit_query(response_text="")
if not queries:
return None
return {
"queries": queries,
"compression_metadata": {
"is_compressed": False,
"compression_points": [],
},
}
def _rebuild_messages_after_compression(
self,
messages: List[Dict],
compressed_summary: Optional[str],
recent_queries: List[Dict],
include_current_execution: bool = False,
include_tool_calls: bool = False,
) -> Optional[List[Dict]]:
"""
Rebuild the message list after compression so tool execution can continue.
Delegates to MessageBuilder for the actual reconstruction.
"""
from application.api.answer.services.compression.message_builder import (
MessageBuilder,
)
return MessageBuilder.rebuild_messages_after_compression(
messages=messages,
compressed_summary=compressed_summary,
recent_queries=recent_queries,
include_current_execution=include_current_execution,
include_tool_calls=include_tool_calls,
)
def _perform_mid_execution_compression(
self, agent, messages: List[Dict]
) -> tuple[bool, Optional[List[Dict]]]:
"""
Perform compression during tool execution and rebuild messages.
Uses the new orchestrator for simplified compression.
Args:
agent: The agent instance
messages: Current conversation messages
Returns:
(success: bool, rebuilt_messages: Optional[List[Dict]])
"""
try:
from application.api.answer.services.compression import (
CompressionOrchestrator,
)
from application.api.answer.services.conversation_service import (
ConversationService,
)
conversation_service = ConversationService()
orchestrator = CompressionOrchestrator(conversation_service)
# Get conversation from database (may be None for new sessions)
conversation = conversation_service.get_conversation(
agent.conversation_id, agent.initial_user_id
)
if conversation:
# Merge current in-flight messages (including tool calls)
conversation_from_msgs = self._build_conversation_from_messages(messages)
if conversation_from_msgs:
conversation = conversation_from_msgs
else:
logger.warning(
"Could not load conversation for compression; attempting in-memory compression"
)
return self._perform_in_memory_compression(agent, messages)
# Use orchestrator to perform compression
result = orchestrator.compress_mid_execution(
conversation_id=agent.conversation_id,
user_id=agent.initial_user_id,
model_id=agent.model_id,
decoded_token=getattr(agent, "decoded_token", {}),
current_conversation=conversation,
)
if not result.success:
logger.warning(f"Mid-execution compression failed: {result.error}")
# Try minimal pruning as fallback
pruned = self._prune_messages_minimal(messages)
if pruned:
agent.context_limit_reached = False
agent.current_token_count = 0
return True, pruned
return False, None
if not result.compression_performed:
logger.warning("Compression not performed")
return False, None
# Check if compression actually reduced tokens
if result.metadata:
if result.metadata.compressed_token_count >= result.metadata.original_token_count:
logger.warning(
"Compression did not reduce token count; falling back to minimal pruning"
)
pruned = self._prune_messages_minimal(messages)
if pruned:
agent.context_limit_reached = False
agent.current_token_count = 0
return True, pruned
return False, None
logger.info(
f"Mid-execution compression successful - ratio: {result.metadata.compression_ratio:.1f}x, "
f"saved {result.metadata.original_token_count - result.metadata.compressed_token_count} tokens"
)
# Also store the compression summary as a visible message
if result.metadata:
conversation_service.append_compression_message(
agent.conversation_id, result.metadata.to_dict()
)
# Update agent's compressed summary for downstream persistence
agent.compressed_summary = result.compressed_summary
agent.compression_metadata = result.metadata.to_dict() if result.metadata else None
agent.compression_saved = False
# Reset the context limit flag so tools can continue
agent.context_limit_reached = False
agent.current_token_count = 0
# Rebuild messages
rebuilt_messages = self._rebuild_messages_after_compression(
messages,
result.compressed_summary,
result.recent_queries,
include_current_execution=False,
include_tool_calls=False,
)
if rebuilt_messages is None:
return False, None
return True, rebuilt_messages
except Exception as e:
logger.error(
f"Error performing mid-execution compression: {str(e)}", exc_info=True
)
return False, None
def _perform_in_memory_compression(
self, agent, messages: List[Dict]
) -> tuple[bool, Optional[List[Dict]]]:
"""
Fallback compression path when the conversation is not yet persisted.
Uses CompressionService directly without DB persistence.
"""
try:
from application.api.answer.services.compression.service import (
CompressionService,
)
from application.core.model_utils import (
get_api_key_for_provider,
get_provider_from_model_id,
)
from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
conversation = self._build_conversation_from_messages(messages)
if not conversation:
logger.warning(
"Cannot perform in-memory compression: no user/assistant turns found"
)
return False, None
compression_model = (
settings.COMPRESSION_MODEL_OVERRIDE
if settings.COMPRESSION_MODEL_OVERRIDE
else agent.model_id
)
provider = get_provider_from_model_id(compression_model)
api_key = get_api_key_for_provider(provider)
compression_llm = LLMCreator.create_llm(
provider,
api_key,
getattr(agent, "user_api_key", None),
getattr(agent, "decoded_token", None),
model_id=compression_model,
agent_id=getattr(agent, "agent_id", None),
)
# Create service without DB persistence capability
compression_service = CompressionService(
llm=compression_llm,
model_id=compression_model,
conversation_service=None, # No DB updates for in-memory
)
queries_count = len(conversation.get("queries", []))
compress_up_to = queries_count - 1
if compress_up_to < 0 or queries_count == 0:
logger.warning("Not enough queries to compress in-memory context")
return False, None
metadata = compression_service.compress_conversation(
conversation,
compress_up_to_index=compress_up_to,
)
# If compression doesn't reduce tokens, fall back to minimal pruning
if (
metadata.compressed_token_count
>= metadata.original_token_count
):
logger.warning(
"In-memory compression did not reduce token count; falling back to minimal pruning"
)
pruned = self._prune_messages_minimal(messages)
if pruned:
agent.context_limit_reached = False
agent.current_token_count = 0
return True, pruned
return False, None
# Attach metadata to synthetic conversation
conversation["compression_metadata"] = {
"is_compressed": True,
"compression_points": [metadata.to_dict()],
}
compressed_summary, recent_queries = (
compression_service.get_compressed_context(conversation)
)
agent.compressed_summary = compressed_summary
agent.compression_metadata = metadata.to_dict()
agent.compression_saved = False
agent.context_limit_reached = False
agent.current_token_count = 0
rebuilt_messages = self._rebuild_messages_after_compression(
messages,
compressed_summary,
recent_queries,
include_current_execution=False,
include_tool_calls=False,
)
if rebuilt_messages is None:
return False, None
logger.info(
f"In-memory compression successful - ratio: {metadata.compression_ratio:.1f}x, "
f"saved {metadata.original_token_count - metadata.compressed_token_count} tokens"
)
return True, rebuilt_messages
except Exception as e:
logger.error(
f"Error performing in-memory compression: {str(e)}", exc_info=True
)
return False, None
def handle_tool_calls(
self, agent, tool_calls: List[ToolCall], tools_dict: Dict, messages: List[Dict]
) -> Generator:
@@ -195,7 +659,110 @@ class LLMHandler(ABC):
"""
updated_messages = messages.copy()
for call in tool_calls:
for i, call in enumerate(tool_calls):
# Check context limit before executing tool call
if hasattr(agent, '_check_context_limit') and agent._check_context_limit(updated_messages):
# Context limit reached - attempt mid-execution compression
compression_attempted = False
compression_successful = False
try:
from application.core.settings import settings
compression_enabled = settings.ENABLE_CONVERSATION_COMPRESSION
except Exception:
compression_enabled = False
if compression_enabled:
compression_attempted = True
try:
logger.info(
f"Context limit reached with {len(tool_calls) - i} remaining tool calls. "
f"Attempting mid-execution compression..."
)
# Trigger mid-execution compression (DB-backed if available, otherwise in-memory)
compression_successful, rebuilt_messages = self._perform_mid_execution_compression(
agent, updated_messages
)
if compression_successful and rebuilt_messages is not None:
# Update the messages list with rebuilt compressed version
updated_messages = rebuilt_messages
# Yield compression success message
yield {
"type": "info",
"data": {
"message": "Context window limit reached. Compressed conversation history to continue processing."
}
}
logger.info(
f"Mid-execution compression successful. Continuing with {len(tool_calls) - i} remaining tool calls."
)
# Proceed to execute the current tool call with the reduced context
else:
logger.warning("Mid-execution compression attempted but failed. Skipping remaining tools.")
except Exception as e:
logger.error(f"Error during mid-execution compression: {str(e)}", exc_info=True)
compression_attempted = True
compression_successful = False
# If compression wasn't attempted or failed, skip remaining tools
if not compression_successful:
if i == 0:
# Special case: limit reached before executing any tools
# This can happen when previous tool responses pushed context over limit
if compression_attempted:
logger.warning(
f"Context limit reached before executing any tools. "
f"Compression attempted but failed. "
f"Skipping all {len(tool_calls)} pending tool call(s). "
f"This typically occurs when previous tool responses contained large amounts of data."
)
else:
logger.warning(
f"Context limit reached before executing any tools. "
f"Skipping all {len(tool_calls)} pending tool call(s). "
f"This typically occurs when previous tool responses contained large amounts of data. "
f"Consider enabling compression or using a model with larger context window."
)
else:
# Normal case: executed some tools, now stopping
tool_word = "tool call" if i == 1 else "tool calls"
remaining = len(tool_calls) - i
remaining_word = "tool call" if remaining == 1 else "tool calls"
if compression_attempted:
logger.warning(
f"Context limit reached after executing {i} {tool_word}. "
f"Compression attempted but failed. "
f"Skipping remaining {remaining} {remaining_word}."
)
else:
logger.warning(
f"Context limit reached after executing {i} {tool_word}. "
f"Skipping remaining {remaining} {remaining_word}. "
f"Consider enabling compression or using a model with larger context window."
)
# Mark remaining tools as skipped
for remaining_call in tool_calls[i:]:
skip_message = {
"type": "tool_call",
"data": {
"tool_name": "system",
"call_id": remaining_call.id,
"action_name": remaining_call.name,
"arguments": {},
"result": "Skipped: Context limit reached. Too many tool calls in conversation.",
"status": "skipped"
}
}
yield skip_message
# Set flag on agent
agent.context_limit_reached = True
break
try:
self.tool_calls.append(call)
tool_executor_gen = agent._execute_tool_action(tools_dict, call)
@@ -205,21 +772,26 @@ class LLMHandler(ABC):
except StopIteration as e:
tool_response, call_id = e.value
break
function_call_content = {
"function_call": {
"name": call.name,
"args": call.arguments,
"call_id": call_id,
}
}
# Include thought_signature for Google Gemini 3 models
# It should be at the same level as function_call, not inside it
if call.thought_signature:
function_call_content["thought_signature"] = call.thought_signature
updated_messages.append(
{
"role": "assistant",
"content": [
{
"function_call": {
"name": call.name,
"args": call.arguments,
"call_id": call_id,
}
}
],
"content": [function_call_content],
}
)
updated_messages.append(self.create_tool_message(call, tool_response))
except Exception as e:
logger.error(f"Error executing tool: {str(e)}", exc_info=True)
@@ -282,7 +854,7 @@ class LLMHandler(ABC):
messages = e.value
break
response = agent.llm.gen(
model=agent.gpt_model, messages=messages, tools=agent.tools
model=agent.model_id, messages=messages, tools=agent.tools
)
parsed = self.parse_response(response)
self.llm_calls.append(build_stack_data(agent.llm))
@@ -307,6 +879,9 @@ class LLMHandler(ABC):
tool_calls = {}
for chunk in self._iterate_stream(response):
if isinstance(chunk, dict) and chunk.get("type") == "thought":
yield chunk
continue
if isinstance(chunk, str):
yield chunk
continue
@@ -323,7 +898,13 @@ class LLMHandler(ABC):
if call.name:
existing.name = call.name
if call.arguments:
existing.arguments += call.arguments
if existing.arguments is None:
existing.arguments = call.arguments
else:
existing.arguments += call.arguments
# Preserve thought_signature for Google Gemini 3 models
if call.thought_signature:
existing.thought_signature = call.thought_signature
if parsed.finish_reason == "tool_calls":
tool_handler_gen = self.handle_tool_calls(
agent, list(tool_calls.values()), tools_dict, messages
@@ -336,8 +917,21 @@ class LLMHandler(ABC):
break
tool_calls = {}
# Check if context limit was reached during tool execution
if hasattr(agent, 'context_limit_reached') and agent.context_limit_reached:
# Add system message warning about context limit
messages.append({
"role": "system",
"content": (
"WARNING: Context window limit has been reached. "
"Please provide a final response to the user without making additional tool calls. "
"Summarize the work completed so far."
)
})
logger.info("Context limit reached - instructing agent to wrap up")
response = agent.llm.gen_stream(
model=agent.gpt_model, messages=messages, tools=agent.tools
model=agent.model_id, messages=messages, tools=agent.tools if not agent.context_limit_reached else None
)
self.llm_calls.append(build_stack_data(agent.llm))

View File

@@ -19,15 +19,20 @@ class GoogleLLMHandler(LLMHandler):
)
if hasattr(response, "candidates"):
parts = response.candidates[0].content.parts if response.candidates else []
tool_calls = [
ToolCall(
id=str(uuid.uuid4()),
name=part.function_call.name,
arguments=part.function_call.args,
)
for part in parts
if hasattr(part, "function_call") and part.function_call is not None
]
tool_calls = []
for idx, part in enumerate(parts):
if hasattr(part, "function_call") and part.function_call is not None:
has_sig = hasattr(part, "thought_signature") and part.thought_signature is not None
thought_sig = part.thought_signature if has_sig else None
tool_calls.append(
ToolCall(
id=str(uuid.uuid4()),
name=part.function_call.name,
arguments=part.function_call.args,
index=idx,
thought_signature=thought_sig,
)
)
content = " ".join(
part.text
@@ -41,13 +46,17 @@ class GoogleLLMHandler(LLMHandler):
raw_response=response,
)
else:
# This branch handles individual Part objects from streaming responses
tool_calls = []
if hasattr(response, "function_call"):
if hasattr(response, "function_call") and response.function_call is not None:
has_sig = hasattr(response, "thought_signature") and response.thought_signature is not None
thought_sig = response.thought_signature if has_sig else None
tool_calls.append(
ToolCall(
id=str(uuid.uuid4()),
name=response.function_call.name,
arguments=response.function_call.args,
thought_signature=thought_sig,
)
)
return LLMResponse(

View File

@@ -1,68 +0,0 @@
from application.llm.base import BaseLLM
class HuggingFaceLLM(BaseLLM):
def __init__(
self,
api_key=None,
user_api_key=None,
llm_name="Arc53/DocsGPT-7B",
q=False,
*args,
**kwargs,
):
global hf
from langchain.llms import HuggingFacePipeline
if q:
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
pipeline,
BitsAndBytesConfig,
)
tokenizer = AutoTokenizer.from_pretrained(llm_name)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
model = AutoModelForCausalLM.from_pretrained(
llm_name, quantization_config=bnb_config
)
else:
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
tokenizer = AutoTokenizer.from_pretrained(llm_name)
model = AutoModelForCausalLM.from_pretrained(llm_name)
super().__init__(*args, **kwargs)
self.api_key = api_key
self.user_api_key = user_api_key
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=2000,
device_map="auto",
eos_token_id=tokenizer.eos_token_id,
)
hf = HuggingFacePipeline(pipeline=pipe)
def _raw_gen(self, baseself, model, messages, stream=False, **kwargs):
context = messages[0]["content"]
user_question = messages[-1]["content"]
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
result = hf(prompt)
return result.content
def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs):
raise NotImplementedError("HuggingFaceLLM Streaming is not implemented yet.")

View File

@@ -1,13 +1,17 @@
from application.llm.groq import GroqLLM
from application.llm.openai import OpenAILLM, AzureOpenAILLM
from application.llm.sagemaker import SagemakerAPILLM
from application.llm.huggingface import HuggingFaceLLM
from application.llm.llama_cpp import LlamaCpp
import logging
from application.llm.anthropic import AnthropicLLM
from application.llm.docsgpt_provider import DocsGPTAPILLM
from application.llm.premai import PremAILLM
from application.llm.google_ai import GoogleLLM
from application.llm.groq import GroqLLM
from application.llm.llama_cpp import LlamaCpp
from application.llm.novita import NovitaLLM
from application.llm.openai import AzureOpenAILLM, OpenAILLM
from application.llm.premai import PremAILLM
from application.llm.sagemaker import SagemakerAPILLM
from application.llm.open_router import OpenRouterLLM
logger = logging.getLogger(__name__)
class LLMCreator:
@@ -15,7 +19,6 @@ class LLMCreator:
"openai": OpenAILLM,
"azure_openai": AzureOpenAILLM,
"sagemaker": SagemakerAPILLM,
"huggingface": HuggingFaceLLM,
"llama.cpp": LlamaCpp,
"anthropic": AnthropicLLM,
"docsgpt": DocsGPTAPILLM,
@@ -23,13 +26,39 @@ class LLMCreator:
"groq": GroqLLM,
"google": GoogleLLM,
"novita": NovitaLLM,
"openrouter": OpenRouterLLM,
}
@classmethod
def create_llm(cls, type, api_key, user_api_key, decoded_token, *args, **kwargs):
def create_llm(
cls,
type,
api_key,
user_api_key,
decoded_token,
model_id=None,
agent_id=None,
*args,
**kwargs,
):
from application.core.model_utils import get_base_url_for_model
llm_class = cls.llms.get(type.lower())
if not llm_class:
raise ValueError(f"No LLM class found for type {type}")
# Extract base_url from model configuration if model_id is provided
base_url = None
if model_id:
base_url = get_base_url_for_model(model_id)
return llm_class(
api_key, user_api_key, decoded_token=decoded_token, *args, **kwargs
api_key,
user_api_key,
decoded_token=decoded_token,
model_id=model_id,
agent_id=agent_id,
base_url=base_url,
*args,
**kwargs,
)

View File

@@ -1,32 +1,15 @@
from application.llm.base import BaseLLM
from openai import OpenAI
from application.core.settings import settings
from application.llm.openai import OpenAILLM
NOVITA_BASE_URL = "https://api.novita.ai/v3/openai"
class NovitaLLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.client = OpenAI(api_key=api_key, base_url="https://api.novita.ai/v3/openai")
self.api_key = api_key
self.user_api_key = user_api_key
def _raw_gen(self, baseself, model, messages, stream=False, tools=None, **kwargs):
if tools:
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, tools=tools, **kwargs
)
return response.choices[0]
else:
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)
return response.choices[0].message.content
def _raw_gen_stream(
self, baseself, model, messages, stream=True, tools=None, **kwargs
):
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
class NovitaLLM(OpenAILLM):
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
super().__init__(
api_key=api_key or settings.API_KEY,
user_api_key=user_api_key,
base_url=base_url or NOVITA_BASE_URL,
*args,
**kwargs,
)
for line in response:
if line.choices[0].delta.content is not None:
yield line.choices[0].delta.content

View File

@@ -0,0 +1,15 @@
from application.core.settings import settings
from application.llm.openai import OpenAILLM
OPEN_ROUTER_BASE_URL = "https://openrouter.ai/api/v1"
class OpenRouterLLM(OpenAILLM):
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
super().__init__(
api_key=api_key or settings.OPEN_ROUTER_API_KEY or settings.API_KEY,
user_api_key=user_api_key,
base_url=base_url or OPEN_ROUTER_BASE_URL,
*args,
**kwargs,
)

View File

@@ -2,27 +2,85 @@ import base64
import json
import logging
from openai import OpenAI
from application.core.settings import settings
from application.llm.base import BaseLLM
from application.storage.storage_creator import StorageCreator
def _truncate_base64_for_logging(messages):
"""
Create a copy of messages with base64 data truncated for readable logging.
Args:
messages: List of message dicts
Returns:
Copy of messages with truncated base64 content
"""
import copy
def truncate_content(content):
if isinstance(content, str):
# Check if it looks like a data URL with base64
if content.startswith("data:") and ";base64," in content:
prefix_end = content.index(";base64,") + len(";base64,")
prefix = content[:prefix_end]
return f"{prefix}[BASE64_DATA_TRUNCATED, length={len(content) - prefix_end}]"
return content
elif isinstance(content, list):
return [truncate_item(item) for item in content]
elif isinstance(content, dict):
return {k: truncate_content(v) for k, v in content.items()}
return content
def truncate_item(item):
if isinstance(item, dict):
result = {}
for k, v in item.items():
if k == "url" and isinstance(v, str) and ";base64," in v:
prefix_end = v.index(";base64,") + len(";base64,")
prefix = v[:prefix_end]
result[k] = f"{prefix}[BASE64_DATA_TRUNCATED, length={len(v) - prefix_end}]"
elif k == "data" and isinstance(v, str) and len(v) > 100:
result[k] = f"[BASE64_DATA_TRUNCATED, length={len(v)}]"
else:
result[k] = truncate_content(v)
return result
return truncate_content(item)
truncated = []
for msg in messages:
msg_copy = copy.copy(msg)
if "content" in msg_copy:
msg_copy["content"] = truncate_content(msg_copy["content"])
truncated.append(msg_copy)
return truncated
class OpenAILLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
from openai import OpenAI
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
super().__init__(*args, **kwargs)
if (
self.api_key = api_key or settings.OPENAI_API_KEY or settings.API_KEY
self.user_api_key = user_api_key
# Priority: 1) Parameter base_url, 2) Settings OPENAI_BASE_URL, 3) Default
effective_base_url = None
if base_url and isinstance(base_url, str) and base_url.strip():
effective_base_url = base_url
elif (
isinstance(settings.OPENAI_BASE_URL, str)
and settings.OPENAI_BASE_URL.strip()
):
self.client = OpenAI(api_key=api_key, base_url=settings.OPENAI_BASE_URL)
effective_base_url = settings.OPENAI_BASE_URL
else:
DEFAULT_OPENAI_API_BASE = "https://api.openai.com/v1"
self.client = OpenAI(api_key=api_key, base_url=DEFAULT_OPENAI_API_BASE)
self.api_key = api_key
self.user_api_key = user_api_key
effective_base_url = "https://api.openai.com/v1"
self.client = OpenAI(api_key=self.api_key, base_url=effective_base_url)
self.storage = StorageCreator.get_storage()
def _clean_messages_openai(self, messages):
@@ -33,17 +91,16 @@ class OpenAILLM(BaseLLM):
if role == "model":
role = "assistant"
if role and content is not None:
if isinstance(content, str):
cleaned_messages.append({"role": role, "content": content})
elif isinstance(content, list):
# Collect all content parts into a single message
content_parts = []
for item in content:
if "text" in item:
cleaned_messages.append(
{"role": role, "content": item["text"]}
)
elif "function_call" in item:
if "function_call" in item:
# Function calls need their own message
cleaned_args = self._remove_null_values(
item["function_call"]["args"]
)
@@ -63,6 +120,7 @@ class OpenAILLM(BaseLLM):
}
)
elif "function_response" in item:
# Function responses need their own message
cleaned_messages.append(
{
"role": "tool",
@@ -75,41 +133,69 @@ class OpenAILLM(BaseLLM):
}
)
elif isinstance(item, dict):
content_parts = []
if "text" in item:
content_parts.append(
{"type": "text", "text": item["text"]}
)
elif (
"type" in item
and item["type"] == "text"
and "text" in item
):
# Collect content parts (text, images, files) into a single message
if "type" in item and item["type"] == "text" and "text" in item:
content_parts.append(item)
elif (
"type" in item
and item["type"] == "file"
and "file" in item
):
elif "type" in item and item["type"] == "file" and "file" in item:
content_parts.append(item)
elif (
"type" in item
and item["type"] == "image_url"
and "image_url" in item
):
elif "type" in item and item["type"] == "image_url" and "image_url" in item:
content_parts.append(item)
cleaned_messages.append(
{"role": role, "content": content_parts}
)
else:
raise ValueError(
f"Unexpected content dictionary format: {item}"
)
elif "text" in item and "type" not in item:
# Legacy format: {"text": "..."} without type
content_parts.append({"type": "text", "text": item["text"]})
# Add the collected content parts as a single message
if content_parts:
cleaned_messages.append({"role": role, "content": content_parts})
else:
raise ValueError(f"Unexpected content type: {type(content)}")
return cleaned_messages
@staticmethod
def _normalize_reasoning_value(value):
"""Normalize reasoning payloads from OpenAI-compatible stream chunks."""
if value is None:
return ""
if isinstance(value, str):
return value
if isinstance(value, list):
return "".join(
OpenAILLM._normalize_reasoning_value(item) for item in value
)
if isinstance(value, dict):
for key in ("text", "content", "value", "reasoning_content", "reasoning"):
normalized = OpenAILLM._normalize_reasoning_value(value.get(key))
if normalized:
return normalized
return ""
for attr in ("text", "content", "value"):
if hasattr(value, attr):
normalized = OpenAILLM._normalize_reasoning_value(getattr(value, attr))
if normalized:
return normalized
return ""
@classmethod
def _extract_reasoning_text(cls, delta):
"""Extract reasoning/thinking tokens from OpenAI-compatible delta chunks."""
if delta is None:
return ""
for key in (
"reasoning_content",
"reasoning",
"thinking",
"thinking_content",
):
value = getattr(delta, key, None)
if value is None and isinstance(delta, dict):
value = delta.get(key)
normalized = cls._normalize_reasoning_value(value)
if normalized:
return normalized
return ""
def _raw_gen(
self,
baseself,
@@ -122,6 +208,11 @@ class OpenAILLM(BaseLLM):
**kwargs,
):
messages = self._clean_messages_openai(messages)
logging.info(f"Cleaned messages: {_truncate_base64_for_logging(messages)}")
# Convert max_tokens to max_completion_tokens for newer models
if "max_tokens" in kwargs:
kwargs["max_completion_tokens"] = kwargs.pop("max_tokens")
request_params = {
"model": model,
@@ -132,12 +223,10 @@ class OpenAILLM(BaseLLM):
if tools:
request_params["tools"] = tools
if response_format:
request_params["response_format"] = response_format
response = self.client.chat.completions.create(**request_params)
logging.info(f"OpenAI response: {response}")
if tools:
return response.choices[0]
else:
@@ -155,6 +244,11 @@ class OpenAILLM(BaseLLM):
**kwargs,
):
messages = self._clean_messages_openai(messages)
logging.info(f"Cleaned messages: {_truncate_base64_for_logging(messages)}")
# Convert max_tokens to max_completion_tokens for newer models
if "max_tokens" in kwargs:
kwargs["max_completion_tokens"] = kwargs.pop("max_tokens")
request_params = {
"model": model,
@@ -165,22 +259,33 @@ class OpenAILLM(BaseLLM):
if tools:
request_params["tools"] = tools
if response_format:
request_params["response_format"] = response_format
response = self.client.chat.completions.create(**request_params)
try:
for line in response:
if (
len(line.choices) > 0
and line.choices[0].delta.content is not None
and len(line.choices[0].delta.content) > 0
):
yield line.choices[0].delta.content
elif len(line.choices) > 0:
yield line.choices[0]
logging.debug(f"OpenAI stream line: {line}")
if not getattr(line, "choices", None):
continue
choice = line.choices[0]
delta = getattr(choice, "delta", None)
reasoning_text = self._extract_reasoning_text(delta)
if reasoning_text:
yield {"type": "thought", "thought": reasoning_text}
content = getattr(delta, "content", None)
if isinstance(content, str) and content:
yield content
continue
has_tool_calls = bool(getattr(delta, "tool_calls", None))
finish_reason = getattr(choice, "finish_reason", None)
# Yield non-content chunks only when needed for tool-call handling.
if has_tool_calls or finish_reason == "tool_calls":
yield choice
finally:
if hasattr(response, "close"):
response.close()
@@ -194,7 +299,6 @@ class OpenAILLM(BaseLLM):
def prepare_structured_output_format(self, json_schema):
if not json_schema:
return None
try:
def add_additional_properties_false(schema_obj):
@@ -204,11 +308,11 @@ class OpenAILLM(BaseLLM):
if schema_copy.get("type") == "object":
schema_copy["additionalProperties"] = False
# Ensure 'required' includes all properties for OpenAI strict mode
if "properties" in schema_copy:
schema_copy["required"] = list(
schema_copy["properties"].keys()
)
for key, value in schema_copy.items():
if key == "properties" and isinstance(value, dict):
schema_copy[key] = {
@@ -224,7 +328,6 @@ class OpenAILLM(BaseLLM):
add_additional_properties_false(sub_schema)
for sub_schema in value
]
return schema_copy
return schema_obj
@@ -243,7 +346,6 @@ class OpenAILLM(BaseLLM):
}
return result
except Exception as e:
logging.error(f"Error preparing structured output format: {e}")
return None
@@ -252,17 +354,14 @@ class OpenAILLM(BaseLLM):
"""
Return a list of MIME types supported by OpenAI for file uploads.
This reads from the model config to ensure consistency.
If no model config found, falls back to images only (safest default).
Returns:
list: List of supported MIME types
"""
return [
"application/pdf",
"image/png",
"image/jpeg",
"image/jpg",
"image/webp",
"image/gif",
]
from application.core.model_configs import OPENAI_ATTACHMENTS
return OPENAI_ATTACHMENTS
def prepare_messages_with_attachments(self, messages, attachments=None):
"""
@@ -277,21 +376,19 @@ class OpenAILLM(BaseLLM):
"""
if not attachments:
return messages
prepared_messages = messages.copy()
# Find the user message to attach file_id to the last one
user_message_index = None
for i in range(len(prepared_messages) - 1, -1, -1):
if prepared_messages[i].get("role") == "user":
user_message_index = i
break
if user_message_index is None:
user_message = {"role": "user", "content": []}
prepared_messages.append(user_message)
user_message_index = len(prepared_messages) - 1
if isinstance(prepared_messages[user_message_index].get("content"), str):
text_content = prepared_messages[user_message_index]["content"]
prepared_messages[user_message_index]["content"] = [
@@ -299,13 +396,18 @@ class OpenAILLM(BaseLLM):
]
elif not isinstance(prepared_messages[user_message_index].get("content"), list):
prepared_messages[user_message_index]["content"] = []
for attachment in attachments:
mime_type = attachment.get("mime_type")
logging.info(f"Processing attachment with mime_type: {mime_type}, has_data: {'data' in attachment}, has_path: {'path' in attachment}")
if mime_type and mime_type.startswith("image/"):
try:
base64_image = self._get_base64_image(attachment)
# Check if this is a pre-converted image (from PDF-to-image conversion)
if "data" in attachment:
base64_image = attachment["data"]
else:
base64_image = self._get_base64_image(attachment)
prepared_messages[user_message_index]["content"].append(
{
"type": "image_url",
@@ -314,6 +416,7 @@ class OpenAILLM(BaseLLM):
},
}
)
except Exception as e:
logging.error(
f"Error processing image attachment: {e}", exc_info=True
@@ -326,7 +429,9 @@ class OpenAILLM(BaseLLM):
}
)
# Handle PDFs using the file API
elif mime_type == "application/pdf":
logging.info(f"Attempting to upload PDF to OpenAI: {attachment.get('path', 'unknown')}")
try:
file_id = self._upload_file_to_openai(attachment)
prepared_messages[user_message_index]["content"].append(
@@ -341,7 +446,8 @@ class OpenAILLM(BaseLLM):
"text": f"File content:\n\n{attachment['content']}",
}
)
else:
logging.warning(f"Unsupported attachment type in OpenAI provider: {mime_type}")
return prepared_messages
def _get_base64_image(self, attachment):
@@ -357,7 +463,6 @@ class OpenAILLM(BaseLLM):
file_path = attachment.get("path")
if not file_path:
raise ValueError("No file path provided in attachment")
try:
with self.storage.get_file(file_path) as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
@@ -381,12 +486,10 @@ class OpenAILLM(BaseLLM):
if "openai_file_id" in attachment:
return attachment["openai_file_id"]
file_path = attachment.get("path")
if not self.storage.file_exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
try:
file_id = self.storage.process_file(
file_path,
@@ -404,7 +507,6 @@ class OpenAILLM(BaseLLM):
attachments_collection.update_one(
{"_id": attachment["_id"]}, {"$set": {"openai_file_id": file_id}}
)
return file_id
except Exception as e:
logging.error(f"Error uploading file to OpenAI: {e}", exc_info=True)

View File

@@ -62,15 +62,26 @@ class BaseConnectorAuth(ABC):
def is_token_expired(self, token_info: Dict[str, Any]) -> bool:
"""
Check if a token is expired.
Args:
token_info: Token information dictionary
Returns:
True if token is expired, False otherwise
"""
pass
def sanitize_token_info(self, token_info: Dict[str, Any], **extra_fields) -> Dict[str, Any]:
"""Extract the fields safe to persist in the session store.
"""
return {
"access_token": token_info.get("access_token"),
"refresh_token": token_info.get("refresh_token"),
"token_uri": token_info.get("token_uri"),
"expiry": token_info.get("expiry"),
**extra_fields,
}
class BaseConnectorLoader(ABC):
"""

View File

@@ -1,5 +1,7 @@
from application.parser.connectors.google_drive.loader import GoogleDriveLoader
from application.parser.connectors.google_drive.auth import GoogleDriveAuth
from application.parser.connectors.share_point.auth import SharePointAuth
from application.parser.connectors.share_point.loader import SharePointLoader
class ConnectorCreator:
@@ -12,10 +14,12 @@ class ConnectorCreator:
connectors = {
"google_drive": GoogleDriveLoader,
"share_point": SharePointLoader,
}
auth_providers = {
"google_drive": GoogleDriveAuth,
"share_point": SharePointAuth,
}
@classmethod

View File

@@ -232,10 +232,6 @@ class GoogleDriveAuth(BaseConnectorAuth):
if missing_fields:
raise ValueError(f"Missing required token fields: {missing_fields}")
if 'client_id' not in token_info:
token_info['client_id'] = settings.GOOGLE_CLIENT_ID
if 'client_secret' not in token_info:
token_info['client_secret'] = settings.GOOGLE_CLIENT_SECRET
if 'token_uri' not in token_info:
token_info['token_uri'] = 'https://oauth2.googleapis.com/token'

View File

@@ -327,15 +327,10 @@ class GoogleDriveLoader(BaseConnectorLoader):
content_bytes = file_io.getvalue()
try:
content = content_bytes.decode('utf-8')
return content_bytes.decode('utf-8')
except UnicodeDecodeError:
try:
content = content_bytes.decode('latin-1')
except UnicodeDecodeError:
logging.error(f"Could not decode file {file_id} as text")
return None
return content
logging.error(f"Could not decode file {file_id} as text")
return None
except HttpError as e:
logging.error(f"HTTP error downloading file {file_id}: {e.resp.status} - {e.content}")

View File

@@ -0,0 +1,10 @@
"""
Share Point connector package for DocsGPT.
This module provides authentication and document loading capabilities for Share Point.
"""
from .auth import SharePointAuth
from .loader import SharePointLoader
__all__ = ['SharePointAuth', 'SharePointLoader']

View File

@@ -0,0 +1,152 @@
import datetime
import logging
from typing import Optional, Dict, Any
from msal import ConfidentialClientApplication
from application.core.settings import settings
from application.parser.connectors.base import BaseConnectorAuth
logger = logging.getLogger(__name__)
class SharePointAuth(BaseConnectorAuth):
"""
Handles Microsoft OAuth 2.0 authentication for SharePoint/OneDrive.
Note: Files.Read scope allows access to files the user has granted access to,
similar to Google Drive's drive.file scope.
"""
SCOPES = [
"Files.Read",
"Sites.Read.All",
"User.Read",
]
def __init__(self):
self.client_id = settings.MICROSOFT_CLIENT_ID
self.client_secret = settings.MICROSOFT_CLIENT_SECRET
if not self.client_id:
raise ValueError(
"Microsoft OAuth credentials not configured. Please set MICROSOFT_CLIENT_ID in settings."
)
if not self.client_secret:
raise ValueError(
"Microsoft OAuth credentials not configured. Please set MICROSOFT_CLIENT_SECRET in settings."
)
self.redirect_uri = settings.CONNECTOR_REDIRECT_BASE_URI
self.tenant_id = settings.MICROSOFT_TENANT_ID
self.authority = getattr(settings, "MICROSOFT_AUTHORITY", f"https://login.microsoftonline.com/{self.tenant_id}")
self.auth_app = ConfidentialClientApplication(
client_id=self.client_id,
client_credential=self.client_secret,
authority=self.authority
)
def get_authorization_url(self, state: Optional[str] = None) -> str:
return self.auth_app.get_authorization_request_url(
scopes=self.SCOPES, state=state, redirect_uri=self.redirect_uri
)
def exchange_code_for_tokens(self, authorization_code: str) -> Dict[str, Any]:
result = self.auth_app.acquire_token_by_authorization_code(
code=authorization_code,
scopes=self.SCOPES,
redirect_uri=self.redirect_uri
)
if "error" in result:
logger.error("Token exchange failed: %s", result.get("error_description"))
raise ValueError(f"Error acquiring token: {result.get('error_description')}")
return self.map_token_response(result)
def refresh_access_token(self, refresh_token: str) -> Dict[str, Any]:
result = self.auth_app.acquire_token_by_refresh_token(refresh_token=refresh_token, scopes=self.SCOPES)
if "error" in result:
logger.error("Token refresh failed: %s", result.get("error_description"))
raise ValueError(f"Error refreshing token: {result.get('error_description')}")
return self.map_token_response(result)
def get_token_info_from_session(self, session_token: str) -> Dict[str, Any]:
try:
from application.core.mongo_db import MongoDB
from application.core.settings import settings
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
sessions_collection = db["connector_sessions"]
session = sessions_collection.find_one({"session_token": session_token})
if not session:
raise ValueError(f"Invalid session token: {session_token}")
if "token_info" not in session:
raise ValueError("Session missing token information")
token_info = session["token_info"]
if not token_info:
raise ValueError("Invalid token information")
required_fields = ["access_token", "refresh_token"]
missing_fields = [field for field in required_fields if field not in token_info or not token_info.get(field)]
if missing_fields:
raise ValueError(f"Missing required token fields: {missing_fields}")
if 'token_uri' not in token_info:
token_info['token_uri'] = f"https://login.microsoftonline.com/{settings.MICROSOFT_TENANT_ID}/oauth2/v2.0/token"
return token_info
except Exception as e:
logger.error("Failed to retrieve token from session: %s", e)
raise ValueError(f"Failed to retrieve SharePoint token information: {str(e)}")
def is_token_expired(self, token_info: Dict[str, Any]) -> bool:
if not token_info:
return True
expiry_timestamp = token_info.get("expiry")
if expiry_timestamp is None:
return True
current_timestamp = int(datetime.datetime.now().timestamp())
return (expiry_timestamp - current_timestamp) < 60
def sanitize_token_info(self, token_info: Dict[str, Any], **extra_fields) -> Dict[str, Any]:
return super().sanitize_token_info(
token_info,
allows_shared_content=token_info.get("allows_shared_content", False),
**extra_fields,
)
PERSONAL_ACCOUNT_TENANT_ID = "9188040d-6c67-4c5b-b112-36a304b66dad"
def _allows_shared_content(self, id_token_claims: Dict[str, Any]) -> bool:
"""Return True when the account is a work/school tenant that can access SharePoint shared content."""
tid = id_token_claims.get("tid", "")
return bool(tid) and tid != self.PERSONAL_ACCOUNT_TENANT_ID
def map_token_response(self, result) -> Dict[str, Any]:
claims = result.get("id_token_claims", {})
return {
"access_token": result.get("access_token"),
"refresh_token": result.get("refresh_token"),
"token_uri": claims.get("iss"),
"scopes": result.get("scope"),
"expiry": claims.get("exp"),
"allows_shared_content": self._allows_shared_content(claims),
"user_info": {
"name": claims.get("name"),
"email": claims.get("preferred_username"),
},
}

View File

@@ -0,0 +1,649 @@
"""
SharePoint/OneDrive loader for DocsGPT.
Loads documents from SharePoint/OneDrive using Microsoft Graph API.
"""
import functools
import logging
import os
from typing import List, Dict, Any, Optional, Tuple
from urllib.parse import quote
import requests
from application.parser.connectors.base import BaseConnectorLoader
from application.parser.connectors.share_point.auth import SharePointAuth
from application.parser.schema.base import Document
def _retry_on_auth_failure(func):
"""Retry once after refreshing the access token on 401/403 responses."""
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
try:
return func(self, *args, **kwargs)
except requests.exceptions.HTTPError as e:
if e.response is not None and e.response.status_code in (401, 403):
logging.info(f"Auth failure in {func.__name__}, refreshing token and retrying")
try:
new_token_info = self.auth.refresh_access_token(self.refresh_token)
self.access_token = new_token_info.get('access_token')
except Exception as refresh_error:
raise ValueError(
f"Authentication failed and could not be refreshed: {refresh_error}"
) from e
return func(self, *args, **kwargs)
raise
return wrapper
class SharePointLoader(BaseConnectorLoader):
SUPPORTED_MIME_TYPES = {
'application/pdf': '.pdf',
'application/vnd.openxmlformats-officedocument.wordprocessingml.document': '.docx',
'application/vnd.openxmlformats-officedocument.presentationml.presentation': '.pptx',
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': '.xlsx',
'application/msword': '.doc',
'application/vnd.ms-powerpoint': '.ppt',
'application/vnd.ms-excel': '.xls',
'text/plain': '.txt',
'text/csv': '.csv',
'text/html': '.html',
'text/markdown': '.md',
'text/x-rst': '.rst',
'application/json': '.json',
'application/epub+zip': '.epub',
'application/rtf': '.rtf',
'image/jpeg': '.jpg',
'image/png': '.png',
}
EXTENSION_TO_MIME = {v: k for k, v in SUPPORTED_MIME_TYPES.items()}
GRAPH_API_BASE = "https://graph.microsoft.com/v1.0"
def __init__(self, session_token: str):
self.auth = SharePointAuth()
self.session_token = session_token
token_info = self.auth.get_token_info_from_session(session_token)
self.access_token = token_info.get('access_token')
self.refresh_token = token_info.get('refresh_token')
self.allows_shared_content = token_info.get('allows_shared_content', False)
if not self.access_token:
raise ValueError("No access token found in session")
self.next_page_token = None
def _get_headers(self) -> Dict[str, str]:
return {
'Authorization': f'Bearer {self.access_token}',
'Accept': 'application/json'
}
def _ensure_valid_token(self):
if not self.access_token:
raise ValueError("No access token available")
token_info = {'access_token': self.access_token, 'expiry': None}
if self.auth.is_token_expired(token_info):
logging.info("Token expired, attempting refresh")
try:
new_token_info = self.auth.refresh_access_token(self.refresh_token)
self.access_token = new_token_info.get('access_token')
except Exception:
raise ValueError("Failed to refresh access token")
def _get_item_url(self, item_ref: str) -> str:
if ':' in item_ref:
drive_id, item_id = item_ref.split(':', 1)
return f"{self.GRAPH_API_BASE}/drives/{drive_id}/items/{item_id}"
return f"{self.GRAPH_API_BASE}/me/drive/items/{item_ref}"
def _process_file(self, file_metadata: Dict[str, Any], load_content: bool = True) -> Optional[Document]:
try:
drive_item_id = file_metadata.get('id')
file_name = file_metadata.get('name', 'Unknown')
file_data = file_metadata.get('file', {})
mime_type = file_data.get('mimeType', 'application/octet-stream')
if mime_type not in self.SUPPORTED_MIME_TYPES:
logging.info(f"Skipping unsupported file type: {mime_type} for file {file_name}")
return None
doc_metadata = {
'file_name': file_name,
'mime_type': mime_type,
'size': file_metadata.get('size'),
'created_time': file_metadata.get('createdDateTime'),
'modified_time': file_metadata.get('lastModifiedDateTime'),
'source': 'share_point'
}
if not load_content:
return Document(
text="",
doc_id=drive_item_id,
extra_info=doc_metadata
)
content = self._download_file_content(drive_item_id)
if content is None:
logging.warning(f"Could not load content for file {file_name} ({drive_item_id})")
return None
return Document(
text=content,
doc_id=drive_item_id,
extra_info=doc_metadata
)
except Exception as e:
logging.error(f"Error processing file: {e}")
return None
def load_data(self, inputs: Dict[str, Any]) -> List[Document]:
try:
documents: List[Document] = []
folder_id = inputs.get('folder_id')
file_ids = inputs.get('file_ids', [])
limit = inputs.get('limit', 100)
list_only = inputs.get('list_only', False)
load_content = not list_only
page_token = inputs.get('page_token')
search_query = inputs.get('search_query')
self.next_page_token = None
shared = inputs.get('shared', False)
if file_ids:
for file_id in file_ids:
try:
doc = self._load_file_by_id(file_id, load_content=load_content)
if doc:
if not search_query or (
search_query.lower() in doc.extra_info.get('file_name', '').lower()
):
documents.append(doc)
except Exception as e:
logging.error(f"Error loading file {file_id}: {e}")
continue
elif shared:
if not self.allows_shared_content:
logging.warning("Shared content is only available for work/school Microsoft accounts")
return []
documents = self._list_shared_items(
limit=limit,
load_content=load_content,
page_token=page_token,
search_query=search_query
)
else:
parent_id = folder_id if folder_id else 'root'
documents = self._list_items_in_parent(
parent_id,
limit=limit,
load_content=load_content,
page_token=page_token,
search_query=search_query
)
logging.info(f"Loaded {len(documents)} documents from SharePoint/OneDrive")
return documents
except Exception as e:
logging.error(f"Error loading data from SharePoint/OneDrive: {e}", exc_info=True)
raise
@_retry_on_auth_failure
def _load_file_by_id(self, file_id: str, load_content: bool = True) -> Optional[Document]:
self._ensure_valid_token()
try:
url = self._get_item_url(file_id)
params = {'$select': 'id,name,file,createdDateTime,lastModifiedDateTime,size'}
response = requests.get(url, headers=self._get_headers(), params=params)
response.raise_for_status()
file_metadata = response.json()
return self._process_file(file_metadata, load_content=load_content)
except requests.exceptions.HTTPError:
raise
except Exception as e:
logging.error(f"Error loading file {file_id}: {e}")
return None
@_retry_on_auth_failure
def _list_items_in_parent(self, parent_id: str, limit: int = 100, load_content: bool = False, page_token: Optional[str] = None, search_query: Optional[str] = None) -> List[Document]:
self._ensure_valid_token()
documents: List[Document] = []
try:
url = f"{self._get_item_url(parent_id)}/children"
params = {'$top': min(100, limit) if limit else 100, '$select': 'id,name,file,folder,createdDateTime,lastModifiedDateTime,size'}
if page_token:
params['$skipToken'] = page_token
if search_query:
encoded_query = quote(search_query, safe='')
if ':' in parent_id:
drive_id = parent_id.split(':', 1)[0]
search_url = f"{self.GRAPH_API_BASE}/drives/{drive_id}/root/search(q='{encoded_query}')"
else:
search_url = f"{self.GRAPH_API_BASE}/me/drive/search(q='{encoded_query}')"
response = requests.get(search_url, headers=self._get_headers(), params=params)
else:
response = requests.get(url, headers=self._get_headers(), params=params)
response.raise_for_status()
results = response.json()
items = results.get('value', [])
for item in items:
if 'folder' in item:
doc_metadata = {
'file_name': item.get('name', 'Unknown'),
'mime_type': 'folder',
'size': item.get('size'),
'created_time': item.get('createdDateTime'),
'modified_time': item.get('lastModifiedDateTime'),
'source': 'share_point',
'is_folder': True
}
documents.append(Document(text="", doc_id=item.get('id'), extra_info=doc_metadata))
else:
doc = self._process_file(item, load_content=load_content)
if doc:
documents.append(doc)
if limit and len(documents) >= limit:
break
next_link = results.get('@odata.nextLink')
if next_link:
from urllib.parse import urlparse, parse_qs
parsed = urlparse(next_link)
query_params = parse_qs(parsed.query)
skiptoken_list = query_params.get('$skiptoken')
if skiptoken_list:
self.next_page_token = skiptoken_list[0]
else:
self.next_page_token = None
else:
self.next_page_token = None
return documents
except Exception as e:
logging.error(f"Error listing items under parent {parent_id}: {e}")
return documents
def _resolve_mime_type(self, resource: Dict[str, Any]) -> Tuple[str, bool]:
"""Resolve mime type from resource, falling back to file extension."""
file_data = resource.get('file', {})
mime_type = file_data.get('mimeType') if file_data else None
if mime_type and mime_type in self.SUPPORTED_MIME_TYPES:
return mime_type, True
name = resource.get('name', '')
ext = os.path.splitext(name)[1].lower()
if ext in self.EXTENSION_TO_MIME:
return self.EXTENSION_TO_MIME[ext], True
return mime_type or 'application/octet-stream', False
def _get_user_drive_web_url(self) -> Optional[str]:
"""Fetch the current user's OneDrive web URL for KQL path exclusion."""
try:
response = requests.get(
f"{self.GRAPH_API_BASE}/me/drive",
headers=self._get_headers(),
params={'$select': 'webUrl'}
)
response.raise_for_status()
return response.json().get('webUrl')
except Exception as e:
logging.warning(f"Could not fetch user drive web URL: {e}")
return None
def _build_shared_kql_query(self, search_query: Optional[str], user_drive_url: Optional[str]) -> str:
"""Build KQL query string that excludes the user's own drive items."""
base_query = search_query if search_query else "*"
if user_drive_url:
return f'{base_query} AND -path:"{user_drive_url}"'
return base_query
def _list_shared_items(self, limit: int = 100, load_content: bool = False, page_token: Optional[str] = None, search_query: Optional[str] = None) -> List[Document]:
"""Fetch shared drive items using Microsoft Graph Search API with local offset paging.
We always fetch up to a fixed maximum number of hits from Graph (single request),
then page through that array locally using `page_token` as a simple integer offset.
This avoids relying on buggy or inconsistent remote `from`/`size` semantics.
"""
self._ensure_valid_token()
documents: List[Document] = []
try:
user_drive_url = self._get_user_drive_web_url()
query_text = self._build_shared_kql_query(search_query, user_drive_url)
url = f"{self.GRAPH_API_BASE}/search/query"
page_size = 500 # maximum number of hits we care about for selection
body = {
"requests": [
{
"entityTypes": ["driveItem"],
"query": {"queryString": query_text},
"from": 0,
"size": page_size,
}
]
}
headers = self._get_headers()
headers["Content-Type"] = "application/json"
response = requests.post(url, headers=headers, json=body)
response.raise_for_status()
results = response.json()
search_response = results.get("value", [])
if not search_response:
logging.warning("Search API returned empty value array")
self.next_page_token = None
return documents
hits_containers = search_response[0].get("hitsContainers", [])
if not hits_containers:
logging.warning("Search API returned no hitsContainers")
self.next_page_token = None
return documents
container = hits_containers[0]
total = container.get("total", 0)
raw_hits = container.get("hits", [])
# Deduplicate by effective item ID (driveId:itemId) to avoid the same
# resource appearing multiple times across the result set.
deduped_hits = []
seen_ids = set()
for hit in raw_hits:
resource = hit.get("resource", {})
item_id = resource.get("id")
drive_id = resource.get("parentReference", {}).get("driveId")
effective_id = f"{drive_id}:{item_id}" if drive_id and item_id else item_id
if not effective_id or effective_id in seen_ids:
continue
seen_ids.add(effective_id)
deduped_hits.append(hit)
hits = deduped_hits
logging.info(
f"Search API returned {total} total results, {len(raw_hits)} raw hits, {len(hits)} unique hits in this batch"
)
try:
offset = int(page_token) if page_token is not None else 0
except (TypeError, ValueError):
logging.warning(
f"Invalid page_token '{page_token}' for shared items search, defaulting to 0"
)
offset = 0
if offset < 0:
offset = 0
if offset >= len(hits):
self.next_page_token = None
return documents
end_index = offset + limit if limit else len(hits)
end_index = min(end_index, len(hits))
for hit in hits[offset:end_index]:
resource = hit.get("resource", {})
item_name = resource.get("name", "Unknown")
item_id = resource.get("id")
drive_id = resource.get("parentReference", {}).get("driveId")
effective_id = f"{drive_id}:{item_id}" if drive_id and item_id else item_id
is_folder = "folder" in resource
if is_folder:
doc_metadata = {
"file_name": item_name,
"mime_type": "folder",
"size": resource.get("size"),
"created_time": resource.get("createdDateTime"),
"modified_time": resource.get("lastModifiedDateTime"),
"source": "share_point",
"is_folder": True,
}
documents.append(
Document(text="", doc_id=effective_id, extra_info=doc_metadata)
)
else:
mime_type, supported = self._resolve_mime_type(resource)
if not supported:
logging.info(
f"Skipping unsupported shared file: {item_name} (mime: {mime_type})"
)
continue
doc_metadata = {
"file_name": item_name,
"mime_type": mime_type,
"size": resource.get("size"),
"created_time": resource.get("createdDateTime"),
"modified_time": resource.get("lastModifiedDateTime"),
"source": "share_point",
}
content = ""
if load_content:
content = self._download_file_content(effective_id) or ""
documents.append(
Document(text=content, doc_id=effective_id, extra_info=doc_metadata)
)
if limit and end_index < len(hits):
self.next_page_token = str(end_index)
else:
self.next_page_token = None
return documents
except Exception as e:
logging.error(f"Error listing shared items via search API: {e}", exc_info=True)
return documents
@_retry_on_auth_failure
def _download_file_content(self, file_id: str) -> Optional[str]:
self._ensure_valid_token()
try:
url = f"{self._get_item_url(file_id)}/content"
response = requests.get(url, headers=self._get_headers())
response.raise_for_status()
try:
return response.content.decode('utf-8')
except UnicodeDecodeError:
logging.error(f"Could not decode file {file_id} as text")
return None
except requests.exceptions.HTTPError:
raise
except Exception as e:
logging.error(f"Error downloading file {file_id}: {e}")
return None
def _download_single_file(self, file_id: str, local_dir: str) -> bool:
try:
url = self._get_item_url(file_id)
params = {'$select': 'id,name,file'}
response = requests.get(url, headers=self._get_headers(), params=params)
response.raise_for_status()
metadata = response.json()
file_name = metadata.get('name', 'unknown')
file_data = metadata.get('file', {})
mime_type = file_data.get('mimeType', 'application/octet-stream')
if mime_type not in self.SUPPORTED_MIME_TYPES:
logging.info(f"Skipping unsupported file type: {mime_type}")
return False
os.makedirs(local_dir, exist_ok=True)
full_path = os.path.join(local_dir, file_name)
download_url = f"{self._get_item_url(file_id)}/content"
download_response = requests.get(download_url, headers=self._get_headers())
download_response.raise_for_status()
with open(full_path, 'wb') as f:
f.write(download_response.content)
return True
except Exception as e:
logging.error(f"Error in _download_single_file: {e}")
return False
def _download_folder_recursive(self, folder_id: str, local_dir: str, recursive: bool = True) -> int:
files_downloaded = 0
try:
os.makedirs(local_dir, exist_ok=True)
url = f"{self._get_item_url(folder_id)}/children"
params = {'$top': 1000}
while url:
response = requests.get(url, headers=self._get_headers(), params=params)
response.raise_for_status()
results = response.json()
items = results.get('value', [])
logging.info(f"Found {len(items)} items in folder {folder_id}")
for item in items:
item_name = item.get('name', 'unknown')
item_id = item.get('id')
if 'folder' in item:
if recursive:
subfolder_path = os.path.join(local_dir, item_name)
os.makedirs(subfolder_path, exist_ok=True)
subfolder_files = self._download_folder_recursive(
item_id,
subfolder_path,
recursive
)
files_downloaded += subfolder_files
logging.info(f"Downloaded {subfolder_files} files from subfolder {item_name}")
else:
success = self._download_single_file(item_id, local_dir)
if success:
files_downloaded += 1
logging.info(f"Downloaded file: {item_name}")
else:
logging.warning(f"Failed to download file: {item_name}")
url = results.get('@odata.nextLink')
return files_downloaded
except Exception as e:
logging.error(f"Error in _download_folder_recursive for folder {folder_id}: {e}", exc_info=True)
return files_downloaded
def _download_folder_contents(self, folder_id: str, local_dir: str, recursive: bool = True) -> int:
try:
self._ensure_valid_token()
return self._download_folder_recursive(folder_id, local_dir, recursive)
except Exception as e:
logging.error(f"Error downloading folder {folder_id}: {e}", exc_info=True)
return 0
def _download_file_to_directory(self, file_id: str, local_dir: str) -> bool:
try:
self._ensure_valid_token()
return self._download_single_file(file_id, local_dir)
except Exception as e:
logging.error(f"Error downloading file {file_id}: {e}", exc_info=True)
return False
def download_to_directory(self, local_dir: str, source_config: Dict[str, Any] = None) -> Dict[str, Any]:
if source_config is None:
source_config = {}
config = source_config if source_config else getattr(self, 'config', {})
files_downloaded = 0
try:
folder_ids = config.get('folder_ids', [])
file_ids = config.get('file_ids', [])
recursive = config.get('recursive', True)
if file_ids:
if isinstance(file_ids, str):
file_ids = [file_ids]
for file_id in file_ids:
if self._download_file_to_directory(file_id, local_dir):
files_downloaded += 1
if folder_ids:
if isinstance(folder_ids, str):
folder_ids = [folder_ids]
for folder_id in folder_ids:
try:
url = self._get_item_url(folder_id)
params = {'$select': 'id,name'}
response = requests.get(url, headers=self._get_headers(), params=params)
response.raise_for_status()
folder_metadata = response.json()
folder_name = folder_metadata.get('name', '')
folder_path = os.path.join(local_dir, folder_name)
os.makedirs(folder_path, exist_ok=True)
folder_files = self._download_folder_recursive(
folder_id,
folder_path,
recursive
)
files_downloaded += folder_files
logging.info(f"Downloaded {folder_files} files from folder {folder_name}")
except Exception as e:
logging.error(f"Error downloading folder {folder_id}: {e}", exc_info=True)
if not file_ids and not folder_ids:
raise ValueError("No folder_ids or file_ids provided for download")
return {
"files_downloaded": files_downloaded,
"directory_path": local_dir,
"empty_result": files_downloaded == 0,
"source_type": "share_point",
"config_used": config
}
except Exception as e:
return {
"files_downloaded": files_downloaded,
"directory_path": local_dir,
"empty_result": True,
"source_type": "share_point",
"config_used": config,
"error": str(e)
}

View File

@@ -65,6 +65,10 @@ def embed_and_store_documents(docs: List[Any], folder_name: str, source_id: str,
if not os.path.exists(folder_name):
os.makedirs(folder_name)
# Validate docs is not empty
if not docs:
raise ValueError("No documents to embed - check file format and extension")
# Initialize vector store
if settings.VECTOR_STORE == "faiss":
docs_init = [docs.pop(0)]

View File

@@ -2,7 +2,7 @@
from abc import abstractmethod
from typing import Any, List
from langchain.docstore.document import Document as LCDocument
from langchain_core.documents import Document as LCDocument
from application.parser.schema.base import Document

View File

@@ -10,29 +10,97 @@ from application.parser.file.epub_parser import EpubParser
from application.parser.file.html_parser import HTMLParser
from application.parser.file.markdown_parser import MarkdownParser
from application.parser.file.rst_parser import RstParser
from application.parser.file.tabular_parser import PandasCSVParser,ExcelParser
from application.parser.file.tabular_parser import PandasCSVParser, ExcelParser
from application.parser.file.json_parser import JSONParser
from application.parser.file.pptx_parser import PPTXParser
from application.parser.file.image_parser import ImageParser
from application.parser.schema.base import Document
from application.utils import num_tokens_from_string
from application.core.settings import settings
DEFAULT_FILE_EXTRACTOR: Dict[str, BaseParser] = {
".pdf": PDFParser(),
".docx": DocxParser(),
".csv": PandasCSVParser(),
".xlsx":ExcelParser(),
".epub": EpubParser(),
".md": MarkdownParser(),
".rst": RstParser(),
".html": HTMLParser(),
".mdx": MarkdownParser(),
".json":JSONParser(),
".pptx":PPTXParser(),
".png": ImageParser(),
".jpg": ImageParser(),
".jpeg": ImageParser(),
}
def get_default_file_extractor(
ocr_enabled: Optional[bool] = None,
) -> Dict[str, BaseParser]:
"""Get the default file extractor.
Uses docling parsers by default for advanced document processing.
Falls back to standard parsers if docling is not installed.
"""
try:
from application.parser.file.docling_parser import (
DoclingPDFParser,
DoclingDocxParser,
DoclingPPTXParser,
DoclingXLSXParser,
DoclingHTMLParser,
DoclingImageParser,
DoclingCSVParser,
DoclingAsciiDocParser,
DoclingVTTParser,
DoclingXMLParser,
)
if ocr_enabled is None:
ocr_enabled = settings.DOCLING_OCR_ENABLED
return {
# Documents
".pdf": DoclingPDFParser(ocr_enabled=ocr_enabled),
".docx": DoclingDocxParser(),
".pptx": DoclingPPTXParser(),
".xlsx": DoclingXLSXParser(),
# Web formats
".html": DoclingHTMLParser(),
".xhtml": DoclingHTMLParser(),
# Data formats
".csv": DoclingCSVParser(),
".json": JSONParser(), # Keep JSON parser (specialized handling)
# Text/markup formats
".md": MarkdownParser(), # Keep markdown parser (specialized handling)
".mdx": MarkdownParser(),
".rst": RstParser(),
".adoc": DoclingAsciiDocParser(),
".asciidoc": DoclingAsciiDocParser(),
# Images (with OCR) - only use Docling when OCR is enabled
".png": DoclingImageParser(ocr_enabled=ocr_enabled) if ocr_enabled else ImageParser(),
".jpg": DoclingImageParser(ocr_enabled=ocr_enabled) if ocr_enabled else ImageParser(),
".jpeg": DoclingImageParser(ocr_enabled=ocr_enabled) if ocr_enabled else ImageParser(),
".tiff": DoclingImageParser(ocr_enabled=ocr_enabled) if ocr_enabled else ImageParser(),
".tif": DoclingImageParser(ocr_enabled=ocr_enabled) if ocr_enabled else ImageParser(),
".bmp": DoclingImageParser(ocr_enabled=ocr_enabled) if ocr_enabled else ImageParser(),
".webp": DoclingImageParser(ocr_enabled=ocr_enabled) if ocr_enabled else ImageParser(),
# Media/subtitles
".vtt": DoclingVTTParser(),
# Specialized XML formats
".xml": DoclingXMLParser(),
# Formats docling doesn't support - use standard parsers
".epub": EpubParser(),
}
except ImportError:
logging.warning(
"docling is not installed. Using standard parsers. "
"For advanced document parsing, install with: pip install docling"
)
# Fallback to standard parsers
return {
".pdf": PDFParser(),
".docx": DocxParser(),
".csv": PandasCSVParser(),
".xlsx": ExcelParser(),
".epub": EpubParser(),
".md": MarkdownParser(),
".rst": RstParser(),
".html": HTMLParser(),
".mdx": MarkdownParser(),
".json": JSONParser(),
".pptx": PPTXParser(),
".png": ImageParser(),
".jpg": ImageParser(),
".jpeg": ImageParser(),
}
# For backwards compatibility
DEFAULT_FILE_EXTRACTOR: Dict[str, BaseParser] = get_default_file_extractor()
class SimpleDirectoryReader(BaseReader):
@@ -83,7 +151,10 @@ class SimpleDirectoryReader(BaseReader):
self.recursive = recursive
self.exclude_hidden = exclude_hidden
self.required_exts = required_exts
# Normalize extensions to lowercase for case-insensitive matching
self.required_exts = (
[ext.lower() for ext in required_exts] if required_exts else None
)
self.num_files_limit = num_files_limit
if input_files:
@@ -112,7 +183,7 @@ class SimpleDirectoryReader(BaseReader):
continue
elif (
self.required_exts is not None
and input_file.suffix not in self.required_exts
and input_file.suffix.lower() not in self.required_exts
):
continue
else:
@@ -149,8 +220,9 @@ class SimpleDirectoryReader(BaseReader):
self.file_token_counts = {}
for input_file in self.input_files:
if input_file.suffix in self.file_extractor:
parser = self.file_extractor[input_file.suffix]
suffix_lower = input_file.suffix.lower()
if suffix_lower in self.file_extractor:
parser = self.file_extractor[suffix_lower]
if not parser.parser_config_set:
parser.init_parser()
data = parser.parse_file(input_file, errors=self.errors)
@@ -232,7 +304,7 @@ class SimpleDirectoryReader(BaseReader):
if subtree:
result[item.name] = subtree
else:
if self.required_exts is not None and item.suffix not in self.required_exts:
if self.required_exts is not None and item.suffix.lower() not in self.required_exts:
continue
full_path = str(item.resolve())
@@ -251,4 +323,4 @@ class SimpleDirectoryReader(BaseReader):
return result
return build_tree(Path(base_path))
return build_tree(Path(base_path))

View File

@@ -0,0 +1,330 @@
"""Docling parser.
Uses docling library for advanced document parsing with layout detection,
table structure recognition, and unified document representation.
Supports: PDF, DOCX, PPTX, XLSX, HTML, XHTML, CSV, Markdown, AsciiDoc,
images (PNG, JPEG, TIFF, BMP, WEBP), WebVTT, and specialized XML formats.
"""
import importlib.util
import logging
from pathlib import Path
from typing import Dict, List, Optional, Union
from application.parser.file.base_parser import BaseParser
logger = logging.getLogger(__name__)
class DoclingParser(BaseParser):
"""Parser using docling for advanced document processing.
Docling provides:
- Advanced PDF layout analysis
- Table structure recognition
- Reading order detection
- OCR for scanned documents (supports RapidOCR)
- Unified DoclingDocument format
- Export to Markdown
Uses hybrid OCR approach by default:
- Text regions: Direct PDF text extraction (fast)
- Bitmap/image regions: OCR only these areas (smart)
"""
def __init__(
self,
ocr_enabled: bool = True,
table_structure: bool = True,
export_format: str = "markdown",
use_rapidocr: bool = True,
ocr_languages: Optional[List[str]] = None,
force_full_page_ocr: bool = False,
):
"""Initialize DoclingParser.
Args:
ocr_enabled: Enable OCR for bitmap/image regions in documents
table_structure: Enable table structure recognition
export_format: Output format ('markdown', 'text', 'html')
use_rapidocr: Use RapidOCR engine (default True, works well in Docker)
ocr_languages: List of OCR languages (default: ['english'])
force_full_page_ocr: Force OCR on entire page (False = smart hybrid OCR)
"""
super().__init__()
self.ocr_enabled = ocr_enabled
self.table_structure = table_structure
self.export_format = export_format
self.use_rapidocr = use_rapidocr
self.ocr_languages = ocr_languages or ["english"]
self.force_full_page_ocr = force_full_page_ocr
self._converter = None
def _create_converter(self):
"""Create a docling converter with hybrid OCR configuration.
Uses smart OCR approach:
- When ocr_enabled=True and force_full_page_ocr=False (default):
Layout model detects text vs bitmap regions, OCR only runs on bitmaps
- When ocr_enabled=True and force_full_page_ocr=True:
OCR runs on entire page (for scanned documents/images)
- When ocr_enabled=False:
No OCR, only native text extraction
Returns:
DocumentConverter instance
"""
from docling.document_converter import (
DocumentConverter,
ImageFormatOption,
InputFormat,
PdfFormatOption,
)
from docling.datamodel.pipeline_options import PdfPipelineOptions
pipeline_options = PdfPipelineOptions(
do_ocr=self.ocr_enabled,
do_table_structure=self.table_structure,
)
if self.ocr_enabled:
ocr_options = self._get_ocr_options()
if ocr_options is not None:
pipeline_options.ocr_options = ocr_options
return DocumentConverter(
format_options={
InputFormat.PDF: PdfFormatOption(
pipeline_options=pipeline_options,
),
InputFormat.IMAGE: ImageFormatOption(
pipeline_options=pipeline_options,
),
}
)
def _init_parser(self) -> Dict:
"""Initialize the docling converter with hybrid OCR."""
logger.info("Initializing DoclingParser...")
logger.info(f" ocr_enabled={self.ocr_enabled}")
logger.info(f" force_full_page_ocr={self.force_full_page_ocr}")
logger.info(f" use_rapidocr={self.use_rapidocr}")
if importlib.util.find_spec("docling.document_converter") is None:
raise ImportError(
"docling is required for DoclingParser. "
"Install it with: pip install docling"
)
# Create converter with hybrid OCR (smart: text direct, bitmaps OCR'd)
self._converter = self._create_converter()
logger.info("DoclingParser initialized successfully")
return {
"ocr_enabled": self.ocr_enabled,
"table_structure": self.table_structure,
"export_format": self.export_format,
"use_rapidocr": self.use_rapidocr,
"ocr_languages": self.ocr_languages,
"force_full_page_ocr": self.force_full_page_ocr,
}
def _get_ocr_options(self):
"""Get OCR options based on configuration.
Returns RapidOcrOptions if use_rapidocr is True and available,
otherwise returns None to use docling defaults.
"""
if not self.use_rapidocr:
return None
try:
from docling.datamodel.pipeline_options import RapidOcrOptions
return RapidOcrOptions(
lang=self.ocr_languages,
force_full_page_ocr=self.force_full_page_ocr,
)
except ImportError as e:
logger.warning(f"Failed to import RapidOcrOptions: {e}")
return None
except Exception as e:
logger.error(f"Error creating RapidOcrOptions: {e}")
return None
def _export_content(self, document) -> str:
"""Export document content in the configured format.
Handles edge case where text is nested under picture elements (e.g., OCR'd
images). If the standard export returns minimal content but document.texts
contains extracted text, falls back to direct text extraction.
"""
if self.export_format == "markdown":
content = document.export_to_markdown()
elif self.export_format == "html":
content = document.export_to_html()
else:
content = document.export_to_text()
# Handle case where text is nested under pictures (common with OCR'd images)
# Standard exports may return just "<!-- image -->" while actual text exists
stripped_content = content.strip()
is_minimal = len(stripped_content) < 50 or stripped_content == "<!-- image -->"
if is_minimal and hasattr(document, "texts") and document.texts:
# Extract text directly from document.texts
extracted_texts = [t.text for t in document.texts if t.text]
if extracted_texts:
logger.info(
f"Standard export minimal ({len(stripped_content)} chars), "
f"extracting {len(extracted_texts)} texts directly"
)
return "\n\n".join(extracted_texts)
return content
def parse_file(self, file: Path, errors: str = "ignore") -> Union[str, List[str]]:
"""Parse file using docling with hybrid OCR.
Uses smart OCR approach where the layout model detects text vs bitmap
regions. Text is extracted directly, bitmaps are OCR'd only when needed.
Args:
file: Path to the file to parse
errors: Error handling mode (ignored, docling handles internally)
Returns:
Parsed document content as markdown string
"""
logger.info(f"parse_file called for: {file}")
if self._converter is None:
self._init_parser()
try:
logger.info(f"Converting file with hybrid OCR: {file}")
result = self._converter.convert(str(file))
content = self._export_content(result.document)
logger.info(f"Parse complete, content length: {len(content)} chars")
return content
except Exception as e:
logger.error(f"Error parsing file with docling: {e}", exc_info=True)
if errors == "ignore":
return f"[Error parsing file with docling: {str(e)}]"
raise
class DoclingPDFParser(DoclingParser):
"""Docling-based PDF parser with advanced features and RapidOCR support.
Uses hybrid OCR approach by default:
- Text regions: Direct PDF text extraction (fast)
- Bitmap/image regions: OCR only these areas (smart)
Set force_full_page_ocr=True only for fully scanned documents.
"""
def __init__(
self,
ocr_enabled: bool = True,
table_structure: bool = True,
use_rapidocr: bool = True,
ocr_languages: Optional[List[str]] = None,
force_full_page_ocr: bool = False,
):
super().__init__(
ocr_enabled=ocr_enabled,
table_structure=table_structure,
export_format="markdown",
use_rapidocr=use_rapidocr,
ocr_languages=ocr_languages,
force_full_page_ocr=force_full_page_ocr,
)
class DoclingDocxParser(DoclingParser):
"""Docling-based DOCX parser."""
def __init__(self):
super().__init__(export_format="markdown")
class DoclingPPTXParser(DoclingParser):
"""Docling-based PPTX parser."""
def __init__(self):
super().__init__(export_format="markdown")
class DoclingXLSXParser(DoclingParser):
"""Docling-based XLSX parser with table structure."""
def __init__(self):
super().__init__(table_structure=True, export_format="markdown")
class DoclingHTMLParser(DoclingParser):
"""Docling-based HTML parser."""
def __init__(self):
super().__init__(export_format="markdown")
class DoclingImageParser(DoclingParser):
"""Docling-based image parser with OCR and RapidOCR support.
For images, force_full_page_ocr=True is used since images are entirely
visual and require full OCR to extract any text.
"""
def __init__(
self,
ocr_enabled: bool = True,
use_rapidocr: bool = True,
ocr_languages: Optional[List[str]] = None,
force_full_page_ocr: bool = True,
):
super().__init__(
ocr_enabled=ocr_enabled,
export_format="markdown",
use_rapidocr=use_rapidocr,
ocr_languages=ocr_languages,
force_full_page_ocr=force_full_page_ocr,
)
class DoclingCSVParser(DoclingParser):
"""Docling-based CSV parser."""
def __init__(self):
super().__init__(table_structure=True, export_format="markdown")
class DoclingMarkdownParser(DoclingParser):
"""Docling-based Markdown parser."""
def __init__(self):
super().__init__(export_format="markdown")
class DoclingAsciiDocParser(DoclingParser):
"""Docling-based AsciiDoc parser."""
def __init__(self):
super().__init__(export_format="markdown")
class DoclingVTTParser(DoclingParser):
"""Docling-based WebVTT (video text tracks) parser."""
def __init__(self):
super().__init__(export_format="markdown")
class DoclingXMLParser(DoclingParser):
"""Docling-based XML parser (USPTO, JATS)."""
def __init__(self):
super().__init__(export_format="markdown")

View File

@@ -7,8 +7,8 @@ import re
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union, cast
import tiktoken
from application.parser.file.base_parser import BaseParser
from application.utils import num_tokens_from_string
class MarkdownParser(BaseParser):
@@ -38,7 +38,7 @@ class MarkdownParser(BaseParser):
def tups_chunk_append(self, tups: List[Tuple[Optional[str], str]], current_header: Optional[str],
current_text: str):
"""Append to tups chunk."""
num_tokens = len(tiktoken.get_encoding("cl100k_base").encode(current_text))
num_tokens = num_tokens_from_string(current_text)
if num_tokens > self._max_tokens:
chunks = [current_text[i:i + self._max_tokens] for i in range(0, len(current_text), self._max_tokens)]
for chunk in chunks:

Some files were not shown because too many files have changed in this diff Show More