Compare commits

..

1 Commits

Author SHA1 Message Date
dependabot[bot]
5d0eed6084 chore(deps): bump mpmath from 1.3.0 to 1.4.1 in /application
Bumps [mpmath](https://github.com/mpmath/mpmath) from 1.3.0 to 1.4.1.
- [Release notes](https://github.com/mpmath/mpmath/releases)
- [Changelog](https://github.com/mpmath/mpmath/blob/1.4.1/CHANGES)
- [Commits](https://github.com/mpmath/mpmath/compare/1.3.0...1.4.1)

---
updated-dependencies:
- dependency-name: mpmath
  dependency-version: 1.4.1
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-14 19:43:29 +00:00
660 changed files with 45958 additions and 822192 deletions

View File

@@ -35,5 +35,8 @@ MICROSOFT_TENANT_ID=your-azure-ad-tenant-id
#Alternatively, use "https://login.microsoftonline.com/common" for multi-tenant app.
MICROSOFT_AUTHORITY=https://{tenantId}.ciamlogin.com/{tenantId}
# User-data Postgres DB (Phase 0 of the MongoDB→Postgres migration).
# Standard Postgres URI — `postgres://` and `postgresql://` both work.
# Leave unset while the migration is still being rolled out; the app will
# fall back to MongoDB for user data until POSTGRES_URI is configured.
# POSTGRES_URI=postgresql://docsgpt:docsgpt@localhost:5432/docsgpt

View File

@@ -1,144 +0,0 @@
# DocsGPT Public Threat Model
**Classification:** Public
**Last updated:** 2026-04-15
**Applies to:** Open-source and self-hosted DocsGPT deployments
## 1) Overview
DocsGPT ingests content (files/URLs/connectors), indexes it, and answers queries via LLM-backed APIs and optional tools.
Core components:
- Backend API (`application/`)
- Workers/ingestion (`application/worker.py` and related modules)
- Datastores (MongoDB/Redis/vector stores)
- Frontend (`frontend/`)
- Optional extensions/integrations (`extensions/`)
## 2) Scope and assumptions
In scope:
- Application-level threats in this repository.
- Local and internet-exposed self-hosted deployments.
Assumptions:
- Internet-facing instances enable auth and use strong secrets.
- Datastores/internal services are not publicly exposed.
Out of scope:
- Cloud hardware/provider compromise.
- Security guarantees of external LLM vendors.
- Full security audits of third-party systems targeted by tools (external DBs/MCP servers/code-exec APIs).
## 3) Security objectives
- Protect document/conversation confidentiality.
- Preserve integrity of prompts, agents, tools, and indexed data.
- Maintain API/worker availability.
- Enforce tenant isolation in authenticated deployments.
## 4) Assets
- Documents, attachments, chunks/embeddings, summaries.
- Conversations, agents, workflows, prompt templates.
- Secrets (JWT secret, `INTERNAL_KEY`, provider/API/OAuth credentials).
- Operational capacity (worker throughput, queue depth, model quota/cost).
## 5) Trust boundaries and untrusted input
Trust boundaries:
- Internet ↔ Frontend
- Frontend ↔ Backend API
- Backend ↔ Workers/internal APIs
- Backend/workers ↔ Datastores
- Backend ↔ External LLM/connectors/remote URLs
Untrusted input includes API payloads, file uploads, remote URLs, OAuth/webhook data, retrieved content, and LLM/tool arguments.
## 6) Main attack surfaces
1. Auth/authz paths and sharing tokens.
2. File upload + parsing pipeline.
3. Remote URL fetching and connectors (SSRF risk).
4. Agent/tool execution from LLM output.
5. Template/workflow rendering.
6. Frontend rendering + token storage.
7. Internal service endpoints (`INTERNAL_KEY`).
8. High-impact integrations (SQL tool, generic API tool, remote MCP tools).
## 7) Key threats and expected mitigations
### A. Auth/authz misconfiguration
- Threat: weak/no auth or leaked tokens leads to broad data access.
- Mitigations: require auth for public deployments, short-lived tokens, rotation/revocation, least-privilege sharing.
### B. Untrusted file ingestion
- Threat: malicious files/archives trigger traversal, parser exploits, or resource exhaustion.
- Mitigations: strict path checks, archive safeguards, file limits, patched parser dependencies.
### C. SSRF/outbound abuse
- Threat: URL loaders/tools access private/internal/metadata endpoints.
- Mitigations: validate URLs + redirects, block private/link-local ranges, apply egress controls/allowlists.
### D. Prompt injection + tool abuse
- Threat: retrieved text manipulates model behavior and causes unsafe tool calls.
- Threat: never rely on the model to "choose correctly" under adversarial input.
- Mitigations: treat retrieved/model output as untrusted, enforce tool policies, only expose tools explicitly assigned by the user/admin to that agent, separate system instructions from retrieved content, audit tool calls.
### E. Dangerous tool capability chaining (SQL/API/MCP)
- Threat: write-capable SQL credentials allow destructive queries.
- Threat: API tool can trigger side effects (infra/payment/webhook/code-exec endpoints).
- Threat: remote MCP tools may expose privileged operations.
- Mitigations: read-only-by-default credentials, destination allowlists, explicit approval for write/exec actions, per-tool policy enforcement + logging.
### F. Frontend/XSS + token theft
- Threat: XSS can steal local tokens and call APIs.
- Mitigations: reduce unsafe rendering paths, strong CSP, scoped short-lived credentials.
### G. Internal endpoint exposure
- Threat: weak/unset `INTERNAL_KEY` enables internal API abuse.
- Mitigations: fail closed, require strong random keys, keep internal APIs private.
### H. DoS and cost abuse
- Threat: request floods, large ingestion jobs, expensive prompts/crawls.
- Mitigations: rate limits, quotas, timeouts, queue backpressure, usage budgets.
## 8) Example attacker stories
- Internet-exposed deployment runs with weak/no auth and receives unauthorized data access/abuse.
- Intranet deployment intentionally using weak/no auth is vulnerable to insider misuse and lateral-movement abuse.
- Crafted archive attempts path traversal during extraction.
- Malicious URL/redirect chain targets internal services.
- Poisoned document causes data exfiltration through tool calls.
- Over-privileged SQL/API/MCP tool performs destructive side effects.
## 9) Severity calibration
- **Critical:** unauthenticated public data access; prompt-injection-driven exfiltration; SSRF to sensitive internal endpoints.
- **High:** cross-tenant leakage, persistent token compromise, over-privileged destructive tools.
- **Medium:** DoS/cost amplification and non-critical information disclosure.
- **Low:** minor hardening gaps with limited impact.
## 10) Baseline controls for public deployments
1. Enforce authentication and secure defaults.
2. Set/rotate strong secrets (`JWT`, `INTERNAL_KEY`, encryption keys).
3. Restrict CORS and front API with a hardened proxy.
4. Add rate limiting/quotas for answer/upload/crawl/token endpoints.
5. Enforce URL+redirect SSRF protections and egress restrictions.
6. Apply upload/archive/parsing hardening.
7. Require least-privilege tool credentials and auditable tool execution.
8. Monitor auth failures, tool anomalies, ingestion spikes, and cost anomalies.
9. Keep dependencies/images patched and scanned.
10. Validate multi-tenant isolation with explicit tests.
## 11) Maintenance
Review this model after major auth, ingestion, connector, tool, or workflow changes.
## References
- [OWASP Top 10 for LLM Applications](https://owasp.org/www-project-top-10-for-large-language-model-applications/)
- [OWASP ASVS](https://owasp.org/www-project-application-security-verification-standard/)
- [STRIDE overview](https://learn.microsoft.com/azure/security/develop/threat-modeling-tool-threats)
- [DocsGPT SECURITY.md](../SECURITY.md)

View File

@@ -11,6 +11,7 @@ on:
permissions:
contents: read
pull-requests: write
jobs:
vale:
@@ -19,16 +20,11 @@ jobs:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Vale
run: |
curl -fsSL -o vale.tar.gz \
https://github.com/errata-ai/vale/releases/download/v3.0.5/vale_3.0.5_Linux_64-bit.tar.gz
tar -xzf vale.tar.gz
sudo mv vale /usr/local/bin/vale
vale --version
- name: Sync Vale packages
run: vale sync
- name: Run Vale
run: vale --minAlertLevel=error docs
- name: Vale linter
uses: errata-ai/vale-action@v2
with:
files: docs
fail_on_error: false
version: 3.0.5
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

8
.gitignore vendored
View File

@@ -186,11 +186,3 @@ node_modules/
.vscode/sftp.json
/models/
model/
# E2E test artifacts
.e2e-tmp/
/tmp/docsgpt-e2e/
tests/e2e/node_modules/
tests/e2e/playwright-report/
tests/e2e/test-results/
tests/e2e/.e2e-last-run.json

View File

@@ -10,15 +10,9 @@
For feature work, do **not** assume the environment needs to be recreated.
- Check whether the user already has a Python virtual environment such as `venv/` or `.venv/`.
- Check whether Postgres is already running and reachable via `POSTGRES_URI` (the canonical user-data store).
- Check whether MongoDB is already running.
- Check whether Redis is already running.
- Reuse what is already working. Do not stop or recreate Postgres, Redis, or the Python environment unless the task is environment setup or troubleshooting.
> MongoDB is **not** required for the default install. It is only needed if
> the user opts into the Mongo vector-store backend (`VECTOR_STORE=mongodb`)
> or is running the one-shot `scripts/db/backfill.py` to migrate existing
> user data from the legacy Mongo-based install. In those cases, `pymongo`
> is available as an optional extra, not a core dependency.
- Reuse what is already working. Do not stop or recreate MongoDB, Redis, or the Python environment unless the task is environment setup or troubleshooting.
## Normal local development commands
@@ -37,22 +31,6 @@ Run the Flask API (if needed):
flask --app application/app.py run --host=0.0.0.0 --port=7091
```
That's the fast inner-loop option — quick startup, the Werkzeug interactive
debugger still works, and it hot-reloads on source changes. It serves the
Flask routes only (`/api/*`, `/stream`, etc.).
If you need to exercise the full ASGI stack — the `/mcp` FastMCP endpoint,
or to match the production runtime exactly — run the ASGI composition under
uvicorn instead:
```bash
uvicorn application.asgi:asgi_app --host 0.0.0.0 --port 7091 --reload
```
Production uses `gunicorn -k uvicorn_worker.UvicornWorker` against the same
`application.asgi:asgi_app` target; see `application/Dockerfile` for the
full flag set.
Run the Celery worker in a separate terminal (if needed):
```bash
@@ -115,7 +93,7 @@ vale .
- `frontend/`: Vite + React + TypeScript application.
- `frontend/src/`: main UI code, including `components`, `conversation`, `hooks`, `locale`, `settings`, `upload`, and Redux store wiring in `store.ts`.
- `docs/`: separate documentation site built with Next.js/Nextra.
- `extensions/`: integrations and widgets — currently the Chatwoot webhook bridge and the React widget (published to npm as `docsgpt`). The Discord bot, Slack bot, and Chrome extension have been moved to their own repos under `arc53/`.
- `extensions/`: integrations and widgets such as Chatwoot, Chrome, Discord, React widget, Slack bot, and web widget.
- `deployment/`: Docker Compose variants and Kubernetes manifests.
## Coding rules

View File

@@ -47,13 +47,11 @@
</ul>
## Roadmap
- [x] Agent Workflow Builder with conditional nodes ( February 2026 )
- [x] SharePoint & Confluence connectors ( March April 2026 )
- [x] Research mode ( March 2026 )
- [x] Postgres migration for user data ( April 2026 )
- [x] OpenTelemetry observability ( April 2026 )
- [x] Bring Your Own Model (BYOM) ( April 2026 )
- [ ] Agent scheduling (RedBeat-backed) ( Q2 2026 )
- [x] Add OAuth 2.0 authentication for MCP ( September 2025 )
- [x] Deep Agents ( October 2025 )
- [x] Prompt Templating ( October 2025 )
- [x] Full api tooling ( Dec 2025 )
- [ ] Agent scheduling ( Jan 2026 )
You can find our full roadmap [here](https://github.com/orgs/arc53/projects/2). Please don't hesitate to contribute or create issues, it helps us improve DocsGPT!

View File

@@ -8,7 +8,7 @@ RUN apt-get update && \
add-apt-repository ppa:deadsnakes/ppa && \
apt-get update && \
apt-get install -y --no-install-recommends gcc g++ wget unzip libc6-dev python3.12 python3.12-venv python3.12-dev && \
rm -rf /var/lib/apt/lists/*
rm -rf /var/lib/apt/lists/*
# Verify Python installation and setup symlink
RUN if [ -f /usr/bin/python3.12 ]; then \
@@ -73,7 +73,7 @@ COPY --from=builder /models /app/models
COPY . /app/application
# Change the ownership of the /app directory to the appuser
RUN mkdir -p /app/application/inputs/local
RUN chown -R appuser:appuser /app
@@ -82,26 +82,11 @@ ENV FLASK_APP=app.py \
FLASK_DEBUG=true \
PATH="/venv/bin:$PATH"
ENV MALLOC_ARENA_MAX=2 \
OMP_NUM_THREADS=4 \
MKL_NUM_THREADS=4 \
OPENBLAS_NUM_THREADS=4
# Expose the port the app runs on
EXPOSE 7091
# Switch to non-root user
USER appuser
CMD ["gunicorn", \
"-w", "1", \
"-k", "uvicorn_worker.UvicornWorker", \
"--bind", "0.0.0.0:7091", \
"--timeout", "180", \
"--graceful-timeout", "120", \
"--keep-alive", "5", \
"--worker-tmp-dir", "/dev/shm", \
"--max-requests", "1000", \
"--max-requests-jitter", "100", \
"--config", "application/gunicorn_conf.py", \
"application.asgi:asgi_app"]
# Start Gunicorn
CMD ["gunicorn", "-w", "1", "--timeout", "120", "--bind", "0.0.0.0:7091", "--preload", "application.wsgi:app"]

View File

@@ -42,7 +42,6 @@ class BaseAgent(ABC):
llm_handler=None,
tool_executor: Optional[ToolExecutor] = None,
backup_models: Optional[List[str]] = None,
model_user_id: Optional[str] = None,
):
self.endpoint = endpoint
self.llm_name = llm_name
@@ -53,13 +52,10 @@ class BaseAgent(ABC):
self.prompt = prompt
self.decoded_token = decoded_token or {}
self.user: str = self.decoded_token.get("sub")
# BYOM-resolution scope: owner for shared agents, caller for
# caller-owned BYOM, None for built-ins. Falls back to self.user
# for worker/legacy callers that don't thread model_user_id.
self.model_user_id = model_user_id
self.tools: List[Dict] = []
self.chat_history: List[Dict] = chat_history if chat_history is not None else []
# Dependency injection for LLM — fall back to creating if not provided
if llm is not None:
self.llm = llm
else:
@@ -71,16 +67,8 @@ class BaseAgent(ABC):
model_id=model_id,
agent_id=agent_id,
backup_models=backup_models,
model_user_id=model_user_id,
)
# For BYOM, registry id (UUID) differs from upstream model id
# (e.g. ``mistral-large-latest``). LLMCreator resolved this onto
# the LLM instance; cache it for subsequent gen calls.
self.upstream_model_id = (
getattr(self.llm, "model_id", None) or model_id
)
self.retrieved_docs = retrieved_docs or []
if llm_handler is not None:
@@ -98,7 +86,6 @@ class BaseAgent(ABC):
user_api_key=user_api_key,
user=self.user,
decoded_token=decoded_token,
agent_id=agent_id,
)
self.attachments = attachments or []
@@ -115,8 +102,6 @@ class BaseAgent(ABC):
self.compressed_summary = compressed_summary
self.current_token_count = 0
self.context_limit_reached = False
self.conversation_id: Optional[str] = None
self.initial_user_id: Optional[str] = None
@log_activity()
def gen(
@@ -321,9 +306,7 @@ class BaseAgent(ABC):
try:
current_tokens = self._calculate_current_context_tokens(messages)
self.current_token_count = current_tokens
context_limit = get_token_limit(
self.model_id, user_id=self.model_user_id or self.user
)
context_limit = get_token_limit(self.model_id)
threshold = int(context_limit * settings.COMPRESSION_THRESHOLD_PERCENTAGE)
if current_tokens >= threshold:
@@ -342,9 +325,7 @@ class BaseAgent(ABC):
current_tokens = self._calculate_current_context_tokens(messages)
self.current_token_count = current_tokens
context_limit = get_token_limit(
self.model_id, user_id=self.model_user_id or self.user
)
context_limit = get_token_limit(self.model_id)
percentage = (current_tokens / context_limit) * 100
if current_tokens >= context_limit:
@@ -406,9 +387,7 @@ class BaseAgent(ABC):
)
system_prompt = system_prompt + compression_context
context_limit = get_token_limit(
self.model_id, user_id=self.model_user_id or self.user
)
context_limit = get_token_limit(self.model_id)
system_tokens = num_tokens_from_string(system_prompt)
safety_buffer = int(context_limit * 0.1)
@@ -518,10 +497,7 @@ class BaseAgent(ABC):
def _llm_gen(self, messages: List[Dict], log_context: Optional[LogContext] = None):
self._validate_context_size(messages)
# Use the upstream id resolved by LLMCreator (see __init__).
# Built-in models: same as self.model_id. BYOM: the user's
# typed model name, not the internal UUID.
gen_kwargs = {"model": self.upstream_model_id, "messages": messages}
gen_kwargs = {"model": self.model_id, "messages": messages}
if self.attachments:
gen_kwargs["_usage_attachments"] = self.attachments

View File

@@ -1,356 +0,0 @@
"""Default chat tools — config-free tools on by default in chats."""
from __future__ import annotations
import importlib
import inspect
import logging
import uuid
from typing import Any, Dict, List, Optional
from application.core.settings import settings
logger = logging.getLogger(__name__)
# Fixed namespace — never regenerate; produced ids are persisted.
_DEFAULT_TOOL_NAMESPACE = uuid.UUID("6b1d3f2a-9c84-4d17-bf6e-2a0c5e8d4471")
# Tool names whose storage tables FK ``tool_id`` to ``user_tools.id``;
# a synthetic id has no row, so a write would FK-violate. Schema-rot
# guard: ``tests.agents.test_default_tools.TestFkBoundToolsIsInSync``.
_FK_BOUND_TOOLS = frozenset({"notes", "todo_list"})
# Tools that should NEVER appear in a headless run (scheduled or webhook).
# ``scheduler`` only makes sense from an interactive chat — letting an LLM
# call ``schedule_task`` from a scheduled run chains new schedules each fire,
# bounded only by ``SCHEDULE_MAX_PER_USER`` (cost foot-gun, confusing UX).
_HEADLESS_EXCLUDED_TOOLS = frozenset({"scheduler"})
# Agent-selectable builtins: hidden from the Add-Tool catalog (internal=True)
# and exposed to the agent picker via the same synthetic-id machinery as
# default tools. Names may overlap with DEFAULT_CHAT_TOOLS (e.g. ``scheduler``)
# — both registries share ``_DEFAULT_TOOL_NAMESPACE`` so the same uuid5
# resolves either way (the dual-flag row carries ``default`` AND ``builtin``).
BUILTIN_AGENT_TOOLS: tuple = ("scheduler",)
_tool_cache: Dict[str, Optional[Any]] = {}
_ids_cache: Dict[tuple, Dict[str, str]] = {}
_loaded_cache: Dict[tuple, List[str]] = {}
_builtin_ids_cache: Dict[tuple, Dict[str, str]] = {}
_builtin_loaded_cache: Dict[tuple, List[str]] = {}
def _load_tool(tool_name: str) -> Optional[Any]:
"""Return a metadata-only instance of a tool, or None if it has no class."""
# Imports just the named module (not the whole package) — avoids the
# circular import via ``mcp_tool`` → ``application.api.user``.
if tool_name in _tool_cache:
return _tool_cache[tool_name]
from application.agents.tools.base import Tool
instance: Optional[Any] = None
try:
module = importlib.import_module(f"application.agents.tools.{tool_name}")
except ModuleNotFoundError:
_tool_cache[tool_name] = None
return None
for _, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, Tool) and obj is not Tool:
try:
instance = obj({})
except Exception:
logger.warning(
"DEFAULT_CHAT_TOOLS entry %r failed to instantiate; skipping.",
tool_name,
)
instance = None
break
_tool_cache[tool_name] = instance
return instance
def default_tool_id(tool_name: str) -> str:
"""Return the deterministic synthetic id for a default tool name."""
return str(uuid.uuid5(_DEFAULT_TOOL_NAMESPACE, tool_name))
def default_tool_ids() -> Dict[str, str]:
"""Map each configured default-tool name to its synthetic id (memoized)."""
key = tuple(settings.DEFAULT_CHAT_TOOLS)
cached = _ids_cache.get(key)
if cached is None:
cached = {name: default_tool_id(name) for name in key}
_ids_cache[key] = cached
return cached
def is_default_tool_id(tool_id: Any) -> bool:
"""Return True if ``tool_id`` is a synthetic default-tool id."""
if not tool_id:
return False
return str(tool_id) in set(default_tool_ids().values())
def default_tool_name_for_id(tool_id: Any) -> Optional[str]:
"""Return the default-tool name for a synthetic id, or None."""
target = str(tool_id) if tool_id else ""
for name, synthetic_id in default_tool_ids().items():
if synthetic_id == target:
return name
return None
def builtin_agent_tool_ids() -> Dict[str, str]:
"""Map each agent-selectable builtin to its synthetic id (memoized)."""
key = tuple(BUILTIN_AGENT_TOOLS)
cached = _builtin_ids_cache.get(key)
if cached is None:
cached = {name: default_tool_id(name) for name in key}
_builtin_ids_cache[key] = cached
return cached
def is_builtin_agent_tool_id(tool_id: Any) -> bool:
"""Return True if ``tool_id`` is an agent-selectable builtin synthetic id."""
if not tool_id:
return False
return str(tool_id) in set(builtin_agent_tool_ids().values())
def builtin_agent_tool_name_for_id(tool_id: Any) -> Optional[str]:
"""Return the builtin tool name for a synthetic id, or None."""
target = str(tool_id) if tool_id else ""
for name, synthetic_id in builtin_agent_tool_ids().items():
if synthetic_id == target:
return name
return None
def synthesized_tool_name_for_id(tool_id: Any) -> Optional[str]:
"""Return the tool name for any synthetic id (default or builtin), or None."""
return default_tool_name_for_id(tool_id) or builtin_agent_tool_name_for_id(tool_id)
def is_synthesized_tool_id(tool_id: Any) -> bool:
"""Return True for any synthetic id (default chat or agent-builtin)."""
return is_default_tool_id(tool_id) or is_builtin_agent_tool_id(tool_id)
def loaded_default_tools() -> List[str]:
"""Return configured default-tool names that resolve to a loaded tool."""
# Silent + memoized — runs per request; the one-time skip notice
# for unimplemented names lives in ``validate_default_chat_tools``.
key = tuple(settings.DEFAULT_CHAT_TOOLS)
cached = _loaded_cache.get(key)
if cached is None:
cached = [name for name in key if _load_tool(name) is not None]
_loaded_cache[key] = cached
return cached
def loaded_builtin_agent_tools() -> List[str]:
"""Return builtin agent-tool names that resolve to a loaded tool."""
key = tuple(BUILTIN_AGENT_TOOLS)
cached = _builtin_loaded_cache.get(key)
if cached is None:
cached = [name for name in key if _load_tool(name) is not None]
_builtin_loaded_cache[key] = cached
return cached
def validate_default_chat_tools() -> List[str]:
"""Validate ``DEFAULT_CHAT_TOOLS`` at startup; return the usable names."""
skipped = [
name for name in settings.DEFAULT_CHAT_TOOLS if _load_tool(name) is None
]
if skipped:
logger.debug(
"DEFAULT_CHAT_TOOLS entries with no loaded tool, skipped: %s. "
"Each activates automatically once its tool exists.",
", ".join(skipped),
)
usable = loaded_default_tools()
for name in usable:
if name in _FK_BOUND_TOOLS:
raise ValueError(
f"DEFAULT_CHAT_TOOLS entry {name!r} has a storage table "
f"that foreign-keys tool_id to user_tools; a default tool "
f"has a synthetic id with no user_tools row, so it would "
f"fail at write time. It cannot be defaulted on."
)
requirements = _load_tool(name).get_config_requirements() or {}
required = [
key for key, spec in requirements.items()
if isinstance(spec, dict) and spec.get("required")
]
if required:
raise ValueError(
f"DEFAULT_CHAT_TOOLS entry {name!r} requires config "
f"fields {required}; only config-free tools may be "
"defaulted on."
)
if usable:
logger.info("Default chat tools active: %s", ", ".join(usable))
return usable
def _tool_display(tool_name: str) -> str:
"""Return the human-readable display name from the tool docstring."""
tool = _load_tool(tool_name)
doc = (tool.__doc__ or "").strip() if tool else ""
first_line = doc.split("\n", 1)[0].strip() if doc else ""
return first_line or tool_name
def _tool_description(tool_name: str) -> str:
"""Return the tool description (docstring lines after the first)."""
tool = _load_tool(tool_name)
doc = (tool.__doc__ or "").strip() if tool else ""
parts = doc.split("\n", 1)
return parts[1].strip() if len(parts) > 1 else ""
def synthesize_default_tool(tool_name: str) -> Optional[Dict[str, Any]]:
"""Build an in-memory ``user_tools``-shaped row for a default tool."""
tool = _load_tool(tool_name)
if tool is None:
return None
synthetic_id = default_tool_id(tool_name)
return {
"id": synthetic_id,
"_id": synthetic_id,
"name": tool_name,
"display_name": _tool_display(tool_name),
"custom_name": "",
"description": _tool_description(tool_name),
"config": {},
"config_requirements": {},
"actions": tool.get_actions_metadata() or [],
"status": True,
"default": True,
}
def synthesize_builtin_agent_tool(tool_name: str) -> Optional[Dict[str, Any]]:
"""Build an in-memory ``user_tools``-shaped row for a builtin agent tool."""
tool = _load_tool(tool_name)
if tool is None:
return None
synthetic_id = default_tool_id(tool_name)
return {
"id": synthetic_id,
"_id": synthetic_id,
"name": tool_name,
"display_name": _tool_display(tool_name),
"custom_name": "",
"description": _tool_description(tool_name),
"config": {},
"config_requirements": {},
"actions": tool.get_actions_metadata() or [],
"status": True,
"default": False,
"builtin": True,
}
def synthesize_tool_by_name(tool_name: str) -> Optional[Dict[str, Any]]:
"""Synthesize the row for any default or builtin tool name."""
if tool_name in BUILTIN_AGENT_TOOLS:
return synthesize_builtin_agent_tool(tool_name)
return synthesize_default_tool(tool_name)
def disabled_default_tools(user_doc: Optional[Dict[str, Any]]) -> List[str]:
"""Return the user's opt-out list from ``tool_preferences``."""
if not isinstance(user_doc, dict):
return []
prefs = user_doc.get("tool_preferences") or {}
if not isinstance(prefs, dict):
return []
disabled = prefs.get("disabled_default_tools") or []
if not isinstance(disabled, list):
return []
return [str(name) for name in disabled]
def synthesized_default_tools(
user_doc: Optional[Dict[str, Any]] = None,
*,
headless: bool = False,
) -> List[Dict[str, Any]]:
"""Return synthesized default-tool rows for an agentless chat."""
# Agent-bound chats must NOT call this — they resolve exactly
# ``agents.tools``. Disabled defaults are dropped. ``headless=True``
# additionally drops chat-only tools (e.g. ``scheduler``) so a scheduled
# / webhook LLM can't re-schedule itself.
disabled = set(disabled_default_tools(user_doc))
rows: List[Dict[str, Any]] = []
for name in loaded_default_tools():
if name in disabled:
continue
if headless and name in _HEADLESS_EXCLUDED_TOOLS:
continue
row = synthesize_default_tool(name)
if row is not None:
rows.append(row)
return rows
def is_headless_excluded_tool(tool_name: Optional[str]) -> bool:
"""Return True if ``tool_name`` must be hidden from headless runs."""
return bool(tool_name) and tool_name in _HEADLESS_EXCLUDED_TOOLS
def default_tools_for_management(
user_doc: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, Any]]:
"""Return every loaded default tool with its on/off ``status``."""
# Unlike ``synthesized_default_tools`` (chat toolset), this keeps
# disabled tools so the management UI can render their toggle.
disabled = set(disabled_default_tools(user_doc))
rows: List[Dict[str, Any]] = []
for name in loaded_default_tools():
row = synthesize_default_tool(name)
if row is None:
continue
row["status"] = name not in disabled
rows.append(row)
return rows
def builtin_agent_tools_for_management() -> List[Dict[str, Any]]:
"""Return every loaded agent-builtin tool for the agent picker (no per-user state)."""
rows: List[Dict[str, Any]] = []
for name in loaded_builtin_agent_tools():
row = synthesize_builtin_agent_tool(name)
if row is None:
continue
rows.append(row)
return rows
def resolve_tool_by_id(
tool_id: Any,
user: Optional[str],
*,
user_tools_repo: Any = None,
) -> Optional[Dict[str, Any]]:
"""Resolve a tool by id: default/builtin synthetic id, else user_tools row.
Dual-registered tools (e.g. ``scheduler``) get both flags on the resolved
row so callers can branch on either path without losing the discriminator.
"""
default_name = default_tool_name_for_id(tool_id)
builtin_name = builtin_agent_tool_name_for_id(tool_id)
if default_name is not None and builtin_name is not None:
row = synthesize_default_tool(default_name) or {}
row["builtin"] = True
return row or None
if default_name is not None:
return synthesize_default_tool(default_name)
if builtin_name is not None:
return synthesize_builtin_agent_tool(builtin_name)
if user_tools_repo is None or not user:
return None
return user_tools_repo.get_any(str(tool_id), user)

View File

@@ -1,173 +0,0 @@
"""Shared headless agent runner used by webhooks and scheduled runs."""
from __future__ import annotations
import logging
from typing import Any, Dict, Iterable, List, Optional
from application.agents.agent_creator import AgentCreator
from application.agents.tool_executor import ToolExecutor
from application.api.answer.services.stream_processor import get_prompt
from application.core.settings import settings
from application.retriever.retriever_creator import RetrieverCreator
from application.storage.db.repositories.sources import SourcesRepository
from application.storage.db.session import db_readonly
logger = logging.getLogger(__name__)
def _resolve_owner(agent_config: Dict[str, Any]) -> Optional[str]:
return agent_config.get("user_id") or agent_config.get("user")
def _resolve_agent_id(agent_config: Dict[str, Any]) -> Optional[str]:
raw = agent_config.get("id") or agent_config.get("_id")
return str(raw) if raw else None
def run_agent_headless(
agent_config: Dict[str, Any],
query: str,
*,
tool_allowlist: Optional[Iterable[str]] = None,
model_id_override: Optional[str] = None,
endpoint: str = "headless",
chat_history: Optional[List[Dict[str, Any]]] = None,
conversation_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Run an agent with no live client; returns a structured outcome dict."""
from application.core.model_utils import (
get_api_key_for_provider,
get_default_model_id,
get_provider_from_model_id,
validate_model_id,
)
from application.utils import calculate_doc_token_budget
owner = _resolve_owner(agent_config)
if not owner:
raise ValueError("Agent config is missing user_id; cannot run headless.")
decoded_token = {"sub": owner}
retriever_kind = agent_config.get("retriever", "classic")
source_id = agent_config.get("source_id") or agent_config.get("source")
source_active: Any = {}
if source_id:
with db_readonly() as conn:
src_row = SourcesRepository(conn).get(str(source_id), owner)
if src_row:
source_active = str(src_row["id"])
retriever_kind = src_row.get("retriever", retriever_kind)
source = {"active_docs": source_active}
chunks = int(agent_config.get("chunks", 2) or 2)
prompt_id = agent_config.get("prompt_id", "default")
user_api_key = agent_config.get("key")
agent_id = _resolve_agent_id(agent_config)
agent_type = agent_config.get("agent_type", "classic")
json_schema = agent_config.get("json_schema")
prompt = get_prompt(prompt_id)
candidate_model = model_id_override or agent_config.get("default_model_id") or ""
if candidate_model and validate_model_id(candidate_model, user_id=owner):
model_id = candidate_model
else:
model_id = get_default_model_id()
if candidate_model:
logger.warning(
"Agent %s references unknown model_id %r; falling back to %r",
agent_id, candidate_model, model_id,
)
provider = (
get_provider_from_model_id(model_id, user_id=owner)
if model_id
else settings.LLM_PROVIDER
)
system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER)
doc_token_limit = calculate_doc_token_budget(model_id=model_id, user_id=owner)
retriever = RetrieverCreator.create_retriever(
retriever_kind,
source=source,
chat_history=chat_history or [],
prompt=prompt,
chunks=chunks,
doc_token_limit=doc_token_limit,
model_id=model_id,
user_api_key=user_api_key,
agent_id=agent_id,
decoded_token=decoded_token,
)
retrieved_docs: List[Dict[str, Any]] = []
try:
docs = retriever.search(query)
if docs:
retrieved_docs = docs
except Exception as exc:
logger.warning("Headless retrieve failed: %s", exc)
tool_executor = ToolExecutor(
user_api_key=user_api_key,
user=owner,
decoded_token=decoded_token,
agent_id=agent_id,
headless=True,
tool_allowlist=list(tool_allowlist or []),
)
if conversation_id:
tool_executor.conversation_id = str(conversation_id)
agent = AgentCreator.create_agent(
agent_type,
endpoint=endpoint,
llm_name=provider or settings.LLM_PROVIDER,
model_id=model_id,
api_key=system_api_key,
agent_id=agent_id,
user_api_key=user_api_key,
prompt=prompt,
chat_history=chat_history or [],
retrieved_docs=retrieved_docs,
decoded_token=decoded_token,
attachments=[],
json_schema=json_schema,
tool_executor=tool_executor,
)
if conversation_id:
agent.conversation_id = str(conversation_id)
answer_full = ""
thought = ""
sources_log: List[Dict[str, Any]] = []
tool_calls: List[Dict[str, Any]] = []
for event in agent.gen(query=query):
if not isinstance(event, dict):
continue
if "answer" in event:
answer_full += str(event["answer"])
elif "sources" in event:
sources_log.extend(event["sources"])
elif "tool_calls" in event:
tool_calls.extend(event["tool_calls"])
elif "thought" in event:
thought += str(event["thought"])
denied = list(getattr(tool_executor, "headless_denials", []))
error_type = "tool_not_allowed" if denied and not answer_full.strip() else None
# Use the LLM accumulator (gen_token_usage / stream_token_usage decorators);
# current_token_count is a context-size sentinel, not a usage tally.
llm_usage = getattr(getattr(agent, "llm", None), "token_usage", None) or {}
prompt_tokens = int(llm_usage.get("prompt_tokens", 0) or 0)
generated_tokens = int(llm_usage.get("generated_tokens", 0) or 0)
return {
"answer": answer_full,
"thought": thought,
"sources": sources_log,
"tool_calls": tool_calls,
"prompt_tokens": prompt_tokens,
"generated_tokens": generated_tokens,
"denied": denied,
"error_type": error_type,
"model_id": model_id,
}

View File

@@ -312,7 +312,7 @@ class ResearchAgent(BaseAgent):
try:
response = self.llm.gen(
model=self.upstream_model_id,
model=self.model_id,
messages=messages,
tools=None,
response_format={"type": "json_object"},
@@ -390,7 +390,7 @@ class ResearchAgent(BaseAgent):
try:
response = self.llm.gen(
model=self.upstream_model_id,
model=self.model_id,
messages=messages,
tools=None,
response_format={"type": "json_object"},
@@ -506,7 +506,7 @@ class ResearchAgent(BaseAgent):
try:
response = self.llm.gen(
model=self.upstream_model_id,
model=self.model_id,
messages=messages,
tools=self.tools if self.tools else None,
)
@@ -537,7 +537,7 @@ class ResearchAgent(BaseAgent):
)
try:
response = self.llm.gen(
model=self.upstream_model_id, messages=messages, tools=None
model=self.model_id, messages=messages, tools=None
)
self._track_tokens(self._snapshot_llm_tokens())
text = self._extract_text(response)
@@ -664,7 +664,7 @@ class ResearchAgent(BaseAgent):
]
llm_response = self.llm.gen_stream(
model=self.upstream_model_id, messages=messages, tools=None
model=self.model_id, messages=messages, tools=None
)
if log_context:

View File

@@ -1,131 +0,0 @@
"""Cron/tz computations for the scheduler (shared by dispatcher, routes, and tool)."""
from __future__ import annotations
import re
from datetime import datetime, timedelta, timezone
from typing import Optional
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
from croniter import croniter
_DELAY_RE = re.compile(r"^\s*(\d+)\s*(s|m|h|d)\s*$", re.IGNORECASE)
_DELAY_MULTIPLIERS = {"s": 1, "m": 60, "h": 3600, "d": 86_400}
class ScheduleValidationError(ValueError):
"""Raised when a schedule's cron, run_at, or delay is invalid."""
def resolve_timezone(tz_name: Optional[str]) -> ZoneInfo:
"""Return a ``ZoneInfo`` for ``tz_name`` (default UTC)."""
name = (tz_name or "UTC").strip() or "UTC"
try:
return ZoneInfo(name)
except ZoneInfoNotFoundError as exc:
raise ScheduleValidationError(f"Unknown timezone: {name}") from exc
def parse_cron(expression: str) -> None:
"""Validate a 5-field cron expression; raise on bad input."""
# croniter defers some malformed inputs until get_next, so force one here.
if not expression or not isinstance(expression, str):
raise ScheduleValidationError("Cron expression is required.")
fields = expression.strip().split()
if len(fields) != 5:
raise ScheduleValidationError("Cron expression must have 5 fields.")
try:
itr = croniter(expression, datetime.now(timezone.utc))
itr.get_next(datetime)
except (ValueError, KeyError) as exc:
raise ScheduleValidationError(f"Invalid cron expression: {exc}") from exc
_CRON_INTERVAL_WINDOW = 64
def cron_interval_seconds(expression: str, tz_name: Optional[str]) -> int:
"""Return the smallest gap between ticks in a rolling window (enforces SCHEDULE_MIN_INTERVAL).
Walks _CRON_INTERVAL_WINDOW ticks because bursty expressions like
``* 9 * * *`` have tiny within-burst gaps and huge between-burst gaps;
sampling only two adjacent ticks would miss the small gap.
"""
parse_cron(expression)
tz = resolve_timezone(tz_name)
anchor_local = datetime.now(timezone.utc).astimezone(tz)
itr = croniter(expression, anchor_local)
prev = itr.get_next(datetime)
smallest: Optional[int] = None
for _ in range(_CRON_INTERVAL_WINDOW - 1):
nxt = itr.get_next(datetime)
gap = int((nxt - prev).total_seconds())
if gap > 0 and (smallest is None or gap < smallest):
smallest = gap
prev = nxt
return smallest if smallest is not None else 0
def next_cron_run(
expression: str,
tz_name: Optional[str],
after: Optional[datetime] = None,
) -> datetime:
"""Return the next fire time strictly after ``after`` (UTC, tz-aware).
Evaluates the cadence in the schedule's IANA tz so DST boundaries land on
the intended local clock-time (e.g. 9 AM Warsaw stays 9 AM across the jump).
"""
parse_cron(expression)
tz = resolve_timezone(tz_name)
anchor_utc = after if after is not None else datetime.now(timezone.utc)
if anchor_utc.tzinfo is None:
anchor_utc = anchor_utc.replace(tzinfo=timezone.utc)
anchor_local = anchor_utc.astimezone(tz)
itr = croniter(expression, anchor_local)
nxt_local = itr.get_next(datetime)
return nxt_local.astimezone(timezone.utc)
def parse_delay(delay: str) -> timedelta:
"""Parse a duration like ``30m`` / ``2h`` / ``1d`` into a timedelta."""
if not isinstance(delay, str):
raise ScheduleValidationError("delay must be a string like '30m' or '2h'.")
match = _DELAY_RE.match(delay)
if not match:
raise ScheduleValidationError(
"delay must look like '30s', '15m', '2h', or '1d'."
)
amount, unit = int(match.group(1)), match.group(2).lower()
if amount <= 0:
raise ScheduleValidationError("delay must be positive.")
return timedelta(seconds=amount * _DELAY_MULTIPLIERS[unit])
def parse_run_at(run_at: str, tz_name: Optional[str] = None) -> datetime:
"""Parse an ISO 8601 timestamp; naive values resolve in ``tz_name``.
Naive values inside the DST "fall back" hour resolve to the earlier instance
(zoneinfo default fold=0); pass an explicit offset to select the later one.
"""
if not isinstance(run_at, str) or not run_at.strip():
raise ScheduleValidationError("run_at must be an ISO 8601 string.")
try:
parsed = datetime.fromisoformat(run_at.strip().replace("Z", "+00:00"))
except ValueError as exc:
raise ScheduleValidationError(f"Invalid run_at: {exc}") from exc
if parsed.tzinfo is None:
parsed = parsed.replace(tzinfo=resolve_timezone(tz_name))
return parsed.astimezone(timezone.utc)
def clamp_once_horizon(run_at: datetime, max_horizon_seconds: int) -> None:
"""Raise when ``run_at`` is in the past or beyond the once-task horizon."""
now = datetime.now(timezone.utc)
if run_at <= now:
raise ScheduleValidationError("run_at is in the past.")
if max_horizon_seconds > 0 and run_at - now > timedelta(seconds=max_horizon_seconds):
raise ScheduleValidationError(
"run_at is beyond the maximum allowed scheduling horizon."
)

View File

@@ -1,113 +1,19 @@
import logging
import uuid
from collections import Counter
from typing import Any, Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple
from bson.objectid import ObjectId
from application.agents.default_tools import (
is_headless_excluded_tool,
resolve_tool_by_id,
synthesized_default_tools,
)
from application.agents.tools.tool_action_parser import ToolActionParser
from application.agents.tools.tool_manager import ToolManager
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.security.encryption import decrypt_credentials
from application.storage.db.base_repository import looks_like_uuid
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.repositories.tool_call_attempts import (
ToolCallAttemptsRepository,
)
from application.storage.db.repositories.user_tools import UserToolsRepository
from application.storage.db.repositories.users import UsersRepository
from application.storage.db.session import db_readonly, db_session
logger = logging.getLogger(__name__)
def _record_proposed(
call_id: str,
tool_name: str,
action_name: str,
arguments: Any,
*,
tool_id: Optional[str] = None,
) -> bool:
"""Insert a ``proposed`` row; swallow infra failures so tool calls
still run when the journal is unreachable. Returns True iff the row
is now journaled (newly created or already present).
"""
try:
with db_session() as conn:
inserted = ToolCallAttemptsRepository(conn).record_proposed(
call_id,
tool_name,
action_name,
arguments,
tool_id=tool_id if tool_id and looks_like_uuid(tool_id) else None,
)
if not inserted:
logger.warning(
"tool_call_attempts duplicate call_id=%s; existing row left in place",
call_id,
extra={"alert": "tool_call_id_collision", "call_id": call_id},
)
return True
except Exception:
logger.exception("tool_call_attempts proposed write failed for %s", call_id)
return False
def _mark_executed(
call_id: str,
result: Any,
*,
message_id: Optional[str] = None,
artifact_id: Optional[str] = None,
proposed_ok: bool = True,
tool_name: Optional[str] = None,
action_name: Optional[str] = None,
arguments: Any = None,
tool_id: Optional[str] = None,
) -> None:
"""Flip the row to ``executed``. If ``proposed_ok`` is False (the
proposed write failed earlier), upsert a fresh row in ``executed`` so
the reconciler can still see the attempt — without this, the side
effect would be invisible to the journal.
"""
try:
with db_session() as conn:
repo = ToolCallAttemptsRepository(conn)
if proposed_ok:
updated = repo.mark_executed(
call_id,
result,
message_id=message_id,
artifact_id=artifact_id,
)
if updated:
return
# Fallback synthesizes the row so the journal isn't lost.
repo.upsert_executed(
call_id,
tool_name=tool_name or "unknown",
action_name=action_name or "",
arguments=arguments if arguments is not None else {},
result=result,
tool_id=tool_id if tool_id and looks_like_uuid(tool_id) else None,
message_id=message_id,
artifact_id=artifact_id,
)
except Exception:
logger.exception("tool_call_attempts executed write failed for %s", call_id)
def _mark_failed(call_id: str, error: str) -> None:
try:
with db_session() as conn:
ToolCallAttemptsRepository(conn).mark_failed(call_id, error)
except Exception:
logger.exception("tool_call_attempts failed-write failed for %s", call_id)
class ToolExecutor:
"""Handles tool discovery, preparation, and execution.
@@ -119,31 +25,16 @@ class ToolExecutor:
user_api_key: Optional[str] = None,
user: Optional[str] = None,
decoded_token: Optional[Dict] = None,
agent_id: Optional[str] = None,
*,
headless: bool = False,
tool_allowlist: Optional[List[str]] = None,
):
self.user_api_key = user_api_key
self.user = user
self.decoded_token = decoded_token
self.agent_id = agent_id
# Headless mode (scheduled / webhook): no human to resolve a pause,
# so check_pause returns headless_denied sentinels instead.
self.headless = bool(headless)
# Tool-instance ids pre-authorized for headless approval-gated execution.
self.tool_allowlist: set = (
{str(x) for x in tool_allowlist} if tool_allowlist else set()
)
self.tool_calls: List[Dict] = []
self._loaded_tools: Dict[str, object] = {}
self.conversation_id: Optional[str] = None
self.message_id: Optional[str] = None
self.client_tools: Optional[List[Dict]] = None
self._name_to_tool: Dict[str, Tuple[str, str]] = {}
self._tool_to_name: Dict[Tuple[str, str], str] = {}
# Filled by the LLMHandler.handle_tool_calls headless loop.
self.headless_denials: List[Dict] = []
def get_tools(self) -> Dict[str, Dict]:
"""Load tool configs from DB based on user context.
@@ -160,54 +51,31 @@ class ToolExecutor:
return tools
def _get_tools_by_api_key(self, api_key: str) -> Dict[str, Dict]:
"""Resolve an agent's toolset — exactly ``agents.tools``, no defaults."""
# Per-operation session: the answer pipeline spans a long-lived
# generator; wrapping it in a single connection would pin a PG
# conn for the whole stream. Open, fetch, close.
with db_readonly() as conn:
agent_data = AgentsRepository(conn).find_by_key(api_key)
tool_ids = agent_data.get("tools", []) if agent_data else []
tools_repo = UserToolsRepository(conn)
owner = (
(agent_data.get("user_id") or agent_data.get("user"))
if agent_data
else None
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
agents_collection = db["agents"]
tools_collection = db["user_tools"]
agent_data = agents_collection.find_one({"key": api_key})
tool_ids = agent_data.get("tools", []) if agent_data else []
tools = (
tools_collection.find(
{"_id": {"$in": [ObjectId(tool_id) for tool_id in tool_ids]}}
)
tools: List[Dict] = []
for tid in tool_ids:
row = resolve_tool_by_id(tid, owner, user_tools_repo=tools_repo)
if row is None:
continue
# Headless runs (scheduled / webhook) drop chat-only tools
# like ``scheduler`` so a fire-time LLM can't chain schedules.
if self.headless and is_headless_excluded_tool(row.get("name")):
continue
tools.append(row)
return {str(tool["id"]): tool for tool in tools}
if tool_ids
else []
)
tools = list(tools)
return {str(tool["_id"]): tool for tool in tools} if tools else {}
def _get_user_tools(self, user: str = "local") -> Dict[str, Dict]:
"""Resolve an agentless chat's toolset: explicit user tools plus defaults."""
with db_readonly() as conn:
user_tools = UserToolsRepository(conn).list_active_for_user(user)
user_doc = (
UsersRepository(conn).get(user) if self.agent_id is None else None
)
# Headless agentless runs (e.g. scheduled fire) drop chat-only
# tools (``scheduler``) from explicit user_tools too.
filtered_user_tools = [
t for t in user_tools
if not (self.headless and is_headless_excluded_tool(t.get("name")))
]
# Index keys (ints) and synthetic uuid5 keys can't collide.
tools: Dict[str, Dict] = {
str(i): tool for i, tool in enumerate(filtered_user_tools)
}
if self.agent_id is None:
for default_row in synthesized_default_tools(
user_doc, headless=self.headless,
):
tools[str(default_row["id"])] = default_row
return tools
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
user_tools_collection = db["user_tools"]
user_tools = user_tools_collection.find({"user": user, "status": True})
user_tools = list(user_tools)
return {str(i): tool for i, tool in enumerate(user_tools)}
def merge_client_tools(
self, tools_dict: Dict, client_tools: List[Dict]
@@ -345,11 +213,9 @@ class ToolExecutor:
def check_pause(
self, tools_dict: Dict, call, llm_class_name: str
) -> Optional[Dict]:
"""Return a pending-action dict (approval / client / headless_denied) or None.
"""Check if a tool call requires pausing for approval or client execution.
In headless mode the dict's pause_type is ``headless_denied`` so the
upstream loop synthesizes a tool result instead of pausing (nothing can
resume a scheduled / webhook run).
Returns a dict describing the pending action if pause is needed, None otherwise.
"""
parser = ToolActionParser(llm_class_name, name_mapping=self._name_to_tool)
tool_id, action_name, call_args = parser.parse_args(call)
@@ -360,26 +226,9 @@ class ToolExecutor:
return None # Will be handled as error by execute()
tool_data = tools_dict[tool_id]
arguments = call_args if isinstance(call_args, dict) else {}
# Client-side tools
if tool_data.get("client_side"):
if self.headless:
return {
"call_id": call_id,
"name": llm_name,
"tool_name": tool_data.get("name", "unknown"),
"tool_id": tool_id,
"action_name": action_name,
"llm_name": llm_name,
"arguments": arguments,
"pause_type": "headless_denied",
"deny_reason": (
"Client-side tools cannot run in headless / scheduled runs."
),
"error_type": "tool_not_allowed",
"thought_signature": getattr(call, "thought_signature", None),
}
return {
"call_id": call_id,
"name": llm_name,
@@ -387,7 +236,7 @@ class ToolExecutor:
"tool_id": tool_id,
"action_name": action_name,
"llm_name": llm_name,
"arguments": arguments,
"arguments": call_args if isinstance(call_args, dict) else {},
"pause_type": "requires_client_execution",
"thought_signature": getattr(call, "thought_signature", None),
}
@@ -404,27 +253,6 @@ class ToolExecutor:
)
if action_data.get("require_approval"):
if self.headless:
tool_row_id = str(tool_data.get("id") or tool_id)
if tool_row_id in self.tool_allowlist:
# Pre-authorized for headless execution — fall through.
return None
return {
"call_id": call_id,
"name": llm_name,
"tool_name": tool_data.get("name", "unknown"),
"tool_id": tool_id,
"action_name": action_name,
"llm_name": llm_name,
"arguments": arguments,
"pause_type": "headless_denied",
"deny_reason": (
"This tool requires approval and is not in the run's "
"tool_allowlist."
),
"error_type": "tool_not_allowed",
"thought_signature": getattr(call, "thought_signature", None),
}
return {
"call_id": call_id,
"name": llm_name,
@@ -432,7 +260,7 @@ class ToolExecutor:
"tool_id": tool_id,
"action_name": action_name,
"llm_name": llm_name,
"arguments": arguments,
"arguments": call_args if isinstance(call_args, dict) else {},
"pause_type": "awaiting_approval",
"thought_signature": getattr(call, "thought_signature", None),
}
@@ -449,14 +277,7 @@ class ToolExecutor:
if tool_id is None or action_name is None:
error_message = f"Error: Failed to parse LLM tool call. Tool name: {llm_name}"
logger.error(
"tool_call_parse_failed",
extra={
"llm_class_name": llm_class_name,
"llm_tool_name": llm_name,
"call_id": call_id,
},
)
logger.error(error_message)
tool_call_data = {
"tool_name": "unknown",
@@ -471,15 +292,7 @@ class ToolExecutor:
if tool_id not in tools_dict:
error_message = f"Error: Tool ID '{tool_id}' extracted from LLM call not found in available tools_dict. Available IDs: {list(tools_dict.keys())}"
logger.error(
"tool_id_not_found",
extra={
"tool_id": tool_id,
"llm_tool_name": llm_name,
"call_id": call_id,
"available_tool_count": len(tools_dict),
},
)
logger.error(error_message)
tool_call_data = {
"tool_name": "unknown",
@@ -498,36 +311,9 @@ class ToolExecutor:
"action_name": llm_name,
"arguments": call_args,
}
tool_data = tools_dict[tool_id]
# Journal first so the reconciler sees malformed calls and any
# subsequent ``_mark_failed`` actually updates a real row.
proposed_ok = _record_proposed(
call_id,
tool_data["name"],
action_name,
call_args if isinstance(call_args, dict) else {},
tool_id=tool_data.get("id"),
)
# Defensive guard: a non-dict ``call_args`` (e.g. malformed
# JSON on the resume path) would crash the param walk below
# with AttributeError on ``.items()``. Surface a clean error
# event and flip the journal row to ``failed`` instead of
# killing the stream.
if not isinstance(call_args, dict):
error_message = (
f"Tool call arguments must be a JSON object, got "
f"{type(call_args).__name__}."
)
tool_call_data["result"] = error_message
tool_call_data["arguments"] = {}
_mark_failed(call_id, error_message)
yield {
"type": "tool_call",
"data": {**tool_call_data, "status": "error"},
}
self.tool_calls.append(tool_call_data)
return error_message, call_id
yield {"type": "tool_call", "data": {**tool_call_data, "status": "pending"}}
tool_data = tools_dict[tool_id]
action_data = (
tool_data["config"]["actions"][action_name]
if tool_data["name"] == "api_tool"
@@ -568,43 +354,19 @@ class ToolExecutor:
headers=headers, query_params=query_params,
)
if tool is None:
error_message = (
f"Failed to load tool '{tool_data.get('name')}' (tool_id key={tool_id}): "
"missing 'id' on tool row."
)
logger.error(
"tool_load_failed",
extra={
"tool_name": tool_data.get("name"),
"tool_id": tool_id,
"action_name": action_name,
"call_id": call_id,
},
)
tool_call_data["result"] = error_message
_mark_failed(call_id, error_message)
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
self.tool_calls.append(tool_call_data)
return error_message, call_id
resolved_arguments = (
{"query_params": query_params, "headers": headers, "body": body}
if tool_data["name"] == "api_tool"
else parameters
)
try:
if tool_data["name"] == "api_tool":
logger.debug(
f"Executing api: {action_name} with query_params: {query_params}, headers: {headers}, body: {body}"
)
result = tool.execute_action(action_name, **body)
else:
logger.debug(f"Executing tool: {action_name} with args: {call_args}")
result = tool.execute_action(action_name, **parameters)
except Exception as exc:
_mark_failed(call_id, str(exc))
raise
if tool_data["name"] == "api_tool":
logger.debug(
f"Executing api: {action_name} with query_params: {query_params}, headers: {headers}, body: {body}"
)
result = tool.execute_action(action_name, **body)
else:
logger.debug(f"Executing tool: {action_name} with args: {call_args}")
result = tool.execute_action(action_name, **parameters)
get_artifact_id = (
getattr(tool, "get_artifact_id", None)
@@ -633,22 +395,6 @@ class ToolExecutor:
f"{result_full[:50]}..." if len(result_full) > 50 else result_full
)
# Tool side effect has run; flip the journal row so the
# message-finalize path can later confirm it. If the proposed
# write failed (DB outage), upsert a fresh row in ``executed`` so
# the reconciler still sees the side effect.
_mark_executed(
call_id,
result_full,
message_id=self.message_id,
artifact_id=artifact_id or None,
proposed_ok=proposed_ok,
tool_name=tool_data["name"],
action_name=action_name,
arguments=call_args,
tool_id=tool_data.get("id"),
)
stream_tool_call_data = {
key: value
for key, value in tool_call_data.items()
@@ -694,27 +440,9 @@ class ToolExecutor:
tool_config.update(decrypted)
tool_config["auth_credentials"] = decrypted
tool_config.pop("encrypted_credentials", None)
row_id = tool_data.get("id")
if not row_id:
logger.error(
"tool_missing_row_id",
extra={
"tool_name": tool_data.get("name"),
"tool_id": tool_id,
"action_name": action_name,
},
)
return None
tool_config["tool_id"] = str(row_id)
tool_config["tool_id"] = str(tool_data.get("_id", tool_id))
if self.conversation_id:
tool_config["conversation_id"] = self.conversation_id
if tool_data["name"] == "scheduler":
# Agent-bound: stamp schedules.agent_id. Agentless: the tool
# falls back to ``origin_conversation_id`` as the schedule's
# conversation home.
tool_config["agent_id"] = (
str(self.agent_id) if self.agent_id else None
)
if tool_data["name"] == "mcp_tool":
tool_config["query_mode"] = True

View File

@@ -39,7 +39,6 @@ class InternalSearchTool(Tool):
chunks=int(self.config.get("chunks", 2)),
doc_token_limit=int(self.config.get("doc_token_limit", 50000)),
model_id=self.config.get("model_id", "docsgpt-local"),
model_user_id=self.config.get("model_user_id"),
user_api_key=self.config.get("user_api_key"),
agent_id=self.config.get("agent_id"),
llm_name=self.config.get("llm_name", settings.LLM_PROVIDER),
@@ -49,7 +48,7 @@ class InternalSearchTool(Tool):
return self._retriever
def _get_directory_structure(self) -> Optional[Dict]:
"""Load directory structure from Postgres for the configured sources."""
"""Load directory structure from MongoDB for the configured sources."""
if self._dir_structure_loaded:
return self._directory_structure
@@ -60,39 +59,35 @@ class InternalSearchTool(Tool):
return None
try:
# Per-operation session: this tool runs inside the answer
# generator hot path, so we open a short-lived read
# connection for the batch lookup and release immediately.
from application.storage.db.repositories.sources import (
SourcesRepository,
)
from application.storage.db.session import db_readonly
from bson.objectid import ObjectId
from application.core.mongo_db import MongoDB
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
sources_collection = db["sources"]
if isinstance(active_docs, str):
active_docs = [active_docs]
decoded_token = self.config.get("decoded_token") or {}
user_id = decoded_token.get("sub") if decoded_token else None
merged_structure = {}
with db_readonly() as conn:
repo = SourcesRepository(conn)
for doc_id in active_docs:
try:
source_doc = repo.get_any(str(doc_id), user_id) if user_id else None
if not source_doc:
continue
dir_str = source_doc.get("directory_structure")
if dir_str:
if isinstance(dir_str, str):
dir_str = json.loads(dir_str)
source_name = source_doc.get("name", doc_id)
if len(active_docs) > 1:
merged_structure[source_name] = dir_str
else:
merged_structure = dir_str
except Exception as e:
logger.debug(f"Could not load dir structure for {doc_id}: {e}")
for doc_id in active_docs:
try:
source_doc = sources_collection.find_one(
{"_id": ObjectId(doc_id)}
)
if not source_doc:
continue
dir_str = source_doc.get("directory_structure")
if dir_str:
if isinstance(dir_str, str):
dir_str = json.loads(dir_str)
source_name = source_doc.get("name", doc_id)
if len(active_docs) > 1:
merged_structure[source_name] = dir_str
else:
merged_structure = dir_str
except Exception as e:
logger.debug(f"Could not load dir structure for {doc_id}: {e}")
self._directory_structure = merged_structure if merged_structure else None
except Exception as e:
@@ -362,48 +357,32 @@ INTERNAL_TOOL_ENTRY = build_internal_tool_entry(has_directory_structure=False)
def sources_have_directory_structure(source: Dict) -> bool:
"""Check if any of the active sources have a ``directory_structure`` row."""
"""Check if any of the active sources have directory_structure in MongoDB."""
active_docs = source.get("active_docs", [])
if not active_docs:
return False
try:
# TODO(pg-cutover): SourcesRepository.get_any requires ``user_id``
# scoping, but callers in the agent build path don't always
# thread the decoded token through here. Use a direct
# short-lived SQL lookup instead of the repo until the call
# sites are updated to propagate user context.
from sqlalchemy import text as _text
from bson.objectid import ObjectId
from application.core.mongo_db import MongoDB
from application.storage.db.session import db_readonly
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
sources_collection = db["sources"]
if isinstance(active_docs, str):
active_docs = [active_docs]
with db_readonly() as conn:
for doc_id in active_docs:
try:
value = str(doc_id)
if len(value) == 36 and "-" in value:
row = conn.execute(
_text(
"SELECT directory_structure FROM sources "
"WHERE id = CAST(:id AS uuid)"
),
{"id": value},
).fetchone()
else:
row = conn.execute(
_text(
"SELECT directory_structure FROM sources "
"WHERE legacy_mongo_id = :lid"
),
{"lid": value},
).fetchone()
if row is not None and row[0]:
return True
except Exception:
continue
for doc_id in active_docs:
try:
source_doc = sources_collection.find_one(
{"_id": ObjectId(doc_id)},
{"directory_structure": 1},
)
if source_doc and source_doc.get("directory_structure"):
return True
except Exception:
continue
except Exception as e:
logger.debug(f"Could not check directory structure: {e}")
@@ -436,7 +415,6 @@ def build_internal_tool_config(
chunks: int = 2,
doc_token_limit: int = 50000,
model_id: str = "docsgpt-local",
model_user_id: Optional[str] = None,
user_api_key: Optional[str] = None,
agent_id: Optional[str] = None,
llm_name: str = None,
@@ -451,7 +429,6 @@ def build_internal_tool_config(
"chunks": chunks,
"doc_token_limit": doc_token_limit,
"model_id": model_id,
"model_user_id": model_user_id,
"user_api_key": user_api_key,
"agent_id": agent_id,
"llm_name": llm_name or settings.LLM_PROVIDER,

View File

@@ -20,15 +20,18 @@ from pydantic import AnyHttpUrl, ValidationError
from redis import Redis
from application.agents.tools.base import Tool
from application.api.user.tasks import mcp_oauth_task
from application.api.user.tasks import mcp_oauth_status_task, mcp_oauth_task
from application.cache import get_redis_instance
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.core.url_validation import SSRFError, validate_url
from application.events.keys import stream_key
from application.security.encryption import decrypt_credentials
logger = logging.getLogger(__name__)
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
_mcp_clients_cache = {}
@@ -77,12 +80,6 @@ class MCPTool(Tool):
self.oauth_task_id = config.get("oauth_task_id", None)
self.oauth_client_name = config.get("oauth_client_name", "DocsGPT-MCP")
self.redirect_uri = self._resolve_redirect_uri(config.get("redirect_uri"))
# Pulled out of ``config`` (rather than left in ``self.config``)
# because it is a callable supplied by the OAuth worker — not
# something the rest of the tool plumbing should marshal or
# serialize. ``DocsGPTOAuth`` invokes it from ``redirect_handler``
# so the SSE envelope can carry ``authorization_url``.
self.oauth_redirect_publish = config.pop("oauth_redirect_publish", None)
self.available_tools = []
self._cache_key = self._generate_cache_key()
@@ -164,6 +161,7 @@ class MCPTool(Tool):
scopes=self.oauth_scopes,
redis_client=redis_client,
redirect_uri=self.redirect_uri,
db=db,
user_id=self.user_id,
)
else:
@@ -173,8 +171,8 @@ class MCPTool(Tool):
redis_client=redis_client,
redirect_uri=self.redirect_uri,
task_id=self.oauth_task_id,
db=db,
user_id=self.user_id,
redirect_publish=self.oauth_redirect_publish,
)
elif self.auth_type == "bearer":
token = self.auth_credentials.get(
@@ -493,7 +491,7 @@ class MCPTool(Tool):
def _test_oauth_connection(self) -> Dict:
storage = DBTokenStorage(
server_url=self.server_url, user_id=self.user_id,
server_url=self.server_url, user_id=self.user_id, db_client=db
)
loop = asyncio.new_event_loop()
try:
@@ -685,19 +683,16 @@ class DocsGPTOAuth(OAuthClientProvider):
scopes: str | list[str] | None = None,
client_name: str = "DocsGPT-MCP",
user_id=None,
db=None,
additional_client_metadata: dict[str, Any] | None = None,
skip_redirect_validation: bool = False,
redirect_publish=None,
):
self.redirect_uri = redirect_uri
self.redis_client = redis_client
self.redis_prefix = redis_prefix
self.task_id = task_id
self.user_id = user_id
# Worker-supplied callback. Invoked from ``redirect_handler``
# once the authorization URL is known so the SSE envelope can
# carry it. ``None`` for any non-worker entrypoint.
self.redirect_publish = redirect_publish
self.db = db
parsed_url = urlparse(mcp_url)
self.server_base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
@@ -716,6 +711,7 @@ class DocsGPTOAuth(OAuthClientProvider):
storage = DBTokenStorage(
server_url=self.server_base_url,
user_id=self.user_id,
db_client=self.db,
expected_redirect_uri=None if skip_redirect_validation else redirect_uri,
)
@@ -757,19 +753,17 @@ class DocsGPTOAuth(OAuthClientProvider):
self.redis_client.setex(key, 600, auth_url)
logger.info("Stored auth_url in Redis: %s", key)
if self.redirect_publish is not None:
# Best-effort: a publish failure must not abort the OAuth
# handshake — the user can still authorize via the popup
# opened from the legacy polling fallback if the SSE
# envelope is lost.
try:
self.redirect_publish(auth_url)
except Exception:
logger.warning(
"redirect_publish callback raised for task_id=%s",
self.task_id,
exc_info=True,
)
if self.task_id:
status_key = f"mcp_oauth_status:{self.task_id}"
status_data = {
"status": "requires_redirect",
"message": "Authorization required",
"authorization_url": self.auth_url,
"state": self.extracted_state,
"requires_oauth": True,
"task_id": self.task_id,
}
self.redis_client.setex(status_key, 600, json.dumps(status_data))
async def callback_handler(self) -> tuple[str, str | None]:
"""Wait for auth code from Redis using the state value."""
@@ -779,6 +773,17 @@ class DocsGPTOAuth(OAuthClientProvider):
max_wait_time = 300
code_key = f"{self.redis_prefix}code:{self.extracted_state}"
if self.task_id:
status_key = f"mcp_oauth_status:{self.task_id}"
status_data = {
"status": "awaiting_callback",
"message": "Waiting for authorization...",
"authorization_url": self.auth_url,
"state": self.extracted_state,
"requires_oauth": True,
"task_id": self.task_id,
}
self.redis_client.setex(status_key, 600, json.dumps(status_data))
start_time = time.time()
while time.time() - start_time < max_wait_time:
code_data = self.redis_client.get(code_key)
@@ -793,6 +798,14 @@ class DocsGPTOAuth(OAuthClientProvider):
self.redis_client.delete(
f"{self.redis_prefix}state:{self.extracted_state}"
)
if self.task_id:
status_data = {
"status": "callback_received",
"message": "Completing authentication...",
"task_id": self.task_id,
}
self.redis_client.setex(status_key, 600, json.dumps(status_data))
return code, returned_state
error_key = f"{self.redis_prefix}error:{self.extracted_state}"
error_data = self.redis_client.get(error_key)
@@ -840,95 +853,54 @@ class DBTokenStorage(TokenStorage):
self,
server_url: str,
user_id: str,
db_client,
expected_redirect_uri: Optional[str] = None,
):
self.server_url = server_url
self.user_id = user_id
self.db_client = db_client
self.expected_redirect_uri = expected_redirect_uri
self.collection = db_client["connector_sessions"]
@staticmethod
def get_base_url(url: str) -> str:
parsed = urlparse(url)
return f"{parsed.scheme}://{parsed.netloc}"
def _pg_provider(self) -> str:
return f"mcp:{self.get_base_url(self.server_url)}"
def _fetch_session_data(self) -> dict:
"""Read the JSONB ``session_data`` blob for this MCP server row."""
from application.storage.db.repositories.connector_sessions import (
ConnectorSessionsRepository,
)
from application.storage.db.session import db_readonly
base_url = self.get_base_url(self.server_url)
with db_readonly() as conn:
row = ConnectorSessionsRepository(conn).get_by_user_and_server_url(
self.user_id, base_url,
)
if not row:
return {}
data = row.get("session_data") or {}
if isinstance(data, str):
try:
data = json.loads(data)
except ValueError:
return {}
return data if isinstance(data, dict) else {}
def get_db_key(self) -> dict:
return {
"server_url": self.get_base_url(self.server_url),
"user_id": self.user_id,
}
async def get_tokens(self) -> OAuthToken | None:
data = await asyncio.to_thread(self._fetch_session_data)
if not data or "tokens" not in data:
doc = await asyncio.to_thread(self.collection.find_one, self.get_db_key())
if not doc or "tokens" not in doc:
return None
try:
return OAuthToken.model_validate(data["tokens"])
return OAuthToken.model_validate(doc["tokens"])
except ValidationError as e:
logger.error("Could not load tokens: %s", e)
return None
def _merge(self, patch: dict) -> None:
"""Shallow-merge ``patch`` into this row's ``session_data``.
Threads ``server_url`` through to the repository so it lands in
the scalar column — ``get_by_user_and_server_url`` needs that to
resolve the row (``NULL = 'https://...'`` is UNKNOWN in SQL).
"""
from application.storage.db.repositories.connector_sessions import (
ConnectorSessionsRepository,
)
from application.storage.db.session import db_session
base_url = self.get_base_url(self.server_url)
with db_session() as conn:
ConnectorSessionsRepository(conn).merge_session_data(
self.user_id, self._pg_provider(), base_url, patch,
)
def _delete(self) -> None:
from application.storage.db.repositories.connector_sessions import (
ConnectorSessionsRepository,
)
from application.storage.db.session import db_session
with db_session() as conn:
ConnectorSessionsRepository(conn).delete(
self.user_id, self._pg_provider(),
)
async def set_tokens(self, tokens: OAuthToken) -> None:
base_url = self.get_base_url(self.server_url)
token_dump = tokens.model_dump()
await asyncio.to_thread(self._merge, {"tokens": token_dump})
logger.info("Saved tokens for %s", base_url)
await asyncio.to_thread(
self.collection.update_one,
self.get_db_key(),
{"$set": {"tokens": tokens.model_dump()}},
True,
)
logger.info("Saved tokens for %s", self.get_base_url(self.server_url))
async def get_client_info(self) -> OAuthClientInformationFull | None:
data = await asyncio.to_thread(self._fetch_session_data)
base_url = self.get_base_url(self.server_url)
if not data or "client_info" not in data:
logger.debug("No client_info in DB for %s", base_url)
doc = await asyncio.to_thread(self.collection.find_one, self.get_db_key())
if not doc or "client_info" not in doc:
logger.debug(
"No client_info in DB for %s", self.get_base_url(self.server_url)
)
return None
try:
client_info = OAuthClientInformationFull.model_validate(data["client_info"])
client_info = OAuthClientInformationFull.model_validate(doc["client_info"])
if self.expected_redirect_uri:
stored_uris = [
str(uri).rstrip("/") for uri in client_info.redirect_uris
@@ -937,16 +909,14 @@ class DBTokenStorage(TokenStorage):
if expected_uri not in stored_uris:
logger.warning(
"Redirect URI mismatch for %s: expected=%s stored=%s — clearing.",
base_url,
self.get_base_url(self.server_url),
expected_uri,
stored_uris,
)
# Drop ``tokens`` and ``client_info`` from the JSONB
# blob via merge_session_data's ``None``-drops-key
# semantics — preserves the row + any other keys.
await asyncio.to_thread(
self._merge,
{"tokens": None, "client_info": None},
self.collection.update_one,
self.get_db_key(),
{"$unset": {"client_info": "", "tokens": ""}},
)
return None
return client_info
@@ -961,37 +931,22 @@ class DBTokenStorage(TokenStorage):
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
serialized_info = self._serialize_client_info(client_info.model_dump())
base_url = self.get_base_url(self.server_url)
await asyncio.to_thread(
self._merge, {"client_info": serialized_info},
self.collection.update_one,
self.get_db_key(),
{"$set": {"client_info": serialized_info}},
True,
)
logger.info("Saved client info for %s", base_url)
logger.info("Saved client info for %s", self.get_base_url(self.server_url))
async def clear(self) -> None:
await asyncio.to_thread(self._delete)
await asyncio.to_thread(self.collection.delete_one, self.get_db_key())
logger.info("Cleared OAuth cache for %s", self.get_base_url(self.server_url))
@classmethod
async def clear_all(cls, db_client=None) -> None:
"""Delete every MCP-tagged connector session row.
``db_client`` retained for call-site compatibility but unused —
storage is Postgres-only now.
"""
from sqlalchemy import text
from application.storage.db.session import db_session
def _delete_all() -> None:
with db_session() as conn:
conn.execute(
text(
"DELETE FROM connector_sessions "
"WHERE provider LIKE 'mcp:%'"
)
)
await asyncio.to_thread(_delete_all)
async def clear_all(cls, db_client) -> None:
collection = db_client["connector_sessions"]
await asyncio.to_thread(collection.delete_many, {})
logger.info("Cleared all OAuth client cache data.")
@@ -1034,73 +989,8 @@ class MCPOAuthManager:
logger.error("Error handling OAuth callback: %s", e)
return False
def get_oauth_status(self, task_id: str, user_id: str) -> Dict[str, Any]:
"""Return the latest OAuth status for ``task_id`` from the user's SSE journal.
Mirrors the legacy polling contract: ``status`` derived from the
``mcp.oauth.*`` event-type suffix, with payload fields surfaced
(e.g. ``tools``/``tools_count`` on ``completed``).
"""
def get_oauth_status(self, task_id: str) -> Dict[str, Any]:
"""Get current status of OAuth flow using provided task_id."""
if not task_id:
return {"status": "not_started", "message": "OAuth flow not started"}
if not user_id:
return {"status": "not_found", "message": "User not provided"}
if self.redis_client is None:
return {"status": "not_found", "message": "Redis unavailable"}
try:
# OAuth flows are short-lived but a concurrent source
# ingest can flood the user channel between the OAuth
# popup completing and the user clicking Save, pushing the
# completion envelope outside the read window. Bound the
# scan by the configured stream cap so we cover the full
# journal — XADD MAXLEN keeps that bounded too.
scan_count = max(settings.EVENTS_STREAM_MAXLEN, 200)
entries = self.redis_client.xrevrange(
stream_key(user_id), count=scan_count
)
except Exception:
logger.exception(
"xrevrange failed for oauth status: user_id=%s task_id=%s",
user_id,
task_id,
)
return {"status": "not_found", "message": "Status unavailable"}
for _entry_id, fields in entries:
if not isinstance(fields, dict):
continue
# decode_responses=False ⇒ bytes keys; the string-key fallback
# covers a future flip of that default without a forced refactor.
event_raw = fields.get(b"event")
if event_raw is None:
event_raw = fields.get("event")
if event_raw is None:
continue
if isinstance(event_raw, bytes):
try:
event_raw = event_raw.decode("utf-8")
except Exception:
continue
try:
envelope = json.loads(event_raw)
except Exception:
continue
if not isinstance(envelope, dict):
continue
event_type = envelope.get("type", "")
if not isinstance(event_type, str) or not event_type.startswith(
"mcp.oauth."
):
continue
scope = envelope.get("scope") or {}
if scope.get("kind") != "mcp_oauth" or scope.get("id") != task_id:
continue
payload = envelope.get("payload") or {}
return {
"status": event_type[len("mcp.oauth."):],
"task_id": task_id,
**payload,
}
return {"status": "not_found", "message": "Status not found"}
return mcp_oauth_status_task(task_id)

View File

@@ -1,14 +1,12 @@
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional
import logging
import re
import uuid
from .base import Tool
from application.storage.db.repositories.memories import MemoriesRepository
from application.storage.db.session import db_readonly, db_session
logger = logging.getLogger(__name__)
from application.core.mongo_db import MongoDB
from application.core.settings import settings
class MemoryTool(Tool):
@@ -29,7 +27,7 @@ class MemoryTool(Tool):
self.user_id: Optional[str] = user_id
# Get tool_id from configuration (passed from user_tools._id in production)
# In production, tool_id is the UUID string from user_tools.id.
# In production, tool_id is the MongoDB ObjectId string from user_tools collection
if tool_config and "tool_id" in tool_config:
self.tool_id = tool_config["tool_id"]
elif user_id:
@@ -39,35 +37,8 @@ class MemoryTool(Tool):
# Last resort fallback (shouldn't happen in normal use)
self.tool_id = str(uuid.uuid4())
def _pg_enabled(self) -> bool:
"""Return True if this MemoryTool's tool_id is a real ``user_tools.id``.
The ``memories`` PG table has a UUID foreign key to ``user_tools``.
The sentinel ``default_{uid}`` fallback tool_id is not a UUID and
has no row in ``user_tools``, so any storage operation would fail
the foreign-key check. After the Postgres cutover Postgres is the
only store, so for the sentinel case there is nowhere to read or
write — operations become no-ops and the tool returns an
explanatory error to the caller.
"""
tool_id = getattr(self, "tool_id", None)
if not tool_id or not isinstance(tool_id, str):
return False
if tool_id.startswith("default_"):
logger.debug(
"Skipping Postgres operation for MemoryTool with sentinel tool_id=%s",
tool_id,
)
return False
from application.storage.db.base_repository import looks_like_uuid
if not looks_like_uuid(tool_id):
logger.debug(
"Skipping Postgres operation for MemoryTool with non-UUID tool_id=%s",
tool_id,
)
return False
return True
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
self.collection = db["memories"]
# -----------------------------
# Action implementations
@@ -85,12 +56,6 @@ class MemoryTool(Tool):
if not self.user_id:
return "Error: MemoryTool requires a valid user_id."
if not self._pg_enabled():
return (
"Error: MemoryTool is not configured with a persistent tool_id; "
"memory storage is unavailable for this session."
)
if action_name == "view":
return self._view(
kwargs.get("path", "/"),
@@ -317,10 +282,14 @@ class MemoryTool(Tool):
# Ensure path ends with / for proper prefix matching
search_path = path if path.endswith("/") else path + "/"
with db_readonly() as conn:
docs = MemoriesRepository(conn).list_by_prefix(
self.user_id, self.tool_id, search_path
)
# Find all files that start with this directory path
query = {
"user_id": self.user_id,
"tool_id": self.tool_id,
"path": {"$regex": f"^{re.escape(search_path)}"}
}
docs = list(self.collection.find(query, {"path": 1}))
if not docs:
return f"Directory: {path}\n(empty)"
@@ -341,10 +310,7 @@ class MemoryTool(Tool):
def _view_file(self, path: str, view_range: Optional[List[int]] = None) -> str:
"""View file contents with optional line range."""
with db_readonly() as conn:
doc = MemoriesRepository(conn).get_by_path(
self.user_id, self.tool_id, path
)
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": path})
if not doc or not doc.get("content"):
return f"Error: File not found: {path}"
@@ -378,10 +344,16 @@ class MemoryTool(Tool):
if validated_path == "/" or validated_path.endswith("/"):
return "Error: Cannot create a file at directory path."
with db_session() as conn:
MemoriesRepository(conn).upsert(
self.user_id, self.tool_id, validated_path, file_text
)
self.collection.update_one(
{"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path},
{
"$set": {
"content": file_text,
"updated_at": datetime.now()
}
},
upsert=True
)
return f"File created: {validated_path}"
@@ -394,29 +366,30 @@ class MemoryTool(Tool):
if not old_str:
return "Error: old_str is required."
with db_session() as conn:
repo = MemoriesRepository(conn)
doc = repo.get_by_path(self.user_id, self.tool_id, validated_path)
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path})
if not doc or not doc.get("content"):
return f"Error: File not found: {validated_path}"
if not doc or not doc.get("content"):
return f"Error: File not found: {validated_path}"
current_content = str(doc["content"])
current_content = str(doc["content"])
# Check if old_str exists (case-insensitive)
if old_str.lower() not in current_content.lower():
return f"Error: String '{old_str}' not found in file."
# Check if old_str exists (case-insensitive)
if old_str.lower() not in current_content.lower():
return f"Error: String '{old_str}' not found in file."
# Case-insensitive replace
import re as regex_module
updated_content = regex_module.sub(
regex_module.escape(old_str),
new_str,
current_content,
flags=regex_module.IGNORECASE,
)
# Replace the string (case-insensitive)
import re as regex_module
updated_content = regex_module.sub(regex_module.escape(old_str), new_str, current_content, flags=regex_module.IGNORECASE)
repo.upsert(self.user_id, self.tool_id, validated_path, updated_content)
self.collection.update_one(
{"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path},
{
"$set": {
"content": updated_content,
"updated_at": datetime.now()
}
}
)
return f"File updated: {validated_path}"
@@ -429,25 +402,31 @@ class MemoryTool(Tool):
if not insert_text:
return "Error: insert_text is required."
with db_session() as conn:
repo = MemoriesRepository(conn)
doc = repo.get_by_path(self.user_id, self.tool_id, validated_path)
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path})
if not doc or not doc.get("content"):
return f"Error: File not found: {validated_path}"
if not doc or not doc.get("content"):
return f"Error: File not found: {validated_path}"
current_content = str(doc["content"])
lines = current_content.split("\n")
current_content = str(doc["content"])
lines = current_content.split("\n")
# Convert to 0-indexed
index = insert_line - 1
if index < 0 or index > len(lines):
return f"Error: Invalid line number. File has {len(lines)} lines."
# Convert to 0-indexed
index = insert_line - 1
if index < 0 or index > len(lines):
return f"Error: Invalid line number. File has {len(lines)} lines."
lines.insert(index, insert_text)
updated_content = "\n".join(lines)
lines.insert(index, insert_text)
updated_content = "\n".join(lines)
repo.upsert(self.user_id, self.tool_id, validated_path, updated_content)
self.collection.update_one(
{"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path},
{
"$set": {
"content": updated_content,
"updated_at": datetime.now()
}
}
)
return f"Text inserted at line {insert_line} in {validated_path}"
@@ -459,36 +438,39 @@ class MemoryTool(Tool):
if validated_path == "/":
# Delete all files for this user and tool
with db_session() as conn:
deleted = MemoriesRepository(conn).delete_all(
self.user_id, self.tool_id
)
return f"Deleted {deleted} file(s) from memory."
result = self.collection.delete_many({"user_id": self.user_id, "tool_id": self.tool_id})
return f"Deleted {result.deleted_count} file(s) from memory."
# Check if it's a directory (ends with /)
if validated_path.endswith("/"):
with db_session() as conn:
deleted = MemoriesRepository(conn).delete_by_prefix(
self.user_id, self.tool_id, validated_path
)
return f"Deleted directory and {deleted} file(s)."
# Delete all files in directory
result = self.collection.delete_many({
"user_id": self.user_id,
"tool_id": self.tool_id,
"path": {"$regex": f"^{re.escape(validated_path)}"}
})
return f"Deleted directory and {result.deleted_count} file(s)."
# Try as directory first (without trailing slash)
# Try to delete as directory first (without trailing slash)
# Check if any files start with this path + /
search_path = validated_path + "/"
with db_session() as conn:
repo = MemoriesRepository(conn)
directory_deleted = repo.delete_by_prefix(
self.user_id, self.tool_id, search_path
)
if directory_deleted > 0:
return f"Deleted directory and {directory_deleted} file(s)."
directory_result = self.collection.delete_many({
"user_id": self.user_id,
"tool_id": self.tool_id,
"path": {"$regex": f"^{re.escape(search_path)}"}
})
# Otherwise delete a single file
file_deleted = repo.delete_by_path(
self.user_id, self.tool_id, validated_path
)
if directory_result.deleted_count > 0:
return f"Deleted directory and {directory_result.deleted_count} file(s)."
if file_deleted:
# Delete single file
result = self.collection.delete_one({
"user_id": self.user_id,
"tool_id": self.tool_id,
"path": validated_path
})
if result.deleted_count:
return f"Deleted: {validated_path}"
return f"Error: File not found: {validated_path}"
@@ -503,46 +485,62 @@ class MemoryTool(Tool):
if validated_old == "/" or validated_new == "/":
return "Error: Cannot rename root directory."
# Directory rename: do all path updates inside one transaction so
# the rename is atomic from the caller's perspective.
# Check if renaming a directory
if validated_old.endswith("/"):
# Ensure validated_new also ends with / for proper path replacement
if not validated_new.endswith("/"):
validated_new = validated_new + "/"
with db_session() as conn:
repo = MemoriesRepository(conn)
docs = repo.list_by_prefix(
self.user_id, self.tool_id, validated_old
# Find all files in the old directory
docs = list(self.collection.find({
"user_id": self.user_id,
"tool_id": self.tool_id,
"path": {"$regex": f"^{re.escape(validated_old)}"}
}))
if not docs:
return f"Error: Directory not found: {validated_old}"
# Update paths for all files
for doc in docs:
old_file_path = doc["path"]
new_file_path = old_file_path.replace(validated_old, validated_new, 1)
self.collection.update_one(
{"_id": doc["_id"]},
{"$set": {"path": new_file_path, "updated_at": datetime.now()}}
)
if not docs:
return f"Error: Directory not found: {validated_old}"
for doc in docs:
old_file_path = doc["path"]
new_file_path = old_file_path.replace(
validated_old, validated_new, 1
)
repo.update_path(
self.user_id, self.tool_id, old_file_path, new_file_path
)
return f"Renamed directory: {validated_old} -> {validated_new} ({len(docs)} files)"
# Single-file rename: lookup, collision check, and update in one txn.
with db_session() as conn:
repo = MemoriesRepository(conn)
doc = repo.get_by_path(self.user_id, self.tool_id, validated_old)
if not doc:
return f"Error: File not found: {validated_old}"
# Rename single file
doc = self.collection.find_one({
"user_id": self.user_id,
"tool_id": self.tool_id,
"path": validated_old
})
existing = repo.get_by_path(self.user_id, self.tool_id, validated_new)
if existing:
return f"Error: File already exists at {validated_new}"
if not doc:
return f"Error: File not found: {validated_old}"
repo.update_path(
self.user_id, self.tool_id, validated_old, validated_new
)
# Check if new path already exists
existing = self.collection.find_one({
"user_id": self.user_id,
"tool_id": self.tool_id,
"path": validated_new
})
if existing:
return f"Error: File already exists at {validated_new}"
# Delete the old document and create a new one with the new path
self.collection.delete_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_old})
self.collection.insert_one({
"user_id": self.user_id,
"tool_id": self.tool_id,
"path": validated_new,
"content": doc.get("content", ""),
"updated_at": datetime.now()
})
return f"Renamed: {validated_old} -> {validated_new}"

View File

@@ -1,16 +1,10 @@
from datetime import datetime
from typing import Any, Dict, List, Optional
import uuid
from .base import Tool
from application.storage.db.repositories.notes import NotesRepository
from application.storage.db.session import db_readonly, db_session
# Stable synthetic title used in the Postgres ``notes.title`` column.
# The notes tool stores one note per (user_id, tool_id); there is no
# user-facing title. PG requires ``title`` NOT NULL, so we write a stable
# constant alongside the actual note body in ``content``.
_NOTE_TITLE = "note"
from application.core.mongo_db import MongoDB
from application.core.settings import settings
class NotesTool(Tool):
@@ -31,6 +25,7 @@ class NotesTool(Tool):
self.user_id: Optional[str] = user_id
# Get tool_id from configuration (passed from user_tools._id in production)
# In production, tool_id is the MongoDB ObjectId string from user_tools collection
if tool_config and "tool_id" in tool_config:
self.tool_id = tool_config["tool_id"]
elif user_id:
@@ -40,25 +35,11 @@ class NotesTool(Tool):
# Last resort fallback (shouldn't happen in normal use)
self.tool_id = str(uuid.uuid4())
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
self.collection = db["notes"]
self._last_artifact_id: Optional[str] = None
def _pg_enabled(self) -> bool:
"""Return True only when ``tool_id`` is a real ``user_tools.id`` UUID.
``notes.tool_id`` is a UUID FK to ``user_tools``; repo queries
``CAST(:tool_id AS uuid)``. The sentinel ``default_{uid}``
fallback is neither a UUID nor a ``user_tools`` row, so any DB
operation would crash. Mirror MemoryTool's guard and no-op.
"""
tool_id = getattr(self, "tool_id", None)
if not tool_id or not isinstance(tool_id, str):
return False
if tool_id.startswith("default_"):
return False
from application.storage.db.base_repository import looks_like_uuid
return looks_like_uuid(tool_id)
# -----------------------------
# Action implementations
# -----------------------------
@@ -73,13 +54,7 @@ class NotesTool(Tool):
A human-readable string result.
"""
if not self.user_id:
return "Error: NotesTool requires a valid user_id."
if not self._pg_enabled():
return (
"Error: NotesTool is not configured with a persistent "
"tool_id; note storage is unavailable for this session."
)
return "Error: NotesTool requires a valid user_id."
self._last_artifact_id = None
@@ -160,45 +135,37 @@ class NotesTool(Tool):
# -----------------------------
# Internal helpers (single-note)
# -----------------------------
def _fetch_note(self) -> Optional[dict]:
"""Read the note row for this (user, tool) from Postgres."""
with db_readonly() as conn:
return NotesRepository(conn).get_for_user_tool(self.user_id, self.tool_id)
def _get_note(self) -> str:
doc = self._fetch_note()
# ``content`` is the PG column; expose as ``note`` to callers via the
# textual return value. Frontends that read the artifact via the
# repo dict get ``content`` (PG-native) plus the artifact id below.
body = (doc or {}).get("content")
if not doc or not body:
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
if not doc or not doc.get("note"):
return "No note found."
if doc.get("id") is not None:
self._last_artifact_id = str(doc.get("id"))
return str(body)
if doc.get("_id") is not None:
self._last_artifact_id = str(doc.get("_id"))
return str(doc["note"])
def _overwrite_note(self, content: str) -> str:
content = (content or "").strip()
if not content:
return "Note content required."
with db_session() as conn:
row = NotesRepository(conn).upsert(
self.user_id, self.tool_id, _NOTE_TITLE, content
)
if row and row.get("id") is not None:
self._last_artifact_id = str(row.get("id"))
result = self.collection.find_one_and_update(
{"user_id": self.user_id, "tool_id": self.tool_id},
{"$set": {"note": content, "updated_at": datetime.utcnow()}},
upsert=True,
return_document=True,
)
if result and result.get("_id") is not None:
self._last_artifact_id = str(result.get("_id"))
return "Note saved."
def _str_replace(self, old_str: str, new_str: str) -> str:
if not old_str:
return "old_str is required."
doc = self._fetch_note()
existing = (doc or {}).get("content")
if not doc or not existing:
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
if not doc or not doc.get("note"):
return "No note found."
current_note = str(existing)
current_note = str(doc["note"])
# Case-insensitive search
if old_str.lower() not in current_note.lower():
@@ -208,24 +175,24 @@ class NotesTool(Tool):
import re
updated_note = re.sub(re.escape(old_str), new_str, current_note, flags=re.IGNORECASE)
with db_session() as conn:
row = NotesRepository(conn).upsert(
self.user_id, self.tool_id, _NOTE_TITLE, updated_note
)
if row and row.get("id") is not None:
self._last_artifact_id = str(row.get("id"))
result = self.collection.find_one_and_update(
{"user_id": self.user_id, "tool_id": self.tool_id},
{"$set": {"note": updated_note, "updated_at": datetime.utcnow()}},
return_document=True,
)
if result and result.get("_id") is not None:
self._last_artifact_id = str(result.get("_id"))
return "Note updated."
def _insert(self, line_number: int, text: str) -> str:
if not text:
return "Text is required."
doc = self._fetch_note()
existing = (doc or {}).get("content")
if not doc or not existing:
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
if not doc or not doc.get("note"):
return "No note found."
current_note = str(existing)
current_note = str(doc["note"])
lines = current_note.split("\n")
# Convert to 0-indexed and validate
@@ -236,23 +203,21 @@ class NotesTool(Tool):
lines.insert(index, text)
updated_note = "\n".join(lines)
with db_session() as conn:
row = NotesRepository(conn).upsert(
self.user_id, self.tool_id, _NOTE_TITLE, updated_note
)
if row and row.get("id") is not None:
self._last_artifact_id = str(row.get("id"))
result = self.collection.find_one_and_update(
{"user_id": self.user_id, "tool_id": self.tool_id},
{"$set": {"note": updated_note, "updated_at": datetime.utcnow()}},
return_document=True,
)
if result and result.get("_id") is not None:
self._last_artifact_id = str(result.get("_id"))
return "Text inserted."
def _delete_note(self) -> str:
# Capture the id (for artifact tracking) before deleting.
existing = self._fetch_note()
if not existing:
doc = self.collection.find_one_and_delete(
{"user_id": self.user_id, "tool_id": self.tool_id}
)
if not doc:
return "No note found to delete."
with db_session() as conn:
deleted = NotesRepository(conn).delete(self.user_id, self.tool_id)
if not deleted:
return "No note found to delete."
if existing.get("id") is not None:
self._last_artifact_id = str(existing.get("id"))
if doc.get("_id") is not None:
self._last_artifact_id = str(doc.get("_id"))
return "Note deleted."

View File

@@ -177,4 +177,3 @@ class PostgresTool(Tool):
"order": 1,
},
}

View File

@@ -1,339 +0,0 @@
"""Scheduler tool: one-time agent tasks in agent-bound or agentless chats."""
from __future__ import annotations
import json
import logging
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
from application.agents.scheduler_utils import (
ScheduleValidationError,
clamp_once_horizon,
parse_delay,
parse_run_at,
)
from application.core.settings import settings
from application.storage.db.base_repository import looks_like_uuid
from application.storage.db.repositories.schedules import SchedulesRepository
from application.storage.db.session import db_readonly, db_session
from .base import Tool
logger = logging.getLogger(__name__)
class SchedulerTool(Tool):
"""Scheduling"""
# internal=True keeps scheduler out of /api/available_tools and the
# agentless Add-Tool modal; tool_manager.load_tool still lazy-loads it
# per-user at execute time (same as memory/notes/todo_list).
internal: bool = True
def __init__(
self,
tool_config: Optional[Dict[str, Any]] = None,
user_id: Optional[str] = None,
) -> None:
cfg = tool_config or {}
self.user_id: Optional[str] = user_id
self.agent_id: Optional[str] = cfg.get("agent_id")
self.conversation_id: Optional[str] = cfg.get("conversation_id")
def execute_action(self, action_name: str, **kwargs: Any) -> str:
"""Dispatch on the LLM-supplied action name."""
if not self.user_id:
return "Error: SchedulerTool requires a valid user_id."
# Agent-bound: agent_id must look like a UUID. Agentless: agent_id is
# absent; an originating conversation is then mandatory (the schedule's
# conversation home, used for history + output append).
if self.agent_id and not looks_like_uuid(str(self.agent_id)):
return "Error: SchedulerTool received an invalid agent_id."
if not self.agent_id and not self.conversation_id:
return (
"Error: SchedulerTool requires an agent_id or a "
"conversation_id (no conversation home)."
)
if action_name == "schedule_task":
return self._schedule_task(
instruction=kwargs.get("instruction", ""),
delay=kwargs.get("delay"),
run_at=kwargs.get("run_at"),
tz=kwargs.get("timezone"),
)
if action_name == "list_scheduled_tasks":
return self._list_scheduled_tasks()
if action_name == "cancel_scheduled_task":
return self._cancel_scheduled_task(kwargs.get("task_id", ""))
return f"Unknown action: {action_name}"
def get_actions_metadata(self) -> List[Dict[str, Any]]:
"""Action schemas for the LLM tool catalogue."""
return [
{
"name": "schedule_task",
"description": (
"Schedule a one-time task. Provide either a `delay` "
"(e.g. '30m', '2h', '1d') from now, or a `run_at` ISO-8601 "
"absolute time. Optionally pass an IANA `timezone` to resolve "
"naive run_at values. The instruction is the task that will "
"execute at fire time (including delivery, e.g. 'send to my "
"Telegram'). For recurring schedules in an agent chat, point "
"the user to the agent's Schedules tab."
),
"parameters": {
"type": "object",
"properties": {
"instruction": {
"type": "string",
"description": "What the agent should do at fire time.",
},
"delay": {
"type": "string",
"description": "Duration like '30m', '2h', '1d'.",
},
"run_at": {
"type": "string",
"description": "Absolute ISO 8601 timestamp.",
},
"timezone": {
"type": "string",
"description": (
"IANA timezone (e.g. Europe/Warsaw) for naive run_at."
),
},
},
"required": ["instruction"],
},
},
{
"name": "list_scheduled_tasks",
"description": (
"List pending one-time tasks for the current chat. "
"Agent-bound chats scope to user+agent; agentless chats "
"scope to user+originating conversation."
),
"parameters": {"type": "object", "properties": {}},
},
{
"name": "cancel_scheduled_task",
"description": "Cancel a pending one-time task by its task_id.",
"parameters": {
"type": "object",
"properties": {
"task_id": {
"type": "string",
"description": "The schedule id returned by schedule_task.",
},
},
"required": ["task_id"],
},
},
]
def get_config_requirements(self) -> Dict[str, Any]:
return {}
def _schedule_task(
self,
instruction: str,
delay: Optional[str],
run_at: Optional[str],
tz: Optional[str],
) -> str:
if not instruction or not isinstance(instruction, str):
return "Error: instruction is required."
if not delay and not run_at:
return "Error: provide either `delay` or `run_at`."
if delay and run_at:
return "Error: provide only one of `delay` or `run_at`."
try:
if delay:
fire = datetime.now(timezone.utc) + parse_delay(delay)
else:
fire = parse_run_at(run_at, tz)
clamp_once_horizon(fire, settings.SCHEDULE_ONCE_MAX_HORIZON)
except ScheduleValidationError as exc:
return f"Error: {exc}"
with db_readonly() as conn:
count = SchedulesRepository(conn).count_active_for_user(self.user_id)
if (
settings.SCHEDULE_MAX_PER_USER > 0
and count >= settings.SCHEDULE_MAX_PER_USER
):
return (
"Error: you have reached the maximum number of active schedules."
)
# Chat-created tasks default to the user's non-approval tools (for the
# agent's toolset when agent-bound, or the user's defaults+user_tools
# when agentless).
allowlist = _safe_default_allowlist(self.agent_id, self.user_id)
auto_name = _name_from_instruction(instruction)
try:
with db_session() as conn:
created = SchedulesRepository(conn).create(
user_id=self.user_id,
agent_id=self.agent_id,
trigger_type="once",
instruction=instruction.strip(),
name=auto_name,
run_at=fire,
next_run_at=fire,
timezone=tz or "UTC",
tool_allowlist=allowlist,
origin_conversation_id=self.conversation_id,
created_via="chat",
)
except Exception as exc:
logger.exception("schedule_task create failed: %s", exc)
return "Error: failed to create scheduled task."
return json.dumps(
{
"task_id": str(created["id"]),
"resolved_run_at": _iso_utc(fire),
"timezone": tz or "UTC",
"instruction": instruction.strip(),
"name": auto_name,
}
)
def _list_scheduled_tasks(self) -> str:
"""Pending one-time tasks for this user, oldest fire first.
Agent-bound chats scope to user+agent. Agentless chats scope to user+
origin_conversation_id so a user only sees tasks created from this chat.
"""
with db_readonly() as conn:
repo = SchedulesRepository(conn)
if self.agent_id:
rows = repo.list_for_agent(
self.agent_id,
self.user_id,
statuses=["active"],
trigger_type="once",
)
else:
rows = repo.list_for_conversation(
self.user_id,
self.conversation_id,
statuses=["active"],
trigger_type="once",
)
# Values arrive as ISO strings (coerce_pg_native); string sentinel keeps types uniform.
rows.sort(key=lambda r: r.get("next_run_at") or "9999-12-31T23:59:59Z")
items = [
{
"task_id": str(r["id"]),
"instruction": r.get("instruction"),
"name": r.get("name"),
"resolved_run_at": _iso_utc(r.get("next_run_at")),
"timezone": r.get("timezone"),
"status": r.get("status"),
}
for r in rows
]
return json.dumps({"tasks": items})
def _cancel_scheduled_task(self, task_id: str) -> str:
if not task_id or not looks_like_uuid(str(task_id)):
return "Error: task_id must be a valid id."
with db_session() as conn:
repo = SchedulesRepository(conn)
# Agentless: scope cancel to user + originating conversation so a
# user can only cancel tasks they created in the current chat.
if not self.agent_id:
row = repo.get(task_id, self.user_id)
if row is None or row.get("agent_id") is not None or (
str(row.get("origin_conversation_id") or "")
!= str(self.conversation_id or "")
):
return (
"Error: scheduled task not found or already terminal."
)
ok = repo.cancel(task_id, self.user_id)
if not ok:
return "Error: scheduled task not found or already terminal."
return json.dumps({"task_id": str(task_id), "status": "cancelled"})
def _name_from_instruction(instruction: str, *, max_len: int = 80) -> str:
"""Compact display name derived from the instruction's first line."""
first_line = instruction.strip().split("\n", 1)[0]
if len(first_line) <= max_len:
return first_line
return first_line[: max_len - 1] + ""
def _iso_utc(value: Any) -> Optional[str]:
"""Render a datetime (or ISO string) as RFC3339 UTC; ``None`` passes through."""
if value is None:
return None
if isinstance(value, str):
try:
value = datetime.fromisoformat(value.replace("Z", "+00:00"))
except ValueError:
return value
if value.tzinfo is None:
value = value.replace(tzinfo=timezone.utc)
return value.astimezone(timezone.utc).isoformat().replace("+00:00", "Z")
def _safe_default_allowlist(
agent_id: Optional[str], user_id: str,
) -> List[str]:
"""Return ids of available tools whose actions are all non-approval.
Agent-bound: the agent's ``agents.tools`` entries.
Agentless: the user's active ``user_tools`` rows plus synthesized default
chat tools (resolved against ``settings.DEFAULT_CHAT_TOOLS`` and the
user's ``tool_preferences.disabled_default_tools`` opt-outs).
"""
from application.agents.default_tools import (
resolve_tool_by_id,
synthesized_default_tools,
)
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.repositories.user_tools import UserToolsRepository
from application.storage.db.repositories.users import UsersRepository
def _is_safe(row: Dict[str, Any]) -> bool:
actions = row.get("actions") or []
return not any(a.get("require_approval") for a in actions)
safe_ids: List[str] = []
try:
with db_readonly() as conn:
tools_repo = UserToolsRepository(conn)
if agent_id:
agent = AgentsRepository(conn).get(agent_id, user_id)
tool_ids = (agent or {}).get("tools") or []
for raw_id in tool_ids:
tool_id = str(raw_id)
row = resolve_tool_by_id(
tool_id, user_id, user_tools_repo=tools_repo,
)
if not row or not _is_safe(row):
continue
safe_ids.append(tool_id)
else:
# Agentless: explicit user_tools (active=true) + synthesized
# defaults respecting the user's opt-out preferences.
user_doc = UsersRepository(conn).get(user_id)
for row in tools_repo.list_active_for_user(user_id):
if not _is_safe(row):
continue
safe_ids.append(str(row["id"]))
for default_row in synthesized_default_tools(user_doc):
if not _is_safe(default_row):
continue
safe_ids.append(str(default_row["id"]))
except Exception: # pragma: no cover — best-effort fallback
logger.exception("scheduler: default allowlist build failed")
return []
return safe_ids

View File

@@ -1,19 +1,10 @@
from datetime import datetime
from typing import Any, Dict, List, Optional
import uuid
from .base import Tool
from application.storage.db.repositories.todos import TodosRepository
from application.storage.db.session import db_readonly, db_session
def _status_from_completed(completed: Any) -> str:
"""Translate the PG ``completed`` boolean to the legacy status string.
The frontend (and prior LLM-facing tool output) expects
``"open"`` / ``"completed"``. Keeping that contract at the tool
boundary insulates callers from the schema change.
"""
return "completed" if bool(completed) else "open"
from application.core.mongo_db import MongoDB
from application.core.settings import settings
class TodoListTool(Tool):
@@ -34,6 +25,7 @@ class TodoListTool(Tool):
self.user_id: Optional[str] = user_id
# Get tool_id from configuration (passed from user_tools._id in production)
# In production, tool_id is the MongoDB ObjectId string from user_tools collection
if tool_config and "tool_id" in tool_config:
self.tool_id = tool_config["tool_id"]
elif user_id:
@@ -43,27 +35,11 @@ class TodoListTool(Tool):
# Last resort fallback (shouldn't happen in normal use)
self.tool_id = str(uuid.uuid4())
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
self.collection = db["todos"]
self._last_artifact_id: Optional[str] = None
def _pg_enabled(self) -> bool:
"""Return True only when ``tool_id`` is a real ``user_tools.id`` UUID.
The ``todos`` PG table has a UUID foreign key to ``user_tools`` and
the repo queries ``CAST(:tool_id AS uuid)``. The sentinel
``default_{uid}`` fallback is neither a UUID nor a row in
``user_tools`` — binding it would crash ``invalid input syntax for
type uuid`` and even if it didn't the FK would reject it. Mirror
the MemoryTool guard and no-op in that case.
"""
tool_id = getattr(self, "tool_id", None)
if not tool_id or not isinstance(tool_id, str):
return False
if tool_id.startswith("default_"):
return False
from application.storage.db.base_repository import looks_like_uuid
return looks_like_uuid(tool_id)
# -----------------------------
# Action implementations
# -----------------------------
@@ -80,12 +56,6 @@ class TodoListTool(Tool):
if not self.user_id:
return "Error: TodoListTool requires a valid user_id."
if not self._pg_enabled():
return (
"Error: TodoListTool is not configured with a persistent "
"tool_id; todo storage is unavailable for this session."
)
self._last_artifact_id = None
if action_name == "list":
@@ -221,10 +191,28 @@ class TodoListTool(Tool):
return None
def _get_next_todo_id(self) -> int:
"""Get the next sequential todo_id for this user and tool.
Returns a simple integer (1, 2, 3, ...) scoped to this user/tool.
With 5-10 todos max, scanning is negligible.
"""
query = {"user_id": self.user_id, "tool_id": self.tool_id}
todos = list(self.collection.find(query, {"todo_id": 1}))
# Find the maximum todo_id
max_id = 0
for todo in todos:
todo_id = self._coerce_todo_id(todo.get("todo_id"))
if todo_id is not None:
max_id = max(max_id, todo_id)
return max_id + 1
def _list(self) -> str:
"""List all todos for the user."""
with db_readonly() as conn:
todos = TodosRepository(conn).list_for_tool(self.user_id, self.tool_id)
query = {"user_id": self.user_id, "tool_id": self.tool_id}
todos = list(self.collection.find(query))
if not todos:
return "No todos found."
@@ -233,7 +221,7 @@ class TodoListTool(Tool):
for doc in todos:
todo_id = doc.get("todo_id")
title = doc.get("title", "Untitled")
status = _status_from_completed(doc.get("completed"))
status = doc.get("status", "open")
line = f"[{todo_id}] {title} ({status})"
result_lines.append(line)
@@ -241,23 +229,27 @@ class TodoListTool(Tool):
return "\n".join(result_lines)
def _create(self, title: str) -> str:
"""Create a new todo item.
``TodosRepository.create`` allocates the per-tool monotonic
``todo_id`` inside the same transaction (``COALESCE(MAX(todo_id),0)+1``
scoped to ``tool_id``), so we no longer need a separate read-then-
write step here.
"""
"""Create a new todo item."""
title = (title or "").strip()
if not title:
return "Error: Title is required."
with db_session() as conn:
row = TodosRepository(conn).create(self.user_id, self.tool_id, title)
now = datetime.now()
todo_id = self._get_next_todo_id()
todo_id = row.get("todo_id")
if row.get("id") is not None:
self._last_artifact_id = str(row.get("id"))
doc = {
"todo_id": todo_id,
"user_id": self.user_id,
"tool_id": self.tool_id,
"title": title,
"status": "open",
"created_at": now,
"updated_at": now,
}
insert_result = self.collection.insert_one(doc)
inserted_id = getattr(insert_result, "inserted_id", None) or doc.get("_id")
if inserted_id is not None:
self._last_artifact_id = str(inserted_id)
return f"Todo created with ID {todo_id}: {title}"
def _get(self, todo_id: Optional[Any]) -> str:
@@ -266,21 +258,21 @@ class TodoListTool(Tool):
if parsed_todo_id is None:
return "Error: todo_id must be a positive integer."
with db_readonly() as conn:
doc = TodosRepository(conn).get_by_tool_and_todo_id(
self.user_id, self.tool_id, parsed_todo_id
)
query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}
doc = self.collection.find_one(query)
if not doc:
return f"Error: Todo with ID {parsed_todo_id} not found."
if doc.get("id") is not None:
self._last_artifact_id = str(doc.get("id"))
if doc.get("_id") is not None:
self._last_artifact_id = str(doc.get("_id"))
title = doc.get("title", "Untitled")
status = _status_from_completed(doc.get("completed"))
status = doc.get("status", "open")
return f"Todo [{parsed_todo_id}]:\nTitle: {title}\nStatus: {status}"
result = f"Todo [{parsed_todo_id}]:\nTitle: {title}\nStatus: {status}"
return result
def _update(self, todo_id: Optional[Any], title: str) -> str:
"""Update a todo's title by ID."""
@@ -292,19 +284,16 @@ class TodoListTool(Tool):
if not title:
return "Error: Title is required."
with db_session() as conn:
repo = TodosRepository(conn)
existing = repo.get_by_tool_and_todo_id(
self.user_id, self.tool_id, parsed_todo_id
)
if not existing:
return f"Error: Todo with ID {parsed_todo_id} not found."
repo.update_title_by_tool_and_todo_id(
self.user_id, self.tool_id, parsed_todo_id, title
)
query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}
doc = self.collection.find_one_and_update(
query,
{"$set": {"title": title, "updated_at": datetime.now()}},
)
if not doc:
return f"Error: Todo with ID {parsed_todo_id} not found."
if existing.get("id") is not None:
self._last_artifact_id = str(existing.get("id"))
if doc.get("_id") is not None:
self._last_artifact_id = str(doc.get("_id"))
return f"Todo {parsed_todo_id} updated to: {title}"
@@ -314,17 +303,16 @@ class TodoListTool(Tool):
if parsed_todo_id is None:
return "Error: todo_id must be a positive integer."
with db_session() as conn:
repo = TodosRepository(conn)
existing = repo.get_by_tool_and_todo_id(
self.user_id, self.tool_id, parsed_todo_id
)
if not existing:
return f"Error: Todo with ID {parsed_todo_id} not found."
repo.set_completed(self.user_id, self.tool_id, parsed_todo_id, True)
query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}
doc = self.collection.find_one_and_update(
query,
{"$set": {"status": "completed", "updated_at": datetime.now()}},
)
if not doc:
return f"Error: Todo with ID {parsed_todo_id} not found."
if existing.get("id") is not None:
self._last_artifact_id = str(existing.get("id"))
if doc.get("_id") is not None:
self._last_artifact_id = str(doc.get("_id"))
return f"Todo {parsed_todo_id} marked as completed."
@@ -334,18 +322,12 @@ class TodoListTool(Tool):
if parsed_todo_id is None:
return "Error: todo_id must be a positive integer."
with db_session() as conn:
repo = TodosRepository(conn)
existing = repo.get_by_tool_and_todo_id(
self.user_id, self.tool_id, parsed_todo_id
)
if not existing:
return f"Error: Todo with ID {parsed_todo_id} not found."
repo.delete_by_tool_and_todo_id(
self.user_id, self.tool_id, parsed_todo_id
)
query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}
doc = self.collection.find_one_and_delete(query)
if not doc:
return f"Error: Todo with ID {parsed_todo_id} not found."
if existing.get("id") is not None:
self._last_artifact_id = str(existing.get("id"))
if doc.get("_id") is not None:
self._last_artifact_id = str(doc.get("_id"))
return f"Todo {parsed_todo_id} deleted."

View File

@@ -57,29 +57,6 @@ class ToolActionParser:
def _parse_google_llm(self, call):
try:
call_args = call.arguments
# Gemini's SDK natively returns ``args`` as a dict, but the
# resume path (``gen_continuation``) stringifies it for the
# assistant message. Coerce a JSON string back into a dict;
# fall back to an empty dict on malformed input so downstream
# ``call_args.items()`` doesn't crash the stream.
if isinstance(call_args, str):
try:
call_args = json.loads(call_args)
except (json.JSONDecodeError, TypeError):
logger.warning(
"Google call.arguments was not valid JSON; "
"falling back to empty args for %s",
getattr(call, "name", "<unknown>"),
)
call_args = {}
if not isinstance(call_args, dict):
logger.warning(
"Google call.arguments has unexpected type %s; "
"falling back to empty args for %s",
type(call_args).__name__,
getattr(call, "name", "<unknown>"),
)
call_args = {}
resolved = self._resolve_via_mapping(call.name)
if resolved:

View File

@@ -28,10 +28,7 @@ class ToolManager:
module = importlib.import_module(f"application.agents.tools.{tool_name}")
for member_name, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, Tool) and obj is not Tool:
if (
tool_name in {"mcp_tool", "notes", "memory", "todo_list", "scheduler"}
and user_id
):
if tool_name in {"mcp_tool", "notes", "memory", "todo_list"} and user_id:
return obj(tool_config, user_id)
else:
return obj(tool_config)
@@ -39,10 +36,7 @@ class ToolManager:
def execute_action(self, tool_name, action_name, user_id=None, **kwargs):
if tool_name not in self.tools:
raise ValueError(f"Tool '{tool_name}' not loaded")
if (
tool_name in {"mcp_tool", "memory", "todo_list", "notes", "scheduler"}
and user_id
):
if tool_name in {"mcp_tool", "memory", "todo_list", "notes"} and user_id:
tool_config = self.config.get(tool_name, {})
tool = self.load_tool(tool_name, tool_config, user_id)
return tool.execute_action(action_name, **kwargs)

View File

@@ -12,13 +12,12 @@ from application.agents.workflows.schemas import (
WorkflowRun,
)
from application.agents.workflows.workflow_engine import WorkflowEngine
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.logging import log_activity, LogContext
from application.storage.db.base_repository import looks_like_uuid
from application.storage.db.repositories.workflow_edges import WorkflowEdgesRepository
from application.storage.db.repositories.workflow_nodes import WorkflowNodesRepository
from application.storage.db.dual_write import dual_write
from application.storage.db.repositories.workflow_runs import WorkflowRunsRepository
from application.storage.db.repositories.workflows import WorkflowsRepository
from application.storage.db.session import db_readonly, db_session
logger = logging.getLogger(__name__)
@@ -107,8 +106,10 @@ class WorkflowAgent(BaseAgent):
def _load_from_database(self) -> Optional[WorkflowGraph]:
try:
if not self.workflow_id:
logger.error("Missing workflow ID for load")
from bson.objectid import ObjectId
if not self.workflow_id or not ObjectId.is_valid(self.workflow_id):
logger.error(f"Invalid workflow ID: {self.workflow_id}")
return None
owner_id = self.workflow_owner
if not owner_id and isinstance(self.decoded_token, dict):
@@ -119,61 +120,61 @@ class WorkflowAgent(BaseAgent):
)
return None
with db_readonly() as conn:
wf_repo = WorkflowsRepository(conn)
if looks_like_uuid(self.workflow_id):
workflow_row = wf_repo.get(self.workflow_id, owner_id)
else:
workflow_row = wf_repo.get_by_legacy_id(self.workflow_id, owner_id)
if workflow_row is None:
logger.error(
f"Workflow {self.workflow_id} not found or inaccessible "
f"for user {owner_id}"
)
return None
pg_workflow_id = str(workflow_row["id"])
graph_version = workflow_row.get("current_graph_version", 1)
try:
graph_version = int(graph_version)
if graph_version <= 0:
graph_version = 1
except (ValueError, TypeError):
graph_version = 1
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
node_rows = WorkflowNodesRepository(conn).find_by_version(
pg_workflow_id, graph_version,
)
edge_rows = WorkflowEdgesRepository(conn).find_by_version(
pg_workflow_id, graph_version,
)
workflows_coll = db["workflows"]
workflow_nodes_coll = db["workflow_nodes"]
workflow_edges_coll = db["workflow_edges"]
workflow = Workflow(
name=workflow_row.get("name"),
description=workflow_row.get("description"),
workflow_doc = workflows_coll.find_one(
{"_id": ObjectId(self.workflow_id), "user": owner_id}
)
nodes = [
WorkflowNode(
id=n["node_id"],
workflow_id=pg_workflow_id,
type=n["node_type"],
title=n.get("title") or "Node",
description=n.get("description"),
position=n.get("position") or {"x": 0, "y": 0},
config=n.get("config") or {},
if not workflow_doc:
logger.error(
f"Workflow {self.workflow_id} not found or inaccessible for user {owner_id}"
)
for n in node_rows
]
edges = [
WorkflowEdge(
id=e["edge_id"],
workflow_id=pg_workflow_id,
source=e.get("source_id"),
target=e.get("target_id"),
sourceHandle=e.get("source_handle"),
targetHandle=e.get("target_handle"),
return None
workflow = Workflow(**workflow_doc)
graph_version = workflow_doc.get("current_graph_version", 1)
try:
graph_version = int(graph_version)
if graph_version <= 0:
graph_version = 1
except (ValueError, TypeError):
graph_version = 1
nodes_docs = list(
workflow_nodes_coll.find(
{"workflow_id": self.workflow_id, "graph_version": graph_version}
)
for e in edge_rows
]
)
if not nodes_docs and graph_version == 1:
nodes_docs = list(
workflow_nodes_coll.find(
{
"workflow_id": self.workflow_id,
"graph_version": {"$exists": False},
}
)
)
nodes = [WorkflowNode(**doc) for doc in nodes_docs]
edges_docs = list(
workflow_edges_coll.find(
{"workflow_id": self.workflow_id, "graph_version": graph_version}
)
)
if not edges_docs and graph_version == 1:
edges_docs = list(
workflow_edges_coll.find(
{
"workflow_id": self.workflow_id,
"graph_version": {"$exists": False},
}
)
)
edges = [WorkflowEdge(**doc) for doc in edges_docs]
return WorkflowGraph(workflow=workflow, nodes=nodes, edges=edges)
except Exception as e:
@@ -187,6 +188,10 @@ class WorkflowAgent(BaseAgent):
if not owner_id and isinstance(self.decoded_token, dict):
owner_id = self.decoded_token.get("sub")
try:
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
workflow_runs_coll = db["workflow_runs"]
run = WorkflowRun(
workflow_id=self.workflow_id or "unknown",
user=owner_id,
@@ -198,20 +203,23 @@ class WorkflowAgent(BaseAgent):
completed_at=datetime.now(timezone.utc),
)
if not self.workflow_id or not owner_id:
return
with db_session() as conn:
wf_repo = WorkflowsRepository(conn)
if looks_like_uuid(self.workflow_id):
workflow_row = wf_repo.get(self.workflow_id, owner_id)
else:
workflow_row = wf_repo.get_by_legacy_id(
self.workflow_id, owner_id,
)
if workflow_row is None:
result = workflow_runs_coll.insert_one(run.to_mongo_doc())
legacy_mongo_id = (
str(result.inserted_id)
if getattr(result, "inserted_id", None) is not None
else None
)
def _pg_write(repo: WorkflowRunsRepository) -> None:
if not self.workflow_id or not owner_id or not legacy_mongo_id:
return
WorkflowRunsRepository(conn).create(
str(workflow_row["id"]),
workflow = WorkflowsRepository(repo._conn).get_by_legacy_id(
self.workflow_id, owner_id,
)
if workflow is None:
return
repo.create(
workflow["id"],
owner_id,
run.status.value,
inputs=run.inputs,
@@ -219,7 +227,10 @@ class WorkflowAgent(BaseAgent):
steps=[step.model_dump(mode="json") for step in run.steps],
started_at=run.created_at,
ended_at=run.completed_at,
legacy_mongo_id=legacy_mongo_id,
)
dual_write(WorkflowRunsRepository, _pg_write)
except Exception as e:
logger.error(f"Failed to save workflow run: {e}")

View File

@@ -2,6 +2,7 @@ from datetime import datetime, timezone
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Union
from bson import ObjectId
from pydantic import BaseModel, ConfigDict, Field, field_validator
@@ -80,7 +81,24 @@ class WorkflowEdgeCreate(BaseModel):
class WorkflowEdge(WorkflowEdgeCreate):
pass
mongo_id: Optional[str] = Field(None, alias="_id")
@field_validator("mongo_id", mode="before")
@classmethod
def convert_objectid(cls, v: Any) -> Optional[str]:
if isinstance(v, ObjectId):
return str(v)
return v
def to_mongo_doc(self) -> Dict[str, Any]:
return {
"id": self.id,
"workflow_id": self.workflow_id,
"source_id": self.source_id,
"target_id": self.target_id,
"source_handle": self.source_handle,
"target_handle": self.target_handle,
}
class WorkflowNodeCreate(BaseModel):
@@ -102,7 +120,25 @@ class WorkflowNodeCreate(BaseModel):
class WorkflowNode(WorkflowNodeCreate):
pass
mongo_id: Optional[str] = Field(None, alias="_id")
@field_validator("mongo_id", mode="before")
@classmethod
def convert_objectid(cls, v: Any) -> Optional[str]:
if isinstance(v, ObjectId):
return str(v)
return v
def to_mongo_doc(self) -> Dict[str, Any]:
return {
"id": self.id,
"workflow_id": self.workflow_id,
"type": self.type.value,
"title": self.title,
"description": self.description,
"position": self.position.model_dump(),
"config": self.config,
}
class WorkflowCreate(BaseModel):
@@ -113,10 +149,26 @@ class WorkflowCreate(BaseModel):
class Workflow(WorkflowCreate):
id: Optional[str] = None
id: Optional[str] = Field(None, alias="_id")
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
@field_validator("id", mode="before")
@classmethod
def convert_objectid(cls, v: Any) -> Optional[str]:
if isinstance(v, ObjectId):
return str(v)
return v
def to_mongo_doc(self) -> Dict[str, Any]:
return {
"name": self.name,
"description": self.description,
"user": self.user,
"created_at": self.created_at,
"updated_at": self.updated_at,
}
class WorkflowGraph(BaseModel):
workflow: Workflow
@@ -157,7 +209,7 @@ class WorkflowRunCreate(BaseModel):
class WorkflowRun(BaseModel):
model_config = ConfigDict(extra="allow")
id: Optional[str] = None
id: Optional[str] = Field(None, alias="_id")
workflow_id: str
user: Optional[str] = None
status: ExecutionStatus = ExecutionStatus.PENDING
@@ -166,3 +218,25 @@ class WorkflowRun(BaseModel):
steps: List[NodeExecutionLog] = Field(default_factory=list)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
completed_at: Optional[datetime] = None
@field_validator("id", mode="before")
@classmethod
def convert_objectid(cls, v: Any) -> Optional[str]:
if isinstance(v, ObjectId):
return str(v)
return v
def to_mongo_doc(self) -> Dict[str, Any]:
doc = {
"workflow_id": self.workflow_id,
"status": self.status.value,
"inputs": self.inputs,
"outputs": self.outputs,
"steps": [step.model_dump() for step in self.steps],
"created_at": self.created_at,
"completed_at": self.completed_at,
}
if self.user:
doc["user"] = self.user
doc["user_id"] = self.user
return doc

View File

@@ -200,9 +200,6 @@ class WorkflowEngine:
node_config = AgentNodeConfig(**node.config.get("config", node.config))
if node_config.sources:
self._retrieve_node_sources(node_config)
if node_config.prompt_template:
formatted_prompt = self._format_template(node_config.prompt_template)
else:
@@ -211,26 +208,15 @@ class WorkflowEngine:
node_config.json_schema, node.title
)
node_model_id = node_config.model_id or self.agent.model_id
# Inherit BYOM scope from parent agent so owner-stored BYOM
# resolves on shared workflows.
node_user_id = getattr(self.agent, "model_user_id", None) or (
self.agent.decoded_token.get("sub")
if isinstance(self.agent.decoded_token, dict)
else None
)
node_llm_name = (
node_config.llm_name
or get_provider_from_model_id(
node_model_id or "", user_id=node_user_id
)
or get_provider_from_model_id(node_model_id or "")
or self.agent.llm_name
)
node_api_key = get_api_key_for_provider(node_llm_name) or self.agent.api_key
if node_json_schema and node_model_id:
model_capabilities = get_model_capabilities(
node_model_id, user_id=node_user_id
)
model_capabilities = get_model_capabilities(node_model_id)
if model_capabilities and not model_capabilities.get(
"supports_structured_output", False
):
@@ -243,7 +229,6 @@ class WorkflowEngine:
"endpoint": self.agent.endpoint,
"llm_name": node_llm_name,
"model_id": node_model_id,
"model_user_id": getattr(self.agent, "model_user_id", None),
"api_key": node_api_key,
"tool_ids": node_config.tools,
"prompt": node_config.system_prompt,
@@ -470,29 +455,6 @@ class WorkflowEngine:
docs_together = "\n\n".join(docs_together_parts) if docs_together_parts else None
return docs, docs_together
def _retrieve_node_sources(self, node_config: AgentNodeConfig) -> None:
"""Retrieve documents from the node's sources for template resolution."""
from application.retriever.retriever_creator import RetrieverCreator
query = self.state.get("query", "")
if not query:
return
try:
retriever = RetrieverCreator.create_retriever(
node_config.retriever or "classic",
source={"active_docs": node_config.sources},
chat_history=[],
prompt="",
chunks=int(node_config.chunks) if node_config.chunks else 2,
decoded_token=self.agent.decoded_token,
)
docs = retriever.search(query)
if docs:
self.agent.retrieved_docs = docs
except Exception:
logger.exception("Failed to retrieve docs for workflow node")
def get_execution_summary(self) -> List[NodeExecutionLog]:
return [
NodeExecutionLog(

View File

@@ -1,4 +1,4 @@
"""0001 initial schema — consolidated baseline for user-data tables.
"""0001 initial schema — consolidated Phase-1..3 baseline.
Revision ID: 0001_initial
Revises:
@@ -167,19 +167,14 @@ def upgrade() -> None:
op.execute(
"""
CREATE TABLE user_tools (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
name TEXT NOT NULL,
custom_name TEXT,
display_name TEXT,
description TEXT,
config JSONB NOT NULL DEFAULT '{}'::jsonb,
config_requirements JSONB NOT NULL DEFAULT '{}'::jsonb,
actions JSONB NOT NULL DEFAULT '[]'::jsonb,
status BOOLEAN NOT NULL DEFAULT true,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
legacy_mongo_id TEXT
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
name TEXT NOT NULL,
custom_name TEXT,
display_name TEXT,
config JSONB NOT NULL DEFAULT '{}'::jsonb,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
"""
)
@@ -193,8 +188,7 @@ def upgrade() -> None:
agent_id UUID,
prompt_tokens INTEGER NOT NULL DEFAULT 0,
generated_tokens INTEGER NOT NULL DEFAULT 0,
timestamp TIMESTAMPTZ NOT NULL DEFAULT now(),
mongo_id TEXT
timestamp TIMESTAMPTZ NOT NULL DEFAULT now()
);
"""
)
@@ -210,8 +204,7 @@ def upgrade() -> None:
user_id TEXT,
endpoint TEXT,
timestamp TIMESTAMPTZ NOT NULL DEFAULT now(),
data JSONB,
mongo_id TEXT
data JSONB
);
"""
)
@@ -227,8 +220,7 @@ def upgrade() -> None:
api_key TEXT,
query TEXT,
stacks JSONB NOT NULL DEFAULT '[]'::jsonb,
timestamp TIMESTAMPTZ NOT NULL DEFAULT now(),
mongo_id TEXT
timestamp TIMESTAMPTZ NOT NULL DEFAULT now()
);
"""
)
@@ -236,14 +228,12 @@ def upgrade() -> None:
op.execute(
"""
CREATE TABLE agent_folders (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
name TEXT NOT NULL,
description TEXT,
parent_id UUID REFERENCES agent_folders(id) ON DELETE SET NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
legacy_mongo_id TEXT
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
name TEXT NOT NULL,
description TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
"""
)
@@ -251,24 +241,13 @@ def upgrade() -> None:
op.execute(
"""
CREATE TABLE sources (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
name TEXT NOT NULL,
language TEXT,
date TIMESTAMPTZ NOT NULL DEFAULT now(),
model TEXT,
type TEXT,
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
retriever TEXT,
sync_frequency TEXT,
tokens TEXT,
file_path TEXT,
remote_data JSONB,
directory_structure JSONB,
file_name_map JSONB,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
legacy_mongo_id TEXT
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
name TEXT NOT NULL,
type TEXT,
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
"""
)
@@ -276,38 +255,33 @@ def upgrade() -> None:
op.execute(
"""
CREATE TABLE agents (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
name TEXT NOT NULL,
description TEXT,
agent_type TEXT,
status TEXT NOT NULL,
key CITEXT UNIQUE,
image TEXT,
source_id UUID REFERENCES sources(id) ON DELETE SET NULL,
extra_source_ids UUID[] NOT NULL DEFAULT '{}',
chunks INTEGER,
retriever TEXT,
prompt_id UUID REFERENCES prompts(id) ON DELETE SET NULL,
tools JSONB NOT NULL DEFAULT '[]'::jsonb,
json_schema JSONB,
models JSONB,
default_model_id TEXT,
folder_id UUID REFERENCES agent_folders(id) ON DELETE SET NULL,
workflow_id UUID,
limited_token_mode BOOLEAN NOT NULL DEFAULT false,
token_limit INTEGER,
limited_request_mode BOOLEAN NOT NULL DEFAULT false,
request_limit INTEGER,
allow_system_prompt_override BOOLEAN NOT NULL DEFAULT false,
shared BOOLEAN NOT NULL DEFAULT false,
shared_token CITEXT UNIQUE,
shared_metadata JSONB,
incoming_webhook_token CITEXT UNIQUE,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
last_used_at TIMESTAMPTZ,
legacy_mongo_id TEXT
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
name TEXT NOT NULL,
description TEXT,
agent_type TEXT,
status TEXT NOT NULL,
key CITEXT UNIQUE,
source_id UUID REFERENCES sources(id) ON DELETE SET NULL,
extra_source_ids UUID[] NOT NULL DEFAULT '{}',
chunks INTEGER,
retriever TEXT,
prompt_id UUID REFERENCES prompts(id) ON DELETE SET NULL,
tools JSONB NOT NULL DEFAULT '[]'::jsonb,
json_schema JSONB,
models JSONB,
default_model_id TEXT,
folder_id UUID REFERENCES agent_folders(id) ON DELETE SET NULL,
limited_token_mode BOOLEAN NOT NULL DEFAULT false,
token_limit INTEGER,
limited_request_mode BOOLEAN NOT NULL DEFAULT false,
request_limit INTEGER,
shared BOOLEAN NOT NULL DEFAULT false,
incoming_webhook_token CITEXT UNIQUE,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
last_used_at TIMESTAMPTZ,
legacy_mongo_id TEXT
);
"""
)
@@ -325,11 +299,6 @@ def upgrade() -> None:
upload_path TEXT NOT NULL,
mime_type TEXT,
size BIGINT,
content TEXT,
token_count INTEGER,
openai_file_id TEXT,
google_file_uri TEXT,
metadata JSONB,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
legacy_mongo_id TEXT
);
@@ -344,7 +313,6 @@ def upgrade() -> None:
tool_id UUID REFERENCES user_tools(id) ON DELETE CASCADE,
path TEXT NOT NULL,
content TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
"""
@@ -353,16 +321,13 @@ def upgrade() -> None:
op.execute(
"""
CREATE TABLE todos (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
tool_id UUID REFERENCES user_tools(id) ON DELETE CASCADE,
todo_id INTEGER,
title TEXT NOT NULL,
completed BOOLEAN NOT NULL DEFAULT false,
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
legacy_mongo_id TEXT
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
tool_id UUID REFERENCES user_tools(id) ON DELETE CASCADE,
title TEXT NOT NULL,
completed BOOLEAN NOT NULL DEFAULT false,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
"""
)
@@ -370,15 +335,13 @@ def upgrade() -> None:
op.execute(
"""
CREATE TABLE notes (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
tool_id UUID REFERENCES user_tools(id) ON DELETE CASCADE,
title TEXT NOT NULL,
content TEXT NOT NULL,
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
legacy_mongo_id TEXT
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
tool_id UUID REFERENCES user_tools(id) ON DELETE CASCADE,
title TEXT NOT NULL,
content TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
"""
)
@@ -386,18 +349,12 @@ def upgrade() -> None:
op.execute(
"""
CREATE TABLE connector_sessions (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
provider TEXT NOT NULL,
server_url TEXT,
session_token TEXT UNIQUE,
user_email TEXT,
status TEXT,
token_info JSONB,
session_data JSONB NOT NULL DEFAULT '{}'::jsonb,
expires_at TIMESTAMPTZ,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
legacy_mongo_id TEXT
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
provider TEXT NOT NULL,
session_data JSONB NOT NULL,
expires_at TIMESTAMPTZ,
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
"""
)
@@ -497,14 +454,6 @@ def upgrade() -> None:
);
"""
)
# Backfill the agents.workflow_id FK now that workflows exists.
# The column was created without a FK (forward reference to a table
# that hadn't been declared yet); add the constraint here so workflow
# deletion still cascades through to agent unset.
op.execute(
"ALTER TABLE agents ADD CONSTRAINT agents_workflow_fk "
"FOREIGN KEY (workflow_id) REFERENCES workflows(id) ON DELETE SET NULL;"
)
op.execute(
"""
@@ -590,26 +539,13 @@ def upgrade() -> None:
)
op.execute(
# MCP and OAuth connectors share the ``provider`` slot, so the
# dedup key is ``(user_id, server_url, provider)``: MCP rows
# differentiate by server_url (one per MCP server), OAuth rows
# have server_url = NULL and differentiate by provider alone.
# COALESCE lets NULL server_url participate in the constraint.
"CREATE UNIQUE INDEX connector_sessions_user_endpoint_uidx "
"ON connector_sessions (user_id, COALESCE(server_url, ''), provider);"
"CREATE UNIQUE INDEX connector_sessions_user_provider_uidx "
"ON connector_sessions (user_id, provider);"
)
op.execute(
"CREATE INDEX connector_sessions_expiry_idx "
"ON connector_sessions (expires_at) WHERE expires_at IS NOT NULL;"
)
op.execute(
"CREATE INDEX connector_sessions_server_url_idx "
"ON connector_sessions (server_url) WHERE server_url IS NOT NULL;"
)
op.execute(
"CREATE UNIQUE INDEX connector_sessions_legacy_mongo_id_uidx "
"ON connector_sessions (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
)
op.execute(
"CREATE UNIQUE INDEX conversation_messages_conv_pos_uidx "
@@ -651,10 +587,6 @@ def upgrade() -> None:
op.execute("CREATE UNIQUE INDEX notes_user_tool_uidx ON notes (user_id, tool_id);")
op.execute("CREATE INDEX notes_tool_id_idx ON notes (tool_id);")
op.execute(
"CREATE UNIQUE INDEX notes_legacy_mongo_id_uidx "
"ON notes (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
)
op.execute(
"CREATE UNIQUE INDEX pending_tool_state_conv_user_uidx "
@@ -684,54 +616,20 @@ def upgrade() -> None:
)
op.execute("CREATE INDEX sources_user_idx ON sources (user_id);")
op.execute(
"CREATE UNIQUE INDEX sources_legacy_mongo_id_uidx "
"ON sources (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
)
op.execute(
"CREATE UNIQUE INDEX user_tools_legacy_mongo_id_uidx "
"ON user_tools (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
)
op.execute(
"CREATE UNIQUE INDEX agent_folders_legacy_mongo_id_uidx "
"ON agent_folders (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
)
op.execute("CREATE INDEX agent_folders_parent_idx ON agent_folders (parent_id);")
op.execute("CREATE INDEX agents_workflow_idx ON agents (workflow_id);")
op.execute('CREATE INDEX stack_logs_timestamp_idx ON stack_logs ("timestamp" DESC);')
op.execute('CREATE INDEX stack_logs_user_ts_idx ON stack_logs (user_id, "timestamp" DESC);')
op.execute('CREATE INDEX stack_logs_level_ts_idx ON stack_logs (level, "timestamp" DESC);')
op.execute("CREATE INDEX stack_logs_activity_idx ON stack_logs (activity_id);")
op.execute(
"CREATE UNIQUE INDEX stack_logs_mongo_id_uidx "
"ON stack_logs (mongo_id) WHERE mongo_id IS NOT NULL;"
)
op.execute("CREATE INDEX todos_user_tool_idx ON todos (user_id, tool_id);")
op.execute("CREATE INDEX todos_tool_id_idx ON todos (tool_id);")
op.execute(
"CREATE UNIQUE INDEX todos_legacy_mongo_id_uidx "
"ON todos (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
)
op.execute(
"CREATE UNIQUE INDEX todos_tool_todo_id_uidx "
"ON todos (tool_id, todo_id) WHERE todo_id IS NOT NULL;"
)
op.execute('CREATE INDEX token_usage_user_ts_idx ON token_usage (user_id, "timestamp" DESC);')
op.execute('CREATE INDEX token_usage_key_ts_idx ON token_usage (api_key, "timestamp" DESC);')
op.execute('CREATE INDEX token_usage_agent_ts_idx ON token_usage (agent_id, "timestamp" DESC);')
op.execute(
"CREATE UNIQUE INDEX token_usage_mongo_id_uidx "
"ON token_usage (mongo_id) WHERE mongo_id IS NOT NULL;"
)
op.execute('CREATE INDEX user_logs_user_ts_idx ON user_logs (user_id, "timestamp" DESC);')
op.execute(
"CREATE UNIQUE INDEX user_logs_mongo_id_uidx "
"ON user_logs (mongo_id) WHERE mongo_id IS NOT NULL;"
)
op.execute("CREATE INDEX user_tools_user_id_idx ON user_tools (user_id);")

View File

@@ -1,37 +0,0 @@
"""0002 app_metadata — singleton key/value table for instance-wide state.
Used by the startup version-check client to persist the anonymous
instance UUID and a one-shot "notice shown" flag. Both values are tiny
plain-text strings; this is a deliberate generic-config table rather
than dedicated columns so future one-off settings (telemetry opt-in
timestamps, feature-flag overrides, etc.) don't each need their own
migration.
Revision ID: 0002_app_metadata
Revises: 0001_initial
"""
from typing import Sequence, Union
from alembic import op
revision: str = "0002_app_metadata"
down_revision: Union[str, None] = "0001_initial"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.execute(
"""
CREATE TABLE app_metadata (
key TEXT PRIMARY KEY,
value TEXT NOT NULL
);
"""
)
def downgrade() -> None:
op.execute("DROP TABLE IF EXISTS app_metadata;")

View File

@@ -1,65 +0,0 @@
"""0003 user_custom_models — per-user OpenAI-compatible model registrations.
Revision ID: 0003_user_custom_models
Revises: 0002_app_metadata
"""
from typing import Sequence, Union
from alembic import op
revision: str = "0003_user_custom_models"
down_revision: Union[str, None] = "0002_app_metadata"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.execute(
"""
CREATE TABLE user_custom_models (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
upstream_model_id TEXT NOT NULL,
display_name TEXT NOT NULL,
description TEXT NOT NULL DEFAULT '',
base_url TEXT NOT NULL,
api_key_encrypted TEXT NOT NULL,
capabilities JSONB NOT NULL DEFAULT '{}'::jsonb,
enabled BOOLEAN NOT NULL DEFAULT true,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
"""
)
op.execute(
"CREATE INDEX user_custom_models_user_id_idx "
"ON user_custom_models (user_id);"
)
# Mirror the project-wide invariants set up in 0001_initial:
# * user_id FK with ON DELETE RESTRICT (deferrable),
# * ensure_user_exists() trigger so the parent users row autocreates,
# * set_updated_at() trigger.
op.execute(
"ALTER TABLE user_custom_models "
"ADD CONSTRAINT user_custom_models_user_id_fk "
"FOREIGN KEY (user_id) REFERENCES users(user_id) "
"ON DELETE RESTRICT DEFERRABLE INITIALLY IMMEDIATE;"
)
op.execute(
"CREATE TRIGGER user_custom_models_ensure_user "
"BEFORE INSERT OR UPDATE OF user_id ON user_custom_models "
"FOR EACH ROW EXECUTE FUNCTION ensure_user_exists();"
)
op.execute(
"CREATE TRIGGER user_custom_models_set_updated_at "
"BEFORE UPDATE ON user_custom_models "
"FOR EACH ROW WHEN (OLD.* IS DISTINCT FROM NEW.*) "
"EXECUTE FUNCTION set_updated_at();"
)
def downgrade() -> None:
op.execute("DROP TABLE IF EXISTS user_custom_models;")

View File

@@ -1,217 +0,0 @@
"""0004 durability foundation — idempotency, tool-call log, ingest checkpoint.
Adds ``task_dedup``, ``webhook_dedup``, ``tool_call_attempts``,
``ingest_chunk_progress``, and per-row status flags on
``conversation_messages`` and ``pending_tool_state``. Also adds
``token_usage.source`` and ``token_usage.request_id`` so per-channel
cost attribution (``agent_stream`` / ``title`` / ``compression`` /
``rag_condense`` / ``fallback``) is queryable and multi-call agent runs
can be DISTINCT-collapsed into a single user request for rate limiting.
Revision ID: 0004_durability_foundation
Revises: 0003_user_custom_models
"""
from typing import Sequence, Union
from alembic import op
revision: str = "0004_durability_foundation"
down_revision: Union[str, None] = "0003_user_custom_models"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ------------------------------------------------------------------
# New tables
# ------------------------------------------------------------------
# ``attempt_count`` bounds the per-Celery-task idempotency wrapper's
# retry loop so a poison message can't run forever; default 0 means
# existing rows behave as if no attempts have run yet.
op.execute(
"""
CREATE TABLE task_dedup (
idempotency_key TEXT PRIMARY KEY,
task_name TEXT NOT NULL,
task_id TEXT NOT NULL,
result_json JSONB,
status TEXT NOT NULL
CHECK (status IN ('pending', 'completed', 'failed')),
attempt_count INT NOT NULL DEFAULT 0,
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
"""
)
op.execute(
"""
CREATE TABLE webhook_dedup (
idempotency_key TEXT PRIMARY KEY,
agent_id UUID NOT NULL,
task_id TEXT NOT NULL,
response_json JSONB,
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
"""
)
# FK on ``message_id`` uses ``ON DELETE SET NULL`` so the journal row
# survives parent-message deletion (compliance / cost-attribution).
op.execute(
"""
CREATE TABLE tool_call_attempts (
call_id TEXT PRIMARY KEY,
message_id UUID
REFERENCES conversation_messages (id)
ON DELETE SET NULL,
tool_id UUID,
tool_name TEXT NOT NULL,
action_name TEXT NOT NULL,
arguments JSONB NOT NULL,
result JSONB,
error TEXT,
status TEXT NOT NULL
CHECK (status IN (
'proposed', 'executed', 'confirmed',
'compensated', 'failed'
)),
attempted_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
"""
)
op.execute(
"""
CREATE TABLE ingest_chunk_progress (
source_id UUID PRIMARY KEY,
total_chunks INT NOT NULL,
embedded_chunks INT NOT NULL DEFAULT 0,
last_index INT NOT NULL DEFAULT -1,
last_updated TIMESTAMPTZ NOT NULL DEFAULT now()
);
"""
)
# ------------------------------------------------------------------
# Column additions on existing tables
# ------------------------------------------------------------------
# DEFAULT 'complete' backfills existing rows — they're already done.
op.execute(
"""
ALTER TABLE conversation_messages
ADD COLUMN status TEXT NOT NULL DEFAULT 'complete'
CHECK (status IN ('pending', 'streaming', 'complete', 'failed')),
ADD COLUMN request_id TEXT;
"""
)
op.execute(
"""
ALTER TABLE pending_tool_state
ADD COLUMN status TEXT NOT NULL DEFAULT 'pending'
CHECK (status IN ('pending', 'resuming')),
ADD COLUMN resumed_at TIMESTAMPTZ;
"""
)
# Default ``agent_stream`` backfills historical rows under the
# assumption they were written from the primary path — pre-fix the
# only path that wrote was the error branch reading agent.llm.
# ``request_id`` is the stream-scoped UUID stamped by the route on
# ``agent.llm`` so multi-tool agent runs (which produce N rows)
# collapse to one request via DISTINCT in ``count_in_range``.
# Side-channel sources (``title`` / ``compression`` / ``rag_condense``
# / ``fallback``) leave it NULL and are excluded from the request
# count by source filter.
op.execute(
"""
ALTER TABLE token_usage
ADD COLUMN source TEXT NOT NULL DEFAULT 'agent_stream',
ADD COLUMN request_id TEXT;
"""
)
# ------------------------------------------------------------------
# Indexes — partial where the predicate selects only non-terminal rows
# ------------------------------------------------------------------
op.execute(
"CREATE INDEX conversation_messages_pending_ts_idx "
"ON conversation_messages (timestamp) "
"WHERE status IN ('pending', 'streaming');"
)
op.execute(
"CREATE INDEX tool_call_attempts_pending_ts_idx "
"ON tool_call_attempts (attempted_at) "
"WHERE status IN ('proposed', 'executed');"
)
op.execute(
"CREATE INDEX tool_call_attempts_message_idx "
"ON tool_call_attempts (message_id) "
"WHERE message_id IS NOT NULL;"
)
op.execute(
"CREATE INDEX pending_tool_state_resuming_ts_idx "
"ON pending_tool_state (resumed_at) "
"WHERE status = 'resuming';"
)
op.execute(
"CREATE INDEX webhook_dedup_agent_idx "
"ON webhook_dedup (agent_id);"
)
op.execute(
"CREATE INDEX task_dedup_pending_attempts_idx "
"ON task_dedup (attempt_count) WHERE status = 'pending';"
)
# Cost-attribution dashboards filter ``token_usage`` by
# ``(timestamp, source)``; index the same shape so they stay cheap.
op.execute(
"CREATE INDEX token_usage_source_ts_idx "
"ON token_usage (source, timestamp);"
)
# Partial index — only rows with a stamped request_id participate
# in the DISTINCT count. NULL rows fall through to the COUNT(*)
# branch in the repository query.
op.execute(
"CREATE INDEX token_usage_request_id_idx "
"ON token_usage (request_id) "
"WHERE request_id IS NOT NULL;"
)
op.execute(
"CREATE TRIGGER tool_call_attempts_set_updated_at "
"BEFORE UPDATE ON tool_call_attempts "
"FOR EACH ROW WHEN (OLD.* IS DISTINCT FROM NEW.*) "
"EXECUTE FUNCTION set_updated_at();"
)
def downgrade() -> None:
# CASCADE so the downgrade stays safe if later migrations FK into these.
for table in (
"ingest_chunk_progress",
"tool_call_attempts",
"webhook_dedup",
"task_dedup",
):
op.execute(f"DROP TABLE IF EXISTS {table} CASCADE;")
op.execute(
"ALTER TABLE conversation_messages "
"DROP COLUMN IF EXISTS request_id, "
"DROP COLUMN IF EXISTS status;"
)
op.execute(
"ALTER TABLE pending_tool_state "
"DROP COLUMN IF EXISTS resumed_at, "
"DROP COLUMN IF EXISTS status;"
)
op.execute("DROP INDEX IF EXISTS token_usage_request_id_idx;")
op.execute("DROP INDEX IF EXISTS token_usage_source_ts_idx;")
op.execute(
"ALTER TABLE token_usage "
"DROP COLUMN IF EXISTS request_id, "
"DROP COLUMN IF EXISTS source;"
)

View File

@@ -1,44 +0,0 @@
"""0005 ingest_chunk_progress.attempt_id — per-attempt resume scoping.
Without this column, a completed checkpoint row poisoned every later
embed call on the same ``source_id``: a sync after an upload finished
read the upload's terminal ``last_index`` and either embedded zero
chunks (if new ``total_docs <= last_index + 1``) or stacked new chunks
on top of the old vectors (if ``total_docs > last_index + 1``).
``attempt_id`` is stamped from ``self.request.id`` (Celery's stable
task id, which survives ``acks_late`` retries of the same task but
differs across separate task invocations). The repository's
``init_progress`` upsert resets ``last_index`` / ``embedded_chunks``
when the incoming ``attempt_id`` differs from the stored one — so a
fresh sync starts from chunk 0 while a retry of the same task resumes
from the last checkpointed chunk.
Revision ID: 0005_ingest_attempt_id
Revises: 0004_durability_foundation
"""
from typing import Sequence, Union
from alembic import op
revision: str = "0005_ingest_attempt_id"
down_revision: Union[str, None] = "0004_durability_foundation"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.execute(
"""
ALTER TABLE ingest_chunk_progress
ADD COLUMN attempt_id TEXT;
"""
)
def downgrade() -> None:
op.execute(
"ALTER TABLE ingest_chunk_progress DROP COLUMN IF EXISTS attempt_id;"
)

View File

@@ -1,57 +0,0 @@
"""0006 task_dedup lease columns — running-lease for in-flight tasks.
Without these, ``with_idempotency`` only short-circuits *completed*
rows. A late-ack redelivery (Redis ``visibility_timeout`` exceeded by a
long ingest, or a hung-but-alive worker) hands the same message to a
second worker; ``_claim_or_bump`` only bumped the attempt counter and
both workers ran the task body in parallel — duplicate vector writes,
duplicate token spend, duplicate webhook side effects.
``lease_owner_id`` + ``lease_expires_at`` turn that into an atomic
compare-and-swap. The wrapper claims a lease at entry, refreshes it via
a 30 s heartbeat thread, and finalises (which makes the lease moot via
``status='completed'``). A second worker hitting the same key sees a
fresh lease and ``self.retry(countdown=LEASE_TTL)``s instead of running.
A crashed worker's lease expires after ``LEASE_TTL`` seconds and the
next retry can claim it.
Revision ID: 0006_idempotency_lease
Revises: 0005_ingest_attempt_id
"""
from typing import Sequence, Union
from alembic import op
revision: str = "0006_idempotency_lease"
down_revision: Union[str, None] = "0005_ingest_attempt_id"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.execute(
"""
ALTER TABLE task_dedup
ADD COLUMN lease_owner_id TEXT,
ADD COLUMN lease_expires_at TIMESTAMPTZ;
"""
)
# Reconciler's stuck-pending sweep filters by
# ``(status='pending', lease_expires_at < now() - 60s, attempt_count >= 5)``.
# Partial index keeps the scan small even under heavy task throughput.
op.execute(
"CREATE INDEX task_dedup_pending_lease_idx "
"ON task_dedup (lease_expires_at) "
"WHERE status = 'pending';"
)
def downgrade() -> None:
op.execute("DROP INDEX IF EXISTS task_dedup_pending_lease_idx;")
op.execute(
"ALTER TABLE task_dedup "
"DROP COLUMN IF EXISTS lease_expires_at, "
"DROP COLUMN IF EXISTS lease_owner_id;"
)

View File

@@ -1,40 +0,0 @@
"""0007 message_events — durable journal of chat-stream events.
Snapshot half of the chat-stream snapshot+tail pattern. Composite PK
``(message_id, sequence_no)``, ``created_at`` indexed for retention
sweeps, ``ON DELETE CASCADE`` from ``conversation_messages``.
Revision ID: 0007_message_events
Revises: 0006_idempotency_lease
"""
from typing import Sequence, Union
from alembic import op
revision: str = "0007_message_events"
down_revision: Union[str, None] = "0006_idempotency_lease"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.execute(
"""
CREATE TABLE message_events (
message_id UUID NOT NULL REFERENCES conversation_messages(id) ON DELETE CASCADE,
sequence_no INTEGER NOT NULL,
event_type TEXT NOT NULL,
payload JSONB NOT NULL DEFAULT '{}'::jsonb,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
PRIMARY KEY (message_id, sequence_no)
);
CREATE INDEX message_events_created_at_idx ON message_events(created_at);
"""
)
def downgrade() -> None:
op.execute("DROP INDEX IF EXISTS message_events_created_at_idx;")
op.execute("DROP TABLE IF EXISTS message_events;")

View File

@@ -1,44 +0,0 @@
"""0008 ingest_chunk_progress.status — terminal flag for stalled ingests.
The reconciler's stalled-ingest sweep had no terminal write, so a dead
ingest re-alerted every ~30 min forever. ``status`` lets it escalate a
stalled checkpoint to ``'stalled'`` once and stop re-selecting it;
``init_progress`` resets it to ``'active'`` on reingest.
Revision ID: 0008_ingest_progress_status
Revises: 0007_message_events
"""
from typing import Sequence, Union
from alembic import op
revision: str = "0008_ingest_progress_status"
down_revision: Union[str, None] = "0007_message_events"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Constant DEFAULT — metadata-only ADD COLUMN, no table rewrite.
op.execute(
"""
ALTER TABLE ingest_chunk_progress
ADD COLUMN status TEXT NOT NULL DEFAULT 'active'
CHECK (status IN ('active', 'stalled'));
"""
)
# Partial index for the reconciler's stalled-ingest sweep.
op.execute(
"CREATE INDEX ingest_chunk_progress_active_idx "
"ON ingest_chunk_progress (last_updated) "
"WHERE status = 'active';"
)
def downgrade() -> None:
op.execute("DROP INDEX IF EXISTS ingest_chunk_progress_active_idx;")
op.execute(
"ALTER TABLE ingest_chunk_progress DROP COLUMN IF EXISTS status;"
)

View File

@@ -1,83 +0,0 @@
"""0009 default chat tools — users.tool_preferences + memories.tool_id.
Adds ``users.tool_preferences`` JSONB and drops the
``memories.tool_id`` FK to ``user_tools`` (synthetic default-tool ids
have no ``user_tools`` row). Delete-cascade for real tools is kept via
an AFTER DELETE trigger on ``user_tools``. Idempotent both ways.
Revision ID: 0009_tool_preferences
Revises: 0008_ingest_progress_status
"""
from typing import Sequence, Union
from alembic import op
revision: str = "0009_tool_preferences"
down_revision: Union[str, None] = "0008_ingest_progress_status"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.execute(
"""
ALTER TABLE users
ADD COLUMN IF NOT EXISTS tool_preferences JSONB
NOT NULL DEFAULT '{}'::jsonb;
"""
)
op.execute(
"ALTER TABLE memories DROP CONSTRAINT IF EXISTS memories_tool_id_fkey;"
)
op.execute(
"""
CREATE OR REPLACE FUNCTION cleanup_tool_memories() RETURNS trigger
LANGUAGE plpgsql AS $$
BEGIN
DELETE FROM memories WHERE tool_id = OLD.id;
RETURN OLD;
END;
$$;
"""
)
# DROP-then-CREATE — no CREATE OR REPLACE TRIGGER for this signature.
op.execute(
"DROP TRIGGER IF EXISTS user_tools_cleanup_memories ON user_tools;"
)
op.execute(
"CREATE TRIGGER user_tools_cleanup_memories "
"AFTER DELETE ON user_tools "
"FOR EACH ROW EXECUTE FUNCTION cleanup_tool_memories();"
)
def downgrade() -> None:
op.execute(
"DROP TRIGGER IF EXISTS user_tools_cleanup_memories ON user_tools;"
)
op.execute("DROP FUNCTION IF EXISTS cleanup_tool_memories();")
# DESTRUCTIVE: restoring the FK requires every memories.tool_id to
# reference a real user_tools row. Any memory written by a built-in
# default tool (synthetic uuid5 id, no user_tools row) is permanently
# DELETED here so the constraint can be re-created. Downgrading 0009
# therefore loses all built-in-memory-tool data — by necessity, since
# the restored schema cannot represent it.
op.execute(
"""
DELETE FROM memories
WHERE tool_id IS NOT NULL
AND tool_id NOT IN (SELECT id FROM user_tools);
"""
)
op.execute(
"""
ALTER TABLE memories
ADD CONSTRAINT memories_tool_id_fkey
FOREIGN KEY (tool_id) REFERENCES user_tools(id) ON DELETE CASCADE;
"""
)
op.execute("ALTER TABLE users DROP COLUMN IF EXISTS tool_preferences;")

View File

@@ -1,147 +0,0 @@
"""0010 scheduler — schedules + schedule_runs tables.
Revision ID: 0010_schedules
Revises: 0009_tool_preferences
"""
from typing import Sequence, Union
from alembic import op
revision: str = "0010_schedules"
down_revision: Union[str, None] = "0009_tool_preferences"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.execute(
"""
CREATE TABLE schedules (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
agent_id UUID NOT NULL REFERENCES agents(id) ON DELETE CASCADE,
trigger_type TEXT NOT NULL,
name TEXT,
instruction TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'active',
cron TEXT,
run_at TIMESTAMPTZ,
timezone TEXT NOT NULL DEFAULT 'UTC',
next_run_at TIMESTAMPTZ,
last_run_at TIMESTAMPTZ,
end_at TIMESTAMPTZ,
tool_allowlist JSONB NOT NULL DEFAULT '[]'::jsonb,
model_id TEXT,
token_budget INTEGER,
origin_conversation_id UUID REFERENCES conversations(id) ON DELETE SET NULL,
created_via TEXT NOT NULL DEFAULT 'ui',
consecutive_failure_count INTEGER NOT NULL DEFAULT 0,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
CONSTRAINT schedules_trigger_type_chk
CHECK (trigger_type IN ('once', 'recurring')),
CONSTRAINT schedules_status_chk
CHECK (status IN ('active', 'paused', 'completed', 'cancelled')),
CONSTRAINT schedules_created_via_chk
CHECK (created_via IN ('chat', 'ui')),
CONSTRAINT schedules_recurring_cron_chk
CHECK (trigger_type <> 'recurring' OR cron IS NOT NULL),
CONSTRAINT schedules_once_run_at_chk
CHECK (trigger_type <> 'once' OR run_at IS NOT NULL)
);
"""
)
op.execute(
"CREATE INDEX schedules_user_idx ON schedules (user_id);"
)
op.execute(
"CREATE INDEX schedules_agent_idx ON schedules (agent_id);"
)
# Dispatcher hot path: status='active' AND next_run_at <= now().
op.execute(
"CREATE INDEX schedules_due_idx "
"ON schedules (status, next_run_at) "
"WHERE status = 'active';"
)
op.execute(
"CREATE TRIGGER schedules_set_updated_at "
"BEFORE UPDATE ON schedules "
"FOR EACH ROW EXECUTE FUNCTION set_updated_at();"
)
op.execute(
"""
CREATE TABLE schedule_runs (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
schedule_id UUID NOT NULL REFERENCES schedules(id) ON DELETE CASCADE,
user_id TEXT NOT NULL,
agent_id UUID NOT NULL REFERENCES agents(id) ON DELETE CASCADE,
status TEXT NOT NULL DEFAULT 'pending',
scheduled_for TIMESTAMPTZ NOT NULL,
trigger_source TEXT NOT NULL DEFAULT 'cron',
started_at TIMESTAMPTZ,
finished_at TIMESTAMPTZ,
output TEXT,
output_truncated BOOLEAN NOT NULL DEFAULT false,
error TEXT,
error_type TEXT,
prompt_tokens INTEGER NOT NULL DEFAULT 0,
generated_tokens INTEGER NOT NULL DEFAULT 0,
conversation_id UUID,
message_id UUID,
celery_task_id TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
CONSTRAINT schedule_runs_status_chk
CHECK (status IN (
'pending', 'running', 'success', 'failed', 'skipped', 'timeout'
)),
CONSTRAINT schedule_runs_trigger_source_chk
CHECK (trigger_source IN ('cron', 'manual')),
CONSTRAINT schedule_runs_error_type_chk
CHECK (error_type IS NULL OR error_type IN (
'auth_expired', 'tool_not_allowed', 'budget_exceeded',
'timeout', 'agent_error', 'internal', 'missed', 'overlap'
))
);
"""
)
# Dedup primitive: racing dispatchers hit ON CONFLICT on this index.
op.execute(
"CREATE UNIQUE INDEX schedule_runs_dedup_uidx "
"ON schedule_runs (schedule_id, scheduled_for);"
)
op.execute(
"CREATE INDEX schedule_runs_schedule_recent_idx "
"ON schedule_runs (schedule_id, scheduled_for DESC);"
)
op.execute(
"CREATE INDEX schedule_runs_user_idx ON schedule_runs (user_id);"
)
op.execute(
"CREATE INDEX schedule_runs_running_idx "
"ON schedule_runs (status, started_at) "
"WHERE status = 'running';"
)
op.execute(
"CREATE TRIGGER schedule_runs_set_updated_at "
"BEFORE UPDATE ON schedule_runs "
"FOR EACH ROW EXECUTE FUNCTION set_updated_at();"
)
def downgrade() -> None:
# Drop triggers explicitly (grep-able) before CASCADE-dropping the tables.
op.execute(
"DROP TRIGGER IF EXISTS schedule_runs_set_updated_at ON schedule_runs;"
)
op.execute("DROP TABLE IF EXISTS schedule_runs CASCADE;")
op.execute(
"DROP TRIGGER IF EXISTS schedules_set_updated_at ON schedules;"
)
op.execute("DROP TABLE IF EXISTS schedules CASCADE;")

View File

@@ -1,53 +0,0 @@
"""0011 scheduler — make schedules.agent_id / schedule_runs.agent_id nullable.
Agentless schedules (created from agentless chats via the dual-registered
``scheduler`` default chat tool) carry ``agent_id IS NULL``. Existing FK +
``ON DELETE CASCADE`` semantics on ``agents(id)`` are unaffected — Postgres
only cascades when the parent row is deleted, NULL rows aren't matched.
Revision ID: 0011_schedules_nullable_agent
Revises: 0010_schedules
"""
from typing import Sequence, Union
from alembic import op
revision: str = "0011_schedules_nullable_agent"
down_revision: Union[str, None] = "0010_schedules"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.execute("ALTER TABLE schedules ALTER COLUMN agent_id DROP NOT NULL;")
op.execute("ALTER TABLE schedule_runs ALTER COLUMN agent_id DROP NOT NULL;")
def downgrade() -> None:
# Destructive otherwise: agentless rows have agent_id IS NULL by design,
# so restoring NOT NULL must fail loudly if any exist.
op.execute(
"""
DO $$
DECLARE
sched_nulls INTEGER;
run_nulls INTEGER;
BEGIN
SELECT count(*) INTO sched_nulls
FROM schedules WHERE agent_id IS NULL;
SELECT count(*) INTO run_nulls
FROM schedule_runs WHERE agent_id IS NULL;
IF sched_nulls > 0 OR run_nulls > 0 THEN
RAISE EXCEPTION
'Cannot downgrade 0011: agentless rows present '
'(schedules=%, schedule_runs=%). '
'Delete or reassign them before retrying.',
sched_nulls, run_nulls;
END IF;
END$$;
"""
)
op.execute("ALTER TABLE schedule_runs ALTER COLUMN agent_id SET NOT NULL;")
op.execute("ALTER TABLE schedules ALTER COLUMN agent_id SET NOT NULL;")

View File

@@ -102,8 +102,6 @@ class AnswerResource(Resource, BaseAnswerResource):
"tools_dict": tools_dict,
"pending_tool_calls": pending_tool_calls,
"tool_actions": tool_actions,
"reserved_message_id": processor.reserved_message_id,
"request_id": processor.request_id,
},
)
else:

View File

@@ -1,38 +1,23 @@
import datetime
import json
import logging
import time
import uuid
from typing import Any, Dict, Generator, List, Optional
from flask import jsonify, make_response, Response
from flask_restx import Namespace
from application.api.answer.services.continuation_service import ContinuationService
from application.api.answer.services.conversation_service import (
ConversationService,
TERMINATED_RESPONSE_PLACEHOLDER,
)
from application.api.answer.services.conversation_service import ConversationService
from application.core.model_utils import (
get_api_key_for_provider,
get_default_model_id,
get_provider_from_model_id,
)
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.error import sanitize_api_error
from application.llm.llm_creator import LLMCreator
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.repositories.conversations import MessageUpdateOutcome
from application.storage.db.repositories.token_usage import TokenUsageRepository
from application.storage.db.repositories.user_logs import UserLogsRepository
from application.storage.db.session import db_readonly, db_session
from application.events.publisher import publish_user_event
from application.streaming.event_replay import format_sse_event
from application.streaming.message_journal import (
BatchedJournalWriter,
record_event,
)
from application.utils import check_required_fields
logger = logging.getLogger(__name__)
@@ -45,6 +30,10 @@ class BaseAnswerResource:
"""Shared base class for answer endpoints"""
def __init__(self):
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
self.db = db
self.user_logs_collection = db["user_logs"]
self.default_model_id = get_default_model_id()
self.conversation_service = ConversationService()
@@ -102,8 +91,8 @@ class BaseAnswerResource:
api_key = agent_config.get("user_api_key")
if not api_key:
return None
with db_readonly() as conn:
agent = AgentsRepository(conn).find_by_key(api_key)
agents_collection = self.db["agents"]
agent = agents_collection.find_one({"key": api_key})
if not agent:
return make_response(
@@ -124,32 +113,41 @@ class BaseAnswerResource:
)
token_limit = int(
agent.get("token_limit") or settings.DEFAULT_AGENT_LIMITS["token_limit"]
agent.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"])
)
request_limit = int(
agent.get("request_limit") or settings.DEFAULT_AGENT_LIMITS["request_limit"]
agent.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"])
)
end_date = datetime.datetime.now(datetime.timezone.utc)
token_usage_collection = self.db["token_usage"]
end_date = datetime.datetime.now()
start_date = end_date - datetime.timedelta(hours=24)
if limited_token_mode or limited_request_mode:
with db_readonly() as conn:
token_repo = TokenUsageRepository(conn)
if limited_token_mode:
daily_token_usage = token_repo.sum_tokens_in_range(
start=start_date, end=end_date, api_key=api_key,
)
else:
daily_token_usage = 0
if limited_request_mode:
daily_request_usage = token_repo.count_in_range(
start=start_date, end=end_date, api_key=api_key,
)
else:
daily_request_usage = 0
match_query = {
"timestamp": {"$gte": start_date, "$lte": end_date},
"api_key": api_key,
}
if limited_token_mode:
token_pipeline = [
{"$match": match_query},
{
"$group": {
"_id": None,
"total_tokens": {
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
},
}
},
]
token_result = list(token_usage_collection.aggregate(token_pipeline))
daily_token_usage = token_result[0]["total_tokens"] if token_result else 0
else:
daily_token_usage = 0
if limited_request_mode:
daily_request_usage = token_usage_collection.count_documents(match_query)
else:
daily_request_usage = 0
if not limited_token_mode and not limited_request_mode:
return None
@@ -189,7 +187,6 @@ class BaseAnswerResource:
is_shared_usage: bool = False,
shared_token: Optional[str] = None,
model_id: Optional[str] = None,
model_user_id: Optional[str] = None,
_continuation: Optional[Dict] = None,
) -> Generator[str, None, None]:
"""
@@ -215,199 +212,13 @@ class BaseAnswerResource:
Yields:
Server-sent event strings
"""
response_full, thought, source_log_docs, tool_calls = "", "", [], []
is_structured = False
schema_info = None
structured_chunks = []
query_metadata: Dict[str, Any] = {}
paused = False
# One id shared across the WAL row, primary LLM (token_usage
# attribution), the SSE event, and resumed continuations.
request_id = (
_continuation.get("request_id") if _continuation else None
) or str(uuid.uuid4())
# Reserve the placeholder row before the LLM call so a crash
# mid-stream still leaves the question queryable. Continuations
# reuse the original placeholder.
reserved_message_id: Optional[str] = None
wal_eligible = should_save_conversation and not _continuation
if wal_eligible:
try:
reservation = self.conversation_service.save_user_question(
conversation_id=conversation_id,
question=question,
decoded_token=decoded_token,
attachment_ids=attachment_ids,
api_key=user_api_key,
agent_id=agent_id,
is_shared_usage=is_shared_usage,
shared_token=shared_token,
model_id=model_id or self.default_model_id,
request_id=request_id,
index=index,
)
conversation_id = reservation["conversation_id"]
reserved_message_id = reservation["message_id"]
except Exception as e:
logger.error(
f"Failed to reserve message row before stream: {e}",
exc_info=True,
)
elif _continuation and _continuation.get("reserved_message_id"):
reserved_message_id = _continuation["reserved_message_id"]
primary_llm = getattr(agent, "llm", None)
if primary_llm is not None:
primary_llm._request_id = request_id
# Flipped to ``streaming`` on first chunk; reconciler uses this
# to tell "never started" from "in flight".
streaming_marked = False
# Heartbeat goes into ``metadata.last_heartbeat_at`` (not
# ``updated_at``, which reconciler-side writes share) and uses
# ``time.monotonic`` so a blocked event loop can't fake fresh.
STREAM_HEARTBEAT_INTERVAL = 60
last_heartbeat_at = time.monotonic()
def _mark_streaming_once() -> None:
nonlocal streaming_marked, last_heartbeat_at
if streaming_marked or not reserved_message_id:
return
try:
self.conversation_service.update_message_status(
reserved_message_id, "streaming",
)
except Exception:
logger.exception(
"update_message_status streaming failed for %s",
reserved_message_id,
)
# Seed last_heartbeat_at so watchdog doesn't fall back to `timestamp`
# (creation time) before the first STREAM_HEARTBEAT_INTERVAL tick.
try:
self.conversation_service.heartbeat_message(
reserved_message_id,
)
except Exception:
logger.exception(
"initial heartbeat seed failed for %s",
reserved_message_id,
)
streaming_marked = True
last_heartbeat_at = time.monotonic()
def _heartbeat_streaming() -> None:
nonlocal last_heartbeat_at
if not reserved_message_id or not streaming_marked:
return
now_mono = time.monotonic()
if now_mono - last_heartbeat_at < STREAM_HEARTBEAT_INTERVAL:
return
try:
self.conversation_service.heartbeat_message(
reserved_message_id,
)
except Exception:
logger.exception(
"stream heartbeat update failed for %s",
reserved_message_id,
)
last_heartbeat_at = now_mono
# Correlates tool_call_attempts rows with this message.
if reserved_message_id and getattr(agent, "tool_executor", None):
try:
agent.tool_executor.message_id = reserved_message_id
except Exception:
logger.debug(
"Could not set tool_executor.message_id; tool-call correlation will be missing for message_id=%s",
reserved_message_id,
)
# The reservation above may create the conversation row (first turn in
# a new chat). Propagate that fresh id to the tool_executor so tools
# that need a conversation home (e.g. ``scheduler`` in agentless chats)
# see it on the very first call instead of waiting for the next turn.
if conversation_id and getattr(agent, "tool_executor", None):
try:
agent.tool_executor.conversation_id = str(conversation_id)
except Exception:
logger.debug(
"Could not set tool_executor.conversation_id post-reserve",
)
# Per-stream monotonic SSE event id. Allocated by ``_emit`` and
# threaded through both the wire format (``id: <seq>\\n``) and
# the journal write so a reconnecting client can ``Last-Event-
# ID`` past anything they already saw. Continuations resume
# against the original ``reserved_message_id`` — seed the
# allocator from the journal's high-water mark so we don't
# collide on the duplicate-PK and silently lose every emit
# past the resume point.
sequence_no = -1
if _continuation and reserved_message_id:
try:
from application.storage.db.repositories.message_events import (
MessageEventsRepository,
)
with db_readonly() as conn:
latest = MessageEventsRepository(conn).latest_sequence_no(
reserved_message_id
)
if latest is not None:
sequence_no = latest
except Exception:
logger.exception(
"Continuation seq seed lookup failed for message_id=%s; "
"falling back to seq=-1 (duplicate-PK collisions will "
"be swallowed)",
reserved_message_id,
)
# One batched journal writer per stream.
journal_writer: Optional[BatchedJournalWriter] = (
BatchedJournalWriter(reserved_message_id)
if reserved_message_id
else None
)
def _emit(payload: dict) -> str:
"""Format-and-journal one SSE event.
With a reserved ``message_id``, buffers into the journal and
emits ``id: <seq>``-tagged SSE frames; otherwise falls back to
legacy ``data: ...\\n\\n`` framing.
"""
nonlocal sequence_no
if not reserved_message_id or journal_writer is None:
return f"data: {json.dumps(payload)}\n\n"
sequence_no += 1
seq = sequence_no
event_type = (
payload.get("type", "data")
if isinstance(payload, dict)
else "data"
)
normalised = payload if isinstance(payload, dict) else {"value": payload}
journal_writer.record(seq, event_type, normalised)
return format_sse_event(normalised, seq)
try:
# Surface the placeholder id before any LLM tokens so a
# mid-handshake disconnect still has a row to tail-poll.
if reserved_message_id:
yield _emit(
{
"type": "message_id",
"message_id": reserved_message_id,
"conversation_id": (
str(conversation_id) if conversation_id else None
),
"request_id": request_id,
}
)
response_full, thought, source_log_docs, tool_calls = "", "", [], []
is_structured = False
schema_info = None
structured_chunks = []
query_metadata = {}
paused = False
if _continuation:
gen_iter = agent.gen_continuation(
@@ -420,24 +231,18 @@ class BaseAnswerResource:
gen_iter = agent.gen(query=question)
for line in gen_iter:
# Cheap closure check that only hits the DB when the
# heartbeat interval has elapsed.
_heartbeat_streaming()
if "metadata" in line:
query_metadata.update(line["metadata"])
elif "answer" in line:
_mark_streaming_once()
response_full += str(line["answer"])
if line.get("structured"):
is_structured = True
schema_info = line.get("schema")
structured_chunks.append(line["answer"])
else:
yield _emit(
{"type": "answer", "answer": line["answer"]}
)
data = json.dumps({"type": "answer", "answer": line["answer"]})
yield f"data: {data}\n\n"
elif "sources" in line:
_mark_streaming_once()
truncated_sources = []
source_log_docs = line["sources"]
for source in line["sources"]:
@@ -448,58 +253,54 @@ class BaseAnswerResource:
)
truncated_sources.append(truncated_source)
if truncated_sources:
yield _emit(
data = json.dumps(
{"type": "source", "source": truncated_sources}
)
yield f"data: {data}\n\n"
elif "tool_calls" in line:
tool_calls = line["tool_calls"]
yield _emit({"type": "tool_calls", "tool_calls": tool_calls})
data = json.dumps({"type": "tool_calls", "tool_calls": tool_calls})
yield f"data: {data}\n\n"
elif "thought" in line:
thought += line["thought"]
yield _emit({"type": "thought", "thought": line["thought"]})
data = json.dumps({"type": "thought", "thought": line["thought"]})
yield f"data: {data}\n\n"
elif "type" in line:
if line.get("type") == "tool_calls_pending":
# Save continuation state and end the stream
paused = True
yield _emit(line)
data = json.dumps(line)
yield f"data: {data}\n\n"
elif line.get("type") == "error":
yield _emit(
{
"type": "error",
"error": sanitize_api_error(
line.get("error", "An error occurred")
),
}
)
sanitized_error = {
"type": "error",
"error": sanitize_api_error(line.get("error", "An error occurred"))
}
data = json.dumps(sanitized_error)
yield f"data: {data}\n\n"
else:
yield _emit(line)
data = json.dumps(line)
yield f"data: {data}\n\n"
if is_structured and structured_chunks:
yield _emit(
{
"type": "structured_answer",
"answer": response_full,
"structured": True,
"schema": schema_info,
}
)
structured_data = {
"type": "structured_answer",
"answer": response_full,
"structured": True,
"schema": schema_info,
}
data = json.dumps(structured_data)
yield f"data: {data}\n\n"
# ---- Paused: save continuation state and end stream early ----
if paused:
continuation = getattr(agent, "_pending_continuation", None)
if continuation:
# First-turn pause needs a conversation row to attach to.
# Ensure we have a conversation_id — create a partial
# conversation if this is the first turn.
if not conversation_id and should_save_conversation:
try:
provider = (
get_provider_from_model_id(
model_id,
user_id=model_user_id
or (
decoded_token.get("sub")
if decoded_token
else None
),
)
get_provider_from_model_id(model_id)
if model_id
else settings.LLM_PROVIDER
)
@@ -513,7 +314,6 @@ class BaseAnswerResource:
decoded_token=decoded_token,
model_id=model_id,
agent_id=agent_id,
model_user_id=model_user_id,
)
conversation_id = (
self.conversation_service.save_conversation(
@@ -538,7 +338,6 @@ class BaseAnswerResource:
exc_info=True,
)
state_saved = False
if conversation_id:
try:
cont_service = ContinuationService()
@@ -551,9 +350,6 @@ class BaseAnswerResource:
tool_schemas=getattr(agent, "tools", []),
agent_config={
"model_id": model_id or self.default_model_id,
# BYOM scope; without it resume falls
# back to caller's layer.
"model_user_id": model_user_id,
"llm_name": getattr(agent, "llm_name", settings.LLM_PROVIDER),
"api_key": getattr(agent, "api_key", None),
"user_api_key": user_api_key,
@@ -562,87 +358,30 @@ class BaseAnswerResource:
"prompt": getattr(agent, "prompt", ""),
"json_schema": getattr(agent, "json_schema", None),
"retriever_config": getattr(agent, "retriever_config", None),
# Reused on resume so the same WAL row
# is finalised and request_id stays
# consistent across token_usage rows.
"reserved_message_id": reserved_message_id,
"request_id": request_id,
},
client_tools=getattr(
agent.tool_executor, "client_tools", None
),
)
state_saved = True
except Exception as e:
logger.error(
f"Failed to save continuation state: {str(e)}",
exc_info=True,
)
# Notify the user out-of-band so they can navigate
# back to the conversation and decide on the
# pending tool calls. Gated on ``state_saved``: a
# missing pending_tool_state row would 404 the
# resume endpoint, so an unfulfillable notification
# is worse than no notification.
user_id_for_event = (
decoded_token.get("sub") if decoded_token else None
)
if state_saved and user_id_for_event and conversation_id:
pending_calls = continuation.get(
"pending_tool_calls", []
) if continuation else []
# Trim each pending tool call to its identifying
# metadata so a tool with a multi-MB argument
# doesn't blow out the per-event payload size
# cap. The resume page fetches full args from
# ``pending_tool_state`` regardless.
pending_summaries = [
{
k: tc.get(k)
for k in (
"call_id",
"tool_name",
"action_name",
"name",
)
if isinstance(tc, dict) and tc.get(k) is not None
}
for tc in (pending_calls or [])
if isinstance(tc, dict)
]
publish_user_event(
user_id_for_event,
"tool.approval.required",
{
"conversation_id": str(conversation_id),
"message_id": reserved_message_id,
"pending_tool_calls": pending_summaries,
},
scope={
"kind": "conversation",
"id": str(conversation_id),
},
)
id_data = {"type": "id", "id": str(conversation_id)}
data = json.dumps(id_data)
yield f"data: {data}\n\n"
yield _emit({"type": "id", "id": str(conversation_id)})
yield _emit({"type": "end"})
# Drain the terminal ``end`` so a reconnecting client
# sees it on snapshot — same reason as the main exit.
if journal_writer is not None:
journal_writer.close()
data = json.dumps({"type": "end"})
yield f"data: {data}\n\n"
return
if isNoneDoc:
for doc in source_log_docs:
doc["source"] = "None"
# Model-owner scope so title-gen uses owner's BYOM key.
provider = (
get_provider_from_model_id(
model_id,
user_id=model_user_id
or (decoded_token.get("sub") if decoded_token else None),
)
get_provider_from_model_id(model_id)
if model_id
else settings.LLM_PROVIDER
)
@@ -655,51 +394,27 @@ class BaseAnswerResource:
decoded_token=decoded_token,
model_id=model_id,
agent_id=agent_id,
model_user_id=model_user_id,
)
# Title-gen only; agent stream tokens live on ``agent.llm``.
llm._token_usage_source = "title"
if should_save_conversation:
if reserved_message_id is not None:
self.conversation_service.finalize_message(
reserved_message_id,
response_full,
thought=thought,
sources=source_log_docs,
tool_calls=tool_calls,
model_id=model_id or self.default_model_id,
metadata=query_metadata if query_metadata else None,
status="complete",
title_inputs={
"llm": llm,
"question": question,
"response": response_full,
"model_id": model_id or self.default_model_id,
"fallback_name": (
question[:50] if question else "New Conversation"
),
},
)
else:
conversation_id = self.conversation_service.save_conversation(
conversation_id,
question,
response_full,
thought,
source_log_docs,
tool_calls,
llm,
model_id or self.default_model_id,
decoded_token,
index=index,
api_key=user_api_key,
agent_id=agent_id,
is_shared_usage=is_shared_usage,
shared_token=shared_token,
attachment_ids=attachment_ids,
metadata=query_metadata if query_metadata else None,
)
conversation_id = self.conversation_service.save_conversation(
conversation_id,
question,
response_full,
thought,
source_log_docs,
tool_calls,
llm,
model_id or self.default_model_id,
decoded_token,
index=index,
api_key=user_api_key,
agent_id=agent_id,
is_shared_usage=is_shared_usage,
shared_token=shared_token,
attachment_ids=attachment_ids,
metadata=query_metadata if query_metadata else None,
)
# Persist compression metadata/summary if it exists and wasn't saved mid-execution
compression_meta = getattr(agent, "compression_metadata", None)
compression_saved = getattr(agent, "compression_saved", False)
@@ -722,22 +437,9 @@ class BaseAnswerResource:
)
else:
conversation_id = None
# Resume finished cleanly; drop the continuation row.
# Crash-paths leave it ``resuming`` for the janitor to revert.
if _continuation and conversation_id:
try:
cont_service = ContinuationService()
cont_service.delete_state(
str(conversation_id),
decoded_token.get("sub", "local"),
)
except Exception as e:
logger.error(
f"Failed to delete continuation state on resume "
f"completion: {e}",
exc_info=True,
)
yield _emit({"type": "id", "id": str(conversation_id)})
id_data = {"type": "id", "id": str(conversation_id)}
data = json.dumps(id_data)
yield f"data: {data}\n\n"
tool_calls_for_logging = self._prepare_tool_calls_for_logging(
getattr(agent, "tool_calls", tool_calls) or tool_calls
@@ -765,130 +467,56 @@ class BaseAnswerResource:
for key, value in log_data.items():
if isinstance(value, str) and len(value) > 10000:
log_data[key] = value[:10000]
try:
with db_session() as conn:
UserLogsRepository(conn).insert(
user_id=log_data.get("user"),
endpoint="stream_answer",
data=log_data,
)
except Exception as log_err:
logger.error(
f"Failed to persist stream_answer user log: {log_err}",
exc_info=True,
)
self.user_logs_collection.insert_one(log_data)
yield _emit({"type": "end"})
# Drain the journal buffer so the terminal ``end`` event is
# visible to any reconnecting client. Without this the
# client could snapshot up to the last flush boundary and
# then live-tail waiting for an ``end`` that's still
# sitting in memory.
if journal_writer is not None:
journal_writer.close()
from application.storage.db.dual_write import dual_write
from application.storage.db.repositories.user_logs import UserLogsRepository
dual_write(
UserLogsRepository,
lambda repo, d=log_data: repo.insert(
user_id=d.get("user"),
endpoint="stream_answer",
data=d,
),
)
data = json.dumps({"type": "end"})
yield f"data: {data}\n\n"
except GeneratorExit:
logger.info(f"Stream aborted by client for question: {question[:50]}... ")
# Drain any buffered events before the terminal one-shot
# ``record_event`` below — keeps the journal's seq order
# contiguous (buffered events ... terminal event). ``close``
# is idempotent; pairing it with ``flush`` matches the
# normal-exit and error branches so any future ``record()``
# past this point would log instead of silently buffering.
if journal_writer is not None:
journal_writer.flush()
journal_writer.close()
# Save partial response
# Whether the DB row was flipped to ``complete`` during this
# abort handler. Drives the choice of terminal journal event
# below: journal ``end`` only when the row actually matches,
# else journal ``error`` so a reconnecting client sees a
# failed terminal state instead of a blank "success".
finalized_complete = False
if should_save_conversation and response_full:
try:
if isNoneDoc:
for doc in source_log_docs:
doc["source"] = "None"
# Resolve under model-owner scope so shared-agent
# title-gen uses owner BYOM, not deployment default.
provider = (
get_provider_from_model_id(
model_id,
user_id=model_user_id
or (
decoded_token.get("sub")
if decoded_token
else None
),
)
if model_id
else settings.LLM_PROVIDER
)
sys_api_key = get_api_key_for_provider(
provider or settings.LLM_PROVIDER
)
llm = LLMCreator.create_llm(
provider or settings.LLM_PROVIDER,
api_key=sys_api_key,
settings.LLM_PROVIDER,
api_key=settings.API_KEY,
user_api_key=user_api_key,
decoded_token=decoded_token,
model_id=model_id,
agent_id=agent_id,
model_user_id=model_user_id,
)
llm._token_usage_source = "title"
if reserved_message_id is not None:
outcome = self.conversation_service.finalize_message(
reserved_message_id,
response_full,
thought=thought,
sources=source_log_docs,
tool_calls=tool_calls,
model_id=model_id or self.default_model_id,
metadata=query_metadata if query_metadata else None,
status="complete",
title_inputs={
"llm": llm,
"question": question,
"response": response_full,
"model_id": model_id or self.default_model_id,
"fallback_name": (
question[:50] if question else "New Conversation"
),
},
)
# ``ALREADY_COMPLETE`` means the normal-path
# finalize at line 632 won the race: the DB row
# is already at ``complete`` and the reconnect
# journal should reflect that with ``end``,
# not a spurious ``error``.
finalized_complete = outcome in (
MessageUpdateOutcome.UPDATED,
MessageUpdateOutcome.ALREADY_COMPLETE,
)
else:
self.conversation_service.save_conversation(
conversation_id,
question,
response_full,
thought,
source_log_docs,
tool_calls,
llm,
model_id or self.default_model_id,
decoded_token,
index=index,
api_key=user_api_key,
agent_id=agent_id,
is_shared_usage=is_shared_usage,
shared_token=shared_token,
attachment_ids=attachment_ids,
metadata=query_metadata if query_metadata else None,
)
# No journal row to gate, but flag the save as
# successful for symmetry with the WAL path.
finalized_complete = True
self.conversation_service.save_conversation(
conversation_id,
question,
response_full,
thought,
source_log_docs,
tool_calls,
llm,
model_id or self.default_model_id,
decoded_token,
index=index,
api_key=user_api_key,
agent_id=agent_id,
is_shared_usage=is_shared_usage,
shared_token=shared_token,
attachment_ids=attachment_ids,
metadata=query_metadata if query_metadata else None,
)
compression_meta = getattr(agent, "compression_metadata", None)
compression_saved = getattr(agent, "compression_saved", False)
if conversation_id and compression_meta and not compression_saved:
@@ -912,94 +540,16 @@ class BaseAnswerResource:
logger.error(
f"Error saving partial response: {str(e)}", exc_info=True
)
# Journal a terminal event so reconnecting clients stop tailing;
# ``end`` only when the row is ``complete``, else ``error``.
if reserved_message_id is not None:
try:
sequence_no += 1
if finalized_complete:
# Match the wire shape ``_emit({"type": "end"})``
# uses on the normal path — the replay terminal
# check at ``event_replay._payload_is_terminal``
# reads ``payload.type``, and the frontend parses
# the same key off ``data:``.
record_event(
reserved_message_id,
sequence_no,
"end",
{"type": "end"},
)
else:
# Nothing was persisted under the complete status
# — mark the row failed so the reconciler doesn't
# need to sweep it, and journal an ``error`` so a
# reconnecting client surfaces the same failure
# the UI would show on a live error.
try:
self.conversation_service.finalize_message(
reserved_message_id,
response_full or TERMINATED_RESPONSE_PLACEHOLDER,
thought=thought,
sources=source_log_docs,
tool_calls=tool_calls,
model_id=model_id or self.default_model_id,
metadata=query_metadata if query_metadata else None,
status="failed",
error=ConnectionError(
"client disconnected before response was persisted"
),
)
except Exception as fin_err:
logger.error(
f"Failed to mark aborted message failed: {fin_err}",
exc_info=True,
)
record_event(
reserved_message_id,
sequence_no,
"error",
{
"type": "error",
"error": "Stream aborted before any response was produced.",
"code": "client_disconnect",
},
)
except Exception as journal_err:
logger.error(
f"Failed to journal terminal event on abort: {journal_err}",
exc_info=True,
)
raise
except Exception as e:
logger.error(f"Error in stream: {str(e)}", exc_info=True)
if reserved_message_id is not None:
try:
self.conversation_service.finalize_message(
reserved_message_id,
response_full or TERMINATED_RESPONSE_PLACEHOLDER,
thought=thought,
sources=source_log_docs,
tool_calls=tool_calls,
model_id=model_id or self.default_model_id,
metadata=query_metadata if query_metadata else None,
status="failed",
error=e,
)
except Exception as fin_err:
logger.error(
f"Failed to finalize errored message: {fin_err}",
exc_info=True,
)
yield _emit(
data = json.dumps(
{
"type": "error",
"error": "Please try again later. We apologize for any inconvenience.",
}
)
# Drain the terminal ``error`` event we just yielded so a
# reconnecting client sees it on snapshot.
if journal_writer is not None:
journal_writer.close()
yield f"data: {data}\n\n"
return
def process_response_stream(self, stream) -> Dict[str, Any]:
@@ -1021,22 +571,8 @@ class BaseAnswerResource:
for line in stream:
try:
# Each chunk may carry an ``id: <seq>`` header before
# the ``data:`` line. Pull just the ``data:`` body so
# the JSON decode doesn't choke on the SSE framing.
event_data = ""
for raw in line.split("\n"):
if raw.startswith("data:"):
event_data = raw[len("data:") :].lstrip()
break
if not event_data:
continue
event_data = line.replace("data: ", "").strip()
event = json.loads(event_data)
# The ``message_id`` event is informational for the
# streaming consumer and has no synchronous-API field;
# skip it so the type-switch below doesn't KeyError.
if event.get("type") == "message_id":
continue
if event["type"] == "id":
conversation_id = event["id"]

View File

@@ -1,135 +0,0 @@
"""GET /api/messages/<message_id>/events — chat-stream reconnect endpoint.
Authenticates the caller, verifies ``message_id`` belongs to the user,
then hands off to ``build_message_event_stream`` for snapshot+tail.
"""
from __future__ import annotations
import logging
import re
from typing import Iterator, Optional
from flask import Blueprint, Response, jsonify, make_response, request, stream_with_context
from sqlalchemy import text
from application.core.settings import settings
from application.storage.db.session import db_readonly
from application.streaming.event_replay import (
DEFAULT_KEEPALIVE_SECONDS,
DEFAULT_POLL_TIMEOUT_SECONDS,
build_message_event_stream,
)
logger = logging.getLogger(__name__)
messages_bp = Blueprint("message_stream", __name__)
# A message_id is the canonical UUID hex format. Reject anything else
# before the SQL layer so a malformed cookie can't surface as a 500.
_MESSAGE_ID_RE = re.compile(
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-"
r"[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$"
)
# ``sequence_no`` is a non-negative decimal integer. Anything else is
# corrupt client state — fall through to a fresh-replay cursor and let
# the snapshot reader catch the client up.
_SEQUENCE_NO_RE = re.compile(r"^\d+$")
def _normalise_last_event_id(raw: Optional[str]) -> Optional[int]:
if raw is None:
return None
raw = raw.strip()
if not raw or not _SEQUENCE_NO_RE.match(raw):
return None
return int(raw)
def _user_owns_message(message_id: str, user_id: str) -> bool:
"""Return True iff ``message_id`` belongs to ``user_id``."""
try:
with db_readonly() as conn:
row = conn.execute(
text(
"""
SELECT 1 FROM conversation_messages
WHERE id = CAST(:id AS uuid)
AND user_id = :u
LIMIT 1
"""
),
{"id": message_id, "u": user_id},
).first()
return row is not None
except Exception:
logger.exception(
"Ownership lookup failed for message_id=%s user_id=%s",
message_id,
user_id,
)
return False
@messages_bp.route("/api/messages/<message_id>/events", methods=["GET"])
def stream_message_events(message_id: str) -> Response:
decoded = getattr(request, "decoded_token", None)
user_id = decoded.get("sub") if isinstance(decoded, dict) else None
if not user_id:
return make_response(
jsonify({"success": False, "message": "Authentication required"}),
401,
)
if not _MESSAGE_ID_RE.match(message_id):
return make_response(
jsonify({"success": False, "message": "Invalid message id"}),
400,
)
if not _user_owns_message(message_id, user_id):
# Don't disclose whether the row exists — a malicious caller
# gets the same 404 whether the id is bogus, taken by another
# user, or simply gone.
return make_response(
jsonify({"success": False, "message": "Not found"}),
404,
)
raw_cursor = request.headers.get("Last-Event-ID") or request.args.get(
"last_event_id"
)
last_event_id = _normalise_last_event_id(raw_cursor)
keepalive_seconds = float(
getattr(settings, "SSE_KEEPALIVE_SECONDS", DEFAULT_KEEPALIVE_SECONDS)
)
@stream_with_context
def generate() -> Iterator[str]:
try:
yield from build_message_event_stream(
message_id,
last_event_id=last_event_id,
keepalive_seconds=keepalive_seconds,
poll_timeout_seconds=DEFAULT_POLL_TIMEOUT_SECONDS,
)
except GeneratorExit:
return
except Exception:
logger.exception(
"Reconnect stream crashed for message_id=%s user_id=%s",
message_id,
user_id,
)
response = Response(generate(), mimetype="text/event-stream")
response.headers["Cache-Control"] = "no-store"
response.headers["X-Accel-Buffering"] = "no"
response.headers["Connection"] = "keep-alive"
logger.info(
"message.event.connect message_id=%s user_id=%s last_event_id=%s",
message_id,
user_id,
last_event_id if last_event_id is not None else "-",
)
return response

View File

@@ -1,21 +1,28 @@
import logging
from typing import Any, Dict, List
from flask import make_response, request
from flask_restx import fields, Resource
from bson.dbref import DBRef
from application.api.answer.routes.base import answer_ns
from application.services.search_service import (
InvalidAPIKey,
SearchFailed,
search,
)
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.vectorstore.vector_creator import VectorCreator
logger = logging.getLogger(__name__)
@answer_ns.route("/api/search")
class SearchResource(Resource):
"""Fast search endpoint for retrieving relevant documents."""
"""Fast search endpoint for retrieving relevant documents"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
mongo = MongoDB.get_client()
self.db = mongo[settings.MONGO_DB_NAME]
self.agents_collection = self.db["agents"]
search_model = answer_ns.model(
"SearchModel",
@@ -32,10 +39,116 @@ class SearchResource(Resource):
},
)
def _get_sources_from_api_key(self, api_key: str) -> List[str]:
"""Get source IDs connected to the API key/agent.
"""
agent_data = self.agents_collection.find_one({"key": api_key})
if not agent_data:
return []
source_ids = []
# Handle multiple sources (only if non-empty)
sources = agent_data.get("sources", [])
if sources and isinstance(sources, list) and len(sources) > 0:
for source_ref in sources:
# Skip "default" - it's a placeholder, not an actual vectorstore
if source_ref == "default":
continue
elif isinstance(source_ref, DBRef):
source_doc = self.db.dereference(source_ref)
if source_doc:
source_ids.append(str(source_doc["_id"]))
# Handle single source (legacy) - check if sources was empty or didn't yield results
if not source_ids:
source = agent_data.get("source")
if isinstance(source, DBRef):
source_doc = self.db.dereference(source)
if source_doc:
source_ids.append(str(source_doc["_id"]))
# Skip "default" - it's a placeholder, not an actual vectorstore
elif source and source != "default":
source_ids.append(source)
return source_ids
def _search_vectorstores(
self, query: str, source_ids: List[str], chunks: int
) -> List[Dict[str, Any]]:
"""Search across vectorstores and return results"""
if not source_ids:
return []
results = []
chunks_per_source = max(1, chunks // len(source_ids))
seen_texts = set()
for source_id in source_ids:
if not source_id or not source_id.strip():
continue
try:
docsearch = VectorCreator.create_vectorstore(
settings.VECTOR_STORE, source_id, settings.EMBEDDINGS_KEY
)
docs = docsearch.search(query, k=chunks_per_source * 2)
for doc in docs:
if len(results) >= chunks:
break
if hasattr(doc, "page_content") and hasattr(doc, "metadata"):
page_content = doc.page_content
metadata = doc.metadata
else:
page_content = doc.get("text", doc.get("page_content", ""))
metadata = doc.get("metadata", {})
# Skip duplicates
text_hash = hash(page_content[:200])
if text_hash in seen_texts:
continue
seen_texts.add(text_hash)
title = metadata.get(
"title", metadata.get("post_title", "")
)
if not isinstance(title, str):
title = str(title) if title else ""
# Clean up title
if title:
title = title.split("/")[-1]
else:
# Use filename or first part of content as title
title = metadata.get("filename", page_content[:50] + "...")
source = metadata.get("source", source_id)
results.append({
"text": page_content,
"title": title,
"source": source,
})
if len(results) >= chunks:
break
except Exception as e:
logger.error(
f"Error searching vectorstore {source_id}: {e}",
exc_info=True,
)
continue
return results[:chunks]
@answer_ns.expect(search_model)
@answer_ns.doc(description="Search for relevant documents based on query")
def post(self):
data = request.get_json() or {}
data = request.get_json()
question = data.get("question")
api_key = data.get("api_key")
@@ -43,13 +156,31 @@ class SearchResource(Resource):
if not question:
return make_response({"error": "question is required"}, 400)
if not api_key:
return make_response({"error": "api_key is required"}, 400)
try:
return make_response(search(api_key, question, chunks), 200)
except InvalidAPIKey:
# Validate API key
agent = self.agents_collection.find_one({"key": api_key})
if not agent:
return make_response({"error": "Invalid API key"}, 401)
except SearchFailed:
logger.exception("/api/search failed")
try:
# Get sources connected to this API key
source_ids = self._get_sources_from_api_key(api_key)
if not source_ids:
return make_response([], 200)
# Perform search
results = self._search_vectorstores(question, source_ids, chunks)
return make_response(results, 200)
except Exception as e:
logger.error(
f"/api/search - error: {str(e)}",
extra={"error": str(e)},
exc_info=True,
)
return make_response({"error": "Search failed"}, 500)

View File

@@ -109,14 +109,11 @@ class StreamResource(Resource, BaseAnswerResource):
decoded_token=processor.decoded_token,
agent_id=processor.agent_id,
model_id=processor.model_id,
model_user_id=processor.model_user_id,
_continuation={
"messages": messages,
"tools_dict": tools_dict,
"pending_tool_calls": pending_tool_calls,
"tool_actions": tool_actions,
"reserved_message_id": processor.reserved_message_id,
"request_id": processor.request_id,
},
),
mimetype="text/event-stream",
@@ -148,7 +145,6 @@ class StreamResource(Resource, BaseAnswerResource):
is_shared_usage=processor.is_shared_usage,
shared_token=processor.shared_token,
model_id=processor.model_id,
model_user_id=processor.model_user_id,
),
mimetype="text/event-stream",
)

View File

@@ -49,7 +49,6 @@ class CompressionOrchestrator:
model_id: str,
decoded_token: Dict[str, Any],
current_query_tokens: int = 500,
model_user_id: Optional[str] = None,
) -> CompressionResult:
"""
Check if compression is needed and perform it if so.
@@ -58,18 +57,16 @@ class CompressionOrchestrator:
Args:
conversation_id: Conversation ID
user_id: Caller's user id — used for conversation access checks
user_id: User ID
model_id: Model being used for conversation
decoded_token: User's decoded JWT token
current_query_tokens: Estimated tokens for current query
model_user_id: BYOM-resolution scope (model owner); defaults
to ``user_id`` for built-in / caller-owned models.
Returns:
CompressionResult with summary and recent queries
"""
try:
# Conversation row is owned by the caller, not the model owner.
# Load conversation
conversation = self.conversation_service.get_conversation(
conversation_id, user_id
)
@@ -80,14 +77,9 @@ class CompressionOrchestrator:
)
return CompressionResult.failure("Conversation not found")
# Use model-owner scope so per-user BYOM context windows
# (e.g. 8k) compute the threshold against the right limit.
registry_user_id = model_user_id or user_id
# Check if compression is needed
if not self.threshold_checker.should_compress(
conversation,
model_id,
current_query_tokens,
user_id=registry_user_id,
conversation, model_id, current_query_tokens
):
# No compression needed, return full history
queries = conversation.get("queries", [])
@@ -95,12 +87,7 @@ class CompressionOrchestrator:
# Perform compression
return self._perform_compression(
conversation_id,
conversation,
model_id,
decoded_token,
user_id=user_id,
model_user_id=model_user_id,
conversation_id, conversation, model_id, decoded_token
)
except Exception as e:
@@ -115,8 +102,6 @@ class CompressionOrchestrator:
conversation: Dict[str, Any],
model_id: str,
decoded_token: Dict[str, Any],
user_id: Optional[str] = None,
model_user_id: Optional[str] = None,
) -> CompressionResult:
"""
Perform the actual compression operation.
@@ -126,8 +111,6 @@ class CompressionOrchestrator:
conversation: Conversation document
model_id: Model ID for conversation
decoded_token: User token
user_id: Caller's id (for conversation reload after compression)
model_user_id: BYOM-resolution scope (model owner)
Returns:
CompressionResult
@@ -140,17 +123,11 @@ class CompressionOrchestrator:
else model_id
)
# Use model-owner scope so provider/api_key resolves to the
# owner's BYOM record (shared-agent dispatch).
caller_user_id = user_id
if caller_user_id is None and isinstance(decoded_token, dict):
caller_user_id = decoded_token.get("sub")
registry_user_id = model_user_id or caller_user_id
provider = get_provider_from_model_id(
compression_model, user_id=registry_user_id
)
# Get provider and API key for compression model
provider = get_provider_from_model_id(compression_model)
api_key = get_api_key_for_provider(provider)
# Create compression LLM
compression_llm = LLMCreator.create_llm(
provider,
api_key=api_key,
@@ -158,11 +135,7 @@ class CompressionOrchestrator:
decoded_token=decoded_token,
model_id=compression_model,
agent_id=conversation.get("agent_id"),
model_user_id=registry_user_id,
)
# Side-channel LLM tag — distinguishes compression rows
# from primary stream rows for cost-attribution dashboards.
compression_llm._token_usage_source = "compression"
# Create compression service with DB update capability
compression_service = CompressionService(
@@ -194,12 +167,9 @@ class CompressionOrchestrator:
f"saved {metadata.original_token_count - metadata.compressed_token_count} tokens"
)
# Reload under caller (conversation is owned by caller).
reload_user_id = caller_user_id
if reload_user_id is None and isinstance(decoded_token, dict):
reload_user_id = decoded_token.get("sub")
# Reload conversation with updated metadata
conversation = self.conversation_service.get_conversation(
conversation_id, user_id=reload_user_id
conversation_id, user_id=decoded_token.get("sub")
)
# Get compressed context
@@ -222,21 +192,16 @@ class CompressionOrchestrator:
model_id: str,
decoded_token: Dict[str, Any],
current_conversation: Optional[Dict[str, Any]] = None,
model_user_id: Optional[str] = None,
) -> CompressionResult:
"""
Perform compression during tool execution.
Args:
conversation_id: Conversation ID
user_id: Caller's user id — used for conversation access checks
user_id: User ID
model_id: Model ID
decoded_token: User token
current_conversation: Pre-loaded conversation (optional)
model_user_id: BYOM-resolution scope (model owner). For
shared-agent dispatch this is the agent owner; defaults
to ``user_id`` so built-in / caller-owned models are
unaffected.
Returns:
CompressionResult
@@ -258,12 +223,7 @@ class CompressionOrchestrator:
# Perform compression
return self._perform_compression(
conversation_id,
conversation,
model_id,
decoded_token,
user_id=user_id,
model_user_id=model_user_id,
conversation_id, conversation, model_id, decoded_token
)
except Exception as e:

View File

@@ -106,13 +106,8 @@ class CompressionService:
f"using model {self.model_id}"
)
# See note in conversation_service.py: ``self.model_id`` is
# the registry id (UUID for BYOM); the LLM's own model_id is
# what the provider's API actually expects.
response = self.llm.gen(
model=getattr(self.llm, "model_id", None) or self.model_id,
messages=messages,
max_tokens=4000,
model=self.model_id, messages=messages, max_tokens=4000
)
# Extract summary from response

View File

@@ -30,7 +30,6 @@ class CompressionThresholdChecker:
conversation: Dict[str, Any],
model_id: str,
current_query_tokens: int = 500,
user_id: str | None = None,
) -> bool:
"""
Determine if compression is needed.
@@ -39,8 +38,6 @@ class CompressionThresholdChecker:
conversation: Full conversation document
model_id: Target model for this request
current_query_tokens: Estimated tokens for current query
user_id: Owner — needed so per-user BYOM custom-model UUIDs
resolve when looking up the context window.
Returns:
True if tokens >= threshold% of context window
@@ -51,7 +48,7 @@ class CompressionThresholdChecker:
total_tokens += current_query_tokens
# Get context window limit for model
context_limit = get_token_limit(model_id, user_id=user_id)
context_limit = get_token_limit(model_id)
# Calculate threshold
threshold = int(context_limit * self.threshold_percentage)
@@ -76,24 +73,20 @@ class CompressionThresholdChecker:
logger.error(f"Error checking compression need: {str(e)}", exc_info=True)
return False
def check_message_tokens(
self, messages: list, model_id: str, user_id: str | None = None
) -> bool:
def check_message_tokens(self, messages: list, model_id: str) -> bool:
"""
Check if message list exceeds threshold.
Args:
messages: List of message dicts
model_id: Target model
user_id: Owner — needed so per-user BYOM custom-model UUIDs
resolve when looking up the context window.
Returns:
True if at or above threshold
"""
try:
current_tokens = TokenCounter.count_message_tokens(messages)
context_limit = get_token_limit(model_id, user_id=user_id)
context_limit = get_token_limit(model_id)
threshold = int(context_limit * self.threshold_percentage)
if current_tokens >= threshold:

View File

@@ -12,12 +12,6 @@ logger = logging.getLogger(__name__)
class TokenCounter:
"""Centralized token counting for conversations and messages."""
# Per-image token estimate. Provider tokenizers vary widely
# (Gemini ~258, GPT-4o 85-1500, Claude ~1500) and the actual cost
# depends on resolution/detail we can't see here. Errs slightly high
# so the threshold check stays conservative.
_IMAGE_PART_TOKEN_ESTIMATE = 1500
@staticmethod
def count_message_tokens(messages: List[Dict]) -> int:
"""
@@ -35,36 +29,12 @@ class TokenCounter:
if isinstance(content, str):
total_tokens += num_tokens_from_string(content)
elif isinstance(content, list):
# Handle structured content (tool calls, image parts, etc.)
# Handle structured content (tool calls, etc.)
for item in content:
if isinstance(item, dict):
total_tokens += TokenCounter._count_content_part(item)
total_tokens += num_tokens_from_string(str(item))
return total_tokens
@staticmethod
def _count_content_part(item: Dict) -> int:
# Image/file attachments are billed by the provider per image,
# not proportional to the inline bytes/base64 string.
# ``str(item)`` on a 1MB image inflates the count by ~10000x,
# which trips spurious compression and overflows downstream
# input limits.
item_type = item.get("type")
if "files" in item:
files = item.get("files")
count = len(files) if isinstance(files, list) and files else 1
return TokenCounter._IMAGE_PART_TOKEN_ESTIMATE * count
if "image_url" in item or item_type in {
"image",
"image_url",
"input_image",
"file",
}:
return TokenCounter._IMAGE_PART_TOKEN_ESTIMATE
return num_tokens_from_string(str(item))
@staticmethod
def count_query_tokens(
queries: List[Dict[str, Any]], include_tool_calls: bool = True

View File

@@ -1,39 +1,63 @@
"""Service for saving and restoring tool-call continuation state.
When a stream pauses (tool needs approval or client-side execution),
the full execution state is persisted to Postgres so the client can
the full execution state is persisted to MongoDB so the client can
resume later by sending tool_actions.
"""
import datetime
import logging
from typing import Any, Dict, List, Optional
from application.storage.db.base_repository import looks_like_uuid
from bson import ObjectId
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.storage.db.dual_write import dual_write
from application.storage.db.repositories.conversations import ConversationsRepository
from application.storage.db.repositories.pending_tool_state import (
PendingToolStateRepository,
)
from application.storage.db.serialization import coerce_pg_native as _make_serializable
from application.storage.db.session import db_readonly, db_session
logger = logging.getLogger(__name__)
# TTL for pending states — auto-cleaned after this period
PENDING_STATE_TTL_SECONDS = 30 * 60 # 30 minutes
# Re-export so the existing tests at tests/api/answer/services/test_continuation_service_pg.py
# can keep importing ``_make_serializable`` from here.
__all__ = ["_make_serializable", "ContinuationService", "PENDING_STATE_TTL_SECONDS"]
def _make_serializable(obj: Any) -> Any:
"""Recursively convert MongoDB ObjectIds and other non-JSON types."""
if isinstance(obj, ObjectId):
return str(obj)
if isinstance(obj, dict):
return {str(k): _make_serializable(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_make_serializable(v) for v in obj]
if isinstance(obj, bytes):
return obj.decode("utf-8", errors="replace")
return obj
class ContinuationService:
"""Manages pending tool-call state in Postgres."""
"""Manages pending tool-call state in MongoDB."""
def __init__(self):
# No-op constructor retained for call-site compatibility. State
# lives in Postgres now; each operation opens its own short-lived
# session rather than holding a connection on the service.
pass
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
self.collection = db["pending_tool_state"]
self._ensure_indexes()
def _ensure_indexes(self):
try:
self.collection.create_index(
"expires_at", expireAfterSeconds=0
)
self.collection.create_index(
[("conversation_id", 1), ("user", 1)], unique=True
)
except Exception:
# Indexes may already exist or mongomock doesn't support TTL
pass
def save_state(
self,
@@ -48,10 +72,6 @@ class ContinuationService:
) -> str:
"""Save execution state for later continuation.
``conversation_id`` may be a Postgres UUID or the legacy Mongo
``ObjectId`` string — the latter is resolved via
``conversations.legacy_mongo_id`` to find the matching row.
Args:
conversation_id: The conversation this state belongs to.
user: Owner user ID.
@@ -63,26 +83,45 @@ class ContinuationService:
client_tools: Client-provided tool schemas for client-side execution.
Returns:
The string ID (conversation_id as provided) of the saved state.
The string ID of the saved state document.
"""
with db_session() as conn:
now = datetime.datetime.now(datetime.timezone.utc)
expires_at = now + datetime.timedelta(seconds=PENDING_STATE_TTL_SECONDS)
doc = {
"conversation_id": conversation_id,
"user": user,
"messages": _make_serializable(messages),
"pending_tool_calls": _make_serializable(pending_tool_calls),
"tools_dict": _make_serializable(tools_dict),
"tool_schemas": _make_serializable(tool_schemas),
"agent_config": _make_serializable(agent_config),
"client_tools": _make_serializable(client_tools) if client_tools else None,
"created_at": now,
"expires_at": expires_at,
}
# Upsert — only one pending state per conversation per user
result = self.collection.replace_one(
{"conversation_id": conversation_id, "user": user},
doc,
upsert=True,
)
state_id = str(result.upserted_id) if result.upserted_id else conversation_id
logger.info(
f"Saved continuation state for conversation {conversation_id} "
f"with {len(pending_tool_calls)} pending tool call(s)"
)
# Dual-write to Postgres — upsert against the same Mongo conversation
# by resolving its UUID via conversations.legacy_mongo_id.
def _pg_save(_: PendingToolStateRepository) -> None:
conn = _._conn # reuse the existing transaction
conv = ConversationsRepository(conn).get_by_legacy_id(conversation_id)
if conv is not None:
pg_conv_id = conv["id"]
elif looks_like_uuid(conversation_id):
pg_conv_id = conversation_id
else:
# Unresolvable legacy ObjectId — downstream ``CAST AS uuid``
# would raise and poison the save. Surface the mismatch so
# the caller can decide (the stream loop in routes/base.py
# already wraps this in try/except).
raise ValueError(
f"Cannot save continuation state: conversation_id "
f"{conversation_id!r} is neither a PG UUID nor a "
f"backfilled legacy Mongo id."
)
PendingToolStateRepository(conn).save_state(
pg_conv_id,
if conv is None:
return
_.save_state(
conv["id"],
user,
messages=_make_serializable(messages),
pending_tool_calls=_make_serializable(pending_tool_calls),
@@ -92,11 +131,8 @@ class ContinuationService:
client_tools=_make_serializable(client_tools) if client_tools else None,
)
logger.info(
f"Saved continuation state for conversation {conversation_id} "
f"with {len(pending_tool_calls)} pending tool call(s)"
)
return conversation_id
dual_write(PendingToolStateRepository, _pg_save)
return state_id
def load_state(
self, conversation_id: str, user: str
@@ -106,58 +142,34 @@ class ContinuationService:
Returns:
The state dict, or None if no pending state exists.
"""
with db_readonly() as conn:
conv = ConversationsRepository(conn).get_by_legacy_id(conversation_id)
if conv is not None:
pg_conv_id = conv["id"]
elif looks_like_uuid(conversation_id):
pg_conv_id = conversation_id
else:
# Unresolvable legacy ObjectId → no state can exist for it.
return None
doc = PendingToolStateRepository(conn).load_state(pg_conv_id, user)
doc = self.collection.find_one(
{"conversation_id": conversation_id, "user": user}
)
if not doc:
return None
doc["_id"] = str(doc["_id"])
return doc
def delete_state(self, conversation_id: str, user: str) -> bool:
"""Delete pending state after successful resumption.
Returns:
True if a row was deleted.
True if a document was deleted.
"""
with db_session() as conn:
conv = ConversationsRepository(conn).get_by_legacy_id(conversation_id)
if conv is not None:
pg_conv_id = conv["id"]
elif looks_like_uuid(conversation_id):
pg_conv_id = conversation_id
else:
# Unresolvable legacy ObjectId → nothing to delete.
return False
deleted = PendingToolStateRepository(conn).delete_state(pg_conv_id, user)
if deleted:
result = self.collection.delete_one(
{"conversation_id": conversation_id, "user": user}
)
if result.deleted_count:
logger.info(
f"Deleted continuation state for conversation {conversation_id}"
)
return deleted
def mark_resuming(self, conversation_id: str, user: str) -> bool:
"""Flip the pending row to ``resuming`` so a crashed resume can be retried."""
with db_session() as conn:
conv = ConversationsRepository(conn).get_by_legacy_id(conversation_id)
if conv is not None:
pg_conv_id = conv["id"]
elif looks_like_uuid(conversation_id):
pg_conv_id = conversation_id
else:
return False
flipped = PendingToolStateRepository(conn).mark_resuming(
pg_conv_id, user
)
if flipped:
logger.info(
f"Marked continuation state as resuming for conversation "
f"{conversation_id}"
)
return flipped
# Dual-write to Postgres — delete the same row.
def _pg_delete(repo: PendingToolStateRepository) -> None:
conv = ConversationsRepository(repo._conn).get_by_legacy_id(conversation_id)
if conv is None:
return
repo.delete_state(conv["id"], user)
dual_write(PendingToolStateRepository, _pg_delete)
return result.deleted_count > 0

View File

@@ -1,61 +1,46 @@
"""Conversation persistence service backed by Postgres.
Handles create / append / update / compression for conversations during
the answer-streaming path. Connections are opened per-operation rather
than held for the duration of a stream.
"""
import logging
import uuid
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
from sqlalchemy import text as sql_text
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.storage.db.base_repository import looks_like_uuid
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.repositories.conversations import (
ConversationsRepository,
MessageUpdateOutcome,
)
from application.storage.db.session import db_readonly, db_session
from application.storage.db.dual_write import dual_write
from application.storage.db.repositories.conversations import ConversationsRepository
from bson import ObjectId
logger = logging.getLogger(__name__)
# Shown to the user if the worker dies mid-stream and the response is never finalised.
TERMINATED_RESPONSE_PLACEHOLDER = (
"Response was terminated prior to completion, try regenerating."
)
class ConversationService:
def __init__(self):
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
self.conversations_collection = db["conversations"]
self.agents_collection = db["agents"]
def get_conversation(
self, conversation_id: str, user_id: str
) -> Optional[Dict[str, Any]]:
"""Retrieve a conversation with owner-or-shared access control.
Returns a dict in the legacy Mongo shape — ``queries`` is a list
of message dicts (prompt/response/...) — for compatibility with
the streaming pipeline that consumes this shape.
"""
"""Retrieve a conversation with proper access control"""
if not conversation_id or not user_id:
return None
try:
with db_readonly() as conn:
repo = ConversationsRepository(conn)
conv = repo.get_any(conversation_id, user_id)
if conv is None:
logger.warning(
f"Conversation not found or unauthorized - ID: {conversation_id}, User: {user_id}"
)
return None
messages = repo.get_messages(str(conv["id"]))
conv["queries"] = messages
conv["_id"] = str(conv["id"])
return conv
conversation = self.conversations_collection.find_one(
{
"_id": ObjectId(conversation_id),
"$or": [{"user": user_id}, {"shared_with": user_id}],
}
)
if not conversation:
logger.warning(
f"Conversation not found or unauthorized - ID: {conversation_id}, User: {user_id}"
)
return None
conversation["_id"] = str(conversation["_id"])
return conversation
except Exception as e:
logger.error(f"Error fetching conversation: {str(e)}", exc_info=True)
return None
@@ -79,11 +64,7 @@ class ConversationService:
attachment_ids: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> str:
"""Save or update a conversation in Postgres.
Returns the string conversation id (PG UUID as string, or the
caller-provided id if it was already a UUID).
"""
"""Save or update a conversation in the database"""
if decoded_token is None:
raise ValueError("Invalid or missing authentication token")
user_id = decoded_token.get("sub")
@@ -91,47 +72,117 @@ class ConversationService:
raise ValueError("User ID not found in token")
current_time = datetime.now(timezone.utc)
# Trim huge inline source text to a reasonable max before persist.
# clean up in sources array such that we save max 1k characters for text part
for source in sources:
if "text" in source and isinstance(source["text"], str):
source["text"] = source["text"][:1000]
message_payload = {
"prompt": question,
"response": response,
"thought": thought,
"sources": sources,
"tool_calls": tool_calls,
"attachments": attachment_ids,
"model_id": model_id,
"timestamp": current_time,
}
if metadata:
message_payload["metadata"] = metadata
if conversation_id is not None and index is not None:
with db_session() as conn:
repo = ConversationsRepository(conn)
conv = repo.get_any(conversation_id, user_id)
# Update existing conversation with new query
result = self.conversations_collection.update_one(
{
"_id": ObjectId(conversation_id),
"user": user_id,
f"queries.{index}": {"$exists": True},
},
{
"$set": {
f"queries.{index}.prompt": question,
f"queries.{index}.response": response,
f"queries.{index}.thought": thought,
f"queries.{index}.sources": sources,
f"queries.{index}.tool_calls": tool_calls,
f"queries.{index}.timestamp": current_time,
f"queries.{index}.attachments": attachment_ids,
f"queries.{index}.model_id": model_id,
**(
{f"queries.{index}.metadata": metadata}
if metadata
else {}
),
}
},
)
if result.matched_count == 0:
raise ValueError("Conversation not found or unauthorized")
self.conversations_collection.update_one(
{
"_id": ObjectId(conversation_id),
"user": user_id,
f"queries.{index}": {"$exists": True},
},
{"$push": {"queries": {"$each": [], "$slice": index + 1}}},
)
# Dual-write to Postgres: update the message at :index and
# truncate anything after it, mirroring Mongo's $set+$slice.
def _pg_update_at_index(repo: ConversationsRepository) -> None:
conv = repo.get_by_legacy_id(conversation_id)
if conv is None:
raise ValueError("Conversation not found or unauthorized")
conv_pg_id = str(conv["id"])
repo.update_message_at(conv_pg_id, index, message_payload)
repo.truncate_after(conv_pg_id, index)
return
repo.update_message_at(conv["id"], index, {
"prompt": question,
"response": response,
"thought": thought,
"sources": sources,
"tool_calls": tool_calls,
"attachments": attachment_ids,
"model_id": model_id,
"timestamp": current_time,
**({"metadata": metadata} if metadata else {}),
})
repo.truncate_after(conv["id"], index)
dual_write(ConversationsRepository, _pg_update_at_index)
return conversation_id
elif conversation_id:
with db_session() as conn:
repo = ConversationsRepository(conn)
conv = repo.get_any(conversation_id, user_id)
# Append new message to existing conversation
result = self.conversations_collection.update_one(
{"_id": ObjectId(conversation_id), "user": user_id},
{
"$push": {
"queries": {
"prompt": question,
"response": response,
"thought": thought,
"sources": sources,
"tool_calls": tool_calls,
"timestamp": current_time,
"attachments": attachment_ids,
"model_id": model_id,
**({"metadata": metadata} if metadata else {}),
}
}
},
)
if result.matched_count == 0:
raise ValueError("Conversation not found or unauthorized")
# Dual-write to Postgres: append the same message.
def _pg_append(repo: ConversationsRepository) -> None:
conv = repo.get_by_legacy_id(conversation_id)
if conv is None:
raise ValueError("Conversation not found or unauthorized")
conv_pg_id = str(conv["id"])
# append_message expects 'metadata' key either way; normalise.
append_payload = dict(message_payload)
append_payload.setdefault("metadata", metadata or {})
repo.append_message(conv_pg_id, append_payload)
return
repo.append_message(conv["id"], {
"prompt": question,
"response": response,
"thought": thought,
"sources": sources,
"tool_calls": tool_calls,
"attachments": attachment_ids,
"model_id": model_id,
"timestamp": current_time,
"metadata": metadata or {},
})
dual_write(ConversationsRepository, _pg_append)
return conversation_id
else:
# Create new conversation
messages_summary = [
{
"role": "system",
@@ -146,317 +197,125 @@ class ConversationService:
},
]
# ``model_id`` here is the registry id (a UUID for BYOM
# records). The LLM's own ``model_id`` is the upstream name
# LLMCreator resolved at construction time — that's what
# the provider's API expects. Built-ins are unaffected.
completion = llm.gen(
model=getattr(llm, "model_id", None) or model_id,
messages=messages_summary,
max_tokens=500,
model=model_id, messages=messages_summary, max_tokens=500
)
if not completion or not completion.strip():
completion = question[:50] if question else "New Conversation"
resolved_api_key: Optional[str] = None
resolved_agent_id: Optional[str] = None
if api_key:
with db_readonly() as conn:
agent = AgentsRepository(conn).find_by_key(api_key)
if agent:
resolved_api_key = agent.get("key")
if agent_id:
resolved_agent_id = agent_id
query_doc = {
"prompt": question,
"response": response,
"thought": thought,
"sources": sources,
"tool_calls": tool_calls,
"timestamp": current_time,
"attachments": attachment_ids,
"model_id": model_id,
}
if metadata:
query_doc["metadata"] = metadata
with db_session() as conn:
repo = ConversationsRepository(conn)
conversation_data = {
"user": user_id,
"date": current_time,
"name": completion,
"queries": [query_doc],
}
if api_key:
if agent_id:
conversation_data["agent_id"] = agent_id
if is_shared_usage:
conversation_data["is_shared_usage"] = is_shared_usage
conversation_data["shared_token"] = shared_token
agent = self.agents_collection.find_one({"key": api_key})
if agent:
conversation_data["api_key"] = agent["key"]
result = self.conversations_collection.insert_one(conversation_data)
inserted_id = str(result.inserted_id)
# Dual-write to Postgres: create the conversation row with
# legacy_mongo_id and append the first message.
def _pg_create(repo: ConversationsRepository) -> None:
conv = repo.create(
user_id,
completion,
agent_id=resolved_agent_id,
api_key=resolved_api_key,
is_shared_usage=bool(resolved_agent_id and is_shared_usage),
shared_token=(
shared_token
if (resolved_agent_id and is_shared_usage)
else None
),
agent_id=conversation_data.get("agent_id"),
api_key=conversation_data.get("api_key"),
is_shared_usage=conversation_data.get("is_shared_usage", False),
shared_token=conversation_data.get("shared_token"),
legacy_mongo_id=inserted_id,
)
conv_pg_id = str(conv["id"])
append_payload = dict(message_payload)
append_payload.setdefault("metadata", metadata or {})
repo.append_message(conv_pg_id, append_payload)
return conv_pg_id
repo.append_message(conv["id"], {
"prompt": question,
"response": response,
"thought": thought,
"sources": sources,
"tool_calls": tool_calls,
"attachments": attachment_ids,
"model_id": model_id,
"timestamp": current_time,
"metadata": metadata or {},
})
def save_user_question(
self,
conversation_id: Optional[str],
question: str,
decoded_token: Dict[str, Any],
*,
attachment_ids: Optional[List[str]] = None,
api_key: Optional[str] = None,
agent_id: Optional[str] = None,
is_shared_usage: bool = False,
shared_token: Optional[str] = None,
model_id: Optional[str] = None,
request_id: Optional[str] = None,
status: str = "pending",
index: Optional[int] = None,
) -> Dict[str, str]:
"""Reserve the placeholder message row before the LLM call.
``index`` triggers regenerate semantics: messages at
``position >= index`` are truncated so the new placeholder
lands at ``position = index`` rather than appending.
Returns ``{"conversation_id", "message_id", "request_id"}``.
"""
if decoded_token is None:
raise ValueError("Invalid or missing authentication token")
user_id = decoded_token.get("sub")
if not user_id:
raise ValueError("User ID not found in token")
request_id = request_id or str(uuid.uuid4())
resolved_api_key: Optional[str] = None
resolved_agent_id: Optional[str] = None
if api_key and not conversation_id:
with db_readonly() as conn:
agent = AgentsRepository(conn).find_by_key(api_key)
if agent:
resolved_api_key = agent.get("key")
if agent_id:
resolved_agent_id = agent_id
with db_session() as conn:
repo = ConversationsRepository(conn)
if conversation_id:
conv = repo.get_any(conversation_id, user_id)
if conv is None:
raise ValueError("Conversation not found or unauthorized")
conv_pg_id = str(conv["id"])
# Regenerate / edit-prior-question: drop the message at
# ``index`` and everything after it so the new
# ``reserve_message`` lands at ``position=index`` rather
# than appending at the end of the conversation.
if isinstance(index, int) and index >= 0:
repo.truncate_after(conv_pg_id, keep_up_to=index - 1)
else:
fallback_name = (question[:50] if question else "New Conversation")
conv = repo.create(
user_id,
fallback_name,
agent_id=resolved_agent_id,
api_key=resolved_api_key,
is_shared_usage=bool(resolved_agent_id and is_shared_usage),
shared_token=(
shared_token
if (resolved_agent_id and is_shared_usage)
else None
),
)
conv_pg_id = str(conv["id"])
row = repo.reserve_message(
conv_pg_id,
prompt=question,
placeholder_response=TERMINATED_RESPONSE_PLACEHOLDER,
request_id=request_id,
status=status,
attachments=attachment_ids,
model_id=model_id,
)
message_id = str(row["id"])
return {
"conversation_id": conv_pg_id,
"message_id": message_id,
"request_id": request_id,
}
def update_message_status(self, message_id: str, status: str) -> bool:
"""Cheap status-only transition (e.g. ``pending → streaming``)."""
if not message_id:
return False
with db_session() as conn:
return ConversationsRepository(conn).update_message_status(
message_id, status,
)
def heartbeat_message(self, message_id: str) -> bool:
"""Bump ``message_metadata.last_heartbeat_at`` so the reconciler's
staleness sweep counts the row as alive. No-ops on terminal rows.
"""
if not message_id:
return False
with db_session() as conn:
return ConversationsRepository(conn).heartbeat_message(message_id)
def finalize_message(
self,
message_id: str,
response: str,
*,
thought: str = "",
sources: Optional[List[Dict[str, Any]]] = None,
tool_calls: Optional[List[Dict[str, Any]]] = None,
model_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
status: str = "complete",
error: Optional[BaseException] = None,
title_inputs: Optional[Dict[str, Any]] = None,
) -> MessageUpdateOutcome:
"""Commit the response and tool_call confirms in one transaction.
The outcome propagates directly from ``update_message_by_id`` so
callers (notably the SSE abort handler) can tell a fresh
finalize from "the row was already terminal" — the latter must
still be treated as success when the prior state was
``complete``.
"""
if not message_id:
return MessageUpdateOutcome.INVALID
sources = sources or []
for source in sources:
if "text" in source and isinstance(source["text"], str):
source["text"] = source["text"][:1000]
merged_metadata: Dict[str, Any] = dict(metadata or {})
if status == "failed" and error is not None:
merged_metadata.setdefault(
"error", f"{type(error).__name__}: {str(error)}"
)
update_fields: Dict[str, Any] = {
"response": response,
"status": status,
"thought": thought,
"sources": sources,
"tool_calls": tool_calls or [],
"metadata": merged_metadata,
}
if model_id is not None:
update_fields["model_id"] = model_id
# Atomic message update + tool_call_attempts confirm; the
# ``only_if_non_terminal`` guard prevents a late stream from
# retracting a row the reconciler already escalated.
with db_session() as conn:
repo = ConversationsRepository(conn)
outcome = repo.update_message_by_id(
message_id, update_fields,
only_if_non_terminal=True,
)
if outcome is not MessageUpdateOutcome.UPDATED:
logger.warning(
f"finalize_message: no row updated for message_id={message_id} "
f"(outcome={outcome.value} — possibly already terminal)"
)
return outcome
repo.confirm_executed_tool_calls(message_id)
# Outside the txn — title-gen is a multi-second LLM round trip.
if title_inputs and status == "complete":
try:
with db_session() as conn:
self._maybe_generate_title(conn, message_id, title_inputs)
except Exception as e:
logger.error(
f"finalize_message title generation failed: {e}",
exc_info=True,
)
return MessageUpdateOutcome.UPDATED
def _maybe_generate_title(
self,
conn,
message_id: str,
title_inputs: Dict[str, Any],
) -> None:
"""Generate an LLM-summarised conversation name if one isn't set yet."""
llm = title_inputs.get("llm")
question = title_inputs.get("question") or ""
response = title_inputs.get("response") or ""
fallback_name = title_inputs.get("fallback_name") or question[:50]
if llm is None:
return
row = conn.execute(
sql_text(
"SELECT c.id, c.name FROM conversation_messages m "
"JOIN conversations c ON c.id = m.conversation_id "
"WHERE m.id = CAST(:mid AS uuid)"
),
{"mid": message_id},
).fetchone()
if row is None:
return
conv_id, current_name = str(row[0]), row[1]
if current_name and current_name != fallback_name:
return
messages_summary = [
{
"role": "system",
"content": "You are a helpful assistant that creates concise conversation titles. "
"Summarize conversations in 3 words or less using the same language as the user.",
},
{
"role": "user",
"content": "Summarise following conversation in no more than 3 words, "
"respond ONLY with the summary, use the same language as the "
"user query \n\nUser: " + question + "\n\n" + "AI: " + response,
},
]
completion = llm.gen(
model=getattr(llm, "model_id", None) or title_inputs.get("model_id"),
messages=messages_summary,
max_tokens=500,
)
if not completion or not completion.strip():
completion = fallback_name or "New Conversation"
conn.execute(
sql_text(
"UPDATE conversations SET name = :name, updated_at = now() "
"WHERE id = CAST(:id AS uuid)"
),
{"id": conv_id, "name": completion.strip()},
)
dual_write(ConversationsRepository, _pg_create)
return inserted_id
def update_compression_metadata(
self, conversation_id: str, compression_metadata: Dict[str, Any]
) -> None:
"""Persist compression flags and append a compression point.
"""
Update conversation with compression metadata.
Mirrors the Mongo-era ``$set`` + ``$push $slice`` on
``compression_metadata`` but goes through the PG repo API.
Uses $push with $slice to keep only the most recent compression points,
preventing unbounded array growth. Since each compression incorporates
previous compressions, older points become redundant.
Args:
conversation_id: Conversation ID
compression_metadata: Compression point data
"""
try:
with db_session() as conn:
repo = ConversationsRepository(conn)
# conversation_id here comes from the streaming pipeline
# which has already resolved it; accept either UUID or
# legacy id for safety.
self.conversations_collection.update_one(
{"_id": ObjectId(conversation_id)},
{
"$set": {
"compression_metadata.is_compressed": True,
"compression_metadata.last_compression_at": compression_metadata.get(
"timestamp"
),
},
"$push": {
"compression_metadata.compression_points": {
"$each": [compression_metadata],
"$slice": -settings.COMPRESSION_MAX_HISTORY_POINTS,
}
},
},
)
logger.info(
f"Updated compression metadata for conversation {conversation_id}"
)
# Dual-write to Postgres: mirror $set + $push $slice.
def _pg_compression(repo: ConversationsRepository) -> None:
conv = repo.get_by_legacy_id(conversation_id)
conv_pg_id = (
str(conv["id"]) if conv is not None else conversation_id
)
if conv is None:
return
repo.set_compression_flags(
conv_pg_id,
conv["id"],
is_compressed=True,
last_compression_at=compression_metadata.get("timestamp"),
)
repo.append_compression_point(
conv_pg_id,
conv["id"],
compression_metadata,
max_points=settings.COMPRESSION_MAX_HISTORY_POINTS,
)
logger.info(
f"Updated compression metadata for conversation {conversation_id}"
)
dual_write(ConversationsRepository, _pg_compression)
except Exception as e:
logger.error(
f"Error updating compression metadata: {str(e)}", exc_info=True
@@ -466,22 +325,39 @@ class ConversationService:
def append_compression_message(
self, conversation_id: str, compression_metadata: Dict[str, Any]
) -> None:
"""Append a synthetic compression summary message to the conversation."""
"""
Append a synthetic compression summary entry into the conversation history.
This makes the summary visible in the DB alongside normal queries.
"""
try:
summary = compression_metadata.get("compressed_summary", "")
if not summary:
return
timestamp = compression_metadata.get(
"timestamp", datetime.now(timezone.utc)
timestamp = compression_metadata.get("timestamp", datetime.now(timezone.utc))
self.conversations_collection.update_one(
{"_id": ObjectId(conversation_id)},
{
"$push": {
"queries": {
"prompt": "[Context Compression Summary]",
"response": summary,
"thought": "",
"sources": [],
"tool_calls": [],
"timestamp": timestamp,
"attachments": [],
"model_id": compression_metadata.get("model_used"),
}
}
},
)
with db_session() as conn:
repo = ConversationsRepository(conn)
def _pg_append_summary(repo: ConversationsRepository) -> None:
conv = repo.get_by_legacy_id(conversation_id)
conv_pg_id = (
str(conv["id"]) if conv is not None else conversation_id
)
repo.append_message(conv_pg_id, {
if conv is None:
return
repo.append_message(conv["id"], {
"prompt": "[Context Compression Summary]",
"response": summary,
"thought": "",
@@ -491,9 +367,9 @@ class ConversationService:
"model_id": compression_metadata.get("model_used"),
"timestamp": timestamp,
})
logger.info(
f"Appended compression summary to conversation {conversation_id}"
)
dual_write(ConversationsRepository, _pg_append_summary)
logger.info(f"Appended compression summary to conversation {conversation_id}")
except Exception as e:
logger.error(
f"Error appending compression summary: {str(e)}", exc_info=True
@@ -502,30 +378,20 @@ class ConversationService:
def get_compression_metadata(
self, conversation_id: str
) -> Optional[Dict[str, Any]]:
"""Fetch the stored compression metadata JSONB blob for a conversation."""
"""
Get compression metadata for a conversation.
Args:
conversation_id: Conversation ID
Returns:
Compression metadata dict or None
"""
try:
with db_readonly() as conn:
repo = ConversationsRepository(conn)
conv = repo.get_by_legacy_id(conversation_id)
if conv is None:
# Fallback to UUID lookup without user scoping — the
# caller already holds an authenticated conversation
# id from the streaming path. Gate on id shape so a
# non-UUID (legacy ObjectId that wasn't backfilled)
# doesn't reach CAST — the cast raises and spams the
# logs with a stack trace on every call.
if not looks_like_uuid(conversation_id):
return None
result = conn.execute(
sql_text(
"SELECT compression_metadata FROM conversations "
"WHERE id = CAST(:id AS uuid)"
),
{"id": conversation_id},
)
row = result.fetchone()
return row[0] if row is not None else None
return conv.get("compression_metadata") if conv else None
conversation = self.conversations_collection.find_one(
{"_id": ObjectId(conversation_id)}, {"compression_metadata": 1}
)
return conversation.get("compression_metadata") if conversation else None
except Exception as e:
logger.error(
f"Error getting compression metadata: {str(e)}", exc_info=True

View File

@@ -5,8 +5,11 @@ import os
from pathlib import Path
from typing import Any, Dict, Optional, Set
from bson.dbref import DBRef
from bson.objectid import ObjectId
from application.agents.agent_creator import AgentCreator
from application.agents.default_tools import synthesized_default_tools
from application.api.answer.services.compression import CompressionOrchestrator
from application.api.answer.services.compression.token_counter import TokenCounter
from application.api.answer.services.conversation_service import ConversationService
@@ -17,17 +20,8 @@ from application.core.model_utils import (
get_provider_from_model_id,
validate_model_id,
)
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from sqlalchemy import text as sql_text
from application.storage.db.base_repository import looks_like_uuid, row_to_dict
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.repositories.attachments import AttachmentsRepository
from application.storage.db.repositories.prompts import PromptsRepository
from application.storage.db.repositories.sources import SourcesRepository
from application.storage.db.repositories.user_tools import UserToolsRepository
from application.storage.db.repositories.users import UsersRepository
from application.storage.db.session import db_readonly, db_session
from application.retriever.retriever_creator import RetrieverCreator
from application.utils import (
calculate_doc_token_budget,
@@ -38,41 +32,28 @@ logger = logging.getLogger(__name__)
def get_prompt(prompt_id: str, prompts_collection=None) -> str:
"""Get a prompt by preset name or Postgres ID (UUID or legacy ObjectId).
The ``prompts_collection`` parameter is retained for backwards
compatibility with call sites that still pass it positionally; it is
ignored post-cutover.
"""
del prompts_collection # unused — retained for call-site compatibility
# Callers may pass a ``uuid.UUID`` (from a PG ``prompt_id`` column) or a
# plain string ("default"/"creative"/legacy ObjectId). Normalise to str
# so both the preset lookup and the UUID-vs-legacy branching work.
# ``None`` / empty means "use the default prompt" — agents that never
# set a custom prompt land here (PG ``agents.prompt_id`` is NULL).
if prompt_id is None or prompt_id == "":
prompt_id = "default"
elif not isinstance(prompt_id, str):
prompt_id = str(prompt_id)
Get a prompt by preset name or MongoDB ID
"""
current_dir = Path(__file__).resolve().parents[3]
prompts_dir = current_dir / "prompts"
# Maps for classic agent types
CLASSIC_PRESETS = {
"default": "chat_combine_default.txt",
"creative": "chat_combine_creative.txt",
"strict": "chat_combine_strict.txt",
"reduce": "chat_reduce_prompt.txt",
}
# Agentic counterparts — same styles, but with search tool instructions
AGENTIC_PRESETS = {
"default": "agentic/default.txt",
"creative": "agentic/creative.txt",
"strict": "agentic/strict.txt",
}
preset_mapping = {
**CLASSIC_PRESETS,
**{f"agentic_{k}": v for k, v in AGENTIC_PRESETS.items()},
}
preset_mapping = {**CLASSIC_PRESETS, **{f"agentic_{k}": v for k, v in AGENTIC_PRESETS.items()}}
if prompt_id in preset_mapping:
file_path = os.path.join(prompts_dir, preset_mapping[prompt_id])
@@ -82,18 +63,14 @@ def get_prompt(prompt_id: str, prompts_collection=None) -> str:
except FileNotFoundError:
raise FileNotFoundError(f"Prompt file not found: {file_path}")
try:
with db_readonly() as conn:
repo = PromptsRepository(conn)
prompt_doc = None
if looks_like_uuid(prompt_id):
prompt_doc = repo.get_for_rendering(prompt_id)
if prompt_doc is None:
prompt_doc = repo.get_by_legacy_id(prompt_id)
if prompts_collection is None:
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
prompts_collection = db["prompts"]
prompt_doc = prompts_collection.find_one({"_id": ObjectId(prompt_id)})
if not prompt_doc:
raise ValueError(f"Prompt with ID {prompt_id} not found")
return prompt_doc["content"]
except ValueError:
raise
except Exception as e:
raise ValueError(f"Invalid prompt ID: {prompt_id}") from e
@@ -102,9 +79,12 @@ class StreamProcessor:
def __init__(
self, request_data: Dict[str, Any], decoded_token: Optional[Dict[str, Any]]
):
# Legacy attribute retained as None for any external callers that
# introspect the processor; all DB access uses per-op connections.
self.prompts_collection = None
mongo = MongoDB.get_client()
self.db = mongo[settings.MONGO_DB_NAME]
self.agents_collection = self.db["agents"]
self.attachments_collection = self.db["attachments"]
self.prompts_collection = self.db["prompts"]
self.data = request_data
self.decoded_token = decoded_token
self.initial_user_id = (
@@ -123,12 +103,6 @@ class StreamProcessor:
self.agent_id = self.data.get("agent_id")
self.agent_key = None
self.model_id: Optional[str] = None
# BYOM-resolution scope, set by _validate_and_set_model.
self.model_user_id: Optional[str] = None
# WAL placeholder id pulled from continuation state on resume.
self.reserved_message_id: Optional[str] = None
# Carried through resumes so multi-pause runs keep one request_id.
self.request_id: Optional[str] = None
self.conversation_service = ConversationService()
self.compression_orchestrator = CompressionOrchestrator(
self.conversation_service
@@ -199,23 +173,16 @@ class StreamProcessor:
for query in conversation.get("queries", [])
]
else:
# model_user_id keeps history trim aligned with the BYOM's
# actual context window instead of the default 128k.
self.history = limit_chat_history(
json.loads(self.data.get("history", "[]")),
model_id=self.model_id,
user_id=self.model_user_id,
json.loads(self.data.get("history", "[]")), model_id=self.model_id
)
def _handle_compression(self, conversation: Dict[str, Any]):
"""Handle conversation compression logic using orchestrator."""
try:
# initial_user_id for conversation access; model_user_id
# for BYOM context-window / provider lookups.
result = self.compression_orchestrator.compress_if_needed(
conversation_id=self.conversation_id,
user_id=self.initial_user_id,
model_user_id=self.model_user_id,
model_id=self.model_id,
decoded_token=self.decoded_token,
)
@@ -277,54 +244,29 @@ class StreamProcessor:
if not attachment_ids:
return []
attachments = []
try:
with db_readonly() as conn:
repo = AttachmentsRepository(conn)
for attachment_id in attachment_ids:
try:
attachment_doc = repo.get_any(str(attachment_id), user_id)
if attachment_doc:
attachments.append(attachment_doc)
except Exception as e:
logger.error(
f"Error retrieving attachment {attachment_id}: {e}",
exc_info=True,
)
except Exception as e:
logger.error(f"Error opening attachments connection: {e}", exc_info=True)
for attachment_id in attachment_ids:
try:
attachment_doc = self.attachments_collection.find_one(
{"_id": ObjectId(attachment_id), "user": user_id}
)
if attachment_doc:
attachments.append(attachment_doc)
except Exception as e:
logger.error(
f"Error retrieving attachment {attachment_id}: {e}", exc_info=True
)
return attachments
def _validate_and_set_model(self):
"""Pick model_id with agent authority on agent-bound chats."""
"""Validate and set model_id from request"""
from application.core.model_settings import ModelRegistry
requested_model = self.data.get("model_id")
# Caller picks from their own BYOM layer; agent defaults resolve
# under the owner's layer (shared agents have caller != owner).
caller_user_id = self.initial_user_id
owner_user_id = self.agent_config.get("user_id") or caller_user_id
# Agent-bound: agent's default_model_id wins, body's model_id is dropped.
agent_bound = self._agent_data is not None
if agent_bound:
agent_default_model = self.agent_config.get("default_model_id", "")
if agent_default_model and validate_model_id(
agent_default_model, user_id=owner_user_id
):
self.model_id = agent_default_model
self.model_user_id = owner_user_id
else:
self.model_id = get_default_model_id()
self.model_user_id = None
return
if requested_model:
if not validate_model_id(requested_model, user_id=caller_user_id):
if not validate_model_id(requested_model):
registry = ModelRegistry.get_instance()
available_models = [
m.id
for m in registry.get_enabled_models(user_id=caller_user_id)
]
available_models = [m.id for m in registry.get_enabled_models()]
raise ValueError(
f"Invalid model_id '{requested_model}'. "
f"Available models: {', '.join(available_models[:5])}"
@@ -335,138 +277,86 @@ class StreamProcessor:
)
)
self.model_id = requested_model
self.model_user_id = caller_user_id
else:
self.model_id = get_default_model_id()
self.model_user_id = None
agent_default_model = self.agent_config.get("default_model_id", "")
if agent_default_model and validate_model_id(agent_default_model):
self.model_id = agent_default_model
else:
self.model_id = get_default_model_id()
def _get_agent_key(self, agent_id: Optional[str], user_id: Optional[str]) -> tuple:
"""Get API key for agent with access control."""
"""Get API key for agent with access control"""
if not agent_id:
return None, False, None
try:
with db_readonly() as conn:
# Lookup without user scoping — access control is done
# against ``user_id`` / ``shared_with`` / ``shared`` flags
# right below, matching the legacy Mongo semantics.
repo = AgentsRepository(conn)
agent = None
if looks_like_uuid(str(agent_id)):
result = conn.execute(
sql_text(
"SELECT * FROM agents WHERE id = CAST(:id AS uuid)"
),
{"id": str(agent_id)},
)
row = result.fetchone()
if row is not None:
agent = row_to_dict(row)
if agent is None:
agent = repo.get_by_legacy_id(str(agent_id))
agent = self.agents_collection.find_one({"_id": ObjectId(agent_id)})
if agent is None:
raise Exception("Agent not found")
agent_owner = agent.get("user_id")
is_owner = agent_owner == user_id
is_shared_with_user = bool(agent.get("shared", False))
is_owner = agent.get("user") == user_id
is_shared_with_user = agent.get(
"shared_publicly", False
) or user_id in agent.get("shared_with", [])
if not (is_owner or is_shared_with_user):
raise Exception("Unauthorized access to the agent")
if is_owner:
now = datetime.datetime.now(datetime.timezone.utc)
try:
with db_session() as conn:
AgentsRepository(conn).update(
str(agent["id"]), agent_owner,
{"last_used_at": now},
)
except Exception:
logger.warning(
"Failed to update last_used_at for agent",
exc_info=True,
)
return (
str(agent["key"]) if agent.get("key") else None,
not is_owner,
agent.get("shared_token"),
)
self.agents_collection.update_one(
{"_id": ObjectId(agent_id)},
{
"$set": {
"lastUsedAt": datetime.datetime.now(datetime.timezone.utc)
}
},
)
return str(agent["key"]), not is_owner, agent.get("shared_token")
except Exception as e:
logger.error(f"Error in get_agent_key: {str(e)}", exc_info=True)
raise
def _get_data_from_api_key(self, api_key: str) -> Dict[str, Any]:
"""Resolve agent metadata + the unioned source set for the given key."""
with db_readonly() as conn:
agent = AgentsRepository(conn).find_by_key(api_key)
if not agent:
raise Exception("Invalid API Key, please generate a new key", 401)
sources_repo = SourcesRepository(conn)
# The repo dict uses "user_id" — the streaming path expects
# a "user" key (legacy Mongo shape) for identity propagation.
data: Dict[str, Any] = dict(agent)
data["user"] = agent.get("user_id")
# Active sources = primary extras, primary first, deduplicated.
# ``_configure_source`` ignores an empty ``data["sources"]``,
# so the primary must appear in the union too — not only in
# the legacy ``data["source"]`` slot.
sources_list: list = []
seen: set = set()
owner = agent.get("user_id")
primary_id = agent.get("source_id")
# ``sources`` row may have NULL ``retriever``/``chunks`` —
# fall back to the agent's value (``dict.get`` returns None
# even when the key exists with value None).
if primary_id:
source_doc = sources_repo.get(str(primary_id), owner)
if source_doc:
sid = str(source_doc["id"])
data["source"] = sid
src_retriever = source_doc.get("retriever")
if src_retriever:
data["retriever"] = src_retriever
src_chunks = source_doc.get("chunks")
if src_chunks is not None:
data["chunks"] = src_chunks
sources_list.append(
{
"id": sid,
"retriever": src_retriever or "classic",
"chunks": (
src_chunks if src_chunks is not None
else data.get("chunks", "2")
),
}
)
seen.add(sid)
else:
data["source"] = None
data = self.agents_collection.find_one({"key": api_key})
if not data:
raise Exception("Invalid API Key, please generate a new key", 401)
source = data.get("source")
if isinstance(source, DBRef):
source_doc = self.db.dereference(source)
if source_doc:
data["source"] = str(source_doc["_id"])
data["retriever"] = source_doc.get("retriever", data.get("retriever"))
data["chunks"] = source_doc.get("chunks", data.get("chunks"))
else:
data["source"] = None
elif source == "default":
data["source"] = "default"
else:
data["source"] = None
for sid_raw in agent.get("extra_source_ids") or []:
if not sid_raw:
continue
source_doc = sources_repo.get(str(sid_raw), owner)
if not source_doc:
continue
sid = str(source_doc["id"])
if sid in seen:
continue
src_retriever = source_doc.get("retriever")
src_chunks = source_doc.get("chunks")
sources_list.append(
{
"id": sid,
"retriever": src_retriever or "classic",
"chunks": (
src_chunks if src_chunks is not None
else data.get("chunks", "2")
),
sources = data.get("sources", [])
if sources and isinstance(sources, list):
sources_list = []
for i, source_ref in enumerate(sources):
if source_ref == "default":
processed_source = {
"id": "default",
"retriever": "classic",
"chunks": data.get("chunks", "2"),
}
)
seen.add(sid)
data["sources"] = sources_list
sources_list.append(processed_source)
elif isinstance(source_ref, DBRef):
source_doc = self.db.dereference(source_ref)
if source_doc:
processed_source = {
"id": str(source_doc["_id"]),
"retriever": source_doc.get("retriever", "classic"),
"chunks": source_doc.get("chunks", data.get("chunks", "2")),
}
sources_list.append(processed_source)
data["sources"] = sources_list
else:
data["sources"] = []
data["default_model_id"] = data.get("default_model_id", "")
return data
def _configure_source(self):
@@ -579,10 +469,6 @@ class StreamProcessor:
"allow_system_prompt_override": self._agent_data.get(
"allow_system_prompt_override", False
),
# Owner identity — _validate_and_set_model reads this to
# resolve owner-stored BYOM default_model_id against the
# owner's per-user model layer rather than the caller's.
"user_id": self._agent_data.get("user"),
}
)
@@ -598,14 +484,8 @@ class StreamProcessor:
# Owner using their own agent
self.decoded_token = {"sub": self._agent_data.get("user")}
# PG row exposes the workflow as ``workflow_id`` (UUID column);
# legacy Mongo shape used the key ``workflow``. Accept either so
# API-key-invoked workflow agents bind correctly downstream.
wf_ref = self._agent_data.get("workflow") or self._agent_data.get(
"workflow_id"
)
if wf_ref:
self.agent_config["workflow"] = str(wf_ref)
if self._agent_data.get("workflow"):
self.agent_config["workflow"] = self._agent_data["workflow"]
self.agent_config["workflow_owner"] = self._agent_data.get("user")
else:
# No API key — default/workflow configuration
@@ -629,20 +509,15 @@ class StreamProcessor:
)
def _configure_retriever(self):
"""Assemble retriever config; agent's values are authoritative when bound."""
# BYOM scope: owner for shared-agent BYOM, caller for own BYOM,
# None for built-ins. Without ``user_id`` here, the doc budget
# falls back to settings.DEFAULT_LLM_TOKEN_LIMIT and overfills
# the upstream context window for any small (e.g. 8k/32k) BYOM.
doc_token_limit = calculate_doc_token_budget(
model_id=self.model_id, user_id=self.model_user_id
)
"""Assemble retriever config with precedence: request > agent > default."""
doc_token_limit = calculate_doc_token_budget(model_id=self.model_id)
# Start with defaults
retriever_name = "classic"
chunks = 2
if self._agent_data is not None:
# Agent-bound: agent wins, body's retriever/chunks are dropped.
# Layer agent-level config (if present)
if self._agent_data:
if self._agent_data.get("retriever"):
retriever_name = self._agent_data["retriever"]
if self._agent_data.get("chunks") is not None:
@@ -653,17 +528,18 @@ class StreamProcessor:
f"Invalid agent chunks value: {self._agent_data['chunks']}, "
"using default value 2"
)
else:
if "retriever" in self.data:
retriever_name = self.data["retriever"]
if "chunks" in self.data:
try:
chunks = int(self.data["chunks"])
except (ValueError, TypeError):
logger.warning(
f"Invalid request chunks value: {self.data['chunks']}, "
"using default value 2"
)
# Explicit request values win over agent config
if "retriever" in self.data:
retriever_name = self.data["retriever"]
if "chunks" in self.data:
try:
chunks = int(self.data["chunks"])
except (ValueError, TypeError):
logger.warning(
f"Invalid request chunks value: {self.data['chunks']}, "
"using default value 2"
)
self.retriever_config = {
"retriever_name": retriever_name,
@@ -671,7 +547,7 @@ class StreamProcessor:
"doc_token_limit": doc_token_limit,
}
# isNoneDoc without an API key forces no retrieval (agentless only)
# isNoneDoc without an API key forces no retrieval
api_key = self.data.get("api_key") or self.agent_key
if not api_key and "isNoneDoc" in self.data and self.data["isNoneDoc"]:
self.retriever_config["chunks"] = 0
@@ -685,7 +561,6 @@ class StreamProcessor:
chunks=self.retriever_config["chunks"],
doc_token_limit=self.retriever_config.get("doc_token_limit", 50000),
model_id=self.model_id,
model_user_id=self.model_user_id,
user_api_key=self.agent_config["user_api_key"],
agent_id=self.agent_id,
decoded_token=self.decoded_token,
@@ -745,27 +620,21 @@ class StreamProcessor:
filtering_enabled = required_tool_actions is not None
try:
user_tools_collection = self.db["user_tools"]
user_id = self.initial_user_id or "local"
agentless = self.agent_id is None
with db_readonly() as conn:
user_tools = UserToolsRepository(conn).list_active_for_user(user_id)
user_doc = (
UsersRepository(conn).get(user_id) if agentless else None
)
default_docs = (
synthesized_default_tools(user_doc) if agentless else []
user_tools = list(
user_tools_collection.find({"user": user_id, "status": True})
)
tool_docs = list(user_tools) + default_docs
if not tool_docs:
if not user_tools:
return None
tools_data = {}
for tool_doc in tool_docs:
for tool_doc in user_tools:
tool_name = tool_doc.get("name")
tool_id = str(tool_doc.get("_id") or tool_doc.get("id"))
is_default = bool(tool_doc.get("default"))
tool_id = str(tool_doc.get("_id"))
if filtering_enabled:
required_actions_by_name = required_tool_actions.get(
@@ -778,18 +647,11 @@ class StreamProcessor:
if not required_actions:
continue
else:
# No template names a default tool, so running its
# actions blind would only inject noise.
if is_default:
continue
required_actions = None
tool_data = self._fetch_tool_data(tool_doc, required_actions)
if tool_data:
# Defaults reachable by synthetic id only — the name
# key stays bound to an explicit row of the same name.
if not is_default:
tools_data[tool_name] = tool_data
tools_data[tool_name] = tool_data
tools_data[tool_id] = tool_data
return tools_data if tools_data else None
@@ -986,20 +848,6 @@ class StreamProcessor:
if not state:
raise ValueError("No pending tool state found for this conversation")
# Claim the resume up-front. ``mark_resuming`` only flips ``pending``
# → ``resuming``; if it returns False, another resume already
# claimed this row (status='resuming') — bail before any further
# LLM/tool work to avoid double-execution. The cleanup janitor
# reverts a stale ``resuming`` claim back to ``pending`` after the
# 10-minute grace window so the user can retry.
if not cont_service.mark_resuming(
conversation_id, self.initial_user_id,
):
raise ValueError(
"Resume already in progress for this conversation; "
"retry after the grace window if it stalls."
)
messages = state["messages"]
pending_tool_calls = state["pending_tool_calls"]
tools_dict = state["tools_dict"]
@@ -1007,11 +855,6 @@ class StreamProcessor:
agent_config = state["agent_config"]
model_id = agent_config.get("model_id")
# BYOM scope captured at initial dispatch. None for built-ins or
# caller-owned BYOM where decoded_token['sub'] is already the
# right scope; non-None for shared-agent owner BYOM where the
# caller's identity differs from the model owner's.
model_user_id = agent_config.get("model_user_id")
llm_name = agent_config.get("llm_name", settings.LLM_PROVIDER)
api_key = agent_config.get("api_key")
user_api_key = agent_config.get("user_api_key")
@@ -1029,14 +872,12 @@ class StreamProcessor:
decoded_token=self.decoded_token,
model_id=model_id,
agent_id=agent_id,
model_user_id=model_user_id,
)
llm_handler = LLMHandlerCreator.create_handler(llm_name or "default")
tool_executor = ToolExecutor(
user_api_key=user_api_key,
user=self.initial_user_id,
decoded_token=self.decoded_token,
agent_id=agent_id,
)
tool_executor.conversation_id = conversation_id
# Restore client tools so they stay available for subsequent LLM calls
@@ -1060,7 +901,6 @@ class StreamProcessor:
"endpoint": "stream",
"llm_name": llm_name,
"model_id": model_id,
"model_user_id": model_user_id,
"api_key": system_api_key,
"agent_id": agent_id,
"user_api_key": user_api_key,
@@ -1083,22 +923,12 @@ class StreamProcessor:
# Store config for the route layer
self.model_id = model_id
# Mirror ``model_user_id`` back onto the processor so the route
# layer (StreamResource) reads the owner scope captured at
# initial dispatch. Without this, ``processor.model_user_id``
# stays at the __init__ default (None) and complete_stream
# falls back to the caller's sub: the post-resume title-LLM
# save misses the owner's BYOM layer, and any second tool
# pause persists ``model_user_id=None`` — losing owner scope
# for every subsequent resume of this conversation.
self.model_user_id = model_user_id
self.agent_id = agent_id
self.agent_config["user_api_key"] = user_api_key
self.conversation_id = conversation_id
# Reused on resume so the same WAL row gets finalised and
# request_id stays consistent across token_usage rows.
self.reserved_message_id = agent_config.get("reserved_message_id")
self.request_id = agent_config.get("request_id")
# Delete state so it can't be replayed
cont_service.delete_state(conversation_id, self.initial_user_id)
return agent, messages, tools_dict, pending_tool_calls, tool_actions
@@ -1144,11 +974,8 @@ class StreamProcessor:
tools_data=tools_data,
)
# Use the user_id that resolved the model so owner-scoped BYOM
# records dispatch correctly on shared-agent requests.
model_user_id = getattr(self, "model_user_id", self.initial_user_id)
provider = (
get_provider_from_model_id(self.model_id, user_id=model_user_id)
get_provider_from_model_id(self.model_id)
if self.model_id
else settings.LLM_PROVIDER
)
@@ -1159,10 +986,8 @@ class StreamProcessor:
from application.llm.handlers.handler_creator import LLMHandlerCreator
from application.agents.tool_executor import ToolExecutor
# Compute backup models: agent's configured models minus the active one.
# PG agents may carry an explicit ``models: NULL`` (not absent), so
# ``.get("models", [])`` isn't enough — coerce None → [].
agent_models = self.agent_config.get("models") or []
# Compute backup models: agent's configured models minus the active one
agent_models = self.agent_config.get("models", [])
backup_models = [m for m in agent_models if m != self.model_id]
llm = LLMCreator.create_llm(
@@ -1173,8 +998,6 @@ class StreamProcessor:
model_id=self.model_id,
agent_id=self.agent_id,
backup_models=backup_models,
# Owner-scope on shared-agent BYOM dispatch.
model_user_id=model_user_id,
)
llm_handler = LLMHandlerCreator.create_handler(
provider if provider else "default"
@@ -1185,7 +1008,6 @@ class StreamProcessor:
user_api_key=self.agent_config["user_api_key"],
user=user,
decoded_token=self.decoded_token,
agent_id=self.agent_id,
)
tool_executor.conversation_id = self.conversation_id
# Pass client-side tools so they get merged in get_tools()
@@ -1193,11 +1015,11 @@ class StreamProcessor:
if client_tools:
tool_executor.client_tools = client_tools
# Base agent kwargs
agent_kwargs = {
"endpoint": "stream",
"llm_name": provider or settings.LLM_PROVIDER,
"model_id": self.model_id,
"model_user_id": self.model_user_id,
"api_key": system_api_key,
"agent_id": self.agent_id,
"user_api_key": self.agent_config["user_api_key"],
@@ -1225,7 +1047,6 @@ class StreamProcessor:
"doc_token_limit", 50000
),
"model_id": self.model_id,
"model_user_id": self.model_user_id,
"user_api_key": self.agent_config["user_api_key"],
"agent_id": self.agent_id,
"llm_name": provider or settings.LLM_PROVIDER,

View File

@@ -1,10 +1,12 @@
import base64
import datetime
import html
import json
import uuid
from urllib.parse import urlencode
from bson.objectid import ObjectId
from flask import (
Blueprint,
current_app,
@@ -15,18 +17,22 @@ from flask import (
from flask_restx import fields, Namespace, Resource
from application.api import api
from application.api.user.tasks import (
ingest_connector_task,
)
from application.parser.connectors.connector_creator import ConnectorCreator
from application.storage.db.repositories.connector_sessions import (
ConnectorSessionsRepository,
)
from application.storage.db.repositories.sources import SourcesRepository
from application.storage.db.session import db_readonly, db_session
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.api import api
from application.parser.connectors.connector_creator import ConnectorCreator
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
sources_collection = db["sources"]
sessions_collection = db["connector_sessions"]
connector = Blueprint("connector", __name__)
connectors_ns = Namespace("connectors", description="Connector operations", path="/")
api.add_namespace(connectors_ns)
@@ -62,14 +68,16 @@ class ConnectorAuth(Resource):
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
user_id = decoded_token.get('sub')
with db_session() as conn:
session_row = ConnectorSessionsRepository(conn).upsert(
user_id, provider, status="pending",
)
session_pg_id = str(session_row["id"])
now = datetime.datetime.now(datetime.timezone.utc)
result = sessions_collection.insert_one({
"provider": provider,
"user": user_id,
"status": "pending",
"created_at": now
})
state_dict = {
"provider": provider,
"object_id": session_pg_id,
"object_id": str(result.inserted_id)
}
state = base64.urlsafe_b64encode(json.dumps(state_dict).encode()).decode()
@@ -152,25 +160,17 @@ class ConnectorsCallback(Resource):
sanitized_token_info = auth.sanitize_token_info(token_info)
# ``object_id`` in the OAuth state is the PG session row
# UUID (new flow) or a legacy Mongo ObjectId (pre-cutover
# issued state). Try UUID update first; fall back to
# legacy id path.
patch = {
"session_token": session_token,
"token_info": sanitized_token_info,
"user_email": user_email,
"status": "authorized",
}
with db_session() as conn:
repo = ConnectorSessionsRepository(conn)
if state_object_id:
value = str(state_object_id)
updated = False
if len(value) == 36 and "-" in value:
updated = repo.update(value, patch)
if not updated:
repo.update_by_legacy_id(value, patch)
sessions_collection.find_one_and_update(
{"_id": ObjectId(state_object_id), "provider": provider},
{
"$set": {
"session_token": session_token,
"token_info": sanitized_token_info,
"user_email": user_email,
"status": "authorized"
}
}
)
# Redirect to success page with session token and user email
return redirect(build_callback_redirect({
@@ -222,11 +222,8 @@ class ConnectorFiles(Resource):
if not decoded_token:
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
user = decoded_token.get('sub')
with db_readonly() as conn:
session = ConnectorSessionsRepository(conn).get_by_session_token(
session_token,
)
if not session or session.get("user_id") != user:
session = sessions_collection.find_one({"session_token": session_token, "user": user})
if not session:
return make_response(jsonify({"success": False, "error": "Invalid or unauthorized session"}), 401)
loader = ConnectorCreator.create_connector(provider, session_token)
@@ -291,11 +288,8 @@ class ConnectorValidateSession(Resource):
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
user = decoded_token.get('sub')
with db_readonly() as conn:
session = ConnectorSessionsRepository(conn).get_by_session_token(
session_token,
)
if not session or session.get("user_id") != user or not session.get("token_info"):
session = sessions_collection.find_one({"session_token": session_token, "user": user})
if not session or "token_info" not in session:
return make_response(jsonify({"success": False, "error": "Invalid or expired session"}), 401)
token_info = session["token_info"]
@@ -306,11 +300,10 @@ class ConnectorValidateSession(Resource):
try:
refreshed_token_info = auth.refresh_access_token(token_info.get('refresh_token'))
sanitized_token_info = auth.sanitize_token_info(refreshed_token_info)
with db_session() as conn:
repo = ConnectorSessionsRepository(conn)
row = repo.get_by_session_token(session_token)
if row:
repo.update(str(row["id"]), {"token_info": sanitized_token_info})
sessions_collection.update_one(
{"session_token": session_token},
{"$set": {"token_info": sanitized_token_info}}
)
token_info = sanitized_token_info
is_expired = False
except Exception as refresh_error:
@@ -354,11 +347,8 @@ class ConnectorDisconnect(Resource):
if session_token:
with db_session() as conn:
ConnectorSessionsRepository(conn).delete_by_session_token(
session_token,
)
sessions_collection.delete_one({"session_token": session_token})
return make_response(jsonify({"success": True}), 200)
except Exception as e:
current_app.logger.error(f"Error disconnecting connector session: {e}", exc_info=True)
@@ -395,28 +385,32 @@ class ConnectorSync(Resource):
}),
400
)
user_id = decoded_token.get('sub')
with db_readonly() as conn:
source = SourcesRepository(conn).get_any(source_id, user_id)
source = sources_collection.find_one({"_id": ObjectId(source_id)})
if not source:
return make_response(
jsonify({
"success": False,
"error": "Source not found"
}),
}),
404
)
# ``get_any`` already scopes by ``user_id``; an extra guard
# here would be dead code.
if source.get('user') != decoded_token.get('sub'):
return make_response(
jsonify({
"success": False,
"error": "Unauthorized access to source"
}),
403
)
remote_data = source.get('remote_data') or {}
if isinstance(remote_data, str):
try:
remote_data = json.loads(remote_data)
except json.JSONDecodeError:
current_app.logger.error(f"Invalid remote_data format for source {source_id}")
remote_data = {}
remote_data = {}
try:
if source.get('remote_data'):
remote_data = json.loads(source.get('remote_data'))
except json.JSONDecodeError:
current_app.logger.error(f"Invalid remote_data format for source {source_id}")
remote_data = {}
source_type = remote_data.get('provider')
if not source_type:
@@ -444,7 +438,7 @@ class ConnectorSync(Resource):
recursive=recursive,
retriever=source.get('retriever', 'classic'),
operation_mode="sync",
doc_id=str(source.get('id') or source_id),
doc_id=source_id,
sync_frequency=source.get('sync_frequency', 'never')
)

View File

@@ -1,504 +0,0 @@
"""GET /api/events — user-scoped Server-Sent Events endpoint.
Subscribe-then-snapshot pattern: subscribe to ``user:{user_id}``
pub/sub, snapshot the Redis Streams backlog past ``Last-Event-ID``
inside the SUBSCRIBE-ack callback, flush snapshot, then tail live
events (dedup'd by stream id). See ``docs/runbooks/sse-notifications.md``.
"""
from __future__ import annotations
import json
import logging
import re
import time
from typing import Iterator, Optional
from flask import Blueprint, Response, jsonify, make_response, request, stream_with_context
from application.cache import get_redis_instance
from application.core.settings import settings
from application.events.keys import (
connection_counter_key,
replay_budget_key,
stream_id_compare,
stream_key,
topic_name,
)
from application.streaming.broadcast_channel import Topic
logger = logging.getLogger(__name__)
events = Blueprint("event_stream", __name__)
SUBSCRIBE_POLL_INTERVAL_SECONDS = 1.0
# WHATWG SSE treats CRLF, CR, and LF equivalently as line terminators.
_SSE_LINE_SPLIT = re.compile(r"\r\n|\r|\n")
# Redis Streams ids are ``ms`` or ``ms-seq`` where both halves are decimal.
# Anything else is a corrupted client cookie / IndexedDB residue and must
# not be passed to XRANGE — Redis would reject it and our truncation gate
# would silently fail.
_STREAM_ID_RE = re.compile(r"^\d+(-\d+)?$")
# Only emitted at most once per process so a misconfigured deployment
# doesn't drown the logs.
_local_user_warned = False
def _format_sse(data: str, *, event_id: Optional[str] = None) -> str:
"""Encode a payload as one SSE message terminated by a blank line.
Splits on any line-terminator variant (``\\r\\n``, ``\\r``, ``\\n``)
so a stray CR in upstream content can't smuggle a premature line
boundary into the wire format.
"""
lines: list[str] = []
if event_id:
lines.append(f"id: {event_id}")
for line in _SSE_LINE_SPLIT.split(data):
lines.append(f"data: {line}")
return "\n".join(lines) + "\n\n"
def _decode(value) -> Optional[str]:
if value is None:
return None
if isinstance(value, (bytes, bytearray)):
try:
return value.decode("utf-8")
except Exception:
return None
return str(value)
def _oldest_retained_id(redis_client, user_id: str) -> Optional[str]:
"""Return the id of the oldest entry still in the stream, or ``None``.
Used to detect ``Last-Event-ID`` having slid off the back of the
MAXLEN'd window.
"""
try:
info = redis_client.xinfo_stream(stream_key(user_id))
except Exception:
return None
if not isinstance(info, dict):
return None
# redis-py 7.4 returns str-keyed dicts here; the bytes-key probe is
# defence in depth in case ``decode_responses`` is ever flipped.
first_entry = info.get("first-entry") or info.get(b"first-entry")
if not first_entry:
return None
# XINFO STREAM returns first-entry as [id, [field, value, ...]]
try:
return _decode(first_entry[0])
except Exception:
return None
def _allow_replay(
redis_client, user_id: str, last_event_id: Optional[str]
) -> bool:
"""Per-user sliding-window snapshot-replay budget.
Fails open on Redis errors or when the budget is disabled. Empty-backlog
no-cursor connects skip INCR so dev double-mounts don't trip 429.
"""
budget = int(settings.EVENTS_REPLAY_BUDGET_REQUESTS_PER_WINDOW)
if budget <= 0:
return True
if redis_client is None:
return True
# Cheap pre-check: only INCR when we might actually replay. XLEN
# is one Redis op; the alternative (INCR every connect) is two
# ops AND wrongly counts no-op probes. The check is conservative:
# if ``last_event_id`` is set we always INCR, even if the cursor
# has already overtaken the latest entry — that case is rare and
# short-lived, and probing further would mean a redundant XRANGE.
if last_event_id is None:
try:
if int(redis_client.xlen(stream_key(user_id))) == 0:
return True
except Exception:
# XLEN probe failed; fall through to the INCR path so a
# transient Redis hiccup can't bypass the budget.
logger.debug(
"XLEN probe failed for replay budget check user=%s; "
"proceeding to INCR",
user_id,
)
window = max(1, int(settings.EVENTS_REPLAY_BUDGET_WINDOW_SECONDS))
key = replay_budget_key(user_id)
try:
used = int(redis_client.incr(key))
# Always (re)seed the TTL. Gating on ``used == 1`` would wedge
# the counter forever if INCR succeeds but EXPIRE raises on
# the seeding call. EXPIRE on an existing key resets the TTL
# to ``window`` — within ±1s of the per-window budget semantic.
redis_client.expire(key, window)
except Exception:
logger.debug(
"replay budget probe failed for user=%s; failing open",
user_id,
)
return True
return used <= budget
def _normalize_last_event_id(raw: Optional[str]) -> Optional[str]:
"""Validate the ``Last-Event-ID`` header / query param.
Returns the value unchanged when it parses as a Redis Streams id,
otherwise ``None`` — callers treat ``None`` as "client has nothing"
and replay from the start of the retained window. Invalid ids would
otherwise pass straight to XRANGE and surface as a quiet replay
failure plus broken truncation detection.
"""
if raw is None:
return None
raw = raw.strip()
if not raw or not _STREAM_ID_RE.match(raw):
return None
return raw
def _replay_backlog(
redis_client, user_id: str, last_event_id: Optional[str], max_count: int
) -> Iterator[tuple[str, str]]:
"""Yield ``(entry_id, sse_line)`` for backlog entries past ``last_event_id``.
Capped at ``max_count`` rows; clients catch up across reconnects.
Parse failures are skipped; the Streams id is injected into the
envelope so replay matches live-tail shape.
"""
# Exclusive start: '(<id>' skips the already-delivered entry.
start = f"({last_event_id}" if last_event_id else "-"
try:
entries = redis_client.xrange(
stream_key(user_id), min=start, max="+", count=max_count
)
except Exception as exc:
logger.warning(
"xrange replay failed for user=%s last_id=%s err=%s",
user_id,
last_event_id or "-",
exc,
)
return
for entry_id, fields in entries:
entry_id_str = _decode(entry_id)
if not entry_id_str:
continue
# decode_responses=False on the cache client ⇒ field keys/values
# are bytes. The string-key fallback covers a future flip of that
# default without a forced refactor here.
raw_event = None
if isinstance(fields, dict):
raw_event = fields.get(b"event")
if raw_event is None:
raw_event = fields.get("event")
event_str = _decode(raw_event)
if not event_str:
continue
try:
envelope = json.loads(event_str)
if isinstance(envelope, dict):
envelope["id"] = entry_id_str
event_str = json.dumps(envelope)
except Exception:
logger.debug(
"Replay envelope parse failed for entry %s; passing through raw",
entry_id_str,
)
yield entry_id_str, _format_sse(event_str, event_id=entry_id_str)
def _truncation_notice_line(oldest_id: str) -> str:
"""SSE event the frontend can react to with a full-state refetch."""
return _format_sse(
json.dumps(
{
"type": "backlog.truncated",
"payload": {"oldest_retained_id": oldest_id},
}
)
)
@events.route("/api/events", methods=["GET"])
def stream_events() -> Response:
decoded = getattr(request, "decoded_token", None)
user_id = decoded.get("sub") if isinstance(decoded, dict) else None
if not user_id:
return make_response(
jsonify({"success": False, "message": "Authentication required"}),
401,
)
# In dev deployments without AUTH_TYPE configured, every request
# resolves to user_id="local" and shares one stream. Surface this so
# an accidentally-multi-user dev box doesn't silently cross-stream.
global _local_user_warned
if user_id == "local" and not _local_user_warned:
logger.warning(
"SSE serving user_id='local' (AUTH_TYPE not set). "
"All clients on this deployment will share one event stream."
)
_local_user_warned = True
raw_last_event_id = request.headers.get("Last-Event-ID") or request.args.get(
"last_event_id"
)
last_event_id = _normalize_last_event_id(raw_last_event_id)
last_event_id_invalid = raw_last_event_id is not None and last_event_id is None
keepalive_seconds = float(settings.SSE_KEEPALIVE_SECONDS)
push_enabled = settings.ENABLE_SSE_PUSH
cap = int(settings.SSE_MAX_CONCURRENT_PER_USER)
redis_client = get_redis_instance()
counter_key = connection_counter_key(user_id)
counted = False
if push_enabled and redis_client is not None and cap > 0:
try:
current = int(redis_client.incr(counter_key))
counted = True
except Exception:
current = 0
logger.debug(
"SSE connection counter INCR failed for user=%s", user_id
)
if counted:
# 1h safety TTL — orphaned counts from hard crashes self-heal.
# EXPIRE failure must NOT clobber ``current`` and bypass the cap.
try:
redis_client.expire(counter_key, 3600)
except Exception:
logger.debug(
"SSE connection counter EXPIRE failed for user=%s", user_id
)
if current > cap:
try:
redis_client.decr(counter_key)
except Exception:
logger.debug(
"SSE connection counter DECR failed for user=%s",
user_id,
)
return make_response(
jsonify(
{
"success": False,
"message": "Too many concurrent SSE connections",
}
),
429,
)
# Replay budget is checked here, before the generator opens the
# stream, so a denial can surface as HTTP 429 instead of a silent
# snapshot skip. The earlier in-generator skip lost events between
# the client's cursor and the first live-tailed entry: the live
# tail still carried ``id:`` headers, the frontend advanced
# ``lastEventId`` to one of those ids, and the events in between
# were never reachable on the next reconnect. 429 keeps the
# cursor pinned and lets the frontend back off until the window
# slides (eventStreamClient.ts treats 429 as escalated backoff).
if push_enabled and redis_client is not None and not _allow_replay(
redis_client, user_id, last_event_id
):
if counted:
try:
redis_client.decr(counter_key)
except Exception:
logger.debug(
"SSE connection counter DECR failed for user=%s",
user_id,
)
return make_response(
jsonify(
{
"success": False,
"message": "Replay budget exhausted",
}
),
429,
)
@stream_with_context
def generate() -> Iterator[str]:
connect_ts = time.monotonic()
replayed_count = 0
try:
# First frame primes intermediaries (Cloudflare, nginx) so they
# don't sit on a buffer waiting for body bytes.
yield ": connected\n\n"
if not push_enabled:
yield ": push_disabled\n\n"
return
replay_lines: list[str] = []
max_replayed_id: Optional[str] = None
replay_done = False
# If the client sent a malformed Last-Event-ID, surface the
# truncation notice synchronously *before* the subscribe
# loop. Buffering it into ``replay_lines`` would lose it
# when ``Topic.subscribe`` returns immediately (Redis down)
# — the loop body never runs, and the flush at line ~335
# never fires.
if last_event_id_invalid:
yield _truncation_notice_line("")
replayed_count += 1
def _on_subscribe_callback() -> None:
# Runs synchronously inside Topic.subscribe after the
# SUBSCRIBE is acked. By doing XRANGE here, any publisher
# firing between SUBSCRIBE-send and SUBSCRIBE-ack has its
# XADD captured by XRANGE *and* its PUBLISH buffered at
# the connection layer until we read it — closing the
# replay/subscribe race the design doc warns about.
#
# Truncation contract: ``backlog.truncated`` is emitted
# ONLY when the client's ``Last-Event-ID`` has slid off
# the MAXLEN'd window — that's the case where the
# journal is genuinely gone past the cursor and the
# frontend should clear its slice cursor and refetch
# state. Cap-hit skips the snapshot silently: the
# cursor advances via the per-entry ``id:`` headers
# and the frontend's slice keeps the latest id so the
# next reconnect resumes from there. Budget-exhausted
# never reaches this callback — the route 429s before
# opening the stream, keeping the cursor pinned.
# Conflating these with stale-cursor truncation would
# tell the client to clear its cursor and re-receive
# the same oldest-N entries on every reconnect —
# locking the user out of entries past N.
nonlocal max_replayed_id, replay_done
try:
if redis_client is None:
return
oldest = _oldest_retained_id(redis_client, user_id)
if (
last_event_id
and oldest
and stream_id_compare(last_event_id, oldest) < 0
):
# The Last-Event-ID has slid off the MAXLEN window.
# Tell the client so it can fetch full state.
replay_lines.append(_truncation_notice_line(oldest))
replay_cap = int(settings.EVENTS_REPLAY_MAX_PER_REQUEST)
for entry_id, sse_line in _replay_backlog(
redis_client, user_id, last_event_id, replay_cap
):
replay_lines.append(sse_line)
max_replayed_id = entry_id
finally:
# Always flip the flag — even on partial-replay failure
# the outer loop must reach the flush step so we don't
# silently strand whatever entries did land.
replay_done = True
topic = Topic(topic_name(user_id))
last_keepalive = time.monotonic()
for payload in topic.subscribe(
on_subscribe=_on_subscribe_callback,
poll_timeout=SUBSCRIBE_POLL_INTERVAL_SECONDS,
):
# Flush snapshot on the first iteration after the SUBSCRIBE
# callback ran. This runs at most once per connection.
if replay_done and replay_lines:
for line in replay_lines:
yield line
replayed_count += 1
replay_lines.clear()
now = time.monotonic()
if payload is None:
if now - last_keepalive >= keepalive_seconds:
yield ": keepalive\n\n"
last_keepalive = now
continue
event_str = _decode(payload) or ""
event_id: Optional[str] = None
try:
envelope = json.loads(event_str)
if isinstance(envelope, dict):
candidate = envelope.get("id")
# Only trust ids that look like real Redis Streams
# ids (``ms`` or ``ms-seq``). A malformed or
# adversarial publisher could otherwise pin
# dedupe forever — a lex-greater bogus id would
# make every legitimate later id compare ``<=``
# and get dropped silently.
if isinstance(candidate, str) and _STREAM_ID_RE.match(
candidate
):
event_id = candidate
except Exception:
pass
# Dedupe: if this id was already covered by replay, drop it.
if (
event_id is not None
and max_replayed_id is not None
and stream_id_compare(event_id, max_replayed_id) <= 0
):
continue
yield _format_sse(event_str, event_id=event_id)
last_keepalive = now
# Topic.subscribe exited before the first yield (transient
# Redis hiccup between SUBSCRIBE-ack and the first poll, or
# an immediate Redis-down return). The callback may already
# have populated the snapshot — flush it so the client gets
# the backlog instead of a silent drop. Safe no-op when the
# in-loop flush ran (it clear()'d the buffer) and when the
# callback never fired (replay_done stays False).
if replay_done and replay_lines:
for line in replay_lines:
yield line
replayed_count += 1
replay_lines.clear()
except GeneratorExit:
return
except Exception:
logger.exception(
"SSE event-stream generator crashed for user=%s", user_id
)
finally:
duration_s = time.monotonic() - connect_ts
logger.info(
"event.disconnect user=%s duration_s=%.1f replayed=%d",
user_id,
duration_s,
replayed_count,
)
if counted and redis_client is not None:
try:
redis_client.decr(counter_key)
except Exception:
logger.debug(
"SSE connection counter DECR failed for user=%s on disconnect",
user_id,
)
response = Response(generate(), mimetype="text/event-stream")
response.headers["Cache-Control"] = "no-store"
response.headers["X-Accel-Buffering"] = "no"
response.headers["Connection"] = "keep-alive"
logger.info(
"event.connect user=%s last_event_id=%s%s",
user_id,
last_event_id or "-",
" (rejected_invalid)" if last_event_id_invalid else "",
)
return response

View File

@@ -3,16 +3,18 @@ import datetime
import json
from flask import Blueprint, request, send_from_directory, jsonify
from werkzeug.utils import secure_filename
from bson.objectid import ObjectId
import logging
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.storage.db.base_repository import looks_like_uuid
from application.storage.db.repositories.sources import SourcesRepository
from application.storage.db.session import db_session
from application.storage.storage_creator import StorageCreator
logger = logging.getLogger(__name__)
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
conversations_collection = db["conversations"]
sources_collection = db["sources"]
current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
@@ -54,21 +56,21 @@ def upload_index_files():
"""Upload two files(index.faiss, index.pkl) to the user's folder."""
if "user" not in request.form:
return {"status": "no user"}
user = request.form["user"]
user = request.form["user"]
if "name" not in request.form:
return {"status": "no name"}
job_name = request.form["name"]
tokens = request.form["tokens"]
retriever = request.form["retriever"]
source_id = request.form["id"]
id = request.form["id"]
type = request.form["type"]
remote_data = request.form["remote_data"] if "remote_data" in request.form else None
sync_frequency = request.form["sync_frequency"] if "sync_frequency" in request.form else None
file_path = request.form.get("file_path")
directory_structure = request.form.get("directory_structure")
file_name_map = request.form.get("file_name_map")
if directory_structure:
try:
directory_structure = json.loads(directory_structure)
@@ -87,8 +89,8 @@ def upload_index_files():
file_name_map = None
storage = StorageCreator.get_storage()
index_base_path = f"indexes/{source_id}"
index_base_path = f"indexes/{id}"
if settings.VECTOR_STORE == "faiss":
if "file_faiss" not in request.files:
logger.error("No file_faiss part")
@@ -109,48 +111,46 @@ def upload_index_files():
storage.save_file(file_faiss, faiss_storage_path)
storage.save_file(file_pkl, pkl_storage_path)
now = datetime.datetime.now(datetime.timezone.utc)
update_fields = {
"name": job_name,
"type": type,
"language": job_name,
"date": now,
"model": settings.EMBEDDINGS_NAME,
"tokens": tokens,
"retriever": retriever,
"remote_data": remote_data,
"sync_frequency": sync_frequency,
"file_path": file_path,
"directory_structure": directory_structure,
}
if file_name_map is not None:
update_fields["file_name_map"] = file_name_map
with db_session() as conn:
repo = SourcesRepository(conn)
existing = None
if looks_like_uuid(source_id):
existing = repo.get(source_id, user)
if existing is None:
existing = repo.get_by_legacy_id(source_id, user)
if existing is not None:
repo.update(str(existing["id"]), user, update_fields)
else:
repo.create(
job_name,
source_id=source_id if looks_like_uuid(source_id) else None,
user_id=user,
type=type,
tokens=tokens,
retriever=retriever,
remote_data=remote_data,
sync_frequency=sync_frequency,
file_path=file_path,
directory_structure=directory_structure,
file_name_map=file_name_map,
language=job_name,
model=settings.EMBEDDINGS_NAME,
date=now,
legacy_mongo_id=None if looks_like_uuid(source_id) else str(source_id),
)
existing_entry = sources_collection.find_one({"_id": ObjectId(id)})
if existing_entry:
update_fields = {
"user": user,
"name": job_name,
"language": job_name,
"date": datetime.datetime.now(),
"model": settings.EMBEDDINGS_NAME,
"type": type,
"tokens": tokens,
"retriever": retriever,
"remote_data": remote_data,
"sync_frequency": sync_frequency,
"file_path": file_path,
"directory_structure": directory_structure,
}
if file_name_map is not None:
update_fields["file_name_map"] = file_name_map
sources_collection.update_one(
{"_id": ObjectId(id)},
{"$set": update_fields},
)
else:
insert_doc = {
"_id": ObjectId(id),
"user": user,
"name": job_name,
"language": job_name,
"date": datetime.datetime.now(),
"model": settings.EMBEDDINGS_NAME,
"type": type,
"tokens": tokens,
"retriever": retriever,
"remote_data": remote_data,
"sync_frequency": sync_frequency,
"file_path": file_path,
"directory_structure": directory_structure,
}
if file_name_map is not None:
insert_doc["file_name_map"] = file_name_map
sources_collection.insert_one(insert_doc)
return {"status": "ok"}

View File

@@ -3,50 +3,29 @@ Agent folders management routes.
Provides virtual folder organization for agents (Google Drive-like structure).
"""
import datetime
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask_restx import Namespace, Resource, fields
from sqlalchemy import text as _sql_text
from application.api import api
from application.storage.db.base_repository import looks_like_uuid
from application.api.user.base import (
agent_folders_collection,
agents_collection,
)
from application.storage.db.dual_write import dual_write
from application.storage.db.repositories.agent_folders import AgentFoldersRepository
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.session import db_readonly, db_session
agents_folders_ns = Namespace(
"agents_folders", description="Agent folder management", path="/api/agents/folders"
)
def _resolve_folder_id(repo: AgentFoldersRepository, folder_id: str, user: str):
"""Resolve a folder id that may be either a UUID or legacy Mongo ObjectId."""
if not folder_id:
return None
if looks_like_uuid(folder_id):
row = repo.get(folder_id, user)
if row is not None:
return row
return repo.get_by_legacy_id(folder_id, user)
def _folder_error_response(message: str, err: Exception):
current_app.logger.error(f"{message}: {err}", exc_info=True)
return make_response(jsonify({"success": False, "message": message}), 400)
def _serialize_folder(f: dict) -> dict:
created_at = f.get("created_at")
updated_at = f.get("updated_at")
return {
"id": str(f["id"]),
"name": f.get("name"),
"parent_id": str(f["parent_id"]) if f.get("parent_id") else None,
"created_at": created_at.isoformat() if hasattr(created_at, "isoformat") else created_at,
"updated_at": updated_at.isoformat() if hasattr(updated_at, "isoformat") else updated_at,
}
@agents_folders_ns.route("/")
class AgentFolders(Resource):
@api.doc(description="Get all folders for the user")
@@ -56,9 +35,17 @@ class AgentFolders(Resource):
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
try:
with db_readonly() as conn:
folders = AgentFoldersRepository(conn).list_for_user(user)
result = [_serialize_folder(f) for f in folders]
folders = list(agent_folders_collection.find({"user": user}))
result = [
{
"id": str(f["_id"]),
"name": f["name"],
"parent_id": f.get("parent_id"),
"created_at": f.get("created_at", "").isoformat() if f.get("created_at") else None,
"updated_at": f.get("updated_at", "").isoformat() if f.get("updated_at") else None,
}
for f in folders
]
return make_response(jsonify({"folders": result}), 200)
except Exception as err:
return _folder_error_response("Failed to fetch folders", err)
@@ -82,34 +69,28 @@ class AgentFolders(Resource):
if not data or not data.get("name"):
return make_response(jsonify({"success": False, "message": "Folder name is required"}), 400)
parent_id_input = data.get("parent_id")
description = data.get("description")
parent_id = data.get("parent_id")
if parent_id:
parent = agent_folders_collection.find_one({"_id": ObjectId(parent_id), "user": user})
if not parent:
return make_response(jsonify({"success": False, "message": "Parent folder not found"}), 404)
try:
with db_session() as conn:
repo = AgentFoldersRepository(conn)
pg_parent_id = None
if parent_id_input:
parent = _resolve_folder_id(repo, parent_id_input, user)
if not parent:
return make_response(
jsonify({"success": False, "message": "Parent folder not found"}),
404,
)
pg_parent_id = str(parent["id"])
folder = repo.create(
user, data["name"],
description=description,
parent_id=pg_parent_id,
)
now = datetime.datetime.now(datetime.timezone.utc)
folder = {
"user": user,
"name": data["name"],
"parent_id": parent_id,
"created_at": now,
"updated_at": now,
}
result = agent_folders_collection.insert_one(folder)
dual_write(
AgentFoldersRepository,
lambda repo, u=user, n=data["name"]: repo.create(u, n),
)
return make_response(
jsonify(
{
"id": str(folder["id"]),
"name": folder["name"],
"parent_id": pg_parent_id,
}
),
jsonify({"id": str(result.inserted_id), "name": data["name"], "parent_id": parent_id}),
201,
)
except Exception as err:
@@ -125,51 +106,26 @@ class AgentFolder(Resource):
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
try:
with db_readonly() as conn:
folders_repo = AgentFoldersRepository(conn)
folder = _resolve_folder_id(folders_repo, folder_id, user)
if not folder:
return make_response(
jsonify({"success": False, "message": "Folder not found"}),
404,
)
pg_folder_id = str(folder["id"])
agents_rows = conn.execute(
_sql_text(
"SELECT id, name, description FROM agents "
"WHERE user_id = :user_id AND folder_id = CAST(:fid AS uuid) "
"ORDER BY created_at DESC"
),
{"user_id": user, "fid": pg_folder_id},
).fetchall()
agents_list = [
{
"id": str(row._mapping["id"]),
"name": row._mapping["name"],
"description": row._mapping.get("description", "") or "",
}
for row in agents_rows
]
subfolders = folders_repo.list_children(pg_folder_id, user)
subfolders_list = [
{"id": str(sf["id"]), "name": sf["name"]}
for sf in subfolders
]
folder = agent_folders_collection.find_one({"_id": ObjectId(folder_id), "user": user})
if not folder:
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
agents = list(agents_collection.find({"user": user, "folder_id": folder_id}))
agents_list = [
{"id": str(a["_id"]), "name": a["name"], "description": a.get("description", "")}
for a in agents
]
subfolders = list(agent_folders_collection.find({"user": user, "parent_id": folder_id}))
subfolders_list = [{"id": str(sf["_id"]), "name": sf["name"]} for sf in subfolders]
return make_response(
jsonify(
{
"id": pg_folder_id,
"name": folder["name"],
"parent_id": (
str(folder["parent_id"]) if folder.get("parent_id") else None
),
"agents": agents_list,
"subfolders": subfolders_list,
}
),
jsonify({
"id": str(folder["_id"]),
"name": folder["name"],
"parent_id": folder.get("parent_id"),
"agents": agents_list,
"subfolders": subfolders_list,
}),
200,
)
except Exception as err:
@@ -186,57 +142,19 @@ class AgentFolder(Resource):
return make_response(jsonify({"success": False, "message": "No data provided"}), 400)
try:
with db_session() as conn:
repo = AgentFoldersRepository(conn)
folder = _resolve_folder_id(repo, folder_id, user)
if not folder:
return make_response(
jsonify({"success": False, "message": "Folder not found"}),
404,
)
pg_folder_id = str(folder["id"])
update_fields: dict = {}
if "name" in data:
update_fields["name"] = data["name"]
if "description" in data:
update_fields["description"] = data["description"]
if "parent_id" in data:
parent_input = data.get("parent_id")
if parent_input:
if parent_input == folder_id or parent_input == pg_folder_id:
return make_response(
jsonify(
{
"success": False,
"message": "Cannot set folder as its own parent",
}
),
400,
)
parent = _resolve_folder_id(repo, parent_input, user)
if not parent:
return make_response(
jsonify({"success": False, "message": "Parent folder not found"}),
404,
)
if str(parent["id"]) == pg_folder_id:
return make_response(
jsonify(
{
"success": False,
"message": "Cannot set folder as its own parent",
}
),
400,
)
update_fields["parent_id"] = str(parent["id"])
else:
update_fields["parent_id"] = None
if update_fields:
repo.update(pg_folder_id, user, update_fields)
update_fields = {"updated_at": datetime.datetime.now(datetime.timezone.utc)}
if "name" in data:
update_fields["name"] = data["name"]
if "parent_id" in data:
if data["parent_id"] == folder_id:
return make_response(jsonify({"success": False, "message": "Cannot set folder as its own parent"}), 400)
update_fields["parent_id"] = data["parent_id"]
result = agent_folders_collection.update_one(
{"_id": ObjectId(folder_id), "user": user}, {"$set": update_fields}
)
if result.matched_count == 0:
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
return make_response(jsonify({"success": True}), 200)
except Exception as err:
return _folder_error_response("Failed to update folder", err)
@@ -248,24 +166,19 @@ class AgentFolder(Resource):
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
try:
with db_session() as conn:
repo = AgentFoldersRepository(conn)
folder = _resolve_folder_id(repo, folder_id, user)
if not folder:
return make_response(
jsonify({"success": False, "message": "Folder not found"}),
404,
)
pg_folder_id = str(folder["id"])
# Clear folder assignments from agents; self-FK
# ``ON DELETE SET NULL`` handles child folders.
AgentsRepository(conn).clear_folder_for_all(pg_folder_id, user)
deleted = repo.delete(pg_folder_id, user)
if not deleted:
return make_response(
jsonify({"success": False, "message": "Folder not found"}),
404,
)
agents_collection.update_many(
{"user": user, "folder_id": folder_id}, {"$unset": {"folder_id": ""}}
)
agent_folders_collection.update_many(
{"user": user, "parent_id": folder_id}, {"$unset": {"parent_id": ""}}
)
result = agent_folders_collection.delete_one({"_id": ObjectId(folder_id), "user": user})
dual_write(
AgentFoldersRepository,
lambda repo, fid=folder_id, u=user: repo.delete(fid, u),
)
if result.deleted_count == 0:
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
return make_response(jsonify({"success": True}), 200)
except Exception as err:
return _folder_error_response("Failed to delete folder", err)
@@ -292,29 +205,26 @@ class MoveAgentToFolder(Resource):
if not data or not data.get("agent_id"):
return make_response(jsonify({"success": False, "message": "Agent ID is required"}), 400)
agent_id_input = data["agent_id"]
folder_id_input = data.get("folder_id")
agent_id = data["agent_id"]
folder_id = data.get("folder_id")
try:
with db_session() as conn:
agents_repo = AgentsRepository(conn)
agent = agents_repo.get_any(agent_id_input, user)
if not agent:
return make_response(
jsonify({"success": False, "message": "Agent not found"}),
404,
)
pg_folder_id = None
if folder_id_input:
folders_repo = AgentFoldersRepository(conn)
folder = _resolve_folder_id(folders_repo, folder_id_input, user)
if not folder:
return make_response(
jsonify({"success": False, "message": "Folder not found"}),
404,
)
pg_folder_id = str(folder["id"])
agents_repo.set_folder(str(agent["id"]), user, pg_folder_id)
agent = agents_collection.find_one({"_id": ObjectId(agent_id), "user": user})
if not agent:
return make_response(jsonify({"success": False, "message": "Agent not found"}), 404)
if folder_id:
folder = agent_folders_collection.find_one({"_id": ObjectId(folder_id), "user": user})
if not folder:
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
agents_collection.update_one(
{"_id": ObjectId(agent_id)}, {"$set": {"folder_id": folder_id}}
)
else:
agents_collection.update_one(
{"_id": ObjectId(agent_id)}, {"$unset": {"folder_id": ""}}
)
return make_response(jsonify({"success": True}), 200)
except Exception as err:
return _folder_error_response("Failed to move agent", err)
@@ -342,25 +252,25 @@ class BulkMoveAgents(Resource):
return make_response(jsonify({"success": False, "message": "Agent IDs are required"}), 400)
agent_ids = data["agent_ids"]
folder_id_input = data.get("folder_id")
folder_id = data.get("folder_id")
try:
with db_session() as conn:
agents_repo = AgentsRepository(conn)
pg_folder_id = None
if folder_id_input:
folders_repo = AgentFoldersRepository(conn)
folder = _resolve_folder_id(folders_repo, folder_id_input, user)
if not folder:
return make_response(
jsonify({"success": False, "message": "Folder not found"}),
404,
)
pg_folder_id = str(folder["id"])
for agent_id_input in agent_ids:
agent = agents_repo.get_any(agent_id_input, user)
if agent is not None:
agents_repo.set_folder(str(agent["id"]), user, pg_folder_id)
if folder_id:
folder = agent_folders_collection.find_one({"_id": ObjectId(folder_id), "user": user})
if not folder:
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
object_ids = [ObjectId(aid) for aid in agent_ids]
if folder_id:
agents_collection.update_many(
{"_id": {"$in": object_ids}, "user": user},
{"$set": {"folder_id": folder_id}},
)
else:
agents_collection.update_many(
{"_id": {"$in": object_ids}, "user": user},
{"$unset": {"folder_id": ""}},
)
return make_response(jsonify({"success": True}), 200)
except Exception as err:
return _folder_error_response("Failed to move agents", err)

File diff suppressed because it is too large Load Diff

View File

@@ -3,17 +3,23 @@
import datetime
import secrets
from bson import DBRef
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, Namespace, Resource
from sqlalchemy import text as _sql_text
from application.api import api
from application.core.settings import settings
from application.api.user.base import resolve_tool_details
from application.storage.db.base_repository import looks_like_uuid
from application.storage.db.repositories.agents import AgentsRepository
from application.api.user.base import (
agents_collection,
db,
ensure_user_doc,
resolve_tool_details,
user_tools_collection,
users_collection,
)
from application.storage.db.dual_write import dual_write
from application.storage.db.repositories.users import UsersRepository
from application.storage.db.session import db_readonly, db_session
from application.utils import generate_image_url
agents_sharing_ns = Namespace(
@@ -21,38 +27,6 @@ agents_sharing_ns = Namespace(
)
def _serialize_agent_basic(agent: dict) -> dict:
"""Shape a PG agent row into the API response dict."""
source_id = agent.get("source_id")
return {
"id": str(agent["id"]),
"user": agent.get("user_id", ""),
"name": agent.get("name", ""),
"image": (
generate_image_url(agent["image"]) if agent.get("image") else ""
),
"description": agent.get("description", ""),
"source": str(source_id) if source_id else "",
"chunks": str(agent["chunks"]) if agent.get("chunks") is not None else "0",
"retriever": agent.get("retriever", "classic") or "classic",
"prompt_id": str(agent["prompt_id"]) if agent.get("prompt_id") else "default",
"tools": agent.get("tools", []) or [],
"tool_details": resolve_tool_details(agent.get("tools", []) or []),
"agent_type": agent.get("agent_type", "") or "",
"status": agent.get("status", "") or "",
"json_schema": agent.get("json_schema"),
"limited_token_mode": agent.get("limited_token_mode", False),
"token_limit": agent.get("token_limit") or settings.DEFAULT_AGENT_LIMITS["token_limit"],
"limited_request_mode": agent.get("limited_request_mode", False),
"request_limit": agent.get("request_limit") or settings.DEFAULT_AGENT_LIMITS["request_limit"],
"created_at": agent.get("created_at", ""),
"updated_at": agent.get("updated_at", ""),
"shared": bool(agent.get("shared", False)),
"shared_token": agent.get("shared_token", "") or "",
"shared_metadata": agent.get("shared_metadata", {}) or {},
}
@agents_sharing_ns.route("/shared_agent")
class SharedAgent(Resource):
@api.doc(
@@ -69,33 +43,73 @@ class SharedAgent(Resource):
jsonify({"success": False, "message": "Token or ID is required"}), 400
)
try:
with db_readonly() as conn:
shared_agent = AgentsRepository(conn).find_by_shared_token(
shared_token,
)
query = {
"shared_publicly": True,
"shared_token": shared_token,
}
shared_agent = agents_collection.find_one(query)
if not shared_agent:
return make_response(
jsonify({"success": False, "message": "Shared agent not found"}),
404,
)
agent_id = str(shared_agent["id"])
data = _serialize_agent_basic(shared_agent)
agent_id = str(shared_agent["_id"])
data = {
"id": agent_id,
"user": shared_agent.get("user", ""),
"name": shared_agent.get("name", ""),
"image": (
generate_image_url(shared_agent["image"])
if shared_agent.get("image")
else ""
),
"description": shared_agent.get("description", ""),
"source": (
str(source_doc["_id"])
if isinstance(shared_agent.get("source"), DBRef)
and (source_doc := db.dereference(shared_agent.get("source")))
else ""
),
"chunks": shared_agent.get("chunks", "0"),
"retriever": shared_agent.get("retriever", "classic"),
"prompt_id": shared_agent.get("prompt_id", "default"),
"tools": shared_agent.get("tools", []),
"tool_details": resolve_tool_details(shared_agent.get("tools", [])),
"agent_type": shared_agent.get("agent_type", ""),
"status": shared_agent.get("status", ""),
"json_schema": shared_agent.get("json_schema"),
"limited_token_mode": shared_agent.get("limited_token_mode", False),
"token_limit": shared_agent.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"]),
"limited_request_mode": shared_agent.get("limited_request_mode", False),
"request_limit": shared_agent.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"]),
"created_at": shared_agent.get("createdAt", ""),
"updated_at": shared_agent.get("updatedAt", ""),
"shared": shared_agent.get("shared_publicly", False),
"shared_token": shared_agent.get("shared_token", ""),
"shared_metadata": shared_agent.get("shared_metadata", {}),
}
if data["tools"]:
enriched_tools = []
for detail in data["tool_details"]:
enriched_tools.append(detail.get("name", ""))
for tool in data["tools"]:
tool_data = user_tools_collection.find_one({"_id": ObjectId(tool)})
if tool_data:
enriched_tools.append(tool_data.get("name", ""))
data["tools"] = enriched_tools
decoded_token = getattr(request, "decoded_token", None)
if decoded_token:
user_id = decoded_token.get("sub")
owner_id = shared_agent.get("user_id")
owner_id = shared_agent.get("user")
if user_id != owner_id:
with db_session() as conn:
users_repo = UsersRepository(conn)
users_repo.upsert(user_id)
users_repo.add_shared(user_id, agent_id)
ensure_user_doc(user_id)
users_collection.update_one(
{"user_id": user_id},
{"$addToSet": {"agent_preferences.shared_with_me": agent_id}},
)
dual_write(UsersRepository,
lambda repo, uid=user_id, aid=agent_id: repo.add_shared(uid, aid)
)
return make_response(jsonify(data), 200)
except Exception as err:
current_app.logger.error(f"Error retrieving shared agent: {err}")
@@ -112,73 +126,55 @@ class SharedAgents(Resource):
return make_response(jsonify({"success": False}), 401)
user_id = decoded_token.get("sub")
with db_session() as conn:
users_repo = UsersRepository(conn)
user_doc = users_repo.upsert(user_id)
shared_with_ids = (
user_doc.get("agent_preferences", {}).get("shared_with_me", [])
if isinstance(user_doc.get("agent_preferences"), dict)
else []
user_doc = ensure_user_doc(user_id)
shared_with_ids = user_doc.get("agent_preferences", {}).get(
"shared_with_me", []
)
shared_object_ids = [ObjectId(id) for id in shared_with_ids]
shared_agents_cursor = agents_collection.find(
{"_id": {"$in": shared_object_ids}, "shared_publicly": True}
)
shared_agents = list(shared_agents_cursor)
found_ids_set = {str(agent["_id"]) for agent in shared_agents}
stale_ids = [id for id in shared_with_ids if id not in found_ids_set]
if stale_ids:
users_collection.update_one(
{"user_id": user_id},
{"$pullAll": {"agent_preferences.shared_with_me": stale_ids}},
)
# Keep only UUID-shaped ids; ObjectId leftovers are stripped below.
uuid_ids = [sid for sid in shared_with_ids if looks_like_uuid(sid)]
non_uuid_ids = [sid for sid in shared_with_ids if not looks_like_uuid(sid)]
if uuid_ids:
result = conn.execute(
_sql_text(
"SELECT * FROM agents "
"WHERE id = ANY(CAST(:ids AS uuid[])) "
"AND shared = true"
),
{"ids": uuid_ids},
)
shared_agents = [dict(row._mapping) for row in result.fetchall()]
else:
shared_agents = []
found_ids_set = {str(agent["id"]) for agent in shared_agents}
stale_ids = [sid for sid in uuid_ids if sid not in found_ids_set]
stale_ids.extend(non_uuid_ids)
if stale_ids:
users_repo.remove_shared_bulk(user_id, stale_ids)
pinned_ids = set(
user_doc.get("agent_preferences", {}).get("pinned", [])
if isinstance(user_doc.get("agent_preferences"), dict)
else []
dual_write(UsersRepository,
lambda repo, uid=user_id, ids=stale_ids: repo.remove_shared_bulk(uid, ids)
)
pinned_ids = set(user_doc.get("agent_preferences", {}).get("pinned", []))
list_shared_agents = []
for agent in shared_agents:
agent_id_str = str(agent["id"])
list_shared_agents.append(
{
"id": agent_id_str,
"name": agent.get("name", ""),
"description": agent.get("description", ""),
"image": (
generate_image_url(agent["image"]) if agent.get("image") else ""
),
"tools": agent.get("tools", []) or [],
"tool_details": resolve_tool_details(
agent.get("tools", []) or []
),
"agent_type": agent.get("agent_type", "") or "",
"status": agent.get("status", "") or "",
"json_schema": agent.get("json_schema"),
"limited_token_mode": agent.get("limited_token_mode", False),
"token_limit": agent.get("token_limit") or settings.DEFAULT_AGENT_LIMITS["token_limit"],
"limited_request_mode": agent.get("limited_request_mode", False),
"request_limit": agent.get("request_limit") or settings.DEFAULT_AGENT_LIMITS["request_limit"],
"created_at": agent.get("created_at", ""),
"updated_at": agent.get("updated_at", ""),
"pinned": agent_id_str in pinned_ids,
"shared": bool(agent.get("shared", False)),
"shared_token": agent.get("shared_token", "") or "",
"shared_metadata": agent.get("shared_metadata", {}) or {},
}
)
list_shared_agents = [
{
"id": str(agent["_id"]),
"name": agent.get("name", ""),
"description": agent.get("description", ""),
"image": (
generate_image_url(agent["image"]) if agent.get("image") else ""
),
"tools": agent.get("tools", []),
"tool_details": resolve_tool_details(agent.get("tools", [])),
"agent_type": agent.get("agent_type", ""),
"status": agent.get("status", ""),
"json_schema": agent.get("json_schema"),
"limited_token_mode": agent.get("limited_token_mode", False),
"token_limit": agent.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"]),
"limited_request_mode": agent.get("limited_request_mode", False),
"request_limit": agent.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"]),
"created_at": agent.get("createdAt", ""),
"updated_at": agent.get("updatedAt", ""),
"pinned": str(agent["_id"]) in pinned_ids,
"shared": agent.get("shared_publicly", False),
"shared_token": agent.get("shared_token", ""),
"shared_metadata": agent.get("shared_metadata", {}),
}
for agent in shared_agents
]
return make_response(jsonify(list_shared_agents), 200)
except Exception as err:
@@ -232,43 +228,44 @@ class ShareAgent(Resource):
),
400,
)
shared_token = None
try:
with db_session() as conn:
repo = AgentsRepository(conn)
agent = repo.get_any(agent_id, user)
if not agent:
return make_response(
jsonify({"success": False, "message": "Agent not found"}), 404
)
if shared:
shared_metadata = {
"shared_by": username,
"shared_at": datetime.datetime.now(
datetime.timezone.utc
).isoformat(),
}
shared_token = secrets.token_urlsafe(32)
repo.update(
str(agent["id"]), user,
{
"shared": True,
"shared_token": shared_token,
try:
agent_oid = ObjectId(agent_id)
except Exception:
return make_response(
jsonify({"success": False, "message": "Invalid agent ID"}), 400
)
agent = agents_collection.find_one({"_id": agent_oid, "user": user})
if not agent:
return make_response(
jsonify({"success": False, "message": "Agent not found"}), 404
)
if shared:
shared_metadata = {
"shared_by": username,
"shared_at": datetime.datetime.now(datetime.timezone.utc),
}
shared_token = secrets.token_urlsafe(32)
agents_collection.update_one(
{"_id": agent_oid, "user": user},
{
"$set": {
"shared_publicly": shared,
"shared_metadata": shared_metadata,
},
)
else:
repo.update(
str(agent["id"]), user,
{
"shared": False,
"shared_token": None,
"shared_metadata": None,
},
)
"shared_token": shared_token,
}
},
)
else:
agents_collection.update_one(
{"_id": agent_oid, "user": user},
{"$set": {"shared_publicly": shared, "shared_token": None}},
{"$unset": {"shared_metadata": ""}},
)
except Exception as err:
current_app.logger.error(f"Error sharing/unsharing agent: {err}", exc_info=True)
return make_response(jsonify({"success": False, "error": "Failed to update agent sharing status"}), 400)
shared_token = shared_token if shared else None
return make_response(
jsonify({"success": True, "shared_token": shared_token}), 200
)

View File

@@ -1,20 +1,15 @@
"""Agent management webhook handlers."""
import secrets
import uuid
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask_restx import Namespace, Resource
from sqlalchemy import text as sql_text
from application.api import api
from application.api.user.base import require_agent
from application.api.user.base import agents_collection, require_agent
from application.api.user.tasks import process_agent_webhook
from application.core.settings import settings
from application.storage.db.base_repository import looks_like_uuid
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.repositories.idempotency import IdempotencyRepository
from application.storage.db.session import db_readonly, db_session
agents_webhooks_ns = Namespace(
@@ -22,37 +17,6 @@ agents_webhooks_ns = Namespace(
)
_IDEMPOTENCY_KEY_MAX_LEN = 256
def _read_idempotency_key():
"""Return (key, error_response). Empty header → (None, None); oversized → (None, 400)."""
key = request.headers.get("Idempotency-Key")
if not key:
return None, None
if len(key) > _IDEMPOTENCY_KEY_MAX_LEN:
return None, make_response(
jsonify(
{
"success": False,
"message": (
f"Idempotency-Key exceeds maximum length of "
f"{_IDEMPOTENCY_KEY_MAX_LEN} characters"
),
}
),
400,
)
return key, None
def _scoped_idempotency_key(idempotency_key, scope):
"""``{scope}:{key}`` so different agents can't collide on the same key."""
if not idempotency_key or not scope:
return None
return f"{scope}:{idempotency_key}"
@agents_webhooks_ns.route("/agent_webhook")
class AgentWebhook(Resource):
@api.doc(
@@ -70,8 +34,9 @@ class AgentWebhook(Resource):
jsonify({"success": False, "message": "ID is required"}), 400
)
try:
with db_readonly() as conn:
agent = AgentsRepository(conn).get_any(agent_id, user)
agent = agents_collection.find_one(
{"_id": ObjectId(agent_id), "user": user}
)
if not agent:
return make_response(
jsonify({"success": False, "message": "Agent not found"}), 404
@@ -79,11 +44,10 @@ class AgentWebhook(Resource):
webhook_token = agent.get("incoming_webhook_token")
if not webhook_token:
webhook_token = secrets.token_urlsafe(32)
with db_session() as conn:
AgentsRepository(conn).update(
str(agent["id"]), user,
{"incoming_webhook_token": webhook_token},
)
agents_collection.update_one(
{"_id": ObjectId(agent_id), "user": user},
{"$set": {"incoming_webhook_token": webhook_token}},
)
base_url = settings.API_URL.rstrip("/")
full_webhook_url = f"{base_url}/api/webhooks/agents/{webhook_token}"
except Exception as err:
@@ -103,7 +67,7 @@ class AgentWebhook(Resource):
class AgentWebhookListener(Resource):
method_decorators = [require_agent]
def _enqueue_webhook_task(self, agent_id_str, payload, source_method, agent=None):
def _enqueue_webhook_task(self, agent_id_str, payload, source_method):
if not payload:
current_app.logger.warning(
f"Webhook ({source_method}) received for agent {agent_id_str} with empty payload."
@@ -112,94 +76,26 @@ class AgentWebhookListener(Resource):
f"Incoming {source_method} webhook for agent {agent_id_str}. Enqueuing task with payload: {payload}"
)
idempotency_key, key_error = _read_idempotency_key()
if key_error is not None:
return key_error
# Resolve to PG UUID first so dedup writes don't crash on legacy ids.
agent_uuid = None
if agent is not None:
candidate = str(agent.get("id") or "")
if looks_like_uuid(candidate):
agent_uuid = candidate
if idempotency_key and agent_uuid is None:
current_app.logger.warning(
"Skipping webhook idempotency dedup: agent %s has non-UUID id",
agent_id_str,
)
idempotency_key = None
# Agent-scoped (webhooks have no user_id).
scoped_key = _scoped_idempotency_key(idempotency_key, agent_uuid)
# Claim before enqueue; the loser returns the winner's task_id.
predetermined_task_id = None
if scoped_key:
predetermined_task_id = str(uuid.uuid4())
with db_session() as conn:
claimed = IdempotencyRepository(conn).record_webhook(
key=scoped_key,
agent_id=agent_uuid,
task_id=predetermined_task_id,
response_json={
"success": True, "task_id": predetermined_task_id,
},
)
if claimed is None:
with db_readonly() as conn:
cached = IdempotencyRepository(conn).get_webhook(scoped_key)
if cached is not None:
return make_response(jsonify(cached["response_json"]), 200)
return make_response(
jsonify({"success": True, "task_id": "deduplicated"}), 200
)
try:
apply_kwargs = dict(
kwargs={
"agent_id": agent_id_str,
"payload": payload,
# Scoped so the worker dedup row matches the HTTP claim.
"idempotency_key": scoped_key or idempotency_key,
},
task = process_agent_webhook.delay(
agent_id=agent_id_str,
payload=payload,
)
if predetermined_task_id is not None:
apply_kwargs["task_id"] = predetermined_task_id
task = process_agent_webhook.apply_async(**apply_kwargs)
current_app.logger.info(
f"Task {task.id} enqueued for agent {agent_id_str} ({source_method})."
)
response_payload = {"success": True, "task_id": task.id}
return make_response(jsonify(response_payload), 200)
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
except Exception as err:
current_app.logger.error(
f"Error enqueuing webhook task ({source_method}) for agent {agent_id_str}: {err}",
exc_info=True,
)
if scoped_key:
# Roll back the claim so a retry can succeed.
try:
with db_session() as conn:
conn.execute(
sql_text(
"DELETE FROM webhook_dedup "
"WHERE idempotency_key = :k"
),
{"k": scoped_key},
)
except Exception:
current_app.logger.exception(
"Failed to release webhook_dedup claim for key=%s",
scoped_key,
)
return make_response(
jsonify({"success": False, "message": "Error processing webhook"}), 500
)
@api.doc(
description=(
"Webhook listener for agent events (POST). Expects JSON payload, which "
"is used to trigger processing. Honors an optional ``Idempotency-Key`` "
"header: a repeat request with the same key within 24h returns the "
"original cached response and does not re-enqueue the task."
),
description="Webhook listener for agent events (POST). Expects JSON payload, which is used to trigger processing.",
)
def post(self, webhook_token, agent, agent_id_str):
payload = request.get_json()
@@ -213,20 +109,11 @@ class AgentWebhookListener(Resource):
),
400,
)
return self._enqueue_webhook_task(
agent_id_str, payload, source_method="POST", agent=agent,
)
return self._enqueue_webhook_task(agent_id_str, payload, source_method="POST")
@api.doc(
description=(
"Webhook listener for agent events (GET). Uses URL query parameters as "
"payload to trigger processing. Honors an optional ``Idempotency-Key`` "
"header: a repeat request with the same key within 24h returns the "
"original cached response and does not re-enqueue the task."
),
description="Webhook listener for agent events (GET). Uses URL query parameters as payload to trigger processing.",
)
def get(self, webhook_token, agent, agent_id_str):
payload = request.args.to_dict(flat=True)
return self._enqueue_webhook_task(
agent_id_str, payload, source_method="GET", agent=agent,
)
return self._enqueue_webhook_task(agent_id_str, payload, source_method="GET")

View File

@@ -2,84 +2,26 @@
import datetime
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, Namespace, Resource
from sqlalchemy import text as _sql_text
from application.api import api
from application.api.user.base import (
agents_collection,
conversations_collection,
generate_date_range,
generate_hourly_range,
generate_minute_range,
token_usage_collection,
user_logs_collection,
)
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.repositories.token_usage import TokenUsageRepository
from application.storage.db.repositories.user_logs import UserLogsRepository
from application.storage.db.session import db_readonly
analytics_ns = Namespace(
"analytics", description="Analytics and reporting operations", path="/api"
)
_FILTER_BUCKETS = {
"last_hour": ("minute", "%Y-%m-%d %H:%M:00", "YYYY-MM-DD HH24:MI:00"),
"last_24_hour": ("hour", "%Y-%m-%d %H:00", "YYYY-MM-DD HH24:00"),
"last_7_days": ("day", "%Y-%m-%d", "YYYY-MM-DD"),
"last_15_days": ("day", "%Y-%m-%d", "YYYY-MM-DD"),
"last_30_days": ("day", "%Y-%m-%d", "YYYY-MM-DD"),
}
def _range_for_filter(filter_option: str):
"""Return ``(start_date, end_date, bucket_unit, pg_fmt)`` for the filter.
Returns ``None`` on invalid filter.
"""
if filter_option not in _FILTER_BUCKETS:
return None
end_date = datetime.datetime.now(datetime.timezone.utc)
bucket_unit, _py_fmt, pg_fmt = _FILTER_BUCKETS[filter_option]
if filter_option == "last_hour":
start_date = end_date - datetime.timedelta(hours=1)
elif filter_option == "last_24_hour":
start_date = end_date - datetime.timedelta(hours=24)
else:
days = {
"last_7_days": 6,
"last_15_days": 14,
"last_30_days": 29,
}[filter_option]
start_date = end_date - datetime.timedelta(days=days)
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
end_date = end_date.replace(
hour=23, minute=59, second=59, microsecond=999999
)
return start_date, end_date, bucket_unit, pg_fmt
def _intervals_for_filter(filter_option, start_date, end_date):
if filter_option == "last_hour":
return generate_minute_range(start_date, end_date)
if filter_option == "last_24_hour":
return generate_hourly_range(start_date, end_date)
return generate_date_range(start_date, end_date)
def _resolve_api_key(conn, api_key_id, user_id):
"""Look up the ``agents.key`` value for a given agent id.
Scoped by ``user_id`` so an authenticated caller can't probe another
user's agents. Accepts either UUID or legacy Mongo ObjectId shape.
"""
if not api_key_id:
return None
agent = AgentsRepository(conn).get_any(api_key_id, user_id)
return (agent or {}).get("key") if agent else None
@analytics_ns.route("/get_message_analytics")
class GetMessageAnalytics(Resource):
get_message_analytics_model = api.model(
@@ -90,7 +32,13 @@ class GetMessageAnalytics(Resource):
required=False,
description="Filter option for analytics",
default="last_30_days",
enum=list(_FILTER_BUCKETS.keys()),
enum=[
"last_hour",
"last_24_hour",
"last_7_days",
"last_15_days",
"last_30_days",
],
),
},
)
@@ -102,54 +50,88 @@ class GetMessageAnalytics(Resource):
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json() or {}
data = request.get_json()
api_key_id = data.get("api_key_id")
filter_option = data.get("filter_option", "last_30_days")
window = _range_for_filter(filter_option)
if window is None:
return make_response(
jsonify({"success": False, "message": "Invalid option"}), 400
)
start_date, end_date, _bucket_unit, pg_fmt = window
try:
with db_readonly() as conn:
api_key = _resolve_api_key(conn, api_key_id, user)
# Count messages per bucket, filtered by the conversation's
# owner (user_id) and optionally the agent api_key. The
# ``user_id`` filter is always applied post-cutover to
# prevent cross-tenant leakage on admin dashboards.
clauses = [
"c.user_id = :user_id",
"m.timestamp >= :start",
"m.timestamp <= :end",
api_key = (
agents_collection.find_one({"_id": ObjectId(api_key_id), "user": user})[
"key"
]
params: dict = {
"user_id": user,
"start": start_date,
"end": end_date,
"fmt": pg_fmt,
}
if api_key:
clauses.append("c.api_key = :api_key")
params["api_key"] = api_key
where = " AND ".join(clauses)
sql = (
"SELECT to_char(m.timestamp AT TIME ZONE 'UTC', :fmt) AS bucket, "
"COUNT(*) AS count "
"FROM conversation_messages m "
"JOIN conversations c ON c.id = m.conversation_id "
f"WHERE {where} "
"GROUP BY bucket ORDER BY bucket ASC"
)
rows = conn.execute(_sql_text(sql), params).fetchall()
if api_key_id
else None
)
except Exception as err:
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
end_date = datetime.datetime.now(datetime.timezone.utc)
intervals = _intervals_for_filter(filter_option, start_date, end_date)
if filter_option == "last_hour":
start_date = end_date - datetime.timedelta(hours=1)
group_format = "%Y-%m-%d %H:%M:00"
elif filter_option == "last_24_hour":
start_date = end_date - datetime.timedelta(hours=24)
group_format = "%Y-%m-%d %H:00"
else:
if filter_option in ["last_7_days", "last_15_days", "last_30_days"]:
filter_days = (
6
if filter_option == "last_7_days"
else 14 if filter_option == "last_15_days" else 29
)
else:
return make_response(
jsonify({"success": False, "message": "Invalid option"}), 400
)
start_date = end_date - datetime.timedelta(days=filter_days)
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
end_date = end_date.replace(
hour=23, minute=59, second=59, microsecond=999999
)
group_format = "%Y-%m-%d"
try:
match_stage = {
"$match": {
"user": user,
}
}
if api_key:
match_stage["$match"]["api_key"] = api_key
pipeline = [
match_stage,
{"$unwind": "$queries"},
{
"$match": {
"queries.timestamp": {"$gte": start_date, "$lte": end_date}
}
},
{
"$group": {
"_id": {
"$dateToString": {
"format": group_format,
"date": "$queries.timestamp",
}
},
"count": {"$sum": 1},
}
},
{"$sort": {"_id": 1}},
]
message_data = conversations_collection.aggregate(pipeline)
if filter_option == "last_hour":
intervals = generate_minute_range(start_date, end_date)
elif filter_option == "last_24_hour":
intervals = generate_hourly_range(start_date, end_date)
else:
intervals = generate_date_range(start_date, end_date)
daily_messages = {interval: 0 for interval in intervals}
for row in rows:
daily_messages[row._mapping["bucket"]] = int(row._mapping["count"])
for entry in message_data:
daily_messages[entry["_id"]] = entry["count"]
except Exception as err:
current_app.logger.error(
f"Error getting message analytics: {err}", exc_info=True
@@ -170,7 +152,13 @@ class GetTokenAnalytics(Resource):
required=False,
description="Filter option for analytics",
default="last_30_days",
enum=list(_FILTER_BUCKETS.keys()),
enum=[
"last_hour",
"last_24_hour",
"last_7_days",
"last_15_days",
"last_30_days",
],
),
},
)
@@ -182,36 +170,123 @@ class GetTokenAnalytics(Resource):
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json() or {}
data = request.get_json()
api_key_id = data.get("api_key_id")
filter_option = data.get("filter_option", "last_30_days")
window = _range_for_filter(filter_option)
if window is None:
return make_response(
jsonify({"success": False, "message": "Invalid option"}), 400
)
start_date, end_date, bucket_unit, _pg_fmt = window
try:
with db_readonly() as conn:
api_key = _resolve_api_key(conn, api_key_id, user)
# ``bucketed_totals`` applies user_id / api_key filters
# directly — no need to reshape a Mongo pipeline.
rows = TokenUsageRepository(conn).bucketed_totals(
bucket_unit=bucket_unit,
user_id=user,
api_key=api_key,
timestamp_gte=start_date,
timestamp_lt=end_date,
)
api_key = (
agents_collection.find_one({"_id": ObjectId(api_key_id), "user": user})[
"key"
]
if api_key_id
else None
)
except Exception as err:
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
end_date = datetime.datetime.now(datetime.timezone.utc)
intervals = _intervals_for_filter(filter_option, start_date, end_date)
daily_token_usage = {interval: 0 for interval in intervals}
for entry in rows:
daily_token_usage[entry["bucket"]] = int(
entry["prompt_tokens"] + entry["generated_tokens"]
if filter_option == "last_hour":
start_date = end_date - datetime.timedelta(hours=1)
group_format = "%Y-%m-%d %H:%M:00"
group_stage = {
"$group": {
"_id": {
"minute": {
"$dateToString": {
"format": group_format,
"date": "$timestamp",
}
}
},
"total_tokens": {
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
},
}
}
elif filter_option == "last_24_hour":
start_date = end_date - datetime.timedelta(hours=24)
group_format = "%Y-%m-%d %H:00"
group_stage = {
"$group": {
"_id": {
"hour": {
"$dateToString": {
"format": group_format,
"date": "$timestamp",
}
}
},
"total_tokens": {
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
},
}
}
else:
if filter_option in ["last_7_days", "last_15_days", "last_30_days"]:
filter_days = (
6
if filter_option == "last_7_days"
else (14 if filter_option == "last_15_days" else 29)
)
else:
return make_response(
jsonify({"success": False, "message": "Invalid option"}), 400
)
start_date = end_date - datetime.timedelta(days=filter_days)
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
end_date = end_date.replace(
hour=23, minute=59, second=59, microsecond=999999
)
group_format = "%Y-%m-%d"
group_stage = {
"$group": {
"_id": {
"day": {
"$dateToString": {
"format": group_format,
"date": "$timestamp",
}
}
},
"total_tokens": {
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
},
}
}
try:
match_stage = {
"$match": {
"user_id": user,
"timestamp": {"$gte": start_date, "$lte": end_date},
}
}
if api_key:
match_stage["$match"]["api_key"] = api_key
token_usage_data = token_usage_collection.aggregate(
[
match_stage,
group_stage,
{"$sort": {"_id": 1}},
]
)
if filter_option == "last_hour":
intervals = generate_minute_range(start_date, end_date)
elif filter_option == "last_24_hour":
intervals = generate_hourly_range(start_date, end_date)
else:
intervals = generate_date_range(start_date, end_date)
daily_token_usage = {interval: 0 for interval in intervals}
for entry in token_usage_data:
if filter_option == "last_hour":
daily_token_usage[entry["_id"]["minute"]] = entry["total_tokens"]
elif filter_option == "last_24_hour":
daily_token_usage[entry["_id"]["hour"]] = entry["total_tokens"]
else:
daily_token_usage[entry["_id"]["day"]] = entry["total_tokens"]
except Exception as err:
current_app.logger.error(
f"Error getting token analytics: {err}", exc_info=True
@@ -232,7 +307,13 @@ class GetFeedbackAnalytics(Resource):
required=False,
description="Filter option for analytics",
default="last_30_days",
enum=list(_FILTER_BUCKETS.keys()),
enum=[
"last_hour",
"last_24_hour",
"last_7_days",
"last_15_days",
"last_30_days",
],
),
},
)
@@ -244,64 +325,128 @@ class GetFeedbackAnalytics(Resource):
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json() or {}
data = request.get_json()
api_key_id = data.get("api_key_id")
filter_option = data.get("filter_option", "last_30_days")
window = _range_for_filter(filter_option)
if window is None:
return make_response(
jsonify({"success": False, "message": "Invalid option"}), 400
)
start_date, end_date, _bucket_unit, pg_fmt = window
try:
with db_readonly() as conn:
api_key = _resolve_api_key(conn, api_key_id, user)
# Feedback lives inside the ``conversation_messages.feedback``
# JSONB as ``{"text": "like"|"dislike", "timestamp": "..."}``.
# There is no scalar ``feedback_timestamp`` column — extract
# the timestamp from the JSONB and cast it to timestamptz for
# the range filter + bucket grouping.
clauses = [
"c.user_id = :user_id",
"m.feedback IS NOT NULL",
"(m.feedback->>'timestamp')::timestamptz >= :start",
"(m.feedback->>'timestamp')::timestamptz <= :end",
api_key = (
agents_collection.find_one({"_id": ObjectId(api_key_id), "user": user})[
"key"
]
params: dict = {
"user_id": user,
"start": start_date,
"end": end_date,
"fmt": pg_fmt,
}
if api_key:
clauses.append("c.api_key = :api_key")
params["api_key"] = api_key
where = " AND ".join(clauses)
sql = (
"SELECT to_char("
"(m.feedback->>'timestamp')::timestamptz AT TIME ZONE 'UTC', :fmt"
") AS bucket, "
"SUM(CASE WHEN m.feedback->>'text' = 'like' THEN 1 ELSE 0 END) AS positive, "
"SUM(CASE WHEN m.feedback->>'text' = 'dislike' THEN 1 ELSE 0 END) AS negative "
"FROM conversation_messages m "
"JOIN conversations c ON c.id = m.conversation_id "
f"WHERE {where} "
"GROUP BY bucket ORDER BY bucket ASC"
)
rows = conn.execute(_sql_text(sql), params).fetchall()
if api_key_id
else None
)
except Exception as err:
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
end_date = datetime.datetime.now(datetime.timezone.utc)
intervals = _intervals_for_filter(filter_option, start_date, end_date)
if filter_option == "last_hour":
start_date = end_date - datetime.timedelta(hours=1)
group_format = "%Y-%m-%d %H:%M:00"
date_field = {
"$dateToString": {
"format": group_format,
"date": "$queries.feedback_timestamp",
}
}
elif filter_option == "last_24_hour":
start_date = end_date - datetime.timedelta(hours=24)
group_format = "%Y-%m-%d %H:00"
date_field = {
"$dateToString": {
"format": group_format,
"date": "$queries.feedback_timestamp",
}
}
else:
if filter_option in ["last_7_days", "last_15_days", "last_30_days"]:
filter_days = (
6
if filter_option == "last_7_days"
else (14 if filter_option == "last_15_days" else 29)
)
else:
return make_response(
jsonify({"success": False, "message": "Invalid option"}), 400
)
start_date = end_date - datetime.timedelta(days=filter_days)
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
end_date = end_date.replace(
hour=23, minute=59, second=59, microsecond=999999
)
group_format = "%Y-%m-%d"
date_field = {
"$dateToString": {
"format": group_format,
"date": "$queries.feedback_timestamp",
}
}
try:
match_stage = {
"$match": {
"queries.feedback_timestamp": {
"$gte": start_date,
"$lte": end_date,
},
"queries.feedback": {"$exists": True},
}
}
if api_key:
match_stage["$match"]["api_key"] = api_key
pipeline = [
match_stage,
{"$unwind": "$queries"},
{"$match": {"queries.feedback": {"$exists": True}}},
{
"$group": {
"_id": {"time": date_field, "feedback": "$queries.feedback"},
"count": {"$sum": 1},
}
},
{
"$group": {
"_id": "$_id.time",
"positive": {
"$sum": {
"$cond": [
{"$eq": ["$_id.feedback", "LIKE"]},
"$count",
0,
]
}
},
"negative": {
"$sum": {
"$cond": [
{"$eq": ["$_id.feedback", "DISLIKE"]},
"$count",
0,
]
}
},
}
},
{"$sort": {"_id": 1}},
]
feedback_data = conversations_collection.aggregate(pipeline)
if filter_option == "last_hour":
intervals = generate_minute_range(start_date, end_date)
elif filter_option == "last_24_hour":
intervals = generate_hourly_range(start_date, end_date)
else:
intervals = generate_date_range(start_date, end_date)
daily_feedback = {
interval: {"positive": 0, "negative": 0} for interval in intervals
}
for row in rows:
bucket = row._mapping["bucket"]
daily_feedback[bucket] = {
"positive": int(row._mapping["positive"] or 0),
"negative": int(row._mapping["negative"] or 0),
for entry in feedback_data:
daily_feedback[entry["_id"]] = {
"positive": entry["positive"],
"negative": entry["negative"],
}
except Exception as err:
current_app.logger.error(
@@ -339,89 +484,47 @@ class GetUserLogs(Resource):
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json() or {}
data = request.get_json()
page = int(data.get("page", 1))
api_key_id = data.get("api_key_id")
page_size = int(data.get("page_size", 10))
skip = (page - 1) * page_size
try:
with db_readonly() as conn:
api_key = _resolve_api_key(conn, api_key_id, user)
logs_repo = UserLogsRepository(conn)
if api_key:
# ``find_by_api_key`` filters on ``data->>'api_key'``
# — the PG shape of the legacy top-level ``api_key``
# filter. Paginate client-side using offset/limit.
all_rows = logs_repo.find_by_api_key(api_key)
offset = (page - 1) * page_size
window = all_rows[offset: offset + page_size + 1]
items = window
else:
items, has_more_flag = logs_repo.list_paginated(
user_id=user,
page=page,
page_size=page_size,
)
# list_paginated already trims to page_size and
# returns has_more separately.
results = [
{
"id": str(item.get("id") or item.get("_id")),
"action": (item.get("data") or {}).get("action"),
"level": (item.get("data") or {}).get("level"),
"user": item.get("user_id"),
"question": (item.get("data") or {}).get("question"),
"sources": (item.get("data") or {}).get("sources"),
"retriever_params": (item.get("data") or {}).get(
"retriever_params"
),
"timestamp": (
item["timestamp"].isoformat()
if hasattr(item.get("timestamp"), "isoformat")
else item.get("timestamp")
),
}
for item in items
]
return make_response(
jsonify(
{
"success": True,
"logs": results,
"page": page,
"page_size": page_size,
"has_more": has_more_flag,
}
),
200,
)
has_more = len(items) > page_size
items = items[:page_size]
results = [
{
"id": str(item.get("id") or item.get("_id")),
"action": (item.get("data") or {}).get("action"),
"level": (item.get("data") or {}).get("level"),
"user": item.get("user_id"),
"question": (item.get("data") or {}).get("question"),
"sources": (item.get("data") or {}).get("sources"),
"retriever_params": (item.get("data") or {}).get(
"retriever_params"
),
"timestamp": (
item["timestamp"].isoformat()
if hasattr(item.get("timestamp"), "isoformat")
else item.get("timestamp")
),
}
for item in items
]
except Exception as err:
current_app.logger.error(
f"Error getting user logs: {err}", exc_info=True
api_key = (
agents_collection.find_one({"_id": ObjectId(api_key_id)})["key"]
if api_key_id
else None
)
except Exception as err:
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
query = {"user": user}
if api_key:
query = {"api_key": api_key}
items_cursor = (
user_logs_collection.find(query)
.sort("timestamp", -1)
.skip(skip)
.limit(page_size + 1)
)
items = list(items_cursor)
results = [
{
"id": str(item.get("_id")),
"action": item.get("action"),
"level": item.get("level"),
"user": item.get("user"),
"question": item.get("question"),
"sources": item.get("sources"),
"retriever_params": item.get("retriever_params"),
"timestamp": item.get("timestamp"),
}
for item in items[:page_size]
]
has_more = len(items) > page_size
return make_response(
jsonify(

View File

@@ -4,16 +4,13 @@ import os
import tempfile
from pathlib import Path
import uuid
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, Namespace, Resource
from application.api import api
from application.cache import get_redis_instance
from application.core.settings import settings
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.session import db_readonly
from application.stt.constants import (
SUPPORTED_AUDIO_EXTENSIONS,
SUPPORTED_AUDIO_MIME_TYPES,
@@ -51,13 +48,14 @@ def _resolve_authenticated_user():
return safe_filename(decoded_token.get("sub"))
if api_key:
with db_readonly() as conn:
agent = AgentsRepository(conn).find_by_key(api_key)
from application.api.user.base import agents_collection
agent = agents_collection.find_one({"key": api_key})
if not agent:
return make_response(
jsonify({"success": False, "message": "Invalid API key"}), 401
)
return safe_filename(agent.get("user_id"))
return safe_filename(agent.get("user"))
return None
@@ -159,7 +157,7 @@ class StoreAttachment(Resource):
for idx, file in enumerate(files):
try:
attachment_id = uuid.uuid4()
attachment_id = ObjectId()
original_filename = safe_filename(os.path.basename(file.filename))
_enforce_uploaded_audio_size_limit(file, original_filename)
relative_path = f"{settings.UPLOAD_FOLDER}/{user}/attachments/{str(attachment_id)}/{original_filename}"
@@ -214,10 +212,6 @@ class StoreAttachment(Resource):
{
"success": True,
"task_id": tasks[0]["task_id"],
# Surface the attachment_id so the frontend
# can correlate ``attachment.*`` SSE events
# to this row and skip the polling fallback.
"attachment_id": tasks[0]["attachment_id"],
"message": "File uploaded successfully. Processing started.",
}
),

View File

@@ -8,15 +8,15 @@ import uuid
from functools import wraps
from typing import Optional, Tuple
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, Response
from pymongo import ReturnDocument
from werkzeug.utils import secure_filename
from sqlalchemy import text as _sql_text
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.storage.db.base_repository import looks_like_uuid, row_to_dict
from application.storage.db.dual_write import dual_write
from application.storage.db.repositories.users import UsersRepository
from application.storage.db.session import db_readonly, db_session
from application.storage.storage_creator import StorageCreator
from application.vectorstore.vector_creator import VectorCreator
@@ -24,6 +24,56 @@ from application.vectorstore.vector_creator import VectorCreator
storage = StorageCreator.get_storage()
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
conversations_collection = db["conversations"]
sources_collection = db["sources"]
prompts_collection = db["prompts"]
feedback_collection = db["feedback"]
agents_collection = db["agents"]
agent_folders_collection = db["agent_folders"]
token_usage_collection = db["token_usage"]
shared_conversations_collections = db["shared_conversations"]
users_collection = db["users"]
user_logs_collection = db["user_logs"]
user_tools_collection = db["user_tools"]
attachments_collection = db["attachments"]
workflow_runs_collection = db["workflow_runs"]
workflows_collection = db["workflows"]
workflow_nodes_collection = db["workflow_nodes"]
workflow_edges_collection = db["workflow_edges"]
try:
agents_collection.create_index(
[("shared", 1)],
name="shared_index",
background=True,
)
users_collection.create_index("user_id", unique=True)
workflows_collection.create_index(
[("user", 1)], name="workflow_user_index", background=True
)
workflow_nodes_collection.create_index(
[("workflow_id", 1)], name="node_workflow_index", background=True
)
workflow_nodes_collection.create_index(
[("workflow_id", 1), ("graph_version", 1)],
name="node_workflow_graph_version_index",
background=True,
)
workflow_edges_collection.create_index(
[("workflow_id", 1)], name="edge_workflow_index", background=True
)
workflow_edges_collection.create_index(
[("workflow_id", 1), ("graph_version", 1)],
name="edge_workflow_graph_version_index",
background=True,
)
except Exception as e:
print("Error creating indexes:", e)
current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
@@ -55,115 +105,69 @@ def generate_date_range(start_date, end_date):
def ensure_user_doc(user_id):
"""
Ensure a Postgres ``users`` row exists for ``user_id``.
Returns the row as a dict with the shape legacy callers expect — in
particular ``user_id`` and ``agent_preferences`` (with ``pinned`` and
``shared_with_me`` list keys always present).
Ensure user document exists with proper agent preferences structure.
Args:
user_id: The user ID to ensure
Returns:
The user document as a dict.
The user document
"""
with db_session() as conn:
user_doc = UsersRepository(conn).upsert(user_id)
default_prefs = {
"pinned": [],
"shared_with_me": [],
}
user_doc = users_collection.find_one_and_update(
{"user_id": user_id},
{"$setOnInsert": {"agent_preferences": default_prefs}},
upsert=True,
return_document=ReturnDocument.AFTER,
)
prefs = user_doc.get("agent_preferences", {})
updates = {}
if "pinned" not in prefs:
updates["agent_preferences.pinned"] = []
if "shared_with_me" not in prefs:
updates["agent_preferences.shared_with_me"] = []
if updates:
users_collection.update_one({"user_id": user_id}, {"$set": updates})
user_doc = users_collection.find_one({"user_id": user_id})
dual_write(UsersRepository, lambda repo: repo.upsert(user_id))
prefs = user_doc.get("agent_preferences") or {}
if not isinstance(prefs, dict):
prefs = {}
prefs.setdefault("pinned", [])
prefs.setdefault("shared_with_me", [])
user_doc["agent_preferences"] = prefs
return user_doc
def resolve_tool_details(tool_ids):
"""
Resolve tool IDs to their display details.
Accepts Postgres UUIDs, legacy Mongo ObjectId strings, or the
synthetic ids of default chat tools / agent-selectable builtins
(mixed lists are supported). Synthetic ids are resolved in memory;
real ids are looked up via ``get_any``. Unknown ids are silently
skipped.
Resolve tool IDs to their details.
Args:
tool_ids: List of tool IDs (UUIDs, legacy ObjectId strings, or
synthetic default-tool / builtin ids).
tool_ids: List of tool IDs
Returns:
List of tool details with ``id``, ``name``, and ``display_name``.
List of tool details with id, name, and display_name
"""
if not tool_ids:
return []
from application.agents.default_tools import (
is_synthesized_tool_id,
synthesize_tool_by_name,
synthesized_tool_name_for_id,
)
uuid_ids: list[str] = []
legacy_ids: list[str] = []
default_details: list[dict] = []
valid_ids = []
for tid in tool_ids:
if not tid:
try:
valid_ids.append(ObjectId(tid))
except Exception:
continue
tid_str = str(tid)
if is_synthesized_tool_id(tid_str):
synth = synthesize_tool_by_name(synthesized_tool_name_for_id(tid_str))
if synth is not None:
default_details.append(
{
"id": tid_str,
"name": synth.get("name", ""),
"display_name": synth.get("display_name", ""),
}
)
continue
if looks_like_uuid(tid_str):
uuid_ids.append(tid_str)
else:
legacy_ids.append(tid_str)
if not uuid_ids and not legacy_ids:
return default_details
rows: list[dict] = []
with db_readonly() as conn:
if uuid_ids:
result = conn.execute(
_sql_text(
"SELECT * FROM user_tools "
"WHERE id = ANY(CAST(:ids AS uuid[]))"
),
{"ids": uuid_ids},
)
rows.extend(row_to_dict(r) for r in result.fetchall())
if legacy_ids:
result = conn.execute(
_sql_text(
"SELECT * FROM user_tools "
"WHERE legacy_mongo_id = ANY(:ids)"
),
{"ids": legacy_ids},
)
rows.extend(row_to_dict(r) for r in result.fetchall())
return default_details + [
tools = user_tools_collection.find(
{"_id": {"$in": valid_ids}}
) if valid_ids else []
return [
{
"id": str(tool.get("id") or tool.get("legacy_mongo_id") or ""),
"name": tool.get("name", "") or "",
"display_name": (
tool.get("custom_name")
or tool.get("display_name")
or tool.get("name", "")
or ""
),
"id": str(tool["_id"]),
"name": tool.get("name", ""),
"display_name": tool.get("customName")
or tool.get("displayName")
or tool.get("name", ""),
}
for tool in rows
for tool in tools
]
@@ -233,15 +237,14 @@ def require_agent(func):
@wraps(func)
def wrapper(*args, **kwargs):
from application.storage.db.repositories.agents import AgentsRepository
webhook_token = kwargs.get("webhook_token")
if not webhook_token:
return make_response(
jsonify({"success": False, "message": "Webhook token missing"}), 400
)
with db_readonly() as conn:
agent = AgentsRepository(conn).find_by_webhook_token(webhook_token)
agent = agents_collection.find_one(
{"incoming_webhook_token": webhook_token}, {"_id": 1}
)
if not agent:
current_app.logger.warning(
f"Webhook attempt with invalid token: {webhook_token}"
@@ -250,7 +253,7 @@ def require_agent(func):
jsonify({"success": False, "message": "Agent not found"}), 404
)
kwargs["agent"] = agent
kwargs["agent_id_str"] = str(agent["id"])
kwargs["agent_id_str"] = str(agent["_id"])
return func(*args, **kwargs)
return wrapper

View File

@@ -2,19 +2,14 @@
import datetime
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, Namespace, Resource
from sqlalchemy import text as sql_text
from application.api import api
from application.api.answer.services.conversation_service import (
TERMINATED_RESPONSE_PLACEHOLDER,
)
from application.storage.db.base_repository import looks_like_uuid, row_to_dict
from application.storage.db.repositories.attachments import AttachmentsRepository
from application.api.user.base import attachments_collection, conversations_collection
from application.storage.db.dual_write import dual_write
from application.storage.db.repositories.conversations import ConversationsRepository
from application.storage.db.repositories.message_events import MessageEventsRepository
from application.storage.db.session import db_readonly, db_session
from application.utils import check_required_fields
conversations_ns = Namespace(
@@ -39,16 +34,21 @@ class DeleteConversation(Resource):
)
user_id = decoded_token["sub"]
try:
with db_session() as conn:
repo = ConversationsRepository(conn)
conv = repo.get_any(conversation_id, user_id)
if conv is not None:
repo.delete(str(conv["id"]), user_id)
conversations_collection.delete_one(
{"_id": ObjectId(conversation_id), "user": user_id}
)
except Exception as err:
current_app.logger.error(
f"Error deleting conversation: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
def _pg_delete(repo: ConversationsRepository) -> None:
conv = repo.get_by_legacy_id(conversation_id)
if conv is not None:
repo.delete(conv["id"], user_id)
dual_write(ConversationsRepository, _pg_delete)
return make_response(jsonify({"success": True}), 200)
@@ -63,13 +63,17 @@ class DeleteAllConversations(Resource):
return make_response(jsonify({"success": False}), 401)
user_id = decoded_token.get("sub")
try:
with db_session() as conn:
ConversationsRepository(conn).delete_all_for_user(user_id)
conversations_collection.delete_many({"user": user_id})
except Exception as err:
current_app.logger.error(
f"Error deleting all conversations: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
dual_write(
ConversationsRepository,
lambda r, uid=user_id: r.delete_all_for_user(uid),
)
return make_response(jsonify({"success": True}), 200)
@@ -82,21 +86,26 @@ class GetConversations(Resource):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user_id = decoded_token.get("sub")
try:
with db_readonly() as conn:
conversations = ConversationsRepository(conn).list_for_user(
user_id, limit=30
conversations = (
conversations_collection.find(
{
"$or": [
{"api_key": {"$exists": False}},
{"agent_id": {"$exists": True}},
],
"user": decoded_token.get("sub"),
}
)
.sort("date", -1)
.limit(30)
)
list_conversations = [
{
"id": str(conversation["id"]),
"id": str(conversation["_id"]),
"name": conversation["name"],
"agent_id": (
str(conversation["agent_id"])
if conversation.get("agent_id")
else None
),
"agent_id": conversation.get("agent_id", None),
"is_shared_usage": conversation.get("is_shared_usage", False),
"shared_token": conversation.get("shared_token", None),
}
@@ -125,74 +134,38 @@ class GetSingleConversation(Resource):
return make_response(
jsonify({"success": False, "message": "ID is required"}), 400
)
user_id = decoded_token.get("sub")
try:
with db_readonly() as conn:
repo = ConversationsRepository(conn)
conversation = repo.get_any(conversation_id, user_id)
if not conversation:
return make_response(jsonify({"status": "not found"}), 404)
conv_pg_id = str(conversation["id"])
messages = repo.get_messages(conv_pg_id)
conversation = conversations_collection.find_one(
{"_id": ObjectId(conversation_id), "user": decoded_token.get("sub")}
)
if not conversation:
return make_response(jsonify({"status": "not found"}), 404)
# Process queries to include attachment names
# Resolve attachment details (id, fileName) for each message.
attachments_repo = AttachmentsRepository(conn)
queries = []
for msg in messages:
metadata = msg.get("metadata") or {}
query = {
"prompt": msg.get("prompt"),
"response": msg.get("response"),
"thought": msg.get("thought"),
"sources": msg.get("sources") or [],
"tool_calls": msg.get("tool_calls") or [],
"timestamp": msg.get("timestamp"),
"model_id": msg.get("model_id"),
# Lets the client distinguish placeholder rows from
# finalised answers and tail-poll in-flight ones.
"message_id": str(msg["id"]) if msg.get("id") else None,
"status": msg.get("status"),
"request_id": msg.get("request_id"),
"last_heartbeat_at": metadata.get("last_heartbeat_at"),
}
if metadata:
query["metadata"] = metadata
# Feedback on conversation_messages is a JSONB blob with
# shape {"text": <str>, "timestamp": <iso>}. The legacy
# frontend consumed a flat scalar feedback string, so
# unwrap the ``text`` field for compat.
feedback = msg.get("feedback")
if feedback is not None:
if isinstance(feedback, dict):
query["feedback"] = feedback.get("text")
if feedback.get("timestamp"):
query["feedback_timestamp"] = feedback["timestamp"]
else:
query["feedback"] = feedback
attachments = msg.get("attachments") or []
if attachments:
attachment_details = []
for attachment_id in attachments:
try:
att = attachments_repo.get_any(
str(attachment_id), user_id
queries = conversation["queries"]
for query in queries:
if "attachments" in query and query["attachments"]:
attachment_details = []
for attachment_id in query["attachments"]:
try:
attachment = attachments_collection.find_one(
{"_id": ObjectId(attachment_id)}
)
if attachment:
attachment_details.append(
{
"id": str(attachment["_id"]),
"fileName": attachment.get(
"filename", "Unknown file"
),
}
)
if att:
attachment_details.append(
{
"id": str(att["id"]),
"fileName": att.get(
"filename", "Unknown file"
),
}
)
except Exception as e:
current_app.logger.error(
f"Error retrieving attachment {attachment_id}: {e}",
exc_info=True,
)
query["attachments"] = attachment_details
queries.append(query)
except Exception as e:
current_app.logger.error(
f"Error retrieving attachment {attachment_id}: {e}",
exc_info=True,
)
query["attachments"] = attachment_details
except Exception as err:
current_app.logger.error(
f"Error retrieving conversation: {err}", exc_info=True
@@ -200,9 +173,7 @@ class GetSingleConversation(Resource):
return make_response(jsonify({"success": False}), 400)
data = {
"queries": queries,
"agent_id": (
str(conversation["agent_id"]) if conversation.get("agent_id") else None
),
"agent_id": conversation.get("agent_id"),
"is_shared_usage": conversation.get("is_shared_usage", False),
"shared_token": conversation.get("shared_token", None),
}
@@ -236,16 +207,22 @@ class UpdateConversationName(Resource):
return missing_fields
user_id = decoded_token.get("sub")
try:
with db_session() as conn:
repo = ConversationsRepository(conn)
conv = repo.get_any(data["id"], user_id)
if conv is not None:
repo.rename(str(conv["id"]), user_id, data["name"])
conversations_collection.update_one(
{"_id": ObjectId(data["id"]), "user": user_id},
{"$set": {"name": data["name"]}},
)
except Exception as err:
current_app.logger.error(
f"Error updating conversation name: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
def _pg_rename(repo: ConversationsRepository) -> None:
conv = repo.get_by_legacy_id(data["id"])
if conv is not None:
repo.rename(conv["id"], user_id, data["name"])
dual_write(ConversationsRepository, _pg_rename)
return make_response(jsonify({"success": True}), 200)
@@ -283,111 +260,61 @@ class SubmitFeedback(Resource):
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
user_id = decoded_token.get("sub")
feedback_value = data["feedback"]
question_index = int(data["question_index"])
# Normalize string feedback to lowercase so analytics queries
# (which match 'like'/'dislike') count rows correctly. Tolerate
# legacy uppercase clients on ingest. Non-string values pass through.
if isinstance(feedback_value, str):
feedback_value = feedback_value.lower()
feedback_payload = (
None
if feedback_value is None
else {
"text": feedback_value,
"timestamp": datetime.datetime.now(
datetime.timezone.utc
).isoformat(),
}
)
try:
with db_session() as conn:
repo = ConversationsRepository(conn)
conv = repo.get_any(data["conversation_id"], user_id)
if conv is None:
return make_response(
jsonify({"success": False, "message": "Not found"}), 404
)
repo.set_feedback(str(conv["id"]), question_index, feedback_payload)
if data["feedback"] is None:
# Remove feedback and feedback_timestamp if feedback is null
conversations_collection.update_one(
{
"_id": ObjectId(data["conversation_id"]),
"user": decoded_token.get("sub"),
f"queries.{data['question_index']}": {"$exists": True},
},
{
"$unset": {
f"queries.{data['question_index']}.feedback": "",
f"queries.{data['question_index']}.feedback_timestamp": "",
}
},
)
else:
# Set feedback and feedback_timestamp if feedback has a value
conversations_collection.update_one(
{
"_id": ObjectId(data["conversation_id"]),
"user": decoded_token.get("sub"),
f"queries.{data['question_index']}": {"$exists": True},
},
{
"$set": {
f"queries.{data['question_index']}.feedback": data[
"feedback"
],
f"queries.{data['question_index']}.feedback_timestamp": datetime.datetime.now(
datetime.timezone.utc
),
}
},
)
except Exception as err:
current_app.logger.error(f"Error submitting feedback: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True}), 200)
@conversations_ns.route("/messages/<string:message_id>/tail")
class GetMessageTail(Resource):
@api.doc(
description=(
"Current state of one conversation_messages row, scoped to the "
"authenticated user. Used to reconnect to an in-flight stream "
"after a refresh."
),
params={"message_id": "Message UUID"},
)
def get(self, message_id):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
if not looks_like_uuid(message_id):
return make_response(
jsonify({"success": False, "message": "Invalid message id"}), 400
)
user_id = decoded_token.get("sub")
try:
with db_readonly() as conn:
# Owner-or-shared, matching ``ConversationsRepository.get``.
row = conn.execute(
sql_text(
"SELECT m.* FROM conversation_messages m "
"JOIN conversations c ON c.id = m.conversation_id "
"WHERE m.id = CAST(:mid AS uuid) "
"AND (c.user_id = :uid OR :uid = ANY(c.shared_with))"
),
{"mid": message_id, "uid": user_id},
).fetchone()
if row is None:
return make_response(jsonify({"status": "not found"}), 404)
msg = row_to_dict(row)
# Mid-stream the row's response is the placeholder; rebuild
# the live partial from the journal so /tail mirrors SSE.
status = msg.get("status")
response = msg.get("response")
thought = msg.get("thought")
sources = msg.get("sources") or []
tool_calls = msg.get("tool_calls") or []
if status in ("pending", "streaming") and (
response == TERMINATED_RESPONSE_PLACEHOLDER
):
partial = MessageEventsRepository(conn).reconstruct_partial(
message_id
)
response = partial["response"]
thought = partial["thought"] or thought
if partial["sources"]:
sources = partial["sources"]
if partial["tool_calls"]:
tool_calls = partial["tool_calls"]
except Exception as err:
current_app.logger.error(
f"Error tailing message {message_id}: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
metadata = msg.get("message_metadata") or {}
return make_response(
jsonify(
{
"message_id": str(msg["id"]),
"status": status,
"response": response,
"thought": thought,
"sources": sources,
"tool_calls": tool_calls,
"request_id": msg.get("request_id"),
"last_heartbeat_at": metadata.get("last_heartbeat_at"),
"error": metadata.get("error"),
}
),
200,
# Dual-write to Postgres: mirror the per-message feedback set/unset.
feedback_value = data["feedback"]
question_index = int(data["question_index"])
feedback_payload = (
None if feedback_value is None
else {"text": feedback_value, "timestamp": datetime.datetime.now(
datetime.timezone.utc
).isoformat()}
)
def _pg_feedback(repo: ConversationsRepository) -> None:
conv = repo.get_by_legacy_id(data["conversation_id"])
if conv is not None:
repo.set_feedback(conv["id"], question_index, feedback_payload)
dual_write(ConversationsRepository, _pg_feedback)
return make_response(jsonify({"success": True}), 200)

View File

@@ -1,294 +0,0 @@
"""Per-Celery-task idempotency wrapper backed by ``task_dedup``."""
from __future__ import annotations
import functools
import inspect
import logging
import threading
import uuid
from typing import Any, Callable, Optional
from application.storage.db.repositories.idempotency import IdempotencyRepository
from application.storage.db.session import db_readonly, db_session
logger = logging.getLogger(__name__)
# Poison-loop cap; transient-failure headroom without infinite retry.
MAX_TASK_ATTEMPTS = 5
# 30s heartbeat / 60s TTL → ~2 missed ticks of slack before reclaim.
LEASE_TTL_SECONDS = 60
LEASE_HEARTBEAT_INTERVAL = 30
# 10 × 60s ≈ 5 min of deferral before giving up on a held lease.
LEASE_RETRY_MAX = 10
def with_idempotency(
task_name: str,
*,
on_poison: Optional[Callable[[str, dict], None]] = None,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""Short-circuit on completed key; gate concurrent runs via a lease.
The guard key is the caller's ``idempotency_key``, or one synthesized
from ``source_id`` so a keyless dispatch is still poison-guarded.
Entry short-circuits:
- completed row → return cached result
- live lease held → retry(countdown=LEASE_TTL_SECONDS)
- attempt_count > MAX_TASK_ATTEMPTS → poison alert; ``on_poison`` fires
Success writes ``completed``; exceptions leave ``pending`` for
autoretry until the poison-loop guard trips.
"""
def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
@functools.wraps(fn)
def wrapper(self, *args: Any, idempotency_key: Any = None, **kwargs: Any) -> Any:
explicit_key = (
idempotency_key
if isinstance(idempotency_key, str) and idempotency_key
else None
)
# A keyless dispatch still gets the guard via a synthesized key;
# None means no anchor exists — run unguarded, as before.
key = explicit_key or _synthesize_guard_key(task_name, kwargs)
if key is None:
return fn(self, *args, idempotency_key=idempotency_key, **kwargs)
cached = _lookup_completed(key)
if cached is not None:
logger.info(
"idempotency hit for task=%s key=%s — returning cached result",
task_name, key,
)
return cached
owner_id = str(uuid.uuid4())
attempt = _try_claim_lease(
key, task_name, _safe_task_id(self), owner_id,
)
if attempt is None:
# Live lease held by another worker. Re-queue and bail
# quickly — by the time the retry fires (LEASE_TTL
# seconds), Worker 1 has either finalised (we'll hit
# ``_lookup_completed`` and return cached) or its lease
# has expired and we can claim.
logger.info(
"idempotency: live lease held; deferring task=%s key=%s",
task_name, key,
)
raise self.retry(
countdown=LEASE_TTL_SECONDS,
max_retries=LEASE_RETRY_MAX,
)
if attempt > MAX_TASK_ATTEMPTS:
logger.error(
"idempotency poison-loop guard: task=%s key=%s attempts=%s",
task_name, key, attempt,
extra={
"alert": "idempotency_poison_loop",
"task_name": task_name,
"idempotency_key": key,
"attempts": attempt,
},
)
poisoned = {
"success": False,
"error": "idempotency poison-loop guard tripped",
"attempts": attempt,
}
_finalize(key, poisoned, status="failed")
_run_poison_hook(
on_poison, task_name, fn, self, args, kwargs, idempotency_key,
)
return poisoned
heartbeat_thread, heartbeat_stop = _start_lease_heartbeat(
key, owner_id,
)
try:
result = fn(self, *args, idempotency_key=idempotency_key, **kwargs)
_finalize(key, result, status="completed")
return result
except Exception:
# Drop the lease so the next retry doesn't wait LEASE_TTL.
_release_lease(key, owner_id)
raise
finally:
_stop_lease_heartbeat(heartbeat_thread, heartbeat_stop)
return wrapper
return decorator
def _synthesize_guard_key(task_name: str, kwargs: dict) -> Optional[str]:
"""Derive a deterministic guard key from ``source_id`` for a keyless dispatch.
``source_id`` is stable across broker redeliveries and unique per
upload, so the poison-loop counter survives an OOM SIGKILL. Returns
``None`` when absent — the dispatch then runs unguarded as before.
"""
source_id = kwargs.get("source_id")
if source_id:
return f"auto:{task_name}:{source_id}"
return None
def _run_poison_hook(
on_poison: Optional[Callable[[str, dict], None]],
task_name: str,
fn: Callable[..., Any],
task_self: Any,
args: tuple,
kwargs: dict,
idempotency_key: Any,
) -> None:
"""Invoke a task's poison-path hook with named call args; swallow failures.
A hook failure must never change the poison-guard outcome.
"""
if on_poison is None:
return
try:
bound = inspect.signature(fn).bind_partial(
task_self, *args, idempotency_key=idempotency_key, **kwargs,
)
on_poison(task_name, dict(bound.arguments))
except Exception:
logger.exception(
"idempotency: poison hook failed for task=%s", task_name,
)
def _lookup_completed(key: str) -> Any:
"""Return cached ``result_json`` if a completed row exists for ``key``, else None."""
with db_readonly() as conn:
row = IdempotencyRepository(conn).get_task(key)
if row is None:
return None
if row.get("status") != "completed":
return None
return row.get("result_json")
def _try_claim_lease(
key: str, task_name: str, task_id: str, owner_id: str,
) -> Optional[int]:
"""Atomic CAS; returns ``attempt_count`` or ``None`` when held.
DB outage → treated as ``attempt=1`` so transient failures don't
block all task execution; reconciler repairs the lease columns.
"""
try:
with db_session() as conn:
return IdempotencyRepository(conn).try_claim_lease(
key=key,
task_name=task_name,
task_id=task_id,
owner_id=owner_id,
ttl_seconds=LEASE_TTL_SECONDS,
)
except Exception:
logger.exception(
"idempotency lease-claim failed for key=%s task=%s", key, task_name,
)
return 1
def _finalize(key: str, result_json: Any, *, status: str) -> None:
"""Best-effort terminal write. Never let DB outage fail the task."""
try:
with db_session() as conn:
IdempotencyRepository(conn).finalize_task(
key=key, result_json=result_json, status=status,
)
except Exception:
logger.exception(
"idempotency finalize failed for key=%s status=%s", key, status,
)
def _release_lease(key: str, owner_id: str) -> None:
"""Best-effort lease release on the wrapper's exception path."""
try:
with db_session() as conn:
IdempotencyRepository(conn).release_lease(key, owner_id)
except Exception:
logger.exception("idempotency release-lease failed for key=%s", key)
def _start_lease_heartbeat(
key: str, owner_id: str,
) -> tuple[threading.Thread, threading.Event]:
"""Spawn a daemon thread that bumps ``lease_expires_at`` every
:data:`LEASE_HEARTBEAT_INTERVAL` seconds until ``stop_event`` fires.
Mirrors ``application.worker._start_ingest_heartbeat`` so the two
durability heartbeats share shape and cadence.
"""
stop_event = threading.Event()
thread = threading.Thread(
target=_lease_heartbeat_loop,
args=(key, owner_id, stop_event, LEASE_HEARTBEAT_INTERVAL),
daemon=True,
name=f"idempotency-lease-heartbeat:{key[:32]}",
)
thread.start()
return thread, stop_event
def _stop_lease_heartbeat(
thread: threading.Thread, stop_event: threading.Event,
) -> None:
"""Signal the heartbeat thread to exit and join with a short timeout."""
stop_event.set()
thread.join(timeout=10)
def _lease_heartbeat_loop(
key: str,
owner_id: str,
stop_event: threading.Event,
interval: int,
) -> None:
"""Refresh the lease until ``stop_event`` is set or ownership is lost.
A failed refresh (rowcount 0) means another worker stole the lease
after expiry — at that point the damage is already possible, so we
log and keep ticking. Don't escalate to thread death; the main task
body needs to keep running so its outcome is at least *recorded*.
"""
while not stop_event.wait(interval):
try:
with db_session() as conn:
still_owned = IdempotencyRepository(conn).refresh_lease(
key=key, owner_id=owner_id, ttl_seconds=LEASE_TTL_SECONDS,
)
if not still_owned:
logger.warning(
"idempotency lease lost mid-task for key=%s "
"(another worker may have taken over)",
key,
)
except Exception:
logger.exception(
"idempotency lease-heartbeat tick failed for key=%s", key,
)
def _safe_task_id(task_self: Any) -> str:
"""Best-effort extraction of ``self.request.id`` from a Celery task."""
try:
request = getattr(task_self, "request", None)
task_id: Optional[str] = (
getattr(request, "id", None) if request is not None else None
)
except Exception:
task_id = None
return task_id or "unknown"

View File

@@ -1,135 +1,18 @@
"""Model routes.
- ``GET /api/models`` — list available models for the current user.
Combines the built-in catalog with the user's BYOM records.
- ``GET/POST/PATCH/DELETE /api/user/models[/<id>]`` — CRUD for the
user's own OpenAI-compatible model registrations (BYOM).
- ``POST /api/user/models/<id>/test`` — sanity-check the upstream
endpoint with a tiny request.
Every BYOM endpoint is user-scoped at the repository layer
(every query filters on ``user_id`` from ``request.decoded_token``).
"""
from __future__ import annotations
import logging
import requests
from flask import current_app, jsonify, make_response, request
from flask import current_app, jsonify, make_response
from flask_restx import Namespace, Resource
from application.api import api
from application.core.model_registry import ModelRegistry
from application.security.safe_url import (
UnsafeUserUrlError,
pinned_post,
validate_user_base_url,
)
from application.storage.db.repositories.user_custom_models import (
UserCustomModelsRepository,
)
from application.storage.db.session import db_readonly, db_session
from application.utils import check_required_fields
logger = logging.getLogger(__name__)
from application.core.model_settings import ModelRegistry
models_ns = Namespace("models", description="Available models", path="/api")
_CONTEXT_WINDOW_MIN = 1_000
_CONTEXT_WINDOW_MAX = 10_000_000
def _user_id_or_401():
decoded_token = request.decoded_token
if not decoded_token:
return None, make_response(jsonify({"success": False}), 401)
user_id = decoded_token.get("sub")
if not user_id:
return None, make_response(jsonify({"success": False}), 401)
return user_id, None
def _normalize_capabilities(raw) -> dict:
"""Coerce + bound the user-supplied capabilities payload."""
raw = raw or {}
out = {}
if "supports_tools" in raw:
out["supports_tools"] = bool(raw["supports_tools"])
if "supports_structured_output" in raw:
out["supports_structured_output"] = bool(raw["supports_structured_output"])
if "supports_streaming" in raw:
out["supports_streaming"] = bool(raw["supports_streaming"])
if "attachments" in raw:
atts = raw["attachments"] or []
if not isinstance(atts, list):
raise ValueError("'capabilities.attachments' must be a list")
coerced = [str(a) for a in atts]
# Reject unknown aliases at the API boundary so bad payloads
# never reach the registry layer (where lenient expansion just
# drops them). Raw MIME types (containing ``/``) pass through
# unchanged for parity with the built-in YAML schema.
from application.core.model_yaml import builtin_attachment_aliases
aliases = builtin_attachment_aliases()
for entry in coerced:
if "/" in entry:
continue
if entry not in aliases:
valid = ", ".join(sorted(aliases.keys())) or "<none defined>"
raise ValueError(
f"unknown attachment alias '{entry}' in "
f"'capabilities.attachments'. Valid aliases: {valid}, "
f"or use a raw MIME type like 'image/png'."
)
out["attachments"] = coerced
if "context_window" in raw:
try:
cw = int(raw["context_window"])
except (TypeError, ValueError):
raise ValueError("'capabilities.context_window' must be an integer")
if not (_CONTEXT_WINDOW_MIN <= cw <= _CONTEXT_WINDOW_MAX):
raise ValueError(
f"'capabilities.context_window' must be between "
f"{_CONTEXT_WINDOW_MIN} and {_CONTEXT_WINDOW_MAX}"
)
out["context_window"] = cw
return out
def _row_to_response(row: dict) -> dict:
"""Wire-format projection — never includes the API key."""
return {
"id": str(row["id"]),
"upstream_model_id": row["upstream_model_id"],
"display_name": row["display_name"],
"description": row.get("description") or "",
"base_url": row["base_url"],
"capabilities": row.get("capabilities") or {},
"enabled": bool(row.get("enabled", True)),
"source": "user",
}
@models_ns.route("/models")
class ModelsListResource(Resource):
def get(self):
"""Get list of available models with their capabilities.
When the request is authenticated, the response includes the
user's own BYOM registrations alongside the built-in catalog.
"""
"""Get list of available models with their capabilities."""
try:
user_id = None
decoded_token = getattr(request, "decoded_token", None)
if decoded_token:
user_id = decoded_token.get("sub")
registry = ModelRegistry.get_instance()
models = registry.get_enabled_models(user_id=user_id)
models = registry.get_enabled_models()
response = {
"models": [model.to_dict() for model in models],
@@ -140,382 +23,3 @@ class ModelsListResource(Resource):
current_app.logger.error(f"Error fetching models: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 500)
return make_response(jsonify(response), 200)
@models_ns.route("/user/models")
class UserModelsCollectionResource(Resource):
@api.doc(description="List the current user's BYOM custom models")
def get(self):
user_id, err = _user_id_or_401()
if err:
return err
try:
with db_readonly() as conn:
rows = UserCustomModelsRepository(conn).list_for_user(user_id)
return make_response(
jsonify({"models": [_row_to_response(r) for r in rows]}), 200
)
except Exception as e:
current_app.logger.error(
f"Error listing user custom models: {e}", exc_info=True
)
return make_response(jsonify({"success": False}), 500)
@api.doc(description="Register a new BYOM custom model")
def post(self):
user_id, err = _user_id_or_401()
if err:
return err
data = request.get_json() or {}
missing = check_required_fields(
data,
["upstream_model_id", "display_name", "base_url", "api_key"],
)
if missing:
return missing
# SECURITY: reject blank api_key — would leak instance API key
# to the user-supplied base_url via LLMCreator fallback.
for required_nonblank in (
"upstream_model_id",
"display_name",
"base_url",
"api_key",
):
value = data.get(required_nonblank)
if not isinstance(value, str) or not value.strip():
return make_response(
jsonify(
{
"success": False,
"error": f"'{required_nonblank}' must be a non-empty string",
}
),
400,
)
# SSRF guard at create time. Re-runs at dispatch time (LLMCreator)
# as defense in depth against DNS rebinding and pre-guard rows.
try:
validate_user_base_url(data["base_url"])
except UnsafeUserUrlError as e:
return make_response(
jsonify({"success": False, "error": str(e)}), 400
)
try:
capabilities = _normalize_capabilities(data.get("capabilities"))
except ValueError as e:
return make_response(
jsonify({"success": False, "error": str(e)}), 400
)
try:
with db_session() as conn:
row = UserCustomModelsRepository(conn).create(
user_id=user_id,
upstream_model_id=data["upstream_model_id"],
display_name=data["display_name"],
description=data.get("description") or "",
base_url=data["base_url"],
api_key_plaintext=data["api_key"],
capabilities=capabilities,
enabled=bool(data.get("enabled", True)),
)
except Exception as e:
current_app.logger.error(
f"Error creating user custom model: {e}", exc_info=True
)
return make_response(jsonify({"success": False}), 500)
ModelRegistry.invalidate_user(user_id)
return make_response(jsonify(_row_to_response(row)), 201)
@models_ns.route("/user/models/<string:model_id>")
class UserModelResource(Resource):
@api.doc(description="Get one BYOM custom model")
def get(self, model_id):
user_id, err = _user_id_or_401()
if err:
return err
try:
with db_readonly() as conn:
row = UserCustomModelsRepository(conn).get(model_id, user_id)
except Exception as e:
current_app.logger.error(
f"Error fetching user custom model: {e}", exc_info=True
)
return make_response(jsonify({"success": False}), 500)
if row is None:
return make_response(jsonify({"success": False}), 404)
return make_response(jsonify(_row_to_response(row)), 200)
@api.doc(description="Update a BYOM custom model (partial)")
def patch(self, model_id):
user_id, err = _user_id_or_401()
if err:
return err
data = request.get_json() or {}
# Reject present-but-blank values for fields where blank doesn't
# mean "no change". (The api_key special case — blank means "keep
# existing" — is handled below.)
for required_nonblank in (
"upstream_model_id",
"display_name",
"base_url",
):
if required_nonblank in data:
value = data[required_nonblank]
if not isinstance(value, str) or not value.strip():
return make_response(
jsonify(
{
"success": False,
"error": f"'{required_nonblank}' cannot be blank",
}
),
400,
)
if "base_url" in data and data["base_url"]:
try:
validate_user_base_url(data["base_url"])
except UnsafeUserUrlError as e:
return make_response(
jsonify({"success": False, "error": str(e)}), 400
)
update_fields: dict = {}
for k in (
"upstream_model_id",
"display_name",
"description",
"base_url",
"enabled",
):
if k in data:
update_fields[k] = data[k]
if "capabilities" in data:
try:
update_fields["capabilities"] = _normalize_capabilities(
data["capabilities"]
)
except ValueError as e:
return make_response(
jsonify({"success": False, "error": str(e)}), 400
)
# PATCH semantics: blank/missing api_key → keep the existing
# ciphertext; non-empty api_key → re-encrypt and replace.
if data.get("api_key"):
update_fields["api_key_plaintext"] = data["api_key"]
if not update_fields:
return make_response(
jsonify({"success": False, "error": "no updatable fields"}), 400
)
try:
with db_session() as conn:
ok = UserCustomModelsRepository(conn).update(
model_id, user_id, update_fields
)
except Exception as e:
current_app.logger.error(
f"Error updating user custom model: {e}", exc_info=True
)
return make_response(jsonify({"success": False}), 500)
if not ok:
return make_response(jsonify({"success": False}), 404)
ModelRegistry.invalidate_user(user_id)
with db_readonly() as conn:
row = UserCustomModelsRepository(conn).get(model_id, user_id)
return make_response(jsonify(_row_to_response(row)), 200)
@api.doc(description="Delete a BYOM custom model")
def delete(self, model_id):
user_id, err = _user_id_or_401()
if err:
return err
try:
with db_session() as conn:
ok = UserCustomModelsRepository(conn).delete(model_id, user_id)
except Exception as e:
current_app.logger.error(
f"Error deleting user custom model: {e}", exc_info=True
)
return make_response(jsonify({"success": False}), 500)
if not ok:
return make_response(jsonify({"success": False}), 404)
ModelRegistry.invalidate_user(user_id)
return make_response(jsonify({"success": True}), 200)
def _run_connection_test(
base_url: str, api_key: str, upstream_model_id: str
):
"""Send a 1-token chat-completion to verify a BYOM endpoint.
Returns ``(body, http_status)``. Upstream errors return 200 with
``ok=False`` so the UI can render inline errors; only local SSRF
rejection returns 400.
"""
url = base_url.rstrip("/") + "/chat/completions"
payload = {
"model": upstream_model_id,
"messages": [{"role": "user", "content": "hi"}],
"max_tokens": 1,
"stream": False,
}
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
try:
# pinned_post closes the DNS-rebinding window. Redirects off
# because 3xx could bounce to an internal address (the SSRF
# guard only validates the supplied URL).
resp = pinned_post(
url,
json=payload,
headers=headers,
timeout=5,
allow_redirects=False,
)
except UnsafeUserUrlError as e:
return {"ok": False, "error": str(e)}, 400
except requests.RequestException as e:
return {"ok": False, "error": f"connection error: {e}"}, 200
if 300 <= resp.status_code < 400:
return (
{
"ok": False,
"error": (
f"upstream returned HTTP {resp.status_code} "
"redirect; refusing to follow"
),
},
200,
)
if resp.status_code >= 400:
# Cap and only reflect JSON to avoid body-exfil via non-API responses.
content_type = (resp.headers.get("Content-Type") or "").lower()
if "application/json" in content_type:
text = (resp.text or "")[:500]
error_msg = f"upstream returned HTTP {resp.status_code}: {text}"
else:
error_msg = f"upstream returned HTTP {resp.status_code}"
return {"ok": False, "error": error_msg}, 200
return {"ok": True}, 200
@models_ns.route("/user/models/test")
class UserModelTestPayloadResource(Resource):
@api.doc(
description=(
"Test an arbitrary BYOM payload (display_name / model id / "
"base_url / api_key) without saving. Used by the UI's 'Test "
"connection' button so the user can validate before they "
"Save. Same SSRF guard, same 1-token request, same 5s "
"timeout as the by-id variant."
)
)
def post(self):
user_id, err = _user_id_or_401()
if err:
return err
data = request.get_json() or {}
missing = check_required_fields(
data, ["base_url", "api_key", "upstream_model_id"]
)
if missing:
return missing
body, status = _run_connection_test(
data["base_url"], data["api_key"], data["upstream_model_id"]
)
return make_response(jsonify(body), status)
@models_ns.route("/user/models/<string:model_id>/test")
class UserModelTestResource(Resource):
@api.doc(
description=(
"Test a saved BYOM record. Defaults to the stored "
"base_url / upstream_model_id / encrypted api_key, but "
"any of those can be overridden via the request body so "
"the UI can test in-flight edits before saving. Used by "
"the 'Test connection' button in edit mode."
)
)
def post(self, model_id):
user_id, err = _user_id_or_401()
if err:
return err
data = request.get_json() or {}
# Per-field overrides; blank/missing falls back to stored value.
override_base_url = (data.get("base_url") or "").strip() or None
override_upstream_model_id = (
data.get("upstream_model_id") or ""
).strip() or None
override_api_key = (data.get("api_key") or "").strip() or None
try:
with db_readonly() as conn:
repo = UserCustomModelsRepository(conn)
row = repo.get(model_id, user_id)
if row is None:
return make_response(jsonify({"success": False}), 404)
stored_api_key = (
repo._decrypt_api_key(
row.get("api_key_encrypted", ""), user_id
)
if not override_api_key
else None
)
except Exception as e:
current_app.logger.error(
f"Error loading user custom model for test: {e}", exc_info=True
)
return make_response(
jsonify({"ok": False, "error": "internal error loading model"}),
500,
)
api_key = override_api_key or stored_api_key
if not api_key:
return make_response(
jsonify(
{
"ok": False,
"error": (
"Stored API key could not be decrypted. The "
"encryption secret may have rotated. Re-save "
"the model with the API key to recover."
),
}
),
400,
)
base_url = override_base_url or row["base_url"]
upstream_model_id = (
override_upstream_model_id or row["upstream_model_id"]
)
body, status = _run_connection_test(
base_url, api_key, upstream_model_id
)
return make_response(jsonify(body), status)

View File

@@ -2,13 +2,14 @@
import os
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, Namespace, Resource
from application.api import api
from application.api.user.base import current_dir
from application.api.user.base import current_dir, prompts_collection
from application.storage.db.dual_write import dual_write
from application.storage.db.repositories.prompts import PromptsRepository
from application.storage.db.session import db_readonly, db_session
from application.utils import check_required_fields
prompts_ns = Namespace(
@@ -41,9 +42,21 @@ class CreatePrompt(Resource):
return missing_fields
user = decoded_token.get("sub")
try:
with db_session() as conn:
prompt = PromptsRepository(conn).create(user, data["name"], data["content"])
new_id = str(prompt["id"])
resp = prompts_collection.insert_one(
{
"name": data["name"],
"content": data["content"],
"user": user,
}
)
new_id = str(resp.inserted_id)
dual_write(
PromptsRepository,
lambda repo, u=user, n=data["name"], c=data["content"], mid=new_id: repo.create(
u, n, c, legacy_mongo_id=mid,
),
)
except Exception as err:
current_app.logger.error(f"Error creating prompt: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
@@ -59,17 +72,17 @@ class GetPrompts(Resource):
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
try:
with db_readonly() as conn:
prompts = PromptsRepository(conn).list_for_user(user)
prompts = prompts_collection.find({"user": user})
list_prompts = [
{"id": "default", "name": "default", "type": "public"},
{"id": "creative", "name": "creative", "type": "public"},
{"id": "strict", "name": "strict", "type": "public"},
]
for prompt in prompts:
list_prompts.append(
{
"id": str(prompt["id"]),
"id": str(prompt["_id"]),
"name": prompt["name"],
"type": "private",
}
@@ -114,12 +127,9 @@ class GetSinglePrompt(Resource):
) as f:
chat_reduce_strict = f.read()
return make_response(jsonify({"content": chat_reduce_strict}), 200)
with db_readonly() as conn:
prompt = PromptsRepository(conn).get_any(prompt_id, user)
if not prompt:
return make_response(
jsonify({"success": False, "message": "Prompt not found"}), 404
)
prompt = prompts_collection.find_one(
{"_id": ObjectId(prompt_id), "user": user}
)
except Exception as err:
current_app.logger.error(f"Error retrieving prompt: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
@@ -146,15 +156,11 @@ class DeletePrompt(Resource):
if missing_fields:
return missing_fields
try:
with db_session() as conn:
repo = PromptsRepository(conn)
prompt = repo.get_any(data["id"], user)
if not prompt:
return make_response(
jsonify({"success": False, "message": "Prompt not found"}),
404,
)
repo.delete(str(prompt["id"]), user)
prompts_collection.delete_one({"_id": ObjectId(data["id"]), "user": user})
dual_write(
PromptsRepository,
lambda repo, pid=data["id"], u=user: repo.delete_by_legacy_id(pid, u),
)
except Exception as err:
current_app.logger.error(f"Error deleting prompt: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
@@ -187,15 +193,16 @@ class UpdatePrompt(Resource):
if missing_fields:
return missing_fields
try:
with db_session() as conn:
repo = PromptsRepository(conn)
prompt = repo.get_any(data["id"], user)
if not prompt:
return make_response(
jsonify({"success": False, "message": "Prompt not found"}),
404,
)
repo.update(str(prompt["id"]), user, data["name"], data["content"])
prompts_collection.update_one(
{"_id": ObjectId(data["id"]), "user": user},
{"$set": {"name": data["name"], "content": data["content"]}},
)
dual_write(
PromptsRepository,
lambda repo, pid=data["id"], u=user, n=data["name"], c=data["content"]: repo.update_by_legacy_id(
pid, u, n, c,
),
)
except Exception as err:
current_app.logger.error(f"Error updating prompt: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)

View File

@@ -1,292 +0,0 @@
"""Reconciler tick: sweep stuck rows and escalate to terminal status + alert."""
from __future__ import annotations
import logging
import uuid
from datetime import datetime, timezone
from typing import Any, Dict, Optional, TYPE_CHECKING
from sqlalchemy import Connection
from application.api.user.idempotency import MAX_TASK_ATTEMPTS
from application.core.settings import settings
from application.storage.db.engine import get_engine
from application.storage.db.repositories.reconciliation import (
ReconciliationRepository,
)
from application.storage.db.repositories.stack_logs import StackLogsRepository
if TYPE_CHECKING:
from application.storage.db.repositories.schedules import SchedulesRepository
logger = logging.getLogger(__name__)
MAX_MESSAGE_RECONCILE_ATTEMPTS = 3
def run_reconciliation() -> Dict[str, Any]:
"""Single tick of the reconciler. Five sweeps, FOR UPDATE SKIP LOCKED.
Stuck ``executed`` tool calls always flip to ``failed`` — operators
handle cleanup manually via the structured alert. The side effect is
assumed to have committed; no automated rollback is attempted.
Stuck ``task_dedup`` rows (lease expired AND attempts >= max)
promote to ``failed`` so a same-key retry can re-claim instead of
sitting in ``pending`` until 24 h TTL.
"""
if not settings.POSTGRES_URI:
return {
"messages_failed": 0,
"tool_calls_failed": 0,
"skipped": "POSTGRES_URI not set",
}
engine = get_engine()
summary = {
"messages_failed": 0,
"tool_calls_failed": 0,
"ingests_stalled": 0,
"idempotency_pending_failed": 0,
"schedule_runs_failed": 0,
}
with engine.begin() as conn:
repo = ReconciliationRepository(conn)
for msg in repo.find_and_lock_stuck_messages():
new_count = repo.increment_message_reconcile_attempts(msg["id"])
if new_count >= MAX_MESSAGE_RECONCILE_ATTEMPTS:
repo.mark_message_failed(
msg["id"],
error=(
"reconciler: stuck in pending/streaming for >5 min "
f"after {new_count} attempts"
),
)
summary["messages_failed"] += 1
_emit_alert(
conn,
name="reconciler_message_failed",
user_id=msg.get("user_id"),
detail={
"message_id": str(msg["id"]),
"attempts": new_count,
},
)
with engine.begin() as conn:
repo = ReconciliationRepository(conn)
for row in repo.find_and_lock_proposed_tool_calls():
repo.mark_tool_call_failed(
row["call_id"],
error=(
"reconciler: stuck in 'proposed' for >5 min; "
"side effect status unknown"
),
)
summary["tool_calls_failed"] += 1
_emit_alert(
conn,
name="reconciler_tool_call_failed_proposed",
user_id=None,
detail={
"call_id": row["call_id"],
"tool_name": row.get("tool_name"),
},
)
with engine.begin() as conn:
repo = ReconciliationRepository(conn)
for row in repo.find_and_lock_executed_tool_calls():
repo.mark_tool_call_failed(
row["call_id"],
error=(
"reconciler: executed-not-confirmed; side effect "
"assumed committed, manual cleanup required"
),
)
summary["tool_calls_failed"] += 1
_emit_alert(
conn,
name="reconciler_tool_call_failed_executed",
user_id=None,
detail={
"call_id": row["call_id"],
"tool_name": row.get("tool_name"),
"action_name": row.get("action_name"),
},
)
# Q4: ingest checkpoints whose heartbeat has gone silent. Each is
# escalated to terminal ``status='stalled'`` and alerted once — no
# worker kill, no rollback of the partial embed. The 'stalled' flag
# ends the re-alert loop and drives the "indexing failed" badge the
# sources list derives from this row.
with engine.begin() as conn:
repo = ReconciliationRepository(conn)
for row in repo.find_and_lock_stalled_ingests():
summary["ingests_stalled"] += 1
_emit_alert(
conn,
name="reconciler_ingest_stalled",
user_id=None,
detail={
"source_id": str(row.get("source_id")),
"embedded_chunks": row.get("embedded_chunks"),
"total_chunks": row.get("total_chunks"),
"last_updated": str(row.get("last_updated")),
},
)
repo.mark_ingest_stalled(str(row["source_id"]))
# Q5: idempotency rows whose lease expired with attempts exhausted.
# The wrapper's poison-loop guard normally finalises these, but if
# the wrapper itself died mid-task (worker SIGKILL, OOM during
# heartbeat) the row sits in ``pending`` blocking same-key retries
# via ``_lookup_completed`` returning None for the whole 24 h TTL.
# Promote to ``failed`` so a retry can re-claim and either resume
# or fail loudly.
with engine.begin() as conn:
repo = ReconciliationRepository(conn)
for row in repo.find_stuck_idempotency_pending(
max_attempts=MAX_TASK_ATTEMPTS,
):
error_msg = (
"reconciler: idempotency lease expired with attempts "
f"({row['attempt_count']}) >= {MAX_TASK_ATTEMPTS}; "
"task abandoned"
)
repo.mark_idempotency_pending_failed(
row["idempotency_key"], error=error_msg,
)
summary["idempotency_pending_failed"] += 1
_emit_alert(
conn,
name="reconciler_idempotency_pending_failed",
user_id=None,
detail={
"idempotency_key": row["idempotency_key"],
"task_name": row.get("task_name"),
"task_id": row.get("task_id"),
"attempts": row.get("attempt_count"),
},
)
# Q6: scheduler runs stuck in 'running' past the soft-time-limit window.
from application.storage.db.repositories.schedule_runs import (
ScheduleRunsRepository,
)
from application.storage.db.repositories.schedules import SchedulesRepository
from application.core.settings import settings as _settings
stuck_age = max(
15, int(_settings.SCHEDULE_RUN_TIMEOUT // 60) + 5,
)
with engine.begin() as conn:
runs_repo = ScheduleRunsRepository(conn)
schedules_repo = SchedulesRepository(conn)
for run in runs_repo.list_stuck_running(age_minutes=stuck_age):
runs_repo.update(
run["id"],
{
"status": "timeout",
"finished_at": datetime.now(timezone.utc),
"error_type": "timeout",
"error": (
"reconciler: schedule_run stuck in 'running' past "
f"{stuck_age} min"
),
},
)
schedules_repo.bump_failure_count(str(run["schedule_id"]))
_terminal_flip_once_schedule(
schedules_repo, str(run["schedule_id"]),
)
summary["schedule_runs_failed"] += 1
_emit_alert(
conn,
name="reconciler_schedule_run_timeout",
user_id=run.get("user_id"),
detail={
"run_id": str(run["id"]),
"schedule_id": str(run["schedule_id"]),
},
)
# Q7: scheduler runs orphaned in 'pending' — dispatcher committed but
# apply_async failed (broker outage / crash mid-dispatch).
with engine.begin() as conn:
runs_repo = ScheduleRunsRepository(conn)
schedules_repo = SchedulesRepository(conn)
for run in runs_repo.list_stuck_pending(age_minutes=stuck_age):
runs_repo.update(
run["id"],
{
"status": "failed",
"finished_at": datetime.now(timezone.utc),
"error_type": "internal",
"error": (
"reconciler: schedule_run stuck in 'pending' past "
f"{stuck_age} min (worker_never_started)"
),
},
)
schedules_repo.bump_failure_count(str(run["schedule_id"]))
_terminal_flip_once_schedule(
schedules_repo, str(run["schedule_id"]),
)
summary["schedule_runs_failed"] += 1
_emit_alert(
conn,
name="reconciler_schedule_run_pending",
user_id=run.get("user_id"),
detail={
"run_id": str(run["id"]),
"schedule_id": str(run["schedule_id"]),
},
)
return summary
def _terminal_flip_once_schedule(
schedules_repo: "SchedulesRepository", schedule_id: str,
) -> None:
"""Flip a once-schedule to 'completed' after its run terminates.
Recurring schedules keep firing; once-schedules would otherwise read
'active forever' since next_run_at is already NULL.
"""
schedule = schedules_repo.get_internal(schedule_id)
if schedule is None or schedule.get("trigger_type") != "once":
return
if schedule.get("status") in {"completed", "cancelled"}:
return
schedules_repo.update_internal(
schedule_id, {"status": "completed", "next_run_at": None},
)
def _emit_alert(
conn: Connection,
*,
name: str,
user_id: Optional[str],
detail: Dict[str, Any],
) -> None:
"""Structured ``logger.error`` plus a ``stack_logs`` row for operators."""
extra = {"alert": name, **detail}
logger.error("reconciler alert: %s", name, extra=extra)
try:
StackLogsRepository(conn).insert(
activity_id=str(uuid.uuid4()),
endpoint="reconciliation_worker",
level="ERROR",
user_id=user_id,
query=name,
stacks=[extra],
)
except Exception:
logger.exception("reconciler: failed to write stack_logs row for %s", name)

View File

@@ -11,7 +11,6 @@ from .attachments import attachments_ns
from .conversations import conversations_ns
from .models import models_ns
from .prompts import prompts_ns
from .schedules import schedules_ns
from .sharing import sharing_ns
from .sources import sources_chunks_ns, sources_ns, sources_upload_ns
from .tools import tools_mcp_ns, tools_ns
@@ -41,9 +40,6 @@ api.add_namespace(agents_folders_ns)
# Prompts
api.add_namespace(prompts_ns)
# Schedules
api.add_namespace(schedules_ns)
# Sharing
api.add_namespace(sharing_ns)

View File

@@ -1,186 +0,0 @@
"""Schedule dispatcher: poll Postgres, claim due rows under FOR UPDATE SKIP LOCKED,
advance next_run_at atomically with the run claim, then enqueue.
Per-schedule IANA tz semantics (croniter+zoneinfo) outside Celery's app-wide tz,
plus Postgres-native dedup avoid Redis visibility_timeout double-fires.
"""
from __future__ import annotations
import logging
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List, Optional
from application.agents.scheduler_utils import next_cron_run
from application.core.settings import settings
from application.storage.db.engine import get_engine
from application.storage.db.repositories.schedule_runs import (
ScheduleRunsRepository,
)
from application.storage.db.repositories.schedules import SchedulesRepository
logger = logging.getLogger(__name__)
def _normalize_dt(value: Any) -> Optional[datetime]:
"""Accept a datetime / ISO string / None and return a tz-aware UTC dt."""
if value is None:
return None
if isinstance(value, datetime):
return value.astimezone(timezone.utc) if value.tzinfo else (
value.replace(tzinfo=timezone.utc)
)
if isinstance(value, str):
try:
parsed = datetime.fromisoformat(value.replace("Z", "+00:00"))
except ValueError:
return None
return parsed.astimezone(timezone.utc) if parsed.tzinfo else (
parsed.replace(tzinfo=timezone.utc)
)
return None
def _compute_next(
schedule: Dict[str, Any],
*,
after: datetime,
) -> Optional[datetime]:
"""Next next_run_at for a recurring schedule, or None when past end_at."""
cron = schedule.get("cron")
if not cron:
return None
end_at = _normalize_dt(schedule.get("end_at"))
candidate = next_cron_run(cron, schedule.get("timezone"), after=after)
if end_at is not None and candidate > end_at:
return None
return candidate
def dispatch_due_runs() -> Dict[str, int]:
"""One dispatcher tick; returns counts for schedule_syncs-style logging."""
if not settings.POSTGRES_URI:
return {"enqueued": 0, "skipped": 0, "advanced": 0}
from application.api.user.tasks import execute_scheduled_run
now = datetime.now(timezone.utc)
grace = timedelta(seconds=max(0, settings.SCHEDULE_MISFIRE_GRACE))
engine = get_engine()
counts = {"enqueued": 0, "skipped": 0, "advanced": 0}
enqueue_args: List[str] = []
with engine.begin() as conn:
schedules_repo = SchedulesRepository(conn)
runs_repo = ScheduleRunsRepository(conn)
for schedule in schedules_repo.list_due():
scheduled_for = _normalize_dt(schedule.get("next_run_at"))
if scheduled_for is None:
continue
trigger_type = schedule.get("trigger_type")
agent_id_raw = schedule.get("agent_id")
agent_id = str(agent_id_raw) if agent_id_raw else None
# Misfire grace applies to recurring only — once-tasks fire late, not vanish.
if (
trigger_type == "recurring"
and grace > timedelta(0)
and (now - scheduled_for) > grace
):
runs_repo.record_skipped(
str(schedule["id"]),
schedule["user_id"],
agent_id,
scheduled_for,
error_type="missed",
error="misfire grace exceeded",
)
counts["skipped"] += 1
nxt = _compute_next(schedule, after=now)
if nxt is None:
schedules_repo.update_internal(
str(schedule["id"]),
{"status": "completed", "next_run_at": None,
"last_run_at": now},
)
else:
schedules_repo.update_internal(
str(schedule["id"]),
{"next_run_at": nxt, "last_run_at": now},
)
counts["advanced"] += 1
continue
# Overlap guard: never enqueue while a previous run is active.
if runs_repo.has_active_run(str(schedule["id"])):
runs_repo.record_skipped(
str(schedule["id"]),
schedule["user_id"],
agent_id,
scheduled_for,
error_type="overlap",
error="previous run still active",
)
counts["skipped"] += 1
if trigger_type == "recurring":
nxt = _compute_next(schedule, after=scheduled_for)
schedules_repo.update_internal(
str(schedule["id"]),
{"next_run_at": nxt, "last_run_at": now},
)
else:
# Once: null next_run_at so we don't re-pick; the in-flight
# run will terminal-flip the schedule when it finishes.
schedules_repo.update_internal(
str(schedule["id"]),
{"next_run_at": None, "last_run_at": now},
)
continue
# Dedup primitive: two racing dispatchers see exactly one row.
run = runs_repo.record_pending(
str(schedule["id"]),
schedule["user_id"],
agent_id,
scheduled_for,
trigger_source="cron",
)
if run is None:
counts["skipped"] += 1
else:
enqueue_args.append(str(run["id"]))
counts["enqueued"] += 1
# Advance: recurring picks next tick, once nulls next_run_at
# (worker terminal-flips status on completion).
if trigger_type == "recurring":
nxt = _compute_next(schedule, after=scheduled_for)
if nxt is None:
schedules_repo.update_internal(
str(schedule["id"]),
{"status": "completed", "next_run_at": None,
"last_run_at": now},
)
else:
schedules_repo.update_internal(
str(schedule["id"]),
{"next_run_at": nxt, "last_run_at": now},
)
else:
schedules_repo.update_internal(
str(schedule["id"]),
{"next_run_at": None, "last_run_at": now},
)
counts["advanced"] += 1
# Enqueue after commit so the worker sees the schedule_runs row on pick-up.
for run_id in enqueue_args:
try:
execute_scheduled_run.apply_async(args=[run_id], queue="docsgpt")
except Exception:
logger.exception(
"dispatcher: failed to enqueue execute_scheduled_run for %s",
run_id,
)
return counts

View File

@@ -1,433 +0,0 @@
"""Body of ``execute_scheduled_run`` — runs a single agent execution.
Not a DURABLE_TASK: agent runs have side effects (messages, CRM writes)
and blind auto-retry would double them. Failures after agent.gen starts
are terminal and recorded; only the pre-start load is retry-safe.
"""
from __future__ import annotations
import logging
from datetime import datetime, timezone
from typing import Any, Dict, Optional
from celery.exceptions import SoftTimeLimitExceeded
from sqlalchemy import text as sql_text
from application.agents.headless_runner import run_agent_headless
from application.core.settings import settings
from application.events.publisher import publish_user_event
from application.storage.db.base_repository import row_to_dict
from application.storage.db.engine import get_engine
from application.storage.db.repositories.conversations import (
ConversationsRepository,
)
from application.storage.db.repositories.schedule_runs import (
ScheduleRunsRepository,
)
from application.storage.db.repositories.schedules import SchedulesRepository
from application.storage.db.repositories.token_usage import TokenUsageRepository
logger = logging.getLogger(__name__)
# Cap output verbatim in the run log; beyond the cap we keep the head and stamp output_truncated.
_OUTPUT_CAP_CHARS = 24_000
def _agent_config_for_schedule(schedule: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""Resolve the agent row (agent-bound) or build an ephemeral classic config.
For agentless schedules (``agent_id IS NULL``), the worker constructs an
in-memory agent shape carrying just enough fields for ``run_agent_headless``:
classic agent type, system-default retriever/chunks/prompt, no source, and
the optional ``model_id`` override. The runtime toolset is rebuilt by
``ToolExecutor`` at fire time (current ``user_tools`` + non-disabled,
non-headless-excluded defaults), so a snapshot here would be dead code.
"""
if schedule.get("agent_id"):
engine = get_engine()
with engine.connect() as conn:
row = conn.execute(
sql_text("SELECT * FROM agents WHERE id = CAST(:id AS uuid)"),
{"id": str(schedule["agent_id"])},
).fetchone()
return row_to_dict(row) if row is not None else None
return _ephemeral_agent_for_agentless(schedule)
def _ephemeral_agent_for_agentless(
schedule: Dict[str, Any],
) -> Optional[Dict[str, Any]]:
"""Build an agent-shaped config for a schedule with no parent agent."""
# ``agent_config["tools"]`` is intentionally omitted: ``run_agent_headless``
# never reads it. The runtime toolset is rebuilt by
# ``ToolExecutor._get_user_tools(owner)`` at fire time — same dereference
# the agent-bound path uses, so a tool added/disabled after creation is
# reflected. Headless mode there filters chat-only tools (``scheduler``).
user_id = schedule.get("user_id")
if not user_id:
return None
return {
"id": None,
"user_id": user_id,
"agent_type": "classic",
"retriever": "classic",
"chunks": 2,
"prompt_id": "default",
"source_id": None,
"default_model_id": schedule.get("model_id") or "",
}
def _load_chat_history(schedule: Dict[str, Any]) -> list:
"""Originating conversation history (one-time only; recurring has none)."""
origin = schedule.get("origin_conversation_id")
if not origin or schedule.get("trigger_type") != "once":
return []
user_id = schedule.get("user_id")
if not user_id:
return []
try:
engine = get_engine()
with engine.connect() as conn:
conv = ConversationsRepository(conn).get_any(str(origin), user_id)
if conv is None:
return []
messages = ConversationsRepository(conn).get_messages(str(conv["id"]))
except Exception:
logger.exception("scheduler: failed loading chat history")
return []
history: list = []
for msg in messages:
if msg.get("prompt") and msg.get("response"):
history.append({
"prompt": msg["prompt"],
"response": msg["response"],
})
return history
def _publish_run_event(
event_type: str, run: Dict[str, Any], schedule: Dict[str, Any], **extra: Any
) -> None:
"""Best-effort SSE publish for a scheduler run state transition."""
user_id = run.get("user_id") or schedule.get("user_id")
if not user_id:
return
agent_id_raw = schedule.get("agent_id")
payload = {
"run_id": str(run["id"]),
"schedule_id": str(schedule["id"]),
"agent_id": str(agent_id_raw) if agent_id_raw else None,
"trigger_type": schedule.get("trigger_type"),
"status": run.get("status"),
**extra,
}
try:
publish_user_event(
user_id,
event_type,
payload,
scope={"kind": "schedule", "id": str(schedule["id"])},
)
except Exception:
logger.exception(
"scheduler: SSE publish failed event=%s run=%s",
event_type, run.get("id"),
)
def _publish_message_appended(
user_id: str,
conversation_id: str,
message: Dict[str, Any],
schedule_id: str,
run_id: str,
) -> None:
"""SSE message-appended event for a one-time run's chat turn."""
try:
publish_user_event(
user_id,
"schedule.message.appended",
{
"conversation_id": str(conversation_id),
"message_id": str(message["id"]),
"schedule_id": str(schedule_id),
"run_id": str(run_id),
"position": int(message.get("position", 0)),
},
scope={"kind": "conversation", "id": str(conversation_id)},
)
except Exception:
logger.exception(
"scheduler: message.appended publish failed run=%s", run_id,
)
def _append_one_time_turn(
schedule: Dict[str, Any],
run: Dict[str, Any],
outcome: Dict[str, Any],
) -> Optional[Dict[str, Any]]:
"""Insert an assistant turn in the originating conversation (once only)."""
origin = schedule.get("origin_conversation_id")
if not origin:
return None
engine = get_engine()
user_id = schedule.get("user_id")
metadata = {
"scheduled": True,
"schedule_id": str(schedule["id"]),
"run_id": str(run["id"]),
"scheduled_run_at": (
run.get("scheduled_for")
if isinstance(run.get("scheduled_for"), str)
else None
),
}
with engine.begin() as conn:
conv = ConversationsRepository(conn).get_any(str(origin), user_id)
if conv is None:
return None
message = ConversationsRepository(conn).append_message(
str(conv["id"]),
{
"prompt": schedule.get("instruction") or "",
"response": outcome.get("answer") or "",
"thought": outcome.get("thought") or "",
"sources": outcome.get("sources") or [],
"tool_calls": outcome.get("tool_calls") or [],
"model_id": outcome.get("model_id"),
"metadata": metadata,
},
)
return message
def execute_scheduled_run_body(run_id: str, celery_task_id: Optional[str]) -> Dict[str, Any]:
"""Execute one scheduled run by id; returns a result dict for tracing."""
if not settings.POSTGRES_URI:
return {"status": "skipped", "reason": "POSTGRES_URI not set"}
engine = get_engine()
with engine.connect() as conn:
run = ScheduleRunsRepository(conn).get_internal(run_id)
if run is None:
return {"status": "skipped", "reason": "run not found"}
schedule = SchedulesRepository(conn).get_internal(str(run["schedule_id"]))
if schedule is None:
return {"status": "skipped", "reason": "schedule not found"}
# Refuse non-runnable terminal states; manual run-now bypasses.
if run.get("status") != "pending":
return {"status": "skipped", "reason": f"run status={run.get('status')}"}
if schedule.get("status") in {"cancelled", "completed"} and run.get(
"trigger_source"
) != "manual":
with engine.begin() as conn:
ScheduleRunsRepository(conn).update(
run_id,
{
"status": "skipped",
"finished_at": datetime.now(timezone.utc),
"error_type": "internal",
"error": "schedule no longer active",
},
)
return {"status": "skipped", "reason": "schedule terminal"}
agent_config = _agent_config_for_schedule(schedule)
if agent_config is None:
with engine.begin() as conn:
updated = ScheduleRunsRepository(conn).update(
run_id,
{
"status": "failed",
"finished_at": datetime.now(timezone.utc),
"error_type": "internal",
"error": "agent missing",
},
)
SchedulesRepository(conn).bump_failure_count(str(schedule["id"]))
_publish_run_event("schedule.run.failed", updated or run, schedule,
error="agent missing")
return {"status": "failed", "reason": "agent missing"}
with engine.begin() as conn:
if not ScheduleRunsRepository(conn).mark_running(run_id, celery_task_id):
return {"status": "skipped", "reason": "lost race to mark_running"}
started = datetime.now(timezone.utc)
instruction = schedule.get("instruction") or ""
allowlist = schedule.get("tool_allowlist") or []
chat_history = _load_chat_history(schedule)
outcome: Dict[str, Any]
error_type: Optional[str] = None
error_text: Optional[str] = None
timed_out = False
try:
outcome = run_agent_headless(
agent_config,
instruction,
tool_allowlist=allowlist,
model_id_override=schedule.get("model_id"),
endpoint="schedule",
conversation_id=schedule.get("origin_conversation_id"),
chat_history=chat_history,
)
except SoftTimeLimitExceeded:
timed_out = True
outcome = {"answer": "", "tool_calls": [], "sources": [], "thought": ""}
error_type = "timeout"
error_text = "run exceeded soft time limit"
except Exception as exc:
outcome = {"answer": "", "tool_calls": [], "sources": [], "thought": ""}
error_type = "agent_error"
error_text = str(exc)
logger.exception("scheduler: agent run failed run=%s", run_id)
finished = datetime.now(timezone.utc)
# Headless denial with no usable output → tool_not_allowed.
if (
error_type is None
and (outcome.get("denied") or [])
and not (outcome.get("answer") or "").strip()
):
error_type = "tool_not_allowed"
error_text = "headless allowlist blocked required tool"
prompt_tokens = int(outcome.get("prompt_tokens", 0) or 0)
generated_tokens = int(outcome.get("generated_tokens", 0) or 0)
used_tokens = prompt_tokens + generated_tokens
if (
schedule.get("token_budget") is not None
and int(schedule["token_budget"]) > 0
and used_tokens > int(schedule["token_budget"])
):
error_type = "budget_exceeded"
error_text = (
f"used {used_tokens} tokens exceeds budget "
f"{schedule['token_budget']}"
)
answer = outcome.get("answer") or ""
truncated = False
if len(answer) > _OUTPUT_CAP_CHARS:
answer = answer[:_OUTPUT_CAP_CHARS]
truncated = True
new_status = (
"timeout" if timed_out else ("failed" if error_type else "success")
)
with engine.begin() as conn:
update_fields: Dict[str, Any] = {
"status": new_status,
"started_at": started,
"finished_at": finished,
"output": answer or None,
"output_truncated": truncated,
"prompt_tokens": prompt_tokens,
"generated_tokens": generated_tokens,
}
if error_type:
update_fields["error_type"] = error_type
update_fields["error"] = error_text
updated_run = ScheduleRunsRepository(conn).update(run_id, update_fields)
if used_tokens > 0:
agent_id_raw = schedule.get("agent_id")
try:
TokenUsageRepository(conn).insert(
user_id=schedule.get("user_id"),
api_key=None,
prompt_tokens=prompt_tokens,
generated_tokens=generated_tokens,
timestamp=finished,
agent_id=str(agent_id_raw) if agent_id_raw else None,
source="schedule",
request_id=str(run_id),
)
except Exception:
logger.exception(
"scheduler: token_usage insert failed run=%s", run_id,
)
schedules_repo = SchedulesRepository(conn)
autopaused = False
if new_status == "success":
schedules_repo.reset_failure_count(str(schedule["id"]))
elif new_status in ("failed", "timeout"):
count = schedules_repo.bump_failure_count(str(schedule["id"]))
if (
settings.SCHEDULE_AUTOPAUSE_FAILURES > 0
and count >= settings.SCHEDULE_AUTOPAUSE_FAILURES
and schedule.get("trigger_type") == "recurring"
):
autopaused = schedules_repo.autopause(str(schedule["id"]))
# Once: terminal-flip on cron-fired runs only; manual runs on a
# still-active once-schedule leave the future cadence intact.
if (
schedule.get("trigger_type") == "once"
and run.get("trigger_source") != "manual"
and schedule.get("status") == "active"
):
schedules_repo.update_internal(
str(schedule["id"]),
{"status": "completed", "next_run_at": None},
)
appended: Optional[Dict[str, Any]] = None
if (
schedule.get("trigger_type") == "once"
and new_status == "success"
and schedule.get("origin_conversation_id")
):
try:
appended = _append_one_time_turn(schedule, updated_run or run, outcome)
except Exception:
logger.exception(
"scheduler: append turn failed run=%s", run_id,
)
if appended is not None:
with engine.begin() as conn:
ScheduleRunsRepository(conn).update(
run_id,
{
"conversation_id": str(appended["conversation_id"]),
"message_id": str(appended["id"]),
},
)
_publish_message_appended(
schedule.get("user_id"),
str(appended["conversation_id"]),
appended,
str(schedule["id"]),
run_id,
)
if new_status == "success":
_publish_run_event("schedule.run.completed", updated_run or run, schedule)
else:
_publish_run_event(
"schedule.run.failed",
updated_run or run,
schedule,
error_type=error_type,
error=error_text,
)
if autopaused:
_publish_run_event(
"schedule.autopaused",
updated_run or run,
schedule,
consecutive_failure_count=settings.SCHEDULE_AUTOPAUSE_FAILURES,
)
return {
"status": new_status,
"run_id": run_id,
"error_type": error_type,
}

View File

@@ -1,5 +0,0 @@
"""Schedules module."""
from .routes import schedules_ns
__all__ = ["schedules_ns"]

View File

@@ -1,550 +0,0 @@
"""Schedules REST API (owner-scoped via request.decoded_token)."""
from __future__ import annotations
import logging
from datetime import datetime, timezone
from typing import Any, Dict, Optional
from flask import current_app, jsonify, make_response, request
from flask_restx import Namespace, Resource, fields
from application.agents.scheduler_utils import (
ScheduleValidationError,
clamp_once_horizon,
cron_interval_seconds,
next_cron_run,
parse_cron,
parse_run_at,
resolve_timezone,
)
from application.api import api
from application.core.settings import settings
from application.storage.db.base_repository import looks_like_uuid
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.repositories.schedule_runs import (
ScheduleRunsRepository,
)
from application.storage.db.repositories.schedules import SchedulesRepository
from application.storage.db.session import db_readonly, db_session
logger = logging.getLogger(__name__)
schedules_ns = Namespace(
"schedules", description="Agent schedule management", path="/api",
)
def _ok(data: Any, status: int = 200):
return make_response(jsonify(data), status)
def _err(message: str, status: int = 400):
return make_response(jsonify({"success": False, "message": message}), status)
def _format_schedule(row: Dict[str, Any]) -> Dict[str, Any]:
"""Render a schedule row for the API (id-as-string + ISO timestamps)."""
if not row:
return {}
out = dict(row)
for key in (
"id", "agent_id", "origin_conversation_id",
):
if out.get(key) is not None:
out[key] = str(out[key])
out.pop("_id", None) # drop dual-id legacy mirror
return out
def _format_run(row: Dict[str, Any]) -> Dict[str, Any]:
"""Render a schedule_run row for the API."""
if not row:
return {}
out = dict(row)
for key in (
"id", "schedule_id", "agent_id", "conversation_id", "message_id",
):
if out.get(key) is not None:
out[key] = str(out[key])
out.pop("_id", None)
return out
def _agent_owned(agent_id: str, user_id: str) -> Optional[Dict[str, Any]]:
if not looks_like_uuid(str(agent_id)):
return None
with db_readonly() as conn:
return AgentsRepository(conn).get_any(agent_id, user_id)
def _user_id() -> Optional[str]:
decoded = getattr(request, "decoded_token", None)
if not decoded:
return None
return decoded.get("sub")
@schedules_ns.route("/agents/<string:agent_id>/schedules")
class AgentSchedules(Resource):
@api.doc(description="List schedules for an agent (recurring + one-time).")
def get(self, agent_id):
user_id = _user_id()
if not user_id:
return _err("unauthorized", 401)
agent = _agent_owned(agent_id, user_id)
if agent is None:
return _err("agent not found", 404)
try:
with db_readonly() as conn:
rows = SchedulesRepository(conn).list_for_agent(
str(agent["id"]), user_id,
)
except Exception as exc:
current_app.logger.error("list schedules failed: %s", exc, exc_info=True)
return _err("internal error", 500)
return _ok({"schedules": [_format_schedule(r) for r in rows]})
create_model = api.model(
"ScheduleCreate",
{
"instruction": fields.String(required=True),
"trigger_type": fields.String(
required=False,
description="'recurring' (default) or 'once'",
),
"cron": fields.String(
required=False,
description="Required when trigger_type == 'recurring'",
),
"run_at": fields.String(
required=False,
description="ISO 8601 — required when trigger_type == 'once'",
),
"timezone": fields.String(required=False),
"name": fields.String(required=False),
"end_at": fields.String(required=False, description="ISO 8601"),
"tool_allowlist": fields.List(fields.String, required=False),
"model_id": fields.String(required=False),
"token_budget": fields.Integer(required=False),
},
)
@api.expect(create_model)
@api.doc(description="Create a schedule (recurring or one-time) for an agent.")
def post(self, agent_id):
user_id = _user_id()
if not user_id:
return _err("unauthorized", 401)
agent = _agent_owned(agent_id, user_id)
if agent is None:
return _err("agent not found", 404)
data = request.get_json(silent=True) or {}
instruction = (data.get("instruction") or "").strip()
tz_name = (data.get("timezone") or "UTC").strip() or "UTC"
trigger_type = (data.get("trigger_type") or "recurring").strip().lower()
if trigger_type not in ("recurring", "once"):
return _err("trigger_type must be 'recurring' or 'once'")
if not instruction:
return _err("instruction is required")
try:
resolve_timezone(tz_name)
except ScheduleValidationError as exc:
return _err(str(exc))
token_budget = data.get("token_budget")
if token_budget is not None:
try:
token_budget = int(token_budget)
if token_budget < 0:
raise ValueError
except (TypeError, ValueError):
return _err("token_budget must be a non-negative integer")
with db_readonly() as conn:
count = SchedulesRepository(conn).count_active_for_user(user_id)
if (
settings.SCHEDULE_MAX_PER_USER > 0
and count >= settings.SCHEDULE_MAX_PER_USER
):
return _err("max schedules per user reached", 429)
if trigger_type == "once":
run_at_raw = (data.get("run_at") or "").strip()
if not run_at_raw:
return _err("run_at is required for trigger_type 'once'")
try:
fire = parse_run_at(run_at_raw, tz_name)
clamp_once_horizon(
fire, settings.SCHEDULE_ONCE_MAX_HORIZON,
)
except ScheduleValidationError as exc:
return _err(str(exc))
try:
with db_session() as conn:
created = SchedulesRepository(conn).create(
user_id=user_id,
agent_id=str(agent["id"]),
trigger_type="once",
instruction=instruction,
run_at=fire,
next_run_at=fire,
timezone=tz_name,
name=(data.get("name") or "").strip() or None,
tool_allowlist=data.get("tool_allowlist") or [],
model_id=(data.get("model_id") or None),
token_budget=token_budget,
created_via="ui",
)
except Exception as exc:
current_app.logger.error(
"create one-time schedule failed: %s", exc, exc_info=True,
)
return _err("internal error", 500)
return _ok({"schedule": _format_schedule(created)}, status=201)
cron = (data.get("cron") or "").strip()
if not cron:
return _err("cron is required")
try:
parse_cron(cron)
except ScheduleValidationError as exc:
return _err(str(exc))
min_interval = max(0, int(settings.SCHEDULE_MIN_INTERVAL))
if min_interval > 0:
try:
cadence = cron_interval_seconds(cron, tz_name)
except ScheduleValidationError as exc:
return _err(str(exc))
if cadence < min_interval:
return _err(
"cadence below minimum interval "
f"({cadence}s < {min_interval}s)",
)
end_at = None
if data.get("end_at"):
try:
end_at = datetime.fromisoformat(
str(data["end_at"]).replace("Z", "+00:00"),
)
except ValueError:
return _err("invalid end_at")
try:
next_run = next_cron_run(cron, tz_name, after=datetime.now(timezone.utc))
except ScheduleValidationError as exc:
return _err(str(exc))
if end_at is not None and next_run > end_at:
return _err("end_at is before the first cron tick")
try:
with db_session() as conn:
created = SchedulesRepository(conn).create(
user_id=user_id,
agent_id=str(agent["id"]),
trigger_type="recurring",
instruction=instruction,
cron=cron,
timezone=tz_name,
next_run_at=next_run,
end_at=end_at,
name=(data.get("name") or "").strip() or None,
tool_allowlist=data.get("tool_allowlist") or [],
model_id=(data.get("model_id") or None),
token_budget=token_budget,
created_via="ui",
)
except Exception as exc:
current_app.logger.error(
"create schedule failed: %s", exc, exc_info=True,
)
return _err("internal error", 500)
return _ok({"schedule": _format_schedule(created)}, status=201)
@schedules_ns.route("/schedules/<string:schedule_id>")
class ScheduleResource(Resource):
@api.doc(description="Get schedule by id.")
def get(self, schedule_id):
user_id = _user_id()
if not user_id:
return _err("unauthorized", 401)
if not looks_like_uuid(schedule_id):
return _err("invalid schedule id", 400)
with db_readonly() as conn:
row = SchedulesRepository(conn).get(schedule_id, user_id)
if row is None:
return _err("schedule not found", 404)
return _ok({"schedule": _format_schedule(row)})
@api.doc(description="Edit a schedule's editable fields.")
def put(self, schedule_id):
user_id = _user_id()
if not user_id:
return _err("unauthorized", 401)
if not looks_like_uuid(schedule_id):
return _err("invalid schedule id", 400)
data = request.get_json(silent=True) or {}
fields_in: Dict[str, Any] = {}
if "instruction" in data:
inst = (data["instruction"] or "").strip()
if not inst:
return _err("instruction must not be empty")
fields_in["instruction"] = inst
if "cron" in data:
cron = (data["cron"] or "").strip()
try:
parse_cron(cron)
except ScheduleValidationError as exc:
return _err(str(exc))
fields_in["cron"] = cron
if "timezone" in data:
tz_name = (data["timezone"] or "UTC").strip() or "UTC"
try:
resolve_timezone(tz_name)
except ScheduleValidationError as exc:
return _err(str(exc))
fields_in["timezone"] = tz_name
if "tool_allowlist" in data:
fields_in["tool_allowlist"] = data["tool_allowlist"] or []
if "name" in data:
fields_in["name"] = (data["name"] or "").strip() or None
if "model_id" in data:
fields_in["model_id"] = (data["model_id"] or None)
if "token_budget" in data:
tb = data["token_budget"]
if tb is not None:
try:
tb = int(tb)
if tb < 0:
raise ValueError
except (TypeError, ValueError):
return _err("token_budget must be a non-negative integer")
fields_in["token_budget"] = tb
if "end_at" in data:
if data["end_at"]:
try:
fields_in["end_at"] = datetime.fromisoformat(
str(data["end_at"]).replace("Z", "+00:00"),
)
except ValueError:
return _err("invalid end_at")
else:
fields_in["end_at"] = None
# Recompute next_run_at when cron/tz changes.
with db_session() as conn:
existing = SchedulesRepository(conn).get(schedule_id, user_id)
if existing is None:
return _err("schedule not found", 404)
if (
("cron" in fields_in or "timezone" in fields_in)
and existing.get("trigger_type") == "recurring"
):
cron_eff = fields_in.get("cron") or existing.get("cron")
tz_eff = fields_in.get("timezone") or existing.get("timezone")
if cron_eff:
min_interval = max(0, int(settings.SCHEDULE_MIN_INTERVAL))
if min_interval > 0:
try:
cadence = cron_interval_seconds(cron_eff, tz_eff)
except ScheduleValidationError as exc:
return _err(str(exc))
if cadence < min_interval:
return _err(
"cadence below minimum interval "
f"({cadence}s < {min_interval}s)",
)
try:
fields_in["next_run_at"] = next_cron_run(
cron_eff, tz_eff, after=datetime.now(timezone.utc),
)
except ScheduleValidationError as exc:
return _err(str(exc))
updated = SchedulesRepository(conn).update(
schedule_id, user_id, fields_in,
)
return _ok({"schedule": _format_schedule(updated or {})})
@api.doc(description="Pause / resume a schedule.")
def patch(self, schedule_id):
user_id = _user_id()
if not user_id:
return _err("unauthorized", 401)
if not looks_like_uuid(schedule_id):
return _err("invalid schedule id", 400)
data = request.get_json(silent=True) or {}
action = (data.get("action") or "").lower().strip()
if action not in {"pause", "resume"}:
return _err("action must be 'pause' or 'resume'")
with db_session() as conn:
existing = SchedulesRepository(conn).get(schedule_id, user_id)
if existing is None:
return _err("schedule not found", 404)
if existing.get("status") in ("cancelled", "completed"):
return _err("schedule is terminal", 409)
if action == "pause":
fields_in: Dict[str, Any] = {"status": "paused", "next_run_at": None}
else:
# Resume: recurring recomputes from now; once honours run_at if still future.
fields_in = {"status": "active"}
if existing.get("trigger_type") == "recurring":
try:
fields_in["next_run_at"] = next_cron_run(
existing["cron"],
existing["timezone"],
after=datetime.now(timezone.utc),
)
except ScheduleValidationError as exc:
return _err(str(exc))
else:
new_run_at = data.get("run_at")
if new_run_at:
try:
run_at_dt = datetime.fromisoformat(
str(new_run_at).replace("Z", "+00:00"),
)
except ValueError:
return _err("invalid run_at")
if run_at_dt <= datetime.now(timezone.utc):
return _err(
"run_at must be in the future to resume", 409,
)
fields_in["next_run_at"] = run_at_dt
fields_in["run_at"] = run_at_dt
else:
run_at = existing.get("run_at")
if run_at:
if isinstance(run_at, str):
try:
run_at_dt = datetime.fromisoformat(
run_at.replace("Z", "+00:00"),
)
except ValueError:
return _err("schedule run_at is invalid")
else:
run_at_dt = run_at
if run_at_dt <= datetime.now(timezone.utc):
return _err(
"the once schedule has elapsed; recreate "
"it or supply a new run_at",
409,
)
fields_in["next_run_at"] = run_at_dt
updated = SchedulesRepository(conn).update(
schedule_id, user_id, fields_in,
)
if action == "resume":
SchedulesRepository(conn).reset_failure_count(schedule_id)
return _ok({"schedule": _format_schedule(updated or {})})
@api.doc(description="Cancel / delete a schedule.")
def delete(self, schedule_id):
user_id = _user_id()
if not user_id:
return _err("unauthorized", 401)
if not looks_like_uuid(schedule_id):
return _err("invalid schedule id", 400)
with db_session() as conn:
ok = SchedulesRepository(conn).delete(schedule_id, user_id)
if not ok:
return _err("schedule not found", 404)
return _ok({"success": True})
@schedules_ns.route("/schedules/<string:schedule_id>/run")
class ScheduleRunNow(Resource):
@api.doc(description="Run a schedule immediately (trigger_source='manual').")
def post(self, schedule_id):
user_id = _user_id()
if not user_id:
return _err("unauthorized", 401)
if not looks_like_uuid(schedule_id):
return _err("invalid schedule id", 400)
# FOR UPDATE serializes concurrent Run-Now POSTs (timestamp-unique
# scheduled_for values would otherwise sneak past the unique index).
with db_session() as conn:
schedule = SchedulesRepository(conn).get_for_update(
schedule_id, user_id,
)
if schedule is None:
return _err("schedule not found", 404)
if schedule.get("status") == "cancelled":
return _err("schedule is cancelled", 409)
if ScheduleRunsRepository(conn).has_active_run(schedule_id):
return _err("a run is already in flight", 409)
scheduled_for = datetime.now(timezone.utc)
agent_id_raw = schedule.get("agent_id")
run = ScheduleRunsRepository(conn).record_pending(
schedule_id,
user_id,
str(agent_id_raw) if agent_id_raw else None,
scheduled_for,
trigger_source="manual",
)
if run is None:
return _err("could not claim run (concurrent dispatch)", 409)
# Import inside the handler to avoid a circular tasks <-> routes import.
try:
from application.api.user.tasks import execute_scheduled_run
execute_scheduled_run.apply_async(args=[str(run["id"])], queue="docsgpt")
except Exception as exc:
current_app.logger.error(
"run-now enqueue failed: %s", exc, exc_info=True,
)
return _err("enqueue failed", 500)
return _ok({"run": _format_run(run)}, status=202)
@schedules_ns.route("/schedules/<string:schedule_id>/runs")
class ScheduleRunList(Resource):
@api.doc(
description="Paginated run log for a schedule.",
params={"limit": "Page size (default 50)", "offset": "Page offset"},
)
def get(self, schedule_id):
user_id = _user_id()
if not user_id:
return _err("unauthorized", 401)
if not looks_like_uuid(schedule_id):
return _err("invalid schedule id", 400)
try:
limit = max(1, min(int(request.args.get("limit", 50)), 200))
except (TypeError, ValueError):
limit = 50
try:
offset = max(0, int(request.args.get("offset", 0)))
except (TypeError, ValueError):
offset = 0
with db_readonly() as conn:
schedule = SchedulesRepository(conn).get(schedule_id, user_id)
if schedule is None:
return _err("schedule not found", 404)
rows = ScheduleRunsRepository(conn).list_runs(
schedule_id, user_id, limit=limit, offset=offset,
)
return _ok(
{
"runs": [_format_run(r) for r in rows],
"limit": limit,
"offset": offset,
}
)
@schedules_ns.route("/schedules/<string:schedule_id>/runs/<string:run_id>")
class ScheduleRunDetail(Resource):
@api.doc(description="Full output / error for a single run.")
def get(self, schedule_id, run_id):
user_id = _user_id()
if not user_id:
return _err("unauthorized", 401)
if not looks_like_uuid(schedule_id) or not looks_like_uuid(run_id):
return _err("invalid id", 400)
with db_readonly() as conn:
schedule = SchedulesRepository(conn).get(schedule_id, user_id)
if schedule is None:
return _err("schedule not found", 404)
run = ScheduleRunsRepository(conn).get(run_id, user_id)
if run is None or str(run.get("schedule_id")) != str(
schedule["id"]
):
return _err("run not found", 404)
return _ok({"run": _format_run(run)})

View File

@@ -2,126 +2,89 @@
import uuid
from bson.binary import Binary, UuidRepresentation
from bson.dbref import DBRef
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, inputs, Namespace, Resource
from sqlalchemy import text as _sql_text
from application.api import api
from application.storage.db.base_repository import looks_like_uuid
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.repositories.attachments import AttachmentsRepository
from application.api.user.base import (
agents_collection,
attachments_collection,
conversations_collection,
shared_conversations_collections,
)
from application.storage.db.dual_write import dual_write
from application.storage.db.repositories.conversations import ConversationsRepository
from application.storage.db.repositories.shared_conversations import (
SharedConversationsRepository,
)
from application.storage.db.session import db_readonly, db_session
from application.utils import check_required_fields
def _dual_write_share(
mongo_conv_id: str,
share_uuid: str,
user: str,
*,
is_promptable: bool,
first_n_queries: int,
api_key: str | None,
prompt_id: str | None = None,
chunks: int | None = None,
) -> None:
"""Mirror a Mongo share-record insert into Postgres.
Preserves the Mongo-generated UUID so public ``/shared/{uuid}`` URLs
resolve from both stores during cutover.
"""
def _write(repo: SharedConversationsRepository) -> None:
conv = ConversationsRepository(repo._conn).get_by_legacy_id(
mongo_conv_id, user_id=user,
)
if conv is None:
return
# prompt_id / chunks are only meaningful for promptable shares;
# prompt_id is often the string "default" or an ObjectId that
# hasn't been migrated — pass as-is and let the repo drop
# non-UUID values. Scope the prompt lookup by user_id so an
# authenticated caller can't link another user's prompt into
# their share record.
resolved_prompt_id = None
if prompt_id and len(str(prompt_id)) == 24:
from sqlalchemy import text as _text
row = repo._conn.execute(
_text(
"SELECT id FROM prompts "
"WHERE legacy_mongo_id = :legacy_id AND user_id = :user_id"
),
{"legacy_id": str(prompt_id), "user_id": user},
).fetchone()
if row:
resolved_prompt_id = str(row[0])
# get_or_create is race-free on the PG side thanks to the
# composite partial unique index on the dedup tuple
# (migration 0008). It converges concurrent share requests to
# a single row.
repo.get_or_create(
conv["id"],
user,
is_promptable=is_promptable,
first_n_queries=first_n_queries,
api_key=api_key,
prompt_id=resolved_prompt_id,
chunks=chunks,
share_uuid=share_uuid,
)
dual_write(SharedConversationsRepository, _write)
sharing_ns = Namespace(
"sharing", description="Conversation sharing operations", path="/api"
)
def _resolve_prompt_pg_id(conn, prompt_id_raw, user_id):
"""Translate an incoming prompt id (UUID or legacy Mongo ObjectId) to a PG UUID.
Scoped by ``user_id`` so a caller can't link another user's prompt
into their share record. Returns ``None`` for sentinel values
(``"default"``) or unresolved ids.
"""
if not prompt_id_raw or prompt_id_raw == "default":
return None
value = str(prompt_id_raw)
# Already UUID — trust it but still require ownership. A shape-gate
# (rather than a loose ``len == 36 and '-' in value`` check) keeps
# non-UUID input out of ``CAST(:pid AS uuid)``; the cast would raise
# and poison the readonly transaction otherwise.
if looks_like_uuid(value):
row = conn.execute(
_sql_text(
"SELECT id FROM prompts WHERE id = CAST(:pid AS uuid) "
"AND user_id = :uid"
),
{"pid": value, "uid": user_id},
).fetchone()
return str(row[0]) if row else None
# Legacy Mongo ObjectId fallback.
row = conn.execute(
_sql_text(
"SELECT id FROM prompts WHERE legacy_mongo_id = :pid "
"AND user_id = :uid"
),
{"pid": value, "uid": user_id},
).fetchone()
return str(row[0]) if row else None
def _resolve_source_pg_id(conn, source_raw):
"""Translate a source id (UUID or legacy Mongo ObjectId) to a PG UUID."""
if not source_raw:
return None
value = str(source_raw)
# See ``_resolve_prompt_pg_id`` for the shape-gate rationale.
if looks_like_uuid(value):
row = conn.execute(
_sql_text(
"SELECT id FROM sources WHERE id = CAST(:sid AS uuid)"
),
{"sid": value},
).fetchone()
return str(row[0]) if row else None
row = conn.execute(
_sql_text("SELECT id FROM sources WHERE legacy_mongo_id = :sid"),
{"sid": value},
).fetchone()
return str(row[0]) if row else None
def _find_reusable_share_agent(
conn, user_id, *, prompt_pg_id, chunks, source_pg_id, retriever,
):
"""Find an existing share-as-agent key row matching these parameters.
Mirrors the legacy Mongo ``agents_collection.find_one`` pre-existence
check. Used to reuse an api key across repeated shares of the same
conversation with the same prompt/chunks/source/retriever.
"""
clauses = ["user_id = :uid", "key IS NOT NULL"]
params: dict = {"uid": user_id}
if prompt_pg_id is None:
clauses.append("prompt_id IS NULL")
else:
clauses.append("prompt_id = CAST(:pid AS uuid)")
params["pid"] = prompt_pg_id
if chunks is None:
clauses.append("chunks IS NULL")
else:
clauses.append("chunks = :chunks")
params["chunks"] = int(chunks)
if source_pg_id is None:
clauses.append("source_id IS NULL")
else:
clauses.append("source_id = CAST(:sid AS uuid)")
params["sid"] = source_pg_id
if retriever is None:
clauses.append("retriever IS NULL")
else:
clauses.append("retriever = :retr")
params["retr"] = retriever
sql = (
"SELECT * FROM agents WHERE "
+ " AND ".join(clauses)
+ " LIMIT 1"
)
row = conn.execute(_sql_text(sql), params).fetchone()
if row is None:
return None
mapping = dict(row._mapping)
mapping["id"] = str(mapping["id"]) if mapping.get("id") else None
return mapping
@sharing_ns.route("/share")
class ShareConversation(Resource):
share_conversation_model = api.model(
@@ -156,93 +119,173 @@ class ShareConversation(Resource):
conversation_id = data["conversation_id"]
try:
with db_session() as conn:
conv_repo = ConversationsRepository(conn)
shared_repo = SharedConversationsRepository(conn)
agents_repo = AgentsRepository(conn)
conversation = conversations_collection.find_one(
{"_id": ObjectId(conversation_id), "user": user}
)
if conversation is None:
return make_response(
jsonify(
{
"status": "error",
"message": "Conversation does not exist",
}
),
404,
)
current_n_queries = len(conversation["queries"])
explicit_binary = Binary.from_uuid(
uuid.uuid4(), UuidRepresentation.STANDARD
)
conversation = conv_repo.get_any(conversation_id, user)
if conversation is None:
return make_response(
jsonify(
{
"status": "error",
"message": "Conversation does not exist",
}
),
404,
if is_promptable:
prompt_id = data.get("prompt_id", "default")
chunks = data.get("chunks", "2")
name = conversation["name"] + "(shared)"
new_api_key_data = {
"prompt_id": prompt_id,
"chunks": chunks,
"user": user,
}
if "source" in data and ObjectId.is_valid(data["source"]):
new_api_key_data["source"] = DBRef(
"sources", ObjectId(data["source"])
)
conv_pg_id = str(conversation["id"])
current_n_queries = conv_repo.message_count(conv_pg_id)
if is_promptable:
prompt_id_raw = data.get("prompt_id", "default")
chunks_raw = data.get("chunks", "2")
try:
chunks_int = int(chunks_raw) if chunks_raw not in (None, "") else None
except (TypeError, ValueError):
chunks_int = None
prompt_pg_id = _resolve_prompt_pg_id(conn, prompt_id_raw, user)
source_pg_id = _resolve_source_pg_id(conn, data.get("source"))
retriever = data.get("retriever")
reusable = _find_reusable_share_agent(
conn, user,
prompt_pg_id=prompt_pg_id,
chunks=chunks_int,
source_pg_id=source_pg_id,
retriever=retriever,
if "retriever" in data:
new_api_key_data["retriever"] = data["retriever"]
pre_existing_api_document = agents_collection.find_one(new_api_key_data)
if pre_existing_api_document:
api_uuid = pre_existing_api_document["key"]
pre_existing = shared_conversations_collections.find_one(
{
"conversation_id": ObjectId(conversation_id),
"isPromptable": is_promptable,
"first_n_queries": current_n_queries,
"user": user,
"api_key": api_uuid,
}
)
if reusable:
api_uuid = reusable.get("key")
else:
api_uuid = str(uuid.uuid4())
name = (conversation.get("name") or "") + "(shared)"
agents_repo.create(
user,
name,
"published",
key=api_uuid,
retriever=retriever,
chunks=chunks_int,
prompt_id=prompt_pg_id,
source_id=source_pg_id,
if pre_existing is not None:
return make_response(
jsonify(
{
"success": True,
"identifier": str(pre_existing["uuid"].as_uuid()),
}
),
200,
)
else:
shared_conversations_collections.insert_one(
{
"uuid": explicit_binary,
"conversation_id": ObjectId(conversation_id),
"isPromptable": is_promptable,
"first_n_queries": current_n_queries,
"user": user,
"api_key": api_uuid,
}
)
_dual_write_share(
conversation_id,
str(explicit_binary.as_uuid()),
user,
is_promptable=is_promptable,
first_n_queries=current_n_queries,
api_key=api_uuid,
prompt_id=prompt_id,
chunks=int(chunks) if chunks else None,
)
return make_response(
jsonify(
{
"success": True,
"identifier": str(explicit_binary.as_uuid()),
}
),
201,
)
else:
api_uuid = str(uuid.uuid4())
new_api_key_data["key"] = api_uuid
new_api_key_data["name"] = name
share = shared_repo.get_or_create(
conv_pg_id,
if "source" in data and ObjectId.is_valid(data["source"]):
new_api_key_data["source"] = DBRef(
"sources", ObjectId(data["source"])
)
if "retriever" in data:
new_api_key_data["retriever"] = data["retriever"]
agents_collection.insert_one(new_api_key_data)
shared_conversations_collections.insert_one(
{
"uuid": explicit_binary,
"conversation_id": ObjectId(conversation_id),
"isPromptable": is_promptable,
"first_n_queries": current_n_queries,
"user": user,
"api_key": api_uuid,
}
)
_dual_write_share(
conversation_id,
str(explicit_binary.as_uuid()),
user,
is_promptable=True,
is_promptable=is_promptable,
first_n_queries=current_n_queries,
api_key=api_uuid,
prompt_id=prompt_pg_id,
chunks=chunks_int,
prompt_id=prompt_id,
chunks=int(chunks) if chunks else None,
)
return make_response(
jsonify(
{
"success": True,
"identifier": str(share["uuid"]),
"identifier": str(explicit_binary.as_uuid()),
}
),
201 if reusable is None else 200,
201,
)
# Non-promptable share path.
share = shared_repo.get_or_create(
conv_pg_id,
pre_existing = shared_conversations_collections.find_one(
{
"conversation_id": ObjectId(conversation_id),
"isPromptable": is_promptable,
"first_n_queries": current_n_queries,
"user": user,
}
)
if pre_existing is not None:
return make_response(
jsonify(
{
"success": True,
"identifier": str(pre_existing["uuid"].as_uuid()),
}
),
200,
)
else:
shared_conversations_collections.insert_one(
{
"uuid": explicit_binary,
"conversation_id": ObjectId(conversation_id),
"isPromptable": is_promptable,
"first_n_queries": current_n_queries,
"user": user,
}
)
_dual_write_share(
conversation_id,
str(explicit_binary.as_uuid()),
user,
is_promptable=False,
is_promptable=is_promptable,
first_n_queries=current_n_queries,
api_key=None,
)
return make_response(
jsonify(
{
"success": True,
"identifier": str(share["uuid"]),
}
{"success": True, "identifier": str(explicit_binary.as_uuid())}
),
201,
)
@@ -258,13 +301,37 @@ class GetPubliclySharedConversations(Resource):
@api.doc(description="Get publicly shared conversations by identifier")
def get(self, identifier: str):
try:
with db_readonly() as conn:
shared_repo = SharedConversationsRepository(conn)
conv_repo = ConversationsRepository(conn)
attach_repo = AttachmentsRepository(conn)
query_uuid = Binary.from_uuid(
uuid.UUID(identifier), UuidRepresentation.STANDARD
)
shared = shared_conversations_collections.find_one({"uuid": query_uuid})
conversation_queries = []
shared = shared_repo.find_by_uuid(identifier)
if not shared or not shared.get("conversation_id"):
if (
shared
and "conversation_id" in shared
):
# Handle DBRef (legacy), ObjectId, dict, and string formats for conversation_id
conversation_id = shared["conversation_id"]
if isinstance(conversation_id, DBRef):
conversation_id = conversation_id.id
elif isinstance(conversation_id, dict):
# Handle dict representation of DBRef (e.g., {"$ref": "...", "$id": "..."})
if "$id" in conversation_id:
conv_id = conversation_id["$id"]
# $id might be a dict like {"$oid": "..."} or a string
if isinstance(conv_id, dict) and "$oid" in conv_id:
conversation_id = ObjectId(conv_id["$oid"])
else:
conversation_id = ObjectId(conv_id)
elif "_id" in conversation_id:
conversation_id = ObjectId(conversation_id["_id"])
elif isinstance(conversation_id, str):
conversation_id = ObjectId(conversation_id)
conversation = conversations_collection.find_one(
{"_id": conversation_id}
)
if conversation is None:
return make_response(
jsonify(
{
@@ -274,60 +341,22 @@ class GetPubliclySharedConversations(Resource):
),
404,
)
conv_pg_id = str(shared["conversation_id"])
owner_user = shared.get("user_id")
conversation_queries = conversation["queries"][
: (shared["first_n_queries"])
]
conversation = conv_repo.get_owned(conv_pg_id, owner_user) if owner_user else None
if conversation is None:
# Fall back to any-user lookup in case shared row's
# user_id is missing — still keyed by PG UUID.
row = conn.execute(
_sql_text(
"SELECT * FROM conversations WHERE id = CAST(:id AS uuid)"
),
{"id": conv_pg_id},
).fetchone()
if row is None:
return make_response(
jsonify(
{
"success": False,
"error": "might have broken url or the conversation does not exist",
}
),
404,
)
conversation = dict(row._mapping)
messages = conv_repo.get_messages(conv_pg_id)
first_n = shared.get("first_n_queries") or 0
conversation_queries = []
for msg in messages[:first_n]:
query = {
"prompt": msg.get("prompt"),
"response": msg.get("response"),
"thought": msg.get("thought"),
"sources": msg.get("sources") or [],
"tool_calls": msg.get("tool_calls") or [],
"timestamp": (
msg["timestamp"].isoformat()
if hasattr(msg.get("timestamp"), "isoformat")
else msg.get("timestamp")
),
"feedback": msg.get("feedback"),
}
attachments = msg.get("attachments") or []
if attachments:
for query in conversation_queries:
if "attachments" in query and query["attachments"]:
attachment_details = []
for attachment_id in attachments:
for attachment_id in query["attachments"]:
try:
attachment = attach_repo.get_any(
str(attachment_id), owner_user,
) if owner_user else None
attachment = attachments_collection.find_one(
{"_id": ObjectId(attachment_id)}
)
if attachment:
attachment_details.append(
{
"id": str(attachment["id"]),
"id": str(attachment["_id"]),
"fileName": attachment.get(
"filename", "Unknown file"
),
@@ -339,23 +368,26 @@ class GetPubliclySharedConversations(Resource):
exc_info=True,
)
query["attachments"] = attachment_details
conversation_queries.append(query)
created = conversation.get("created_at") or conversation.get("date")
date_iso = (
created.isoformat()
if hasattr(created, "isoformat")
else (str(created) if created is not None else None)
else:
return make_response(
jsonify(
{
"success": False,
"error": "might have broken url or the conversation does not exist",
}
),
404,
)
res = {
"success": True,
"queries": conversation_queries,
"title": conversation.get("name"),
"timestamp": date_iso,
}
if shared.get("is_promptable") and shared.get("api_key"):
res["api_key"] = shared["api_key"]
return make_response(jsonify(res), 200)
date = conversation["_id"].generation_time.isoformat()
res = {
"success": True,
"queries": conversation_queries,
"title": conversation["name"],
"timestamp": date,
}
if shared["isPromptable"] and "api_key" in shared:
res["api_key"] = shared["api_key"]
return make_response(jsonify(res), 200)
except Exception as err:
current_app.logger.error(
f"Error getting shared conversation: {err}", exc_info=True

View File

@@ -1,12 +1,11 @@
"""Source document management chunk management."""
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, Namespace, Resource
from application.api import api
from application.api.user.base import get_vector_store
from application.storage.db.repositories.sources import SourcesRepository
from application.storage.db.session import db_readonly
from application.api.user.base import get_vector_store, sources_collection
from application.utils import check_required_fields, num_tokens_from_string
sources_chunks_ns = Namespace(
@@ -14,15 +13,6 @@ sources_chunks_ns = Namespace(
)
def _resolve_source(doc_id: str, user: str):
"""Resolve a source (UUID or legacy ObjectId) for the caller.
Returns the row dict (with PG UUID in ``id``) or ``None`` if missing.
"""
with db_readonly() as conn:
return SourcesRepository(conn).get_any(doc_id, user)
@sources_chunks_ns.route("/get_chunks")
class GetChunks(Resource):
@api.doc(
@@ -46,34 +36,36 @@ class GetChunks(Resource):
path = request.args.get("path")
search_term = request.args.get("search", "").strip().lower()
if not doc_id:
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
try:
doc = _resolve_source(doc_id, user)
except Exception as e:
current_app.logger.error(f"Error resolving source: {e}", exc_info=True)
if not ObjectId.is_valid(doc_id):
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
if not doc:
return make_response(
jsonify({"error": "Document not found or access denied"}), 404
)
resolved_id = str(doc["id"])
try:
store = get_vector_store(resolved_id)
store = get_vector_store(doc_id)
chunks = store.get_chunks()
filtered_chunks = []
for chunk in chunks:
metadata = chunk.get("metadata", {})
# Filter by path if provided
if path:
chunk_source = metadata.get("source", "")
chunk_file_path = metadata.get("file_path", "")
# Check if the chunk matches the requested path
# For file uploads: source ends with path (e.g., "inputs/.../file.pdf" ends with "file.pdf")
# For crawlers: file_path ends with path (e.g., "guides/setup.md" ends with "setup.md")
source_match = chunk_source and chunk_source.endswith(path)
file_path_match = chunk_file_path and chunk_file_path.endswith(path)
if not (source_match or file_path_match):
continue
# Filter by search term if provided
if search_term:
text_match = search_term in chunk.get("text", "").lower()
title_match = search_term in metadata.get("title", "").lower()
@@ -140,17 +132,15 @@ class AddChunk(Resource):
token_count = num_tokens_from_string(text)
metadata["token_count"] = token_count
try:
doc = _resolve_source(doc_id, user)
except Exception as e:
current_app.logger.error(f"Error resolving source: {e}", exc_info=True)
if not ObjectId.is_valid(doc_id):
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
if not doc:
return make_response(
jsonify({"error": "Document not found or access denied"}), 404
)
try:
store = get_vector_store(str(doc["id"]))
store = get_vector_store(doc_id)
chunk_id = store.add_chunk(text, metadata)
return make_response(
jsonify({"message": "Chunk added successfully", "chunk_id": chunk_id}),
@@ -175,17 +165,15 @@ class DeleteChunk(Resource):
doc_id = request.args.get("id")
chunk_id = request.args.get("chunk_id")
try:
doc = _resolve_source(doc_id, user)
except Exception as e:
current_app.logger.error(f"Error resolving source: {e}", exc_info=True)
if not ObjectId.is_valid(doc_id):
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
if not doc:
return make_response(
jsonify({"error": "Document not found or access denied"}), 404
)
try:
store = get_vector_store(str(doc["id"]))
store = get_vector_store(doc_id)
deleted = store.delete_chunk(chunk_id)
if deleted:
return make_response(
@@ -244,17 +232,15 @@ class UpdateChunk(Resource):
if metadata is None:
metadata = {}
metadata["token_count"] = token_count
try:
doc = _resolve_source(doc_id, user)
except Exception as e:
current_app.logger.error(f"Error resolving source: {e}", exc_info=True)
if not ObjectId.is_valid(doc_id):
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
if not doc:
return make_response(
jsonify({"error": "Document not found or access denied"}), 404
)
try:
store = get_vector_store(str(doc["id"]))
store = get_vector_store(doc_id)
chunks = store.get_chunks()
existing_chunk = next((c for c in chunks if c["doc_id"] == chunk_id), None)

View File

@@ -3,18 +3,14 @@
import json
import math
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, redirect, request
from flask_restx import fields, Namespace, Resource
from application.api import api
from application.api.user.tasks import reingest_source_task, sync_source
from application.api.user.base import sources_collection
from application.api.user.tasks import sync_source
from application.core.settings import settings
from application.parser.remote.remote_creator import normalize_remote_data
from application.storage.db.repositories.ingest_chunk_progress import (
IngestChunkProgressRepository,
)
from application.storage.db.repositories.sources import SourcesRepository
from application.storage.db.session import db_readonly, db_session
from application.storage.storage_creator import StorageCreator
from application.utils import check_required_fields
from application.vectorstore.vector_creator import VectorCreator
@@ -60,20 +56,11 @@ class CombinedJson(Resource):
]
try:
with db_readonly() as conn:
indexes = SourcesRepository(conn).list_for_user(user)
# list_for_user sorts by created_at DESC; legacy shape sorted by
# "date" DESC. Both are monotonic on creation so the ordering is
# equivalent for dev; re-sort defensively.
indexes = sorted(
indexes, key=lambda r: r.get("date") or r.get("created_at") or "",
reverse=True,
)
for index in indexes:
for index in sources_collection.find({"user": user}).sort("date", -1):
provider = _get_provider_from_remote_data(index.get("remote_data"))
data.append(
{
"id": str(index["id"]),
"id": str(index["_id"]),
"name": index.get("name"),
"date": index.get("date"),
"model": settings.EMBEDDINGS_NAME,
@@ -83,7 +70,9 @@ class CombinedJson(Resource):
"syncFrequency": index.get("sync_frequency", ""),
"provider": provider,
"is_nested": bool(index.get("directory_structure")),
"type": index.get("type", "file"),
"type": index.get(
"type", "file"
), # Add type field with default "file"
}
)
except Exception as err:
@@ -100,57 +89,61 @@ class PaginatedSources(Resource):
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
sort_field = request.args.get("sort", "date")
sort_order = request.args.get("order", "desc")
page = max(1, int(request.args.get("page", 1)))
rows_per_page = max(1, int(request.args.get("rows", 10)))
search_term = request.args.get("search", "").strip() or None
sort_field = request.args.get("sort", "date") # Default to 'date'
sort_order = request.args.get("order", "desc") # Default to 'desc'
page = int(request.args.get("page", 1)) # Default to 1
rows_per_page = int(request.args.get("rows", 10)) # Default to 10
# add .strip() to remove leading and trailing whitespaces
search_term = request.args.get(
"search", ""
).strip() # add search for filter documents
# Prepare query for filtering
query = {"user": user}
if search_term:
query["name"] = {
"$regex": search_term,
"$options": "i", # using case-insensitive search
}
total_documents = sources_collection.count_documents(query)
total_pages = max(1, math.ceil(total_documents / rows_per_page))
page = min(
max(1, page), total_pages
) # add this to make sure page inbound is within the range
sort_order = 1 if sort_order == "asc" else -1
skip = (page - 1) * rows_per_page
try:
with db_readonly() as conn:
repo = SourcesRepository(conn)
total_documents = repo.count_for_user(
user, search_term=search_term,
)
# Prior in-Python implementation returned ``totalPages = 1``
# for empty result sets (``max(1, ceil(0/rows))``); we
# preserve that contract so the frontend pager stays stable.
total_pages = max(1, math.ceil(total_documents / rows_per_page))
effective_page = min(page, total_pages)
offset = (effective_page - 1) * rows_per_page
window = repo.list_for_user(
user,
limit=rows_per_page,
offset=offset,
search_term=search_term,
sort_field=sort_field,
sort_order=sort_order,
)
documents = (
sources_collection.find(query)
.sort(sort_field, sort_order)
.skip(skip)
.limit(rows_per_page)
)
paginated_docs = []
for doc in window:
for doc in documents:
provider = _get_provider_from_remote_data(doc.get("remote_data"))
paginated_docs.append(
{
"id": str(doc["id"]),
"name": doc.get("name", ""),
"date": doc.get("date", ""),
"model": settings.EMBEDDINGS_NAME,
"location": "local",
"tokens": doc.get("tokens", ""),
"retriever": doc.get("retriever", "classic"),
"syncFrequency": doc.get("sync_frequency", ""),
"provider": provider,
"isNested": bool(doc.get("directory_structure")),
"type": doc.get("type", "file"),
# Derived in SourcesRepository.list_for_user.
"ingestStatus": doc.get("ingest_status"),
}
)
doc_data = {
"id": str(doc["_id"]),
"name": doc.get("name", ""),
"date": doc.get("date", ""),
"model": settings.EMBEDDINGS_NAME,
"location": "local",
"tokens": doc.get("tokens", ""),
"retriever": doc.get("retriever", "classic"),
"syncFrequency": doc.get("sync_frequency", ""),
"provider": provider,
"isNested": bool(doc.get("directory_structure")),
"type": doc.get("type", "file"),
}
paginated_docs.append(doc_data)
response = {
"total": total_documents,
"totalPages": total_pages,
"currentPage": effective_page,
"currentPage": page,
"paginated": paginated_docs,
}
return make_response(jsonify(response), 200)
@@ -161,6 +154,28 @@ class PaginatedSources(Resource):
return make_response(jsonify({"success": False}), 400)
@sources_ns.route("/delete_by_ids")
class DeleteByIds(Resource):
@api.doc(
description="Deletes documents from the vector store by IDs",
params={"path": "Comma-separated list of IDs"},
)
def get(self):
ids = request.args.get("path")
if not ids:
return make_response(
jsonify({"success": False, "message": "Missing required fields"}), 400
)
try:
result = sources_collection.delete_index(ids=ids)
if result:
return make_response(jsonify({"success": True}), 200)
except Exception as err:
current_app.logger.error(f"Error deleting indexes: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": False}), 400)
@sources_ns.route("/delete_old")
class DeleteOldIndexes(Resource):
@api.doc(
@@ -171,33 +186,30 @@ class DeleteOldIndexes(Resource):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
source_id = request.args.get("source_id")
if not source_id:
return make_response(
jsonify({"success": False, "message": "Missing required fields"}), 400
)
try:
with db_readonly() as conn:
doc = SourcesRepository(conn).get_any(source_id, user)
except Exception as err:
current_app.logger.error(f"Error looking up source: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
doc = sources_collection.find_one(
{"_id": ObjectId(source_id), "user": decoded_token.get("sub")}
)
if not doc:
return make_response(jsonify({"status": "not found"}), 404)
storage = StorageCreator.get_storage()
resolved_id = str(doc["id"])
try:
# Delete vector index
if settings.VECTOR_STORE == "faiss":
index_path = f"indexes/{resolved_id}"
index_path = f"indexes/{str(doc['_id'])}"
if storage.file_exists(f"{index_path}/index.faiss"):
storage.delete_file(f"{index_path}/index.faiss")
if storage.file_exists(f"{index_path}/index.pkl"):
storage.delete_file(f"{index_path}/index.pkl")
else:
vectorstore = VectorCreator.create_vectorstore(
settings.VECTOR_STORE, source_id=resolved_id
settings.VECTOR_STORE, source_id=str(doc["_id"])
)
vectorstore.delete_index()
if "file_path" in doc and doc["file_path"]:
@@ -215,14 +227,7 @@ class DeleteOldIndexes(Resource):
f"Error deleting files and indexes: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
try:
with db_session() as conn:
SourcesRepository(conn).delete(resolved_id, user)
except Exception as err:
current_app.logger.error(
f"Error deleting source row: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
sources_collection.delete_one({"_id": ObjectId(source_id)})
return make_response(jsonify({"success": True}), 200)
@@ -267,16 +272,15 @@ class ManageSync(Resource):
return make_response(
jsonify({"success": False, "message": "Invalid frequency"}), 400
)
update_data = {"$set": {"sync_frequency": sync_frequency}}
try:
with db_session() as conn:
repo = SourcesRepository(conn)
doc = repo.get_any(source_id, user)
if doc is None:
return make_response(
jsonify({"success": False, "message": "Source not found"}),
404,
)
repo.update(str(doc["id"]), user, {"sync_frequency": sync_frequency})
sources_collection.update_one(
{
"_id": ObjectId(source_id),
"user": user,
},
update_data,
)
except Exception as err:
current_app.logger.error(
f"Error updating sync frequency: {err}", exc_info=True
@@ -305,20 +309,19 @@ class SyncSource(Resource):
if missing_fields:
return missing_fields
source_id = data["source_id"]
try:
with db_readonly() as conn:
doc = SourcesRepository(conn).get_any(source_id, user)
except Exception as err:
current_app.logger.error(f"Error looking up source: {err}", exc_info=True)
if not ObjectId.is_valid(source_id):
return make_response(
jsonify({"success": False, "message": "Invalid source ID"}), 400
)
doc = sources_collection.find_one(
{"_id": ObjectId(source_id), "user": user}
)
if not doc:
return make_response(
jsonify({"success": False, "message": "Source not found"}), 404
)
source_type = doc.get("type", "")
if source_type and source_type.startswith("connector"):
if source_type.startswith("connector"):
return make_response(
jsonify(
{
@@ -328,7 +331,7 @@ class SyncSource(Resource):
),
400,
)
source_data = normalize_remote_data(source_type, doc.get("remote_data"))
source_data = doc.get("remote_data")
if not source_data:
return make_response(
jsonify({"success": False, "message": "Source is not syncable"}), 400
@@ -341,7 +344,7 @@ class SyncSource(Resource):
loader=source_type,
sync_frequency=doc.get("sync_frequency", "never"),
retriever=doc.get("retriever", "classic"),
doc_id=str(doc["id"]),
doc_id=source_id,
)
except Exception as err:
current_app.logger.error(
@@ -352,70 +355,6 @@ class SyncSource(Resource):
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
@sources_ns.route("/sources/reingest")
class ReingestSource(Resource):
reingest_source_model = api.model(
"ReingestSourceModel",
{"source_id": fields.String(required=True, description="Source ID")},
)
@api.expect(reingest_source_model)
@api.doc(
description="Re-run ingestion for a source — e.g. to recover a "
"stalled embed flagged by the reconciler."
)
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json() or {}
missing_fields = check_required_fields(data, ["source_id"])
if missing_fields:
return missing_fields
source_id = data["source_id"]
try:
with db_readonly() as conn:
doc = SourcesRepository(conn).get_any(source_id, user)
except Exception as err:
current_app.logger.error(
f"Error looking up source: {err}", exc_info=True
)
return make_response(
jsonify({"success": False, "message": "Invalid source ID"}), 400
)
if not doc:
return make_response(
jsonify({"success": False, "message": "Source not found"}), 404
)
resolved_source_id = str(doc["id"])
# Drop the stale chunk-progress row so the sources list stops
# deriving a 'failed' status; reingest never rewrites it itself.
try:
with db_session() as conn:
IngestChunkProgressRepository(conn).delete(resolved_source_id)
except Exception as err:
current_app.logger.warning(
f"Could not clear ingest progress for {resolved_source_id}: "
f"{err}",
exc_info=True,
)
try:
# Scoped key so repeated clicks collapse onto one reingest.
task = reingest_source_task.delay(
source_id=resolved_source_id,
user=user,
idempotency_key=f"reingest-source:{user}:{resolved_source_id}",
)
except Exception as err:
current_app.logger.error(
f"Error starting reingest for source {source_id}: {err}",
exc_info=True,
)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
@sources_ns.route("/directory_structure")
class DirectoryStructure(Resource):
@api.doc(
@@ -431,9 +370,10 @@ class DirectoryStructure(Resource):
if not doc_id:
return make_response(jsonify({"error": "Document ID is required"}), 400)
if not ObjectId.is_valid(doc_id):
return make_response(jsonify({"error": "Invalid document ID"}), 400)
try:
with db_readonly() as conn:
doc = SourcesRepository(conn).get_any(doc_id, user)
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
if not doc:
return make_response(
jsonify({"error": "Document not found or access denied"}), 404
@@ -447,8 +387,6 @@ class DirectoryStructure(Resource):
if isinstance(remote_data, str) and remote_data:
remote_data_obj = json.loads(remote_data)
provider = remote_data_obj.get("provider")
elif isinstance(remote_data, dict):
provider = remote_data.get("provider")
except Exception as e:
current_app.logger.warning(
f"Failed to parse remote_data for doc {doc_id}: {e}"
@@ -468,7 +406,4 @@ class DirectoryStructure(Resource):
current_app.logger.error(
f"Error retrieving directory structure: {e}", exc_info=True
)
return make_response(
jsonify({"success": False, "error": "Failed to retrieve directory structure"}),
500,
)
return make_response(jsonify({"success": False, "error": "Failed to retrieve directory structure"}), 500)

View File

@@ -3,22 +3,18 @@
import json
import os
import tempfile
import uuid
import zipfile
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, Namespace, Resource
from sqlalchemy import text as sql_text
from application.api import api
from application.api.user.base import sources_collection
from application.api.user.tasks import ingest, ingest_connector_task, ingest_remote
from application.core.settings import settings
from application.storage.db.source_ids import derive_source_id as _derive_source_id
from application.parser.connectors.connector_creator import ConnectorCreator
from application.parser.file.constants import SUPPORTED_SOURCE_EXTENSIONS
from application.storage.db.repositories.idempotency import IdempotencyRepository
from application.storage.db.repositories.sources import SourcesRepository
from application.storage.db.session import db_readonly, db_session
from application.storage.storage_creator import StorageCreator
from application.stt.upload_limits import (
AudioFileTooLargeError,
@@ -34,91 +30,6 @@ sources_upload_ns = Namespace(
)
_IDEMPOTENCY_KEY_MAX_LEN = 256
def _read_idempotency_key():
"""Return (key, error_response). Empty header → (None, None); oversized → (None, 400)."""
key = request.headers.get("Idempotency-Key")
if not key:
return None, None
if len(key) > _IDEMPOTENCY_KEY_MAX_LEN:
return None, make_response(
jsonify(
{
"success": False,
"message": (
f"Idempotency-Key exceeds maximum length of "
f"{_IDEMPOTENCY_KEY_MAX_LEN} characters"
),
}
),
400,
)
return key, None
def _scoped_idempotency_key(idempotency_key, scope):
"""``{scope}:{key}`` so different users can't collide on the same key."""
if not idempotency_key or not scope:
return None
return f"{scope}:{idempotency_key}"
def _claim_task_or_get_cached(key, task_name):
"""Claim ``key`` for this request OR return the winner's cached payload.
Pre-generates the celery task_id so a losing writer sees the same
id immediately. Returns ``(task_id, cached_response)``; non-None
cached means the caller should return without enqueuing. The
cached payload mirrors the fresh-request response shape (including
``source_id``) so the frontend can correlate SSE ingest events to
the cached upload task without an extra round-trip — but only when
the cached row actually exists; the "deduplicated" sentinel
deliberately omits ``source_id`` so the frontend doesn't bind to a
phantom source.
"""
predetermined_id = str(uuid.uuid4())
with db_session() as conn:
claimed = IdempotencyRepository(conn).claim_task(
key=key, task_name=task_name, task_id=predetermined_id,
)
if claimed is not None:
return claimed["task_id"], None
with db_readonly() as conn:
existing = IdempotencyRepository(conn).get_task(key)
cached_id = existing.get("task_id") if existing else None
payload: dict = {
"success": True,
"task_id": cached_id or "deduplicated",
}
# Only surface ``source_id`` when there's a real winner whose worker
# is publishing SSE events tagged with that id. The "deduplicated"
# branch means the lock row vanished — we have nothing to correlate.
if cached_id is not None:
payload["source_id"] = str(_derive_source_id(key))
return None, payload
def _release_claim(key):
"""Drop a pending claim so a client retry can re-claim it."""
try:
with db_session() as conn:
conn.execute(
sql_text(
"DELETE FROM task_dedup WHERE idempotency_key = :k "
"AND status = 'pending'"
),
{"k": key},
)
except Exception:
current_app.logger.exception(
"Failed to release task_dedup claim for key=%s", key,
)
def _enforce_audio_path_size_limit(file_path: str, filename: str) -> None:
if not is_audio_filename(filename):
return
@@ -138,38 +49,17 @@ class UploadFile(Resource):
)
)
@api.doc(
description=(
"Uploads a file to be vectorized and indexed. Honors an optional "
"``Idempotency-Key`` header: a repeat request with the same key "
"within 24h returns the original cached response without re-enqueuing."
),
description="Uploads a file to be vectorized and indexed",
)
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
idempotency_key, key_error = _read_idempotency_key()
if key_error is not None:
return key_error
# User-scoped to avoid cross-user collisions; also feeds
# ``_derive_source_id`` so uuid5 stays user-disjoint.
scoped_key = _scoped_idempotency_key(idempotency_key, user)
# Claim before enqueue; the loser returns the winner's task_id.
predetermined_task_id = None
if scoped_key:
predetermined_task_id, cached = _claim_task_or_get_cached(
scoped_key, "ingest",
)
if cached is not None:
return make_response(jsonify(cached), 200)
data = request.form
files = request.files.getlist("file")
required_fields = ["user", "name"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields or not files or all(file.filename == "" for file in files):
if scoped_key:
_release_claim(scoped_key)
return make_response(
jsonify(
{
@@ -179,6 +69,7 @@ class UploadFile(Resource):
),
400,
)
user = decoded_token.get("sub")
job_name = request.form["name"]
# Create safe versions for filesystem operations
@@ -249,37 +140,16 @@ class UploadFile(Resource):
file_path = f"{base_path}/{safe_file}"
with open(temp_file_path, "rb") as f:
storage.save_file(f, file_path)
# Mint the source UUID up here so the HTTP response and the
# worker's SSE envelopes share one id. With an idempotency
# key we reuse the deterministic uuid5 (retried task lands on
# the same source row); without a key we fall back to uuid4.
# The worker is told to use this id verbatim — see
# ``ingest_worker(source_id=...)``.
source_uuid = (
_derive_source_id(scoped_key) if scoped_key else uuid.uuid4()
task = ingest.delay(
settings.UPLOAD_FOLDER,
list(SUPPORTED_SOURCE_EXTENSIONS),
job_name,
user,
file_path=base_path,
filename=dir_name,
file_name_map=file_name_map,
)
ingest_kwargs = dict(
args=(
settings.UPLOAD_FOLDER,
list(SUPPORTED_SOURCE_EXTENSIONS),
job_name,
user,
),
kwargs={
"file_path": base_path,
"filename": dir_name,
"file_name_map": file_name_map,
# Scoped so the worker dedup row matches the HTTP claim.
"idempotency_key": scoped_key or idempotency_key,
"source_id": str(source_uuid),
},
)
if predetermined_task_id is not None:
ingest_kwargs["task_id"] = predetermined_task_id
task = ingest.apply_async(**ingest_kwargs)
except AudioFileTooLargeError:
if scoped_key:
_release_claim(scoped_key)
return make_response(
jsonify(
{
@@ -291,21 +161,8 @@ class UploadFile(Resource):
)
except Exception as err:
current_app.logger.error(f"Error uploading file: {err}", exc_info=True)
if scoped_key:
_release_claim(scoped_key)
return make_response(jsonify({"success": False}), 400)
# Predetermined id matches the dedup-claim row; loser GET sees same.
response_task_id = predetermined_task_id or task.id
# ``source_uuid`` was minted above and passed to the worker as
# ``source_id``; the worker uses it verbatim for every SSE event,
# so the frontend can correlate inbound ``source.ingest.*`` to
# this upload regardless of whether an idempotency key was set.
response_payload: dict = {
"success": True,
"task_id": response_task_id,
"source_id": str(source_uuid),
}
return make_response(jsonify(response_payload), 200)
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
@sources_upload_ns.route("/remote")
@@ -325,50 +182,17 @@ class UploadRemote(Resource):
)
)
@api.doc(
description=(
"Uploads remote source for vectorization. Honors an optional "
"``Idempotency-Key`` header: a repeat request with the same key "
"within 24h returns the original cached response without re-enqueuing."
),
description="Uploads remote source for vectorization",
)
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
idempotency_key, key_error = _read_idempotency_key()
if key_error is not None:
return key_error
scoped_key = _scoped_idempotency_key(idempotency_key, user)
data = request.form
required_fields = ["user", "source", "name", "data"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
task_name_for_dedup = (
"ingest_connector_task"
if data.get("source") in ConnectorCreator.get_supported_connectors()
else "ingest_remote"
)
predetermined_task_id = None
if scoped_key:
predetermined_task_id, cached = _claim_task_or_get_cached(
scoped_key, task_name_for_dedup,
)
if cached is not None:
return make_response(jsonify(cached), 200)
# Mint the source UUID up here so the HTTP response and the
# worker's SSE envelopes share one id. Same pattern as
# ``UploadFile.post``: with an idempotency key we reuse the
# deterministic uuid5 (retried task lands on the same source
# row); without a key we fall back to uuid4. The worker is told
# to use this id verbatim — see ``remote_worker`` and
# ``ingest_connector``. Without this the no-key path would mint
# a random uuid4 inside the worker that the frontend has no way
# to correlate SSE events to.
source_uuid = (
_derive_source_id(scoped_key) if scoped_key else uuid.uuid4()
)
try:
config = json.loads(data["data"])
source_data = None
@@ -384,8 +208,6 @@ class UploadRemote(Resource):
elif data["source"] in ConnectorCreator.get_supported_connectors():
session_token = config.get("session_token")
if not session_token:
if scoped_key:
_release_claim(scoped_key)
return make_response(
jsonify(
{
@@ -414,62 +236,31 @@ class UploadRemote(Resource):
config["file_ids"] = file_ids
config["folder_ids"] = folder_ids
connector_kwargs = {
"kwargs": {
"job_name": data["name"],
"user": user,
"source_type": data["source"],
"session_token": session_token,
"file_ids": file_ids,
"folder_ids": folder_ids,
"recursive": config.get("recursive", False),
"retriever": config.get("retriever", "classic"),
"idempotency_key": scoped_key or idempotency_key,
"source_id": str(source_uuid),
},
}
if predetermined_task_id is not None:
connector_kwargs["task_id"] = predetermined_task_id
task = ingest_connector_task.apply_async(**connector_kwargs)
response_task_id = predetermined_task_id or task.id
# ``source_uuid`` was minted above and passed to the
# worker as ``source_id``; the worker uses it verbatim
# for every SSE event, so the frontend can correlate
# inbound ``source.ingest.*`` regardless of whether an
# idempotency key was set.
response_payload = {
"success": True,
"task_id": response_task_id,
"source_id": str(source_uuid),
}
return make_response(jsonify(response_payload), 200)
remote_kwargs = {
"kwargs": {
"source_data": source_data,
"job_name": data["name"],
"user": user,
"loader": data["source"],
"idempotency_key": scoped_key or idempotency_key,
"source_id": str(source_uuid),
},
}
if predetermined_task_id is not None:
remote_kwargs["task_id"] = predetermined_task_id
task = ingest_remote.apply_async(**remote_kwargs)
task = ingest_connector_task.delay(
job_name=data["name"],
user=decoded_token.get("sub"),
source_type=data["source"],
session_token=session_token,
file_ids=file_ids,
folder_ids=folder_ids,
recursive=config.get("recursive", False),
retriever=config.get("retriever", "classic"),
)
return make_response(
jsonify({"success": True, "task_id": task.id}), 200
)
task = ingest_remote.delay(
source_data=source_data,
job_name=data["name"],
user=decoded_token.get("sub"),
loader=data["source"],
)
except Exception as err:
current_app.logger.error(
f"Error uploading remote source: {err}", exc_info=True
)
if scoped_key:
_release_claim(scoped_key)
return make_response(jsonify({"success": False}), 400)
response_task_id = predetermined_task_id or task.id
response_payload = {
"success": True,
"task_id": response_task_id,
"source_id": str(source_uuid),
}
return make_response(jsonify(response_payload), 200)
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
@sources_upload_ns.route("/manage_source_files")
@@ -514,10 +305,6 @@ class ManageSourceFiles(Resource):
jsonify({"success": False, "message": "Unauthorized"}), 401
)
user = decoded_token.get("sub")
idempotency_key, key_error = _read_idempotency_key()
if key_error is not None:
return key_error
scoped_key = _scoped_idempotency_key(idempotency_key, user)
source_id = request.form.get("source_id")
operation = request.form.get("operation")
@@ -542,8 +329,15 @@ class ManageSourceFiles(Resource):
400,
)
try:
with db_readonly() as conn:
source = SourcesRepository(conn).get_any(source_id, user)
ObjectId(source_id)
except Exception:
return make_response(
jsonify({"success": False, "message": "Invalid source ID format"}), 400
)
try:
source = sources_collection.find_one(
{"_id": ObjectId(source_id), "user": user}
)
if not source:
return make_response(
jsonify(
@@ -559,13 +353,6 @@ class ManageSourceFiles(Resource):
return make_response(
jsonify({"success": False, "message": "Database error"}), 500
)
resolved_source_id = str(source["id"])
# Flips to True after each branch's ``apply_async`` returns
# successfully — at that point the worker owns the predetermined
# task_id. The outer ``except`` only releases the claim while
# this is False, so a post-``apply_async`` failure (jsonify,
# make_response, etc.) doesn't double-enqueue on the next retry.
claim_transferred = False
try:
storage = StorageCreator.get_storage()
source_file_path = source.get("file_path", "")
@@ -598,34 +385,6 @@ class ManageSourceFiles(Resource):
),
400,
)
# Claim before any storage mutation so a duplicate request
# short-circuits without touching the filesystem. Mirrors
# the pattern in ``UploadFile.post`` / ``UploadRemote.post``
# — without it ``.delay()`` would enqueue twice for two
# racing same-key POSTs (the worker decorator only
# deduplicates *after* completion).
predetermined_task_id = None
if scoped_key:
predetermined_task_id, cached = _claim_task_or_get_cached(
scoped_key, "reingest_source_task",
)
if cached is not None:
# Frontend keys reingest polling on
# ``reingest_task_id``; the shared cache helper
# writes ``task_id``. Alias here so a dedup
# response doesn't silently break FileTree's
# poller. Override ``source_id`` too — the
# helper derives it from the scoped key, which
# is correct for upload but wrong for reingest
# (the worker publishes events scoped to the
# actual source row id).
cached_task_id = cached.pop("task_id", None)
if cached_task_id is not None:
cached["reingest_task_id"] = cached_task_id
cached["source_id"] = resolved_source_id
return make_response(jsonify(cached), 200)
added_files = []
map_updated = False
@@ -652,24 +411,15 @@ class ManageSourceFiles(Resource):
map_updated = True
if map_updated:
with db_session() as conn:
SourcesRepository(conn).update(
resolved_source_id, user,
{"file_name_map": dict(file_name_map)},
)
sources_collection.update_one(
{"_id": ObjectId(source_id)},
{"$set": {"file_name_map": file_name_map}},
)
# Trigger re-ingestion pipeline
from application.api.user.tasks import reingest_source_task
task = reingest_source_task.apply_async(
kwargs={
"source_id": resolved_source_id,
"user": user,
"idempotency_key": scoped_key or idempotency_key,
},
task_id=predetermined_task_id,
)
claim_transferred = True
task = reingest_source_task.delay(source_id=source_id, user=user)
return make_response(
jsonify(
@@ -679,12 +429,6 @@ class ManageSourceFiles(Resource):
"added_files": added_files,
"parent_dir": parent_dir,
"reingest_task_id": task.id,
# ``source_id`` lets the frontend correlate
# inbound ``source.ingest.*`` SSE events
# (emitted by ``reingest_source_worker``)
# back to the reingest task — matches the
# upload route's source-id contract.
"source_id": resolved_source_id,
}
),
200,
@@ -714,8 +458,10 @@ class ManageSourceFiles(Resource):
),
400,
)
# Path-traversal guard runs *before* the claim so a 400
# for an invalid path doesn't leave a pending dedup row.
# Remove files from storage and directory structure
removed_files = []
map_updated = False
for file_path in file_paths:
if ".." in str(file_path) or str(file_path).startswith("/"):
return make_response(
@@ -727,31 +473,6 @@ class ManageSourceFiles(Resource):
),
400,
)
# Claim before any storage mutation. See ``add`` branch
# comment for rationale.
predetermined_task_id = None
if scoped_key:
predetermined_task_id, cached = _claim_task_or_get_cached(
scoped_key, "reingest_source_task",
)
if cached is not None:
cached_task_id = cached.pop("task_id", None)
if cached_task_id is not None:
cached["reingest_task_id"] = cached_task_id
# Override the helper's synthetic source_id (uuid5
# of the scoped key) with the real source row id
# — the reingest worker publishes SSE events
# scoped to ``resolved_source_id`` and FileTree
# correlates on it.
cached["source_id"] = resolved_source_id
return make_response(jsonify(cached), 200)
# Remove files from storage and directory structure
removed_files = []
map_updated = False
for file_path in file_paths:
full_path = f"{source_file_path}/{file_path}"
# Remove from storage
@@ -764,24 +485,15 @@ class ManageSourceFiles(Resource):
map_updated = True
if map_updated and isinstance(file_name_map, dict):
with db_session() as conn:
SourcesRepository(conn).update(
resolved_source_id, user,
{"file_name_map": dict(file_name_map)},
)
sources_collection.update_one(
{"_id": ObjectId(source_id)},
{"$set": {"file_name_map": file_name_map}},
)
# Trigger re-ingestion pipeline
from application.api.user.tasks import reingest_source_task
task = reingest_source_task.apply_async(
kwargs={
"source_id": resolved_source_id,
"user": user,
"idempotency_key": scoped_key or idempotency_key,
},
task_id=predetermined_task_id,
)
claim_transferred = True
task = reingest_source_task.delay(source_id=source_id, user=user)
return make_response(
jsonify(
@@ -790,7 +502,6 @@ class ManageSourceFiles(Resource):
"message": f"Removed {len(removed_files)} files",
"removed_files": removed_files,
"reingest_task_id": task.id,
"source_id": resolved_source_id,
}
),
200,
@@ -841,24 +552,6 @@ class ManageSourceFiles(Resource):
),
404,
)
# Claim before mutation. See ``add`` branch for rationale.
predetermined_task_id = None
if scoped_key:
predetermined_task_id, cached = _claim_task_or_get_cached(
scoped_key, "reingest_source_task",
)
if cached is not None:
cached_task_id = cached.pop("task_id", None)
if cached_task_id is not None:
cached["reingest_task_id"] = cached_task_id
# Same source_id override as the ``remove`` /
# ``add`` cached branches — the helper's synthetic
# id doesn't match what reingest_source_worker
# tags its SSE events with.
cached["source_id"] = resolved_source_id
return make_response(jsonify(cached), 200)
success = storage.remove_directory(full_directory_path)
if not success:
@@ -867,11 +560,6 @@ class ManageSourceFiles(Resource):
f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}, "
f"Full path: {full_directory_path}"
)
# Release so a client retry can reclaim — otherwise
# the next request would silently 200-cache to the
# task_id that never enqueued.
if scoped_key:
_release_claim(scoped_key)
return make_response(
jsonify(
{"success": False, "message": "Failed to remove directory"}
@@ -893,25 +581,16 @@ class ManageSourceFiles(Resource):
if keys_to_remove:
for key in keys_to_remove:
file_name_map.pop(key, None)
with db_session() as conn:
SourcesRepository(conn).update(
resolved_source_id, user,
{"file_name_map": dict(file_name_map)},
)
sources_collection.update_one(
{"_id": ObjectId(source_id)},
{"$set": {"file_name_map": file_name_map}},
)
# Trigger re-ingestion pipeline
from application.api.user.tasks import reingest_source_task
task = reingest_source_task.apply_async(
kwargs={
"source_id": resolved_source_id,
"user": user,
"idempotency_key": scoped_key or idempotency_key,
},
task_id=predetermined_task_id,
)
claim_transferred = True
task = reingest_source_task.delay(source_id=source_id, user=user)
return make_response(
jsonify(
@@ -920,20 +599,11 @@ class ManageSourceFiles(Resource):
"message": f"Successfully removed directory: {directory_path}",
"removed_directory": directory_path,
"reingest_task_id": task.id,
"source_id": resolved_source_id,
}
),
200,
)
except Exception as err:
# Release the dedup claim only if it wasn't transferred to
# a worker. Without this, a same-key retry within the 24h
# TTL would 200-cache to a predetermined task_id whose
# ``apply_async`` never ran (or ran but the response builder
# blew up afterward — only the first case matters in
# practice; the flag protects both).
if scoped_key and not claim_transferred:
_release_claim(scoped_key)
error_context = f"operation={operation}, user={user}, source_id={source_id}"
if operation == "remove_directory":
directory_path = request.form.get("directory_path", "")

View File

@@ -1,79 +1,21 @@
from datetime import timedelta
from application.api.user.idempotency import with_idempotency
from application.celery_init import celery
from application.worker import (
agent_webhook_worker,
attachment_worker,
ingest_worker,
mcp_oauth,
mcp_oauth_status,
remote_worker,
sync,
sync_worker,
)
# Shared decorator config for long-running, side-effecting tasks. ``acks_late``
# is also the celeryconfig default but stays explicit here so each task's
# durability story is grep-able next to the body. Combined with
# ``autoretry_for=(Exception,)`` and a bounded ``max_retries`` so a poison
# message can't loop forever.
DURABLE_TASK = dict(
bind=True,
acks_late=True,
autoretry_for=(Exception,),
retry_kwargs={"max_retries": 3, "countdown": 60},
retry_backoff=True,
)
# operation tag for the poison-path source.ingest.failed event, per task.
_INGEST_POISON_OPERATION = {
"ingest": "upload",
"ingest_remote": "upload",
"ingest_connector_task": "upload",
"reingest_source_task": "reingest",
}
def _emit_ingest_poison_event(task_name, bound):
"""Publish a terminal ``source.ingest.failed`` when the poison-guard trips.
The guard returns before the worker runs, so the worker's own failed
event never fires — without this the upload toast spins on "training".
"""
user = bound.get("user")
source_id = bound.get("source_id")
if not user or not source_id:
return
from application.events.publisher import publish_user_event
publish_user_event(
user,
"source.ingest.failed",
{
"source_id": str(source_id),
"filename": bound.get("filename") or "",
"operation": _INGEST_POISON_OPERATION.get(task_name, "upload"),
"error": "Ingestion stopped after repeated failures.",
},
scope={"kind": "source", "id": str(source_id)},
)
@celery.task(**DURABLE_TASK)
@with_idempotency(task_name="ingest", on_poison=_emit_ingest_poison_event)
@celery.task(bind=True)
def ingest(
self,
directory,
formats,
job_name,
user,
file_path,
filename,
file_name_map=None,
idempotency_key=None,
source_id=None,
self, directory, formats, job_name, user, file_path, filename, file_name_map=None
):
resp = ingest_worker(
self,
@@ -84,42 +26,25 @@ def ingest(
filename,
user,
file_name_map=file_name_map,
idempotency_key=idempotency_key,
source_id=source_id,
)
return resp
@celery.task(**DURABLE_TASK)
@with_idempotency(task_name="ingest_remote", on_poison=_emit_ingest_poison_event)
def ingest_remote(
self, source_data, job_name, user, loader,
idempotency_key=None, source_id=None,
):
resp = remote_worker(
self, source_data, job_name, user, loader,
idempotency_key=idempotency_key,
source_id=source_id,
)
@celery.task(bind=True)
def ingest_remote(self, source_data, job_name, user, loader):
resp = remote_worker(self, source_data, job_name, user, loader)
return resp
@celery.task(**DURABLE_TASK)
@with_idempotency(
task_name="reingest_source_task", on_poison=_emit_ingest_poison_event,
)
def reingest_source_task(self, source_id, user, idempotency_key=None):
@celery.task(bind=True)
def reingest_source_task(self, source_id, user):
from application.worker import reingest_source_worker
resp = reingest_source_worker(self, source_id, user)
return resp
# Beat-driven dispatch tasks default to ``acks_late=False``: a SIGKILL
# of a beat tick is harmless to redeliver only if the dispatch itself is
# idempotent. We keep these early-ACK so the broker doesn't replay a
# dispatch that already enqueued downstream work.
@celery.task(bind=True, acks_late=False)
@celery.task(bind=True)
def schedule_syncs(self, frequency):
resp = sync_worker(self, frequency)
return resp
@@ -149,24 +74,19 @@ def sync_source(
return resp
@celery.task(**DURABLE_TASK)
@with_idempotency(task_name="store_attachment")
def store_attachment(self, file_info, user, idempotency_key=None):
@celery.task(bind=True)
def store_attachment(self, file_info, user):
resp = attachment_worker(self, file_info, user)
return resp
@celery.task(**DURABLE_TASK)
@with_idempotency(task_name="process_agent_webhook")
def process_agent_webhook(self, agent_id, payload, idempotency_key=None):
@celery.task(bind=True)
def process_agent_webhook(self, agent_id, payload):
resp = agent_webhook_worker(self, agent_id, payload)
return resp
@celery.task(**DURABLE_TASK)
@with_idempotency(
task_name="ingest_connector_task", on_poison=_emit_ingest_poison_event,
)
@celery.task(bind=True)
def ingest_connector_task(
self,
job_name,
@@ -180,8 +100,6 @@ def ingest_connector_task(
operation_mode="upload",
doc_id=None,
sync_frequency="never",
idempotency_key=None,
source_id=None,
):
from application.worker import ingest_connector
@@ -198,70 +116,12 @@ def ingest_connector_task(
operation_mode=operation_mode,
doc_id=doc_id,
sync_frequency=sync_frequency,
idempotency_key=idempotency_key,
source_id=source_id,
)
return resp
@celery.task(bind=True, acks_late=False)
def dispatch_scheduled_runs(self):
"""Beat-driven scheduler poller (body in scheduler_dispatcher)."""
from application.api.user.scheduler_dispatcher import dispatch_due_runs
return dispatch_due_runs()
@celery.task(
bind=True,
acks_late=True,
# Not DURABLE_TASK: agent runs have side effects; blind retry would double them.
autoretry_for=(),
max_retries=0,
)
def execute_scheduled_run(self, run_id):
"""Execute one scheduled run; soft-time-limit honors SCHEDULE_RUN_TIMEOUT."""
from application.api.user.scheduler_worker import execute_scheduled_run_body
return execute_scheduled_run_body(run_id, getattr(self.request, "id", None))
# Bind runtime soft-time-limit so the prefork worker can raise mid-agent.
try:
from application.core.settings import settings as _scheduler_settings
execute_scheduled_run.soft_time_limit = max(
30, int(_scheduler_settings.SCHEDULE_RUN_TIMEOUT),
)
execute_scheduled_run.time_limit = (
execute_scheduled_run.soft_time_limit + 60
)
except Exception:
pass
@celery.task(bind=True, acks_late=False)
def cleanup_schedule_runs(self):
"""Trim ``schedule_runs`` per ``SCHEDULE_RUN_OUTPUT_RETENTION_DAYS``."""
from application.core.settings import settings
if not settings.POSTGRES_URI:
return {"deleted": 0, "skipped": "POSTGRES_URI not set"}
from application.storage.db.engine import get_engine
from application.storage.db.repositories.schedule_runs import (
ScheduleRunsRepository,
)
ttl_days = settings.SCHEDULE_RUN_OUTPUT_RETENTION_DAYS
engine = get_engine()
with engine.begin() as conn:
deleted = ScheduleRunsRepository(conn).cleanup_older_than(ttl_days)
return {"deleted": deleted, "ttl_days": ttl_days}
@celery.on_after_configure.connect
def setup_periodic_tasks(sender, **kwargs):
from application.core.settings import settings
sender.add_periodic_task(
timedelta(days=1),
schedule_syncs.s("daily"),
@@ -280,49 +140,6 @@ def setup_periodic_tasks(sender, **kwargs):
cleanup_pending_tool_state.s(),
name="cleanup-pending-tool-state",
)
# Pure housekeeping for ``task_dedup`` / ``webhook_dedup`` — the
# upsert paths already handle stale rows, so cadence only bounds
# table size. Hourly is plenty for typical traffic.
sender.add_periodic_task(
timedelta(hours=1),
cleanup_idempotency_dedup.s(),
name="cleanup-idempotency-dedup",
)
sender.add_periodic_task(
timedelta(seconds=30),
reconciliation_task.s(),
name="reconciliation",
)
sender.add_periodic_task(
timedelta(hours=7),
version_check_task.s(),
name="version-check",
)
# Bound ``message_events`` growth — every streamed SSE chunk writes
# one row, so retained chats accumulate hundreds of rows per
# message. Reconnect-replay is only meaningful for streams the user
# could plausibly still be waiting on, so 14 days is generous.
sender.add_periodic_task(
timedelta(hours=24),
cleanup_message_events.s(),
name="cleanup-message-events",
)
sender.add_periodic_task(
timedelta(hours=24),
cleanup_orphan_memories.s(),
name="cleanup-orphan-memories",
)
# Scheduler dispatcher and run-log trim.
sender.add_periodic_task(
timedelta(seconds=max(15, settings.SCHEDULE_DISPATCHER_INTERVAL)),
dispatch_scheduled_runs.s(),
name="dispatch-scheduled-runs",
)
sender.add_periodic_task(
timedelta(hours=24),
cleanup_schedule_runs.s(),
name="cleanup-schedule-runs",
)
@celery.task(bind=True)
@@ -331,12 +148,24 @@ def mcp_oauth_task(self, config, user):
return resp
@celery.task(bind=True, acks_late=False)
@celery.task(bind=True)
def mcp_oauth_status_task(self, task_id):
resp = mcp_oauth_status(self, task_id)
return resp
@celery.task(bind=True)
def cleanup_pending_tool_state(self):
"""Revert stale ``resuming`` rows, then delete TTL-expired rows."""
"""Delete pending_tool_state rows past their TTL.
Replaces Mongo's ``expireAfterSeconds=0`` TTL index — Postgres has
no native TTL, so this task runs every 60 seconds to keep
``pending_tool_state`` bounded. No-ops if ``POSTGRES_URI`` isn't
configured (keeps the task runnable in Mongo-only environments).
"""
from application.core.settings import settings
if not settings.POSTGRES_URI:
return {"deleted": 0, "reverted": 0, "skipped": "POSTGRES_URI not set"}
return {"deleted": 0, "skipped": "POSTGRES_URI not set"}
from application.storage.db.engine import get_engine
from application.storage.db.repositories.pending_tool_state import (
@@ -345,103 +174,5 @@ def cleanup_pending_tool_state(self):
engine = get_engine()
with engine.begin() as conn:
repo = PendingToolStateRepository(conn)
reverted = repo.revert_stale_resuming(grace_seconds=600)
deleted = repo.cleanup_expired()
return {"deleted": deleted, "reverted": reverted}
@celery.task(bind=True, acks_late=False)
def cleanup_idempotency_dedup(self):
"""Delete TTL-expired rows from ``task_dedup`` and ``webhook_dedup``.
Pure housekeeping — the upsert paths already ignore stale rows
(TTL-aware ``ON CONFLICT DO UPDATE``), so this only bounds table
growth and keeps SELECT planning tight on large deployments.
"""
from application.core.settings import settings
if not settings.POSTGRES_URI:
return {
"task_dedup_deleted": 0,
"webhook_dedup_deleted": 0,
"skipped": "POSTGRES_URI not set",
}
from application.storage.db.engine import get_engine
from application.storage.db.repositories.idempotency import (
IdempotencyRepository,
)
engine = get_engine()
with engine.begin() as conn:
return IdempotencyRepository(conn).cleanup_expired()
@celery.task(bind=True, acks_late=False)
def reconciliation_task(self):
"""Sweep stuck durability rows and escalate them to terminal status + alert."""
from application.api.user.reconciliation import run_reconciliation
return run_reconciliation()
@celery.task(bind=True, acks_late=False)
def cleanup_message_events(self):
"""Delete ``message_events`` rows older than the retention window.
Streamed answer responses write one journal row per SSE yield,
so unbounded growth would dominate Postgres for any retained-
conversations deployment. The reconnect-replay path only needs
rows for in-flight streams; 14 days covers paused/tool-action
flows comfortably.
"""
from application.core.settings import settings
if not settings.POSTGRES_URI:
return {"deleted": 0, "skipped": "POSTGRES_URI not set"}
from application.storage.db.engine import get_engine
from application.storage.db.repositories.message_events import (
MessageEventsRepository,
)
ttl_days = settings.MESSAGE_EVENTS_RETENTION_DAYS
engine = get_engine()
with engine.begin() as conn:
deleted = MessageEventsRepository(conn).cleanup_older_than(ttl_days)
return {"deleted": deleted, "ttl_days": ttl_days}
@celery.task(bind=True, acks_late=False)
def cleanup_orphan_memories(self):
"""Sweep orphan memories left by the 0009 FK-to-trigger orphan window.
A ``memories`` INSERT for a real ``tool_id`` racing a ``user_tools``
DELETE leaves a permanent orphan the dropped FK would have rejected.
Default-tool synthetic ids are preserved (legitimate built-in data).
"""
from application.core.settings import settings
if not settings.POSTGRES_URI:
return {"deleted": 0, "skipped": "POSTGRES_URI not set"}
from application.agents.default_tools import default_tool_ids
from application.storage.db.engine import get_engine
from application.storage.db.repositories.memories import MemoriesRepository
keep_tool_ids = list(default_tool_ids().values())
engine = get_engine()
with engine.begin() as conn:
deleted = MemoriesRepository(conn).delete_orphans(keep_tool_ids)
deleted = PendingToolStateRepository(conn).cleanup_expired()
return {"deleted": deleted}
@celery.task(bind=True, acks_late=False)
def version_check_task(self):
"""Periodic anonymous version check.
Complements the ``worker_ready`` boot trigger so long-running
deployments (>6h cache TTL) still refresh advisories. ``run_check``
is fail-silent and coordinates across replicas via Redis lock +
cache (see ``application.updates.version_check``).
"""
from application.updates.version_check import run_check
run_check()

View File

@@ -1,25 +1,29 @@
"""Tool management MCP server integration."""
import json
from urllib.parse import urlencode, urlparse
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, redirect, request
from flask_restx import Namespace, Resource, fields
from application.agents.tools.mcp_tool import MCPOAuthManager, MCPTool
from application.api import api
from application.api.user.base import user_tools_collection
from application.api.user.tools.routes import transform_actions
from application.cache import get_redis_instance
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.core.url_validation import SSRFError, validate_url
from application.security.encryption import decrypt_credentials, encrypt_credentials
from application.storage.db.repositories.connector_sessions import (
ConnectorSessionsRepository,
)
from application.storage.db.repositories.user_tools import UserToolsRepository
from application.storage.db.session import db_readonly, db_session
from application.utils import check_required_fields
tools_mcp_ns = Namespace("tools", description="Tool management operations", path="/api")
_mongo = MongoDB.get_client()
_db = _mongo[settings.MONGO_DB_NAME]
_connector_sessions = _db["connector_sessions"]
_ALLOWED_TRANSPORTS = {"auto", "sse", "http"}
@@ -225,9 +229,7 @@ class MCPServerSave(Resource):
)
redis_client = get_redis_instance()
manager = MCPOAuthManager(redis_client)
result = manager.get_oauth_status(
config["oauth_task_id"], user
)
result = manager.get_oauth_status(config["oauth_task_id"])
if not result.get("status") == "completed":
return make_response(
jsonify(
@@ -250,18 +252,15 @@ class MCPServerSave(Resource):
storage_config = config.copy()
tool_id = data.get("id")
existing_doc = None
existing_encrypted = None
if tool_id:
with db_readonly() as conn:
repo = UserToolsRepository(conn)
existing_doc = repo.get_any(tool_id, user)
if existing_doc and existing_doc.get("name") == "mcp_tool":
existing_encrypted = (existing_doc.get("config") or {}).get(
existing_doc = user_tools_collection.find_one(
{"_id": ObjectId(tool_id), "user": user, "name": "mcp_tool"}
)
if existing_doc:
existing_encrypted = existing_doc.get("config", {}).get(
"encrypted_credentials"
)
else:
existing_doc = None
if auth_credentials:
if existing_encrypted:
@@ -284,88 +283,47 @@ class MCPServerSave(Resource):
]:
storage_config.pop(field, None)
transformed_actions = transform_actions(actions_metadata)
tool_data = {
"name": "mcp_tool",
"displayName": data["displayName"],
"customName": data["displayName"],
"description": f"MCP Server: {storage_config.get('server_url', 'Unknown')}",
"config": storage_config,
"actions": transformed_actions,
"status": data.get("status", True),
"user": user,
}
display_name = data["displayName"]
description = f"MCP Server: {storage_config.get('server_url', 'Unknown')}"
status_bool = bool(data.get("status", True))
with db_session() as conn:
repo = UserToolsRepository(conn)
if existing_doc:
repo.update(
str(existing_doc["id"]), user,
{
"display_name": display_name,
"custom_name": display_name,
"description": description,
"config": storage_config,
"actions": transformed_actions,
"status": status_bool,
},
)
saved_id = str(existing_doc["id"])
response_data = {
"success": True,
"id": saved_id,
"message": f"MCP server updated successfully! Discovered {len(transformed_actions)} tools.",
"tools_count": len(transformed_actions),
}
else:
# Fall back to find_by_user_and_name — the original
# dual-write path also ran an existence check before
# deciding between insert and update.
existing_by_name = repo.find_by_user_and_name(user, "mcp_tool")
if tool_id is None and existing_by_name and (
(existing_by_name.get("config") or {}).get("server_url")
== storage_config.get("server_url")
):
repo.update(
str(existing_by_name["id"]), user,
if tool_id:
result = user_tools_collection.update_one(
{"_id": ObjectId(tool_id), "user": user, "name": "mcp_tool"},
{"$set": {k: v for k, v in tool_data.items() if k != "user"}},
)
if result.matched_count == 0:
return make_response(
jsonify(
{
"display_name": display_name,
"custom_name": display_name,
"description": description,
"config": storage_config,
"actions": transformed_actions,
"status": status_bool,
},
)
saved_id = str(existing_by_name["id"])
response_data = {
"success": True,
"id": saved_id,
"message": f"MCP server updated successfully! Discovered {len(transformed_actions)} tools.",
"tools_count": len(transformed_actions),
}
else:
created = repo.create(
user, "mcp_tool",
config=storage_config,
custom_name=display_name,
display_name=display_name,
description=description,
config_requirements={},
actions=transformed_actions,
status=status_bool,
)
saved_id = str(created["id"])
response_data = {
"success": True,
"id": saved_id,
"message": f"MCP server created successfully! Discovered {len(transformed_actions)} tools.",
"tools_count": len(transformed_actions),
}
if tool_id and existing_doc is None:
# Client requested update on a non-existent tool id.
return make_response(
jsonify(
{
"success": False,
"error": "Tool not found or access denied",
}
),
404,
)
"success": False,
"error": "Tool not found or access denied",
}
),
404,
)
response_data = {
"success": True,
"id": tool_id,
"message": f"MCP server updated successfully! Discovered {len(transformed_actions)} tools.",
"tools_count": len(transformed_actions),
}
else:
result = user_tools_collection.insert_one(tool_data)
tool_id = str(result.inserted_id)
response_data = {
"success": True,
"id": tool_id,
"message": f"MCP server created successfully! Discovered {len(transformed_actions)} tools.",
"tools_count": len(transformed_actions),
}
return make_response(jsonify(response_data), 200)
except ValueError as e:
current_app.logger.warning(f"Invalid MCP server save request: {e}")
@@ -439,6 +397,56 @@ class MCPOAuthCallback(Resource):
)
@tools_mcp_ns.route("/mcp_server/oauth_status/<string:task_id>")
class MCPOAuthStatus(Resource):
def get(self, task_id):
try:
redis_client = get_redis_instance()
status_key = f"mcp_oauth_status:{task_id}"
status_data = redis_client.get(status_key)
if status_data:
status = json.loads(status_data)
if "tools" in status and isinstance(status["tools"], list):
status["tools"] = [
{
"name": t.get("name", "unknown"),
"description": t.get("description", ""),
}
for t in status["tools"]
]
return make_response(
jsonify({"success": True, "task_id": task_id, **status})
)
else:
return make_response(
jsonify(
{
"success": True,
"task_id": task_id,
"status": "pending",
"message": "Waiting for OAuth to start...",
}
),
200,
)
except Exception as e:
current_app.logger.error(
f"Error getting OAuth status for task {task_id}: {str(e)}",
exc_info=True,
)
return make_response(
jsonify(
{
"success": False,
"error": "Failed to get OAuth status",
"task_id": task_id,
}
),
500,
)
@tools_mcp_ns.route("/mcp_server/auth_status")
class MCPAuthStatus(Resource):
@api.doc(
@@ -451,59 +459,49 @@ class MCPAuthStatus(Resource):
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
try:
with db_readonly() as conn:
tools_repo = UserToolsRepository(conn)
sessions_repo = ConnectorSessionsRepository(conn)
all_tools = tools_repo.list_for_user(user)
mcp_tools = [t for t in all_tools if t.get("name") == "mcp_tool"]
if not mcp_tools:
return make_response(
jsonify({"success": True, "statuses": {}}), 200
)
mcp_tools = list(
user_tools_collection.find(
{"user": user, "name": "mcp_tool"},
{"_id": 1, "config": 1},
)
)
if not mcp_tools:
return make_response(jsonify({"success": True, "statuses": {}}), 200)
oauth_server_urls: dict = {}
statuses: dict = {}
for tool in mcp_tools:
tool_id = str(tool["id"])
config = tool.get("config") or {}
auth_type = config.get("auth_type", "none")
if auth_type == "oauth":
server_url = config.get("server_url", "")
if server_url:
parsed = urlparse(server_url)
base_url = f"{parsed.scheme}://{parsed.netloc}"
oauth_server_urls[tool_id] = base_url
else:
statuses[tool_id] = "needs_auth"
oauth_server_urls = {}
statuses = {}
for tool in mcp_tools:
tool_id = str(tool["_id"])
config = tool.get("config", {})
auth_type = config.get("auth_type", "none")
if auth_type == "oauth":
server_url = config.get("server_url", "")
if server_url:
parsed = urlparse(server_url)
base_url = f"{parsed.scheme}://{parsed.netloc}"
oauth_server_urls[tool_id] = base_url
else:
statuses[tool_id] = "configured"
statuses[tool_id] = "needs_auth"
else:
statuses[tool_id] = "configured"
if oauth_server_urls:
# Look up a session per distinct base URL. MCP sessions
# are stored with ``provider = "mcp:<server_url>"``
# and the URL in ``server_url``; reuse the repo's
# per-URL accessor rather than an ad-hoc $in query.
url_has_tokens: dict = {}
for base_url in set(oauth_server_urls.values()):
session = sessions_repo.get_by_user_and_server_url(
user, base_url,
)
tokens = (
(session or {}).get("session_data", {}) or {}
).get("tokens", {}) or {}
# MCP code also stashes tokens into token_info on
# the row; consider either present as "connected".
token_info = (session or {}).get("token_info") or {}
url_has_tokens[base_url] = bool(
tokens.get("access_token")
or token_info.get("access_token")
)
for tool_id, base_url in oauth_server_urls.items():
if url_has_tokens.get(base_url):
statuses[tool_id] = "connected"
else:
statuses[tool_id] = "needs_auth"
if oauth_server_urls:
unique_urls = list(set(oauth_server_urls.values()))
sessions = list(
_connector_sessions.find(
{"user_id": user, "server_url": {"$in": unique_urls}},
{"server_url": 1, "tokens": 1},
)
)
url_has_tokens = {
doc["server_url"]: bool(doc.get("tokens", {}).get("access_token"))
for doc in sessions
}
for tool_id, base_url in oauth_server_urls.items():
if url_has_tokens.get(base_url):
statuses[tool_id] = "connected"
else:
statuses[tool_id] = "needs_auth"
return make_response(jsonify({"success": True, "statuses": statuses}), 200)
except Exception as e:

View File

@@ -1,69 +1,23 @@
"""Tool management routes."""
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, Namespace, Resource
from application.agents.default_tools import (
builtin_agent_tools_for_management,
BUILTIN_AGENT_TOOLS,
default_tool_name_for_id,
default_tools_for_management,
is_builtin_agent_tool_id,
is_default_tool_id,
is_synthesized_tool_id,
)
from application.agents.tools.spec_parser import parse_spec
from application.agents.tools.tool_manager import ToolManager
from application.api import api
from application.api.user.base import user_tools_collection
from application.core.url_validation import SSRFError, validate_url
from application.security.encryption import decrypt_credentials, encrypt_credentials
from application.storage.db.repositories.notes import NotesRepository
from application.storage.db.repositories.todos import TodosRepository
from application.storage.db.dual_write import dual_write
from application.storage.db.repositories.user_tools import UserToolsRepository
from application.storage.db.repositories.users import UsersRepository
from application.storage.db.session import db_readonly, db_session
from application.security.encryption import decrypt_credentials, encrypt_credentials
from application.utils import check_required_fields, validate_function_name
tool_config = {}
tool_manager = ToolManager(config=tool_config)
# ---------------------------------------------------------------------------
# Shape translation helpers
# ---------------------------------------------------------------------------
# The frontend speaks camelCase (``displayName`` / ``customName`` /
# ``configRequirements``). The PG ``user_tools`` table stores snake_case
# (``display_name`` / ``custom_name`` / ``config_requirements``). Keep the
# translation localized to this module so repositories stay pure.
_CAMEL_TO_SNAKE = {
"displayName": "display_name",
"customName": "custom_name",
"configRequirements": "config_requirements",
}
_SNAKE_TO_CAMEL = {v: k for k, v in _CAMEL_TO_SNAKE.items()}
def _row_to_api(row: dict) -> dict:
"""Rename DB-native snake_case keys to the camelCase shape the frontend expects."""
out = dict(row)
for snake, camel in _SNAKE_TO_CAMEL.items():
if snake in out:
out[camel] = out.pop(snake)
# ``user_id`` is exposed as ``user`` in the legacy API shape.
if "user_id" in out:
out["user"] = out.pop("user_id")
return out
def _api_to_update_fields(data: dict) -> dict:
"""Rename incoming camelCase update keys to the repo's snake_case columns."""
fields_out: dict = {}
for key, value in data.items():
fields_out[_CAMEL_TO_SNAKE.get(key, key)] = value
return fields_out
def _encrypt_secret_fields(config, config_requirements, user_id):
secret_keys = [
key for key, spec in config_requirements.items()
@@ -216,12 +170,12 @@ class GetTools(Resource):
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
with db_readonly() as conn:
rows = UserToolsRepository(conn).list_for_user(user)
user_doc = UsersRepository(conn).get(user)
tools = user_tools_collection.find({"user": user})
user_tools = []
for row in rows:
tool_copy = _row_to_api(row)
for tool in tools:
tool_copy = {**tool}
tool_copy["id"] = str(tool["_id"])
tool_copy.pop("_id", None)
config_req = tool_copy.get("configRequirements", {})
if not config_req:
@@ -238,29 +192,6 @@ class GetTools(Resource):
tool_copy["config"].pop("encrypted_credentials", None)
user_tools.append(tool_copy)
# ``scheduler`` is dual-registered (default chat tool + agent-
# selectable builtin) and resolves to the same synthetic uuid5 id.
# Surface a single row with both flags so the frontend can show it
# in the management page (toggle) and the agent picker.
seen_ids: set = set()
for default_row in default_tools_for_management(user_doc):
default_copy = _row_to_api(default_row)
default_copy["default"] = True
if default_copy.get("name") in BUILTIN_AGENT_TOOLS:
default_copy["builtin"] = True
seen_ids.add(str(default_copy["id"]))
user_tools.append(default_copy)
# Builtins (e.g. scheduler) hidden from Add-Tool catalog, visible
# to the agent picker. Skip ones already added via the default
# path — both registries share ``_DEFAULT_TOOL_NAMESPACE``.
for builtin_row in builtin_agent_tools_for_management():
builtin_copy = _row_to_api(builtin_row)
if str(builtin_copy["id"]) in seen_ids:
continue
builtin_copy["builtin"] = True
builtin_copy["default"] = False
user_tools.append(builtin_copy)
except Exception as err:
current_app.logger.error(f"Error getting user tools: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
@@ -352,19 +283,26 @@ class CreateTool(Resource):
storage_config = _encrypt_secret_fields(
data["config"], config_requirements, user
)
with db_session() as conn:
created = UserToolsRepository(conn).create(
user,
data["name"],
config=storage_config,
custom_name=data.get("customName", ""),
display_name=data["displayName"],
description=data["description"],
config_requirements=config_requirements,
actions=transformed_actions,
status=bool(data.get("status", True)),
)
new_id = str(created["id"])
new_tool = {
"user": user,
"name": data["name"],
"displayName": data["displayName"],
"description": data["description"],
"customName": data.get("customName", ""),
"actions": transformed_actions,
"config": storage_config,
"configRequirements": config_requirements,
"status": data["status"],
}
resp = user_tools_collection.insert_one(new_tool)
new_id = str(resp.inserted_id)
dual_write(
UserToolsRepository,
lambda repo, u=user, t=new_tool: repo.create(
u, t["name"], config=t.get("config"),
custom_name=t.get("customName"), display_name=t.get("displayName"),
),
)
except Exception as err:
current_app.logger.error(f"Error creating tool: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
@@ -401,51 +339,18 @@ class UpdateTool(Resource):
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
# Default-tool branch first: a dual-registered tool (e.g. ``scheduler``)
# matches BOTH ``is_default_tool_id`` and ``is_builtin_agent_tool_id``.
# The toggle in Tools settings is the per-user opt-out for the
# agentless default — it must reach the ``set_default_tool_enabled``
# path, not the builtin "not editable" reject.
if is_default_tool_id(data["id"]):
if "status" not in data:
return make_response(
jsonify(
{
"success": False,
"message": "Default tools are not editable; "
"only their on/off status can be changed.",
}
),
400,
)
tool_name = default_tool_name_for_id(data["id"])
try:
with db_session() as conn:
UsersRepository(conn).set_default_tool_enabled(
user, tool_name, bool(data["status"])
)
except Exception as err:
current_app.logger.error(
f"Error updating default tool: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True}), 200)
if is_builtin_agent_tool_id(data["id"]):
return make_response(
jsonify(
{
"success": False,
"message": "Built-in agent tools are not editable; "
"add them to an agent via the agent picker.",
}
),
400,
)
try:
update_data: dict = {}
for key in ("name", "displayName", "customName", "description", "actions"):
if key in data:
update_data[key] = data[key]
update_data = {}
if "name" in data:
update_data["name"] = data["name"]
if "displayName" in data:
update_data["displayName"] = data["displayName"]
if "customName" in data:
update_data["customName"] = data["customName"]
if "description" in data:
update_data["description"] = data["description"]
if "actions" in data:
update_data["actions"] = data["actions"]
if "config" in data:
if "actions" in data["config"]:
for action_name in list(data["config"]["actions"].keys()):
@@ -460,61 +365,46 @@ class UpdateTool(Resource):
),
400,
)
with db_session() as conn:
repo = UserToolsRepository(conn)
tool_doc = repo.get_any(data["id"], user)
if not tool_doc:
return make_response(
jsonify({"success": False, "message": "Tool not found"}),
404,
)
tool_name = tool_doc.get("name", data.get("name"))
tool_instance = tool_manager.tools.get(tool_name)
config_requirements = (
tool_instance.get_config_requirements()
if tool_instance
else {}
tool_doc = user_tools_collection.find_one(
{"_id": ObjectId(data["id"]), "user": user}
)
if not tool_doc:
return make_response(
jsonify({"success": False, "message": "Tool not found"}),
404,
)
existing_config = tool_doc.get("config", {}) or {}
has_existing_secrets = "encrypted_credentials" in existing_config
tool_name = tool_doc.get("name", data.get("name"))
tool_instance = tool_manager.tools.get(tool_name)
config_requirements = (
tool_instance.get_config_requirements() if tool_instance else {}
)
existing_config = tool_doc.get("config", {})
has_existing_secrets = "encrypted_credentials" in existing_config
if config_requirements:
validation_errors = _validate_config(
data["config"], config_requirements,
has_existing_secrets=has_existing_secrets,
)
if validation_errors:
return make_response(
jsonify({
"success": False,
"message": "Validation failed",
"errors": validation_errors,
}),
400,
)
update_data["config"] = _merge_secrets_on_update(
data["config"], existing_config, config_requirements, user
if config_requirements:
validation_errors = _validate_config(
data["config"], config_requirements,
has_existing_secrets=has_existing_secrets,
)
if "status" in data:
update_data["status"] = bool(data["status"])
repo.update(
str(tool_doc["id"]), user, _api_to_update_fields(update_data),
)
else:
if "status" in data:
update_data["status"] = bool(data["status"])
with db_session() as conn:
repo = UserToolsRepository(conn)
tool_doc = repo.get_any(data["id"], user)
if not tool_doc:
if validation_errors:
return make_response(
jsonify({"success": False, "message": "Tool not found"}),
404,
jsonify({
"success": False,
"message": "Validation failed",
"errors": validation_errors,
}),
400,
)
repo.update(
str(tool_doc["id"]), user, _api_to_update_fields(update_data),
)
update_data["config"] = _merge_secrets_on_update(
data["config"], existing_config, config_requirements, user
)
if "status" in data:
update_data["status"] = data["status"]
user_tools_collection.update_one(
{"_id": ObjectId(data["id"]), "user": user},
{"$set": update_data},
)
except Exception as err:
current_app.logger.error(f"Error updating tool: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
@@ -545,62 +435,54 @@ class UpdateToolConfig(Resource):
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
if is_synthesized_tool_id(data["id"]):
return make_response(
jsonify(
{
"success": False,
"message": "Default and built-in tools are config-free "
"and cannot be configured.",
}
),
400,
)
try:
with db_session() as conn:
repo = UserToolsRepository(conn)
tool_doc = repo.get_any(data["id"], user)
if not tool_doc:
return make_response(jsonify({"success": False}), 404)
tool_doc = user_tools_collection.find_one(
{"_id": ObjectId(data["id"]), "user": user}
)
if not tool_doc:
return make_response(jsonify({"success": False}), 404)
tool_name = tool_doc.get("name")
if tool_name == "mcp_tool":
server_url = (data["config"].get("server_url") or "").strip()
if server_url:
try:
validate_url(server_url)
except SSRFError:
return make_response(
jsonify({"success": False, "message": "Invalid server URL"}),
400,
)
tool_instance = tool_manager.tools.get(tool_name)
config_requirements = (
tool_instance.get_config_requirements() if tool_instance else {}
)
existing_config = tool_doc.get("config", {}) or {}
has_existing_secrets = "encrypted_credentials" in existing_config
if config_requirements:
validation_errors = _validate_config(
data["config"], config_requirements,
has_existing_secrets=has_existing_secrets,
)
if validation_errors:
tool_name = tool_doc.get("name")
if tool_name == "mcp_tool":
server_url = (data["config"].get("server_url") or "").strip()
if server_url:
try:
validate_url(server_url)
except SSRFError:
return make_response(
jsonify({
"success": False,
"message": "Validation failed",
"errors": validation_errors,
}),
jsonify({"success": False, "message": "Invalid server URL"}),
400,
)
tool_instance = tool_manager.tools.get(tool_name)
config_requirements = (
tool_instance.get_config_requirements() if tool_instance else {}
)
existing_config = tool_doc.get("config", {})
has_existing_secrets = "encrypted_credentials" in existing_config
final_config = _merge_secrets_on_update(
data["config"], existing_config, config_requirements, user
if config_requirements:
validation_errors = _validate_config(
data["config"], config_requirements,
has_existing_secrets=has_existing_secrets,
)
if validation_errors:
return make_response(
jsonify({
"success": False,
"message": "Validation failed",
"errors": validation_errors,
}),
400,
)
repo.update(str(tool_doc["id"]), user, {"config": final_config})
final_config = _merge_secrets_on_update(
data["config"], existing_config, config_requirements, user
)
user_tools_collection.update_one(
{"_id": ObjectId(data["id"]), "user": user},
{"$set": {"config": final_config}},
)
except Exception as err:
current_app.logger.error(
f"Error updating tool config: {err}", exc_info=True
@@ -635,28 +517,11 @@ class UpdateToolActions(Resource):
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
if is_synthesized_tool_id(data["id"]):
return make_response(
jsonify(
{
"success": False,
"message": "Default and built-in tools' actions are not editable.",
}
),
400,
)
try:
with db_session() as conn:
repo = UserToolsRepository(conn)
tool_doc = repo.get_any(data["id"], user)
if not tool_doc:
return make_response(
jsonify({"success": False, "message": "Tool not found"}),
404,
)
repo.update(
str(tool_doc["id"]), user, {"actions": data["actions"]},
)
user_tools_collection.update_one(
{"_id": ObjectId(data["id"]), "user": user},
{"$set": {"actions": data["actions"]}},
)
except Exception as err:
current_app.logger.error(
f"Error updating tool actions: {err}", exc_info=True
@@ -690,38 +555,10 @@ class UpdateToolStatus(Resource):
if missing_fields:
return missing_fields
try:
# Default branch first so a dual-registered id (e.g. ``scheduler``)
# writes the per-user opt-out instead of being rejected as a
# not-editable builtin (both predicates match the same uuid5).
if is_default_tool_id(data["id"]):
tool_name = default_tool_name_for_id(data["id"])
with db_session() as conn:
UsersRepository(conn).set_default_tool_enabled(
user, tool_name, bool(data["status"])
)
return make_response(jsonify({"success": True}), 200)
if is_builtin_agent_tool_id(data["id"]):
return make_response(
jsonify(
{
"success": False,
"message": "Built-in agent tools have no per-user "
"toggle; add them to an agent via the agent picker.",
}
),
400,
)
with db_session() as conn:
repo = UserToolsRepository(conn)
tool_doc = repo.get_any(data["id"], user)
if not tool_doc:
return make_response(
jsonify({"success": False, "message": "Tool not found"}),
404,
)
repo.update(
str(tool_doc["id"]), user, {"status": bool(data["status"])},
)
user_tools_collection.update_one(
{"_id": ObjectId(data["id"]), "user": user},
{"$set": {"status": data["status"]}},
)
except Exception as err:
current_app.logger.error(
f"Error updating tool status: {err}", exc_info=True
@@ -749,25 +586,18 @@ class DeleteTool(Resource):
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
if is_synthesized_tool_id(data["id"]):
return make_response(
jsonify(
{
"success": False,
"message": "Built-in tools cannot be deleted; disable them instead.",
}
),
400,
)
try:
with db_session() as conn:
repo = UserToolsRepository(conn)
tool_doc = repo.get_any(data["id"], user)
if not tool_doc:
return make_response(
jsonify({"success": False, "message": "Tool not found"}), 404
)
repo.delete(str(tool_doc["id"]), user)
result = user_tools_collection.delete_one(
{"_id": ObjectId(data["id"]), "user": user}
)
dual_write(
UserToolsRepository,
lambda repo, tid=data["id"], u=user: repo.delete(tid, u),
)
if result.deleted_count == 0:
return make_response(
jsonify({"success": False, "message": "Tool not found"}), 404
)
except Exception as err:
current_app.logger.error(f"Error deleting tool: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
@@ -836,88 +666,70 @@ class GetArtifact(Resource):
user_id = decoded_token.get("sub")
try:
with db_readonly() as conn:
notes_repo = NotesRepository(conn)
todos_repo = TodosRepository(conn)
# Artifact IDs may be PG UUIDs (post-cutover) or legacy
# Mongo ObjectIds embedded in older conversation history.
# Both repos' ``get_any`` handles the id-shape branching
# internally so a non-UUID input never reaches
# ``CAST(:id AS uuid)`` (which would poison the readonly
# transaction and break the fallback below).
note_doc = notes_repo.get_any(artifact_id, user_id)
if note_doc:
content = note_doc.get("note", "") or note_doc.get("content", "")
line_count = len(content.split("\n")) if content else 0
updated = note_doc.get("updated_at")
artifact = {
"artifact_type": "note",
"data": {
"content": content,
"line_count": line_count,
"updated_at": (
updated.isoformat()
if hasattr(updated, "isoformat")
else updated
),
},
}
return make_response(
jsonify({"success": True, "artifact": artifact}), 200
)
todo_doc = todos_repo.get_any(artifact_id, user_id)
if todo_doc:
tool_id = todo_doc.get("tool_id")
all_todos = todos_repo.list_for_tool(user_id, tool_id) if tool_id else []
items = []
open_count = 0
completed_count = 0
for t in all_todos:
# PG ``todos`` stores a ``completed BOOLEAN`` column;
# the legacy Mongo shape used a ``status`` string.
# Keep the response shape stable by translating here.
status = "completed" if t.get("completed") else "open"
if status == "open":
open_count += 1
else:
completed_count += 1
created = t.get("created_at")
updated = t.get("updated_at")
items.append({
"todo_id": t.get("todo_id"),
"title": t.get("title", ""),
"status": status,
"created_at": (
created.isoformat()
if hasattr(created, "isoformat")
else created
),
"updated_at": (
updated.isoformat()
if hasattr(updated, "isoformat")
else updated
),
})
artifact = {
"artifact_type": "todo_list",
"data": {
"items": items,
"total_count": len(items),
"open_count": open_count,
"completed_count": completed_count,
},
}
return make_response(
jsonify({"success": True, "artifact": artifact}), 200
)
except Exception as err:
current_app.logger.error(
f"Error retrieving artifact: {err}", exc_info=True
obj_id = ObjectId(artifact_id)
except Exception:
return make_response(
jsonify({"success": False, "message": "Invalid artifact ID"}), 400
)
return make_response(jsonify({"success": False}), 400)
from application.core.mongo_db import MongoDB
from application.core.settings import settings
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
note_doc = db["notes"].find_one({"_id": obj_id, "user_id": user_id})
if note_doc:
content = note_doc.get("note", "")
line_count = len(content.split("\n")) if content else 0
artifact = {
"artifact_type": "note",
"data": {
"content": content,
"line_count": line_count,
"updated_at": (
note_doc["updated_at"].isoformat()
if note_doc.get("updated_at")
else None
),
},
}
return make_response(jsonify({"success": True, "artifact": artifact}), 200)
todo_doc = db["todos"].find_one({"_id": obj_id, "user_id": user_id})
if todo_doc:
tool_id = todo_doc.get("tool_id")
query = {"user_id": user_id, "tool_id": tool_id}
all_todos = list(db["todos"].find(query))
items = []
open_count = 0
completed_count = 0
for t in all_todos:
status = t.get("status", "open")
if status == "open":
open_count += 1
elif status == "completed":
completed_count += 1
items.append({
"todo_id": t.get("todo_id"),
"title": t.get("title", ""),
"status": status,
"created_at": (
t["created_at"].isoformat() if t.get("created_at") else None
),
"updated_at": (
t["updated_at"].isoformat() if t.get("updated_at") else None
),
})
artifact = {
"artifact_type": "todo_list",
"data": {
"items": items,
"total_count": len(items),
"open_count": open_count,
"completed_count": completed_count,
},
}
return make_response(jsonify({"success": True, "artifact": artifact}), 200)
return make_response(
jsonify({"success": False, "message": "Artifact not found"}), 404

View File

@@ -1,61 +1,290 @@
"""Centralized utilities for API routes.
Post-Mongo-cutover slim: the old Mongo-shaped helpers (``validate_object_id``,
``check_resource_ownership``, ``paginated_response``, ``serialize_object_id``,
``safe_db_operation``, ``validate_enum``, ``extract_sort_params``) have been
removed — they carried ``bson`` / ``pymongo`` imports and had zero callers.
"""
"""Centralized utilities for API routes."""
from functools import wraps
from typing import Callable, Optional
from typing import Any, Callable, Dict, List, Optional, Tuple
from bson.errors import InvalidId
from bson.objectid import ObjectId
from flask import (
Response,
current_app,
has_app_context,
jsonify,
make_response,
request,
)
from pymongo.collection import Collection
def get_user_id() -> Optional[str]:
"""Extract user ID from decoded JWT token, or None if unauthenticated."""
"""
Extract user ID from decoded JWT token.
Returns:
User ID string or None if not authenticated
"""
decoded_token = getattr(request, "decoded_token", None)
return decoded_token.get("sub") if decoded_token else None
def require_auth(func: Callable) -> Callable:
"""Decorator to require authentication. Returns 401 when absent."""
"""
Decorator to require authentication for route handlers.
Usage:
@require_auth
def get(self):
user_id = get_user_id()
...
"""
@wraps(func)
def wrapper(*args, **kwargs):
user_id = get_user_id()
if not user_id:
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
return error_response("Unauthorized", 401)
return func(*args, **kwargs)
return wrapper
def success_response(
data=None, message: Optional[str] = None, status: int = 200
data: Optional[Dict[str, Any]] = None, status: int = 200
) -> Response:
"""Shape a successful JSON response."""
body = {"success": True}
if data is not None:
body["data"] = data
if message is not None:
body["message"] = message
return make_response(jsonify(body), status)
"""
Create a standardized success response.
Args:
data: Optional data dictionary to include in response
status: HTTP status code (default: 200)
Returns:
Flask Response object
Example:
return success_response({"users": [...], "total": 10})
"""
response = {"success": True}
if data:
response.update(data)
return make_response(jsonify(response), status)
def error_response(message: str, status: int = 400, **kwargs) -> Response:
"""Shape an error JSON response; any kwargs are merged into the body."""
body = {"success": False, "error": message, **kwargs}
return make_response(jsonify(body), status)
"""
Create a standardized error response.
Args:
message: Error message string
status: HTTP status code (default: 400)
**kwargs: Additional fields to include in response
Returns:
Flask Response object
Example:
return error_response("Resource not found", 404)
return error_response("Invalid input", 400, errors=["field1", "field2"])
"""
response = {"success": False, "message": message}
response.update(kwargs)
return make_response(jsonify(response), status)
def require_fields(required: list) -> Callable:
"""Decorator: return 400 if any listed field is missing/falsy in the JSON body."""
def validate_object_id(
id_string: str, resource_name: str = "Resource"
) -> Tuple[Optional[ObjectId], Optional[Response]]:
"""
Validate and convert string to ObjectId.
Args:
id_string: String to convert
resource_name: Name of resource for error message
Returns:
Tuple of (ObjectId or None, error_response or None)
Example:
obj_id, error = validate_object_id(workflow_id, "Workflow")
if error:
return error
"""
try:
return ObjectId(id_string), None
except (InvalidId, TypeError):
return None, error_response(f"Invalid {resource_name} ID format")
def validate_pagination(
default_limit: int = 20, max_limit: int = 100
) -> Tuple[int, int, Optional[Response]]:
"""
Extract and validate pagination parameters from request.
Args:
default_limit: Default items per page
max_limit: Maximum allowed items per page
Returns:
Tuple of (limit, skip, error_response or None)
Example:
limit, skip, error = validate_pagination()
if error:
return error
"""
try:
limit = min(int(request.args.get("limit", default_limit)), max_limit)
skip = int(request.args.get("skip", 0))
if limit < 1 or skip < 0:
return 0, 0, error_response("Invalid pagination parameters")
return limit, skip, None
except ValueError:
return 0, 0, error_response("Invalid pagination parameters")
def check_resource_ownership(
collection: Collection,
resource_id: ObjectId,
user_id: str,
resource_name: str = "Resource",
) -> Tuple[Optional[Dict], Optional[Response]]:
"""
Check if resource exists and belongs to user.
Args:
collection: MongoDB collection
resource_id: Resource ObjectId
user_id: User ID string
resource_name: Name of resource for error messages
Returns:
Tuple of (resource_dict or None, error_response or None)
Example:
workflow, error = check_resource_ownership(
workflows_collection,
workflow_id,
user_id,
"Workflow"
)
if error:
return error
"""
resource = collection.find_one({"_id": resource_id, "user": user_id})
if not resource:
return None, error_response(f"{resource_name} not found", 404)
return resource, None
def serialize_object_id(
obj: Dict[str, Any], id_field: str = "_id", new_field: str = "id"
) -> Dict[str, Any]:
"""
Convert ObjectId to string in a dictionary.
Args:
obj: Dictionary containing ObjectId
id_field: Field name containing ObjectId
new_field: New field name for string ID
Returns:
Modified dictionary
Example:
user = serialize_object_id(user_doc)
# user["id"] = "507f1f77bcf86cd799439011"
"""
if id_field in obj:
obj[new_field] = str(obj[id_field])
if id_field != new_field:
obj.pop(id_field, None)
return obj
def serialize_list(items: List[Dict], serializer: Callable[[Dict], Dict]) -> List[Dict]:
"""
Apply serializer function to list of items.
Args:
items: List of dictionaries
serializer: Function to apply to each item
Returns:
List of serialized items
Example:
workflows = serialize_list(workflow_docs, serialize_workflow)
"""
return [serializer(item) for item in items]
def paginated_response(
collection: Collection,
query: Dict[str, Any],
serializer: Callable[[Dict], Dict],
limit: int,
skip: int,
sort_field: str = "created_at",
sort_order: int = -1,
response_key: str = "items",
) -> Response:
"""
Create paginated response for collection query.
Args:
collection: MongoDB collection
query: Query dictionary
serializer: Function to serialize each item
limit: Items per page
skip: Number of items to skip
sort_field: Field to sort by
sort_order: Sort order (1=asc, -1=desc)
response_key: Key name for items in response
Returns:
Flask Response with paginated data
Example:
return paginated_response(
workflows_collection,
{"user": user_id},
serialize_workflow,
limit, skip,
response_key="workflows"
)
"""
items = list(
collection.find(query).sort(sort_field, sort_order).skip(skip).limit(limit)
)
total = collection.count_documents(query)
return success_response(
{
response_key: serialize_list(items, serializer),
"total": total,
"limit": limit,
"skip": skip,
}
)
def require_fields(required: List[str]) -> Callable:
"""
Decorator to validate required fields in request JSON.
Args:
required: List of required field names
Returns:
Decorator function
Example:
@require_fields(["name", "description"])
def post(self):
data = request.get_json()
...
"""
def decorator(func: Callable) -> Callable:
@wraps(func)
@@ -65,11 +294,94 @@ def require_fields(required: list) -> Callable:
return error_response("Request body required")
missing = [field for field in required if not data.get(field)]
if missing:
return error_response(
f"Missing required fields: {', '.join(missing)}"
)
return error_response(f"Missing required fields: {', '.join(missing)}")
return func(*args, **kwargs)
return wrapper
return decorator
def safe_db_operation(
operation: Callable, error_message: str = "Database operation failed"
) -> Tuple[Any, Optional[Response]]:
"""
Safely execute database operation with error handling.
Args:
operation: Function to execute
error_message: Error message if operation fails
Returns:
Tuple of (result or None, error_response or None)
Example:
result, error = safe_db_operation(
lambda: collection.insert_one(doc),
"Failed to create resource"
)
if error:
return error
"""
try:
result = operation()
return result, None
except Exception as err:
if has_app_context():
current_app.logger.error(f"{error_message}: {err}", exc_info=True)
return None, error_response(error_message)
def validate_enum(
value: Any, allowed: List[Any], field_name: str
) -> Optional[Response]:
"""
Validate that value is in allowed list.
Args:
value: Value to validate
allowed: List of allowed values
field_name: Field name for error message
Returns:
error_response if invalid, None if valid
Example:
error = validate_enum(status, ["draft", "published"], "status")
if error:
return error
"""
if value not in allowed:
allowed_str = ", ".join(f"'{v}'" for v in allowed)
return error_response(f"Invalid {field_name}. Must be one of: {allowed_str}")
return None
def extract_sort_params(
default_field: str = "created_at",
default_order: str = "desc",
allowed_fields: Optional[List[str]] = None,
) -> Tuple[str, int]:
"""
Extract and validate sort parameters from request.
Args:
default_field: Default sort field
default_order: Default sort order ("asc" or "desc")
allowed_fields: List of allowed sort fields (None = no validation)
Returns:
Tuple of (sort_field, sort_order)
Example:
sort_field, sort_order = extract_sort_params(
allowed_fields=["name", "date", "status"]
)
"""
sort_field = request.args.get("sort", default_field)
sort_order_str = request.args.get("order", default_order).lower()
if allowed_fields and sort_field not in allowed_fields:
sort_field = default_field
sort_order = -1 if sort_order_str == "desc" else 1
return sort_field, sort_order

View File

@@ -1,26 +1,34 @@
"""Workflow management routes."""
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional, Set
from flask import current_app, request
from flask_restx import Namespace, Resource
from application.storage.db.base_repository import looks_like_uuid
from application.api.user.base import (
workflow_edges_collection,
workflow_nodes_collection,
workflows_collection,
)
from application.storage.db.dual_write import dual_write
from application.storage.db.repositories.workflow_edges import WorkflowEdgesRepository
from application.storage.db.repositories.workflow_nodes import WorkflowNodesRepository
from application.storage.db.repositories.workflows import WorkflowsRepository
from application.storage.db.session import db_readonly, db_session
from application.core.json_schema_utils import (
JsonSchemaValidationError,
normalize_json_schema_payload,
)
from application.core.model_utils import get_model_capabilities
from application.api.user.utils import (
check_resource_ownership,
error_response,
get_user_id,
require_auth,
require_fields,
safe_db_operation,
success_response,
validate_object_id,
)
workflows_ns = Namespace("workflows", path="/api")
@@ -31,15 +39,109 @@ def _workflow_error_response(message: str, err: Exception):
return error_response(message)
def _resolve_workflow(repo: WorkflowsRepository, workflow_id: str, user_id: str):
"""Resolve a workflow by UUID or legacy Mongo id, scoped to user."""
if not workflow_id:
return None
if looks_like_uuid(workflow_id):
row = repo.get(workflow_id, user_id)
if row is not None:
return row
return repo.get_by_legacy_id(workflow_id, user_id)
# ---------------------------------------------------------------------------
# Postgres dual-write helpers
#
# Workflows are unusual relative to other Phase 3 tables: a single user
# action (create / update) writes to three collections in concert
# (workflows + workflow_nodes + workflow_edges) and the edges reference
# nodes by user-provided string ids. The Postgres mirror needs to:
#
# 1. Run all three writes inside one PG transaction (so the just-created
# nodes are visible when we resolve their UUIDs for the edge insert).
# 2. Translate edge source_id/target_id strings → workflow_nodes.id UUIDs
# after the bulk_create returns them.
#
# Each helper opens exactly one ``dual_write`` call (one PG txn) and uses
# the connection from whichever repo it was instantiated with to spin up
# any sibling repos it needs.
# ---------------------------------------------------------------------------
def _dual_write_workflow_create(
mongo_workflow_id: str,
user_id: str,
name: str,
description: str,
nodes_data: List[Dict],
edges_data: List[Dict],
graph_version: int = 1,
) -> None:
"""Mirror a Mongo workflow create into Postgres."""
def _do(repo: WorkflowsRepository) -> None:
conn = repo._conn
wf = repo.create(
user_id,
name,
description=description,
legacy_mongo_id=mongo_workflow_id,
)
_write_graph(conn, wf["id"], graph_version, nodes_data, edges_data)
dual_write(WorkflowsRepository, _do)
def _dual_write_workflow_update(
mongo_workflow_id: str,
user_id: str,
name: str,
description: str,
nodes_data: List[Dict],
edges_data: List[Dict],
next_graph_version: int,
) -> None:
"""Mirror a Mongo workflow update into Postgres.
Mirrors the Mongo route: insert the new graph_version's nodes/edges,
bump the workflow's name/description/current_graph_version, then drop
every other graph_version's nodes/edges.
"""
def _do(repo: WorkflowsRepository) -> None:
conn = repo._conn
wf = _resolve_pg_workflow(conn, mongo_workflow_id)
if wf is None:
return
_write_graph(conn, wf["id"], next_graph_version, nodes_data, edges_data)
repo.update(wf["id"], user_id, {
"name": name,
"description": description,
"current_graph_version": next_graph_version,
})
WorkflowNodesRepository(conn).delete_other_versions(
wf["id"], next_graph_version,
)
WorkflowEdgesRepository(conn).delete_other_versions(
wf["id"], next_graph_version,
)
dual_write(WorkflowsRepository, _do)
def _dual_write_workflow_delete(mongo_workflow_id: str, user_id: str) -> None:
"""Mirror a Mongo workflow delete into Postgres.
The CASCADE on workflows.id → workflow_nodes/workflow_edges takes
care of the children automatically.
"""
def _do(repo: WorkflowsRepository) -> None:
wf = _resolve_pg_workflow(repo._conn, mongo_workflow_id)
if wf is not None:
repo.delete(wf["id"], user_id)
dual_write(WorkflowsRepository, _do)
def _resolve_pg_workflow(conn, mongo_workflow_id: str) -> Optional[Dict]:
"""Look up a Postgres workflow by its Mongo ObjectId string."""
from sqlalchemy import text as _text
row = conn.execute(
_text("SELECT id FROM workflows WHERE legacy_mongo_id = :legacy_id"),
{"legacy_id": mongo_workflow_id},
).fetchone()
return {"id": str(row[0])} if row else None
def _write_graph(
@@ -48,13 +150,14 @@ def _write_graph(
graph_version: int,
nodes_data: List[Dict],
edges_data: List[Dict],
) -> List[Dict]:
"""Bulk-create nodes + edges for one graph version. Uses ON CONFLICT upsert.
) -> None:
"""Bulk-create nodes + edges for one graph version inside one txn.
Edges arrive with source/target as user-provided node-id strings. We
insert nodes first, capture their ``node_id → UUID`` map, then
translate edges before insertion. Edges referencing missing nodes are
dropped with a warning.
Edges arrive with source/target as user-provided node-id strings
(the same shape the Mongo route stores). We bulk-insert nodes first,
capture their ``node_id → UUID`` map from the returned rows, then
translate edge source/target strings to those UUIDs before the edge
bulk insert. Edges referencing missing nodes are dropped (logged).
"""
nodes_repo = WorkflowNodesRepository(conn)
edges_repo = WorkflowEdgesRepository(conn)
@@ -70,13 +173,13 @@ def _write_graph(
"description": n.get("description", ""),
"position": n.get("position", {"x": 0, "y": 0}),
"config": n.get("data", {}),
"legacy_mongo_id": n.get("legacy_mongo_id"),
}
for n in nodes_data
],
)
node_uuid_by_str = {n["node_id"]: n["id"] for n in created_nodes}
else:
created_nodes = []
node_uuid_by_str = {}
if edges_data:
@@ -88,7 +191,7 @@ def _write_graph(
to_uuid = node_uuid_by_str.get(tgt)
if not from_uuid or not to_uuid:
current_app.logger.warning(
"Workflow graph write: dropping edge %s; node refs unresolved "
"PG dual-write: dropping edge %s; node refs unresolved "
"(source=%s, target=%s)",
e.get("id"), src, tgt,
)
@@ -101,42 +204,36 @@ def _write_graph(
"target_handle": e.get("targetHandle"),
})
if translated_edges:
edges_repo.bulk_create(
pg_workflow_id, graph_version, translated_edges,
)
return created_nodes
edges_repo.bulk_create(pg_workflow_id, graph_version, translated_edges)
def serialize_workflow(w: Dict) -> Dict:
"""Serialize workflow row to API response format."""
created_at = w.get("created_at")
updated_at = w.get("updated_at")
"""Serialize workflow document to API response format."""
return {
"id": str(w["id"]),
"id": str(w["_id"]),
"name": w.get("name"),
"description": w.get("description"),
"created_at": created_at.isoformat() if hasattr(created_at, "isoformat") else created_at,
"updated_at": updated_at.isoformat() if hasattr(updated_at, "isoformat") else updated_at,
"created_at": w["created_at"].isoformat() if w.get("created_at") else None,
"updated_at": w["updated_at"].isoformat() if w.get("updated_at") else None,
}
def serialize_node(n: Dict) -> Dict:
"""Serialize workflow node row to API response format."""
"""Serialize workflow node document to API response format."""
return {
"id": n["node_id"],
"type": n["node_type"],
"id": n["id"],
"type": n["type"],
"title": n.get("title"),
"description": n.get("description"),
"position": n.get("position"),
"data": n.get("config", {}) or {},
"data": n.get("config", {}),
}
def serialize_edge(e: Dict) -> Dict:
"""Serialize workflow edge row to API response format."""
"""Serialize workflow edge document to API response format."""
return {
"id": e["edge_id"],
"id": e["id"],
"source": e.get("source_id"),
"target": e.get("target_id"),
"sourceHandle": e.get("source_handle"),
@@ -145,7 +242,7 @@ def serialize_edge(e: Dict) -> Dict:
def get_workflow_graph_version(workflow: Dict) -> int:
"""Get current graph version with fallback."""
"""Get current graph version with legacy fallback."""
raw_version = workflow.get("current_graph_version", 1)
try:
version = int(raw_version)
@@ -154,6 +251,22 @@ def get_workflow_graph_version(workflow: Dict) -> int:
return 1
def fetch_graph_documents(collection, workflow_id: str, graph_version: int) -> List[Dict]:
"""Fetch graph docs for active version, with fallback for legacy unversioned data."""
docs = list(
collection.find({"workflow_id": workflow_id, "graph_version": graph_version})
)
if docs:
return docs
if graph_version == 1:
return list(
collection.find(
{"workflow_id": workflow_id, "graph_version": {"$exists": False}}
)
)
return docs
def validate_json_schema_payload(
json_schema: Any,
) -> tuple[Optional[Dict[str, Any]], Optional[str]]:
@@ -198,14 +311,8 @@ def normalize_agent_node_json_schemas(nodes: List[Dict]) -> List[Dict]:
return normalized_nodes
def validate_workflow_structure(
nodes: List[Dict], edges: List[Dict], user_id: str | None = None
) -> List[str]:
"""Validate workflow graph structure.
``user_id`` is required so per-user BYOM custom-model UUIDs resolve
when checking each agent node's structured-output capability.
"""
def validate_workflow_structure(nodes: List[Dict], edges: List[Dict]) -> List[str]:
"""Validate workflow graph structure."""
errors = []
if not nodes:
@@ -349,7 +456,7 @@ def validate_workflow_structure(
model_id = raw_config.get("model_id")
if has_json_schema and isinstance(model_id, str) and model_id.strip():
capabilities = get_model_capabilities(model_id.strip(), user_id=user_id)
capabilities = get_model_capabilities(model_id.strip())
if capabilities and not capabilities.get("supports_structured_output", False):
errors.append(
f"Agent node '{agent_title}' selected model does not support structured output"
@@ -380,6 +487,53 @@ def _can_reach_end(
return any(_can_reach_end(t, edges, node_map, end_ids, visited) for t in outgoing if t)
def create_workflow_nodes(
workflow_id: str, nodes_data: List[Dict], graph_version: int
) -> List[Dict]:
"""Insert workflow nodes into Mongo and return rows with Mongo ids."""
if nodes_data:
mongo_nodes = [
{
"id": n["id"],
"workflow_id": workflow_id,
"graph_version": graph_version,
"type": n["type"],
"title": n.get("title", ""),
"description": n.get("description", ""),
"position": n.get("position", {"x": 0, "y": 0}),
"config": n.get("data", {}),
}
for n in nodes_data
]
result = workflow_nodes_collection.insert_many(mongo_nodes)
return [
{**node, "legacy_mongo_id": str(inserted_id)}
for node, inserted_id in zip(nodes_data, result.inserted_ids)
]
return []
def create_workflow_edges(
workflow_id: str, edges_data: List[Dict], graph_version: int
) -> None:
"""Insert workflow edges into database."""
if edges_data:
workflow_edges_collection.insert_many(
[
{
"id": e["id"],
"workflow_id": workflow_id,
"graph_version": graph_version,
"source_id": e.get("source"),
"target_id": e.get("target"),
"source_handle": e.get("sourceHandle"),
"target_handle": e.get("targetHandle"),
}
for e in edges_data
]
)
@workflows_ns.route("/workflows")
class WorkflowList(Resource):
@@ -391,29 +545,54 @@ class WorkflowList(Resource):
data = request.get_json()
name = data.get("name", "").strip()
description = data.get("description", "")
nodes_data = data.get("nodes", [])
edges_data = data.get("edges", [])
validation_errors = validate_workflow_structure(
nodes_data, edges_data, user_id=user_id
)
validation_errors = validate_workflow_structure(nodes_data, edges_data)
if validation_errors:
return error_response(
"Workflow validation failed", errors=validation_errors
)
nodes_data = normalize_agent_node_json_schemas(nodes_data)
try:
with db_session() as conn:
repo = WorkflowsRepository(conn)
workflow = repo.create(user_id, name, description=description)
pg_workflow_id = str(workflow["id"])
_write_graph(conn, pg_workflow_id, 1, nodes_data, edges_data)
except Exception as err:
return _workflow_error_response("Failed to create workflow", err)
now = datetime.now(timezone.utc)
workflow_doc = {
"name": name,
"description": data.get("description", ""),
"user": user_id,
"created_at": now,
"updated_at": now,
"current_graph_version": 1,
}
return success_response({"id": pg_workflow_id}, 201)
result, error = safe_db_operation(
lambda: workflows_collection.insert_one(workflow_doc),
"Failed to create workflow",
)
if error:
return error
workflow_id = str(result.inserted_id)
try:
created_nodes = create_workflow_nodes(workflow_id, nodes_data, 1)
create_workflow_edges(workflow_id, edges_data, 1)
except Exception as err:
workflow_nodes_collection.delete_many({"workflow_id": workflow_id})
workflow_edges_collection.delete_many({"workflow_id": workflow_id})
workflows_collection.delete_one({"_id": result.inserted_id})
return _workflow_error_response("Failed to create workflow structure", err)
_dual_write_workflow_create(
workflow_id,
user_id,
name,
data.get("description", ""),
created_nodes,
edges_data,
)
return success_response({"id": workflow_id}, 201)
@workflows_ns.route("/workflows/<string:workflow_id>")
@@ -423,22 +602,23 @@ class WorkflowDetail(Resource):
def get(self, workflow_id: str):
"""Get workflow details with nodes and edges."""
user_id = get_user_id()
try:
with db_readonly() as conn:
repo = WorkflowsRepository(conn)
workflow = _resolve_workflow(repo, workflow_id, user_id)
if workflow is None:
return error_response("Workflow not found", 404)
pg_workflow_id = str(workflow["id"])
graph_version = get_workflow_graph_version(workflow)
nodes = WorkflowNodesRepository(conn).find_by_version(
pg_workflow_id, graph_version,
)
edges = WorkflowEdgesRepository(conn).find_by_version(
pg_workflow_id, graph_version,
)
except Exception as err:
return _workflow_error_response("Failed to fetch workflow", err)
obj_id, error = validate_object_id(workflow_id, "Workflow")
if error:
return error
workflow, error = check_resource_ownership(
workflows_collection, obj_id, user_id, "Workflow"
)
if error:
return error
graph_version = get_workflow_graph_version(workflow)
nodes = fetch_graph_documents(
workflow_nodes_collection, workflow_id, graph_version
)
edges = fetch_graph_documents(
workflow_edges_collection, workflow_id, graph_version
)
return success_response(
{
@@ -453,51 +633,89 @@ class WorkflowDetail(Resource):
def put(self, workflow_id: str):
"""Update workflow and replace nodes/edges."""
user_id = get_user_id()
obj_id, error = validate_object_id(workflow_id, "Workflow")
if error:
return error
workflow, error = check_resource_ownership(
workflows_collection, obj_id, user_id, "Workflow"
)
if error:
return error
data = request.get_json()
name = data.get("name", "").strip()
description = data.get("description", "")
nodes_data = data.get("nodes", [])
edges_data = data.get("edges", [])
validation_errors = validate_workflow_structure(
nodes_data, edges_data, user_id=user_id
)
validation_errors = validate_workflow_structure(nodes_data, edges_data)
if validation_errors:
return error_response(
"Workflow validation failed", errors=validation_errors
)
nodes_data = normalize_agent_node_json_schemas(nodes_data)
current_graph_version = get_workflow_graph_version(workflow)
next_graph_version = current_graph_version + 1
try:
with db_session() as conn:
repo = WorkflowsRepository(conn)
workflow = _resolve_workflow(repo, workflow_id, user_id)
if workflow is None:
return error_response("Workflow not found", 404)
pg_workflow_id = str(workflow["id"])
current_graph_version = get_workflow_graph_version(workflow)
next_graph_version = current_graph_version + 1
_write_graph(
conn, pg_workflow_id, next_graph_version,
nodes_data, edges_data,
)
repo.update(
pg_workflow_id, user_id,
{
"name": name,
"description": description,
"current_graph_version": next_graph_version,
},
)
WorkflowNodesRepository(conn).delete_other_versions(
pg_workflow_id, next_graph_version,
)
WorkflowEdgesRepository(conn).delete_other_versions(
pg_workflow_id, next_graph_version,
)
created_nodes = create_workflow_nodes(
workflow_id, nodes_data, next_graph_version,
)
create_workflow_edges(workflow_id, edges_data, next_graph_version)
except Exception as err:
return _workflow_error_response("Failed to update workflow", err)
workflow_nodes_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": next_graph_version}
)
workflow_edges_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": next_graph_version}
)
return _workflow_error_response("Failed to update workflow structure", err)
now = datetime.now(timezone.utc)
_, error = safe_db_operation(
lambda: workflows_collection.update_one(
{"_id": obj_id},
{
"$set": {
"name": name,
"description": data.get("description", ""),
"updated_at": now,
"current_graph_version": next_graph_version,
}
},
),
"Failed to update workflow",
)
if error:
workflow_nodes_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": next_graph_version}
)
workflow_edges_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": next_graph_version}
)
return error
try:
workflow_nodes_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": {"$ne": next_graph_version}}
)
workflow_edges_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": {"$ne": next_graph_version}}
)
except Exception as cleanup_err:
current_app.logger.warning(
f"Failed to clean old workflow graph versions for {workflow_id}: {cleanup_err}"
)
_dual_write_workflow_update(
workflow_id,
user_id,
name,
data.get("description", ""),
created_nodes,
edges_data,
next_graph_version,
)
return success_response()
@@ -505,15 +723,23 @@ class WorkflowDetail(Resource):
def delete(self, workflow_id: str):
"""Delete workflow and its graph."""
user_id = get_user_id()
obj_id, error = validate_object_id(workflow_id, "Workflow")
if error:
return error
workflow, error = check_resource_ownership(
workflows_collection, obj_id, user_id, "Workflow"
)
if error:
return error
try:
with db_session() as conn:
repo = WorkflowsRepository(conn)
workflow = _resolve_workflow(repo, workflow_id, user_id)
if workflow is None:
return error_response("Workflow not found", 404)
# ON DELETE CASCADE on workflow_nodes/edges cleans children.
repo.delete(str(workflow["id"]), user_id)
workflow_nodes_collection.delete_many({"workflow_id": workflow_id})
workflow_edges_collection.delete_many({"workflow_id": workflow_id})
workflows_collection.delete_one({"_id": workflow["_id"], "user": user_id})
except Exception as err:
return _workflow_error_response("Failed to delete workflow", err)
_dual_write_workflow_delete(workflow_id, user_id)
return success_response()

View File

@@ -9,7 +9,6 @@ import json
import logging
import time
import traceback
from datetime import datetime
from typing import Any, Dict, Generator, Optional
from flask import Blueprint, jsonify, make_response, request, Response
@@ -21,8 +20,8 @@ from application.api.v1.translator import (
translate_response,
translate_stream_event,
)
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.session import db_readonly
from application.core.mongo_db import MongoDB
from application.core.settings import settings
logger = logging.getLogger(__name__)
@@ -40,8 +39,9 @@ def _extract_bearer_token() -> Optional[str]:
def _lookup_agent(api_key: str) -> Optional[Dict]:
"""Look up the agent document for this API key."""
try:
with db_readonly() as conn:
return AgentsRepository(conn).find_by_key(api_key)
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
return db["agents"].find_one({"key": api_key})
except Exception:
logger.warning("Failed to look up agent for API key", exc_info=True)
return None
@@ -90,14 +90,8 @@ def chat_completions():
)
# Link decoded_token to the agent's owner so continuation state,
# logs, and tool execution use the correct user identity. The PG
# ``agents`` row exposes the owner via ``user_id`` (``user`` is the
# legacy Mongo field name kept in ``row_to_dict`` only for the
# mapping ``id``/``_id``).
agent_user = (
(agent_doc.get("user_id") or agent_doc.get("user"))
if agent_doc else None
)
# logs, and tool execution use the correct user identity.
agent_user = agent_doc.get("user") if agent_doc else None
decoded_token = {"sub": agent_user or "api_key_user"}
try:
@@ -214,7 +208,6 @@ def _stream_response(
decoded_token=processor.decoded_token,
agent_id=processor.agent_id,
model_id=processor.model_id,
model_user_id=processor.model_user_id,
should_save_conversation=should_save_conversation,
_continuation=continuation,
)
@@ -222,26 +215,13 @@ def _stream_response(
for line in internal_stream:
if not line.strip():
continue
# ``complete_stream`` prefixes each frame with ``id: <seq>\n``
# before the ``data:`` line. Extract just the data line so JSON
# decode doesn't choke on the SSE framing.
event_str = ""
for raw in line.split("\n"):
if raw.startswith("data:"):
event_str = raw[len("data:") :].lstrip()
break
if not event_str:
continue
# Parse the internal SSE event
event_str = line.replace("data: ", "").strip()
try:
event_data = json.loads(event_str)
except (json.JSONDecodeError, TypeError):
continue
# Skip the informational ``message_id`` event — it has no v1 /
# OpenAI-compatible analog.
if event_data.get("type") == "message_id":
continue
# Update completion_id when we get the conversation id
if event_data.get("type") == "id":
conv_id = event_data.get("id", "")
@@ -272,7 +252,6 @@ def _non_stream_response(
decoded_token=processor.decoded_token,
agent_id=processor.agent_id,
model_id=processor.model_id,
model_user_id=processor.model_user_id,
should_save_conversation=should_save_conversation,
_continuation=continuation,
)
@@ -311,41 +290,39 @@ def list_models():
)
try:
with db_readonly() as conn:
agents_repo = AgentsRepository(conn)
agent = agents_repo.find_by_key(api_key)
if not agent:
return make_response(
jsonify({"error": {"message": "Invalid API key", "type": "auth_error"}}),
401,
)
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
agents_collection = db["agents"]
# Repository rows now go through ``coerce_pg_native`` at SELECT
# time, so timestamps arrive as ISO 8601 strings. Parse before
# taking ``.timestamp()``; fall back to ``time.time()`` only when
# the value is genuinely missing or unparseable.
created = agent.get("created_at") or agent.get("createdAt")
if isinstance(created, str):
try:
created = datetime.fromisoformat(created)
except (ValueError, TypeError):
created = None
created_ts = (
int(created.timestamp()) if hasattr(created, "timestamp")
else int(time.time())
)
model_id = str(agent.get("id") or agent.get("_id") or "")
model = {
"id": model_id,
"object": "model",
"created": created_ts,
"owned_by": "docsgpt",
"name": agent.get("name", ""),
"description": agent.get("description", ""),
}
# Find the agent for this api_key
agent = agents_collection.find_one({"key": api_key})
if not agent:
return make_response(
jsonify({"error": {"message": "Invalid API key", "type": "auth_error"}}),
401,
)
user = agent.get("user")
# Return all agents belonging to this user
user_agents = list(agents_collection.find({"user": user}))
models = []
for ag in user_agents:
created = ag.get("createdAt")
created_ts = int(created.timestamp()) if created else int(time.time())
model_id = str(ag.get("_id") or ag.get("id") or "")
models.append({
"id": model_id,
"object": "model",
"created": created_ts,
"owned_by": "docsgpt",
"name": ag.get("name", ""),
"description": ag.get("description", ""),
})
return make_response(
jsonify({"object": "list", "data": [model]}),
jsonify({"object": "list", "data": models}),
200,
)
except Exception as e:

View File

@@ -1,30 +1,25 @@
import logging
import os
import platform
import uuid
import dotenv
from flask import Flask, Response, jsonify, redirect, request
from flask import Flask, jsonify, redirect, request
from jose import jwt
from application.auth import handle_auth
from application.core import log_context
from application.core.logging_config import setup_logging
setup_logging()
from application.api import api # noqa: E402
from application.api.answer import answer # noqa: E402
from application.api.answer.routes.messages import messages_bp # noqa: E402
from application.api.events.routes import events # noqa: E402
from application.api.internal.routes import internal # noqa: E402
from application.api.user.routes import user # noqa: E402
from application.api.connector.routes import connector # noqa: E402
from application.api.v1 import v1_bp # noqa: E402
from application.celery_init import celery # noqa: E402
from application.core.settings import settings # noqa: E402
from application.storage.db.bootstrap import ensure_database_ready # noqa: E402
from application.stt.upload_limits import ( # noqa: E402
build_stt_file_size_limit_message,
should_reject_stt_request,
@@ -37,28 +32,9 @@ if platform.system() == "Windows":
pathlib.PosixPath = pathlib.WindowsPath
dotenv.load_dotenv()
# Self-bootstrap the user-data Postgres DB. Runs before any blueprint or
# repository touches the engine, so the first request can't race the
# schema being created. Gated by AUTO_CREATE_DB / AUTO_MIGRATE settings
# (default ON for dev; disable in prod if schema is managed out-of-band).
ensure_database_ready(
settings.POSTGRES_URI,
create_db=settings.AUTO_CREATE_DB,
migrate=settings.AUTO_MIGRATE,
logger=logging.getLogger("application.app"),
)
from application.agents.default_tools import ( # noqa: E402
validate_default_chat_tools,
)
validate_default_chat_tools()
app = Flask(__name__)
app.register_blueprint(user)
app.register_blueprint(answer)
app.register_blueprint(events)
app.register_blueprint(messages_bp)
app.register_blueprint(internal)
app.register_blueprint(connector)
app.register_blueprint(v1_bp)
@@ -123,38 +99,6 @@ def generate_token():
return jsonify({"error": "Token generation not allowed in current auth mode"}), 400
_LOG_CTX_TOKEN_ATTR = "_log_ctx_token"
@app.before_request
def _bind_log_context():
"""Bind activity_id + endpoint for the duration of this request.
Runs before ``authenticate_request``; ``user_id`` is overlaid in a
follow-up handler once the JWT has been decoded.
"""
if request.method == "OPTIONS":
return None
activity_id = str(uuid.uuid4())
request.activity_id = activity_id
token = log_context.bind(
activity_id=activity_id,
endpoint=request.endpoint,
)
setattr(request, _LOG_CTX_TOKEN_ATTR, token)
return None
@app.teardown_request
def _reset_log_context(_exc):
# SSE streams keep yielding after teardown fires, but a2wsgi runs each
# request inside ``copy_context().run(...)``, so this reset doesn't
# leak into the stream's view of the context.
token = getattr(request, _LOG_CTX_TOKEN_ATTR, None)
if token is not None:
log_context.reset(token)
@app.before_request
def enforce_stt_request_size_limits():
if request.method == "OPTIONS":
@@ -176,12 +120,6 @@ def enforce_stt_request_size_limits():
def authenticate_request():
if request.method == "OPTIONS":
return "", 200
# OpenAI-compatible routes authenticate via opaque agent API keys in the
# Authorization header, which the JWT decoder below would reject. Defer
# auth to the route handlers (see application/api/v1/routes.py).
if request.path.startswith("/v1/"):
request.decoded_token = None
return None
decoded_token = handle_auth(request)
if not decoded_token:
request.decoded_token = None
@@ -191,29 +129,13 @@ def authenticate_request():
request.decoded_token = decoded_token
@app.before_request
def _bind_user_id_to_log_context():
# Registered after ``authenticate_request`` (Flask runs before_request
# handlers in registration order), so ``request.decoded_token`` is
# populated by the time we read it. ``teardown_request`` unwinds the
# whole request-level bind, so no separate reset token is needed here.
if request.method == "OPTIONS":
return None
decoded_token = getattr(request, "decoded_token", None)
user_id = decoded_token.get("sub") if isinstance(decoded_token, dict) else None
if user_id:
log_context.bind(user_id=user_id)
return None
@app.after_request
def after_request(response: Response) -> Response:
"""Add CORS headers for the pure Flask development entrypoint."""
response.headers["Access-Control-Allow-Origin"] = "*"
response.headers["Access-Control-Allow-Headers"] = (
"Content-Type, Authorization, Idempotency-Key"
def after_request(response):
response.headers.add("Access-Control-Allow-Origin", "*")
response.headers.add("Access-Control-Allow-Headers", "Content-Type, Authorization")
response.headers.add(
"Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS"
)
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
return response

View File

@@ -1,38 +0,0 @@
"""ASGI entrypoint: Flask (WSGI) + FastMCP on the same process."""
from __future__ import annotations
from a2wsgi import WSGIMiddleware
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.cors import CORSMiddleware
from starlette.routing import Mount
from application.app import app as flask_app
from application.mcp_server import mcp
_WSGI_THREADPOOL = 32
mcp_app = mcp.http_app(path="/")
asgi_app = Starlette(
routes=[
Mount("/mcp", app=mcp_app),
Mount("/", app=WSGIMiddleware(flask_app, workers=_WSGI_THREADPOOL)),
],
middleware=[
Middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
allow_headers=[
"Content-Type",
"Authorization",
"Mcp-Session-Id",
"Idempotency-Key",
],
expose_headers=["Mcp-Session-Id"],
),
],
lifespan=mcp_app.lifespan,
)

View File

@@ -1,4 +1,3 @@
import hashlib
import json
import logging
import time
@@ -11,14 +10,6 @@ from application.utils import get_hash
logger = logging.getLogger(__name__)
def _cache_default(value):
# Image attachments arrive inline as bytes (see GoogleLLM.prepare_messages_with_attachments);
# hash so the cache key stays bounded in size and stable across identical content.
if isinstance(value, (bytes, bytearray, memoryview)):
return f"<bytes:sha256:{hashlib.sha256(bytes(value)).hexdigest()}>"
return repr(value)
_redis_instance = None
_redis_creation_failed = False
_instance_lock = Lock()
@@ -29,17 +20,8 @@ def get_redis_instance():
with _instance_lock:
if _redis_instance is None and not _redis_creation_failed:
try:
# ``health_check_interval`` makes redis-py ping the
# connection every N seconds when otherwise idle.
# Without it, a half-open TCP (NAT silently dropped
# state, ELB idle-close) can hang the SSE generator
# in ``pubsub.get_message`` past its keepalive
# cadence — the kernel never surfaces the dead
# socket because no payload is in flight.
_redis_instance = redis.Redis.from_url(
settings.CACHE_REDIS_URL,
socket_connect_timeout=2,
health_check_interval=10,
settings.CACHE_REDIS_URL, socket_connect_timeout=2
)
except ValueError as e:
logger.error(f"Invalid Redis URL: {e}")
@@ -54,7 +36,7 @@ def get_redis_instance():
def gen_cache_key(messages, model="docgpt", tools=None):
if not all(isinstance(msg, dict) for msg in messages):
raise ValueError("All messages must be dictionaries.")
messages_str = json.dumps(messages, default=_cache_default)
messages_str = json.dumps(messages)
tools_str = json.dumps(str(tools)) if tools else ""
combined = f"{model}_{messages_str}_{tools_str}"
cache_key = get_hash(combined)

View File

@@ -1,20 +1,6 @@
import ctypes
import gc
import inspect
import logging
import sys
import threading
from celery import Celery
from application.core import log_context
from application.core.settings import settings
from celery.signals import (
setup_logging,
task_postrun,
task_prerun,
worker_process_init,
worker_ready,
)
from celery.signals import setup_logging, worker_process_init
def make_celery(app_name=__name__):
@@ -53,101 +39,5 @@ def _dispose_db_engine_on_fork(*args, **kwargs):
dispose_engine()
# Most tasks in this repo accept ``user`` where the log context wants
# ``user_id``; map task parameter names to context keys explicitly.
_TASK_PARAM_TO_CTX_KEY: dict[str, str] = {
"user": "user_id",
"user_id": "user_id",
"agent_id": "agent_id",
"conversation_id": "conversation_id",
}
_task_log_tokens: dict[str, object] = {}
@task_prerun.connect
def _bind_task_log_context(task_id, task, args, kwargs, **_):
# Resolve task args by parameter name — nearly every task in this repo
# is called positionally, so ``kwargs.get('user')`` would bind nothing.
ctx = {"activity_id": task_id}
try:
sig = inspect.signature(task.run)
bound = sig.bind_partial(*args, **kwargs).arguments
except (TypeError, ValueError):
bound = dict(kwargs)
for param_name, value in bound.items():
ctx_key = _TASK_PARAM_TO_CTX_KEY.get(param_name)
if ctx_key and value:
ctx[ctx_key] = value
_task_log_tokens[task_id] = log_context.bind(**ctx)
@task_postrun.connect
def _unbind_task_log_context(task_id, **_):
# ``task_postrun`` fires on both success and failure. Required for
# Celery: unlike the Flask path, tasks aren't isolated in their own
# ``copy_context().run(...)``, so a missing reset would leak the
# bind onto the next task on the same worker.
token = _task_log_tokens.pop(task_id, None)
if token is None:
return
try:
log_context.reset(token)
except ValueError:
# task_prerun and task_postrun ran on different threads (non-default
# Celery pool); the token isn't valid in this context. Drop it.
logging.getLogger(__name__).debug(
"log_context reset skipped for task %s", task_id
)
def _trim_native_heap() -> None:
"""Return freed glibc heap pages to the OS (Linux only; no-op elsewhere)."""
# docling/torch parsing makes large transient allocations; glibc keeps the
# freed pages in per-thread malloc arenas rather than returning them, so a
# long-lived worker child's RSS only ever climbs. malloc_trim hands them
# back. The symbol is glibc-only — absent in macOS libc.
if not sys.platform.startswith("linux"):
return
try:
ctypes.CDLL("libc.so.6").malloc_trim(0)
except (OSError, AttributeError):
pass
@task_postrun.connect
def _reclaim_memory_after_task(*args, **kwargs):
"""Drop per-task allocations so the prefork child's RSS doesn't ratchet."""
gc.collect()
torch = sys.modules.get("torch")
if torch is not None:
try:
if torch.cuda.is_available():
torch.cuda.empty_cache()
except Exception:
pass
_trim_native_heap()
@worker_ready.connect
def _run_version_check(*args, **kwargs):
"""Kick off the anonymous version check on worker startup.
Runs in a daemon thread so a slow endpoint or bad DNS never holds
up the worker becoming ready for tasks. The check itself is
fail-silent (see ``application.updates.version_check.run_check``);
this handler's only job is to launch it and get out of the way.
Import is lazy so the symbol resolution never fires at module
import time — consistent with the ``_dispose_db_engine_on_fork``
pattern above.
"""
try:
from application.updates.version_check import run_check
except Exception:
return
threading.Thread(target=run_check, name="version-check", daemon=True).start()
celery = make_celery()
celery.config_from_object("application.celeryconfig")

View File

@@ -1,10 +1,7 @@
from application.core.settings import settings
import os
# Pydantic loads .env into ``settings`` but does not inject values into
# ``os.environ`` — read directly from settings so beat startup (which
# imports this module before any explicit env load) sees a real URL.
broker_url = settings.CELERY_BROKER_URL
result_backend = settings.CELERY_RESULT_BACKEND
broker_url = os.getenv("CELERY_BROKER_URL")
result_backend = os.getenv("CELERY_RESULT_BACKEND")
task_serializer = 'json'
result_serializer = 'json'
@@ -12,29 +9,3 @@ accept_content = ['json']
# Autodiscover tasks
imports = ('application.api.user.tasks',)
# Project-scoped queue so a stray sibling worker on the same broker
# (other repo, same default ``celery`` queue) can't grab DocsGPT tasks.
task_default_queue = "docsgpt"
task_default_exchange = "docsgpt"
task_default_routing_key = "docsgpt"
beat_scheduler = "redbeat.RedBeatScheduler"
redbeat_redis_url = broker_url
redbeat_key_prefix = "redbeat:docsgpt:"
redbeat_lock_timeout = 90
# Survive worker SIGKILL/OOM without silently dropping in-flight tasks.
task_acks_late = True
task_reject_on_worker_lost = True
worker_prefetch_multiplier = settings.CELERY_WORKER_PREFETCH_MULTIPLIER
broker_transport_options = {"visibility_timeout": settings.CELERY_VISIBILITY_TIMEOUT}
result_expires = 86400 * 7
task_track_started = True
# Recycle the prefork worker child to bound native-heap growth from
# docling/torch parsing. Left unset (Celery's unlimited default) when 0.
if settings.CELERY_WORKER_MAX_MEMORY_PER_CHILD > 0:
worker_max_memory_per_child = settings.CELERY_WORKER_MAX_MEMORY_PER_CHILD
if settings.CELERY_WORKER_MAX_TASKS_PER_CHILD > 0:
worker_max_tasks_per_child = settings.CELERY_WORKER_MAX_TASKS_PER_CHILD

View File

@@ -1,57 +0,0 @@
"""Per-activity logging context backed by ``contextvars``.
The ``_ContextFilter`` installed by ``logging_config.setup_logging`` stamps
every ``LogRecord`` emitted inside a ``bind`` block with the bound keys, so
they land as first-class attributes on the OTLP log export rather than being
buried inside formatted message bodies.
A single ``ContextVar`` holds a dict so nested binds reset atomically (LIFO)
via the token returned by ``bind``.
"""
from __future__ import annotations
from contextvars import ContextVar, Token
from typing import Mapping
_CTX_KEYS: frozenset[str] = frozenset(
{
"activity_id",
"parent_activity_id",
"user_id",
"agent_id",
"conversation_id",
"endpoint",
"model",
}
)
_ctx: ContextVar[Mapping[str, str]] = ContextVar("log_ctx", default={})
def bind(**kwargs: object) -> Token:
"""Overlay the given keys onto the current context.
Returns a ``Token`` so the caller can ``reset`` in a ``finally`` block.
Keys outside :data:`_CTX_KEYS` are silently dropped (so a typo can't
stamp a stray field name onto every record), as are ``None`` values
(a missing attribute is more useful than the literal string ``"None"``).
"""
overlay = {
k: str(v)
for k, v in kwargs.items()
if k in _CTX_KEYS and v is not None
}
new = {**_ctx.get(), **overlay}
return _ctx.set(new)
def reset(token: Token) -> None:
"""Restore the context to the snapshot captured by the matching ``bind``."""
_ctx.reset(token)
def snapshot() -> Mapping[str, str]:
"""Return the current context dict. Treat as read-only; use :func:`bind`."""
return _ctx.get()

View File

@@ -1,75 +1,11 @@
import logging
import os
from logging.config import dictConfig
from application.core.log_context import snapshot as _ctx_snapshot
# Loggers with ``propagate=False`` don't share root's handlers, so the
# context filter has to be installed on their handlers directly.
_NON_PROPAGATING_LOGGERS: tuple[str, ...] = (
"uvicorn",
"uvicorn.access",
"uvicorn.error",
"celery.app.trace",
"celery.worker.strategy",
"gunicorn.error",
"gunicorn.access",
)
class _ContextFilter(logging.Filter):
"""Stamp the current ``log_context`` snapshot onto every ``LogRecord``.
Must be installed on **handlers**, not loggers: Python skips logger-level
filters when a child logger's record propagates up. The ``hasattr`` guard
keeps an explicit ``logger.info(..., extra={...})`` from being overwritten.
"""
def filter(self, record: logging.LogRecord) -> bool:
for key, value in _ctx_snapshot().items():
if not hasattr(record, key):
setattr(record, key, value)
return True
def _otlp_logs_enabled() -> bool:
"""Return True when the user has opted in to OTLP log export.
Gated by the standard OTEL env vars so no project-specific knob is needed:
set ``OTEL_LOGS_EXPORTER=otlp`` (and leave ``OTEL_SDK_DISABLED`` unset or
false) to flip it on. When false, ``setup_logging`` keeps its original
console-only behavior.
"""
exporter = os.getenv("OTEL_LOGS_EXPORTER", "").strip().lower()
disabled = os.getenv("OTEL_SDK_DISABLED", "false").strip().lower() == "true"
return exporter == "otlp" and not disabled
def setup_logging() -> None:
"""Configure the root logger with a stdout console handler.
When OTLP log export is enabled, ``opentelemetry-instrument`` attaches a
``LoggingHandler`` to the root logger before this function runs. The
``dictConfig`` call below replaces ``root.handlers`` with the console
handler, which would silently drop the OTEL handler. To make OTLP log
export work without forcing every contributor to opt in, snapshot the
OTEL handlers up front and re-attach them after ``dictConfig``.
"""
preserved_handlers: list[logging.Handler] = []
if _otlp_logs_enabled():
preserved_handlers = [
h
for h in logging.getLogger().handlers
if h.__class__.__module__.startswith("opentelemetry")
]
def setup_logging():
dictConfig({
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"default": {
"format": "[%(asctime)s] %(levelname)s in %(module)s: %(message)s",
'version': 1,
'formatters': {
'default': {
'format': '[%(asctime)s] %(levelname)s in %(module)s: %(message)s',
}
},
"handlers": {
@@ -79,34 +15,8 @@ def setup_logging() -> None:
"formatter": "default",
}
},
"root": {
"level": "INFO",
"handlers": ["console"],
'root': {
'level': 'INFO',
'handlers': ['console'],
},
})
if preserved_handlers:
root = logging.getLogger()
for handler in preserved_handlers:
if handler not in root.handlers:
root.addHandler(handler)
_install_context_filter()
def _install_context_filter() -> None:
"""Attach :class:`_ContextFilter` to root's handlers + every handler on
the known non-propagating loggers. Skipping handlers that already carry
one keeps repeat ``setup_logging`` calls from stacking filters.
"""
def _has_ctx_filter(handler: logging.Handler) -> bool:
return any(isinstance(f, _ContextFilter) for f in handler.filters)
for handler in logging.getLogger().handlers:
if not _has_ctx_filter(handler):
handler.addFilter(_ContextFilter())
for name in _NON_PROPAGATING_LOGGERS:
for handler in logging.getLogger(name).handlers:
if not _has_ctx_filter(handler):
handler.addFilter(_ContextFilter())
})

View File

@@ -0,0 +1,266 @@
"""
Model configurations for all supported LLM providers.
"""
from application.core.model_settings import (
AvailableModel,
ModelCapabilities,
ModelProvider,
)
# Base image attachment types supported by most vision-capable LLMs
IMAGE_ATTACHMENTS = [
"image/png",
"image/jpeg",
"image/jpg",
"image/webp",
"image/gif",
]
# PDF excluded: most OpenAI-compatible endpoints don't support native PDF uploads.
# When excluded, PDFs are synthetically processed by converting pages to images.
OPENAI_ATTACHMENTS = IMAGE_ATTACHMENTS
GOOGLE_ATTACHMENTS = ["application/pdf"] + IMAGE_ATTACHMENTS
ANTHROPIC_ATTACHMENTS = IMAGE_ATTACHMENTS
OPENROUTER_ATTACHMENTS = IMAGE_ATTACHMENTS
NOVITA_ATTACHMENTS = IMAGE_ATTACHMENTS
OPENAI_MODELS = [
AvailableModel(
id="gpt-5.1",
provider=ModelProvider.OPENAI,
display_name="GPT-5.1",
description="Flagship model with enhanced reasoning, coding, and agentic capabilities",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=OPENAI_ATTACHMENTS,
context_window=200000,
),
),
AvailableModel(
id="gpt-5-mini",
provider=ModelProvider.OPENAI,
display_name="GPT-5 Mini",
description="Faster, cost-effective variant of GPT-5.1",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=OPENAI_ATTACHMENTS,
context_window=200000,
),
)
]
ANTHROPIC_MODELS = [
AvailableModel(
id="claude-3-5-sonnet-20241022",
provider=ModelProvider.ANTHROPIC,
display_name="Claude 3.5 Sonnet (Latest)",
description="Latest Claude 3.5 Sonnet with enhanced capabilities",
capabilities=ModelCapabilities(
supports_tools=True,
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
context_window=200000,
),
),
AvailableModel(
id="claude-3-5-sonnet",
provider=ModelProvider.ANTHROPIC,
display_name="Claude 3.5 Sonnet",
description="Balanced performance and capability",
capabilities=ModelCapabilities(
supports_tools=True,
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
context_window=200000,
),
),
AvailableModel(
id="claude-3-opus",
provider=ModelProvider.ANTHROPIC,
display_name="Claude 3 Opus",
description="Most capable Claude model",
capabilities=ModelCapabilities(
supports_tools=True,
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
context_window=200000,
),
),
AvailableModel(
id="claude-3-haiku",
provider=ModelProvider.ANTHROPIC,
display_name="Claude 3 Haiku",
description="Fastest Claude model",
capabilities=ModelCapabilities(
supports_tools=True,
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
context_window=200000,
),
),
]
GOOGLE_MODELS = [
AvailableModel(
id="gemini-flash-latest",
provider=ModelProvider.GOOGLE,
display_name="Gemini Flash (Latest)",
description="Latest experimental Gemini model",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=GOOGLE_ATTACHMENTS,
context_window=int(1e6),
),
),
AvailableModel(
id="gemini-flash-lite-latest",
provider=ModelProvider.GOOGLE,
display_name="Gemini Flash Lite (Latest)",
description="Fast with huge context window",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=GOOGLE_ATTACHMENTS,
context_window=int(1e6),
),
),
AvailableModel(
id="gemini-3-pro-preview",
provider=ModelProvider.GOOGLE,
display_name="Gemini 3 Pro",
description="Most capable Gemini model",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=GOOGLE_ATTACHMENTS,
context_window=2000000,
),
),
]
GROQ_MODELS = [
AvailableModel(
id="llama-3.3-70b-versatile",
provider=ModelProvider.GROQ,
display_name="Llama 3.3 70B",
description="Latest Llama model with high-speed inference",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=128000,
),
),
AvailableModel(
id="openai/gpt-oss-120b",
provider=ModelProvider.GROQ,
display_name="GPT-OSS 120B",
description="Open-source GPT model optimized for speed",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=128000,
),
),
]
OPENROUTER_MODELS = [
AvailableModel(
id="qwen/qwen3-coder:free",
provider=ModelProvider.OPENROUTER,
display_name="Qwen 3 Coder",
description="Latest Qwen model with high-speed inference",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=128000,
supported_attachment_types=OPENROUTER_ATTACHMENTS
),
),
AvailableModel(
id="google/gemma-3-27b-it:free",
provider=ModelProvider.OPENROUTER,
display_name="Gemma 3 27B",
description="Latest Gemma model with high-speed inference",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=128000,
supported_attachment_types=OPENROUTER_ATTACHMENTS
),
),
]
NOVITA_MODELS = [
AvailableModel(
id="moonshotai/kimi-k2.5",
provider=ModelProvider.NOVITA,
display_name="Kimi K2.5",
description="MoE model with function calling, structured output, reasoning, and vision",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=NOVITA_ATTACHMENTS,
context_window=262144,
),
),
AvailableModel(
id="zai-org/glm-5",
provider=ModelProvider.NOVITA,
display_name="GLM-5",
description="MoE model with function calling, structured output, and reasoning",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=[],
context_window=202800,
),
),
AvailableModel(
id="minimax/minimax-m2.5",
provider=ModelProvider.NOVITA,
display_name="MiniMax M2.5",
description="MoE model with function calling, structured output, and reasoning",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=[],
context_window=204800,
),
),
]
AZURE_OPENAI_MODELS = [
AvailableModel(
id="azure-gpt-4",
provider=ModelProvider.AZURE_OPENAI,
display_name="Azure OpenAI GPT-4",
description="Azure-hosted GPT model",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=OPENAI_ATTACHMENTS,
context_window=8192,
),
),
]
def create_custom_openai_model(model_name: str, base_url: str) -> AvailableModel:
"""Create a custom OpenAI-compatible model (e.g., LM Studio, Ollama)."""
return AvailableModel(
id=model_name,
provider=ModelProvider.OPENAI,
display_name=model_name,
description=f"Custom OpenAI-compatible model at {base_url}",
base_url=base_url,
capabilities=ModelCapabilities(
supports_tools=True,
supported_attachment_types=OPENAI_ATTACHMENTS,
),
)

View File

@@ -1,385 +0,0 @@
"""Layered model registry.
Loads model catalogs from YAML files (built-in + operator-supplied),
groups them by provider name, then for each registered provider plugin
calls ``get_models`` to produce the final per-provider model list.
End-user BYOM (per-user model records in Postgres) is layered on top:
when a lookup arrives with a ``user_id``, the registry consults a
per-user cache first (loaded from the ``user_custom_models`` table on
miss) and falls through to the built-in catalog.
Cross-process invalidation: ``ModelRegistry`` is a per-process
singleton, so a CRUD write only evicts the cache in the process that
served it. Other gunicorn workers and Celery workers would otherwise
keep using a deleted/disabled/key-rotated BYOM record indefinitely.
``invalidate_user`` therefore both drops the local layer *and* bumps a
Redis-side version counter; other processes notice the bump on their
next access (after the local TTL window) and reload from Postgres. If
Redis is unreachable the per-process TTL still bounds staleness — pure
TTL semantics, no regression.
"""
from __future__ import annotations
import logging
import time
from collections import defaultdict
from typing import Dict, List, Optional, Tuple
from application.core.model_settings import AvailableModel
from application.core.model_yaml import (
BUILTIN_MODELS_DIR,
ProviderCatalog,
load_model_yamls,
)
logger = logging.getLogger(__name__)
_USER_CACHE_TTL_SECONDS = 60.0
_USER_VERSION_KEY_PREFIX = "byom:registry_version:"
class ModelRegistry:
"""Singleton registry of available models."""
_instance: Optional["ModelRegistry"] = None
_initialized: bool = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not ModelRegistry._initialized:
self.models: Dict[str, AvailableModel] = {}
self.default_model_id: Optional[str] = None
# Per-user BYOM cache. Each entry is
# ``(layer, version_at_load, loaded_at_monotonic)``:
# * ``layer`` — {model_id: AvailableModel}
# * ``version_at_load`` — Redis-side counter snapshot at
# reload time, or ``None`` if Redis was unreachable
# * ``loaded_at_monotonic`` — for TTL bookkeeping
# Populated lazily, evicted by TTL + cross-process
# invalidation (see ``invalidate_user``).
self._user_models: Dict[
str,
Tuple[Dict[str, AvailableModel], Optional[int], float],
] = {}
self._load_models()
ModelRegistry._initialized = True
@classmethod
def get_instance(cls) -> "ModelRegistry":
return cls()
@classmethod
def reset(cls) -> None:
"""Clear the singleton. Intended for test fixtures."""
cls._instance = None
cls._initialized = False
@classmethod
def invalidate_user(cls, user_id: str) -> None:
"""Drop the cached per-user model layer for ``user_id``.
Called by the BYOM REST routes after every create/update/delete.
Two effects:
* Local: pop the entry from this process's cache so the next
lookup re-reads from Postgres immediately.
* Cross-process: ``INCR`` a Redis-side version counter for this
user. Other gunicorn/Celery processes notice the counter
changed on their next TTL-driven recheck (see
``_user_models_for``) and reload. If Redis is unreachable we
log and continue — local invalidation still happened, and
peers fall back to TTL-only staleness bounds.
"""
if cls._instance is not None:
cls._instance._user_models.pop(user_id, None)
try:
from application.cache import get_redis_instance
client = get_redis_instance()
if client is not None:
client.incr(_USER_VERSION_KEY_PREFIX + user_id)
except Exception as e:
logger.warning(
"BYOM invalidate: failed to publish version bump for "
"user %s (Redis unreachable?): %s",
user_id,
e,
)
@classmethod
def _read_user_version(cls, user_id: str) -> Optional[int]:
"""Return the Redis-side invalidation counter for ``user_id``.
``0`` if the key has never been bumped; ``None`` if Redis is
unreachable or the read failed (callers fall back to TTL-only
staleness in that case).
"""
try:
from application.cache import get_redis_instance
client = get_redis_instance()
if client is None:
return None
raw = client.get(_USER_VERSION_KEY_PREFIX + user_id)
if raw is None:
return 0
return int(raw)
except Exception:
return None
def _load_models(self) -> None:
from pathlib import Path
from application.core.settings import settings
from application.llm.providers import ALL_PROVIDERS
directories = [BUILTIN_MODELS_DIR]
operator_dir = getattr(settings, "MODELS_CONFIG_DIR", None)
if operator_dir:
op_path = Path(operator_dir)
if not op_path.exists():
logger.warning(
"MODELS_CONFIG_DIR=%s does not exist; no operator "
"model YAMLs will be loaded.",
operator_dir,
)
elif not op_path.is_dir():
logger.warning(
"MODELS_CONFIG_DIR=%s is not a directory; no operator "
"model YAMLs will be loaded.",
operator_dir,
)
else:
directories.append(op_path)
catalogs = load_model_yamls(directories)
# Validate every catalog targets a known plugin before doing any
# registry work, so an unknown provider name in YAML aborts boot
# with a clear error.
plugin_names = {p.name for p in ALL_PROVIDERS}
for c in catalogs:
if c.provider not in plugin_names:
raise ValueError(
f"{c.source_path}: YAML declares unknown provider "
f"{c.provider!r}; no Provider plugin is registered "
f"under that name. Known: {sorted(plugin_names)}"
)
catalogs_by_provider: Dict[str, List[ProviderCatalog]] = defaultdict(list)
for c in catalogs:
catalogs_by_provider[c.provider].append(c)
self.models.clear()
for provider in ALL_PROVIDERS:
if not provider.is_enabled(settings):
continue
for model in provider.get_models(
settings, catalogs_by_provider.get(provider.name, [])
):
self.models[model.id] = model
self.default_model_id = self._resolve_default(settings)
logger.info(
"ModelRegistry loaded %d models, default: %s",
len(self.models),
self.default_model_id,
)
def _resolve_default(self, settings) -> Optional[str]:
if settings.LLM_NAME:
for name in self._parse_model_names(settings.LLM_NAME):
if name in self.models:
return name
if settings.LLM_NAME in self.models:
return settings.LLM_NAME
if settings.LLM_PROVIDER and settings.API_KEY:
for model_id, model in self.models.items():
if model.provider.value == settings.LLM_PROVIDER:
return model_id
if self.models:
return next(iter(self.models.keys()))
return None
@staticmethod
def _parse_model_names(llm_name: str) -> List[str]:
if not llm_name:
return []
return [name.strip() for name in llm_name.split(",") if name.strip()]
# Per-user (BYOM) layer
def _user_models_for(self, user_id: str) -> Dict[str, AvailableModel]:
"""Return the user's BYOM models keyed by registry id (UUID).
Loaded lazily from Postgres on first access; cached subject to
a per-process TTL (``_USER_CACHE_TTL_SECONDS``) and a Redis-
backed version counter for cross-process invalidation. The TTL
bounds staleness even when Redis is unreachable, while the
version stamp lets peers refresh without a DB read on the
common case (no invalidation since last load). Decryption
failures and DB errors yield an empty layer (logged) — the
user simply doesn't see their custom models on this request,
never a 500.
"""
cached = self._user_models.get(user_id)
now = time.monotonic()
if cached is not None:
layer, cached_version, loaded_at = cached
if (now - loaded_at) < _USER_CACHE_TTL_SECONDS:
return layer
# TTL elapsed: peek at the cross-process counter. If it
# matches what we saw at load time, no invalidation has
# happened — extend the TTL without touching Postgres. If
# Redis is unreachable (``current_version is None``) we
# fall through to a real reload, which keeps staleness
# bounded to the TTL.
current_version = self._read_user_version(user_id)
if (
current_version is not None
and cached_version is not None
and current_version == cached_version
):
self._user_models[user_id] = (layer, cached_version, now)
return layer
# Capture the counter *before* the DB read so a CRUD that lands
# mid-reload doesn't get masked: the next access will see a
# newer version and reload again.
version_before_read = self._read_user_version(user_id)
layer: Dict[str, AvailableModel] = {}
try:
from application.core.model_settings import (
ModelCapabilities,
ModelProvider,
)
from application.storage.db.repositories.user_custom_models import (
UserCustomModelsRepository,
)
from application.storage.db.session import db_readonly
with db_readonly() as conn:
repo = UserCustomModelsRepository(conn)
rows = repo.list_for_user(user_id)
for row in rows:
api_key = repo._decrypt_api_key(
row.get("api_key_encrypted", ""), user_id
)
if not api_key:
# SECURITY: do NOT register an unroutable BYOM
# record. If we did, LLMCreator would fall back
# to the caller-passed api_key (settings.API_KEY
# for openai_compatible) and POST it to the
# user-supplied base_url — leaking the instance
# credential to the user's chosen endpoint.
# Most likely cause is ENCRYPTION_SECRET_KEY
# having rotated; user must re-save the model.
logger.warning(
"user_custom_models: skipping model %s for "
"user %s — api_key could not be decrypted "
"(rotated ENCRYPTION_SECRET_KEY?). Re-save "
"the model to recover.",
row.get("id"),
user_id,
)
continue
caps_raw = row.get("capabilities") or {}
# Stored attachments may be aliases (``image``) or
# raw MIME types. Built-in YAML models expand at
# load time; mirror that here so downstream MIME-
# type comparisons (handlers/base.prepare_messages)
# match concrete types like ``image/png`` rather
# than the bare alias.
from application.core.model_yaml import (
expand_attachments_lenient,
)
raw_attachments = caps_raw.get("attachments", []) or []
expanded_attachments = expand_attachments_lenient(
raw_attachments,
f"user_custom_models[user={user_id}, model={row.get('id')}]",
)
caps = ModelCapabilities(
supports_tools=bool(caps_raw.get("supports_tools", False)),
supports_structured_output=bool(
caps_raw.get("supports_structured_output", False)
),
supports_streaming=bool(
caps_raw.get("supports_streaming", True)
),
supported_attachment_types=expanded_attachments,
context_window=int(
caps_raw.get("context_window") or 128000
),
)
model_id = str(row["id"])
layer[model_id] = AvailableModel(
id=model_id,
provider=ModelProvider.OPENAI_COMPATIBLE,
display_name=row["display_name"],
description=row.get("description") or "",
capabilities=caps,
enabled=bool(row.get("enabled", True)),
base_url=row["base_url"],
upstream_model_id=row["upstream_model_id"],
source="user",
api_key=api_key,
)
except Exception as e:
logger.warning(
"user_custom_models: failed to load layer for user %s: %s",
user_id,
e,
)
layer = {}
self._user_models[user_id] = (layer, version_before_read, now)
return layer
# Lookup API. ``user_id`` enables the BYOM per-user layer; without
# it, callers see only the built-in + operator catalog.
def get_model(
self, model_id: str, user_id: Optional[str] = None
) -> Optional[AvailableModel]:
if user_id:
user_layer = self._user_models_for(user_id)
if model_id in user_layer:
return user_layer[model_id]
return self.models.get(model_id)
def get_all_models(
self, user_id: Optional[str] = None
) -> List[AvailableModel]:
out = list(self.models.values())
if user_id:
out.extend(self._user_models_for(user_id).values())
return out
def get_enabled_models(
self, user_id: Optional[str] = None
) -> List[AvailableModel]:
out = [m for m in self.models.values() if m.enabled]
if user_id:
out.extend(
m for m in self._user_models_for(user_id).values() if m.enabled
)
return out
def model_exists(
self, model_id: str, user_id: Optional[str] = None
) -> bool:
if user_id and model_id in self._user_models_for(user_id):
return True
return model_id in self.models

View File

@@ -5,16 +5,9 @@ from typing import Dict, List, Optional
logger = logging.getLogger(__name__)
# Re-exported here so existing call sites (and tests) that do
# ``from application.core.model_settings import ModelRegistry`` keep
# working. The implementation lives in ``application/core/model_registry.py``.
# Imported lazily inside ``__getattr__`` to avoid an import cycle with
# ``model_yaml`` → ``model_settings`` (this file).
class ModelProvider(str, Enum):
OPENAI = "openai"
OPENAI_COMPATIBLE = "openai_compatible"
OPENROUTER = "openrouter"
AZURE_OPENAI = "azure_openai"
ANTHROPIC = "anthropic"
@@ -48,21 +41,11 @@ class AvailableModel:
capabilities: ModelCapabilities = field(default_factory=ModelCapabilities)
enabled: bool = True
base_url: Optional[str] = None
# User-facing label distinct from dispatch provider (e.g. mistral
# routed through openai_compatible).
display_provider: Optional[str] = None
# Sent in the API call's ``model`` field; falls back to ``self.id``
# for built-ins where id IS the upstream name.
upstream_model_id: Optional[str] = None
# "builtin" for catalog YAMLs, "user" for BYOM records.
source: str = "builtin"
# Decrypted/resolved at registry-merge time. Never serialized.
api_key: Optional[str] = field(default=None, repr=False, compare=False)
def to_dict(self) -> Dict:
result = {
"id": self.id,
"provider": self.display_provider or self.provider.value,
"provider": self.provider.value,
"display_name": self.display_name,
"description": self.description,
"supported_attachment_types": self.capabilities.supported_attachment_types,
@@ -71,21 +54,261 @@ class AvailableModel:
"supports_streaming": self.capabilities.supports_streaming,
"context_window": self.capabilities.context_window,
"enabled": self.enabled,
"source": self.source,
}
if self.base_url:
result["base_url"] = self.base_url
return result
def __getattr__(name):
"""Lazy re-export of ``ModelRegistry`` from ``model_registry.py``.
class ModelRegistry:
_instance = None
_initialized = False
Done lazily to avoid an import cycle: ``model_registry`` imports
``model_yaml`` which imports the dataclasses from this file.
"""
if name == "ModelRegistry":
from application.core.model_registry import ModelRegistry as _MR
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
return _MR
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
def __init__(self):
if not ModelRegistry._initialized:
self.models: Dict[str, AvailableModel] = {}
self.default_model_id: Optional[str] = None
self._load_models()
ModelRegistry._initialized = True
@classmethod
def get_instance(cls) -> "ModelRegistry":
return cls()
def _load_models(self):
from application.core.settings import settings
self.models.clear()
# Skip DocsGPT model if using custom OpenAI-compatible endpoint
if not settings.OPENAI_BASE_URL:
self._add_docsgpt_models(settings)
if (
settings.OPENAI_API_KEY
or (settings.LLM_PROVIDER == "openai" and settings.API_KEY)
or settings.OPENAI_BASE_URL
):
self._add_openai_models(settings)
if settings.OPENAI_API_BASE or (
settings.LLM_PROVIDER == "azure_openai" and settings.API_KEY
):
self._add_azure_openai_models(settings)
if settings.ANTHROPIC_API_KEY or (
settings.LLM_PROVIDER == "anthropic" and settings.API_KEY
):
self._add_anthropic_models(settings)
if settings.GOOGLE_API_KEY or (
settings.LLM_PROVIDER == "google" and settings.API_KEY
):
self._add_google_models(settings)
if settings.GROQ_API_KEY or (
settings.LLM_PROVIDER == "groq" and settings.API_KEY
):
self._add_groq_models(settings)
if settings.OPEN_ROUTER_API_KEY or (
settings.LLM_PROVIDER == "openrouter" and settings.API_KEY
):
self._add_openrouter_models(settings)
if settings.NOVITA_API_KEY or (
settings.LLM_PROVIDER == "novita" and settings.API_KEY
):
self._add_novita_models(settings)
if settings.HUGGINGFACE_API_KEY or (
settings.LLM_PROVIDER == "huggingface" and settings.API_KEY
):
self._add_huggingface_models(settings)
# Default model selection
if settings.LLM_NAME:
# Parse LLM_NAME (may be comma-separated)
model_names = self._parse_model_names(settings.LLM_NAME)
# First model in the list becomes default
for model_name in model_names:
if model_name in self.models:
self.default_model_id = model_name
break
# Backward compat: try exact match if no parsed model found
if not self.default_model_id and settings.LLM_NAME in self.models:
self.default_model_id = settings.LLM_NAME
if not self.default_model_id:
if settings.LLM_PROVIDER and settings.API_KEY:
for model_id, model in self.models.items():
if model.provider.value == settings.LLM_PROVIDER:
self.default_model_id = model_id
break
if not self.default_model_id and self.models:
self.default_model_id = next(iter(self.models.keys()))
logger.info(
f"ModelRegistry loaded {len(self.models)} models, default: {self.default_model_id}"
)
def _add_openai_models(self, settings):
from application.core.model_configs import (
OPENAI_MODELS,
create_custom_openai_model,
)
# Check if using local OpenAI-compatible endpoint (Ollama, LM Studio, etc.)
using_local_endpoint = bool(
settings.OPENAI_BASE_URL and settings.OPENAI_BASE_URL.strip()
)
if using_local_endpoint:
# When OPENAI_BASE_URL is set, ONLY register custom models from LLM_NAME
# Do NOT add standard OpenAI models (gpt-5.1, etc.)
if settings.LLM_NAME:
model_names = self._parse_model_names(settings.LLM_NAME)
for model_name in model_names:
custom_model = create_custom_openai_model(
model_name, settings.OPENAI_BASE_URL
)
self.models[model_name] = custom_model
logger.info(
f"Registered custom OpenAI model: {model_name} at {settings.OPENAI_BASE_URL}"
)
else:
# Standard OpenAI API usage - add standard models if API key is valid
if settings.OPENAI_API_KEY:
for model in OPENAI_MODELS:
self.models[model.id] = model
def _add_azure_openai_models(self, settings):
from application.core.model_configs import AZURE_OPENAI_MODELS
if settings.LLM_PROVIDER == "azure_openai" and settings.LLM_NAME:
for model in AZURE_OPENAI_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in AZURE_OPENAI_MODELS:
self.models[model.id] = model
def _add_anthropic_models(self, settings):
from application.core.model_configs import ANTHROPIC_MODELS
if settings.ANTHROPIC_API_KEY:
for model in ANTHROPIC_MODELS:
self.models[model.id] = model
return
if settings.LLM_PROVIDER == "anthropic" and settings.LLM_NAME:
for model in ANTHROPIC_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in ANTHROPIC_MODELS:
self.models[model.id] = model
def _add_google_models(self, settings):
from application.core.model_configs import GOOGLE_MODELS
if settings.GOOGLE_API_KEY:
for model in GOOGLE_MODELS:
self.models[model.id] = model
return
if settings.LLM_PROVIDER == "google" and settings.LLM_NAME:
for model in GOOGLE_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in GOOGLE_MODELS:
self.models[model.id] = model
def _add_groq_models(self, settings):
from application.core.model_configs import GROQ_MODELS
if settings.GROQ_API_KEY:
for model in GROQ_MODELS:
self.models[model.id] = model
return
if settings.LLM_PROVIDER == "groq" and settings.LLM_NAME:
for model in GROQ_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in GROQ_MODELS:
self.models[model.id] = model
def _add_openrouter_models(self, settings):
from application.core.model_configs import OPENROUTER_MODELS
if settings.OPEN_ROUTER_API_KEY:
for model in OPENROUTER_MODELS:
self.models[model.id] = model
return
if settings.LLM_PROVIDER == "openrouter" and settings.LLM_NAME:
for model in OPENROUTER_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in OPENROUTER_MODELS:
self.models[model.id] = model
def _add_novita_models(self, settings):
from application.core.model_configs import NOVITA_MODELS
if settings.NOVITA_API_KEY:
for model in NOVITA_MODELS:
self.models[model.id] = model
return
if settings.LLM_PROVIDER == "novita" and settings.LLM_NAME:
for model in NOVITA_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in NOVITA_MODELS:
self.models[model.id] = model
def _add_docsgpt_models(self, settings):
model_id = "docsgpt-local"
model = AvailableModel(
id=model_id,
provider=ModelProvider.DOCSGPT,
display_name="DocsGPT Model",
description="Local model",
capabilities=ModelCapabilities(
supports_tools=False,
supported_attachment_types=[],
),
)
self.models[model_id] = model
def _add_huggingface_models(self, settings):
model_id = "huggingface-local"
model = AvailableModel(
id=model_id,
provider=ModelProvider.HUGGINGFACE,
display_name="Hugging Face Model",
description="Local Hugging Face model",
capabilities=ModelCapabilities(
supports_tools=False,
supported_attachment_types=[],
),
)
self.models[model_id] = model
def _parse_model_names(self, llm_name: str) -> List[str]:
"""
Parse LLM_NAME which may contain comma-separated model names.
E.g., 'deepseek-r1:1.5b,gemma:2b' -> ['deepseek-r1:1.5b', 'gemma:2b']
"""
if not llm_name:
return []
return [name.strip() for name in llm_name.split(",") if name.strip()]
def get_model(self, model_id: str) -> Optional[AvailableModel]:
return self.models.get(model_id)
def get_all_models(self) -> List[AvailableModel]:
return list(self.models.values())
def get_enabled_models(self) -> List[AvailableModel]:
return [m for m in self.models.values() if m.enabled]
def model_exists(self, model_id: str) -> bool:
return model_id in self.models

View File

@@ -1,59 +1,47 @@
from typing import Any, Dict, Optional
from application.core.model_registry import ModelRegistry
from application.core.model_settings import ModelRegistry
def get_api_key_for_provider(provider: str) -> Optional[str]:
"""Get the appropriate API key for a provider.
Delegates to the provider plugin's ``get_api_key``. Falls back to the
generic ``settings.API_KEY`` for unknown providers.
"""
"""Get the appropriate API key for a provider"""
from application.core.settings import settings
from application.llm.providers import PROVIDERS_BY_NAME
plugin = PROVIDERS_BY_NAME.get(provider)
if plugin is not None:
key = plugin.get_api_key(settings)
if key:
return key
provider_key_map = {
"openai": settings.OPENAI_API_KEY,
"openrouter": settings.OPEN_ROUTER_API_KEY,
"novita": settings.NOVITA_API_KEY,
"anthropic": settings.ANTHROPIC_API_KEY,
"google": settings.GOOGLE_API_KEY,
"groq": settings.GROQ_API_KEY,
"huggingface": settings.HUGGINGFACE_API_KEY,
"azure_openai": settings.API_KEY,
"docsgpt": None,
"llama.cpp": None,
}
provider_key = provider_key_map.get(provider)
if provider_key:
return provider_key
return settings.API_KEY
def get_all_available_models(
user_id: Optional[str] = None,
) -> Dict[str, Dict[str, Any]]:
"""Get all available models with metadata for API response.
When ``user_id`` is supplied, the user's BYOM custom-model records
are merged into the result alongside the built-in catalog.
"""
def get_all_available_models() -> Dict[str, Dict[str, Any]]:
"""Get all available models with metadata for API response"""
registry = ModelRegistry.get_instance()
return {
model.id: model.to_dict()
for model in registry.get_enabled_models(user_id=user_id)
}
return {model.id: model.to_dict() for model in registry.get_enabled_models()}
def validate_model_id(model_id: str, user_id: Optional[str] = None) -> bool:
"""Check if a model ID exists in registry.
``user_id`` enables resolution of per-user BYOM records (UUIDs).
Without it, only built-in catalog ids resolve.
"""
def validate_model_id(model_id: str) -> bool:
"""Check if a model ID exists in registry"""
registry = ModelRegistry.get_instance()
return registry.model_exists(model_id, user_id=user_id)
return registry.model_exists(model_id)
def get_model_capabilities(
model_id: str, user_id: Optional[str] = None
) -> Optional[Dict[str, Any]]:
"""Get capabilities for a specific model.
``user_id`` enables resolution of per-user BYOM records.
"""
def get_model_capabilities(model_id: str) -> Optional[Dict[str, Any]]:
"""Get capabilities for a specific model"""
registry = ModelRegistry.get_instance()
model = registry.get_model(model_id, user_id=user_id)
model = registry.get_model(model_id)
if model:
return {
"supported_attachment_types": model.capabilities.supported_attachment_types,
@@ -70,68 +58,36 @@ def get_default_model_id() -> str:
return registry.default_model_id
def get_provider_from_model_id(
model_id: str, user_id: Optional[str] = None
) -> Optional[str]:
"""Get the provider name for a given model_id.
``user_id`` enables resolution of per-user BYOM records (UUIDs).
Without it, BYOM model ids return ``None`` and the caller falls
back to the deployment default.
"""
def get_provider_from_model_id(model_id: str) -> Optional[str]:
"""Get the provider name for a given model_id"""
registry = ModelRegistry.get_instance()
model = registry.get_model(model_id, user_id=user_id)
model = registry.get_model(model_id)
if model:
return model.provider.value
return None
def get_token_limit(model_id: str, user_id: Optional[str] = None) -> int:
"""Get context window (token limit) for a model.
Returns the model's ``context_window`` or ``DEFAULT_LLM_TOKEN_LIMIT``
if not found. ``user_id`` enables resolution of per-user BYOM records.
def get_token_limit(model_id: str) -> int:
"""
Get context window (token limit) for a model.
Returns model's context_window or default 128000 if model not found.
"""
from application.core.settings import settings
registry = ModelRegistry.get_instance()
model = registry.get_model(model_id, user_id=user_id)
model = registry.get_model(model_id)
if model:
return model.capabilities.context_window
return settings.DEFAULT_LLM_TOKEN_LIMIT
def get_base_url_for_model(
model_id: str, user_id: Optional[str] = None
) -> Optional[str]:
"""Get the custom base_url for a specific model if configured.
Returns ``None`` if no custom base_url is set. ``user_id`` enables
resolution of per-user BYOM records.
def get_base_url_for_model(model_id: str) -> Optional[str]:
"""
Get the custom base_url for a specific model if configured.
Returns None if no custom base_url is set.
"""
registry = ModelRegistry.get_instance()
model = registry.get_model(model_id, user_id=user_id)
model = registry.get_model(model_id)
if model:
return model.base_url
return None
def get_api_key_for_model(
model_id: str, user_id: Optional[str] = None
) -> Optional[str]:
"""Resolve the API key to use when invoking ``model_id``.
Priority:
1. The model record's own ``api_key`` (BYOM records and
``openai_compatible`` YAMLs populate this).
2. The provider plugin's settings-based key.
``user_id`` enables resolution of per-user BYOM records.
"""
registry = ModelRegistry.get_instance()
model = registry.get_model(model_id, user_id=user_id)
if model is not None and model.api_key:
return model.api_key
if model is not None:
return get_api_key_for_provider(model.provider.value)
return None

View File

@@ -1,358 +0,0 @@
"""YAML loader for model catalog files under ``application/core/models/``.
Each ``*.yaml`` file declares one provider's static model catalog. Files
are validated with Pydantic at load time; any parse, schema, or alias
error aborts startup with the offending file path in the message.
For most providers, one YAML maps to one catalog. The
``openai_compatible`` provider is special: each YAML file represents a
distinct logical endpoint (Mistral, Together, Ollama, ...) with its own
``api_key_env`` and ``base_url``. The loader returns a flat list so the
registry can distinguish multiple files with the same ``provider:`` value.
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Dict, List, Optional, Sequence
import yaml
from pydantic import BaseModel, ConfigDict, Field, field_validator
from application.core.model_settings import (
AvailableModel,
ModelCapabilities,
ModelProvider,
)
logger = logging.getLogger(__name__)
BUILTIN_MODELS_DIR = Path(__file__).parent / "models"
DEFAULTS_FILENAME = "_defaults.yaml"
class _DefaultsFile(BaseModel):
"""Schema for ``_defaults.yaml``. Currently just attachment aliases."""
model_config = ConfigDict(extra="forbid")
attachment_aliases: Dict[str, List[str]] = Field(default_factory=dict)
class _CapabilityFields(BaseModel):
"""Capability fields shared between provider ``defaults:`` and per-model overrides.
All fields are optional so a per-model override can selectively replace
a single field from the provider-level defaults.
"""
model_config = ConfigDict(extra="forbid")
supports_tools: Optional[bool] = None
supports_structured_output: Optional[bool] = None
supports_streaming: Optional[bool] = None
attachments: Optional[List[str]] = None
context_window: Optional[int] = None
input_cost_per_token: Optional[float] = None
output_cost_per_token: Optional[float] = None
class _ModelEntry(_CapabilityFields):
"""Schema for one model row inside a YAML's ``models:`` list."""
id: str
display_name: Optional[str] = None
description: str = ""
enabled: bool = True
base_url: Optional[str] = None
aliases: List[str] = Field(default_factory=list)
@field_validator("id")
@classmethod
def _id_nonempty(cls, v: str) -> str:
if not v or not v.strip():
raise ValueError("model id must be a non-empty string")
return v
class _ProviderFile(BaseModel):
"""Schema for one ``<provider>.yaml`` catalog file."""
model_config = ConfigDict(extra="forbid")
provider: str
defaults: _CapabilityFields = Field(default_factory=_CapabilityFields)
models: List[_ModelEntry] = Field(default_factory=list)
# openai_compatible metadata. Optional for other providers.
display_provider: Optional[str] = None
api_key_env: Optional[str] = None
base_url: Optional[str] = None
class ProviderCatalog(BaseModel):
"""One YAML file's parsed contents, ready for the registry.
For most providers, multiple catalogs with the same ``provider`` get
merged later by the registry. The ``openai_compatible`` provider is
the exception: each catalog is treated as a distinct endpoint, with
its own ``api_key_env`` and ``base_url``.
"""
provider: str
models: List[AvailableModel]
source_path: Optional[Path] = None
display_provider: Optional[str] = None
api_key_env: Optional[str] = None
base_url: Optional[str] = None
model_config = ConfigDict(arbitrary_types_allowed=True)
class ModelYAMLError(ValueError):
"""Raised when a model YAML fails parsing, schema, or alias validation."""
def _expand_attachments(
attachments: Sequence[str], aliases: Dict[str, List[str]], source: str
) -> List[str]:
"""Resolve attachment shorthands (``image``, ``pdf``) to MIME types.
Raw MIME-typed entries (containing ``/``) pass through unchanged.
Unknown aliases raise ``ModelYAMLError``.
"""
expanded: List[str] = []
seen: set = set()
for entry in attachments:
if "/" in entry:
if entry not in seen:
expanded.append(entry)
seen.add(entry)
continue
if entry not in aliases:
valid = ", ".join(sorted(aliases.keys())) or "<none defined>"
raise ModelYAMLError(
f"{source}: unknown attachment alias '{entry}'. "
f"Valid aliases: {valid}. "
"(Or use a raw MIME type like 'image/png'.)"
)
for mime in aliases[entry]:
if mime not in seen:
expanded.append(mime)
seen.add(mime)
return expanded
def _load_defaults(directory: Path) -> Dict[str, List[str]]:
"""Load ``_defaults.yaml`` from ``directory`` if it exists."""
path = directory / DEFAULTS_FILENAME
if not path.exists():
return {}
try:
raw = yaml.safe_load(path.read_text(encoding="utf-8")) or {}
except yaml.YAMLError as e:
raise ModelYAMLError(f"{path}: invalid YAML: {e}") from e
try:
parsed = _DefaultsFile.model_validate(raw)
except Exception as e:
raise ModelYAMLError(f"{path}: schema error: {e}") from e
return parsed.attachment_aliases
def _resolve_provider_enum(name: str, source: Path) -> ModelProvider:
try:
return ModelProvider(name)
except ValueError as e:
valid = ", ".join(p.value for p in ModelProvider)
raise ModelYAMLError(
f"{source}: unknown provider '{name}'. Valid: {valid}"
) from e
def _build_model(
entry: _ModelEntry,
defaults: _CapabilityFields,
provider: ModelProvider,
aliases: Dict[str, List[str]],
source: Path,
display_provider: Optional[str] = None,
) -> AvailableModel:
"""Merge defaults + per-model overrides into a final ``AvailableModel``."""
def pick(field_name: str, fallback):
v = getattr(entry, field_name)
if v is not None:
return v
d = getattr(defaults, field_name)
if d is not None:
return d
return fallback
raw_attachments = entry.attachments
if raw_attachments is None:
raw_attachments = defaults.attachments
if raw_attachments is None:
raw_attachments = []
expanded = _expand_attachments(
raw_attachments, aliases, f"{source} [model={entry.id}]"
)
caps = ModelCapabilities(
supports_tools=pick("supports_tools", False),
supports_structured_output=pick("supports_structured_output", False),
supports_streaming=pick("supports_streaming", True),
supported_attachment_types=expanded,
context_window=pick("context_window", 128000),
input_cost_per_token=pick("input_cost_per_token", None),
output_cost_per_token=pick("output_cost_per_token", None),
)
return AvailableModel(
id=entry.id,
provider=provider,
display_name=entry.display_name or entry.id,
description=entry.description,
capabilities=caps,
enabled=entry.enabled,
base_url=entry.base_url,
display_provider=display_provider,
)
def _load_one_yaml(
path: Path, aliases: Dict[str, List[str]]
) -> ProviderCatalog:
try:
raw = yaml.safe_load(path.read_text(encoding="utf-8")) or {}
except yaml.YAMLError as e:
raise ModelYAMLError(f"{path}: invalid YAML: {e}") from e
try:
parsed = _ProviderFile.model_validate(raw)
except Exception as e:
raise ModelYAMLError(f"{path}: schema error: {e}") from e
provider_enum = _resolve_provider_enum(parsed.provider, path)
models = [
_build_model(
entry,
parsed.defaults,
provider_enum,
aliases,
path,
display_provider=parsed.display_provider,
)
for entry in parsed.models
]
return ProviderCatalog(
provider=parsed.provider,
models=models,
source_path=path,
display_provider=parsed.display_provider,
api_key_env=parsed.api_key_env,
base_url=parsed.base_url,
)
_BUILTIN_ALIASES_CACHE: Optional[Dict[str, List[str]]] = None
def builtin_attachment_aliases() -> Dict[str, List[str]]:
"""Return the built-in attachment alias map from ``_defaults.yaml``.
Cached after first read so repeat calls are cheap.
"""
global _BUILTIN_ALIASES_CACHE
if _BUILTIN_ALIASES_CACHE is None:
_BUILTIN_ALIASES_CACHE = _load_defaults(BUILTIN_MODELS_DIR)
return _BUILTIN_ALIASES_CACHE
def resolve_attachment_alias(alias: str) -> List[str]:
"""Resolve a single attachment alias (e.g. ``"image"``) to its
canonical MIME-type list. Raises ``ModelYAMLError`` if unknown.
"""
aliases = builtin_attachment_aliases()
if alias not in aliases:
valid = ", ".join(sorted(aliases.keys())) or "<none defined>"
raise ModelYAMLError(
f"Unknown attachment alias '{alias}'. Valid: {valid}"
)
return list(aliases[alias])
def expand_attachments_lenient(
attachments: Sequence[str], source: str
) -> List[str]:
"""Expand attachment aliases to MIME types, tolerating unknowns.
Mirrors ``_expand_attachments`` but logs+skips unknown aliases
rather than raising. Used for runtime call sites (BYOM registry
load) where an operator-side alias-map edit must not drop the
entire user's BYOM layer; the strict raise still happens at the
API validation boundary.
"""
aliases = builtin_attachment_aliases()
expanded: List[str] = []
seen: set = set()
for entry in attachments:
if "/" in entry:
if entry not in seen:
expanded.append(entry)
seen.add(entry)
continue
mime_list = aliases.get(entry)
if mime_list is None:
logger.warning(
"%s: skipping unknown attachment alias %r", source, entry,
)
continue
for mime in mime_list:
if mime not in seen:
expanded.append(mime)
seen.add(mime)
return expanded
def load_model_yamls(directories: Sequence[Path]) -> List[ProviderCatalog]:
"""Load every ``*.yaml`` file (excluding ``_defaults.yaml``) under each
directory in order and return a flat list of catalogs.
Caller is responsible for merging multiple catalogs that target the
same provider plugin. The flat-list shape lets ``openai_compatible``
keep each file separate (one logical endpoint per file).
When the same model ``id`` appears in more than one YAML across the
directory list, a warning is logged. Order in the returned list
preserves load order, so the registry's "later wins" merge gives the
later directory's definition.
"""
catalogs: List[ProviderCatalog] = []
seen_ids: Dict[str, Path] = {}
aliases: Dict[str, List[str]] = {}
for d in directories:
if not d or not d.exists():
continue
aliases.update(_load_defaults(d))
for d in directories:
if not d or not d.exists():
continue
for path in sorted(d.glob("*.yaml")):
if path.name == DEFAULTS_FILENAME:
continue
catalog = _load_one_yaml(path, aliases)
catalogs.append(catalog)
for m in catalog.models:
prior = seen_ids.get(m.id)
if prior is not None and prior != path:
logger.warning(
"Model id %r redefined: %s overrides %s (later wins)",
m.id,
path,
prior,
)
seen_ids[m.id] = path
return catalogs

View File

@@ -1,213 +0,0 @@
# Model catalogs
Each `*.yaml` file in this directory declares one provider's model
catalog. The registry loads every YAML at boot and joins it to the
matching provider plugin under `application/llm/providers/`.
To add or edit models, you almost always only touch a YAML here — no
Python code required.
## Add a model to an existing provider
Open the provider's YAML (e.g. `anthropic.yaml`) and append two lines
under `models:`:
```yaml
models:
- id: claude-3-7-sonnet
display_name: Claude 3.7 Sonnet
```
Capabilities default to the provider's `defaults:` block. Override
per-model only when needed:
```yaml
- id: claude-3-7-sonnet
display_name: Claude 3.7 Sonnet
context_window: 500000
```
Restart the app. The new model appears in `/api/models`.
> The model `id` is what gets stored in agent / workflow records. Once
> users start picking the model, **don't rename it** — agent and
> workflow rows reference it as a free-form string and silently fall
> back to the system default if the id disappears.
## Add an OpenAI-compatible provider (zero Python)
Drop a YAML in this directory (or in your `MODELS_CONFIG_DIR`) that uses
the `openai_compatible` plugin. Set the env var named in `api_key_env`
and you're done — no Python, no settings.py edit, no LLMCreator change:
```yaml
# mistral.yaml
provider: openai_compatible
display_provider: mistral # shown in /api/models response
api_key_env: MISTRAL_API_KEY # env var the plugin reads at boot
base_url: https://api.mistral.ai/v1
defaults:
supports_tools: true
context_window: 128000
models:
- id: mistral-large-latest
display_name: Mistral Large
- id: mistral-small-latest
display_name: Mistral Small
```
`MISTRAL_API_KEY=sk-... ; restart` — Mistral models appear in
`/api/models` with `provider: "mistral"`. They route through the OpenAI
wire format (it's `OpenAILLM` under the hood) but with Mistral's
endpoint and key.
Multiple `openai_compatible` YAMLs coexist: each file is one logical
endpoint with its own `api_key_env` and `base_url`. Drop in
`together.yaml`, `fireworks.yaml`, etc. side by side. If an env var
isn't set, that catalog is silently skipped at boot (logged at INFO) —
no error.
Working example: `examples/mistral.yaml.example`. Files inside
`examples/` aren't loaded by the registry; the glob only picks up
`*.yaml` at the top level.
## Add a provider with its own SDK
For a provider that doesn't speak OpenAI's wire format, add one Python
file to `application/llm/providers/<name>.py`:
```python
from application.llm.providers.base import Provider
from application.llm.my_provider import MyLLM
class MyProvider(Provider):
name = "my_provider"
llm_class = MyLLM
def get_api_key(self, settings):
return settings.MY_PROVIDER_API_KEY
```
Register it in `application/llm/providers/__init__.py` (one line in
`ALL_PROVIDERS`), add `MY_PROVIDER_API_KEY` to `settings.py`, and create
`my_provider.yaml` here with the model catalog.
## Schema reference
```yaml
provider: <string, required> # matches the Provider plugin's `name`
# openai_compatible only — required for that provider, ignored for others
display_provider: <string> # label shown in /api/models response
api_key_env: <string> # name of the env var carrying the key
base_url: <string> # endpoint URL
defaults: # optional, applied to every model below
supports_tools: bool # default false
supports_structured_output: bool # default false
supports_streaming: bool # default true
attachments: [<alias-or-mime>, ...] # default []
context_window: int # default 128000
input_cost_per_token: float # default null
output_cost_per_token: float # default null
models: # required
- id: <string, required> # the value persisted in agent records
display_name: <string> # default: id
description: <string> # default: ""
enabled: bool # default true; false hides from /api/models
base_url: <string> # optional custom endpoint for this model
# All `defaults:` fields above can be overridden here per-model.
```
### Attachment aliases
The `attachments:` list can mix human-readable aliases with raw MIME
types. Aliases are defined in `_defaults.yaml`:
| Alias | Expands to |
|---|---|
| `image` | `image/png`, `image/jpeg`, `image/jpg`, `image/webp`, `image/gif` |
| `pdf` | `application/pdf` |
| `audio` | `audio/mpeg`, `audio/wav`, `audio/ogg` |
Use raw MIME types when you need surgical control:
```yaml
attachments: [image/png, image/webp] # only these two
```
## Operator-supplied YAMLs (`MODELS_CONFIG_DIR`)
Set the `MODELS_CONFIG_DIR` env var (or `.env` entry) to a directory
path. Every `*.yaml` in that directory is loaded **after** the built-in
catalog under `application/core/models/`. Operators use this to:
- Add new `openai_compatible` providers (Mistral, Together, Fireworks,
Ollama, ...) without forking the repo.
- Extend an existing provider's catalog with extra models — append
models under `provider: anthropic` and they show up alongside the
built-ins.
- Override a built-in model's capabilities — declare the same `id`
with different fields (e.g. a higher `context_window`). Later wins;
the override is logged as a `WARNING` so you can audit it.
Things you cannot do via `MODELS_CONFIG_DIR`:
- Add a brand-new non-OpenAI provider — that needs a Python plugin
under `application/llm/providers/` (see "Add a provider with its own
SDK" above). Operator YAMLs may only target a `provider:` value that
already has a registered plugin.
### Example: Docker
Mount your model YAMLs into the container and point the env var at the
mount path:
```yaml
# docker-compose.yml
services:
app:
image: arc53/docsgpt
environment:
MODELS_CONFIG_DIR: /etc/docsgpt/models
MISTRAL_API_KEY: ${MISTRAL_API_KEY}
volumes:
- ./my-models:/etc/docsgpt/models:ro
```
Then `./my-models/mistral.yaml` (the file from
`examples/mistral.yaml.example`) gets picked up at boot.
### Example: Kubernetes
Mount a `ConfigMap` containing your YAMLs at a known path and set
`MODELS_CONFIG_DIR` on the deployment. The same `examples/mistral.yaml.example`
becomes a key in the ConfigMap.
### Misconfiguration
If `MODELS_CONFIG_DIR` is set but the path doesn't exist (or isn't a
directory), the app logs a `WARNING` at boot and continues with just
the built-in catalog. The app does *not* fail to start — operators can
ship config drift without taking down the service — but the warning is
loud enough to surface in any reasonable log aggregator.
## Validation
YAMLs are parsed with Pydantic at boot. The app fails to start with a
clear error message if:
- a top-level key is unknown
- a model is missing `id`
- an attachment alias isn't defined
- the `provider:` value isn't registered as a plugin
This is intentional — silent fallbacks would mean users don't notice
their model picks broke until they hit the API.
## Reserved fields (not yet implemented)
- `aliases:` on a model — old IDs that resolve to this model. Reserved
for future renames; the schema accepts the field but it is not yet
acted on.

View File

@@ -1,18 +0,0 @@
# Global defaults applied across every model YAML in this directory.
# Keep this file sparse — per-provider `defaults:` blocks are clearer
# than a deep global default chain. This file is for things that
# genuinely never vary, like the meaning of "image".
attachment_aliases:
image:
- image/png
- image/jpeg
- image/jpg
- image/webp
- image/gif
pdf:
- application/pdf
audio:
- audio/mpeg
- audio/wav
- audio/ogg

View File

@@ -1,23 +0,0 @@
provider: anthropic
defaults:
supports_tools: true
attachments: [image]
context_window: 200000
models:
- id: claude-opus-4-7
display_name: Claude Opus 4.7
description: Most capable Claude model for complex reasoning and agentic coding
context_window: 1000000
supports_structured_output: true
- id: claude-sonnet-4-6
display_name: Claude Sonnet 4.6
description: Best balance of speed and intelligence with extended thinking
context_window: 1000000
supports_structured_output: true
- id: claude-haiku-4-5
display_name: Claude Haiku 4.5
description: Fastest Claude model with near-frontier intelligence
supports_structured_output: true

View File

@@ -1,31 +0,0 @@
# Azure OpenAI catalog.
#
# IMPORTANT: For Azure OpenAI, the `id` field is the **deployment name**, not
# a model name. Deployment names are arbitrary strings the operator chooses
# in Azure portal (or via ARM/Bicep/Terraform) when they create a deployment
# for a given underlying model + version.
#
# The IDs below are sensible defaults that mirror the underlying OpenAI
# model name (prefixed with `azure-`). Operators almost always need to
# override them via `MODELS_CONFIG_DIR` to match the deployment names that
# actually exist in their Azure resource. The `display_name`, capability
# flags, and `context_window` reflect the underlying OpenAI model.
provider: azure_openai
defaults:
supports_tools: true
supports_structured_output: true
attachments: [image]
context_window: 400000
models:
- id: azure-gpt-5.5
display_name: Azure OpenAI GPT-5.5
description: Azure-hosted flagship frontier model for complex reasoning, coding, and agentic work with a 1M-token context window
context_window: 1050000
- id: azure-gpt-5.4-mini
display_name: Azure OpenAI GPT-5.4 Mini
description: Azure-hosted cost-efficient GPT-5.4-class model for high-volume coding, computer use, and subagent workloads
- id: azure-gpt-5.4-nano
display_name: Azure OpenAI GPT-5.4 Nano
description: Azure-hosted cheapest GPT-5.4-class model, optimized for simple high-volume tasks where speed and cost matter most

View File

@@ -1,7 +0,0 @@
provider: docsgpt
models:
- id: docsgpt-local
display_name: DocsGPT Model
description: Local model
supports_tools: false

View File

@@ -1,31 +0,0 @@
# EXAMPLE — copy this file to ../mistral.yaml (or to your
# MODELS_CONFIG_DIR) and set MISTRAL_API_KEY in your environment.
#
# This is the entire integration. No Python required: the
# `openai_compatible` plugin reads `api_key_env` and `base_url` from
# the file and routes calls through the OpenAI wire format.
#
# Files in this `examples/` directory are NOT loaded by the registry
# (the loader globs *.yaml at the top level only).
provider: openai_compatible
display_provider: mistral # shown in /api/models response
api_key_env: MISTRAL_API_KEY # env var the plugin reads
base_url: https://api.mistral.ai/v1 # OpenAI-compatible endpoint
defaults:
supports_tools: true
context_window: 128000
models:
- id: mistral-large-latest
display_name: Mistral Large
description: Top-tier reasoning model
- id: mistral-small-latest
display_name: Mistral Small
description: Fast, cost-efficient
- id: codestral-latest
display_name: Codestral
description: Code-specialized model

View File

@@ -1,17 +0,0 @@
provider: google
defaults:
supports_tools: true
supports_structured_output: true
attachments: [pdf, image]
context_window: 1048576
models:
- id: gemini-3.1-pro-preview
display_name: Gemini 3.1 Pro
description: Most capable Gemini 3 model with advanced reasoning and agentic coding (preview)
- id: gemini-3-flash-preview
display_name: Gemini 3 Flash
description: Frontier-class performance for low-latency, high-volume tasks (preview)
- id: gemini-3.1-flash-lite-preview
display_name: Gemini 3.1 Flash-Lite
description: Cost-efficient frontier-class multimodal model for high-throughput workloads (preview)

View File

@@ -1,16 +0,0 @@
provider: groq
defaults:
supports_tools: true
context_window: 131072
models:
- id: openai/gpt-oss-120b
display_name: GPT-OSS 120B
description: OpenAI's open-weight 120B flagship served on Groq's LPU hardware; strong general reasoning with strict structured output support
supports_structured_output: true
- id: llama-3.3-70b-versatile
display_name: Llama 3.3 70B Versatile
description: Meta's Llama 3.3 70B for general-purpose chat with parallel tool use
- id: llama-3.1-8b-instant
display_name: Llama 3.1 8B Instant
description: Small, very low-latency Llama model (~560 tok/s) with parallel tool use

View File

@@ -1,7 +0,0 @@
provider: huggingface
models:
- id: huggingface-local
display_name: Hugging Face Model
description: Local Hugging Face model
supports_tools: false

View File

@@ -1,21 +0,0 @@
provider: novita
defaults:
supports_tools: true
supports_structured_output: true
models:
- id: deepseek/deepseek-v4-pro
display_name: DeepSeek V4 Pro
description: 1.6T MoE (49B active) with 1M context, hybrid CSA/HCA attention, top-tier reasoning and agentic coding
context_window: 1048576
- id: moonshotai/kimi-k2.6
display_name: Kimi K2.6
description: 1T-parameter open-weight MoE with native vision/video, multi-step tool calling, and agentic long-horizon execution
attachments: [image]
context_window: 262144
- id: zai-org/glm-5
display_name: GLM-5
description: Z.AI 754B-parameter MoE with strong general reasoning, function calling, and structured output
context_window: 202800

Some files were not shown because too many files have changed in this diff Show More