Compare commits

..

1 Commits

Author SHA1 Message Date
dependabot[bot]
835182461e chore(deps): bump actions/checkout from 4 to 6
Bumps [actions/checkout](https://github.com/actions/checkout) from 4 to 6.
- [Release notes](https://github.com/actions/checkout/releases)
- [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/checkout/compare/v4...v6)

---
updated-dependencies:
- dependency-name: actions/checkout
  dependency-version: '6'
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-11-20 20:12:53 +00:00
271 changed files with 3541 additions and 35809 deletions

View File

@@ -1,28 +1,9 @@
API_KEY=<LLM api key (for example, open ai key)>
LLM_NAME=docsgpt
VITE_API_STREAMING=true
INTERNAL_KEY=<internal key for worker-to-backend authentication>
# Remote Embeddings (Optional - for using a remote embeddings API instead of local SentenceTransformer)
# When set, the app will use the remote API and won't load SentenceTransformer (saves RAM)
EMBEDDINGS_BASE_URL=
EMBEDDINGS_KEY=
#For Azure (you can delete it if you don't use Azure)
OPENAI_API_BASE=
OPENAI_API_VERSION=
AZURE_DEPLOYMENT_NAME=
AZURE_EMBEDDINGS_DEPLOYMENT_NAME=
#Azure AD Application (client) ID
MICROSOFT_CLIENT_ID=your-azure-ad-client-id
#Azure AD Application client secret
MICROSOFT_CLIENT_SECRET=your-azure-ad-client-secret
#Azure AD Tenant ID (or 'common' for multi-tenant)
MICROSOFT_TENANT_ID=your-azure-ad-tenant-id
#If you are using a Microsoft Entra ID tenant,
#configure the AUTHORITY variable as
#"https://login.microsoftonline.com/TENANT_GUID"
#or "https://login.microsoftonline.com/contoso.onmicrosoft.com".
#Alternatively, use "https://login.microsoftonline.com/common" for multi-tenant app.
MICROSOFT_AUTHORITY=https://{tenantId}.ciamlogin.com/{tenantId}
AZURE_EMBEDDINGS_DEPLOYMENT_NAME=

View File

@@ -18,7 +18,7 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: Set up Python
uses: actions/setup-python@v5

View File

@@ -21,7 +21,7 @@ jobs:
contents: read
packages: write
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v6
- name: Set up QEMU # Only needed for emulation, not for native arm64 builds
if: matrix.platform == 'linux/arm64'

View File

@@ -21,7 +21,7 @@ jobs:
contents: read
packages: write
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v6
- name: Set up QEMU # Only needed for emulation, not for native arm64 builds
if: matrix.platform == 'linux/arm64'

View File

@@ -23,7 +23,7 @@ jobs:
contents: read
packages: write
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v6
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3

View File

@@ -23,7 +23,7 @@ jobs:
contents: read
packages: write
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v6
- name: Set up QEMU # Only needed for emulation, not for native arm64 builds
if: matrix.platform == 'linux/arm64'

View File

@@ -7,14 +7,11 @@ on:
pull_request:
types: [ opened, synchronize ]
permissions:
contents: read
jobs:
ruff:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v6
- name: Lint with Ruff
uses: chartboost/ruff-action@v1

View File

@@ -1,9 +1,5 @@
name: Run python tests with pytest
on: [push, pull_request]
permissions:
contents: read
jobs:
pytest_and_coverage:
name: Run tests and count coverage
@@ -12,7 +8,7 @@ jobs:
matrix:
python-version: ["3.12"]
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v6
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:

View File

@@ -17,7 +17,7 @@ jobs:
steps:
# Step 1: run a standard checkout action
- name: Checkout target repo
uses: actions/checkout@v4
uses: actions/checkout@v6
# Step 2: run the sync action
- name: Sync upstream changes

View File

@@ -9,16 +9,12 @@ on:
- '.vale.ini'
- '.github/styles/**'
permissions:
contents: read
pull-requests: write
jobs:
vale:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: Vale linter
uses: errata-ai/vale-action@v2

5
.gitignore vendored
View File

@@ -2,7 +2,6 @@
__pycache__/
*.py[cod]
*$py.class
experiments/
experiments
# C extensions
@@ -147,10 +146,6 @@ frontend/yarn-error.log*
frontend/pnpm-debug.log*
frontend/lerna-debug.log*
# Keep frontend utility helpers tracked (overrides global lib/ ignore)
!frontend/src/lib/
!frontend/src/lib/**
frontend/node_modules
frontend/dist
frontend/dist-ssr

View File

@@ -1,6 +1,2 @@
# Allow lines to be as long as 120 characters.
line-length = 120
[lint.per-file-ignores]
# Integration tests use sys.path.insert() before imports for standalone execution
"tests/integration/*.py" = ["E402"]
line-length = 120

View File

@@ -26,6 +26,13 @@
</div>
<div align="center">
<br>
🎃 <a href="https://github.com/arc53/DocsGPT/blob/main/HACKTOBERFEST.md"> Hacktoberfest Prizes, Rules & Q&A </a> 🎃
<br>
<br>
</div>
<div align="center">
<br>
@@ -46,11 +53,24 @@
</ul>
## Roadmap
- [x] Add OAuth 2.0 authentication for MCP ( September 2025 )
- [x] Deep Agents ( October 2025 )
- [x] Prompt Templating ( October 2025 )
- [x] Full api tooling ( Dec 2025 )
- [ ] Agent scheduling ( Jan 2026 )
- [x] Full GoogleAI compatibility (Jan 2025)
- [x] Add tools (Jan 2025)
- [x] Manually updating chunks in the app UI (Feb 2025)
- [x] Devcontainer for easy development (Feb 2025)
- [x] ReACT agent (March 2025)
- [x] Chatbots menu re-design to handle tools, agent types, and more (April 2025)
- [x] New input box in the conversation menu (April 2025)
- [x] Add triggerable actions / tools (webhook) (April 2025)
- [x] Agent optimisations (May 2025)
- [x] Filesystem sources update (July 2025)
- [x] Json Responses (August 2025)
- [x] MCP support (August 2025)
- [x] Google Drive integration (September 2025)
- [x] Add OAuth 2.0 authentication for MCP (September 2025)
- [ ] SharePoint integration (October 2025)
- [ ] Deep Agents (October 2025)
- [ ] Agent scheduling
You can find our full roadmap [here](https://github.com/orgs/arc53/projects/2). Please don't hesitate to contribute or create issues, it helps us improve DocsGPT!
@@ -145,17 +165,9 @@ We as members, contributors, and leaders, pledge to make participation in our co
The source code license is [MIT](https://opensource.org/license/mit/), as described in the [LICENSE](LICENSE) file.
## This project is supported by:
<p>This project is supported by:</p>
<p>
<a href="https://www.digitalocean.com/?utm_medium=opensource&utm_source=DocsGPT">
<img src="https://opensource.nyc3.cdn.digitaloceanspaces.com/attribution/assets/SVG/DO_Logo_horizontal_blue.svg" width="201px">
</a>
</p>
<p>
<a href="https://get.neon.com/docsgpt">
<img width="201" alt="color" src="https://github.com/user-attachments/assets/7d9813b7-0e6d-403f-b5af-68af066b326f" />
</a>
</p>

11
application/.env_sample Normal file
View File

@@ -0,0 +1,11 @@
API_KEY=your_api_key
EMBEDDINGS_KEY=your_api_key
API_URL=http://localhost:7091
FLASK_APP=application/app.py
FLASK_DEBUG=true
#For OPENAI on Azure
OPENAI_API_BASE=
OPENAI_API_VERSION=
AZURE_DEPLOYMENT_NAME=
AZURE_EMBEDDINGS_DEPLOYMENT_NAME=

View File

@@ -7,7 +7,7 @@ RUN apt-get update && \
apt-get install -y software-properties-common && \
add-apt-repository ppa:deadsnakes/ppa && \
apt-get update && \
apt-get install -y --no-install-recommends gcc g++ wget unzip libc6-dev python3.12 python3.12-venv python3.12-dev && \
apt-get install -y --no-install-recommends gcc wget unzip libc6-dev python3.12 python3.12-venv && \
rm -rf /var/lib/apt/lists/*
# Verify Python installation and setup symlink
@@ -48,12 +48,7 @@ FROM ubuntu:24.04 as final
RUN apt-get update && \
apt-get install -y software-properties-common && \
add-apt-repository ppa:deadsnakes/ppa && \
apt-get update && apt-get install -y --no-install-recommends \
python3.12 \
libgl1 \
libglib2.0-0 \
poppler-utils \
&& \
apt-get update && apt-get install -y --no-install-recommends python3.12 && \
ln -s /usr/bin/python3.12 /usr/bin/python && \
rm -rf /var/lib/apt/lists/*

View File

@@ -1,8 +1,6 @@
import logging
from application.agents.classic_agent import ClassicAgent
from application.agents.react_agent import ReActAgent
from application.agents.workflow_agent import WorkflowAgent
import logging
logger = logging.getLogger(__name__)
@@ -11,7 +9,6 @@ class AgentCreator:
agents = {
"classic": ClassicAgent,
"react": ReActAgent,
"workflow": WorkflowAgent,
}
@classmethod
@@ -19,4 +16,5 @@ class AgentCreator:
agent_class = cls.agents.get(type.lower())
if not agent_class:
raise ValueError(f"No agent class found for type {type}")
return agent_class(*args, **kwargs)

View File

@@ -34,7 +34,6 @@ class BaseAgent(ABC):
token_limit: Optional[int] = settings.DEFAULT_AGENT_LIMITS["token_limit"],
limited_request_mode: Optional[bool] = False,
request_limit: Optional[int] = settings.DEFAULT_AGENT_LIMITS["request_limit"],
compressed_summary: Optional[str] = None,
):
self.endpoint = endpoint
self.llm_name = llm_name
@@ -65,9 +64,6 @@ class BaseAgent(ABC):
self.token_limit = token_limit
self.limited_request_mode = limited_request_mode
self.request_limit = request_limit
self.compressed_summary = compressed_summary
self.current_token_count = 0
self.context_limit_reached = False
@log_activity()
def gen(
@@ -120,10 +116,10 @@ class BaseAgent(ABC):
params["properties"][k] = {
key: value
for key, value in v.items()
if key not in ("filled_by_llm", "value", "required")
if key != "filled_by_llm" and key != "value"
}
if v.get("required", False):
params["required"].append(k)
params["required"].append(k)
return params
def _prepare_tools(self, tools_dict):
@@ -219,11 +215,7 @@ class BaseAgent(ABC):
for param_type, target_dict in param_types.items():
if param_type in action_data and action_data[param_type].get("properties"):
for param, details in action_data[param_type]["properties"].items():
if (
param not in call_args
and "value" in details
and details["value"]
):
if param not in call_args and "value" in details:
target_dict[param] = details["value"]
for param, value in call_args.items():
for param_type, target_dict in param_types.items():
@@ -236,62 +228,31 @@ class BaseAgent(ABC):
# Prepare tool_config and add tool_id for memory tools
if tool_data["name"] == "api_tool":
action_config = tool_data["config"]["actions"][action_name]
tool_config = {
"url": action_config["url"],
"method": action_config["method"],
"url": tool_data["config"]["actions"][action_name]["url"],
"method": tool_data["config"]["actions"][action_name]["method"],
"headers": headers,
"query_params": query_params,
}
if "body_content_type" in action_config:
tool_config["body_content_type"] = action_config.get(
"body_content_type", "application/json"
)
tool_config["body_encoding_rules"] = action_config.get(
"body_encoding_rules", {}
)
else:
tool_config = tool_data["config"].copy() if tool_data["config"] else {}
# Add tool_id from MongoDB _id for tools that need instance isolation (like memory tool)
# Use MongoDB _id if available, otherwise fall back to enumerated tool_id
tool_config["tool_id"] = str(tool_data.get("_id", tool_id))
if hasattr(self, "conversation_id") and self.conversation_id:
tool_config["conversation_id"] = self.conversation_id
tool = tm.load_tool(
tool_data["name"],
tool_config=tool_config,
user_id=self.user,
user_id=self.user, # Pass user ID for MCP tools credential decryption
)
if tool_data["name"] == "api_tool":
logger.debug(
print(
f"Executing api: {action_name} with query_params: {query_params}, headers: {headers}, body: {body}"
)
result = tool.execute_action(action_name, **body)
else:
logger.debug(f"Executing tool: {action_name} with args: {call_args}")
print(f"Executing tool: {action_name} with args: {call_args}")
result = tool.execute_action(action_name, **parameters)
get_artifact_id = (
getattr(tool, "get_artifact_id", None)
if tool_data["name"] != "api_tool"
else None
)
artifact_id = None
if callable(get_artifact_id):
try:
artifact_id = get_artifact_id(action_name, **parameters)
except Exception:
logger.exception(
"Failed to extract artifact_id from tool %s for action %s",
tool_data["name"],
action_name,
)
artifact_id = str(artifact_id).strip() if artifact_id is not None else ""
if artifact_id:
tool_call_data["artifact_id"] = artifact_id
tool_call_data["result"] = (
f"{str(result)[:50]}..." if len(str(result)) > 50 else result
)
@@ -315,176 +276,15 @@ class BaseAgent(ABC):
for tool_call in self.tool_calls
]
def _calculate_current_context_tokens(self, messages: List[Dict]) -> int:
"""
Calculate total tokens in current context (messages).
Args:
messages: List of message dicts
Returns:
Total token count
"""
from application.api.answer.services.compression.token_counter import (
TokenCounter,
)
return TokenCounter.count_message_tokens(messages)
def _check_context_limit(self, messages: List[Dict]) -> bool:
"""
Check if we're approaching context limit (80%).
Args:
messages: Current message list
Returns:
True if at or above 80% of context limit
"""
from application.core.model_utils import get_token_limit
from application.core.settings import settings
try:
# Calculate current tokens
current_tokens = self._calculate_current_context_tokens(messages)
self.current_token_count = current_tokens
# Get context limit for model
context_limit = get_token_limit(self.model_id)
# Calculate threshold (80%)
threshold = int(context_limit * settings.COMPRESSION_THRESHOLD_PERCENTAGE)
# Check if we've reached the limit
if current_tokens >= threshold:
logger.warning(
f"Context limit approaching: {current_tokens}/{context_limit} tokens "
f"({(current_tokens/context_limit)*100:.1f}%)"
)
return True
return False
except Exception as e:
logger.error(f"Error checking context limit: {str(e)}", exc_info=True)
return False
def _validate_context_size(self, messages: List[Dict]) -> None:
"""
Pre-flight validation before calling LLM. Logs warnings but never raises errors.
Args:
messages: Messages to be sent to LLM
"""
from application.core.model_utils import get_token_limit
current_tokens = self._calculate_current_context_tokens(messages)
self.current_token_count = current_tokens
context_limit = get_token_limit(self.model_id)
percentage = (current_tokens / context_limit) * 100
# Log based on usage level
if current_tokens >= context_limit:
logger.warning(
f"Context at limit: {current_tokens:,}/{context_limit:,} tokens "
f"({percentage:.1f}%). Model: {self.model_id}"
)
elif current_tokens >= int(
context_limit * settings.COMPRESSION_THRESHOLD_PERCENTAGE
):
logger.info(
f"Context approaching limit: {current_tokens:,}/{context_limit:,} tokens "
f"({percentage:.1f}%)"
)
def _truncate_text_middle(self, text: str, max_tokens: int) -> str:
"""
Truncate text by removing content from the middle, preserving start and end.
Args:
text: Text to truncate
max_tokens: Maximum tokens allowed
Returns:
Truncated text with middle removed if needed
"""
from application.utils import num_tokens_from_string
current_tokens = num_tokens_from_string(text)
if current_tokens <= max_tokens:
return text
# Estimate chars per token (roughly 4 chars per token for English)
chars_per_token = len(text) / current_tokens if current_tokens > 0 else 4
target_chars = int(max_tokens * chars_per_token * 0.95) # 5% safety margin
if target_chars <= 0:
return ""
# Split: keep 40% from start, 40% from end, remove middle
start_chars = int(target_chars * 0.4)
end_chars = int(target_chars * 0.4)
truncation_marker = "\n\n[... content truncated to fit context limit ...]\n\n"
truncated = text[:start_chars] + truncation_marker + text[-end_chars:]
logger.info(
f"Truncated text from {current_tokens:,} to ~{max_tokens:,} tokens "
f"(removed middle section)"
)
return truncated
def _build_messages(
self,
system_prompt: str,
query: str,
) -> List[Dict]:
"""Build messages using pre-rendered system prompt"""
from application.core.model_utils import get_token_limit
from application.utils import num_tokens_from_string
# Append compression summary to system prompt if present
if self.compressed_summary:
compression_context = (
"\n\n---\n\n"
"This session is being continued from a previous conversation that "
"has been compressed to fit within context limits. "
"The conversation is summarized below:\n\n"
f"{self.compressed_summary}"
)
system_prompt = system_prompt + compression_context
context_limit = get_token_limit(self.model_id)
system_tokens = num_tokens_from_string(system_prompt)
# Reserve 10% for response/tools
safety_buffer = int(context_limit * 0.1)
available_after_system = context_limit - system_tokens - safety_buffer
# Max tokens for query: 80% of available space (leave room for history)
max_query_tokens = int(available_after_system * 0.8)
query_tokens = num_tokens_from_string(query)
# Truncate query from middle if it exceeds 80% of available context
if query_tokens > max_query_tokens:
query = self._truncate_text_middle(query, max_query_tokens)
query_tokens = num_tokens_from_string(query)
# Calculate remaining budget for chat history
available_for_history = max(available_after_system - query_tokens, 0)
# Truncate chat history to fit within available budget
working_history = self._truncate_history_to_fit(
self.chat_history,
available_for_history,
)
messages = [{"role": "system", "content": system_prompt}]
for i in working_history:
for i in self.chat_history:
if "prompt" in i and "response" in i:
messages.append({"role": "user", "content": i["prompt"]})
messages.append({"role": "assistant", "content": i["response"]})
@@ -516,65 +316,7 @@ class BaseAgent(ABC):
messages.append({"role": "user", "content": query})
return messages
def _truncate_history_to_fit(
self,
history: List[Dict],
max_tokens: int,
) -> List[Dict]:
"""
Truncate chat history to fit within token budget, keeping most recent messages.
Args:
history: Full chat history
max_tokens: Maximum tokens allowed for history
Returns:
Truncated history (most recent messages that fit)
"""
from application.utils import num_tokens_from_string
if not history or max_tokens <= 0:
return []
truncated = []
current_tokens = 0
# Iterate from newest to oldest
for message in reversed(history):
message_tokens = 0
if "prompt" in message and "response" in message:
message_tokens += num_tokens_from_string(message["prompt"])
message_tokens += num_tokens_from_string(message["response"])
if "tool_calls" in message:
for tool_call in message["tool_calls"]:
tool_str = (
f"Tool: {tool_call.get('tool_name')} | "
f"Action: {tool_call.get('action_name')} | "
f"Args: {tool_call.get('arguments')} | "
f"Response: {tool_call.get('result')}"
)
message_tokens += num_tokens_from_string(tool_str)
if current_tokens + message_tokens <= max_tokens:
current_tokens += message_tokens
truncated.insert(0, message) # Maintain chronological order
else:
break
if len(truncated) < len(history):
logger.info(
f"Truncated chat history from {len(history)} to {len(truncated)} messages "
f"to fit within {max_tokens:,} token budget"
)
return truncated
def _llm_gen(self, messages: List[Dict], log_context: Optional[LogContext] = None):
# Pre-flight context validation - fail fast if over limit
self._validate_context_size(messages)
gen_kwargs = {"model": self.model_id, "messages": messages}
if (

View File

@@ -235,4 +235,4 @@ class ReActAgent(BaseAgent):
)
except Exception as e:
logger.error(f"Error extracting content: {e}")
return "".join(collected)
return "".join(collected)

View File

@@ -1,323 +0,0 @@
import base64
import json
import logging
from enum import Enum
from typing import Any, Dict, Optional, Union
from urllib.parse import quote, urlencode
logger = logging.getLogger(__name__)
class ContentType(str, Enum):
"""Supported content types for request bodies."""
JSON = "application/json"
FORM_URLENCODED = "application/x-www-form-urlencoded"
MULTIPART_FORM_DATA = "multipart/form-data"
TEXT_PLAIN = "text/plain"
XML = "application/xml"
OCTET_STREAM = "application/octet-stream"
class RequestBodySerializer:
"""Serializes request bodies according to content-type and OpenAPI 3.1 spec."""
@staticmethod
def serialize(
body_data: Dict[str, Any],
content_type: str = ContentType.JSON,
encoding_rules: Optional[Dict[str, Dict[str, Any]]] = None,
) -> tuple[Union[str, bytes], Dict[str, str]]:
"""
Serialize body data to appropriate format.
Args:
body_data: Dictionary of body parameters
content_type: Content-Type header value
encoding_rules: OpenAPI Encoding Object rules per field
Returns:
Tuple of (serialized_body, updated_headers_dict)
Raises:
ValueError: If serialization fails
"""
if not body_data:
return None, {}
try:
content_type_lower = content_type.lower().split(";")[0].strip()
if content_type_lower == ContentType.JSON:
return RequestBodySerializer._serialize_json(body_data)
elif content_type_lower == ContentType.FORM_URLENCODED:
return RequestBodySerializer._serialize_form_urlencoded(
body_data, encoding_rules
)
elif content_type_lower == ContentType.MULTIPART_FORM_DATA:
return RequestBodySerializer._serialize_multipart_form_data(
body_data, encoding_rules
)
elif content_type_lower == ContentType.TEXT_PLAIN:
return RequestBodySerializer._serialize_text_plain(body_data)
elif content_type_lower == ContentType.XML:
return RequestBodySerializer._serialize_xml(body_data)
elif content_type_lower == ContentType.OCTET_STREAM:
return RequestBodySerializer._serialize_octet_stream(body_data)
else:
logger.warning(
f"Unknown content type: {content_type}, treating as JSON"
)
return RequestBodySerializer._serialize_json(body_data)
except Exception as e:
logger.error(f"Error serializing body: {str(e)}", exc_info=True)
raise ValueError(f"Failed to serialize request body: {str(e)}")
@staticmethod
def _serialize_json(body_data: Dict[str, Any]) -> tuple[str, Dict[str, str]]:
"""Serialize body as JSON per OpenAPI spec."""
try:
serialized = json.dumps(
body_data, separators=(",", ":"), ensure_ascii=False
)
headers = {"Content-Type": ContentType.JSON.value}
return serialized, headers
except (TypeError, ValueError) as e:
raise ValueError(f"Failed to serialize JSON body: {str(e)}")
@staticmethod
def _serialize_form_urlencoded(
body_data: Dict[str, Any],
encoding_rules: Optional[Dict[str, Dict[str, Any]]] = None,
) -> tuple[str, Dict[str, str]]:
"""Serialize body as application/x-www-form-urlencoded per RFC1866/RFC3986."""
encoding_rules = encoding_rules or {}
params = []
for key, value in body_data.items():
if value is None:
continue
rule = encoding_rules.get(key, {})
style = rule.get("style", "form")
explode = rule.get("explode", style == "form")
content_type = rule.get("contentType", "text/plain")
serialized_value = RequestBodySerializer._serialize_form_value(
value, style, explode, content_type, key
)
if isinstance(serialized_value, list):
for sv in serialized_value:
params.append((key, sv))
else:
params.append((key, serialized_value))
# Use standard urlencode (replaces space with +)
serialized = urlencode(params, safe="")
headers = {"Content-Type": ContentType.FORM_URLENCODED.value}
return serialized, headers
@staticmethod
def _serialize_form_value(
value: Any, style: str, explode: bool, content_type: str, key: str
) -> Union[str, list]:
"""Serialize individual form value with encoding rules."""
if isinstance(value, dict):
if content_type == "application/json":
return json.dumps(value, separators=(",", ":"))
elif content_type == "application/xml":
return RequestBodySerializer._dict_to_xml(value)
else:
if style == "deepObject" and explode:
return [
f"{RequestBodySerializer._percent_encode(str(v))}"
for v in value.values()
]
elif explode:
return [
f"{RequestBodySerializer._percent_encode(str(v))}"
for v in value.values()
]
else:
pairs = [f"{k},{v}" for k, v in value.items()]
return RequestBodySerializer._percent_encode(",".join(pairs))
elif isinstance(value, (list, tuple)):
if explode:
return [
RequestBodySerializer._percent_encode(str(item)) for item in value
]
else:
return RequestBodySerializer._percent_encode(
",".join(str(v) for v in value)
)
else:
return RequestBodySerializer._percent_encode(str(value))
@staticmethod
def _serialize_multipart_form_data(
body_data: Dict[str, Any],
encoding_rules: Optional[Dict[str, Dict[str, Any]]] = None,
) -> tuple[bytes, Dict[str, str]]:
"""
Serialize body as multipart/form-data per RFC7578.
Supports file uploads and encoding rules.
"""
import secrets
encoding_rules = encoding_rules or {}
boundary = f"----DocsGPT{secrets.token_hex(16)}"
parts = []
for key, value in body_data.items():
if value is None:
continue
rule = encoding_rules.get(key, {})
content_type = rule.get("contentType", "text/plain")
headers_rule = rule.get("headers", {})
part = RequestBodySerializer._create_multipart_part(
key, value, content_type, headers_rule
)
parts.append(part)
body_bytes = f"--{boundary}\r\n".encode("utf-8")
body_bytes += f"--{boundary}\r\n".join(parts).encode("utf-8")
body_bytes += f"\r\n--{boundary}--\r\n".encode("utf-8")
headers = {
"Content-Type": f"multipart/form-data; boundary={boundary}",
}
return body_bytes, headers
@staticmethod
def _create_multipart_part(
name: str, value: Any, content_type: str, headers_rule: Dict[str, Any]
) -> str:
"""Create a single multipart/form-data part."""
headers = [
f'Content-Disposition: form-data; name="{RequestBodySerializer._percent_encode(name)}"'
]
if isinstance(value, bytes):
if content_type == "application/octet-stream":
value_encoded = base64.b64encode(value).decode("utf-8")
else:
value_encoded = value.decode("utf-8", errors="replace")
headers.append(f"Content-Type: {content_type}")
headers.append("Content-Transfer-Encoding: base64")
elif isinstance(value, dict):
if content_type == "application/json":
value_encoded = json.dumps(value, separators=(",", ":"))
elif content_type == "application/xml":
value_encoded = RequestBodySerializer._dict_to_xml(value)
else:
value_encoded = str(value)
headers.append(f"Content-Type: {content_type}")
elif isinstance(value, str) and content_type != "text/plain":
try:
if content_type == "application/json":
json.loads(value)
value_encoded = value
elif content_type == "application/xml":
value_encoded = value
else:
value_encoded = str(value)
except json.JSONDecodeError:
value_encoded = str(value)
headers.append(f"Content-Type: {content_type}")
else:
value_encoded = str(value)
if content_type != "text/plain":
headers.append(f"Content-Type: {content_type}")
part = "\r\n".join(headers) + "\r\n\r\n" + value_encoded + "\r\n"
return part
@staticmethod
def _serialize_text_plain(body_data: Dict[str, Any]) -> tuple[str, Dict[str, str]]:
"""Serialize body as plain text."""
if len(body_data) == 1:
value = list(body_data.values())[0]
return str(value), {"Content-Type": ContentType.TEXT_PLAIN.value}
else:
text = "\n".join(f"{k}: {v}" for k, v in body_data.items())
return text, {"Content-Type": ContentType.TEXT_PLAIN.value}
@staticmethod
def _serialize_xml(body_data: Dict[str, Any]) -> tuple[str, Dict[str, str]]:
"""Serialize body as XML."""
xml_str = RequestBodySerializer._dict_to_xml(body_data)
return xml_str, {"Content-Type": ContentType.XML.value}
@staticmethod
def _serialize_octet_stream(
body_data: Dict[str, Any],
) -> tuple[bytes, Dict[str, str]]:
"""Serialize body as binary octet stream."""
if isinstance(body_data, bytes):
return body_data, {"Content-Type": ContentType.OCTET_STREAM.value}
elif isinstance(body_data, str):
return body_data.encode("utf-8"), {
"Content-Type": ContentType.OCTET_STREAM.value
}
else:
serialized = json.dumps(body_data)
return serialized.encode("utf-8"), {
"Content-Type": ContentType.OCTET_STREAM.value
}
@staticmethod
def _percent_encode(value: str, safe_chars: str = "") -> str:
"""
Percent-encode per RFC3986.
Args:
value: String to encode
safe_chars: Additional characters to not encode
"""
return quote(value, safe=safe_chars)
@staticmethod
def _dict_to_xml(data: Dict[str, Any], root_name: str = "root") -> str:
"""
Convert dict to simple XML format.
"""
def build_xml(obj: Any, name: str) -> str:
if isinstance(obj, dict):
inner = "".join(build_xml(v, k) for k, v in obj.items())
return f"<{name}>{inner}</{name}>"
elif isinstance(obj, (list, tuple)):
items = "".join(
build_xml(item, f"{name[:-1] if name.endswith('s') else name}")
for item in obj
)
return items
else:
return f"<{name}>{RequestBodySerializer._escape_xml(str(obj))}</{name}>"
root = build_xml(data, root_name)
return f'<?xml version="1.0" encoding="UTF-8"?>{root}'
@staticmethod
def _escape_xml(value: str) -> str:
"""Escape XML special characters."""
return (
value.replace("&", "&amp;")
.replace("<", "&lt;")
.replace(">", "&gt;")
.replace('"', "&quot;")
.replace("'", "&apos;")
)

View File

@@ -1,280 +1,72 @@
import json
import logging
import re
from typing import Any, Dict, Optional
from urllib.parse import urlencode
import requests
from application.agents.tools.api_body_serializer import (
ContentType,
RequestBodySerializer,
)
from application.agents.tools.base import Tool
from application.core.url_validation import validate_url, SSRFError
logger = logging.getLogger(__name__)
DEFAULT_TIMEOUT = 90 # seconds
class APITool(Tool):
"""
API Tool
A flexible tool for performing various API actions (e.g., sending messages, retrieving data) via custom user-specified APIs.
A flexible tool for performing various API actions (e.g., sending messages, retrieving data) via custom user-specified APIs
"""
def __init__(self, config):
self.config = config
self.url = config.get("url", "")
self.method = config.get("method", "GET")
self.headers = config.get("headers", {})
self.headers = config.get("headers", {"Content-Type": "application/json"})
self.query_params = config.get("query_params", {})
self.body_content_type = config.get("body_content_type", ContentType.JSON)
self.body_encoding_rules = config.get("body_encoding_rules", {})
def execute_action(self, action_name, **kwargs):
"""Execute an API action with the given arguments."""
return self._make_api_call(
self.url,
self.method,
self.headers,
self.query_params,
kwargs,
self.body_content_type,
self.body_encoding_rules,
self.url, self.method, self.headers, self.query_params, kwargs
)
def _make_api_call(
self,
url: str,
method: str,
headers: Dict[str, str],
query_params: Dict[str, Any],
body: Dict[str, Any],
content_type: str = ContentType.JSON,
encoding_rules: Optional[Dict[str, Dict[str, Any]]] = None,
) -> Dict[str, Any]:
"""
Make an API call with proper body serialization and error handling.
Args:
url: API endpoint URL
method: HTTP method (GET, POST, PUT, DELETE, PATCH, HEAD, OPTIONS)
headers: Request headers dict
query_params: URL query parameters
body: Request body as dict
content_type: Content-Type for serialization
encoding_rules: OpenAPI encoding rules
Returns:
Dict with status_code, data, and message
"""
request_url = url
request_headers = headers.copy() if headers else {}
response = None
# Validate URL to prevent SSRF attacks
def _make_api_call(self, url, method, headers, query_params, body):
if query_params:
url = f"{url}?{requests.compat.urlencode(query_params)}"
# if isinstance(body, dict):
# body = json.dumps(body)
try:
validate_url(request_url)
except SSRFError as e:
logger.error(f"URL validation failed: {e}")
return {
"status_code": None,
"message": f"URL validation error: {e}",
"data": None,
}
try:
path_params_used = set()
if query_params:
for match in re.finditer(r"\{([^}]+)\}", request_url):
param_name = match.group(1)
if param_name in query_params:
request_url = request_url.replace(
f"{{{param_name}}}", str(query_params[param_name])
)
path_params_used.add(param_name)
remaining_params = {
k: v for k, v in query_params.items() if k not in path_params_used
}
if remaining_params:
query_string = urlencode(remaining_params)
separator = "&" if "?" in request_url else "?"
request_url = f"{request_url}{separator}{query_string}"
# Re-validate URL after parameter substitution to prevent SSRF via path params
try:
validate_url(request_url)
except SSRFError as e:
logger.error(f"URL validation failed after parameter substitution: {e}")
return {
"status_code": None,
"message": f"URL validation error: {e}",
"data": None,
}
# Serialize body based on content type
if body and body != {}:
try:
serialized_body, body_headers = RequestBodySerializer.serialize(
body, content_type, encoding_rules
)
request_headers.update(body_headers)
except ValueError as e:
logger.error(f"Body serialization failed: {str(e)}")
return {
"status_code": None,
"message": f"Body serialization error: {str(e)}",
"data": None,
}
else:
serialized_body = None
if "Content-Type" not in request_headers and method not in [
"GET",
"HEAD",
"DELETE",
]:
request_headers["Content-Type"] = ContentType.JSON
logger.debug(
f"API Call: {method} {request_url} | Content-Type: {request_headers.get('Content-Type', 'N/A')}"
)
if method.upper() == "GET":
response = requests.get(
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
)
elif method.upper() == "POST":
response = requests.post(
request_url,
data=serialized_body,
headers=request_headers,
timeout=DEFAULT_TIMEOUT,
)
elif method.upper() == "PUT":
response = requests.put(
request_url,
data=serialized_body,
headers=request_headers,
timeout=DEFAULT_TIMEOUT,
)
elif method.upper() == "DELETE":
response = requests.delete(
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
)
elif method.upper() == "PATCH":
response = requests.patch(
request_url,
data=serialized_body,
headers=request_headers,
timeout=DEFAULT_TIMEOUT,
)
elif method.upper() == "HEAD":
response = requests.head(
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
)
elif method.upper() == "OPTIONS":
response = requests.options(
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
)
else:
return {
"status_code": None,
"message": f"Unsupported HTTP method: {method}",
"data": None,
}
print(f"Making API call: {method} {url} with body: {body}")
if body == "{}":
body = None
response = requests.request(method, url, headers=headers, data=body)
response.raise_for_status()
data = self._parse_response(response)
content_type = response.headers.get(
"Content-Type", "application/json"
).lower()
if "application/json" in content_type:
try:
data = response.json()
except json.JSONDecodeError as e:
print(f"Error decoding JSON: {e}. Raw response: {response.text}")
return {
"status_code": response.status_code,
"message": f"API call returned invalid JSON. Error: {e}",
"data": response.text,
}
elif "text/" in content_type or "application/xml" in content_type:
data = response.text
elif not response.content:
data = None
else:
print(f"Unsupported content type: {content_type}")
data = response.content
return {
"status_code": response.status_code,
"data": data,
"message": "API call successful.",
}
except requests.exceptions.Timeout:
logger.error(f"Request timeout for {request_url}")
return {
"status_code": None,
"message": f"Request timeout ({DEFAULT_TIMEOUT}s exceeded)",
"data": None,
}
except requests.exceptions.ConnectionError as e:
logger.error(f"Connection error: {str(e)}")
return {
"status_code": None,
"message": f"Connection error: {str(e)}",
"data": None,
}
except requests.exceptions.HTTPError as e:
logger.error(f"HTTP error {response.status_code}: {str(e)}")
try:
error_data = response.json()
except (json.JSONDecodeError, ValueError):
error_data = response.text
return {
"status_code": response.status_code,
"message": f"HTTP Error {response.status_code}",
"data": error_data,
}
except requests.exceptions.RequestException as e:
logger.error(f"Request failed: {str(e)}")
return {
"status_code": response.status_code if response else None,
"message": f"API call failed: {str(e)}",
"data": None,
}
except Exception as e:
logger.error(f"Unexpected error in API call: {str(e)}", exc_info=True)
return {
"status_code": None,
"message": f"Unexpected error: {str(e)}",
"data": None,
}
def _parse_response(self, response: requests.Response) -> Any:
"""
Parse response based on Content-Type header.
Supports: JSON, XML, plain text, binary data.
"""
content_type = response.headers.get("Content-Type", "").lower()
if not response.content:
return None
# JSON response
if "application/json" in content_type:
try:
return response.json()
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse JSON response: {str(e)}")
return response.text
# XML response
elif "application/xml" in content_type or "text/xml" in content_type:
return response.text
# Plain text response
elif "text/plain" in content_type or "text/html" in content_type:
return response.text
# Binary/unknown response
else:
# Try to decode as text first, fall back to base64
try:
return response.text
except (UnicodeDecodeError, AttributeError):
import base64
return base64.b64encode(response.content).decode("utf-8")
def get_actions_metadata(self):
"""Return metadata for available actions (none for API Tool - actions are user-defined)."""
return []
def get_config_requirements(self):
"""Return configuration requirements for the tool."""
return {}

View File

@@ -169,8 +169,6 @@ class MCPTool(Tool):
transport_type = "http"
else:
transport_type = self.transport_type
if transport_type == "stdio":
raise ValueError("STDIO transport is disabled")
if transport_type == "sse":
headers.update({"Accept": "text/event-stream", "Cache-Control": "no-cache"})
return SSETransport(url=self.server_url, headers=headers)
@@ -456,12 +454,6 @@ class MCPTool(Tool):
def get_config_requirements(self) -> Dict:
"""Get configuration requirements for the MCP tool."""
transport_enum = ["auto", "sse", "http"]
transport_help = {
"auto": "Automatically detect best transport",
"sse": "Server-Sent Events (for real-time streaming)",
"http": "HTTP streaming (recommended for production)",
}
return {
"server_url": {
"type": "string",
@@ -471,11 +463,14 @@ class MCPTool(Tool):
"transport_type": {
"type": "string",
"description": "Transport type for connection",
"enum": transport_enum,
"enum": ["auto", "sse", "http", "stdio"],
"default": "auto",
"required": False,
"help": {
**transport_help,
"auto": "Automatically detect best transport",
"sse": "Server-Sent Events (for real-time streaming)",
"http": "HTTP streaming (recommended for production)",
"stdio": "Standard I/O (for local servers)",
},
},
"auth_type": {

View File

@@ -38,8 +38,6 @@ class NotesTool(Tool):
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
self.collection = db["notes"]
self._last_artifact_id: Optional[str] = None
# -----------------------------
# Action implementations
# -----------------------------
@@ -56,8 +54,6 @@ class NotesTool(Tool):
if not self.user_id:
return "Error: NotesTool requires a valid user_id."
self._last_artifact_id = None
if action_name == "view":
return self._get_note()
@@ -129,9 +125,6 @@ class NotesTool(Tool):
"""Return configuration requirements (none for now)."""
return {}
def get_artifact_id(self, action_name: str, **kwargs: Any) -> Optional[str]:
return self._last_artifact_id
# -----------------------------
# Internal helpers (single-note)
# -----------------------------
@@ -139,22 +132,17 @@ class NotesTool(Tool):
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
if not doc or not doc.get("note"):
return "No note found."
if doc.get("_id") is not None:
self._last_artifact_id = str(doc.get("_id"))
return str(doc["note"])
def _overwrite_note(self, content: str) -> str:
content = (content or "").strip()
if not content:
return "Note content required."
result = self.collection.find_one_and_update(
self.collection.update_one(
{"user_id": self.user_id, "tool_id": self.tool_id},
{"$set": {"note": content, "updated_at": datetime.utcnow()}},
upsert=True,
return_document=True,
upsert=True, # ✅ create if missing
)
if result and result.get("_id") is not None:
self._last_artifact_id = str(result.get("_id"))
return "Note saved."
def _str_replace(self, old_str: str, new_str: str) -> str:
@@ -175,13 +163,10 @@ class NotesTool(Tool):
import re
updated_note = re.sub(re.escape(old_str), new_str, current_note, flags=re.IGNORECASE)
result = self.collection.find_one_and_update(
self.collection.update_one(
{"user_id": self.user_id, "tool_id": self.tool_id},
{"$set": {"note": updated_note, "updated_at": datetime.utcnow()}},
return_document=True,
)
if result and result.get("_id") is not None:
self._last_artifact_id = str(result.get("_id"))
return "Note updated."
def _insert(self, line_number: int, text: str) -> str:
@@ -203,21 +188,12 @@ class NotesTool(Tool):
lines.insert(index, text)
updated_note = "\n".join(lines)
result = self.collection.find_one_and_update(
self.collection.update_one(
{"user_id": self.user_id, "tool_id": self.tool_id},
{"$set": {"note": updated_note, "updated_at": datetime.utcnow()}},
return_document=True,
)
if result and result.get("_id") is not None:
self._last_artifact_id = str(result.get("_id"))
return "Text inserted."
def _delete_note(self) -> str:
doc = self.collection.find_one_and_delete(
{"user_id": self.user_id, "tool_id": self.tool_id}
)
if not doc:
return "No note found to delete."
if doc.get("_id") is not None:
self._last_artifact_id = str(doc.get("_id"))
return "Note deleted."
res = self.collection.delete_one({"user_id": self.user_id, "tool_id": self.tool_id})
return "Note deleted." if res.deleted_count else "No note found to delete."

View File

@@ -1,7 +1,7 @@
import requests
from markdownify import markdownify
from application.agents.tools.base import Tool
from application.core.url_validation import validate_url, SSRFError
from urllib.parse import urlparse
class ReadWebpageTool(Tool):
"""
@@ -31,12 +31,11 @@ class ReadWebpageTool(Tool):
if not url:
return "Error: URL parameter is missing."
# Validate URL to prevent SSRF attacks
try:
url = validate_url(url)
except SSRFError as e:
return f"Error: URL validation failed - {e}"
# Ensure the URL has a scheme (if not, default to http)
parsed_url = urlparse(url)
if not parsed_url.scheme:
url = "http://" + url
try:
response = requests.get(url, timeout=10, headers={'User-Agent': 'DocsGPT-Agent/1.0'})
response.raise_for_status() # Raise an exception for HTTP errors (4xx or 5xx)

View File

@@ -1,342 +0,0 @@
"""
API Specification Parser
Parses OpenAPI 3.x and Swagger 2.0 specifications and converts them
to API Tool action definitions for use in DocsGPT.
"""
import json
import logging
import re
from typing import Any, Dict, List, Optional, Tuple
import yaml
logger = logging.getLogger(__name__)
SUPPORTED_METHODS = frozenset(
{"get", "post", "put", "delete", "patch", "head", "options"}
)
def parse_spec(spec_content: str) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
"""
Parse an API specification and convert operations to action definitions.
Supports OpenAPI 3.x and Swagger 2.0 formats in JSON or YAML.
Args:
spec_content: Raw specification content as string
Returns:
Tuple of (metadata dict, list of action dicts)
Raises:
ValueError: If the spec is invalid or uses an unsupported format
"""
spec = _load_spec(spec_content)
_validate_spec(spec)
is_swagger = "swagger" in spec
metadata = _extract_metadata(spec, is_swagger)
actions = _extract_actions(spec, is_swagger)
return metadata, actions
def _load_spec(content: str) -> Dict[str, Any]:
"""Parse spec content from JSON or YAML string."""
content = content.strip()
if not content:
raise ValueError("Empty specification content")
try:
if content.startswith("{"):
return json.loads(content)
return yaml.safe_load(content)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON format: {e.msg}")
except yaml.YAMLError as e:
raise ValueError(f"Invalid YAML format: {e}")
def _validate_spec(spec: Dict[str, Any]) -> None:
"""Validate spec version and required fields."""
if not isinstance(spec, dict):
raise ValueError("Specification must be a valid object")
openapi_version = spec.get("openapi", "")
swagger_version = spec.get("swagger", "")
if not (openapi_version.startswith("3.") or swagger_version == "2.0"):
raise ValueError(
"Unsupported specification version. Expected OpenAPI 3.x or Swagger 2.0"
)
if "paths" not in spec or not spec["paths"]:
raise ValueError("No API paths defined in the specification")
def _extract_metadata(spec: Dict[str, Any], is_swagger: bool) -> Dict[str, Any]:
"""Extract API metadata from specification."""
info = spec.get("info", {})
base_url = _get_base_url(spec, is_swagger)
return {
"title": info.get("title", "Untitled API"),
"description": (info.get("description", "") or "")[:500],
"version": info.get("version", ""),
"base_url": base_url,
}
def _get_base_url(spec: Dict[str, Any], is_swagger: bool) -> str:
"""Extract base URL from spec (handles both OpenAPI 3.x and Swagger 2.0)."""
if is_swagger:
schemes = spec.get("schemes", ["https"])
host = spec.get("host", "")
base_path = spec.get("basePath", "")
if host:
scheme = schemes[0] if schemes else "https"
return f"{scheme}://{host}{base_path}".rstrip("/")
return ""
servers = spec.get("servers", [])
if servers and isinstance(servers, list) and servers[0].get("url"):
return servers[0]["url"].rstrip("/")
return ""
def _extract_actions(spec: Dict[str, Any], is_swagger: bool) -> List[Dict[str, Any]]:
"""Extract all API operations as action definitions."""
actions = []
paths = spec.get("paths", {})
base_url = _get_base_url(spec, is_swagger)
components = spec.get("components", {})
definitions = spec.get("definitions", {})
for path, path_item in paths.items():
if not isinstance(path_item, dict):
continue
path_params = path_item.get("parameters", [])
for method in SUPPORTED_METHODS:
operation = path_item.get(method)
if not isinstance(operation, dict):
continue
try:
action = _build_action(
path=path,
method=method,
operation=operation,
path_params=path_params,
base_url=base_url,
components=components,
definitions=definitions,
is_swagger=is_swagger,
)
actions.append(action)
except Exception as e:
logger.warning(
f"Failed to parse operation {method.upper()} {path}: {e}"
)
continue
return actions
def _build_action(
path: str,
method: str,
operation: Dict[str, Any],
path_params: List[Dict],
base_url: str,
components: Dict[str, Any],
definitions: Dict[str, Any],
is_swagger: bool,
) -> Dict[str, Any]:
"""Build a single action from an API operation."""
action_name = _generate_action_name(operation, method, path)
full_url = f"{base_url}{path}" if base_url else path
all_params = path_params + operation.get("parameters", [])
query_params, headers = _categorize_parameters(all_params, components, definitions)
body, body_content_type = _extract_request_body(
operation, components, definitions, is_swagger
)
description = operation.get("summary", "") or operation.get("description", "")
return {
"name": action_name,
"url": full_url,
"method": method.upper(),
"description": (description or "")[:500],
"query_params": {"type": "object", "properties": query_params},
"headers": {"type": "object", "properties": headers},
"body": {"type": "object", "properties": body},
"body_content_type": body_content_type,
"active": True,
}
def _generate_action_name(operation: Dict[str, Any], method: str, path: str) -> str:
"""Generate a valid action name from operationId or method+path."""
if operation.get("operationId"):
name = operation["operationId"]
else:
path_slug = re.sub(r"[{}]", "", path)
path_slug = re.sub(r"[^a-zA-Z0-9]", "_", path_slug)
path_slug = re.sub(r"_+", "_", path_slug).strip("_")
name = f"{method}_{path_slug}"
name = re.sub(r"[^a-zA-Z0-9_-]", "_", name)
return name[:64]
def _categorize_parameters(
parameters: List[Dict],
components: Dict[str, Any],
definitions: Dict[str, Any],
) -> Tuple[Dict, Dict]:
"""Categorize parameters into query params and headers."""
query_params = {}
headers = {}
for param in parameters:
resolved = _resolve_ref(param, components, definitions)
if not resolved or "name" not in resolved:
continue
location = resolved.get("in", "query")
prop = _param_to_property(resolved)
if location in ("query", "path"):
query_params[resolved["name"]] = prop
elif location == "header":
headers[resolved["name"]] = prop
return query_params, headers
def _param_to_property(param: Dict) -> Dict[str, Any]:
"""Convert an API parameter to an action property definition."""
schema = param.get("schema", {})
param_type = schema.get("type", param.get("type", "string"))
mapped_type = "integer" if param_type in ("integer", "number") else "string"
return {
"type": mapped_type,
"description": (param.get("description", "") or "")[:200],
"value": "",
"filled_by_llm": param.get("required", False),
"required": param.get("required", False),
}
def _extract_request_body(
operation: Dict[str, Any],
components: Dict[str, Any],
definitions: Dict[str, Any],
is_swagger: bool,
) -> Tuple[Dict, str]:
"""Extract request body schema and content type."""
content_types = [
"application/json",
"application/x-www-form-urlencoded",
"multipart/form-data",
"text/plain",
"application/xml",
]
if is_swagger:
consumes = operation.get("consumes", [])
body_param = next(
(p for p in operation.get("parameters", []) if p.get("in") == "body"), None
)
if not body_param:
return {}, "application/json"
selected_type = consumes[0] if consumes else "application/json"
schema = body_param.get("schema", {})
else:
request_body = operation.get("requestBody", {})
if not request_body:
return {}, "application/json"
request_body = _resolve_ref(request_body, components, definitions)
content = request_body.get("content", {})
selected_type = "application/json"
schema = {}
for ct in content_types:
if ct in content:
selected_type = ct
schema = content[ct].get("schema", {})
break
if not schema and content:
first_type = next(iter(content))
selected_type = first_type
schema = content[first_type].get("schema", {})
properties = _schema_to_properties(schema, components, definitions)
return properties, selected_type
def _schema_to_properties(
schema: Dict,
components: Dict[str, Any],
definitions: Dict[str, Any],
depth: int = 0,
) -> Dict[str, Any]:
"""Convert schema to action body properties (limited depth to prevent recursion)."""
if depth > 3:
return {}
schema = _resolve_ref(schema, components, definitions)
if not schema or not isinstance(schema, dict):
return {}
properties = {}
schema_type = schema.get("type", "object")
if schema_type == "object":
required_fields = set(schema.get("required", []))
for prop_name, prop_schema in schema.get("properties", {}).items():
resolved = _resolve_ref(prop_schema, components, definitions)
if not isinstance(resolved, dict):
continue
prop_type = resolved.get("type", "string")
mapped_type = "integer" if prop_type in ("integer", "number") else "string"
properties[prop_name] = {
"type": mapped_type,
"description": (resolved.get("description", "") or "")[:200],
"value": "",
"filled_by_llm": prop_name in required_fields,
"required": prop_name in required_fields,
}
return properties
def _resolve_ref(
obj: Any,
components: Dict[str, Any],
definitions: Dict[str, Any],
) -> Optional[Dict]:
"""Resolve $ref references in the specification."""
if not isinstance(obj, dict):
return obj if isinstance(obj, dict) else None
if "$ref" not in obj:
return obj
ref_path = obj["$ref"]
if ref_path.startswith("#/components/"):
parts = ref_path.replace("#/components/", "").split("/")
return _traverse_path(components, parts)
elif ref_path.startswith("#/definitions/"):
parts = ref_path.replace("#/definitions/", "").split("/")
return _traverse_path(definitions, parts)
logger.debug(f"Unsupported ref path: {ref_path}")
return None
def _traverse_path(obj: Dict, parts: List[str]) -> Optional[Dict]:
"""Traverse a nested dictionary using path parts."""
try:
for part in parts:
obj = obj[part]
return obj if isinstance(obj, dict) else None
except (KeyError, TypeError):
return None

View File

@@ -38,8 +38,6 @@ class TodoListTool(Tool):
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
self.collection = db["todos"]
self._last_artifact_id: Optional[str] = None
# -----------------------------
# Action implementations
# -----------------------------
@@ -56,8 +54,6 @@ class TodoListTool(Tool):
if not self.user_id:
return "Error: TodoListTool requires a valid user_id."
self._last_artifact_id = None
if action_name == "list":
return self._list()
@@ -169,9 +165,6 @@ class TodoListTool(Tool):
"""Return configuration requirements."""
return {}
def get_artifact_id(self, action_name: str, **kwargs: Any) -> Optional[str]:
return self._last_artifact_id
# -----------------------------
# Internal helpers
# -----------------------------
@@ -197,8 +190,11 @@ class TodoListTool(Tool):
Returns a simple integer (1, 2, 3, ...) scoped to this user/tool.
With 5-10 todos max, scanning is negligible.
"""
query = {"user_id": self.user_id, "tool_id": self.tool_id}
todos = list(self.collection.find(query, {"todo_id": 1}))
# Find all todos for this user/tool and get their IDs
todos = list(self.collection.find(
{"user_id": self.user_id, "tool_id": self.tool_id},
{"todo_id": 1}
))
# Find the maximum todo_id
max_id = 0
@@ -211,8 +207,8 @@ class TodoListTool(Tool):
def _list(self) -> str:
"""List all todos for the user."""
query = {"user_id": self.user_id, "tool_id": self.tool_id}
todos = list(self.collection.find(query))
cursor = self.collection.find({"user_id": self.user_id, "tool_id": self.tool_id})
todos = list(cursor)
if not todos:
return "No todos found."
@@ -246,10 +242,7 @@ class TodoListTool(Tool):
"created_at": now,
"updated_at": now,
}
insert_result = self.collection.insert_one(doc)
inserted_id = getattr(insert_result, "inserted_id", None) or doc.get("_id")
if inserted_id is not None:
self._last_artifact_id = str(inserted_id)
self.collection.insert_one(doc)
return f"Todo created with ID {todo_id}: {title}"
def _get(self, todo_id: Optional[Any]) -> str:
@@ -258,15 +251,15 @@ class TodoListTool(Tool):
if parsed_todo_id is None:
return "Error: todo_id must be a positive integer."
query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}
doc = self.collection.find_one(query)
doc = self.collection.find_one({
"user_id": self.user_id,
"tool_id": self.tool_id,
"todo_id": parsed_todo_id
})
if not doc:
return f"Error: Todo with ID {parsed_todo_id} not found."
if doc.get("_id") is not None:
self._last_artifact_id = str(doc.get("_id"))
title = doc.get("title", "Untitled")
status = doc.get("status", "open")
@@ -284,16 +277,13 @@ class TodoListTool(Tool):
if not title:
return "Error: Title is required."
query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}
doc = self.collection.find_one_and_update(
query,
{"$set": {"title": title, "updated_at": datetime.now()}},
result = self.collection.update_one(
{"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id},
{"$set": {"title": title, "updated_at": datetime.now()}}
)
if not doc:
return f"Error: Todo with ID {parsed_todo_id} not found."
if doc.get("_id") is not None:
self._last_artifact_id = str(doc.get("_id"))
if result.matched_count == 0:
return f"Error: Todo with ID {parsed_todo_id} not found."
return f"Todo {parsed_todo_id} updated to: {title}"
@@ -303,16 +293,13 @@ class TodoListTool(Tool):
if parsed_todo_id is None:
return "Error: todo_id must be a positive integer."
query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}
doc = self.collection.find_one_and_update(
query,
{"$set": {"status": "completed", "updated_at": datetime.now()}},
result = self.collection.update_one(
{"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id},
{"$set": {"status": "completed", "updated_at": datetime.now()}}
)
if not doc:
return f"Error: Todo with ID {parsed_todo_id} not found."
if doc.get("_id") is not None:
self._last_artifact_id = str(doc.get("_id"))
if result.matched_count == 0:
return f"Error: Todo with ID {parsed_todo_id} not found."
return f"Todo {parsed_todo_id} marked as completed."
@@ -322,12 +309,13 @@ class TodoListTool(Tool):
if parsed_todo_id is None:
return "Error: todo_id must be a positive integer."
query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}
doc = self.collection.find_one_and_delete(query)
if not doc:
result = self.collection.delete_one({
"user_id": self.user_id,
"tool_id": self.tool_id,
"todo_id": parsed_todo_id
})
if result.deleted_count == 0:
return f"Error: Todo with ID {parsed_todo_id} not found."
if doc.get("_id") is not None:
self._last_artifact_id = str(doc.get("_id"))
return f"Todo {parsed_todo_id} deleted."

View File

@@ -1,218 +0,0 @@
import logging
from datetime import datetime, timezone
from typing import Any, Dict, Generator, Optional
from application.agents.base import BaseAgent
from application.agents.workflows.schemas import (
ExecutionStatus,
Workflow,
WorkflowEdge,
WorkflowGraph,
WorkflowNode,
WorkflowRun,
)
from application.agents.workflows.workflow_engine import WorkflowEngine
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.logging import log_activity, LogContext
logger = logging.getLogger(__name__)
class WorkflowAgent(BaseAgent):
"""A specialized agent that executes predefined workflows."""
def __init__(
self,
*args,
workflow_id: Optional[str] = None,
workflow: Optional[Dict[str, Any]] = None,
workflow_owner: Optional[str] = None,
**kwargs,
):
super().__init__(*args, **kwargs)
self.workflow_id = workflow_id
self.workflow_owner = workflow_owner
self._workflow_data = workflow
self._engine: Optional[WorkflowEngine] = None
@log_activity()
def gen(
self, query: str, log_context: LogContext = None
) -> Generator[Dict[str, str], None, None]:
yield from self._gen_inner(query, log_context)
def _gen_inner(
self, query: str, log_context: LogContext
) -> Generator[Dict[str, str], None, None]:
graph = self._load_workflow_graph()
if not graph:
yield {"type": "error", "error": "Failed to load workflow configuration."}
return
self._engine = WorkflowEngine(graph, self)
yield from self._engine.execute({}, query)
self._save_workflow_run(query)
def _load_workflow_graph(self) -> Optional[WorkflowGraph]:
if self._workflow_data:
return self._parse_embedded_workflow()
if self.workflow_id:
return self._load_from_database()
return None
def _parse_embedded_workflow(self) -> Optional[WorkflowGraph]:
try:
nodes_data = self._workflow_data.get("nodes", [])
edges_data = self._workflow_data.get("edges", [])
workflow = Workflow(
name=self._workflow_data.get("name", "Embedded Workflow"),
description=self._workflow_data.get("description"),
)
nodes = []
for n in nodes_data:
node_config = n.get("data", {})
nodes.append(
WorkflowNode(
id=n["id"],
workflow_id=self.workflow_id or "embedded",
type=n["type"],
title=n.get("title", "Node"),
description=n.get("description"),
position=n.get("position", {"x": 0, "y": 0}),
config=node_config,
)
)
edges = []
for e in edges_data:
edges.append(
WorkflowEdge(
id=e["id"],
workflow_id=self.workflow_id or "embedded",
source=e.get("source") or e.get("source_id"),
target=e.get("target") or e.get("target_id"),
sourceHandle=e.get("sourceHandle") or e.get("source_handle"),
targetHandle=e.get("targetHandle") or e.get("target_handle"),
)
)
return WorkflowGraph(workflow=workflow, nodes=nodes, edges=edges)
except Exception as e:
logger.error(f"Invalid embedded workflow: {e}")
return None
def _load_from_database(self) -> Optional[WorkflowGraph]:
try:
from bson.objectid import ObjectId
if not self.workflow_id or not ObjectId.is_valid(self.workflow_id):
logger.error(f"Invalid workflow ID: {self.workflow_id}")
return None
owner_id = self.workflow_owner
if not owner_id and isinstance(self.decoded_token, dict):
owner_id = self.decoded_token.get("sub")
if not owner_id:
logger.error(
f"Workflow owner not available for workflow load: {self.workflow_id}"
)
return None
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
workflows_coll = db["workflows"]
workflow_nodes_coll = db["workflow_nodes"]
workflow_edges_coll = db["workflow_edges"]
workflow_doc = workflows_coll.find_one(
{"_id": ObjectId(self.workflow_id), "user": owner_id}
)
if not workflow_doc:
logger.error(
f"Workflow {self.workflow_id} not found or inaccessible for user {owner_id}"
)
return None
workflow = Workflow(**workflow_doc)
graph_version = workflow_doc.get("current_graph_version", 1)
try:
graph_version = int(graph_version)
if graph_version <= 0:
graph_version = 1
except (ValueError, TypeError):
graph_version = 1
nodes_docs = list(
workflow_nodes_coll.find(
{"workflow_id": self.workflow_id, "graph_version": graph_version}
)
)
if not nodes_docs and graph_version == 1:
nodes_docs = list(
workflow_nodes_coll.find(
{
"workflow_id": self.workflow_id,
"graph_version": {"$exists": False},
}
)
)
nodes = [WorkflowNode(**doc) for doc in nodes_docs]
edges_docs = list(
workflow_edges_coll.find(
{"workflow_id": self.workflow_id, "graph_version": graph_version}
)
)
if not edges_docs and graph_version == 1:
edges_docs = list(
workflow_edges_coll.find(
{
"workflow_id": self.workflow_id,
"graph_version": {"$exists": False},
}
)
)
edges = [WorkflowEdge(**doc) for doc in edges_docs]
return WorkflowGraph(workflow=workflow, nodes=nodes, edges=edges)
except Exception as e:
logger.error(f"Failed to load workflow from database: {e}")
return None
def _save_workflow_run(self, query: str) -> None:
if not self._engine:
return
try:
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
workflow_runs_coll = db["workflow_runs"]
run = WorkflowRun(
workflow_id=self.workflow_id or "unknown",
status=self._determine_run_status(),
inputs={"query": query},
outputs=self._serialize_state(self._engine.state),
steps=self._engine.get_execution_summary(),
created_at=datetime.now(timezone.utc),
completed_at=datetime.now(timezone.utc),
)
workflow_runs_coll.insert_one(run.to_mongo_doc())
except Exception as e:
logger.error(f"Failed to save workflow run: {e}")
def _determine_run_status(self) -> ExecutionStatus:
if not self._engine or not self._engine.execution_log:
return ExecutionStatus.COMPLETED
for log in self._engine.execution_log:
if log.get("status") == ExecutionStatus.FAILED.value:
return ExecutionStatus.FAILED
return ExecutionStatus.COMPLETED
def _serialize_state(self, state: Dict[str, Any]) -> Dict[str, Any]:
serialized: Dict[str, Any] = {}
for key, value in state.items():
if isinstance(value, (str, int, float, bool, type(None))):
serialized[key] = value
else:
serialized[key] = str(value)
return serialized

View File

@@ -1,109 +0,0 @@
"""Workflow Node Agents - defines specialized agents for workflow nodes."""
from typing import Any, Dict, List, Optional, Type
from application.agents.base import BaseAgent
from application.agents.classic_agent import ClassicAgent
from application.agents.react_agent import ReActAgent
from application.agents.workflows.schemas import AgentType
class ToolFilterMixin:
"""Mixin that filters fetched tools to only those specified in tool_ids."""
_allowed_tool_ids: List[str]
def _get_user_tools(self, user: str = "local") -> Dict[str, Dict[str, Any]]:
all_tools = super()._get_user_tools(user)
if not self._allowed_tool_ids:
return {}
filtered_tools = {
tool_id: tool
for tool_id, tool in all_tools.items()
if str(tool.get("_id", "")) in self._allowed_tool_ids
}
return filtered_tools
def _get_tools(self, api_key: str = None) -> Dict[str, Dict[str, Any]]:
all_tools = super()._get_tools(api_key)
if not self._allowed_tool_ids:
return {}
filtered_tools = {
tool_id: tool
for tool_id, tool in all_tools.items()
if str(tool.get("_id", "")) in self._allowed_tool_ids
}
return filtered_tools
class WorkflowNodeClassicAgent(ToolFilterMixin, ClassicAgent):
def __init__(
self,
endpoint: str,
llm_name: str,
model_id: str,
api_key: str,
tool_ids: Optional[List[str]] = None,
**kwargs,
):
super().__init__(
endpoint=endpoint,
llm_name=llm_name,
model_id=model_id,
api_key=api_key,
**kwargs,
)
self._allowed_tool_ids = tool_ids or []
class WorkflowNodeReActAgent(ToolFilterMixin, ReActAgent):
def __init__(
self,
endpoint: str,
llm_name: str,
model_id: str,
api_key: str,
tool_ids: Optional[List[str]] = None,
**kwargs,
):
super().__init__(
endpoint=endpoint,
llm_name=llm_name,
model_id=model_id,
api_key=api_key,
**kwargs,
)
self._allowed_tool_ids = tool_ids or []
class WorkflowNodeAgentFactory:
_agents: Dict[AgentType, Type[BaseAgent]] = {
AgentType.CLASSIC: WorkflowNodeClassicAgent,
AgentType.REACT: WorkflowNodeReActAgent,
}
@classmethod
def create(
cls,
agent_type: AgentType,
endpoint: str,
llm_name: str,
model_id: str,
api_key: str,
tool_ids: Optional[List[str]] = None,
**kwargs,
) -> BaseAgent:
agent_class = cls._agents.get(agent_type)
if not agent_class:
raise ValueError(f"Unsupported agent type: {agent_type}")
return agent_class(
endpoint=endpoint,
llm_name=llm_name,
model_id=model_id,
api_key=api_key,
tool_ids=tool_ids,
**kwargs,
)

View File

@@ -1,215 +0,0 @@
from datetime import datetime, timezone
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from bson import ObjectId
from pydantic import BaseModel, ConfigDict, Field, field_validator
class NodeType(str, Enum):
START = "start"
END = "end"
AGENT = "agent"
NOTE = "note"
STATE = "state"
class AgentType(str, Enum):
CLASSIC = "classic"
REACT = "react"
class ExecutionStatus(str, Enum):
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
class Position(BaseModel):
model_config = ConfigDict(extra="forbid")
x: float = 0.0
y: float = 0.0
class AgentNodeConfig(BaseModel):
model_config = ConfigDict(extra="allow")
agent_type: AgentType = AgentType.CLASSIC
llm_name: Optional[str] = None
system_prompt: str = "You are a helpful assistant."
prompt_template: str = ""
output_variable: Optional[str] = None
stream_to_user: bool = True
tools: List[str] = Field(default_factory=list)
sources: List[str] = Field(default_factory=list)
chunks: str = "2"
retriever: str = ""
model_id: Optional[str] = None
json_schema: Optional[Dict[str, Any]] = None
class WorkflowEdgeCreate(BaseModel):
model_config = ConfigDict(populate_by_name=True)
id: str
workflow_id: str
source_id: str = Field(..., alias="source")
target_id: str = Field(..., alias="target")
source_handle: Optional[str] = Field(None, alias="sourceHandle")
target_handle: Optional[str] = Field(None, alias="targetHandle")
class WorkflowEdge(WorkflowEdgeCreate):
mongo_id: Optional[str] = Field(None, alias="_id")
@field_validator("mongo_id", mode="before")
@classmethod
def convert_objectid(cls, v: Any) -> Optional[str]:
if isinstance(v, ObjectId):
return str(v)
return v
def to_mongo_doc(self) -> Dict[str, Any]:
return {
"id": self.id,
"workflow_id": self.workflow_id,
"source_id": self.source_id,
"target_id": self.target_id,
"source_handle": self.source_handle,
"target_handle": self.target_handle,
}
class WorkflowNodeCreate(BaseModel):
model_config = ConfigDict(extra="allow")
id: str
workflow_id: str
type: NodeType
title: str = "Node"
description: Optional[str] = None
position: Position = Field(default_factory=Position)
config: Dict[str, Any] = Field(default_factory=dict)
@field_validator("position", mode="before")
@classmethod
def parse_position(cls, v: Union[Dict[str, float], Position]) -> Position:
if isinstance(v, dict):
return Position(**v)
return v
class WorkflowNode(WorkflowNodeCreate):
mongo_id: Optional[str] = Field(None, alias="_id")
@field_validator("mongo_id", mode="before")
@classmethod
def convert_objectid(cls, v: Any) -> Optional[str]:
if isinstance(v, ObjectId):
return str(v)
return v
def to_mongo_doc(self) -> Dict[str, Any]:
return {
"id": self.id,
"workflow_id": self.workflow_id,
"type": self.type.value,
"title": self.title,
"description": self.description,
"position": self.position.model_dump(),
"config": self.config,
}
class WorkflowCreate(BaseModel):
model_config = ConfigDict(extra="allow")
name: str = "New Workflow"
description: Optional[str] = None
user: Optional[str] = None
class Workflow(WorkflowCreate):
id: Optional[str] = Field(None, alias="_id")
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
@field_validator("id", mode="before")
@classmethod
def convert_objectid(cls, v: Any) -> Optional[str]:
if isinstance(v, ObjectId):
return str(v)
return v
def to_mongo_doc(self) -> Dict[str, Any]:
return {
"name": self.name,
"description": self.description,
"user": self.user,
"created_at": self.created_at,
"updated_at": self.updated_at,
}
class WorkflowGraph(BaseModel):
workflow: Workflow
nodes: List[WorkflowNode] = Field(default_factory=list)
edges: List[WorkflowEdge] = Field(default_factory=list)
def get_node_by_id(self, node_id: str) -> Optional[WorkflowNode]:
for node in self.nodes:
if node.id == node_id:
return node
return None
def get_start_node(self) -> Optional[WorkflowNode]:
for node in self.nodes:
if node.type == NodeType.START:
return node
return None
def get_outgoing_edges(self, node_id: str) -> List[WorkflowEdge]:
return [edge for edge in self.edges if edge.source_id == node_id]
class NodeExecutionLog(BaseModel):
model_config = ConfigDict(extra="forbid")
node_id: str
node_type: str
status: ExecutionStatus
started_at: datetime
completed_at: Optional[datetime] = None
error: Optional[str] = None
state_snapshot: Dict[str, Any] = Field(default_factory=dict)
class WorkflowRunCreate(BaseModel):
workflow_id: str
inputs: Dict[str, str] = Field(default_factory=dict)
class WorkflowRun(BaseModel):
model_config = ConfigDict(extra="allow")
id: Optional[str] = Field(None, alias="_id")
workflow_id: str
status: ExecutionStatus = ExecutionStatus.PENDING
inputs: Dict[str, str] = Field(default_factory=dict)
outputs: Dict[str, Any] = Field(default_factory=dict)
steps: List[NodeExecutionLog] = Field(default_factory=list)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
completed_at: Optional[datetime] = None
@field_validator("id", mode="before")
@classmethod
def convert_objectid(cls, v: Any) -> Optional[str]:
if isinstance(v, ObjectId):
return str(v)
return v
def to_mongo_doc(self) -> Dict[str, Any]:
return {
"workflow_id": self.workflow_id,
"status": self.status.value,
"inputs": self.inputs,
"outputs": self.outputs,
"steps": [step.model_dump() for step in self.steps],
"created_at": self.created_at,
"completed_at": self.completed_at,
}

View File

@@ -1,276 +0,0 @@
import logging
from datetime import datetime, timezone
from typing import Any, Dict, Generator, List, Optional, TYPE_CHECKING
from application.agents.workflows.node_agent import WorkflowNodeAgentFactory
from application.agents.workflows.schemas import (
AgentNodeConfig,
ExecutionStatus,
NodeExecutionLog,
NodeType,
WorkflowGraph,
WorkflowNode,
)
if TYPE_CHECKING:
from application.agents.base import BaseAgent
logger = logging.getLogger(__name__)
StateValue = Any
WorkflowState = Dict[str, StateValue]
class WorkflowEngine:
MAX_EXECUTION_STEPS = 50
def __init__(self, graph: WorkflowGraph, agent: "BaseAgent"):
self.graph = graph
self.agent = agent
self.state: WorkflowState = {}
self.execution_log: List[Dict[str, Any]] = []
def execute(
self, initial_inputs: WorkflowState, query: str
) -> Generator[Dict[str, str], None, None]:
self._initialize_state(initial_inputs, query)
start_node = self.graph.get_start_node()
if not start_node:
yield {"type": "error", "error": "No start node found in workflow."}
return
current_node_id: Optional[str] = start_node.id
steps = 0
while current_node_id and steps < self.MAX_EXECUTION_STEPS:
node = self.graph.get_node_by_id(current_node_id)
if not node:
yield {"type": "error", "error": f"Node {current_node_id} not found."}
break
log_entry = self._create_log_entry(node)
yield {
"type": "workflow_step",
"node_id": node.id,
"node_type": node.type.value,
"node_title": node.title,
"status": "running",
}
try:
yield from self._execute_node(node)
log_entry["status"] = ExecutionStatus.COMPLETED.value
log_entry["completed_at"] = datetime.now(timezone.utc)
output_key = f"node_{node.id}_output"
node_output = self.state.get(output_key)
yield {
"type": "workflow_step",
"node_id": node.id,
"node_type": node.type.value,
"node_title": node.title,
"status": "completed",
"state_snapshot": dict(self.state),
"output": node_output,
}
except Exception as e:
logger.error(f"Error executing node {node.id}: {e}", exc_info=True)
log_entry["status"] = ExecutionStatus.FAILED.value
log_entry["error"] = str(e)
log_entry["completed_at"] = datetime.now(timezone.utc)
log_entry["state_snapshot"] = dict(self.state)
self.execution_log.append(log_entry)
yield {
"type": "workflow_step",
"node_id": node.id,
"node_type": node.type.value,
"node_title": node.title,
"status": "failed",
"state_snapshot": dict(self.state),
"error": str(e),
}
yield {"type": "error", "error": str(e)}
break
log_entry["state_snapshot"] = dict(self.state)
self.execution_log.append(log_entry)
if node.type == NodeType.END:
break
current_node_id = self._get_next_node_id(current_node_id)
steps += 1
if steps >= self.MAX_EXECUTION_STEPS:
logger.warning(
f"Workflow reached max steps limit ({self.MAX_EXECUTION_STEPS})"
)
def _initialize_state(self, initial_inputs: WorkflowState, query: str) -> None:
self.state.update(initial_inputs)
self.state["query"] = query
self.state["chat_history"] = str(self.agent.chat_history)
def _create_log_entry(self, node: WorkflowNode) -> Dict[str, Any]:
return {
"node_id": node.id,
"node_type": node.type.value,
"started_at": datetime.now(timezone.utc),
"completed_at": None,
"status": ExecutionStatus.RUNNING.value,
"error": None,
"state_snapshot": {},
}
def _get_next_node_id(self, current_node_id: str) -> Optional[str]:
edges = self.graph.get_outgoing_edges(current_node_id)
if edges:
return edges[0].target_id
return None
def _execute_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
logger.info(f"Executing node {node.id} ({node.type.value})")
node_handlers = {
NodeType.START: self._execute_start_node,
NodeType.NOTE: self._execute_note_node,
NodeType.AGENT: self._execute_agent_node,
NodeType.STATE: self._execute_state_node,
NodeType.END: self._execute_end_node,
}
handler = node_handlers.get(node.type)
if handler:
yield from handler(node)
def _execute_start_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
yield from ()
def _execute_note_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
yield from ()
def _execute_agent_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
from application.core.model_utils import get_api_key_for_provider
node_config = AgentNodeConfig(**node.config)
if node_config.prompt_template:
formatted_prompt = self._format_template(node_config.prompt_template)
else:
formatted_prompt = self.state.get("query", "")
node_llm_name = node_config.llm_name or self.agent.llm_name
node_api_key = get_api_key_for_provider(node_llm_name) or self.agent.api_key
node_agent = WorkflowNodeAgentFactory.create(
agent_type=node_config.agent_type,
endpoint=self.agent.endpoint,
llm_name=node_llm_name,
model_id=node_config.model_id or self.agent.model_id,
api_key=node_api_key,
tool_ids=node_config.tools,
prompt=node_config.system_prompt,
chat_history=self.agent.chat_history,
decoded_token=self.agent.decoded_token,
json_schema=node_config.json_schema,
)
full_response = ""
first_chunk = True
for event in node_agent.gen(formatted_prompt):
if "answer" in event:
full_response += event["answer"]
if node_config.stream_to_user:
if first_chunk and hasattr(self, "_has_streamed"):
yield {"answer": "\n\n"}
first_chunk = False
yield event
if node_config.stream_to_user:
self._has_streamed = True
output_key = node_config.output_variable or f"node_{node.id}_output"
self.state[output_key] = full_response
def _execute_state_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
config = node.config
operations = config.get("operations", [])
if operations:
for op in operations:
key = op.get("key")
operation = op.get("operation", "set")
value = op.get("value")
if not key:
continue
if operation == "set":
formatted_value = (
self._format_template(str(value))
if isinstance(value, str)
else value
)
self.state[key] = formatted_value
elif operation == "increment":
current = self.state.get(key, 0)
try:
self.state[key] = int(current) + int(value or 1)
except (ValueError, TypeError):
self.state[key] = 1
elif operation == "append":
if key not in self.state:
self.state[key] = []
if isinstance(self.state[key], list):
self.state[key].append(value)
else:
updates = config.get("updates", {})
if not updates:
var_name = config.get("variable")
var_value = config.get("value")
if var_name and isinstance(var_name, str):
updates = {var_name: var_value or ""}
if isinstance(updates, dict):
for key, value in updates.items():
if isinstance(value, str):
self.state[key] = self._format_template(value)
else:
self.state[key] = value
yield from ()
def _execute_end_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
config = node.config
output_template = str(config.get("output_template", ""))
if output_template:
formatted_output = self._format_template(output_template)
yield {"answer": formatted_output}
def _format_template(self, template: str) -> str:
formatted = template
for key, value in self.state.items():
placeholder = f"{{{{{key}}}}}"
if placeholder in formatted and value is not None:
formatted = formatted.replace(placeholder, str(value))
return formatted
def get_execution_summary(self) -> List[NodeExecutionLog]:
return [
NodeExecutionLog(
node_id=log["node_id"],
node_type=log["node_type"],
status=ExecutionStatus(log["status"]),
started_at=log["started_at"],
completed_at=log.get("completed_at"),
error=log.get("error"),
state_snapshot=log.get("state_snapshot", {}),
)
for log in self.execution_log
]

View File

@@ -3,7 +3,6 @@ from flask import Blueprint
from application.api import api
from application.api.answer.routes.answer import AnswerResource
from application.api.answer.routes.base import answer_ns
from application.api.answer.routes.search import SearchResource
from application.api.answer.routes.stream import StreamResource
@@ -15,7 +14,6 @@ api.add_namespace(answer_ns)
def init_answer_routes():
api.add_resource(StreamResource, "/stream")
api.add_resource(AnswerResource, "/api/answer")
api.add_resource(SearchResource, "/api/search")
init_answer_routes()

View File

@@ -40,6 +40,7 @@ class AnswerResource(Resource, BaseAnswerResource):
"chunks": fields.Integer(
required=False, default=2, description="Number of chunks"
),
"token_limit": fields.Integer(required=False, description="Token limit"),
"retriever": fields.String(required=False, description="Retriever type"),
"api_key": fields.String(required=False, description="API key"),
"active_docs": fields.String(
@@ -137,5 +138,5 @@ class AnswerResource(Resource, BaseAnswerResource):
f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}",
extra={"error": str(e), "traceback": traceback.format_exc()},
)
return make_response({"error": "An error occurred processing your request"}, 500)
return make_response({"error": str(e)}, 500)
return make_response(result, 200)

View File

@@ -266,26 +266,6 @@ class BaseAnswerResource:
shared_token=shared_token,
attachment_ids=attachment_ids,
)
# Persist compression metadata/summary if it exists and wasn't saved mid-execution
compression_meta = getattr(agent, "compression_metadata", None)
compression_saved = getattr(agent, "compression_saved", False)
if conversation_id and compression_meta and not compression_saved:
try:
self.conversation_service.update_compression_metadata(
conversation_id, compression_meta
)
self.conversation_service.append_compression_message(
conversation_id, compression_meta
)
agent.compression_saved = True
logger.info(
f"Persisted compression metadata for conversation {conversation_id}"
)
except Exception as e:
logger.error(
f"Failed to persist compression metadata: {str(e)}",
exc_info=True,
)
else:
conversation_id = None
id_data = {"type": "id", "id": str(conversation_id)}
@@ -348,25 +328,6 @@ class BaseAnswerResource:
shared_token=shared_token,
attachment_ids=attachment_ids,
)
compression_meta = getattr(agent, "compression_metadata", None)
compression_saved = getattr(agent, "compression_saved", False)
if conversation_id and compression_meta and not compression_saved:
try:
self.conversation_service.update_compression_metadata(
conversation_id, compression_meta
)
self.conversation_service.append_compression_message(
conversation_id, compression_meta
)
agent.compression_saved = True
logger.info(
f"Persisted compression metadata for conversation {conversation_id} (partial stream)"
)
except Exception as e:
logger.error(
f"Failed to persist compression metadata (partial stream): {str(e)}",
exc_info=True,
)
except Exception as e:
logger.error(
f"Error saving partial response: {str(e)}", exc_info=True

View File

@@ -1,186 +0,0 @@
import logging
from typing import Any, Dict, List
from flask import make_response, request
from flask_restx import fields, Resource
from bson.dbref import DBRef
from application.api.answer.routes.base import answer_ns
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.vectorstore.vector_creator import VectorCreator
logger = logging.getLogger(__name__)
@answer_ns.route("/api/search")
class SearchResource(Resource):
"""Fast search endpoint for retrieving relevant documents"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
mongo = MongoDB.get_client()
self.db = mongo[settings.MONGO_DB_NAME]
self.agents_collection = self.db["agents"]
search_model = answer_ns.model(
"SearchModel",
{
"question": fields.String(
required=True, description="Search query"
),
"api_key": fields.String(
required=True, description="API key for authentication"
),
"chunks": fields.Integer(
required=False, default=5, description="Number of results to return"
),
},
)
def _get_sources_from_api_key(self, api_key: str) -> List[str]:
"""Get source IDs connected to the API key/agent.
"""
agent_data = self.agents_collection.find_one({"key": api_key})
if not agent_data:
return []
source_ids = []
# Handle multiple sources (only if non-empty)
sources = agent_data.get("sources", [])
if sources and isinstance(sources, list) and len(sources) > 0:
for source_ref in sources:
# Skip "default" - it's a placeholder, not an actual vectorstore
if source_ref == "default":
continue
elif isinstance(source_ref, DBRef):
source_doc = self.db.dereference(source_ref)
if source_doc:
source_ids.append(str(source_doc["_id"]))
# Handle single source (legacy) - check if sources was empty or didn't yield results
if not source_ids:
source = agent_data.get("source")
if isinstance(source, DBRef):
source_doc = self.db.dereference(source)
if source_doc:
source_ids.append(str(source_doc["_id"]))
# Skip "default" - it's a placeholder, not an actual vectorstore
elif source and source != "default":
source_ids.append(source)
return source_ids
def _search_vectorstores(
self, query: str, source_ids: List[str], chunks: int
) -> List[Dict[str, Any]]:
"""Search across vectorstores and return results"""
if not source_ids:
return []
results = []
chunks_per_source = max(1, chunks // len(source_ids))
seen_texts = set()
for source_id in source_ids:
if not source_id or not source_id.strip():
continue
try:
docsearch = VectorCreator.create_vectorstore(
settings.VECTOR_STORE, source_id, settings.EMBEDDINGS_KEY
)
docs = docsearch.search(query, k=chunks_per_source * 2)
for doc in docs:
if len(results) >= chunks:
break
if hasattr(doc, "page_content") and hasattr(doc, "metadata"):
page_content = doc.page_content
metadata = doc.metadata
else:
page_content = doc.get("text", doc.get("page_content", ""))
metadata = doc.get("metadata", {})
# Skip duplicates
text_hash = hash(page_content[:200])
if text_hash in seen_texts:
continue
seen_texts.add(text_hash)
title = metadata.get(
"title", metadata.get("post_title", "")
)
if not isinstance(title, str):
title = str(title) if title else ""
# Clean up title
if title:
title = title.split("/")[-1]
else:
# Use filename or first part of content as title
title = metadata.get("filename", page_content[:50] + "...")
source = metadata.get("source", source_id)
results.append({
"text": page_content,
"title": title,
"source": source,
})
if len(results) >= chunks:
break
except Exception as e:
logger.error(
f"Error searching vectorstore {source_id}: {e}",
exc_info=True,
)
continue
return results[:chunks]
@answer_ns.expect(search_model)
@answer_ns.doc(description="Search for relevant documents based on query")
def post(self):
data = request.get_json()
question = data.get("question")
api_key = data.get("api_key")
chunks = data.get("chunks", 5)
if not question:
return make_response({"error": "question is required"}, 400)
if not api_key:
return make_response({"error": "api_key is required"}, 400)
# Validate API key
agent = self.agents_collection.find_one({"key": api_key})
if not agent:
return make_response({"error": "Invalid API key"}, 401)
try:
# Get sources connected to this API key
source_ids = self._get_sources_from_api_key(api_key)
if not source_ids:
return make_response([], 200)
# Perform search
results = self._search_vectorstores(question, source_ids, chunks)
return make_response(results, 200)
except Exception as e:
logger.error(
f"/api/search - error: {str(e)}",
extra={"error": str(e)},
exc_info=True,
)
return make_response({"error": "Search failed"}, 500)

View File

@@ -40,6 +40,7 @@ class StreamResource(Resource, BaseAnswerResource):
"chunks": fields.Integer(
required=False, default=2, description="Number of chunks"
),
"token_limit": fields.Integer(required=False, description="Token limit"),
"retriever": fields.String(required=False, description="Retriever type"),
"api_key": fields.String(required=False, description="API key"),
"active_docs": fields.String(
@@ -80,12 +81,6 @@ class StreamResource(Resource, BaseAnswerResource):
processor = StreamProcessor(data, decoded_token)
try:
processor.initialize()
if not processor.decoded_token:
return Response(
self.error_stream_generate("Unauthorized"),
status=401,
mimetype="text/event-stream",
)
docs_together, docs_list = processor.pre_fetch_docs(data["question"])
tools_data = processor.pre_fetch_tools()

View File

@@ -1,20 +0,0 @@
"""
Compression module for managing conversation context compression.
"""
from application.api.answer.services.compression.orchestrator import (
CompressionOrchestrator,
)
from application.api.answer.services.compression.service import CompressionService
from application.api.answer.services.compression.types import (
CompressionResult,
CompressionMetadata,
)
__all__ = [
"CompressionOrchestrator",
"CompressionService",
"CompressionResult",
"CompressionMetadata",
]

View File

@@ -1,234 +0,0 @@
"""Message reconstruction utilities for compression."""
import logging
import uuid
from typing import Dict, List, Optional
logger = logging.getLogger(__name__)
class MessageBuilder:
"""Builds message arrays from compressed context."""
@staticmethod
def build_from_compressed_context(
system_prompt: str,
compressed_summary: Optional[str],
recent_queries: List[Dict],
include_tool_calls: bool = False,
context_type: str = "pre_request",
) -> List[Dict]:
"""
Build messages from compressed context.
Args:
system_prompt: Original system prompt
compressed_summary: Compressed summary (if any)
recent_queries: Recent uncompressed queries
include_tool_calls: Whether to include tool calls from history
context_type: Type of context ('pre_request' or 'mid_execution')
Returns:
List of message dicts ready for LLM
"""
# Append compression summary to system prompt if present
if compressed_summary:
system_prompt = MessageBuilder._append_compression_context(
system_prompt, compressed_summary, context_type
)
messages = [{"role": "system", "content": system_prompt}]
# Add recent history
for query in recent_queries:
if "prompt" in query and "response" in query:
messages.append({"role": "user", "content": query["prompt"]})
messages.append({"role": "assistant", "content": query["response"]})
# Add tool calls from history if present
if include_tool_calls and "tool_calls" in query:
for tool_call in query["tool_calls"]:
call_id = tool_call.get("call_id") or str(uuid.uuid4())
function_call_dict = {
"function_call": {
"name": tool_call.get("action_name"),
"args": tool_call.get("arguments"),
"call_id": call_id,
}
}
function_response_dict = {
"function_response": {
"name": tool_call.get("action_name"),
"response": {"result": tool_call.get("result")},
"call_id": call_id,
}
}
messages.append(
{"role": "assistant", "content": [function_call_dict]}
)
messages.append(
{"role": "tool", "content": [function_response_dict]}
)
# If no recent queries (everything was compressed), add a continuation user message
if len(recent_queries) == 0 and compressed_summary:
messages.append({
"role": "user",
"content": "Please continue with the remaining tasks based on the context above."
})
logger.info("Added continuation user message to maintain proper turn-taking after full compression")
return messages
@staticmethod
def _append_compression_context(
system_prompt: str, compressed_summary: str, context_type: str = "pre_request"
) -> str:
"""
Append compression context to system prompt.
Args:
system_prompt: Original system prompt
compressed_summary: Summary to append
context_type: Type of compression context
Returns:
Updated system prompt
"""
# Remove existing compression context if present
if "This session is being continued" in system_prompt or "Context window limit reached" in system_prompt:
parts = system_prompt.split("\n\n---\n\n")
system_prompt = parts[0]
# Build appropriate context message based on type
if context_type == "mid_execution":
context_message = (
"\n\n---\n\n"
"Context window limit reached during execution. "
"Previous conversation has been compressed to fit within limits. "
"The conversation is summarized below:\n\n"
f"{compressed_summary}"
)
else: # pre_request
context_message = (
"\n\n---\n\n"
"This session is being continued from a previous conversation that "
"has been compressed to fit within context limits. "
"The conversation is summarized below:\n\n"
f"{compressed_summary}"
)
return system_prompt + context_message
@staticmethod
def rebuild_messages_after_compression(
messages: List[Dict],
compressed_summary: Optional[str],
recent_queries: List[Dict],
include_current_execution: bool = False,
include_tool_calls: bool = False,
) -> Optional[List[Dict]]:
"""
Rebuild the message list after compression so tool execution can continue.
Args:
messages: Original message list
compressed_summary: Compressed summary
recent_queries: Recent uncompressed queries
include_current_execution: Whether to preserve current execution messages
include_tool_calls: Whether to include tool calls from history
Returns:
Rebuilt message list or None if failed
"""
# Find the system message
system_message = next(
(msg for msg in messages if msg.get("role") == "system"), None
)
if not system_message:
logger.warning("No system message found in messages list")
return None
# Update system message with compressed summary
if compressed_summary:
content = system_message.get("content", "")
system_message["content"] = MessageBuilder._append_compression_context(
content, compressed_summary, "mid_execution"
)
logger.info(
"Appended compression summary to system prompt (truncated): %s",
(
compressed_summary[:500] + "..."
if len(compressed_summary) > 500
else compressed_summary
),
)
rebuilt_messages = [system_message]
# Add recent history from compressed context
for query in recent_queries:
if "prompt" in query and "response" in query:
rebuilt_messages.append({"role": "user", "content": query["prompt"]})
rebuilt_messages.append(
{"role": "assistant", "content": query["response"]}
)
# Add tool calls from history if present
if include_tool_calls and "tool_calls" in query:
for tool_call in query["tool_calls"]:
call_id = tool_call.get("call_id") or str(uuid.uuid4())
function_call_dict = {
"function_call": {
"name": tool_call.get("action_name"),
"args": tool_call.get("arguments"),
"call_id": call_id,
}
}
function_response_dict = {
"function_response": {
"name": tool_call.get("action_name"),
"response": {"result": tool_call.get("result")},
"call_id": call_id,
}
}
rebuilt_messages.append(
{"role": "assistant", "content": [function_call_dict]}
)
rebuilt_messages.append(
{"role": "tool", "content": [function_response_dict]}
)
# If no recent queries (everything was compressed), add a continuation user message
if len(recent_queries) == 0 and compressed_summary:
rebuilt_messages.append({
"role": "user",
"content": "Please continue with the remaining tasks based on the context above."
})
logger.info("Added continuation user message to maintain proper turn-taking after full compression")
if include_current_execution:
# Preserve any messages that were added during the current execution cycle
recent_msg_count = 1 # system message
for query in recent_queries:
if "prompt" in query and "response" in query:
recent_msg_count += 2
if "tool_calls" in query:
recent_msg_count += len(query["tool_calls"]) * 2
if len(messages) > recent_msg_count:
current_execution_messages = messages[recent_msg_count:]
rebuilt_messages.extend(current_execution_messages)
logger.info(
f"Preserved {len(current_execution_messages)} messages from current execution cycle"
)
logger.info(
f"Messages rebuilt: {len(messages)}{len(rebuilt_messages)} messages. "
f"Ready to continue tool execution."
)
return rebuilt_messages

View File

@@ -1,232 +0,0 @@
"""High-level compression orchestration."""
import logging
from typing import Any, Dict, Optional
from application.api.answer.services.compression.service import CompressionService
from application.api.answer.services.compression.threshold_checker import (
CompressionThresholdChecker,
)
from application.api.answer.services.compression.types import CompressionResult
from application.api.answer.services.conversation_service import ConversationService
from application.core.model_utils import (
get_api_key_for_provider,
get_provider_from_model_id,
)
from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
logger = logging.getLogger(__name__)
class CompressionOrchestrator:
"""
Facade for compression operations.
Coordinates between all compression components and provides
a simple interface for callers.
"""
def __init__(
self,
conversation_service: ConversationService,
threshold_checker: Optional[CompressionThresholdChecker] = None,
):
"""
Initialize orchestrator.
Args:
conversation_service: Service for DB operations
threshold_checker: Custom threshold checker (optional)
"""
self.conversation_service = conversation_service
self.threshold_checker = threshold_checker or CompressionThresholdChecker()
def compress_if_needed(
self,
conversation_id: str,
user_id: str,
model_id: str,
decoded_token: Dict[str, Any],
current_query_tokens: int = 500,
) -> CompressionResult:
"""
Check if compression is needed and perform it if so.
This is the main entry point for compression operations.
Args:
conversation_id: Conversation ID
user_id: User ID
model_id: Model being used for conversation
decoded_token: User's decoded JWT token
current_query_tokens: Estimated tokens for current query
Returns:
CompressionResult with summary and recent queries
"""
try:
# Load conversation
conversation = self.conversation_service.get_conversation(
conversation_id, user_id
)
if not conversation:
logger.warning(
f"Conversation {conversation_id} not found for user {user_id}"
)
return CompressionResult.failure("Conversation not found")
# Check if compression is needed
if not self.threshold_checker.should_compress(
conversation, model_id, current_query_tokens
):
# No compression needed, return full history
queries = conversation.get("queries", [])
return CompressionResult.success_no_compression(queries)
# Perform compression
return self._perform_compression(
conversation_id, conversation, model_id, decoded_token
)
except Exception as e:
logger.error(
f"Error in compress_if_needed: {str(e)}", exc_info=True
)
return CompressionResult.failure(str(e))
def _perform_compression(
self,
conversation_id: str,
conversation: Dict[str, Any],
model_id: str,
decoded_token: Dict[str, Any],
) -> CompressionResult:
"""
Perform the actual compression operation.
Args:
conversation_id: Conversation ID
conversation: Conversation document
model_id: Model ID for conversation
decoded_token: User token
Returns:
CompressionResult
"""
try:
# Determine which model to use for compression
compression_model = (
settings.COMPRESSION_MODEL_OVERRIDE
if settings.COMPRESSION_MODEL_OVERRIDE
else model_id
)
# Get provider and API key for compression model
provider = get_provider_from_model_id(compression_model)
api_key = get_api_key_for_provider(provider)
# Create compression LLM
compression_llm = LLMCreator.create_llm(
provider,
api_key=api_key,
user_api_key=None,
decoded_token=decoded_token,
model_id=compression_model,
)
# Create compression service with DB update capability
compression_service = CompressionService(
llm=compression_llm,
model_id=compression_model,
conversation_service=self.conversation_service,
)
# Compress all queries up to the latest
queries_count = len(conversation.get("queries", []))
compress_up_to = queries_count - 1
if compress_up_to < 0:
logger.warning("No queries to compress")
return CompressionResult.success_no_compression([])
logger.info(
f"Initiating compression for conversation {conversation_id}: "
f"compressing all {queries_count} queries (0-{compress_up_to})"
)
# Perform compression and save to DB
metadata = compression_service.compress_and_save(
conversation_id, conversation, compress_up_to
)
logger.info(
f"Compression successful - ratio: {metadata.compression_ratio:.1f}x, "
f"saved {metadata.original_token_count - metadata.compressed_token_count} tokens"
)
# Reload conversation with updated metadata
conversation = self.conversation_service.get_conversation(
conversation_id, user_id=decoded_token.get("sub")
)
# Get compressed context
compressed_summary, recent_queries = (
compression_service.get_compressed_context(conversation)
)
return CompressionResult.success_with_compression(
compressed_summary, recent_queries, metadata
)
except Exception as e:
logger.error(f"Error performing compression: {str(e)}", exc_info=True)
return CompressionResult.failure(str(e))
def compress_mid_execution(
self,
conversation_id: str,
user_id: str,
model_id: str,
decoded_token: Dict[str, Any],
current_conversation: Optional[Dict[str, Any]] = None,
) -> CompressionResult:
"""
Perform compression during tool execution.
Args:
conversation_id: Conversation ID
user_id: User ID
model_id: Model ID
decoded_token: User token
current_conversation: Pre-loaded conversation (optional)
Returns:
CompressionResult
"""
try:
# Load conversation if not provided
if current_conversation:
conversation = current_conversation
else:
conversation = self.conversation_service.get_conversation(
conversation_id, user_id
)
if not conversation:
logger.warning(
f"Could not load conversation {conversation_id} for mid-execution compression"
)
return CompressionResult.failure("Conversation not found")
# Perform compression
return self._perform_compression(
conversation_id, conversation, model_id, decoded_token
)
except Exception as e:
logger.error(
f"Error in mid-execution compression: {str(e)}", exc_info=True
)
return CompressionResult.failure(str(e))

View File

@@ -1,149 +0,0 @@
"""Compression prompt building logic."""
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
class CompressionPromptBuilder:
"""Builds prompts for LLM compression calls."""
def __init__(self, version: str = "v1.0"):
"""
Initialize prompt builder.
Args:
version: Prompt template version to use
"""
self.version = version
self.system_prompt = self._load_prompt(version)
def _load_prompt(self, version: str) -> str:
"""
Load prompt template from file.
Args:
version: Version string (e.g., 'v1.0')
Returns:
Prompt template content
Raises:
FileNotFoundError: If prompt template file doesn't exist
"""
current_dir = Path(__file__).resolve().parents[4]
prompt_path = current_dir / "prompts" / "compression" / f"{version}.txt"
try:
with open(prompt_path, "r") as f:
return f.read()
except FileNotFoundError:
logger.error(f"Compression prompt template not found: {prompt_path}")
raise FileNotFoundError(
f"Compression prompt template '{version}' not found at {prompt_path}. "
f"Please ensure the template file exists."
)
def build_prompt(
self,
queries: List[Dict[str, Any]],
existing_compressions: Optional[List[Dict[str, Any]]] = None,
) -> List[Dict[str, str]]:
"""
Build messages for compression LLM call.
Args:
queries: List of query objects to compress
existing_compressions: List of previous compression points
Returns:
List of message dicts for LLM
"""
# Build conversation text
conversation_text = self._format_conversation(queries)
# Add existing compression context if present
existing_compression_context = ""
if existing_compressions and len(existing_compressions) > 0:
existing_compression_context = (
"\n\nIMPORTANT: This conversation has been compressed before. "
"Previous compression summaries:\n\n"
)
for i, comp in enumerate(existing_compressions):
existing_compression_context += (
f"--- Compression {i + 1} (up to message {comp.get('query_index', 'unknown')}) ---\n"
f"{comp.get('compressed_summary', '')}\n\n"
)
existing_compression_context += (
"Your task is to create a NEW summary that incorporates the context from "
"previous compressions AND the new messages below. The final summary should "
"be comprehensive and include all important information from both previous "
"compressions and new messages.\n\n"
)
user_prompt = (
f"{existing_compression_context}"
f"Here is the conversation to summarize:\n\n"
f"{conversation_text}"
)
messages = [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": user_prompt},
]
return messages
def _format_conversation(self, queries: List[Dict[str, Any]]) -> str:
"""
Format conversation queries into readable text for compression.
Args:
queries: List of query objects
Returns:
Formatted conversation text
"""
conversation_lines = []
for i, query in enumerate(queries):
conversation_lines.append(f"--- Message {i + 1} ---")
conversation_lines.append(f"User: {query.get('prompt', '')}")
# Add tool calls if present
tool_calls = query.get("tool_calls", [])
if tool_calls:
conversation_lines.append("\nTool Calls:")
for tc in tool_calls:
tool_name = tc.get("tool_name", "unknown")
action_name = tc.get("action_name", "unknown")
arguments = tc.get("arguments", {})
result = tc.get("result", "")
if result is None:
result = ""
status = tc.get("status", "unknown")
# Include full tool result for complete compression context
conversation_lines.append(
f" - {tool_name}.{action_name}({arguments}) "
f"[{status}] → {result}"
)
# Add agent thought if present
thought = query.get("thought", "")
if thought:
conversation_lines.append(f"\nAgent Thought: {thought}")
# Add assistant response
conversation_lines.append(f"\nAssistant: {query.get('response', '')}")
# Add sources if present
sources = query.get("sources", [])
if sources:
conversation_lines.append(f"\nSources Used: {len(sources)} documents")
conversation_lines.append("") # Empty line between messages
return "\n".join(conversation_lines)

View File

@@ -1,306 +0,0 @@
"""Core compression service with simplified responsibilities."""
import logging
import re
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
from application.api.answer.services.compression.prompt_builder import (
CompressionPromptBuilder,
)
from application.api.answer.services.compression.token_counter import TokenCounter
from application.api.answer.services.compression.types import (
CompressionMetadata,
)
from application.core.settings import settings
logger = logging.getLogger(__name__)
class CompressionService:
"""
Service for compressing conversation history.
Handles DB updates.
"""
def __init__(
self,
llm,
model_id: str,
conversation_service=None,
prompt_builder: Optional[CompressionPromptBuilder] = None,
):
"""
Initialize compression service.
Args:
llm: LLM instance to use for compression
model_id: Model ID for compression
conversation_service: Service for DB operations (optional, for DB updates)
prompt_builder: Custom prompt builder (optional)
"""
self.llm = llm
self.model_id = model_id
self.conversation_service = conversation_service
self.prompt_builder = prompt_builder or CompressionPromptBuilder(
version=settings.COMPRESSION_PROMPT_VERSION
)
def compress_conversation(
self,
conversation: Dict[str, Any],
compress_up_to_index: int,
) -> CompressionMetadata:
"""
Compress conversation history up to specified index.
Args:
conversation: Full conversation document
compress_up_to_index: Last query index to include in compression
Returns:
CompressionMetadata with compression details
Raises:
ValueError: If compress_up_to_index is invalid
"""
try:
queries = conversation.get("queries", [])
if compress_up_to_index < 0 or compress_up_to_index >= len(queries):
raise ValueError(
f"Invalid compress_up_to_index: {compress_up_to_index} "
f"(conversation has {len(queries)} queries)"
)
# Get queries to compress
queries_to_compress = queries[: compress_up_to_index + 1]
# Check if there are existing compressions
existing_compressions = conversation.get("compression_metadata", {}).get(
"compression_points", []
)
if existing_compressions:
logger.info(
f"Found {len(existing_compressions)} previous compression(s) - "
f"will incorporate into new summary"
)
# Calculate original token count
original_tokens = TokenCounter.count_query_tokens(queries_to_compress)
# Log tool call stats
self._log_tool_call_stats(queries_to_compress)
# Build compression prompt
messages = self.prompt_builder.build_prompt(
queries_to_compress, existing_compressions
)
# Call LLM to generate compression
logger.info(
f"Starting compression: {len(queries_to_compress)} queries "
f"(messages 0-{compress_up_to_index}, {original_tokens} tokens) "
f"using model {self.model_id}"
)
response = self.llm.gen(
model=self.model_id, messages=messages, max_tokens=4000
)
# Extract summary from response
compressed_summary = self._extract_summary(response)
# Calculate compressed token count
compressed_tokens = TokenCounter.count_message_tokens(
[{"content": compressed_summary}]
)
# Calculate compression ratio
compression_ratio = (
original_tokens / compressed_tokens if compressed_tokens > 0 else 0
)
logger.info(
f"Compression complete: {original_tokens}{compressed_tokens} tokens "
f"({compression_ratio:.1f}x compression)"
)
# Build compression metadata
compression_metadata = CompressionMetadata(
timestamp=datetime.now(timezone.utc),
query_index=compress_up_to_index,
compressed_summary=compressed_summary,
original_token_count=original_tokens,
compressed_token_count=compressed_tokens,
compression_ratio=compression_ratio,
model_used=self.model_id,
compression_prompt_version=self.prompt_builder.version,
)
return compression_metadata
except Exception as e:
logger.error(f"Error compressing conversation: {str(e)}", exc_info=True)
raise
def compress_and_save(
self,
conversation_id: str,
conversation: Dict[str, Any],
compress_up_to_index: int,
) -> CompressionMetadata:
"""
Compress conversation and save to database.
Args:
conversation_id: Conversation ID
conversation: Full conversation document
compress_up_to_index: Last query index to include
Returns:
CompressionMetadata
Raises:
ValueError: If conversation_service not provided or invalid index
"""
if not self.conversation_service:
raise ValueError(
"conversation_service required for compress_and_save operation"
)
# Perform compression
metadata = self.compress_conversation(conversation, compress_up_to_index)
# Save to database
self.conversation_service.update_compression_metadata(
conversation_id, metadata.to_dict()
)
logger.info(f"Compression metadata saved to database for {conversation_id}")
return metadata
def get_compressed_context(
self, conversation: Dict[str, Any]
) -> tuple[Optional[str], List[Dict[str, Any]]]:
"""
Get compressed summary + recent uncompressed messages.
Args:
conversation: Full conversation document
Returns:
(compressed_summary, recent_messages)
"""
try:
compression_metadata = conversation.get("compression_metadata", {})
if not compression_metadata.get("is_compressed"):
logger.debug("No compression metadata found - using full history")
queries = conversation.get("queries", [])
if queries is None:
logger.error("Conversation queries is None - returning empty list")
return None, []
return None, queries
compression_points = compression_metadata.get("compression_points", [])
if not compression_points:
logger.debug("No compression points found - using full history")
queries = conversation.get("queries", [])
if queries is None:
logger.error("Conversation queries is None - returning empty list")
return None, []
return None, queries
# Get the most recent compression point
latest_compression = compression_points[-1]
compressed_summary = latest_compression.get("compressed_summary")
last_compressed_index = latest_compression.get("query_index")
compressed_tokens = latest_compression.get("compressed_token_count", 0)
original_tokens = latest_compression.get("original_token_count", 0)
# Get only messages after compression point
queries = conversation.get("queries", [])
total_queries = len(queries)
recent_queries = queries[last_compressed_index + 1 :]
logger.info(
f"Using compressed context: summary ({compressed_tokens} tokens, "
f"compressed from {original_tokens}) + {len(recent_queries)} recent messages "
f"(messages {last_compressed_index + 1}-{total_queries - 1})"
)
return compressed_summary, recent_queries
except Exception as e:
logger.error(
f"Error getting compressed context: {str(e)}", exc_info=True
)
queries = conversation.get("queries", [])
if queries is None:
return None, []
return None, queries
def _extract_summary(self, llm_response: str) -> str:
"""
Extract clean summary from LLM response.
Args:
llm_response: Raw LLM response
Returns:
Cleaned summary text
"""
try:
# Try to extract content within <summary> tags
summary_match = re.search(
r"<summary>(.*?)</summary>", llm_response, re.DOTALL
)
if summary_match:
summary = summary_match.group(1).strip()
else:
# If no summary tags, remove analysis tags and use the rest
summary = re.sub(
r"<analysis>.*?</analysis>", "", llm_response, flags=re.DOTALL
).strip()
return summary
except Exception as e:
logger.warning(f"Error extracting summary: {str(e)}, using full response")
return llm_response
def _log_tool_call_stats(self, queries: List[Dict[str, Any]]) -> None:
"""Log statistics about tool calls in queries."""
total_tool_calls = 0
total_tool_result_chars = 0
tool_call_breakdown = {}
for q in queries:
for tc in q.get("tool_calls", []):
total_tool_calls += 1
tool_name = tc.get("tool_name", "unknown")
action_name = tc.get("action_name", "unknown")
key = f"{tool_name}.{action_name}"
tool_call_breakdown[key] = tool_call_breakdown.get(key, 0) + 1
# Track total tool result size
result = tc.get("result", "")
if result:
total_tool_result_chars += len(str(result))
if total_tool_calls > 0:
tool_breakdown_str = ", ".join(
f"{tool}({count})"
for tool, count in sorted(tool_call_breakdown.items())
)
tool_result_kb = total_tool_result_chars / 1024
logger.info(
f"Tool call breakdown: {tool_breakdown_str} "
f"(total result size: {tool_result_kb:.1f} KB, {total_tool_result_chars:,} chars)"
)

View File

@@ -1,103 +0,0 @@
"""Compression threshold checking logic."""
import logging
from typing import Any, Dict
from application.core.model_utils import get_token_limit
from application.core.settings import settings
from application.api.answer.services.compression.token_counter import TokenCounter
logger = logging.getLogger(__name__)
class CompressionThresholdChecker:
"""Determines if compression is needed based on token thresholds."""
def __init__(self, threshold_percentage: float = None):
"""
Initialize threshold checker.
Args:
threshold_percentage: Percentage of context to use as threshold
(defaults to settings.COMPRESSION_THRESHOLD_PERCENTAGE)
"""
self.threshold_percentage = (
threshold_percentage or settings.COMPRESSION_THRESHOLD_PERCENTAGE
)
def should_compress(
self,
conversation: Dict[str, Any],
model_id: str,
current_query_tokens: int = 500,
) -> bool:
"""
Determine if compression is needed.
Args:
conversation: Full conversation document
model_id: Target model for this request
current_query_tokens: Estimated tokens for current query
Returns:
True if tokens >= threshold% of context window
"""
try:
# Calculate total tokens in conversation
total_tokens = TokenCounter.count_conversation_tokens(conversation)
total_tokens += current_query_tokens
# Get context window limit for model
context_limit = get_token_limit(model_id)
# Calculate threshold
threshold = int(context_limit * self.threshold_percentage)
compression_needed = total_tokens >= threshold
percentage_used = (total_tokens / context_limit) * 100
if compression_needed:
logger.warning(
f"COMPRESSION TRIGGERED: {total_tokens} tokens / {context_limit} limit "
f"({percentage_used:.1f}% used, threshold: {self.threshold_percentage * 100:.0f}%)"
)
else:
logger.info(
f"Compression check: {total_tokens}/{context_limit} tokens "
f"({percentage_used:.1f}% used, threshold: {self.threshold_percentage * 100:.0f}%) - No compression needed"
)
return compression_needed
except Exception as e:
logger.error(f"Error checking compression need: {str(e)}", exc_info=True)
return False
def check_message_tokens(self, messages: list, model_id: str) -> bool:
"""
Check if message list exceeds threshold.
Args:
messages: List of message dicts
model_id: Target model
Returns:
True if at or above threshold
"""
try:
current_tokens = TokenCounter.count_message_tokens(messages)
context_limit = get_token_limit(model_id)
threshold = int(context_limit * self.threshold_percentage)
if current_tokens >= threshold:
logger.warning(
f"Message context limit approaching: {current_tokens}/{context_limit} tokens "
f"({(current_tokens/context_limit)*100:.1f}%)"
)
return True
return False
except Exception as e:
logger.error(f"Error checking message tokens: {str(e)}", exc_info=True)
return False

View File

@@ -1,103 +0,0 @@
"""Token counting utilities for compression."""
import logging
from typing import Any, Dict, List
from application.utils import num_tokens_from_string
from application.core.settings import settings
logger = logging.getLogger(__name__)
class TokenCounter:
"""Centralized token counting for conversations and messages."""
@staticmethod
def count_message_tokens(messages: List[Dict]) -> int:
"""
Calculate total tokens in a list of messages.
Args:
messages: List of message dicts with 'content' field
Returns:
Total token count
"""
total_tokens = 0
for message in messages:
content = message.get("content", "")
if isinstance(content, str):
total_tokens += num_tokens_from_string(content)
elif isinstance(content, list):
# Handle structured content (tool calls, etc.)
for item in content:
if isinstance(item, dict):
total_tokens += num_tokens_from_string(str(item))
return total_tokens
@staticmethod
def count_query_tokens(
queries: List[Dict[str, Any]], include_tool_calls: bool = True
) -> int:
"""
Count tokens across multiple query objects.
Args:
queries: List of query objects from conversation
include_tool_calls: Whether to count tool call tokens
Returns:
Total token count
"""
total_tokens = 0
for query in queries:
# Count prompt and response tokens
if "prompt" in query:
total_tokens += num_tokens_from_string(query["prompt"])
if "response" in query:
total_tokens += num_tokens_from_string(query["response"])
if "thought" in query:
total_tokens += num_tokens_from_string(query.get("thought", ""))
# Count tool call tokens
if include_tool_calls and "tool_calls" in query:
for tool_call in query["tool_calls"]:
tool_call_string = (
f"Tool: {tool_call.get('tool_name')} | "
f"Action: {tool_call.get('action_name')} | "
f"Args: {tool_call.get('arguments')} | "
f"Response: {tool_call.get('result')}"
)
total_tokens += num_tokens_from_string(tool_call_string)
return total_tokens
@staticmethod
def count_conversation_tokens(
conversation: Dict[str, Any], include_system_prompt: bool = False
) -> int:
"""
Calculate total tokens in a conversation.
Args:
conversation: Conversation document
include_system_prompt: Whether to include system prompt in count
Returns:
Total token count
"""
try:
queries = conversation.get("queries", [])
total_tokens = TokenCounter.count_query_tokens(queries)
# Add system prompt tokens if requested
if include_system_prompt:
# Rough estimate for system prompt
total_tokens += settings.RESERVED_TOKENS.get("system_prompt", 500)
return total_tokens
except Exception as e:
logger.error(f"Error calculating conversation tokens: {str(e)}")
return 0

View File

@@ -1,83 +0,0 @@
"""Type definitions for compression module."""
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Dict, List, Optional
@dataclass
class CompressionMetadata:
"""Metadata about a compression operation."""
timestamp: datetime
query_index: int
compressed_summary: str
original_token_count: int
compressed_token_count: int
compression_ratio: float
model_used: str
compression_prompt_version: str
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for DB storage."""
return {
"timestamp": self.timestamp,
"query_index": self.query_index,
"compressed_summary": self.compressed_summary,
"original_token_count": self.original_token_count,
"compressed_token_count": self.compressed_token_count,
"compression_ratio": self.compression_ratio,
"model_used": self.model_used,
"compression_prompt_version": self.compression_prompt_version,
}
@dataclass
class CompressionResult:
"""Result of a compression operation."""
success: bool
compressed_summary: Optional[str] = None
recent_queries: List[Dict[str, Any]] = field(default_factory=list)
metadata: Optional[CompressionMetadata] = None
error: Optional[str] = None
compression_performed: bool = False
@classmethod
def success_with_compression(
cls, summary: str, queries: List[Dict], metadata: CompressionMetadata
) -> "CompressionResult":
"""Create a successful result with compression."""
return cls(
success=True,
compressed_summary=summary,
recent_queries=queries,
metadata=metadata,
compression_performed=True,
)
@classmethod
def success_no_compression(cls, queries: List[Dict]) -> "CompressionResult":
"""Create a successful result without compression needed."""
return cls(
success=True,
recent_queries=queries,
compression_performed=False,
)
@classmethod
def failure(cls, error: str) -> "CompressionResult":
"""Create a failure result."""
return cls(success=False, error=error, compression_performed=False)
def as_history(self) -> List[Dict[str, str]]:
"""
Convert recent queries to history format.
Returns:
List of prompt/response dicts
"""
return [
{"prompt": q["prompt"], "response": q["response"]}
for q in self.recent_queries
]

View File

@@ -62,8 +62,6 @@ class ConversationService:
attachment_ids: Optional[List[str]] = None,
) -> str:
"""Save or update a conversation in the database"""
if decoded_token is None:
raise ValueError("Invalid or missing authentication token")
user_id = decoded_token.get("sub")
if not user_id:
raise ValueError("User ID not found in token")
@@ -150,12 +148,9 @@ class ConversationService:
]
completion = llm.gen(
model=model_id, messages=messages_summary, max_tokens=500
model=model_id, messages=messages_summary, max_tokens=30
)
if not completion or not completion.strip():
completion = question[:50] if question else "New Conversation"
conversation_data = {
"user": user_id,
"date": current_time,
@@ -185,103 +180,3 @@ class ConversationService:
conversation_data["api_key"] = agent["key"]
result = self.conversations_collection.insert_one(conversation_data)
return str(result.inserted_id)
def update_compression_metadata(
self, conversation_id: str, compression_metadata: Dict[str, Any]
) -> None:
"""
Update conversation with compression metadata.
Uses $push with $slice to keep only the most recent compression points,
preventing unbounded array growth. Since each compression incorporates
previous compressions, older points become redundant.
Args:
conversation_id: Conversation ID
compression_metadata: Compression point data
"""
try:
self.conversations_collection.update_one(
{"_id": ObjectId(conversation_id)},
{
"$set": {
"compression_metadata.is_compressed": True,
"compression_metadata.last_compression_at": compression_metadata.get(
"timestamp"
),
},
"$push": {
"compression_metadata.compression_points": {
"$each": [compression_metadata],
"$slice": -settings.COMPRESSION_MAX_HISTORY_POINTS,
}
},
},
)
logger.info(
f"Updated compression metadata for conversation {conversation_id}"
)
except Exception as e:
logger.error(
f"Error updating compression metadata: {str(e)}", exc_info=True
)
raise
def append_compression_message(
self, conversation_id: str, compression_metadata: Dict[str, Any]
) -> None:
"""
Append a synthetic compression summary entry into the conversation history.
This makes the summary visible in the DB alongside normal queries.
"""
try:
summary = compression_metadata.get("compressed_summary", "")
if not summary:
return
timestamp = compression_metadata.get("timestamp", datetime.now(timezone.utc))
self.conversations_collection.update_one(
{"_id": ObjectId(conversation_id)},
{
"$push": {
"queries": {
"prompt": "[Context Compression Summary]",
"response": summary,
"thought": "",
"sources": [],
"tool_calls": [],
"timestamp": timestamp,
"attachments": [],
"model_id": compression_metadata.get("model_used"),
}
}
},
)
logger.info(f"Appended compression summary to conversation {conversation_id}")
except Exception as e:
logger.error(
f"Error appending compression summary: {str(e)}", exc_info=True
)
def get_compression_metadata(
self, conversation_id: str
) -> Optional[Dict[str, Any]]:
"""
Get compression metadata for a conversation.
Args:
conversation_id: Conversation ID
Returns:
Compression metadata dict or None
"""
try:
conversation = self.conversations_collection.find_one(
{"_id": ObjectId(conversation_id)}, {"compression_metadata": 1}
)
return conversation.get("compression_metadata") if conversation else None
except Exception as e:
logger.error(
f"Error getting compression metadata: {str(e)}", exc_info=True
)
return None

View File

@@ -10,8 +10,6 @@ from bson.dbref import DBRef
from bson.objectid import ObjectId
from application.agents.agent_creator import AgentCreator
from application.api.answer.services.compression import CompressionOrchestrator
from application.api.answer.services.compression.token_counter import TokenCounter
from application.api.answer.services.conversation_service import ConversationService
from application.api.answer.services.prompt_renderer import PromptRenderer
from application.core.model_utils import (
@@ -92,21 +90,17 @@ class StreamProcessor:
self.shared_token = None
self.model_id: Optional[str] = None
self.conversation_service = ConversationService()
self.compression_orchestrator = CompressionOrchestrator(
self.conversation_service
)
self.prompt_renderer = PromptRenderer()
self._prompt_content: Optional[str] = None
self._required_tool_actions: Optional[Dict[str, Set[Optional[str]]]] = None
self.compressed_summary: Optional[str] = None
self.compressed_summary_tokens: int = 0
def initialize(self):
"""Initialize all required components for processing"""
self._configure_agent()
self._validate_and_set_model()
self._configure_agent()
self._configure_source()
self._configure_retriever()
self._configure_agent()
self._load_conversation_history()
self._process_attachments()
@@ -118,69 +112,14 @@ class StreamProcessor:
)
if not conversation:
raise ValueError("Conversation not found or unauthorized")
# Check if compression is enabled and needed
if settings.ENABLE_CONVERSATION_COMPRESSION:
self._handle_compression(conversation)
else:
# Original behavior - load all history
self.history = [
{"prompt": query["prompt"], "response": query["response"]}
for query in conversation.get("queries", [])
]
else:
self.history = limit_chat_history(
json.loads(self.data.get("history", "[]")), model_id=self.model_id
)
def _handle_compression(self, conversation: Dict[str, Any]):
"""
Handle conversation compression logic using orchestrator.
Args:
conversation: Full conversation document
"""
try:
# Use orchestrator to handle all compression logic
result = self.compression_orchestrator.compress_if_needed(
conversation_id=self.conversation_id,
user_id=self.initial_user_id,
model_id=self.model_id,
decoded_token=self.decoded_token,
)
if not result.success:
logger.error(f"Compression failed: {result.error}, using full history")
self.history = [
{"prompt": query["prompt"], "response": query["response"]}
for query in conversation.get("queries", [])
]
return
# Set compressed summary if compression was performed
if result.compression_performed and result.compressed_summary:
self.compressed_summary = result.compressed_summary
self.compressed_summary_tokens = TokenCounter.count_message_tokens(
[{"content": result.compressed_summary}]
)
logger.info(
f"Using compressed summary ({self.compressed_summary_tokens} tokens) "
f"+ {len(result.recent_queries)} recent messages"
)
# Build history from recent queries
self.history = result.as_history()
except Exception as e:
logger.error(
f"Error handling compression, falling back to standard history: {str(e)}",
exc_info=True,
)
# Fallback to original behavior
self.history = [
{"prompt": query["prompt"], "response": query["response"]}
for query in conversation.get("queries", [])
]
else:
self.history = limit_chat_history(
json.loads(self.data.get("history", "[]")), model_id=self.model_id
)
def _process_attachments(self):
"""Process any attachments in the request"""
@@ -223,20 +162,11 @@ class StreamProcessor:
raise ValueError(
f"Invalid model_id '{requested_model}'. "
f"Available models: {', '.join(available_models[:5])}"
+ (
f" and {len(available_models) - 5} more"
if len(available_models) > 5
else ""
)
+ (f" and {len(available_models) - 5} more" if len(available_models) > 5 else "")
)
self.model_id = requested_model
else:
# Check if agent has a default model configured
agent_default_model = self.agent_config.get("default_model_id", "")
if agent_default_model and validate_model_id(agent_default_model):
self.model_id = agent_default_model
else:
self.model_id = get_default_model_id()
self.model_id = get_default_model_id()
def _get_agent_key(self, agent_id: Optional[str], user_id: Optional[str]) -> tuple:
"""Get API key for agent with access control"""
@@ -309,10 +239,6 @@ class StreamProcessor:
data["sources"] = sources_list
else:
data["sources"] = []
# Preserve model configuration from agent
data["default_model_id"] = data.get("default_model_id", "")
return data
def _configure_source(self):
@@ -365,16 +291,12 @@ class StreamProcessor:
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
"user_api_key": api_key,
"json_schema": data_key.get("json_schema"),
"default_model_id": data_key.get("default_model_id", ""),
}
)
self.initial_user_id = data_key.get("user")
self.decoded_token = {"sub": data_key.get("user")}
if data_key.get("source"):
self.source = {"active_docs": data_key["source"]}
if data_key.get("workflow"):
self.agent_config["workflow"] = data_key["workflow"]
self.agent_config["workflow_owner"] = data_key.get("user")
if data_key.get("retriever"):
self.retriever_config["retriever_name"] = data_key["retriever"]
if data_key.get("chunks") is not None:
@@ -393,7 +315,6 @@ class StreamProcessor:
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
"user_api_key": self.agent_key,
"json_schema": data_key.get("json_schema"),
"default_model_id": data_key.get("default_model_id", ""),
}
)
self.decoded_token = (
@@ -403,9 +324,6 @@ class StreamProcessor:
)
if data_key.get("source"):
self.source = {"active_docs": data_key["source"]}
if data_key.get("workflow"):
self.agent_config["workflow"] = data_key["workflow"]
self.agent_config["workflow_owner"] = data_key.get("user")
if data_key.get("retriever"):
self.retriever_config["retriever_name"] = data_key["retriever"]
if data_key.get("chunks") is not None:
@@ -417,32 +335,26 @@ class StreamProcessor:
)
self.retriever_config["chunks"] = 2
else:
agent_type = settings.AGENT_NAME
if self.data.get("workflow") and isinstance(
self.data.get("workflow"), dict
):
agent_type = "workflow"
self.agent_config["workflow"] = self.data["workflow"]
if isinstance(self.decoded_token, dict):
self.agent_config["workflow_owner"] = self.decoded_token.get("sub")
self.agent_config.update(
{
"prompt_id": self.data.get("prompt_id", "default"),
"agent_type": agent_type,
"agent_type": settings.AGENT_NAME,
"user_api_key": None,
"json_schema": None,
"default_model_id": "",
}
)
def _configure_retriever(self):
doc_token_limit = calculate_doc_token_budget(model_id=self.model_id)
history_token_limit = int(self.data.get("token_limit", 2000))
doc_token_limit = calculate_doc_token_budget(
model_id=self.model_id, history_token_limit=history_token_limit
)
self.retriever_config = {
"retriever_name": self.data.get("retriever", "classic"),
"chunks": int(self.data.get("chunks", 2)),
"doc_token_limit": doc_token_limit,
"history_token_limit": history_token_limit,
}
api_key = self.data.get("api_key") or self.agent_key
@@ -746,38 +658,17 @@ class StreamProcessor:
)
system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER)
agent_type = self.agent_config["agent_type"]
# Base agent kwargs
agent_kwargs = {
"endpoint": "stream",
"llm_name": provider or settings.LLM_PROVIDER,
"model_id": self.model_id,
"api_key": system_api_key,
"user_api_key": self.agent_config["user_api_key"],
"prompt": rendered_prompt,
"chat_history": self.history,
"retrieved_docs": self.retrieved_docs,
"decoded_token": self.decoded_token,
"attachments": self.attachments,
"json_schema": self.agent_config.get("json_schema"),
"compressed_summary": self.compressed_summary,
}
# Workflow-specific kwargs for workflow agents
if agent_type == "workflow":
workflow_config = self.agent_config.get("workflow")
if isinstance(workflow_config, str):
agent_kwargs["workflow_id"] = workflow_config
elif isinstance(workflow_config, dict):
agent_kwargs["workflow"] = workflow_config
workflow_owner = self.agent_config.get("workflow_owner")
if workflow_owner:
agent_kwargs["workflow_owner"] = workflow_owner
agent = AgentCreator.create_agent(agent_type, **agent_kwargs)
agent.conversation_id = self.conversation_id
agent.initial_user_id = self.initial_user_id
return agent
return AgentCreator.create_agent(
self.agent_config["agent_type"],
endpoint="stream",
llm_name=provider or settings.LLM_PROVIDER,
model_id=self.model_id,
api_key=system_api_key,
user_api_key=self.agent_config["user_api_key"],
prompt=rendered_prompt,
chat_history=self.history,
retrieved_docs=self.retrieved_docs,
decoded_token=self.decoded_token,
attachments=self.attachments,
json_schema=self.agent_config.get("json_schema"),
)

View File

@@ -1,9 +1,7 @@
import base64
import datetime
import html
import json
import uuid
from urllib.parse import urlencode
from bson.objectid import ObjectId
@@ -37,18 +35,6 @@ connector = Blueprint("connector", __name__)
connectors_ns = Namespace("connectors", description="Connector operations", path="/")
api.add_namespace(connectors_ns)
# Fixed callback status path to prevent open redirect
CALLBACK_STATUS_PATH = "/api/connectors/callback-status"
def build_callback_redirect(params: dict) -> str:
"""Build a safe redirect URL to the callback status page.
Uses a fixed path and properly URL-encodes all parameters
to prevent URL injection and open redirect vulnerabilities.
"""
return f"{CALLBACK_STATUS_PATH}?{urlencode(params)}"
@connectors_ns.route("/api/connectors/auth")
@@ -89,8 +75,8 @@ class ConnectorAuth(Resource):
"state": state
}), 200)
except Exception as e:
current_app.logger.error(f"Error generating connector auth URL: {e}", exc_info=True)
return make_response(jsonify({"success": False, "error": "Failed to generate authorization URL"}), 500)
current_app.logger.error(f"Error generating connector auth URL: {e}")
return make_response(jsonify({"success": False, "error": str(e)}), 500)
@connectors_ns.route("/api/connectors/callback")
@@ -107,37 +93,18 @@ class ConnectorsCallback(Resource):
error = request.args.get('error')
state_dict = json.loads(base64.urlsafe_b64decode(state.encode()).decode())
provider = state_dict.get("provider")
state_object_id = state_dict.get("object_id")
# Validate provider
if not provider or not isinstance(provider, str) or not ConnectorCreator.is_supported(provider):
return redirect(build_callback_redirect({
"status": "error",
"message": "Invalid provider"
}))
provider = state_dict["provider"]
state_object_id = state_dict["object_id"]
if error:
if error == "access_denied":
return redirect(build_callback_redirect({
"status": "cancelled",
"message": "Authentication was cancelled. You can try again if you'd like to connect your account.",
"provider": provider
}))
return redirect(f"/api/connectors/callback-status?status=cancelled&message=Authentication+was+cancelled.+You+can+try+again+if+you'd+like+to+connect+your+account.&provider={provider}")
else:
current_app.logger.warning(f"OAuth error in callback: {error}")
return redirect(build_callback_redirect({
"status": "error",
"message": "Authentication failed. Please try again and make sure to grant all requested permissions.",
"provider": provider
}))
return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}")
if not authorization_code:
return redirect(build_callback_redirect({
"status": "error",
"message": "Authentication failed. Please try again and make sure to grant all requested permissions.",
"provider": provider
}))
return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}")
try:
auth = ConnectorCreator.create_auth(provider)
@@ -146,19 +113,20 @@ class ConnectorsCallback(Resource):
session_token = str(uuid.uuid4())
try:
if provider == "google_drive":
credentials = auth.create_credentials_from_token_info(token_info)
service = auth.build_drive_service(credentials)
user_info = service.about().get(fields="user").execute()
user_email = user_info.get('user', {}).get('emailAddress', 'Connected User')
else:
user_email = token_info.get('user_info', {}).get('email', 'Connected User')
credentials = auth.create_credentials_from_token_info(token_info)
service = auth.build_drive_service(credentials)
user_info = service.about().get(fields="user").execute()
user_email = user_info.get('user', {}).get('emailAddress', 'Connected User')
except Exception as e:
current_app.logger.warning(f"Could not get user info: {e}")
user_email = 'Connected User'
sanitized_token_info = auth.sanitize_token_info(token_info)
sanitized_token_info = {
"access_token": token_info.get("access_token"),
"refresh_token": token_info.get("refresh_token"),
"token_uri": token_info.get("token_uri"),
"expiry": token_info.get("expiry")
}
sessions_collection.find_one_and_update(
{"_id": ObjectId(state_object_id), "provider": provider},
@@ -173,39 +141,26 @@ class ConnectorsCallback(Resource):
)
# Redirect to success page with session token and user email
return redirect(build_callback_redirect({
"status": "success",
"message": "Authentication successful",
"provider": provider,
"session_token": session_token,
"user_email": user_email
}))
return redirect(f"/api/connectors/callback-status?status=success&message=Authentication+successful&provider={provider}&session_token={session_token}&user_email={user_email}")
except Exception as e:
current_app.logger.error(f"Error exchanging code for tokens: {str(e)}", exc_info=True)
return redirect(build_callback_redirect({
"status": "error",
"message": "Authentication failed. Please try again and make sure to grant all requested permissions.",
"provider": provider
}))
return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}")
except Exception as e:
current_app.logger.error(f"Error handling connector callback: {e}")
return redirect(build_callback_redirect({
"status": "error",
"message": "Authentication failed. Please try again and make sure to grant all requested permissions."
}))
return redirect("/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.")
@connectors_ns.route("/api/connectors/files")
class ConnectorFiles(Resource):
@api.expect(api.model("ConnectorFilesModel", {
"provider": fields.String(required=True),
"session_token": fields.String(required=True),
"folder_id": fields.String(required=False),
"limit": fields.Integer(required=False),
"provider": fields.String(required=True),
"session_token": fields.String(required=True),
"folder_id": fields.String(required=False),
"limit": fields.Integer(required=False),
"page_token": fields.String(required=False),
"search_query": fields.String(required=False),
"search_query": fields.String(required=False)
}))
@api.doc(description="List files from a connector provider (supports pagination and search)")
def post(self):
@@ -213,8 +168,11 @@ class ConnectorFiles(Resource):
data = request.get_json()
provider = data.get('provider')
session_token = data.get('session_token')
folder_id = data.get('folder_id')
limit = data.get('limit', 10)
page_token = data.get('page_token')
search_query = data.get('search_query')
if not provider or not session_token:
return make_response(jsonify({"success": False, "error": "provider and session_token are required"}), 400)
@@ -227,12 +185,15 @@ class ConnectorFiles(Resource):
return make_response(jsonify({"success": False, "error": "Invalid or unauthorized session"}), 401)
loader = ConnectorCreator.create_connector(provider, session_token)
generic_keys = {'provider', 'session_token'}
input_config = {
k: v for k, v in data.items() if k not in generic_keys
'limit': limit,
'list_only': True,
'session_token': session_token,
'folder_id': folder_id,
'page_token': page_token
}
input_config['list_only'] = True
if search_query:
input_config['search_query'] = search_query
documents = loader.load_data(input_config)
@@ -267,8 +228,8 @@ class ConnectorFiles(Resource):
"has_more": has_more
}), 200)
except Exception as e:
current_app.logger.error(f"Error loading connector files: {e}", exc_info=True)
return make_response(jsonify({"success": False, "error": "Failed to load files"}), 500)
current_app.logger.error(f"Error loading connector files: {e}")
return make_response(jsonify({"success": False, "error": f"Failed to load files: {str(e)}"}), 500)
@connectors_ns.route("/api/connectors/validate-session")
@@ -299,7 +260,12 @@ class ConnectorValidateSession(Resource):
if is_expired and token_info.get('refresh_token'):
try:
refreshed_token_info = auth.refresh_access_token(token_info.get('refresh_token'))
sanitized_token_info = auth.sanitize_token_info(refreshed_token_info)
sanitized_token_info = {
"access_token": refreshed_token_info.get("access_token"),
"refresh_token": refreshed_token_info.get("refresh_token"),
"token_uri": refreshed_token_info.get("token_uri"),
"expiry": refreshed_token_info.get("expiry")
}
sessions_collection.update_one(
{"session_token": session_token},
{"$set": {"token_info": sanitized_token_info}}
@@ -316,21 +282,15 @@ class ConnectorValidateSession(Resource):
"error": "Session token has expired. Please reconnect."
}), 401)
_base_fields = {"access_token", "refresh_token", "token_uri", "expiry"}
provider_extras = {k: v for k, v in token_info.items() if k not in _base_fields}
response_data = {
return make_response(jsonify({
"success": True,
"expired": False,
"user_email": session.get('user_email', 'Connected User'),
"access_token": token_info.get('access_token'),
**provider_extras,
}
return make_response(jsonify(response_data), 200)
"access_token": token_info.get('access_token')
}), 200)
except Exception as e:
current_app.logger.error(f"Error validating connector session: {e}", exc_info=True)
return make_response(jsonify({"success": False, "error": "Failed to validate session"}), 500)
current_app.logger.error(f"Error validating connector session: {e}")
return make_response(jsonify({"success": False, "error": str(e)}), 500)
@connectors_ns.route("/api/connectors/disconnect")
@@ -351,8 +311,8 @@ class ConnectorDisconnect(Resource):
return make_response(jsonify({"success": True}), 200)
except Exception as e:
current_app.logger.error(f"Error disconnecting connector session: {e}", exc_info=True)
return make_response(jsonify({"success": False, "error": "Failed to disconnect session"}), 500)
current_app.logger.error(f"Error disconnecting connector session: {e}")
return make_response(jsonify({"success": False, "error": str(e)}), 500)
@connectors_ns.route("/api/connectors/sync")
@@ -458,8 +418,8 @@ class ConnectorSync(Resource):
return make_response(
jsonify({
"success": False,
"error": "Failed to sync connector source"
}),
"error": str(err)
}),
400
)
@@ -470,32 +430,17 @@ class ConnectorCallbackStatus(Resource):
def get(self):
"""Return HTML page with connector authentication status"""
try:
# Validate and sanitize status to a known value
status_raw = request.args.get('status', 'error')
status = status_raw if status_raw in ('success', 'error', 'cancelled') else 'error'
# Escape all user-controlled values for HTML context
message = html.escape(request.args.get('message', ''))
provider_raw = request.args.get('provider', 'connector')
provider = html.escape(provider_raw.replace('_', ' ').title())
status = request.args.get('status', 'error')
message = request.args.get('message', '')
provider = request.args.get('provider', 'connector')
session_token = request.args.get('session_token', '')
user_email = html.escape(request.args.get('user_email', ''))
def safe_js_string(value: str) -> str:
"""Safely encode a string for embedding in inline JavaScript."""
js_encoded = json.dumps(value)
return js_encoded.replace('</', '<\\/').replace('<!--', '<\\!--')
js_status = safe_js_string(status)
js_session_token = safe_js_string(session_token)
js_user_email = safe_js_string(user_email)
js_provider_type = safe_js_string(provider_raw)
user_email = request.args.get('user_email', '')
html_content = f"""
<!DOCTYPE html>
<html>
<head>
<title>{provider} Authentication</title>
<title>{provider.replace('_', ' ').title()} Authentication</title>
<style>
body {{ font-family: Arial, sans-serif; text-align: center; padding: 40px; }}
.container {{ max-width: 600px; margin: 0 auto; }}
@@ -505,14 +450,13 @@ class ConnectorCallbackStatus(Resource):
</style>
<script>
window.onload = function() {{
const status = {js_status};
const sessionToken = {js_session_token};
const userEmail = {js_user_email};
const providerType = {js_provider_type};
const status = "{status}";
const sessionToken = "{session_token}";
const userEmail = "{user_email}";
if (status === "success" && window.opener) {{
window.opener.postMessage({{
type: providerType + '_auth_success',
type: '{provider}_auth_success',
session_token: sessionToken,
user_email: userEmail
}}, '*');
@@ -526,17 +470,17 @@ class ConnectorCallbackStatus(Resource):
</head>
<body>
<div class="container">
<h2>{provider} Authentication</h2>
<h2>{provider.replace('_', ' ').title()} Authentication</h2>
<div class="{status}">
<p>{message}</p>
{f'<p>Connected as: {user_email}</p>' if status == 'success' else ''}
</div>
<p><small>You can close this window. {f"Your {provider} is now connected and ready to use." if status == 'success' else "Feel free to close this window."}</small></p>
<p><small>You can close this window. {f"Your {provider.replace('_', ' ').title()} is now connected and ready to use." if status == 'success' else "Feel free to close this window."}</small></p>
</div>
</body>
</html>
"""
return make_response(html_content, 200, {'Content-Type': 'text/html'})
except Exception as e:
current_app.logger.error(f"Error rendering callback status page: {e}")

View File

@@ -1,7 +1,7 @@
import os
import datetime
import json
from flask import Blueprint, request, send_from_directory, jsonify
from flask import Blueprint, request, send_from_directory
from werkzeug.utils import secure_filename
from bson.objectid import ObjectId
import logging
@@ -24,16 +24,6 @@ current_dir = os.path.dirname(
internal = Blueprint("internal", __name__)
@internal.before_request
def verify_internal_key():
"""Verify INTERNAL_KEY for all internal endpoint requests."""
if settings.INTERNAL_KEY:
internal_key = request.headers.get("X-Internal-Key")
if not internal_key or internal_key != settings.INTERNAL_KEY:
logger.warning(f"Unauthorized internal API access attempt from {request.remote_addr}")
return jsonify({"error": "Unauthorized", "message": "Invalid or missing internal key"}), 401
@internal.route("/api/download", methods=["get"])
def download_file():
user = secure_filename(request.args.get("user"))
@@ -61,7 +51,6 @@ def upload_index_files():
file_path = request.form.get("file_path")
directory_structure = request.form.get("directory_structure")
file_name_map = request.form.get("file_name_map")
if directory_structure:
try:
@@ -71,14 +60,6 @@ def upload_index_files():
directory_structure = {}
else:
directory_structure = {}
if file_name_map:
try:
file_name_map = json.loads(file_name_map)
except Exception:
logger.error("Error parsing file_name_map")
file_name_map = None
else:
file_name_map = None
storage = StorageCreator.get_storage()
index_base_path = f"indexes/{id}"
@@ -106,43 +87,41 @@ def upload_index_files():
existing_entry = sources_collection.find_one({"_id": ObjectId(id)})
if existing_entry:
update_fields = {
"user": user,
"name": job_name,
"language": job_name,
"date": datetime.datetime.now(),
"model": settings.EMBEDDINGS_NAME,
"type": type,
"tokens": tokens,
"retriever": retriever,
"remote_data": remote_data,
"sync_frequency": sync_frequency,
"file_path": file_path,
"directory_structure": directory_structure,
}
if file_name_map is not None:
update_fields["file_name_map"] = file_name_map
sources_collection.update_one(
{"_id": ObjectId(id)},
{"$set": update_fields},
{
"$set": {
"user": user,
"name": job_name,
"language": job_name,
"date": datetime.datetime.now(),
"model": settings.EMBEDDINGS_NAME,
"type": type,
"tokens": tokens,
"retriever": retriever,
"remote_data": remote_data,
"sync_frequency": sync_frequency,
"file_path": file_path,
"directory_structure": directory_structure,
}
},
)
else:
insert_doc = {
"_id": ObjectId(id),
"user": user,
"name": job_name,
"language": job_name,
"date": datetime.datetime.now(),
"model": settings.EMBEDDINGS_NAME,
"type": type,
"tokens": tokens,
"retriever": retriever,
"remote_data": remote_data,
"sync_frequency": sync_frequency,
"file_path": file_path,
"directory_structure": directory_structure,
}
if file_name_map is not None:
insert_doc["file_name_map"] = file_name_map
sources_collection.insert_one(insert_doc)
sources_collection.insert_one(
{
"_id": ObjectId(id),
"user": user,
"name": job_name,
"language": job_name,
"date": datetime.datetime.now(),
"model": settings.EMBEDDINGS_NAME,
"type": type,
"tokens": tokens,
"retriever": retriever,
"remote_data": remote_data,
"sync_frequency": sync_frequency,
"file_path": file_path,
"directory_structure": directory_structure,
}
)
return {"status": "ok"}

View File

@@ -3,6 +3,5 @@
from .routes import agents_ns
from .sharing import agents_sharing_ns
from .webhooks import agents_webhooks_ns
from .folders import agents_folders_ns
__all__ = ["agents_ns", "agents_sharing_ns", "agents_webhooks_ns", "agents_folders_ns"]
__all__ = ["agents_ns", "agents_sharing_ns", "agents_webhooks_ns"]

View File

@@ -1,261 +0,0 @@
"""
Agent folders management routes.
Provides virtual folder organization for agents (Google Drive-like structure).
"""
import datetime
from bson.objectid import ObjectId
from flask import jsonify, make_response, request
from flask_restx import Namespace, Resource, fields
from application.api import api
from application.api.user.base import (
agent_folders_collection,
agents_collection,
)
agents_folders_ns = Namespace(
"agents_folders", description="Agent folder management", path="/api/agents/folders"
)
@agents_folders_ns.route("/")
class AgentFolders(Resource):
@api.doc(description="Get all folders for the user")
def get(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
try:
folders = list(agent_folders_collection.find({"user": user}))
result = [
{
"id": str(f["_id"]),
"name": f["name"],
"parent_id": f.get("parent_id"),
"created_at": f.get("created_at", "").isoformat() if f.get("created_at") else None,
"updated_at": f.get("updated_at", "").isoformat() if f.get("updated_at") else None,
}
for f in folders
]
return make_response(jsonify({"folders": result}), 200)
except Exception as e:
return make_response(jsonify({"success": False, "message": str(e)}), 400)
@api.doc(description="Create a new folder")
@api.expect(
api.model(
"CreateFolder",
{
"name": fields.String(required=True, description="Folder name"),
"parent_id": fields.String(required=False, description="Parent folder ID"),
},
)
)
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
if not data or not data.get("name"):
return make_response(jsonify({"success": False, "message": "Folder name is required"}), 400)
parent_id = data.get("parent_id")
if parent_id:
parent = agent_folders_collection.find_one({"_id": ObjectId(parent_id), "user": user})
if not parent:
return make_response(jsonify({"success": False, "message": "Parent folder not found"}), 404)
try:
now = datetime.datetime.now(datetime.timezone.utc)
folder = {
"user": user,
"name": data["name"],
"parent_id": parent_id,
"created_at": now,
"updated_at": now,
}
result = agent_folders_collection.insert_one(folder)
return make_response(
jsonify({"id": str(result.inserted_id), "name": data["name"], "parent_id": parent_id}),
201,
)
except Exception as e:
return make_response(jsonify({"success": False, "message": str(e)}), 400)
@agents_folders_ns.route("/<string:folder_id>")
class AgentFolder(Resource):
@api.doc(description="Get a specific folder with its agents")
def get(self, folder_id):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
try:
folder = agent_folders_collection.find_one({"_id": ObjectId(folder_id), "user": user})
if not folder:
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
agents = list(agents_collection.find({"user": user, "folder_id": folder_id}))
agents_list = [
{"id": str(a["_id"]), "name": a["name"], "description": a.get("description", "")}
for a in agents
]
subfolders = list(agent_folders_collection.find({"user": user, "parent_id": folder_id}))
subfolders_list = [{"id": str(sf["_id"]), "name": sf["name"]} for sf in subfolders]
return make_response(
jsonify({
"id": str(folder["_id"]),
"name": folder["name"],
"parent_id": folder.get("parent_id"),
"agents": agents_list,
"subfolders": subfolders_list,
}),
200,
)
except Exception as e:
return make_response(jsonify({"success": False, "message": str(e)}), 400)
@api.doc(description="Update a folder")
def put(self, folder_id):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
if not data:
return make_response(jsonify({"success": False, "message": "No data provided"}), 400)
try:
update_fields = {"updated_at": datetime.datetime.now(datetime.timezone.utc)}
if "name" in data:
update_fields["name"] = data["name"]
if "parent_id" in data:
if data["parent_id"] == folder_id:
return make_response(jsonify({"success": False, "message": "Cannot set folder as its own parent"}), 400)
update_fields["parent_id"] = data["parent_id"]
result = agent_folders_collection.update_one(
{"_id": ObjectId(folder_id), "user": user}, {"$set": update_fields}
)
if result.matched_count == 0:
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
return make_response(jsonify({"success": True}), 200)
except Exception as e:
return make_response(jsonify({"success": False, "message": str(e)}), 400)
@api.doc(description="Delete a folder")
def delete(self, folder_id):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
try:
agents_collection.update_many(
{"user": user, "folder_id": folder_id}, {"$unset": {"folder_id": ""}}
)
agent_folders_collection.update_many(
{"user": user, "parent_id": folder_id}, {"$unset": {"parent_id": ""}}
)
result = agent_folders_collection.delete_one({"_id": ObjectId(folder_id), "user": user})
if result.deleted_count == 0:
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
return make_response(jsonify({"success": True}), 200)
except Exception as e:
return make_response(jsonify({"success": False, "message": str(e)}), 400)
@agents_folders_ns.route("/move_agent")
class MoveAgentToFolder(Resource):
@api.doc(description="Move an agent to a folder or remove from folder")
@api.expect(
api.model(
"MoveAgent",
{
"agent_id": fields.String(required=True, description="Agent ID to move"),
"folder_id": fields.String(required=False, description="Target folder ID (null to remove from folder)"),
},
)
)
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
if not data or not data.get("agent_id"):
return make_response(jsonify({"success": False, "message": "Agent ID is required"}), 400)
agent_id = data["agent_id"]
folder_id = data.get("folder_id")
try:
agent = agents_collection.find_one({"_id": ObjectId(agent_id), "user": user})
if not agent:
return make_response(jsonify({"success": False, "message": "Agent not found"}), 404)
if folder_id:
folder = agent_folders_collection.find_one({"_id": ObjectId(folder_id), "user": user})
if not folder:
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
agents_collection.update_one(
{"_id": ObjectId(agent_id)}, {"$set": {"folder_id": folder_id}}
)
else:
agents_collection.update_one(
{"_id": ObjectId(agent_id)}, {"$unset": {"folder_id": ""}}
)
return make_response(jsonify({"success": True}), 200)
except Exception as e:
return make_response(jsonify({"success": False, "message": str(e)}), 400)
@agents_folders_ns.route("/bulk_move")
class BulkMoveAgents(Resource):
@api.doc(description="Move multiple agents to a folder")
@api.expect(
api.model(
"BulkMoveAgents",
{
"agent_ids": fields.List(fields.String, required=True, description="List of agent IDs"),
"folder_id": fields.String(required=False, description="Target folder ID"),
},
)
)
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
if not data or not data.get("agent_ids"):
return make_response(jsonify({"success": False, "message": "Agent IDs are required"}), 400)
agent_ids = data["agent_ids"]
folder_id = data.get("folder_id")
try:
if folder_id:
folder = agent_folders_collection.find_one({"_id": ObjectId(folder_id), "user": user})
if not folder:
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
object_ids = [ObjectId(aid) for aid in agent_ids]
if folder_id:
agents_collection.update_many(
{"_id": {"$in": object_ids}, "user": user},
{"$set": {"folder_id": folder_id}},
)
else:
agents_collection.update_many(
{"_id": {"$in": object_ids}, "user": user},
{"$unset": {"folder_id": ""}},
)
return make_response(jsonify({"success": True}), 200)
except Exception as e:
return make_response(jsonify({"success": False, "message": str(e)}), 400)

View File

@@ -11,7 +11,6 @@ from flask_restx import fields, Namespace, Resource
from application.api import api
from application.api.user.base import (
agent_folders_collection,
agents_collection,
db,
ensure_user_doc,
@@ -19,9 +18,6 @@ from application.api.user.base import (
resolve_tool_details,
storage,
users_collection,
workflow_edges_collection,
workflow_nodes_collection,
workflows_collection,
)
from application.core.settings import settings
from application.utils import (
@@ -34,189 +30,6 @@ from application.utils import (
agents_ns = Namespace("agents", description="Agent management operations", path="/api")
AGENT_TYPE_SCHEMAS = {
"classic": {
"required_published": [
"name",
"description",
"chunks",
"retriever",
"prompt_id",
],
"required_draft": ["name"],
"validate_published": ["name", "description", "prompt_id"],
"validate_draft": [],
"require_source": True,
"fields": [
"user",
"name",
"description",
"agent_type",
"status",
"key",
"image",
"source",
"sources",
"chunks",
"retriever",
"prompt_id",
"tools",
"json_schema",
"models",
"default_model_id",
"folder_id",
"limited_token_mode",
"token_limit",
"limited_request_mode",
"request_limit",
"createdAt",
"updatedAt",
"lastUsedAt",
],
},
"workflow": {
"required_published": ["name", "workflow"],
"required_draft": ["name"],
"validate_published": ["name", "workflow"],
"validate_draft": [],
"fields": [
"user",
"name",
"description",
"agent_type",
"status",
"key",
"workflow",
"folder_id",
"limited_token_mode",
"token_limit",
"limited_request_mode",
"request_limit",
"createdAt",
"updatedAt",
"lastUsedAt",
],
},
}
AGENT_TYPE_SCHEMAS["react"] = AGENT_TYPE_SCHEMAS["classic"]
AGENT_TYPE_SCHEMAS["openai"] = AGENT_TYPE_SCHEMAS["classic"]
def normalize_workflow_reference(workflow_value):
"""Normalize workflow references from form/json payloads."""
if workflow_value is None:
return None
if isinstance(workflow_value, dict):
return (
workflow_value.get("id")
or workflow_value.get("_id")
or workflow_value.get("workflow_id")
)
if isinstance(workflow_value, str):
value = workflow_value.strip()
if not value:
return ""
try:
parsed = json.loads(value)
if isinstance(parsed, str):
return parsed.strip()
if isinstance(parsed, dict):
return (
parsed.get("id") or parsed.get("_id") or parsed.get("workflow_id")
)
except json.JSONDecodeError:
pass
return value
return str(workflow_value)
def validate_workflow_access(workflow_value, user, required=False):
"""Validate workflow reference and ensure ownership."""
workflow_id = normalize_workflow_reference(workflow_value)
if not workflow_id:
if required:
return None, make_response(
jsonify({"success": False, "message": "Workflow is required"}), 400
)
return None, None
if not ObjectId.is_valid(workflow_id):
return None, make_response(
jsonify({"success": False, "message": "Invalid workflow ID format"}), 400
)
workflow = workflows_collection.find_one({"_id": ObjectId(workflow_id), "user": user})
if not workflow:
return None, make_response(
jsonify({"success": False, "message": "Workflow not found"}), 404
)
return workflow_id, None
def build_agent_document(
data, user, key, agent_type, image_url=None, source_field=None, sources_list=None
):
"""Build agent document based on agent type schema."""
if not agent_type or agent_type not in AGENT_TYPE_SCHEMAS:
agent_type = "classic"
schema = AGENT_TYPE_SCHEMAS.get(agent_type, AGENT_TYPE_SCHEMAS["classic"])
allowed_fields = set(schema["fields"])
now = datetime.datetime.now(datetime.timezone.utc)
base_doc = {
"user": user,
"name": data.get("name"),
"description": data.get("description", ""),
"agent_type": agent_type,
"status": data.get("status"),
"key": key,
"createdAt": now,
"updatedAt": now,
"lastUsedAt": None,
}
if agent_type == "workflow":
base_doc["workflow"] = data.get("workflow")
base_doc["folder_id"] = data.get("folder_id")
else:
base_doc.update(
{
"image": image_url or "",
"source": source_field or "",
"sources": sources_list or [],
"chunks": data.get("chunks", ""),
"retriever": data.get("retriever", ""),
"prompt_id": data.get("prompt_id", ""),
"tools": data.get("tools", []),
"json_schema": data.get("json_schema"),
"models": data.get("models", []),
"default_model_id": data.get("default_model_id", ""),
"folder_id": data.get("folder_id"),
}
)
if "limited_token_mode" in allowed_fields:
base_doc["limited_token_mode"] = (
data.get("limited_token_mode") == "True"
if isinstance(data.get("limited_token_mode"), str)
else bool(data.get("limited_token_mode", False))
)
if "token_limit" in allowed_fields:
base_doc["token_limit"] = int(
data.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"])
)
if "limited_request_mode" in allowed_fields:
base_doc["limited_request_mode"] = (
data.get("limited_request_mode") == "True"
if isinstance(data.get("limited_request_mode"), str)
else bool(data.get("limited_request_mode", False))
)
if "request_limit" in allowed_fields:
base_doc["request_limit"] = int(
data.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"])
)
return {k: v for k, v in base_doc.items() if k in allowed_fields}
@agents_ns.route("/get_agent")
class GetAgent(Resource):
@api.doc(params={"id": "Agent ID"}, description="Get agent by ID")
@@ -254,7 +67,7 @@ class GetAgent(Resource):
if (isinstance(source_ref, DBRef) and db.dereference(source_ref))
or source_ref == "default"
],
"chunks": agent.get("chunks", "2"),
"chunks": agent["chunks"],
"retriever": agent.get("retriever", ""),
"prompt_id": agent.get("prompt_id", ""),
"tools": agent.get("tools", []),
@@ -284,8 +97,6 @@ class GetAgent(Resource):
"shared_token": agent.get("shared_token", ""),
"models": agent.get("models", []),
"default_model_id": agent.get("default_model_id", ""),
"folder_id": agent.get("folder_id"),
"workflow": agent.get("workflow"),
}
return make_response(jsonify(data), 200)
except Exception as e:
@@ -335,7 +146,7 @@ class GetAgents(Resource):
isinstance(source_ref, DBRef) and db.dereference(source_ref)
)
],
"chunks": agent.get("chunks", "2"),
"chunks": agent["chunks"],
"retriever": agent.get("retriever", ""),
"prompt_id": agent.get("prompt_id", ""),
"tools": agent.get("tools", []),
@@ -365,13 +176,9 @@ class GetAgents(Resource):
"shared_token": agent.get("shared_token", ""),
"models": agent.get("models", []),
"default_model_id": agent.get("default_model_id", ""),
"folder_id": agent.get("folder_id"),
"workflow": agent.get("workflow"),
}
for agent in agents
if "source" in agent
or "retriever" in agent
or agent.get("agent_type") == "workflow"
if "source" in agent or "retriever" in agent
]
except Exception as err:
current_app.logger.error(f"Error retrieving agents: {err}", exc_info=True)
@@ -399,22 +206,16 @@ class CreateAgent(Resource):
required=False,
description="List of source identifiers for multiple sources",
),
"chunks": fields.Integer(required=False, description="Chunks count"),
"retriever": fields.String(required=False, description="Retriever ID"),
"prompt_id": fields.String(required=False, description="Prompt ID"),
"chunks": fields.Integer(required=True, description="Chunks count"),
"retriever": fields.String(required=True, description="Retriever ID"),
"prompt_id": fields.String(required=True, description="Prompt ID"),
"tools": fields.List(
fields.String, required=False, description="List of tool identifiers"
),
"agent_type": fields.String(
required=False,
description="Type of the agent (classic, react, workflow). Defaults to 'classic' for backwards compatibility.",
),
"agent_type": fields.String(required=True, description="Type of the agent"),
"status": fields.String(
required=True, description="Status of the agent (draft or published)"
),
"workflow": fields.String(
required=False, description="Workflow ID for workflow-type agents"
),
"json_schema": fields.Raw(
required=False,
description="JSON schema for enforcing structured output format",
@@ -441,9 +242,6 @@ class CreateAgent(Resource):
"default_model_id": fields.String(
required=False, description="Default model ID for this agent"
),
"folder_id": fields.String(
required=False, description="Folder ID to organize the agent"
),
},
)
@@ -509,10 +307,9 @@ class CreateAgent(Resource):
400,
)
except Exception as e:
current_app.logger.error(f"Invalid JSON schema: {e}")
return make_response(
jsonify(
{"success": False, "message": "Invalid JSON schema format"}
{"success": False, "message": f"Invalid JSON schema: {str(e)}"}
),
400,
)
@@ -526,34 +323,18 @@ class CreateAgent(Resource):
),
400,
)
agent_type = data.get("agent_type", "")
# Default to classic schema for empty or unknown agent types
if not agent_type or agent_type not in AGENT_TYPE_SCHEMAS:
schema = AGENT_TYPE_SCHEMAS["classic"]
# Set agent_type to classic if it was empty
if not agent_type:
agent_type = "classic"
else:
schema = AGENT_TYPE_SCHEMAS[agent_type]
is_published = data.get("status") == "published"
if agent_type == "workflow":
workflow_id, workflow_error = validate_workflow_access(
data.get("workflow"), user, required=is_published
)
if workflow_error:
return workflow_error
data["workflow"] = workflow_id
if data.get("status") == "published":
required_fields = schema["required_published"]
validate_fields = schema["validate_published"]
required_fields = [
"name",
"description",
"chunks",
"retriever",
"prompt_id",
"agent_type",
]
# Require either source or sources (but not both)
if (
schema.get("require_source")
and not data.get("source")
and not data.get("sources")
):
if not data.get("source") and not data.get("sources"):
return make_response(
jsonify(
{
@@ -563,9 +344,10 @@ class CreateAgent(Resource):
),
400,
)
validate_fields = ["name", "description", "prompt_id", "agent_type"]
else:
required_fields = schema["required_draft"]
validate_fields = schema["validate_draft"]
required_fields = ["name"]
validate_fields = []
missing_fields = check_required_fields(data, required_fields)
invalid_fields = validate_required_fields(data, validate_fields)
if missing_fields:
@@ -577,50 +359,74 @@ class CreateAgent(Resource):
return make_response(
jsonify({"success": False, "message": "Image upload failed"}), 400
)
folder_id = data.get("folder_id")
if folder_id:
if not ObjectId.is_valid(folder_id):
return make_response(
jsonify({"success": False, "message": "Invalid folder ID format"}),
400,
)
folder = agent_folders_collection.find_one(
{"_id": ObjectId(folder_id), "user": user}
)
if not folder:
return make_response(
jsonify({"success": False, "message": "Folder not found"}), 404
)
try:
key = str(uuid.uuid4()) if data.get("status") == "published" else ""
sources_list = []
source_field = ""
if data.get("sources") and len(data.get("sources", [])) > 0:
for source_id in data.get("sources", []):
if source_id == "default":
sources_list.append("default")
elif ObjectId.is_valid(source_id):
sources_list.append(DBRef("sources", ObjectId(source_id)))
source_field = ""
else:
source_value = data.get("source", "")
if source_value == "default":
source_field = "default"
elif ObjectId.is_valid(source_value):
source_field = DBRef("sources", ObjectId(source_value))
new_agent = build_agent_document(
data, user, key, agent_type, image_url, source_field, sources_list
)
if agent_type != "workflow":
if new_agent.get("chunks") == "":
new_agent["chunks"] = "2"
if (
new_agent.get("source") == ""
and new_agent.get("retriever") == ""
and not new_agent.get("sources")
):
new_agent["retriever"] = "classic"
else:
source_field = ""
new_agent = {
"user": user,
"name": data.get("name"),
"description": data.get("description", ""),
"image": image_url,
"source": source_field,
"sources": sources_list,
"chunks": data.get("chunks", ""),
"retriever": data.get("retriever", ""),
"prompt_id": data.get("prompt_id", ""),
"tools": data.get("tools", []),
"agent_type": data.get("agent_type", ""),
"status": data.get("status"),
"json_schema": data.get("json_schema"),
"limited_token_mode": (
data.get("limited_token_mode") == "True"
if isinstance(data.get("limited_token_mode"), str)
else bool(data.get("limited_token_mode", False))
),
"token_limit": int(
data.get(
"token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"]
)
),
"limited_request_mode": (
data.get("limited_request_mode") == "True"
if isinstance(data.get("limited_request_mode"), str)
else bool(data.get("limited_request_mode", False))
),
"request_limit": int(
data.get(
"request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"]
)
),
"createdAt": datetime.datetime.now(datetime.timezone.utc),
"updatedAt": datetime.datetime.now(datetime.timezone.utc),
"lastUsedAt": None,
"key": key,
"models": data.get("models", []),
"default_model_id": data.get("default_model_id", ""),
}
if new_agent["chunks"] == "":
new_agent["chunks"] = "2"
if (
new_agent["source"] == ""
and new_agent["retriever"] == ""
and not new_agent["sources"]
):
new_agent["retriever"] = "classic"
resp = agents_collection.insert_one(new_agent)
new_id = str(resp.inserted_id)
except Exception as err:
@@ -649,22 +455,16 @@ class UpdateAgent(Resource):
required=False,
description="List of source identifiers for multiple sources",
),
"chunks": fields.Integer(required=False, description="Chunks count"),
"retriever": fields.String(required=False, description="Retriever ID"),
"prompt_id": fields.String(required=False, description="Prompt ID"),
"chunks": fields.Integer(required=True, description="Chunks count"),
"retriever": fields.String(required=True, description="Retriever ID"),
"prompt_id": fields.String(required=True, description="Prompt ID"),
"tools": fields.List(
fields.String, required=False, description="List of tool identifiers"
),
"agent_type": fields.String(
required=False,
description="Type of the agent (classic, react, workflow). Defaults to 'classic' for backwards compatibility.",
),
"agent_type": fields.String(required=True, description="Type of the agent"),
"status": fields.String(
required=True, description="Status of the agent (draft or published)"
),
"workflow": fields.String(
required=False, description="Workflow ID for workflow-type agents"
),
"json_schema": fields.Raw(
required=False,
description="JSON schema for enforcing structured output format",
@@ -691,9 +491,6 @@ class UpdateAgent(Resource):
"default_model_id": fields.String(
required=False, description="Default model ID for this agent"
),
"folder_id": fields.String(
required=False, description="Folder ID to organize the agent"
),
},
)
@@ -787,8 +584,6 @@ class UpdateAgent(Resource):
"request_limit",
"models",
"default_model_id",
"folder_id",
"workflow",
]
for field in allowed_fields:
@@ -945,10 +740,10 @@ class UpdateAgent(Resource):
)
elif field == "token_limit":
token_limit = data.get("token_limit")
# Convert to int and store
update_fields[field] = int(token_limit) if token_limit else 0
# Validate consistency with mode
if update_fields[field] > 0 and not data.get("limited_token_mode"):
return make_response(
jsonify(
@@ -973,42 +768,6 @@ class UpdateAgent(Resource):
),
400,
)
elif field == "folder_id":
folder_id = data.get("folder_id")
if folder_id:
if not ObjectId.is_valid(folder_id):
return make_response(
jsonify(
{
"success": False,
"message": "Invalid folder ID format",
}
),
400,
)
folder = agent_folders_collection.find_one(
{"_id": ObjectId(folder_id), "user": user}
)
if not folder:
return make_response(
jsonify({"success": False, "message": "Folder not found"}),
404,
)
update_fields[field] = folder_id
else:
update_fields[field] = None
elif field == "workflow":
workflow_required = (
data.get("status", existing_agent.get("status")) == "published"
and data.get("agent_type", existing_agent.get("agent_type"))
== "workflow"
)
workflow_id, workflow_error = validate_workflow_access(
data.get("workflow"), user, required=workflow_required
)
if workflow_error:
return workflow_error
update_fields[field] = workflow_id
else:
value = data[field]
if field in ["name", "description", "prompt_id", "agent_type"]:
@@ -1037,82 +796,46 @@ class UpdateAgent(Resource):
)
newly_generated_key = None
final_status = update_fields.get("status", existing_agent.get("status"))
agent_type = update_fields.get("agent_type", existing_agent.get("agent_type"))
if final_status == "published":
if agent_type == "workflow":
required_published_fields = {
"name": "Agent name",
}
missing_published_fields = []
for req_field, field_label in required_published_fields.items():
final_value = update_fields.get(
req_field, existing_agent.get(req_field)
)
if not final_value:
missing_published_fields.append(field_label)
required_published_fields = {
"name": "Agent name",
"description": "Agent description",
"chunks": "Chunks count",
"prompt_id": "Prompt",
"agent_type": "Agent type",
}
workflow_id = update_fields.get("workflow", existing_agent.get("workflow"))
if not workflow_id:
missing_published_fields.append("Workflow")
elif not ObjectId.is_valid(workflow_id):
missing_published_fields.append("Valid workflow")
else:
workflow = workflows_collection.find_one(
{"_id": ObjectId(workflow_id), "user": user}
)
if not workflow:
missing_published_fields.append("Workflow access")
if missing_published_fields:
return make_response(
jsonify(
{
"success": False,
"message": f"Cannot publish workflow agent. Missing required fields: {', '.join(missing_published_fields)}",
}
),
400,
)
else:
required_published_fields = {
"name": "Agent name",
"description": "Agent description",
"chunks": "Chunks count",
"prompt_id": "Prompt",
"agent_type": "Agent type",
}
missing_published_fields = []
for req_field, field_label in required_published_fields.items():
final_value = update_fields.get(
req_field, existing_agent.get(req_field)
)
if not final_value:
missing_published_fields.append(field_label)
source_val = update_fields.get("source", existing_agent.get("source"))
sources_val = update_fields.get(
"sources", existing_agent.get("sources", [])
missing_published_fields = []
for req_field, field_label in required_published_fields.items():
final_value = update_fields.get(
req_field, existing_agent.get(req_field)
)
if not final_value:
missing_published_fields.append(field_label)
source_val = update_fields.get("source", existing_agent.get("source"))
sources_val = update_fields.get(
"sources", existing_agent.get("sources", [])
)
has_valid_source = (
isinstance(source_val, DBRef)
or source_val == "default"
or (isinstance(sources_val, list) and len(sources_val) > 0)
has_valid_source = (
isinstance(source_val, DBRef)
or source_val == "default"
or (isinstance(sources_val, list) and len(sources_val) > 0)
)
if not has_valid_source:
missing_published_fields.append("Source")
if missing_published_fields:
return make_response(
jsonify(
{
"success": False,
"message": f"Cannot publish agent. Missing or invalid required fields: {', '.join(missing_published_fields)}",
}
),
400,
)
if not has_valid_source:
missing_published_fields.append("Source")
if missing_published_fields:
return make_response(
jsonify(
{
"success": False,
"message": f"Cannot publish agent. Missing or invalid required fields: {', '.join(missing_published_fields)}",
}
),
400,
)
if not existing_agent.get("key"):
newly_generated_key = str(uuid.uuid4())
update_fields["key"] = newly_generated_key
@@ -1184,29 +907,6 @@ class DeleteAgent(Resource):
jsonify({"success": False, "message": "Agent not found"}), 404
)
deleted_id = str(deleted_agent["_id"])
if deleted_agent.get("agent_type") == "workflow" and deleted_agent.get(
"workflow"
):
workflow_id = normalize_workflow_reference(deleted_agent.get("workflow"))
if workflow_id and ObjectId.is_valid(workflow_id):
workflow_oid = ObjectId(workflow_id)
owned_workflow = workflows_collection.find_one(
{"_id": workflow_oid, "user": user}, {"_id": 1}
)
if owned_workflow:
workflow_nodes_collection.delete_many({"workflow_id": workflow_id})
workflow_edges_collection.delete_many({"workflow_id": workflow_id})
workflows_collection.delete_one({"_id": workflow_oid, "user": user})
else:
current_app.logger.warning(
f"Skipping workflow cleanup for non-owned workflow {workflow_id}"
)
elif workflow_id:
current_app.logger.warning(
f"Skipping workflow cleanup for invalid workflow id {workflow_id}"
)
except Exception as err:
current_app.logger.error(f"Error deleting agent: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
@@ -1315,16 +1015,19 @@ class AdoptAgent(Resource):
def post(self):
if not (decoded_token := request.decoded_token):
return make_response(jsonify({"success": False}), 401)
if not (agent_id := request.args.get("id")):
return make_response(
jsonify({"success": False, "message": "ID required"}), 400
)
try:
agent = agents_collection.find_one(
{"_id": ObjectId(agent_id), "user": "system"}
)
if not agent:
return make_response(jsonify({"status": "Not found"}), 404)
new_agent = agent.copy()
new_agent.pop("_id", None)
new_agent["user"] = decoded_token["sub"]
@@ -1434,4 +1137,4 @@ class RemoveSharedAgent(Resource):
current_app.logger.error(f"Error removing shared agent: {err}")
return make_response(
jsonify({"success": False, "message": "Server error"}), 500
)
)

View File

@@ -255,8 +255,8 @@ class ShareAgent(Resource):
{"$unset": {"shared_metadata": ""}},
)
except Exception as err:
current_app.logger.error(f"Error sharing/unsharing agent: {err}", exc_info=True)
return make_response(jsonify({"success": False, "error": "Failed to update agent sharing status"}), 400)
current_app.logger.error(f"Error sharing/unsharing agent: {err}")
return make_response(jsonify({"success": False, "error": str(err)}), 400)
shared_token = shared_token if shared else None
return make_response(
jsonify({"success": True, "shared_token": shared_token}), 200

View File

@@ -99,8 +99,11 @@ class StoreAttachment(Resource):
})
if not tasks:
error_msg = "No valid files to upload"
if errors:
error_msg += f". Errors: {errors}"
return make_response(
jsonify({"status": "error", "message": "No valid files to upload"}),
jsonify({"status": "error", "message": error_msg, "errors": errors}),
400,
)
@@ -132,7 +135,7 @@ class StoreAttachment(Resource):
)
except Exception as err:
current_app.logger.error(f"Error storing attachment: {err}", exc_info=True)
return make_response(jsonify({"success": False, "error": "Failed to store attachment"}), 400)
return make_response(jsonify({"success": False, "error": str(err)}), 400)
@attachments_ns.route("/images/<path:image_path>")

View File

@@ -31,17 +31,12 @@ sources_collection = db["sources"]
prompts_collection = db["prompts"]
feedback_collection = db["feedback"]
agents_collection = db["agents"]
agent_folders_collection = db["agent_folders"]
token_usage_collection = db["token_usage"]
shared_conversations_collections = db["shared_conversations"]
users_collection = db["users"]
user_logs_collection = db["user_logs"]
user_tools_collection = db["user_tools"]
attachments_collection = db["attachments"]
workflow_runs_collection = db["workflow_runs"]
workflows_collection = db["workflows"]
workflow_nodes_collection = db["workflow_nodes"]
workflow_edges_collection = db["workflow_edges"]
try:
@@ -51,25 +46,6 @@ try:
background=True,
)
users_collection.create_index("user_id", unique=True)
workflows_collection.create_index(
[("user", 1)], name="workflow_user_index", background=True
)
workflow_nodes_collection.create_index(
[("workflow_id", 1)], name="node_workflow_index", background=True
)
workflow_nodes_collection.create_index(
[("workflow_id", 1), ("graph_version", 1)],
name="node_workflow_graph_version_index",
background=True,
)
workflow_edges_collection.create_index(
[("workflow_id", 1)], name="edge_workflow_index", background=True
)
workflow_edges_collection.create_index(
[("workflow_id", 1), ("graph_version", 1)],
name="edge_workflow_graph_version_index",
background=True,
)
except Exception as e:
print("Error creating indexes:", e)
current_dir = os.path.dirname(

View File

@@ -5,7 +5,8 @@ Main user API routes - registers all namespace modules.
from flask import Blueprint
from application.api import api
from .agents import agents_ns, agents_sharing_ns, agents_webhooks_ns, agents_folders_ns
from .agents import agents_ns, agents_sharing_ns, agents_webhooks_ns
from .analytics import analytics_ns
from .attachments import attachments_ns
from .conversations import conversations_ns
@@ -14,7 +15,6 @@ from .prompts import prompts_ns
from .sharing import sharing_ns
from .sources import sources_chunks_ns, sources_ns, sources_upload_ns
from .tools import tools_mcp_ns, tools_ns
from .workflows import workflows_ns
user = Blueprint("user", __name__)
@@ -31,11 +31,10 @@ api.add_namespace(conversations_ns)
# Models
api.add_namespace(models_ns)
# Agents (main, sharing, webhooks, folders)
# Agents (main, sharing, webhooks)
api.add_namespace(agents_ns)
api.add_namespace(agents_sharing_ns)
api.add_namespace(agents_webhooks_ns)
api.add_namespace(agents_folders_ns)
# Prompts
api.add_namespace(prompts_ns)
@@ -51,6 +50,3 @@ api.add_namespace(sources_upload_ns)
# Tools (main, MCP)
api.add_namespace(tools_ns)
api.add_namespace(tools_mcp_ns)
# Workflows
api.add_namespace(workflows_ns)

View File

@@ -220,23 +220,8 @@ class GetPubliclySharedConversations(Resource):
shared
and "conversation_id" in shared
):
# Handle DBRef (legacy), ObjectId, dict, and string formats for conversation_id
# conversation_id is now stored as an ObjectId, not a DBRef
conversation_id = shared["conversation_id"]
if isinstance(conversation_id, DBRef):
conversation_id = conversation_id.id
elif isinstance(conversation_id, dict):
# Handle dict representation of DBRef (e.g., {"$ref": "...", "$id": "..."})
if "$id" in conversation_id:
conv_id = conversation_id["$id"]
# $id might be a dict like {"$oid": "..."} or a string
if isinstance(conv_id, dict) and "$oid" in conv_id:
conversation_id = ObjectId(conv_id["$oid"])
else:
conversation_id = ObjectId(conv_id)
elif "_id" in conversation_id:
conversation_id = ObjectId(conversation_id["_id"])
elif isinstance(conversation_id, str):
conversation_id = ObjectId(conversation_id)
conversation = conversations_collection.find_one(
{"_id": conversation_id}
)

View File

@@ -55,14 +55,9 @@ class GetChunks(Resource):
if path:
chunk_source = metadata.get("source", "")
chunk_file_path = metadata.get("file_path", "")
# Check if the chunk matches the requested path
# For file uploads: source ends with path (e.g., "inputs/.../file.pdf" ends with "file.pdf")
# For crawlers: file_path ends with path (e.g., "guides/setup.md" ends with "setup.md")
source_match = chunk_source and chunk_source.endswith(path)
file_path_match = chunk_file_path and chunk_file_path.endswith(path)
# Check if the chunk's source matches the requested path
if not (source_match or file_path_match):
if not chunk_source or not chunk_source.endswith(path):
continue
# Filter by search term if provided

View File

@@ -9,7 +9,6 @@ from flask_restx import fields, Namespace, Resource
from application.api import api
from application.api.user.base import sources_collection
from application.api.user.tasks import sync_source
from application.core.settings import settings
from application.storage.storage_creator import StorageCreator
from application.utils import check_required_fields
@@ -21,21 +20,6 @@ sources_ns = Namespace(
)
def _get_provider_from_remote_data(remote_data):
if not remote_data:
return None
if isinstance(remote_data, dict):
return remote_data.get("provider")
if isinstance(remote_data, str):
try:
remote_data_obj = json.loads(remote_data)
except Exception:
return None
if isinstance(remote_data_obj, dict):
return remote_data_obj.get("provider")
return None
@sources_ns.route("/sources")
class CombinedJson(Resource):
@api.doc(description="Provide JSON file with combined available indexes")
@@ -57,7 +41,6 @@ class CombinedJson(Resource):
try:
for index in sources_collection.find({"user": user}).sort("date", -1):
provider = _get_provider_from_remote_data(index.get("remote_data"))
data.append(
{
"id": str(index["_id"]),
@@ -68,7 +51,6 @@ class CombinedJson(Resource):
"tokens": index.get("tokens", ""),
"retriever": index.get("retriever", "classic"),
"syncFrequency": index.get("sync_frequency", ""),
"provider": provider,
"is_nested": bool(index.get("directory_structure")),
"type": index.get(
"type", "file"
@@ -125,7 +107,6 @@ class PaginatedSources(Resource):
paginated_docs = []
for doc in documents:
provider = _get_provider_from_remote_data(doc.get("remote_data"))
doc_data = {
"id": str(doc["_id"]),
"name": doc.get("name", ""),
@@ -135,7 +116,6 @@ class PaginatedSources(Resource):
"tokens": doc.get("tokens", ""),
"retriever": doc.get("retriever", "classic"),
"syncFrequency": doc.get("sync_frequency", ""),
"provider": provider,
"isNested": bool(doc.get("directory_structure")),
"type": doc.get("type", "file"),
}
@@ -260,7 +240,7 @@ class ManageSync(Resource):
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json() or {}
data = request.get_json()
required_fields = ["source_id", "sync_frequency"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
@@ -289,72 +269,6 @@ class ManageSync(Resource):
return make_response(jsonify({"success": True}), 200)
@sources_ns.route("/sync_source")
class SyncSource(Resource):
sync_source_model = api.model(
"SyncSourceModel",
{"source_id": fields.String(required=True, description="Source ID")},
)
@api.expect(sync_source_model)
@api.doc(description="Trigger an immediate sync for a source")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
required_fields = ["source_id"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
source_id = data["source_id"]
if not ObjectId.is_valid(source_id):
return make_response(
jsonify({"success": False, "message": "Invalid source ID"}), 400
)
doc = sources_collection.find_one(
{"_id": ObjectId(source_id), "user": user}
)
if not doc:
return make_response(
jsonify({"success": False, "message": "Source not found"}), 404
)
source_type = doc.get("type", "")
if source_type.startswith("connector"):
return make_response(
jsonify(
{
"success": False,
"message": "Connector sources must be synced via /api/connectors/sync",
}
),
400,
)
source_data = doc.get("remote_data")
if not source_data:
return make_response(
jsonify({"success": False, "message": "Source is not syncable"}), 400
)
try:
task = sync_source.delay(
source_data=source_data,
job_name=doc.get("name", ""),
user=user,
loader=source_type,
sync_frequency=doc.get("sync_frequency", "never"),
retriever=doc.get("retriever", "classic"),
doc_id=source_id,
)
except Exception as err:
current_app.logger.error(
f"Error starting sync for source {source_id}: {err}",
exc_info=True,
)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
@sources_ns.route("/directory_structure")
class DirectoryStructure(Resource):
@api.doc(
@@ -406,4 +320,4 @@ class DirectoryStructure(Resource):
current_app.logger.error(
f"Error retrieving directory structure: {e}", exc_info=True
)
return make_response(jsonify({"success": False, "error": "Failed to retrieve directory structure"}), 500)
return make_response(jsonify({"success": False, "error": str(e)}), 500)

View File

@@ -64,27 +64,19 @@ class UploadFile(Resource):
safe_user = safe_filename(user)
dir_name = safe_filename(job_name)
base_path = f"{settings.UPLOAD_FOLDER}/{safe_user}/{dir_name}"
file_name_map = {}
try:
storage = StorageCreator.get_storage()
for file in files:
original_filename = os.path.basename(file.filename)
original_filename = file.filename
safe_file = safe_filename(original_filename)
if original_filename:
file_name_map[safe_file] = original_filename
with tempfile.TemporaryDirectory() as temp_dir:
temp_file_path = os.path.join(temp_dir, safe_file)
file.save(temp_file_path)
# Only extract actual .zip files, not Office formats (.docx, .xlsx, .pptx)
# which are technically zip archives but should be processed as-is
is_office_format = safe_file.lower().endswith(
(".docx", ".xlsx", ".pptx", ".odt", ".ods", ".odp", ".epub")
)
if zipfile.is_zipfile(temp_file_path) and not is_office_format:
if zipfile.is_zipfile(temp_file_path):
try:
with zipfile.ZipFile(temp_file_path, "r") as zip_ref:
zip_ref.extractall(path=temp_dir)
@@ -145,7 +137,6 @@ class UploadFile(Resource):
user,
file_path=base_path,
filename=dir_name,
file_name_map=file_name_map,
)
except Exception as err:
current_app.logger.error(f"Error uploading file: {err}", exc_info=True)
@@ -191,8 +182,6 @@ class UploadRemote(Resource):
source_data = config.get("url")
elif data["source"] == "reddit":
source_data = config
elif data["source"] == "s3":
source_data = config
elif data["source"] in ConnectorCreator.get_supported_connectors():
session_token = config.get("session_token")
if not session_token:
@@ -345,14 +334,6 @@ class ManageSourceFiles(Resource):
storage = StorageCreator.get_storage()
source_file_path = source.get("file_path", "")
parent_dir = request.form.get("parent_dir", "")
file_name_map = source.get("file_name_map") or {}
if isinstance(file_name_map, str):
try:
file_name_map = json.loads(file_name_map)
except Exception:
file_name_map = {}
if not isinstance(file_name_map, dict):
file_name_map = {}
if parent_dir and (parent_dir.startswith("/") or ".." in parent_dir):
return make_response(
@@ -374,35 +355,19 @@ class ManageSourceFiles(Resource):
400,
)
added_files = []
map_updated = False
target_dir = source_file_path
if parent_dir:
target_dir = f"{source_file_path}/{parent_dir}"
for file in files:
if file.filename:
original_filename = os.path.basename(file.filename)
safe_filename_str = safe_filename(original_filename)
safe_filename_str = safe_filename(file.filename)
file_path = f"{target_dir}/{safe_filename_str}"
# Save file to storage
storage.save_file(file, file_path)
added_files.append(safe_filename_str)
if original_filename:
relative_key = (
f"{parent_dir}/{safe_filename_str}"
if parent_dir
else safe_filename_str
)
file_name_map[relative_key] = original_filename
map_updated = True
if map_updated:
sources_collection.update_one(
{"_id": ObjectId(source_id)},
{"$set": {"file_name_map": file_name_map}},
)
# Trigger re-ingestion pipeline
from application.api.user.tasks import reingest_source_task
@@ -449,7 +414,6 @@ class ManageSourceFiles(Resource):
# Remove files from storage and directory structure
removed_files = []
map_updated = False
for file_path in file_paths:
full_path = f"{source_file_path}/{file_path}"
@@ -458,15 +422,6 @@ class ManageSourceFiles(Resource):
if storage.file_exists(full_path):
storage.delete_file(full_path)
removed_files.append(file_path)
if file_path in file_name_map:
file_name_map.pop(file_path, None)
map_updated = True
if map_updated and isinstance(file_name_map, dict):
sources_collection.update_one(
{"_id": ObjectId(source_id)},
{"$set": {"file_name_map": file_name_map}},
)
# Trigger re-ingestion pipeline
from application.api.user.tasks import reingest_source_task
@@ -549,20 +504,6 @@ class ManageSourceFiles(Resource):
f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}, "
f"Full path: {full_directory_path}"
)
if directory_path and file_name_map:
prefix = f"{directory_path.rstrip('/')}/"
keys_to_remove = [
key
for key in file_name_map.keys()
if key == directory_path or key.startswith(prefix)
]
if keys_to_remove:
for key in keys_to_remove:
file_name_map.pop(key, None)
sources_collection.update_one(
{"_id": ObjectId(source_id)},
{"$set": {"file_name_map": file_name_map}},
)
# Trigger re-ingestion pipeline
@@ -633,9 +574,8 @@ class TaskStatus(Resource):
):
task_meta = str(task_meta) # Convert to a string representation
except ConnectionError as err:
current_app.logger.error(f"Connection error getting task status: {err}")
return make_response(
jsonify({"success": False, "message": "Service unavailable"}), 503
jsonify({"success": False, "message": str(err)}), 503
)
except Exception as err:
current_app.logger.error(f"Error getting task status: {err}", exc_info=True)

View File

@@ -8,25 +8,13 @@ from application.worker import (
mcp_oauth,
mcp_oauth_status,
remote_worker,
sync,
sync_worker,
)
@celery.task(bind=True)
def ingest(
self, directory, formats, job_name, user, file_path, filename, file_name_map=None
):
resp = ingest_worker(
self,
directory,
formats,
job_name,
file_path,
filename,
user,
file_name_map=file_name_map,
)
def ingest(self, directory, formats, job_name, user, file_path, filename):
resp = ingest_worker(self, directory, formats, job_name, file_path, filename, user)
return resp
@@ -50,30 +38,6 @@ def schedule_syncs(self, frequency):
return resp
@celery.task(bind=True)
def sync_source(
self,
source_data,
job_name,
user,
loader,
sync_frequency,
retriever,
doc_id,
):
resp = sync(
self,
source_data,
job_name,
user,
loader,
sync_frequency,
retriever,
doc_id,
)
return resp
@celery.task(bind=True)
def store_attachment(self, file_info, user):
resp = attachment_worker(self, file_info, user)

View File

@@ -1,7 +1,7 @@
"""Tool management MCP server integration."""
import json
from urllib.parse import unquote, urlencode
from email.quoprimime import unquote
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, redirect, request
@@ -43,16 +43,6 @@ class TestMCPServerConfig(Resource):
return missing_fields
try:
config = data["config"]
transport_type = (config.get("transport_type") or "auto").lower()
allowed_transports = {"auto", "sse", "http"}
if transport_type not in allowed_transports:
return make_response(
jsonify({"success": False, "error": "Unsupported transport_type"}),
400,
)
config.pop("command", None)
config.pop("args", None)
config["transport_type"] = transport_type
auth_credentials = {}
auth_type = config.get("auth_type", "none")
@@ -74,17 +64,12 @@ class TestMCPServerConfig(Resource):
mcp_tool = MCPTool(config=test_config, user_id=user)
result = mcp_tool.test_connection()
# Sanitize the response to avoid exposing internal error details
if not result.get("success") and "message" in result:
current_app.logger.error(f"MCP connection test failed: {result.get('message')}")
result["message"] = "Connection test failed"
return make_response(jsonify(result), 200)
except Exception as e:
current_app.logger.error(f"Error testing MCP server: {e}", exc_info=True)
return make_response(
jsonify(
{"success": False, "error": "Connection test failed"}
{"success": False, "error": f"Connection test failed: {str(e)}"}
),
500,
)
@@ -125,16 +110,6 @@ class MCPServerSave(Resource):
return missing_fields
try:
config = data["config"]
transport_type = (config.get("transport_type") or "auto").lower()
allowed_transports = {"auto", "sse", "http"}
if transport_type not in allowed_transports:
return make_response(
jsonify({"success": False, "error": "Unsupported transport_type"}),
400,
)
config.pop("command", None)
config.pop("args", None)
config["transport_type"] = transport_type
auth_credentials = {}
auth_type = config.get("auth_type", "none")
@@ -259,7 +234,7 @@ class MCPServerSave(Resource):
current_app.logger.error(f"Error saving MCP server: {e}", exc_info=True)
return make_response(
jsonify(
{"success": False, "error": "Failed to save MCP server"}
{"success": False, "error": f"Failed to save MCP server: {str(e)}"}
),
500,
)
@@ -288,12 +263,9 @@ class MCPOAuthCallback(Resource):
error = request.args.get("error")
if error:
params = {
"status": "error",
"message": f"OAuth error: {error}. Please try again and make sure to grant all requested permissions, including offline access.",
"provider": "mcp_tool"
}
return redirect(f"/api/connectors/callback-status?{urlencode(params)}")
return redirect(
f"/api/connectors/callback-status?status=error&message=OAuth+error:+{error}.+Please+try+again+and+make+sure+to+grant+all+requested+permissions,+including+offline+access.&provider=mcp_tool"
)
if not code or not state:
return redirect(
"/api/connectors/callback-status?status=error&message=Authorization+code+or+state+not+provided.+Please+complete+the+authorization+process+and+make+sure+to+grant+offline+access.&provider=mcp_tool"
@@ -320,7 +292,7 @@ class MCPOAuthCallback(Resource):
f"Error handling MCP OAuth callback: {str(e)}", exc_info=True
)
return redirect(
"/api/connectors/callback-status?status=error&message=Internal+server+error.&provider=mcp_tool"
f"/api/connectors/callback-status?status=error&message=Internal+server+error:+{str(e)}.&provider=mcp_tool"
)
@@ -354,8 +326,8 @@ class MCPOAuthStatus(Resource):
)
except Exception as e:
current_app.logger.error(
f"Error getting OAuth status for task {task_id}: {str(e)}", exc_info=True
f"Error getting OAuth status for task {task_id}: {str(e)}"
)
return make_response(
jsonify({"success": False, "error": "Failed to get OAuth status", "task_id": task_id}), 500
jsonify({"success": False, "error": str(e), "task_id": task_id}), 500
)

View File

@@ -4,7 +4,6 @@ from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, Namespace, Resource
from application.agents.tools.spec_parser import parse_spec
from application.agents.tools.tool_manager import ToolManager
from application.api import api
from application.api.user.base import user_tools_collection
@@ -415,136 +414,3 @@ class DeleteTool(Resource):
current_app.logger.error(f"Error deleting tool: {err}", exc_info=True)
return {"success": False}, 400
return {"success": True}, 200
@tools_ns.route("/parse_spec")
class ParseSpec(Resource):
@api.doc(
description="Parse an API specification (OpenAPI 3.x or Swagger 2.0) and return actions"
)
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
if "file" in request.files:
file = request.files["file"]
if not file.filename:
return make_response(
jsonify({"success": False, "message": "No file selected"}), 400
)
try:
spec_content = file.read().decode("utf-8")
except UnicodeDecodeError:
return make_response(
jsonify({"success": False, "message": "Invalid file encoding"}), 400
)
elif request.is_json:
data = request.get_json()
spec_content = data.get("spec_content", "")
else:
return make_response(
jsonify({"success": False, "message": "No spec provided"}), 400
)
if not spec_content or not spec_content.strip():
return make_response(
jsonify({"success": False, "message": "Empty spec content"}), 400
)
try:
metadata, actions = parse_spec(spec_content)
return make_response(
jsonify(
{
"success": True,
"metadata": metadata,
"actions": actions,
}
),
200,
)
except ValueError as e:
current_app.logger.error(f"Spec validation error: {e}")
return make_response(jsonify({"success": False, "error": "Invalid specification format"}), 400)
except Exception as err:
current_app.logger.error(f"Error parsing spec: {err}", exc_info=True)
return make_response(jsonify({"success": False, "error": "Failed to parse specification"}), 500)
@tools_ns.route("/artifact/<artifact_id>")
class GetArtifact(Resource):
@api.doc(description="Get artifact data by artifact ID. Returns all todos for the tool when fetching a todo artifact.")
def get(self, artifact_id: str):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user_id = decoded_token.get("sub")
try:
obj_id = ObjectId(artifact_id)
except Exception:
return make_response(
jsonify({"success": False, "message": "Invalid artifact ID"}), 400
)
from application.core.mongo_db import MongoDB
from application.core.settings import settings
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
note_doc = db["notes"].find_one({"_id": obj_id, "user_id": user_id})
if note_doc:
content = note_doc.get("note", "")
line_count = len(content.split("\n")) if content else 0
artifact = {
"artifact_type": "note",
"data": {
"content": content,
"line_count": line_count,
"updated_at": (
note_doc["updated_at"].isoformat()
if note_doc.get("updated_at")
else None
),
},
}
return make_response(jsonify({"success": True, "artifact": artifact}), 200)
todo_doc = db["todos"].find_one({"_id": obj_id, "user_id": user_id})
if todo_doc:
tool_id = todo_doc.get("tool_id")
# Return all todos for the tool
query = {"user_id": user_id, "tool_id": tool_id}
all_todos = list(db["todos"].find(query))
items = []
open_count = 0
completed_count = 0
for t in all_todos:
status = t.get("status", "open")
if status == "open":
open_count += 1
elif status == "completed":
completed_count += 1
items.append({
"todo_id": t.get("todo_id"),
"title": t.get("title", ""),
"status": status,
"created_at": (
t["created_at"].isoformat() if t.get("created_at") else None
),
"updated_at": (
t["updated_at"].isoformat() if t.get("updated_at") else None
),
})
artifact = {
"artifact_type": "todo_list",
"data": {
"items": items,
"total_count": len(items),
"open_count": open_count,
"completed_count": completed_count,
},
}
return make_response(jsonify({"success": True, "artifact": artifact}), 200)
return make_response(
jsonify({"success": False, "message": "Artifact not found"}), 404
)

View File

@@ -1,378 +0,0 @@
"""Centralized utilities for API routes."""
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Tuple
from bson.errors import InvalidId
from bson.objectid import ObjectId
from flask import jsonify, make_response, request, Response
from pymongo.collection import Collection
def get_user_id() -> Optional[str]:
"""
Extract user ID from decoded JWT token.
Returns:
User ID string or None if not authenticated
"""
decoded_token = getattr(request, "decoded_token", None)
return decoded_token.get("sub") if decoded_token else None
def require_auth(func: Callable) -> Callable:
"""
Decorator to require authentication for route handlers.
Usage:
@require_auth
def get(self):
user_id = get_user_id()
...
"""
@wraps(func)
def wrapper(*args, **kwargs):
user_id = get_user_id()
if not user_id:
return error_response("Unauthorized", 401)
return func(*args, **kwargs)
return wrapper
def success_response(
data: Optional[Dict[str, Any]] = None, status: int = 200
) -> Response:
"""
Create a standardized success response.
Args:
data: Optional data dictionary to include in response
status: HTTP status code (default: 200)
Returns:
Flask Response object
Example:
return success_response({"users": [...], "total": 10})
"""
response = {"success": True}
if data:
response.update(data)
return make_response(jsonify(response), status)
def error_response(message: str, status: int = 400, **kwargs) -> Response:
"""
Create a standardized error response.
Args:
message: Error message string
status: HTTP status code (default: 400)
**kwargs: Additional fields to include in response
Returns:
Flask Response object
Example:
return error_response("Resource not found", 404)
return error_response("Invalid input", 400, errors=["field1", "field2"])
"""
response = {"success": False, "message": message}
response.update(kwargs)
return make_response(jsonify(response), status)
def validate_object_id(
id_string: str, resource_name: str = "Resource"
) -> Tuple[Optional[ObjectId], Optional[Response]]:
"""
Validate and convert string to ObjectId.
Args:
id_string: String to convert
resource_name: Name of resource for error message
Returns:
Tuple of (ObjectId or None, error_response or None)
Example:
obj_id, error = validate_object_id(workflow_id, "Workflow")
if error:
return error
"""
try:
return ObjectId(id_string), None
except (InvalidId, TypeError):
return None, error_response(f"Invalid {resource_name} ID format")
def validate_pagination(
default_limit: int = 20, max_limit: int = 100
) -> Tuple[int, int, Optional[Response]]:
"""
Extract and validate pagination parameters from request.
Args:
default_limit: Default items per page
max_limit: Maximum allowed items per page
Returns:
Tuple of (limit, skip, error_response or None)
Example:
limit, skip, error = validate_pagination()
if error:
return error
"""
try:
limit = min(int(request.args.get("limit", default_limit)), max_limit)
skip = int(request.args.get("skip", 0))
if limit < 1 or skip < 0:
return 0, 0, error_response("Invalid pagination parameters")
return limit, skip, None
except ValueError:
return 0, 0, error_response("Invalid pagination parameters")
def check_resource_ownership(
collection: Collection,
resource_id: ObjectId,
user_id: str,
resource_name: str = "Resource",
) -> Tuple[Optional[Dict], Optional[Response]]:
"""
Check if resource exists and belongs to user.
Args:
collection: MongoDB collection
resource_id: Resource ObjectId
user_id: User ID string
resource_name: Name of resource for error messages
Returns:
Tuple of (resource_dict or None, error_response or None)
Example:
workflow, error = check_resource_ownership(
workflows_collection,
workflow_id,
user_id,
"Workflow"
)
if error:
return error
"""
resource = collection.find_one({"_id": resource_id, "user": user_id})
if not resource:
return None, error_response(f"{resource_name} not found", 404)
return resource, None
def serialize_object_id(
obj: Dict[str, Any], id_field: str = "_id", new_field: str = "id"
) -> Dict[str, Any]:
"""
Convert ObjectId to string in a dictionary.
Args:
obj: Dictionary containing ObjectId
id_field: Field name containing ObjectId
new_field: New field name for string ID
Returns:
Modified dictionary
Example:
user = serialize_object_id(user_doc)
# user["id"] = "507f1f77bcf86cd799439011"
"""
if id_field in obj:
obj[new_field] = str(obj[id_field])
if id_field != new_field:
obj.pop(id_field, None)
return obj
def serialize_list(items: List[Dict], serializer: Callable[[Dict], Dict]) -> List[Dict]:
"""
Apply serializer function to list of items.
Args:
items: List of dictionaries
serializer: Function to apply to each item
Returns:
List of serialized items
Example:
workflows = serialize_list(workflow_docs, serialize_workflow)
"""
return [serializer(item) for item in items]
def paginated_response(
collection: Collection,
query: Dict[str, Any],
serializer: Callable[[Dict], Dict],
limit: int,
skip: int,
sort_field: str = "created_at",
sort_order: int = -1,
response_key: str = "items",
) -> Response:
"""
Create paginated response for collection query.
Args:
collection: MongoDB collection
query: Query dictionary
serializer: Function to serialize each item
limit: Items per page
skip: Number of items to skip
sort_field: Field to sort by
sort_order: Sort order (1=asc, -1=desc)
response_key: Key name for items in response
Returns:
Flask Response with paginated data
Example:
return paginated_response(
workflows_collection,
{"user": user_id},
serialize_workflow,
limit, skip,
response_key="workflows"
)
"""
items = list(
collection.find(query).sort(sort_field, sort_order).skip(skip).limit(limit)
)
total = collection.count_documents(query)
return success_response(
{
response_key: serialize_list(items, serializer),
"total": total,
"limit": limit,
"skip": skip,
}
)
def require_fields(required: List[str]) -> Callable:
"""
Decorator to validate required fields in request JSON.
Args:
required: List of required field names
Returns:
Decorator function
Example:
@require_fields(["name", "description"])
def post(self):
data = request.get_json()
...
"""
def decorator(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args, **kwargs):
data = request.get_json()
if not data:
return error_response("Request body required")
missing = [field for field in required if not data.get(field)]
if missing:
return error_response(f"Missing required fields: {', '.join(missing)}")
return func(*args, **kwargs)
return wrapper
return decorator
def safe_db_operation(
operation: Callable, error_message: str = "Database operation failed"
) -> Tuple[Any, Optional[Response]]:
"""
Safely execute database operation with error handling.
Args:
operation: Function to execute
error_message: Error message if operation fails
Returns:
Tuple of (result or None, error_response or None)
Example:
result, error = safe_db_operation(
lambda: collection.insert_one(doc),
"Failed to create resource"
)
if error:
return error
"""
try:
result = operation()
return result, None
except Exception as e:
return None, error_response(f"{error_message}: {str(e)}")
def validate_enum(
value: Any, allowed: List[Any], field_name: str
) -> Optional[Response]:
"""
Validate that value is in allowed list.
Args:
value: Value to validate
allowed: List of allowed values
field_name: Field name for error message
Returns:
error_response if invalid, None if valid
Example:
error = validate_enum(status, ["draft", "published"], "status")
if error:
return error
"""
if value not in allowed:
allowed_str = ", ".join(f"'{v}'" for v in allowed)
return error_response(f"Invalid {field_name}. Must be one of: {allowed_str}")
return None
def extract_sort_params(
default_field: str = "created_at",
default_order: str = "desc",
allowed_fields: Optional[List[str]] = None,
) -> Tuple[str, int]:
"""
Extract and validate sort parameters from request.
Args:
default_field: Default sort field
default_order: Default sort order ("asc" or "desc")
allowed_fields: List of allowed sort fields (None = no validation)
Returns:
Tuple of (sort_field, sort_order)
Example:
sort_field, sort_order = extract_sort_params(
allowed_fields=["name", "date", "status"]
)
"""
sort_field = request.args.get("sort", default_field)
sort_order_str = request.args.get("order", default_order).lower()
if allowed_fields and sort_field not in allowed_fields:
sort_field = default_field
sort_order = -1 if sort_order_str == "desc" else 1
return sort_field, sort_order

View File

@@ -1,3 +0,0 @@
from .routes import workflows_ns
__all__ = ["workflows_ns"]

View File

@@ -1,353 +0,0 @@
"""Workflow management routes."""
from datetime import datetime, timezone
from typing import Dict, List
from flask import current_app, request
from flask_restx import Namespace, Resource
from application.api.user.base import (
workflow_edges_collection,
workflow_nodes_collection,
workflows_collection,
)
from application.api.user.utils import (
check_resource_ownership,
error_response,
get_user_id,
require_auth,
require_fields,
safe_db_operation,
success_response,
validate_object_id,
)
workflows_ns = Namespace("workflows", path="/api")
def serialize_workflow(w: Dict) -> Dict:
"""Serialize workflow document to API response format."""
return {
"id": str(w["_id"]),
"name": w.get("name"),
"description": w.get("description"),
"created_at": w["created_at"].isoformat() if w.get("created_at") else None,
"updated_at": w["updated_at"].isoformat() if w.get("updated_at") else None,
}
def serialize_node(n: Dict) -> Dict:
"""Serialize workflow node document to API response format."""
return {
"id": n["id"],
"type": n["type"],
"title": n.get("title"),
"description": n.get("description"),
"position": n.get("position"),
"data": n.get("config", {}),
}
def serialize_edge(e: Dict) -> Dict:
"""Serialize workflow edge document to API response format."""
return {
"id": e["id"],
"source": e.get("source_id"),
"target": e.get("target_id"),
"sourceHandle": e.get("source_handle"),
"targetHandle": e.get("target_handle"),
}
def get_workflow_graph_version(workflow: Dict) -> int:
"""Get current graph version with legacy fallback."""
raw_version = workflow.get("current_graph_version", 1)
try:
version = int(raw_version)
return version if version > 0 else 1
except (ValueError, TypeError):
return 1
def fetch_graph_documents(collection, workflow_id: str, graph_version: int) -> List[Dict]:
"""Fetch graph docs for active version, with fallback for legacy unversioned data."""
docs = list(
collection.find({"workflow_id": workflow_id, "graph_version": graph_version})
)
if docs:
return docs
if graph_version == 1:
return list(
collection.find(
{"workflow_id": workflow_id, "graph_version": {"$exists": False}}
)
)
return docs
def validate_workflow_structure(nodes: List[Dict], edges: List[Dict]) -> List[str]:
"""Validate workflow graph structure."""
errors = []
if not nodes:
errors.append("Workflow must have at least one node")
return errors
start_nodes = [n for n in nodes if n.get("type") == "start"]
if len(start_nodes) != 1:
errors.append("Workflow must have exactly one start node")
end_nodes = [n for n in nodes if n.get("type") == "end"]
if not end_nodes:
errors.append("Workflow must have at least one end node")
node_ids = {n.get("id") for n in nodes}
for edge in edges:
source_id = edge.get("source")
target_id = edge.get("target")
if source_id not in node_ids:
errors.append(f"Edge references non-existent source: {source_id}")
if target_id not in node_ids:
errors.append(f"Edge references non-existent target: {target_id}")
if start_nodes:
start_id = start_nodes[0].get("id")
if not any(e.get("source") == start_id for e in edges):
errors.append("Start node must have at least one outgoing edge")
for node in nodes:
if not node.get("id"):
errors.append("All nodes must have an id")
if not node.get("type"):
errors.append(f"Node {node.get('id', 'unknown')} must have a type")
return errors
def create_workflow_nodes(
workflow_id: str, nodes_data: List[Dict], graph_version: int
) -> None:
"""Insert workflow nodes into database."""
if nodes_data:
workflow_nodes_collection.insert_many(
[
{
"id": n["id"],
"workflow_id": workflow_id,
"graph_version": graph_version,
"type": n["type"],
"title": n.get("title", ""),
"description": n.get("description", ""),
"position": n.get("position", {"x": 0, "y": 0}),
"config": n.get("data", {}),
}
for n in nodes_data
]
)
def create_workflow_edges(
workflow_id: str, edges_data: List[Dict], graph_version: int
) -> None:
"""Insert workflow edges into database."""
if edges_data:
workflow_edges_collection.insert_many(
[
{
"id": e["id"],
"workflow_id": workflow_id,
"graph_version": graph_version,
"source_id": e.get("source"),
"target_id": e.get("target"),
"source_handle": e.get("sourceHandle"),
"target_handle": e.get("targetHandle"),
}
for e in edges_data
]
)
@workflows_ns.route("/workflows")
class WorkflowList(Resource):
@require_auth
@require_fields(["name"])
def post(self):
"""Create a new workflow with nodes and edges."""
user_id = get_user_id()
data = request.get_json()
name = data.get("name", "").strip()
nodes_data = data.get("nodes", [])
edges_data = data.get("edges", [])
validation_errors = validate_workflow_structure(nodes_data, edges_data)
if validation_errors:
return error_response(
"Workflow validation failed", errors=validation_errors
)
now = datetime.now(timezone.utc)
workflow_doc = {
"name": name,
"description": data.get("description", ""),
"user": user_id,
"created_at": now,
"updated_at": now,
"current_graph_version": 1,
}
result, error = safe_db_operation(
lambda: workflows_collection.insert_one(workflow_doc),
"Failed to create workflow",
)
if error:
return error
workflow_id = str(result.inserted_id)
try:
create_workflow_nodes(workflow_id, nodes_data, 1)
create_workflow_edges(workflow_id, edges_data, 1)
except Exception as e:
workflow_nodes_collection.delete_many({"workflow_id": workflow_id})
workflow_edges_collection.delete_many({"workflow_id": workflow_id})
workflows_collection.delete_one({"_id": result.inserted_id})
return error_response(f"Failed to create workflow structure: {str(e)}")
return success_response({"id": workflow_id}, 201)
@workflows_ns.route("/workflows/<string:workflow_id>")
class WorkflowDetail(Resource):
@require_auth
def get(self, workflow_id: str):
"""Get workflow details with nodes and edges."""
user_id = get_user_id()
obj_id, error = validate_object_id(workflow_id, "Workflow")
if error:
return error
workflow, error = check_resource_ownership(
workflows_collection, obj_id, user_id, "Workflow"
)
if error:
return error
graph_version = get_workflow_graph_version(workflow)
nodes = fetch_graph_documents(
workflow_nodes_collection, workflow_id, graph_version
)
edges = fetch_graph_documents(
workflow_edges_collection, workflow_id, graph_version
)
return success_response(
{
"workflow": serialize_workflow(workflow),
"nodes": [serialize_node(n) for n in nodes],
"edges": [serialize_edge(e) for e in edges],
}
)
@require_auth
@require_fields(["name"])
def put(self, workflow_id: str):
"""Update workflow and replace nodes/edges."""
user_id = get_user_id()
obj_id, error = validate_object_id(workflow_id, "Workflow")
if error:
return error
workflow, error = check_resource_ownership(
workflows_collection, obj_id, user_id, "Workflow"
)
if error:
return error
data = request.get_json()
name = data.get("name", "").strip()
nodes_data = data.get("nodes", [])
edges_data = data.get("edges", [])
validation_errors = validate_workflow_structure(nodes_data, edges_data)
if validation_errors:
return error_response(
"Workflow validation failed", errors=validation_errors
)
current_graph_version = get_workflow_graph_version(workflow)
next_graph_version = current_graph_version + 1
try:
create_workflow_nodes(workflow_id, nodes_data, next_graph_version)
create_workflow_edges(workflow_id, edges_data, next_graph_version)
except Exception as e:
workflow_nodes_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": next_graph_version}
)
workflow_edges_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": next_graph_version}
)
return error_response(f"Failed to update workflow structure: {str(e)}")
now = datetime.now(timezone.utc)
_, error = safe_db_operation(
lambda: workflows_collection.update_one(
{"_id": obj_id},
{
"$set": {
"name": name,
"description": data.get("description", ""),
"updated_at": now,
"current_graph_version": next_graph_version,
}
},
),
"Failed to update workflow",
)
if error:
workflow_nodes_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": next_graph_version}
)
workflow_edges_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": next_graph_version}
)
return error
try:
workflow_nodes_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": {"$ne": next_graph_version}}
)
workflow_edges_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": {"$ne": next_graph_version}}
)
except Exception as cleanup_err:
current_app.logger.warning(
f"Failed to clean old workflow graph versions for {workflow_id}: {cleanup_err}"
)
return success_response()
@require_auth
def delete(self, workflow_id: str):
"""Delete workflow and its graph."""
user_id = get_user_id()
obj_id, error = validate_object_id(workflow_id, "Workflow")
if error:
return error
workflow, error = check_resource_ownership(
workflows_collection, obj_id, user_id, "Workflow"
)
if error:
return error
try:
workflow_nodes_collection.delete_many({"workflow_id": workflow_id})
workflow_edges_collection.delete_many({"workflow_id": workflow_id})
workflows_collection.delete_one({"_id": workflow["_id"], "user": user_id})
except Exception as e:
return error_response(f"Failed to delete workflow: {str(e)}")
return success_response()

View File

@@ -19,9 +19,9 @@ def handle_auth(request, data={}):
options={"verify_exp": False},
)
return decoded_token
except Exception:
except Exception as e:
return {
"message": "Authentication error: invalid token",
"message": f"Authentication error: {str(e)}",
"error": "invalid_token",
}
else:

View File

@@ -21,4 +21,3 @@ def config_loggers(*args, **kwargs):
celery = make_celery()
celery.config_from_object("application.celeryconfig")

View File

@@ -6,6 +6,3 @@ result_backend = os.getenv("CELERY_RESULT_BACKEND")
task_serializer = 'json'
result_serializer = 'json'
accept_content = ['json']
# Autodiscover tasks
imports = ('application.api.user.tasks',)

View File

@@ -8,8 +8,8 @@ from application.core.model_settings import (
ModelProvider,
)
# Base image attachment types supported by most vision-capable LLMs
IMAGE_ATTACHMENTS = [
OPENAI_ATTACHMENTS = [
"application/pdf",
"image/png",
"image/jpeg",
"image/jpg",
@@ -17,42 +17,75 @@ IMAGE_ATTACHMENTS = [
"image/gif",
]
# PDF excluded: most OpenAI-compatible endpoints don't support native PDF uploads.
# When excluded, PDFs are synthetically processed by converting pages to images.
OPENAI_ATTACHMENTS = IMAGE_ATTACHMENTS
GOOGLE_ATTACHMENTS = ["application/pdf"] + IMAGE_ATTACHMENTS
ANTHROPIC_ATTACHMENTS = IMAGE_ATTACHMENTS
OPENROUTER_ATTACHMENTS = IMAGE_ATTACHMENTS
GOOGLE_ATTACHMENTS = [
"application/pdf",
"image/png",
"image/jpeg",
"image/jpg",
"image/webp",
"image/gif",
]
OPENAI_MODELS = [
AvailableModel(
id="gpt-5.1",
id="gpt-4o",
provider=ModelProvider.OPENAI,
display_name="GPT-5.1",
description="Flagship model with enhanced reasoning, coding, and agentic capabilities",
display_name="GPT-4 Omni",
description="Latest and most capable model",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=OPENAI_ATTACHMENTS,
context_window=200000,
context_window=128000,
),
),
AvailableModel(
id="gpt-5-mini",
id="gpt-4o-mini",
provider=ModelProvider.OPENAI,
display_name="GPT-5 Mini",
description="Faster, cost-effective variant of GPT-5.1",
display_name="GPT-4 Omni Mini",
description="Fast and efficient",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=OPENAI_ATTACHMENTS,
context_window=200000,
context_window=128000,
),
)
),
AvailableModel(
id="gpt-4-turbo",
provider=ModelProvider.OPENAI,
display_name="GPT-4 Turbo",
description="Fast GPT-4 with 128k context",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=OPENAI_ATTACHMENTS,
context_window=128000,
),
),
AvailableModel(
id="gpt-4",
provider=ModelProvider.OPENAI,
display_name="GPT-4",
description="Most capable model",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=OPENAI_ATTACHMENTS,
context_window=8192,
),
),
AvailableModel(
id="gpt-3.5-turbo",
provider=ModelProvider.OPENAI,
display_name="GPT-3.5 Turbo",
description="Fast and cost-effective",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=4096,
),
),
]
@@ -64,7 +97,6 @@ ANTHROPIC_MODELS = [
description="Latest Claude 3.5 Sonnet with enhanced capabilities",
capabilities=ModelCapabilities(
supports_tools=True,
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
context_window=200000,
),
),
@@ -75,7 +107,6 @@ ANTHROPIC_MODELS = [
description="Balanced performance and capability",
capabilities=ModelCapabilities(
supports_tools=True,
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
context_window=200000,
),
),
@@ -86,7 +117,6 @@ ANTHROPIC_MODELS = [
description="Most capable Claude model",
capabilities=ModelCapabilities(
supports_tools=True,
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
context_window=200000,
),
),
@@ -97,7 +127,6 @@ ANTHROPIC_MODELS = [
description="Fastest Claude model",
capabilities=ModelCapabilities(
supports_tools=True,
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
context_window=200000,
),
),
@@ -130,9 +159,9 @@ GOOGLE_MODELS = [
),
),
AvailableModel(
id="gemini-3-pro-preview",
id="gemini-2.5-pro",
provider=ModelProvider.GOOGLE,
display_name="Gemini 3 Pro",
display_name="Gemini 2.5 Pro",
description="Most capable Gemini model",
capabilities=ModelCapabilities(
supports_tools=True,
@@ -156,43 +185,28 @@ GROQ_MODELS = [
),
),
AvailableModel(
id="openai/gpt-oss-120b",
id="llama-3.1-8b-instant",
provider=ModelProvider.GROQ,
display_name="GPT-OSS 120B",
description="Open-source GPT model optimized for speed",
display_name="Llama 3.1 8B",
description="Ultra-fast inference",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=128000,
),
),
AvailableModel(
id="mixtral-8x7b-32768",
provider=ModelProvider.GROQ,
display_name="Mixtral 8x7B",
description="High-speed inference with tools",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=32768,
),
),
]
OPENROUTER_MODELS = [
AvailableModel(
id="qwen/qwen3-coder:free",
provider=ModelProvider.OPENROUTER,
display_name="Qwen 3 Coder",
description="Latest Qwen model with high-speed inference",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=128000,
supported_attachment_types=OPENROUTER_ATTACHMENTS
),
),
AvailableModel(
id="google/gemma-3-27b-it:free",
provider=ModelProvider.OPENROUTER,
display_name="Gemma 3 27B",
description="Latest Gemma model with high-speed inference",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=128000,
supported_attachment_types=OPENROUTER_ATTACHMENTS
),
),
]
AZURE_OPENAI_MODELS = [
AvailableModel(
id="azure-gpt-4",
@@ -207,18 +221,3 @@ AZURE_OPENAI_MODELS = [
),
),
]
def create_custom_openai_model(model_name: str, base_url: str) -> AvailableModel:
"""Create a custom OpenAI-compatible model (e.g., LM Studio, Ollama)."""
return AvailableModel(
id=model_name,
provider=ModelProvider.OPENAI,
display_name=model_name,
description=f"Custom OpenAI-compatible model at {base_url}",
base_url=base_url,
capabilities=ModelCapabilities(
supports_tools=True,
supported_attachment_types=OPENAI_ATTACHMENTS,
),
)

View File

@@ -8,7 +8,6 @@ logger = logging.getLogger(__name__)
class ModelProvider(str, Enum):
OPENAI = "openai"
OPENROUTER = "openrouter"
AZURE_OPENAI = "azure_openai"
ANTHROPIC = "anthropic"
GROQ = "groq"
@@ -85,13 +84,9 @@ class ModelRegistry:
self.models.clear()
# Skip DocsGPT model if using custom OpenAI-compatible endpoint
if not settings.OPENAI_BASE_URL:
self._add_docsgpt_models(settings)
if (
settings.OPENAI_API_KEY
or (settings.LLM_PROVIDER == "openai" and settings.API_KEY)
or settings.OPENAI_BASE_URL
self._add_docsgpt_models(settings)
if settings.OPENAI_API_KEY or (
settings.LLM_PROVIDER == "openai" and settings.API_KEY
):
self._add_openai_models(settings)
if settings.OPENAI_API_BASE or (
@@ -110,69 +105,39 @@ class ModelRegistry:
settings.LLM_PROVIDER == "groq" and settings.API_KEY
):
self._add_groq_models(settings)
if settings.OPEN_ROUTER_API_KEY or (
settings.LLM_PROVIDER == "openrouter" and settings.API_KEY
):
self._add_openrouter_models(settings)
if settings.HUGGINGFACE_API_KEY or (
settings.LLM_PROVIDER == "huggingface" and settings.API_KEY
):
self._add_huggingface_models(settings)
# Default model selection
if settings.LLM_NAME:
# Parse LLM_NAME (may be comma-separated)
model_names = self._parse_model_names(settings.LLM_NAME)
# First model in the list becomes default
for model_name in model_names:
if model_name in self.models:
self.default_model_id = model_name
if settings.LLM_NAME and settings.LLM_NAME in self.models:
self.default_model_id = settings.LLM_NAME
elif settings.LLM_PROVIDER and settings.API_KEY:
for model_id, model in self.models.items():
if model.provider.value == settings.LLM_PROVIDER:
self.default_model_id = model_id
break
# Backward compat: try exact match if no parsed model found
if not self.default_model_id and settings.LLM_NAME in self.models:
self.default_model_id = settings.LLM_NAME
if not self.default_model_id:
if settings.LLM_PROVIDER and settings.API_KEY:
for model_id, model in self.models.items():
if model.provider.value == settings.LLM_PROVIDER:
self.default_model_id = model_id
break
if not self.default_model_id and self.models:
else:
self.default_model_id = next(iter(self.models.keys()))
logger.info(
f"ModelRegistry loaded {len(self.models)} models, default: {self.default_model_id}"
)
def _add_openai_models(self, settings):
from application.core.model_configs import (
OPENAI_MODELS,
create_custom_openai_model,
)
from application.core.model_configs import OPENAI_MODELS
# Check if using local OpenAI-compatible endpoint (Ollama, LM Studio, etc.)
using_local_endpoint = bool(
settings.OPENAI_BASE_URL and settings.OPENAI_BASE_URL.strip()
)
if using_local_endpoint:
# When OPENAI_BASE_URL is set, ONLY register custom models from LLM_NAME
# Do NOT add standard OpenAI models (gpt-5.1, etc.)
if settings.LLM_NAME:
model_names = self._parse_model_names(settings.LLM_NAME)
for model_name in model_names:
custom_model = create_custom_openai_model(
model_name, settings.OPENAI_BASE_URL
)
self.models[model_name] = custom_model
logger.info(
f"Registered custom OpenAI model: {model_name} at {settings.OPENAI_BASE_URL}"
)
else:
# Standard OpenAI API usage - add standard models if API key is valid
if settings.OPENAI_API_KEY:
for model in OPENAI_MODELS:
if settings.OPENAI_API_KEY:
for model in OPENAI_MODELS:
self.models[model.id] = model
return
if settings.LLM_PROVIDER == "openai" and settings.LLM_NAME:
for model in OPENAI_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in OPENAI_MODELS:
self.models[model.id] = model
def _add_azure_openai_models(self, settings):
from application.core.model_configs import AZURE_OPENAI_MODELS
@@ -229,21 +194,6 @@ class ModelRegistry:
return
for model in GROQ_MODELS:
self.models[model.id] = model
def _add_openrouter_models(self, settings):
from application.core.model_configs import OPENROUTER_MODELS
if settings.OPEN_ROUTER_API_KEY:
for model in OPENROUTER_MODELS:
self.models[model.id] = model
return
if settings.LLM_PROVIDER == "openrouter" and settings.LLM_NAME:
for model in OPENROUTER_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in OPENROUTER_MODELS:
self.models[model.id] = model
def _add_docsgpt_models(self, settings):
model_id = "docsgpt-local"
@@ -273,15 +223,6 @@ class ModelRegistry:
)
self.models[model_id] = model
def _parse_model_names(self, llm_name: str) -> List[str]:
"""
Parse LLM_NAME which may contain comma-separated model names.
E.g., 'deepseek-r1:1.5b,gemma:2b' -> ['deepseek-r1:1.5b', 'gemma:2b']
"""
if not llm_name:
return []
return [name.strip() for name in llm_name.split(",") if name.strip()]
def get_model(self, model_id: str) -> Optional[AvailableModel]:
return self.models.get(model_id)

View File

@@ -9,7 +9,6 @@ def get_api_key_for_provider(provider: str) -> Optional[str]:
provider_key_map = {
"openai": settings.OPENAI_API_KEY,
"openrouter": settings.OPEN_ROUTER_API_KEY,
"anthropic": settings.ANTHROPIC_API_KEY,
"google": settings.GOOGLE_API_KEY,
"groq": settings.GROQ_API_KEY,

View File

@@ -2,8 +2,7 @@ import os
from pathlib import Path
from typing import Optional
from pydantic import field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
from pydantic_settings import BaseSettings
current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
@@ -11,19 +10,12 @@ current_dir = os.path.dirname(
class Settings(BaseSettings):
model_config = SettingsConfigDict(extra="ignore")
AUTH_TYPE: Optional[str] = None # simple_jwt, session_jwt, or None
LLM_PROVIDER: str = "docsgpt"
LLM_NAME: Optional[str] = (
None # if LLM_PROVIDER is openai, LLM_NAME can be gpt-4 or gpt-3.5-turbo
)
EMBEDDINGS_NAME: str = "huggingface_sentence-transformers/all-mpnet-base-v2"
EMBEDDINGS_BASE_URL: Optional[str] = None # Remote embeddings API URL (OpenAI-compatible)
EMBEDDINGS_KEY: Optional[str] = (
None # api key for embeddings (if using openai, just copy API_KEY)
)
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
MONGO_URI: str = "mongodb://localhost:27017/docsgpt"
@@ -43,10 +35,8 @@ class Settings(BaseSettings):
UPLOAD_FOLDER: str = "inputs"
PARSE_PDF_AS_IMAGE: bool = False
PARSE_IMAGE_REMOTE: bool = False
DOCLING_OCR_ENABLED: bool = False # Enable OCR for docling parsers (PDF, images)
DOCLING_OCR_ATTACHMENTS_ENABLED: bool = False # Enable OCR for docling when parsing attachments
VECTOR_STORE: str = (
"faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb" or "pgvector"
"faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb"
)
RETRIEVERS_ENABLED: list = ["classic_rag"]
AGENT_NAME: str = "classic"
@@ -65,20 +55,13 @@ class Settings(BaseSettings):
"http://127.0.0.1:7091/api/connectors/callback" ##add redirect url as it is to your provider's console(gcp)
)
# Microsoft Entra ID (Azure AD) integration
MICROSOFT_CLIENT_ID: Optional[str] = None # Azure AD Application (client) ID
MICROSOFT_CLIENT_SECRET: Optional[str] = None # Azure AD Application client secret
MICROSOFT_TENANT_ID: Optional[str] = "common" # Azure AD Tenant ID (or 'common' for multi-tenant)
MICROSOFT_AUTHORITY: Optional[str] = None # e.g., "https://login.microsoftonline.com/{tenant_id}"
# GitHub source
GITHUB_ACCESS_TOKEN: Optional[str] = None # PAT token with read repo access
GITHUB_ACCESS_TOKEN: Optional[str] = None # PAT token with read repo access
# LLM Cache
CACHE_REDIS_URL: str = "redis://localhost:6379/2"
API_URL: str = "http://localhost:7091" # backend url for celery worker
INTERNAL_KEY: Optional[str] = None # internal api key for worker-to-backend auth
API_KEY: Optional[str] = None # LLM api key (used by LLM_PROVIDER)
@@ -88,8 +71,10 @@ class Settings(BaseSettings):
GOOGLE_API_KEY: Optional[str] = None
GROQ_API_KEY: Optional[str] = None
HUGGINGFACE_API_KEY: Optional[str] = None
OPEN_ROUTER_API_KEY: Optional[str] = None
EMBEDDINGS_KEY: Optional[str] = (
None # api key for embeddings (if using openai, just copy API_KEY)
)
OPENAI_API_BASE: Optional[str] = None # azure openai api base url
OPENAI_API_VERSION: Optional[str] = None # azure openai api version
AZURE_DEPLOYMENT_NAME: Optional[str] = None # azure deployment name for answering
@@ -139,7 +124,7 @@ class Settings(BaseSettings):
MILVUS_TOKEN: Optional[str] = ""
# LanceDB vectorstore config
LANCEDB_PATH: str = "./data/lancedb" # Path where LanceDB stores its local data
LANCEDB_PATH: str = "/tmp/lancedb" # Path where LanceDB stores its local data
LANCEDB_TABLE_NAME: Optional[str] = (
"docsgpts" # Name of the table to use for storing vectors
)
@@ -159,44 +144,6 @@ class Settings(BaseSettings):
# Tool pre-fetch settings
ENABLE_TOOL_PREFETCH: bool = True
# Conversation Compression Settings
ENABLE_CONVERSATION_COMPRESSION: bool = True
COMPRESSION_THRESHOLD_PERCENTAGE: float = 0.8 # Trigger at 80% of context
COMPRESSION_MODEL_OVERRIDE: Optional[str] = None # Use different model for compression
COMPRESSION_PROMPT_VERSION: str = "v1.0" # Track prompt iterations
COMPRESSION_MAX_HISTORY_POINTS: int = 3 # Keep only last N compression points to prevent DB bloat
@field_validator(
"API_KEY",
"OPENAI_API_KEY",
"ANTHROPIC_API_KEY",
"GOOGLE_API_KEY",
"GROQ_API_KEY",
"HUGGINGFACE_API_KEY",
"EMBEDDINGS_KEY",
"FALLBACK_LLM_API_KEY",
"QDRANT_API_KEY",
"ELEVENLABS_API_KEY",
"INTERNAL_KEY",
mode="before",
)
@classmethod
def normalize_api_key(cls, v: Optional[str]) -> Optional[str]:
"""
Normalize API keys: convert 'None', 'none', empty strings,
and whitespace-only strings to actual None.
Handles Pydantic loading 'None' from .env as string "None".
"""
if v is None:
return None
if not isinstance(v, str):
return v
stripped = v.strip()
if stripped == "" or stripped.lower() == "none":
return None
return stripped
# Project root is one level above application/
path = Path(__file__).parent.parent.parent.absolute()
path = Path(__file__).parent.parent.absolute()
settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8")

View File

@@ -1,181 +0,0 @@
"""
URL validation utilities to prevent SSRF (Server-Side Request Forgery) attacks.
This module provides functions to validate URLs before making HTTP requests,
blocking access to internal networks, cloud metadata services, and other
potentially dangerous endpoints.
"""
import ipaddress
import socket
from urllib.parse import urlparse
from typing import Optional, Set
class SSRFError(Exception):
"""Raised when a URL fails SSRF validation."""
pass
# Blocked hostnames that should never be accessed
BLOCKED_HOSTNAMES: Set[str] = {
"localhost",
"localhost.localdomain",
"metadata.google.internal",
"metadata",
}
# Cloud metadata IP addresses (AWS, GCP, Azure, etc.)
METADATA_IPS: Set[str] = {
"169.254.169.254", # AWS, GCP, Azure metadata
"169.254.170.2", # AWS ECS task metadata
"fd00:ec2::254", # AWS IPv6 metadata
}
# Allowed schemes for external requests
ALLOWED_SCHEMES: Set[str] = {"http", "https"}
def is_private_ip(ip_str: str) -> bool:
"""
Check if an IP address is private, loopback, or link-local.
Args:
ip_str: IP address as a string
Returns:
True if the IP is private/internal, False otherwise
"""
try:
ip = ipaddress.ip_address(ip_str)
return (
ip.is_private or
ip.is_loopback or
ip.is_link_local or
ip.is_reserved or
ip.is_multicast or
ip.is_unspecified
)
except ValueError:
# If we can't parse it as an IP, return False
return False
def is_metadata_ip(ip_str: str) -> bool:
"""
Check if an IP address is a cloud metadata service IP.
Args:
ip_str: IP address as a string
Returns:
True if the IP is a metadata service, False otherwise
"""
return ip_str in METADATA_IPS
def resolve_hostname(hostname: str) -> Optional[str]:
"""
Resolve a hostname to an IP address.
Args:
hostname: The hostname to resolve
Returns:
The resolved IP address, or None if resolution fails
"""
try:
return socket.gethostbyname(hostname)
except socket.gaierror:
return None
def validate_url(url: str, allow_localhost: bool = False) -> str:
"""
Validate a URL to prevent SSRF attacks.
This function checks that:
1. The URL has an allowed scheme (http or https)
2. The hostname is not a blocked hostname
3. The resolved IP is not a private/internal IP
4. The resolved IP is not a cloud metadata service
Args:
url: The URL to validate
allow_localhost: If True, allow localhost connections (for testing only)
Returns:
The validated URL (with scheme added if missing)
Raises:
SSRFError: If the URL fails validation
"""
# Ensure URL has a scheme
if not urlparse(url).scheme:
url = "http://" + url
parsed = urlparse(url)
# Check scheme
if parsed.scheme not in ALLOWED_SCHEMES:
raise SSRFError(f"URL scheme '{parsed.scheme}' is not allowed. Only HTTP(S) is permitted.")
hostname = parsed.hostname
if not hostname:
raise SSRFError("URL must have a valid hostname.")
hostname_lower = hostname.lower()
# Check blocked hostnames
if hostname_lower in BLOCKED_HOSTNAMES and not allow_localhost:
raise SSRFError(f"Access to '{hostname}' is not allowed.")
# Check if hostname is an IP address directly
try:
ip = ipaddress.ip_address(hostname)
ip_str = str(ip)
if is_metadata_ip(ip_str):
raise SSRFError("Access to cloud metadata services is not allowed.")
if is_private_ip(ip_str) and not allow_localhost:
raise SSRFError("Access to private/internal IP addresses is not allowed.")
return url
except ValueError:
# Not an IP address, it's a hostname - resolve it
pass
# Resolve hostname and check the IP
resolved_ip = resolve_hostname(hostname)
if resolved_ip is None:
raise SSRFError(f"Unable to resolve hostname: {hostname}")
if is_metadata_ip(resolved_ip):
raise SSRFError("Access to cloud metadata services is not allowed.")
if is_private_ip(resolved_ip) and not allow_localhost:
raise SSRFError("Access to private/internal networks is not allowed.")
return url
def validate_url_safe(url: str, allow_localhost: bool = False) -> tuple[bool, str, Optional[str]]:
"""
Validate a URL and return a tuple with validation result.
This is a non-throwing version of validate_url for cases where
you want to handle validation failures gracefully.
Args:
url: The URL to validate
allow_localhost: If True, allow localhost connections (for testing only)
Returns:
Tuple of (is_valid, validated_url_or_original, error_message_or_none)
"""
try:
validated = validate_url(url, allow_localhost)
return (True, validated, None)
except SSRFError as e:
return (False, url, str(e))

View File

@@ -1,13 +1,7 @@
import base64
import logging
from anthropic import AI_PROMPT, Anthropic, HUMAN_PROMPT
from application.core.settings import settings
from application.llm.base import BaseLLM
from application.storage.storage_creator import StorageCreator
logger = logging.getLogger(__name__)
class AnthropicLLM(BaseLLM):
@@ -26,7 +20,6 @@ class AnthropicLLM(BaseLLM):
self.HUMAN_PROMPT = HUMAN_PROMPT
self.AI_PROMPT = AI_PROMPT
self.storage = StorageCreator.get_storage()
def _raw_gen(
self,
@@ -77,115 +70,3 @@ class AnthropicLLM(BaseLLM):
finally:
if hasattr(stream_response, "close"):
stream_response.close()
def get_supported_attachment_types(self):
"""
Return a list of MIME types supported by Anthropic Claude for file uploads.
Claude supports images but not PDFs natively.
PDFs are synthetically supported via PDF-to-image conversion in the handler.
Returns:
list: List of supported MIME types
"""
return [
"image/png",
"image/jpeg",
"image/jpg",
"image/webp",
"image/gif",
]
def prepare_messages_with_attachments(self, messages, attachments=None):
"""
Process attachments for Anthropic Claude API.
Formats images using Claude's vision message format.
Args:
messages (list): List of message dictionaries.
attachments (list): List of attachment dictionaries with content and metadata.
Returns:
list: Messages formatted with image content for Claude API.
"""
if not attachments:
return messages
prepared_messages = messages.copy()
# Find the last user message to attach images to
user_message_index = None
for i in range(len(prepared_messages) - 1, -1, -1):
if prepared_messages[i].get("role") == "user":
user_message_index = i
break
if user_message_index is None:
user_message = {"role": "user", "content": []}
prepared_messages.append(user_message)
user_message_index = len(prepared_messages) - 1
# Convert content to list format if it's a string
if isinstance(prepared_messages[user_message_index].get("content"), str):
text_content = prepared_messages[user_message_index]["content"]
prepared_messages[user_message_index]["content"] = [
{"type": "text", "text": text_content}
]
elif not isinstance(prepared_messages[user_message_index].get("content"), list):
prepared_messages[user_message_index]["content"] = []
for attachment in attachments:
mime_type = attachment.get("mime_type")
if mime_type and mime_type.startswith("image/"):
try:
# Check if this is a pre-converted image (from PDF-to-image conversion)
# These have 'data' key with base64 already
if "data" in attachment:
base64_image = attachment["data"]
else:
base64_image = self._get_base64_image(attachment)
# Claude uses a specific format for images
prepared_messages[user_message_index]["content"].append(
{
"type": "image",
"source": {
"type": "base64",
"media_type": mime_type,
"data": base64_image,
},
}
)
except Exception as e:
logger.error(
f"Error processing image attachment: {e}", exc_info=True
)
if "content" in attachment:
prepared_messages[user_message_index]["content"].append(
{
"type": "text",
"text": f"[Image could not be processed: {attachment.get('path', 'unknown')}]",
}
)
return prepared_messages
def _get_base64_image(self, attachment):
"""
Convert an image file to base64 encoding.
Args:
attachment (dict): Attachment dictionary with path and metadata.
Returns:
str: Base64-encoded image data.
"""
file_path = attachment.get("path")
if not file_path:
raise ValueError("No file path provided in attachment")
try:
with self.storage.get_file(file_path) as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
except FileNotFoundError:
raise FileNotFoundError(f"File not found: {file_path}")

View File

@@ -1,19 +1,75 @@
import json
from openai import OpenAI
from application.core.settings import settings
from application.llm.openai import OpenAILLM
from application.llm.base import BaseLLM
DOCSGPT_API_KEY = "sk-docsgpt-public"
DOCSGPT_BASE_URL = "https://oai.arc53.com"
DOCSGPT_MODEL = "docsgpt"
class DocsGPTAPILLM(OpenAILLM):
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
super().__init__(
api_key=DOCSGPT_API_KEY,
user_api_key=user_api_key,
base_url=DOCSGPT_BASE_URL,
*args,
**kwargs,
)
class DocsGPTAPILLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.api_key = "sk-docsgpt-public"
self.client = OpenAI(api_key=self.api_key, base_url="https://oai.arc53.com")
self.user_api_key = user_api_key
def _clean_messages_openai(self, messages):
cleaned_messages = []
for message in messages:
role = message.get("role")
content = message.get("content")
if role == "model":
role = "assistant"
if role and content is not None:
if isinstance(content, str):
cleaned_messages.append({"role": role, "content": content})
elif isinstance(content, list):
for item in content:
if "text" in item:
cleaned_messages.append(
{"role": role, "content": item["text"]}
)
elif "function_call" in item:
cleaned_args = self._remove_null_values(
item["function_call"]["args"]
)
tool_call = {
"id": item["function_call"]["call_id"],
"type": "function",
"function": {
"name": item["function_call"]["name"],
"arguments": json.dumps(cleaned_args),
},
}
cleaned_messages.append(
{
"role": "assistant",
"content": None,
"tool_calls": [tool_call],
}
)
elif "function_response" in item:
cleaned_messages.append(
{
"role": "tool",
"tool_call_id": item["function_response"][
"call_id"
],
"content": json.dumps(
item["function_response"]["response"]["result"]
),
}
)
else:
raise ValueError(
f"Unexpected content dictionary format: {item}"
)
else:
raise ValueError(f"Unexpected content type: {type(content)}")
return cleaned_messages
def _raw_gen(
self,
@@ -23,19 +79,23 @@ class DocsGPTAPILLM(OpenAILLM):
stream=False,
tools=None,
engine=settings.AZURE_DEPLOYMENT_NAME,
response_format=None,
**kwargs,
):
return super()._raw_gen(
baseself,
DOCSGPT_MODEL,
messages,
stream=stream,
tools=tools,
engine=engine,
response_format=response_format,
**kwargs,
)
messages = self._clean_messages_openai(messages)
if tools:
response = self.client.chat.completions.create(
model="docsgpt",
messages=messages,
stream=stream,
tools=tools,
**kwargs,
)
return response.choices[0]
else:
response = self.client.chat.completions.create(
model="docsgpt", messages=messages, stream=stream, **kwargs
)
return response.choices[0].message.content
def _raw_gen_stream(
self,
@@ -45,16 +105,34 @@ class DocsGPTAPILLM(OpenAILLM):
stream=True,
tools=None,
engine=settings.AZURE_DEPLOYMENT_NAME,
response_format=None,
**kwargs,
):
return super()._raw_gen_stream(
baseself,
DOCSGPT_MODEL,
messages,
stream=stream,
tools=tools,
engine=engine,
response_format=response_format,
**kwargs,
)
messages = self._clean_messages_openai(messages)
if tools:
response = self.client.chat.completions.create(
model="docsgpt",
messages=messages,
stream=stream,
tools=tools,
**kwargs,
)
else:
response = self.client.chat.completions.create(
model="docsgpt", messages=messages, stream=stream, **kwargs
)
try:
for line in response:
if (
len(line.choices) > 0
and line.choices[0].delta.content is not None
and len(line.choices[0].delta.content) > 0
):
yield line.choices[0].delta.content
elif len(line.choices) > 0:
yield line.choices[0]
finally:
if hasattr(response, "close"):
response.close()
def _supports_tools(self):
return True

View File

@@ -1,3 +1,4 @@
import json
import logging
from google import genai
@@ -10,13 +11,11 @@ from application.storage.storage_creator import StorageCreator
class GoogleLLM(BaseLLM):
def __init__(
self, api_key=None, user_api_key=None, decoded_token=None, *args, **kwargs
):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.api_key = api_key or settings.GOOGLE_API_KEY or settings.API_KEY
self.user_api_key = user_api_key
self.client = genai.Client(api_key=self.api_key)
self.storage = StorageCreator.get_storage()
@@ -34,12 +33,6 @@ class GoogleLLM(BaseLLM):
"image/jpg",
"image/webp",
"image/gif",
"application/pdf",
"image/png",
"image/jpeg",
"image/jpg",
"image/webp",
"image/gif",
]
def prepare_messages_with_attachments(self, messages, attachments=None):
@@ -142,38 +135,12 @@ class GoogleLLM(BaseLLM):
raise
def _clean_messages_google(self, messages):
"""
Convert OpenAI format messages to Google AI format and collect system prompts.
Returns:
tuple[list[types.Content], Optional[str]]: cleaned messages and optional
combined system instruction.
"""
"""Convert OpenAI format messages to Google AI format."""
cleaned_messages = []
system_instructions = []
def _extract_system_text(content):
if isinstance(content, str):
return content
if isinstance(content, list):
parts = []
for item in content:
if isinstance(item, dict) and "text" in item and item["text"] is not None:
parts.append(item["text"])
return "\n".join(parts)
return ""
for message in messages:
role = message.get("role")
content = message.get("content")
# Gemini only accepts user/model in the contents list.
if role == "system":
sys_text = _extract_system_text(content)
if sys_text:
system_instructions.append(sys_text)
continue
if role == "assistant":
role = "model"
elif role == "tool":
@@ -192,27 +159,12 @@ class GoogleLLM(BaseLLM):
cleaned_args = self._remove_null_values(
item["function_call"]["args"]
)
# Create function call part with thought_signature if present
# For Gemini 3 models, we need to include thought_signature
if "thought_signature" in item:
# Use Part constructor with functionCall and thoughtSignature
parts.append(
types.Part(
functionCall=types.FunctionCall(
name=item["function_call"]["name"],
args=cleaned_args,
),
thoughtSignature=item["thought_signature"],
)
)
else:
# Use helper method when no thought_signature
parts.append(
types.Part.from_function_call(
name=item["function_call"]["name"],
args=cleaned_args,
)
parts.append(
types.Part.from_function_call(
name=item["function_call"]["name"],
args=cleaned_args,
)
)
elif "function_response" in item:
parts.append(
types.Part.from_function_response(
@@ -236,8 +188,7 @@ class GoogleLLM(BaseLLM):
raise ValueError(f"Unexpected content type: {type(content)}")
if parts:
cleaned_messages.append(types.Content(role=role, parts=parts))
system_instruction = "\n\n".join(system_instructions) if system_instructions else None
return cleaned_messages, system_instruction
return cleaned_messages
def _clean_schema(self, schema_obj):
"""
@@ -323,77 +274,6 @@ class GoogleLLM(BaseLLM):
genai_tools.append(genai_tool)
return genai_tools
def _extract_preview_from_message(self, message):
"""Get a short, human-readable preview from the last message."""
try:
if hasattr(message, "parts"):
for part in reversed(message.parts):
if getattr(part, "text", None):
return part.text
function_call = getattr(part, "function_call", None)
if function_call:
name = getattr(function_call, "name", "") or "function_call"
return f"function_call:{name}"
function_response = getattr(part, "function_response", None)
if function_response:
name = getattr(function_response, "name", "") or "function_response"
return f"function_response:{name}"
if isinstance(message, dict):
content = message.get("content")
if isinstance(content, str):
return content
if isinstance(content, list):
for item in reversed(content):
if isinstance(item, str):
return item
if isinstance(item, dict):
if item.get("text"):
return item["text"]
if item.get("function_call"):
fn = item["function_call"]
if isinstance(fn, dict):
name = fn.get("name") or "function_call"
return f"function_call:{name}"
return "function_call"
if item.get("function_response"):
resp = item["function_response"]
if isinstance(resp, dict):
name = resp.get("name") or "function_response"
return f"function_response:{name}"
return "function_response"
if "text" in message and isinstance(message["text"], str):
return message["text"]
except Exception:
pass
return str(message)
def _summarize_messages_for_log(self, messages, preview_chars=20):
"""Return a compact summary for logging to avoid huge payloads."""
message_count = len(messages) if messages else 0
last_preview = ""
if messages:
last_preview = self._extract_preview_from_message(messages[-1]) or ""
last_preview = str(last_preview).replace("\n", " ")
if len(last_preview) > preview_chars:
last_preview = f"{last_preview[:preview_chars]}..."
return f"count={message_count}, last='{last_preview}'"
@staticmethod
def _get_text_value(part):
"""Get text from both SDK objects and dict-shaped test doubles."""
if isinstance(part, dict):
value = part.get("text")
return value if isinstance(value, str) else ""
value = getattr(part, "text", None)
return value if isinstance(value, str) else ""
@staticmethod
def _is_thought_part(part):
"""Detect Gemini thinking parts when available."""
if isinstance(part, dict):
return bool(part.get("thought"))
return bool(getattr(part, "thought", False))
def _raw_gen(
self,
baseself,
@@ -407,12 +287,12 @@ class GoogleLLM(BaseLLM):
):
"""Generate content using Google AI API without streaming."""
client = genai.Client(api_key=self.api_key)
system_instruction = None
if formatting == "openai":
messages, system_instruction = self._clean_messages_google(messages)
messages = self._clean_messages_google(messages)
config = types.GenerateContentConfig()
if system_instruction:
config.system_instruction = system_instruction
if messages[0].role == "system":
config.system_instruction = messages[0].parts[0].text
messages = messages[1:]
if tools:
cleaned_tools = self._clean_tools_format(tools)
config.tools = cleaned_tools
@@ -445,15 +325,16 @@ class GoogleLLM(BaseLLM):
):
"""Generate content using Google AI API with streaming."""
client = genai.Client(api_key=self.api_key)
system_instruction = None
if formatting == "openai":
messages, system_instruction = self._clean_messages_google(messages)
messages = self._clean_messages_google(messages)
config = types.GenerateContentConfig()
if system_instruction:
config.system_instruction = system_instruction
if messages[0].role == "system":
config.system_instruction = messages[0].parts[0].text
messages = messages[1:]
if tools:
cleaned_tools = self._clean_tools_format(tools)
config.tools = cleaned_tools
# Add response schema for structured output if provided
if response_schema:
config.response_schema = response_schema
@@ -468,12 +349,8 @@ class GoogleLLM(BaseLLM):
break
if has_attachments:
break
messages_summary = self._summarize_messages_for_log(messages)
logging.info(
"GoogleLLM: Starting stream generation. Model: %s, Messages: %s, Has attachments: %s",
model,
messages_summary,
has_attachments,
f"GoogleLLM: Starting stream generation. Model: {model}, Messages: {json.dumps(messages, default=str)}, Has attachments: {has_attachments}"
)
response = client.models.generate_content_stream(
@@ -490,23 +367,10 @@ class GoogleLLM(BaseLLM):
for part in candidate.content.parts:
if part.function_call:
yield part
continue
part_text = self._get_text_value(part)
if not part_text:
continue
if self._is_thought_part(part):
yield {"type": "thought", "thought": part_text}
else:
yield part_text
elif part.text:
yield part.text
elif hasattr(chunk, "text"):
chunk_text = self._get_text_value(chunk)
if chunk_text:
if self._is_thought_part(chunk):
yield {"type": "thought", "thought": chunk_text}
else:
yield chunk_text
yield chunk.text
finally:
if hasattr(response, "close"):
response.close()

View File

@@ -1,15 +1,37 @@
from openai import OpenAI
from application.core.settings import settings
from application.llm.openai import OpenAILLM
GROQ_BASE_URL = "https://api.groq.com/openai/v1"
from application.llm.base import BaseLLM
class GroqLLM(OpenAILLM):
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
super().__init__(
api_key=api_key or settings.GROQ_API_KEY or settings.API_KEY,
user_api_key=user_api_key,
base_url=base_url or GROQ_BASE_URL,
*args,
**kwargs,
class GroqLLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.api_key = api_key or settings.GROQ_API_KEY or settings.API_KEY
self.user_api_key = user_api_key
self.client = OpenAI(
api_key=self.api_key, base_url="https://api.groq.com/openai/v1"
)
def _raw_gen(self, baseself, model, messages, stream=False, tools=None, **kwargs):
if tools:
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, tools=tools, **kwargs
)
return response.choices[0]
else:
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)
return response.choices[0].message.content
def _raw_gen_stream(
self, baseself, model, messages, stream=True, tools=None, **kwargs
):
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)
for line in response:
if line.choices[0].delta.content is not None:
yield line.choices[0].delta.content

View File

@@ -1,5 +1,4 @@
import logging
import uuid
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Generator, List, Optional, Union
@@ -17,7 +16,6 @@ class ToolCall:
name: str
arguments: Union[str, Dict]
index: Optional[int] = None
thought_signature: Optional[str] = None
@classmethod
def from_dict(cls, data: Dict) -> "ToolCall":
@@ -105,7 +103,6 @@ class LLMHandler(ABC):
"""
Prepare messages with attachments and provider-specific formatting.
Args:
agent: The agent instance
messages: Original messages
@@ -119,40 +116,11 @@ class LLMHandler(ABC):
logger.info(f"Preparing messages with {len(attachments)} attachments")
supported_types = agent.llm.get_supported_attachment_types()
# Check if provider supports images but not PDF (synthetic PDF support)
supports_images = any(t.startswith("image/") for t in supported_types)
supports_pdf = "application/pdf" in supported_types
# Process attachments, converting PDFs to images if needed
processed_attachments = []
for attachment in attachments:
mime_type = attachment.get("mime_type")
# Synthetic PDF support: convert PDF to images if LLM supports images but not PDF
if mime_type == "application/pdf" and supports_images and not supports_pdf:
logger.info(
f"Converting PDF to images for synthetic PDF support: {attachment.get('path', 'unknown')}"
)
try:
converted_images = self._convert_pdf_to_images(attachment)
processed_attachments.extend(converted_images)
logger.info(
f"Converted PDF to {len(converted_images)} images"
)
except Exception as e:
logger.error(
f"Failed to convert PDF to images, falling back to text: {e}"
)
# Fall back to treating as unsupported (text extraction)
processed_attachments.append(attachment)
else:
processed_attachments.append(attachment)
supported_attachments = [
a for a in processed_attachments if a.get("mime_type") in supported_types
a for a in attachments if a.get("mime_type") in supported_types
]
unsupported_attachments = [
a for a in processed_attachments if a.get("mime_type") not in supported_types
a for a in attachments if a.get("mime_type") not in supported_types
]
# Process supported attachments with the LLM's custom method
@@ -175,37 +143,6 @@ class LLMHandler(ABC):
)
return messages
def _convert_pdf_to_images(self, attachment: Dict) -> List[Dict]:
"""
Convert a PDF attachment to a list of image attachments.
This enables synthetic PDF support for LLMs that support images but not PDFs.
Args:
attachment: PDF attachment dictionary with 'path' and optional 'content'
Returns:
List of image attachment dictionaries with 'data', 'mime_type', and 'page'
"""
from application.utils import convert_pdf_to_images
from application.storage.storage_creator import StorageCreator
file_path = attachment.get("path")
if not file_path:
raise ValueError("No file path provided in PDF attachment")
storage = StorageCreator.get_storage()
# Convert PDF to images
images_data = convert_pdf_to_images(
file_path=file_path,
storage=storage,
max_pages=20,
dpi=150,
)
return images_data
def _append_unsupported_attachments(
self, messages: List[Dict], attachments: List[Dict]
) -> List[Dict]:
@@ -241,406 +178,6 @@ class LLMHandler(ABC):
system_msg["content"] += f"\n\n{combined_text}"
return prepared_messages
def _prune_messages_minimal(self, messages: List[Dict]) -> Optional[List[Dict]]:
"""
Build a minimal context: system prompt + latest user message only.
Drops all tool/function messages to shrink context aggressively.
"""
system_message = next((m for m in messages if m.get("role") == "system"), None)
if not system_message:
logger.warning("Cannot prune messages minimally: missing system message.")
return None
last_non_system = None
for m in reversed(messages):
if m.get("role") == "user":
last_non_system = m
break
if not last_non_system and m.get("role") not in ("system", None):
last_non_system = m
if not last_non_system:
logger.warning("Cannot prune messages minimally: missing user/assistant messages.")
return None
logger.info("Pruning context to system + latest user/assistant message to proceed.")
return [system_message, last_non_system]
def _extract_text_from_content(self, content: Any) -> str:
"""
Convert message content (str or list of parts) to plain text for compression.
"""
if isinstance(content, str):
return content
if isinstance(content, list):
parts_text = []
for item in content:
if isinstance(item, dict):
if "text" in item and item["text"] is not None:
parts_text.append(str(item["text"]))
elif "function_call" in item or "function_response" in item:
# Keep serialized function calls/responses so the compressor sees actions
parts_text.append(str(item))
elif "files" in item:
parts_text.append(str(item))
return "\n".join(parts_text)
return ""
def _build_conversation_from_messages(self, messages: List[Dict]) -> Optional[Dict]:
"""
Build a conversation-like dict from current messages so we can compress
even when the conversation isn't persisted yet. Includes tool calls/results.
"""
queries = []
current_prompt = None
current_tool_calls = {}
def _commit_query(response_text: str):
nonlocal current_prompt, current_tool_calls
if current_prompt is None and not response_text:
return
tool_calls_list = list(current_tool_calls.values())
queries.append(
{
"prompt": current_prompt or "",
"response": response_text,
"tool_calls": tool_calls_list,
}
)
current_prompt = None
current_tool_calls = {}
for message in messages:
role = message.get("role")
content = message.get("content")
if role == "user":
current_prompt = self._extract_text_from_content(content)
elif role in {"assistant", "model"}:
# If this assistant turn contains tool calls, collect them; otherwise commit a response.
if isinstance(content, list):
for item in content:
if "function_call" in item:
fc = item["function_call"]
call_id = fc.get("call_id") or str(uuid.uuid4())
current_tool_calls[call_id] = {
"tool_name": "unknown_tool",
"action_name": fc.get("name"),
"arguments": fc.get("args"),
"result": None,
"status": "called",
"call_id": call_id,
}
elif "function_response" in item:
fr = item["function_response"]
call_id = fr.get("call_id") or str(uuid.uuid4())
current_tool_calls[call_id] = {
"tool_name": "unknown_tool",
"action_name": fr.get("name"),
"arguments": None,
"result": fr.get("response", {}).get("result"),
"status": "completed",
"call_id": call_id,
}
# No direct assistant text here; continue to next message
continue
response_text = self._extract_text_from_content(content)
_commit_query(response_text)
elif role == "tool":
# Attach tool outputs to the latest pending tool call if possible
tool_text = self._extract_text_from_content(content)
# Attempt to parse function_response style
call_id = None
if isinstance(content, list):
for item in content:
if "function_response" in item and item["function_response"].get("call_id"):
call_id = item["function_response"]["call_id"]
break
if call_id and call_id in current_tool_calls:
current_tool_calls[call_id]["result"] = tool_text
current_tool_calls[call_id]["status"] = "completed"
elif queries:
queries[-1].setdefault("tool_calls", []).append(
{
"tool_name": "unknown_tool",
"action_name": "unknown_action",
"arguments": {},
"result": tool_text,
"status": "completed",
}
)
# If there's an unfinished prompt with tool_calls but no response yet, commit it
if current_prompt is not None or current_tool_calls:
_commit_query(response_text="")
if not queries:
return None
return {
"queries": queries,
"compression_metadata": {
"is_compressed": False,
"compression_points": [],
},
}
def _rebuild_messages_after_compression(
self,
messages: List[Dict],
compressed_summary: Optional[str],
recent_queries: List[Dict],
include_current_execution: bool = False,
include_tool_calls: bool = False,
) -> Optional[List[Dict]]:
"""
Rebuild the message list after compression so tool execution can continue.
Delegates to MessageBuilder for the actual reconstruction.
"""
from application.api.answer.services.compression.message_builder import (
MessageBuilder,
)
return MessageBuilder.rebuild_messages_after_compression(
messages=messages,
compressed_summary=compressed_summary,
recent_queries=recent_queries,
include_current_execution=include_current_execution,
include_tool_calls=include_tool_calls,
)
def _perform_mid_execution_compression(
self, agent, messages: List[Dict]
) -> tuple[bool, Optional[List[Dict]]]:
"""
Perform compression during tool execution and rebuild messages.
Uses the new orchestrator for simplified compression.
Args:
agent: The agent instance
messages: Current conversation messages
Returns:
(success: bool, rebuilt_messages: Optional[List[Dict]])
"""
try:
from application.api.answer.services.compression import (
CompressionOrchestrator,
)
from application.api.answer.services.conversation_service import (
ConversationService,
)
conversation_service = ConversationService()
orchestrator = CompressionOrchestrator(conversation_service)
# Get conversation from database (may be None for new sessions)
conversation = conversation_service.get_conversation(
agent.conversation_id, agent.initial_user_id
)
if conversation:
# Merge current in-flight messages (including tool calls)
conversation_from_msgs = self._build_conversation_from_messages(messages)
if conversation_from_msgs:
conversation = conversation_from_msgs
else:
logger.warning(
"Could not load conversation for compression; attempting in-memory compression"
)
return self._perform_in_memory_compression(agent, messages)
# Use orchestrator to perform compression
result = orchestrator.compress_mid_execution(
conversation_id=agent.conversation_id,
user_id=agent.initial_user_id,
model_id=agent.model_id,
decoded_token=getattr(agent, "decoded_token", {}),
current_conversation=conversation,
)
if not result.success:
logger.warning(f"Mid-execution compression failed: {result.error}")
# Try minimal pruning as fallback
pruned = self._prune_messages_minimal(messages)
if pruned:
agent.context_limit_reached = False
agent.current_token_count = 0
return True, pruned
return False, None
if not result.compression_performed:
logger.warning("Compression not performed")
return False, None
# Check if compression actually reduced tokens
if result.metadata:
if result.metadata.compressed_token_count >= result.metadata.original_token_count:
logger.warning(
"Compression did not reduce token count; falling back to minimal pruning"
)
pruned = self._prune_messages_minimal(messages)
if pruned:
agent.context_limit_reached = False
agent.current_token_count = 0
return True, pruned
return False, None
logger.info(
f"Mid-execution compression successful - ratio: {result.metadata.compression_ratio:.1f}x, "
f"saved {result.metadata.original_token_count - result.metadata.compressed_token_count} tokens"
)
# Also store the compression summary as a visible message
if result.metadata:
conversation_service.append_compression_message(
agent.conversation_id, result.metadata.to_dict()
)
# Update agent's compressed summary for downstream persistence
agent.compressed_summary = result.compressed_summary
agent.compression_metadata = result.metadata.to_dict() if result.metadata else None
agent.compression_saved = False
# Reset the context limit flag so tools can continue
agent.context_limit_reached = False
agent.current_token_count = 0
# Rebuild messages
rebuilt_messages = self._rebuild_messages_after_compression(
messages,
result.compressed_summary,
result.recent_queries,
include_current_execution=False,
include_tool_calls=False,
)
if rebuilt_messages is None:
return False, None
return True, rebuilt_messages
except Exception as e:
logger.error(
f"Error performing mid-execution compression: {str(e)}", exc_info=True
)
return False, None
def _perform_in_memory_compression(
self, agent, messages: List[Dict]
) -> tuple[bool, Optional[List[Dict]]]:
"""
Fallback compression path when the conversation is not yet persisted.
Uses CompressionService directly without DB persistence.
"""
try:
from application.api.answer.services.compression.service import (
CompressionService,
)
from application.core.model_utils import (
get_api_key_for_provider,
get_provider_from_model_id,
)
from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
conversation = self._build_conversation_from_messages(messages)
if not conversation:
logger.warning(
"Cannot perform in-memory compression: no user/assistant turns found"
)
return False, None
compression_model = (
settings.COMPRESSION_MODEL_OVERRIDE
if settings.COMPRESSION_MODEL_OVERRIDE
else agent.model_id
)
provider = get_provider_from_model_id(compression_model)
api_key = get_api_key_for_provider(provider)
compression_llm = LLMCreator.create_llm(
provider,
api_key,
getattr(agent, "user_api_key", None),
getattr(agent, "decoded_token", None),
model_id=compression_model,
)
# Create service without DB persistence capability
compression_service = CompressionService(
llm=compression_llm,
model_id=compression_model,
conversation_service=None, # No DB updates for in-memory
)
queries_count = len(conversation.get("queries", []))
compress_up_to = queries_count - 1
if compress_up_to < 0 or queries_count == 0:
logger.warning("Not enough queries to compress in-memory context")
return False, None
metadata = compression_service.compress_conversation(
conversation,
compress_up_to_index=compress_up_to,
)
# If compression doesn't reduce tokens, fall back to minimal pruning
if (
metadata.compressed_token_count
>= metadata.original_token_count
):
logger.warning(
"In-memory compression did not reduce token count; falling back to minimal pruning"
)
pruned = self._prune_messages_minimal(messages)
if pruned:
agent.context_limit_reached = False
agent.current_token_count = 0
return True, pruned
return False, None
# Attach metadata to synthetic conversation
conversation["compression_metadata"] = {
"is_compressed": True,
"compression_points": [metadata.to_dict()],
}
compressed_summary, recent_queries = (
compression_service.get_compressed_context(conversation)
)
agent.compressed_summary = compressed_summary
agent.compression_metadata = metadata.to_dict()
agent.compression_saved = False
agent.context_limit_reached = False
agent.current_token_count = 0
rebuilt_messages = self._rebuild_messages_after_compression(
messages,
compressed_summary,
recent_queries,
include_current_execution=False,
include_tool_calls=False,
)
if rebuilt_messages is None:
return False, None
logger.info(
f"In-memory compression successful - ratio: {metadata.compression_ratio:.1f}x, "
f"saved {metadata.original_token_count - metadata.compressed_token_count} tokens"
)
return True, rebuilt_messages
except Exception as e:
logger.error(
f"Error performing in-memory compression: {str(e)}", exc_info=True
)
return False, None
def handle_tool_calls(
self, agent, tool_calls: List[ToolCall], tools_dict: Dict, messages: List[Dict]
) -> Generator:
@@ -658,110 +195,7 @@ class LLMHandler(ABC):
"""
updated_messages = messages.copy()
for i, call in enumerate(tool_calls):
# Check context limit before executing tool call
if hasattr(agent, '_check_context_limit') and agent._check_context_limit(updated_messages):
# Context limit reached - attempt mid-execution compression
compression_attempted = False
compression_successful = False
try:
from application.core.settings import settings
compression_enabled = settings.ENABLE_CONVERSATION_COMPRESSION
except Exception:
compression_enabled = False
if compression_enabled:
compression_attempted = True
try:
logger.info(
f"Context limit reached with {len(tool_calls) - i} remaining tool calls. "
f"Attempting mid-execution compression..."
)
# Trigger mid-execution compression (DB-backed if available, otherwise in-memory)
compression_successful, rebuilt_messages = self._perform_mid_execution_compression(
agent, updated_messages
)
if compression_successful and rebuilt_messages is not None:
# Update the messages list with rebuilt compressed version
updated_messages = rebuilt_messages
# Yield compression success message
yield {
"type": "info",
"data": {
"message": "Context window limit reached. Compressed conversation history to continue processing."
}
}
logger.info(
f"Mid-execution compression successful. Continuing with {len(tool_calls) - i} remaining tool calls."
)
# Proceed to execute the current tool call with the reduced context
else:
logger.warning("Mid-execution compression attempted but failed. Skipping remaining tools.")
except Exception as e:
logger.error(f"Error during mid-execution compression: {str(e)}", exc_info=True)
compression_attempted = True
compression_successful = False
# If compression wasn't attempted or failed, skip remaining tools
if not compression_successful:
if i == 0:
# Special case: limit reached before executing any tools
# This can happen when previous tool responses pushed context over limit
if compression_attempted:
logger.warning(
f"Context limit reached before executing any tools. "
f"Compression attempted but failed. "
f"Skipping all {len(tool_calls)} pending tool call(s). "
f"This typically occurs when previous tool responses contained large amounts of data."
)
else:
logger.warning(
f"Context limit reached before executing any tools. "
f"Skipping all {len(tool_calls)} pending tool call(s). "
f"This typically occurs when previous tool responses contained large amounts of data. "
f"Consider enabling compression or using a model with larger context window."
)
else:
# Normal case: executed some tools, now stopping
tool_word = "tool call" if i == 1 else "tool calls"
remaining = len(tool_calls) - i
remaining_word = "tool call" if remaining == 1 else "tool calls"
if compression_attempted:
logger.warning(
f"Context limit reached after executing {i} {tool_word}. "
f"Compression attempted but failed. "
f"Skipping remaining {remaining} {remaining_word}."
)
else:
logger.warning(
f"Context limit reached after executing {i} {tool_word}. "
f"Skipping remaining {remaining} {remaining_word}. "
f"Consider enabling compression or using a model with larger context window."
)
# Mark remaining tools as skipped
for remaining_call in tool_calls[i:]:
skip_message = {
"type": "tool_call",
"data": {
"tool_name": "system",
"call_id": remaining_call.id,
"action_name": remaining_call.name,
"arguments": {},
"result": "Skipped: Context limit reached. Too many tool calls in conversation.",
"status": "skipped"
}
}
yield skip_message
# Set flag on agent
agent.context_limit_reached = True
break
for call in tool_calls:
try:
self.tool_calls.append(call)
tool_executor_gen = agent._execute_tool_action(tools_dict, call)
@@ -771,26 +205,21 @@ class LLMHandler(ABC):
except StopIteration as e:
tool_response, call_id = e.value
break
function_call_content = {
"function_call": {
"name": call.name,
"args": call.arguments,
"call_id": call_id,
}
}
# Include thought_signature for Google Gemini 3 models
# It should be at the same level as function_call, not inside it
if call.thought_signature:
function_call_content["thought_signature"] = call.thought_signature
updated_messages.append(
{
"role": "assistant",
"content": [function_call_content],
"content": [
{
"function_call": {
"name": call.name,
"args": call.arguments,
"call_id": call_id,
}
}
],
}
)
updated_messages.append(self.create_tool_message(call, tool_response))
except Exception as e:
logger.error(f"Error executing tool: {str(e)}", exc_info=True)
@@ -878,9 +307,6 @@ class LLMHandler(ABC):
tool_calls = {}
for chunk in self._iterate_stream(response):
if isinstance(chunk, dict) and chunk.get("type") == "thought":
yield chunk
continue
if isinstance(chunk, str):
yield chunk
continue
@@ -897,13 +323,7 @@ class LLMHandler(ABC):
if call.name:
existing.name = call.name
if call.arguments:
if existing.arguments is None:
existing.arguments = call.arguments
else:
existing.arguments += call.arguments
# Preserve thought_signature for Google Gemini 3 models
if call.thought_signature:
existing.thought_signature = call.thought_signature
existing.arguments += call.arguments
if parsed.finish_reason == "tool_calls":
tool_handler_gen = self.handle_tool_calls(
agent, list(tool_calls.values()), tools_dict, messages
@@ -916,21 +336,8 @@ class LLMHandler(ABC):
break
tool_calls = {}
# Check if context limit was reached during tool execution
if hasattr(agent, 'context_limit_reached') and agent.context_limit_reached:
# Add system message warning about context limit
messages.append({
"role": "system",
"content": (
"WARNING: Context window limit has been reached. "
"Please provide a final response to the user without making additional tool calls. "
"Summarize the work completed so far."
)
})
logger.info("Context limit reached - instructing agent to wrap up")
response = agent.llm.gen_stream(
model=agent.model_id, messages=messages, tools=agent.tools if not agent.context_limit_reached else None
model=agent.model_id, messages=messages, tools=agent.tools
)
self.llm_calls.append(build_stack_data(agent.llm))

View File

@@ -19,20 +19,15 @@ class GoogleLLMHandler(LLMHandler):
)
if hasattr(response, "candidates"):
parts = response.candidates[0].content.parts if response.candidates else []
tool_calls = []
for idx, part in enumerate(parts):
if hasattr(part, "function_call") and part.function_call is not None:
has_sig = hasattr(part, "thought_signature") and part.thought_signature is not None
thought_sig = part.thought_signature if has_sig else None
tool_calls.append(
ToolCall(
id=str(uuid.uuid4()),
name=part.function_call.name,
arguments=part.function_call.args,
index=idx,
thought_signature=thought_sig,
)
)
tool_calls = [
ToolCall(
id=str(uuid.uuid4()),
name=part.function_call.name,
arguments=part.function_call.args,
)
for part in parts
if hasattr(part, "function_call") and part.function_call is not None
]
content = " ".join(
part.text
@@ -46,17 +41,13 @@ class GoogleLLMHandler(LLMHandler):
raw_response=response,
)
else:
# This branch handles individual Part objects from streaming responses
tool_calls = []
if hasattr(response, "function_call") and response.function_call is not None:
has_sig = hasattr(response, "thought_signature") and response.thought_signature is not None
thought_sig = response.thought_signature if has_sig else None
if hasattr(response, "function_call"):
tool_calls.append(
ToolCall(
id=str(uuid.uuid4()),
name=response.function_call.name,
arguments=response.function_call.args,
thought_signature=thought_sig,
)
)
return LLMResponse(

View File

@@ -0,0 +1,68 @@
from application.llm.base import BaseLLM
class HuggingFaceLLM(BaseLLM):
def __init__(
self,
api_key=None,
user_api_key=None,
llm_name="Arc53/DocsGPT-7B",
q=False,
*args,
**kwargs,
):
global hf
from langchain.llms import HuggingFacePipeline
if q:
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
pipeline,
BitsAndBytesConfig,
)
tokenizer = AutoTokenizer.from_pretrained(llm_name)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
model = AutoModelForCausalLM.from_pretrained(
llm_name, quantization_config=bnb_config
)
else:
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
tokenizer = AutoTokenizer.from_pretrained(llm_name)
model = AutoModelForCausalLM.from_pretrained(llm_name)
super().__init__(*args, **kwargs)
self.api_key = api_key
self.user_api_key = user_api_key
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=2000,
device_map="auto",
eos_token_id=tokenizer.eos_token_id,
)
hf = HuggingFacePipeline(pipeline=pipe)
def _raw_gen(self, baseself, model, messages, stream=False, **kwargs):
context = messages[0]["content"]
user_question = messages[-1]["content"]
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
result = hf(prompt)
return result.content
def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs):
raise NotImplementedError("HuggingFaceLLM Streaming is not implemented yet.")

View File

@@ -4,12 +4,12 @@ from application.llm.anthropic import AnthropicLLM
from application.llm.docsgpt_provider import DocsGPTAPILLM
from application.llm.google_ai import GoogleLLM
from application.llm.groq import GroqLLM
from application.llm.huggingface import HuggingFaceLLM
from application.llm.llama_cpp import LlamaCpp
from application.llm.novita import NovitaLLM
from application.llm.openai import AzureOpenAILLM, OpenAILLM
from application.llm.premai import PremAILLM
from application.llm.sagemaker import SagemakerAPILLM
from application.llm.open_router import OpenRouterLLM
logger = logging.getLogger(__name__)
@@ -19,6 +19,7 @@ class LLMCreator:
"openai": OpenAILLM,
"azure_openai": AzureOpenAILLM,
"sagemaker": SagemakerAPILLM,
"huggingface": HuggingFaceLLM,
"llama.cpp": LlamaCpp,
"anthropic": AnthropicLLM,
"docsgpt": DocsGPTAPILLM,
@@ -26,7 +27,6 @@ class LLMCreator:
"groq": GroqLLM,
"google": GoogleLLM,
"novita": NovitaLLM,
"openrouter": OpenRouterLLM,
}
@classmethod

View File

@@ -1,15 +1,32 @@
from application.core.settings import settings
from application.llm.openai import OpenAILLM
NOVITA_BASE_URL = "https://api.novita.ai/v3/openai"
from application.llm.base import BaseLLM
from openai import OpenAI
class NovitaLLM(OpenAILLM):
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
super().__init__(
api_key=api_key or settings.API_KEY,
user_api_key=user_api_key,
base_url=base_url or NOVITA_BASE_URL,
*args,
**kwargs,
class NovitaLLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.client = OpenAI(api_key=api_key, base_url="https://api.novita.ai/v3/openai")
self.api_key = api_key
self.user_api_key = user_api_key
def _raw_gen(self, baseself, model, messages, stream=False, tools=None, **kwargs):
if tools:
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, tools=tools, **kwargs
)
return response.choices[0]
else:
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)
return response.choices[0].message.content
def _raw_gen_stream(
self, baseself, model, messages, stream=True, tools=None, **kwargs
):
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)
for line in response:
if line.choices[0].delta.content is not None:
yield line.choices[0].delta.content

View File

@@ -1,15 +0,0 @@
from application.core.settings import settings
from application.llm.openai import OpenAILLM
OPEN_ROUTER_BASE_URL = "https://openrouter.ai/api/v1"
class OpenRouterLLM(OpenAILLM):
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
super().__init__(
api_key=api_key or settings.OPEN_ROUTER_API_KEY or settings.API_KEY,
user_api_key=user_api_key,
base_url=base_url or OPEN_ROUTER_BASE_URL,
*args,
**kwargs,
)

View File

@@ -9,57 +9,6 @@ from application.llm.base import BaseLLM
from application.storage.storage_creator import StorageCreator
def _truncate_base64_for_logging(messages):
"""
Create a copy of messages with base64 data truncated for readable logging.
Args:
messages: List of message dicts
Returns:
Copy of messages with truncated base64 content
"""
import copy
def truncate_content(content):
if isinstance(content, str):
# Check if it looks like a data URL with base64
if content.startswith("data:") and ";base64," in content:
prefix_end = content.index(";base64,") + len(";base64,")
prefix = content[:prefix_end]
return f"{prefix}[BASE64_DATA_TRUNCATED, length={len(content) - prefix_end}]"
return content
elif isinstance(content, list):
return [truncate_item(item) for item in content]
elif isinstance(content, dict):
return {k: truncate_content(v) for k, v in content.items()}
return content
def truncate_item(item):
if isinstance(item, dict):
result = {}
for k, v in item.items():
if k == "url" and isinstance(v, str) and ";base64," in v:
prefix_end = v.index(";base64,") + len(";base64,")
prefix = v[:prefix_end]
result[k] = f"{prefix}[BASE64_DATA_TRUNCATED, length={len(v) - prefix_end}]"
elif k == "data" and isinstance(v, str) and len(v) > 100:
result[k] = f"[BASE64_DATA_TRUNCATED, length={len(v)}]"
else:
result[k] = truncate_content(v)
return result
return truncate_content(item)
truncated = []
for msg in messages:
msg_copy = copy.copy(msg)
if "content" in msg_copy:
msg_copy["content"] = truncate_content(msg_copy["content"])
truncated.append(msg_copy)
return truncated
class OpenAILLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
@@ -95,12 +44,12 @@ class OpenAILLM(BaseLLM):
if isinstance(content, str):
cleaned_messages.append({"role": role, "content": content})
elif isinstance(content, list):
# Collect all content parts into a single message
content_parts = []
for item in content:
if "function_call" in item:
# Function calls need their own message
if "text" in item:
cleaned_messages.append(
{"role": role, "content": item["text"]}
)
elif "function_call" in item:
cleaned_args = self._remove_null_values(
item["function_call"]["args"]
)
@@ -120,7 +69,6 @@ class OpenAILLM(BaseLLM):
}
)
elif "function_response" in item:
# Function responses need their own message
cleaned_messages.append(
{
"role": "tool",
@@ -133,69 +81,40 @@ class OpenAILLM(BaseLLM):
}
)
elif isinstance(item, dict):
# Collect content parts (text, images, files) into a single message
if "type" in item and item["type"] == "text" and "text" in item:
content_parts = []
if "text" in item:
content_parts.append(
{"type": "text", "text": item["text"]}
)
elif (
"type" in item
and item["type"] == "text"
and "text" in item
):
content_parts.append(item)
elif "type" in item and item["type"] == "file" and "file" in item:
elif (
"type" in item
and item["type"] == "file"
and "file" in item
):
content_parts.append(item)
elif "type" in item and item["type"] == "image_url" and "image_url" in item:
elif (
"type" in item
and item["type"] == "image_url"
and "image_url" in item
):
content_parts.append(item)
elif "text" in item and "type" not in item:
# Legacy format: {"text": "..."} without type
content_parts.append({"type": "text", "text": item["text"]})
# Add the collected content parts as a single message
if content_parts:
cleaned_messages.append({"role": role, "content": content_parts})
cleaned_messages.append(
{"role": role, "content": content_parts}
)
else:
raise ValueError(
f"Unexpected content dictionary format: {item}"
)
else:
raise ValueError(f"Unexpected content type: {type(content)}")
return cleaned_messages
@staticmethod
def _normalize_reasoning_value(value):
"""Normalize reasoning payloads from OpenAI-compatible stream chunks."""
if value is None:
return ""
if isinstance(value, str):
return value
if isinstance(value, list):
return "".join(
OpenAILLM._normalize_reasoning_value(item) for item in value
)
if isinstance(value, dict):
for key in ("text", "content", "value", "reasoning_content", "reasoning"):
normalized = OpenAILLM._normalize_reasoning_value(value.get(key))
if normalized:
return normalized
return ""
for attr in ("text", "content", "value"):
if hasattr(value, attr):
normalized = OpenAILLM._normalize_reasoning_value(getattr(value, attr))
if normalized:
return normalized
return ""
@classmethod
def _extract_reasoning_text(cls, delta):
"""Extract reasoning/thinking tokens from OpenAI-compatible delta chunks."""
if delta is None:
return ""
for key in (
"reasoning_content",
"reasoning",
"thinking",
"thinking_content",
):
value = getattr(delta, key, None)
if value is None and isinstance(delta, dict):
value = delta.get(key)
normalized = cls._normalize_reasoning_value(value)
if normalized:
return normalized
return ""
def _raw_gen(
self,
baseself,
@@ -208,11 +127,6 @@ class OpenAILLM(BaseLLM):
**kwargs,
):
messages = self._clean_messages_openai(messages)
logging.info(f"Cleaned messages: {_truncate_base64_for_logging(messages)}")
# Convert max_tokens to max_completion_tokens for newer models
if "max_tokens" in kwargs:
kwargs["max_completion_tokens"] = kwargs.pop("max_tokens")
request_params = {
"model": model,
@@ -226,7 +140,7 @@ class OpenAILLM(BaseLLM):
if response_format:
request_params["response_format"] = response_format
response = self.client.chat.completions.create(**request_params)
logging.info(f"OpenAI response: {response}")
if tools:
return response.choices[0]
else:
@@ -244,11 +158,6 @@ class OpenAILLM(BaseLLM):
**kwargs,
):
messages = self._clean_messages_openai(messages)
logging.info(f"Cleaned messages: {_truncate_base64_for_logging(messages)}")
# Convert max_tokens to max_completion_tokens for newer models
if "max_tokens" in kwargs:
kwargs["max_completion_tokens"] = kwargs.pop("max_tokens")
request_params = {
"model": model,
@@ -265,27 +174,14 @@ class OpenAILLM(BaseLLM):
try:
for line in response:
logging.debug(f"OpenAI stream line: {line}")
if not getattr(line, "choices", None):
continue
choice = line.choices[0]
delta = getattr(choice, "delta", None)
reasoning_text = self._extract_reasoning_text(delta)
if reasoning_text:
yield {"type": "thought", "thought": reasoning_text}
content = getattr(delta, "content", None)
if isinstance(content, str) and content:
yield content
continue
has_tool_calls = bool(getattr(delta, "tool_calls", None))
finish_reason = getattr(choice, "finish_reason", None)
# Yield non-content chunks only when needed for tool-call handling.
if has_tool_calls or finish_reason == "tool_calls":
yield choice
if (
len(line.choices) > 0
and line.choices[0].delta.content is not None
and len(line.choices[0].delta.content) > 0
):
yield line.choices[0].delta.content
elif len(line.choices) > 0:
yield line.choices[0]
finally:
if hasattr(response, "close"):
response.close()
@@ -354,14 +250,17 @@ class OpenAILLM(BaseLLM):
"""
Return a list of MIME types supported by OpenAI for file uploads.
This reads from the model config to ensure consistency.
If no model config found, falls back to images only (safest default).
Returns:
list: List of supported MIME types
"""
from application.core.model_configs import OPENAI_ATTACHMENTS
return OPENAI_ATTACHMENTS
return [
"application/pdf",
"image/png",
"image/jpeg",
"image/jpg",
"image/webp",
"image/gif",
]
def prepare_messages_with_attachments(self, messages, attachments=None):
"""
@@ -398,16 +297,10 @@ class OpenAILLM(BaseLLM):
prepared_messages[user_message_index]["content"] = []
for attachment in attachments:
mime_type = attachment.get("mime_type")
logging.info(f"Processing attachment with mime_type: {mime_type}, has_data: {'data' in attachment}, has_path: {'path' in attachment}")
if mime_type and mime_type.startswith("image/"):
try:
# Check if this is a pre-converted image (from PDF-to-image conversion)
if "data" in attachment:
base64_image = attachment["data"]
else:
base64_image = self._get_base64_image(attachment)
base64_image = self._get_base64_image(attachment)
prepared_messages[user_message_index]["content"].append(
{
"type": "image_url",
@@ -416,7 +309,6 @@ class OpenAILLM(BaseLLM):
},
}
)
except Exception as e:
logging.error(
f"Error processing image attachment: {e}", exc_info=True
@@ -431,7 +323,6 @@ class OpenAILLM(BaseLLM):
# Handle PDFs using the file API
elif mime_type == "application/pdf":
logging.info(f"Attempting to upload PDF to OpenAI: {attachment.get('path', 'unknown')}")
try:
file_id = self._upload_file_to_openai(attachment)
prepared_messages[user_message_index]["content"].append(
@@ -446,8 +337,6 @@ class OpenAILLM(BaseLLM):
"text": f"File content:\n\n{attachment['content']}",
}
)
else:
logging.warning(f"Unsupported attachment type in OpenAI provider: {mime_type}")
return prepared_messages
def _get_base64_image(self, attachment):

View File

@@ -62,26 +62,15 @@ class BaseConnectorAuth(ABC):
def is_token_expired(self, token_info: Dict[str, Any]) -> bool:
"""
Check if a token is expired.
Args:
token_info: Token information dictionary
Returns:
True if token is expired, False otherwise
"""
pass
def sanitize_token_info(self, token_info: Dict[str, Any], **extra_fields) -> Dict[str, Any]:
"""Extract the fields safe to persist in the session store.
"""
return {
"access_token": token_info.get("access_token"),
"refresh_token": token_info.get("refresh_token"),
"token_uri": token_info.get("token_uri"),
"expiry": token_info.get("expiry"),
**extra_fields,
}
class BaseConnectorLoader(ABC):
"""

View File

@@ -1,7 +1,5 @@
from application.parser.connectors.google_drive.loader import GoogleDriveLoader
from application.parser.connectors.google_drive.auth import GoogleDriveAuth
from application.parser.connectors.share_point.auth import SharePointAuth
from application.parser.connectors.share_point.loader import SharePointLoader
class ConnectorCreator:
@@ -14,12 +12,10 @@ class ConnectorCreator:
connectors = {
"google_drive": GoogleDriveLoader,
"share_point": SharePointLoader,
}
auth_providers = {
"google_drive": GoogleDriveAuth,
"share_point": SharePointAuth,
}
@classmethod

View File

@@ -232,6 +232,10 @@ class GoogleDriveAuth(BaseConnectorAuth):
if missing_fields:
raise ValueError(f"Missing required token fields: {missing_fields}")
if 'client_id' not in token_info:
token_info['client_id'] = settings.GOOGLE_CLIENT_ID
if 'client_secret' not in token_info:
token_info['client_secret'] = settings.GOOGLE_CLIENT_SECRET
if 'token_uri' not in token_info:
token_info['token_uri'] = 'https://oauth2.googleapis.com/token'

View File

@@ -327,10 +327,15 @@ class GoogleDriveLoader(BaseConnectorLoader):
content_bytes = file_io.getvalue()
try:
return content_bytes.decode('utf-8')
content = content_bytes.decode('utf-8')
except UnicodeDecodeError:
logging.error(f"Could not decode file {file_id} as text")
return None
try:
content = content_bytes.decode('latin-1')
except UnicodeDecodeError:
logging.error(f"Could not decode file {file_id} as text")
return None
return content
except HttpError as e:
logging.error(f"HTTP error downloading file {file_id}: {e.resp.status} - {e.content}")

View File

@@ -1,10 +0,0 @@
"""
Share Point connector package for DocsGPT.
This module provides authentication and document loading capabilities for Share Point.
"""
from .auth import SharePointAuth
from .loader import SharePointLoader
__all__ = ['SharePointAuth', 'SharePointLoader']

View File

@@ -1,152 +0,0 @@
import datetime
import logging
from typing import Optional, Dict, Any
from msal import ConfidentialClientApplication
from application.core.settings import settings
from application.parser.connectors.base import BaseConnectorAuth
logger = logging.getLogger(__name__)
class SharePointAuth(BaseConnectorAuth):
"""
Handles Microsoft OAuth 2.0 authentication for SharePoint/OneDrive.
Note: Files.Read scope allows access to files the user has granted access to,
similar to Google Drive's drive.file scope.
"""
SCOPES = [
"Files.Read",
"Sites.Read.All",
"User.Read",
]
def __init__(self):
self.client_id = settings.MICROSOFT_CLIENT_ID
self.client_secret = settings.MICROSOFT_CLIENT_SECRET
if not self.client_id:
raise ValueError(
"Microsoft OAuth credentials not configured. Please set MICROSOFT_CLIENT_ID in settings."
)
if not self.client_secret:
raise ValueError(
"Microsoft OAuth credentials not configured. Please set MICROSOFT_CLIENT_SECRET in settings."
)
self.redirect_uri = settings.CONNECTOR_REDIRECT_BASE_URI
self.tenant_id = settings.MICROSOFT_TENANT_ID
self.authority = getattr(settings, "MICROSOFT_AUTHORITY", f"https://login.microsoftonline.com/{self.tenant_id}")
self.auth_app = ConfidentialClientApplication(
client_id=self.client_id,
client_credential=self.client_secret,
authority=self.authority
)
def get_authorization_url(self, state: Optional[str] = None) -> str:
return self.auth_app.get_authorization_request_url(
scopes=self.SCOPES, state=state, redirect_uri=self.redirect_uri
)
def exchange_code_for_tokens(self, authorization_code: str) -> Dict[str, Any]:
result = self.auth_app.acquire_token_by_authorization_code(
code=authorization_code,
scopes=self.SCOPES,
redirect_uri=self.redirect_uri
)
if "error" in result:
logger.error("Token exchange failed: %s", result.get("error_description"))
raise ValueError(f"Error acquiring token: {result.get('error_description')}")
return self.map_token_response(result)
def refresh_access_token(self, refresh_token: str) -> Dict[str, Any]:
result = self.auth_app.acquire_token_by_refresh_token(refresh_token=refresh_token, scopes=self.SCOPES)
if "error" in result:
logger.error("Token refresh failed: %s", result.get("error_description"))
raise ValueError(f"Error refreshing token: {result.get('error_description')}")
return self.map_token_response(result)
def get_token_info_from_session(self, session_token: str) -> Dict[str, Any]:
try:
from application.core.mongo_db import MongoDB
from application.core.settings import settings
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
sessions_collection = db["connector_sessions"]
session = sessions_collection.find_one({"session_token": session_token})
if not session:
raise ValueError(f"Invalid session token: {session_token}")
if "token_info" not in session:
raise ValueError("Session missing token information")
token_info = session["token_info"]
if not token_info:
raise ValueError("Invalid token information")
required_fields = ["access_token", "refresh_token"]
missing_fields = [field for field in required_fields if field not in token_info or not token_info.get(field)]
if missing_fields:
raise ValueError(f"Missing required token fields: {missing_fields}")
if 'token_uri' not in token_info:
token_info['token_uri'] = f"https://login.microsoftonline.com/{settings.MICROSOFT_TENANT_ID}/oauth2/v2.0/token"
return token_info
except Exception as e:
logger.error("Failed to retrieve token from session: %s", e)
raise ValueError(f"Failed to retrieve SharePoint token information: {str(e)}")
def is_token_expired(self, token_info: Dict[str, Any]) -> bool:
if not token_info:
return True
expiry_timestamp = token_info.get("expiry")
if expiry_timestamp is None:
return True
current_timestamp = int(datetime.datetime.now().timestamp())
return (expiry_timestamp - current_timestamp) < 60
def sanitize_token_info(self, token_info: Dict[str, Any], **extra_fields) -> Dict[str, Any]:
return super().sanitize_token_info(
token_info,
allows_shared_content=token_info.get("allows_shared_content", False),
**extra_fields,
)
PERSONAL_ACCOUNT_TENANT_ID = "9188040d-6c67-4c5b-b112-36a304b66dad"
def _allows_shared_content(self, id_token_claims: Dict[str, Any]) -> bool:
"""Return True when the account is a work/school tenant that can access SharePoint shared content."""
tid = id_token_claims.get("tid", "")
return bool(tid) and tid != self.PERSONAL_ACCOUNT_TENANT_ID
def map_token_response(self, result) -> Dict[str, Any]:
claims = result.get("id_token_claims", {})
return {
"access_token": result.get("access_token"),
"refresh_token": result.get("refresh_token"),
"token_uri": claims.get("iss"),
"scopes": result.get("scope"),
"expiry": claims.get("exp"),
"allows_shared_content": self._allows_shared_content(claims),
"user_info": {
"name": claims.get("name"),
"email": claims.get("preferred_username"),
},
}

View File

@@ -1,649 +0,0 @@
"""
SharePoint/OneDrive loader for DocsGPT.
Loads documents from SharePoint/OneDrive using Microsoft Graph API.
"""
import functools
import logging
import os
from typing import List, Dict, Any, Optional, Tuple
from urllib.parse import quote
import requests
from application.parser.connectors.base import BaseConnectorLoader
from application.parser.connectors.share_point.auth import SharePointAuth
from application.parser.schema.base import Document
def _retry_on_auth_failure(func):
"""Retry once after refreshing the access token on 401/403 responses."""
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
try:
return func(self, *args, **kwargs)
except requests.exceptions.HTTPError as e:
if e.response is not None and e.response.status_code in (401, 403):
logging.info(f"Auth failure in {func.__name__}, refreshing token and retrying")
try:
new_token_info = self.auth.refresh_access_token(self.refresh_token)
self.access_token = new_token_info.get('access_token')
except Exception as refresh_error:
raise ValueError(
f"Authentication failed and could not be refreshed: {refresh_error}"
) from e
return func(self, *args, **kwargs)
raise
return wrapper
class SharePointLoader(BaseConnectorLoader):
SUPPORTED_MIME_TYPES = {
'application/pdf': '.pdf',
'application/vnd.openxmlformats-officedocument.wordprocessingml.document': '.docx',
'application/vnd.openxmlformats-officedocument.presentationml.presentation': '.pptx',
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': '.xlsx',
'application/msword': '.doc',
'application/vnd.ms-powerpoint': '.ppt',
'application/vnd.ms-excel': '.xls',
'text/plain': '.txt',
'text/csv': '.csv',
'text/html': '.html',
'text/markdown': '.md',
'text/x-rst': '.rst',
'application/json': '.json',
'application/epub+zip': '.epub',
'application/rtf': '.rtf',
'image/jpeg': '.jpg',
'image/png': '.png',
}
EXTENSION_TO_MIME = {v: k for k, v in SUPPORTED_MIME_TYPES.items()}
GRAPH_API_BASE = "https://graph.microsoft.com/v1.0"
def __init__(self, session_token: str):
self.auth = SharePointAuth()
self.session_token = session_token
token_info = self.auth.get_token_info_from_session(session_token)
self.access_token = token_info.get('access_token')
self.refresh_token = token_info.get('refresh_token')
self.allows_shared_content = token_info.get('allows_shared_content', False)
if not self.access_token:
raise ValueError("No access token found in session")
self.next_page_token = None
def _get_headers(self) -> Dict[str, str]:
return {
'Authorization': f'Bearer {self.access_token}',
'Accept': 'application/json'
}
def _ensure_valid_token(self):
if not self.access_token:
raise ValueError("No access token available")
token_info = {'access_token': self.access_token, 'expiry': None}
if self.auth.is_token_expired(token_info):
logging.info("Token expired, attempting refresh")
try:
new_token_info = self.auth.refresh_access_token(self.refresh_token)
self.access_token = new_token_info.get('access_token')
except Exception:
raise ValueError("Failed to refresh access token")
def _get_item_url(self, item_ref: str) -> str:
if ':' in item_ref:
drive_id, item_id = item_ref.split(':', 1)
return f"{self.GRAPH_API_BASE}/drives/{drive_id}/items/{item_id}"
return f"{self.GRAPH_API_BASE}/me/drive/items/{item_ref}"
def _process_file(self, file_metadata: Dict[str, Any], load_content: bool = True) -> Optional[Document]:
try:
drive_item_id = file_metadata.get('id')
file_name = file_metadata.get('name', 'Unknown')
file_data = file_metadata.get('file', {})
mime_type = file_data.get('mimeType', 'application/octet-stream')
if mime_type not in self.SUPPORTED_MIME_TYPES:
logging.info(f"Skipping unsupported file type: {mime_type} for file {file_name}")
return None
doc_metadata = {
'file_name': file_name,
'mime_type': mime_type,
'size': file_metadata.get('size'),
'created_time': file_metadata.get('createdDateTime'),
'modified_time': file_metadata.get('lastModifiedDateTime'),
'source': 'share_point'
}
if not load_content:
return Document(
text="",
doc_id=drive_item_id,
extra_info=doc_metadata
)
content = self._download_file_content(drive_item_id)
if content is None:
logging.warning(f"Could not load content for file {file_name} ({drive_item_id})")
return None
return Document(
text=content,
doc_id=drive_item_id,
extra_info=doc_metadata
)
except Exception as e:
logging.error(f"Error processing file: {e}")
return None
def load_data(self, inputs: Dict[str, Any]) -> List[Document]:
try:
documents: List[Document] = []
folder_id = inputs.get('folder_id')
file_ids = inputs.get('file_ids', [])
limit = inputs.get('limit', 100)
list_only = inputs.get('list_only', False)
load_content = not list_only
page_token = inputs.get('page_token')
search_query = inputs.get('search_query')
self.next_page_token = None
shared = inputs.get('shared', False)
if file_ids:
for file_id in file_ids:
try:
doc = self._load_file_by_id(file_id, load_content=load_content)
if doc:
if not search_query or (
search_query.lower() in doc.extra_info.get('file_name', '').lower()
):
documents.append(doc)
except Exception as e:
logging.error(f"Error loading file {file_id}: {e}")
continue
elif shared:
if not self.allows_shared_content:
logging.warning("Shared content is only available for work/school Microsoft accounts")
return []
documents = self._list_shared_items(
limit=limit,
load_content=load_content,
page_token=page_token,
search_query=search_query
)
else:
parent_id = folder_id if folder_id else 'root'
documents = self._list_items_in_parent(
parent_id,
limit=limit,
load_content=load_content,
page_token=page_token,
search_query=search_query
)
logging.info(f"Loaded {len(documents)} documents from SharePoint/OneDrive")
return documents
except Exception as e:
logging.error(f"Error loading data from SharePoint/OneDrive: {e}", exc_info=True)
raise
@_retry_on_auth_failure
def _load_file_by_id(self, file_id: str, load_content: bool = True) -> Optional[Document]:
self._ensure_valid_token()
try:
url = self._get_item_url(file_id)
params = {'$select': 'id,name,file,createdDateTime,lastModifiedDateTime,size'}
response = requests.get(url, headers=self._get_headers(), params=params)
response.raise_for_status()
file_metadata = response.json()
return self._process_file(file_metadata, load_content=load_content)
except requests.exceptions.HTTPError:
raise
except Exception as e:
logging.error(f"Error loading file {file_id}: {e}")
return None
@_retry_on_auth_failure
def _list_items_in_parent(self, parent_id: str, limit: int = 100, load_content: bool = False, page_token: Optional[str] = None, search_query: Optional[str] = None) -> List[Document]:
self._ensure_valid_token()
documents: List[Document] = []
try:
url = f"{self._get_item_url(parent_id)}/children"
params = {'$top': min(100, limit) if limit else 100, '$select': 'id,name,file,folder,createdDateTime,lastModifiedDateTime,size'}
if page_token:
params['$skipToken'] = page_token
if search_query:
encoded_query = quote(search_query, safe='')
if ':' in parent_id:
drive_id = parent_id.split(':', 1)[0]
search_url = f"{self.GRAPH_API_BASE}/drives/{drive_id}/root/search(q='{encoded_query}')"
else:
search_url = f"{self.GRAPH_API_BASE}/me/drive/search(q='{encoded_query}')"
response = requests.get(search_url, headers=self._get_headers(), params=params)
else:
response = requests.get(url, headers=self._get_headers(), params=params)
response.raise_for_status()
results = response.json()
items = results.get('value', [])
for item in items:
if 'folder' in item:
doc_metadata = {
'file_name': item.get('name', 'Unknown'),
'mime_type': 'folder',
'size': item.get('size'),
'created_time': item.get('createdDateTime'),
'modified_time': item.get('lastModifiedDateTime'),
'source': 'share_point',
'is_folder': True
}
documents.append(Document(text="", doc_id=item.get('id'), extra_info=doc_metadata))
else:
doc = self._process_file(item, load_content=load_content)
if doc:
documents.append(doc)
if limit and len(documents) >= limit:
break
next_link = results.get('@odata.nextLink')
if next_link:
from urllib.parse import urlparse, parse_qs
parsed = urlparse(next_link)
query_params = parse_qs(parsed.query)
skiptoken_list = query_params.get('$skiptoken')
if skiptoken_list:
self.next_page_token = skiptoken_list[0]
else:
self.next_page_token = None
else:
self.next_page_token = None
return documents
except Exception as e:
logging.error(f"Error listing items under parent {parent_id}: {e}")
return documents
def _resolve_mime_type(self, resource: Dict[str, Any]) -> Tuple[str, bool]:
"""Resolve mime type from resource, falling back to file extension."""
file_data = resource.get('file', {})
mime_type = file_data.get('mimeType') if file_data else None
if mime_type and mime_type in self.SUPPORTED_MIME_TYPES:
return mime_type, True
name = resource.get('name', '')
ext = os.path.splitext(name)[1].lower()
if ext in self.EXTENSION_TO_MIME:
return self.EXTENSION_TO_MIME[ext], True
return mime_type or 'application/octet-stream', False
def _get_user_drive_web_url(self) -> Optional[str]:
"""Fetch the current user's OneDrive web URL for KQL path exclusion."""
try:
response = requests.get(
f"{self.GRAPH_API_BASE}/me/drive",
headers=self._get_headers(),
params={'$select': 'webUrl'}
)
response.raise_for_status()
return response.json().get('webUrl')
except Exception as e:
logging.warning(f"Could not fetch user drive web URL: {e}")
return None
def _build_shared_kql_query(self, search_query: Optional[str], user_drive_url: Optional[str]) -> str:
"""Build KQL query string that excludes the user's own drive items."""
base_query = search_query if search_query else "*"
if user_drive_url:
return f'{base_query} AND -path:"{user_drive_url}"'
return base_query
def _list_shared_items(self, limit: int = 100, load_content: bool = False, page_token: Optional[str] = None, search_query: Optional[str] = None) -> List[Document]:
"""Fetch shared drive items using Microsoft Graph Search API with local offset paging.
We always fetch up to a fixed maximum number of hits from Graph (single request),
then page through that array locally using `page_token` as a simple integer offset.
This avoids relying on buggy or inconsistent remote `from`/`size` semantics.
"""
self._ensure_valid_token()
documents: List[Document] = []
try:
user_drive_url = self._get_user_drive_web_url()
query_text = self._build_shared_kql_query(search_query, user_drive_url)
url = f"{self.GRAPH_API_BASE}/search/query"
page_size = 500 # maximum number of hits we care about for selection
body = {
"requests": [
{
"entityTypes": ["driveItem"],
"query": {"queryString": query_text},
"from": 0,
"size": page_size,
}
]
}
headers = self._get_headers()
headers["Content-Type"] = "application/json"
response = requests.post(url, headers=headers, json=body)
response.raise_for_status()
results = response.json()
search_response = results.get("value", [])
if not search_response:
logging.warning("Search API returned empty value array")
self.next_page_token = None
return documents
hits_containers = search_response[0].get("hitsContainers", [])
if not hits_containers:
logging.warning("Search API returned no hitsContainers")
self.next_page_token = None
return documents
container = hits_containers[0]
total = container.get("total", 0)
raw_hits = container.get("hits", [])
# Deduplicate by effective item ID (driveId:itemId) to avoid the same
# resource appearing multiple times across the result set.
deduped_hits = []
seen_ids = set()
for hit in raw_hits:
resource = hit.get("resource", {})
item_id = resource.get("id")
drive_id = resource.get("parentReference", {}).get("driveId")
effective_id = f"{drive_id}:{item_id}" if drive_id and item_id else item_id
if not effective_id or effective_id in seen_ids:
continue
seen_ids.add(effective_id)
deduped_hits.append(hit)
hits = deduped_hits
logging.info(
f"Search API returned {total} total results, {len(raw_hits)} raw hits, {len(hits)} unique hits in this batch"
)
try:
offset = int(page_token) if page_token is not None else 0
except (TypeError, ValueError):
logging.warning(
f"Invalid page_token '{page_token}' for shared items search, defaulting to 0"
)
offset = 0
if offset < 0:
offset = 0
if offset >= len(hits):
self.next_page_token = None
return documents
end_index = offset + limit if limit else len(hits)
end_index = min(end_index, len(hits))
for hit in hits[offset:end_index]:
resource = hit.get("resource", {})
item_name = resource.get("name", "Unknown")
item_id = resource.get("id")
drive_id = resource.get("parentReference", {}).get("driveId")
effective_id = f"{drive_id}:{item_id}" if drive_id and item_id else item_id
is_folder = "folder" in resource
if is_folder:
doc_metadata = {
"file_name": item_name,
"mime_type": "folder",
"size": resource.get("size"),
"created_time": resource.get("createdDateTime"),
"modified_time": resource.get("lastModifiedDateTime"),
"source": "share_point",
"is_folder": True,
}
documents.append(
Document(text="", doc_id=effective_id, extra_info=doc_metadata)
)
else:
mime_type, supported = self._resolve_mime_type(resource)
if not supported:
logging.info(
f"Skipping unsupported shared file: {item_name} (mime: {mime_type})"
)
continue
doc_metadata = {
"file_name": item_name,
"mime_type": mime_type,
"size": resource.get("size"),
"created_time": resource.get("createdDateTime"),
"modified_time": resource.get("lastModifiedDateTime"),
"source": "share_point",
}
content = ""
if load_content:
content = self._download_file_content(effective_id) or ""
documents.append(
Document(text=content, doc_id=effective_id, extra_info=doc_metadata)
)
if limit and end_index < len(hits):
self.next_page_token = str(end_index)
else:
self.next_page_token = None
return documents
except Exception as e:
logging.error(f"Error listing shared items via search API: {e}", exc_info=True)
return documents
@_retry_on_auth_failure
def _download_file_content(self, file_id: str) -> Optional[str]:
self._ensure_valid_token()
try:
url = f"{self._get_item_url(file_id)}/content"
response = requests.get(url, headers=self._get_headers())
response.raise_for_status()
try:
return response.content.decode('utf-8')
except UnicodeDecodeError:
logging.error(f"Could not decode file {file_id} as text")
return None
except requests.exceptions.HTTPError:
raise
except Exception as e:
logging.error(f"Error downloading file {file_id}: {e}")
return None
def _download_single_file(self, file_id: str, local_dir: str) -> bool:
try:
url = self._get_item_url(file_id)
params = {'$select': 'id,name,file'}
response = requests.get(url, headers=self._get_headers(), params=params)
response.raise_for_status()
metadata = response.json()
file_name = metadata.get('name', 'unknown')
file_data = metadata.get('file', {})
mime_type = file_data.get('mimeType', 'application/octet-stream')
if mime_type not in self.SUPPORTED_MIME_TYPES:
logging.info(f"Skipping unsupported file type: {mime_type}")
return False
os.makedirs(local_dir, exist_ok=True)
full_path = os.path.join(local_dir, file_name)
download_url = f"{self._get_item_url(file_id)}/content"
download_response = requests.get(download_url, headers=self._get_headers())
download_response.raise_for_status()
with open(full_path, 'wb') as f:
f.write(download_response.content)
return True
except Exception as e:
logging.error(f"Error in _download_single_file: {e}")
return False
def _download_folder_recursive(self, folder_id: str, local_dir: str, recursive: bool = True) -> int:
files_downloaded = 0
try:
os.makedirs(local_dir, exist_ok=True)
url = f"{self._get_item_url(folder_id)}/children"
params = {'$top': 1000}
while url:
response = requests.get(url, headers=self._get_headers(), params=params)
response.raise_for_status()
results = response.json()
items = results.get('value', [])
logging.info(f"Found {len(items)} items in folder {folder_id}")
for item in items:
item_name = item.get('name', 'unknown')
item_id = item.get('id')
if 'folder' in item:
if recursive:
subfolder_path = os.path.join(local_dir, item_name)
os.makedirs(subfolder_path, exist_ok=True)
subfolder_files = self._download_folder_recursive(
item_id,
subfolder_path,
recursive
)
files_downloaded += subfolder_files
logging.info(f"Downloaded {subfolder_files} files from subfolder {item_name}")
else:
success = self._download_single_file(item_id, local_dir)
if success:
files_downloaded += 1
logging.info(f"Downloaded file: {item_name}")
else:
logging.warning(f"Failed to download file: {item_name}")
url = results.get('@odata.nextLink')
return files_downloaded
except Exception as e:
logging.error(f"Error in _download_folder_recursive for folder {folder_id}: {e}", exc_info=True)
return files_downloaded
def _download_folder_contents(self, folder_id: str, local_dir: str, recursive: bool = True) -> int:
try:
self._ensure_valid_token()
return self._download_folder_recursive(folder_id, local_dir, recursive)
except Exception as e:
logging.error(f"Error downloading folder {folder_id}: {e}", exc_info=True)
return 0
def _download_file_to_directory(self, file_id: str, local_dir: str) -> bool:
try:
self._ensure_valid_token()
return self._download_single_file(file_id, local_dir)
except Exception as e:
logging.error(f"Error downloading file {file_id}: {e}", exc_info=True)
return False
def download_to_directory(self, local_dir: str, source_config: Dict[str, Any] = None) -> Dict[str, Any]:
if source_config is None:
source_config = {}
config = source_config if source_config else getattr(self, 'config', {})
files_downloaded = 0
try:
folder_ids = config.get('folder_ids', [])
file_ids = config.get('file_ids', [])
recursive = config.get('recursive', True)
if file_ids:
if isinstance(file_ids, str):
file_ids = [file_ids]
for file_id in file_ids:
if self._download_file_to_directory(file_id, local_dir):
files_downloaded += 1
if folder_ids:
if isinstance(folder_ids, str):
folder_ids = [folder_ids]
for folder_id in folder_ids:
try:
url = self._get_item_url(folder_id)
params = {'$select': 'id,name'}
response = requests.get(url, headers=self._get_headers(), params=params)
response.raise_for_status()
folder_metadata = response.json()
folder_name = folder_metadata.get('name', '')
folder_path = os.path.join(local_dir, folder_name)
os.makedirs(folder_path, exist_ok=True)
folder_files = self._download_folder_recursive(
folder_id,
folder_path,
recursive
)
files_downloaded += folder_files
logging.info(f"Downloaded {folder_files} files from folder {folder_name}")
except Exception as e:
logging.error(f"Error downloading folder {folder_id}: {e}", exc_info=True)
if not file_ids and not folder_ids:
raise ValueError("No folder_ids or file_ids provided for download")
return {
"files_downloaded": files_downloaded,
"directory_path": local_dir,
"empty_result": files_downloaded == 0,
"source_type": "share_point",
"config_used": config
}
except Exception as e:
return {
"files_downloaded": files_downloaded,
"directory_path": local_dir,
"empty_result": True,
"source_type": "share_point",
"config_used": config,
"error": str(e)
}

View File

@@ -65,10 +65,6 @@ def embed_and_store_documents(docs: List[Any], folder_name: str, source_id: str,
if not os.path.exists(folder_name):
os.makedirs(folder_name)
# Validate docs is not empty
if not docs:
raise ValueError("No documents to embed - check file format and extension")
# Initialize vector store
if settings.VECTOR_STORE == "faiss":
docs_init = [docs.pop(0)]

View File

@@ -2,7 +2,7 @@
from abc import abstractmethod
from typing import Any, List
from langchain_core.documents import Document as LCDocument
from langchain.docstore.document import Document as LCDocument
from application.parser.schema.base import Document

View File

@@ -10,97 +10,29 @@ from application.parser.file.epub_parser import EpubParser
from application.parser.file.html_parser import HTMLParser
from application.parser.file.markdown_parser import MarkdownParser
from application.parser.file.rst_parser import RstParser
from application.parser.file.tabular_parser import PandasCSVParser, ExcelParser
from application.parser.file.tabular_parser import PandasCSVParser,ExcelParser
from application.parser.file.json_parser import JSONParser
from application.parser.file.pptx_parser import PPTXParser
from application.parser.file.image_parser import ImageParser
from application.parser.schema.base import Document
from application.utils import num_tokens_from_string
from application.core.settings import settings
def get_default_file_extractor(
ocr_enabled: Optional[bool] = None,
) -> Dict[str, BaseParser]:
"""Get the default file extractor.
Uses docling parsers by default for advanced document processing.
Falls back to standard parsers if docling is not installed.
"""
try:
from application.parser.file.docling_parser import (
DoclingPDFParser,
DoclingDocxParser,
DoclingPPTXParser,
DoclingXLSXParser,
DoclingHTMLParser,
DoclingImageParser,
DoclingCSVParser,
DoclingAsciiDocParser,
DoclingVTTParser,
DoclingXMLParser,
)
if ocr_enabled is None:
ocr_enabled = settings.DOCLING_OCR_ENABLED
return {
# Documents
".pdf": DoclingPDFParser(ocr_enabled=ocr_enabled),
".docx": DoclingDocxParser(),
".pptx": DoclingPPTXParser(),
".xlsx": DoclingXLSXParser(),
# Web formats
".html": DoclingHTMLParser(),
".xhtml": DoclingHTMLParser(),
# Data formats
".csv": DoclingCSVParser(),
".json": JSONParser(), # Keep JSON parser (specialized handling)
# Text/markup formats
".md": MarkdownParser(), # Keep markdown parser (specialized handling)
".mdx": MarkdownParser(),
".rst": RstParser(),
".adoc": DoclingAsciiDocParser(),
".asciidoc": DoclingAsciiDocParser(),
# Images (with OCR) - only use Docling when OCR is enabled
".png": DoclingImageParser(ocr_enabled=ocr_enabled) if ocr_enabled else ImageParser(),
".jpg": DoclingImageParser(ocr_enabled=ocr_enabled) if ocr_enabled else ImageParser(),
".jpeg": DoclingImageParser(ocr_enabled=ocr_enabled) if ocr_enabled else ImageParser(),
".tiff": DoclingImageParser(ocr_enabled=ocr_enabled) if ocr_enabled else ImageParser(),
".tif": DoclingImageParser(ocr_enabled=ocr_enabled) if ocr_enabled else ImageParser(),
".bmp": DoclingImageParser(ocr_enabled=ocr_enabled) if ocr_enabled else ImageParser(),
".webp": DoclingImageParser(ocr_enabled=ocr_enabled) if ocr_enabled else ImageParser(),
# Media/subtitles
".vtt": DoclingVTTParser(),
# Specialized XML formats
".xml": DoclingXMLParser(),
# Formats docling doesn't support - use standard parsers
".epub": EpubParser(),
}
except ImportError:
logging.warning(
"docling is not installed. Using standard parsers. "
"For advanced document parsing, install with: pip install docling"
)
# Fallback to standard parsers
return {
".pdf": PDFParser(),
".docx": DocxParser(),
".csv": PandasCSVParser(),
".xlsx": ExcelParser(),
".epub": EpubParser(),
".md": MarkdownParser(),
".rst": RstParser(),
".html": HTMLParser(),
".mdx": MarkdownParser(),
".json": JSONParser(),
".pptx": PPTXParser(),
".png": ImageParser(),
".jpg": ImageParser(),
".jpeg": ImageParser(),
}
# For backwards compatibility
DEFAULT_FILE_EXTRACTOR: Dict[str, BaseParser] = get_default_file_extractor()
DEFAULT_FILE_EXTRACTOR: Dict[str, BaseParser] = {
".pdf": PDFParser(),
".docx": DocxParser(),
".csv": PandasCSVParser(),
".xlsx":ExcelParser(),
".epub": EpubParser(),
".md": MarkdownParser(),
".rst": RstParser(),
".html": HTMLParser(),
".mdx": MarkdownParser(),
".json":JSONParser(),
".pptx":PPTXParser(),
".png": ImageParser(),
".jpg": ImageParser(),
".jpeg": ImageParser(),
}
class SimpleDirectoryReader(BaseReader):
@@ -151,10 +83,7 @@ class SimpleDirectoryReader(BaseReader):
self.recursive = recursive
self.exclude_hidden = exclude_hidden
# Normalize extensions to lowercase for case-insensitive matching
self.required_exts = (
[ext.lower() for ext in required_exts] if required_exts else None
)
self.required_exts = required_exts
self.num_files_limit = num_files_limit
if input_files:
@@ -183,7 +112,7 @@ class SimpleDirectoryReader(BaseReader):
continue
elif (
self.required_exts is not None
and input_file.suffix.lower() not in self.required_exts
and input_file.suffix not in self.required_exts
):
continue
else:
@@ -220,9 +149,8 @@ class SimpleDirectoryReader(BaseReader):
self.file_token_counts = {}
for input_file in self.input_files:
suffix_lower = input_file.suffix.lower()
if suffix_lower in self.file_extractor:
parser = self.file_extractor[suffix_lower]
if input_file.suffix in self.file_extractor:
parser = self.file_extractor[input_file.suffix]
if not parser.parser_config_set:
parser.init_parser()
data = parser.parse_file(input_file, errors=self.errors)
@@ -304,7 +232,7 @@ class SimpleDirectoryReader(BaseReader):
if subtree:
result[item.name] = subtree
else:
if self.required_exts is not None and item.suffix.lower() not in self.required_exts:
if self.required_exts is not None and item.suffix not in self.required_exts:
continue
full_path = str(item.resolve())
@@ -323,4 +251,4 @@ class SimpleDirectoryReader(BaseReader):
return result
return build_tree(Path(base_path))
return build_tree(Path(base_path))

View File

@@ -1,330 +0,0 @@
"""Docling parser.
Uses docling library for advanced document parsing with layout detection,
table structure recognition, and unified document representation.
Supports: PDF, DOCX, PPTX, XLSX, HTML, XHTML, CSV, Markdown, AsciiDoc,
images (PNG, JPEG, TIFF, BMP, WEBP), WebVTT, and specialized XML formats.
"""
import importlib.util
import logging
from pathlib import Path
from typing import Dict, List, Optional, Union
from application.parser.file.base_parser import BaseParser
logger = logging.getLogger(__name__)
class DoclingParser(BaseParser):
"""Parser using docling for advanced document processing.
Docling provides:
- Advanced PDF layout analysis
- Table structure recognition
- Reading order detection
- OCR for scanned documents (supports RapidOCR)
- Unified DoclingDocument format
- Export to Markdown
Uses hybrid OCR approach by default:
- Text regions: Direct PDF text extraction (fast)
- Bitmap/image regions: OCR only these areas (smart)
"""
def __init__(
self,
ocr_enabled: bool = True,
table_structure: bool = True,
export_format: str = "markdown",
use_rapidocr: bool = True,
ocr_languages: Optional[List[str]] = None,
force_full_page_ocr: bool = False,
):
"""Initialize DoclingParser.
Args:
ocr_enabled: Enable OCR for bitmap/image regions in documents
table_structure: Enable table structure recognition
export_format: Output format ('markdown', 'text', 'html')
use_rapidocr: Use RapidOCR engine (default True, works well in Docker)
ocr_languages: List of OCR languages (default: ['english'])
force_full_page_ocr: Force OCR on entire page (False = smart hybrid OCR)
"""
super().__init__()
self.ocr_enabled = ocr_enabled
self.table_structure = table_structure
self.export_format = export_format
self.use_rapidocr = use_rapidocr
self.ocr_languages = ocr_languages or ["english"]
self.force_full_page_ocr = force_full_page_ocr
self._converter = None
def _create_converter(self):
"""Create a docling converter with hybrid OCR configuration.
Uses smart OCR approach:
- When ocr_enabled=True and force_full_page_ocr=False (default):
Layout model detects text vs bitmap regions, OCR only runs on bitmaps
- When ocr_enabled=True and force_full_page_ocr=True:
OCR runs on entire page (for scanned documents/images)
- When ocr_enabled=False:
No OCR, only native text extraction
Returns:
DocumentConverter instance
"""
from docling.document_converter import (
DocumentConverter,
ImageFormatOption,
InputFormat,
PdfFormatOption,
)
from docling.datamodel.pipeline_options import PdfPipelineOptions
pipeline_options = PdfPipelineOptions(
do_ocr=self.ocr_enabled,
do_table_structure=self.table_structure,
)
if self.ocr_enabled:
ocr_options = self._get_ocr_options()
if ocr_options is not None:
pipeline_options.ocr_options = ocr_options
return DocumentConverter(
format_options={
InputFormat.PDF: PdfFormatOption(
pipeline_options=pipeline_options,
),
InputFormat.IMAGE: ImageFormatOption(
pipeline_options=pipeline_options,
),
}
)
def _init_parser(self) -> Dict:
"""Initialize the docling converter with hybrid OCR."""
logger.info("Initializing DoclingParser...")
logger.info(f" ocr_enabled={self.ocr_enabled}")
logger.info(f" force_full_page_ocr={self.force_full_page_ocr}")
logger.info(f" use_rapidocr={self.use_rapidocr}")
if importlib.util.find_spec("docling.document_converter") is None:
raise ImportError(
"docling is required for DoclingParser. "
"Install it with: pip install docling"
)
# Create converter with hybrid OCR (smart: text direct, bitmaps OCR'd)
self._converter = self._create_converter()
logger.info("DoclingParser initialized successfully")
return {
"ocr_enabled": self.ocr_enabled,
"table_structure": self.table_structure,
"export_format": self.export_format,
"use_rapidocr": self.use_rapidocr,
"ocr_languages": self.ocr_languages,
"force_full_page_ocr": self.force_full_page_ocr,
}
def _get_ocr_options(self):
"""Get OCR options based on configuration.
Returns RapidOcrOptions if use_rapidocr is True and available,
otherwise returns None to use docling defaults.
"""
if not self.use_rapidocr:
return None
try:
from docling.datamodel.pipeline_options import RapidOcrOptions
return RapidOcrOptions(
lang=self.ocr_languages,
force_full_page_ocr=self.force_full_page_ocr,
)
except ImportError as e:
logger.warning(f"Failed to import RapidOcrOptions: {e}")
return None
except Exception as e:
logger.error(f"Error creating RapidOcrOptions: {e}")
return None
def _export_content(self, document) -> str:
"""Export document content in the configured format.
Handles edge case where text is nested under picture elements (e.g., OCR'd
images). If the standard export returns minimal content but document.texts
contains extracted text, falls back to direct text extraction.
"""
if self.export_format == "markdown":
content = document.export_to_markdown()
elif self.export_format == "html":
content = document.export_to_html()
else:
content = document.export_to_text()
# Handle case where text is nested under pictures (common with OCR'd images)
# Standard exports may return just "<!-- image -->" while actual text exists
stripped_content = content.strip()
is_minimal = len(stripped_content) < 50 or stripped_content == "<!-- image -->"
if is_minimal and hasattr(document, "texts") and document.texts:
# Extract text directly from document.texts
extracted_texts = [t.text for t in document.texts if t.text]
if extracted_texts:
logger.info(
f"Standard export minimal ({len(stripped_content)} chars), "
f"extracting {len(extracted_texts)} texts directly"
)
return "\n\n".join(extracted_texts)
return content
def parse_file(self, file: Path, errors: str = "ignore") -> Union[str, List[str]]:
"""Parse file using docling with hybrid OCR.
Uses smart OCR approach where the layout model detects text vs bitmap
regions. Text is extracted directly, bitmaps are OCR'd only when needed.
Args:
file: Path to the file to parse
errors: Error handling mode (ignored, docling handles internally)
Returns:
Parsed document content as markdown string
"""
logger.info(f"parse_file called for: {file}")
if self._converter is None:
self._init_parser()
try:
logger.info(f"Converting file with hybrid OCR: {file}")
result = self._converter.convert(str(file))
content = self._export_content(result.document)
logger.info(f"Parse complete, content length: {len(content)} chars")
return content
except Exception as e:
logger.error(f"Error parsing file with docling: {e}", exc_info=True)
if errors == "ignore":
return f"[Error parsing file with docling: {str(e)}]"
raise
class DoclingPDFParser(DoclingParser):
"""Docling-based PDF parser with advanced features and RapidOCR support.
Uses hybrid OCR approach by default:
- Text regions: Direct PDF text extraction (fast)
- Bitmap/image regions: OCR only these areas (smart)
Set force_full_page_ocr=True only for fully scanned documents.
"""
def __init__(
self,
ocr_enabled: bool = True,
table_structure: bool = True,
use_rapidocr: bool = True,
ocr_languages: Optional[List[str]] = None,
force_full_page_ocr: bool = False,
):
super().__init__(
ocr_enabled=ocr_enabled,
table_structure=table_structure,
export_format="markdown",
use_rapidocr=use_rapidocr,
ocr_languages=ocr_languages,
force_full_page_ocr=force_full_page_ocr,
)
class DoclingDocxParser(DoclingParser):
"""Docling-based DOCX parser."""
def __init__(self):
super().__init__(export_format="markdown")
class DoclingPPTXParser(DoclingParser):
"""Docling-based PPTX parser."""
def __init__(self):
super().__init__(export_format="markdown")
class DoclingXLSXParser(DoclingParser):
"""Docling-based XLSX parser with table structure."""
def __init__(self):
super().__init__(table_structure=True, export_format="markdown")
class DoclingHTMLParser(DoclingParser):
"""Docling-based HTML parser."""
def __init__(self):
super().__init__(export_format="markdown")
class DoclingImageParser(DoclingParser):
"""Docling-based image parser with OCR and RapidOCR support.
For images, force_full_page_ocr=True is used since images are entirely
visual and require full OCR to extract any text.
"""
def __init__(
self,
ocr_enabled: bool = True,
use_rapidocr: bool = True,
ocr_languages: Optional[List[str]] = None,
force_full_page_ocr: bool = True,
):
super().__init__(
ocr_enabled=ocr_enabled,
export_format="markdown",
use_rapidocr=use_rapidocr,
ocr_languages=ocr_languages,
force_full_page_ocr=force_full_page_ocr,
)
class DoclingCSVParser(DoclingParser):
"""Docling-based CSV parser."""
def __init__(self):
super().__init__(table_structure=True, export_format="markdown")
class DoclingMarkdownParser(DoclingParser):
"""Docling-based Markdown parser."""
def __init__(self):
super().__init__(export_format="markdown")
class DoclingAsciiDocParser(DoclingParser):
"""Docling-based AsciiDoc parser."""
def __init__(self):
super().__init__(export_format="markdown")
class DoclingVTTParser(DoclingParser):
"""Docling-based WebVTT (video text tracks) parser."""
def __init__(self):
super().__init__(export_format="markdown")
class DoclingXMLParser(DoclingParser):
"""Docling-based XML parser (USPTO, JATS)."""
def __init__(self):
super().__init__(export_format="markdown")

View File

@@ -7,8 +7,8 @@ import re
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union, cast
import tiktoken
from application.parser.file.base_parser import BaseParser
from application.utils import num_tokens_from_string
class MarkdownParser(BaseParser):
@@ -38,7 +38,7 @@ class MarkdownParser(BaseParser):
def tups_chunk_append(self, tups: List[Tuple[Optional[str], str]], current_header: Optional[str],
current_text: str):
"""Append to tups chunk."""
num_tokens = num_tokens_from_string(current_text)
num_tokens = len(tiktoken.get_encoding("cl100k_base").encode(current_text))
if num_tokens > self._max_tokens:
chunks = [current_text[i:i + self._max_tokens] for i in range(0, len(current_text), self._max_tokens)]
for chunk in chunks:

View File

@@ -2,7 +2,7 @@
from abc import abstractmethod
from typing import Any, List
from langchain_core.documents import Document as LCDocument
from langchain.docstore.document import Document as LCDocument
from application.parser.schema.base import Document

View File

@@ -1,11 +1,9 @@
import logging
import os
import requests
from urllib.parse import urlparse, urljoin
from bs4 import BeautifulSoup
from application.parser.remote.base import BaseRemote
from application.parser.schema.base import Document
from application.core.url_validation import validate_url, SSRFError
from langchain_community.document_loaders import WebBaseLoader
class CrawlerLoader(BaseRemote):
@@ -18,12 +16,9 @@ class CrawlerLoader(BaseRemote):
if isinstance(url, list) and url:
url = url[0]
# Validate URL to prevent SSRF attacks
try:
url = validate_url(url)
except SSRFError as e:
logging.error(f"URL validation failed: {e}")
return []
# Check if the URL scheme is provided, if not, assume http
if not urlparse(url).scheme:
url = "http://" + url
visited_urls = set()
base_url = urlparse(url).scheme + "://" + urlparse(url).hostname
@@ -35,26 +30,16 @@ class CrawlerLoader(BaseRemote):
visited_urls.add(current_url)
try:
# Validate each URL before making requests
try:
validate_url(current_url)
except SSRFError as e:
logging.warning(f"Skipping URL due to validation failure: {current_url} - {e}")
continue
response = requests.get(current_url, timeout=30)
response = requests.get(current_url)
response.raise_for_status()
loader = self.loader([current_url])
docs = loader.load()
# Convert the loaded documents to your Document schema
for doc in docs:
metadata = dict(doc.metadata or {})
source_url = metadata.get("source") or current_url
metadata["file_path"] = self._url_to_virtual_path(source_url)
loaded_content.append(
Document(
doc.page_content,
extra_info=metadata
extra_info=doc.metadata
)
)
except Exception as e:
@@ -78,29 +63,3 @@ class CrawlerLoader(BaseRemote):
break
return loaded_content
def _url_to_virtual_path(self, url):
"""
Convert a URL to a virtual file path ending with .md.
Examples:
https://docs.docsgpt.cloud/ -> index.md
https://docs.docsgpt.cloud/guides/setup -> guides/setup.md
https://docs.docsgpt.cloud/guides/setup/ -> guides/setup.md
https://example.com/page.html -> page.md
"""
parsed = urlparse(url)
path = parsed.path.strip("/")
if not path:
return "index.md"
# Remove common file extensions and add .md
base, ext = os.path.splitext(path)
if ext.lower() in [".html", ".htm", ".php", ".asp", ".aspx", ".jsp"]:
path = base
if not path.endswith(".md"):
path = f"{path}.md"
return path

View File

@@ -2,12 +2,10 @@ import requests
from urllib.parse import urlparse, urljoin
from bs4 import BeautifulSoup
from application.parser.remote.base import BaseRemote
from application.core.url_validation import validate_url, SSRFError
import re
from markdownify import markdownify
from application.parser.schema.base import Document
import tldextract
import os
class CrawlerLoader(BaseRemote):
def __init__(self, limit=10, allow_subdomains=False):
@@ -27,12 +25,9 @@ class CrawlerLoader(BaseRemote):
if isinstance(url, list) and url:
url = url[0]
# Validate URL to prevent SSRF attacks
try:
url = validate_url(url)
except SSRFError as e:
print(f"URL validation failed: {e}")
return []
# Ensure the URL has a scheme (if not, default to http)
if not urlparse(url).scheme:
url = "http://" + url
# Keep track of visited URLs to avoid revisiting the same page
visited_urls = set()
@@ -58,21 +53,13 @@ class CrawlerLoader(BaseRemote):
# Convert the HTML to Markdown for cleaner text formatting
title, language, processed_markdown = self._process_html_to_markdown(html_content, current_url)
if processed_markdown:
# Generate virtual file path from URL for consistent file-like matching
virtual_path = self._url_to_virtual_path(current_url)
# Create a Document for each visited page
documents.append(
Document(
processed_markdown, # content
None, # doc_id
None, # embedding
{
"source": current_url,
"title": title,
"language": language,
"file_path": virtual_path,
}, # extra_info
{"source": current_url, "title": title, "language": language} # extra_info
)
)
@@ -91,14 +78,9 @@ class CrawlerLoader(BaseRemote):
def _fetch_page(self, url):
try:
# Validate URL before fetching to prevent SSRF
validate_url(url)
response = self.session.get(url, timeout=10)
response.raise_for_status()
return response.text
except SSRFError as e:
print(f"URL validation failed for {url}: {e}")
return None
except requests.exceptions.RequestException as e:
print(f"Error fetching URL {url}: {e}")
return None
@@ -154,31 +136,4 @@ class CrawlerLoader(BaseRemote):
# Exact domain match
if link_base == base_domain:
filtered.append(link)
return filtered
def _url_to_virtual_path(self, url):
"""
Convert a URL to a virtual file path ending with .md.
Examples:
https://docs.docsgpt.cloud/ -> index.md
https://docs.docsgpt.cloud/guides/setup -> guides/setup.md
https://docs.docsgpt.cloud/guides/setup/ -> guides/setup.md
https://example.com/page.html -> page.md
"""
parsed = urlparse(url)
path = parsed.path.strip("/")
if not path:
return "index.md"
# Remove common file extensions and add .md
base, ext = os.path.splitext(path)
if ext.lower() in [".html", ".htm", ".php", ".asp", ".aspx", ".jsp"]:
path = base
# Ensure path ends with .md
if not path.endswith(".md"):
path = path + ".md"
return path
return filtered

View File

@@ -3,7 +3,6 @@ from application.parser.remote.crawler_loader import CrawlerLoader
from application.parser.remote.web_loader import WebLoader
from application.parser.remote.reddit_loader import RedditPostsLoaderRemote
from application.parser.remote.github_loader import GitHubLoader
from application.parser.remote.s3_loader import S3Loader
class RemoteCreator:
@@ -23,7 +22,6 @@ class RemoteCreator:
"crawler": CrawlerLoader,
"reddit": RedditPostsLoaderRemote,
"github": GitHubLoader,
"s3": S3Loader,
}
@classmethod

View File

@@ -1,427 +0,0 @@
import json
import logging
import os
import tempfile
import mimetypes
from typing import List, Optional
from application.parser.remote.base import BaseRemote
from application.parser.schema.base import Document
try:
import boto3
from botocore.exceptions import ClientError, NoCredentialsError
except ImportError:
boto3 = None
logger = logging.getLogger(__name__)
class S3Loader(BaseRemote):
"""Load documents from an AWS S3 bucket."""
def __init__(self):
if boto3 is None:
raise ImportError(
"boto3 is required for S3Loader. Install it with: pip install boto3"
)
self.s3_client = None
def _normalize_endpoint_url(self, endpoint_url: str, bucket: str) -> tuple[str, str]:
"""
Normalize endpoint URL for S3-compatible services.
Detects common mistakes like using bucket-prefixed URLs and extracts
the correct endpoint and bucket name.
Args:
endpoint_url: The provided endpoint URL
bucket: The provided bucket name
Returns:
Tuple of (normalized_endpoint_url, bucket_name)
"""
import re
from urllib.parse import urlparse
if not endpoint_url:
return endpoint_url, bucket
parsed = urlparse(endpoint_url)
host = parsed.netloc or parsed.path
# Check for DigitalOcean Spaces bucket-prefixed URL pattern
# e.g., https://mybucket.nyc3.digitaloceanspaces.com
do_match = re.match(r"^([^.]+)\.([a-z0-9]+)\.digitaloceanspaces\.com$", host)
if do_match:
extracted_bucket = do_match.group(1)
region = do_match.group(2)
correct_endpoint = f"https://{region}.digitaloceanspaces.com"
logger.warning(
f"Detected bucket-prefixed DigitalOcean Spaces URL. "
f"Extracted bucket '{extracted_bucket}' from endpoint. "
f"Using endpoint: {correct_endpoint}"
)
# If bucket wasn't provided or differs, use extracted one
if not bucket or bucket != extracted_bucket:
logger.info(f"Using extracted bucket name: '{extracted_bucket}' (was: '{bucket}')")
bucket = extracted_bucket
return correct_endpoint, bucket
# Check for just "digitaloceanspaces.com" without region
if host == "digitaloceanspaces.com":
logger.error(
"Invalid DigitalOcean Spaces endpoint: missing region. "
"Use format: https://<region>.digitaloceanspaces.com (e.g., https://lon1.digitaloceanspaces.com)"
)
return endpoint_url, bucket
def _init_client(
self,
aws_access_key_id: str,
aws_secret_access_key: str,
region_name: str = "us-east-1",
endpoint_url: Optional[str] = None,
bucket: Optional[str] = None,
) -> Optional[str]:
"""
Initialize the S3 client with credentials.
Returns:
The potentially corrected bucket name if endpoint URL was normalized
"""
from botocore.config import Config
client_kwargs = {
"aws_access_key_id": aws_access_key_id,
"aws_secret_access_key": aws_secret_access_key,
"region_name": region_name,
}
logger.info(f"Initializing S3 client with region: {region_name}")
corrected_bucket = bucket
if endpoint_url:
# Normalize the endpoint URL and potentially extract bucket name
normalized_endpoint, corrected_bucket = self._normalize_endpoint_url(endpoint_url, bucket)
logger.info(f"Original endpoint URL: {endpoint_url}")
logger.info(f"Normalized endpoint URL: {normalized_endpoint}")
logger.info(f"Bucket name: '{corrected_bucket}'")
client_kwargs["endpoint_url"] = normalized_endpoint
# Use path-style addressing for S3-compatible services
# (DigitalOcean Spaces, MinIO, etc.)
client_kwargs["config"] = Config(s3={"addressing_style": "path"})
else:
logger.info("Using default AWS S3 endpoint")
self.s3_client = boto3.client("s3", **client_kwargs)
logger.info("S3 client initialized successfully")
return corrected_bucket
def is_text_file(self, file_path: str) -> bool:
"""Determine if a file is a text file based on extension."""
text_extensions = {
".txt",
".md",
".markdown",
".rst",
".json",
".xml",
".yaml",
".yml",
".py",
".js",
".ts",
".jsx",
".tsx",
".java",
".c",
".cpp",
".h",
".hpp",
".cs",
".go",
".rs",
".rb",
".php",
".swift",
".kt",
".scala",
".html",
".css",
".scss",
".sass",
".less",
".sh",
".bash",
".zsh",
".fish",
".sql",
".r",
".m",
".mat",
".ini",
".cfg",
".conf",
".config",
".env",
".gitignore",
".dockerignore",
".editorconfig",
".log",
".csv",
".tsv",
}
file_lower = file_path.lower()
for ext in text_extensions:
if file_lower.endswith(ext):
return True
mime_type, _ = mimetypes.guess_type(file_path)
if mime_type and (
mime_type.startswith("text")
or mime_type in ["application/json", "application/xml"]
):
return True
return False
def is_supported_document(self, file_path: str) -> bool:
"""Check if file is a supported document type for parsing."""
document_extensions = {
".pdf",
".docx",
".doc",
".xlsx",
".xls",
".pptx",
".ppt",
".epub",
".odt",
".rtf",
}
file_lower = file_path.lower()
for ext in document_extensions:
if file_lower.endswith(ext):
return True
return False
def list_objects(self, bucket: str, prefix: str = "") -> List[str]:
"""
List all objects in the bucket with the given prefix.
Args:
bucket: S3 bucket name
prefix: Optional path prefix to filter objects
Returns:
List of object keys
"""
objects = []
paginator = self.s3_client.get_paginator("list_objects_v2")
logger.info(f"Listing objects in bucket: '{bucket}' with prefix: '{prefix}'")
logger.debug(f"S3 client endpoint: {self.s3_client.meta.endpoint_url}")
try:
page_count = 0
for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
page_count += 1
logger.debug(f"Processing page {page_count}, keys in response: {list(page.keys())}")
if "Contents" in page:
for obj in page["Contents"]:
key = obj["Key"]
if not key.endswith("/"):
objects.append(key)
logger.debug(f"Found object: {key}")
else:
logger.info(f"Page {page_count} has no 'Contents' key - bucket may be empty or prefix not found")
logger.info(f"Found {len(objects)} objects in bucket '{bucket}'")
except ClientError as e:
error_code = e.response.get("Error", {}).get("Code", "")
error_message = e.response.get("Error", {}).get("Message", "")
logger.error(f"ClientError listing objects - Code: {error_code}, Message: {error_message}")
logger.error(f"Full error response: {e.response}")
logger.error(f"Bucket: '{bucket}', Prefix: '{prefix}', Endpoint: {self.s3_client.meta.endpoint_url}")
if error_code == "NoSuchBucket":
raise Exception(f"S3 bucket '{bucket}' does not exist")
elif error_code == "AccessDenied":
raise Exception(
f"Access denied to S3 bucket '{bucket}'. Check your credentials and permissions."
)
elif error_code == "NoSuchKey":
# This is unusual for ListObjectsV2 - may indicate endpoint/bucket configuration issue
logger.error(
"NoSuchKey error on ListObjectsV2 - this may indicate the bucket name "
"is incorrect or the endpoint URL format is wrong. "
"For DigitalOcean Spaces, the endpoint should be like: "
"https://<region>.digitaloceanspaces.com and bucket should be just the space name."
)
raise Exception(
f"S3 error: {e}. For S3-compatible services, verify: "
f"1) Endpoint URL format (e.g., https://nyc3.digitaloceanspaces.com), "
f"2) Bucket name is just the space/bucket name without region prefix"
)
else:
raise Exception(f"S3 error: {e}")
except NoCredentialsError:
raise Exception(
"AWS credentials not found. Please provide valid credentials."
)
return objects
def get_object_content(self, bucket: str, key: str) -> Optional[str]:
"""
Get the content of an S3 object as text.
Args:
bucket: S3 bucket name
key: Object key
Returns:
File content as string, or None if file should be skipped
"""
if not self.is_text_file(key) and not self.is_supported_document(key):
return None
try:
response = self.s3_client.get_object(Bucket=bucket, Key=key)
content = response["Body"].read()
if self.is_text_file(key):
try:
decoded_content = content.decode("utf-8").strip()
if not decoded_content:
return None
return decoded_content
except UnicodeDecodeError:
return None
elif self.is_supported_document(key):
return self._process_document(content, key)
except ClientError as e:
error_code = e.response.get("Error", {}).get("Code", "")
if error_code == "NoSuchKey":
return None
elif error_code == "AccessDenied":
print(f"Access denied to object: {key}")
return None
else:
print(f"Error fetching object {key}: {e}")
return None
return None
def _process_document(self, content: bytes, key: str) -> Optional[str]:
"""
Process a document file (PDF, DOCX, etc.) and extract text.
Args:
content: File content as bytes
key: Object key (filename)
Returns:
Extracted text content
"""
ext = os.path.splitext(key)[1].lower()
with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as tmp_file:
tmp_file.write(content)
tmp_path = tmp_file.name
try:
from application.parser.file.bulk import SimpleDirectoryReader
reader = SimpleDirectoryReader(input_files=[tmp_path])
documents = reader.load_data()
if documents:
return "\n\n".join(doc.text for doc in documents if doc.text)
return None
except Exception as e:
print(f"Error processing document {key}: {e}")
return None
finally:
if os.path.exists(tmp_path):
os.unlink(tmp_path)
def load_data(self, inputs) -> List[Document]:
"""
Load documents from an S3 bucket.
Args:
inputs: JSON string or dict containing:
- aws_access_key_id: AWS access key ID
- aws_secret_access_key: AWS secret access key
- bucket: S3 bucket name
- prefix: Optional path prefix to filter objects
- region: AWS region (default: us-east-1)
- endpoint_url: Custom S3 endpoint URL (for MinIO, R2, etc.)
Returns:
List of Document objects
"""
if isinstance(inputs, str):
try:
data = json.loads(inputs)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON input: {e}")
else:
data = inputs
required_fields = ["aws_access_key_id", "aws_secret_access_key", "bucket"]
missing_fields = [field for field in required_fields if not data.get(field)]
if missing_fields:
raise ValueError(f"Missing required fields: {', '.join(missing_fields)}")
aws_access_key_id = data["aws_access_key_id"]
aws_secret_access_key = data["aws_secret_access_key"]
bucket = data["bucket"]
prefix = data.get("prefix", "")
region = data.get("region", "us-east-1")
endpoint_url = data.get("endpoint_url", "")
logger.info(f"Loading data from S3 - Bucket: '{bucket}', Prefix: '{prefix}', Region: '{region}'")
if endpoint_url:
logger.info(f"Custom endpoint URL provided: '{endpoint_url}'")
corrected_bucket = self._init_client(
aws_access_key_id, aws_secret_access_key, region, endpoint_url or None, bucket
)
# Use the corrected bucket name if endpoint URL normalization extracted one
if corrected_bucket and corrected_bucket != bucket:
logger.info(f"Using corrected bucket name: '{corrected_bucket}' (original: '{bucket}')")
bucket = corrected_bucket
objects = self.list_objects(bucket, prefix)
documents = []
for key in objects:
content = self.get_object_content(bucket, key)
if content is None:
continue
documents.append(
Document(
text=content,
doc_id=key,
extra_info={
"title": os.path.basename(key),
"source": f"s3://{bucket}/{key}",
"bucket": bucket,
"key": key,
},
)
)
logger.info(f"Loaded {len(documents)} documents from S3 bucket '{bucket}'")
return documents

View File

@@ -1,9 +1,8 @@
import logging
import requests
import re # Import regular expression library
import defusedxml.ElementTree as ET
import xml.etree.ElementTree as ET
from application.parser.remote.base import BaseRemote
from application.core.url_validation import validate_url, SSRFError
class SitemapLoader(BaseRemote):
def __init__(self, limit=20):
@@ -15,14 +14,7 @@ class SitemapLoader(BaseRemote):
sitemap_url= inputs
# Check if the input is a list and if it is, use the first element
if isinstance(sitemap_url, list) and sitemap_url:
sitemap_url = sitemap_url[0]
# Validate URL to prevent SSRF attacks
try:
sitemap_url = validate_url(sitemap_url)
except SSRFError as e:
logging.error(f"URL validation failed: {e}")
return []
url = sitemap_url[0]
urls = self._extract_urls(sitemap_url)
if not urls:
@@ -48,13 +40,8 @@ class SitemapLoader(BaseRemote):
def _extract_urls(self, sitemap_url):
try:
# Validate URL before fetching to prevent SSRF
validate_url(sitemap_url)
response = requests.get(sitemap_url, timeout=30)
response = requests.get(sitemap_url)
response.raise_for_status() # Raise an exception for HTTP errors
except SSRFError as e:
print(f"URL validation failed for sitemap: {sitemap_url}. Error: {e}")
return []
except (requests.exceptions.HTTPError, requests.exceptions.ConnectionError) as e:
print(f"Failed to fetch sitemap: {sitemap_url}. Error: {e}")
return []

Some files were not shown because too many files have changed in this diff Show More