mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-22 13:25:08 +00:00
Compare commits
6 Commits
fix-worker
...
feat/defau
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
85e885520b | ||
|
|
16695205d5 | ||
|
|
764a23d641 | ||
|
|
0bbcbf4539 | ||
|
|
d041db77e1 | ||
|
|
1de82ca040 |
@@ -98,6 +98,7 @@ class BaseAgent(ABC):
|
||||
user_api_key=user_api_key,
|
||||
user=self.user,
|
||||
decoded_token=decoded_token,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
self.attachments = attachments or []
|
||||
|
||||
356
application/agents/default_tools.py
Normal file
356
application/agents/default_tools.py
Normal file
@@ -0,0 +1,356 @@
|
||||
"""Default chat tools — config-free tools on by default in chats."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from application.core.settings import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Fixed namespace — never regenerate; produced ids are persisted.
|
||||
_DEFAULT_TOOL_NAMESPACE = uuid.UUID("6b1d3f2a-9c84-4d17-bf6e-2a0c5e8d4471")
|
||||
|
||||
# Tool names whose storage tables FK ``tool_id`` to ``user_tools.id``;
|
||||
# a synthetic id has no row, so a write would FK-violate. Schema-rot
|
||||
# guard: ``tests.agents.test_default_tools.TestFkBoundToolsIsInSync``.
|
||||
_FK_BOUND_TOOLS = frozenset({"notes", "todo_list"})
|
||||
|
||||
# Tools that should NEVER appear in a headless run (scheduled or webhook).
|
||||
# ``scheduler`` only makes sense from an interactive chat — letting an LLM
|
||||
# call ``schedule_task`` from a scheduled run chains new schedules each fire,
|
||||
# bounded only by ``SCHEDULE_MAX_PER_USER`` (cost foot-gun, confusing UX).
|
||||
_HEADLESS_EXCLUDED_TOOLS = frozenset({"scheduler"})
|
||||
|
||||
# Agent-selectable builtins: hidden from the Add-Tool catalog (internal=True)
|
||||
# and exposed to the agent picker via the same synthetic-id machinery as
|
||||
# default tools. Names may overlap with DEFAULT_CHAT_TOOLS (e.g. ``scheduler``)
|
||||
# — both registries share ``_DEFAULT_TOOL_NAMESPACE`` so the same uuid5
|
||||
# resolves either way (the dual-flag row carries ``default`` AND ``builtin``).
|
||||
BUILTIN_AGENT_TOOLS: tuple = ("scheduler",)
|
||||
|
||||
_tool_cache: Dict[str, Optional[Any]] = {}
|
||||
_ids_cache: Dict[tuple, Dict[str, str]] = {}
|
||||
_loaded_cache: Dict[tuple, List[str]] = {}
|
||||
_builtin_ids_cache: Dict[tuple, Dict[str, str]] = {}
|
||||
_builtin_loaded_cache: Dict[tuple, List[str]] = {}
|
||||
|
||||
|
||||
def _load_tool(tool_name: str) -> Optional[Any]:
|
||||
"""Return a metadata-only instance of a tool, or None if it has no class."""
|
||||
# Imports just the named module (not the whole package) — avoids the
|
||||
# circular import via ``mcp_tool`` → ``application.api.user``.
|
||||
if tool_name in _tool_cache:
|
||||
return _tool_cache[tool_name]
|
||||
|
||||
from application.agents.tools.base import Tool
|
||||
|
||||
instance: Optional[Any] = None
|
||||
try:
|
||||
module = importlib.import_module(f"application.agents.tools.{tool_name}")
|
||||
except ModuleNotFoundError:
|
||||
_tool_cache[tool_name] = None
|
||||
return None
|
||||
for _, obj in inspect.getmembers(module, inspect.isclass):
|
||||
if issubclass(obj, Tool) and obj is not Tool:
|
||||
try:
|
||||
instance = obj({})
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"DEFAULT_CHAT_TOOLS entry %r failed to instantiate; skipping.",
|
||||
tool_name,
|
||||
)
|
||||
instance = None
|
||||
break
|
||||
_tool_cache[tool_name] = instance
|
||||
return instance
|
||||
|
||||
|
||||
def default_tool_id(tool_name: str) -> str:
|
||||
"""Return the deterministic synthetic id for a default tool name."""
|
||||
return str(uuid.uuid5(_DEFAULT_TOOL_NAMESPACE, tool_name))
|
||||
|
||||
|
||||
def default_tool_ids() -> Dict[str, str]:
|
||||
"""Map each configured default-tool name to its synthetic id (memoized)."""
|
||||
key = tuple(settings.DEFAULT_CHAT_TOOLS)
|
||||
cached = _ids_cache.get(key)
|
||||
if cached is None:
|
||||
cached = {name: default_tool_id(name) for name in key}
|
||||
_ids_cache[key] = cached
|
||||
return cached
|
||||
|
||||
|
||||
def is_default_tool_id(tool_id: Any) -> bool:
|
||||
"""Return True if ``tool_id`` is a synthetic default-tool id."""
|
||||
if not tool_id:
|
||||
return False
|
||||
return str(tool_id) in set(default_tool_ids().values())
|
||||
|
||||
|
||||
def default_tool_name_for_id(tool_id: Any) -> Optional[str]:
|
||||
"""Return the default-tool name for a synthetic id, or None."""
|
||||
target = str(tool_id) if tool_id else ""
|
||||
for name, synthetic_id in default_tool_ids().items():
|
||||
if synthetic_id == target:
|
||||
return name
|
||||
return None
|
||||
|
||||
|
||||
def builtin_agent_tool_ids() -> Dict[str, str]:
|
||||
"""Map each agent-selectable builtin to its synthetic id (memoized)."""
|
||||
key = tuple(BUILTIN_AGENT_TOOLS)
|
||||
cached = _builtin_ids_cache.get(key)
|
||||
if cached is None:
|
||||
cached = {name: default_tool_id(name) for name in key}
|
||||
_builtin_ids_cache[key] = cached
|
||||
return cached
|
||||
|
||||
|
||||
def is_builtin_agent_tool_id(tool_id: Any) -> bool:
|
||||
"""Return True if ``tool_id`` is an agent-selectable builtin synthetic id."""
|
||||
if not tool_id:
|
||||
return False
|
||||
return str(tool_id) in set(builtin_agent_tool_ids().values())
|
||||
|
||||
|
||||
def builtin_agent_tool_name_for_id(tool_id: Any) -> Optional[str]:
|
||||
"""Return the builtin tool name for a synthetic id, or None."""
|
||||
target = str(tool_id) if tool_id else ""
|
||||
for name, synthetic_id in builtin_agent_tool_ids().items():
|
||||
if synthetic_id == target:
|
||||
return name
|
||||
return None
|
||||
|
||||
|
||||
def synthesized_tool_name_for_id(tool_id: Any) -> Optional[str]:
|
||||
"""Return the tool name for any synthetic id (default or builtin), or None."""
|
||||
return default_tool_name_for_id(tool_id) or builtin_agent_tool_name_for_id(tool_id)
|
||||
|
||||
|
||||
def is_synthesized_tool_id(tool_id: Any) -> bool:
|
||||
"""Return True for any synthetic id (default chat or agent-builtin)."""
|
||||
return is_default_tool_id(tool_id) or is_builtin_agent_tool_id(tool_id)
|
||||
|
||||
|
||||
def loaded_default_tools() -> List[str]:
|
||||
"""Return configured default-tool names that resolve to a loaded tool."""
|
||||
# Silent + memoized — runs per request; the one-time skip notice
|
||||
# for unimplemented names lives in ``validate_default_chat_tools``.
|
||||
key = tuple(settings.DEFAULT_CHAT_TOOLS)
|
||||
cached = _loaded_cache.get(key)
|
||||
if cached is None:
|
||||
cached = [name for name in key if _load_tool(name) is not None]
|
||||
_loaded_cache[key] = cached
|
||||
return cached
|
||||
|
||||
|
||||
def loaded_builtin_agent_tools() -> List[str]:
|
||||
"""Return builtin agent-tool names that resolve to a loaded tool."""
|
||||
key = tuple(BUILTIN_AGENT_TOOLS)
|
||||
cached = _builtin_loaded_cache.get(key)
|
||||
if cached is None:
|
||||
cached = [name for name in key if _load_tool(name) is not None]
|
||||
_builtin_loaded_cache[key] = cached
|
||||
return cached
|
||||
|
||||
|
||||
def validate_default_chat_tools() -> List[str]:
|
||||
"""Validate ``DEFAULT_CHAT_TOOLS`` at startup; return the usable names."""
|
||||
skipped = [
|
||||
name for name in settings.DEFAULT_CHAT_TOOLS if _load_tool(name) is None
|
||||
]
|
||||
if skipped:
|
||||
logger.debug(
|
||||
"DEFAULT_CHAT_TOOLS entries with no loaded tool, skipped: %s. "
|
||||
"Each activates automatically once its tool exists.",
|
||||
", ".join(skipped),
|
||||
)
|
||||
usable = loaded_default_tools()
|
||||
for name in usable:
|
||||
if name in _FK_BOUND_TOOLS:
|
||||
raise ValueError(
|
||||
f"DEFAULT_CHAT_TOOLS entry {name!r} has a storage table "
|
||||
f"that foreign-keys tool_id to user_tools; a default tool "
|
||||
f"has a synthetic id with no user_tools row, so it would "
|
||||
f"fail at write time. It cannot be defaulted on."
|
||||
)
|
||||
requirements = _load_tool(name).get_config_requirements() or {}
|
||||
required = [
|
||||
key for key, spec in requirements.items()
|
||||
if isinstance(spec, dict) and spec.get("required")
|
||||
]
|
||||
if required:
|
||||
raise ValueError(
|
||||
f"DEFAULT_CHAT_TOOLS entry {name!r} requires config "
|
||||
f"fields {required}; only config-free tools may be "
|
||||
"defaulted on."
|
||||
)
|
||||
if usable:
|
||||
logger.info("Default chat tools active: %s", ", ".join(usable))
|
||||
return usable
|
||||
|
||||
|
||||
def _tool_display(tool_name: str) -> str:
|
||||
"""Return the human-readable display name from the tool docstring."""
|
||||
tool = _load_tool(tool_name)
|
||||
doc = (tool.__doc__ or "").strip() if tool else ""
|
||||
first_line = doc.split("\n", 1)[0].strip() if doc else ""
|
||||
return first_line or tool_name
|
||||
|
||||
|
||||
def _tool_description(tool_name: str) -> str:
|
||||
"""Return the tool description (docstring lines after the first)."""
|
||||
tool = _load_tool(tool_name)
|
||||
doc = (tool.__doc__ or "").strip() if tool else ""
|
||||
parts = doc.split("\n", 1)
|
||||
return parts[1].strip() if len(parts) > 1 else ""
|
||||
|
||||
|
||||
def synthesize_default_tool(tool_name: str) -> Optional[Dict[str, Any]]:
|
||||
"""Build an in-memory ``user_tools``-shaped row for a default tool."""
|
||||
tool = _load_tool(tool_name)
|
||||
if tool is None:
|
||||
return None
|
||||
synthetic_id = default_tool_id(tool_name)
|
||||
return {
|
||||
"id": synthetic_id,
|
||||
"_id": synthetic_id,
|
||||
"name": tool_name,
|
||||
"display_name": _tool_display(tool_name),
|
||||
"custom_name": "",
|
||||
"description": _tool_description(tool_name),
|
||||
"config": {},
|
||||
"config_requirements": {},
|
||||
"actions": tool.get_actions_metadata() or [],
|
||||
"status": True,
|
||||
"default": True,
|
||||
}
|
||||
|
||||
|
||||
def synthesize_builtin_agent_tool(tool_name: str) -> Optional[Dict[str, Any]]:
|
||||
"""Build an in-memory ``user_tools``-shaped row for a builtin agent tool."""
|
||||
tool = _load_tool(tool_name)
|
||||
if tool is None:
|
||||
return None
|
||||
synthetic_id = default_tool_id(tool_name)
|
||||
return {
|
||||
"id": synthetic_id,
|
||||
"_id": synthetic_id,
|
||||
"name": tool_name,
|
||||
"display_name": _tool_display(tool_name),
|
||||
"custom_name": "",
|
||||
"description": _tool_description(tool_name),
|
||||
"config": {},
|
||||
"config_requirements": {},
|
||||
"actions": tool.get_actions_metadata() or [],
|
||||
"status": True,
|
||||
"default": False,
|
||||
"builtin": True,
|
||||
}
|
||||
|
||||
|
||||
def synthesize_tool_by_name(tool_name: str) -> Optional[Dict[str, Any]]:
|
||||
"""Synthesize the row for any default or builtin tool name."""
|
||||
if tool_name in BUILTIN_AGENT_TOOLS:
|
||||
return synthesize_builtin_agent_tool(tool_name)
|
||||
return synthesize_default_tool(tool_name)
|
||||
|
||||
|
||||
def disabled_default_tools(user_doc: Optional[Dict[str, Any]]) -> List[str]:
|
||||
"""Return the user's opt-out list from ``tool_preferences``."""
|
||||
if not isinstance(user_doc, dict):
|
||||
return []
|
||||
prefs = user_doc.get("tool_preferences") or {}
|
||||
if not isinstance(prefs, dict):
|
||||
return []
|
||||
disabled = prefs.get("disabled_default_tools") or []
|
||||
if not isinstance(disabled, list):
|
||||
return []
|
||||
return [str(name) for name in disabled]
|
||||
|
||||
|
||||
def synthesized_default_tools(
|
||||
user_doc: Optional[Dict[str, Any]] = None,
|
||||
*,
|
||||
headless: bool = False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Return synthesized default-tool rows for an agentless chat."""
|
||||
# Agent-bound chats must NOT call this — they resolve exactly
|
||||
# ``agents.tools``. Disabled defaults are dropped. ``headless=True``
|
||||
# additionally drops chat-only tools (e.g. ``scheduler``) so a scheduled
|
||||
# / webhook LLM can't re-schedule itself.
|
||||
disabled = set(disabled_default_tools(user_doc))
|
||||
rows: List[Dict[str, Any]] = []
|
||||
for name in loaded_default_tools():
|
||||
if name in disabled:
|
||||
continue
|
||||
if headless and name in _HEADLESS_EXCLUDED_TOOLS:
|
||||
continue
|
||||
row = synthesize_default_tool(name)
|
||||
if row is not None:
|
||||
rows.append(row)
|
||||
return rows
|
||||
|
||||
|
||||
def is_headless_excluded_tool(tool_name: Optional[str]) -> bool:
|
||||
"""Return True if ``tool_name`` must be hidden from headless runs."""
|
||||
return bool(tool_name) and tool_name in _HEADLESS_EXCLUDED_TOOLS
|
||||
|
||||
|
||||
def default_tools_for_management(
|
||||
user_doc: Optional[Dict[str, Any]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Return every loaded default tool with its on/off ``status``."""
|
||||
# Unlike ``synthesized_default_tools`` (chat toolset), this keeps
|
||||
# disabled tools so the management UI can render their toggle.
|
||||
disabled = set(disabled_default_tools(user_doc))
|
||||
rows: List[Dict[str, Any]] = []
|
||||
for name in loaded_default_tools():
|
||||
row = synthesize_default_tool(name)
|
||||
if row is None:
|
||||
continue
|
||||
row["status"] = name not in disabled
|
||||
rows.append(row)
|
||||
return rows
|
||||
|
||||
|
||||
def builtin_agent_tools_for_management() -> List[Dict[str, Any]]:
|
||||
"""Return every loaded agent-builtin tool for the agent picker (no per-user state)."""
|
||||
rows: List[Dict[str, Any]] = []
|
||||
for name in loaded_builtin_agent_tools():
|
||||
row = synthesize_builtin_agent_tool(name)
|
||||
if row is None:
|
||||
continue
|
||||
rows.append(row)
|
||||
return rows
|
||||
|
||||
|
||||
def resolve_tool_by_id(
|
||||
tool_id: Any,
|
||||
user: Optional[str],
|
||||
*,
|
||||
user_tools_repo: Any = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Resolve a tool by id: default/builtin synthetic id, else user_tools row.
|
||||
|
||||
Dual-registered tools (e.g. ``scheduler``) get both flags on the resolved
|
||||
row so callers can branch on either path without losing the discriminator.
|
||||
"""
|
||||
default_name = default_tool_name_for_id(tool_id)
|
||||
builtin_name = builtin_agent_tool_name_for_id(tool_id)
|
||||
if default_name is not None and builtin_name is not None:
|
||||
row = synthesize_default_tool(default_name) or {}
|
||||
row["builtin"] = True
|
||||
return row or None
|
||||
if default_name is not None:
|
||||
return synthesize_default_tool(default_name)
|
||||
if builtin_name is not None:
|
||||
return synthesize_builtin_agent_tool(builtin_name)
|
||||
if user_tools_repo is None or not user:
|
||||
return None
|
||||
return user_tools_repo.get_any(str(tool_id), user)
|
||||
173
application/agents/headless_runner.py
Normal file
173
application/agents/headless_runner.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""Shared headless agent runner used by webhooks and scheduled runs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
|
||||
from application.agents.agent_creator import AgentCreator
|
||||
from application.agents.tool_executor import ToolExecutor
|
||||
from application.api.answer.services.stream_processor import get_prompt
|
||||
from application.core.settings import settings
|
||||
from application.retriever.retriever_creator import RetrieverCreator
|
||||
from application.storage.db.repositories.sources import SourcesRepository
|
||||
from application.storage.db.session import db_readonly
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _resolve_owner(agent_config: Dict[str, Any]) -> Optional[str]:
|
||||
return agent_config.get("user_id") or agent_config.get("user")
|
||||
|
||||
|
||||
def _resolve_agent_id(agent_config: Dict[str, Any]) -> Optional[str]:
|
||||
raw = agent_config.get("id") or agent_config.get("_id")
|
||||
return str(raw) if raw else None
|
||||
|
||||
|
||||
def run_agent_headless(
|
||||
agent_config: Dict[str, Any],
|
||||
query: str,
|
||||
*,
|
||||
tool_allowlist: Optional[Iterable[str]] = None,
|
||||
model_id_override: Optional[str] = None,
|
||||
endpoint: str = "headless",
|
||||
chat_history: Optional[List[Dict[str, Any]]] = None,
|
||||
conversation_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run an agent with no live client; returns a structured outcome dict."""
|
||||
from application.core.model_utils import (
|
||||
get_api_key_for_provider,
|
||||
get_default_model_id,
|
||||
get_provider_from_model_id,
|
||||
validate_model_id,
|
||||
)
|
||||
from application.utils import calculate_doc_token_budget
|
||||
|
||||
owner = _resolve_owner(agent_config)
|
||||
if not owner:
|
||||
raise ValueError("Agent config is missing user_id; cannot run headless.")
|
||||
decoded_token = {"sub": owner}
|
||||
|
||||
retriever_kind = agent_config.get("retriever", "classic")
|
||||
source_id = agent_config.get("source_id") or agent_config.get("source")
|
||||
source_active: Any = {}
|
||||
if source_id:
|
||||
with db_readonly() as conn:
|
||||
src_row = SourcesRepository(conn).get(str(source_id), owner)
|
||||
if src_row:
|
||||
source_active = str(src_row["id"])
|
||||
retriever_kind = src_row.get("retriever", retriever_kind)
|
||||
source = {"active_docs": source_active}
|
||||
chunks = int(agent_config.get("chunks", 2) or 2)
|
||||
prompt_id = agent_config.get("prompt_id", "default")
|
||||
user_api_key = agent_config.get("key")
|
||||
agent_id = _resolve_agent_id(agent_config)
|
||||
agent_type = agent_config.get("agent_type", "classic")
|
||||
json_schema = agent_config.get("json_schema")
|
||||
prompt = get_prompt(prompt_id)
|
||||
|
||||
candidate_model = model_id_override or agent_config.get("default_model_id") or ""
|
||||
if candidate_model and validate_model_id(candidate_model, user_id=owner):
|
||||
model_id = candidate_model
|
||||
else:
|
||||
model_id = get_default_model_id()
|
||||
if candidate_model:
|
||||
logger.warning(
|
||||
"Agent %s references unknown model_id %r; falling back to %r",
|
||||
agent_id, candidate_model, model_id,
|
||||
)
|
||||
provider = (
|
||||
get_provider_from_model_id(model_id, user_id=owner)
|
||||
if model_id
|
||||
else settings.LLM_PROVIDER
|
||||
)
|
||||
system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER)
|
||||
doc_token_limit = calculate_doc_token_budget(model_id=model_id, user_id=owner)
|
||||
|
||||
retriever = RetrieverCreator.create_retriever(
|
||||
retriever_kind,
|
||||
source=source,
|
||||
chat_history=chat_history or [],
|
||||
prompt=prompt,
|
||||
chunks=chunks,
|
||||
doc_token_limit=doc_token_limit,
|
||||
model_id=model_id,
|
||||
user_api_key=user_api_key,
|
||||
agent_id=agent_id,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
retrieved_docs: List[Dict[str, Any]] = []
|
||||
try:
|
||||
docs = retriever.search(query)
|
||||
if docs:
|
||||
retrieved_docs = docs
|
||||
except Exception as exc:
|
||||
logger.warning("Headless retrieve failed: %s", exc)
|
||||
|
||||
tool_executor = ToolExecutor(
|
||||
user_api_key=user_api_key,
|
||||
user=owner,
|
||||
decoded_token=decoded_token,
|
||||
agent_id=agent_id,
|
||||
headless=True,
|
||||
tool_allowlist=list(tool_allowlist or []),
|
||||
)
|
||||
if conversation_id:
|
||||
tool_executor.conversation_id = str(conversation_id)
|
||||
|
||||
agent = AgentCreator.create_agent(
|
||||
agent_type,
|
||||
endpoint=endpoint,
|
||||
llm_name=provider or settings.LLM_PROVIDER,
|
||||
model_id=model_id,
|
||||
api_key=system_api_key,
|
||||
agent_id=agent_id,
|
||||
user_api_key=user_api_key,
|
||||
prompt=prompt,
|
||||
chat_history=chat_history or [],
|
||||
retrieved_docs=retrieved_docs,
|
||||
decoded_token=decoded_token,
|
||||
attachments=[],
|
||||
json_schema=json_schema,
|
||||
tool_executor=tool_executor,
|
||||
)
|
||||
if conversation_id:
|
||||
agent.conversation_id = str(conversation_id)
|
||||
|
||||
answer_full = ""
|
||||
thought = ""
|
||||
sources_log: List[Dict[str, Any]] = []
|
||||
tool_calls: List[Dict[str, Any]] = []
|
||||
for event in agent.gen(query=query):
|
||||
if not isinstance(event, dict):
|
||||
continue
|
||||
if "answer" in event:
|
||||
answer_full += str(event["answer"])
|
||||
elif "sources" in event:
|
||||
sources_log.extend(event["sources"])
|
||||
elif "tool_calls" in event:
|
||||
tool_calls.extend(event["tool_calls"])
|
||||
elif "thought" in event:
|
||||
thought += str(event["thought"])
|
||||
|
||||
denied = list(getattr(tool_executor, "headless_denials", []))
|
||||
error_type = "tool_not_allowed" if denied and not answer_full.strip() else None
|
||||
|
||||
# Use the LLM accumulator (gen_token_usage / stream_token_usage decorators);
|
||||
# current_token_count is a context-size sentinel, not a usage tally.
|
||||
llm_usage = getattr(getattr(agent, "llm", None), "token_usage", None) or {}
|
||||
prompt_tokens = int(llm_usage.get("prompt_tokens", 0) or 0)
|
||||
generated_tokens = int(llm_usage.get("generated_tokens", 0) or 0)
|
||||
|
||||
return {
|
||||
"answer": answer_full,
|
||||
"thought": thought,
|
||||
"sources": sources_log,
|
||||
"tool_calls": tool_calls,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"generated_tokens": generated_tokens,
|
||||
"denied": denied,
|
||||
"error_type": error_type,
|
||||
"model_id": model_id,
|
||||
}
|
||||
131
application/agents/scheduler_utils.py
Normal file
131
application/agents/scheduler_utils.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""Cron/tz computations for the scheduler (shared by dispatcher, routes, and tool)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||
|
||||
from croniter import croniter
|
||||
|
||||
|
||||
_DELAY_RE = re.compile(r"^\s*(\d+)\s*(s|m|h|d)\s*$", re.IGNORECASE)
|
||||
_DELAY_MULTIPLIERS = {"s": 1, "m": 60, "h": 3600, "d": 86_400}
|
||||
|
||||
|
||||
class ScheduleValidationError(ValueError):
|
||||
"""Raised when a schedule's cron, run_at, or delay is invalid."""
|
||||
|
||||
|
||||
def resolve_timezone(tz_name: Optional[str]) -> ZoneInfo:
|
||||
"""Return a ``ZoneInfo`` for ``tz_name`` (default UTC)."""
|
||||
name = (tz_name or "UTC").strip() or "UTC"
|
||||
try:
|
||||
return ZoneInfo(name)
|
||||
except ZoneInfoNotFoundError as exc:
|
||||
raise ScheduleValidationError(f"Unknown timezone: {name}") from exc
|
||||
|
||||
|
||||
def parse_cron(expression: str) -> None:
|
||||
"""Validate a 5-field cron expression; raise on bad input."""
|
||||
# croniter defers some malformed inputs until get_next, so force one here.
|
||||
if not expression or not isinstance(expression, str):
|
||||
raise ScheduleValidationError("Cron expression is required.")
|
||||
fields = expression.strip().split()
|
||||
if len(fields) != 5:
|
||||
raise ScheduleValidationError("Cron expression must have 5 fields.")
|
||||
try:
|
||||
itr = croniter(expression, datetime.now(timezone.utc))
|
||||
itr.get_next(datetime)
|
||||
except (ValueError, KeyError) as exc:
|
||||
raise ScheduleValidationError(f"Invalid cron expression: {exc}") from exc
|
||||
|
||||
|
||||
_CRON_INTERVAL_WINDOW = 64
|
||||
|
||||
|
||||
def cron_interval_seconds(expression: str, tz_name: Optional[str]) -> int:
|
||||
"""Return the smallest gap between ticks in a rolling window (enforces SCHEDULE_MIN_INTERVAL).
|
||||
|
||||
Walks _CRON_INTERVAL_WINDOW ticks because bursty expressions like
|
||||
``* 9 * * *`` have tiny within-burst gaps and huge between-burst gaps;
|
||||
sampling only two adjacent ticks would miss the small gap.
|
||||
"""
|
||||
parse_cron(expression)
|
||||
tz = resolve_timezone(tz_name)
|
||||
anchor_local = datetime.now(timezone.utc).astimezone(tz)
|
||||
itr = croniter(expression, anchor_local)
|
||||
prev = itr.get_next(datetime)
|
||||
smallest: Optional[int] = None
|
||||
for _ in range(_CRON_INTERVAL_WINDOW - 1):
|
||||
nxt = itr.get_next(datetime)
|
||||
gap = int((nxt - prev).total_seconds())
|
||||
if gap > 0 and (smallest is None or gap < smallest):
|
||||
smallest = gap
|
||||
prev = nxt
|
||||
return smallest if smallest is not None else 0
|
||||
|
||||
|
||||
def next_cron_run(
|
||||
expression: str,
|
||||
tz_name: Optional[str],
|
||||
after: Optional[datetime] = None,
|
||||
) -> datetime:
|
||||
"""Return the next fire time strictly after ``after`` (UTC, tz-aware).
|
||||
|
||||
Evaluates the cadence in the schedule's IANA tz so DST boundaries land on
|
||||
the intended local clock-time (e.g. 9 AM Warsaw stays 9 AM across the jump).
|
||||
"""
|
||||
parse_cron(expression)
|
||||
tz = resolve_timezone(tz_name)
|
||||
anchor_utc = after if after is not None else datetime.now(timezone.utc)
|
||||
if anchor_utc.tzinfo is None:
|
||||
anchor_utc = anchor_utc.replace(tzinfo=timezone.utc)
|
||||
anchor_local = anchor_utc.astimezone(tz)
|
||||
itr = croniter(expression, anchor_local)
|
||||
nxt_local = itr.get_next(datetime)
|
||||
return nxt_local.astimezone(timezone.utc)
|
||||
|
||||
|
||||
def parse_delay(delay: str) -> timedelta:
|
||||
"""Parse a duration like ``30m`` / ``2h`` / ``1d`` into a timedelta."""
|
||||
if not isinstance(delay, str):
|
||||
raise ScheduleValidationError("delay must be a string like '30m' or '2h'.")
|
||||
match = _DELAY_RE.match(delay)
|
||||
if not match:
|
||||
raise ScheduleValidationError(
|
||||
"delay must look like '30s', '15m', '2h', or '1d'."
|
||||
)
|
||||
amount, unit = int(match.group(1)), match.group(2).lower()
|
||||
if amount <= 0:
|
||||
raise ScheduleValidationError("delay must be positive.")
|
||||
return timedelta(seconds=amount * _DELAY_MULTIPLIERS[unit])
|
||||
|
||||
|
||||
def parse_run_at(run_at: str, tz_name: Optional[str] = None) -> datetime:
|
||||
"""Parse an ISO 8601 timestamp; naive values resolve in ``tz_name``.
|
||||
|
||||
Naive values inside the DST "fall back" hour resolve to the earlier instance
|
||||
(zoneinfo default fold=0); pass an explicit offset to select the later one.
|
||||
"""
|
||||
if not isinstance(run_at, str) or not run_at.strip():
|
||||
raise ScheduleValidationError("run_at must be an ISO 8601 string.")
|
||||
try:
|
||||
parsed = datetime.fromisoformat(run_at.strip().replace("Z", "+00:00"))
|
||||
except ValueError as exc:
|
||||
raise ScheduleValidationError(f"Invalid run_at: {exc}") from exc
|
||||
if parsed.tzinfo is None:
|
||||
parsed = parsed.replace(tzinfo=resolve_timezone(tz_name))
|
||||
return parsed.astimezone(timezone.utc)
|
||||
|
||||
|
||||
def clamp_once_horizon(run_at: datetime, max_horizon_seconds: int) -> None:
|
||||
"""Raise when ``run_at`` is in the past or beyond the once-task horizon."""
|
||||
now = datetime.now(timezone.utc)
|
||||
if run_at <= now:
|
||||
raise ScheduleValidationError("run_at is in the past.")
|
||||
if max_horizon_seconds > 0 and run_at - now > timedelta(seconds=max_horizon_seconds):
|
||||
raise ScheduleValidationError(
|
||||
"run_at is beyond the maximum allowed scheduling horizon."
|
||||
)
|
||||
@@ -3,6 +3,11 @@ import uuid
|
||||
from collections import Counter
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from application.agents.default_tools import (
|
||||
is_headless_excluded_tool,
|
||||
resolve_tool_by_id,
|
||||
synthesized_default_tools,
|
||||
)
|
||||
from application.agents.tools.tool_action_parser import ToolActionParser
|
||||
from application.agents.tools.tool_manager import ToolManager
|
||||
from application.security.encryption import decrypt_credentials
|
||||
@@ -12,6 +17,7 @@ from application.storage.db.repositories.tool_call_attempts import (
|
||||
ToolCallAttemptsRepository,
|
||||
)
|
||||
from application.storage.db.repositories.user_tools import UserToolsRepository
|
||||
from application.storage.db.repositories.users import UsersRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -113,10 +119,22 @@ class ToolExecutor:
|
||||
user_api_key: Optional[str] = None,
|
||||
user: Optional[str] = None,
|
||||
decoded_token: Optional[Dict] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
*,
|
||||
headless: bool = False,
|
||||
tool_allowlist: Optional[List[str]] = None,
|
||||
):
|
||||
self.user_api_key = user_api_key
|
||||
self.user = user
|
||||
self.decoded_token = decoded_token
|
||||
self.agent_id = agent_id
|
||||
# Headless mode (scheduled / webhook): no human to resolve a pause,
|
||||
# so check_pause returns headless_denied sentinels instead.
|
||||
self.headless = bool(headless)
|
||||
# Tool-instance ids pre-authorized for headless approval-gated execution.
|
||||
self.tool_allowlist: set = (
|
||||
{str(x) for x in tool_allowlist} if tool_allowlist else set()
|
||||
)
|
||||
self.tool_calls: List[Dict] = []
|
||||
self._loaded_tools: Dict[str, object] = {}
|
||||
self.conversation_id: Optional[str] = None
|
||||
@@ -124,6 +142,8 @@ class ToolExecutor:
|
||||
self.client_tools: Optional[List[Dict]] = None
|
||||
self._name_to_tool: Dict[str, Tuple[str, str]] = {}
|
||||
self._tool_to_name: Dict[Tuple[str, str], str] = {}
|
||||
# Filled by the LLMHandler.handle_tool_calls headless loop.
|
||||
self.headless_denials: List[Dict] = []
|
||||
|
||||
def get_tools(self) -> Dict[str, Dict]:
|
||||
"""Load tool configs from DB based on user context.
|
||||
@@ -140,29 +160,54 @@ class ToolExecutor:
|
||||
return tools
|
||||
|
||||
def _get_tools_by_api_key(self, api_key: str) -> Dict[str, Dict]:
|
||||
"""Resolve an agent's toolset — exactly ``agents.tools``, no defaults."""
|
||||
# Per-operation session: the answer pipeline spans a long-lived
|
||||
# generator; wrapping it in a single connection would pin a PG
|
||||
# conn for the whole stream. Open, fetch, close.
|
||||
with db_readonly() as conn:
|
||||
agent_data = AgentsRepository(conn).find_by_key(api_key)
|
||||
tool_ids = agent_data.get("tools", []) if agent_data else []
|
||||
if not tool_ids:
|
||||
return {}
|
||||
tools_repo = UserToolsRepository(conn)
|
||||
owner = (
|
||||
(agent_data.get("user_id") or agent_data.get("user"))
|
||||
if agent_data
|
||||
else None
|
||||
)
|
||||
tools: List[Dict] = []
|
||||
owner = (agent_data.get("user_id") or agent_data.get("user")) if agent_data else None
|
||||
for tid in tool_ids:
|
||||
row = None
|
||||
if owner:
|
||||
row = tools_repo.get_any(str(tid), owner)
|
||||
if row is not None:
|
||||
tools.append(row)
|
||||
return {str(tool["id"]): tool for tool in tools} if tools else {}
|
||||
row = resolve_tool_by_id(tid, owner, user_tools_repo=tools_repo)
|
||||
if row is None:
|
||||
continue
|
||||
# Headless runs (scheduled / webhook) drop chat-only tools
|
||||
# like ``scheduler`` so a fire-time LLM can't chain schedules.
|
||||
if self.headless and is_headless_excluded_tool(row.get("name")):
|
||||
continue
|
||||
tools.append(row)
|
||||
return {str(tool["id"]): tool for tool in tools}
|
||||
|
||||
def _get_user_tools(self, user: str = "local") -> Dict[str, Dict]:
|
||||
"""Resolve an agentless chat's toolset: explicit user tools plus defaults."""
|
||||
with db_readonly() as conn:
|
||||
user_tools = UserToolsRepository(conn).list_active_for_user(user)
|
||||
return {str(i): tool for i, tool in enumerate(user_tools)}
|
||||
user_doc = (
|
||||
UsersRepository(conn).get(user) if self.agent_id is None else None
|
||||
)
|
||||
# Headless agentless runs (e.g. scheduled fire) drop chat-only
|
||||
# tools (``scheduler``) from explicit user_tools too.
|
||||
filtered_user_tools = [
|
||||
t for t in user_tools
|
||||
if not (self.headless and is_headless_excluded_tool(t.get("name")))
|
||||
]
|
||||
# Index keys (ints) and synthetic uuid5 keys can't collide.
|
||||
tools: Dict[str, Dict] = {
|
||||
str(i): tool for i, tool in enumerate(filtered_user_tools)
|
||||
}
|
||||
if self.agent_id is None:
|
||||
for default_row in synthesized_default_tools(
|
||||
user_doc, headless=self.headless,
|
||||
):
|
||||
tools[str(default_row["id"])] = default_row
|
||||
return tools
|
||||
|
||||
def merge_client_tools(
|
||||
self, tools_dict: Dict, client_tools: List[Dict]
|
||||
@@ -300,9 +345,11 @@ class ToolExecutor:
|
||||
def check_pause(
|
||||
self, tools_dict: Dict, call, llm_class_name: str
|
||||
) -> Optional[Dict]:
|
||||
"""Check if a tool call requires pausing for approval or client execution.
|
||||
"""Return a pending-action dict (approval / client / headless_denied) or None.
|
||||
|
||||
Returns a dict describing the pending action if pause is needed, None otherwise.
|
||||
In headless mode the dict's pause_type is ``headless_denied`` so the
|
||||
upstream loop synthesizes a tool result instead of pausing (nothing can
|
||||
resume a scheduled / webhook run).
|
||||
"""
|
||||
parser = ToolActionParser(llm_class_name, name_mapping=self._name_to_tool)
|
||||
tool_id, action_name, call_args = parser.parse_args(call)
|
||||
@@ -313,9 +360,26 @@ class ToolExecutor:
|
||||
return None # Will be handled as error by execute()
|
||||
|
||||
tool_data = tools_dict[tool_id]
|
||||
arguments = call_args if isinstance(call_args, dict) else {}
|
||||
|
||||
# Client-side tools
|
||||
if tool_data.get("client_side"):
|
||||
if self.headless:
|
||||
return {
|
||||
"call_id": call_id,
|
||||
"name": llm_name,
|
||||
"tool_name": tool_data.get("name", "unknown"),
|
||||
"tool_id": tool_id,
|
||||
"action_name": action_name,
|
||||
"llm_name": llm_name,
|
||||
"arguments": arguments,
|
||||
"pause_type": "headless_denied",
|
||||
"deny_reason": (
|
||||
"Client-side tools cannot run in headless / scheduled runs."
|
||||
),
|
||||
"error_type": "tool_not_allowed",
|
||||
"thought_signature": getattr(call, "thought_signature", None),
|
||||
}
|
||||
return {
|
||||
"call_id": call_id,
|
||||
"name": llm_name,
|
||||
@@ -323,7 +387,7 @@ class ToolExecutor:
|
||||
"tool_id": tool_id,
|
||||
"action_name": action_name,
|
||||
"llm_name": llm_name,
|
||||
"arguments": call_args if isinstance(call_args, dict) else {},
|
||||
"arguments": arguments,
|
||||
"pause_type": "requires_client_execution",
|
||||
"thought_signature": getattr(call, "thought_signature", None),
|
||||
}
|
||||
@@ -340,6 +404,27 @@ class ToolExecutor:
|
||||
)
|
||||
|
||||
if action_data.get("require_approval"):
|
||||
if self.headless:
|
||||
tool_row_id = str(tool_data.get("id") or tool_id)
|
||||
if tool_row_id in self.tool_allowlist:
|
||||
# Pre-authorized for headless execution — fall through.
|
||||
return None
|
||||
return {
|
||||
"call_id": call_id,
|
||||
"name": llm_name,
|
||||
"tool_name": tool_data.get("name", "unknown"),
|
||||
"tool_id": tool_id,
|
||||
"action_name": action_name,
|
||||
"llm_name": llm_name,
|
||||
"arguments": arguments,
|
||||
"pause_type": "headless_denied",
|
||||
"deny_reason": (
|
||||
"This tool requires approval and is not in the run's "
|
||||
"tool_allowlist."
|
||||
),
|
||||
"error_type": "tool_not_allowed",
|
||||
"thought_signature": getattr(call, "thought_signature", None),
|
||||
}
|
||||
return {
|
||||
"call_id": call_id,
|
||||
"name": llm_name,
|
||||
@@ -347,7 +432,7 @@ class ToolExecutor:
|
||||
"tool_id": tool_id,
|
||||
"action_name": action_name,
|
||||
"llm_name": llm_name,
|
||||
"arguments": call_args if isinstance(call_args, dict) else {},
|
||||
"arguments": arguments,
|
||||
"pause_type": "awaiting_approval",
|
||||
"thought_signature": getattr(call, "thought_signature", None),
|
||||
}
|
||||
@@ -623,6 +708,13 @@ class ToolExecutor:
|
||||
tool_config["tool_id"] = str(row_id)
|
||||
if self.conversation_id:
|
||||
tool_config["conversation_id"] = self.conversation_id
|
||||
if tool_data["name"] == "scheduler":
|
||||
# Agent-bound: stamp schedules.agent_id. Agentless: the tool
|
||||
# falls back to ``origin_conversation_id`` as the schedule's
|
||||
# conversation home.
|
||||
tool_config["agent_id"] = (
|
||||
str(self.agent_id) if self.agent_id else None
|
||||
)
|
||||
if tool_data["name"] == "mcp_tool":
|
||||
tool_config["query_mode"] = True
|
||||
|
||||
|
||||
342
application/agents/tools/scheduler.py
Normal file
342
application/agents/tools/scheduler.py
Normal file
@@ -0,0 +1,342 @@
|
||||
"""Scheduler tool: one-time agent tasks in agent-bound or agentless chats."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from application.agents.scheduler_utils import (
|
||||
ScheduleValidationError,
|
||||
clamp_once_horizon,
|
||||
parse_delay,
|
||||
parse_run_at,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
from application.storage.db.repositories.schedules import SchedulesRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
|
||||
from .base import Tool
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SchedulerTool(Tool):
|
||||
"""
|
||||
Scheduling
|
||||
Schedules a one-time task for the agent to run at a chosen time or delay.
|
||||
"""
|
||||
|
||||
# internal=True keeps scheduler out of /api/available_tools and the
|
||||
# agentless Add-Tool modal; tool_manager.load_tool still lazy-loads it
|
||||
# per-user at execute time (same as memory/notes/todo_list).
|
||||
internal: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tool_config: Optional[Dict[str, Any]] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> None:
|
||||
cfg = tool_config or {}
|
||||
self.user_id: Optional[str] = user_id
|
||||
self.agent_id: Optional[str] = cfg.get("agent_id")
|
||||
self.conversation_id: Optional[str] = cfg.get("conversation_id")
|
||||
|
||||
def execute_action(self, action_name: str, **kwargs: Any) -> str:
|
||||
"""Dispatch on the LLM-supplied action name."""
|
||||
if not self.user_id:
|
||||
return "Error: SchedulerTool requires a valid user_id."
|
||||
# Agent-bound: agent_id must look like a UUID. Agentless: agent_id is
|
||||
# absent; an originating conversation is then mandatory (the schedule's
|
||||
# conversation home, used for history + output append).
|
||||
if self.agent_id and not looks_like_uuid(str(self.agent_id)):
|
||||
return "Error: SchedulerTool received an invalid agent_id."
|
||||
if not self.agent_id and not self.conversation_id:
|
||||
return (
|
||||
"Error: SchedulerTool requires an agent_id or a "
|
||||
"conversation_id (no conversation home)."
|
||||
)
|
||||
if action_name == "schedule_task":
|
||||
return self._schedule_task(
|
||||
instruction=kwargs.get("instruction", ""),
|
||||
delay=kwargs.get("delay"),
|
||||
run_at=kwargs.get("run_at"),
|
||||
tz=kwargs.get("timezone"),
|
||||
)
|
||||
if action_name == "list_scheduled_tasks":
|
||||
return self._list_scheduled_tasks()
|
||||
if action_name == "cancel_scheduled_task":
|
||||
return self._cancel_scheduled_task(kwargs.get("task_id", ""))
|
||||
return f"Unknown action: {action_name}"
|
||||
|
||||
def get_actions_metadata(self) -> List[Dict[str, Any]]:
|
||||
"""Action schemas for the LLM tool catalogue."""
|
||||
return [
|
||||
{
|
||||
"name": "schedule_task",
|
||||
"description": (
|
||||
"Schedule a one-time task. Provide either a `delay` "
|
||||
"(e.g. '30m', '2h', '1d') from now, or a `run_at` ISO-8601 "
|
||||
"absolute time. Optionally pass an IANA `timezone` to resolve "
|
||||
"naive run_at values. The instruction is the task that will "
|
||||
"execute at fire time (including delivery, e.g. 'send to my "
|
||||
"Telegram'). For recurring schedules in an agent chat, point "
|
||||
"the user to the agent's Schedules tab."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"instruction": {
|
||||
"type": "string",
|
||||
"description": "What the agent should do at fire time.",
|
||||
},
|
||||
"delay": {
|
||||
"type": "string",
|
||||
"description": "Duration like '30m', '2h', '1d'.",
|
||||
},
|
||||
"run_at": {
|
||||
"type": "string",
|
||||
"description": "Absolute ISO 8601 timestamp.",
|
||||
},
|
||||
"timezone": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"IANA timezone (e.g. Europe/Warsaw) for naive run_at."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["instruction"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "list_scheduled_tasks",
|
||||
"description": (
|
||||
"List pending one-time tasks for the current chat. "
|
||||
"Agent-bound chats scope to user+agent; agentless chats "
|
||||
"scope to user+originating conversation."
|
||||
),
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
{
|
||||
"name": "cancel_scheduled_task",
|
||||
"description": "Cancel a pending one-time task by its task_id.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"task_id": {
|
||||
"type": "string",
|
||||
"description": "The schedule id returned by schedule_task.",
|
||||
},
|
||||
},
|
||||
"required": ["task_id"],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
def get_config_requirements(self) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
def _schedule_task(
|
||||
self,
|
||||
instruction: str,
|
||||
delay: Optional[str],
|
||||
run_at: Optional[str],
|
||||
tz: Optional[str],
|
||||
) -> str:
|
||||
if not instruction or not isinstance(instruction, str):
|
||||
return "Error: instruction is required."
|
||||
if not delay and not run_at:
|
||||
return "Error: provide either `delay` or `run_at`."
|
||||
if delay and run_at:
|
||||
return "Error: provide only one of `delay` or `run_at`."
|
||||
|
||||
try:
|
||||
if delay:
|
||||
fire = datetime.now(timezone.utc) + parse_delay(delay)
|
||||
else:
|
||||
fire = parse_run_at(run_at, tz)
|
||||
clamp_once_horizon(fire, settings.SCHEDULE_ONCE_MAX_HORIZON)
|
||||
except ScheduleValidationError as exc:
|
||||
return f"Error: {exc}"
|
||||
|
||||
with db_readonly() as conn:
|
||||
count = SchedulesRepository(conn).count_active_for_user(self.user_id)
|
||||
if (
|
||||
settings.SCHEDULE_MAX_PER_USER > 0
|
||||
and count >= settings.SCHEDULE_MAX_PER_USER
|
||||
):
|
||||
return (
|
||||
"Error: you have reached the maximum number of active schedules."
|
||||
)
|
||||
|
||||
# Chat-created tasks default to the user's non-approval tools (for the
|
||||
# agent's toolset when agent-bound, or the user's defaults+user_tools
|
||||
# when agentless).
|
||||
allowlist = _safe_default_allowlist(self.agent_id, self.user_id)
|
||||
|
||||
auto_name = _name_from_instruction(instruction)
|
||||
try:
|
||||
with db_session() as conn:
|
||||
created = SchedulesRepository(conn).create(
|
||||
user_id=self.user_id,
|
||||
agent_id=self.agent_id,
|
||||
trigger_type="once",
|
||||
instruction=instruction.strip(),
|
||||
name=auto_name,
|
||||
run_at=fire,
|
||||
next_run_at=fire,
|
||||
timezone=tz or "UTC",
|
||||
tool_allowlist=allowlist,
|
||||
origin_conversation_id=self.conversation_id,
|
||||
created_via="chat",
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception("schedule_task create failed: %s", exc)
|
||||
return "Error: failed to create scheduled task."
|
||||
return json.dumps(
|
||||
{
|
||||
"task_id": str(created["id"]),
|
||||
"resolved_run_at": _iso_utc(fire),
|
||||
"timezone": tz or "UTC",
|
||||
"instruction": instruction.strip(),
|
||||
"name": auto_name,
|
||||
}
|
||||
)
|
||||
|
||||
def _list_scheduled_tasks(self) -> str:
|
||||
"""Pending one-time tasks for this user, oldest fire first.
|
||||
|
||||
Agent-bound chats scope to user+agent. Agentless chats scope to user+
|
||||
origin_conversation_id so a user only sees tasks created from this chat.
|
||||
"""
|
||||
with db_readonly() as conn:
|
||||
repo = SchedulesRepository(conn)
|
||||
if self.agent_id:
|
||||
rows = repo.list_for_agent(
|
||||
self.agent_id,
|
||||
self.user_id,
|
||||
statuses=["active"],
|
||||
trigger_type="once",
|
||||
)
|
||||
else:
|
||||
rows = repo.list_for_conversation(
|
||||
self.user_id,
|
||||
self.conversation_id,
|
||||
statuses=["active"],
|
||||
trigger_type="once",
|
||||
)
|
||||
# Values arrive as ISO strings (coerce_pg_native); string sentinel keeps types uniform.
|
||||
rows.sort(key=lambda r: r.get("next_run_at") or "9999-12-31T23:59:59Z")
|
||||
items = [
|
||||
{
|
||||
"task_id": str(r["id"]),
|
||||
"instruction": r.get("instruction"),
|
||||
"name": r.get("name"),
|
||||
"resolved_run_at": _iso_utc(r.get("next_run_at")),
|
||||
"timezone": r.get("timezone"),
|
||||
"status": r.get("status"),
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
return json.dumps({"tasks": items})
|
||||
|
||||
def _cancel_scheduled_task(self, task_id: str) -> str:
|
||||
if not task_id or not looks_like_uuid(str(task_id)):
|
||||
return "Error: task_id must be a valid id."
|
||||
with db_session() as conn:
|
||||
repo = SchedulesRepository(conn)
|
||||
# Agentless: scope cancel to user + originating conversation so a
|
||||
# user can only cancel tasks they created in the current chat.
|
||||
if not self.agent_id:
|
||||
row = repo.get(task_id, self.user_id)
|
||||
if row is None or row.get("agent_id") is not None or (
|
||||
str(row.get("origin_conversation_id") or "")
|
||||
!= str(self.conversation_id or "")
|
||||
):
|
||||
return (
|
||||
"Error: scheduled task not found or already terminal."
|
||||
)
|
||||
ok = repo.cancel(task_id, self.user_id)
|
||||
if not ok:
|
||||
return "Error: scheduled task not found or already terminal."
|
||||
return json.dumps({"task_id": str(task_id), "status": "cancelled"})
|
||||
|
||||
|
||||
def _name_from_instruction(instruction: str, *, max_len: int = 80) -> str:
|
||||
"""Compact display name derived from the instruction's first line."""
|
||||
first_line = instruction.strip().split("\n", 1)[0]
|
||||
if len(first_line) <= max_len:
|
||||
return first_line
|
||||
return first_line[: max_len - 1] + "…"
|
||||
|
||||
|
||||
def _iso_utc(value: Any) -> Optional[str]:
|
||||
"""Render a datetime (or ISO string) as RFC3339 UTC; ``None`` passes through."""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
value = datetime.fromisoformat(value.replace("Z", "+00:00"))
|
||||
except ValueError:
|
||||
return value
|
||||
if value.tzinfo is None:
|
||||
value = value.replace(tzinfo=timezone.utc)
|
||||
return value.astimezone(timezone.utc).isoformat().replace("+00:00", "Z")
|
||||
|
||||
|
||||
def _safe_default_allowlist(
|
||||
agent_id: Optional[str], user_id: str,
|
||||
) -> List[str]:
|
||||
"""Return ids of available tools whose actions are all non-approval.
|
||||
|
||||
Agent-bound: the agent's ``agents.tools`` entries.
|
||||
Agentless: the user's active ``user_tools`` rows plus synthesized default
|
||||
chat tools (resolved against ``settings.DEFAULT_CHAT_TOOLS`` and the
|
||||
user's ``tool_preferences.disabled_default_tools`` opt-outs).
|
||||
"""
|
||||
from application.agents.default_tools import (
|
||||
resolve_tool_by_id,
|
||||
synthesized_default_tools,
|
||||
)
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.user_tools import UserToolsRepository
|
||||
from application.storage.db.repositories.users import UsersRepository
|
||||
|
||||
def _is_safe(row: Dict[str, Any]) -> bool:
|
||||
actions = row.get("actions") or []
|
||||
return not any(a.get("require_approval") for a in actions)
|
||||
|
||||
safe_ids: List[str] = []
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
tools_repo = UserToolsRepository(conn)
|
||||
if agent_id:
|
||||
agent = AgentsRepository(conn).get(agent_id, user_id)
|
||||
tool_ids = (agent or {}).get("tools") or []
|
||||
for raw_id in tool_ids:
|
||||
tool_id = str(raw_id)
|
||||
row = resolve_tool_by_id(
|
||||
tool_id, user_id, user_tools_repo=tools_repo,
|
||||
)
|
||||
if not row or not _is_safe(row):
|
||||
continue
|
||||
safe_ids.append(tool_id)
|
||||
else:
|
||||
# Agentless: explicit user_tools (active=true) + synthesized
|
||||
# defaults respecting the user's opt-out preferences.
|
||||
user_doc = UsersRepository(conn).get(user_id)
|
||||
for row in tools_repo.list_active_for_user(user_id):
|
||||
if not _is_safe(row):
|
||||
continue
|
||||
safe_ids.append(str(row["id"]))
|
||||
for default_row in synthesized_default_tools(user_doc):
|
||||
if not _is_safe(default_row):
|
||||
continue
|
||||
safe_ids.append(str(default_row["id"]))
|
||||
except Exception: # pragma: no cover — best-effort fallback
|
||||
logger.exception("scheduler: default allowlist build failed")
|
||||
return []
|
||||
return safe_ids
|
||||
@@ -28,7 +28,10 @@ class ToolManager:
|
||||
module = importlib.import_module(f"application.agents.tools.{tool_name}")
|
||||
for member_name, obj in inspect.getmembers(module, inspect.isclass):
|
||||
if issubclass(obj, Tool) and obj is not Tool:
|
||||
if tool_name in {"mcp_tool", "notes", "memory", "todo_list"} and user_id:
|
||||
if (
|
||||
tool_name in {"mcp_tool", "notes", "memory", "todo_list", "scheduler"}
|
||||
and user_id
|
||||
):
|
||||
return obj(tool_config, user_id)
|
||||
else:
|
||||
return obj(tool_config)
|
||||
@@ -36,7 +39,10 @@ class ToolManager:
|
||||
def execute_action(self, tool_name, action_name, user_id=None, **kwargs):
|
||||
if tool_name not in self.tools:
|
||||
raise ValueError(f"Tool '{tool_name}' not loaded")
|
||||
if tool_name in {"mcp_tool", "memory", "todo_list", "notes"} and user_id:
|
||||
if (
|
||||
tool_name in {"mcp_tool", "memory", "todo_list", "notes", "scheduler"}
|
||||
and user_id
|
||||
):
|
||||
tool_config = self.config.get(tool_name, {})
|
||||
tool = self.load_tool(tool_name, tool_config, user_id)
|
||||
return tool.execute_action(action_name, **kwargs)
|
||||
|
||||
83
application/alembic/versions/0009_tool_preferences.py
Normal file
83
application/alembic/versions/0009_tool_preferences.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""0009 default chat tools — users.tool_preferences + memories.tool_id.
|
||||
|
||||
Adds ``users.tool_preferences`` JSONB and drops the
|
||||
``memories.tool_id`` FK to ``user_tools`` (synthetic default-tool ids
|
||||
have no ``user_tools`` row). Delete-cascade for real tools is kept via
|
||||
an AFTER DELETE trigger on ``user_tools``. Idempotent both ways.
|
||||
|
||||
Revision ID: 0009_tool_preferences
|
||||
Revises: 0008_ingest_progress_status
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
revision: str = "0009_tool_preferences"
|
||||
down_revision: Union[str, None] = "0008_ingest_progress_status"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE users
|
||||
ADD COLUMN IF NOT EXISTS tool_preferences JSONB
|
||||
NOT NULL DEFAULT '{}'::jsonb;
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"ALTER TABLE memories DROP CONSTRAINT IF EXISTS memories_tool_id_fkey;"
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE OR REPLACE FUNCTION cleanup_tool_memories() RETURNS trigger
|
||||
LANGUAGE plpgsql AS $$
|
||||
BEGIN
|
||||
DELETE FROM memories WHERE tool_id = OLD.id;
|
||||
RETURN OLD;
|
||||
END;
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
# DROP-then-CREATE — no CREATE OR REPLACE TRIGGER for this signature.
|
||||
op.execute(
|
||||
"DROP TRIGGER IF EXISTS user_tools_cleanup_memories ON user_tools;"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE TRIGGER user_tools_cleanup_memories "
|
||||
"AFTER DELETE ON user_tools "
|
||||
"FOR EACH ROW EXECUTE FUNCTION cleanup_tool_memories();"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
"DROP TRIGGER IF EXISTS user_tools_cleanup_memories ON user_tools;"
|
||||
)
|
||||
op.execute("DROP FUNCTION IF EXISTS cleanup_tool_memories();")
|
||||
# DESTRUCTIVE: restoring the FK requires every memories.tool_id to
|
||||
# reference a real user_tools row. Any memory written by a built-in
|
||||
# default tool (synthetic uuid5 id, no user_tools row) is permanently
|
||||
# DELETED here so the constraint can be re-created. Downgrading 0009
|
||||
# therefore loses all built-in-memory-tool data — by necessity, since
|
||||
# the restored schema cannot represent it.
|
||||
op.execute(
|
||||
"""
|
||||
DELETE FROM memories
|
||||
WHERE tool_id IS NOT NULL
|
||||
AND tool_id NOT IN (SELECT id FROM user_tools);
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE memories
|
||||
ADD CONSTRAINT memories_tool_id_fkey
|
||||
FOREIGN KEY (tool_id) REFERENCES user_tools(id) ON DELETE CASCADE;
|
||||
"""
|
||||
)
|
||||
op.execute("ALTER TABLE users DROP COLUMN IF EXISTS tool_preferences;")
|
||||
147
application/alembic/versions/0010_schedules.py
Normal file
147
application/alembic/versions/0010_schedules.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""0010 scheduler — schedules + schedule_runs tables.
|
||||
|
||||
Revision ID: 0010_schedules
|
||||
Revises: 0009_tool_preferences
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
revision: str = "0010_schedules"
|
||||
down_revision: Union[str, None] = "0009_tool_preferences"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE schedules (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
agent_id UUID NOT NULL REFERENCES agents(id) ON DELETE CASCADE,
|
||||
trigger_type TEXT NOT NULL,
|
||||
name TEXT,
|
||||
instruction TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'active',
|
||||
cron TEXT,
|
||||
run_at TIMESTAMPTZ,
|
||||
timezone TEXT NOT NULL DEFAULT 'UTC',
|
||||
next_run_at TIMESTAMPTZ,
|
||||
last_run_at TIMESTAMPTZ,
|
||||
end_at TIMESTAMPTZ,
|
||||
tool_allowlist JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||
model_id TEXT,
|
||||
token_budget INTEGER,
|
||||
origin_conversation_id UUID REFERENCES conversations(id) ON DELETE SET NULL,
|
||||
created_via TEXT NOT NULL DEFAULT 'ui',
|
||||
consecutive_failure_count INTEGER NOT NULL DEFAULT 0,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
CONSTRAINT schedules_trigger_type_chk
|
||||
CHECK (trigger_type IN ('once', 'recurring')),
|
||||
CONSTRAINT schedules_status_chk
|
||||
CHECK (status IN ('active', 'paused', 'completed', 'cancelled')),
|
||||
CONSTRAINT schedules_created_via_chk
|
||||
CHECK (created_via IN ('chat', 'ui')),
|
||||
CONSTRAINT schedules_recurring_cron_chk
|
||||
CHECK (trigger_type <> 'recurring' OR cron IS NOT NULL),
|
||||
CONSTRAINT schedules_once_run_at_chk
|
||||
CHECK (trigger_type <> 'once' OR run_at IS NOT NULL)
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"CREATE INDEX schedules_user_idx ON schedules (user_id);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX schedules_agent_idx ON schedules (agent_id);"
|
||||
)
|
||||
# Dispatcher hot path: status='active' AND next_run_at <= now().
|
||||
op.execute(
|
||||
"CREATE INDEX schedules_due_idx "
|
||||
"ON schedules (status, next_run_at) "
|
||||
"WHERE status = 'active';"
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"CREATE TRIGGER schedules_set_updated_at "
|
||||
"BEFORE UPDATE ON schedules "
|
||||
"FOR EACH ROW EXECUTE FUNCTION set_updated_at();"
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE schedule_runs (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
schedule_id UUID NOT NULL REFERENCES schedules(id) ON DELETE CASCADE,
|
||||
user_id TEXT NOT NULL,
|
||||
agent_id UUID NOT NULL REFERENCES agents(id) ON DELETE CASCADE,
|
||||
status TEXT NOT NULL DEFAULT 'pending',
|
||||
scheduled_for TIMESTAMPTZ NOT NULL,
|
||||
trigger_source TEXT NOT NULL DEFAULT 'cron',
|
||||
started_at TIMESTAMPTZ,
|
||||
finished_at TIMESTAMPTZ,
|
||||
output TEXT,
|
||||
output_truncated BOOLEAN NOT NULL DEFAULT false,
|
||||
error TEXT,
|
||||
error_type TEXT,
|
||||
prompt_tokens INTEGER NOT NULL DEFAULT 0,
|
||||
generated_tokens INTEGER NOT NULL DEFAULT 0,
|
||||
conversation_id UUID,
|
||||
message_id UUID,
|
||||
celery_task_id TEXT,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
CONSTRAINT schedule_runs_status_chk
|
||||
CHECK (status IN (
|
||||
'pending', 'running', 'success', 'failed', 'skipped', 'timeout'
|
||||
)),
|
||||
CONSTRAINT schedule_runs_trigger_source_chk
|
||||
CHECK (trigger_source IN ('cron', 'manual')),
|
||||
CONSTRAINT schedule_runs_error_type_chk
|
||||
CHECK (error_type IS NULL OR error_type IN (
|
||||
'auth_expired', 'tool_not_allowed', 'budget_exceeded',
|
||||
'timeout', 'agent_error', 'internal', 'missed', 'overlap'
|
||||
))
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# Dedup primitive: racing dispatchers hit ON CONFLICT on this index.
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX schedule_runs_dedup_uidx "
|
||||
"ON schedule_runs (schedule_id, scheduled_for);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX schedule_runs_schedule_recent_idx "
|
||||
"ON schedule_runs (schedule_id, scheduled_for DESC);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX schedule_runs_user_idx ON schedule_runs (user_id);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX schedule_runs_running_idx "
|
||||
"ON schedule_runs (status, started_at) "
|
||||
"WHERE status = 'running';"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE TRIGGER schedule_runs_set_updated_at "
|
||||
"BEFORE UPDATE ON schedule_runs "
|
||||
"FOR EACH ROW EXECUTE FUNCTION set_updated_at();"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop triggers explicitly (grep-able) before CASCADE-dropping the tables.
|
||||
op.execute(
|
||||
"DROP TRIGGER IF EXISTS schedule_runs_set_updated_at ON schedule_runs;"
|
||||
)
|
||||
op.execute("DROP TABLE IF EXISTS schedule_runs CASCADE;")
|
||||
op.execute(
|
||||
"DROP TRIGGER IF EXISTS schedules_set_updated_at ON schedules;"
|
||||
)
|
||||
op.execute("DROP TABLE IF EXISTS schedules CASCADE;")
|
||||
@@ -0,0 +1,53 @@
|
||||
"""0011 scheduler — make schedules.agent_id / schedule_runs.agent_id nullable.
|
||||
|
||||
Agentless schedules (created from agentless chats via the dual-registered
|
||||
``scheduler`` default chat tool) carry ``agent_id IS NULL``. Existing FK +
|
||||
``ON DELETE CASCADE`` semantics on ``agents(id)`` are unaffected — Postgres
|
||||
only cascades when the parent row is deleted, NULL rows aren't matched.
|
||||
|
||||
Revision ID: 0011_schedules_nullable_agent
|
||||
Revises: 0010_schedules
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
revision: str = "0011_schedules_nullable_agent"
|
||||
down_revision: Union[str, None] = "0010_schedules"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute("ALTER TABLE schedules ALTER COLUMN agent_id DROP NOT NULL;")
|
||||
op.execute("ALTER TABLE schedule_runs ALTER COLUMN agent_id DROP NOT NULL;")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Destructive otherwise: agentless rows have agent_id IS NULL by design,
|
||||
# so restoring NOT NULL must fail loudly if any exist.
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
DECLARE
|
||||
sched_nulls INTEGER;
|
||||
run_nulls INTEGER;
|
||||
BEGIN
|
||||
SELECT count(*) INTO sched_nulls
|
||||
FROM schedules WHERE agent_id IS NULL;
|
||||
SELECT count(*) INTO run_nulls
|
||||
FROM schedule_runs WHERE agent_id IS NULL;
|
||||
IF sched_nulls > 0 OR run_nulls > 0 THEN
|
||||
RAISE EXCEPTION
|
||||
'Cannot downgrade 0011: agentless rows present '
|
||||
'(schedules=%, schedule_runs=%). '
|
||||
'Delete or reassign them before retrying.',
|
||||
sched_nulls, run_nulls;
|
||||
END IF;
|
||||
END$$;
|
||||
"""
|
||||
)
|
||||
op.execute("ALTER TABLE schedule_runs ALTER COLUMN agent_id SET NOT NULL;")
|
||||
op.execute("ALTER TABLE schedules ALTER COLUMN agent_id SET NOT NULL;")
|
||||
@@ -325,6 +325,17 @@ class BaseAnswerResource:
|
||||
"Could not set tool_executor.message_id; tool-call correlation will be missing for message_id=%s",
|
||||
reserved_message_id,
|
||||
)
|
||||
# The reservation above may create the conversation row (first turn in
|
||||
# a new chat). Propagate that fresh id to the tool_executor so tools
|
||||
# that need a conversation home (e.g. ``scheduler`` in agentless chats)
|
||||
# see it on the very first call instead of waiting for the next turn.
|
||||
if conversation_id and getattr(agent, "tool_executor", None):
|
||||
try:
|
||||
agent.tool_executor.conversation_id = str(conversation_id)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Could not set tool_executor.conversation_id post-reserve",
|
||||
)
|
||||
|
||||
# Per-stream monotonic SSE event id. Allocated by ``_emit`` and
|
||||
# threaded through both the wire format (``id: <seq>\\n``) and
|
||||
|
||||
@@ -6,6 +6,7 @@ from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Set
|
||||
|
||||
from application.agents.agent_creator import AgentCreator
|
||||
from application.agents.default_tools import synthesized_default_tools
|
||||
from application.api.answer.services.compression import CompressionOrchestrator
|
||||
from application.api.answer.services.compression.token_counter import TokenCounter
|
||||
from application.api.answer.services.conversation_service import ConversationService
|
||||
@@ -25,6 +26,7 @@ from application.storage.db.repositories.attachments import AttachmentsRepositor
|
||||
from application.storage.db.repositories.prompts import PromptsRepository
|
||||
from application.storage.db.repositories.sources import SourcesRepository
|
||||
from application.storage.db.repositories.user_tools import UserToolsRepository
|
||||
from application.storage.db.repositories.users import UsersRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.retriever.retriever_creator import RetrieverCreator
|
||||
from application.utils import (
|
||||
@@ -293,7 +295,7 @@ class StreamProcessor:
|
||||
return attachments
|
||||
|
||||
def _validate_and_set_model(self):
|
||||
"""Validate and set model_id from request"""
|
||||
"""Pick model_id with agent authority on agent-bound chats."""
|
||||
from application.core.model_settings import ModelRegistry
|
||||
|
||||
requested_model = self.data.get("model_id")
|
||||
@@ -302,6 +304,20 @@ class StreamProcessor:
|
||||
caller_user_id = self.initial_user_id
|
||||
owner_user_id = self.agent_config.get("user_id") or caller_user_id
|
||||
|
||||
# Agent-bound: agent's default_model_id wins, body's model_id is dropped.
|
||||
agent_bound = self._agent_data is not None
|
||||
if agent_bound:
|
||||
agent_default_model = self.agent_config.get("default_model_id", "")
|
||||
if agent_default_model and validate_model_id(
|
||||
agent_default_model, user_id=owner_user_id
|
||||
):
|
||||
self.model_id = agent_default_model
|
||||
self.model_user_id = owner_user_id
|
||||
else:
|
||||
self.model_id = get_default_model_id()
|
||||
self.model_user_id = None
|
||||
return
|
||||
|
||||
if requested_model:
|
||||
if not validate_model_id(requested_model, user_id=caller_user_id):
|
||||
registry = ModelRegistry.get_instance()
|
||||
@@ -321,15 +337,8 @@ class StreamProcessor:
|
||||
self.model_id = requested_model
|
||||
self.model_user_id = caller_user_id
|
||||
else:
|
||||
agent_default_model = self.agent_config.get("default_model_id", "")
|
||||
if agent_default_model and validate_model_id(
|
||||
agent_default_model, user_id=owner_user_id
|
||||
):
|
||||
self.model_id = agent_default_model
|
||||
self.model_user_id = owner_user_id
|
||||
else:
|
||||
self.model_id = get_default_model_id()
|
||||
self.model_user_id = None
|
||||
self.model_id = get_default_model_id()
|
||||
self.model_user_id = None
|
||||
|
||||
def _get_agent_key(self, agent_id: Optional[str], user_id: Optional[str]) -> tuple:
|
||||
"""Get API key for agent with access control."""
|
||||
@@ -385,6 +394,7 @@ class StreamProcessor:
|
||||
raise
|
||||
|
||||
def _get_data_from_api_key(self, api_key: str) -> Dict[str, Any]:
|
||||
"""Resolve agent metadata + the unioned source set for the given key."""
|
||||
with db_readonly() as conn:
|
||||
agent = AgentsRepository(conn).find_by_key(api_key)
|
||||
if not agent:
|
||||
@@ -395,36 +405,66 @@ class StreamProcessor:
|
||||
data: Dict[str, Any] = dict(agent)
|
||||
data["user"] = agent.get("user_id")
|
||||
|
||||
# Resolve the primary source row (if any) for retriever/chunks.
|
||||
source_id = agent.get("source_id")
|
||||
if source_id:
|
||||
source_doc = sources_repo.get(str(source_id), agent.get("user_id"))
|
||||
# Active sources = primary ∪ extras, primary first, deduplicated.
|
||||
# ``_configure_source`` ignores an empty ``data["sources"]``,
|
||||
# so the primary must appear in the union too — not only in
|
||||
# the legacy ``data["source"]`` slot.
|
||||
sources_list: list = []
|
||||
seen: set = set()
|
||||
owner = agent.get("user_id")
|
||||
primary_id = agent.get("source_id")
|
||||
# ``sources`` row may have NULL ``retriever``/``chunks`` —
|
||||
# fall back to the agent's value (``dict.get`` returns None
|
||||
# even when the key exists with value None).
|
||||
if primary_id:
|
||||
source_doc = sources_repo.get(str(primary_id), owner)
|
||||
if source_doc:
|
||||
data["source"] = str(source_doc["id"])
|
||||
data["retriever"] = source_doc.get(
|
||||
"retriever", data.get("retriever")
|
||||
sid = str(source_doc["id"])
|
||||
data["source"] = sid
|
||||
src_retriever = source_doc.get("retriever")
|
||||
if src_retriever:
|
||||
data["retriever"] = src_retriever
|
||||
src_chunks = source_doc.get("chunks")
|
||||
if src_chunks is not None:
|
||||
data["chunks"] = src_chunks
|
||||
sources_list.append(
|
||||
{
|
||||
"id": sid,
|
||||
"retriever": src_retriever or "classic",
|
||||
"chunks": (
|
||||
src_chunks if src_chunks is not None
|
||||
else data.get("chunks", "2")
|
||||
),
|
||||
}
|
||||
)
|
||||
data["chunks"] = source_doc.get("chunks", data.get("chunks"))
|
||||
seen.add(sid)
|
||||
else:
|
||||
data["source"] = None
|
||||
else:
|
||||
data["source"] = None
|
||||
|
||||
sources_list = []
|
||||
extra = agent.get("extra_source_ids") or []
|
||||
if extra:
|
||||
for sid in extra:
|
||||
source_doc = sources_repo.get(str(sid), agent.get("user_id"))
|
||||
if source_doc:
|
||||
sources_list.append(
|
||||
{
|
||||
"id": str(source_doc["id"]),
|
||||
"retriever": source_doc.get("retriever", "classic"),
|
||||
"chunks": source_doc.get(
|
||||
"chunks", data.get("chunks", "2")
|
||||
),
|
||||
}
|
||||
)
|
||||
for sid_raw in agent.get("extra_source_ids") or []:
|
||||
if not sid_raw:
|
||||
continue
|
||||
source_doc = sources_repo.get(str(sid_raw), owner)
|
||||
if not source_doc:
|
||||
continue
|
||||
sid = str(source_doc["id"])
|
||||
if sid in seen:
|
||||
continue
|
||||
src_retriever = source_doc.get("retriever")
|
||||
src_chunks = source_doc.get("chunks")
|
||||
sources_list.append(
|
||||
{
|
||||
"id": sid,
|
||||
"retriever": src_retriever or "classic",
|
||||
"chunks": (
|
||||
src_chunks if src_chunks is not None
|
||||
else data.get("chunks", "2")
|
||||
),
|
||||
}
|
||||
)
|
||||
seen.add(sid)
|
||||
data["sources"] = sources_list
|
||||
data["default_model_id"] = data.get("default_model_id", "")
|
||||
return data
|
||||
@@ -589,7 +629,7 @@ class StreamProcessor:
|
||||
)
|
||||
|
||||
def _configure_retriever(self):
|
||||
"""Assemble retriever config with precedence: request > agent > default."""
|
||||
"""Assemble retriever config; agent's values are authoritative when bound."""
|
||||
# BYOM scope: owner for shared-agent BYOM, caller for own BYOM,
|
||||
# None for built-ins. Without ``user_id`` here, the doc budget
|
||||
# falls back to settings.DEFAULT_LLM_TOKEN_LIMIT and overfills
|
||||
@@ -598,12 +638,11 @@ class StreamProcessor:
|
||||
model_id=self.model_id, user_id=self.model_user_id
|
||||
)
|
||||
|
||||
# Start with defaults
|
||||
retriever_name = "classic"
|
||||
chunks = 2
|
||||
|
||||
# Layer agent-level config (if present)
|
||||
if self._agent_data:
|
||||
if self._agent_data is not None:
|
||||
# Agent-bound: agent wins, body's retriever/chunks are dropped.
|
||||
if self._agent_data.get("retriever"):
|
||||
retriever_name = self._agent_data["retriever"]
|
||||
if self._agent_data.get("chunks") is not None:
|
||||
@@ -614,18 +653,17 @@ class StreamProcessor:
|
||||
f"Invalid agent chunks value: {self._agent_data['chunks']}, "
|
||||
"using default value 2"
|
||||
)
|
||||
|
||||
# Explicit request values win over agent config
|
||||
if "retriever" in self.data:
|
||||
retriever_name = self.data["retriever"]
|
||||
if "chunks" in self.data:
|
||||
try:
|
||||
chunks = int(self.data["chunks"])
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
f"Invalid request chunks value: {self.data['chunks']}, "
|
||||
"using default value 2"
|
||||
)
|
||||
else:
|
||||
if "retriever" in self.data:
|
||||
retriever_name = self.data["retriever"]
|
||||
if "chunks" in self.data:
|
||||
try:
|
||||
chunks = int(self.data["chunks"])
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
f"Invalid request chunks value: {self.data['chunks']}, "
|
||||
"using default value 2"
|
||||
)
|
||||
|
||||
self.retriever_config = {
|
||||
"retriever_name": retriever_name,
|
||||
@@ -633,7 +671,7 @@ class StreamProcessor:
|
||||
"doc_token_limit": doc_token_limit,
|
||||
}
|
||||
|
||||
# isNoneDoc without an API key forces no retrieval
|
||||
# isNoneDoc without an API key forces no retrieval (agentless only)
|
||||
api_key = self.data.get("api_key") or self.agent_key
|
||||
if not api_key and "isNoneDoc" in self.data and self.data["isNoneDoc"]:
|
||||
self.retriever_config["chunks"] = 0
|
||||
@@ -708,17 +746,26 @@ class StreamProcessor:
|
||||
|
||||
try:
|
||||
user_id = self.initial_user_id or "local"
|
||||
agentless = self.agent_id is None
|
||||
with db_readonly() as conn:
|
||||
user_tools = UserToolsRepository(conn).list_active_for_user(user_id)
|
||||
user_doc = (
|
||||
UsersRepository(conn).get(user_id) if agentless else None
|
||||
)
|
||||
|
||||
if not user_tools:
|
||||
default_docs = (
|
||||
synthesized_default_tools(user_doc) if agentless else []
|
||||
)
|
||||
tool_docs = list(user_tools) + default_docs
|
||||
if not tool_docs:
|
||||
return None
|
||||
|
||||
tools_data = {}
|
||||
|
||||
for tool_doc in user_tools:
|
||||
for tool_doc in tool_docs:
|
||||
tool_name = tool_doc.get("name")
|
||||
tool_id = str(tool_doc.get("_id"))
|
||||
tool_id = str(tool_doc.get("_id") or tool_doc.get("id"))
|
||||
is_default = bool(tool_doc.get("default"))
|
||||
|
||||
if filtering_enabled:
|
||||
required_actions_by_name = required_tool_actions.get(
|
||||
@@ -731,11 +778,18 @@ class StreamProcessor:
|
||||
if not required_actions:
|
||||
continue
|
||||
else:
|
||||
# No template names a default tool, so running its
|
||||
# actions blind would only inject noise.
|
||||
if is_default:
|
||||
continue
|
||||
required_actions = None
|
||||
|
||||
tool_data = self._fetch_tool_data(tool_doc, required_actions)
|
||||
if tool_data:
|
||||
tools_data[tool_name] = tool_data
|
||||
# Defaults reachable by synthetic id only — the name
|
||||
# key stays bound to an explicit row of the same name.
|
||||
if not is_default:
|
||||
tools_data[tool_name] = tool_data
|
||||
tools_data[tool_id] = tool_data
|
||||
|
||||
return tools_data if tools_data else None
|
||||
@@ -982,6 +1036,7 @@ class StreamProcessor:
|
||||
user_api_key=user_api_key,
|
||||
user=self.initial_user_id,
|
||||
decoded_token=self.decoded_token,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
tool_executor.conversation_id = conversation_id
|
||||
# Restore client tools so they stay available for subsequent LLM calls
|
||||
@@ -1130,6 +1185,7 @@ class StreamProcessor:
|
||||
user_api_key=self.agent_config["user_api_key"],
|
||||
user=user,
|
||||
decoded_token=self.decoded_token,
|
||||
agent_id=self.agent_id,
|
||||
)
|
||||
tool_executor.conversation_id = self.conversation_id
|
||||
# Pass client-side tools so they get merged in get_tools()
|
||||
@@ -1137,7 +1193,6 @@ class StreamProcessor:
|
||||
if client_tools:
|
||||
tool_executor.client_tools = client_tools
|
||||
|
||||
# Base agent kwargs
|
||||
agent_kwargs = {
|
||||
"endpoint": "stream",
|
||||
"llm_name": provider or settings.LLM_PROVIDER,
|
||||
|
||||
@@ -83,13 +83,15 @@ def resolve_tool_details(tool_ids):
|
||||
"""
|
||||
Resolve tool IDs to their display details.
|
||||
|
||||
Accepts either Postgres UUIDs or legacy Mongo ObjectId strings (mixed
|
||||
lists are supported — each id is looked up via ``get_any``, which
|
||||
resolves to whichever column matches). Unknown ids are silently
|
||||
Accepts Postgres UUIDs, legacy Mongo ObjectId strings, or the
|
||||
synthetic ids of default chat tools / agent-selectable builtins
|
||||
(mixed lists are supported). Synthetic ids are resolved in memory;
|
||||
real ids are looked up via ``get_any``. Unknown ids are silently
|
||||
skipped.
|
||||
|
||||
Args:
|
||||
tool_ids: List of tool IDs (UUIDs or legacy Mongo ObjectId strings).
|
||||
tool_ids: List of tool IDs (UUIDs, legacy ObjectId strings, or
|
||||
synthetic default-tool / builtin ids).
|
||||
|
||||
Returns:
|
||||
List of tool details with ``id``, ``name``, and ``display_name``.
|
||||
@@ -97,19 +99,37 @@ def resolve_tool_details(tool_ids):
|
||||
if not tool_ids:
|
||||
return []
|
||||
|
||||
from application.agents.default_tools import (
|
||||
is_synthesized_tool_id,
|
||||
synthesize_tool_by_name,
|
||||
synthesized_tool_name_for_id,
|
||||
)
|
||||
|
||||
uuid_ids: list[str] = []
|
||||
legacy_ids: list[str] = []
|
||||
default_details: list[dict] = []
|
||||
for tid in tool_ids:
|
||||
if not tid:
|
||||
continue
|
||||
tid_str = str(tid)
|
||||
if is_synthesized_tool_id(tid_str):
|
||||
synth = synthesize_tool_by_name(synthesized_tool_name_for_id(tid_str))
|
||||
if synth is not None:
|
||||
default_details.append(
|
||||
{
|
||||
"id": tid_str,
|
||||
"name": synth.get("name", ""),
|
||||
"display_name": synth.get("display_name", ""),
|
||||
}
|
||||
)
|
||||
continue
|
||||
if looks_like_uuid(tid_str):
|
||||
uuid_ids.append(tid_str)
|
||||
else:
|
||||
legacy_ids.append(tid_str)
|
||||
|
||||
if not uuid_ids and not legacy_ids:
|
||||
return []
|
||||
return default_details
|
||||
|
||||
rows: list[dict] = []
|
||||
with db_readonly() as conn:
|
||||
@@ -132,7 +152,7 @@ def resolve_tool_details(tool_ids):
|
||||
)
|
||||
rows.extend(row_to_dict(r) for r in result.fetchall())
|
||||
|
||||
return [
|
||||
return default_details + [
|
||||
{
|
||||
"id": str(tool.get("id") or tool.get("legacy_mongo_id") or ""),
|
||||
"name": tool.get("name", "") or "",
|
||||
|
||||
@@ -4,7 +4,8 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Dict, Optional
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Optional, TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import Connection
|
||||
|
||||
@@ -16,6 +17,9 @@ from application.storage.db.repositories.reconciliation import (
|
||||
)
|
||||
from application.storage.db.repositories.stack_logs import StackLogsRepository
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from application.storage.db.repositories.schedules import SchedulesRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -46,6 +50,7 @@ def run_reconciliation() -> Dict[str, Any]:
|
||||
"tool_calls_failed": 0,
|
||||
"ingests_stalled": 0,
|
||||
"idempotency_pending_failed": 0,
|
||||
"schedule_runs_failed": 0,
|
||||
}
|
||||
|
||||
with engine.begin() as conn:
|
||||
@@ -169,9 +174,101 @@ def run_reconciliation() -> Dict[str, Any]:
|
||||
},
|
||||
)
|
||||
|
||||
# Q6: scheduler runs stuck in 'running' past the soft-time-limit window.
|
||||
from application.storage.db.repositories.schedule_runs import (
|
||||
ScheduleRunsRepository,
|
||||
)
|
||||
from application.storage.db.repositories.schedules import SchedulesRepository
|
||||
from application.core.settings import settings as _settings
|
||||
|
||||
stuck_age = max(
|
||||
15, int(_settings.SCHEDULE_RUN_TIMEOUT // 60) + 5,
|
||||
)
|
||||
with engine.begin() as conn:
|
||||
runs_repo = ScheduleRunsRepository(conn)
|
||||
schedules_repo = SchedulesRepository(conn)
|
||||
for run in runs_repo.list_stuck_running(age_minutes=stuck_age):
|
||||
runs_repo.update(
|
||||
run["id"],
|
||||
{
|
||||
"status": "timeout",
|
||||
"finished_at": datetime.now(timezone.utc),
|
||||
"error_type": "timeout",
|
||||
"error": (
|
||||
"reconciler: schedule_run stuck in 'running' past "
|
||||
f"{stuck_age} min"
|
||||
),
|
||||
},
|
||||
)
|
||||
schedules_repo.bump_failure_count(str(run["schedule_id"]))
|
||||
_terminal_flip_once_schedule(
|
||||
schedules_repo, str(run["schedule_id"]),
|
||||
)
|
||||
summary["schedule_runs_failed"] += 1
|
||||
_emit_alert(
|
||||
conn,
|
||||
name="reconciler_schedule_run_timeout",
|
||||
user_id=run.get("user_id"),
|
||||
detail={
|
||||
"run_id": str(run["id"]),
|
||||
"schedule_id": str(run["schedule_id"]),
|
||||
},
|
||||
)
|
||||
|
||||
# Q7: scheduler runs orphaned in 'pending' — dispatcher committed but
|
||||
# apply_async failed (broker outage / crash mid-dispatch).
|
||||
with engine.begin() as conn:
|
||||
runs_repo = ScheduleRunsRepository(conn)
|
||||
schedules_repo = SchedulesRepository(conn)
|
||||
for run in runs_repo.list_stuck_pending(age_minutes=stuck_age):
|
||||
runs_repo.update(
|
||||
run["id"],
|
||||
{
|
||||
"status": "failed",
|
||||
"finished_at": datetime.now(timezone.utc),
|
||||
"error_type": "internal",
|
||||
"error": (
|
||||
"reconciler: schedule_run stuck in 'pending' past "
|
||||
f"{stuck_age} min (worker_never_started)"
|
||||
),
|
||||
},
|
||||
)
|
||||
schedules_repo.bump_failure_count(str(run["schedule_id"]))
|
||||
_terminal_flip_once_schedule(
|
||||
schedules_repo, str(run["schedule_id"]),
|
||||
)
|
||||
summary["schedule_runs_failed"] += 1
|
||||
_emit_alert(
|
||||
conn,
|
||||
name="reconciler_schedule_run_pending",
|
||||
user_id=run.get("user_id"),
|
||||
detail={
|
||||
"run_id": str(run["id"]),
|
||||
"schedule_id": str(run["schedule_id"]),
|
||||
},
|
||||
)
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
def _terminal_flip_once_schedule(
|
||||
schedules_repo: "SchedulesRepository", schedule_id: str,
|
||||
) -> None:
|
||||
"""Flip a once-schedule to 'completed' after its run terminates.
|
||||
|
||||
Recurring schedules keep firing; once-schedules would otherwise read
|
||||
'active forever' since next_run_at is already NULL.
|
||||
"""
|
||||
schedule = schedules_repo.get_internal(schedule_id)
|
||||
if schedule is None or schedule.get("trigger_type") != "once":
|
||||
return
|
||||
if schedule.get("status") in {"completed", "cancelled"}:
|
||||
return
|
||||
schedules_repo.update_internal(
|
||||
schedule_id, {"status": "completed", "next_run_at": None},
|
||||
)
|
||||
|
||||
|
||||
def _emit_alert(
|
||||
conn: Connection,
|
||||
*,
|
||||
|
||||
@@ -11,6 +11,7 @@ from .attachments import attachments_ns
|
||||
from .conversations import conversations_ns
|
||||
from .models import models_ns
|
||||
from .prompts import prompts_ns
|
||||
from .schedules import schedules_ns
|
||||
from .sharing import sharing_ns
|
||||
from .sources import sources_chunks_ns, sources_ns, sources_upload_ns
|
||||
from .tools import tools_mcp_ns, tools_ns
|
||||
@@ -40,6 +41,9 @@ api.add_namespace(agents_folders_ns)
|
||||
# Prompts
|
||||
api.add_namespace(prompts_ns)
|
||||
|
||||
# Schedules
|
||||
api.add_namespace(schedules_ns)
|
||||
|
||||
# Sharing
|
||||
api.add_namespace(sharing_ns)
|
||||
|
||||
|
||||
186
application/api/user/scheduler_dispatcher.py
Normal file
186
application/api/user/scheduler_dispatcher.py
Normal file
@@ -0,0 +1,186 @@
|
||||
"""Schedule dispatcher: poll Postgres, claim due rows under FOR UPDATE SKIP LOCKED,
|
||||
advance next_run_at atomically with the run claim, then enqueue.
|
||||
|
||||
Per-schedule IANA tz semantics (croniter+zoneinfo) outside Celery's app-wide tz,
|
||||
plus Postgres-native dedup avoid Redis visibility_timeout double-fires.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from application.agents.scheduler_utils import next_cron_run
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.engine import get_engine
|
||||
from application.storage.db.repositories.schedule_runs import (
|
||||
ScheduleRunsRepository,
|
||||
)
|
||||
from application.storage.db.repositories.schedules import SchedulesRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _normalize_dt(value: Any) -> Optional[datetime]:
|
||||
"""Accept a datetime / ISO string / None and return a tz-aware UTC dt."""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, datetime):
|
||||
return value.astimezone(timezone.utc) if value.tzinfo else (
|
||||
value.replace(tzinfo=timezone.utc)
|
||||
)
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
parsed = datetime.fromisoformat(value.replace("Z", "+00:00"))
|
||||
except ValueError:
|
||||
return None
|
||||
return parsed.astimezone(timezone.utc) if parsed.tzinfo else (
|
||||
parsed.replace(tzinfo=timezone.utc)
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _compute_next(
|
||||
schedule: Dict[str, Any],
|
||||
*,
|
||||
after: datetime,
|
||||
) -> Optional[datetime]:
|
||||
"""Next next_run_at for a recurring schedule, or None when past end_at."""
|
||||
cron = schedule.get("cron")
|
||||
if not cron:
|
||||
return None
|
||||
end_at = _normalize_dt(schedule.get("end_at"))
|
||||
candidate = next_cron_run(cron, schedule.get("timezone"), after=after)
|
||||
if end_at is not None and candidate > end_at:
|
||||
return None
|
||||
return candidate
|
||||
|
||||
|
||||
def dispatch_due_runs() -> Dict[str, int]:
|
||||
"""One dispatcher tick; returns counts for schedule_syncs-style logging."""
|
||||
if not settings.POSTGRES_URI:
|
||||
return {"enqueued": 0, "skipped": 0, "advanced": 0}
|
||||
|
||||
from application.api.user.tasks import execute_scheduled_run
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
grace = timedelta(seconds=max(0, settings.SCHEDULE_MISFIRE_GRACE))
|
||||
engine = get_engine()
|
||||
counts = {"enqueued": 0, "skipped": 0, "advanced": 0}
|
||||
enqueue_args: List[str] = []
|
||||
|
||||
with engine.begin() as conn:
|
||||
schedules_repo = SchedulesRepository(conn)
|
||||
runs_repo = ScheduleRunsRepository(conn)
|
||||
for schedule in schedules_repo.list_due():
|
||||
scheduled_for = _normalize_dt(schedule.get("next_run_at"))
|
||||
if scheduled_for is None:
|
||||
continue
|
||||
|
||||
trigger_type = schedule.get("trigger_type")
|
||||
agent_id_raw = schedule.get("agent_id")
|
||||
agent_id = str(agent_id_raw) if agent_id_raw else None
|
||||
|
||||
# Misfire grace applies to recurring only — once-tasks fire late, not vanish.
|
||||
if (
|
||||
trigger_type == "recurring"
|
||||
and grace > timedelta(0)
|
||||
and (now - scheduled_for) > grace
|
||||
):
|
||||
runs_repo.record_skipped(
|
||||
str(schedule["id"]),
|
||||
schedule["user_id"],
|
||||
agent_id,
|
||||
scheduled_for,
|
||||
error_type="missed",
|
||||
error="misfire grace exceeded",
|
||||
)
|
||||
counts["skipped"] += 1
|
||||
nxt = _compute_next(schedule, after=now)
|
||||
if nxt is None:
|
||||
schedules_repo.update_internal(
|
||||
str(schedule["id"]),
|
||||
{"status": "completed", "next_run_at": None,
|
||||
"last_run_at": now},
|
||||
)
|
||||
else:
|
||||
schedules_repo.update_internal(
|
||||
str(schedule["id"]),
|
||||
{"next_run_at": nxt, "last_run_at": now},
|
||||
)
|
||||
counts["advanced"] += 1
|
||||
continue
|
||||
|
||||
# Overlap guard: never enqueue while a previous run is active.
|
||||
if runs_repo.has_active_run(str(schedule["id"])):
|
||||
runs_repo.record_skipped(
|
||||
str(schedule["id"]),
|
||||
schedule["user_id"],
|
||||
agent_id,
|
||||
scheduled_for,
|
||||
error_type="overlap",
|
||||
error="previous run still active",
|
||||
)
|
||||
counts["skipped"] += 1
|
||||
if trigger_type == "recurring":
|
||||
nxt = _compute_next(schedule, after=scheduled_for)
|
||||
schedules_repo.update_internal(
|
||||
str(schedule["id"]),
|
||||
{"next_run_at": nxt, "last_run_at": now},
|
||||
)
|
||||
else:
|
||||
# Once: null next_run_at so we don't re-pick; the in-flight
|
||||
# run will terminal-flip the schedule when it finishes.
|
||||
schedules_repo.update_internal(
|
||||
str(schedule["id"]),
|
||||
{"next_run_at": None, "last_run_at": now},
|
||||
)
|
||||
continue
|
||||
|
||||
# Dedup primitive: two racing dispatchers see exactly one row.
|
||||
run = runs_repo.record_pending(
|
||||
str(schedule["id"]),
|
||||
schedule["user_id"],
|
||||
agent_id,
|
||||
scheduled_for,
|
||||
trigger_source="cron",
|
||||
)
|
||||
if run is None:
|
||||
counts["skipped"] += 1
|
||||
else:
|
||||
enqueue_args.append(str(run["id"]))
|
||||
counts["enqueued"] += 1
|
||||
|
||||
# Advance: recurring picks next tick, once nulls next_run_at
|
||||
# (worker terminal-flips status on completion).
|
||||
if trigger_type == "recurring":
|
||||
nxt = _compute_next(schedule, after=scheduled_for)
|
||||
if nxt is None:
|
||||
schedules_repo.update_internal(
|
||||
str(schedule["id"]),
|
||||
{"status": "completed", "next_run_at": None,
|
||||
"last_run_at": now},
|
||||
)
|
||||
else:
|
||||
schedules_repo.update_internal(
|
||||
str(schedule["id"]),
|
||||
{"next_run_at": nxt, "last_run_at": now},
|
||||
)
|
||||
else:
|
||||
schedules_repo.update_internal(
|
||||
str(schedule["id"]),
|
||||
{"next_run_at": None, "last_run_at": now},
|
||||
)
|
||||
counts["advanced"] += 1
|
||||
|
||||
# Enqueue after commit so the worker sees the schedule_runs row on pick-up.
|
||||
for run_id in enqueue_args:
|
||||
try:
|
||||
execute_scheduled_run.apply_async(args=[run_id], queue="docsgpt")
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"dispatcher: failed to enqueue execute_scheduled_run for %s",
|
||||
run_id,
|
||||
)
|
||||
return counts
|
||||
433
application/api/user/scheduler_worker.py
Normal file
433
application/api/user/scheduler_worker.py
Normal file
@@ -0,0 +1,433 @@
|
||||
"""Body of ``execute_scheduled_run`` — runs a single agent execution.
|
||||
|
||||
Not a DURABLE_TASK: agent runs have side effects (messages, CRM writes)
|
||||
and blind auto-retry would double them. Failures after agent.gen starts
|
||||
are terminal and recorded; only the pre-start load is retry-safe.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from sqlalchemy import text as sql_text
|
||||
|
||||
from application.agents.headless_runner import run_agent_headless
|
||||
from application.core.settings import settings
|
||||
from application.events.publisher import publish_user_event
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
from application.storage.db.engine import get_engine
|
||||
from application.storage.db.repositories.conversations import (
|
||||
ConversationsRepository,
|
||||
)
|
||||
from application.storage.db.repositories.schedule_runs import (
|
||||
ScheduleRunsRepository,
|
||||
)
|
||||
from application.storage.db.repositories.schedules import SchedulesRepository
|
||||
from application.storage.db.repositories.token_usage import TokenUsageRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Cap output verbatim in the run log; beyond the cap we keep the head and stamp output_truncated.
|
||||
_OUTPUT_CAP_CHARS = 24_000
|
||||
|
||||
|
||||
def _agent_config_for_schedule(schedule: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""Resolve the agent row (agent-bound) or build an ephemeral classic config.
|
||||
|
||||
For agentless schedules (``agent_id IS NULL``), the worker constructs an
|
||||
in-memory agent shape carrying just enough fields for ``run_agent_headless``:
|
||||
classic agent type, system-default retriever/chunks/prompt, no source, and
|
||||
the optional ``model_id`` override. The runtime toolset is rebuilt by
|
||||
``ToolExecutor`` at fire time (current ``user_tools`` + non-disabled,
|
||||
non-headless-excluded defaults), so a snapshot here would be dead code.
|
||||
"""
|
||||
if schedule.get("agent_id"):
|
||||
engine = get_engine()
|
||||
with engine.connect() as conn:
|
||||
row = conn.execute(
|
||||
sql_text("SELECT * FROM agents WHERE id = CAST(:id AS uuid)"),
|
||||
{"id": str(schedule["agent_id"])},
|
||||
).fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
return _ephemeral_agent_for_agentless(schedule)
|
||||
|
||||
|
||||
def _ephemeral_agent_for_agentless(
|
||||
schedule: Dict[str, Any],
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Build an agent-shaped config for a schedule with no parent agent."""
|
||||
# ``agent_config["tools"]`` is intentionally omitted: ``run_agent_headless``
|
||||
# never reads it. The runtime toolset is rebuilt by
|
||||
# ``ToolExecutor._get_user_tools(owner)`` at fire time — same dereference
|
||||
# the agent-bound path uses, so a tool added/disabled after creation is
|
||||
# reflected. Headless mode there filters chat-only tools (``scheduler``).
|
||||
user_id = schedule.get("user_id")
|
||||
if not user_id:
|
||||
return None
|
||||
return {
|
||||
"id": None,
|
||||
"user_id": user_id,
|
||||
"agent_type": "classic",
|
||||
"retriever": "classic",
|
||||
"chunks": 2,
|
||||
"prompt_id": "default",
|
||||
"source_id": None,
|
||||
"default_model_id": schedule.get("model_id") or "",
|
||||
}
|
||||
|
||||
|
||||
def _load_chat_history(schedule: Dict[str, Any]) -> list:
|
||||
"""Originating conversation history (one-time only; recurring has none)."""
|
||||
origin = schedule.get("origin_conversation_id")
|
||||
if not origin or schedule.get("trigger_type") != "once":
|
||||
return []
|
||||
user_id = schedule.get("user_id")
|
||||
if not user_id:
|
||||
return []
|
||||
try:
|
||||
engine = get_engine()
|
||||
with engine.connect() as conn:
|
||||
conv = ConversationsRepository(conn).get_any(str(origin), user_id)
|
||||
if conv is None:
|
||||
return []
|
||||
messages = ConversationsRepository(conn).get_messages(str(conv["id"]))
|
||||
except Exception:
|
||||
logger.exception("scheduler: failed loading chat history")
|
||||
return []
|
||||
history: list = []
|
||||
for msg in messages:
|
||||
if msg.get("prompt") and msg.get("response"):
|
||||
history.append({
|
||||
"prompt": msg["prompt"],
|
||||
"response": msg["response"],
|
||||
})
|
||||
return history
|
||||
|
||||
|
||||
def _publish_run_event(
|
||||
event_type: str, run: Dict[str, Any], schedule: Dict[str, Any], **extra: Any
|
||||
) -> None:
|
||||
"""Best-effort SSE publish for a scheduler run state transition."""
|
||||
user_id = run.get("user_id") or schedule.get("user_id")
|
||||
if not user_id:
|
||||
return
|
||||
agent_id_raw = schedule.get("agent_id")
|
||||
payload = {
|
||||
"run_id": str(run["id"]),
|
||||
"schedule_id": str(schedule["id"]),
|
||||
"agent_id": str(agent_id_raw) if agent_id_raw else None,
|
||||
"trigger_type": schedule.get("trigger_type"),
|
||||
"status": run.get("status"),
|
||||
**extra,
|
||||
}
|
||||
try:
|
||||
publish_user_event(
|
||||
user_id,
|
||||
event_type,
|
||||
payload,
|
||||
scope={"kind": "schedule", "id": str(schedule["id"])},
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"scheduler: SSE publish failed event=%s run=%s",
|
||||
event_type, run.get("id"),
|
||||
)
|
||||
|
||||
|
||||
def _publish_message_appended(
|
||||
user_id: str,
|
||||
conversation_id: str,
|
||||
message: Dict[str, Any],
|
||||
schedule_id: str,
|
||||
run_id: str,
|
||||
) -> None:
|
||||
"""SSE message-appended event for a one-time run's chat turn."""
|
||||
try:
|
||||
publish_user_event(
|
||||
user_id,
|
||||
"schedule.message.appended",
|
||||
{
|
||||
"conversation_id": str(conversation_id),
|
||||
"message_id": str(message["id"]),
|
||||
"schedule_id": str(schedule_id),
|
||||
"run_id": str(run_id),
|
||||
"position": int(message.get("position", 0)),
|
||||
},
|
||||
scope={"kind": "conversation", "id": str(conversation_id)},
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"scheduler: message.appended publish failed run=%s", run_id,
|
||||
)
|
||||
|
||||
|
||||
def _append_one_time_turn(
|
||||
schedule: Dict[str, Any],
|
||||
run: Dict[str, Any],
|
||||
outcome: Dict[str, Any],
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Insert an assistant turn in the originating conversation (once only)."""
|
||||
origin = schedule.get("origin_conversation_id")
|
||||
if not origin:
|
||||
return None
|
||||
engine = get_engine()
|
||||
user_id = schedule.get("user_id")
|
||||
metadata = {
|
||||
"scheduled": True,
|
||||
"schedule_id": str(schedule["id"]),
|
||||
"run_id": str(run["id"]),
|
||||
"scheduled_run_at": (
|
||||
run.get("scheduled_for")
|
||||
if isinstance(run.get("scheduled_for"), str)
|
||||
else None
|
||||
),
|
||||
}
|
||||
with engine.begin() as conn:
|
||||
conv = ConversationsRepository(conn).get_any(str(origin), user_id)
|
||||
if conv is None:
|
||||
return None
|
||||
message = ConversationsRepository(conn).append_message(
|
||||
str(conv["id"]),
|
||||
{
|
||||
"prompt": schedule.get("instruction") or "",
|
||||
"response": outcome.get("answer") or "",
|
||||
"thought": outcome.get("thought") or "",
|
||||
"sources": outcome.get("sources") or [],
|
||||
"tool_calls": outcome.get("tool_calls") or [],
|
||||
"model_id": outcome.get("model_id"),
|
||||
"metadata": metadata,
|
||||
},
|
||||
)
|
||||
return message
|
||||
|
||||
|
||||
def execute_scheduled_run_body(run_id: str, celery_task_id: Optional[str]) -> Dict[str, Any]:
|
||||
"""Execute one scheduled run by id; returns a result dict for tracing."""
|
||||
if not settings.POSTGRES_URI:
|
||||
return {"status": "skipped", "reason": "POSTGRES_URI not set"}
|
||||
|
||||
engine = get_engine()
|
||||
|
||||
with engine.connect() as conn:
|
||||
run = ScheduleRunsRepository(conn).get_internal(run_id)
|
||||
if run is None:
|
||||
return {"status": "skipped", "reason": "run not found"}
|
||||
schedule = SchedulesRepository(conn).get_internal(str(run["schedule_id"]))
|
||||
if schedule is None:
|
||||
return {"status": "skipped", "reason": "schedule not found"}
|
||||
|
||||
# Refuse non-runnable terminal states; manual run-now bypasses.
|
||||
if run.get("status") != "pending":
|
||||
return {"status": "skipped", "reason": f"run status={run.get('status')}"}
|
||||
if schedule.get("status") in {"cancelled", "completed"} and run.get(
|
||||
"trigger_source"
|
||||
) != "manual":
|
||||
with engine.begin() as conn:
|
||||
ScheduleRunsRepository(conn).update(
|
||||
run_id,
|
||||
{
|
||||
"status": "skipped",
|
||||
"finished_at": datetime.now(timezone.utc),
|
||||
"error_type": "internal",
|
||||
"error": "schedule no longer active",
|
||||
},
|
||||
)
|
||||
return {"status": "skipped", "reason": "schedule terminal"}
|
||||
|
||||
agent_config = _agent_config_for_schedule(schedule)
|
||||
if agent_config is None:
|
||||
with engine.begin() as conn:
|
||||
updated = ScheduleRunsRepository(conn).update(
|
||||
run_id,
|
||||
{
|
||||
"status": "failed",
|
||||
"finished_at": datetime.now(timezone.utc),
|
||||
"error_type": "internal",
|
||||
"error": "agent missing",
|
||||
},
|
||||
)
|
||||
SchedulesRepository(conn).bump_failure_count(str(schedule["id"]))
|
||||
_publish_run_event("schedule.run.failed", updated or run, schedule,
|
||||
error="agent missing")
|
||||
return {"status": "failed", "reason": "agent missing"}
|
||||
|
||||
with engine.begin() as conn:
|
||||
if not ScheduleRunsRepository(conn).mark_running(run_id, celery_task_id):
|
||||
return {"status": "skipped", "reason": "lost race to mark_running"}
|
||||
|
||||
started = datetime.now(timezone.utc)
|
||||
instruction = schedule.get("instruction") or ""
|
||||
allowlist = schedule.get("tool_allowlist") or []
|
||||
chat_history = _load_chat_history(schedule)
|
||||
outcome: Dict[str, Any]
|
||||
error_type: Optional[str] = None
|
||||
error_text: Optional[str] = None
|
||||
timed_out = False
|
||||
try:
|
||||
outcome = run_agent_headless(
|
||||
agent_config,
|
||||
instruction,
|
||||
tool_allowlist=allowlist,
|
||||
model_id_override=schedule.get("model_id"),
|
||||
endpoint="schedule",
|
||||
conversation_id=schedule.get("origin_conversation_id"),
|
||||
chat_history=chat_history,
|
||||
)
|
||||
except SoftTimeLimitExceeded:
|
||||
timed_out = True
|
||||
outcome = {"answer": "", "tool_calls": [], "sources": [], "thought": ""}
|
||||
error_type = "timeout"
|
||||
error_text = "run exceeded soft time limit"
|
||||
except Exception as exc:
|
||||
outcome = {"answer": "", "tool_calls": [], "sources": [], "thought": ""}
|
||||
error_type = "agent_error"
|
||||
error_text = str(exc)
|
||||
logger.exception("scheduler: agent run failed run=%s", run_id)
|
||||
|
||||
finished = datetime.now(timezone.utc)
|
||||
|
||||
# Headless denial with no usable output → tool_not_allowed.
|
||||
if (
|
||||
error_type is None
|
||||
and (outcome.get("denied") or [])
|
||||
and not (outcome.get("answer") or "").strip()
|
||||
):
|
||||
error_type = "tool_not_allowed"
|
||||
error_text = "headless allowlist blocked required tool"
|
||||
|
||||
prompt_tokens = int(outcome.get("prompt_tokens", 0) or 0)
|
||||
generated_tokens = int(outcome.get("generated_tokens", 0) or 0)
|
||||
used_tokens = prompt_tokens + generated_tokens
|
||||
if (
|
||||
schedule.get("token_budget") is not None
|
||||
and int(schedule["token_budget"]) > 0
|
||||
and used_tokens > int(schedule["token_budget"])
|
||||
):
|
||||
error_type = "budget_exceeded"
|
||||
error_text = (
|
||||
f"used {used_tokens} tokens exceeds budget "
|
||||
f"{schedule['token_budget']}"
|
||||
)
|
||||
|
||||
answer = outcome.get("answer") or ""
|
||||
truncated = False
|
||||
if len(answer) > _OUTPUT_CAP_CHARS:
|
||||
answer = answer[:_OUTPUT_CAP_CHARS]
|
||||
truncated = True
|
||||
|
||||
new_status = (
|
||||
"timeout" if timed_out else ("failed" if error_type else "success")
|
||||
)
|
||||
|
||||
with engine.begin() as conn:
|
||||
update_fields: Dict[str, Any] = {
|
||||
"status": new_status,
|
||||
"started_at": started,
|
||||
"finished_at": finished,
|
||||
"output": answer or None,
|
||||
"output_truncated": truncated,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"generated_tokens": generated_tokens,
|
||||
}
|
||||
if error_type:
|
||||
update_fields["error_type"] = error_type
|
||||
update_fields["error"] = error_text
|
||||
updated_run = ScheduleRunsRepository(conn).update(run_id, update_fields)
|
||||
if used_tokens > 0:
|
||||
agent_id_raw = schedule.get("agent_id")
|
||||
try:
|
||||
TokenUsageRepository(conn).insert(
|
||||
user_id=schedule.get("user_id"),
|
||||
api_key=None,
|
||||
prompt_tokens=prompt_tokens,
|
||||
generated_tokens=generated_tokens,
|
||||
timestamp=finished,
|
||||
agent_id=str(agent_id_raw) if agent_id_raw else None,
|
||||
source="schedule",
|
||||
request_id=str(run_id),
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"scheduler: token_usage insert failed run=%s", run_id,
|
||||
)
|
||||
schedules_repo = SchedulesRepository(conn)
|
||||
autopaused = False
|
||||
if new_status == "success":
|
||||
schedules_repo.reset_failure_count(str(schedule["id"]))
|
||||
elif new_status in ("failed", "timeout"):
|
||||
count = schedules_repo.bump_failure_count(str(schedule["id"]))
|
||||
if (
|
||||
settings.SCHEDULE_AUTOPAUSE_FAILURES > 0
|
||||
and count >= settings.SCHEDULE_AUTOPAUSE_FAILURES
|
||||
and schedule.get("trigger_type") == "recurring"
|
||||
):
|
||||
autopaused = schedules_repo.autopause(str(schedule["id"]))
|
||||
# Once: terminal-flip on cron-fired runs only; manual runs on a
|
||||
# still-active once-schedule leave the future cadence intact.
|
||||
if (
|
||||
schedule.get("trigger_type") == "once"
|
||||
and run.get("trigger_source") != "manual"
|
||||
and schedule.get("status") == "active"
|
||||
):
|
||||
schedules_repo.update_internal(
|
||||
str(schedule["id"]),
|
||||
{"status": "completed", "next_run_at": None},
|
||||
)
|
||||
|
||||
appended: Optional[Dict[str, Any]] = None
|
||||
if (
|
||||
schedule.get("trigger_type") == "once"
|
||||
and new_status == "success"
|
||||
and schedule.get("origin_conversation_id")
|
||||
):
|
||||
try:
|
||||
appended = _append_one_time_turn(schedule, updated_run or run, outcome)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"scheduler: append turn failed run=%s", run_id,
|
||||
)
|
||||
if appended is not None:
|
||||
with engine.begin() as conn:
|
||||
ScheduleRunsRepository(conn).update(
|
||||
run_id,
|
||||
{
|
||||
"conversation_id": str(appended["conversation_id"]),
|
||||
"message_id": str(appended["id"]),
|
||||
},
|
||||
)
|
||||
_publish_message_appended(
|
||||
schedule.get("user_id"),
|
||||
str(appended["conversation_id"]),
|
||||
appended,
|
||||
str(schedule["id"]),
|
||||
run_id,
|
||||
)
|
||||
|
||||
if new_status == "success":
|
||||
_publish_run_event("schedule.run.completed", updated_run or run, schedule)
|
||||
else:
|
||||
_publish_run_event(
|
||||
"schedule.run.failed",
|
||||
updated_run or run,
|
||||
schedule,
|
||||
error_type=error_type,
|
||||
error=error_text,
|
||||
)
|
||||
|
||||
if autopaused:
|
||||
_publish_run_event(
|
||||
"schedule.autopaused",
|
||||
updated_run or run,
|
||||
schedule,
|
||||
consecutive_failure_count=settings.SCHEDULE_AUTOPAUSE_FAILURES,
|
||||
)
|
||||
|
||||
return {
|
||||
"status": new_status,
|
||||
"run_id": run_id,
|
||||
"error_type": error_type,
|
||||
}
|
||||
5
application/api/user/schedules/__init__.py
Normal file
5
application/api/user/schedules/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Schedules module."""
|
||||
|
||||
from .routes import schedules_ns
|
||||
|
||||
__all__ = ["schedules_ns"]
|
||||
550
application/api/user/schedules/routes.py
Normal file
550
application/api/user/schedules/routes.py
Normal file
@@ -0,0 +1,550 @@
|
||||
"""Schedules REST API (owner-scoped via request.decoded_token)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import Namespace, Resource, fields
|
||||
|
||||
from application.agents.scheduler_utils import (
|
||||
ScheduleValidationError,
|
||||
clamp_once_horizon,
|
||||
cron_interval_seconds,
|
||||
next_cron_run,
|
||||
parse_cron,
|
||||
parse_run_at,
|
||||
resolve_timezone,
|
||||
)
|
||||
from application.api import api
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.schedule_runs import (
|
||||
ScheduleRunsRepository,
|
||||
)
|
||||
from application.storage.db.repositories.schedules import SchedulesRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
schedules_ns = Namespace(
|
||||
"schedules", description="Agent schedule management", path="/api",
|
||||
)
|
||||
|
||||
|
||||
def _ok(data: Any, status: int = 200):
|
||||
return make_response(jsonify(data), status)
|
||||
|
||||
|
||||
def _err(message: str, status: int = 400):
|
||||
return make_response(jsonify({"success": False, "message": message}), status)
|
||||
|
||||
|
||||
def _format_schedule(row: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Render a schedule row for the API (id-as-string + ISO timestamps)."""
|
||||
if not row:
|
||||
return {}
|
||||
out = dict(row)
|
||||
for key in (
|
||||
"id", "agent_id", "origin_conversation_id",
|
||||
):
|
||||
if out.get(key) is not None:
|
||||
out[key] = str(out[key])
|
||||
out.pop("_id", None) # drop dual-id legacy mirror
|
||||
return out
|
||||
|
||||
|
||||
def _format_run(row: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Render a schedule_run row for the API."""
|
||||
if not row:
|
||||
return {}
|
||||
out = dict(row)
|
||||
for key in (
|
||||
"id", "schedule_id", "agent_id", "conversation_id", "message_id",
|
||||
):
|
||||
if out.get(key) is not None:
|
||||
out[key] = str(out[key])
|
||||
out.pop("_id", None)
|
||||
return out
|
||||
|
||||
|
||||
def _agent_owned(agent_id: str, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
if not looks_like_uuid(str(agent_id)):
|
||||
return None
|
||||
with db_readonly() as conn:
|
||||
return AgentsRepository(conn).get_any(agent_id, user_id)
|
||||
|
||||
|
||||
def _user_id() -> Optional[str]:
|
||||
decoded = getattr(request, "decoded_token", None)
|
||||
if not decoded:
|
||||
return None
|
||||
return decoded.get("sub")
|
||||
|
||||
|
||||
@schedules_ns.route("/agents/<string:agent_id>/schedules")
|
||||
class AgentSchedules(Resource):
|
||||
@api.doc(description="List schedules for an agent (recurring + one-time).")
|
||||
def get(self, agent_id):
|
||||
user_id = _user_id()
|
||||
if not user_id:
|
||||
return _err("unauthorized", 401)
|
||||
agent = _agent_owned(agent_id, user_id)
|
||||
if agent is None:
|
||||
return _err("agent not found", 404)
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
rows = SchedulesRepository(conn).list_for_agent(
|
||||
str(agent["id"]), user_id,
|
||||
)
|
||||
except Exception as exc:
|
||||
current_app.logger.error("list schedules failed: %s", exc, exc_info=True)
|
||||
return _err("internal error", 500)
|
||||
return _ok({"schedules": [_format_schedule(r) for r in rows]})
|
||||
|
||||
create_model = api.model(
|
||||
"ScheduleCreate",
|
||||
{
|
||||
"instruction": fields.String(required=True),
|
||||
"trigger_type": fields.String(
|
||||
required=False,
|
||||
description="'recurring' (default) or 'once'",
|
||||
),
|
||||
"cron": fields.String(
|
||||
required=False,
|
||||
description="Required when trigger_type == 'recurring'",
|
||||
),
|
||||
"run_at": fields.String(
|
||||
required=False,
|
||||
description="ISO 8601 — required when trigger_type == 'once'",
|
||||
),
|
||||
"timezone": fields.String(required=False),
|
||||
"name": fields.String(required=False),
|
||||
"end_at": fields.String(required=False, description="ISO 8601"),
|
||||
"tool_allowlist": fields.List(fields.String, required=False),
|
||||
"model_id": fields.String(required=False),
|
||||
"token_budget": fields.Integer(required=False),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(create_model)
|
||||
@api.doc(description="Create a schedule (recurring or one-time) for an agent.")
|
||||
def post(self, agent_id):
|
||||
user_id = _user_id()
|
||||
if not user_id:
|
||||
return _err("unauthorized", 401)
|
||||
agent = _agent_owned(agent_id, user_id)
|
||||
if agent is None:
|
||||
return _err("agent not found", 404)
|
||||
data = request.get_json(silent=True) or {}
|
||||
instruction = (data.get("instruction") or "").strip()
|
||||
tz_name = (data.get("timezone") or "UTC").strip() or "UTC"
|
||||
trigger_type = (data.get("trigger_type") or "recurring").strip().lower()
|
||||
if trigger_type not in ("recurring", "once"):
|
||||
return _err("trigger_type must be 'recurring' or 'once'")
|
||||
if not instruction:
|
||||
return _err("instruction is required")
|
||||
try:
|
||||
resolve_timezone(tz_name)
|
||||
except ScheduleValidationError as exc:
|
||||
return _err(str(exc))
|
||||
token_budget = data.get("token_budget")
|
||||
if token_budget is not None:
|
||||
try:
|
||||
token_budget = int(token_budget)
|
||||
if token_budget < 0:
|
||||
raise ValueError
|
||||
except (TypeError, ValueError):
|
||||
return _err("token_budget must be a non-negative integer")
|
||||
with db_readonly() as conn:
|
||||
count = SchedulesRepository(conn).count_active_for_user(user_id)
|
||||
if (
|
||||
settings.SCHEDULE_MAX_PER_USER > 0
|
||||
and count >= settings.SCHEDULE_MAX_PER_USER
|
||||
):
|
||||
return _err("max schedules per user reached", 429)
|
||||
|
||||
if trigger_type == "once":
|
||||
run_at_raw = (data.get("run_at") or "").strip()
|
||||
if not run_at_raw:
|
||||
return _err("run_at is required for trigger_type 'once'")
|
||||
try:
|
||||
fire = parse_run_at(run_at_raw, tz_name)
|
||||
clamp_once_horizon(
|
||||
fire, settings.SCHEDULE_ONCE_MAX_HORIZON,
|
||||
)
|
||||
except ScheduleValidationError as exc:
|
||||
return _err(str(exc))
|
||||
try:
|
||||
with db_session() as conn:
|
||||
created = SchedulesRepository(conn).create(
|
||||
user_id=user_id,
|
||||
agent_id=str(agent["id"]),
|
||||
trigger_type="once",
|
||||
instruction=instruction,
|
||||
run_at=fire,
|
||||
next_run_at=fire,
|
||||
timezone=tz_name,
|
||||
name=(data.get("name") or "").strip() or None,
|
||||
tool_allowlist=data.get("tool_allowlist") or [],
|
||||
model_id=(data.get("model_id") or None),
|
||||
token_budget=token_budget,
|
||||
created_via="ui",
|
||||
)
|
||||
except Exception as exc:
|
||||
current_app.logger.error(
|
||||
"create one-time schedule failed: %s", exc, exc_info=True,
|
||||
)
|
||||
return _err("internal error", 500)
|
||||
return _ok({"schedule": _format_schedule(created)}, status=201)
|
||||
|
||||
cron = (data.get("cron") or "").strip()
|
||||
if not cron:
|
||||
return _err("cron is required")
|
||||
try:
|
||||
parse_cron(cron)
|
||||
except ScheduleValidationError as exc:
|
||||
return _err(str(exc))
|
||||
min_interval = max(0, int(settings.SCHEDULE_MIN_INTERVAL))
|
||||
if min_interval > 0:
|
||||
try:
|
||||
cadence = cron_interval_seconds(cron, tz_name)
|
||||
except ScheduleValidationError as exc:
|
||||
return _err(str(exc))
|
||||
if cadence < min_interval:
|
||||
return _err(
|
||||
"cadence below minimum interval "
|
||||
f"({cadence}s < {min_interval}s)",
|
||||
)
|
||||
end_at = None
|
||||
if data.get("end_at"):
|
||||
try:
|
||||
end_at = datetime.fromisoformat(
|
||||
str(data["end_at"]).replace("Z", "+00:00"),
|
||||
)
|
||||
except ValueError:
|
||||
return _err("invalid end_at")
|
||||
try:
|
||||
next_run = next_cron_run(cron, tz_name, after=datetime.now(timezone.utc))
|
||||
except ScheduleValidationError as exc:
|
||||
return _err(str(exc))
|
||||
if end_at is not None and next_run > end_at:
|
||||
return _err("end_at is before the first cron tick")
|
||||
try:
|
||||
with db_session() as conn:
|
||||
created = SchedulesRepository(conn).create(
|
||||
user_id=user_id,
|
||||
agent_id=str(agent["id"]),
|
||||
trigger_type="recurring",
|
||||
instruction=instruction,
|
||||
cron=cron,
|
||||
timezone=tz_name,
|
||||
next_run_at=next_run,
|
||||
end_at=end_at,
|
||||
name=(data.get("name") or "").strip() or None,
|
||||
tool_allowlist=data.get("tool_allowlist") or [],
|
||||
model_id=(data.get("model_id") or None),
|
||||
token_budget=token_budget,
|
||||
created_via="ui",
|
||||
)
|
||||
except Exception as exc:
|
||||
current_app.logger.error(
|
||||
"create schedule failed: %s", exc, exc_info=True,
|
||||
)
|
||||
return _err("internal error", 500)
|
||||
return _ok({"schedule": _format_schedule(created)}, status=201)
|
||||
|
||||
|
||||
@schedules_ns.route("/schedules/<string:schedule_id>")
|
||||
class ScheduleResource(Resource):
|
||||
@api.doc(description="Get schedule by id.")
|
||||
def get(self, schedule_id):
|
||||
user_id = _user_id()
|
||||
if not user_id:
|
||||
return _err("unauthorized", 401)
|
||||
if not looks_like_uuid(schedule_id):
|
||||
return _err("invalid schedule id", 400)
|
||||
with db_readonly() as conn:
|
||||
row = SchedulesRepository(conn).get(schedule_id, user_id)
|
||||
if row is None:
|
||||
return _err("schedule not found", 404)
|
||||
return _ok({"schedule": _format_schedule(row)})
|
||||
|
||||
@api.doc(description="Edit a schedule's editable fields.")
|
||||
def put(self, schedule_id):
|
||||
user_id = _user_id()
|
||||
if not user_id:
|
||||
return _err("unauthorized", 401)
|
||||
if not looks_like_uuid(schedule_id):
|
||||
return _err("invalid schedule id", 400)
|
||||
data = request.get_json(silent=True) or {}
|
||||
fields_in: Dict[str, Any] = {}
|
||||
if "instruction" in data:
|
||||
inst = (data["instruction"] or "").strip()
|
||||
if not inst:
|
||||
return _err("instruction must not be empty")
|
||||
fields_in["instruction"] = inst
|
||||
if "cron" in data:
|
||||
cron = (data["cron"] or "").strip()
|
||||
try:
|
||||
parse_cron(cron)
|
||||
except ScheduleValidationError as exc:
|
||||
return _err(str(exc))
|
||||
fields_in["cron"] = cron
|
||||
if "timezone" in data:
|
||||
tz_name = (data["timezone"] or "UTC").strip() or "UTC"
|
||||
try:
|
||||
resolve_timezone(tz_name)
|
||||
except ScheduleValidationError as exc:
|
||||
return _err(str(exc))
|
||||
fields_in["timezone"] = tz_name
|
||||
if "tool_allowlist" in data:
|
||||
fields_in["tool_allowlist"] = data["tool_allowlist"] or []
|
||||
if "name" in data:
|
||||
fields_in["name"] = (data["name"] or "").strip() or None
|
||||
if "model_id" in data:
|
||||
fields_in["model_id"] = (data["model_id"] or None)
|
||||
if "token_budget" in data:
|
||||
tb = data["token_budget"]
|
||||
if tb is not None:
|
||||
try:
|
||||
tb = int(tb)
|
||||
if tb < 0:
|
||||
raise ValueError
|
||||
except (TypeError, ValueError):
|
||||
return _err("token_budget must be a non-negative integer")
|
||||
fields_in["token_budget"] = tb
|
||||
if "end_at" in data:
|
||||
if data["end_at"]:
|
||||
try:
|
||||
fields_in["end_at"] = datetime.fromisoformat(
|
||||
str(data["end_at"]).replace("Z", "+00:00"),
|
||||
)
|
||||
except ValueError:
|
||||
return _err("invalid end_at")
|
||||
else:
|
||||
fields_in["end_at"] = None
|
||||
# Recompute next_run_at when cron/tz changes.
|
||||
with db_session() as conn:
|
||||
existing = SchedulesRepository(conn).get(schedule_id, user_id)
|
||||
if existing is None:
|
||||
return _err("schedule not found", 404)
|
||||
if (
|
||||
("cron" in fields_in or "timezone" in fields_in)
|
||||
and existing.get("trigger_type") == "recurring"
|
||||
):
|
||||
cron_eff = fields_in.get("cron") or existing.get("cron")
|
||||
tz_eff = fields_in.get("timezone") or existing.get("timezone")
|
||||
if cron_eff:
|
||||
min_interval = max(0, int(settings.SCHEDULE_MIN_INTERVAL))
|
||||
if min_interval > 0:
|
||||
try:
|
||||
cadence = cron_interval_seconds(cron_eff, tz_eff)
|
||||
except ScheduleValidationError as exc:
|
||||
return _err(str(exc))
|
||||
if cadence < min_interval:
|
||||
return _err(
|
||||
"cadence below minimum interval "
|
||||
f"({cadence}s < {min_interval}s)",
|
||||
)
|
||||
try:
|
||||
fields_in["next_run_at"] = next_cron_run(
|
||||
cron_eff, tz_eff, after=datetime.now(timezone.utc),
|
||||
)
|
||||
except ScheduleValidationError as exc:
|
||||
return _err(str(exc))
|
||||
updated = SchedulesRepository(conn).update(
|
||||
schedule_id, user_id, fields_in,
|
||||
)
|
||||
return _ok({"schedule": _format_schedule(updated or {})})
|
||||
|
||||
@api.doc(description="Pause / resume a schedule.")
|
||||
def patch(self, schedule_id):
|
||||
user_id = _user_id()
|
||||
if not user_id:
|
||||
return _err("unauthorized", 401)
|
||||
if not looks_like_uuid(schedule_id):
|
||||
return _err("invalid schedule id", 400)
|
||||
data = request.get_json(silent=True) or {}
|
||||
action = (data.get("action") or "").lower().strip()
|
||||
if action not in {"pause", "resume"}:
|
||||
return _err("action must be 'pause' or 'resume'")
|
||||
with db_session() as conn:
|
||||
existing = SchedulesRepository(conn).get(schedule_id, user_id)
|
||||
if existing is None:
|
||||
return _err("schedule not found", 404)
|
||||
if existing.get("status") in ("cancelled", "completed"):
|
||||
return _err("schedule is terminal", 409)
|
||||
if action == "pause":
|
||||
fields_in: Dict[str, Any] = {"status": "paused", "next_run_at": None}
|
||||
else:
|
||||
# Resume: recurring recomputes from now; once honours run_at if still future.
|
||||
fields_in = {"status": "active"}
|
||||
if existing.get("trigger_type") == "recurring":
|
||||
try:
|
||||
fields_in["next_run_at"] = next_cron_run(
|
||||
existing["cron"],
|
||||
existing["timezone"],
|
||||
after=datetime.now(timezone.utc),
|
||||
)
|
||||
except ScheduleValidationError as exc:
|
||||
return _err(str(exc))
|
||||
else:
|
||||
new_run_at = data.get("run_at")
|
||||
if new_run_at:
|
||||
try:
|
||||
run_at_dt = datetime.fromisoformat(
|
||||
str(new_run_at).replace("Z", "+00:00"),
|
||||
)
|
||||
except ValueError:
|
||||
return _err("invalid run_at")
|
||||
if run_at_dt <= datetime.now(timezone.utc):
|
||||
return _err(
|
||||
"run_at must be in the future to resume", 409,
|
||||
)
|
||||
fields_in["next_run_at"] = run_at_dt
|
||||
fields_in["run_at"] = run_at_dt
|
||||
else:
|
||||
run_at = existing.get("run_at")
|
||||
if run_at:
|
||||
if isinstance(run_at, str):
|
||||
try:
|
||||
run_at_dt = datetime.fromisoformat(
|
||||
run_at.replace("Z", "+00:00"),
|
||||
)
|
||||
except ValueError:
|
||||
return _err("schedule run_at is invalid")
|
||||
else:
|
||||
run_at_dt = run_at
|
||||
if run_at_dt <= datetime.now(timezone.utc):
|
||||
return _err(
|
||||
"the once schedule has elapsed; recreate "
|
||||
"it or supply a new run_at",
|
||||
409,
|
||||
)
|
||||
fields_in["next_run_at"] = run_at_dt
|
||||
updated = SchedulesRepository(conn).update(
|
||||
schedule_id, user_id, fields_in,
|
||||
)
|
||||
if action == "resume":
|
||||
SchedulesRepository(conn).reset_failure_count(schedule_id)
|
||||
return _ok({"schedule": _format_schedule(updated or {})})
|
||||
|
||||
@api.doc(description="Cancel / delete a schedule.")
|
||||
def delete(self, schedule_id):
|
||||
user_id = _user_id()
|
||||
if not user_id:
|
||||
return _err("unauthorized", 401)
|
||||
if not looks_like_uuid(schedule_id):
|
||||
return _err("invalid schedule id", 400)
|
||||
with db_session() as conn:
|
||||
ok = SchedulesRepository(conn).delete(schedule_id, user_id)
|
||||
if not ok:
|
||||
return _err("schedule not found", 404)
|
||||
return _ok({"success": True})
|
||||
|
||||
|
||||
@schedules_ns.route("/schedules/<string:schedule_id>/run")
|
||||
class ScheduleRunNow(Resource):
|
||||
@api.doc(description="Run a schedule immediately (trigger_source='manual').")
|
||||
def post(self, schedule_id):
|
||||
user_id = _user_id()
|
||||
if not user_id:
|
||||
return _err("unauthorized", 401)
|
||||
if not looks_like_uuid(schedule_id):
|
||||
return _err("invalid schedule id", 400)
|
||||
# FOR UPDATE serializes concurrent Run-Now POSTs (timestamp-unique
|
||||
# scheduled_for values would otherwise sneak past the unique index).
|
||||
with db_session() as conn:
|
||||
schedule = SchedulesRepository(conn).get_for_update(
|
||||
schedule_id, user_id,
|
||||
)
|
||||
if schedule is None:
|
||||
return _err("schedule not found", 404)
|
||||
if schedule.get("status") == "cancelled":
|
||||
return _err("schedule is cancelled", 409)
|
||||
if ScheduleRunsRepository(conn).has_active_run(schedule_id):
|
||||
return _err("a run is already in flight", 409)
|
||||
scheduled_for = datetime.now(timezone.utc)
|
||||
agent_id_raw = schedule.get("agent_id")
|
||||
run = ScheduleRunsRepository(conn).record_pending(
|
||||
schedule_id,
|
||||
user_id,
|
||||
str(agent_id_raw) if agent_id_raw else None,
|
||||
scheduled_for,
|
||||
trigger_source="manual",
|
||||
)
|
||||
if run is None:
|
||||
return _err("could not claim run (concurrent dispatch)", 409)
|
||||
# Import inside the handler to avoid a circular tasks <-> routes import.
|
||||
try:
|
||||
from application.api.user.tasks import execute_scheduled_run
|
||||
execute_scheduled_run.apply_async(args=[str(run["id"])], queue="docsgpt")
|
||||
except Exception as exc:
|
||||
current_app.logger.error(
|
||||
"run-now enqueue failed: %s", exc, exc_info=True,
|
||||
)
|
||||
return _err("enqueue failed", 500)
|
||||
return _ok({"run": _format_run(run)}, status=202)
|
||||
|
||||
|
||||
@schedules_ns.route("/schedules/<string:schedule_id>/runs")
|
||||
class ScheduleRunList(Resource):
|
||||
@api.doc(
|
||||
description="Paginated run log for a schedule.",
|
||||
params={"limit": "Page size (default 50)", "offset": "Page offset"},
|
||||
)
|
||||
def get(self, schedule_id):
|
||||
user_id = _user_id()
|
||||
if not user_id:
|
||||
return _err("unauthorized", 401)
|
||||
if not looks_like_uuid(schedule_id):
|
||||
return _err("invalid schedule id", 400)
|
||||
try:
|
||||
limit = max(1, min(int(request.args.get("limit", 50)), 200))
|
||||
except (TypeError, ValueError):
|
||||
limit = 50
|
||||
try:
|
||||
offset = max(0, int(request.args.get("offset", 0)))
|
||||
except (TypeError, ValueError):
|
||||
offset = 0
|
||||
with db_readonly() as conn:
|
||||
schedule = SchedulesRepository(conn).get(schedule_id, user_id)
|
||||
if schedule is None:
|
||||
return _err("schedule not found", 404)
|
||||
rows = ScheduleRunsRepository(conn).list_runs(
|
||||
schedule_id, user_id, limit=limit, offset=offset,
|
||||
)
|
||||
return _ok(
|
||||
{
|
||||
"runs": [_format_run(r) for r in rows],
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@schedules_ns.route("/schedules/<string:schedule_id>/runs/<string:run_id>")
|
||||
class ScheduleRunDetail(Resource):
|
||||
@api.doc(description="Full output / error for a single run.")
|
||||
def get(self, schedule_id, run_id):
|
||||
user_id = _user_id()
|
||||
if not user_id:
|
||||
return _err("unauthorized", 401)
|
||||
if not looks_like_uuid(schedule_id) or not looks_like_uuid(run_id):
|
||||
return _err("invalid id", 400)
|
||||
with db_readonly() as conn:
|
||||
schedule = SchedulesRepository(conn).get(schedule_id, user_id)
|
||||
if schedule is None:
|
||||
return _err("schedule not found", 404)
|
||||
run = ScheduleRunsRepository(conn).get(run_id, user_id)
|
||||
if run is None or str(run.get("schedule_id")) != str(
|
||||
schedule["id"]
|
||||
):
|
||||
return _err("run not found", 404)
|
||||
return _ok({"run": _format_run(run)})
|
||||
@@ -204,8 +204,64 @@ def ingest_connector_task(
|
||||
return resp
|
||||
|
||||
|
||||
@celery.task(bind=True, acks_late=False)
|
||||
def dispatch_scheduled_runs(self):
|
||||
"""Beat-driven scheduler poller (body in scheduler_dispatcher)."""
|
||||
from application.api.user.scheduler_dispatcher import dispatch_due_runs
|
||||
|
||||
return dispatch_due_runs()
|
||||
|
||||
|
||||
@celery.task(
|
||||
bind=True,
|
||||
acks_late=True,
|
||||
# Not DURABLE_TASK: agent runs have side effects; blind retry would double them.
|
||||
autoretry_for=(),
|
||||
max_retries=0,
|
||||
)
|
||||
def execute_scheduled_run(self, run_id):
|
||||
"""Execute one scheduled run; soft-time-limit honors SCHEDULE_RUN_TIMEOUT."""
|
||||
from application.api.user.scheduler_worker import execute_scheduled_run_body
|
||||
|
||||
return execute_scheduled_run_body(run_id, getattr(self.request, "id", None))
|
||||
|
||||
|
||||
# Bind runtime soft-time-limit so the prefork worker can raise mid-agent.
|
||||
try:
|
||||
from application.core.settings import settings as _scheduler_settings
|
||||
execute_scheduled_run.soft_time_limit = max(
|
||||
30, int(_scheduler_settings.SCHEDULE_RUN_TIMEOUT),
|
||||
)
|
||||
execute_scheduled_run.time_limit = (
|
||||
execute_scheduled_run.soft_time_limit + 60
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@celery.task(bind=True, acks_late=False)
|
||||
def cleanup_schedule_runs(self):
|
||||
"""Trim ``schedule_runs`` per ``SCHEDULE_RUN_OUTPUT_RETENTION_DAYS``."""
|
||||
from application.core.settings import settings
|
||||
if not settings.POSTGRES_URI:
|
||||
return {"deleted": 0, "skipped": "POSTGRES_URI not set"}
|
||||
|
||||
from application.storage.db.engine import get_engine
|
||||
from application.storage.db.repositories.schedule_runs import (
|
||||
ScheduleRunsRepository,
|
||||
)
|
||||
|
||||
ttl_days = settings.SCHEDULE_RUN_OUTPUT_RETENTION_DAYS
|
||||
engine = get_engine()
|
||||
with engine.begin() as conn:
|
||||
deleted = ScheduleRunsRepository(conn).cleanup_older_than(ttl_days)
|
||||
return {"deleted": deleted, "ttl_days": ttl_days}
|
||||
|
||||
|
||||
@celery.on_after_configure.connect
|
||||
def setup_periodic_tasks(sender, **kwargs):
|
||||
from application.core.settings import settings
|
||||
|
||||
sender.add_periodic_task(
|
||||
timedelta(days=1),
|
||||
schedule_syncs.s("daily"),
|
||||
@@ -251,6 +307,22 @@ def setup_periodic_tasks(sender, **kwargs):
|
||||
cleanup_message_events.s(),
|
||||
name="cleanup-message-events",
|
||||
)
|
||||
sender.add_periodic_task(
|
||||
timedelta(hours=24),
|
||||
cleanup_orphan_memories.s(),
|
||||
name="cleanup-orphan-memories",
|
||||
)
|
||||
# Scheduler dispatcher and run-log trim.
|
||||
sender.add_periodic_task(
|
||||
timedelta(seconds=max(15, settings.SCHEDULE_DISPATCHER_INTERVAL)),
|
||||
dispatch_scheduled_runs.s(),
|
||||
name="dispatch-scheduled-runs",
|
||||
)
|
||||
sender.add_periodic_task(
|
||||
timedelta(hours=24),
|
||||
cleanup_schedule_runs.s(),
|
||||
name="cleanup-schedule-runs",
|
||||
)
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
@@ -339,6 +411,29 @@ def cleanup_message_events(self):
|
||||
return {"deleted": deleted, "ttl_days": ttl_days}
|
||||
|
||||
|
||||
@celery.task(bind=True, acks_late=False)
|
||||
def cleanup_orphan_memories(self):
|
||||
"""Sweep orphan memories left by the 0009 FK-to-trigger orphan window.
|
||||
|
||||
A ``memories`` INSERT for a real ``tool_id`` racing a ``user_tools``
|
||||
DELETE leaves a permanent orphan the dropped FK would have rejected.
|
||||
Default-tool synthetic ids are preserved (legitimate built-in data).
|
||||
"""
|
||||
from application.core.settings import settings
|
||||
if not settings.POSTGRES_URI:
|
||||
return {"deleted": 0, "skipped": "POSTGRES_URI not set"}
|
||||
|
||||
from application.agents.default_tools import default_tool_ids
|
||||
from application.storage.db.engine import get_engine
|
||||
from application.storage.db.repositories.memories import MemoriesRepository
|
||||
|
||||
keep_tool_ids = list(default_tool_ids().values())
|
||||
engine = get_engine()
|
||||
with engine.begin() as conn:
|
||||
deleted = MemoriesRepository(conn).delete_orphans(keep_tool_ids)
|
||||
return {"deleted": deleted}
|
||||
|
||||
|
||||
@celery.task(bind=True, acks_late=False)
|
||||
def version_check_task(self):
|
||||
"""Periodic anonymous version check.
|
||||
|
||||
@@ -3,6 +3,15 @@
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.agents.default_tools import (
|
||||
builtin_agent_tools_for_management,
|
||||
BUILTIN_AGENT_TOOLS,
|
||||
default_tool_name_for_id,
|
||||
default_tools_for_management,
|
||||
is_builtin_agent_tool_id,
|
||||
is_default_tool_id,
|
||||
is_synthesized_tool_id,
|
||||
)
|
||||
from application.agents.tools.spec_parser import parse_spec
|
||||
from application.agents.tools.tool_manager import ToolManager
|
||||
from application.api import api
|
||||
@@ -11,6 +20,7 @@ from application.security.encryption import decrypt_credentials, encrypt_credent
|
||||
from application.storage.db.repositories.notes import NotesRepository
|
||||
from application.storage.db.repositories.todos import TodosRepository
|
||||
from application.storage.db.repositories.user_tools import UserToolsRepository
|
||||
from application.storage.db.repositories.users import UsersRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.utils import check_required_fields, validate_function_name
|
||||
|
||||
@@ -208,6 +218,7 @@ class GetTools(Resource):
|
||||
user = decoded_token.get("sub")
|
||||
with db_readonly() as conn:
|
||||
rows = UserToolsRepository(conn).list_for_user(user)
|
||||
user_doc = UsersRepository(conn).get(user)
|
||||
user_tools = []
|
||||
for row in rows:
|
||||
tool_copy = _row_to_api(row)
|
||||
@@ -227,6 +238,29 @@ class GetTools(Resource):
|
||||
tool_copy["config"].pop("encrypted_credentials", None)
|
||||
|
||||
user_tools.append(tool_copy)
|
||||
|
||||
# ``scheduler`` is dual-registered (default chat tool + agent-
|
||||
# selectable builtin) and resolves to the same synthetic uuid5 id.
|
||||
# Surface a single row with both flags so the frontend can show it
|
||||
# in the management page (toggle) and the agent picker.
|
||||
seen_ids: set = set()
|
||||
for default_row in default_tools_for_management(user_doc):
|
||||
default_copy = _row_to_api(default_row)
|
||||
default_copy["default"] = True
|
||||
if default_copy.get("name") in BUILTIN_AGENT_TOOLS:
|
||||
default_copy["builtin"] = True
|
||||
seen_ids.add(str(default_copy["id"]))
|
||||
user_tools.append(default_copy)
|
||||
# Builtins (e.g. scheduler) hidden from Add-Tool catalog, visible
|
||||
# to the agent picker. Skip ones already added via the default
|
||||
# path — both registries share ``_DEFAULT_TOOL_NAMESPACE``.
|
||||
for builtin_row in builtin_agent_tools_for_management():
|
||||
builtin_copy = _row_to_api(builtin_row)
|
||||
if str(builtin_copy["id"]) in seen_ids:
|
||||
continue
|
||||
builtin_copy["builtin"] = True
|
||||
builtin_copy["default"] = False
|
||||
user_tools.append(builtin_copy)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error getting user tools: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
@@ -367,6 +401,46 @@ class UpdateTool(Resource):
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
# Default-tool branch first: a dual-registered tool (e.g. ``scheduler``)
|
||||
# matches BOTH ``is_default_tool_id`` and ``is_builtin_agent_tool_id``.
|
||||
# The toggle in Tools settings is the per-user opt-out for the
|
||||
# agentless default — it must reach the ``set_default_tool_enabled``
|
||||
# path, not the builtin "not editable" reject.
|
||||
if is_default_tool_id(data["id"]):
|
||||
if "status" not in data:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Default tools are not editable; "
|
||||
"only their on/off status can be changed.",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
tool_name = default_tool_name_for_id(data["id"])
|
||||
try:
|
||||
with db_session() as conn:
|
||||
UsersRepository(conn).set_default_tool_enabled(
|
||||
user, tool_name, bool(data["status"])
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error updating default tool: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
if is_builtin_agent_tool_id(data["id"]):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Built-in agent tools are not editable; "
|
||||
"add them to an agent via the agent picker.",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
try:
|
||||
update_data: dict = {}
|
||||
for key in ("name", "displayName", "customName", "description", "actions"):
|
||||
@@ -471,6 +545,17 @@ class UpdateToolConfig(Resource):
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
if is_synthesized_tool_id(data["id"]):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Default and built-in tools are config-free "
|
||||
"and cannot be configured.",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
try:
|
||||
with db_session() as conn:
|
||||
repo = UserToolsRepository(conn)
|
||||
@@ -550,6 +635,16 @@ class UpdateToolActions(Resource):
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
if is_synthesized_tool_id(data["id"]):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Default and built-in tools' actions are not editable.",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
try:
|
||||
with db_session() as conn:
|
||||
repo = UserToolsRepository(conn)
|
||||
@@ -595,6 +690,27 @@ class UpdateToolStatus(Resource):
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
# Default branch first so a dual-registered id (e.g. ``scheduler``)
|
||||
# writes the per-user opt-out instead of being rejected as a
|
||||
# not-editable builtin (both predicates match the same uuid5).
|
||||
if is_default_tool_id(data["id"]):
|
||||
tool_name = default_tool_name_for_id(data["id"])
|
||||
with db_session() as conn:
|
||||
UsersRepository(conn).set_default_tool_enabled(
|
||||
user, tool_name, bool(data["status"])
|
||||
)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
if is_builtin_agent_tool_id(data["id"]):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Built-in agent tools have no per-user "
|
||||
"toggle; add them to an agent via the agent picker.",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
with db_session() as conn:
|
||||
repo = UserToolsRepository(conn)
|
||||
tool_doc = repo.get_any(data["id"], user)
|
||||
@@ -633,6 +749,16 @@ class DeleteTool(Resource):
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
if is_synthesized_tool_id(data["id"]):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Built-in tools cannot be deleted; disable them instead.",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
try:
|
||||
with db_session() as conn:
|
||||
repo = UserToolsRepository(conn)
|
||||
|
||||
@@ -48,6 +48,12 @@ ensure_database_ready(
|
||||
logger=logging.getLogger("application.app"),
|
||||
)
|
||||
|
||||
from application.agents.default_tools import ( # noqa: E402
|
||||
validate_default_chat_tools,
|
||||
)
|
||||
|
||||
validate_default_chat_tools()
|
||||
|
||||
app = Flask(__name__)
|
||||
app.register_blueprint(user)
|
||||
app.register_blueprint(answer)
|
||||
|
||||
@@ -189,6 +189,11 @@ class Settings(BaseSettings):
|
||||
# Tool pre-fetch settings
|
||||
ENABLE_TOOL_PREFETCH: bool = True
|
||||
|
||||
# Config-free tools on by default in agentless chats. ``scheduler`` is
|
||||
# dual-registered (also in ``BUILTIN_AGENT_TOOLS``) so the same synthetic id
|
||||
# resolves whether reached via defaults or the agent picker.
|
||||
DEFAULT_CHAT_TOOLS: list = ["memory", "read_webpage", "scheduler"]
|
||||
|
||||
# Conversation Compression Settings
|
||||
ENABLE_CONVERSATION_COMPRESSION: bool = True
|
||||
COMPRESSION_THRESHOLD_PERCENTAGE: float = 0.8 # Trigger at 80% of context
|
||||
@@ -232,6 +237,16 @@ class Settings(BaseSettings):
|
||||
# flows without unbounded table growth.
|
||||
MESSAGE_EVENTS_RETENTION_DAYS: int = 14
|
||||
|
||||
# Scheduler (see scheduler.md).
|
||||
SCHEDULE_DISPATCHER_INTERVAL: int = 30
|
||||
SCHEDULE_MIN_INTERVAL: int = 900
|
||||
SCHEDULE_MAX_PER_USER: int = 50
|
||||
SCHEDULE_RUN_TIMEOUT: int = 600
|
||||
SCHEDULE_MISFIRE_GRACE: int = 60
|
||||
SCHEDULE_AUTOPAUSE_FAILURES: int = 3
|
||||
SCHEDULE_ONCE_MAX_HORIZON: int = 31_536_000
|
||||
SCHEDULE_RUN_OUTPUT_RETENTION_DAYS: int = 90
|
||||
|
||||
@field_validator("POSTGRES_URI", mode="before")
|
||||
@classmethod
|
||||
def _normalize_postgres_uri_validator(cls, v):
|
||||
|
||||
@@ -850,6 +850,79 @@ class LLMHandler(ABC):
|
||||
tools_dict, call, llm_class
|
||||
)
|
||||
if pause_info:
|
||||
# Headless (scheduled / webhook): synthesize a denial tool message
|
||||
# so the LLM finishes gracefully instead of stalling on a pause
|
||||
# nobody will resolve, then journal so the reconciler sees it.
|
||||
if pause_info.get("pause_type") == "headless_denied":
|
||||
deny_reason = pause_info.get(
|
||||
"deny_reason", "Tool blocked in headless mode."
|
||||
)
|
||||
args_str = (
|
||||
json.dumps(call.arguments)
|
||||
if isinstance(call.arguments, dict)
|
||||
else (call.arguments or "{}")
|
||||
)
|
||||
tool_call_obj = {
|
||||
"id": pause_info["call_id"],
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": call.name,
|
||||
"arguments": args_str,
|
||||
},
|
||||
}
|
||||
if getattr(call, "thought_signature", None):
|
||||
tool_call_obj["thought_signature"] = call.thought_signature
|
||||
updated_messages.append({
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [tool_call_obj],
|
||||
})
|
||||
denial_call = ToolCall(
|
||||
id=pause_info["call_id"],
|
||||
name=call.name,
|
||||
arguments=call.arguments,
|
||||
)
|
||||
updated_messages.append(
|
||||
self.create_tool_message(
|
||||
denial_call,
|
||||
f"Tool denied (headless): {deny_reason}",
|
||||
)
|
||||
)
|
||||
if hasattr(agent.tool_executor, "headless_denials"):
|
||||
agent.tool_executor.headless_denials.append(pause_info)
|
||||
from application.agents.tool_executor import (
|
||||
_mark_failed,
|
||||
_record_proposed,
|
||||
)
|
||||
|
||||
_record_proposed(
|
||||
pause_info["call_id"],
|
||||
pause_info["tool_name"],
|
||||
pause_info["action_name"],
|
||||
pause_info.get("arguments") or {},
|
||||
tool_id=pause_info.get("tool_id"),
|
||||
)
|
||||
_mark_failed(
|
||||
pause_info["call_id"],
|
||||
f"headless: {deny_reason}",
|
||||
)
|
||||
yield {
|
||||
"type": "tool_call",
|
||||
"data": {
|
||||
"tool_name": pause_info["tool_name"],
|
||||
"call_id": pause_info["call_id"],
|
||||
"action_name": pause_info.get(
|
||||
"llm_name", pause_info["name"]
|
||||
),
|
||||
"arguments": pause_info["arguments"],
|
||||
"status": "denied",
|
||||
"error": deny_reason,
|
||||
"error_type": pause_info.get(
|
||||
"error_type", "tool_not_allowed"
|
||||
),
|
||||
},
|
||||
}
|
||||
continue
|
||||
# Yield pause event so the client knows this tool is waiting
|
||||
yield {
|
||||
"type": "tool_call",
|
||||
|
||||
@@ -7,6 +7,7 @@ beautifulsoup4==4.14.3
|
||||
cel-python==0.5.0
|
||||
celery==5.6.3
|
||||
celery-redbeat==2.3.3
|
||||
croniter==6.2.2
|
||||
cryptography==46.0.7
|
||||
dataclasses-json==0.6.7
|
||||
defusedxml==0.7.1
|
||||
|
||||
@@ -47,6 +47,7 @@ users_table = Table(
|
||||
nullable=False,
|
||||
server_default='{"pinned": [], "shared_with_me": []}',
|
||||
),
|
||||
Column("tool_preferences", JSONB, nullable=False, server_default="{}"),
|
||||
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
)
|
||||
@@ -254,7 +255,8 @@ memories_table = Table(
|
||||
metadata,
|
||||
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||
Column("user_id", Text, nullable=False),
|
||||
Column("tool_id", UUID(as_uuid=True), ForeignKey("user_tools.id", ondelete="CASCADE")),
|
||||
# No FK since 0009 — delete-cascade preserved by trigger.
|
||||
Column("tool_id", UUID(as_uuid=True)),
|
||||
Column("path", Text, nullable=False),
|
||||
Column("content", Text, nullable=False),
|
||||
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
@@ -598,3 +600,74 @@ workflow_runs_table = Table(
|
||||
Column("ended_at", DateTime(timezone=True)),
|
||||
Column("legacy_mongo_id", Text),
|
||||
)
|
||||
|
||||
|
||||
# --- Scheduler (migration 0010) ---------------------------------------------
|
||||
|
||||
schedules_table = Table(
|
||||
"schedules",
|
||||
metadata,
|
||||
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||
Column("user_id", Text, nullable=False),
|
||||
# Nullable as of 0011: agentless chats create one-time schedules whose
|
||||
# run is built ephemerally at fire time from system defaults.
|
||||
Column(
|
||||
"agent_id",
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("agents.id", ondelete="CASCADE"),
|
||||
),
|
||||
Column("trigger_type", Text, nullable=False),
|
||||
Column("name", Text),
|
||||
Column("instruction", Text, nullable=False),
|
||||
Column("status", Text, nullable=False, server_default="active"),
|
||||
Column("cron", Text),
|
||||
Column("run_at", DateTime(timezone=True)),
|
||||
Column("timezone", Text, nullable=False, server_default="UTC"),
|
||||
Column("next_run_at", DateTime(timezone=True)),
|
||||
Column("last_run_at", DateTime(timezone=True)),
|
||||
Column("end_at", DateTime(timezone=True)),
|
||||
Column("tool_allowlist", JSONB, nullable=False, server_default="[]"),
|
||||
Column("model_id", Text),
|
||||
Column("token_budget", Integer),
|
||||
Column(
|
||||
"origin_conversation_id",
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("conversations.id", ondelete="SET NULL"),
|
||||
),
|
||||
Column("created_via", Text, nullable=False, server_default="ui"),
|
||||
Column("consecutive_failure_count", Integer, nullable=False, server_default="0"),
|
||||
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
)
|
||||
|
||||
schedule_runs_table = Table(
|
||||
"schedule_runs",
|
||||
metadata,
|
||||
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||
Column(
|
||||
"schedule_id",
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("schedules.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
Column("user_id", Text, nullable=False),
|
||||
# Nullable as of 0011 (mirrors ``schedules.agent_id``).
|
||||
Column("agent_id", UUID(as_uuid=True)),
|
||||
Column("status", Text, nullable=False, server_default="pending"),
|
||||
Column("scheduled_for", DateTime(timezone=True), nullable=False),
|
||||
Column("trigger_source", Text, nullable=False, server_default="cron"),
|
||||
Column("started_at", DateTime(timezone=True)),
|
||||
Column("finished_at", DateTime(timezone=True)),
|
||||
Column("output", Text),
|
||||
Column("output_truncated", Boolean, nullable=False, server_default="false"),
|
||||
Column("error", Text),
|
||||
Column("error_type", Text),
|
||||
Column("prompt_tokens", Integer, nullable=False, server_default="0"),
|
||||
Column("generated_tokens", Integer, nullable=False, server_default="0"),
|
||||
Column("conversation_id", UUID(as_uuid=True)),
|
||||
Column("message_id", UUID(as_uuid=True)),
|
||||
Column("celery_task_id", Text),
|
||||
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
UniqueConstraint("schedule_id", "scheduled_for", name="schedule_runs_dedup_uidx"),
|
||||
)
|
||||
|
||||
@@ -86,6 +86,22 @@ class MemoriesRepository:
|
||||
)
|
||||
return result.rowcount
|
||||
|
||||
def delete_orphans(self, keep_tool_ids: Optional[list[str]] = None) -> int:
|
||||
"""Delete memories whose tool_id has no user_tools row, except keep_tool_ids."""
|
||||
keep = [str(tid) for tid in (keep_tool_ids or [])]
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
DELETE FROM memories
|
||||
WHERE tool_id IS NOT NULL
|
||||
AND tool_id NOT IN (SELECT id FROM user_tools)
|
||||
AND NOT (tool_id = ANY(CAST(:keep AS uuid[])))
|
||||
"""
|
||||
),
|
||||
{"keep": keep},
|
||||
)
|
||||
return result.rowcount
|
||||
|
||||
def update_path(self, user_id: str, tool_id: str, old_path: str, new_path: str) -> bool:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
|
||||
278
application/storage/db/repositories/schedule_runs.py
Normal file
278
application/storage/db/repositories/schedule_runs.py
Normal file
@@ -0,0 +1,278 @@
|
||||
"""Repository for ``schedule_runs`` (record_pending is the dedup primitive)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
|
||||
|
||||
_ALLOWED_UPDATES = frozenset(
|
||||
{
|
||||
"status", "started_at", "finished_at", "output", "output_truncated",
|
||||
"error", "error_type", "prompt_tokens", "generated_tokens",
|
||||
"conversation_id", "message_id", "celery_task_id",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class ScheduleRunsRepository:
|
||||
"""CRUD + dedup insert + reconciliation sweep for ``schedule_runs``."""
|
||||
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def record_pending(
|
||||
self,
|
||||
schedule_id: str,
|
||||
user_id: str,
|
||||
agent_id: Optional[str],
|
||||
scheduled_for: datetime,
|
||||
*,
|
||||
trigger_source: str = "cron",
|
||||
) -> Optional[dict]:
|
||||
"""Insert a ``pending`` row; ``None`` on conflict (already claimed)."""
|
||||
row = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO schedule_runs (
|
||||
schedule_id, user_id, agent_id, scheduled_for,
|
||||
trigger_source, status
|
||||
) VALUES (
|
||||
CAST(:schedule_id AS uuid),
|
||||
:user_id,
|
||||
CAST(:agent_id AS uuid),
|
||||
:scheduled_for,
|
||||
:trigger_source,
|
||||
'pending'
|
||||
)
|
||||
ON CONFLICT (schedule_id, scheduled_for) DO NOTHING
|
||||
RETURNING *
|
||||
"""
|
||||
),
|
||||
{
|
||||
"schedule_id": str(schedule_id),
|
||||
"user_id": user_id,
|
||||
"agent_id": str(agent_id) if agent_id else None,
|
||||
"scheduled_for": scheduled_for,
|
||||
"trigger_source": trigger_source,
|
||||
},
|
||||
).fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def record_skipped(
|
||||
self,
|
||||
schedule_id: str,
|
||||
user_id: str,
|
||||
agent_id: Optional[str],
|
||||
scheduled_for: datetime,
|
||||
*,
|
||||
error_type: str,
|
||||
error: Optional[str] = None,
|
||||
) -> Optional[dict]:
|
||||
"""Write a terminal ``skipped`` row; returns ``None`` on conflict."""
|
||||
row = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO schedule_runs (
|
||||
schedule_id, user_id, agent_id, scheduled_for,
|
||||
trigger_source, status, started_at, finished_at,
|
||||
error, error_type
|
||||
) VALUES (
|
||||
CAST(:schedule_id AS uuid),
|
||||
:user_id,
|
||||
CAST(:agent_id AS uuid),
|
||||
:scheduled_for,
|
||||
'cron',
|
||||
'skipped',
|
||||
now(),
|
||||
now(),
|
||||
:error,
|
||||
:error_type
|
||||
)
|
||||
ON CONFLICT (schedule_id, scheduled_for) DO NOTHING
|
||||
RETURNING *
|
||||
"""
|
||||
),
|
||||
{
|
||||
"schedule_id": str(schedule_id),
|
||||
"user_id": user_id,
|
||||
"agent_id": str(agent_id) if agent_id else None,
|
||||
"scheduled_for": scheduled_for,
|
||||
"error": error,
|
||||
"error_type": error_type,
|
||||
},
|
||||
).fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def get(self, run_id: str, user_id: str) -> Optional[dict]:
|
||||
"""Fetch an owned run row."""
|
||||
row = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM schedule_runs "
|
||||
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||
),
|
||||
{"id": str(run_id), "user_id": user_id},
|
||||
).fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def get_internal(self, run_id: str) -> Optional[dict]:
|
||||
"""Fetch a run row with no ownership scoping (worker-only)."""
|
||||
row = self._conn.execute(
|
||||
text("SELECT * FROM schedule_runs WHERE id = CAST(:id AS uuid)"),
|
||||
{"id": str(run_id)},
|
||||
).fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def has_active_run(self, schedule_id: str) -> bool:
|
||||
"""True iff a ``pending``/``running`` run exists for the schedule."""
|
||||
scalar = self._conn.execute(
|
||||
text(
|
||||
"SELECT 1 FROM schedule_runs "
|
||||
"WHERE schedule_id = CAST(:id AS uuid) "
|
||||
"AND status IN ('pending', 'running') "
|
||||
"LIMIT 1"
|
||||
),
|
||||
{"id": str(schedule_id)},
|
||||
).first()
|
||||
return scalar is not None
|
||||
|
||||
def list_runs(
|
||||
self,
|
||||
schedule_id: str,
|
||||
user_id: str,
|
||||
*,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> list[dict]:
|
||||
"""Paginated newest-first run log for an owned schedule."""
|
||||
rows = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT * FROM schedule_runs
|
||||
WHERE schedule_id = CAST(:id AS uuid) AND user_id = :user_id
|
||||
ORDER BY scheduled_for DESC
|
||||
LIMIT :limit OFFSET :offset
|
||||
"""
|
||||
),
|
||||
{
|
||||
"id": str(schedule_id),
|
||||
"user_id": user_id,
|
||||
"limit": int(limit),
|
||||
"offset": int(offset),
|
||||
},
|
||||
).fetchall()
|
||||
return [row_to_dict(r) for r in rows]
|
||||
|
||||
def update(self, run_id: str, fields: dict) -> Optional[dict]:
|
||||
"""Apply a whitelisted partial update to a run row."""
|
||||
filtered = {k: v for k, v in fields.items() if k in _ALLOWED_UPDATES}
|
||||
if not filtered:
|
||||
return self.get_internal(run_id)
|
||||
set_parts: list[str] = []
|
||||
params: dict[str, Any] = {"id": str(run_id)}
|
||||
for key, val in filtered.items():
|
||||
if key in ("conversation_id", "message_id"):
|
||||
set_parts.append(f"{key} = CAST(:{key} AS uuid)")
|
||||
params[key] = str(val) if val else None
|
||||
else:
|
||||
set_parts.append(f"{key} = :{key}")
|
||||
params[key] = val
|
||||
sql = (
|
||||
"UPDATE schedule_runs SET " + ", ".join(set_parts) +
|
||||
" WHERE id = CAST(:id AS uuid) RETURNING *"
|
||||
)
|
||||
row = self._conn.execute(text(sql), params).fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def mark_running(self, run_id: str, celery_task_id: Optional[str]) -> bool:
|
||||
"""Flip ``pending`` → ``running`` and stamp ``started_at``."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE schedule_runs
|
||||
SET status = 'running',
|
||||
started_at = now(),
|
||||
celery_task_id = :celery_task_id
|
||||
WHERE id = CAST(:id AS uuid)
|
||||
AND status = 'pending'
|
||||
"""
|
||||
),
|
||||
{"id": str(run_id), "celery_task_id": celery_task_id},
|
||||
)
|
||||
return (result.rowcount or 0) > 0
|
||||
|
||||
def list_stuck_running(
|
||||
self, *, age_minutes: int = 15, limit: int = 50,
|
||||
) -> list[dict]:
|
||||
"""Lock ``running`` rows past the soft-time-limit envelope."""
|
||||
rows = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT * FROM schedule_runs
|
||||
WHERE status = 'running'
|
||||
AND started_at IS NOT NULL
|
||||
AND started_at < now() - make_interval(mins => :age)
|
||||
ORDER BY started_at ASC
|
||||
LIMIT :limit
|
||||
FOR UPDATE SKIP LOCKED
|
||||
"""
|
||||
),
|
||||
{"age": int(age_minutes), "limit": int(limit)},
|
||||
).fetchall()
|
||||
return [row_to_dict(r) for r in rows]
|
||||
|
||||
def list_stuck_pending(
|
||||
self, *, age_minutes: int = 15, limit: int = 50,
|
||||
) -> list[dict]:
|
||||
"""Lock 'pending' rows whose worker never picked them up (created_at-based)."""
|
||||
rows = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT * FROM schedule_runs
|
||||
WHERE status = 'pending'
|
||||
AND started_at IS NULL
|
||||
AND created_at < now() - make_interval(mins => :age)
|
||||
ORDER BY created_at ASC
|
||||
LIMIT :limit
|
||||
FOR UPDATE SKIP LOCKED
|
||||
"""
|
||||
),
|
||||
{"age": int(age_minutes), "limit": int(limit)},
|
||||
).fetchall()
|
||||
return [row_to_dict(r) for r in rows]
|
||||
|
||||
def cleanup_older_than(
|
||||
self,
|
||||
ttl_days: int,
|
||||
*,
|
||||
keep_recent_per_schedule: int = 50,
|
||||
) -> int:
|
||||
"""Trim run rows older than ``ttl_days``, keeping the recent log slice."""
|
||||
if ttl_days <= 0:
|
||||
raise ValueError("ttl_days must be positive")
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
DELETE FROM schedule_runs
|
||||
WHERE id IN (
|
||||
SELECT id FROM (
|
||||
SELECT id,
|
||||
ROW_NUMBER() OVER (
|
||||
PARTITION BY schedule_id
|
||||
ORDER BY scheduled_for DESC
|
||||
) AS rn,
|
||||
created_at
|
||||
FROM schedule_runs
|
||||
) ranked
|
||||
WHERE ranked.rn > :keep
|
||||
AND ranked.created_at < now() - make_interval(days => :ttl)
|
||||
)
|
||||
"""
|
||||
),
|
||||
{"keep": int(keep_recent_per_schedule), "ttl": int(ttl_days)},
|
||||
)
|
||||
return int(result.rowcount or 0)
|
||||
352
application/storage/db/repositories/schedules.py
Normal file
352
application/storage/db/repositories/schedules.py
Normal file
@@ -0,0 +1,352 @@
|
||||
"""Repository for the ``schedules`` table (CRUD + dispatcher claim query)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any, Iterable, Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
|
||||
|
||||
_ALLOWED_UPDATES = frozenset(
|
||||
{
|
||||
"name", "instruction", "status", "cron", "run_at", "timezone",
|
||||
"next_run_at", "last_run_at", "end_at", "tool_allowlist",
|
||||
"model_id", "token_budget", "consecutive_failure_count",
|
||||
"origin_conversation_id",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class SchedulesRepository:
|
||||
"""CRUD + dispatcher hot path for ``schedules``."""
|
||||
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def create(
|
||||
self,
|
||||
user_id: str,
|
||||
agent_id: Optional[str],
|
||||
trigger_type: str,
|
||||
instruction: str,
|
||||
*,
|
||||
cron: Optional[str] = None,
|
||||
run_at: Optional[datetime] = None,
|
||||
timezone: str = "UTC",
|
||||
next_run_at: Optional[datetime] = None,
|
||||
end_at: Optional[datetime] = None,
|
||||
name: Optional[str] = None,
|
||||
tool_allowlist: Optional[Iterable[str]] = None,
|
||||
model_id: Optional[str] = None,
|
||||
token_budget: Optional[int] = None,
|
||||
origin_conversation_id: Optional[str] = None,
|
||||
created_via: str = "ui",
|
||||
status: str = "active",
|
||||
) -> dict:
|
||||
"""Insert a new schedule and return the populated row."""
|
||||
params = {
|
||||
"user_id": user_id,
|
||||
"agent_id": str(agent_id) if agent_id else None,
|
||||
"trigger_type": trigger_type,
|
||||
"instruction": instruction,
|
||||
"cron": cron,
|
||||
"run_at": run_at,
|
||||
"tz": timezone,
|
||||
"next_run_at": next_run_at,
|
||||
"end_at": end_at,
|
||||
"name": name,
|
||||
"allowlist": json.dumps(list(tool_allowlist or [])),
|
||||
"model_id": model_id,
|
||||
"token_budget": int(token_budget) if token_budget is not None else None,
|
||||
"origin_conversation_id": (
|
||||
str(origin_conversation_id) if origin_conversation_id else None
|
||||
),
|
||||
"created_via": created_via,
|
||||
"status": status,
|
||||
}
|
||||
row = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO schedules (
|
||||
user_id, agent_id, trigger_type, instruction, status,
|
||||
cron, run_at, timezone, next_run_at, end_at, name,
|
||||
tool_allowlist, model_id, token_budget,
|
||||
origin_conversation_id, created_via
|
||||
) VALUES (
|
||||
:user_id,
|
||||
CAST(:agent_id AS uuid),
|
||||
:trigger_type,
|
||||
:instruction,
|
||||
:status,
|
||||
:cron,
|
||||
:run_at,
|
||||
:tz,
|
||||
:next_run_at,
|
||||
:end_at,
|
||||
:name,
|
||||
CAST(:allowlist AS jsonb),
|
||||
:model_id,
|
||||
:token_budget,
|
||||
CAST(:origin_conversation_id AS uuid),
|
||||
:created_via
|
||||
) RETURNING *
|
||||
"""
|
||||
),
|
||||
params,
|
||||
).fetchone()
|
||||
return row_to_dict(row)
|
||||
|
||||
def get(self, schedule_id: str, user_id: str) -> Optional[dict]:
|
||||
"""Fetch an owned schedule (None when missing or owned by another)."""
|
||||
row = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM schedules "
|
||||
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||
),
|
||||
{"id": str(schedule_id), "user_id": user_id},
|
||||
).fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def get_internal(self, schedule_id: str) -> Optional[dict]:
|
||||
"""Fetch a schedule with no ownership scoping (worker-only)."""
|
||||
row = self._conn.execute(
|
||||
text("SELECT * FROM schedules WHERE id = CAST(:id AS uuid)"),
|
||||
{"id": str(schedule_id)},
|
||||
).fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def get_for_update(
|
||||
self, schedule_id: str, user_id: str,
|
||||
) -> Optional[dict]:
|
||||
"""Owned fetch with FOR UPDATE; closes the Run-Now TOCTOU."""
|
||||
row = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM schedules "
|
||||
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id "
|
||||
"FOR UPDATE"
|
||||
),
|
||||
{"id": str(schedule_id), "user_id": user_id},
|
||||
).fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def list_for_agent(
|
||||
self,
|
||||
agent_id: str,
|
||||
user_id: str,
|
||||
*,
|
||||
statuses: Optional[Iterable[str]] = None,
|
||||
trigger_type: Optional[str] = None,
|
||||
) -> list[dict]:
|
||||
"""Owned schedules for an agent, newest-created first."""
|
||||
sql = (
|
||||
"SELECT * FROM schedules "
|
||||
"WHERE agent_id = CAST(:agent_id AS uuid) AND user_id = :user_id"
|
||||
)
|
||||
params: dict[str, Any] = {"agent_id": str(agent_id), "user_id": user_id}
|
||||
if statuses is not None:
|
||||
status_list = [str(s) for s in statuses]
|
||||
if not status_list:
|
||||
return []
|
||||
placeholders = ", ".join(f":s{i}" for i, _ in enumerate(status_list))
|
||||
sql += f" AND status IN ({placeholders})"
|
||||
for i, s in enumerate(status_list):
|
||||
params[f"s{i}"] = s
|
||||
if trigger_type:
|
||||
sql += " AND trigger_type = :trigger_type"
|
||||
params["trigger_type"] = trigger_type
|
||||
sql += " ORDER BY created_at DESC"
|
||||
rows = self._conn.execute(text(sql), params).fetchall()
|
||||
return [row_to_dict(r) for r in rows]
|
||||
|
||||
def list_for_conversation(
|
||||
self,
|
||||
user_id: str,
|
||||
origin_conversation_id: str,
|
||||
*,
|
||||
statuses: Optional[Iterable[str]] = None,
|
||||
trigger_type: Optional[str] = None,
|
||||
) -> list[dict]:
|
||||
"""Owned agentless schedules anchored to an originating conversation."""
|
||||
sql = (
|
||||
"SELECT * FROM schedules "
|
||||
"WHERE user_id = :user_id "
|
||||
"AND agent_id IS NULL "
|
||||
"AND origin_conversation_id = CAST(:conv AS uuid)"
|
||||
)
|
||||
params: dict[str, Any] = {
|
||||
"user_id": user_id,
|
||||
"conv": str(origin_conversation_id),
|
||||
}
|
||||
if statuses is not None:
|
||||
status_list = [str(s) for s in statuses]
|
||||
if not status_list:
|
||||
return []
|
||||
placeholders = ", ".join(f":s{i}" for i, _ in enumerate(status_list))
|
||||
sql += f" AND status IN ({placeholders})"
|
||||
for i, s in enumerate(status_list):
|
||||
params[f"s{i}"] = s
|
||||
if trigger_type:
|
||||
sql += " AND trigger_type = :trigger_type"
|
||||
params["trigger_type"] = trigger_type
|
||||
sql += " ORDER BY created_at DESC"
|
||||
rows = self._conn.execute(text(sql), params).fetchall()
|
||||
return [row_to_dict(r) for r in rows]
|
||||
|
||||
def list_for_user(self, user_id: str, *, limit: int = 200) -> list[dict]:
|
||||
"""Owned schedules across all agents — admin / debugging path."""
|
||||
rows = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM schedules WHERE user_id = :user_id "
|
||||
"ORDER BY created_at DESC LIMIT :limit"
|
||||
),
|
||||
{"user_id": user_id, "limit": int(limit)},
|
||||
).fetchall()
|
||||
return [row_to_dict(r) for r in rows]
|
||||
|
||||
def count_active_for_user(self, user_id: str) -> int:
|
||||
"""Active+paused schedules for quota enforcement."""
|
||||
scalar = self._conn.execute(
|
||||
text(
|
||||
"SELECT COUNT(*) FROM schedules "
|
||||
"WHERE user_id = :user_id AND status IN ('active', 'paused')"
|
||||
),
|
||||
{"user_id": user_id},
|
||||
).scalar()
|
||||
return int(scalar or 0)
|
||||
|
||||
def list_due(self, *, limit: int = 100) -> list[dict]:
|
||||
"""Lock and return schedules with ``next_run_at <= now()``."""
|
||||
rows = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT * FROM schedules
|
||||
WHERE status = 'active'
|
||||
AND next_run_at IS NOT NULL
|
||||
AND next_run_at <= now()
|
||||
AND (end_at IS NULL OR next_run_at <= end_at)
|
||||
ORDER BY next_run_at ASC
|
||||
LIMIT :limit
|
||||
FOR UPDATE SKIP LOCKED
|
||||
"""
|
||||
),
|
||||
{"limit": int(limit)},
|
||||
).fetchall()
|
||||
return [row_to_dict(r) for r in rows]
|
||||
|
||||
def update(
|
||||
self,
|
||||
schedule_id: str,
|
||||
user_id: str,
|
||||
fields: dict,
|
||||
) -> Optional[dict]:
|
||||
"""Apply a whitelisted partial update; return the new row or None."""
|
||||
filtered = {k: v for k, v in fields.items() if k in _ALLOWED_UPDATES}
|
||||
if not filtered:
|
||||
return self.get(schedule_id, user_id)
|
||||
set_parts: list[str] = []
|
||||
params: dict[str, Any] = {"id": str(schedule_id), "user_id": user_id}
|
||||
for key, val in filtered.items():
|
||||
if key == "tool_allowlist":
|
||||
set_parts.append("tool_allowlist = CAST(:tool_allowlist AS jsonb)")
|
||||
params["tool_allowlist"] = json.dumps(list(val or []))
|
||||
elif key == "origin_conversation_id":
|
||||
set_parts.append(
|
||||
"origin_conversation_id = CAST(:origin_conversation_id AS uuid)"
|
||||
)
|
||||
params["origin_conversation_id"] = str(val) if val else None
|
||||
else:
|
||||
set_parts.append(f"{key} = :{key}")
|
||||
params[key] = val
|
||||
sql = (
|
||||
"UPDATE schedules SET " + ", ".join(set_parts) +
|
||||
" WHERE id = CAST(:id AS uuid) AND user_id = :user_id "
|
||||
"RETURNING *"
|
||||
)
|
||||
row = self._conn.execute(text(sql), params).fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def update_internal(self, schedule_id: str, fields: dict) -> None:
|
||||
"""Apply a whitelisted partial update from a worker context."""
|
||||
filtered = {k: v for k, v in fields.items() if k in _ALLOWED_UPDATES}
|
||||
if not filtered:
|
||||
return
|
||||
set_parts: list[str] = []
|
||||
params: dict[str, Any] = {"id": str(schedule_id)}
|
||||
for key, val in filtered.items():
|
||||
if key == "tool_allowlist":
|
||||
set_parts.append("tool_allowlist = CAST(:tool_allowlist AS jsonb)")
|
||||
params["tool_allowlist"] = json.dumps(list(val or []))
|
||||
elif key == "origin_conversation_id":
|
||||
set_parts.append(
|
||||
"origin_conversation_id = CAST(:origin_conversation_id AS uuid)"
|
||||
)
|
||||
params["origin_conversation_id"] = str(val) if val else None
|
||||
else:
|
||||
set_parts.append(f"{key} = :{key}")
|
||||
params[key] = val
|
||||
sql = (
|
||||
"UPDATE schedules SET " + ", ".join(set_parts) +
|
||||
" WHERE id = CAST(:id AS uuid)"
|
||||
)
|
||||
self._conn.execute(text(sql), params)
|
||||
|
||||
def cancel(self, schedule_id: str, user_id: str) -> bool:
|
||||
"""Soft-cancel — flips ``status`` to ``cancelled`` and clears ``next_run_at``."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"UPDATE schedules SET status = 'cancelled', next_run_at = NULL "
|
||||
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id "
|
||||
"AND status NOT IN ('cancelled', 'completed')"
|
||||
),
|
||||
{"id": str(schedule_id), "user_id": user_id},
|
||||
)
|
||||
return (result.rowcount or 0) > 0
|
||||
|
||||
def delete(self, schedule_id: str, user_id: str) -> bool:
|
||||
"""Hard-delete an owned schedule and its runs (FK cascade)."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"DELETE FROM schedules "
|
||||
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||
),
|
||||
{"id": str(schedule_id), "user_id": user_id},
|
||||
)
|
||||
return (result.rowcount or 0) > 0
|
||||
|
||||
def bump_failure_count(self, schedule_id: str) -> int:
|
||||
"""Increment ``consecutive_failure_count`` and return the new value."""
|
||||
row = self._conn.execute(
|
||||
text(
|
||||
"UPDATE schedules "
|
||||
"SET consecutive_failure_count = consecutive_failure_count + 1 "
|
||||
"WHERE id = CAST(:id AS uuid) "
|
||||
"RETURNING consecutive_failure_count"
|
||||
),
|
||||
{"id": str(schedule_id)},
|
||||
).fetchone()
|
||||
return int(row[0]) if row is not None else 0
|
||||
|
||||
def reset_failure_count(self, schedule_id: str) -> None:
|
||||
"""Reset the failure counter to 0 after a successful run."""
|
||||
self._conn.execute(
|
||||
text(
|
||||
"UPDATE schedules SET consecutive_failure_count = 0 "
|
||||
"WHERE id = CAST(:id AS uuid)"
|
||||
),
|
||||
{"id": str(schedule_id)},
|
||||
)
|
||||
|
||||
def autopause(self, schedule_id: str) -> bool:
|
||||
"""Flip an active schedule to ``paused`` after repeated failures."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"UPDATE schedules SET status = 'paused', next_run_at = NULL "
|
||||
"WHERE id = CAST(:id AS uuid) AND status = 'active'"
|
||||
),
|
||||
{"id": str(schedule_id)},
|
||||
)
|
||||
return (result.rowcount or 0) > 0
|
||||
@@ -175,6 +175,67 @@ class UsersRepository:
|
||||
{"user_id": user_id, "agent_id": agent_id},
|
||||
)
|
||||
|
||||
def set_default_tool_enabled(
|
||||
self, user_id: str, tool_name: str, enabled: bool
|
||||
) -> None:
|
||||
"""Toggle a default chat tool in ``tool_preferences`` (idempotent)."""
|
||||
self.upsert(user_id)
|
||||
if enabled:
|
||||
self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE users
|
||||
SET tool_preferences = jsonb_set(
|
||||
COALESCE(tool_preferences, '{}'::jsonb),
|
||||
'{disabled_default_tools}',
|
||||
COALESCE(
|
||||
(
|
||||
SELECT jsonb_agg(elem)
|
||||
FROM jsonb_array_elements(
|
||||
COALESCE(
|
||||
tool_preferences->'disabled_default_tools',
|
||||
'[]'::jsonb
|
||||
)
|
||||
) AS elem
|
||||
WHERE (elem #>> '{}') != :tool_name
|
||||
),
|
||||
'[]'::jsonb
|
||||
)
|
||||
),
|
||||
updated_at = now()
|
||||
WHERE user_id = :user_id
|
||||
"""
|
||||
),
|
||||
{"user_id": user_id, "tool_name": tool_name},
|
||||
)
|
||||
else:
|
||||
self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE users
|
||||
SET tool_preferences = jsonb_set(
|
||||
COALESCE(tool_preferences, '{}'::jsonb),
|
||||
'{disabled_default_tools}',
|
||||
CASE
|
||||
WHEN COALESCE(
|
||||
tool_preferences->'disabled_default_tools',
|
||||
'[]'::jsonb
|
||||
) @> to_jsonb(CAST(:tool_name AS text))
|
||||
THEN tool_preferences->'disabled_default_tools'
|
||||
ELSE
|
||||
COALESCE(
|
||||
tool_preferences->'disabled_default_tools',
|
||||
'[]'::jsonb
|
||||
) || to_jsonb(CAST(:tool_name AS text))
|
||||
END
|
||||
),
|
||||
updated_at = now()
|
||||
WHERE user_id = :user_id
|
||||
"""
|
||||
),
|
||||
{"user_id": user_id, "tool_name": tool_name},
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Private helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@@ -16,9 +16,6 @@ from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
|
||||
from application.agents.agent_creator import AgentCreator
|
||||
from application.api.answer.services.stream_processor import get_prompt
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.events.publisher import publish_user_event
|
||||
from application.parser.chunking import Chunker
|
||||
@@ -34,7 +31,6 @@ from application.parser.remote.remote_creator import (
|
||||
normalize_remote_data,
|
||||
)
|
||||
from application.parser.schema.base import Document
|
||||
from application.retriever.retriever_creator import RetrieverCreator
|
||||
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
@@ -391,146 +387,6 @@ def upload_index(full_path, file_data):
|
||||
file.close()
|
||||
|
||||
|
||||
def run_agent_logic(agent_config, input_data):
|
||||
try:
|
||||
from application.core.model_utils import (
|
||||
get_api_key_for_provider,
|
||||
get_default_model_id,
|
||||
get_provider_from_model_id,
|
||||
validate_model_id,
|
||||
)
|
||||
from application.utils import calculate_doc_token_budget
|
||||
|
||||
retriever = agent_config.get("retriever", "classic")
|
||||
# agent_config is a PG row dict: ``source_id`` is a UUID, and the
|
||||
# retriever/chunks live on the source row. Resolve source row for
|
||||
# its retriever/chunks if the agent points at one.
|
||||
source_id = agent_config.get("source_id") or agent_config.get("source")
|
||||
source_active = {}
|
||||
if source_id:
|
||||
with db_readonly() as conn:
|
||||
src_row = SourcesRepository(conn).get(
|
||||
str(source_id),
|
||||
agent_config.get("user_id") or agent_config.get("user"),
|
||||
)
|
||||
if src_row:
|
||||
source_active = str(src_row["id"])
|
||||
retriever = src_row.get("retriever", retriever)
|
||||
source = {"active_docs": source_active}
|
||||
chunks = int(agent_config.get("chunks", 2) or 2)
|
||||
prompt_id = agent_config.get("prompt_id", "default")
|
||||
user_api_key = agent_config["key"]
|
||||
agent_id = (
|
||||
str(agent_config.get("id"))
|
||||
if agent_config.get("id")
|
||||
else (str(agent_config.get("_id")) if agent_config.get("_id") else None)
|
||||
)
|
||||
agent_type = agent_config.get("agent_type", "classic")
|
||||
owner = agent_config.get("user_id") or agent_config.get("user")
|
||||
decoded_token = {"sub": owner}
|
||||
json_schema = agent_config.get("json_schema")
|
||||
prompt = get_prompt(prompt_id)
|
||||
|
||||
# Determine model_id: check agent's default_model_id, fallback to system default
|
||||
agent_default_model = agent_config.get("default_model_id", "")
|
||||
if agent_default_model and validate_model_id(
|
||||
agent_default_model, user_id=owner
|
||||
):
|
||||
model_id = agent_default_model
|
||||
else:
|
||||
model_id = get_default_model_id()
|
||||
if agent_default_model:
|
||||
# Stored model_id no longer resolves in the registry. Log so
|
||||
# operators can detect bad YAML edits before users complain;
|
||||
# behavior matches the historical silent fallback.
|
||||
logging.warning(
|
||||
"Agent %s references unknown model_id %r; falling back to %r",
|
||||
agent_id,
|
||||
agent_default_model,
|
||||
model_id,
|
||||
)
|
||||
|
||||
# Get provider and API key for the selected model
|
||||
provider = (
|
||||
get_provider_from_model_id(model_id, user_id=owner)
|
||||
if model_id
|
||||
else settings.LLM_PROVIDER
|
||||
)
|
||||
system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER)
|
||||
|
||||
# Calculate proper doc_token_limit based on model's context window
|
||||
doc_token_limit = calculate_doc_token_budget(
|
||||
model_id=model_id, user_id=owner
|
||||
)
|
||||
|
||||
retriever = RetrieverCreator.create_retriever(
|
||||
retriever,
|
||||
source=source,
|
||||
chat_history=[],
|
||||
prompt=prompt,
|
||||
chunks=chunks,
|
||||
doc_token_limit=doc_token_limit,
|
||||
model_id=model_id,
|
||||
user_api_key=user_api_key,
|
||||
agent_id=agent_id,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
# Pre-fetch documents using the retriever
|
||||
retrieved_docs = []
|
||||
try:
|
||||
docs = retriever.search(input_data)
|
||||
if docs:
|
||||
retrieved_docs = docs
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to retrieve documents: {e}")
|
||||
|
||||
agent = AgentCreator.create_agent(
|
||||
agent_type,
|
||||
endpoint="webhook",
|
||||
llm_name=provider or settings.LLM_PROVIDER,
|
||||
model_id=model_id,
|
||||
api_key=system_api_key,
|
||||
agent_id=agent_id,
|
||||
user_api_key=user_api_key,
|
||||
prompt=prompt,
|
||||
chat_history=[],
|
||||
retrieved_docs=retrieved_docs,
|
||||
decoded_token=decoded_token,
|
||||
attachments=[],
|
||||
json_schema=json_schema,
|
||||
)
|
||||
answer = agent.gen(query=input_data)
|
||||
response_full = ""
|
||||
thought = ""
|
||||
source_log_docs = []
|
||||
tool_calls = []
|
||||
|
||||
for line in answer:
|
||||
if "answer" in line:
|
||||
response_full += str(line["answer"])
|
||||
elif "sources" in line:
|
||||
source_log_docs.extend(line["sources"])
|
||||
elif "tool_calls" in line:
|
||||
tool_calls.extend(line["tool_calls"])
|
||||
elif "thought" in line:
|
||||
thought += line["thought"]
|
||||
result = {
|
||||
"answer": response_full,
|
||||
"sources": source_log_docs,
|
||||
"tool_calls": tool_calls,
|
||||
"thought": thought,
|
||||
}
|
||||
# Per-activity summary fields (answer_length, thought_length,
|
||||
# source_count, tool_call_count) now ride on the inner
|
||||
# ``activity_finished`` event emitted by ``log_activity`` around
|
||||
# ``Agent.gen`` above; no separate ``agent_response`` log needed.
|
||||
return result
|
||||
except Exception as e:
|
||||
logging.error(f"Error in run_agent_logic: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
# Define the main function for ingesting and processing documents.
|
||||
|
||||
|
||||
@@ -1685,7 +1541,21 @@ def agent_webhook_worker(self, agent_id, payload):
|
||||
raise
|
||||
self.update_state(state="PROGRESS", meta={"current": 50})
|
||||
try:
|
||||
result = run_agent_logic(agent_config, input_data)
|
||||
# Shared headless path with the scheduler; approval-gated tools auto-deny.
|
||||
from application.agents.headless_runner import run_agent_headless
|
||||
|
||||
outcome = run_agent_headless(
|
||||
agent_config,
|
||||
input_data,
|
||||
tool_allowlist=_webhook_tool_allowlist(agent_config),
|
||||
endpoint="webhook",
|
||||
)
|
||||
result = {
|
||||
"answer": outcome.get("answer", ""),
|
||||
"sources": outcome.get("sources", []),
|
||||
"tool_calls": outcome.get("tool_calls", []),
|
||||
"thought": outcome.get("thought", ""),
|
||||
}
|
||||
except Exception as e:
|
||||
logging.error(f"Error running agent logic: {e}", exc_info=True)
|
||||
raise
|
||||
@@ -1698,6 +1568,11 @@ def agent_webhook_worker(self, agent_id, payload):
|
||||
self.update_state(state="PROGRESS", meta={"current": 100})
|
||||
|
||||
|
||||
def _webhook_tool_allowlist(agent_config):
|
||||
"""Deny-all on approval-gated tools for webhooks (per-agent opt-in is TBD)."""
|
||||
return []
|
||||
|
||||
|
||||
def ingest_connector(
|
||||
self,
|
||||
job_name: str,
|
||||
|
||||
44
frontend/package-lock.json
generated
44
frontend/package-lock.json
generated
@@ -18,6 +18,7 @@
|
||||
"clsx": "^2.1.1",
|
||||
"cmdk": "^1.1.1",
|
||||
"copy-to-clipboard": "^3.3.3",
|
||||
"date-fns": "^4.2.1",
|
||||
"i18next": "^26.0.4",
|
||||
"i18next-browser-languagedetector": "^8.2.1",
|
||||
"lodash": "^4.18.1",
|
||||
@@ -27,6 +28,7 @@
|
||||
"radix-ui": "^1.4.3",
|
||||
"react": "^19.1.0",
|
||||
"react-chartjs-2": "^5.3.0",
|
||||
"react-day-picker": "^10.0.1",
|
||||
"react-dom": "^19.2.5",
|
||||
"react-dropzone": "^15.0.0",
|
||||
"react-google-drive-picker": "^1.2.2",
|
||||
@@ -419,6 +421,12 @@
|
||||
"integrity": "sha512-lB59uJoaGIfOOL9knQqQRfhl9g7x8/wqFkp13zTdkRu1huG9kg6IJs1O8hqj9rs6h7orGxHJUKb+mX3rPbWGhA==",
|
||||
"license": "Apache-2.0"
|
||||
},
|
||||
"node_modules/@date-fns/tz": {
|
||||
"version": "1.5.0",
|
||||
"resolved": "https://registry.npmjs.org/@date-fns/tz/-/tz-1.5.0.tgz",
|
||||
"integrity": "sha512-lwYN/vDPeNRULcepoE/LO2Pgx+7/RV+S9ARfbc9lr2DtGkOD7pAiruHvbR1RX3Qyf6ja47EWJDMsNK5vK08DJg==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/@emnapi/core": {
|
||||
"version": "1.10.0",
|
||||
"resolved": "https://registry.npmjs.org/@emnapi/core/-/core-1.10.0.tgz",
|
||||
@@ -6515,6 +6523,16 @@
|
||||
"url": "https://github.com/sponsors/ljharb"
|
||||
}
|
||||
},
|
||||
"node_modules/date-fns": {
|
||||
"version": "4.2.1",
|
||||
"resolved": "https://registry.npmjs.org/date-fns/-/date-fns-4.2.1.tgz",
|
||||
"integrity": "sha512-37RhSdxaG1suen6VDCza6rNrQfooyQh57HFVPwQGEq2QWliVLzPQZ8Oa017weOu+HZCnzI7N3Pf/wyoBKfEqrA==",
|
||||
"license": "MIT",
|
||||
"funding": {
|
||||
"type": "github",
|
||||
"url": "https://github.com/sponsors/kossnocorp"
|
||||
}
|
||||
},
|
||||
"node_modules/dayjs": {
|
||||
"version": "1.11.19",
|
||||
"resolved": "https://registry.npmjs.org/dayjs/-/dayjs-1.11.19.tgz",
|
||||
@@ -11371,6 +11389,32 @@
|
||||
"react": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/react-day-picker": {
|
||||
"version": "10.0.1",
|
||||
"resolved": "https://registry.npmjs.org/react-day-picker/-/react-day-picker-10.0.1.tgz",
|
||||
"integrity": "sha512-eNh6BlwcYInWaJtRv18mXQ06Ys/H6rdTZAnTaSdOYJuTpwP1JMCHNd1FDRadA+gbeinq+psdULN5Xnowy9mV8w==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@date-fns/tz": "^1.4.1",
|
||||
"date-fns": "^4.1.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
},
|
||||
"funding": {
|
||||
"type": "individual",
|
||||
"url": "https://github.com/sponsors/gpbl"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": ">=16.8.0",
|
||||
"react": ">=16.8.0"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/react-dom": {
|
||||
"version": "19.2.5",
|
||||
"resolved": "https://registry.npmjs.org/react-dom/-/react-dom-19.2.5.tgz",
|
||||
|
||||
@@ -31,6 +31,7 @@
|
||||
"clsx": "^2.1.1",
|
||||
"cmdk": "^1.1.1",
|
||||
"copy-to-clipboard": "^3.3.3",
|
||||
"date-fns": "^4.2.1",
|
||||
"i18next": "^26.0.4",
|
||||
"i18next-browser-languagedetector": "^8.2.1",
|
||||
"lodash": "^4.18.1",
|
||||
@@ -40,6 +41,7 @@
|
||||
"radix-ui": "^1.4.3",
|
||||
"react": "^19.1.0",
|
||||
"react-chartjs-2": "^5.3.0",
|
||||
"react-day-picker": "^10.0.1",
|
||||
"react-dom": "^19.2.5",
|
||||
"react-dropzone": "^15.0.0",
|
||||
"react-google-drive-picker": "^1.2.2",
|
||||
|
||||
1
frontend/public/toolIcons/tool_scheduler.svg
Normal file
1
frontend/public/toolIcons/tool_scheduler.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#e3e3e3"><path d="m612-292 56-56-148-148v-184h-80v216l172 172ZM480-80q-83 0-156-31.5T197-197q-54-54-85.5-127T80-480q0-83 31.5-156T197-763q54-54 127-85.5T480-880q83 0 156 31.5T763-763q54 54 85.5 127T880-480q0 83-31.5 156T763-197q-54 54-127 85.5T480-80Zm0-80q134 0 227-93t93-227q0-134-93-227t-227-93q-134 0-227 93t-93 227q0 134 93 227t227 93Zm0-320Z"/></svg>
|
||||
|
After Width: | Height: | Size: 455 B |
@@ -1,19 +1,18 @@
|
||||
import { useEffect, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useSelector } from 'react-redux';
|
||||
import { useNavigate, useParams } from 'react-router-dom';
|
||||
import { useParams } from 'react-router-dom';
|
||||
|
||||
import userService from '../api/services/userService';
|
||||
import ArrowLeft from '../assets/arrow-left.svg';
|
||||
import Spinner from '../components/Spinner';
|
||||
import { selectToken } from '../preferences/preferenceSlice';
|
||||
import Analytics from '../settings/Analytics';
|
||||
import Logs from '../settings/Logs';
|
||||
import AgentPageHeader from './AgentPageHeader';
|
||||
import { Agent } from './types';
|
||||
|
||||
export default function AgentLogs() {
|
||||
const { t } = useTranslation();
|
||||
const navigate = useNavigate();
|
||||
const { agentId } = useParams();
|
||||
const token = useSelector(selectToken);
|
||||
|
||||
@@ -37,19 +36,21 @@ export default function AgentLogs() {
|
||||
useEffect(() => {
|
||||
if (agentId) fetchAgent(agentId);
|
||||
}, [agentId, token]);
|
||||
|
||||
const agentEditPath =
|
||||
agent?.agent_type === 'workflow'
|
||||
? `/agents/workflow/edit/${agentId}`
|
||||
: `/agents/edit/${agentId}`;
|
||||
|
||||
return (
|
||||
<div className="p-4 md:p-12">
|
||||
<div className="flex items-center gap-3 px-4">
|
||||
<button
|
||||
className="border-border text-muted-foreground hover:bg-accent rounded-full border p-3 text-sm"
|
||||
onClick={() => navigate('/agents')}
|
||||
>
|
||||
<img src={ArrowLeft} alt="left-arrow" className="h-3 w-3" />
|
||||
</button>
|
||||
<p className="text-foreground dark:text-foreground mt-px text-sm font-semibold">
|
||||
{t('agents.backToAll')}
|
||||
</p>
|
||||
</div>
|
||||
<AgentPageHeader
|
||||
agentId={agentId}
|
||||
agentName={agent?.name}
|
||||
agentEditPath={agentEditPath}
|
||||
currentPage="logs"
|
||||
className="px-4"
|
||||
/>
|
||||
<div className="mt-5 flex w-full flex-wrap items-center justify-between gap-2 px-4">
|
||||
<h1 className="text-foreground m-0 text-[32px] font-bold md:text-[40px] dark:text-white">
|
||||
{t('agents.logs.title')}
|
||||
@@ -78,7 +79,6 @@ export default function AgentLogs() {
|
||||
)}
|
||||
{loadingAgent ? (
|
||||
<div className="flex h-[55vh] w-full items-center justify-center">
|
||||
{' '}
|
||||
<Spinner />
|
||||
</div>
|
||||
) : (
|
||||
|
||||
155
frontend/src/agents/AgentPageHeader.tsx
Normal file
155
frontend/src/agents/AgentPageHeader.tsx
Normal file
@@ -0,0 +1,155 @@
|
||||
import { useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { Link } from 'react-router-dom';
|
||||
|
||||
import {
|
||||
Breadcrumb,
|
||||
BreadcrumbItem,
|
||||
BreadcrumbLink,
|
||||
BreadcrumbList,
|
||||
BreadcrumbPage,
|
||||
BreadcrumbSeparator,
|
||||
} from '@/components/ui/breadcrumb';
|
||||
import { cn } from '@/lib/utils';
|
||||
|
||||
export type AgentPageTab = 'overview' | 'logs' | 'schedules';
|
||||
|
||||
type AgentPageHeaderProps = {
|
||||
agentId?: string;
|
||||
agentName?: string;
|
||||
/** Route shape for the agent's own root page. Defaults to classic edit URL. */
|
||||
agentEditPath?: string;
|
||||
currentPage: AgentPageTab;
|
||||
/** Optional className wrapper for layout tweaks per page. */
|
||||
className?: string;
|
||||
/**
|
||||
* Drop the 1px baseline border under the tabs row. Use when the header is
|
||||
* embedded in a container that already provides its own bottom border
|
||||
* (e.g. the workflow builder's fixed toolbar) to avoid a double rule.
|
||||
*/
|
||||
inline?: boolean;
|
||||
};
|
||||
|
||||
/**
|
||||
* Shared chrome for the agent sub-pages (Overview/Edit, Logs, Schedules).
|
||||
*
|
||||
* Top: shadcn Breadcrumb (`Agents > <agent name> > <current page>`).
|
||||
* Bottom: underline-style sub-nav linking between the agent's sub-pages.
|
||||
*/
|
||||
export default function AgentPageHeader({
|
||||
agentId,
|
||||
agentName,
|
||||
agentEditPath,
|
||||
currentPage,
|
||||
className,
|
||||
inline = false,
|
||||
}: AgentPageHeaderProps) {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const editPath =
|
||||
agentEditPath ?? (agentId ? `/agents/edit/${agentId}` : '/agents');
|
||||
const tabs = useMemo(
|
||||
() => [
|
||||
{
|
||||
id: 'overview' as const,
|
||||
label: t('agents.pageHeader.tabs.overview'),
|
||||
href: editPath,
|
||||
},
|
||||
{
|
||||
id: 'logs' as const,
|
||||
label: t('agents.pageHeader.tabs.logs'),
|
||||
href: agentId ? `/agents/logs/${agentId}` : '#',
|
||||
},
|
||||
{
|
||||
id: 'schedules' as const,
|
||||
label: t('agents.pageHeader.tabs.schedules'),
|
||||
href: agentId ? `/agents/schedules/${agentId}` : '#',
|
||||
},
|
||||
],
|
||||
[agentId, editPath, t],
|
||||
);
|
||||
|
||||
const currentTabLabel =
|
||||
tabs.find((tab) => tab.id === currentPage)?.label ?? '';
|
||||
const displayName = agentName?.trim() || t('agents.pageHeader.fallbackName');
|
||||
|
||||
return (
|
||||
<div className={cn('flex flex-col gap-3', className)}>
|
||||
<Breadcrumb>
|
||||
<BreadcrumbList>
|
||||
<BreadcrumbItem>
|
||||
<BreadcrumbLink asChild>
|
||||
<Link to="/agents">{t('agents.pageHeader.crumbs.agents')}</Link>
|
||||
</BreadcrumbLink>
|
||||
</BreadcrumbItem>
|
||||
<BreadcrumbSeparator />
|
||||
<BreadcrumbItem>
|
||||
{currentPage === 'overview' ? (
|
||||
<BreadcrumbPage className="max-w-[40ch] truncate">
|
||||
{displayName}
|
||||
</BreadcrumbPage>
|
||||
) : (
|
||||
<BreadcrumbLink asChild>
|
||||
<Link to={editPath} className="max-w-[40ch] truncate">
|
||||
{displayName}
|
||||
</Link>
|
||||
</BreadcrumbLink>
|
||||
)}
|
||||
</BreadcrumbItem>
|
||||
{currentPage !== 'overview' && (
|
||||
<>
|
||||
<BreadcrumbSeparator />
|
||||
<BreadcrumbItem>
|
||||
<BreadcrumbPage>{currentTabLabel}</BreadcrumbPage>
|
||||
</BreadcrumbItem>
|
||||
</>
|
||||
)}
|
||||
</BreadcrumbList>
|
||||
</Breadcrumb>
|
||||
|
||||
<nav
|
||||
aria-label={t('agents.pageHeader.subnavAriaLabel')}
|
||||
className={cn(
|
||||
'flex w-full items-center gap-6',
|
||||
// 1px baseline rule under the whole row; the active tab's 2px
|
||||
// primary underline sits on top of it for the GitHub-style look.
|
||||
!inline && 'border-border border-b',
|
||||
)}
|
||||
>
|
||||
{tabs.map((tab) => {
|
||||
const isActive = tab.id === currentPage;
|
||||
// Always render a 2px bottom border so row height stays constant
|
||||
// between active/inactive; only the color changes.
|
||||
const baseClasses =
|
||||
'whitespace-nowrap border-b-2 pb-2 text-sm font-medium transition-colors';
|
||||
if (isActive) {
|
||||
return (
|
||||
<span
|
||||
key={tab.id}
|
||||
aria-current="page"
|
||||
className={cn(
|
||||
baseClasses,
|
||||
'border-primary text-foreground -mb-px',
|
||||
)}
|
||||
>
|
||||
{tab.label}
|
||||
</span>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<Link
|
||||
key={tab.id}
|
||||
to={tab.href}
|
||||
className={cn(
|
||||
baseClasses,
|
||||
'text-muted-foreground hover:text-foreground hover:border-border/60 -mb-px border-transparent',
|
||||
)}
|
||||
>
|
||||
{tab.label}
|
||||
</Link>
|
||||
);
|
||||
})}
|
||||
</nav>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -6,7 +6,6 @@ import { useNavigate, useParams, useSearchParams } from 'react-router-dom';
|
||||
|
||||
import modelService from '../api/services/modelService';
|
||||
import userService from '../api/services/userService';
|
||||
import ArrowLeft from '../assets/arrow-left.svg';
|
||||
import SourceIcon from '../assets/source.svg';
|
||||
import Dropdown from '../components/Dropdown';
|
||||
import { FileUpload } from '../components/FileUpload';
|
||||
@@ -29,6 +28,7 @@ import PromptsModal from '../preferences/PromptsModal';
|
||||
import Prompts from '../settings/Prompts';
|
||||
import { UserToolType } from '../settings/types';
|
||||
import { getToolDisplayName } from '../utils/toolUtils';
|
||||
import AgentPageHeader from './AgentPageHeader';
|
||||
import AgentPreview from './AgentPreview';
|
||||
import { Agent, ToolSummary } from './types';
|
||||
import WorkflowBuilder from './workflow/WorkflowBuilder';
|
||||
@@ -113,7 +113,6 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
buttonText: t('agents.form.buttons.publish'),
|
||||
showDelete: false,
|
||||
showSaveDraft: true,
|
||||
showLogs: false,
|
||||
showAccessDetails: false,
|
||||
trackChanges: false,
|
||||
},
|
||||
@@ -122,7 +121,6 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
buttonText: t('agents.form.buttons.save'),
|
||||
showDelete: true,
|
||||
showSaveDraft: false,
|
||||
showLogs: true,
|
||||
showAccessDetails: true,
|
||||
trackChanges: true,
|
||||
},
|
||||
@@ -131,7 +129,6 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
buttonText: t('agents.form.buttons.publish'),
|
||||
showDelete: true,
|
||||
showSaveDraft: true,
|
||||
showLogs: false,
|
||||
showAccessDetails: false,
|
||||
trackChanges: false,
|
||||
},
|
||||
@@ -439,12 +436,29 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
const response = await userService.getUserTools(token);
|
||||
if (!response.ok) throw new Error('Failed to fetch tools');
|
||||
const data = await response.json();
|
||||
// Group ordering: builtins -> defaults -> user tools (sorted via the
|
||||
// MultiSelectPopup first-appearance grouping).
|
||||
const groupFor = (tool: UserToolType): string => {
|
||||
if (tool.builtin) return t('agents.form.toolsPopup.groupBuiltin');
|
||||
if (tool.default) return t('agents.form.toolsPopup.groupDefault');
|
||||
return t('agents.form.toolsPopup.groupCustom');
|
||||
};
|
||||
const tools: OptionType[] = data.tools.map((tool: UserToolType) => ({
|
||||
id: tool.id,
|
||||
label: getToolDisplayName(tool),
|
||||
icon: `/toolIcons/tool_${tool.name}.svg`,
|
||||
name: tool.name,
|
||||
group: groupFor(tool),
|
||||
}));
|
||||
const groupOrder = [
|
||||
t('agents.form.toolsPopup.groupBuiltin'),
|
||||
t('agents.form.toolsPopup.groupDefault'),
|
||||
t('agents.form.toolsPopup.groupCustom'),
|
||||
];
|
||||
tools.sort(
|
||||
(a, b) =>
|
||||
groupOrder.indexOf(a.group || '') - groupOrder.indexOf(b.group || ''),
|
||||
);
|
||||
setUserTools(tools);
|
||||
};
|
||||
const getModels = async () => {
|
||||
@@ -698,19 +712,21 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
jsonSchemaText !== initialJsonSchemaText;
|
||||
setHasChanges(isChanged);
|
||||
}, [agent, dispatch, effectiveMode, imageFile, jsonSchemaText]);
|
||||
// Only show the agent sub-nav once the agent has an id (i.e. not the bare
|
||||
// ``new`` mode). The sub-nav links to Logs/Schedules which require an id.
|
||||
const showAgentNav = effectiveMode === 'edit' && Boolean(agent.id);
|
||||
|
||||
return (
|
||||
<div className="flex flex-col px-4 pt-4 pb-2 max-[1179px]:min-h-dvh min-[1180px]:h-dvh md:px-12 md:pt-12 md:pb-3">
|
||||
<div className="flex items-center gap-3 px-4">
|
||||
<button
|
||||
className="border-border text-muted-foreground hover:bg-accent rounded-full border p-3 text-sm"
|
||||
onClick={handleCancel}
|
||||
>
|
||||
<img src={ArrowLeft} alt="left-arrow" className="h-3 w-3" />
|
||||
</button>
|
||||
<p className="text-foreground dark:text-foreground mt-px text-sm font-semibold">
|
||||
{t('agents.backToAll')}
|
||||
</p>
|
||||
</div>
|
||||
{showAgentNav ? (
|
||||
<AgentPageHeader
|
||||
agentId={agent.id}
|
||||
agentName={agent.name}
|
||||
agentEditPath={`/agents/edit/${agent.id}`}
|
||||
currentPage="overview"
|
||||
className="px-4"
|
||||
/>
|
||||
) : null}
|
||||
<div className="mt-5 flex w-full flex-wrap items-center justify-between gap-2 px-4">
|
||||
<h1 className="text-foreground m-0 text-[32px] font-bold lg:text-[40px] dark:text-white">
|
||||
{modeConfig[effectiveMode].heading}
|
||||
@@ -753,15 +769,6 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
</span>
|
||||
</button>
|
||||
)}
|
||||
{modeConfig[effectiveMode].showAccessDetails && (
|
||||
<button
|
||||
className="group border-primary text-primary hover:bg-primary/90 flex items-center gap-2 rounded-3xl border border-solid px-5 py-2 text-sm font-medium transition-colors hover:text-white"
|
||||
onClick={() => navigate(`/agents/logs/${agent.id}`)}
|
||||
>
|
||||
<span className="block h-5 w-5 bg-[url('/src/assets/monitoring-purple.svg')] bg-contain bg-center bg-no-repeat transition-all group-hover:bg-[url('/src/assets/monitoring-white.svg')]" />
|
||||
{t('agents.form.buttons.logs')}
|
||||
</button>
|
||||
)}
|
||||
{modeConfig[effectiveMode].showAccessDetails && (
|
||||
<button
|
||||
className="border-primary text-primary hover:bg-primary/90 rounded-3xl border border-solid px-5 py-2 text-sm font-medium transition-colors hover:text-white"
|
||||
|
||||
@@ -3,6 +3,7 @@ import { Route, Routes } from 'react-router-dom';
|
||||
import AgentLogs from './AgentLogs';
|
||||
import AgentsList from './AgentsList';
|
||||
import NewAgent from './NewAgent';
|
||||
import SchedulesView from './schedules/SchedulesView';
|
||||
import SharedAgent from './SharedAgent';
|
||||
import WorkflowBuilder from './workflow/WorkflowBuilder';
|
||||
|
||||
@@ -13,6 +14,7 @@ export default function Agents() {
|
||||
<Route path="/new" element={<NewAgent mode="new" />} />
|
||||
<Route path="/edit/:agentId" element={<NewAgent mode="edit" />} />
|
||||
<Route path="/logs/:agentId" element={<AgentLogs />} />
|
||||
<Route path="/schedules/:agentId" element={<SchedulesView />} />
|
||||
<Route path="/shared/:agentId" element={<SharedAgent />} />
|
||||
<Route path="/workflow/new" element={<WorkflowBuilder />} />
|
||||
<Route path="/workflow/edit/:agentId" element={<WorkflowBuilder />} />
|
||||
|
||||
98
frontend/src/agents/schedules/RunDetailDrawer.tsx
Normal file
98
frontend/src/agents/schedules/RunDetailDrawer.tsx
Normal file
@@ -0,0 +1,98 @@
|
||||
import { useEffect } from 'react';
|
||||
|
||||
import type { ScheduleRun } from '../types/schedule';
|
||||
|
||||
export type RunDetailDrawerProps = {
|
||||
run: ScheduleRun | null;
|
||||
onClose: () => void;
|
||||
};
|
||||
|
||||
const formatTimestamp = (value?: string | null): string => {
|
||||
if (!value) return '—';
|
||||
const d = new Date(value);
|
||||
if (Number.isNaN(d.getTime())) return value;
|
||||
return d.toLocaleString();
|
||||
};
|
||||
|
||||
/** Side drawer with a single run's output / error (terminal-state only). */
|
||||
export default function RunDetailDrawer({
|
||||
run,
|
||||
onClose,
|
||||
}: RunDetailDrawerProps) {
|
||||
useEffect(() => {
|
||||
if (!run) return;
|
||||
const onKey = (e: KeyboardEvent) => {
|
||||
if (e.key === 'Escape') onClose();
|
||||
};
|
||||
document.addEventListener('keydown', onKey);
|
||||
return () => document.removeEventListener('keydown', onKey);
|
||||
}, [run, onClose]);
|
||||
|
||||
if (!run) return null;
|
||||
return (
|
||||
<>
|
||||
<div
|
||||
className="fixed inset-0 z-20 bg-black/40"
|
||||
onClick={onClose}
|
||||
aria-hidden="true"
|
||||
/>
|
||||
<aside
|
||||
className="border-border bg-card fixed top-0 right-0 z-30 flex h-full w-full max-w-xl flex-col border-l p-6 shadow-lg"
|
||||
role="dialog"
|
||||
aria-label="Schedule run details"
|
||||
>
|
||||
<div className="mb-4 flex items-center justify-between">
|
||||
<h2 className="text-lg font-semibold">Run details</h2>
|
||||
<button
|
||||
type="button"
|
||||
onClick={onClose}
|
||||
className="border-border text-muted-foreground rounded-md border px-3 py-1 text-sm"
|
||||
>
|
||||
Close
|
||||
</button>
|
||||
</div>
|
||||
<dl className="mb-4 grid grid-cols-2 gap-2 text-sm">
|
||||
<dt className="text-muted-foreground">Status</dt>
|
||||
<dd>{run.status}</dd>
|
||||
<dt className="text-muted-foreground">Scheduled for</dt>
|
||||
<dd>{formatTimestamp(run.scheduled_for)}</dd>
|
||||
<dt className="text-muted-foreground">Started</dt>
|
||||
<dd>{formatTimestamp(run.started_at)}</dd>
|
||||
<dt className="text-muted-foreground">Finished</dt>
|
||||
<dd>{formatTimestamp(run.finished_at)}</dd>
|
||||
<dt className="text-muted-foreground">Tokens</dt>
|
||||
<dd>
|
||||
{run.prompt_tokens} prompt · {run.generated_tokens} generated
|
||||
</dd>
|
||||
<dt className="text-muted-foreground">Trigger</dt>
|
||||
<dd>{run.trigger_source}</dd>
|
||||
</dl>
|
||||
{run.error && (
|
||||
<section className="mb-4">
|
||||
<h3 className="text-destructive text-sm font-semibold">
|
||||
Error{run.error_type ? ` (${run.error_type})` : ''}
|
||||
</h3>
|
||||
<pre className="bg-background mt-1 max-h-48 overflow-auto rounded-md p-3 font-mono text-xs">
|
||||
{run.error}
|
||||
</pre>
|
||||
</section>
|
||||
)}
|
||||
{run.output && (
|
||||
<section className="flex-1 overflow-hidden">
|
||||
<h3 className="text-sm font-semibold">
|
||||
Output
|
||||
{run.output_truncated && (
|
||||
<span className="text-muted-foreground ml-1 text-xs">
|
||||
(truncated)
|
||||
</span>
|
||||
)}
|
||||
</h3>
|
||||
<pre className="bg-background mt-1 h-full overflow-auto rounded-md p-3 font-mono text-xs whitespace-pre-wrap">
|
||||
{run.output}
|
||||
</pre>
|
||||
</section>
|
||||
)}
|
||||
</aside>
|
||||
</>
|
||||
);
|
||||
}
|
||||
88
frontend/src/agents/schedules/RunLog.tsx
Normal file
88
frontend/src/agents/schedules/RunLog.tsx
Normal file
@@ -0,0 +1,88 @@
|
||||
import { useEffect } from 'react';
|
||||
import { useDispatch, useSelector } from 'react-redux';
|
||||
|
||||
import { selectToken } from '../../preferences/preferenceSlice';
|
||||
import type { AppDispatch, RootState } from '../../store';
|
||||
import type { ScheduleRun } from '../types/schedule';
|
||||
import { loadRunsForSchedule, selectRunsForSchedule } from './schedulesSlice';
|
||||
|
||||
export type RunLogProps = {
|
||||
scheduleId: string;
|
||||
onSelect?: (run: ScheduleRun) => void;
|
||||
};
|
||||
|
||||
const STATUS_STYLES: Record<string, string> = {
|
||||
success: 'text-green-600',
|
||||
failed: 'text-destructive',
|
||||
timeout: 'text-amber-600',
|
||||
skipped: 'text-muted-foreground',
|
||||
running: 'text-blue-600',
|
||||
pending: 'text-muted-foreground',
|
||||
};
|
||||
|
||||
const formatTimestamp = (value?: string | null): string => {
|
||||
if (!value) return '—';
|
||||
const d = new Date(value);
|
||||
if (Number.isNaN(d.getTime())) return value;
|
||||
return d.toLocaleString();
|
||||
};
|
||||
|
||||
/** Paginated run log for a schedule (SSE updates merge via schedulesSlice). */
|
||||
export default function RunLog({ scheduleId, onSelect }: RunLogProps) {
|
||||
const dispatch = useDispatch<AppDispatch>();
|
||||
const token = useSelector(selectToken);
|
||||
const runs = useSelector((state: RootState) =>
|
||||
selectRunsForSchedule(state, scheduleId),
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (!scheduleId) return;
|
||||
dispatch(loadRunsForSchedule({ id: scheduleId, token }));
|
||||
}, [dispatch, scheduleId, token]);
|
||||
|
||||
if (runs.length === 0) {
|
||||
return (
|
||||
<p className="text-muted-foreground py-3 text-sm">
|
||||
No runs recorded for this schedule yet.
|
||||
</p>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<table className="w-full text-left text-sm">
|
||||
<thead className="text-muted-foreground text-xs uppercase">
|
||||
<tr>
|
||||
<th className="py-2">When</th>
|
||||
<th className="py-2">Status</th>
|
||||
<th className="py-2">Tokens</th>
|
||||
<th className="py-2">Trigger</th>
|
||||
<th className="py-2"></th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{runs.map((run) => (
|
||||
<tr key={run.id} className="border-border border-t">
|
||||
<td className="py-2">{formatTimestamp(run.scheduled_for)}</td>
|
||||
<td className={`py-2 ${STATUS_STYLES[run.status] ?? ''}`}>
|
||||
{run.status}
|
||||
{run.error_type ? ` (${run.error_type})` : ''}
|
||||
</td>
|
||||
<td className="py-2">{run.prompt_tokens + run.generated_tokens}</td>
|
||||
<td className="py-2">{run.trigger_source}</td>
|
||||
<td className="py-2">
|
||||
{onSelect && (
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => onSelect(run)}
|
||||
className="text-primary text-xs underline"
|
||||
>
|
||||
Details
|
||||
</button>
|
||||
)}
|
||||
</td>
|
||||
</tr>
|
||||
))}
|
||||
</tbody>
|
||||
</table>
|
||||
);
|
||||
}
|
||||
557
frontend/src/agents/schedules/ScheduleFormModal.tsx
Normal file
557
frontend/src/agents/schedules/ScheduleFormModal.tsx
Normal file
@@ -0,0 +1,557 @@
|
||||
import { CalendarIcon } from 'lucide-react';
|
||||
import { useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import { Button } from '@/components/ui/button';
|
||||
import { Calendar } from '@/components/ui/calendar';
|
||||
import {
|
||||
Popover,
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
} from '@/components/ui/popover';
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
SelectItem,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from '@/components/ui/select';
|
||||
import { TimePicker } from '@/components/ui/time-picker';
|
||||
import { cn } from '@/lib/utils';
|
||||
|
||||
import WrapperModal from '../../modals/WrapperModal';
|
||||
import type { Schedule, ScheduleCreatePayload } from '../types/schedule';
|
||||
import {
|
||||
browserTimezone,
|
||||
buildCron,
|
||||
buildRunAtUtc,
|
||||
parseScheduleToFormValues,
|
||||
supportedTimezones,
|
||||
todayDate,
|
||||
type ScheduleFormValues,
|
||||
type ScheduleFrequency,
|
||||
} from './cronBuilder';
|
||||
import TimezoneCombobox from './TimezoneCombobox';
|
||||
|
||||
export type ScheduleFormModalProps = {
|
||||
open: boolean;
|
||||
initial?: Schedule | null;
|
||||
agentToolIds: string[];
|
||||
onClose: () => void;
|
||||
onSubmit: (payload: ScheduleCreatePayload) => Promise<void> | void;
|
||||
submitting?: boolean;
|
||||
};
|
||||
|
||||
const FREQUENCIES: ScheduleFrequency[] = [
|
||||
'once',
|
||||
'daily',
|
||||
'weekly',
|
||||
'monthly',
|
||||
'yearly',
|
||||
];
|
||||
|
||||
// 0=Sun ... 6=Sat (matches POSIX cron's dow field).
|
||||
const DAY_OPTIONS = [
|
||||
{ value: 1, key: 'mon' },
|
||||
{ value: 2, key: 'tue' },
|
||||
{ value: 3, key: 'wed' },
|
||||
{ value: 4, key: 'thu' },
|
||||
{ value: 5, key: 'fri' },
|
||||
{ value: 6, key: 'sat' },
|
||||
{ value: 0, key: 'sun' },
|
||||
] as const;
|
||||
|
||||
const MONTH_KEYS = [
|
||||
'jan',
|
||||
'feb',
|
||||
'mar',
|
||||
'apr',
|
||||
'may',
|
||||
'jun',
|
||||
'jul',
|
||||
'aug',
|
||||
'sep',
|
||||
'oct',
|
||||
'nov',
|
||||
'dec',
|
||||
] as const;
|
||||
|
||||
/** Parse ``YYYY-MM-DD`` into a local Date (no tz drift for calendar use). */
|
||||
const dateStringToDate = (value: string): Date | undefined => {
|
||||
const m = /^(\d{4})-(\d{1,2})-(\d{1,2})$/.exec(value ?? '');
|
||||
if (!m) return undefined;
|
||||
return new Date(Number(m[1]), Number(m[2]) - 1, Number(m[3]));
|
||||
};
|
||||
|
||||
const dateToDateString = (d: Date): string => {
|
||||
const y = d.getFullYear();
|
||||
const m = String(d.getMonth() + 1).padStart(2, '0');
|
||||
const day = String(d.getDate()).padStart(2, '0');
|
||||
return `${y}-${m}-${day}`;
|
||||
};
|
||||
|
||||
const formatDateLabel = (value: string): string => {
|
||||
const d = dateStringToDate(value);
|
||||
if (!d) return '';
|
||||
return d.toLocaleDateString(undefined, {
|
||||
year: 'numeric',
|
||||
month: 'short',
|
||||
day: 'numeric',
|
||||
});
|
||||
};
|
||||
|
||||
/** Create/edit a Schedule via a modal dialog. */
|
||||
export default function ScheduleFormModal({
|
||||
open,
|
||||
initial,
|
||||
agentToolIds,
|
||||
onClose,
|
||||
onSubmit,
|
||||
submitting,
|
||||
}: ScheduleFormModalProps) {
|
||||
const { t } = useTranslation();
|
||||
// Edit mode pre-populates from the saved schedule; create mode uses the browser tz.
|
||||
const initialTimezone = useMemo<string>(
|
||||
() => initial?.timezone || browserTimezone(),
|
||||
[initial?.timezone],
|
||||
);
|
||||
|
||||
const defaults: ScheduleFormValues = useMemo(
|
||||
() =>
|
||||
initial
|
||||
? parseScheduleToFormValues(initial, initialTimezone)
|
||||
: {
|
||||
frequency: 'daily',
|
||||
date: todayDate(initialTimezone),
|
||||
time: '09:00',
|
||||
dayOfWeek: 1,
|
||||
dayOfMonth: 1,
|
||||
month: 1,
|
||||
},
|
||||
[initial, initialTimezone],
|
||||
);
|
||||
|
||||
const [name, setName] = useState<string>(initial?.name ?? '');
|
||||
const [instruction, setInstruction] = useState<string>(
|
||||
initial?.instruction ?? '',
|
||||
);
|
||||
const [values, setValues] = useState<ScheduleFormValues>(defaults);
|
||||
const [timezone, setTimezone] = useState<string>(initialTimezone);
|
||||
const timezoneOptions = useMemo<string[]>(() => {
|
||||
const list = supportedTimezones();
|
||||
// Make sure the current selection is always present, even if absent from
|
||||
// the engine's supported list (e.g. an exotic tz saved on the schedule).
|
||||
return list.includes(timezone) ? list : [timezone, ...list];
|
||||
}, [timezone]);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
if (!open) return null;
|
||||
|
||||
const setFrequency = (frequency: ScheduleFrequency) =>
|
||||
setValues((current) => ({ ...current, frequency }));
|
||||
|
||||
const submit = async () => {
|
||||
if (!instruction.trim()) {
|
||||
setError(t('agents.schedules.modal.errors.instructionRequired'));
|
||||
return;
|
||||
}
|
||||
const payload: ScheduleCreatePayload = {
|
||||
instruction: instruction.trim(),
|
||||
timezone,
|
||||
name: name.trim() || undefined,
|
||||
tool_allowlist: agentToolIds,
|
||||
};
|
||||
if (values.frequency === 'once') {
|
||||
let runAt: string;
|
||||
try {
|
||||
runAt = buildRunAtUtc(values.date, values.time, timezone);
|
||||
} catch {
|
||||
setError(t('agents.schedules.modal.errors.runAtInPast'));
|
||||
return;
|
||||
}
|
||||
if (new Date(runAt).getTime() <= Date.now()) {
|
||||
setError(t('agents.schedules.modal.errors.runAtInPast'));
|
||||
return;
|
||||
}
|
||||
payload.trigger_type = 'once';
|
||||
payload.run_at = runAt;
|
||||
} else {
|
||||
const cron = buildCron(values.frequency, values);
|
||||
if (!cron) {
|
||||
setError(t('agents.schedules.modal.errors.instructionRequired'));
|
||||
return;
|
||||
}
|
||||
payload.trigger_type = 'recurring';
|
||||
payload.cron = cron;
|
||||
}
|
||||
setError(null);
|
||||
await onSubmit(payload);
|
||||
};
|
||||
|
||||
const isEdit = Boolean(initial?.id);
|
||||
|
||||
return (
|
||||
<WrapperModal
|
||||
className="w-[min(560px,92vw)] sm:p-6"
|
||||
contentClassName="max-h-[80vh]"
|
||||
close={onClose}
|
||||
isPerformingTask={submitting}
|
||||
>
|
||||
<div className="flex flex-col gap-5">
|
||||
<div className="flex items-start gap-3 pr-6">
|
||||
<input
|
||||
type="text"
|
||||
value={name}
|
||||
onChange={(e) => setName(e.target.value)}
|
||||
placeholder={t('agents.schedules.modal.namePlaceholder')}
|
||||
className="text-foreground placeholder:text-muted-foreground w-full bg-transparent text-xl font-semibold outline-none"
|
||||
aria-label={t('agents.schedules.modal.namePlaceholder')}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<FrequencyTabs
|
||||
frequency={values.frequency}
|
||||
onChange={setFrequency}
|
||||
labels={{
|
||||
once: t('agents.schedules.modal.frequency.once'),
|
||||
daily: t('agents.schedules.modal.frequency.daily'),
|
||||
weekly: t('agents.schedules.modal.frequency.weekly'),
|
||||
monthly: t('agents.schedules.modal.frequency.monthly'),
|
||||
yearly: t('agents.schedules.modal.frequency.yearly'),
|
||||
}}
|
||||
/>
|
||||
|
||||
<OnPicker
|
||||
values={values}
|
||||
onChange={setValues}
|
||||
tDay={(key) => t(`agents.schedules.modal.days.${key}`)}
|
||||
tMonth={(key) => t(`agents.schedules.modal.months.${key}`)}
|
||||
labels={{
|
||||
on: t('agents.schedules.modal.on'),
|
||||
at: t('agents.schedules.modal.at'),
|
||||
pickDate: t('agents.schedules.modal.pickDate'),
|
||||
}}
|
||||
/>
|
||||
|
||||
<div className="border-border flex flex-wrap items-center justify-between gap-2 rounded-md border p-3">
|
||||
<span className="text-foreground text-sm font-medium">
|
||||
{t('agents.schedules.modal.timezone')}
|
||||
</span>
|
||||
<div className="w-full max-w-[16rem]">
|
||||
<TimezoneCombobox
|
||||
value={timezone}
|
||||
options={timezoneOptions}
|
||||
onChange={setTimezone}
|
||||
placeholder={t('agents.schedules.modal.timezonePlaceholder')}
|
||||
searchPlaceholder={t(
|
||||
'agents.schedules.modal.timezoneSearchPlaceholder',
|
||||
)}
|
||||
emptyText={t('agents.schedules.modal.timezoneEmpty')}
|
||||
ariaLabel={t('agents.schedules.modal.timezone')}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<label className="flex flex-col gap-2">
|
||||
<span className="text-foreground text-sm font-medium">
|
||||
{t('agents.schedules.modal.instructionsLabel')}
|
||||
</span>
|
||||
<textarea
|
||||
value={instruction}
|
||||
onChange={(e) => setInstruction(e.target.value)}
|
||||
placeholder={t('agents.schedules.modal.instructionsPlaceholder')}
|
||||
rows={5}
|
||||
className="border-border bg-background text-foreground placeholder:text-muted-foreground focus:border-ring focus:ring-ring/40 rounded-md border px-3 py-2 text-sm outline-none focus:ring-2"
|
||||
/>
|
||||
</label>
|
||||
|
||||
{error && <p className="text-destructive text-sm">{error}</p>}
|
||||
|
||||
<div className="flex justify-end">
|
||||
<Button
|
||||
type="button"
|
||||
disabled={submitting}
|
||||
onClick={submit}
|
||||
className="rounded-full px-5"
|
||||
>
|
||||
{submitting
|
||||
? '…'
|
||||
: isEdit
|
||||
? t('agents.schedules.modal.save')
|
||||
: t('agents.schedules.modal.create')}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</WrapperModal>
|
||||
);
|
||||
}
|
||||
|
||||
type FrequencyTabsProps = {
|
||||
frequency: ScheduleFrequency;
|
||||
onChange: (f: ScheduleFrequency) => void;
|
||||
labels: Record<ScheduleFrequency, string>;
|
||||
};
|
||||
|
||||
function FrequencyTabs({ frequency, onChange, labels }: FrequencyTabsProps) {
|
||||
return (
|
||||
<div className="bg-muted/60 dark:bg-muted/40 inline-flex w-full gap-1 rounded-full p-1">
|
||||
{FREQUENCIES.map((f) => {
|
||||
const active = f === frequency;
|
||||
return (
|
||||
<button
|
||||
key={f}
|
||||
type="button"
|
||||
onClick={() => onChange(f)}
|
||||
className={cn(
|
||||
'flex-1 rounded-full px-3 py-1.5 text-xs font-medium transition-colors',
|
||||
active
|
||||
? 'bg-card text-foreground shadow-sm'
|
||||
: 'text-muted-foreground hover:text-foreground',
|
||||
)}
|
||||
aria-pressed={active}
|
||||
>
|
||||
{labels[f]}
|
||||
</button>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
type OnPickerProps = {
|
||||
values: ScheduleFormValues;
|
||||
onChange: (next: ScheduleFormValues) => void;
|
||||
tDay: (key: string) => string;
|
||||
tMonth: (key: string) => string;
|
||||
labels: { on: string; at: string; pickDate: string };
|
||||
};
|
||||
|
||||
function OnPicker({ values, onChange, tDay, tMonth, labels }: OnPickerProps) {
|
||||
const set = (patch: Partial<ScheduleFormValues>) =>
|
||||
onChange({ ...values, ...patch });
|
||||
|
||||
return (
|
||||
<div className="border-border flex flex-col gap-3 rounded-md border p-3">
|
||||
{values.frequency === 'once' && (
|
||||
<div className="flex flex-wrap items-center justify-between gap-2">
|
||||
<span className="text-foreground text-sm font-medium">
|
||||
{labels.on}
|
||||
</span>
|
||||
<div className="flex items-center gap-2">
|
||||
<DatePicker
|
||||
value={values.date}
|
||||
onChange={(date) => set({ date })}
|
||||
placeholder={labels.pickDate}
|
||||
/>
|
||||
<TimeInput
|
||||
value={values.time}
|
||||
onChange={(time) => set({ time })}
|
||||
ariaLabel={labels.at}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{values.frequency === 'daily' && (
|
||||
<div className="flex items-center justify-between gap-2">
|
||||
<span className="text-foreground text-sm font-medium">
|
||||
{labels.at}
|
||||
</span>
|
||||
<TimeInput
|
||||
value={values.time}
|
||||
onChange={(time) => set({ time })}
|
||||
ariaLabel={labels.at}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{values.frequency === 'weekly' && (
|
||||
<div className="flex flex-col gap-2">
|
||||
<div className="flex flex-wrap gap-1">
|
||||
{DAY_OPTIONS.map((d) => {
|
||||
const active = d.value === values.dayOfWeek;
|
||||
return (
|
||||
<Button
|
||||
key={d.key}
|
||||
type="button"
|
||||
size="sm"
|
||||
variant={active ? 'default' : 'outline'}
|
||||
onClick={() => set({ dayOfWeek: d.value })}
|
||||
className="rounded-full"
|
||||
aria-pressed={active}
|
||||
>
|
||||
{tDay(d.key)}
|
||||
</Button>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
<div className="flex items-center justify-between gap-2">
|
||||
<span className="text-foreground text-sm font-medium">
|
||||
{labels.at}
|
||||
</span>
|
||||
<TimeInput
|
||||
value={values.time}
|
||||
onChange={(time) => set({ time })}
|
||||
ariaLabel={labels.at}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{values.frequency === 'monthly' && (
|
||||
<div className="flex flex-wrap items-center justify-between gap-2">
|
||||
<span className="text-foreground text-sm font-medium">
|
||||
{labels.on}
|
||||
</span>
|
||||
<div className="flex items-center gap-2">
|
||||
<DayOfMonthSelect
|
||||
value={values.dayOfMonth}
|
||||
onChange={(dayOfMonth) => set({ dayOfMonth })}
|
||||
ariaLabel={labels.on}
|
||||
/>
|
||||
<TimeInput
|
||||
value={values.time}
|
||||
onChange={(time) => set({ time })}
|
||||
ariaLabel={labels.at}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{values.frequency === 'yearly' && (
|
||||
<div className="flex flex-wrap items-center justify-between gap-2">
|
||||
<span className="text-foreground text-sm font-medium">
|
||||
{labels.on}
|
||||
</span>
|
||||
<div className="flex flex-wrap items-center gap-2">
|
||||
<MonthSelect
|
||||
value={values.month}
|
||||
onChange={(month) => set({ month })}
|
||||
tMonth={tMonth}
|
||||
ariaLabel={labels.on}
|
||||
/>
|
||||
<DayOfMonthSelect
|
||||
value={values.dayOfMonth}
|
||||
onChange={(dayOfMonth) => set({ dayOfMonth })}
|
||||
ariaLabel={labels.on}
|
||||
/>
|
||||
<TimeInput
|
||||
value={values.time}
|
||||
onChange={(time) => set({ time })}
|
||||
ariaLabel={labels.at}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
type DatePickerProps = {
|
||||
value: string;
|
||||
onChange: (next: string) => void;
|
||||
placeholder: string;
|
||||
};
|
||||
|
||||
function DatePicker({ value, onChange, placeholder }: DatePickerProps) {
|
||||
const [open, setOpen] = useState<boolean>(false);
|
||||
const selected = dateStringToDate(value);
|
||||
return (
|
||||
<Popover open={open} onOpenChange={setOpen}>
|
||||
<PopoverTrigger asChild>
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="sm"
|
||||
aria-label={placeholder}
|
||||
className={cn(
|
||||
'h-9 justify-start gap-2 px-3 font-normal',
|
||||
!value && 'text-muted-foreground',
|
||||
)}
|
||||
>
|
||||
<CalendarIcon className="size-4 opacity-70" />
|
||||
{value ? formatDateLabel(value) : placeholder}
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
{/* z-200 keeps the popover above WrapperModal (z-100); matches SelectContent. */}
|
||||
<PopoverContent className="z-200 w-auto p-0" align="start">
|
||||
<Calendar
|
||||
mode="single"
|
||||
selected={selected}
|
||||
onSelect={(d) => {
|
||||
if (!d) return;
|
||||
onChange(dateToDateString(d));
|
||||
setOpen(false);
|
||||
}}
|
||||
captionLayout="dropdown"
|
||||
/>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
);
|
||||
}
|
||||
|
||||
type TimeInputProps = {
|
||||
value: string;
|
||||
onChange: (next: string) => void;
|
||||
ariaLabel: string;
|
||||
};
|
||||
|
||||
// Theme-aware replacement for <input type="time"> (clock icon + hours/minutes selects).
|
||||
function TimeInput({ value, onChange, ariaLabel }: TimeInputProps) {
|
||||
return <TimePicker value={value} onChange={onChange} ariaLabel={ariaLabel} />;
|
||||
}
|
||||
|
||||
type DayOfMonthSelectProps = {
|
||||
value: number;
|
||||
onChange: (next: number) => void;
|
||||
ariaLabel: string;
|
||||
};
|
||||
|
||||
function DayOfMonthSelect({
|
||||
value,
|
||||
onChange,
|
||||
ariaLabel,
|
||||
}: DayOfMonthSelectProps) {
|
||||
return (
|
||||
<Select value={String(value)} onValueChange={(v) => onChange(Number(v))}>
|
||||
<SelectTrigger size="sm" aria-label={ariaLabel} className="h-9 w-[5rem]">
|
||||
<SelectValue />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
{Array.from({ length: 31 }, (_, i) => i + 1).map((d) => (
|
||||
<SelectItem key={d} value={String(d)}>
|
||||
{d}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
);
|
||||
}
|
||||
|
||||
type MonthSelectProps = {
|
||||
value: number;
|
||||
onChange: (next: number) => void;
|
||||
tMonth: (key: string) => string;
|
||||
ariaLabel: string;
|
||||
};
|
||||
|
||||
function MonthSelect({ value, onChange, tMonth, ariaLabel }: MonthSelectProps) {
|
||||
return (
|
||||
<Select value={String(value)} onValueChange={(v) => onChange(Number(v))}>
|
||||
<SelectTrigger
|
||||
size="sm"
|
||||
aria-label={ariaLabel}
|
||||
className="h-9 w-[6.5rem]"
|
||||
>
|
||||
<SelectValue />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
{MONTH_KEYS.map((k, i) => (
|
||||
<SelectItem key={k} value={String(i + 1)}>
|
||||
{tMonth(k)}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
);
|
||||
}
|
||||
37
frontend/src/agents/schedules/SchedulerToolCallCard.test.ts
Normal file
37
frontend/src/agents/schedules/SchedulerToolCallCard.test.ts
Normal file
@@ -0,0 +1,37 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import { extractToolError } from './SchedulerToolCallCard';
|
||||
|
||||
// Regression for the iter-6 issue where ``cancel_scheduled_task`` returning
|
||||
// a plain ``"Error: …"`` string still rendered "Scheduled task cancelled."
|
||||
// The fix is to extract the error message so the card can branch on it.
|
||||
describe('extractToolError', () => {
|
||||
it('returns the message for an Error: prefixed string', () => {
|
||||
expect(
|
||||
extractToolError('Error: scheduled task not found or already terminal.'),
|
||||
).toBe('scheduled task not found or already terminal.');
|
||||
});
|
||||
|
||||
it('trims leading whitespace before the prefix', () => {
|
||||
expect(extractToolError(' Error: foo ')).toBe('foo');
|
||||
});
|
||||
|
||||
it('returns null for JSON success payloads', () => {
|
||||
expect(
|
||||
extractToolError(JSON.stringify({ task_id: 'x', status: 'cancelled' })),
|
||||
).toBeNull();
|
||||
});
|
||||
|
||||
it('returns null for plain non-error strings', () => {
|
||||
expect(extractToolError('done')).toBeNull();
|
||||
});
|
||||
|
||||
it('returns null for object results', () => {
|
||||
expect(extractToolError({ task_id: 'x' })).toBeNull();
|
||||
});
|
||||
|
||||
it('returns null for undefined / null', () => {
|
||||
expect(extractToolError(undefined)).toBeNull();
|
||||
expect(extractToolError(null)).toBeNull();
|
||||
});
|
||||
});
|
||||
176
frontend/src/agents/schedules/SchedulerToolCallCard.tsx
Normal file
176
frontend/src/agents/schedules/SchedulerToolCallCard.tsx
Normal file
@@ -0,0 +1,176 @@
|
||||
import { useEffect, useState } from 'react';
|
||||
import { useDispatch, useSelector } from 'react-redux';
|
||||
|
||||
import { selectToken } from '../../preferences/preferenceSlice';
|
||||
import type { AppDispatch } from '../../store';
|
||||
import { deleteSchedule, loadSchedulesForAgent } from './schedulesSlice';
|
||||
|
||||
export type SchedulerToolCallCardProps = {
|
||||
/** Outcome JSON the scheduler tool returned (action result). */
|
||||
result?: unknown;
|
||||
/** Action name dispatched by the LLM. */
|
||||
actionName: string;
|
||||
/** Status of this tool call (pending → completed). */
|
||||
status?: string;
|
||||
/** Agent id, for live-refresh of the cancel action. */
|
||||
agentId?: string;
|
||||
};
|
||||
|
||||
const formatTimestamp = (value?: string | null): string => {
|
||||
if (!value) return '—';
|
||||
const d = new Date(value);
|
||||
if (Number.isNaN(d.getTime())) return value;
|
||||
return d.toLocaleString();
|
||||
};
|
||||
|
||||
const parseResult = (result: unknown): Record<string, unknown> | null => {
|
||||
if (!result) return null;
|
||||
if (typeof result === 'object') return result as Record<string, unknown>;
|
||||
if (typeof result === 'string') {
|
||||
try {
|
||||
return JSON.parse(result) as Record<string, unknown>;
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
};
|
||||
|
||||
/** Tool returns a plain "Error: …" string on failure (cancel-not-found etc). */
|
||||
export const extractToolError = (result: unknown): string | null => {
|
||||
if (typeof result === 'string') {
|
||||
const trimmed = result.trim();
|
||||
if (trimmed.startsWith('Error:')) {
|
||||
return trimmed.slice('Error:'.length).trim();
|
||||
}
|
||||
}
|
||||
return null;
|
||||
};
|
||||
|
||||
/** In-chat card for scheduler.schedule_task with a one-click cancel. */
|
||||
export default function SchedulerToolCallCard({
|
||||
result,
|
||||
actionName,
|
||||
status,
|
||||
agentId,
|
||||
}: SchedulerToolCallCardProps) {
|
||||
const dispatch = useDispatch<AppDispatch>();
|
||||
const token = useSelector(selectToken);
|
||||
const [cancelled, setCancelled] = useState<boolean>(false);
|
||||
const parsed = parseResult(result);
|
||||
const taskId =
|
||||
parsed && typeof parsed.task_id === 'string' ? parsed.task_id : null;
|
||||
const runAt =
|
||||
parsed && typeof parsed.resolved_run_at === 'string'
|
||||
? parsed.resolved_run_at
|
||||
: null;
|
||||
const instruction =
|
||||
parsed && typeof parsed.instruction === 'string'
|
||||
? parsed.instruction
|
||||
: null;
|
||||
const error =
|
||||
parsed && typeof parsed.error === 'string' ? parsed.error : null;
|
||||
|
||||
// Agent-bound chats prime the Schedules tab cache; agentless chats have
|
||||
// no per-agent listing, so skip the fetch.
|
||||
useEffect(() => {
|
||||
if (agentId) dispatch(loadSchedulesForAgent({ agentId, token }));
|
||||
}, [dispatch, agentId, token]);
|
||||
|
||||
const cancel = async () => {
|
||||
if (!taskId) return;
|
||||
setCancelled(true);
|
||||
try {
|
||||
await dispatch(deleteSchedule({ id: taskId, token })).unwrap();
|
||||
} catch (err) {
|
||||
setCancelled(false);
|
||||
console.error(err);
|
||||
}
|
||||
};
|
||||
|
||||
if (actionName.startsWith('cancel_scheduled_task')) {
|
||||
// The tool returns a plain "Error: …" string when the cancel fails
|
||||
// (not found, already terminal, invalid id). Don't claim success.
|
||||
const cancelError = extractToolError(result);
|
||||
if (cancelError) {
|
||||
return (
|
||||
<div className="border-border bg-card rounded-2xl border p-4 text-sm">
|
||||
<p className="text-destructive font-semibold">
|
||||
Cancel failed: {cancelError}
|
||||
</p>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<div className="border-border bg-card rounded-2xl border p-4 text-sm">
|
||||
<p className="font-semibold">Scheduled task cancelled.</p>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (actionName.startsWith('list_scheduled_tasks')) {
|
||||
const tasks = Array.isArray(parsed?.tasks)
|
||||
? (parsed?.tasks as Array<Record<string, unknown>>)
|
||||
: [];
|
||||
return (
|
||||
<div className="border-border bg-card rounded-2xl border p-4 text-sm">
|
||||
<p className="font-semibold">
|
||||
{tasks.length} pending scheduled task{tasks.length === 1 ? '' : 's'}
|
||||
</p>
|
||||
<ul className="mt-2 flex flex-col gap-1">
|
||||
{tasks.map((task) => (
|
||||
<li key={String(task.task_id)}>
|
||||
{formatTimestamp(task.resolved_run_at as string)} —{' '}
|
||||
{String(task.instruction || task.name || task.task_id)}
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// ``error`` may be JSON-shaped (``{"error": "…"}``) or a plain
|
||||
// ``"Error: …"`` string returned by the tool on validation failures.
|
||||
const schedulingError = error || extractToolError(result);
|
||||
if (schedulingError) {
|
||||
return (
|
||||
<div className="border-border bg-card rounded-2xl border p-4 text-sm">
|
||||
<p className="text-destructive font-semibold">
|
||||
Scheduling failed: {schedulingError}
|
||||
</p>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="border-border bg-card rounded-2xl border p-4 text-sm">
|
||||
<div className="flex items-center justify-between">
|
||||
<p className="font-semibold">
|
||||
{status === 'pending' ? '⏰ Scheduling…' : '⏰ Scheduled task'}
|
||||
</p>
|
||||
{runAt && (
|
||||
<span className="text-muted-foreground text-xs">
|
||||
{formatTimestamp(runAt)}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
{instruction && (
|
||||
<p className="text-muted-foreground mt-2 text-sm italic">
|
||||
“{instruction}”
|
||||
</p>
|
||||
)}
|
||||
{taskId && !cancelled && (
|
||||
<button
|
||||
type="button"
|
||||
onClick={cancel}
|
||||
className="text-destructive border-border mt-2 rounded-md border px-3 py-1 text-xs"
|
||||
>
|
||||
Cancel
|
||||
</button>
|
||||
)}
|
||||
{cancelled && (
|
||||
<p className="text-muted-foreground mt-2 text-xs">Cancelled.</p>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
396
frontend/src/agents/schedules/SchedulesView.tsx
Normal file
396
frontend/src/agents/schedules/SchedulesView.tsx
Normal file
@@ -0,0 +1,396 @@
|
||||
import { useEffect, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useDispatch, useSelector } from 'react-redux';
|
||||
import { useParams } from 'react-router-dom';
|
||||
|
||||
import { Button } from '@/components/ui/button';
|
||||
|
||||
import userService from '../../api/services/userService';
|
||||
import Spinner from '../../components/Spinner';
|
||||
import ConfirmationModal from '../../modals/ConfirmationModal';
|
||||
import { ActiveState } from '../../models/misc';
|
||||
import { selectToken } from '../../preferences/preferenceSlice';
|
||||
import type { AppDispatch, RootState } from '../../store';
|
||||
import AgentPageHeader from '../AgentPageHeader';
|
||||
import type { Agent } from '../types';
|
||||
import type {
|
||||
Schedule,
|
||||
ScheduleCreatePayload,
|
||||
ScheduleRun,
|
||||
} from '../types/schedule';
|
||||
import RunDetailDrawer from './RunDetailDrawer';
|
||||
import RunLog from './RunLog';
|
||||
import ScheduleFormModal from './ScheduleFormModal';
|
||||
import { formatCron } from './cronBuilder';
|
||||
import {
|
||||
createSchedule,
|
||||
deleteSchedule,
|
||||
loadSchedulesForAgent,
|
||||
runScheduleNow,
|
||||
selectSchedulesForAgent,
|
||||
setSchedulePaused,
|
||||
updateSchedule,
|
||||
} from './schedulesSlice';
|
||||
|
||||
const formatTimestamp = (value?: string | null): string => {
|
||||
if (!value) return '—';
|
||||
const d = new Date(value);
|
||||
if (Number.isNaN(d.getTime())) return value;
|
||||
return d.toLocaleString();
|
||||
};
|
||||
|
||||
/** Standalone Schedules page for an agent: list, create, edit, pause, run, delete. */
|
||||
export default function SchedulesView() {
|
||||
const { t } = useTranslation();
|
||||
const { agentId } = useParams();
|
||||
const dispatch = useDispatch<AppDispatch>();
|
||||
const token = useSelector(selectToken);
|
||||
|
||||
const [agent, setAgent] = useState<Agent | undefined>();
|
||||
const [loadingAgent, setLoadingAgent] = useState<boolean>(true);
|
||||
const [modalOpen, setModalOpen] = useState<boolean>(false);
|
||||
const [editing, setEditing] = useState<Schedule | null>(null);
|
||||
const [submitting, setSubmitting] = useState<boolean>(false);
|
||||
const [expanded, setExpanded] = useState<string | null>(null);
|
||||
const [activeRun, setActiveRun] = useState<ScheduleRun | null>(null);
|
||||
const [deleteConfirmation, setDeleteConfirmation] =
|
||||
useState<ActiveState>('INACTIVE');
|
||||
const [scheduleToDelete, setScheduleToDelete] = useState<Schedule | null>(
|
||||
null,
|
||||
);
|
||||
|
||||
const schedules = useSelector((state: RootState) =>
|
||||
selectSchedulesForAgent(state, agentId ?? ''),
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (!agentId) return;
|
||||
const fetchAgent = async () => {
|
||||
setLoadingAgent(true);
|
||||
try {
|
||||
const response = await userService.getAgent(agentId, token);
|
||||
if (!response.ok) throw new Error('Failed to fetch agent');
|
||||
const data = await response.json();
|
||||
setAgent(data);
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
} finally {
|
||||
setLoadingAgent(false);
|
||||
}
|
||||
};
|
||||
fetchAgent();
|
||||
}, [agentId, token]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!agentId) return;
|
||||
dispatch(loadSchedulesForAgent({ agentId, token }));
|
||||
}, [dispatch, agentId, token]);
|
||||
|
||||
const agentToolIds = useMemo<string[]>(() => {
|
||||
if (!agent) return [];
|
||||
const fromDetails = (agent.tool_details ?? []).map((d) => d.id);
|
||||
if (fromDetails.length > 0) return fromDetails;
|
||||
return agent.tools ?? [];
|
||||
}, [agent]);
|
||||
|
||||
const recurring = useMemo(
|
||||
() => schedules.filter((s) => s.trigger_type === 'recurring'),
|
||||
[schedules],
|
||||
);
|
||||
const oneTime = useMemo(
|
||||
() => schedules.filter((s) => s.trigger_type === 'once'),
|
||||
[schedules],
|
||||
);
|
||||
|
||||
const openCreate = () => {
|
||||
setEditing(null);
|
||||
setModalOpen(true);
|
||||
};
|
||||
|
||||
const openEdit = (schedule: Schedule) => {
|
||||
setEditing(schedule);
|
||||
setModalOpen(true);
|
||||
};
|
||||
|
||||
const closeModal = () => {
|
||||
if (submitting) return;
|
||||
setModalOpen(false);
|
||||
setEditing(null);
|
||||
};
|
||||
|
||||
const requestDelete = (schedule: Schedule) => {
|
||||
setScheduleToDelete(schedule);
|
||||
setDeleteConfirmation('ACTIVE');
|
||||
};
|
||||
|
||||
const confirmDelete = () => {
|
||||
if (!scheduleToDelete) return;
|
||||
dispatch(deleteSchedule({ id: scheduleToDelete.id, token }));
|
||||
setScheduleToDelete(null);
|
||||
};
|
||||
|
||||
const handleSubmit = async (payload: ScheduleCreatePayload) => {
|
||||
if (!agentId) return;
|
||||
setSubmitting(true);
|
||||
try {
|
||||
if (editing?.id) {
|
||||
await dispatch(
|
||||
updateSchedule({ id: editing.id, payload, token }),
|
||||
).unwrap();
|
||||
} else {
|
||||
await dispatch(createSchedule({ agentId, payload, token })).unwrap();
|
||||
}
|
||||
setModalOpen(false);
|
||||
setEditing(null);
|
||||
} catch (err) {
|
||||
console.error(err);
|
||||
} finally {
|
||||
setSubmitting(false);
|
||||
}
|
||||
};
|
||||
|
||||
const agentEditPath =
|
||||
agent?.agent_type === 'workflow'
|
||||
? `/agents/workflow/edit/${agentId}`
|
||||
: `/agents/edit/${agentId}`;
|
||||
|
||||
return (
|
||||
<div className="p-4 md:p-12">
|
||||
<AgentPageHeader
|
||||
agentId={agentId}
|
||||
agentName={agent?.name}
|
||||
agentEditPath={agentEditPath}
|
||||
currentPage="schedules"
|
||||
className="px-4"
|
||||
/>
|
||||
<div className="mt-5 flex w-full flex-wrap items-center justify-between gap-2 px-4">
|
||||
<h1 className="text-foreground m-0 text-[32px] font-bold md:text-[40px] dark:text-white">
|
||||
{t('agents.schedules.title')}
|
||||
</h1>
|
||||
</div>
|
||||
<div className="mt-6 flex flex-col gap-3 px-4">
|
||||
{agent && (
|
||||
<div className="flex flex-col gap-1">
|
||||
<p className="text-foreground">{agent.name}</p>
|
||||
<p className="text-muted-foreground text-xs">
|
||||
{agent.last_used_at
|
||||
? t('agents.logs.lastUsedAt') +
|
||||
' ' +
|
||||
new Date(agent.last_used_at).toLocaleString()
|
||||
: t('agents.logs.noUsageHistory')}
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
{loadingAgent ? (
|
||||
<div className="flex h-[55vh] w-full items-center justify-center">
|
||||
<Spinner />
|
||||
</div>
|
||||
) : (
|
||||
agent && (
|
||||
<div className="flex flex-col gap-4 p-4">
|
||||
<header className="flex items-center justify-between">
|
||||
<h2 className="text-lg font-semibold">
|
||||
{t('agents.schedules.heading')}
|
||||
</h2>
|
||||
<button
|
||||
type="button"
|
||||
onClick={openCreate}
|
||||
className="bg-primary text-primary-foreground hover:bg-primary/90 rounded-md px-3 py-1 text-sm"
|
||||
>
|
||||
{t('agents.schedules.newRecurring')}
|
||||
</button>
|
||||
</header>
|
||||
<section>
|
||||
<h3 className="text-muted-foreground mb-2 text-sm font-semibold uppercase">
|
||||
{t('agents.schedules.recurring')} ({recurring.length})
|
||||
</h3>
|
||||
{recurring.length === 0 ? (
|
||||
<p className="text-muted-foreground text-sm">
|
||||
{t('agents.schedules.noRecurring')}
|
||||
</p>
|
||||
) : (
|
||||
<ul className="flex flex-col gap-3">
|
||||
{recurring.map((schedule) => (
|
||||
<li
|
||||
key={schedule.id}
|
||||
className="border-border bg-card rounded-lg border p-3"
|
||||
>
|
||||
<div className="flex items-start justify-between">
|
||||
<div>
|
||||
<p className="font-semibold">
|
||||
{schedule.name || schedule.instruction.slice(0, 80)}
|
||||
</p>
|
||||
<p className="text-muted-foreground text-xs">
|
||||
{formatCron(schedule.cron)} · tz:{' '}
|
||||
{schedule.timezone} · status: {schedule.status} ·
|
||||
next: {formatTimestamp(schedule.next_run_at)}
|
||||
</p>
|
||||
</div>
|
||||
<div className="flex gap-1">
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={() => openEdit(schedule)}
|
||||
className="h-auto px-2 py-1 text-xs"
|
||||
>
|
||||
{t('agents.schedules.edit')}
|
||||
</Button>
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={() =>
|
||||
dispatch(
|
||||
setSchedulePaused({
|
||||
id: schedule.id,
|
||||
action:
|
||||
schedule.status === 'active'
|
||||
? 'pause'
|
||||
: 'resume',
|
||||
token,
|
||||
}),
|
||||
)
|
||||
}
|
||||
className="h-auto px-2 py-1 text-xs"
|
||||
>
|
||||
{schedule.status === 'active'
|
||||
? t('agents.schedules.pause')
|
||||
: t('agents.schedules.resume')}
|
||||
</Button>
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={() =>
|
||||
dispatch(
|
||||
runScheduleNow({ id: schedule.id, token }),
|
||||
)
|
||||
}
|
||||
className="h-auto px-2 py-1 text-xs"
|
||||
>
|
||||
{t('agents.schedules.runNow')}
|
||||
</Button>
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={() => requestDelete(schedule)}
|
||||
className="text-destructive hover:bg-destructive/10 hover:text-destructive h-auto px-2 py-1 text-xs"
|
||||
>
|
||||
{t('agents.schedules.delete')}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() =>
|
||||
setExpanded(
|
||||
expanded === schedule.id ? null : schedule.id,
|
||||
)
|
||||
}
|
||||
className="text-primary mt-2 text-xs underline"
|
||||
>
|
||||
{expanded === schedule.id
|
||||
? t('agents.schedules.hideRuns')
|
||||
: t('agents.schedules.showRuns')}
|
||||
</button>
|
||||
{expanded === schedule.id && (
|
||||
<div className="mt-2">
|
||||
<RunLog
|
||||
scheduleId={schedule.id}
|
||||
onSelect={(run) => setActiveRun(run)}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
)}
|
||||
</section>
|
||||
<section>
|
||||
<h3 className="text-muted-foreground mb-2 text-sm font-semibold uppercase">
|
||||
{t('agents.schedules.oneTime')} ({oneTime.length})
|
||||
</h3>
|
||||
{oneTime.length === 0 ? (
|
||||
<p className="text-muted-foreground text-sm">
|
||||
{t('agents.schedules.noOneTime')}
|
||||
</p>
|
||||
) : (
|
||||
<ul className="flex flex-col gap-2">
|
||||
{oneTime.map((schedule) => (
|
||||
<li
|
||||
key={schedule.id}
|
||||
className="border-border bg-card rounded-lg border p-3 text-sm"
|
||||
>
|
||||
<div className="flex items-start justify-between">
|
||||
<div>
|
||||
<p className="font-semibold">
|
||||
{schedule.name || schedule.instruction.slice(0, 80)}
|
||||
</p>
|
||||
<p className="text-muted-foreground text-xs">
|
||||
runs at {formatTimestamp(schedule.run_at)} · status:{' '}
|
||||
{schedule.status}
|
||||
</p>
|
||||
</div>
|
||||
<div className="flex gap-1">
|
||||
{schedule.status === 'active' && (
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={() => openEdit(schedule)}
|
||||
className="h-auto px-2 py-1 text-xs"
|
||||
>
|
||||
{t('agents.schedules.edit')}
|
||||
</Button>
|
||||
)}
|
||||
{schedule.status === 'active' && (
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={() => requestDelete(schedule)}
|
||||
className="text-destructive hover:bg-destructive/10 hover:text-destructive h-auto px-2 py-1 text-xs"
|
||||
>
|
||||
{t('agents.schedules.cancel')}
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
)}
|
||||
</section>
|
||||
<RunDetailDrawer
|
||||
run={activeRun}
|
||||
onClose={() => setActiveRun(null)}
|
||||
/>
|
||||
{modalOpen && (
|
||||
<ScheduleFormModal
|
||||
key={editing?.id ?? 'create'}
|
||||
open={modalOpen}
|
||||
initial={editing}
|
||||
agentToolIds={agentToolIds}
|
||||
onClose={closeModal}
|
||||
onSubmit={handleSubmit}
|
||||
submitting={submitting}
|
||||
/>
|
||||
)}
|
||||
<ConfirmationModal
|
||||
message={t('agents.schedules.deleteConfirm')}
|
||||
modalState={deleteConfirmation}
|
||||
setModalState={setDeleteConfirmation}
|
||||
submitLabel={t('agents.schedules.delete')}
|
||||
handleSubmit={confirmDelete}
|
||||
handleCancel={() => setScheduleToDelete(null)}
|
||||
variant="danger"
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
63
frontend/src/agents/schedules/TimezoneCombobox.test.ts
Normal file
63
frontend/src/agents/schedules/TimezoneCombobox.test.ts
Normal file
@@ -0,0 +1,63 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import { getTimezoneOffsetLabel, matchesTimezone } from './TimezoneCombobox';
|
||||
|
||||
describe('matchesTimezone', () => {
|
||||
it('matches case-insensitively by substring', () => {
|
||||
expect(matchesTimezone('Europe/Warsaw', 'war')).toBe(true);
|
||||
expect(matchesTimezone('Europe/Warsaw', 'WARSAW')).toBe(true);
|
||||
});
|
||||
|
||||
it('treats path separators as spaces so "europe war" matches', () => {
|
||||
expect(matchesTimezone('Europe/Warsaw', 'europe war')).toBe(true);
|
||||
});
|
||||
|
||||
it('treats underscores as spaces so "los angeles" matches "America/Los_Angeles"', () => {
|
||||
expect(matchesTimezone('America/Los_Angeles', 'los angeles')).toBe(true);
|
||||
});
|
||||
|
||||
it('rejects non-matching queries', () => {
|
||||
expect(matchesTimezone('Europe/Warsaw', 'tokyo')).toBe(false);
|
||||
expect(matchesTimezone('Asia/Tokyo', 'warsaw')).toBe(false);
|
||||
});
|
||||
|
||||
it('returns true for an empty query (no filter)', () => {
|
||||
expect(matchesTimezone('Anywhere', '')).toBe(true);
|
||||
expect(matchesTimezone('Anywhere', ' ')).toBe(true);
|
||||
});
|
||||
|
||||
it('requires all tokens to match (AND semantics)', () => {
|
||||
expect(matchesTimezone('Europe/Warsaw', 'europe tokyo')).toBe(false);
|
||||
expect(matchesTimezone('America/New_York', 'new york')).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getTimezoneOffsetLabel', () => {
|
||||
it('returns a UTC+ string for Europe/Warsaw (DST-dependent, so just check prefix)', () => {
|
||||
const label = getTimezoneOffsetLabel('Europe/Warsaw');
|
||||
expect(label.startsWith('UTC+')).toBe(true);
|
||||
});
|
||||
|
||||
it('renders the half-hour offset for Asia/Kolkata', () => {
|
||||
expect(getTimezoneOffsetLabel('Asia/Kolkata')).toContain('5:30');
|
||||
});
|
||||
|
||||
it('returns exactly "UTC" for the UTC zone (no +0 suffix)', () => {
|
||||
expect(getTimezoneOffsetLabel('UTC')).toBe('UTC');
|
||||
});
|
||||
|
||||
it('returns a UTC- string for America/Los_Angeles (always west of UTC)', () => {
|
||||
const label = getTimezoneOffsetLabel('America/Los_Angeles');
|
||||
expect(label.startsWith('UTC-')).toBe(true);
|
||||
});
|
||||
|
||||
it('is stable: repeat calls return the same value (cache hit)', () => {
|
||||
const first = getTimezoneOffsetLabel('Asia/Kolkata');
|
||||
const second = getTimezoneOffsetLabel('Asia/Kolkata');
|
||||
expect(second).toBe(first);
|
||||
});
|
||||
|
||||
it('degrades gracefully on an invalid timezone (returns input)', () => {
|
||||
expect(getTimezoneOffsetLabel('Not/A_Zone_xyz')).toBe('Not/A_Zone_xyz');
|
||||
});
|
||||
});
|
||||
199
frontend/src/agents/schedules/TimezoneCombobox.tsx
Normal file
199
frontend/src/agents/schedules/TimezoneCombobox.tsx
Normal file
@@ -0,0 +1,199 @@
|
||||
import { Check, ChevronsUpDown } from 'lucide-react';
|
||||
import { useMemo, useState } from 'react';
|
||||
|
||||
import { Button } from '@/components/ui/button';
|
||||
import {
|
||||
Command,
|
||||
CommandEmpty,
|
||||
CommandGroup,
|
||||
CommandInput,
|
||||
CommandItem,
|
||||
CommandList,
|
||||
} from '@/components/ui/command';
|
||||
import {
|
||||
Popover,
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
} from '@/components/ui/popover';
|
||||
import { cn } from '@/lib/utils';
|
||||
|
||||
export type TimezoneComboboxProps = {
|
||||
value: string;
|
||||
options: string[];
|
||||
onChange: (next: string) => void;
|
||||
placeholder?: string;
|
||||
searchPlaceholder?: string;
|
||||
emptyText?: string;
|
||||
ariaLabel?: string;
|
||||
className?: string;
|
||||
};
|
||||
|
||||
/**
|
||||
* Case-insensitive substring match against the tz string with separators
|
||||
* normalized to spaces — so typing "warsaw", "Warsaw", or "europe war" all
|
||||
* match ``Europe/Warsaw``.
|
||||
*/
|
||||
export function matchesTimezone(option: string, query: string): boolean {
|
||||
const q = query.trim().toLowerCase();
|
||||
if (!q) return true;
|
||||
const haystack = option.toLowerCase().replace(/[/_]/g, ' ');
|
||||
return q
|
||||
.split(/\s+/)
|
||||
.filter(Boolean)
|
||||
.every((token) => haystack.includes(token));
|
||||
}
|
||||
|
||||
// Process-lifetime cache. Offsets are DST-dependent so they're correct for
|
||||
// "now" — matches what users see in Google Calendar et al. We trade a stale
|
||||
// offset across a DST boundary mid-session for a much cheaper render path.
|
||||
const offsetCache = new Map<string, string>();
|
||||
|
||||
/**
|
||||
* Current UTC offset for an IANA timezone, e.g. ``UTC+1``, ``UTC+5:30``,
|
||||
* ``UTC-3:30``, or just ``UTC`` for GMT. Returns the raw input on invalid
|
||||
* timezones so the UI degrades gracefully. Memoized for the process lifetime.
|
||||
*/
|
||||
export function getTimezoneOffsetLabel(tz: string): string {
|
||||
const cached = offsetCache.get(tz);
|
||||
if (cached !== undefined) return cached;
|
||||
const label = computeTimezoneOffsetLabel(tz);
|
||||
offsetCache.set(tz, label);
|
||||
return label;
|
||||
}
|
||||
|
||||
function computeTimezoneOffsetLabel(tz: string): string {
|
||||
let raw: string | undefined;
|
||||
try {
|
||||
const fmt = new Intl.DateTimeFormat('en', {
|
||||
timeZone: tz,
|
||||
timeZoneName: 'shortOffset',
|
||||
});
|
||||
raw = fmt
|
||||
.formatToParts(new Date())
|
||||
.find((p) => p.type === 'timeZoneName')?.value;
|
||||
} catch {
|
||||
return tz;
|
||||
}
|
||||
if (!raw) return 'UTC';
|
||||
// Normalize ``GMT+1`` → ``UTC+1``, ``GMT-05:30`` → ``UTC-5:30``,
|
||||
// and the bare ``GMT`` (UTC zone) → ``UTC``.
|
||||
const normalized = raw.replace(/^GMT/, 'UTC');
|
||||
if (
|
||||
normalized === 'UTC' ||
|
||||
normalized === 'UTC+0' ||
|
||||
normalized === 'UTC-0'
|
||||
) {
|
||||
return 'UTC';
|
||||
}
|
||||
// Strip leading zero in the hour part: ``UTC+05:30`` → ``UTC+5:30``.
|
||||
return normalized.replace(
|
||||
/^UTC([+-])0?(\d+)(?::(\d{2}))?$/,
|
||||
(_, sign, h, m) => (m ? `UTC${sign}${h}:${m}` : `UTC${sign}${h}`),
|
||||
);
|
||||
}
|
||||
|
||||
/** Searchable IANA timezone picker (Popover + Command). */
|
||||
export default function TimezoneCombobox({
|
||||
value,
|
||||
options,
|
||||
onChange,
|
||||
placeholder = 'Select timezone',
|
||||
searchPlaceholder = 'Search timezone…',
|
||||
emptyText = 'No timezone found.',
|
||||
ariaLabel,
|
||||
className,
|
||||
}: TimezoneComboboxProps) {
|
||||
const [open, setOpen] = useState(false);
|
||||
const [query, setQuery] = useState('');
|
||||
|
||||
// Precompute (tz, offset) once per options array — ~400 zones is fast but
|
||||
// not free, and we re-render on every keystroke during filtering.
|
||||
const optionsWithOffset = useMemo(
|
||||
() => options.map((tz) => ({ tz, offset: getTimezoneOffsetLabel(tz) })),
|
||||
[options],
|
||||
);
|
||||
|
||||
const filtered = useMemo(
|
||||
() => optionsWithOffset.filter(({ tz }) => matchesTimezone(tz, query)),
|
||||
[optionsWithOffset, query],
|
||||
);
|
||||
|
||||
const selectedOffset = value ? getTimezoneOffsetLabel(value) : '';
|
||||
|
||||
return (
|
||||
<Popover open={open} onOpenChange={setOpen}>
|
||||
<PopoverTrigger asChild>
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
role="combobox"
|
||||
size="sm"
|
||||
aria-expanded={open}
|
||||
aria-label={ariaLabel}
|
||||
className={cn(
|
||||
'h-9 w-full justify-between gap-2 px-3 font-normal',
|
||||
!value && 'text-muted-foreground',
|
||||
className,
|
||||
)}
|
||||
>
|
||||
{value ? (
|
||||
<span className="flex min-w-0 flex-1 items-center justify-between gap-3">
|
||||
<span className="truncate">{value}</span>
|
||||
<span className="text-muted-foreground shrink-0 text-xs">
|
||||
{selectedOffset}
|
||||
</span>
|
||||
</span>
|
||||
) : (
|
||||
<span className="truncate">{placeholder}</span>
|
||||
)}
|
||||
<ChevronsUpDown className="size-4 shrink-0 opacity-50" />
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
{/* z-200 keeps the popover above WrapperModal (z-100); matches DatePicker. */}
|
||||
<PopoverContent
|
||||
className="z-200 w-[min(20rem,calc(100vw-2rem))] p-0"
|
||||
align="start"
|
||||
>
|
||||
<Command shouldFilter={false}>
|
||||
<CommandInput
|
||||
placeholder={searchPlaceholder}
|
||||
value={query}
|
||||
onValueChange={setQuery}
|
||||
/>
|
||||
<CommandList>
|
||||
<CommandEmpty>{emptyText}</CommandEmpty>
|
||||
<CommandGroup>
|
||||
{filtered.map(({ tz, offset }) => {
|
||||
const selected = tz === value;
|
||||
return (
|
||||
<CommandItem
|
||||
key={tz}
|
||||
value={tz}
|
||||
onSelect={() => {
|
||||
onChange(tz);
|
||||
setOpen(false);
|
||||
setQuery('');
|
||||
}}
|
||||
>
|
||||
<Check
|
||||
className={cn(
|
||||
'mr-2 size-4 shrink-0',
|
||||
selected ? 'opacity-100' : 'opacity-0',
|
||||
)}
|
||||
/>
|
||||
<div className="flex w-full min-w-0 items-center justify-between gap-3">
|
||||
<span className="truncate">{tz}</span>
|
||||
<span className="text-muted-foreground shrink-0 text-xs">
|
||||
{offset}
|
||||
</span>
|
||||
</div>
|
||||
</CommandItem>
|
||||
);
|
||||
})}
|
||||
</CommandGroup>
|
||||
</CommandList>
|
||||
</Command>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
);
|
||||
}
|
||||
325
frontend/src/agents/schedules/cronBuilder.test.ts
Normal file
325
frontend/src/agents/schedules/cronBuilder.test.ts
Normal file
@@ -0,0 +1,325 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import type { Schedule } from '../types/schedule';
|
||||
import {
|
||||
browserTimezone,
|
||||
buildCron,
|
||||
buildRunAtUtc,
|
||||
formatCron,
|
||||
parseCron,
|
||||
parseScheduleToFormValues,
|
||||
parseTime,
|
||||
supportedTimezones,
|
||||
} from './cronBuilder';
|
||||
|
||||
const baseValues = {
|
||||
time: '09:00',
|
||||
dayOfWeek: 1,
|
||||
dayOfMonth: 1,
|
||||
month: 1,
|
||||
};
|
||||
|
||||
describe('buildCron', () => {
|
||||
it('Daily 22:30 → "30 22 * * *"', () => {
|
||||
expect(buildCron('daily', { ...baseValues, time: '22:30' })).toBe(
|
||||
'30 22 * * *',
|
||||
);
|
||||
});
|
||||
|
||||
it('Weekly Mon 09:00 → "0 9 * * 1"', () => {
|
||||
expect(
|
||||
buildCron('weekly', { ...baseValues, time: '09:00', dayOfWeek: 1 }),
|
||||
).toBe('0 9 * * 1');
|
||||
});
|
||||
|
||||
it('Monthly day-15 10:00 → "0 10 15 * *"', () => {
|
||||
expect(
|
||||
buildCron('monthly', { ...baseValues, time: '10:00', dayOfMonth: 15 }),
|
||||
).toBe('0 10 15 * *');
|
||||
});
|
||||
|
||||
it('Yearly March 15 08:00 → "0 8 15 3 *"', () => {
|
||||
expect(
|
||||
buildCron('yearly', {
|
||||
...baseValues,
|
||||
time: '08:00',
|
||||
dayOfMonth: 15,
|
||||
month: 3,
|
||||
}),
|
||||
).toBe('0 8 15 3 *');
|
||||
});
|
||||
|
||||
it('Once returns null cron', () => {
|
||||
expect(buildCron('once', baseValues)).toBeNull();
|
||||
});
|
||||
|
||||
it('clamps out-of-range time inputs', () => {
|
||||
expect(buildCron('daily', { ...baseValues, time: '99:99' })).toBe(
|
||||
'59 23 * * *',
|
||||
);
|
||||
});
|
||||
|
||||
it('clamps day-of-month and month for yearly', () => {
|
||||
expect(
|
||||
buildCron('yearly', {
|
||||
...baseValues,
|
||||
time: '00:00',
|
||||
dayOfMonth: 99,
|
||||
month: 0,
|
||||
}),
|
||||
).toBe('0 0 31 1 *');
|
||||
});
|
||||
});
|
||||
|
||||
describe('parseTime', () => {
|
||||
it('parses "HH:MM"', () => {
|
||||
expect(parseTime('07:05')).toEqual({ hour: 7, minute: 5 });
|
||||
});
|
||||
|
||||
it('falls back to 09:00 on bad input', () => {
|
||||
expect(parseTime('garbage')).toEqual({ hour: 9, minute: 0 });
|
||||
});
|
||||
});
|
||||
|
||||
describe('buildRunAtUtc', () => {
|
||||
it('UTC noon → UTC noon (no offset)', () => {
|
||||
expect(buildRunAtUtc('2026-06-15', '12:00', 'UTC')).toBe(
|
||||
'2026-06-15T12:00:00.000Z',
|
||||
);
|
||||
});
|
||||
|
||||
it('Europe/Warsaw 12:00 in summer (CEST, UTC+2) → 10:00Z', () => {
|
||||
expect(buildRunAtUtc('2026-06-15', '12:00', 'Europe/Warsaw')).toBe(
|
||||
'2026-06-15T10:00:00.000Z',
|
||||
);
|
||||
});
|
||||
|
||||
it('Europe/Warsaw 12:00 in winter (CET, UTC+1) → 11:00Z', () => {
|
||||
expect(buildRunAtUtc('2026-12-15', '12:00', 'Europe/Warsaw')).toBe(
|
||||
'2026-12-15T11:00:00.000Z',
|
||||
);
|
||||
});
|
||||
|
||||
it('America/Los_Angeles 09:00 in summer (PDT, UTC-7) → 16:00Z', () => {
|
||||
expect(buildRunAtUtc('2026-07-04', '09:00', 'America/Los_Angeles')).toBe(
|
||||
'2026-07-04T16:00:00.000Z',
|
||||
);
|
||||
});
|
||||
|
||||
it('Asia/Tokyo 09:00 (JST, UTC+9, no DST) → 00:00Z', () => {
|
||||
// System tz is irrelevant here — the helper must honour the tz parameter.
|
||||
expect(buildRunAtUtc('2026-06-15', '09:00', 'Asia/Tokyo')).toBe(
|
||||
'2026-06-15T00:00:00.000Z',
|
||||
);
|
||||
});
|
||||
|
||||
it('Australia/Sydney 09:00 in Jul winter (AEST, UTC+10) → 23:00Z prev day', () => {
|
||||
expect(buildRunAtUtc('2026-07-15', '09:00', 'Australia/Sydney')).toBe(
|
||||
'2026-07-14T23:00:00.000Z',
|
||||
);
|
||||
});
|
||||
|
||||
it('honours an arbitrary picked tz independently of the system tz', () => {
|
||||
// Picking Asia/Singapore (UTC+8, no DST) maps 10:00 local → 02:00Z.
|
||||
expect(buildRunAtUtc('2026-03-10', '10:00', 'Asia/Singapore')).toBe(
|
||||
'2026-03-10T02:00:00.000Z',
|
||||
);
|
||||
});
|
||||
|
||||
it('throws on invalid date', () => {
|
||||
expect(() => buildRunAtUtc('not-a-date', '12:00', 'UTC')).toThrow();
|
||||
});
|
||||
});
|
||||
|
||||
describe('supportedTimezones', () => {
|
||||
it('returns a non-empty list of IANA tz strings', () => {
|
||||
const list = supportedTimezones();
|
||||
expect(Array.isArray(list)).toBe(true);
|
||||
expect(list.length).toBeGreaterThan(0);
|
||||
// Common zones every fallback / modern engine should expose.
|
||||
expect(list).toContain('UTC');
|
||||
});
|
||||
|
||||
it('includes the browser zone when modern Intl is available', () => {
|
||||
// happy-dom exposes Intl.supportedValuesOf; this is the modern path.
|
||||
const list = supportedTimezones();
|
||||
expect(list).toContain('Europe/Warsaw');
|
||||
expect(list).toContain('Asia/Tokyo');
|
||||
});
|
||||
});
|
||||
|
||||
describe('parseCron', () => {
|
||||
it('round-trips daily cron', () => {
|
||||
expect(parseCron('30 22 * * *')).toMatchObject({
|
||||
frequency: 'daily',
|
||||
minute: 30,
|
||||
hour: 22,
|
||||
});
|
||||
});
|
||||
|
||||
it('round-trips weekly cron', () => {
|
||||
expect(parseCron('0 9 * * 1')).toMatchObject({
|
||||
frequency: 'weekly',
|
||||
minute: 0,
|
||||
hour: 9,
|
||||
dow: 1,
|
||||
});
|
||||
});
|
||||
|
||||
it('round-trips monthly cron', () => {
|
||||
expect(parseCron('0 10 15 * *')).toMatchObject({
|
||||
frequency: 'monthly',
|
||||
minute: 0,
|
||||
hour: 10,
|
||||
dom: 15,
|
||||
});
|
||||
});
|
||||
|
||||
it('round-trips yearly cron', () => {
|
||||
expect(parseCron('0 8 15 3 *')).toMatchObject({
|
||||
frequency: 'yearly',
|
||||
minute: 0,
|
||||
hour: 8,
|
||||
dom: 15,
|
||||
mon: 3,
|
||||
});
|
||||
});
|
||||
|
||||
it('returns null for unsupported shapes (weekday range)', () => {
|
||||
expect(parseCron('0 9 * * 1-5')).toBeNull();
|
||||
});
|
||||
|
||||
it('returns null for non-5-field input', () => {
|
||||
expect(parseCron('* * *')).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('browserTimezone', () => {
|
||||
it('returns a non-empty IANA-looking string', () => {
|
||||
const tz = browserTimezone();
|
||||
expect(typeof tz).toBe('string');
|
||||
expect(tz.length).toBeGreaterThan(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('parseScheduleToFormValues', () => {
|
||||
const makeSchedule = (overrides: Partial<Schedule>): Schedule => ({
|
||||
id: 's',
|
||||
user_id: 'u',
|
||||
agent_id: 'a',
|
||||
trigger_type: 'recurring',
|
||||
instruction: 'do thing',
|
||||
status: 'active',
|
||||
timezone: 'UTC',
|
||||
tool_allowlist: [],
|
||||
created_via: 'ui',
|
||||
consecutive_failure_count: 0,
|
||||
created_at: '2026-05-19T12:00:00Z',
|
||||
updated_at: '2026-05-19T12:00:00Z',
|
||||
...overrides,
|
||||
});
|
||||
|
||||
it('reconstructs weekly from a cron schedule', () => {
|
||||
const s = makeSchedule({ cron: '0 9 * * 1' });
|
||||
const v = parseScheduleToFormValues(s, 'UTC');
|
||||
expect(v.frequency).toBe('weekly');
|
||||
expect(v.time).toBe('09:00');
|
||||
expect(v.dayOfWeek).toBe(1);
|
||||
});
|
||||
|
||||
it('reconstructs once from run_at', () => {
|
||||
const s = makeSchedule({
|
||||
trigger_type: 'once',
|
||||
cron: null,
|
||||
run_at: '2026-06-15T12:00:00Z',
|
||||
});
|
||||
const v = parseScheduleToFormValues(s, 'UTC');
|
||||
expect(v.frequency).toBe('once');
|
||||
expect(v.date).toBe('2026-06-15');
|
||||
expect(v.time).toBe('12:00');
|
||||
});
|
||||
|
||||
it('falls back to daily 09:00 when cron is unrecognized', () => {
|
||||
const s = makeSchedule({ cron: '0 9 * * 1-5' });
|
||||
const v = parseScheduleToFormValues(s, 'UTC');
|
||||
expect(v.frequency).toBe('daily');
|
||||
});
|
||||
|
||||
it('round-trips weekly cron → form values → cron', () => {
|
||||
const s = makeSchedule({ cron: '30 14 * * 4' });
|
||||
const v = parseScheduleToFormValues(s, 'UTC');
|
||||
expect(buildCron(v.frequency, v)).toBe('30 14 * * 4');
|
||||
});
|
||||
|
||||
it('round-trips monthly cron → form values → cron', () => {
|
||||
const s = makeSchedule({ cron: '5 7 22 * *' });
|
||||
const v = parseScheduleToFormValues(s, 'UTC');
|
||||
expect(buildCron(v.frequency, v)).toBe('5 7 22 * *');
|
||||
});
|
||||
|
||||
it('round-trips yearly cron → form values → cron', () => {
|
||||
const s = makeSchedule({ cron: '0 8 15 3 *' });
|
||||
const v = parseScheduleToFormValues(s, 'UTC');
|
||||
expect(buildCron(v.frequency, v)).toBe('0 8 15 3 *');
|
||||
});
|
||||
|
||||
it('round-trips daily cron → form values → cron', () => {
|
||||
const s = makeSchedule({ cron: '45 6 * * *' });
|
||||
const v = parseScheduleToFormValues(s, 'UTC');
|
||||
expect(buildCron(v.frequency, v)).toBe('45 6 * * *');
|
||||
});
|
||||
|
||||
it('round-trips once schedule → run_at preserved through tz', () => {
|
||||
const runAt = '2026-06-15T12:00:00.000Z';
|
||||
const s = makeSchedule({
|
||||
trigger_type: 'once',
|
||||
cron: null,
|
||||
run_at: runAt,
|
||||
});
|
||||
const v = parseScheduleToFormValues(s, 'UTC');
|
||||
expect(buildRunAtUtc(v.date, v.time, 'UTC')).toBe(runAt);
|
||||
});
|
||||
});
|
||||
|
||||
describe('formatCron', () => {
|
||||
it('renders daily 9:00 AM', () => {
|
||||
expect(formatCron('0 9 * * *')).toBe('Daily at 9:00 AM');
|
||||
});
|
||||
|
||||
it('renders daily 10:30 PM', () => {
|
||||
expect(formatCron('30 22 * * *')).toBe('Daily at 10:30 PM');
|
||||
});
|
||||
|
||||
it('renders weekly Monday 9:00 AM', () => {
|
||||
expect(formatCron('0 9 * * 1')).toBe('Weekly on Monday at 9:00 AM');
|
||||
});
|
||||
|
||||
it('renders weekly Sunday', () => {
|
||||
expect(formatCron('15 7 * * 0')).toBe('Weekly on Sunday at 7:15 AM');
|
||||
});
|
||||
|
||||
it('renders monthly day 15 at 10:00 AM', () => {
|
||||
expect(formatCron('0 10 15 * *')).toBe('Monthly on day 15 at 10:00 AM');
|
||||
});
|
||||
|
||||
it('renders yearly March 15 8:00 AM', () => {
|
||||
expect(formatCron('0 8 15 3 *')).toBe('Yearly on March 15 at 8:00 AM');
|
||||
});
|
||||
|
||||
it('renders midnight as 12:00 AM', () => {
|
||||
expect(formatCron('0 0 * * *')).toBe('Daily at 12:00 AM');
|
||||
});
|
||||
|
||||
it('renders noon as 12:00 PM', () => {
|
||||
expect(formatCron('0 12 * * *')).toBe('Daily at 12:00 PM');
|
||||
});
|
||||
|
||||
it('falls back to custom for unsupported cron', () => {
|
||||
expect(formatCron('0 9 * * 1-5')).toBe('Custom: 0 9 * * 1-5');
|
||||
});
|
||||
|
||||
it('returns empty string for null/undefined', () => {
|
||||
expect(formatCron(null)).toBe('');
|
||||
expect(formatCron(undefined)).toBe('');
|
||||
});
|
||||
});
|
||||
334
frontend/src/agents/schedules/cronBuilder.ts
Normal file
334
frontend/src/agents/schedules/cronBuilder.ts
Normal file
@@ -0,0 +1,334 @@
|
||||
import type { Schedule } from '../types/schedule';
|
||||
|
||||
export type ScheduleFrequency =
|
||||
| 'once'
|
||||
| 'daily'
|
||||
| 'weekly'
|
||||
| 'monthly'
|
||||
| 'yearly';
|
||||
|
||||
export type ScheduleFormValues = {
|
||||
frequency: ScheduleFrequency;
|
||||
date: string; // YYYY-MM-DD (used by 'once')
|
||||
time: string; // HH:MM (24h)
|
||||
dayOfWeek: number; // 0=Sun … 6=Sat (used by 'weekly')
|
||||
dayOfMonth: number; // 1..31 (used by 'monthly' / 'yearly')
|
||||
month: number; // 1..12 (used by 'yearly')
|
||||
};
|
||||
|
||||
const clamp = (n: number, lo: number, hi: number): number =>
|
||||
Math.max(lo, Math.min(hi, Math.floor(n)));
|
||||
|
||||
const pad2 = (n: number): string => String(n).padStart(2, '0');
|
||||
|
||||
/** Parse "HH:MM" into [hour, minute]; defaults on bad input. */
|
||||
export function parseTime(time: string): { hour: number; minute: number } {
|
||||
const m = /^(\d{1,2}):(\d{1,2})$/.exec(time?.trim() ?? '');
|
||||
if (!m) return { hour: 9, minute: 0 };
|
||||
return {
|
||||
hour: clamp(Number(m[1]), 0, 23),
|
||||
minute: clamp(Number(m[2]), 0, 59),
|
||||
};
|
||||
}
|
||||
|
||||
/** Detect the browser's IANA timezone (e.g. ``Europe/Warsaw``). */
|
||||
export function browserTimezone(): string {
|
||||
try {
|
||||
const tz = Intl.DateTimeFormat().resolvedOptions().timeZone;
|
||||
return tz || 'UTC';
|
||||
} catch {
|
||||
return 'UTC';
|
||||
}
|
||||
}
|
||||
|
||||
// Minimal fallback list for engines without ``Intl.supportedValuesOf``.
|
||||
const FALLBACK_TIMEZONES: readonly string[] = [
|
||||
'UTC',
|
||||
'Europe/London',
|
||||
'Europe/Berlin',
|
||||
'Europe/Warsaw',
|
||||
'Europe/Moscow',
|
||||
'America/New_York',
|
||||
'America/Chicago',
|
||||
'America/Denver',
|
||||
'America/Los_Angeles',
|
||||
'America/Sao_Paulo',
|
||||
'Asia/Dubai',
|
||||
'Asia/Kolkata',
|
||||
'Asia/Singapore',
|
||||
'Asia/Tokyo',
|
||||
'Australia/Sydney',
|
||||
'Pacific/Auckland',
|
||||
] as const;
|
||||
|
||||
/** Full IANA timezone list via ``Intl.supportedValuesOf``; falls back for older engines. */
|
||||
export function supportedTimezones(): string[] {
|
||||
try {
|
||||
const intlAny = Intl as unknown as {
|
||||
supportedValuesOf?: (key: 'timeZone') => string[];
|
||||
};
|
||||
if (typeof intlAny.supportedValuesOf === 'function') {
|
||||
const values = intlAny.supportedValuesOf('timeZone');
|
||||
if (Array.isArray(values) && values.length > 0) {
|
||||
// ``supportedValuesOf`` omits the ``UTC`` alias on most engines; ensure it
|
||||
// is always pickable as it's the universal default.
|
||||
return values.includes('UTC') ? values : ['UTC', ...values];
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// fall through to the fallback list
|
||||
}
|
||||
return [...FALLBACK_TIMEZONES];
|
||||
}
|
||||
|
||||
/** Build a 5-field cron expression for recurring frequencies; ``null`` for 'once'. */
|
||||
export function buildCron(
|
||||
frequency: ScheduleFrequency,
|
||||
values: Pick<
|
||||
ScheduleFormValues,
|
||||
'time' | 'dayOfWeek' | 'dayOfMonth' | 'month'
|
||||
>,
|
||||
): string | null {
|
||||
if (frequency === 'once') return null;
|
||||
const { hour, minute } = parseTime(values.time);
|
||||
switch (frequency) {
|
||||
case 'daily':
|
||||
return `${minute} ${hour} * * *`;
|
||||
case 'weekly':
|
||||
return `${minute} ${hour} * * ${clamp(values.dayOfWeek, 0, 6)}`;
|
||||
case 'monthly':
|
||||
return `${minute} ${hour} ${clamp(values.dayOfMonth, 1, 31)} * *`;
|
||||
case 'yearly':
|
||||
return `${minute} ${hour} ${clamp(values.dayOfMonth, 1, 31)} ${clamp(values.month, 1, 12)} *`;
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/** Convert a local date/time + IANA tz to a UTC ISO 8601 string. */
|
||||
export function buildRunAtUtc(
|
||||
date: string,
|
||||
time: string,
|
||||
timezone: string,
|
||||
): string {
|
||||
const { hour, minute } = parseTime(time);
|
||||
const dm = /^(\d{4})-(\d{1,2})-(\d{1,2})$/.exec(date?.trim() ?? '');
|
||||
if (!dm) throw new Error('invalid date');
|
||||
const year = Number(dm[1]);
|
||||
const month = clamp(Number(dm[2]), 1, 12);
|
||||
const day = clamp(Number(dm[3]), 1, 31);
|
||||
// Compute UTC offset of the chosen tz at the chosen wall-clock instant by
|
||||
// formatting an interim UTC date and reading back the tz parts.
|
||||
const guess = Date.UTC(year, month - 1, day, hour, minute, 0);
|
||||
const parts = formatInTimeZone(guess, timezone);
|
||||
const wallUtc = Date.UTC(
|
||||
parts.year,
|
||||
parts.month - 1,
|
||||
parts.day,
|
||||
parts.hour,
|
||||
parts.minute,
|
||||
0,
|
||||
);
|
||||
const offset = wallUtc - guess;
|
||||
return new Date(guess - offset).toISOString();
|
||||
}
|
||||
|
||||
type TzParts = {
|
||||
year: number;
|
||||
month: number;
|
||||
day: number;
|
||||
hour: number;
|
||||
minute: number;
|
||||
};
|
||||
|
||||
const formatInTimeZone = (utcMs: number, timezone: string): TzParts => {
|
||||
const fmt = new Intl.DateTimeFormat('en-US', {
|
||||
timeZone: timezone,
|
||||
year: 'numeric',
|
||||
month: '2-digit',
|
||||
day: '2-digit',
|
||||
hour: '2-digit',
|
||||
minute: '2-digit',
|
||||
hour12: false,
|
||||
});
|
||||
const map: Record<string, string> = {};
|
||||
for (const p of fmt.formatToParts(new Date(utcMs))) {
|
||||
if (p.type !== 'literal') map[p.type] = p.value;
|
||||
}
|
||||
return {
|
||||
year: Number(map.year),
|
||||
month: Number(map.month),
|
||||
day: Number(map.day),
|
||||
// Intl returns "24" at midnight in some engines; normalize to 0.
|
||||
hour: Number(map.hour) % 24,
|
||||
minute: Number(map.minute),
|
||||
};
|
||||
};
|
||||
|
||||
/** Derive form initial values from an existing schedule (edit mode). */
|
||||
export function parseScheduleToFormValues(
|
||||
schedule: Schedule,
|
||||
timezone: string,
|
||||
): ScheduleFormValues {
|
||||
const fallback: ScheduleFormValues = {
|
||||
frequency: 'daily',
|
||||
date: todayDate(timezone),
|
||||
time: '09:00',
|
||||
dayOfWeek: 1,
|
||||
dayOfMonth: 1,
|
||||
month: 1,
|
||||
};
|
||||
if (schedule.trigger_type === 'once' && schedule.run_at) {
|
||||
const parts = formatInTimeZone(
|
||||
new Date(schedule.run_at).getTime(),
|
||||
timezone,
|
||||
);
|
||||
return {
|
||||
...fallback,
|
||||
frequency: 'once',
|
||||
date: `${parts.year}-${pad2(parts.month)}-${pad2(parts.day)}`,
|
||||
time: `${pad2(parts.hour)}:${pad2(parts.minute)}`,
|
||||
};
|
||||
}
|
||||
if (!schedule.cron) return fallback;
|
||||
const parsed = parseCron(schedule.cron);
|
||||
if (!parsed) return fallback;
|
||||
const { frequency, minute, hour, dom, mon, dow } = parsed;
|
||||
return {
|
||||
frequency,
|
||||
date: fallback.date,
|
||||
time: `${pad2(hour)}:${pad2(minute)}`,
|
||||
dayOfWeek: dow ?? 1,
|
||||
dayOfMonth: dom ?? 1,
|
||||
month: mon ?? 1,
|
||||
};
|
||||
}
|
||||
|
||||
type ParsedCron = {
|
||||
frequency: Exclude<ScheduleFrequency, 'once'>;
|
||||
minute: number;
|
||||
hour: number;
|
||||
dom: number | null;
|
||||
mon: number | null;
|
||||
dow: number | null;
|
||||
};
|
||||
|
||||
/** Recognize the cron shapes ``buildCron`` produces; otherwise ``null``. */
|
||||
export function parseCron(expression: string): ParsedCron | null {
|
||||
const parts = expression.trim().split(/\s+/);
|
||||
if (parts.length !== 5) return null;
|
||||
const [mn, hr, dom, mon, dow] = parts;
|
||||
const m = Number(mn);
|
||||
const h = Number(hr);
|
||||
if (!Number.isFinite(m) || !Number.isFinite(h)) return null;
|
||||
// yearly: explicit dom + explicit mon
|
||||
if (dom !== '*' && mon !== '*' && dow === '*') {
|
||||
const d = Number(dom);
|
||||
const mm = Number(mon);
|
||||
if (!Number.isFinite(d) || !Number.isFinite(mm)) return null;
|
||||
return {
|
||||
frequency: 'yearly',
|
||||
minute: m,
|
||||
hour: h,
|
||||
dom: d,
|
||||
mon: mm,
|
||||
dow: null,
|
||||
};
|
||||
}
|
||||
// monthly: explicit dom, * mon, * dow
|
||||
if (dom !== '*' && mon === '*' && dow === '*') {
|
||||
const d = Number(dom);
|
||||
if (!Number.isFinite(d)) return null;
|
||||
return {
|
||||
frequency: 'monthly',
|
||||
minute: m,
|
||||
hour: h,
|
||||
dom: d,
|
||||
mon: null,
|
||||
dow: null,
|
||||
};
|
||||
}
|
||||
// weekly: * dom, * mon, explicit dow (single value)
|
||||
if (dom === '*' && mon === '*' && dow !== '*' && !dow.includes(',')) {
|
||||
const d = Number(dow);
|
||||
if (!Number.isFinite(d)) return null;
|
||||
return {
|
||||
frequency: 'weekly',
|
||||
minute: m,
|
||||
hour: h,
|
||||
dom: null,
|
||||
mon: null,
|
||||
dow: d,
|
||||
};
|
||||
}
|
||||
// daily: * dom, * mon, * dow
|
||||
if (dom === '*' && mon === '*' && dow === '*') {
|
||||
return {
|
||||
frequency: 'daily',
|
||||
minute: m,
|
||||
hour: h,
|
||||
dom: null,
|
||||
mon: null,
|
||||
dow: null,
|
||||
};
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/** Today's date in ``YYYY-MM-DD`` for the given IANA timezone. */
|
||||
export function todayDate(timezone: string): string {
|
||||
const p = formatInTimeZone(Date.now(), timezone);
|
||||
return `${p.year}-${pad2(p.month)}-${pad2(p.day)}`;
|
||||
}
|
||||
|
||||
const DAY_NAMES = [
|
||||
'Sunday',
|
||||
'Monday',
|
||||
'Tuesday',
|
||||
'Wednesday',
|
||||
'Thursday',
|
||||
'Friday',
|
||||
'Saturday',
|
||||
];
|
||||
|
||||
const MONTH_NAMES = [
|
||||
'January',
|
||||
'February',
|
||||
'March',
|
||||
'April',
|
||||
'May',
|
||||
'June',
|
||||
'July',
|
||||
'August',
|
||||
'September',
|
||||
'October',
|
||||
'November',
|
||||
'December',
|
||||
];
|
||||
|
||||
const formatTime12h = (hour: number, minute: number): string => {
|
||||
const period = hour >= 12 ? 'PM' : 'AM';
|
||||
const h12 = hour % 12 === 0 ? 12 : hour % 12;
|
||||
return `${h12}:${pad2(minute)} ${period}`;
|
||||
};
|
||||
|
||||
/** Human-readable label for a cron string the form emits; falls back for custom shapes. */
|
||||
export function formatCron(expression?: string | null): string {
|
||||
if (!expression) return '';
|
||||
const parsed = parseCron(expression);
|
||||
if (!parsed) return `Custom: ${expression}`;
|
||||
const { frequency, hour, minute, dom, mon, dow } = parsed;
|
||||
const time = formatTime12h(hour, minute);
|
||||
switch (frequency) {
|
||||
case 'daily':
|
||||
return `Daily at ${time}`;
|
||||
case 'weekly':
|
||||
return `Weekly on ${DAY_NAMES[(dow ?? 0) % 7]} at ${time}`;
|
||||
case 'monthly':
|
||||
return `Monthly on day ${dom} at ${time}`;
|
||||
case 'yearly':
|
||||
return `Yearly on ${MONTH_NAMES[((mon ?? 1) - 1) % 12]} ${dom} at ${time}`;
|
||||
default:
|
||||
return `Custom: ${expression}`;
|
||||
}
|
||||
}
|
||||
222
frontend/src/agents/schedules/schedulesSlice.test.ts
Normal file
222
frontend/src/agents/schedules/schedulesSlice.test.ts
Normal file
@@ -0,0 +1,222 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import {
|
||||
sseEventReceived,
|
||||
type SSEEvent,
|
||||
} from '../../notifications/notificationsSlice';
|
||||
import type { Schedule, ScheduleRun } from '../types/schedule';
|
||||
import reducer, {
|
||||
applyEvent,
|
||||
selectRunsForSchedule,
|
||||
selectSchedulesForAgent,
|
||||
type SchedulesState,
|
||||
} from './schedulesSlice';
|
||||
|
||||
const sampleSchedule = (overrides: Partial<Schedule> = {}): Schedule => ({
|
||||
id: 'sched-1',
|
||||
user_id: 'alice',
|
||||
agent_id: 'agent-1',
|
||||
trigger_type: 'recurring',
|
||||
instruction: 'do it',
|
||||
status: 'active',
|
||||
timezone: 'UTC',
|
||||
tool_allowlist: [],
|
||||
created_via: 'ui',
|
||||
consecutive_failure_count: 0,
|
||||
created_at: '2026-05-19T12:00:00Z',
|
||||
updated_at: '2026-05-19T12:00:00Z',
|
||||
...overrides,
|
||||
});
|
||||
|
||||
const sampleRun = (overrides: Partial<ScheduleRun> = {}): ScheduleRun => ({
|
||||
id: 'run-1',
|
||||
schedule_id: 'sched-1',
|
||||
user_id: 'alice',
|
||||
agent_id: 'agent-1',
|
||||
status: 'pending',
|
||||
scheduled_for: '2026-05-19T12:00:00Z',
|
||||
trigger_source: 'cron',
|
||||
output_truncated: false,
|
||||
prompt_tokens: 0,
|
||||
generated_tokens: 0,
|
||||
created_at: '2026-05-19T12:00:00Z',
|
||||
updated_at: '2026-05-19T12:00:00Z',
|
||||
...overrides,
|
||||
});
|
||||
|
||||
const seedState = () => reducer(undefined, { type: '@@INIT' });
|
||||
|
||||
const seedWithSchedule = (): SchedulesState => {
|
||||
let state = seedState();
|
||||
state = reducer(
|
||||
state,
|
||||
applyEvent({ type: 'noop', scheduleId: 'sched-1', run: sampleRun() }),
|
||||
);
|
||||
return {
|
||||
...state,
|
||||
byAgent: { 'agent-1': [sampleSchedule()] } as Record<string, Schedule[]>,
|
||||
};
|
||||
};
|
||||
|
||||
describe('schedulesSlice SSE event handling', () => {
|
||||
it('schedule.run.completed upserts run + bumps last_run_at', () => {
|
||||
let state = seedWithSchedule();
|
||||
const envelope: SSEEvent = {
|
||||
id: 'evt-1',
|
||||
ts: '2026-05-19T12:05:00Z',
|
||||
type: 'schedule.run.completed',
|
||||
payload: {
|
||||
run_id: 'run-1',
|
||||
schedule_id: 'sched-1',
|
||||
status: 'success',
|
||||
},
|
||||
};
|
||||
state = reducer(state, sseEventReceived(envelope));
|
||||
const runs = selectRunsForSchedule({ schedules: state }, 'sched-1');
|
||||
expect(runs[0].status).toBe('success');
|
||||
const schedules = selectSchedulesForAgent({ schedules: state }, 'agent-1');
|
||||
expect(schedules[0].last_run_at).toBe('2026-05-19T12:05:00Z');
|
||||
});
|
||||
|
||||
it('schedule.run.failed marks the run as failed and carries error_type', () => {
|
||||
let state = seedWithSchedule();
|
||||
const envelope: SSEEvent = {
|
||||
id: 'evt-2',
|
||||
ts: '2026-05-19T12:06:00Z',
|
||||
type: 'schedule.run.failed',
|
||||
payload: {
|
||||
run_id: 'run-1',
|
||||
schedule_id: 'sched-1',
|
||||
error_type: 'agent_error',
|
||||
error: 'LLM exploded',
|
||||
},
|
||||
};
|
||||
state = reducer(state, sseEventReceived(envelope));
|
||||
const runs = selectRunsForSchedule({ schedules: state }, 'sched-1');
|
||||
expect(runs[0].status).toBe('failed');
|
||||
expect(runs[0].error_type).toBe('agent_error');
|
||||
expect(runs[0].error).toBe('LLM exploded');
|
||||
});
|
||||
|
||||
it('schedule.autopaused flips the schedule status to paused', () => {
|
||||
let state = seedWithSchedule();
|
||||
const envelope: SSEEvent = {
|
||||
id: 'evt-3',
|
||||
ts: '2026-05-19T12:07:00Z',
|
||||
type: 'schedule.autopaused',
|
||||
payload: { schedule_id: 'sched-1' },
|
||||
};
|
||||
state = reducer(state, sseEventReceived(envelope));
|
||||
const schedules = selectSchedulesForAgent({ schedules: state }, 'agent-1');
|
||||
expect(schedules[0].status).toBe('paused');
|
||||
});
|
||||
|
||||
it('schedule.message.appended is acknowledged without mutating run state', () => {
|
||||
let state = seedWithSchedule();
|
||||
const envelope: SSEEvent = {
|
||||
id: 'evt-4',
|
||||
ts: '2026-05-19T12:08:00Z',
|
||||
type: 'schedule.message.appended',
|
||||
payload: {
|
||||
schedule_id: 'sched-1',
|
||||
run_id: 'run-1',
|
||||
conversation_id: 'conv-1',
|
||||
message_id: 'msg-1',
|
||||
},
|
||||
};
|
||||
const before = JSON.stringify(state);
|
||||
state = reducer(state, sseEventReceived(envelope));
|
||||
expect(JSON.stringify(state)).toBe(before);
|
||||
});
|
||||
|
||||
it('ignores envelopes without a schedule_id payload', () => {
|
||||
let state = seedWithSchedule();
|
||||
const envelope: SSEEvent = {
|
||||
id: 'evt-5',
|
||||
type: 'schedule.run.completed',
|
||||
payload: { run_id: 'run-1' },
|
||||
};
|
||||
const before = JSON.stringify(state);
|
||||
state = reducer(state, sseEventReceived(envelope));
|
||||
expect(JSON.stringify(state)).toBe(before);
|
||||
});
|
||||
|
||||
it('inserts a stub run row when the envelope arrives before the run log is loaded', () => {
|
||||
let state = seedState();
|
||||
state = {
|
||||
...state,
|
||||
byAgent: { 'agent-1': [sampleSchedule()] } as Record<string, Schedule[]>,
|
||||
};
|
||||
const envelope: SSEEvent = {
|
||||
id: 'evt-6',
|
||||
ts: '2026-05-19T12:09:00Z',
|
||||
type: 'schedule.run.completed',
|
||||
payload: {
|
||||
run_id: 'run-new',
|
||||
schedule_id: 'sched-1',
|
||||
},
|
||||
};
|
||||
state = reducer(state, sseEventReceived(envelope));
|
||||
const runs = selectRunsForSchedule({ schedules: state }, 'sched-1');
|
||||
expect(runs[0].id).toBe('run-new');
|
||||
expect(runs[0].status).toBe('success');
|
||||
});
|
||||
|
||||
it('seeds stub-insert run rows with safe defaults so RunLog never renders NaN', () => {
|
||||
let state = seedState();
|
||||
state = {
|
||||
...state,
|
||||
byAgent: { 'agent-1': [sampleSchedule()] } as Record<string, Schedule[]>,
|
||||
};
|
||||
const envelope: SSEEvent = {
|
||||
id: 'evt-7',
|
||||
ts: '2026-05-19T12:10:00Z',
|
||||
type: 'schedule.run.completed',
|
||||
payload: { run_id: 'run-stub', schedule_id: 'sched-1' },
|
||||
};
|
||||
state = reducer(state, sseEventReceived(envelope));
|
||||
const stub = selectRunsForSchedule({ schedules: state }, 'sched-1')[0];
|
||||
|
||||
expect(stub.prompt_tokens).toBe(0);
|
||||
expect(stub.generated_tokens).toBe(0);
|
||||
expect(stub.prompt_tokens + stub.generated_tokens).toBe(0);
|
||||
expect(Number.isNaN(stub.prompt_tokens + stub.generated_tokens)).toBe(
|
||||
false,
|
||||
);
|
||||
expect(stub.trigger_source).toBe('cron');
|
||||
expect(stub.output_truncated).toBe(false);
|
||||
expect(stub.scheduled_for).toBe('2026-05-19T12:10:00Z');
|
||||
expect(stub.started_at).toBe('2026-05-19T12:10:00Z');
|
||||
expect(stub.finished_at).toBe('2026-05-19T12:10:00Z');
|
||||
expect(stub.status).toBe('success');
|
||||
expect(stub.error).toBeNull();
|
||||
expect(stub.error_type).toBeNull();
|
||||
});
|
||||
|
||||
it('stub-insert seeds defaults for failed runs too', () => {
|
||||
let state = seedState();
|
||||
state = {
|
||||
...state,
|
||||
byAgent: { 'agent-1': [sampleSchedule()] } as Record<string, Schedule[]>,
|
||||
};
|
||||
const envelope: SSEEvent = {
|
||||
id: 'evt-8',
|
||||
ts: '2026-05-19T12:11:00Z',
|
||||
type: 'schedule.run.failed',
|
||||
payload: {
|
||||
run_id: 'run-stub-failed',
|
||||
schedule_id: 'sched-1',
|
||||
error_type: 'agent_error',
|
||||
error: 'boom',
|
||||
},
|
||||
};
|
||||
state = reducer(state, sseEventReceived(envelope));
|
||||
const stub = selectRunsForSchedule({ schedules: state }, 'sched-1')[0];
|
||||
expect(stub.status).toBe('failed');
|
||||
expect(stub.error).toBe('boom');
|
||||
expect(stub.error_type).toBe('agent_error');
|
||||
expect(stub.prompt_tokens).toBe(0);
|
||||
expect(stub.generated_tokens).toBe(0);
|
||||
expect(stub.trigger_source).toBe('cron');
|
||||
});
|
||||
});
|
||||
319
frontend/src/agents/schedules/schedulesSlice.ts
Normal file
319
frontend/src/agents/schedules/schedulesSlice.ts
Normal file
@@ -0,0 +1,319 @@
|
||||
import { createAsyncThunk, createSlice, PayloadAction } from '@reduxjs/toolkit';
|
||||
|
||||
import schedulesService from '../../api/services/schedulesService';
|
||||
import {
|
||||
sseEventReceived,
|
||||
type SSEEvent,
|
||||
} from '../../notifications/notificationsSlice';
|
||||
import type {
|
||||
Schedule,
|
||||
ScheduleCreatePayload,
|
||||
ScheduleRun,
|
||||
ScheduleUpdatePayload,
|
||||
} from '../types/schedule';
|
||||
|
||||
export type SchedulesState = {
|
||||
byAgent: Record<string, Schedule[]>;
|
||||
runsBySchedule: Record<string, ScheduleRun[]>;
|
||||
loading: boolean;
|
||||
error: string | null;
|
||||
};
|
||||
|
||||
const initialState: SchedulesState = {
|
||||
byAgent: {},
|
||||
runsBySchedule: {},
|
||||
loading: false,
|
||||
error: null,
|
||||
};
|
||||
|
||||
export const loadSchedulesForAgent = createAsyncThunk<
|
||||
{ agentId: string; schedules: Schedule[] },
|
||||
{ agentId: string; token: string | null }
|
||||
>('schedules/loadForAgent', async ({ agentId, token }) => {
|
||||
const r = await schedulesService.listForAgent(agentId, token);
|
||||
return { agentId, schedules: r.schedules };
|
||||
});
|
||||
|
||||
export const createSchedule = createAsyncThunk<
|
||||
Schedule,
|
||||
{
|
||||
agentId: string;
|
||||
payload: ScheduleCreatePayload;
|
||||
token: string | null;
|
||||
}
|
||||
>('schedules/create', async ({ agentId, payload, token }) => {
|
||||
const r = await schedulesService.create(agentId, payload, token);
|
||||
return r.schedule;
|
||||
});
|
||||
|
||||
export const updateSchedule = createAsyncThunk<
|
||||
Schedule,
|
||||
{
|
||||
id: string;
|
||||
payload: ScheduleUpdatePayload;
|
||||
token: string | null;
|
||||
}
|
||||
>('schedules/update', async ({ id, payload, token }) => {
|
||||
const r = await schedulesService.update(id, payload, token);
|
||||
return r.schedule;
|
||||
});
|
||||
|
||||
export const setSchedulePaused = createAsyncThunk<
|
||||
Schedule,
|
||||
{ id: string; action: 'pause' | 'resume'; token: string | null }
|
||||
>('schedules/setPaused', async ({ id, action, token }) => {
|
||||
const r = await schedulesService.setPaused(id, action, token);
|
||||
return r.schedule;
|
||||
});
|
||||
|
||||
export const deleteSchedule = createAsyncThunk<
|
||||
string,
|
||||
{ id: string; token: string | null }
|
||||
>('schedules/delete', async ({ id, token }) => {
|
||||
await schedulesService.remove(id, token);
|
||||
return id;
|
||||
});
|
||||
|
||||
export const runScheduleNow = createAsyncThunk<
|
||||
{ scheduleId: string; run: ScheduleRun },
|
||||
{ id: string; token: string | null }
|
||||
>('schedules/runNow', async ({ id, token }) => {
|
||||
const r = await schedulesService.runNow(id, token);
|
||||
return { scheduleId: id, run: r.run };
|
||||
});
|
||||
|
||||
export const loadRunsForSchedule = createAsyncThunk<
|
||||
{ scheduleId: string; runs: ScheduleRun[] },
|
||||
{
|
||||
id: string;
|
||||
limit?: number;
|
||||
offset?: number;
|
||||
token: string | null;
|
||||
}
|
||||
>('schedules/loadRuns', async ({ id, limit, offset, token }) => {
|
||||
const r = await schedulesService.listRuns(id, limit, offset, token);
|
||||
return { scheduleId: id, runs: r.runs };
|
||||
});
|
||||
|
||||
const upsert = (list: Schedule[], next: Schedule): Schedule[] => {
|
||||
const idx = list.findIndex((s) => s.id === next.id);
|
||||
if (idx === -1) return [next, ...list];
|
||||
const copy = list.slice();
|
||||
copy[idx] = next;
|
||||
return copy;
|
||||
};
|
||||
|
||||
const removeFrom = (list: Schedule[], id: string): Schedule[] =>
|
||||
list.filter((s) => s.id !== id);
|
||||
|
||||
// SSE delivers a partial schedule_run; stub the missing fields so RunLog
|
||||
// renders cleanly until the next list refetch.
|
||||
const stubRunDefaults = (
|
||||
scheduleId: string,
|
||||
ts: string | undefined,
|
||||
): Omit<ScheduleRun, 'id' | 'status'> => {
|
||||
const now = ts ?? new Date().toISOString();
|
||||
return {
|
||||
schedule_id: scheduleId,
|
||||
user_id: '',
|
||||
agent_id: '',
|
||||
scheduled_for: now,
|
||||
trigger_source: 'cron',
|
||||
started_at: now,
|
||||
finished_at: now,
|
||||
output: null,
|
||||
output_truncated: false,
|
||||
error: null,
|
||||
error_type: null,
|
||||
prompt_tokens: 0,
|
||||
generated_tokens: 0,
|
||||
conversation_id: null,
|
||||
message_id: null,
|
||||
celery_task_id: null,
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
};
|
||||
};
|
||||
|
||||
const upsertRunDelta = (
|
||||
state: SchedulesState,
|
||||
scheduleId: string,
|
||||
delta: Partial<ScheduleRun> & { id: string; status: ScheduleRun['status'] },
|
||||
ts: string | undefined,
|
||||
): void => {
|
||||
const list = state.runsBySchedule[scheduleId] ?? [];
|
||||
const idx = list.findIndex((r) => r.id === delta.id);
|
||||
if (idx === -1) {
|
||||
const stub: ScheduleRun = { ...stubRunDefaults(scheduleId, ts), ...delta };
|
||||
state.runsBySchedule[scheduleId] = [stub, ...list];
|
||||
return;
|
||||
}
|
||||
list[idx] = { ...list[idx], ...delta };
|
||||
};
|
||||
|
||||
const findAgentForSchedule = (
|
||||
state: SchedulesState,
|
||||
scheduleId: string,
|
||||
): { agentId: string; schedule: Schedule } | null => {
|
||||
for (const agentId of Object.keys(state.byAgent)) {
|
||||
const list = state.byAgent[agentId];
|
||||
const schedule = list.find((s) => s.id === scheduleId);
|
||||
if (schedule) return { agentId, schedule };
|
||||
}
|
||||
return null;
|
||||
};
|
||||
|
||||
const schedulesSlice = createSlice({
|
||||
name: 'schedules',
|
||||
initialState,
|
||||
reducers: {
|
||||
applyEvent: (
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
type: string;
|
||||
scheduleId: string;
|
||||
run?: ScheduleRun;
|
||||
}>,
|
||||
) => {
|
||||
const { scheduleId, run } = action.payload;
|
||||
if (run) {
|
||||
const existing = state.runsBySchedule[scheduleId] ?? [];
|
||||
const idx = existing.findIndex((r) => r.id === run.id);
|
||||
if (idx === -1) {
|
||||
state.runsBySchedule[scheduleId] = [run, ...existing];
|
||||
} else {
|
||||
existing[idx] = run;
|
||||
}
|
||||
}
|
||||
},
|
||||
resetSchedules: () => initialState,
|
||||
},
|
||||
extraReducers: (builder) => {
|
||||
builder
|
||||
.addCase(loadSchedulesForAgent.pending, (state) => {
|
||||
state.loading = true;
|
||||
state.error = null;
|
||||
})
|
||||
.addCase(loadSchedulesForAgent.fulfilled, (state, action) => {
|
||||
state.byAgent[action.payload.agentId] = action.payload.schedules;
|
||||
state.loading = false;
|
||||
})
|
||||
.addCase(loadSchedulesForAgent.rejected, (state, action) => {
|
||||
state.loading = false;
|
||||
state.error = action.error.message ?? 'failed to load schedules';
|
||||
})
|
||||
// Agentless schedules (``agent_id === null``) skip the byAgent cache —
|
||||
// they have no Schedules tab home. The inline ⏰ card is the only UI.
|
||||
.addCase(createSchedule.fulfilled, (state, action) => {
|
||||
const next = action.payload;
|
||||
if (!next.agent_id) return;
|
||||
const list = state.byAgent[next.agent_id] ?? [];
|
||||
state.byAgent[next.agent_id] = upsert(list, next);
|
||||
})
|
||||
.addCase(updateSchedule.fulfilled, (state, action) => {
|
||||
const next = action.payload;
|
||||
if (!next.agent_id) return;
|
||||
const list = state.byAgent[next.agent_id] ?? [];
|
||||
state.byAgent[next.agent_id] = upsert(list, next);
|
||||
})
|
||||
.addCase(setSchedulePaused.fulfilled, (state, action) => {
|
||||
const next = action.payload;
|
||||
if (!next.agent_id) return;
|
||||
const list = state.byAgent[next.agent_id] ?? [];
|
||||
state.byAgent[next.agent_id] = upsert(list, next);
|
||||
})
|
||||
.addCase(deleteSchedule.fulfilled, (state, action) => {
|
||||
const id = action.payload;
|
||||
Object.keys(state.byAgent).forEach((agentId) => {
|
||||
state.byAgent[agentId] = removeFrom(state.byAgent[agentId], id);
|
||||
});
|
||||
delete state.runsBySchedule[id];
|
||||
})
|
||||
.addCase(runScheduleNow.fulfilled, (state, action) => {
|
||||
const { scheduleId, run } = action.payload;
|
||||
const list = state.runsBySchedule[scheduleId] ?? [];
|
||||
state.runsBySchedule[scheduleId] = [run, ...list];
|
||||
})
|
||||
.addCase(loadRunsForSchedule.fulfilled, (state, action) => {
|
||||
const { scheduleId, runs } = action.payload;
|
||||
state.runsBySchedule[scheduleId] = runs;
|
||||
})
|
||||
// SSE envelopes from scheduler_worker.py; unknown shapes are no-ops.
|
||||
.addMatcher(
|
||||
(action) => action.type === sseEventReceived.type,
|
||||
(state, action: PayloadAction<SSEEvent>) => {
|
||||
const envelope = action.payload;
|
||||
const payload = (envelope.payload || {}) as Record<string, unknown>;
|
||||
const scheduleId = (payload.schedule_id as string | undefined) || '';
|
||||
if (!scheduleId) return;
|
||||
switch (envelope.type) {
|
||||
case 'schedule.run.completed':
|
||||
case 'schedule.run.failed': {
|
||||
const runId = (payload.run_id as string | undefined) || '';
|
||||
if (runId) {
|
||||
const status =
|
||||
envelope.type === 'schedule.run.completed'
|
||||
? 'success'
|
||||
: 'failed';
|
||||
upsertRunDelta(
|
||||
state,
|
||||
scheduleId,
|
||||
{
|
||||
id: runId,
|
||||
schedule_id: scheduleId,
|
||||
status: status as ScheduleRun['status'],
|
||||
error_type:
|
||||
(payload.error_type as ScheduleRun['error_type']) ?? null,
|
||||
error: (payload.error as string | undefined) ?? null,
|
||||
finished_at: envelope.ts ?? null,
|
||||
},
|
||||
envelope.ts,
|
||||
);
|
||||
}
|
||||
const found = findAgentForSchedule(state, scheduleId);
|
||||
if (found && envelope.ts) {
|
||||
const next: Schedule = {
|
||||
...found.schedule,
|
||||
last_run_at: envelope.ts,
|
||||
};
|
||||
state.byAgent[found.agentId] = upsert(
|
||||
state.byAgent[found.agentId],
|
||||
next,
|
||||
);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 'schedule.autopaused': {
|
||||
const found = findAgentForSchedule(state, scheduleId);
|
||||
if (found) {
|
||||
const next: Schedule = { ...found.schedule, status: 'paused' };
|
||||
state.byAgent[found.agentId] = upsert(
|
||||
state.byAgent[found.agentId],
|
||||
next,
|
||||
);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 'schedule.message.appended':
|
||||
// Handled by conversationSlice; nothing to mutate here.
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
},
|
||||
);
|
||||
},
|
||||
});
|
||||
|
||||
export const { applyEvent, resetSchedules } = schedulesSlice.actions;
|
||||
export default schedulesSlice.reducer;
|
||||
|
||||
export const selectSchedulesForAgent = (
|
||||
state: { schedules: SchedulesState },
|
||||
agentId: string,
|
||||
): Schedule[] => state.schedules.byAgent[agentId] ?? [];
|
||||
|
||||
export const selectRunsForSchedule = (
|
||||
state: { schedules: SchedulesState },
|
||||
scheduleId: string,
|
||||
): ScheduleRun[] => state.schedules.runsBySchedule[scheduleId] ?? [];
|
||||
@@ -47,4 +47,5 @@ export type AgentFolder = {
|
||||
updated_at?: string;
|
||||
};
|
||||
|
||||
export * from './schedule';
|
||||
export * from './workflow';
|
||||
|
||||
94
frontend/src/agents/types/schedule.ts
Normal file
94
frontend/src/agents/types/schedule.ts
Normal file
@@ -0,0 +1,94 @@
|
||||
export type ScheduleTriggerType = 'once' | 'recurring';
|
||||
|
||||
export type ScheduleStatus = 'active' | 'paused' | 'completed' | 'cancelled';
|
||||
|
||||
export type ScheduleRunStatus =
|
||||
| 'pending'
|
||||
| 'running'
|
||||
| 'success'
|
||||
| 'failed'
|
||||
| 'skipped'
|
||||
| 'timeout';
|
||||
|
||||
export type ScheduleRunErrorType =
|
||||
| 'auth_expired'
|
||||
| 'tool_not_allowed'
|
||||
| 'budget_exceeded'
|
||||
| 'timeout'
|
||||
| 'agent_error'
|
||||
| 'internal'
|
||||
| 'missed'
|
||||
| 'overlap';
|
||||
|
||||
export type Schedule = {
|
||||
id: string;
|
||||
user_id: string;
|
||||
// Null for agentless one-time tasks (migration 0011).
|
||||
agent_id: string | null;
|
||||
trigger_type: ScheduleTriggerType;
|
||||
name?: string | null;
|
||||
instruction: string;
|
||||
status: ScheduleStatus;
|
||||
cron?: string | null;
|
||||
run_at?: string | null;
|
||||
timezone: string;
|
||||
next_run_at?: string | null;
|
||||
last_run_at?: string | null;
|
||||
end_at?: string | null;
|
||||
tool_allowlist: string[];
|
||||
model_id?: string | null;
|
||||
token_budget?: number | null;
|
||||
origin_conversation_id?: string | null;
|
||||
created_via: 'chat' | 'ui';
|
||||
consecutive_failure_count: number;
|
||||
created_at: string;
|
||||
updated_at: string;
|
||||
};
|
||||
|
||||
export type ScheduleRun = {
|
||||
id: string;
|
||||
schedule_id: string;
|
||||
user_id: string;
|
||||
// Null for runs of agentless schedules (migration 0011).
|
||||
agent_id: string | null;
|
||||
status: ScheduleRunStatus;
|
||||
scheduled_for: string;
|
||||
trigger_source: 'cron' | 'manual';
|
||||
started_at?: string | null;
|
||||
finished_at?: string | null;
|
||||
output?: string | null;
|
||||
output_truncated: boolean;
|
||||
error?: string | null;
|
||||
error_type?: ScheduleRunErrorType | null;
|
||||
prompt_tokens: number;
|
||||
generated_tokens: number;
|
||||
conversation_id?: string | null;
|
||||
message_id?: string | null;
|
||||
celery_task_id?: string | null;
|
||||
created_at: string;
|
||||
updated_at: string;
|
||||
};
|
||||
|
||||
export type ScheduleListResponse = { schedules: Schedule[] };
|
||||
export type ScheduleResponse = { schedule: Schedule };
|
||||
export type ScheduleRunListResponse = {
|
||||
runs: ScheduleRun[];
|
||||
limit: number;
|
||||
offset: number;
|
||||
};
|
||||
export type ScheduleRunResponse = { run: ScheduleRun };
|
||||
|
||||
export type ScheduleCreatePayload = {
|
||||
instruction: string;
|
||||
trigger_type?: ScheduleTriggerType;
|
||||
cron?: string;
|
||||
run_at?: string; // ISO 8601 UTC; set for trigger_type === 'once'
|
||||
timezone?: string;
|
||||
name?: string;
|
||||
end_at?: string;
|
||||
tool_allowlist?: string[];
|
||||
model_id?: string;
|
||||
token_budget?: number;
|
||||
};
|
||||
|
||||
export type ScheduleUpdatePayload = Partial<ScheduleCreatePayload>;
|
||||
@@ -3,7 +3,6 @@ import 'reactflow/dist/style.css';
|
||||
import {
|
||||
AlertCircle,
|
||||
Bot,
|
||||
ChartColumn,
|
||||
Database,
|
||||
Flag,
|
||||
GitBranch,
|
||||
@@ -52,7 +51,6 @@ import { Sheet, SheetContent } from '@/components/ui/sheet';
|
||||
|
||||
import modelService from '../../api/services/modelService';
|
||||
import userService from '../../api/services/userService';
|
||||
import ArrowLeft from '../../assets/arrow-left.svg';
|
||||
import { FileUpload } from '../../components/FileUpload';
|
||||
import AgentDetailsModal from '../../modals/AgentDetailsModal';
|
||||
import ConfirmationModal from '../../modals/ConfirmationModal';
|
||||
@@ -62,6 +60,7 @@ import {
|
||||
selectToken,
|
||||
} from '../../preferences/preferenceSlice';
|
||||
import { getToolDisplayName } from '../../utils/toolUtils';
|
||||
import AgentPageHeader from '../AgentPageHeader';
|
||||
import { Agent } from '../types';
|
||||
import { ConditionCase, WorkflowNode } from '../types/workflow';
|
||||
import MobileBlocker from './components/MobileBlocker';
|
||||
@@ -1373,12 +1372,22 @@ function WorkflowBuilderInner() {
|
||||
<div className="bg-background fixed inset-0 z-50 hidden h-screen w-full flex-col md:flex">
|
||||
<div className="border-border bg-card dark:bg-background flex items-center justify-between border-b px-6 py-4">
|
||||
<div className="flex items-center gap-4">
|
||||
<button
|
||||
onClick={navigateBackToAgents}
|
||||
className="border-border text-muted-foreground hover:bg-accent rounded-full border p-3 text-sm"
|
||||
>
|
||||
<img src={ArrowLeft} alt="left-arrow" className="h-3 w-3" />
|
||||
</button>
|
||||
{canManageAgent ? (
|
||||
<AgentPageHeader
|
||||
agentId={effectiveAgentId}
|
||||
agentName={workflowName}
|
||||
agentEditPath={`/agents/workflow/edit/${effectiveAgentId}`}
|
||||
currentPage="overview"
|
||||
inline
|
||||
/>
|
||||
) : (
|
||||
<button
|
||||
onClick={navigateBackToAgents}
|
||||
className="border-border text-muted-foreground hover:bg-accent rounded-full border px-4 py-2 text-sm"
|
||||
>
|
||||
{t('agents.backToAll')}
|
||||
</button>
|
||||
)}
|
||||
<div className="group relative flex items-center gap-2">
|
||||
<div>
|
||||
<div
|
||||
@@ -1524,15 +1533,6 @@ function WorkflowBuilderInner() {
|
||||
<Settings2 size={16} />
|
||||
Details
|
||||
</button>
|
||||
{canManageAgent && (
|
||||
<button
|
||||
onClick={() => navigate(`/agents/logs/${effectiveAgentId}`)}
|
||||
className="border-border bg-card hover:bg-accent flex items-center gap-2 rounded-full border px-4 py-2 text-sm font-medium text-gray-700 transition-colors dark:text-gray-200"
|
||||
>
|
||||
<ChartColumn size={16} />
|
||||
Logs
|
||||
</button>
|
||||
)}
|
||||
{canManageAgent && (
|
||||
<button
|
||||
onClick={() => setAgentDetails('ACTIVE')}
|
||||
|
||||
@@ -84,6 +84,13 @@ const endpoints = {
|
||||
CUSTOM_MODEL: (id: string) => `/api/user/models/${id}`,
|
||||
CUSTOM_MODEL_TEST: (id: string) => `/api/user/models/${id}/test`,
|
||||
CUSTOM_MODEL_TEST_PAYLOAD: '/api/user/models/test',
|
||||
AGENT_SCHEDULES: (agentId: string) => `/api/agents/${agentId}/schedules`,
|
||||
SCHEDULE: (id: string) => `/api/schedules/${id}`,
|
||||
SCHEDULE_RUN_NOW: (id: string) => `/api/schedules/${id}/run`,
|
||||
SCHEDULE_RUNS: (id: string, limit?: number, offset?: number) =>
|
||||
`/api/schedules/${id}/runs?limit=${limit ?? 50}&offset=${offset ?? 0}`,
|
||||
SCHEDULE_RUN: (id: string, runId: string) =>
|
||||
`/api/schedules/${id}/runs/${runId}`,
|
||||
},
|
||||
V1: {
|
||||
CHAT_COMPLETIONS: '/v1/chat/completions',
|
||||
|
||||
116
frontend/src/api/services/schedulesService.ts
Normal file
116
frontend/src/api/services/schedulesService.ts
Normal file
@@ -0,0 +1,116 @@
|
||||
import apiClient from '../client';
|
||||
import endpoints from '../endpoints';
|
||||
import type {
|
||||
ScheduleCreatePayload,
|
||||
ScheduleListResponse,
|
||||
ScheduleResponse,
|
||||
ScheduleRunListResponse,
|
||||
ScheduleRunResponse,
|
||||
ScheduleUpdatePayload,
|
||||
} from '../../agents/types/schedule';
|
||||
|
||||
const json = async (response: Response | unknown) => {
|
||||
const r = response as Response;
|
||||
if (!('json' in r) || typeof r.json !== 'function') return r as unknown;
|
||||
return r.json();
|
||||
};
|
||||
|
||||
const schedulesService = {
|
||||
listForAgent: async (
|
||||
agentId: string,
|
||||
token: string | null,
|
||||
): Promise<ScheduleListResponse> => {
|
||||
const r = await apiClient.get(
|
||||
endpoints.USER.AGENT_SCHEDULES(agentId),
|
||||
token,
|
||||
);
|
||||
return (await json(r)) as ScheduleListResponse;
|
||||
},
|
||||
|
||||
create: async (
|
||||
agentId: string,
|
||||
payload: ScheduleCreatePayload,
|
||||
token: string | null,
|
||||
): Promise<ScheduleResponse> => {
|
||||
const r = await apiClient.post(
|
||||
endpoints.USER.AGENT_SCHEDULES(agentId),
|
||||
payload,
|
||||
token,
|
||||
);
|
||||
return (await json(r)) as ScheduleResponse;
|
||||
},
|
||||
|
||||
get: async (id: string, token: string | null): Promise<ScheduleResponse> => {
|
||||
const r = await apiClient.get(endpoints.USER.SCHEDULE(id), token);
|
||||
return (await json(r)) as ScheduleResponse;
|
||||
},
|
||||
|
||||
update: async (
|
||||
id: string,
|
||||
payload: ScheduleUpdatePayload,
|
||||
token: string | null,
|
||||
): Promise<ScheduleResponse> => {
|
||||
const r = await apiClient.put(endpoints.USER.SCHEDULE(id), payload, token);
|
||||
return (await json(r)) as ScheduleResponse;
|
||||
},
|
||||
|
||||
setPaused: async (
|
||||
id: string,
|
||||
action: 'pause' | 'resume',
|
||||
token: string | null,
|
||||
): Promise<ScheduleResponse> => {
|
||||
const r = await apiClient.patch(
|
||||
endpoints.USER.SCHEDULE(id),
|
||||
{ action },
|
||||
token,
|
||||
);
|
||||
return (await json(r)) as ScheduleResponse;
|
||||
},
|
||||
|
||||
remove: async (
|
||||
id: string,
|
||||
token: string | null,
|
||||
): Promise<{ success: boolean }> => {
|
||||
const r = await apiClient.delete(endpoints.USER.SCHEDULE(id), token);
|
||||
return (await json(r)) as { success: boolean };
|
||||
},
|
||||
|
||||
runNow: async (
|
||||
id: string,
|
||||
token: string | null,
|
||||
): Promise<ScheduleRunResponse> => {
|
||||
const r = await apiClient.post(
|
||||
endpoints.USER.SCHEDULE_RUN_NOW(id),
|
||||
{},
|
||||
token,
|
||||
);
|
||||
return (await json(r)) as ScheduleRunResponse;
|
||||
},
|
||||
|
||||
listRuns: async (
|
||||
id: string,
|
||||
limit: number | undefined,
|
||||
offset: number | undefined,
|
||||
token: string | null,
|
||||
): Promise<ScheduleRunListResponse> => {
|
||||
const r = await apiClient.get(
|
||||
endpoints.USER.SCHEDULE_RUNS(id, limit, offset),
|
||||
token,
|
||||
);
|
||||
return (await json(r)) as ScheduleRunListResponse;
|
||||
},
|
||||
|
||||
getRun: async (
|
||||
id: string,
|
||||
runId: string,
|
||||
token: string | null,
|
||||
): Promise<ScheduleRunResponse> => {
|
||||
const r = await apiClient.get(
|
||||
endpoints.USER.SCHEDULE_RUN(id, runId),
|
||||
token,
|
||||
);
|
||||
return (await json(r)) as ScheduleRunResponse;
|
||||
},
|
||||
};
|
||||
|
||||
export default schedulesService;
|
||||
3
frontend/src/assets/clock-purple.svg
Normal file
3
frontend/src/assets/clock-purple.svg
Normal file
@@ -0,0 +1,3 @@
|
||||
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M12 3.5C7.30558 3.5 3.5 7.30558 3.5 12C3.5 16.6944 7.30558 20.5 12 20.5C16.6944 20.5 20.5 16.6944 20.5 12C20.5 7.30558 16.6944 3.5 12 3.5ZM12 5C15.866 5 19 8.134 19 12C19 15.866 15.866 19 12 19C8.134 19 5 15.866 5 12C5 8.134 8.134 5 12 5ZM11.25 7C11.0511 7 10.8603 7.07902 10.7197 7.21967C10.579 7.36032 10.5 7.55109 10.5 7.75V12C10.5 12.1989 10.579 12.3897 10.7197 12.5303L13.4697 15.2803C13.6103 15.421 13.8011 15.5 14 15.5C14.1989 15.5 14.3897 15.421 14.5303 15.2803C14.671 15.1397 14.75 14.9489 14.75 14.75C14.75 14.5511 14.671 14.3603 14.5303 14.2197L12 11.6893V7.75C12 7.55109 11.921 7.36032 11.7803 7.21967C11.6397 7.07902 11.4489 7 11.25 7Z" fill="#7D54D1" stroke="#7D54D1" stroke-width="0.3"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 815 B |
3
frontend/src/assets/clock-white.svg
Normal file
3
frontend/src/assets/clock-white.svg
Normal file
@@ -0,0 +1,3 @@
|
||||
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M12 3.5C7.30558 3.5 3.5 7.30558 3.5 12C3.5 16.6944 7.30558 20.5 12 20.5C16.6944 20.5 20.5 16.6944 20.5 12C20.5 7.30558 16.6944 3.5 12 3.5ZM12 5C15.866 5 19 8.134 19 12C19 15.866 15.866 19 12 19C8.134 19 5 15.866 5 12C5 8.134 8.134 5 12 5ZM11.25 7C11.0511 7 10.8603 7.07902 10.7197 7.21967C10.579 7.36032 10.5 7.55109 10.5 7.75V12C10.5 12.1989 10.579 12.3897 10.7197 12.5303L13.4697 15.2803C13.6103 15.421 13.8011 15.5 14 15.5C14.1989 15.5 14.3897 15.421 14.5303 15.2803C14.671 15.1397 14.75 14.9489 14.75 14.75C14.75 14.5511 14.671 14.3603 14.5303 14.2197L12 11.6893V7.75C12 7.55109 11.921 7.36032 11.7803 7.21967C11.6397 7.07902 11.4489 7 11.25 7Z" fill="#FFFFFF" stroke="#FFFFFF" stroke-width="0.3"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 815 B |
27
frontend/src/components/ToolsPopup.test.ts
Normal file
27
frontend/src/components/ToolsPopup.test.ts
Normal file
@@ -0,0 +1,27 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import { isChatToolVisible } from './ToolsPopup';
|
||||
|
||||
// Regression for the filter drift introduced when ``scheduler`` was
|
||||
// dual-registered (both ``default: true`` and ``builtin: true``). The
|
||||
// chat-popup previously filtered ``!tool.builtin`` and dropped scheduler.
|
||||
describe('isChatToolVisible', () => {
|
||||
it('keeps dual-registered tools (default + builtin, e.g. scheduler)', () => {
|
||||
expect(isChatToolVisible({ default: true, builtin: true })).toBe(true);
|
||||
});
|
||||
|
||||
it('keeps default-only chat tools (memory, read_webpage before dual-reg)', () => {
|
||||
expect(isChatToolVisible({ default: true, builtin: false })).toBe(true);
|
||||
expect(isChatToolVisible({ default: true })).toBe(true);
|
||||
});
|
||||
|
||||
it('keeps regular user_tools (neither flag set)', () => {
|
||||
expect(isChatToolVisible({})).toBe(true);
|
||||
expect(isChatToolVisible({ default: false, builtin: false })).toBe(true);
|
||||
});
|
||||
|
||||
it('drops pure builtins (agent-only, e.g. a future builtin without default)', () => {
|
||||
expect(isChatToolVisible({ builtin: true })).toBe(false);
|
||||
expect(isChatToolVisible({ default: false, builtin: true })).toBe(false);
|
||||
});
|
||||
});
|
||||
@@ -12,6 +12,15 @@ import NoFilesDarkIcon from '../assets/no-files-dark.svg';
|
||||
import CheckmarkIcon from '../assets/checkmark.svg';
|
||||
import { useDarkTheme } from '../hooks';
|
||||
|
||||
// Chat-popup visibility rule: show defaults (so users can toggle the
|
||||
// agentless chat tools on/off) plus any non-builtin user_tools row. Hide
|
||||
// pure builtins (agent-only). Dual-registered tools like ``scheduler``
|
||||
// carry BOTH flags and stay visible via the ``default`` branch.
|
||||
export const isChatToolVisible = (tool: {
|
||||
default?: boolean;
|
||||
builtin?: boolean;
|
||||
}): boolean => Boolean(tool.default) || !tool.builtin;
|
||||
|
||||
interface ToolsPopupProps {
|
||||
isOpen: boolean;
|
||||
onClose: () => void;
|
||||
@@ -104,7 +113,8 @@ export default function ToolsPopup({
|
||||
return res.json();
|
||||
})
|
||||
.then((data) => {
|
||||
setUserTools(data.tools);
|
||||
const filtered = (data.tools || []).filter(isChatToolVisible);
|
||||
setUserTools(filtered);
|
||||
setLoading(false);
|
||||
})
|
||||
.catch((error) => {
|
||||
|
||||
109
frontend/src/components/ui/breadcrumb.tsx
Normal file
109
frontend/src/components/ui/breadcrumb.tsx
Normal file
@@ -0,0 +1,109 @@
|
||||
import * as React from 'react';
|
||||
import { ChevronRight, MoreHorizontal } from 'lucide-react';
|
||||
import { Slot } from 'radix-ui';
|
||||
|
||||
import { cn } from '@/lib/utils';
|
||||
|
||||
function Breadcrumb({ ...props }: React.ComponentProps<'nav'>) {
|
||||
return <nav aria-label="breadcrumb" data-slot="breadcrumb" {...props} />;
|
||||
}
|
||||
|
||||
function BreadcrumbList({ className, ...props }: React.ComponentProps<'ol'>) {
|
||||
return (
|
||||
<ol
|
||||
data-slot="breadcrumb-list"
|
||||
className={cn(
|
||||
'text-muted-foreground flex flex-wrap items-center gap-1.5 text-sm break-words sm:gap-2.5',
|
||||
className,
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
function BreadcrumbItem({ className, ...props }: React.ComponentProps<'li'>) {
|
||||
return (
|
||||
<li
|
||||
data-slot="breadcrumb-item"
|
||||
className={cn('inline-flex items-center gap-1.5', className)}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
function BreadcrumbLink({
|
||||
asChild,
|
||||
className,
|
||||
...props
|
||||
}: React.ComponentProps<'a'> & {
|
||||
asChild?: boolean;
|
||||
}) {
|
||||
const Comp = asChild ? Slot.Root : 'a';
|
||||
|
||||
return (
|
||||
<Comp
|
||||
data-slot="breadcrumb-link"
|
||||
className={cn('hover:text-foreground transition-colors', className)}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
function BreadcrumbPage({ className, ...props }: React.ComponentProps<'span'>) {
|
||||
return (
|
||||
<span
|
||||
data-slot="breadcrumb-page"
|
||||
role="link"
|
||||
aria-disabled="true"
|
||||
aria-current="page"
|
||||
className={cn('text-foreground font-normal', className)}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
function BreadcrumbSeparator({
|
||||
children,
|
||||
className,
|
||||
...props
|
||||
}: React.ComponentProps<'li'>) {
|
||||
return (
|
||||
<li
|
||||
data-slot="breadcrumb-separator"
|
||||
role="presentation"
|
||||
aria-hidden="true"
|
||||
className={cn('[&>svg]:size-3.5', className)}
|
||||
{...props}
|
||||
>
|
||||
{children ?? <ChevronRight />}
|
||||
</li>
|
||||
);
|
||||
}
|
||||
|
||||
function BreadcrumbEllipsis({
|
||||
className,
|
||||
...props
|
||||
}: React.ComponentProps<'span'>) {
|
||||
return (
|
||||
<span
|
||||
data-slot="breadcrumb-ellipsis"
|
||||
role="presentation"
|
||||
aria-hidden="true"
|
||||
className={cn('flex size-9 items-center justify-center', className)}
|
||||
{...props}
|
||||
>
|
||||
<MoreHorizontal className="size-4" />
|
||||
<span className="sr-only">More</span>
|
||||
</span>
|
||||
);
|
||||
}
|
||||
|
||||
export {
|
||||
Breadcrumb,
|
||||
BreadcrumbList,
|
||||
BreadcrumbItem,
|
||||
BreadcrumbLink,
|
||||
BreadcrumbPage,
|
||||
BreadcrumbSeparator,
|
||||
BreadcrumbEllipsis,
|
||||
};
|
||||
221
frontend/src/components/ui/calendar.tsx
Normal file
221
frontend/src/components/ui/calendar.tsx
Normal file
@@ -0,0 +1,221 @@
|
||||
import {
|
||||
ChevronDownIcon,
|
||||
ChevronLeftIcon,
|
||||
ChevronRightIcon,
|
||||
} from 'lucide-react';
|
||||
import * as React from 'react';
|
||||
import {
|
||||
DayPicker,
|
||||
type DayButton,
|
||||
getDefaultClassNames,
|
||||
} from 'react-day-picker';
|
||||
|
||||
import { Button, buttonVariants } from '@/components/ui/button';
|
||||
import { cn } from '@/lib/utils';
|
||||
|
||||
function Calendar({
|
||||
className,
|
||||
classNames,
|
||||
showOutsideDays = true,
|
||||
captionLayout = 'label',
|
||||
buttonVariant = 'ghost',
|
||||
formatters,
|
||||
components,
|
||||
...props
|
||||
}: React.ComponentProps<typeof DayPicker> & {
|
||||
buttonVariant?: React.ComponentProps<typeof Button>['variant'];
|
||||
}) {
|
||||
const defaultClassNames = getDefaultClassNames();
|
||||
|
||||
return (
|
||||
<DayPicker
|
||||
showOutsideDays={showOutsideDays}
|
||||
className={cn(
|
||||
'group/calendar bg-background p-3 [--cell-size:--spacing(8)] [[data-slot=card-content]_&]:bg-transparent [[data-slot=popover-content]_&]:bg-transparent',
|
||||
String.raw`rtl:**:[.rdp-button\_next>svg]:rotate-180`,
|
||||
String.raw`rtl:**:[.rdp-button\_previous>svg]:rotate-180`,
|
||||
className,
|
||||
)}
|
||||
captionLayout={captionLayout}
|
||||
formatters={{
|
||||
formatMonthDropdown: (date) =>
|
||||
date.toLocaleString('default', { month: 'short' }),
|
||||
...formatters,
|
||||
}}
|
||||
classNames={{
|
||||
root: cn('w-fit', defaultClassNames.root),
|
||||
months: cn(
|
||||
'relative flex flex-col gap-4 md:flex-row',
|
||||
defaultClassNames.months,
|
||||
),
|
||||
month: cn('flex w-full flex-col gap-4', defaultClassNames.month),
|
||||
nav: cn(
|
||||
'absolute inset-x-0 top-0 flex w-full items-center justify-between gap-1',
|
||||
defaultClassNames.nav,
|
||||
),
|
||||
button_previous: cn(
|
||||
buttonVariants({ variant: buttonVariant }),
|
||||
'size-(--cell-size) p-0 select-none aria-disabled:opacity-50',
|
||||
defaultClassNames.button_previous,
|
||||
),
|
||||
button_next: cn(
|
||||
buttonVariants({ variant: buttonVariant }),
|
||||
'size-(--cell-size) p-0 select-none aria-disabled:opacity-50',
|
||||
defaultClassNames.button_next,
|
||||
),
|
||||
month_caption: cn(
|
||||
'flex h-(--cell-size) w-full items-center justify-center px-(--cell-size)',
|
||||
defaultClassNames.month_caption,
|
||||
),
|
||||
dropdowns: cn(
|
||||
'flex h-(--cell-size) w-full items-center justify-center gap-1.5 text-sm font-medium',
|
||||
defaultClassNames.dropdowns,
|
||||
),
|
||||
dropdown_root: cn(
|
||||
'relative rounded-md border border-input shadow-xs has-focus:border-ring has-focus:ring-[3px] has-focus:ring-ring/50',
|
||||
defaultClassNames.dropdown_root,
|
||||
),
|
||||
dropdown: cn(
|
||||
'absolute inset-0 bg-popover opacity-0',
|
||||
defaultClassNames.dropdown,
|
||||
),
|
||||
caption_label: cn(
|
||||
'font-medium select-none',
|
||||
captionLayout === 'label'
|
||||
? 'text-sm'
|
||||
: 'flex h-8 items-center gap-1 rounded-md pr-1 pl-2 text-sm [&>svg]:size-3.5 [&>svg]:text-muted-foreground',
|
||||
defaultClassNames.caption_label,
|
||||
),
|
||||
month_grid: 'w-full border-collapse',
|
||||
weekdays: cn('flex', defaultClassNames.weekdays),
|
||||
weekday: cn(
|
||||
'flex-1 rounded-md text-[0.8rem] font-normal text-muted-foreground select-none',
|
||||
defaultClassNames.weekday,
|
||||
),
|
||||
week: cn('mt-2 flex w-full', defaultClassNames.week),
|
||||
week_number_header: cn(
|
||||
'w-(--cell-size) select-none',
|
||||
defaultClassNames.week_number_header,
|
||||
),
|
||||
week_number: cn(
|
||||
'text-[0.8rem] text-muted-foreground select-none',
|
||||
defaultClassNames.week_number,
|
||||
),
|
||||
day: cn(
|
||||
'group/day relative aspect-square h-full w-full p-0 text-center select-none [&:last-child[data-selected=true]_button]:rounded-r-md',
|
||||
props.showWeekNumber
|
||||
? '[&:nth-child(2)[data-selected=true]_button]:rounded-l-md'
|
||||
: '[&:first-child[data-selected=true]_button]:rounded-l-md',
|
||||
defaultClassNames.day,
|
||||
),
|
||||
range_start: cn(
|
||||
'rounded-l-md bg-accent',
|
||||
defaultClassNames.range_start,
|
||||
),
|
||||
range_middle: cn('rounded-none', defaultClassNames.range_middle),
|
||||
range_end: cn('rounded-r-md bg-accent', defaultClassNames.range_end),
|
||||
today: cn(
|
||||
'rounded-md bg-accent text-accent-foreground data-[selected=true]:rounded-none',
|
||||
defaultClassNames.today,
|
||||
),
|
||||
outside: cn(
|
||||
'text-muted-foreground aria-selected:text-muted-foreground',
|
||||
defaultClassNames.outside,
|
||||
),
|
||||
disabled: cn(
|
||||
'text-muted-foreground opacity-50',
|
||||
defaultClassNames.disabled,
|
||||
),
|
||||
hidden: cn('invisible', defaultClassNames.hidden),
|
||||
...classNames,
|
||||
}}
|
||||
components={{
|
||||
Root: ({ className, rootRef, ...rootProps }) => {
|
||||
return (
|
||||
<div
|
||||
data-slot="calendar"
|
||||
ref={rootRef}
|
||||
className={cn(className)}
|
||||
{...rootProps}
|
||||
/>
|
||||
);
|
||||
},
|
||||
Chevron: ({ className, orientation, ...chevronProps }) => {
|
||||
if (orientation === 'left') {
|
||||
return (
|
||||
<ChevronLeftIcon
|
||||
className={cn('size-4', className)}
|
||||
{...chevronProps}
|
||||
/>
|
||||
);
|
||||
}
|
||||
if (orientation === 'right') {
|
||||
return (
|
||||
<ChevronRightIcon
|
||||
className={cn('size-4', className)}
|
||||
{...chevronProps}
|
||||
/>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<ChevronDownIcon
|
||||
className={cn('size-4', className)}
|
||||
{...chevronProps}
|
||||
/>
|
||||
);
|
||||
},
|
||||
DayButton: CalendarDayButton,
|
||||
WeekNumber: ({ children, ...weekProps }) => {
|
||||
return (
|
||||
<td {...weekProps}>
|
||||
<div className="flex size-(--cell-size) items-center justify-center text-center">
|
||||
{children}
|
||||
</div>
|
||||
</td>
|
||||
);
|
||||
},
|
||||
...components,
|
||||
}}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
function CalendarDayButton({
|
||||
className,
|
||||
day,
|
||||
modifiers,
|
||||
...props
|
||||
}: React.ComponentProps<typeof DayButton>) {
|
||||
const defaultClassNames = getDefaultClassNames();
|
||||
const ref = React.useRef<HTMLButtonElement>(null);
|
||||
React.useEffect(() => {
|
||||
if (modifiers.focused) ref.current?.focus();
|
||||
}, [modifiers.focused]);
|
||||
|
||||
return (
|
||||
<Button
|
||||
ref={ref}
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
data-day={day.date.toLocaleDateString()}
|
||||
data-selected-single={
|
||||
modifiers.selected &&
|
||||
!modifiers.range_start &&
|
||||
!modifiers.range_end &&
|
||||
!modifiers.range_middle
|
||||
}
|
||||
data-range-start={modifiers.range_start}
|
||||
data-range-end={modifiers.range_end}
|
||||
data-range-middle={modifiers.range_middle}
|
||||
className={cn(
|
||||
'group-data-[focused=true]/day:border-ring group-data-[focused=true]/day:ring-ring/50 data-[range-end=true]:bg-primary data-[range-end=true]:text-primary-foreground data-[range-middle=true]:bg-accent data-[range-middle=true]:text-accent-foreground data-[range-start=true]:bg-primary data-[range-start=true]:text-primary-foreground data-[selected-single=true]:bg-primary data-[selected-single=true]:text-primary-foreground dark:hover:text-accent-foreground flex aspect-square size-auto w-full min-w-(--cell-size) flex-col gap-1 leading-none font-normal group-data-[focused=true]/day:relative group-data-[focused=true]/day:z-10 group-data-[focused=true]/day:ring-[3px] data-[range-end=true]:rounded-md data-[range-end=true]:rounded-r-md data-[range-middle=true]:rounded-none data-[range-start=true]:rounded-md data-[range-start=true]:rounded-l-md [&>span]:text-xs [&>span]:opacity-70',
|
||||
defaultClassNames.day,
|
||||
className,
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
export { Calendar, CalendarDayButton };
|
||||
113
frontend/src/components/ui/time-picker.tsx
Normal file
113
frontend/src/components/ui/time-picker.tsx
Normal file
@@ -0,0 +1,113 @@
|
||||
import { Clock } from 'lucide-react';
|
||||
import * as React from 'react';
|
||||
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
SelectItem,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from '@/components/ui/select';
|
||||
import { cn } from '@/lib/utils';
|
||||
|
||||
export interface TimePickerProps {
|
||||
/** 24-hour "HH:MM" value (matches the native <input type="time"> contract). */
|
||||
value: string;
|
||||
onChange: (next: string) => void;
|
||||
ariaLabel?: string;
|
||||
className?: string;
|
||||
/** Minute step within the dropdown. Default 1 (minute precision). */
|
||||
minuteStep?: number;
|
||||
/** Hide the leading clock icon when ``false``. */
|
||||
showIcon?: boolean;
|
||||
}
|
||||
|
||||
const pad2 = (n: number): string => String(n).padStart(2, '0');
|
||||
|
||||
const parseValue = (value: string): { hour: number; minute: number } => {
|
||||
const m = /^(\d{1,2}):(\d{1,2})$/.exec(value ?? '');
|
||||
if (!m) return { hour: 9, minute: 0 };
|
||||
const hour = Math.max(0, Math.min(23, Number(m[1])));
|
||||
const minute = Math.max(0, Math.min(59, Number(m[2])));
|
||||
return { hour, minute };
|
||||
};
|
||||
|
||||
/**
|
||||
* Shadcn-style time picker composed of two Selects (hours + minutes).
|
||||
* Theme-aware (avoids the native <input type="time"> styling issues in dark mode).
|
||||
*/
|
||||
export function TimePicker({
|
||||
value,
|
||||
onChange,
|
||||
ariaLabel,
|
||||
className,
|
||||
minuteStep = 1,
|
||||
showIcon = true,
|
||||
}: TimePickerProps) {
|
||||
const { hour, minute } = React.useMemo(() => parseValue(value), [value]);
|
||||
|
||||
const hourOptions = React.useMemo(
|
||||
() => Array.from({ length: 24 }, (_, i) => i),
|
||||
[],
|
||||
);
|
||||
const minuteOptions = React.useMemo(() => {
|
||||
const step = Math.max(1, Math.floor(minuteStep));
|
||||
return Array.from({ length: Math.ceil(60 / step) }, (_, i) => i * step);
|
||||
}, [minuteStep]);
|
||||
|
||||
const emit = (h: number, m: number) => onChange(`${pad2(h)}:${pad2(m)}`);
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn('inline-flex items-center gap-1.5', className)}
|
||||
role="group"
|
||||
aria-label={ariaLabel}
|
||||
>
|
||||
{showIcon && (
|
||||
<Clock
|
||||
className="text-muted-foreground size-4 shrink-0"
|
||||
aria-hidden="true"
|
||||
/>
|
||||
)}
|
||||
<Select
|
||||
value={String(hour)}
|
||||
onValueChange={(v) => emit(Number(v), minute)}
|
||||
>
|
||||
<SelectTrigger
|
||||
size="sm"
|
||||
aria-label={ariaLabel ? `${ariaLabel} hours` : 'Hours'}
|
||||
className="h-9 w-[4.25rem]"
|
||||
>
|
||||
<SelectValue>{pad2(hour)}</SelectValue>
|
||||
</SelectTrigger>
|
||||
<SelectContent className="max-h-60">
|
||||
{hourOptions.map((h) => (
|
||||
<SelectItem key={h} value={String(h)}>
|
||||
{pad2(h)}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
<span className="text-muted-foreground text-sm select-none">:</span>
|
||||
<Select
|
||||
value={String(minute)}
|
||||
onValueChange={(v) => emit(hour, Number(v))}
|
||||
>
|
||||
<SelectTrigger
|
||||
size="sm"
|
||||
aria-label={ariaLabel ? `${ariaLabel} minutes` : 'Minutes'}
|
||||
className="h-9 w-[4.25rem]"
|
||||
>
|
||||
<SelectValue>{pad2(minute)}</SelectValue>
|
||||
</SelectTrigger>
|
||||
<SelectContent className="max-h-60">
|
||||
{minuteOptions.map((m) => (
|
||||
<SelectItem key={m} value={String(m)}>
|
||||
{pad2(m)}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -333,6 +333,7 @@ export default function Conversation() {
|
||||
onOpenArtifact={handleOpenArtifact}
|
||||
onToolAction={handleToolAction}
|
||||
isSplitView={isSplitArtifactOpen}
|
||||
agentId={selectedAgent?.id}
|
||||
headerContent={
|
||||
selectedAgent ? (
|
||||
<div className="flex w-full items-center justify-center py-4">
|
||||
|
||||
@@ -13,6 +13,7 @@ import rehypeKatex from 'rehype-katex';
|
||||
import remarkGfm from 'remark-gfm';
|
||||
import remarkMath from 'remark-math';
|
||||
|
||||
import SchedulerToolCallCard from '../agents/schedules/SchedulerToolCallCard';
|
||||
import ChevronDown from '../assets/chevron-down.svg';
|
||||
import Cloud from '../assets/cloud.svg';
|
||||
import DocsGPT3 from '../assets/cute_docsgpt3.svg';
|
||||
@@ -70,6 +71,8 @@ const ConversationBubble = forwardRef<
|
||||
decision: 'approved' | 'denied',
|
||||
comment?: string,
|
||||
) => void;
|
||||
/** Active agent id; refreshes the Schedules tab from SchedulerToolCallCard. */
|
||||
agentId?: string;
|
||||
}
|
||||
>(function ConversationBubble(
|
||||
{
|
||||
@@ -89,6 +92,7 @@ const ConversationBubble = forwardRef<
|
||||
filesAttached,
|
||||
onOpenArtifact,
|
||||
onToolAction,
|
||||
agentId,
|
||||
},
|
||||
ref,
|
||||
) {
|
||||
@@ -423,7 +427,11 @@ const ConversationBubble = forwardRef<
|
||||
)}
|
||||
{research && <ResearchProgress research={research} />}
|
||||
{toolCalls && toolCalls.length > 0 && (
|
||||
<ToolCalls toolCalls={toolCalls} onToolAction={onToolAction} />
|
||||
<ToolCalls
|
||||
toolCalls={toolCalls}
|
||||
onToolAction={onToolAction}
|
||||
agentId={agentId}
|
||||
/>
|
||||
)}
|
||||
{!message && primaryArtifactCall?.artifact_id && onOpenArtifact && (
|
||||
<div className="my-2 ml-2 flex justify-start">
|
||||
@@ -1005,6 +1013,7 @@ function ToolCallApprovalBar({
|
||||
function ToolCalls({
|
||||
toolCalls,
|
||||
onToolAction,
|
||||
agentId,
|
||||
}: {
|
||||
toolCalls: ToolCallsType[];
|
||||
onToolAction?: (
|
||||
@@ -1012,6 +1021,7 @@ function ToolCalls({
|
||||
decision: 'approved' | 'denied',
|
||||
comment?: string,
|
||||
) => void;
|
||||
agentId?: string;
|
||||
}) {
|
||||
const [isToolCallsOpen, setIsToolCallsOpen] = useState(false);
|
||||
|
||||
@@ -1023,7 +1033,7 @@ function ToolCalls({
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="mb-4 relative flex w-full flex-col flex-wrap items-start self-start lg:flex-nowrap">
|
||||
<div className="relative mb-4 flex w-full flex-col flex-wrap items-start self-start lg:flex-nowrap">
|
||||
{/* Approval bars — always visible, compact inline */}
|
||||
{awaitingCalls.length > 0 && (
|
||||
<div className="fade-in mt-4 ml-3 w-[90vw] md:w-[70vw] lg:w-full">
|
||||
@@ -1066,88 +1076,101 @@ function ToolCalls({
|
||||
{isToolCallsOpen && (
|
||||
<div className="fade-in mr-5 ml-3 w-[90vw] md:w-[70vw] lg:w-full">
|
||||
<div className="grid grid-cols-1 gap-2">
|
||||
{resolvedCalls.map((toolCall, index) => (
|
||||
<Accordion
|
||||
key={`tool-call-${index}`}
|
||||
title={`${toolCall.tool_name} - ${toolCall.action_name.substring(0, toolCall.action_name.lastIndexOf('_'))}`}
|
||||
className="bg-muted dark:bg-answer-bubble w-full rounded-4xl"
|
||||
titleClassName="px-6 py-2 text-sm font-semibold"
|
||||
>
|
||||
<div className="flex flex-col gap-1">
|
||||
<div className="border-border flex flex-col rounded-2xl border">
|
||||
<p className="dark:bg-background flex flex-row items-center justify-between rounded-t-2xl bg-black/10 px-2 py-1 text-sm font-semibold wrap-break-word">
|
||||
<span style={{ fontFamily: 'IBMPlexMono-Medium' }}>
|
||||
Arguments
|
||||
</span>{' '}
|
||||
<CopyButton
|
||||
textToCopy={JSON.stringify(
|
||||
toolCall.arguments,
|
||||
null,
|
||||
2,
|
||||
)}
|
||||
/>
|
||||
</p>
|
||||
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
|
||||
<span
|
||||
className="dark:text-muted-foreground leading-5.75 text-black"
|
||||
style={{ fontFamily: 'IBMPlexMono-Medium' }}
|
||||
>
|
||||
{JSON.stringify(toolCall.arguments, null, 2)}
|
||||
</span>
|
||||
</p>
|
||||
</div>
|
||||
<div className="border-border flex flex-col rounded-2xl border">
|
||||
<p className="dark:bg-background flex flex-row items-center justify-between rounded-t-2xl bg-black/10 px-2 py-1 text-sm font-semibold wrap-break-word">
|
||||
<span style={{ fontFamily: 'IBMPlexMono-Medium' }}>
|
||||
Response
|
||||
</span>{' '}
|
||||
<CopyButton
|
||||
textToCopy={
|
||||
toolCall.status === 'error'
|
||||
? toolCall.error || 'Unknown error'
|
||||
: JSON.stringify(toolCall.result, null, 2)
|
||||
}
|
||||
/>
|
||||
</p>
|
||||
{toolCall.status === 'pending' && (
|
||||
<span className="dark:bg-card flex w-full items-center justify-center rounded-b-2xl p-2">
|
||||
<Spinner size="small" />
|
||||
</span>
|
||||
)}
|
||||
{toolCall.status === 'completed' && (
|
||||
{resolvedCalls.map((toolCall, index) => {
|
||||
if (toolCall.tool_name === 'scheduler') {
|
||||
return (
|
||||
<SchedulerToolCallCard
|
||||
key={`scheduler-${toolCall.call_id ?? index}`}
|
||||
result={toolCall.result}
|
||||
actionName={toolCall.action_name}
|
||||
status={toolCall.status}
|
||||
agentId={agentId}
|
||||
/>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<Accordion
|
||||
key={`tool-call-${index}`}
|
||||
title={`${toolCall.tool_name} - ${toolCall.action_name.substring(0, toolCall.action_name.lastIndexOf('_'))}`}
|
||||
className="bg-muted dark:bg-answer-bubble w-full rounded-4xl"
|
||||
titleClassName="px-6 py-2 text-sm font-semibold"
|
||||
>
|
||||
<div className="flex flex-col gap-1">
|
||||
<div className="border-border flex flex-col rounded-2xl border">
|
||||
<p className="dark:bg-background flex flex-row items-center justify-between rounded-t-2xl bg-black/10 px-2 py-1 text-sm font-semibold wrap-break-word">
|
||||
<span style={{ fontFamily: 'IBMPlexMono-Medium' }}>
|
||||
Arguments
|
||||
</span>{' '}
|
||||
<CopyButton
|
||||
textToCopy={JSON.stringify(
|
||||
toolCall.arguments,
|
||||
null,
|
||||
2,
|
||||
)}
|
||||
/>
|
||||
</p>
|
||||
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
|
||||
<span
|
||||
className="dark:text-muted-foreground leading-5.75 text-black"
|
||||
style={{ fontFamily: 'IBMPlexMono-Medium' }}
|
||||
>
|
||||
{JSON.stringify(toolCall.result, null, 2)}
|
||||
{JSON.stringify(toolCall.arguments, null, 2)}
|
||||
</span>
|
||||
</p>
|
||||
)}
|
||||
{toolCall.status === 'error' && (
|
||||
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
|
||||
<span
|
||||
className="text-destructive leading-5.75"
|
||||
style={{ fontFamily: 'IBMPlexMono-Medium' }}
|
||||
>
|
||||
{toolCall.error}
|
||||
</span>
|
||||
</div>
|
||||
<div className="border-border flex flex-col rounded-2xl border">
|
||||
<p className="dark:bg-background flex flex-row items-center justify-between rounded-t-2xl bg-black/10 px-2 py-1 text-sm font-semibold wrap-break-word">
|
||||
<span style={{ fontFamily: 'IBMPlexMono-Medium' }}>
|
||||
Response
|
||||
</span>{' '}
|
||||
<CopyButton
|
||||
textToCopy={
|
||||
toolCall.status === 'error'
|
||||
? toolCall.error || 'Unknown error'
|
||||
: JSON.stringify(toolCall.result, null, 2)
|
||||
}
|
||||
/>
|
||||
</p>
|
||||
)}
|
||||
{toolCall.status === 'denied' && (
|
||||
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
|
||||
<span
|
||||
className="text-muted-foreground leading-5.75"
|
||||
style={{ fontFamily: 'IBMPlexMono-Medium' }}
|
||||
>
|
||||
Denied by user
|
||||
{toolCall.status === 'pending' && (
|
||||
<span className="dark:bg-card flex w-full items-center justify-center rounded-b-2xl p-2">
|
||||
<Spinner size="small" />
|
||||
</span>
|
||||
</p>
|
||||
)}
|
||||
)}
|
||||
{toolCall.status === 'completed' && (
|
||||
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
|
||||
<span
|
||||
className="dark:text-muted-foreground leading-5.75 text-black"
|
||||
style={{ fontFamily: 'IBMPlexMono-Medium' }}
|
||||
>
|
||||
{JSON.stringify(toolCall.result, null, 2)}
|
||||
</span>
|
||||
</p>
|
||||
)}
|
||||
{toolCall.status === 'error' && (
|
||||
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
|
||||
<span
|
||||
className="text-destructive leading-5.75"
|
||||
style={{ fontFamily: 'IBMPlexMono-Medium' }}
|
||||
>
|
||||
{toolCall.error}
|
||||
</span>
|
||||
</p>
|
||||
)}
|
||||
{toolCall.status === 'denied' && (
|
||||
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
|
||||
<span
|
||||
className="text-muted-foreground leading-5.75"
|
||||
style={{ fontFamily: 'IBMPlexMono-Medium' }}
|
||||
>
|
||||
Denied by user
|
||||
</span>
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</Accordion>
|
||||
))}
|
||||
</Accordion>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
@@ -44,6 +44,8 @@ type ConversationMessagesProps = {
|
||||
comment?: string,
|
||||
) => void;
|
||||
isSplitView?: boolean;
|
||||
/** Active agent id; threaded into SchedulerToolCallCard. */
|
||||
agentId?: string;
|
||||
};
|
||||
|
||||
export default function ConversationMessages({
|
||||
@@ -57,6 +59,7 @@ export default function ConversationMessages({
|
||||
onOpenArtifact,
|
||||
onToolAction,
|
||||
isSplitView = false,
|
||||
agentId,
|
||||
}: ConversationMessagesProps) {
|
||||
const [isDarkTheme] = useDarkTheme();
|
||||
const { t } = useTranslation();
|
||||
@@ -302,6 +305,7 @@ export default function ConversationMessages({
|
||||
onToolAction={onToolAction}
|
||||
feedback={query.feedback}
|
||||
isStreaming={isCurrentlyStreaming}
|
||||
agentId={agentId}
|
||||
handleFeedback={
|
||||
handleFeedback
|
||||
? (feedback) => handleFeedback(query, feedback, index)
|
||||
|
||||
216
frontend/src/conversation/conversationListener.test.ts
Normal file
216
frontend/src/conversation/conversationListener.test.ts
Normal file
@@ -0,0 +1,216 @@
|
||||
import { configureStore } from '@reduxjs/toolkit';
|
||||
import {
|
||||
afterEach,
|
||||
beforeEach,
|
||||
describe,
|
||||
expect,
|
||||
it,
|
||||
vi,
|
||||
type Mock,
|
||||
} from 'vitest';
|
||||
|
||||
import conversationService from '../api/services/conversationService';
|
||||
import {
|
||||
sseEventReceived,
|
||||
type SSEEvent,
|
||||
} from '../notifications/notificationsSlice';
|
||||
import * as preferenceApi from '../preferences/preferenceApi';
|
||||
import { type Preference, prefSlice } from '../preferences/preferenceSlice';
|
||||
import { type ConversationState } from './conversationModels';
|
||||
import {
|
||||
conversationListenerMiddleware,
|
||||
conversationSlice,
|
||||
setConversation,
|
||||
} from './conversationSlice';
|
||||
|
||||
vi.mock('../api/services/conversationService', () => ({
|
||||
default: {
|
||||
getConversation: vi.fn(),
|
||||
tailMessage: vi.fn(),
|
||||
getConversations: vi.fn(),
|
||||
answer: vi.fn(),
|
||||
answerStream: vi.fn(),
|
||||
search: vi.fn(),
|
||||
feedback: vi.fn(),
|
||||
shareConversation: vi.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
vi.mock('../preferences/preferenceApi', async () => {
|
||||
const actual = await vi.importActual<typeof preferenceApi>(
|
||||
'../preferences/preferenceApi',
|
||||
);
|
||||
return { ...actual, getConversations: vi.fn() };
|
||||
});
|
||||
|
||||
const ENVELOPE = (overrides: Partial<SSEEvent> = {}): SSEEvent => ({
|
||||
id: 'evt-msg-1',
|
||||
ts: '2026-05-19T12:34:56Z',
|
||||
type: 'schedule.message.appended',
|
||||
payload: {
|
||||
conversation_id: 'conv-1',
|
||||
message_id: 'msg-1',
|
||||
schedule_id: 'sched-1',
|
||||
run_id: 'run-1',
|
||||
},
|
||||
...overrides,
|
||||
});
|
||||
|
||||
const makeStore = (
|
||||
initialConversationId: string | null = null,
|
||||
initialStatus: ConversationState['status'] = 'idle',
|
||||
) => {
|
||||
const preference: Preference = {
|
||||
apiKey: '',
|
||||
prompt: { name: 'default', id: 'default', type: 'public' },
|
||||
prompts: [],
|
||||
chunks: '2',
|
||||
selectedDocs: [],
|
||||
sourceDocs: null,
|
||||
conversations: { data: null, loading: false },
|
||||
token: 'tok-1',
|
||||
modalState: 'INACTIVE',
|
||||
paginatedDocuments: null,
|
||||
templateAgents: null,
|
||||
agents: null,
|
||||
sharedAgents: null,
|
||||
selectedAgent: null,
|
||||
selectedModel: null,
|
||||
availableModels: [],
|
||||
modelsLoading: false,
|
||||
agentFolders: null,
|
||||
};
|
||||
const conversation: ConversationState = {
|
||||
queries: [],
|
||||
status: initialStatus,
|
||||
conversationId: initialConversationId,
|
||||
};
|
||||
return configureStore({
|
||||
reducer: {
|
||||
preference: prefSlice.reducer,
|
||||
conversation: conversationSlice.reducer,
|
||||
},
|
||||
preloadedState: { preference, conversation },
|
||||
middleware: (getDefaultMiddleware) =>
|
||||
getDefaultMiddleware().concat(conversationListenerMiddleware.middleware),
|
||||
});
|
||||
};
|
||||
|
||||
describe('conversation listener — schedule.message.appended', () => {
|
||||
beforeEach(() => {
|
||||
(conversationService.getConversation as unknown as Mock).mockReset();
|
||||
(preferenceApi.getConversations as unknown as Mock).mockReset();
|
||||
(conversationService.getConversation as unknown as Mock).mockResolvedValue({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
queries: [
|
||||
{ prompt: 'hi', response: 'hello', status: 'complete' },
|
||||
{
|
||||
prompt: '',
|
||||
response: 'scheduled run output',
|
||||
status: 'complete',
|
||||
},
|
||||
],
|
||||
}),
|
||||
});
|
||||
(preferenceApi.getConversations as unknown as Mock).mockResolvedValue({
|
||||
data: [{ id: 'conv-1', name: 'Scheduled chat', agent_id: 'agent-1' }],
|
||||
loading: false,
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
it('refetches the open conversation when the appended message lands on it', async () => {
|
||||
const store = makeStore('conv-1');
|
||||
store.dispatch(sseEventReceived(ENVELOPE()));
|
||||
await new Promise((r) => setTimeout(r, 0));
|
||||
await new Promise((r) => setTimeout(r, 0));
|
||||
|
||||
expect(conversationService.getConversation).toHaveBeenCalledWith(
|
||||
'conv-1',
|
||||
'tok-1',
|
||||
);
|
||||
const state = store.getState();
|
||||
expect(state.conversation.queries).toHaveLength(2);
|
||||
expect(state.conversation.queries[1].response).toBe('scheduled run output');
|
||||
expect(state.conversation.conversationId).toBe('conv-1');
|
||||
});
|
||||
|
||||
it('refreshes the conversations sidebar list so the bumped chat reorders', async () => {
|
||||
const store = makeStore('conv-other');
|
||||
store.dispatch(sseEventReceived(ENVELOPE()));
|
||||
await new Promise((r) => setTimeout(r, 0));
|
||||
await new Promise((r) => setTimeout(r, 0));
|
||||
|
||||
expect(preferenceApi.getConversations).toHaveBeenCalledWith('tok-1');
|
||||
const list = store.getState().preference.conversations;
|
||||
expect(list.data).toEqual([
|
||||
{ id: 'conv-1', name: 'Scheduled chat', agent_id: 'agent-1' },
|
||||
]);
|
||||
});
|
||||
|
||||
it('does not refetch the open conversation when the appended message targets a different chat', async () => {
|
||||
const store = makeStore('conv-other');
|
||||
store.dispatch(sseEventReceived(ENVELOPE()));
|
||||
await new Promise((r) => setTimeout(r, 0));
|
||||
await new Promise((r) => setTimeout(r, 0));
|
||||
|
||||
expect(conversationService.getConversation).not.toHaveBeenCalled();
|
||||
expect(preferenceApi.getConversations).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('ignores envelopes without a conversation_id', async () => {
|
||||
const store = makeStore('conv-1');
|
||||
store.dispatch(
|
||||
sseEventReceived(
|
||||
ENVELOPE({ payload: { schedule_id: 'sched-1', run_id: 'run-1' } }),
|
||||
),
|
||||
);
|
||||
await new Promise((r) => setTimeout(r, 0));
|
||||
|
||||
expect(conversationService.getConversation).not.toHaveBeenCalled();
|
||||
expect(preferenceApi.getConversations).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('skips refetching the open conversation while a live stream is in flight', async () => {
|
||||
// Mid-stream: refetching would flip status to 'idle' and the next chunk
|
||||
// would die on the updateStreamingQuery guard.
|
||||
const store = makeStore('conv-1', 'loading');
|
||||
store.dispatch(sseEventReceived(ENVELOPE()));
|
||||
await new Promise((r) => setTimeout(r, 0));
|
||||
await new Promise((r) => setTimeout(r, 0));
|
||||
|
||||
expect(conversationService.getConversation).not.toHaveBeenCalled();
|
||||
expect(store.getState().conversation.status).toBe('loading');
|
||||
expect(preferenceApi.getConversations).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('ignores non-scheduler SSE envelopes', async () => {
|
||||
const store = makeStore('conv-1');
|
||||
store.dispatch(
|
||||
sseEventReceived({
|
||||
id: 'evt-2',
|
||||
type: 'source.ingest.progress',
|
||||
payload: { conversation_id: 'conv-1' },
|
||||
}),
|
||||
);
|
||||
await new Promise((r) => setTimeout(r, 0));
|
||||
|
||||
expect(conversationService.getConversation).not.toHaveBeenCalled();
|
||||
expect(preferenceApi.getConversations).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('listener middleware export hygiene', () => {
|
||||
it('exports the listener middleware so the store can wire it', () => {
|
||||
expect(conversationListenerMiddleware).toBeDefined();
|
||||
expect(typeof conversationListenerMiddleware.middleware).toBe('function');
|
||||
});
|
||||
|
||||
it('still exports the slice actions consumers rely on', () => {
|
||||
expect(typeof setConversation).toBe('function');
|
||||
});
|
||||
});
|
||||
@@ -1,6 +1,15 @@
|
||||
import { createAsyncThunk, createSlice, PayloadAction } from '@reduxjs/toolkit';
|
||||
import {
|
||||
createAsyncThunk,
|
||||
createListenerMiddleware,
|
||||
createSlice,
|
||||
PayloadAction,
|
||||
} from '@reduxjs/toolkit';
|
||||
|
||||
import conversationService from '../api/services/conversationService';
|
||||
import {
|
||||
sseEventReceived,
|
||||
type SSEEvent,
|
||||
} from '../notifications/notificationsSlice';
|
||||
import { getConversations } from '../preferences/preferenceApi';
|
||||
import { setConversations } from '../preferences/preferenceSlice';
|
||||
import store from '../store';
|
||||
@@ -1052,3 +1061,45 @@ export const {
|
||||
updateMessageMeta,
|
||||
} = conversationSlice.actions;
|
||||
export default conversationSlice.reducer;
|
||||
|
||||
// Listener (not a reducer) so a scheduled message appended to the open
|
||||
// chat can dispatch loadConversation + sidebar refresh.
|
||||
export const conversationListenerMiddleware = createListenerMiddleware();
|
||||
|
||||
conversationListenerMiddleware.startListening({
|
||||
actionCreator: sseEventReceived,
|
||||
effect: async (action: PayloadAction<SSEEvent>, listenerApi) => {
|
||||
const envelope = action.payload;
|
||||
if (envelope.type !== 'schedule.message.appended') return;
|
||||
const payload = (envelope.payload || {}) as Record<string, unknown>;
|
||||
const conversationId =
|
||||
(payload.conversation_id as string | undefined) || '';
|
||||
if (!conversationId) return;
|
||||
|
||||
const state = listenerApi.getState() as RootState;
|
||||
const token = state.preference.token;
|
||||
|
||||
// Skip mid-stream: loadConversation -> updateConversationId flips status
|
||||
// to 'idle', and the next SSE chunk dies on the 'idle' guard in
|
||||
// updateStreamingQuery. Defer the refresh to the user's next navigation.
|
||||
if (
|
||||
state.conversation.conversationId === conversationId &&
|
||||
state.conversation.status !== 'loading'
|
||||
) {
|
||||
listenerApi.dispatch(
|
||||
loadConversation({ id: conversationId, force: true }),
|
||||
);
|
||||
}
|
||||
|
||||
// Refresh sidebar; server reorders by updated_at which just bumped.
|
||||
try {
|
||||
const fetched = await getConversations(token);
|
||||
listenerApi.dispatch(setConversations(fetched));
|
||||
} catch (error) {
|
||||
console.error(
|
||||
'schedule.message.appended: conversations refresh failed',
|
||||
error,
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
@@ -46,4 +46,15 @@ describe('dispatchSSEEvent', () => {
|
||||
'mystery.event',
|
||||
]);
|
||||
});
|
||||
|
||||
it.each([
|
||||
'schedule.run.completed',
|
||||
'schedule.run.failed',
|
||||
'schedule.autopaused',
|
||||
'schedule.message.appended',
|
||||
])('treats %s as a known envelope (no debug noise)', (type) => {
|
||||
const dispatch = vi.fn() as unknown as AppDispatch;
|
||||
dispatchSSEEvent({ id: `e-${type}`, type }, dispatch);
|
||||
expect(debugSpy).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -24,6 +24,11 @@ const KNOWN_TYPES: ReadonlySet<string> = new Set([
|
||||
'mcp.oauth.completed',
|
||||
'mcp.oauth.failed',
|
||||
'tool.approval.required',
|
||||
// Scheduler envelopes (scheduler_worker.py); consumed by schedulesSlice.
|
||||
'schedule.run.completed',
|
||||
'schedule.run.failed',
|
||||
'schedule.autopaused',
|
||||
'schedule.message.appended',
|
||||
]);
|
||||
|
||||
/**
|
||||
|
||||
@@ -161,6 +161,7 @@
|
||||
"manageTools": "Zu den Werkzeugen",
|
||||
"edit": "Bearbeiten",
|
||||
"delete": "Löschen",
|
||||
"builtIn": "Integriert",
|
||||
"deleteWarning": "Bist du sicher, dass du das Werkzeug \"{{toolName}}\" löschen möchtest?",
|
||||
"unsavedChanges": "Du hast ungespeicherte Änderungen, die verloren gehen, wenn du ohne Speichern verlässt.",
|
||||
"leaveWithoutSaving": "Ohne Speichern verlassen",
|
||||
|
||||
@@ -162,6 +162,7 @@
|
||||
"edit": "Edit",
|
||||
"delete": "Delete",
|
||||
"reconnect": "Reconnect",
|
||||
"builtIn": "Built-in",
|
||||
"authStatus": {
|
||||
"connected": "Connected",
|
||||
"needsAuth": "Needs Auth",
|
||||
@@ -630,6 +631,18 @@
|
||||
"description": "Discover and create custom versions of DocsGPT that combine instructions, extra knowledge, and any combination of skills",
|
||||
"newAgent": "New Agent",
|
||||
"backToAll": "Back to all agents",
|
||||
"pageHeader": {
|
||||
"crumbs": {
|
||||
"agents": "Agents"
|
||||
},
|
||||
"tabs": {
|
||||
"overview": "Overview",
|
||||
"logs": "Logs",
|
||||
"schedules": "Schedules"
|
||||
},
|
||||
"fallbackName": "Untitled agent",
|
||||
"subnavAriaLabel": "Agent sub-navigation"
|
||||
},
|
||||
"searchPlaceholder": "Search...",
|
||||
"noSearchResults": "No agents found",
|
||||
"tryDifferentSearch": "Try a different search term",
|
||||
@@ -669,6 +682,7 @@
|
||||
"cancel": "Cancel",
|
||||
"delete": "Delete",
|
||||
"logs": "Logs",
|
||||
"schedules": "Schedules",
|
||||
"accessDetails": "Access Details",
|
||||
"add": "Add"
|
||||
},
|
||||
@@ -702,7 +716,10 @@
|
||||
"toolsPopup": {
|
||||
"title": "Select Tools",
|
||||
"searchPlaceholder": "Search tools...",
|
||||
"noOptionsMessage": "No tools available"
|
||||
"noOptionsMessage": "No tools available",
|
||||
"groupBuiltin": "Built-in",
|
||||
"groupDefault": "Default",
|
||||
"groupCustom": "Custom"
|
||||
},
|
||||
"modelsPopup": {
|
||||
"title": "Select Models",
|
||||
@@ -743,6 +760,75 @@
|
||||
"noUsageHistory": "No usage history",
|
||||
"tableHeader": "Agent endpoint logs"
|
||||
},
|
||||
"schedules": {
|
||||
"title": "Agent Schedules",
|
||||
"heading": "Schedules",
|
||||
"newRecurring": "New schedule",
|
||||
"closeForm": "Close form",
|
||||
"edit": "Edit",
|
||||
"recurring": "Recurring",
|
||||
"oneTime": "One-time tasks",
|
||||
"noRecurring": "No recurring schedules yet.",
|
||||
"noOneTime": "No one-time tasks yet.",
|
||||
"pause": "Pause",
|
||||
"resume": "Resume",
|
||||
"runNow": "Run now",
|
||||
"delete": "Delete",
|
||||
"deleteConfirm": "Delete this schedule? This will also delete its run history. This action cannot be undone.",
|
||||
"cancel": "Cancel",
|
||||
"showRuns": "Show runs",
|
||||
"hideRuns": "Hide runs",
|
||||
"modal": {
|
||||
"titleCreate": "New schedule",
|
||||
"titleEdit": "Edit schedule",
|
||||
"namePlaceholder": "Name of task",
|
||||
"frequency": {
|
||||
"once": "Once",
|
||||
"daily": "Daily",
|
||||
"weekly": "Weekly",
|
||||
"monthly": "Monthly",
|
||||
"yearly": "Yearly"
|
||||
},
|
||||
"on": "On",
|
||||
"at": "At",
|
||||
"pickDate": "Pick a date",
|
||||
"timezone": "Timezone",
|
||||
"timezonePlaceholder": "Select timezone",
|
||||
"timezoneSearchPlaceholder": "Search timezone…",
|
||||
"timezoneEmpty": "No timezone found.",
|
||||
"instructionsLabel": "Instructions",
|
||||
"instructionsPlaceholder": "Enter prompt here.",
|
||||
"create": "Create task",
|
||||
"save": "Save changes",
|
||||
"errors": {
|
||||
"instructionRequired": "Instructions are required.",
|
||||
"runAtInPast": "Pick a date/time in the future."
|
||||
},
|
||||
"days": {
|
||||
"mon": "Mon",
|
||||
"tue": "Tue",
|
||||
"wed": "Wed",
|
||||
"thu": "Thu",
|
||||
"fri": "Fri",
|
||||
"sat": "Sat",
|
||||
"sun": "Sun"
|
||||
},
|
||||
"months": {
|
||||
"jan": "Jan",
|
||||
"feb": "Feb",
|
||||
"mar": "Mar",
|
||||
"apr": "Apr",
|
||||
"may": "May",
|
||||
"jun": "Jun",
|
||||
"jul": "Jul",
|
||||
"aug": "Aug",
|
||||
"sep": "Sep",
|
||||
"oct": "Oct",
|
||||
"nov": "Nov",
|
||||
"dec": "Dec"
|
||||
}
|
||||
}
|
||||
},
|
||||
"shared": {
|
||||
"notFound": "No agent found. Please ensure the agent is shared."
|
||||
},
|
||||
|
||||
@@ -161,6 +161,7 @@
|
||||
"manageTools": "Ir a Herramientas",
|
||||
"edit": "Editar",
|
||||
"delete": "Eliminar",
|
||||
"builtIn": "Integrada",
|
||||
"deleteWarning": "¿Estás seguro de que deseas eliminar la herramienta \"{{toolName}}\"?",
|
||||
"unsavedChanges": "Tienes cambios sin guardar que se perderán si sales sin guardar.",
|
||||
"leaveWithoutSaving": "Salir sin Guardar",
|
||||
|
||||
@@ -161,6 +161,7 @@
|
||||
"manageTools": "ツールへ移動",
|
||||
"edit": "編集",
|
||||
"delete": "削除",
|
||||
"builtIn": "ビルトイン",
|
||||
"deleteWarning": "ツール \"{{toolName}}\" を削除してもよろしいですか?",
|
||||
"unsavedChanges": "保存されていない変更があります。保存せずに離れると失われます。",
|
||||
"leaveWithoutSaving": "保存せずに離れる",
|
||||
|
||||
@@ -161,6 +161,7 @@
|
||||
"manageTools": "Перейти к инструментам",
|
||||
"edit": "Редактировать",
|
||||
"delete": "Удалить",
|
||||
"builtIn": "Встроенный",
|
||||
"deleteWarning": "Вы уверены, что хотите удалить инструмент \"{{toolName}}\"?",
|
||||
"unsavedChanges": "У вас есть несохраненные изменения, которые будут потеряны, если вы уйдете без сохранения.",
|
||||
"leaveWithoutSaving": "Уйти без сохранения",
|
||||
|
||||
@@ -161,6 +161,7 @@
|
||||
"manageTools": "前往工具",
|
||||
"edit": "編輯",
|
||||
"delete": "刪除",
|
||||
"builtIn": "內建",
|
||||
"deleteWarning": "您確定要刪除工具 \"{{toolName}}\" 嗎?",
|
||||
"unsavedChanges": "您有未儲存的變更,如果不儲存就離開將會遺失。",
|
||||
"leaveWithoutSaving": "不儲存離開",
|
||||
|
||||
@@ -161,6 +161,7 @@
|
||||
"manageTools": "前往工具",
|
||||
"edit": "编辑",
|
||||
"delete": "删除",
|
||||
"builtIn": "内置",
|
||||
"deleteWarning": "您确定要删除工具 \"{{toolName}}\" 吗?",
|
||||
"unsavedChanges": "您有未保存的更改,如果不保存就离开将会丢失。",
|
||||
"leaveWithoutSaving": "不保存离开",
|
||||
|
||||
@@ -144,7 +144,15 @@ export default function Tools() {
|
||||
return res.json();
|
||||
})
|
||||
.then((data) => {
|
||||
setUserTools(data.tools);
|
||||
// Pure builtins (agent-only, e.g. a future builtin without an
|
||||
// agentless path) carry no per-user state and only apply when
|
||||
// added to an agent — hide them from the management page. Dual-
|
||||
// registered tools (``scheduler``: builtin + default) stay visible
|
||||
// here so the user can toggle the default off in agentless chats.
|
||||
const filtered = (data.tools || []).filter(
|
||||
(tool: UserToolType) => tool.default || !tool.builtin,
|
||||
);
|
||||
setUserTools(filtered);
|
||||
setLoading(false);
|
||||
})
|
||||
.catch((error) => {
|
||||
@@ -282,32 +290,34 @@ export default function Tools() {
|
||||
key={index}
|
||||
className="bg-muted hover:bg-accent relative flex h-52 w-[300px] flex-col justify-between overflow-hidden rounded-2xl p-6"
|
||||
>
|
||||
<div
|
||||
ref={menuRefs.current[tool.id]}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
setActiveMenuId(
|
||||
activeMenuId === tool.id ? null : tool.id,
|
||||
);
|
||||
}}
|
||||
className="absolute top-4 right-4 z-10 cursor-pointer"
|
||||
>
|
||||
<img
|
||||
src={ThreeDotsIcon}
|
||||
alt={t('settings.tools.settingsIconAlt')}
|
||||
className="h-[19px] w-[19px]"
|
||||
/>
|
||||
<ContextMenu
|
||||
isOpen={activeMenuId === tool.id}
|
||||
setIsOpen={(isOpen) => {
|
||||
setActiveMenuId(isOpen ? tool.id : null);
|
||||
{!tool.default && (
|
||||
<div
|
||||
ref={menuRefs.current[tool.id]}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
setActiveMenuId(
|
||||
activeMenuId === tool.id ? null : tool.id,
|
||||
);
|
||||
}}
|
||||
options={getMenuOptions(tool)}
|
||||
anchorRef={menuRefs.current[tool.id]}
|
||||
position="bottom-right"
|
||||
offset={{ x: 0, y: 0 }}
|
||||
/>
|
||||
</div>
|
||||
className="absolute top-4 right-4 z-10 cursor-pointer"
|
||||
>
|
||||
<img
|
||||
src={ThreeDotsIcon}
|
||||
alt={t('settings.tools.settingsIconAlt')}
|
||||
className="h-[19px] w-[19px]"
|
||||
/>
|
||||
<ContextMenu
|
||||
isOpen={activeMenuId === tool.id}
|
||||
setIsOpen={(isOpen) => {
|
||||
setActiveMenuId(isOpen ? tool.id : null);
|
||||
}}
|
||||
options={getMenuOptions(tool)}
|
||||
anchorRef={menuRefs.current[tool.id]}
|
||||
position="bottom-right"
|
||||
offset={{ x: 0, y: 0 }}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
<div className="w-full">
|
||||
<div className="flex w-full items-center gap-2 px-1">
|
||||
<img
|
||||
@@ -315,6 +325,11 @@ export default function Tools() {
|
||||
alt={`${tool.displayName} icon`}
|
||||
className="h-6 w-6"
|
||||
/>
|
||||
{tool.default && (
|
||||
<span className="inline-flex items-center rounded-full bg-gray-100 px-2 py-0.5 text-[10px] leading-none font-medium text-gray-600 dark:bg-gray-700/40 dark:text-gray-300">
|
||||
{t('settings.tools.builtIn')}
|
||||
</span>
|
||||
)}
|
||||
{tool.name === 'mcp_tool' &&
|
||||
mcpStatuses[tool.id] && (
|
||||
<span
|
||||
|
||||
@@ -47,6 +47,14 @@ export type UserToolType = {
|
||||
customName?: string;
|
||||
description: string;
|
||||
status: boolean;
|
||||
// True for built-in default chat tools — managed via the opt-out list,
|
||||
// not a user_tools row; not deletable. ``scheduler`` is dual-registered
|
||||
// (both ``default`` and ``builtin``).
|
||||
default?: boolean;
|
||||
// True for agent-selectable builtins (e.g. ``scheduler``) — hidden
|
||||
// from the Add-Tool modal; surfaced to the agent picker. May coexist
|
||||
// with ``default`` for dual-registered tools.
|
||||
builtin?: boolean;
|
||||
config: {
|
||||
[key: string]: any;
|
||||
};
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
import { configureStore } from '@reduxjs/toolkit';
|
||||
|
||||
import agentPreviewReducer from './agents/agentPreviewSlice';
|
||||
import schedulesReducer from './agents/schedules/schedulesSlice';
|
||||
import workflowPreviewReducer from './agents/workflow/workflowPreviewSlice';
|
||||
import { conversationSlice } from './conversation/conversationSlice';
|
||||
import {
|
||||
conversationListenerMiddleware,
|
||||
conversationSlice,
|
||||
} from './conversation/conversationSlice';
|
||||
import { sharedConversationSlice } from './conversation/sharedConversationSlice';
|
||||
import notificationsReducer from './notifications/notificationsSlice';
|
||||
import { getStoredRecentDocs } from './preferences/preferenceApi';
|
||||
@@ -69,9 +73,13 @@ const store = configureStore({
|
||||
agentPreview: agentPreviewReducer,
|
||||
workflowPreview: workflowPreviewReducer,
|
||||
notifications: notificationsReducer,
|
||||
schedules: schedulesReducer,
|
||||
},
|
||||
middleware: (getDefaultMiddleware) =>
|
||||
getDefaultMiddleware().concat(prefListenerMiddleware.middleware),
|
||||
getDefaultMiddleware().concat(
|
||||
prefListenerMiddleware.middleware,
|
||||
conversationListenerMiddleware.middleware,
|
||||
),
|
||||
});
|
||||
|
||||
export type RootState = ReturnType<typeof store.getState>;
|
||||
|
||||
@@ -239,9 +239,14 @@ class TestBaseAgentTools:
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
tools = agent._get_user_tools("test_user")
|
||||
|
||||
assert len(tools) == 2
|
||||
from application.agents.default_tools import loaded_default_tools
|
||||
|
||||
assert len(tools) == 2 + len(loaded_default_tools())
|
||||
assert "0" in tools
|
||||
assert "1" in tools
|
||||
names = {t["name"] for t in tools.values()}
|
||||
assert {"tool1", "tool2"}.issubset(names)
|
||||
assert set(loaded_default_tools()).issubset(names)
|
||||
|
||||
def test_get_user_tools_filters_by_status(
|
||||
self,
|
||||
@@ -268,7 +273,12 @@ class TestBaseAgentTools:
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
tools = agent._get_user_tools("test_user")
|
||||
|
||||
assert len(tools) == 1
|
||||
from application.agents.default_tools import loaded_default_tools
|
||||
|
||||
assert len(tools) == 1 + len(loaded_default_tools())
|
||||
names = {t["name"] for t in tools.values()}
|
||||
assert "tool1" in names
|
||||
assert "tool2" not in names
|
||||
|
||||
def test_get_tools_by_api_key(
|
||||
self,
|
||||
@@ -305,7 +315,13 @@ class TestBaseAgentTools:
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
tools = agent._get_tools("api_key_123")
|
||||
|
||||
assert tool_id in tools
|
||||
from application.agents.default_tools import loaded_default_tools
|
||||
|
||||
# Agent-bound: exactly agents.tools, no defaults.
|
||||
assert set(tools) == {tool_id}
|
||||
names = {t["name"] for t in tools.values()}
|
||||
assert names == {"api_tool"}
|
||||
assert not (set(loaded_default_tools()) & names)
|
||||
|
||||
def test_build_tool_parameters(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
|
||||
398
tests/agents/test_default_tools.py
Normal file
398
tests/agents/test_default_tools.py
Normal file
@@ -0,0 +1,398 @@
|
||||
"""Tests for application.agents.default_tools — the default chat tools."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from application.agents import default_tools
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_tool_cache():
|
||||
"""Drop the module caches so settings overrides take effect."""
|
||||
def _clear():
|
||||
default_tools._tool_cache.clear()
|
||||
default_tools._ids_cache.clear()
|
||||
default_tools._loaded_cache.clear()
|
||||
default_tools._builtin_ids_cache.clear()
|
||||
default_tools._builtin_loaded_cache.clear()
|
||||
|
||||
_clear()
|
||||
yield
|
||||
_clear()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Synthetic ids
|
||||
# ---------------------------------------------------------------------------
|
||||
@pytest.mark.unit
|
||||
class TestSyntheticIds:
|
||||
def test_default_tool_id_is_a_valid_uuid(self):
|
||||
tool_id = default_tools.default_tool_id("memory")
|
||||
assert str(uuid.UUID(tool_id)) == tool_id
|
||||
|
||||
def test_default_tool_id_is_deterministic(self):
|
||||
assert default_tools.default_tool_id("memory") == default_tools.default_tool_id(
|
||||
"memory"
|
||||
)
|
||||
|
||||
def test_distinct_names_get_distinct_ids(self):
|
||||
assert default_tools.default_tool_id("memory") != default_tools.default_tool_id(
|
||||
"read_webpage"
|
||||
)
|
||||
|
||||
def test_default_tool_ids_covers_configured_set(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
default_tools.settings, "DEFAULT_CHAT_TOOLS", ["memory", "scheduler"]
|
||||
)
|
||||
ids = default_tools.default_tool_ids()
|
||||
assert set(ids) == {"memory", "scheduler"}
|
||||
|
||||
def test_default_tool_ids_is_memoized(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
default_tools.settings, "DEFAULT_CHAT_TOOLS", ["memory", "scheduler"]
|
||||
)
|
||||
first = default_tools.default_tool_ids()
|
||||
assert default_tools.default_tool_ids() is first
|
||||
|
||||
def test_default_tool_ids_rebuilds_when_setting_changes(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
default_tools.settings, "DEFAULT_CHAT_TOOLS", ["memory"]
|
||||
)
|
||||
assert set(default_tools.default_tool_ids()) == {"memory"}
|
||||
monkeypatch.setattr(
|
||||
default_tools.settings, "DEFAULT_CHAT_TOOLS", ["memory", "read_webpage"]
|
||||
)
|
||||
assert set(default_tools.default_tool_ids()) == {"memory", "read_webpage"}
|
||||
|
||||
def test_is_default_tool_id_recognises_synthetic_ids(self):
|
||||
assert default_tools.is_default_tool_id(
|
||||
default_tools.default_tool_id("memory")
|
||||
)
|
||||
|
||||
def test_is_default_tool_id_rejects_random_uuid(self):
|
||||
assert not default_tools.is_default_tool_id(str(uuid.uuid4()))
|
||||
|
||||
def test_is_default_tool_id_rejects_empty(self):
|
||||
assert not default_tools.is_default_tool_id(None)
|
||||
assert not default_tools.is_default_tool_id("")
|
||||
|
||||
def test_name_for_id_round_trip(self):
|
||||
tool_id = default_tools.default_tool_id("read_webpage")
|
||||
assert default_tools.default_tool_name_for_id(tool_id) == "read_webpage"
|
||||
|
||||
def test_name_for_id_unknown_returns_none(self):
|
||||
assert default_tools.default_tool_name_for_id(str(uuid.uuid4())) is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Startup validation
|
||||
# ---------------------------------------------------------------------------
|
||||
@pytest.mark.unit
|
||||
class TestValidation:
|
||||
def test_unimplemented_tool_is_skipped_not_an_error(self, monkeypatch, caplog):
|
||||
monkeypatch.setattr(
|
||||
default_tools.settings,
|
||||
"DEFAULT_CHAT_TOOLS",
|
||||
["memory", "read_webpage", "future_tool_x"],
|
||||
)
|
||||
with caplog.at_level("DEBUG", logger="application.agents.default_tools"):
|
||||
usable = default_tools.validate_default_chat_tools()
|
||||
assert "future_tool_x" not in usable
|
||||
assert "memory" in usable and "read_webpage" in usable
|
||||
assert any(
|
||||
"future_tool_x" in rec.message and rec.levelname == "DEBUG"
|
||||
for rec in caplog.records
|
||||
)
|
||||
assert not any(rec.levelname == "WARNING" for rec in caplog.records)
|
||||
|
||||
def test_loaded_default_tools_is_silent(self, monkeypatch, caplog):
|
||||
# Runs per request — must never log.
|
||||
monkeypatch.setattr(
|
||||
default_tools.settings,
|
||||
"DEFAULT_CHAT_TOOLS",
|
||||
["memory", "read_webpage", "future_tool_x"],
|
||||
)
|
||||
with caplog.at_level("DEBUG", logger="application.agents.default_tools"):
|
||||
default_tools.loaded_default_tools()
|
||||
assert caplog.records == []
|
||||
|
||||
def test_fk_bound_tool_is_rejected(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
default_tools.settings, "DEFAULT_CHAT_TOOLS", ["memory", "notes"]
|
||||
)
|
||||
with pytest.raises(ValueError, match="notes"):
|
||||
default_tools.validate_default_chat_tools()
|
||||
|
||||
def test_fk_bound_todo_list_is_rejected(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
default_tools.settings, "DEFAULT_CHAT_TOOLS", ["memory", "todo_list"]
|
||||
)
|
||||
with pytest.raises(ValueError, match="todo_list"):
|
||||
default_tools.validate_default_chat_tools()
|
||||
|
||||
def test_fully_unknown_name_is_skipped(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
default_tools.settings,
|
||||
"DEFAULT_CHAT_TOOLS",
|
||||
["memory", "definitely_not_a_real_tool"],
|
||||
)
|
||||
usable = default_tools.validate_default_chat_tools()
|
||||
assert usable == ["memory"]
|
||||
|
||||
def test_config_free_tools_pass(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
default_tools.settings, "DEFAULT_CHAT_TOOLS", ["memory", "read_webpage"]
|
||||
)
|
||||
assert default_tools.validate_default_chat_tools() == [
|
||||
"memory",
|
||||
"read_webpage",
|
||||
]
|
||||
|
||||
def test_scheduler_is_config_free(self, monkeypatch):
|
||||
# Dual-registration only works if scheduler passes the config-free
|
||||
# assertion — otherwise startup would reject DEFAULT_CHAT_TOOLS.
|
||||
monkeypatch.setattr(
|
||||
default_tools.settings, "DEFAULT_CHAT_TOOLS", ["scheduler"]
|
||||
)
|
||||
assert default_tools.validate_default_chat_tools() == ["scheduler"]
|
||||
|
||||
def test_tool_with_required_config_is_rejected(self, monkeypatch):
|
||||
# ``brave`` needs an API key.
|
||||
monkeypatch.setattr(
|
||||
default_tools.settings, "DEFAULT_CHAT_TOOLS", ["memory", "brave"]
|
||||
)
|
||||
with pytest.raises(ValueError, match="brave"):
|
||||
default_tools.validate_default_chat_tools()
|
||||
|
||||
def test_loaded_default_tools_filters_unimplemented(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
default_tools.settings,
|
||||
"DEFAULT_CHAT_TOOLS",
|
||||
["memory", "read_webpage", "future_tool_x"],
|
||||
)
|
||||
assert default_tools.loaded_default_tools() == ["memory", "read_webpage"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Synthesized rows
|
||||
# ---------------------------------------------------------------------------
|
||||
@pytest.mark.unit
|
||||
class TestSynthesize:
|
||||
def test_synthesize_returns_row_shaped_entry(self):
|
||||
row = default_tools.synthesize_default_tool("memory")
|
||||
assert row is not None
|
||||
assert row["name"] == "memory"
|
||||
assert row["id"] == default_tools.default_tool_id("memory")
|
||||
assert row["id"] == row["_id"]
|
||||
assert row["config"] == {}
|
||||
assert row["config_requirements"] == {}
|
||||
assert row["status"] is True
|
||||
assert row["default"] is True
|
||||
assert isinstance(row["actions"], list) and row["actions"]
|
||||
|
||||
def test_synthesize_unknown_tool_returns_none(self):
|
||||
assert default_tools.synthesize_default_tool("future_tool_x") is None
|
||||
assert default_tools.synthesize_default_tool("nope") is None
|
||||
|
||||
def test_synthesize_includes_display_name(self):
|
||||
row = default_tools.synthesize_default_tool("read_webpage")
|
||||
assert row["display_name"]
|
||||
assert isinstance(row["description"], str)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Opt-out list
|
||||
# ---------------------------------------------------------------------------
|
||||
@pytest.mark.unit
|
||||
class TestDisabledList:
|
||||
def test_none_user_doc_yields_empty(self):
|
||||
assert default_tools.disabled_default_tools(None) == []
|
||||
|
||||
def test_missing_preferences_yields_empty(self):
|
||||
assert default_tools.disabled_default_tools({"user_id": "u"}) == []
|
||||
|
||||
def test_reads_disabled_list(self):
|
||||
doc = {"tool_preferences": {"disabled_default_tools": ["read_webpage"]}}
|
||||
assert default_tools.disabled_default_tools(doc) == ["read_webpage"]
|
||||
|
||||
def test_malformed_preferences_yields_empty(self):
|
||||
assert default_tools.disabled_default_tools(
|
||||
{"tool_preferences": "not-a-dict"}
|
||||
) == []
|
||||
assert default_tools.disabled_default_tools(
|
||||
{"tool_preferences": {"disabled_default_tools": "x"}}
|
||||
) == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Chat resolver — synthesized defaults
|
||||
# ---------------------------------------------------------------------------
|
||||
@pytest.mark.unit
|
||||
class TestSynthesizedDefaults:
|
||||
def test_all_defaults_present_when_nothing_disabled(self):
|
||||
rows = default_tools.synthesized_default_tools(None)
|
||||
names = {r["name"] for r in rows}
|
||||
assert names == set(default_tools.loaded_default_tools())
|
||||
|
||||
def test_opt_out_removes_a_tool(self):
|
||||
doc = {"tool_preferences": {"disabled_default_tools": ["read_webpage"]}}
|
||||
rows = default_tools.synthesized_default_tools(doc)
|
||||
names = {r["name"] for r in rows}
|
||||
assert "read_webpage" not in names
|
||||
assert "memory" in names
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# default_tools_for_management — the tool-management listing
|
||||
# ---------------------------------------------------------------------------
|
||||
@pytest.mark.unit
|
||||
class TestDefaultToolsForManagement:
|
||||
def test_lists_every_loaded_default(self):
|
||||
rows = default_tools.default_tools_for_management(None)
|
||||
assert {r["name"] for r in rows} == set(
|
||||
default_tools.loaded_default_tools()
|
||||
)
|
||||
|
||||
def test_all_enabled_when_nothing_disabled(self):
|
||||
rows = default_tools.default_tools_for_management(None)
|
||||
assert all(r["status"] is True for r in rows)
|
||||
|
||||
def test_disabled_default_still_listed_with_status_false(self):
|
||||
doc = {"tool_preferences": {"disabled_default_tools": ["read_webpage"]}}
|
||||
rows = default_tools.default_tools_for_management(doc)
|
||||
by_name = {r["name"]: r for r in rows}
|
||||
assert "read_webpage" in by_name
|
||||
assert by_name["read_webpage"]["status"] is False
|
||||
assert by_name["memory"]["status"] is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_tool_by_id
|
||||
# ---------------------------------------------------------------------------
|
||||
@pytest.mark.unit
|
||||
class TestResolveToolById:
|
||||
def test_synthetic_id_resolves_in_memory(self):
|
||||
tool_id = default_tools.default_tool_id("memory")
|
||||
row = default_tools.resolve_tool_by_id(tool_id, "user-x")
|
||||
assert row is not None
|
||||
assert row["name"] == "memory"
|
||||
assert row["id"] == tool_id
|
||||
|
||||
def test_non_default_id_delegates_to_repo(self):
|
||||
sentinel = {"id": "real", "name": "brave"}
|
||||
|
||||
class _Repo:
|
||||
def get_any(self, tool_id, user):
|
||||
assert user == "user-x"
|
||||
return sentinel
|
||||
|
||||
row = default_tools.resolve_tool_by_id(
|
||||
str(uuid.uuid4()), "user-x", user_tools_repo=_Repo()
|
||||
)
|
||||
assert row is sentinel
|
||||
|
||||
def test_non_default_id_without_repo_returns_none(self):
|
||||
assert default_tools.resolve_tool_by_id(str(uuid.uuid4()), "user-x") is None
|
||||
|
||||
def test_builtin_agent_tool_id_resolves_in_memory(self):
|
||||
"""Dual-registered scheduler resolves with BOTH ``default`` and
|
||||
``builtin`` flags so either path can branch on the discriminator."""
|
||||
tool_id = default_tools.default_tool_id("scheduler")
|
||||
row = default_tools.resolve_tool_by_id(tool_id, "user-x")
|
||||
assert row is not None
|
||||
assert row["name"] == "scheduler"
|
||||
assert row["builtin"] is True
|
||||
assert row["default"] is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Agent-selectable builtins (scheduler) — synthesized like defaults but
|
||||
# hidden from agentless-chat synthesis and from /api/available_tools.
|
||||
# ---------------------------------------------------------------------------
|
||||
@pytest.mark.unit
|
||||
class TestBuiltinAgentTools:
|
||||
def test_scheduler_is_a_builtin(self):
|
||||
assert "scheduler" in default_tools.BUILTIN_AGENT_TOOLS
|
||||
|
||||
def test_scheduler_dual_registered_in_default_chat_tools(self):
|
||||
# Revised decision #8: scheduler is dual-registered as a default
|
||||
# chat tool (auto-on in agentless chats) AND a builtin agent tool
|
||||
# (opt-in via the agent picker). Both registries share the same
|
||||
# ``_DEFAULT_TOOL_NAMESPACE`` so the synthetic id is one stable uuid5.
|
||||
assert "scheduler" in default_tools.settings.DEFAULT_CHAT_TOOLS
|
||||
|
||||
def test_dual_registration_produces_one_synthetic_id(self):
|
||||
# Same uuid5 namespace → same id whether reached via defaults or builtins.
|
||||
as_default = default_tools.default_tool_id("scheduler")
|
||||
assert default_tools.is_default_tool_id(as_default)
|
||||
assert default_tools.is_builtin_agent_tool_id(as_default)
|
||||
|
||||
def test_builtin_id_is_recognised(self):
|
||||
tool_id = default_tools.default_tool_id("scheduler")
|
||||
assert default_tools.is_builtin_agent_tool_id(tool_id)
|
||||
assert default_tools.builtin_agent_tool_name_for_id(tool_id) == "scheduler"
|
||||
|
||||
def test_synthesize_builtin_marks_flags_correctly(self):
|
||||
row = default_tools.synthesize_builtin_agent_tool("scheduler")
|
||||
assert row is not None
|
||||
assert row["name"] == "scheduler"
|
||||
assert row["default"] is False
|
||||
assert row["builtin"] is True
|
||||
assert isinstance(row["actions"], list) and row["actions"]
|
||||
|
||||
def test_builtin_agent_tools_for_management_lists_scheduler(self):
|
||||
rows = default_tools.builtin_agent_tools_for_management()
|
||||
names = {r["name"] for r in rows}
|
||||
assert "scheduler" in names
|
||||
for row in rows:
|
||||
assert row["builtin"] is True
|
||||
assert row["default"] is False
|
||||
|
||||
def test_synthesized_default_chat_now_includes_scheduler(self):
|
||||
# Revised decision #8: scheduler is dual-registered → it appears in
|
||||
# ``synthesized_default_tools`` so agentless chats can use it.
|
||||
rows = default_tools.synthesized_default_tools(None)
|
||||
assert "scheduler" in {r["name"] for r in rows}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _FK_BOUND_TOOLS — schema introspection guard against rot
|
||||
# ---------------------------------------------------------------------------
|
||||
@pytest.mark.unit
|
||||
class TestFkBoundToolsIsInSync:
|
||||
# Table name -> tool module name (``application/agents/tools/<name>``).
|
||||
_TABLE_TO_TOOL = {
|
||||
"notes": "notes",
|
||||
"todos": "todo_list",
|
||||
}
|
||||
|
||||
def test_fk_bound_tools_matches_metadata(self):
|
||||
from application.storage.db.models import metadata
|
||||
|
||||
fk_bound_tables = set()
|
||||
for tbl in metadata.tables.values():
|
||||
tool_id_col = tbl.columns.get("tool_id")
|
||||
if tool_id_col is None:
|
||||
continue
|
||||
for fk in tool_id_col.foreign_keys:
|
||||
if fk.target_fullname == "user_tools.id":
|
||||
fk_bound_tables.add(tbl.name)
|
||||
break
|
||||
|
||||
unmapped = fk_bound_tables - set(self._TABLE_TO_TOOL)
|
||||
assert not unmapped, (
|
||||
f"New FK-bound table(s) without a tool mapping: {sorted(unmapped)}. "
|
||||
"Add an entry to _TABLE_TO_TOOL here AND to "
|
||||
"application.agents.default_tools._FK_BOUND_TOOLS."
|
||||
)
|
||||
derived_names = {
|
||||
self._TABLE_TO_TOOL[name] for name in fk_bound_tables
|
||||
}
|
||||
assert derived_names == set(default_tools._FK_BOUND_TOOLS), (
|
||||
"_FK_BOUND_TOOLS is out of sync with schema-derived names: "
|
||||
f"derived={sorted(derived_names)} "
|
||||
f"declared={sorted(default_tools._FK_BOUND_TOOLS)}"
|
||||
)
|
||||
187
tests/agents/test_scheduler_agent_builtin.py
Normal file
187
tests/agents/test_scheduler_agent_builtin.py
Normal file
@@ -0,0 +1,187 @@
|
||||
"""Regression: scheduler stays out of the Add-Tool catalog but reaches the
|
||||
agent picker, the LLM tool schema, and the schedules table on execute."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import text
|
||||
|
||||
# Pre-import to stabilise the ToolManager.load_tools walk's import order.
|
||||
import application.api.user.tools.mcp # noqa: F401
|
||||
|
||||
from application.agents.default_tools import ( # noqa: E402
|
||||
BUILTIN_AGENT_TOOLS,
|
||||
builtin_agent_tools_for_management,
|
||||
default_tool_id,
|
||||
resolve_tool_by_id,
|
||||
)
|
||||
from application.agents.tool_executor import ToolExecutor # noqa: E402
|
||||
from application.agents.tools.tool_manager import ToolManager # noqa: E402
|
||||
from application.storage.db.repositories.schedules import ( # noqa: E402
|
||||
SchedulesRepository,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_scheduler_sessions(pg_conn):
|
||||
"""Redirect scheduler tool db session helpers to ``pg_conn``."""
|
||||
|
||||
@contextmanager
|
||||
def _ctx():
|
||||
yield pg_conn
|
||||
|
||||
with patch(
|
||||
"application.agents.tools.scheduler.db_session", _ctx,
|
||||
), patch(
|
||||
"application.agents.tools.scheduler.db_readonly", _ctx,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
def _make_agent(conn, *, user_id="alice", agent_tools=None) -> dict:
|
||||
"""Insert an agents row whose tools JSONB carries agent_tools."""
|
||||
row = conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO agents (user_id, name, status, key, tools)
|
||||
VALUES (:u, 'sched-agent', 'active', :k, CAST(:t AS jsonb))
|
||||
RETURNING *
|
||||
"""
|
||||
),
|
||||
{
|
||||
"u": user_id,
|
||||
"k": f"sk-{uuid.uuid4()}",
|
||||
"t": json.dumps(list(agent_tools or [])),
|
||||
},
|
||||
).fetchone()
|
||||
return dict(row._mapping)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestAddToolCatalogHidesScheduler:
|
||||
def test_tool_manager_walks_skip_internal_scheduler(self):
|
||||
tm = ToolManager(config={})
|
||||
assert "scheduler" not in tm.tools
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestAgentPickerExposesScheduler:
|
||||
def test_scheduler_is_listed_in_builtin_agent_tools(self):
|
||||
rows = builtin_agent_tools_for_management()
|
||||
assert any(r["name"] == "scheduler" for r in rows)
|
||||
assert "scheduler" in BUILTIN_AGENT_TOOLS
|
||||
|
||||
def test_scheduler_row_is_flagged_builtin_not_default(self):
|
||||
scheduler_row = next(
|
||||
r for r in builtin_agent_tools_for_management()
|
||||
if r["name"] == "scheduler"
|
||||
)
|
||||
assert scheduler_row["builtin"] is True
|
||||
assert scheduler_row["default"] is False
|
||||
|
||||
def test_synthetic_id_resolves_to_row_with_schedule_task_action(self):
|
||||
synthetic_id = default_tool_id("scheduler")
|
||||
row = resolve_tool_by_id(synthetic_id, "alice")
|
||||
assert row is not None
|
||||
assert row["name"] == "scheduler"
|
||||
action_names = {a["name"] for a in row.get("actions") or []}
|
||||
assert "schedule_task" in action_names
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestDualRegistration:
|
||||
"""``scheduler`` is in both registries; same uuid5 resolves either way."""
|
||||
|
||||
def test_scheduler_in_both_registries(self):
|
||||
from application.agents.default_tools import (
|
||||
BUILTIN_AGENT_TOOLS as BUILTINS,
|
||||
settings,
|
||||
)
|
||||
assert "scheduler" in BUILTINS
|
||||
assert "scheduler" in settings.DEFAULT_CHAT_TOOLS
|
||||
|
||||
def test_same_synthetic_id_in_both_paths(self):
|
||||
from application.agents.default_tools import (
|
||||
builtin_agent_tool_ids,
|
||||
default_tool_ids,
|
||||
)
|
||||
via_default = default_tool_ids().get("scheduler")
|
||||
via_builtin = builtin_agent_tool_ids().get("scheduler")
|
||||
assert via_default == via_builtin
|
||||
assert via_default is not None
|
||||
|
||||
def test_synthesized_default_tools_includes_scheduler(self):
|
||||
"""Agentless chats see scheduler in the default-tools synthesis."""
|
||||
from application.agents.default_tools import synthesized_default_tools
|
||||
|
||||
rows = synthesized_default_tools(None)
|
||||
names = {r["name"] for r in rows}
|
||||
assert "scheduler" in names
|
||||
|
||||
def test_synthesized_builtin_agent_tools_includes_scheduler(self):
|
||||
"""Agent picker still sees scheduler via the builtin registry."""
|
||||
from application.agents.default_tools import (
|
||||
builtin_agent_tools_for_management,
|
||||
)
|
||||
|
||||
rows = builtin_agent_tools_for_management()
|
||||
names = {r["name"] for r in rows}
|
||||
assert "scheduler" in names
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestEndToEndAgentPickerToLLMSchema:
|
||||
def test_agent_with_scheduler_in_tools_exposes_schedule_task_to_llm(
|
||||
self, pg_conn,
|
||||
):
|
||||
scheduler_id = default_tool_id("scheduler")
|
||||
agent = _make_agent(pg_conn, agent_tools=[scheduler_id])
|
||||
|
||||
@contextmanager
|
||||
def _use_conn():
|
||||
yield pg_conn
|
||||
|
||||
with patch("application.agents.tool_executor.db_readonly", _use_conn):
|
||||
executor = ToolExecutor(
|
||||
user_api_key=agent["key"], user="alice",
|
||||
agent_id=str(agent["id"]),
|
||||
)
|
||||
tools_dict = executor.get_tools()
|
||||
|
||||
assert scheduler_id in tools_dict
|
||||
row = tools_dict[scheduler_id]
|
||||
assert row["name"] == "scheduler"
|
||||
|
||||
schema = executor.prepare_tools_for_llm(tools_dict)
|
||||
function_names = {entry["function"]["name"] for entry in schema}
|
||||
assert "schedule_task" in function_names
|
||||
|
||||
def test_executing_schedule_task_creates_one_time_schedule(
|
||||
self, pg_conn, patch_scheduler_sessions,
|
||||
):
|
||||
agent = _make_agent(pg_conn)
|
||||
agent_id = str(agent["id"])
|
||||
user_id = "alice"
|
||||
|
||||
tm = ToolManager(config={})
|
||||
tool = tm.load_tool(
|
||||
"scheduler",
|
||||
tool_config={"agent_id": agent_id, "conversation_id": None},
|
||||
user_id=user_id,
|
||||
)
|
||||
out = tool.execute_action(
|
||||
"schedule_task", instruction="ping me later", delay="1h",
|
||||
)
|
||||
parsed = json.loads(out)
|
||||
assert "task_id" in parsed
|
||||
|
||||
row = SchedulesRepository(pg_conn).get(parsed["task_id"], user_id)
|
||||
assert row is not None
|
||||
assert row["trigger_type"] == "once"
|
||||
assert row["status"] == "active"
|
||||
assert row["created_via"] == "chat"
|
||||
146
tests/agents/test_scheduler_utils.py
Normal file
146
tests/agents/test_scheduler_utils.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""Tests for scheduler_utils (cron / DST / delay / horizon)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import pytest
|
||||
|
||||
from application.agents.scheduler_utils import (
|
||||
ScheduleValidationError,
|
||||
clamp_once_horizon,
|
||||
cron_interval_seconds,
|
||||
next_cron_run,
|
||||
parse_cron,
|
||||
parse_delay,
|
||||
parse_run_at,
|
||||
resolve_timezone,
|
||||
)
|
||||
|
||||
|
||||
class TestParseCron:
|
||||
def test_valid(self):
|
||||
parse_cron("0 9 * * 1")
|
||||
|
||||
def test_invalid(self):
|
||||
with pytest.raises(ScheduleValidationError):
|
||||
parse_cron("not a cron")
|
||||
|
||||
def test_wrong_field_count(self):
|
||||
with pytest.raises(ScheduleValidationError):
|
||||
parse_cron("0 9 * *")
|
||||
|
||||
|
||||
class TestNextCronRunDST:
|
||||
def test_daily_9am_warsaw_across_spring_forward(self):
|
||||
tz = ZoneInfo("Europe/Warsaw")
|
||||
before_dst = datetime(2026, 3, 28, 9, 30, tzinfo=tz)
|
||||
nxt = next_cron_run("0 9 * * *", "Europe/Warsaw", after=before_dst)
|
||||
assert nxt.astimezone(tz) == datetime(2026, 3, 29, 9, 0, tzinfo=tz)
|
||||
|
||||
def test_daily_9am_warsaw_across_fall_back(self):
|
||||
tz = ZoneInfo("Europe/Warsaw")
|
||||
before_dst = datetime(2026, 10, 24, 9, 30, tzinfo=tz)
|
||||
nxt = next_cron_run("0 9 * * *", "Europe/Warsaw", after=before_dst)
|
||||
assert nxt.astimezone(tz) == datetime(2026, 10, 25, 9, 0, tzinfo=tz)
|
||||
|
||||
def test_utc_default(self):
|
||||
anchor = datetime(2026, 5, 19, 12, 0, tzinfo=timezone.utc)
|
||||
nxt = next_cron_run("0 * * * *", None, after=anchor)
|
||||
assert nxt > anchor
|
||||
assert nxt.tzinfo is not None
|
||||
|
||||
def test_returned_value_is_utc(self):
|
||||
anchor = datetime(2026, 5, 19, 12, 0, tzinfo=timezone.utc)
|
||||
nxt = next_cron_run("0 9 * * *", "Europe/Warsaw", after=anchor)
|
||||
assert nxt.tzinfo is not None
|
||||
assert nxt.utcoffset() == timedelta(0)
|
||||
|
||||
|
||||
class TestResolveTimezone:
|
||||
def test_unknown(self):
|
||||
with pytest.raises(ScheduleValidationError):
|
||||
resolve_timezone("Atlantis/Nowhere")
|
||||
|
||||
def test_blank_defaults_utc(self):
|
||||
assert resolve_timezone("").key == "UTC"
|
||||
assert resolve_timezone(None).key == "UTC"
|
||||
|
||||
|
||||
class TestParseDelay:
|
||||
@pytest.mark.parametrize(
|
||||
"raw,seconds",
|
||||
[("30s", 30), ("15m", 900), ("2h", 7200), ("1d", 86_400)],
|
||||
)
|
||||
def test_units(self, raw, seconds):
|
||||
assert parse_delay(raw).total_seconds() == seconds
|
||||
|
||||
def test_uppercase(self):
|
||||
assert parse_delay("2H").total_seconds() == 7200
|
||||
|
||||
def test_zero_rejected(self):
|
||||
with pytest.raises(ScheduleValidationError):
|
||||
parse_delay("0m")
|
||||
|
||||
def test_garbage(self):
|
||||
with pytest.raises(ScheduleValidationError):
|
||||
parse_delay("two hours")
|
||||
|
||||
|
||||
class TestParseRunAt:
|
||||
def test_iso_utc(self):
|
||||
parsed = parse_run_at("2026-05-19T12:00:00Z")
|
||||
assert parsed.tzinfo is not None
|
||||
assert parsed == datetime(2026, 5, 19, 12, 0, tzinfo=timezone.utc)
|
||||
|
||||
def test_iso_with_offset(self):
|
||||
parsed = parse_run_at("2026-05-19T14:00:00+02:00")
|
||||
assert parsed == datetime(2026, 5, 19, 12, 0, tzinfo=timezone.utc)
|
||||
|
||||
def test_naive_uses_tz(self):
|
||||
parsed = parse_run_at("2026-05-19T14:00:00", "Europe/Warsaw")
|
||||
assert parsed == datetime(2026, 5, 19, 12, 0, tzinfo=timezone.utc)
|
||||
|
||||
def test_invalid(self):
|
||||
with pytest.raises(ScheduleValidationError):
|
||||
parse_run_at("not a date")
|
||||
|
||||
|
||||
class TestCronIntervalSeconds:
|
||||
def test_every_minute_returns_60s(self):
|
||||
assert cron_interval_seconds("* * * * *", None) == 60
|
||||
|
||||
def test_hourly_returns_3600s(self):
|
||||
assert cron_interval_seconds("0 * * * *", None) == 3600
|
||||
|
||||
def test_bursty_cron_returns_smallest_gap(self):
|
||||
# '* 9 * * *' has 60s gaps inside the 9 AM burst; sampling two adjacent
|
||||
# ticks at random can miss them — the rolling window must catch the 60.
|
||||
assert cron_interval_seconds("* 9 * * *", None) == 60
|
||||
|
||||
def test_bursty_cron_rejected_when_floor_above_burst(self):
|
||||
from application.core.settings import settings as app_settings
|
||||
burst = "* 9 * * *"
|
||||
cadence = cron_interval_seconds(burst, None)
|
||||
floor = max(0, int(app_settings.SCHEDULE_MIN_INTERVAL))
|
||||
assert cadence < floor, (
|
||||
f"bursty cron {burst!r} cadence {cadence}s must be below the "
|
||||
f"configured SCHEDULE_MIN_INTERVAL floor ({floor}s)"
|
||||
)
|
||||
|
||||
|
||||
class TestClampOnceHorizon:
|
||||
def test_rejects_past(self):
|
||||
past = datetime.now(timezone.utc) - timedelta(minutes=1)
|
||||
with pytest.raises(ScheduleValidationError):
|
||||
clamp_once_horizon(past, max_horizon_seconds=3600)
|
||||
|
||||
def test_rejects_beyond_horizon(self):
|
||||
far = datetime.now(timezone.utc) + timedelta(days=400)
|
||||
with pytest.raises(ScheduleValidationError):
|
||||
clamp_once_horizon(far, max_horizon_seconds=365 * 86_400)
|
||||
|
||||
def test_accepts_in_range(self):
|
||||
soon = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
clamp_once_horizon(soon, max_horizon_seconds=86_400)
|
||||
@@ -13,16 +13,21 @@ class TestToolExecutorInit:
|
||||
executor = ToolExecutor()
|
||||
assert executor.user_api_key is None
|
||||
assert executor.user is None
|
||||
assert executor.agent_id is None
|
||||
assert executor.tool_calls == []
|
||||
assert executor._loaded_tools == {}
|
||||
assert executor.conversation_id is None
|
||||
|
||||
def test_init_with_params(self):
|
||||
executor = ToolExecutor(
|
||||
user_api_key="key", user="alice", decoded_token={"sub": "alice"}
|
||||
user_api_key="key",
|
||||
user="alice",
|
||||
decoded_token={"sub": "alice"},
|
||||
agent_id="agent-1",
|
||||
)
|
||||
assert executor.user_api_key == "key"
|
||||
assert executor.user == "alice"
|
||||
assert executor.agent_id == "agent-1"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@@ -61,7 +66,8 @@ class TestToolExecutorGetTools:
|
||||
assert str(tool["id"]) in tools
|
||||
assert tools[str(tool["id"])]["id"] == tool["id"]
|
||||
|
||||
def test_get_tools_uses_user_when_no_api_key(self, pg_conn, monkeypatch):
|
||||
def test_agentless_chat_synthesizes_defaults(self, pg_conn, monkeypatch):
|
||||
from application.agents.default_tools import loaded_default_tools
|
||||
from application.storage.db.repositories.user_tools import UserToolsRepository
|
||||
|
||||
UserToolsRepository(pg_conn).create(
|
||||
@@ -72,15 +78,148 @@ class TestToolExecutorGetTools:
|
||||
executor = ToolExecutor(user="alice")
|
||||
tools = executor.get_tools()
|
||||
assert isinstance(tools, dict)
|
||||
assert len(tools) == 1
|
||||
assert len(tools) == 1 + len(loaded_default_tools())
|
||||
names = {t["name"] for t in tools.values()}
|
||||
assert "tool1" in names
|
||||
assert "memory" in names
|
||||
|
||||
def test_agent_bound_chat_via_user_path_excludes_defaults(
|
||||
self, pg_conn, monkeypatch
|
||||
):
|
||||
"""``agent_id`` forces ``agents.tools``-only; no defaults synthesized."""
|
||||
from application.agents.default_tools import loaded_default_tools
|
||||
from application.storage.db.repositories.user_tools import UserToolsRepository
|
||||
|
||||
UserToolsRepository(pg_conn).create(
|
||||
user_id="alice", name="tool1", status=True
|
||||
)
|
||||
self._patch_conn(monkeypatch, pg_conn)
|
||||
|
||||
executor = ToolExecutor(user="alice", agent_id="agent-x")
|
||||
tools = executor.get_tools()
|
||||
names = {t["name"] for t in tools.values()}
|
||||
assert "tool1" in names
|
||||
assert not (set(loaded_default_tools()) & names)
|
||||
|
||||
def test_get_tools_defaults_to_local(self, pg_conn, monkeypatch):
|
||||
from application.agents.default_tools import loaded_default_tools
|
||||
|
||||
self._patch_conn(monkeypatch, pg_conn)
|
||||
|
||||
executor = ToolExecutor()
|
||||
tools = executor.get_tools()
|
||||
assert isinstance(tools, dict)
|
||||
assert tools == {}
|
||||
assert len(tools) == len(loaded_default_tools())
|
||||
assert {t["name"] for t in tools.values()} == set(loaded_default_tools())
|
||||
|
||||
def test_api_key_path_excludes_defaults(self, pg_conn, monkeypatch):
|
||||
"""Agent-bound resolution returns exactly ``agents.tools``."""
|
||||
from application.agents.default_tools import loaded_default_tools
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.user_tools import UserToolsRepository
|
||||
|
||||
tool = UserToolsRepository(pg_conn).create(user_id="alice", name="tool1")
|
||||
AgentsRepository(pg_conn).create(
|
||||
user_id="alice",
|
||||
name="a",
|
||||
status="active",
|
||||
key="key-agentbound",
|
||||
tools=[str(tool["id"])],
|
||||
)
|
||||
self._patch_conn(monkeypatch, pg_conn)
|
||||
|
||||
executor = ToolExecutor(user_api_key="key-agentbound", user="alice")
|
||||
tools = executor.get_tools()
|
||||
names = {t["name"] for t in tools.values()}
|
||||
assert names == {"tool1"}
|
||||
assert not (set(loaded_default_tools()) & names)
|
||||
|
||||
def test_api_key_path_empty_agent_tools_gets_nothing(
|
||||
self, pg_conn, monkeypatch
|
||||
):
|
||||
"""Empty ``agents.tools`` invoked via API key yields no tools."""
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
|
||||
AgentsRepository(pg_conn).create(
|
||||
user_id="bob",
|
||||
name="a",
|
||||
status="active",
|
||||
key="key-empty",
|
||||
tools=[],
|
||||
)
|
||||
self._patch_conn(monkeypatch, pg_conn)
|
||||
|
||||
executor = ToolExecutor(user_api_key="key-empty", user="bob")
|
||||
assert executor.get_tools() == {}
|
||||
|
||||
def test_api_key_path_only_synthesizes_author_added_defaults(
|
||||
self, pg_conn, monkeypatch
|
||||
):
|
||||
"""Only ``read_webpage`` in ``agents.tools`` -> exactly that; no other defaults bolted on."""
|
||||
from application.agents.default_tools import default_tool_id
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
|
||||
read_webpage_id = default_tool_id("read_webpage")
|
||||
memory_id = default_tool_id("memory")
|
||||
AgentsRepository(pg_conn).create(
|
||||
user_id="erin",
|
||||
name="a",
|
||||
status="active",
|
||||
key="key-only-read",
|
||||
tools=[read_webpage_id],
|
||||
)
|
||||
self._patch_conn(monkeypatch, pg_conn)
|
||||
|
||||
executor = ToolExecutor(
|
||||
user_api_key="key-only-read", user="erin", agent_id="erin-agent"
|
||||
)
|
||||
tools = executor.get_tools()
|
||||
assert set(tools) == {read_webpage_id}
|
||||
assert tools[read_webpage_id]["name"] == "read_webpage"
|
||||
assert memory_id not in tools
|
||||
assert "memory" not in {t["name"] for t in tools.values()}
|
||||
|
||||
def test_explicit_default_on_agent_resolves(
|
||||
self, pg_conn, monkeypatch
|
||||
):
|
||||
"""A default tool added explicitly to ``agents.tools`` resolves for every caller."""
|
||||
from application.agents.default_tools import default_tool_id
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
|
||||
memory_id = default_tool_id("memory")
|
||||
AgentsRepository(pg_conn).create(
|
||||
user_id="erin",
|
||||
name="a",
|
||||
status="active",
|
||||
key="key-explicit-default",
|
||||
tools=[memory_id],
|
||||
)
|
||||
self._patch_conn(monkeypatch, pg_conn)
|
||||
|
||||
executor = ToolExecutor(
|
||||
user_api_key="key-explicit-default", user="erin"
|
||||
)
|
||||
tools = executor.get_tools()
|
||||
assert set(tools) == {memory_id}
|
||||
assert tools[memory_id]["name"] == "memory"
|
||||
|
||||
def test_no_dedup_between_explicit_and_default_memory(
|
||||
self, pg_conn, monkeypatch
|
||||
):
|
||||
from application.storage.db.repositories.user_tools import UserToolsRepository
|
||||
|
||||
# Explicit ``memory`` row and the default ``memory`` coexist (separate stores).
|
||||
UserToolsRepository(pg_conn).create(
|
||||
user_id="dave", name="memory", status=True
|
||||
)
|
||||
self._patch_conn(monkeypatch, pg_conn)
|
||||
|
||||
executor = ToolExecutor(user="dave")
|
||||
tools = executor.get_tools()
|
||||
memory_entries = [t for t in tools.values() if t["name"] == "memory"]
|
||||
assert len(memory_entries) == 2
|
||||
ids = {t["id"] for t in memory_entries}
|
||||
assert len(ids) == 2
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
|
||||
205
tests/agents/test_tool_executor_headless.py
Normal file
205
tests/agents/test_tool_executor_headless.py
Normal file
@@ -0,0 +1,205 @@
|
||||
"""Headless mode + tool allowlist enforcement on ToolExecutor.check_pause."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
from application.agents.tool_executor import ToolExecutor
|
||||
|
||||
|
||||
def _call(name: str, args: dict | None = None, call_id: str = "c1"):
|
||||
import json
|
||||
return SimpleNamespace(
|
||||
id=call_id,
|
||||
name=name,
|
||||
arguments=json.dumps(args or {}),
|
||||
thought_signature=None,
|
||||
)
|
||||
|
||||
|
||||
def _executor(*, headless=False, allowlist=None):
|
||||
ex = ToolExecutor(headless=headless, tool_allowlist=allowlist or [])
|
||||
ex._name_to_tool = {
|
||||
"send": ("tool-a", "send"),
|
||||
"freecall": ("tool-b", "freecall"),
|
||||
"client_only": ("ct0", "client_only"),
|
||||
}
|
||||
return ex
|
||||
|
||||
|
||||
def _tools_dict():
|
||||
return {
|
||||
"tool-a": {
|
||||
"id": "tool-a",
|
||||
"name": "telegram",
|
||||
"actions": [
|
||||
{"name": "send", "require_approval": True},
|
||||
],
|
||||
},
|
||||
"tool-b": {
|
||||
"id": "tool-b",
|
||||
"name": "noop",
|
||||
"actions": [
|
||||
{"name": "freecall", "require_approval": False},
|
||||
],
|
||||
},
|
||||
"ct0": {
|
||||
"name": "client_only",
|
||||
"client_side": True,
|
||||
"actions": [
|
||||
{"name": "client_only"},
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class TestHeadlessApproval:
|
||||
def test_denied_when_not_in_allowlist(self):
|
||||
ex = _executor(headless=True, allowlist=[])
|
||||
result = ex.check_pause(_tools_dict(), _call("send"), "MockLLM")
|
||||
assert result is not None
|
||||
assert result["pause_type"] == "headless_denied"
|
||||
assert result["error_type"] == "tool_not_allowed"
|
||||
|
||||
def test_allowed_when_in_allowlist(self):
|
||||
ex = _executor(headless=True, allowlist=["tool-a"])
|
||||
assert ex.check_pause(_tools_dict(), _call("send"), "MockLLM") is None
|
||||
|
||||
def test_non_approval_tool_runs_freely(self):
|
||||
ex = _executor(headless=True, allowlist=[])
|
||||
assert ex.check_pause(_tools_dict(), _call("freecall"), "MockLLM") is None
|
||||
|
||||
|
||||
class TestHeadlessClientSide:
|
||||
def test_client_side_always_denied_in_headless(self):
|
||||
# Client-side ignores the allowlist; no headless answer is possible.
|
||||
ex = _executor(headless=True, allowlist=["ct0"])
|
||||
result = ex.check_pause(_tools_dict(), _call("client_only"), "MockLLM")
|
||||
assert result is not None
|
||||
assert result["pause_type"] == "headless_denied"
|
||||
|
||||
|
||||
class TestNormalModeUnchanged:
|
||||
def test_approval_still_pauses_without_headless(self):
|
||||
ex = _executor(headless=False)
|
||||
result = ex.check_pause(_tools_dict(), _call("send"), "MockLLM")
|
||||
assert result["pause_type"] == "awaiting_approval"
|
||||
|
||||
def test_client_side_still_pauses_without_headless(self):
|
||||
ex = _executor(headless=False)
|
||||
result = ex.check_pause(_tools_dict(), _call("client_only"), "MockLLM")
|
||||
assert result["pause_type"] == "requires_client_execution"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scheduler exclusion in headless runs — chat-only tool must not appear in
|
||||
# the toolset when a scheduled / webhook LLM runs, else it could re-schedule.
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestHeadlessSchedulerExclusion:
|
||||
def test_synthesized_default_tools_drops_scheduler_in_headless(self):
|
||||
from application.agents.default_tools import (
|
||||
loaded_default_tools,
|
||||
synthesized_default_tools,
|
||||
)
|
||||
|
||||
# Sanity: scheduler is on for normal chats…
|
||||
names_chat = {r["name"] for r in synthesized_default_tools(None)}
|
||||
if "scheduler" in loaded_default_tools():
|
||||
assert "scheduler" in names_chat
|
||||
# …and silently absent for headless runs.
|
||||
names_headless = {
|
||||
r["name"]
|
||||
for r in synthesized_default_tools(None, headless=True)
|
||||
}
|
||||
assert "scheduler" not in names_headless
|
||||
|
||||
def test_get_user_tools_filters_scheduler_when_headless(
|
||||
self, monkeypatch,
|
||||
):
|
||||
from application.agents import tool_executor as te_module
|
||||
from application.agents.default_tools import (
|
||||
default_tool_id,
|
||||
loaded_default_tools,
|
||||
)
|
||||
|
||||
if "scheduler" not in loaded_default_tools():
|
||||
import pytest as _pytest # local alias to keep top-of-module noise low
|
||||
_pytest.skip("scheduler not loaded in this env")
|
||||
|
||||
# Stub the DB layer: no explicit user_tools so the synthesized
|
||||
# defaults are the only ``scheduler`` source — that path is what
|
||||
# this test pins.
|
||||
from contextlib import contextmanager
|
||||
|
||||
@contextmanager
|
||||
def _fake_readonly():
|
||||
yield object()
|
||||
|
||||
monkeypatch.setattr(te_module, "db_readonly", _fake_readonly)
|
||||
monkeypatch.setattr(
|
||||
te_module, "UserToolsRepository",
|
||||
lambda _c: type("R", (), {
|
||||
"list_active_for_user": lambda _self, _u: [],
|
||||
})(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
te_module, "UsersRepository",
|
||||
lambda _c: type("R", (), {
|
||||
"get": lambda _self, _u: None,
|
||||
})(),
|
||||
)
|
||||
|
||||
sched_id = default_tool_id("scheduler")
|
||||
|
||||
ex_chat = te_module.ToolExecutor(headless=False)
|
||||
tools_chat = ex_chat._get_user_tools("u-test")
|
||||
assert sched_id in tools_chat
|
||||
|
||||
ex_headless = te_module.ToolExecutor(headless=True)
|
||||
tools_headless = ex_headless._get_user_tools("u-test")
|
||||
assert sched_id not in tools_headless
|
||||
|
||||
def test_get_tools_by_api_key_drops_scheduler_when_headless(
|
||||
self, monkeypatch,
|
||||
):
|
||||
"""An agent-bound headless run (e.g. webhook) skips scheduler even if
|
||||
the author added the synthetic id to ``agents.tools``."""
|
||||
from application.agents import tool_executor as te_module
|
||||
from application.agents.default_tools import default_tool_id
|
||||
|
||||
sched_id = default_tool_id("scheduler")
|
||||
from contextlib import contextmanager
|
||||
|
||||
@contextmanager
|
||||
def _fake_readonly():
|
||||
yield object()
|
||||
|
||||
class _AgentsRepo:
|
||||
def __init__(self, _conn):
|
||||
pass
|
||||
|
||||
def find_by_key(self, _k):
|
||||
return {"user_id": "u1", "tools": [sched_id]}
|
||||
|
||||
class _UTRepo:
|
||||
def __init__(self, _conn):
|
||||
pass
|
||||
|
||||
def get_any(self, _t, _u):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(te_module, "db_readonly", _fake_readonly)
|
||||
monkeypatch.setattr(te_module, "AgentsRepository", _AgentsRepo)
|
||||
monkeypatch.setattr(te_module, "UserToolsRepository", _UTRepo)
|
||||
|
||||
ex_normal = te_module.ToolExecutor(
|
||||
user_api_key="k", headless=False, agent_id="a",
|
||||
)
|
||||
tools_normal = ex_normal._get_tools_by_api_key("k")
|
||||
assert sched_id in tools_normal
|
||||
|
||||
ex_headless = te_module.ToolExecutor(
|
||||
user_api_key="k", headless=True, agent_id="a",
|
||||
)
|
||||
tools_headless = ex_headless._get_tools_by_api_key("k")
|
||||
assert sched_id not in tools_headless
|
||||
@@ -250,3 +250,39 @@ class TestRepository:
|
||||
row = _select_attempt(pg_conn, "c-y")
|
||||
assert row["status"] == "failed"
|
||||
assert row["error"] == "kaboom"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestDefaultToolJournaling:
|
||||
"""A default tool's synthetic id round-trips through execute/journal."""
|
||||
|
||||
def test_synthetic_tool_id_is_journaled(
|
||||
self, pg_conn, mock_tool_manager, monkeypatch
|
||||
):
|
||||
from application.agents.default_tools import synthesize_default_tool
|
||||
|
||||
memory_row = synthesize_default_tool("memory")
|
||||
assert memory_row is not None
|
||||
tools_dict = {memory_row["id"]: memory_row}
|
||||
|
||||
executor = ToolExecutor(user="u")
|
||||
monkeypatch.setattr(
|
||||
"application.agents.tool_executor.ToolActionParser",
|
||||
lambda _cls, **kw: Mock(
|
||||
parse_args=Mock(
|
||||
return_value=(memory_row["id"], "view", {"path": "/"})
|
||||
)
|
||||
),
|
||||
)
|
||||
_patch_db(monkeypatch, pg_conn)
|
||||
|
||||
events, result = _drain(
|
||||
executor.execute(tools_dict, _make_call(call_id="c-def"), "MockLLM")
|
||||
)
|
||||
assert result[0] == "Tool result"
|
||||
|
||||
row = _select_attempt(pg_conn, "c-def")
|
||||
assert row is not None
|
||||
assert row["status"] == "confirmed"
|
||||
assert row["tool_name"] == "memory"
|
||||
assert str(row["tool_id"]) == memory_row["id"]
|
||||
|
||||
442
tests/agents/tools/test_scheduler.py
Normal file
442
tests/agents/tools/test_scheduler.py
Normal file
@@ -0,0 +1,442 @@
|
||||
"""Tests for the SchedulerTool."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import text
|
||||
|
||||
# Pre-import to stabilise the ToolManager.load_tools walk's import order
|
||||
# (avoids the mcp_tool ↔ application.api.user circular when ToolManager
|
||||
# instantiation is the first reachable importer in a test process).
|
||||
import application.api.user.tools.mcp # noqa: F401
|
||||
|
||||
from application.agents.tools.scheduler import SchedulerTool # noqa: E402
|
||||
from application.core.settings import settings # noqa: E402
|
||||
from application.storage.db.repositories.schedules import SchedulesRepository # noqa: E402
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_sessions(pg_conn):
|
||||
from contextlib import contextmanager
|
||||
|
||||
@contextmanager
|
||||
def _ctx():
|
||||
yield pg_conn
|
||||
|
||||
with patch(
|
||||
"application.agents.tools.scheduler.db_session", _ctx,
|
||||
), patch(
|
||||
"application.agents.tools.scheduler.db_readonly", _ctx,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
def _make_agent(conn, user_id: str = "u1") -> str:
|
||||
row = conn.execute(
|
||||
text(
|
||||
"INSERT INTO agents (user_id, name, status) "
|
||||
"VALUES (:u, 'a', 'draft') RETURNING id"
|
||||
),
|
||||
{"u": user_id},
|
||||
).fetchone()
|
||||
return str(row[0])
|
||||
|
||||
|
||||
def _make_tool(name="scheduler", *, user_id="u1", agent_id=None, conversation_id=None):
|
||||
return SchedulerTool(
|
||||
tool_config={
|
||||
"agent_id": agent_id,
|
||||
"conversation_id": conversation_id,
|
||||
},
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
|
||||
class TestGuards:
|
||||
def test_requires_user_id(self):
|
||||
tool = SchedulerTool(tool_config={"agent_id": str(uuid.uuid4())})
|
||||
assert "user_id" in tool.execute_action("schedule_task", instruction="x")
|
||||
|
||||
def test_rejects_invalid_agent_id(self):
|
||||
tool = _make_tool(user_id="u1", agent_id="not-a-uuid")
|
||||
assert "invalid agent_id" in tool.execute_action(
|
||||
"schedule_task", instruction="x"
|
||||
)
|
||||
|
||||
def test_requires_agent_or_conversation(self):
|
||||
# Neither agent_id nor conversation_id → hard error (webhook caller
|
||||
# outside any chat); scheduler can't operate without a conversation home.
|
||||
tool = _make_tool(user_id="u1", agent_id=None, conversation_id=None)
|
||||
out = tool.execute_action("schedule_task", instruction="x")
|
||||
assert "conversation_id" in out or "conversation home" in out
|
||||
|
||||
|
||||
class TestScheduleTask:
|
||||
def test_creates_with_delay(self, pg_conn, patch_sessions):
|
||||
agent_id = _make_agent(pg_conn)
|
||||
tool = _make_tool(user_id="u1", agent_id=agent_id, conversation_id=None)
|
||||
out = tool.execute_action(
|
||||
"schedule_task", instruction="say hi", delay="2h",
|
||||
)
|
||||
parsed = json.loads(out)
|
||||
assert "task_id" in parsed
|
||||
assert "resolved_run_at" in parsed
|
||||
row = SchedulesRepository(pg_conn).get(parsed["task_id"], "u1")
|
||||
assert row is not None
|
||||
assert row["trigger_type"] == "once"
|
||||
assert row["created_via"] == "chat"
|
||||
fire = datetime.fromisoformat(parsed["resolved_run_at"].replace("Z", "+00:00"))
|
||||
delta = fire - datetime.now(timezone.utc)
|
||||
assert timedelta(minutes=119) <= delta <= timedelta(minutes=121)
|
||||
|
||||
def test_creates_with_run_at(self, pg_conn, patch_sessions):
|
||||
agent_id = _make_agent(pg_conn)
|
||||
tool = _make_tool(user_id="u1", agent_id=agent_id)
|
||||
fire = (datetime.now(timezone.utc) + timedelta(hours=3)).isoformat()
|
||||
out = tool.execute_action(
|
||||
"schedule_task", instruction="x", run_at=fire,
|
||||
)
|
||||
parsed = json.loads(out)
|
||||
assert "task_id" in parsed
|
||||
|
||||
def test_rejects_both_delay_and_run_at(self, pg_conn, patch_sessions):
|
||||
agent_id = _make_agent(pg_conn)
|
||||
tool = _make_tool(user_id="u1", agent_id=agent_id)
|
||||
out = tool.execute_action(
|
||||
"schedule_task", instruction="x", delay="30m",
|
||||
run_at="2030-01-01T00:00:00Z",
|
||||
)
|
||||
assert "only one" in out
|
||||
|
||||
def test_rejects_past_run_at(self, pg_conn, patch_sessions):
|
||||
agent_id = _make_agent(pg_conn)
|
||||
tool = _make_tool(user_id="u1", agent_id=agent_id)
|
||||
past = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat()
|
||||
out = tool.execute_action("schedule_task", instruction="x", run_at=past)
|
||||
assert "past" in out
|
||||
|
||||
def test_rejects_beyond_horizon(
|
||||
self, pg_conn, patch_sessions, monkeypatch
|
||||
):
|
||||
monkeypatch.setattr(settings, "SCHEDULE_ONCE_MAX_HORIZON", 3600)
|
||||
agent_id = _make_agent(pg_conn)
|
||||
tool = _make_tool(user_id="u1", agent_id=agent_id)
|
||||
far = (datetime.now(timezone.utc) + timedelta(hours=10)).isoformat()
|
||||
out = tool.execute_action("schedule_task", instruction="x", run_at=far)
|
||||
assert "horizon" in out
|
||||
|
||||
|
||||
class TestQuota:
|
||||
def test_quota_enforced(self, pg_conn, patch_sessions, monkeypatch):
|
||||
monkeypatch.setattr(settings, "SCHEDULE_MAX_PER_USER", 2)
|
||||
agent_id = _make_agent(pg_conn)
|
||||
tool = _make_tool(user_id="u1", agent_id=agent_id)
|
||||
for _ in range(2):
|
||||
out = tool.execute_action(
|
||||
"schedule_task", instruction="x", delay="1h",
|
||||
)
|
||||
assert "task_id" in out
|
||||
out = tool.execute_action(
|
||||
"schedule_task", instruction="x", delay="1h",
|
||||
)
|
||||
assert "maximum" in out
|
||||
|
||||
|
||||
class TestListAndCancel:
|
||||
def test_list_returns_pending(self, pg_conn, patch_sessions):
|
||||
agent_id = _make_agent(pg_conn)
|
||||
tool = _make_tool(user_id="u1", agent_id=agent_id)
|
||||
for _ in range(3):
|
||||
tool.execute_action(
|
||||
"schedule_task", instruction="x", delay="1h",
|
||||
)
|
||||
listed = json.loads(tool.execute_action("list_scheduled_tasks"))
|
||||
assert len(listed["tasks"]) == 3
|
||||
assert all(t["status"] == "active" for t in listed["tasks"])
|
||||
|
||||
def test_cancel_flips_status(self, pg_conn, patch_sessions):
|
||||
agent_id = _make_agent(pg_conn)
|
||||
tool = _make_tool(user_id="u1", agent_id=agent_id)
|
||||
created = json.loads(
|
||||
tool.execute_action("schedule_task", instruction="x", delay="1h")
|
||||
)
|
||||
out = tool.execute_action(
|
||||
"cancel_scheduled_task", task_id=created["task_id"]
|
||||
)
|
||||
assert "cancelled" in out
|
||||
row = SchedulesRepository(pg_conn).get(created["task_id"], "u1")
|
||||
assert row["status"] == "cancelled"
|
||||
|
||||
def test_cancel_unknown_id_rejected(self, pg_conn, patch_sessions):
|
||||
agent_id = _make_agent(pg_conn)
|
||||
tool = _make_tool(user_id="u1", agent_id=agent_id)
|
||||
out = tool.execute_action(
|
||||
"cancel_scheduled_task", task_id="not-a-uuid",
|
||||
)
|
||||
assert "valid id" in out
|
||||
|
||||
|
||||
class TestActionsMetadata:
|
||||
def test_actions_listed(self):
|
||||
tool = SchedulerTool()
|
||||
names = {a["name"] for a in tool.get_actions_metadata()}
|
||||
assert names == {
|
||||
"schedule_task", "list_scheduled_tasks", "cancel_scheduled_task",
|
||||
}
|
||||
|
||||
|
||||
class TestAgentlessInvocation:
|
||||
def test_agentless_creates_schedule_with_null_agent_id(
|
||||
self, pg_conn, patch_sessions,
|
||||
):
|
||||
"""Agentless chat → scheduler.schedule_task → row with NULL agent_id."""
|
||||
conv_id = pg_conn.execute(
|
||||
text(
|
||||
"INSERT INTO conversations (user_id, name) "
|
||||
"VALUES ('u1', 'origin') RETURNING id"
|
||||
)
|
||||
).fetchone()[0]
|
||||
tool = _make_tool(
|
||||
user_id="u1", agent_id=None, conversation_id=str(conv_id),
|
||||
)
|
||||
out = tool.execute_action(
|
||||
"schedule_task", instruction="ping me later", delay="1h",
|
||||
)
|
||||
parsed = json.loads(out)
|
||||
assert "task_id" in parsed
|
||||
row = SchedulesRepository(pg_conn).get(parsed["task_id"], "u1")
|
||||
assert row is not None
|
||||
assert row["agent_id"] is None
|
||||
assert row["trigger_type"] == "once"
|
||||
assert row["created_via"] == "chat"
|
||||
assert str(row["origin_conversation_id"]) == str(conv_id)
|
||||
|
||||
def test_agentless_list_scoped_to_conversation(
|
||||
self, pg_conn, patch_sessions,
|
||||
):
|
||||
"""Agentless list_scheduled_tasks scopes to user + origin conversation."""
|
||||
conv_a = pg_conn.execute(
|
||||
text(
|
||||
"INSERT INTO conversations (user_id, name) "
|
||||
"VALUES ('u1', 'a') RETURNING id"
|
||||
)
|
||||
).fetchone()[0]
|
||||
conv_b = pg_conn.execute(
|
||||
text(
|
||||
"INSERT INTO conversations (user_id, name) "
|
||||
"VALUES ('u1', 'b') RETURNING id"
|
||||
)
|
||||
).fetchone()[0]
|
||||
tool_a = _make_tool(
|
||||
user_id="u1", agent_id=None, conversation_id=str(conv_a),
|
||||
)
|
||||
tool_b = _make_tool(
|
||||
user_id="u1", agent_id=None, conversation_id=str(conv_b),
|
||||
)
|
||||
tool_a.execute_action(
|
||||
"schedule_task", instruction="in-a", delay="1h",
|
||||
)
|
||||
tool_a.execute_action(
|
||||
"schedule_task", instruction="in-a-2", delay="2h",
|
||||
)
|
||||
tool_b.execute_action(
|
||||
"schedule_task", instruction="in-b", delay="3h",
|
||||
)
|
||||
listed_a = json.loads(tool_a.execute_action("list_scheduled_tasks"))
|
||||
listed_b = json.loads(tool_b.execute_action("list_scheduled_tasks"))
|
||||
assert len(listed_a["tasks"]) == 2
|
||||
assert len(listed_b["tasks"]) == 1
|
||||
assert all(t["status"] == "active" for t in listed_a["tasks"])
|
||||
|
||||
def test_agentless_cancel_blocked_for_other_conversation(
|
||||
self, pg_conn, patch_sessions,
|
||||
):
|
||||
"""A user can't cancel tasks created in another agentless chat."""
|
||||
conv_a = pg_conn.execute(
|
||||
text(
|
||||
"INSERT INTO conversations (user_id, name) "
|
||||
"VALUES ('u1', 'a') RETURNING id"
|
||||
)
|
||||
).fetchone()[0]
|
||||
conv_b = pg_conn.execute(
|
||||
text(
|
||||
"INSERT INTO conversations (user_id, name) "
|
||||
"VALUES ('u1', 'b') RETURNING id"
|
||||
)
|
||||
).fetchone()[0]
|
||||
tool_a = _make_tool(
|
||||
user_id="u1", agent_id=None, conversation_id=str(conv_a),
|
||||
)
|
||||
tool_b = _make_tool(
|
||||
user_id="u1", agent_id=None, conversation_id=str(conv_b),
|
||||
)
|
||||
created = json.loads(
|
||||
tool_a.execute_action(
|
||||
"schedule_task", instruction="x", delay="1h",
|
||||
)
|
||||
)
|
||||
out = tool_b.execute_action(
|
||||
"cancel_scheduled_task", task_id=created["task_id"],
|
||||
)
|
||||
assert "not found" in out
|
||||
|
||||
def test_agentless_cancel_succeeds_in_own_conversation(
|
||||
self, pg_conn, patch_sessions,
|
||||
):
|
||||
conv = pg_conn.execute(
|
||||
text(
|
||||
"INSERT INTO conversations (user_id, name) "
|
||||
"VALUES ('u1', 'a') RETURNING id"
|
||||
)
|
||||
).fetchone()[0]
|
||||
tool = _make_tool(
|
||||
user_id="u1", agent_id=None, conversation_id=str(conv),
|
||||
)
|
||||
created = json.loads(
|
||||
tool.execute_action("schedule_task", instruction="x", delay="1h")
|
||||
)
|
||||
out = tool.execute_action(
|
||||
"cancel_scheduled_task", task_id=created["task_id"],
|
||||
)
|
||||
assert "cancelled" in out
|
||||
|
||||
def test_agentless_snapshot_allowlist_lists_user_tools(
|
||||
self, pg_conn, patch_sessions,
|
||||
):
|
||||
"""Agentless schedule captures the user's non-approval tools at fire-time."""
|
||||
from application.agents.tools.scheduler import _safe_default_allowlist
|
||||
from application.storage.db.repositories.user_tools import (
|
||||
UserToolsRepository,
|
||||
)
|
||||
|
||||
# Seed an explicit non-approval user tool.
|
||||
user_tool = UserToolsRepository(pg_conn).create(
|
||||
"u1", "read_webpage", config={}, actions=[
|
||||
{"name": "fetch", "active": True, "require_approval": False},
|
||||
], status=True,
|
||||
)
|
||||
conv = pg_conn.execute(
|
||||
text(
|
||||
"INSERT INTO conversations (user_id, name) "
|
||||
"VALUES ('u1', 'a') RETURNING id"
|
||||
)
|
||||
).fetchone()[0]
|
||||
tool = _make_tool(
|
||||
user_id="u1", agent_id=None, conversation_id=str(conv),
|
||||
)
|
||||
out = tool.execute_action(
|
||||
"schedule_task", instruction="x", delay="1h",
|
||||
)
|
||||
parsed = json.loads(out)
|
||||
row = SchedulesRepository(pg_conn).get(parsed["task_id"], "u1")
|
||||
# The explicit user_tools row is in the snapshot (approval=False).
|
||||
assert str(user_tool["id"]) in (row["tool_allowlist"] or [])
|
||||
# Direct allowlist call returns the same set.
|
||||
ids = _safe_default_allowlist(None, "u1")
|
||||
assert str(user_tool["id"]) in ids
|
||||
|
||||
|
||||
class TestAllowlistSnapshotSemantics:
|
||||
"""The schedule's ``tool_allowlist`` is a **pre-auth snapshot**, not a
|
||||
visibility cap. The LLM sees the user's *current* tools at fire time
|
||||
(via ``ToolExecutor._get_user_tools``); the snapshot only governs
|
||||
whether an approval-gated tool can run unattended."""
|
||||
|
||||
def test_tool_added_after_creation_is_visible_at_fire_time(
|
||||
self, pg_conn, patch_sessions,
|
||||
):
|
||||
"""Schedule captures the allowlist at creation; a tool added later is
|
||||
visible at fire time (resolver re-queries) but isn't in the snapshot."""
|
||||
from application.agents.tools.scheduler import _safe_default_allowlist
|
||||
from application.storage.db.repositories.user_tools import (
|
||||
UserToolsRepository,
|
||||
)
|
||||
|
||||
pg_conn.execute(
|
||||
text(
|
||||
"INSERT INTO conversations (user_id, name) "
|
||||
"VALUES ('u1', 'snap-add') RETURNING id"
|
||||
)
|
||||
).fetchone()
|
||||
# Snapshot the allowlist BEFORE adding the new tool.
|
||||
snapshot_before = _safe_default_allowlist(None, "u1")
|
||||
|
||||
# User adds an approval-gated tool AFTER schedule creation.
|
||||
added = UserToolsRepository(pg_conn).create(
|
||||
"u1", "telegram",
|
||||
config={}, actions=[
|
||||
{"name": "send", "active": True, "require_approval": True},
|
||||
], status=True,
|
||||
)
|
||||
|
||||
# The snapshot does NOT include the post-creation tool.
|
||||
assert str(added["id"]) not in snapshot_before
|
||||
# …but the LLM sees it at fire time (current resolver state).
|
||||
snapshot_after = _safe_default_allowlist(None, "u1")
|
||||
# An approval-gated tool is excluded from the snapshot regardless,
|
||||
# but it IS in ``list_active_for_user`` (what the LLM's tool_executor
|
||||
# uses) — make that explicit:
|
||||
ids_now = {
|
||||
str(r["id"]) for r in
|
||||
UserToolsRepository(pg_conn).list_active_for_user("u1")
|
||||
}
|
||||
assert str(added["id"]) in ids_now
|
||||
# And approval-gated still skipped from the safe allowlist.
|
||||
assert str(added["id"]) not in snapshot_after
|
||||
|
||||
def test_tool_deleted_between_creation_and_fire_is_invisible(
|
||||
self, pg_conn, patch_sessions,
|
||||
):
|
||||
"""A tool deleted between schedule creation and fire is gone for the
|
||||
LLM at fire time (the resolver lists the current state)."""
|
||||
from application.agents.tools.scheduler import _safe_default_allowlist
|
||||
from application.storage.db.repositories.user_tools import (
|
||||
UserToolsRepository,
|
||||
)
|
||||
|
||||
repo = UserToolsRepository(pg_conn)
|
||||
existing = repo.create(
|
||||
"u1", "read_webpage",
|
||||
config={}, actions=[
|
||||
{"name": "fetch", "active": True, "require_approval": False},
|
||||
], status=True,
|
||||
)
|
||||
# Snapshot at creation includes it (non-approval).
|
||||
snapshot = _safe_default_allowlist(None, "u1")
|
||||
assert str(existing["id"]) in snapshot
|
||||
|
||||
# User deletes it; fire-time resolver no longer surfaces it.
|
||||
repo.delete(str(existing["id"]), "u1")
|
||||
ids_now = {r["id"] for r in repo.list_active_for_user("u1")}
|
||||
assert str(existing["id"]) not in ids_now
|
||||
# And the freshly-recomputed allowlist drops it too.
|
||||
snapshot_after = _safe_default_allowlist(None, "u1")
|
||||
assert str(existing["id"]) not in snapshot_after
|
||||
|
||||
|
||||
class TestInternalFlag:
|
||||
def test_internal_true(self):
|
||||
assert SchedulerTool.internal is True
|
||||
|
||||
def test_not_in_tool_manager_auto_load(self):
|
||||
from application.agents.tools.tool_manager import ToolManager
|
||||
|
||||
tm = ToolManager(config={})
|
||||
assert "scheduler" not in tm.tools
|
||||
|
||||
def test_load_tool_special_case_still_works(self):
|
||||
from application.agents.tools.tool_manager import ToolManager
|
||||
|
||||
tm = ToolManager(config={})
|
||||
tool = tm.load_tool(
|
||||
"scheduler",
|
||||
tool_config={"agent_id": str(uuid.uuid4())},
|
||||
user_id="u1",
|
||||
)
|
||||
assert isinstance(tool, SchedulerTool)
|
||||
assert tool.user_id == "u1"
|
||||
@@ -246,6 +246,68 @@ class TestCompleteStreamMethod:
|
||||
mock_reserve.assert_called_once()
|
||||
mock_finalize.assert_called_once()
|
||||
|
||||
def test_tool_executor_conversation_id_set_after_reserve(
|
||||
self, mock_mongo_db, flask_app,
|
||||
):
|
||||
"""Regression: ``save_user_question`` may mint a fresh
|
||||
``conversation_id`` (first turn). The propagation MUST land on
|
||||
``agent.tool_executor.conversation_id`` BEFORE ``agent.gen`` runs,
|
||||
so tools needing a conversation home (``scheduler`` in an agentless
|
||||
chat) see it on the very first call.
|
||||
"""
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
|
||||
fresh_conv_id = str(uuid.uuid4())
|
||||
seen_conv_id_on_gen: dict = {}
|
||||
|
||||
mock_agent = MagicMock()
|
||||
tool_executor = MagicMock()
|
||||
# Start with no conversation_id — the propagation must set it.
|
||||
tool_executor.conversation_id = None
|
||||
mock_agent.tool_executor = tool_executor
|
||||
|
||||
def _gen(**_kwargs):
|
||||
# Capture the executor's id at the exact moment gen runs;
|
||||
# this is what tools see when called from the agent loop.
|
||||
seen_conv_id_on_gen["value"] = (
|
||||
mock_agent.tool_executor.conversation_id
|
||||
)
|
||||
yield {"answer": "ok"}
|
||||
|
||||
mock_agent.gen.side_effect = _gen
|
||||
mock_agent.gen.return_value = None # use side_effect instead
|
||||
|
||||
with patch.object(
|
||||
resource.conversation_service, "save_user_question"
|
||||
) as mock_reserve, patch.object(
|
||||
resource.conversation_service, "finalize_message",
|
||||
return_value=True,
|
||||
):
|
||||
mock_reserve.return_value = {
|
||||
"conversation_id": fresh_conv_id,
|
||||
"message_id": str(uuid.uuid4()),
|
||||
"request_id": "req-prop",
|
||||
}
|
||||
|
||||
list(
|
||||
resource.complete_stream(
|
||||
question="schedule something",
|
||||
agent=mock_agent,
|
||||
conversation_id=None, # caller had no conv yet
|
||||
user_api_key=None,
|
||||
decoded_token={"sub": "user-prop"},
|
||||
should_save_conversation=True,
|
||||
)
|
||||
)
|
||||
|
||||
# The fresh id reserved by save_user_question must reach the
|
||||
# tool_executor before agent.gen consumes it.
|
||||
assert seen_conv_id_on_gen["value"] == fresh_conv_id
|
||||
assert tool_executor.conversation_id == fresh_conv_id
|
||||
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
|
||||
397
tests/api/answer/services/test_continuation_service.py
Normal file
397
tests/api/answer/services/test_continuation_service.py
Normal file
@@ -0,0 +1,397 @@
|
||||
"""Unit tests for application/api/answer/services/continuation_service.py.
|
||||
|
||||
Covers:
|
||||
- _make_serializable: ObjectId, dict, list, bytes conversions
|
||||
- ContinuationService.__init__: index creation
|
||||
- save_state: upserts document with correct shape
|
||||
- load_state: returns doc or None
|
||||
- delete_state: removes doc and returns bool
|
||||
"""
|
||||
|
||||
import datetime
|
||||
|
||||
import pytest
|
||||
from bson import ObjectId
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestMakeSerializable:
|
||||
|
||||
def test_converts_objectid_to_string(self):
|
||||
from application.api.answer.services.continuation_service import (
|
||||
_make_serializable,
|
||||
)
|
||||
|
||||
oid = ObjectId()
|
||||
result = _make_serializable(oid)
|
||||
assert result == str(oid)
|
||||
|
||||
def test_converts_nested_dict(self):
|
||||
from application.api.answer.services.continuation_service import (
|
||||
_make_serializable,
|
||||
)
|
||||
|
||||
oid = ObjectId()
|
||||
data = {"key": oid, "nested": {"inner": oid}}
|
||||
result = _make_serializable(data)
|
||||
assert result == {"key": str(oid), "nested": {"inner": str(oid)}}
|
||||
|
||||
def test_converts_list_with_objectids(self):
|
||||
from application.api.answer.services.continuation_service import (
|
||||
_make_serializable,
|
||||
)
|
||||
|
||||
oid = ObjectId()
|
||||
data = [oid, "plain", 42]
|
||||
result = _make_serializable(data)
|
||||
assert result == [str(oid), "plain", 42]
|
||||
|
||||
def test_converts_bytes_to_string(self):
|
||||
from application.api.answer.services.continuation_service import (
|
||||
_make_serializable,
|
||||
)
|
||||
|
||||
result = _make_serializable(b"hello world")
|
||||
assert result == "hello world"
|
||||
|
||||
def test_passes_through_primitives(self):
|
||||
from application.api.answer.services.continuation_service import (
|
||||
_make_serializable,
|
||||
)
|
||||
|
||||
assert _make_serializable("hello") == "hello"
|
||||
assert _make_serializable(42) == 42
|
||||
assert _make_serializable(3.14) == 3.14
|
||||
assert _make_serializable(None) is None
|
||||
assert _make_serializable(True) is True
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestContinuationServiceInit:
|
||||
|
||||
def test_initializes_collection(self, mock_mongo_db):
|
||||
from application.api.answer.services.continuation_service import (
|
||||
ContinuationService,
|
||||
)
|
||||
|
||||
service = ContinuationService()
|
||||
assert service.collection is not None
|
||||
|
||||
def test_ensure_indexes_tolerates_existing(self, mock_mongo_db):
|
||||
"""Second init should not raise even if indexes already exist."""
|
||||
from application.api.answer.services.continuation_service import (
|
||||
ContinuationService,
|
||||
)
|
||||
|
||||
ContinuationService()
|
||||
ContinuationService() # Should not raise
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestContinuationServiceSaveState:
|
||||
|
||||
def test_save_state_creates_document(self, mock_mongo_db):
|
||||
from application.api.answer.services.continuation_service import (
|
||||
ContinuationService,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
|
||||
service = ContinuationService()
|
||||
collection = mock_mongo_db[settings.MONGO_DB_NAME]["pending_tool_state"]
|
||||
|
||||
state_id = service.save_state(
|
||||
conversation_id="conv_abc",
|
||||
user="user_123",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
pending_tool_calls=[{"id": "call_1", "function": {"name": "search"}}],
|
||||
tools_dict={"search": {"type": "function"}},
|
||||
tool_schemas=[{"name": "search", "description": "search tool"}],
|
||||
agent_config={"model_id": "gpt-4", "llm_name": "openai"},
|
||||
)
|
||||
|
||||
assert state_id is not None
|
||||
doc = collection.find_one({"conversation_id": "conv_abc", "user": "user_123"})
|
||||
assert doc is not None
|
||||
assert doc["messages"] == [{"role": "user", "content": "hello"}]
|
||||
assert len(doc["pending_tool_calls"]) == 1
|
||||
assert doc["tools_dict"] == {"search": {"type": "function"}}
|
||||
|
||||
def test_save_state_upserts_existing(self, mock_mongo_db):
|
||||
"""Second save for same conversation replaces first."""
|
||||
from application.api.answer.services.continuation_service import (
|
||||
ContinuationService,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
|
||||
service = ContinuationService()
|
||||
collection = mock_mongo_db[settings.MONGO_DB_NAME]["pending_tool_state"]
|
||||
|
||||
service.save_state(
|
||||
conversation_id="conv_abc",
|
||||
user="user_123",
|
||||
messages=[{"role": "user", "content": "first"}],
|
||||
pending_tool_calls=[],
|
||||
tools_dict={},
|
||||
tool_schemas=[],
|
||||
agent_config={},
|
||||
)
|
||||
service.save_state(
|
||||
conversation_id="conv_abc",
|
||||
user="user_123",
|
||||
messages=[{"role": "user", "content": "second"}],
|
||||
pending_tool_calls=[],
|
||||
tools_dict={},
|
||||
tool_schemas=[],
|
||||
agent_config={},
|
||||
)
|
||||
|
||||
docs = list(collection.find({"conversation_id": "conv_abc"}))
|
||||
assert len(docs) == 1
|
||||
assert docs[0]["messages"][0]["content"] == "second"
|
||||
|
||||
def test_save_state_with_client_tools(self, mock_mongo_db):
|
||||
from application.api.answer.services.continuation_service import (
|
||||
ContinuationService,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
|
||||
service = ContinuationService()
|
||||
collection = mock_mongo_db[settings.MONGO_DB_NAME]["pending_tool_state"]
|
||||
|
||||
service.save_state(
|
||||
conversation_id="conv_tools",
|
||||
user="user_123",
|
||||
messages=[],
|
||||
pending_tool_calls=[],
|
||||
tools_dict={},
|
||||
tool_schemas=[],
|
||||
agent_config={},
|
||||
client_tools=[{"name": "my_tool", "description": "A client tool"}],
|
||||
)
|
||||
|
||||
doc = collection.find_one({"conversation_id": "conv_tools"})
|
||||
assert doc["client_tools"] == [{"name": "my_tool", "description": "A client tool"}]
|
||||
|
||||
def test_save_state_no_client_tools_stores_none(self, mock_mongo_db):
|
||||
from application.api.answer.services.continuation_service import (
|
||||
ContinuationService,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
|
||||
service = ContinuationService()
|
||||
collection = mock_mongo_db[settings.MONGO_DB_NAME]["pending_tool_state"]
|
||||
|
||||
service.save_state(
|
||||
conversation_id="conv_notool",
|
||||
user="user_123",
|
||||
messages=[],
|
||||
pending_tool_calls=[],
|
||||
tools_dict={},
|
||||
tool_schemas=[],
|
||||
agent_config={},
|
||||
client_tools=None,
|
||||
)
|
||||
|
||||
doc = collection.find_one({"conversation_id": "conv_notool"})
|
||||
assert doc["client_tools"] is None
|
||||
|
||||
def test_save_state_sets_expires_at(self, mock_mongo_db):
|
||||
from application.api.answer.services.continuation_service import (
|
||||
ContinuationService,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
|
||||
service = ContinuationService()
|
||||
collection = mock_mongo_db[settings.MONGO_DB_NAME]["pending_tool_state"]
|
||||
|
||||
datetime.datetime.now(datetime.timezone.utc)
|
||||
service.save_state(
|
||||
conversation_id="conv_ttl",
|
||||
user="user_123",
|
||||
messages=[],
|
||||
pending_tool_calls=[],
|
||||
tools_dict={},
|
||||
tool_schemas=[],
|
||||
agent_config={},
|
||||
)
|
||||
datetime.datetime.now(datetime.timezone.utc)
|
||||
|
||||
doc = collection.find_one({"conversation_id": "conv_ttl"})
|
||||
# expires_at should be roughly TTL seconds after save
|
||||
assert doc["expires_at"] is not None
|
||||
|
||||
def test_save_state_serializes_objectids(self, mock_mongo_db):
|
||||
"""ObjectIds in messages are converted to strings."""
|
||||
from application.api.answer.services.continuation_service import (
|
||||
ContinuationService,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
|
||||
service = ContinuationService()
|
||||
collection = mock_mongo_db[settings.MONGO_DB_NAME]["pending_tool_state"]
|
||||
|
||||
oid = ObjectId()
|
||||
service.save_state(
|
||||
conversation_id="conv_oid",
|
||||
user="user_123",
|
||||
messages=[{"role": "user", "content": str(oid), "_id": oid}],
|
||||
pending_tool_calls=[],
|
||||
tools_dict={},
|
||||
tool_schemas=[],
|
||||
agent_config={"oid_key": oid},
|
||||
)
|
||||
|
||||
doc = collection.find_one({"conversation_id": "conv_oid"})
|
||||
# The oid in agent_config should be serialized
|
||||
assert doc["agent_config"]["oid_key"] == str(oid)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestContinuationServiceLoadState:
|
||||
|
||||
def test_load_state_returns_document(self, mock_mongo_db):
|
||||
from application.api.answer.services.continuation_service import (
|
||||
ContinuationService,
|
||||
)
|
||||
|
||||
service = ContinuationService()
|
||||
service.save_state(
|
||||
conversation_id="conv_load",
|
||||
user="user_123",
|
||||
messages=[{"role": "user", "content": "test"}],
|
||||
pending_tool_calls=[{"id": "call_1"}],
|
||||
tools_dict={"t": "v"},
|
||||
tool_schemas=[],
|
||||
agent_config={"model_id": "gpt-4"},
|
||||
)
|
||||
|
||||
result = service.load_state("conv_load", "user_123")
|
||||
|
||||
assert result is not None
|
||||
assert result["conversation_id"] == "conv_load"
|
||||
assert result["user"] == "user_123"
|
||||
assert result["messages"] == [{"role": "user", "content": "test"}]
|
||||
assert isinstance(result["_id"], str) # ObjectId converted to string
|
||||
|
||||
def test_load_state_returns_none_when_not_found(self, mock_mongo_db):
|
||||
from application.api.answer.services.continuation_service import (
|
||||
ContinuationService,
|
||||
)
|
||||
|
||||
service = ContinuationService()
|
||||
result = service.load_state("nonexistent_conv", "user_123")
|
||||
assert result is None
|
||||
|
||||
def test_load_state_is_user_scoped(self, mock_mongo_db):
|
||||
"""State for user A should not be accessible by user B."""
|
||||
from application.api.answer.services.continuation_service import (
|
||||
ContinuationService,
|
||||
)
|
||||
|
||||
service = ContinuationService()
|
||||
service.save_state(
|
||||
conversation_id="conv_scoped",
|
||||
user="user_A",
|
||||
messages=[],
|
||||
pending_tool_calls=[],
|
||||
tools_dict={},
|
||||
tool_schemas=[],
|
||||
agent_config={},
|
||||
)
|
||||
|
||||
result = service.load_state("conv_scoped", "user_B")
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestContinuationServiceDeleteState:
|
||||
|
||||
def test_delete_state_removes_document(self, mock_mongo_db):
|
||||
from application.api.answer.services.continuation_service import (
|
||||
ContinuationService,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
|
||||
service = ContinuationService()
|
||||
collection = mock_mongo_db[settings.MONGO_DB_NAME]["pending_tool_state"]
|
||||
|
||||
service.save_state(
|
||||
conversation_id="conv_del",
|
||||
user="user_123",
|
||||
messages=[],
|
||||
pending_tool_calls=[],
|
||||
tools_dict={},
|
||||
tool_schemas=[],
|
||||
agent_config={},
|
||||
)
|
||||
|
||||
result = service.delete_state("conv_del", "user_123")
|
||||
|
||||
assert result is True
|
||||
doc = collection.find_one({"conversation_id": "conv_del"})
|
||||
assert doc is None
|
||||
|
||||
def test_delete_state_returns_false_when_not_found(self, mock_mongo_db):
|
||||
from application.api.answer.services.continuation_service import (
|
||||
ContinuationService,
|
||||
)
|
||||
|
||||
service = ContinuationService()
|
||||
result = service.delete_state("nonexistent_conv", "user_123")
|
||||
assert result is False
|
||||
|
||||
def test_delete_state_is_user_scoped(self, mock_mongo_db):
|
||||
"""Deleting with wrong user should not remove the state."""
|
||||
from application.api.answer.services.continuation_service import (
|
||||
ContinuationService,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
|
||||
service = ContinuationService()
|
||||
collection = mock_mongo_db[settings.MONGO_DB_NAME]["pending_tool_state"]
|
||||
|
||||
service.save_state(
|
||||
conversation_id="conv_scoped_del",
|
||||
user="user_A",
|
||||
messages=[],
|
||||
pending_tool_calls=[],
|
||||
tools_dict={},
|
||||
tool_schemas=[],
|
||||
agent_config={},
|
||||
)
|
||||
|
||||
# Wrong user — should not delete
|
||||
result = service.delete_state("conv_scoped_del", "user_B")
|
||||
assert result is False
|
||||
|
||||
# State should still exist
|
||||
doc = collection.find_one({"conversation_id": "conv_scoped_del"})
|
||||
assert doc is not None
|
||||
|
||||
def test_save_load_delete_round_trip(self, mock_mongo_db):
|
||||
"""Full round trip: save → load → delete → load returns None."""
|
||||
from application.api.answer.services.continuation_service import (
|
||||
ContinuationService,
|
||||
)
|
||||
|
||||
service = ContinuationService()
|
||||
|
||||
service.save_state(
|
||||
conversation_id="conv_rt",
|
||||
user="user_rt",
|
||||
messages=[{"role": "assistant", "content": "thinking..."}],
|
||||
pending_tool_calls=[{"id": "call_rt", "function": {"name": "search", "arguments": "{}"}}],
|
||||
tools_dict={"search": {"type": "function", "name": "search"}},
|
||||
tool_schemas=[{"name": "search"}],
|
||||
agent_config={"model_id": "gpt-4", "llm_name": "openai"},
|
||||
)
|
||||
|
||||
loaded = service.load_state("conv_rt", "user_rt")
|
||||
assert loaded is not None
|
||||
assert loaded["pending_tool_calls"][0]["id"] == "call_rt"
|
||||
|
||||
deleted = service.delete_state("conv_rt", "user_rt")
|
||||
assert deleted is True
|
||||
|
||||
none_result = service.load_state("conv_rt", "user_rt")
|
||||
assert none_result is None
|
||||
@@ -429,7 +429,7 @@ class TestConfigureRetriever:
|
||||
assert sp.retriever_config["retriever_name"] == "hybrid_search"
|
||||
assert sp.retriever_config["chunks"] == 5
|
||||
|
||||
def test_request_overrides_agent(self):
|
||||
def test_agent_wins_over_request_on_agent_bound(self):
|
||||
from application.api.answer.services.stream_processor import (
|
||||
StreamProcessor,
|
||||
)
|
||||
@@ -438,9 +438,33 @@ class TestConfigureRetriever:
|
||||
)
|
||||
sp._agent_data = {"retriever": "hybrid_search", "chunks": 5}
|
||||
sp._configure_retriever()
|
||||
assert sp.retriever_config["retriever_name"] == "hybrid_search"
|
||||
assert sp.retriever_config["chunks"] == 5
|
||||
|
||||
def test_body_wins_on_agentless(self):
|
||||
from application.api.answer.services.stream_processor import (
|
||||
StreamProcessor,
|
||||
)
|
||||
sp = StreamProcessor(
|
||||
{"retriever": "duckdb", "chunks": 7}, {"sub": "u"},
|
||||
)
|
||||
sp._configure_retriever()
|
||||
assert sp.retriever_config["retriever_name"] == "duckdb"
|
||||
assert sp.retriever_config["chunks"] == 7
|
||||
|
||||
def test_agent_bound_drops_body_chunks_and_retriever(self):
|
||||
# Missing agent values fall back to system defaults, not body's.
|
||||
from application.api.answer.services.stream_processor import (
|
||||
StreamProcessor,
|
||||
)
|
||||
sp = StreamProcessor(
|
||||
{"retriever": "duckdb", "chunks": 7}, {"sub": "u"},
|
||||
)
|
||||
sp._agent_data = {}
|
||||
sp._configure_retriever()
|
||||
assert sp.retriever_config["retriever_name"] == "classic"
|
||||
assert sp.retriever_config["chunks"] == 2
|
||||
|
||||
def test_invalid_agent_chunks_falls_back(self):
|
||||
from application.api.answer.services.stream_processor import (
|
||||
StreamProcessor,
|
||||
@@ -569,10 +593,11 @@ class TestPreFetchTools:
|
||||
got = sp.pre_fetch_tools()
|
||||
assert got is None
|
||||
|
||||
def test_no_user_tools_returns_none(self, pg_conn):
|
||||
def test_no_template_skips_default_tool_prefetch(self, pg_conn):
|
||||
from application.api.answer.services.stream_processor import (
|
||||
StreamProcessor,
|
||||
)
|
||||
|
||||
sp = StreamProcessor({}, {"sub": "no-tools-user"})
|
||||
with _patch_db(pg_conn), patch(
|
||||
"application.api.answer.services.stream_processor.settings.ENABLE_TOOL_PREFETCH",
|
||||
@@ -580,3 +605,347 @@ class TestPreFetchTools:
|
||||
):
|
||||
got = sp.pre_fetch_tools()
|
||||
assert got is None
|
||||
|
||||
def test_no_template_skips_only_default_rows_not_explicit(self, pg_conn):
|
||||
from application.api.answer.services.stream_processor import (
|
||||
StreamProcessor,
|
||||
)
|
||||
from application.storage.db.repositories.user_tools import (
|
||||
UserToolsRepository,
|
||||
)
|
||||
|
||||
UserToolsRepository(pg_conn).create(
|
||||
user_id="u-explicit-prefetch", name="read_webpage", status=True
|
||||
)
|
||||
sp = StreamProcessor({}, {"sub": "u-explicit-prefetch"})
|
||||
fetched = []
|
||||
|
||||
def _fake_fetch(tool_doc, required_actions):
|
||||
fetched.append(tool_doc)
|
||||
return {"ok": True}
|
||||
|
||||
with _patch_db(pg_conn), patch(
|
||||
"application.api.answer.services.stream_processor.settings.ENABLE_TOOL_PREFETCH",
|
||||
True,
|
||||
), patch.object(sp, "_fetch_tool_data", _fake_fetch):
|
||||
got = sp.pre_fetch_tools()
|
||||
assert got is not None
|
||||
assert "read_webpage" in got
|
||||
assert all(not d.get("default") for d in fetched)
|
||||
assert any(d.get("name") == "read_webpage" for d in fetched)
|
||||
|
||||
def test_default_tool_prefetched_when_template_references_it(
|
||||
self, pg_conn
|
||||
):
|
||||
from application.agents.default_tools import default_tool_id
|
||||
from application.api.answer.services.stream_processor import (
|
||||
StreamProcessor,
|
||||
)
|
||||
|
||||
sp = StreamProcessor({}, {"sub": "u-tpl-default"})
|
||||
sp._required_tool_actions = {"read_webpage": {None}}
|
||||
fetched = []
|
||||
|
||||
def _fake_fetch(tool_doc, required_actions):
|
||||
fetched.append(tool_doc)
|
||||
return {"ok": True}
|
||||
|
||||
with _patch_db(pg_conn), patch(
|
||||
"application.api.answer.services.stream_processor.settings.ENABLE_TOOL_PREFETCH",
|
||||
True,
|
||||
), patch.object(sp, "_fetch_tool_data", _fake_fetch):
|
||||
got = sp.pre_fetch_tools()
|
||||
assert got is not None
|
||||
assert any(
|
||||
d.get("name") == "read_webpage" and d.get("default")
|
||||
for d in fetched
|
||||
)
|
||||
# Defaults are reachable by synthetic id only — not by name.
|
||||
assert default_tool_id("read_webpage") in got
|
||||
|
||||
def test_agent_bound_invocation_omits_default_tool_prefetch(self, pg_conn):
|
||||
from application.api.answer.services.stream_processor import (
|
||||
StreamProcessor,
|
||||
)
|
||||
|
||||
sp = StreamProcessor({"agent_id": "agent-xyz"}, {"sub": "u-ag"})
|
||||
sp._required_tool_actions = {"read_webpage": {None}}
|
||||
with _patch_db(pg_conn), patch(
|
||||
"application.api.answer.services.stream_processor.settings.ENABLE_TOOL_PREFETCH",
|
||||
True,
|
||||
):
|
||||
got = sp.pre_fetch_tools()
|
||||
assert got is None
|
||||
|
||||
def test_template_name_key_favors_explicit_over_default(self, pg_conn):
|
||||
"""An explicit row and the synthesized default of the same name
|
||||
coexist: name key stays on the explicit, default reachable by
|
||||
synthetic id only."""
|
||||
from application.agents.default_tools import default_tool_id
|
||||
from application.api.answer.services.stream_processor import (
|
||||
StreamProcessor,
|
||||
)
|
||||
from application.storage.db.repositories.user_tools import (
|
||||
UserToolsRepository,
|
||||
)
|
||||
|
||||
user = "u-collision"
|
||||
explicit = UserToolsRepository(pg_conn).create(
|
||||
user_id=user, name="read_webpage", status=True,
|
||||
)
|
||||
explicit_id = str(explicit["id"])
|
||||
default_id = default_tool_id("read_webpage")
|
||||
|
||||
sp = StreamProcessor({}, {"sub": user})
|
||||
sp._required_tool_actions = {"read_webpage": {None}}
|
||||
|
||||
def _fake_fetch(tool_doc, required_actions):
|
||||
return {
|
||||
"is_default": bool(tool_doc.get("default")),
|
||||
"id": str(tool_doc.get("_id") or tool_doc.get("id")),
|
||||
}
|
||||
|
||||
with _patch_db(pg_conn), patch(
|
||||
"application.api.answer.services.stream_processor.settings.ENABLE_TOOL_PREFETCH",
|
||||
True,
|
||||
), patch.object(sp, "_fetch_tool_data", _fake_fetch):
|
||||
got = sp.pre_fetch_tools()
|
||||
assert got is not None
|
||||
assert got["read_webpage"]["is_default"] is False
|
||||
assert got["read_webpage"]["id"] == explicit_id
|
||||
assert got[explicit_id]["is_default"] is False
|
||||
assert got[default_id]["is_default"] is True
|
||||
|
||||
|
||||
class TestValidateAndSetModelAgentAuthority:
|
||||
"""Agent-bound chats: agent's ``default_model_id`` is authoritative."""
|
||||
|
||||
def test_agent_bound_ignores_body_model_id(self):
|
||||
from application.api.answer.services.stream_processor import (
|
||||
StreamProcessor,
|
||||
)
|
||||
sp = StreamProcessor({"model_id": "body-model"}, {"sub": "caller"})
|
||||
sp._agent_data = {"user": "owner"}
|
||||
sp.agent_config = {
|
||||
"default_model_id": "agent-model",
|
||||
"user_id": "owner",
|
||||
}
|
||||
captured: list = []
|
||||
|
||||
def _fake_validate(model_id, user_id=None):
|
||||
captured.append((model_id, user_id))
|
||||
return True
|
||||
|
||||
with patch(
|
||||
"application.api.answer.services.stream_processor.validate_model_id",
|
||||
side_effect=_fake_validate,
|
||||
), patch(
|
||||
"application.api.answer.services.stream_processor.get_default_model_id",
|
||||
return_value="global-default",
|
||||
):
|
||||
sp._validate_and_set_model()
|
||||
assert sp.model_id == "agent-model"
|
||||
# Resolved under the agent owner, not the caller.
|
||||
assert sp.model_user_id == "owner"
|
||||
assert ("agent-model", "owner") in captured
|
||||
|
||||
def test_agent_bound_no_default_falls_back_to_system(self):
|
||||
from application.api.answer.services.stream_processor import (
|
||||
StreamProcessor,
|
||||
)
|
||||
sp = StreamProcessor({"model_id": "body-model"}, {"sub": "u"})
|
||||
sp._agent_data = {"user": "u"}
|
||||
sp.agent_config = {"default_model_id": "", "user_id": "u"}
|
||||
with patch(
|
||||
"application.api.answer.services.stream_processor.validate_model_id",
|
||||
return_value=False,
|
||||
), patch(
|
||||
"application.api.answer.services.stream_processor.get_default_model_id",
|
||||
return_value="global-default",
|
||||
):
|
||||
sp._validate_and_set_model()
|
||||
assert sp.model_id == "global-default"
|
||||
assert sp.model_user_id is None
|
||||
|
||||
def test_agentless_body_model_still_wins(self):
|
||||
from application.api.answer.services.stream_processor import (
|
||||
StreamProcessor,
|
||||
)
|
||||
sp = StreamProcessor({"model_id": "body-model"}, {"sub": "u"})
|
||||
sp._agent_data = None
|
||||
with patch(
|
||||
"application.api.answer.services.stream_processor.validate_model_id",
|
||||
return_value=True,
|
||||
):
|
||||
sp._validate_and_set_model()
|
||||
assert sp.model_id == "body-model"
|
||||
assert sp.model_user_id == "u"
|
||||
|
||||
|
||||
class TestGetDataFromApiKeySourceUnion:
|
||||
"""`_get_data_from_api_key`: primary ∪ extras, deduplicated, primary first."""
|
||||
|
||||
def _make_sp(self):
|
||||
from application.api.answer.services.stream_processor import (
|
||||
StreamProcessor,
|
||||
)
|
||||
return StreamProcessor({}, {"sub": "u"})
|
||||
|
||||
def test_union_primary_and_extras(self, pg_conn):
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.sources import SourcesRepository
|
||||
|
||||
owner = "u-merge-both"
|
||||
sources_repo = SourcesRepository(pg_conn)
|
||||
primary = sources_repo.create(name="primary", user_id=owner)
|
||||
extra1 = sources_repo.create(name="extra1", user_id=owner)
|
||||
extra2 = sources_repo.create(name="extra2", user_id=owner)
|
||||
|
||||
agent = AgentsRepository(pg_conn).create(
|
||||
owner, "agent-merge", "published",
|
||||
key="merge-key",
|
||||
source_id=str(primary["id"]),
|
||||
extra_source_ids=[str(extra1["id"]), str(extra2["id"])],
|
||||
retriever="hybrid",
|
||||
chunks=5,
|
||||
)
|
||||
assert agent is not None
|
||||
|
||||
sp = self._make_sp()
|
||||
with _patch_db(pg_conn):
|
||||
data = sp._get_data_from_api_key("merge-key")
|
||||
ids = [s["id"] for s in data["sources"]]
|
||||
assert ids == [
|
||||
str(primary["id"]),
|
||||
str(extra1["id"]),
|
||||
str(extra2["id"]),
|
||||
]
|
||||
assert data["source"] == str(primary["id"])
|
||||
|
||||
def test_only_primary(self, pg_conn):
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.sources import SourcesRepository
|
||||
|
||||
owner = "u-merge-primary-only"
|
||||
primary = SourcesRepository(pg_conn).create(
|
||||
name="primary", user_id=owner,
|
||||
)
|
||||
|
||||
AgentsRepository(pg_conn).create(
|
||||
owner, "primary-only", "published",
|
||||
key="primary-only-key",
|
||||
source_id=str(primary["id"]),
|
||||
extra_source_ids=[],
|
||||
)
|
||||
|
||||
sp = self._make_sp()
|
||||
with _patch_db(pg_conn):
|
||||
data = sp._get_data_from_api_key("primary-only-key")
|
||||
assert [s["id"] for s in data["sources"]] == [str(primary["id"])]
|
||||
assert data["source"] == str(primary["id"])
|
||||
|
||||
def test_only_extras(self, pg_conn):
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.sources import SourcesRepository
|
||||
|
||||
owner = "u-merge-extras-only"
|
||||
e1 = SourcesRepository(pg_conn).create(name="e1", user_id=owner)
|
||||
e2 = SourcesRepository(pg_conn).create(name="e2", user_id=owner)
|
||||
|
||||
AgentsRepository(pg_conn).create(
|
||||
owner, "extras-only", "published",
|
||||
key="extras-only-key",
|
||||
extra_source_ids=[str(e1["id"]), str(e2["id"])],
|
||||
)
|
||||
|
||||
sp = self._make_sp()
|
||||
with _patch_db(pg_conn):
|
||||
data = sp._get_data_from_api_key("extras-only-key")
|
||||
assert [s["id"] for s in data["sources"]] == [
|
||||
str(e1["id"]), str(e2["id"]),
|
||||
]
|
||||
assert data["source"] is None
|
||||
|
||||
def test_dedupe_primary_repeated_in_extras(self, pg_conn):
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.sources import SourcesRepository
|
||||
|
||||
owner = "u-merge-dedupe"
|
||||
primary = SourcesRepository(pg_conn).create(
|
||||
name="dup-primary", user_id=owner,
|
||||
)
|
||||
extra = SourcesRepository(pg_conn).create(
|
||||
name="dup-extra", user_id=owner,
|
||||
)
|
||||
|
||||
AgentsRepository(pg_conn).create(
|
||||
owner, "dedupe", "published",
|
||||
key="dedupe-key",
|
||||
source_id=str(primary["id"]),
|
||||
extra_source_ids=[str(primary["id"]), str(extra["id"])],
|
||||
)
|
||||
|
||||
sp = self._make_sp()
|
||||
with _patch_db(pg_conn):
|
||||
data = sp._get_data_from_api_key("dedupe-key")
|
||||
ids = [s["id"] for s in data["sources"]]
|
||||
assert ids == [str(primary["id"]), str(extra["id"])]
|
||||
|
||||
|
||||
class TestAgentBoundFieldsAuthoritative:
|
||||
"""End-to-end regression: agent's source/model/chunks/retriever win."""
|
||||
|
||||
def test_agent_values_win_over_body(self, pg_conn):
|
||||
from application.api.answer.services.stream_processor import (
|
||||
StreamProcessor,
|
||||
)
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.sources import SourcesRepository
|
||||
|
||||
owner = "u-regr-agent-authority"
|
||||
primary = SourcesRepository(pg_conn).create(
|
||||
name="primary", user_id=owner,
|
||||
)
|
||||
extra = SourcesRepository(pg_conn).create(
|
||||
name="extra", user_id=owner,
|
||||
)
|
||||
AgentsRepository(pg_conn).create(
|
||||
owner, "authoritative", "published",
|
||||
key="auth-key",
|
||||
source_id=str(primary["id"]),
|
||||
extra_source_ids=[str(extra["id"])],
|
||||
default_model_id="model-A",
|
||||
retriever="hybrid",
|
||||
chunks=5,
|
||||
)
|
||||
|
||||
# Body sends different values for every field; all must be ignored.
|
||||
body = {
|
||||
"api_key": "auth-key",
|
||||
"model_id": "body-model-Z",
|
||||
"retriever": "duckdb",
|
||||
"chunks": 99,
|
||||
"active_docs": "body-source-id",
|
||||
}
|
||||
sp = StreamProcessor(body, {"sub": owner})
|
||||
|
||||
with _patch_db(pg_conn), patch(
|
||||
"application.api.answer.services.stream_processor.validate_model_id",
|
||||
return_value=True,
|
||||
), patch(
|
||||
"application.api.answer.services.stream_processor.get_default_model_id",
|
||||
return_value="system-default",
|
||||
):
|
||||
sp._configure_agent()
|
||||
sp._validate_and_set_model()
|
||||
sp._configure_source()
|
||||
sp._configure_retriever()
|
||||
|
||||
assert sp.model_id == "model-A"
|
||||
assert sp.model_user_id == owner
|
||||
assert sp.agent_config["default_model_id"] == "model-A"
|
||||
assert sp.retriever_config["chunks"] == 5
|
||||
assert sp.retriever_config["retriever_name"] == "hybrid"
|
||||
assert sp.source == {
|
||||
"active_docs": [str(primary["id"]), str(extra["id"])],
|
||||
}
|
||||
|
||||
63
tests/api/user/test_base_extended.py
Normal file
63
tests/api/user/test_base_extended.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""Additional tests for application/api/user/base.py to cover remaining branches.
|
||||
|
||||
Target missing lines:
|
||||
- 131: when "pinned" is missing from existing agent_preferences
|
||||
- 157-158: invalid ObjectId in resolve_tool_details (continue branch)
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from bson import ObjectId
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestEnsureUserDocMissingPinnedBranch:
|
||||
"""Cover line 131: when existing doc has shared_with_me but no pinned."""
|
||||
|
||||
def test_adds_missing_pinned_field(self, mock_mongo_db):
|
||||
from application.api.user.base import ensure_user_doc
|
||||
from application.core.settings import settings
|
||||
|
||||
users_collection = mock_mongo_db[settings.MONGO_DB_NAME]["users"]
|
||||
user_id = "user_missing_pinned"
|
||||
|
||||
# Insert a user that only has shared_with_me but not pinned
|
||||
users_collection.insert_one(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"agent_preferences": {
|
||||
"shared_with_me": ["agent-x"],
|
||||
# "pinned" intentionally absent
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
result = ensure_user_doc(user_id)
|
||||
|
||||
assert "pinned" in result["agent_preferences"]
|
||||
assert result["agent_preferences"]["pinned"] == []
|
||||
assert result["agent_preferences"]["shared_with_me"] == ["agent-x"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestResolveToolDetailsInvalidIds:
|
||||
"""Cover lines 157-158: invalid ObjectId strings are skipped silently."""
|
||||
|
||||
def test_skips_invalid_object_ids(self, mock_mongo_db):
|
||||
from application.api.user.base import resolve_tool_details
|
||||
from application.core.settings import settings
|
||||
|
||||
user_tools = mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"]
|
||||
valid_id = ObjectId()
|
||||
user_tools.insert_one({"_id": valid_id, "name": "valid_tool"})
|
||||
|
||||
result = resolve_tool_details(["not-an-objectid", str(valid_id), "also-invalid"])
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["id"] == str(valid_id)
|
||||
assert result[0]["name"] == "valid_tool"
|
||||
|
||||
def test_all_invalid_ids_returns_empty(self, mock_mongo_db):
|
||||
from application.api.user.base import resolve_tool_details
|
||||
|
||||
result = resolve_tool_details(["bad-id-1", "bad-id-2"])
|
||||
assert result == []
|
||||
401
tests/api/user/test_scheduler_dispatcher.py
Normal file
401
tests/api/user/test_scheduler_dispatcher.py
Normal file
@@ -0,0 +1,401 @@
|
||||
"""Tests for the scheduler dispatcher (engine-level, no Celery worker)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import text
|
||||
|
||||
from application.api.user.scheduler_dispatcher import dispatch_due_runs
|
||||
from application.storage.db.repositories.schedule_runs import (
|
||||
ScheduleRunsRepository,
|
||||
)
|
||||
from application.storage.db.repositories.schedules import SchedulesRepository
|
||||
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def _make_agent(conn, user_id: str = "u1") -> str:
|
||||
row = conn.execute(
|
||||
text(
|
||||
"INSERT INTO agents (user_id, name, status) "
|
||||
"VALUES (:u, 'a', 'draft') RETURNING id"
|
||||
),
|
||||
{"u": user_id},
|
||||
).fetchone()
|
||||
return str(row[0])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_engine(pg_engine, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"application.api.user.scheduler_dispatcher.get_engine",
|
||||
lambda: pg_engine,
|
||||
)
|
||||
yield pg_engine
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def stub_enqueue(monkeypatch):
|
||||
"""Capture every execute_scheduled_run.apply_async."""
|
||||
enqueued: list[str] = []
|
||||
|
||||
class _Task:
|
||||
@staticmethod
|
||||
def apply_async(args=None, **kwargs):
|
||||
if args:
|
||||
enqueued.append(args[0])
|
||||
|
||||
monkeypatch.setattr(
|
||||
"application.api.user.tasks.execute_scheduled_run", _Task
|
||||
)
|
||||
return enqueued
|
||||
|
||||
|
||||
def _create_schedule(engine, **kwargs):
|
||||
with engine.begin() as conn:
|
||||
return SchedulesRepository(conn).create(**kwargs)
|
||||
|
||||
|
||||
def _set_postgres_uri(monkeypatch, pg_engine):
|
||||
monkeypatch.setattr(
|
||||
"application.api.user.scheduler_dispatcher.settings",
|
||||
type("S", (), {
|
||||
"POSTGRES_URI": str(pg_engine.url),
|
||||
"SCHEDULE_MISFIRE_GRACE": 60,
|
||||
})(),
|
||||
)
|
||||
|
||||
|
||||
class TestDispatcherBasic:
|
||||
def test_due_recurring_enqueues_once_and_advances(
|
||||
self, pg_engine, patched_engine, stub_enqueue, monkeypatch,
|
||||
):
|
||||
_set_postgres_uri(monkeypatch, pg_engine)
|
||||
with pg_engine.begin() as conn:
|
||||
agent_id = _make_agent(conn)
|
||||
schedule = _create_schedule(
|
||||
pg_engine,
|
||||
user_id="u1", agent_id=agent_id, trigger_type="recurring",
|
||||
instruction="i", cron="* * * * *",
|
||||
next_run_at=_now() - timedelta(seconds=5),
|
||||
)
|
||||
counts = dispatch_due_runs()
|
||||
assert counts["enqueued"] == 1
|
||||
assert len(stub_enqueue) == 1
|
||||
with pg_engine.connect() as conn:
|
||||
row = SchedulesRepository(conn).get_internal(str(schedule["id"]))
|
||||
assert row["next_run_at"] is not None
|
||||
|
||||
def test_once_dispatch_nulls_next_run_at_keeps_active(
|
||||
self, pg_engine, patched_engine, stub_enqueue, monkeypatch,
|
||||
):
|
||||
"""Once: dispatcher nulls next_run_at but leaves status='active' for the worker."""
|
||||
_set_postgres_uri(monkeypatch, pg_engine)
|
||||
with pg_engine.begin() as conn:
|
||||
agent_id = _make_agent(conn)
|
||||
schedule = _create_schedule(
|
||||
pg_engine,
|
||||
user_id="u1", agent_id=agent_id, trigger_type="once",
|
||||
instruction="i", run_at=_now() + timedelta(seconds=1),
|
||||
next_run_at=_now() - timedelta(seconds=5),
|
||||
)
|
||||
counts = dispatch_due_runs()
|
||||
assert counts["enqueued"] == 1
|
||||
with pg_engine.connect() as conn:
|
||||
row = SchedulesRepository(conn).get_internal(str(schedule["id"]))
|
||||
assert row["status"] == "active"
|
||||
assert row["next_run_at"] is None
|
||||
|
||||
|
||||
class TestDedupConstraint:
|
||||
def test_double_dispatch_only_one_run(
|
||||
self, pg_engine, patched_engine, stub_enqueue, monkeypatch,
|
||||
):
|
||||
_set_postgres_uri(monkeypatch, pg_engine)
|
||||
with pg_engine.begin() as conn:
|
||||
agent_id = _make_agent(conn)
|
||||
schedule = _create_schedule(
|
||||
pg_engine,
|
||||
user_id="u1", agent_id=agent_id, trigger_type="recurring",
|
||||
instruction="i", cron="*/5 * * * *",
|
||||
next_run_at=_now() - timedelta(seconds=2),
|
||||
)
|
||||
# Pre-claim simulates a racing dispatcher tick.
|
||||
with pg_engine.begin() as conn:
|
||||
row = SchedulesRepository(conn).get_internal(str(schedule["id"]))
|
||||
ScheduleRunsRepository(conn).record_pending(
|
||||
str(schedule["id"]),
|
||||
"u1",
|
||||
str(row["agent_id"]),
|
||||
row["next_run_at"],
|
||||
)
|
||||
counts = dispatch_due_runs()
|
||||
assert counts["enqueued"] == 0
|
||||
assert stub_enqueue == []
|
||||
|
||||
|
||||
class TestMisfireGrace:
|
||||
def test_stale_tick_recorded_skipped(
|
||||
self, pg_engine, patched_engine, stub_enqueue, monkeypatch,
|
||||
):
|
||||
_set_postgres_uri(monkeypatch, pg_engine)
|
||||
monkeypatch.setattr(
|
||||
"application.api.user.scheduler_dispatcher.settings",
|
||||
type("S", (), {
|
||||
"POSTGRES_URI": str(pg_engine.url),
|
||||
"SCHEDULE_MISFIRE_GRACE": 30,
|
||||
})(),
|
||||
)
|
||||
with pg_engine.begin() as conn:
|
||||
agent_id = _make_agent(conn)
|
||||
schedule = _create_schedule(
|
||||
pg_engine,
|
||||
user_id="u1", agent_id=agent_id, trigger_type="recurring",
|
||||
instruction="i", cron="*/5 * * * *",
|
||||
next_run_at=_now() - timedelta(hours=2),
|
||||
)
|
||||
counts = dispatch_due_runs()
|
||||
assert counts["enqueued"] == 0
|
||||
assert counts["skipped"] >= 1
|
||||
with pg_engine.connect() as conn:
|
||||
runs = ScheduleRunsRepository(conn).list_runs(
|
||||
str(schedule["id"]), "u1",
|
||||
)
|
||||
assert any(r["error_type"] == "missed" for r in runs)
|
||||
|
||||
|
||||
class TestOverlap:
|
||||
def test_active_run_blocks_dispatch(
|
||||
self, pg_engine, patched_engine, stub_enqueue, monkeypatch,
|
||||
):
|
||||
_set_postgres_uri(monkeypatch, pg_engine)
|
||||
with pg_engine.begin() as conn:
|
||||
agent_id = _make_agent(conn)
|
||||
schedule = _create_schedule(
|
||||
pg_engine,
|
||||
user_id="u1", agent_id=agent_id, trigger_type="recurring",
|
||||
instruction="i", cron="*/5 * * * *",
|
||||
next_run_at=_now() - timedelta(seconds=2),
|
||||
)
|
||||
# Pre-create a running run with a different scheduled_for so overlap fires.
|
||||
with pg_engine.begin() as conn:
|
||||
row = ScheduleRunsRepository(conn).record_pending(
|
||||
str(schedule["id"]),
|
||||
"u1",
|
||||
str(agent_id),
|
||||
_now() - timedelta(minutes=10),
|
||||
)
|
||||
ScheduleRunsRepository(conn).mark_running(row["id"], "t1")
|
||||
counts = dispatch_due_runs()
|
||||
assert counts["enqueued"] == 0
|
||||
|
||||
def test_once_overlap_clears_next_run_at(
|
||||
self, pg_engine, patched_engine, stub_enqueue, monkeypatch,
|
||||
):
|
||||
"""Once + overlap nulls next_run_at so the dispatcher stops re-picking."""
|
||||
_set_postgres_uri(monkeypatch, pg_engine)
|
||||
with pg_engine.begin() as conn:
|
||||
agent_id = _make_agent(conn)
|
||||
schedule = _create_schedule(
|
||||
pg_engine,
|
||||
user_id="u1", agent_id=agent_id, trigger_type="once",
|
||||
instruction="i", run_at=_now() + timedelta(seconds=30),
|
||||
next_run_at=_now() - timedelta(seconds=5),
|
||||
)
|
||||
with pg_engine.begin() as conn:
|
||||
existing = ScheduleRunsRepository(conn).record_pending(
|
||||
str(schedule["id"]),
|
||||
"u1",
|
||||
str(agent_id),
|
||||
_now() - timedelta(minutes=10),
|
||||
)
|
||||
ScheduleRunsRepository(conn).mark_running(existing["id"], "t-prev")
|
||||
dispatch_due_runs()
|
||||
with pg_engine.connect() as conn:
|
||||
row = SchedulesRepository(conn).get_internal(str(schedule["id"]))
|
||||
assert row["status"] == "active"
|
||||
assert row["next_run_at"] is None
|
||||
|
||||
|
||||
class TestAgentlessSchedules:
|
||||
def test_dispatcher_claims_and_enqueues_agentless_once(
|
||||
self, pg_engine, patched_engine, stub_enqueue, monkeypatch,
|
||||
):
|
||||
"""``agent_id IS NULL`` rows are claimed like any other once-schedule."""
|
||||
_set_postgres_uri(monkeypatch, pg_engine)
|
||||
schedule = _create_schedule(
|
||||
pg_engine,
|
||||
user_id="u-agentless", agent_id=None, trigger_type="once",
|
||||
instruction="agentless ping",
|
||||
run_at=_now() + timedelta(seconds=30),
|
||||
next_run_at=_now() - timedelta(seconds=5),
|
||||
origin_conversation_id=None,
|
||||
created_via="chat",
|
||||
)
|
||||
counts = dispatch_due_runs()
|
||||
assert counts["enqueued"] == 1
|
||||
assert len(stub_enqueue) == 1
|
||||
with pg_engine.connect() as conn:
|
||||
sched = SchedulesRepository(conn).get_internal(str(schedule["id"]))
|
||||
run_row = conn.execute(
|
||||
text(
|
||||
"SELECT * FROM schedule_runs "
|
||||
"WHERE schedule_id = CAST(:s AS uuid)"
|
||||
),
|
||||
{"s": str(schedule["id"])},
|
||||
).fetchone()
|
||||
# Once: dispatcher nulled next_run_at, schedule still active.
|
||||
assert sched["status"] == "active"
|
||||
assert sched["next_run_at"] is None
|
||||
# The pending run carries NULL agent_id (matches the parent).
|
||||
assert run_row._mapping["agent_id"] is None
|
||||
assert run_row._mapping["user_id"] == "u-agentless"
|
||||
|
||||
|
||||
class TestAgentlessRoundTrip:
|
||||
"""Agentless chat → tool → dispatcher → headless run → message appended."""
|
||||
|
||||
def test_agentless_dispatch_executes_and_appends_message(
|
||||
self, pg_engine, patched_engine, stub_enqueue, monkeypatch,
|
||||
):
|
||||
from unittest.mock import patch
|
||||
|
||||
from application.api.user.scheduler_worker import (
|
||||
execute_scheduled_run_body,
|
||||
)
|
||||
|
||||
_set_postgres_uri(monkeypatch, pg_engine)
|
||||
monkeypatch.setattr(
|
||||
"application.api.user.scheduler_worker.get_engine",
|
||||
lambda: pg_engine,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.api.user.scheduler_worker.settings",
|
||||
type("S", (), {
|
||||
"POSTGRES_URI": str(pg_engine.url),
|
||||
"SCHEDULE_AUTOPAUSE_FAILURES": 3,
|
||||
})(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.api.user.scheduler_worker.publish_user_event",
|
||||
lambda *a, **k: "1-0",
|
||||
)
|
||||
|
||||
with pg_engine.begin() as conn:
|
||||
conv_id = conn.execute(
|
||||
text(
|
||||
"INSERT INTO conversations (user_id, name) "
|
||||
"VALUES ('u-e2e', 'agentless-chat') RETURNING id"
|
||||
)
|
||||
).fetchone()[0]
|
||||
schedule = _create_schedule(
|
||||
pg_engine,
|
||||
user_id="u-e2e", agent_id=None, trigger_type="once",
|
||||
instruction="ping later",
|
||||
run_at=_now() + timedelta(seconds=30),
|
||||
next_run_at=_now() - timedelta(seconds=5),
|
||||
origin_conversation_id=str(conv_id),
|
||||
created_via="chat",
|
||||
)
|
||||
|
||||
counts = dispatch_due_runs()
|
||||
assert counts["enqueued"] == 1
|
||||
run_id = stub_enqueue[0]
|
||||
|
||||
with patch(
|
||||
"application.api.user.scheduler_worker.run_agent_headless",
|
||||
return_value={
|
||||
"answer": "agentless e2e done",
|
||||
"tool_calls": [], "sources": [], "thought": "",
|
||||
"prompt_tokens": 1, "generated_tokens": 1,
|
||||
"denied": [], "error_type": None, "model_id": "fake",
|
||||
},
|
||||
):
|
||||
result = execute_scheduled_run_body(run_id, "celery-e2e")
|
||||
assert result["status"] == "success"
|
||||
|
||||
with pg_engine.connect() as conn:
|
||||
run = ScheduleRunsRepository(conn).get_internal(run_id)
|
||||
sched = SchedulesRepository(conn).get_internal(str(schedule["id"]))
|
||||
messages = conn.execute(
|
||||
text(
|
||||
"SELECT * FROM conversation_messages "
|
||||
"WHERE conversation_id = CAST(:c AS uuid)"
|
||||
),
|
||||
{"c": str(conv_id)},
|
||||
).fetchall()
|
||||
assert run["status"] == "success"
|
||||
assert run["agent_id"] is None
|
||||
assert sched["status"] == "completed"
|
||||
assert sched["agent_id"] is None
|
||||
assert len(messages) == 1
|
||||
meta = messages[0]._mapping["message_metadata"]
|
||||
assert meta.get("scheduled") is True
|
||||
|
||||
|
||||
class TestOnceRoundTrip:
|
||||
"""End-to-end: chat-driven once-schedule executes and the schedule completes."""
|
||||
|
||||
def test_once_dispatch_executes_and_completes_schedule(
|
||||
self, pg_engine, patched_engine, stub_enqueue, monkeypatch,
|
||||
):
|
||||
from unittest.mock import patch
|
||||
|
||||
from application.api.user.scheduler_worker import (
|
||||
execute_scheduled_run_body,
|
||||
)
|
||||
|
||||
_set_postgres_uri(monkeypatch, pg_engine)
|
||||
monkeypatch.setattr(
|
||||
"application.api.user.scheduler_worker.get_engine",
|
||||
lambda: pg_engine,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.api.user.scheduler_worker.settings",
|
||||
type("S", (), {
|
||||
"POSTGRES_URI": str(pg_engine.url),
|
||||
"SCHEDULE_AUTOPAUSE_FAILURES": 3,
|
||||
})(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.api.user.scheduler_worker.publish_user_event",
|
||||
lambda *a, **k: "1-0",
|
||||
)
|
||||
with pg_engine.begin() as conn:
|
||||
agent_id = _make_agent(conn)
|
||||
schedule = _create_schedule(
|
||||
pg_engine,
|
||||
user_id="u1", agent_id=agent_id, trigger_type="once",
|
||||
instruction="follow up", run_at=_now() + timedelta(seconds=1),
|
||||
next_run_at=_now() - timedelta(seconds=5),
|
||||
)
|
||||
counts = dispatch_due_runs()
|
||||
assert counts["enqueued"] == 1
|
||||
assert len(stub_enqueue) == 1
|
||||
run_id = stub_enqueue[0]
|
||||
with pg_engine.connect() as conn:
|
||||
sched = SchedulesRepository(conn).get_internal(str(schedule["id"]))
|
||||
assert sched["status"] == "active"
|
||||
assert sched["next_run_at"] is None
|
||||
with patch(
|
||||
"application.api.user.scheduler_worker.run_agent_headless",
|
||||
return_value={
|
||||
"answer": "done",
|
||||
"tool_calls": [], "sources": [], "thought": "",
|
||||
"prompt_tokens": 1, "generated_tokens": 1,
|
||||
"denied": [], "error_type": None, "model_id": "fake",
|
||||
},
|
||||
):
|
||||
result = execute_scheduled_run_body(run_id, "celery-c1")
|
||||
assert result["status"] == "success"
|
||||
with pg_engine.connect() as conn:
|
||||
run = ScheduleRunsRepository(conn).get_internal(run_id)
|
||||
sched = SchedulesRepository(conn).get_internal(str(schedule["id"]))
|
||||
assert run["status"] == "success"
|
||||
assert run["output"] == "done"
|
||||
assert sched["status"] == "completed"
|
||||
assert sched["next_run_at"] is None
|
||||
282
tests/api/user/test_scheduler_reconcile.py
Normal file
282
tests/api/user/test_scheduler_reconcile.py
Normal file
@@ -0,0 +1,282 @@
|
||||
"""Tests for the scheduler reconciliation sweep + cleanup task."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import text
|
||||
|
||||
from application.api.user.reconciliation import run_reconciliation
|
||||
from application.storage.db.repositories.schedule_runs import (
|
||||
ScheduleRunsRepository,
|
||||
)
|
||||
from application.storage.db.repositories.schedules import SchedulesRepository
|
||||
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def _make_pending_run(conn, *, user_id="u1"):
|
||||
agent_id = str(
|
||||
conn.execute(
|
||||
text(
|
||||
"INSERT INTO agents (user_id, name, status) "
|
||||
"VALUES (:u, 'a', 'draft') RETURNING id"
|
||||
),
|
||||
{"u": user_id},
|
||||
).fetchone()[0]
|
||||
)
|
||||
schedule = SchedulesRepository(conn).create(
|
||||
user_id=user_id, agent_id=agent_id, trigger_type="recurring",
|
||||
instruction="i", cron="*/5 * * * *",
|
||||
next_run_at=_now() + timedelta(minutes=5),
|
||||
)
|
||||
run = ScheduleRunsRepository(conn).record_pending(
|
||||
str(schedule["id"]), user_id, agent_id, _now(),
|
||||
)
|
||||
return schedule, run, agent_id
|
||||
|
||||
|
||||
def _make_once_pending_run(conn, *, user_id="u1"):
|
||||
"""Once-schedule + pending run variant of _make_pending_run."""
|
||||
agent_id = str(
|
||||
conn.execute(
|
||||
text(
|
||||
"INSERT INTO agents (user_id, name, status) "
|
||||
"VALUES (:u, 'a', 'draft') RETURNING id"
|
||||
),
|
||||
{"u": user_id},
|
||||
).fetchone()[0]
|
||||
)
|
||||
fire = _now() + timedelta(minutes=5)
|
||||
schedule = SchedulesRepository(conn).create(
|
||||
user_id=user_id, agent_id=agent_id, trigger_type="once",
|
||||
instruction="do once", run_at=fire,
|
||||
next_run_at=fire,
|
||||
)
|
||||
run = ScheduleRunsRepository(conn).record_pending(
|
||||
str(schedule["id"]), user_id, agent_id, fire,
|
||||
)
|
||||
return schedule, run, agent_id
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_engine(pg_engine, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"application.api.user.reconciliation.get_engine",
|
||||
lambda: pg_engine,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.api.user.reconciliation.settings",
|
||||
type("S", (), {
|
||||
"POSTGRES_URI": str(pg_engine.url),
|
||||
"SCHEDULE_RUN_TIMEOUT": 60,
|
||||
})(),
|
||||
)
|
||||
yield pg_engine
|
||||
|
||||
|
||||
class TestReconciler:
|
||||
def test_stuck_running_flipped_to_timeout(self, pg_engine, patched_engine):
|
||||
with pg_engine.begin() as conn:
|
||||
schedule, run, _ = _make_pending_run(conn)
|
||||
ScheduleRunsRepository(conn).mark_running(run["id"], "t1")
|
||||
conn.execute(
|
||||
text(
|
||||
"UPDATE schedule_runs "
|
||||
"SET started_at = now() - interval '120 minutes' "
|
||||
"WHERE id = CAST(:i AS uuid)"
|
||||
),
|
||||
{"i": run["id"]},
|
||||
)
|
||||
|
||||
summary = run_reconciliation()
|
||||
assert summary["schedule_runs_failed"] >= 1
|
||||
with pg_engine.connect() as conn:
|
||||
row = ScheduleRunsRepository(conn).get_internal(run["id"])
|
||||
sched = SchedulesRepository(conn).get_internal(str(schedule["id"]))
|
||||
assert row["status"] == "timeout"
|
||||
assert row["error_type"] == "timeout"
|
||||
assert sched["consecutive_failure_count"] == 1
|
||||
|
||||
def test_stuck_pending_flipped_to_failed(self, pg_engine, patched_engine):
|
||||
"""A pending run whose worker never started reconciles to 'failed'."""
|
||||
with pg_engine.begin() as conn:
|
||||
schedule, run, _ = _make_pending_run(conn)
|
||||
conn.execute(
|
||||
text(
|
||||
"UPDATE schedule_runs "
|
||||
"SET created_at = now() - interval '120 minutes' "
|
||||
"WHERE id = CAST(:i AS uuid)"
|
||||
),
|
||||
{"i": run["id"]},
|
||||
)
|
||||
|
||||
summary = run_reconciliation()
|
||||
assert summary["schedule_runs_failed"] >= 1
|
||||
with pg_engine.connect() as conn:
|
||||
row = ScheduleRunsRepository(conn).get_internal(run["id"])
|
||||
sched = SchedulesRepository(conn).get_internal(str(schedule["id"]))
|
||||
assert row["status"] == "failed"
|
||||
assert row["error_type"] == "internal"
|
||||
assert "worker_never_started" in (row["error"] or "")
|
||||
assert sched["consecutive_failure_count"] == 1
|
||||
|
||||
def test_once_schedule_with_stuck_running_run_marked_completed(
|
||||
self, pg_engine, patched_engine,
|
||||
):
|
||||
"""Once + stuck 'running' run -> parent flipped to 'completed'."""
|
||||
with pg_engine.begin() as conn:
|
||||
schedule, run, _ = _make_once_pending_run(conn)
|
||||
ScheduleRunsRepository(conn).mark_running(run["id"], "t-once")
|
||||
conn.execute(
|
||||
text(
|
||||
"UPDATE schedule_runs "
|
||||
"SET started_at = now() - interval '120 minutes' "
|
||||
"WHERE id = CAST(:i AS uuid)"
|
||||
),
|
||||
{"i": run["id"]},
|
||||
)
|
||||
|
||||
run_reconciliation()
|
||||
with pg_engine.connect() as conn:
|
||||
sched = SchedulesRepository(conn).get_internal(str(schedule["id"]))
|
||||
row = ScheduleRunsRepository(conn).get_internal(run["id"])
|
||||
assert row["status"] == "timeout"
|
||||
assert sched["status"] == "completed", (
|
||||
"stuck once-run must terminal-flip the parent schedule"
|
||||
)
|
||||
assert sched["next_run_at"] is None
|
||||
|
||||
def test_once_schedule_with_stuck_pending_run_marked_completed(
|
||||
self, pg_engine, patched_engine,
|
||||
):
|
||||
"""Once + stuck 'pending' run -> parent flipped to 'completed'."""
|
||||
with pg_engine.begin() as conn:
|
||||
schedule, run, _ = _make_once_pending_run(conn)
|
||||
conn.execute(
|
||||
text(
|
||||
"UPDATE schedule_runs "
|
||||
"SET created_at = now() - interval '120 minutes' "
|
||||
"WHERE id = CAST(:i AS uuid)"
|
||||
),
|
||||
{"i": run["id"]},
|
||||
)
|
||||
|
||||
run_reconciliation()
|
||||
with pg_engine.connect() as conn:
|
||||
sched = SchedulesRepository(conn).get_internal(str(schedule["id"]))
|
||||
row = ScheduleRunsRepository(conn).get_internal(run["id"])
|
||||
assert row["status"] == "failed"
|
||||
assert sched["status"] == "completed", (
|
||||
"stuck pending once-run must terminal-flip the parent schedule"
|
||||
)
|
||||
assert sched["next_run_at"] is None
|
||||
|
||||
def test_agentless_once_stuck_running_marked_completed(
|
||||
self, pg_engine, patched_engine,
|
||||
):
|
||||
"""Stuck-run terminal flip works for agentless once-schedules."""
|
||||
with pg_engine.begin() as conn:
|
||||
fire = _now() + timedelta(minutes=5)
|
||||
schedule = SchedulesRepository(conn).create(
|
||||
user_id="u-agentless", agent_id=None, trigger_type="once",
|
||||
instruction="agentless go", run_at=fire,
|
||||
next_run_at=fire,
|
||||
created_via="chat",
|
||||
)
|
||||
run = ScheduleRunsRepository(conn).record_pending(
|
||||
str(schedule["id"]), "u-agentless", None, fire,
|
||||
)
|
||||
ScheduleRunsRepository(conn).mark_running(run["id"], "t-stuck")
|
||||
conn.execute(
|
||||
text(
|
||||
"UPDATE schedule_runs "
|
||||
"SET started_at = now() - interval '120 minutes' "
|
||||
"WHERE id = CAST(:i AS uuid)"
|
||||
),
|
||||
{"i": run["id"]},
|
||||
)
|
||||
|
||||
run_reconciliation()
|
||||
with pg_engine.connect() as conn:
|
||||
sched = SchedulesRepository(conn).get_internal(str(schedule["id"]))
|
||||
row = ScheduleRunsRepository(conn).get_internal(run["id"])
|
||||
assert row["status"] == "timeout"
|
||||
assert sched["status"] == "completed"
|
||||
assert sched["next_run_at"] is None
|
||||
|
||||
def test_recurring_schedule_with_stuck_run_stays_active(
|
||||
self, pg_engine, patched_engine,
|
||||
):
|
||||
"""Recurring keeps firing; only the run flips, not the parent."""
|
||||
with pg_engine.begin() as conn:
|
||||
schedule, run, _ = _make_pending_run(conn)
|
||||
ScheduleRunsRepository(conn).mark_running(run["id"], "t-rec")
|
||||
conn.execute(
|
||||
text(
|
||||
"UPDATE schedule_runs "
|
||||
"SET started_at = now() - interval '120 minutes' "
|
||||
"WHERE id = CAST(:i AS uuid)"
|
||||
),
|
||||
{"i": run["id"]},
|
||||
)
|
||||
|
||||
run_reconciliation()
|
||||
with pg_engine.connect() as conn:
|
||||
sched = SchedulesRepository(conn).get_internal(str(schedule["id"]))
|
||||
assert sched["status"] == "active"
|
||||
assert sched["consecutive_failure_count"] == 1
|
||||
|
||||
|
||||
class TestCleanup:
|
||||
def test_cleanup_schedule_runs_trims_old_rows(self, pg_engine, monkeypatch):
|
||||
from application.api.user.tasks import cleanup_schedule_runs as _task
|
||||
|
||||
monkeypatch.setattr(
|
||||
"application.storage.db.engine.get_engine",
|
||||
lambda: pg_engine,
|
||||
)
|
||||
|
||||
class S:
|
||||
POSTGRES_URI = str(pg_engine.url)
|
||||
SCHEDULE_RUN_OUTPUT_RETENTION_DAYS = 30
|
||||
monkeypatch.setattr("application.api.user.tasks.settings", S, raising=False)
|
||||
monkeypatch.setattr(
|
||||
"application.core.settings.settings.POSTGRES_URI",
|
||||
str(pg_engine.url),
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.settings.settings.SCHEDULE_RUN_OUTPUT_RETENTION_DAYS",
|
||||
30,
|
||||
raising=False,
|
||||
)
|
||||
|
||||
with pg_engine.begin() as conn:
|
||||
schedule, _, _ = _make_pending_run(conn)
|
||||
for i in range(60):
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO schedule_runs (
|
||||
schedule_id, user_id, agent_id, status,
|
||||
scheduled_for, created_at
|
||||
) VALUES (
|
||||
CAST(:s AS uuid), 'u1',
|
||||
CAST(:a AS uuid), 'success',
|
||||
now() - interval '100 days' - (:i * interval '1 second'),
|
||||
now() - interval '100 days'
|
||||
)
|
||||
"""
|
||||
),
|
||||
{"s": str(schedule["id"]),
|
||||
"a": str(schedule["agent_id"]),
|
||||
"i": i},
|
||||
)
|
||||
|
||||
result = _task.run()
|
||||
assert isinstance(result.get("deleted"), int)
|
||||
assert result["deleted"] >= 1
|
||||
428
tests/api/user/test_scheduler_worker.py
Normal file
428
tests/api/user/test_scheduler_worker.py
Normal file
@@ -0,0 +1,428 @@
|
||||
"""Tests for execute_scheduled_run_body (mocked agent run)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import text
|
||||
|
||||
from application.api.user.scheduler_worker import execute_scheduled_run_body
|
||||
from application.storage.db.repositories.schedule_runs import (
|
||||
ScheduleRunsRepository,
|
||||
)
|
||||
from application.storage.db.repositories.schedules import SchedulesRepository
|
||||
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def _make_agent(conn, user_id: str = "u1") -> str:
|
||||
row = conn.execute(
|
||||
text(
|
||||
"INSERT INTO agents (user_id, name, status, default_model_id) "
|
||||
"VALUES (:u, 'a', 'draft', '') RETURNING id"
|
||||
),
|
||||
{"u": user_id},
|
||||
).fetchone()
|
||||
return str(row[0])
|
||||
|
||||
|
||||
def _make_pending_run(conn, *, user_id="u1"):
|
||||
agent_id = _make_agent(conn, user_id)
|
||||
schedule = SchedulesRepository(conn).create(
|
||||
user_id=user_id, agent_id=agent_id, trigger_type="recurring",
|
||||
instruction="hello", cron="* * * * *",
|
||||
next_run_at=_now() + timedelta(minutes=5),
|
||||
)
|
||||
run = ScheduleRunsRepository(conn).record_pending(
|
||||
str(schedule["id"]),
|
||||
user_id,
|
||||
agent_id,
|
||||
_now(),
|
||||
)
|
||||
return schedule, run, agent_id
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_engine(pg_engine, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"application.api.user.scheduler_worker.get_engine",
|
||||
lambda: pg_engine,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.api.user.scheduler_worker.settings",
|
||||
type("S", (), {
|
||||
"POSTGRES_URI": str(pg_engine.url),
|
||||
"SCHEDULE_AUTOPAUSE_FAILURES": 2,
|
||||
})(),
|
||||
)
|
||||
yield pg_engine
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def stub_events(monkeypatch):
|
||||
captured: list[tuple] = []
|
||||
|
||||
def _fake_publish(user_id, event_type, payload, *, scope=None):
|
||||
captured.append((event_type, payload, scope))
|
||||
return "1-0"
|
||||
|
||||
monkeypatch.setattr(
|
||||
"application.api.user.scheduler_worker.publish_user_event",
|
||||
_fake_publish,
|
||||
)
|
||||
return captured
|
||||
|
||||
|
||||
class TestExecuteScheduledRunBody:
|
||||
def test_success_flow(self, pg_engine, patched_engine, stub_events):
|
||||
with pg_engine.begin() as conn:
|
||||
schedule, run, _ = _make_pending_run(conn)
|
||||
with patch(
|
||||
"application.api.user.scheduler_worker.run_agent_headless",
|
||||
return_value={
|
||||
"answer": "all done",
|
||||
"tool_calls": [],
|
||||
"sources": [],
|
||||
"thought": "",
|
||||
"prompt_tokens": 10,
|
||||
"generated_tokens": 5,
|
||||
"denied": [],
|
||||
"error_type": None,
|
||||
"model_id": "fake-model",
|
||||
},
|
||||
):
|
||||
result = execute_scheduled_run_body(str(run["id"]), "celery-1")
|
||||
assert result["status"] == "success"
|
||||
with pg_engine.connect() as conn:
|
||||
row = ScheduleRunsRepository(conn).get_internal(str(run["id"]))
|
||||
sched = SchedulesRepository(conn).get_internal(str(schedule["id"]))
|
||||
assert row["status"] == "success"
|
||||
assert row["output"] == "all done"
|
||||
assert row["prompt_tokens"] == 10
|
||||
assert sched["consecutive_failure_count"] == 0
|
||||
event_types = [e[0] for e in stub_events]
|
||||
assert "schedule.run.completed" in event_types
|
||||
|
||||
def test_agent_exception_marks_failed_and_bumps(
|
||||
self, pg_engine, patched_engine, stub_events,
|
||||
):
|
||||
with pg_engine.begin() as conn:
|
||||
schedule, run, _ = _make_pending_run(conn)
|
||||
with patch(
|
||||
"application.api.user.scheduler_worker.run_agent_headless",
|
||||
side_effect=RuntimeError("boom"),
|
||||
):
|
||||
result = execute_scheduled_run_body(str(run["id"]), "celery-2")
|
||||
assert result["status"] == "failed"
|
||||
with pg_engine.connect() as conn:
|
||||
row = ScheduleRunsRepository(conn).get_internal(str(run["id"]))
|
||||
sched = SchedulesRepository(conn).get_internal(str(schedule["id"]))
|
||||
assert row["status"] == "failed"
|
||||
assert row["error_type"] == "agent_error"
|
||||
assert sched["consecutive_failure_count"] == 1
|
||||
assert "schedule.run.failed" in {e[0] for e in stub_events}
|
||||
|
||||
def test_autopause_after_threshold(
|
||||
self, pg_engine, patched_engine, stub_events,
|
||||
):
|
||||
with pg_engine.begin() as conn:
|
||||
schedule, run, agent_id = _make_pending_run(conn)
|
||||
SchedulesRepository(conn).bump_failure_count(str(schedule["id"]))
|
||||
another_run = ScheduleRunsRepository(conn).record_pending(
|
||||
str(schedule["id"]),
|
||||
"u1",
|
||||
agent_id,
|
||||
_now() + timedelta(seconds=1),
|
||||
)
|
||||
with patch(
|
||||
"application.api.user.scheduler_worker.run_agent_headless",
|
||||
side_effect=RuntimeError("boom"),
|
||||
):
|
||||
execute_scheduled_run_body(str(another_run["id"]), "celery-3")
|
||||
with pg_engine.connect() as conn:
|
||||
sched = SchedulesRepository(conn).get_internal(str(schedule["id"]))
|
||||
assert sched["status"] == "paused"
|
||||
assert "schedule.autopaused" in {e[0] for e in stub_events}
|
||||
|
||||
def test_denied_with_empty_output_marks_tool_not_allowed(
|
||||
self, pg_engine, patched_engine, stub_events,
|
||||
):
|
||||
with pg_engine.begin() as conn:
|
||||
schedule, run, _ = _make_pending_run(conn)
|
||||
with patch(
|
||||
"application.api.user.scheduler_worker.run_agent_headless",
|
||||
return_value={
|
||||
"answer": "",
|
||||
"tool_calls": [],
|
||||
"sources": [],
|
||||
"thought": "",
|
||||
"prompt_tokens": 1,
|
||||
"generated_tokens": 0,
|
||||
"denied": [{"tool_name": "telegram"}],
|
||||
"error_type": "tool_not_allowed",
|
||||
"model_id": "fake",
|
||||
},
|
||||
):
|
||||
execute_scheduled_run_body(str(run["id"]), "celery-4")
|
||||
with pg_engine.connect() as conn:
|
||||
row = ScheduleRunsRepository(conn).get_internal(str(run["id"]))
|
||||
assert row["status"] == "failed"
|
||||
assert row["error_type"] == "tool_not_allowed"
|
||||
|
||||
def test_one_time_loads_chat_history(
|
||||
self, pg_engine, patched_engine, stub_events,
|
||||
):
|
||||
with pg_engine.begin() as conn:
|
||||
agent_id = _make_agent(conn)
|
||||
schedule = SchedulesRepository(conn).create(
|
||||
user_id="u1", agent_id=agent_id, trigger_type="once",
|
||||
instruction="follow up", run_at=_now() + timedelta(seconds=5),
|
||||
next_run_at=_now(),
|
||||
)
|
||||
conv_id = conn.execute(
|
||||
text(
|
||||
"INSERT INTO conversations (user_id, agent_id, name) "
|
||||
"VALUES ('u1', CAST(:a AS uuid), 'origin') RETURNING id"
|
||||
),
|
||||
{"a": agent_id},
|
||||
).fetchone()[0]
|
||||
SchedulesRepository(conn).update_internal(
|
||||
str(schedule["id"]),
|
||||
{"origin_conversation_id": str(conv_id)},
|
||||
)
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO conversation_messages
|
||||
(conversation_id, position, prompt, response, user_id)
|
||||
VALUES (CAST(:c AS uuid), 0, 'hello', 'hi', 'u1')
|
||||
"""
|
||||
),
|
||||
{"c": str(conv_id)},
|
||||
)
|
||||
run = ScheduleRunsRepository(conn).record_pending(
|
||||
str(schedule["id"]), "u1", agent_id, _now(),
|
||||
)
|
||||
captured: dict = {}
|
||||
def _fake_run(agent_config, query, **kwargs):
|
||||
captured.update(kwargs)
|
||||
return {
|
||||
"answer": "follow-up answer",
|
||||
"tool_calls": [], "sources": [], "thought": "",
|
||||
"prompt_tokens": 1, "generated_tokens": 1,
|
||||
"denied": [], "error_type": None, "model_id": "fake",
|
||||
}
|
||||
with patch(
|
||||
"application.api.user.scheduler_worker.run_agent_headless", _fake_run,
|
||||
):
|
||||
execute_scheduled_run_body(str(run["id"]), "celery-h")
|
||||
assert len(captured.get("chat_history", [])) == 1
|
||||
assert captured["chat_history"][0]["prompt"] == "hello"
|
||||
|
||||
def test_agentless_schedule_uses_system_defaults_and_appends(
|
||||
self, pg_engine, patched_engine, stub_events,
|
||||
):
|
||||
"""Agentless ``once`` schedule → ephemeral classic agent → message appended."""
|
||||
with pg_engine.begin() as conn:
|
||||
conv_id = conn.execute(
|
||||
text(
|
||||
"INSERT INTO conversations (user_id, name) "
|
||||
"VALUES ('u1', 'agentless-origin') RETURNING id"
|
||||
)
|
||||
).fetchone()[0]
|
||||
schedule = SchedulesRepository(conn).create(
|
||||
user_id="u1", agent_id=None, trigger_type="once",
|
||||
instruction="follow up agentless",
|
||||
run_at=_now() + timedelta(seconds=5),
|
||||
next_run_at=_now(),
|
||||
origin_conversation_id=str(conv_id),
|
||||
created_via="chat",
|
||||
)
|
||||
run = ScheduleRunsRepository(conn).record_pending(
|
||||
str(schedule["id"]), "u1", None, _now(),
|
||||
)
|
||||
captured: dict = {}
|
||||
|
||||
def _fake_run(agent_config, query, **kwargs):
|
||||
captured["agent_config"] = agent_config
|
||||
captured["kwargs"] = kwargs
|
||||
return {
|
||||
"answer": "agentless ran",
|
||||
"tool_calls": [], "sources": [], "thought": "",
|
||||
"prompt_tokens": 4, "generated_tokens": 6,
|
||||
"denied": [], "error_type": None, "model_id": "fake",
|
||||
}
|
||||
|
||||
with patch(
|
||||
"application.api.user.scheduler_worker.run_agent_headless",
|
||||
_fake_run,
|
||||
):
|
||||
result = execute_scheduled_run_body(str(run["id"]), "celery-agentless")
|
||||
assert result["status"] == "success"
|
||||
# Ephemeral classic config: no source, default retriever, no agent id.
|
||||
cfg = captured["agent_config"]
|
||||
assert cfg["id"] is None
|
||||
assert cfg["user_id"] == "u1"
|
||||
assert cfg["agent_type"] == "classic"
|
||||
assert cfg["retriever"] == "classic"
|
||||
assert cfg["prompt_id"] == "default"
|
||||
with pg_engine.connect() as conn:
|
||||
row = ScheduleRunsRepository(conn).get_internal(str(run["id"]))
|
||||
messages = conn.execute(
|
||||
text(
|
||||
"SELECT * FROM conversation_messages "
|
||||
"WHERE conversation_id = CAST(:c AS uuid)"
|
||||
),
|
||||
{"c": str(conv_id)},
|
||||
).fetchall()
|
||||
assert row["status"] == "success"
|
||||
assert row["output"] == "agentless ran"
|
||||
assert row["conversation_id"] is not None
|
||||
assert len(messages) == 1
|
||||
# The published event payload tolerates a NULL agent_id.
|
||||
appended_events = [e for e in stub_events if e[0] == "schedule.message.appended"]
|
||||
assert appended_events
|
||||
|
||||
def test_agentless_ephemeral_config_omits_tools_snapshot(
|
||||
self, pg_engine, patched_engine, stub_events,
|
||||
):
|
||||
"""Dead ``tools`` snapshot dropped — toolset is rebuilt at fire time."""
|
||||
with pg_engine.begin() as conn:
|
||||
conv_id = conn.execute(
|
||||
text(
|
||||
"INSERT INTO conversations (user_id, name) "
|
||||
"VALUES ('u1', 'no-tools-snap') RETURNING id"
|
||||
)
|
||||
).fetchone()[0]
|
||||
schedule = SchedulesRepository(conn).create(
|
||||
user_id="u1", agent_id=None, trigger_type="once",
|
||||
instruction="x", run_at=_now() + timedelta(seconds=5),
|
||||
next_run_at=_now(),
|
||||
origin_conversation_id=str(conv_id),
|
||||
created_via="chat",
|
||||
)
|
||||
run = ScheduleRunsRepository(conn).record_pending(
|
||||
str(schedule["id"]), "u1", None, _now(),
|
||||
)
|
||||
captured: dict = {}
|
||||
|
||||
def _fake_run(agent_config, query, **kwargs):
|
||||
captured["agent_config"] = agent_config
|
||||
return {
|
||||
"answer": "ok", "tool_calls": [], "sources": [], "thought": "",
|
||||
"prompt_tokens": 1, "generated_tokens": 1,
|
||||
"denied": [], "error_type": None, "model_id": "fake",
|
||||
}
|
||||
|
||||
with patch(
|
||||
"application.api.user.scheduler_worker.run_agent_headless",
|
||||
_fake_run,
|
||||
):
|
||||
execute_scheduled_run_body(str(run["id"]), "celery-no-snap")
|
||||
cfg = captured["agent_config"]
|
||||
# ``tools`` MUST NOT be in the ephemeral shape — the runtime
|
||||
# toolset is rebuilt by ``ToolExecutor`` (which honours headless
|
||||
# filtering for chat-only tools like ``scheduler``).
|
||||
assert "tools" not in cfg
|
||||
|
||||
def test_agentless_token_usage_row_has_null_agent_id(
|
||||
self, pg_engine, patched_engine, stub_events,
|
||||
):
|
||||
"""token_usage row for an agentless run carries ``agent_id IS NULL``."""
|
||||
with pg_engine.begin() as conn:
|
||||
conv_id = conn.execute(
|
||||
text(
|
||||
"INSERT INTO conversations (user_id, name) "
|
||||
"VALUES ('u1', 'agentless-tu') RETURNING id"
|
||||
)
|
||||
).fetchone()[0]
|
||||
schedule = SchedulesRepository(conn).create(
|
||||
user_id="u1", agent_id=None, trigger_type="once",
|
||||
instruction="tu", run_at=_now() + timedelta(seconds=5),
|
||||
next_run_at=_now(),
|
||||
origin_conversation_id=str(conv_id),
|
||||
created_via="chat",
|
||||
)
|
||||
run = ScheduleRunsRepository(conn).record_pending(
|
||||
str(schedule["id"]), "u1", None, _now(),
|
||||
)
|
||||
with patch(
|
||||
"application.api.user.scheduler_worker.run_agent_headless",
|
||||
return_value={
|
||||
"answer": "yes",
|
||||
"tool_calls": [], "sources": [], "thought": "",
|
||||
"prompt_tokens": 11, "generated_tokens": 7,
|
||||
"denied": [], "error_type": None, "model_id": "fake",
|
||||
},
|
||||
):
|
||||
execute_scheduled_run_body(str(run["id"]), "celery-tu")
|
||||
with pg_engine.connect() as conn:
|
||||
tu_row = conn.execute(
|
||||
text(
|
||||
"SELECT * FROM token_usage "
|
||||
"WHERE request_id = :r"
|
||||
),
|
||||
{"r": str(run["id"])},
|
||||
).fetchone()
|
||||
assert tu_row is not None
|
||||
assert tu_row._mapping["agent_id"] is None
|
||||
assert tu_row._mapping["source"] == "schedule"
|
||||
|
||||
def test_one_time_appends_message(
|
||||
self, pg_engine, patched_engine, stub_events,
|
||||
):
|
||||
with pg_engine.begin() as conn:
|
||||
agent_id = _make_agent(conn)
|
||||
schedule = SchedulesRepository(conn).create(
|
||||
user_id="u1", agent_id=agent_id, trigger_type="once",
|
||||
instruction="hello", run_at=_now() + timedelta(seconds=5),
|
||||
next_run_at=_now(),
|
||||
)
|
||||
conv_id = conn.execute(
|
||||
text(
|
||||
"INSERT INTO conversations (user_id, agent_id, name) "
|
||||
"VALUES ('u1', CAST(:a AS uuid), 'origin') RETURNING id"
|
||||
),
|
||||
{"a": agent_id},
|
||||
).fetchone()[0]
|
||||
SchedulesRepository(conn).update_internal(
|
||||
str(schedule["id"]),
|
||||
{"origin_conversation_id": str(conv_id)},
|
||||
)
|
||||
run = ScheduleRunsRepository(conn).record_pending(
|
||||
str(schedule["id"]), "u1", agent_id, _now(),
|
||||
)
|
||||
with patch(
|
||||
"application.api.user.scheduler_worker.run_agent_headless",
|
||||
return_value={
|
||||
"answer": "scheduled answer",
|
||||
"tool_calls": [],
|
||||
"sources": [],
|
||||
"thought": "",
|
||||
"prompt_tokens": 2,
|
||||
"generated_tokens": 3,
|
||||
"denied": [],
|
||||
"error_type": None,
|
||||
"model_id": "fake",
|
||||
},
|
||||
):
|
||||
execute_scheduled_run_body(str(run["id"]), "celery-5")
|
||||
with pg_engine.connect() as conn:
|
||||
row = ScheduleRunsRepository(conn).get_internal(str(run["id"]))
|
||||
messages = conn.execute(
|
||||
text(
|
||||
"SELECT * FROM conversation_messages "
|
||||
"WHERE conversation_id = CAST(:c AS uuid)"
|
||||
),
|
||||
{"c": str(conv_id)},
|
||||
).fetchall()
|
||||
assert row["conversation_id"] is not None
|
||||
assert row["message_id"] is not None
|
||||
assert len(messages) == 1
|
||||
meta = messages[0]._mapping["message_metadata"]
|
||||
assert meta.get("scheduled") is True
|
||||
assert "schedule.message.appended" in {e[0] for e in stub_events}
|
||||
474
tests/api/user/test_schedules_routes.py
Normal file
474
tests/api/user/test_schedules_routes.py
Normal file
@@ -0,0 +1,474 @@
|
||||
"""Tests for the schedules REST API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from sqlalchemy import text
|
||||
|
||||
from application.storage.db.repositories.schedules import SchedulesRepository
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
return Flask(__name__)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _patch_db(conn):
|
||||
@contextmanager
|
||||
def _yield():
|
||||
yield conn
|
||||
|
||||
with patch(
|
||||
"application.api.user.schedules.routes.db_session", _yield,
|
||||
), patch(
|
||||
"application.api.user.schedules.routes.db_readonly", _yield,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def _make_agent(conn, user_id: str = "u1") -> str:
|
||||
row = conn.execute(
|
||||
text(
|
||||
"INSERT INTO agents (user_id, name, status) "
|
||||
"VALUES (:u, 'a', 'draft') RETURNING id"
|
||||
),
|
||||
{"u": user_id},
|
||||
).fetchone()
|
||||
return str(row[0])
|
||||
|
||||
|
||||
class TestCreateRecurring:
|
||||
def test_unauthorized(self, app):
|
||||
from application.api.user.schedules.routes import AgentSchedules
|
||||
|
||||
with app.test_request_context(
|
||||
"/api/agents/x/schedules", method="POST", json={},
|
||||
):
|
||||
from flask import request
|
||||
request.decoded_token = None
|
||||
resp = AgentSchedules().post("x")
|
||||
assert resp.status_code == 401
|
||||
|
||||
def test_agent_not_found(self, app, pg_conn):
|
||||
from application.api.user.schedules.routes import AgentSchedules
|
||||
|
||||
with _patch_db(pg_conn), app.test_request_context(
|
||||
"/api/agents/00000000-0000-0000-0000-000000000000/schedules",
|
||||
method="POST", json={"instruction": "x", "cron": "* * * * *"},
|
||||
):
|
||||
from flask import request
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
resp = AgentSchedules().post(
|
||||
"00000000-0000-0000-0000-000000000000",
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_invalid_cron(self, app, pg_conn):
|
||||
from application.api.user.schedules.routes import AgentSchedules
|
||||
|
||||
agent_id = _make_agent(pg_conn)
|
||||
with _patch_db(pg_conn), app.test_request_context(
|
||||
f"/api/agents/{agent_id}/schedules",
|
||||
method="POST",
|
||||
json={"instruction": "x", "cron": "not a cron"},
|
||||
):
|
||||
from flask import request
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
resp = AgentSchedules().post(agent_id)
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_create_success(self, app, pg_conn):
|
||||
from application.api.user.schedules.routes import AgentSchedules
|
||||
|
||||
agent_id = _make_agent(pg_conn)
|
||||
with _patch_db(pg_conn), app.test_request_context(
|
||||
f"/api/agents/{agent_id}/schedules",
|
||||
method="POST",
|
||||
json={
|
||||
"instruction": "weekly digest",
|
||||
"cron": "0 9 * * 1",
|
||||
"timezone": "Europe/Warsaw",
|
||||
"tool_allowlist": [],
|
||||
},
|
||||
):
|
||||
from flask import request
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
resp = AgentSchedules().post(agent_id)
|
||||
assert resp.status_code == 201
|
||||
body = resp.get_json()
|
||||
assert body["schedule"]["cron"] == "0 9 * * 1"
|
||||
assert body["schedule"]["timezone"] == "Europe/Warsaw"
|
||||
|
||||
|
||||
class TestCreateOnce:
|
||||
def test_creates_once_with_run_at(self, app, pg_conn):
|
||||
from application.api.user.schedules.routes import AgentSchedules
|
||||
|
||||
agent_id = _make_agent(pg_conn)
|
||||
run_at = (_now() + timedelta(hours=2)).isoformat().replace(
|
||||
"+00:00", "Z",
|
||||
)
|
||||
with _patch_db(pg_conn), app.test_request_context(
|
||||
f"/api/agents/{agent_id}/schedules",
|
||||
method="POST",
|
||||
json={
|
||||
"instruction": "remind me",
|
||||
"trigger_type": "once",
|
||||
"run_at": run_at,
|
||||
"timezone": "UTC",
|
||||
"tool_allowlist": [],
|
||||
},
|
||||
):
|
||||
from flask import request
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
resp = AgentSchedules().post(agent_id)
|
||||
assert resp.status_code == 201
|
||||
body = resp.get_json()
|
||||
assert body["schedule"]["trigger_type"] == "once"
|
||||
assert body["schedule"]["run_at"] is not None
|
||||
|
||||
def test_once_requires_run_at(self, app, pg_conn):
|
||||
from application.api.user.schedules.routes import AgentSchedules
|
||||
|
||||
agent_id = _make_agent(pg_conn)
|
||||
with _patch_db(pg_conn), app.test_request_context(
|
||||
f"/api/agents/{agent_id}/schedules",
|
||||
method="POST",
|
||||
json={
|
||||
"instruction": "remind me",
|
||||
"trigger_type": "once",
|
||||
},
|
||||
):
|
||||
from flask import request
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
resp = AgentSchedules().post(agent_id)
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_once_rejects_past_run_at(self, app, pg_conn):
|
||||
from application.api.user.schedules.routes import AgentSchedules
|
||||
|
||||
agent_id = _make_agent(pg_conn)
|
||||
past = (_now() - timedelta(hours=1)).isoformat().replace(
|
||||
"+00:00", "Z",
|
||||
)
|
||||
with _patch_db(pg_conn), app.test_request_context(
|
||||
f"/api/agents/{agent_id}/schedules",
|
||||
method="POST",
|
||||
json={
|
||||
"instruction": "x",
|
||||
"trigger_type": "once",
|
||||
"run_at": past,
|
||||
},
|
||||
):
|
||||
from flask import request
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
resp = AgentSchedules().post(agent_id)
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_recurring_default_when_trigger_type_omitted(self, app, pg_conn):
|
||||
"""Backwards compat: a payload with cron but no trigger_type still works."""
|
||||
from application.api.user.schedules.routes import AgentSchedules
|
||||
|
||||
agent_id = _make_agent(pg_conn)
|
||||
with _patch_db(pg_conn), app.test_request_context(
|
||||
f"/api/agents/{agent_id}/schedules",
|
||||
method="POST",
|
||||
json={
|
||||
"instruction": "hourly",
|
||||
"cron": "0 * * * *",
|
||||
"timezone": "UTC",
|
||||
},
|
||||
):
|
||||
from flask import request
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
resp = AgentSchedules().post(agent_id)
|
||||
assert resp.status_code == 201
|
||||
assert resp.get_json()["schedule"]["trigger_type"] == "recurring"
|
||||
|
||||
|
||||
class TestListForAgent:
|
||||
def test_list(self, app, pg_conn):
|
||||
from application.api.user.schedules.routes import AgentSchedules
|
||||
|
||||
agent_id = _make_agent(pg_conn)
|
||||
SchedulesRepository(pg_conn).create(
|
||||
user_id="u1", agent_id=agent_id, trigger_type="recurring",
|
||||
instruction="i", cron="* * * * *",
|
||||
next_run_at=_now() + timedelta(hours=1),
|
||||
)
|
||||
with _patch_db(pg_conn), app.test_request_context(
|
||||
f"/api/agents/{agent_id}/schedules", method="GET",
|
||||
):
|
||||
from flask import request
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
resp = AgentSchedules().get(agent_id)
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.get_json()["schedules"]) == 1
|
||||
|
||||
|
||||
class TestGetEditPatchDelete:
|
||||
def _make(self, conn, **kwargs):
|
||||
return SchedulesRepository(conn).create(**kwargs)
|
||||
|
||||
def test_get_owner_scoped(self, app, pg_conn):
|
||||
from application.api.user.schedules.routes import ScheduleResource
|
||||
|
||||
agent_id = _make_agent(pg_conn)
|
||||
s = self._make(
|
||||
pg_conn,
|
||||
user_id="u1", agent_id=agent_id, trigger_type="recurring",
|
||||
instruction="i", cron="* * * * *",
|
||||
next_run_at=_now() + timedelta(hours=1),
|
||||
)
|
||||
with _patch_db(pg_conn), app.test_request_context(
|
||||
f"/api/schedules/{s['id']}", method="GET",
|
||||
):
|
||||
from flask import request
|
||||
request.decoded_token = {"sub": "u2"}
|
||||
resp = ScheduleResource().get(str(s["id"]))
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_pause_then_resume(self, app, pg_conn):
|
||||
from application.api.user.schedules.routes import ScheduleResource
|
||||
|
||||
agent_id = _make_agent(pg_conn)
|
||||
s = self._make(
|
||||
pg_conn,
|
||||
user_id="u1", agent_id=agent_id, trigger_type="recurring",
|
||||
instruction="i", cron="* * * * *",
|
||||
next_run_at=_now() + timedelta(hours=1),
|
||||
)
|
||||
with _patch_db(pg_conn), app.test_request_context(
|
||||
f"/api/schedules/{s['id']}",
|
||||
method="PATCH",
|
||||
json={"action": "pause"},
|
||||
):
|
||||
from flask import request
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
resp = ScheduleResource().patch(str(s["id"]))
|
||||
assert resp.status_code == 200
|
||||
assert resp.get_json()["schedule"]["status"] == "paused"
|
||||
with _patch_db(pg_conn), app.test_request_context(
|
||||
f"/api/schedules/{s['id']}",
|
||||
method="PATCH",
|
||||
json={"action": "resume"},
|
||||
):
|
||||
from flask import request
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
resp = ScheduleResource().patch(str(s["id"]))
|
||||
assert resp.status_code == 200
|
||||
body = resp.get_json()
|
||||
assert body["schedule"]["status"] == "active"
|
||||
assert body["schedule"]["next_run_at"] is not None
|
||||
|
||||
def test_delete_owner_scoped(self, app, pg_conn):
|
||||
from application.api.user.schedules.routes import ScheduleResource
|
||||
|
||||
agent_id = _make_agent(pg_conn)
|
||||
s = self._make(
|
||||
pg_conn,
|
||||
user_id="u1", agent_id=agent_id, trigger_type="recurring",
|
||||
instruction="i", cron="* * * * *",
|
||||
next_run_at=_now() + timedelta(hours=1),
|
||||
)
|
||||
with _patch_db(pg_conn), app.test_request_context(
|
||||
f"/api/schedules/{s['id']}", method="DELETE",
|
||||
):
|
||||
from flask import request
|
||||
request.decoded_token = {"sub": "u2"}
|
||||
resp = ScheduleResource().delete(str(s["id"]))
|
||||
assert resp.status_code == 404
|
||||
with _patch_db(pg_conn), app.test_request_context(
|
||||
f"/api/schedules/{s['id']}", method="DELETE",
|
||||
):
|
||||
from flask import request
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
resp = ScheduleResource().delete(str(s["id"]))
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_put_invalid_cron(self, app, pg_conn):
|
||||
from application.api.user.schedules.routes import ScheduleResource
|
||||
|
||||
agent_id = _make_agent(pg_conn)
|
||||
s = self._make(
|
||||
pg_conn,
|
||||
user_id="u1", agent_id=agent_id, trigger_type="recurring",
|
||||
instruction="i", cron="* * * * *",
|
||||
next_run_at=_now() + timedelta(hours=1),
|
||||
)
|
||||
with _patch_db(pg_conn), app.test_request_context(
|
||||
f"/api/schedules/{s['id']}",
|
||||
method="PUT", json={"cron": "bad"},
|
||||
):
|
||||
from flask import request
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
resp = ScheduleResource().put(str(s["id"]))
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
class TestRunNow:
|
||||
def test_runs_returns_202(self, app, pg_conn):
|
||||
from application.api.user.schedules.routes import ScheduleRunNow
|
||||
|
||||
agent_id = _make_agent(pg_conn)
|
||||
s = SchedulesRepository(pg_conn).create(
|
||||
user_id="u1", agent_id=agent_id, trigger_type="recurring",
|
||||
instruction="i", cron="* * * * *",
|
||||
next_run_at=_now() + timedelta(hours=1),
|
||||
)
|
||||
with _patch_db(pg_conn), patch(
|
||||
"application.api.user.tasks.execute_scheduled_run",
|
||||
type("T", (), {"apply_async": staticmethod(lambda **k: None)}),
|
||||
), app.test_request_context(
|
||||
f"/api/schedules/{s['id']}/run", method="POST",
|
||||
):
|
||||
from flask import request
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
resp = ScheduleRunNow().post(str(s["id"]))
|
||||
assert resp.status_code == 202
|
||||
|
||||
def test_second_run_blocked_by_active(self, app, pg_conn):
|
||||
"""Run-Now serializes via FOR UPDATE + has_active_run; second 409s."""
|
||||
from application.api.user.schedules.routes import ScheduleRunNow
|
||||
|
||||
agent_id = _make_agent(pg_conn)
|
||||
s = SchedulesRepository(pg_conn).create(
|
||||
user_id="u1", agent_id=agent_id, trigger_type="recurring",
|
||||
instruction="i", cron="0 9 * * 1",
|
||||
next_run_at=_now() + timedelta(hours=1),
|
||||
)
|
||||
with _patch_db(pg_conn), patch(
|
||||
"application.api.user.tasks.execute_scheduled_run",
|
||||
type("T", (), {"apply_async": staticmethod(lambda **k: None)}),
|
||||
), app.test_request_context(
|
||||
f"/api/schedules/{s['id']}/run", method="POST",
|
||||
):
|
||||
from flask import request
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
first = ScheduleRunNow().post(str(s["id"]))
|
||||
assert first.status_code == 202
|
||||
second = ScheduleRunNow().post(str(s["id"]))
|
||||
assert second.status_code == 409
|
||||
|
||||
|
||||
class TestMinInterval:
|
||||
def test_create_rejects_below_min_interval(self, app, pg_conn):
|
||||
from application.api.user.schedules.routes import AgentSchedules
|
||||
|
||||
agent_id = _make_agent(pg_conn)
|
||||
with _patch_db(pg_conn), app.test_request_context(
|
||||
f"/api/agents/{agent_id}/schedules",
|
||||
method="POST",
|
||||
json={"instruction": "x", "cron": "* * * * *"},
|
||||
):
|
||||
from flask import request
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
resp = AgentSchedules().post(agent_id)
|
||||
assert resp.status_code == 400
|
||||
assert "minimum interval" in resp.get_json()["message"]
|
||||
|
||||
def test_put_rejects_below_min_interval(self, app, pg_conn):
|
||||
from application.api.user.schedules.routes import ScheduleResource
|
||||
|
||||
agent_id = _make_agent(pg_conn)
|
||||
s = SchedulesRepository(pg_conn).create(
|
||||
user_id="u1", agent_id=agent_id, trigger_type="recurring",
|
||||
instruction="i", cron="0 9 * * 1",
|
||||
next_run_at=_now() + timedelta(hours=1),
|
||||
)
|
||||
with _patch_db(pg_conn), app.test_request_context(
|
||||
f"/api/schedules/{s['id']}", method="PUT",
|
||||
json={"cron": "*/5 * * * *"},
|
||||
):
|
||||
from flask import request
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
resp = ScheduleResource().put(str(s["id"]))
|
||||
assert resp.status_code == 400
|
||||
assert "minimum interval" in resp.get_json()["message"]
|
||||
|
||||
|
||||
class TestResumeOnceStale:
|
||||
def test_stale_run_at_returns_clear_409(self, app, pg_conn):
|
||||
from application.api.user.schedules.routes import ScheduleResource
|
||||
|
||||
agent_id = _make_agent(pg_conn)
|
||||
s = SchedulesRepository(pg_conn).create(
|
||||
user_id="u1", agent_id=agent_id, trigger_type="once",
|
||||
instruction="i", run_at=_now() + timedelta(hours=1),
|
||||
next_run_at=_now() + timedelta(hours=1),
|
||||
status="paused",
|
||||
)
|
||||
pg_conn.execute(
|
||||
text(
|
||||
"UPDATE schedules SET run_at = now() - interval '1 day' "
|
||||
"WHERE id = CAST(:id AS uuid)"
|
||||
),
|
||||
{"id": str(s["id"])},
|
||||
)
|
||||
with _patch_db(pg_conn), app.test_request_context(
|
||||
f"/api/schedules/{s['id']}", method="PATCH",
|
||||
json={"action": "resume"},
|
||||
):
|
||||
from flask import request
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
resp = ScheduleResource().patch(str(s["id"]))
|
||||
assert resp.status_code == 409
|
||||
assert "elapsed" in resp.get_json()["message"]
|
||||
|
||||
def test_resume_accepts_new_run_at(self, app, pg_conn):
|
||||
from application.api.user.schedules.routes import ScheduleResource
|
||||
|
||||
agent_id = _make_agent(pg_conn)
|
||||
s = SchedulesRepository(pg_conn).create(
|
||||
user_id="u1", agent_id=agent_id, trigger_type="once",
|
||||
instruction="i", run_at=_now() + timedelta(hours=1),
|
||||
next_run_at=_now() + timedelta(hours=1),
|
||||
status="paused",
|
||||
)
|
||||
pg_conn.execute(
|
||||
text(
|
||||
"UPDATE schedules SET run_at = now() - interval '1 day' "
|
||||
"WHERE id = CAST(:id AS uuid)"
|
||||
),
|
||||
{"id": str(s["id"])},
|
||||
)
|
||||
new_run_at = (_now() + timedelta(hours=3)).isoformat()
|
||||
with _patch_db(pg_conn), app.test_request_context(
|
||||
f"/api/schedules/{s['id']}", method="PATCH",
|
||||
json={"action": "resume", "run_at": new_run_at},
|
||||
):
|
||||
from flask import request
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
resp = ScheduleResource().patch(str(s["id"]))
|
||||
assert resp.status_code == 200
|
||||
body = resp.get_json()["schedule"]
|
||||
assert body["status"] == "active"
|
||||
|
||||
|
||||
class TestRunList:
|
||||
def test_list_owner_scoped(self, app, pg_conn):
|
||||
from application.api.user.schedules.routes import ScheduleRunList
|
||||
|
||||
agent_id = _make_agent(pg_conn)
|
||||
s = SchedulesRepository(pg_conn).create(
|
||||
user_id="u1", agent_id=agent_id, trigger_type="recurring",
|
||||
instruction="i", cron="* * * * *",
|
||||
next_run_at=_now() + timedelta(hours=1),
|
||||
)
|
||||
with _patch_db(pg_conn), app.test_request_context(
|
||||
f"/api/schedules/{s['id']}/runs", method="GET",
|
||||
):
|
||||
from flask import request
|
||||
request.decoded_token = {"sub": "u2"}
|
||||
resp = ScheduleRunList().get(str(s["id"]))
|
||||
assert resp.status_code == 404
|
||||
@@ -222,7 +222,7 @@ class TestSetupPeriodicTasks:
|
||||
|
||||
setup_periodic_tasks(sender)
|
||||
|
||||
assert sender.add_periodic_task.call_count == 8
|
||||
assert sender.add_periodic_task.call_count == 11
|
||||
|
||||
calls = sender.add_periodic_task.call_args_list
|
||||
|
||||
@@ -246,6 +246,14 @@ class TestSetupPeriodicTasks:
|
||||
# message_events retention sweep (24h)
|
||||
assert calls[7][0][0] == timedelta(hours=24)
|
||||
assert calls[7][1].get("name") == "cleanup-message-events"
|
||||
# orphan memories sweep (24h)
|
||||
assert calls[8][0][0] == timedelta(hours=24)
|
||||
assert calls[8][1].get("name") == "cleanup-orphan-memories"
|
||||
# scheduler dispatcher
|
||||
assert calls[9][1].get("name") == "dispatch-scheduled-runs"
|
||||
# schedule runs cleanup (24h)
|
||||
assert calls[10][0][0] == timedelta(hours=24)
|
||||
assert calls[10][1].get("name") == "cleanup-schedule-runs"
|
||||
|
||||
|
||||
class TestMcpOauthTask:
|
||||
@@ -296,6 +304,7 @@ class TestDurableTaskRetryPolicy:
|
||||
"cleanup_pending_tool_state",
|
||||
"reconciliation_task",
|
||||
"version_check_task",
|
||||
"cleanup_orphan_memories",
|
||||
],
|
||||
)
|
||||
def test_short_periodic_tasks_have_no_retry_config(self, task_name):
|
||||
@@ -515,6 +524,72 @@ class TestCleanupMessageEventsTask:
|
||||
assert [r["sequence_no"] for r in rows] == [1]
|
||||
|
||||
|
||||
class TestCleanupOrphanMemoriesTask:
|
||||
"""Sweeps orphan memories from the FK-to-trigger orphan window."""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_skips_when_postgres_uri_missing(self, monkeypatch):
|
||||
from application.api.user.tasks import cleanup_orphan_memories
|
||||
from application.core.settings import settings
|
||||
|
||||
monkeypatch.setattr(settings, "POSTGRES_URI", None, raising=False)
|
||||
|
||||
result = cleanup_orphan_memories.run()
|
||||
assert result == {"deleted": 0, "skipped": "POSTGRES_URI not set"}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_deletes_orphan_keeps_synthetic_and_live(
|
||||
self, pg_conn, monkeypatch
|
||||
):
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import text as _text
|
||||
|
||||
from application.agents.default_tools import default_tool_id
|
||||
from application.api.user.tasks import cleanup_orphan_memories
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.repositories.memories import (
|
||||
MemoriesRepository,
|
||||
)
|
||||
|
||||
repo = MemoriesRepository(pg_conn)
|
||||
synthetic_id = default_tool_id("memory")
|
||||
live_id = str(
|
||||
pg_conn.execute(
|
||||
_text(
|
||||
"INSERT INTO user_tools (user_id, name) "
|
||||
"VALUES ('u-task-mem', 'memory') RETURNING id"
|
||||
)
|
||||
).scalar()
|
||||
)
|
||||
orphan_id = str(uuid.uuid4())
|
||||
repo.upsert("u-task-mem", synthetic_id, "/syn.txt", "keep")
|
||||
repo.upsert("u-task-mem", live_id, "/live.txt", "keep")
|
||||
repo.upsert("u-task-mem", orphan_id, "/orphan.txt", "drop")
|
||||
|
||||
monkeypatch.setattr(
|
||||
settings, "POSTGRES_URI", "postgresql://stub", raising=False
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def _fake_begin():
|
||||
yield pg_conn
|
||||
|
||||
fake_engine = MagicMock()
|
||||
fake_engine.begin = _fake_begin
|
||||
|
||||
with patch(
|
||||
"application.storage.db.engine.get_engine",
|
||||
return_value=fake_engine,
|
||||
):
|
||||
result = cleanup_orphan_memories.run()
|
||||
|
||||
assert result == {"deleted": 1}
|
||||
assert repo.get_by_path("u-task-mem", synthetic_id, "/syn.txt")
|
||||
assert repo.get_by_path("u-task-mem", live_id, "/live.txt")
|
||||
assert repo.get_by_path("u-task-mem", orphan_id, "/orphan.txt") is None
|
||||
|
||||
|
||||
class TestIngestIdempotency:
|
||||
"""Same short-circuit applies to the ingest task path."""
|
||||
|
||||
|
||||
@@ -1066,6 +1066,11 @@ def _seed_tool(pg_conn, user="u-tools", name="read_webpage", config=None):
|
||||
|
||||
class TestGetToolsHappy:
|
||||
def test_returns_user_tools(self, app, pg_conn):
|
||||
from application.agents.default_tools import (
|
||||
BUILTIN_AGENT_TOOLS,
|
||||
loaded_builtin_agent_tools,
|
||||
loaded_default_tools,
|
||||
)
|
||||
from application.api.user.tools.routes import GetTools
|
||||
|
||||
user = "u-get-tools"
|
||||
@@ -1079,7 +1084,26 @@ class TestGetToolsHappy:
|
||||
response = GetTools().get()
|
||||
assert response.status_code == 200
|
||||
assert response.json["success"] is True
|
||||
assert len(response.json["tools"]) == 1
|
||||
tools = response.json["tools"]
|
||||
# Response shape: 1 explicit + every default + builtins not already
|
||||
# surfaced as a default (dual-registered ``scheduler`` is dedup'd).
|
||||
defaults_count = len(loaded_default_tools())
|
||||
builtins_count = len(loaded_builtin_agent_tools())
|
||||
dual = sum(
|
||||
1 for name in loaded_default_tools() if name in BUILTIN_AGENT_TOOLS
|
||||
)
|
||||
assert len(tools) == 1 + defaults_count + (builtins_count - dual)
|
||||
explicit = [
|
||||
t for t in tools
|
||||
if not t.get("default") and not t.get("builtin")
|
||||
]
|
||||
defaults = [t for t in tools if t.get("default")]
|
||||
builtins = [t for t in tools if t.get("builtin")]
|
||||
assert len(explicit) == 1
|
||||
assert len(defaults) == defaults_count
|
||||
# Dual-registered tools (scheduler) appear once with both flags;
|
||||
# ``builtins`` here counts them via ``builtin=True``.
|
||||
assert len(builtins) == (builtins_count - dual) + dual
|
||||
|
||||
def test_db_error_returns_400(self, app):
|
||||
from application.api.user.tools.routes import GetTools
|
||||
@@ -1379,8 +1403,291 @@ class TestGetArtifactHappy:
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Default chat tools — synthetic-id branching in the tool endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestDefaultToolsRoutes:
|
||||
def test_get_tools_flags_defaults(self, app, pg_conn):
|
||||
from application.agents.default_tools import (
|
||||
default_tool_id,
|
||||
loaded_default_tools,
|
||||
)
|
||||
from application.api.user.tools.routes import GetTools
|
||||
|
||||
user = "u-def-get"
|
||||
with _patch_tools_db(pg_conn), app.test_request_context("/api/get_tools"):
|
||||
from flask import request
|
||||
request.decoded_token = {"sub": user}
|
||||
response = GetTools().get()
|
||||
assert response.status_code == 200
|
||||
defaults = [t for t in response.json["tools"] if t.get("default")]
|
||||
names = {t["name"] for t in defaults}
|
||||
assert names == set(loaded_default_tools())
|
||||
for tool in defaults:
|
||||
assert tool["id"] == default_tool_id(tool["name"])
|
||||
assert tool["status"] is True
|
||||
|
||||
def test_get_tools_surfaces_scheduler_with_both_flags(self, app, pg_conn):
|
||||
"""Dual-registered scheduler appears once with default+builtin flags."""
|
||||
from application.agents.default_tools import default_tool_id
|
||||
from application.api.user.tools.routes import GetTools
|
||||
|
||||
user = "u-sched-dual"
|
||||
with _patch_tools_db(pg_conn), app.test_request_context("/api/get_tools"):
|
||||
from flask import request
|
||||
request.decoded_token = {"sub": user}
|
||||
response = GetTools().get()
|
||||
assert response.status_code == 200
|
||||
scheduler_id = default_tool_id("scheduler")
|
||||
scheduler_rows = [
|
||||
t for t in response.json["tools"] if t["id"] == scheduler_id
|
||||
]
|
||||
assert len(scheduler_rows) == 1 # dedup at the routes layer
|
||||
row = scheduler_rows[0]
|
||||
assert row["default"] is True
|
||||
assert row["builtin"] is True
|
||||
assert row["name"] == "scheduler"
|
||||
|
||||
def test_get_tools_status_reflects_opt_out(self, app, pg_conn):
|
||||
from application.api.user.tools.routes import GetTools
|
||||
from application.storage.db.repositories.users import UsersRepository
|
||||
|
||||
user = "u-def-optout"
|
||||
UsersRepository(pg_conn).set_default_tool_enabled(
|
||||
user, "read_webpage", False
|
||||
)
|
||||
with _patch_tools_db(pg_conn), app.test_request_context("/api/get_tools"):
|
||||
from flask import request
|
||||
request.decoded_token = {"sub": user}
|
||||
response = GetTools().get()
|
||||
by_name = {
|
||||
t["name"]: t
|
||||
for t in response.json["tools"]
|
||||
if t.get("default")
|
||||
}
|
||||
assert by_name["read_webpage"]["status"] is False
|
||||
assert by_name["memory"]["status"] is True
|
||||
|
||||
def test_update_tool_status_toggles_default_off(self, app, pg_conn):
|
||||
from application.agents.default_tools import default_tool_id
|
||||
from application.api.user.tools.routes import UpdateToolStatus
|
||||
from application.storage.db.repositories.users import UsersRepository
|
||||
|
||||
user = "u-def-toggle"
|
||||
with _patch_tools_db(pg_conn), app.test_request_context(
|
||||
"/api/update_tool_status",
|
||||
method="POST",
|
||||
json={"id": default_tool_id("memory"), "status": False},
|
||||
):
|
||||
from flask import request
|
||||
request.decoded_token = {"sub": user}
|
||||
response = UpdateToolStatus().post()
|
||||
assert response.status_code == 200
|
||||
user_doc = UsersRepository(pg_conn).get(user)
|
||||
assert user_doc["tool_preferences"]["disabled_default_tools"] == ["memory"]
|
||||
|
||||
def test_update_tool_status_toggles_default_back_on(self, app, pg_conn):
|
||||
from application.agents.default_tools import default_tool_id
|
||||
from application.api.user.tools.routes import UpdateToolStatus
|
||||
from application.storage.db.repositories.users import UsersRepository
|
||||
|
||||
user = "u-def-on"
|
||||
UsersRepository(pg_conn).set_default_tool_enabled(user, "memory", False)
|
||||
with _patch_tools_db(pg_conn), app.test_request_context(
|
||||
"/api/update_tool_status",
|
||||
method="POST",
|
||||
json={"id": default_tool_id("memory"), "status": True},
|
||||
):
|
||||
from flask import request
|
||||
request.decoded_token = {"sub": user}
|
||||
response = UpdateToolStatus().post()
|
||||
assert response.status_code == 200
|
||||
user_doc = UsersRepository(pg_conn).get(user)
|
||||
assert user_doc["tool_preferences"]["disabled_default_tools"] == []
|
||||
|
||||
def test_update_tool_toggles_default_via_status(self, app, pg_conn):
|
||||
from application.agents.default_tools import default_tool_id
|
||||
from application.api.user.tools.routes import UpdateTool
|
||||
from application.storage.db.repositories.users import UsersRepository
|
||||
|
||||
user = "u-def-updtool"
|
||||
with _patch_tools_db(pg_conn), app.test_request_context(
|
||||
"/api/update_tool",
|
||||
method="POST",
|
||||
json={"id": default_tool_id("read_webpage"), "status": False},
|
||||
):
|
||||
from flask import request
|
||||
request.decoded_token = {"sub": user}
|
||||
response = UpdateTool().post()
|
||||
assert response.status_code == 200
|
||||
user_doc = UsersRepository(pg_conn).get(user)
|
||||
assert user_doc["tool_preferences"]["disabled_default_tools"] == [
|
||||
"read_webpage"
|
||||
]
|
||||
|
||||
def test_delete_tool_rejects_default(self, app, pg_conn):
|
||||
from application.agents.default_tools import default_tool_id
|
||||
from application.api.user.tools.routes import DeleteTool
|
||||
|
||||
user = "u-def-del"
|
||||
with _patch_tools_db(pg_conn), app.test_request_context(
|
||||
"/api/delete_tool",
|
||||
method="POST",
|
||||
json={"id": default_tool_id("memory")},
|
||||
):
|
||||
from flask import request
|
||||
request.decoded_token = {"sub": user}
|
||||
response = DeleteTool().post()
|
||||
assert response.status_code == 400
|
||||
assert response.json["success"] is False
|
||||
|
||||
def test_update_tool_default_without_status_is_rejected(
|
||||
self, app, pg_conn
|
||||
):
|
||||
from application.agents.default_tools import default_tool_id
|
||||
from application.api.user.tools.routes import UpdateTool
|
||||
|
||||
user = "u-def-noedit"
|
||||
with _patch_tools_db(pg_conn), app.test_request_context(
|
||||
"/api/update_tool",
|
||||
method="POST",
|
||||
json={"id": default_tool_id("memory"), "displayName": "Renamed"},
|
||||
):
|
||||
from flask import request
|
||||
request.decoded_token = {"sub": user}
|
||||
response = UpdateTool().post()
|
||||
assert response.status_code == 400
|
||||
assert response.json["success"] is False
|
||||
assert "not editable" in response.json["message"]
|
||||
|
||||
def test_update_tool_config_rejects_default(self, app, pg_conn):
|
||||
from application.agents.default_tools import default_tool_id
|
||||
from application.api.user.tools.routes import UpdateToolConfig
|
||||
|
||||
user = "u-def-cfg"
|
||||
with _patch_tools_db(pg_conn), app.test_request_context(
|
||||
"/api/update_tool_config",
|
||||
method="POST",
|
||||
json={"id": default_tool_id("memory"), "config": {"x": 1}},
|
||||
):
|
||||
from flask import request
|
||||
request.decoded_token = {"sub": user}
|
||||
response = UpdateToolConfig().post()
|
||||
assert response.status_code == 400
|
||||
assert response.json["success"] is False
|
||||
assert "config-free" in response.json["message"]
|
||||
|
||||
def test_update_tool_actions_rejects_default(self, app, pg_conn):
|
||||
from application.agents.default_tools import default_tool_id
|
||||
from application.api.user.tools.routes import UpdateToolActions
|
||||
|
||||
user = "u-def-act"
|
||||
with _patch_tools_db(pg_conn), app.test_request_context(
|
||||
"/api/update_tool_actions",
|
||||
method="POST",
|
||||
json={"id": default_tool_id("memory"), "actions": []},
|
||||
):
|
||||
from flask import request
|
||||
request.decoded_token = {"sub": user}
|
||||
response = UpdateToolActions().post()
|
||||
assert response.status_code == 400
|
||||
assert response.json["success"] is False
|
||||
assert "not editable" in response.json["message"]
|
||||
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dual-registered tools (scheduler) — toggle MUST hit the default-tool path,
|
||||
# not the builtin "not editable" rejection. Regression for the iter-6 issue
|
||||
# where ``is_builtin_agent_tool_id`` was checked first, silently dropping the
|
||||
# write on a dual-registered uuid5.
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestDualRegisteredToggle:
|
||||
def test_update_tool_status_off_writes_disabled_default(self, app, pg_conn):
|
||||
from application.agents.default_tools import default_tool_id
|
||||
from application.api.user.tools.routes import UpdateToolStatus
|
||||
from application.storage.db.repositories.users import UsersRepository
|
||||
|
||||
user = "u-sched-off"
|
||||
with _patch_tools_db(pg_conn), app.test_request_context(
|
||||
"/api/update_tool_status",
|
||||
method="POST",
|
||||
json={"id": default_tool_id("scheduler"), "status": False},
|
||||
):
|
||||
from flask import request
|
||||
request.decoded_token = {"sub": user}
|
||||
response = UpdateToolStatus().post()
|
||||
assert response.status_code == 200
|
||||
assert response.json["success"] is True
|
||||
user_doc = UsersRepository(pg_conn).get(user)
|
||||
assert "scheduler" in (
|
||||
user_doc["tool_preferences"]["disabled_default_tools"]
|
||||
)
|
||||
|
||||
def test_update_tool_status_on_removes_disabled_default(self, app, pg_conn):
|
||||
from application.agents.default_tools import default_tool_id
|
||||
from application.api.user.tools.routes import UpdateToolStatus
|
||||
from application.storage.db.repositories.users import UsersRepository
|
||||
|
||||
user = "u-sched-on"
|
||||
UsersRepository(pg_conn).set_default_tool_enabled(
|
||||
user, "scheduler", False,
|
||||
)
|
||||
with _patch_tools_db(pg_conn), app.test_request_context(
|
||||
"/api/update_tool_status",
|
||||
method="POST",
|
||||
json={"id": default_tool_id("scheduler"), "status": True},
|
||||
):
|
||||
from flask import request
|
||||
request.decoded_token = {"sub": user}
|
||||
response = UpdateToolStatus().post()
|
||||
assert response.status_code == 200
|
||||
assert response.json["success"] is True
|
||||
user_doc = UsersRepository(pg_conn).get(user)
|
||||
assert "scheduler" not in (
|
||||
user_doc["tool_preferences"]["disabled_default_tools"]
|
||||
)
|
||||
|
||||
def test_update_tool_status_round_trip(self, app, pg_conn):
|
||||
"""Off → on returns to the empty-list baseline."""
|
||||
from application.agents.default_tools import default_tool_id
|
||||
from application.api.user.tools.routes import UpdateToolStatus
|
||||
from application.storage.db.repositories.users import UsersRepository
|
||||
|
||||
user = "u-sched-rt"
|
||||
scheduler_id = default_tool_id("scheduler")
|
||||
for status in (False, True):
|
||||
with _patch_tools_db(pg_conn), app.test_request_context(
|
||||
"/api/update_tool_status",
|
||||
method="POST",
|
||||
json={"id": scheduler_id, "status": status},
|
||||
):
|
||||
from flask import request
|
||||
request.decoded_token = {"sub": user}
|
||||
response = UpdateToolStatus().post()
|
||||
assert response.status_code == 200
|
||||
user_doc = UsersRepository(pg_conn).get(user)
|
||||
assert user_doc["tool_preferences"]["disabled_default_tools"] == []
|
||||
|
||||
def test_update_tool_with_status_writes_disabled_default(self, app, pg_conn):
|
||||
"""The /api/update_tool route also honours the default branch first."""
|
||||
from application.agents.default_tools import default_tool_id
|
||||
from application.api.user.tools.routes import UpdateTool
|
||||
from application.storage.db.repositories.users import UsersRepository
|
||||
|
||||
user = "u-sched-upd"
|
||||
with _patch_tools_db(pg_conn), app.test_request_context(
|
||||
"/api/update_tool",
|
||||
method="POST",
|
||||
json={"id": default_tool_id("scheduler"), "status": False},
|
||||
):
|
||||
from flask import request
|
||||
request.decoded_token = {"sub": user}
|
||||
response = UpdateTool().post()
|
||||
assert response.status_code == 200
|
||||
assert response.json["success"] is True
|
||||
user_doc = UsersRepository(pg_conn).get(user)
|
||||
assert "scheduler" in (
|
||||
user_doc["tool_preferences"]["disabled_default_tools"]
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -572,6 +572,7 @@ class TestSharedAgentResolvesOwnerBYOM:
|
||||
sp = StreamProcessor.__new__(StreamProcessor)
|
||||
sp.data = {}
|
||||
sp.initial_user_id = "caller"
|
||||
sp._agent_data = {"_id": "shared-agent", "user": "owner"}
|
||||
sp.agent_config = {
|
||||
"user_id": "owner",
|
||||
"default_model_id": owner_model["id"],
|
||||
|
||||
@@ -127,3 +127,96 @@ class TestUpdatePath:
|
||||
repo = _repo(pg_conn)
|
||||
tool_id = _make_tool(pg_conn)
|
||||
assert repo.update_path("u", tool_id, "/nope", "/new") is False
|
||||
|
||||
|
||||
class TestDefaultToolMemories:
|
||||
"""Synthetic-id memory writes work; real-tool delete still cascades via trigger."""
|
||||
|
||||
def test_synthetic_tool_id_memory_write_succeeds(self, pg_conn):
|
||||
from application.agents.default_tools import default_tool_id
|
||||
|
||||
repo = _repo(pg_conn)
|
||||
synthetic_id = default_tool_id("memory")
|
||||
doc = repo.upsert("u-syn-mem", synthetic_id, "/note.txt", "built-in")
|
||||
assert doc["content"] == "built-in"
|
||||
got = repo.get_by_path("u-syn-mem", synthetic_id, "/note.txt")
|
||||
assert got is not None and got["content"] == "built-in"
|
||||
|
||||
def test_built_in_and_explicit_memory_are_separate_stores(self, pg_conn):
|
||||
from application.agents.default_tools import default_tool_id
|
||||
|
||||
repo = _repo(pg_conn)
|
||||
synthetic_id = default_tool_id("memory")
|
||||
explicit_id = _make_tool(pg_conn, user_id="u-two-mem", name="memory")
|
||||
repo.upsert("u-two-mem", synthetic_id, "/x.txt", "from built-in")
|
||||
repo.upsert("u-two-mem", explicit_id, "/x.txt", "from explicit")
|
||||
assert (
|
||||
repo.get_by_path("u-two-mem", synthetic_id, "/x.txt")["content"]
|
||||
== "from built-in"
|
||||
)
|
||||
assert (
|
||||
repo.get_by_path("u-two-mem", explicit_id, "/x.txt")["content"]
|
||||
== "from explicit"
|
||||
)
|
||||
|
||||
def test_deleting_real_tool_purges_its_memories(self, pg_conn):
|
||||
from sqlalchemy import text
|
||||
|
||||
repo = _repo(pg_conn)
|
||||
tool_id = _make_tool(pg_conn, user_id="u-del-mem", name="memory")
|
||||
repo.upsert("u-del-mem", tool_id, "/keep.txt", "data")
|
||||
pg_conn.execute(
|
||||
text("DELETE FROM user_tools WHERE id = CAST(:id AS uuid)"),
|
||||
{"id": tool_id},
|
||||
)
|
||||
assert repo.get_by_path("u-del-mem", tool_id, "/keep.txt") is None
|
||||
|
||||
|
||||
class TestDeleteOrphans:
|
||||
"""``delete_orphans`` sweeps the FK-to-trigger orphan window."""
|
||||
|
||||
def test_removes_orphan_with_no_user_tools_row(self, pg_conn):
|
||||
import uuid
|
||||
|
||||
repo = _repo(pg_conn)
|
||||
orphan_tool_id = str(uuid.uuid4())
|
||||
repo.upsert("u-orphan", orphan_tool_id, "/x.txt", "stale")
|
||||
deleted = repo.delete_orphans()
|
||||
assert deleted == 1
|
||||
assert repo.get_by_path("u-orphan", orphan_tool_id, "/x.txt") is None
|
||||
|
||||
def test_keeps_memory_of_a_live_tool(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
tool_id = _make_tool(pg_conn, user_id="u-live", name="memory")
|
||||
repo.upsert("u-live", tool_id, "/keep.txt", "data")
|
||||
assert repo.delete_orphans() == 0
|
||||
assert repo.get_by_path("u-live", tool_id, "/keep.txt") is not None
|
||||
|
||||
def test_keeps_synthetic_default_tool_memory(self, pg_conn):
|
||||
from application.agents.default_tools import default_tool_id
|
||||
|
||||
repo = _repo(pg_conn)
|
||||
synthetic_id = default_tool_id("memory")
|
||||
repo.upsert("u-syn", synthetic_id, "/note.txt", "built-in")
|
||||
deleted = repo.delete_orphans(keep_tool_ids=[synthetic_id])
|
||||
assert deleted == 0
|
||||
assert repo.get_by_path("u-syn", synthetic_id, "/note.txt") is not None
|
||||
|
||||
def test_sweeps_orphan_but_spares_synthetic_and_live(self, pg_conn):
|
||||
import uuid
|
||||
|
||||
from application.agents.default_tools import default_tool_id
|
||||
|
||||
repo = _repo(pg_conn)
|
||||
synthetic_id = default_tool_id("memory")
|
||||
live_id = _make_tool(pg_conn, user_id="u-mix", name="memory")
|
||||
orphan_id = str(uuid.uuid4())
|
||||
repo.upsert("u-mix", synthetic_id, "/syn.txt", "keep-syn")
|
||||
repo.upsert("u-mix", live_id, "/live.txt", "keep-live")
|
||||
repo.upsert("u-mix", orphan_id, "/orphan.txt", "drop")
|
||||
|
||||
deleted = repo.delete_orphans(keep_tool_ids=[synthetic_id])
|
||||
assert deleted == 1
|
||||
assert repo.get_by_path("u-mix", synthetic_id, "/syn.txt") is not None
|
||||
assert repo.get_by_path("u-mix", live_id, "/live.txt") is not None
|
||||
assert repo.get_by_path("u-mix", orphan_id, "/orphan.txt") is None
|
||||
|
||||
198
tests/storage/db/repositories/test_schedule_runs.py
Normal file
198
tests/storage/db/repositories/test_schedule_runs.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""Tests for ScheduleRunsRepository."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from application.storage.db.repositories.schedule_runs import (
|
||||
ScheduleRunsRepository,
|
||||
)
|
||||
from application.storage.db.repositories.schedules import SchedulesRepository
|
||||
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def _make_schedule(conn, *, user_id: str = "u1") -> tuple[str, str]:
|
||||
agent_id = str(
|
||||
conn.execute(
|
||||
text(
|
||||
"INSERT INTO agents (user_id, name, status) "
|
||||
"VALUES (:u, 'a', 'draft') RETURNING id"
|
||||
),
|
||||
{"u": user_id},
|
||||
).fetchone()[0]
|
||||
)
|
||||
schedule = SchedulesRepository(conn).create(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
trigger_type="recurring",
|
||||
instruction="i",
|
||||
cron="* * * * *",
|
||||
next_run_at=_now() + timedelta(minutes=5),
|
||||
)
|
||||
return str(schedule["id"]), agent_id
|
||||
|
||||
|
||||
class TestRecordPending:
|
||||
def test_first_insert_wins(self, pg_conn):
|
||||
schedule_id, agent_id = _make_schedule(pg_conn)
|
||||
repo = ScheduleRunsRepository(pg_conn)
|
||||
scheduled_for = _now().replace(microsecond=0)
|
||||
first = repo.record_pending(
|
||||
schedule_id, "u1", agent_id, scheduled_for,
|
||||
)
|
||||
assert first is not None
|
||||
assert first["status"] == "pending"
|
||||
|
||||
def test_conflict_returns_none(self, pg_conn):
|
||||
schedule_id, agent_id = _make_schedule(pg_conn)
|
||||
repo = ScheduleRunsRepository(pg_conn)
|
||||
scheduled_for = _now().replace(microsecond=0)
|
||||
first = repo.record_pending(
|
||||
schedule_id, "u1", agent_id, scheduled_for,
|
||||
)
|
||||
second = repo.record_pending(
|
||||
schedule_id, "u1", agent_id, scheduled_for,
|
||||
)
|
||||
assert first is not None
|
||||
assert second is None
|
||||
|
||||
def test_different_scheduled_for_both_succeed(self, pg_conn):
|
||||
schedule_id, agent_id = _make_schedule(pg_conn)
|
||||
repo = ScheduleRunsRepository(pg_conn)
|
||||
first = repo.record_pending(
|
||||
schedule_id, "u1", agent_id, _now(),
|
||||
)
|
||||
second = repo.record_pending(
|
||||
schedule_id, "u1", agent_id, _now() + timedelta(seconds=1),
|
||||
)
|
||||
assert first is not None
|
||||
assert second is not None
|
||||
assert first["id"] != second["id"]
|
||||
|
||||
|
||||
class TestAgentlessRuns:
|
||||
"""Agentless schedules (NULL agent_id) write runs with NULL agent_id."""
|
||||
|
||||
def test_record_pending_with_null_agent_id(self, pg_conn):
|
||||
schedule = SchedulesRepository(pg_conn).create(
|
||||
user_id="u-agentless",
|
||||
agent_id=None,
|
||||
trigger_type="once",
|
||||
instruction="ping",
|
||||
run_at=_now() + timedelta(minutes=5),
|
||||
next_run_at=_now() + timedelta(minutes=5),
|
||||
)
|
||||
repo = ScheduleRunsRepository(pg_conn)
|
||||
run = repo.record_pending(
|
||||
str(schedule["id"]), "u-agentless", None,
|
||||
_now().replace(microsecond=0),
|
||||
)
|
||||
assert run is not None
|
||||
assert run["agent_id"] is None
|
||||
assert run["user_id"] == "u-agentless"
|
||||
|
||||
def test_record_skipped_with_null_agent_id(self, pg_conn):
|
||||
schedule = SchedulesRepository(pg_conn).create(
|
||||
user_id="u-agentless",
|
||||
agent_id=None,
|
||||
trigger_type="once",
|
||||
instruction="ping",
|
||||
run_at=_now() + timedelta(minutes=5),
|
||||
next_run_at=_now() + timedelta(minutes=5),
|
||||
)
|
||||
repo = ScheduleRunsRepository(pg_conn)
|
||||
row = repo.record_skipped(
|
||||
str(schedule["id"]), "u-agentless", None, _now(),
|
||||
error_type="missed", error="agentless miss",
|
||||
)
|
||||
assert row is not None
|
||||
assert row["agent_id"] is None
|
||||
assert row["status"] == "skipped"
|
||||
|
||||
|
||||
class TestSkippedAndActive:
|
||||
def test_record_skipped(self, pg_conn):
|
||||
schedule_id, agent_id = _make_schedule(pg_conn)
|
||||
repo = ScheduleRunsRepository(pg_conn)
|
||||
row = repo.record_skipped(
|
||||
schedule_id, "u1", agent_id, _now(),
|
||||
error_type="missed", error="worker down",
|
||||
)
|
||||
assert row["status"] == "skipped"
|
||||
assert row["error_type"] == "missed"
|
||||
|
||||
def test_has_active_run(self, pg_conn):
|
||||
schedule_id, agent_id = _make_schedule(pg_conn)
|
||||
repo = ScheduleRunsRepository(pg_conn)
|
||||
assert repo.has_active_run(schedule_id) is False
|
||||
run = repo.record_pending(schedule_id, "u1", agent_id, _now())
|
||||
assert repo.has_active_run(schedule_id) is True
|
||||
repo.update(run["id"], {"status": "success", "finished_at": _now()})
|
||||
assert repo.has_active_run(schedule_id) is False
|
||||
|
||||
|
||||
class TestUpdateAndList:
|
||||
def test_mark_running_only_from_pending(self, pg_conn):
|
||||
schedule_id, agent_id = _make_schedule(pg_conn)
|
||||
repo = ScheduleRunsRepository(pg_conn)
|
||||
run = repo.record_pending(schedule_id, "u1", agent_id, _now())
|
||||
assert repo.mark_running(run["id"], "task-1") is True
|
||||
assert repo.mark_running(run["id"], "task-2") is False
|
||||
|
||||
def test_list_runs_owner_scoped(self, pg_conn):
|
||||
schedule_id, agent_id = _make_schedule(pg_conn)
|
||||
repo = ScheduleRunsRepository(pg_conn)
|
||||
for i in range(3):
|
||||
repo.record_pending(
|
||||
schedule_id, "u1", agent_id,
|
||||
_now() + timedelta(seconds=i),
|
||||
)
|
||||
rows = repo.list_runs(schedule_id, "u1")
|
||||
assert len(rows) == 3
|
||||
assert repo.list_runs(schedule_id, "u2") == []
|
||||
|
||||
def test_list_stuck_running(self, pg_conn):
|
||||
schedule_id, agent_id = _make_schedule(pg_conn)
|
||||
repo = ScheduleRunsRepository(pg_conn)
|
||||
run = repo.record_pending(schedule_id, "u1", agent_id, _now())
|
||||
pg_conn.execute(
|
||||
text(
|
||||
"UPDATE schedule_runs "
|
||||
"SET status = 'running', started_at = now() - interval '30 minutes' "
|
||||
"WHERE id = CAST(:i AS uuid)"
|
||||
),
|
||||
{"i": run["id"]},
|
||||
)
|
||||
stuck = repo.list_stuck_running(age_minutes=15)
|
||||
assert any(r["id"] == run["id"] for r in stuck)
|
||||
|
||||
|
||||
class TestCleanup:
|
||||
def test_cleanup_older_than_keeps_recent(self, pg_conn):
|
||||
schedule_id, agent_id = _make_schedule(pg_conn)
|
||||
repo = ScheduleRunsRepository(pg_conn)
|
||||
ids = []
|
||||
for i in range(5):
|
||||
row = repo.record_pending(
|
||||
schedule_id, "u1", agent_id,
|
||||
_now() + timedelta(seconds=i),
|
||||
)
|
||||
ids.append(row["id"])
|
||||
pg_conn.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE schedule_runs
|
||||
SET created_at = now() - interval '120 days',
|
||||
scheduled_for = scheduled_for - interval '120 days'
|
||||
WHERE id = ANY(CAST(:ids AS uuid[]))
|
||||
"""
|
||||
),
|
||||
{"ids": "{" + ",".join(ids[:3]) + "}"},
|
||||
)
|
||||
deleted = repo.cleanup_older_than(90, keep_recent_per_schedule=2)
|
||||
assert deleted >= 1
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user