mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-07 22:44:10 +00:00
Compare commits
1 Commits
sharepoint
...
dependabot
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
835182461e |
@@ -1,28 +1,9 @@
|
||||
API_KEY=<LLM api key (for example, open ai key)>
|
||||
LLM_NAME=docsgpt
|
||||
VITE_API_STREAMING=true
|
||||
INTERNAL_KEY=<internal key for worker-to-backend authentication>
|
||||
|
||||
# Remote Embeddings (Optional - for using a remote embeddings API instead of local SentenceTransformer)
|
||||
# When set, the app will use the remote API and won't load SentenceTransformer (saves RAM)
|
||||
EMBEDDINGS_BASE_URL=
|
||||
EMBEDDINGS_KEY=
|
||||
|
||||
#For Azure (you can delete it if you don't use Azure)
|
||||
OPENAI_API_BASE=
|
||||
OPENAI_API_VERSION=
|
||||
AZURE_DEPLOYMENT_NAME=
|
||||
AZURE_EMBEDDINGS_DEPLOYMENT_NAME=
|
||||
|
||||
#Azure AD Application (client) ID
|
||||
MICROSOFT_CLIENT_ID=your-azure-ad-client-id
|
||||
#Azure AD Application client secret
|
||||
MICROSOFT_CLIENT_SECRET=your-azure-ad-client-secret
|
||||
#Azure AD Tenant ID (or 'common' for multi-tenant)
|
||||
MICROSOFT_TENANT_ID=your-azure-ad-tenant-id
|
||||
#If you are using a Microsoft Entra ID tenant,
|
||||
#configure the AUTHORITY variable as
|
||||
#"https://login.microsoftonline.com/TENANT_GUID"
|
||||
#or "https://login.microsoftonline.com/contoso.onmicrosoft.com".
|
||||
#Alternatively, use "https://login.microsoftonline.com/common" for multi-tenant app.
|
||||
MICROSOFT_AUTHORITY=https://{tenantId}.ciamlogin.com/{tenantId}
|
||||
AZURE_EMBEDDINGS_DEPLOYMENT_NAME=
|
||||
2
.github/workflows/bandit.yaml
vendored
2
.github/workflows/bandit.yaml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
|
||||
2
.github/workflows/ci.yml
vendored
2
.github/workflows/ci.yml
vendored
@@ -21,7 +21,7 @@ jobs:
|
||||
contents: read
|
||||
packages: write
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- name: Set up QEMU # Only needed for emulation, not for native arm64 builds
|
||||
if: matrix.platform == 'linux/arm64'
|
||||
|
||||
2
.github/workflows/cife.yml
vendored
2
.github/workflows/cife.yml
vendored
@@ -21,7 +21,7 @@ jobs:
|
||||
contents: read
|
||||
packages: write
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- name: Set up QEMU # Only needed for emulation, not for native arm64 builds
|
||||
if: matrix.platform == 'linux/arm64'
|
||||
|
||||
2
.github/workflows/docker-develop-build.yml
vendored
2
.github/workflows/docker-develop-build.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
||||
contents: read
|
||||
packages: write
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
@@ -23,7 +23,7 @@ jobs:
|
||||
contents: read
|
||||
packages: write
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- name: Set up QEMU # Only needed for emulation, not for native arm64 builds
|
||||
if: matrix.platform == 'linux/arm64'
|
||||
|
||||
5
.github/workflows/lint.yml
vendored
5
.github/workflows/lint.yml
vendored
@@ -7,14 +7,11 @@ on:
|
||||
pull_request:
|
||||
types: [ opened, synchronize ]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
ruff:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- name: Lint with Ruff
|
||||
uses: chartboost/ruff-action@v1
|
||||
|
||||
6
.github/workflows/pytest.yml
vendored
6
.github/workflows/pytest.yml
vendored
@@ -1,9 +1,5 @@
|
||||
name: Run python tests with pytest
|
||||
on: [push, pull_request]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
pytest_and_coverage:
|
||||
name: Run tests and count coverage
|
||||
@@ -12,7 +8,7 @@ jobs:
|
||||
matrix:
|
||||
python-version: ["3.12"]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
|
||||
2
.github/workflows/sync_fork.yaml
vendored
2
.github/workflows/sync_fork.yaml
vendored
@@ -17,7 +17,7 @@ jobs:
|
||||
steps:
|
||||
# Step 1: run a standard checkout action
|
||||
- name: Checkout target repo
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
# Step 2: run the sync action
|
||||
- name: Sync upstream changes
|
||||
|
||||
6
.github/workflows/vale.yml
vendored
6
.github/workflows/vale.yml
vendored
@@ -9,16 +9,12 @@ on:
|
||||
- '.vale.ini'
|
||||
- '.github/styles/**'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
vale:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Vale linter
|
||||
uses: errata-ai/vale-action@v2
|
||||
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -2,7 +2,6 @@
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
experiments/
|
||||
|
||||
experiments
|
||||
# C extensions
|
||||
@@ -147,10 +146,6 @@ frontend/yarn-error.log*
|
||||
frontend/pnpm-debug.log*
|
||||
frontend/lerna-debug.log*
|
||||
|
||||
# Keep frontend utility helpers tracked (overrides global lib/ ignore)
|
||||
!frontend/src/lib/
|
||||
!frontend/src/lib/**
|
||||
|
||||
frontend/node_modules
|
||||
frontend/dist
|
||||
frontend/dist-ssr
|
||||
|
||||
@@ -1,6 +1,2 @@
|
||||
# Allow lines to be as long as 120 characters.
|
||||
line-length = 120
|
||||
|
||||
[lint.per-file-ignores]
|
||||
# Integration tests use sys.path.insert() before imports for standalone execution
|
||||
"tests/integration/*.py" = ["E402"]
|
||||
line-length = 120
|
||||
40
README.md
40
README.md
@@ -26,6 +26,13 @@
|
||||
|
||||
</div>
|
||||
|
||||
<div align="center">
|
||||
<br>
|
||||
🎃 <a href="https://github.com/arc53/DocsGPT/blob/main/HACKTOBERFEST.md"> Hacktoberfest Prizes, Rules & Q&A </a> 🎃
|
||||
<br>
|
||||
<br>
|
||||
</div>
|
||||
|
||||
|
||||
<div align="center">
|
||||
<br>
|
||||
@@ -46,11 +53,24 @@
|
||||
</ul>
|
||||
|
||||
## Roadmap
|
||||
- [x] Add OAuth 2.0 authentication for MCP ( September 2025 )
|
||||
- [x] Deep Agents ( October 2025 )
|
||||
- [x] Prompt Templating ( October 2025 )
|
||||
- [x] Full api tooling ( Dec 2025 )
|
||||
- [ ] Agent scheduling ( Jan 2026 )
|
||||
|
||||
- [x] Full GoogleAI compatibility (Jan 2025)
|
||||
- [x] Add tools (Jan 2025)
|
||||
- [x] Manually updating chunks in the app UI (Feb 2025)
|
||||
- [x] Devcontainer for easy development (Feb 2025)
|
||||
- [x] ReACT agent (March 2025)
|
||||
- [x] Chatbots menu re-design to handle tools, agent types, and more (April 2025)
|
||||
- [x] New input box in the conversation menu (April 2025)
|
||||
- [x] Add triggerable actions / tools (webhook) (April 2025)
|
||||
- [x] Agent optimisations (May 2025)
|
||||
- [x] Filesystem sources update (July 2025)
|
||||
- [x] Json Responses (August 2025)
|
||||
- [x] MCP support (August 2025)
|
||||
- [x] Google Drive integration (September 2025)
|
||||
- [x] Add OAuth 2.0 authentication for MCP (September 2025)
|
||||
- [ ] SharePoint integration (October 2025)
|
||||
- [ ] Deep Agents (October 2025)
|
||||
- [ ] Agent scheduling
|
||||
|
||||
You can find our full roadmap [here](https://github.com/orgs/arc53/projects/2). Please don't hesitate to contribute or create issues, it helps us improve DocsGPT!
|
||||
|
||||
@@ -145,17 +165,9 @@ We as members, contributors, and leaders, pledge to make participation in our co
|
||||
|
||||
The source code license is [MIT](https://opensource.org/license/mit/), as described in the [LICENSE](LICENSE) file.
|
||||
|
||||
## This project is supported by:
|
||||
|
||||
<p>This project is supported by:</p>
|
||||
<p>
|
||||
<a href="https://www.digitalocean.com/?utm_medium=opensource&utm_source=DocsGPT">
|
||||
<img src="https://opensource.nyc3.cdn.digitaloceanspaces.com/attribution/assets/SVG/DO_Logo_horizontal_blue.svg" width="201px">
|
||||
</a>
|
||||
</p>
|
||||
<p>
|
||||
<a href="https://get.neon.com/docsgpt">
|
||||
<img width="201" alt="color" src="https://github.com/user-attachments/assets/7d9813b7-0e6d-403f-b5af-68af066b326f" />
|
||||
</a>
|
||||
|
||||
</p>
|
||||
|
||||
|
||||
11
application/.env_sample
Normal file
11
application/.env_sample
Normal file
@@ -0,0 +1,11 @@
|
||||
API_KEY=your_api_key
|
||||
EMBEDDINGS_KEY=your_api_key
|
||||
API_URL=http://localhost:7091
|
||||
FLASK_APP=application/app.py
|
||||
FLASK_DEBUG=true
|
||||
|
||||
#For OPENAI on Azure
|
||||
OPENAI_API_BASE=
|
||||
OPENAI_API_VERSION=
|
||||
AZURE_DEPLOYMENT_NAME=
|
||||
AZURE_EMBEDDINGS_DEPLOYMENT_NAME=
|
||||
@@ -7,7 +7,7 @@ RUN apt-get update && \
|
||||
apt-get install -y software-properties-common && \
|
||||
add-apt-repository ppa:deadsnakes/ppa && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends gcc g++ wget unzip libc6-dev python3.12 python3.12-venv python3.12-dev && \
|
||||
apt-get install -y --no-install-recommends gcc wget unzip libc6-dev python3.12 python3.12-venv && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Verify Python installation and setup symlink
|
||||
@@ -48,12 +48,7 @@ FROM ubuntu:24.04 as final
|
||||
RUN apt-get update && \
|
||||
apt-get install -y software-properties-common && \
|
||||
add-apt-repository ppa:deadsnakes/ppa && \
|
||||
apt-get update && apt-get install -y --no-install-recommends \
|
||||
python3.12 \
|
||||
libgl1 \
|
||||
libglib2.0-0 \
|
||||
poppler-utils \
|
||||
&& \
|
||||
apt-get update && apt-get install -y --no-install-recommends python3.12 && \
|
||||
ln -s /usr/bin/python3.12 /usr/bin/python && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import logging
|
||||
|
||||
from application.agents.classic_agent import ClassicAgent
|
||||
from application.agents.react_agent import ReActAgent
|
||||
from application.agents.workflow_agent import WorkflowAgent
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -11,7 +9,6 @@ class AgentCreator:
|
||||
agents = {
|
||||
"classic": ClassicAgent,
|
||||
"react": ReActAgent,
|
||||
"workflow": WorkflowAgent,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -19,4 +16,5 @@ class AgentCreator:
|
||||
agent_class = cls.agents.get(type.lower())
|
||||
if not agent_class:
|
||||
raise ValueError(f"No agent class found for type {type}")
|
||||
|
||||
return agent_class(*args, **kwargs)
|
||||
|
||||
@@ -34,7 +34,6 @@ class BaseAgent(ABC):
|
||||
token_limit: Optional[int] = settings.DEFAULT_AGENT_LIMITS["token_limit"],
|
||||
limited_request_mode: Optional[bool] = False,
|
||||
request_limit: Optional[int] = settings.DEFAULT_AGENT_LIMITS["request_limit"],
|
||||
compressed_summary: Optional[str] = None,
|
||||
):
|
||||
self.endpoint = endpoint
|
||||
self.llm_name = llm_name
|
||||
@@ -65,9 +64,6 @@ class BaseAgent(ABC):
|
||||
self.token_limit = token_limit
|
||||
self.limited_request_mode = limited_request_mode
|
||||
self.request_limit = request_limit
|
||||
self.compressed_summary = compressed_summary
|
||||
self.current_token_count = 0
|
||||
self.context_limit_reached = False
|
||||
|
||||
@log_activity()
|
||||
def gen(
|
||||
@@ -120,10 +116,10 @@ class BaseAgent(ABC):
|
||||
params["properties"][k] = {
|
||||
key: value
|
||||
for key, value in v.items()
|
||||
if key not in ("filled_by_llm", "value", "required")
|
||||
if key != "filled_by_llm" and key != "value"
|
||||
}
|
||||
if v.get("required", False):
|
||||
params["required"].append(k)
|
||||
|
||||
params["required"].append(k)
|
||||
return params
|
||||
|
||||
def _prepare_tools(self, tools_dict):
|
||||
@@ -219,11 +215,7 @@ class BaseAgent(ABC):
|
||||
for param_type, target_dict in param_types.items():
|
||||
if param_type in action_data and action_data[param_type].get("properties"):
|
||||
for param, details in action_data[param_type]["properties"].items():
|
||||
if (
|
||||
param not in call_args
|
||||
and "value" in details
|
||||
and details["value"]
|
||||
):
|
||||
if param not in call_args and "value" in details:
|
||||
target_dict[param] = details["value"]
|
||||
for param, value in call_args.items():
|
||||
for param_type, target_dict in param_types.items():
|
||||
@@ -236,62 +228,31 @@ class BaseAgent(ABC):
|
||||
# Prepare tool_config and add tool_id for memory tools
|
||||
|
||||
if tool_data["name"] == "api_tool":
|
||||
action_config = tool_data["config"]["actions"][action_name]
|
||||
tool_config = {
|
||||
"url": action_config["url"],
|
||||
"method": action_config["method"],
|
||||
"url": tool_data["config"]["actions"][action_name]["url"],
|
||||
"method": tool_data["config"]["actions"][action_name]["method"],
|
||||
"headers": headers,
|
||||
"query_params": query_params,
|
||||
}
|
||||
if "body_content_type" in action_config:
|
||||
tool_config["body_content_type"] = action_config.get(
|
||||
"body_content_type", "application/json"
|
||||
)
|
||||
tool_config["body_encoding_rules"] = action_config.get(
|
||||
"body_encoding_rules", {}
|
||||
)
|
||||
else:
|
||||
tool_config = tool_data["config"].copy() if tool_data["config"] else {}
|
||||
# Add tool_id from MongoDB _id for tools that need instance isolation (like memory tool)
|
||||
# Use MongoDB _id if available, otherwise fall back to enumerated tool_id
|
||||
|
||||
tool_config["tool_id"] = str(tool_data.get("_id", tool_id))
|
||||
if hasattr(self, "conversation_id") and self.conversation_id:
|
||||
tool_config["conversation_id"] = self.conversation_id
|
||||
tool = tm.load_tool(
|
||||
tool_data["name"],
|
||||
tool_config=tool_config,
|
||||
user_id=self.user,
|
||||
user_id=self.user, # Pass user ID for MCP tools credential decryption
|
||||
)
|
||||
if tool_data["name"] == "api_tool":
|
||||
logger.debug(
|
||||
print(
|
||||
f"Executing api: {action_name} with query_params: {query_params}, headers: {headers}, body: {body}"
|
||||
)
|
||||
result = tool.execute_action(action_name, **body)
|
||||
else:
|
||||
logger.debug(f"Executing tool: {action_name} with args: {call_args}")
|
||||
print(f"Executing tool: {action_name} with args: {call_args}")
|
||||
result = tool.execute_action(action_name, **parameters)
|
||||
|
||||
get_artifact_id = (
|
||||
getattr(tool, "get_artifact_id", None)
|
||||
if tool_data["name"] != "api_tool"
|
||||
else None
|
||||
)
|
||||
|
||||
artifact_id = None
|
||||
if callable(get_artifact_id):
|
||||
try:
|
||||
artifact_id = get_artifact_id(action_name, **parameters)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to extract artifact_id from tool %s for action %s",
|
||||
tool_data["name"],
|
||||
action_name,
|
||||
)
|
||||
|
||||
artifact_id = str(artifact_id).strip() if artifact_id is not None else ""
|
||||
if artifact_id:
|
||||
tool_call_data["artifact_id"] = artifact_id
|
||||
tool_call_data["result"] = (
|
||||
f"{str(result)[:50]}..." if len(str(result)) > 50 else result
|
||||
)
|
||||
@@ -315,176 +276,15 @@ class BaseAgent(ABC):
|
||||
for tool_call in self.tool_calls
|
||||
]
|
||||
|
||||
def _calculate_current_context_tokens(self, messages: List[Dict]) -> int:
|
||||
"""
|
||||
Calculate total tokens in current context (messages).
|
||||
|
||||
Args:
|
||||
messages: List of message dicts
|
||||
|
||||
Returns:
|
||||
Total token count
|
||||
"""
|
||||
from application.api.answer.services.compression.token_counter import (
|
||||
TokenCounter,
|
||||
)
|
||||
|
||||
return TokenCounter.count_message_tokens(messages)
|
||||
|
||||
def _check_context_limit(self, messages: List[Dict]) -> bool:
|
||||
"""
|
||||
Check if we're approaching context limit (80%).
|
||||
|
||||
Args:
|
||||
messages: Current message list
|
||||
|
||||
Returns:
|
||||
True if at or above 80% of context limit
|
||||
"""
|
||||
from application.core.model_utils import get_token_limit
|
||||
from application.core.settings import settings
|
||||
|
||||
try:
|
||||
# Calculate current tokens
|
||||
current_tokens = self._calculate_current_context_tokens(messages)
|
||||
self.current_token_count = current_tokens
|
||||
|
||||
# Get context limit for model
|
||||
context_limit = get_token_limit(self.model_id)
|
||||
|
||||
# Calculate threshold (80%)
|
||||
threshold = int(context_limit * settings.COMPRESSION_THRESHOLD_PERCENTAGE)
|
||||
|
||||
# Check if we've reached the limit
|
||||
if current_tokens >= threshold:
|
||||
logger.warning(
|
||||
f"Context limit approaching: {current_tokens}/{context_limit} tokens "
|
||||
f"({(current_tokens/context_limit)*100:.1f}%)"
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking context limit: {str(e)}", exc_info=True)
|
||||
return False
|
||||
|
||||
def _validate_context_size(self, messages: List[Dict]) -> None:
|
||||
"""
|
||||
Pre-flight validation before calling LLM. Logs warnings but never raises errors.
|
||||
|
||||
Args:
|
||||
messages: Messages to be sent to LLM
|
||||
"""
|
||||
from application.core.model_utils import get_token_limit
|
||||
|
||||
current_tokens = self._calculate_current_context_tokens(messages)
|
||||
self.current_token_count = current_tokens
|
||||
context_limit = get_token_limit(self.model_id)
|
||||
|
||||
percentage = (current_tokens / context_limit) * 100
|
||||
|
||||
# Log based on usage level
|
||||
if current_tokens >= context_limit:
|
||||
logger.warning(
|
||||
f"Context at limit: {current_tokens:,}/{context_limit:,} tokens "
|
||||
f"({percentage:.1f}%). Model: {self.model_id}"
|
||||
)
|
||||
elif current_tokens >= int(
|
||||
context_limit * settings.COMPRESSION_THRESHOLD_PERCENTAGE
|
||||
):
|
||||
logger.info(
|
||||
f"Context approaching limit: {current_tokens:,}/{context_limit:,} tokens "
|
||||
f"({percentage:.1f}%)"
|
||||
)
|
||||
|
||||
def _truncate_text_middle(self, text: str, max_tokens: int) -> str:
|
||||
"""
|
||||
Truncate text by removing content from the middle, preserving start and end.
|
||||
|
||||
Args:
|
||||
text: Text to truncate
|
||||
max_tokens: Maximum tokens allowed
|
||||
|
||||
Returns:
|
||||
Truncated text with middle removed if needed
|
||||
"""
|
||||
from application.utils import num_tokens_from_string
|
||||
|
||||
current_tokens = num_tokens_from_string(text)
|
||||
if current_tokens <= max_tokens:
|
||||
return text
|
||||
|
||||
# Estimate chars per token (roughly 4 chars per token for English)
|
||||
chars_per_token = len(text) / current_tokens if current_tokens > 0 else 4
|
||||
target_chars = int(max_tokens * chars_per_token * 0.95) # 5% safety margin
|
||||
|
||||
if target_chars <= 0:
|
||||
return ""
|
||||
|
||||
# Split: keep 40% from start, 40% from end, remove middle
|
||||
start_chars = int(target_chars * 0.4)
|
||||
end_chars = int(target_chars * 0.4)
|
||||
|
||||
truncation_marker = "\n\n[... content truncated to fit context limit ...]\n\n"
|
||||
|
||||
truncated = text[:start_chars] + truncation_marker + text[-end_chars:]
|
||||
|
||||
logger.info(
|
||||
f"Truncated text from {current_tokens:,} to ~{max_tokens:,} tokens "
|
||||
f"(removed middle section)"
|
||||
)
|
||||
|
||||
return truncated
|
||||
|
||||
def _build_messages(
|
||||
self,
|
||||
system_prompt: str,
|
||||
query: str,
|
||||
) -> List[Dict]:
|
||||
"""Build messages using pre-rendered system prompt"""
|
||||
from application.core.model_utils import get_token_limit
|
||||
from application.utils import num_tokens_from_string
|
||||
|
||||
# Append compression summary to system prompt if present
|
||||
if self.compressed_summary:
|
||||
compression_context = (
|
||||
"\n\n---\n\n"
|
||||
"This session is being continued from a previous conversation that "
|
||||
"has been compressed to fit within context limits. "
|
||||
"The conversation is summarized below:\n\n"
|
||||
f"{self.compressed_summary}"
|
||||
)
|
||||
system_prompt = system_prompt + compression_context
|
||||
|
||||
context_limit = get_token_limit(self.model_id)
|
||||
system_tokens = num_tokens_from_string(system_prompt)
|
||||
|
||||
# Reserve 10% for response/tools
|
||||
safety_buffer = int(context_limit * 0.1)
|
||||
available_after_system = context_limit - system_tokens - safety_buffer
|
||||
|
||||
# Max tokens for query: 80% of available space (leave room for history)
|
||||
max_query_tokens = int(available_after_system * 0.8)
|
||||
query_tokens = num_tokens_from_string(query)
|
||||
|
||||
# Truncate query from middle if it exceeds 80% of available context
|
||||
if query_tokens > max_query_tokens:
|
||||
query = self._truncate_text_middle(query, max_query_tokens)
|
||||
query_tokens = num_tokens_from_string(query)
|
||||
|
||||
# Calculate remaining budget for chat history
|
||||
available_for_history = max(available_after_system - query_tokens, 0)
|
||||
|
||||
# Truncate chat history to fit within available budget
|
||||
working_history = self._truncate_history_to_fit(
|
||||
self.chat_history,
|
||||
available_for_history,
|
||||
)
|
||||
|
||||
messages = [{"role": "system", "content": system_prompt}]
|
||||
|
||||
for i in working_history:
|
||||
for i in self.chat_history:
|
||||
if "prompt" in i and "response" in i:
|
||||
messages.append({"role": "user", "content": i["prompt"]})
|
||||
messages.append({"role": "assistant", "content": i["response"]})
|
||||
@@ -516,65 +316,7 @@ class BaseAgent(ABC):
|
||||
messages.append({"role": "user", "content": query})
|
||||
return messages
|
||||
|
||||
def _truncate_history_to_fit(
|
||||
self,
|
||||
history: List[Dict],
|
||||
max_tokens: int,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Truncate chat history to fit within token budget, keeping most recent messages.
|
||||
|
||||
Args:
|
||||
history: Full chat history
|
||||
max_tokens: Maximum tokens allowed for history
|
||||
|
||||
Returns:
|
||||
Truncated history (most recent messages that fit)
|
||||
"""
|
||||
from application.utils import num_tokens_from_string
|
||||
|
||||
if not history or max_tokens <= 0:
|
||||
return []
|
||||
|
||||
truncated = []
|
||||
current_tokens = 0
|
||||
|
||||
# Iterate from newest to oldest
|
||||
for message in reversed(history):
|
||||
message_tokens = 0
|
||||
|
||||
if "prompt" in message and "response" in message:
|
||||
message_tokens += num_tokens_from_string(message["prompt"])
|
||||
message_tokens += num_tokens_from_string(message["response"])
|
||||
|
||||
if "tool_calls" in message:
|
||||
for tool_call in message["tool_calls"]:
|
||||
tool_str = (
|
||||
f"Tool: {tool_call.get('tool_name')} | "
|
||||
f"Action: {tool_call.get('action_name')} | "
|
||||
f"Args: {tool_call.get('arguments')} | "
|
||||
f"Response: {tool_call.get('result')}"
|
||||
)
|
||||
message_tokens += num_tokens_from_string(tool_str)
|
||||
|
||||
if current_tokens + message_tokens <= max_tokens:
|
||||
current_tokens += message_tokens
|
||||
truncated.insert(0, message) # Maintain chronological order
|
||||
else:
|
||||
break
|
||||
|
||||
if len(truncated) < len(history):
|
||||
logger.info(
|
||||
f"Truncated chat history from {len(history)} to {len(truncated)} messages "
|
||||
f"to fit within {max_tokens:,} token budget"
|
||||
)
|
||||
|
||||
return truncated
|
||||
|
||||
def _llm_gen(self, messages: List[Dict], log_context: Optional[LogContext] = None):
|
||||
# Pre-flight context validation - fail fast if over limit
|
||||
self._validate_context_size(messages)
|
||||
|
||||
gen_kwargs = {"model": self.model_id, "messages": messages}
|
||||
|
||||
if (
|
||||
|
||||
@@ -235,4 +235,4 @@ class ReActAgent(BaseAgent):
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting content: {e}")
|
||||
return "".join(collected)
|
||||
return "".join(collected)
|
||||
|
||||
@@ -1,323 +0,0 @@
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from urllib.parse import quote, urlencode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ContentType(str, Enum):
|
||||
"""Supported content types for request bodies."""
|
||||
|
||||
JSON = "application/json"
|
||||
FORM_URLENCODED = "application/x-www-form-urlencoded"
|
||||
MULTIPART_FORM_DATA = "multipart/form-data"
|
||||
TEXT_PLAIN = "text/plain"
|
||||
XML = "application/xml"
|
||||
OCTET_STREAM = "application/octet-stream"
|
||||
|
||||
|
||||
class RequestBodySerializer:
|
||||
"""Serializes request bodies according to content-type and OpenAPI 3.1 spec."""
|
||||
|
||||
@staticmethod
|
||||
def serialize(
|
||||
body_data: Dict[str, Any],
|
||||
content_type: str = ContentType.JSON,
|
||||
encoding_rules: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
) -> tuple[Union[str, bytes], Dict[str, str]]:
|
||||
"""
|
||||
Serialize body data to appropriate format.
|
||||
|
||||
Args:
|
||||
body_data: Dictionary of body parameters
|
||||
content_type: Content-Type header value
|
||||
encoding_rules: OpenAPI Encoding Object rules per field
|
||||
|
||||
Returns:
|
||||
Tuple of (serialized_body, updated_headers_dict)
|
||||
|
||||
Raises:
|
||||
ValueError: If serialization fails
|
||||
"""
|
||||
if not body_data:
|
||||
return None, {}
|
||||
|
||||
try:
|
||||
content_type_lower = content_type.lower().split(";")[0].strip()
|
||||
|
||||
if content_type_lower == ContentType.JSON:
|
||||
return RequestBodySerializer._serialize_json(body_data)
|
||||
|
||||
elif content_type_lower == ContentType.FORM_URLENCODED:
|
||||
return RequestBodySerializer._serialize_form_urlencoded(
|
||||
body_data, encoding_rules
|
||||
)
|
||||
|
||||
elif content_type_lower == ContentType.MULTIPART_FORM_DATA:
|
||||
return RequestBodySerializer._serialize_multipart_form_data(
|
||||
body_data, encoding_rules
|
||||
)
|
||||
|
||||
elif content_type_lower == ContentType.TEXT_PLAIN:
|
||||
return RequestBodySerializer._serialize_text_plain(body_data)
|
||||
|
||||
elif content_type_lower == ContentType.XML:
|
||||
return RequestBodySerializer._serialize_xml(body_data)
|
||||
|
||||
elif content_type_lower == ContentType.OCTET_STREAM:
|
||||
return RequestBodySerializer._serialize_octet_stream(body_data)
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
f"Unknown content type: {content_type}, treating as JSON"
|
||||
)
|
||||
return RequestBodySerializer._serialize_json(body_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error serializing body: {str(e)}", exc_info=True)
|
||||
raise ValueError(f"Failed to serialize request body: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def _serialize_json(body_data: Dict[str, Any]) -> tuple[str, Dict[str, str]]:
|
||||
"""Serialize body as JSON per OpenAPI spec."""
|
||||
try:
|
||||
serialized = json.dumps(
|
||||
body_data, separators=(",", ":"), ensure_ascii=False
|
||||
)
|
||||
headers = {"Content-Type": ContentType.JSON.value}
|
||||
return serialized, headers
|
||||
except (TypeError, ValueError) as e:
|
||||
raise ValueError(f"Failed to serialize JSON body: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def _serialize_form_urlencoded(
|
||||
body_data: Dict[str, Any],
|
||||
encoding_rules: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
) -> tuple[str, Dict[str, str]]:
|
||||
"""Serialize body as application/x-www-form-urlencoded per RFC1866/RFC3986."""
|
||||
encoding_rules = encoding_rules or {}
|
||||
params = []
|
||||
|
||||
for key, value in body_data.items():
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
rule = encoding_rules.get(key, {})
|
||||
style = rule.get("style", "form")
|
||||
explode = rule.get("explode", style == "form")
|
||||
content_type = rule.get("contentType", "text/plain")
|
||||
|
||||
serialized_value = RequestBodySerializer._serialize_form_value(
|
||||
value, style, explode, content_type, key
|
||||
)
|
||||
|
||||
if isinstance(serialized_value, list):
|
||||
for sv in serialized_value:
|
||||
params.append((key, sv))
|
||||
else:
|
||||
params.append((key, serialized_value))
|
||||
|
||||
# Use standard urlencode (replaces space with +)
|
||||
serialized = urlencode(params, safe="")
|
||||
headers = {"Content-Type": ContentType.FORM_URLENCODED.value}
|
||||
return serialized, headers
|
||||
|
||||
@staticmethod
|
||||
def _serialize_form_value(
|
||||
value: Any, style: str, explode: bool, content_type: str, key: str
|
||||
) -> Union[str, list]:
|
||||
"""Serialize individual form value with encoding rules."""
|
||||
if isinstance(value, dict):
|
||||
if content_type == "application/json":
|
||||
return json.dumps(value, separators=(",", ":"))
|
||||
elif content_type == "application/xml":
|
||||
return RequestBodySerializer._dict_to_xml(value)
|
||||
else:
|
||||
if style == "deepObject" and explode:
|
||||
return [
|
||||
f"{RequestBodySerializer._percent_encode(str(v))}"
|
||||
for v in value.values()
|
||||
]
|
||||
elif explode:
|
||||
return [
|
||||
f"{RequestBodySerializer._percent_encode(str(v))}"
|
||||
for v in value.values()
|
||||
]
|
||||
else:
|
||||
pairs = [f"{k},{v}" for k, v in value.items()]
|
||||
return RequestBodySerializer._percent_encode(",".join(pairs))
|
||||
|
||||
elif isinstance(value, (list, tuple)):
|
||||
if explode:
|
||||
return [
|
||||
RequestBodySerializer._percent_encode(str(item)) for item in value
|
||||
]
|
||||
else:
|
||||
return RequestBodySerializer._percent_encode(
|
||||
",".join(str(v) for v in value)
|
||||
)
|
||||
|
||||
else:
|
||||
return RequestBodySerializer._percent_encode(str(value))
|
||||
|
||||
@staticmethod
|
||||
def _serialize_multipart_form_data(
|
||||
body_data: Dict[str, Any],
|
||||
encoding_rules: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
) -> tuple[bytes, Dict[str, str]]:
|
||||
"""
|
||||
Serialize body as multipart/form-data per RFC7578.
|
||||
|
||||
Supports file uploads and encoding rules.
|
||||
"""
|
||||
import secrets
|
||||
|
||||
encoding_rules = encoding_rules or {}
|
||||
boundary = f"----DocsGPT{secrets.token_hex(16)}"
|
||||
parts = []
|
||||
|
||||
for key, value in body_data.items():
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
rule = encoding_rules.get(key, {})
|
||||
content_type = rule.get("contentType", "text/plain")
|
||||
headers_rule = rule.get("headers", {})
|
||||
|
||||
part = RequestBodySerializer._create_multipart_part(
|
||||
key, value, content_type, headers_rule
|
||||
)
|
||||
parts.append(part)
|
||||
|
||||
body_bytes = f"--{boundary}\r\n".encode("utf-8")
|
||||
body_bytes += f"--{boundary}\r\n".join(parts).encode("utf-8")
|
||||
body_bytes += f"\r\n--{boundary}--\r\n".encode("utf-8")
|
||||
|
||||
headers = {
|
||||
"Content-Type": f"multipart/form-data; boundary={boundary}",
|
||||
}
|
||||
return body_bytes, headers
|
||||
|
||||
@staticmethod
|
||||
def _create_multipart_part(
|
||||
name: str, value: Any, content_type: str, headers_rule: Dict[str, Any]
|
||||
) -> str:
|
||||
"""Create a single multipart/form-data part."""
|
||||
headers = [
|
||||
f'Content-Disposition: form-data; name="{RequestBodySerializer._percent_encode(name)}"'
|
||||
]
|
||||
|
||||
if isinstance(value, bytes):
|
||||
if content_type == "application/octet-stream":
|
||||
value_encoded = base64.b64encode(value).decode("utf-8")
|
||||
else:
|
||||
value_encoded = value.decode("utf-8", errors="replace")
|
||||
headers.append(f"Content-Type: {content_type}")
|
||||
headers.append("Content-Transfer-Encoding: base64")
|
||||
elif isinstance(value, dict):
|
||||
if content_type == "application/json":
|
||||
value_encoded = json.dumps(value, separators=(",", ":"))
|
||||
elif content_type == "application/xml":
|
||||
value_encoded = RequestBodySerializer._dict_to_xml(value)
|
||||
else:
|
||||
value_encoded = str(value)
|
||||
headers.append(f"Content-Type: {content_type}")
|
||||
elif isinstance(value, str) and content_type != "text/plain":
|
||||
try:
|
||||
if content_type == "application/json":
|
||||
json.loads(value)
|
||||
value_encoded = value
|
||||
elif content_type == "application/xml":
|
||||
value_encoded = value
|
||||
else:
|
||||
value_encoded = str(value)
|
||||
except json.JSONDecodeError:
|
||||
value_encoded = str(value)
|
||||
headers.append(f"Content-Type: {content_type}")
|
||||
else:
|
||||
value_encoded = str(value)
|
||||
if content_type != "text/plain":
|
||||
headers.append(f"Content-Type: {content_type}")
|
||||
|
||||
part = "\r\n".join(headers) + "\r\n\r\n" + value_encoded + "\r\n"
|
||||
return part
|
||||
|
||||
@staticmethod
|
||||
def _serialize_text_plain(body_data: Dict[str, Any]) -> tuple[str, Dict[str, str]]:
|
||||
"""Serialize body as plain text."""
|
||||
if len(body_data) == 1:
|
||||
value = list(body_data.values())[0]
|
||||
return str(value), {"Content-Type": ContentType.TEXT_PLAIN.value}
|
||||
else:
|
||||
text = "\n".join(f"{k}: {v}" for k, v in body_data.items())
|
||||
return text, {"Content-Type": ContentType.TEXT_PLAIN.value}
|
||||
|
||||
@staticmethod
|
||||
def _serialize_xml(body_data: Dict[str, Any]) -> tuple[str, Dict[str, str]]:
|
||||
"""Serialize body as XML."""
|
||||
xml_str = RequestBodySerializer._dict_to_xml(body_data)
|
||||
return xml_str, {"Content-Type": ContentType.XML.value}
|
||||
|
||||
@staticmethod
|
||||
def _serialize_octet_stream(
|
||||
body_data: Dict[str, Any],
|
||||
) -> tuple[bytes, Dict[str, str]]:
|
||||
"""Serialize body as binary octet stream."""
|
||||
if isinstance(body_data, bytes):
|
||||
return body_data, {"Content-Type": ContentType.OCTET_STREAM.value}
|
||||
elif isinstance(body_data, str):
|
||||
return body_data.encode("utf-8"), {
|
||||
"Content-Type": ContentType.OCTET_STREAM.value
|
||||
}
|
||||
else:
|
||||
serialized = json.dumps(body_data)
|
||||
return serialized.encode("utf-8"), {
|
||||
"Content-Type": ContentType.OCTET_STREAM.value
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _percent_encode(value: str, safe_chars: str = "") -> str:
|
||||
"""
|
||||
Percent-encode per RFC3986.
|
||||
|
||||
Args:
|
||||
value: String to encode
|
||||
safe_chars: Additional characters to not encode
|
||||
"""
|
||||
return quote(value, safe=safe_chars)
|
||||
|
||||
@staticmethod
|
||||
def _dict_to_xml(data: Dict[str, Any], root_name: str = "root") -> str:
|
||||
"""
|
||||
Convert dict to simple XML format.
|
||||
"""
|
||||
|
||||
def build_xml(obj: Any, name: str) -> str:
|
||||
if isinstance(obj, dict):
|
||||
inner = "".join(build_xml(v, k) for k, v in obj.items())
|
||||
return f"<{name}>{inner}</{name}>"
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
items = "".join(
|
||||
build_xml(item, f"{name[:-1] if name.endswith('s') else name}")
|
||||
for item in obj
|
||||
)
|
||||
return items
|
||||
else:
|
||||
return f"<{name}>{RequestBodySerializer._escape_xml(str(obj))}</{name}>"
|
||||
|
||||
root = build_xml(data, root_name)
|
||||
return f'<?xml version="1.0" encoding="UTF-8"?>{root}'
|
||||
|
||||
@staticmethod
|
||||
def _escape_xml(value: str) -> str:
|
||||
"""Escape XML special characters."""
|
||||
return (
|
||||
value.replace("&", "&")
|
||||
.replace("<", "<")
|
||||
.replace(">", ">")
|
||||
.replace('"', """)
|
||||
.replace("'", "'")
|
||||
)
|
||||
@@ -1,280 +1,72 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import requests
|
||||
|
||||
from application.agents.tools.api_body_serializer import (
|
||||
ContentType,
|
||||
RequestBodySerializer,
|
||||
)
|
||||
from application.agents.tools.base import Tool
|
||||
from application.core.url_validation import validate_url, SSRFError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_TIMEOUT = 90 # seconds
|
||||
|
||||
|
||||
class APITool(Tool):
|
||||
"""
|
||||
API Tool
|
||||
A flexible tool for performing various API actions (e.g., sending messages, retrieving data) via custom user-specified APIs.
|
||||
A flexible tool for performing various API actions (e.g., sending messages, retrieving data) via custom user-specified APIs
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.url = config.get("url", "")
|
||||
self.method = config.get("method", "GET")
|
||||
self.headers = config.get("headers", {})
|
||||
self.headers = config.get("headers", {"Content-Type": "application/json"})
|
||||
self.query_params = config.get("query_params", {})
|
||||
self.body_content_type = config.get("body_content_type", ContentType.JSON)
|
||||
self.body_encoding_rules = config.get("body_encoding_rules", {})
|
||||
|
||||
def execute_action(self, action_name, **kwargs):
|
||||
"""Execute an API action with the given arguments."""
|
||||
return self._make_api_call(
|
||||
self.url,
|
||||
self.method,
|
||||
self.headers,
|
||||
self.query_params,
|
||||
kwargs,
|
||||
self.body_content_type,
|
||||
self.body_encoding_rules,
|
||||
self.url, self.method, self.headers, self.query_params, kwargs
|
||||
)
|
||||
|
||||
def _make_api_call(
|
||||
self,
|
||||
url: str,
|
||||
method: str,
|
||||
headers: Dict[str, str],
|
||||
query_params: Dict[str, Any],
|
||||
body: Dict[str, Any],
|
||||
content_type: str = ContentType.JSON,
|
||||
encoding_rules: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Make an API call with proper body serialization and error handling.
|
||||
|
||||
Args:
|
||||
url: API endpoint URL
|
||||
method: HTTP method (GET, POST, PUT, DELETE, PATCH, HEAD, OPTIONS)
|
||||
headers: Request headers dict
|
||||
query_params: URL query parameters
|
||||
body: Request body as dict
|
||||
content_type: Content-Type for serialization
|
||||
encoding_rules: OpenAPI encoding rules
|
||||
|
||||
Returns:
|
||||
Dict with status_code, data, and message
|
||||
"""
|
||||
request_url = url
|
||||
request_headers = headers.copy() if headers else {}
|
||||
response = None
|
||||
|
||||
# Validate URL to prevent SSRF attacks
|
||||
def _make_api_call(self, url, method, headers, query_params, body):
|
||||
if query_params:
|
||||
url = f"{url}?{requests.compat.urlencode(query_params)}"
|
||||
# if isinstance(body, dict):
|
||||
# body = json.dumps(body)
|
||||
try:
|
||||
validate_url(request_url)
|
||||
except SSRFError as e:
|
||||
logger.error(f"URL validation failed: {e}")
|
||||
return {
|
||||
"status_code": None,
|
||||
"message": f"URL validation error: {e}",
|
||||
"data": None,
|
||||
}
|
||||
|
||||
try:
|
||||
path_params_used = set()
|
||||
if query_params:
|
||||
for match in re.finditer(r"\{([^}]+)\}", request_url):
|
||||
param_name = match.group(1)
|
||||
if param_name in query_params:
|
||||
request_url = request_url.replace(
|
||||
f"{{{param_name}}}", str(query_params[param_name])
|
||||
)
|
||||
path_params_used.add(param_name)
|
||||
remaining_params = {
|
||||
k: v for k, v in query_params.items() if k not in path_params_used
|
||||
}
|
||||
if remaining_params:
|
||||
query_string = urlencode(remaining_params)
|
||||
separator = "&" if "?" in request_url else "?"
|
||||
request_url = f"{request_url}{separator}{query_string}"
|
||||
|
||||
# Re-validate URL after parameter substitution to prevent SSRF via path params
|
||||
try:
|
||||
validate_url(request_url)
|
||||
except SSRFError as e:
|
||||
logger.error(f"URL validation failed after parameter substitution: {e}")
|
||||
return {
|
||||
"status_code": None,
|
||||
"message": f"URL validation error: {e}",
|
||||
"data": None,
|
||||
}
|
||||
|
||||
# Serialize body based on content type
|
||||
|
||||
if body and body != {}:
|
||||
try:
|
||||
serialized_body, body_headers = RequestBodySerializer.serialize(
|
||||
body, content_type, encoding_rules
|
||||
)
|
||||
request_headers.update(body_headers)
|
||||
except ValueError as e:
|
||||
logger.error(f"Body serialization failed: {str(e)}")
|
||||
return {
|
||||
"status_code": None,
|
||||
"message": f"Body serialization error: {str(e)}",
|
||||
"data": None,
|
||||
}
|
||||
else:
|
||||
serialized_body = None
|
||||
if "Content-Type" not in request_headers and method not in [
|
||||
"GET",
|
||||
"HEAD",
|
||||
"DELETE",
|
||||
]:
|
||||
request_headers["Content-Type"] = ContentType.JSON
|
||||
logger.debug(
|
||||
f"API Call: {method} {request_url} | Content-Type: {request_headers.get('Content-Type', 'N/A')}"
|
||||
)
|
||||
|
||||
if method.upper() == "GET":
|
||||
response = requests.get(
|
||||
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
|
||||
)
|
||||
elif method.upper() == "POST":
|
||||
response = requests.post(
|
||||
request_url,
|
||||
data=serialized_body,
|
||||
headers=request_headers,
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
)
|
||||
elif method.upper() == "PUT":
|
||||
response = requests.put(
|
||||
request_url,
|
||||
data=serialized_body,
|
||||
headers=request_headers,
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
)
|
||||
elif method.upper() == "DELETE":
|
||||
response = requests.delete(
|
||||
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
|
||||
)
|
||||
elif method.upper() == "PATCH":
|
||||
response = requests.patch(
|
||||
request_url,
|
||||
data=serialized_body,
|
||||
headers=request_headers,
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
)
|
||||
elif method.upper() == "HEAD":
|
||||
response = requests.head(
|
||||
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
|
||||
)
|
||||
elif method.upper() == "OPTIONS":
|
||||
response = requests.options(
|
||||
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
|
||||
)
|
||||
else:
|
||||
return {
|
||||
"status_code": None,
|
||||
"message": f"Unsupported HTTP method: {method}",
|
||||
"data": None,
|
||||
}
|
||||
print(f"Making API call: {method} {url} with body: {body}")
|
||||
if body == "{}":
|
||||
body = None
|
||||
response = requests.request(method, url, headers=headers, data=body)
|
||||
response.raise_for_status()
|
||||
|
||||
data = self._parse_response(response)
|
||||
content_type = response.headers.get(
|
||||
"Content-Type", "application/json"
|
||||
).lower()
|
||||
if "application/json" in content_type:
|
||||
try:
|
||||
data = response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Error decoding JSON: {e}. Raw response: {response.text}")
|
||||
return {
|
||||
"status_code": response.status_code,
|
||||
"message": f"API call returned invalid JSON. Error: {e}",
|
||||
"data": response.text,
|
||||
}
|
||||
elif "text/" in content_type or "application/xml" in content_type:
|
||||
data = response.text
|
||||
elif not response.content:
|
||||
data = None
|
||||
else:
|
||||
print(f"Unsupported content type: {content_type}")
|
||||
data = response.content
|
||||
|
||||
return {
|
||||
"status_code": response.status_code,
|
||||
"data": data,
|
||||
"message": "API call successful.",
|
||||
}
|
||||
except requests.exceptions.Timeout:
|
||||
logger.error(f"Request timeout for {request_url}")
|
||||
return {
|
||||
"status_code": None,
|
||||
"message": f"Request timeout ({DEFAULT_TIMEOUT}s exceeded)",
|
||||
"data": None,
|
||||
}
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
logger.error(f"Connection error: {str(e)}")
|
||||
return {
|
||||
"status_code": None,
|
||||
"message": f"Connection error: {str(e)}",
|
||||
"data": None,
|
||||
}
|
||||
except requests.exceptions.HTTPError as e:
|
||||
logger.error(f"HTTP error {response.status_code}: {str(e)}")
|
||||
try:
|
||||
error_data = response.json()
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
error_data = response.text
|
||||
return {
|
||||
"status_code": response.status_code,
|
||||
"message": f"HTTP Error {response.status_code}",
|
||||
"data": error_data,
|
||||
}
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"Request failed: {str(e)}")
|
||||
return {
|
||||
"status_code": response.status_code if response else None,
|
||||
"message": f"API call failed: {str(e)}",
|
||||
"data": None,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in API call: {str(e)}", exc_info=True)
|
||||
return {
|
||||
"status_code": None,
|
||||
"message": f"Unexpected error: {str(e)}",
|
||||
"data": None,
|
||||
}
|
||||
|
||||
def _parse_response(self, response: requests.Response) -> Any:
|
||||
"""
|
||||
Parse response based on Content-Type header.
|
||||
|
||||
Supports: JSON, XML, plain text, binary data.
|
||||
"""
|
||||
content_type = response.headers.get("Content-Type", "").lower()
|
||||
|
||||
if not response.content:
|
||||
return None
|
||||
# JSON response
|
||||
|
||||
if "application/json" in content_type:
|
||||
try:
|
||||
return response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Failed to parse JSON response: {str(e)}")
|
||||
return response.text
|
||||
# XML response
|
||||
|
||||
elif "application/xml" in content_type or "text/xml" in content_type:
|
||||
return response.text
|
||||
# Plain text response
|
||||
|
||||
elif "text/plain" in content_type or "text/html" in content_type:
|
||||
return response.text
|
||||
# Binary/unknown response
|
||||
|
||||
else:
|
||||
# Try to decode as text first, fall back to base64
|
||||
|
||||
try:
|
||||
return response.text
|
||||
except (UnicodeDecodeError, AttributeError):
|
||||
import base64
|
||||
|
||||
return base64.b64encode(response.content).decode("utf-8")
|
||||
|
||||
def get_actions_metadata(self):
|
||||
"""Return metadata for available actions (none for API Tool - actions are user-defined)."""
|
||||
return []
|
||||
|
||||
def get_config_requirements(self):
|
||||
"""Return configuration requirements for the tool."""
|
||||
return {}
|
||||
|
||||
@@ -169,8 +169,6 @@ class MCPTool(Tool):
|
||||
transport_type = "http"
|
||||
else:
|
||||
transport_type = self.transport_type
|
||||
if transport_type == "stdio":
|
||||
raise ValueError("STDIO transport is disabled")
|
||||
if transport_type == "sse":
|
||||
headers.update({"Accept": "text/event-stream", "Cache-Control": "no-cache"})
|
||||
return SSETransport(url=self.server_url, headers=headers)
|
||||
@@ -456,12 +454,6 @@ class MCPTool(Tool):
|
||||
|
||||
def get_config_requirements(self) -> Dict:
|
||||
"""Get configuration requirements for the MCP tool."""
|
||||
transport_enum = ["auto", "sse", "http"]
|
||||
transport_help = {
|
||||
"auto": "Automatically detect best transport",
|
||||
"sse": "Server-Sent Events (for real-time streaming)",
|
||||
"http": "HTTP streaming (recommended for production)",
|
||||
}
|
||||
return {
|
||||
"server_url": {
|
||||
"type": "string",
|
||||
@@ -471,11 +463,14 @@ class MCPTool(Tool):
|
||||
"transport_type": {
|
||||
"type": "string",
|
||||
"description": "Transport type for connection",
|
||||
"enum": transport_enum,
|
||||
"enum": ["auto", "sse", "http", "stdio"],
|
||||
"default": "auto",
|
||||
"required": False,
|
||||
"help": {
|
||||
**transport_help,
|
||||
"auto": "Automatically detect best transport",
|
||||
"sse": "Server-Sent Events (for real-time streaming)",
|
||||
"http": "HTTP streaming (recommended for production)",
|
||||
"stdio": "Standard I/O (for local servers)",
|
||||
},
|
||||
},
|
||||
"auth_type": {
|
||||
|
||||
@@ -38,8 +38,6 @@ class NotesTool(Tool):
|
||||
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
|
||||
self.collection = db["notes"]
|
||||
|
||||
self._last_artifact_id: Optional[str] = None
|
||||
|
||||
# -----------------------------
|
||||
# Action implementations
|
||||
# -----------------------------
|
||||
@@ -56,8 +54,6 @@ class NotesTool(Tool):
|
||||
if not self.user_id:
|
||||
return "Error: NotesTool requires a valid user_id."
|
||||
|
||||
self._last_artifact_id = None
|
||||
|
||||
if action_name == "view":
|
||||
return self._get_note()
|
||||
|
||||
@@ -129,9 +125,6 @@ class NotesTool(Tool):
|
||||
"""Return configuration requirements (none for now)."""
|
||||
return {}
|
||||
|
||||
def get_artifact_id(self, action_name: str, **kwargs: Any) -> Optional[str]:
|
||||
return self._last_artifact_id
|
||||
|
||||
# -----------------------------
|
||||
# Internal helpers (single-note)
|
||||
# -----------------------------
|
||||
@@ -139,22 +132,17 @@ class NotesTool(Tool):
|
||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
|
||||
if not doc or not doc.get("note"):
|
||||
return "No note found."
|
||||
if doc.get("_id") is not None:
|
||||
self._last_artifact_id = str(doc.get("_id"))
|
||||
return str(doc["note"])
|
||||
|
||||
def _overwrite_note(self, content: str) -> str:
|
||||
content = (content or "").strip()
|
||||
if not content:
|
||||
return "Note content required."
|
||||
result = self.collection.find_one_and_update(
|
||||
self.collection.update_one(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id},
|
||||
{"$set": {"note": content, "updated_at": datetime.utcnow()}},
|
||||
upsert=True,
|
||||
return_document=True,
|
||||
upsert=True, # ✅ create if missing
|
||||
)
|
||||
if result and result.get("_id") is not None:
|
||||
self._last_artifact_id = str(result.get("_id"))
|
||||
return "Note saved."
|
||||
|
||||
def _str_replace(self, old_str: str, new_str: str) -> str:
|
||||
@@ -175,13 +163,10 @@ class NotesTool(Tool):
|
||||
import re
|
||||
updated_note = re.sub(re.escape(old_str), new_str, current_note, flags=re.IGNORECASE)
|
||||
|
||||
result = self.collection.find_one_and_update(
|
||||
self.collection.update_one(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id},
|
||||
{"$set": {"note": updated_note, "updated_at": datetime.utcnow()}},
|
||||
return_document=True,
|
||||
)
|
||||
if result and result.get("_id") is not None:
|
||||
self._last_artifact_id = str(result.get("_id"))
|
||||
return "Note updated."
|
||||
|
||||
def _insert(self, line_number: int, text: str) -> str:
|
||||
@@ -203,21 +188,12 @@ class NotesTool(Tool):
|
||||
lines.insert(index, text)
|
||||
updated_note = "\n".join(lines)
|
||||
|
||||
result = self.collection.find_one_and_update(
|
||||
self.collection.update_one(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id},
|
||||
{"$set": {"note": updated_note, "updated_at": datetime.utcnow()}},
|
||||
return_document=True,
|
||||
)
|
||||
if result and result.get("_id") is not None:
|
||||
self._last_artifact_id = str(result.get("_id"))
|
||||
return "Text inserted."
|
||||
|
||||
def _delete_note(self) -> str:
|
||||
doc = self.collection.find_one_and_delete(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id}
|
||||
)
|
||||
if not doc:
|
||||
return "No note found to delete."
|
||||
if doc.get("_id") is not None:
|
||||
self._last_artifact_id = str(doc.get("_id"))
|
||||
return "Note deleted."
|
||||
res = self.collection.delete_one({"user_id": self.user_id, "tool_id": self.tool_id})
|
||||
return "Note deleted." if res.deleted_count else "No note found to delete."
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import requests
|
||||
from markdownify import markdownify
|
||||
from application.agents.tools.base import Tool
|
||||
from application.core.url_validation import validate_url, SSRFError
|
||||
from urllib.parse import urlparse
|
||||
|
||||
class ReadWebpageTool(Tool):
|
||||
"""
|
||||
@@ -31,12 +31,11 @@ class ReadWebpageTool(Tool):
|
||||
if not url:
|
||||
return "Error: URL parameter is missing."
|
||||
|
||||
# Validate URL to prevent SSRF attacks
|
||||
try:
|
||||
url = validate_url(url)
|
||||
except SSRFError as e:
|
||||
return f"Error: URL validation failed - {e}"
|
||||
|
||||
# Ensure the URL has a scheme (if not, default to http)
|
||||
parsed_url = urlparse(url)
|
||||
if not parsed_url.scheme:
|
||||
url = "http://" + url
|
||||
|
||||
try:
|
||||
response = requests.get(url, timeout=10, headers={'User-Agent': 'DocsGPT-Agent/1.0'})
|
||||
response.raise_for_status() # Raise an exception for HTTP errors (4xx or 5xx)
|
||||
|
||||
@@ -1,342 +0,0 @@
|
||||
"""
|
||||
API Specification Parser
|
||||
|
||||
Parses OpenAPI 3.x and Swagger 2.0 specifications and converts them
|
||||
to API Tool action definitions for use in DocsGPT.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import yaml
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SUPPORTED_METHODS = frozenset(
|
||||
{"get", "post", "put", "delete", "patch", "head", "options"}
|
||||
)
|
||||
|
||||
|
||||
def parse_spec(spec_content: str) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
|
||||
"""
|
||||
Parse an API specification and convert operations to action definitions.
|
||||
|
||||
Supports OpenAPI 3.x and Swagger 2.0 formats in JSON or YAML.
|
||||
|
||||
Args:
|
||||
spec_content: Raw specification content as string
|
||||
|
||||
Returns:
|
||||
Tuple of (metadata dict, list of action dicts)
|
||||
|
||||
Raises:
|
||||
ValueError: If the spec is invalid or uses an unsupported format
|
||||
"""
|
||||
spec = _load_spec(spec_content)
|
||||
_validate_spec(spec)
|
||||
|
||||
is_swagger = "swagger" in spec
|
||||
metadata = _extract_metadata(spec, is_swagger)
|
||||
actions = _extract_actions(spec, is_swagger)
|
||||
|
||||
return metadata, actions
|
||||
|
||||
|
||||
def _load_spec(content: str) -> Dict[str, Any]:
|
||||
"""Parse spec content from JSON or YAML string."""
|
||||
content = content.strip()
|
||||
if not content:
|
||||
raise ValueError("Empty specification content")
|
||||
try:
|
||||
if content.startswith("{"):
|
||||
return json.loads(content)
|
||||
return yaml.safe_load(content)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON format: {e.msg}")
|
||||
except yaml.YAMLError as e:
|
||||
raise ValueError(f"Invalid YAML format: {e}")
|
||||
|
||||
|
||||
def _validate_spec(spec: Dict[str, Any]) -> None:
|
||||
"""Validate spec version and required fields."""
|
||||
if not isinstance(spec, dict):
|
||||
raise ValueError("Specification must be a valid object")
|
||||
openapi_version = spec.get("openapi", "")
|
||||
swagger_version = spec.get("swagger", "")
|
||||
|
||||
if not (openapi_version.startswith("3.") or swagger_version == "2.0"):
|
||||
raise ValueError(
|
||||
"Unsupported specification version. Expected OpenAPI 3.x or Swagger 2.0"
|
||||
)
|
||||
if "paths" not in spec or not spec["paths"]:
|
||||
raise ValueError("No API paths defined in the specification")
|
||||
|
||||
|
||||
def _extract_metadata(spec: Dict[str, Any], is_swagger: bool) -> Dict[str, Any]:
|
||||
"""Extract API metadata from specification."""
|
||||
info = spec.get("info", {})
|
||||
base_url = _get_base_url(spec, is_swagger)
|
||||
|
||||
return {
|
||||
"title": info.get("title", "Untitled API"),
|
||||
"description": (info.get("description", "") or "")[:500],
|
||||
"version": info.get("version", ""),
|
||||
"base_url": base_url,
|
||||
}
|
||||
|
||||
|
||||
def _get_base_url(spec: Dict[str, Any], is_swagger: bool) -> str:
|
||||
"""Extract base URL from spec (handles both OpenAPI 3.x and Swagger 2.0)."""
|
||||
if is_swagger:
|
||||
schemes = spec.get("schemes", ["https"])
|
||||
host = spec.get("host", "")
|
||||
base_path = spec.get("basePath", "")
|
||||
if host:
|
||||
scheme = schemes[0] if schemes else "https"
|
||||
return f"{scheme}://{host}{base_path}".rstrip("/")
|
||||
return ""
|
||||
servers = spec.get("servers", [])
|
||||
if servers and isinstance(servers, list) and servers[0].get("url"):
|
||||
return servers[0]["url"].rstrip("/")
|
||||
return ""
|
||||
|
||||
|
||||
def _extract_actions(spec: Dict[str, Any], is_swagger: bool) -> List[Dict[str, Any]]:
|
||||
"""Extract all API operations as action definitions."""
|
||||
actions = []
|
||||
paths = spec.get("paths", {})
|
||||
base_url = _get_base_url(spec, is_swagger)
|
||||
|
||||
components = spec.get("components", {})
|
||||
definitions = spec.get("definitions", {})
|
||||
|
||||
for path, path_item in paths.items():
|
||||
if not isinstance(path_item, dict):
|
||||
continue
|
||||
path_params = path_item.get("parameters", [])
|
||||
|
||||
for method in SUPPORTED_METHODS:
|
||||
operation = path_item.get(method)
|
||||
if not isinstance(operation, dict):
|
||||
continue
|
||||
try:
|
||||
action = _build_action(
|
||||
path=path,
|
||||
method=method,
|
||||
operation=operation,
|
||||
path_params=path_params,
|
||||
base_url=base_url,
|
||||
components=components,
|
||||
definitions=definitions,
|
||||
is_swagger=is_swagger,
|
||||
)
|
||||
actions.append(action)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to parse operation {method.upper()} {path}: {e}"
|
||||
)
|
||||
continue
|
||||
return actions
|
||||
|
||||
|
||||
def _build_action(
|
||||
path: str,
|
||||
method: str,
|
||||
operation: Dict[str, Any],
|
||||
path_params: List[Dict],
|
||||
base_url: str,
|
||||
components: Dict[str, Any],
|
||||
definitions: Dict[str, Any],
|
||||
is_swagger: bool,
|
||||
) -> Dict[str, Any]:
|
||||
"""Build a single action from an API operation."""
|
||||
action_name = _generate_action_name(operation, method, path)
|
||||
full_url = f"{base_url}{path}" if base_url else path
|
||||
|
||||
all_params = path_params + operation.get("parameters", [])
|
||||
query_params, headers = _categorize_parameters(all_params, components, definitions)
|
||||
|
||||
body, body_content_type = _extract_request_body(
|
||||
operation, components, definitions, is_swagger
|
||||
)
|
||||
|
||||
description = operation.get("summary", "") or operation.get("description", "")
|
||||
|
||||
return {
|
||||
"name": action_name,
|
||||
"url": full_url,
|
||||
"method": method.upper(),
|
||||
"description": (description or "")[:500],
|
||||
"query_params": {"type": "object", "properties": query_params},
|
||||
"headers": {"type": "object", "properties": headers},
|
||||
"body": {"type": "object", "properties": body},
|
||||
"body_content_type": body_content_type,
|
||||
"active": True,
|
||||
}
|
||||
|
||||
|
||||
def _generate_action_name(operation: Dict[str, Any], method: str, path: str) -> str:
|
||||
"""Generate a valid action name from operationId or method+path."""
|
||||
if operation.get("operationId"):
|
||||
name = operation["operationId"]
|
||||
else:
|
||||
path_slug = re.sub(r"[{}]", "", path)
|
||||
path_slug = re.sub(r"[^a-zA-Z0-9]", "_", path_slug)
|
||||
path_slug = re.sub(r"_+", "_", path_slug).strip("_")
|
||||
name = f"{method}_{path_slug}"
|
||||
name = re.sub(r"[^a-zA-Z0-9_-]", "_", name)
|
||||
return name[:64]
|
||||
|
||||
|
||||
def _categorize_parameters(
|
||||
parameters: List[Dict],
|
||||
components: Dict[str, Any],
|
||||
definitions: Dict[str, Any],
|
||||
) -> Tuple[Dict, Dict]:
|
||||
"""Categorize parameters into query params and headers."""
|
||||
query_params = {}
|
||||
headers = {}
|
||||
|
||||
for param in parameters:
|
||||
resolved = _resolve_ref(param, components, definitions)
|
||||
if not resolved or "name" not in resolved:
|
||||
continue
|
||||
location = resolved.get("in", "query")
|
||||
prop = _param_to_property(resolved)
|
||||
|
||||
if location in ("query", "path"):
|
||||
query_params[resolved["name"]] = prop
|
||||
elif location == "header":
|
||||
headers[resolved["name"]] = prop
|
||||
return query_params, headers
|
||||
|
||||
|
||||
def _param_to_property(param: Dict) -> Dict[str, Any]:
|
||||
"""Convert an API parameter to an action property definition."""
|
||||
schema = param.get("schema", {})
|
||||
param_type = schema.get("type", param.get("type", "string"))
|
||||
|
||||
mapped_type = "integer" if param_type in ("integer", "number") else "string"
|
||||
|
||||
return {
|
||||
"type": mapped_type,
|
||||
"description": (param.get("description", "") or "")[:200],
|
||||
"value": "",
|
||||
"filled_by_llm": param.get("required", False),
|
||||
"required": param.get("required", False),
|
||||
}
|
||||
|
||||
|
||||
def _extract_request_body(
|
||||
operation: Dict[str, Any],
|
||||
components: Dict[str, Any],
|
||||
definitions: Dict[str, Any],
|
||||
is_swagger: bool,
|
||||
) -> Tuple[Dict, str]:
|
||||
"""Extract request body schema and content type."""
|
||||
content_types = [
|
||||
"application/json",
|
||||
"application/x-www-form-urlencoded",
|
||||
"multipart/form-data",
|
||||
"text/plain",
|
||||
"application/xml",
|
||||
]
|
||||
|
||||
if is_swagger:
|
||||
consumes = operation.get("consumes", [])
|
||||
body_param = next(
|
||||
(p for p in operation.get("parameters", []) if p.get("in") == "body"), None
|
||||
)
|
||||
if not body_param:
|
||||
return {}, "application/json"
|
||||
selected_type = consumes[0] if consumes else "application/json"
|
||||
schema = body_param.get("schema", {})
|
||||
else:
|
||||
request_body = operation.get("requestBody", {})
|
||||
if not request_body:
|
||||
return {}, "application/json"
|
||||
request_body = _resolve_ref(request_body, components, definitions)
|
||||
content = request_body.get("content", {})
|
||||
|
||||
selected_type = "application/json"
|
||||
schema = {}
|
||||
|
||||
for ct in content_types:
|
||||
if ct in content:
|
||||
selected_type = ct
|
||||
schema = content[ct].get("schema", {})
|
||||
break
|
||||
if not schema and content:
|
||||
first_type = next(iter(content))
|
||||
selected_type = first_type
|
||||
schema = content[first_type].get("schema", {})
|
||||
properties = _schema_to_properties(schema, components, definitions)
|
||||
return properties, selected_type
|
||||
|
||||
|
||||
def _schema_to_properties(
|
||||
schema: Dict,
|
||||
components: Dict[str, Any],
|
||||
definitions: Dict[str, Any],
|
||||
depth: int = 0,
|
||||
) -> Dict[str, Any]:
|
||||
"""Convert schema to action body properties (limited depth to prevent recursion)."""
|
||||
if depth > 3:
|
||||
return {}
|
||||
schema = _resolve_ref(schema, components, definitions)
|
||||
if not schema or not isinstance(schema, dict):
|
||||
return {}
|
||||
properties = {}
|
||||
schema_type = schema.get("type", "object")
|
||||
|
||||
if schema_type == "object":
|
||||
required_fields = set(schema.get("required", []))
|
||||
for prop_name, prop_schema in schema.get("properties", {}).items():
|
||||
resolved = _resolve_ref(prop_schema, components, definitions)
|
||||
if not isinstance(resolved, dict):
|
||||
continue
|
||||
prop_type = resolved.get("type", "string")
|
||||
mapped_type = "integer" if prop_type in ("integer", "number") else "string"
|
||||
|
||||
properties[prop_name] = {
|
||||
"type": mapped_type,
|
||||
"description": (resolved.get("description", "") or "")[:200],
|
||||
"value": "",
|
||||
"filled_by_llm": prop_name in required_fields,
|
||||
"required": prop_name in required_fields,
|
||||
}
|
||||
return properties
|
||||
|
||||
|
||||
def _resolve_ref(
|
||||
obj: Any,
|
||||
components: Dict[str, Any],
|
||||
definitions: Dict[str, Any],
|
||||
) -> Optional[Dict]:
|
||||
"""Resolve $ref references in the specification."""
|
||||
if not isinstance(obj, dict):
|
||||
return obj if isinstance(obj, dict) else None
|
||||
if "$ref" not in obj:
|
||||
return obj
|
||||
ref_path = obj["$ref"]
|
||||
|
||||
if ref_path.startswith("#/components/"):
|
||||
parts = ref_path.replace("#/components/", "").split("/")
|
||||
return _traverse_path(components, parts)
|
||||
elif ref_path.startswith("#/definitions/"):
|
||||
parts = ref_path.replace("#/definitions/", "").split("/")
|
||||
return _traverse_path(definitions, parts)
|
||||
logger.debug(f"Unsupported ref path: {ref_path}")
|
||||
return None
|
||||
|
||||
|
||||
def _traverse_path(obj: Dict, parts: List[str]) -> Optional[Dict]:
|
||||
"""Traverse a nested dictionary using path parts."""
|
||||
try:
|
||||
for part in parts:
|
||||
obj = obj[part]
|
||||
return obj if isinstance(obj, dict) else None
|
||||
except (KeyError, TypeError):
|
||||
return None
|
||||
@@ -38,8 +38,6 @@ class TodoListTool(Tool):
|
||||
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
|
||||
self.collection = db["todos"]
|
||||
|
||||
self._last_artifact_id: Optional[str] = None
|
||||
|
||||
# -----------------------------
|
||||
# Action implementations
|
||||
# -----------------------------
|
||||
@@ -56,8 +54,6 @@ class TodoListTool(Tool):
|
||||
if not self.user_id:
|
||||
return "Error: TodoListTool requires a valid user_id."
|
||||
|
||||
self._last_artifact_id = None
|
||||
|
||||
if action_name == "list":
|
||||
return self._list()
|
||||
|
||||
@@ -169,9 +165,6 @@ class TodoListTool(Tool):
|
||||
"""Return configuration requirements."""
|
||||
return {}
|
||||
|
||||
def get_artifact_id(self, action_name: str, **kwargs: Any) -> Optional[str]:
|
||||
return self._last_artifact_id
|
||||
|
||||
# -----------------------------
|
||||
# Internal helpers
|
||||
# -----------------------------
|
||||
@@ -197,8 +190,11 @@ class TodoListTool(Tool):
|
||||
Returns a simple integer (1, 2, 3, ...) scoped to this user/tool.
|
||||
With 5-10 todos max, scanning is negligible.
|
||||
"""
|
||||
query = {"user_id": self.user_id, "tool_id": self.tool_id}
|
||||
todos = list(self.collection.find(query, {"todo_id": 1}))
|
||||
# Find all todos for this user/tool and get their IDs
|
||||
todos = list(self.collection.find(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id},
|
||||
{"todo_id": 1}
|
||||
))
|
||||
|
||||
# Find the maximum todo_id
|
||||
max_id = 0
|
||||
@@ -211,8 +207,8 @@ class TodoListTool(Tool):
|
||||
|
||||
def _list(self) -> str:
|
||||
"""List all todos for the user."""
|
||||
query = {"user_id": self.user_id, "tool_id": self.tool_id}
|
||||
todos = list(self.collection.find(query))
|
||||
cursor = self.collection.find({"user_id": self.user_id, "tool_id": self.tool_id})
|
||||
todos = list(cursor)
|
||||
|
||||
if not todos:
|
||||
return "No todos found."
|
||||
@@ -246,10 +242,7 @@ class TodoListTool(Tool):
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
}
|
||||
insert_result = self.collection.insert_one(doc)
|
||||
inserted_id = getattr(insert_result, "inserted_id", None) or doc.get("_id")
|
||||
if inserted_id is not None:
|
||||
self._last_artifact_id = str(inserted_id)
|
||||
self.collection.insert_one(doc)
|
||||
return f"Todo created with ID {todo_id}: {title}"
|
||||
|
||||
def _get(self, todo_id: Optional[Any]) -> str:
|
||||
@@ -258,15 +251,15 @@ class TodoListTool(Tool):
|
||||
if parsed_todo_id is None:
|
||||
return "Error: todo_id must be a positive integer."
|
||||
|
||||
query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}
|
||||
doc = self.collection.find_one(query)
|
||||
doc = self.collection.find_one({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"todo_id": parsed_todo_id
|
||||
})
|
||||
|
||||
if not doc:
|
||||
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||
|
||||
if doc.get("_id") is not None:
|
||||
self._last_artifact_id = str(doc.get("_id"))
|
||||
|
||||
title = doc.get("title", "Untitled")
|
||||
status = doc.get("status", "open")
|
||||
|
||||
@@ -284,16 +277,13 @@ class TodoListTool(Tool):
|
||||
if not title:
|
||||
return "Error: Title is required."
|
||||
|
||||
query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}
|
||||
doc = self.collection.find_one_and_update(
|
||||
query,
|
||||
{"$set": {"title": title, "updated_at": datetime.now()}},
|
||||
result = self.collection.update_one(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id},
|
||||
{"$set": {"title": title, "updated_at": datetime.now()}}
|
||||
)
|
||||
if not doc:
|
||||
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||
|
||||
if doc.get("_id") is not None:
|
||||
self._last_artifact_id = str(doc.get("_id"))
|
||||
if result.matched_count == 0:
|
||||
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||
|
||||
return f"Todo {parsed_todo_id} updated to: {title}"
|
||||
|
||||
@@ -303,16 +293,13 @@ class TodoListTool(Tool):
|
||||
if parsed_todo_id is None:
|
||||
return "Error: todo_id must be a positive integer."
|
||||
|
||||
query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}
|
||||
doc = self.collection.find_one_and_update(
|
||||
query,
|
||||
{"$set": {"status": "completed", "updated_at": datetime.now()}},
|
||||
result = self.collection.update_one(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id},
|
||||
{"$set": {"status": "completed", "updated_at": datetime.now()}}
|
||||
)
|
||||
if not doc:
|
||||
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||
|
||||
if doc.get("_id") is not None:
|
||||
self._last_artifact_id = str(doc.get("_id"))
|
||||
if result.matched_count == 0:
|
||||
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||
|
||||
return f"Todo {parsed_todo_id} marked as completed."
|
||||
|
||||
@@ -322,12 +309,13 @@ class TodoListTool(Tool):
|
||||
if parsed_todo_id is None:
|
||||
return "Error: todo_id must be a positive integer."
|
||||
|
||||
query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}
|
||||
doc = self.collection.find_one_and_delete(query)
|
||||
if not doc:
|
||||
result = self.collection.delete_one({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"todo_id": parsed_todo_id
|
||||
})
|
||||
|
||||
if result.deleted_count == 0:
|
||||
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||
|
||||
if doc.get("_id") is not None:
|
||||
self._last_artifact_id = str(doc.get("_id"))
|
||||
|
||||
return f"Todo {parsed_todo_id} deleted."
|
||||
|
||||
@@ -1,218 +0,0 @@
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Generator, Optional
|
||||
|
||||
from application.agents.base import BaseAgent
|
||||
from application.agents.workflows.schemas import (
|
||||
ExecutionStatus,
|
||||
Workflow,
|
||||
WorkflowEdge,
|
||||
WorkflowGraph,
|
||||
WorkflowNode,
|
||||
WorkflowRun,
|
||||
)
|
||||
from application.agents.workflows.workflow_engine import WorkflowEngine
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.logging import log_activity, LogContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowAgent(BaseAgent):
|
||||
"""A specialized agent that executes predefined workflows."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
workflow_id: Optional[str] = None,
|
||||
workflow: Optional[Dict[str, Any]] = None,
|
||||
workflow_owner: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.workflow_id = workflow_id
|
||||
self.workflow_owner = workflow_owner
|
||||
self._workflow_data = workflow
|
||||
self._engine: Optional[WorkflowEngine] = None
|
||||
|
||||
@log_activity()
|
||||
def gen(
|
||||
self, query: str, log_context: LogContext = None
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
yield from self._gen_inner(query, log_context)
|
||||
|
||||
def _gen_inner(
|
||||
self, query: str, log_context: LogContext
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
graph = self._load_workflow_graph()
|
||||
if not graph:
|
||||
yield {"type": "error", "error": "Failed to load workflow configuration."}
|
||||
return
|
||||
self._engine = WorkflowEngine(graph, self)
|
||||
yield from self._engine.execute({}, query)
|
||||
self._save_workflow_run(query)
|
||||
|
||||
def _load_workflow_graph(self) -> Optional[WorkflowGraph]:
|
||||
if self._workflow_data:
|
||||
return self._parse_embedded_workflow()
|
||||
if self.workflow_id:
|
||||
return self._load_from_database()
|
||||
return None
|
||||
|
||||
def _parse_embedded_workflow(self) -> Optional[WorkflowGraph]:
|
||||
try:
|
||||
nodes_data = self._workflow_data.get("nodes", [])
|
||||
edges_data = self._workflow_data.get("edges", [])
|
||||
|
||||
workflow = Workflow(
|
||||
name=self._workflow_data.get("name", "Embedded Workflow"),
|
||||
description=self._workflow_data.get("description"),
|
||||
)
|
||||
|
||||
nodes = []
|
||||
for n in nodes_data:
|
||||
node_config = n.get("data", {})
|
||||
nodes.append(
|
||||
WorkflowNode(
|
||||
id=n["id"],
|
||||
workflow_id=self.workflow_id or "embedded",
|
||||
type=n["type"],
|
||||
title=n.get("title", "Node"),
|
||||
description=n.get("description"),
|
||||
position=n.get("position", {"x": 0, "y": 0}),
|
||||
config=node_config,
|
||||
)
|
||||
)
|
||||
edges = []
|
||||
for e in edges_data:
|
||||
edges.append(
|
||||
WorkflowEdge(
|
||||
id=e["id"],
|
||||
workflow_id=self.workflow_id or "embedded",
|
||||
source=e.get("source") or e.get("source_id"),
|
||||
target=e.get("target") or e.get("target_id"),
|
||||
sourceHandle=e.get("sourceHandle") or e.get("source_handle"),
|
||||
targetHandle=e.get("targetHandle") or e.get("target_handle"),
|
||||
)
|
||||
)
|
||||
return WorkflowGraph(workflow=workflow, nodes=nodes, edges=edges)
|
||||
except Exception as e:
|
||||
logger.error(f"Invalid embedded workflow: {e}")
|
||||
return None
|
||||
|
||||
def _load_from_database(self) -> Optional[WorkflowGraph]:
|
||||
try:
|
||||
from bson.objectid import ObjectId
|
||||
|
||||
if not self.workflow_id or not ObjectId.is_valid(self.workflow_id):
|
||||
logger.error(f"Invalid workflow ID: {self.workflow_id}")
|
||||
return None
|
||||
owner_id = self.workflow_owner
|
||||
if not owner_id and isinstance(self.decoded_token, dict):
|
||||
owner_id = self.decoded_token.get("sub")
|
||||
if not owner_id:
|
||||
logger.error(
|
||||
f"Workflow owner not available for workflow load: {self.workflow_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
|
||||
workflows_coll = db["workflows"]
|
||||
workflow_nodes_coll = db["workflow_nodes"]
|
||||
workflow_edges_coll = db["workflow_edges"]
|
||||
|
||||
workflow_doc = workflows_coll.find_one(
|
||||
{"_id": ObjectId(self.workflow_id), "user": owner_id}
|
||||
)
|
||||
if not workflow_doc:
|
||||
logger.error(
|
||||
f"Workflow {self.workflow_id} not found or inaccessible for user {owner_id}"
|
||||
)
|
||||
return None
|
||||
workflow = Workflow(**workflow_doc)
|
||||
graph_version = workflow_doc.get("current_graph_version", 1)
|
||||
try:
|
||||
graph_version = int(graph_version)
|
||||
if graph_version <= 0:
|
||||
graph_version = 1
|
||||
except (ValueError, TypeError):
|
||||
graph_version = 1
|
||||
|
||||
nodes_docs = list(
|
||||
workflow_nodes_coll.find(
|
||||
{"workflow_id": self.workflow_id, "graph_version": graph_version}
|
||||
)
|
||||
)
|
||||
if not nodes_docs and graph_version == 1:
|
||||
nodes_docs = list(
|
||||
workflow_nodes_coll.find(
|
||||
{
|
||||
"workflow_id": self.workflow_id,
|
||||
"graph_version": {"$exists": False},
|
||||
}
|
||||
)
|
||||
)
|
||||
nodes = [WorkflowNode(**doc) for doc in nodes_docs]
|
||||
|
||||
edges_docs = list(
|
||||
workflow_edges_coll.find(
|
||||
{"workflow_id": self.workflow_id, "graph_version": graph_version}
|
||||
)
|
||||
)
|
||||
if not edges_docs and graph_version == 1:
|
||||
edges_docs = list(
|
||||
workflow_edges_coll.find(
|
||||
{
|
||||
"workflow_id": self.workflow_id,
|
||||
"graph_version": {"$exists": False},
|
||||
}
|
||||
)
|
||||
)
|
||||
edges = [WorkflowEdge(**doc) for doc in edges_docs]
|
||||
|
||||
return WorkflowGraph(workflow=workflow, nodes=nodes, edges=edges)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load workflow from database: {e}")
|
||||
return None
|
||||
|
||||
def _save_workflow_run(self, query: str) -> None:
|
||||
if not self._engine:
|
||||
return
|
||||
try:
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
workflow_runs_coll = db["workflow_runs"]
|
||||
|
||||
run = WorkflowRun(
|
||||
workflow_id=self.workflow_id or "unknown",
|
||||
status=self._determine_run_status(),
|
||||
inputs={"query": query},
|
||||
outputs=self._serialize_state(self._engine.state),
|
||||
steps=self._engine.get_execution_summary(),
|
||||
created_at=datetime.now(timezone.utc),
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
workflow_runs_coll.insert_one(run.to_mongo_doc())
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save workflow run: {e}")
|
||||
|
||||
def _determine_run_status(self) -> ExecutionStatus:
|
||||
if not self._engine or not self._engine.execution_log:
|
||||
return ExecutionStatus.COMPLETED
|
||||
for log in self._engine.execution_log:
|
||||
if log.get("status") == ExecutionStatus.FAILED.value:
|
||||
return ExecutionStatus.FAILED
|
||||
return ExecutionStatus.COMPLETED
|
||||
|
||||
def _serialize_state(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
||||
serialized: Dict[str, Any] = {}
|
||||
for key, value in state.items():
|
||||
if isinstance(value, (str, int, float, bool, type(None))):
|
||||
serialized[key] = value
|
||||
else:
|
||||
serialized[key] = str(value)
|
||||
return serialized
|
||||
@@ -1,109 +0,0 @@
|
||||
"""Workflow Node Agents - defines specialized agents for workflow nodes."""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
from application.agents.base import BaseAgent
|
||||
from application.agents.classic_agent import ClassicAgent
|
||||
from application.agents.react_agent import ReActAgent
|
||||
from application.agents.workflows.schemas import AgentType
|
||||
|
||||
|
||||
class ToolFilterMixin:
|
||||
"""Mixin that filters fetched tools to only those specified in tool_ids."""
|
||||
|
||||
_allowed_tool_ids: List[str]
|
||||
|
||||
def _get_user_tools(self, user: str = "local") -> Dict[str, Dict[str, Any]]:
|
||||
all_tools = super()._get_user_tools(user)
|
||||
if not self._allowed_tool_ids:
|
||||
return {}
|
||||
filtered_tools = {
|
||||
tool_id: tool
|
||||
for tool_id, tool in all_tools.items()
|
||||
if str(tool.get("_id", "")) in self._allowed_tool_ids
|
||||
}
|
||||
return filtered_tools
|
||||
|
||||
def _get_tools(self, api_key: str = None) -> Dict[str, Dict[str, Any]]:
|
||||
all_tools = super()._get_tools(api_key)
|
||||
if not self._allowed_tool_ids:
|
||||
return {}
|
||||
filtered_tools = {
|
||||
tool_id: tool
|
||||
for tool_id, tool in all_tools.items()
|
||||
if str(tool.get("_id", "")) in self._allowed_tool_ids
|
||||
}
|
||||
return filtered_tools
|
||||
|
||||
|
||||
class WorkflowNodeClassicAgent(ToolFilterMixin, ClassicAgent):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: str,
|
||||
llm_name: str,
|
||||
model_id: str,
|
||||
api_key: str,
|
||||
tool_ids: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
endpoint=endpoint,
|
||||
llm_name=llm_name,
|
||||
model_id=model_id,
|
||||
api_key=api_key,
|
||||
**kwargs,
|
||||
)
|
||||
self._allowed_tool_ids = tool_ids or []
|
||||
|
||||
|
||||
class WorkflowNodeReActAgent(ToolFilterMixin, ReActAgent):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: str,
|
||||
llm_name: str,
|
||||
model_id: str,
|
||||
api_key: str,
|
||||
tool_ids: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
endpoint=endpoint,
|
||||
llm_name=llm_name,
|
||||
model_id=model_id,
|
||||
api_key=api_key,
|
||||
**kwargs,
|
||||
)
|
||||
self._allowed_tool_ids = tool_ids or []
|
||||
|
||||
|
||||
class WorkflowNodeAgentFactory:
|
||||
|
||||
_agents: Dict[AgentType, Type[BaseAgent]] = {
|
||||
AgentType.CLASSIC: WorkflowNodeClassicAgent,
|
||||
AgentType.REACT: WorkflowNodeReActAgent,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
agent_type: AgentType,
|
||||
endpoint: str,
|
||||
llm_name: str,
|
||||
model_id: str,
|
||||
api_key: str,
|
||||
tool_ids: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
) -> BaseAgent:
|
||||
agent_class = cls._agents.get(agent_type)
|
||||
if not agent_class:
|
||||
raise ValueError(f"Unsupported agent type: {agent_type}")
|
||||
return agent_class(
|
||||
endpoint=endpoint,
|
||||
llm_name=llm_name,
|
||||
model_id=model_id,
|
||||
api_key=api_key,
|
||||
tool_ids=tool_ids,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -1,215 +0,0 @@
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from bson import ObjectId
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
|
||||
class NodeType(str, Enum):
|
||||
START = "start"
|
||||
END = "end"
|
||||
AGENT = "agent"
|
||||
NOTE = "note"
|
||||
STATE = "state"
|
||||
|
||||
|
||||
class AgentType(str, Enum):
|
||||
CLASSIC = "classic"
|
||||
REACT = "react"
|
||||
|
||||
|
||||
class ExecutionStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class Position(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
x: float = 0.0
|
||||
y: float = 0.0
|
||||
|
||||
|
||||
class AgentNodeConfig(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
agent_type: AgentType = AgentType.CLASSIC
|
||||
llm_name: Optional[str] = None
|
||||
system_prompt: str = "You are a helpful assistant."
|
||||
prompt_template: str = ""
|
||||
output_variable: Optional[str] = None
|
||||
stream_to_user: bool = True
|
||||
tools: List[str] = Field(default_factory=list)
|
||||
sources: List[str] = Field(default_factory=list)
|
||||
chunks: str = "2"
|
||||
retriever: str = ""
|
||||
model_id: Optional[str] = None
|
||||
json_schema: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class WorkflowEdgeCreate(BaseModel):
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
id: str
|
||||
workflow_id: str
|
||||
source_id: str = Field(..., alias="source")
|
||||
target_id: str = Field(..., alias="target")
|
||||
source_handle: Optional[str] = Field(None, alias="sourceHandle")
|
||||
target_handle: Optional[str] = Field(None, alias="targetHandle")
|
||||
|
||||
|
||||
class WorkflowEdge(WorkflowEdgeCreate):
|
||||
mongo_id: Optional[str] = Field(None, alias="_id")
|
||||
|
||||
@field_validator("mongo_id", mode="before")
|
||||
@classmethod
|
||||
def convert_objectid(cls, v: Any) -> Optional[str]:
|
||||
if isinstance(v, ObjectId):
|
||||
return str(v)
|
||||
return v
|
||||
|
||||
def to_mongo_doc(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"workflow_id": self.workflow_id,
|
||||
"source_id": self.source_id,
|
||||
"target_id": self.target_id,
|
||||
"source_handle": self.source_handle,
|
||||
"target_handle": self.target_handle,
|
||||
}
|
||||
|
||||
|
||||
class WorkflowNodeCreate(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
id: str
|
||||
workflow_id: str
|
||||
type: NodeType
|
||||
title: str = "Node"
|
||||
description: Optional[str] = None
|
||||
position: Position = Field(default_factory=Position)
|
||||
config: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@field_validator("position", mode="before")
|
||||
@classmethod
|
||||
def parse_position(cls, v: Union[Dict[str, float], Position]) -> Position:
|
||||
if isinstance(v, dict):
|
||||
return Position(**v)
|
||||
return v
|
||||
|
||||
|
||||
class WorkflowNode(WorkflowNodeCreate):
|
||||
mongo_id: Optional[str] = Field(None, alias="_id")
|
||||
|
||||
@field_validator("mongo_id", mode="before")
|
||||
@classmethod
|
||||
def convert_objectid(cls, v: Any) -> Optional[str]:
|
||||
if isinstance(v, ObjectId):
|
||||
return str(v)
|
||||
return v
|
||||
|
||||
def to_mongo_doc(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"workflow_id": self.workflow_id,
|
||||
"type": self.type.value,
|
||||
"title": self.title,
|
||||
"description": self.description,
|
||||
"position": self.position.model_dump(),
|
||||
"config": self.config,
|
||||
}
|
||||
|
||||
|
||||
class WorkflowCreate(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
name: str = "New Workflow"
|
||||
description: Optional[str] = None
|
||||
user: Optional[str] = None
|
||||
|
||||
|
||||
class Workflow(WorkflowCreate):
|
||||
id: Optional[str] = Field(None, alias="_id")
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
@classmethod
|
||||
def convert_objectid(cls, v: Any) -> Optional[str]:
|
||||
if isinstance(v, ObjectId):
|
||||
return str(v)
|
||||
return v
|
||||
|
||||
def to_mongo_doc(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"user": self.user,
|
||||
"created_at": self.created_at,
|
||||
"updated_at": self.updated_at,
|
||||
}
|
||||
|
||||
|
||||
class WorkflowGraph(BaseModel):
|
||||
workflow: Workflow
|
||||
nodes: List[WorkflowNode] = Field(default_factory=list)
|
||||
edges: List[WorkflowEdge] = Field(default_factory=list)
|
||||
|
||||
def get_node_by_id(self, node_id: str) -> Optional[WorkflowNode]:
|
||||
for node in self.nodes:
|
||||
if node.id == node_id:
|
||||
return node
|
||||
return None
|
||||
|
||||
def get_start_node(self) -> Optional[WorkflowNode]:
|
||||
for node in self.nodes:
|
||||
if node.type == NodeType.START:
|
||||
return node
|
||||
return None
|
||||
|
||||
def get_outgoing_edges(self, node_id: str) -> List[WorkflowEdge]:
|
||||
return [edge for edge in self.edges if edge.source_id == node_id]
|
||||
|
||||
|
||||
class NodeExecutionLog(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
node_id: str
|
||||
node_type: str
|
||||
status: ExecutionStatus
|
||||
started_at: datetime
|
||||
completed_at: Optional[datetime] = None
|
||||
error: Optional[str] = None
|
||||
state_snapshot: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class WorkflowRunCreate(BaseModel):
|
||||
workflow_id: str
|
||||
inputs: Dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class WorkflowRun(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
id: Optional[str] = Field(None, alias="_id")
|
||||
workflow_id: str
|
||||
status: ExecutionStatus = ExecutionStatus.PENDING
|
||||
inputs: Dict[str, str] = Field(default_factory=dict)
|
||||
outputs: Dict[str, Any] = Field(default_factory=dict)
|
||||
steps: List[NodeExecutionLog] = Field(default_factory=list)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
completed_at: Optional[datetime] = None
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
@classmethod
|
||||
def convert_objectid(cls, v: Any) -> Optional[str]:
|
||||
if isinstance(v, ObjectId):
|
||||
return str(v)
|
||||
return v
|
||||
|
||||
def to_mongo_doc(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"workflow_id": self.workflow_id,
|
||||
"status": self.status.value,
|
||||
"inputs": self.inputs,
|
||||
"outputs": self.outputs,
|
||||
"steps": [step.model_dump() for step in self.steps],
|
||||
"created_at": self.created_at,
|
||||
"completed_at": self.completed_at,
|
||||
}
|
||||
@@ -1,276 +0,0 @@
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Generator, List, Optional, TYPE_CHECKING
|
||||
|
||||
from application.agents.workflows.node_agent import WorkflowNodeAgentFactory
|
||||
from application.agents.workflows.schemas import (
|
||||
AgentNodeConfig,
|
||||
ExecutionStatus,
|
||||
NodeExecutionLog,
|
||||
NodeType,
|
||||
WorkflowGraph,
|
||||
WorkflowNode,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from application.agents.base import BaseAgent
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
StateValue = Any
|
||||
WorkflowState = Dict[str, StateValue]
|
||||
|
||||
|
||||
class WorkflowEngine:
|
||||
MAX_EXECUTION_STEPS = 50
|
||||
|
||||
def __init__(self, graph: WorkflowGraph, agent: "BaseAgent"):
|
||||
self.graph = graph
|
||||
self.agent = agent
|
||||
self.state: WorkflowState = {}
|
||||
self.execution_log: List[Dict[str, Any]] = []
|
||||
|
||||
def execute(
|
||||
self, initial_inputs: WorkflowState, query: str
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
self._initialize_state(initial_inputs, query)
|
||||
|
||||
start_node = self.graph.get_start_node()
|
||||
if not start_node:
|
||||
yield {"type": "error", "error": "No start node found in workflow."}
|
||||
return
|
||||
current_node_id: Optional[str] = start_node.id
|
||||
steps = 0
|
||||
|
||||
while current_node_id and steps < self.MAX_EXECUTION_STEPS:
|
||||
node = self.graph.get_node_by_id(current_node_id)
|
||||
if not node:
|
||||
yield {"type": "error", "error": f"Node {current_node_id} not found."}
|
||||
break
|
||||
log_entry = self._create_log_entry(node)
|
||||
|
||||
yield {
|
||||
"type": "workflow_step",
|
||||
"node_id": node.id,
|
||||
"node_type": node.type.value,
|
||||
"node_title": node.title,
|
||||
"status": "running",
|
||||
}
|
||||
|
||||
try:
|
||||
yield from self._execute_node(node)
|
||||
log_entry["status"] = ExecutionStatus.COMPLETED.value
|
||||
log_entry["completed_at"] = datetime.now(timezone.utc)
|
||||
|
||||
output_key = f"node_{node.id}_output"
|
||||
node_output = self.state.get(output_key)
|
||||
|
||||
yield {
|
||||
"type": "workflow_step",
|
||||
"node_id": node.id,
|
||||
"node_type": node.type.value,
|
||||
"node_title": node.title,
|
||||
"status": "completed",
|
||||
"state_snapshot": dict(self.state),
|
||||
"output": node_output,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing node {node.id}: {e}", exc_info=True)
|
||||
log_entry["status"] = ExecutionStatus.FAILED.value
|
||||
log_entry["error"] = str(e)
|
||||
log_entry["completed_at"] = datetime.now(timezone.utc)
|
||||
log_entry["state_snapshot"] = dict(self.state)
|
||||
self.execution_log.append(log_entry)
|
||||
|
||||
yield {
|
||||
"type": "workflow_step",
|
||||
"node_id": node.id,
|
||||
"node_type": node.type.value,
|
||||
"node_title": node.title,
|
||||
"status": "failed",
|
||||
"state_snapshot": dict(self.state),
|
||||
"error": str(e),
|
||||
}
|
||||
yield {"type": "error", "error": str(e)}
|
||||
break
|
||||
log_entry["state_snapshot"] = dict(self.state)
|
||||
self.execution_log.append(log_entry)
|
||||
|
||||
if node.type == NodeType.END:
|
||||
break
|
||||
current_node_id = self._get_next_node_id(current_node_id)
|
||||
steps += 1
|
||||
if steps >= self.MAX_EXECUTION_STEPS:
|
||||
logger.warning(
|
||||
f"Workflow reached max steps limit ({self.MAX_EXECUTION_STEPS})"
|
||||
)
|
||||
|
||||
def _initialize_state(self, initial_inputs: WorkflowState, query: str) -> None:
|
||||
self.state.update(initial_inputs)
|
||||
self.state["query"] = query
|
||||
self.state["chat_history"] = str(self.agent.chat_history)
|
||||
|
||||
def _create_log_entry(self, node: WorkflowNode) -> Dict[str, Any]:
|
||||
return {
|
||||
"node_id": node.id,
|
||||
"node_type": node.type.value,
|
||||
"started_at": datetime.now(timezone.utc),
|
||||
"completed_at": None,
|
||||
"status": ExecutionStatus.RUNNING.value,
|
||||
"error": None,
|
||||
"state_snapshot": {},
|
||||
}
|
||||
|
||||
def _get_next_node_id(self, current_node_id: str) -> Optional[str]:
|
||||
edges = self.graph.get_outgoing_edges(current_node_id)
|
||||
if edges:
|
||||
return edges[0].target_id
|
||||
return None
|
||||
|
||||
def _execute_node(
|
||||
self, node: WorkflowNode
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
logger.info(f"Executing node {node.id} ({node.type.value})")
|
||||
|
||||
node_handlers = {
|
||||
NodeType.START: self._execute_start_node,
|
||||
NodeType.NOTE: self._execute_note_node,
|
||||
NodeType.AGENT: self._execute_agent_node,
|
||||
NodeType.STATE: self._execute_state_node,
|
||||
NodeType.END: self._execute_end_node,
|
||||
}
|
||||
|
||||
handler = node_handlers.get(node.type)
|
||||
if handler:
|
||||
yield from handler(node)
|
||||
|
||||
def _execute_start_node(
|
||||
self, node: WorkflowNode
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
yield from ()
|
||||
|
||||
def _execute_note_node(
|
||||
self, node: WorkflowNode
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
yield from ()
|
||||
|
||||
def _execute_agent_node(
|
||||
self, node: WorkflowNode
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
from application.core.model_utils import get_api_key_for_provider
|
||||
|
||||
node_config = AgentNodeConfig(**node.config)
|
||||
|
||||
if node_config.prompt_template:
|
||||
formatted_prompt = self._format_template(node_config.prompt_template)
|
||||
else:
|
||||
formatted_prompt = self.state.get("query", "")
|
||||
node_llm_name = node_config.llm_name or self.agent.llm_name
|
||||
node_api_key = get_api_key_for_provider(node_llm_name) or self.agent.api_key
|
||||
|
||||
node_agent = WorkflowNodeAgentFactory.create(
|
||||
agent_type=node_config.agent_type,
|
||||
endpoint=self.agent.endpoint,
|
||||
llm_name=node_llm_name,
|
||||
model_id=node_config.model_id or self.agent.model_id,
|
||||
api_key=node_api_key,
|
||||
tool_ids=node_config.tools,
|
||||
prompt=node_config.system_prompt,
|
||||
chat_history=self.agent.chat_history,
|
||||
decoded_token=self.agent.decoded_token,
|
||||
json_schema=node_config.json_schema,
|
||||
)
|
||||
|
||||
full_response = ""
|
||||
first_chunk = True
|
||||
for event in node_agent.gen(formatted_prompt):
|
||||
if "answer" in event:
|
||||
full_response += event["answer"]
|
||||
if node_config.stream_to_user:
|
||||
if first_chunk and hasattr(self, "_has_streamed"):
|
||||
yield {"answer": "\n\n"}
|
||||
first_chunk = False
|
||||
yield event
|
||||
|
||||
if node_config.stream_to_user:
|
||||
self._has_streamed = True
|
||||
|
||||
output_key = node_config.output_variable or f"node_{node.id}_output"
|
||||
self.state[output_key] = full_response
|
||||
|
||||
def _execute_state_node(
|
||||
self, node: WorkflowNode
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
config = node.config
|
||||
operations = config.get("operations", [])
|
||||
|
||||
if operations:
|
||||
for op in operations:
|
||||
key = op.get("key")
|
||||
operation = op.get("operation", "set")
|
||||
value = op.get("value")
|
||||
|
||||
if not key:
|
||||
continue
|
||||
if operation == "set":
|
||||
formatted_value = (
|
||||
self._format_template(str(value))
|
||||
if isinstance(value, str)
|
||||
else value
|
||||
)
|
||||
self.state[key] = formatted_value
|
||||
elif operation == "increment":
|
||||
current = self.state.get(key, 0)
|
||||
try:
|
||||
self.state[key] = int(current) + int(value or 1)
|
||||
except (ValueError, TypeError):
|
||||
self.state[key] = 1
|
||||
elif operation == "append":
|
||||
if key not in self.state:
|
||||
self.state[key] = []
|
||||
if isinstance(self.state[key], list):
|
||||
self.state[key].append(value)
|
||||
else:
|
||||
updates = config.get("updates", {})
|
||||
if not updates:
|
||||
var_name = config.get("variable")
|
||||
var_value = config.get("value")
|
||||
if var_name and isinstance(var_name, str):
|
||||
updates = {var_name: var_value or ""}
|
||||
if isinstance(updates, dict):
|
||||
for key, value in updates.items():
|
||||
if isinstance(value, str):
|
||||
self.state[key] = self._format_template(value)
|
||||
else:
|
||||
self.state[key] = value
|
||||
yield from ()
|
||||
|
||||
def _execute_end_node(
|
||||
self, node: WorkflowNode
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
config = node.config
|
||||
output_template = str(config.get("output_template", ""))
|
||||
if output_template:
|
||||
formatted_output = self._format_template(output_template)
|
||||
yield {"answer": formatted_output}
|
||||
|
||||
def _format_template(self, template: str) -> str:
|
||||
formatted = template
|
||||
for key, value in self.state.items():
|
||||
placeholder = f"{{{{{key}}}}}"
|
||||
if placeholder in formatted and value is not None:
|
||||
formatted = formatted.replace(placeholder, str(value))
|
||||
return formatted
|
||||
|
||||
def get_execution_summary(self) -> List[NodeExecutionLog]:
|
||||
return [
|
||||
NodeExecutionLog(
|
||||
node_id=log["node_id"],
|
||||
node_type=log["node_type"],
|
||||
status=ExecutionStatus(log["status"]),
|
||||
started_at=log["started_at"],
|
||||
completed_at=log.get("completed_at"),
|
||||
error=log.get("error"),
|
||||
state_snapshot=log.get("state_snapshot", {}),
|
||||
)
|
||||
for log in self.execution_log
|
||||
]
|
||||
@@ -3,7 +3,6 @@ from flask import Blueprint
|
||||
from application.api import api
|
||||
from application.api.answer.routes.answer import AnswerResource
|
||||
from application.api.answer.routes.base import answer_ns
|
||||
from application.api.answer.routes.search import SearchResource
|
||||
from application.api.answer.routes.stream import StreamResource
|
||||
|
||||
|
||||
@@ -15,7 +14,6 @@ api.add_namespace(answer_ns)
|
||||
def init_answer_routes():
|
||||
api.add_resource(StreamResource, "/stream")
|
||||
api.add_resource(AnswerResource, "/api/answer")
|
||||
api.add_resource(SearchResource, "/api/search")
|
||||
|
||||
|
||||
init_answer_routes()
|
||||
|
||||
@@ -40,6 +40,7 @@ class AnswerResource(Resource, BaseAnswerResource):
|
||||
"chunks": fields.Integer(
|
||||
required=False, default=2, description="Number of chunks"
|
||||
),
|
||||
"token_limit": fields.Integer(required=False, description="Token limit"),
|
||||
"retriever": fields.String(required=False, description="Retriever type"),
|
||||
"api_key": fields.String(required=False, description="API key"),
|
||||
"active_docs": fields.String(
|
||||
@@ -137,5 +138,5 @@ class AnswerResource(Resource, BaseAnswerResource):
|
||||
f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}",
|
||||
extra={"error": str(e), "traceback": traceback.format_exc()},
|
||||
)
|
||||
return make_response({"error": "An error occurred processing your request"}, 500)
|
||||
return make_response({"error": str(e)}, 500)
|
||||
return make_response(result, 200)
|
||||
|
||||
@@ -266,26 +266,6 @@ class BaseAnswerResource:
|
||||
shared_token=shared_token,
|
||||
attachment_ids=attachment_ids,
|
||||
)
|
||||
# Persist compression metadata/summary if it exists and wasn't saved mid-execution
|
||||
compression_meta = getattr(agent, "compression_metadata", None)
|
||||
compression_saved = getattr(agent, "compression_saved", False)
|
||||
if conversation_id and compression_meta and not compression_saved:
|
||||
try:
|
||||
self.conversation_service.update_compression_metadata(
|
||||
conversation_id, compression_meta
|
||||
)
|
||||
self.conversation_service.append_compression_message(
|
||||
conversation_id, compression_meta
|
||||
)
|
||||
agent.compression_saved = True
|
||||
logger.info(
|
||||
f"Persisted compression metadata for conversation {conversation_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to persist compression metadata: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
else:
|
||||
conversation_id = None
|
||||
id_data = {"type": "id", "id": str(conversation_id)}
|
||||
@@ -348,25 +328,6 @@ class BaseAnswerResource:
|
||||
shared_token=shared_token,
|
||||
attachment_ids=attachment_ids,
|
||||
)
|
||||
compression_meta = getattr(agent, "compression_metadata", None)
|
||||
compression_saved = getattr(agent, "compression_saved", False)
|
||||
if conversation_id and compression_meta and not compression_saved:
|
||||
try:
|
||||
self.conversation_service.update_compression_metadata(
|
||||
conversation_id, compression_meta
|
||||
)
|
||||
self.conversation_service.append_compression_message(
|
||||
conversation_id, compression_meta
|
||||
)
|
||||
agent.compression_saved = True
|
||||
logger.info(
|
||||
f"Persisted compression metadata for conversation {conversation_id} (partial stream)"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to persist compression metadata (partial stream): {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error saving partial response: {str(e)}", exc_info=True
|
||||
|
||||
@@ -1,186 +0,0 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from flask import make_response, request
|
||||
from flask_restx import fields, Resource
|
||||
|
||||
from bson.dbref import DBRef
|
||||
|
||||
from application.api.answer.routes.base import answer_ns
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.vectorstore.vector_creator import VectorCreator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@answer_ns.route("/api/search")
|
||||
class SearchResource(Resource):
|
||||
"""Fast search endpoint for retrieving relevant documents"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
mongo = MongoDB.get_client()
|
||||
self.db = mongo[settings.MONGO_DB_NAME]
|
||||
self.agents_collection = self.db["agents"]
|
||||
|
||||
search_model = answer_ns.model(
|
||||
"SearchModel",
|
||||
{
|
||||
"question": fields.String(
|
||||
required=True, description="Search query"
|
||||
),
|
||||
"api_key": fields.String(
|
||||
required=True, description="API key for authentication"
|
||||
),
|
||||
"chunks": fields.Integer(
|
||||
required=False, default=5, description="Number of results to return"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
def _get_sources_from_api_key(self, api_key: str) -> List[str]:
|
||||
"""Get source IDs connected to the API key/agent.
|
||||
|
||||
"""
|
||||
agent_data = self.agents_collection.find_one({"key": api_key})
|
||||
if not agent_data:
|
||||
return []
|
||||
|
||||
source_ids = []
|
||||
|
||||
# Handle multiple sources (only if non-empty)
|
||||
sources = agent_data.get("sources", [])
|
||||
if sources and isinstance(sources, list) and len(sources) > 0:
|
||||
for source_ref in sources:
|
||||
# Skip "default" - it's a placeholder, not an actual vectorstore
|
||||
if source_ref == "default":
|
||||
continue
|
||||
elif isinstance(source_ref, DBRef):
|
||||
source_doc = self.db.dereference(source_ref)
|
||||
if source_doc:
|
||||
source_ids.append(str(source_doc["_id"]))
|
||||
|
||||
# Handle single source (legacy) - check if sources was empty or didn't yield results
|
||||
if not source_ids:
|
||||
source = agent_data.get("source")
|
||||
if isinstance(source, DBRef):
|
||||
source_doc = self.db.dereference(source)
|
||||
if source_doc:
|
||||
source_ids.append(str(source_doc["_id"]))
|
||||
# Skip "default" - it's a placeholder, not an actual vectorstore
|
||||
elif source and source != "default":
|
||||
source_ids.append(source)
|
||||
|
||||
return source_ids
|
||||
|
||||
def _search_vectorstores(
|
||||
self, query: str, source_ids: List[str], chunks: int
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search across vectorstores and return results"""
|
||||
if not source_ids:
|
||||
return []
|
||||
|
||||
results = []
|
||||
chunks_per_source = max(1, chunks // len(source_ids))
|
||||
seen_texts = set()
|
||||
|
||||
for source_id in source_ids:
|
||||
if not source_id or not source_id.strip():
|
||||
continue
|
||||
|
||||
try:
|
||||
docsearch = VectorCreator.create_vectorstore(
|
||||
settings.VECTOR_STORE, source_id, settings.EMBEDDINGS_KEY
|
||||
)
|
||||
docs = docsearch.search(query, k=chunks_per_source * 2)
|
||||
|
||||
for doc in docs:
|
||||
if len(results) >= chunks:
|
||||
break
|
||||
|
||||
if hasattr(doc, "page_content") and hasattr(doc, "metadata"):
|
||||
page_content = doc.page_content
|
||||
metadata = doc.metadata
|
||||
else:
|
||||
page_content = doc.get("text", doc.get("page_content", ""))
|
||||
metadata = doc.get("metadata", {})
|
||||
|
||||
# Skip duplicates
|
||||
text_hash = hash(page_content[:200])
|
||||
if text_hash in seen_texts:
|
||||
continue
|
||||
seen_texts.add(text_hash)
|
||||
|
||||
title = metadata.get(
|
||||
"title", metadata.get("post_title", "")
|
||||
)
|
||||
if not isinstance(title, str):
|
||||
title = str(title) if title else ""
|
||||
|
||||
# Clean up title
|
||||
if title:
|
||||
title = title.split("/")[-1]
|
||||
else:
|
||||
# Use filename or first part of content as title
|
||||
title = metadata.get("filename", page_content[:50] + "...")
|
||||
|
||||
source = metadata.get("source", source_id)
|
||||
|
||||
results.append({
|
||||
"text": page_content,
|
||||
"title": title,
|
||||
"source": source,
|
||||
})
|
||||
|
||||
if len(results) >= chunks:
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error searching vectorstore {source_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
continue
|
||||
|
||||
return results[:chunks]
|
||||
|
||||
@answer_ns.expect(search_model)
|
||||
@answer_ns.doc(description="Search for relevant documents based on query")
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
|
||||
question = data.get("question")
|
||||
api_key = data.get("api_key")
|
||||
chunks = data.get("chunks", 5)
|
||||
|
||||
if not question:
|
||||
return make_response({"error": "question is required"}, 400)
|
||||
|
||||
if not api_key:
|
||||
return make_response({"error": "api_key is required"}, 400)
|
||||
|
||||
# Validate API key
|
||||
agent = self.agents_collection.find_one({"key": api_key})
|
||||
if not agent:
|
||||
return make_response({"error": "Invalid API key"}, 401)
|
||||
|
||||
try:
|
||||
# Get sources connected to this API key
|
||||
source_ids = self._get_sources_from_api_key(api_key)
|
||||
|
||||
if not source_ids:
|
||||
return make_response([], 200)
|
||||
|
||||
# Perform search
|
||||
results = self._search_vectorstores(question, source_ids, chunks)
|
||||
|
||||
return make_response(results, 200)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"/api/search - error: {str(e)}",
|
||||
extra={"error": str(e)},
|
||||
exc_info=True,
|
||||
)
|
||||
return make_response({"error": "Search failed"}, 500)
|
||||
@@ -40,6 +40,7 @@ class StreamResource(Resource, BaseAnswerResource):
|
||||
"chunks": fields.Integer(
|
||||
required=False, default=2, description="Number of chunks"
|
||||
),
|
||||
"token_limit": fields.Integer(required=False, description="Token limit"),
|
||||
"retriever": fields.String(required=False, description="Retriever type"),
|
||||
"api_key": fields.String(required=False, description="API key"),
|
||||
"active_docs": fields.String(
|
||||
@@ -80,12 +81,6 @@ class StreamResource(Resource, BaseAnswerResource):
|
||||
processor = StreamProcessor(data, decoded_token)
|
||||
try:
|
||||
processor.initialize()
|
||||
if not processor.decoded_token:
|
||||
return Response(
|
||||
self.error_stream_generate("Unauthorized"),
|
||||
status=401,
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
|
||||
docs_together, docs_list = processor.pre_fetch_docs(data["question"])
|
||||
tools_data = processor.pre_fetch_tools()
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
"""
|
||||
Compression module for managing conversation context compression.
|
||||
|
||||
"""
|
||||
|
||||
from application.api.answer.services.compression.orchestrator import (
|
||||
CompressionOrchestrator,
|
||||
)
|
||||
from application.api.answer.services.compression.service import CompressionService
|
||||
from application.api.answer.services.compression.types import (
|
||||
CompressionResult,
|
||||
CompressionMetadata,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CompressionOrchestrator",
|
||||
"CompressionService",
|
||||
"CompressionResult",
|
||||
"CompressionMetadata",
|
||||
]
|
||||
@@ -1,234 +0,0 @@
|
||||
"""Message reconstruction utilities for compression."""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageBuilder:
|
||||
"""Builds message arrays from compressed context."""
|
||||
|
||||
@staticmethod
|
||||
def build_from_compressed_context(
|
||||
system_prompt: str,
|
||||
compressed_summary: Optional[str],
|
||||
recent_queries: List[Dict],
|
||||
include_tool_calls: bool = False,
|
||||
context_type: str = "pre_request",
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Build messages from compressed context.
|
||||
|
||||
Args:
|
||||
system_prompt: Original system prompt
|
||||
compressed_summary: Compressed summary (if any)
|
||||
recent_queries: Recent uncompressed queries
|
||||
include_tool_calls: Whether to include tool calls from history
|
||||
context_type: Type of context ('pre_request' or 'mid_execution')
|
||||
|
||||
Returns:
|
||||
List of message dicts ready for LLM
|
||||
"""
|
||||
# Append compression summary to system prompt if present
|
||||
if compressed_summary:
|
||||
system_prompt = MessageBuilder._append_compression_context(
|
||||
system_prompt, compressed_summary, context_type
|
||||
)
|
||||
|
||||
messages = [{"role": "system", "content": system_prompt}]
|
||||
|
||||
# Add recent history
|
||||
for query in recent_queries:
|
||||
if "prompt" in query and "response" in query:
|
||||
messages.append({"role": "user", "content": query["prompt"]})
|
||||
messages.append({"role": "assistant", "content": query["response"]})
|
||||
|
||||
# Add tool calls from history if present
|
||||
if include_tool_calls and "tool_calls" in query:
|
||||
for tool_call in query["tool_calls"]:
|
||||
call_id = tool_call.get("call_id") or str(uuid.uuid4())
|
||||
|
||||
function_call_dict = {
|
||||
"function_call": {
|
||||
"name": tool_call.get("action_name"),
|
||||
"args": tool_call.get("arguments"),
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
function_response_dict = {
|
||||
"function_response": {
|
||||
"name": tool_call.get("action_name"),
|
||||
"response": {"result": tool_call.get("result")},
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
|
||||
messages.append(
|
||||
{"role": "assistant", "content": [function_call_dict]}
|
||||
)
|
||||
messages.append(
|
||||
{"role": "tool", "content": [function_response_dict]}
|
||||
)
|
||||
|
||||
# If no recent queries (everything was compressed), add a continuation user message
|
||||
if len(recent_queries) == 0 and compressed_summary:
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": "Please continue with the remaining tasks based on the context above."
|
||||
})
|
||||
logger.info("Added continuation user message to maintain proper turn-taking after full compression")
|
||||
|
||||
return messages
|
||||
|
||||
@staticmethod
|
||||
def _append_compression_context(
|
||||
system_prompt: str, compressed_summary: str, context_type: str = "pre_request"
|
||||
) -> str:
|
||||
"""
|
||||
Append compression context to system prompt.
|
||||
|
||||
Args:
|
||||
system_prompt: Original system prompt
|
||||
compressed_summary: Summary to append
|
||||
context_type: Type of compression context
|
||||
|
||||
Returns:
|
||||
Updated system prompt
|
||||
"""
|
||||
# Remove existing compression context if present
|
||||
if "This session is being continued" in system_prompt or "Context window limit reached" in system_prompt:
|
||||
parts = system_prompt.split("\n\n---\n\n")
|
||||
system_prompt = parts[0]
|
||||
|
||||
# Build appropriate context message based on type
|
||||
if context_type == "mid_execution":
|
||||
context_message = (
|
||||
"\n\n---\n\n"
|
||||
"Context window limit reached during execution. "
|
||||
"Previous conversation has been compressed to fit within limits. "
|
||||
"The conversation is summarized below:\n\n"
|
||||
f"{compressed_summary}"
|
||||
)
|
||||
else: # pre_request
|
||||
context_message = (
|
||||
"\n\n---\n\n"
|
||||
"This session is being continued from a previous conversation that "
|
||||
"has been compressed to fit within context limits. "
|
||||
"The conversation is summarized below:\n\n"
|
||||
f"{compressed_summary}"
|
||||
)
|
||||
|
||||
return system_prompt + context_message
|
||||
|
||||
@staticmethod
|
||||
def rebuild_messages_after_compression(
|
||||
messages: List[Dict],
|
||||
compressed_summary: Optional[str],
|
||||
recent_queries: List[Dict],
|
||||
include_current_execution: bool = False,
|
||||
include_tool_calls: bool = False,
|
||||
) -> Optional[List[Dict]]:
|
||||
"""
|
||||
Rebuild the message list after compression so tool execution can continue.
|
||||
|
||||
Args:
|
||||
messages: Original message list
|
||||
compressed_summary: Compressed summary
|
||||
recent_queries: Recent uncompressed queries
|
||||
include_current_execution: Whether to preserve current execution messages
|
||||
include_tool_calls: Whether to include tool calls from history
|
||||
|
||||
Returns:
|
||||
Rebuilt message list or None if failed
|
||||
"""
|
||||
# Find the system message
|
||||
system_message = next(
|
||||
(msg for msg in messages if msg.get("role") == "system"), None
|
||||
)
|
||||
if not system_message:
|
||||
logger.warning("No system message found in messages list")
|
||||
return None
|
||||
|
||||
# Update system message with compressed summary
|
||||
if compressed_summary:
|
||||
content = system_message.get("content", "")
|
||||
system_message["content"] = MessageBuilder._append_compression_context(
|
||||
content, compressed_summary, "mid_execution"
|
||||
)
|
||||
logger.info(
|
||||
"Appended compression summary to system prompt (truncated): %s",
|
||||
(
|
||||
compressed_summary[:500] + "..."
|
||||
if len(compressed_summary) > 500
|
||||
else compressed_summary
|
||||
),
|
||||
)
|
||||
|
||||
rebuilt_messages = [system_message]
|
||||
|
||||
# Add recent history from compressed context
|
||||
for query in recent_queries:
|
||||
if "prompt" in query and "response" in query:
|
||||
rebuilt_messages.append({"role": "user", "content": query["prompt"]})
|
||||
rebuilt_messages.append(
|
||||
{"role": "assistant", "content": query["response"]}
|
||||
)
|
||||
|
||||
# Add tool calls from history if present
|
||||
if include_tool_calls and "tool_calls" in query:
|
||||
for tool_call in query["tool_calls"]:
|
||||
call_id = tool_call.get("call_id") or str(uuid.uuid4())
|
||||
|
||||
function_call_dict = {
|
||||
"function_call": {
|
||||
"name": tool_call.get("action_name"),
|
||||
"args": tool_call.get("arguments"),
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
function_response_dict = {
|
||||
"function_response": {
|
||||
"name": tool_call.get("action_name"),
|
||||
"response": {"result": tool_call.get("result")},
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
|
||||
rebuilt_messages.append(
|
||||
{"role": "assistant", "content": [function_call_dict]}
|
||||
)
|
||||
rebuilt_messages.append(
|
||||
{"role": "tool", "content": [function_response_dict]}
|
||||
)
|
||||
|
||||
# If no recent queries (everything was compressed), add a continuation user message
|
||||
if len(recent_queries) == 0 and compressed_summary:
|
||||
rebuilt_messages.append({
|
||||
"role": "user",
|
||||
"content": "Please continue with the remaining tasks based on the context above."
|
||||
})
|
||||
logger.info("Added continuation user message to maintain proper turn-taking after full compression")
|
||||
|
||||
if include_current_execution:
|
||||
# Preserve any messages that were added during the current execution cycle
|
||||
recent_msg_count = 1 # system message
|
||||
for query in recent_queries:
|
||||
if "prompt" in query and "response" in query:
|
||||
recent_msg_count += 2
|
||||
if "tool_calls" in query:
|
||||
recent_msg_count += len(query["tool_calls"]) * 2
|
||||
|
||||
if len(messages) > recent_msg_count:
|
||||
current_execution_messages = messages[recent_msg_count:]
|
||||
rebuilt_messages.extend(current_execution_messages)
|
||||
logger.info(
|
||||
f"Preserved {len(current_execution_messages)} messages from current execution cycle"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Messages rebuilt: {len(messages)} → {len(rebuilt_messages)} messages. "
|
||||
f"Ready to continue tool execution."
|
||||
)
|
||||
return rebuilt_messages
|
||||
@@ -1,232 +0,0 @@
|
||||
"""High-level compression orchestration."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from application.api.answer.services.compression.service import CompressionService
|
||||
from application.api.answer.services.compression.threshold_checker import (
|
||||
CompressionThresholdChecker,
|
||||
)
|
||||
from application.api.answer.services.compression.types import CompressionResult
|
||||
from application.api.answer.services.conversation_service import ConversationService
|
||||
from application.core.model_utils import (
|
||||
get_api_key_for_provider,
|
||||
get_provider_from_model_id,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CompressionOrchestrator:
|
||||
"""
|
||||
Facade for compression operations.
|
||||
|
||||
Coordinates between all compression components and provides
|
||||
a simple interface for callers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conversation_service: ConversationService,
|
||||
threshold_checker: Optional[CompressionThresholdChecker] = None,
|
||||
):
|
||||
"""
|
||||
Initialize orchestrator.
|
||||
|
||||
Args:
|
||||
conversation_service: Service for DB operations
|
||||
threshold_checker: Custom threshold checker (optional)
|
||||
"""
|
||||
self.conversation_service = conversation_service
|
||||
self.threshold_checker = threshold_checker or CompressionThresholdChecker()
|
||||
|
||||
def compress_if_needed(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_id: str,
|
||||
model_id: str,
|
||||
decoded_token: Dict[str, Any],
|
||||
current_query_tokens: int = 500,
|
||||
) -> CompressionResult:
|
||||
"""
|
||||
Check if compression is needed and perform it if so.
|
||||
|
||||
This is the main entry point for compression operations.
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation ID
|
||||
user_id: User ID
|
||||
model_id: Model being used for conversation
|
||||
decoded_token: User's decoded JWT token
|
||||
current_query_tokens: Estimated tokens for current query
|
||||
|
||||
Returns:
|
||||
CompressionResult with summary and recent queries
|
||||
"""
|
||||
try:
|
||||
# Load conversation
|
||||
conversation = self.conversation_service.get_conversation(
|
||||
conversation_id, user_id
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
logger.warning(
|
||||
f"Conversation {conversation_id} not found for user {user_id}"
|
||||
)
|
||||
return CompressionResult.failure("Conversation not found")
|
||||
|
||||
# Check if compression is needed
|
||||
if not self.threshold_checker.should_compress(
|
||||
conversation, model_id, current_query_tokens
|
||||
):
|
||||
# No compression needed, return full history
|
||||
queries = conversation.get("queries", [])
|
||||
return CompressionResult.success_no_compression(queries)
|
||||
|
||||
# Perform compression
|
||||
return self._perform_compression(
|
||||
conversation_id, conversation, model_id, decoded_token
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in compress_if_needed: {str(e)}", exc_info=True
|
||||
)
|
||||
return CompressionResult.failure(str(e))
|
||||
|
||||
def _perform_compression(
|
||||
self,
|
||||
conversation_id: str,
|
||||
conversation: Dict[str, Any],
|
||||
model_id: str,
|
||||
decoded_token: Dict[str, Any],
|
||||
) -> CompressionResult:
|
||||
"""
|
||||
Perform the actual compression operation.
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation ID
|
||||
conversation: Conversation document
|
||||
model_id: Model ID for conversation
|
||||
decoded_token: User token
|
||||
|
||||
Returns:
|
||||
CompressionResult
|
||||
"""
|
||||
try:
|
||||
# Determine which model to use for compression
|
||||
compression_model = (
|
||||
settings.COMPRESSION_MODEL_OVERRIDE
|
||||
if settings.COMPRESSION_MODEL_OVERRIDE
|
||||
else model_id
|
||||
)
|
||||
|
||||
# Get provider and API key for compression model
|
||||
provider = get_provider_from_model_id(compression_model)
|
||||
api_key = get_api_key_for_provider(provider)
|
||||
|
||||
# Create compression LLM
|
||||
compression_llm = LLMCreator.create_llm(
|
||||
provider,
|
||||
api_key=api_key,
|
||||
user_api_key=None,
|
||||
decoded_token=decoded_token,
|
||||
model_id=compression_model,
|
||||
)
|
||||
|
||||
# Create compression service with DB update capability
|
||||
compression_service = CompressionService(
|
||||
llm=compression_llm,
|
||||
model_id=compression_model,
|
||||
conversation_service=self.conversation_service,
|
||||
)
|
||||
|
||||
# Compress all queries up to the latest
|
||||
queries_count = len(conversation.get("queries", []))
|
||||
compress_up_to = queries_count - 1
|
||||
|
||||
if compress_up_to < 0:
|
||||
logger.warning("No queries to compress")
|
||||
return CompressionResult.success_no_compression([])
|
||||
|
||||
logger.info(
|
||||
f"Initiating compression for conversation {conversation_id}: "
|
||||
f"compressing all {queries_count} queries (0-{compress_up_to})"
|
||||
)
|
||||
|
||||
# Perform compression and save to DB
|
||||
metadata = compression_service.compress_and_save(
|
||||
conversation_id, conversation, compress_up_to
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Compression successful - ratio: {metadata.compression_ratio:.1f}x, "
|
||||
f"saved {metadata.original_token_count - metadata.compressed_token_count} tokens"
|
||||
)
|
||||
|
||||
# Reload conversation with updated metadata
|
||||
conversation = self.conversation_service.get_conversation(
|
||||
conversation_id, user_id=decoded_token.get("sub")
|
||||
)
|
||||
|
||||
# Get compressed context
|
||||
compressed_summary, recent_queries = (
|
||||
compression_service.get_compressed_context(conversation)
|
||||
)
|
||||
|
||||
return CompressionResult.success_with_compression(
|
||||
compressed_summary, recent_queries, metadata
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error performing compression: {str(e)}", exc_info=True)
|
||||
return CompressionResult.failure(str(e))
|
||||
|
||||
def compress_mid_execution(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_id: str,
|
||||
model_id: str,
|
||||
decoded_token: Dict[str, Any],
|
||||
current_conversation: Optional[Dict[str, Any]] = None,
|
||||
) -> CompressionResult:
|
||||
"""
|
||||
Perform compression during tool execution.
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation ID
|
||||
user_id: User ID
|
||||
model_id: Model ID
|
||||
decoded_token: User token
|
||||
current_conversation: Pre-loaded conversation (optional)
|
||||
|
||||
Returns:
|
||||
CompressionResult
|
||||
"""
|
||||
try:
|
||||
# Load conversation if not provided
|
||||
if current_conversation:
|
||||
conversation = current_conversation
|
||||
else:
|
||||
conversation = self.conversation_service.get_conversation(
|
||||
conversation_id, user_id
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
logger.warning(
|
||||
f"Could not load conversation {conversation_id} for mid-execution compression"
|
||||
)
|
||||
return CompressionResult.failure("Conversation not found")
|
||||
|
||||
# Perform compression
|
||||
return self._perform_compression(
|
||||
conversation_id, conversation, model_id, decoded_token
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in mid-execution compression: {str(e)}", exc_info=True
|
||||
)
|
||||
return CompressionResult.failure(str(e))
|
||||
@@ -1,149 +0,0 @@
|
||||
"""Compression prompt building logic."""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CompressionPromptBuilder:
|
||||
"""Builds prompts for LLM compression calls."""
|
||||
|
||||
def __init__(self, version: str = "v1.0"):
|
||||
"""
|
||||
Initialize prompt builder.
|
||||
|
||||
Args:
|
||||
version: Prompt template version to use
|
||||
"""
|
||||
self.version = version
|
||||
self.system_prompt = self._load_prompt(version)
|
||||
|
||||
def _load_prompt(self, version: str) -> str:
|
||||
"""
|
||||
Load prompt template from file.
|
||||
|
||||
Args:
|
||||
version: Version string (e.g., 'v1.0')
|
||||
|
||||
Returns:
|
||||
Prompt template content
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If prompt template file doesn't exist
|
||||
"""
|
||||
current_dir = Path(__file__).resolve().parents[4]
|
||||
prompt_path = current_dir / "prompts" / "compression" / f"{version}.txt"
|
||||
|
||||
try:
|
||||
with open(prompt_path, "r") as f:
|
||||
return f.read()
|
||||
except FileNotFoundError:
|
||||
logger.error(f"Compression prompt template not found: {prompt_path}")
|
||||
raise FileNotFoundError(
|
||||
f"Compression prompt template '{version}' not found at {prompt_path}. "
|
||||
f"Please ensure the template file exists."
|
||||
)
|
||||
|
||||
def build_prompt(
|
||||
self,
|
||||
queries: List[Dict[str, Any]],
|
||||
existing_compressions: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Build messages for compression LLM call.
|
||||
|
||||
Args:
|
||||
queries: List of query objects to compress
|
||||
existing_compressions: List of previous compression points
|
||||
|
||||
Returns:
|
||||
List of message dicts for LLM
|
||||
"""
|
||||
# Build conversation text
|
||||
conversation_text = self._format_conversation(queries)
|
||||
|
||||
# Add existing compression context if present
|
||||
existing_compression_context = ""
|
||||
if existing_compressions and len(existing_compressions) > 0:
|
||||
existing_compression_context = (
|
||||
"\n\nIMPORTANT: This conversation has been compressed before. "
|
||||
"Previous compression summaries:\n\n"
|
||||
)
|
||||
for i, comp in enumerate(existing_compressions):
|
||||
existing_compression_context += (
|
||||
f"--- Compression {i + 1} (up to message {comp.get('query_index', 'unknown')}) ---\n"
|
||||
f"{comp.get('compressed_summary', '')}\n\n"
|
||||
)
|
||||
existing_compression_context += (
|
||||
"Your task is to create a NEW summary that incorporates the context from "
|
||||
"previous compressions AND the new messages below. The final summary should "
|
||||
"be comprehensive and include all important information from both previous "
|
||||
"compressions and new messages.\n\n"
|
||||
)
|
||||
|
||||
user_prompt = (
|
||||
f"{existing_compression_context}"
|
||||
f"Here is the conversation to summarize:\n\n"
|
||||
f"{conversation_text}"
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": self.system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
]
|
||||
|
||||
return messages
|
||||
|
||||
def _format_conversation(self, queries: List[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
Format conversation queries into readable text for compression.
|
||||
|
||||
Args:
|
||||
queries: List of query objects
|
||||
|
||||
Returns:
|
||||
Formatted conversation text
|
||||
"""
|
||||
conversation_lines = []
|
||||
|
||||
for i, query in enumerate(queries):
|
||||
conversation_lines.append(f"--- Message {i + 1} ---")
|
||||
conversation_lines.append(f"User: {query.get('prompt', '')}")
|
||||
|
||||
# Add tool calls if present
|
||||
tool_calls = query.get("tool_calls", [])
|
||||
if tool_calls:
|
||||
conversation_lines.append("\nTool Calls:")
|
||||
for tc in tool_calls:
|
||||
tool_name = tc.get("tool_name", "unknown")
|
||||
action_name = tc.get("action_name", "unknown")
|
||||
arguments = tc.get("arguments", {})
|
||||
result = tc.get("result", "")
|
||||
if result is None:
|
||||
result = ""
|
||||
status = tc.get("status", "unknown")
|
||||
|
||||
# Include full tool result for complete compression context
|
||||
conversation_lines.append(
|
||||
f" - {tool_name}.{action_name}({arguments}) "
|
||||
f"[{status}] → {result}"
|
||||
)
|
||||
|
||||
# Add agent thought if present
|
||||
thought = query.get("thought", "")
|
||||
if thought:
|
||||
conversation_lines.append(f"\nAgent Thought: {thought}")
|
||||
|
||||
# Add assistant response
|
||||
conversation_lines.append(f"\nAssistant: {query.get('response', '')}")
|
||||
|
||||
# Add sources if present
|
||||
sources = query.get("sources", [])
|
||||
if sources:
|
||||
conversation_lines.append(f"\nSources Used: {len(sources)} documents")
|
||||
|
||||
conversation_lines.append("") # Empty line between messages
|
||||
|
||||
return "\n".join(conversation_lines)
|
||||
@@ -1,306 +0,0 @@
|
||||
"""Core compression service with simplified responsibilities."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from application.api.answer.services.compression.prompt_builder import (
|
||||
CompressionPromptBuilder,
|
||||
)
|
||||
from application.api.answer.services.compression.token_counter import TokenCounter
|
||||
from application.api.answer.services.compression.types import (
|
||||
CompressionMetadata,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CompressionService:
|
||||
"""
|
||||
Service for compressing conversation history.
|
||||
|
||||
Handles DB updates.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm,
|
||||
model_id: str,
|
||||
conversation_service=None,
|
||||
prompt_builder: Optional[CompressionPromptBuilder] = None,
|
||||
):
|
||||
"""
|
||||
Initialize compression service.
|
||||
|
||||
Args:
|
||||
llm: LLM instance to use for compression
|
||||
model_id: Model ID for compression
|
||||
conversation_service: Service for DB operations (optional, for DB updates)
|
||||
prompt_builder: Custom prompt builder (optional)
|
||||
"""
|
||||
self.llm = llm
|
||||
self.model_id = model_id
|
||||
self.conversation_service = conversation_service
|
||||
self.prompt_builder = prompt_builder or CompressionPromptBuilder(
|
||||
version=settings.COMPRESSION_PROMPT_VERSION
|
||||
)
|
||||
|
||||
def compress_conversation(
|
||||
self,
|
||||
conversation: Dict[str, Any],
|
||||
compress_up_to_index: int,
|
||||
) -> CompressionMetadata:
|
||||
"""
|
||||
Compress conversation history up to specified index.
|
||||
|
||||
Args:
|
||||
conversation: Full conversation document
|
||||
compress_up_to_index: Last query index to include in compression
|
||||
|
||||
Returns:
|
||||
CompressionMetadata with compression details
|
||||
|
||||
Raises:
|
||||
ValueError: If compress_up_to_index is invalid
|
||||
"""
|
||||
try:
|
||||
queries = conversation.get("queries", [])
|
||||
|
||||
if compress_up_to_index < 0 or compress_up_to_index >= len(queries):
|
||||
raise ValueError(
|
||||
f"Invalid compress_up_to_index: {compress_up_to_index} "
|
||||
f"(conversation has {len(queries)} queries)"
|
||||
)
|
||||
|
||||
# Get queries to compress
|
||||
queries_to_compress = queries[: compress_up_to_index + 1]
|
||||
|
||||
# Check if there are existing compressions
|
||||
existing_compressions = conversation.get("compression_metadata", {}).get(
|
||||
"compression_points", []
|
||||
)
|
||||
|
||||
if existing_compressions:
|
||||
logger.info(
|
||||
f"Found {len(existing_compressions)} previous compression(s) - "
|
||||
f"will incorporate into new summary"
|
||||
)
|
||||
|
||||
# Calculate original token count
|
||||
original_tokens = TokenCounter.count_query_tokens(queries_to_compress)
|
||||
|
||||
# Log tool call stats
|
||||
self._log_tool_call_stats(queries_to_compress)
|
||||
|
||||
# Build compression prompt
|
||||
messages = self.prompt_builder.build_prompt(
|
||||
queries_to_compress, existing_compressions
|
||||
)
|
||||
|
||||
# Call LLM to generate compression
|
||||
logger.info(
|
||||
f"Starting compression: {len(queries_to_compress)} queries "
|
||||
f"(messages 0-{compress_up_to_index}, {original_tokens} tokens) "
|
||||
f"using model {self.model_id}"
|
||||
)
|
||||
|
||||
response = self.llm.gen(
|
||||
model=self.model_id, messages=messages, max_tokens=4000
|
||||
)
|
||||
|
||||
# Extract summary from response
|
||||
compressed_summary = self._extract_summary(response)
|
||||
|
||||
# Calculate compressed token count
|
||||
compressed_tokens = TokenCounter.count_message_tokens(
|
||||
[{"content": compressed_summary}]
|
||||
)
|
||||
|
||||
# Calculate compression ratio
|
||||
compression_ratio = (
|
||||
original_tokens / compressed_tokens if compressed_tokens > 0 else 0
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Compression complete: {original_tokens} → {compressed_tokens} tokens "
|
||||
f"({compression_ratio:.1f}x compression)"
|
||||
)
|
||||
|
||||
# Build compression metadata
|
||||
compression_metadata = CompressionMetadata(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
query_index=compress_up_to_index,
|
||||
compressed_summary=compressed_summary,
|
||||
original_token_count=original_tokens,
|
||||
compressed_token_count=compressed_tokens,
|
||||
compression_ratio=compression_ratio,
|
||||
model_used=self.model_id,
|
||||
compression_prompt_version=self.prompt_builder.version,
|
||||
)
|
||||
|
||||
return compression_metadata
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error compressing conversation: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def compress_and_save(
|
||||
self,
|
||||
conversation_id: str,
|
||||
conversation: Dict[str, Any],
|
||||
compress_up_to_index: int,
|
||||
) -> CompressionMetadata:
|
||||
"""
|
||||
Compress conversation and save to database.
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation ID
|
||||
conversation: Full conversation document
|
||||
compress_up_to_index: Last query index to include
|
||||
|
||||
Returns:
|
||||
CompressionMetadata
|
||||
|
||||
Raises:
|
||||
ValueError: If conversation_service not provided or invalid index
|
||||
"""
|
||||
if not self.conversation_service:
|
||||
raise ValueError(
|
||||
"conversation_service required for compress_and_save operation"
|
||||
)
|
||||
|
||||
# Perform compression
|
||||
metadata = self.compress_conversation(conversation, compress_up_to_index)
|
||||
|
||||
# Save to database
|
||||
self.conversation_service.update_compression_metadata(
|
||||
conversation_id, metadata.to_dict()
|
||||
)
|
||||
|
||||
logger.info(f"Compression metadata saved to database for {conversation_id}")
|
||||
|
||||
return metadata
|
||||
|
||||
def get_compressed_context(
|
||||
self, conversation: Dict[str, Any]
|
||||
) -> tuple[Optional[str], List[Dict[str, Any]]]:
|
||||
"""
|
||||
Get compressed summary + recent uncompressed messages.
|
||||
|
||||
Args:
|
||||
conversation: Full conversation document
|
||||
|
||||
Returns:
|
||||
(compressed_summary, recent_messages)
|
||||
"""
|
||||
try:
|
||||
compression_metadata = conversation.get("compression_metadata", {})
|
||||
|
||||
if not compression_metadata.get("is_compressed"):
|
||||
logger.debug("No compression metadata found - using full history")
|
||||
queries = conversation.get("queries", [])
|
||||
if queries is None:
|
||||
logger.error("Conversation queries is None - returning empty list")
|
||||
return None, []
|
||||
return None, queries
|
||||
|
||||
compression_points = compression_metadata.get("compression_points", [])
|
||||
|
||||
if not compression_points:
|
||||
logger.debug("No compression points found - using full history")
|
||||
queries = conversation.get("queries", [])
|
||||
if queries is None:
|
||||
logger.error("Conversation queries is None - returning empty list")
|
||||
return None, []
|
||||
return None, queries
|
||||
|
||||
# Get the most recent compression point
|
||||
latest_compression = compression_points[-1]
|
||||
compressed_summary = latest_compression.get("compressed_summary")
|
||||
last_compressed_index = latest_compression.get("query_index")
|
||||
compressed_tokens = latest_compression.get("compressed_token_count", 0)
|
||||
original_tokens = latest_compression.get("original_token_count", 0)
|
||||
|
||||
# Get only messages after compression point
|
||||
queries = conversation.get("queries", [])
|
||||
total_queries = len(queries)
|
||||
recent_queries = queries[last_compressed_index + 1 :]
|
||||
|
||||
logger.info(
|
||||
f"Using compressed context: summary ({compressed_tokens} tokens, "
|
||||
f"compressed from {original_tokens}) + {len(recent_queries)} recent messages "
|
||||
f"(messages {last_compressed_index + 1}-{total_queries - 1})"
|
||||
)
|
||||
|
||||
return compressed_summary, recent_queries
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting compressed context: {str(e)}", exc_info=True
|
||||
)
|
||||
queries = conversation.get("queries", [])
|
||||
if queries is None:
|
||||
return None, []
|
||||
return None, queries
|
||||
|
||||
def _extract_summary(self, llm_response: str) -> str:
|
||||
"""
|
||||
Extract clean summary from LLM response.
|
||||
|
||||
Args:
|
||||
llm_response: Raw LLM response
|
||||
|
||||
Returns:
|
||||
Cleaned summary text
|
||||
"""
|
||||
try:
|
||||
# Try to extract content within <summary> tags
|
||||
summary_match = re.search(
|
||||
r"<summary>(.*?)</summary>", llm_response, re.DOTALL
|
||||
)
|
||||
|
||||
if summary_match:
|
||||
summary = summary_match.group(1).strip()
|
||||
else:
|
||||
# If no summary tags, remove analysis tags and use the rest
|
||||
summary = re.sub(
|
||||
r"<analysis>.*?</analysis>", "", llm_response, flags=re.DOTALL
|
||||
).strip()
|
||||
|
||||
return summary
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error extracting summary: {str(e)}, using full response")
|
||||
return llm_response
|
||||
|
||||
def _log_tool_call_stats(self, queries: List[Dict[str, Any]]) -> None:
|
||||
"""Log statistics about tool calls in queries."""
|
||||
total_tool_calls = 0
|
||||
total_tool_result_chars = 0
|
||||
tool_call_breakdown = {}
|
||||
|
||||
for q in queries:
|
||||
for tc in q.get("tool_calls", []):
|
||||
total_tool_calls += 1
|
||||
tool_name = tc.get("tool_name", "unknown")
|
||||
action_name = tc.get("action_name", "unknown")
|
||||
key = f"{tool_name}.{action_name}"
|
||||
tool_call_breakdown[key] = tool_call_breakdown.get(key, 0) + 1
|
||||
|
||||
# Track total tool result size
|
||||
result = tc.get("result", "")
|
||||
if result:
|
||||
total_tool_result_chars += len(str(result))
|
||||
|
||||
if total_tool_calls > 0:
|
||||
tool_breakdown_str = ", ".join(
|
||||
f"{tool}({count})"
|
||||
for tool, count in sorted(tool_call_breakdown.items())
|
||||
)
|
||||
tool_result_kb = total_tool_result_chars / 1024
|
||||
logger.info(
|
||||
f"Tool call breakdown: {tool_breakdown_str} "
|
||||
f"(total result size: {tool_result_kb:.1f} KB, {total_tool_result_chars:,} chars)"
|
||||
)
|
||||
@@ -1,103 +0,0 @@
|
||||
"""Compression threshold checking logic."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
|
||||
from application.core.model_utils import get_token_limit
|
||||
from application.core.settings import settings
|
||||
from application.api.answer.services.compression.token_counter import TokenCounter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CompressionThresholdChecker:
|
||||
"""Determines if compression is needed based on token thresholds."""
|
||||
|
||||
def __init__(self, threshold_percentage: float = None):
|
||||
"""
|
||||
Initialize threshold checker.
|
||||
|
||||
Args:
|
||||
threshold_percentage: Percentage of context to use as threshold
|
||||
(defaults to settings.COMPRESSION_THRESHOLD_PERCENTAGE)
|
||||
"""
|
||||
self.threshold_percentage = (
|
||||
threshold_percentage or settings.COMPRESSION_THRESHOLD_PERCENTAGE
|
||||
)
|
||||
|
||||
def should_compress(
|
||||
self,
|
||||
conversation: Dict[str, Any],
|
||||
model_id: str,
|
||||
current_query_tokens: int = 500,
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if compression is needed.
|
||||
|
||||
Args:
|
||||
conversation: Full conversation document
|
||||
model_id: Target model for this request
|
||||
current_query_tokens: Estimated tokens for current query
|
||||
|
||||
Returns:
|
||||
True if tokens >= threshold% of context window
|
||||
"""
|
||||
try:
|
||||
# Calculate total tokens in conversation
|
||||
total_tokens = TokenCounter.count_conversation_tokens(conversation)
|
||||
total_tokens += current_query_tokens
|
||||
|
||||
# Get context window limit for model
|
||||
context_limit = get_token_limit(model_id)
|
||||
|
||||
# Calculate threshold
|
||||
threshold = int(context_limit * self.threshold_percentage)
|
||||
|
||||
compression_needed = total_tokens >= threshold
|
||||
percentage_used = (total_tokens / context_limit) * 100
|
||||
|
||||
if compression_needed:
|
||||
logger.warning(
|
||||
f"COMPRESSION TRIGGERED: {total_tokens} tokens / {context_limit} limit "
|
||||
f"({percentage_used:.1f}% used, threshold: {self.threshold_percentage * 100:.0f}%)"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Compression check: {total_tokens}/{context_limit} tokens "
|
||||
f"({percentage_used:.1f}% used, threshold: {self.threshold_percentage * 100:.0f}%) - No compression needed"
|
||||
)
|
||||
|
||||
return compression_needed
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking compression need: {str(e)}", exc_info=True)
|
||||
return False
|
||||
|
||||
def check_message_tokens(self, messages: list, model_id: str) -> bool:
|
||||
"""
|
||||
Check if message list exceeds threshold.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts
|
||||
model_id: Target model
|
||||
|
||||
Returns:
|
||||
True if at or above threshold
|
||||
"""
|
||||
try:
|
||||
current_tokens = TokenCounter.count_message_tokens(messages)
|
||||
context_limit = get_token_limit(model_id)
|
||||
threshold = int(context_limit * self.threshold_percentage)
|
||||
|
||||
if current_tokens >= threshold:
|
||||
logger.warning(
|
||||
f"Message context limit approaching: {current_tokens}/{context_limit} tokens "
|
||||
f"({(current_tokens/context_limit)*100:.1f}%)"
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking message tokens: {str(e)}", exc_info=True)
|
||||
return False
|
||||
@@ -1,103 +0,0 @@
|
||||
"""Token counting utilities for compression."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from application.utils import num_tokens_from_string
|
||||
from application.core.settings import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TokenCounter:
|
||||
"""Centralized token counting for conversations and messages."""
|
||||
|
||||
@staticmethod
|
||||
def count_message_tokens(messages: List[Dict]) -> int:
|
||||
"""
|
||||
Calculate total tokens in a list of messages.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'content' field
|
||||
|
||||
Returns:
|
||||
Total token count
|
||||
"""
|
||||
total_tokens = 0
|
||||
for message in messages:
|
||||
content = message.get("content", "")
|
||||
if isinstance(content, str):
|
||||
total_tokens += num_tokens_from_string(content)
|
||||
elif isinstance(content, list):
|
||||
# Handle structured content (tool calls, etc.)
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
total_tokens += num_tokens_from_string(str(item))
|
||||
return total_tokens
|
||||
|
||||
@staticmethod
|
||||
def count_query_tokens(
|
||||
queries: List[Dict[str, Any]], include_tool_calls: bool = True
|
||||
) -> int:
|
||||
"""
|
||||
Count tokens across multiple query objects.
|
||||
|
||||
Args:
|
||||
queries: List of query objects from conversation
|
||||
include_tool_calls: Whether to count tool call tokens
|
||||
|
||||
Returns:
|
||||
Total token count
|
||||
"""
|
||||
total_tokens = 0
|
||||
|
||||
for query in queries:
|
||||
# Count prompt and response tokens
|
||||
if "prompt" in query:
|
||||
total_tokens += num_tokens_from_string(query["prompt"])
|
||||
if "response" in query:
|
||||
total_tokens += num_tokens_from_string(query["response"])
|
||||
if "thought" in query:
|
||||
total_tokens += num_tokens_from_string(query.get("thought", ""))
|
||||
|
||||
# Count tool call tokens
|
||||
if include_tool_calls and "tool_calls" in query:
|
||||
for tool_call in query["tool_calls"]:
|
||||
tool_call_string = (
|
||||
f"Tool: {tool_call.get('tool_name')} | "
|
||||
f"Action: {tool_call.get('action_name')} | "
|
||||
f"Args: {tool_call.get('arguments')} | "
|
||||
f"Response: {tool_call.get('result')}"
|
||||
)
|
||||
total_tokens += num_tokens_from_string(tool_call_string)
|
||||
|
||||
return total_tokens
|
||||
|
||||
@staticmethod
|
||||
def count_conversation_tokens(
|
||||
conversation: Dict[str, Any], include_system_prompt: bool = False
|
||||
) -> int:
|
||||
"""
|
||||
Calculate total tokens in a conversation.
|
||||
|
||||
Args:
|
||||
conversation: Conversation document
|
||||
include_system_prompt: Whether to include system prompt in count
|
||||
|
||||
Returns:
|
||||
Total token count
|
||||
"""
|
||||
try:
|
||||
queries = conversation.get("queries", [])
|
||||
total_tokens = TokenCounter.count_query_tokens(queries)
|
||||
|
||||
# Add system prompt tokens if requested
|
||||
if include_system_prompt:
|
||||
# Rough estimate for system prompt
|
||||
total_tokens += settings.RESERVED_TOKENS.get("system_prompt", 500)
|
||||
|
||||
return total_tokens
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating conversation tokens: {str(e)}")
|
||||
return 0
|
||||
@@ -1,83 +0,0 @@
|
||||
"""Type definitions for compression module."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompressionMetadata:
|
||||
"""Metadata about a compression operation."""
|
||||
|
||||
timestamp: datetime
|
||||
query_index: int
|
||||
compressed_summary: str
|
||||
original_token_count: int
|
||||
compressed_token_count: int
|
||||
compression_ratio: float
|
||||
model_used: str
|
||||
compression_prompt_version: str
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for DB storage."""
|
||||
return {
|
||||
"timestamp": self.timestamp,
|
||||
"query_index": self.query_index,
|
||||
"compressed_summary": self.compressed_summary,
|
||||
"original_token_count": self.original_token_count,
|
||||
"compressed_token_count": self.compressed_token_count,
|
||||
"compression_ratio": self.compression_ratio,
|
||||
"model_used": self.model_used,
|
||||
"compression_prompt_version": self.compression_prompt_version,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompressionResult:
|
||||
"""Result of a compression operation."""
|
||||
|
||||
success: bool
|
||||
compressed_summary: Optional[str] = None
|
||||
recent_queries: List[Dict[str, Any]] = field(default_factory=list)
|
||||
metadata: Optional[CompressionMetadata] = None
|
||||
error: Optional[str] = None
|
||||
compression_performed: bool = False
|
||||
|
||||
@classmethod
|
||||
def success_with_compression(
|
||||
cls, summary: str, queries: List[Dict], metadata: CompressionMetadata
|
||||
) -> "CompressionResult":
|
||||
"""Create a successful result with compression."""
|
||||
return cls(
|
||||
success=True,
|
||||
compressed_summary=summary,
|
||||
recent_queries=queries,
|
||||
metadata=metadata,
|
||||
compression_performed=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def success_no_compression(cls, queries: List[Dict]) -> "CompressionResult":
|
||||
"""Create a successful result without compression needed."""
|
||||
return cls(
|
||||
success=True,
|
||||
recent_queries=queries,
|
||||
compression_performed=False,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def failure(cls, error: str) -> "CompressionResult":
|
||||
"""Create a failure result."""
|
||||
return cls(success=False, error=error, compression_performed=False)
|
||||
|
||||
def as_history(self) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Convert recent queries to history format.
|
||||
|
||||
Returns:
|
||||
List of prompt/response dicts
|
||||
"""
|
||||
return [
|
||||
{"prompt": q["prompt"], "response": q["response"]}
|
||||
for q in self.recent_queries
|
||||
]
|
||||
@@ -62,8 +62,6 @@ class ConversationService:
|
||||
attachment_ids: Optional[List[str]] = None,
|
||||
) -> str:
|
||||
"""Save or update a conversation in the database"""
|
||||
if decoded_token is None:
|
||||
raise ValueError("Invalid or missing authentication token")
|
||||
user_id = decoded_token.get("sub")
|
||||
if not user_id:
|
||||
raise ValueError("User ID not found in token")
|
||||
@@ -150,12 +148,9 @@ class ConversationService:
|
||||
]
|
||||
|
||||
completion = llm.gen(
|
||||
model=model_id, messages=messages_summary, max_tokens=500
|
||||
model=model_id, messages=messages_summary, max_tokens=30
|
||||
)
|
||||
|
||||
if not completion or not completion.strip():
|
||||
completion = question[:50] if question else "New Conversation"
|
||||
|
||||
conversation_data = {
|
||||
"user": user_id,
|
||||
"date": current_time,
|
||||
@@ -185,103 +180,3 @@ class ConversationService:
|
||||
conversation_data["api_key"] = agent["key"]
|
||||
result = self.conversations_collection.insert_one(conversation_data)
|
||||
return str(result.inserted_id)
|
||||
|
||||
def update_compression_metadata(
|
||||
self, conversation_id: str, compression_metadata: Dict[str, Any]
|
||||
) -> None:
|
||||
"""
|
||||
Update conversation with compression metadata.
|
||||
|
||||
Uses $push with $slice to keep only the most recent compression points,
|
||||
preventing unbounded array growth. Since each compression incorporates
|
||||
previous compressions, older points become redundant.
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation ID
|
||||
compression_metadata: Compression point data
|
||||
"""
|
||||
try:
|
||||
self.conversations_collection.update_one(
|
||||
{"_id": ObjectId(conversation_id)},
|
||||
{
|
||||
"$set": {
|
||||
"compression_metadata.is_compressed": True,
|
||||
"compression_metadata.last_compression_at": compression_metadata.get(
|
||||
"timestamp"
|
||||
),
|
||||
},
|
||||
"$push": {
|
||||
"compression_metadata.compression_points": {
|
||||
"$each": [compression_metadata],
|
||||
"$slice": -settings.COMPRESSION_MAX_HISTORY_POINTS,
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
logger.info(
|
||||
f"Updated compression metadata for conversation {conversation_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error updating compression metadata: {str(e)}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def append_compression_message(
|
||||
self, conversation_id: str, compression_metadata: Dict[str, Any]
|
||||
) -> None:
|
||||
"""
|
||||
Append a synthetic compression summary entry into the conversation history.
|
||||
This makes the summary visible in the DB alongside normal queries.
|
||||
"""
|
||||
try:
|
||||
summary = compression_metadata.get("compressed_summary", "")
|
||||
if not summary:
|
||||
return
|
||||
timestamp = compression_metadata.get("timestamp", datetime.now(timezone.utc))
|
||||
|
||||
self.conversations_collection.update_one(
|
||||
{"_id": ObjectId(conversation_id)},
|
||||
{
|
||||
"$push": {
|
||||
"queries": {
|
||||
"prompt": "[Context Compression Summary]",
|
||||
"response": summary,
|
||||
"thought": "",
|
||||
"sources": [],
|
||||
"tool_calls": [],
|
||||
"timestamp": timestamp,
|
||||
"attachments": [],
|
||||
"model_id": compression_metadata.get("model_used"),
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
logger.info(f"Appended compression summary to conversation {conversation_id}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error appending compression summary: {str(e)}", exc_info=True
|
||||
)
|
||||
|
||||
def get_compression_metadata(
|
||||
self, conversation_id: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get compression metadata for a conversation.
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation ID
|
||||
|
||||
Returns:
|
||||
Compression metadata dict or None
|
||||
"""
|
||||
try:
|
||||
conversation = self.conversations_collection.find_one(
|
||||
{"_id": ObjectId(conversation_id)}, {"compression_metadata": 1}
|
||||
)
|
||||
return conversation.get("compression_metadata") if conversation else None
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting compression metadata: {str(e)}", exc_info=True
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -10,8 +10,6 @@ from bson.dbref import DBRef
|
||||
from bson.objectid import ObjectId
|
||||
|
||||
from application.agents.agent_creator import AgentCreator
|
||||
from application.api.answer.services.compression import CompressionOrchestrator
|
||||
from application.api.answer.services.compression.token_counter import TokenCounter
|
||||
from application.api.answer.services.conversation_service import ConversationService
|
||||
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||||
from application.core.model_utils import (
|
||||
@@ -92,21 +90,17 @@ class StreamProcessor:
|
||||
self.shared_token = None
|
||||
self.model_id: Optional[str] = None
|
||||
self.conversation_service = ConversationService()
|
||||
self.compression_orchestrator = CompressionOrchestrator(
|
||||
self.conversation_service
|
||||
)
|
||||
self.prompt_renderer = PromptRenderer()
|
||||
self._prompt_content: Optional[str] = None
|
||||
self._required_tool_actions: Optional[Dict[str, Set[Optional[str]]]] = None
|
||||
self.compressed_summary: Optional[str] = None
|
||||
self.compressed_summary_tokens: int = 0
|
||||
|
||||
def initialize(self):
|
||||
"""Initialize all required components for processing"""
|
||||
self._configure_agent()
|
||||
self._validate_and_set_model()
|
||||
self._configure_agent()
|
||||
self._configure_source()
|
||||
self._configure_retriever()
|
||||
self._configure_agent()
|
||||
self._load_conversation_history()
|
||||
self._process_attachments()
|
||||
|
||||
@@ -118,69 +112,14 @@ class StreamProcessor:
|
||||
)
|
||||
if not conversation:
|
||||
raise ValueError("Conversation not found or unauthorized")
|
||||
|
||||
# Check if compression is enabled and needed
|
||||
if settings.ENABLE_CONVERSATION_COMPRESSION:
|
||||
self._handle_compression(conversation)
|
||||
else:
|
||||
# Original behavior - load all history
|
||||
self.history = [
|
||||
{"prompt": query["prompt"], "response": query["response"]}
|
||||
for query in conversation.get("queries", [])
|
||||
]
|
||||
else:
|
||||
self.history = limit_chat_history(
|
||||
json.loads(self.data.get("history", "[]")), model_id=self.model_id
|
||||
)
|
||||
|
||||
def _handle_compression(self, conversation: Dict[str, Any]):
|
||||
"""
|
||||
Handle conversation compression logic using orchestrator.
|
||||
|
||||
Args:
|
||||
conversation: Full conversation document
|
||||
"""
|
||||
try:
|
||||
# Use orchestrator to handle all compression logic
|
||||
result = self.compression_orchestrator.compress_if_needed(
|
||||
conversation_id=self.conversation_id,
|
||||
user_id=self.initial_user_id,
|
||||
model_id=self.model_id,
|
||||
decoded_token=self.decoded_token,
|
||||
)
|
||||
|
||||
if not result.success:
|
||||
logger.error(f"Compression failed: {result.error}, using full history")
|
||||
self.history = [
|
||||
{"prompt": query["prompt"], "response": query["response"]}
|
||||
for query in conversation.get("queries", [])
|
||||
]
|
||||
return
|
||||
|
||||
# Set compressed summary if compression was performed
|
||||
if result.compression_performed and result.compressed_summary:
|
||||
self.compressed_summary = result.compressed_summary
|
||||
self.compressed_summary_tokens = TokenCounter.count_message_tokens(
|
||||
[{"content": result.compressed_summary}]
|
||||
)
|
||||
logger.info(
|
||||
f"Using compressed summary ({self.compressed_summary_tokens} tokens) "
|
||||
f"+ {len(result.recent_queries)} recent messages"
|
||||
)
|
||||
|
||||
# Build history from recent queries
|
||||
self.history = result.as_history()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error handling compression, falling back to standard history: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Fallback to original behavior
|
||||
self.history = [
|
||||
{"prompt": query["prompt"], "response": query["response"]}
|
||||
for query in conversation.get("queries", [])
|
||||
]
|
||||
else:
|
||||
self.history = limit_chat_history(
|
||||
json.loads(self.data.get("history", "[]")), model_id=self.model_id
|
||||
)
|
||||
|
||||
def _process_attachments(self):
|
||||
"""Process any attachments in the request"""
|
||||
@@ -223,20 +162,11 @@ class StreamProcessor:
|
||||
raise ValueError(
|
||||
f"Invalid model_id '{requested_model}'. "
|
||||
f"Available models: {', '.join(available_models[:5])}"
|
||||
+ (
|
||||
f" and {len(available_models) - 5} more"
|
||||
if len(available_models) > 5
|
||||
else ""
|
||||
)
|
||||
+ (f" and {len(available_models) - 5} more" if len(available_models) > 5 else "")
|
||||
)
|
||||
self.model_id = requested_model
|
||||
else:
|
||||
# Check if agent has a default model configured
|
||||
agent_default_model = self.agent_config.get("default_model_id", "")
|
||||
if agent_default_model and validate_model_id(agent_default_model):
|
||||
self.model_id = agent_default_model
|
||||
else:
|
||||
self.model_id = get_default_model_id()
|
||||
self.model_id = get_default_model_id()
|
||||
|
||||
def _get_agent_key(self, agent_id: Optional[str], user_id: Optional[str]) -> tuple:
|
||||
"""Get API key for agent with access control"""
|
||||
@@ -309,10 +239,6 @@ class StreamProcessor:
|
||||
data["sources"] = sources_list
|
||||
else:
|
||||
data["sources"] = []
|
||||
|
||||
# Preserve model configuration from agent
|
||||
data["default_model_id"] = data.get("default_model_id", "")
|
||||
|
||||
return data
|
||||
|
||||
def _configure_source(self):
|
||||
@@ -365,16 +291,12 @@ class StreamProcessor:
|
||||
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
|
||||
"user_api_key": api_key,
|
||||
"json_schema": data_key.get("json_schema"),
|
||||
"default_model_id": data_key.get("default_model_id", ""),
|
||||
}
|
||||
)
|
||||
self.initial_user_id = data_key.get("user")
|
||||
self.decoded_token = {"sub": data_key.get("user")}
|
||||
if data_key.get("source"):
|
||||
self.source = {"active_docs": data_key["source"]}
|
||||
if data_key.get("workflow"):
|
||||
self.agent_config["workflow"] = data_key["workflow"]
|
||||
self.agent_config["workflow_owner"] = data_key.get("user")
|
||||
if data_key.get("retriever"):
|
||||
self.retriever_config["retriever_name"] = data_key["retriever"]
|
||||
if data_key.get("chunks") is not None:
|
||||
@@ -393,7 +315,6 @@ class StreamProcessor:
|
||||
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
|
||||
"user_api_key": self.agent_key,
|
||||
"json_schema": data_key.get("json_schema"),
|
||||
"default_model_id": data_key.get("default_model_id", ""),
|
||||
}
|
||||
)
|
||||
self.decoded_token = (
|
||||
@@ -403,9 +324,6 @@ class StreamProcessor:
|
||||
)
|
||||
if data_key.get("source"):
|
||||
self.source = {"active_docs": data_key["source"]}
|
||||
if data_key.get("workflow"):
|
||||
self.agent_config["workflow"] = data_key["workflow"]
|
||||
self.agent_config["workflow_owner"] = data_key.get("user")
|
||||
if data_key.get("retriever"):
|
||||
self.retriever_config["retriever_name"] = data_key["retriever"]
|
||||
if data_key.get("chunks") is not None:
|
||||
@@ -417,32 +335,26 @@ class StreamProcessor:
|
||||
)
|
||||
self.retriever_config["chunks"] = 2
|
||||
else:
|
||||
agent_type = settings.AGENT_NAME
|
||||
if self.data.get("workflow") and isinstance(
|
||||
self.data.get("workflow"), dict
|
||||
):
|
||||
agent_type = "workflow"
|
||||
self.agent_config["workflow"] = self.data["workflow"]
|
||||
if isinstance(self.decoded_token, dict):
|
||||
self.agent_config["workflow_owner"] = self.decoded_token.get("sub")
|
||||
|
||||
self.agent_config.update(
|
||||
{
|
||||
"prompt_id": self.data.get("prompt_id", "default"),
|
||||
"agent_type": agent_type,
|
||||
"agent_type": settings.AGENT_NAME,
|
||||
"user_api_key": None,
|
||||
"json_schema": None,
|
||||
"default_model_id": "",
|
||||
}
|
||||
)
|
||||
|
||||
def _configure_retriever(self):
|
||||
doc_token_limit = calculate_doc_token_budget(model_id=self.model_id)
|
||||
history_token_limit = int(self.data.get("token_limit", 2000))
|
||||
doc_token_limit = calculate_doc_token_budget(
|
||||
model_id=self.model_id, history_token_limit=history_token_limit
|
||||
)
|
||||
|
||||
self.retriever_config = {
|
||||
"retriever_name": self.data.get("retriever", "classic"),
|
||||
"chunks": int(self.data.get("chunks", 2)),
|
||||
"doc_token_limit": doc_token_limit,
|
||||
"history_token_limit": history_token_limit,
|
||||
}
|
||||
|
||||
api_key = self.data.get("api_key") or self.agent_key
|
||||
@@ -746,38 +658,17 @@ class StreamProcessor:
|
||||
)
|
||||
system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER)
|
||||
|
||||
agent_type = self.agent_config["agent_type"]
|
||||
|
||||
# Base agent kwargs
|
||||
agent_kwargs = {
|
||||
"endpoint": "stream",
|
||||
"llm_name": provider or settings.LLM_PROVIDER,
|
||||
"model_id": self.model_id,
|
||||
"api_key": system_api_key,
|
||||
"user_api_key": self.agent_config["user_api_key"],
|
||||
"prompt": rendered_prompt,
|
||||
"chat_history": self.history,
|
||||
"retrieved_docs": self.retrieved_docs,
|
||||
"decoded_token": self.decoded_token,
|
||||
"attachments": self.attachments,
|
||||
"json_schema": self.agent_config.get("json_schema"),
|
||||
"compressed_summary": self.compressed_summary,
|
||||
}
|
||||
|
||||
# Workflow-specific kwargs for workflow agents
|
||||
if agent_type == "workflow":
|
||||
workflow_config = self.agent_config.get("workflow")
|
||||
if isinstance(workflow_config, str):
|
||||
agent_kwargs["workflow_id"] = workflow_config
|
||||
elif isinstance(workflow_config, dict):
|
||||
agent_kwargs["workflow"] = workflow_config
|
||||
workflow_owner = self.agent_config.get("workflow_owner")
|
||||
if workflow_owner:
|
||||
agent_kwargs["workflow_owner"] = workflow_owner
|
||||
|
||||
agent = AgentCreator.create_agent(agent_type, **agent_kwargs)
|
||||
|
||||
agent.conversation_id = self.conversation_id
|
||||
agent.initial_user_id = self.initial_user_id
|
||||
|
||||
return agent
|
||||
return AgentCreator.create_agent(
|
||||
self.agent_config["agent_type"],
|
||||
endpoint="stream",
|
||||
llm_name=provider or settings.LLM_PROVIDER,
|
||||
model_id=self.model_id,
|
||||
api_key=system_api_key,
|
||||
user_api_key=self.agent_config["user_api_key"],
|
||||
prompt=rendered_prompt,
|
||||
chat_history=self.history,
|
||||
retrieved_docs=self.retrieved_docs,
|
||||
decoded_token=self.decoded_token,
|
||||
attachments=self.attachments,
|
||||
json_schema=self.agent_config.get("json_schema"),
|
||||
)
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import base64
|
||||
import datetime
|
||||
import html
|
||||
import json
|
||||
import uuid
|
||||
from urllib.parse import urlencode
|
||||
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
@@ -37,18 +35,6 @@ connector = Blueprint("connector", __name__)
|
||||
connectors_ns = Namespace("connectors", description="Connector operations", path="/")
|
||||
api.add_namespace(connectors_ns)
|
||||
|
||||
# Fixed callback status path to prevent open redirect
|
||||
CALLBACK_STATUS_PATH = "/api/connectors/callback-status"
|
||||
|
||||
|
||||
def build_callback_redirect(params: dict) -> str:
|
||||
"""Build a safe redirect URL to the callback status page.
|
||||
|
||||
Uses a fixed path and properly URL-encodes all parameters
|
||||
to prevent URL injection and open redirect vulnerabilities.
|
||||
"""
|
||||
return f"{CALLBACK_STATUS_PATH}?{urlencode(params)}"
|
||||
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/auth")
|
||||
@@ -89,8 +75,8 @@ class ConnectorAuth(Resource):
|
||||
"state": state
|
||||
}), 200)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error generating connector auth URL: {e}", exc_info=True)
|
||||
return make_response(jsonify({"success": False, "error": "Failed to generate authorization URL"}), 500)
|
||||
current_app.logger.error(f"Error generating connector auth URL: {e}")
|
||||
return make_response(jsonify({"success": False, "error": str(e)}), 500)
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/callback")
|
||||
@@ -107,37 +93,18 @@ class ConnectorsCallback(Resource):
|
||||
error = request.args.get('error')
|
||||
|
||||
state_dict = json.loads(base64.urlsafe_b64decode(state.encode()).decode())
|
||||
provider = state_dict.get("provider")
|
||||
state_object_id = state_dict.get("object_id")
|
||||
|
||||
# Validate provider
|
||||
if not provider or not isinstance(provider, str) or not ConnectorCreator.is_supported(provider):
|
||||
return redirect(build_callback_redirect({
|
||||
"status": "error",
|
||||
"message": "Invalid provider"
|
||||
}))
|
||||
provider = state_dict["provider"]
|
||||
state_object_id = state_dict["object_id"]
|
||||
|
||||
if error:
|
||||
if error == "access_denied":
|
||||
return redirect(build_callback_redirect({
|
||||
"status": "cancelled",
|
||||
"message": "Authentication was cancelled. You can try again if you'd like to connect your account.",
|
||||
"provider": provider
|
||||
}))
|
||||
return redirect(f"/api/connectors/callback-status?status=cancelled&message=Authentication+was+cancelled.+You+can+try+again+if+you'd+like+to+connect+your+account.&provider={provider}")
|
||||
else:
|
||||
current_app.logger.warning(f"OAuth error in callback: {error}")
|
||||
return redirect(build_callback_redirect({
|
||||
"status": "error",
|
||||
"message": "Authentication failed. Please try again and make sure to grant all requested permissions.",
|
||||
"provider": provider
|
||||
}))
|
||||
return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}")
|
||||
|
||||
if not authorization_code:
|
||||
return redirect(build_callback_redirect({
|
||||
"status": "error",
|
||||
"message": "Authentication failed. Please try again and make sure to grant all requested permissions.",
|
||||
"provider": provider
|
||||
}))
|
||||
return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}")
|
||||
|
||||
try:
|
||||
auth = ConnectorCreator.create_auth(provider)
|
||||
@@ -146,19 +113,20 @@ class ConnectorsCallback(Resource):
|
||||
session_token = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
if provider == "google_drive":
|
||||
credentials = auth.create_credentials_from_token_info(token_info)
|
||||
service = auth.build_drive_service(credentials)
|
||||
user_info = service.about().get(fields="user").execute()
|
||||
user_email = user_info.get('user', {}).get('emailAddress', 'Connected User')
|
||||
else:
|
||||
user_email = token_info.get('user_info', {}).get('email', 'Connected User')
|
||||
|
||||
credentials = auth.create_credentials_from_token_info(token_info)
|
||||
service = auth.build_drive_service(credentials)
|
||||
user_info = service.about().get(fields="user").execute()
|
||||
user_email = user_info.get('user', {}).get('emailAddress', 'Connected User')
|
||||
except Exception as e:
|
||||
current_app.logger.warning(f"Could not get user info: {e}")
|
||||
user_email = 'Connected User'
|
||||
|
||||
sanitized_token_info = auth.sanitize_token_info(token_info)
|
||||
sanitized_token_info = {
|
||||
"access_token": token_info.get("access_token"),
|
||||
"refresh_token": token_info.get("refresh_token"),
|
||||
"token_uri": token_info.get("token_uri"),
|
||||
"expiry": token_info.get("expiry")
|
||||
}
|
||||
|
||||
sessions_collection.find_one_and_update(
|
||||
{"_id": ObjectId(state_object_id), "provider": provider},
|
||||
@@ -173,39 +141,26 @@ class ConnectorsCallback(Resource):
|
||||
)
|
||||
|
||||
# Redirect to success page with session token and user email
|
||||
return redirect(build_callback_redirect({
|
||||
"status": "success",
|
||||
"message": "Authentication successful",
|
||||
"provider": provider,
|
||||
"session_token": session_token,
|
||||
"user_email": user_email
|
||||
}))
|
||||
return redirect(f"/api/connectors/callback-status?status=success&message=Authentication+successful&provider={provider}&session_token={session_token}&user_email={user_email}")
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error exchanging code for tokens: {str(e)}", exc_info=True)
|
||||
return redirect(build_callback_redirect({
|
||||
"status": "error",
|
||||
"message": "Authentication failed. Please try again and make sure to grant all requested permissions.",
|
||||
"provider": provider
|
||||
}))
|
||||
return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}")
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error handling connector callback: {e}")
|
||||
return redirect(build_callback_redirect({
|
||||
"status": "error",
|
||||
"message": "Authentication failed. Please try again and make sure to grant all requested permissions."
|
||||
}))
|
||||
return redirect("/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.")
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/files")
|
||||
class ConnectorFiles(Resource):
|
||||
@api.expect(api.model("ConnectorFilesModel", {
|
||||
"provider": fields.String(required=True),
|
||||
"session_token": fields.String(required=True),
|
||||
"folder_id": fields.String(required=False),
|
||||
"limit": fields.Integer(required=False),
|
||||
"provider": fields.String(required=True),
|
||||
"session_token": fields.String(required=True),
|
||||
"folder_id": fields.String(required=False),
|
||||
"limit": fields.Integer(required=False),
|
||||
"page_token": fields.String(required=False),
|
||||
"search_query": fields.String(required=False),
|
||||
"search_query": fields.String(required=False)
|
||||
}))
|
||||
@api.doc(description="List files from a connector provider (supports pagination and search)")
|
||||
def post(self):
|
||||
@@ -213,8 +168,11 @@ class ConnectorFiles(Resource):
|
||||
data = request.get_json()
|
||||
provider = data.get('provider')
|
||||
session_token = data.get('session_token')
|
||||
folder_id = data.get('folder_id')
|
||||
limit = data.get('limit', 10)
|
||||
|
||||
page_token = data.get('page_token')
|
||||
search_query = data.get('search_query')
|
||||
|
||||
if not provider or not session_token:
|
||||
return make_response(jsonify({"success": False, "error": "provider and session_token are required"}), 400)
|
||||
|
||||
@@ -227,12 +185,15 @@ class ConnectorFiles(Resource):
|
||||
return make_response(jsonify({"success": False, "error": "Invalid or unauthorized session"}), 401)
|
||||
|
||||
loader = ConnectorCreator.create_connector(provider, session_token)
|
||||
|
||||
generic_keys = {'provider', 'session_token'}
|
||||
input_config = {
|
||||
k: v for k, v in data.items() if k not in generic_keys
|
||||
'limit': limit,
|
||||
'list_only': True,
|
||||
'session_token': session_token,
|
||||
'folder_id': folder_id,
|
||||
'page_token': page_token
|
||||
}
|
||||
input_config['list_only'] = True
|
||||
if search_query:
|
||||
input_config['search_query'] = search_query
|
||||
|
||||
documents = loader.load_data(input_config)
|
||||
|
||||
@@ -267,8 +228,8 @@ class ConnectorFiles(Resource):
|
||||
"has_more": has_more
|
||||
}), 200)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error loading connector files: {e}", exc_info=True)
|
||||
return make_response(jsonify({"success": False, "error": "Failed to load files"}), 500)
|
||||
current_app.logger.error(f"Error loading connector files: {e}")
|
||||
return make_response(jsonify({"success": False, "error": f"Failed to load files: {str(e)}"}), 500)
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/validate-session")
|
||||
@@ -299,7 +260,12 @@ class ConnectorValidateSession(Resource):
|
||||
if is_expired and token_info.get('refresh_token'):
|
||||
try:
|
||||
refreshed_token_info = auth.refresh_access_token(token_info.get('refresh_token'))
|
||||
sanitized_token_info = auth.sanitize_token_info(refreshed_token_info)
|
||||
sanitized_token_info = {
|
||||
"access_token": refreshed_token_info.get("access_token"),
|
||||
"refresh_token": refreshed_token_info.get("refresh_token"),
|
||||
"token_uri": refreshed_token_info.get("token_uri"),
|
||||
"expiry": refreshed_token_info.get("expiry")
|
||||
}
|
||||
sessions_collection.update_one(
|
||||
{"session_token": session_token},
|
||||
{"$set": {"token_info": sanitized_token_info}}
|
||||
@@ -316,21 +282,15 @@ class ConnectorValidateSession(Resource):
|
||||
"error": "Session token has expired. Please reconnect."
|
||||
}), 401)
|
||||
|
||||
_base_fields = {"access_token", "refresh_token", "token_uri", "expiry"}
|
||||
provider_extras = {k: v for k, v in token_info.items() if k not in _base_fields}
|
||||
|
||||
response_data = {
|
||||
return make_response(jsonify({
|
||||
"success": True,
|
||||
"expired": False,
|
||||
"user_email": session.get('user_email', 'Connected User'),
|
||||
"access_token": token_info.get('access_token'),
|
||||
**provider_extras,
|
||||
}
|
||||
|
||||
return make_response(jsonify(response_data), 200)
|
||||
"access_token": token_info.get('access_token')
|
||||
}), 200)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error validating connector session: {e}", exc_info=True)
|
||||
return make_response(jsonify({"success": False, "error": "Failed to validate session"}), 500)
|
||||
current_app.logger.error(f"Error validating connector session: {e}")
|
||||
return make_response(jsonify({"success": False, "error": str(e)}), 500)
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/disconnect")
|
||||
@@ -351,8 +311,8 @@ class ConnectorDisconnect(Resource):
|
||||
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error disconnecting connector session: {e}", exc_info=True)
|
||||
return make_response(jsonify({"success": False, "error": "Failed to disconnect session"}), 500)
|
||||
current_app.logger.error(f"Error disconnecting connector session: {e}")
|
||||
return make_response(jsonify({"success": False, "error": str(e)}), 500)
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/sync")
|
||||
@@ -458,8 +418,8 @@ class ConnectorSync(Resource):
|
||||
return make_response(
|
||||
jsonify({
|
||||
"success": False,
|
||||
"error": "Failed to sync connector source"
|
||||
}),
|
||||
"error": str(err)
|
||||
}),
|
||||
400
|
||||
)
|
||||
|
||||
@@ -470,32 +430,17 @@ class ConnectorCallbackStatus(Resource):
|
||||
def get(self):
|
||||
"""Return HTML page with connector authentication status"""
|
||||
try:
|
||||
# Validate and sanitize status to a known value
|
||||
status_raw = request.args.get('status', 'error')
|
||||
status = status_raw if status_raw in ('success', 'error', 'cancelled') else 'error'
|
||||
|
||||
# Escape all user-controlled values for HTML context
|
||||
message = html.escape(request.args.get('message', ''))
|
||||
provider_raw = request.args.get('provider', 'connector')
|
||||
provider = html.escape(provider_raw.replace('_', ' ').title())
|
||||
status = request.args.get('status', 'error')
|
||||
message = request.args.get('message', '')
|
||||
provider = request.args.get('provider', 'connector')
|
||||
session_token = request.args.get('session_token', '')
|
||||
user_email = html.escape(request.args.get('user_email', ''))
|
||||
|
||||
def safe_js_string(value: str) -> str:
|
||||
"""Safely encode a string for embedding in inline JavaScript."""
|
||||
js_encoded = json.dumps(value)
|
||||
return js_encoded.replace('</', '<\\/').replace('<!--', '<\\!--')
|
||||
|
||||
js_status = safe_js_string(status)
|
||||
js_session_token = safe_js_string(session_token)
|
||||
js_user_email = safe_js_string(user_email)
|
||||
js_provider_type = safe_js_string(provider_raw)
|
||||
|
||||
user_email = request.args.get('user_email', '')
|
||||
|
||||
html_content = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>{provider} Authentication</title>
|
||||
<title>{provider.replace('_', ' ').title()} Authentication</title>
|
||||
<style>
|
||||
body {{ font-family: Arial, sans-serif; text-align: center; padding: 40px; }}
|
||||
.container {{ max-width: 600px; margin: 0 auto; }}
|
||||
@@ -505,14 +450,13 @@ class ConnectorCallbackStatus(Resource):
|
||||
</style>
|
||||
<script>
|
||||
window.onload = function() {{
|
||||
const status = {js_status};
|
||||
const sessionToken = {js_session_token};
|
||||
const userEmail = {js_user_email};
|
||||
const providerType = {js_provider_type};
|
||||
const status = "{status}";
|
||||
const sessionToken = "{session_token}";
|
||||
const userEmail = "{user_email}";
|
||||
|
||||
if (status === "success" && window.opener) {{
|
||||
window.opener.postMessage({{
|
||||
type: providerType + '_auth_success',
|
||||
type: '{provider}_auth_success',
|
||||
session_token: sessionToken,
|
||||
user_email: userEmail
|
||||
}}, '*');
|
||||
@@ -526,17 +470,17 @@ class ConnectorCallbackStatus(Resource):
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h2>{provider} Authentication</h2>
|
||||
<h2>{provider.replace('_', ' ').title()} Authentication</h2>
|
||||
<div class="{status}">
|
||||
<p>{message}</p>
|
||||
{f'<p>Connected as: {user_email}</p>' if status == 'success' else ''}
|
||||
</div>
|
||||
<p><small>You can close this window. {f"Your {provider} is now connected and ready to use." if status == 'success' else "Feel free to close this window."}</small></p>
|
||||
<p><small>You can close this window. {f"Your {provider.replace('_', ' ').title()} is now connected and ready to use." if status == 'success' else "Feel free to close this window."}</small></p>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
return make_response(html_content, 200, {'Content-Type': 'text/html'})
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error rendering callback status page: {e}")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
import datetime
|
||||
import json
|
||||
from flask import Blueprint, request, send_from_directory, jsonify
|
||||
from flask import Blueprint, request, send_from_directory
|
||||
from werkzeug.utils import secure_filename
|
||||
from bson.objectid import ObjectId
|
||||
import logging
|
||||
@@ -24,16 +24,6 @@ current_dir = os.path.dirname(
|
||||
internal = Blueprint("internal", __name__)
|
||||
|
||||
|
||||
@internal.before_request
|
||||
def verify_internal_key():
|
||||
"""Verify INTERNAL_KEY for all internal endpoint requests."""
|
||||
if settings.INTERNAL_KEY:
|
||||
internal_key = request.headers.get("X-Internal-Key")
|
||||
if not internal_key or internal_key != settings.INTERNAL_KEY:
|
||||
logger.warning(f"Unauthorized internal API access attempt from {request.remote_addr}")
|
||||
return jsonify({"error": "Unauthorized", "message": "Invalid or missing internal key"}), 401
|
||||
|
||||
|
||||
@internal.route("/api/download", methods=["get"])
|
||||
def download_file():
|
||||
user = secure_filename(request.args.get("user"))
|
||||
@@ -61,7 +51,6 @@ def upload_index_files():
|
||||
|
||||
file_path = request.form.get("file_path")
|
||||
directory_structure = request.form.get("directory_structure")
|
||||
file_name_map = request.form.get("file_name_map")
|
||||
|
||||
if directory_structure:
|
||||
try:
|
||||
@@ -71,14 +60,6 @@ def upload_index_files():
|
||||
directory_structure = {}
|
||||
else:
|
||||
directory_structure = {}
|
||||
if file_name_map:
|
||||
try:
|
||||
file_name_map = json.loads(file_name_map)
|
||||
except Exception:
|
||||
logger.error("Error parsing file_name_map")
|
||||
file_name_map = None
|
||||
else:
|
||||
file_name_map = None
|
||||
|
||||
storage = StorageCreator.get_storage()
|
||||
index_base_path = f"indexes/{id}"
|
||||
@@ -106,43 +87,41 @@ def upload_index_files():
|
||||
|
||||
existing_entry = sources_collection.find_one({"_id": ObjectId(id)})
|
||||
if existing_entry:
|
||||
update_fields = {
|
||||
"user": user,
|
||||
"name": job_name,
|
||||
"language": job_name,
|
||||
"date": datetime.datetime.now(),
|
||||
"model": settings.EMBEDDINGS_NAME,
|
||||
"type": type,
|
||||
"tokens": tokens,
|
||||
"retriever": retriever,
|
||||
"remote_data": remote_data,
|
||||
"sync_frequency": sync_frequency,
|
||||
"file_path": file_path,
|
||||
"directory_structure": directory_structure,
|
||||
}
|
||||
if file_name_map is not None:
|
||||
update_fields["file_name_map"] = file_name_map
|
||||
sources_collection.update_one(
|
||||
{"_id": ObjectId(id)},
|
||||
{"$set": update_fields},
|
||||
{
|
||||
"$set": {
|
||||
"user": user,
|
||||
"name": job_name,
|
||||
"language": job_name,
|
||||
"date": datetime.datetime.now(),
|
||||
"model": settings.EMBEDDINGS_NAME,
|
||||
"type": type,
|
||||
"tokens": tokens,
|
||||
"retriever": retriever,
|
||||
"remote_data": remote_data,
|
||||
"sync_frequency": sync_frequency,
|
||||
"file_path": file_path,
|
||||
"directory_structure": directory_structure,
|
||||
}
|
||||
},
|
||||
)
|
||||
else:
|
||||
insert_doc = {
|
||||
"_id": ObjectId(id),
|
||||
"user": user,
|
||||
"name": job_name,
|
||||
"language": job_name,
|
||||
"date": datetime.datetime.now(),
|
||||
"model": settings.EMBEDDINGS_NAME,
|
||||
"type": type,
|
||||
"tokens": tokens,
|
||||
"retriever": retriever,
|
||||
"remote_data": remote_data,
|
||||
"sync_frequency": sync_frequency,
|
||||
"file_path": file_path,
|
||||
"directory_structure": directory_structure,
|
||||
}
|
||||
if file_name_map is not None:
|
||||
insert_doc["file_name_map"] = file_name_map
|
||||
sources_collection.insert_one(insert_doc)
|
||||
sources_collection.insert_one(
|
||||
{
|
||||
"_id": ObjectId(id),
|
||||
"user": user,
|
||||
"name": job_name,
|
||||
"language": job_name,
|
||||
"date": datetime.datetime.now(),
|
||||
"model": settings.EMBEDDINGS_NAME,
|
||||
"type": type,
|
||||
"tokens": tokens,
|
||||
"retriever": retriever,
|
||||
"remote_data": remote_data,
|
||||
"sync_frequency": sync_frequency,
|
||||
"file_path": file_path,
|
||||
"directory_structure": directory_structure,
|
||||
}
|
||||
)
|
||||
return {"status": "ok"}
|
||||
|
||||
@@ -3,6 +3,5 @@
|
||||
from .routes import agents_ns
|
||||
from .sharing import agents_sharing_ns
|
||||
from .webhooks import agents_webhooks_ns
|
||||
from .folders import agents_folders_ns
|
||||
|
||||
__all__ = ["agents_ns", "agents_sharing_ns", "agents_webhooks_ns", "agents_folders_ns"]
|
||||
__all__ = ["agents_ns", "agents_sharing_ns", "agents_webhooks_ns"]
|
||||
|
||||
@@ -1,261 +0,0 @@
|
||||
"""
|
||||
Agent folders management routes.
|
||||
Provides virtual folder organization for agents (Google Drive-like structure).
|
||||
"""
|
||||
|
||||
import datetime
|
||||
from bson.objectid import ObjectId
|
||||
from flask import jsonify, make_response, request
|
||||
from flask_restx import Namespace, Resource, fields
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import (
|
||||
agent_folders_collection,
|
||||
agents_collection,
|
||||
)
|
||||
|
||||
agents_folders_ns = Namespace(
|
||||
"agents_folders", description="Agent folder management", path="/api/agents/folders"
|
||||
)
|
||||
|
||||
|
||||
@agents_folders_ns.route("/")
|
||||
class AgentFolders(Resource):
|
||||
@api.doc(description="Get all folders for the user")
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
folders = list(agent_folders_collection.find({"user": user}))
|
||||
result = [
|
||||
{
|
||||
"id": str(f["_id"]),
|
||||
"name": f["name"],
|
||||
"parent_id": f.get("parent_id"),
|
||||
"created_at": f.get("created_at", "").isoformat() if f.get("created_at") else None,
|
||||
"updated_at": f.get("updated_at", "").isoformat() if f.get("updated_at") else None,
|
||||
}
|
||||
for f in folders
|
||||
]
|
||||
return make_response(jsonify({"folders": result}), 200)
|
||||
except Exception as e:
|
||||
return make_response(jsonify({"success": False, "message": str(e)}), 400)
|
||||
|
||||
@api.doc(description="Create a new folder")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"CreateFolder",
|
||||
{
|
||||
"name": fields.String(required=True, description="Folder name"),
|
||||
"parent_id": fields.String(required=False, description="Parent folder ID"),
|
||||
},
|
||||
)
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
if not data or not data.get("name"):
|
||||
return make_response(jsonify({"success": False, "message": "Folder name is required"}), 400)
|
||||
|
||||
parent_id = data.get("parent_id")
|
||||
if parent_id:
|
||||
parent = agent_folders_collection.find_one({"_id": ObjectId(parent_id), "user": user})
|
||||
if not parent:
|
||||
return make_response(jsonify({"success": False, "message": "Parent folder not found"}), 404)
|
||||
|
||||
try:
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
folder = {
|
||||
"user": user,
|
||||
"name": data["name"],
|
||||
"parent_id": parent_id,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
}
|
||||
result = agent_folders_collection.insert_one(folder)
|
||||
return make_response(
|
||||
jsonify({"id": str(result.inserted_id), "name": data["name"], "parent_id": parent_id}),
|
||||
201,
|
||||
)
|
||||
except Exception as e:
|
||||
return make_response(jsonify({"success": False, "message": str(e)}), 400)
|
||||
|
||||
|
||||
@agents_folders_ns.route("/<string:folder_id>")
|
||||
class AgentFolder(Resource):
|
||||
@api.doc(description="Get a specific folder with its agents")
|
||||
def get(self, folder_id):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
folder = agent_folders_collection.find_one({"_id": ObjectId(folder_id), "user": user})
|
||||
if not folder:
|
||||
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
|
||||
|
||||
agents = list(agents_collection.find({"user": user, "folder_id": folder_id}))
|
||||
agents_list = [
|
||||
{"id": str(a["_id"]), "name": a["name"], "description": a.get("description", "")}
|
||||
for a in agents
|
||||
]
|
||||
subfolders = list(agent_folders_collection.find({"user": user, "parent_id": folder_id}))
|
||||
subfolders_list = [{"id": str(sf["_id"]), "name": sf["name"]} for sf in subfolders]
|
||||
|
||||
return make_response(
|
||||
jsonify({
|
||||
"id": str(folder["_id"]),
|
||||
"name": folder["name"],
|
||||
"parent_id": folder.get("parent_id"),
|
||||
"agents": agents_list,
|
||||
"subfolders": subfolders_list,
|
||||
}),
|
||||
200,
|
||||
)
|
||||
except Exception as e:
|
||||
return make_response(jsonify({"success": False, "message": str(e)}), 400)
|
||||
|
||||
@api.doc(description="Update a folder")
|
||||
def put(self, folder_id):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
if not data:
|
||||
return make_response(jsonify({"success": False, "message": "No data provided"}), 400)
|
||||
|
||||
try:
|
||||
update_fields = {"updated_at": datetime.datetime.now(datetime.timezone.utc)}
|
||||
if "name" in data:
|
||||
update_fields["name"] = data["name"]
|
||||
if "parent_id" in data:
|
||||
if data["parent_id"] == folder_id:
|
||||
return make_response(jsonify({"success": False, "message": "Cannot set folder as its own parent"}), 400)
|
||||
update_fields["parent_id"] = data["parent_id"]
|
||||
|
||||
result = agent_folders_collection.update_one(
|
||||
{"_id": ObjectId(folder_id), "user": user}, {"$set": update_fields}
|
||||
)
|
||||
if result.matched_count == 0:
|
||||
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
except Exception as e:
|
||||
return make_response(jsonify({"success": False, "message": str(e)}), 400)
|
||||
|
||||
@api.doc(description="Delete a folder")
|
||||
def delete(self, folder_id):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
agents_collection.update_many(
|
||||
{"user": user, "folder_id": folder_id}, {"$unset": {"folder_id": ""}}
|
||||
)
|
||||
agent_folders_collection.update_many(
|
||||
{"user": user, "parent_id": folder_id}, {"$unset": {"parent_id": ""}}
|
||||
)
|
||||
result = agent_folders_collection.delete_one({"_id": ObjectId(folder_id), "user": user})
|
||||
if result.deleted_count == 0:
|
||||
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
except Exception as e:
|
||||
return make_response(jsonify({"success": False, "message": str(e)}), 400)
|
||||
|
||||
|
||||
@agents_folders_ns.route("/move_agent")
|
||||
class MoveAgentToFolder(Resource):
|
||||
@api.doc(description="Move an agent to a folder or remove from folder")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"MoveAgent",
|
||||
{
|
||||
"agent_id": fields.String(required=True, description="Agent ID to move"),
|
||||
"folder_id": fields.String(required=False, description="Target folder ID (null to remove from folder)"),
|
||||
},
|
||||
)
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
if not data or not data.get("agent_id"):
|
||||
return make_response(jsonify({"success": False, "message": "Agent ID is required"}), 400)
|
||||
|
||||
agent_id = data["agent_id"]
|
||||
folder_id = data.get("folder_id")
|
||||
|
||||
try:
|
||||
agent = agents_collection.find_one({"_id": ObjectId(agent_id), "user": user})
|
||||
if not agent:
|
||||
return make_response(jsonify({"success": False, "message": "Agent not found"}), 404)
|
||||
|
||||
if folder_id:
|
||||
folder = agent_folders_collection.find_one({"_id": ObjectId(folder_id), "user": user})
|
||||
if not folder:
|
||||
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
|
||||
agents_collection.update_one(
|
||||
{"_id": ObjectId(agent_id)}, {"$set": {"folder_id": folder_id}}
|
||||
)
|
||||
else:
|
||||
agents_collection.update_one(
|
||||
{"_id": ObjectId(agent_id)}, {"$unset": {"folder_id": ""}}
|
||||
)
|
||||
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
except Exception as e:
|
||||
return make_response(jsonify({"success": False, "message": str(e)}), 400)
|
||||
|
||||
|
||||
@agents_folders_ns.route("/bulk_move")
|
||||
class BulkMoveAgents(Resource):
|
||||
@api.doc(description="Move multiple agents to a folder")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"BulkMoveAgents",
|
||||
{
|
||||
"agent_ids": fields.List(fields.String, required=True, description="List of agent IDs"),
|
||||
"folder_id": fields.String(required=False, description="Target folder ID"),
|
||||
},
|
||||
)
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
if not data or not data.get("agent_ids"):
|
||||
return make_response(jsonify({"success": False, "message": "Agent IDs are required"}), 400)
|
||||
|
||||
agent_ids = data["agent_ids"]
|
||||
folder_id = data.get("folder_id")
|
||||
|
||||
try:
|
||||
if folder_id:
|
||||
folder = agent_folders_collection.find_one({"_id": ObjectId(folder_id), "user": user})
|
||||
if not folder:
|
||||
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
|
||||
|
||||
object_ids = [ObjectId(aid) for aid in agent_ids]
|
||||
if folder_id:
|
||||
agents_collection.update_many(
|
||||
{"_id": {"$in": object_ids}, "user": user},
|
||||
{"$set": {"folder_id": folder_id}},
|
||||
)
|
||||
else:
|
||||
agents_collection.update_many(
|
||||
{"_id": {"$in": object_ids}, "user": user},
|
||||
{"$unset": {"folder_id": ""}},
|
||||
)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
except Exception as e:
|
||||
return make_response(jsonify({"success": False, "message": str(e)}), 400)
|
||||
@@ -11,7 +11,6 @@ from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import (
|
||||
agent_folders_collection,
|
||||
agents_collection,
|
||||
db,
|
||||
ensure_user_doc,
|
||||
@@ -19,9 +18,6 @@ from application.api.user.base import (
|
||||
resolve_tool_details,
|
||||
storage,
|
||||
users_collection,
|
||||
workflow_edges_collection,
|
||||
workflow_nodes_collection,
|
||||
workflows_collection,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
from application.utils import (
|
||||
@@ -34,189 +30,6 @@ from application.utils import (
|
||||
agents_ns = Namespace("agents", description="Agent management operations", path="/api")
|
||||
|
||||
|
||||
AGENT_TYPE_SCHEMAS = {
|
||||
"classic": {
|
||||
"required_published": [
|
||||
"name",
|
||||
"description",
|
||||
"chunks",
|
||||
"retriever",
|
||||
"prompt_id",
|
||||
],
|
||||
"required_draft": ["name"],
|
||||
"validate_published": ["name", "description", "prompt_id"],
|
||||
"validate_draft": [],
|
||||
"require_source": True,
|
||||
"fields": [
|
||||
"user",
|
||||
"name",
|
||||
"description",
|
||||
"agent_type",
|
||||
"status",
|
||||
"key",
|
||||
"image",
|
||||
"source",
|
||||
"sources",
|
||||
"chunks",
|
||||
"retriever",
|
||||
"prompt_id",
|
||||
"tools",
|
||||
"json_schema",
|
||||
"models",
|
||||
"default_model_id",
|
||||
"folder_id",
|
||||
"limited_token_mode",
|
||||
"token_limit",
|
||||
"limited_request_mode",
|
||||
"request_limit",
|
||||
"createdAt",
|
||||
"updatedAt",
|
||||
"lastUsedAt",
|
||||
],
|
||||
},
|
||||
"workflow": {
|
||||
"required_published": ["name", "workflow"],
|
||||
"required_draft": ["name"],
|
||||
"validate_published": ["name", "workflow"],
|
||||
"validate_draft": [],
|
||||
"fields": [
|
||||
"user",
|
||||
"name",
|
||||
"description",
|
||||
"agent_type",
|
||||
"status",
|
||||
"key",
|
||||
"workflow",
|
||||
"folder_id",
|
||||
"limited_token_mode",
|
||||
"token_limit",
|
||||
"limited_request_mode",
|
||||
"request_limit",
|
||||
"createdAt",
|
||||
"updatedAt",
|
||||
"lastUsedAt",
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
AGENT_TYPE_SCHEMAS["react"] = AGENT_TYPE_SCHEMAS["classic"]
|
||||
AGENT_TYPE_SCHEMAS["openai"] = AGENT_TYPE_SCHEMAS["classic"]
|
||||
|
||||
|
||||
def normalize_workflow_reference(workflow_value):
|
||||
"""Normalize workflow references from form/json payloads."""
|
||||
if workflow_value is None:
|
||||
return None
|
||||
if isinstance(workflow_value, dict):
|
||||
return (
|
||||
workflow_value.get("id")
|
||||
or workflow_value.get("_id")
|
||||
or workflow_value.get("workflow_id")
|
||||
)
|
||||
if isinstance(workflow_value, str):
|
||||
value = workflow_value.strip()
|
||||
if not value:
|
||||
return ""
|
||||
try:
|
||||
parsed = json.loads(value)
|
||||
if isinstance(parsed, str):
|
||||
return parsed.strip()
|
||||
if isinstance(parsed, dict):
|
||||
return (
|
||||
parsed.get("id") or parsed.get("_id") or parsed.get("workflow_id")
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return value
|
||||
return str(workflow_value)
|
||||
|
||||
|
||||
def validate_workflow_access(workflow_value, user, required=False):
|
||||
"""Validate workflow reference and ensure ownership."""
|
||||
workflow_id = normalize_workflow_reference(workflow_value)
|
||||
if not workflow_id:
|
||||
if required:
|
||||
return None, make_response(
|
||||
jsonify({"success": False, "message": "Workflow is required"}), 400
|
||||
)
|
||||
return None, None
|
||||
if not ObjectId.is_valid(workflow_id):
|
||||
return None, make_response(
|
||||
jsonify({"success": False, "message": "Invalid workflow ID format"}), 400
|
||||
)
|
||||
workflow = workflows_collection.find_one({"_id": ObjectId(workflow_id), "user": user})
|
||||
if not workflow:
|
||||
return None, make_response(
|
||||
jsonify({"success": False, "message": "Workflow not found"}), 404
|
||||
)
|
||||
return workflow_id, None
|
||||
|
||||
|
||||
def build_agent_document(
|
||||
data, user, key, agent_type, image_url=None, source_field=None, sources_list=None
|
||||
):
|
||||
"""Build agent document based on agent type schema."""
|
||||
|
||||
if not agent_type or agent_type not in AGENT_TYPE_SCHEMAS:
|
||||
agent_type = "classic"
|
||||
schema = AGENT_TYPE_SCHEMAS.get(agent_type, AGENT_TYPE_SCHEMAS["classic"])
|
||||
allowed_fields = set(schema["fields"])
|
||||
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
base_doc = {
|
||||
"user": user,
|
||||
"name": data.get("name"),
|
||||
"description": data.get("description", ""),
|
||||
"agent_type": agent_type,
|
||||
"status": data.get("status"),
|
||||
"key": key,
|
||||
"createdAt": now,
|
||||
"updatedAt": now,
|
||||
"lastUsedAt": None,
|
||||
}
|
||||
|
||||
if agent_type == "workflow":
|
||||
base_doc["workflow"] = data.get("workflow")
|
||||
base_doc["folder_id"] = data.get("folder_id")
|
||||
else:
|
||||
base_doc.update(
|
||||
{
|
||||
"image": image_url or "",
|
||||
"source": source_field or "",
|
||||
"sources": sources_list or [],
|
||||
"chunks": data.get("chunks", ""),
|
||||
"retriever": data.get("retriever", ""),
|
||||
"prompt_id": data.get("prompt_id", ""),
|
||||
"tools": data.get("tools", []),
|
||||
"json_schema": data.get("json_schema"),
|
||||
"models": data.get("models", []),
|
||||
"default_model_id": data.get("default_model_id", ""),
|
||||
"folder_id": data.get("folder_id"),
|
||||
}
|
||||
)
|
||||
if "limited_token_mode" in allowed_fields:
|
||||
base_doc["limited_token_mode"] = (
|
||||
data.get("limited_token_mode") == "True"
|
||||
if isinstance(data.get("limited_token_mode"), str)
|
||||
else bool(data.get("limited_token_mode", False))
|
||||
)
|
||||
if "token_limit" in allowed_fields:
|
||||
base_doc["token_limit"] = int(
|
||||
data.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"])
|
||||
)
|
||||
if "limited_request_mode" in allowed_fields:
|
||||
base_doc["limited_request_mode"] = (
|
||||
data.get("limited_request_mode") == "True"
|
||||
if isinstance(data.get("limited_request_mode"), str)
|
||||
else bool(data.get("limited_request_mode", False))
|
||||
)
|
||||
if "request_limit" in allowed_fields:
|
||||
base_doc["request_limit"] = int(
|
||||
data.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"])
|
||||
)
|
||||
return {k: v for k, v in base_doc.items() if k in allowed_fields}
|
||||
|
||||
|
||||
@agents_ns.route("/get_agent")
|
||||
class GetAgent(Resource):
|
||||
@api.doc(params={"id": "Agent ID"}, description="Get agent by ID")
|
||||
@@ -254,7 +67,7 @@ class GetAgent(Resource):
|
||||
if (isinstance(source_ref, DBRef) and db.dereference(source_ref))
|
||||
or source_ref == "default"
|
||||
],
|
||||
"chunks": agent.get("chunks", "2"),
|
||||
"chunks": agent["chunks"],
|
||||
"retriever": agent.get("retriever", ""),
|
||||
"prompt_id": agent.get("prompt_id", ""),
|
||||
"tools": agent.get("tools", []),
|
||||
@@ -284,8 +97,6 @@ class GetAgent(Resource):
|
||||
"shared_token": agent.get("shared_token", ""),
|
||||
"models": agent.get("models", []),
|
||||
"default_model_id": agent.get("default_model_id", ""),
|
||||
"folder_id": agent.get("folder_id"),
|
||||
"workflow": agent.get("workflow"),
|
||||
}
|
||||
return make_response(jsonify(data), 200)
|
||||
except Exception as e:
|
||||
@@ -335,7 +146,7 @@ class GetAgents(Resource):
|
||||
isinstance(source_ref, DBRef) and db.dereference(source_ref)
|
||||
)
|
||||
],
|
||||
"chunks": agent.get("chunks", "2"),
|
||||
"chunks": agent["chunks"],
|
||||
"retriever": agent.get("retriever", ""),
|
||||
"prompt_id": agent.get("prompt_id", ""),
|
||||
"tools": agent.get("tools", []),
|
||||
@@ -365,13 +176,9 @@ class GetAgents(Resource):
|
||||
"shared_token": agent.get("shared_token", ""),
|
||||
"models": agent.get("models", []),
|
||||
"default_model_id": agent.get("default_model_id", ""),
|
||||
"folder_id": agent.get("folder_id"),
|
||||
"workflow": agent.get("workflow"),
|
||||
}
|
||||
for agent in agents
|
||||
if "source" in agent
|
||||
or "retriever" in agent
|
||||
or agent.get("agent_type") == "workflow"
|
||||
if "source" in agent or "retriever" in agent
|
||||
]
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error retrieving agents: {err}", exc_info=True)
|
||||
@@ -399,22 +206,16 @@ class CreateAgent(Resource):
|
||||
required=False,
|
||||
description="List of source identifiers for multiple sources",
|
||||
),
|
||||
"chunks": fields.Integer(required=False, description="Chunks count"),
|
||||
"retriever": fields.String(required=False, description="Retriever ID"),
|
||||
"prompt_id": fields.String(required=False, description="Prompt ID"),
|
||||
"chunks": fields.Integer(required=True, description="Chunks count"),
|
||||
"retriever": fields.String(required=True, description="Retriever ID"),
|
||||
"prompt_id": fields.String(required=True, description="Prompt ID"),
|
||||
"tools": fields.List(
|
||||
fields.String, required=False, description="List of tool identifiers"
|
||||
),
|
||||
"agent_type": fields.String(
|
||||
required=False,
|
||||
description="Type of the agent (classic, react, workflow). Defaults to 'classic' for backwards compatibility.",
|
||||
),
|
||||
"agent_type": fields.String(required=True, description="Type of the agent"),
|
||||
"status": fields.String(
|
||||
required=True, description="Status of the agent (draft or published)"
|
||||
),
|
||||
"workflow": fields.String(
|
||||
required=False, description="Workflow ID for workflow-type agents"
|
||||
),
|
||||
"json_schema": fields.Raw(
|
||||
required=False,
|
||||
description="JSON schema for enforcing structured output format",
|
||||
@@ -441,9 +242,6 @@ class CreateAgent(Resource):
|
||||
"default_model_id": fields.String(
|
||||
required=False, description="Default model ID for this agent"
|
||||
),
|
||||
"folder_id": fields.String(
|
||||
required=False, description="Folder ID to organize the agent"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -509,10 +307,9 @@ class CreateAgent(Resource):
|
||||
400,
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Invalid JSON schema: {e}")
|
||||
return make_response(
|
||||
jsonify(
|
||||
{"success": False, "message": "Invalid JSON schema format"}
|
||||
{"success": False, "message": f"Invalid JSON schema: {str(e)}"}
|
||||
),
|
||||
400,
|
||||
)
|
||||
@@ -526,34 +323,18 @@ class CreateAgent(Resource):
|
||||
),
|
||||
400,
|
||||
)
|
||||
agent_type = data.get("agent_type", "")
|
||||
# Default to classic schema for empty or unknown agent types
|
||||
|
||||
if not agent_type or agent_type not in AGENT_TYPE_SCHEMAS:
|
||||
schema = AGENT_TYPE_SCHEMAS["classic"]
|
||||
# Set agent_type to classic if it was empty
|
||||
|
||||
if not agent_type:
|
||||
agent_type = "classic"
|
||||
else:
|
||||
schema = AGENT_TYPE_SCHEMAS[agent_type]
|
||||
is_published = data.get("status") == "published"
|
||||
if agent_type == "workflow":
|
||||
workflow_id, workflow_error = validate_workflow_access(
|
||||
data.get("workflow"), user, required=is_published
|
||||
)
|
||||
if workflow_error:
|
||||
return workflow_error
|
||||
data["workflow"] = workflow_id
|
||||
if data.get("status") == "published":
|
||||
required_fields = schema["required_published"]
|
||||
validate_fields = schema["validate_published"]
|
||||
required_fields = [
|
||||
"name",
|
||||
"description",
|
||||
"chunks",
|
||||
"retriever",
|
||||
"prompt_id",
|
||||
"agent_type",
|
||||
]
|
||||
# Require either source or sources (but not both)
|
||||
|
||||
if (
|
||||
schema.get("require_source")
|
||||
and not data.get("source")
|
||||
and not data.get("sources")
|
||||
):
|
||||
if not data.get("source") and not data.get("sources"):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
@@ -563,9 +344,10 @@ class CreateAgent(Resource):
|
||||
),
|
||||
400,
|
||||
)
|
||||
validate_fields = ["name", "description", "prompt_id", "agent_type"]
|
||||
else:
|
||||
required_fields = schema["required_draft"]
|
||||
validate_fields = schema["validate_draft"]
|
||||
required_fields = ["name"]
|
||||
validate_fields = []
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
invalid_fields = validate_required_fields(data, validate_fields)
|
||||
if missing_fields:
|
||||
@@ -577,50 +359,74 @@ class CreateAgent(Resource):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Image upload failed"}), 400
|
||||
)
|
||||
folder_id = data.get("folder_id")
|
||||
if folder_id:
|
||||
if not ObjectId.is_valid(folder_id):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid folder ID format"}),
|
||||
400,
|
||||
)
|
||||
folder = agent_folders_collection.find_one(
|
||||
{"_id": ObjectId(folder_id), "user": user}
|
||||
)
|
||||
if not folder:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Folder not found"}), 404
|
||||
)
|
||||
try:
|
||||
key = str(uuid.uuid4()) if data.get("status") == "published" else ""
|
||||
|
||||
sources_list = []
|
||||
source_field = ""
|
||||
if data.get("sources") and len(data.get("sources", [])) > 0:
|
||||
for source_id in data.get("sources", []):
|
||||
if source_id == "default":
|
||||
sources_list.append("default")
|
||||
elif ObjectId.is_valid(source_id):
|
||||
sources_list.append(DBRef("sources", ObjectId(source_id)))
|
||||
source_field = ""
|
||||
else:
|
||||
source_value = data.get("source", "")
|
||||
if source_value == "default":
|
||||
source_field = "default"
|
||||
elif ObjectId.is_valid(source_value):
|
||||
source_field = DBRef("sources", ObjectId(source_value))
|
||||
new_agent = build_agent_document(
|
||||
data, user, key, agent_type, image_url, source_field, sources_list
|
||||
)
|
||||
|
||||
if agent_type != "workflow":
|
||||
if new_agent.get("chunks") == "":
|
||||
new_agent["chunks"] = "2"
|
||||
if (
|
||||
new_agent.get("source") == ""
|
||||
and new_agent.get("retriever") == ""
|
||||
and not new_agent.get("sources")
|
||||
):
|
||||
new_agent["retriever"] = "classic"
|
||||
else:
|
||||
source_field = ""
|
||||
new_agent = {
|
||||
"user": user,
|
||||
"name": data.get("name"),
|
||||
"description": data.get("description", ""),
|
||||
"image": image_url,
|
||||
"source": source_field,
|
||||
"sources": sources_list,
|
||||
"chunks": data.get("chunks", ""),
|
||||
"retriever": data.get("retriever", ""),
|
||||
"prompt_id": data.get("prompt_id", ""),
|
||||
"tools": data.get("tools", []),
|
||||
"agent_type": data.get("agent_type", ""),
|
||||
"status": data.get("status"),
|
||||
"json_schema": data.get("json_schema"),
|
||||
"limited_token_mode": (
|
||||
data.get("limited_token_mode") == "True"
|
||||
if isinstance(data.get("limited_token_mode"), str)
|
||||
else bool(data.get("limited_token_mode", False))
|
||||
),
|
||||
"token_limit": int(
|
||||
data.get(
|
||||
"token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"]
|
||||
)
|
||||
),
|
||||
"limited_request_mode": (
|
||||
data.get("limited_request_mode") == "True"
|
||||
if isinstance(data.get("limited_request_mode"), str)
|
||||
else bool(data.get("limited_request_mode", False))
|
||||
),
|
||||
"request_limit": int(
|
||||
data.get(
|
||||
"request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"]
|
||||
)
|
||||
),
|
||||
"createdAt": datetime.datetime.now(datetime.timezone.utc),
|
||||
"updatedAt": datetime.datetime.now(datetime.timezone.utc),
|
||||
"lastUsedAt": None,
|
||||
"key": key,
|
||||
"models": data.get("models", []),
|
||||
"default_model_id": data.get("default_model_id", ""),
|
||||
}
|
||||
if new_agent["chunks"] == "":
|
||||
new_agent["chunks"] = "2"
|
||||
if (
|
||||
new_agent["source"] == ""
|
||||
and new_agent["retriever"] == ""
|
||||
and not new_agent["sources"]
|
||||
):
|
||||
new_agent["retriever"] = "classic"
|
||||
resp = agents_collection.insert_one(new_agent)
|
||||
new_id = str(resp.inserted_id)
|
||||
except Exception as err:
|
||||
@@ -649,22 +455,16 @@ class UpdateAgent(Resource):
|
||||
required=False,
|
||||
description="List of source identifiers for multiple sources",
|
||||
),
|
||||
"chunks": fields.Integer(required=False, description="Chunks count"),
|
||||
"retriever": fields.String(required=False, description="Retriever ID"),
|
||||
"prompt_id": fields.String(required=False, description="Prompt ID"),
|
||||
"chunks": fields.Integer(required=True, description="Chunks count"),
|
||||
"retriever": fields.String(required=True, description="Retriever ID"),
|
||||
"prompt_id": fields.String(required=True, description="Prompt ID"),
|
||||
"tools": fields.List(
|
||||
fields.String, required=False, description="List of tool identifiers"
|
||||
),
|
||||
"agent_type": fields.String(
|
||||
required=False,
|
||||
description="Type of the agent (classic, react, workflow). Defaults to 'classic' for backwards compatibility.",
|
||||
),
|
||||
"agent_type": fields.String(required=True, description="Type of the agent"),
|
||||
"status": fields.String(
|
||||
required=True, description="Status of the agent (draft or published)"
|
||||
),
|
||||
"workflow": fields.String(
|
||||
required=False, description="Workflow ID for workflow-type agents"
|
||||
),
|
||||
"json_schema": fields.Raw(
|
||||
required=False,
|
||||
description="JSON schema for enforcing structured output format",
|
||||
@@ -691,9 +491,6 @@ class UpdateAgent(Resource):
|
||||
"default_model_id": fields.String(
|
||||
required=False, description="Default model ID for this agent"
|
||||
),
|
||||
"folder_id": fields.String(
|
||||
required=False, description="Folder ID to organize the agent"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -787,8 +584,6 @@ class UpdateAgent(Resource):
|
||||
"request_limit",
|
||||
"models",
|
||||
"default_model_id",
|
||||
"folder_id",
|
||||
"workflow",
|
||||
]
|
||||
|
||||
for field in allowed_fields:
|
||||
@@ -945,10 +740,10 @@ class UpdateAgent(Resource):
|
||||
)
|
||||
elif field == "token_limit":
|
||||
token_limit = data.get("token_limit")
|
||||
# Convert to int and store
|
||||
update_fields[field] = int(token_limit) if token_limit else 0
|
||||
|
||||
# Validate consistency with mode
|
||||
|
||||
if update_fields[field] > 0 and not data.get("limited_token_mode"):
|
||||
return make_response(
|
||||
jsonify(
|
||||
@@ -973,42 +768,6 @@ class UpdateAgent(Resource):
|
||||
),
|
||||
400,
|
||||
)
|
||||
elif field == "folder_id":
|
||||
folder_id = data.get("folder_id")
|
||||
if folder_id:
|
||||
if not ObjectId.is_valid(folder_id):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Invalid folder ID format",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
folder = agent_folders_collection.find_one(
|
||||
{"_id": ObjectId(folder_id), "user": user}
|
||||
)
|
||||
if not folder:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Folder not found"}),
|
||||
404,
|
||||
)
|
||||
update_fields[field] = folder_id
|
||||
else:
|
||||
update_fields[field] = None
|
||||
elif field == "workflow":
|
||||
workflow_required = (
|
||||
data.get("status", existing_agent.get("status")) == "published"
|
||||
and data.get("agent_type", existing_agent.get("agent_type"))
|
||||
== "workflow"
|
||||
)
|
||||
workflow_id, workflow_error = validate_workflow_access(
|
||||
data.get("workflow"), user, required=workflow_required
|
||||
)
|
||||
if workflow_error:
|
||||
return workflow_error
|
||||
update_fields[field] = workflow_id
|
||||
else:
|
||||
value = data[field]
|
||||
if field in ["name", "description", "prompt_id", "agent_type"]:
|
||||
@@ -1037,82 +796,46 @@ class UpdateAgent(Resource):
|
||||
)
|
||||
newly_generated_key = None
|
||||
final_status = update_fields.get("status", existing_agent.get("status"))
|
||||
agent_type = update_fields.get("agent_type", existing_agent.get("agent_type"))
|
||||
|
||||
if final_status == "published":
|
||||
if agent_type == "workflow":
|
||||
required_published_fields = {
|
||||
"name": "Agent name",
|
||||
}
|
||||
missing_published_fields = []
|
||||
for req_field, field_label in required_published_fields.items():
|
||||
final_value = update_fields.get(
|
||||
req_field, existing_agent.get(req_field)
|
||||
)
|
||||
if not final_value:
|
||||
missing_published_fields.append(field_label)
|
||||
required_published_fields = {
|
||||
"name": "Agent name",
|
||||
"description": "Agent description",
|
||||
"chunks": "Chunks count",
|
||||
"prompt_id": "Prompt",
|
||||
"agent_type": "Agent type",
|
||||
}
|
||||
|
||||
workflow_id = update_fields.get("workflow", existing_agent.get("workflow"))
|
||||
if not workflow_id:
|
||||
missing_published_fields.append("Workflow")
|
||||
elif not ObjectId.is_valid(workflow_id):
|
||||
missing_published_fields.append("Valid workflow")
|
||||
else:
|
||||
workflow = workflows_collection.find_one(
|
||||
{"_id": ObjectId(workflow_id), "user": user}
|
||||
)
|
||||
if not workflow:
|
||||
missing_published_fields.append("Workflow access")
|
||||
|
||||
if missing_published_fields:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"Cannot publish workflow agent. Missing required fields: {', '.join(missing_published_fields)}",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
else:
|
||||
required_published_fields = {
|
||||
"name": "Agent name",
|
||||
"description": "Agent description",
|
||||
"chunks": "Chunks count",
|
||||
"prompt_id": "Prompt",
|
||||
"agent_type": "Agent type",
|
||||
}
|
||||
|
||||
missing_published_fields = []
|
||||
for req_field, field_label in required_published_fields.items():
|
||||
final_value = update_fields.get(
|
||||
req_field, existing_agent.get(req_field)
|
||||
)
|
||||
if not final_value:
|
||||
missing_published_fields.append(field_label)
|
||||
source_val = update_fields.get("source", existing_agent.get("source"))
|
||||
sources_val = update_fields.get(
|
||||
"sources", existing_agent.get("sources", [])
|
||||
missing_published_fields = []
|
||||
for req_field, field_label in required_published_fields.items():
|
||||
final_value = update_fields.get(
|
||||
req_field, existing_agent.get(req_field)
|
||||
)
|
||||
if not final_value:
|
||||
missing_published_fields.append(field_label)
|
||||
source_val = update_fields.get("source", existing_agent.get("source"))
|
||||
sources_val = update_fields.get(
|
||||
"sources", existing_agent.get("sources", [])
|
||||
)
|
||||
|
||||
has_valid_source = (
|
||||
isinstance(source_val, DBRef)
|
||||
or source_val == "default"
|
||||
or (isinstance(sources_val, list) and len(sources_val) > 0)
|
||||
has_valid_source = (
|
||||
isinstance(source_val, DBRef)
|
||||
or source_val == "default"
|
||||
or (isinstance(sources_val, list) and len(sources_val) > 0)
|
||||
)
|
||||
|
||||
if not has_valid_source:
|
||||
missing_published_fields.append("Source")
|
||||
if missing_published_fields:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"Cannot publish agent. Missing or invalid required fields: {', '.join(missing_published_fields)}",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
|
||||
if not has_valid_source:
|
||||
missing_published_fields.append("Source")
|
||||
if missing_published_fields:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"Cannot publish agent. Missing or invalid required fields: {', '.join(missing_published_fields)}",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
if not existing_agent.get("key"):
|
||||
newly_generated_key = str(uuid.uuid4())
|
||||
update_fields["key"] = newly_generated_key
|
||||
@@ -1184,29 +907,6 @@ class DeleteAgent(Resource):
|
||||
jsonify({"success": False, "message": "Agent not found"}), 404
|
||||
)
|
||||
deleted_id = str(deleted_agent["_id"])
|
||||
|
||||
if deleted_agent.get("agent_type") == "workflow" and deleted_agent.get(
|
||||
"workflow"
|
||||
):
|
||||
workflow_id = normalize_workflow_reference(deleted_agent.get("workflow"))
|
||||
if workflow_id and ObjectId.is_valid(workflow_id):
|
||||
workflow_oid = ObjectId(workflow_id)
|
||||
owned_workflow = workflows_collection.find_one(
|
||||
{"_id": workflow_oid, "user": user}, {"_id": 1}
|
||||
)
|
||||
if owned_workflow:
|
||||
workflow_nodes_collection.delete_many({"workflow_id": workflow_id})
|
||||
workflow_edges_collection.delete_many({"workflow_id": workflow_id})
|
||||
workflows_collection.delete_one({"_id": workflow_oid, "user": user})
|
||||
else:
|
||||
current_app.logger.warning(
|
||||
f"Skipping workflow cleanup for non-owned workflow {workflow_id}"
|
||||
)
|
||||
elif workflow_id:
|
||||
current_app.logger.warning(
|
||||
f"Skipping workflow cleanup for invalid workflow id {workflow_id}"
|
||||
)
|
||||
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error deleting agent: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
@@ -1315,16 +1015,19 @@ class AdoptAgent(Resource):
|
||||
def post(self):
|
||||
if not (decoded_token := request.decoded_token):
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
|
||||
if not (agent_id := request.args.get("id")):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "ID required"}), 400
|
||||
)
|
||||
|
||||
try:
|
||||
agent = agents_collection.find_one(
|
||||
{"_id": ObjectId(agent_id), "user": "system"}
|
||||
)
|
||||
if not agent:
|
||||
return make_response(jsonify({"status": "Not found"}), 404)
|
||||
|
||||
new_agent = agent.copy()
|
||||
new_agent.pop("_id", None)
|
||||
new_agent["user"] = decoded_token["sub"]
|
||||
@@ -1434,4 +1137,4 @@ class RemoveSharedAgent(Resource):
|
||||
current_app.logger.error(f"Error removing shared agent: {err}")
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Server error"}), 500
|
||||
)
|
||||
)
|
||||
|
||||
@@ -255,8 +255,8 @@ class ShareAgent(Resource):
|
||||
{"$unset": {"shared_metadata": ""}},
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error sharing/unsharing agent: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False, "error": "Failed to update agent sharing status"}), 400)
|
||||
current_app.logger.error(f"Error sharing/unsharing agent: {err}")
|
||||
return make_response(jsonify({"success": False, "error": str(err)}), 400)
|
||||
shared_token = shared_token if shared else None
|
||||
return make_response(
|
||||
jsonify({"success": True, "shared_token": shared_token}), 200
|
||||
|
||||
@@ -99,8 +99,11 @@ class StoreAttachment(Resource):
|
||||
})
|
||||
|
||||
if not tasks:
|
||||
error_msg = "No valid files to upload"
|
||||
if errors:
|
||||
error_msg += f". Errors: {errors}"
|
||||
return make_response(
|
||||
jsonify({"status": "error", "message": "No valid files to upload"}),
|
||||
jsonify({"status": "error", "message": error_msg, "errors": errors}),
|
||||
400,
|
||||
)
|
||||
|
||||
@@ -132,7 +135,7 @@ class StoreAttachment(Resource):
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error storing attachment: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False, "error": "Failed to store attachment"}), 400)
|
||||
return make_response(jsonify({"success": False, "error": str(err)}), 400)
|
||||
|
||||
|
||||
@attachments_ns.route("/images/<path:image_path>")
|
||||
|
||||
@@ -31,17 +31,12 @@ sources_collection = db["sources"]
|
||||
prompts_collection = db["prompts"]
|
||||
feedback_collection = db["feedback"]
|
||||
agents_collection = db["agents"]
|
||||
agent_folders_collection = db["agent_folders"]
|
||||
token_usage_collection = db["token_usage"]
|
||||
shared_conversations_collections = db["shared_conversations"]
|
||||
users_collection = db["users"]
|
||||
user_logs_collection = db["user_logs"]
|
||||
user_tools_collection = db["user_tools"]
|
||||
attachments_collection = db["attachments"]
|
||||
workflow_runs_collection = db["workflow_runs"]
|
||||
workflows_collection = db["workflows"]
|
||||
workflow_nodes_collection = db["workflow_nodes"]
|
||||
workflow_edges_collection = db["workflow_edges"]
|
||||
|
||||
|
||||
try:
|
||||
@@ -51,25 +46,6 @@ try:
|
||||
background=True,
|
||||
)
|
||||
users_collection.create_index("user_id", unique=True)
|
||||
workflows_collection.create_index(
|
||||
[("user", 1)], name="workflow_user_index", background=True
|
||||
)
|
||||
workflow_nodes_collection.create_index(
|
||||
[("workflow_id", 1)], name="node_workflow_index", background=True
|
||||
)
|
||||
workflow_nodes_collection.create_index(
|
||||
[("workflow_id", 1), ("graph_version", 1)],
|
||||
name="node_workflow_graph_version_index",
|
||||
background=True,
|
||||
)
|
||||
workflow_edges_collection.create_index(
|
||||
[("workflow_id", 1)], name="edge_workflow_index", background=True
|
||||
)
|
||||
workflow_edges_collection.create_index(
|
||||
[("workflow_id", 1), ("graph_version", 1)],
|
||||
name="edge_workflow_graph_version_index",
|
||||
background=True,
|
||||
)
|
||||
except Exception as e:
|
||||
print("Error creating indexes:", e)
|
||||
current_dir = os.path.dirname(
|
||||
|
||||
@@ -5,7 +5,8 @@ Main user API routes - registers all namespace modules.
|
||||
from flask import Blueprint
|
||||
|
||||
from application.api import api
|
||||
from .agents import agents_ns, agents_sharing_ns, agents_webhooks_ns, agents_folders_ns
|
||||
from .agents import agents_ns, agents_sharing_ns, agents_webhooks_ns
|
||||
|
||||
from .analytics import analytics_ns
|
||||
from .attachments import attachments_ns
|
||||
from .conversations import conversations_ns
|
||||
@@ -14,7 +15,6 @@ from .prompts import prompts_ns
|
||||
from .sharing import sharing_ns
|
||||
from .sources import sources_chunks_ns, sources_ns, sources_upload_ns
|
||||
from .tools import tools_mcp_ns, tools_ns
|
||||
from .workflows import workflows_ns
|
||||
|
||||
|
||||
user = Blueprint("user", __name__)
|
||||
@@ -31,11 +31,10 @@ api.add_namespace(conversations_ns)
|
||||
# Models
|
||||
api.add_namespace(models_ns)
|
||||
|
||||
# Agents (main, sharing, webhooks, folders)
|
||||
# Agents (main, sharing, webhooks)
|
||||
api.add_namespace(agents_ns)
|
||||
api.add_namespace(agents_sharing_ns)
|
||||
api.add_namespace(agents_webhooks_ns)
|
||||
api.add_namespace(agents_folders_ns)
|
||||
|
||||
# Prompts
|
||||
api.add_namespace(prompts_ns)
|
||||
@@ -51,6 +50,3 @@ api.add_namespace(sources_upload_ns)
|
||||
# Tools (main, MCP)
|
||||
api.add_namespace(tools_ns)
|
||||
api.add_namespace(tools_mcp_ns)
|
||||
|
||||
# Workflows
|
||||
api.add_namespace(workflows_ns)
|
||||
|
||||
@@ -220,23 +220,8 @@ class GetPubliclySharedConversations(Resource):
|
||||
shared
|
||||
and "conversation_id" in shared
|
||||
):
|
||||
# Handle DBRef (legacy), ObjectId, dict, and string formats for conversation_id
|
||||
# conversation_id is now stored as an ObjectId, not a DBRef
|
||||
conversation_id = shared["conversation_id"]
|
||||
if isinstance(conversation_id, DBRef):
|
||||
conversation_id = conversation_id.id
|
||||
elif isinstance(conversation_id, dict):
|
||||
# Handle dict representation of DBRef (e.g., {"$ref": "...", "$id": "..."})
|
||||
if "$id" in conversation_id:
|
||||
conv_id = conversation_id["$id"]
|
||||
# $id might be a dict like {"$oid": "..."} or a string
|
||||
if isinstance(conv_id, dict) and "$oid" in conv_id:
|
||||
conversation_id = ObjectId(conv_id["$oid"])
|
||||
else:
|
||||
conversation_id = ObjectId(conv_id)
|
||||
elif "_id" in conversation_id:
|
||||
conversation_id = ObjectId(conversation_id["_id"])
|
||||
elif isinstance(conversation_id, str):
|
||||
conversation_id = ObjectId(conversation_id)
|
||||
conversation = conversations_collection.find_one(
|
||||
{"_id": conversation_id}
|
||||
)
|
||||
|
||||
@@ -55,14 +55,9 @@ class GetChunks(Resource):
|
||||
|
||||
if path:
|
||||
chunk_source = metadata.get("source", "")
|
||||
chunk_file_path = metadata.get("file_path", "")
|
||||
# Check if the chunk matches the requested path
|
||||
# For file uploads: source ends with path (e.g., "inputs/.../file.pdf" ends with "file.pdf")
|
||||
# For crawlers: file_path ends with path (e.g., "guides/setup.md" ends with "setup.md")
|
||||
source_match = chunk_source and chunk_source.endswith(path)
|
||||
file_path_match = chunk_file_path and chunk_file_path.endswith(path)
|
||||
# Check if the chunk's source matches the requested path
|
||||
|
||||
if not (source_match or file_path_match):
|
||||
if not chunk_source or not chunk_source.endswith(path):
|
||||
continue
|
||||
# Filter by search term if provided
|
||||
|
||||
|
||||
@@ -9,7 +9,6 @@ from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import sources_collection
|
||||
from application.api.user.tasks import sync_source
|
||||
from application.core.settings import settings
|
||||
from application.storage.storage_creator import StorageCreator
|
||||
from application.utils import check_required_fields
|
||||
@@ -21,21 +20,6 @@ sources_ns = Namespace(
|
||||
)
|
||||
|
||||
|
||||
def _get_provider_from_remote_data(remote_data):
|
||||
if not remote_data:
|
||||
return None
|
||||
if isinstance(remote_data, dict):
|
||||
return remote_data.get("provider")
|
||||
if isinstance(remote_data, str):
|
||||
try:
|
||||
remote_data_obj = json.loads(remote_data)
|
||||
except Exception:
|
||||
return None
|
||||
if isinstance(remote_data_obj, dict):
|
||||
return remote_data_obj.get("provider")
|
||||
return None
|
||||
|
||||
|
||||
@sources_ns.route("/sources")
|
||||
class CombinedJson(Resource):
|
||||
@api.doc(description="Provide JSON file with combined available indexes")
|
||||
@@ -57,7 +41,6 @@ class CombinedJson(Resource):
|
||||
|
||||
try:
|
||||
for index in sources_collection.find({"user": user}).sort("date", -1):
|
||||
provider = _get_provider_from_remote_data(index.get("remote_data"))
|
||||
data.append(
|
||||
{
|
||||
"id": str(index["_id"]),
|
||||
@@ -68,7 +51,6 @@ class CombinedJson(Resource):
|
||||
"tokens": index.get("tokens", ""),
|
||||
"retriever": index.get("retriever", "classic"),
|
||||
"syncFrequency": index.get("sync_frequency", ""),
|
||||
"provider": provider,
|
||||
"is_nested": bool(index.get("directory_structure")),
|
||||
"type": index.get(
|
||||
"type", "file"
|
||||
@@ -125,7 +107,6 @@ class PaginatedSources(Resource):
|
||||
|
||||
paginated_docs = []
|
||||
for doc in documents:
|
||||
provider = _get_provider_from_remote_data(doc.get("remote_data"))
|
||||
doc_data = {
|
||||
"id": str(doc["_id"]),
|
||||
"name": doc.get("name", ""),
|
||||
@@ -135,7 +116,6 @@ class PaginatedSources(Resource):
|
||||
"tokens": doc.get("tokens", ""),
|
||||
"retriever": doc.get("retriever", "classic"),
|
||||
"syncFrequency": doc.get("sync_frequency", ""),
|
||||
"provider": provider,
|
||||
"isNested": bool(doc.get("directory_structure")),
|
||||
"type": doc.get("type", "file"),
|
||||
}
|
||||
@@ -260,7 +240,7 @@ class ManageSync(Resource):
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json() or {}
|
||||
data = request.get_json()
|
||||
required_fields = ["source_id", "sync_frequency"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
@@ -289,72 +269,6 @@ class ManageSync(Resource):
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@sources_ns.route("/sync_source")
|
||||
class SyncSource(Resource):
|
||||
sync_source_model = api.model(
|
||||
"SyncSourceModel",
|
||||
{"source_id": fields.String(required=True, description="Source ID")},
|
||||
)
|
||||
|
||||
@api.expect(sync_source_model)
|
||||
@api.doc(description="Trigger an immediate sync for a source")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["source_id"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
source_id = data["source_id"]
|
||||
if not ObjectId.is_valid(source_id):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid source ID"}), 400
|
||||
)
|
||||
doc = sources_collection.find_one(
|
||||
{"_id": ObjectId(source_id), "user": user}
|
||||
)
|
||||
if not doc:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Source not found"}), 404
|
||||
)
|
||||
source_type = doc.get("type", "")
|
||||
if source_type.startswith("connector"):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Connector sources must be synced via /api/connectors/sync",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
source_data = doc.get("remote_data")
|
||||
if not source_data:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Source is not syncable"}), 400
|
||||
)
|
||||
try:
|
||||
task = sync_source.delay(
|
||||
source_data=source_data,
|
||||
job_name=doc.get("name", ""),
|
||||
user=user,
|
||||
loader=source_type,
|
||||
sync_frequency=doc.get("sync_frequency", "never"),
|
||||
retriever=doc.get("retriever", "classic"),
|
||||
doc_id=source_id,
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error starting sync for source {source_id}: {err}",
|
||||
exc_info=True,
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
|
||||
|
||||
|
||||
@sources_ns.route("/directory_structure")
|
||||
class DirectoryStructure(Resource):
|
||||
@api.doc(
|
||||
@@ -406,4 +320,4 @@ class DirectoryStructure(Resource):
|
||||
current_app.logger.error(
|
||||
f"Error retrieving directory structure: {e}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False, "error": "Failed to retrieve directory structure"}), 500)
|
||||
return make_response(jsonify({"success": False, "error": str(e)}), 500)
|
||||
|
||||
@@ -64,27 +64,19 @@ class UploadFile(Resource):
|
||||
safe_user = safe_filename(user)
|
||||
dir_name = safe_filename(job_name)
|
||||
base_path = f"{settings.UPLOAD_FOLDER}/{safe_user}/{dir_name}"
|
||||
file_name_map = {}
|
||||
|
||||
try:
|
||||
storage = StorageCreator.get_storage()
|
||||
|
||||
for file in files:
|
||||
original_filename = os.path.basename(file.filename)
|
||||
original_filename = file.filename
|
||||
safe_file = safe_filename(original_filename)
|
||||
if original_filename:
|
||||
file_name_map[safe_file] = original_filename
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_file_path = os.path.join(temp_dir, safe_file)
|
||||
file.save(temp_file_path)
|
||||
|
||||
# Only extract actual .zip files, not Office formats (.docx, .xlsx, .pptx)
|
||||
# which are technically zip archives but should be processed as-is
|
||||
is_office_format = safe_file.lower().endswith(
|
||||
(".docx", ".xlsx", ".pptx", ".odt", ".ods", ".odp", ".epub")
|
||||
)
|
||||
if zipfile.is_zipfile(temp_file_path) and not is_office_format:
|
||||
if zipfile.is_zipfile(temp_file_path):
|
||||
try:
|
||||
with zipfile.ZipFile(temp_file_path, "r") as zip_ref:
|
||||
zip_ref.extractall(path=temp_dir)
|
||||
@@ -145,7 +137,6 @@ class UploadFile(Resource):
|
||||
user,
|
||||
file_path=base_path,
|
||||
filename=dir_name,
|
||||
file_name_map=file_name_map,
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error uploading file: {err}", exc_info=True)
|
||||
@@ -191,8 +182,6 @@ class UploadRemote(Resource):
|
||||
source_data = config.get("url")
|
||||
elif data["source"] == "reddit":
|
||||
source_data = config
|
||||
elif data["source"] == "s3":
|
||||
source_data = config
|
||||
elif data["source"] in ConnectorCreator.get_supported_connectors():
|
||||
session_token = config.get("session_token")
|
||||
if not session_token:
|
||||
@@ -345,14 +334,6 @@ class ManageSourceFiles(Resource):
|
||||
storage = StorageCreator.get_storage()
|
||||
source_file_path = source.get("file_path", "")
|
||||
parent_dir = request.form.get("parent_dir", "")
|
||||
file_name_map = source.get("file_name_map") or {}
|
||||
if isinstance(file_name_map, str):
|
||||
try:
|
||||
file_name_map = json.loads(file_name_map)
|
||||
except Exception:
|
||||
file_name_map = {}
|
||||
if not isinstance(file_name_map, dict):
|
||||
file_name_map = {}
|
||||
|
||||
if parent_dir and (parent_dir.startswith("/") or ".." in parent_dir):
|
||||
return make_response(
|
||||
@@ -374,35 +355,19 @@ class ManageSourceFiles(Resource):
|
||||
400,
|
||||
)
|
||||
added_files = []
|
||||
map_updated = False
|
||||
|
||||
target_dir = source_file_path
|
||||
if parent_dir:
|
||||
target_dir = f"{source_file_path}/{parent_dir}"
|
||||
for file in files:
|
||||
if file.filename:
|
||||
original_filename = os.path.basename(file.filename)
|
||||
safe_filename_str = safe_filename(original_filename)
|
||||
safe_filename_str = safe_filename(file.filename)
|
||||
file_path = f"{target_dir}/{safe_filename_str}"
|
||||
|
||||
# Save file to storage
|
||||
|
||||
storage.save_file(file, file_path)
|
||||
added_files.append(safe_filename_str)
|
||||
if original_filename:
|
||||
relative_key = (
|
||||
f"{parent_dir}/{safe_filename_str}"
|
||||
if parent_dir
|
||||
else safe_filename_str
|
||||
)
|
||||
file_name_map[relative_key] = original_filename
|
||||
map_updated = True
|
||||
|
||||
if map_updated:
|
||||
sources_collection.update_one(
|
||||
{"_id": ObjectId(source_id)},
|
||||
{"$set": {"file_name_map": file_name_map}},
|
||||
)
|
||||
# Trigger re-ingestion pipeline
|
||||
|
||||
from application.api.user.tasks import reingest_source_task
|
||||
@@ -449,7 +414,6 @@ class ManageSourceFiles(Resource):
|
||||
# Remove files from storage and directory structure
|
||||
|
||||
removed_files = []
|
||||
map_updated = False
|
||||
for file_path in file_paths:
|
||||
full_path = f"{source_file_path}/{file_path}"
|
||||
|
||||
@@ -458,15 +422,6 @@ class ManageSourceFiles(Resource):
|
||||
if storage.file_exists(full_path):
|
||||
storage.delete_file(full_path)
|
||||
removed_files.append(file_path)
|
||||
if file_path in file_name_map:
|
||||
file_name_map.pop(file_path, None)
|
||||
map_updated = True
|
||||
|
||||
if map_updated and isinstance(file_name_map, dict):
|
||||
sources_collection.update_one(
|
||||
{"_id": ObjectId(source_id)},
|
||||
{"$set": {"file_name_map": file_name_map}},
|
||||
)
|
||||
# Trigger re-ingestion pipeline
|
||||
|
||||
from application.api.user.tasks import reingest_source_task
|
||||
@@ -549,20 +504,6 @@ class ManageSourceFiles(Resource):
|
||||
f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}, "
|
||||
f"Full path: {full_directory_path}"
|
||||
)
|
||||
if directory_path and file_name_map:
|
||||
prefix = f"{directory_path.rstrip('/')}/"
|
||||
keys_to_remove = [
|
||||
key
|
||||
for key in file_name_map.keys()
|
||||
if key == directory_path or key.startswith(prefix)
|
||||
]
|
||||
if keys_to_remove:
|
||||
for key in keys_to_remove:
|
||||
file_name_map.pop(key, None)
|
||||
sources_collection.update_one(
|
||||
{"_id": ObjectId(source_id)},
|
||||
{"$set": {"file_name_map": file_name_map}},
|
||||
)
|
||||
|
||||
# Trigger re-ingestion pipeline
|
||||
|
||||
@@ -633,9 +574,8 @@ class TaskStatus(Resource):
|
||||
):
|
||||
task_meta = str(task_meta) # Convert to a string representation
|
||||
except ConnectionError as err:
|
||||
current_app.logger.error(f"Connection error getting task status: {err}")
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Service unavailable"}), 503
|
||||
jsonify({"success": False, "message": str(err)}), 503
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error getting task status: {err}", exc_info=True)
|
||||
|
||||
@@ -8,25 +8,13 @@ from application.worker import (
|
||||
mcp_oauth,
|
||||
mcp_oauth_status,
|
||||
remote_worker,
|
||||
sync,
|
||||
sync_worker,
|
||||
)
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
def ingest(
|
||||
self, directory, formats, job_name, user, file_path, filename, file_name_map=None
|
||||
):
|
||||
resp = ingest_worker(
|
||||
self,
|
||||
directory,
|
||||
formats,
|
||||
job_name,
|
||||
file_path,
|
||||
filename,
|
||||
user,
|
||||
file_name_map=file_name_map,
|
||||
)
|
||||
def ingest(self, directory, formats, job_name, user, file_path, filename):
|
||||
resp = ingest_worker(self, directory, formats, job_name, file_path, filename, user)
|
||||
return resp
|
||||
|
||||
|
||||
@@ -50,30 +38,6 @@ def schedule_syncs(self, frequency):
|
||||
return resp
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
def sync_source(
|
||||
self,
|
||||
source_data,
|
||||
job_name,
|
||||
user,
|
||||
loader,
|
||||
sync_frequency,
|
||||
retriever,
|
||||
doc_id,
|
||||
):
|
||||
resp = sync(
|
||||
self,
|
||||
source_data,
|
||||
job_name,
|
||||
user,
|
||||
loader,
|
||||
sync_frequency,
|
||||
retriever,
|
||||
doc_id,
|
||||
)
|
||||
return resp
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
def store_attachment(self, file_info, user):
|
||||
resp = attachment_worker(self, file_info, user)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Tool management MCP server integration."""
|
||||
|
||||
import json
|
||||
from urllib.parse import unquote, urlencode
|
||||
from email.quoprimime import unquote
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, redirect, request
|
||||
@@ -43,16 +43,6 @@ class TestMCPServerConfig(Resource):
|
||||
return missing_fields
|
||||
try:
|
||||
config = data["config"]
|
||||
transport_type = (config.get("transport_type") or "auto").lower()
|
||||
allowed_transports = {"auto", "sse", "http"}
|
||||
if transport_type not in allowed_transports:
|
||||
return make_response(
|
||||
jsonify({"success": False, "error": "Unsupported transport_type"}),
|
||||
400,
|
||||
)
|
||||
config.pop("command", None)
|
||||
config.pop("args", None)
|
||||
config["transport_type"] = transport_type
|
||||
|
||||
auth_credentials = {}
|
||||
auth_type = config.get("auth_type", "none")
|
||||
@@ -74,17 +64,12 @@ class TestMCPServerConfig(Resource):
|
||||
mcp_tool = MCPTool(config=test_config, user_id=user)
|
||||
result = mcp_tool.test_connection()
|
||||
|
||||
# Sanitize the response to avoid exposing internal error details
|
||||
if not result.get("success") and "message" in result:
|
||||
current_app.logger.error(f"MCP connection test failed: {result.get('message')}")
|
||||
result["message"] = "Connection test failed"
|
||||
|
||||
return make_response(jsonify(result), 200)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error testing MCP server: {e}", exc_info=True)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{"success": False, "error": "Connection test failed"}
|
||||
{"success": False, "error": f"Connection test failed: {str(e)}"}
|
||||
),
|
||||
500,
|
||||
)
|
||||
@@ -125,16 +110,6 @@ class MCPServerSave(Resource):
|
||||
return missing_fields
|
||||
try:
|
||||
config = data["config"]
|
||||
transport_type = (config.get("transport_type") or "auto").lower()
|
||||
allowed_transports = {"auto", "sse", "http"}
|
||||
if transport_type not in allowed_transports:
|
||||
return make_response(
|
||||
jsonify({"success": False, "error": "Unsupported transport_type"}),
|
||||
400,
|
||||
)
|
||||
config.pop("command", None)
|
||||
config.pop("args", None)
|
||||
config["transport_type"] = transport_type
|
||||
|
||||
auth_credentials = {}
|
||||
auth_type = config.get("auth_type", "none")
|
||||
@@ -259,7 +234,7 @@ class MCPServerSave(Resource):
|
||||
current_app.logger.error(f"Error saving MCP server: {e}", exc_info=True)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{"success": False, "error": "Failed to save MCP server"}
|
||||
{"success": False, "error": f"Failed to save MCP server: {str(e)}"}
|
||||
),
|
||||
500,
|
||||
)
|
||||
@@ -288,12 +263,9 @@ class MCPOAuthCallback(Resource):
|
||||
error = request.args.get("error")
|
||||
|
||||
if error:
|
||||
params = {
|
||||
"status": "error",
|
||||
"message": f"OAuth error: {error}. Please try again and make sure to grant all requested permissions, including offline access.",
|
||||
"provider": "mcp_tool"
|
||||
}
|
||||
return redirect(f"/api/connectors/callback-status?{urlencode(params)}")
|
||||
return redirect(
|
||||
f"/api/connectors/callback-status?status=error&message=OAuth+error:+{error}.+Please+try+again+and+make+sure+to+grant+all+requested+permissions,+including+offline+access.&provider=mcp_tool"
|
||||
)
|
||||
if not code or not state:
|
||||
return redirect(
|
||||
"/api/connectors/callback-status?status=error&message=Authorization+code+or+state+not+provided.+Please+complete+the+authorization+process+and+make+sure+to+grant+offline+access.&provider=mcp_tool"
|
||||
@@ -320,7 +292,7 @@ class MCPOAuthCallback(Resource):
|
||||
f"Error handling MCP OAuth callback: {str(e)}", exc_info=True
|
||||
)
|
||||
return redirect(
|
||||
"/api/connectors/callback-status?status=error&message=Internal+server+error.&provider=mcp_tool"
|
||||
f"/api/connectors/callback-status?status=error&message=Internal+server+error:+{str(e)}.&provider=mcp_tool"
|
||||
)
|
||||
|
||||
|
||||
@@ -354,8 +326,8 @@ class MCPOAuthStatus(Resource):
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error getting OAuth status for task {task_id}: {str(e)}", exc_info=True
|
||||
f"Error getting OAuth status for task {task_id}: {str(e)}"
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "error": "Failed to get OAuth status", "task_id": task_id}), 500
|
||||
jsonify({"success": False, "error": str(e), "task_id": task_id}), 500
|
||||
)
|
||||
|
||||
@@ -4,7 +4,6 @@ from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.agents.tools.spec_parser import parse_spec
|
||||
from application.agents.tools.tool_manager import ToolManager
|
||||
from application.api import api
|
||||
from application.api.user.base import user_tools_collection
|
||||
@@ -415,136 +414,3 @@ class DeleteTool(Resource):
|
||||
current_app.logger.error(f"Error deleting tool: {err}", exc_info=True)
|
||||
return {"success": False}, 400
|
||||
return {"success": True}, 200
|
||||
|
||||
|
||||
@tools_ns.route("/parse_spec")
|
||||
class ParseSpec(Resource):
|
||||
@api.doc(
|
||||
description="Parse an API specification (OpenAPI 3.x or Swagger 2.0) and return actions"
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
if "file" in request.files:
|
||||
file = request.files["file"]
|
||||
if not file.filename:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "No file selected"}), 400
|
||||
)
|
||||
try:
|
||||
spec_content = file.read().decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid file encoding"}), 400
|
||||
)
|
||||
elif request.is_json:
|
||||
data = request.get_json()
|
||||
spec_content = data.get("spec_content", "")
|
||||
else:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "No spec provided"}), 400
|
||||
)
|
||||
if not spec_content or not spec_content.strip():
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Empty spec content"}), 400
|
||||
)
|
||||
try:
|
||||
metadata, actions = parse_spec(spec_content)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"metadata": metadata,
|
||||
"actions": actions,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
except ValueError as e:
|
||||
current_app.logger.error(f"Spec validation error: {e}")
|
||||
return make_response(jsonify({"success": False, "error": "Invalid specification format"}), 400)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error parsing spec: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False, "error": "Failed to parse specification"}), 500)
|
||||
|
||||
|
||||
@tools_ns.route("/artifact/<artifact_id>")
|
||||
class GetArtifact(Resource):
|
||||
@api.doc(description="Get artifact data by artifact ID. Returns all todos for the tool when fetching a todo artifact.")
|
||||
def get(self, artifact_id: str):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user_id = decoded_token.get("sub")
|
||||
|
||||
try:
|
||||
obj_id = ObjectId(artifact_id)
|
||||
except Exception:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid artifact ID"}), 400
|
||||
)
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
|
||||
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
|
||||
|
||||
note_doc = db["notes"].find_one({"_id": obj_id, "user_id": user_id})
|
||||
if note_doc:
|
||||
content = note_doc.get("note", "")
|
||||
line_count = len(content.split("\n")) if content else 0
|
||||
artifact = {
|
||||
"artifact_type": "note",
|
||||
"data": {
|
||||
"content": content,
|
||||
"line_count": line_count,
|
||||
"updated_at": (
|
||||
note_doc["updated_at"].isoformat()
|
||||
if note_doc.get("updated_at")
|
||||
else None
|
||||
),
|
||||
},
|
||||
}
|
||||
return make_response(jsonify({"success": True, "artifact": artifact}), 200)
|
||||
|
||||
todo_doc = db["todos"].find_one({"_id": obj_id, "user_id": user_id})
|
||||
if todo_doc:
|
||||
tool_id = todo_doc.get("tool_id")
|
||||
# Return all todos for the tool
|
||||
query = {"user_id": user_id, "tool_id": tool_id}
|
||||
all_todos = list(db["todos"].find(query))
|
||||
items = []
|
||||
open_count = 0
|
||||
completed_count = 0
|
||||
for t in all_todos:
|
||||
status = t.get("status", "open")
|
||||
if status == "open":
|
||||
open_count += 1
|
||||
elif status == "completed":
|
||||
completed_count += 1
|
||||
items.append({
|
||||
"todo_id": t.get("todo_id"),
|
||||
"title": t.get("title", ""),
|
||||
"status": status,
|
||||
"created_at": (
|
||||
t["created_at"].isoformat() if t.get("created_at") else None
|
||||
),
|
||||
"updated_at": (
|
||||
t["updated_at"].isoformat() if t.get("updated_at") else None
|
||||
),
|
||||
})
|
||||
artifact = {
|
||||
"artifact_type": "todo_list",
|
||||
"data": {
|
||||
"items": items,
|
||||
"total_count": len(items),
|
||||
"open_count": open_count,
|
||||
"completed_count": completed_count,
|
||||
},
|
||||
}
|
||||
return make_response(jsonify({"success": True, "artifact": artifact}), 200)
|
||||
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Artifact not found"}), 404
|
||||
)
|
||||
|
||||
@@ -1,378 +0,0 @@
|
||||
"""Centralized utilities for API routes."""
|
||||
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from bson.errors import InvalidId
|
||||
from bson.objectid import ObjectId
|
||||
from flask import jsonify, make_response, request, Response
|
||||
from pymongo.collection import Collection
|
||||
|
||||
|
||||
def get_user_id() -> Optional[str]:
|
||||
"""
|
||||
Extract user ID from decoded JWT token.
|
||||
|
||||
Returns:
|
||||
User ID string or None if not authenticated
|
||||
"""
|
||||
decoded_token = getattr(request, "decoded_token", None)
|
||||
return decoded_token.get("sub") if decoded_token else None
|
||||
|
||||
|
||||
def require_auth(func: Callable) -> Callable:
|
||||
"""
|
||||
Decorator to require authentication for route handlers.
|
||||
|
||||
Usage:
|
||||
@require_auth
|
||||
def get(self):
|
||||
user_id = get_user_id()
|
||||
...
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
user_id = get_user_id()
|
||||
if not user_id:
|
||||
return error_response("Unauthorized", 401)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def success_response(
|
||||
data: Optional[Dict[str, Any]] = None, status: int = 200
|
||||
) -> Response:
|
||||
"""
|
||||
Create a standardized success response.
|
||||
|
||||
Args:
|
||||
data: Optional data dictionary to include in response
|
||||
status: HTTP status code (default: 200)
|
||||
|
||||
Returns:
|
||||
Flask Response object
|
||||
|
||||
Example:
|
||||
return success_response({"users": [...], "total": 10})
|
||||
"""
|
||||
response = {"success": True}
|
||||
if data:
|
||||
response.update(data)
|
||||
return make_response(jsonify(response), status)
|
||||
|
||||
|
||||
def error_response(message: str, status: int = 400, **kwargs) -> Response:
|
||||
"""
|
||||
Create a standardized error response.
|
||||
|
||||
Args:
|
||||
message: Error message string
|
||||
status: HTTP status code (default: 400)
|
||||
**kwargs: Additional fields to include in response
|
||||
|
||||
Returns:
|
||||
Flask Response object
|
||||
|
||||
Example:
|
||||
return error_response("Resource not found", 404)
|
||||
return error_response("Invalid input", 400, errors=["field1", "field2"])
|
||||
"""
|
||||
response = {"success": False, "message": message}
|
||||
response.update(kwargs)
|
||||
return make_response(jsonify(response), status)
|
||||
|
||||
|
||||
def validate_object_id(
|
||||
id_string: str, resource_name: str = "Resource"
|
||||
) -> Tuple[Optional[ObjectId], Optional[Response]]:
|
||||
"""
|
||||
Validate and convert string to ObjectId.
|
||||
|
||||
Args:
|
||||
id_string: String to convert
|
||||
resource_name: Name of resource for error message
|
||||
|
||||
Returns:
|
||||
Tuple of (ObjectId or None, error_response or None)
|
||||
|
||||
Example:
|
||||
obj_id, error = validate_object_id(workflow_id, "Workflow")
|
||||
if error:
|
||||
return error
|
||||
"""
|
||||
try:
|
||||
return ObjectId(id_string), None
|
||||
except (InvalidId, TypeError):
|
||||
return None, error_response(f"Invalid {resource_name} ID format")
|
||||
|
||||
|
||||
def validate_pagination(
|
||||
default_limit: int = 20, max_limit: int = 100
|
||||
) -> Tuple[int, int, Optional[Response]]:
|
||||
"""
|
||||
Extract and validate pagination parameters from request.
|
||||
|
||||
Args:
|
||||
default_limit: Default items per page
|
||||
max_limit: Maximum allowed items per page
|
||||
|
||||
Returns:
|
||||
Tuple of (limit, skip, error_response or None)
|
||||
|
||||
Example:
|
||||
limit, skip, error = validate_pagination()
|
||||
if error:
|
||||
return error
|
||||
"""
|
||||
try:
|
||||
limit = min(int(request.args.get("limit", default_limit)), max_limit)
|
||||
skip = int(request.args.get("skip", 0))
|
||||
if limit < 1 or skip < 0:
|
||||
return 0, 0, error_response("Invalid pagination parameters")
|
||||
return limit, skip, None
|
||||
except ValueError:
|
||||
return 0, 0, error_response("Invalid pagination parameters")
|
||||
|
||||
|
||||
def check_resource_ownership(
|
||||
collection: Collection,
|
||||
resource_id: ObjectId,
|
||||
user_id: str,
|
||||
resource_name: str = "Resource",
|
||||
) -> Tuple[Optional[Dict], Optional[Response]]:
|
||||
"""
|
||||
Check if resource exists and belongs to user.
|
||||
|
||||
Args:
|
||||
collection: MongoDB collection
|
||||
resource_id: Resource ObjectId
|
||||
user_id: User ID string
|
||||
resource_name: Name of resource for error messages
|
||||
|
||||
Returns:
|
||||
Tuple of (resource_dict or None, error_response or None)
|
||||
|
||||
Example:
|
||||
workflow, error = check_resource_ownership(
|
||||
workflows_collection,
|
||||
workflow_id,
|
||||
user_id,
|
||||
"Workflow"
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
"""
|
||||
resource = collection.find_one({"_id": resource_id, "user": user_id})
|
||||
if not resource:
|
||||
return None, error_response(f"{resource_name} not found", 404)
|
||||
return resource, None
|
||||
|
||||
|
||||
def serialize_object_id(
|
||||
obj: Dict[str, Any], id_field: str = "_id", new_field: str = "id"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert ObjectId to string in a dictionary.
|
||||
|
||||
Args:
|
||||
obj: Dictionary containing ObjectId
|
||||
id_field: Field name containing ObjectId
|
||||
new_field: New field name for string ID
|
||||
|
||||
Returns:
|
||||
Modified dictionary
|
||||
|
||||
Example:
|
||||
user = serialize_object_id(user_doc)
|
||||
# user["id"] = "507f1f77bcf86cd799439011"
|
||||
"""
|
||||
if id_field in obj:
|
||||
obj[new_field] = str(obj[id_field])
|
||||
if id_field != new_field:
|
||||
obj.pop(id_field, None)
|
||||
return obj
|
||||
|
||||
|
||||
def serialize_list(items: List[Dict], serializer: Callable[[Dict], Dict]) -> List[Dict]:
|
||||
"""
|
||||
Apply serializer function to list of items.
|
||||
|
||||
Args:
|
||||
items: List of dictionaries
|
||||
serializer: Function to apply to each item
|
||||
|
||||
Returns:
|
||||
List of serialized items
|
||||
|
||||
Example:
|
||||
workflows = serialize_list(workflow_docs, serialize_workflow)
|
||||
"""
|
||||
return [serializer(item) for item in items]
|
||||
|
||||
|
||||
def paginated_response(
|
||||
collection: Collection,
|
||||
query: Dict[str, Any],
|
||||
serializer: Callable[[Dict], Dict],
|
||||
limit: int,
|
||||
skip: int,
|
||||
sort_field: str = "created_at",
|
||||
sort_order: int = -1,
|
||||
response_key: str = "items",
|
||||
) -> Response:
|
||||
"""
|
||||
Create paginated response for collection query.
|
||||
|
||||
Args:
|
||||
collection: MongoDB collection
|
||||
query: Query dictionary
|
||||
serializer: Function to serialize each item
|
||||
limit: Items per page
|
||||
skip: Number of items to skip
|
||||
sort_field: Field to sort by
|
||||
sort_order: Sort order (1=asc, -1=desc)
|
||||
response_key: Key name for items in response
|
||||
|
||||
Returns:
|
||||
Flask Response with paginated data
|
||||
|
||||
Example:
|
||||
return paginated_response(
|
||||
workflows_collection,
|
||||
{"user": user_id},
|
||||
serialize_workflow,
|
||||
limit, skip,
|
||||
response_key="workflows"
|
||||
)
|
||||
"""
|
||||
items = list(
|
||||
collection.find(query).sort(sort_field, sort_order).skip(skip).limit(limit)
|
||||
)
|
||||
total = collection.count_documents(query)
|
||||
|
||||
return success_response(
|
||||
{
|
||||
response_key: serialize_list(items, serializer),
|
||||
"total": total,
|
||||
"limit": limit,
|
||||
"skip": skip,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def require_fields(required: List[str]) -> Callable:
|
||||
"""
|
||||
Decorator to validate required fields in request JSON.
|
||||
|
||||
Args:
|
||||
required: List of required field names
|
||||
|
||||
Returns:
|
||||
Decorator function
|
||||
|
||||
Example:
|
||||
@require_fields(["name", "description"])
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
...
|
||||
"""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
data = request.get_json()
|
||||
if not data:
|
||||
return error_response("Request body required")
|
||||
missing = [field for field in required if not data.get(field)]
|
||||
if missing:
|
||||
return error_response(f"Missing required fields: {', '.join(missing)}")
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def safe_db_operation(
|
||||
operation: Callable, error_message: str = "Database operation failed"
|
||||
) -> Tuple[Any, Optional[Response]]:
|
||||
"""
|
||||
Safely execute database operation with error handling.
|
||||
|
||||
Args:
|
||||
operation: Function to execute
|
||||
error_message: Error message if operation fails
|
||||
|
||||
Returns:
|
||||
Tuple of (result or None, error_response or None)
|
||||
|
||||
Example:
|
||||
result, error = safe_db_operation(
|
||||
lambda: collection.insert_one(doc),
|
||||
"Failed to create resource"
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
"""
|
||||
try:
|
||||
result = operation()
|
||||
return result, None
|
||||
except Exception as e:
|
||||
return None, error_response(f"{error_message}: {str(e)}")
|
||||
|
||||
|
||||
def validate_enum(
|
||||
value: Any, allowed: List[Any], field_name: str
|
||||
) -> Optional[Response]:
|
||||
"""
|
||||
Validate that value is in allowed list.
|
||||
|
||||
Args:
|
||||
value: Value to validate
|
||||
allowed: List of allowed values
|
||||
field_name: Field name for error message
|
||||
|
||||
Returns:
|
||||
error_response if invalid, None if valid
|
||||
|
||||
Example:
|
||||
error = validate_enum(status, ["draft", "published"], "status")
|
||||
if error:
|
||||
return error
|
||||
"""
|
||||
if value not in allowed:
|
||||
allowed_str = ", ".join(f"'{v}'" for v in allowed)
|
||||
return error_response(f"Invalid {field_name}. Must be one of: {allowed_str}")
|
||||
return None
|
||||
|
||||
|
||||
def extract_sort_params(
|
||||
default_field: str = "created_at",
|
||||
default_order: str = "desc",
|
||||
allowed_fields: Optional[List[str]] = None,
|
||||
) -> Tuple[str, int]:
|
||||
"""
|
||||
Extract and validate sort parameters from request.
|
||||
|
||||
Args:
|
||||
default_field: Default sort field
|
||||
default_order: Default sort order ("asc" or "desc")
|
||||
allowed_fields: List of allowed sort fields (None = no validation)
|
||||
|
||||
Returns:
|
||||
Tuple of (sort_field, sort_order)
|
||||
|
||||
Example:
|
||||
sort_field, sort_order = extract_sort_params(
|
||||
allowed_fields=["name", "date", "status"]
|
||||
)
|
||||
"""
|
||||
sort_field = request.args.get("sort", default_field)
|
||||
sort_order_str = request.args.get("order", default_order).lower()
|
||||
|
||||
if allowed_fields and sort_field not in allowed_fields:
|
||||
sort_field = default_field
|
||||
sort_order = -1 if sort_order_str == "desc" else 1
|
||||
return sort_field, sort_order
|
||||
@@ -1,3 +0,0 @@
|
||||
from .routes import workflows_ns
|
||||
|
||||
__all__ = ["workflows_ns"]
|
||||
@@ -1,353 +0,0 @@
|
||||
"""Workflow management routes."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, List
|
||||
|
||||
from flask import current_app, request
|
||||
from flask_restx import Namespace, Resource
|
||||
|
||||
from application.api.user.base import (
|
||||
workflow_edges_collection,
|
||||
workflow_nodes_collection,
|
||||
workflows_collection,
|
||||
)
|
||||
from application.api.user.utils import (
|
||||
check_resource_ownership,
|
||||
error_response,
|
||||
get_user_id,
|
||||
require_auth,
|
||||
require_fields,
|
||||
safe_db_operation,
|
||||
success_response,
|
||||
validate_object_id,
|
||||
)
|
||||
|
||||
workflows_ns = Namespace("workflows", path="/api")
|
||||
|
||||
|
||||
def serialize_workflow(w: Dict) -> Dict:
|
||||
"""Serialize workflow document to API response format."""
|
||||
return {
|
||||
"id": str(w["_id"]),
|
||||
"name": w.get("name"),
|
||||
"description": w.get("description"),
|
||||
"created_at": w["created_at"].isoformat() if w.get("created_at") else None,
|
||||
"updated_at": w["updated_at"].isoformat() if w.get("updated_at") else None,
|
||||
}
|
||||
|
||||
|
||||
def serialize_node(n: Dict) -> Dict:
|
||||
"""Serialize workflow node document to API response format."""
|
||||
return {
|
||||
"id": n["id"],
|
||||
"type": n["type"],
|
||||
"title": n.get("title"),
|
||||
"description": n.get("description"),
|
||||
"position": n.get("position"),
|
||||
"data": n.get("config", {}),
|
||||
}
|
||||
|
||||
|
||||
def serialize_edge(e: Dict) -> Dict:
|
||||
"""Serialize workflow edge document to API response format."""
|
||||
return {
|
||||
"id": e["id"],
|
||||
"source": e.get("source_id"),
|
||||
"target": e.get("target_id"),
|
||||
"sourceHandle": e.get("source_handle"),
|
||||
"targetHandle": e.get("target_handle"),
|
||||
}
|
||||
|
||||
|
||||
def get_workflow_graph_version(workflow: Dict) -> int:
|
||||
"""Get current graph version with legacy fallback."""
|
||||
raw_version = workflow.get("current_graph_version", 1)
|
||||
try:
|
||||
version = int(raw_version)
|
||||
return version if version > 0 else 1
|
||||
except (ValueError, TypeError):
|
||||
return 1
|
||||
|
||||
|
||||
def fetch_graph_documents(collection, workflow_id: str, graph_version: int) -> List[Dict]:
|
||||
"""Fetch graph docs for active version, with fallback for legacy unversioned data."""
|
||||
docs = list(
|
||||
collection.find({"workflow_id": workflow_id, "graph_version": graph_version})
|
||||
)
|
||||
if docs:
|
||||
return docs
|
||||
if graph_version == 1:
|
||||
return list(
|
||||
collection.find(
|
||||
{"workflow_id": workflow_id, "graph_version": {"$exists": False}}
|
||||
)
|
||||
)
|
||||
return docs
|
||||
|
||||
|
||||
def validate_workflow_structure(nodes: List[Dict], edges: List[Dict]) -> List[str]:
|
||||
"""Validate workflow graph structure."""
|
||||
errors = []
|
||||
|
||||
if not nodes:
|
||||
errors.append("Workflow must have at least one node")
|
||||
return errors
|
||||
|
||||
start_nodes = [n for n in nodes if n.get("type") == "start"]
|
||||
if len(start_nodes) != 1:
|
||||
errors.append("Workflow must have exactly one start node")
|
||||
|
||||
end_nodes = [n for n in nodes if n.get("type") == "end"]
|
||||
if not end_nodes:
|
||||
errors.append("Workflow must have at least one end node")
|
||||
|
||||
node_ids = {n.get("id") for n in nodes}
|
||||
for edge in edges:
|
||||
source_id = edge.get("source")
|
||||
target_id = edge.get("target")
|
||||
if source_id not in node_ids:
|
||||
errors.append(f"Edge references non-existent source: {source_id}")
|
||||
if target_id not in node_ids:
|
||||
errors.append(f"Edge references non-existent target: {target_id}")
|
||||
|
||||
if start_nodes:
|
||||
start_id = start_nodes[0].get("id")
|
||||
if not any(e.get("source") == start_id for e in edges):
|
||||
errors.append("Start node must have at least one outgoing edge")
|
||||
|
||||
for node in nodes:
|
||||
if not node.get("id"):
|
||||
errors.append("All nodes must have an id")
|
||||
if not node.get("type"):
|
||||
errors.append(f"Node {node.get('id', 'unknown')} must have a type")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def create_workflow_nodes(
|
||||
workflow_id: str, nodes_data: List[Dict], graph_version: int
|
||||
) -> None:
|
||||
"""Insert workflow nodes into database."""
|
||||
if nodes_data:
|
||||
workflow_nodes_collection.insert_many(
|
||||
[
|
||||
{
|
||||
"id": n["id"],
|
||||
"workflow_id": workflow_id,
|
||||
"graph_version": graph_version,
|
||||
"type": n["type"],
|
||||
"title": n.get("title", ""),
|
||||
"description": n.get("description", ""),
|
||||
"position": n.get("position", {"x": 0, "y": 0}),
|
||||
"config": n.get("data", {}),
|
||||
}
|
||||
for n in nodes_data
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def create_workflow_edges(
|
||||
workflow_id: str, edges_data: List[Dict], graph_version: int
|
||||
) -> None:
|
||||
"""Insert workflow edges into database."""
|
||||
if edges_data:
|
||||
workflow_edges_collection.insert_many(
|
||||
[
|
||||
{
|
||||
"id": e["id"],
|
||||
"workflow_id": workflow_id,
|
||||
"graph_version": graph_version,
|
||||
"source_id": e.get("source"),
|
||||
"target_id": e.get("target"),
|
||||
"source_handle": e.get("sourceHandle"),
|
||||
"target_handle": e.get("targetHandle"),
|
||||
}
|
||||
for e in edges_data
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@workflows_ns.route("/workflows")
|
||||
class WorkflowList(Resource):
|
||||
|
||||
@require_auth
|
||||
@require_fields(["name"])
|
||||
def post(self):
|
||||
"""Create a new workflow with nodes and edges."""
|
||||
user_id = get_user_id()
|
||||
data = request.get_json()
|
||||
|
||||
name = data.get("name", "").strip()
|
||||
nodes_data = data.get("nodes", [])
|
||||
edges_data = data.get("edges", [])
|
||||
|
||||
validation_errors = validate_workflow_structure(nodes_data, edges_data)
|
||||
if validation_errors:
|
||||
return error_response(
|
||||
"Workflow validation failed", errors=validation_errors
|
||||
)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
workflow_doc = {
|
||||
"name": name,
|
||||
"description": data.get("description", ""),
|
||||
"user": user_id,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"current_graph_version": 1,
|
||||
}
|
||||
|
||||
result, error = safe_db_operation(
|
||||
lambda: workflows_collection.insert_one(workflow_doc),
|
||||
"Failed to create workflow",
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
|
||||
workflow_id = str(result.inserted_id)
|
||||
|
||||
try:
|
||||
create_workflow_nodes(workflow_id, nodes_data, 1)
|
||||
create_workflow_edges(workflow_id, edges_data, 1)
|
||||
except Exception as e:
|
||||
workflow_nodes_collection.delete_many({"workflow_id": workflow_id})
|
||||
workflow_edges_collection.delete_many({"workflow_id": workflow_id})
|
||||
workflows_collection.delete_one({"_id": result.inserted_id})
|
||||
return error_response(f"Failed to create workflow structure: {str(e)}")
|
||||
|
||||
return success_response({"id": workflow_id}, 201)
|
||||
|
||||
|
||||
@workflows_ns.route("/workflows/<string:workflow_id>")
|
||||
class WorkflowDetail(Resource):
|
||||
|
||||
@require_auth
|
||||
def get(self, workflow_id: str):
|
||||
"""Get workflow details with nodes and edges."""
|
||||
user_id = get_user_id()
|
||||
obj_id, error = validate_object_id(workflow_id, "Workflow")
|
||||
if error:
|
||||
return error
|
||||
|
||||
workflow, error = check_resource_ownership(
|
||||
workflows_collection, obj_id, user_id, "Workflow"
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
|
||||
graph_version = get_workflow_graph_version(workflow)
|
||||
nodes = fetch_graph_documents(
|
||||
workflow_nodes_collection, workflow_id, graph_version
|
||||
)
|
||||
edges = fetch_graph_documents(
|
||||
workflow_edges_collection, workflow_id, graph_version
|
||||
)
|
||||
|
||||
return success_response(
|
||||
{
|
||||
"workflow": serialize_workflow(workflow),
|
||||
"nodes": [serialize_node(n) for n in nodes],
|
||||
"edges": [serialize_edge(e) for e in edges],
|
||||
}
|
||||
)
|
||||
|
||||
@require_auth
|
||||
@require_fields(["name"])
|
||||
def put(self, workflow_id: str):
|
||||
"""Update workflow and replace nodes/edges."""
|
||||
user_id = get_user_id()
|
||||
obj_id, error = validate_object_id(workflow_id, "Workflow")
|
||||
if error:
|
||||
return error
|
||||
|
||||
workflow, error = check_resource_ownership(
|
||||
workflows_collection, obj_id, user_id, "Workflow"
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
|
||||
data = request.get_json()
|
||||
name = data.get("name", "").strip()
|
||||
nodes_data = data.get("nodes", [])
|
||||
edges_data = data.get("edges", [])
|
||||
|
||||
validation_errors = validate_workflow_structure(nodes_data, edges_data)
|
||||
if validation_errors:
|
||||
return error_response(
|
||||
"Workflow validation failed", errors=validation_errors
|
||||
)
|
||||
|
||||
current_graph_version = get_workflow_graph_version(workflow)
|
||||
next_graph_version = current_graph_version + 1
|
||||
try:
|
||||
create_workflow_nodes(workflow_id, nodes_data, next_graph_version)
|
||||
create_workflow_edges(workflow_id, edges_data, next_graph_version)
|
||||
except Exception as e:
|
||||
workflow_nodes_collection.delete_many(
|
||||
{"workflow_id": workflow_id, "graph_version": next_graph_version}
|
||||
)
|
||||
workflow_edges_collection.delete_many(
|
||||
{"workflow_id": workflow_id, "graph_version": next_graph_version}
|
||||
)
|
||||
return error_response(f"Failed to update workflow structure: {str(e)}")
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
_, error = safe_db_operation(
|
||||
lambda: workflows_collection.update_one(
|
||||
{"_id": obj_id},
|
||||
{
|
||||
"$set": {
|
||||
"name": name,
|
||||
"description": data.get("description", ""),
|
||||
"updated_at": now,
|
||||
"current_graph_version": next_graph_version,
|
||||
}
|
||||
},
|
||||
),
|
||||
"Failed to update workflow",
|
||||
)
|
||||
if error:
|
||||
workflow_nodes_collection.delete_many(
|
||||
{"workflow_id": workflow_id, "graph_version": next_graph_version}
|
||||
)
|
||||
workflow_edges_collection.delete_many(
|
||||
{"workflow_id": workflow_id, "graph_version": next_graph_version}
|
||||
)
|
||||
return error
|
||||
|
||||
try:
|
||||
workflow_nodes_collection.delete_many(
|
||||
{"workflow_id": workflow_id, "graph_version": {"$ne": next_graph_version}}
|
||||
)
|
||||
workflow_edges_collection.delete_many(
|
||||
{"workflow_id": workflow_id, "graph_version": {"$ne": next_graph_version}}
|
||||
)
|
||||
except Exception as cleanup_err:
|
||||
current_app.logger.warning(
|
||||
f"Failed to clean old workflow graph versions for {workflow_id}: {cleanup_err}"
|
||||
)
|
||||
|
||||
return success_response()
|
||||
|
||||
@require_auth
|
||||
def delete(self, workflow_id: str):
|
||||
"""Delete workflow and its graph."""
|
||||
user_id = get_user_id()
|
||||
obj_id, error = validate_object_id(workflow_id, "Workflow")
|
||||
if error:
|
||||
return error
|
||||
|
||||
workflow, error = check_resource_ownership(
|
||||
workflows_collection, obj_id, user_id, "Workflow"
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
|
||||
try:
|
||||
workflow_nodes_collection.delete_many({"workflow_id": workflow_id})
|
||||
workflow_edges_collection.delete_many({"workflow_id": workflow_id})
|
||||
workflows_collection.delete_one({"_id": workflow["_id"], "user": user_id})
|
||||
except Exception as e:
|
||||
return error_response(f"Failed to delete workflow: {str(e)}")
|
||||
|
||||
return success_response()
|
||||
@@ -19,9 +19,9 @@ def handle_auth(request, data={}):
|
||||
options={"verify_exp": False},
|
||||
)
|
||||
return decoded_token
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
return {
|
||||
"message": "Authentication error: invalid token",
|
||||
"message": f"Authentication error: {str(e)}",
|
||||
"error": "invalid_token",
|
||||
}
|
||||
else:
|
||||
|
||||
@@ -21,4 +21,3 @@ def config_loggers(*args, **kwargs):
|
||||
|
||||
|
||||
celery = make_celery()
|
||||
celery.config_from_object("application.celeryconfig")
|
||||
|
||||
@@ -6,6 +6,3 @@ result_backend = os.getenv("CELERY_RESULT_BACKEND")
|
||||
task_serializer = 'json'
|
||||
result_serializer = 'json'
|
||||
accept_content = ['json']
|
||||
|
||||
# Autodiscover tasks
|
||||
imports = ('application.api.user.tasks',)
|
||||
|
||||
@@ -8,8 +8,8 @@ from application.core.model_settings import (
|
||||
ModelProvider,
|
||||
)
|
||||
|
||||
# Base image attachment types supported by most vision-capable LLMs
|
||||
IMAGE_ATTACHMENTS = [
|
||||
OPENAI_ATTACHMENTS = [
|
||||
"application/pdf",
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/jpg",
|
||||
@@ -17,42 +17,75 @@ IMAGE_ATTACHMENTS = [
|
||||
"image/gif",
|
||||
]
|
||||
|
||||
# PDF excluded: most OpenAI-compatible endpoints don't support native PDF uploads.
|
||||
# When excluded, PDFs are synthetically processed by converting pages to images.
|
||||
OPENAI_ATTACHMENTS = IMAGE_ATTACHMENTS
|
||||
|
||||
GOOGLE_ATTACHMENTS = ["application/pdf"] + IMAGE_ATTACHMENTS
|
||||
|
||||
ANTHROPIC_ATTACHMENTS = IMAGE_ATTACHMENTS
|
||||
|
||||
OPENROUTER_ATTACHMENTS = IMAGE_ATTACHMENTS
|
||||
GOOGLE_ATTACHMENTS = [
|
||||
"application/pdf",
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/jpg",
|
||||
"image/webp",
|
||||
"image/gif",
|
||||
]
|
||||
|
||||
|
||||
OPENAI_MODELS = [
|
||||
AvailableModel(
|
||||
id="gpt-5.1",
|
||||
id="gpt-4o",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="GPT-5.1",
|
||||
description="Flagship model with enhanced reasoning, coding, and agentic capabilities",
|
||||
display_name="GPT-4 Omni",
|
||||
description="Latest and most capable model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=OPENAI_ATTACHMENTS,
|
||||
context_window=200000,
|
||||
context_window=128000,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="gpt-5-mini",
|
||||
id="gpt-4o-mini",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="GPT-5 Mini",
|
||||
description="Faster, cost-effective variant of GPT-5.1",
|
||||
display_name="GPT-4 Omni Mini",
|
||||
description="Fast and efficient",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=OPENAI_ATTACHMENTS,
|
||||
context_window=200000,
|
||||
context_window=128000,
|
||||
),
|
||||
)
|
||||
),
|
||||
AvailableModel(
|
||||
id="gpt-4-turbo",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="GPT-4 Turbo",
|
||||
description="Fast GPT-4 with 128k context",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=OPENAI_ATTACHMENTS,
|
||||
context_window=128000,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="gpt-4",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="GPT-4",
|
||||
description="Most capable model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=OPENAI_ATTACHMENTS,
|
||||
context_window=8192,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="gpt-3.5-turbo",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="GPT-3.5 Turbo",
|
||||
description="Fast and cost-effective",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
context_window=4096,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@@ -64,7 +97,6 @@ ANTHROPIC_MODELS = [
|
||||
description="Latest Claude 3.5 Sonnet with enhanced capabilities",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
|
||||
context_window=200000,
|
||||
),
|
||||
),
|
||||
@@ -75,7 +107,6 @@ ANTHROPIC_MODELS = [
|
||||
description="Balanced performance and capability",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
|
||||
context_window=200000,
|
||||
),
|
||||
),
|
||||
@@ -86,7 +117,6 @@ ANTHROPIC_MODELS = [
|
||||
description="Most capable Claude model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
|
||||
context_window=200000,
|
||||
),
|
||||
),
|
||||
@@ -97,7 +127,6 @@ ANTHROPIC_MODELS = [
|
||||
description="Fastest Claude model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
|
||||
context_window=200000,
|
||||
),
|
||||
),
|
||||
@@ -130,9 +159,9 @@ GOOGLE_MODELS = [
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="gemini-3-pro-preview",
|
||||
id="gemini-2.5-pro",
|
||||
provider=ModelProvider.GOOGLE,
|
||||
display_name="Gemini 3 Pro",
|
||||
display_name="Gemini 2.5 Pro",
|
||||
description="Most capable Gemini model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
@@ -156,43 +185,28 @@ GROQ_MODELS = [
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="openai/gpt-oss-120b",
|
||||
id="llama-3.1-8b-instant",
|
||||
provider=ModelProvider.GROQ,
|
||||
display_name="GPT-OSS 120B",
|
||||
description="Open-source GPT model optimized for speed",
|
||||
display_name="Llama 3.1 8B",
|
||||
description="Ultra-fast inference",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
context_window=128000,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="mixtral-8x7b-32768",
|
||||
provider=ModelProvider.GROQ,
|
||||
display_name="Mixtral 8x7B",
|
||||
description="High-speed inference with tools",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
context_window=32768,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
OPENROUTER_MODELS = [
|
||||
AvailableModel(
|
||||
id="qwen/qwen3-coder:free",
|
||||
provider=ModelProvider.OPENROUTER,
|
||||
display_name="Qwen 3 Coder",
|
||||
description="Latest Qwen model with high-speed inference",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
context_window=128000,
|
||||
supported_attachment_types=OPENROUTER_ATTACHMENTS
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="google/gemma-3-27b-it:free",
|
||||
provider=ModelProvider.OPENROUTER,
|
||||
display_name="Gemma 3 27B",
|
||||
description="Latest Gemma model with high-speed inference",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
context_window=128000,
|
||||
supported_attachment_types=OPENROUTER_ATTACHMENTS
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
AZURE_OPENAI_MODELS = [
|
||||
AvailableModel(
|
||||
id="azure-gpt-4",
|
||||
@@ -207,18 +221,3 @@ AZURE_OPENAI_MODELS = [
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def create_custom_openai_model(model_name: str, base_url: str) -> AvailableModel:
|
||||
"""Create a custom OpenAI-compatible model (e.g., LM Studio, Ollama)."""
|
||||
return AvailableModel(
|
||||
id=model_name,
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name=model_name,
|
||||
description=f"Custom OpenAI-compatible model at {base_url}",
|
||||
base_url=base_url,
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supported_attachment_types=OPENAI_ATTACHMENTS,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -8,7 +8,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class ModelProvider(str, Enum):
|
||||
OPENAI = "openai"
|
||||
OPENROUTER = "openrouter"
|
||||
AZURE_OPENAI = "azure_openai"
|
||||
ANTHROPIC = "anthropic"
|
||||
GROQ = "groq"
|
||||
@@ -85,13 +84,9 @@ class ModelRegistry:
|
||||
|
||||
self.models.clear()
|
||||
|
||||
# Skip DocsGPT model if using custom OpenAI-compatible endpoint
|
||||
if not settings.OPENAI_BASE_URL:
|
||||
self._add_docsgpt_models(settings)
|
||||
if (
|
||||
settings.OPENAI_API_KEY
|
||||
or (settings.LLM_PROVIDER == "openai" and settings.API_KEY)
|
||||
or settings.OPENAI_BASE_URL
|
||||
self._add_docsgpt_models(settings)
|
||||
if settings.OPENAI_API_KEY or (
|
||||
settings.LLM_PROVIDER == "openai" and settings.API_KEY
|
||||
):
|
||||
self._add_openai_models(settings)
|
||||
if settings.OPENAI_API_BASE or (
|
||||
@@ -110,69 +105,39 @@ class ModelRegistry:
|
||||
settings.LLM_PROVIDER == "groq" and settings.API_KEY
|
||||
):
|
||||
self._add_groq_models(settings)
|
||||
if settings.OPEN_ROUTER_API_KEY or (
|
||||
settings.LLM_PROVIDER == "openrouter" and settings.API_KEY
|
||||
):
|
||||
self._add_openrouter_models(settings)
|
||||
if settings.HUGGINGFACE_API_KEY or (
|
||||
settings.LLM_PROVIDER == "huggingface" and settings.API_KEY
|
||||
):
|
||||
self._add_huggingface_models(settings)
|
||||
# Default model selection
|
||||
if settings.LLM_NAME:
|
||||
# Parse LLM_NAME (may be comma-separated)
|
||||
model_names = self._parse_model_names(settings.LLM_NAME)
|
||||
# First model in the list becomes default
|
||||
for model_name in model_names:
|
||||
if model_name in self.models:
|
||||
self.default_model_id = model_name
|
||||
|
||||
if settings.LLM_NAME and settings.LLM_NAME in self.models:
|
||||
self.default_model_id = settings.LLM_NAME
|
||||
elif settings.LLM_PROVIDER and settings.API_KEY:
|
||||
for model_id, model in self.models.items():
|
||||
if model.provider.value == settings.LLM_PROVIDER:
|
||||
self.default_model_id = model_id
|
||||
break
|
||||
# Backward compat: try exact match if no parsed model found
|
||||
if not self.default_model_id and settings.LLM_NAME in self.models:
|
||||
self.default_model_id = settings.LLM_NAME
|
||||
|
||||
if not self.default_model_id:
|
||||
if settings.LLM_PROVIDER and settings.API_KEY:
|
||||
for model_id, model in self.models.items():
|
||||
if model.provider.value == settings.LLM_PROVIDER:
|
||||
self.default_model_id = model_id
|
||||
break
|
||||
|
||||
if not self.default_model_id and self.models:
|
||||
else:
|
||||
self.default_model_id = next(iter(self.models.keys()))
|
||||
logger.info(
|
||||
f"ModelRegistry loaded {len(self.models)} models, default: {self.default_model_id}"
|
||||
)
|
||||
|
||||
def _add_openai_models(self, settings):
|
||||
from application.core.model_configs import (
|
||||
OPENAI_MODELS,
|
||||
create_custom_openai_model,
|
||||
)
|
||||
from application.core.model_configs import OPENAI_MODELS
|
||||
|
||||
# Check if using local OpenAI-compatible endpoint (Ollama, LM Studio, etc.)
|
||||
using_local_endpoint = bool(
|
||||
settings.OPENAI_BASE_URL and settings.OPENAI_BASE_URL.strip()
|
||||
)
|
||||
|
||||
if using_local_endpoint:
|
||||
# When OPENAI_BASE_URL is set, ONLY register custom models from LLM_NAME
|
||||
# Do NOT add standard OpenAI models (gpt-5.1, etc.)
|
||||
if settings.LLM_NAME:
|
||||
model_names = self._parse_model_names(settings.LLM_NAME)
|
||||
for model_name in model_names:
|
||||
custom_model = create_custom_openai_model(
|
||||
model_name, settings.OPENAI_BASE_URL
|
||||
)
|
||||
self.models[model_name] = custom_model
|
||||
logger.info(
|
||||
f"Registered custom OpenAI model: {model_name} at {settings.OPENAI_BASE_URL}"
|
||||
)
|
||||
else:
|
||||
# Standard OpenAI API usage - add standard models if API key is valid
|
||||
if settings.OPENAI_API_KEY:
|
||||
for model in OPENAI_MODELS:
|
||||
if settings.OPENAI_API_KEY:
|
||||
for model in OPENAI_MODELS:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
if settings.LLM_PROVIDER == "openai" and settings.LLM_NAME:
|
||||
for model in OPENAI_MODELS:
|
||||
if model.id == settings.LLM_NAME:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
for model in OPENAI_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_azure_openai_models(self, settings):
|
||||
from application.core.model_configs import AZURE_OPENAI_MODELS
|
||||
@@ -229,21 +194,6 @@ class ModelRegistry:
|
||||
return
|
||||
for model in GROQ_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_openrouter_models(self, settings):
|
||||
from application.core.model_configs import OPENROUTER_MODELS
|
||||
|
||||
if settings.OPEN_ROUTER_API_KEY:
|
||||
for model in OPENROUTER_MODELS:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
if settings.LLM_PROVIDER == "openrouter" and settings.LLM_NAME:
|
||||
for model in OPENROUTER_MODELS:
|
||||
if model.id == settings.LLM_NAME:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
for model in OPENROUTER_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_docsgpt_models(self, settings):
|
||||
model_id = "docsgpt-local"
|
||||
@@ -273,15 +223,6 @@ class ModelRegistry:
|
||||
)
|
||||
self.models[model_id] = model
|
||||
|
||||
def _parse_model_names(self, llm_name: str) -> List[str]:
|
||||
"""
|
||||
Parse LLM_NAME which may contain comma-separated model names.
|
||||
E.g., 'deepseek-r1:1.5b,gemma:2b' -> ['deepseek-r1:1.5b', 'gemma:2b']
|
||||
"""
|
||||
if not llm_name:
|
||||
return []
|
||||
return [name.strip() for name in llm_name.split(",") if name.strip()]
|
||||
|
||||
def get_model(self, model_id: str) -> Optional[AvailableModel]:
|
||||
return self.models.get(model_id)
|
||||
|
||||
|
||||
@@ -9,7 +9,6 @@ def get_api_key_for_provider(provider: str) -> Optional[str]:
|
||||
|
||||
provider_key_map = {
|
||||
"openai": settings.OPENAI_API_KEY,
|
||||
"openrouter": settings.OPEN_ROUTER_API_KEY,
|
||||
"anthropic": settings.ANTHROPIC_API_KEY,
|
||||
"google": settings.GOOGLE_API_KEY,
|
||||
"groq": settings.GROQ_API_KEY,
|
||||
|
||||
@@ -2,8 +2,7 @@ import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import field_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
current_dir = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
@@ -11,19 +10,12 @@ current_dir = os.path.dirname(
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(extra="ignore")
|
||||
|
||||
AUTH_TYPE: Optional[str] = None # simple_jwt, session_jwt, or None
|
||||
LLM_PROVIDER: str = "docsgpt"
|
||||
LLM_NAME: Optional[str] = (
|
||||
None # if LLM_PROVIDER is openai, LLM_NAME can be gpt-4 or gpt-3.5-turbo
|
||||
)
|
||||
EMBEDDINGS_NAME: str = "huggingface_sentence-transformers/all-mpnet-base-v2"
|
||||
EMBEDDINGS_BASE_URL: Optional[str] = None # Remote embeddings API URL (OpenAI-compatible)
|
||||
EMBEDDINGS_KEY: Optional[str] = (
|
||||
None # api key for embeddings (if using openai, just copy API_KEY)
|
||||
)
|
||||
|
||||
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
|
||||
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
|
||||
MONGO_URI: str = "mongodb://localhost:27017/docsgpt"
|
||||
@@ -43,10 +35,8 @@ class Settings(BaseSettings):
|
||||
UPLOAD_FOLDER: str = "inputs"
|
||||
PARSE_PDF_AS_IMAGE: bool = False
|
||||
PARSE_IMAGE_REMOTE: bool = False
|
||||
DOCLING_OCR_ENABLED: bool = False # Enable OCR for docling parsers (PDF, images)
|
||||
DOCLING_OCR_ATTACHMENTS_ENABLED: bool = False # Enable OCR for docling when parsing attachments
|
||||
VECTOR_STORE: str = (
|
||||
"faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb" or "pgvector"
|
||||
"faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb"
|
||||
)
|
||||
RETRIEVERS_ENABLED: list = ["classic_rag"]
|
||||
AGENT_NAME: str = "classic"
|
||||
@@ -65,20 +55,13 @@ class Settings(BaseSettings):
|
||||
"http://127.0.0.1:7091/api/connectors/callback" ##add redirect url as it is to your provider's console(gcp)
|
||||
)
|
||||
|
||||
# Microsoft Entra ID (Azure AD) integration
|
||||
MICROSOFT_CLIENT_ID: Optional[str] = None # Azure AD Application (client) ID
|
||||
MICROSOFT_CLIENT_SECRET: Optional[str] = None # Azure AD Application client secret
|
||||
MICROSOFT_TENANT_ID: Optional[str] = "common" # Azure AD Tenant ID (or 'common' for multi-tenant)
|
||||
MICROSOFT_AUTHORITY: Optional[str] = None # e.g., "https://login.microsoftonline.com/{tenant_id}"
|
||||
|
||||
# GitHub source
|
||||
GITHUB_ACCESS_TOKEN: Optional[str] = None # PAT token with read repo access
|
||||
GITHUB_ACCESS_TOKEN: Optional[str] = None # PAT token with read repo access
|
||||
|
||||
# LLM Cache
|
||||
CACHE_REDIS_URL: str = "redis://localhost:6379/2"
|
||||
|
||||
API_URL: str = "http://localhost:7091" # backend url for celery worker
|
||||
INTERNAL_KEY: Optional[str] = None # internal api key for worker-to-backend auth
|
||||
|
||||
API_KEY: Optional[str] = None # LLM api key (used by LLM_PROVIDER)
|
||||
|
||||
@@ -88,8 +71,10 @@ class Settings(BaseSettings):
|
||||
GOOGLE_API_KEY: Optional[str] = None
|
||||
GROQ_API_KEY: Optional[str] = None
|
||||
HUGGINGFACE_API_KEY: Optional[str] = None
|
||||
OPEN_ROUTER_API_KEY: Optional[str] = None
|
||||
|
||||
EMBEDDINGS_KEY: Optional[str] = (
|
||||
None # api key for embeddings (if using openai, just copy API_KEY)
|
||||
)
|
||||
OPENAI_API_BASE: Optional[str] = None # azure openai api base url
|
||||
OPENAI_API_VERSION: Optional[str] = None # azure openai api version
|
||||
AZURE_DEPLOYMENT_NAME: Optional[str] = None # azure deployment name for answering
|
||||
@@ -139,7 +124,7 @@ class Settings(BaseSettings):
|
||||
MILVUS_TOKEN: Optional[str] = ""
|
||||
|
||||
# LanceDB vectorstore config
|
||||
LANCEDB_PATH: str = "./data/lancedb" # Path where LanceDB stores its local data
|
||||
LANCEDB_PATH: str = "/tmp/lancedb" # Path where LanceDB stores its local data
|
||||
LANCEDB_TABLE_NAME: Optional[str] = (
|
||||
"docsgpts" # Name of the table to use for storing vectors
|
||||
)
|
||||
@@ -159,44 +144,6 @@ class Settings(BaseSettings):
|
||||
# Tool pre-fetch settings
|
||||
ENABLE_TOOL_PREFETCH: bool = True
|
||||
|
||||
# Conversation Compression Settings
|
||||
ENABLE_CONVERSATION_COMPRESSION: bool = True
|
||||
COMPRESSION_THRESHOLD_PERCENTAGE: float = 0.8 # Trigger at 80% of context
|
||||
COMPRESSION_MODEL_OVERRIDE: Optional[str] = None # Use different model for compression
|
||||
COMPRESSION_PROMPT_VERSION: str = "v1.0" # Track prompt iterations
|
||||
COMPRESSION_MAX_HISTORY_POINTS: int = 3 # Keep only last N compression points to prevent DB bloat
|
||||
|
||||
@field_validator(
|
||||
"API_KEY",
|
||||
"OPENAI_API_KEY",
|
||||
"ANTHROPIC_API_KEY",
|
||||
"GOOGLE_API_KEY",
|
||||
"GROQ_API_KEY",
|
||||
"HUGGINGFACE_API_KEY",
|
||||
"EMBEDDINGS_KEY",
|
||||
"FALLBACK_LLM_API_KEY",
|
||||
"QDRANT_API_KEY",
|
||||
"ELEVENLABS_API_KEY",
|
||||
"INTERNAL_KEY",
|
||||
mode="before",
|
||||
)
|
||||
@classmethod
|
||||
def normalize_api_key(cls, v: Optional[str]) -> Optional[str]:
|
||||
"""
|
||||
Normalize API keys: convert 'None', 'none', empty strings,
|
||||
and whitespace-only strings to actual None.
|
||||
Handles Pydantic loading 'None' from .env as string "None".
|
||||
"""
|
||||
if v is None:
|
||||
return None
|
||||
if not isinstance(v, str):
|
||||
return v
|
||||
stripped = v.strip()
|
||||
if stripped == "" or stripped.lower() == "none":
|
||||
return None
|
||||
return stripped
|
||||
|
||||
|
||||
# Project root is one level above application/
|
||||
path = Path(__file__).parent.parent.parent.absolute()
|
||||
path = Path(__file__).parent.parent.absolute()
|
||||
settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8")
|
||||
|
||||
@@ -1,181 +0,0 @@
|
||||
"""
|
||||
URL validation utilities to prevent SSRF (Server-Side Request Forgery) attacks.
|
||||
|
||||
This module provides functions to validate URLs before making HTTP requests,
|
||||
blocking access to internal networks, cloud metadata services, and other
|
||||
potentially dangerous endpoints.
|
||||
"""
|
||||
|
||||
import ipaddress
|
||||
import socket
|
||||
from urllib.parse import urlparse
|
||||
from typing import Optional, Set
|
||||
|
||||
|
||||
class SSRFError(Exception):
|
||||
"""Raised when a URL fails SSRF validation."""
|
||||
pass
|
||||
|
||||
|
||||
# Blocked hostnames that should never be accessed
|
||||
BLOCKED_HOSTNAMES: Set[str] = {
|
||||
"localhost",
|
||||
"localhost.localdomain",
|
||||
"metadata.google.internal",
|
||||
"metadata",
|
||||
}
|
||||
|
||||
# Cloud metadata IP addresses (AWS, GCP, Azure, etc.)
|
||||
METADATA_IPS: Set[str] = {
|
||||
"169.254.169.254", # AWS, GCP, Azure metadata
|
||||
"169.254.170.2", # AWS ECS task metadata
|
||||
"fd00:ec2::254", # AWS IPv6 metadata
|
||||
}
|
||||
|
||||
# Allowed schemes for external requests
|
||||
ALLOWED_SCHEMES: Set[str] = {"http", "https"}
|
||||
|
||||
|
||||
def is_private_ip(ip_str: str) -> bool:
|
||||
"""
|
||||
Check if an IP address is private, loopback, or link-local.
|
||||
|
||||
Args:
|
||||
ip_str: IP address as a string
|
||||
|
||||
Returns:
|
||||
True if the IP is private/internal, False otherwise
|
||||
"""
|
||||
try:
|
||||
ip = ipaddress.ip_address(ip_str)
|
||||
return (
|
||||
ip.is_private or
|
||||
ip.is_loopback or
|
||||
ip.is_link_local or
|
||||
ip.is_reserved or
|
||||
ip.is_multicast or
|
||||
ip.is_unspecified
|
||||
)
|
||||
except ValueError:
|
||||
# If we can't parse it as an IP, return False
|
||||
return False
|
||||
|
||||
|
||||
def is_metadata_ip(ip_str: str) -> bool:
|
||||
"""
|
||||
Check if an IP address is a cloud metadata service IP.
|
||||
|
||||
Args:
|
||||
ip_str: IP address as a string
|
||||
|
||||
Returns:
|
||||
True if the IP is a metadata service, False otherwise
|
||||
"""
|
||||
return ip_str in METADATA_IPS
|
||||
|
||||
|
||||
def resolve_hostname(hostname: str) -> Optional[str]:
|
||||
"""
|
||||
Resolve a hostname to an IP address.
|
||||
|
||||
Args:
|
||||
hostname: The hostname to resolve
|
||||
|
||||
Returns:
|
||||
The resolved IP address, or None if resolution fails
|
||||
"""
|
||||
try:
|
||||
return socket.gethostbyname(hostname)
|
||||
except socket.gaierror:
|
||||
return None
|
||||
|
||||
|
||||
def validate_url(url: str, allow_localhost: bool = False) -> str:
|
||||
"""
|
||||
Validate a URL to prevent SSRF attacks.
|
||||
|
||||
This function checks that:
|
||||
1. The URL has an allowed scheme (http or https)
|
||||
2. The hostname is not a blocked hostname
|
||||
3. The resolved IP is not a private/internal IP
|
||||
4. The resolved IP is not a cloud metadata service
|
||||
|
||||
Args:
|
||||
url: The URL to validate
|
||||
allow_localhost: If True, allow localhost connections (for testing only)
|
||||
|
||||
Returns:
|
||||
The validated URL (with scheme added if missing)
|
||||
|
||||
Raises:
|
||||
SSRFError: If the URL fails validation
|
||||
"""
|
||||
# Ensure URL has a scheme
|
||||
if not urlparse(url).scheme:
|
||||
url = "http://" + url
|
||||
|
||||
parsed = urlparse(url)
|
||||
|
||||
# Check scheme
|
||||
if parsed.scheme not in ALLOWED_SCHEMES:
|
||||
raise SSRFError(f"URL scheme '{parsed.scheme}' is not allowed. Only HTTP(S) is permitted.")
|
||||
|
||||
hostname = parsed.hostname
|
||||
if not hostname:
|
||||
raise SSRFError("URL must have a valid hostname.")
|
||||
|
||||
hostname_lower = hostname.lower()
|
||||
|
||||
# Check blocked hostnames
|
||||
if hostname_lower in BLOCKED_HOSTNAMES and not allow_localhost:
|
||||
raise SSRFError(f"Access to '{hostname}' is not allowed.")
|
||||
|
||||
# Check if hostname is an IP address directly
|
||||
try:
|
||||
ip = ipaddress.ip_address(hostname)
|
||||
ip_str = str(ip)
|
||||
|
||||
if is_metadata_ip(ip_str):
|
||||
raise SSRFError("Access to cloud metadata services is not allowed.")
|
||||
|
||||
if is_private_ip(ip_str) and not allow_localhost:
|
||||
raise SSRFError("Access to private/internal IP addresses is not allowed.")
|
||||
|
||||
return url
|
||||
except ValueError:
|
||||
# Not an IP address, it's a hostname - resolve it
|
||||
pass
|
||||
|
||||
# Resolve hostname and check the IP
|
||||
resolved_ip = resolve_hostname(hostname)
|
||||
if resolved_ip is None:
|
||||
raise SSRFError(f"Unable to resolve hostname: {hostname}")
|
||||
|
||||
if is_metadata_ip(resolved_ip):
|
||||
raise SSRFError("Access to cloud metadata services is not allowed.")
|
||||
|
||||
if is_private_ip(resolved_ip) and not allow_localhost:
|
||||
raise SSRFError("Access to private/internal networks is not allowed.")
|
||||
|
||||
return url
|
||||
|
||||
|
||||
def validate_url_safe(url: str, allow_localhost: bool = False) -> tuple[bool, str, Optional[str]]:
|
||||
"""
|
||||
Validate a URL and return a tuple with validation result.
|
||||
|
||||
This is a non-throwing version of validate_url for cases where
|
||||
you want to handle validation failures gracefully.
|
||||
|
||||
Args:
|
||||
url: The URL to validate
|
||||
allow_localhost: If True, allow localhost connections (for testing only)
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, validated_url_or_original, error_message_or_none)
|
||||
"""
|
||||
try:
|
||||
validated = validate_url(url, allow_localhost)
|
||||
return (True, validated, None)
|
||||
except SSRFError as e:
|
||||
return (False, url, str(e))
|
||||
@@ -1,13 +1,7 @@
|
||||
import base64
|
||||
import logging
|
||||
|
||||
from anthropic import AI_PROMPT, Anthropic, HUMAN_PROMPT
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.llm.base import BaseLLM
|
||||
from application.storage.storage_creator import StorageCreator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnthropicLLM(BaseLLM):
|
||||
@@ -26,7 +20,6 @@ class AnthropicLLM(BaseLLM):
|
||||
|
||||
self.HUMAN_PROMPT = HUMAN_PROMPT
|
||||
self.AI_PROMPT = AI_PROMPT
|
||||
self.storage = StorageCreator.get_storage()
|
||||
|
||||
def _raw_gen(
|
||||
self,
|
||||
@@ -77,115 +70,3 @@ class AnthropicLLM(BaseLLM):
|
||||
finally:
|
||||
if hasattr(stream_response, "close"):
|
||||
stream_response.close()
|
||||
|
||||
def get_supported_attachment_types(self):
|
||||
"""
|
||||
Return a list of MIME types supported by Anthropic Claude for file uploads.
|
||||
Claude supports images but not PDFs natively.
|
||||
PDFs are synthetically supported via PDF-to-image conversion in the handler.
|
||||
|
||||
Returns:
|
||||
list: List of supported MIME types
|
||||
"""
|
||||
return [
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/jpg",
|
||||
"image/webp",
|
||||
"image/gif",
|
||||
]
|
||||
|
||||
def prepare_messages_with_attachments(self, messages, attachments=None):
|
||||
"""
|
||||
Process attachments for Anthropic Claude API.
|
||||
Formats images using Claude's vision message format.
|
||||
|
||||
Args:
|
||||
messages (list): List of message dictionaries.
|
||||
attachments (list): List of attachment dictionaries with content and metadata.
|
||||
|
||||
Returns:
|
||||
list: Messages formatted with image content for Claude API.
|
||||
"""
|
||||
if not attachments:
|
||||
return messages
|
||||
|
||||
prepared_messages = messages.copy()
|
||||
|
||||
# Find the last user message to attach images to
|
||||
user_message_index = None
|
||||
for i in range(len(prepared_messages) - 1, -1, -1):
|
||||
if prepared_messages[i].get("role") == "user":
|
||||
user_message_index = i
|
||||
break
|
||||
|
||||
if user_message_index is None:
|
||||
user_message = {"role": "user", "content": []}
|
||||
prepared_messages.append(user_message)
|
||||
user_message_index = len(prepared_messages) - 1
|
||||
|
||||
# Convert content to list format if it's a string
|
||||
if isinstance(prepared_messages[user_message_index].get("content"), str):
|
||||
text_content = prepared_messages[user_message_index]["content"]
|
||||
prepared_messages[user_message_index]["content"] = [
|
||||
{"type": "text", "text": text_content}
|
||||
]
|
||||
elif not isinstance(prepared_messages[user_message_index].get("content"), list):
|
||||
prepared_messages[user_message_index]["content"] = []
|
||||
|
||||
for attachment in attachments:
|
||||
mime_type = attachment.get("mime_type")
|
||||
|
||||
if mime_type and mime_type.startswith("image/"):
|
||||
try:
|
||||
# Check if this is a pre-converted image (from PDF-to-image conversion)
|
||||
# These have 'data' key with base64 already
|
||||
if "data" in attachment:
|
||||
base64_image = attachment["data"]
|
||||
else:
|
||||
base64_image = self._get_base64_image(attachment)
|
||||
|
||||
# Claude uses a specific format for images
|
||||
prepared_messages[user_message_index]["content"].append(
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": mime_type,
|
||||
"data": base64_image,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing image attachment: {e}", exc_info=True
|
||||
)
|
||||
if "content" in attachment:
|
||||
prepared_messages[user_message_index]["content"].append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"[Image could not be processed: {attachment.get('path', 'unknown')}]",
|
||||
}
|
||||
)
|
||||
|
||||
return prepared_messages
|
||||
|
||||
def _get_base64_image(self, attachment):
|
||||
"""
|
||||
Convert an image file to base64 encoding.
|
||||
|
||||
Args:
|
||||
attachment (dict): Attachment dictionary with path and metadata.
|
||||
|
||||
Returns:
|
||||
str: Base64-encoded image data.
|
||||
"""
|
||||
file_path = attachment.get("path")
|
||||
if not file_path:
|
||||
raise ValueError("No file path provided in attachment")
|
||||
try:
|
||||
with self.storage.get_file(file_path) as image_file:
|
||||
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||
except FileNotFoundError:
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
@@ -1,19 +1,75 @@
|
||||
import json
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.llm.openai import OpenAILLM
|
||||
from application.llm.base import BaseLLM
|
||||
|
||||
DOCSGPT_API_KEY = "sk-docsgpt-public"
|
||||
DOCSGPT_BASE_URL = "https://oai.arc53.com"
|
||||
DOCSGPT_MODEL = "docsgpt"
|
||||
|
||||
class DocsGPTAPILLM(OpenAILLM):
|
||||
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
|
||||
super().__init__(
|
||||
api_key=DOCSGPT_API_KEY,
|
||||
user_api_key=user_api_key,
|
||||
base_url=DOCSGPT_BASE_URL,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
class DocsGPTAPILLM(BaseLLM):
|
||||
|
||||
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
self.api_key = "sk-docsgpt-public"
|
||||
self.client = OpenAI(api_key=self.api_key, base_url="https://oai.arc53.com")
|
||||
self.user_api_key = user_api_key
|
||||
|
||||
def _clean_messages_openai(self, messages):
|
||||
cleaned_messages = []
|
||||
for message in messages:
|
||||
role = message.get("role")
|
||||
content = message.get("content")
|
||||
|
||||
if role == "model":
|
||||
role = "assistant"
|
||||
if role and content is not None:
|
||||
if isinstance(content, str):
|
||||
cleaned_messages.append({"role": role, "content": content})
|
||||
elif isinstance(content, list):
|
||||
for item in content:
|
||||
if "text" in item:
|
||||
cleaned_messages.append(
|
||||
{"role": role, "content": item["text"]}
|
||||
)
|
||||
elif "function_call" in item:
|
||||
cleaned_args = self._remove_null_values(
|
||||
item["function_call"]["args"]
|
||||
)
|
||||
tool_call = {
|
||||
"id": item["function_call"]["call_id"],
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": item["function_call"]["name"],
|
||||
"arguments": json.dumps(cleaned_args),
|
||||
},
|
||||
}
|
||||
cleaned_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [tool_call],
|
||||
}
|
||||
)
|
||||
elif "function_response" in item:
|
||||
cleaned_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": item["function_response"][
|
||||
"call_id"
|
||||
],
|
||||
"content": json.dumps(
|
||||
item["function_response"]["response"]["result"]
|
||||
),
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected content dictionary format: {item}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unexpected content type: {type(content)}")
|
||||
return cleaned_messages
|
||||
|
||||
def _raw_gen(
|
||||
self,
|
||||
@@ -23,19 +79,23 @@ class DocsGPTAPILLM(OpenAILLM):
|
||||
stream=False,
|
||||
tools=None,
|
||||
engine=settings.AZURE_DEPLOYMENT_NAME,
|
||||
response_format=None,
|
||||
**kwargs,
|
||||
):
|
||||
return super()._raw_gen(
|
||||
baseself,
|
||||
DOCSGPT_MODEL,
|
||||
messages,
|
||||
stream=stream,
|
||||
tools=tools,
|
||||
engine=engine,
|
||||
response_format=response_format,
|
||||
**kwargs,
|
||||
)
|
||||
messages = self._clean_messages_openai(messages)
|
||||
if tools:
|
||||
response = self.client.chat.completions.create(
|
||||
model="docsgpt",
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
tools=tools,
|
||||
**kwargs,
|
||||
)
|
||||
return response.choices[0]
|
||||
else:
|
||||
response = self.client.chat.completions.create(
|
||||
model="docsgpt", messages=messages, stream=stream, **kwargs
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
|
||||
def _raw_gen_stream(
|
||||
self,
|
||||
@@ -45,16 +105,34 @@ class DocsGPTAPILLM(OpenAILLM):
|
||||
stream=True,
|
||||
tools=None,
|
||||
engine=settings.AZURE_DEPLOYMENT_NAME,
|
||||
response_format=None,
|
||||
**kwargs,
|
||||
):
|
||||
return super()._raw_gen_stream(
|
||||
baseself,
|
||||
DOCSGPT_MODEL,
|
||||
messages,
|
||||
stream=stream,
|
||||
tools=tools,
|
||||
engine=engine,
|
||||
response_format=response_format,
|
||||
**kwargs,
|
||||
)
|
||||
messages = self._clean_messages_openai(messages)
|
||||
if tools:
|
||||
response = self.client.chat.completions.create(
|
||||
model="docsgpt",
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
tools=tools,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
response = self.client.chat.completions.create(
|
||||
model="docsgpt", messages=messages, stream=stream, **kwargs
|
||||
)
|
||||
try:
|
||||
for line in response:
|
||||
if (
|
||||
len(line.choices) > 0
|
||||
and line.choices[0].delta.content is not None
|
||||
and len(line.choices[0].delta.content) > 0
|
||||
):
|
||||
yield line.choices[0].delta.content
|
||||
elif len(line.choices) > 0:
|
||||
yield line.choices[0]
|
||||
finally:
|
||||
if hasattr(response, "close"):
|
||||
response.close()
|
||||
|
||||
def _supports_tools(self):
|
||||
return True
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from google import genai
|
||||
@@ -10,13 +11,11 @@ from application.storage.storage_creator import StorageCreator
|
||||
|
||||
|
||||
class GoogleLLM(BaseLLM):
|
||||
def __init__(
|
||||
self, api_key=None, user_api_key=None, decoded_token=None, *args, **kwargs
|
||||
):
|
||||
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.api_key = api_key or settings.GOOGLE_API_KEY or settings.API_KEY
|
||||
self.user_api_key = user_api_key
|
||||
|
||||
|
||||
self.client = genai.Client(api_key=self.api_key)
|
||||
self.storage = StorageCreator.get_storage()
|
||||
|
||||
@@ -34,12 +33,6 @@ class GoogleLLM(BaseLLM):
|
||||
"image/jpg",
|
||||
"image/webp",
|
||||
"image/gif",
|
||||
"application/pdf",
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/jpg",
|
||||
"image/webp",
|
||||
"image/gif",
|
||||
]
|
||||
|
||||
def prepare_messages_with_attachments(self, messages, attachments=None):
|
||||
@@ -142,38 +135,12 @@ class GoogleLLM(BaseLLM):
|
||||
raise
|
||||
|
||||
def _clean_messages_google(self, messages):
|
||||
"""
|
||||
Convert OpenAI format messages to Google AI format and collect system prompts.
|
||||
|
||||
Returns:
|
||||
tuple[list[types.Content], Optional[str]]: cleaned messages and optional
|
||||
combined system instruction.
|
||||
"""
|
||||
"""Convert OpenAI format messages to Google AI format."""
|
||||
cleaned_messages = []
|
||||
system_instructions = []
|
||||
|
||||
def _extract_system_text(content):
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict) and "text" in item and item["text"] is not None:
|
||||
parts.append(item["text"])
|
||||
return "\n".join(parts)
|
||||
return ""
|
||||
|
||||
for message in messages:
|
||||
role = message.get("role")
|
||||
content = message.get("content")
|
||||
|
||||
# Gemini only accepts user/model in the contents list.
|
||||
if role == "system":
|
||||
sys_text = _extract_system_text(content)
|
||||
if sys_text:
|
||||
system_instructions.append(sys_text)
|
||||
continue
|
||||
|
||||
if role == "assistant":
|
||||
role = "model"
|
||||
elif role == "tool":
|
||||
@@ -192,27 +159,12 @@ class GoogleLLM(BaseLLM):
|
||||
cleaned_args = self._remove_null_values(
|
||||
item["function_call"]["args"]
|
||||
)
|
||||
# Create function call part with thought_signature if present
|
||||
# For Gemini 3 models, we need to include thought_signature
|
||||
if "thought_signature" in item:
|
||||
# Use Part constructor with functionCall and thoughtSignature
|
||||
parts.append(
|
||||
types.Part(
|
||||
functionCall=types.FunctionCall(
|
||||
name=item["function_call"]["name"],
|
||||
args=cleaned_args,
|
||||
),
|
||||
thoughtSignature=item["thought_signature"],
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Use helper method when no thought_signature
|
||||
parts.append(
|
||||
types.Part.from_function_call(
|
||||
name=item["function_call"]["name"],
|
||||
args=cleaned_args,
|
||||
)
|
||||
parts.append(
|
||||
types.Part.from_function_call(
|
||||
name=item["function_call"]["name"],
|
||||
args=cleaned_args,
|
||||
)
|
||||
)
|
||||
elif "function_response" in item:
|
||||
parts.append(
|
||||
types.Part.from_function_response(
|
||||
@@ -236,8 +188,7 @@ class GoogleLLM(BaseLLM):
|
||||
raise ValueError(f"Unexpected content type: {type(content)}")
|
||||
if parts:
|
||||
cleaned_messages.append(types.Content(role=role, parts=parts))
|
||||
system_instruction = "\n\n".join(system_instructions) if system_instructions else None
|
||||
return cleaned_messages, system_instruction
|
||||
return cleaned_messages
|
||||
|
||||
def _clean_schema(self, schema_obj):
|
||||
"""
|
||||
@@ -323,77 +274,6 @@ class GoogleLLM(BaseLLM):
|
||||
genai_tools.append(genai_tool)
|
||||
return genai_tools
|
||||
|
||||
def _extract_preview_from_message(self, message):
|
||||
"""Get a short, human-readable preview from the last message."""
|
||||
try:
|
||||
if hasattr(message, "parts"):
|
||||
for part in reversed(message.parts):
|
||||
if getattr(part, "text", None):
|
||||
return part.text
|
||||
function_call = getattr(part, "function_call", None)
|
||||
if function_call:
|
||||
name = getattr(function_call, "name", "") or "function_call"
|
||||
return f"function_call:{name}"
|
||||
function_response = getattr(part, "function_response", None)
|
||||
if function_response:
|
||||
name = getattr(function_response, "name", "") or "function_response"
|
||||
return f"function_response:{name}"
|
||||
if isinstance(message, dict):
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
for item in reversed(content):
|
||||
if isinstance(item, str):
|
||||
return item
|
||||
if isinstance(item, dict):
|
||||
if item.get("text"):
|
||||
return item["text"]
|
||||
if item.get("function_call"):
|
||||
fn = item["function_call"]
|
||||
if isinstance(fn, dict):
|
||||
name = fn.get("name") or "function_call"
|
||||
return f"function_call:{name}"
|
||||
return "function_call"
|
||||
if item.get("function_response"):
|
||||
resp = item["function_response"]
|
||||
if isinstance(resp, dict):
|
||||
name = resp.get("name") or "function_response"
|
||||
return f"function_response:{name}"
|
||||
return "function_response"
|
||||
if "text" in message and isinstance(message["text"], str):
|
||||
return message["text"]
|
||||
except Exception:
|
||||
pass
|
||||
return str(message)
|
||||
|
||||
def _summarize_messages_for_log(self, messages, preview_chars=20):
|
||||
"""Return a compact summary for logging to avoid huge payloads."""
|
||||
message_count = len(messages) if messages else 0
|
||||
last_preview = ""
|
||||
if messages:
|
||||
last_preview = self._extract_preview_from_message(messages[-1]) or ""
|
||||
last_preview = str(last_preview).replace("\n", " ")
|
||||
if len(last_preview) > preview_chars:
|
||||
last_preview = f"{last_preview[:preview_chars]}..."
|
||||
return f"count={message_count}, last='{last_preview}'"
|
||||
|
||||
@staticmethod
|
||||
def _get_text_value(part):
|
||||
"""Get text from both SDK objects and dict-shaped test doubles."""
|
||||
if isinstance(part, dict):
|
||||
value = part.get("text")
|
||||
return value if isinstance(value, str) else ""
|
||||
value = getattr(part, "text", None)
|
||||
return value if isinstance(value, str) else ""
|
||||
|
||||
@staticmethod
|
||||
def _is_thought_part(part):
|
||||
"""Detect Gemini thinking parts when available."""
|
||||
if isinstance(part, dict):
|
||||
return bool(part.get("thought"))
|
||||
return bool(getattr(part, "thought", False))
|
||||
|
||||
def _raw_gen(
|
||||
self,
|
||||
baseself,
|
||||
@@ -407,12 +287,12 @@ class GoogleLLM(BaseLLM):
|
||||
):
|
||||
"""Generate content using Google AI API without streaming."""
|
||||
client = genai.Client(api_key=self.api_key)
|
||||
system_instruction = None
|
||||
if formatting == "openai":
|
||||
messages, system_instruction = self._clean_messages_google(messages)
|
||||
messages = self._clean_messages_google(messages)
|
||||
config = types.GenerateContentConfig()
|
||||
if system_instruction:
|
||||
config.system_instruction = system_instruction
|
||||
if messages[0].role == "system":
|
||||
config.system_instruction = messages[0].parts[0].text
|
||||
messages = messages[1:]
|
||||
if tools:
|
||||
cleaned_tools = self._clean_tools_format(tools)
|
||||
config.tools = cleaned_tools
|
||||
@@ -445,15 +325,16 @@ class GoogleLLM(BaseLLM):
|
||||
):
|
||||
"""Generate content using Google AI API with streaming."""
|
||||
client = genai.Client(api_key=self.api_key)
|
||||
system_instruction = None
|
||||
if formatting == "openai":
|
||||
messages, system_instruction = self._clean_messages_google(messages)
|
||||
messages = self._clean_messages_google(messages)
|
||||
config = types.GenerateContentConfig()
|
||||
if system_instruction:
|
||||
config.system_instruction = system_instruction
|
||||
if messages[0].role == "system":
|
||||
config.system_instruction = messages[0].parts[0].text
|
||||
messages = messages[1:]
|
||||
if tools:
|
||||
cleaned_tools = self._clean_tools_format(tools)
|
||||
config.tools = cleaned_tools
|
||||
# Add response schema for structured output if provided
|
||||
|
||||
if response_schema:
|
||||
config.response_schema = response_schema
|
||||
@@ -468,12 +349,8 @@ class GoogleLLM(BaseLLM):
|
||||
break
|
||||
if has_attachments:
|
||||
break
|
||||
messages_summary = self._summarize_messages_for_log(messages)
|
||||
logging.info(
|
||||
"GoogleLLM: Starting stream generation. Model: %s, Messages: %s, Has attachments: %s",
|
||||
model,
|
||||
messages_summary,
|
||||
has_attachments,
|
||||
f"GoogleLLM: Starting stream generation. Model: {model}, Messages: {json.dumps(messages, default=str)}, Has attachments: {has_attachments}"
|
||||
)
|
||||
|
||||
response = client.models.generate_content_stream(
|
||||
@@ -490,23 +367,10 @@ class GoogleLLM(BaseLLM):
|
||||
for part in candidate.content.parts:
|
||||
if part.function_call:
|
||||
yield part
|
||||
continue
|
||||
|
||||
part_text = self._get_text_value(part)
|
||||
if not part_text:
|
||||
continue
|
||||
|
||||
if self._is_thought_part(part):
|
||||
yield {"type": "thought", "thought": part_text}
|
||||
else:
|
||||
yield part_text
|
||||
elif part.text:
|
||||
yield part.text
|
||||
elif hasattr(chunk, "text"):
|
||||
chunk_text = self._get_text_value(chunk)
|
||||
if chunk_text:
|
||||
if self._is_thought_part(chunk):
|
||||
yield {"type": "thought", "thought": chunk_text}
|
||||
else:
|
||||
yield chunk_text
|
||||
yield chunk.text
|
||||
finally:
|
||||
if hasattr(response, "close"):
|
||||
response.close()
|
||||
|
||||
@@ -1,15 +1,37 @@
|
||||
from openai import OpenAI
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.llm.openai import OpenAILLM
|
||||
|
||||
GROQ_BASE_URL = "https://api.groq.com/openai/v1"
|
||||
from application.llm.base import BaseLLM
|
||||
|
||||
|
||||
class GroqLLM(OpenAILLM):
|
||||
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
|
||||
super().__init__(
|
||||
api_key=api_key or settings.GROQ_API_KEY or settings.API_KEY,
|
||||
user_api_key=user_api_key,
|
||||
base_url=base_url or GROQ_BASE_URL,
|
||||
*args,
|
||||
**kwargs,
|
||||
class GroqLLM(BaseLLM):
|
||||
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
self.api_key = api_key or settings.GROQ_API_KEY or settings.API_KEY
|
||||
self.user_api_key = user_api_key
|
||||
self.client = OpenAI(
|
||||
api_key=self.api_key, base_url="https://api.groq.com/openai/v1"
|
||||
)
|
||||
|
||||
def _raw_gen(self, baseself, model, messages, stream=False, tools=None, **kwargs):
|
||||
if tools:
|
||||
response = self.client.chat.completions.create(
|
||||
model=model, messages=messages, stream=stream, tools=tools, **kwargs
|
||||
)
|
||||
return response.choices[0]
|
||||
else:
|
||||
response = self.client.chat.completions.create(
|
||||
model=model, messages=messages, stream=stream, **kwargs
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
|
||||
def _raw_gen_stream(
|
||||
self, baseself, model, messages, stream=True, tools=None, **kwargs
|
||||
):
|
||||
response = self.client.chat.completions.create(
|
||||
model=model, messages=messages, stream=stream, **kwargs
|
||||
)
|
||||
for line in response:
|
||||
if line.choices[0].delta.content is not None:
|
||||
yield line.choices[0].delta.content
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Generator, List, Optional, Union
|
||||
@@ -17,7 +16,6 @@ class ToolCall:
|
||||
name: str
|
||||
arguments: Union[str, Dict]
|
||||
index: Optional[int] = None
|
||||
thought_signature: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict) -> "ToolCall":
|
||||
@@ -105,7 +103,6 @@ class LLMHandler(ABC):
|
||||
"""
|
||||
Prepare messages with attachments and provider-specific formatting.
|
||||
|
||||
|
||||
Args:
|
||||
agent: The agent instance
|
||||
messages: Original messages
|
||||
@@ -119,40 +116,11 @@ class LLMHandler(ABC):
|
||||
logger.info(f"Preparing messages with {len(attachments)} attachments")
|
||||
supported_types = agent.llm.get_supported_attachment_types()
|
||||
|
||||
# Check if provider supports images but not PDF (synthetic PDF support)
|
||||
supports_images = any(t.startswith("image/") for t in supported_types)
|
||||
supports_pdf = "application/pdf" in supported_types
|
||||
|
||||
# Process attachments, converting PDFs to images if needed
|
||||
processed_attachments = []
|
||||
for attachment in attachments:
|
||||
mime_type = attachment.get("mime_type")
|
||||
|
||||
# Synthetic PDF support: convert PDF to images if LLM supports images but not PDF
|
||||
if mime_type == "application/pdf" and supports_images and not supports_pdf:
|
||||
logger.info(
|
||||
f"Converting PDF to images for synthetic PDF support: {attachment.get('path', 'unknown')}"
|
||||
)
|
||||
try:
|
||||
converted_images = self._convert_pdf_to_images(attachment)
|
||||
processed_attachments.extend(converted_images)
|
||||
logger.info(
|
||||
f"Converted PDF to {len(converted_images)} images"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to convert PDF to images, falling back to text: {e}"
|
||||
)
|
||||
# Fall back to treating as unsupported (text extraction)
|
||||
processed_attachments.append(attachment)
|
||||
else:
|
||||
processed_attachments.append(attachment)
|
||||
|
||||
supported_attachments = [
|
||||
a for a in processed_attachments if a.get("mime_type") in supported_types
|
||||
a for a in attachments if a.get("mime_type") in supported_types
|
||||
]
|
||||
unsupported_attachments = [
|
||||
a for a in processed_attachments if a.get("mime_type") not in supported_types
|
||||
a for a in attachments if a.get("mime_type") not in supported_types
|
||||
]
|
||||
|
||||
# Process supported attachments with the LLM's custom method
|
||||
@@ -175,37 +143,6 @@ class LLMHandler(ABC):
|
||||
)
|
||||
return messages
|
||||
|
||||
def _convert_pdf_to_images(self, attachment: Dict) -> List[Dict]:
|
||||
"""
|
||||
Convert a PDF attachment to a list of image attachments.
|
||||
|
||||
This enables synthetic PDF support for LLMs that support images but not PDFs.
|
||||
|
||||
Args:
|
||||
attachment: PDF attachment dictionary with 'path' and optional 'content'
|
||||
|
||||
Returns:
|
||||
List of image attachment dictionaries with 'data', 'mime_type', and 'page'
|
||||
"""
|
||||
from application.utils import convert_pdf_to_images
|
||||
from application.storage.storage_creator import StorageCreator
|
||||
|
||||
file_path = attachment.get("path")
|
||||
if not file_path:
|
||||
raise ValueError("No file path provided in PDF attachment")
|
||||
|
||||
storage = StorageCreator.get_storage()
|
||||
|
||||
# Convert PDF to images
|
||||
images_data = convert_pdf_to_images(
|
||||
file_path=file_path,
|
||||
storage=storage,
|
||||
max_pages=20,
|
||||
dpi=150,
|
||||
)
|
||||
|
||||
return images_data
|
||||
|
||||
def _append_unsupported_attachments(
|
||||
self, messages: List[Dict], attachments: List[Dict]
|
||||
) -> List[Dict]:
|
||||
@@ -241,406 +178,6 @@ class LLMHandler(ABC):
|
||||
system_msg["content"] += f"\n\n{combined_text}"
|
||||
return prepared_messages
|
||||
|
||||
def _prune_messages_minimal(self, messages: List[Dict]) -> Optional[List[Dict]]:
|
||||
"""
|
||||
Build a minimal context: system prompt + latest user message only.
|
||||
Drops all tool/function messages to shrink context aggressively.
|
||||
"""
|
||||
system_message = next((m for m in messages if m.get("role") == "system"), None)
|
||||
if not system_message:
|
||||
logger.warning("Cannot prune messages minimally: missing system message.")
|
||||
return None
|
||||
last_non_system = None
|
||||
for m in reversed(messages):
|
||||
if m.get("role") == "user":
|
||||
last_non_system = m
|
||||
break
|
||||
if not last_non_system and m.get("role") not in ("system", None):
|
||||
last_non_system = m
|
||||
if not last_non_system:
|
||||
logger.warning("Cannot prune messages minimally: missing user/assistant messages.")
|
||||
return None
|
||||
logger.info("Pruning context to system + latest user/assistant message to proceed.")
|
||||
return [system_message, last_non_system]
|
||||
|
||||
def _extract_text_from_content(self, content: Any) -> str:
|
||||
"""
|
||||
Convert message content (str or list of parts) to plain text for compression.
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts_text = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
if "text" in item and item["text"] is not None:
|
||||
parts_text.append(str(item["text"]))
|
||||
elif "function_call" in item or "function_response" in item:
|
||||
# Keep serialized function calls/responses so the compressor sees actions
|
||||
parts_text.append(str(item))
|
||||
elif "files" in item:
|
||||
parts_text.append(str(item))
|
||||
return "\n".join(parts_text)
|
||||
return ""
|
||||
|
||||
def _build_conversation_from_messages(self, messages: List[Dict]) -> Optional[Dict]:
|
||||
"""
|
||||
Build a conversation-like dict from current messages so we can compress
|
||||
even when the conversation isn't persisted yet. Includes tool calls/results.
|
||||
"""
|
||||
queries = []
|
||||
current_prompt = None
|
||||
current_tool_calls = {}
|
||||
|
||||
def _commit_query(response_text: str):
|
||||
nonlocal current_prompt, current_tool_calls
|
||||
if current_prompt is None and not response_text:
|
||||
return
|
||||
tool_calls_list = list(current_tool_calls.values())
|
||||
queries.append(
|
||||
{
|
||||
"prompt": current_prompt or "",
|
||||
"response": response_text,
|
||||
"tool_calls": tool_calls_list,
|
||||
}
|
||||
)
|
||||
current_prompt = None
|
||||
current_tool_calls = {}
|
||||
|
||||
for message in messages:
|
||||
role = message.get("role")
|
||||
content = message.get("content")
|
||||
|
||||
if role == "user":
|
||||
current_prompt = self._extract_text_from_content(content)
|
||||
|
||||
elif role in {"assistant", "model"}:
|
||||
# If this assistant turn contains tool calls, collect them; otherwise commit a response.
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if "function_call" in item:
|
||||
fc = item["function_call"]
|
||||
call_id = fc.get("call_id") or str(uuid.uuid4())
|
||||
current_tool_calls[call_id] = {
|
||||
"tool_name": "unknown_tool",
|
||||
"action_name": fc.get("name"),
|
||||
"arguments": fc.get("args"),
|
||||
"result": None,
|
||||
"status": "called",
|
||||
"call_id": call_id,
|
||||
}
|
||||
elif "function_response" in item:
|
||||
fr = item["function_response"]
|
||||
call_id = fr.get("call_id") or str(uuid.uuid4())
|
||||
current_tool_calls[call_id] = {
|
||||
"tool_name": "unknown_tool",
|
||||
"action_name": fr.get("name"),
|
||||
"arguments": None,
|
||||
"result": fr.get("response", {}).get("result"),
|
||||
"status": "completed",
|
||||
"call_id": call_id,
|
||||
}
|
||||
# No direct assistant text here; continue to next message
|
||||
continue
|
||||
|
||||
response_text = self._extract_text_from_content(content)
|
||||
_commit_query(response_text)
|
||||
|
||||
elif role == "tool":
|
||||
# Attach tool outputs to the latest pending tool call if possible
|
||||
tool_text = self._extract_text_from_content(content)
|
||||
# Attempt to parse function_response style
|
||||
call_id = None
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if "function_response" in item and item["function_response"].get("call_id"):
|
||||
call_id = item["function_response"]["call_id"]
|
||||
break
|
||||
if call_id and call_id in current_tool_calls:
|
||||
current_tool_calls[call_id]["result"] = tool_text
|
||||
current_tool_calls[call_id]["status"] = "completed"
|
||||
elif queries:
|
||||
queries[-1].setdefault("tool_calls", []).append(
|
||||
{
|
||||
"tool_name": "unknown_tool",
|
||||
"action_name": "unknown_action",
|
||||
"arguments": {},
|
||||
"result": tool_text,
|
||||
"status": "completed",
|
||||
}
|
||||
)
|
||||
|
||||
# If there's an unfinished prompt with tool_calls but no response yet, commit it
|
||||
if current_prompt is not None or current_tool_calls:
|
||||
_commit_query(response_text="")
|
||||
|
||||
if not queries:
|
||||
return None
|
||||
|
||||
return {
|
||||
"queries": queries,
|
||||
"compression_metadata": {
|
||||
"is_compressed": False,
|
||||
"compression_points": [],
|
||||
},
|
||||
}
|
||||
|
||||
def _rebuild_messages_after_compression(
|
||||
self,
|
||||
messages: List[Dict],
|
||||
compressed_summary: Optional[str],
|
||||
recent_queries: List[Dict],
|
||||
include_current_execution: bool = False,
|
||||
include_tool_calls: bool = False,
|
||||
) -> Optional[List[Dict]]:
|
||||
"""
|
||||
Rebuild the message list after compression so tool execution can continue.
|
||||
|
||||
Delegates to MessageBuilder for the actual reconstruction.
|
||||
"""
|
||||
from application.api.answer.services.compression.message_builder import (
|
||||
MessageBuilder,
|
||||
)
|
||||
|
||||
return MessageBuilder.rebuild_messages_after_compression(
|
||||
messages=messages,
|
||||
compressed_summary=compressed_summary,
|
||||
recent_queries=recent_queries,
|
||||
include_current_execution=include_current_execution,
|
||||
include_tool_calls=include_tool_calls,
|
||||
)
|
||||
|
||||
def _perform_mid_execution_compression(
|
||||
self, agent, messages: List[Dict]
|
||||
) -> tuple[bool, Optional[List[Dict]]]:
|
||||
"""
|
||||
Perform compression during tool execution and rebuild messages.
|
||||
|
||||
Uses the new orchestrator for simplified compression.
|
||||
|
||||
Args:
|
||||
agent: The agent instance
|
||||
messages: Current conversation messages
|
||||
|
||||
Returns:
|
||||
(success: bool, rebuilt_messages: Optional[List[Dict]])
|
||||
"""
|
||||
try:
|
||||
from application.api.answer.services.compression import (
|
||||
CompressionOrchestrator,
|
||||
)
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
|
||||
conversation_service = ConversationService()
|
||||
orchestrator = CompressionOrchestrator(conversation_service)
|
||||
|
||||
# Get conversation from database (may be None for new sessions)
|
||||
conversation = conversation_service.get_conversation(
|
||||
agent.conversation_id, agent.initial_user_id
|
||||
)
|
||||
|
||||
if conversation:
|
||||
# Merge current in-flight messages (including tool calls)
|
||||
conversation_from_msgs = self._build_conversation_from_messages(messages)
|
||||
if conversation_from_msgs:
|
||||
conversation = conversation_from_msgs
|
||||
else:
|
||||
logger.warning(
|
||||
"Could not load conversation for compression; attempting in-memory compression"
|
||||
)
|
||||
return self._perform_in_memory_compression(agent, messages)
|
||||
|
||||
# Use orchestrator to perform compression
|
||||
result = orchestrator.compress_mid_execution(
|
||||
conversation_id=agent.conversation_id,
|
||||
user_id=agent.initial_user_id,
|
||||
model_id=agent.model_id,
|
||||
decoded_token=getattr(agent, "decoded_token", {}),
|
||||
current_conversation=conversation,
|
||||
)
|
||||
|
||||
if not result.success:
|
||||
logger.warning(f"Mid-execution compression failed: {result.error}")
|
||||
# Try minimal pruning as fallback
|
||||
pruned = self._prune_messages_minimal(messages)
|
||||
if pruned:
|
||||
agent.context_limit_reached = False
|
||||
agent.current_token_count = 0
|
||||
return True, pruned
|
||||
return False, None
|
||||
|
||||
if not result.compression_performed:
|
||||
logger.warning("Compression not performed")
|
||||
return False, None
|
||||
|
||||
# Check if compression actually reduced tokens
|
||||
if result.metadata:
|
||||
if result.metadata.compressed_token_count >= result.metadata.original_token_count:
|
||||
logger.warning(
|
||||
"Compression did not reduce token count; falling back to minimal pruning"
|
||||
)
|
||||
pruned = self._prune_messages_minimal(messages)
|
||||
if pruned:
|
||||
agent.context_limit_reached = False
|
||||
agent.current_token_count = 0
|
||||
return True, pruned
|
||||
return False, None
|
||||
|
||||
logger.info(
|
||||
f"Mid-execution compression successful - ratio: {result.metadata.compression_ratio:.1f}x, "
|
||||
f"saved {result.metadata.original_token_count - result.metadata.compressed_token_count} tokens"
|
||||
)
|
||||
|
||||
# Also store the compression summary as a visible message
|
||||
if result.metadata:
|
||||
conversation_service.append_compression_message(
|
||||
agent.conversation_id, result.metadata.to_dict()
|
||||
)
|
||||
|
||||
# Update agent's compressed summary for downstream persistence
|
||||
agent.compressed_summary = result.compressed_summary
|
||||
agent.compression_metadata = result.metadata.to_dict() if result.metadata else None
|
||||
agent.compression_saved = False
|
||||
|
||||
# Reset the context limit flag so tools can continue
|
||||
agent.context_limit_reached = False
|
||||
agent.current_token_count = 0
|
||||
|
||||
# Rebuild messages
|
||||
rebuilt_messages = self._rebuild_messages_after_compression(
|
||||
messages,
|
||||
result.compressed_summary,
|
||||
result.recent_queries,
|
||||
include_current_execution=False,
|
||||
include_tool_calls=False,
|
||||
)
|
||||
|
||||
if rebuilt_messages is None:
|
||||
return False, None
|
||||
|
||||
return True, rebuilt_messages
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error performing mid-execution compression: {str(e)}", exc_info=True
|
||||
)
|
||||
return False, None
|
||||
|
||||
def _perform_in_memory_compression(
|
||||
self, agent, messages: List[Dict]
|
||||
) -> tuple[bool, Optional[List[Dict]]]:
|
||||
"""
|
||||
Fallback compression path when the conversation is not yet persisted.
|
||||
|
||||
Uses CompressionService directly without DB persistence.
|
||||
"""
|
||||
try:
|
||||
from application.api.answer.services.compression.service import (
|
||||
CompressionService,
|
||||
)
|
||||
from application.core.model_utils import (
|
||||
get_api_key_for_provider,
|
||||
get_provider_from_model_id,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
|
||||
conversation = self._build_conversation_from_messages(messages)
|
||||
if not conversation:
|
||||
logger.warning(
|
||||
"Cannot perform in-memory compression: no user/assistant turns found"
|
||||
)
|
||||
return False, None
|
||||
|
||||
compression_model = (
|
||||
settings.COMPRESSION_MODEL_OVERRIDE
|
||||
if settings.COMPRESSION_MODEL_OVERRIDE
|
||||
else agent.model_id
|
||||
)
|
||||
provider = get_provider_from_model_id(compression_model)
|
||||
api_key = get_api_key_for_provider(provider)
|
||||
compression_llm = LLMCreator.create_llm(
|
||||
provider,
|
||||
api_key,
|
||||
getattr(agent, "user_api_key", None),
|
||||
getattr(agent, "decoded_token", None),
|
||||
model_id=compression_model,
|
||||
)
|
||||
|
||||
# Create service without DB persistence capability
|
||||
compression_service = CompressionService(
|
||||
llm=compression_llm,
|
||||
model_id=compression_model,
|
||||
conversation_service=None, # No DB updates for in-memory
|
||||
)
|
||||
|
||||
queries_count = len(conversation.get("queries", []))
|
||||
compress_up_to = queries_count - 1
|
||||
|
||||
if compress_up_to < 0 or queries_count == 0:
|
||||
logger.warning("Not enough queries to compress in-memory context")
|
||||
return False, None
|
||||
|
||||
metadata = compression_service.compress_conversation(
|
||||
conversation,
|
||||
compress_up_to_index=compress_up_to,
|
||||
)
|
||||
|
||||
# If compression doesn't reduce tokens, fall back to minimal pruning
|
||||
if (
|
||||
metadata.compressed_token_count
|
||||
>= metadata.original_token_count
|
||||
):
|
||||
logger.warning(
|
||||
"In-memory compression did not reduce token count; falling back to minimal pruning"
|
||||
)
|
||||
pruned = self._prune_messages_minimal(messages)
|
||||
if pruned:
|
||||
agent.context_limit_reached = False
|
||||
agent.current_token_count = 0
|
||||
return True, pruned
|
||||
return False, None
|
||||
|
||||
# Attach metadata to synthetic conversation
|
||||
conversation["compression_metadata"] = {
|
||||
"is_compressed": True,
|
||||
"compression_points": [metadata.to_dict()],
|
||||
}
|
||||
|
||||
compressed_summary, recent_queries = (
|
||||
compression_service.get_compressed_context(conversation)
|
||||
)
|
||||
|
||||
agent.compressed_summary = compressed_summary
|
||||
agent.compression_metadata = metadata.to_dict()
|
||||
agent.compression_saved = False
|
||||
agent.context_limit_reached = False
|
||||
agent.current_token_count = 0
|
||||
|
||||
rebuilt_messages = self._rebuild_messages_after_compression(
|
||||
messages,
|
||||
compressed_summary,
|
||||
recent_queries,
|
||||
include_current_execution=False,
|
||||
include_tool_calls=False,
|
||||
)
|
||||
if rebuilt_messages is None:
|
||||
return False, None
|
||||
|
||||
logger.info(
|
||||
f"In-memory compression successful - ratio: {metadata.compression_ratio:.1f}x, "
|
||||
f"saved {metadata.original_token_count - metadata.compressed_token_count} tokens"
|
||||
)
|
||||
return True, rebuilt_messages
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error performing in-memory compression: {str(e)}", exc_info=True
|
||||
)
|
||||
return False, None
|
||||
|
||||
def handle_tool_calls(
|
||||
self, agent, tool_calls: List[ToolCall], tools_dict: Dict, messages: List[Dict]
|
||||
) -> Generator:
|
||||
@@ -658,110 +195,7 @@ class LLMHandler(ABC):
|
||||
"""
|
||||
updated_messages = messages.copy()
|
||||
|
||||
for i, call in enumerate(tool_calls):
|
||||
# Check context limit before executing tool call
|
||||
if hasattr(agent, '_check_context_limit') and agent._check_context_limit(updated_messages):
|
||||
# Context limit reached - attempt mid-execution compression
|
||||
compression_attempted = False
|
||||
compression_successful = False
|
||||
|
||||
try:
|
||||
from application.core.settings import settings
|
||||
compression_enabled = settings.ENABLE_CONVERSATION_COMPRESSION
|
||||
except Exception:
|
||||
compression_enabled = False
|
||||
|
||||
if compression_enabled:
|
||||
compression_attempted = True
|
||||
try:
|
||||
logger.info(
|
||||
f"Context limit reached with {len(tool_calls) - i} remaining tool calls. "
|
||||
f"Attempting mid-execution compression..."
|
||||
)
|
||||
|
||||
# Trigger mid-execution compression (DB-backed if available, otherwise in-memory)
|
||||
compression_successful, rebuilt_messages = self._perform_mid_execution_compression(
|
||||
agent, updated_messages
|
||||
)
|
||||
|
||||
if compression_successful and rebuilt_messages is not None:
|
||||
# Update the messages list with rebuilt compressed version
|
||||
updated_messages = rebuilt_messages
|
||||
|
||||
# Yield compression success message
|
||||
yield {
|
||||
"type": "info",
|
||||
"data": {
|
||||
"message": "Context window limit reached. Compressed conversation history to continue processing."
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Mid-execution compression successful. Continuing with {len(tool_calls) - i} remaining tool calls."
|
||||
)
|
||||
# Proceed to execute the current tool call with the reduced context
|
||||
else:
|
||||
logger.warning("Mid-execution compression attempted but failed. Skipping remaining tools.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during mid-execution compression: {str(e)}", exc_info=True)
|
||||
compression_attempted = True
|
||||
compression_successful = False
|
||||
|
||||
# If compression wasn't attempted or failed, skip remaining tools
|
||||
if not compression_successful:
|
||||
if i == 0:
|
||||
# Special case: limit reached before executing any tools
|
||||
# This can happen when previous tool responses pushed context over limit
|
||||
if compression_attempted:
|
||||
logger.warning(
|
||||
f"Context limit reached before executing any tools. "
|
||||
f"Compression attempted but failed. "
|
||||
f"Skipping all {len(tool_calls)} pending tool call(s). "
|
||||
f"This typically occurs when previous tool responses contained large amounts of data."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Context limit reached before executing any tools. "
|
||||
f"Skipping all {len(tool_calls)} pending tool call(s). "
|
||||
f"This typically occurs when previous tool responses contained large amounts of data. "
|
||||
f"Consider enabling compression or using a model with larger context window."
|
||||
)
|
||||
else:
|
||||
# Normal case: executed some tools, now stopping
|
||||
tool_word = "tool call" if i == 1 else "tool calls"
|
||||
remaining = len(tool_calls) - i
|
||||
remaining_word = "tool call" if remaining == 1 else "tool calls"
|
||||
if compression_attempted:
|
||||
logger.warning(
|
||||
f"Context limit reached after executing {i} {tool_word}. "
|
||||
f"Compression attempted but failed. "
|
||||
f"Skipping remaining {remaining} {remaining_word}."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Context limit reached after executing {i} {tool_word}. "
|
||||
f"Skipping remaining {remaining} {remaining_word}. "
|
||||
f"Consider enabling compression or using a model with larger context window."
|
||||
)
|
||||
|
||||
# Mark remaining tools as skipped
|
||||
for remaining_call in tool_calls[i:]:
|
||||
skip_message = {
|
||||
"type": "tool_call",
|
||||
"data": {
|
||||
"tool_name": "system",
|
||||
"call_id": remaining_call.id,
|
||||
"action_name": remaining_call.name,
|
||||
"arguments": {},
|
||||
"result": "Skipped: Context limit reached. Too many tool calls in conversation.",
|
||||
"status": "skipped"
|
||||
}
|
||||
}
|
||||
yield skip_message
|
||||
|
||||
# Set flag on agent
|
||||
agent.context_limit_reached = True
|
||||
break
|
||||
for call in tool_calls:
|
||||
try:
|
||||
self.tool_calls.append(call)
|
||||
tool_executor_gen = agent._execute_tool_action(tools_dict, call)
|
||||
@@ -771,26 +205,21 @@ class LLMHandler(ABC):
|
||||
except StopIteration as e:
|
||||
tool_response, call_id = e.value
|
||||
break
|
||||
|
||||
function_call_content = {
|
||||
"function_call": {
|
||||
"name": call.name,
|
||||
"args": call.arguments,
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
# Include thought_signature for Google Gemini 3 models
|
||||
# It should be at the same level as function_call, not inside it
|
||||
if call.thought_signature:
|
||||
function_call_content["thought_signature"] = call.thought_signature
|
||||
updated_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [function_call_content],
|
||||
"content": [
|
||||
{
|
||||
"function_call": {
|
||||
"name": call.name,
|
||||
"args": call.arguments,
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
updated_messages.append(self.create_tool_message(call, tool_response))
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing tool: {str(e)}", exc_info=True)
|
||||
@@ -878,9 +307,6 @@ class LLMHandler(ABC):
|
||||
tool_calls = {}
|
||||
|
||||
for chunk in self._iterate_stream(response):
|
||||
if isinstance(chunk, dict) and chunk.get("type") == "thought":
|
||||
yield chunk
|
||||
continue
|
||||
if isinstance(chunk, str):
|
||||
yield chunk
|
||||
continue
|
||||
@@ -897,13 +323,7 @@ class LLMHandler(ABC):
|
||||
if call.name:
|
||||
existing.name = call.name
|
||||
if call.arguments:
|
||||
if existing.arguments is None:
|
||||
existing.arguments = call.arguments
|
||||
else:
|
||||
existing.arguments += call.arguments
|
||||
# Preserve thought_signature for Google Gemini 3 models
|
||||
if call.thought_signature:
|
||||
existing.thought_signature = call.thought_signature
|
||||
existing.arguments += call.arguments
|
||||
if parsed.finish_reason == "tool_calls":
|
||||
tool_handler_gen = self.handle_tool_calls(
|
||||
agent, list(tool_calls.values()), tools_dict, messages
|
||||
@@ -916,21 +336,8 @@ class LLMHandler(ABC):
|
||||
break
|
||||
tool_calls = {}
|
||||
|
||||
# Check if context limit was reached during tool execution
|
||||
if hasattr(agent, 'context_limit_reached') and agent.context_limit_reached:
|
||||
# Add system message warning about context limit
|
||||
messages.append({
|
||||
"role": "system",
|
||||
"content": (
|
||||
"WARNING: Context window limit has been reached. "
|
||||
"Please provide a final response to the user without making additional tool calls. "
|
||||
"Summarize the work completed so far."
|
||||
)
|
||||
})
|
||||
logger.info("Context limit reached - instructing agent to wrap up")
|
||||
|
||||
response = agent.llm.gen_stream(
|
||||
model=agent.model_id, messages=messages, tools=agent.tools if not agent.context_limit_reached else None
|
||||
model=agent.model_id, messages=messages, tools=agent.tools
|
||||
)
|
||||
self.llm_calls.append(build_stack_data(agent.llm))
|
||||
|
||||
|
||||
@@ -19,20 +19,15 @@ class GoogleLLMHandler(LLMHandler):
|
||||
)
|
||||
if hasattr(response, "candidates"):
|
||||
parts = response.candidates[0].content.parts if response.candidates else []
|
||||
tool_calls = []
|
||||
for idx, part in enumerate(parts):
|
||||
if hasattr(part, "function_call") and part.function_call is not None:
|
||||
has_sig = hasattr(part, "thought_signature") and part.thought_signature is not None
|
||||
thought_sig = part.thought_signature if has_sig else None
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=str(uuid.uuid4()),
|
||||
name=part.function_call.name,
|
||||
arguments=part.function_call.args,
|
||||
index=idx,
|
||||
thought_signature=thought_sig,
|
||||
)
|
||||
)
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
id=str(uuid.uuid4()),
|
||||
name=part.function_call.name,
|
||||
arguments=part.function_call.args,
|
||||
)
|
||||
for part in parts
|
||||
if hasattr(part, "function_call") and part.function_call is not None
|
||||
]
|
||||
|
||||
content = " ".join(
|
||||
part.text
|
||||
@@ -46,17 +41,13 @@ class GoogleLLMHandler(LLMHandler):
|
||||
raw_response=response,
|
||||
)
|
||||
else:
|
||||
# This branch handles individual Part objects from streaming responses
|
||||
tool_calls = []
|
||||
if hasattr(response, "function_call") and response.function_call is not None:
|
||||
has_sig = hasattr(response, "thought_signature") and response.thought_signature is not None
|
||||
thought_sig = response.thought_signature if has_sig else None
|
||||
if hasattr(response, "function_call"):
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=str(uuid.uuid4()),
|
||||
name=response.function_call.name,
|
||||
arguments=response.function_call.args,
|
||||
thought_signature=thought_sig,
|
||||
)
|
||||
)
|
||||
return LLMResponse(
|
||||
|
||||
68
application/llm/huggingface.py
Normal file
68
application/llm/huggingface.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from application.llm.base import BaseLLM
|
||||
|
||||
|
||||
class HuggingFaceLLM(BaseLLM):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key=None,
|
||||
user_api_key=None,
|
||||
llm_name="Arc53/DocsGPT-7B",
|
||||
q=False,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
global hf
|
||||
|
||||
from langchain.llms import HuggingFacePipeline
|
||||
|
||||
if q:
|
||||
import torch
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
pipeline,
|
||||
BitsAndBytesConfig,
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(llm_name)
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
llm_name, quantization_config=bnb_config
|
||||
)
|
||||
else:
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(llm_name)
|
||||
model = AutoModelForCausalLM.from_pretrained(llm_name)
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
self.api_key = api_key
|
||||
self.user_api_key = user_api_key
|
||||
pipe = pipeline(
|
||||
"text-generation",
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
max_new_tokens=2000,
|
||||
device_map="auto",
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
)
|
||||
hf = HuggingFacePipeline(pipeline=pipe)
|
||||
|
||||
def _raw_gen(self, baseself, model, messages, stream=False, **kwargs):
|
||||
context = messages[0]["content"]
|
||||
user_question = messages[-1]["content"]
|
||||
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
|
||||
|
||||
result = hf(prompt)
|
||||
|
||||
return result.content
|
||||
|
||||
def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs):
|
||||
|
||||
raise NotImplementedError("HuggingFaceLLM Streaming is not implemented yet.")
|
||||
@@ -4,12 +4,12 @@ from application.llm.anthropic import AnthropicLLM
|
||||
from application.llm.docsgpt_provider import DocsGPTAPILLM
|
||||
from application.llm.google_ai import GoogleLLM
|
||||
from application.llm.groq import GroqLLM
|
||||
from application.llm.huggingface import HuggingFaceLLM
|
||||
from application.llm.llama_cpp import LlamaCpp
|
||||
from application.llm.novita import NovitaLLM
|
||||
from application.llm.openai import AzureOpenAILLM, OpenAILLM
|
||||
from application.llm.premai import PremAILLM
|
||||
from application.llm.sagemaker import SagemakerAPILLM
|
||||
from application.llm.open_router import OpenRouterLLM
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -19,6 +19,7 @@ class LLMCreator:
|
||||
"openai": OpenAILLM,
|
||||
"azure_openai": AzureOpenAILLM,
|
||||
"sagemaker": SagemakerAPILLM,
|
||||
"huggingface": HuggingFaceLLM,
|
||||
"llama.cpp": LlamaCpp,
|
||||
"anthropic": AnthropicLLM,
|
||||
"docsgpt": DocsGPTAPILLM,
|
||||
@@ -26,7 +27,6 @@ class LLMCreator:
|
||||
"groq": GroqLLM,
|
||||
"google": GoogleLLM,
|
||||
"novita": NovitaLLM,
|
||||
"openrouter": OpenRouterLLM,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -1,15 +1,32 @@
|
||||
from application.core.settings import settings
|
||||
from application.llm.openai import OpenAILLM
|
||||
|
||||
NOVITA_BASE_URL = "https://api.novita.ai/v3/openai"
|
||||
from application.llm.base import BaseLLM
|
||||
from openai import OpenAI
|
||||
|
||||
|
||||
class NovitaLLM(OpenAILLM):
|
||||
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
|
||||
super().__init__(
|
||||
api_key=api_key or settings.API_KEY,
|
||||
user_api_key=user_api_key,
|
||||
base_url=base_url or NOVITA_BASE_URL,
|
||||
*args,
|
||||
**kwargs,
|
||||
class NovitaLLM(BaseLLM):
|
||||
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.client = OpenAI(api_key=api_key, base_url="https://api.novita.ai/v3/openai")
|
||||
self.api_key = api_key
|
||||
self.user_api_key = user_api_key
|
||||
|
||||
def _raw_gen(self, baseself, model, messages, stream=False, tools=None, **kwargs):
|
||||
if tools:
|
||||
response = self.client.chat.completions.create(
|
||||
model=model, messages=messages, stream=stream, tools=tools, **kwargs
|
||||
)
|
||||
return response.choices[0]
|
||||
else:
|
||||
response = self.client.chat.completions.create(
|
||||
model=model, messages=messages, stream=stream, **kwargs
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
|
||||
def _raw_gen_stream(
|
||||
self, baseself, model, messages, stream=True, tools=None, **kwargs
|
||||
):
|
||||
response = self.client.chat.completions.create(
|
||||
model=model, messages=messages, stream=stream, **kwargs
|
||||
)
|
||||
for line in response:
|
||||
if line.choices[0].delta.content is not None:
|
||||
yield line.choices[0].delta.content
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
from application.core.settings import settings
|
||||
from application.llm.openai import OpenAILLM
|
||||
|
||||
OPEN_ROUTER_BASE_URL = "https://openrouter.ai/api/v1"
|
||||
|
||||
|
||||
class OpenRouterLLM(OpenAILLM):
|
||||
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
|
||||
super().__init__(
|
||||
api_key=api_key or settings.OPEN_ROUTER_API_KEY or settings.API_KEY,
|
||||
user_api_key=user_api_key,
|
||||
base_url=base_url or OPEN_ROUTER_BASE_URL,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -9,57 +9,6 @@ from application.llm.base import BaseLLM
|
||||
from application.storage.storage_creator import StorageCreator
|
||||
|
||||
|
||||
def _truncate_base64_for_logging(messages):
|
||||
"""
|
||||
Create a copy of messages with base64 data truncated for readable logging.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts
|
||||
|
||||
Returns:
|
||||
Copy of messages with truncated base64 content
|
||||
"""
|
||||
import copy
|
||||
|
||||
def truncate_content(content):
|
||||
if isinstance(content, str):
|
||||
# Check if it looks like a data URL with base64
|
||||
if content.startswith("data:") and ";base64," in content:
|
||||
prefix_end = content.index(";base64,") + len(";base64,")
|
||||
prefix = content[:prefix_end]
|
||||
return f"{prefix}[BASE64_DATA_TRUNCATED, length={len(content) - prefix_end}]"
|
||||
return content
|
||||
elif isinstance(content, list):
|
||||
return [truncate_item(item) for item in content]
|
||||
elif isinstance(content, dict):
|
||||
return {k: truncate_content(v) for k, v in content.items()}
|
||||
return content
|
||||
|
||||
def truncate_item(item):
|
||||
if isinstance(item, dict):
|
||||
result = {}
|
||||
for k, v in item.items():
|
||||
if k == "url" and isinstance(v, str) and ";base64," in v:
|
||||
prefix_end = v.index(";base64,") + len(";base64,")
|
||||
prefix = v[:prefix_end]
|
||||
result[k] = f"{prefix}[BASE64_DATA_TRUNCATED, length={len(v) - prefix_end}]"
|
||||
elif k == "data" and isinstance(v, str) and len(v) > 100:
|
||||
result[k] = f"[BASE64_DATA_TRUNCATED, length={len(v)}]"
|
||||
else:
|
||||
result[k] = truncate_content(v)
|
||||
return result
|
||||
return truncate_content(item)
|
||||
|
||||
truncated = []
|
||||
for msg in messages:
|
||||
msg_copy = copy.copy(msg)
|
||||
if "content" in msg_copy:
|
||||
msg_copy["content"] = truncate_content(msg_copy["content"])
|
||||
truncated.append(msg_copy)
|
||||
|
||||
return truncated
|
||||
|
||||
|
||||
class OpenAILLM(BaseLLM):
|
||||
|
||||
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
|
||||
@@ -95,12 +44,12 @@ class OpenAILLM(BaseLLM):
|
||||
if isinstance(content, str):
|
||||
cleaned_messages.append({"role": role, "content": content})
|
||||
elif isinstance(content, list):
|
||||
# Collect all content parts into a single message
|
||||
content_parts = []
|
||||
|
||||
for item in content:
|
||||
if "function_call" in item:
|
||||
# Function calls need their own message
|
||||
if "text" in item:
|
||||
cleaned_messages.append(
|
||||
{"role": role, "content": item["text"]}
|
||||
)
|
||||
elif "function_call" in item:
|
||||
cleaned_args = self._remove_null_values(
|
||||
item["function_call"]["args"]
|
||||
)
|
||||
@@ -120,7 +69,6 @@ class OpenAILLM(BaseLLM):
|
||||
}
|
||||
)
|
||||
elif "function_response" in item:
|
||||
# Function responses need their own message
|
||||
cleaned_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
@@ -133,69 +81,40 @@ class OpenAILLM(BaseLLM):
|
||||
}
|
||||
)
|
||||
elif isinstance(item, dict):
|
||||
# Collect content parts (text, images, files) into a single message
|
||||
if "type" in item and item["type"] == "text" and "text" in item:
|
||||
content_parts = []
|
||||
if "text" in item:
|
||||
content_parts.append(
|
||||
{"type": "text", "text": item["text"]}
|
||||
)
|
||||
elif (
|
||||
"type" in item
|
||||
and item["type"] == "text"
|
||||
and "text" in item
|
||||
):
|
||||
content_parts.append(item)
|
||||
elif "type" in item and item["type"] == "file" and "file" in item:
|
||||
elif (
|
||||
"type" in item
|
||||
and item["type"] == "file"
|
||||
and "file" in item
|
||||
):
|
||||
content_parts.append(item)
|
||||
elif "type" in item and item["type"] == "image_url" and "image_url" in item:
|
||||
elif (
|
||||
"type" in item
|
||||
and item["type"] == "image_url"
|
||||
and "image_url" in item
|
||||
):
|
||||
content_parts.append(item)
|
||||
elif "text" in item and "type" not in item:
|
||||
# Legacy format: {"text": "..."} without type
|
||||
content_parts.append({"type": "text", "text": item["text"]})
|
||||
|
||||
# Add the collected content parts as a single message
|
||||
if content_parts:
|
||||
cleaned_messages.append({"role": role, "content": content_parts})
|
||||
cleaned_messages.append(
|
||||
{"role": role, "content": content_parts}
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected content dictionary format: {item}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unexpected content type: {type(content)}")
|
||||
return cleaned_messages
|
||||
|
||||
@staticmethod
|
||||
def _normalize_reasoning_value(value):
|
||||
"""Normalize reasoning payloads from OpenAI-compatible stream chunks."""
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
if isinstance(value, list):
|
||||
return "".join(
|
||||
OpenAILLM._normalize_reasoning_value(item) for item in value
|
||||
)
|
||||
if isinstance(value, dict):
|
||||
for key in ("text", "content", "value", "reasoning_content", "reasoning"):
|
||||
normalized = OpenAILLM._normalize_reasoning_value(value.get(key))
|
||||
if normalized:
|
||||
return normalized
|
||||
return ""
|
||||
|
||||
for attr in ("text", "content", "value"):
|
||||
if hasattr(value, attr):
|
||||
normalized = OpenAILLM._normalize_reasoning_value(getattr(value, attr))
|
||||
if normalized:
|
||||
return normalized
|
||||
return ""
|
||||
|
||||
@classmethod
|
||||
def _extract_reasoning_text(cls, delta):
|
||||
"""Extract reasoning/thinking tokens from OpenAI-compatible delta chunks."""
|
||||
if delta is None:
|
||||
return ""
|
||||
|
||||
for key in (
|
||||
"reasoning_content",
|
||||
"reasoning",
|
||||
"thinking",
|
||||
"thinking_content",
|
||||
):
|
||||
value = getattr(delta, key, None)
|
||||
if value is None and isinstance(delta, dict):
|
||||
value = delta.get(key)
|
||||
normalized = cls._normalize_reasoning_value(value)
|
||||
if normalized:
|
||||
return normalized
|
||||
return ""
|
||||
|
||||
def _raw_gen(
|
||||
self,
|
||||
baseself,
|
||||
@@ -208,11 +127,6 @@ class OpenAILLM(BaseLLM):
|
||||
**kwargs,
|
||||
):
|
||||
messages = self._clean_messages_openai(messages)
|
||||
logging.info(f"Cleaned messages: {_truncate_base64_for_logging(messages)}")
|
||||
|
||||
# Convert max_tokens to max_completion_tokens for newer models
|
||||
if "max_tokens" in kwargs:
|
||||
kwargs["max_completion_tokens"] = kwargs.pop("max_tokens")
|
||||
|
||||
request_params = {
|
||||
"model": model,
|
||||
@@ -226,7 +140,7 @@ class OpenAILLM(BaseLLM):
|
||||
if response_format:
|
||||
request_params["response_format"] = response_format
|
||||
response = self.client.chat.completions.create(**request_params)
|
||||
logging.info(f"OpenAI response: {response}")
|
||||
|
||||
if tools:
|
||||
return response.choices[0]
|
||||
else:
|
||||
@@ -244,11 +158,6 @@ class OpenAILLM(BaseLLM):
|
||||
**kwargs,
|
||||
):
|
||||
messages = self._clean_messages_openai(messages)
|
||||
logging.info(f"Cleaned messages: {_truncate_base64_for_logging(messages)}")
|
||||
|
||||
# Convert max_tokens to max_completion_tokens for newer models
|
||||
if "max_tokens" in kwargs:
|
||||
kwargs["max_completion_tokens"] = kwargs.pop("max_tokens")
|
||||
|
||||
request_params = {
|
||||
"model": model,
|
||||
@@ -265,27 +174,14 @@ class OpenAILLM(BaseLLM):
|
||||
|
||||
try:
|
||||
for line in response:
|
||||
logging.debug(f"OpenAI stream line: {line}")
|
||||
if not getattr(line, "choices", None):
|
||||
continue
|
||||
|
||||
choice = line.choices[0]
|
||||
delta = getattr(choice, "delta", None)
|
||||
reasoning_text = self._extract_reasoning_text(delta)
|
||||
if reasoning_text:
|
||||
yield {"type": "thought", "thought": reasoning_text}
|
||||
|
||||
content = getattr(delta, "content", None)
|
||||
if isinstance(content, str) and content:
|
||||
yield content
|
||||
continue
|
||||
|
||||
has_tool_calls = bool(getattr(delta, "tool_calls", None))
|
||||
finish_reason = getattr(choice, "finish_reason", None)
|
||||
|
||||
# Yield non-content chunks only when needed for tool-call handling.
|
||||
if has_tool_calls or finish_reason == "tool_calls":
|
||||
yield choice
|
||||
if (
|
||||
len(line.choices) > 0
|
||||
and line.choices[0].delta.content is not None
|
||||
and len(line.choices[0].delta.content) > 0
|
||||
):
|
||||
yield line.choices[0].delta.content
|
||||
elif len(line.choices) > 0:
|
||||
yield line.choices[0]
|
||||
finally:
|
||||
if hasattr(response, "close"):
|
||||
response.close()
|
||||
@@ -354,14 +250,17 @@ class OpenAILLM(BaseLLM):
|
||||
"""
|
||||
Return a list of MIME types supported by OpenAI for file uploads.
|
||||
|
||||
This reads from the model config to ensure consistency.
|
||||
If no model config found, falls back to images only (safest default).
|
||||
|
||||
Returns:
|
||||
list: List of supported MIME types
|
||||
"""
|
||||
from application.core.model_configs import OPENAI_ATTACHMENTS
|
||||
return OPENAI_ATTACHMENTS
|
||||
return [
|
||||
"application/pdf",
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/jpg",
|
||||
"image/webp",
|
||||
"image/gif",
|
||||
]
|
||||
|
||||
def prepare_messages_with_attachments(self, messages, attachments=None):
|
||||
"""
|
||||
@@ -398,16 +297,10 @@ class OpenAILLM(BaseLLM):
|
||||
prepared_messages[user_message_index]["content"] = []
|
||||
for attachment in attachments:
|
||||
mime_type = attachment.get("mime_type")
|
||||
logging.info(f"Processing attachment with mime_type: {mime_type}, has_data: {'data' in attachment}, has_path: {'path' in attachment}")
|
||||
|
||||
if mime_type and mime_type.startswith("image/"):
|
||||
try:
|
||||
# Check if this is a pre-converted image (from PDF-to-image conversion)
|
||||
if "data" in attachment:
|
||||
base64_image = attachment["data"]
|
||||
else:
|
||||
base64_image = self._get_base64_image(attachment)
|
||||
|
||||
base64_image = self._get_base64_image(attachment)
|
||||
prepared_messages[user_message_index]["content"].append(
|
||||
{
|
||||
"type": "image_url",
|
||||
@@ -416,7 +309,6 @@ class OpenAILLM(BaseLLM):
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error processing image attachment: {e}", exc_info=True
|
||||
@@ -431,7 +323,6 @@ class OpenAILLM(BaseLLM):
|
||||
# Handle PDFs using the file API
|
||||
|
||||
elif mime_type == "application/pdf":
|
||||
logging.info(f"Attempting to upload PDF to OpenAI: {attachment.get('path', 'unknown')}")
|
||||
try:
|
||||
file_id = self._upload_file_to_openai(attachment)
|
||||
prepared_messages[user_message_index]["content"].append(
|
||||
@@ -446,8 +337,6 @@ class OpenAILLM(BaseLLM):
|
||||
"text": f"File content:\n\n{attachment['content']}",
|
||||
}
|
||||
)
|
||||
else:
|
||||
logging.warning(f"Unsupported attachment type in OpenAI provider: {mime_type}")
|
||||
return prepared_messages
|
||||
|
||||
def _get_base64_image(self, attachment):
|
||||
|
||||
@@ -62,26 +62,15 @@ class BaseConnectorAuth(ABC):
|
||||
def is_token_expired(self, token_info: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Check if a token is expired.
|
||||
|
||||
|
||||
Args:
|
||||
token_info: Token information dictionary
|
||||
|
||||
|
||||
Returns:
|
||||
True if token is expired, False otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
def sanitize_token_info(self, token_info: Dict[str, Any], **extra_fields) -> Dict[str, Any]:
|
||||
"""Extract the fields safe to persist in the session store.
|
||||
"""
|
||||
return {
|
||||
"access_token": token_info.get("access_token"),
|
||||
"refresh_token": token_info.get("refresh_token"),
|
||||
"token_uri": token_info.get("token_uri"),
|
||||
"expiry": token_info.get("expiry"),
|
||||
**extra_fields,
|
||||
}
|
||||
|
||||
|
||||
class BaseConnectorLoader(ABC):
|
||||
"""
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from application.parser.connectors.google_drive.loader import GoogleDriveLoader
|
||||
from application.parser.connectors.google_drive.auth import GoogleDriveAuth
|
||||
from application.parser.connectors.share_point.auth import SharePointAuth
|
||||
from application.parser.connectors.share_point.loader import SharePointLoader
|
||||
|
||||
|
||||
class ConnectorCreator:
|
||||
@@ -14,12 +12,10 @@ class ConnectorCreator:
|
||||
|
||||
connectors = {
|
||||
"google_drive": GoogleDriveLoader,
|
||||
"share_point": SharePointLoader,
|
||||
}
|
||||
|
||||
auth_providers = {
|
||||
"google_drive": GoogleDriveAuth,
|
||||
"share_point": SharePointAuth,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -232,6 +232,10 @@ class GoogleDriveAuth(BaseConnectorAuth):
|
||||
if missing_fields:
|
||||
raise ValueError(f"Missing required token fields: {missing_fields}")
|
||||
|
||||
if 'client_id' not in token_info:
|
||||
token_info['client_id'] = settings.GOOGLE_CLIENT_ID
|
||||
if 'client_secret' not in token_info:
|
||||
token_info['client_secret'] = settings.GOOGLE_CLIENT_SECRET
|
||||
if 'token_uri' not in token_info:
|
||||
token_info['token_uri'] = 'https://oauth2.googleapis.com/token'
|
||||
|
||||
|
||||
@@ -327,10 +327,15 @@ class GoogleDriveLoader(BaseConnectorLoader):
|
||||
content_bytes = file_io.getvalue()
|
||||
|
||||
try:
|
||||
return content_bytes.decode('utf-8')
|
||||
content = content_bytes.decode('utf-8')
|
||||
except UnicodeDecodeError:
|
||||
logging.error(f"Could not decode file {file_id} as text")
|
||||
return None
|
||||
try:
|
||||
content = content_bytes.decode('latin-1')
|
||||
except UnicodeDecodeError:
|
||||
logging.error(f"Could not decode file {file_id} as text")
|
||||
return None
|
||||
|
||||
return content
|
||||
|
||||
except HttpError as e:
|
||||
logging.error(f"HTTP error downloading file {file_id}: {e.resp.status} - {e.content}")
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
"""
|
||||
Share Point connector package for DocsGPT.
|
||||
|
||||
This module provides authentication and document loading capabilities for Share Point.
|
||||
"""
|
||||
|
||||
from .auth import SharePointAuth
|
||||
from .loader import SharePointLoader
|
||||
|
||||
__all__ = ['SharePointAuth', 'SharePointLoader']
|
||||
@@ -1,152 +0,0 @@
|
||||
import datetime
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from msal import ConfidentialClientApplication
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.parser.connectors.base import BaseConnectorAuth
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SharePointAuth(BaseConnectorAuth):
|
||||
"""
|
||||
Handles Microsoft OAuth 2.0 authentication for SharePoint/OneDrive.
|
||||
|
||||
Note: Files.Read scope allows access to files the user has granted access to,
|
||||
similar to Google Drive's drive.file scope.
|
||||
"""
|
||||
|
||||
SCOPES = [
|
||||
"Files.Read",
|
||||
"Sites.Read.All",
|
||||
"User.Read",
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
self.client_id = settings.MICROSOFT_CLIENT_ID
|
||||
self.client_secret = settings.MICROSOFT_CLIENT_SECRET
|
||||
|
||||
if not self.client_id:
|
||||
raise ValueError(
|
||||
"Microsoft OAuth credentials not configured. Please set MICROSOFT_CLIENT_ID in settings."
|
||||
)
|
||||
|
||||
if not self.client_secret:
|
||||
raise ValueError(
|
||||
"Microsoft OAuth credentials not configured. Please set MICROSOFT_CLIENT_SECRET in settings."
|
||||
)
|
||||
|
||||
self.redirect_uri = settings.CONNECTOR_REDIRECT_BASE_URI
|
||||
self.tenant_id = settings.MICROSOFT_TENANT_ID
|
||||
self.authority = getattr(settings, "MICROSOFT_AUTHORITY", f"https://login.microsoftonline.com/{self.tenant_id}")
|
||||
|
||||
self.auth_app = ConfidentialClientApplication(
|
||||
client_id=self.client_id,
|
||||
client_credential=self.client_secret,
|
||||
authority=self.authority
|
||||
)
|
||||
|
||||
def get_authorization_url(self, state: Optional[str] = None) -> str:
|
||||
return self.auth_app.get_authorization_request_url(
|
||||
scopes=self.SCOPES, state=state, redirect_uri=self.redirect_uri
|
||||
)
|
||||
|
||||
def exchange_code_for_tokens(self, authorization_code: str) -> Dict[str, Any]:
|
||||
result = self.auth_app.acquire_token_by_authorization_code(
|
||||
code=authorization_code,
|
||||
scopes=self.SCOPES,
|
||||
redirect_uri=self.redirect_uri
|
||||
)
|
||||
|
||||
if "error" in result:
|
||||
logger.error("Token exchange failed: %s", result.get("error_description"))
|
||||
raise ValueError(f"Error acquiring token: {result.get('error_description')}")
|
||||
|
||||
return self.map_token_response(result)
|
||||
|
||||
def refresh_access_token(self, refresh_token: str) -> Dict[str, Any]:
|
||||
result = self.auth_app.acquire_token_by_refresh_token(refresh_token=refresh_token, scopes=self.SCOPES)
|
||||
|
||||
if "error" in result:
|
||||
logger.error("Token refresh failed: %s", result.get("error_description"))
|
||||
raise ValueError(f"Error refreshing token: {result.get('error_description')}")
|
||||
|
||||
return self.map_token_response(result)
|
||||
|
||||
def get_token_info_from_session(self, session_token: str) -> Dict[str, Any]:
|
||||
try:
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
|
||||
sessions_collection = db["connector_sessions"]
|
||||
session = sessions_collection.find_one({"session_token": session_token})
|
||||
|
||||
if not session:
|
||||
raise ValueError(f"Invalid session token: {session_token}")
|
||||
|
||||
if "token_info" not in session:
|
||||
raise ValueError("Session missing token information")
|
||||
|
||||
token_info = session["token_info"]
|
||||
if not token_info:
|
||||
raise ValueError("Invalid token information")
|
||||
|
||||
required_fields = ["access_token", "refresh_token"]
|
||||
missing_fields = [field for field in required_fields if field not in token_info or not token_info.get(field)]
|
||||
if missing_fields:
|
||||
raise ValueError(f"Missing required token fields: {missing_fields}")
|
||||
|
||||
if 'token_uri' not in token_info:
|
||||
token_info['token_uri'] = f"https://login.microsoftonline.com/{settings.MICROSOFT_TENANT_ID}/oauth2/v2.0/token"
|
||||
|
||||
return token_info
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to retrieve token from session: %s", e)
|
||||
raise ValueError(f"Failed to retrieve SharePoint token information: {str(e)}")
|
||||
|
||||
def is_token_expired(self, token_info: Dict[str, Any]) -> bool:
|
||||
if not token_info:
|
||||
return True
|
||||
|
||||
expiry_timestamp = token_info.get("expiry")
|
||||
|
||||
if expiry_timestamp is None:
|
||||
return True
|
||||
|
||||
current_timestamp = int(datetime.datetime.now().timestamp())
|
||||
return (expiry_timestamp - current_timestamp) < 60
|
||||
|
||||
def sanitize_token_info(self, token_info: Dict[str, Any], **extra_fields) -> Dict[str, Any]:
|
||||
return super().sanitize_token_info(
|
||||
token_info,
|
||||
allows_shared_content=token_info.get("allows_shared_content", False),
|
||||
**extra_fields,
|
||||
)
|
||||
|
||||
PERSONAL_ACCOUNT_TENANT_ID = "9188040d-6c67-4c5b-b112-36a304b66dad"
|
||||
|
||||
def _allows_shared_content(self, id_token_claims: Dict[str, Any]) -> bool:
|
||||
"""Return True when the account is a work/school tenant that can access SharePoint shared content."""
|
||||
tid = id_token_claims.get("tid", "")
|
||||
return bool(tid) and tid != self.PERSONAL_ACCOUNT_TENANT_ID
|
||||
|
||||
def map_token_response(self, result) -> Dict[str, Any]:
|
||||
claims = result.get("id_token_claims", {})
|
||||
return {
|
||||
"access_token": result.get("access_token"),
|
||||
"refresh_token": result.get("refresh_token"),
|
||||
"token_uri": claims.get("iss"),
|
||||
"scopes": result.get("scope"),
|
||||
"expiry": claims.get("exp"),
|
||||
"allows_shared_content": self._allows_shared_content(claims),
|
||||
"user_info": {
|
||||
"name": claims.get("name"),
|
||||
"email": claims.get("preferred_username"),
|
||||
},
|
||||
}
|
||||
@@ -1,649 +0,0 @@
|
||||
"""
|
||||
SharePoint/OneDrive loader for DocsGPT.
|
||||
Loads documents from SharePoint/OneDrive using Microsoft Graph API.
|
||||
"""
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from urllib.parse import quote
|
||||
|
||||
import requests
|
||||
|
||||
from application.parser.connectors.base import BaseConnectorLoader
|
||||
from application.parser.connectors.share_point.auth import SharePointAuth
|
||||
from application.parser.schema.base import Document
|
||||
|
||||
|
||||
def _retry_on_auth_failure(func):
|
||||
"""Retry once after refreshing the access token on 401/403 responses."""
|
||||
@functools.wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
try:
|
||||
return func(self, *args, **kwargs)
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response is not None and e.response.status_code in (401, 403):
|
||||
logging.info(f"Auth failure in {func.__name__}, refreshing token and retrying")
|
||||
try:
|
||||
new_token_info = self.auth.refresh_access_token(self.refresh_token)
|
||||
self.access_token = new_token_info.get('access_token')
|
||||
except Exception as refresh_error:
|
||||
raise ValueError(
|
||||
f"Authentication failed and could not be refreshed: {refresh_error}"
|
||||
) from e
|
||||
return func(self, *args, **kwargs)
|
||||
raise
|
||||
return wrapper
|
||||
|
||||
|
||||
class SharePointLoader(BaseConnectorLoader):
|
||||
|
||||
SUPPORTED_MIME_TYPES = {
|
||||
'application/pdf': '.pdf',
|
||||
'application/vnd.openxmlformats-officedocument.wordprocessingml.document': '.docx',
|
||||
'application/vnd.openxmlformats-officedocument.presentationml.presentation': '.pptx',
|
||||
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': '.xlsx',
|
||||
'application/msword': '.doc',
|
||||
'application/vnd.ms-powerpoint': '.ppt',
|
||||
'application/vnd.ms-excel': '.xls',
|
||||
'text/plain': '.txt',
|
||||
'text/csv': '.csv',
|
||||
'text/html': '.html',
|
||||
'text/markdown': '.md',
|
||||
'text/x-rst': '.rst',
|
||||
'application/json': '.json',
|
||||
'application/epub+zip': '.epub',
|
||||
'application/rtf': '.rtf',
|
||||
'image/jpeg': '.jpg',
|
||||
'image/png': '.png',
|
||||
}
|
||||
|
||||
EXTENSION_TO_MIME = {v: k for k, v in SUPPORTED_MIME_TYPES.items()}
|
||||
|
||||
GRAPH_API_BASE = "https://graph.microsoft.com/v1.0"
|
||||
|
||||
def __init__(self, session_token: str):
|
||||
self.auth = SharePointAuth()
|
||||
self.session_token = session_token
|
||||
|
||||
token_info = self.auth.get_token_info_from_session(session_token)
|
||||
self.access_token = token_info.get('access_token')
|
||||
self.refresh_token = token_info.get('refresh_token')
|
||||
self.allows_shared_content = token_info.get('allows_shared_content', False)
|
||||
|
||||
if not self.access_token:
|
||||
raise ValueError("No access token found in session")
|
||||
|
||||
self.next_page_token = None
|
||||
|
||||
def _get_headers(self) -> Dict[str, str]:
|
||||
return {
|
||||
'Authorization': f'Bearer {self.access_token}',
|
||||
'Accept': 'application/json'
|
||||
}
|
||||
|
||||
def _ensure_valid_token(self):
|
||||
if not self.access_token:
|
||||
raise ValueError("No access token available")
|
||||
|
||||
token_info = {'access_token': self.access_token, 'expiry': None}
|
||||
if self.auth.is_token_expired(token_info):
|
||||
logging.info("Token expired, attempting refresh")
|
||||
try:
|
||||
new_token_info = self.auth.refresh_access_token(self.refresh_token)
|
||||
self.access_token = new_token_info.get('access_token')
|
||||
except Exception:
|
||||
raise ValueError("Failed to refresh access token")
|
||||
|
||||
def _get_item_url(self, item_ref: str) -> str:
|
||||
if ':' in item_ref:
|
||||
drive_id, item_id = item_ref.split(':', 1)
|
||||
return f"{self.GRAPH_API_BASE}/drives/{drive_id}/items/{item_id}"
|
||||
return f"{self.GRAPH_API_BASE}/me/drive/items/{item_ref}"
|
||||
|
||||
def _process_file(self, file_metadata: Dict[str, Any], load_content: bool = True) -> Optional[Document]:
|
||||
try:
|
||||
drive_item_id = file_metadata.get('id')
|
||||
file_name = file_metadata.get('name', 'Unknown')
|
||||
file_data = file_metadata.get('file', {})
|
||||
mime_type = file_data.get('mimeType', 'application/octet-stream')
|
||||
|
||||
if mime_type not in self.SUPPORTED_MIME_TYPES:
|
||||
logging.info(f"Skipping unsupported file type: {mime_type} for file {file_name}")
|
||||
return None
|
||||
|
||||
doc_metadata = {
|
||||
'file_name': file_name,
|
||||
'mime_type': mime_type,
|
||||
'size': file_metadata.get('size'),
|
||||
'created_time': file_metadata.get('createdDateTime'),
|
||||
'modified_time': file_metadata.get('lastModifiedDateTime'),
|
||||
'source': 'share_point'
|
||||
}
|
||||
|
||||
if not load_content:
|
||||
return Document(
|
||||
text="",
|
||||
doc_id=drive_item_id,
|
||||
extra_info=doc_metadata
|
||||
)
|
||||
|
||||
content = self._download_file_content(drive_item_id)
|
||||
if content is None:
|
||||
logging.warning(f"Could not load content for file {file_name} ({drive_item_id})")
|
||||
return None
|
||||
|
||||
return Document(
|
||||
text=content,
|
||||
doc_id=drive_item_id,
|
||||
extra_info=doc_metadata
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing file: {e}")
|
||||
return None
|
||||
|
||||
def load_data(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||
try:
|
||||
documents: List[Document] = []
|
||||
|
||||
folder_id = inputs.get('folder_id')
|
||||
file_ids = inputs.get('file_ids', [])
|
||||
limit = inputs.get('limit', 100)
|
||||
list_only = inputs.get('list_only', False)
|
||||
load_content = not list_only
|
||||
page_token = inputs.get('page_token')
|
||||
search_query = inputs.get('search_query')
|
||||
self.next_page_token = None
|
||||
|
||||
shared = inputs.get('shared', False)
|
||||
|
||||
if file_ids:
|
||||
for file_id in file_ids:
|
||||
try:
|
||||
doc = self._load_file_by_id(file_id, load_content=load_content)
|
||||
if doc:
|
||||
if not search_query or (
|
||||
search_query.lower() in doc.extra_info.get('file_name', '').lower()
|
||||
):
|
||||
documents.append(doc)
|
||||
except Exception as e:
|
||||
logging.error(f"Error loading file {file_id}: {e}")
|
||||
continue
|
||||
elif shared:
|
||||
if not self.allows_shared_content:
|
||||
logging.warning("Shared content is only available for work/school Microsoft accounts")
|
||||
return []
|
||||
documents = self._list_shared_items(
|
||||
limit=limit,
|
||||
load_content=load_content,
|
||||
page_token=page_token,
|
||||
search_query=search_query
|
||||
)
|
||||
else:
|
||||
parent_id = folder_id if folder_id else 'root'
|
||||
documents = self._list_items_in_parent(
|
||||
parent_id,
|
||||
limit=limit,
|
||||
load_content=load_content,
|
||||
page_token=page_token,
|
||||
search_query=search_query
|
||||
)
|
||||
|
||||
logging.info(f"Loaded {len(documents)} documents from SharePoint/OneDrive")
|
||||
return documents
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error loading data from SharePoint/OneDrive: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
@_retry_on_auth_failure
|
||||
def _load_file_by_id(self, file_id: str, load_content: bool = True) -> Optional[Document]:
|
||||
self._ensure_valid_token()
|
||||
|
||||
try:
|
||||
url = self._get_item_url(file_id)
|
||||
params = {'$select': 'id,name,file,createdDateTime,lastModifiedDateTime,size'}
|
||||
response = requests.get(url, headers=self._get_headers(), params=params)
|
||||
response.raise_for_status()
|
||||
|
||||
file_metadata = response.json()
|
||||
return self._process_file(file_metadata, load_content=load_content)
|
||||
|
||||
except requests.exceptions.HTTPError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logging.error(f"Error loading file {file_id}: {e}")
|
||||
return None
|
||||
|
||||
@_retry_on_auth_failure
|
||||
def _list_items_in_parent(self, parent_id: str, limit: int = 100, load_content: bool = False, page_token: Optional[str] = None, search_query: Optional[str] = None) -> List[Document]:
|
||||
self._ensure_valid_token()
|
||||
|
||||
documents: List[Document] = []
|
||||
|
||||
try:
|
||||
url = f"{self._get_item_url(parent_id)}/children"
|
||||
params = {'$top': min(100, limit) if limit else 100, '$select': 'id,name,file,folder,createdDateTime,lastModifiedDateTime,size'}
|
||||
if page_token:
|
||||
params['$skipToken'] = page_token
|
||||
|
||||
if search_query:
|
||||
encoded_query = quote(search_query, safe='')
|
||||
if ':' in parent_id:
|
||||
drive_id = parent_id.split(':', 1)[0]
|
||||
search_url = f"{self.GRAPH_API_BASE}/drives/{drive_id}/root/search(q='{encoded_query}')"
|
||||
else:
|
||||
search_url = f"{self.GRAPH_API_BASE}/me/drive/search(q='{encoded_query}')"
|
||||
response = requests.get(search_url, headers=self._get_headers(), params=params)
|
||||
else:
|
||||
response = requests.get(url, headers=self._get_headers(), params=params)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
results = response.json()
|
||||
|
||||
items = results.get('value', [])
|
||||
for item in items:
|
||||
if 'folder' in item:
|
||||
doc_metadata = {
|
||||
'file_name': item.get('name', 'Unknown'),
|
||||
'mime_type': 'folder',
|
||||
'size': item.get('size'),
|
||||
'created_time': item.get('createdDateTime'),
|
||||
'modified_time': item.get('lastModifiedDateTime'),
|
||||
'source': 'share_point',
|
||||
'is_folder': True
|
||||
}
|
||||
documents.append(Document(text="", doc_id=item.get('id'), extra_info=doc_metadata))
|
||||
else:
|
||||
doc = self._process_file(item, load_content=load_content)
|
||||
if doc:
|
||||
documents.append(doc)
|
||||
|
||||
if limit and len(documents) >= limit:
|
||||
break
|
||||
|
||||
next_link = results.get('@odata.nextLink')
|
||||
if next_link:
|
||||
from urllib.parse import urlparse, parse_qs
|
||||
parsed = urlparse(next_link)
|
||||
query_params = parse_qs(parsed.query)
|
||||
skiptoken_list = query_params.get('$skiptoken')
|
||||
if skiptoken_list:
|
||||
self.next_page_token = skiptoken_list[0]
|
||||
else:
|
||||
self.next_page_token = None
|
||||
else:
|
||||
self.next_page_token = None
|
||||
return documents
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error listing items under parent {parent_id}: {e}")
|
||||
return documents
|
||||
|
||||
|
||||
|
||||
|
||||
def _resolve_mime_type(self, resource: Dict[str, Any]) -> Tuple[str, bool]:
|
||||
"""Resolve mime type from resource, falling back to file extension."""
|
||||
file_data = resource.get('file', {})
|
||||
mime_type = file_data.get('mimeType') if file_data else None
|
||||
|
||||
if mime_type and mime_type in self.SUPPORTED_MIME_TYPES:
|
||||
return mime_type, True
|
||||
|
||||
name = resource.get('name', '')
|
||||
ext = os.path.splitext(name)[1].lower()
|
||||
if ext in self.EXTENSION_TO_MIME:
|
||||
return self.EXTENSION_TO_MIME[ext], True
|
||||
|
||||
return mime_type or 'application/octet-stream', False
|
||||
|
||||
def _get_user_drive_web_url(self) -> Optional[str]:
|
||||
"""Fetch the current user's OneDrive web URL for KQL path exclusion."""
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{self.GRAPH_API_BASE}/me/drive",
|
||||
headers=self._get_headers(),
|
||||
params={'$select': 'webUrl'}
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json().get('webUrl')
|
||||
except Exception as e:
|
||||
logging.warning(f"Could not fetch user drive web URL: {e}")
|
||||
return None
|
||||
|
||||
def _build_shared_kql_query(self, search_query: Optional[str], user_drive_url: Optional[str]) -> str:
|
||||
"""Build KQL query string that excludes the user's own drive items."""
|
||||
base_query = search_query if search_query else "*"
|
||||
if user_drive_url:
|
||||
return f'{base_query} AND -path:"{user_drive_url}"'
|
||||
return base_query
|
||||
|
||||
def _list_shared_items(self, limit: int = 100, load_content: bool = False, page_token: Optional[str] = None, search_query: Optional[str] = None) -> List[Document]:
|
||||
"""Fetch shared drive items using Microsoft Graph Search API with local offset paging.
|
||||
|
||||
We always fetch up to a fixed maximum number of hits from Graph (single request),
|
||||
then page through that array locally using `page_token` as a simple integer offset.
|
||||
This avoids relying on buggy or inconsistent remote `from`/`size` semantics.
|
||||
"""
|
||||
self._ensure_valid_token()
|
||||
documents: List[Document] = []
|
||||
|
||||
try:
|
||||
user_drive_url = self._get_user_drive_web_url()
|
||||
query_text = self._build_shared_kql_query(search_query, user_drive_url)
|
||||
|
||||
url = f"{self.GRAPH_API_BASE}/search/query"
|
||||
page_size = 500 # maximum number of hits we care about for selection
|
||||
|
||||
body = {
|
||||
"requests": [
|
||||
{
|
||||
"entityTypes": ["driveItem"],
|
||||
"query": {"queryString": query_text},
|
||||
"from": 0,
|
||||
"size": page_size,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
headers = self._get_headers()
|
||||
headers["Content-Type"] = "application/json"
|
||||
response = requests.post(url, headers=headers, json=body)
|
||||
response.raise_for_status()
|
||||
results = response.json()
|
||||
|
||||
search_response = results.get("value", [])
|
||||
if not search_response:
|
||||
logging.warning("Search API returned empty value array")
|
||||
self.next_page_token = None
|
||||
return documents
|
||||
|
||||
hits_containers = search_response[0].get("hitsContainers", [])
|
||||
if not hits_containers:
|
||||
logging.warning("Search API returned no hitsContainers")
|
||||
self.next_page_token = None
|
||||
return documents
|
||||
|
||||
container = hits_containers[0]
|
||||
total = container.get("total", 0)
|
||||
raw_hits = container.get("hits", [])
|
||||
|
||||
# Deduplicate by effective item ID (driveId:itemId) to avoid the same
|
||||
# resource appearing multiple times across the result set.
|
||||
deduped_hits = []
|
||||
seen_ids = set()
|
||||
for hit in raw_hits:
|
||||
resource = hit.get("resource", {})
|
||||
item_id = resource.get("id")
|
||||
drive_id = resource.get("parentReference", {}).get("driveId")
|
||||
effective_id = f"{drive_id}:{item_id}" if drive_id and item_id else item_id
|
||||
if not effective_id or effective_id in seen_ids:
|
||||
continue
|
||||
seen_ids.add(effective_id)
|
||||
deduped_hits.append(hit)
|
||||
|
||||
hits = deduped_hits
|
||||
logging.info(
|
||||
f"Search API returned {total} total results, {len(raw_hits)} raw hits, {len(hits)} unique hits in this batch"
|
||||
)
|
||||
try:
|
||||
offset = int(page_token) if page_token is not None else 0
|
||||
except (TypeError, ValueError):
|
||||
logging.warning(
|
||||
f"Invalid page_token '{page_token}' for shared items search, defaulting to 0"
|
||||
)
|
||||
offset = 0
|
||||
|
||||
if offset < 0:
|
||||
offset = 0
|
||||
if offset >= len(hits):
|
||||
self.next_page_token = None
|
||||
return documents
|
||||
|
||||
end_index = offset + limit if limit else len(hits)
|
||||
end_index = min(end_index, len(hits))
|
||||
|
||||
for hit in hits[offset:end_index]:
|
||||
resource = hit.get("resource", {})
|
||||
item_name = resource.get("name", "Unknown")
|
||||
item_id = resource.get("id")
|
||||
drive_id = resource.get("parentReference", {}).get("driveId")
|
||||
|
||||
effective_id = f"{drive_id}:{item_id}" if drive_id and item_id else item_id
|
||||
|
||||
is_folder = "folder" in resource
|
||||
|
||||
if is_folder:
|
||||
doc_metadata = {
|
||||
"file_name": item_name,
|
||||
"mime_type": "folder",
|
||||
"size": resource.get("size"),
|
||||
"created_time": resource.get("createdDateTime"),
|
||||
"modified_time": resource.get("lastModifiedDateTime"),
|
||||
"source": "share_point",
|
||||
"is_folder": True,
|
||||
}
|
||||
documents.append(
|
||||
Document(text="", doc_id=effective_id, extra_info=doc_metadata)
|
||||
)
|
||||
else:
|
||||
mime_type, supported = self._resolve_mime_type(resource)
|
||||
if not supported:
|
||||
logging.info(
|
||||
f"Skipping unsupported shared file: {item_name} (mime: {mime_type})"
|
||||
)
|
||||
continue
|
||||
|
||||
doc_metadata = {
|
||||
"file_name": item_name,
|
||||
"mime_type": mime_type,
|
||||
"size": resource.get("size"),
|
||||
"created_time": resource.get("createdDateTime"),
|
||||
"modified_time": resource.get("lastModifiedDateTime"),
|
||||
"source": "share_point",
|
||||
}
|
||||
|
||||
content = ""
|
||||
if load_content:
|
||||
content = self._download_file_content(effective_id) or ""
|
||||
|
||||
documents.append(
|
||||
Document(text=content, doc_id=effective_id, extra_info=doc_metadata)
|
||||
)
|
||||
|
||||
if limit and end_index < len(hits):
|
||||
self.next_page_token = str(end_index)
|
||||
else:
|
||||
self.next_page_token = None
|
||||
|
||||
return documents
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error listing shared items via search API: {e}", exc_info=True)
|
||||
return documents
|
||||
|
||||
@_retry_on_auth_failure
|
||||
def _download_file_content(self, file_id: str) -> Optional[str]:
|
||||
self._ensure_valid_token()
|
||||
|
||||
try:
|
||||
url = f"{self._get_item_url(file_id)}/content"
|
||||
response = requests.get(url, headers=self._get_headers())
|
||||
response.raise_for_status()
|
||||
|
||||
try:
|
||||
return response.content.decode('utf-8')
|
||||
except UnicodeDecodeError:
|
||||
logging.error(f"Could not decode file {file_id} as text")
|
||||
return None
|
||||
|
||||
except requests.exceptions.HTTPError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logging.error(f"Error downloading file {file_id}: {e}")
|
||||
return None
|
||||
|
||||
def _download_single_file(self, file_id: str, local_dir: str) -> bool:
|
||||
try:
|
||||
url = self._get_item_url(file_id)
|
||||
params = {'$select': 'id,name,file'}
|
||||
response = requests.get(url, headers=self._get_headers(), params=params)
|
||||
response.raise_for_status()
|
||||
|
||||
metadata = response.json()
|
||||
file_name = metadata.get('name', 'unknown')
|
||||
file_data = metadata.get('file', {})
|
||||
mime_type = file_data.get('mimeType', 'application/octet-stream')
|
||||
|
||||
if mime_type not in self.SUPPORTED_MIME_TYPES:
|
||||
logging.info(f"Skipping unsupported file type: {mime_type}")
|
||||
return False
|
||||
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
full_path = os.path.join(local_dir, file_name)
|
||||
|
||||
download_url = f"{self._get_item_url(file_id)}/content"
|
||||
download_response = requests.get(download_url, headers=self._get_headers())
|
||||
download_response.raise_for_status()
|
||||
|
||||
with open(full_path, 'wb') as f:
|
||||
f.write(download_response.content)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.error(f"Error in _download_single_file: {e}")
|
||||
return False
|
||||
|
||||
def _download_folder_recursive(self, folder_id: str, local_dir: str, recursive: bool = True) -> int:
|
||||
files_downloaded = 0
|
||||
try:
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
|
||||
url = f"{self._get_item_url(folder_id)}/children"
|
||||
params = {'$top': 1000}
|
||||
|
||||
while url:
|
||||
response = requests.get(url, headers=self._get_headers(), params=params)
|
||||
response.raise_for_status()
|
||||
|
||||
results = response.json()
|
||||
items = results.get('value', [])
|
||||
logging.info(f"Found {len(items)} items in folder {folder_id}")
|
||||
|
||||
for item in items:
|
||||
item_name = item.get('name', 'unknown')
|
||||
item_id = item.get('id')
|
||||
|
||||
if 'folder' in item:
|
||||
if recursive:
|
||||
subfolder_path = os.path.join(local_dir, item_name)
|
||||
os.makedirs(subfolder_path, exist_ok=True)
|
||||
subfolder_files = self._download_folder_recursive(
|
||||
item_id,
|
||||
subfolder_path,
|
||||
recursive
|
||||
)
|
||||
files_downloaded += subfolder_files
|
||||
logging.info(f"Downloaded {subfolder_files} files from subfolder {item_name}")
|
||||
else:
|
||||
success = self._download_single_file(item_id, local_dir)
|
||||
if success:
|
||||
files_downloaded += 1
|
||||
logging.info(f"Downloaded file: {item_name}")
|
||||
else:
|
||||
logging.warning(f"Failed to download file: {item_name}")
|
||||
|
||||
url = results.get('@odata.nextLink')
|
||||
|
||||
return files_downloaded
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error in _download_folder_recursive for folder {folder_id}: {e}", exc_info=True)
|
||||
return files_downloaded
|
||||
|
||||
def _download_folder_contents(self, folder_id: str, local_dir: str, recursive: bool = True) -> int:
|
||||
try:
|
||||
self._ensure_valid_token()
|
||||
return self._download_folder_recursive(folder_id, local_dir, recursive)
|
||||
except Exception as e:
|
||||
logging.error(f"Error downloading folder {folder_id}: {e}", exc_info=True)
|
||||
return 0
|
||||
|
||||
def _download_file_to_directory(self, file_id: str, local_dir: str) -> bool:
|
||||
try:
|
||||
self._ensure_valid_token()
|
||||
return self._download_single_file(file_id, local_dir)
|
||||
except Exception as e:
|
||||
logging.error(f"Error downloading file {file_id}: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def download_to_directory(self, local_dir: str, source_config: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
if source_config is None:
|
||||
source_config = {}
|
||||
|
||||
config = source_config if source_config else getattr(self, 'config', {})
|
||||
files_downloaded = 0
|
||||
|
||||
try:
|
||||
folder_ids = config.get('folder_ids', [])
|
||||
file_ids = config.get('file_ids', [])
|
||||
recursive = config.get('recursive', True)
|
||||
|
||||
if file_ids:
|
||||
if isinstance(file_ids, str):
|
||||
file_ids = [file_ids]
|
||||
|
||||
for file_id in file_ids:
|
||||
if self._download_file_to_directory(file_id, local_dir):
|
||||
files_downloaded += 1
|
||||
|
||||
if folder_ids:
|
||||
if isinstance(folder_ids, str):
|
||||
folder_ids = [folder_ids]
|
||||
|
||||
for folder_id in folder_ids:
|
||||
try:
|
||||
url = self._get_item_url(folder_id)
|
||||
params = {'$select': 'id,name'}
|
||||
response = requests.get(url, headers=self._get_headers(), params=params)
|
||||
response.raise_for_status()
|
||||
|
||||
folder_metadata = response.json()
|
||||
folder_name = folder_metadata.get('name', '')
|
||||
folder_path = os.path.join(local_dir, folder_name)
|
||||
os.makedirs(folder_path, exist_ok=True)
|
||||
|
||||
folder_files = self._download_folder_recursive(
|
||||
folder_id,
|
||||
folder_path,
|
||||
recursive
|
||||
)
|
||||
files_downloaded += folder_files
|
||||
logging.info(f"Downloaded {folder_files} files from folder {folder_name}")
|
||||
except Exception as e:
|
||||
logging.error(f"Error downloading folder {folder_id}: {e}", exc_info=True)
|
||||
|
||||
if not file_ids and not folder_ids:
|
||||
raise ValueError("No folder_ids or file_ids provided for download")
|
||||
|
||||
return {
|
||||
"files_downloaded": files_downloaded,
|
||||
"directory_path": local_dir,
|
||||
"empty_result": files_downloaded == 0,
|
||||
"source_type": "share_point",
|
||||
"config_used": config
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"files_downloaded": files_downloaded,
|
||||
"directory_path": local_dir,
|
||||
"empty_result": True,
|
||||
"source_type": "share_point",
|
||||
"config_used": config,
|
||||
"error": str(e)
|
||||
}
|
||||
@@ -65,10 +65,6 @@ def embed_and_store_documents(docs: List[Any], folder_name: str, source_id: str,
|
||||
if not os.path.exists(folder_name):
|
||||
os.makedirs(folder_name)
|
||||
|
||||
# Validate docs is not empty
|
||||
if not docs:
|
||||
raise ValueError("No documents to embed - check file format and extension")
|
||||
|
||||
# Initialize vector store
|
||||
if settings.VECTOR_STORE == "faiss":
|
||||
docs_init = [docs.pop(0)]
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Any, List
|
||||
|
||||
from langchain_core.documents import Document as LCDocument
|
||||
from langchain.docstore.document import Document as LCDocument
|
||||
from application.parser.schema.base import Document
|
||||
|
||||
|
||||
|
||||
@@ -10,97 +10,29 @@ from application.parser.file.epub_parser import EpubParser
|
||||
from application.parser.file.html_parser import HTMLParser
|
||||
from application.parser.file.markdown_parser import MarkdownParser
|
||||
from application.parser.file.rst_parser import RstParser
|
||||
from application.parser.file.tabular_parser import PandasCSVParser, ExcelParser
|
||||
from application.parser.file.tabular_parser import PandasCSVParser,ExcelParser
|
||||
from application.parser.file.json_parser import JSONParser
|
||||
from application.parser.file.pptx_parser import PPTXParser
|
||||
from application.parser.file.image_parser import ImageParser
|
||||
from application.parser.schema.base import Document
|
||||
from application.utils import num_tokens_from_string
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
def get_default_file_extractor(
|
||||
ocr_enabled: Optional[bool] = None,
|
||||
) -> Dict[str, BaseParser]:
|
||||
"""Get the default file extractor.
|
||||
|
||||
Uses docling parsers by default for advanced document processing.
|
||||
Falls back to standard parsers if docling is not installed.
|
||||
"""
|
||||
try:
|
||||
from application.parser.file.docling_parser import (
|
||||
DoclingPDFParser,
|
||||
DoclingDocxParser,
|
||||
DoclingPPTXParser,
|
||||
DoclingXLSXParser,
|
||||
DoclingHTMLParser,
|
||||
DoclingImageParser,
|
||||
DoclingCSVParser,
|
||||
DoclingAsciiDocParser,
|
||||
DoclingVTTParser,
|
||||
DoclingXMLParser,
|
||||
)
|
||||
if ocr_enabled is None:
|
||||
ocr_enabled = settings.DOCLING_OCR_ENABLED
|
||||
return {
|
||||
# Documents
|
||||
".pdf": DoclingPDFParser(ocr_enabled=ocr_enabled),
|
||||
".docx": DoclingDocxParser(),
|
||||
".pptx": DoclingPPTXParser(),
|
||||
".xlsx": DoclingXLSXParser(),
|
||||
# Web formats
|
||||
".html": DoclingHTMLParser(),
|
||||
".xhtml": DoclingHTMLParser(),
|
||||
# Data formats
|
||||
".csv": DoclingCSVParser(),
|
||||
".json": JSONParser(), # Keep JSON parser (specialized handling)
|
||||
# Text/markup formats
|
||||
".md": MarkdownParser(), # Keep markdown parser (specialized handling)
|
||||
".mdx": MarkdownParser(),
|
||||
".rst": RstParser(),
|
||||
".adoc": DoclingAsciiDocParser(),
|
||||
".asciidoc": DoclingAsciiDocParser(),
|
||||
# Images (with OCR) - only use Docling when OCR is enabled
|
||||
".png": DoclingImageParser(ocr_enabled=ocr_enabled) if ocr_enabled else ImageParser(),
|
||||
".jpg": DoclingImageParser(ocr_enabled=ocr_enabled) if ocr_enabled else ImageParser(),
|
||||
".jpeg": DoclingImageParser(ocr_enabled=ocr_enabled) if ocr_enabled else ImageParser(),
|
||||
".tiff": DoclingImageParser(ocr_enabled=ocr_enabled) if ocr_enabled else ImageParser(),
|
||||
".tif": DoclingImageParser(ocr_enabled=ocr_enabled) if ocr_enabled else ImageParser(),
|
||||
".bmp": DoclingImageParser(ocr_enabled=ocr_enabled) if ocr_enabled else ImageParser(),
|
||||
".webp": DoclingImageParser(ocr_enabled=ocr_enabled) if ocr_enabled else ImageParser(),
|
||||
# Media/subtitles
|
||||
".vtt": DoclingVTTParser(),
|
||||
# Specialized XML formats
|
||||
".xml": DoclingXMLParser(),
|
||||
# Formats docling doesn't support - use standard parsers
|
||||
".epub": EpubParser(),
|
||||
}
|
||||
except ImportError:
|
||||
logging.warning(
|
||||
"docling is not installed. Using standard parsers. "
|
||||
"For advanced document parsing, install with: pip install docling"
|
||||
)
|
||||
# Fallback to standard parsers
|
||||
return {
|
||||
".pdf": PDFParser(),
|
||||
".docx": DocxParser(),
|
||||
".csv": PandasCSVParser(),
|
||||
".xlsx": ExcelParser(),
|
||||
".epub": EpubParser(),
|
||||
".md": MarkdownParser(),
|
||||
".rst": RstParser(),
|
||||
".html": HTMLParser(),
|
||||
".mdx": MarkdownParser(),
|
||||
".json": JSONParser(),
|
||||
".pptx": PPTXParser(),
|
||||
".png": ImageParser(),
|
||||
".jpg": ImageParser(),
|
||||
".jpeg": ImageParser(),
|
||||
}
|
||||
|
||||
|
||||
# For backwards compatibility
|
||||
DEFAULT_FILE_EXTRACTOR: Dict[str, BaseParser] = get_default_file_extractor()
|
||||
DEFAULT_FILE_EXTRACTOR: Dict[str, BaseParser] = {
|
||||
".pdf": PDFParser(),
|
||||
".docx": DocxParser(),
|
||||
".csv": PandasCSVParser(),
|
||||
".xlsx":ExcelParser(),
|
||||
".epub": EpubParser(),
|
||||
".md": MarkdownParser(),
|
||||
".rst": RstParser(),
|
||||
".html": HTMLParser(),
|
||||
".mdx": MarkdownParser(),
|
||||
".json":JSONParser(),
|
||||
".pptx":PPTXParser(),
|
||||
".png": ImageParser(),
|
||||
".jpg": ImageParser(),
|
||||
".jpeg": ImageParser(),
|
||||
}
|
||||
|
||||
|
||||
class SimpleDirectoryReader(BaseReader):
|
||||
@@ -151,10 +83,7 @@ class SimpleDirectoryReader(BaseReader):
|
||||
|
||||
self.recursive = recursive
|
||||
self.exclude_hidden = exclude_hidden
|
||||
# Normalize extensions to lowercase for case-insensitive matching
|
||||
self.required_exts = (
|
||||
[ext.lower() for ext in required_exts] if required_exts else None
|
||||
)
|
||||
self.required_exts = required_exts
|
||||
self.num_files_limit = num_files_limit
|
||||
|
||||
if input_files:
|
||||
@@ -183,7 +112,7 @@ class SimpleDirectoryReader(BaseReader):
|
||||
continue
|
||||
elif (
|
||||
self.required_exts is not None
|
||||
and input_file.suffix.lower() not in self.required_exts
|
||||
and input_file.suffix not in self.required_exts
|
||||
):
|
||||
continue
|
||||
else:
|
||||
@@ -220,9 +149,8 @@ class SimpleDirectoryReader(BaseReader):
|
||||
self.file_token_counts = {}
|
||||
|
||||
for input_file in self.input_files:
|
||||
suffix_lower = input_file.suffix.lower()
|
||||
if suffix_lower in self.file_extractor:
|
||||
parser = self.file_extractor[suffix_lower]
|
||||
if input_file.suffix in self.file_extractor:
|
||||
parser = self.file_extractor[input_file.suffix]
|
||||
if not parser.parser_config_set:
|
||||
parser.init_parser()
|
||||
data = parser.parse_file(input_file, errors=self.errors)
|
||||
@@ -304,7 +232,7 @@ class SimpleDirectoryReader(BaseReader):
|
||||
if subtree:
|
||||
result[item.name] = subtree
|
||||
else:
|
||||
if self.required_exts is not None and item.suffix.lower() not in self.required_exts:
|
||||
if self.required_exts is not None and item.suffix not in self.required_exts:
|
||||
continue
|
||||
|
||||
full_path = str(item.resolve())
|
||||
@@ -323,4 +251,4 @@ class SimpleDirectoryReader(BaseReader):
|
||||
|
||||
return result
|
||||
|
||||
return build_tree(Path(base_path))
|
||||
return build_tree(Path(base_path))
|
||||
@@ -1,330 +0,0 @@
|
||||
"""Docling parser.
|
||||
|
||||
Uses docling library for advanced document parsing with layout detection,
|
||||
table structure recognition, and unified document representation.
|
||||
|
||||
Supports: PDF, DOCX, PPTX, XLSX, HTML, XHTML, CSV, Markdown, AsciiDoc,
|
||||
images (PNG, JPEG, TIFF, BMP, WEBP), WebVTT, and specialized XML formats.
|
||||
"""
|
||||
import importlib.util
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from application.parser.file.base_parser import BaseParser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DoclingParser(BaseParser):
|
||||
"""Parser using docling for advanced document processing.
|
||||
|
||||
Docling provides:
|
||||
- Advanced PDF layout analysis
|
||||
- Table structure recognition
|
||||
- Reading order detection
|
||||
- OCR for scanned documents (supports RapidOCR)
|
||||
- Unified DoclingDocument format
|
||||
- Export to Markdown
|
||||
|
||||
Uses hybrid OCR approach by default:
|
||||
- Text regions: Direct PDF text extraction (fast)
|
||||
- Bitmap/image regions: OCR only these areas (smart)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ocr_enabled: bool = True,
|
||||
table_structure: bool = True,
|
||||
export_format: str = "markdown",
|
||||
use_rapidocr: bool = True,
|
||||
ocr_languages: Optional[List[str]] = None,
|
||||
force_full_page_ocr: bool = False,
|
||||
):
|
||||
"""Initialize DoclingParser.
|
||||
|
||||
Args:
|
||||
ocr_enabled: Enable OCR for bitmap/image regions in documents
|
||||
table_structure: Enable table structure recognition
|
||||
export_format: Output format ('markdown', 'text', 'html')
|
||||
use_rapidocr: Use RapidOCR engine (default True, works well in Docker)
|
||||
ocr_languages: List of OCR languages (default: ['english'])
|
||||
force_full_page_ocr: Force OCR on entire page (False = smart hybrid OCR)
|
||||
"""
|
||||
super().__init__()
|
||||
self.ocr_enabled = ocr_enabled
|
||||
self.table_structure = table_structure
|
||||
self.export_format = export_format
|
||||
self.use_rapidocr = use_rapidocr
|
||||
self.ocr_languages = ocr_languages or ["english"]
|
||||
self.force_full_page_ocr = force_full_page_ocr
|
||||
self._converter = None
|
||||
|
||||
def _create_converter(self):
|
||||
"""Create a docling converter with hybrid OCR configuration.
|
||||
|
||||
Uses smart OCR approach:
|
||||
- When ocr_enabled=True and force_full_page_ocr=False (default):
|
||||
Layout model detects text vs bitmap regions, OCR only runs on bitmaps
|
||||
- When ocr_enabled=True and force_full_page_ocr=True:
|
||||
OCR runs on entire page (for scanned documents/images)
|
||||
- When ocr_enabled=False:
|
||||
No OCR, only native text extraction
|
||||
|
||||
Returns:
|
||||
DocumentConverter instance
|
||||
"""
|
||||
from docling.document_converter import (
|
||||
DocumentConverter,
|
||||
ImageFormatOption,
|
||||
InputFormat,
|
||||
PdfFormatOption,
|
||||
)
|
||||
from docling.datamodel.pipeline_options import PdfPipelineOptions
|
||||
|
||||
pipeline_options = PdfPipelineOptions(
|
||||
do_ocr=self.ocr_enabled,
|
||||
do_table_structure=self.table_structure,
|
||||
)
|
||||
|
||||
if self.ocr_enabled:
|
||||
ocr_options = self._get_ocr_options()
|
||||
if ocr_options is not None:
|
||||
pipeline_options.ocr_options = ocr_options
|
||||
|
||||
return DocumentConverter(
|
||||
format_options={
|
||||
InputFormat.PDF: PdfFormatOption(
|
||||
pipeline_options=pipeline_options,
|
||||
),
|
||||
InputFormat.IMAGE: ImageFormatOption(
|
||||
pipeline_options=pipeline_options,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
def _init_parser(self) -> Dict:
|
||||
"""Initialize the docling converter with hybrid OCR."""
|
||||
logger.info("Initializing DoclingParser...")
|
||||
logger.info(f" ocr_enabled={self.ocr_enabled}")
|
||||
logger.info(f" force_full_page_ocr={self.force_full_page_ocr}")
|
||||
logger.info(f" use_rapidocr={self.use_rapidocr}")
|
||||
|
||||
if importlib.util.find_spec("docling.document_converter") is None:
|
||||
raise ImportError(
|
||||
"docling is required for DoclingParser. "
|
||||
"Install it with: pip install docling"
|
||||
)
|
||||
|
||||
# Create converter with hybrid OCR (smart: text direct, bitmaps OCR'd)
|
||||
self._converter = self._create_converter()
|
||||
|
||||
logger.info("DoclingParser initialized successfully")
|
||||
return {
|
||||
"ocr_enabled": self.ocr_enabled,
|
||||
"table_structure": self.table_structure,
|
||||
"export_format": self.export_format,
|
||||
"use_rapidocr": self.use_rapidocr,
|
||||
"ocr_languages": self.ocr_languages,
|
||||
"force_full_page_ocr": self.force_full_page_ocr,
|
||||
}
|
||||
|
||||
def _get_ocr_options(self):
|
||||
"""Get OCR options based on configuration.
|
||||
|
||||
Returns RapidOcrOptions if use_rapidocr is True and available,
|
||||
otherwise returns None to use docling defaults.
|
||||
"""
|
||||
if not self.use_rapidocr:
|
||||
return None
|
||||
|
||||
try:
|
||||
from docling.datamodel.pipeline_options import RapidOcrOptions
|
||||
|
||||
return RapidOcrOptions(
|
||||
lang=self.ocr_languages,
|
||||
force_full_page_ocr=self.force_full_page_ocr,
|
||||
)
|
||||
except ImportError as e:
|
||||
logger.warning(f"Failed to import RapidOcrOptions: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating RapidOcrOptions: {e}")
|
||||
return None
|
||||
|
||||
def _export_content(self, document) -> str:
|
||||
"""Export document content in the configured format.
|
||||
|
||||
Handles edge case where text is nested under picture elements (e.g., OCR'd
|
||||
images). If the standard export returns minimal content but document.texts
|
||||
contains extracted text, falls back to direct text extraction.
|
||||
"""
|
||||
if self.export_format == "markdown":
|
||||
content = document.export_to_markdown()
|
||||
elif self.export_format == "html":
|
||||
content = document.export_to_html()
|
||||
else:
|
||||
content = document.export_to_text()
|
||||
|
||||
# Handle case where text is nested under pictures (common with OCR'd images)
|
||||
# Standard exports may return just "<!-- image -->" while actual text exists
|
||||
stripped_content = content.strip()
|
||||
is_minimal = len(stripped_content) < 50 or stripped_content == "<!-- image -->"
|
||||
|
||||
if is_minimal and hasattr(document, "texts") and document.texts:
|
||||
# Extract text directly from document.texts
|
||||
extracted_texts = [t.text for t in document.texts if t.text]
|
||||
if extracted_texts:
|
||||
logger.info(
|
||||
f"Standard export minimal ({len(stripped_content)} chars), "
|
||||
f"extracting {len(extracted_texts)} texts directly"
|
||||
)
|
||||
return "\n\n".join(extracted_texts)
|
||||
|
||||
return content
|
||||
|
||||
def parse_file(self, file: Path, errors: str = "ignore") -> Union[str, List[str]]:
|
||||
"""Parse file using docling with hybrid OCR.
|
||||
|
||||
Uses smart OCR approach where the layout model detects text vs bitmap
|
||||
regions. Text is extracted directly, bitmaps are OCR'd only when needed.
|
||||
|
||||
Args:
|
||||
file: Path to the file to parse
|
||||
errors: Error handling mode (ignored, docling handles internally)
|
||||
|
||||
Returns:
|
||||
Parsed document content as markdown string
|
||||
"""
|
||||
logger.info(f"parse_file called for: {file}")
|
||||
|
||||
if self._converter is None:
|
||||
self._init_parser()
|
||||
|
||||
try:
|
||||
logger.info(f"Converting file with hybrid OCR: {file}")
|
||||
result = self._converter.convert(str(file))
|
||||
content = self._export_content(result.document)
|
||||
logger.info(f"Parse complete, content length: {len(content)} chars")
|
||||
|
||||
return content
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing file with docling: {e}", exc_info=True)
|
||||
if errors == "ignore":
|
||||
return f"[Error parsing file with docling: {str(e)}]"
|
||||
raise
|
||||
|
||||
|
||||
class DoclingPDFParser(DoclingParser):
|
||||
"""Docling-based PDF parser with advanced features and RapidOCR support.
|
||||
|
||||
Uses hybrid OCR approach by default:
|
||||
- Text regions: Direct PDF text extraction (fast)
|
||||
- Bitmap/image regions: OCR only these areas (smart)
|
||||
|
||||
Set force_full_page_ocr=True only for fully scanned documents.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ocr_enabled: bool = True,
|
||||
table_structure: bool = True,
|
||||
use_rapidocr: bool = True,
|
||||
ocr_languages: Optional[List[str]] = None,
|
||||
force_full_page_ocr: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
ocr_enabled=ocr_enabled,
|
||||
table_structure=table_structure,
|
||||
export_format="markdown",
|
||||
use_rapidocr=use_rapidocr,
|
||||
ocr_languages=ocr_languages,
|
||||
force_full_page_ocr=force_full_page_ocr,
|
||||
)
|
||||
|
||||
|
||||
class DoclingDocxParser(DoclingParser):
|
||||
"""Docling-based DOCX parser."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(export_format="markdown")
|
||||
|
||||
|
||||
class DoclingPPTXParser(DoclingParser):
|
||||
"""Docling-based PPTX parser."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(export_format="markdown")
|
||||
|
||||
|
||||
class DoclingXLSXParser(DoclingParser):
|
||||
"""Docling-based XLSX parser with table structure."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(table_structure=True, export_format="markdown")
|
||||
|
||||
|
||||
class DoclingHTMLParser(DoclingParser):
|
||||
"""Docling-based HTML parser."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(export_format="markdown")
|
||||
|
||||
|
||||
class DoclingImageParser(DoclingParser):
|
||||
"""Docling-based image parser with OCR and RapidOCR support.
|
||||
|
||||
For images, force_full_page_ocr=True is used since images are entirely
|
||||
visual and require full OCR to extract any text.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ocr_enabled: bool = True,
|
||||
use_rapidocr: bool = True,
|
||||
ocr_languages: Optional[List[str]] = None,
|
||||
force_full_page_ocr: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
ocr_enabled=ocr_enabled,
|
||||
export_format="markdown",
|
||||
use_rapidocr=use_rapidocr,
|
||||
ocr_languages=ocr_languages,
|
||||
force_full_page_ocr=force_full_page_ocr,
|
||||
)
|
||||
|
||||
|
||||
class DoclingCSVParser(DoclingParser):
|
||||
"""Docling-based CSV parser."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(table_structure=True, export_format="markdown")
|
||||
|
||||
|
||||
class DoclingMarkdownParser(DoclingParser):
|
||||
"""Docling-based Markdown parser."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(export_format="markdown")
|
||||
|
||||
|
||||
class DoclingAsciiDocParser(DoclingParser):
|
||||
"""Docling-based AsciiDoc parser."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(export_format="markdown")
|
||||
|
||||
|
||||
class DoclingVTTParser(DoclingParser):
|
||||
"""Docling-based WebVTT (video text tracks) parser."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(export_format="markdown")
|
||||
|
||||
|
||||
class DoclingXMLParser(DoclingParser):
|
||||
"""Docling-based XML parser (USPTO, JATS)."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(export_format="markdown")
|
||||
@@ -7,8 +7,8 @@ import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
import tiktoken
|
||||
from application.parser.file.base_parser import BaseParser
|
||||
from application.utils import num_tokens_from_string
|
||||
|
||||
|
||||
class MarkdownParser(BaseParser):
|
||||
@@ -38,7 +38,7 @@ class MarkdownParser(BaseParser):
|
||||
def tups_chunk_append(self, tups: List[Tuple[Optional[str], str]], current_header: Optional[str],
|
||||
current_text: str):
|
||||
"""Append to tups chunk."""
|
||||
num_tokens = num_tokens_from_string(current_text)
|
||||
num_tokens = len(tiktoken.get_encoding("cl100k_base").encode(current_text))
|
||||
if num_tokens > self._max_tokens:
|
||||
chunks = [current_text[i:i + self._max_tokens] for i in range(0, len(current_text), self._max_tokens)]
|
||||
for chunk in chunks:
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Any, List
|
||||
|
||||
from langchain_core.documents import Document as LCDocument
|
||||
from langchain.docstore.document import Document as LCDocument
|
||||
from application.parser.schema.base import Document
|
||||
|
||||
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
import logging
|
||||
import os
|
||||
import requests
|
||||
from urllib.parse import urlparse, urljoin
|
||||
from bs4 import BeautifulSoup
|
||||
from application.parser.remote.base import BaseRemote
|
||||
from application.parser.schema.base import Document
|
||||
from application.core.url_validation import validate_url, SSRFError
|
||||
from langchain_community.document_loaders import WebBaseLoader
|
||||
|
||||
class CrawlerLoader(BaseRemote):
|
||||
@@ -18,12 +16,9 @@ class CrawlerLoader(BaseRemote):
|
||||
if isinstance(url, list) and url:
|
||||
url = url[0]
|
||||
|
||||
# Validate URL to prevent SSRF attacks
|
||||
try:
|
||||
url = validate_url(url)
|
||||
except SSRFError as e:
|
||||
logging.error(f"URL validation failed: {e}")
|
||||
return []
|
||||
# Check if the URL scheme is provided, if not, assume http
|
||||
if not urlparse(url).scheme:
|
||||
url = "http://" + url
|
||||
|
||||
visited_urls = set()
|
||||
base_url = urlparse(url).scheme + "://" + urlparse(url).hostname
|
||||
@@ -35,26 +30,16 @@ class CrawlerLoader(BaseRemote):
|
||||
visited_urls.add(current_url)
|
||||
|
||||
try:
|
||||
# Validate each URL before making requests
|
||||
try:
|
||||
validate_url(current_url)
|
||||
except SSRFError as e:
|
||||
logging.warning(f"Skipping URL due to validation failure: {current_url} - {e}")
|
||||
continue
|
||||
|
||||
response = requests.get(current_url, timeout=30)
|
||||
response = requests.get(current_url)
|
||||
response.raise_for_status()
|
||||
loader = self.loader([current_url])
|
||||
docs = loader.load()
|
||||
# Convert the loaded documents to your Document schema
|
||||
for doc in docs:
|
||||
metadata = dict(doc.metadata or {})
|
||||
source_url = metadata.get("source") or current_url
|
||||
metadata["file_path"] = self._url_to_virtual_path(source_url)
|
||||
loaded_content.append(
|
||||
Document(
|
||||
doc.page_content,
|
||||
extra_info=metadata
|
||||
extra_info=doc.metadata
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -78,29 +63,3 @@ class CrawlerLoader(BaseRemote):
|
||||
break
|
||||
|
||||
return loaded_content
|
||||
|
||||
def _url_to_virtual_path(self, url):
|
||||
"""
|
||||
Convert a URL to a virtual file path ending with .md.
|
||||
|
||||
Examples:
|
||||
https://docs.docsgpt.cloud/ -> index.md
|
||||
https://docs.docsgpt.cloud/guides/setup -> guides/setup.md
|
||||
https://docs.docsgpt.cloud/guides/setup/ -> guides/setup.md
|
||||
https://example.com/page.html -> page.md
|
||||
"""
|
||||
parsed = urlparse(url)
|
||||
path = parsed.path.strip("/")
|
||||
|
||||
if not path:
|
||||
return "index.md"
|
||||
|
||||
# Remove common file extensions and add .md
|
||||
base, ext = os.path.splitext(path)
|
||||
if ext.lower() in [".html", ".htm", ".php", ".asp", ".aspx", ".jsp"]:
|
||||
path = base
|
||||
|
||||
if not path.endswith(".md"):
|
||||
path = f"{path}.md"
|
||||
|
||||
return path
|
||||
|
||||
@@ -2,12 +2,10 @@ import requests
|
||||
from urllib.parse import urlparse, urljoin
|
||||
from bs4 import BeautifulSoup
|
||||
from application.parser.remote.base import BaseRemote
|
||||
from application.core.url_validation import validate_url, SSRFError
|
||||
import re
|
||||
from markdownify import markdownify
|
||||
from application.parser.schema.base import Document
|
||||
import tldextract
|
||||
import os
|
||||
|
||||
class CrawlerLoader(BaseRemote):
|
||||
def __init__(self, limit=10, allow_subdomains=False):
|
||||
@@ -27,12 +25,9 @@ class CrawlerLoader(BaseRemote):
|
||||
if isinstance(url, list) and url:
|
||||
url = url[0]
|
||||
|
||||
# Validate URL to prevent SSRF attacks
|
||||
try:
|
||||
url = validate_url(url)
|
||||
except SSRFError as e:
|
||||
print(f"URL validation failed: {e}")
|
||||
return []
|
||||
# Ensure the URL has a scheme (if not, default to http)
|
||||
if not urlparse(url).scheme:
|
||||
url = "http://" + url
|
||||
|
||||
# Keep track of visited URLs to avoid revisiting the same page
|
||||
visited_urls = set()
|
||||
@@ -58,21 +53,13 @@ class CrawlerLoader(BaseRemote):
|
||||
# Convert the HTML to Markdown for cleaner text formatting
|
||||
title, language, processed_markdown = self._process_html_to_markdown(html_content, current_url)
|
||||
if processed_markdown:
|
||||
# Generate virtual file path from URL for consistent file-like matching
|
||||
virtual_path = self._url_to_virtual_path(current_url)
|
||||
|
||||
# Create a Document for each visited page
|
||||
documents.append(
|
||||
Document(
|
||||
processed_markdown, # content
|
||||
None, # doc_id
|
||||
None, # embedding
|
||||
{
|
||||
"source": current_url,
|
||||
"title": title,
|
||||
"language": language,
|
||||
"file_path": virtual_path,
|
||||
}, # extra_info
|
||||
{"source": current_url, "title": title, "language": language} # extra_info
|
||||
)
|
||||
)
|
||||
|
||||
@@ -91,14 +78,9 @@ class CrawlerLoader(BaseRemote):
|
||||
|
||||
def _fetch_page(self, url):
|
||||
try:
|
||||
# Validate URL before fetching to prevent SSRF
|
||||
validate_url(url)
|
||||
response = self.session.get(url, timeout=10)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
except SSRFError as e:
|
||||
print(f"URL validation failed for {url}: {e}")
|
||||
return None
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"Error fetching URL {url}: {e}")
|
||||
return None
|
||||
@@ -154,31 +136,4 @@ class CrawlerLoader(BaseRemote):
|
||||
# Exact domain match
|
||||
if link_base == base_domain:
|
||||
filtered.append(link)
|
||||
return filtered
|
||||
|
||||
def _url_to_virtual_path(self, url):
|
||||
"""
|
||||
Convert a URL to a virtual file path ending with .md.
|
||||
|
||||
Examples:
|
||||
https://docs.docsgpt.cloud/ -> index.md
|
||||
https://docs.docsgpt.cloud/guides/setup -> guides/setup.md
|
||||
https://docs.docsgpt.cloud/guides/setup/ -> guides/setup.md
|
||||
https://example.com/page.html -> page.md
|
||||
"""
|
||||
parsed = urlparse(url)
|
||||
path = parsed.path.strip("/")
|
||||
|
||||
if not path:
|
||||
return "index.md"
|
||||
|
||||
# Remove common file extensions and add .md
|
||||
base, ext = os.path.splitext(path)
|
||||
if ext.lower() in [".html", ".htm", ".php", ".asp", ".aspx", ".jsp"]:
|
||||
path = base
|
||||
|
||||
# Ensure path ends with .md
|
||||
if not path.endswith(".md"):
|
||||
path = path + ".md"
|
||||
|
||||
return path
|
||||
return filtered
|
||||
@@ -3,7 +3,6 @@ from application.parser.remote.crawler_loader import CrawlerLoader
|
||||
from application.parser.remote.web_loader import WebLoader
|
||||
from application.parser.remote.reddit_loader import RedditPostsLoaderRemote
|
||||
from application.parser.remote.github_loader import GitHubLoader
|
||||
from application.parser.remote.s3_loader import S3Loader
|
||||
|
||||
|
||||
class RemoteCreator:
|
||||
@@ -23,7 +22,6 @@ class RemoteCreator:
|
||||
"crawler": CrawlerLoader,
|
||||
"reddit": RedditPostsLoaderRemote,
|
||||
"github": GitHubLoader,
|
||||
"s3": S3Loader,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -1,427 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import mimetypes
|
||||
from typing import List, Optional
|
||||
from application.parser.remote.base import BaseRemote
|
||||
from application.parser.schema.base import Document
|
||||
|
||||
try:
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError, NoCredentialsError
|
||||
except ImportError:
|
||||
boto3 = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class S3Loader(BaseRemote):
|
||||
"""Load documents from an AWS S3 bucket."""
|
||||
|
||||
def __init__(self):
|
||||
if boto3 is None:
|
||||
raise ImportError(
|
||||
"boto3 is required for S3Loader. Install it with: pip install boto3"
|
||||
)
|
||||
self.s3_client = None
|
||||
|
||||
def _normalize_endpoint_url(self, endpoint_url: str, bucket: str) -> tuple[str, str]:
|
||||
"""
|
||||
Normalize endpoint URL for S3-compatible services.
|
||||
|
||||
Detects common mistakes like using bucket-prefixed URLs and extracts
|
||||
the correct endpoint and bucket name.
|
||||
|
||||
Args:
|
||||
endpoint_url: The provided endpoint URL
|
||||
bucket: The provided bucket name
|
||||
|
||||
Returns:
|
||||
Tuple of (normalized_endpoint_url, bucket_name)
|
||||
"""
|
||||
import re
|
||||
from urllib.parse import urlparse
|
||||
|
||||
if not endpoint_url:
|
||||
return endpoint_url, bucket
|
||||
|
||||
parsed = urlparse(endpoint_url)
|
||||
host = parsed.netloc or parsed.path
|
||||
|
||||
# Check for DigitalOcean Spaces bucket-prefixed URL pattern
|
||||
# e.g., https://mybucket.nyc3.digitaloceanspaces.com
|
||||
do_match = re.match(r"^([^.]+)\.([a-z0-9]+)\.digitaloceanspaces\.com$", host)
|
||||
if do_match:
|
||||
extracted_bucket = do_match.group(1)
|
||||
region = do_match.group(2)
|
||||
correct_endpoint = f"https://{region}.digitaloceanspaces.com"
|
||||
logger.warning(
|
||||
f"Detected bucket-prefixed DigitalOcean Spaces URL. "
|
||||
f"Extracted bucket '{extracted_bucket}' from endpoint. "
|
||||
f"Using endpoint: {correct_endpoint}"
|
||||
)
|
||||
# If bucket wasn't provided or differs, use extracted one
|
||||
if not bucket or bucket != extracted_bucket:
|
||||
logger.info(f"Using extracted bucket name: '{extracted_bucket}' (was: '{bucket}')")
|
||||
bucket = extracted_bucket
|
||||
return correct_endpoint, bucket
|
||||
|
||||
# Check for just "digitaloceanspaces.com" without region
|
||||
if host == "digitaloceanspaces.com":
|
||||
logger.error(
|
||||
"Invalid DigitalOcean Spaces endpoint: missing region. "
|
||||
"Use format: https://<region>.digitaloceanspaces.com (e.g., https://lon1.digitaloceanspaces.com)"
|
||||
)
|
||||
|
||||
return endpoint_url, bucket
|
||||
|
||||
def _init_client(
|
||||
self,
|
||||
aws_access_key_id: str,
|
||||
aws_secret_access_key: str,
|
||||
region_name: str = "us-east-1",
|
||||
endpoint_url: Optional[str] = None,
|
||||
bucket: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Initialize the S3 client with credentials.
|
||||
|
||||
Returns:
|
||||
The potentially corrected bucket name if endpoint URL was normalized
|
||||
"""
|
||||
from botocore.config import Config
|
||||
|
||||
client_kwargs = {
|
||||
"aws_access_key_id": aws_access_key_id,
|
||||
"aws_secret_access_key": aws_secret_access_key,
|
||||
"region_name": region_name,
|
||||
}
|
||||
|
||||
logger.info(f"Initializing S3 client with region: {region_name}")
|
||||
|
||||
corrected_bucket = bucket
|
||||
if endpoint_url:
|
||||
# Normalize the endpoint URL and potentially extract bucket name
|
||||
normalized_endpoint, corrected_bucket = self._normalize_endpoint_url(endpoint_url, bucket)
|
||||
logger.info(f"Original endpoint URL: {endpoint_url}")
|
||||
logger.info(f"Normalized endpoint URL: {normalized_endpoint}")
|
||||
logger.info(f"Bucket name: '{corrected_bucket}'")
|
||||
|
||||
client_kwargs["endpoint_url"] = normalized_endpoint
|
||||
# Use path-style addressing for S3-compatible services
|
||||
# (DigitalOcean Spaces, MinIO, etc.)
|
||||
client_kwargs["config"] = Config(s3={"addressing_style": "path"})
|
||||
else:
|
||||
logger.info("Using default AWS S3 endpoint")
|
||||
|
||||
self.s3_client = boto3.client("s3", **client_kwargs)
|
||||
logger.info("S3 client initialized successfully")
|
||||
|
||||
return corrected_bucket
|
||||
|
||||
def is_text_file(self, file_path: str) -> bool:
|
||||
"""Determine if a file is a text file based on extension."""
|
||||
text_extensions = {
|
||||
".txt",
|
||||
".md",
|
||||
".markdown",
|
||||
".rst",
|
||||
".json",
|
||||
".xml",
|
||||
".yaml",
|
||||
".yml",
|
||||
".py",
|
||||
".js",
|
||||
".ts",
|
||||
".jsx",
|
||||
".tsx",
|
||||
".java",
|
||||
".c",
|
||||
".cpp",
|
||||
".h",
|
||||
".hpp",
|
||||
".cs",
|
||||
".go",
|
||||
".rs",
|
||||
".rb",
|
||||
".php",
|
||||
".swift",
|
||||
".kt",
|
||||
".scala",
|
||||
".html",
|
||||
".css",
|
||||
".scss",
|
||||
".sass",
|
||||
".less",
|
||||
".sh",
|
||||
".bash",
|
||||
".zsh",
|
||||
".fish",
|
||||
".sql",
|
||||
".r",
|
||||
".m",
|
||||
".mat",
|
||||
".ini",
|
||||
".cfg",
|
||||
".conf",
|
||||
".config",
|
||||
".env",
|
||||
".gitignore",
|
||||
".dockerignore",
|
||||
".editorconfig",
|
||||
".log",
|
||||
".csv",
|
||||
".tsv",
|
||||
}
|
||||
|
||||
file_lower = file_path.lower()
|
||||
for ext in text_extensions:
|
||||
if file_lower.endswith(ext):
|
||||
return True
|
||||
|
||||
mime_type, _ = mimetypes.guess_type(file_path)
|
||||
if mime_type and (
|
||||
mime_type.startswith("text")
|
||||
or mime_type in ["application/json", "application/xml"]
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def is_supported_document(self, file_path: str) -> bool:
|
||||
"""Check if file is a supported document type for parsing."""
|
||||
document_extensions = {
|
||||
".pdf",
|
||||
".docx",
|
||||
".doc",
|
||||
".xlsx",
|
||||
".xls",
|
||||
".pptx",
|
||||
".ppt",
|
||||
".epub",
|
||||
".odt",
|
||||
".rtf",
|
||||
}
|
||||
|
||||
file_lower = file_path.lower()
|
||||
for ext in document_extensions:
|
||||
if file_lower.endswith(ext):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def list_objects(self, bucket: str, prefix: str = "") -> List[str]:
|
||||
"""
|
||||
List all objects in the bucket with the given prefix.
|
||||
|
||||
Args:
|
||||
bucket: S3 bucket name
|
||||
prefix: Optional path prefix to filter objects
|
||||
|
||||
Returns:
|
||||
List of object keys
|
||||
"""
|
||||
objects = []
|
||||
paginator = self.s3_client.get_paginator("list_objects_v2")
|
||||
|
||||
logger.info(f"Listing objects in bucket: '{bucket}' with prefix: '{prefix}'")
|
||||
logger.debug(f"S3 client endpoint: {self.s3_client.meta.endpoint_url}")
|
||||
|
||||
try:
|
||||
page_count = 0
|
||||
for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
|
||||
page_count += 1
|
||||
logger.debug(f"Processing page {page_count}, keys in response: {list(page.keys())}")
|
||||
if "Contents" in page:
|
||||
for obj in page["Contents"]:
|
||||
key = obj["Key"]
|
||||
if not key.endswith("/"):
|
||||
objects.append(key)
|
||||
logger.debug(f"Found object: {key}")
|
||||
else:
|
||||
logger.info(f"Page {page_count} has no 'Contents' key - bucket may be empty or prefix not found")
|
||||
|
||||
logger.info(f"Found {len(objects)} objects in bucket '{bucket}'")
|
||||
|
||||
except ClientError as e:
|
||||
error_code = e.response.get("Error", {}).get("Code", "")
|
||||
error_message = e.response.get("Error", {}).get("Message", "")
|
||||
logger.error(f"ClientError listing objects - Code: {error_code}, Message: {error_message}")
|
||||
logger.error(f"Full error response: {e.response}")
|
||||
logger.error(f"Bucket: '{bucket}', Prefix: '{prefix}', Endpoint: {self.s3_client.meta.endpoint_url}")
|
||||
|
||||
if error_code == "NoSuchBucket":
|
||||
raise Exception(f"S3 bucket '{bucket}' does not exist")
|
||||
elif error_code == "AccessDenied":
|
||||
raise Exception(
|
||||
f"Access denied to S3 bucket '{bucket}'. Check your credentials and permissions."
|
||||
)
|
||||
elif error_code == "NoSuchKey":
|
||||
# This is unusual for ListObjectsV2 - may indicate endpoint/bucket configuration issue
|
||||
logger.error(
|
||||
"NoSuchKey error on ListObjectsV2 - this may indicate the bucket name "
|
||||
"is incorrect or the endpoint URL format is wrong. "
|
||||
"For DigitalOcean Spaces, the endpoint should be like: "
|
||||
"https://<region>.digitaloceanspaces.com and bucket should be just the space name."
|
||||
)
|
||||
raise Exception(
|
||||
f"S3 error: {e}. For S3-compatible services, verify: "
|
||||
f"1) Endpoint URL format (e.g., https://nyc3.digitaloceanspaces.com), "
|
||||
f"2) Bucket name is just the space/bucket name without region prefix"
|
||||
)
|
||||
else:
|
||||
raise Exception(f"S3 error: {e}")
|
||||
except NoCredentialsError:
|
||||
raise Exception(
|
||||
"AWS credentials not found. Please provide valid credentials."
|
||||
)
|
||||
|
||||
return objects
|
||||
|
||||
def get_object_content(self, bucket: str, key: str) -> Optional[str]:
|
||||
"""
|
||||
Get the content of an S3 object as text.
|
||||
|
||||
Args:
|
||||
bucket: S3 bucket name
|
||||
key: Object key
|
||||
|
||||
Returns:
|
||||
File content as string, or None if file should be skipped
|
||||
"""
|
||||
if not self.is_text_file(key) and not self.is_supported_document(key):
|
||||
return None
|
||||
|
||||
try:
|
||||
response = self.s3_client.get_object(Bucket=bucket, Key=key)
|
||||
content = response["Body"].read()
|
||||
|
||||
if self.is_text_file(key):
|
||||
try:
|
||||
decoded_content = content.decode("utf-8").strip()
|
||||
if not decoded_content:
|
||||
return None
|
||||
return decoded_content
|
||||
except UnicodeDecodeError:
|
||||
return None
|
||||
elif self.is_supported_document(key):
|
||||
return self._process_document(content, key)
|
||||
|
||||
except ClientError as e:
|
||||
error_code = e.response.get("Error", {}).get("Code", "")
|
||||
if error_code == "NoSuchKey":
|
||||
return None
|
||||
elif error_code == "AccessDenied":
|
||||
print(f"Access denied to object: {key}")
|
||||
return None
|
||||
else:
|
||||
print(f"Error fetching object {key}: {e}")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def _process_document(self, content: bytes, key: str) -> Optional[str]:
|
||||
"""
|
||||
Process a document file (PDF, DOCX, etc.) and extract text.
|
||||
|
||||
Args:
|
||||
content: File content as bytes
|
||||
key: Object key (filename)
|
||||
|
||||
Returns:
|
||||
Extracted text content
|
||||
"""
|
||||
ext = os.path.splitext(key)[1].lower()
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as tmp_file:
|
||||
tmp_file.write(content)
|
||||
tmp_path = tmp_file.name
|
||||
|
||||
try:
|
||||
from application.parser.file.bulk import SimpleDirectoryReader
|
||||
|
||||
reader = SimpleDirectoryReader(input_files=[tmp_path])
|
||||
documents = reader.load_data()
|
||||
if documents:
|
||||
return "\n\n".join(doc.text for doc in documents if doc.text)
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Error processing document {key}: {e}")
|
||||
return None
|
||||
finally:
|
||||
if os.path.exists(tmp_path):
|
||||
os.unlink(tmp_path)
|
||||
|
||||
def load_data(self, inputs) -> List[Document]:
|
||||
"""
|
||||
Load documents from an S3 bucket.
|
||||
|
||||
Args:
|
||||
inputs: JSON string or dict containing:
|
||||
- aws_access_key_id: AWS access key ID
|
||||
- aws_secret_access_key: AWS secret access key
|
||||
- bucket: S3 bucket name
|
||||
- prefix: Optional path prefix to filter objects
|
||||
- region: AWS region (default: us-east-1)
|
||||
- endpoint_url: Custom S3 endpoint URL (for MinIO, R2, etc.)
|
||||
|
||||
Returns:
|
||||
List of Document objects
|
||||
"""
|
||||
if isinstance(inputs, str):
|
||||
try:
|
||||
data = json.loads(inputs)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON input: {e}")
|
||||
else:
|
||||
data = inputs
|
||||
|
||||
required_fields = ["aws_access_key_id", "aws_secret_access_key", "bucket"]
|
||||
missing_fields = [field for field in required_fields if not data.get(field)]
|
||||
if missing_fields:
|
||||
raise ValueError(f"Missing required fields: {', '.join(missing_fields)}")
|
||||
|
||||
aws_access_key_id = data["aws_access_key_id"]
|
||||
aws_secret_access_key = data["aws_secret_access_key"]
|
||||
bucket = data["bucket"]
|
||||
prefix = data.get("prefix", "")
|
||||
region = data.get("region", "us-east-1")
|
||||
endpoint_url = data.get("endpoint_url", "")
|
||||
|
||||
logger.info(f"Loading data from S3 - Bucket: '{bucket}', Prefix: '{prefix}', Region: '{region}'")
|
||||
if endpoint_url:
|
||||
logger.info(f"Custom endpoint URL provided: '{endpoint_url}'")
|
||||
|
||||
corrected_bucket = self._init_client(
|
||||
aws_access_key_id, aws_secret_access_key, region, endpoint_url or None, bucket
|
||||
)
|
||||
|
||||
# Use the corrected bucket name if endpoint URL normalization extracted one
|
||||
if corrected_bucket and corrected_bucket != bucket:
|
||||
logger.info(f"Using corrected bucket name: '{corrected_bucket}' (original: '{bucket}')")
|
||||
bucket = corrected_bucket
|
||||
|
||||
objects = self.list_objects(bucket, prefix)
|
||||
documents = []
|
||||
|
||||
for key in objects:
|
||||
content = self.get_object_content(bucket, key)
|
||||
if content is None:
|
||||
continue
|
||||
|
||||
documents.append(
|
||||
Document(
|
||||
text=content,
|
||||
doc_id=key,
|
||||
extra_info={
|
||||
"title": os.path.basename(key),
|
||||
"source": f"s3://{bucket}/{key}",
|
||||
"bucket": bucket,
|
||||
"key": key,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(f"Loaded {len(documents)} documents from S3 bucket '{bucket}'")
|
||||
return documents
|
||||
@@ -1,9 +1,8 @@
|
||||
import logging
|
||||
import requests
|
||||
import re # Import regular expression library
|
||||
import defusedxml.ElementTree as ET
|
||||
import xml.etree.ElementTree as ET
|
||||
from application.parser.remote.base import BaseRemote
|
||||
from application.core.url_validation import validate_url, SSRFError
|
||||
|
||||
class SitemapLoader(BaseRemote):
|
||||
def __init__(self, limit=20):
|
||||
@@ -15,14 +14,7 @@ class SitemapLoader(BaseRemote):
|
||||
sitemap_url= inputs
|
||||
# Check if the input is a list and if it is, use the first element
|
||||
if isinstance(sitemap_url, list) and sitemap_url:
|
||||
sitemap_url = sitemap_url[0]
|
||||
|
||||
# Validate URL to prevent SSRF attacks
|
||||
try:
|
||||
sitemap_url = validate_url(sitemap_url)
|
||||
except SSRFError as e:
|
||||
logging.error(f"URL validation failed: {e}")
|
||||
return []
|
||||
url = sitemap_url[0]
|
||||
|
||||
urls = self._extract_urls(sitemap_url)
|
||||
if not urls:
|
||||
@@ -48,13 +40,8 @@ class SitemapLoader(BaseRemote):
|
||||
|
||||
def _extract_urls(self, sitemap_url):
|
||||
try:
|
||||
# Validate URL before fetching to prevent SSRF
|
||||
validate_url(sitemap_url)
|
||||
response = requests.get(sitemap_url, timeout=30)
|
||||
response = requests.get(sitemap_url)
|
||||
response.raise_for_status() # Raise an exception for HTTP errors
|
||||
except SSRFError as e:
|
||||
print(f"URL validation failed for sitemap: {sitemap_url}. Error: {e}")
|
||||
return []
|
||||
except (requests.exceptions.HTTPError, requests.exceptions.ConnectionError) as e:
|
||||
print(f"Failed to fetch sitemap: {sitemap_url}. Error: {e}")
|
||||
return []
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user