Compare commits

..

1 Commits

Author SHA1 Message Date
Alex
e0a8cc178b feat: BYOM 2026-04-27 21:50:45 +01:00
263 changed files with 2405 additions and 30550 deletions

View File

@@ -47,13 +47,11 @@
</ul>
## Roadmap
- [x] Agent Workflow Builder with conditional nodes ( February 2026 )
- [x] SharePoint & Confluence connectors ( March April 2026 )
- [x] Research mode ( March 2026 )
- [x] Postgres migration for user data ( April 2026 )
- [x] OpenTelemetry observability ( April 2026 )
- [x] Bring Your Own Model (BYOM) ( April 2026 )
- [ ] Agent scheduling (RedBeat-backed) ( Q2 2026 )
- [x] Add OAuth 2.0 authentication for MCP ( September 2025 )
- [x] Deep Agents ( October 2025 )
- [x] Prompt Templating ( October 2025 )
- [x] Full api tooling ( Dec 2025 )
- [ ] Agent scheduling ( Jan 2026 )
You can find our full roadmap [here](https://github.com/orgs/arc53/projects/2). Please don't hesitate to contribute or create issues, it helps us improve DocsGPT!

View File

@@ -1,107 +1,18 @@
import logging
import uuid
from collections import Counter
from typing import Any, Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple
from application.agents.tools.tool_action_parser import ToolActionParser
from application.agents.tools.tool_manager import ToolManager
from application.security.encryption import decrypt_credentials
from application.storage.db.base_repository import looks_like_uuid
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.repositories.tool_call_attempts import (
ToolCallAttemptsRepository,
)
from application.storage.db.repositories.user_tools import UserToolsRepository
from application.storage.db.session import db_readonly, db_session
from application.storage.db.session import db_readonly
logger = logging.getLogger(__name__)
def _record_proposed(
call_id: str,
tool_name: str,
action_name: str,
arguments: Any,
*,
tool_id: Optional[str] = None,
) -> bool:
"""Insert a ``proposed`` row; swallow infra failures so tool calls
still run when the journal is unreachable. Returns True iff the row
is now journaled (newly created or already present).
"""
try:
with db_session() as conn:
inserted = ToolCallAttemptsRepository(conn).record_proposed(
call_id,
tool_name,
action_name,
arguments,
tool_id=tool_id if tool_id and looks_like_uuid(tool_id) else None,
)
if not inserted:
logger.warning(
"tool_call_attempts duplicate call_id=%s; existing row left in place",
call_id,
extra={"alert": "tool_call_id_collision", "call_id": call_id},
)
return True
except Exception:
logger.exception("tool_call_attempts proposed write failed for %s", call_id)
return False
def _mark_executed(
call_id: str,
result: Any,
*,
message_id: Optional[str] = None,
artifact_id: Optional[str] = None,
proposed_ok: bool = True,
tool_name: Optional[str] = None,
action_name: Optional[str] = None,
arguments: Any = None,
tool_id: Optional[str] = None,
) -> None:
"""Flip the row to ``executed``. If ``proposed_ok`` is False (the
proposed write failed earlier), upsert a fresh row in ``executed`` so
the reconciler can still see the attempt — without this, the side
effect would be invisible to the journal.
"""
try:
with db_session() as conn:
repo = ToolCallAttemptsRepository(conn)
if proposed_ok:
updated = repo.mark_executed(
call_id,
result,
message_id=message_id,
artifact_id=artifact_id,
)
if updated:
return
# Fallback synthesizes the row so the journal isn't lost.
repo.upsert_executed(
call_id,
tool_name=tool_name or "unknown",
action_name=action_name or "",
arguments=arguments if arguments is not None else {},
result=result,
tool_id=tool_id if tool_id and looks_like_uuid(tool_id) else None,
message_id=message_id,
artifact_id=artifact_id,
)
except Exception:
logger.exception("tool_call_attempts executed write failed for %s", call_id)
def _mark_failed(call_id: str, error: str) -> None:
try:
with db_session() as conn:
ToolCallAttemptsRepository(conn).mark_failed(call_id, error)
except Exception:
logger.exception("tool_call_attempts failed-write failed for %s", call_id)
class ToolExecutor:
"""Handles tool discovery, preparation, and execution.
@@ -120,7 +31,6 @@ class ToolExecutor:
self.tool_calls: List[Dict] = []
self._loaded_tools: Dict[str, object] = {}
self.conversation_id: Optional[str] = None
self.message_id: Optional[str] = None
self.client_tools: Optional[List[Dict]] = None
self._name_to_tool: Dict[str, Tuple[str, str]] = {}
self._tool_to_name: Dict[Tuple[str, str], str] = {}
@@ -364,14 +274,7 @@ class ToolExecutor:
if tool_id is None or action_name is None:
error_message = f"Error: Failed to parse LLM tool call. Tool name: {llm_name}"
logger.error(
"tool_call_parse_failed",
extra={
"llm_class_name": llm_class_name,
"llm_tool_name": llm_name,
"call_id": call_id,
},
)
logger.error(error_message)
tool_call_data = {
"tool_name": "unknown",
@@ -386,15 +289,7 @@ class ToolExecutor:
if tool_id not in tools_dict:
error_message = f"Error: Tool ID '{tool_id}' extracted from LLM call not found in available tools_dict. Available IDs: {list(tools_dict.keys())}"
logger.error(
"tool_id_not_found",
extra={
"tool_id": tool_id,
"llm_tool_name": llm_name,
"call_id": call_id,
"available_tool_count": len(tools_dict),
},
)
logger.error(error_message)
tool_call_data = {
"tool_name": "unknown",
@@ -413,36 +308,9 @@ class ToolExecutor:
"action_name": llm_name,
"arguments": call_args,
}
tool_data = tools_dict[tool_id]
# Journal first so the reconciler sees malformed calls and any
# subsequent ``_mark_failed`` actually updates a real row.
proposed_ok = _record_proposed(
call_id,
tool_data["name"],
action_name,
call_args if isinstance(call_args, dict) else {},
tool_id=tool_data.get("id"),
)
# Defensive guard: a non-dict ``call_args`` (e.g. malformed
# JSON on the resume path) would crash the param walk below
# with AttributeError on ``.items()``. Surface a clean error
# event and flip the journal row to ``failed`` instead of
# killing the stream.
if not isinstance(call_args, dict):
error_message = (
f"Tool call arguments must be a JSON object, got "
f"{type(call_args).__name__}."
)
tool_call_data["result"] = error_message
tool_call_data["arguments"] = {}
_mark_failed(call_id, error_message)
yield {
"type": "tool_call",
"data": {**tool_call_data, "status": "error"},
}
self.tool_calls.append(tool_call_data)
return error_message, call_id
yield {"type": "tool_call", "data": {**tool_call_data, "status": "pending"}}
tool_data = tools_dict[tool_id]
action_data = (
tool_data["config"]["actions"][action_name]
if tool_data["name"] == "api_tool"
@@ -488,17 +356,8 @@ class ToolExecutor:
f"Failed to load tool '{tool_data.get('name')}' (tool_id key={tool_id}): "
"missing 'id' on tool row."
)
logger.error(
"tool_load_failed",
extra={
"tool_name": tool_data.get("name"),
"tool_id": tool_id,
"action_name": action_name,
"call_id": call_id,
},
)
logger.error(error_message)
tool_call_data["result"] = error_message
_mark_failed(call_id, error_message)
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
self.tool_calls.append(tool_call_data)
return error_message, call_id
@@ -508,18 +367,14 @@ class ToolExecutor:
if tool_data["name"] == "api_tool"
else parameters
)
try:
if tool_data["name"] == "api_tool":
logger.debug(
f"Executing api: {action_name} with query_params: {query_params}, headers: {headers}, body: {body}"
)
result = tool.execute_action(action_name, **body)
else:
logger.debug(f"Executing tool: {action_name} with args: {call_args}")
result = tool.execute_action(action_name, **parameters)
except Exception as exc:
_mark_failed(call_id, str(exc))
raise
if tool_data["name"] == "api_tool":
logger.debug(
f"Executing api: {action_name} with query_params: {query_params}, headers: {headers}, body: {body}"
)
result = tool.execute_action(action_name, **body)
else:
logger.debug(f"Executing tool: {action_name} with args: {call_args}")
result = tool.execute_action(action_name, **parameters)
get_artifact_id = (
getattr(tool, "get_artifact_id", None)
@@ -548,22 +403,6 @@ class ToolExecutor:
f"{result_full[:50]}..." if len(result_full) > 50 else result_full
)
# Tool side effect has run; flip the journal row so the
# message-finalize path can later confirm it. If the proposed
# write failed (DB outage), upsert a fresh row in ``executed`` so
# the reconciler still sees the side effect.
_mark_executed(
call_id,
result_full,
message_id=self.message_id,
artifact_id=artifact_id or None,
proposed_ok=proposed_ok,
tool_name=tool_data["name"],
action_name=action_name,
arguments=call_args,
tool_id=tool_data.get("id"),
)
stream_tool_call_data = {
key: value
for key, value in tool_call_data.items()
@@ -612,12 +451,10 @@ class ToolExecutor:
row_id = tool_data.get("id")
if not row_id:
logger.error(
"tool_missing_row_id",
extra={
"tool_name": tool_data.get("name"),
"tool_id": tool_id,
"action_name": action_name,
},
"Tool data missing 'id' for tool name=%s (enumerate-key tool_id=%s); "
"skipping load to avoid binding a non-UUID downstream.",
tool_data.get("name"),
tool_id,
)
return None
tool_config["tool_id"] = str(row_id)

View File

@@ -20,11 +20,10 @@ from pydantic import AnyHttpUrl, ValidationError
from redis import Redis
from application.agents.tools.base import Tool
from application.api.user.tasks import mcp_oauth_task
from application.api.user.tasks import mcp_oauth_status_task, mcp_oauth_task
from application.cache import get_redis_instance
from application.core.settings import settings
from application.core.url_validation import SSRFError, validate_url
from application.events.keys import stream_key
from application.security.encryption import decrypt_credentials
logger = logging.getLogger(__name__)
@@ -77,12 +76,6 @@ class MCPTool(Tool):
self.oauth_task_id = config.get("oauth_task_id", None)
self.oauth_client_name = config.get("oauth_client_name", "DocsGPT-MCP")
self.redirect_uri = self._resolve_redirect_uri(config.get("redirect_uri"))
# Pulled out of ``config`` (rather than left in ``self.config``)
# because it is a callable supplied by the OAuth worker — not
# something the rest of the tool plumbing should marshal or
# serialize. ``DocsGPTOAuth`` invokes it from ``redirect_handler``
# so the SSE envelope can carry ``authorization_url``.
self.oauth_redirect_publish = config.pop("oauth_redirect_publish", None)
self.available_tools = []
self._cache_key = self._generate_cache_key()
@@ -174,7 +167,6 @@ class MCPTool(Tool):
redirect_uri=self.redirect_uri,
task_id=self.oauth_task_id,
user_id=self.user_id,
redirect_publish=self.oauth_redirect_publish,
)
elif self.auth_type == "bearer":
token = self.auth_credentials.get(
@@ -687,17 +679,12 @@ class DocsGPTOAuth(OAuthClientProvider):
user_id=None,
additional_client_metadata: dict[str, Any] | None = None,
skip_redirect_validation: bool = False,
redirect_publish=None,
):
self.redirect_uri = redirect_uri
self.redis_client = redis_client
self.redis_prefix = redis_prefix
self.task_id = task_id
self.user_id = user_id
# Worker-supplied callback. Invoked from ``redirect_handler``
# once the authorization URL is known so the SSE envelope can
# carry it. ``None`` for any non-worker entrypoint.
self.redirect_publish = redirect_publish
parsed_url = urlparse(mcp_url)
self.server_base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
@@ -757,19 +744,17 @@ class DocsGPTOAuth(OAuthClientProvider):
self.redis_client.setex(key, 600, auth_url)
logger.info("Stored auth_url in Redis: %s", key)
if self.redirect_publish is not None:
# Best-effort: a publish failure must not abort the OAuth
# handshake — the user can still authorize via the popup
# opened from the legacy polling fallback if the SSE
# envelope is lost.
try:
self.redirect_publish(auth_url)
except Exception:
logger.warning(
"redirect_publish callback raised for task_id=%s",
self.task_id,
exc_info=True,
)
if self.task_id:
status_key = f"mcp_oauth_status:{self.task_id}"
status_data = {
"status": "requires_redirect",
"message": "Authorization required",
"authorization_url": self.auth_url,
"state": self.extracted_state,
"requires_oauth": True,
"task_id": self.task_id,
}
self.redis_client.setex(status_key, 600, json.dumps(status_data))
async def callback_handler(self) -> tuple[str, str | None]:
"""Wait for auth code from Redis using the state value."""
@@ -779,6 +764,17 @@ class DocsGPTOAuth(OAuthClientProvider):
max_wait_time = 300
code_key = f"{self.redis_prefix}code:{self.extracted_state}"
if self.task_id:
status_key = f"mcp_oauth_status:{self.task_id}"
status_data = {
"status": "awaiting_callback",
"message": "Waiting for authorization...",
"authorization_url": self.auth_url,
"state": self.extracted_state,
"requires_oauth": True,
"task_id": self.task_id,
}
self.redis_client.setex(status_key, 600, json.dumps(status_data))
start_time = time.time()
while time.time() - start_time < max_wait_time:
code_data = self.redis_client.get(code_key)
@@ -793,6 +789,14 @@ class DocsGPTOAuth(OAuthClientProvider):
self.redis_client.delete(
f"{self.redis_prefix}state:{self.extracted_state}"
)
if self.task_id:
status_data = {
"status": "callback_received",
"message": "Completing authentication...",
"task_id": self.task_id,
}
self.redis_client.setex(status_key, 600, json.dumps(status_data))
return code, returned_state
error_key = f"{self.redis_prefix}error:{self.extracted_state}"
error_data = self.redis_client.get(error_key)
@@ -1034,73 +1038,8 @@ class MCPOAuthManager:
logger.error("Error handling OAuth callback: %s", e)
return False
def get_oauth_status(self, task_id: str, user_id: str) -> Dict[str, Any]:
"""Return the latest OAuth status for ``task_id`` from the user's SSE journal.
Mirrors the legacy polling contract: ``status`` derived from the
``mcp.oauth.*`` event-type suffix, with payload fields surfaced
(e.g. ``tools``/``tools_count`` on ``completed``).
"""
def get_oauth_status(self, task_id: str) -> Dict[str, Any]:
"""Get current status of OAuth flow using provided task_id."""
if not task_id:
return {"status": "not_started", "message": "OAuth flow not started"}
if not user_id:
return {"status": "not_found", "message": "User not provided"}
if self.redis_client is None:
return {"status": "not_found", "message": "Redis unavailable"}
try:
# OAuth flows are short-lived but a concurrent source
# ingest can flood the user channel between the OAuth
# popup completing and the user clicking Save, pushing the
# completion envelope outside the read window. Bound the
# scan by the configured stream cap so we cover the full
# journal — XADD MAXLEN keeps that bounded too.
scan_count = max(settings.EVENTS_STREAM_MAXLEN, 200)
entries = self.redis_client.xrevrange(
stream_key(user_id), count=scan_count
)
except Exception:
logger.exception(
"xrevrange failed for oauth status: user_id=%s task_id=%s",
user_id,
task_id,
)
return {"status": "not_found", "message": "Status unavailable"}
for _entry_id, fields in entries:
if not isinstance(fields, dict):
continue
# decode_responses=False ⇒ bytes keys; the string-key fallback
# covers a future flip of that default without a forced refactor.
event_raw = fields.get(b"event")
if event_raw is None:
event_raw = fields.get("event")
if event_raw is None:
continue
if isinstance(event_raw, bytes):
try:
event_raw = event_raw.decode("utf-8")
except Exception:
continue
try:
envelope = json.loads(event_raw)
except Exception:
continue
if not isinstance(envelope, dict):
continue
event_type = envelope.get("type", "")
if not isinstance(event_type, str) or not event_type.startswith(
"mcp.oauth."
):
continue
scope = envelope.get("scope") or {}
if scope.get("kind") != "mcp_oauth" or scope.get("id") != task_id:
continue
payload = envelope.get("payload") or {}
return {
"status": event_type[len("mcp.oauth."):],
"task_id": task_id,
**payload,
}
return {"status": "not_found", "message": "Status not found"}
return mcp_oauth_status_task(task_id)

View File

@@ -177,4 +177,3 @@ class PostgresTool(Tool):
"order": 1,
},
}

View File

@@ -57,29 +57,6 @@ class ToolActionParser:
def _parse_google_llm(self, call):
try:
call_args = call.arguments
# Gemini's SDK natively returns ``args`` as a dict, but the
# resume path (``gen_continuation``) stringifies it for the
# assistant message. Coerce a JSON string back into a dict;
# fall back to an empty dict on malformed input so downstream
# ``call_args.items()`` doesn't crash the stream.
if isinstance(call_args, str):
try:
call_args = json.loads(call_args)
except (json.JSONDecodeError, TypeError):
logger.warning(
"Google call.arguments was not valid JSON; "
"falling back to empty args for %s",
getattr(call, "name", "<unknown>"),
)
call_args = {}
if not isinstance(call_args, dict):
logger.warning(
"Google call.arguments has unexpected type %s; "
"falling back to empty args for %s",
type(call_args).__name__,
getattr(call, "name", "<unknown>"),
)
call_args = {}
resolved = self._resolve_via_mapping(call.name)
if resolved:

View File

@@ -1,4 +1,4 @@
"""0001 initial schema — consolidated baseline for user-data tables.
"""0001 initial schema — consolidated Phase-1..3 baseline.
Revision ID: 0001_initial
Revises:

View File

@@ -1,217 +0,0 @@
"""0004 durability foundation — idempotency, tool-call log, ingest checkpoint.
Adds ``task_dedup``, ``webhook_dedup``, ``tool_call_attempts``,
``ingest_chunk_progress``, and per-row status flags on
``conversation_messages`` and ``pending_tool_state``. Also adds
``token_usage.source`` and ``token_usage.request_id`` so per-channel
cost attribution (``agent_stream`` / ``title`` / ``compression`` /
``rag_condense`` / ``fallback``) is queryable and multi-call agent runs
can be DISTINCT-collapsed into a single user request for rate limiting.
Revision ID: 0004_durability_foundation
Revises: 0003_user_custom_models
"""
from typing import Sequence, Union
from alembic import op
revision: str = "0004_durability_foundation"
down_revision: Union[str, None] = "0003_user_custom_models"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ------------------------------------------------------------------
# New tables
# ------------------------------------------------------------------
# ``attempt_count`` bounds the per-Celery-task idempotency wrapper's
# retry loop so a poison message can't run forever; default 0 means
# existing rows behave as if no attempts have run yet.
op.execute(
"""
CREATE TABLE task_dedup (
idempotency_key TEXT PRIMARY KEY,
task_name TEXT NOT NULL,
task_id TEXT NOT NULL,
result_json JSONB,
status TEXT NOT NULL
CHECK (status IN ('pending', 'completed', 'failed')),
attempt_count INT NOT NULL DEFAULT 0,
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
"""
)
op.execute(
"""
CREATE TABLE webhook_dedup (
idempotency_key TEXT PRIMARY KEY,
agent_id UUID NOT NULL,
task_id TEXT NOT NULL,
response_json JSONB,
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
"""
)
# FK on ``message_id`` uses ``ON DELETE SET NULL`` so the journal row
# survives parent-message deletion (compliance / cost-attribution).
op.execute(
"""
CREATE TABLE tool_call_attempts (
call_id TEXT PRIMARY KEY,
message_id UUID
REFERENCES conversation_messages (id)
ON DELETE SET NULL,
tool_id UUID,
tool_name TEXT NOT NULL,
action_name TEXT NOT NULL,
arguments JSONB NOT NULL,
result JSONB,
error TEXT,
status TEXT NOT NULL
CHECK (status IN (
'proposed', 'executed', 'confirmed',
'compensated', 'failed'
)),
attempted_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
"""
)
op.execute(
"""
CREATE TABLE ingest_chunk_progress (
source_id UUID PRIMARY KEY,
total_chunks INT NOT NULL,
embedded_chunks INT NOT NULL DEFAULT 0,
last_index INT NOT NULL DEFAULT -1,
last_updated TIMESTAMPTZ NOT NULL DEFAULT now()
);
"""
)
# ------------------------------------------------------------------
# Column additions on existing tables
# ------------------------------------------------------------------
# DEFAULT 'complete' backfills existing rows — they're already done.
op.execute(
"""
ALTER TABLE conversation_messages
ADD COLUMN status TEXT NOT NULL DEFAULT 'complete'
CHECK (status IN ('pending', 'streaming', 'complete', 'failed')),
ADD COLUMN request_id TEXT;
"""
)
op.execute(
"""
ALTER TABLE pending_tool_state
ADD COLUMN status TEXT NOT NULL DEFAULT 'pending'
CHECK (status IN ('pending', 'resuming')),
ADD COLUMN resumed_at TIMESTAMPTZ;
"""
)
# Default ``agent_stream`` backfills historical rows under the
# assumption they were written from the primary path — pre-fix the
# only path that wrote was the error branch reading agent.llm.
# ``request_id`` is the stream-scoped UUID stamped by the route on
# ``agent.llm`` so multi-tool agent runs (which produce N rows)
# collapse to one request via DISTINCT in ``count_in_range``.
# Side-channel sources (``title`` / ``compression`` / ``rag_condense``
# / ``fallback``) leave it NULL and are excluded from the request
# count by source filter.
op.execute(
"""
ALTER TABLE token_usage
ADD COLUMN source TEXT NOT NULL DEFAULT 'agent_stream',
ADD COLUMN request_id TEXT;
"""
)
# ------------------------------------------------------------------
# Indexes — partial where the predicate selects only non-terminal rows
# ------------------------------------------------------------------
op.execute(
"CREATE INDEX conversation_messages_pending_ts_idx "
"ON conversation_messages (timestamp) "
"WHERE status IN ('pending', 'streaming');"
)
op.execute(
"CREATE INDEX tool_call_attempts_pending_ts_idx "
"ON tool_call_attempts (attempted_at) "
"WHERE status IN ('proposed', 'executed');"
)
op.execute(
"CREATE INDEX tool_call_attempts_message_idx "
"ON tool_call_attempts (message_id) "
"WHERE message_id IS NOT NULL;"
)
op.execute(
"CREATE INDEX pending_tool_state_resuming_ts_idx "
"ON pending_tool_state (resumed_at) "
"WHERE status = 'resuming';"
)
op.execute(
"CREATE INDEX webhook_dedup_agent_idx "
"ON webhook_dedup (agent_id);"
)
op.execute(
"CREATE INDEX task_dedup_pending_attempts_idx "
"ON task_dedup (attempt_count) WHERE status = 'pending';"
)
# Cost-attribution dashboards filter ``token_usage`` by
# ``(timestamp, source)``; index the same shape so they stay cheap.
op.execute(
"CREATE INDEX token_usage_source_ts_idx "
"ON token_usage (source, timestamp);"
)
# Partial index — only rows with a stamped request_id participate
# in the DISTINCT count. NULL rows fall through to the COUNT(*)
# branch in the repository query.
op.execute(
"CREATE INDEX token_usage_request_id_idx "
"ON token_usage (request_id) "
"WHERE request_id IS NOT NULL;"
)
op.execute(
"CREATE TRIGGER tool_call_attempts_set_updated_at "
"BEFORE UPDATE ON tool_call_attempts "
"FOR EACH ROW WHEN (OLD.* IS DISTINCT FROM NEW.*) "
"EXECUTE FUNCTION set_updated_at();"
)
def downgrade() -> None:
# CASCADE so the downgrade stays safe if later migrations FK into these.
for table in (
"ingest_chunk_progress",
"tool_call_attempts",
"webhook_dedup",
"task_dedup",
):
op.execute(f"DROP TABLE IF EXISTS {table} CASCADE;")
op.execute(
"ALTER TABLE conversation_messages "
"DROP COLUMN IF EXISTS request_id, "
"DROP COLUMN IF EXISTS status;"
)
op.execute(
"ALTER TABLE pending_tool_state "
"DROP COLUMN IF EXISTS resumed_at, "
"DROP COLUMN IF EXISTS status;"
)
op.execute("DROP INDEX IF EXISTS token_usage_request_id_idx;")
op.execute("DROP INDEX IF EXISTS token_usage_source_ts_idx;")
op.execute(
"ALTER TABLE token_usage "
"DROP COLUMN IF EXISTS request_id, "
"DROP COLUMN IF EXISTS source;"
)

View File

@@ -1,44 +0,0 @@
"""0005 ingest_chunk_progress.attempt_id — per-attempt resume scoping.
Without this column, a completed checkpoint row poisoned every later
embed call on the same ``source_id``: a sync after an upload finished
read the upload's terminal ``last_index`` and either embedded zero
chunks (if new ``total_docs <= last_index + 1``) or stacked new chunks
on top of the old vectors (if ``total_docs > last_index + 1``).
``attempt_id`` is stamped from ``self.request.id`` (Celery's stable
task id, which survives ``acks_late`` retries of the same task but
differs across separate task invocations). The repository's
``init_progress`` upsert resets ``last_index`` / ``embedded_chunks``
when the incoming ``attempt_id`` differs from the stored one — so a
fresh sync starts from chunk 0 while a retry of the same task resumes
from the last checkpointed chunk.
Revision ID: 0005_ingest_attempt_id
Revises: 0004_durability_foundation
"""
from typing import Sequence, Union
from alembic import op
revision: str = "0005_ingest_attempt_id"
down_revision: Union[str, None] = "0004_durability_foundation"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.execute(
"""
ALTER TABLE ingest_chunk_progress
ADD COLUMN attempt_id TEXT;
"""
)
def downgrade() -> None:
op.execute(
"ALTER TABLE ingest_chunk_progress DROP COLUMN IF EXISTS attempt_id;"
)

View File

@@ -1,57 +0,0 @@
"""0006 task_dedup lease columns — running-lease for in-flight tasks.
Without these, ``with_idempotency`` only short-circuits *completed*
rows. A late-ack redelivery (Redis ``visibility_timeout`` exceeded by a
long ingest, or a hung-but-alive worker) hands the same message to a
second worker; ``_claim_or_bump`` only bumped the attempt counter and
both workers ran the task body in parallel — duplicate vector writes,
duplicate token spend, duplicate webhook side effects.
``lease_owner_id`` + ``lease_expires_at`` turn that into an atomic
compare-and-swap. The wrapper claims a lease at entry, refreshes it via
a 30 s heartbeat thread, and finalises (which makes the lease moot via
``status='completed'``). A second worker hitting the same key sees a
fresh lease and ``self.retry(countdown=LEASE_TTL)``s instead of running.
A crashed worker's lease expires after ``LEASE_TTL`` seconds and the
next retry can claim it.
Revision ID: 0006_idempotency_lease
Revises: 0005_ingest_attempt_id
"""
from typing import Sequence, Union
from alembic import op
revision: str = "0006_idempotency_lease"
down_revision: Union[str, None] = "0005_ingest_attempt_id"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.execute(
"""
ALTER TABLE task_dedup
ADD COLUMN lease_owner_id TEXT,
ADD COLUMN lease_expires_at TIMESTAMPTZ;
"""
)
# Reconciler's stuck-pending sweep filters by
# ``(status='pending', lease_expires_at < now() - 60s, attempt_count >= 5)``.
# Partial index keeps the scan small even under heavy task throughput.
op.execute(
"CREATE INDEX task_dedup_pending_lease_idx "
"ON task_dedup (lease_expires_at) "
"WHERE status = 'pending';"
)
def downgrade() -> None:
op.execute("DROP INDEX IF EXISTS task_dedup_pending_lease_idx;")
op.execute(
"ALTER TABLE task_dedup "
"DROP COLUMN IF EXISTS lease_expires_at, "
"DROP COLUMN IF EXISTS lease_owner_id;"
)

View File

@@ -1,40 +0,0 @@
"""0007 message_events — durable journal of chat-stream events.
Snapshot half of the chat-stream snapshot+tail pattern. Composite PK
``(message_id, sequence_no)``, ``created_at`` indexed for retention
sweeps, ``ON DELETE CASCADE`` from ``conversation_messages``.
Revision ID: 0007_message_events
Revises: 0006_idempotency_lease
"""
from typing import Sequence, Union
from alembic import op
revision: str = "0007_message_events"
down_revision: Union[str, None] = "0006_idempotency_lease"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.execute(
"""
CREATE TABLE message_events (
message_id UUID NOT NULL REFERENCES conversation_messages(id) ON DELETE CASCADE,
sequence_no INTEGER NOT NULL,
event_type TEXT NOT NULL,
payload JSONB NOT NULL DEFAULT '{}'::jsonb,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
PRIMARY KEY (message_id, sequence_no)
);
CREATE INDEX message_events_created_at_idx ON message_events(created_at);
"""
)
def downgrade() -> None:
op.execute("DROP INDEX IF EXISTS message_events_created_at_idx;")
op.execute("DROP TABLE IF EXISTS message_events;")

View File

@@ -102,8 +102,6 @@ class AnswerResource(Resource, BaseAnswerResource):
"tools_dict": tools_dict,
"pending_tool_calls": pending_tool_calls,
"tool_actions": tool_actions,
"reserved_message_id": processor.reserved_message_id,
"request_id": processor.request_id,
},
)
else:

View File

@@ -1,18 +1,13 @@
import datetime
import json
import logging
import time
import uuid
from typing import Any, Dict, Generator, List, Optional
from flask import jsonify, make_response, Response
from flask_restx import Namespace
from application.api.answer.services.continuation_service import ContinuationService
from application.api.answer.services.conversation_service import (
ConversationService,
TERMINATED_RESPONSE_PLACEHOLDER,
)
from application.api.answer.services.conversation_service import ConversationService
from application.core.model_utils import (
get_api_key_for_provider,
get_default_model_id,
@@ -23,16 +18,9 @@ from application.core.settings import settings
from application.error import sanitize_api_error
from application.llm.llm_creator import LLMCreator
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.repositories.conversations import MessageUpdateOutcome
from application.storage.db.repositories.token_usage import TokenUsageRepository
from application.storage.db.repositories.user_logs import UserLogsRepository
from application.storage.db.session import db_readonly, db_session
from application.events.publisher import publish_user_event
from application.streaming.event_replay import format_sse_event
from application.streaming.message_journal import (
BatchedJournalWriter,
record_event,
)
from application.utils import check_required_fields
logger = logging.getLogger(__name__)
@@ -215,188 +203,13 @@ class BaseAnswerResource:
Yields:
Server-sent event strings
"""
response_full, thought, source_log_docs, tool_calls = "", "", [], []
is_structured = False
schema_info = None
structured_chunks = []
query_metadata: Dict[str, Any] = {}
paused = False
# One id shared across the WAL row, primary LLM (token_usage
# attribution), the SSE event, and resumed continuations.
request_id = (
_continuation.get("request_id") if _continuation else None
) or str(uuid.uuid4())
# Reserve the placeholder row before the LLM call so a crash
# mid-stream still leaves the question queryable. Continuations
# reuse the original placeholder.
reserved_message_id: Optional[str] = None
wal_eligible = should_save_conversation and not _continuation
if wal_eligible:
try:
reservation = self.conversation_service.save_user_question(
conversation_id=conversation_id,
question=question,
decoded_token=decoded_token,
attachment_ids=attachment_ids,
api_key=user_api_key,
agent_id=agent_id,
is_shared_usage=is_shared_usage,
shared_token=shared_token,
model_id=model_id or self.default_model_id,
request_id=request_id,
index=index,
)
conversation_id = reservation["conversation_id"]
reserved_message_id = reservation["message_id"]
except Exception as e:
logger.error(
f"Failed to reserve message row before stream: {e}",
exc_info=True,
)
elif _continuation and _continuation.get("reserved_message_id"):
reserved_message_id = _continuation["reserved_message_id"]
primary_llm = getattr(agent, "llm", None)
if primary_llm is not None:
primary_llm._request_id = request_id
# Flipped to ``streaming`` on first chunk; reconciler uses this
# to tell "never started" from "in flight".
streaming_marked = False
# Heartbeat goes into ``metadata.last_heartbeat_at`` (not
# ``updated_at``, which reconciler-side writes share) and uses
# ``time.monotonic`` so a blocked event loop can't fake fresh.
STREAM_HEARTBEAT_INTERVAL = 60
last_heartbeat_at = time.monotonic()
def _mark_streaming_once() -> None:
nonlocal streaming_marked, last_heartbeat_at
if streaming_marked or not reserved_message_id:
return
try:
self.conversation_service.update_message_status(
reserved_message_id, "streaming",
)
except Exception:
logger.exception(
"update_message_status streaming failed for %s",
reserved_message_id,
)
# Seed last_heartbeat_at so watchdog doesn't fall back to `timestamp`
# (creation time) before the first STREAM_HEARTBEAT_INTERVAL tick.
try:
self.conversation_service.heartbeat_message(
reserved_message_id,
)
except Exception:
logger.exception(
"initial heartbeat seed failed for %s",
reserved_message_id,
)
streaming_marked = True
last_heartbeat_at = time.monotonic()
def _heartbeat_streaming() -> None:
nonlocal last_heartbeat_at
if not reserved_message_id or not streaming_marked:
return
now_mono = time.monotonic()
if now_mono - last_heartbeat_at < STREAM_HEARTBEAT_INTERVAL:
return
try:
self.conversation_service.heartbeat_message(
reserved_message_id,
)
except Exception:
logger.exception(
"stream heartbeat update failed for %s",
reserved_message_id,
)
last_heartbeat_at = now_mono
# Correlates tool_call_attempts rows with this message.
if reserved_message_id and getattr(agent, "tool_executor", None):
try:
agent.tool_executor.message_id = reserved_message_id
except Exception:
logger.debug(
"Could not set tool_executor.message_id; tool-call correlation will be missing for message_id=%s",
reserved_message_id,
)
# Per-stream monotonic SSE event id. Allocated by ``_emit`` and
# threaded through both the wire format (``id: <seq>\\n``) and
# the journal write so a reconnecting client can ``Last-Event-
# ID`` past anything they already saw. Continuations resume
# against the original ``reserved_message_id`` — seed the
# allocator from the journal's high-water mark so we don't
# collide on the duplicate-PK and silently lose every emit
# past the resume point.
sequence_no = -1
if _continuation and reserved_message_id:
try:
from application.storage.db.repositories.message_events import (
MessageEventsRepository,
)
with db_readonly() as conn:
latest = MessageEventsRepository(conn).latest_sequence_no(
reserved_message_id
)
if latest is not None:
sequence_no = latest
except Exception:
logger.exception(
"Continuation seq seed lookup failed for message_id=%s; "
"falling back to seq=-1 (duplicate-PK collisions will "
"be swallowed)",
reserved_message_id,
)
# One batched journal writer per stream.
journal_writer: Optional[BatchedJournalWriter] = (
BatchedJournalWriter(reserved_message_id)
if reserved_message_id
else None
)
def _emit(payload: dict) -> str:
"""Format-and-journal one SSE event.
With a reserved ``message_id``, buffers into the journal and
emits ``id: <seq>``-tagged SSE frames; otherwise falls back to
legacy ``data: ...\\n\\n`` framing.
"""
nonlocal sequence_no
if not reserved_message_id or journal_writer is None:
return f"data: {json.dumps(payload)}\n\n"
sequence_no += 1
seq = sequence_no
event_type = (
payload.get("type", "data")
if isinstance(payload, dict)
else "data"
)
normalised = payload if isinstance(payload, dict) else {"value": payload}
journal_writer.record(seq, event_type, normalised)
return format_sse_event(normalised, seq)
try:
# Surface the placeholder id before any LLM tokens so a
# mid-handshake disconnect still has a row to tail-poll.
if reserved_message_id:
yield _emit(
{
"type": "message_id",
"message_id": reserved_message_id,
"conversation_id": (
str(conversation_id) if conversation_id else None
),
"request_id": request_id,
}
)
response_full, thought, source_log_docs, tool_calls = "", "", [], []
is_structured = False
schema_info = None
structured_chunks = []
query_metadata = {}
paused = False
if _continuation:
gen_iter = agent.gen_continuation(
@@ -409,24 +222,18 @@ class BaseAnswerResource:
gen_iter = agent.gen(query=question)
for line in gen_iter:
# Cheap closure check that only hits the DB when the
# heartbeat interval has elapsed.
_heartbeat_streaming()
if "metadata" in line:
query_metadata.update(line["metadata"])
elif "answer" in line:
_mark_streaming_once()
response_full += str(line["answer"])
if line.get("structured"):
is_structured = True
schema_info = line.get("schema")
structured_chunks.append(line["answer"])
else:
yield _emit(
{"type": "answer", "answer": line["answer"]}
)
data = json.dumps({"type": "answer", "answer": line["answer"]})
yield f"data: {data}\n\n"
elif "sources" in line:
_mark_streaming_once()
truncated_sources = []
source_log_docs = line["sources"]
for source in line["sources"]:
@@ -437,48 +244,54 @@ class BaseAnswerResource:
)
truncated_sources.append(truncated_source)
if truncated_sources:
yield _emit(
data = json.dumps(
{"type": "source", "source": truncated_sources}
)
yield f"data: {data}\n\n"
elif "tool_calls" in line:
tool_calls = line["tool_calls"]
yield _emit({"type": "tool_calls", "tool_calls": tool_calls})
data = json.dumps({"type": "tool_calls", "tool_calls": tool_calls})
yield f"data: {data}\n\n"
elif "thought" in line:
thought += line["thought"]
yield _emit({"type": "thought", "thought": line["thought"]})
data = json.dumps({"type": "thought", "thought": line["thought"]})
yield f"data: {data}\n\n"
elif "type" in line:
if line.get("type") == "tool_calls_pending":
# Save continuation state and end the stream
paused = True
yield _emit(line)
data = json.dumps(line)
yield f"data: {data}\n\n"
elif line.get("type") == "error":
yield _emit(
{
"type": "error",
"error": sanitize_api_error(
line.get("error", "An error occurred")
),
}
)
sanitized_error = {
"type": "error",
"error": sanitize_api_error(line.get("error", "An error occurred"))
}
data = json.dumps(sanitized_error)
yield f"data: {data}\n\n"
else:
yield _emit(line)
data = json.dumps(line)
yield f"data: {data}\n\n"
if is_structured and structured_chunks:
yield _emit(
{
"type": "structured_answer",
"answer": response_full,
"structured": True,
"schema": schema_info,
}
)
structured_data = {
"type": "structured_answer",
"answer": response_full,
"structured": True,
"schema": schema_info,
}
data = json.dumps(structured_data)
yield f"data: {data}\n\n"
# ---- Paused: save continuation state and end stream early ----
if paused:
continuation = getattr(agent, "_pending_continuation", None)
if continuation:
# First-turn pause needs a conversation row to attach to.
# Ensure we have a conversation_id — create a partial
# conversation if this is the first turn.
if not conversation_id and should_save_conversation:
try:
# Use model-owner scope so shared-agent
# owner-BYOM resolves to its registered plugin.
provider = (
get_provider_from_model_id(
model_id,
@@ -527,7 +340,6 @@ class BaseAnswerResource:
exc_info=True,
)
state_saved = False
if conversation_id:
try:
cont_service = ContinuationService()
@@ -540,8 +352,8 @@ class BaseAnswerResource:
tool_schemas=getattr(agent, "tools", []),
agent_config={
"model_id": model_id or self.default_model_id,
# BYOM scope; without it resume falls
# back to caller's layer.
# Persist BYOM scope so resume doesn't
# fall back to caller's layer.
"model_user_id": model_user_id,
"llm_name": getattr(agent, "llm_name", settings.LLM_PROVIDER),
"api_key": getattr(agent, "api_key", None),
@@ -551,81 +363,30 @@ class BaseAnswerResource:
"prompt": getattr(agent, "prompt", ""),
"json_schema": getattr(agent, "json_schema", None),
"retriever_config": getattr(agent, "retriever_config", None),
# Reused on resume so the same WAL row
# is finalised and request_id stays
# consistent across token_usage rows.
"reserved_message_id": reserved_message_id,
"request_id": request_id,
},
client_tools=getattr(
agent.tool_executor, "client_tools", None
),
)
state_saved = True
except Exception as e:
logger.error(
f"Failed to save continuation state: {str(e)}",
exc_info=True,
)
# Notify the user out-of-band so they can navigate
# back to the conversation and decide on the
# pending tool calls. Gated on ``state_saved``: a
# missing pending_tool_state row would 404 the
# resume endpoint, so an unfulfillable notification
# is worse than no notification.
user_id_for_event = (
decoded_token.get("sub") if decoded_token else None
)
if state_saved and user_id_for_event and conversation_id:
pending_calls = continuation.get(
"pending_tool_calls", []
) if continuation else []
# Trim each pending tool call to its identifying
# metadata so a tool with a multi-MB argument
# doesn't blow out the per-event payload size
# cap. The resume page fetches full args from
# ``pending_tool_state`` regardless.
pending_summaries = [
{
k: tc.get(k)
for k in (
"call_id",
"tool_name",
"action_name",
"name",
)
if isinstance(tc, dict) and tc.get(k) is not None
}
for tc in (pending_calls or [])
if isinstance(tc, dict)
]
publish_user_event(
user_id_for_event,
"tool.approval.required",
{
"conversation_id": str(conversation_id),
"message_id": reserved_message_id,
"pending_tool_calls": pending_summaries,
},
scope={
"kind": "conversation",
"id": str(conversation_id),
},
)
id_data = {"type": "id", "id": str(conversation_id)}
data = json.dumps(id_data)
yield f"data: {data}\n\n"
yield _emit({"type": "id", "id": str(conversation_id)})
yield _emit({"type": "end"})
# Drain the terminal ``end`` so a reconnecting client
# sees it on snapshot — same reason as the main exit.
if journal_writer is not None:
journal_writer.close()
data = json.dumps({"type": "end"})
yield f"data: {data}\n\n"
return
if isNoneDoc:
for doc in source_log_docs:
doc["source"] = "None"
# Model-owner scope so title-gen uses owner's BYOM key.
# Run under model-owner scope so title-gen LLM inside
# save_conversation uses the owner's BYOM provider/key.
provider = (
get_provider_from_model_id(
model_id,
@@ -646,49 +407,26 @@ class BaseAnswerResource:
agent_id=agent_id,
model_user_id=model_user_id,
)
# Title-gen only; agent stream tokens live on ``agent.llm``.
llm._token_usage_source = "title"
if should_save_conversation:
if reserved_message_id is not None:
self.conversation_service.finalize_message(
reserved_message_id,
response_full,
thought=thought,
sources=source_log_docs,
tool_calls=tool_calls,
model_id=model_id or self.default_model_id,
metadata=query_metadata if query_metadata else None,
status="complete",
title_inputs={
"llm": llm,
"question": question,
"response": response_full,
"model_id": model_id or self.default_model_id,
"fallback_name": (
question[:50] if question else "New Conversation"
),
},
)
else:
conversation_id = self.conversation_service.save_conversation(
conversation_id,
question,
response_full,
thought,
source_log_docs,
tool_calls,
llm,
model_id or self.default_model_id,
decoded_token,
index=index,
api_key=user_api_key,
agent_id=agent_id,
is_shared_usage=is_shared_usage,
shared_token=shared_token,
attachment_ids=attachment_ids,
metadata=query_metadata if query_metadata else None,
)
conversation_id = self.conversation_service.save_conversation(
conversation_id,
question,
response_full,
thought,
source_log_docs,
tool_calls,
llm,
model_id or self.default_model_id,
decoded_token,
index=index,
api_key=user_api_key,
agent_id=agent_id,
is_shared_usage=is_shared_usage,
shared_token=shared_token,
attachment_ids=attachment_ids,
metadata=query_metadata if query_metadata else None,
)
# Persist compression metadata/summary if it exists and wasn't saved mid-execution
compression_meta = getattr(agent, "compression_metadata", None)
compression_saved = getattr(agent, "compression_saved", False)
@@ -711,22 +449,9 @@ class BaseAnswerResource:
)
else:
conversation_id = None
# Resume finished cleanly; drop the continuation row.
# Crash-paths leave it ``resuming`` for the janitor to revert.
if _continuation and conversation_id:
try:
cont_service = ContinuationService()
cont_service.delete_state(
str(conversation_id),
decoded_token.get("sub", "local"),
)
except Exception as e:
logger.error(
f"Failed to delete continuation state on resume "
f"completion: {e}",
exc_info=True,
)
yield _emit({"type": "id", "id": str(conversation_id)})
id_data = {"type": "id", "id": str(conversation_id)}
data = json.dumps(id_data)
yield f"data: {data}\n\n"
tool_calls_for_logging = self._prepare_tool_calls_for_logging(
getattr(agent, "tool_calls", tool_calls) or tool_calls
@@ -767,40 +492,21 @@ class BaseAnswerResource:
exc_info=True,
)
yield _emit({"type": "end"})
# Drain the journal buffer so the terminal ``end`` event is
# visible to any reconnecting client. Without this the
# client could snapshot up to the last flush boundary and
# then live-tail waiting for an ``end`` that's still
# sitting in memory.
if journal_writer is not None:
journal_writer.close()
data = json.dumps({"type": "end"})
yield f"data: {data}\n\n"
except GeneratorExit:
logger.info(f"Stream aborted by client for question: {question[:50]}... ")
# Drain any buffered events before the terminal one-shot
# ``record_event`` below — keeps the journal's seq order
# contiguous (buffered events ... terminal event). ``close``
# is idempotent; pairing it with ``flush`` matches the
# normal-exit and error branches so any future ``record()``
# past this point would log instead of silently buffering.
if journal_writer is not None:
journal_writer.flush()
journal_writer.close()
# Save partial response
# Whether the DB row was flipped to ``complete`` during this
# abort handler. Drives the choice of terminal journal event
# below: journal ``end`` only when the row actually matches,
# else journal ``error`` so a reconnecting client sees a
# failed terminal state instead of a blank "success".
finalized_complete = False
if should_save_conversation and response_full:
try:
if isNoneDoc:
for doc in source_log_docs:
doc["source"] = "None"
# Resolve under model-owner scope so shared-agent
# title-gen uses owner BYOM, not deployment default.
# Mirror the normal-path provider resolution so the
# partial-save title LLM uses the model-owner's BYOM
# registration (shared-agent dispatch) rather than
# the deployment default with the instance api key.
provider = (
get_provider_from_model_id(
model_id,
@@ -826,58 +532,24 @@ class BaseAnswerResource:
agent_id=agent_id,
model_user_id=model_user_id,
)
llm._token_usage_source = "title"
if reserved_message_id is not None:
outcome = self.conversation_service.finalize_message(
reserved_message_id,
response_full,
thought=thought,
sources=source_log_docs,
tool_calls=tool_calls,
model_id=model_id or self.default_model_id,
metadata=query_metadata if query_metadata else None,
status="complete",
title_inputs={
"llm": llm,
"question": question,
"response": response_full,
"model_id": model_id or self.default_model_id,
"fallback_name": (
question[:50] if question else "New Conversation"
),
},
)
# ``ALREADY_COMPLETE`` means the normal-path
# finalize at line 632 won the race: the DB row
# is already at ``complete`` and the reconnect
# journal should reflect that with ``end``,
# not a spurious ``error``.
finalized_complete = outcome in (
MessageUpdateOutcome.UPDATED,
MessageUpdateOutcome.ALREADY_COMPLETE,
)
else:
self.conversation_service.save_conversation(
conversation_id,
question,
response_full,
thought,
source_log_docs,
tool_calls,
llm,
model_id or self.default_model_id,
decoded_token,
index=index,
api_key=user_api_key,
agent_id=agent_id,
is_shared_usage=is_shared_usage,
shared_token=shared_token,
attachment_ids=attachment_ids,
metadata=query_metadata if query_metadata else None,
)
# No journal row to gate, but flag the save as
# successful for symmetry with the WAL path.
finalized_complete = True
self.conversation_service.save_conversation(
conversation_id,
question,
response_full,
thought,
source_log_docs,
tool_calls,
llm,
model_id or self.default_model_id,
decoded_token,
index=index,
api_key=user_api_key,
agent_id=agent_id,
is_shared_usage=is_shared_usage,
shared_token=shared_token,
attachment_ids=attachment_ids,
metadata=query_metadata if query_metadata else None,
)
compression_meta = getattr(agent, "compression_metadata", None)
compression_saved = getattr(agent, "compression_saved", False)
if conversation_id and compression_meta and not compression_saved:
@@ -901,94 +573,16 @@ class BaseAnswerResource:
logger.error(
f"Error saving partial response: {str(e)}", exc_info=True
)
# Journal a terminal event so reconnecting clients stop tailing;
# ``end`` only when the row is ``complete``, else ``error``.
if reserved_message_id is not None:
try:
sequence_no += 1
if finalized_complete:
# Match the wire shape ``_emit({"type": "end"})``
# uses on the normal path — the replay terminal
# check at ``event_replay._payload_is_terminal``
# reads ``payload.type``, and the frontend parses
# the same key off ``data:``.
record_event(
reserved_message_id,
sequence_no,
"end",
{"type": "end"},
)
else:
# Nothing was persisted under the complete status
# — mark the row failed so the reconciler doesn't
# need to sweep it, and journal an ``error`` so a
# reconnecting client surfaces the same failure
# the UI would show on a live error.
try:
self.conversation_service.finalize_message(
reserved_message_id,
response_full or TERMINATED_RESPONSE_PLACEHOLDER,
thought=thought,
sources=source_log_docs,
tool_calls=tool_calls,
model_id=model_id or self.default_model_id,
metadata=query_metadata if query_metadata else None,
status="failed",
error=ConnectionError(
"client disconnected before response was persisted"
),
)
except Exception as fin_err:
logger.error(
f"Failed to mark aborted message failed: {fin_err}",
exc_info=True,
)
record_event(
reserved_message_id,
sequence_no,
"error",
{
"type": "error",
"error": "Stream aborted before any response was produced.",
"code": "client_disconnect",
},
)
except Exception as journal_err:
logger.error(
f"Failed to journal terminal event on abort: {journal_err}",
exc_info=True,
)
raise
except Exception as e:
logger.error(f"Error in stream: {str(e)}", exc_info=True)
if reserved_message_id is not None:
try:
self.conversation_service.finalize_message(
reserved_message_id,
response_full or TERMINATED_RESPONSE_PLACEHOLDER,
thought=thought,
sources=source_log_docs,
tool_calls=tool_calls,
model_id=model_id or self.default_model_id,
metadata=query_metadata if query_metadata else None,
status="failed",
error=e,
)
except Exception as fin_err:
logger.error(
f"Failed to finalize errored message: {fin_err}",
exc_info=True,
)
yield _emit(
data = json.dumps(
{
"type": "error",
"error": "Please try again later. We apologize for any inconvenience.",
}
)
# Drain the terminal ``error`` event we just yielded so a
# reconnecting client sees it on snapshot.
if journal_writer is not None:
journal_writer.close()
yield f"data: {data}\n\n"
return
def process_response_stream(self, stream) -> Dict[str, Any]:
@@ -1010,22 +604,8 @@ class BaseAnswerResource:
for line in stream:
try:
# Each chunk may carry an ``id: <seq>`` header before
# the ``data:`` line. Pull just the ``data:`` body so
# the JSON decode doesn't choke on the SSE framing.
event_data = ""
for raw in line.split("\n"):
if raw.startswith("data:"):
event_data = raw[len("data:") :].lstrip()
break
if not event_data:
continue
event_data = line.replace("data: ", "").strip()
event = json.loads(event_data)
# The ``message_id`` event is informational for the
# streaming consumer and has no synchronous-API field;
# skip it so the type-switch below doesn't KeyError.
if event.get("type") == "message_id":
continue
if event["type"] == "id":
conversation_id = event["id"]

View File

@@ -1,135 +0,0 @@
"""GET /api/messages/<message_id>/events — chat-stream reconnect endpoint.
Authenticates the caller, verifies ``message_id`` belongs to the user,
then hands off to ``build_message_event_stream`` for snapshot+tail.
"""
from __future__ import annotations
import logging
import re
from typing import Iterator, Optional
from flask import Blueprint, Response, jsonify, make_response, request, stream_with_context
from sqlalchemy import text
from application.core.settings import settings
from application.storage.db.session import db_readonly
from application.streaming.event_replay import (
DEFAULT_KEEPALIVE_SECONDS,
DEFAULT_POLL_TIMEOUT_SECONDS,
build_message_event_stream,
)
logger = logging.getLogger(__name__)
messages_bp = Blueprint("message_stream", __name__)
# A message_id is the canonical UUID hex format. Reject anything else
# before the SQL layer so a malformed cookie can't surface as a 500.
_MESSAGE_ID_RE = re.compile(
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-"
r"[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$"
)
# ``sequence_no`` is a non-negative decimal integer. Anything else is
# corrupt client state — fall through to a fresh-replay cursor and let
# the snapshot reader catch the client up.
_SEQUENCE_NO_RE = re.compile(r"^\d+$")
def _normalise_last_event_id(raw: Optional[str]) -> Optional[int]:
if raw is None:
return None
raw = raw.strip()
if not raw or not _SEQUENCE_NO_RE.match(raw):
return None
return int(raw)
def _user_owns_message(message_id: str, user_id: str) -> bool:
"""Return True iff ``message_id`` belongs to ``user_id``."""
try:
with db_readonly() as conn:
row = conn.execute(
text(
"""
SELECT 1 FROM conversation_messages
WHERE id = CAST(:id AS uuid)
AND user_id = :u
LIMIT 1
"""
),
{"id": message_id, "u": user_id},
).first()
return row is not None
except Exception:
logger.exception(
"Ownership lookup failed for message_id=%s user_id=%s",
message_id,
user_id,
)
return False
@messages_bp.route("/api/messages/<message_id>/events", methods=["GET"])
def stream_message_events(message_id: str) -> Response:
decoded = getattr(request, "decoded_token", None)
user_id = decoded.get("sub") if isinstance(decoded, dict) else None
if not user_id:
return make_response(
jsonify({"success": False, "message": "Authentication required"}),
401,
)
if not _MESSAGE_ID_RE.match(message_id):
return make_response(
jsonify({"success": False, "message": "Invalid message id"}),
400,
)
if not _user_owns_message(message_id, user_id):
# Don't disclose whether the row exists — a malicious caller
# gets the same 404 whether the id is bogus, taken by another
# user, or simply gone.
return make_response(
jsonify({"success": False, "message": "Not found"}),
404,
)
raw_cursor = request.headers.get("Last-Event-ID") or request.args.get(
"last_event_id"
)
last_event_id = _normalise_last_event_id(raw_cursor)
keepalive_seconds = float(
getattr(settings, "SSE_KEEPALIVE_SECONDS", DEFAULT_KEEPALIVE_SECONDS)
)
@stream_with_context
def generate() -> Iterator[str]:
try:
yield from build_message_event_stream(
message_id,
last_event_id=last_event_id,
keepalive_seconds=keepalive_seconds,
poll_timeout_seconds=DEFAULT_POLL_TIMEOUT_SECONDS,
)
except GeneratorExit:
return
except Exception:
logger.exception(
"Reconnect stream crashed for message_id=%s user_id=%s",
message_id,
user_id,
)
response = Response(generate(), mimetype="text/event-stream")
response.headers["Cache-Control"] = "no-store"
response.headers["X-Accel-Buffering"] = "no"
response.headers["Connection"] = "keep-alive"
logger.info(
"message.event.connect message_id=%s user_id=%s last_event_id=%s",
message_id,
user_id,
last_event_id if last_event_id is not None else "-",
)
return response

View File

@@ -115,8 +115,6 @@ class StreamResource(Resource, BaseAnswerResource):
"tools_dict": tools_dict,
"pending_tool_calls": pending_tool_calls,
"tool_actions": tool_actions,
"reserved_message_id": processor.reserved_message_id,
"request_id": processor.request_id,
},
),
mimetype="text/event-stream",

View File

@@ -160,9 +160,6 @@ class CompressionOrchestrator:
agent_id=conversation.get("agent_id"),
model_user_id=registry_user_id,
)
# Side-channel LLM tag — distinguishes compression rows
# from primary stream rows for cost-attribution dashboards.
compression_llm._token_usage_source = "compression"
# Create compression service with DB update capability
compression_service = CompressionService(

View File

@@ -12,12 +12,6 @@ logger = logging.getLogger(__name__)
class TokenCounter:
"""Centralized token counting for conversations and messages."""
# Per-image token estimate. Provider tokenizers vary widely
# (Gemini ~258, GPT-4o 85-1500, Claude ~1500) and the actual cost
# depends on resolution/detail we can't see here. Errs slightly high
# so the threshold check stays conservative.
_IMAGE_PART_TOKEN_ESTIMATE = 1500
@staticmethod
def count_message_tokens(messages: List[Dict]) -> int:
"""
@@ -35,36 +29,12 @@ class TokenCounter:
if isinstance(content, str):
total_tokens += num_tokens_from_string(content)
elif isinstance(content, list):
# Handle structured content (tool calls, image parts, etc.)
# Handle structured content (tool calls, etc.)
for item in content:
if isinstance(item, dict):
total_tokens += TokenCounter._count_content_part(item)
total_tokens += num_tokens_from_string(str(item))
return total_tokens
@staticmethod
def _count_content_part(item: Dict) -> int:
# Image/file attachments are billed by the provider per image,
# not proportional to the inline bytes/base64 string.
# ``str(item)`` on a 1MB image inflates the count by ~10000x,
# which trips spurious compression and overflows downstream
# input limits.
item_type = item.get("type")
if "files" in item:
files = item.get("files")
count = len(files) if isinstance(files, list) and files else 1
return TokenCounter._IMAGE_PART_TOKEN_ESTIMATE * count
if "image_url" in item or item_type in {
"image",
"image_url",
"input_image",
"file",
}:
return TokenCounter._IMAGE_PART_TOKEN_ESTIMATE
return num_tokens_from_string(str(item))
@staticmethod
def count_query_tokens(
queries: List[Dict[str, Any]], include_tool_calls: bool = True

View File

@@ -7,13 +7,13 @@ resume later by sending tool_actions.
import logging
from typing import Any, Dict, List, Optional
from uuid import UUID
from application.storage.db.base_repository import looks_like_uuid
from application.storage.db.repositories.conversations import ConversationsRepository
from application.storage.db.repositories.pending_tool_state import (
PendingToolStateRepository,
)
from application.storage.db.serialization import coerce_pg_native as _make_serializable
from application.storage.db.session import db_readonly, db_session
logger = logging.getLogger(__name__)
@@ -21,9 +21,23 @@ logger = logging.getLogger(__name__)
# TTL for pending states — auto-cleaned after this period
PENDING_STATE_TTL_SECONDS = 30 * 60 # 30 minutes
# Re-export so the existing tests at tests/api/answer/services/test_continuation_service_pg.py
# can keep importing ``_make_serializable`` from here.
__all__ = ["_make_serializable", "ContinuationService", "PENDING_STATE_TTL_SECONDS"]
def _make_serializable(obj: Any) -> Any:
"""Recursively coerce non-JSON values into JSON-safe forms.
Handles ``uuid.UUID`` (from PG columns), ``bytes``, and recurses into
dicts/lists. Post-Mongo-cutover the ObjectId branch is gone — none of
our writers produce them anymore.
"""
if isinstance(obj, UUID):
return str(obj)
if isinstance(obj, dict):
return {str(k): _make_serializable(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_make_serializable(v) for v in obj]
if isinstance(obj, bytes):
return obj.decode("utf-8", errors="replace")
return obj
class ContinuationService:
@@ -141,23 +155,3 @@ class ContinuationService:
f"Deleted continuation state for conversation {conversation_id}"
)
return deleted
def mark_resuming(self, conversation_id: str, user: str) -> bool:
"""Flip the pending row to ``resuming`` so a crashed resume can be retried."""
with db_session() as conn:
conv = ConversationsRepository(conn).get_by_legacy_id(conversation_id)
if conv is not None:
pg_conv_id = conv["id"]
elif looks_like_uuid(conversation_id):
pg_conv_id = conversation_id
else:
return False
flipped = PendingToolStateRepository(conn).mark_resuming(
pg_conv_id, user
)
if flipped:
logger.info(
f"Marked continuation state as resuming for conversation "
f"{conversation_id}"
)
return flipped

View File

@@ -6,7 +6,6 @@ than held for the duration of a stream.
"""
import logging
import uuid
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
@@ -15,22 +14,13 @@ from sqlalchemy import text as sql_text
from application.core.settings import settings
from application.storage.db.base_repository import looks_like_uuid
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.repositories.conversations import (
ConversationsRepository,
MessageUpdateOutcome,
)
from application.storage.db.repositories.conversations import ConversationsRepository
from application.storage.db.session import db_readonly, db_session
logger = logging.getLogger(__name__)
# Shown to the user if the worker dies mid-stream and the response is never finalised.
TERMINATED_RESPONSE_PLACEHOLDER = (
"Response was terminated prior to completion, try regenerating."
)
class ConversationService:
def get_conversation(
self, conversation_id: str, user_id: str
@@ -189,243 +179,6 @@ class ConversationService:
repo.append_message(conv_pg_id, append_payload)
return conv_pg_id
def save_user_question(
self,
conversation_id: Optional[str],
question: str,
decoded_token: Dict[str, Any],
*,
attachment_ids: Optional[List[str]] = None,
api_key: Optional[str] = None,
agent_id: Optional[str] = None,
is_shared_usage: bool = False,
shared_token: Optional[str] = None,
model_id: Optional[str] = None,
request_id: Optional[str] = None,
status: str = "pending",
index: Optional[int] = None,
) -> Dict[str, str]:
"""Reserve the placeholder message row before the LLM call.
``index`` triggers regenerate semantics: messages at
``position >= index`` are truncated so the new placeholder
lands at ``position = index`` rather than appending.
Returns ``{"conversation_id", "message_id", "request_id"}``.
"""
if decoded_token is None:
raise ValueError("Invalid or missing authentication token")
user_id = decoded_token.get("sub")
if not user_id:
raise ValueError("User ID not found in token")
request_id = request_id or str(uuid.uuid4())
resolved_api_key: Optional[str] = None
resolved_agent_id: Optional[str] = None
if api_key and not conversation_id:
with db_readonly() as conn:
agent = AgentsRepository(conn).find_by_key(api_key)
if agent:
resolved_api_key = agent.get("key")
if agent_id:
resolved_agent_id = agent_id
with db_session() as conn:
repo = ConversationsRepository(conn)
if conversation_id:
conv = repo.get_any(conversation_id, user_id)
if conv is None:
raise ValueError("Conversation not found or unauthorized")
conv_pg_id = str(conv["id"])
# Regenerate / edit-prior-question: drop the message at
# ``index`` and everything after it so the new
# ``reserve_message`` lands at ``position=index`` rather
# than appending at the end of the conversation.
if isinstance(index, int) and index >= 0:
repo.truncate_after(conv_pg_id, keep_up_to=index - 1)
else:
fallback_name = (question[:50] if question else "New Conversation")
conv = repo.create(
user_id,
fallback_name,
agent_id=resolved_agent_id,
api_key=resolved_api_key,
is_shared_usage=bool(resolved_agent_id and is_shared_usage),
shared_token=(
shared_token
if (resolved_agent_id and is_shared_usage)
else None
),
)
conv_pg_id = str(conv["id"])
row = repo.reserve_message(
conv_pg_id,
prompt=question,
placeholder_response=TERMINATED_RESPONSE_PLACEHOLDER,
request_id=request_id,
status=status,
attachments=attachment_ids,
model_id=model_id,
)
message_id = str(row["id"])
return {
"conversation_id": conv_pg_id,
"message_id": message_id,
"request_id": request_id,
}
def update_message_status(self, message_id: str, status: str) -> bool:
"""Cheap status-only transition (e.g. ``pending → streaming``)."""
if not message_id:
return False
with db_session() as conn:
return ConversationsRepository(conn).update_message_status(
message_id, status,
)
def heartbeat_message(self, message_id: str) -> bool:
"""Bump ``message_metadata.last_heartbeat_at`` so the reconciler's
staleness sweep counts the row as alive. No-ops on terminal rows.
"""
if not message_id:
return False
with db_session() as conn:
return ConversationsRepository(conn).heartbeat_message(message_id)
def finalize_message(
self,
message_id: str,
response: str,
*,
thought: str = "",
sources: Optional[List[Dict[str, Any]]] = None,
tool_calls: Optional[List[Dict[str, Any]]] = None,
model_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
status: str = "complete",
error: Optional[BaseException] = None,
title_inputs: Optional[Dict[str, Any]] = None,
) -> MessageUpdateOutcome:
"""Commit the response and tool_call confirms in one transaction.
The outcome propagates directly from ``update_message_by_id`` so
callers (notably the SSE abort handler) can tell a fresh
finalize from "the row was already terminal" — the latter must
still be treated as success when the prior state was
``complete``.
"""
if not message_id:
return MessageUpdateOutcome.INVALID
sources = sources or []
for source in sources:
if "text" in source and isinstance(source["text"], str):
source["text"] = source["text"][:1000]
merged_metadata: Dict[str, Any] = dict(metadata or {})
if status == "failed" and error is not None:
merged_metadata.setdefault(
"error", f"{type(error).__name__}: {str(error)}"
)
update_fields: Dict[str, Any] = {
"response": response,
"status": status,
"thought": thought,
"sources": sources,
"tool_calls": tool_calls or [],
"metadata": merged_metadata,
}
if model_id is not None:
update_fields["model_id"] = model_id
# Atomic message update + tool_call_attempts confirm; the
# ``only_if_non_terminal`` guard prevents a late stream from
# retracting a row the reconciler already escalated.
with db_session() as conn:
repo = ConversationsRepository(conn)
outcome = repo.update_message_by_id(
message_id, update_fields,
only_if_non_terminal=True,
)
if outcome is not MessageUpdateOutcome.UPDATED:
logger.warning(
f"finalize_message: no row updated for message_id={message_id} "
f"(outcome={outcome.value} — possibly already terminal)"
)
return outcome
repo.confirm_executed_tool_calls(message_id)
# Outside the txn — title-gen is a multi-second LLM round trip.
if title_inputs and status == "complete":
try:
with db_session() as conn:
self._maybe_generate_title(conn, message_id, title_inputs)
except Exception as e:
logger.error(
f"finalize_message title generation failed: {e}",
exc_info=True,
)
return MessageUpdateOutcome.UPDATED
def _maybe_generate_title(
self,
conn,
message_id: str,
title_inputs: Dict[str, Any],
) -> None:
"""Generate an LLM-summarised conversation name if one isn't set yet."""
llm = title_inputs.get("llm")
question = title_inputs.get("question") or ""
response = title_inputs.get("response") or ""
fallback_name = title_inputs.get("fallback_name") or question[:50]
if llm is None:
return
row = conn.execute(
sql_text(
"SELECT c.id, c.name FROM conversation_messages m "
"JOIN conversations c ON c.id = m.conversation_id "
"WHERE m.id = CAST(:mid AS uuid)"
),
{"mid": message_id},
).fetchone()
if row is None:
return
conv_id, current_name = str(row[0]), row[1]
if current_name and current_name != fallback_name:
return
messages_summary = [
{
"role": "system",
"content": "You are a helpful assistant that creates concise conversation titles. "
"Summarize conversations in 3 words or less using the same language as the user.",
},
{
"role": "user",
"content": "Summarise following conversation in no more than 3 words, "
"respond ONLY with the summary, use the same language as the "
"user query \n\nUser: " + question + "\n\n" + "AI: " + response,
},
]
completion = llm.gen(
model=getattr(llm, "model_id", None) or title_inputs.get("model_id"),
messages=messages_summary,
max_tokens=500,
)
if not completion or not completion.strip():
completion = fallback_name or "New Conversation"
conn.execute(
sql_text(
"UPDATE conversations SET name = :name, updated_at = now() "
"WHERE id = CAST(:id AS uuid)"
),
{"id": conv_id, "name": completion.strip()},
)
def update_compression_metadata(
self, conversation_id: str, compression_metadata: Dict[str, Any]
) -> None:

View File

@@ -123,10 +123,6 @@ class StreamProcessor:
self.model_id: Optional[str] = None
# BYOM-resolution scope, set by _validate_and_set_model.
self.model_user_id: Optional[str] = None
# WAL placeholder id pulled from continuation state on resume.
self.reserved_message_id: Optional[str] = None
# Carried through resumes so multi-pause runs keep one request_id.
self.request_id: Optional[str] = None
self.conversation_service = ConversationService()
self.compression_orchestrator = CompressionOrchestrator(
self.conversation_service
@@ -932,20 +928,6 @@ class StreamProcessor:
if not state:
raise ValueError("No pending tool state found for this conversation")
# Claim the resume up-front. ``mark_resuming`` only flips ``pending``
# → ``resuming``; if it returns False, another resume already
# claimed this row (status='resuming') — bail before any further
# LLM/tool work to avoid double-execution. The cleanup janitor
# reverts a stale ``resuming`` claim back to ``pending`` after the
# 10-minute grace window so the user can retry.
if not cont_service.mark_resuming(
conversation_id, self.initial_user_id,
):
raise ValueError(
"Resume already in progress for this conversation; "
"retry after the grace window if it stalls."
)
messages = state["messages"]
pending_tool_calls = state["pending_tool_calls"]
tools_dict = state["tools_dict"]
@@ -1040,10 +1022,9 @@ class StreamProcessor:
self.agent_id = agent_id
self.agent_config["user_api_key"] = user_api_key
self.conversation_id = conversation_id
# Reused on resume so the same WAL row gets finalised and
# request_id stays consistent across token_usage rows.
self.reserved_message_id = agent_config.get("reserved_message_id")
self.request_id = agent_config.get("request_id")
# Delete state so it can't be replayed
cont_service.delete_state(conversation_id, self.initial_user_id)
return agent, messages, tools_dict, pending_tool_calls, tool_actions

View File

@@ -1,504 +0,0 @@
"""GET /api/events — user-scoped Server-Sent Events endpoint.
Subscribe-then-snapshot pattern: subscribe to ``user:{user_id}``
pub/sub, snapshot the Redis Streams backlog past ``Last-Event-ID``
inside the SUBSCRIBE-ack callback, flush snapshot, then tail live
events (dedup'd by stream id). See ``docs/runbooks/sse-notifications.md``.
"""
from __future__ import annotations
import json
import logging
import re
import time
from typing import Iterator, Optional
from flask import Blueprint, Response, jsonify, make_response, request, stream_with_context
from application.cache import get_redis_instance
from application.core.settings import settings
from application.events.keys import (
connection_counter_key,
replay_budget_key,
stream_id_compare,
stream_key,
topic_name,
)
from application.streaming.broadcast_channel import Topic
logger = logging.getLogger(__name__)
events = Blueprint("event_stream", __name__)
SUBSCRIBE_POLL_INTERVAL_SECONDS = 1.0
# WHATWG SSE treats CRLF, CR, and LF equivalently as line terminators.
_SSE_LINE_SPLIT = re.compile(r"\r\n|\r|\n")
# Redis Streams ids are ``ms`` or ``ms-seq`` where both halves are decimal.
# Anything else is a corrupted client cookie / IndexedDB residue and must
# not be passed to XRANGE — Redis would reject it and our truncation gate
# would silently fail.
_STREAM_ID_RE = re.compile(r"^\d+(-\d+)?$")
# Only emitted at most once per process so a misconfigured deployment
# doesn't drown the logs.
_local_user_warned = False
def _format_sse(data: str, *, event_id: Optional[str] = None) -> str:
"""Encode a payload as one SSE message terminated by a blank line.
Splits on any line-terminator variant (``\\r\\n``, ``\\r``, ``\\n``)
so a stray CR in upstream content can't smuggle a premature line
boundary into the wire format.
"""
lines: list[str] = []
if event_id:
lines.append(f"id: {event_id}")
for line in _SSE_LINE_SPLIT.split(data):
lines.append(f"data: {line}")
return "\n".join(lines) + "\n\n"
def _decode(value) -> Optional[str]:
if value is None:
return None
if isinstance(value, (bytes, bytearray)):
try:
return value.decode("utf-8")
except Exception:
return None
return str(value)
def _oldest_retained_id(redis_client, user_id: str) -> Optional[str]:
"""Return the id of the oldest entry still in the stream, or ``None``.
Used to detect ``Last-Event-ID`` having slid off the back of the
MAXLEN'd window.
"""
try:
info = redis_client.xinfo_stream(stream_key(user_id))
except Exception:
return None
if not isinstance(info, dict):
return None
# redis-py 7.4 returns str-keyed dicts here; the bytes-key probe is
# defence in depth in case ``decode_responses`` is ever flipped.
first_entry = info.get("first-entry") or info.get(b"first-entry")
if not first_entry:
return None
# XINFO STREAM returns first-entry as [id, [field, value, ...]]
try:
return _decode(first_entry[0])
except Exception:
return None
def _allow_replay(
redis_client, user_id: str, last_event_id: Optional[str]
) -> bool:
"""Per-user sliding-window snapshot-replay budget.
Fails open on Redis errors or when the budget is disabled. Empty-backlog
no-cursor connects skip INCR so dev double-mounts don't trip 429.
"""
budget = int(settings.EVENTS_REPLAY_BUDGET_REQUESTS_PER_WINDOW)
if budget <= 0:
return True
if redis_client is None:
return True
# Cheap pre-check: only INCR when we might actually replay. XLEN
# is one Redis op; the alternative (INCR every connect) is two
# ops AND wrongly counts no-op probes. The check is conservative:
# if ``last_event_id`` is set we always INCR, even if the cursor
# has already overtaken the latest entry — that case is rare and
# short-lived, and probing further would mean a redundant XRANGE.
if last_event_id is None:
try:
if int(redis_client.xlen(stream_key(user_id))) == 0:
return True
except Exception:
# XLEN probe failed; fall through to the INCR path so a
# transient Redis hiccup can't bypass the budget.
logger.debug(
"XLEN probe failed for replay budget check user=%s; "
"proceeding to INCR",
user_id,
)
window = max(1, int(settings.EVENTS_REPLAY_BUDGET_WINDOW_SECONDS))
key = replay_budget_key(user_id)
try:
used = int(redis_client.incr(key))
# Always (re)seed the TTL. Gating on ``used == 1`` would wedge
# the counter forever if INCR succeeds but EXPIRE raises on
# the seeding call. EXPIRE on an existing key resets the TTL
# to ``window`` — within ±1s of the per-window budget semantic.
redis_client.expire(key, window)
except Exception:
logger.debug(
"replay budget probe failed for user=%s; failing open",
user_id,
)
return True
return used <= budget
def _normalize_last_event_id(raw: Optional[str]) -> Optional[str]:
"""Validate the ``Last-Event-ID`` header / query param.
Returns the value unchanged when it parses as a Redis Streams id,
otherwise ``None`` — callers treat ``None`` as "client has nothing"
and replay from the start of the retained window. Invalid ids would
otherwise pass straight to XRANGE and surface as a quiet replay
failure plus broken truncation detection.
"""
if raw is None:
return None
raw = raw.strip()
if not raw or not _STREAM_ID_RE.match(raw):
return None
return raw
def _replay_backlog(
redis_client, user_id: str, last_event_id: Optional[str], max_count: int
) -> Iterator[tuple[str, str]]:
"""Yield ``(entry_id, sse_line)`` for backlog entries past ``last_event_id``.
Capped at ``max_count`` rows; clients catch up across reconnects.
Parse failures are skipped; the Streams id is injected into the
envelope so replay matches live-tail shape.
"""
# Exclusive start: '(<id>' skips the already-delivered entry.
start = f"({last_event_id}" if last_event_id else "-"
try:
entries = redis_client.xrange(
stream_key(user_id), min=start, max="+", count=max_count
)
except Exception as exc:
logger.warning(
"xrange replay failed for user=%s last_id=%s err=%s",
user_id,
last_event_id or "-",
exc,
)
return
for entry_id, fields in entries:
entry_id_str = _decode(entry_id)
if not entry_id_str:
continue
# decode_responses=False on the cache client ⇒ field keys/values
# are bytes. The string-key fallback covers a future flip of that
# default without a forced refactor here.
raw_event = None
if isinstance(fields, dict):
raw_event = fields.get(b"event")
if raw_event is None:
raw_event = fields.get("event")
event_str = _decode(raw_event)
if not event_str:
continue
try:
envelope = json.loads(event_str)
if isinstance(envelope, dict):
envelope["id"] = entry_id_str
event_str = json.dumps(envelope)
except Exception:
logger.debug(
"Replay envelope parse failed for entry %s; passing through raw",
entry_id_str,
)
yield entry_id_str, _format_sse(event_str, event_id=entry_id_str)
def _truncation_notice_line(oldest_id: str) -> str:
"""SSE event the frontend can react to with a full-state refetch."""
return _format_sse(
json.dumps(
{
"type": "backlog.truncated",
"payload": {"oldest_retained_id": oldest_id},
}
)
)
@events.route("/api/events", methods=["GET"])
def stream_events() -> Response:
decoded = getattr(request, "decoded_token", None)
user_id = decoded.get("sub") if isinstance(decoded, dict) else None
if not user_id:
return make_response(
jsonify({"success": False, "message": "Authentication required"}),
401,
)
# In dev deployments without AUTH_TYPE configured, every request
# resolves to user_id="local" and shares one stream. Surface this so
# an accidentally-multi-user dev box doesn't silently cross-stream.
global _local_user_warned
if user_id == "local" and not _local_user_warned:
logger.warning(
"SSE serving user_id='local' (AUTH_TYPE not set). "
"All clients on this deployment will share one event stream."
)
_local_user_warned = True
raw_last_event_id = request.headers.get("Last-Event-ID") or request.args.get(
"last_event_id"
)
last_event_id = _normalize_last_event_id(raw_last_event_id)
last_event_id_invalid = raw_last_event_id is not None and last_event_id is None
keepalive_seconds = float(settings.SSE_KEEPALIVE_SECONDS)
push_enabled = settings.ENABLE_SSE_PUSH
cap = int(settings.SSE_MAX_CONCURRENT_PER_USER)
redis_client = get_redis_instance()
counter_key = connection_counter_key(user_id)
counted = False
if push_enabled and redis_client is not None and cap > 0:
try:
current = int(redis_client.incr(counter_key))
counted = True
except Exception:
current = 0
logger.debug(
"SSE connection counter INCR failed for user=%s", user_id
)
if counted:
# 1h safety TTL — orphaned counts from hard crashes self-heal.
# EXPIRE failure must NOT clobber ``current`` and bypass the cap.
try:
redis_client.expire(counter_key, 3600)
except Exception:
logger.debug(
"SSE connection counter EXPIRE failed for user=%s", user_id
)
if current > cap:
try:
redis_client.decr(counter_key)
except Exception:
logger.debug(
"SSE connection counter DECR failed for user=%s",
user_id,
)
return make_response(
jsonify(
{
"success": False,
"message": "Too many concurrent SSE connections",
}
),
429,
)
# Replay budget is checked here, before the generator opens the
# stream, so a denial can surface as HTTP 429 instead of a silent
# snapshot skip. The earlier in-generator skip lost events between
# the client's cursor and the first live-tailed entry: the live
# tail still carried ``id:`` headers, the frontend advanced
# ``lastEventId`` to one of those ids, and the events in between
# were never reachable on the next reconnect. 429 keeps the
# cursor pinned and lets the frontend back off until the window
# slides (eventStreamClient.ts treats 429 as escalated backoff).
if push_enabled and redis_client is not None and not _allow_replay(
redis_client, user_id, last_event_id
):
if counted:
try:
redis_client.decr(counter_key)
except Exception:
logger.debug(
"SSE connection counter DECR failed for user=%s",
user_id,
)
return make_response(
jsonify(
{
"success": False,
"message": "Replay budget exhausted",
}
),
429,
)
@stream_with_context
def generate() -> Iterator[str]:
connect_ts = time.monotonic()
replayed_count = 0
try:
# First frame primes intermediaries (Cloudflare, nginx) so they
# don't sit on a buffer waiting for body bytes.
yield ": connected\n\n"
if not push_enabled:
yield ": push_disabled\n\n"
return
replay_lines: list[str] = []
max_replayed_id: Optional[str] = None
replay_done = False
# If the client sent a malformed Last-Event-ID, surface the
# truncation notice synchronously *before* the subscribe
# loop. Buffering it into ``replay_lines`` would lose it
# when ``Topic.subscribe`` returns immediately (Redis down)
# — the loop body never runs, and the flush at line ~335
# never fires.
if last_event_id_invalid:
yield _truncation_notice_line("")
replayed_count += 1
def _on_subscribe_callback() -> None:
# Runs synchronously inside Topic.subscribe after the
# SUBSCRIBE is acked. By doing XRANGE here, any publisher
# firing between SUBSCRIBE-send and SUBSCRIBE-ack has its
# XADD captured by XRANGE *and* its PUBLISH buffered at
# the connection layer until we read it — closing the
# replay/subscribe race the design doc warns about.
#
# Truncation contract: ``backlog.truncated`` is emitted
# ONLY when the client's ``Last-Event-ID`` has slid off
# the MAXLEN'd window — that's the case where the
# journal is genuinely gone past the cursor and the
# frontend should clear its slice cursor and refetch
# state. Cap-hit skips the snapshot silently: the
# cursor advances via the per-entry ``id:`` headers
# and the frontend's slice keeps the latest id so the
# next reconnect resumes from there. Budget-exhausted
# never reaches this callback — the route 429s before
# opening the stream, keeping the cursor pinned.
# Conflating these with stale-cursor truncation would
# tell the client to clear its cursor and re-receive
# the same oldest-N entries on every reconnect —
# locking the user out of entries past N.
nonlocal max_replayed_id, replay_done
try:
if redis_client is None:
return
oldest = _oldest_retained_id(redis_client, user_id)
if (
last_event_id
and oldest
and stream_id_compare(last_event_id, oldest) < 0
):
# The Last-Event-ID has slid off the MAXLEN window.
# Tell the client so it can fetch full state.
replay_lines.append(_truncation_notice_line(oldest))
replay_cap = int(settings.EVENTS_REPLAY_MAX_PER_REQUEST)
for entry_id, sse_line in _replay_backlog(
redis_client, user_id, last_event_id, replay_cap
):
replay_lines.append(sse_line)
max_replayed_id = entry_id
finally:
# Always flip the flag — even on partial-replay failure
# the outer loop must reach the flush step so we don't
# silently strand whatever entries did land.
replay_done = True
topic = Topic(topic_name(user_id))
last_keepalive = time.monotonic()
for payload in topic.subscribe(
on_subscribe=_on_subscribe_callback,
poll_timeout=SUBSCRIBE_POLL_INTERVAL_SECONDS,
):
# Flush snapshot on the first iteration after the SUBSCRIBE
# callback ran. This runs at most once per connection.
if replay_done and replay_lines:
for line in replay_lines:
yield line
replayed_count += 1
replay_lines.clear()
now = time.monotonic()
if payload is None:
if now - last_keepalive >= keepalive_seconds:
yield ": keepalive\n\n"
last_keepalive = now
continue
event_str = _decode(payload) or ""
event_id: Optional[str] = None
try:
envelope = json.loads(event_str)
if isinstance(envelope, dict):
candidate = envelope.get("id")
# Only trust ids that look like real Redis Streams
# ids (``ms`` or ``ms-seq``). A malformed or
# adversarial publisher could otherwise pin
# dedupe forever — a lex-greater bogus id would
# make every legitimate later id compare ``<=``
# and get dropped silently.
if isinstance(candidate, str) and _STREAM_ID_RE.match(
candidate
):
event_id = candidate
except Exception:
pass
# Dedupe: if this id was already covered by replay, drop it.
if (
event_id is not None
and max_replayed_id is not None
and stream_id_compare(event_id, max_replayed_id) <= 0
):
continue
yield _format_sse(event_str, event_id=event_id)
last_keepalive = now
# Topic.subscribe exited before the first yield (transient
# Redis hiccup between SUBSCRIBE-ack and the first poll, or
# an immediate Redis-down return). The callback may already
# have populated the snapshot — flush it so the client gets
# the backlog instead of a silent drop. Safe no-op when the
# in-loop flush ran (it clear()'d the buffer) and when the
# callback never fired (replay_done stays False).
if replay_done and replay_lines:
for line in replay_lines:
yield line
replayed_count += 1
replay_lines.clear()
except GeneratorExit:
return
except Exception:
logger.exception(
"SSE event-stream generator crashed for user=%s", user_id
)
finally:
duration_s = time.monotonic() - connect_ts
logger.info(
"event.disconnect user=%s duration_s=%.1f replayed=%d",
user_id,
duration_s,
replayed_count,
)
if counted and redis_client is not None:
try:
redis_client.decr(counter_key)
except Exception:
logger.debug(
"SSE connection counter DECR failed for user=%s on disconnect",
user_id,
)
response = Response(generate(), mimetype="text/event-stream")
response.headers["Cache-Control"] = "no-store"
response.headers["X-Accel-Buffering"] = "no"
response.headers["Connection"] = "keep-alive"
logger.info(
"event.connect user=%s last_event_id=%s%s",
user_id,
last_event_id or "-",
" (rejected_invalid)" if last_event_id_invalid else "",
)
return response

View File

@@ -46,9 +46,7 @@ AGENT_TYPE_SCHEMAS = {
"prompt_id",
],
"required_draft": ["name"],
# ``prompt_id`` intentionally omitted — the "default" sentinel
# is acceptable and maps to NULL downstream.
"validate_published": ["name", "description"],
"validate_published": ["name", "description", "prompt_id"],
"validate_draft": [],
"require_source": True,
"fields": [
@@ -1011,16 +1009,12 @@ class UpdateAgent(Resource):
400,
)
else:
# ``prompt_id`` is intentionally omitted: the
# frontend's "default" choice maps to NULL here
# (see the prompt_id branch above), and NULL
# means "use the built-in default prompt" which
# is a valid published-agent state.
missing_published_fields = []
for req_field, field_label in (
("name", "Agent name"),
("description", "Agent description"),
("chunks", "Chunks count"),
("prompt_id", "Prompt"),
("agent_type", "Agent type"),
):
final_value = update_fields.get(
@@ -1034,23 +1028,8 @@ class UpdateAgent(Resource):
extra_final = update_fields.get(
"extra_source_ids", existing_agent.get("extra_source_ids") or [],
)
# ``retriever`` carries the runtime identity for
# agents that publish against the synthetic
# "Default" source (frontend's auto-selected
# ``{name: "Default", retriever: "classic"}``
# entry has no ``id``, so ``source_id`` ends up
# NULL even though the user picked something).
# Without this fallback the most common new-agent
# publish flow gets a 400.
retriever_final = update_fields.get(
"retriever", existing_agent.get("retriever"),
)
if (
not source_final
and not extra_final
and not retriever_final
):
missing_published_fields.append("Source or retriever")
if not source_final and not extra_final:
missing_published_fields.append("Source")
if missing_published_fields:
return make_response(
jsonify(

View File

@@ -1,19 +1,15 @@
"""Agent management webhook handlers."""
import secrets
import uuid
from flask import current_app, jsonify, make_response, request
from flask_restx import Namespace, Resource
from sqlalchemy import text as sql_text
from application.api import api
from application.api.user.base import require_agent
from application.api.user.tasks import process_agent_webhook
from application.core.settings import settings
from application.storage.db.base_repository import looks_like_uuid
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.repositories.idempotency import IdempotencyRepository
from application.storage.db.session import db_readonly, db_session
@@ -22,37 +18,6 @@ agents_webhooks_ns = Namespace(
)
_IDEMPOTENCY_KEY_MAX_LEN = 256
def _read_idempotency_key():
"""Return (key, error_response). Empty header → (None, None); oversized → (None, 400)."""
key = request.headers.get("Idempotency-Key")
if not key:
return None, None
if len(key) > _IDEMPOTENCY_KEY_MAX_LEN:
return None, make_response(
jsonify(
{
"success": False,
"message": (
f"Idempotency-Key exceeds maximum length of "
f"{_IDEMPOTENCY_KEY_MAX_LEN} characters"
),
}
),
400,
)
return key, None
def _scoped_idempotency_key(idempotency_key, scope):
"""``{scope}:{key}`` so different agents can't collide on the same key."""
if not idempotency_key or not scope:
return None
return f"{scope}:{idempotency_key}"
@agents_webhooks_ns.route("/agent_webhook")
class AgentWebhook(Resource):
@api.doc(
@@ -103,7 +68,7 @@ class AgentWebhook(Resource):
class AgentWebhookListener(Resource):
method_decorators = [require_agent]
def _enqueue_webhook_task(self, agent_id_str, payload, source_method, agent=None):
def _enqueue_webhook_task(self, agent_id_str, payload, source_method):
if not payload:
current_app.logger.warning(
f"Webhook ({source_method}) received for agent {agent_id_str} with empty payload."
@@ -112,94 +77,26 @@ class AgentWebhookListener(Resource):
f"Incoming {source_method} webhook for agent {agent_id_str}. Enqueuing task with payload: {payload}"
)
idempotency_key, key_error = _read_idempotency_key()
if key_error is not None:
return key_error
# Resolve to PG UUID first so dedup writes don't crash on legacy ids.
agent_uuid = None
if agent is not None:
candidate = str(agent.get("id") or "")
if looks_like_uuid(candidate):
agent_uuid = candidate
if idempotency_key and agent_uuid is None:
current_app.logger.warning(
"Skipping webhook idempotency dedup: agent %s has non-UUID id",
agent_id_str,
)
idempotency_key = None
# Agent-scoped (webhooks have no user_id).
scoped_key = _scoped_idempotency_key(idempotency_key, agent_uuid)
# Claim before enqueue; the loser returns the winner's task_id.
predetermined_task_id = None
if scoped_key:
predetermined_task_id = str(uuid.uuid4())
with db_session() as conn:
claimed = IdempotencyRepository(conn).record_webhook(
key=scoped_key,
agent_id=agent_uuid,
task_id=predetermined_task_id,
response_json={
"success": True, "task_id": predetermined_task_id,
},
)
if claimed is None:
with db_readonly() as conn:
cached = IdempotencyRepository(conn).get_webhook(scoped_key)
if cached is not None:
return make_response(jsonify(cached["response_json"]), 200)
return make_response(
jsonify({"success": True, "task_id": "deduplicated"}), 200
)
try:
apply_kwargs = dict(
kwargs={
"agent_id": agent_id_str,
"payload": payload,
# Scoped so the worker dedup row matches the HTTP claim.
"idempotency_key": scoped_key or idempotency_key,
},
task = process_agent_webhook.delay(
agent_id=agent_id_str,
payload=payload,
)
if predetermined_task_id is not None:
apply_kwargs["task_id"] = predetermined_task_id
task = process_agent_webhook.apply_async(**apply_kwargs)
current_app.logger.info(
f"Task {task.id} enqueued for agent {agent_id_str} ({source_method})."
)
response_payload = {"success": True, "task_id": task.id}
return make_response(jsonify(response_payload), 200)
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
except Exception as err:
current_app.logger.error(
f"Error enqueuing webhook task ({source_method}) for agent {agent_id_str}: {err}",
exc_info=True,
)
if scoped_key:
# Roll back the claim so a retry can succeed.
try:
with db_session() as conn:
conn.execute(
sql_text(
"DELETE FROM webhook_dedup "
"WHERE idempotency_key = :k"
),
{"k": scoped_key},
)
except Exception:
current_app.logger.exception(
"Failed to release webhook_dedup claim for key=%s",
scoped_key,
)
return make_response(
jsonify({"success": False, "message": "Error processing webhook"}), 500
)
@api.doc(
description=(
"Webhook listener for agent events (POST). Expects JSON payload, which "
"is used to trigger processing. Honors an optional ``Idempotency-Key`` "
"header: a repeat request with the same key within 24h returns the "
"original cached response and does not re-enqueue the task."
),
description="Webhook listener for agent events (POST). Expects JSON payload, which is used to trigger processing.",
)
def post(self, webhook_token, agent, agent_id_str):
payload = request.get_json()
@@ -213,20 +110,11 @@ class AgentWebhookListener(Resource):
),
400,
)
return self._enqueue_webhook_task(
agent_id_str, payload, source_method="POST", agent=agent,
)
return self._enqueue_webhook_task(agent_id_str, payload, source_method="POST")
@api.doc(
description=(
"Webhook listener for agent events (GET). Uses URL query parameters as "
"payload to trigger processing. Honors an optional ``Idempotency-Key`` "
"header: a repeat request with the same key within 24h returns the "
"original cached response and does not re-enqueue the task."
),
description="Webhook listener for agent events (GET). Uses URL query parameters as payload to trigger processing.",
)
def get(self, webhook_token, agent, agent_id_str):
payload = request.args.to_dict(flat=True)
return self._enqueue_webhook_task(
agent_id_str, payload, source_method="GET", agent=agent,
)
return self._enqueue_webhook_task(agent_id_str, payload, source_method="GET")

View File

@@ -214,10 +214,6 @@ class StoreAttachment(Resource):
{
"success": True,
"task_id": tasks[0]["task_id"],
# Surface the attachment_id so the frontend
# can correlate ``attachment.*`` SSE events
# to this row and skip the polling fallback.
"attachment_id": tasks[0]["attachment_id"],
"message": "File uploaded successfully. Processing started.",
}
),

View File

@@ -4,16 +4,10 @@ import datetime
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, Namespace, Resource
from sqlalchemy import text as sql_text
from application.api import api
from application.api.answer.services.conversation_service import (
TERMINATED_RESPONSE_PLACEHOLDER,
)
from application.storage.db.base_repository import looks_like_uuid, row_to_dict
from application.storage.db.repositories.attachments import AttachmentsRepository
from application.storage.db.repositories.conversations import ConversationsRepository
from application.storage.db.repositories.message_events import MessageEventsRepository
from application.storage.db.session import db_readonly, db_session
from application.utils import check_required_fields
@@ -139,7 +133,6 @@ class GetSingleConversation(Resource):
attachments_repo = AttachmentsRepository(conn)
queries = []
for msg in messages:
metadata = msg.get("metadata") or {}
query = {
"prompt": msg.get("prompt"),
"response": msg.get("response"),
@@ -148,15 +141,9 @@ class GetSingleConversation(Resource):
"tool_calls": msg.get("tool_calls") or [],
"timestamp": msg.get("timestamp"),
"model_id": msg.get("model_id"),
# Lets the client distinguish placeholder rows from
# finalised answers and tail-poll in-flight ones.
"message_id": str(msg["id"]) if msg.get("id") else None,
"status": msg.get("status"),
"request_id": msg.get("request_id"),
"last_heartbeat_at": metadata.get("last_heartbeat_at"),
}
if metadata:
query["metadata"] = metadata
if msg.get("metadata"):
query["metadata"] = msg["metadata"]
# Feedback on conversation_messages is a JSONB blob with
# shape {"text": <str>, "timestamp": <iso>}. The legacy
# frontend consumed a flat scalar feedback string, so
@@ -314,80 +301,3 @@ class SubmitFeedback(Resource):
current_app.logger.error(f"Error submitting feedback: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True}), 200)
@conversations_ns.route("/messages/<string:message_id>/tail")
class GetMessageTail(Resource):
@api.doc(
description=(
"Current state of one conversation_messages row, scoped to the "
"authenticated user. Used to reconnect to an in-flight stream "
"after a refresh."
),
params={"message_id": "Message UUID"},
)
def get(self, message_id):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
if not looks_like_uuid(message_id):
return make_response(
jsonify({"success": False, "message": "Invalid message id"}), 400
)
user_id = decoded_token.get("sub")
try:
with db_readonly() as conn:
# Owner-or-shared, matching ``ConversationsRepository.get``.
row = conn.execute(
sql_text(
"SELECT m.* FROM conversation_messages m "
"JOIN conversations c ON c.id = m.conversation_id "
"WHERE m.id = CAST(:mid AS uuid) "
"AND (c.user_id = :uid OR :uid = ANY(c.shared_with))"
),
{"mid": message_id, "uid": user_id},
).fetchone()
if row is None:
return make_response(jsonify({"status": "not found"}), 404)
msg = row_to_dict(row)
# Mid-stream the row's response is the placeholder; rebuild
# the live partial from the journal so /tail mirrors SSE.
status = msg.get("status")
response = msg.get("response")
thought = msg.get("thought")
sources = msg.get("sources") or []
tool_calls = msg.get("tool_calls") or []
if status in ("pending", "streaming") and (
response == TERMINATED_RESPONSE_PLACEHOLDER
):
partial = MessageEventsRepository(conn).reconstruct_partial(
message_id
)
response = partial["response"]
thought = partial["thought"] or thought
if partial["sources"]:
sources = partial["sources"]
if partial["tool_calls"]:
tool_calls = partial["tool_calls"]
except Exception as err:
current_app.logger.error(
f"Error tailing message {message_id}: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
metadata = msg.get("message_metadata") or {}
return make_response(
jsonify(
{
"message_id": str(msg["id"]),
"status": status,
"response": response,
"thought": thought,
"sources": sources,
"tool_calls": tool_calls,
"request_id": msg.get("request_id"),
"last_heartbeat_at": metadata.get("last_heartbeat_at"),
"error": metadata.get("error"),
}
),
200,
)

View File

@@ -1,237 +0,0 @@
"""Per-Celery-task idempotency wrapper backed by ``task_dedup``."""
from __future__ import annotations
import functools
import logging
import threading
import uuid
from typing import Any, Callable, Optional
from application.storage.db.repositories.idempotency import IdempotencyRepository
from application.storage.db.session import db_readonly, db_session
logger = logging.getLogger(__name__)
# Poison-loop cap; transient-failure headroom without infinite retry.
MAX_TASK_ATTEMPTS = 5
# 30s heartbeat / 60s TTL → ~2 missed ticks of slack before reclaim.
LEASE_TTL_SECONDS = 60
LEASE_HEARTBEAT_INTERVAL = 30
# 10 × 60s ≈ 5 min of deferral before giving up on a held lease.
LEASE_RETRY_MAX = 10
def with_idempotency(task_name: str) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""Short-circuit on completed key; gate concurrent runs via a lease.
Entry short-circuits:
- completed row → return cached result
- live lease held → retry(countdown=LEASE_TTL_SECONDS)
- attempt_count > MAX_TASK_ATTEMPTS → poison-loop alert
Success writes ``completed``; exceptions leave ``pending`` for
autoretry until the poison-loop guard trips.
"""
def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
@functools.wraps(fn)
def wrapper(self, *args: Any, idempotency_key: Any = None, **kwargs: Any) -> Any:
key = idempotency_key if isinstance(idempotency_key, str) and idempotency_key else None
if key is None:
return fn(self, *args, idempotency_key=idempotency_key, **kwargs)
cached = _lookup_completed(key)
if cached is not None:
logger.info(
"idempotency hit for task=%s key=%s — returning cached result",
task_name, key,
)
return cached
owner_id = str(uuid.uuid4())
attempt = _try_claim_lease(
key, task_name, _safe_task_id(self), owner_id,
)
if attempt is None:
# Live lease held by another worker. Re-queue and bail
# quickly — by the time the retry fires (LEASE_TTL
# seconds), Worker 1 has either finalised (we'll hit
# ``_lookup_completed`` and return cached) or its lease
# has expired and we can claim.
logger.info(
"idempotency: live lease held; deferring task=%s key=%s",
task_name, key,
)
raise self.retry(
countdown=LEASE_TTL_SECONDS,
max_retries=LEASE_RETRY_MAX,
)
if attempt > MAX_TASK_ATTEMPTS:
logger.error(
"idempotency poison-loop guard: task=%s key=%s attempts=%s",
task_name, key, attempt,
extra={
"alert": "idempotency_poison_loop",
"task_name": task_name,
"idempotency_key": key,
"attempts": attempt,
},
)
poisoned = {
"success": False,
"error": "idempotency poison-loop guard tripped",
"attempts": attempt,
}
_finalize(key, poisoned, status="failed")
return poisoned
heartbeat_thread, heartbeat_stop = _start_lease_heartbeat(
key, owner_id,
)
try:
result = fn(self, *args, idempotency_key=idempotency_key, **kwargs)
_finalize(key, result, status="completed")
return result
except Exception:
# Drop the lease so the next retry doesn't wait LEASE_TTL.
_release_lease(key, owner_id)
raise
finally:
_stop_lease_heartbeat(heartbeat_thread, heartbeat_stop)
return wrapper
return decorator
def _lookup_completed(key: str) -> Any:
"""Return cached ``result_json`` if a completed row exists for ``key``, else None."""
with db_readonly() as conn:
row = IdempotencyRepository(conn).get_task(key)
if row is None:
return None
if row.get("status") != "completed":
return None
return row.get("result_json")
def _try_claim_lease(
key: str, task_name: str, task_id: str, owner_id: str,
) -> Optional[int]:
"""Atomic CAS; returns ``attempt_count`` or ``None`` when held.
DB outage → treated as ``attempt=1`` so transient failures don't
block all task execution; reconciler repairs the lease columns.
"""
try:
with db_session() as conn:
return IdempotencyRepository(conn).try_claim_lease(
key=key,
task_name=task_name,
task_id=task_id,
owner_id=owner_id,
ttl_seconds=LEASE_TTL_SECONDS,
)
except Exception:
logger.exception(
"idempotency lease-claim failed for key=%s task=%s", key, task_name,
)
return 1
def _finalize(key: str, result_json: Any, *, status: str) -> None:
"""Best-effort terminal write. Never let DB outage fail the task."""
try:
with db_session() as conn:
IdempotencyRepository(conn).finalize_task(
key=key, result_json=result_json, status=status,
)
except Exception:
logger.exception(
"idempotency finalize failed for key=%s status=%s", key, status,
)
def _release_lease(key: str, owner_id: str) -> None:
"""Best-effort lease release on the wrapper's exception path."""
try:
with db_session() as conn:
IdempotencyRepository(conn).release_lease(key, owner_id)
except Exception:
logger.exception("idempotency release-lease failed for key=%s", key)
def _start_lease_heartbeat(
key: str, owner_id: str,
) -> tuple[threading.Thread, threading.Event]:
"""Spawn a daemon thread that bumps ``lease_expires_at`` every
:data:`LEASE_HEARTBEAT_INTERVAL` seconds until ``stop_event`` fires.
Mirrors ``application.worker._start_ingest_heartbeat`` so the two
durability heartbeats share shape and cadence.
"""
stop_event = threading.Event()
thread = threading.Thread(
target=_lease_heartbeat_loop,
args=(key, owner_id, stop_event, LEASE_HEARTBEAT_INTERVAL),
daemon=True,
name=f"idempotency-lease-heartbeat:{key[:32]}",
)
thread.start()
return thread, stop_event
def _stop_lease_heartbeat(
thread: threading.Thread, stop_event: threading.Event,
) -> None:
"""Signal the heartbeat thread to exit and join with a short timeout."""
stop_event.set()
thread.join(timeout=10)
def _lease_heartbeat_loop(
key: str,
owner_id: str,
stop_event: threading.Event,
interval: int,
) -> None:
"""Refresh the lease until ``stop_event`` is set or ownership is lost.
A failed refresh (rowcount 0) means another worker stole the lease
after expiry — at that point the damage is already possible, so we
log and keep ticking. Don't escalate to thread death; the main task
body needs to keep running so its outcome is at least *recorded*.
"""
while not stop_event.wait(interval):
try:
with db_session() as conn:
still_owned = IdempotencyRepository(conn).refresh_lease(
key=key, owner_id=owner_id, ttl_seconds=LEASE_TTL_SECONDS,
)
if not still_owned:
logger.warning(
"idempotency lease lost mid-task for key=%s "
"(another worker may have taken over)",
key,
)
except Exception:
logger.exception(
"idempotency lease-heartbeat tick failed for key=%s", key,
)
def _safe_task_id(task_self: Any) -> str:
"""Best-effort extraction of ``self.request.id`` from a Celery task."""
try:
request = getattr(task_self, "request", None)
task_id: Optional[str] = (
getattr(request, "id", None) if request is not None else None
)
except Exception:
task_id = None
return task_id or "unknown"

View File

@@ -1,196 +0,0 @@
"""Reconciler tick: sweep stuck rows and escalate to terminal status + alert."""
from __future__ import annotations
import logging
import uuid
from typing import Any, Dict, Optional
from sqlalchemy import Connection
from application.api.user.idempotency import MAX_TASK_ATTEMPTS
from application.core.settings import settings
from application.storage.db.engine import get_engine
from application.storage.db.repositories.reconciliation import (
ReconciliationRepository,
)
from application.storage.db.repositories.stack_logs import StackLogsRepository
logger = logging.getLogger(__name__)
MAX_MESSAGE_RECONCILE_ATTEMPTS = 3
def run_reconciliation() -> Dict[str, Any]:
"""Single tick of the reconciler. Five sweeps, FOR UPDATE SKIP LOCKED.
Stuck ``executed`` tool calls always flip to ``failed`` — operators
handle cleanup manually via the structured alert. The side effect is
assumed to have committed; no automated rollback is attempted.
Stuck ``task_dedup`` rows (lease expired AND attempts >= max)
promote to ``failed`` so a same-key retry can re-claim instead of
sitting in ``pending`` until 24 h TTL.
"""
if not settings.POSTGRES_URI:
return {
"messages_failed": 0,
"tool_calls_failed": 0,
"skipped": "POSTGRES_URI not set",
}
engine = get_engine()
summary = {
"messages_failed": 0,
"tool_calls_failed": 0,
"ingests_stalled": 0,
"idempotency_pending_failed": 0,
}
with engine.begin() as conn:
repo = ReconciliationRepository(conn)
for msg in repo.find_and_lock_stuck_messages():
new_count = repo.increment_message_reconcile_attempts(msg["id"])
if new_count >= MAX_MESSAGE_RECONCILE_ATTEMPTS:
repo.mark_message_failed(
msg["id"],
error=(
"reconciler: stuck in pending/streaming for >5 min "
f"after {new_count} attempts"
),
)
summary["messages_failed"] += 1
_emit_alert(
conn,
name="reconciler_message_failed",
user_id=msg.get("user_id"),
detail={
"message_id": str(msg["id"]),
"attempts": new_count,
},
)
with engine.begin() as conn:
repo = ReconciliationRepository(conn)
for row in repo.find_and_lock_proposed_tool_calls():
repo.mark_tool_call_failed(
row["call_id"],
error=(
"reconciler: stuck in 'proposed' for >5 min; "
"side effect status unknown"
),
)
summary["tool_calls_failed"] += 1
_emit_alert(
conn,
name="reconciler_tool_call_failed_proposed",
user_id=None,
detail={
"call_id": row["call_id"],
"tool_name": row.get("tool_name"),
},
)
with engine.begin() as conn:
repo = ReconciliationRepository(conn)
for row in repo.find_and_lock_executed_tool_calls():
repo.mark_tool_call_failed(
row["call_id"],
error=(
"reconciler: executed-not-confirmed; side effect "
"assumed committed, manual cleanup required"
),
)
summary["tool_calls_failed"] += 1
_emit_alert(
conn,
name="reconciler_tool_call_failed_executed",
user_id=None,
detail={
"call_id": row["call_id"],
"tool_name": row.get("tool_name"),
"action_name": row.get("action_name"),
},
)
# Q4: ingest checkpoints whose heartbeat has gone silent. The
# reconciler only escalates (alerts) — it doesn't kill the worker
# or roll back the partial embed. The next dispatch resumes from
# ``last_index`` thanks to the per-chunk checkpoint, so this is an
# observability sweep, not a recovery action.
with engine.begin() as conn:
repo = ReconciliationRepository(conn)
for row in repo.find_and_lock_stalled_ingests():
summary["ingests_stalled"] += 1
_emit_alert(
conn,
name="reconciler_ingest_stalled",
user_id=None,
detail={
"source_id": str(row.get("source_id")),
"embedded_chunks": row.get("embedded_chunks"),
"total_chunks": row.get("total_chunks"),
"last_updated": str(row.get("last_updated")),
},
)
# Bump the heartbeat so we don't re-alert every tick.
repo.touch_ingest_progress(str(row["source_id"]))
# Q5: idempotency rows whose lease expired with attempts exhausted.
# The wrapper's poison-loop guard normally finalises these, but if
# the wrapper itself died mid-task (worker SIGKILL, OOM during
# heartbeat) the row sits in ``pending`` blocking same-key retries
# via ``_lookup_completed`` returning None for the whole 24 h TTL.
# Promote to ``failed`` so a retry can re-claim and either resume
# or fail loudly.
with engine.begin() as conn:
repo = ReconciliationRepository(conn)
for row in repo.find_stuck_idempotency_pending(
max_attempts=MAX_TASK_ATTEMPTS,
):
error_msg = (
"reconciler: idempotency lease expired with attempts "
f"({row['attempt_count']}) >= {MAX_TASK_ATTEMPTS}; "
"task abandoned"
)
repo.mark_idempotency_pending_failed(
row["idempotency_key"], error=error_msg,
)
summary["idempotency_pending_failed"] += 1
_emit_alert(
conn,
name="reconciler_idempotency_pending_failed",
user_id=None,
detail={
"idempotency_key": row["idempotency_key"],
"task_name": row.get("task_name"),
"task_id": row.get("task_id"),
"attempts": row.get("attempt_count"),
},
)
return summary
def _emit_alert(
conn: Connection,
*,
name: str,
user_id: Optional[str],
detail: Dict[str, Any],
) -> None:
"""Structured ``logger.error`` plus a ``stack_logs`` row for operators."""
extra = {"alert": name, **detail}
logger.error("reconciler alert: %s", name, extra=extra)
try:
StackLogsRepository(conn).insert(
activity_id=str(uuid.uuid4()),
endpoint="reconciliation_worker",
level="ERROR",
user_id=user_id,
query=name,
stacks=[extra],
)
except Exception:
logger.exception("reconciler: failed to write stack_logs row for %s", name)

View File

@@ -3,20 +3,16 @@
import json
import os
import tempfile
import uuid
import zipfile
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, Namespace, Resource
from sqlalchemy import text as sql_text
from application.api import api
from application.api.user.tasks import ingest, ingest_connector_task, ingest_remote
from application.core.settings import settings
from application.storage.db.source_ids import derive_source_id as _derive_source_id
from application.parser.connectors.connector_creator import ConnectorCreator
from application.parser.file.constants import SUPPORTED_SOURCE_EXTENSIONS
from application.storage.db.repositories.idempotency import IdempotencyRepository
from application.storage.db.repositories.sources import SourcesRepository
from application.storage.db.session import db_readonly, db_session
from application.storage.storage_creator import StorageCreator
@@ -34,91 +30,6 @@ sources_upload_ns = Namespace(
)
_IDEMPOTENCY_KEY_MAX_LEN = 256
def _read_idempotency_key():
"""Return (key, error_response). Empty header → (None, None); oversized → (None, 400)."""
key = request.headers.get("Idempotency-Key")
if not key:
return None, None
if len(key) > _IDEMPOTENCY_KEY_MAX_LEN:
return None, make_response(
jsonify(
{
"success": False,
"message": (
f"Idempotency-Key exceeds maximum length of "
f"{_IDEMPOTENCY_KEY_MAX_LEN} characters"
),
}
),
400,
)
return key, None
def _scoped_idempotency_key(idempotency_key, scope):
"""``{scope}:{key}`` so different users can't collide on the same key."""
if not idempotency_key or not scope:
return None
return f"{scope}:{idempotency_key}"
def _claim_task_or_get_cached(key, task_name):
"""Claim ``key`` for this request OR return the winner's cached payload.
Pre-generates the celery task_id so a losing writer sees the same
id immediately. Returns ``(task_id, cached_response)``; non-None
cached means the caller should return without enqueuing. The
cached payload mirrors the fresh-request response shape (including
``source_id``) so the frontend can correlate SSE ingest events to
the cached upload task without an extra round-trip — but only when
the cached row actually exists; the "deduplicated" sentinel
deliberately omits ``source_id`` so the frontend doesn't bind to a
phantom source.
"""
predetermined_id = str(uuid.uuid4())
with db_session() as conn:
claimed = IdempotencyRepository(conn).claim_task(
key=key, task_name=task_name, task_id=predetermined_id,
)
if claimed is not None:
return claimed["task_id"], None
with db_readonly() as conn:
existing = IdempotencyRepository(conn).get_task(key)
cached_id = existing.get("task_id") if existing else None
payload: dict = {
"success": True,
"task_id": cached_id or "deduplicated",
}
# Only surface ``source_id`` when there's a real winner whose worker
# is publishing SSE events tagged with that id. The "deduplicated"
# branch means the lock row vanished — we have nothing to correlate.
if cached_id is not None:
payload["source_id"] = str(_derive_source_id(key))
return None, payload
def _release_claim(key):
"""Drop a pending claim so a client retry can re-claim it."""
try:
with db_session() as conn:
conn.execute(
sql_text(
"DELETE FROM task_dedup WHERE idempotency_key = :k "
"AND status = 'pending'"
),
{"k": key},
)
except Exception:
current_app.logger.exception(
"Failed to release task_dedup claim for key=%s", key,
)
def _enforce_audio_path_size_limit(file_path: str, filename: str) -> None:
if not is_audio_filename(filename):
return
@@ -138,38 +49,17 @@ class UploadFile(Resource):
)
)
@api.doc(
description=(
"Uploads a file to be vectorized and indexed. Honors an optional "
"``Idempotency-Key`` header: a repeat request with the same key "
"within 24h returns the original cached response without re-enqueuing."
),
description="Uploads a file to be vectorized and indexed",
)
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
idempotency_key, key_error = _read_idempotency_key()
if key_error is not None:
return key_error
# User-scoped to avoid cross-user collisions; also feeds
# ``_derive_source_id`` so uuid5 stays user-disjoint.
scoped_key = _scoped_idempotency_key(idempotency_key, user)
# Claim before enqueue; the loser returns the winner's task_id.
predetermined_task_id = None
if scoped_key:
predetermined_task_id, cached = _claim_task_or_get_cached(
scoped_key, "ingest",
)
if cached is not None:
return make_response(jsonify(cached), 200)
data = request.form
files = request.files.getlist("file")
required_fields = ["user", "name"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields or not files or all(file.filename == "" for file in files):
if scoped_key:
_release_claim(scoped_key)
return make_response(
jsonify(
{
@@ -179,6 +69,7 @@ class UploadFile(Resource):
),
400,
)
user = decoded_token.get("sub")
job_name = request.form["name"]
# Create safe versions for filesystem operations
@@ -249,37 +140,16 @@ class UploadFile(Resource):
file_path = f"{base_path}/{safe_file}"
with open(temp_file_path, "rb") as f:
storage.save_file(f, file_path)
# Mint the source UUID up here so the HTTP response and the
# worker's SSE envelopes share one id. With an idempotency
# key we reuse the deterministic uuid5 (retried task lands on
# the same source row); without a key we fall back to uuid4.
# The worker is told to use this id verbatim — see
# ``ingest_worker(source_id=...)``.
source_uuid = (
_derive_source_id(scoped_key) if scoped_key else uuid.uuid4()
task = ingest.delay(
settings.UPLOAD_FOLDER,
list(SUPPORTED_SOURCE_EXTENSIONS),
job_name,
user,
file_path=base_path,
filename=dir_name,
file_name_map=file_name_map,
)
ingest_kwargs = dict(
args=(
settings.UPLOAD_FOLDER,
list(SUPPORTED_SOURCE_EXTENSIONS),
job_name,
user,
),
kwargs={
"file_path": base_path,
"filename": dir_name,
"file_name_map": file_name_map,
# Scoped so the worker dedup row matches the HTTP claim.
"idempotency_key": scoped_key or idempotency_key,
"source_id": str(source_uuid),
},
)
if predetermined_task_id is not None:
ingest_kwargs["task_id"] = predetermined_task_id
task = ingest.apply_async(**ingest_kwargs)
except AudioFileTooLargeError:
if scoped_key:
_release_claim(scoped_key)
return make_response(
jsonify(
{
@@ -291,21 +161,8 @@ class UploadFile(Resource):
)
except Exception as err:
current_app.logger.error(f"Error uploading file: {err}", exc_info=True)
if scoped_key:
_release_claim(scoped_key)
return make_response(jsonify({"success": False}), 400)
# Predetermined id matches the dedup-claim row; loser GET sees same.
response_task_id = predetermined_task_id or task.id
# ``source_uuid`` was minted above and passed to the worker as
# ``source_id``; the worker uses it verbatim for every SSE event,
# so the frontend can correlate inbound ``source.ingest.*`` to
# this upload regardless of whether an idempotency key was set.
response_payload: dict = {
"success": True,
"task_id": response_task_id,
"source_id": str(source_uuid),
}
return make_response(jsonify(response_payload), 200)
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
@sources_upload_ns.route("/remote")
@@ -325,50 +182,17 @@ class UploadRemote(Resource):
)
)
@api.doc(
description=(
"Uploads remote source for vectorization. Honors an optional "
"``Idempotency-Key`` header: a repeat request with the same key "
"within 24h returns the original cached response without re-enqueuing."
),
description="Uploads remote source for vectorization",
)
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
idempotency_key, key_error = _read_idempotency_key()
if key_error is not None:
return key_error
scoped_key = _scoped_idempotency_key(idempotency_key, user)
data = request.form
required_fields = ["user", "source", "name", "data"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
task_name_for_dedup = (
"ingest_connector_task"
if data.get("source") in ConnectorCreator.get_supported_connectors()
else "ingest_remote"
)
predetermined_task_id = None
if scoped_key:
predetermined_task_id, cached = _claim_task_or_get_cached(
scoped_key, task_name_for_dedup,
)
if cached is not None:
return make_response(jsonify(cached), 200)
# Mint the source UUID up here so the HTTP response and the
# worker's SSE envelopes share one id. Same pattern as
# ``UploadFile.post``: with an idempotency key we reuse the
# deterministic uuid5 (retried task lands on the same source
# row); without a key we fall back to uuid4. The worker is told
# to use this id verbatim — see ``remote_worker`` and
# ``ingest_connector``. Without this the no-key path would mint
# a random uuid4 inside the worker that the frontend has no way
# to correlate SSE events to.
source_uuid = (
_derive_source_id(scoped_key) if scoped_key else uuid.uuid4()
)
try:
config = json.loads(data["data"])
source_data = None
@@ -384,8 +208,6 @@ class UploadRemote(Resource):
elif data["source"] in ConnectorCreator.get_supported_connectors():
session_token = config.get("session_token")
if not session_token:
if scoped_key:
_release_claim(scoped_key)
return make_response(
jsonify(
{
@@ -414,62 +236,31 @@ class UploadRemote(Resource):
config["file_ids"] = file_ids
config["folder_ids"] = folder_ids
connector_kwargs = {
"kwargs": {
"job_name": data["name"],
"user": user,
"source_type": data["source"],
"session_token": session_token,
"file_ids": file_ids,
"folder_ids": folder_ids,
"recursive": config.get("recursive", False),
"retriever": config.get("retriever", "classic"),
"idempotency_key": scoped_key or idempotency_key,
"source_id": str(source_uuid),
},
}
if predetermined_task_id is not None:
connector_kwargs["task_id"] = predetermined_task_id
task = ingest_connector_task.apply_async(**connector_kwargs)
response_task_id = predetermined_task_id or task.id
# ``source_uuid`` was minted above and passed to the
# worker as ``source_id``; the worker uses it verbatim
# for every SSE event, so the frontend can correlate
# inbound ``source.ingest.*`` regardless of whether an
# idempotency key was set.
response_payload = {
"success": True,
"task_id": response_task_id,
"source_id": str(source_uuid),
}
return make_response(jsonify(response_payload), 200)
remote_kwargs = {
"kwargs": {
"source_data": source_data,
"job_name": data["name"],
"user": user,
"loader": data["source"],
"idempotency_key": scoped_key or idempotency_key,
"source_id": str(source_uuid),
},
}
if predetermined_task_id is not None:
remote_kwargs["task_id"] = predetermined_task_id
task = ingest_remote.apply_async(**remote_kwargs)
task = ingest_connector_task.delay(
job_name=data["name"],
user=decoded_token.get("sub"),
source_type=data["source"],
session_token=session_token,
file_ids=file_ids,
folder_ids=folder_ids,
recursive=config.get("recursive", False),
retriever=config.get("retriever", "classic"),
)
return make_response(
jsonify({"success": True, "task_id": task.id}), 200
)
task = ingest_remote.delay(
source_data=source_data,
job_name=data["name"],
user=decoded_token.get("sub"),
loader=data["source"],
)
except Exception as err:
current_app.logger.error(
f"Error uploading remote source: {err}", exc_info=True
)
if scoped_key:
_release_claim(scoped_key)
return make_response(jsonify({"success": False}), 400)
response_task_id = predetermined_task_id or task.id
response_payload = {
"success": True,
"task_id": response_task_id,
"source_id": str(source_uuid),
}
return make_response(jsonify(response_payload), 200)
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
@sources_upload_ns.route("/manage_source_files")
@@ -514,10 +305,6 @@ class ManageSourceFiles(Resource):
jsonify({"success": False, "message": "Unauthorized"}), 401
)
user = decoded_token.get("sub")
idempotency_key, key_error = _read_idempotency_key()
if key_error is not None:
return key_error
scoped_key = _scoped_idempotency_key(idempotency_key, user)
source_id = request.form.get("source_id")
operation = request.form.get("operation")
@@ -560,12 +347,6 @@ class ManageSourceFiles(Resource):
jsonify({"success": False, "message": "Database error"}), 500
)
resolved_source_id = str(source["id"])
# Flips to True after each branch's ``apply_async`` returns
# successfully — at that point the worker owns the predetermined
# task_id. The outer ``except`` only releases the claim while
# this is False, so a post-``apply_async`` failure (jsonify,
# make_response, etc.) doesn't double-enqueue on the next retry.
claim_transferred = False
try:
storage = StorageCreator.get_storage()
source_file_path = source.get("file_path", "")
@@ -598,34 +379,6 @@ class ManageSourceFiles(Resource):
),
400,
)
# Claim before any storage mutation so a duplicate request
# short-circuits without touching the filesystem. Mirrors
# the pattern in ``UploadFile.post`` / ``UploadRemote.post``
# — without it ``.delay()`` would enqueue twice for two
# racing same-key POSTs (the worker decorator only
# deduplicates *after* completion).
predetermined_task_id = None
if scoped_key:
predetermined_task_id, cached = _claim_task_or_get_cached(
scoped_key, "reingest_source_task",
)
if cached is not None:
# Frontend keys reingest polling on
# ``reingest_task_id``; the shared cache helper
# writes ``task_id``. Alias here so a dedup
# response doesn't silently break FileTree's
# poller. Override ``source_id`` too — the
# helper derives it from the scoped key, which
# is correct for upload but wrong for reingest
# (the worker publishes events scoped to the
# actual source row id).
cached_task_id = cached.pop("task_id", None)
if cached_task_id is not None:
cached["reingest_task_id"] = cached_task_id
cached["source_id"] = resolved_source_id
return make_response(jsonify(cached), 200)
added_files = []
map_updated = False
@@ -661,15 +414,9 @@ class ManageSourceFiles(Resource):
from application.api.user.tasks import reingest_source_task
task = reingest_source_task.apply_async(
kwargs={
"source_id": resolved_source_id,
"user": user,
"idempotency_key": scoped_key or idempotency_key,
},
task_id=predetermined_task_id,
task = reingest_source_task.delay(
source_id=resolved_source_id, user=user
)
claim_transferred = True
return make_response(
jsonify(
@@ -679,12 +426,6 @@ class ManageSourceFiles(Resource):
"added_files": added_files,
"parent_dir": parent_dir,
"reingest_task_id": task.id,
# ``source_id`` lets the frontend correlate
# inbound ``source.ingest.*`` SSE events
# (emitted by ``reingest_source_worker``)
# back to the reingest task — matches the
# upload route's source-id contract.
"source_id": resolved_source_id,
}
),
200,
@@ -714,8 +455,10 @@ class ManageSourceFiles(Resource):
),
400,
)
# Path-traversal guard runs *before* the claim so a 400
# for an invalid path doesn't leave a pending dedup row.
# Remove files from storage and directory structure
removed_files = []
map_updated = False
for file_path in file_paths:
if ".." in str(file_path) or str(file_path).startswith("/"):
return make_response(
@@ -727,31 +470,6 @@ class ManageSourceFiles(Resource):
),
400,
)
# Claim before any storage mutation. See ``add`` branch
# comment for rationale.
predetermined_task_id = None
if scoped_key:
predetermined_task_id, cached = _claim_task_or_get_cached(
scoped_key, "reingest_source_task",
)
if cached is not None:
cached_task_id = cached.pop("task_id", None)
if cached_task_id is not None:
cached["reingest_task_id"] = cached_task_id
# Override the helper's synthetic source_id (uuid5
# of the scoped key) with the real source row id
# — the reingest worker publishes SSE events
# scoped to ``resolved_source_id`` and FileTree
# correlates on it.
cached["source_id"] = resolved_source_id
return make_response(jsonify(cached), 200)
# Remove files from storage and directory structure
removed_files = []
map_updated = False
for file_path in file_paths:
full_path = f"{source_file_path}/{file_path}"
# Remove from storage
@@ -773,15 +491,9 @@ class ManageSourceFiles(Resource):
from application.api.user.tasks import reingest_source_task
task = reingest_source_task.apply_async(
kwargs={
"source_id": resolved_source_id,
"user": user,
"idempotency_key": scoped_key or idempotency_key,
},
task_id=predetermined_task_id,
task = reingest_source_task.delay(
source_id=resolved_source_id, user=user
)
claim_transferred = True
return make_response(
jsonify(
@@ -790,7 +502,6 @@ class ManageSourceFiles(Resource):
"message": f"Removed {len(removed_files)} files",
"removed_files": removed_files,
"reingest_task_id": task.id,
"source_id": resolved_source_id,
}
),
200,
@@ -841,24 +552,6 @@ class ManageSourceFiles(Resource):
),
404,
)
# Claim before mutation. See ``add`` branch for rationale.
predetermined_task_id = None
if scoped_key:
predetermined_task_id, cached = _claim_task_or_get_cached(
scoped_key, "reingest_source_task",
)
if cached is not None:
cached_task_id = cached.pop("task_id", None)
if cached_task_id is not None:
cached["reingest_task_id"] = cached_task_id
# Same source_id override as the ``remove`` /
# ``add`` cached branches — the helper's synthetic
# id doesn't match what reingest_source_worker
# tags its SSE events with.
cached["source_id"] = resolved_source_id
return make_response(jsonify(cached), 200)
success = storage.remove_directory(full_directory_path)
if not success:
@@ -867,11 +560,6 @@ class ManageSourceFiles(Resource):
f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}, "
f"Full path: {full_directory_path}"
)
# Release so a client retry can reclaim — otherwise
# the next request would silently 200-cache to the
# task_id that never enqueued.
if scoped_key:
_release_claim(scoped_key)
return make_response(
jsonify(
{"success": False, "message": "Failed to remove directory"}
@@ -903,15 +591,9 @@ class ManageSourceFiles(Resource):
from application.api.user.tasks import reingest_source_task
task = reingest_source_task.apply_async(
kwargs={
"source_id": resolved_source_id,
"user": user,
"idempotency_key": scoped_key or idempotency_key,
},
task_id=predetermined_task_id,
task = reingest_source_task.delay(
source_id=resolved_source_id, user=user
)
claim_transferred = True
return make_response(
jsonify(
@@ -920,20 +602,11 @@ class ManageSourceFiles(Resource):
"message": f"Successfully removed directory: {directory_path}",
"removed_directory": directory_path,
"reingest_task_id": task.id,
"source_id": resolved_source_id,
}
),
200,
)
except Exception as err:
# Release the dedup claim only if it wasn't transferred to
# a worker. Without this, a same-key retry within the 24h
# TTL would 200-cache to a predetermined task_id whose
# ``apply_async`` never ran (or ran but the response builder
# blew up afterward — only the first case matters in
# practice; the flag protects both).
if scoped_key and not claim_transferred:
_release_claim(scoped_key)
error_context = f"operation={operation}, user={user}, source_id={source_id}"
if operation == "remove_directory":
directory_path = request.form.get("directory_path", "")

View File

@@ -1,45 +1,21 @@
from datetime import timedelta
from application.api.user.idempotency import with_idempotency
from application.celery_init import celery
from application.worker import (
agent_webhook_worker,
attachment_worker,
ingest_worker,
mcp_oauth,
mcp_oauth_status,
remote_worker,
sync,
sync_worker,
)
# Shared decorator config for long-running, side-effecting tasks. ``acks_late``
# is also the celeryconfig default but stays explicit here so each task's
# durability story is grep-able next to the body. Combined with
# ``autoretry_for=(Exception,)`` and a bounded ``max_retries`` so a poison
# message can't loop forever.
DURABLE_TASK = dict(
bind=True,
acks_late=True,
autoretry_for=(Exception,),
retry_kwargs={"max_retries": 3, "countdown": 60},
retry_backoff=True,
)
@celery.task(**DURABLE_TASK)
@with_idempotency(task_name="ingest")
@celery.task(bind=True)
def ingest(
self,
directory,
formats,
job_name,
user,
file_path,
filename,
file_name_map=None,
idempotency_key=None,
source_id=None,
self, directory, formats, job_name, user, file_path, filename, file_name_map=None
):
resp = ingest_worker(
self,
@@ -50,40 +26,25 @@ def ingest(
filename,
user,
file_name_map=file_name_map,
idempotency_key=idempotency_key,
source_id=source_id,
)
return resp
@celery.task(**DURABLE_TASK)
@with_idempotency(task_name="ingest_remote")
def ingest_remote(
self, source_data, job_name, user, loader,
idempotency_key=None, source_id=None,
):
resp = remote_worker(
self, source_data, job_name, user, loader,
idempotency_key=idempotency_key,
source_id=source_id,
)
@celery.task(bind=True)
def ingest_remote(self, source_data, job_name, user, loader):
resp = remote_worker(self, source_data, job_name, user, loader)
return resp
@celery.task(**DURABLE_TASK)
@with_idempotency(task_name="reingest_source_task")
def reingest_source_task(self, source_id, user, idempotency_key=None):
@celery.task(bind=True)
def reingest_source_task(self, source_id, user):
from application.worker import reingest_source_worker
resp = reingest_source_worker(self, source_id, user)
return resp
# Beat-driven dispatch tasks default to ``acks_late=False``: a SIGKILL
# of a beat tick is harmless to redeliver only if the dispatch itself is
# idempotent. We keep these early-ACK so the broker doesn't replay a
# dispatch that already enqueued downstream work.
@celery.task(bind=True, acks_late=False)
@celery.task(bind=True)
def schedule_syncs(self, frequency):
resp = sync_worker(self, frequency)
return resp
@@ -113,22 +74,19 @@ def sync_source(
return resp
@celery.task(**DURABLE_TASK)
@with_idempotency(task_name="store_attachment")
def store_attachment(self, file_info, user, idempotency_key=None):
@celery.task(bind=True)
def store_attachment(self, file_info, user):
resp = attachment_worker(self, file_info, user)
return resp
@celery.task(**DURABLE_TASK)
@with_idempotency(task_name="process_agent_webhook")
def process_agent_webhook(self, agent_id, payload, idempotency_key=None):
@celery.task(bind=True)
def process_agent_webhook(self, agent_id, payload):
resp = agent_webhook_worker(self, agent_id, payload)
return resp
@celery.task(**DURABLE_TASK)
@with_idempotency(task_name="ingest_connector_task")
@celery.task(bind=True)
def ingest_connector_task(
self,
job_name,
@@ -142,8 +100,6 @@ def ingest_connector_task(
operation_mode="upload",
doc_id=None,
sync_frequency="never",
idempotency_key=None,
source_id=None,
):
from application.worker import ingest_connector
@@ -160,8 +116,6 @@ def ingest_connector_task(
operation_mode=operation_mode,
doc_id=doc_id,
sync_frequency=sync_frequency,
idempotency_key=idempotency_key,
source_id=source_id,
)
return resp
@@ -186,33 +140,11 @@ def setup_periodic_tasks(sender, **kwargs):
cleanup_pending_tool_state.s(),
name="cleanup-pending-tool-state",
)
# Pure housekeeping for ``task_dedup`` / ``webhook_dedup`` — the
# upsert paths already handle stale rows, so cadence only bounds
# table size. Hourly is plenty for typical traffic.
sender.add_periodic_task(
timedelta(hours=1),
cleanup_idempotency_dedup.s(),
name="cleanup-idempotency-dedup",
)
sender.add_periodic_task(
timedelta(seconds=30),
reconciliation_task.s(),
name="reconciliation",
)
sender.add_periodic_task(
timedelta(hours=7),
version_check_task.s(),
name="version-check",
)
# Bound ``message_events`` growth — every streamed SSE chunk writes
# one row, so retained chats accumulate hundreds of rows per
# message. Reconnect-replay is only meaningful for streams the user
# could plausibly still be waiting on, so 14 days is generous.
sender.add_periodic_task(
timedelta(hours=24),
cleanup_message_events.s(),
name="cleanup-message-events",
)
@celery.task(bind=True)
@@ -221,12 +153,24 @@ def mcp_oauth_task(self, config, user):
return resp
@celery.task(bind=True, acks_late=False)
@celery.task(bind=True)
def mcp_oauth_status_task(self, task_id):
resp = mcp_oauth_status(self, task_id)
return resp
@celery.task(bind=True)
def cleanup_pending_tool_state(self):
"""Revert stale ``resuming`` rows, then delete TTL-expired rows."""
"""Delete pending_tool_state rows past their TTL.
Replaces Mongo's ``expireAfterSeconds=0`` TTL index — Postgres has
no native TTL, so this task runs every 60 seconds to keep
``pending_tool_state`` bounded. No-ops if ``POSTGRES_URI`` isn't
configured (keeps the task runnable in Mongo-only environments).
"""
from application.core.settings import settings
if not settings.POSTGRES_URI:
return {"deleted": 0, "reverted": 0, "skipped": "POSTGRES_URI not set"}
return {"deleted": 0, "skipped": "POSTGRES_URI not set"}
from application.storage.db.engine import get_engine
from application.storage.db.repositories.pending_tool_state import (
@@ -235,73 +179,11 @@ def cleanup_pending_tool_state(self):
engine = get_engine()
with engine.begin() as conn:
repo = PendingToolStateRepository(conn)
reverted = repo.revert_stale_resuming(grace_seconds=600)
deleted = repo.cleanup_expired()
return {"deleted": deleted, "reverted": reverted}
deleted = PendingToolStateRepository(conn).cleanup_expired()
return {"deleted": deleted}
@celery.task(bind=True, acks_late=False)
def cleanup_idempotency_dedup(self):
"""Delete TTL-expired rows from ``task_dedup`` and ``webhook_dedup``.
Pure housekeeping — the upsert paths already ignore stale rows
(TTL-aware ``ON CONFLICT DO UPDATE``), so this only bounds table
growth and keeps SELECT planning tight on large deployments.
"""
from application.core.settings import settings
if not settings.POSTGRES_URI:
return {
"task_dedup_deleted": 0,
"webhook_dedup_deleted": 0,
"skipped": "POSTGRES_URI not set",
}
from application.storage.db.engine import get_engine
from application.storage.db.repositories.idempotency import (
IdempotencyRepository,
)
engine = get_engine()
with engine.begin() as conn:
return IdempotencyRepository(conn).cleanup_expired()
@celery.task(bind=True, acks_late=False)
def reconciliation_task(self):
"""Sweep stuck durability rows and escalate them to terminal status + alert."""
from application.api.user.reconciliation import run_reconciliation
return run_reconciliation()
@celery.task(bind=True, acks_late=False)
def cleanup_message_events(self):
"""Delete ``message_events`` rows older than the retention window.
Streamed answer responses write one journal row per SSE yield,
so unbounded growth would dominate Postgres for any retained-
conversations deployment. The reconnect-replay path only needs
rows for in-flight streams; 14 days covers paused/tool-action
flows comfortably.
"""
from application.core.settings import settings
if not settings.POSTGRES_URI:
return {"deleted": 0, "skipped": "POSTGRES_URI not set"}
from application.storage.db.engine import get_engine
from application.storage.db.repositories.message_events import (
MessageEventsRepository,
)
ttl_days = settings.MESSAGE_EVENTS_RETENTION_DAYS
engine = get_engine()
with engine.begin() as conn:
deleted = MessageEventsRepository(conn).cleanup_older_than(ttl_days)
return {"deleted": deleted, "ttl_days": ttl_days}
@celery.task(bind=True, acks_late=False)
@celery.task(bind=True)
def version_check_task(self):
"""Periodic anonymous version check.

View File

@@ -1,5 +1,6 @@
"""Tool management MCP server integration."""
import json
from urllib.parse import urlencode, urlparse
from flask import current_app, jsonify, make_response, redirect, request
@@ -225,9 +226,7 @@ class MCPServerSave(Resource):
)
redis_client = get_redis_instance()
manager = MCPOAuthManager(redis_client)
result = manager.get_oauth_status(
config["oauth_task_id"], user
)
result = manager.get_oauth_status(config["oauth_task_id"])
if not result.get("status") == "completed":
return make_response(
jsonify(
@@ -439,6 +438,56 @@ class MCPOAuthCallback(Resource):
)
@tools_mcp_ns.route("/mcp_server/oauth_status/<string:task_id>")
class MCPOAuthStatus(Resource):
def get(self, task_id):
try:
redis_client = get_redis_instance()
status_key = f"mcp_oauth_status:{task_id}"
status_data = redis_client.get(status_key)
if status_data:
status = json.loads(status_data)
if "tools" in status and isinstance(status["tools"], list):
status["tools"] = [
{
"name": t.get("name", "unknown"),
"description": t.get("description", ""),
}
for t in status["tools"]
]
return make_response(
jsonify({"success": True, "task_id": task_id, **status})
)
else:
return make_response(
jsonify(
{
"success": True,
"task_id": task_id,
"status": "pending",
"message": "Waiting for OAuth to start...",
}
),
200,
)
except Exception as e:
current_app.logger.error(
f"Error getting OAuth status for task {task_id}: {str(e)}",
exc_info=True,
)
return make_response(
jsonify(
{
"success": False,
"error": "Failed to get OAuth status",
"task_id": task_id,
}
),
500,
)
@tools_mcp_ns.route("/mcp_server/auth_status")
class MCPAuthStatus(Resource):
@api.doc(

View File

@@ -9,7 +9,6 @@ import json
import logging
import time
import traceback
from datetime import datetime
from typing import Any, Dict, Generator, Optional
from flask import Blueprint, jsonify, make_response, request, Response
@@ -222,26 +221,13 @@ def _stream_response(
for line in internal_stream:
if not line.strip():
continue
# ``complete_stream`` prefixes each frame with ``id: <seq>\n``
# before the ``data:`` line. Extract just the data line so JSON
# decode doesn't choke on the SSE framing.
event_str = ""
for raw in line.split("\n"):
if raw.startswith("data:"):
event_str = raw[len("data:") :].lstrip()
break
if not event_str:
continue
# Parse the internal SSE event
event_str = line.replace("data: ", "").strip()
try:
event_data = json.loads(event_str)
except (json.JSONDecodeError, TypeError):
continue
# Skip the informational ``message_id`` event — it has no v1 /
# OpenAI-compatible analog.
if event_data.get("type") == "message_id":
continue
# Update completion_id when we get the conversation id
if event_data.get("type") == "id":
conv_id = event_data.get("id", "")
@@ -320,16 +306,7 @@ def list_models():
401,
)
# Repository rows now go through ``coerce_pg_native`` at SELECT
# time, so timestamps arrive as ISO 8601 strings. Parse before
# taking ``.timestamp()``; fall back to ``time.time()`` only when
# the value is genuinely missing or unparseable.
created = agent.get("created_at") or agent.get("createdAt")
if isinstance(created, str):
try:
created = datetime.fromisoformat(created)
except (ValueError, TypeError):
created = None
created_ts = (
int(created.timestamp()) if hasattr(created, "timestamp")
else int(time.time())

View File

@@ -9,15 +9,12 @@ from jose import jwt
from application.auth import handle_auth
from application.core import log_context
from application.core.logging_config import setup_logging
setup_logging()
from application.api import api # noqa: E402
from application.api.answer import answer # noqa: E402
from application.api.answer.routes.messages import messages_bp # noqa: E402
from application.api.events.routes import events # noqa: E402
from application.api.internal.routes import internal # noqa: E402
from application.api.user.routes import user # noqa: E402
from application.api.connector.routes import connector # noqa: E402
@@ -51,8 +48,6 @@ ensure_database_ready(
app = Flask(__name__)
app.register_blueprint(user)
app.register_blueprint(answer)
app.register_blueprint(events)
app.register_blueprint(messages_bp)
app.register_blueprint(internal)
app.register_blueprint(connector)
app.register_blueprint(v1_bp)
@@ -117,38 +112,6 @@ def generate_token():
return jsonify({"error": "Token generation not allowed in current auth mode"}), 400
_LOG_CTX_TOKEN_ATTR = "_log_ctx_token"
@app.before_request
def _bind_log_context():
"""Bind activity_id + endpoint for the duration of this request.
Runs before ``authenticate_request``; ``user_id`` is overlaid in a
follow-up handler once the JWT has been decoded.
"""
if request.method == "OPTIONS":
return None
activity_id = str(uuid.uuid4())
request.activity_id = activity_id
token = log_context.bind(
activity_id=activity_id,
endpoint=request.endpoint,
)
setattr(request, _LOG_CTX_TOKEN_ATTR, token)
return None
@app.teardown_request
def _reset_log_context(_exc):
# SSE streams keep yielding after teardown fires, but a2wsgi runs each
# request inside ``copy_context().run(...)``, so this reset doesn't
# leak into the stream's view of the context.
token = getattr(request, _LOG_CTX_TOKEN_ATTR, None)
if token is not None:
log_context.reset(token)
@app.before_request
def enforce_stt_request_size_limits():
if request.method == "OPTIONS":
@@ -185,28 +148,11 @@ def authenticate_request():
request.decoded_token = decoded_token
@app.before_request
def _bind_user_id_to_log_context():
# Registered after ``authenticate_request`` (Flask runs before_request
# handlers in registration order), so ``request.decoded_token`` is
# populated by the time we read it. ``teardown_request`` unwinds the
# whole request-level bind, so no separate reset token is needed here.
if request.method == "OPTIONS":
return None
decoded_token = getattr(request, "decoded_token", None)
user_id = decoded_token.get("sub") if isinstance(decoded_token, dict) else None
if user_id:
log_context.bind(user_id=user_id)
return None
@app.after_request
def after_request(response: Response) -> Response:
"""Add CORS headers for the pure Flask development entrypoint."""
response.headers["Access-Control-Allow-Origin"] = "*"
response.headers["Access-Control-Allow-Headers"] = (
"Content-Type, Authorization, Idempotency-Key"
)
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization"
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
return response

View File

@@ -25,12 +25,7 @@ asgi_app = Starlette(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
allow_headers=[
"Content-Type",
"Authorization",
"Mcp-Session-Id",
"Idempotency-Key",
],
allow_headers=["Content-Type", "Authorization", "Mcp-Session-Id"],
expose_headers=["Mcp-Session-Id"],
),
],

View File

@@ -1,4 +1,3 @@
import hashlib
import json
import logging
import time
@@ -11,14 +10,6 @@ from application.utils import get_hash
logger = logging.getLogger(__name__)
def _cache_default(value):
# Image attachments arrive inline as bytes (see GoogleLLM.prepare_messages_with_attachments);
# hash so the cache key stays bounded in size and stable across identical content.
if isinstance(value, (bytes, bytearray, memoryview)):
return f"<bytes:sha256:{hashlib.sha256(bytes(value)).hexdigest()}>"
return repr(value)
_redis_instance = None
_redis_creation_failed = False
_instance_lock = Lock()
@@ -29,17 +20,8 @@ def get_redis_instance():
with _instance_lock:
if _redis_instance is None and not _redis_creation_failed:
try:
# ``health_check_interval`` makes redis-py ping the
# connection every N seconds when otherwise idle.
# Without it, a half-open TCP (NAT silently dropped
# state, ELB idle-close) can hang the SSE generator
# in ``pubsub.get_message`` past its keepalive
# cadence — the kernel never surfaces the dead
# socket because no payload is in flight.
_redis_instance = redis.Redis.from_url(
settings.CACHE_REDIS_URL,
socket_connect_timeout=2,
health_check_interval=10,
settings.CACHE_REDIS_URL, socket_connect_timeout=2
)
except ValueError as e:
logger.error(f"Invalid Redis URL: {e}")
@@ -54,7 +36,7 @@ def get_redis_instance():
def gen_cache_key(messages, model="docgpt", tools=None):
if not all(isinstance(msg, dict) for msg in messages):
raise ValueError("All messages must be dictionaries.")
messages_str = json.dumps(messages, default=_cache_default)
messages_str = json.dumps(messages)
tools_str = json.dumps(str(tools)) if tools else ""
combined = f"{model}_{messages_str}_{tools_str}"
cache_key = get_hash(combined)

View File

@@ -1,17 +1,8 @@
import inspect
import logging
import threading
from celery import Celery
from application.core import log_context
from application.core.settings import settings
from celery.signals import (
setup_logging,
task_postrun,
task_prerun,
worker_process_init,
worker_ready,
)
from celery.signals import setup_logging, worker_process_init, worker_ready
def make_celery(app_name=__name__):
@@ -50,54 +41,6 @@ def _dispose_db_engine_on_fork(*args, **kwargs):
dispose_engine()
# Most tasks in this repo accept ``user`` where the log context wants
# ``user_id``; map task parameter names to context keys explicitly.
_TASK_PARAM_TO_CTX_KEY: dict[str, str] = {
"user": "user_id",
"user_id": "user_id",
"agent_id": "agent_id",
"conversation_id": "conversation_id",
}
_task_log_tokens: dict[str, object] = {}
@task_prerun.connect
def _bind_task_log_context(task_id, task, args, kwargs, **_):
# Resolve task args by parameter name — nearly every task in this repo
# is called positionally, so ``kwargs.get('user')`` would bind nothing.
ctx = {"activity_id": task_id}
try:
sig = inspect.signature(task.run)
bound = sig.bind_partial(*args, **kwargs).arguments
except (TypeError, ValueError):
bound = dict(kwargs)
for param_name, value in bound.items():
ctx_key = _TASK_PARAM_TO_CTX_KEY.get(param_name)
if ctx_key and value:
ctx[ctx_key] = value
_task_log_tokens[task_id] = log_context.bind(**ctx)
@task_postrun.connect
def _unbind_task_log_context(task_id, **_):
# ``task_postrun`` fires on both success and failure. Required for
# Celery: unlike the Flask path, tasks aren't isolated in their own
# ``copy_context().run(...)``, so a missing reset would leak the
# bind onto the next task on the same worker.
token = _task_log_tokens.pop(task_id, None)
if token is None:
return
try:
log_context.reset(token)
except ValueError:
# task_prerun and task_postrun ran on different threads (non-default
# Celery pool); the token isn't valid in this context. Drop it.
logging.getLogger(__name__).debug(
"log_context reset skipped for task %s", task_id
)
@worker_ready.connect
def _run_version_check(*args, **kwargs):
"""Kick off the anonymous version check on worker startup.

View File

@@ -1,10 +1,7 @@
from application.core.settings import settings
import os
# Pydantic loads .env into ``settings`` but does not inject values into
# ``os.environ`` — read directly from settings so beat startup (which
# imports this module before any explicit env load) sees a real URL.
broker_url = settings.CELERY_BROKER_URL
result_backend = settings.CELERY_RESULT_BACKEND
broker_url = os.getenv("CELERY_BROKER_URL")
result_backend = os.getenv("CELERY_RESULT_BACKEND")
task_serializer = 'json'
result_serializer = 'json'
@@ -13,21 +10,7 @@ accept_content = ['json']
# Autodiscover tasks
imports = ('application.api.user.tasks',)
# Project-scoped queue so a stray sibling worker on the same broker
# (other repo, same default ``celery`` queue) can't grab DocsGPT tasks.
task_default_queue = "docsgpt"
task_default_exchange = "docsgpt"
task_default_routing_key = "docsgpt"
beat_scheduler = "redbeat.RedBeatScheduler"
redbeat_redis_url = broker_url
redbeat_key_prefix = "redbeat:docsgpt:"
redbeat_lock_timeout = 90
# Survive worker SIGKILL/OOM without silently dropping in-flight tasks.
task_acks_late = True
task_reject_on_worker_lost = True
worker_prefetch_multiplier = settings.CELERY_WORKER_PREFETCH_MULTIPLIER
broker_transport_options = {"visibility_timeout": settings.CELERY_VISIBILITY_TIMEOUT}
result_expires = 86400 * 7
task_track_started = True

View File

@@ -1,57 +0,0 @@
"""Per-activity logging context backed by ``contextvars``.
The ``_ContextFilter`` installed by ``logging_config.setup_logging`` stamps
every ``LogRecord`` emitted inside a ``bind`` block with the bound keys, so
they land as first-class attributes on the OTLP log export rather than being
buried inside formatted message bodies.
A single ``ContextVar`` holds a dict so nested binds reset atomically (LIFO)
via the token returned by ``bind``.
"""
from __future__ import annotations
from contextvars import ContextVar, Token
from typing import Mapping
_CTX_KEYS: frozenset[str] = frozenset(
{
"activity_id",
"parent_activity_id",
"user_id",
"agent_id",
"conversation_id",
"endpoint",
"model",
}
)
_ctx: ContextVar[Mapping[str, str]] = ContextVar("log_ctx", default={})
def bind(**kwargs: object) -> Token:
"""Overlay the given keys onto the current context.
Returns a ``Token`` so the caller can ``reset`` in a ``finally`` block.
Keys outside :data:`_CTX_KEYS` are silently dropped (so a typo can't
stamp a stray field name onto every record), as are ``None`` values
(a missing attribute is more useful than the literal string ``"None"``).
"""
overlay = {
k: str(v)
for k, v in kwargs.items()
if k in _CTX_KEYS and v is not None
}
new = {**_ctx.get(), **overlay}
return _ctx.set(new)
def reset(token: Token) -> None:
"""Restore the context to the snapshot captured by the matching ``bind``."""
_ctx.reset(token)
def snapshot() -> Mapping[str, str]:
"""Return the current context dict. Treat as read-only; use :func:`bind`."""
return _ctx.get()

View File

@@ -2,36 +2,6 @@ import logging
import os
from logging.config import dictConfig
from application.core.log_context import snapshot as _ctx_snapshot
# Loggers with ``propagate=False`` don't share root's handlers, so the
# context filter has to be installed on their handlers directly.
_NON_PROPAGATING_LOGGERS: tuple[str, ...] = (
"uvicorn",
"uvicorn.access",
"uvicorn.error",
"celery.app.trace",
"celery.worker.strategy",
"gunicorn.error",
"gunicorn.access",
)
class _ContextFilter(logging.Filter):
"""Stamp the current ``log_context`` snapshot onto every ``LogRecord``.
Must be installed on **handlers**, not loggers: Python skips logger-level
filters when a child logger's record propagates up. The ``hasattr`` guard
keeps an explicit ``logger.info(..., extra={...})`` from being overwritten.
"""
def filter(self, record: logging.LogRecord) -> bool:
for key, value in _ctx_snapshot().items():
if not hasattr(record, key):
setattr(record, key, value)
return True
def _otlp_logs_enabled() -> bool:
"""Return True when the user has opted in to OTLP log export.
@@ -90,23 +60,3 @@ def setup_logging() -> None:
for handler in preserved_handlers:
if handler not in root.handlers:
root.addHandler(handler)
_install_context_filter()
def _install_context_filter() -> None:
"""Attach :class:`_ContextFilter` to root's handlers + every handler on
the known non-propagating loggers. Skipping handlers that already carry
one keeps repeat ``setup_logging`` calls from stacking filters.
"""
def _has_ctx_filter(handler: logging.Handler) -> bool:
return any(isinstance(f, _ContextFilter) for f in handler.filters)
for handler in logging.getLogger().handlers:
if not _has_ctx_filter(handler):
handler.addFilter(_ContextFilter())
for name in _NON_PROPAGATING_LOGGERS:
for handler in logging.getLogger(name).handlers:
if not _has_ctx_filter(handler):
handler.addFilter(_ContextFilter())

View File

@@ -30,12 +30,6 @@ class Settings(BaseSettings):
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
# Prefetch=1 caps SIGKILL loss to one task. Visibility timeout must exceed
# the longest legitimate task runtime (ingest, agent webhook) but stay
# short enough that SIGKILLed tasks redeliver promptly. 1h matches Onyx
# and Dify defaults; long ingests can override via env.
CELERY_WORKER_PREFETCH_MULTIPLIER: int = 1
CELERY_VISIBILITY_TIMEOUT: int = 3600
# Only consulted when VECTOR_STORE=mongodb or when running scripts/db/backfill.py; user data lives in Postgres.
MONGO_URI: Optional[str] = None
# User-data Postgres DB.
@@ -188,42 +182,6 @@ class Settings(BaseSettings):
COMPRESSION_PROMPT_VERSION: str = "v1.0" # Track prompt iterations
COMPRESSION_MAX_HISTORY_POINTS: int = 3 # Keep only last N compression points to prevent DB bloat
# Internal SSE push channel (notifications + durable replay journal)
# Master switch — when False, /api/events emits a "push_disabled" comment
# and returns; clients fall back to polling. Publisher becomes a no-op.
ENABLE_SSE_PUSH: bool = True
# Per-user durable backlog cap (~entries). At typical event rates this
# gives ~24h of replay; tune up for verbose feeds, down for memory.
EVENTS_STREAM_MAXLEN: int = 1000
# SSE keepalive comment cadence. Must sit under Cloudflare's 100s idle
# close and iOS Safari's ~60s — 15s gives generous headroom.
SSE_KEEPALIVE_SECONDS: int = 15
# Cap on simultaneous SSE connections per user. Each connection holds
# one WSGI thread (32 per gunicorn worker) and one Redis pub/sub
# connection. 8 covers normal multi-tab use without letting one user
# starve the pool. Set to 0 to disable the cap.
SSE_MAX_CONCURRENT_PER_USER: int = 8
# Per-request cap on the number of backlog entries XRANGE returns
# for ``/api/events`` snapshots. Bounds the bytes a single replay
# can move from Redis to the wire — a malicious client looping
# ``Last-Event-ID=<oldest>`` reconnects can only enumerate this
# many entries per round-trip. Combined with the per-user
# connection cap above and the windowed budget below, total
# enumeration throughput is bounded.
EVENTS_REPLAY_MAX_PER_REQUEST: int = 200
# Sliding-window cap on snapshot replays per user. Once the budget
# is exhausted the route returns HTTP 429 with the cursor pinned;
# the client backs off and retries after the window rolls over.
EVENTS_REPLAY_BUDGET_REQUESTS_PER_WINDOW: int = 30
EVENTS_REPLAY_BUDGET_WINDOW_SECONDS: int = 60
# Retention for the ``message_events`` journal. The ``cleanup_message_events``
# beat task deletes rows older than this. Reconnect-replay only
# needs the journal for streams a client could still be tailing,
# so 14 days is a generous default that covers paused/tool-action
# flows without unbounded table growth.
MESSAGE_EVENTS_RETENTION_DAYS: int = 14
@field_validator("POSTGRES_URI", mode="before")
@classmethod
def _normalize_postgres_uri_validator(cls, v):

View File

@@ -1,52 +0,0 @@
"""Stream/topic key derivations shared by publisher and SSE consumer.
Single source of truth for the per-user Redis Streams key and pub/sub
topic name. Both must agree exactly — a typo here splits the
publisher's writes from the consumer's reads.
"""
from __future__ import annotations
def stream_key(user_id: str) -> str:
"""Redis Streams key holding the durable backlog for ``user_id``."""
return f"user:{user_id}:stream"
def topic_name(user_id: str) -> str:
"""Redis pub/sub channel used for live fan-out to ``user_id``."""
return f"user:{user_id}"
def connection_counter_key(user_id: str) -> str:
"""Redis counter tracking active SSE connections for ``user_id``."""
return f"user:{user_id}:sse_count"
def replay_budget_key(user_id: str) -> str:
"""Redis counter tracking snapshot replays for ``user_id`` in the
rolling rate-limit window."""
return f"user:{user_id}:replay_count"
def stream_id_compare(a: str, b: str) -> int:
"""Compare two Redis Streams ids. Returns -1, 0, 1 like ``cmp``.
Stream ids are ``ms-seq`` strings; comparing as strings would be wrong
once ``ms`` straddles digit-count boundaries. We parse and compare
as ``(int, int)`` tuples.
Raises ``ValueError`` on malformed input. Callers must pre-validate
against ``_STREAM_ID_RE`` (or equivalent) — a lex fallback here let
a malformed id compare lex-greater than a real one and silently pin
dedup forever.
"""
a_ms, _, a_seq = a.partition("-")
b_ms, _, b_seq = b.partition("-")
a_tuple = (int(a_ms), int(a_seq) if a_seq else 0)
b_tuple = (int(b_ms), int(b_seq) if b_seq else 0)
if a_tuple < b_tuple:
return -1
if a_tuple > b_tuple:
return 1
return 0

View File

@@ -1,144 +0,0 @@
"""User-scoped event publisher: durable backlog + live fan-out.
Each ``publish_user_event`` call writes twice:
1. ``XADD user:{user_id}:stream MAXLEN ~ <cap> * event <json>`` — the
durable backlog used by SSE reconnect (``Last-Event-ID``) and stream
replay. Bounded by ``EVENTS_STREAM_MAXLEN`` (~24h at typical event
rates) so the per-user footprint stays predictable.
2. ``PUBLISH user:{user_id} <json-with-id>`` — live fan-out to every
currently connected SSE generator for the user, across instances.
Together they give a snapshot-plus-tail story: a reconnecting client
reads ``XRANGE`` from its last seen id and then transitions onto the
live pub/sub. The Redis Streams entry id (e.g. ``1735682400000-0``) is
the canonical, monotonically increasing event id and is what
``Last-Event-ID`` carries.
Failures are logged and swallowed: the caller is typically a Celery
task whose primary work has already succeeded, and a notification
delivery miss should not surface as a task failure.
"""
from __future__ import annotations
import json
import logging
from datetime import datetime, timezone
from typing import Any, Optional
from application.cache import get_redis_instance
from application.core.settings import settings
from application.events.keys import stream_key, topic_name
from application.streaming.broadcast_channel import Topic
logger = logging.getLogger(__name__)
def _iso_now() -> str:
"""ISO 8601 UTC with millisecond precision and Z suffix."""
return (
datetime.now(timezone.utc)
.isoformat(timespec="milliseconds")
.replace("+00:00", "Z")
)
def publish_user_event(
user_id: str,
event_type: str,
payload: dict[str, Any],
*,
scope: Optional[dict[str, Any]] = None,
) -> Optional[str]:
"""Publish a user-scoped event; return the Redis Streams id or ``None``.
Fire-and-forget: never raises. ``None`` means the event reached
neither the journal nor live subscribers (see runbook for causes).
"""
if not user_id or not event_type:
logger.warning(
"publish_user_event called without user_id or event_type "
"(user_id=%r, event_type=%r)",
user_id,
event_type,
)
return None
if not settings.ENABLE_SSE_PUSH:
return None
envelope_partial: dict[str, Any] = {
"type": event_type,
"ts": _iso_now(),
"user_id": user_id,
"topic": topic_name(user_id),
"scope": scope or {},
"payload": payload,
}
try:
envelope_partial_json = json.dumps(envelope_partial)
except (TypeError, ValueError) as exc:
logger.warning(
"publish_user_event payload not JSON-serializable: "
"user=%s type=%s err=%s",
user_id,
event_type,
exc,
)
return None
redis = get_redis_instance()
if redis is None:
logger.debug("Redis unavailable; skipping publish_user_event")
return None
maxlen = settings.EVENTS_STREAM_MAXLEN
stream_id: Optional[str] = None
try:
# Auto-id ('*') gives a monotonic ms-seq id that doubles as the
# SSE event id. ``approximate=True`` lets Redis trim in chunks
# for performance; the cap is treated as ~MAXLEN, never <.
result = redis.xadd(
stream_key(user_id),
{"event": envelope_partial_json},
maxlen=maxlen,
approximate=True,
)
stream_id = (
result.decode("utf-8")
if isinstance(result, (bytes, bytearray))
else str(result)
)
except Exception:
logger.exception(
"xadd failed for user=%s event_type=%s", user_id, event_type
)
# If the durable journal write failed there is no canonical id to
# ship — publishing the envelope live would put an id-less record
# on the wire that bypasses the SSE route's dedup floor and breaks
# ``Last-Event-ID`` semantics for any reconnect. Best-effort
# delivery means dropping consistently, not delivering inconsistent
# state.
if stream_id is None:
return None
envelope = dict(envelope_partial)
envelope["id"] = stream_id
try:
Topic(topic_name(user_id)).publish(json.dumps(envelope))
except Exception:
logger.exception(
"publish failed for user=%s event_type=%s", user_id, event_type
)
logger.debug(
"event.published topic=%s type=%s id=%s",
topic_name(user_id),
event_type,
stream_id,
)
return stream_id

View File

@@ -11,7 +11,6 @@ logger = logging.getLogger(__name__)
class AnthropicLLM(BaseLLM):
provider_name = "anthropic"
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):

View File

@@ -1,6 +1,5 @@
import logging
from abc import ABC, abstractmethod
from typing import ClassVar
from application.cache import gen_cache, stream_cache
@@ -11,10 +10,6 @@ logger = logging.getLogger(__name__)
class BaseLLM(ABC):
# Stamped onto the ``llm_stream_start`` event so dashboards can group
# calls by vendor. Subclasses override.
provider_name: ClassVar[str] = "unknown"
def __init__(
self,
decoded_token=None,
@@ -80,14 +75,6 @@ class BaseLLM(ABC):
agent_id=self.agent_id,
model_user_id=self.model_user_id,
)
# Tag the fallback LLM so its rows land as
# ``source='fallback'`` in cost-attribution dashboards.
# Propagate the parent's ``_request_id`` so a user
# request that ran fallback is still grouped under one id.
self._fallback_llm._token_usage_source = "fallback"
self._fallback_llm._request_id = getattr(
self, "_request_id", None,
)
logger.info(
f"Fallback LLM initialized from agent backup model: "
f"{provider}/{backup_model_id}"
@@ -114,11 +101,6 @@ class BaseLLM(ABC):
agent_id=self.agent_id,
model_user_id=self.model_user_id,
)
# Same rationale as the agent-backup branch.
self._fallback_llm._token_usage_source = "fallback"
self._fallback_llm._request_id = getattr(
self, "_request_id", None,
)
logger.info(
f"Fallback LLM initialized from global settings: "
f"{settings.FALLBACK_LLM_PROVIDER}/{settings.FALLBACK_LLM_NAME}"
@@ -136,26 +118,6 @@ class BaseLLM(ABC):
return args_dict
return {k: v for k, v in args_dict.items() if v is not None}
@staticmethod
def _is_non_retriable_client_error(exc: BaseException) -> bool:
"""4xx errors mean the request itself is malformed — retrying with
a different model fails identically and doubles the work. Only
transient/5xx/connection errors should trigger fallback."""
try:
from google.genai.errors import ClientError as _GenaiClientError
if isinstance(exc, _GenaiClientError):
return True
except ImportError:
pass
for attr in ("status_code", "code", "http_status"):
v = getattr(exc, attr, None)
if isinstance(v, int) and 400 <= v < 500:
return True
resp = getattr(exc, "response", None)
v = getattr(resp, "status_code", None)
return isinstance(v, int) and 400 <= v < 500
def _execute_with_fallback(
self, method_name: str, decorators: list, *args, **kwargs
):
@@ -179,18 +141,12 @@ class BaseLLM(ABC):
if is_stream:
return self._stream_with_fallback(
decorated_method, method_name, decorators, *args, **kwargs
decorated_method, method_name, *args, **kwargs
)
try:
return decorated_method()
except Exception as e:
if self._is_non_retriable_client_error(e):
logger.error(
f"Primary LLM failed with non-retriable client error; "
f"skipping fallback: {str(e)}"
)
raise
if not self.fallback_llm:
logger.error(f"Primary LLM failed and no fallback configured: {str(e)}")
raise
@@ -200,27 +156,14 @@ class BaseLLM(ABC):
f"{fallback.model_id}. Error: {str(e)}"
)
# Apply decorators to fallback's raw method directly — calling
# fallback.gen() would re-enter the orchestrator and recurse via
# fallback.fallback_llm.
fallback_method = getattr(fallback, method_name)
for decorator in decorators:
fallback_method = decorator(fallback_method)
fallback_method = getattr(
fallback, method_name.replace("_raw_", "")
)
fallback_kwargs = {**kwargs, "model": fallback.model_id}
try:
return fallback_method(fallback, *args, **fallback_kwargs)
except Exception as e2:
if self._is_non_retriable_client_error(e2):
logger.error(
f"Fallback LLM failed with non-retriable client "
f"error; giving up: {str(e2)}"
)
else:
logger.error(f"Fallback LLM also failed; giving up: {str(e2)}")
raise
return fallback_method(*args, **fallback_kwargs)
def _stream_with_fallback(
self, decorated_method, method_name, decorators, *args, **kwargs
self, decorated_method, method_name, *args, **kwargs
):
"""
Wrapper generator that catches mid-stream errors and falls back.
@@ -233,12 +176,6 @@ class BaseLLM(ABC):
try:
yield from decorated_method()
except Exception as e:
if self._is_non_retriable_client_error(e):
logger.error(
f"Primary LLM failed mid-stream with non-retriable client "
f"error; skipping fallback: {str(e)}"
)
raise
if not self.fallback_llm:
logger.error(
f"Primary LLM failed and no fallback configured: {str(e)}"
@@ -249,37 +186,11 @@ class BaseLLM(ABC):
f"Primary LLM failed mid-stream. Falling back to "
f"{fallback.model_id}. Error: {str(e)}"
)
# Apply decorators to fallback's raw stream method directly —
# calling fallback.gen_stream() would re-enter the orchestrator
# and recurse via fallback.fallback_llm. Emit the stream-start
# event manually so dashboards still see the fallback's
# provider/model when the response actually comes from it.
fallback._emit_stream_start_log(
fallback.model_id,
kwargs.get("messages"),
kwargs.get("tools"),
bool(
kwargs.get("_usage_attachments")
or kwargs.get("attachments")
),
fallback_method = getattr(
fallback, method_name.replace("_raw_", "")
)
fallback_method = getattr(fallback, method_name)
for decorator in decorators:
fallback_method = decorator(fallback_method)
fallback_kwargs = {**kwargs, "model": fallback.model_id}
try:
yield from fallback_method(fallback, *args, **fallback_kwargs)
except Exception as e2:
if self._is_non_retriable_client_error(e2):
logger.error(
f"Fallback LLM failed mid-stream with non-retriable "
f"client error; giving up: {str(e2)}"
)
else:
logger.error(
f"Fallback LLM also failed mid-stream; giving up: {str(e2)}"
)
raise
yield from fallback_method(*args, **fallback_kwargs)
def gen(self, model, messages, stream=False, tools=None, *args, **kwargs):
decorators = [gen_token_usage, gen_cache]
@@ -294,58 +205,7 @@ class BaseLLM(ABC):
**kwargs,
)
def _emit_stream_start_log(self, model, messages, tools, has_attachments):
# Stamped with ``self.provider_name`` so dashboards can group calls
# by vendor; the fallback path emits its own copy on the fallback
# instance so the actual responding provider is recorded.
logging.info(
"llm_stream_start",
extra={
"model": model,
"provider": self.provider_name,
"message_count": len(messages) if messages is not None else 0,
"has_attachments": bool(has_attachments),
"has_tools": bool(tools),
},
)
def _emit_stream_finished_log(
self,
model,
*,
prompt_tokens,
completion_tokens,
latency_ms,
cached_tokens=None,
error=None,
):
# Paired with ``llm_stream_start`` so cost dashboards can sum tokens
# by user/agent/provider. Token counts are client-side estimates
# from ``stream_token_usage``; vendor-reported counts (incl.
# ``cached_tokens`` for prompt caching) require per-provider
# extraction in each ``_raw_gen_stream`` and aren't wired yet.
extra = {
"model": model,
"provider": self.provider_name,
"prompt_tokens": int(prompt_tokens),
"completion_tokens": int(completion_tokens),
"latency_ms": int(latency_ms),
"status": "error" if error is not None else "ok",
}
if cached_tokens is not None:
extra["cached_tokens"] = int(cached_tokens)
if error is not None:
extra["error_class"] = type(error).__name__
logging.info("llm_stream_finished", extra=extra)
def gen_stream(self, model, messages, stream=True, tools=None, *args, **kwargs):
# Attachments arrive as ``_usage_attachments`` from ``Agent._llm_gen``;
# the ``stream_token_usage`` decorator pops that key, but the log
# fires before the decorator runs so it's still in ``kwargs`` here.
has_attachments = bool(
kwargs.get("_usage_attachments") or kwargs.get("attachments")
)
self._emit_stream_start_log(model, messages, tools, has_attachments)
decorators = [stream_cache, stream_token_usage]
return self._execute_with_fallback(
"_raw_gen_stream",

View File

@@ -6,8 +6,6 @@ DOCSGPT_BASE_URL = "https://oai.arc53.com"
DOCSGPT_MODEL = "docsgpt"
class DocsGPTAPILLM(OpenAILLM):
provider_name = "docsgpt"
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
super().__init__(
api_key=DOCSGPT_API_KEY,

View File

@@ -6,13 +6,10 @@ from google.genai import types
from application.core.settings import settings
from application.llm.base import BaseLLM
from application.llm.handlers.google import _decode_thought_signature
from application.storage.storage_creator import StorageCreator
class GoogleLLM(BaseLLM):
provider_name = "google"
def __init__(
self, api_key=None, user_api_key=None, decoded_token=None, *args, **kwargs
):
@@ -82,39 +79,24 @@ class GoogleLLM(BaseLLM):
for attachment in attachments:
mime_type = attachment.get("mime_type")
if mime_type not in self.get_supported_attachment_types():
continue
try:
# Images go inline as bytes per Google's guidance for
# requests under 20MB; the Files API can return before
# the upload reaches ACTIVE state and yield an empty URI.
if mime_type.startswith("image/"):
file_bytes = self._read_attachment_bytes(attachment)
files.append(
{"file_bytes": file_bytes, "mime_type": mime_type}
)
else:
if mime_type in self.get_supported_attachment_types():
try:
file_uri = self._upload_file_to_google(attachment)
if not file_uri:
raise ValueError(
f"Google Files API returned empty URI for "
f"{attachment.get('path', 'unknown')}"
)
logging.info(
f"GoogleLLM: Successfully uploaded file, got URI: {file_uri}"
)
files.append({"file_uri": file_uri, "mime_type": mime_type})
except Exception as e:
logging.error(
f"GoogleLLM: Error processing attachment: {e}", exc_info=True
)
if "content" in attachment:
prepared_messages[user_message_index]["content"].append(
{
"type": "text",
"text": f"[File could not be processed: {attachment.get('path', 'unknown')}]",
}
except Exception as e:
logging.error(
f"GoogleLLM: Error uploading file: {e}", exc_info=True
)
if "content" in attachment:
prepared_messages[user_message_index]["content"].append(
{
"type": "text",
"text": f"[File could not be processed: {attachment.get('path', 'unknown')}]",
}
)
if files:
logging.info(f"GoogleLLM: Adding {len(files)} files to message")
prepared_messages[user_message_index]["content"].append({"files": files})
@@ -130,9 +112,7 @@ class GoogleLLM(BaseLLM):
Returns:
str: Google AI file URI for the uploaded file.
"""
# Truthy check, not membership: a poisoned cache row of "" or
# None must be treated as a miss and trigger a fresh upload.
if attachment.get("google_file_uri"):
if "google_file_uri" in attachment:
return attachment["google_file_uri"]
file_path = attachment.get("path")
if not file_path:
@@ -146,10 +126,6 @@ class GoogleLLM(BaseLLM):
file=local_path
).uri,
)
if not file_uri:
raise ValueError(
f"Google Files API upload returned empty URI for {file_path}"
)
# Cache the Google file URI on the attachment row so we don't
# re-upload on the next LLM call. Accept either a PG UUID
@@ -183,26 +159,6 @@ class GoogleLLM(BaseLLM):
logging.error(f"Error uploading file to Google AI: {e}", exc_info=True)
raise
def _read_attachment_bytes(self, attachment):
"""
Read attachment bytes from storage for inline transmission.
Args:
attachment (dict): Attachment dictionary with path and metadata.
Returns:
bytes: Raw file bytes.
"""
file_path = attachment.get("path")
if not file_path:
raise ValueError("No file path provided in attachment")
if not self.storage.file_exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
return self.storage.process_file(
file_path,
lambda local_path, **kwargs: open(local_path, "rb").read(),
)
def _clean_messages_google(self, messages):
"""
Convert OpenAI format messages to Google AI format and collect system prompts.
@@ -259,7 +215,7 @@ class GoogleLLM(BaseLLM):
except (_json.JSONDecodeError, TypeError):
args = {}
cleaned_args = self._remove_null_values(args)
thought_sig = _decode_thought_signature(tc.get("thought_signature"))
thought_sig = tc.get("thought_signature")
if thought_sig:
parts.append(
types.Part(
@@ -323,9 +279,7 @@ class GoogleLLM(BaseLLM):
name=item["function_call"]["name"],
args=cleaned_args,
),
thoughtSignature=_decode_thought_signature(
item["thought_signature"]
),
thoughtSignature=item["thought_signature"],
)
)
else:
@@ -344,24 +298,12 @@ class GoogleLLM(BaseLLM):
)
elif "files" in item:
for file_data in item["files"]:
if "file_bytes" in file_data:
parts.append(
types.Part.from_bytes(
data=file_data["file_bytes"],
mime_type=file_data["mime_type"],
)
)
elif file_data.get("file_uri"):
parts.append(
types.Part.from_uri(
file_uri=file_data["file_uri"],
mime_type=file_data["mime_type"],
)
)
else:
logging.warning(
"GoogleLLM: dropping file part with empty URI and no bytes"
parts.append(
types.Part.from_uri(
file_uri=file_data["file_uri"],
mime_type=file_data["mime_type"],
)
)
else:
raise ValueError(
f"Unexpected content dictionary format:{item}"
@@ -599,6 +541,22 @@ class GoogleLLM(BaseLLM):
config.response_mime_type = "application/json"
# Check if we have both tools and file attachments
has_attachments = False
for message in messages:
for part in message.parts:
if hasattr(part, "file_data") and part.file_data is not None:
has_attachments = True
break
if has_attachments:
break
messages_summary = self._summarize_messages_for_log(messages)
logging.info(
"GoogleLLM: Starting stream generation. Model: %s, Messages: %s, Has attachments: %s",
model,
messages_summary,
has_attachments,
)
response = client.models.generate_content_stream(
model=model,
contents=messages,

View File

@@ -5,8 +5,6 @@ GROQ_BASE_URL = "https://api.groq.com/openai/v1"
class GroqLLM(OpenAILLM):
provider_name = "groq"
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
super().__init__(
api_key=api_key or settings.GROQ_API_KEY or settings.API_KEY,

View File

@@ -10,18 +10,6 @@ from application.logging import build_stack_data
logger = logging.getLogger(__name__)
# Cap the agent tool-call loop. Without this an LLM that keeps
# requesting more tool calls (preview models, sparse tool results,
# under-specified prompts) can chain searches indefinitely and the
# stream never finalises. 25 mirrors Dify's default.
MAX_TOOL_ITERATIONS = 25
_FINALIZE_INSTRUCTION = (
f"You have made {MAX_TOOL_ITERATIONS} tool calls. Provide a final "
"response to the user based on what you have, without making any "
"additional tool calls."
)
@dataclass
class ToolCall:
"""Represents a tool/function call from the LLM."""
@@ -292,26 +280,7 @@ class LLMHandler(ABC):
# Keep serialized function calls/responses so the compressor sees actions
parts_text.append(str(item))
elif "files" in item:
# Image attachments arrive with raw bytes / base64
# inline (see GoogleLLM.prepare_messages_with_attachments).
# ``str(item)`` would dump the whole byte/base64
# blob into the compression prompt and bust the
# compression LLM's input limit.
files = item.get("files") or []
descriptors = []
if isinstance(files, list):
for f in files:
if isinstance(f, dict):
descriptors.append(
f.get("mime_type") or "file"
)
elif isinstance(f, str):
descriptors.append(f)
if not descriptors:
descriptors = ["file"]
parts_text.append(
f"[attachment: {', '.join(descriptors)}]"
)
parts_text.append(str(item))
return "\n".join(parts_text)
return ""
@@ -636,10 +605,6 @@ class LLMHandler(ABC):
agent_id=getattr(agent, "agent_id", None),
model_user_id=compression_user_id,
)
# Side-channel LLM tag — see ``orchestrator.py`` for rationale.
compression_llm._token_usage_source = "compression"
compression_llm._request_id = getattr(agent, "_request_id", None) \
or getattr(getattr(agent, "llm", None), "_request_id", None)
# Create service without DB persistence capability
compression_service = CompressionService(
@@ -950,9 +915,7 @@ class LLMHandler(ABC):
parsed = self.parse_response(response)
self.llm_calls.append(build_stack_data(agent.llm))
iteration = 0
while parsed.requires_tool_call:
iteration += 1
tool_handler_gen = self.handle_tool_calls(
agent, parsed.tool_calls, tools_dict, messages
)
@@ -976,25 +939,6 @@ class LLMHandler(ABC):
}
return ""
# Cap reached: force one final tool-less call so the stream
# always ends with content rather than cutting off.
if iteration >= MAX_TOOL_ITERATIONS:
logger.warning(
"agent tool loop hit cap (%d); forcing finalize",
MAX_TOOL_ITERATIONS,
)
messages.append(
{"role": "system", "content": _FINALIZE_INSTRUCTION},
)
response = agent.llm.gen(
model=getattr(agent.llm, "model_id", None) or agent.model_id,
messages=messages,
tools=None,
)
parsed = self.parse_response(response)
self.llm_calls.append(build_stack_data(agent.llm))
break
# ``agent.model_id`` is the registry id (a UUID for BYOM
# records). Use the LLM's own model_id, which LLMCreator
# already resolved to the upstream model name. Built-ins:
@@ -1010,12 +954,7 @@ class LLMHandler(ABC):
return parsed.content
def handle_streaming(
self,
agent,
response: Any,
tools_dict: Dict,
messages: List[Dict],
_iteration: int = 0,
self, agent, response: Any, tools_dict: Dict, messages: List[Dict]
) -> Generator:
"""
Handle streaming response flow.
@@ -1084,9 +1023,6 @@ class LLMHandler(ABC):
}
return
next_iteration = _iteration + 1
cap_reached = next_iteration >= MAX_TOOL_ITERATIONS
# Check if context limit was reached during tool execution
if hasattr(agent, 'context_limit_reached') and agent.context_limit_reached:
# Add system message warning about context limit
@@ -1099,32 +1035,16 @@ class LLMHandler(ABC):
)
})
logger.info("Context limit reached - instructing agent to wrap up")
elif cap_reached:
logger.warning(
"agent tool loop hit cap (%d); forcing finalize",
MAX_TOOL_ITERATIONS,
)
messages.append(
{"role": "system", "content": _FINALIZE_INSTRUCTION},
)
# See note above on agent.model_id vs llm.model_id.
response = agent.llm.gen_stream(
model=getattr(agent.llm, "model_id", None) or agent.model_id,
messages=messages,
tools=(
None
if cap_reached
or getattr(agent, "context_limit_reached", False)
else agent.tools
),
tools=agent.tools if not agent.context_limit_reached else None,
)
self.llm_calls.append(build_stack_data(agent.llm))
yield from self.handle_streaming(
agent, response, tools_dict, messages,
_iteration=next_iteration,
)
yield from self.handle_streaming(agent, response, tools_dict, messages)
return
if parsed.content:
buffer += parsed.content

View File

@@ -1,35 +1,9 @@
import base64
import binascii
import uuid
from typing import Any, Dict, Generator, Optional, Union
from typing import Any, Dict, Generator
from application.llm.handlers.base import LLMHandler, LLMResponse, ToolCall
def _encode_thought_signature(sig: Optional[Union[bytes, str]]) -> Optional[str]:
# Gemini's Python SDK returns thought_signature as raw bytes, but the
# field is typed Optional[str] downstream and gets json.dumps'd into
# SSE events. Encode once at ingress so callers only ever see a str.
if isinstance(sig, bytes):
return base64.b64encode(sig).decode("ascii")
return sig
def _decode_thought_signature(
sig: Optional[Union[bytes, str]],
) -> Optional[Union[bytes, str]]:
# Reverse of _encode_thought_signature — Gemini's SDK expects bytes
# back when we replay a tool call. ``validate=True`` keeps ASCII
# strings that happen to be loosely decodable from being silently
# turned into bytes; non-base64 inputs pass through unchanged.
if isinstance(sig, str):
try:
return base64.b64decode(sig.encode("ascii"), validate=True)
except (binascii.Error, ValueError):
return sig
return sig
class GoogleLLMHandler(LLMHandler):
"""Handler for Google's GenAI API."""
@@ -49,7 +23,7 @@ class GoogleLLMHandler(LLMHandler):
for idx, part in enumerate(parts):
if hasattr(part, "function_call") and part.function_call is not None:
has_sig = hasattr(part, "thought_signature") and part.thought_signature is not None
thought_sig = _encode_thought_signature(part.thought_signature) if has_sig else None
thought_sig = part.thought_signature if has_sig else None
tool_calls.append(
ToolCall(
id=str(uuid.uuid4()),
@@ -76,7 +50,7 @@ class GoogleLLMHandler(LLMHandler):
tool_calls = []
if hasattr(response, "function_call") and response.function_call is not None:
has_sig = hasattr(response, "thought_signature") and response.thought_signature is not None
thought_sig = _encode_thought_signature(response.thought_signature) if has_sig else None
thought_sig = response.thought_signature if has_sig else None
tool_calls.append(
ToolCall(
id=str(uuid.uuid4()),
@@ -96,15 +70,8 @@ class GoogleLLMHandler(LLMHandler):
"""Create a tool result message in the standard internal format."""
import json as _json
from application.storage.db.serialization import PGNativeJSONEncoder
# PostgresTool results commonly include PG-native types
# (datetime / UUID / Decimal / bytea) when SELECT touches
# timestamptz / numeric / uuid / bytea columns. The shared
# encoder handles all five — bytes get base64 (lossless) instead
# of the ``str(b'...')`` repr that ``default=str`` would emit.
content = (
_json.dumps(result, cls=PGNativeJSONEncoder)
_json.dumps(result)
if not isinstance(result, str)
else result
)

View File

@@ -40,15 +40,8 @@ class OpenAILLMHandler(LLMHandler):
"""Create a tool result message in the standard internal format."""
import json as _json
from application.storage.db.serialization import PGNativeJSONEncoder
# PostgresTool results commonly include PG-native types
# (datetime / UUID / Decimal / bytea) when SELECT touches
# timestamptz / numeric / uuid / bytea columns. The shared
# encoder handles all five — bytes get base64 (lossless) instead
# of the ``str(b'...')`` repr that ``default=str`` would emit.
content = (
_json.dumps(result, cls=PGNativeJSONEncoder)
_json.dumps(result)
if not isinstance(result, str)
else result
)

View File

@@ -26,8 +26,6 @@ class LlamaSingleton:
class LlamaCpp(BaseLLM):
provider_name = "llama_cpp"
def __init__(
self,
api_key=None,

View File

@@ -5,8 +5,6 @@ NOVITA_BASE_URL = "https://api.novita.ai/openai"
class NovitaLLM(OpenAILLM):
provider_name = "novita"
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
super().__init__(
api_key=api_key or settings.NOVITA_API_KEY or settings.API_KEY,

View File

@@ -5,8 +5,6 @@ OPEN_ROUTER_BASE_URL = "https://openrouter.ai/api/v1"
class OpenRouterLLM(OpenAILLM):
provider_name = "openrouter"
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
super().__init__(
api_key=api_key or settings.OPEN_ROUTER_API_KEY or settings.API_KEY,

View File

@@ -61,7 +61,6 @@ def _truncate_base64_for_logging(messages):
class OpenAILLM(BaseLLM):
provider_name = "openai"
def __init__(
self,

View File

@@ -3,7 +3,6 @@ from application.core.settings import settings
class PremAILLM(BaseLLM):
provider_name = "premai"
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
from premai import Prem

View File

@@ -59,7 +59,6 @@ class LineIterator:
class SagemakerAPILLM(BaseLLM):
provider_name = "sagemaker"
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
import boto3

View File

@@ -1,13 +1,11 @@
import datetime
import functools
import inspect
import time
import logging
import uuid
from typing import Any, Callable, Dict, Generator, List
from application.core import log_context
from application.storage.db.repositories.stack_logs import StackLogsRepository
from application.storage.db.session import db_session
@@ -24,15 +22,6 @@ class LogContext:
self.api_key = api_key
self.query = query
self.stacks = []
# Per-activity response aggregates populated by ``_consume_and_log``
# while it forwards stream items, then flushed onto the
# ``activity_finished`` event so every Flask request gets the
# same summary that ``run_agent_logic`` used to log only for the
# Celery webhook path.
self.answer_length = 0
self.thought_length = 0
self.source_count = 0
self.tool_call_count = 0
def build_stack_data(
@@ -89,125 +78,25 @@ def log_activity() -> Callable:
user = data.get("user", "local")
api_key = data.get("user_api_key", "")
query = kwargs.get("query", getattr(args[0], "query", ""))
agent_id = getattr(args[0], "agent_id", None) or kwargs.get("agent_id")
conversation_id = (
kwargs.get("conversation_id")
or getattr(args[0], "conversation_id", None)
)
model = getattr(args[0], "gpt_model", None) or getattr(args[0], "model", None)
# Capture the surrounding activity_id before overlaying ours,
# so nested activities record the parent → child link.
parent_activity_id = log_context.snapshot().get("activity_id")
context = LogContext(endpoint, activity_id, user, api_key, query)
kwargs["log_context"] = context
ctx_token = log_context.bind(
activity_id=activity_id,
parent_activity_id=parent_activity_id,
user_id=user,
agent_id=agent_id,
conversation_id=conversation_id,
endpoint=endpoint,
model=model,
)
started_at = time.monotonic()
logging.info(
"activity_started",
extra={
"activity_id": activity_id,
"parent_activity_id": parent_activity_id,
"user_id": user,
"agent_id": agent_id,
"conversation_id": conversation_id,
"endpoint": endpoint,
"model": model,
},
f"Starting activity: {endpoint} - {activity_id} - User: {user}"
)
error: BaseException | None = None
try:
generator = func(*args, **kwargs)
yield from _consume_and_log(generator, context)
except Exception as exc:
# Only ``Exception`` counts as an activity error; ``GeneratorExit``
# (consumer disconnected mid-stream) and ``KeyboardInterrupt``
# flow through the finally as ``status="ok"``, matching
# ``_consume_and_log``.
error = exc
raise
finally:
_emit_activity_finished(
context=context,
parent_activity_id=parent_activity_id,
started_at=started_at,
error=error,
)
log_context.reset(ctx_token)
generator = func(*args, **kwargs)
yield from _consume_and_log(generator, context)
return wrapper
return decorator
def _emit_activity_finished(
*,
context: "LogContext",
parent_activity_id: str | None,
started_at: float,
error: BaseException | None,
) -> None:
"""Emit the paired ``activity_finished`` event with duration, outcome,
and per-activity response aggregates accumulated in ``_consume_and_log``.
"""
duration_ms = int((time.monotonic() - started_at) * 1000)
logging.info(
"activity_finished",
extra={
"activity_id": context.activity_id,
"parent_activity_id": parent_activity_id,
"user_id": context.user,
"endpoint": context.endpoint,
"duration_ms": duration_ms,
"status": "error" if error is not None else "ok",
"error_class": type(error).__name__ if error is not None else None,
"answer_length": context.answer_length,
"thought_length": context.thought_length,
"source_count": context.source_count,
"tool_call_count": context.tool_call_count,
},
)
def _accumulate_response_summary(item: Any, context: "LogContext") -> None:
"""Mirror the per-line aggregation that ``run_agent_logic`` did for the
Celery webhook path, but at the generator-consumption layer so every
``Agent.gen`` activity (Flask streaming, sub-agents, workflow agents)
gets the same summary.
"""
if not isinstance(item, dict):
return
if "answer" in item:
context.answer_length += len(str(item["answer"]))
return
if "thought" in item:
context.thought_length += len(str(item["thought"]))
return
sources = item.get("sources") if "sources" in item else None
if isinstance(sources, list):
context.source_count += len(sources)
return
tool_calls = item.get("tool_calls") if "tool_calls" in item else None
if isinstance(tool_calls, list):
context.tool_call_count += len(tool_calls)
def _consume_and_log(generator: Generator, context: "LogContext"):
try:
for item in generator:
_accumulate_response_summary(item, context)
yield item
except Exception as e:
logging.exception(f"Error in {context.endpoint} - {context.activity_id}: {e}")

View File

@@ -1,28 +1,12 @@
import os
import logging
from typing import Any, List, Optional
from typing import List, Any
from retry import retry
from tqdm import tqdm
from application.core.settings import settings
from application.events.publisher import publish_user_event
from application.storage.db.repositories.ingest_chunk_progress import (
IngestChunkProgressRepository,
)
from application.storage.db.session import db_session
from application.vectorstore.vector_creator import VectorCreator
class EmbeddingPipelineError(Exception):
"""Raised when the per-chunk embed loop produces a partial index.
Escapes into Celery's ``autoretry_for`` so a transient cause (rate
limit, network blip) gets another shot. The chunk-progress
checkpoint makes retries cheap — only the failed-and-after chunks
re-run. After ``MAX_TASK_ATTEMPTS`` the poison-loop guard in
``with_idempotency`` finalises the row as ``failed``.
"""
def sanitize_content(content: str) -> str:
"""
Remove NUL characters that can cause vector store ingestion to fail.
@@ -38,11 +22,7 @@ def sanitize_content(content: str) -> str:
return content.replace('\x00', '')
# Per-chunk inline retry. Aggressive defaults (tries=10, delay=60) blocked
# the loop for up to 9 min per chunk and wedged the heartbeat: lower the
# tail so a transient failure fails-fast and the chunk-progress checkpoint
# resumes cleanly on next dispatch.
@retry(tries=3, delay=5, backoff=2)
@retry(tries=10, delay=60)
def add_text_to_store_with_retry(store: Any, doc: Any, source_id: str) -> None:
"""Add a document's text and metadata to the vector store with retry logic.
@@ -65,124 +45,21 @@ def add_text_to_store_with_retry(store: Any, doc: Any, source_id: str) -> None:
raise
def _init_progress_and_resume_index(
source_id: str, total_chunks: int, attempt_id: Optional[str],
) -> int:
"""Upsert the progress row and return the next chunk index to embed.
The repository's upsert preserves ``last_index`` only when the
incoming ``attempt_id`` matches the stored one (a Celery autoretry
of the same task). On a fresh attempt — including any caller that
doesn't pass an ``attempt_id``, e.g. legacy code or tests — the
row's checkpoint is reset so the loop starts from chunk 0. This
is what prevents a completed checkpoint from any prior run
silently no-op'ing the next sync/reingest.
Best-effort: a DB outage falls back to ``0`` (fresh run from
chunk 0). The embed loop's own re-raise still ensures partial
runs don't get cached as complete.
"""
try:
with db_session() as conn:
progress = IngestChunkProgressRepository(conn).init_progress(
source_id, total_chunks, attempt_id,
)
except Exception as e:
logging.warning(
f"Could not init ingest progress for {source_id}: {e}",
exc_info=True,
)
return 0
if not progress:
return 0
last_index = progress.get("last_index", -1)
if last_index is None or last_index < 0:
return 0
return int(last_index) + 1
def _record_progress(source_id: str, last_index: int, embedded_chunks: int) -> None:
"""Best-effort checkpoint after each chunk; logged but never raised."""
try:
with db_session() as conn:
IngestChunkProgressRepository(conn).record_chunk(
source_id, last_index=last_index, embedded_chunks=embedded_chunks
)
except Exception as e:
logging.warning(
f"Could not record ingest progress for {source_id}: {e}", exc_info=True
)
def assert_index_complete(source_id: str) -> None:
"""Raise ``EmbeddingPipelineError`` if ``ingest_chunk_progress``
shows a partial embed for ``source_id``.
Defense-in-depth tripwire that workers run after
``embed_and_store_documents`` to catch any future swallow path
that bypasses the function's own re-raise — the chunk-progress
row is the authoritative record of how many chunks landed.
No-op when no row exists (zero-doc validation raised before init,
or progress repo was unreachable).
"""
try:
with db_session() as conn:
progress = IngestChunkProgressRepository(conn).get_progress(source_id)
except Exception as e:
logging.warning(
f"assert_index_complete: progress lookup failed for "
f"{source_id}: {e}",
exc_info=True,
)
return
if not progress:
return
embedded = int(progress.get("embedded_chunks") or 0)
total = int(progress.get("total_chunks") or 0)
if embedded < total:
raise EmbeddingPipelineError(
f"partial index for source {source_id}: "
f"{embedded}/{total} chunks embedded"
)
def embed_and_store_documents(
docs: List[Any],
folder_name: str,
source_id: str,
task_status: Any,
*,
attempt_id: Optional[str] = None,
user_id: Optional[str] = None,
) -> None:
def embed_and_store_documents(docs: List[Any], folder_name: str, source_id: str, task_status: Any) -> None:
"""Embeds documents and stores them in a vector store.
Resumable across Celery autoretries of the *same* task: when
``attempt_id`` matches the stored checkpoint's ``attempt_id``,
the loop resumes from ``last_index + 1``. A different
``attempt_id`` (a fresh sync / reingest invocation) resets the
checkpoint so the index is rebuilt from chunk 0 — this is what
keeps a completed checkpoint from poisoning the next sync.
Args:
docs: List of documents to be embedded and stored.
folder_name: Directory to save the vector store.
source_id: Unique identifier for the source.
task_status: Task state manager for progress updates.
attempt_id: Stable id of the current task invocation,
typically ``self.request.id`` from the Celery task body.
``None`` is treated as a fresh attempt every time.
user_id: When provided, per-percent SSE progress events are
published to ``user:{user_id}`` for the in-app upload toast.
``None`` is the safe default — workers without a user
context (e.g. background syncs) skip the publish.
Returns:
None
Raises:
OSError: If unable to create folder or save vector store.
EmbeddingPipelineError: If a chunk fails after retries.
Exception: If vector store creation or document embedding fails.
"""
# Ensure the folder exists
if not os.path.exists(folder_name):
@@ -192,108 +69,41 @@ def embed_and_store_documents(
if not docs:
raise ValueError("No documents to embed - check file format and extension")
total_docs = len(docs)
# Atomic upsert that preserves checkpoint state on attempt-id match
# (autoretry of same task) and resets it on mismatch (fresh sync /
# reingest). Returns the new resume index — 0 means "start fresh".
resume_index = _init_progress_and_resume_index(
source_id, total_docs, attempt_id,
)
is_resume = resume_index > 0
# Initialize vector store
if settings.VECTOR_STORE == "faiss":
if is_resume:
# Load the existing FAISS index from storage so chunks
# already embedded by the prior attempt survive the
# save_local rewrite at the end of this run.
store = VectorCreator.create_vectorstore(
settings.VECTOR_STORE,
source_id=source_id,
embeddings_key=os.getenv("EMBEDDINGS_KEY"),
)
loop_start = resume_index
else:
# FAISS requires at least one doc to construct the store;
# seed with ``docs[0]`` and let the loop pick up at index 1.
store = VectorCreator.create_vectorstore(
settings.VECTOR_STORE,
docs_init=[docs[0]],
source_id=source_id,
embeddings_key=os.getenv("EMBEDDINGS_KEY"),
)
# Record the seeded chunk so single-doc ingests don't fail
# ``assert_index_complete`` — the loop never runs for
# ``total_docs == 1`` and would otherwise leave
# ``embedded_chunks`` at 0 / ``last_index`` at -1. The loop
# body's per-iteration ``_record_progress`` overshoots
# correctly for multi-chunk runs (counts seed + iterations),
# so writing this checkpoint up-front is a no-op for those.
_record_progress(source_id, last_index=0, embedded_chunks=1)
loop_start = 1
docs_init = [docs.pop(0)]
store = VectorCreator.create_vectorstore(
settings.VECTOR_STORE,
docs_init=docs_init,
source_id=source_id,
embeddings_key=os.getenv("EMBEDDINGS_KEY"),
)
else:
store = VectorCreator.create_vectorstore(
settings.VECTOR_STORE,
source_id=source_id,
embeddings_key=os.getenv("EMBEDDINGS_KEY"),
)
# Only wipe the index on a fresh run — a resume must keep the
# chunks that earlier attempts already embedded.
if not is_resume:
store.delete_index()
loop_start = resume_index
store.delete_index()
if is_resume and loop_start >= total_docs:
# Nothing left to do; the loop runs zero iterations and
# downstream finalize logic still executes. This is only
# reachable on a same-attempt retry of a task whose previous
# attempt finished — typically a Celery acks_late redelivery
# after the task already returned. The ``assert_index_complete``
# tripwire still validates ``embedded == total`` afterwards.
loop_start = total_docs
total_docs = len(docs)
# Process and embed documents
chunk_error: Exception | None = None
failed_idx: int | None = None
last_published_pct = -1
source_id_str = str(source_id)
for idx in tqdm(
range(loop_start, total_docs),
for idx, doc in tqdm(
enumerate(docs),
desc="Embedding 🦖",
unit="docs",
total=total_docs - loop_start,
total=total_docs,
bar_format="{l_bar}{bar}| Time Left: {remaining}",
):
doc = docs[idx]
try:
# Update task status for progress tracking
progress = int(((idx + 1) / total_docs) * 100)
task_status.update_state(state="PROGRESS", meta={"current": progress})
# SSE push for sub-second upload-toast updates. Throttled to one
# event per percent so a 10k-chunk ingest emits ~100 events,
# not 10k. The Celery update_state above stays the source of
# truth for the polling-fallback path.
if user_id and progress > last_published_pct:
publish_user_event(
user_id,
"source.ingest.progress",
{
"current": progress,
"total": total_docs,
"embedded_chunks": idx + 1,
"stage": "embedding",
},
scope={"kind": "source", "id": source_id_str},
)
last_published_pct = progress
# Add document to vector store
add_text_to_store_with_retry(store, doc, source_id)
_record_progress(source_id, last_index=idx, embedded_chunks=idx + 1)
except Exception as e:
chunk_error = e
failed_idx = idx
logging.error(f"Error embedding document {idx}: {e}", exc_info=True)
logging.info(f"Saving progress at document {idx} out of {total_docs}")
try:
@@ -314,16 +124,3 @@ def embed_and_store_documents(
raise OSError(f"Unable to save vector store to {folder_name}: {e}") from e
else:
logging.info("Vector store saved successfully.")
# Re-raise after the partial save: the chunks that *did* embed are
# flushed to disk and recorded in ``ingest_chunk_progress``, so a
# Celery autoretry resumes via ``_read_resume_index`` and only
# re-runs the failed-and-after chunks. Without the raise, the
# task body returns success and ``with_idempotency`` finalises
# ``task_dedup`` as ``completed`` for a partial index — poisoning
# the cache for 24h.
if chunk_error is not None:
raise EmbeddingPipelineError(
f"embed failure at chunk {failed_idx}/{total_docs} "
f"for source {source_id}"
) from chunk_error

View File

@@ -60,9 +60,6 @@ class ClassicRAG(BaseRetriever):
agent_id=self.agent_id,
model_user_id=self.model_user_id,
)
# Query-rephrase LLM is a side channel — tag it so its rows
# land as ``source='rag_condense'`` in cost-attribution.
self.llm._token_usage_source = "rag_condense"
if "active_docs" in source and source["active_docs"] is not None:
if isinstance(source["active_docs"], list):

View File

@@ -11,8 +11,6 @@ import re
from typing import Any, Mapping
from uuid import UUID
from application.storage.db.serialization import coerce_pg_native
_UUID_RE = re.compile(
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$",
@@ -36,17 +34,12 @@ def looks_like_uuid(value: Any) -> bool:
def row_to_dict(row: Any) -> dict:
"""Convert a SQLAlchemy ``Row`` to a plain JSON-safe dict.
"""Convert a SQLAlchemy ``Row`` to a plain dict with Mongo-compatible ids.
Normalises PG-native types at the SELECT boundary: UUID, datetime,
date, Decimal, and bytes are coerced to JSON-safe forms via
:func:`coerce_pg_native`. Downstream serialisation (SSE events,
JSONB writes, API responses) becomes safe by default — repository
consumers no longer need to know that PG returns a different type
set than Mongo did.
Also emits ``_id`` alongside ``id`` for the duration of the Mongo→PG
cutover so legacy serializers expecting Mongo's shape keep working.
During the migration window, API responses and downstream code still
expect a string ``_id`` field (matching the Mongo shape). This helper
normalizes UUID columns to strings and emits both ``id`` and ``_id`` so
existing serializers keep working unchanged.
Args:
row: A SQLAlchemy ``Row`` object, or ``None``.
@@ -59,9 +52,10 @@ def row_to_dict(row: Any) -> dict:
# Row has a ``._mapping`` attribute exposing a MappingProxy view.
mapping: Mapping[str, Any] = row._mapping # type: ignore[attr-defined]
out = coerce_pg_native(dict(mapping))
out = dict(mapping)
if "id" in out and out["id"] is not None:
out["id"] = str(out["id"]) if isinstance(out["id"], UUID) else out["id"]
out["_id"] = out["id"]
return out

View File

@@ -34,7 +34,7 @@ from sqlalchemy.dialects.postgresql import ARRAY, CITEXT, JSONB, UUID
metadata = MetaData()
# --- Users, prompts, tools, logs --------------------------------------------
# --- Phase 1, Tier 1 --------------------------------------------------------
users_table = Table(
"users",
@@ -91,16 +91,6 @@ token_usage_table = Table(
Column("prompt_tokens", Integer, nullable=False, server_default="0"),
Column("generated_tokens", Integer, nullable=False, server_default="0"),
Column("timestamp", DateTime(timezone=True), nullable=False, server_default=func.now()),
# Added in ``0004_durability_foundation``. Distinguishes
# ``agent_stream`` (primary completion) from side-channel inserts
# (``title`` / ``compression`` / ``rag_condense`` / ``fallback``)
# so cost attribution dashboards can group by call source.
Column("source", Text, nullable=False, server_default="agent_stream"),
# Added in ``0005_token_usage_request_id``. Stream-scoped UUID stamped
# on the agent's primary LLM so multi-call agent runs (which produce
# N rows) count as a single request via DISTINCT in the repository
# query. NULL on side-channel sources by design.
Column("request_id", Text),
)
user_logs_table = Table(
@@ -138,7 +128,7 @@ app_metadata_table = Table(
)
# --- Agents, sources, attachments, artifacts --------------------------------
# --- Phase 2, Tier 2 --------------------------------------------------------
agent_folders_table = Table(
"agent_folders",
@@ -307,7 +297,7 @@ connector_sessions_table = Table(
)
# --- Conversations, messages, workflows -------------------------------------
# --- Phase 3, Tier 3 --------------------------------------------------------
conversations_table = Table(
"conversations",
@@ -355,44 +345,9 @@ conversation_messages_table = Table(
Column("feedback", JSONB),
Column("timestamp", DateTime(timezone=True), nullable=False, server_default=func.now()),
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
# Added in 0004_durability_foundation. ``status`` is the WAL state
# machine (pending|streaming|complete|failed); ``request_id`` ties a
# row to a specific HTTP request for log correlation.
Column("status", Text, nullable=False, server_default="complete"),
Column("request_id", Text),
UniqueConstraint("conversation_id", "position", name="conversation_messages_conv_pos_uidx"),
)
# Per-yield journal of chat-stream events, used by the snapshot+tail
# reconnect: the route's GET reconnect endpoint reads
# ``WHERE message_id = ? AND sequence_no > ?`` from this table before
# tailing the live ``channel:{message_id}`` pub/sub. See
# ``application/streaming/event_replay.py`` and migration 0007.
message_events_table = Table(
"message_events",
metadata,
# PK is the composite ``(message_id, sequence_no)`` — it doubles as
# the snapshot read index (covering range scan on
# ``WHERE message_id = ? AND sequence_no > ?``).
Column(
"message_id",
UUID(as_uuid=True),
ForeignKey("conversation_messages.id", ondelete="CASCADE"),
primary_key=True,
nullable=False,
),
# Strictly monotonic per ``message_id``. Allocated by the route as it
# yields, so the writer is single-threaded for the lifetime of one
# stream — no contention, no SERIAL needed.
Column("sequence_no", Integer, primary_key=True, nullable=False),
Column("event_type", Text, nullable=False),
Column("payload", JSONB, nullable=False, server_default="{}"),
Column(
"created_at", DateTime(timezone=True), nullable=False, server_default=func.now()
),
)
shared_conversations_table = Table(
"shared_conversations",
metadata,
@@ -422,101 +377,9 @@ pending_tool_state_table = Table(
Column("client_tools", JSONB),
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
Column("expires_at", DateTime(timezone=True), nullable=False),
# Added in ``0004_durability_foundation``. ``status`` is the
# ``pending|resuming`` claim flag for the resumed-run path;
# ``resumed_at`` stamps when ``mark_resuming`` flipped the row so
# the cleanup janitor can revert stale claims after the grace
# window.
Column("status", Text, nullable=False, server_default="pending"),
Column("resumed_at", DateTime(timezone=True)),
UniqueConstraint("conversation_id", "user_id", name="pending_tool_state_conv_user_uidx"),
)
# --- Durability foundation (idempotency / journals, migration 0004) ---------
# CHECK constraints (status enums) and partial indexes are intentionally
# omitted from these declarations — the DB is the authority. Repositories
# use raw ``text(...)`` SQL against these tables, not the Core objects.
task_dedup_table = Table(
"task_dedup",
metadata,
Column("idempotency_key", Text, primary_key=True),
Column("task_name", Text, nullable=False),
Column("task_id", Text, nullable=False),
Column("result_json", JSONB),
# CHECK (status IN ('pending', 'completed', 'failed')) lives in 0004.
Column("status", Text, nullable=False),
# Bumped each time the per-Celery-task wrapper re-enters; the
# poison-loop guard (``MAX_TASK_ATTEMPTS=5``) refuses to run fn once
# this exceeds the threshold.
Column("attempt_count", Integer, nullable=False, server_default="0"),
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
# Added in ``0006_idempotency_lease``. Per-invocation random id
# written by the wrapper at lease claim; refreshed every 30 s by a
# heartbeat thread. Other workers seeing a fresh lease (NOT NULL
# AND ``lease_expires_at > now()``) refuse to run the task body.
Column("lease_owner_id", Text),
Column("lease_expires_at", DateTime(timezone=True)),
)
webhook_dedup_table = Table(
"webhook_dedup",
metadata,
Column("idempotency_key", Text, primary_key=True),
Column("agent_id", UUID(as_uuid=True), nullable=False),
Column("task_id", Text, nullable=False),
Column("response_json", JSONB),
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
)
# Three-phase tool-call journal: ``proposed → executed → confirmed``
# (terminal: ``failed``; ``compensated`` is grandfathered in the CHECK
# from migration 0004 but no code writes it). The reconciler sweeps
# stuck rows via the partial ``tool_call_attempts_pending_ts_idx``.
tool_call_attempts_table = Table(
"tool_call_attempts",
metadata,
Column("call_id", Text, primary_key=True),
# ON DELETE SET NULL preserves the journal even after the parent
# message is deleted — useful for cost-attribution / compliance.
Column(
"message_id",
UUID(as_uuid=True),
ForeignKey("conversation_messages.id", ondelete="SET NULL"),
),
Column("tool_id", UUID(as_uuid=True)),
Column("tool_name", Text, nullable=False),
Column("action_name", Text, nullable=False),
Column("arguments", JSONB, nullable=False),
Column("result", JSONB),
Column("error", Text),
# CHECK (status IN ('proposed', 'executed', 'confirmed',
# 'compensated', 'failed')) lives in 0004.
Column("status", Text, nullable=False),
Column("attempted_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
)
# Per-source ingest checkpoint. Heartbeat thread bumps ``last_updated``
# every 30s while a worker embeds; the reconciler escalates when it
# stops ticking.
ingest_chunk_progress_table = Table(
"ingest_chunk_progress",
metadata,
Column("source_id", UUID(as_uuid=True), primary_key=True),
Column("total_chunks", Integer, nullable=False),
Column("embedded_chunks", Integer, nullable=False, server_default="0"),
Column("last_index", Integer, nullable=False, server_default="-1"),
Column("last_updated", DateTime(timezone=True), nullable=False, server_default=func.now()),
# Added in ``0005_ingest_attempt_id``. Stamped from
# ``self.request.id`` (Celery's stable task id) so a retry of the
# same task resumes from the checkpoint, but a separate invocation
# (manual reingest, scheduled sync) resets to a clean re-index.
Column("attempt_id", Text),
)
workflows_table = Table(
"workflows",
metadata,

View File

@@ -17,21 +17,6 @@ _UPDATABLE_SCALARS = {
_UPDATABLE_JSONB = {"metadata"}
def _attachment_to_dict(row: Any) -> dict:
"""row_to_dict + ``upload_path``→``path`` alias.
Pre-Postgres, the Mongo attachment shape used ``path``. The PG column
is ``upload_path``; LLM provider code (google_ai/openai/anthropic and
handlers/base) still reads ``attachment.get("path")``. Mirroring the
``id``/``_id`` dual-emit in row_to_dict so consumers don't need to
know which storage backend produced the dict.
"""
out = row_to_dict(row)
if "upload_path" in out and out.get("path") is None:
out["path"] = out["upload_path"]
return out
class AttachmentsRepository:
def __init__(self, conn: Connection) -> None:
self._conn = conn
@@ -81,7 +66,7 @@ class AttachmentsRepository:
"legacy_mongo_id": legacy_mongo_id,
},
)
return _attachment_to_dict(result.fetchone())
return row_to_dict(result.fetchone())
def get(self, attachment_id: str, user_id: str) -> Optional[dict]:
result = self._conn.execute(
@@ -91,7 +76,7 @@ class AttachmentsRepository:
{"id": attachment_id, "user_id": user_id},
)
row = result.fetchone()
return _attachment_to_dict(row) if row is not None else None
return row_to_dict(row) if row is not None else None
def get_any(self, attachment_id: str, user_id: str) -> Optional[dict]:
"""Resolve an attachment by either PG UUID or legacy Mongo ObjectId string."""
@@ -170,14 +155,14 @@ class AttachmentsRepository:
params["user_id"] = user_id
result = self._conn.execute(text(sql), params)
row = result.fetchone()
return _attachment_to_dict(row) if row is not None else None
return row_to_dict(row) if row is not None else None
def list_for_user(self, user_id: str) -> list[dict]:
result = self._conn.execute(
text("SELECT * FROM attachments WHERE user_id = :user_id ORDER BY created_at DESC"),
{"user_id": user_id},
)
return [_attachment_to_dict(r) for r in result.fetchall()]
return [row_to_dict(r) for r in result.fetchall()]
def update(self, attachment_id: str, user_id: str, fields: dict) -> bool:
"""Partial update. Used by the LLM providers to cache their

View File

@@ -25,7 +25,6 @@ from typing import Any, Optional
from sqlalchemy import Connection, text
from application.storage.db.base_repository import row_to_dict
from application.storage.db.serialization import PGNativeJSONEncoder
_UPDATABLE_SCALARS = {
@@ -37,7 +36,7 @@ _UPDATABLE_JSONB = {"session_data", "token_info"}
def _jsonb(value: Any) -> Any:
if value is None:
return None
return json.dumps(value, cls=PGNativeJSONEncoder)
return json.dumps(value, default=str)
class ConnectorSessionsRepository:

View File

@@ -15,7 +15,6 @@ Covers every operation the legacy Mongo code performs on
from __future__ import annotations
import json
from enum import Enum
from typing import Optional
from sqlalchemy import Connection, text
@@ -23,23 +22,6 @@ from sqlalchemy.dialects.postgresql import insert as pg_insert
from application.storage.db.base_repository import looks_like_uuid, row_to_dict
from application.storage.db.models import conversations_table, conversation_messages_table
from application.storage.db.serialization import PGNativeJSONEncoder
class MessageUpdateOutcome(str, Enum):
"""Discriminated result of ``update_message_by_id``.
Distinguishes the row-actually-updated case from the row-already-at-
the-requested-terminal-state case so an abort handler can journal
``end`` instead of ``error`` when the normal-path finalize already
flipped the row to ``complete``.
"""
UPDATED = "updated"
ALREADY_COMPLETE = "already_complete"
ALREADY_FAILED = "already_failed"
NOT_FOUND = "not_found"
INVALID = "invalid"
def _message_row_to_dict(row) -> dict:
@@ -75,8 +57,8 @@ class ConversationsRepository:
- Already-UUID-shaped → returned as-is.
- Otherwise treated as a Mongo ObjectId and looked up via
``agents.legacy_mongo_id``. Returns ``None`` if no PG row
exists yet (e.g. the agent was created before the backfill
ran).
exists yet (e.g. the agent was created before Phase 1
backfill).
"""
if not agent_id_raw:
return None
@@ -470,7 +452,7 @@ class ConversationsRepository:
),
{
"id": conversation_id,
"point": json.dumps(point, cls=PGNativeJSONEncoder),
"point": json.dumps(point, default=str),
"max_points": int(max_points),
},
)
@@ -650,233 +632,6 @@ class ConversationsRepository:
result = self._conn.execute(text(sql), params)
return result.rowcount > 0
def reserve_message(
self,
conversation_id: str,
*,
prompt: str,
placeholder_response: str,
request_id: str | None = None,
status: str = "pending",
attachments: list[str] | None = None,
model_id: str | None = None,
metadata: dict | None = None,
) -> dict:
"""Pre-persist a placeholder assistant message before the LLM call."""
self._conn.execute(
text(
"SELECT id FROM conversations "
"WHERE id = CAST(:conv_id AS uuid) FOR UPDATE"
),
{"conv_id": conversation_id},
)
next_pos = self._conn.execute(
text(
"SELECT COALESCE(MAX(position), -1) + 1 AS next_pos "
"FROM conversation_messages "
"WHERE conversation_id = CAST(:conv_id AS uuid)"
),
{"conv_id": conversation_id},
).scalar()
values = {
"conversation_id": conversation_id,
"position": next_pos,
"prompt": prompt,
"response": placeholder_response,
"status": status,
"request_id": request_id,
"model_id": model_id,
"message_metadata": metadata or {},
}
if attachments:
resolved = self._resolve_attachment_refs(
[str(a) for a in attachments],
)
if resolved:
values["attachments"] = resolved
stmt = (
pg_insert(conversation_messages_table)
.values(**values)
.returning(conversation_messages_table)
)
result = self._conn.execute(stmt)
self._conn.execute(
text(
"UPDATE conversations SET updated_at = now() "
"WHERE id = CAST(:id AS uuid)"
),
{"id": conversation_id},
)
return _message_row_to_dict(result.fetchone())
def update_message_by_id(
self, message_id: str, fields: dict,
*, only_if_non_terminal: bool = False,
) -> MessageUpdateOutcome:
"""Update specific fields on a message identified by its UUID.
``metadata`` is merged into the existing JSONB rather than
overwritten, so a reconciler-set ``reconcile_attempts`` survives
a successful late finalize. When ``only_if_non_terminal`` is
True, the update is gated so a late finalize cannot retract a
reconciler-set ``failed`` (or a prior ``complete``).
The return value discriminates "I updated the row" from "the
row was already at a terminal state" so the abort handler can
journal ``end`` when the normal-path finalize already ran.
"""
if not looks_like_uuid(message_id):
return MessageUpdateOutcome.INVALID
allowed = {
"prompt", "response", "thought", "sources", "tool_calls",
"attachments", "model_id", "metadata", "timestamp", "status",
"request_id", "feedback", "feedback_timestamp",
}
filtered = {k: v for k, v in fields.items() if k in allowed}
if not filtered:
return MessageUpdateOutcome.INVALID
api_to_col = {"metadata": "message_metadata"}
set_parts = []
params: dict = {"id": message_id}
for key, val in filtered.items():
col = api_to_col.get(key, key)
if key == "metadata":
if val is None:
set_parts.append(f"{col} = NULL")
else:
set_parts.append(
f"{col} = COALESCE({col}, '{{}}'::jsonb) "
f"|| CAST(:{col} AS jsonb)"
)
params[col] = (
json.dumps(val) if not isinstance(val, str) else val
)
elif key in ("sources", "tool_calls", "feedback"):
set_parts.append(f"{col} = CAST(:{col} AS jsonb)")
if val is None:
params[col] = None
else:
params[col] = (
json.dumps(val) if not isinstance(val, str) else val
)
elif key == "attachments":
set_parts.append(f"{col} = CAST(:{col} AS uuid[])")
params[col] = self._resolve_attachment_refs(
[str(a) for a in val] if val else [],
)
else:
set_parts.append(f"{col} = :{col}")
params[col] = val
set_parts.append("updated_at = now()")
update_where = ["id = CAST(:id AS uuid)"]
if only_if_non_terminal:
update_where.append("status NOT IN ('complete', 'failed')")
# Single-statement attempt + prior-status probe. Both CTEs see
# the same MVCC snapshot, so ``prior.status`` reflects the row
# state before the UPDATE — exactly what we need to tell
# ``ALREADY_COMPLETE`` apart from ``ALREADY_FAILED`` apart from
# ``NOT_FOUND`` without a follow-up SELECT.
sql = (
"WITH attempted AS ("
f" UPDATE conversation_messages SET {', '.join(set_parts)} "
f" WHERE {' AND '.join(update_where)} "
" RETURNING 1 AS updated"
"), "
"prior AS ("
" SELECT status FROM conversation_messages "
" WHERE id = CAST(:id AS uuid)"
") "
"SELECT (SELECT updated FROM attempted) AS updated, "
" (SELECT status FROM prior) AS prior_status"
)
row = self._conn.execute(text(sql), params).fetchone()
if row is None:
return MessageUpdateOutcome.NOT_FOUND
updated, prior_status = row[0], row[1]
if updated:
return MessageUpdateOutcome.UPDATED
if prior_status is None:
return MessageUpdateOutcome.NOT_FOUND
if prior_status == "complete":
return MessageUpdateOutcome.ALREADY_COMPLETE
if prior_status == "failed":
return MessageUpdateOutcome.ALREADY_FAILED
# ``only_if_non_terminal=False`` always updates an existing row,
# so reaching here means the gate excluded it for some status
# the terminal set doesn't cover — treat as "not found" rather
# than inventing a new variant.
return MessageUpdateOutcome.NOT_FOUND
def update_message_status(
self, message_id: str, status: str,
) -> bool:
"""Cheap status-only transition (e.g. pending → streaming).
Only flips non-terminal rows: a reconciler-set ``failed`` row
stays put so the late streaming chunk doesn't silently retract
the alert.
"""
if not looks_like_uuid(message_id):
return False
result = self._conn.execute(
text(
"UPDATE conversation_messages SET status = :status, "
"updated_at = now() "
"WHERE id = CAST(:id AS uuid) "
"AND status NOT IN ('complete', 'failed')"
),
{"id": message_id, "status": status},
)
return result.rowcount > 0
def heartbeat_message(self, message_id: str) -> bool:
"""Stamp ``message_metadata.last_heartbeat_at`` with ``clock_timestamp()``.
The reconciler's staleness check uses ``GREATEST(timestamp,
last_heartbeat_at)``, so this call extends a long-running
stream's effective freshness without touching ``timestamp`` (the
creation time, used for history sort) or ``status`` (the WAL
marker). Skips terminal rows so a late heartbeat can't silently
retract a reconciler-set ``failed``.
"""
if not looks_like_uuid(message_id):
return False
result = self._conn.execute(
text(
"""
UPDATE conversation_messages
SET message_metadata = jsonb_set(
COALESCE(message_metadata, '{}'::jsonb),
'{last_heartbeat_at}',
to_jsonb(clock_timestamp())
)
WHERE id = CAST(:id AS uuid)
AND status NOT IN ('complete', 'failed')
"""
),
{"id": message_id},
)
return result.rowcount > 0
def confirm_executed_tool_calls(self, message_id: str) -> int:
"""Flip ``tool_call_attempts.status='executed''confirmed'`` for the message."""
if not looks_like_uuid(message_id):
return 0
result = self._conn.execute(
text(
"UPDATE tool_call_attempts SET status = 'confirmed', "
"updated_at = now() "
"WHERE message_id = CAST(:mid AS uuid) AND status = 'executed'"
),
{"mid": message_id},
)
return result.rowcount or 0
def truncate_after(self, conversation_id: str, keep_up_to: int) -> int:
"""Delete messages with position > keep_up_to.

View File

@@ -1,346 +0,0 @@
"""Repository for ``webhook_dedup`` and ``task_dedup``; 24h TTL enforced at read."""
from __future__ import annotations
import json
from typing import Any, Optional
from sqlalchemy import Connection, text
from application.storage.db.base_repository import row_to_dict
from application.storage.db.serialization import PGNativeJSONEncoder
# 24h TTL is the contract surfaced in the upload/webhook docstrings; the
# read filters and the stale-row replacement predicate must agree, or the
# upsert can fall into a window where the row is "fresh" to the writer
# but "expired" to the reader (or vice versa). Keep one constant so any
# future change moves both directions in lockstep.
DEDUP_TTL_INTERVAL = "24 hours"
def _jsonb(value: Any) -> Any:
if value is None:
return None
return json.dumps(value, cls=PGNativeJSONEncoder)
class IdempotencyRepository:
def __init__(self, conn: Connection) -> None:
self._conn = conn
# --- webhook_dedup -----------------------------------------------------
def get_webhook(self, key: str) -> Optional[dict]:
"""Return the cached webhook row for ``key`` if still within the 24h window."""
row = self._conn.execute(
text(
"""
SELECT * FROM webhook_dedup
WHERE idempotency_key = :key
AND created_at > now() - CAST(:ttl AS interval)
"""
),
{"key": key, "ttl": DEDUP_TTL_INTERVAL},
).fetchone()
return row_to_dict(row) if row is not None else None
def record_webhook(
self,
key: str,
agent_id: str,
task_id: str,
response_json: dict,
) -> Optional[dict]:
"""Insert a webhook dedup row; return None if another writer raced and won.
``ON CONFLICT`` replaces an existing row only when its ``created_at``
is past TTL — atomic stale-row recycling under the row lock. A
within-TTL conflict yields no row; the caller resolves it via
:meth:`get_webhook`.
"""
result = self._conn.execute(
text(
"""
INSERT INTO webhook_dedup (
idempotency_key, agent_id, task_id, response_json
)
VALUES (
:key, CAST(:agent_id AS uuid), :task_id,
CAST(:response_json AS jsonb)
)
ON CONFLICT (idempotency_key) DO UPDATE
SET agent_id = EXCLUDED.agent_id,
task_id = EXCLUDED.task_id,
response_json = EXCLUDED.response_json,
created_at = now()
WHERE webhook_dedup.created_at
<= now() - CAST(:ttl AS interval)
RETURNING *
"""
),
{
"key": key,
"agent_id": agent_id,
"task_id": task_id,
"response_json": _jsonb(response_json),
"ttl": DEDUP_TTL_INTERVAL,
},
)
row = result.fetchone()
return row_to_dict(row) if row is not None else None
# --- task_dedup --------------------------------------------------------
def get_task(self, key: str) -> Optional[dict]:
"""Return the cached task row for ``key`` if still within the 24h window."""
row = self._conn.execute(
text(
"""
SELECT * FROM task_dedup
WHERE idempotency_key = :key
AND created_at > now() - CAST(:ttl AS interval)
"""
),
{"key": key, "ttl": DEDUP_TTL_INTERVAL},
).fetchone()
return row_to_dict(row) if row is not None else None
def claim_task(
self,
key: str,
task_name: str,
task_id: str,
) -> Optional[dict]:
"""Claim ``key`` for this task. Returns the inserted row, or None if
another writer raced and won. The HTTP entry must call this *before*
``.delay()`` so only the winner enqueues the Celery task.
``ON CONFLICT`` replaces an existing row in two cases:
- **status='failed'**: the worker's poison-loop guard or the
reconciler's stuck-pending sweep finalised the prior attempt
as failed. Both explicitly intend a same-key retry to re-run
(see ``run_reconciliation`` Q5 docstring) — letting the row
block for 24 h would silently undo that intent.
- **created_at past TTL**: a stale claim from any status no
longer represents a meaningful dedup signal.
``status='completed'`` rows still block within TTL — that's the
cached-success contract callers rely on. ``status='pending'``
rows still block within TTL so concurrent same-key requests
collapse onto the in-flight task. Result/attempt fields are
reset to their fresh-claim defaults during replacement.
"""
result = self._conn.execute(
text(
"""
INSERT INTO task_dedup (
idempotency_key, task_name, task_id, result_json, status
)
VALUES (
:key, :task_name, :task_id, NULL, 'pending'
)
ON CONFLICT (idempotency_key) DO UPDATE
SET task_name = EXCLUDED.task_name,
task_id = EXCLUDED.task_id,
result_json = NULL,
status = 'pending',
attempt_count = 0,
created_at = now()
WHERE task_dedup.status = 'failed'
OR task_dedup.created_at
<= now() - CAST(:ttl AS interval)
RETURNING *
"""
),
{
"key": key,
"task_name": task_name,
"task_id": task_id,
"ttl": DEDUP_TTL_INTERVAL,
},
)
row = result.fetchone()
return row_to_dict(row) if row is not None else None
def try_claim_lease(
self,
key: str,
task_name: str,
task_id: str,
owner_id: str,
ttl_seconds: int = 60,
) -> Optional[int]:
"""Atomically claim the running lease for ``key``.
Returns the new ``attempt_count`` if this caller now owns the
lease (fresh insert OR existing row whose lease was empty/expired),
or ``None`` if a different worker holds a live lease.
The conflict path also bumps ``attempt_count`` so the
poison-loop guard in :func:`with_idempotency` can fire after
:data:`MAX_TASK_ATTEMPTS` reclaims. ``status='completed'`` rows
are deliberately untouched — :func:`_lookup_completed` is the
cache short-circuit and runs before this. Uses
``clock_timestamp()`` so a same-transaction refresh actually
moves the expiry forward (``now()`` is frozen at txn start).
"""
result = self._conn.execute(
text(
"""
INSERT INTO task_dedup (
idempotency_key, task_name, task_id, status, attempt_count,
lease_owner_id, lease_expires_at
) VALUES (
:key, :task_name, :task_id, 'pending', 1,
:owner,
clock_timestamp() + make_interval(secs => :ttl)
)
ON CONFLICT (idempotency_key) DO UPDATE
SET attempt_count = task_dedup.attempt_count + 1,
task_name = EXCLUDED.task_name,
lease_owner_id = EXCLUDED.lease_owner_id,
lease_expires_at = EXCLUDED.lease_expires_at
WHERE task_dedup.status <> 'completed'
AND (task_dedup.lease_expires_at IS NULL
OR task_dedup.lease_expires_at <= clock_timestamp())
RETURNING attempt_count
"""
),
{
"key": key,
"task_name": task_name,
"task_id": task_id,
"owner": owner_id,
"ttl": int(ttl_seconds),
},
)
row = result.fetchone()
return int(row[0]) if row is not None else None
def refresh_lease(
self,
key: str,
owner_id: str,
ttl_seconds: int = 60,
) -> bool:
"""Bump ``lease_expires_at`` if this caller still owns the lease.
Returns False when ownership was lost (lease stolen by another
worker after expiry, or row finalised). The heartbeat thread
logs that as a warning but doesn't try to abort the running
task — at-most-one-worker is bounded by ``ttl_seconds``, the
damage from a brief overlap window is unavoidable in this case.
"""
result = self._conn.execute(
text(
"""
UPDATE task_dedup
SET lease_expires_at =
clock_timestamp() + make_interval(secs => :ttl)
WHERE idempotency_key = :key
AND lease_owner_id = :owner
AND status = 'pending'
"""
),
{
"key": key,
"owner": owner_id,
"ttl": int(ttl_seconds),
},
)
return result.rowcount > 0
def release_lease(self, key: str, owner_id: str) -> bool:
"""Clear ``lease_owner_id`` / ``lease_expires_at`` on the
wrapper's exception path so Celery's autoretry_for doesn't have
to wait the full ``ttl_seconds`` before the next worker can
re-claim. No-op if a different worker has since taken over the
lease — that case is benign (we'd just be acknowledging we
weren't the owner anymore).
"""
result = self._conn.execute(
text(
"""
UPDATE task_dedup
SET lease_owner_id = NULL,
lease_expires_at = NULL
WHERE idempotency_key = :key
AND lease_owner_id = :owner
AND status = 'pending'
"""
),
{"key": key, "owner": owner_id},
)
return result.rowcount > 0
def finalize_task(
self,
key: str,
*,
result_json: Optional[dict],
status: str,
) -> bool:
"""Promote ``status='pending'`` → ``completed|failed`` with the
recorded result. Also clears the lease columns so a stale
``lease_expires_at`` doesn't show up in operator dashboards.
No-op if the row is already terminal — preserves the first
writer's outcome on a crash + retry.
"""
if status not in ("completed", "failed"):
raise ValueError(f"finalize_task: invalid status {status!r}")
result = self._conn.execute(
text(
"""
UPDATE task_dedup
SET status = :status,
result_json = CAST(:result_json AS jsonb),
lease_owner_id = NULL,
lease_expires_at = NULL
WHERE idempotency_key = :key
AND status = 'pending'
"""
),
{
"key": key,
"status": status,
"result_json": _jsonb(result_json),
},
)
return result.rowcount > 0
# --- housekeeping ------------------------------------------------------
def cleanup_expired(self) -> dict:
"""Delete rows past TTL from both dedup tables; return per-table counts.
The TTL-aware upserts already prevent stale rows from blocking new
work, so this is purely housekeeping — bounds table growth and
keeps test isolation cheap. Safe to run concurrently with other
writers: a same-key INSERT racing the DELETE will either find no
row (acts as a fresh insert) or find a fresh row (re-created
between DELETE and conflict-check), neither of which is wrong.
"""
task_deleted = self._conn.execute(
text(
"""
DELETE FROM task_dedup
WHERE created_at <= now() - CAST(:ttl AS interval)
"""
),
{"ttl": DEDUP_TTL_INTERVAL},
).rowcount
webhook_deleted = self._conn.execute(
text(
"""
DELETE FROM webhook_dedup
WHERE created_at <= now() - CAST(:ttl AS interval)
"""
),
{"ttl": DEDUP_TTL_INTERVAL},
).rowcount
return {
"task_dedup_deleted": int(task_deleted or 0),
"webhook_dedup_deleted": int(webhook_deleted or 0),
}

View File

@@ -1,127 +0,0 @@
"""Repository for ``ingest_chunk_progress``; per-source resume + heartbeat."""
from __future__ import annotations
from typing import Optional
from sqlalchemy import Connection, text
from application.storage.db.base_repository import row_to_dict
class IngestChunkProgressRepository:
"""Read/write helpers for ``ingest_chunk_progress``."""
def __init__(self, conn: Connection) -> None:
self._conn = conn
def init_progress(
self,
source_id: str,
total_chunks: int,
attempt_id: Optional[str] = None,
) -> dict:
"""Upsert the progress row, scoped by ``attempt_id``.
On conflict the upsert distinguishes two cases:
- **Same attempt** (``attempt_id`` matches the stored value):
this is a Celery autoretry of the same task — preserve
``last_index`` / ``embedded_chunks`` so the embed loop resumes
from the checkpoint. Only ``total_chunks`` and
``last_updated`` get refreshed.
- **Different attempt** (a fresh invocation: manual reingest,
scheduled sync, or any caller that didn't pass an
``attempt_id``): reset ``last_index`` to ``-1`` and
``embedded_chunks`` to ``0`` so the loop starts from chunk 0.
This prevents a completed checkpoint from any prior run
poisoning the index.
``IS NOT DISTINCT FROM`` treats two NULLs as equal — so legacy
rows with NULL ``attempt_id`` resume against another NULL
caller (e.g. test fixtures), but get reset the moment a real
``attempt_id`` arrives.
"""
result = self._conn.execute(
text(
"""
INSERT INTO ingest_chunk_progress (
source_id, total_chunks, embedded_chunks, last_index,
attempt_id, last_updated
)
VALUES (
CAST(:source_id AS uuid), :total_chunks, 0, -1,
:attempt_id, now()
)
ON CONFLICT (source_id) DO UPDATE SET
total_chunks = EXCLUDED.total_chunks,
last_updated = now(),
last_index = CASE
WHEN ingest_chunk_progress.attempt_id
IS NOT DISTINCT FROM EXCLUDED.attempt_id
THEN ingest_chunk_progress.last_index
ELSE -1
END,
embedded_chunks = CASE
WHEN ingest_chunk_progress.attempt_id
IS NOT DISTINCT FROM EXCLUDED.attempt_id
THEN ingest_chunk_progress.embedded_chunks
ELSE 0
END,
attempt_id = EXCLUDED.attempt_id
RETURNING *
"""
),
{
"source_id": str(source_id),
"total_chunks": int(total_chunks),
"attempt_id": attempt_id,
},
)
return row_to_dict(result.fetchone())
def record_chunk(
self, source_id: str, last_index: int, embedded_chunks: int
) -> None:
"""Persist progress after a chunk is embedded."""
self._conn.execute(
text(
"""
UPDATE ingest_chunk_progress
SET last_index = :last_index,
embedded_chunks = :embedded_chunks,
last_updated = now()
WHERE source_id = CAST(:source_id AS uuid)
"""
),
{
"source_id": str(source_id),
"last_index": int(last_index),
"embedded_chunks": int(embedded_chunks),
},
)
def get_progress(self, source_id: str) -> Optional[dict]:
"""Return the progress row for ``source_id`` if it exists."""
result = self._conn.execute(
text(
"SELECT * FROM ingest_chunk_progress "
"WHERE source_id = CAST(:source_id AS uuid)"
),
{"source_id": str(source_id)},
)
row = result.fetchone()
return row_to_dict(row) if row is not None else None
def bump_heartbeat(self, source_id: str) -> None:
"""Refresh ``last_updated`` so the row looks alive to the reconciler."""
self._conn.execute(
text(
"""
UPDATE ingest_chunk_progress
SET last_updated = now()
WHERE source_id = CAST(:source_id AS uuid)
"""
),
{"source_id": str(source_id)},
)

View File

@@ -1,248 +0,0 @@
"""Repository for ``message_events`` — the chat-stream snapshot journal.
``record`` / ``bulk_record`` write per-yield events; ``read_after``
replays rows past a cursor for reconnect snapshots. Composite PK
``(message_id, sequence_no)`` raises ``IntegrityError`` on duplicates.
Callers must use short-lived per-call transactions — long-lived
transactions hide writes from reconnecting clients on a separate
connection and turn one bad row into ``InFailedSqlTransaction``.
"""
from __future__ import annotations
import json
import logging
from typing import Any, Optional
from sqlalchemy import Connection, text
from application.storage.db.base_repository import row_to_dict
logger = logging.getLogger(__name__)
class MessageEventsRepository:
"""Read/write helpers for ``message_events``."""
def __init__(self, conn: Connection) -> None:
self._conn = conn
def record(
self,
message_id: str,
sequence_no: int,
event_type: str,
payload: Optional[Any] = None,
) -> None:
"""Append a single event to the journal.
At this raw repo layer ``payload`` is preserved as-is when not
``None`` (lists, scalars, and dicts all round-trip via JSONB);
``None`` substitutes an empty object so the column's NOT NULL
invariant holds. The streaming-route wrapper
``application/streaming/message_journal.py::record_event``
tightens this contract to dicts only — the live and replay
paths reconstruct non-dict payloads differently, so the wrapper
rejects them at the gate. Direct callers of this repo method
(cleanup tasks, tests, future ad-hoc consumers) keep the wider
JSONB-compatible surface.
Raises ``sqlalchemy.exc.IntegrityError`` on duplicate
``(message_id, sequence_no)`` and ``DataError`` on a malformed
``message_id`` UUID. Both abort the surrounding transaction —
callers must run inside a short-lived per-event session
(see module docstring).
"""
if not event_type:
raise ValueError("event_type must be a non-empty string")
materialised_payload = payload if payload is not None else {}
self._conn.execute(
text(
"""
INSERT INTO message_events (
message_id, sequence_no, event_type, payload
) VALUES (
CAST(:message_id AS uuid), :sequence_no, :event_type,
CAST(:payload AS jsonb)
)
"""
),
{
"message_id": str(message_id),
"sequence_no": int(sequence_no),
"event_type": event_type,
"payload": json.dumps(materialised_payload),
},
)
def bulk_record(
self,
message_id: str,
events: list[tuple[int, str, dict]],
) -> None:
"""Append multiple events for ``message_id`` in one INSERT.
``events`` is a list of ``(sequence_no, event_type, payload)``
tuples. SQLAlchemy ``executemany`` issues one bulk INSERT;
Postgres treats the whole batch as one statement, so an
IntegrityError on any row aborts the entire batch.
Caller contract: on IntegrityError, do NOT retry this method
with the same batch — fall back to per-row ``record()`` calls
(each in its own short-lived session) so a single colliding
seq doesn't drop the rest of the batch. ``BatchedJournalWriter``
in ``application/streaming/message_journal.py`` is the canonical
consumer.
"""
if not events:
return
params = [
{
"message_id": str(message_id),
"sequence_no": int(seq),
"event_type": event_type,
"payload": json.dumps(payload if payload is not None else {}),
}
for seq, event_type, payload in events
]
self._conn.execute(
text(
"""
INSERT INTO message_events (
message_id, sequence_no, event_type, payload
) VALUES (
CAST(:message_id AS uuid), :sequence_no, :event_type,
CAST(:payload AS jsonb)
)
"""
),
params,
)
def read_after(
self,
message_id: str,
last_sequence_no: Optional[int] = None,
) -> list[dict]:
"""Return events with ``sequence_no > last_sequence_no``.
``last_sequence_no=None`` returns the full backlog. Rows are
returned in ascending ``sequence_no`` order. The composite PK
is the snapshot read index for this scan — Postgres typically
picks an in-order index range scan, though for highly mixed
data the planner may pick a bitmap+sort. Either way the result
is sorted on ``sequence_no``.
Returns a ``list`` (not a generator) so the underlying
``Result`` is fully drained before the caller can issue
another query on the same connection.
"""
cursor = -1 if last_sequence_no is None else int(last_sequence_no)
rows = self._conn.execute(
text(
"""
SELECT message_id, sequence_no, event_type, payload, created_at
FROM message_events
WHERE message_id = CAST(:message_id AS uuid)
AND sequence_no > :cursor
ORDER BY sequence_no ASC
"""
),
{"message_id": str(message_id), "cursor": cursor},
).fetchall()
return [row_to_dict(row) for row in rows]
def cleanup_older_than(self, ttl_days: int) -> int:
"""Delete journal rows older than ``ttl_days``. Returns row count.
Reconnect-replay is meaningful only for streams the client
could plausibly still be waiting on, so old rows are dead
weight. The ``message_events_created_at_idx`` btree makes the
range delete a cheap index scan even on large tables.
"""
if ttl_days <= 0:
raise ValueError("ttl_days must be positive")
result = self._conn.execute(
text(
"""
DELETE FROM message_events
WHERE created_at < now() - make_interval(days => :ttl_days)
"""
),
{"ttl_days": int(ttl_days)},
)
return int(result.rowcount or 0)
def reconstruct_partial(self, message_id: str) -> dict:
"""Rebuild partial response/thought/sources/tool_calls from journal events.
``answer``/``thought`` chunks concat in seq order; ``source``/
``tool_calls`` carry the full list at emit time (last-wins).
"""
rows = self._conn.execute(
text(
"""
SELECT sequence_no, event_type, payload
FROM message_events
WHERE message_id = CAST(:message_id AS uuid)
ORDER BY sequence_no ASC
"""
),
{"message_id": str(message_id)},
).fetchall()
response_parts: list[str] = []
thought_parts: list[str] = []
sources: list = []
tool_calls: list = []
for row in rows:
payload = row.payload
if not isinstance(payload, dict):
continue
etype = row.event_type
if etype == "answer":
chunk = payload.get("answer")
if isinstance(chunk, str):
response_parts.append(chunk)
elif etype == "thought":
chunk = payload.get("thought")
if isinstance(chunk, str):
thought_parts.append(chunk)
elif etype == "source":
src = payload.get("source")
if isinstance(src, list):
sources = src
elif etype == "tool_calls":
tcs = payload.get("tool_calls")
if isinstance(tcs, list):
tool_calls = tcs
return {
"response": "".join(response_parts),
"thought": "".join(thought_parts),
"sources": sources,
"tool_calls": tool_calls,
}
def latest_sequence_no(self, message_id: str) -> Optional[int]:
"""Largest ``sequence_no`` recorded for ``message_id``, or ``None``.
Used by the route to seed the per-stream allocator on retry /
process restart so a re-run continues numbering instead of
trampling earlier entries with duplicate sequence_no.
"""
# ``MAX`` always returns one row — NULL when the journal is
# empty — so we test the value, not the row presence.
row = self._conn.execute(
text(
"""
SELECT MAX(sequence_no) AS s
FROM message_events
WHERE message_id = CAST(:message_id AS uuid)
"""
),
{"message_id": str(message_id)},
).first()
value = row[0] if row is not None else None
return int(value) if value is not None else None

View File

@@ -7,11 +7,6 @@ Mirrors the continuation service's three operations on
- load_state → find_one by (conversation_id, user_id)
- delete_state → delete_one by (conversation_id, user_id)
Adds ``mark_resuming`` so a resumed run can claim a row without
deleting it; a separate ``revert_stale_resuming`` flips abandoned
``resuming`` rows back to ``pending`` so a crashed worker doesn't
strand the user.
Plus a cleanup method for the Celery beat task that replaces Mongo's
TTL index.
"""
@@ -25,7 +20,6 @@ from typing import Optional
from sqlalchemy import Connection, text
from application.storage.db.base_repository import row_to_dict
from application.storage.db.serialization import PGNativeJSONEncoder
PENDING_STATE_TTL_SECONDS = 30 * 60 # 1800 seconds
@@ -77,24 +71,19 @@ class PendingToolStateRepository:
agent_config = EXCLUDED.agent_config,
client_tools = EXCLUDED.client_tools,
created_at = EXCLUDED.created_at,
expires_at = EXCLUDED.expires_at,
status = 'pending',
resumed_at = NULL
expires_at = EXCLUDED.expires_at
RETURNING *
"""
),
{
"conv_id": conversation_id,
"user_id": user_id,
"messages": json.dumps(messages, cls=PGNativeJSONEncoder),
"pending": json.dumps(pending_tool_calls, cls=PGNativeJSONEncoder),
"tools_dict": json.dumps(tools_dict, cls=PGNativeJSONEncoder),
"schemas": json.dumps(tool_schemas, cls=PGNativeJSONEncoder),
"agent_config": json.dumps(agent_config, cls=PGNativeJSONEncoder),
"client_tools": (
json.dumps(client_tools, cls=PGNativeJSONEncoder)
if client_tools is not None else None
),
"messages": json.dumps(messages),
"pending": json.dumps(pending_tool_calls),
"tools_dict": json.dumps(tools_dict),
"schemas": json.dumps(tool_schemas),
"agent_config": json.dumps(agent_config),
"client_tools": json.dumps(client_tools) if client_tools is not None else None,
"created_at": now,
"expires_at": expires,
},
@@ -124,45 +113,6 @@ class PendingToolStateRepository:
)
return result.rowcount > 0
def mark_resuming(self, conversation_id: str, user_id: str) -> bool:
"""Flip a pending row to ``resuming`` and stamp ``resumed_at``."""
result = self._conn.execute(
text(
"""
UPDATE pending_tool_state
SET status = 'resuming', resumed_at = clock_timestamp()
WHERE conversation_id = CAST(:conv_id AS uuid)
AND user_id = :user_id
AND status = 'pending'
"""
),
{"conv_id": conversation_id, "user_id": user_id},
)
return result.rowcount > 0
def revert_stale_resuming(
self,
grace_seconds: int = 600,
ttl_extension_seconds: int = PENDING_STATE_TTL_SECONDS,
) -> int:
"""Revert ``resuming`` rows older than ``grace_seconds`` to ``pending``; bump TTL."""
result = self._conn.execute(
text(
"""
UPDATE pending_tool_state
SET status = 'pending',
resumed_at = NULL,
expires_at = clock_timestamp()
+ make_interval(secs => :ttl)
WHERE status = 'resuming'
AND resumed_at
< clock_timestamp() - make_interval(secs => :grace)
"""
),
{"grace": grace_seconds, "ttl": ttl_extension_seconds},
)
return result.rowcount
def cleanup_expired(self) -> int:
"""Delete rows where ``expires_at < now()``.

View File

@@ -1,273 +0,0 @@
"""Repository for reconciliation sweeps over stuck durability rows."""
from __future__ import annotations
from sqlalchemy import Connection, text
from application.storage.db.base_repository import row_to_dict
class ReconciliationRepository:
"""Sweeps and terminal writes for the reconciler beat task."""
def __init__(self, conn: Connection) -> None:
self._conn = conn
def find_and_lock_stuck_messages(
self, *, age_minutes: int = 5, limit: int = 100,
) -> list[dict]:
"""Lock stuck pending/streaming messages skipping live resumes.
Staleness rides on the **later of** ``cm.timestamp`` (creation)
and ``message_metadata.last_heartbeat_at`` (route heartbeat). An
in-flight stream that re-stamps the heartbeat each minute stays
out of the sweep; reconciler-side writes deliberately don't
touch either column so the per-row attempts counter advances
across ticks. Liveness exemption covers both ``pending`` (paused
waiting for resume) and ``resuming`` (actively executing)
``pending_tool_state`` rows so a paused message survives until
the PT row's own TTL retires it.
"""
result = self._conn.execute(
text(
"""
SELECT cm.id, cm.conversation_id, cm.user_id, cm.timestamp,
cm.message_metadata
FROM conversation_messages cm
WHERE cm.status IN ('pending', 'streaming')
AND cm.timestamp < now() - make_interval(mins => :age)
AND COALESCE(
(cm.message_metadata->>'last_heartbeat_at')::timestamptz,
cm.timestamp
) < now() - make_interval(mins => :age)
AND NOT EXISTS (
SELECT 1
FROM pending_tool_state pts
WHERE pts.conversation_id = cm.conversation_id
AND (
(pts.status = 'pending'
AND pts.expires_at > now())
OR
(pts.status = 'resuming'
AND pts.resumed_at
> now() - interval '10 minutes')
)
)
ORDER BY cm.timestamp ASC
LIMIT :limit
FOR UPDATE OF cm SKIP LOCKED
"""
),
{"age": age_minutes, "limit": limit},
)
return [row_to_dict(r) for r in result.fetchall()]
def find_and_lock_proposed_tool_calls(
self, *, age_minutes: int = 5, limit: int = 100,
) -> list[dict]:
"""Lock tool_call_attempts that never advanced past ``proposed``."""
result = self._conn.execute(
text(
"""
SELECT call_id, message_id, tool_id, tool_name, action_name,
arguments, attempted_at, updated_at
FROM tool_call_attempts
WHERE status = 'proposed'
AND attempted_at < now() - make_interval(mins => :age)
ORDER BY attempted_at ASC
LIMIT :limit
FOR UPDATE SKIP LOCKED
"""
),
{"age": age_minutes, "limit": limit},
)
return [row_to_dict(r) for r in result.fetchall()]
def find_and_lock_executed_tool_calls(
self, *, age_minutes: int = 15, limit: int = 100,
) -> list[dict]:
"""Lock tool_call_attempts stuck in ``executed`` past confirm window."""
result = self._conn.execute(
text(
"""
SELECT call_id, message_id, tool_id, tool_name, action_name,
arguments, result, attempted_at, updated_at
FROM tool_call_attempts
WHERE status = 'executed'
AND updated_at < now() - make_interval(mins => :age)
ORDER BY updated_at ASC
LIMIT :limit
FOR UPDATE SKIP LOCKED
"""
),
{"age": age_minutes, "limit": limit},
)
return [row_to_dict(r) for r in result.fetchall()]
def find_and_lock_stalled_ingests(
self, *, age_minutes: int = 30, limit: int = 100,
) -> list[dict]:
"""Lock ingest checkpoints whose heartbeat hasn't ticked recently."""
result = self._conn.execute(
text(
"""
SELECT source_id, total_chunks, embedded_chunks,
last_index, last_updated
FROM ingest_chunk_progress
WHERE last_updated < now() - make_interval(mins => :age)
AND embedded_chunks < total_chunks
ORDER BY last_updated ASC
LIMIT :limit
FOR UPDATE SKIP LOCKED
"""
),
{"age": age_minutes, "limit": limit},
)
return [row_to_dict(r) for r in result.fetchall()]
def touch_ingest_progress(self, source_id: str) -> bool:
"""Bump ``last_updated`` so a once-stalled ingest re-enters the watch window."""
result = self._conn.execute(
text(
"UPDATE ingest_chunk_progress SET last_updated = now() "
"WHERE source_id = CAST(:sid AS uuid)"
),
{"sid": str(source_id)},
)
return result.rowcount > 0
def increment_message_reconcile_attempts(self, message_id: str) -> int:
"""Bump ``message_metadata.reconcile_attempts`` and return the new count."""
result = self._conn.execute(
text(
"""
UPDATE conversation_messages
SET message_metadata = jsonb_set(
COALESCE(message_metadata, '{}'::jsonb),
'{reconcile_attempts}',
to_jsonb(
COALESCE(
(message_metadata->>'reconcile_attempts')::int,
0
) + 1
)
)
WHERE id = CAST(:message_id AS uuid)
RETURNING (message_metadata->>'reconcile_attempts')::int
AS new_count
"""
),
{"message_id": message_id},
)
row = result.fetchone()
return int(row[0]) if row is not None else 0
def mark_message_failed(self, message_id: str, *, error: str) -> bool:
"""Flip a message to ``status='failed'`` and stash ``error`` in metadata."""
result = self._conn.execute(
text(
"""
UPDATE conversation_messages
SET status = 'failed',
message_metadata = jsonb_set(
COALESCE(message_metadata, '{}'::jsonb),
'{error}',
to_jsonb(CAST(:error AS text))
)
WHERE id = CAST(:message_id AS uuid)
"""
),
{"message_id": message_id, "error": error},
)
return result.rowcount > 0
def mark_tool_call_failed(self, call_id: str, *, error: str) -> bool:
"""Flip a tool_call_attempts row to ``failed`` with ``error``."""
result = self._conn.execute(
text(
"UPDATE tool_call_attempts SET status = 'failed', "
"error = :error WHERE call_id = :call_id"
),
{"call_id": call_id, "error": error},
)
return result.rowcount > 0
def find_stuck_idempotency_pending(
self,
*,
max_attempts: int,
lease_grace_seconds: int = 60,
limit: int = 100,
) -> list[dict]:
"""Lock ``task_dedup`` rows abandoned past the lease + retry budget.
A row is "stuck" when:
- ``status='pending'`` (lease was claimed but never finalised)
- ``lease_expires_at`` is past by at least ``lease_grace_seconds``
(the heartbeat thread is gone — the lease isn't going to come
back)
- ``attempt_count >= max_attempts`` (the poison-loop guard
should already have escalated this; if it hasn't, the wrapper
died before getting there)
These rows would otherwise sit in ``pending`` until the 24 h
TTL aged them out, blocking same-key retries via
``_lookup_completed`` returning None for the whole window.
"""
result = self._conn.execute(
text(
"""
SELECT idempotency_key, task_name, task_id, attempt_count,
lease_owner_id, lease_expires_at, created_at
FROM task_dedup
WHERE status = 'pending'
AND lease_expires_at IS NOT NULL
AND lease_expires_at
< now() - make_interval(secs => :grace)
AND attempt_count >= :max_attempts
ORDER BY created_at ASC
LIMIT :limit
FOR UPDATE SKIP LOCKED
"""
),
{
"max_attempts": int(max_attempts),
"grace": int(lease_grace_seconds),
"limit": int(limit),
},
)
return [row_to_dict(r) for r in result.fetchall()]
def mark_idempotency_pending_failed(
self, key: str, *, error: str,
) -> bool:
"""Promote a stuck pending ``task_dedup`` row to ``failed``."""
from application.storage.db.serialization import PGNativeJSONEncoder
import json
result = self._conn.execute(
text(
"""
UPDATE task_dedup
SET status = 'failed',
result_json = CAST(:result AS jsonb),
lease_owner_id = NULL,
lease_expires_at = NULL
WHERE idempotency_key = :key
AND status = 'pending'
"""
),
{
"key": key,
"result": json.dumps(
{
"success": False,
"error": error,
"reconciled": True,
},
cls=PGNativeJSONEncoder,
),
},
)
return result.rowcount > 0

View File

@@ -13,8 +13,6 @@ import json
from datetime import datetime
from typing import Optional
from application.storage.db.serialization import PGNativeJSONEncoder
from sqlalchemy import Connection, text
@@ -54,7 +52,7 @@ class StackLogsRepository:
"user_id": user_id,
"api_key": api_key,
"query": query,
"stacks": json.dumps(stacks or [], cls=PGNativeJSONEncoder),
"stacks": json.dumps(stacks or []),
"timestamp": timestamp,
},
)

View File

@@ -31,8 +31,6 @@ class TokenUsageRepository:
agent_id: Optional[str] = None,
prompt_tokens: int = 0,
generated_tokens: int = 0,
source: str = "agent_stream",
request_id: Optional[str] = None,
timestamp: Optional[datetime] = None,
) -> None:
# Attribution guard: the ``token_usage_attribution_chk`` CHECK
@@ -56,16 +54,12 @@ class TokenUsageRepository:
self._conn.execute(
text(
"""
INSERT INTO token_usage (
user_id, api_key, agent_id,
prompt_tokens, generated_tokens,
source, request_id, timestamp
)
INSERT INTO token_usage (user_id, api_key, agent_id, prompt_tokens, generated_tokens, timestamp)
VALUES (
:user_id, :api_key,
CAST(:agent_id AS uuid),
:prompt_tokens, :generated_tokens,
:source, :request_id, COALESCE(:timestamp, now())
COALESCE(:timestamp, now())
)
"""
),
@@ -75,8 +69,6 @@ class TokenUsageRepository:
"agent_id": agent_id_uuid,
"prompt_tokens": prompt_tokens,
"generated_tokens": generated_tokens,
"source": source,
"request_id": request_id,
"timestamp": timestamp,
},
)
@@ -181,22 +173,8 @@ class TokenUsageRepository:
user_id: Optional[str] = None,
api_key: Optional[str] = None,
) -> int:
"""Count user-initiated requests in the given time range.
A request = one ``agent_stream`` invocation. Multi-tool agent
runs produce multiple rows (one per LLM call) tagged with the
same ``request_id``; we DISTINCT on that to count the request
once. Pre-migration rows have ``request_id=NULL`` and are
counted one-per-row via the second branch (back-compat).
Side-channel sources (``title`` / ``compression`` /
``rag_condense`` / ``fallback``) are excluded — they aren't
user-initiated and shouldn't tick the request limit.
"""
clauses = [
"timestamp >= :start",
"timestamp <= :end",
"source = 'agent_stream'",
]
"""Count of token_usage rows in the given time range (for request limiting)."""
clauses = ["timestamp >= :start", "timestamp <= :end"]
params: dict = {"start": start, "end": end}
if user_id is not None:
clauses.append("user_id = :user_id")
@@ -206,15 +184,7 @@ class TokenUsageRepository:
params["api_key"] = api_key
where = " AND ".join(clauses)
result = self._conn.execute(
text(
f"""
SELECT
COUNT(DISTINCT request_id) FILTER (WHERE request_id IS NOT NULL)
+ COUNT(*) FILTER (WHERE request_id IS NULL)
FROM token_usage
WHERE {where}
"""
),
text(f"SELECT COUNT(*) FROM token_usage WHERE {where}"),
params,
)
return result.scalar()

View File

@@ -1,144 +0,0 @@
"""Repository for ``tool_call_attempts``; executor's proposed/executed/failed writes."""
from __future__ import annotations
import json
from typing import Any, Optional
from sqlalchemy import Connection, text
from application.storage.db.serialization import PGNativeJSONEncoder
class ToolCallAttemptsRepository:
def __init__(self, conn: Connection) -> None:
self._conn = conn
def record_proposed(
self,
call_id: str,
tool_name: str,
action_name: str,
arguments: Any,
*,
tool_id: Optional[str] = None,
) -> bool:
"""Insert a ``proposed`` row before the tool executes.
Returns True if a new row was created. ``ON CONFLICT DO NOTHING``
guards against the LLM emitting a duplicate ``call_id``: the
existing row stays put rather than a re-insert raising
``IntegrityError``.
"""
result = self._conn.execute(
text(
"""
INSERT INTO tool_call_attempts
(call_id, tool_id, tool_name, action_name, arguments, status)
VALUES
(:call_id, CAST(:tool_id AS uuid), :tool_name,
:action_name, CAST(:arguments AS jsonb), 'proposed')
ON CONFLICT (call_id) DO NOTHING
"""
),
{
"call_id": call_id,
"tool_id": tool_id,
"tool_name": tool_name,
"action_name": action_name,
"arguments": json.dumps(arguments if arguments is not None else {}, cls=PGNativeJSONEncoder),
},
)
return result.rowcount > 0
def upsert_executed(
self,
call_id: str,
tool_name: str,
action_name: str,
arguments: Any,
result: Any,
*,
tool_id: Optional[str] = None,
message_id: Optional[str] = None,
artifact_id: Optional[str] = None,
) -> None:
"""Insert OR upgrade a row to ``executed``.
Used as a fallback when ``record_proposed`` failed (DB outage)
and the tool ran anyway — preserves the journal so the
reconciler can still see the attempt.
"""
result_payload: dict = {"result": result}
if artifact_id:
result_payload["artifact_id"] = artifact_id
self._conn.execute(
text(
"""
INSERT INTO tool_call_attempts
(call_id, tool_id, tool_name, action_name, arguments,
result, message_id, status)
VALUES
(:call_id, CAST(:tool_id AS uuid), :tool_name,
:action_name, CAST(:arguments AS jsonb),
CAST(:result AS jsonb), CAST(:message_id AS uuid),
'executed')
ON CONFLICT (call_id) DO UPDATE
SET status = 'executed',
result = EXCLUDED.result,
message_id = COALESCE(EXCLUDED.message_id, tool_call_attempts.message_id)
"""
),
{
"call_id": call_id,
"tool_id": tool_id,
"tool_name": tool_name,
"action_name": action_name,
"arguments": json.dumps(arguments if arguments is not None else {}, cls=PGNativeJSONEncoder),
"result": json.dumps(result_payload, cls=PGNativeJSONEncoder),
"message_id": message_id,
},
)
def mark_executed(
self,
call_id: str,
result: Any,
*,
message_id: Optional[str] = None,
artifact_id: Optional[str] = None,
) -> bool:
"""Flip ``proposed`` → ``executed`` with the tool result.
``artifact_id`` (when present) is stored alongside ``result`` in
the JSONB as audit data — the reconciler reads it for diagnostic
alerts when escalating stuck rows to ``failed``.
"""
result_payload: dict = {"result": result}
if artifact_id:
result_payload["artifact_id"] = artifact_id
sql = (
"UPDATE tool_call_attempts SET "
"status = 'executed', result = CAST(:result AS jsonb)"
)
params: dict[str, Any] = {
"call_id": call_id,
"result": json.dumps(result_payload, cls=PGNativeJSONEncoder),
}
if message_id is not None:
sql += ", message_id = CAST(:message_id AS uuid)"
params["message_id"] = message_id
sql += " WHERE call_id = :call_id"
result_proxy = self._conn.execute(text(sql), params)
return result_proxy.rowcount > 0
def mark_failed(self, call_id: str, error: str) -> bool:
"""Flip ``proposed`` → ``failed`` with the exception text."""
result = self._conn.execute(
text(
"UPDATE tool_call_attempts SET status = 'failed', error = :error "
"WHERE call_id = :call_id"
),
{"call_id": call_id, "error": error},
)
return result.rowcount > 0

View File

@@ -20,7 +20,6 @@ from typing import Optional
from sqlalchemy import Connection, text
from application.storage.db.base_repository import row_to_dict
from application.storage.db.serialization import PGNativeJSONEncoder
class UserLogsRepository:
@@ -47,7 +46,7 @@ class UserLogsRepository:
{
"user_id": user_id,
"endpoint": endpoint,
"data": json.dumps(data, cls=PGNativeJSONEncoder) if data is not None else None,
"data": json.dumps(data, default=str) if data is not None else None,
"timestamp": timestamp,
},
)

View File

@@ -1,93 +0,0 @@
"""JSON-safe coercion for PG-native Python types.
Postgres (via psycopg) returns native Python types — ``uuid.UUID``,
``datetime.datetime``/``datetime.date``, ``decimal.Decimal``, ``bytes``
— that ``json.dumps`` rejects. This module is the single place those
coercion rules live; everywhere else should call into it.
Two interfaces with identical coverage:
* :func:`coerce_pg_native` — recursive walk returning a JSON-safe copy.
Use when you need to inspect the dict yourself or pass it to a
serializer that doesn't accept a custom encoder (e.g. SQLAlchemy
parameter binding for a JSONB column).
* :class:`PGNativeJSONEncoder` — ``JSONEncoder`` subclass. Use as
``json.dumps(obj, cls=PGNativeJSONEncoder)`` for serialise-once flows
where the extra recursive walk is wasted work.
Coercion rules:
* ``UUID`` → canonical hex string.
* ``datetime`` / ``date`` → ISO 8601 string.
* ``Decimal`` → numeric string (preserves precision; ``float()`` would not).
* ``bytes`` → base64 string. Lossless and universally JSON-safe;
prior code used UTF-8 with ``errors="replace"`` which silently
corrupted binary payloads (e.g. Gemini's ``thought_signature``).
"""
from __future__ import annotations
import base64
import binascii
import json
from datetime import date, datetime
from decimal import Decimal
from typing import Any
from uuid import UUID
def _coerce_scalar(obj: Any) -> Any:
if isinstance(obj, UUID):
return str(obj)
if isinstance(obj, (datetime, date)):
return obj.isoformat()
if isinstance(obj, Decimal):
return str(obj)
if isinstance(obj, bytes):
return base64.b64encode(obj).decode("ascii")
return obj
def coerce_pg_native(obj: Any) -> Any:
"""Recursively coerce PG-native types to JSON-safe equivalents.
Recurses into ``dict`` (stringifying keys, matching prior helper
behavior) and ``list``/``tuple`` (tuples flatten to lists since JSON
has no tuple type). Any other type passes through unchanged.
"""
if isinstance(obj, dict):
return {str(k): coerce_pg_native(v) for k, v in obj.items()}
if isinstance(obj, (list, tuple)):
return [coerce_pg_native(v) for v in obj]
return _coerce_scalar(obj)
def decode_base64_bytes(value: Any) -> Any:
"""Reverse ``coerce_pg_native``'s bytes-to-base64 step.
Useful at egress points that need the original bytes back (e.g.
sending Gemini's ``thought_signature`` to the SDK on resume). Uses
``validate=True`` so plain ASCII strings that happen to be
permissively decodable (e.g. ``"abcd"``) are not silently turned
into bytes — the original value passes through.
"""
if isinstance(value, str):
try:
return base64.b64decode(value.encode("ascii"), validate=True)
except (binascii.Error, ValueError):
return value
return value
class PGNativeJSONEncoder(json.JSONEncoder):
"""``JSONEncoder`` covering UUID / datetime / date / Decimal / bytes.
Use as ``json.dumps(obj, cls=PGNativeJSONEncoder)``. Equivalent in
coverage to :func:`coerce_pg_native` but skips the eager walk.
"""
def default(self, obj: Any) -> Any:
coerced = _coerce_scalar(obj)
if coerced is obj:
return super().default(obj)
return coerced

View File

@@ -1,23 +0,0 @@
"""Deterministic source-id derivation for idempotent ingest.
DO NOT CHANGE the pinned UUID namespace — it backs cross-deploy
idempotency keys.
"""
from __future__ import annotations
import uuid
# DO NOT CHANGE. See module docstring.
DOCSGPT_INGEST_NAMESPACE = uuid.UUID("fa25d5d1-398b-46df-ac89-8d1c360b9bea")
def derive_source_id(idempotency_key) -> uuid.UUID:
"""``uuid5(NS, key)`` when a key is supplied; ``uuid4()`` otherwise.
A non-string / empty key falls back to ``uuid4()`` so the caller
always gets a fresh id rather than a TypeError mid-route.
"""
if isinstance(idempotency_key, str) and idempotency_key:
return uuid.uuid5(DOCSGPT_INGEST_NAMESPACE, idempotency_key)
return uuid.uuid4()

View File

@@ -1,126 +0,0 @@
"""Redis pub/sub Topic abstraction for SSE fan-out.
A Topic is a named channel for one-shot live event delivery. Canonical uses:
- ``user:{user_id}`` for per-user notifications
- ``channel:{message_id}`` for per-chat-message streams
Subscription is race-free via ``on_subscribe``: the callback fires only
after Redis acknowledges ``SUBSCRIBE``, so a publisher dispatched inside
the callback cannot lose its first event to a not-yet-registered
subscriber.
The subscribe iterator yields ``None`` on poll timeout so the caller can
emit SSE keepalive comments without spawning a separate timer thread.
"""
from __future__ import annotations
import logging
from typing import Callable, Iterator, Optional
from application.cache import get_redis_instance
logger = logging.getLogger(__name__)
class Topic:
"""A pub/sub channel identified by a string name."""
def __init__(self, name: str) -> None:
self.name = name
def publish(self, payload: str | bytes) -> int:
"""Fan out a payload to currently subscribed clients.
Returns the number Redis reports as receiving the message (limited
to subscribers connected to *this* Redis instance), or 0 if Redis
is unavailable. Never raises.
"""
redis = get_redis_instance()
if redis is None:
logger.debug("Redis unavailable; dropping publish to %s", self.name)
return 0
try:
return int(redis.publish(self.name, payload))
except Exception:
logger.exception("Topic.publish failed for %s", self.name)
return 0
def subscribe(
self,
on_subscribe: Optional[Callable[[], None]] = None,
poll_timeout: float = 1.0,
) -> Iterator[Optional[bytes]]:
"""Subscribe to the topic; yield raw payloads or ``None`` on tick.
Yields ``None`` every ``poll_timeout`` seconds while idle so the
caller can emit keepalive frames or check cancellation. Yields
``bytes`` for each delivered message.
``on_subscribe`` runs synchronously after Redis acknowledges the
SUBSCRIBE — use it to seed any state (e.g. read backlog) that
must be ordered after the subscriber is live but before the
first pub/sub message is processed.
If Redis is unavailable, returns immediately without yielding.
Cleanly unsubscribes on ``GeneratorExit`` (client disconnect).
"""
redis = get_redis_instance()
if redis is None:
logger.debug("Redis unavailable; subscribe to %s yielded nothing", self.name)
return
pubsub = None
on_subscribe_fired = False
try:
pubsub = redis.pubsub()
try:
pubsub.subscribe(self.name)
except Exception:
# Subscribe failure (transient Redis hiccup, conn reset, etc.)
# is treated like "Redis unavailable": yield nothing, let the
# caller fall back to its own resilience strategy. The finally
# block will still tear down the pubsub object cleanly.
logger.exception("pubsub.subscribe failed for %s", self.name)
return
while True:
try:
msg = pubsub.get_message(timeout=poll_timeout)
except Exception:
logger.exception("pubsub.get_message failed for %s", self.name)
return
if msg is None:
yield None
continue
msg_type = msg.get("type")
if msg_type == "subscribe":
if not on_subscribe_fired and on_subscribe is not None:
try:
on_subscribe()
except Exception:
logger.exception(
"on_subscribe callback failed for %s", self.name
)
on_subscribe_fired = True
continue
if msg_type != "message":
continue
data = msg.get("data")
if data is None:
continue
yield data if isinstance(data, bytes) else str(data).encode("utf-8")
finally:
if pubsub is not None:
if on_subscribe_fired:
try:
pubsub.unsubscribe(self.name)
except Exception:
logger.debug(
"pubsub unsubscribe error for %s",
self.name,
exc_info=True,
)
try:
pubsub.close()
except Exception:
logger.debug("pubsub close error for %s", self.name, exc_info=True)

View File

@@ -1,434 +0,0 @@
"""Snapshot+tail iterator for chat-stream reconnect-after-disconnect.
Subscribe to ``channel:{message_id}``, snapshot ``message_events``
rows past ``last_event_id`` inside the SUBSCRIBE-ack callback, flush
snapshot, then tail live pub/sub (dedup'd by ``sequence_no``). See
``docs/runbooks/sse-notifications.md``.
"""
from __future__ import annotations
import json
import logging
import re
import time
from typing import Iterator, Optional
from sqlalchemy import text as sql_text
from application.storage.db.repositories.message_events import (
MessageEventsRepository,
)
from application.storage.db.session import db_readonly
from application.streaming.broadcast_channel import Topic
from application.streaming.keys import message_topic_name
logger = logging.getLogger(__name__)
DEFAULT_KEEPALIVE_SECONDS = 15.0
DEFAULT_POLL_TIMEOUT_SECONDS = 1.0
# When the live tail has no events and no terminal in snapshot, fall
# back to checking ``conversation_messages`` directly. If the row has
# already gone terminal (worker journaled ``end``/``error`` to the DB
# but the matching pub/sub publish was lost, or the row was finalized
# without a journal write at all) we surface a terminal event so the
# client doesn't hang on keepalives. If the row is still non-terminal
# but the producer heartbeat is older than ``PRODUCER_IDLE_SECONDS``
# the producer is presumed dead (worker crash / recycle between chunks
# and finalize) and we emit a terminal ``error`` so the UI can recover.
DEFAULT_WATCHDOG_INTERVAL_SECONDS = 5.0
# 1.5× the route's 60s heartbeat interval — long enough that a normal
# heartbeat skew doesn't false-positive, short enough that a stuck
# stream surfaces before the 5-minute reconciler sweep escalates.
DEFAULT_PRODUCER_IDLE_SECONDS = 90.0
# WHATWG SSE accepts CRLF, CR, LF — split on any of them so a stray CR
# can't smuggle a record boundary into the wire format.
_SSE_LINE_SPLIT_PATTERN = re.compile(r"\r\n|\r|\n")
# Event types that mark the end of a chat answer. After delivering one
# we close the reconnect stream — keeping the connection open past a
# terminal event would leak both the client's reconnect promise and
# the server's WSGI thread waiting on keepalives that the user no
# longer cares about. The agent loop emits ``end`` for normal /
# tool-paused completion and ``error`` for the catch-all failure path
# (which doesn't get a trailing ``end``).
_TERMINAL_EVENT_TYPES = frozenset({"end", "error"})
def _payload_is_terminal(
payload: object, event_type: Optional[str] = None
) -> bool:
"""True if ``payload['type']`` or ``event_type`` is a terminal sentinel."""
if isinstance(payload, dict) and payload.get("type") in _TERMINAL_EVENT_TYPES:
return True
return event_type in _TERMINAL_EVENT_TYPES
def format_sse_event(payload: dict, sequence_no: int) -> str:
"""Encode a journal event as one ``id:``/``data:`` SSE record.
The body is the payload's JSON serialisation. ``complete_stream``
payloads are flat JSON dicts with no embedded newlines, so a
single ``data:`` line is sufficient — but we still split on any
line terminator in case a future caller passes a multi-line string.
"""
body = json.dumps(payload)
lines = [f"id: {sequence_no}"]
for line in _SSE_LINE_SPLIT_PATTERN.split(body):
lines.append(f"data: {line}")
return "\n".join(lines) + "\n\n"
def _check_producer_liveness(
message_id: str, idle_seconds: float
) -> Optional[dict]:
"""Inspect ``conversation_messages`` and return a terminal SSE
payload when the producer is no longer alive, else ``None``.
Three terminal cases collapse into a single DB round-trip:
- ``status='complete'`` — the live finalize ran but its journal
terminal write didn't reach us (or never happened). Synthesise
``end`` so the client closes cleanly on the row's user-visible
state.
- ``status='failed'`` — same, but for the failure path. Carry the
stashed ``error`` from ``message_metadata`` so the UI shows the
real reason.
- non-terminal status and ``last_heartbeat_at`` (or ``timestamp``)
older than ``idle_seconds`` — the producing worker is gone.
Synthesise ``error`` so the client doesn't hang on keepalives
until the proxy idle-timeout kicks in.
"""
try:
with db_readonly() as conn:
row = conn.execute(
sql_text(
"""
SELECT
status,
message_metadata->>'error' AS err,
GREATEST(
timestamp,
COALESCE(
(message_metadata->>'last_heartbeat_at')
::timestamptz,
timestamp
)
) < now() - make_interval(secs => :idle_secs)
AS is_stale
FROM conversation_messages
WHERE id = CAST(:id AS uuid)
"""
),
{"id": message_id, "idle_secs": float(idle_seconds)},
).first()
except Exception:
logger.exception(
"Watchdog liveness check failed for message_id=%s", message_id
)
return None
if row is None:
# Row deleted out from under us — treat as terminal so the
# client doesn't keep tailing a message that no longer exists.
return {
"type": "error",
"error": "Message no longer exists; please refresh.",
"code": "message_missing",
"message_id": message_id,
}
status, err, is_stale = row[0], row[1], bool(row[2])
if status == "complete":
return {"type": "end"}
if status == "failed":
return {
"type": "error",
"error": err or "Stream failed; please try again.",
"code": "producer_failed",
"message_id": message_id,
}
if is_stale:
return {
"type": "error",
"error": (
"Stream producer is no longer responding; please try again."
),
"code": "producer_stale",
"message_id": message_id,
}
return None
def build_message_event_stream(
message_id: str,
last_event_id: Optional[int] = None,
*,
keepalive_seconds: float = DEFAULT_KEEPALIVE_SECONDS,
poll_timeout_seconds: float = DEFAULT_POLL_TIMEOUT_SECONDS,
watchdog_interval_seconds: float = DEFAULT_WATCHDOG_INTERVAL_SECONDS,
producer_idle_seconds: float = DEFAULT_PRODUCER_IDLE_SECONDS,
) -> Iterator[str]:
"""Yield SSE-formatted lines for one ``message_id`` reconnect stream.
First frame is ``: connected``; subsequent frames are snapshot rows,
live-tail events, or ``: keepalive`` comments. Runs until the client
disconnects.
"""
yield ": connected\n\n"
# Replay buffer — populated inside ``_on_subscribe`` (or the
# Redis-unavailable fallback below), drained on the first iteration
# of the subscribe loop after the callback runs.
replay_buffer: list[str] = []
# Dedup floor: seeded with the client's cursor so an empty snapshot
# still rejects re-published live events with seq <= last_event_id.
# Advanced by snapshot rows AND by yielded live events, so any
# republish past the snapshot ceiling is also dropped.
max_replayed_seq: Optional[int] = last_event_id
replay_done = False
replay_failed = False
# Set when a snapshot row carries a terminal ``end`` / ``error``
# event. After flushing the buffer the generator returns; if we
# kept tailing we'd loop on keepalives forever for a stream that
# already finished.
terminal_in_snapshot = False
def _read_snapshot_into_buffer() -> None:
nonlocal max_replayed_seq, replay_failed, terminal_in_snapshot
try:
with db_readonly() as conn:
rows = MessageEventsRepository(conn).read_after(
message_id, last_sequence_no=last_event_id
)
for row in rows:
seq = int(row["sequence_no"])
payload = row.get("payload")
if not isinstance(payload, dict):
# ``record_event`` rejects non-dict payloads at the
# write gate, so this can only be a legacy row from
# before that contract or a direct SQL insert. The
# original synthetic fallback (``{"type": event_type}``)
# used to ship a malformed envelope here — drop the
# row instead so a corrupt journal entry doesn't
# poison a reconnect.
logger.warning(
"Skipping non-dict payload from message_events: "
"message_id=%s seq=%s type=%s",
message_id,
seq,
row.get("event_type"),
)
continue
replay_buffer.append(format_sse_event(payload, seq))
if max_replayed_seq is None or seq > max_replayed_seq:
max_replayed_seq = seq
if _payload_is_terminal(payload, row.get("event_type")):
terminal_in_snapshot = True
except Exception:
logger.exception(
"Snapshot read failed for message_id=%s last_event_id=%s",
message_id,
last_event_id,
)
replay_failed = True
def _on_subscribe() -> None:
# SUBSCRIBE has been acked — Postgres reads from this point
# capture every row that's been committed. Pub/sub messages
# published after this point are queued at the connection level
# until the outer loop calls ``get_message`` again.
nonlocal replay_done
try:
_read_snapshot_into_buffer()
finally:
# Flip even on failure so the outer loop continues to live
# tail and the client doesn't hang waiting for a snapshot
# flush that will never come.
replay_done = True
topic = Topic(message_topic_name(message_id))
last_keepalive = time.monotonic()
# Rate-limit the watchdog's DB hit. ``-inf`` makes the first idle
# tick after replay_done fire immediately so a snapshot-already-
# terminal-in-DB case is surfaced before any keepalive cadence.
# Subsequent checks are gated by ``watchdog_interval_seconds``.
last_watchdog_check = float("-inf")
# Synthetic terminal events emitted by the watchdog use the same
# ``sequence_no=-1`` convention as the snapshot-failure path so the
# frontend's strict ``\d+`` cursor regex rejects them as a
# ``Last-Event-ID`` for any future reconnect. The chosen
# discriminator ensures a manual page refresh after a watchdog-fired
# error doesn't loop on the same synthetic id.
watchdog_synthetic_seq = -1
try:
for payload in topic.subscribe(
on_subscribe=_on_subscribe,
poll_timeout=poll_timeout_seconds,
):
# Flush snapshot exactly once after the SUBSCRIBE callback
# has run and produced a buffer.
if replay_done and replay_buffer:
for line in replay_buffer:
yield line
replay_buffer.clear()
if terminal_in_snapshot:
# The original stream already finished; tailing
# would just emit keepalives forever and pin both a
# client reconnect promise and a server WSGI thread.
return
if replay_failed:
# Snapshot read failed (DB blip / transient timeout). Emit a
# terminal ``error`` event and return — the client only
# reconnects after the original stream has already moved on,
# so without a snapshot there's nothing live left to tail and
# holding the connection open would just emit keepalives
# until the proxy idle-timeout fires. ``code`` preserves the
# snapshot-vs-agent-loop distinction so a future client can
# opt into a refetch instead of a hard failure.
yield format_sse_event(
{
"type": "error",
"error": "Stream replay failed; please refresh to load the latest state.",
"code": "snapshot_failed",
"message_id": message_id,
},
sequence_no=-1,
)
return
now = time.monotonic()
if payload is None:
# Idle tick — check both keepalive and watchdog. The
# watchdog only kicks in once the snapshot half has been
# flushed (``replay_done``) so we don't race the
# snapshot read on the first iteration.
if (
replay_done
and watchdog_interval_seconds >= 0
and now - last_watchdog_check >= watchdog_interval_seconds
):
last_watchdog_check = now
terminal_payload = _check_producer_liveness(
message_id, producer_idle_seconds
)
if terminal_payload is not None:
yield format_sse_event(
terminal_payload,
sequence_no=watchdog_synthetic_seq,
)
return
if now - last_keepalive >= keepalive_seconds:
yield ": keepalive\n\n"
last_keepalive = now
continue
envelope = _decode_pubsub_message(payload)
if envelope is None:
continue
seq = envelope.get("sequence_no")
inner = envelope.get("payload")
if (
not isinstance(seq, int)
or isinstance(seq, bool)
or not isinstance(inner, dict)
):
continue
if max_replayed_seq is not None and seq <= max_replayed_seq:
# Snapshot already covered this id — drop the duplicate.
continue
yield format_sse_event(inner, seq)
# Advance the dedup floor on the live path too, so a stale
# republish of an already-yielded seq (process restart, retry
# tool, etc.) is dropped on a later iteration.
max_replayed_seq = seq
last_keepalive = now
if _payload_is_terminal(inner, envelope.get("event_type")):
# Live tail just delivered the terminal event — close
# out the reconnect stream so the client's drain
# promise resolves and the WSGI thread is freed.
return
# Subscribe exited without ever yielding (Redis unavailable,
# ``pubsub.subscribe`` raised, or the inner loop died between
# SUBSCRIBE-ack and the first poll). The snapshot half is in
# Postgres and is still serviceable — read it directly so a
# Redis-only outage doesn't cost the client their reconnect
# backlog. Gate the read on ``replay_done`` rather than
# ``subscribe_started``: if ``_on_subscribe`` already populated
# the buffer, re-reading would append the same rows twice and
# double the answer chunks on the client (the per-message
# reconnect dispatcher does not dedup by ``id``).
if not replay_done:
_read_snapshot_into_buffer()
replay_done = True
for line in replay_buffer:
yield line
replay_buffer.clear()
if replay_failed:
# Mirror the live-tail branch: emit a terminal ``error`` so
# the frontend's existing end/error handling drives the UI
# to a failed state instead of relying on the proxy timeout.
yield format_sse_event(
{
"type": "error",
"error": "Stream replay failed; please refresh to load the latest state.",
"code": "snapshot_failed",
"message_id": message_id,
},
sequence_no=-1,
)
return
# Same close-on-terminal contract as the live-tail branch.
# Without it a Redis-down + already-completed-stream client
# would also hang on a never-ending generator.
if terminal_in_snapshot:
return
except GeneratorExit:
# Client disconnect — let the underlying ``Topic.subscribe``
# ``finally`` block tear down its pubsub cleanly.
return
def _decode_pubsub_message(raw) -> Optional[dict]:
"""Parse a ``Topic.publish`` payload to ``{sequence_no, payload, ...}``.
Returns ``None`` for malformed messages (drop silently — the
journal is still authoritative on reconnect).
"""
try:
if isinstance(raw, (bytes, bytearray)):
text_value = raw.decode("utf-8")
else:
text_value = str(raw)
envelope = json.loads(text_value)
except Exception:
return None
if not isinstance(envelope, dict):
return None
return envelope
def encode_pubsub_message(
message_id: str,
sequence_no: int,
event_type: str,
payload: dict,
) -> str:
"""Build the JSON envelope used for ``channel:{message_id}`` publishes.
Kept here (not in ``message_journal.py``) so the encode/decode pair
stays in one file — replay's ``_decode_pubsub_message`` and the
journal's publish must agree on the shape exactly.
"""
return json.dumps(
{
"message_id": str(message_id),
"sequence_no": int(sequence_no),
"event_type": event_type,
"payload": payload,
}
)

View File

@@ -1,19 +0,0 @@
"""Per-chat-message stream key derivations.
Single source of truth for the Redis pub/sub topic name and any
auxiliary keys that the chat-stream snapshot+tail reconnect path
shares between the writer (``complete_stream`` + journal) and the
reader (``/api/messages/<id>/events`` reconnect endpoint).
"""
from __future__ import annotations
def message_topic_name(message_id: str) -> str:
"""Redis pub/sub channel for live fan-out of one chat message.
Subscribers tail this topic for every event that ``complete_stream``
yielded after the SUBSCRIBE-ack arrived; older events are recovered
from the ``message_events`` snapshot half of the pattern.
"""
return f"channel:{message_id}"

View File

@@ -1,400 +0,0 @@
"""Per-yield journal write for the chat-stream snapshot+tail pattern.
``record_event`` inserts into ``message_events`` and publishes to
``channel:{message_id}``. Both are best-effort; the INSERT commits
before the publish so a fast reconnect sees the row. See
``docs/runbooks/sse-notifications.md``.
"""
from __future__ import annotations
import logging
import time
from typing import Any, Optional
from sqlalchemy.exc import IntegrityError
from application.storage.db.repositories.message_events import (
MessageEventsRepository,
)
from application.storage.db.session import db_readonly, db_session
from application.streaming.broadcast_channel import Topic
from application.streaming.event_replay import encode_pubsub_message
from application.streaming.keys import message_topic_name
logger = logging.getLogger(__name__)
# Tunables for ``BatchedJournalWriter``. A streaming answer emits ~100s
# of ``answer`` chunks per response; without batching, that's one PG
# transaction per yield in the WSGI thread. With these defaults, ~10x
# fewer commits at the cost of a ≤100ms reconnect-visibility lag for
# any event still sitting in the buffer.
DEFAULT_BATCH_SIZE = 16
DEFAULT_BATCH_INTERVAL_MS = 100
def _strip_null_bytes(value: Any) -> Any:
"""Recursively strip ``\\x00`` from string keys/values in ``value``.
Postgres JSONB rejects the NUL escape; an LLM emitting a stray NUL
in a chunk would otherwise raise ``DataError`` at INSERT and the row
would be lost from the journal (live stream proceeds, reconnect
snapshot misses the chunk). Mirrors the strip already done in
``parser/embedding_pipeline.py`` and
``api/user/attachments/routes.py``.
"""
if isinstance(value, str):
return value.replace("\x00", "") if "\x00" in value else value
if isinstance(value, dict):
return {
(k.replace("\x00", "") if isinstance(k, str) and "\x00" in k else k):
_strip_null_bytes(v)
for k, v in value.items()
}
if isinstance(value, list):
return [_strip_null_bytes(item) for item in value]
if isinstance(value, tuple):
return tuple(_strip_null_bytes(item) for item in value)
return value
def record_event(
message_id: str,
sequence_no: int,
event_type: str,
payload: Optional[dict[str, Any]] = None,
) -> bool:
"""Journal one SSE event and publish it live. Best-effort.
``payload`` must be a ``dict`` or ``None`` (non-dicts are dropped so
live and replay envelopes stay byte-identical). Returns ``True`` when
the journal INSERT committed. Never raises.
"""
if not message_id or not event_type:
logger.warning(
"record_event called without message_id/event_type "
"(message_id=%r, event_type=%r)",
message_id,
event_type,
)
return False
if payload is None:
materialised_payload: dict[str, Any] = {}
elif isinstance(payload, dict):
materialised_payload = _strip_null_bytes(payload)
else:
logger.warning(
"record_event called with non-dict payload "
"(message_id=%s seq=%s type=%s payload_type=%s) — dropping",
message_id,
sequence_no,
event_type,
type(payload).__name__,
)
return False
journal_committed = False
# The seq we actually managed to write. Diverges from
# ``sequence_no`` only on the IntegrityError-retry path below.
materialised_seq = sequence_no
try:
# Short-lived per-event transaction. Critical for visibility:
# the reconnect endpoint reads the journal from a separate
# connection and only sees committed rows.
with db_session() as conn:
MessageEventsRepository(conn).record(
message_id, sequence_no, event_type, materialised_payload
)
journal_committed = True
except IntegrityError:
# Composite-PK collision on (message_id, sequence_no). Most
# likely cause is a stale ``latest_sequence_no`` seed on a
# continuation retry — the route read MAX(seq) from a separate
# connection before another writer committed past it. Look up
# the live latest and retry once with latest+1 so the event is
# not silently lost. Bounded to a single retry — if two
# writers keep racing in lockstep the route-level retry will
# converge them across attempts.
try:
with db_readonly() as conn:
latest = MessageEventsRepository(conn).latest_sequence_no(
message_id
)
materialised_seq = (latest if latest is not None else -1) + 1
with db_session() as conn:
MessageEventsRepository(conn).record(
message_id,
materialised_seq,
event_type,
materialised_payload,
)
journal_committed = True
logger.info(
"record_event: collision at seq=%s recovered → wrote at "
"seq=%s message_id=%s type=%s",
sequence_no,
materialised_seq,
message_id,
event_type,
)
except IntegrityError:
# Second collision under the same retry — give up and log.
# The route's nonlocal counter will continue at
# ``sequence_no+1`` on the next emit; the next call may
# land cleanly past the contended window.
logger.warning(
"record_event: IntegrityError persists after seq+1 retry; "
"dropping. message_id=%s original_seq=%s retry_seq=%s "
"type=%s",
message_id,
sequence_no,
materialised_seq,
event_type,
)
except Exception:
logger.exception(
"record_event: retry path failed unexpectedly "
"(message_id=%s seq=%s type=%s)",
message_id,
sequence_no,
event_type,
)
except Exception:
logger.exception(
"message_events INSERT failed: message_id=%s seq=%s type=%s",
message_id,
sequence_no,
event_type,
)
try:
# Publish using ``materialised_seq`` so the live pubsub frame
# matches the journal row that other clients will snapshot on
# reconnect. The original POST stream's SSE ``id:`` still
# carries the caller's ``sequence_no`` — a reconnect from that
# client will receive the same event at ``materialised_seq``
# on the snapshot, which is a benign duplicate (the slice's
# ``max_replayed_seq`` advances past it). No-collision case:
# ``materialised_seq == sequence_no`` and this is identical to
# the prior behaviour.
wire = encode_pubsub_message(
message_id, materialised_seq, event_type, materialised_payload
)
Topic(message_topic_name(message_id)).publish(wire)
except Exception:
logger.exception(
"channel:%s publish failed: seq=%s type=%s",
message_id,
materialised_seq,
event_type,
)
return journal_committed
class BatchedJournalWriter:
"""Per-stream journal writer that batches PG INSERTs.
One writer per ``message_id``; ``record()`` buffers events and flushes
on size/time/``close()`` triggers. Pubsub publishes fire only after the
INSERT commits. On ``IntegrityError`` falls back to per-row writes.
"""
def __init__(
self,
message_id: str,
*,
batch_size: int = DEFAULT_BATCH_SIZE,
batch_interval_ms: int = DEFAULT_BATCH_INTERVAL_MS,
) -> None:
self._message_id = message_id
self._batch_size = batch_size
self._batch_interval_ms = batch_interval_ms
self._buffer: list[tuple[int, str, dict[str, Any]]] = []
self._last_flush_mono_ms = time.monotonic() * 1000.0
self._closed = False
def record(
self,
sequence_no: int,
event_type: str,
payload: Optional[dict[str, Any]] = None,
) -> bool:
"""Buffer one event; maybe flush. Publish happens after journal commit."""
if self._closed:
logger.warning(
"BatchedJournalWriter.record after close: "
"message_id=%s seq=%s type=%s",
self._message_id,
sequence_no,
event_type,
)
return False
if not event_type:
logger.warning(
"BatchedJournalWriter.record without event_type: "
"message_id=%s seq=%s",
self._message_id,
sequence_no,
)
return False
if payload is None:
materialised: dict[str, Any] = {}
elif isinstance(payload, dict):
materialised = _strip_null_bytes(payload)
else:
# Same contract as ``record_event`` — non-dict payloads
# are rejected so the live and replay paths can't diverge
# on envelope reconstruction.
logger.warning(
"BatchedJournalWriter.record with non-dict payload: "
"message_id=%s seq=%s type=%s payload_type=%s — dropping",
self._message_id,
sequence_no,
event_type,
type(payload).__name__,
)
return False
self._buffer.append((sequence_no, event_type, materialised))
if self._should_flush():
self.flush()
return True
def _should_flush(self) -> bool:
if len(self._buffer) >= self._batch_size:
return True
elapsed_ms = (time.monotonic() * 1000.0) - self._last_flush_mono_ms
return elapsed_ms >= self._batch_interval_ms and len(self._buffer) > 0
def flush(self) -> None:
"""Commit buffered events to PG. Best-effort.
Tries one bulk INSERT first; on ``IntegrityError`` (composite
PK collision — typically a stale continuation seed) falls back
to per-row ``record_event`` so one bad seq doesn't drop the
rest of the batch. Always clears the buffer to bound memory,
even on failure — a journaled event missing from a snapshot
is degraded UX, but a runaway buffer is corruption.
"""
if not self._buffer:
self._last_flush_mono_ms = time.monotonic() * 1000.0
return
# Snapshot and clear before the I/O so a concurrent record()
# call would land in a fresh buffer rather than racing the
# flush. ``complete_stream`` is single-threaded per stream, so
# this is belt-and-suspenders for any future change.
pending = self._buffer
self._buffer = []
self._last_flush_mono_ms = time.monotonic() * 1000.0
try:
with db_session() as conn:
MessageEventsRepository(conn).bulk_record(
self._message_id, pending
)
except IntegrityError:
logger.info(
"BatchedJournalWriter: bulk INSERT collided for "
"message_id=%s n=%d; falling back to per-row writes",
self._message_id,
len(pending),
)
self._flush_per_row(pending)
return
except Exception:
logger.exception(
"BatchedJournalWriter: bulk INSERT failed for "
"message_id=%s n=%d; events dropped from journal",
self._message_id,
len(pending),
)
return
# Bulk INSERT committed — publish each frame in order. Best-effort:
# one failed publish must not poison the rest of the batch.
for seq, event_type, payload in pending:
self._publish(seq, event_type, payload)
def _flush_per_row(
self, pending: list[tuple[int, str, dict[str, Any]]]
) -> None:
"""Per-row fallback after a bulk collision. Publishes after each commit."""
for seq, event_type, payload in pending:
committed_seq: Optional[int] = None
try:
with db_session() as conn:
MessageEventsRepository(conn).record(
self._message_id, seq, event_type, payload
)
committed_seq = seq
except IntegrityError:
try:
with db_readonly() as conn:
latest = MessageEventsRepository(
conn
).latest_sequence_no(self._message_id)
retry_seq = (latest if latest is not None else -1) + 1
with db_session() as conn:
MessageEventsRepository(conn).record(
self._message_id, retry_seq, event_type, payload
)
committed_seq = retry_seq
except IntegrityError:
logger.warning(
"BatchedJournalWriter: IntegrityError persists "
"after seq+1 retry; dropping. message_id=%s "
"original_seq=%s type=%s",
self._message_id,
seq,
event_type,
)
except Exception:
logger.exception(
"BatchedJournalWriter: per-row retry failed "
"(message_id=%s seq=%s type=%s)",
self._message_id,
seq,
event_type,
)
except Exception:
logger.exception(
"BatchedJournalWriter: per-row INSERT failed "
"(message_id=%s seq=%s type=%s)",
self._message_id,
seq,
event_type,
)
if committed_seq is not None:
self._publish(committed_seq, event_type, payload)
def _publish(
self, sequence_no: int, event_type: str, payload: dict[str, Any]
) -> None:
"""Publish one frame to the per-message pubsub channel. Best-effort."""
try:
wire = encode_pubsub_message(
self._message_id, sequence_no, event_type, payload
)
Topic(message_topic_name(self._message_id)).publish(wire)
except Exception:
logger.exception(
"channel:%s publish failed: seq=%s type=%s",
self._message_id,
sequence_no,
event_type,
)
def close(self) -> None:
"""Final flush. Idempotent — safe to call from multiple
finally clauses.
"""
if self._closed:
return
self.flush()
self._closed = True

View File

@@ -1,5 +1,6 @@
import sys
import logging
import time
from datetime import datetime
from application.storage.db.repositories.token_usage import TokenUsageRepository
from application.storage.db.session import db_session
@@ -19,15 +20,6 @@ def _serialize_for_token_count(value):
if value is None:
return ""
# Raw binary payloads (image/file attachments arrive as ``bytes`` from
# ``GoogleLLM.prepare_messages_with_attachments``) — without this
# branch they fall through to ``str(value)`` below, which produces a
# multi-megabyte ``"b'\\x89PNG...'"`` repr-string and inflates
# ``prompt_tokens`` by orders of magnitude. Same intent as the
# data-URL skip above.
if isinstance(value, (bytes, bytearray, memoryview)):
return ""
if isinstance(value, list):
return [_serialize_for_token_count(item) for item in value]
@@ -91,62 +83,33 @@ def _count_prompt_tokens(messages, tools=None, usage_attachments=None, **kwargs)
return prompt_tokens
def _persist_call_usage(llm, call_usage):
"""Write one ``token_usage`` row per LLM call. Always-on; no flag.
Source defaults to ``agent_stream`` and can be overridden per
instance via ``_token_usage_source`` (set on side-channel LLMs:
title / compression / rag_condense / fallback). A ``_request_id``
stamped on the LLM lets ``count_in_range`` deduplicate the multiple
rows produced by a single multi-tool agent run.
"""
if call_usage["prompt_tokens"] == 0 and call_usage["generated_tokens"] == 0:
def update_token_usage(decoded_token, user_api_key, token_usage, agent_id=None):
if "pytest" in sys.modules:
return
decoded_token = getattr(llm, "decoded_token", None)
user_id = (
decoded_token.get("sub") if isinstance(decoded_token, dict) else None
)
user_api_key = getattr(llm, "user_api_key", None)
agent_id = getattr(llm, "agent_id", None)
if not user_id and not user_api_key:
# Repository would raise on the attribution check — log instead
# so operators see the gap rather than crashing the stream.
user_id = decoded_token.get("sub") if isinstance(decoded_token, dict) else None
normalized_agent_id = str(agent_id) if agent_id else None
if not user_id and not user_api_key and not normalized_agent_id:
logger.warning(
"token_usage skip: no user_id/api_key on LLM instance",
extra={
"source": getattr(llm, "_token_usage_source", "agent_stream"),
},
"Skipping token usage insert: missing user_id, api_key, and agent_id"
)
return
try:
with db_session() as conn:
# ``timestamp`` is omitted so Postgres ``server_default
# = func.now()`` populates a tz-aware UTC value; passing
# naive ``datetime.now()`` would silently shift on
# non-UTC servers.
TokenUsageRepository(conn).insert(
user_id=user_id,
api_key=user_api_key,
agent_id=str(agent_id) if agent_id else None,
prompt_tokens=call_usage["prompt_tokens"],
generated_tokens=call_usage["generated_tokens"],
source=(
getattr(llm, "_token_usage_source", None) or "agent_stream"
),
request_id=getattr(llm, "_request_id", None),
agent_id=normalized_agent_id,
prompt_tokens=token_usage["prompt_tokens"],
generated_tokens=token_usage["generated_tokens"],
timestamp=datetime.now(),
)
except Exception:
logger.exception("token_usage persist failed")
except Exception as e:
logger.error(f"Failed to record token usage: {e}", exc_info=True)
def gen_token_usage(func):
"""Accumulate per-call token counts and write a ``token_usage`` row.
The accumulator on ``self.token_usage`` stays in place for code
paths that introspect it (e.g., logging, response payloads). DB
persistence happens here for every call so primary streams,
side-channel LLMs, and no-save flows all produce rows uniformly.
"""
def wrapper(self, model, messages, stream, tools, **kwargs):
usage_attachments = kwargs.pop("_usage_attachments", None)
call_usage = {"prompt_tokens": 0, "generated_tokens": 0}
@@ -160,14 +123,18 @@ def gen_token_usage(func):
call_usage["generated_tokens"] += _count_tokens(result)
self.token_usage["prompt_tokens"] += call_usage["prompt_tokens"]
self.token_usage["generated_tokens"] += call_usage["generated_tokens"]
_persist_call_usage(self, call_usage)
update_token_usage(
self.decoded_token,
self.user_api_key,
call_usage,
getattr(self, "agent_id", None),
)
return result
return wrapper
def stream_token_usage(func):
"""Stream variant of ``gen_token_usage``. Same persistence contract."""
def wrapper(self, model, messages, stream, tools, **kwargs):
usage_attachments = kwargs.pop("_usage_attachments", None)
call_usage = {"prompt_tokens": 0, "generated_tokens": 0}
@@ -178,36 +145,19 @@ def stream_token_usage(func):
**kwargs,
)
batch = []
started_at = time.monotonic()
error: BaseException | None = None
try:
result = func(self, model, messages, stream, tools, **kwargs)
for r in result:
batch.append(r)
yield r
except Exception as exc:
# ``GeneratorExit`` (consumer disconnected) and KeyboardInterrupt
# flow through as ``status="ok"`` — same convention as
# ``application.logging._consume_and_log``.
error = exc
raise
finally:
for line in batch:
call_usage["generated_tokens"] += _count_tokens(line)
self.token_usage["prompt_tokens"] += call_usage["prompt_tokens"]
self.token_usage["generated_tokens"] += call_usage["generated_tokens"]
_persist_call_usage(self, call_usage)
emit = getattr(self, "_emit_stream_finished_log", None)
if callable(emit):
try:
emit(
model,
prompt_tokens=call_usage["prompt_tokens"],
completion_tokens=call_usage["generated_tokens"],
latency_ms=int((time.monotonic() - started_at) * 1000),
error=error,
)
except Exception:
logger.exception("Failed to emit llm_stream_finished")
result = func(self, model, messages, stream, tools, **kwargs)
for r in result:
batch.append(r)
yield r
for line in batch:
call_usage["generated_tokens"] += _count_tokens(line)
self.token_usage["prompt_tokens"] += call_usage["prompt_tokens"]
self.token_usage["generated_tokens"] += call_usage["generated_tokens"]
update_token_usage(
self.decoded_token,
self.user_api_key,
call_usage,
getattr(self, "agent_id", None),
)
return wrapper

View File

@@ -1,8 +1,5 @@
import logging
from typing import List, Optional, Any, Dict
from psycopg.types.json import Jsonb
from application.core.settings import settings
from application.vectorstore.base import BaseVectorStore
from application.vectorstore.document_class import Document
@@ -178,7 +175,7 @@ class PGVectorStore(BaseVectorStore):
for text, embedding, metadata in zip(texts, embeddings, metadatas):
cursor.execute(
insert_query,
(text, embedding, Jsonb(metadata), self._source_id)
(text, embedding, metadata, self._source_id)
)
inserted_id = cursor.fetchone()[0]
inserted_ids.append(str(inserted_id))
@@ -269,7 +266,7 @@ class PGVectorStore(BaseVectorStore):
cursor.execute(
insert_query,
(text, embeddings[0], Jsonb(final_metadata), self._source_id)
(text, embeddings[0], final_metadata, self._source_id)
)
inserted_id = cursor.fetchone()[0]
conn.commit()

View File

@@ -6,7 +6,6 @@ import os
import shutil
import string
import tempfile
import threading
from typing import Any, Dict
import zipfile
@@ -19,14 +18,11 @@ import requests
from application.agents.agent_creator import AgentCreator
from application.api.answer.services.stream_processor import get_prompt
from application.cache import get_redis_instance
from application.core.settings import settings
from application.events.publisher import publish_user_event
from application.parser.chunking import Chunker
from application.parser.connectors.connector_creator import ConnectorCreator
from application.parser.embedding_pipeline import (
assert_index_complete,
embed_and_store_documents,
)
from application.parser.embedding_pipeline import embed_and_store_documents
from application.parser.file.bulk import SimpleDirectoryReader, get_default_file_extractor
from application.parser.file.constants import SUPPORTED_SOURCE_EXTENSIONS
from application.parser.remote.remote_creator import RemoteCreator
@@ -36,9 +32,6 @@ from application.retriever.retriever_creator import RetrieverCreator
from application.storage.db.base_repository import looks_like_uuid
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.repositories.attachments import AttachmentsRepository
from application.storage.db.repositories.ingest_chunk_progress import (
IngestChunkProgressRepository,
)
from application.storage.db.repositories.sources import SourcesRepository
from application.storage.db.session import db_readonly, db_session
from application.storage.storage_creator import StorageCreator
@@ -50,51 +43,6 @@ from application.utils import count_tokens_docs, num_tokens_from_string, safe_fi
MIN_TOKENS = 150
MAX_TOKENS = 1250
RECURSION_DEPTH = 2
INGEST_HEARTBEAT_INTERVAL_SECONDS = 30
# Re-exported here for backward-compatible imports
# (``from application.worker import _derive_source_id`` /
# ``DOCSGPT_INGEST_NAMESPACE``) from tests and any other in-tree callers.
# New code should import from ``application.storage.db.source_ids``
# directly to avoid pulling this Celery worker module into the API
# process at import time.
from application.storage.db.source_ids import ( # noqa: E402, F401
DOCSGPT_INGEST_NAMESPACE,
derive_source_id as _derive_source_id,
)
def _ingest_heartbeat_loop(source_id, stop_event, interval=INGEST_HEARTBEAT_INTERVAL_SECONDS):
"""Bump ``ingest_chunk_progress.last_updated`` until ``stop_event`` is set."""
while not stop_event.wait(interval):
try:
with db_session() as conn:
IngestChunkProgressRepository(conn).bump_heartbeat(source_id)
except Exception as e:
logging.warning(
f"Heartbeat failed for {source_id}: {e}", exc_info=True
)
def _start_ingest_heartbeat(source_id):
"""Spawn the heartbeat daemon and return ``(thread, stop_event)``."""
stop_event = threading.Event()
thread = threading.Thread(
target=_ingest_heartbeat_loop,
args=(str(source_id), stop_event),
daemon=True,
name=f"ingest-heartbeat-{source_id}",
)
thread.start()
return thread, stop_event
def _stop_ingest_heartbeat(thread, stop_event):
"""Signal the heartbeat daemon to exit and wait briefly for it."""
if stop_event is not None:
stop_event.set()
if thread is not None:
thread.join(timeout=5)
# Define a function to extract metadata from a given filename.
@@ -484,10 +432,7 @@ def run_agent_logic(agent_config, input_data):
"tool_calls": tool_calls,
"thought": thought,
}
# Per-activity summary fields (answer_length, thought_length,
# source_count, tool_call_count) now ride on the inner
# ``activity_finished`` event emitted by ``log_activity`` around
# ``Agent.gen`` above; no separate ``agent_response`` log needed.
logging.info(f"Agent response: {result}")
return result
except Exception as e:
logging.error(f"Error in run_agent_logic: {e}", exc_info=True)
@@ -507,8 +452,6 @@ def ingest_worker(
user,
retriever="classic",
file_name_map=None,
idempotency_key=None,
source_id=None,
):
"""
Ingest and process documents.
@@ -523,14 +466,6 @@ def ingest_worker(
user (str): Identifier for the user initiating the ingestion (original, unsanitized).
retriever (str): Type of retriever to use for processing the documents.
file_name_map (dict|str|None): Optional mapping of safe relative paths to original filenames.
idempotency_key (str|None): When provided, the ``source_id`` is derived
deterministically from the key so a retried task reuses the same
source row instead of duplicating it.
source_id (str|None): UUID minted by the HTTP route and returned in
its response. When supplied, the worker uses it verbatim so SSE
envelopes carry the same id the frontend already has — required
for non-idempotent uploads where the route can't predict
``_derive_source_id(idempotency_key)``.
Returns:
dict: Information about the completed ingestion task, including input parameters and a "limited" flag.
@@ -545,41 +480,10 @@ def ingest_worker(
logging.info(f"Ingest path: {file_path}", extra={"user": user, "job": job_name})
# Source id resolution order:
# 1. Caller-supplied ``source_id`` (HTTP route minted + returned to
# the frontend) — keeps the route response and the SSE event
# payloads in lockstep on the non-idempotent path.
# 2. Deterministic uuid5 from ``idempotency_key`` — retried tasks
# reuse the original source row instead of duplicating it.
# 3. Fresh uuid4 (caller has neither) — opaque, single-shot only.
if source_id:
source_uuid = uuid.UUID(source_id)
else:
source_uuid = _derive_source_id(idempotency_key)
source_id_for_events = str(source_uuid)
# Only emit ``queued`` on the original attempt. Celery retries re-run
# the body, and re-publishing here would oscillate the toast through
# ``queued`` again between ``failed`` and ``completed``.
if self.request.retries == 0:
publish_user_event(
user,
"source.ingest.queued",
{
"job_name": job_name,
"filename": filename,
"source_id": source_id_for_events,
"operation": "upload",
},
scope={"kind": "source", "id": source_id_for_events},
)
# Create temporary working directory
# Wrap the entire body in try/except so a failure between the
# ``queued`` publish above and the inner work (e.g. tempdir
# creation, OS-level resource exhaustion) still emits a terminal
# ``failed`` event rather than leaving the toast wedged on
# 'training' until the polling fallback rescues it 30s later.
try:
with tempfile.TemporaryDirectory() as temp_dir:
with tempfile.TemporaryDirectory() as temp_dir:
try:
os.makedirs(temp_dir, exist_ok=True)
if storage.is_directory(file_path):
@@ -668,22 +572,12 @@ def ingest_worker(
docs = [Document.to_langchain_format(raw_doc) for raw_doc in raw_docs]
id = uuid.uuid4()
vector_store_path = os.path.join(temp_dir, "vector_store")
os.makedirs(vector_store_path, exist_ok=True)
heartbeat_thread, heartbeat_stop = _start_ingest_heartbeat(source_uuid)
try:
embed_and_store_documents(
docs, vector_store_path, source_uuid, self,
attempt_id=getattr(self.request, "id", None),
user_id=user,
)
finally:
_stop_ingest_heartbeat(heartbeat_thread, heartbeat_stop)
# Defense-in-depth: chunk-progress is the authoritative
# record of how many chunks landed; mismatch raises so the
# task fails loud rather than caching a partial index.
assert_index_complete(source_uuid)
embed_and_store_documents(docs, vector_store_path, id, self)
tokens = count_tokens_docs(docs)
@@ -698,7 +592,7 @@ def ingest_worker(
"user": user,
"tokens": tokens,
"retriever": retriever,
"id": source_id_for_events,
"id": str(id),
"type": "local",
"file_path": file_path,
"directory_structure": json.dumps(directory_structure),
@@ -707,36 +601,9 @@ def ingest_worker(
file_data["file_name_map"] = json.dumps(file_name_map)
upload_index(vector_store_path, file_data)
publish_user_event(
user,
"source.ingest.completed",
{
"source_id": source_id_for_events,
"filename": filename,
"tokens": tokens,
"operation": "upload",
# Forward-looking contract: ``limited`` is always
# ``False`` today but is carried on the wire so a
# future token-cap detection path can flip it and
# the frontend slice / UploadToast already react.
"limited": False,
},
scope={"kind": "source", "id": source_id_for_events},
)
except Exception as e:
logging.error(f"Error in ingest_worker: {e}", exc_info=True)
publish_user_event(
user,
"source.ingest.failed",
{
"source_id": source_id_for_events,
"filename": filename,
"operation": "upload",
"error": str(e)[:1024],
},
scope={"kind": "source", "id": source_id_for_events},
)
raise
except Exception as e:
logging.error(f"Error in ingest_worker: {e}", exc_info=True)
raise
return {
"directory": directory,
"formats": formats,
@@ -760,23 +627,7 @@ def reingest_source_worker(self, source_id, user):
Returns:
dict: Information about the re-ingestion task
Note:
Reingest does its own ``vector_store.add_chunk`` work rather
than going through ``embed_and_store_documents`` so it does
*not* emit per-percent SSE progress events — only ``queued``,
``completed`` (carrying ``chunks_added`` / ``chunks_deleted``),
or ``failed``. v1 limitation; revisit if reingest gains a
progress-driven UI.
"""
# Declared at the function scope so the outer except can include
# ``name`` in the failed event payload when the failure happens
# after the source lookup. Empty string until the lookup succeeds.
source_name = ""
# Tracks inner-block failures so a ``completed`` event reflects
# partial-success accurately rather than masking it.
inner_warnings: list[str] = []
try:
from application.vectorstore.vector_creator import VectorCreator
@@ -790,27 +641,6 @@ def reingest_source_worker(self, source_id, user):
if not source:
raise ValueError(f"Source {source_id} not found or access denied")
source_id = str(source["id"])
source_name = source.get("name") or ""
# Publish ``queued`` *after* canonicalising ``source_id`` so the
# event references the same id as the source row. Trade-off
# documented: a Celery-backend or PG-lookup hiccup before this
# publish means the toast may see only a ``failed`` event with
# no preceding ``queued`` — acceptable for v1 since both
# conditions also imply broader system trouble. Gate on first
# attempt only so Celery retries don't re-emit ``queued`` after
# a prior attempt already published ``failed``.
if self.request.retries == 0:
publish_user_event(
user,
"source.ingest.queued",
{
"source_id": source_id,
"name": source_name,
"operation": "reingest",
},
scope={"kind": "source", "id": source_id},
)
storage = StorageCreator.get_storage()
source_file_path = source.get("file_path", "")
@@ -908,19 +738,6 @@ def reingest_source_worker(self, source_id, user):
try:
if not added_files and not removed_files:
logging.info("No changes detected.")
publish_user_event(
user,
"source.ingest.completed",
{
"source_id": source_id,
"name": source_name,
"operation": "reingest",
"no_changes": True,
"chunks_added": 0,
"chunks_deleted": 0,
},
scope={"kind": "source", "id": source_id},
)
return {
"source_id": source_id,
"user": user,
@@ -972,9 +789,6 @@ def reingest_source_worker(self, source_id, user):
f"Error during deletion of removed file chunks: {e}",
exc_info=True,
)
inner_warnings.append(
f"deletion failed: {str(e)[:200]}"
)
# 2) Add chunks from new files
added = 0
@@ -1067,9 +881,6 @@ def reingest_source_worker(self, source_id, user):
logging.error(
f"Error during ingestion of new files: {e}", exc_info=True
)
inner_warnings.append(
f"add failed: {str(e)[:200]}"
)
# 3) Update source directory structure timestamp
try:
@@ -1098,25 +909,6 @@ def reingest_source_worker(self, source_id, user):
meta={"current": 100, "status": "Re-ingestion completed"},
)
completed_payload: dict = {
"source_id": source_id,
"name": source_name,
"operation": "reingest",
"chunks_added": added,
"chunks_deleted": deleted,
"tokens": int(total_tokens) if "total_tokens" in locals() else 0,
}
if inner_warnings:
# Surface the per-block failures so the toast can warn
# rather than claim a clean success.
completed_payload["warnings"] = inner_warnings
publish_user_event(
user,
"source.ingest.completed",
completed_payload,
scope={"kind": "source", "id": source_id},
)
return {
"source_id": source_id,
"user": user,
@@ -1134,17 +926,6 @@ def reingest_source_worker(self, source_id, user):
except Exception as e:
logging.error(f"Error in reingest_source_worker: {e}", exc_info=True)
publish_user_event(
user,
"source.ingest.failed",
{
"source_id": str(source_id),
"name": source_name,
"operation": "reingest",
"error": str(e)[:1024],
},
scope={"kind": "source", "id": str(source_id)},
)
raise
@@ -1159,52 +940,12 @@ def remote_worker(
sync_frequency="never",
operation_mode="upload",
doc_id=None,
idempotency_key=None,
source_id=None,
):
safe_user = safe_filename(user)
full_path = os.path.join(directory, safe_user, uuid.uuid4().hex)
os.makedirs(full_path, exist_ok=True)
# Source id resolution order matches ``ingest_worker``:
# 1. ``operation_mode == "sync"`` reuses the existing source's ``doc_id``.
# 2. Caller-supplied ``source_id`` (the HTTP route minted it and
# already returned it to the frontend) — keeps the route
# response and the SSE event payloads in lockstep on the
# no-idempotency-key path.
# 3. Deterministic uuid5 from ``idempotency_key`` — retried tasks
# reuse the original source row instead of duplicating it.
# 4. Fresh uuid4 — opaque, single-shot only.
if operation_mode == "sync" and doc_id:
source_uuid = str(doc_id)
elif source_id:
source_uuid = uuid.UUID(source_id)
else:
source_uuid = _derive_source_id(idempotency_key)
source_id_for_events = str(source_uuid)
# Emit the queued event before any work that could fail (including
# ``update_state``) so the toast UI always sees a queued envelope
# before any subsequent failed event. Gated on first attempt so
# Celery retries don't re-emit ``queued`` after a prior ``failed``.
if self.request.retries == 0:
publish_user_event(
user,
"source.ingest.queued",
{
"source_id": source_id_for_events,
"job_name": name_job,
"loader": loader,
"operation": operation_mode,
},
scope={"kind": "source", "id": source_id_for_events},
)
# Wrap ``update_state`` plus the entire body so any pre-loader
# failure (Celery backend down, OS resource issue) still emits a
# terminal ``failed`` event rather than wedging the toast.
self.update_state(state="PROGRESS", meta={"current": 1})
try:
self.update_state(state="PROGRESS", meta={"current": 1})
logging.info("Initializing remote loader with type: %s", loader)
remote_loader = RemoteCreator.create_loader(loader)
raw_docs = remote_loader.load_data(source_data)
@@ -1291,22 +1032,14 @@ def remote_worker(
)
if operation_mode == "upload":
embed_and_store_documents(
docs, full_path, source_uuid, self,
attempt_id=getattr(self.request, "id", None),
user_id=user,
)
assert_index_complete(source_uuid)
id = uuid.uuid4()
embed_and_store_documents(docs, full_path, id, self)
elif operation_mode == "sync":
if not doc_id:
logging.error("Invalid doc_id provided for sync operation: %s", doc_id)
raise ValueError("doc_id must be provided for sync operation.")
embed_and_store_documents(
docs, full_path, source_uuid, self,
attempt_id=getattr(self.request, "id", None),
user_id=user,
)
assert_index_complete(source_uuid)
id = str(doc_id)
embed_and_store_documents(docs, full_path, id, self)
self.update_state(state="PROGRESS", meta={"current": 100})
# Serialize remote_data as JSON if it's a dict (for S3, Reddit, etc.)
@@ -1318,7 +1051,7 @@ def remote_worker(
"user": user,
"tokens": tokens,
"retriever": retriever,
"id": source_id_for_events,
"id": str(id),
"type": loader,
"remote_data": remote_data_serialized,
"sync_frequency": sync_frequency,
@@ -1332,49 +1065,23 @@ def remote_worker(
try:
with db_session() as conn:
repo = SourcesRepository(conn)
src = repo.get_any(source_id_for_events, user)
src = repo.get_any(str(id), user)
if src is not None:
repo.update(str(src["id"]), user, {"date": last_sync_now})
except Exception as upd_err:
logging.warning(
f"Failed to update last_sync for source {source_id_for_events}: {upd_err}"
f"Failed to update last_sync for source {id}: {upd_err}"
)
upload_index(full_path, file_data)
publish_user_event(
user,
"source.ingest.completed",
{
"source_id": source_id_for_events,
"job_name": name_job,
"loader": loader,
"operation": operation_mode,
"tokens": tokens,
# Forward-looking contract: see ingest_worker.
"limited": False,
},
scope={"kind": "source", "id": source_id_for_events},
)
except Exception as e:
logging.error("Error in remote_worker task: %s", str(e), exc_info=True)
publish_user_event(
user,
"source.ingest.failed",
{
"source_id": source_id_for_events,
"job_name": name_job,
"loader": loader,
"operation": operation_mode,
"error": str(e)[:1024],
},
scope={"kind": "source", "id": source_id_for_events},
)
raise
finally:
if os.path.exists(full_path):
shutil.rmtree(full_path)
logging.info("remote_worker task completed successfully")
return {
"id": source_id_for_events,
"id": str(id),
"urls": source_data,
"name_job": name_job,
"user": user,
@@ -1457,13 +1164,6 @@ def attachment_worker(self, file_info, user):
relative_path = file_info["path"]
metadata = file_info.get("metadata", {})
publish_user_event(
user,
"attachment.queued",
{"attachment_id": str(attachment_id), "filename": filename},
scope={"kind": "attachment", "id": str(attachment_id)},
)
try:
self.update_state(state="PROGRESS", meta={"current": 10})
storage = StorageCreator.get_storage()
@@ -1471,17 +1171,6 @@ def attachment_worker(self, file_info, user):
self.update_state(
state="PROGRESS", meta={"current": 30, "status": "Processing content"}
)
publish_user_event(
user,
"attachment.progress",
{
"attachment_id": str(attachment_id),
"filename": filename,
"current": 30,
"stage": "processing",
},
scope={"kind": "attachment", "id": str(attachment_id)},
)
file_extractor = get_default_file_extractor(
ocr_enabled=settings.DOCLING_OCR_ATTACHMENTS_ENABLED
@@ -1514,17 +1203,6 @@ def attachment_worker(self, file_info, user):
self.update_state(
state="PROGRESS", meta={"current": 80, "status": "Storing in database"}
)
publish_user_event(
user,
"attachment.progress",
{
"attachment_id": str(attachment_id),
"filename": filename,
"current": 80,
"stage": "storing",
},
scope={"kind": "attachment", "id": str(attachment_id)},
)
mime_type = mimetypes.guess_type(filename)[0] or "application/octet-stream"
@@ -1549,18 +1227,6 @@ def attachment_worker(self, file_info, user):
self.update_state(state="PROGRESS", meta={"current": 100, "status": "Complete"})
publish_user_event(
user,
"attachment.completed",
{
"attachment_id": str(attachment_id),
"filename": filename,
"token_count": token_count,
"mime_type": mime_type,
},
scope={"kind": "attachment", "id": str(attachment_id)},
)
return {
"filename": filename,
"path": relative_path,
@@ -1575,24 +1241,20 @@ def attachment_worker(self, file_info, user):
extra={"user": user},
exc_info=True,
)
publish_user_event(
user,
"attachment.failed",
{
"attachment_id": str(attachment_id),
"filename": filename,
"error": str(e)[:1024],
},
scope={"kind": "attachment", "id": str(attachment_id)},
)
raise
def agent_webhook_worker(self, agent_id, payload):
"""Process the webhook payload for an agent.
"""
Process the webhook payload for an agent.
Raises on failure: Celery treats a returned dict as success and
would skip retries, leaving the caller with a stale 200.
Args:
self: Reference to the instance of the task.
agent_id (str): Unique identifier for the agent.
payload (dict): The payload data from the webhook.
Returns:
dict: Information about the processed webhook.
"""
self.update_state(state="PROGRESS", meta={"current": 1})
try:
@@ -1618,13 +1280,13 @@ def agent_webhook_worker(self, agent_id, payload):
input_data = json.dumps(payload)
except Exception as e:
logging.error(f"Error processing agent webhook: {e}", exc_info=True)
raise
return {"status": "error", "error": str(e)}
self.update_state(state="PROGRESS", meta={"current": 50})
try:
result = run_agent_logic(agent_config, input_data)
except Exception as e:
logging.error(f"Error running agent logic: {e}", exc_info=True)
raise
return {"status": "error"}
else:
logging.info(
f"Webhook processed for agent {agent_id}", extra={"agent_id": agent_id}
@@ -1647,8 +1309,6 @@ def ingest_connector(
operation_mode: str = "upload",
doc_id=None,
sync_frequency: str = "never",
idempotency_key=None,
source_id=None,
) -> Dict[str, Any]:
"""
Ingestion for internal knowledge bases (GoogleDrive, etc.).
@@ -1665,52 +1325,14 @@ def ingest_connector(
operation_mode: "upload" for initial ingestion, "sync" for incremental sync
doc_id: Document ID for sync operations (required when operation_mode="sync")
sync_frequency: How often to sync ("never", "daily", "weekly", "monthly")
idempotency_key: When provided, the ``source_id`` is derived
deterministically so a retried upload reuses the same source row.
source_id: When supplied, the worker uses it verbatim so SSE envelopes
carry the same id the HTTP route already returned to the frontend
— required for non-idempotent uploads where the route can't
predict ``_derive_source_id(idempotency_key)``.
"""
logging.info(
f"Starting remote ingestion from {source_type} for user: {user}, job: {job_name}"
)
# Source id resolution mirrors ``ingest_worker`` / ``remote_worker``:
# sync mode reuses ``doc_id``; otherwise the caller-supplied
# ``source_id`` (minted by the HTTP route and already echoed to the
# client) wins; fall back to ``_derive_source_id`` only when neither
# is supplied. Without rule (2) the no-idempotency-key path would
# mint a fresh uuid4 here that the frontend has no way to correlate
# SSE envelopes to.
if operation_mode == "sync" and doc_id:
source_uuid = str(doc_id)
elif source_id:
source_uuid = uuid.UUID(source_id)
else:
source_uuid = _derive_source_id(idempotency_key)
source_id_for_events = str(source_uuid)
# First-attempt gate: Celery retries re-run the body, and a
# repeated ``queued`` here would oscillate the toast through
# ``queued`` again between ``failed`` and ``completed``.
if self.request.retries == 0:
publish_user_event(
user,
"source.ingest.queued",
{
"source_id": source_id_for_events,
"job_name": job_name,
"loader": source_type,
"operation": operation_mode,
},
scope={"kind": "source", "id": source_id_for_events},
)
self.update_state(state="PROGRESS", meta={"current": 1})
try:
with tempfile.TemporaryDirectory() as temp_dir:
with tempfile.TemporaryDirectory() as temp_dir:
try:
# Step 1: Initialize the appropriate loader
self.update_state(
state="PROGRESS",
@@ -1748,22 +1370,6 @@ def ingest_connector(
"files_downloaded", 0
):
logging.warning(f"No files were downloaded from {source_type}")
# Connector returned no files — surface as a benign
# ``completed`` event with zero tokens so the toast
# closes out cleanly instead of waiting on polling.
publish_user_event(
user,
"source.ingest.completed",
{
"source_id": source_id_for_events,
"job_name": job_name,
"loader": source_type,
"operation": operation_mode,
"tokens": 0,
"no_changes": True,
},
scope={"kind": "source", "id": source_id_for_events},
)
# Create empty result directly instead of calling a separate method
return {
"name": job_name,
@@ -1813,16 +1419,16 @@ def ingest_connector(
docs = [Document.to_langchain_format(raw_doc) for raw_doc in raw_docs]
# Validate operation_mode here too (the source_uuid path
# at the top of the function only branches on the
# sync+doc_id combination; surfacing the wrong-mode error
# this far in matches the legacy behaviour).
if operation_mode == "sync" and not doc_id:
logging.error(
"Invalid doc_id provided for sync operation: %s", doc_id
)
raise ValueError("doc_id must be provided for sync operation.")
if operation_mode not in ("upload", "sync"):
if operation_mode == "upload":
id = uuid.uuid4()
elif operation_mode == "sync":
if not doc_id:
logging.error(
"Invalid doc_id provided for sync operation: %s", doc_id
)
raise ValueError("doc_id must be provided for sync operation.")
id = str(doc_id)
else:
raise ValueError(f"Invalid operation_mode: {operation_mode}")
vector_store_path = os.path.join(temp_dir, "vector_store")
@@ -1831,12 +1437,7 @@ def ingest_connector(
self.update_state(
state="PROGRESS", meta={"current": 80, "status": "Storing documents"}
)
embed_and_store_documents(
docs, vector_store_path, source_uuid, self,
attempt_id=getattr(self.request, "id", None),
user_id=user,
)
assert_index_complete(source_uuid)
embed_and_store_documents(docs, vector_store_path, id, self)
tokens = count_tokens_docs(docs)
@@ -1846,7 +1447,7 @@ def ingest_connector(
"name": job_name,
"tokens": tokens,
"retriever": retriever,
"id": source_id_for_events,
"id": str(id),
"type": "connector:file",
"remote_data": json.dumps(
{"provider": source_type, **api_source_config}
@@ -1855,13 +1456,16 @@ def ingest_connector(
"sync_frequency": sync_frequency,
}
file_data["last_sync"] = datetime.datetime.now()
if operation_mode == "sync":
file_data["last_sync"] = datetime.datetime.now()
else:
file_data["last_sync"] = datetime.datetime.now()
if operation_mode == "sync":
try:
with db_session() as conn:
repo = SourcesRepository(conn)
src = repo.get_any(source_id_for_events, user)
src = repo.get_any(str(id), user)
if src is not None:
repo.update(
str(src["id"]), user,
@@ -1869,9 +1473,7 @@ def ingest_connector(
)
except Exception as upd_err:
logging.warning(
"Failed to update last_sync for source %s: %s",
source_id_for_events,
upd_err,
f"Failed to update last_sync for source {id}: {upd_err}"
)
upload_index(vector_store_path, file_data)
@@ -1883,104 +1485,45 @@ def ingest_connector(
logging.info(f"Remote ingestion completed: {job_name}")
publish_user_event(
user,
"source.ingest.completed",
{
"source_id": source_id_for_events,
"job_name": job_name,
"loader": source_type,
"operation": operation_mode,
"tokens": tokens,
},
scope={"kind": "source", "id": source_id_for_events},
)
return {
"user": user,
"name": job_name,
"tokens": tokens,
"type": source_type,
"id": source_id_for_events,
"id": str(id),
"status": "complete",
}
except Exception as e:
logging.error(f"Error during remote ingestion: {e}", exc_info=True)
publish_user_event(
user,
"source.ingest.failed",
{
"source_id": source_id_for_events,
"job_name": job_name,
"loader": source_type,
"operation": operation_mode,
"error": str(e)[:1024],
},
scope={"kind": "source", "id": source_id_for_events},
)
raise
except Exception as e:
logging.error(f"Error during remote ingestion: {e}", exc_info=True)
raise
def mcp_oauth(self, config: Dict[str, Any], user_id: str = None) -> Dict[str, Any]:
"""Worker to handle MCP OAuth flow asynchronously.
Publishes SSE events at each phase boundary so the frontend can
drive the OAuth popup directly from the push channel. The
``mcp.oauth.awaiting_redirect`` envelope carries the
``authorization_url`` once the upstream OAuth client surfaces it,
eliminating the prior polling-only path for that URL.
"""
# Bind ``task_id`` and the publish helpers OUTSIDE the outer try so
# the ``except`` handler at the bottom can reach them even when an
# early statement raises. Without this, ``publish_oauth`` would
# UnboundLocalError on top of the original failure.
task_id = self.request.id if getattr(self, "request", None) else None
def publish_oauth(event_type: str, payload: Dict[str, Any]) -> None:
# MCP OAuth can be invoked without a route-bound user_id by
# legacy paths. Skip the SSE publish in that case \u2014 the caller
# has no per-user channel to subscribe to, and the status is
# surfaced via the task's return value.
if not user_id or task_id is None:
return
publish_user_event(
user_id,
event_type,
{"task_id": task_id, **payload},
scope={"kind": "mcp_oauth", "id": task_id},
)
def publish_awaiting_redirect(authorization_url: str) -> None:
"""Callback invoked by ``DocsGPTOAuth.redirect_handler`` once
the OAuth client has minted the authorization URL.
Carrying the URL on the SSE envelope lets the frontend open the
popup directly from the event \u2014 the prior polling-only path
for the URL is gone.
"""
publish_oauth(
"mcp.oauth.awaiting_redirect",
{
"message": "Awaiting OAuth redirect...",
"authorization_url": authorization_url,
},
)
"""Worker to handle MCP OAuth flow asynchronously."""
try:
import asyncio
from application.agents.tools.mcp_tool import MCPTool
publish_oauth("mcp.oauth.in_progress", {"message": "Starting OAuth..."})
task_id = self.request.id
redis_client = get_redis_instance()
def update_status(status_data: Dict[str, Any]):
status_key = f"mcp_oauth_status:{task_id}"
redis_client.setex(status_key, 600, json.dumps(status_data))
update_status(
{
"status": "in_progress",
"message": "Starting OAuth...",
"task_id": task_id,
}
)
tool_config = config.copy()
tool_config["oauth_task_id"] = task_id
# Inject the awaiting-redirect publish callback. ``MCPTool`` pops
# it out of the config and threads it into ``DocsGPTOAuth`` so
# the publish fires synchronously from inside
# ``redirect_handler`` \u2014 the only point where the URL is known.
tool_config["oauth_redirect_publish"] = publish_awaiting_redirect
mcp_tool = MCPTool(tool_config, user_id)
async def run_oauth_discovery():
@@ -1988,6 +1531,14 @@ def mcp_oauth(self, config: Dict[str, Any], user_id: str = None) -> Dict[str, An
mcp_tool._setup_client()
return await mcp_tool._execute_with_client("list_tools")
update_status(
{
"status": "awaiting_redirect",
"message": "Awaiting OAuth redirect...",
"task_id": task_id,
}
)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
@@ -1995,21 +1546,49 @@ def mcp_oauth(self, config: Dict[str, Any], user_id: str = None) -> Dict[str, An
loop.run_until_complete(run_oauth_discovery())
tools = mcp_tool.get_actions_metadata()
publish_oauth(
"mcp.oauth.completed",
{"tools": tools, "tools_count": len(tools)},
update_status(
{
"status": "completed",
"message": f"Connected \u2014 found {len(tools)} tool{'s' if len(tools) != 1 else ''}.",
"tools": tools,
"tools_count": len(tools),
"task_id": task_id,
}
)
return {"success": True, "tools": tools, "tools_count": len(tools)}
except Exception as e:
error_msg = f"OAuth failed: {str(e)}"
logging.error("MCP OAuth discovery failed: %s", error_msg, exc_info=True)
publish_oauth("mcp.oauth.failed", {"error": error_msg[:1024]})
update_status(
{
"status": "error",
"message": error_msg,
"task_id": task_id,
}
)
return {"success": False, "error": error_msg}
finally:
loop.close()
except Exception as e:
error_msg = f"OAuth init failed: {str(e)}"
logging.error("MCP OAuth init failed: %s", error_msg, exc_info=True)
publish_oauth("mcp.oauth.failed", {"error": error_msg[:1024]})
update_status(
{
"status": "error",
"message": error_msg,
"task_id": task_id,
}
)
return {"success": False, "error": error_msg}
def mcp_oauth_status(self, task_id: str) -> Dict[str, Any]:
"""Check the status of an MCP OAuth flow."""
redis_client = get_redis_instance()
status_key = f"mcp_oauth_status:{task_id}"
status_data = redis_client.get(status_key)
if status_data:
return json.loads(status_data)
return {"status": "not_found", "message": "Status not found"}

15
docs/package-lock.json generated
View File

@@ -4448,9 +4448,10 @@
}
},
"node_modules/@xmldom/xmldom": {
"version": "0.9.10",
"resolved": "https://registry.npmjs.org/@xmldom/xmldom/-/xmldom-0.9.10.tgz",
"integrity": "sha512-A9gOqLdi6cV4ibazAjcQufGj0B1y/vDqYrcuP6d/6x8P27gRS8643Dj9o1dEKtB6O7fwxb2FgBmJS2mX7gpvdw==",
"version": "0.9.9",
"resolved": "https://registry.npmjs.org/@xmldom/xmldom/-/xmldom-0.9.9.tgz",
"integrity": "sha512-qycIHAucxy/LXAYIjmLmtQ8q9GPnMbnjG1KXhWm9o5sCr6pOYDATkMPiTNa6/v8eELyqOQ2FsEqeoFYmgv/gJg==",
"deprecated": "this version has critical issues, please update to the latest version",
"license": "MIT",
"engines": {
"node": ">=14.6"
@@ -11834,12 +11835,12 @@
}
},
"node_modules/speech-rule-engine": {
"version": "4.1.4",
"resolved": "https://registry.npmjs.org/speech-rule-engine/-/speech-rule-engine-4.1.4.tgz",
"integrity": "sha512-i/VCLG1fvRc95pMHRqG4aQNscv+9aIsqA2oI7ZQS51sTdUcDHYX6cpT8/tqZ+enjs1tKVwbRBWgxut9SWn+f9g==",
"version": "4.1.3",
"resolved": "https://registry.npmjs.org/speech-rule-engine/-/speech-rule-engine-4.1.3.tgz",
"integrity": "sha512-SBMgkuJYvP4F62daRfBNwYC2nXTEhNXAfsBZ/BB7Ly85/KnbnjmKM7/45ZrFbH6jIMiAliDUDPSZFUuXDvcg6A==",
"license": "Apache-2.0",
"dependencies": {
"@xmldom/xmldom": "0.9.10",
"@xmldom/xmldom": "0.9.9",
"commander": "13.1.0",
"wicked-good-xpath": "1.3.0"
},

View File

@@ -1,385 +0,0 @@
# SSE Notifications Runbook
> Operations guide for "user says they didn't get a notification" — and
> the related "the bell never lights up" / "my upload toast hangs" /
> "the chat answer doesn't reconnect" symptoms.
The user-facing notifications channel is the SSE pipe at
`/api/events` plus per-message reconnects at
`/api/messages/<id>/events`. This document maps a user complaint to
the diagnostic that surfaces the cause.
---
## TL;DR — first 60 seconds
Run these three commands in parallel before anything else:
```bash
# 1) Is Redis up and serving the pipe? Should print PONG instantly.
redis-cli -n 2 PING
# 2) Anyone subscribed to the channel right now? Numbers per channel.
redis-cli -n 2 PUBSUB NUMSUB user:<user_id>
# 3) Is the user's backlog populated? Returns the count of journaled events.
redis-cli -n 2 XLEN user:<user_id>:stream
```
- `PING` failing → Redis is the problem. Skip to "Redis-down".
- `NUMSUB user:<user_id>` returns 0 → no client connected. Skip to "Client never connects".
- `XLEN user:<user_id>:stream` returns 0 or low → publisher isn't writing. Skip to "Publisher silent".
- All three look healthy → the events are flowing on the wire; the issue is downstream of the slice (UI rendering, toast suppression, etc.). Skip to "Events flowing but UI silent".
---
## Architecture cheat-sheet
```
Worker (publish_user_event) Frontend tab
│ ▲
▼ │ GET /api/events SSE
Redis Streams: XADD Flask route
user:<id>:stream ──────────────► replay_backlog (snapshot)
│ +
▼ Topic.subscribe (live tail)
Redis pub/sub: PUBLISH │
user:<id> ────────────────────────────────┘
```
**Source of truth:**
- Persistent journal: Redis Stream `user:<user_id>:stream`, capped at
`EVENTS_STREAM_MAXLEN` (default 1000) entries via `MAXLEN ~`. ~24h
at typical event rates.
- Live fan-out: Redis pub/sub channel `user:<user_id>`. No durability;
subscribers must be attached at publish time.
The chat-stream pipe is separate, parallel infrastructure:
- Journal: Postgres `message_events` table.
- Live fan-out: Redis pub/sub `channel:<message_id>`.
Same patterns, different durability layer. This doc covers both;
they share most diagnostic commands.
---
## Symptom → diagnostic map
### A. "I uploaded a source and the toast never appeared"
User flow: chat → upload → expect toast.
| Step | Command | Expect |
| ------------------------------------------------- | ------------------------------------------------------------- | ----------------------------------------------- |
| Worker received the task | `tail -f celery.log` filtered by user | `ingest_worker` start log line |
| Worker published the queued event | `redis-cli -n 2 XREVRANGE user:<id>:stream + - COUNT 5` | A `source.ingest.queued` entry within seconds |
| Frontend got it | DevTools → Network → `/api/events` → EventStream tab | `data: {"type":"source.ingest.queued",...}` |
| Slice updated | Redux DevTools → state.upload.tasks | Task with matching `sourceId`, `status:'training'` |
If the worker's queued log line is there but the XADD didn't land →
look for a `publish_user_event payload not JSON-serializable` warning
in the worker log (the publisher swallows `TypeError`).
If the XADD landed but the frontend never received it → check
`PUBSUB NUMSUB user:<id>` while the user is on the page. If 0, the
SSE connection isn't subscribed; skip to "Client never connects".
If the frontend received it but the toast didn't render → the
`uploadSlice` extraReducer requires `task.sourceId` to match the
event's `scope.id`. Check the upload route returned `source_id` in
its POST response (the upload, connector, and reingest paths all
include it). Idempotent / cached responses must also include
`source_id` (`_claim_task_or_get_cached`).
### B. "The bell badge never goes up"
There is no bell — the global notifications surface is per-event
toasts, not an aggregated counter. If the user is on an old build,
`Cmd-Shift-R` to bypass cache. The surfaces they're looking for are
`UploadToast` for source uploads and `ToolApprovalToast` for
tool-approval events.
### C. "My chat answer froze mid-stream and never recovered"
User flow: ask question → answer streaming → network blip → answer
stops; should reconnect.
```bash
# Was the original message reserved in PG?
psql -c "SELECT id, status, prompt FROM conversation_messages \
WHERE user_id = '<user>' ORDER BY timestamp DESC LIMIT 5;"
# Did the journal capture events past the user's last-seen seq?
psql -c "SELECT sequence_no, event_type FROM message_events \
WHERE message_id = '<id>' ORDER BY sequence_no;"
# Is the live tail still producing? (subscribe and watch)
redis-cli -n 2 SUBSCRIBE channel:<message_id>
```
The frontend should reconnect via `GET /api/messages/<id>/events`
when the original POST stream closes without a typed `end` or
`error` event. If it's not reconnecting, `console.warn('Stream
reconnect failed', ...)` will be in the browser console — the
reconnect HTTP errored. Common cases:
- The user's JWT rotated mid-stream → 401 on the GET. Frontend
doesn't auto-refresh; the user reloads.
- The user is on a different host than the API and CORS is rejecting
the GET → check `application/asgi.py` allow-headers.
### D. "The dev install never delivers any notifications at all"
Default `AUTH_TYPE` unset means `decoded_token = {"sub": "local"}`
for every request. The SSE client connects without the
`Authorization` header in this case, and `user:local:stream` is
the shared channel everything goes to. If the user has multiple dev
machines pointing at the same Redis, they will see each other's
events. Confirm with:
```bash
redis-cli -n 2 KEYS 'user:local:*'
```
If multiple deployments share the Redis, document that as a known
multi-user-on-local-channel limitation. Set `AUTH_TYPE=simple_jwt`
to scope per-user.
### E. "The notifications channel was working, then suddenly stopped after the user reloaded the page"
Likely path: `backlog.truncated` event fired, the slice cleared
`lastEventId` to null, the closure was carrying the same stale id and
re-tripped the same truncation on every reconnect. **Verify the user
is on a current build — `eventStreamClient.ts` must re-read
`lastEventId = opts.getLastEventId();` without a truthy guard so the
null clear propagates into the next reconnect.**
### F. "I keep getting 429 on /api/events"
The per-user concurrent-connection cap (`SSE_MAX_CONCURRENT_PER_USER`,
default 8) refused the connection. User has too many tabs open or a
runaway reconnect loop. `redis-cli -n 2 GET user:<id>:sse_count`
shows the live counter; the TTL is 1h from the last connection
attempt (rolling — every INCR re-seeds it), so the key only ages
out after the user stops reconnecting for a full hour.
If the count is wedged high without explanation, the
counter-DECR-in-finally path didn't run (worker SIGKILL, OOM). Wait
for the TTL or `redis-cli -n 2 DEL user:<id>:sse_count` to reset.
### G. "Replay snapshot stops at 200 events"
The route caps each replay at `EVENTS_REPLAY_MAX_PER_REQUEST`
(default 200). The cap is intentionally **silent** — the route does NOT
emit a `backlog.truncated` notice for cap-hit. The 200 entries each
carry their own `id:` header, so the frontend's slice cursor
advances to the most-recent delivered id. Next reconnect sends
`last_event_id=<max_replayed>` and the snapshot resumes from there.
A user that was 1000 entries behind catches up over ~5 reconnects.
If the user reports getting HTTP 429 on `/api/events` despite being
well under `SSE_MAX_CONCURRENT_PER_USER`, they hit the windowed
replay budget (`EVENTS_REPLAY_BUDGET_REQUESTS_PER_WINDOW`, default
30 / `EVENTS_REPLAY_BUDGET_WINDOW_SECONDS` 60s). The route refuses
the connection so the slice cursor stays pinned at whatever value
it had; the frontend backs off and the next reconnect (after the
window rolls) gets the proper snapshot. Serving the live tail
without a snapshot used to be the behavior here, but that let the
client advance `lastEventId` past entries it never received,
permanently stranding the un-replayed window — so the route now
429s instead. `redis-cli -n 2 GET user:<id>:replay_count` shows the
current counter; TTL is the window size.
`backlog.truncated` is emitted ONLY when the client's
`Last-Event-ID` has slid off the MAXLEN'd window — i.e. the journal
is genuinely gone past the cursor and the frontend should clear the
slice cursor and refetch state. Treating cap-hit or
budget-exhaustion the same way would lock the user into re-receiving
the oldest 200 entries on every reconnect (the cursor would clear,
the snapshot would re-serve from the start, the cap would re-trip).
### H. "User says push notifications stopped after a deploy"
- Pull `event.published topic=user:<id> type=...` from the worker
logs to confirm the publisher is still firing.
- Pull `event.connect user=<id>` from the API logs to confirm the
client is reconnecting.
- Check the gunicorn worker count and `WSGIMiddleware(workers=32)`
if the deploy reduced worker count, the per-user cap is still 8
but total concurrent SSE connections are bounded by `gunicorn
workers × 32`. A capacity miss looks like users randomly getting
429'd.
---
## Common failure modes
### Redis-down
Symptoms: `/api/events` returns 200 but emits only `: connected`
then the body closes. `XLEN` and `PUBLISH` both fail. The publisher's
`record_event` swallows the failure and returns False; the live tail
publish also drops on the floor. Frontend retries forever with
exponential backoff.
Resolution: bring Redis back. The journal is gone (was in-memory
only — Streams persist within a single Redis instance, no replication
configured). New events flow as soon as Redis comes back.
### `AUTH_TYPE` misconfigured = sub:"local" cross-stream
Symptoms: every user shares `user:local:stream`. Any user sees
everyone else's notifications.
Resolution: set `AUTH_TYPE=simple_jwt` (or `session_jwt`) in `.env`.
The events route logs a one-time WARNING per process when
`sub == "local"` is observed. A repeat WARNING after a restart
confirms the misconfiguration.
### MAXLEN trimmed past Last-Event-ID
Symptoms: client reconnects with `last_event_id=X`, snapshot returns
the entire MAXLEN'd backlog (because X is older than the oldest
retained entry). Old events appear duplicated.
Detection: the route's `_oldest_retained_id` check emits
`backlog.truncated` when this case fires. Frontend's
`dispatchSSEEvent` clears `lastEventId` so the next reconnect starts
fresh.
If the WARNING isn't firing but symptoms match: the user's client
may have a corrupt cached `lastEventId`. `localStorage` doesn't
store this state; check Redux state via DevTools.
### Stale event-stream client
Symptoms: events visible in `XRANGE` but the frontend slice doesn't
update.
```bash
# Is the client subscribed?
redis-cli -n 2 PUBSUB NUMSUB user:<id>
# When did its connection start?
grep "event.connect user=<id>" /var/log/docsgpt.log | tail -3
```
If `NUMSUB` is 0 and no recent `event.connect`, the user's tab is
closed or the connection died and never reconnected. Push them to
reload.
### Publisher silent
Symptoms: worker is processing the task (Celery says SUCCESS), but
no XADD and no PUBLISH. User sees no events.
```bash
# Was the publisher import error suppressed?
grep "publish_user_event" /var/log/celery.log | grep -i "warn\|error" | tail -20
# Is push disabled?
grep "ENABLE_SSE_PUSH" /var/log/docsgpt.log | tail -5
```
`ENABLE_SSE_PUSH=False` in `.env` would silence the publisher
globally. Useful for incident response if a runaway publisher is
DoS'ing Redis; toggle off, fix root cause, toggle on.
---
## Useful one-liners
```bash
# Watch a user's live event stream in real time (all events, all types)
redis-cli -n 2 PSUBSCRIBE 'user:*' | grep "user:<id>"
# Last 10 events the user would see on reconnect
redis-cli -n 2 XREVRANGE user:<id>:stream + - COUNT 10
# Live count of subscribed clients per user
redis-cli -n 2 PUBSUB NUMSUB $(redis-cli -n 2 PUBSUB CHANNELS 'user:*')
# Trim a runaway stream (CAREFUL — destroys backlog for all current
# subscribers; OK after explaining to the user)
redis-cli -n 2 XTRIM user:<id>:stream MAXLEN 0
# Clear a wedged concurrent-connection counter
redis-cli -n 2 DEL user:<id>:sse_count
# Force-flip every client to re-snapshot (drop the stream key entirely
# — destroys the backlog; clients reconnect with their last id and
# get a backlog.truncated)
redis-cli -n 2 DEL user:<id>:stream
```
---
## Settings reference
Everything in `application/core/settings.py`:
| Setting | Default | Purpose |
| --------------------------------------------- | ------- | --------------------------------------------- |
| `ENABLE_SSE_PUSH` | `True` | Master switch. False = publisher no-ops, route serves "push_disabled" comment. |
| `EVENTS_STREAM_MAXLEN` | `1000` | Per-user backlog cap. Approximate via `XADD MAXLEN ~`. |
| `SSE_KEEPALIVE_SECONDS` | `15` | Comment-frame cadence. Must sit under reverse-proxy idle close. |
| `SSE_MAX_CONCURRENT_PER_USER` | `8` | Cap on simultaneous SSE connections per user. 0 = disabled. |
| `EVENTS_REPLAY_MAX_PER_REQUEST` | `200` | Hard cap on snapshot rows per request. |
| `EVENTS_REPLAY_BUDGET_REQUESTS_PER_WINDOW` | `30` | Per-user replays per window. 0 = disabled. |
| `EVENTS_REPLAY_BUDGET_WINDOW_SECONDS` | `60` | Window length. |
| `MESSAGE_EVENTS_RETENTION_DAYS` | `14` | Retention for the `message_events` journal; `cleanup_message_events` beat task deletes older rows. |
---
## Known limitations
### Each tab runs its own SSE connection
There is no cross-tab dedup. Every tab open to the app holds its
own SSE connection and dispatches every received event into its
own Redux store, so a user with N tabs open will see N copies of
each toast. With `SSE_MAX_CONCURRENT_PER_USER=8` (the default) a
heavy multi-tab user can also hit the connection cap and start
seeing 429s. Cross-tab dedup via a `BroadcastChannel` ring +
`navigator.locks`-based leader election is tracked as future work.
### `/c/<unknown-id>` normalises to `/c/new`
If a user navigates to a conversation id that isn't in their
loaded list, the conversation route rewrites the URL to `/c/new`.
`ToolApprovalToast`'s gate uses `useMatch('/c/:conversationId')`,
so for the brief window after the rewrite the toast may surface
for a conversation the user *thought* they were already viewing.
Pre-existing route behaviour; not a notifications regression.
### Terminal events un-dismiss running uploads
`frontend/src/upload/uploadSlice.ts` sets `dismissed: false` when
an upload reaches `completed` or `failed`. If the user dismissed a
running task and the terminal SSE arrives later, the toast pops
back. Intentional ("notify the user it's done"); revisit if the
re-surface UX is too aggressive for v2.
### Werkzeug doesn't auto-reload route files
The dev server (`flask run`) doesn't watch
`application/api/events/routes.py` for changes by default.
After editing the route, restart Flask manually — `--reload`
isn't on. (Production gunicorn reloads via deploy.)
### MCP OAuth completion can fall outside the user stream's MAXLEN window
`get_oauth_status` scans up to `EVENTS_STREAM_MAXLEN` (~1000) entries via `XREVRANGE`. If the user has a high-rate ingest running concurrent with the OAuth handshake, the `mcp.oauth.completed` envelope can be trimmed off the back before they click Save. Symptom: backend returns "OAuth failed or not completed" even though the popup completed successfully.
Mitigation today: bump `EVENTS_STREAM_MAXLEN` per-deployment if your users routinely flood the channel during OAuth flows. A dedicated short-TTL Redis key for OAuth task results is tracked as a follow-up.
### React StrictMode double-mounts SSE
In dev, React 18 StrictMode mounts → unmounts → remounts every
component, briefly opening two SSE connections per tab before the
first is aborted. With `SSE_MAX_CONCURRENT_PER_USER=8` and 45
tabs open concurrently you can transiently hit the cap and see
HTTP 429 on cold-load. The first connection's counter increment
fires before the AbortController-induced disconnect can decrement
it. Production (single mount, no StrictMode) is unaffected; raise
the cap in dev or accept transient 429s.

View File

@@ -19,7 +19,7 @@
"class-variance-authority": "^0.7.0",
"clsx": "^2.1.0",
"dompurify": "^3.1.5",
"flow-bin": "^0.311.0",
"flow-bin": "^0.309.0",
"markdown-it": "^14.1.0",
"react": "^19.2.5",
"react-dom": "^19.2.5",
@@ -44,7 +44,7 @@
"eslint-plugin-prettier": "^5.5.5",
"eslint-plugin-react": "^7.37.5",
"eslint-plugin-unused-imports": "^4.4.1",
"globals": "^17.5.0",
"globals": "^15.15.0",
"parcel": "^2.16.4",
"prettier": "^3.8.1",
"process": "^0.11.10",
@@ -546,13 +546,12 @@
}
},
"node_modules/@babel/plugin-syntax-jsx": {
"version": "7.28.6",
"resolved": "https://registry.npmjs.org/@babel/plugin-syntax-jsx/-/plugin-syntax-jsx-7.28.6.tgz",
"integrity": "sha512-wgEmr06G6sIpqr8YDwA2dSRTE3bJ+V0IfpzfSY3Lfgd7YWOaAdlykvJi13ZKBt8cZHfgH1IXN+CL656W3uUa4w==",
"version": "7.24.6",
"resolved": "https://registry.npmjs.org/@babel/plugin-syntax-jsx/-/plugin-syntax-jsx-7.24.6.tgz",
"integrity": "sha512-lWfvAIFNWMlCsU0DRUun2GpFwZdGTukLaHJqRh1JRb80NdAP5Sb1HDHB5X9P9OtgZHQl089UzQkpYlBq2VTPRw==",
"dev": true,
"license": "MIT",
"dependencies": {
"@babel/helper-plugin-utils": "^7.28.6"
"@babel/helper-plugin-utils": "^7.24.6"
},
"engines": {
"node": ">=6.9.0"
@@ -1253,13 +1252,12 @@
}
},
"node_modules/@babel/plugin-transform-react-display-name": {
"version": "7.28.0",
"resolved": "https://registry.npmjs.org/@babel/plugin-transform-react-display-name/-/plugin-transform-react-display-name-7.28.0.tgz",
"integrity": "sha512-D6Eujc2zMxKjfa4Zxl4GHMsmhKKZ9VpcqIchJLvwTxad9zWIYulwYItBovpDOoNLISpcZSXoDJ5gaGbQUDqViA==",
"version": "7.24.6",
"resolved": "https://registry.npmjs.org/@babel/plugin-transform-react-display-name/-/plugin-transform-react-display-name-7.24.6.tgz",
"integrity": "sha512-/3iiEEHDsJuj9QU09gbyWGSUxDboFcD7Nj6dnHIlboWSodxXAoaY/zlNMHeYAC0WsERMqgO9a7UaM77CsYgWcg==",
"dev": true,
"license": "MIT",
"dependencies": {
"@babel/helper-plugin-utils": "^7.27.1"
"@babel/helper-plugin-utils": "^7.24.6"
},
"engines": {
"node": ">=6.9.0"
@@ -1269,17 +1267,16 @@
}
},
"node_modules/@babel/plugin-transform-react-jsx": {
"version": "7.28.6",
"resolved": "https://registry.npmjs.org/@babel/plugin-transform-react-jsx/-/plugin-transform-react-jsx-7.28.6.tgz",
"integrity": "sha512-61bxqhiRfAACulXSLd/GxqmAedUSrRZIu/cbaT18T1CetkTmtDN15it7i80ru4DVqRK1WMxQhXs+Lf9kajm5Ow==",
"version": "7.24.6",
"resolved": "https://registry.npmjs.org/@babel/plugin-transform-react-jsx/-/plugin-transform-react-jsx-7.24.6.tgz",
"integrity": "sha512-pCtPHhpRZHfwdA5G1Gpk5mIzMA99hv0R8S/Ket50Rw+S+8hkt3wBWqdqHaPw0CuUYxdshUgsPiLQ5fAs4ASMhw==",
"dev": true,
"license": "MIT",
"dependencies": {
"@babel/helper-annotate-as-pure": "^7.27.3",
"@babel/helper-module-imports": "^7.28.6",
"@babel/helper-plugin-utils": "^7.28.6",
"@babel/plugin-syntax-jsx": "^7.28.6",
"@babel/types": "^7.28.6"
"@babel/helper-annotate-as-pure": "^7.24.6",
"@babel/helper-module-imports": "^7.24.6",
"@babel/helper-plugin-utils": "^7.24.6",
"@babel/plugin-syntax-jsx": "^7.24.6",
"@babel/types": "^7.24.6"
},
"engines": {
"node": ">=6.9.0"
@@ -1289,13 +1286,12 @@
}
},
"node_modules/@babel/plugin-transform-react-jsx-development": {
"version": "7.27.1",
"resolved": "https://registry.npmjs.org/@babel/plugin-transform-react-jsx-development/-/plugin-transform-react-jsx-development-7.27.1.tgz",
"integrity": "sha512-ykDdF5yI4f1WrAolLqeF3hmYU12j9ntLQl/AOG1HAS21jxyg1Q0/J/tpREuYLfatGdGmXp/3yS0ZA76kOlVq9Q==",
"version": "7.24.6",
"resolved": "https://registry.npmjs.org/@babel/plugin-transform-react-jsx-development/-/plugin-transform-react-jsx-development-7.24.6.tgz",
"integrity": "sha512-F7EsNp5StNDouSSdYyDSxh4J+xvj/JqG+Cb6s2fA+jCyHOzigG5vTwgH8tU2U8Voyiu5zCG9bAK49wTr/wPH0w==",
"dev": true,
"license": "MIT",
"dependencies": {
"@babel/plugin-transform-react-jsx": "^7.27.1"
"@babel/plugin-transform-react-jsx": "^7.24.6"
},
"engines": {
"node": ">=6.9.0"
@@ -1305,14 +1301,13 @@
}
},
"node_modules/@babel/plugin-transform-react-pure-annotations": {
"version": "7.27.1",
"resolved": "https://registry.npmjs.org/@babel/plugin-transform-react-pure-annotations/-/plugin-transform-react-pure-annotations-7.27.1.tgz",
"integrity": "sha512-JfuinvDOsD9FVMTHpzA/pBLisxpv1aSf+OIV8lgH3MuWrks19R27e6a6DipIg4aX1Zm9Wpb04p8wljfKrVSnPA==",
"version": "7.24.6",
"resolved": "https://registry.npmjs.org/@babel/plugin-transform-react-pure-annotations/-/plugin-transform-react-pure-annotations-7.24.6.tgz",
"integrity": "sha512-0HoDQlFJJkXRyV2N+xOpUETbKHcouSwijRQbKWVtxsPoq5bbB30qZag9/pSc5xcWVYjTHlLsBsY+hZDnzQTPNw==",
"dev": true,
"license": "MIT",
"dependencies": {
"@babel/helper-annotate-as-pure": "^7.27.1",
"@babel/helper-plugin-utils": "^7.27.1"
"@babel/helper-annotate-as-pure": "^7.24.6",
"@babel/helper-plugin-utils": "^7.24.6"
},
"engines": {
"node": ">=6.9.0"
@@ -1618,18 +1613,17 @@
}
},
"node_modules/@babel/preset-react": {
"version": "7.28.5",
"resolved": "https://registry.npmjs.org/@babel/preset-react/-/preset-react-7.28.5.tgz",
"integrity": "sha512-Z3J8vhRq7CeLjdC58jLv4lnZ5RKFUJWqH5emvxmv9Hv3BD1T9R/Im713R4MTKwvFaV74ejZ3sM01LyEKk4ugNQ==",
"version": "7.24.6",
"resolved": "https://registry.npmjs.org/@babel/preset-react/-/preset-react-7.24.6.tgz",
"integrity": "sha512-8mpzh1bWvmINmwM3xpz6ahu57mNaWavMm+wBNjQ4AFu1nghKBiIRET7l/Wmj4drXany/BBGjJZngICcD98F1iw==",
"dev": true,
"license": "MIT",
"dependencies": {
"@babel/helper-plugin-utils": "^7.27.1",
"@babel/helper-validator-option": "^7.27.1",
"@babel/plugin-transform-react-display-name": "^7.28.0",
"@babel/plugin-transform-react-jsx": "^7.27.1",
"@babel/plugin-transform-react-jsx-development": "^7.27.1",
"@babel/plugin-transform-react-pure-annotations": "^7.27.1"
"@babel/helper-plugin-utils": "^7.24.6",
"@babel/helper-validator-option": "^7.24.6",
"@babel/plugin-transform-react-display-name": "^7.24.6",
"@babel/plugin-transform-react-jsx": "^7.24.6",
"@babel/plugin-transform-react-jsx-development": "^7.24.6",
"@babel/plugin-transform-react-pure-annotations": "^7.24.6"
},
"engines": {
"node": ">=6.9.0"
@@ -4546,17 +4540,17 @@
"devOptional": true
},
"node_modules/@typescript-eslint/eslint-plugin": {
"version": "8.59.1",
"resolved": "https://registry.npmjs.org/@typescript-eslint/eslint-plugin/-/eslint-plugin-8.59.1.tgz",
"integrity": "sha512-BOziFIfE+6osHO9FoJG4zjoHUcvI7fTNBSpdAwrNH0/TLvzjsk2oo8XSSOT2HhqUyhZPfHv4UOffoJ9oEEQ7Ag==",
"version": "8.59.0",
"resolved": "https://registry.npmjs.org/@typescript-eslint/eslint-plugin/-/eslint-plugin-8.59.0.tgz",
"integrity": "sha512-HyAZtpdkgZwpq8Sz3FSUvCR4c+ScbuWa9AksK2Jweub7w4M3yTz4O11AqVJzLYjy/B9ZWPyc81I+mOdJU/bDQw==",
"dev": true,
"license": "MIT",
"dependencies": {
"@eslint-community/regexpp": "^4.12.2",
"@typescript-eslint/scope-manager": "8.59.1",
"@typescript-eslint/type-utils": "8.59.1",
"@typescript-eslint/utils": "8.59.1",
"@typescript-eslint/visitor-keys": "8.59.1",
"@typescript-eslint/scope-manager": "8.59.0",
"@typescript-eslint/type-utils": "8.59.0",
"@typescript-eslint/utils": "8.59.0",
"@typescript-eslint/visitor-keys": "8.59.0",
"ignore": "^7.0.5",
"natural-compare": "^1.4.0",
"ts-api-utils": "^2.5.0"
@@ -4569,22 +4563,22 @@
"url": "https://opencollective.com/typescript-eslint"
},
"peerDependencies": {
"@typescript-eslint/parser": "^8.59.1",
"@typescript-eslint/parser": "^8.59.0",
"eslint": "^8.57.0 || ^9.0.0 || ^10.0.0",
"typescript": ">=4.8.4 <6.1.0"
}
},
"node_modules/@typescript-eslint/parser": {
"version": "8.59.1",
"resolved": "https://registry.npmjs.org/@typescript-eslint/parser/-/parser-8.59.1.tgz",
"integrity": "sha512-HDQH9O/47Dxi1ceDhBXdaldtf/WV9yRYMjbjCuNk3qnaTD564qwv61Y7+gTxwxRKzSrgO5uhtw584igXVuuZkA==",
"version": "8.59.0",
"resolved": "https://registry.npmjs.org/@typescript-eslint/parser/-/parser-8.59.0.tgz",
"integrity": "sha512-TI1XGwKbDpo9tRW8UDIXCOeLk55qe9ZFGs8MTKU6/M08HWTw52DD/IYhfQtOEhEdPhLMT26Ka/x7p70nd3dzDg==",
"dev": true,
"license": "MIT",
"dependencies": {
"@typescript-eslint/scope-manager": "8.59.1",
"@typescript-eslint/types": "8.59.1",
"@typescript-eslint/typescript-estree": "8.59.1",
"@typescript-eslint/visitor-keys": "8.59.1",
"@typescript-eslint/scope-manager": "8.59.0",
"@typescript-eslint/types": "8.59.0",
"@typescript-eslint/typescript-estree": "8.59.0",
"@typescript-eslint/visitor-keys": "8.59.0",
"debug": "^4.4.3"
},
"engines": {
@@ -4600,14 +4594,14 @@
}
},
"node_modules/@typescript-eslint/project-service": {
"version": "8.59.1",
"resolved": "https://registry.npmjs.org/@typescript-eslint/project-service/-/project-service-8.59.1.tgz",
"integrity": "sha512-+MuHQlHiEr00Of/IQbE/MmEoi44znZHbR/Pz7Opq4HryUOlRi+/44dro9Ycy8Fyo+/024IWtw8m4JUMCGTYxDg==",
"version": "8.59.0",
"resolved": "https://registry.npmjs.org/@typescript-eslint/project-service/-/project-service-8.59.0.tgz",
"integrity": "sha512-Lw5ITrR5s5TbC19YSvlr63ZfLaJoU6vtKTHyB0GQOpX0W7d5/Ir6vUahWi/8Sps/nOukZQ0IB3SmlxZnjaKVnw==",
"dev": true,
"license": "MIT",
"dependencies": {
"@typescript-eslint/tsconfig-utils": "^8.59.1",
"@typescript-eslint/types": "^8.59.1",
"@typescript-eslint/tsconfig-utils": "^8.59.0",
"@typescript-eslint/types": "^8.59.0",
"debug": "^4.4.3"
},
"engines": {
@@ -4622,14 +4616,14 @@
}
},
"node_modules/@typescript-eslint/scope-manager": {
"version": "8.59.1",
"resolved": "https://registry.npmjs.org/@typescript-eslint/scope-manager/-/scope-manager-8.59.1.tgz",
"integrity": "sha512-LwuHQI4pDOYVKvmH2dkaJo6YZCSgouVgnS/z7yBPKBMvgtBvyLqiLy9Z6b7+m/TRcX1NFYUqZetI5Y+aT4GEfg==",
"version": "8.59.0",
"resolved": "https://registry.npmjs.org/@typescript-eslint/scope-manager/-/scope-manager-8.59.0.tgz",
"integrity": "sha512-UzR16Ut8IpA3Mc4DbgAShlPPkVm8xXMWafXxB0BocaVRHs8ZGakAxGRskF7FId3sdk9lgGD73GSFaWmWFDE4dg==",
"dev": true,
"license": "MIT",
"dependencies": {
"@typescript-eslint/types": "8.59.1",
"@typescript-eslint/visitor-keys": "8.59.1"
"@typescript-eslint/types": "8.59.0",
"@typescript-eslint/visitor-keys": "8.59.0"
},
"engines": {
"node": "^18.18.0 || ^20.9.0 || >=21.1.0"
@@ -4640,9 +4634,9 @@
}
},
"node_modules/@typescript-eslint/tsconfig-utils": {
"version": "8.59.1",
"resolved": "https://registry.npmjs.org/@typescript-eslint/tsconfig-utils/-/tsconfig-utils-8.59.1.tgz",
"integrity": "sha512-/0nEyPbX7gRsk0Uwfe4ALwwgxuA66d/l2mhRDNlAvaj4U3juhUtJNq0DsY8M2AYwwb9rEq2hrC3IcIcEt++iJA==",
"version": "8.59.0",
"resolved": "https://registry.npmjs.org/@typescript-eslint/tsconfig-utils/-/tsconfig-utils-8.59.0.tgz",
"integrity": "sha512-91Sbl3s4Kb3SybliIY6muFBmHVv+pYXfybC4Oolp3dvk8BvIE3wOPc+403CWIT7mJNkfQRGtdqghzs2+Z91Tqg==",
"dev": true,
"license": "MIT",
"engines": {
@@ -4657,15 +4651,15 @@
}
},
"node_modules/@typescript-eslint/type-utils": {
"version": "8.59.1",
"resolved": "https://registry.npmjs.org/@typescript-eslint/type-utils/-/type-utils-8.59.1.tgz",
"integrity": "sha512-klWPBR2ciQHS3f++ug/mVnWKPjBUo7icEL3FAO1lhAR1Z1i5NQYZ1EannMSRYcq5qCv5wNALlXr6fksRHyYl7w==",
"version": "8.59.0",
"resolved": "https://registry.npmjs.org/@typescript-eslint/type-utils/-/type-utils-8.59.0.tgz",
"integrity": "sha512-3TRiZaQSltGqGeNrJzzr1+8YcEobKH9rHnqIp/1psfKFmhRQDNMGP5hBufanYTGznwShzVLs3Mz+gDN7HkWfXg==",
"dev": true,
"license": "MIT",
"dependencies": {
"@typescript-eslint/types": "8.59.1",
"@typescript-eslint/typescript-estree": "8.59.1",
"@typescript-eslint/utils": "8.59.1",
"@typescript-eslint/types": "8.59.0",
"@typescript-eslint/typescript-estree": "8.59.0",
"@typescript-eslint/utils": "8.59.0",
"debug": "^4.4.3",
"ts-api-utils": "^2.5.0"
},
@@ -4682,9 +4676,9 @@
}
},
"node_modules/@typescript-eslint/types": {
"version": "8.59.1",
"resolved": "https://registry.npmjs.org/@typescript-eslint/types/-/types-8.59.1.tgz",
"integrity": "sha512-ZDCjgccSdYPw5Bxh+my4Z0lJU96ZDN7jbBzvmEn0FZx3RtU1C7VWl6NbDx94bwY3V5YsgwRzJPOgeY2Q/nLG8A==",
"version": "8.59.0",
"resolved": "https://registry.npmjs.org/@typescript-eslint/types/-/types-8.59.0.tgz",
"integrity": "sha512-nLzdsT1gdOgFxxxwrlNVUBzSNBEEHJ86bblmk4QAS6stfig7rcJzWKqCyxFy3YRRHXDWEkb2NralA1nOYkkm/A==",
"dev": true,
"license": "MIT",
"engines": {
@@ -4696,16 +4690,16 @@
}
},
"node_modules/@typescript-eslint/typescript-estree": {
"version": "8.59.1",
"resolved": "https://registry.npmjs.org/@typescript-eslint/typescript-estree/-/typescript-estree-8.59.1.tgz",
"integrity": "sha512-OUd+vJS05sSkOip+BkZ/2NS8RMxrAAJemsC6vU3kmfLyeaJT0TftHkV9mcx2107MmsBVXXexhVu4F0TZXyMl4g==",
"version": "8.59.0",
"resolved": "https://registry.npmjs.org/@typescript-eslint/typescript-estree/-/typescript-estree-8.59.0.tgz",
"integrity": "sha512-O9Re9P1BmBLFJyikRbQpLku/QA3/AueZNO9WePLBwQrvkixTmDe8u76B6CYUAITRl/rHawggEqUGn5QIkVRLMw==",
"dev": true,
"license": "MIT",
"dependencies": {
"@typescript-eslint/project-service": "8.59.1",
"@typescript-eslint/tsconfig-utils": "8.59.1",
"@typescript-eslint/types": "8.59.1",
"@typescript-eslint/visitor-keys": "8.59.1",
"@typescript-eslint/project-service": "8.59.0",
"@typescript-eslint/tsconfig-utils": "8.59.0",
"@typescript-eslint/types": "8.59.0",
"@typescript-eslint/visitor-keys": "8.59.0",
"debug": "^4.4.3",
"minimatch": "^10.2.2",
"semver": "^7.7.3",
@@ -4737,16 +4731,16 @@
}
},
"node_modules/@typescript-eslint/utils": {
"version": "8.59.1",
"resolved": "https://registry.npmjs.org/@typescript-eslint/utils/-/utils-8.59.1.tgz",
"integrity": "sha512-3pIeoXhCeYH9FSCBI8P3iNwJlGuzPlYKkTlen2O9T1DSeeg8UG8jstq6BLk+Mda0qup7mgk4z4XL4OzRaxZ8LA==",
"version": "8.59.0",
"resolved": "https://registry.npmjs.org/@typescript-eslint/utils/-/utils-8.59.0.tgz",
"integrity": "sha512-I1R/K7V07XsMJ12Oaxg/O9GfrysGTmCRhvZJBv0RE0NcULMzjqVpR5kRRQjHsz3J/bElU7HwCO7zkqL+MSUz+g==",
"dev": true,
"license": "MIT",
"dependencies": {
"@eslint-community/eslint-utils": "^4.9.1",
"@typescript-eslint/scope-manager": "8.59.1",
"@typescript-eslint/types": "8.59.1",
"@typescript-eslint/typescript-estree": "8.59.1"
"@typescript-eslint/scope-manager": "8.59.0",
"@typescript-eslint/types": "8.59.0",
"@typescript-eslint/typescript-estree": "8.59.0"
},
"engines": {
"node": "^18.18.0 || ^20.9.0 || >=21.1.0"
@@ -4761,13 +4755,13 @@
}
},
"node_modules/@typescript-eslint/visitor-keys": {
"version": "8.59.1",
"resolved": "https://registry.npmjs.org/@typescript-eslint/visitor-keys/-/visitor-keys-8.59.1.tgz",
"integrity": "sha512-LdDNl6C5iJExcM0Yh0PwAIBb9PrSiCsWamF/JyEZawm3kFDnRoaq3LGE4bpyRao/fWeGKKyw7icx0YxrLFC5Cg==",
"version": "8.59.0",
"resolved": "https://registry.npmjs.org/@typescript-eslint/visitor-keys/-/visitor-keys-8.59.0.tgz",
"integrity": "sha512-/uejZt4dSere1bx12WLlPfv8GktzcaDtuJ7s42/HEZ5zGj9oxRaD4bj7qwSunXkf+pbAhFt2zjpHYUiT5lHf0Q==",
"dev": true,
"license": "MIT",
"dependencies": {
"@typescript-eslint/types": "8.59.1",
"@typescript-eslint/types": "8.59.0",
"eslint-visitor-keys": "^5.0.0"
},
"engines": {
@@ -6492,9 +6486,9 @@
"license": "ISC"
},
"node_modules/flow-bin": {
"version": "0.311.0",
"resolved": "https://registry.npmjs.org/flow-bin/-/flow-bin-0.311.0.tgz",
"integrity": "sha512-4lXxjhPdmkeizju3F0HDCMYGkoL7hiq0W9bAW4pQpQTi56op+QZrVyMENjbCGZc+KlFBLwWkur+EkyfPTsa6xw==",
"version": "0.309.0",
"resolved": "https://registry.npmjs.org/flow-bin/-/flow-bin-0.309.0.tgz",
"integrity": "sha512-/RH68gcCY8OHzcdSVTUCw+fhDSEYmNHoovfK0EcbB4rs1Xbc5HhxhHTvr7U+h55De4bDRlE52ghH23MRP625cQ==",
"license": "MIT",
"bin": {
"flow": "cli.js"
@@ -6657,9 +6651,9 @@
}
},
"node_modules/globals": {
"version": "17.5.0",
"resolved": "https://registry.npmjs.org/globals/-/globals-17.5.0.tgz",
"integrity": "sha512-qoV+HK2yFl/366t2/Cb3+xxPUo5BuMynomoDmiaZBIdbs+0pYbjfZU+twLhGKp4uCZ/+NbtpVepH5bGCxRyy2g==",
"version": "15.15.0",
"resolved": "https://registry.npmjs.org/globals/-/globals-15.15.0.tgz",
"integrity": "sha512-7ACyT3wmyp3I61S4fG682L0VA2RGD9otkqGJIwNUMF1SWUombIIk+af1unuDYgMm082aHYwD+mzJvv9Iu8dsgg==",
"dev": true,
"license": "MIT",
"engines": {
@@ -8976,9 +8970,9 @@
}
},
"node_modules/styled-components": {
"version": "6.4.1",
"resolved": "https://registry.npmjs.org/styled-components/-/styled-components-6.4.1.tgz",
"integrity": "sha512-ADu2dF53esUzzM4I0ewxhxFtsDd6v4V6dNkg3vG0iFKhnt06sJneTZnRvujAosZwW0XD58IKgGMQoqri4wHRqg==",
"version": "6.4.0",
"resolved": "https://registry.npmjs.org/styled-components/-/styled-components-6.4.0.tgz",
"integrity": "sha512-BL1EDFpt+q10eAeZB0q9ps6pSlPejaBQWBkiuM16pyoVTG4NhZrPrZK0cqNbrozxSsYwUsJ9SQYN6NyeKJYX9A==",
"license": "MIT",
"dependencies": {
"@emotion/is-prop-valid": "1.4.0",

View File

@@ -52,7 +52,7 @@
"class-variance-authority": "^0.7.0",
"clsx": "^2.1.0",
"dompurify": "^3.1.5",
"flow-bin": "^0.311.0",
"flow-bin": "^0.309.0",
"markdown-it": "^14.1.0",
"react": "^19.2.5",
"react-dom": "^19.2.5",
@@ -77,7 +77,7 @@
"eslint-plugin-prettier": "^5.5.5",
"eslint-plugin-react": "^7.37.5",
"eslint-plugin-unused-imports": "^4.4.1",
"globals": "^17.5.0",
"globals": "^15.15.0",
"parcel": "^2.16.4",
"prettier": "^3.8.1",
"process": "^0.11.10",

File diff suppressed because it is too large Load Diff

View File

@@ -7,8 +7,6 @@
"dev": "vite",
"build": "tsc && vite build",
"preview": "vite preview",
"test": "vitest run",
"test:watch": "vitest",
"lint": "eslint ./src --ext .jsx,.js,.ts,.tsx",
"lint-fix": "eslint ./src --ext .jsx,.js,.ts,.tsx --fix",
"format": "prettier ./src --write",
@@ -41,12 +39,12 @@
"react": "^19.1.0",
"react-chartjs-2": "^5.3.0",
"react-dom": "^19.2.5",
"react-dropzone": "^15.0.0",
"react-dropzone": "^14.3.8",
"react-google-drive-picker": "^1.2.2",
"react-i18next": "^17.0.6",
"react-i18next": "^17.0.2",
"react-markdown": "^9.0.1",
"react-redux": "^9.2.0",
"react-router-dom": "^7.14.2",
"react-router-dom": "^7.14.1",
"react-syntax-highlighter": "^16.1.1",
"reactflow": "^11.11.4",
"rehype-katex": "^7.0.1",
@@ -60,7 +58,7 @@
"@types/react": "^19.2.14",
"@types/react-dom": "^19.2.3",
"@types/react-syntax-highlighter": "^15.5.13",
"@typescript-eslint/eslint-plugin": "^8.59.1",
"@typescript-eslint/eslint-plugin": "^8.58.2",
"@typescript-eslint/parser": "^8.46.3",
"@vitejs/plugin-react": "^6.0.1",
"eslint": "^9.39.1",
@@ -71,17 +69,15 @@
"eslint-plugin-promise": "^6.6.0",
"eslint-plugin-react": "^7.37.5",
"eslint-plugin-unused-imports": "^4.1.4",
"happy-dom": "^17.6.3",
"husky": "^9.1.7",
"lint-staged": "^16.4.0",
"postcss": "^8.5.12",
"postcss": "^8.4.49",
"prettier": "^3.5.3",
"prettier-plugin-tailwindcss": "^0.7.2",
"tailwindcss": "^4.2.2",
"tw-animate-css": "^1.4.0",
"typescript": "^6.0.3",
"vite": "^8.0.10",
"vite-plugin-svgr": "^4.3.0",
"vitest": "^3.2.4"
"typescript": "^5.8.3",
"vite": "^8.0.8",
"vite-plugin-svgr": "^4.3.0"
}
}

View File

@@ -10,7 +10,6 @@ import Spinner from './components/Spinner';
import UploadToast from './components/UploadToast';
import Conversation from './conversation/Conversation';
import { SharedConversation } from './conversation/SharedConversation';
import { EventStreamProvider } from './events/EventStreamProvider';
import { useDarkTheme, useMediaQuery } from './hooks';
import useDataInitializer from './hooks/useDataInitializer';
import useTokenAuth from './hooks/useTokenAuth';
@@ -18,7 +17,6 @@ import Navigation from './Navigation';
import PageNotFound from './PageNotFound';
import Setting from './settings';
import Notification from './components/Notification';
import ToolApprovalToast from './notifications/ToolApprovalToast';
function AuthWrapper({ children }: { children: React.ReactNode }) {
const { isAuthLoading } = useTokenAuth();
@@ -31,7 +29,7 @@ function AuthWrapper({ children }: { children: React.ReactNode }) {
</div>
);
}
return <EventStreamProvider>{children}</EventStreamProvider>;
return <>{children}</>;
}
function MainLayout() {
@@ -52,7 +50,6 @@ function MainLayout() {
<Outlet />
</div>
<UploadToast />
<ToolApprovalToast />
</div>
);
}
@@ -88,13 +85,6 @@ export default function App() {
}
>
<Route index element={<Conversation />} />
{/* One dynamic route (accepting "new" or a UUID) so the
/c/new → /c/<id> replace doesn't remount Conversation. */}
<Route path="/c/:conversationId" element={<Conversation />} />
<Route
path="/agents/:agentId/c/:conversationId"
element={<Conversation />}
/>
<Route path="/settings/*" element={<Setting />} />
<Route path="/agents/*" element={<Agents />} />
</Route>

View File

@@ -25,7 +25,6 @@ import UnPin from './assets/unpin.svg';
import Help from './components/Help';
import {
handleAbort,
loadConversation,
selectQueries,
setConversation,
updateConversationId,
@@ -51,7 +50,6 @@ import {
setSelectedAgent,
setSharedAgents,
} from './preferences/preferenceSlice';
import { AppDispatch } from './store';
import Upload from './upload/Upload';
interface NavigationProps {
@@ -60,7 +58,7 @@ interface NavigationProps {
}
export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
const dispatch = useDispatch<AppDispatch>();
const dispatch = useDispatch();
const navigate = useNavigate();
const { t } = useTranslation();
@@ -184,7 +182,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
resetConversation();
dispatch(setSelectedAgent(agent));
if (isMobile || isTablet) setNavOpen(!navOpen);
navigate(agent.id ? `/agents/${agent.id}/c/new` : '/c/new');
navigate('/');
};
const handleTogglePin = (agent: Agent) => {
@@ -202,21 +200,20 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
try {
dispatch(setSelectedAgent(null));
// Pre-fetch to choose the route shape (owned-agent / shared / none).
const result = await dispatch(
loadConversation({ id: index, force: true }),
).unwrap();
// Stale: a newer load has already updated Redux; the URL is
// wherever that newer flow lands, leave it alone.
if (result.stale) return;
const data = result.data;
if (!data) {
navigate('/c/new');
const response = await conversationService.getConversation(index, token);
if (!response.ok) {
navigate('/');
return;
}
const data = await response.json();
if (!data) return;
dispatch(setConversation(data.queries));
dispatch(updateConversationId({ query: { conversationId: index } }));
if (!data.agent_id) {
navigate(`/c/${index}`);
navigate('/');
return;
}
@@ -227,7 +224,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
token,
);
if (!sharedResponse.ok) {
navigate(`/c/${index}`);
navigate('/');
return;
}
agent = await sharedResponse.json();
@@ -235,7 +232,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
} else {
const agentResponse = await userService.getAgent(data.agent_id, token);
if (!agentResponse.ok) {
navigate(`/c/${index}`);
navigate('/');
return;
}
agent = await agentResponse.json();
@@ -243,12 +240,12 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
navigate(`/agents/shared/${agent.shared_token}`);
} else {
await Promise.resolve(dispatch(setSelectedAgent(agent)));
navigate(`/agents/${data.agent_id}/c/${index}`);
navigate('/');
}
}
} catch (error) {
console.error('Error handling conversation click:', error);
navigate('/c/new');
navigate('/');
}
};
@@ -267,7 +264,6 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
if (queries && queries?.length > 0) {
resetConversation();
}
navigate('/c/new');
};
async function updateConversationName(updatedConversation: {
@@ -279,6 +275,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
.then((response) => response.json())
.then((data) => {
if (data) {
navigate('/');
fetchConversations();
}
})
@@ -373,7 +370,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
</button>
</div>
<NavLink
to={'/c/new'}
to={'/'}
onClick={() => {
if (isMobile || isTablet) {
setNavOpen(!navOpen);

View File

@@ -174,7 +174,7 @@ export default function AgentCard({
if (section === 'user') {
if (agent.status === 'published') {
dispatch(setSelectedAgent(agent));
navigate(agent.id ? `/agents/${agent.id}/c/new` : '/c/new');
navigate(`/`);
}
}
if (section === 'shared') {

View File

@@ -565,22 +565,8 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
setJsonSchemaText(jsonText);
setJsonSchemaValid(true);
}
// Backfill required fields so older agents (created before
// agent_type / prompt_id / models existed) don't fail
// ``isPublishable()`` and leave Save permanently disabled.
const normalized = {
...data,
agent_type: data.agent_type || 'classic',
prompt_id: data.prompt_id || 'default',
retriever: data.retriever || 'classic',
chunks: data.chunks || '2',
tools: data.tools || [],
sources: data.sources || [],
models: data.models || [],
default_model_id: data.default_model_id || '',
};
setAgent(normalized);
initialAgentRef.current = normalized;
setAgent(data);
initialAgentRef.current = data;
};
getAgent();
}

View File

@@ -1,18 +1,8 @@
import { useTranslation } from 'react-i18next';
import EditIcon from '../assets/edit.svg';
import AgentImage from '../components/AgentImage';
import { getToolDisplayName } from '../utils/toolUtils';
import { Agent } from './types';
export default function SharedAgentCard({
agent,
onEdit,
}: {
agent: Agent;
onEdit?: () => void;
}) {
const { t } = useTranslation();
export default function SharedAgentCard({ agent }: { agent: Agent }) {
// Check if shared metadata exists and has properties (type is 'any' so we validate it's a non-empty object)
const hasSharedMetadata =
agent.shared_metadata &&
@@ -21,14 +11,14 @@ export default function SharedAgentCard({
Object.keys(agent.shared_metadata).length > 0;
return (
<div className="border-border dark:border-border flex w-full max-w-[720px] flex-col rounded-3xl border p-6 shadow-xs sm:w-fit sm:min-w-[480px]">
<div className="flex items-start gap-3">
<div className="flex items-center gap-3">
<div className="flex h-12 w-12 items-center justify-center overflow-hidden rounded-full p-1">
<AgentImage
src={agent.image}
className="h-full w-full rounded-full object-contain"
/>
</div>
<div className="flex max-h-[92px] flex-1 flex-col gap-px">
<div className="flex max-h-[92px] w-[80%] flex-col gap-px">
<h2 className="text-foreground text-base font-semibold sm:text-lg">
{agent.name}
</h2>
@@ -36,17 +26,6 @@ export default function SharedAgentCard({
{agent.description}
</p>
</div>
{onEdit && (
<button
type="button"
onClick={onEdit}
className="border-border hover:bg-accent text-foreground flex shrink-0 items-center gap-1.5 rounded-full border px-3 py-1.5 text-sm font-medium transition-colors"
aria-label={t('agents.edit')}
>
<img src={EditIcon} alt="" className="h-3.5 w-3.5" />
{t('agents.edit')}
</button>
)}
</div>
{hasSharedMetadata && (
<div className="mt-4 flex items-center gap-8">

View File

@@ -813,11 +813,7 @@ function WorkflowBuilderInner() {
const response = await userService.getWorkflow(workflowId, token);
if (!response.ok) throw new Error('Failed to fetch workflow');
const responseData = await response.json();
const {
workflow,
nodes: apiNodes,
edges: apiEdges,
} = responseData.data;
const { workflow, nodes: apiNodes, edges: apiEdges } = responseData.data;
const nextWorkflowName = workflow.name;
const nextWorkflowDescription = workflow.description || '';
const mappedNodes = apiNodes.map((n: WorkflowNode) => {
@@ -1476,9 +1472,7 @@ function WorkflowBuilderInner() {
{t('agents.form.advanced.systemPromptOverride')}
</label>
<p className="mt-0.5 text-[11px] text-gray-500 dark:text-gray-400">
{t(
'agents.form.advanced.systemPromptOverrideDescription',
)}
{t('agents.form.advanced.systemPromptOverrideDescription')}
</p>
</div>
<button

View File

@@ -1,5 +1,3 @@
import { withThrottle, type FetchLike } from './throttle';
export const baseURL =
import.meta.env.VITE_API_HOST || 'https://docsapi.arc53.com';
@@ -20,121 +18,112 @@ const getHeaders = (
return headers;
};
const createClient = (transport: FetchLike) => {
const request = (url: string, init: RequestInit): Promise<Response> =>
transport(`${baseURL}${url}`, init);
const apiClient = {
get: (
url: string,
token: string | null,
headers = {},
signal?: AbortSignal,
): Promise<any> =>
fetch(`${baseURL}${url}`, {
method: 'GET',
headers: getHeaders(token, headers),
signal,
}).then((response) => {
return response;
}),
return {
get: (
url: string,
token: string | null,
headers = {},
signal?: AbortSignal,
): Promise<any> =>
request(url, {
method: 'GET',
headers: getHeaders(token, headers),
signal,
}),
post: (
url: string,
data: any,
token: string | null,
headers = {},
signal?: AbortSignal,
): Promise<any> =>
fetch(`${baseURL}${url}`, {
method: 'POST',
headers: getHeaders(token, headers),
body: JSON.stringify(data),
signal,
}).then((response) => {
return response;
}),
post: (
url: string,
data: any,
token: string | null,
headers = {},
signal?: AbortSignal,
): Promise<any> =>
request(url, {
method: 'POST',
headers: getHeaders(token, headers),
body: JSON.stringify(data),
signal,
}),
postFormData: (
url: string,
formData: FormData,
token: string | null,
headers = {},
signal?: AbortSignal,
): Promise<Response> => {
return fetch(`${baseURL}${url}`, {
method: 'POST',
headers: getHeaders(token, headers, true),
body: formData,
signal,
});
},
postFormData: (
url: string,
formData: FormData,
token: string | null,
headers = {},
signal?: AbortSignal,
): Promise<Response> =>
request(url, {
method: 'POST',
headers: getHeaders(token, headers, true),
body: formData,
signal,
}),
put: (
url: string,
data: any,
token: string | null,
headers = {},
signal?: AbortSignal,
): Promise<any> =>
fetch(`${baseURL}${url}`, {
method: 'PUT',
headers: getHeaders(token, headers),
body: JSON.stringify(data),
signal,
}).then((response) => {
return response;
}),
put: (
url: string,
data: any,
token: string | null,
headers = {},
signal?: AbortSignal,
): Promise<any> =>
request(url, {
method: 'PUT',
headers: getHeaders(token, headers),
body: JSON.stringify(data),
signal,
}),
patch: (
url: string,
data: any,
token: string | null,
headers = {},
signal?: AbortSignal,
): Promise<any> =>
fetch(`${baseURL}${url}`, {
method: 'PATCH',
headers: getHeaders(token, headers),
body: JSON.stringify(data),
signal,
}).then((response) => {
return response;
}),
patch: (
url: string,
data: any,
token: string | null,
headers = {},
signal?: AbortSignal,
): Promise<any> =>
request(url, {
method: 'PATCH',
headers: getHeaders(token, headers),
body: JSON.stringify(data),
signal,
}),
putFormData: (
url: string,
formData: FormData,
token: string | null,
headers = {},
signal?: AbortSignal,
): Promise<Response> => {
return fetch(`${baseURL}${url}`, {
method: 'PUT',
headers: getHeaders(token, headers, true),
body: formData,
signal,
});
},
putFormData: (
url: string,
formData: FormData,
token: string | null,
headers = {},
signal?: AbortSignal,
): Promise<Response> =>
request(url, {
method: 'PUT',
headers: getHeaders(token, headers, true),
body: formData,
signal,
}),
delete: (
url: string,
token: string | null,
headers = {},
signal?: AbortSignal,
): Promise<any> =>
request(url, {
method: 'DELETE',
headers: getHeaders(token, headers),
signal,
}),
};
delete: (
url: string,
token: string | null,
headers = {},
signal?: AbortSignal,
): Promise<any> =>
fetch(`${baseURL}${url}`, {
method: 'DELETE',
headers: getHeaders(token, headers),
signal,
}).then((response) => {
return response;
}),
};
const apiClient = createClient((url, init) => fetch(url, init));
// Throttled client for endpoints that fan out, are polled, or are commonly
// requested concurrently from multiple components. Shares a single concurrency
// budget and de-duplicates identical in-flight GETs.
export const throttledApiClient = createClient(
withThrottle((url, init) => fetch(url, init), { debugLabel: 'api' }),
);
if (import.meta.env.DEV && typeof window !== 'undefined') {
(window as unknown as Record<string, unknown>).__apiClient = apiClient;
(window as unknown as Record<string, unknown>).__throttledApiClient =
throttledApiClient;
(window as unknown as Record<string, unknown>).__baseURL = baseURL;
}
export default apiClient;

View File

@@ -28,6 +28,7 @@ const endpoints = {
UPDATE_PROMPT: '/api/update_prompt',
SINGLE_PROMPT: (id: string) => `/api/get_single_prompt?id=${id}`,
DELETE_PATH: (docPath: string) => `/api/delete_old?source_id=${docPath}`,
TASK_STATUS: (task_id: string) => `/api/task_status?task_id=${task_id}`,
MESSAGE_ANALYTICS: '/api/get_message_analytics',
TOKEN_ANALYTICS: '/api/get_token_analytics',
FEEDBACK_ANALYTICS: '/api/get_feedback_analytics',
@@ -42,11 +43,6 @@ const endpoints = {
DELETE_TOOL: '/api/delete_tool',
PARSE_SPEC: '/api/parse_spec',
SYNC_CONNECTOR: '/api/connectors/sync',
CONNECTOR_AUTH: (provider: string) =>
`/api/connectors/auth?provider=${provider}`,
CONNECTOR_FILES: '/api/connectors/files',
CONNECTOR_VALIDATE_SESSION: '/api/connectors/validate-session',
CONNECTOR_DISCONNECT: '/api/connectors/disconnect',
GET_CHUNKS: (
docId: string,
page: number,
@@ -63,7 +59,6 @@ const endpoints = {
UPDATE_CHUNK: '/api/update_chunk',
STORE_ATTACHMENT: '/api/store_attachment',
STT: '/api/stt',
TTS: '/api/tts',
LIVE_STT_START: '/api/stt/live/start',
LIVE_STT_CHUNK: '/api/stt/live/chunk',
LIVE_STT_FINISH: '/api/stt/live/finish',
@@ -72,6 +67,8 @@ const endpoints = {
MANAGE_SOURCE_FILES: '/api/manage_source_files',
MCP_TEST_CONNECTION: '/api/mcp_server/test',
MCP_SAVE_SERVER: '/api/mcp_server/save',
MCP_OAUTH_STATUS: (task_id: string) =>
`/api/mcp_server/oauth_status/${task_id}`,
MCP_AUTH_STATUS: '/api/mcp_server/auth_status',
AGENT_FOLDERS: '/api/agents/folders/',
AGENT_FOLDER: (id: string) => `/api/agents/folders/${id}`,
@@ -95,7 +92,6 @@ const endpoints = {
FEEDBACK: '/api/feedback',
CONVERSATION: (id: string) => `/api/get_single_conversation?id=${id}`,
CONVERSATIONS: '/api/get_conversations',
MESSAGE_TAIL: (messageId: string) => `/api/messages/${messageId}/tail`,
SHARE_CONVERSATION: (isPromptable: boolean) =>
`/api/share?isPromptable=${isPromptable}`,
SHARED_CONVERSATION: (identifier: string) =>

View File

@@ -6,20 +6,18 @@ const conversationService = {
data: any,
token: string | null,
signal: AbortSignal,
headers: Record<string, string> = {},
): Promise<any> =>
apiClient.post(endpoints.CONVERSATION.ANSWER, data, token, headers, signal),
apiClient.post(endpoints.CONVERSATION.ANSWER, data, token, {}, signal),
answerStream: (
data: any,
token: string | null,
signal: AbortSignal,
headers: Record<string, string> = {},
): Promise<any> =>
apiClient.post(
endpoints.CONVERSATION.ANSWER_STREAMING,
data,
token,
headers,
{},
signal,
),
search: (data: any, token: string | null): Promise<any> =>
@@ -28,8 +26,6 @@ const conversationService = {
apiClient.post(endpoints.CONVERSATION.FEEDBACK, data, token, {}),
getConversation: (id: string, token: string | null): Promise<any> =>
apiClient.get(endpoints.CONVERSATION.CONVERSATION(id), token, {}),
tailMessage: (messageId: string, token: string | null): Promise<any> =>
apiClient.get(endpoints.CONVERSATION.MESSAGE_TAIL(messageId), token, {}),
getConversations: (token: string | null): Promise<any> =>
apiClient.get(endpoints.CONVERSATION.CONVERSATIONS, token, {}),
shareConversation: (

View File

@@ -1,12 +1,11 @@
import { getSessionToken } from '../../utils/providerUtils';
import apiClient, { throttledApiClient } from '../client';
import apiClient from '../client';
import endpoints from '../endpoints';
const userService = {
getConfig: (): Promise<any> =>
throttledApiClient.get(endpoints.USER.CONFIG, null),
getConfig: (): Promise<any> => apiClient.get(endpoints.USER.CONFIG, null),
getNewToken: (): Promise<any> =>
throttledApiClient.get(endpoints.USER.NEW_TOKEN, null),
apiClient.get(endpoints.USER.NEW_TOKEN, null),
getDocs: (token: string | null): Promise<any> =>
apiClient.get(`${endpoints.USER.DOCS}`, token),
getDocsWithPagination: (query: string, token: string | null): Promise<any> =>
@@ -18,9 +17,9 @@ const userService = {
deleteAPIKey: (data: any, token: string | null): Promise<any> =>
apiClient.post(endpoints.USER.DELETE_API_KEY, data, token),
getAgent: (id: string, token: string | null): Promise<any> =>
throttledApiClient.get(endpoints.USER.AGENT(id), token),
apiClient.get(endpoints.USER.AGENT(id), token),
getAgents: (token: string | null): Promise<any> =>
throttledApiClient.get(endpoints.USER.AGENTS, token),
apiClient.get(endpoints.USER.AGENTS, token),
createAgent: (data: any, token: string | null): Promise<any> =>
apiClient.postFormData(endpoints.USER.CREATE_AGENT, data, token),
updateAgent: (
@@ -32,19 +31,19 @@ const userService = {
deleteAgent: (id: string, token: string | null): Promise<any> =>
apiClient.delete(endpoints.USER.DELETE_AGENT(id), token),
getPinnedAgents: (token: string | null): Promise<any> =>
throttledApiClient.get(endpoints.USER.PINNED_AGENTS, token),
apiClient.get(endpoints.USER.PINNED_AGENTS, token),
togglePinAgent: (id: string, token: string | null): Promise<any> =>
apiClient.post(endpoints.USER.TOGGLE_PIN_AGENT(id), {}, token),
getSharedAgent: (id: string, token: string | null): Promise<any> =>
apiClient.get(endpoints.USER.SHARED_AGENT(id), token),
getSharedAgents: (token: string | null): Promise<any> =>
throttledApiClient.get(endpoints.USER.SHARED_AGENTS, token),
apiClient.get(endpoints.USER.SHARED_AGENTS, token),
shareAgent: (data: any, token: string | null): Promise<any> =>
apiClient.put(endpoints.USER.SHARE_AGENT, data, token),
removeSharedAgent: (id: string, token: string | null): Promise<any> =>
apiClient.delete(endpoints.USER.REMOVE_SHARED_AGENT(id), token),
getTemplateAgents: (token: string | null): Promise<any> =>
throttledApiClient.get(endpoints.USER.TEMPLATE_AGENTS, token),
apiClient.get(endpoints.USER.TEMPLATE_AGENTS, token),
adoptAgent: (id: string, token: string | null): Promise<any> =>
apiClient.post(endpoints.USER.ADOPT_AGENT(id), {}, token),
getAgentWebhook: (id: string, token: string | null): Promise<any> =>
@@ -61,6 +60,8 @@ const userService = {
apiClient.get(endpoints.USER.SINGLE_PROMPT(id), token),
deletePath: (docPath: string, token: string | null): Promise<any> =>
apiClient.get(endpoints.USER.DELETE_PATH(docPath), token),
getTaskStatus: (task_id: string, token: string | null): Promise<any> =>
apiClient.get(endpoints.USER.TASK_STATUS(task_id), token),
getMessageAnalytics: (data: any, token: string | null): Promise<any> =>
apiClient.post(endpoints.USER.MESSAGE_ANALYTICS, data, token),
getTokenAnalytics: (data: any, token: string | null): Promise<any> =>
@@ -148,7 +149,7 @@ const userService = {
path?: string,
search?: string,
): Promise<any> =>
throttledApiClient.get(
apiClient.get(
endpoints.USER.GET_CHUNKS(docId, page, perPage, path, search),
token,
),
@@ -163,15 +164,17 @@ const userService = {
updateChunk: (data: any, token: string | null): Promise<any> =>
apiClient.put(endpoints.USER.UPDATE_CHUNK, data, token),
getDirectoryStructure: (docId: string, token: string | null): Promise<any> =>
throttledApiClient.get(endpoints.USER.DIRECTORY_STRUCTURE(docId), token),
apiClient.get(endpoints.USER.DIRECTORY_STRUCTURE(docId), token),
manageSourceFiles: (data: FormData, token: string | null): Promise<any> =>
apiClient.postFormData(endpoints.USER.MANAGE_SOURCE_FILES, data, token),
testMCPConnection: (data: any, token: string | null): Promise<any> =>
apiClient.post(endpoints.USER.MCP_TEST_CONNECTION, data, token),
saveMCPServer: (data: any, token: string | null): Promise<any> =>
apiClient.post(endpoints.USER.MCP_SAVE_SERVER, data, token),
getMCPOAuthStatus: (task_id: string, token: string | null): Promise<any> =>
apiClient.get(endpoints.USER.MCP_OAUTH_STATUS(task_id), token),
getMCPAuthStatus: (token: string | null): Promise<any> =>
throttledApiClient.get(endpoints.USER.MCP_AUTH_STATUS, token),
apiClient.get(endpoints.USER.MCP_AUTH_STATUS, token),
syncConnector: (
docId: string,
provider: string,
@@ -188,50 +191,8 @@ const userService = {
token,
);
},
getConnectorAuthUrl: (provider: string, token: string | null): Promise<any> =>
apiClient.get(endpoints.USER.CONNECTOR_AUTH(provider), token),
getConnectorFiles: (
data: any,
token: string | null,
signal?: AbortSignal,
): Promise<any> =>
throttledApiClient.post(
endpoints.USER.CONNECTOR_FILES,
data,
token,
{},
signal,
),
validateConnectorSession: (
provider: string,
token: string | null,
): Promise<any> =>
apiClient.post(
endpoints.USER.CONNECTOR_VALIDATE_SESSION,
{
provider,
session_token: getSessionToken(provider),
},
token,
),
disconnectConnector: (
provider: string,
sessionToken: string,
token: string | null,
): Promise<any> =>
apiClient.post(
endpoints.USER.CONNECTOR_DISCONNECT,
{ provider, session_token: sessionToken },
token,
),
textToSpeech: (
text: string,
token: string | null,
signal?: AbortSignal,
): Promise<any> =>
apiClient.post(endpoints.USER.TTS, { text }, token, {}, signal),
getAgentFolders: (token: string | null): Promise<any> =>
throttledApiClient.get(endpoints.USER.AGENT_FOLDERS, token),
apiClient.get(endpoints.USER.AGENT_FOLDERS, token),
createAgentFolder: (
data: { name: string; parent_id?: string },
token: string | null,

View File

@@ -1,223 +0,0 @@
/**
* Transport-layer middleware factory for the frontend API layer.
*/
export type FetchLike = (
input: string,
init?: RequestInit,
) => Promise<Response>;
export interface ThrottleConfig {
maxConcurrentGlobal?: number;
maxConcurrentPerRoute?: number;
dedupe?: boolean;
dedupeKey?: (url: string, init?: RequestInit) => string | false;
debugLabel?: string;
}
const DEFAULT_MAX_CONCURRENT_GLOBAL = 8;
const DEFAULT_MAX_CONCURRENT_PER_ROUTE = 3;
type QueueItem = {
run: () => void;
signal?: AbortSignal;
onAbort: () => void;
};
function routeKey(method: string, url: string): string {
let pathname = url;
try {
pathname = new URL(url, 'http://_').pathname;
} catch {
pathname = url.split('?')[0];
}
return `${method.toUpperCase()} ${pathname}`;
}
function abortError(): DOMException {
return new DOMException('The operation was aborted.', 'AbortError');
}
interface ThrottleState {
perRouteQueues: Map<string, QueueItem[]>;
inflightPerRoute: Map<string, number>;
inflightGets: Map<string, Promise<Response>>;
inflightGlobal: number;
}
function createState(): ThrottleState {
return {
perRouteQueues: new Map(),
inflightPerRoute: new Map(),
inflightGets: new Map(),
inflightGlobal: 0,
};
}
export function withThrottle(
fetchLike: FetchLike,
config: ThrottleConfig = {},
): FetchLike & { __reset: () => void } {
const maxGlobal = config.maxConcurrentGlobal ?? DEFAULT_MAX_CONCURRENT_GLOBAL;
const maxPerRoute =
config.maxConcurrentPerRoute ?? DEFAULT_MAX_CONCURRENT_PER_ROUTE;
const dedupeEnabled = config.dedupe !== false;
const state = createState();
// Toggle in DevTools with: localStorage.setItem('debug:throttle', '1')
const isDebug = (): boolean => {
try {
return (
typeof localStorage !== 'undefined' &&
localStorage.getItem('debug:throttle') === '1'
);
} catch {
return false;
}
};
const log = (
event: string,
key: string,
extra?: Record<string, unknown>,
): void => {
if (!isDebug()) return;
const queued = state.perRouteQueues.get(key)?.length ?? 0;
const perRoute = state.inflightPerRoute.get(key) ?? 0;
const tag = config.debugLabel
? `[throttle:${config.debugLabel}]`
: '[throttle]';
console.debug(
`${tag} ${event} ${key} | inflight=${state.inflightGlobal}/${maxGlobal} route=${perRoute}/${maxPerRoute} queued=${queued}`,
extra ?? '',
);
};
const canDispatch = (key: string): boolean => {
const perRoute = state.inflightPerRoute.get(key) ?? 0;
return state.inflightGlobal < maxGlobal && perRoute < maxPerRoute;
};
const pumpQueues = (): void => {
for (const [key, queue] of state.perRouteQueues) {
while (queue.length > 0 && canDispatch(key)) {
const item = queue.shift()!;
item.signal?.removeEventListener('abort', item.onAbort);
item.run();
}
if (queue.length === 0) state.perRouteQueues.delete(key);
}
};
const enqueue = (key: string, item: QueueItem): void => {
let queue = state.perRouteQueues.get(key);
if (!queue) {
queue = [];
state.perRouteQueues.set(key, queue);
}
queue.push(item);
};
const acquireSlot = (key: string, signal?: AbortSignal): Promise<void> =>
new Promise((resolve, reject) => {
if (signal?.aborted) {
reject(abortError());
return;
}
const item: QueueItem = {
signal,
run: () => {
state.inflightGlobal += 1;
state.inflightPerRoute.set(
key,
(state.inflightPerRoute.get(key) ?? 0) + 1,
);
resolve();
},
onAbort: () => {
const queue = state.perRouteQueues.get(key);
if (queue) {
const idx = queue.indexOf(item);
if (idx >= 0) queue.splice(idx, 1);
}
log('abort-queued', key);
reject(abortError());
},
};
const queued = state.perRouteQueues.get(key);
if ((!queued || queued.length === 0) && canDispatch(key)) {
item.run();
log('dispatch', key);
return;
}
signal?.addEventListener('abort', item.onAbort, { once: true });
enqueue(key, item);
log('queued', key);
});
const releaseSlot = (key: string): void => {
state.inflightGlobal = Math.max(0, state.inflightGlobal - 1);
const next = (state.inflightPerRoute.get(key) ?? 1) - 1;
if (next <= 0) state.inflightPerRoute.delete(key);
else state.inflightPerRoute.set(key, next);
log('release', key);
pumpQueues();
};
const wrapped = (async (url, init = {}) => {
const method = (init.method ?? 'GET').toUpperCase();
const signal = init.signal ?? undefined;
const key = routeKey(method, url);
// Dedupe is restricted to GETs without a caller-supplied AbortSignal:
// sharing a single underlying fetch across waiters means an abort by one
// caller would reject the others, which is not the contract callers expect.
const customKey = config.dedupeKey?.(url, init);
const dedupeAllowed =
dedupeEnabled &&
customKey !== false &&
method === 'GET' &&
!init.body &&
!signal;
const dedupeKey = typeof customKey === 'string' ? customKey : `GET ${url}`;
if (dedupeAllowed) {
const existing = state.inflightGets.get(dedupeKey);
if (existing) {
log('dedupe-hit', key, { dedupeKey });
return existing.then((r) => r.clone());
}
}
const run = async (): Promise<Response> => {
await acquireSlot(key, signal);
try {
return await fetchLike(url, init);
} finally {
releaseSlot(key);
}
};
if (dedupeAllowed) {
const promise = run();
state.inflightGets.set(dedupeKey, promise);
promise.finally(() => {
if (state.inflightGets.get(dedupeKey) === promise) {
state.inflightGets.delete(dedupeKey);
}
});
return promise.then((r) => r.clone());
}
return run();
}) as FetchLike & { __reset: () => void };
wrapped.__reset = () => {
state.perRouteQueues.clear();
state.inflightPerRoute.clear();
state.inflightGets.clear();
state.inflightGlobal = 0;
};
return wrapped;
}

Some files were not shown because too many files have changed in this diff Show More