mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-07 06:30:03 +00:00
Compare commits
32 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
80a3d424d6 | ||
|
|
d8670b65a4 | ||
|
|
c7266bf507 | ||
|
|
402fa18819 | ||
|
|
e103799f81 | ||
|
|
66541e934b | ||
|
|
a7ef7b4402 | ||
|
|
cf513371ec | ||
|
|
0ebd326fb0 | ||
|
|
5e32defe12 | ||
|
|
ca0c943559 | ||
|
|
681485c21f | ||
|
|
f61095b32e | ||
|
|
928e8588ca | ||
|
|
318acf548a | ||
|
|
c10d474156 | ||
|
|
6325d5f044 | ||
|
|
1bbd2b8b65 | ||
|
|
f7964dcc8d | ||
|
|
ef7ff1613b | ||
|
|
ff0f02c2f0 | ||
|
|
4248e4fcc7 | ||
|
|
0ec8208427 | ||
|
|
13c21e5a7e | ||
|
|
a1ceb1ea8a | ||
|
|
2aa1bc04b0 | ||
|
|
aa880d0b31 | ||
|
|
ea4bd608a8 | ||
|
|
b35572c0fc | ||
|
|
fa0b358f42 | ||
|
|
571f132f3e | ||
|
|
e7f1e2a376 |
22
.github/workflows/vale.yml
vendored
22
.github/workflows/vale.yml
vendored
@@ -11,7 +11,6 @@ on:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
vale:
|
||||
@@ -20,11 +19,16 @@ jobs:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Vale linter
|
||||
uses: errata-ai/vale-action@v2
|
||||
with:
|
||||
files: docs
|
||||
fail_on_error: false
|
||||
version: 3.0.5
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: Install Vale
|
||||
run: |
|
||||
curl -fsSL -o vale.tar.gz \
|
||||
https://github.com/errata-ai/vale/releases/download/v3.0.5/vale_3.0.5_Linux_64-bit.tar.gz
|
||||
tar -xzf vale.tar.gz
|
||||
sudo mv vale /usr/local/bin/vale
|
||||
vale --version
|
||||
|
||||
- name: Sync Vale packages
|
||||
run: vale sync
|
||||
|
||||
- name: Run Vale
|
||||
run: vale --minAlertLevel=error docs
|
||||
|
||||
8
.gitignore
vendored
8
.gitignore
vendored
@@ -186,3 +186,11 @@ node_modules/
|
||||
.vscode/sftp.json
|
||||
/models/
|
||||
model/
|
||||
|
||||
# E2E test artifacts
|
||||
.e2e-tmp/
|
||||
/tmp/docsgpt-e2e/
|
||||
tests/e2e/node_modules/
|
||||
tests/e2e/playwright-report/
|
||||
tests/e2e/test-results/
|
||||
tests/e2e/.e2e-last-run.json
|
||||
|
||||
10
AGENTS.md
10
AGENTS.md
@@ -10,9 +10,15 @@
|
||||
For feature work, do **not** assume the environment needs to be recreated.
|
||||
|
||||
- Check whether the user already has a Python virtual environment such as `venv/` or `.venv/`.
|
||||
- Check whether MongoDB is already running.
|
||||
- Check whether Postgres is already running and reachable via `POSTGRES_URI` (the canonical user-data store).
|
||||
- Check whether Redis is already running.
|
||||
- Reuse what is already working. Do not stop or recreate MongoDB, Redis, or the Python environment unless the task is environment setup or troubleshooting.
|
||||
- Reuse what is already working. Do not stop or recreate Postgres, Redis, or the Python environment unless the task is environment setup or troubleshooting.
|
||||
|
||||
> MongoDB is **not** required for the default install. It is only needed if
|
||||
> the user opts into the Mongo vector-store backend (`VECTOR_STORE=mongodb`)
|
||||
> or is running the one-shot `scripts/db/backfill.py` to migrate existing
|
||||
> user data from the legacy Mongo-based install. In those cases, `pymongo`
|
||||
> is available as an optional extra, not a core dependency.
|
||||
|
||||
## Normal local development commands
|
||||
|
||||
|
||||
@@ -3,13 +3,12 @@ import uuid
|
||||
from collections import Counter
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
|
||||
from application.agents.tools.tool_action_parser import ToolActionParser
|
||||
from application.agents.tools.tool_manager import ToolManager
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.security.encryption import decrypt_credentials
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.user_tools import UserToolsRepository
|
||||
from application.storage.db.session import db_readonly
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -51,30 +50,28 @@ class ToolExecutor:
|
||||
return tools
|
||||
|
||||
def _get_tools_by_api_key(self, api_key: str) -> Dict[str, Dict]:
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
agents_collection = db["agents"]
|
||||
tools_collection = db["user_tools"]
|
||||
|
||||
agent_data = agents_collection.find_one({"key": api_key})
|
||||
tool_ids = agent_data.get("tools", []) if agent_data else []
|
||||
|
||||
tools = (
|
||||
tools_collection.find(
|
||||
{"_id": {"$in": [ObjectId(tool_id) for tool_id in tool_ids]}}
|
||||
)
|
||||
if tool_ids
|
||||
else []
|
||||
)
|
||||
tools = list(tools)
|
||||
return {str(tool["_id"]): tool for tool in tools} if tools else {}
|
||||
# Per-operation session: the answer pipeline spans a long-lived
|
||||
# generator; wrapping it in a single connection would pin a PG
|
||||
# conn for the whole stream. Open, fetch, close.
|
||||
with db_readonly() as conn:
|
||||
agent_data = AgentsRepository(conn).find_by_key(api_key)
|
||||
tool_ids = agent_data.get("tools", []) if agent_data else []
|
||||
if not tool_ids:
|
||||
return {}
|
||||
tools_repo = UserToolsRepository(conn)
|
||||
tools: List[Dict] = []
|
||||
owner = (agent_data.get("user_id") or agent_data.get("user")) if agent_data else None
|
||||
for tid in tool_ids:
|
||||
row = None
|
||||
if owner:
|
||||
row = tools_repo.get_any(str(tid), owner)
|
||||
if row is not None:
|
||||
tools.append(row)
|
||||
return {str(tool["id"]): tool for tool in tools} if tools else {}
|
||||
|
||||
def _get_user_tools(self, user: str = "local") -> Dict[str, Dict]:
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
user_tools_collection = db["user_tools"]
|
||||
user_tools = user_tools_collection.find({"user": user, "status": True})
|
||||
user_tools = list(user_tools)
|
||||
with db_readonly() as conn:
|
||||
user_tools = UserToolsRepository(conn).list_active_for_user(user)
|
||||
return {str(i): tool for i, tool in enumerate(user_tools)}
|
||||
|
||||
def merge_client_tools(
|
||||
@@ -354,6 +351,17 @@ class ToolExecutor:
|
||||
headers=headers, query_params=query_params,
|
||||
)
|
||||
|
||||
if tool is None:
|
||||
error_message = (
|
||||
f"Failed to load tool '{tool_data.get('name')}' (tool_id key={tool_id}): "
|
||||
"missing 'id' on tool row."
|
||||
)
|
||||
logger.error(error_message)
|
||||
tool_call_data["result"] = error_message
|
||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
|
||||
self.tool_calls.append(tool_call_data)
|
||||
return error_message, call_id
|
||||
|
||||
resolved_arguments = (
|
||||
{"query_params": query_params, "headers": headers, "body": body}
|
||||
if tool_data["name"] == "api_tool"
|
||||
@@ -440,7 +448,16 @@ class ToolExecutor:
|
||||
tool_config.update(decrypted)
|
||||
tool_config["auth_credentials"] = decrypted
|
||||
tool_config.pop("encrypted_credentials", None)
|
||||
tool_config["tool_id"] = str(tool_data.get("_id", tool_id))
|
||||
row_id = tool_data.get("id")
|
||||
if not row_id:
|
||||
logger.error(
|
||||
"Tool data missing 'id' for tool name=%s (enumerate-key tool_id=%s); "
|
||||
"skipping load to avoid binding a non-UUID downstream.",
|
||||
tool_data.get("name"),
|
||||
tool_id,
|
||||
)
|
||||
return None
|
||||
tool_config["tool_id"] = str(row_id)
|
||||
if self.conversation_id:
|
||||
tool_config["conversation_id"] = self.conversation_id
|
||||
if tool_data["name"] == "mcp_tool":
|
||||
|
||||
@@ -48,7 +48,7 @@ class InternalSearchTool(Tool):
|
||||
return self._retriever
|
||||
|
||||
def _get_directory_structure(self) -> Optional[Dict]:
|
||||
"""Load directory structure from MongoDB for the configured sources."""
|
||||
"""Load directory structure from Postgres for the configured sources."""
|
||||
if self._dir_structure_loaded:
|
||||
return self._directory_structure
|
||||
|
||||
@@ -59,35 +59,39 @@ class InternalSearchTool(Tool):
|
||||
return None
|
||||
|
||||
try:
|
||||
from bson.objectid import ObjectId
|
||||
from application.core.mongo_db import MongoDB
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
sources_collection = db["sources"]
|
||||
# Per-operation session: this tool runs inside the answer
|
||||
# generator hot path, so we open a short-lived read
|
||||
# connection for the batch lookup and release immediately.
|
||||
from application.storage.db.repositories.sources import (
|
||||
SourcesRepository,
|
||||
)
|
||||
from application.storage.db.session import db_readonly
|
||||
|
||||
if isinstance(active_docs, str):
|
||||
active_docs = [active_docs]
|
||||
|
||||
decoded_token = self.config.get("decoded_token") or {}
|
||||
user_id = decoded_token.get("sub") if decoded_token else None
|
||||
|
||||
merged_structure = {}
|
||||
for doc_id in active_docs:
|
||||
try:
|
||||
source_doc = sources_collection.find_one(
|
||||
{"_id": ObjectId(doc_id)}
|
||||
)
|
||||
if not source_doc:
|
||||
continue
|
||||
dir_str = source_doc.get("directory_structure")
|
||||
if dir_str:
|
||||
if isinstance(dir_str, str):
|
||||
dir_str = json.loads(dir_str)
|
||||
source_name = source_doc.get("name", doc_id)
|
||||
if len(active_docs) > 1:
|
||||
merged_structure[source_name] = dir_str
|
||||
else:
|
||||
merged_structure = dir_str
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not load dir structure for {doc_id}: {e}")
|
||||
with db_readonly() as conn:
|
||||
repo = SourcesRepository(conn)
|
||||
for doc_id in active_docs:
|
||||
try:
|
||||
source_doc = repo.get_any(str(doc_id), user_id) if user_id else None
|
||||
if not source_doc:
|
||||
continue
|
||||
dir_str = source_doc.get("directory_structure")
|
||||
if dir_str:
|
||||
if isinstance(dir_str, str):
|
||||
dir_str = json.loads(dir_str)
|
||||
source_name = source_doc.get("name", doc_id)
|
||||
if len(active_docs) > 1:
|
||||
merged_structure[source_name] = dir_str
|
||||
else:
|
||||
merged_structure = dir_str
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not load dir structure for {doc_id}: {e}")
|
||||
|
||||
self._directory_structure = merged_structure if merged_structure else None
|
||||
except Exception as e:
|
||||
@@ -357,32 +361,48 @@ INTERNAL_TOOL_ENTRY = build_internal_tool_entry(has_directory_structure=False)
|
||||
|
||||
|
||||
def sources_have_directory_structure(source: Dict) -> bool:
|
||||
"""Check if any of the active sources have directory_structure in MongoDB."""
|
||||
"""Check if any of the active sources have a ``directory_structure`` row."""
|
||||
active_docs = source.get("active_docs", [])
|
||||
if not active_docs:
|
||||
return False
|
||||
|
||||
try:
|
||||
from bson.objectid import ObjectId
|
||||
from application.core.mongo_db import MongoDB
|
||||
# TODO(pg-cutover): SourcesRepository.get_any requires ``user_id``
|
||||
# scoping, but callers in the agent build path don't always
|
||||
# thread the decoded token through here. Use a direct
|
||||
# short-lived SQL lookup instead of the repo until the call
|
||||
# sites are updated to propagate user context.
|
||||
from sqlalchemy import text as _text
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
sources_collection = db["sources"]
|
||||
from application.storage.db.session import db_readonly
|
||||
|
||||
if isinstance(active_docs, str):
|
||||
active_docs = [active_docs]
|
||||
|
||||
for doc_id in active_docs:
|
||||
try:
|
||||
source_doc = sources_collection.find_one(
|
||||
{"_id": ObjectId(doc_id)},
|
||||
{"directory_structure": 1},
|
||||
)
|
||||
if source_doc and source_doc.get("directory_structure"):
|
||||
return True
|
||||
except Exception:
|
||||
continue
|
||||
with db_readonly() as conn:
|
||||
for doc_id in active_docs:
|
||||
try:
|
||||
value = str(doc_id)
|
||||
if len(value) == 36 and "-" in value:
|
||||
row = conn.execute(
|
||||
_text(
|
||||
"SELECT directory_structure FROM sources "
|
||||
"WHERE id = CAST(:id AS uuid)"
|
||||
),
|
||||
{"id": value},
|
||||
).fetchone()
|
||||
else:
|
||||
row = conn.execute(
|
||||
_text(
|
||||
"SELECT directory_structure FROM sources "
|
||||
"WHERE legacy_mongo_id = :lid"
|
||||
),
|
||||
{"lid": value},
|
||||
).fetchone()
|
||||
if row is not None and row[0]:
|
||||
return True
|
||||
except Exception:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not check directory structure: {e}")
|
||||
|
||||
|
||||
@@ -22,16 +22,12 @@ from redis import Redis
|
||||
from application.agents.tools.base import Tool
|
||||
from application.api.user.tasks import mcp_oauth_status_task, mcp_oauth_task
|
||||
from application.cache import get_redis_instance
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.core.url_validation import SSRFError, validate_url
|
||||
from application.security.encryption import decrypt_credentials
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
|
||||
_mcp_clients_cache = {}
|
||||
|
||||
|
||||
@@ -161,7 +157,6 @@ class MCPTool(Tool):
|
||||
scopes=self.oauth_scopes,
|
||||
redis_client=redis_client,
|
||||
redirect_uri=self.redirect_uri,
|
||||
db=db,
|
||||
user_id=self.user_id,
|
||||
)
|
||||
else:
|
||||
@@ -171,7 +166,6 @@ class MCPTool(Tool):
|
||||
redis_client=redis_client,
|
||||
redirect_uri=self.redirect_uri,
|
||||
task_id=self.oauth_task_id,
|
||||
db=db,
|
||||
user_id=self.user_id,
|
||||
)
|
||||
elif self.auth_type == "bearer":
|
||||
@@ -491,7 +485,7 @@ class MCPTool(Tool):
|
||||
|
||||
def _test_oauth_connection(self) -> Dict:
|
||||
storage = DBTokenStorage(
|
||||
server_url=self.server_url, user_id=self.user_id, db_client=db
|
||||
server_url=self.server_url, user_id=self.user_id,
|
||||
)
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
@@ -683,7 +677,6 @@ class DocsGPTOAuth(OAuthClientProvider):
|
||||
scopes: str | list[str] | None = None,
|
||||
client_name: str = "DocsGPT-MCP",
|
||||
user_id=None,
|
||||
db=None,
|
||||
additional_client_metadata: dict[str, Any] | None = None,
|
||||
skip_redirect_validation: bool = False,
|
||||
):
|
||||
@@ -692,7 +685,6 @@ class DocsGPTOAuth(OAuthClientProvider):
|
||||
self.redis_prefix = redis_prefix
|
||||
self.task_id = task_id
|
||||
self.user_id = user_id
|
||||
self.db = db
|
||||
|
||||
parsed_url = urlparse(mcp_url)
|
||||
self.server_base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
||||
@@ -711,7 +703,6 @@ class DocsGPTOAuth(OAuthClientProvider):
|
||||
storage = DBTokenStorage(
|
||||
server_url=self.server_base_url,
|
||||
user_id=self.user_id,
|
||||
db_client=self.db,
|
||||
expected_redirect_uri=None if skip_redirect_validation else redirect_uri,
|
||||
)
|
||||
|
||||
@@ -853,54 +844,95 @@ class DBTokenStorage(TokenStorage):
|
||||
self,
|
||||
server_url: str,
|
||||
user_id: str,
|
||||
db_client,
|
||||
expected_redirect_uri: Optional[str] = None,
|
||||
):
|
||||
self.server_url = server_url
|
||||
self.user_id = user_id
|
||||
self.db_client = db_client
|
||||
self.expected_redirect_uri = expected_redirect_uri
|
||||
self.collection = db_client["connector_sessions"]
|
||||
|
||||
@staticmethod
|
||||
def get_base_url(url: str) -> str:
|
||||
parsed = urlparse(url)
|
||||
return f"{parsed.scheme}://{parsed.netloc}"
|
||||
|
||||
def get_db_key(self) -> dict:
|
||||
return {
|
||||
"server_url": self.get_base_url(self.server_url),
|
||||
"user_id": self.user_id,
|
||||
}
|
||||
def _pg_provider(self) -> str:
|
||||
return f"mcp:{self.get_base_url(self.server_url)}"
|
||||
|
||||
def _fetch_session_data(self) -> dict:
|
||||
"""Read the JSONB ``session_data`` blob for this MCP server row."""
|
||||
from application.storage.db.repositories.connector_sessions import (
|
||||
ConnectorSessionsRepository,
|
||||
)
|
||||
from application.storage.db.session import db_readonly
|
||||
|
||||
base_url = self.get_base_url(self.server_url)
|
||||
with db_readonly() as conn:
|
||||
row = ConnectorSessionsRepository(conn).get_by_user_and_server_url(
|
||||
self.user_id, base_url,
|
||||
)
|
||||
if not row:
|
||||
return {}
|
||||
data = row.get("session_data") or {}
|
||||
if isinstance(data, str):
|
||||
try:
|
||||
data = json.loads(data)
|
||||
except ValueError:
|
||||
return {}
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
async def get_tokens(self) -> OAuthToken | None:
|
||||
doc = await asyncio.to_thread(self.collection.find_one, self.get_db_key())
|
||||
if not doc or "tokens" not in doc:
|
||||
data = await asyncio.to_thread(self._fetch_session_data)
|
||||
if not data or "tokens" not in data:
|
||||
return None
|
||||
try:
|
||||
return OAuthToken.model_validate(doc["tokens"])
|
||||
return OAuthToken.model_validate(data["tokens"])
|
||||
except ValidationError as e:
|
||||
logger.error("Could not load tokens: %s", e)
|
||||
return None
|
||||
|
||||
async def set_tokens(self, tokens: OAuthToken) -> None:
|
||||
await asyncio.to_thread(
|
||||
self.collection.update_one,
|
||||
self.get_db_key(),
|
||||
{"$set": {"tokens": tokens.model_dump()}},
|
||||
True,
|
||||
def _merge(self, patch: dict) -> None:
|
||||
"""Shallow-merge ``patch`` into this row's ``session_data``.
|
||||
|
||||
Threads ``server_url`` through to the repository so it lands in
|
||||
the scalar column — ``get_by_user_and_server_url`` needs that to
|
||||
resolve the row (``NULL = 'https://...'`` is UNKNOWN in SQL).
|
||||
"""
|
||||
from application.storage.db.repositories.connector_sessions import (
|
||||
ConnectorSessionsRepository,
|
||||
)
|
||||
logger.info("Saved tokens for %s", self.get_base_url(self.server_url))
|
||||
from application.storage.db.session import db_session
|
||||
|
||||
base_url = self.get_base_url(self.server_url)
|
||||
with db_session() as conn:
|
||||
ConnectorSessionsRepository(conn).merge_session_data(
|
||||
self.user_id, self._pg_provider(), base_url, patch,
|
||||
)
|
||||
|
||||
def _delete(self) -> None:
|
||||
from application.storage.db.repositories.connector_sessions import (
|
||||
ConnectorSessionsRepository,
|
||||
)
|
||||
from application.storage.db.session import db_session
|
||||
|
||||
with db_session() as conn:
|
||||
ConnectorSessionsRepository(conn).delete(
|
||||
self.user_id, self._pg_provider(),
|
||||
)
|
||||
|
||||
async def set_tokens(self, tokens: OAuthToken) -> None:
|
||||
base_url = self.get_base_url(self.server_url)
|
||||
token_dump = tokens.model_dump()
|
||||
await asyncio.to_thread(self._merge, {"tokens": token_dump})
|
||||
logger.info("Saved tokens for %s", base_url)
|
||||
|
||||
async def get_client_info(self) -> OAuthClientInformationFull | None:
|
||||
doc = await asyncio.to_thread(self.collection.find_one, self.get_db_key())
|
||||
if not doc or "client_info" not in doc:
|
||||
logger.debug(
|
||||
"No client_info in DB for %s", self.get_base_url(self.server_url)
|
||||
)
|
||||
data = await asyncio.to_thread(self._fetch_session_data)
|
||||
base_url = self.get_base_url(self.server_url)
|
||||
if not data or "client_info" not in data:
|
||||
logger.debug("No client_info in DB for %s", base_url)
|
||||
return None
|
||||
try:
|
||||
client_info = OAuthClientInformationFull.model_validate(doc["client_info"])
|
||||
client_info = OAuthClientInformationFull.model_validate(data["client_info"])
|
||||
if self.expected_redirect_uri:
|
||||
stored_uris = [
|
||||
str(uri).rstrip("/") for uri in client_info.redirect_uris
|
||||
@@ -909,14 +941,16 @@ class DBTokenStorage(TokenStorage):
|
||||
if expected_uri not in stored_uris:
|
||||
logger.warning(
|
||||
"Redirect URI mismatch for %s: expected=%s stored=%s — clearing.",
|
||||
self.get_base_url(self.server_url),
|
||||
base_url,
|
||||
expected_uri,
|
||||
stored_uris,
|
||||
)
|
||||
# Drop ``tokens`` and ``client_info`` from the JSONB
|
||||
# blob via merge_session_data's ``None``-drops-key
|
||||
# semantics — preserves the row + any other keys.
|
||||
await asyncio.to_thread(
|
||||
self.collection.update_one,
|
||||
self.get_db_key(),
|
||||
{"$unset": {"client_info": "", "tokens": ""}},
|
||||
self._merge,
|
||||
{"tokens": None, "client_info": None},
|
||||
)
|
||||
return None
|
||||
return client_info
|
||||
@@ -931,22 +965,37 @@ class DBTokenStorage(TokenStorage):
|
||||
|
||||
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
|
||||
serialized_info = self._serialize_client_info(client_info.model_dump())
|
||||
base_url = self.get_base_url(self.server_url)
|
||||
await asyncio.to_thread(
|
||||
self.collection.update_one,
|
||||
self.get_db_key(),
|
||||
{"$set": {"client_info": serialized_info}},
|
||||
True,
|
||||
self._merge, {"client_info": serialized_info},
|
||||
)
|
||||
logger.info("Saved client info for %s", self.get_base_url(self.server_url))
|
||||
logger.info("Saved client info for %s", base_url)
|
||||
|
||||
async def clear(self) -> None:
|
||||
await asyncio.to_thread(self.collection.delete_one, self.get_db_key())
|
||||
await asyncio.to_thread(self._delete)
|
||||
logger.info("Cleared OAuth cache for %s", self.get_base_url(self.server_url))
|
||||
|
||||
@classmethod
|
||||
async def clear_all(cls, db_client) -> None:
|
||||
collection = db_client["connector_sessions"]
|
||||
await asyncio.to_thread(collection.delete_many, {})
|
||||
async def clear_all(cls, db_client=None) -> None:
|
||||
"""Delete every MCP-tagged connector session row.
|
||||
|
||||
``db_client`` retained for call-site compatibility but unused —
|
||||
storage is Postgres-only now.
|
||||
"""
|
||||
from sqlalchemy import text
|
||||
|
||||
from application.storage.db.session import db_session
|
||||
|
||||
def _delete_all() -> None:
|
||||
with db_session() as conn:
|
||||
conn.execute(
|
||||
text(
|
||||
"DELETE FROM connector_sessions "
|
||||
"WHERE provider LIKE 'mcp:%'"
|
||||
)
|
||||
)
|
||||
|
||||
await asyncio.to_thread(_delete_all)
|
||||
logger.info("Cleared all OAuth client cache data.")
|
||||
|
||||
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
import re
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from .base import Tool
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.repositories.memories import MemoriesRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemoryTool(Tool):
|
||||
@@ -27,7 +29,7 @@ class MemoryTool(Tool):
|
||||
self.user_id: Optional[str] = user_id
|
||||
|
||||
# Get tool_id from configuration (passed from user_tools._id in production)
|
||||
# In production, tool_id is the MongoDB ObjectId string from user_tools collection
|
||||
# In production, tool_id is the UUID string from user_tools.id.
|
||||
if tool_config and "tool_id" in tool_config:
|
||||
self.tool_id = tool_config["tool_id"]
|
||||
elif user_id:
|
||||
@@ -37,8 +39,35 @@ class MemoryTool(Tool):
|
||||
# Last resort fallback (shouldn't happen in normal use)
|
||||
self.tool_id = str(uuid.uuid4())
|
||||
|
||||
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
|
||||
self.collection = db["memories"]
|
||||
def _pg_enabled(self) -> bool:
|
||||
"""Return True if this MemoryTool's tool_id is a real ``user_tools.id``.
|
||||
|
||||
The ``memories`` PG table has a UUID foreign key to ``user_tools``.
|
||||
The sentinel ``default_{uid}`` fallback tool_id is not a UUID and
|
||||
has no row in ``user_tools``, so any storage operation would fail
|
||||
the foreign-key check. After the Postgres cutover Postgres is the
|
||||
only store, so for the sentinel case there is nowhere to read or
|
||||
write — operations become no-ops and the tool returns an
|
||||
explanatory error to the caller.
|
||||
"""
|
||||
tool_id = getattr(self, "tool_id", None)
|
||||
if not tool_id or not isinstance(tool_id, str):
|
||||
return False
|
||||
if tool_id.startswith("default_"):
|
||||
logger.debug(
|
||||
"Skipping Postgres operation for MemoryTool with sentinel tool_id=%s",
|
||||
tool_id,
|
||||
)
|
||||
return False
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
|
||||
if not looks_like_uuid(tool_id):
|
||||
logger.debug(
|
||||
"Skipping Postgres operation for MemoryTool with non-UUID tool_id=%s",
|
||||
tool_id,
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
# -----------------------------
|
||||
# Action implementations
|
||||
@@ -56,6 +85,12 @@ class MemoryTool(Tool):
|
||||
if not self.user_id:
|
||||
return "Error: MemoryTool requires a valid user_id."
|
||||
|
||||
if not self._pg_enabled():
|
||||
return (
|
||||
"Error: MemoryTool is not configured with a persistent tool_id; "
|
||||
"memory storage is unavailable for this session."
|
||||
)
|
||||
|
||||
if action_name == "view":
|
||||
return self._view(
|
||||
kwargs.get("path", "/"),
|
||||
@@ -282,14 +317,10 @@ class MemoryTool(Tool):
|
||||
# Ensure path ends with / for proper prefix matching
|
||||
search_path = path if path.endswith("/") else path + "/"
|
||||
|
||||
# Find all files that start with this directory path
|
||||
query = {
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": {"$regex": f"^{re.escape(search_path)}"}
|
||||
}
|
||||
|
||||
docs = list(self.collection.find(query, {"path": 1}))
|
||||
with db_readonly() as conn:
|
||||
docs = MemoriesRepository(conn).list_by_prefix(
|
||||
self.user_id, self.tool_id, search_path
|
||||
)
|
||||
|
||||
if not docs:
|
||||
return f"Directory: {path}\n(empty)"
|
||||
@@ -310,7 +341,10 @@ class MemoryTool(Tool):
|
||||
|
||||
def _view_file(self, path: str, view_range: Optional[List[int]] = None) -> str:
|
||||
"""View file contents with optional line range."""
|
||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": path})
|
||||
with db_readonly() as conn:
|
||||
doc = MemoriesRepository(conn).get_by_path(
|
||||
self.user_id, self.tool_id, path
|
||||
)
|
||||
|
||||
if not doc or not doc.get("content"):
|
||||
return f"Error: File not found: {path}"
|
||||
@@ -344,16 +378,10 @@ class MemoryTool(Tool):
|
||||
if validated_path == "/" or validated_path.endswith("/"):
|
||||
return "Error: Cannot create a file at directory path."
|
||||
|
||||
self.collection.update_one(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path},
|
||||
{
|
||||
"$set": {
|
||||
"content": file_text,
|
||||
"updated_at": datetime.now()
|
||||
}
|
||||
},
|
||||
upsert=True
|
||||
)
|
||||
with db_session() as conn:
|
||||
MemoriesRepository(conn).upsert(
|
||||
self.user_id, self.tool_id, validated_path, file_text
|
||||
)
|
||||
|
||||
return f"File created: {validated_path}"
|
||||
|
||||
@@ -366,30 +394,29 @@ class MemoryTool(Tool):
|
||||
if not old_str:
|
||||
return "Error: old_str is required."
|
||||
|
||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path})
|
||||
with db_session() as conn:
|
||||
repo = MemoriesRepository(conn)
|
||||
doc = repo.get_by_path(self.user_id, self.tool_id, validated_path)
|
||||
|
||||
if not doc or not doc.get("content"):
|
||||
return f"Error: File not found: {validated_path}"
|
||||
if not doc or not doc.get("content"):
|
||||
return f"Error: File not found: {validated_path}"
|
||||
|
||||
current_content = str(doc["content"])
|
||||
current_content = str(doc["content"])
|
||||
|
||||
# Check if old_str exists (case-insensitive)
|
||||
if old_str.lower() not in current_content.lower():
|
||||
return f"Error: String '{old_str}' not found in file."
|
||||
# Check if old_str exists (case-insensitive)
|
||||
if old_str.lower() not in current_content.lower():
|
||||
return f"Error: String '{old_str}' not found in file."
|
||||
|
||||
# Replace the string (case-insensitive)
|
||||
import re as regex_module
|
||||
updated_content = regex_module.sub(regex_module.escape(old_str), new_str, current_content, flags=regex_module.IGNORECASE)
|
||||
# Case-insensitive replace
|
||||
import re as regex_module
|
||||
updated_content = regex_module.sub(
|
||||
regex_module.escape(old_str),
|
||||
new_str,
|
||||
current_content,
|
||||
flags=regex_module.IGNORECASE,
|
||||
)
|
||||
|
||||
self.collection.update_one(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path},
|
||||
{
|
||||
"$set": {
|
||||
"content": updated_content,
|
||||
"updated_at": datetime.now()
|
||||
}
|
||||
}
|
||||
)
|
||||
repo.upsert(self.user_id, self.tool_id, validated_path, updated_content)
|
||||
|
||||
return f"File updated: {validated_path}"
|
||||
|
||||
@@ -402,31 +429,25 @@ class MemoryTool(Tool):
|
||||
if not insert_text:
|
||||
return "Error: insert_text is required."
|
||||
|
||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path})
|
||||
with db_session() as conn:
|
||||
repo = MemoriesRepository(conn)
|
||||
doc = repo.get_by_path(self.user_id, self.tool_id, validated_path)
|
||||
|
||||
if not doc or not doc.get("content"):
|
||||
return f"Error: File not found: {validated_path}"
|
||||
if not doc or not doc.get("content"):
|
||||
return f"Error: File not found: {validated_path}"
|
||||
|
||||
current_content = str(doc["content"])
|
||||
lines = current_content.split("\n")
|
||||
current_content = str(doc["content"])
|
||||
lines = current_content.split("\n")
|
||||
|
||||
# Convert to 0-indexed
|
||||
index = insert_line - 1
|
||||
if index < 0 or index > len(lines):
|
||||
return f"Error: Invalid line number. File has {len(lines)} lines."
|
||||
# Convert to 0-indexed
|
||||
index = insert_line - 1
|
||||
if index < 0 or index > len(lines):
|
||||
return f"Error: Invalid line number. File has {len(lines)} lines."
|
||||
|
||||
lines.insert(index, insert_text)
|
||||
updated_content = "\n".join(lines)
|
||||
lines.insert(index, insert_text)
|
||||
updated_content = "\n".join(lines)
|
||||
|
||||
self.collection.update_one(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path},
|
||||
{
|
||||
"$set": {
|
||||
"content": updated_content,
|
||||
"updated_at": datetime.now()
|
||||
}
|
||||
}
|
||||
)
|
||||
repo.upsert(self.user_id, self.tool_id, validated_path, updated_content)
|
||||
|
||||
return f"Text inserted at line {insert_line} in {validated_path}"
|
||||
|
||||
@@ -438,39 +459,36 @@ class MemoryTool(Tool):
|
||||
|
||||
if validated_path == "/":
|
||||
# Delete all files for this user and tool
|
||||
result = self.collection.delete_many({"user_id": self.user_id, "tool_id": self.tool_id})
|
||||
return f"Deleted {result.deleted_count} file(s) from memory."
|
||||
with db_session() as conn:
|
||||
deleted = MemoriesRepository(conn).delete_all(
|
||||
self.user_id, self.tool_id
|
||||
)
|
||||
return f"Deleted {deleted} file(s) from memory."
|
||||
|
||||
# Check if it's a directory (ends with /)
|
||||
if validated_path.endswith("/"):
|
||||
# Delete all files in directory
|
||||
result = self.collection.delete_many({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": {"$regex": f"^{re.escape(validated_path)}"}
|
||||
})
|
||||
return f"Deleted directory and {result.deleted_count} file(s)."
|
||||
with db_session() as conn:
|
||||
deleted = MemoriesRepository(conn).delete_by_prefix(
|
||||
self.user_id, self.tool_id, validated_path
|
||||
)
|
||||
return f"Deleted directory and {deleted} file(s)."
|
||||
|
||||
# Try to delete as directory first (without trailing slash)
|
||||
# Check if any files start with this path + /
|
||||
# Try as directory first (without trailing slash)
|
||||
search_path = validated_path + "/"
|
||||
directory_result = self.collection.delete_many({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": {"$regex": f"^{re.escape(search_path)}"}
|
||||
})
|
||||
with db_session() as conn:
|
||||
repo = MemoriesRepository(conn)
|
||||
directory_deleted = repo.delete_by_prefix(
|
||||
self.user_id, self.tool_id, search_path
|
||||
)
|
||||
if directory_deleted > 0:
|
||||
return f"Deleted directory and {directory_deleted} file(s)."
|
||||
|
||||
if directory_result.deleted_count > 0:
|
||||
return f"Deleted directory and {directory_result.deleted_count} file(s)."
|
||||
# Otherwise delete a single file
|
||||
file_deleted = repo.delete_by_path(
|
||||
self.user_id, self.tool_id, validated_path
|
||||
)
|
||||
|
||||
# Delete single file
|
||||
result = self.collection.delete_one({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": validated_path
|
||||
})
|
||||
|
||||
if result.deleted_count:
|
||||
if file_deleted:
|
||||
return f"Deleted: {validated_path}"
|
||||
return f"Error: File not found: {validated_path}"
|
||||
|
||||
@@ -485,62 +503,46 @@ class MemoryTool(Tool):
|
||||
if validated_old == "/" or validated_new == "/":
|
||||
return "Error: Cannot rename root directory."
|
||||
|
||||
# Check if renaming a directory
|
||||
# Directory rename: do all path updates inside one transaction so
|
||||
# the rename is atomic from the caller's perspective.
|
||||
if validated_old.endswith("/"):
|
||||
# Ensure validated_new also ends with / for proper path replacement
|
||||
if not validated_new.endswith("/"):
|
||||
validated_new = validated_new + "/"
|
||||
|
||||
# Find all files in the old directory
|
||||
docs = list(self.collection.find({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": {"$regex": f"^{re.escape(validated_old)}"}
|
||||
}))
|
||||
|
||||
if not docs:
|
||||
return f"Error: Directory not found: {validated_old}"
|
||||
|
||||
# Update paths for all files
|
||||
for doc in docs:
|
||||
old_file_path = doc["path"]
|
||||
new_file_path = old_file_path.replace(validated_old, validated_new, 1)
|
||||
|
||||
self.collection.update_one(
|
||||
{"_id": doc["_id"]},
|
||||
{"$set": {"path": new_file_path, "updated_at": datetime.now()}}
|
||||
with db_session() as conn:
|
||||
repo = MemoriesRepository(conn)
|
||||
docs = repo.list_by_prefix(
|
||||
self.user_id, self.tool_id, validated_old
|
||||
)
|
||||
|
||||
if not docs:
|
||||
return f"Error: Directory not found: {validated_old}"
|
||||
|
||||
for doc in docs:
|
||||
old_file_path = doc["path"]
|
||||
new_file_path = old_file_path.replace(
|
||||
validated_old, validated_new, 1
|
||||
)
|
||||
repo.update_path(
|
||||
self.user_id, self.tool_id, old_file_path, new_file_path
|
||||
)
|
||||
|
||||
return f"Renamed directory: {validated_old} -> {validated_new} ({len(docs)} files)"
|
||||
|
||||
# Rename single file
|
||||
doc = self.collection.find_one({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": validated_old
|
||||
})
|
||||
# Single-file rename: lookup, collision check, and update in one txn.
|
||||
with db_session() as conn:
|
||||
repo = MemoriesRepository(conn)
|
||||
doc = repo.get_by_path(self.user_id, self.tool_id, validated_old)
|
||||
if not doc:
|
||||
return f"Error: File not found: {validated_old}"
|
||||
|
||||
if not doc:
|
||||
return f"Error: File not found: {validated_old}"
|
||||
existing = repo.get_by_path(self.user_id, self.tool_id, validated_new)
|
||||
if existing:
|
||||
return f"Error: File already exists at {validated_new}"
|
||||
|
||||
# Check if new path already exists
|
||||
existing = self.collection.find_one({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": validated_new
|
||||
})
|
||||
|
||||
if existing:
|
||||
return f"Error: File already exists at {validated_new}"
|
||||
|
||||
# Delete the old document and create a new one with the new path
|
||||
self.collection.delete_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_old})
|
||||
self.collection.insert_one({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": validated_new,
|
||||
"content": doc.get("content", ""),
|
||||
"updated_at": datetime.now()
|
||||
})
|
||||
repo.update_path(
|
||||
self.user_id, self.tool_id, validated_old, validated_new
|
||||
)
|
||||
|
||||
return f"Renamed: {validated_old} -> {validated_new}"
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
import uuid
|
||||
|
||||
from .base import Tool
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.repositories.notes import NotesRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
|
||||
|
||||
# Stable synthetic title used in the Postgres ``notes.title`` column.
|
||||
# The notes tool stores one note per (user_id, tool_id); there is no
|
||||
# user-facing title. PG requires ``title`` NOT NULL, so we write a stable
|
||||
# constant alongside the actual note body in ``content``.
|
||||
_NOTE_TITLE = "note"
|
||||
|
||||
|
||||
class NotesTool(Tool):
|
||||
@@ -25,7 +31,6 @@ class NotesTool(Tool):
|
||||
self.user_id: Optional[str] = user_id
|
||||
|
||||
# Get tool_id from configuration (passed from user_tools._id in production)
|
||||
# In production, tool_id is the MongoDB ObjectId string from user_tools collection
|
||||
if tool_config and "tool_id" in tool_config:
|
||||
self.tool_id = tool_config["tool_id"]
|
||||
elif user_id:
|
||||
@@ -35,11 +40,25 @@ class NotesTool(Tool):
|
||||
# Last resort fallback (shouldn't happen in normal use)
|
||||
self.tool_id = str(uuid.uuid4())
|
||||
|
||||
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
|
||||
self.collection = db["notes"]
|
||||
|
||||
self._last_artifact_id: Optional[str] = None
|
||||
|
||||
def _pg_enabled(self) -> bool:
|
||||
"""Return True only when ``tool_id`` is a real ``user_tools.id`` UUID.
|
||||
|
||||
``notes.tool_id`` is a UUID FK to ``user_tools``; repo queries
|
||||
``CAST(:tool_id AS uuid)``. The sentinel ``default_{uid}``
|
||||
fallback is neither a UUID nor a ``user_tools`` row, so any DB
|
||||
operation would crash. Mirror MemoryTool's guard and no-op.
|
||||
"""
|
||||
tool_id = getattr(self, "tool_id", None)
|
||||
if not tool_id or not isinstance(tool_id, str):
|
||||
return False
|
||||
if tool_id.startswith("default_"):
|
||||
return False
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
|
||||
return looks_like_uuid(tool_id)
|
||||
|
||||
# -----------------------------
|
||||
# Action implementations
|
||||
# -----------------------------
|
||||
@@ -54,7 +73,13 @@ class NotesTool(Tool):
|
||||
A human-readable string result.
|
||||
"""
|
||||
if not self.user_id:
|
||||
return "Error: NotesTool requires a valid user_id."
|
||||
return "Error: NotesTool requires a valid user_id."
|
||||
|
||||
if not self._pg_enabled():
|
||||
return (
|
||||
"Error: NotesTool is not configured with a persistent "
|
||||
"tool_id; note storage is unavailable for this session."
|
||||
)
|
||||
|
||||
self._last_artifact_id = None
|
||||
|
||||
@@ -135,37 +160,45 @@ class NotesTool(Tool):
|
||||
# -----------------------------
|
||||
# Internal helpers (single-note)
|
||||
# -----------------------------
|
||||
def _fetch_note(self) -> Optional[dict]:
|
||||
"""Read the note row for this (user, tool) from Postgres."""
|
||||
with db_readonly() as conn:
|
||||
return NotesRepository(conn).get_for_user_tool(self.user_id, self.tool_id)
|
||||
|
||||
def _get_note(self) -> str:
|
||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
|
||||
if not doc or not doc.get("note"):
|
||||
doc = self._fetch_note()
|
||||
# ``content`` is the PG column; expose as ``note`` to callers via the
|
||||
# textual return value. Frontends that read the artifact via the
|
||||
# repo dict get ``content`` (PG-native) plus the artifact id below.
|
||||
body = (doc or {}).get("content")
|
||||
if not doc or not body:
|
||||
return "No note found."
|
||||
if doc.get("_id") is not None:
|
||||
self._last_artifact_id = str(doc.get("_id"))
|
||||
return str(doc["note"])
|
||||
if doc.get("id") is not None:
|
||||
self._last_artifact_id = str(doc.get("id"))
|
||||
return str(body)
|
||||
|
||||
def _overwrite_note(self, content: str) -> str:
|
||||
content = (content or "").strip()
|
||||
if not content:
|
||||
return "Note content required."
|
||||
result = self.collection.find_one_and_update(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id},
|
||||
{"$set": {"note": content, "updated_at": datetime.utcnow()}},
|
||||
upsert=True,
|
||||
return_document=True,
|
||||
)
|
||||
if result and result.get("_id") is not None:
|
||||
self._last_artifact_id = str(result.get("_id"))
|
||||
with db_session() as conn:
|
||||
row = NotesRepository(conn).upsert(
|
||||
self.user_id, self.tool_id, _NOTE_TITLE, content
|
||||
)
|
||||
if row and row.get("id") is not None:
|
||||
self._last_artifact_id = str(row.get("id"))
|
||||
return "Note saved."
|
||||
|
||||
def _str_replace(self, old_str: str, new_str: str) -> str:
|
||||
if not old_str:
|
||||
return "old_str is required."
|
||||
|
||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
|
||||
if not doc or not doc.get("note"):
|
||||
doc = self._fetch_note()
|
||||
existing = (doc or {}).get("content")
|
||||
if not doc or not existing:
|
||||
return "No note found."
|
||||
|
||||
current_note = str(doc["note"])
|
||||
current_note = str(existing)
|
||||
|
||||
# Case-insensitive search
|
||||
if old_str.lower() not in current_note.lower():
|
||||
@@ -175,24 +208,24 @@ class NotesTool(Tool):
|
||||
import re
|
||||
updated_note = re.sub(re.escape(old_str), new_str, current_note, flags=re.IGNORECASE)
|
||||
|
||||
result = self.collection.find_one_and_update(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id},
|
||||
{"$set": {"note": updated_note, "updated_at": datetime.utcnow()}},
|
||||
return_document=True,
|
||||
)
|
||||
if result and result.get("_id") is not None:
|
||||
self._last_artifact_id = str(result.get("_id"))
|
||||
with db_session() as conn:
|
||||
row = NotesRepository(conn).upsert(
|
||||
self.user_id, self.tool_id, _NOTE_TITLE, updated_note
|
||||
)
|
||||
if row and row.get("id") is not None:
|
||||
self._last_artifact_id = str(row.get("id"))
|
||||
return "Note updated."
|
||||
|
||||
def _insert(self, line_number: int, text: str) -> str:
|
||||
if not text:
|
||||
return "Text is required."
|
||||
|
||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
|
||||
if not doc or not doc.get("note"):
|
||||
doc = self._fetch_note()
|
||||
existing = (doc or {}).get("content")
|
||||
if not doc or not existing:
|
||||
return "No note found."
|
||||
|
||||
current_note = str(doc["note"])
|
||||
current_note = str(existing)
|
||||
lines = current_note.split("\n")
|
||||
|
||||
# Convert to 0-indexed and validate
|
||||
@@ -203,21 +236,23 @@ class NotesTool(Tool):
|
||||
lines.insert(index, text)
|
||||
updated_note = "\n".join(lines)
|
||||
|
||||
result = self.collection.find_one_and_update(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id},
|
||||
{"$set": {"note": updated_note, "updated_at": datetime.utcnow()}},
|
||||
return_document=True,
|
||||
)
|
||||
if result and result.get("_id") is not None:
|
||||
self._last_artifact_id = str(result.get("_id"))
|
||||
with db_session() as conn:
|
||||
row = NotesRepository(conn).upsert(
|
||||
self.user_id, self.tool_id, _NOTE_TITLE, updated_note
|
||||
)
|
||||
if row and row.get("id") is not None:
|
||||
self._last_artifact_id = str(row.get("id"))
|
||||
return "Text inserted."
|
||||
|
||||
def _delete_note(self) -> str:
|
||||
doc = self.collection.find_one_and_delete(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id}
|
||||
)
|
||||
if not doc:
|
||||
# Capture the id (for artifact tracking) before deleting.
|
||||
existing = self._fetch_note()
|
||||
if not existing:
|
||||
return "No note found to delete."
|
||||
if doc.get("_id") is not None:
|
||||
self._last_artifact_id = str(doc.get("_id"))
|
||||
with db_session() as conn:
|
||||
deleted = NotesRepository(conn).delete(self.user_id, self.tool_id)
|
||||
if not deleted:
|
||||
return "No note found to delete."
|
||||
if existing.get("id") is not None:
|
||||
self._last_artifact_id = str(existing.get("id"))
|
||||
return "Note deleted."
|
||||
|
||||
@@ -1,10 +1,19 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
import uuid
|
||||
|
||||
from .base import Tool
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.repositories.todos import TodosRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
|
||||
|
||||
def _status_from_completed(completed: Any) -> str:
|
||||
"""Translate the PG ``completed`` boolean to the legacy status string.
|
||||
|
||||
The frontend (and prior LLM-facing tool output) expects
|
||||
``"open"`` / ``"completed"``. Keeping that contract at the tool
|
||||
boundary insulates callers from the schema change.
|
||||
"""
|
||||
return "completed" if bool(completed) else "open"
|
||||
|
||||
|
||||
class TodoListTool(Tool):
|
||||
@@ -25,7 +34,6 @@ class TodoListTool(Tool):
|
||||
self.user_id: Optional[str] = user_id
|
||||
|
||||
# Get tool_id from configuration (passed from user_tools._id in production)
|
||||
# In production, tool_id is the MongoDB ObjectId string from user_tools collection
|
||||
if tool_config and "tool_id" in tool_config:
|
||||
self.tool_id = tool_config["tool_id"]
|
||||
elif user_id:
|
||||
@@ -35,11 +43,27 @@ class TodoListTool(Tool):
|
||||
# Last resort fallback (shouldn't happen in normal use)
|
||||
self.tool_id = str(uuid.uuid4())
|
||||
|
||||
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
|
||||
self.collection = db["todos"]
|
||||
|
||||
self._last_artifact_id: Optional[str] = None
|
||||
|
||||
def _pg_enabled(self) -> bool:
|
||||
"""Return True only when ``tool_id`` is a real ``user_tools.id`` UUID.
|
||||
|
||||
The ``todos`` PG table has a UUID foreign key to ``user_tools`` and
|
||||
the repo queries ``CAST(:tool_id AS uuid)``. The sentinel
|
||||
``default_{uid}`` fallback is neither a UUID nor a row in
|
||||
``user_tools`` — binding it would crash ``invalid input syntax for
|
||||
type uuid`` and even if it didn't the FK would reject it. Mirror
|
||||
the MemoryTool guard and no-op in that case.
|
||||
"""
|
||||
tool_id = getattr(self, "tool_id", None)
|
||||
if not tool_id or not isinstance(tool_id, str):
|
||||
return False
|
||||
if tool_id.startswith("default_"):
|
||||
return False
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
|
||||
return looks_like_uuid(tool_id)
|
||||
|
||||
# -----------------------------
|
||||
# Action implementations
|
||||
# -----------------------------
|
||||
@@ -56,6 +80,12 @@ class TodoListTool(Tool):
|
||||
if not self.user_id:
|
||||
return "Error: TodoListTool requires a valid user_id."
|
||||
|
||||
if not self._pg_enabled():
|
||||
return (
|
||||
"Error: TodoListTool is not configured with a persistent "
|
||||
"tool_id; todo storage is unavailable for this session."
|
||||
)
|
||||
|
||||
self._last_artifact_id = None
|
||||
|
||||
if action_name == "list":
|
||||
@@ -191,28 +221,10 @@ class TodoListTool(Tool):
|
||||
|
||||
return None
|
||||
|
||||
def _get_next_todo_id(self) -> int:
|
||||
"""Get the next sequential todo_id for this user and tool.
|
||||
|
||||
Returns a simple integer (1, 2, 3, ...) scoped to this user/tool.
|
||||
With 5-10 todos max, scanning is negligible.
|
||||
"""
|
||||
query = {"user_id": self.user_id, "tool_id": self.tool_id}
|
||||
todos = list(self.collection.find(query, {"todo_id": 1}))
|
||||
|
||||
# Find the maximum todo_id
|
||||
max_id = 0
|
||||
for todo in todos:
|
||||
todo_id = self._coerce_todo_id(todo.get("todo_id"))
|
||||
if todo_id is not None:
|
||||
max_id = max(max_id, todo_id)
|
||||
|
||||
return max_id + 1
|
||||
|
||||
def _list(self) -> str:
|
||||
"""List all todos for the user."""
|
||||
query = {"user_id": self.user_id, "tool_id": self.tool_id}
|
||||
todos = list(self.collection.find(query))
|
||||
with db_readonly() as conn:
|
||||
todos = TodosRepository(conn).list_for_tool(self.user_id, self.tool_id)
|
||||
|
||||
if not todos:
|
||||
return "No todos found."
|
||||
@@ -221,7 +233,7 @@ class TodoListTool(Tool):
|
||||
for doc in todos:
|
||||
todo_id = doc.get("todo_id")
|
||||
title = doc.get("title", "Untitled")
|
||||
status = doc.get("status", "open")
|
||||
status = _status_from_completed(doc.get("completed"))
|
||||
|
||||
line = f"[{todo_id}] {title} ({status})"
|
||||
result_lines.append(line)
|
||||
@@ -229,27 +241,23 @@ class TodoListTool(Tool):
|
||||
return "\n".join(result_lines)
|
||||
|
||||
def _create(self, title: str) -> str:
|
||||
"""Create a new todo item."""
|
||||
"""Create a new todo item.
|
||||
|
||||
``TodosRepository.create`` allocates the per-tool monotonic
|
||||
``todo_id`` inside the same transaction (``COALESCE(MAX(todo_id),0)+1``
|
||||
scoped to ``tool_id``), so we no longer need a separate read-then-
|
||||
write step here.
|
||||
"""
|
||||
title = (title or "").strip()
|
||||
if not title:
|
||||
return "Error: Title is required."
|
||||
|
||||
now = datetime.now()
|
||||
todo_id = self._get_next_todo_id()
|
||||
with db_session() as conn:
|
||||
row = TodosRepository(conn).create(self.user_id, self.tool_id, title)
|
||||
|
||||
doc = {
|
||||
"todo_id": todo_id,
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"title": title,
|
||||
"status": "open",
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
}
|
||||
insert_result = self.collection.insert_one(doc)
|
||||
inserted_id = getattr(insert_result, "inserted_id", None) or doc.get("_id")
|
||||
if inserted_id is not None:
|
||||
self._last_artifact_id = str(inserted_id)
|
||||
todo_id = row.get("todo_id")
|
||||
if row.get("id") is not None:
|
||||
self._last_artifact_id = str(row.get("id"))
|
||||
return f"Todo created with ID {todo_id}: {title}"
|
||||
|
||||
def _get(self, todo_id: Optional[Any]) -> str:
|
||||
@@ -258,21 +266,21 @@ class TodoListTool(Tool):
|
||||
if parsed_todo_id is None:
|
||||
return "Error: todo_id must be a positive integer."
|
||||
|
||||
query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}
|
||||
doc = self.collection.find_one(query)
|
||||
with db_readonly() as conn:
|
||||
doc = TodosRepository(conn).get_by_tool_and_todo_id(
|
||||
self.user_id, self.tool_id, parsed_todo_id
|
||||
)
|
||||
|
||||
if not doc:
|
||||
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||
|
||||
if doc.get("_id") is not None:
|
||||
self._last_artifact_id = str(doc.get("_id"))
|
||||
if doc.get("id") is not None:
|
||||
self._last_artifact_id = str(doc.get("id"))
|
||||
|
||||
title = doc.get("title", "Untitled")
|
||||
status = doc.get("status", "open")
|
||||
status = _status_from_completed(doc.get("completed"))
|
||||
|
||||
result = f"Todo [{parsed_todo_id}]:\nTitle: {title}\nStatus: {status}"
|
||||
|
||||
return result
|
||||
return f"Todo [{parsed_todo_id}]:\nTitle: {title}\nStatus: {status}"
|
||||
|
||||
def _update(self, todo_id: Optional[Any], title: str) -> str:
|
||||
"""Update a todo's title by ID."""
|
||||
@@ -284,16 +292,19 @@ class TodoListTool(Tool):
|
||||
if not title:
|
||||
return "Error: Title is required."
|
||||
|
||||
query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}
|
||||
doc = self.collection.find_one_and_update(
|
||||
query,
|
||||
{"$set": {"title": title, "updated_at": datetime.now()}},
|
||||
)
|
||||
if not doc:
|
||||
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||
with db_session() as conn:
|
||||
repo = TodosRepository(conn)
|
||||
existing = repo.get_by_tool_and_todo_id(
|
||||
self.user_id, self.tool_id, parsed_todo_id
|
||||
)
|
||||
if not existing:
|
||||
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||
repo.update_title_by_tool_and_todo_id(
|
||||
self.user_id, self.tool_id, parsed_todo_id, title
|
||||
)
|
||||
|
||||
if doc.get("_id") is not None:
|
||||
self._last_artifact_id = str(doc.get("_id"))
|
||||
if existing.get("id") is not None:
|
||||
self._last_artifact_id = str(existing.get("id"))
|
||||
|
||||
return f"Todo {parsed_todo_id} updated to: {title}"
|
||||
|
||||
@@ -303,16 +314,17 @@ class TodoListTool(Tool):
|
||||
if parsed_todo_id is None:
|
||||
return "Error: todo_id must be a positive integer."
|
||||
|
||||
query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}
|
||||
doc = self.collection.find_one_and_update(
|
||||
query,
|
||||
{"$set": {"status": "completed", "updated_at": datetime.now()}},
|
||||
)
|
||||
if not doc:
|
||||
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||
with db_session() as conn:
|
||||
repo = TodosRepository(conn)
|
||||
existing = repo.get_by_tool_and_todo_id(
|
||||
self.user_id, self.tool_id, parsed_todo_id
|
||||
)
|
||||
if not existing:
|
||||
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||
repo.set_completed(self.user_id, self.tool_id, parsed_todo_id, True)
|
||||
|
||||
if doc.get("_id") is not None:
|
||||
self._last_artifact_id = str(doc.get("_id"))
|
||||
if existing.get("id") is not None:
|
||||
self._last_artifact_id = str(existing.get("id"))
|
||||
|
||||
return f"Todo {parsed_todo_id} marked as completed."
|
||||
|
||||
@@ -322,12 +334,18 @@ class TodoListTool(Tool):
|
||||
if parsed_todo_id is None:
|
||||
return "Error: todo_id must be a positive integer."
|
||||
|
||||
query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}
|
||||
doc = self.collection.find_one_and_delete(query)
|
||||
if not doc:
|
||||
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||
with db_session() as conn:
|
||||
repo = TodosRepository(conn)
|
||||
existing = repo.get_by_tool_and_todo_id(
|
||||
self.user_id, self.tool_id, parsed_todo_id
|
||||
)
|
||||
if not existing:
|
||||
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||
repo.delete_by_tool_and_todo_id(
|
||||
self.user_id, self.tool_id, parsed_todo_id
|
||||
)
|
||||
|
||||
if doc.get("_id") is not None:
|
||||
self._last_artifact_id = str(doc.get("_id"))
|
||||
if existing.get("id") is not None:
|
||||
self._last_artifact_id = str(existing.get("id"))
|
||||
|
||||
return f"Todo {parsed_todo_id} deleted."
|
||||
|
||||
@@ -12,12 +12,13 @@ from application.agents.workflows.schemas import (
|
||||
WorkflowRun,
|
||||
)
|
||||
from application.agents.workflows.workflow_engine import WorkflowEngine
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.logging import log_activity, LogContext
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
from application.storage.db.repositories.workflow_edges import WorkflowEdgesRepository
|
||||
from application.storage.db.repositories.workflow_nodes import WorkflowNodesRepository
|
||||
from application.storage.db.repositories.workflow_runs import WorkflowRunsRepository
|
||||
from application.storage.db.repositories.workflows import WorkflowsRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -106,10 +107,8 @@ class WorkflowAgent(BaseAgent):
|
||||
|
||||
def _load_from_database(self) -> Optional[WorkflowGraph]:
|
||||
try:
|
||||
from bson.objectid import ObjectId
|
||||
|
||||
if not self.workflow_id or not ObjectId.is_valid(self.workflow_id):
|
||||
logger.error(f"Invalid workflow ID: {self.workflow_id}")
|
||||
if not self.workflow_id:
|
||||
logger.error("Missing workflow ID for load")
|
||||
return None
|
||||
owner_id = self.workflow_owner
|
||||
if not owner_id and isinstance(self.decoded_token, dict):
|
||||
@@ -120,61 +119,61 @@ class WorkflowAgent(BaseAgent):
|
||||
)
|
||||
return None
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
|
||||
workflows_coll = db["workflows"]
|
||||
workflow_nodes_coll = db["workflow_nodes"]
|
||||
workflow_edges_coll = db["workflow_edges"]
|
||||
|
||||
workflow_doc = workflows_coll.find_one(
|
||||
{"_id": ObjectId(self.workflow_id), "user": owner_id}
|
||||
)
|
||||
if not workflow_doc:
|
||||
logger.error(
|
||||
f"Workflow {self.workflow_id} not found or inaccessible for user {owner_id}"
|
||||
)
|
||||
return None
|
||||
workflow = Workflow(**workflow_doc)
|
||||
graph_version = workflow_doc.get("current_graph_version", 1)
|
||||
try:
|
||||
graph_version = int(graph_version)
|
||||
if graph_version <= 0:
|
||||
with db_readonly() as conn:
|
||||
wf_repo = WorkflowsRepository(conn)
|
||||
if looks_like_uuid(self.workflow_id):
|
||||
workflow_row = wf_repo.get(self.workflow_id, owner_id)
|
||||
else:
|
||||
workflow_row = wf_repo.get_by_legacy_id(self.workflow_id, owner_id)
|
||||
if workflow_row is None:
|
||||
logger.error(
|
||||
f"Workflow {self.workflow_id} not found or inaccessible "
|
||||
f"for user {owner_id}"
|
||||
)
|
||||
return None
|
||||
pg_workflow_id = str(workflow_row["id"])
|
||||
graph_version = workflow_row.get("current_graph_version", 1)
|
||||
try:
|
||||
graph_version = int(graph_version)
|
||||
if graph_version <= 0:
|
||||
graph_version = 1
|
||||
except (ValueError, TypeError):
|
||||
graph_version = 1
|
||||
except (ValueError, TypeError):
|
||||
graph_version = 1
|
||||
|
||||
nodes_docs = list(
|
||||
workflow_nodes_coll.find(
|
||||
{"workflow_id": self.workflow_id, "graph_version": graph_version}
|
||||
node_rows = WorkflowNodesRepository(conn).find_by_version(
|
||||
pg_workflow_id, graph_version,
|
||||
)
|
||||
)
|
||||
if not nodes_docs and graph_version == 1:
|
||||
nodes_docs = list(
|
||||
workflow_nodes_coll.find(
|
||||
{
|
||||
"workflow_id": self.workflow_id,
|
||||
"graph_version": {"$exists": False},
|
||||
}
|
||||
)
|
||||
edge_rows = WorkflowEdgesRepository(conn).find_by_version(
|
||||
pg_workflow_id, graph_version,
|
||||
)
|
||||
nodes = [WorkflowNode(**doc) for doc in nodes_docs]
|
||||
|
||||
edges_docs = list(
|
||||
workflow_edges_coll.find(
|
||||
{"workflow_id": self.workflow_id, "graph_version": graph_version}
|
||||
)
|
||||
workflow = Workflow(
|
||||
name=workflow_row.get("name"),
|
||||
description=workflow_row.get("description"),
|
||||
)
|
||||
if not edges_docs and graph_version == 1:
|
||||
edges_docs = list(
|
||||
workflow_edges_coll.find(
|
||||
{
|
||||
"workflow_id": self.workflow_id,
|
||||
"graph_version": {"$exists": False},
|
||||
}
|
||||
)
|
||||
nodes = [
|
||||
WorkflowNode(
|
||||
id=n["node_id"],
|
||||
workflow_id=pg_workflow_id,
|
||||
type=n["node_type"],
|
||||
title=n.get("title") or "Node",
|
||||
description=n.get("description"),
|
||||
position=n.get("position") or {"x": 0, "y": 0},
|
||||
config=n.get("config") or {},
|
||||
)
|
||||
edges = [WorkflowEdge(**doc) for doc in edges_docs]
|
||||
for n in node_rows
|
||||
]
|
||||
edges = [
|
||||
WorkflowEdge(
|
||||
id=e["edge_id"],
|
||||
workflow_id=pg_workflow_id,
|
||||
source=e.get("source_id"),
|
||||
target=e.get("target_id"),
|
||||
sourceHandle=e.get("source_handle"),
|
||||
targetHandle=e.get("target_handle"),
|
||||
)
|
||||
for e in edge_rows
|
||||
]
|
||||
|
||||
return WorkflowGraph(workflow=workflow, nodes=nodes, edges=edges)
|
||||
except Exception as e:
|
||||
@@ -188,10 +187,6 @@ class WorkflowAgent(BaseAgent):
|
||||
if not owner_id and isinstance(self.decoded_token, dict):
|
||||
owner_id = self.decoded_token.get("sub")
|
||||
try:
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
workflow_runs_coll = db["workflow_runs"]
|
||||
|
||||
run = WorkflowRun(
|
||||
workflow_id=self.workflow_id or "unknown",
|
||||
user=owner_id,
|
||||
@@ -203,23 +198,20 @@ class WorkflowAgent(BaseAgent):
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
result = workflow_runs_coll.insert_one(run.to_mongo_doc())
|
||||
legacy_mongo_id = (
|
||||
str(result.inserted_id)
|
||||
if getattr(result, "inserted_id", None) is not None
|
||||
else None
|
||||
)
|
||||
|
||||
def _pg_write(repo: WorkflowRunsRepository) -> None:
|
||||
if not self.workflow_id or not owner_id or not legacy_mongo_id:
|
||||
if not self.workflow_id or not owner_id:
|
||||
return
|
||||
with db_session() as conn:
|
||||
wf_repo = WorkflowsRepository(conn)
|
||||
if looks_like_uuid(self.workflow_id):
|
||||
workflow_row = wf_repo.get(self.workflow_id, owner_id)
|
||||
else:
|
||||
workflow_row = wf_repo.get_by_legacy_id(
|
||||
self.workflow_id, owner_id,
|
||||
)
|
||||
if workflow_row is None:
|
||||
return
|
||||
workflow = WorkflowsRepository(repo._conn).get_by_legacy_id(
|
||||
self.workflow_id, owner_id,
|
||||
)
|
||||
if workflow is None:
|
||||
return
|
||||
repo.create(
|
||||
workflow["id"],
|
||||
WorkflowRunsRepository(conn).create(
|
||||
str(workflow_row["id"]),
|
||||
owner_id,
|
||||
run.status.value,
|
||||
inputs=run.inputs,
|
||||
@@ -227,10 +219,7 @@ class WorkflowAgent(BaseAgent):
|
||||
steps=[step.model_dump(mode="json") for step in run.steps],
|
||||
started_at=run.created_at,
|
||||
ended_at=run.completed_at,
|
||||
legacy_mongo_id=legacy_mongo_id,
|
||||
)
|
||||
|
||||
dual_write(WorkflowRunsRepository, _pg_write)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save workflow run: {e}")
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from bson import ObjectId
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
|
||||
@@ -81,24 +80,7 @@ class WorkflowEdgeCreate(BaseModel):
|
||||
|
||||
|
||||
class WorkflowEdge(WorkflowEdgeCreate):
|
||||
mongo_id: Optional[str] = Field(None, alias="_id")
|
||||
|
||||
@field_validator("mongo_id", mode="before")
|
||||
@classmethod
|
||||
def convert_objectid(cls, v: Any) -> Optional[str]:
|
||||
if isinstance(v, ObjectId):
|
||||
return str(v)
|
||||
return v
|
||||
|
||||
def to_mongo_doc(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"workflow_id": self.workflow_id,
|
||||
"source_id": self.source_id,
|
||||
"target_id": self.target_id,
|
||||
"source_handle": self.source_handle,
|
||||
"target_handle": self.target_handle,
|
||||
}
|
||||
pass
|
||||
|
||||
|
||||
class WorkflowNodeCreate(BaseModel):
|
||||
@@ -120,25 +102,7 @@ class WorkflowNodeCreate(BaseModel):
|
||||
|
||||
|
||||
class WorkflowNode(WorkflowNodeCreate):
|
||||
mongo_id: Optional[str] = Field(None, alias="_id")
|
||||
|
||||
@field_validator("mongo_id", mode="before")
|
||||
@classmethod
|
||||
def convert_objectid(cls, v: Any) -> Optional[str]:
|
||||
if isinstance(v, ObjectId):
|
||||
return str(v)
|
||||
return v
|
||||
|
||||
def to_mongo_doc(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"workflow_id": self.workflow_id,
|
||||
"type": self.type.value,
|
||||
"title": self.title,
|
||||
"description": self.description,
|
||||
"position": self.position.model_dump(),
|
||||
"config": self.config,
|
||||
}
|
||||
pass
|
||||
|
||||
|
||||
class WorkflowCreate(BaseModel):
|
||||
@@ -149,26 +113,10 @@ class WorkflowCreate(BaseModel):
|
||||
|
||||
|
||||
class Workflow(WorkflowCreate):
|
||||
id: Optional[str] = Field(None, alias="_id")
|
||||
id: Optional[str] = None
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
@classmethod
|
||||
def convert_objectid(cls, v: Any) -> Optional[str]:
|
||||
if isinstance(v, ObjectId):
|
||||
return str(v)
|
||||
return v
|
||||
|
||||
def to_mongo_doc(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"user": self.user,
|
||||
"created_at": self.created_at,
|
||||
"updated_at": self.updated_at,
|
||||
}
|
||||
|
||||
|
||||
class WorkflowGraph(BaseModel):
|
||||
workflow: Workflow
|
||||
@@ -209,7 +157,7 @@ class WorkflowRunCreate(BaseModel):
|
||||
|
||||
class WorkflowRun(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
id: Optional[str] = Field(None, alias="_id")
|
||||
id: Optional[str] = None
|
||||
workflow_id: str
|
||||
user: Optional[str] = None
|
||||
status: ExecutionStatus = ExecutionStatus.PENDING
|
||||
@@ -218,25 +166,3 @@ class WorkflowRun(BaseModel):
|
||||
steps: List[NodeExecutionLog] = Field(default_factory=list)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
completed_at: Optional[datetime] = None
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
@classmethod
|
||||
def convert_objectid(cls, v: Any) -> Optional[str]:
|
||||
if isinstance(v, ObjectId):
|
||||
return str(v)
|
||||
return v
|
||||
|
||||
def to_mongo_doc(self) -> Dict[str, Any]:
|
||||
doc = {
|
||||
"workflow_id": self.workflow_id,
|
||||
"status": self.status.value,
|
||||
"inputs": self.inputs,
|
||||
"outputs": self.outputs,
|
||||
"steps": [step.model_dump() for step in self.steps],
|
||||
"created_at": self.created_at,
|
||||
"completed_at": self.completed_at,
|
||||
}
|
||||
if self.user:
|
||||
doc["user"] = self.user
|
||||
doc["user_id"] = self.user
|
||||
return doc
|
||||
|
||||
@@ -200,6 +200,9 @@ class WorkflowEngine:
|
||||
|
||||
node_config = AgentNodeConfig(**node.config.get("config", node.config))
|
||||
|
||||
if node_config.sources:
|
||||
self._retrieve_node_sources(node_config)
|
||||
|
||||
if node_config.prompt_template:
|
||||
formatted_prompt = self._format_template(node_config.prompt_template)
|
||||
else:
|
||||
@@ -455,6 +458,29 @@ class WorkflowEngine:
|
||||
docs_together = "\n\n".join(docs_together_parts) if docs_together_parts else None
|
||||
return docs, docs_together
|
||||
|
||||
def _retrieve_node_sources(self, node_config: AgentNodeConfig) -> None:
|
||||
"""Retrieve documents from the node's sources for template resolution."""
|
||||
from application.retriever.retriever_creator import RetrieverCreator
|
||||
|
||||
query = self.state.get("query", "")
|
||||
if not query:
|
||||
return
|
||||
|
||||
try:
|
||||
retriever = RetrieverCreator.create_retriever(
|
||||
node_config.retriever or "classic",
|
||||
source={"active_docs": node_config.sources},
|
||||
chat_history=[],
|
||||
prompt="",
|
||||
chunks=int(node_config.chunks) if node_config.chunks else 2,
|
||||
decoded_token=self.agent.decoded_token,
|
||||
)
|
||||
docs = retriever.search(query)
|
||||
if docs:
|
||||
self.agent.retrieved_docs = docs
|
||||
except Exception:
|
||||
logger.exception("Failed to retrieve docs for workflow node")
|
||||
|
||||
def get_execution_summary(self) -> List[NodeExecutionLog]:
|
||||
return [
|
||||
NodeExecutionLog(
|
||||
|
||||
@@ -167,14 +167,19 @@ def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE user_tools (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
custom_name TEXT,
|
||||
display_name TEXT,
|
||||
config JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
custom_name TEXT,
|
||||
display_name TEXT,
|
||||
description TEXT,
|
||||
config JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
config_requirements JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
actions JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||
status BOOLEAN NOT NULL DEFAULT true,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
legacy_mongo_id TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
@@ -188,7 +193,8 @@ def upgrade() -> None:
|
||||
agent_id UUID,
|
||||
prompt_tokens INTEGER NOT NULL DEFAULT 0,
|
||||
generated_tokens INTEGER NOT NULL DEFAULT 0,
|
||||
timestamp TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
timestamp TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
mongo_id TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
@@ -204,7 +210,8 @@ def upgrade() -> None:
|
||||
user_id TEXT,
|
||||
endpoint TEXT,
|
||||
timestamp TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
data JSONB
|
||||
data JSONB,
|
||||
mongo_id TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
@@ -220,7 +227,8 @@ def upgrade() -> None:
|
||||
api_key TEXT,
|
||||
query TEXT,
|
||||
stacks JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||
timestamp TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
timestamp TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
mongo_id TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
@@ -228,12 +236,14 @@ def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE agent_folders (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
description TEXT,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
description TEXT,
|
||||
parent_id UUID REFERENCES agent_folders(id) ON DELETE SET NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
legacy_mongo_id TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
@@ -241,13 +251,24 @@ def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE sources (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
type TEXT,
|
||||
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
language TEXT,
|
||||
date TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
model TEXT,
|
||||
type TEXT,
|
||||
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
retriever TEXT,
|
||||
sync_frequency TEXT,
|
||||
tokens TEXT,
|
||||
file_path TEXT,
|
||||
remote_data JSONB,
|
||||
directory_structure JSONB,
|
||||
file_name_map JSONB,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
legacy_mongo_id TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
@@ -255,33 +276,38 @@ def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE agents (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
description TEXT,
|
||||
agent_type TEXT,
|
||||
status TEXT NOT NULL,
|
||||
key CITEXT UNIQUE,
|
||||
source_id UUID REFERENCES sources(id) ON DELETE SET NULL,
|
||||
extra_source_ids UUID[] NOT NULL DEFAULT '{}',
|
||||
chunks INTEGER,
|
||||
retriever TEXT,
|
||||
prompt_id UUID REFERENCES prompts(id) ON DELETE SET NULL,
|
||||
tools JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||
json_schema JSONB,
|
||||
models JSONB,
|
||||
default_model_id TEXT,
|
||||
folder_id UUID REFERENCES agent_folders(id) ON DELETE SET NULL,
|
||||
limited_token_mode BOOLEAN NOT NULL DEFAULT false,
|
||||
token_limit INTEGER,
|
||||
limited_request_mode BOOLEAN NOT NULL DEFAULT false,
|
||||
request_limit INTEGER,
|
||||
shared BOOLEAN NOT NULL DEFAULT false,
|
||||
incoming_webhook_token CITEXT UNIQUE,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
last_used_at TIMESTAMPTZ,
|
||||
legacy_mongo_id TEXT
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
description TEXT,
|
||||
agent_type TEXT,
|
||||
status TEXT NOT NULL,
|
||||
key CITEXT UNIQUE,
|
||||
image TEXT,
|
||||
source_id UUID REFERENCES sources(id) ON DELETE SET NULL,
|
||||
extra_source_ids UUID[] NOT NULL DEFAULT '{}',
|
||||
chunks INTEGER,
|
||||
retriever TEXT,
|
||||
prompt_id UUID REFERENCES prompts(id) ON DELETE SET NULL,
|
||||
tools JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||
json_schema JSONB,
|
||||
models JSONB,
|
||||
default_model_id TEXT,
|
||||
folder_id UUID REFERENCES agent_folders(id) ON DELETE SET NULL,
|
||||
workflow_id UUID,
|
||||
limited_token_mode BOOLEAN NOT NULL DEFAULT false,
|
||||
token_limit INTEGER,
|
||||
limited_request_mode BOOLEAN NOT NULL DEFAULT false,
|
||||
request_limit INTEGER,
|
||||
allow_system_prompt_override BOOLEAN NOT NULL DEFAULT false,
|
||||
shared BOOLEAN NOT NULL DEFAULT false,
|
||||
shared_token CITEXT UNIQUE,
|
||||
shared_metadata JSONB,
|
||||
incoming_webhook_token CITEXT UNIQUE,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
last_used_at TIMESTAMPTZ,
|
||||
legacy_mongo_id TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
@@ -299,6 +325,11 @@ def upgrade() -> None:
|
||||
upload_path TEXT NOT NULL,
|
||||
mime_type TEXT,
|
||||
size BIGINT,
|
||||
content TEXT,
|
||||
token_count INTEGER,
|
||||
openai_file_id TEXT,
|
||||
google_file_uri TEXT,
|
||||
metadata JSONB,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
legacy_mongo_id TEXT
|
||||
);
|
||||
@@ -313,6 +344,7 @@ def upgrade() -> None:
|
||||
tool_id UUID REFERENCES user_tools(id) ON DELETE CASCADE,
|
||||
path TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
"""
|
||||
@@ -321,13 +353,16 @@ def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE todos (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
tool_id UUID REFERENCES user_tools(id) ON DELETE CASCADE,
|
||||
title TEXT NOT NULL,
|
||||
completed BOOLEAN NOT NULL DEFAULT false,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
tool_id UUID REFERENCES user_tools(id) ON DELETE CASCADE,
|
||||
todo_id INTEGER,
|
||||
title TEXT NOT NULL,
|
||||
completed BOOLEAN NOT NULL DEFAULT false,
|
||||
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
legacy_mongo_id TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
@@ -335,13 +370,15 @@ def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE notes (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
tool_id UUID REFERENCES user_tools(id) ON DELETE CASCADE,
|
||||
title TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
tool_id UUID REFERENCES user_tools(id) ON DELETE CASCADE,
|
||||
title TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
legacy_mongo_id TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
@@ -349,12 +386,18 @@ def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE connector_sessions (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
provider TEXT NOT NULL,
|
||||
session_data JSONB NOT NULL,
|
||||
expires_at TIMESTAMPTZ,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
provider TEXT NOT NULL,
|
||||
server_url TEXT,
|
||||
session_token TEXT UNIQUE,
|
||||
user_email TEXT,
|
||||
status TEXT,
|
||||
token_info JSONB,
|
||||
session_data JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
expires_at TIMESTAMPTZ,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
legacy_mongo_id TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
@@ -454,6 +497,14 @@ def upgrade() -> None:
|
||||
);
|
||||
"""
|
||||
)
|
||||
# Backfill the agents.workflow_id FK now that workflows exists.
|
||||
# The column was created without a FK (forward reference to a table
|
||||
# that hadn't been declared yet); add the constraint here so workflow
|
||||
# deletion still cascades through to agent unset.
|
||||
op.execute(
|
||||
"ALTER TABLE agents ADD CONSTRAINT agents_workflow_fk "
|
||||
"FOREIGN KEY (workflow_id) REFERENCES workflows(id) ON DELETE SET NULL;"
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
@@ -539,13 +590,26 @@ def upgrade() -> None:
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX connector_sessions_user_provider_uidx "
|
||||
"ON connector_sessions (user_id, provider);"
|
||||
# MCP and OAuth connectors share the ``provider`` slot, so the
|
||||
# dedup key is ``(user_id, server_url, provider)``: MCP rows
|
||||
# differentiate by server_url (one per MCP server), OAuth rows
|
||||
# have server_url = NULL and differentiate by provider alone.
|
||||
# COALESCE lets NULL server_url participate in the constraint.
|
||||
"CREATE UNIQUE INDEX connector_sessions_user_endpoint_uidx "
|
||||
"ON connector_sessions (user_id, COALESCE(server_url, ''), provider);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX connector_sessions_expiry_idx "
|
||||
"ON connector_sessions (expires_at) WHERE expires_at IS NOT NULL;"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX connector_sessions_server_url_idx "
|
||||
"ON connector_sessions (server_url) WHERE server_url IS NOT NULL;"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX connector_sessions_legacy_mongo_id_uidx "
|
||||
"ON connector_sessions (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX conversation_messages_conv_pos_uidx "
|
||||
@@ -587,6 +651,10 @@ def upgrade() -> None:
|
||||
|
||||
op.execute("CREATE UNIQUE INDEX notes_user_tool_uidx ON notes (user_id, tool_id);")
|
||||
op.execute("CREATE INDEX notes_tool_id_idx ON notes (tool_id);")
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX notes_legacy_mongo_id_uidx "
|
||||
"ON notes (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX pending_tool_state_conv_user_uidx "
|
||||
@@ -616,20 +684,54 @@ def upgrade() -> None:
|
||||
)
|
||||
|
||||
op.execute("CREATE INDEX sources_user_idx ON sources (user_id);")
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX sources_legacy_mongo_id_uidx "
|
||||
"ON sources (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX user_tools_legacy_mongo_id_uidx "
|
||||
"ON user_tools (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX agent_folders_legacy_mongo_id_uidx "
|
||||
"ON agent_folders (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||
)
|
||||
op.execute("CREATE INDEX agent_folders_parent_idx ON agent_folders (parent_id);")
|
||||
op.execute("CREATE INDEX agents_workflow_idx ON agents (workflow_id);")
|
||||
|
||||
op.execute('CREATE INDEX stack_logs_timestamp_idx ON stack_logs ("timestamp" DESC);')
|
||||
op.execute('CREATE INDEX stack_logs_user_ts_idx ON stack_logs (user_id, "timestamp" DESC);')
|
||||
op.execute('CREATE INDEX stack_logs_level_ts_idx ON stack_logs (level, "timestamp" DESC);')
|
||||
op.execute("CREATE INDEX stack_logs_activity_idx ON stack_logs (activity_id);")
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX stack_logs_mongo_id_uidx "
|
||||
"ON stack_logs (mongo_id) WHERE mongo_id IS NOT NULL;"
|
||||
)
|
||||
|
||||
op.execute("CREATE INDEX todos_user_tool_idx ON todos (user_id, tool_id);")
|
||||
op.execute("CREATE INDEX todos_tool_id_idx ON todos (tool_id);")
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX todos_legacy_mongo_id_uidx "
|
||||
"ON todos (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX todos_tool_todo_id_uidx "
|
||||
"ON todos (tool_id, todo_id) WHERE todo_id IS NOT NULL;"
|
||||
)
|
||||
|
||||
op.execute('CREATE INDEX token_usage_user_ts_idx ON token_usage (user_id, "timestamp" DESC);')
|
||||
op.execute('CREATE INDEX token_usage_key_ts_idx ON token_usage (api_key, "timestamp" DESC);')
|
||||
op.execute('CREATE INDEX token_usage_agent_ts_idx ON token_usage (agent_id, "timestamp" DESC);')
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX token_usage_mongo_id_uidx "
|
||||
"ON token_usage (mongo_id) WHERE mongo_id IS NOT NULL;"
|
||||
)
|
||||
|
||||
op.execute('CREATE INDEX user_logs_user_ts_idx ON user_logs (user_id, "timestamp" DESC);')
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX user_logs_mongo_id_uidx "
|
||||
"ON user_logs (mongo_id) WHERE mongo_id IS NOT NULL;"
|
||||
)
|
||||
|
||||
op.execute("CREATE INDEX user_tools_user_id_idx ON user_tools (user_id);")
|
||||
|
||||
|
||||
@@ -14,10 +14,13 @@ from application.core.model_utils import (
|
||||
get_provider_from_model_id,
|
||||
)
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.error import sanitize_api_error
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.token_usage import TokenUsageRepository
|
||||
from application.storage.db.repositories.user_logs import UserLogsRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.utils import check_required_fields
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -30,10 +33,6 @@ class BaseAnswerResource:
|
||||
"""Shared base class for answer endpoints"""
|
||||
|
||||
def __init__(self):
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
self.db = db
|
||||
self.user_logs_collection = db["user_logs"]
|
||||
self.default_model_id = get_default_model_id()
|
||||
self.conversation_service = ConversationService()
|
||||
|
||||
@@ -91,8 +90,8 @@ class BaseAnswerResource:
|
||||
api_key = agent_config.get("user_api_key")
|
||||
if not api_key:
|
||||
return None
|
||||
agents_collection = self.db["agents"]
|
||||
agent = agents_collection.find_one({"key": api_key})
|
||||
with db_readonly() as conn:
|
||||
agent = AgentsRepository(conn).find_by_key(api_key)
|
||||
|
||||
if not agent:
|
||||
return make_response(
|
||||
@@ -113,41 +112,32 @@ class BaseAnswerResource:
|
||||
)
|
||||
|
||||
token_limit = int(
|
||||
agent.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"])
|
||||
agent.get("token_limit") or settings.DEFAULT_AGENT_LIMITS["token_limit"]
|
||||
)
|
||||
request_limit = int(
|
||||
agent.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"])
|
||||
agent.get("request_limit") or settings.DEFAULT_AGENT_LIMITS["request_limit"]
|
||||
)
|
||||
|
||||
token_usage_collection = self.db["token_usage"]
|
||||
|
||||
end_date = datetime.datetime.now()
|
||||
end_date = datetime.datetime.now(datetime.timezone.utc)
|
||||
start_date = end_date - datetime.timedelta(hours=24)
|
||||
|
||||
match_query = {
|
||||
"timestamp": {"$gte": start_date, "$lte": end_date},
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
||||
if limited_token_mode:
|
||||
token_pipeline = [
|
||||
{"$match": match_query},
|
||||
{
|
||||
"$group": {
|
||||
"_id": None,
|
||||
"total_tokens": {
|
||||
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
|
||||
},
|
||||
}
|
||||
},
|
||||
]
|
||||
token_result = list(token_usage_collection.aggregate(token_pipeline))
|
||||
daily_token_usage = token_result[0]["total_tokens"] if token_result else 0
|
||||
if limited_token_mode or limited_request_mode:
|
||||
with db_readonly() as conn:
|
||||
token_repo = TokenUsageRepository(conn)
|
||||
if limited_token_mode:
|
||||
daily_token_usage = token_repo.sum_tokens_in_range(
|
||||
start=start_date, end=end_date, api_key=api_key,
|
||||
)
|
||||
else:
|
||||
daily_token_usage = 0
|
||||
if limited_request_mode:
|
||||
daily_request_usage = token_repo.count_in_range(
|
||||
start=start_date, end=end_date, api_key=api_key,
|
||||
)
|
||||
else:
|
||||
daily_request_usage = 0
|
||||
else:
|
||||
daily_token_usage = 0
|
||||
if limited_request_mode:
|
||||
daily_request_usage = token_usage_collection.count_documents(match_query)
|
||||
else:
|
||||
daily_request_usage = 0
|
||||
if not limited_token_mode and not limited_request_mode:
|
||||
return None
|
||||
@@ -467,19 +457,18 @@ class BaseAnswerResource:
|
||||
for key, value in log_data.items():
|
||||
if isinstance(value, str) and len(value) > 10000:
|
||||
log_data[key] = value[:10000]
|
||||
self.user_logs_collection.insert_one(log_data)
|
||||
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.repositories.user_logs import UserLogsRepository
|
||||
|
||||
dual_write(
|
||||
UserLogsRepository,
|
||||
lambda repo, d=log_data: repo.insert(
|
||||
user_id=d.get("user"),
|
||||
endpoint="stream_answer",
|
||||
data=d,
|
||||
),
|
||||
)
|
||||
try:
|
||||
with db_session() as conn:
|
||||
UserLogsRepository(conn).insert(
|
||||
user_id=log_data.get("user"),
|
||||
endpoint="stream_answer",
|
||||
data=log_data,
|
||||
)
|
||||
except Exception as log_err:
|
||||
logger.error(
|
||||
f"Failed to persist stream_answer user log: {log_err}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
data = json.dumps({"type": "end"})
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
@@ -4,11 +4,10 @@ from typing import Any, Dict, List
|
||||
from flask import make_response, request
|
||||
from flask_restx import fields, Resource
|
||||
|
||||
from bson.dbref import DBRef
|
||||
|
||||
from application.api.answer.routes.base import answer_ns
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.session import db_readonly
|
||||
from application.vectorstore.vector_creator import VectorCreator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -18,12 +17,6 @@ logger = logging.getLogger(__name__)
|
||||
class SearchResource(Resource):
|
||||
"""Fast search endpoint for retrieving relevant documents"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
mongo = MongoDB.get_client()
|
||||
self.db = mongo[settings.MONGO_DB_NAME]
|
||||
self.agents_collection = self.db["agents"]
|
||||
|
||||
search_model = answer_ns.model(
|
||||
"SearchModel",
|
||||
{
|
||||
@@ -40,37 +33,23 @@ class SearchResource(Resource):
|
||||
)
|
||||
|
||||
def _get_sources_from_api_key(self, api_key: str) -> List[str]:
|
||||
"""Get source IDs connected to the API key/agent.
|
||||
|
||||
"""
|
||||
agent_data = self.agents_collection.find_one({"key": api_key})
|
||||
"""Get source IDs connected to the API key/agent."""
|
||||
with db_readonly() as conn:
|
||||
agent_data = AgentsRepository(conn).find_by_key(api_key)
|
||||
if not agent_data:
|
||||
return []
|
||||
|
||||
source_ids = []
|
||||
source_ids: List[str] = []
|
||||
# extra_source_ids is a PG ARRAY(UUID) of source UUIDs.
|
||||
extra = agent_data.get("extra_source_ids") or []
|
||||
for src in extra:
|
||||
if src:
|
||||
source_ids.append(str(src))
|
||||
|
||||
# Handle multiple sources (only if non-empty)
|
||||
sources = agent_data.get("sources", [])
|
||||
if sources and isinstance(sources, list) and len(sources) > 0:
|
||||
for source_ref in sources:
|
||||
# Skip "default" - it's a placeholder, not an actual vectorstore
|
||||
if source_ref == "default":
|
||||
continue
|
||||
elif isinstance(source_ref, DBRef):
|
||||
source_doc = self.db.dereference(source_ref)
|
||||
if source_doc:
|
||||
source_ids.append(str(source_doc["_id"]))
|
||||
|
||||
# Handle single source (legacy) - check if sources was empty or didn't yield results
|
||||
if not source_ids:
|
||||
source = agent_data.get("source")
|
||||
if isinstance(source, DBRef):
|
||||
source_doc = self.db.dereference(source)
|
||||
if source_doc:
|
||||
source_ids.append(str(source_doc["_id"]))
|
||||
# Skip "default" - it's a placeholder, not an actual vectorstore
|
||||
elif source and source != "default":
|
||||
source_ids.append(source)
|
||||
single = agent_data.get("source_id")
|
||||
if single:
|
||||
source_ids.append(str(single))
|
||||
|
||||
return source_ids
|
||||
|
||||
@@ -161,7 +140,8 @@ class SearchResource(Resource):
|
||||
return make_response({"error": "api_key is required"}, 400)
|
||||
|
||||
# Validate API key
|
||||
agent = self.agents_collection.find_one({"key": api_key})
|
||||
with db_readonly() as conn:
|
||||
agent = AgentsRepository(conn).find_by_key(api_key)
|
||||
if not agent:
|
||||
return make_response({"error": "Invalid API key"}, 401)
|
||||
|
||||
|
||||
@@ -1,23 +1,20 @@
|
||||
"""Service for saving and restoring tool-call continuation state.
|
||||
|
||||
When a stream pauses (tool needs approval or client-side execution),
|
||||
the full execution state is persisted to MongoDB so the client can
|
||||
the full execution state is persisted to Postgres so the client can
|
||||
resume later by sending tool_actions.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from bson import ObjectId
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||
from application.storage.db.repositories.pending_tool_state import (
|
||||
PendingToolStateRepository,
|
||||
)
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -26,8 +23,13 @@ PENDING_STATE_TTL_SECONDS = 30 * 60 # 30 minutes
|
||||
|
||||
|
||||
def _make_serializable(obj: Any) -> Any:
|
||||
"""Recursively convert MongoDB ObjectIds and other non-JSON types."""
|
||||
if isinstance(obj, ObjectId):
|
||||
"""Recursively coerce non-JSON values into JSON-safe forms.
|
||||
|
||||
Handles ``uuid.UUID`` (from PG columns), ``bytes``, and recurses into
|
||||
dicts/lists. Post-Mongo-cutover the ObjectId branch is gone — none of
|
||||
our writers produce them anymore.
|
||||
"""
|
||||
if isinstance(obj, UUID):
|
||||
return str(obj)
|
||||
if isinstance(obj, dict):
|
||||
return {str(k): _make_serializable(v) for k, v in obj.items()}
|
||||
@@ -39,25 +41,13 @@ def _make_serializable(obj: Any) -> Any:
|
||||
|
||||
|
||||
class ContinuationService:
|
||||
"""Manages pending tool-call state in MongoDB."""
|
||||
"""Manages pending tool-call state in Postgres."""
|
||||
|
||||
def __init__(self):
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
self.collection = db["pending_tool_state"]
|
||||
self._ensure_indexes()
|
||||
|
||||
def _ensure_indexes(self):
|
||||
try:
|
||||
self.collection.create_index(
|
||||
"expires_at", expireAfterSeconds=0
|
||||
)
|
||||
self.collection.create_index(
|
||||
[("conversation_id", 1), ("user", 1)], unique=True
|
||||
)
|
||||
except Exception:
|
||||
# Indexes may already exist or mongomock doesn't support TTL
|
||||
pass
|
||||
# No-op constructor retained for call-site compatibility. State
|
||||
# lives in Postgres now; each operation opens its own short-lived
|
||||
# session rather than holding a connection on the service.
|
||||
pass
|
||||
|
||||
def save_state(
|
||||
self,
|
||||
@@ -72,6 +62,10 @@ class ContinuationService:
|
||||
) -> str:
|
||||
"""Save execution state for later continuation.
|
||||
|
||||
``conversation_id`` may be a Postgres UUID or the legacy Mongo
|
||||
``ObjectId`` string — the latter is resolved via
|
||||
``conversations.legacy_mongo_id`` to find the matching row.
|
||||
|
||||
Args:
|
||||
conversation_id: The conversation this state belongs to.
|
||||
user: Owner user ID.
|
||||
@@ -83,45 +77,26 @@ class ContinuationService:
|
||||
client_tools: Client-provided tool schemas for client-side execution.
|
||||
|
||||
Returns:
|
||||
The string ID of the saved state document.
|
||||
The string ID (conversation_id as provided) of the saved state.
|
||||
"""
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
expires_at = now + datetime.timedelta(seconds=PENDING_STATE_TTL_SECONDS)
|
||||
|
||||
doc = {
|
||||
"conversation_id": conversation_id,
|
||||
"user": user,
|
||||
"messages": _make_serializable(messages),
|
||||
"pending_tool_calls": _make_serializable(pending_tool_calls),
|
||||
"tools_dict": _make_serializable(tools_dict),
|
||||
"tool_schemas": _make_serializable(tool_schemas),
|
||||
"agent_config": _make_serializable(agent_config),
|
||||
"client_tools": _make_serializable(client_tools) if client_tools else None,
|
||||
"created_at": now,
|
||||
"expires_at": expires_at,
|
||||
}
|
||||
|
||||
# Upsert — only one pending state per conversation per user
|
||||
result = self.collection.replace_one(
|
||||
{"conversation_id": conversation_id, "user": user},
|
||||
doc,
|
||||
upsert=True,
|
||||
)
|
||||
state_id = str(result.upserted_id) if result.upserted_id else conversation_id
|
||||
logger.info(
|
||||
f"Saved continuation state for conversation {conversation_id} "
|
||||
f"with {len(pending_tool_calls)} pending tool call(s)"
|
||||
)
|
||||
|
||||
# Dual-write to Postgres — upsert against the same Mongo conversation
|
||||
# by resolving its UUID via conversations.legacy_mongo_id.
|
||||
def _pg_save(_: PendingToolStateRepository) -> None:
|
||||
conn = _._conn # reuse the existing transaction
|
||||
with db_session() as conn:
|
||||
conv = ConversationsRepository(conn).get_by_legacy_id(conversation_id)
|
||||
if conv is None:
|
||||
return
|
||||
_.save_state(
|
||||
conv["id"],
|
||||
if conv is not None:
|
||||
pg_conv_id = conv["id"]
|
||||
elif looks_like_uuid(conversation_id):
|
||||
pg_conv_id = conversation_id
|
||||
else:
|
||||
# Unresolvable legacy ObjectId — downstream ``CAST AS uuid``
|
||||
# would raise and poison the save. Surface the mismatch so
|
||||
# the caller can decide (the stream loop in routes/base.py
|
||||
# already wraps this in try/except).
|
||||
raise ValueError(
|
||||
f"Cannot save continuation state: conversation_id "
|
||||
f"{conversation_id!r} is neither a PG UUID nor a "
|
||||
f"backfilled legacy Mongo id."
|
||||
)
|
||||
PendingToolStateRepository(conn).save_state(
|
||||
pg_conv_id,
|
||||
user,
|
||||
messages=_make_serializable(messages),
|
||||
pending_tool_calls=_make_serializable(pending_tool_calls),
|
||||
@@ -131,8 +106,11 @@ class ContinuationService:
|
||||
client_tools=_make_serializable(client_tools) if client_tools else None,
|
||||
)
|
||||
|
||||
dual_write(PendingToolStateRepository, _pg_save)
|
||||
return state_id
|
||||
logger.info(
|
||||
f"Saved continuation state for conversation {conversation_id} "
|
||||
f"with {len(pending_tool_calls)} pending tool call(s)"
|
||||
)
|
||||
return conversation_id
|
||||
|
||||
def load_state(
|
||||
self, conversation_id: str, user: str
|
||||
@@ -142,34 +120,38 @@ class ContinuationService:
|
||||
Returns:
|
||||
The state dict, or None if no pending state exists.
|
||||
"""
|
||||
doc = self.collection.find_one(
|
||||
{"conversation_id": conversation_id, "user": user}
|
||||
)
|
||||
with db_readonly() as conn:
|
||||
conv = ConversationsRepository(conn).get_by_legacy_id(conversation_id)
|
||||
if conv is not None:
|
||||
pg_conv_id = conv["id"]
|
||||
elif looks_like_uuid(conversation_id):
|
||||
pg_conv_id = conversation_id
|
||||
else:
|
||||
# Unresolvable legacy ObjectId → no state can exist for it.
|
||||
return None
|
||||
doc = PendingToolStateRepository(conn).load_state(pg_conv_id, user)
|
||||
if not doc:
|
||||
return None
|
||||
doc["_id"] = str(doc["_id"])
|
||||
return doc
|
||||
|
||||
def delete_state(self, conversation_id: str, user: str) -> bool:
|
||||
"""Delete pending state after successful resumption.
|
||||
|
||||
Returns:
|
||||
True if a document was deleted.
|
||||
True if a row was deleted.
|
||||
"""
|
||||
result = self.collection.delete_one(
|
||||
{"conversation_id": conversation_id, "user": user}
|
||||
)
|
||||
if result.deleted_count:
|
||||
with db_session() as conn:
|
||||
conv = ConversationsRepository(conn).get_by_legacy_id(conversation_id)
|
||||
if conv is not None:
|
||||
pg_conv_id = conv["id"]
|
||||
elif looks_like_uuid(conversation_id):
|
||||
pg_conv_id = conversation_id
|
||||
else:
|
||||
# Unresolvable legacy ObjectId → nothing to delete.
|
||||
return False
|
||||
deleted = PendingToolStateRepository(conn).delete_state(pg_conv_id, user)
|
||||
if deleted:
|
||||
logger.info(
|
||||
f"Deleted continuation state for conversation {conversation_id}"
|
||||
)
|
||||
|
||||
# Dual-write to Postgres — delete the same row.
|
||||
def _pg_delete(repo: PendingToolStateRepository) -> None:
|
||||
conv = ConversationsRepository(repo._conn).get_by_legacy_id(conversation_id)
|
||||
if conv is None:
|
||||
return
|
||||
repo.delete_state(conv["id"], user)
|
||||
|
||||
dual_write(PendingToolStateRepository, _pg_delete)
|
||||
return result.deleted_count > 0
|
||||
return deleted
|
||||
|
||||
@@ -1,46 +1,51 @@
|
||||
"""Conversation persistence service backed by Postgres.
|
||||
|
||||
Handles create / append / update / compression for conversations during
|
||||
the answer-streaming path. Connections are opened per-operation rather
|
||||
than held for the duration of a stream.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from sqlalchemy import text as sql_text
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||
from bson import ObjectId
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConversationService:
|
||||
def __init__(self):
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
self.conversations_collection = db["conversations"]
|
||||
self.agents_collection = db["agents"]
|
||||
|
||||
def get_conversation(
|
||||
self, conversation_id: str, user_id: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Retrieve a conversation with proper access control"""
|
||||
"""Retrieve a conversation with owner-or-shared access control.
|
||||
|
||||
Returns a dict in the legacy Mongo shape — ``queries`` is a list
|
||||
of message dicts (prompt/response/...) — for compatibility with
|
||||
the streaming pipeline that consumes this shape.
|
||||
"""
|
||||
if not conversation_id or not user_id:
|
||||
return None
|
||||
try:
|
||||
conversation = self.conversations_collection.find_one(
|
||||
{
|
||||
"_id": ObjectId(conversation_id),
|
||||
"$or": [{"user": user_id}, {"shared_with": user_id}],
|
||||
}
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
logger.warning(
|
||||
f"Conversation not found or unauthorized - ID: {conversation_id}, User: {user_id}"
|
||||
)
|
||||
return None
|
||||
conversation["_id"] = str(conversation["_id"])
|
||||
return conversation
|
||||
with db_readonly() as conn:
|
||||
repo = ConversationsRepository(conn)
|
||||
conv = repo.get_any(conversation_id, user_id)
|
||||
if conv is None:
|
||||
logger.warning(
|
||||
f"Conversation not found or unauthorized - ID: {conversation_id}, User: {user_id}"
|
||||
)
|
||||
return None
|
||||
messages = repo.get_messages(str(conv["id"]))
|
||||
conv["queries"] = messages
|
||||
conv["_id"] = str(conv["id"])
|
||||
return conv
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching conversation: {str(e)}", exc_info=True)
|
||||
return None
|
||||
@@ -64,7 +69,11 @@ class ConversationService:
|
||||
attachment_ids: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
"""Save or update a conversation in the database"""
|
||||
"""Save or update a conversation in Postgres.
|
||||
|
||||
Returns the string conversation id (PG UUID as string, or the
|
||||
caller-provided id if it was already a UUID).
|
||||
"""
|
||||
if decoded_token is None:
|
||||
raise ValueError("Invalid or missing authentication token")
|
||||
user_id = decoded_token.get("sub")
|
||||
@@ -72,117 +81,47 @@ class ConversationService:
|
||||
raise ValueError("User ID not found in token")
|
||||
current_time = datetime.now(timezone.utc)
|
||||
|
||||
# clean up in sources array such that we save max 1k characters for text part
|
||||
# Trim huge inline source text to a reasonable max before persist.
|
||||
for source in sources:
|
||||
if "text" in source and isinstance(source["text"], str):
|
||||
source["text"] = source["text"][:1000]
|
||||
|
||||
message_payload = {
|
||||
"prompt": question,
|
||||
"response": response,
|
||||
"thought": thought,
|
||||
"sources": sources,
|
||||
"tool_calls": tool_calls,
|
||||
"attachments": attachment_ids,
|
||||
"model_id": model_id,
|
||||
"timestamp": current_time,
|
||||
}
|
||||
if metadata:
|
||||
message_payload["metadata"] = metadata
|
||||
|
||||
if conversation_id is not None and index is not None:
|
||||
# Update existing conversation with new query
|
||||
|
||||
result = self.conversations_collection.update_one(
|
||||
{
|
||||
"_id": ObjectId(conversation_id),
|
||||
"user": user_id,
|
||||
f"queries.{index}": {"$exists": True},
|
||||
},
|
||||
{
|
||||
"$set": {
|
||||
f"queries.{index}.prompt": question,
|
||||
f"queries.{index}.response": response,
|
||||
f"queries.{index}.thought": thought,
|
||||
f"queries.{index}.sources": sources,
|
||||
f"queries.{index}.tool_calls": tool_calls,
|
||||
f"queries.{index}.timestamp": current_time,
|
||||
f"queries.{index}.attachments": attachment_ids,
|
||||
f"queries.{index}.model_id": model_id,
|
||||
**(
|
||||
{f"queries.{index}.metadata": metadata}
|
||||
if metadata
|
||||
else {}
|
||||
),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
if result.matched_count == 0:
|
||||
raise ValueError("Conversation not found or unauthorized")
|
||||
self.conversations_collection.update_one(
|
||||
{
|
||||
"_id": ObjectId(conversation_id),
|
||||
"user": user_id,
|
||||
f"queries.{index}": {"$exists": True},
|
||||
},
|
||||
{"$push": {"queries": {"$each": [], "$slice": index + 1}}},
|
||||
)
|
||||
# Dual-write to Postgres: update the message at :index and
|
||||
# truncate anything after it, mirroring Mongo's $set+$slice.
|
||||
def _pg_update_at_index(repo: ConversationsRepository) -> None:
|
||||
conv = repo.get_by_legacy_id(conversation_id)
|
||||
with db_session() as conn:
|
||||
repo = ConversationsRepository(conn)
|
||||
conv = repo.get_any(conversation_id, user_id)
|
||||
if conv is None:
|
||||
return
|
||||
repo.update_message_at(conv["id"], index, {
|
||||
"prompt": question,
|
||||
"response": response,
|
||||
"thought": thought,
|
||||
"sources": sources,
|
||||
"tool_calls": tool_calls,
|
||||
"attachments": attachment_ids,
|
||||
"model_id": model_id,
|
||||
"timestamp": current_time,
|
||||
**({"metadata": metadata} if metadata else {}),
|
||||
})
|
||||
repo.truncate_after(conv["id"], index)
|
||||
|
||||
dual_write(ConversationsRepository, _pg_update_at_index)
|
||||
raise ValueError("Conversation not found or unauthorized")
|
||||
conv_pg_id = str(conv["id"])
|
||||
repo.update_message_at(conv_pg_id, index, message_payload)
|
||||
repo.truncate_after(conv_pg_id, index)
|
||||
return conversation_id
|
||||
elif conversation_id:
|
||||
# Append new message to existing conversation
|
||||
|
||||
result = self.conversations_collection.update_one(
|
||||
{"_id": ObjectId(conversation_id), "user": user_id},
|
||||
{
|
||||
"$push": {
|
||||
"queries": {
|
||||
"prompt": question,
|
||||
"response": response,
|
||||
"thought": thought,
|
||||
"sources": sources,
|
||||
"tool_calls": tool_calls,
|
||||
"timestamp": current_time,
|
||||
"attachments": attachment_ids,
|
||||
"model_id": model_id,
|
||||
**({"metadata": metadata} if metadata else {}),
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
if result.matched_count == 0:
|
||||
raise ValueError("Conversation not found or unauthorized")
|
||||
|
||||
# Dual-write to Postgres: append the same message.
|
||||
def _pg_append(repo: ConversationsRepository) -> None:
|
||||
conv = repo.get_by_legacy_id(conversation_id)
|
||||
with db_session() as conn:
|
||||
repo = ConversationsRepository(conn)
|
||||
conv = repo.get_any(conversation_id, user_id)
|
||||
if conv is None:
|
||||
return
|
||||
repo.append_message(conv["id"], {
|
||||
"prompt": question,
|
||||
"response": response,
|
||||
"thought": thought,
|
||||
"sources": sources,
|
||||
"tool_calls": tool_calls,
|
||||
"attachments": attachment_ids,
|
||||
"model_id": model_id,
|
||||
"timestamp": current_time,
|
||||
"metadata": metadata or {},
|
||||
})
|
||||
|
||||
dual_write(ConversationsRepository, _pg_append)
|
||||
raise ValueError("Conversation not found or unauthorized")
|
||||
conv_pg_id = str(conv["id"])
|
||||
# append_message expects 'metadata' key either way; normalise.
|
||||
append_payload = dict(message_payload)
|
||||
append_payload.setdefault("metadata", metadata or {})
|
||||
repo.append_message(conv_pg_id, append_payload)
|
||||
return conversation_id
|
||||
else:
|
||||
# Create new conversation
|
||||
|
||||
messages_summary = [
|
||||
{
|
||||
"role": "system",
|
||||
@@ -204,118 +143,67 @@ class ConversationService:
|
||||
if not completion or not completion.strip():
|
||||
completion = question[:50] if question else "New Conversation"
|
||||
|
||||
query_doc = {
|
||||
"prompt": question,
|
||||
"response": response,
|
||||
"thought": thought,
|
||||
"sources": sources,
|
||||
"tool_calls": tool_calls,
|
||||
"timestamp": current_time,
|
||||
"attachments": attachment_ids,
|
||||
"model_id": model_id,
|
||||
}
|
||||
if metadata:
|
||||
query_doc["metadata"] = metadata
|
||||
|
||||
conversation_data = {
|
||||
"user": user_id,
|
||||
"date": current_time,
|
||||
"name": completion,
|
||||
"queries": [query_doc],
|
||||
}
|
||||
|
||||
resolved_api_key: Optional[str] = None
|
||||
resolved_agent_id: Optional[str] = None
|
||||
if api_key:
|
||||
if agent_id:
|
||||
conversation_data["agent_id"] = agent_id
|
||||
if is_shared_usage:
|
||||
conversation_data["is_shared_usage"] = is_shared_usage
|
||||
conversation_data["shared_token"] = shared_token
|
||||
agent = self.agents_collection.find_one({"key": api_key})
|
||||
with db_readonly() as conn:
|
||||
agent = AgentsRepository(conn).find_by_key(api_key)
|
||||
if agent:
|
||||
conversation_data["api_key"] = agent["key"]
|
||||
result = self.conversations_collection.insert_one(conversation_data)
|
||||
inserted_id = str(result.inserted_id)
|
||||
resolved_api_key = agent.get("key")
|
||||
if agent_id:
|
||||
resolved_agent_id = agent_id
|
||||
|
||||
# Dual-write to Postgres: create the conversation row with
|
||||
# legacy_mongo_id and append the first message.
|
||||
def _pg_create(repo: ConversationsRepository) -> None:
|
||||
with db_session() as conn:
|
||||
repo = ConversationsRepository(conn)
|
||||
conv = repo.create(
|
||||
user_id,
|
||||
completion,
|
||||
agent_id=conversation_data.get("agent_id"),
|
||||
api_key=conversation_data.get("api_key"),
|
||||
is_shared_usage=conversation_data.get("is_shared_usage", False),
|
||||
shared_token=conversation_data.get("shared_token"),
|
||||
legacy_mongo_id=inserted_id,
|
||||
agent_id=resolved_agent_id,
|
||||
api_key=resolved_api_key,
|
||||
is_shared_usage=bool(resolved_agent_id and is_shared_usage),
|
||||
shared_token=(
|
||||
shared_token
|
||||
if (resolved_agent_id and is_shared_usage)
|
||||
else None
|
||||
),
|
||||
)
|
||||
repo.append_message(conv["id"], {
|
||||
"prompt": question,
|
||||
"response": response,
|
||||
"thought": thought,
|
||||
"sources": sources,
|
||||
"tool_calls": tool_calls,
|
||||
"attachments": attachment_ids,
|
||||
"model_id": model_id,
|
||||
"timestamp": current_time,
|
||||
"metadata": metadata or {},
|
||||
})
|
||||
|
||||
dual_write(ConversationsRepository, _pg_create)
|
||||
return inserted_id
|
||||
conv_pg_id = str(conv["id"])
|
||||
append_payload = dict(message_payload)
|
||||
append_payload.setdefault("metadata", metadata or {})
|
||||
repo.append_message(conv_pg_id, append_payload)
|
||||
return conv_pg_id
|
||||
|
||||
def update_compression_metadata(
|
||||
self, conversation_id: str, compression_metadata: Dict[str, Any]
|
||||
) -> None:
|
||||
"""
|
||||
Update conversation with compression metadata.
|
||||
"""Persist compression flags and append a compression point.
|
||||
|
||||
Uses $push with $slice to keep only the most recent compression points,
|
||||
preventing unbounded array growth. Since each compression incorporates
|
||||
previous compressions, older points become redundant.
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation ID
|
||||
compression_metadata: Compression point data
|
||||
Mirrors the Mongo-era ``$set`` + ``$push $slice`` on
|
||||
``compression_metadata`` but goes through the PG repo API.
|
||||
"""
|
||||
try:
|
||||
self.conversations_collection.update_one(
|
||||
{"_id": ObjectId(conversation_id)},
|
||||
{
|
||||
"$set": {
|
||||
"compression_metadata.is_compressed": True,
|
||||
"compression_metadata.last_compression_at": compression_metadata.get(
|
||||
"timestamp"
|
||||
),
|
||||
},
|
||||
"$push": {
|
||||
"compression_metadata.compression_points": {
|
||||
"$each": [compression_metadata],
|
||||
"$slice": -settings.COMPRESSION_MAX_HISTORY_POINTS,
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
logger.info(
|
||||
f"Updated compression metadata for conversation {conversation_id}"
|
||||
)
|
||||
|
||||
# Dual-write to Postgres: mirror $set + $push $slice.
|
||||
def _pg_compression(repo: ConversationsRepository) -> None:
|
||||
with db_session() as conn:
|
||||
repo = ConversationsRepository(conn)
|
||||
# conversation_id here comes from the streaming pipeline
|
||||
# which has already resolved it; accept either UUID or
|
||||
# legacy id for safety.
|
||||
conv = repo.get_by_legacy_id(conversation_id)
|
||||
if conv is None:
|
||||
return
|
||||
conv_pg_id = (
|
||||
str(conv["id"]) if conv is not None else conversation_id
|
||||
)
|
||||
repo.set_compression_flags(
|
||||
conv["id"],
|
||||
conv_pg_id,
|
||||
is_compressed=True,
|
||||
last_compression_at=compression_metadata.get("timestamp"),
|
||||
)
|
||||
repo.append_compression_point(
|
||||
conv["id"],
|
||||
conv_pg_id,
|
||||
compression_metadata,
|
||||
max_points=settings.COMPRESSION_MAX_HISTORY_POINTS,
|
||||
)
|
||||
|
||||
dual_write(ConversationsRepository, _pg_compression)
|
||||
logger.info(
|
||||
f"Updated compression metadata for conversation {conversation_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error updating compression metadata: {str(e)}", exc_info=True
|
||||
@@ -325,39 +213,22 @@ class ConversationService:
|
||||
def append_compression_message(
|
||||
self, conversation_id: str, compression_metadata: Dict[str, Any]
|
||||
) -> None:
|
||||
"""
|
||||
Append a synthetic compression summary entry into the conversation history.
|
||||
This makes the summary visible in the DB alongside normal queries.
|
||||
"""
|
||||
"""Append a synthetic compression summary message to the conversation."""
|
||||
try:
|
||||
summary = compression_metadata.get("compressed_summary", "")
|
||||
if not summary:
|
||||
return
|
||||
timestamp = compression_metadata.get("timestamp", datetime.now(timezone.utc))
|
||||
|
||||
self.conversations_collection.update_one(
|
||||
{"_id": ObjectId(conversation_id)},
|
||||
{
|
||||
"$push": {
|
||||
"queries": {
|
||||
"prompt": "[Context Compression Summary]",
|
||||
"response": summary,
|
||||
"thought": "",
|
||||
"sources": [],
|
||||
"tool_calls": [],
|
||||
"timestamp": timestamp,
|
||||
"attachments": [],
|
||||
"model_id": compression_metadata.get("model_used"),
|
||||
}
|
||||
}
|
||||
},
|
||||
timestamp = compression_metadata.get(
|
||||
"timestamp", datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
def _pg_append_summary(repo: ConversationsRepository) -> None:
|
||||
with db_session() as conn:
|
||||
repo = ConversationsRepository(conn)
|
||||
conv = repo.get_by_legacy_id(conversation_id)
|
||||
if conv is None:
|
||||
return
|
||||
repo.append_message(conv["id"], {
|
||||
conv_pg_id = (
|
||||
str(conv["id"]) if conv is not None else conversation_id
|
||||
)
|
||||
repo.append_message(conv_pg_id, {
|
||||
"prompt": "[Context Compression Summary]",
|
||||
"response": summary,
|
||||
"thought": "",
|
||||
@@ -367,9 +238,9 @@ class ConversationService:
|
||||
"model_id": compression_metadata.get("model_used"),
|
||||
"timestamp": timestamp,
|
||||
})
|
||||
|
||||
dual_write(ConversationsRepository, _pg_append_summary)
|
||||
logger.info(f"Appended compression summary to conversation {conversation_id}")
|
||||
logger.info(
|
||||
f"Appended compression summary to conversation {conversation_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error appending compression summary: {str(e)}", exc_info=True
|
||||
@@ -378,20 +249,30 @@ class ConversationService:
|
||||
def get_compression_metadata(
|
||||
self, conversation_id: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get compression metadata for a conversation.
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation ID
|
||||
|
||||
Returns:
|
||||
Compression metadata dict or None
|
||||
"""
|
||||
"""Fetch the stored compression metadata JSONB blob for a conversation."""
|
||||
try:
|
||||
conversation = self.conversations_collection.find_one(
|
||||
{"_id": ObjectId(conversation_id)}, {"compression_metadata": 1}
|
||||
)
|
||||
return conversation.get("compression_metadata") if conversation else None
|
||||
with db_readonly() as conn:
|
||||
repo = ConversationsRepository(conn)
|
||||
conv = repo.get_by_legacy_id(conversation_id)
|
||||
if conv is None:
|
||||
# Fallback to UUID lookup without user scoping — the
|
||||
# caller already holds an authenticated conversation
|
||||
# id from the streaming path. Gate on id shape so a
|
||||
# non-UUID (legacy ObjectId that wasn't backfilled)
|
||||
# doesn't reach CAST — the cast raises and spams the
|
||||
# logs with a stack trace on every call.
|
||||
if not looks_like_uuid(conversation_id):
|
||||
return None
|
||||
result = conn.execute(
|
||||
sql_text(
|
||||
"SELECT compression_metadata FROM conversations "
|
||||
"WHERE id = CAST(:id AS uuid)"
|
||||
),
|
||||
{"id": conversation_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row[0] if row is not None else None
|
||||
return conv.get("compression_metadata") if conv else None
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting compression metadata: {str(e)}", exc_info=True
|
||||
|
||||
@@ -5,10 +5,6 @@ import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Set
|
||||
|
||||
from bson.dbref import DBRef
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
|
||||
from application.agents.agent_creator import AgentCreator
|
||||
from application.api.answer.services.compression import CompressionOrchestrator
|
||||
from application.api.answer.services.compression.token_counter import TokenCounter
|
||||
@@ -20,8 +16,16 @@ from application.core.model_utils import (
|
||||
get_provider_from_model_id,
|
||||
validate_model_id,
|
||||
)
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from sqlalchemy import text as sql_text
|
||||
|
||||
from application.storage.db.base_repository import looks_like_uuid, row_to_dict
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.attachments import AttachmentsRepository
|
||||
from application.storage.db.repositories.prompts import PromptsRepository
|
||||
from application.storage.db.repositories.sources import SourcesRepository
|
||||
from application.storage.db.repositories.user_tools import UserToolsRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.retriever.retriever_creator import RetrieverCreator
|
||||
from application.utils import (
|
||||
calculate_doc_token_budget,
|
||||
@@ -32,28 +36,41 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_prompt(prompt_id: str, prompts_collection=None) -> str:
|
||||
"""Get a prompt by preset name or Postgres ID (UUID or legacy ObjectId).
|
||||
|
||||
The ``prompts_collection`` parameter is retained for backwards
|
||||
compatibility with call sites that still pass it positionally; it is
|
||||
ignored post-cutover.
|
||||
"""
|
||||
Get a prompt by preset name or MongoDB ID
|
||||
"""
|
||||
del prompts_collection # unused — retained for call-site compatibility
|
||||
# Callers may pass a ``uuid.UUID`` (from a PG ``prompt_id`` column) or a
|
||||
# plain string ("default"/"creative"/legacy ObjectId). Normalise to str
|
||||
# so both the preset lookup and the UUID-vs-legacy branching work.
|
||||
# ``None`` / empty means "use the default prompt" — agents that never
|
||||
# set a custom prompt land here (PG ``agents.prompt_id`` is NULL).
|
||||
if prompt_id is None or prompt_id == "":
|
||||
prompt_id = "default"
|
||||
elif not isinstance(prompt_id, str):
|
||||
prompt_id = str(prompt_id)
|
||||
current_dir = Path(__file__).resolve().parents[3]
|
||||
prompts_dir = current_dir / "prompts"
|
||||
|
||||
# Maps for classic agent types
|
||||
CLASSIC_PRESETS = {
|
||||
"default": "chat_combine_default.txt",
|
||||
"creative": "chat_combine_creative.txt",
|
||||
"strict": "chat_combine_strict.txt",
|
||||
"reduce": "chat_reduce_prompt.txt",
|
||||
}
|
||||
|
||||
# Agentic counterparts — same styles, but with search tool instructions
|
||||
AGENTIC_PRESETS = {
|
||||
"default": "agentic/default.txt",
|
||||
"creative": "agentic/creative.txt",
|
||||
"strict": "agentic/strict.txt",
|
||||
}
|
||||
|
||||
preset_mapping = {**CLASSIC_PRESETS, **{f"agentic_{k}": v for k, v in AGENTIC_PRESETS.items()}}
|
||||
preset_mapping = {
|
||||
**CLASSIC_PRESETS,
|
||||
**{f"agentic_{k}": v for k, v in AGENTIC_PRESETS.items()},
|
||||
}
|
||||
|
||||
if prompt_id in preset_mapping:
|
||||
file_path = os.path.join(prompts_dir, preset_mapping[prompt_id])
|
||||
@@ -63,14 +80,18 @@ def get_prompt(prompt_id: str, prompts_collection=None) -> str:
|
||||
except FileNotFoundError:
|
||||
raise FileNotFoundError(f"Prompt file not found: {file_path}")
|
||||
try:
|
||||
if prompts_collection is None:
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
prompts_collection = db["prompts"]
|
||||
prompt_doc = prompts_collection.find_one({"_id": ObjectId(prompt_id)})
|
||||
with db_readonly() as conn:
|
||||
repo = PromptsRepository(conn)
|
||||
prompt_doc = None
|
||||
if looks_like_uuid(prompt_id):
|
||||
prompt_doc = repo.get_for_rendering(prompt_id)
|
||||
if prompt_doc is None:
|
||||
prompt_doc = repo.get_by_legacy_id(prompt_id)
|
||||
if not prompt_doc:
|
||||
raise ValueError(f"Prompt with ID {prompt_id} not found")
|
||||
return prompt_doc["content"]
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid prompt ID: {prompt_id}") from e
|
||||
|
||||
@@ -79,12 +100,9 @@ class StreamProcessor:
|
||||
def __init__(
|
||||
self, request_data: Dict[str, Any], decoded_token: Optional[Dict[str, Any]]
|
||||
):
|
||||
mongo = MongoDB.get_client()
|
||||
self.db = mongo[settings.MONGO_DB_NAME]
|
||||
self.agents_collection = self.db["agents"]
|
||||
self.attachments_collection = self.db["attachments"]
|
||||
self.prompts_collection = self.db["prompts"]
|
||||
|
||||
# Legacy attribute retained as None for any external callers that
|
||||
# introspect the processor; all DB access uses per-op connections.
|
||||
self.prompts_collection = None
|
||||
self.data = request_data
|
||||
self.decoded_token = decoded_token
|
||||
self.initial_user_id = (
|
||||
@@ -244,17 +262,21 @@ class StreamProcessor:
|
||||
if not attachment_ids:
|
||||
return []
|
||||
attachments = []
|
||||
for attachment_id in attachment_ids:
|
||||
try:
|
||||
attachment_doc = self.attachments_collection.find_one(
|
||||
{"_id": ObjectId(attachment_id), "user": user_id}
|
||||
)
|
||||
if attachment_doc:
|
||||
attachments.append(attachment_doc)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error retrieving attachment {attachment_id}: {e}", exc_info=True
|
||||
)
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
repo = AttachmentsRepository(conn)
|
||||
for attachment_id in attachment_ids:
|
||||
try:
|
||||
attachment_doc = repo.get_any(str(attachment_id), user_id)
|
||||
if attachment_doc:
|
||||
attachments.append(attachment_doc)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error retrieving attachment {attachment_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error opening attachments connection: {e}", exc_info=True)
|
||||
return attachments
|
||||
|
||||
def _validate_and_set_model(self):
|
||||
@@ -285,78 +307,101 @@ class StreamProcessor:
|
||||
self.model_id = get_default_model_id()
|
||||
|
||||
def _get_agent_key(self, agent_id: Optional[str], user_id: Optional[str]) -> tuple:
|
||||
"""Get API key for agent with access control"""
|
||||
"""Get API key for agent with access control."""
|
||||
if not agent_id:
|
||||
return None, False, None
|
||||
try:
|
||||
agent = self.agents_collection.find_one({"_id": ObjectId(agent_id)})
|
||||
with db_readonly() as conn:
|
||||
# Lookup without user scoping — access control is done
|
||||
# against ``user_id`` / ``shared_with`` / ``shared`` flags
|
||||
# right below, matching the legacy Mongo semantics.
|
||||
repo = AgentsRepository(conn)
|
||||
agent = None
|
||||
if looks_like_uuid(str(agent_id)):
|
||||
result = conn.execute(
|
||||
sql_text(
|
||||
"SELECT * FROM agents WHERE id = CAST(:id AS uuid)"
|
||||
),
|
||||
{"id": str(agent_id)},
|
||||
)
|
||||
row = result.fetchone()
|
||||
if row is not None:
|
||||
agent = row_to_dict(row)
|
||||
if agent is None:
|
||||
agent = repo.get_by_legacy_id(str(agent_id))
|
||||
if agent is None:
|
||||
raise Exception("Agent not found")
|
||||
is_owner = agent.get("user") == user_id
|
||||
is_shared_with_user = agent.get(
|
||||
"shared_publicly", False
|
||||
) or user_id in agent.get("shared_with", [])
|
||||
agent_owner = agent.get("user_id")
|
||||
is_owner = agent_owner == user_id
|
||||
is_shared_with_user = bool(agent.get("shared", False))
|
||||
|
||||
if not (is_owner or is_shared_with_user):
|
||||
raise Exception("Unauthorized access to the agent")
|
||||
if is_owner:
|
||||
self.agents_collection.update_one(
|
||||
{"_id": ObjectId(agent_id)},
|
||||
{
|
||||
"$set": {
|
||||
"lastUsedAt": datetime.datetime.now(datetime.timezone.utc)
|
||||
}
|
||||
},
|
||||
)
|
||||
return str(agent["key"]), not is_owner, agent.get("shared_token")
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
try:
|
||||
with db_session() as conn:
|
||||
AgentsRepository(conn).update(
|
||||
str(agent["id"]), agent_owner,
|
||||
{"last_used_at": now},
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to update last_used_at for agent",
|
||||
exc_info=True,
|
||||
)
|
||||
return (
|
||||
str(agent["key"]) if agent.get("key") else None,
|
||||
not is_owner,
|
||||
agent.get("shared_token"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_agent_key: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def _get_data_from_api_key(self, api_key: str) -> Dict[str, Any]:
|
||||
data = self.agents_collection.find_one({"key": api_key})
|
||||
if not data:
|
||||
raise Exception("Invalid API Key, please generate a new key", 401)
|
||||
source = data.get("source")
|
||||
if isinstance(source, DBRef):
|
||||
source_doc = self.db.dereference(source)
|
||||
if source_doc:
|
||||
data["source"] = str(source_doc["_id"])
|
||||
data["retriever"] = source_doc.get("retriever", data.get("retriever"))
|
||||
data["chunks"] = source_doc.get("chunks", data.get("chunks"))
|
||||
with db_readonly() as conn:
|
||||
agent = AgentsRepository(conn).find_by_key(api_key)
|
||||
if not agent:
|
||||
raise Exception("Invalid API Key, please generate a new key", 401)
|
||||
sources_repo = SourcesRepository(conn)
|
||||
# The repo dict uses "user_id" — the streaming path expects
|
||||
# a "user" key (legacy Mongo shape) for identity propagation.
|
||||
data: Dict[str, Any] = dict(agent)
|
||||
data["user"] = agent.get("user_id")
|
||||
|
||||
# Resolve the primary source row (if any) for retriever/chunks.
|
||||
source_id = agent.get("source_id")
|
||||
if source_id:
|
||||
source_doc = sources_repo.get(str(source_id), agent.get("user_id"))
|
||||
if source_doc:
|
||||
data["source"] = str(source_doc["id"])
|
||||
data["retriever"] = source_doc.get(
|
||||
"retriever", data.get("retriever")
|
||||
)
|
||||
data["chunks"] = source_doc.get("chunks", data.get("chunks"))
|
||||
else:
|
||||
data["source"] = None
|
||||
else:
|
||||
data["source"] = None
|
||||
elif source == "default":
|
||||
data["source"] = "default"
|
||||
else:
|
||||
data["source"] = None
|
||||
|
||||
sources = data.get("sources", [])
|
||||
if sources and isinstance(sources, list):
|
||||
sources_list = []
|
||||
for i, source_ref in enumerate(sources):
|
||||
if source_ref == "default":
|
||||
processed_source = {
|
||||
"id": "default",
|
||||
"retriever": "classic",
|
||||
"chunks": data.get("chunks", "2"),
|
||||
}
|
||||
sources_list.append(processed_source)
|
||||
elif isinstance(source_ref, DBRef):
|
||||
source_doc = self.db.dereference(source_ref)
|
||||
extra = agent.get("extra_source_ids") or []
|
||||
if extra:
|
||||
for sid in extra:
|
||||
source_doc = sources_repo.get(str(sid), agent.get("user_id"))
|
||||
if source_doc:
|
||||
processed_source = {
|
||||
"id": str(source_doc["_id"]),
|
||||
"retriever": source_doc.get("retriever", "classic"),
|
||||
"chunks": source_doc.get("chunks", data.get("chunks", "2")),
|
||||
}
|
||||
sources_list.append(processed_source)
|
||||
data["sources"] = sources_list
|
||||
else:
|
||||
data["sources"] = []
|
||||
|
||||
sources_list.append(
|
||||
{
|
||||
"id": str(source_doc["id"]),
|
||||
"retriever": source_doc.get("retriever", "classic"),
|
||||
"chunks": source_doc.get(
|
||||
"chunks", data.get("chunks", "2")
|
||||
),
|
||||
}
|
||||
)
|
||||
data["sources"] = sources_list
|
||||
data["default_model_id"] = data.get("default_model_id", "")
|
||||
|
||||
return data
|
||||
|
||||
def _configure_source(self):
|
||||
@@ -484,8 +529,14 @@ class StreamProcessor:
|
||||
# Owner using their own agent
|
||||
self.decoded_token = {"sub": self._agent_data.get("user")}
|
||||
|
||||
if self._agent_data.get("workflow"):
|
||||
self.agent_config["workflow"] = self._agent_data["workflow"]
|
||||
# PG row exposes the workflow as ``workflow_id`` (UUID column);
|
||||
# legacy Mongo shape used the key ``workflow``. Accept either so
|
||||
# API-key-invoked workflow agents bind correctly downstream.
|
||||
wf_ref = self._agent_data.get("workflow") or self._agent_data.get(
|
||||
"workflow_id"
|
||||
)
|
||||
if wf_ref:
|
||||
self.agent_config["workflow"] = str(wf_ref)
|
||||
self.agent_config["workflow_owner"] = self._agent_data.get("user")
|
||||
else:
|
||||
# No API key — default/workflow configuration
|
||||
@@ -620,12 +671,9 @@ class StreamProcessor:
|
||||
filtering_enabled = required_tool_actions is not None
|
||||
|
||||
try:
|
||||
user_tools_collection = self.db["user_tools"]
|
||||
user_id = self.initial_user_id or "local"
|
||||
|
||||
user_tools = list(
|
||||
user_tools_collection.find({"user": user_id, "status": True})
|
||||
)
|
||||
with db_readonly() as conn:
|
||||
user_tools = UserToolsRepository(conn).list_active_for_user(user_id)
|
||||
|
||||
if not user_tools:
|
||||
return None
|
||||
@@ -986,8 +1034,10 @@ class StreamProcessor:
|
||||
from application.llm.handlers.handler_creator import LLMHandlerCreator
|
||||
from application.agents.tool_executor import ToolExecutor
|
||||
|
||||
# Compute backup models: agent's configured models minus the active one
|
||||
agent_models = self.agent_config.get("models", [])
|
||||
# Compute backup models: agent's configured models minus the active one.
|
||||
# PG agents may carry an explicit ``models: NULL`` (not absent), so
|
||||
# ``.get("models", [])`` isn't enough — coerce None → [].
|
||||
agent_models = self.agent_config.get("models") or []
|
||||
backup_models = [m for m in agent_models if m != self.model_id]
|
||||
|
||||
llm = LLMCreator.create_llm(
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
import base64
|
||||
import datetime
|
||||
import html
|
||||
import json
|
||||
import uuid
|
||||
from urllib.parse import urlencode
|
||||
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import (
|
||||
Blueprint,
|
||||
current_app,
|
||||
@@ -17,22 +15,18 @@ from flask import (
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.tasks import (
|
||||
ingest_connector_task,
|
||||
)
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.api import api
|
||||
|
||||
|
||||
from application.parser.connectors.connector_creator import ConnectorCreator
|
||||
from application.storage.db.repositories.connector_sessions import (
|
||||
ConnectorSessionsRepository,
|
||||
)
|
||||
from application.storage.db.repositories.sources import SourcesRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
sources_collection = db["sources"]
|
||||
sessions_collection = db["connector_sessions"]
|
||||
|
||||
connector = Blueprint("connector", __name__)
|
||||
connectors_ns = Namespace("connectors", description="Connector operations", path="/")
|
||||
api.add_namespace(connectors_ns)
|
||||
@@ -68,16 +62,14 @@ class ConnectorAuth(Resource):
|
||||
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
|
||||
user_id = decoded_token.get('sub')
|
||||
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
result = sessions_collection.insert_one({
|
||||
"provider": provider,
|
||||
"user": user_id,
|
||||
"status": "pending",
|
||||
"created_at": now
|
||||
})
|
||||
with db_session() as conn:
|
||||
session_row = ConnectorSessionsRepository(conn).upsert(
|
||||
user_id, provider, status="pending",
|
||||
)
|
||||
session_pg_id = str(session_row["id"])
|
||||
state_dict = {
|
||||
"provider": provider,
|
||||
"object_id": str(result.inserted_id)
|
||||
"object_id": session_pg_id,
|
||||
}
|
||||
state = base64.urlsafe_b64encode(json.dumps(state_dict).encode()).decode()
|
||||
|
||||
@@ -160,17 +152,25 @@ class ConnectorsCallback(Resource):
|
||||
|
||||
sanitized_token_info = auth.sanitize_token_info(token_info)
|
||||
|
||||
sessions_collection.find_one_and_update(
|
||||
{"_id": ObjectId(state_object_id), "provider": provider},
|
||||
{
|
||||
"$set": {
|
||||
"session_token": session_token,
|
||||
"token_info": sanitized_token_info,
|
||||
"user_email": user_email,
|
||||
"status": "authorized"
|
||||
}
|
||||
}
|
||||
)
|
||||
# ``object_id`` in the OAuth state is the PG session row
|
||||
# UUID (new flow) or a legacy Mongo ObjectId (pre-cutover
|
||||
# issued state). Try UUID update first; fall back to
|
||||
# legacy id path.
|
||||
patch = {
|
||||
"session_token": session_token,
|
||||
"token_info": sanitized_token_info,
|
||||
"user_email": user_email,
|
||||
"status": "authorized",
|
||||
}
|
||||
with db_session() as conn:
|
||||
repo = ConnectorSessionsRepository(conn)
|
||||
if state_object_id:
|
||||
value = str(state_object_id)
|
||||
updated = False
|
||||
if len(value) == 36 and "-" in value:
|
||||
updated = repo.update(value, patch)
|
||||
if not updated:
|
||||
repo.update_by_legacy_id(value, patch)
|
||||
|
||||
# Redirect to success page with session token and user email
|
||||
return redirect(build_callback_redirect({
|
||||
@@ -222,8 +222,11 @@ class ConnectorFiles(Resource):
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
|
||||
user = decoded_token.get('sub')
|
||||
session = sessions_collection.find_one({"session_token": session_token, "user": user})
|
||||
if not session:
|
||||
with db_readonly() as conn:
|
||||
session = ConnectorSessionsRepository(conn).get_by_session_token(
|
||||
session_token,
|
||||
)
|
||||
if not session or session.get("user_id") != user:
|
||||
return make_response(jsonify({"success": False, "error": "Invalid or unauthorized session"}), 401)
|
||||
|
||||
loader = ConnectorCreator.create_connector(provider, session_token)
|
||||
@@ -288,8 +291,11 @@ class ConnectorValidateSession(Resource):
|
||||
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
|
||||
user = decoded_token.get('sub')
|
||||
|
||||
session = sessions_collection.find_one({"session_token": session_token, "user": user})
|
||||
if not session or "token_info" not in session:
|
||||
with db_readonly() as conn:
|
||||
session = ConnectorSessionsRepository(conn).get_by_session_token(
|
||||
session_token,
|
||||
)
|
||||
if not session or session.get("user_id") != user or not session.get("token_info"):
|
||||
return make_response(jsonify({"success": False, "error": "Invalid or expired session"}), 401)
|
||||
|
||||
token_info = session["token_info"]
|
||||
@@ -300,10 +306,11 @@ class ConnectorValidateSession(Resource):
|
||||
try:
|
||||
refreshed_token_info = auth.refresh_access_token(token_info.get('refresh_token'))
|
||||
sanitized_token_info = auth.sanitize_token_info(refreshed_token_info)
|
||||
sessions_collection.update_one(
|
||||
{"session_token": session_token},
|
||||
{"$set": {"token_info": sanitized_token_info}}
|
||||
)
|
||||
with db_session() as conn:
|
||||
repo = ConnectorSessionsRepository(conn)
|
||||
row = repo.get_by_session_token(session_token)
|
||||
if row:
|
||||
repo.update(str(row["id"]), {"token_info": sanitized_token_info})
|
||||
token_info = sanitized_token_info
|
||||
is_expired = False
|
||||
except Exception as refresh_error:
|
||||
@@ -347,8 +354,11 @@ class ConnectorDisconnect(Resource):
|
||||
|
||||
|
||||
if session_token:
|
||||
sessions_collection.delete_one({"session_token": session_token})
|
||||
|
||||
with db_session() as conn:
|
||||
ConnectorSessionsRepository(conn).delete_by_session_token(
|
||||
session_token,
|
||||
)
|
||||
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error disconnecting connector session: {e}", exc_info=True)
|
||||
@@ -385,32 +395,28 @@ class ConnectorSync(Resource):
|
||||
}),
|
||||
400
|
||||
)
|
||||
source = sources_collection.find_one({"_id": ObjectId(source_id)})
|
||||
user_id = decoded_token.get('sub')
|
||||
with db_readonly() as conn:
|
||||
source = SourcesRepository(conn).get_any(source_id, user_id)
|
||||
if not source:
|
||||
return make_response(
|
||||
jsonify({
|
||||
"success": False,
|
||||
"error": "Source not found"
|
||||
}),
|
||||
}),
|
||||
404
|
||||
)
|
||||
|
||||
if source.get('user') != decoded_token.get('sub'):
|
||||
return make_response(
|
||||
jsonify({
|
||||
"success": False,
|
||||
"error": "Unauthorized access to source"
|
||||
}),
|
||||
403
|
||||
)
|
||||
# ``get_any`` already scopes by ``user_id``; an extra guard
|
||||
# here would be dead code.
|
||||
|
||||
remote_data = {}
|
||||
try:
|
||||
if source.get('remote_data'):
|
||||
remote_data = json.loads(source.get('remote_data'))
|
||||
except json.JSONDecodeError:
|
||||
current_app.logger.error(f"Invalid remote_data format for source {source_id}")
|
||||
remote_data = {}
|
||||
remote_data = source.get('remote_data') or {}
|
||||
if isinstance(remote_data, str):
|
||||
try:
|
||||
remote_data = json.loads(remote_data)
|
||||
except json.JSONDecodeError:
|
||||
current_app.logger.error(f"Invalid remote_data format for source {source_id}")
|
||||
remote_data = {}
|
||||
|
||||
source_type = remote_data.get('provider')
|
||||
if not source_type:
|
||||
@@ -438,7 +444,7 @@ class ConnectorSync(Resource):
|
||||
recursive=recursive,
|
||||
retriever=source.get('retriever', 'classic'),
|
||||
operation_mode="sync",
|
||||
doc_id=source_id,
|
||||
doc_id=str(source.get('id') or source_id),
|
||||
sync_frequency=source.get('sync_frequency', 'never')
|
||||
)
|
||||
|
||||
|
||||
@@ -3,18 +3,16 @@ import datetime
|
||||
import json
|
||||
from flask import Blueprint, request, send_from_directory, jsonify
|
||||
from werkzeug.utils import secure_filename
|
||||
from bson.objectid import ObjectId
|
||||
import logging
|
||||
from application.core.mongo_db import MongoDB
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
from application.storage.db.repositories.sources import SourcesRepository
|
||||
from application.storage.db.session import db_session
|
||||
from application.storage.storage_creator import StorageCreator
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
conversations_collection = db["conversations"]
|
||||
sources_collection = db["sources"]
|
||||
|
||||
current_dir = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
@@ -56,21 +54,21 @@ def upload_index_files():
|
||||
"""Upload two files(index.faiss, index.pkl) to the user's folder."""
|
||||
if "user" not in request.form:
|
||||
return {"status": "no user"}
|
||||
user = request.form["user"]
|
||||
user = request.form["user"]
|
||||
if "name" not in request.form:
|
||||
return {"status": "no name"}
|
||||
job_name = request.form["name"]
|
||||
tokens = request.form["tokens"]
|
||||
retriever = request.form["retriever"]
|
||||
id = request.form["id"]
|
||||
source_id = request.form["id"]
|
||||
type = request.form["type"]
|
||||
remote_data = request.form["remote_data"] if "remote_data" in request.form else None
|
||||
sync_frequency = request.form["sync_frequency"] if "sync_frequency" in request.form else None
|
||||
|
||||
|
||||
file_path = request.form.get("file_path")
|
||||
directory_structure = request.form.get("directory_structure")
|
||||
file_name_map = request.form.get("file_name_map")
|
||||
|
||||
|
||||
if directory_structure:
|
||||
try:
|
||||
directory_structure = json.loads(directory_structure)
|
||||
@@ -89,8 +87,8 @@ def upload_index_files():
|
||||
file_name_map = None
|
||||
|
||||
storage = StorageCreator.get_storage()
|
||||
index_base_path = f"indexes/{id}"
|
||||
|
||||
index_base_path = f"indexes/{source_id}"
|
||||
|
||||
if settings.VECTOR_STORE == "faiss":
|
||||
if "file_faiss" not in request.files:
|
||||
logger.error("No file_faiss part")
|
||||
@@ -111,46 +109,48 @@ def upload_index_files():
|
||||
storage.save_file(file_faiss, faiss_storage_path)
|
||||
storage.save_file(file_pkl, pkl_storage_path)
|
||||
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
update_fields = {
|
||||
"name": job_name,
|
||||
"type": type,
|
||||
"language": job_name,
|
||||
"date": now,
|
||||
"model": settings.EMBEDDINGS_NAME,
|
||||
"tokens": tokens,
|
||||
"retriever": retriever,
|
||||
"remote_data": remote_data,
|
||||
"sync_frequency": sync_frequency,
|
||||
"file_path": file_path,
|
||||
"directory_structure": directory_structure,
|
||||
}
|
||||
if file_name_map is not None:
|
||||
update_fields["file_name_map"] = file_name_map
|
||||
|
||||
existing_entry = sources_collection.find_one({"_id": ObjectId(id)})
|
||||
if existing_entry:
|
||||
update_fields = {
|
||||
"user": user,
|
||||
"name": job_name,
|
||||
"language": job_name,
|
||||
"date": datetime.datetime.now(),
|
||||
"model": settings.EMBEDDINGS_NAME,
|
||||
"type": type,
|
||||
"tokens": tokens,
|
||||
"retriever": retriever,
|
||||
"remote_data": remote_data,
|
||||
"sync_frequency": sync_frequency,
|
||||
"file_path": file_path,
|
||||
"directory_structure": directory_structure,
|
||||
}
|
||||
if file_name_map is not None:
|
||||
update_fields["file_name_map"] = file_name_map
|
||||
sources_collection.update_one(
|
||||
{"_id": ObjectId(id)},
|
||||
{"$set": update_fields},
|
||||
)
|
||||
else:
|
||||
insert_doc = {
|
||||
"_id": ObjectId(id),
|
||||
"user": user,
|
||||
"name": job_name,
|
||||
"language": job_name,
|
||||
"date": datetime.datetime.now(),
|
||||
"model": settings.EMBEDDINGS_NAME,
|
||||
"type": type,
|
||||
"tokens": tokens,
|
||||
"retriever": retriever,
|
||||
"remote_data": remote_data,
|
||||
"sync_frequency": sync_frequency,
|
||||
"file_path": file_path,
|
||||
"directory_structure": directory_structure,
|
||||
}
|
||||
if file_name_map is not None:
|
||||
insert_doc["file_name_map"] = file_name_map
|
||||
sources_collection.insert_one(insert_doc)
|
||||
with db_session() as conn:
|
||||
repo = SourcesRepository(conn)
|
||||
existing = None
|
||||
if looks_like_uuid(source_id):
|
||||
existing = repo.get(source_id, user)
|
||||
if existing is None:
|
||||
existing = repo.get_by_legacy_id(source_id, user)
|
||||
if existing is not None:
|
||||
repo.update(str(existing["id"]), user, update_fields)
|
||||
else:
|
||||
repo.create(
|
||||
job_name,
|
||||
source_id=source_id if looks_like_uuid(source_id) else None,
|
||||
user_id=user,
|
||||
type=type,
|
||||
tokens=tokens,
|
||||
retriever=retriever,
|
||||
remote_data=remote_data,
|
||||
sync_frequency=sync_frequency,
|
||||
file_path=file_path,
|
||||
directory_structure=directory_structure,
|
||||
file_name_map=file_name_map,
|
||||
language=job_name,
|
||||
model=settings.EMBEDDINGS_NAME,
|
||||
date=now,
|
||||
legacy_mongo_id=None if looks_like_uuid(source_id) else str(source_id),
|
||||
)
|
||||
return {"status": "ok"}
|
||||
|
||||
@@ -3,29 +3,50 @@ Agent folders management routes.
|
||||
Provides virtual folder organization for agents (Google Drive-like structure).
|
||||
"""
|
||||
|
||||
import datetime
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import Namespace, Resource, fields
|
||||
from sqlalchemy import text as _sql_text
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import (
|
||||
agent_folders_collection,
|
||||
agents_collection,
|
||||
)
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
from application.storage.db.repositories.agent_folders import AgentFoldersRepository
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
|
||||
|
||||
agents_folders_ns = Namespace(
|
||||
"agents_folders", description="Agent folder management", path="/api/agents/folders"
|
||||
)
|
||||
|
||||
|
||||
def _resolve_folder_id(repo: AgentFoldersRepository, folder_id: str, user: str):
|
||||
"""Resolve a folder id that may be either a UUID or legacy Mongo ObjectId."""
|
||||
if not folder_id:
|
||||
return None
|
||||
if looks_like_uuid(folder_id):
|
||||
row = repo.get(folder_id, user)
|
||||
if row is not None:
|
||||
return row
|
||||
return repo.get_by_legacy_id(folder_id, user)
|
||||
|
||||
|
||||
def _folder_error_response(message: str, err: Exception):
|
||||
current_app.logger.error(f"{message}: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False, "message": message}), 400)
|
||||
|
||||
|
||||
def _serialize_folder(f: dict) -> dict:
|
||||
created_at = f.get("created_at")
|
||||
updated_at = f.get("updated_at")
|
||||
return {
|
||||
"id": str(f["id"]),
|
||||
"name": f.get("name"),
|
||||
"parent_id": str(f["parent_id"]) if f.get("parent_id") else None,
|
||||
"created_at": created_at.isoformat() if hasattr(created_at, "isoformat") else created_at,
|
||||
"updated_at": updated_at.isoformat() if hasattr(updated_at, "isoformat") else updated_at,
|
||||
}
|
||||
|
||||
|
||||
@agents_folders_ns.route("/")
|
||||
class AgentFolders(Resource):
|
||||
@api.doc(description="Get all folders for the user")
|
||||
@@ -35,17 +56,9 @@ class AgentFolders(Resource):
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
folders = list(agent_folders_collection.find({"user": user}))
|
||||
result = [
|
||||
{
|
||||
"id": str(f["_id"]),
|
||||
"name": f["name"],
|
||||
"parent_id": f.get("parent_id"),
|
||||
"created_at": f.get("created_at", "").isoformat() if f.get("created_at") else None,
|
||||
"updated_at": f.get("updated_at", "").isoformat() if f.get("updated_at") else None,
|
||||
}
|
||||
for f in folders
|
||||
]
|
||||
with db_readonly() as conn:
|
||||
folders = AgentFoldersRepository(conn).list_for_user(user)
|
||||
result = [_serialize_folder(f) for f in folders]
|
||||
return make_response(jsonify({"folders": result}), 200)
|
||||
except Exception as err:
|
||||
return _folder_error_response("Failed to fetch folders", err)
|
||||
@@ -69,28 +82,34 @@ class AgentFolders(Resource):
|
||||
if not data or not data.get("name"):
|
||||
return make_response(jsonify({"success": False, "message": "Folder name is required"}), 400)
|
||||
|
||||
parent_id = data.get("parent_id")
|
||||
if parent_id:
|
||||
parent = agent_folders_collection.find_one({"_id": ObjectId(parent_id), "user": user})
|
||||
if not parent:
|
||||
return make_response(jsonify({"success": False, "message": "Parent folder not found"}), 404)
|
||||
parent_id_input = data.get("parent_id")
|
||||
description = data.get("description")
|
||||
|
||||
try:
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
folder = {
|
||||
"user": user,
|
||||
"name": data["name"],
|
||||
"parent_id": parent_id,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
}
|
||||
result = agent_folders_collection.insert_one(folder)
|
||||
dual_write(
|
||||
AgentFoldersRepository,
|
||||
lambda repo, u=user, n=data["name"]: repo.create(u, n),
|
||||
)
|
||||
with db_session() as conn:
|
||||
repo = AgentFoldersRepository(conn)
|
||||
pg_parent_id = None
|
||||
if parent_id_input:
|
||||
parent = _resolve_folder_id(repo, parent_id_input, user)
|
||||
if not parent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Parent folder not found"}),
|
||||
404,
|
||||
)
|
||||
pg_parent_id = str(parent["id"])
|
||||
folder = repo.create(
|
||||
user, data["name"],
|
||||
description=description,
|
||||
parent_id=pg_parent_id,
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"id": str(result.inserted_id), "name": data["name"], "parent_id": parent_id}),
|
||||
jsonify(
|
||||
{
|
||||
"id": str(folder["id"]),
|
||||
"name": folder["name"],
|
||||
"parent_id": pg_parent_id,
|
||||
}
|
||||
),
|
||||
201,
|
||||
)
|
||||
except Exception as err:
|
||||
@@ -106,26 +125,51 @@ class AgentFolder(Resource):
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
folder = agent_folders_collection.find_one({"_id": ObjectId(folder_id), "user": user})
|
||||
if not folder:
|
||||
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
|
||||
|
||||
agents = list(agents_collection.find({"user": user, "folder_id": folder_id}))
|
||||
agents_list = [
|
||||
{"id": str(a["_id"]), "name": a["name"], "description": a.get("description", "")}
|
||||
for a in agents
|
||||
]
|
||||
subfolders = list(agent_folders_collection.find({"user": user, "parent_id": folder_id}))
|
||||
subfolders_list = [{"id": str(sf["_id"]), "name": sf["name"]} for sf in subfolders]
|
||||
with db_readonly() as conn:
|
||||
folders_repo = AgentFoldersRepository(conn)
|
||||
folder = _resolve_folder_id(folders_repo, folder_id, user)
|
||||
if not folder:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Folder not found"}),
|
||||
404,
|
||||
)
|
||||
pg_folder_id = str(folder["id"])
|
||||
|
||||
agents_rows = conn.execute(
|
||||
_sql_text(
|
||||
"SELECT id, name, description FROM agents "
|
||||
"WHERE user_id = :user_id AND folder_id = CAST(:fid AS uuid) "
|
||||
"ORDER BY created_at DESC"
|
||||
),
|
||||
{"user_id": user, "fid": pg_folder_id},
|
||||
).fetchall()
|
||||
agents_list = [
|
||||
{
|
||||
"id": str(row._mapping["id"]),
|
||||
"name": row._mapping["name"],
|
||||
"description": row._mapping.get("description", "") or "",
|
||||
}
|
||||
for row in agents_rows
|
||||
]
|
||||
|
||||
subfolders = folders_repo.list_children(pg_folder_id, user)
|
||||
subfolders_list = [
|
||||
{"id": str(sf["id"]), "name": sf["name"]}
|
||||
for sf in subfolders
|
||||
]
|
||||
|
||||
return make_response(
|
||||
jsonify({
|
||||
"id": str(folder["_id"]),
|
||||
"name": folder["name"],
|
||||
"parent_id": folder.get("parent_id"),
|
||||
"agents": agents_list,
|
||||
"subfolders": subfolders_list,
|
||||
}),
|
||||
jsonify(
|
||||
{
|
||||
"id": pg_folder_id,
|
||||
"name": folder["name"],
|
||||
"parent_id": (
|
||||
str(folder["parent_id"]) if folder.get("parent_id") else None
|
||||
),
|
||||
"agents": agents_list,
|
||||
"subfolders": subfolders_list,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
except Exception as err:
|
||||
@@ -142,19 +186,57 @@ class AgentFolder(Resource):
|
||||
return make_response(jsonify({"success": False, "message": "No data provided"}), 400)
|
||||
|
||||
try:
|
||||
update_fields = {"updated_at": datetime.datetime.now(datetime.timezone.utc)}
|
||||
if "name" in data:
|
||||
update_fields["name"] = data["name"]
|
||||
if "parent_id" in data:
|
||||
if data["parent_id"] == folder_id:
|
||||
return make_response(jsonify({"success": False, "message": "Cannot set folder as its own parent"}), 400)
|
||||
update_fields["parent_id"] = data["parent_id"]
|
||||
with db_session() as conn:
|
||||
repo = AgentFoldersRepository(conn)
|
||||
folder = _resolve_folder_id(repo, folder_id, user)
|
||||
if not folder:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Folder not found"}),
|
||||
404,
|
||||
)
|
||||
pg_folder_id = str(folder["id"])
|
||||
|
||||
update_fields: dict = {}
|
||||
if "name" in data:
|
||||
update_fields["name"] = data["name"]
|
||||
if "description" in data:
|
||||
update_fields["description"] = data["description"]
|
||||
if "parent_id" in data:
|
||||
parent_input = data.get("parent_id")
|
||||
if parent_input:
|
||||
if parent_input == folder_id or parent_input == pg_folder_id:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Cannot set folder as its own parent",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
parent = _resolve_folder_id(repo, parent_input, user)
|
||||
if not parent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Parent folder not found"}),
|
||||
404,
|
||||
)
|
||||
if str(parent["id"]) == pg_folder_id:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Cannot set folder as its own parent",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
update_fields["parent_id"] = str(parent["id"])
|
||||
else:
|
||||
update_fields["parent_id"] = None
|
||||
|
||||
if update_fields:
|
||||
repo.update(pg_folder_id, user, update_fields)
|
||||
|
||||
result = agent_folders_collection.update_one(
|
||||
{"_id": ObjectId(folder_id), "user": user}, {"$set": update_fields}
|
||||
)
|
||||
if result.matched_count == 0:
|
||||
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
except Exception as err:
|
||||
return _folder_error_response("Failed to update folder", err)
|
||||
@@ -166,19 +248,24 @@ class AgentFolder(Resource):
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
agents_collection.update_many(
|
||||
{"user": user, "folder_id": folder_id}, {"$unset": {"folder_id": ""}}
|
||||
)
|
||||
agent_folders_collection.update_many(
|
||||
{"user": user, "parent_id": folder_id}, {"$unset": {"parent_id": ""}}
|
||||
)
|
||||
result = agent_folders_collection.delete_one({"_id": ObjectId(folder_id), "user": user})
|
||||
dual_write(
|
||||
AgentFoldersRepository,
|
||||
lambda repo, fid=folder_id, u=user: repo.delete(fid, u),
|
||||
)
|
||||
if result.deleted_count == 0:
|
||||
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
|
||||
with db_session() as conn:
|
||||
repo = AgentFoldersRepository(conn)
|
||||
folder = _resolve_folder_id(repo, folder_id, user)
|
||||
if not folder:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Folder not found"}),
|
||||
404,
|
||||
)
|
||||
pg_folder_id = str(folder["id"])
|
||||
# Clear folder assignments from agents; self-FK
|
||||
# ``ON DELETE SET NULL`` handles child folders.
|
||||
AgentsRepository(conn).clear_folder_for_all(pg_folder_id, user)
|
||||
deleted = repo.delete(pg_folder_id, user)
|
||||
if not deleted:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Folder not found"}),
|
||||
404,
|
||||
)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
except Exception as err:
|
||||
return _folder_error_response("Failed to delete folder", err)
|
||||
@@ -205,26 +292,29 @@ class MoveAgentToFolder(Resource):
|
||||
if not data or not data.get("agent_id"):
|
||||
return make_response(jsonify({"success": False, "message": "Agent ID is required"}), 400)
|
||||
|
||||
agent_id = data["agent_id"]
|
||||
folder_id = data.get("folder_id")
|
||||
agent_id_input = data["agent_id"]
|
||||
folder_id_input = data.get("folder_id")
|
||||
|
||||
try:
|
||||
agent = agents_collection.find_one({"_id": ObjectId(agent_id), "user": user})
|
||||
if not agent:
|
||||
return make_response(jsonify({"success": False, "message": "Agent not found"}), 404)
|
||||
|
||||
if folder_id:
|
||||
folder = agent_folders_collection.find_one({"_id": ObjectId(folder_id), "user": user})
|
||||
if not folder:
|
||||
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
|
||||
agents_collection.update_one(
|
||||
{"_id": ObjectId(agent_id)}, {"$set": {"folder_id": folder_id}}
|
||||
)
|
||||
else:
|
||||
agents_collection.update_one(
|
||||
{"_id": ObjectId(agent_id)}, {"$unset": {"folder_id": ""}}
|
||||
)
|
||||
|
||||
with db_session() as conn:
|
||||
agents_repo = AgentsRepository(conn)
|
||||
agent = agents_repo.get_any(agent_id_input, user)
|
||||
if not agent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Agent not found"}),
|
||||
404,
|
||||
)
|
||||
pg_folder_id = None
|
||||
if folder_id_input:
|
||||
folders_repo = AgentFoldersRepository(conn)
|
||||
folder = _resolve_folder_id(folders_repo, folder_id_input, user)
|
||||
if not folder:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Folder not found"}),
|
||||
404,
|
||||
)
|
||||
pg_folder_id = str(folder["id"])
|
||||
agents_repo.set_folder(str(agent["id"]), user, pg_folder_id)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
except Exception as err:
|
||||
return _folder_error_response("Failed to move agent", err)
|
||||
@@ -252,25 +342,25 @@ class BulkMoveAgents(Resource):
|
||||
return make_response(jsonify({"success": False, "message": "Agent IDs are required"}), 400)
|
||||
|
||||
agent_ids = data["agent_ids"]
|
||||
folder_id = data.get("folder_id")
|
||||
folder_id_input = data.get("folder_id")
|
||||
|
||||
try:
|
||||
if folder_id:
|
||||
folder = agent_folders_collection.find_one({"_id": ObjectId(folder_id), "user": user})
|
||||
if not folder:
|
||||
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
|
||||
|
||||
object_ids = [ObjectId(aid) for aid in agent_ids]
|
||||
if folder_id:
|
||||
agents_collection.update_many(
|
||||
{"_id": {"$in": object_ids}, "user": user},
|
||||
{"$set": {"folder_id": folder_id}},
|
||||
)
|
||||
else:
|
||||
agents_collection.update_many(
|
||||
{"_id": {"$in": object_ids}, "user": user},
|
||||
{"$unset": {"folder_id": ""}},
|
||||
)
|
||||
with db_session() as conn:
|
||||
agents_repo = AgentsRepository(conn)
|
||||
pg_folder_id = None
|
||||
if folder_id_input:
|
||||
folders_repo = AgentFoldersRepository(conn)
|
||||
folder = _resolve_folder_id(folders_repo, folder_id_input, user)
|
||||
if not folder:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Folder not found"}),
|
||||
404,
|
||||
)
|
||||
pg_folder_id = str(folder["id"])
|
||||
for agent_id_input in agent_ids:
|
||||
agent = agents_repo.get_any(agent_id_input, user)
|
||||
if agent is not None:
|
||||
agents_repo.set_folder(str(agent["id"]), user, pg_folder_id)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
except Exception as err:
|
||||
return _folder_error_response("Failed to move agents", err)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -3,23 +3,17 @@
|
||||
import datetime
|
||||
import secrets
|
||||
|
||||
from bson import DBRef
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
from sqlalchemy import text as _sql_text
|
||||
|
||||
from application.api import api
|
||||
from application.core.settings import settings
|
||||
from application.api.user.base import (
|
||||
agents_collection,
|
||||
db,
|
||||
ensure_user_doc,
|
||||
resolve_tool_details,
|
||||
user_tools_collection,
|
||||
users_collection,
|
||||
)
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.api.user.base import resolve_tool_details
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.users import UsersRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.utils import generate_image_url
|
||||
|
||||
agents_sharing_ns = Namespace(
|
||||
@@ -27,6 +21,38 @@ agents_sharing_ns = Namespace(
|
||||
)
|
||||
|
||||
|
||||
def _serialize_agent_basic(agent: dict) -> dict:
|
||||
"""Shape a PG agent row into the API response dict."""
|
||||
source_id = agent.get("source_id")
|
||||
return {
|
||||
"id": str(agent["id"]),
|
||||
"user": agent.get("user_id", ""),
|
||||
"name": agent.get("name", ""),
|
||||
"image": (
|
||||
generate_image_url(agent["image"]) if agent.get("image") else ""
|
||||
),
|
||||
"description": agent.get("description", ""),
|
||||
"source": str(source_id) if source_id else "",
|
||||
"chunks": str(agent["chunks"]) if agent.get("chunks") is not None else "0",
|
||||
"retriever": agent.get("retriever", "classic") or "classic",
|
||||
"prompt_id": str(agent["prompt_id"]) if agent.get("prompt_id") else "default",
|
||||
"tools": agent.get("tools", []) or [],
|
||||
"tool_details": resolve_tool_details(agent.get("tools", []) or []),
|
||||
"agent_type": agent.get("agent_type", "") or "",
|
||||
"status": agent.get("status", "") or "",
|
||||
"json_schema": agent.get("json_schema"),
|
||||
"limited_token_mode": agent.get("limited_token_mode", False),
|
||||
"token_limit": agent.get("token_limit") or settings.DEFAULT_AGENT_LIMITS["token_limit"],
|
||||
"limited_request_mode": agent.get("limited_request_mode", False),
|
||||
"request_limit": agent.get("request_limit") or settings.DEFAULT_AGENT_LIMITS["request_limit"],
|
||||
"created_at": agent.get("created_at", ""),
|
||||
"updated_at": agent.get("updated_at", ""),
|
||||
"shared": bool(agent.get("shared", False)),
|
||||
"shared_token": agent.get("shared_token", "") or "",
|
||||
"shared_metadata": agent.get("shared_metadata", {}) or {},
|
||||
}
|
||||
|
||||
|
||||
@agents_sharing_ns.route("/shared_agent")
|
||||
class SharedAgent(Resource):
|
||||
@api.doc(
|
||||
@@ -43,73 +69,33 @@ class SharedAgent(Resource):
|
||||
jsonify({"success": False, "message": "Token or ID is required"}), 400
|
||||
)
|
||||
try:
|
||||
query = {
|
||||
"shared_publicly": True,
|
||||
"shared_token": shared_token,
|
||||
}
|
||||
shared_agent = agents_collection.find_one(query)
|
||||
with db_readonly() as conn:
|
||||
shared_agent = AgentsRepository(conn).find_by_shared_token(
|
||||
shared_token,
|
||||
)
|
||||
if not shared_agent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Shared agent not found"}),
|
||||
404,
|
||||
)
|
||||
agent_id = str(shared_agent["_id"])
|
||||
data = {
|
||||
"id": agent_id,
|
||||
"user": shared_agent.get("user", ""),
|
||||
"name": shared_agent.get("name", ""),
|
||||
"image": (
|
||||
generate_image_url(shared_agent["image"])
|
||||
if shared_agent.get("image")
|
||||
else ""
|
||||
),
|
||||
"description": shared_agent.get("description", ""),
|
||||
"source": (
|
||||
str(source_doc["_id"])
|
||||
if isinstance(shared_agent.get("source"), DBRef)
|
||||
and (source_doc := db.dereference(shared_agent.get("source")))
|
||||
else ""
|
||||
),
|
||||
"chunks": shared_agent.get("chunks", "0"),
|
||||
"retriever": shared_agent.get("retriever", "classic"),
|
||||
"prompt_id": shared_agent.get("prompt_id", "default"),
|
||||
"tools": shared_agent.get("tools", []),
|
||||
"tool_details": resolve_tool_details(shared_agent.get("tools", [])),
|
||||
"agent_type": shared_agent.get("agent_type", ""),
|
||||
"status": shared_agent.get("status", ""),
|
||||
"json_schema": shared_agent.get("json_schema"),
|
||||
"limited_token_mode": shared_agent.get("limited_token_mode", False),
|
||||
"token_limit": shared_agent.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"]),
|
||||
"limited_request_mode": shared_agent.get("limited_request_mode", False),
|
||||
"request_limit": shared_agent.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"]),
|
||||
"created_at": shared_agent.get("createdAt", ""),
|
||||
"updated_at": shared_agent.get("updatedAt", ""),
|
||||
"shared": shared_agent.get("shared_publicly", False),
|
||||
"shared_token": shared_agent.get("shared_token", ""),
|
||||
"shared_metadata": shared_agent.get("shared_metadata", {}),
|
||||
}
|
||||
agent_id = str(shared_agent["id"])
|
||||
data = _serialize_agent_basic(shared_agent)
|
||||
|
||||
if data["tools"]:
|
||||
enriched_tools = []
|
||||
for tool in data["tools"]:
|
||||
tool_data = user_tools_collection.find_one({"_id": ObjectId(tool)})
|
||||
if tool_data:
|
||||
enriched_tools.append(tool_data.get("name", ""))
|
||||
for detail in data["tool_details"]:
|
||||
enriched_tools.append(detail.get("name", ""))
|
||||
data["tools"] = enriched_tools
|
||||
decoded_token = getattr(request, "decoded_token", None)
|
||||
if decoded_token:
|
||||
user_id = decoded_token.get("sub")
|
||||
owner_id = shared_agent.get("user")
|
||||
owner_id = shared_agent.get("user_id")
|
||||
|
||||
if user_id != owner_id:
|
||||
ensure_user_doc(user_id)
|
||||
users_collection.update_one(
|
||||
{"user_id": user_id},
|
||||
{"$addToSet": {"agent_preferences.shared_with_me": agent_id}},
|
||||
)
|
||||
dual_write(UsersRepository,
|
||||
lambda repo, uid=user_id, aid=agent_id: repo.add_shared(uid, aid)
|
||||
)
|
||||
with db_session() as conn:
|
||||
users_repo = UsersRepository(conn)
|
||||
users_repo.upsert(user_id)
|
||||
users_repo.add_shared(user_id, agent_id)
|
||||
return make_response(jsonify(data), 200)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error retrieving shared agent: {err}")
|
||||
@@ -126,55 +112,73 @@ class SharedAgents(Resource):
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user_id = decoded_token.get("sub")
|
||||
|
||||
user_doc = ensure_user_doc(user_id)
|
||||
shared_with_ids = user_doc.get("agent_preferences", {}).get(
|
||||
"shared_with_me", []
|
||||
)
|
||||
shared_object_ids = [ObjectId(id) for id in shared_with_ids]
|
||||
|
||||
shared_agents_cursor = agents_collection.find(
|
||||
{"_id": {"$in": shared_object_ids}, "shared_publicly": True}
|
||||
)
|
||||
shared_agents = list(shared_agents_cursor)
|
||||
|
||||
found_ids_set = {str(agent["_id"]) for agent in shared_agents}
|
||||
stale_ids = [id for id in shared_with_ids if id not in found_ids_set]
|
||||
if stale_ids:
|
||||
users_collection.update_one(
|
||||
{"user_id": user_id},
|
||||
{"$pullAll": {"agent_preferences.shared_with_me": stale_ids}},
|
||||
with db_session() as conn:
|
||||
users_repo = UsersRepository(conn)
|
||||
user_doc = users_repo.upsert(user_id)
|
||||
shared_with_ids = (
|
||||
user_doc.get("agent_preferences", {}).get("shared_with_me", [])
|
||||
if isinstance(user_doc.get("agent_preferences"), dict)
|
||||
else []
|
||||
)
|
||||
dual_write(UsersRepository,
|
||||
lambda repo, uid=user_id, ids=stale_ids: repo.remove_shared_bulk(uid, ids)
|
||||
)
|
||||
pinned_ids = set(user_doc.get("agent_preferences", {}).get("pinned", []))
|
||||
# Keep only UUID-shaped ids; ObjectId leftovers are stripped below.
|
||||
uuid_ids = [sid for sid in shared_with_ids if looks_like_uuid(sid)]
|
||||
non_uuid_ids = [sid for sid in shared_with_ids if not looks_like_uuid(sid)]
|
||||
|
||||
list_shared_agents = [
|
||||
{
|
||||
"id": str(agent["_id"]),
|
||||
"name": agent.get("name", ""),
|
||||
"description": agent.get("description", ""),
|
||||
"image": (
|
||||
generate_image_url(agent["image"]) if agent.get("image") else ""
|
||||
),
|
||||
"tools": agent.get("tools", []),
|
||||
"tool_details": resolve_tool_details(agent.get("tools", [])),
|
||||
"agent_type": agent.get("agent_type", ""),
|
||||
"status": agent.get("status", ""),
|
||||
"json_schema": agent.get("json_schema"),
|
||||
"limited_token_mode": agent.get("limited_token_mode", False),
|
||||
"token_limit": agent.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"]),
|
||||
"limited_request_mode": agent.get("limited_request_mode", False),
|
||||
"request_limit": agent.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"]),
|
||||
"created_at": agent.get("createdAt", ""),
|
||||
"updated_at": agent.get("updatedAt", ""),
|
||||
"pinned": str(agent["_id"]) in pinned_ids,
|
||||
"shared": agent.get("shared_publicly", False),
|
||||
"shared_token": agent.get("shared_token", ""),
|
||||
"shared_metadata": agent.get("shared_metadata", {}),
|
||||
}
|
||||
for agent in shared_agents
|
||||
]
|
||||
if uuid_ids:
|
||||
result = conn.execute(
|
||||
_sql_text(
|
||||
"SELECT * FROM agents "
|
||||
"WHERE id = ANY(CAST(:ids AS uuid[])) "
|
||||
"AND shared = true"
|
||||
),
|
||||
{"ids": uuid_ids},
|
||||
)
|
||||
shared_agents = [dict(row._mapping) for row in result.fetchall()]
|
||||
else:
|
||||
shared_agents = []
|
||||
|
||||
found_ids_set = {str(agent["id"]) for agent in shared_agents}
|
||||
stale_ids = [sid for sid in uuid_ids if sid not in found_ids_set]
|
||||
stale_ids.extend(non_uuid_ids)
|
||||
if stale_ids:
|
||||
users_repo.remove_shared_bulk(user_id, stale_ids)
|
||||
|
||||
pinned_ids = set(
|
||||
user_doc.get("agent_preferences", {}).get("pinned", [])
|
||||
if isinstance(user_doc.get("agent_preferences"), dict)
|
||||
else []
|
||||
)
|
||||
|
||||
list_shared_agents = []
|
||||
for agent in shared_agents:
|
||||
agent_id_str = str(agent["id"])
|
||||
list_shared_agents.append(
|
||||
{
|
||||
"id": agent_id_str,
|
||||
"name": agent.get("name", ""),
|
||||
"description": agent.get("description", ""),
|
||||
"image": (
|
||||
generate_image_url(agent["image"]) if agent.get("image") else ""
|
||||
),
|
||||
"tools": agent.get("tools", []) or [],
|
||||
"tool_details": resolve_tool_details(
|
||||
agent.get("tools", []) or []
|
||||
),
|
||||
"agent_type": agent.get("agent_type", "") or "",
|
||||
"status": agent.get("status", "") or "",
|
||||
"json_schema": agent.get("json_schema"),
|
||||
"limited_token_mode": agent.get("limited_token_mode", False),
|
||||
"token_limit": agent.get("token_limit") or settings.DEFAULT_AGENT_LIMITS["token_limit"],
|
||||
"limited_request_mode": agent.get("limited_request_mode", False),
|
||||
"request_limit": agent.get("request_limit") or settings.DEFAULT_AGENT_LIMITS["request_limit"],
|
||||
"created_at": agent.get("created_at", ""),
|
||||
"updated_at": agent.get("updated_at", ""),
|
||||
"pinned": agent_id_str in pinned_ids,
|
||||
"shared": bool(agent.get("shared", False)),
|
||||
"shared_token": agent.get("shared_token", "") or "",
|
||||
"shared_metadata": agent.get("shared_metadata", {}) or {},
|
||||
}
|
||||
)
|
||||
|
||||
return make_response(jsonify(list_shared_agents), 200)
|
||||
except Exception as err:
|
||||
@@ -228,44 +232,43 @@ class ShareAgent(Resource):
|
||||
),
|
||||
400,
|
||||
)
|
||||
shared_token = None
|
||||
try:
|
||||
try:
|
||||
agent_oid = ObjectId(agent_id)
|
||||
except Exception:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid agent ID"}), 400
|
||||
)
|
||||
agent = agents_collection.find_one({"_id": agent_oid, "user": user})
|
||||
if not agent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Agent not found"}), 404
|
||||
)
|
||||
if shared:
|
||||
shared_metadata = {
|
||||
"shared_by": username,
|
||||
"shared_at": datetime.datetime.now(datetime.timezone.utc),
|
||||
}
|
||||
shared_token = secrets.token_urlsafe(32)
|
||||
agents_collection.update_one(
|
||||
{"_id": agent_oid, "user": user},
|
||||
{
|
||||
"$set": {
|
||||
"shared_publicly": shared,
|
||||
"shared_metadata": shared_metadata,
|
||||
with db_session() as conn:
|
||||
repo = AgentsRepository(conn)
|
||||
agent = repo.get_any(agent_id, user)
|
||||
if not agent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Agent not found"}), 404
|
||||
)
|
||||
if shared:
|
||||
shared_metadata = {
|
||||
"shared_by": username,
|
||||
"shared_at": datetime.datetime.now(
|
||||
datetime.timezone.utc
|
||||
).isoformat(),
|
||||
}
|
||||
shared_token = secrets.token_urlsafe(32)
|
||||
repo.update(
|
||||
str(agent["id"]), user,
|
||||
{
|
||||
"shared": True,
|
||||
"shared_token": shared_token,
|
||||
}
|
||||
},
|
||||
)
|
||||
else:
|
||||
agents_collection.update_one(
|
||||
{"_id": agent_oid, "user": user},
|
||||
{"$set": {"shared_publicly": shared, "shared_token": None}},
|
||||
{"$unset": {"shared_metadata": ""}},
|
||||
)
|
||||
"shared_metadata": shared_metadata,
|
||||
},
|
||||
)
|
||||
else:
|
||||
repo.update(
|
||||
str(agent["id"]), user,
|
||||
{
|
||||
"shared": False,
|
||||
"shared_token": None,
|
||||
"shared_metadata": None,
|
||||
},
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error sharing/unsharing agent: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False, "error": "Failed to update agent sharing status"}), 400)
|
||||
shared_token = shared_token if shared else None
|
||||
return make_response(
|
||||
jsonify({"success": True, "shared_token": shared_token}), 200
|
||||
)
|
||||
|
||||
@@ -2,14 +2,15 @@
|
||||
|
||||
import secrets
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import agents_collection, require_agent
|
||||
from application.api.user.base import require_agent
|
||||
from application.api.user.tasks import process_agent_webhook
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
|
||||
|
||||
agents_webhooks_ns = Namespace(
|
||||
@@ -34,9 +35,8 @@ class AgentWebhook(Resource):
|
||||
jsonify({"success": False, "message": "ID is required"}), 400
|
||||
)
|
||||
try:
|
||||
agent = agents_collection.find_one(
|
||||
{"_id": ObjectId(agent_id), "user": user}
|
||||
)
|
||||
with db_readonly() as conn:
|
||||
agent = AgentsRepository(conn).get_any(agent_id, user)
|
||||
if not agent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Agent not found"}), 404
|
||||
@@ -44,10 +44,11 @@ class AgentWebhook(Resource):
|
||||
webhook_token = agent.get("incoming_webhook_token")
|
||||
if not webhook_token:
|
||||
webhook_token = secrets.token_urlsafe(32)
|
||||
agents_collection.update_one(
|
||||
{"_id": ObjectId(agent_id), "user": user},
|
||||
{"$set": {"incoming_webhook_token": webhook_token}},
|
||||
)
|
||||
with db_session() as conn:
|
||||
AgentsRepository(conn).update(
|
||||
str(agent["id"]), user,
|
||||
{"incoming_webhook_token": webhook_token},
|
||||
)
|
||||
base_url = settings.API_URL.rstrip("/")
|
||||
full_webhook_url = f"{base_url}/api/webhooks/agents/{webhook_token}"
|
||||
except Exception as err:
|
||||
|
||||
@@ -2,26 +2,84 @@
|
||||
|
||||
import datetime
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
from sqlalchemy import text as _sql_text
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import (
|
||||
agents_collection,
|
||||
conversations_collection,
|
||||
generate_date_range,
|
||||
generate_hourly_range,
|
||||
generate_minute_range,
|
||||
token_usage_collection,
|
||||
user_logs_collection,
|
||||
)
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.token_usage import TokenUsageRepository
|
||||
from application.storage.db.repositories.user_logs import UserLogsRepository
|
||||
from application.storage.db.session import db_readonly
|
||||
|
||||
|
||||
analytics_ns = Namespace(
|
||||
"analytics", description="Analytics and reporting operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
_FILTER_BUCKETS = {
|
||||
"last_hour": ("minute", "%Y-%m-%d %H:%M:00", "YYYY-MM-DD HH24:MI:00"),
|
||||
"last_24_hour": ("hour", "%Y-%m-%d %H:00", "YYYY-MM-DD HH24:00"),
|
||||
"last_7_days": ("day", "%Y-%m-%d", "YYYY-MM-DD"),
|
||||
"last_15_days": ("day", "%Y-%m-%d", "YYYY-MM-DD"),
|
||||
"last_30_days": ("day", "%Y-%m-%d", "YYYY-MM-DD"),
|
||||
}
|
||||
|
||||
|
||||
def _range_for_filter(filter_option: str):
|
||||
"""Return ``(start_date, end_date, bucket_unit, pg_fmt)`` for the filter.
|
||||
|
||||
Returns ``None`` on invalid filter.
|
||||
"""
|
||||
if filter_option not in _FILTER_BUCKETS:
|
||||
return None
|
||||
end_date = datetime.datetime.now(datetime.timezone.utc)
|
||||
bucket_unit, _py_fmt, pg_fmt = _FILTER_BUCKETS[filter_option]
|
||||
|
||||
if filter_option == "last_hour":
|
||||
start_date = end_date - datetime.timedelta(hours=1)
|
||||
elif filter_option == "last_24_hour":
|
||||
start_date = end_date - datetime.timedelta(hours=24)
|
||||
else:
|
||||
days = {
|
||||
"last_7_days": 6,
|
||||
"last_15_days": 14,
|
||||
"last_30_days": 29,
|
||||
}[filter_option]
|
||||
start_date = end_date - datetime.timedelta(days=days)
|
||||
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
end_date = end_date.replace(
|
||||
hour=23, minute=59, second=59, microsecond=999999
|
||||
)
|
||||
return start_date, end_date, bucket_unit, pg_fmt
|
||||
|
||||
|
||||
def _intervals_for_filter(filter_option, start_date, end_date):
|
||||
if filter_option == "last_hour":
|
||||
return generate_minute_range(start_date, end_date)
|
||||
if filter_option == "last_24_hour":
|
||||
return generate_hourly_range(start_date, end_date)
|
||||
return generate_date_range(start_date, end_date)
|
||||
|
||||
|
||||
def _resolve_api_key(conn, api_key_id, user_id):
|
||||
"""Look up the ``agents.key`` value for a given agent id.
|
||||
|
||||
Scoped by ``user_id`` so an authenticated caller can't probe another
|
||||
user's agents. Accepts either UUID or legacy Mongo ObjectId shape.
|
||||
"""
|
||||
if not api_key_id:
|
||||
return None
|
||||
agent = AgentsRepository(conn).get_any(api_key_id, user_id)
|
||||
return (agent or {}).get("key") if agent else None
|
||||
|
||||
|
||||
@analytics_ns.route("/get_message_analytics")
|
||||
class GetMessageAnalytics(Resource):
|
||||
get_message_analytics_model = api.model(
|
||||
@@ -32,13 +90,7 @@ class GetMessageAnalytics(Resource):
|
||||
required=False,
|
||||
description="Filter option for analytics",
|
||||
default="last_30_days",
|
||||
enum=[
|
||||
"last_hour",
|
||||
"last_24_hour",
|
||||
"last_7_days",
|
||||
"last_15_days",
|
||||
"last_30_days",
|
||||
],
|
||||
enum=list(_FILTER_BUCKETS.keys()),
|
||||
),
|
||||
},
|
||||
)
|
||||
@@ -50,88 +102,54 @@ class GetMessageAnalytics(Resource):
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
data = request.get_json() or {}
|
||||
api_key_id = data.get("api_key_id")
|
||||
filter_option = data.get("filter_option", "last_30_days")
|
||||
|
||||
window = _range_for_filter(filter_option)
|
||||
if window is None:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid option"}), 400
|
||||
)
|
||||
start_date, end_date, _bucket_unit, pg_fmt = window
|
||||
|
||||
try:
|
||||
api_key = (
|
||||
agents_collection.find_one({"_id": ObjectId(api_key_id), "user": user})[
|
||||
"key"
|
||||
with db_readonly() as conn:
|
||||
api_key = _resolve_api_key(conn, api_key_id, user)
|
||||
|
||||
# Count messages per bucket, filtered by the conversation's
|
||||
# owner (user_id) and optionally the agent api_key. The
|
||||
# ``user_id`` filter is always applied post-cutover to
|
||||
# prevent cross-tenant leakage on admin dashboards.
|
||||
clauses = [
|
||||
"c.user_id = :user_id",
|
||||
"m.timestamp >= :start",
|
||||
"m.timestamp <= :end",
|
||||
]
|
||||
if api_key_id
|
||||
else None
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
end_date = datetime.datetime.now(datetime.timezone.utc)
|
||||
|
||||
if filter_option == "last_hour":
|
||||
start_date = end_date - datetime.timedelta(hours=1)
|
||||
group_format = "%Y-%m-%d %H:%M:00"
|
||||
elif filter_option == "last_24_hour":
|
||||
start_date = end_date - datetime.timedelta(hours=24)
|
||||
group_format = "%Y-%m-%d %H:00"
|
||||
else:
|
||||
if filter_option in ["last_7_days", "last_15_days", "last_30_days"]:
|
||||
filter_days = (
|
||||
6
|
||||
if filter_option == "last_7_days"
|
||||
else 14 if filter_option == "last_15_days" else 29
|
||||
)
|
||||
else:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid option"}), 400
|
||||
)
|
||||
start_date = end_date - datetime.timedelta(days=filter_days)
|
||||
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
end_date = end_date.replace(
|
||||
hour=23, minute=59, second=59, microsecond=999999
|
||||
)
|
||||
group_format = "%Y-%m-%d"
|
||||
try:
|
||||
match_stage = {
|
||||
"$match": {
|
||||
"user": user,
|
||||
params: dict = {
|
||||
"user_id": user,
|
||||
"start": start_date,
|
||||
"end": end_date,
|
||||
"fmt": pg_fmt,
|
||||
}
|
||||
}
|
||||
if api_key:
|
||||
match_stage["$match"]["api_key"] = api_key
|
||||
pipeline = [
|
||||
match_stage,
|
||||
{"$unwind": "$queries"},
|
||||
{
|
||||
"$match": {
|
||||
"queries.timestamp": {"$gte": start_date, "$lte": end_date}
|
||||
}
|
||||
},
|
||||
{
|
||||
"$group": {
|
||||
"_id": {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$queries.timestamp",
|
||||
}
|
||||
},
|
||||
"count": {"$sum": 1},
|
||||
}
|
||||
},
|
||||
{"$sort": {"_id": 1}},
|
||||
]
|
||||
if api_key:
|
||||
clauses.append("c.api_key = :api_key")
|
||||
params["api_key"] = api_key
|
||||
where = " AND ".join(clauses)
|
||||
sql = (
|
||||
"SELECT to_char(m.timestamp AT TIME ZONE 'UTC', :fmt) AS bucket, "
|
||||
"COUNT(*) AS count "
|
||||
"FROM conversation_messages m "
|
||||
"JOIN conversations c ON c.id = m.conversation_id "
|
||||
f"WHERE {where} "
|
||||
"GROUP BY bucket ORDER BY bucket ASC"
|
||||
)
|
||||
rows = conn.execute(_sql_text(sql), params).fetchall()
|
||||
|
||||
message_data = conversations_collection.aggregate(pipeline)
|
||||
|
||||
if filter_option == "last_hour":
|
||||
intervals = generate_minute_range(start_date, end_date)
|
||||
elif filter_option == "last_24_hour":
|
||||
intervals = generate_hourly_range(start_date, end_date)
|
||||
else:
|
||||
intervals = generate_date_range(start_date, end_date)
|
||||
intervals = _intervals_for_filter(filter_option, start_date, end_date)
|
||||
daily_messages = {interval: 0 for interval in intervals}
|
||||
|
||||
for entry in message_data:
|
||||
daily_messages[entry["_id"]] = entry["count"]
|
||||
for row in rows:
|
||||
daily_messages[row._mapping["bucket"]] = int(row._mapping["count"])
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error getting message analytics: {err}", exc_info=True
|
||||
@@ -152,13 +170,7 @@ class GetTokenAnalytics(Resource):
|
||||
required=False,
|
||||
description="Filter option for analytics",
|
||||
default="last_30_days",
|
||||
enum=[
|
||||
"last_hour",
|
||||
"last_24_hour",
|
||||
"last_7_days",
|
||||
"last_15_days",
|
||||
"last_30_days",
|
||||
],
|
||||
enum=list(_FILTER_BUCKETS.keys()),
|
||||
),
|
||||
},
|
||||
)
|
||||
@@ -170,123 +182,36 @@ class GetTokenAnalytics(Resource):
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
data = request.get_json() or {}
|
||||
api_key_id = data.get("api_key_id")
|
||||
filter_option = data.get("filter_option", "last_30_days")
|
||||
|
||||
try:
|
||||
api_key = (
|
||||
agents_collection.find_one({"_id": ObjectId(api_key_id), "user": user})[
|
||||
"key"
|
||||
]
|
||||
if api_key_id
|
||||
else None
|
||||
window = _range_for_filter(filter_option)
|
||||
if window is None:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid option"}), 400
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
end_date = datetime.datetime.now(datetime.timezone.utc)
|
||||
start_date, end_date, bucket_unit, _pg_fmt = window
|
||||
|
||||
if filter_option == "last_hour":
|
||||
start_date = end_date - datetime.timedelta(hours=1)
|
||||
group_format = "%Y-%m-%d %H:%M:00"
|
||||
group_stage = {
|
||||
"$group": {
|
||||
"_id": {
|
||||
"minute": {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$timestamp",
|
||||
}
|
||||
}
|
||||
},
|
||||
"total_tokens": {
|
||||
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
|
||||
},
|
||||
}
|
||||
}
|
||||
elif filter_option == "last_24_hour":
|
||||
start_date = end_date - datetime.timedelta(hours=24)
|
||||
group_format = "%Y-%m-%d %H:00"
|
||||
group_stage = {
|
||||
"$group": {
|
||||
"_id": {
|
||||
"hour": {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$timestamp",
|
||||
}
|
||||
}
|
||||
},
|
||||
"total_tokens": {
|
||||
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
|
||||
},
|
||||
}
|
||||
}
|
||||
else:
|
||||
if filter_option in ["last_7_days", "last_15_days", "last_30_days"]:
|
||||
filter_days = (
|
||||
6
|
||||
if filter_option == "last_7_days"
|
||||
else (14 if filter_option == "last_15_days" else 29)
|
||||
)
|
||||
else:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid option"}), 400
|
||||
)
|
||||
start_date = end_date - datetime.timedelta(days=filter_days)
|
||||
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
end_date = end_date.replace(
|
||||
hour=23, minute=59, second=59, microsecond=999999
|
||||
)
|
||||
group_format = "%Y-%m-%d"
|
||||
group_stage = {
|
||||
"$group": {
|
||||
"_id": {
|
||||
"day": {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$timestamp",
|
||||
}
|
||||
}
|
||||
},
|
||||
"total_tokens": {
|
||||
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
|
||||
},
|
||||
}
|
||||
}
|
||||
try:
|
||||
match_stage = {
|
||||
"$match": {
|
||||
"user_id": user,
|
||||
"timestamp": {"$gte": start_date, "$lte": end_date},
|
||||
}
|
||||
}
|
||||
if api_key:
|
||||
match_stage["$match"]["api_key"] = api_key
|
||||
token_usage_data = token_usage_collection.aggregate(
|
||||
[
|
||||
match_stage,
|
||||
group_stage,
|
||||
{"$sort": {"_id": 1}},
|
||||
]
|
||||
)
|
||||
with db_readonly() as conn:
|
||||
api_key = _resolve_api_key(conn, api_key_id, user)
|
||||
# ``bucketed_totals`` applies user_id / api_key filters
|
||||
# directly — no need to reshape a Mongo pipeline.
|
||||
rows = TokenUsageRepository(conn).bucketed_totals(
|
||||
bucket_unit=bucket_unit,
|
||||
user_id=user,
|
||||
api_key=api_key,
|
||||
timestamp_gte=start_date,
|
||||
timestamp_lt=end_date,
|
||||
)
|
||||
|
||||
if filter_option == "last_hour":
|
||||
intervals = generate_minute_range(start_date, end_date)
|
||||
elif filter_option == "last_24_hour":
|
||||
intervals = generate_hourly_range(start_date, end_date)
|
||||
else:
|
||||
intervals = generate_date_range(start_date, end_date)
|
||||
intervals = _intervals_for_filter(filter_option, start_date, end_date)
|
||||
daily_token_usage = {interval: 0 for interval in intervals}
|
||||
|
||||
for entry in token_usage_data:
|
||||
if filter_option == "last_hour":
|
||||
daily_token_usage[entry["_id"]["minute"]] = entry["total_tokens"]
|
||||
elif filter_option == "last_24_hour":
|
||||
daily_token_usage[entry["_id"]["hour"]] = entry["total_tokens"]
|
||||
else:
|
||||
daily_token_usage[entry["_id"]["day"]] = entry["total_tokens"]
|
||||
for entry in rows:
|
||||
daily_token_usage[entry["bucket"]] = int(
|
||||
entry["prompt_tokens"] + entry["generated_tokens"]
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error getting token analytics: {err}", exc_info=True
|
||||
@@ -307,13 +232,7 @@ class GetFeedbackAnalytics(Resource):
|
||||
required=False,
|
||||
description="Filter option for analytics",
|
||||
default="last_30_days",
|
||||
enum=[
|
||||
"last_hour",
|
||||
"last_24_hour",
|
||||
"last_7_days",
|
||||
"last_15_days",
|
||||
"last_30_days",
|
||||
],
|
||||
enum=list(_FILTER_BUCKETS.keys()),
|
||||
),
|
||||
},
|
||||
)
|
||||
@@ -325,128 +244,64 @@ class GetFeedbackAnalytics(Resource):
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
data = request.get_json() or {}
|
||||
api_key_id = data.get("api_key_id")
|
||||
filter_option = data.get("filter_option", "last_30_days")
|
||||
|
||||
window = _range_for_filter(filter_option)
|
||||
if window is None:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid option"}), 400
|
||||
)
|
||||
start_date, end_date, _bucket_unit, pg_fmt = window
|
||||
|
||||
try:
|
||||
api_key = (
|
||||
agents_collection.find_one({"_id": ObjectId(api_key_id), "user": user})[
|
||||
"key"
|
||||
with db_readonly() as conn:
|
||||
api_key = _resolve_api_key(conn, api_key_id, user)
|
||||
|
||||
# Feedback lives inside the ``conversation_messages.feedback``
|
||||
# JSONB as ``{"text": "like"|"dislike", "timestamp": "..."}``.
|
||||
# There is no scalar ``feedback_timestamp`` column — extract
|
||||
# the timestamp from the JSONB and cast it to timestamptz for
|
||||
# the range filter + bucket grouping.
|
||||
clauses = [
|
||||
"c.user_id = :user_id",
|
||||
"m.feedback IS NOT NULL",
|
||||
"(m.feedback->>'timestamp')::timestamptz >= :start",
|
||||
"(m.feedback->>'timestamp')::timestamptz <= :end",
|
||||
]
|
||||
if api_key_id
|
||||
else None
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
end_date = datetime.datetime.now(datetime.timezone.utc)
|
||||
|
||||
if filter_option == "last_hour":
|
||||
start_date = end_date - datetime.timedelta(hours=1)
|
||||
group_format = "%Y-%m-%d %H:%M:00"
|
||||
date_field = {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$queries.feedback_timestamp",
|
||||
params: dict = {
|
||||
"user_id": user,
|
||||
"start": start_date,
|
||||
"end": end_date,
|
||||
"fmt": pg_fmt,
|
||||
}
|
||||
}
|
||||
elif filter_option == "last_24_hour":
|
||||
start_date = end_date - datetime.timedelta(hours=24)
|
||||
group_format = "%Y-%m-%d %H:00"
|
||||
date_field = {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$queries.feedback_timestamp",
|
||||
}
|
||||
}
|
||||
else:
|
||||
if filter_option in ["last_7_days", "last_15_days", "last_30_days"]:
|
||||
filter_days = (
|
||||
6
|
||||
if filter_option == "last_7_days"
|
||||
else (14 if filter_option == "last_15_days" else 29)
|
||||
if api_key:
|
||||
clauses.append("c.api_key = :api_key")
|
||||
params["api_key"] = api_key
|
||||
where = " AND ".join(clauses)
|
||||
sql = (
|
||||
"SELECT to_char("
|
||||
"(m.feedback->>'timestamp')::timestamptz AT TIME ZONE 'UTC', :fmt"
|
||||
") AS bucket, "
|
||||
"SUM(CASE WHEN m.feedback->>'text' = 'like' THEN 1 ELSE 0 END) AS positive, "
|
||||
"SUM(CASE WHEN m.feedback->>'text' = 'dislike' THEN 1 ELSE 0 END) AS negative "
|
||||
"FROM conversation_messages m "
|
||||
"JOIN conversations c ON c.id = m.conversation_id "
|
||||
f"WHERE {where} "
|
||||
"GROUP BY bucket ORDER BY bucket ASC"
|
||||
)
|
||||
else:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid option"}), 400
|
||||
)
|
||||
start_date = end_date - datetime.timedelta(days=filter_days)
|
||||
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
end_date = end_date.replace(
|
||||
hour=23, minute=59, second=59, microsecond=999999
|
||||
)
|
||||
group_format = "%Y-%m-%d"
|
||||
date_field = {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$queries.feedback_timestamp",
|
||||
}
|
||||
}
|
||||
try:
|
||||
match_stage = {
|
||||
"$match": {
|
||||
"queries.feedback_timestamp": {
|
||||
"$gte": start_date,
|
||||
"$lte": end_date,
|
||||
},
|
||||
"queries.feedback": {"$exists": True},
|
||||
}
|
||||
}
|
||||
if api_key:
|
||||
match_stage["$match"]["api_key"] = api_key
|
||||
pipeline = [
|
||||
match_stage,
|
||||
{"$unwind": "$queries"},
|
||||
{"$match": {"queries.feedback": {"$exists": True}}},
|
||||
{
|
||||
"$group": {
|
||||
"_id": {"time": date_field, "feedback": "$queries.feedback"},
|
||||
"count": {"$sum": 1},
|
||||
}
|
||||
},
|
||||
{
|
||||
"$group": {
|
||||
"_id": "$_id.time",
|
||||
"positive": {
|
||||
"$sum": {
|
||||
"$cond": [
|
||||
{"$eq": ["$_id.feedback", "LIKE"]},
|
||||
"$count",
|
||||
0,
|
||||
]
|
||||
}
|
||||
},
|
||||
"negative": {
|
||||
"$sum": {
|
||||
"$cond": [
|
||||
{"$eq": ["$_id.feedback", "DISLIKE"]},
|
||||
"$count",
|
||||
0,
|
||||
]
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
{"$sort": {"_id": 1}},
|
||||
]
|
||||
rows = conn.execute(_sql_text(sql), params).fetchall()
|
||||
|
||||
feedback_data = conversations_collection.aggregate(pipeline)
|
||||
|
||||
if filter_option == "last_hour":
|
||||
intervals = generate_minute_range(start_date, end_date)
|
||||
elif filter_option == "last_24_hour":
|
||||
intervals = generate_hourly_range(start_date, end_date)
|
||||
else:
|
||||
intervals = generate_date_range(start_date, end_date)
|
||||
intervals = _intervals_for_filter(filter_option, start_date, end_date)
|
||||
daily_feedback = {
|
||||
interval: {"positive": 0, "negative": 0} for interval in intervals
|
||||
}
|
||||
|
||||
for entry in feedback_data:
|
||||
daily_feedback[entry["_id"]] = {
|
||||
"positive": entry["positive"],
|
||||
"negative": entry["negative"],
|
||||
for row in rows:
|
||||
bucket = row._mapping["bucket"]
|
||||
daily_feedback[bucket] = {
|
||||
"positive": int(row._mapping["positive"] or 0),
|
||||
"negative": int(row._mapping["negative"] or 0),
|
||||
}
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
@@ -484,47 +339,89 @@ class GetUserLogs(Resource):
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
data = request.get_json() or {}
|
||||
page = int(data.get("page", 1))
|
||||
api_key_id = data.get("api_key_id")
|
||||
page_size = int(data.get("page_size", 10))
|
||||
skip = (page - 1) * page_size
|
||||
|
||||
try:
|
||||
api_key = (
|
||||
agents_collection.find_one({"_id": ObjectId(api_key_id)})["key"]
|
||||
if api_key_id
|
||||
else None
|
||||
)
|
||||
with db_readonly() as conn:
|
||||
api_key = _resolve_api_key(conn, api_key_id, user)
|
||||
logs_repo = UserLogsRepository(conn)
|
||||
if api_key:
|
||||
# ``find_by_api_key`` filters on ``data->>'api_key'``
|
||||
# — the PG shape of the legacy top-level ``api_key``
|
||||
# filter. Paginate client-side using offset/limit.
|
||||
all_rows = logs_repo.find_by_api_key(api_key)
|
||||
offset = (page - 1) * page_size
|
||||
window = all_rows[offset: offset + page_size + 1]
|
||||
items = window
|
||||
else:
|
||||
items, has_more_flag = logs_repo.list_paginated(
|
||||
user_id=user,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
# list_paginated already trims to page_size and
|
||||
# returns has_more separately.
|
||||
results = [
|
||||
{
|
||||
"id": str(item.get("id") or item.get("_id")),
|
||||
"action": (item.get("data") or {}).get("action"),
|
||||
"level": (item.get("data") or {}).get("level"),
|
||||
"user": item.get("user_id"),
|
||||
"question": (item.get("data") or {}).get("question"),
|
||||
"sources": (item.get("data") or {}).get("sources"),
|
||||
"retriever_params": (item.get("data") or {}).get(
|
||||
"retriever_params"
|
||||
),
|
||||
"timestamp": (
|
||||
item["timestamp"].isoformat()
|
||||
if hasattr(item.get("timestamp"), "isoformat")
|
||||
else item.get("timestamp")
|
||||
),
|
||||
}
|
||||
for item in items
|
||||
]
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"logs": results,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"has_more": has_more_flag,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
|
||||
has_more = len(items) > page_size
|
||||
items = items[:page_size]
|
||||
results = [
|
||||
{
|
||||
"id": str(item.get("id") or item.get("_id")),
|
||||
"action": (item.get("data") or {}).get("action"),
|
||||
"level": (item.get("data") or {}).get("level"),
|
||||
"user": item.get("user_id"),
|
||||
"question": (item.get("data") or {}).get("question"),
|
||||
"sources": (item.get("data") or {}).get("sources"),
|
||||
"retriever_params": (item.get("data") or {}).get(
|
||||
"retriever_params"
|
||||
),
|
||||
"timestamp": (
|
||||
item["timestamp"].isoformat()
|
||||
if hasattr(item.get("timestamp"), "isoformat")
|
||||
else item.get("timestamp")
|
||||
),
|
||||
}
|
||||
for item in items
|
||||
]
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
|
||||
current_app.logger.error(
|
||||
f"Error getting user logs: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
query = {"user": user}
|
||||
if api_key:
|
||||
query = {"api_key": api_key}
|
||||
items_cursor = (
|
||||
user_logs_collection.find(query)
|
||||
.sort("timestamp", -1)
|
||||
.skip(skip)
|
||||
.limit(page_size + 1)
|
||||
)
|
||||
items = list(items_cursor)
|
||||
|
||||
results = [
|
||||
{
|
||||
"id": str(item.get("_id")),
|
||||
"action": item.get("action"),
|
||||
"level": item.get("level"),
|
||||
"user": item.get("user"),
|
||||
"question": item.get("question"),
|
||||
"sources": item.get("sources"),
|
||||
"retriever_params": item.get("retriever_params"),
|
||||
"timestamp": item.get("timestamp"),
|
||||
}
|
||||
for item in items[:page_size]
|
||||
]
|
||||
|
||||
has_more = len(items) > page_size
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
|
||||
@@ -4,13 +4,16 @@ import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
import uuid
|
||||
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.cache import get_redis_instance
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.session import db_readonly
|
||||
from application.stt.constants import (
|
||||
SUPPORTED_AUDIO_EXTENSIONS,
|
||||
SUPPORTED_AUDIO_MIME_TYPES,
|
||||
@@ -48,14 +51,13 @@ def _resolve_authenticated_user():
|
||||
return safe_filename(decoded_token.get("sub"))
|
||||
|
||||
if api_key:
|
||||
from application.api.user.base import agents_collection
|
||||
|
||||
agent = agents_collection.find_one({"key": api_key})
|
||||
with db_readonly() as conn:
|
||||
agent = AgentsRepository(conn).find_by_key(api_key)
|
||||
if not agent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid API key"}), 401
|
||||
)
|
||||
return safe_filename(agent.get("user"))
|
||||
return safe_filename(agent.get("user_id"))
|
||||
|
||||
return None
|
||||
|
||||
@@ -157,7 +159,7 @@ class StoreAttachment(Resource):
|
||||
|
||||
for idx, file in enumerate(files):
|
||||
try:
|
||||
attachment_id = ObjectId()
|
||||
attachment_id = uuid.uuid4()
|
||||
original_filename = safe_filename(os.path.basename(file.filename))
|
||||
_enforce_uploaded_audio_size_limit(file, original_filename)
|
||||
relative_path = f"{settings.UPLOAD_FOLDER}/{user}/attachments/{str(attachment_id)}/{original_filename}"
|
||||
|
||||
@@ -8,15 +8,15 @@ import uuid
|
||||
from functools import wraps
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, Response
|
||||
from pymongo import ReturnDocument
|
||||
from werkzeug.utils import secure_filename
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from sqlalchemy import text as _sql_text
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.base_repository import looks_like_uuid, row_to_dict
|
||||
from application.storage.db.repositories.users import UsersRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.storage.storage_creator import StorageCreator
|
||||
from application.vectorstore.vector_creator import VectorCreator
|
||||
|
||||
@@ -24,56 +24,6 @@ from application.vectorstore.vector_creator import VectorCreator
|
||||
storage = StorageCreator.get_storage()
|
||||
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
|
||||
|
||||
conversations_collection = db["conversations"]
|
||||
sources_collection = db["sources"]
|
||||
prompts_collection = db["prompts"]
|
||||
feedback_collection = db["feedback"]
|
||||
agents_collection = db["agents"]
|
||||
agent_folders_collection = db["agent_folders"]
|
||||
token_usage_collection = db["token_usage"]
|
||||
shared_conversations_collections = db["shared_conversations"]
|
||||
users_collection = db["users"]
|
||||
user_logs_collection = db["user_logs"]
|
||||
user_tools_collection = db["user_tools"]
|
||||
attachments_collection = db["attachments"]
|
||||
workflow_runs_collection = db["workflow_runs"]
|
||||
workflows_collection = db["workflows"]
|
||||
workflow_nodes_collection = db["workflow_nodes"]
|
||||
workflow_edges_collection = db["workflow_edges"]
|
||||
|
||||
|
||||
try:
|
||||
agents_collection.create_index(
|
||||
[("shared", 1)],
|
||||
name="shared_index",
|
||||
background=True,
|
||||
)
|
||||
users_collection.create_index("user_id", unique=True)
|
||||
workflows_collection.create_index(
|
||||
[("user", 1)], name="workflow_user_index", background=True
|
||||
)
|
||||
workflow_nodes_collection.create_index(
|
||||
[("workflow_id", 1)], name="node_workflow_index", background=True
|
||||
)
|
||||
workflow_nodes_collection.create_index(
|
||||
[("workflow_id", 1), ("graph_version", 1)],
|
||||
name="node_workflow_graph_version_index",
|
||||
background=True,
|
||||
)
|
||||
workflow_edges_collection.create_index(
|
||||
[("workflow_id", 1)], name="edge_workflow_index", background=True
|
||||
)
|
||||
workflow_edges_collection.create_index(
|
||||
[("workflow_id", 1), ("graph_version", 1)],
|
||||
name="edge_workflow_graph_version_index",
|
||||
background=True,
|
||||
)
|
||||
except Exception as e:
|
||||
print("Error creating indexes:", e)
|
||||
current_dir = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
)
|
||||
@@ -105,69 +55,95 @@ def generate_date_range(start_date, end_date):
|
||||
|
||||
def ensure_user_doc(user_id):
|
||||
"""
|
||||
Ensure user document exists with proper agent preferences structure.
|
||||
Ensure a Postgres ``users`` row exists for ``user_id``.
|
||||
|
||||
Returns the row as a dict with the shape legacy callers expect — in
|
||||
particular ``user_id`` and ``agent_preferences`` (with ``pinned`` and
|
||||
``shared_with_me`` list keys always present).
|
||||
|
||||
Args:
|
||||
user_id: The user ID to ensure
|
||||
|
||||
Returns:
|
||||
The user document
|
||||
The user document as a dict.
|
||||
"""
|
||||
default_prefs = {
|
||||
"pinned": [],
|
||||
"shared_with_me": [],
|
||||
}
|
||||
|
||||
user_doc = users_collection.find_one_and_update(
|
||||
{"user_id": user_id},
|
||||
{"$setOnInsert": {"agent_preferences": default_prefs}},
|
||||
upsert=True,
|
||||
return_document=ReturnDocument.AFTER,
|
||||
)
|
||||
|
||||
prefs = user_doc.get("agent_preferences", {})
|
||||
updates = {}
|
||||
if "pinned" not in prefs:
|
||||
updates["agent_preferences.pinned"] = []
|
||||
if "shared_with_me" not in prefs:
|
||||
updates["agent_preferences.shared_with_me"] = []
|
||||
if updates:
|
||||
users_collection.update_one({"user_id": user_id}, {"$set": updates})
|
||||
user_doc = users_collection.find_one({"user_id": user_id})
|
||||
|
||||
dual_write(UsersRepository, lambda repo: repo.upsert(user_id))
|
||||
with db_session() as conn:
|
||||
user_doc = UsersRepository(conn).upsert(user_id)
|
||||
|
||||
prefs = user_doc.get("agent_preferences") or {}
|
||||
if not isinstance(prefs, dict):
|
||||
prefs = {}
|
||||
prefs.setdefault("pinned", [])
|
||||
prefs.setdefault("shared_with_me", [])
|
||||
user_doc["agent_preferences"] = prefs
|
||||
return user_doc
|
||||
|
||||
|
||||
def resolve_tool_details(tool_ids):
|
||||
"""
|
||||
Resolve tool IDs to their details.
|
||||
Resolve tool IDs to their display details.
|
||||
|
||||
Accepts either Postgres UUIDs or legacy Mongo ObjectId strings (mixed
|
||||
lists are supported — each id is looked up via ``get_any``, which
|
||||
resolves to whichever column matches). Unknown ids are silently
|
||||
skipped.
|
||||
|
||||
Args:
|
||||
tool_ids: List of tool IDs
|
||||
tool_ids: List of tool IDs (UUIDs or legacy Mongo ObjectId strings).
|
||||
|
||||
Returns:
|
||||
List of tool details with id, name, and display_name
|
||||
List of tool details with ``id``, ``name``, and ``display_name``.
|
||||
"""
|
||||
valid_ids = []
|
||||
if not tool_ids:
|
||||
return []
|
||||
|
||||
uuid_ids: list[str] = []
|
||||
legacy_ids: list[str] = []
|
||||
for tid in tool_ids:
|
||||
try:
|
||||
valid_ids.append(ObjectId(tid))
|
||||
except Exception:
|
||||
if not tid:
|
||||
continue
|
||||
tools = user_tools_collection.find(
|
||||
{"_id": {"$in": valid_ids}}
|
||||
) if valid_ids else []
|
||||
tid_str = str(tid)
|
||||
if looks_like_uuid(tid_str):
|
||||
uuid_ids.append(tid_str)
|
||||
else:
|
||||
legacy_ids.append(tid_str)
|
||||
|
||||
if not uuid_ids and not legacy_ids:
|
||||
return []
|
||||
|
||||
rows: list[dict] = []
|
||||
with db_readonly() as conn:
|
||||
if uuid_ids:
|
||||
result = conn.execute(
|
||||
_sql_text(
|
||||
"SELECT * FROM user_tools "
|
||||
"WHERE id = ANY(CAST(:ids AS uuid[]))"
|
||||
),
|
||||
{"ids": uuid_ids},
|
||||
)
|
||||
rows.extend(row_to_dict(r) for r in result.fetchall())
|
||||
if legacy_ids:
|
||||
result = conn.execute(
|
||||
_sql_text(
|
||||
"SELECT * FROM user_tools "
|
||||
"WHERE legacy_mongo_id = ANY(:ids)"
|
||||
),
|
||||
{"ids": legacy_ids},
|
||||
)
|
||||
rows.extend(row_to_dict(r) for r in result.fetchall())
|
||||
|
||||
return [
|
||||
{
|
||||
"id": str(tool["_id"]),
|
||||
"name": tool.get("name", ""),
|
||||
"display_name": tool.get("customName")
|
||||
or tool.get("displayName")
|
||||
or tool.get("name", ""),
|
||||
"id": str(tool.get("id") or tool.get("legacy_mongo_id") or ""),
|
||||
"name": tool.get("name", "") or "",
|
||||
"display_name": (
|
||||
tool.get("custom_name")
|
||||
or tool.get("display_name")
|
||||
or tool.get("name", "")
|
||||
or ""
|
||||
),
|
||||
}
|
||||
for tool in tools
|
||||
for tool in rows
|
||||
]
|
||||
|
||||
|
||||
@@ -237,14 +213,15 @@ def require_agent(func):
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
|
||||
webhook_token = kwargs.get("webhook_token")
|
||||
if not webhook_token:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Webhook token missing"}), 400
|
||||
)
|
||||
agent = agents_collection.find_one(
|
||||
{"incoming_webhook_token": webhook_token}, {"_id": 1}
|
||||
)
|
||||
with db_readonly() as conn:
|
||||
agent = AgentsRepository(conn).find_by_webhook_token(webhook_token)
|
||||
if not agent:
|
||||
current_app.logger.warning(
|
||||
f"Webhook attempt with invalid token: {webhook_token}"
|
||||
@@ -253,7 +230,7 @@ def require_agent(func):
|
||||
jsonify({"success": False, "message": "Agent not found"}), 404
|
||||
)
|
||||
kwargs["agent"] = agent
|
||||
kwargs["agent_id_str"] = str(agent["_id"])
|
||||
kwargs["agent_id_str"] = str(agent["id"])
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -2,14 +2,13 @@
|
||||
|
||||
import datetime
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import attachments_collection, conversations_collection
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.repositories.attachments import AttachmentsRepository
|
||||
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.utils import check_required_fields
|
||||
|
||||
conversations_ns = Namespace(
|
||||
@@ -34,21 +33,16 @@ class DeleteConversation(Resource):
|
||||
)
|
||||
user_id = decoded_token["sub"]
|
||||
try:
|
||||
conversations_collection.delete_one(
|
||||
{"_id": ObjectId(conversation_id), "user": user_id}
|
||||
)
|
||||
with db_session() as conn:
|
||||
repo = ConversationsRepository(conn)
|
||||
conv = repo.get_any(conversation_id, user_id)
|
||||
if conv is not None:
|
||||
repo.delete(str(conv["id"]), user_id)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error deleting conversation: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
def _pg_delete(repo: ConversationsRepository) -> None:
|
||||
conv = repo.get_by_legacy_id(conversation_id)
|
||||
if conv is not None:
|
||||
repo.delete(conv["id"], user_id)
|
||||
|
||||
dual_write(ConversationsRepository, _pg_delete)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@@ -63,17 +57,13 @@ class DeleteAllConversations(Resource):
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user_id = decoded_token.get("sub")
|
||||
try:
|
||||
conversations_collection.delete_many({"user": user_id})
|
||||
with db_session() as conn:
|
||||
ConversationsRepository(conn).delete_all_for_user(user_id)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error deleting all conversations: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
dual_write(
|
||||
ConversationsRepository,
|
||||
lambda r, uid=user_id: r.delete_all_for_user(uid),
|
||||
)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@@ -86,26 +76,21 @@ class GetConversations(Resource):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user_id = decoded_token.get("sub")
|
||||
try:
|
||||
conversations = (
|
||||
conversations_collection.find(
|
||||
{
|
||||
"$or": [
|
||||
{"api_key": {"$exists": False}},
|
||||
{"agent_id": {"$exists": True}},
|
||||
],
|
||||
"user": decoded_token.get("sub"),
|
||||
}
|
||||
with db_readonly() as conn:
|
||||
conversations = ConversationsRepository(conn).list_for_user(
|
||||
user_id, limit=30
|
||||
)
|
||||
.sort("date", -1)
|
||||
.limit(30)
|
||||
)
|
||||
|
||||
list_conversations = [
|
||||
{
|
||||
"id": str(conversation["_id"]),
|
||||
"id": str(conversation["id"]),
|
||||
"name": conversation["name"],
|
||||
"agent_id": conversation.get("agent_id", None),
|
||||
"agent_id": (
|
||||
str(conversation["agent_id"])
|
||||
if conversation.get("agent_id")
|
||||
else None
|
||||
),
|
||||
"is_shared_usage": conversation.get("is_shared_usage", False),
|
||||
"shared_token": conversation.get("shared_token", None),
|
||||
}
|
||||
@@ -134,38 +119,67 @@ class GetSingleConversation(Resource):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "ID is required"}), 400
|
||||
)
|
||||
user_id = decoded_token.get("sub")
|
||||
try:
|
||||
conversation = conversations_collection.find_one(
|
||||
{"_id": ObjectId(conversation_id), "user": decoded_token.get("sub")}
|
||||
)
|
||||
if not conversation:
|
||||
return make_response(jsonify({"status": "not found"}), 404)
|
||||
# Process queries to include attachment names
|
||||
with db_readonly() as conn:
|
||||
repo = ConversationsRepository(conn)
|
||||
conversation = repo.get_any(conversation_id, user_id)
|
||||
if not conversation:
|
||||
return make_response(jsonify({"status": "not found"}), 404)
|
||||
conv_pg_id = str(conversation["id"])
|
||||
messages = repo.get_messages(conv_pg_id)
|
||||
|
||||
queries = conversation["queries"]
|
||||
for query in queries:
|
||||
if "attachments" in query and query["attachments"]:
|
||||
attachment_details = []
|
||||
for attachment_id in query["attachments"]:
|
||||
try:
|
||||
attachment = attachments_collection.find_one(
|
||||
{"_id": ObjectId(attachment_id)}
|
||||
)
|
||||
if attachment:
|
||||
attachment_details.append(
|
||||
{
|
||||
"id": str(attachment["_id"]),
|
||||
"fileName": attachment.get(
|
||||
"filename", "Unknown file"
|
||||
),
|
||||
}
|
||||
# Resolve attachment details (id, fileName) for each message.
|
||||
attachments_repo = AttachmentsRepository(conn)
|
||||
queries = []
|
||||
for msg in messages:
|
||||
query = {
|
||||
"prompt": msg.get("prompt"),
|
||||
"response": msg.get("response"),
|
||||
"thought": msg.get("thought"),
|
||||
"sources": msg.get("sources") or [],
|
||||
"tool_calls": msg.get("tool_calls") or [],
|
||||
"timestamp": msg.get("timestamp"),
|
||||
"model_id": msg.get("model_id"),
|
||||
}
|
||||
if msg.get("metadata"):
|
||||
query["metadata"] = msg["metadata"]
|
||||
# Feedback on conversation_messages is a JSONB blob with
|
||||
# shape {"text": <str>, "timestamp": <iso>}. The legacy
|
||||
# frontend consumed a flat scalar feedback string, so
|
||||
# unwrap the ``text`` field for compat.
|
||||
feedback = msg.get("feedback")
|
||||
if feedback is not None:
|
||||
if isinstance(feedback, dict):
|
||||
query["feedback"] = feedback.get("text")
|
||||
if feedback.get("timestamp"):
|
||||
query["feedback_timestamp"] = feedback["timestamp"]
|
||||
else:
|
||||
query["feedback"] = feedback
|
||||
attachments = msg.get("attachments") or []
|
||||
if attachments:
|
||||
attachment_details = []
|
||||
for attachment_id in attachments:
|
||||
try:
|
||||
att = attachments_repo.get_any(
|
||||
str(attachment_id), user_id
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error retrieving attachment {attachment_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
query["attachments"] = attachment_details
|
||||
if att:
|
||||
attachment_details.append(
|
||||
{
|
||||
"id": str(att["id"]),
|
||||
"fileName": att.get(
|
||||
"filename", "Unknown file"
|
||||
),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error retrieving attachment {attachment_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
query["attachments"] = attachment_details
|
||||
queries.append(query)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error retrieving conversation: {err}", exc_info=True
|
||||
@@ -173,7 +187,9 @@ class GetSingleConversation(Resource):
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
data = {
|
||||
"queries": queries,
|
||||
"agent_id": conversation.get("agent_id"),
|
||||
"agent_id": (
|
||||
str(conversation["agent_id"]) if conversation.get("agent_id") else None
|
||||
),
|
||||
"is_shared_usage": conversation.get("is_shared_usage", False),
|
||||
"shared_token": conversation.get("shared_token", None),
|
||||
}
|
||||
@@ -207,22 +223,16 @@ class UpdateConversationName(Resource):
|
||||
return missing_fields
|
||||
user_id = decoded_token.get("sub")
|
||||
try:
|
||||
conversations_collection.update_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user_id},
|
||||
{"$set": {"name": data["name"]}},
|
||||
)
|
||||
with db_session() as conn:
|
||||
repo = ConversationsRepository(conn)
|
||||
conv = repo.get_any(data["id"], user_id)
|
||||
if conv is not None:
|
||||
repo.rename(str(conv["id"]), user_id, data["name"])
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error updating conversation name: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
def _pg_rename(repo: ConversationsRepository) -> None:
|
||||
conv = repo.get_by_legacy_id(data["id"])
|
||||
if conv is not None:
|
||||
repo.rename(conv["id"], user_id, data["name"])
|
||||
|
||||
dual_write(ConversationsRepository, _pg_rename)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@@ -260,61 +270,34 @@ class SubmitFeedback(Resource):
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
user_id = decoded_token.get("sub")
|
||||
feedback_value = data["feedback"]
|
||||
question_index = int(data["question_index"])
|
||||
# Normalize string feedback to lowercase so analytics queries
|
||||
# (which match 'like'/'dislike') count rows correctly. Tolerate
|
||||
# legacy uppercase clients on ingest. Non-string values pass through.
|
||||
if isinstance(feedback_value, str):
|
||||
feedback_value = feedback_value.lower()
|
||||
feedback_payload = (
|
||||
None
|
||||
if feedback_value is None
|
||||
else {
|
||||
"text": feedback_value,
|
||||
"timestamp": datetime.datetime.now(
|
||||
datetime.timezone.utc
|
||||
).isoformat(),
|
||||
}
|
||||
)
|
||||
try:
|
||||
if data["feedback"] is None:
|
||||
# Remove feedback and feedback_timestamp if feedback is null
|
||||
|
||||
conversations_collection.update_one(
|
||||
{
|
||||
"_id": ObjectId(data["conversation_id"]),
|
||||
"user": decoded_token.get("sub"),
|
||||
f"queries.{data['question_index']}": {"$exists": True},
|
||||
},
|
||||
{
|
||||
"$unset": {
|
||||
f"queries.{data['question_index']}.feedback": "",
|
||||
f"queries.{data['question_index']}.feedback_timestamp": "",
|
||||
}
|
||||
},
|
||||
)
|
||||
else:
|
||||
# Set feedback and feedback_timestamp if feedback has a value
|
||||
|
||||
conversations_collection.update_one(
|
||||
{
|
||||
"_id": ObjectId(data["conversation_id"]),
|
||||
"user": decoded_token.get("sub"),
|
||||
f"queries.{data['question_index']}": {"$exists": True},
|
||||
},
|
||||
{
|
||||
"$set": {
|
||||
f"queries.{data['question_index']}.feedback": data[
|
||||
"feedback"
|
||||
],
|
||||
f"queries.{data['question_index']}.feedback_timestamp": datetime.datetime.now(
|
||||
datetime.timezone.utc
|
||||
),
|
||||
}
|
||||
},
|
||||
)
|
||||
with db_session() as conn:
|
||||
repo = ConversationsRepository(conn)
|
||||
conv = repo.get_any(data["conversation_id"], user_id)
|
||||
if conv is None:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Not found"}), 404
|
||||
)
|
||||
repo.set_feedback(str(conv["id"]), question_index, feedback_payload)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error submitting feedback: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
# Dual-write to Postgres: mirror the per-message feedback set/unset.
|
||||
feedback_value = data["feedback"]
|
||||
question_index = int(data["question_index"])
|
||||
feedback_payload = (
|
||||
None if feedback_value is None
|
||||
else {"text": feedback_value, "timestamp": datetime.datetime.now(
|
||||
datetime.timezone.utc
|
||||
).isoformat()}
|
||||
)
|
||||
|
||||
def _pg_feedback(repo: ConversationsRepository) -> None:
|
||||
conv = repo.get_by_legacy_id(data["conversation_id"])
|
||||
if conv is not None:
|
||||
repo.set_feedback(conv["id"], question_index, feedback_payload)
|
||||
|
||||
dual_write(ConversationsRepository, _pg_feedback)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
@@ -2,14 +2,13 @@
|
||||
|
||||
import os
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import current_dir, prompts_collection
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.api.user.base import current_dir
|
||||
from application.storage.db.repositories.prompts import PromptsRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.utils import check_required_fields
|
||||
|
||||
prompts_ns = Namespace(
|
||||
@@ -42,21 +41,9 @@ class CreatePrompt(Resource):
|
||||
return missing_fields
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
|
||||
resp = prompts_collection.insert_one(
|
||||
{
|
||||
"name": data["name"],
|
||||
"content": data["content"],
|
||||
"user": user,
|
||||
}
|
||||
)
|
||||
new_id = str(resp.inserted_id)
|
||||
dual_write(
|
||||
PromptsRepository,
|
||||
lambda repo, u=user, n=data["name"], c=data["content"], mid=new_id: repo.create(
|
||||
u, n, c, legacy_mongo_id=mid,
|
||||
),
|
||||
)
|
||||
with db_session() as conn:
|
||||
prompt = PromptsRepository(conn).create(user, data["name"], data["content"])
|
||||
new_id = str(prompt["id"])
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error creating prompt: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
@@ -72,17 +59,17 @@ class GetPrompts(Resource):
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
prompts = prompts_collection.find({"user": user})
|
||||
with db_readonly() as conn:
|
||||
prompts = PromptsRepository(conn).list_for_user(user)
|
||||
list_prompts = [
|
||||
{"id": "default", "name": "default", "type": "public"},
|
||||
{"id": "creative", "name": "creative", "type": "public"},
|
||||
{"id": "strict", "name": "strict", "type": "public"},
|
||||
]
|
||||
|
||||
for prompt in prompts:
|
||||
list_prompts.append(
|
||||
{
|
||||
"id": str(prompt["_id"]),
|
||||
"id": str(prompt["id"]),
|
||||
"name": prompt["name"],
|
||||
"type": "private",
|
||||
}
|
||||
@@ -127,9 +114,12 @@ class GetSinglePrompt(Resource):
|
||||
) as f:
|
||||
chat_reduce_strict = f.read()
|
||||
return make_response(jsonify({"content": chat_reduce_strict}), 200)
|
||||
prompt = prompts_collection.find_one(
|
||||
{"_id": ObjectId(prompt_id), "user": user}
|
||||
)
|
||||
with db_readonly() as conn:
|
||||
prompt = PromptsRepository(conn).get_any(prompt_id, user)
|
||||
if not prompt:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Prompt not found"}), 404
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error retrieving prompt: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
@@ -156,11 +146,15 @@ class DeletePrompt(Resource):
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
prompts_collection.delete_one({"_id": ObjectId(data["id"]), "user": user})
|
||||
dual_write(
|
||||
PromptsRepository,
|
||||
lambda repo, pid=data["id"], u=user: repo.delete_by_legacy_id(pid, u),
|
||||
)
|
||||
with db_session() as conn:
|
||||
repo = PromptsRepository(conn)
|
||||
prompt = repo.get_any(data["id"], user)
|
||||
if not prompt:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Prompt not found"}),
|
||||
404,
|
||||
)
|
||||
repo.delete(str(prompt["id"]), user)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error deleting prompt: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
@@ -193,16 +187,15 @@ class UpdatePrompt(Resource):
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
prompts_collection.update_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user},
|
||||
{"$set": {"name": data["name"], "content": data["content"]}},
|
||||
)
|
||||
dual_write(
|
||||
PromptsRepository,
|
||||
lambda repo, pid=data["id"], u=user, n=data["name"], c=data["content"]: repo.update_by_legacy_id(
|
||||
pid, u, n, c,
|
||||
),
|
||||
)
|
||||
with db_session() as conn:
|
||||
repo = PromptsRepository(conn)
|
||||
prompt = repo.get_any(data["id"], user)
|
||||
if not prompt:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Prompt not found"}),
|
||||
404,
|
||||
)
|
||||
repo.update(str(prompt["id"]), user, data["name"], data["content"])
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error updating prompt: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
@@ -2,89 +2,126 @@
|
||||
|
||||
import uuid
|
||||
|
||||
from bson.binary import Binary, UuidRepresentation
|
||||
from bson.dbref import DBRef
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, inputs, Namespace, Resource
|
||||
from sqlalchemy import text as _sql_text
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import (
|
||||
agents_collection,
|
||||
attachments_collection,
|
||||
conversations_collection,
|
||||
shared_conversations_collections,
|
||||
)
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.attachments import AttachmentsRepository
|
||||
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||
from application.storage.db.repositories.shared_conversations import (
|
||||
SharedConversationsRepository,
|
||||
)
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.utils import check_required_fields
|
||||
|
||||
|
||||
def _dual_write_share(
|
||||
mongo_conv_id: str,
|
||||
share_uuid: str,
|
||||
user: str,
|
||||
*,
|
||||
is_promptable: bool,
|
||||
first_n_queries: int,
|
||||
api_key: str | None,
|
||||
prompt_id: str | None = None,
|
||||
chunks: int | None = None,
|
||||
) -> None:
|
||||
"""Mirror a Mongo share-record insert into Postgres.
|
||||
|
||||
Preserves the Mongo-generated UUID so public ``/shared/{uuid}`` URLs
|
||||
resolve from both stores during cutover.
|
||||
"""
|
||||
def _write(repo: SharedConversationsRepository) -> None:
|
||||
conv = ConversationsRepository(repo._conn).get_by_legacy_id(
|
||||
mongo_conv_id, user_id=user,
|
||||
)
|
||||
if conv is None:
|
||||
return
|
||||
# prompt_id / chunks are only meaningful for promptable shares;
|
||||
# prompt_id is often the string "default" or an ObjectId that
|
||||
# hasn't been migrated — pass as-is and let the repo drop
|
||||
# non-UUID values. Scope the prompt lookup by user_id so an
|
||||
# authenticated caller can't link another user's prompt into
|
||||
# their share record.
|
||||
resolved_prompt_id = None
|
||||
if prompt_id and len(str(prompt_id)) == 24:
|
||||
from sqlalchemy import text as _text
|
||||
row = repo._conn.execute(
|
||||
_text(
|
||||
"SELECT id FROM prompts "
|
||||
"WHERE legacy_mongo_id = :legacy_id AND user_id = :user_id"
|
||||
),
|
||||
{"legacy_id": str(prompt_id), "user_id": user},
|
||||
).fetchone()
|
||||
if row:
|
||||
resolved_prompt_id = str(row[0])
|
||||
# get_or_create is race-free on the PG side thanks to the
|
||||
# composite partial unique index on the dedup tuple
|
||||
# (migration 0008). It converges concurrent share requests to
|
||||
# a single row.
|
||||
repo.get_or_create(
|
||||
conv["id"],
|
||||
user,
|
||||
is_promptable=is_promptable,
|
||||
first_n_queries=first_n_queries,
|
||||
api_key=api_key,
|
||||
prompt_id=resolved_prompt_id,
|
||||
chunks=chunks,
|
||||
share_uuid=share_uuid,
|
||||
)
|
||||
|
||||
dual_write(SharedConversationsRepository, _write)
|
||||
|
||||
sharing_ns = Namespace(
|
||||
"sharing", description="Conversation sharing operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
def _resolve_prompt_pg_id(conn, prompt_id_raw, user_id):
|
||||
"""Translate an incoming prompt id (UUID or legacy Mongo ObjectId) to a PG UUID.
|
||||
|
||||
Scoped by ``user_id`` so a caller can't link another user's prompt
|
||||
into their share record. Returns ``None`` for sentinel values
|
||||
(``"default"``) or unresolved ids.
|
||||
"""
|
||||
if not prompt_id_raw or prompt_id_raw == "default":
|
||||
return None
|
||||
value = str(prompt_id_raw)
|
||||
# Already UUID — trust it but still require ownership. A shape-gate
|
||||
# (rather than a loose ``len == 36 and '-' in value`` check) keeps
|
||||
# non-UUID input out of ``CAST(:pid AS uuid)``; the cast would raise
|
||||
# and poison the readonly transaction otherwise.
|
||||
if looks_like_uuid(value):
|
||||
row = conn.execute(
|
||||
_sql_text(
|
||||
"SELECT id FROM prompts WHERE id = CAST(:pid AS uuid) "
|
||||
"AND user_id = :uid"
|
||||
),
|
||||
{"pid": value, "uid": user_id},
|
||||
).fetchone()
|
||||
return str(row[0]) if row else None
|
||||
# Legacy Mongo ObjectId fallback.
|
||||
row = conn.execute(
|
||||
_sql_text(
|
||||
"SELECT id FROM prompts WHERE legacy_mongo_id = :pid "
|
||||
"AND user_id = :uid"
|
||||
),
|
||||
{"pid": value, "uid": user_id},
|
||||
).fetchone()
|
||||
return str(row[0]) if row else None
|
||||
|
||||
|
||||
def _resolve_source_pg_id(conn, source_raw):
|
||||
"""Translate a source id (UUID or legacy Mongo ObjectId) to a PG UUID."""
|
||||
if not source_raw:
|
||||
return None
|
||||
value = str(source_raw)
|
||||
# See ``_resolve_prompt_pg_id`` for the shape-gate rationale.
|
||||
if looks_like_uuid(value):
|
||||
row = conn.execute(
|
||||
_sql_text(
|
||||
"SELECT id FROM sources WHERE id = CAST(:sid AS uuid)"
|
||||
),
|
||||
{"sid": value},
|
||||
).fetchone()
|
||||
return str(row[0]) if row else None
|
||||
row = conn.execute(
|
||||
_sql_text("SELECT id FROM sources WHERE legacy_mongo_id = :sid"),
|
||||
{"sid": value},
|
||||
).fetchone()
|
||||
return str(row[0]) if row else None
|
||||
|
||||
|
||||
def _find_reusable_share_agent(
|
||||
conn, user_id, *, prompt_pg_id, chunks, source_pg_id, retriever,
|
||||
):
|
||||
"""Find an existing share-as-agent key row matching these parameters.
|
||||
|
||||
Mirrors the legacy Mongo ``agents_collection.find_one`` pre-existence
|
||||
check. Used to reuse an api key across repeated shares of the same
|
||||
conversation with the same prompt/chunks/source/retriever.
|
||||
"""
|
||||
clauses = ["user_id = :uid", "key IS NOT NULL"]
|
||||
params: dict = {"uid": user_id}
|
||||
if prompt_pg_id is None:
|
||||
clauses.append("prompt_id IS NULL")
|
||||
else:
|
||||
clauses.append("prompt_id = CAST(:pid AS uuid)")
|
||||
params["pid"] = prompt_pg_id
|
||||
if chunks is None:
|
||||
clauses.append("chunks IS NULL")
|
||||
else:
|
||||
clauses.append("chunks = :chunks")
|
||||
params["chunks"] = int(chunks)
|
||||
if source_pg_id is None:
|
||||
clauses.append("source_id IS NULL")
|
||||
else:
|
||||
clauses.append("source_id = CAST(:sid AS uuid)")
|
||||
params["sid"] = source_pg_id
|
||||
if retriever is None:
|
||||
clauses.append("retriever IS NULL")
|
||||
else:
|
||||
clauses.append("retriever = :retr")
|
||||
params["retr"] = retriever
|
||||
sql = (
|
||||
"SELECT * FROM agents WHERE "
|
||||
+ " AND ".join(clauses)
|
||||
+ " LIMIT 1"
|
||||
)
|
||||
row = conn.execute(_sql_text(sql), params).fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
mapping = dict(row._mapping)
|
||||
mapping["id"] = str(mapping["id"]) if mapping.get("id") else None
|
||||
return mapping
|
||||
|
||||
|
||||
@sharing_ns.route("/share")
|
||||
class ShareConversation(Resource):
|
||||
share_conversation_model = api.model(
|
||||
@@ -119,173 +156,93 @@ class ShareConversation(Resource):
|
||||
conversation_id = data["conversation_id"]
|
||||
|
||||
try:
|
||||
conversation = conversations_collection.find_one(
|
||||
{"_id": ObjectId(conversation_id), "user": user}
|
||||
)
|
||||
if conversation is None:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"status": "error",
|
||||
"message": "Conversation does not exist",
|
||||
}
|
||||
),
|
||||
404,
|
||||
)
|
||||
current_n_queries = len(conversation["queries"])
|
||||
explicit_binary = Binary.from_uuid(
|
||||
uuid.uuid4(), UuidRepresentation.STANDARD
|
||||
)
|
||||
with db_session() as conn:
|
||||
conv_repo = ConversationsRepository(conn)
|
||||
shared_repo = SharedConversationsRepository(conn)
|
||||
agents_repo = AgentsRepository(conn)
|
||||
|
||||
if is_promptable:
|
||||
prompt_id = data.get("prompt_id", "default")
|
||||
chunks = data.get("chunks", "2")
|
||||
|
||||
name = conversation["name"] + "(shared)"
|
||||
new_api_key_data = {
|
||||
"prompt_id": prompt_id,
|
||||
"chunks": chunks,
|
||||
"user": user,
|
||||
}
|
||||
|
||||
if "source" in data and ObjectId.is_valid(data["source"]):
|
||||
new_api_key_data["source"] = DBRef(
|
||||
"sources", ObjectId(data["source"])
|
||||
)
|
||||
if "retriever" in data:
|
||||
new_api_key_data["retriever"] = data["retriever"]
|
||||
pre_existing_api_document = agents_collection.find_one(new_api_key_data)
|
||||
if pre_existing_api_document:
|
||||
api_uuid = pre_existing_api_document["key"]
|
||||
pre_existing = shared_conversations_collections.find_one(
|
||||
{
|
||||
"conversation_id": ObjectId(conversation_id),
|
||||
"isPromptable": is_promptable,
|
||||
"first_n_queries": current_n_queries,
|
||||
"user": user,
|
||||
"api_key": api_uuid,
|
||||
}
|
||||
)
|
||||
if pre_existing is not None:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"identifier": str(pre_existing["uuid"].as_uuid()),
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
else:
|
||||
shared_conversations_collections.insert_one(
|
||||
conversation = conv_repo.get_any(conversation_id, user)
|
||||
if conversation is None:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"uuid": explicit_binary,
|
||||
"conversation_id": ObjectId(conversation_id),
|
||||
"isPromptable": is_promptable,
|
||||
"first_n_queries": current_n_queries,
|
||||
"user": user,
|
||||
"api_key": api_uuid,
|
||||
"status": "error",
|
||||
"message": "Conversation does not exist",
|
||||
}
|
||||
)
|
||||
_dual_write_share(
|
||||
conversation_id,
|
||||
str(explicit_binary.as_uuid()),
|
||||
user,
|
||||
is_promptable=is_promptable,
|
||||
first_n_queries=current_n_queries,
|
||||
api_key=api_uuid,
|
||||
prompt_id=prompt_id,
|
||||
chunks=int(chunks) if chunks else None,
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"identifier": str(explicit_binary.as_uuid()),
|
||||
}
|
||||
),
|
||||
201,
|
||||
)
|
||||
else:
|
||||
api_uuid = str(uuid.uuid4())
|
||||
new_api_key_data["key"] = api_uuid
|
||||
new_api_key_data["name"] = name
|
||||
|
||||
if "source" in data and ObjectId.is_valid(data["source"]):
|
||||
new_api_key_data["source"] = DBRef(
|
||||
"sources", ObjectId(data["source"])
|
||||
)
|
||||
if "retriever" in data:
|
||||
new_api_key_data["retriever"] = data["retriever"]
|
||||
agents_collection.insert_one(new_api_key_data)
|
||||
shared_conversations_collections.insert_one(
|
||||
{
|
||||
"uuid": explicit_binary,
|
||||
"conversation_id": ObjectId(conversation_id),
|
||||
"isPromptable": is_promptable,
|
||||
"first_n_queries": current_n_queries,
|
||||
"user": user,
|
||||
"api_key": api_uuid,
|
||||
}
|
||||
),
|
||||
404,
|
||||
)
|
||||
_dual_write_share(
|
||||
conversation_id,
|
||||
str(explicit_binary.as_uuid()),
|
||||
conv_pg_id = str(conversation["id"])
|
||||
current_n_queries = conv_repo.message_count(conv_pg_id)
|
||||
|
||||
if is_promptable:
|
||||
prompt_id_raw = data.get("prompt_id", "default")
|
||||
chunks_raw = data.get("chunks", "2")
|
||||
try:
|
||||
chunks_int = int(chunks_raw) if chunks_raw not in (None, "") else None
|
||||
except (TypeError, ValueError):
|
||||
chunks_int = None
|
||||
|
||||
prompt_pg_id = _resolve_prompt_pg_id(conn, prompt_id_raw, user)
|
||||
source_pg_id = _resolve_source_pg_id(conn, data.get("source"))
|
||||
retriever = data.get("retriever")
|
||||
|
||||
reusable = _find_reusable_share_agent(
|
||||
conn, user,
|
||||
prompt_pg_id=prompt_pg_id,
|
||||
chunks=chunks_int,
|
||||
source_pg_id=source_pg_id,
|
||||
retriever=retriever,
|
||||
)
|
||||
if reusable:
|
||||
api_uuid = reusable.get("key")
|
||||
else:
|
||||
api_uuid = str(uuid.uuid4())
|
||||
name = (conversation.get("name") or "") + "(shared)"
|
||||
agents_repo.create(
|
||||
user,
|
||||
name,
|
||||
"published",
|
||||
key=api_uuid,
|
||||
retriever=retriever,
|
||||
chunks=chunks_int,
|
||||
prompt_id=prompt_pg_id,
|
||||
source_id=source_pg_id,
|
||||
)
|
||||
|
||||
share = shared_repo.get_or_create(
|
||||
conv_pg_id,
|
||||
user,
|
||||
is_promptable=is_promptable,
|
||||
is_promptable=True,
|
||||
first_n_queries=current_n_queries,
|
||||
api_key=api_uuid,
|
||||
prompt_id=prompt_id,
|
||||
chunks=int(chunks) if chunks else None,
|
||||
prompt_id=prompt_pg_id,
|
||||
chunks=chunks_int,
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"identifier": str(explicit_binary.as_uuid()),
|
||||
"identifier": str(share["uuid"]),
|
||||
}
|
||||
),
|
||||
201,
|
||||
201 if reusable is None else 200,
|
||||
)
|
||||
pre_existing = shared_conversations_collections.find_one(
|
||||
{
|
||||
"conversation_id": ObjectId(conversation_id),
|
||||
"isPromptable": is_promptable,
|
||||
"first_n_queries": current_n_queries,
|
||||
"user": user,
|
||||
}
|
||||
)
|
||||
if pre_existing is not None:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"identifier": str(pre_existing["uuid"].as_uuid()),
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
else:
|
||||
shared_conversations_collections.insert_one(
|
||||
{
|
||||
"uuid": explicit_binary,
|
||||
"conversation_id": ObjectId(conversation_id),
|
||||
"isPromptable": is_promptable,
|
||||
"first_n_queries": current_n_queries,
|
||||
"user": user,
|
||||
}
|
||||
)
|
||||
_dual_write_share(
|
||||
conversation_id,
|
||||
str(explicit_binary.as_uuid()),
|
||||
|
||||
# Non-promptable share path.
|
||||
share = shared_repo.get_or_create(
|
||||
conv_pg_id,
|
||||
user,
|
||||
is_promptable=is_promptable,
|
||||
is_promptable=False,
|
||||
first_n_queries=current_n_queries,
|
||||
api_key=None,
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{"success": True, "identifier": str(explicit_binary.as_uuid())}
|
||||
{
|
||||
"success": True,
|
||||
"identifier": str(share["uuid"]),
|
||||
}
|
||||
),
|
||||
201,
|
||||
)
|
||||
@@ -301,37 +258,13 @@ class GetPubliclySharedConversations(Resource):
|
||||
@api.doc(description="Get publicly shared conversations by identifier")
|
||||
def get(self, identifier: str):
|
||||
try:
|
||||
query_uuid = Binary.from_uuid(
|
||||
uuid.UUID(identifier), UuidRepresentation.STANDARD
|
||||
)
|
||||
shared = shared_conversations_collections.find_one({"uuid": query_uuid})
|
||||
conversation_queries = []
|
||||
with db_readonly() as conn:
|
||||
shared_repo = SharedConversationsRepository(conn)
|
||||
conv_repo = ConversationsRepository(conn)
|
||||
attach_repo = AttachmentsRepository(conn)
|
||||
|
||||
if (
|
||||
shared
|
||||
and "conversation_id" in shared
|
||||
):
|
||||
# Handle DBRef (legacy), ObjectId, dict, and string formats for conversation_id
|
||||
conversation_id = shared["conversation_id"]
|
||||
if isinstance(conversation_id, DBRef):
|
||||
conversation_id = conversation_id.id
|
||||
elif isinstance(conversation_id, dict):
|
||||
# Handle dict representation of DBRef (e.g., {"$ref": "...", "$id": "..."})
|
||||
if "$id" in conversation_id:
|
||||
conv_id = conversation_id["$id"]
|
||||
# $id might be a dict like {"$oid": "..."} or a string
|
||||
if isinstance(conv_id, dict) and "$oid" in conv_id:
|
||||
conversation_id = ObjectId(conv_id["$oid"])
|
||||
else:
|
||||
conversation_id = ObjectId(conv_id)
|
||||
elif "_id" in conversation_id:
|
||||
conversation_id = ObjectId(conversation_id["_id"])
|
||||
elif isinstance(conversation_id, str):
|
||||
conversation_id = ObjectId(conversation_id)
|
||||
conversation = conversations_collection.find_one(
|
||||
{"_id": conversation_id}
|
||||
)
|
||||
if conversation is None:
|
||||
shared = shared_repo.find_by_uuid(identifier)
|
||||
if not shared or not shared.get("conversation_id"):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
@@ -341,22 +274,60 @@ class GetPubliclySharedConversations(Resource):
|
||||
),
|
||||
404,
|
||||
)
|
||||
conversation_queries = conversation["queries"][
|
||||
: (shared["first_n_queries"])
|
||||
]
|
||||
conv_pg_id = str(shared["conversation_id"])
|
||||
owner_user = shared.get("user_id")
|
||||
|
||||
for query in conversation_queries:
|
||||
if "attachments" in query and query["attachments"]:
|
||||
conversation = conv_repo.get_owned(conv_pg_id, owner_user) if owner_user else None
|
||||
if conversation is None:
|
||||
# Fall back to any-user lookup in case shared row's
|
||||
# user_id is missing — still keyed by PG UUID.
|
||||
row = conn.execute(
|
||||
_sql_text(
|
||||
"SELECT * FROM conversations WHERE id = CAST(:id AS uuid)"
|
||||
),
|
||||
{"id": conv_pg_id},
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": "might have broken url or the conversation does not exist",
|
||||
}
|
||||
),
|
||||
404,
|
||||
)
|
||||
conversation = dict(row._mapping)
|
||||
|
||||
messages = conv_repo.get_messages(conv_pg_id)
|
||||
first_n = shared.get("first_n_queries") or 0
|
||||
conversation_queries = []
|
||||
for msg in messages[:first_n]:
|
||||
query = {
|
||||
"prompt": msg.get("prompt"),
|
||||
"response": msg.get("response"),
|
||||
"thought": msg.get("thought"),
|
||||
"sources": msg.get("sources") or [],
|
||||
"tool_calls": msg.get("tool_calls") or [],
|
||||
"timestamp": (
|
||||
msg["timestamp"].isoformat()
|
||||
if hasattr(msg.get("timestamp"), "isoformat")
|
||||
else msg.get("timestamp")
|
||||
),
|
||||
"feedback": msg.get("feedback"),
|
||||
}
|
||||
attachments = msg.get("attachments") or []
|
||||
if attachments:
|
||||
attachment_details = []
|
||||
for attachment_id in query["attachments"]:
|
||||
for attachment_id in attachments:
|
||||
try:
|
||||
attachment = attachments_collection.find_one(
|
||||
{"_id": ObjectId(attachment_id)}
|
||||
)
|
||||
attachment = attach_repo.get_any(
|
||||
str(attachment_id), owner_user,
|
||||
) if owner_user else None
|
||||
if attachment:
|
||||
attachment_details.append(
|
||||
{
|
||||
"id": str(attachment["_id"]),
|
||||
"id": str(attachment["id"]),
|
||||
"fileName": attachment.get(
|
||||
"filename", "Unknown file"
|
||||
),
|
||||
@@ -368,26 +339,23 @@ class GetPubliclySharedConversations(Resource):
|
||||
exc_info=True,
|
||||
)
|
||||
query["attachments"] = attachment_details
|
||||
else:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": "might have broken url or the conversation does not exist",
|
||||
}
|
||||
),
|
||||
404,
|
||||
conversation_queries.append(query)
|
||||
|
||||
created = conversation.get("created_at") or conversation.get("date")
|
||||
date_iso = (
|
||||
created.isoformat()
|
||||
if hasattr(created, "isoformat")
|
||||
else (str(created) if created is not None else None)
|
||||
)
|
||||
date = conversation["_id"].generation_time.isoformat()
|
||||
res = {
|
||||
"success": True,
|
||||
"queries": conversation_queries,
|
||||
"title": conversation["name"],
|
||||
"timestamp": date,
|
||||
}
|
||||
if shared["isPromptable"] and "api_key" in shared:
|
||||
res["api_key"] = shared["api_key"]
|
||||
return make_response(jsonify(res), 200)
|
||||
res = {
|
||||
"success": True,
|
||||
"queries": conversation_queries,
|
||||
"title": conversation.get("name"),
|
||||
"timestamp": date_iso,
|
||||
}
|
||||
if shared.get("is_promptable") and shared.get("api_key"):
|
||||
res["api_key"] = shared["api_key"]
|
||||
return make_response(jsonify(res), 200)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error getting shared conversation: {err}", exc_info=True
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
"""Source document management chunk management."""
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import get_vector_store, sources_collection
|
||||
from application.api.user.base import get_vector_store
|
||||
from application.storage.db.repositories.sources import SourcesRepository
|
||||
from application.storage.db.session import db_readonly
|
||||
from application.utils import check_required_fields, num_tokens_from_string
|
||||
|
||||
sources_chunks_ns = Namespace(
|
||||
@@ -13,6 +14,15 @@ sources_chunks_ns = Namespace(
|
||||
)
|
||||
|
||||
|
||||
def _resolve_source(doc_id: str, user: str):
|
||||
"""Resolve a source (UUID or legacy ObjectId) for the caller.
|
||||
|
||||
Returns the row dict (with PG UUID in ``id``) or ``None`` if missing.
|
||||
"""
|
||||
with db_readonly() as conn:
|
||||
return SourcesRepository(conn).get_any(doc_id, user)
|
||||
|
||||
|
||||
@sources_chunks_ns.route("/get_chunks")
|
||||
class GetChunks(Resource):
|
||||
@api.doc(
|
||||
@@ -36,36 +46,34 @@ class GetChunks(Resource):
|
||||
path = request.args.get("path")
|
||||
search_term = request.args.get("search", "").strip().lower()
|
||||
|
||||
if not ObjectId.is_valid(doc_id):
|
||||
if not doc_id:
|
||||
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
||||
try:
|
||||
doc = _resolve_source(doc_id, user)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error resolving source: {e}", exc_info=True)
|
||||
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
||||
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
||||
if not doc:
|
||||
return make_response(
|
||||
jsonify({"error": "Document not found or access denied"}), 404
|
||||
)
|
||||
resolved_id = str(doc["id"])
|
||||
try:
|
||||
store = get_vector_store(doc_id)
|
||||
store = get_vector_store(resolved_id)
|
||||
chunks = store.get_chunks()
|
||||
|
||||
filtered_chunks = []
|
||||
for chunk in chunks:
|
||||
metadata = chunk.get("metadata", {})
|
||||
|
||||
# Filter by path if provided
|
||||
|
||||
if path:
|
||||
chunk_source = metadata.get("source", "")
|
||||
chunk_file_path = metadata.get("file_path", "")
|
||||
# Check if the chunk matches the requested path
|
||||
# For file uploads: source ends with path (e.g., "inputs/.../file.pdf" ends with "file.pdf")
|
||||
# For crawlers: file_path ends with path (e.g., "guides/setup.md" ends with "setup.md")
|
||||
source_match = chunk_source and chunk_source.endswith(path)
|
||||
file_path_match = chunk_file_path and chunk_file_path.endswith(path)
|
||||
|
||||
if not (source_match or file_path_match):
|
||||
continue
|
||||
# Filter by search term if provided
|
||||
|
||||
if search_term:
|
||||
text_match = search_term in chunk.get("text", "").lower()
|
||||
title_match = search_term in metadata.get("title", "").lower()
|
||||
@@ -132,15 +140,17 @@ class AddChunk(Resource):
|
||||
token_count = num_tokens_from_string(text)
|
||||
metadata["token_count"] = token_count
|
||||
|
||||
if not ObjectId.is_valid(doc_id):
|
||||
try:
|
||||
doc = _resolve_source(doc_id, user)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error resolving source: {e}", exc_info=True)
|
||||
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
||||
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
||||
if not doc:
|
||||
return make_response(
|
||||
jsonify({"error": "Document not found or access denied"}), 404
|
||||
)
|
||||
try:
|
||||
store = get_vector_store(doc_id)
|
||||
store = get_vector_store(str(doc["id"]))
|
||||
chunk_id = store.add_chunk(text, metadata)
|
||||
return make_response(
|
||||
jsonify({"message": "Chunk added successfully", "chunk_id": chunk_id}),
|
||||
@@ -165,15 +175,17 @@ class DeleteChunk(Resource):
|
||||
doc_id = request.args.get("id")
|
||||
chunk_id = request.args.get("chunk_id")
|
||||
|
||||
if not ObjectId.is_valid(doc_id):
|
||||
try:
|
||||
doc = _resolve_source(doc_id, user)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error resolving source: {e}", exc_info=True)
|
||||
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
||||
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
||||
if not doc:
|
||||
return make_response(
|
||||
jsonify({"error": "Document not found or access denied"}), 404
|
||||
)
|
||||
try:
|
||||
store = get_vector_store(doc_id)
|
||||
store = get_vector_store(str(doc["id"]))
|
||||
deleted = store.delete_chunk(chunk_id)
|
||||
if deleted:
|
||||
return make_response(
|
||||
@@ -232,15 +244,17 @@ class UpdateChunk(Resource):
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
metadata["token_count"] = token_count
|
||||
if not ObjectId.is_valid(doc_id):
|
||||
try:
|
||||
doc = _resolve_source(doc_id, user)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error resolving source: {e}", exc_info=True)
|
||||
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
||||
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
||||
if not doc:
|
||||
return make_response(
|
||||
jsonify({"error": "Document not found or access denied"}), 404
|
||||
)
|
||||
try:
|
||||
store = get_vector_store(doc_id)
|
||||
store = get_vector_store(str(doc["id"]))
|
||||
|
||||
chunks = store.get_chunks()
|
||||
existing_chunk = next((c for c in chunks if c["doc_id"] == chunk_id), None)
|
||||
|
||||
@@ -3,14 +3,14 @@
|
||||
import json
|
||||
import math
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, redirect, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import sources_collection
|
||||
from application.api.user.tasks import sync_source
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.repositories.sources import SourcesRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.storage.storage_creator import StorageCreator
|
||||
from application.utils import check_required_fields
|
||||
from application.vectorstore.vector_creator import VectorCreator
|
||||
@@ -56,11 +56,20 @@ class CombinedJson(Resource):
|
||||
]
|
||||
|
||||
try:
|
||||
for index in sources_collection.find({"user": user}).sort("date", -1):
|
||||
with db_readonly() as conn:
|
||||
indexes = SourcesRepository(conn).list_for_user(user)
|
||||
# list_for_user sorts by created_at DESC; legacy shape sorted by
|
||||
# "date" DESC. Both are monotonic on creation so the ordering is
|
||||
# equivalent for dev; re-sort defensively.
|
||||
indexes = sorted(
|
||||
indexes, key=lambda r: r.get("date") or r.get("created_at") or "",
|
||||
reverse=True,
|
||||
)
|
||||
for index in indexes:
|
||||
provider = _get_provider_from_remote_data(index.get("remote_data"))
|
||||
data.append(
|
||||
{
|
||||
"id": str(index["_id"]),
|
||||
"id": str(index["id"]),
|
||||
"name": index.get("name"),
|
||||
"date": index.get("date"),
|
||||
"model": settings.EMBEDDINGS_NAME,
|
||||
@@ -70,9 +79,7 @@ class CombinedJson(Resource):
|
||||
"syncFrequency": index.get("sync_frequency", ""),
|
||||
"provider": provider,
|
||||
"is_nested": bool(index.get("directory_structure")),
|
||||
"type": index.get(
|
||||
"type", "file"
|
||||
), # Add type field with default "file"
|
||||
"type": index.get("type", "file"),
|
||||
}
|
||||
)
|
||||
except Exception as err:
|
||||
@@ -89,61 +96,55 @@ class PaginatedSources(Resource):
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
sort_field = request.args.get("sort", "date") # Default to 'date'
|
||||
sort_order = request.args.get("order", "desc") # Default to 'desc'
|
||||
page = int(request.args.get("page", 1)) # Default to 1
|
||||
rows_per_page = int(request.args.get("rows", 10)) # Default to 10
|
||||
# add .strip() to remove leading and trailing whitespaces
|
||||
|
||||
search_term = request.args.get(
|
||||
"search", ""
|
||||
).strip() # add search for filter documents
|
||||
|
||||
# Prepare query for filtering
|
||||
|
||||
query = {"user": user}
|
||||
if search_term:
|
||||
query["name"] = {
|
||||
"$regex": search_term,
|
||||
"$options": "i", # using case-insensitive search
|
||||
}
|
||||
total_documents = sources_collection.count_documents(query)
|
||||
total_pages = max(1, math.ceil(total_documents / rows_per_page))
|
||||
page = min(
|
||||
max(1, page), total_pages
|
||||
) # add this to make sure page inbound is within the range
|
||||
sort_order = 1 if sort_order == "asc" else -1
|
||||
skip = (page - 1) * rows_per_page
|
||||
sort_field = request.args.get("sort", "date")
|
||||
sort_order = request.args.get("order", "desc")
|
||||
page = max(1, int(request.args.get("page", 1)))
|
||||
rows_per_page = max(1, int(request.args.get("rows", 10)))
|
||||
search_term = request.args.get("search", "").strip() or None
|
||||
|
||||
try:
|
||||
documents = (
|
||||
sources_collection.find(query)
|
||||
.sort(sort_field, sort_order)
|
||||
.skip(skip)
|
||||
.limit(rows_per_page)
|
||||
)
|
||||
with db_readonly() as conn:
|
||||
repo = SourcesRepository(conn)
|
||||
total_documents = repo.count_for_user(
|
||||
user, search_term=search_term,
|
||||
)
|
||||
# Prior in-Python implementation returned ``totalPages = 1``
|
||||
# for empty result sets (``max(1, ceil(0/rows))``); we
|
||||
# preserve that contract so the frontend pager stays stable.
|
||||
total_pages = max(1, math.ceil(total_documents / rows_per_page))
|
||||
effective_page = min(page, total_pages)
|
||||
offset = (effective_page - 1) * rows_per_page
|
||||
window = repo.list_for_user(
|
||||
user,
|
||||
limit=rows_per_page,
|
||||
offset=offset,
|
||||
search_term=search_term,
|
||||
sort_field=sort_field,
|
||||
sort_order=sort_order,
|
||||
)
|
||||
|
||||
paginated_docs = []
|
||||
for doc in documents:
|
||||
for doc in window:
|
||||
provider = _get_provider_from_remote_data(doc.get("remote_data"))
|
||||
doc_data = {
|
||||
"id": str(doc["_id"]),
|
||||
"name": doc.get("name", ""),
|
||||
"date": doc.get("date", ""),
|
||||
"model": settings.EMBEDDINGS_NAME,
|
||||
"location": "local",
|
||||
"tokens": doc.get("tokens", ""),
|
||||
"retriever": doc.get("retriever", "classic"),
|
||||
"syncFrequency": doc.get("sync_frequency", ""),
|
||||
"provider": provider,
|
||||
"isNested": bool(doc.get("directory_structure")),
|
||||
"type": doc.get("type", "file"),
|
||||
}
|
||||
paginated_docs.append(doc_data)
|
||||
paginated_docs.append(
|
||||
{
|
||||
"id": str(doc["id"]),
|
||||
"name": doc.get("name", ""),
|
||||
"date": doc.get("date", ""),
|
||||
"model": settings.EMBEDDINGS_NAME,
|
||||
"location": "local",
|
||||
"tokens": doc.get("tokens", ""),
|
||||
"retriever": doc.get("retriever", "classic"),
|
||||
"syncFrequency": doc.get("sync_frequency", ""),
|
||||
"provider": provider,
|
||||
"isNested": bool(doc.get("directory_structure")),
|
||||
"type": doc.get("type", "file"),
|
||||
}
|
||||
)
|
||||
response = {
|
||||
"total": total_documents,
|
||||
"totalPages": total_pages,
|
||||
"currentPage": page,
|
||||
"currentPage": effective_page,
|
||||
"paginated": paginated_docs,
|
||||
}
|
||||
return make_response(jsonify(response), 200)
|
||||
@@ -154,28 +155,6 @@ class PaginatedSources(Resource):
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
|
||||
@sources_ns.route("/delete_by_ids")
|
||||
class DeleteByIds(Resource):
|
||||
@api.doc(
|
||||
description="Deletes documents from the vector store by IDs",
|
||||
params={"path": "Comma-separated list of IDs"},
|
||||
)
|
||||
def get(self):
|
||||
ids = request.args.get("path")
|
||||
if not ids:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Missing required fields"}), 400
|
||||
)
|
||||
try:
|
||||
result = sources_collection.delete_index(ids=ids)
|
||||
if result:
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error deleting indexes: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
|
||||
@sources_ns.route("/delete_old")
|
||||
class DeleteOldIndexes(Resource):
|
||||
@api.doc(
|
||||
@@ -186,30 +165,33 @@ class DeleteOldIndexes(Resource):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
source_id = request.args.get("source_id")
|
||||
if not source_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Missing required fields"}), 400
|
||||
)
|
||||
doc = sources_collection.find_one(
|
||||
{"_id": ObjectId(source_id), "user": decoded_token.get("sub")}
|
||||
)
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
doc = SourcesRepository(conn).get_any(source_id, user)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error looking up source: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
if not doc:
|
||||
return make_response(jsonify({"status": "not found"}), 404)
|
||||
storage = StorageCreator.get_storage()
|
||||
resolved_id = str(doc["id"])
|
||||
|
||||
try:
|
||||
# Delete vector index
|
||||
|
||||
if settings.VECTOR_STORE == "faiss":
|
||||
index_path = f"indexes/{str(doc['_id'])}"
|
||||
index_path = f"indexes/{resolved_id}"
|
||||
if storage.file_exists(f"{index_path}/index.faiss"):
|
||||
storage.delete_file(f"{index_path}/index.faiss")
|
||||
if storage.file_exists(f"{index_path}/index.pkl"):
|
||||
storage.delete_file(f"{index_path}/index.pkl")
|
||||
else:
|
||||
vectorstore = VectorCreator.create_vectorstore(
|
||||
settings.VECTOR_STORE, source_id=str(doc["_id"])
|
||||
settings.VECTOR_STORE, source_id=resolved_id
|
||||
)
|
||||
vectorstore.delete_index()
|
||||
if "file_path" in doc and doc["file_path"]:
|
||||
@@ -227,7 +209,14 @@ class DeleteOldIndexes(Resource):
|
||||
f"Error deleting files and indexes: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
sources_collection.delete_one({"_id": ObjectId(source_id)})
|
||||
try:
|
||||
with db_session() as conn:
|
||||
SourcesRepository(conn).delete(resolved_id, user)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error deleting source row: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@@ -272,15 +261,16 @@ class ManageSync(Resource):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid frequency"}), 400
|
||||
)
|
||||
update_data = {"$set": {"sync_frequency": sync_frequency}}
|
||||
try:
|
||||
sources_collection.update_one(
|
||||
{
|
||||
"_id": ObjectId(source_id),
|
||||
"user": user,
|
||||
},
|
||||
update_data,
|
||||
)
|
||||
with db_session() as conn:
|
||||
repo = SourcesRepository(conn)
|
||||
doc = repo.get_any(source_id, user)
|
||||
if doc is None:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Source not found"}),
|
||||
404,
|
||||
)
|
||||
repo.update(str(doc["id"]), user, {"sync_frequency": sync_frequency})
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error updating sync frequency: {err}", exc_info=True
|
||||
@@ -309,19 +299,20 @@ class SyncSource(Resource):
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
source_id = data["source_id"]
|
||||
if not ObjectId.is_valid(source_id):
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
doc = SourcesRepository(conn).get_any(source_id, user)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error looking up source: {err}", exc_info=True)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid source ID"}), 400
|
||||
)
|
||||
doc = sources_collection.find_one(
|
||||
{"_id": ObjectId(source_id), "user": user}
|
||||
)
|
||||
if not doc:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Source not found"}), 404
|
||||
)
|
||||
source_type = doc.get("type", "")
|
||||
if source_type.startswith("connector"):
|
||||
if source_type and source_type.startswith("connector"):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
@@ -344,7 +335,7 @@ class SyncSource(Resource):
|
||||
loader=source_type,
|
||||
sync_frequency=doc.get("sync_frequency", "never"),
|
||||
retriever=doc.get("retriever", "classic"),
|
||||
doc_id=source_id,
|
||||
doc_id=str(doc["id"]),
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
@@ -370,10 +361,9 @@ class DirectoryStructure(Resource):
|
||||
|
||||
if not doc_id:
|
||||
return make_response(jsonify({"error": "Document ID is required"}), 400)
|
||||
if not ObjectId.is_valid(doc_id):
|
||||
return make_response(jsonify({"error": "Invalid document ID"}), 400)
|
||||
try:
|
||||
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
||||
with db_readonly() as conn:
|
||||
doc = SourcesRepository(conn).get_any(doc_id, user)
|
||||
if not doc:
|
||||
return make_response(
|
||||
jsonify({"error": "Document not found or access denied"}), 404
|
||||
@@ -387,6 +377,8 @@ class DirectoryStructure(Resource):
|
||||
if isinstance(remote_data, str) and remote_data:
|
||||
remote_data_obj = json.loads(remote_data)
|
||||
provider = remote_data_obj.get("provider")
|
||||
elif isinstance(remote_data, dict):
|
||||
provider = remote_data.get("provider")
|
||||
except Exception as e:
|
||||
current_app.logger.warning(
|
||||
f"Failed to parse remote_data for doc {doc_id}: {e}"
|
||||
@@ -406,4 +398,7 @@ class DirectoryStructure(Resource):
|
||||
current_app.logger.error(
|
||||
f"Error retrieving directory structure: {e}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False, "error": "Failed to retrieve directory structure"}), 500)
|
||||
return make_response(
|
||||
jsonify({"success": False, "error": "Failed to retrieve directory structure"}),
|
||||
500,
|
||||
)
|
||||
|
||||
@@ -5,16 +5,16 @@ import os
|
||||
import tempfile
|
||||
import zipfile
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import sources_collection
|
||||
from application.api.user.tasks import ingest, ingest_connector_task, ingest_remote
|
||||
from application.core.settings import settings
|
||||
from application.parser.connectors.connector_creator import ConnectorCreator
|
||||
from application.parser.file.constants import SUPPORTED_SOURCE_EXTENSIONS
|
||||
from application.storage.db.repositories.sources import SourcesRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.storage.storage_creator import StorageCreator
|
||||
from application.stt.upload_limits import (
|
||||
AudioFileTooLargeError,
|
||||
@@ -329,15 +329,8 @@ class ManageSourceFiles(Resource):
|
||||
400,
|
||||
)
|
||||
try:
|
||||
ObjectId(source_id)
|
||||
except Exception:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid source ID format"}), 400
|
||||
)
|
||||
try:
|
||||
source = sources_collection.find_one(
|
||||
{"_id": ObjectId(source_id), "user": user}
|
||||
)
|
||||
with db_readonly() as conn:
|
||||
source = SourcesRepository(conn).get_any(source_id, user)
|
||||
if not source:
|
||||
return make_response(
|
||||
jsonify(
|
||||
@@ -353,6 +346,7 @@ class ManageSourceFiles(Resource):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Database error"}), 500
|
||||
)
|
||||
resolved_source_id = str(source["id"])
|
||||
try:
|
||||
storage = StorageCreator.get_storage()
|
||||
source_file_path = source.get("file_path", "")
|
||||
@@ -411,15 +405,18 @@ class ManageSourceFiles(Resource):
|
||||
map_updated = True
|
||||
|
||||
if map_updated:
|
||||
sources_collection.update_one(
|
||||
{"_id": ObjectId(source_id)},
|
||||
{"$set": {"file_name_map": file_name_map}},
|
||||
)
|
||||
with db_session() as conn:
|
||||
SourcesRepository(conn).update(
|
||||
resolved_source_id, user,
|
||||
{"file_name_map": dict(file_name_map)},
|
||||
)
|
||||
# Trigger re-ingestion pipeline
|
||||
|
||||
from application.api.user.tasks import reingest_source_task
|
||||
|
||||
task = reingest_source_task.delay(source_id=source_id, user=user)
|
||||
task = reingest_source_task.delay(
|
||||
source_id=resolved_source_id, user=user
|
||||
)
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
@@ -485,15 +482,18 @@ class ManageSourceFiles(Resource):
|
||||
map_updated = True
|
||||
|
||||
if map_updated and isinstance(file_name_map, dict):
|
||||
sources_collection.update_one(
|
||||
{"_id": ObjectId(source_id)},
|
||||
{"$set": {"file_name_map": file_name_map}},
|
||||
)
|
||||
with db_session() as conn:
|
||||
SourcesRepository(conn).update(
|
||||
resolved_source_id, user,
|
||||
{"file_name_map": dict(file_name_map)},
|
||||
)
|
||||
# Trigger re-ingestion pipeline
|
||||
|
||||
from application.api.user.tasks import reingest_source_task
|
||||
|
||||
task = reingest_source_task.delay(source_id=source_id, user=user)
|
||||
task = reingest_source_task.delay(
|
||||
source_id=resolved_source_id, user=user
|
||||
)
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
@@ -581,16 +581,19 @@ class ManageSourceFiles(Resource):
|
||||
if keys_to_remove:
|
||||
for key in keys_to_remove:
|
||||
file_name_map.pop(key, None)
|
||||
sources_collection.update_one(
|
||||
{"_id": ObjectId(source_id)},
|
||||
{"$set": {"file_name_map": file_name_map}},
|
||||
)
|
||||
with db_session() as conn:
|
||||
SourcesRepository(conn).update(
|
||||
resolved_source_id, user,
|
||||
{"file_name_map": dict(file_name_map)},
|
||||
)
|
||||
|
||||
# Trigger re-ingestion pipeline
|
||||
|
||||
from application.api.user.tasks import reingest_source_task
|
||||
|
||||
task = reingest_source_task.delay(source_id=source_id, user=user)
|
||||
task = reingest_source_task.delay(
|
||||
source_id=resolved_source_id, user=user
|
||||
)
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
|
||||
@@ -3,27 +3,24 @@
|
||||
import json
|
||||
from urllib.parse import urlencode, urlparse
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, redirect, request
|
||||
from flask_restx import Namespace, Resource, fields
|
||||
|
||||
from application.agents.tools.mcp_tool import MCPOAuthManager, MCPTool
|
||||
from application.api import api
|
||||
from application.api.user.base import user_tools_collection
|
||||
from application.api.user.tools.routes import transform_actions
|
||||
from application.cache import get_redis_instance
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.core.url_validation import SSRFError, validate_url
|
||||
from application.security.encryption import decrypt_credentials, encrypt_credentials
|
||||
from application.storage.db.repositories.connector_sessions import (
|
||||
ConnectorSessionsRepository,
|
||||
)
|
||||
from application.storage.db.repositories.user_tools import UserToolsRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.utils import check_required_fields
|
||||
|
||||
tools_mcp_ns = Namespace("tools", description="Tool management operations", path="/api")
|
||||
|
||||
_mongo = MongoDB.get_client()
|
||||
_db = _mongo[settings.MONGO_DB_NAME]
|
||||
_connector_sessions = _db["connector_sessions"]
|
||||
|
||||
_ALLOWED_TRANSPORTS = {"auto", "sse", "http"}
|
||||
|
||||
|
||||
@@ -252,15 +249,18 @@ class MCPServerSave(Resource):
|
||||
storage_config = config.copy()
|
||||
|
||||
tool_id = data.get("id")
|
||||
existing_doc = None
|
||||
existing_encrypted = None
|
||||
if tool_id:
|
||||
existing_doc = user_tools_collection.find_one(
|
||||
{"_id": ObjectId(tool_id), "user": user, "name": "mcp_tool"}
|
||||
)
|
||||
if existing_doc:
|
||||
existing_encrypted = existing_doc.get("config", {}).get(
|
||||
with db_readonly() as conn:
|
||||
repo = UserToolsRepository(conn)
|
||||
existing_doc = repo.get_any(tool_id, user)
|
||||
if existing_doc and existing_doc.get("name") == "mcp_tool":
|
||||
existing_encrypted = (existing_doc.get("config") or {}).get(
|
||||
"encrypted_credentials"
|
||||
)
|
||||
else:
|
||||
existing_doc = None
|
||||
|
||||
if auth_credentials:
|
||||
if existing_encrypted:
|
||||
@@ -283,47 +283,88 @@ class MCPServerSave(Resource):
|
||||
]:
|
||||
storage_config.pop(field, None)
|
||||
transformed_actions = transform_actions(actions_metadata)
|
||||
tool_data = {
|
||||
"name": "mcp_tool",
|
||||
"displayName": data["displayName"],
|
||||
"customName": data["displayName"],
|
||||
"description": f"MCP Server: {storage_config.get('server_url', 'Unknown')}",
|
||||
"config": storage_config,
|
||||
"actions": transformed_actions,
|
||||
"status": data.get("status", True),
|
||||
"user": user,
|
||||
}
|
||||
|
||||
if tool_id:
|
||||
result = user_tools_collection.update_one(
|
||||
{"_id": ObjectId(tool_id), "user": user, "name": "mcp_tool"},
|
||||
{"$set": {k: v for k, v in tool_data.items() if k != "user"}},
|
||||
)
|
||||
if result.matched_count == 0:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": "Tool not found or access denied",
|
||||
}
|
||||
),
|
||||
404,
|
||||
display_name = data["displayName"]
|
||||
description = f"MCP Server: {storage_config.get('server_url', 'Unknown')}"
|
||||
status_bool = bool(data.get("status", True))
|
||||
|
||||
with db_session() as conn:
|
||||
repo = UserToolsRepository(conn)
|
||||
if existing_doc:
|
||||
repo.update(
|
||||
str(existing_doc["id"]), user,
|
||||
{
|
||||
"display_name": display_name,
|
||||
"custom_name": display_name,
|
||||
"description": description,
|
||||
"config": storage_config,
|
||||
"actions": transformed_actions,
|
||||
"status": status_bool,
|
||||
},
|
||||
)
|
||||
response_data = {
|
||||
"success": True,
|
||||
"id": tool_id,
|
||||
"message": f"MCP server updated successfully! Discovered {len(transformed_actions)} tools.",
|
||||
"tools_count": len(transformed_actions),
|
||||
}
|
||||
else:
|
||||
result = user_tools_collection.insert_one(tool_data)
|
||||
tool_id = str(result.inserted_id)
|
||||
response_data = {
|
||||
"success": True,
|
||||
"id": tool_id,
|
||||
"message": f"MCP server created successfully! Discovered {len(transformed_actions)} tools.",
|
||||
"tools_count": len(transformed_actions),
|
||||
}
|
||||
saved_id = str(existing_doc["id"])
|
||||
response_data = {
|
||||
"success": True,
|
||||
"id": saved_id,
|
||||
"message": f"MCP server updated successfully! Discovered {len(transformed_actions)} tools.",
|
||||
"tools_count": len(transformed_actions),
|
||||
}
|
||||
else:
|
||||
# Fall back to find_by_user_and_name — the original
|
||||
# dual-write path also ran an existence check before
|
||||
# deciding between insert and update.
|
||||
existing_by_name = repo.find_by_user_and_name(user, "mcp_tool")
|
||||
if tool_id is None and existing_by_name and (
|
||||
(existing_by_name.get("config") or {}).get("server_url")
|
||||
== storage_config.get("server_url")
|
||||
):
|
||||
repo.update(
|
||||
str(existing_by_name["id"]), user,
|
||||
{
|
||||
"display_name": display_name,
|
||||
"custom_name": display_name,
|
||||
"description": description,
|
||||
"config": storage_config,
|
||||
"actions": transformed_actions,
|
||||
"status": status_bool,
|
||||
},
|
||||
)
|
||||
saved_id = str(existing_by_name["id"])
|
||||
response_data = {
|
||||
"success": True,
|
||||
"id": saved_id,
|
||||
"message": f"MCP server updated successfully! Discovered {len(transformed_actions)} tools.",
|
||||
"tools_count": len(transformed_actions),
|
||||
}
|
||||
else:
|
||||
created = repo.create(
|
||||
user, "mcp_tool",
|
||||
config=storage_config,
|
||||
custom_name=display_name,
|
||||
display_name=display_name,
|
||||
description=description,
|
||||
config_requirements={},
|
||||
actions=transformed_actions,
|
||||
status=status_bool,
|
||||
)
|
||||
saved_id = str(created["id"])
|
||||
response_data = {
|
||||
"success": True,
|
||||
"id": saved_id,
|
||||
"message": f"MCP server created successfully! Discovered {len(transformed_actions)} tools.",
|
||||
"tools_count": len(transformed_actions),
|
||||
}
|
||||
if tool_id and existing_doc is None:
|
||||
# Client requested update on a non-existent tool id.
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": "Tool not found or access denied",
|
||||
}
|
||||
),
|
||||
404,
|
||||
)
|
||||
return make_response(jsonify(response_data), 200)
|
||||
except ValueError as e:
|
||||
current_app.logger.warning(f"Invalid MCP server save request: {e}")
|
||||
@@ -459,49 +500,59 @@ class MCPAuthStatus(Resource):
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
mcp_tools = list(
|
||||
user_tools_collection.find(
|
||||
{"user": user, "name": "mcp_tool"},
|
||||
{"_id": 1, "config": 1},
|
||||
)
|
||||
)
|
||||
if not mcp_tools:
|
||||
return make_response(jsonify({"success": True, "statuses": {}}), 200)
|
||||
|
||||
oauth_server_urls = {}
|
||||
statuses = {}
|
||||
for tool in mcp_tools:
|
||||
tool_id = str(tool["_id"])
|
||||
config = tool.get("config", {})
|
||||
auth_type = config.get("auth_type", "none")
|
||||
if auth_type == "oauth":
|
||||
server_url = config.get("server_url", "")
|
||||
if server_url:
|
||||
parsed = urlparse(server_url)
|
||||
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
||||
oauth_server_urls[tool_id] = base_url
|
||||
else:
|
||||
statuses[tool_id] = "needs_auth"
|
||||
else:
|
||||
statuses[tool_id] = "configured"
|
||||
|
||||
if oauth_server_urls:
|
||||
unique_urls = list(set(oauth_server_urls.values()))
|
||||
sessions = list(
|
||||
_connector_sessions.find(
|
||||
{"user_id": user, "server_url": {"$in": unique_urls}},
|
||||
{"server_url": 1, "tokens": 1},
|
||||
with db_readonly() as conn:
|
||||
tools_repo = UserToolsRepository(conn)
|
||||
sessions_repo = ConnectorSessionsRepository(conn)
|
||||
all_tools = tools_repo.list_for_user(user)
|
||||
mcp_tools = [t for t in all_tools if t.get("name") == "mcp_tool"]
|
||||
if not mcp_tools:
|
||||
return make_response(
|
||||
jsonify({"success": True, "statuses": {}}), 200
|
||||
)
|
||||
)
|
||||
url_has_tokens = {
|
||||
doc["server_url"]: bool(doc.get("tokens", {}).get("access_token"))
|
||||
for doc in sessions
|
||||
}
|
||||
for tool_id, base_url in oauth_server_urls.items():
|
||||
if url_has_tokens.get(base_url):
|
||||
statuses[tool_id] = "connected"
|
||||
|
||||
oauth_server_urls: dict = {}
|
||||
statuses: dict = {}
|
||||
for tool in mcp_tools:
|
||||
tool_id = str(tool["id"])
|
||||
config = tool.get("config") or {}
|
||||
auth_type = config.get("auth_type", "none")
|
||||
if auth_type == "oauth":
|
||||
server_url = config.get("server_url", "")
|
||||
if server_url:
|
||||
parsed = urlparse(server_url)
|
||||
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
||||
oauth_server_urls[tool_id] = base_url
|
||||
else:
|
||||
statuses[tool_id] = "needs_auth"
|
||||
else:
|
||||
statuses[tool_id] = "needs_auth"
|
||||
statuses[tool_id] = "configured"
|
||||
|
||||
if oauth_server_urls:
|
||||
# Look up a session per distinct base URL. MCP sessions
|
||||
# are stored with ``provider = "mcp:<server_url>"``
|
||||
# and the URL in ``server_url``; reuse the repo's
|
||||
# per-URL accessor rather than an ad-hoc $in query.
|
||||
url_has_tokens: dict = {}
|
||||
for base_url in set(oauth_server_urls.values()):
|
||||
session = sessions_repo.get_by_user_and_server_url(
|
||||
user, base_url,
|
||||
)
|
||||
tokens = (
|
||||
(session or {}).get("session_data", {}) or {}
|
||||
).get("tokens", {}) or {}
|
||||
# MCP code also stashes tokens into token_info on
|
||||
# the row; consider either present as "connected".
|
||||
token_info = (session or {}).get("token_info") or {}
|
||||
url_has_tokens[base_url] = bool(
|
||||
tokens.get("access_token")
|
||||
or token_info.get("access_token")
|
||||
)
|
||||
|
||||
for tool_id, base_url in oauth_server_urls.items():
|
||||
if url_has_tokens.get(base_url):
|
||||
statuses[tool_id] = "connected"
|
||||
else:
|
||||
statuses[tool_id] = "needs_auth"
|
||||
|
||||
return make_response(jsonify({"success": True, "statuses": statuses}), 200)
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,23 +1,59 @@
|
||||
"""Tool management routes."""
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.agents.tools.spec_parser import parse_spec
|
||||
from application.agents.tools.tool_manager import ToolManager
|
||||
from application.api import api
|
||||
from application.api.user.base import user_tools_collection
|
||||
from application.core.url_validation import SSRFError, validate_url
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.repositories.user_tools import UserToolsRepository
|
||||
from application.security.encryption import decrypt_credentials, encrypt_credentials
|
||||
from application.storage.db.repositories.notes import NotesRepository
|
||||
from application.storage.db.repositories.todos import TodosRepository
|
||||
from application.storage.db.repositories.user_tools import UserToolsRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.utils import check_required_fields, validate_function_name
|
||||
|
||||
tool_config = {}
|
||||
tool_manager = ToolManager(config=tool_config)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shape translation helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
# The frontend speaks camelCase (``displayName`` / ``customName`` /
|
||||
# ``configRequirements``). The PG ``user_tools`` table stores snake_case
|
||||
# (``display_name`` / ``custom_name`` / ``config_requirements``). Keep the
|
||||
# translation localized to this module so repositories stay pure.
|
||||
|
||||
_CAMEL_TO_SNAKE = {
|
||||
"displayName": "display_name",
|
||||
"customName": "custom_name",
|
||||
"configRequirements": "config_requirements",
|
||||
}
|
||||
_SNAKE_TO_CAMEL = {v: k for k, v in _CAMEL_TO_SNAKE.items()}
|
||||
|
||||
|
||||
def _row_to_api(row: dict) -> dict:
|
||||
"""Rename DB-native snake_case keys to the camelCase shape the frontend expects."""
|
||||
out = dict(row)
|
||||
for snake, camel in _SNAKE_TO_CAMEL.items():
|
||||
if snake in out:
|
||||
out[camel] = out.pop(snake)
|
||||
# ``user_id`` is exposed as ``user`` in the legacy API shape.
|
||||
if "user_id" in out:
|
||||
out["user"] = out.pop("user_id")
|
||||
return out
|
||||
|
||||
|
||||
def _api_to_update_fields(data: dict) -> dict:
|
||||
"""Rename incoming camelCase update keys to the repo's snake_case columns."""
|
||||
fields_out: dict = {}
|
||||
for key, value in data.items():
|
||||
fields_out[_CAMEL_TO_SNAKE.get(key, key)] = value
|
||||
return fields_out
|
||||
|
||||
|
||||
def _encrypt_secret_fields(config, config_requirements, user_id):
|
||||
secret_keys = [
|
||||
key for key, spec in config_requirements.items()
|
||||
@@ -170,12 +206,11 @@ class GetTools(Resource):
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
tools = user_tools_collection.find({"user": user})
|
||||
with db_readonly() as conn:
|
||||
rows = UserToolsRepository(conn).list_for_user(user)
|
||||
user_tools = []
|
||||
for tool in tools:
|
||||
tool_copy = {**tool}
|
||||
tool_copy["id"] = str(tool["_id"])
|
||||
tool_copy.pop("_id", None)
|
||||
for row in rows:
|
||||
tool_copy = _row_to_api(row)
|
||||
|
||||
config_req = tool_copy.get("configRequirements", {})
|
||||
if not config_req:
|
||||
@@ -283,26 +318,19 @@ class CreateTool(Resource):
|
||||
storage_config = _encrypt_secret_fields(
|
||||
data["config"], config_requirements, user
|
||||
)
|
||||
new_tool = {
|
||||
"user": user,
|
||||
"name": data["name"],
|
||||
"displayName": data["displayName"],
|
||||
"description": data["description"],
|
||||
"customName": data.get("customName", ""),
|
||||
"actions": transformed_actions,
|
||||
"config": storage_config,
|
||||
"configRequirements": config_requirements,
|
||||
"status": data["status"],
|
||||
}
|
||||
resp = user_tools_collection.insert_one(new_tool)
|
||||
new_id = str(resp.inserted_id)
|
||||
dual_write(
|
||||
UserToolsRepository,
|
||||
lambda repo, u=user, t=new_tool: repo.create(
|
||||
u, t["name"], config=t.get("config"),
|
||||
custom_name=t.get("customName"), display_name=t.get("displayName"),
|
||||
),
|
||||
)
|
||||
with db_session() as conn:
|
||||
created = UserToolsRepository(conn).create(
|
||||
user,
|
||||
data["name"],
|
||||
config=storage_config,
|
||||
custom_name=data.get("customName", ""),
|
||||
display_name=data["displayName"],
|
||||
description=data["description"],
|
||||
config_requirements=config_requirements,
|
||||
actions=transformed_actions,
|
||||
status=bool(data.get("status", True)),
|
||||
)
|
||||
new_id = str(created["id"])
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error creating tool: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
@@ -340,17 +368,10 @@ class UpdateTool(Resource):
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
update_data = {}
|
||||
if "name" in data:
|
||||
update_data["name"] = data["name"]
|
||||
if "displayName" in data:
|
||||
update_data["displayName"] = data["displayName"]
|
||||
if "customName" in data:
|
||||
update_data["customName"] = data["customName"]
|
||||
if "description" in data:
|
||||
update_data["description"] = data["description"]
|
||||
if "actions" in data:
|
||||
update_data["actions"] = data["actions"]
|
||||
update_data: dict = {}
|
||||
for key in ("name", "displayName", "customName", "description", "actions"):
|
||||
if key in data:
|
||||
update_data[key] = data[key]
|
||||
if "config" in data:
|
||||
if "actions" in data["config"]:
|
||||
for action_name in list(data["config"]["actions"].keys()):
|
||||
@@ -365,46 +386,61 @@ class UpdateTool(Resource):
|
||||
),
|
||||
400,
|
||||
)
|
||||
tool_doc = user_tools_collection.find_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user}
|
||||
)
|
||||
if not tool_doc:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Tool not found"}),
|
||||
404,
|
||||
)
|
||||
tool_name = tool_doc.get("name", data.get("name"))
|
||||
tool_instance = tool_manager.tools.get(tool_name)
|
||||
config_requirements = (
|
||||
tool_instance.get_config_requirements() if tool_instance else {}
|
||||
)
|
||||
existing_config = tool_doc.get("config", {})
|
||||
has_existing_secrets = "encrypted_credentials" in existing_config
|
||||
|
||||
if config_requirements:
|
||||
validation_errors = _validate_config(
|
||||
data["config"], config_requirements,
|
||||
has_existing_secrets=has_existing_secrets,
|
||||
)
|
||||
if validation_errors:
|
||||
with db_session() as conn:
|
||||
repo = UserToolsRepository(conn)
|
||||
tool_doc = repo.get_any(data["id"], user)
|
||||
if not tool_doc:
|
||||
return make_response(
|
||||
jsonify({
|
||||
"success": False,
|
||||
"message": "Validation failed",
|
||||
"errors": validation_errors,
|
||||
}),
|
||||
400,
|
||||
jsonify({"success": False, "message": "Tool not found"}),
|
||||
404,
|
||||
)
|
||||
tool_name = tool_doc.get("name", data.get("name"))
|
||||
tool_instance = tool_manager.tools.get(tool_name)
|
||||
config_requirements = (
|
||||
tool_instance.get_config_requirements()
|
||||
if tool_instance
|
||||
else {}
|
||||
)
|
||||
existing_config = tool_doc.get("config", {}) or {}
|
||||
has_existing_secrets = "encrypted_credentials" in existing_config
|
||||
|
||||
update_data["config"] = _merge_secrets_on_update(
|
||||
data["config"], existing_config, config_requirements, user
|
||||
)
|
||||
if "status" in data:
|
||||
update_data["status"] = data["status"]
|
||||
user_tools_collection.update_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user},
|
||||
{"$set": update_data},
|
||||
)
|
||||
if config_requirements:
|
||||
validation_errors = _validate_config(
|
||||
data["config"], config_requirements,
|
||||
has_existing_secrets=has_existing_secrets,
|
||||
)
|
||||
if validation_errors:
|
||||
return make_response(
|
||||
jsonify({
|
||||
"success": False,
|
||||
"message": "Validation failed",
|
||||
"errors": validation_errors,
|
||||
}),
|
||||
400,
|
||||
)
|
||||
|
||||
update_data["config"] = _merge_secrets_on_update(
|
||||
data["config"], existing_config, config_requirements, user
|
||||
)
|
||||
if "status" in data:
|
||||
update_data["status"] = bool(data["status"])
|
||||
repo.update(
|
||||
str(tool_doc["id"]), user, _api_to_update_fields(update_data),
|
||||
)
|
||||
else:
|
||||
if "status" in data:
|
||||
update_data["status"] = bool(data["status"])
|
||||
with db_session() as conn:
|
||||
repo = UserToolsRepository(conn)
|
||||
tool_doc = repo.get_any(data["id"], user)
|
||||
if not tool_doc:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Tool not found"}),
|
||||
404,
|
||||
)
|
||||
repo.update(
|
||||
str(tool_doc["id"]), user, _api_to_update_fields(update_data),
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error updating tool: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
@@ -436,53 +472,50 @@ class UpdateToolConfig(Resource):
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
tool_doc = user_tools_collection.find_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user}
|
||||
)
|
||||
if not tool_doc:
|
||||
return make_response(jsonify({"success": False}), 404)
|
||||
with db_session() as conn:
|
||||
repo = UserToolsRepository(conn)
|
||||
tool_doc = repo.get_any(data["id"], user)
|
||||
if not tool_doc:
|
||||
return make_response(jsonify({"success": False}), 404)
|
||||
|
||||
tool_name = tool_doc.get("name")
|
||||
if tool_name == "mcp_tool":
|
||||
server_url = (data["config"].get("server_url") or "").strip()
|
||||
if server_url:
|
||||
try:
|
||||
validate_url(server_url)
|
||||
except SSRFError:
|
||||
tool_name = tool_doc.get("name")
|
||||
if tool_name == "mcp_tool":
|
||||
server_url = (data["config"].get("server_url") or "").strip()
|
||||
if server_url:
|
||||
try:
|
||||
validate_url(server_url)
|
||||
except SSRFError:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid server URL"}),
|
||||
400,
|
||||
)
|
||||
tool_instance = tool_manager.tools.get(tool_name)
|
||||
config_requirements = (
|
||||
tool_instance.get_config_requirements() if tool_instance else {}
|
||||
)
|
||||
existing_config = tool_doc.get("config", {}) or {}
|
||||
has_existing_secrets = "encrypted_credentials" in existing_config
|
||||
|
||||
if config_requirements:
|
||||
validation_errors = _validate_config(
|
||||
data["config"], config_requirements,
|
||||
has_existing_secrets=has_existing_secrets,
|
||||
)
|
||||
if validation_errors:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid server URL"}),
|
||||
jsonify({
|
||||
"success": False,
|
||||
"message": "Validation failed",
|
||||
"errors": validation_errors,
|
||||
}),
|
||||
400,
|
||||
)
|
||||
tool_instance = tool_manager.tools.get(tool_name)
|
||||
config_requirements = (
|
||||
tool_instance.get_config_requirements() if tool_instance else {}
|
||||
)
|
||||
existing_config = tool_doc.get("config", {})
|
||||
has_existing_secrets = "encrypted_credentials" in existing_config
|
||||
|
||||
if config_requirements:
|
||||
validation_errors = _validate_config(
|
||||
data["config"], config_requirements,
|
||||
has_existing_secrets=has_existing_secrets,
|
||||
final_config = _merge_secrets_on_update(
|
||||
data["config"], existing_config, config_requirements, user
|
||||
)
|
||||
if validation_errors:
|
||||
return make_response(
|
||||
jsonify({
|
||||
"success": False,
|
||||
"message": "Validation failed",
|
||||
"errors": validation_errors,
|
||||
}),
|
||||
400,
|
||||
)
|
||||
|
||||
final_config = _merge_secrets_on_update(
|
||||
data["config"], existing_config, config_requirements, user
|
||||
)
|
||||
|
||||
user_tools_collection.update_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user},
|
||||
{"$set": {"config": final_config}},
|
||||
)
|
||||
repo.update(str(tool_doc["id"]), user, {"config": final_config})
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error updating tool config: {err}", exc_info=True
|
||||
@@ -518,10 +551,17 @@ class UpdateToolActions(Resource):
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
user_tools_collection.update_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user},
|
||||
{"$set": {"actions": data["actions"]}},
|
||||
)
|
||||
with db_session() as conn:
|
||||
repo = UserToolsRepository(conn)
|
||||
tool_doc = repo.get_any(data["id"], user)
|
||||
if not tool_doc:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Tool not found"}),
|
||||
404,
|
||||
)
|
||||
repo.update(
|
||||
str(tool_doc["id"]), user, {"actions": data["actions"]},
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error updating tool actions: {err}", exc_info=True
|
||||
@@ -555,10 +595,17 @@ class UpdateToolStatus(Resource):
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
user_tools_collection.update_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user},
|
||||
{"$set": {"status": data["status"]}},
|
||||
)
|
||||
with db_session() as conn:
|
||||
repo = UserToolsRepository(conn)
|
||||
tool_doc = repo.get_any(data["id"], user)
|
||||
if not tool_doc:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Tool not found"}),
|
||||
404,
|
||||
)
|
||||
repo.update(
|
||||
str(tool_doc["id"]), user, {"status": bool(data["status"])},
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error updating tool status: {err}", exc_info=True
|
||||
@@ -587,17 +634,14 @@ class DeleteTool(Resource):
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
result = user_tools_collection.delete_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user}
|
||||
)
|
||||
dual_write(
|
||||
UserToolsRepository,
|
||||
lambda repo, tid=data["id"], u=user: repo.delete(tid, u),
|
||||
)
|
||||
if result.deleted_count == 0:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Tool not found"}), 404
|
||||
)
|
||||
with db_session() as conn:
|
||||
repo = UserToolsRepository(conn)
|
||||
tool_doc = repo.get_any(data["id"], user)
|
||||
if not tool_doc:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Tool not found"}), 404
|
||||
)
|
||||
repo.delete(str(tool_doc["id"]), user)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error deleting tool: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
@@ -666,70 +710,88 @@ class GetArtifact(Resource):
|
||||
user_id = decoded_token.get("sub")
|
||||
|
||||
try:
|
||||
obj_id = ObjectId(artifact_id)
|
||||
except Exception:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid artifact ID"}), 400
|
||||
with db_readonly() as conn:
|
||||
notes_repo = NotesRepository(conn)
|
||||
todos_repo = TodosRepository(conn)
|
||||
|
||||
# Artifact IDs may be PG UUIDs (post-cutover) or legacy
|
||||
# Mongo ObjectIds embedded in older conversation history.
|
||||
# Both repos' ``get_any`` handles the id-shape branching
|
||||
# internally so a non-UUID input never reaches
|
||||
# ``CAST(:id AS uuid)`` (which would poison the readonly
|
||||
# transaction and break the fallback below).
|
||||
note_doc = notes_repo.get_any(artifact_id, user_id)
|
||||
|
||||
if note_doc:
|
||||
content = note_doc.get("note", "") or note_doc.get("content", "")
|
||||
line_count = len(content.split("\n")) if content else 0
|
||||
updated = note_doc.get("updated_at")
|
||||
artifact = {
|
||||
"artifact_type": "note",
|
||||
"data": {
|
||||
"content": content,
|
||||
"line_count": line_count,
|
||||
"updated_at": (
|
||||
updated.isoformat()
|
||||
if hasattr(updated, "isoformat")
|
||||
else updated
|
||||
),
|
||||
},
|
||||
}
|
||||
return make_response(
|
||||
jsonify({"success": True, "artifact": artifact}), 200
|
||||
)
|
||||
|
||||
todo_doc = todos_repo.get_any(artifact_id, user_id)
|
||||
if todo_doc:
|
||||
tool_id = todo_doc.get("tool_id")
|
||||
all_todos = todos_repo.list_for_tool(user_id, tool_id) if tool_id else []
|
||||
items = []
|
||||
open_count = 0
|
||||
completed_count = 0
|
||||
for t in all_todos:
|
||||
# PG ``todos`` stores a ``completed BOOLEAN`` column;
|
||||
# the legacy Mongo shape used a ``status`` string.
|
||||
# Keep the response shape stable by translating here.
|
||||
status = "completed" if t.get("completed") else "open"
|
||||
if status == "open":
|
||||
open_count += 1
|
||||
else:
|
||||
completed_count += 1
|
||||
created = t.get("created_at")
|
||||
updated = t.get("updated_at")
|
||||
items.append({
|
||||
"todo_id": t.get("todo_id"),
|
||||
"title": t.get("title", ""),
|
||||
"status": status,
|
||||
"created_at": (
|
||||
created.isoformat()
|
||||
if hasattr(created, "isoformat")
|
||||
else created
|
||||
),
|
||||
"updated_at": (
|
||||
updated.isoformat()
|
||||
if hasattr(updated, "isoformat")
|
||||
else updated
|
||||
),
|
||||
})
|
||||
artifact = {
|
||||
"artifact_type": "todo_list",
|
||||
"data": {
|
||||
"items": items,
|
||||
"total_count": len(items),
|
||||
"open_count": open_count,
|
||||
"completed_count": completed_count,
|
||||
},
|
||||
}
|
||||
return make_response(
|
||||
jsonify({"success": True, "artifact": artifact}), 200
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error retrieving artifact: {err}", exc_info=True
|
||||
)
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
|
||||
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
|
||||
|
||||
note_doc = db["notes"].find_one({"_id": obj_id, "user_id": user_id})
|
||||
if note_doc:
|
||||
content = note_doc.get("note", "")
|
||||
line_count = len(content.split("\n")) if content else 0
|
||||
artifact = {
|
||||
"artifact_type": "note",
|
||||
"data": {
|
||||
"content": content,
|
||||
"line_count": line_count,
|
||||
"updated_at": (
|
||||
note_doc["updated_at"].isoformat()
|
||||
if note_doc.get("updated_at")
|
||||
else None
|
||||
),
|
||||
},
|
||||
}
|
||||
return make_response(jsonify({"success": True, "artifact": artifact}), 200)
|
||||
|
||||
todo_doc = db["todos"].find_one({"_id": obj_id, "user_id": user_id})
|
||||
if todo_doc:
|
||||
tool_id = todo_doc.get("tool_id")
|
||||
query = {"user_id": user_id, "tool_id": tool_id}
|
||||
all_todos = list(db["todos"].find(query))
|
||||
items = []
|
||||
open_count = 0
|
||||
completed_count = 0
|
||||
for t in all_todos:
|
||||
status = t.get("status", "open")
|
||||
if status == "open":
|
||||
open_count += 1
|
||||
elif status == "completed":
|
||||
completed_count += 1
|
||||
items.append({
|
||||
"todo_id": t.get("todo_id"),
|
||||
"title": t.get("title", ""),
|
||||
"status": status,
|
||||
"created_at": (
|
||||
t["created_at"].isoformat() if t.get("created_at") else None
|
||||
),
|
||||
"updated_at": (
|
||||
t["updated_at"].isoformat() if t.get("updated_at") else None
|
||||
),
|
||||
})
|
||||
artifact = {
|
||||
"artifact_type": "todo_list",
|
||||
"data": {
|
||||
"items": items,
|
||||
"total_count": len(items),
|
||||
"open_count": open_count,
|
||||
"completed_count": completed_count,
|
||||
},
|
||||
}
|
||||
return make_response(jsonify({"success": True, "artifact": artifact}), 200)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Artifact not found"}), 404
|
||||
|
||||
@@ -1,290 +1,61 @@
|
||||
"""Centralized utilities for API routes."""
|
||||
"""Centralized utilities for API routes.
|
||||
|
||||
Post-Mongo-cutover slim: the old Mongo-shaped helpers (``validate_object_id``,
|
||||
``check_resource_ownership``, ``paginated_response``, ``serialize_object_id``,
|
||||
``safe_db_operation``, ``validate_enum``, ``extract_sort_params``) have been
|
||||
removed — they carried ``bson`` / ``pymongo`` imports and had zero callers.
|
||||
"""
|
||||
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
from typing import Callable, Optional
|
||||
|
||||
from bson.errors import InvalidId
|
||||
from bson.objectid import ObjectId
|
||||
from flask import (
|
||||
Response,
|
||||
current_app,
|
||||
has_app_context,
|
||||
jsonify,
|
||||
make_response,
|
||||
request,
|
||||
)
|
||||
from pymongo.collection import Collection
|
||||
|
||||
|
||||
def get_user_id() -> Optional[str]:
|
||||
"""
|
||||
Extract user ID from decoded JWT token.
|
||||
|
||||
Returns:
|
||||
User ID string or None if not authenticated
|
||||
"""
|
||||
"""Extract user ID from decoded JWT token, or None if unauthenticated."""
|
||||
decoded_token = getattr(request, "decoded_token", None)
|
||||
return decoded_token.get("sub") if decoded_token else None
|
||||
|
||||
|
||||
def require_auth(func: Callable) -> Callable:
|
||||
"""
|
||||
Decorator to require authentication for route handlers.
|
||||
|
||||
Usage:
|
||||
@require_auth
|
||||
def get(self):
|
||||
user_id = get_user_id()
|
||||
...
|
||||
"""
|
||||
"""Decorator to require authentication. Returns 401 when absent."""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
user_id = get_user_id()
|
||||
if not user_id:
|
||||
return error_response("Unauthorized", 401)
|
||||
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def success_response(
|
||||
data: Optional[Dict[str, Any]] = None, status: int = 200
|
||||
data=None, message: Optional[str] = None, status: int = 200
|
||||
) -> Response:
|
||||
"""
|
||||
Create a standardized success response.
|
||||
|
||||
Args:
|
||||
data: Optional data dictionary to include in response
|
||||
status: HTTP status code (default: 200)
|
||||
|
||||
Returns:
|
||||
Flask Response object
|
||||
|
||||
Example:
|
||||
return success_response({"users": [...], "total": 10})
|
||||
"""
|
||||
response = {"success": True}
|
||||
if data:
|
||||
response.update(data)
|
||||
return make_response(jsonify(response), status)
|
||||
"""Shape a successful JSON response."""
|
||||
body = {"success": True}
|
||||
if data is not None:
|
||||
body["data"] = data
|
||||
if message is not None:
|
||||
body["message"] = message
|
||||
return make_response(jsonify(body), status)
|
||||
|
||||
|
||||
def error_response(message: str, status: int = 400, **kwargs) -> Response:
|
||||
"""
|
||||
Create a standardized error response.
|
||||
|
||||
Args:
|
||||
message: Error message string
|
||||
status: HTTP status code (default: 400)
|
||||
**kwargs: Additional fields to include in response
|
||||
|
||||
Returns:
|
||||
Flask Response object
|
||||
|
||||
Example:
|
||||
return error_response("Resource not found", 404)
|
||||
return error_response("Invalid input", 400, errors=["field1", "field2"])
|
||||
"""
|
||||
response = {"success": False, "message": message}
|
||||
response.update(kwargs)
|
||||
return make_response(jsonify(response), status)
|
||||
"""Shape an error JSON response; any kwargs are merged into the body."""
|
||||
body = {"success": False, "error": message, **kwargs}
|
||||
return make_response(jsonify(body), status)
|
||||
|
||||
|
||||
def validate_object_id(
|
||||
id_string: str, resource_name: str = "Resource"
|
||||
) -> Tuple[Optional[ObjectId], Optional[Response]]:
|
||||
"""
|
||||
Validate and convert string to ObjectId.
|
||||
|
||||
Args:
|
||||
id_string: String to convert
|
||||
resource_name: Name of resource for error message
|
||||
|
||||
Returns:
|
||||
Tuple of (ObjectId or None, error_response or None)
|
||||
|
||||
Example:
|
||||
obj_id, error = validate_object_id(workflow_id, "Workflow")
|
||||
if error:
|
||||
return error
|
||||
"""
|
||||
try:
|
||||
return ObjectId(id_string), None
|
||||
except (InvalidId, TypeError):
|
||||
return None, error_response(f"Invalid {resource_name} ID format")
|
||||
|
||||
|
||||
def validate_pagination(
|
||||
default_limit: int = 20, max_limit: int = 100
|
||||
) -> Tuple[int, int, Optional[Response]]:
|
||||
"""
|
||||
Extract and validate pagination parameters from request.
|
||||
|
||||
Args:
|
||||
default_limit: Default items per page
|
||||
max_limit: Maximum allowed items per page
|
||||
|
||||
Returns:
|
||||
Tuple of (limit, skip, error_response or None)
|
||||
|
||||
Example:
|
||||
limit, skip, error = validate_pagination()
|
||||
if error:
|
||||
return error
|
||||
"""
|
||||
try:
|
||||
limit = min(int(request.args.get("limit", default_limit)), max_limit)
|
||||
skip = int(request.args.get("skip", 0))
|
||||
if limit < 1 or skip < 0:
|
||||
return 0, 0, error_response("Invalid pagination parameters")
|
||||
return limit, skip, None
|
||||
except ValueError:
|
||||
return 0, 0, error_response("Invalid pagination parameters")
|
||||
|
||||
|
||||
def check_resource_ownership(
|
||||
collection: Collection,
|
||||
resource_id: ObjectId,
|
||||
user_id: str,
|
||||
resource_name: str = "Resource",
|
||||
) -> Tuple[Optional[Dict], Optional[Response]]:
|
||||
"""
|
||||
Check if resource exists and belongs to user.
|
||||
|
||||
Args:
|
||||
collection: MongoDB collection
|
||||
resource_id: Resource ObjectId
|
||||
user_id: User ID string
|
||||
resource_name: Name of resource for error messages
|
||||
|
||||
Returns:
|
||||
Tuple of (resource_dict or None, error_response or None)
|
||||
|
||||
Example:
|
||||
workflow, error = check_resource_ownership(
|
||||
workflows_collection,
|
||||
workflow_id,
|
||||
user_id,
|
||||
"Workflow"
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
"""
|
||||
resource = collection.find_one({"_id": resource_id, "user": user_id})
|
||||
if not resource:
|
||||
return None, error_response(f"{resource_name} not found", 404)
|
||||
return resource, None
|
||||
|
||||
|
||||
def serialize_object_id(
|
||||
obj: Dict[str, Any], id_field: str = "_id", new_field: str = "id"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert ObjectId to string in a dictionary.
|
||||
|
||||
Args:
|
||||
obj: Dictionary containing ObjectId
|
||||
id_field: Field name containing ObjectId
|
||||
new_field: New field name for string ID
|
||||
|
||||
Returns:
|
||||
Modified dictionary
|
||||
|
||||
Example:
|
||||
user = serialize_object_id(user_doc)
|
||||
# user["id"] = "507f1f77bcf86cd799439011"
|
||||
"""
|
||||
if id_field in obj:
|
||||
obj[new_field] = str(obj[id_field])
|
||||
if id_field != new_field:
|
||||
obj.pop(id_field, None)
|
||||
return obj
|
||||
|
||||
|
||||
def serialize_list(items: List[Dict], serializer: Callable[[Dict], Dict]) -> List[Dict]:
|
||||
"""
|
||||
Apply serializer function to list of items.
|
||||
|
||||
Args:
|
||||
items: List of dictionaries
|
||||
serializer: Function to apply to each item
|
||||
|
||||
Returns:
|
||||
List of serialized items
|
||||
|
||||
Example:
|
||||
workflows = serialize_list(workflow_docs, serialize_workflow)
|
||||
"""
|
||||
return [serializer(item) for item in items]
|
||||
|
||||
|
||||
def paginated_response(
|
||||
collection: Collection,
|
||||
query: Dict[str, Any],
|
||||
serializer: Callable[[Dict], Dict],
|
||||
limit: int,
|
||||
skip: int,
|
||||
sort_field: str = "created_at",
|
||||
sort_order: int = -1,
|
||||
response_key: str = "items",
|
||||
) -> Response:
|
||||
"""
|
||||
Create paginated response for collection query.
|
||||
|
||||
Args:
|
||||
collection: MongoDB collection
|
||||
query: Query dictionary
|
||||
serializer: Function to serialize each item
|
||||
limit: Items per page
|
||||
skip: Number of items to skip
|
||||
sort_field: Field to sort by
|
||||
sort_order: Sort order (1=asc, -1=desc)
|
||||
response_key: Key name for items in response
|
||||
|
||||
Returns:
|
||||
Flask Response with paginated data
|
||||
|
||||
Example:
|
||||
return paginated_response(
|
||||
workflows_collection,
|
||||
{"user": user_id},
|
||||
serialize_workflow,
|
||||
limit, skip,
|
||||
response_key="workflows"
|
||||
)
|
||||
"""
|
||||
items = list(
|
||||
collection.find(query).sort(sort_field, sort_order).skip(skip).limit(limit)
|
||||
)
|
||||
total = collection.count_documents(query)
|
||||
|
||||
return success_response(
|
||||
{
|
||||
response_key: serialize_list(items, serializer),
|
||||
"total": total,
|
||||
"limit": limit,
|
||||
"skip": skip,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def require_fields(required: List[str]) -> Callable:
|
||||
"""
|
||||
Decorator to validate required fields in request JSON.
|
||||
|
||||
Args:
|
||||
required: List of required field names
|
||||
|
||||
Returns:
|
||||
Decorator function
|
||||
|
||||
Example:
|
||||
@require_fields(["name", "description"])
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
...
|
||||
"""
|
||||
def require_fields(required: list) -> Callable:
|
||||
"""Decorator: return 400 if any listed field is missing/falsy in the JSON body."""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
@@ -294,94 +65,11 @@ def require_fields(required: List[str]) -> Callable:
|
||||
return error_response("Request body required")
|
||||
missing = [field for field in required if not data.get(field)]
|
||||
if missing:
|
||||
return error_response(f"Missing required fields: {', '.join(missing)}")
|
||||
return error_response(
|
||||
f"Missing required fields: {', '.join(missing)}"
|
||||
)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def safe_db_operation(
|
||||
operation: Callable, error_message: str = "Database operation failed"
|
||||
) -> Tuple[Any, Optional[Response]]:
|
||||
"""
|
||||
Safely execute database operation with error handling.
|
||||
|
||||
Args:
|
||||
operation: Function to execute
|
||||
error_message: Error message if operation fails
|
||||
|
||||
Returns:
|
||||
Tuple of (result or None, error_response or None)
|
||||
|
||||
Example:
|
||||
result, error = safe_db_operation(
|
||||
lambda: collection.insert_one(doc),
|
||||
"Failed to create resource"
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
"""
|
||||
try:
|
||||
result = operation()
|
||||
return result, None
|
||||
except Exception as err:
|
||||
if has_app_context():
|
||||
current_app.logger.error(f"{error_message}: {err}", exc_info=True)
|
||||
return None, error_response(error_message)
|
||||
|
||||
|
||||
def validate_enum(
|
||||
value: Any, allowed: List[Any], field_name: str
|
||||
) -> Optional[Response]:
|
||||
"""
|
||||
Validate that value is in allowed list.
|
||||
|
||||
Args:
|
||||
value: Value to validate
|
||||
allowed: List of allowed values
|
||||
field_name: Field name for error message
|
||||
|
||||
Returns:
|
||||
error_response if invalid, None if valid
|
||||
|
||||
Example:
|
||||
error = validate_enum(status, ["draft", "published"], "status")
|
||||
if error:
|
||||
return error
|
||||
"""
|
||||
if value not in allowed:
|
||||
allowed_str = ", ".join(f"'{v}'" for v in allowed)
|
||||
return error_response(f"Invalid {field_name}. Must be one of: {allowed_str}")
|
||||
return None
|
||||
|
||||
|
||||
def extract_sort_params(
|
||||
default_field: str = "created_at",
|
||||
default_order: str = "desc",
|
||||
allowed_fields: Optional[List[str]] = None,
|
||||
) -> Tuple[str, int]:
|
||||
"""
|
||||
Extract and validate sort parameters from request.
|
||||
|
||||
Args:
|
||||
default_field: Default sort field
|
||||
default_order: Default sort order ("asc" or "desc")
|
||||
allowed_fields: List of allowed sort fields (None = no validation)
|
||||
|
||||
Returns:
|
||||
Tuple of (sort_field, sort_order)
|
||||
|
||||
Example:
|
||||
sort_field, sort_order = extract_sort_params(
|
||||
allowed_fields=["name", "date", "status"]
|
||||
)
|
||||
"""
|
||||
sort_field = request.args.get("sort", default_field)
|
||||
sort_order_str = request.args.get("order", default_order).lower()
|
||||
|
||||
if allowed_fields and sort_field not in allowed_fields:
|
||||
sort_field = default_field
|
||||
sort_order = -1 if sort_order_str == "desc" else 1
|
||||
return sort_field, sort_order
|
||||
|
||||
@@ -1,34 +1,26 @@
|
||||
"""Workflow management routes."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from flask import current_app, request
|
||||
from flask_restx import Namespace, Resource
|
||||
|
||||
from application.api.user.base import (
|
||||
workflow_edges_collection,
|
||||
workflow_nodes_collection,
|
||||
workflows_collection,
|
||||
)
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
from application.storage.db.repositories.workflow_edges import WorkflowEdgesRepository
|
||||
from application.storage.db.repositories.workflow_nodes import WorkflowNodesRepository
|
||||
from application.storage.db.repositories.workflows import WorkflowsRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.core.json_schema_utils import (
|
||||
JsonSchemaValidationError,
|
||||
normalize_json_schema_payload,
|
||||
)
|
||||
from application.core.model_utils import get_model_capabilities
|
||||
from application.api.user.utils import (
|
||||
check_resource_ownership,
|
||||
error_response,
|
||||
get_user_id,
|
||||
require_auth,
|
||||
require_fields,
|
||||
safe_db_operation,
|
||||
success_response,
|
||||
validate_object_id,
|
||||
)
|
||||
|
||||
workflows_ns = Namespace("workflows", path="/api")
|
||||
@@ -39,109 +31,15 @@ def _workflow_error_response(message: str, err: Exception):
|
||||
return error_response(message)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Postgres dual-write helpers
|
||||
#
|
||||
# Workflows are unusual relative to other Phase 3 tables: a single user
|
||||
# action (create / update) writes to three collections in concert
|
||||
# (workflows + workflow_nodes + workflow_edges) and the edges reference
|
||||
# nodes by user-provided string ids. The Postgres mirror needs to:
|
||||
#
|
||||
# 1. Run all three writes inside one PG transaction (so the just-created
|
||||
# nodes are visible when we resolve their UUIDs for the edge insert).
|
||||
# 2. Translate edge source_id/target_id strings → workflow_nodes.id UUIDs
|
||||
# after the bulk_create returns them.
|
||||
#
|
||||
# Each helper opens exactly one ``dual_write`` call (one PG txn) and uses
|
||||
# the connection from whichever repo it was instantiated with to spin up
|
||||
# any sibling repos it needs.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _dual_write_workflow_create(
|
||||
mongo_workflow_id: str,
|
||||
user_id: str,
|
||||
name: str,
|
||||
description: str,
|
||||
nodes_data: List[Dict],
|
||||
edges_data: List[Dict],
|
||||
graph_version: int = 1,
|
||||
) -> None:
|
||||
"""Mirror a Mongo workflow create into Postgres."""
|
||||
|
||||
def _do(repo: WorkflowsRepository) -> None:
|
||||
conn = repo._conn
|
||||
wf = repo.create(
|
||||
user_id,
|
||||
name,
|
||||
description=description,
|
||||
legacy_mongo_id=mongo_workflow_id,
|
||||
)
|
||||
_write_graph(conn, wf["id"], graph_version, nodes_data, edges_data)
|
||||
|
||||
dual_write(WorkflowsRepository, _do)
|
||||
|
||||
|
||||
def _dual_write_workflow_update(
|
||||
mongo_workflow_id: str,
|
||||
user_id: str,
|
||||
name: str,
|
||||
description: str,
|
||||
nodes_data: List[Dict],
|
||||
edges_data: List[Dict],
|
||||
next_graph_version: int,
|
||||
) -> None:
|
||||
"""Mirror a Mongo workflow update into Postgres.
|
||||
|
||||
Mirrors the Mongo route: insert the new graph_version's nodes/edges,
|
||||
bump the workflow's name/description/current_graph_version, then drop
|
||||
every other graph_version's nodes/edges.
|
||||
"""
|
||||
|
||||
def _do(repo: WorkflowsRepository) -> None:
|
||||
conn = repo._conn
|
||||
wf = _resolve_pg_workflow(conn, mongo_workflow_id)
|
||||
if wf is None:
|
||||
return
|
||||
_write_graph(conn, wf["id"], next_graph_version, nodes_data, edges_data)
|
||||
repo.update(wf["id"], user_id, {
|
||||
"name": name,
|
||||
"description": description,
|
||||
"current_graph_version": next_graph_version,
|
||||
})
|
||||
WorkflowNodesRepository(conn).delete_other_versions(
|
||||
wf["id"], next_graph_version,
|
||||
)
|
||||
WorkflowEdgesRepository(conn).delete_other_versions(
|
||||
wf["id"], next_graph_version,
|
||||
)
|
||||
|
||||
dual_write(WorkflowsRepository, _do)
|
||||
|
||||
|
||||
def _dual_write_workflow_delete(mongo_workflow_id: str, user_id: str) -> None:
|
||||
"""Mirror a Mongo workflow delete into Postgres.
|
||||
|
||||
The CASCADE on workflows.id → workflow_nodes/workflow_edges takes
|
||||
care of the children automatically.
|
||||
"""
|
||||
|
||||
def _do(repo: WorkflowsRepository) -> None:
|
||||
wf = _resolve_pg_workflow(repo._conn, mongo_workflow_id)
|
||||
if wf is not None:
|
||||
repo.delete(wf["id"], user_id)
|
||||
|
||||
dual_write(WorkflowsRepository, _do)
|
||||
|
||||
|
||||
def _resolve_pg_workflow(conn, mongo_workflow_id: str) -> Optional[Dict]:
|
||||
"""Look up a Postgres workflow by its Mongo ObjectId string."""
|
||||
from sqlalchemy import text as _text
|
||||
row = conn.execute(
|
||||
_text("SELECT id FROM workflows WHERE legacy_mongo_id = :legacy_id"),
|
||||
{"legacy_id": mongo_workflow_id},
|
||||
).fetchone()
|
||||
return {"id": str(row[0])} if row else None
|
||||
def _resolve_workflow(repo: WorkflowsRepository, workflow_id: str, user_id: str):
|
||||
"""Resolve a workflow by UUID or legacy Mongo id, scoped to user."""
|
||||
if not workflow_id:
|
||||
return None
|
||||
if looks_like_uuid(workflow_id):
|
||||
row = repo.get(workflow_id, user_id)
|
||||
if row is not None:
|
||||
return row
|
||||
return repo.get_by_legacy_id(workflow_id, user_id)
|
||||
|
||||
|
||||
def _write_graph(
|
||||
@@ -150,14 +48,13 @@ def _write_graph(
|
||||
graph_version: int,
|
||||
nodes_data: List[Dict],
|
||||
edges_data: List[Dict],
|
||||
) -> None:
|
||||
"""Bulk-create nodes + edges for one graph version inside one txn.
|
||||
) -> List[Dict]:
|
||||
"""Bulk-create nodes + edges for one graph version. Uses ON CONFLICT upsert.
|
||||
|
||||
Edges arrive with source/target as user-provided node-id strings
|
||||
(the same shape the Mongo route stores). We bulk-insert nodes first,
|
||||
capture their ``node_id → UUID`` map from the returned rows, then
|
||||
translate edge source/target strings to those UUIDs before the edge
|
||||
bulk insert. Edges referencing missing nodes are dropped (logged).
|
||||
Edges arrive with source/target as user-provided node-id strings. We
|
||||
insert nodes first, capture their ``node_id → UUID`` map, then
|
||||
translate edges before insertion. Edges referencing missing nodes are
|
||||
dropped with a warning.
|
||||
"""
|
||||
nodes_repo = WorkflowNodesRepository(conn)
|
||||
edges_repo = WorkflowEdgesRepository(conn)
|
||||
@@ -173,13 +70,13 @@ def _write_graph(
|
||||
"description": n.get("description", ""),
|
||||
"position": n.get("position", {"x": 0, "y": 0}),
|
||||
"config": n.get("data", {}),
|
||||
"legacy_mongo_id": n.get("legacy_mongo_id"),
|
||||
}
|
||||
for n in nodes_data
|
||||
],
|
||||
)
|
||||
node_uuid_by_str = {n["node_id"]: n["id"] for n in created_nodes}
|
||||
else:
|
||||
created_nodes = []
|
||||
node_uuid_by_str = {}
|
||||
|
||||
if edges_data:
|
||||
@@ -191,7 +88,7 @@ def _write_graph(
|
||||
to_uuid = node_uuid_by_str.get(tgt)
|
||||
if not from_uuid or not to_uuid:
|
||||
current_app.logger.warning(
|
||||
"PG dual-write: dropping edge %s; node refs unresolved "
|
||||
"Workflow graph write: dropping edge %s; node refs unresolved "
|
||||
"(source=%s, target=%s)",
|
||||
e.get("id"), src, tgt,
|
||||
)
|
||||
@@ -204,36 +101,42 @@ def _write_graph(
|
||||
"target_handle": e.get("targetHandle"),
|
||||
})
|
||||
if translated_edges:
|
||||
edges_repo.bulk_create(pg_workflow_id, graph_version, translated_edges)
|
||||
edges_repo.bulk_create(
|
||||
pg_workflow_id, graph_version, translated_edges,
|
||||
)
|
||||
|
||||
return created_nodes
|
||||
|
||||
|
||||
def serialize_workflow(w: Dict) -> Dict:
|
||||
"""Serialize workflow document to API response format."""
|
||||
"""Serialize workflow row to API response format."""
|
||||
created_at = w.get("created_at")
|
||||
updated_at = w.get("updated_at")
|
||||
return {
|
||||
"id": str(w["_id"]),
|
||||
"id": str(w["id"]),
|
||||
"name": w.get("name"),
|
||||
"description": w.get("description"),
|
||||
"created_at": w["created_at"].isoformat() if w.get("created_at") else None,
|
||||
"updated_at": w["updated_at"].isoformat() if w.get("updated_at") else None,
|
||||
"created_at": created_at.isoformat() if hasattr(created_at, "isoformat") else created_at,
|
||||
"updated_at": updated_at.isoformat() if hasattr(updated_at, "isoformat") else updated_at,
|
||||
}
|
||||
|
||||
|
||||
def serialize_node(n: Dict) -> Dict:
|
||||
"""Serialize workflow node document to API response format."""
|
||||
"""Serialize workflow node row to API response format."""
|
||||
return {
|
||||
"id": n["id"],
|
||||
"type": n["type"],
|
||||
"id": n["node_id"],
|
||||
"type": n["node_type"],
|
||||
"title": n.get("title"),
|
||||
"description": n.get("description"),
|
||||
"position": n.get("position"),
|
||||
"data": n.get("config", {}),
|
||||
"data": n.get("config", {}) or {},
|
||||
}
|
||||
|
||||
|
||||
def serialize_edge(e: Dict) -> Dict:
|
||||
"""Serialize workflow edge document to API response format."""
|
||||
"""Serialize workflow edge row to API response format."""
|
||||
return {
|
||||
"id": e["id"],
|
||||
"id": e["edge_id"],
|
||||
"source": e.get("source_id"),
|
||||
"target": e.get("target_id"),
|
||||
"sourceHandle": e.get("source_handle"),
|
||||
@@ -242,7 +145,7 @@ def serialize_edge(e: Dict) -> Dict:
|
||||
|
||||
|
||||
def get_workflow_graph_version(workflow: Dict) -> int:
|
||||
"""Get current graph version with legacy fallback."""
|
||||
"""Get current graph version with fallback."""
|
||||
raw_version = workflow.get("current_graph_version", 1)
|
||||
try:
|
||||
version = int(raw_version)
|
||||
@@ -251,22 +154,6 @@ def get_workflow_graph_version(workflow: Dict) -> int:
|
||||
return 1
|
||||
|
||||
|
||||
def fetch_graph_documents(collection, workflow_id: str, graph_version: int) -> List[Dict]:
|
||||
"""Fetch graph docs for active version, with fallback for legacy unversioned data."""
|
||||
docs = list(
|
||||
collection.find({"workflow_id": workflow_id, "graph_version": graph_version})
|
||||
)
|
||||
if docs:
|
||||
return docs
|
||||
if graph_version == 1:
|
||||
return list(
|
||||
collection.find(
|
||||
{"workflow_id": workflow_id, "graph_version": {"$exists": False}}
|
||||
)
|
||||
)
|
||||
return docs
|
||||
|
||||
|
||||
def validate_json_schema_payload(
|
||||
json_schema: Any,
|
||||
) -> tuple[Optional[Dict[str, Any]], Optional[str]]:
|
||||
@@ -487,53 +374,6 @@ def _can_reach_end(
|
||||
return any(_can_reach_end(t, edges, node_map, end_ids, visited) for t in outgoing if t)
|
||||
|
||||
|
||||
def create_workflow_nodes(
|
||||
workflow_id: str, nodes_data: List[Dict], graph_version: int
|
||||
) -> List[Dict]:
|
||||
"""Insert workflow nodes into Mongo and return rows with Mongo ids."""
|
||||
if nodes_data:
|
||||
mongo_nodes = [
|
||||
{
|
||||
"id": n["id"],
|
||||
"workflow_id": workflow_id,
|
||||
"graph_version": graph_version,
|
||||
"type": n["type"],
|
||||
"title": n.get("title", ""),
|
||||
"description": n.get("description", ""),
|
||||
"position": n.get("position", {"x": 0, "y": 0}),
|
||||
"config": n.get("data", {}),
|
||||
}
|
||||
for n in nodes_data
|
||||
]
|
||||
result = workflow_nodes_collection.insert_many(mongo_nodes)
|
||||
return [
|
||||
{**node, "legacy_mongo_id": str(inserted_id)}
|
||||
for node, inserted_id in zip(nodes_data, result.inserted_ids)
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
def create_workflow_edges(
|
||||
workflow_id: str, edges_data: List[Dict], graph_version: int
|
||||
) -> None:
|
||||
"""Insert workflow edges into database."""
|
||||
if edges_data:
|
||||
workflow_edges_collection.insert_many(
|
||||
[
|
||||
{
|
||||
"id": e["id"],
|
||||
"workflow_id": workflow_id,
|
||||
"graph_version": graph_version,
|
||||
"source_id": e.get("source"),
|
||||
"target_id": e.get("target"),
|
||||
"source_handle": e.get("sourceHandle"),
|
||||
"target_handle": e.get("targetHandle"),
|
||||
}
|
||||
for e in edges_data
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@workflows_ns.route("/workflows")
|
||||
class WorkflowList(Resource):
|
||||
|
||||
@@ -545,6 +385,7 @@ class WorkflowList(Resource):
|
||||
data = request.get_json()
|
||||
|
||||
name = data.get("name", "").strip()
|
||||
description = data.get("description", "")
|
||||
nodes_data = data.get("nodes", [])
|
||||
edges_data = data.get("edges", [])
|
||||
|
||||
@@ -555,44 +396,16 @@ class WorkflowList(Resource):
|
||||
)
|
||||
nodes_data = normalize_agent_node_json_schemas(nodes_data)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
workflow_doc = {
|
||||
"name": name,
|
||||
"description": data.get("description", ""),
|
||||
"user": user_id,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"current_graph_version": 1,
|
||||
}
|
||||
|
||||
result, error = safe_db_operation(
|
||||
lambda: workflows_collection.insert_one(workflow_doc),
|
||||
"Failed to create workflow",
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
|
||||
workflow_id = str(result.inserted_id)
|
||||
|
||||
try:
|
||||
created_nodes = create_workflow_nodes(workflow_id, nodes_data, 1)
|
||||
create_workflow_edges(workflow_id, edges_data, 1)
|
||||
with db_session() as conn:
|
||||
repo = WorkflowsRepository(conn)
|
||||
workflow = repo.create(user_id, name, description=description)
|
||||
pg_workflow_id = str(workflow["id"])
|
||||
_write_graph(conn, pg_workflow_id, 1, nodes_data, edges_data)
|
||||
except Exception as err:
|
||||
workflow_nodes_collection.delete_many({"workflow_id": workflow_id})
|
||||
workflow_edges_collection.delete_many({"workflow_id": workflow_id})
|
||||
workflows_collection.delete_one({"_id": result.inserted_id})
|
||||
return _workflow_error_response("Failed to create workflow structure", err)
|
||||
return _workflow_error_response("Failed to create workflow", err)
|
||||
|
||||
_dual_write_workflow_create(
|
||||
workflow_id,
|
||||
user_id,
|
||||
name,
|
||||
data.get("description", ""),
|
||||
created_nodes,
|
||||
edges_data,
|
||||
)
|
||||
|
||||
return success_response({"id": workflow_id}, 201)
|
||||
return success_response({"id": pg_workflow_id}, 201)
|
||||
|
||||
|
||||
@workflows_ns.route("/workflows/<string:workflow_id>")
|
||||
@@ -602,23 +415,22 @@ class WorkflowDetail(Resource):
|
||||
def get(self, workflow_id: str):
|
||||
"""Get workflow details with nodes and edges."""
|
||||
user_id = get_user_id()
|
||||
obj_id, error = validate_object_id(workflow_id, "Workflow")
|
||||
if error:
|
||||
return error
|
||||
|
||||
workflow, error = check_resource_ownership(
|
||||
workflows_collection, obj_id, user_id, "Workflow"
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
|
||||
graph_version = get_workflow_graph_version(workflow)
|
||||
nodes = fetch_graph_documents(
|
||||
workflow_nodes_collection, workflow_id, graph_version
|
||||
)
|
||||
edges = fetch_graph_documents(
|
||||
workflow_edges_collection, workflow_id, graph_version
|
||||
)
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
repo = WorkflowsRepository(conn)
|
||||
workflow = _resolve_workflow(repo, workflow_id, user_id)
|
||||
if workflow is None:
|
||||
return error_response("Workflow not found", 404)
|
||||
pg_workflow_id = str(workflow["id"])
|
||||
graph_version = get_workflow_graph_version(workflow)
|
||||
nodes = WorkflowNodesRepository(conn).find_by_version(
|
||||
pg_workflow_id, graph_version,
|
||||
)
|
||||
edges = WorkflowEdgesRepository(conn).find_by_version(
|
||||
pg_workflow_id, graph_version,
|
||||
)
|
||||
except Exception as err:
|
||||
return _workflow_error_response("Failed to fetch workflow", err)
|
||||
|
||||
return success_response(
|
||||
{
|
||||
@@ -633,18 +445,9 @@ class WorkflowDetail(Resource):
|
||||
def put(self, workflow_id: str):
|
||||
"""Update workflow and replace nodes/edges."""
|
||||
user_id = get_user_id()
|
||||
obj_id, error = validate_object_id(workflow_id, "Workflow")
|
||||
if error:
|
||||
return error
|
||||
|
||||
workflow, error = check_resource_ownership(
|
||||
workflows_collection, obj_id, user_id, "Workflow"
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
|
||||
data = request.get_json()
|
||||
name = data.get("name", "").strip()
|
||||
description = data.get("description", "")
|
||||
nodes_data = data.get("nodes", [])
|
||||
edges_data = data.get("edges", [])
|
||||
|
||||
@@ -655,67 +458,36 @@ class WorkflowDetail(Resource):
|
||||
)
|
||||
nodes_data = normalize_agent_node_json_schemas(nodes_data)
|
||||
|
||||
current_graph_version = get_workflow_graph_version(workflow)
|
||||
next_graph_version = current_graph_version + 1
|
||||
try:
|
||||
created_nodes = create_workflow_nodes(
|
||||
workflow_id, nodes_data, next_graph_version,
|
||||
)
|
||||
create_workflow_edges(workflow_id, edges_data, next_graph_version)
|
||||
except Exception as err:
|
||||
workflow_nodes_collection.delete_many(
|
||||
{"workflow_id": workflow_id, "graph_version": next_graph_version}
|
||||
)
|
||||
workflow_edges_collection.delete_many(
|
||||
{"workflow_id": workflow_id, "graph_version": next_graph_version}
|
||||
)
|
||||
return _workflow_error_response("Failed to update workflow structure", err)
|
||||
with db_session() as conn:
|
||||
repo = WorkflowsRepository(conn)
|
||||
workflow = _resolve_workflow(repo, workflow_id, user_id)
|
||||
if workflow is None:
|
||||
return error_response("Workflow not found", 404)
|
||||
pg_workflow_id = str(workflow["id"])
|
||||
current_graph_version = get_workflow_graph_version(workflow)
|
||||
next_graph_version = current_graph_version + 1
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
_, error = safe_db_operation(
|
||||
lambda: workflows_collection.update_one(
|
||||
{"_id": obj_id},
|
||||
{
|
||||
"$set": {
|
||||
_write_graph(
|
||||
conn, pg_workflow_id, next_graph_version,
|
||||
nodes_data, edges_data,
|
||||
)
|
||||
repo.update(
|
||||
pg_workflow_id, user_id,
|
||||
{
|
||||
"name": name,
|
||||
"description": data.get("description", ""),
|
||||
"updated_at": now,
|
||||
"description": description,
|
||||
"current_graph_version": next_graph_version,
|
||||
}
|
||||
},
|
||||
),
|
||||
"Failed to update workflow",
|
||||
)
|
||||
if error:
|
||||
workflow_nodes_collection.delete_many(
|
||||
{"workflow_id": workflow_id, "graph_version": next_graph_version}
|
||||
)
|
||||
workflow_edges_collection.delete_many(
|
||||
{"workflow_id": workflow_id, "graph_version": next_graph_version}
|
||||
)
|
||||
return error
|
||||
|
||||
try:
|
||||
workflow_nodes_collection.delete_many(
|
||||
{"workflow_id": workflow_id, "graph_version": {"$ne": next_graph_version}}
|
||||
)
|
||||
workflow_edges_collection.delete_many(
|
||||
{"workflow_id": workflow_id, "graph_version": {"$ne": next_graph_version}}
|
||||
)
|
||||
except Exception as cleanup_err:
|
||||
current_app.logger.warning(
|
||||
f"Failed to clean old workflow graph versions for {workflow_id}: {cleanup_err}"
|
||||
)
|
||||
|
||||
_dual_write_workflow_update(
|
||||
workflow_id,
|
||||
user_id,
|
||||
name,
|
||||
data.get("description", ""),
|
||||
created_nodes,
|
||||
edges_data,
|
||||
next_graph_version,
|
||||
)
|
||||
},
|
||||
)
|
||||
WorkflowNodesRepository(conn).delete_other_versions(
|
||||
pg_workflow_id, next_graph_version,
|
||||
)
|
||||
WorkflowEdgesRepository(conn).delete_other_versions(
|
||||
pg_workflow_id, next_graph_version,
|
||||
)
|
||||
except Exception as err:
|
||||
return _workflow_error_response("Failed to update workflow", err)
|
||||
|
||||
return success_response()
|
||||
|
||||
@@ -723,23 +495,15 @@ class WorkflowDetail(Resource):
|
||||
def delete(self, workflow_id: str):
|
||||
"""Delete workflow and its graph."""
|
||||
user_id = get_user_id()
|
||||
obj_id, error = validate_object_id(workflow_id, "Workflow")
|
||||
if error:
|
||||
return error
|
||||
|
||||
workflow, error = check_resource_ownership(
|
||||
workflows_collection, obj_id, user_id, "Workflow"
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
|
||||
try:
|
||||
workflow_nodes_collection.delete_many({"workflow_id": workflow_id})
|
||||
workflow_edges_collection.delete_many({"workflow_id": workflow_id})
|
||||
workflows_collection.delete_one({"_id": workflow["_id"], "user": user_id})
|
||||
with db_session() as conn:
|
||||
repo = WorkflowsRepository(conn)
|
||||
workflow = _resolve_workflow(repo, workflow_id, user_id)
|
||||
if workflow is None:
|
||||
return error_response("Workflow not found", 404)
|
||||
# ON DELETE CASCADE on workflow_nodes/edges cleans children.
|
||||
repo.delete(str(workflow["id"]), user_id)
|
||||
except Exception as err:
|
||||
return _workflow_error_response("Failed to delete workflow", err)
|
||||
|
||||
_dual_write_workflow_delete(workflow_id, user_id)
|
||||
|
||||
return success_response()
|
||||
|
||||
@@ -20,8 +20,8 @@ from application.api.v1.translator import (
|
||||
translate_response,
|
||||
translate_stream_event,
|
||||
)
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.session import db_readonly
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -39,9 +39,8 @@ def _extract_bearer_token() -> Optional[str]:
|
||||
def _lookup_agent(api_key: str) -> Optional[Dict]:
|
||||
"""Look up the agent document for this API key."""
|
||||
try:
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
return db["agents"].find_one({"key": api_key})
|
||||
with db_readonly() as conn:
|
||||
return AgentsRepository(conn).find_by_key(api_key)
|
||||
except Exception:
|
||||
logger.warning("Failed to look up agent for API key", exc_info=True)
|
||||
return None
|
||||
@@ -90,8 +89,14 @@ def chat_completions():
|
||||
)
|
||||
|
||||
# Link decoded_token to the agent's owner so continuation state,
|
||||
# logs, and tool execution use the correct user identity.
|
||||
agent_user = agent_doc.get("user") if agent_doc else None
|
||||
# logs, and tool execution use the correct user identity. The PG
|
||||
# ``agents`` row exposes the owner via ``user_id`` (``user`` is the
|
||||
# legacy Mongo field name kept in ``row_to_dict`` only for the
|
||||
# mapping ``id``/``_id``).
|
||||
agent_user = (
|
||||
(agent_doc.get("user_id") or agent_doc.get("user"))
|
||||
if agent_doc else None
|
||||
)
|
||||
decoded_token = {"sub": agent_user or "api_key_user"}
|
||||
|
||||
try:
|
||||
@@ -290,39 +295,32 @@ def list_models():
|
||||
)
|
||||
|
||||
try:
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
agents_collection = db["agents"]
|
||||
with db_readonly() as conn:
|
||||
agents_repo = AgentsRepository(conn)
|
||||
agent = agents_repo.find_by_key(api_key)
|
||||
if not agent:
|
||||
return make_response(
|
||||
jsonify({"error": {"message": "Invalid API key", "type": "auth_error"}}),
|
||||
401,
|
||||
)
|
||||
|
||||
# Find the agent for this api_key
|
||||
agent = agents_collection.find_one({"key": api_key})
|
||||
if not agent:
|
||||
return make_response(
|
||||
jsonify({"error": {"message": "Invalid API key", "type": "auth_error"}}),
|
||||
401,
|
||||
)
|
||||
|
||||
user = agent.get("user")
|
||||
|
||||
# Return all agents belonging to this user
|
||||
user_agents = list(agents_collection.find({"user": user}))
|
||||
|
||||
models = []
|
||||
for ag in user_agents:
|
||||
created = ag.get("createdAt")
|
||||
created_ts = int(created.timestamp()) if created else int(time.time())
|
||||
model_id = str(ag.get("_id") or ag.get("id") or "")
|
||||
models.append({
|
||||
"id": model_id,
|
||||
"object": "model",
|
||||
"created": created_ts,
|
||||
"owned_by": "docsgpt",
|
||||
"name": ag.get("name", ""),
|
||||
"description": ag.get("description", ""),
|
||||
})
|
||||
created = agent.get("created_at") or agent.get("createdAt")
|
||||
created_ts = (
|
||||
int(created.timestamp()) if hasattr(created, "timestamp")
|
||||
else int(time.time())
|
||||
)
|
||||
model_id = str(agent.get("id") or agent.get("_id") or "")
|
||||
model = {
|
||||
"id": model_id,
|
||||
"object": "model",
|
||||
"created": created_ts,
|
||||
"owned_by": "docsgpt",
|
||||
"name": agent.get("name", ""),
|
||||
"description": agent.get("description", ""),
|
||||
}
|
||||
|
||||
return make_response(
|
||||
jsonify({"object": "list", "data": models}),
|
||||
jsonify({"object": "list", "data": [model]}),
|
||||
200,
|
||||
)
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import uuid
|
||||
@@ -20,6 +21,7 @@ from application.api.connector.routes import connector # noqa: E402
|
||||
from application.api.v1 import v1_bp # noqa: E402
|
||||
from application.celery_init import celery # noqa: E402
|
||||
from application.core.settings import settings # noqa: E402
|
||||
from application.storage.db.bootstrap import ensure_database_ready # noqa: E402
|
||||
from application.stt.upload_limits import ( # noqa: E402
|
||||
build_stt_file_size_limit_message,
|
||||
should_reject_stt_request,
|
||||
@@ -32,6 +34,17 @@ if platform.system() == "Windows":
|
||||
pathlib.PosixPath = pathlib.WindowsPath
|
||||
dotenv.load_dotenv()
|
||||
|
||||
# Self-bootstrap the user-data Postgres DB. Runs before any blueprint or
|
||||
# repository touches the engine, so the first request can't race the
|
||||
# schema being created. Gated by AUTO_CREATE_DB / AUTO_MIGRATE settings
|
||||
# (default ON for dev; disable in prod if schema is managed out-of-band).
|
||||
ensure_database_ready(
|
||||
settings.POSTGRES_URI,
|
||||
create_db=settings.AUTO_CREATE_DB,
|
||||
migrate=settings.AUTO_MIGRATE,
|
||||
logger=logging.getLogger("application.app"),
|
||||
)
|
||||
|
||||
app = Flask(__name__)
|
||||
app.register_blueprint(user)
|
||||
app.register_blueprint(answer)
|
||||
@@ -120,6 +133,12 @@ def enforce_stt_request_size_limits():
|
||||
def authenticate_request():
|
||||
if request.method == "OPTIONS":
|
||||
return "", 200
|
||||
# OpenAI-compatible routes authenticate via opaque agent API keys in the
|
||||
# Authorization header, which the JWT decoder below would reject. Defer
|
||||
# auth to the route handlers (see application/api/v1/routes.py).
|
||||
if request.path.startswith("/v1/"):
|
||||
request.decoded_token = None
|
||||
return None
|
||||
decoded_token = handle_auth(request)
|
||||
if not decoded_token:
|
||||
request.decoded_token = None
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
from application.core.settings import settings
|
||||
from pymongo import MongoClient
|
||||
|
||||
|
||||
class MongoDB:
|
||||
_client = None
|
||||
|
||||
@classmethod
|
||||
def get_client(cls):
|
||||
"""
|
||||
Get the MongoDB client instance, creating it if necessary.
|
||||
"""
|
||||
if cls._client is None:
|
||||
cls._client = MongoClient(settings.MONGO_URI)
|
||||
return cls._client
|
||||
|
||||
@classmethod
|
||||
def close_client(cls):
|
||||
"""
|
||||
Close the MongoDB client connection.
|
||||
"""
|
||||
if cls._client is not None:
|
||||
cls._client.close()
|
||||
cls._client = None
|
||||
@@ -26,13 +26,14 @@ class Settings(BaseSettings):
|
||||
|
||||
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
|
||||
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
|
||||
MONGO_URI: str = "mongodb://localhost:27017/docsgpt"
|
||||
MONGO_DB_NAME: str = "docsgpt"
|
||||
# Only consulted when VECTOR_STORE=mongodb or when running scripts/db/backfill.py; user data lives in Postgres.
|
||||
MONGO_URI: Optional[str] = None
|
||||
# User-data Postgres DB.
|
||||
POSTGRES_URI: Optional[str] = None
|
||||
|
||||
# MongoDB→Postgres migration: dual-write to Postgres (Mongo stays source of truth)
|
||||
USE_POSTGRES: bool = False
|
||||
# On app startup, apply pending Alembic migrations. Default ON for dev; disable in prod if you manage schema out-of-band.
|
||||
AUTO_MIGRATE: bool = True
|
||||
# On app startup, create the target Postgres database if it's missing (requires CREATEDB privilege). Dev-friendly default.
|
||||
AUTO_CREATE_DB: bool = True
|
||||
LLM_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf")
|
||||
DEFAULT_MAX_HISTORY: int = 150
|
||||
DEFAULT_LLM_TOKEN_LIMIT: int = 128000 # Fallback when model not found in registry
|
||||
|
||||
@@ -127,15 +127,33 @@ class GoogleLLM(BaseLLM):
|
||||
).uri,
|
||||
)
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
attachments_collection = db["attachments"]
|
||||
if "_id" in attachment:
|
||||
attachments_collection.update_one(
|
||||
{"_id": attachment["_id"]}, {"$set": {"google_file_uri": file_uri}}
|
||||
# Cache the Google file URI on the attachment row so we don't
|
||||
# re-upload on the next LLM call. Accept either a PG UUID
|
||||
# (``id``) or a legacy Mongo ObjectId (``_id``). Opened per
|
||||
# write — this runs mid-LLM-call, so we don't wrap the
|
||||
# surrounding generator in a long-lived session.
|
||||
attachment_id = attachment.get("id") or attachment.get("_id")
|
||||
if attachment_id:
|
||||
user_id = None
|
||||
decoded = getattr(self, "decoded_token", None)
|
||||
if isinstance(decoded, dict):
|
||||
user_id = decoded.get("sub")
|
||||
from application.storage.db.repositories.attachments import (
|
||||
AttachmentsRepository,
|
||||
)
|
||||
from application.storage.db.session import db_session
|
||||
|
||||
try:
|
||||
with db_session() as conn:
|
||||
AttachmentsRepository(conn).update_any(
|
||||
str(attachment_id),
|
||||
user_id,
|
||||
{"google_file_uri": file_uri},
|
||||
)
|
||||
except Exception as cache_err:
|
||||
logging.warning(
|
||||
f"Failed to cache google_file_uri on attachment {attachment_id}: {cache_err}"
|
||||
)
|
||||
return file_uri
|
||||
except Exception as e:
|
||||
logging.error(f"Error uploading file to Google AI: {e}", exc_info=True)
|
||||
|
||||
@@ -527,15 +527,34 @@ class OpenAILLM(BaseLLM):
|
||||
).id,
|
||||
)
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
attachments_collection = db["attachments"]
|
||||
if "_id" in attachment:
|
||||
attachments_collection.update_one(
|
||||
{"_id": attachment["_id"]}, {"$set": {"openai_file_id": file_id}}
|
||||
# Cache the OpenAI file id on the attachment row so we don't
|
||||
# re-upload the same blob on the next LLM call. Prefer the PG
|
||||
# UUID (``id``) when present; fall back to the legacy Mongo
|
||||
# ObjectId string (``_id``). Opened per-write — this runs
|
||||
# inside the hot LLM path, so we don't want a long-lived
|
||||
# session wrapping the generator.
|
||||
attachment_id = attachment.get("id") or attachment.get("_id")
|
||||
if attachment_id:
|
||||
user_id = None
|
||||
decoded = getattr(self, "decoded_token", None)
|
||||
if isinstance(decoded, dict):
|
||||
user_id = decoded.get("sub")
|
||||
from application.storage.db.repositories.attachments import (
|
||||
AttachmentsRepository,
|
||||
)
|
||||
from application.storage.db.session import db_session
|
||||
|
||||
try:
|
||||
with db_session() as conn:
|
||||
AttachmentsRepository(conn).update_any(
|
||||
str(attachment_id),
|
||||
user_id,
|
||||
{"openai_file_id": file_id},
|
||||
)
|
||||
except Exception as cache_err:
|
||||
logging.warning(
|
||||
f"Failed to cache openai_file_id on attachment {attachment_id}: {cache_err}"
|
||||
)
|
||||
return file_id
|
||||
except Exception as e:
|
||||
logging.error(f"Error uploading file to OpenAI: {e}", exc_info=True)
|
||||
|
||||
@@ -6,8 +6,8 @@ import logging
|
||||
import uuid
|
||||
from typing import Any, Callable, Dict, Generator, List
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.repositories.stack_logs import StackLogsRepository
|
||||
from application.storage.db.session import db_session
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
@@ -101,7 +101,7 @@ def _consume_and_log(generator: Generator, context: "LogContext"):
|
||||
except Exception as e:
|
||||
logging.exception(f"Error in {context.endpoint} - {context.activity_id}: {e}")
|
||||
context.stacks.append({"component": "error", "data": {"message": str(e)}})
|
||||
_log_to_mongodb(
|
||||
_log_activity_to_db(
|
||||
endpoint=context.endpoint,
|
||||
activity_id=context.activity_id,
|
||||
user=context.user,
|
||||
@@ -112,7 +112,7 @@ def _consume_and_log(generator: Generator, context: "LogContext"):
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
_log_to_mongodb(
|
||||
_log_activity_to_db(
|
||||
endpoint=context.endpoint,
|
||||
activity_id=context.activity_id,
|
||||
user=context.user,
|
||||
@@ -123,7 +123,7 @@ def _consume_and_log(generator: Generator, context: "LogContext"):
|
||||
)
|
||||
|
||||
|
||||
def _log_to_mongodb(
|
||||
def _log_activity_to_db(
|
||||
endpoint: str,
|
||||
activity_id: str,
|
||||
user: str,
|
||||
@@ -132,46 +132,26 @@ def _log_to_mongodb(
|
||||
stacks: List[Dict],
|
||||
level: str,
|
||||
) -> None:
|
||||
"""Append a per-request activity log row to Postgres (``stack_logs``)."""
|
||||
try:
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
user_logs_collection = db["stack_logs"]
|
||||
|
||||
|
||||
|
||||
log_entry = {
|
||||
"endpoint": endpoint,
|
||||
"id": activity_id,
|
||||
"level": level,
|
||||
"user": user,
|
||||
"api_key": api_key,
|
||||
"query": query,
|
||||
"stacks": stacks,
|
||||
"timestamp": datetime.datetime.now(datetime.timezone.utc),
|
||||
}
|
||||
# clean up text fields to be no longer than 10000 characters
|
||||
for key, value in log_entry.items():
|
||||
if isinstance(value, str) and len(value) > 10000:
|
||||
log_entry[key] = value[:10000]
|
||||
|
||||
user_logs_collection.insert_one(log_entry)
|
||||
logging.debug(f"Logged activity to MongoDB: {activity_id}")
|
||||
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.repositories.stack_logs import StackLogsRepository
|
||||
|
||||
dual_write(
|
||||
StackLogsRepository,
|
||||
lambda repo, e=log_entry: repo.insert(
|
||||
activity_id=e["id"],
|
||||
endpoint=e.get("endpoint"),
|
||||
level=e.get("level"),
|
||||
user_id=e.get("user"),
|
||||
api_key=e.get("api_key"),
|
||||
query=e.get("query"),
|
||||
stacks=e.get("stacks"),
|
||||
),
|
||||
)
|
||||
# Clean up text fields to be no longer than 10000 characters so a
|
||||
# runaway payload can't blow up the insert.
|
||||
def _truncate(val):
|
||||
if isinstance(val, str) and len(val) > 10000:
|
||||
return val[:10000]
|
||||
return val
|
||||
|
||||
with db_session() as conn:
|
||||
StackLogsRepository(conn).insert(
|
||||
activity_id=activity_id,
|
||||
endpoint=_truncate(endpoint),
|
||||
level=_truncate(level),
|
||||
user_id=_truncate(user),
|
||||
api_key=_truncate(api_key),
|
||||
query=_truncate(query),
|
||||
stacks=stacks,
|
||||
timestamp=datetime.datetime.now(datetime.timezone.utc),
|
||||
)
|
||||
logging.debug(f"Logged activity to Postgres: {activity_id}")
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to log to MongoDB: {e}", exc_info=True)
|
||||
logging.error(f"Failed to log activity to Postgres: {e}", exc_info=True)
|
||||
|
||||
37
application/parser/connectors/_auth_utils.py
Normal file
37
application/parser/connectors/_auth_utils.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""Shared helpers for connector auth modules.
|
||||
|
||||
These helpers exist so that sensitive values (session tokens, bearer
|
||||
credentials) never end up interpolated into exception messages or log
|
||||
lines. Exception messages frequently flow into ``stack_logs`` (Postgres)
|
||||
and Sentry via ``exc_info=True``, so the raw value must never be the
|
||||
thing we format.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
|
||||
|
||||
def session_token_fingerprint(session_token: str) -> str:
|
||||
"""Return a short, irreversible fingerprint for a session token.
|
||||
|
||||
The returned string is safe to embed in exception messages and log
|
||||
lines: it is a prefix of a SHA-256 digest, clearly tagged so an
|
||||
operator reading the log knows it is a hash and not the token
|
||||
itself. It is stable for a given input, which lets operators
|
||||
correlate "which token failed" across log lines without exposing
|
||||
the credential.
|
||||
|
||||
Args:
|
||||
session_token: The raw session token. Accepts ``None`` or the
|
||||
empty string for defensive callers; both yield a distinct
|
||||
sentinel rather than raising.
|
||||
|
||||
Returns:
|
||||
A string of the form ``"sha256:<6 hex chars>"``, or
|
||||
``"sha256:<empty>"`` when the input is falsy.
|
||||
"""
|
||||
if not session_token:
|
||||
return "sha256:<empty>"
|
||||
digest = hashlib.sha256(session_token.encode("utf-8")).hexdigest()
|
||||
return f"sha256:{digest[:6]}"
|
||||
@@ -6,6 +6,7 @@ from urllib.parse import urlencode
|
||||
import requests
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.parser.connectors._auth_utils import session_token_fingerprint
|
||||
from application.parser.connectors.base import BaseConnectorAuth
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -152,15 +153,19 @@ class ConfluenceAuth(BaseConnectorAuth):
|
||||
return True
|
||||
|
||||
def get_token_info_from_session(self, session_token: str) -> Dict[str, Any]:
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings as app_settings
|
||||
from application.storage.db.repositories.connector_sessions import (
|
||||
ConnectorSessionsRepository,
|
||||
)
|
||||
from application.storage.db.session import db_readonly
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[app_settings.MONGO_DB_NAME]
|
||||
|
||||
session = db["connector_sessions"].find_one({"session_token": session_token})
|
||||
with db_readonly() as conn:
|
||||
session = ConnectorSessionsRepository(conn).get_by_session_token(
|
||||
session_token
|
||||
)
|
||||
if not session:
|
||||
raise ValueError(f"Invalid session token: {session_token}")
|
||||
raise ValueError(
|
||||
f"Invalid session token ({session_token_fingerprint(session_token)})"
|
||||
)
|
||||
|
||||
token_info = session.get("token_info")
|
||||
if not token_info:
|
||||
|
||||
@@ -83,16 +83,17 @@ class ConfluenceLoader(BaseConnectorLoader):
|
||||
|
||||
def _persist_refreshed_tokens(self, token_info: Dict[str, Any]) -> None:
|
||||
try:
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings as app_settings
|
||||
from application.storage.db.repositories.connector_sessions import (
|
||||
ConnectorSessionsRepository,
|
||||
)
|
||||
from application.storage.db.session import db_session
|
||||
|
||||
sanitized = self.auth.sanitize_token_info(token_info)
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[app_settings.MONGO_DB_NAME]
|
||||
db["connector_sessions"].update_one(
|
||||
{"session_token": self.session_token},
|
||||
{"$set": {"token_info": sanitized}},
|
||||
)
|
||||
with db_session() as conn:
|
||||
repo = ConnectorSessionsRepository(conn)
|
||||
session = repo.get_by_session_token(self.session_token)
|
||||
if session:
|
||||
repo.update(str(session["id"]), {"token_info": sanitized})
|
||||
except Exception as e:
|
||||
logger.warning("Failed to persist refreshed tokens: %s", e)
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from googleapiclient.discovery import build
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.parser.connectors._auth_utils import session_token_fingerprint
|
||||
from application.parser.connectors.base import BaseConnectorAuth
|
||||
|
||||
|
||||
@@ -209,23 +210,23 @@ class GoogleDriveAuth(BaseConnectorAuth):
|
||||
|
||||
def get_token_info_from_session(self, session_token: str) -> Dict[str, Any]:
|
||||
try:
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.repositories.connector_sessions import (
|
||||
ConnectorSessionsRepository,
|
||||
)
|
||||
from application.storage.db.session import db_readonly
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
|
||||
sessions_collection = db["connector_sessions"]
|
||||
session = sessions_collection.find_one({"session_token": session_token})
|
||||
with db_readonly() as conn:
|
||||
session = ConnectorSessionsRepository(conn).get_by_session_token(
|
||||
session_token
|
||||
)
|
||||
if not session:
|
||||
raise ValueError(f"Invalid session token: {session_token}")
|
||||
raise ValueError(
|
||||
f"Invalid session token ({session_token_fingerprint(session_token)})"
|
||||
)
|
||||
|
||||
if "token_info" not in session:
|
||||
raise ValueError("Session missing token information")
|
||||
|
||||
token_info = session["token_info"]
|
||||
token_info = session.get("token_info")
|
||||
if not token_info:
|
||||
raise ValueError("Invalid token information")
|
||||
raise ValueError("Session missing token information")
|
||||
|
||||
required_fields = ["access_token", "refresh_token"]
|
||||
missing_fields = [field for field in required_fields if field not in token_info or not token_info.get(field)]
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Optional, Dict, Any
|
||||
from msal import ConfidentialClientApplication
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.parser.connectors._auth_utils import session_token_fingerprint
|
||||
from application.parser.connectors.base import BaseConnectorAuth
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -77,24 +78,24 @@ class SharePointAuth(BaseConnectorAuth):
|
||||
|
||||
def get_token_info_from_session(self, session_token: str) -> Dict[str, Any]:
|
||||
try:
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.repositories.connector_sessions import (
|
||||
ConnectorSessionsRepository,
|
||||
)
|
||||
from application.storage.db.session import db_readonly
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
|
||||
sessions_collection = db["connector_sessions"]
|
||||
session = sessions_collection.find_one({"session_token": session_token})
|
||||
with db_readonly() as conn:
|
||||
session = ConnectorSessionsRepository(conn).get_by_session_token(
|
||||
session_token
|
||||
)
|
||||
|
||||
if not session:
|
||||
raise ValueError(f"Invalid session token: {session_token}")
|
||||
raise ValueError(
|
||||
f"Invalid session token ({session_token_fingerprint(session_token)})"
|
||||
)
|
||||
|
||||
if "token_info" not in session:
|
||||
raise ValueError("Session missing token information")
|
||||
|
||||
token_info = session["token_info"]
|
||||
token_info = session.get("token_info")
|
||||
if not token_info:
|
||||
raise ValueError("Invalid token information")
|
||||
raise ValueError("Session missing token information")
|
||||
|
||||
required_fields = ["access_token", "refresh_token"]
|
||||
missing_fields = [field for field in required_fields if field not in token_info or not token_info.get(field)]
|
||||
|
||||
@@ -4,6 +4,7 @@ import os
|
||||
import tempfile
|
||||
import mimetypes
|
||||
from typing import List, Optional
|
||||
from application.core.url_validation import SSRFError, validate_url
|
||||
from application.parser.remote.base import BaseRemote
|
||||
from application.parser.schema.base import Document
|
||||
|
||||
@@ -108,6 +109,11 @@ class S3Loader(BaseRemote):
|
||||
logger.info(f"Normalized endpoint URL: {normalized_endpoint}")
|
||||
logger.info(f"Bucket name: '{corrected_bucket}'")
|
||||
|
||||
try:
|
||||
normalized_endpoint = validate_url(normalized_endpoint)
|
||||
except SSRFError as e:
|
||||
raise ValueError(f"Invalid S3 endpoint_url: {e}") from e
|
||||
|
||||
client_kwargs["endpoint_url"] = normalized_endpoint
|
||||
# Use path-style addressing for S3-compatible services
|
||||
# (DigitalOcean Spaces, MinIO, etc.)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from application.core.url_validation import SSRFError, validate_url
|
||||
from application.parser.remote.base import BaseRemote
|
||||
from application.parser.schema.base import Document
|
||||
from application.core.url_validation import validate_url, SSRFError
|
||||
from langchain_community.document_loaders import WebBaseLoader
|
||||
|
||||
headers = {
|
||||
@@ -29,7 +29,9 @@ class WebLoader(BaseRemote):
|
||||
try:
|
||||
url = validate_url(url)
|
||||
except SSRFError as e:
|
||||
logging.error(f"URL validation failed for {url}: {e}")
|
||||
logging.warning(
|
||||
f"Skipping URL due to SSRF validation failure: {url} - {e}"
|
||||
)
|
||||
continue
|
||||
try:
|
||||
loader = self.loader([url], header_template=headers)
|
||||
|
||||
@@ -64,7 +64,6 @@ py==1.11.0
|
||||
pydantic
|
||||
pydantic-core
|
||||
pydantic-settings
|
||||
pymongo==4.16.0
|
||||
pypdf==6.9.2
|
||||
python-dateutil==2.9.0.post0
|
||||
python-dotenv
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import click
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.seed.seeder import DatabaseSeeder
|
||||
|
||||
|
||||
@@ -15,10 +13,7 @@ def seed():
|
||||
@click.option("--force", is_flag=True, help="Force reseeding even if data exists")
|
||||
def init(force):
|
||||
"""Initialize database with seed data"""
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
|
||||
seeder = DatabaseSeeder(db)
|
||||
seeder = DatabaseSeeder()
|
||||
seeder.seed_initial_data(force=force)
|
||||
|
||||
|
||||
|
||||
@@ -1,35 +1,56 @@
|
||||
"""Database seeder — Postgres-native.
|
||||
|
||||
Post-Part-2 cutover: writes template prompts/tools/agents/sources directly
|
||||
into Postgres via the repository layer. No MongoDB dependencies.
|
||||
|
||||
The seeder is invoked by the ``python -m application.seed.commands init``
|
||||
CLI (not at Flask app startup). All template rows are owned by the
|
||||
sentinel user id ``__system__`` — kept in sync with the migration
|
||||
backfill/cleanup-trigger sentinel so template ownership is predictable.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import yaml
|
||||
from bson import ObjectId
|
||||
from bson.dbref import DBRef
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pymongo import MongoClient
|
||||
|
||||
from application.agents.tools.tool_manager import ToolManager
|
||||
from application.api.user.tasks import ingest_remote
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.prompts import PromptsRepository
|
||||
from application.storage.db.repositories.sources import SourcesRepository
|
||||
from application.storage.db.repositories.user_tools import UserToolsRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
|
||||
load_dotenv()
|
||||
tool_config = {}
|
||||
tool_manager = ToolManager(config=tool_config)
|
||||
|
||||
|
||||
# Sentinel user id for template rows (agents/prompts/sources/tools).
|
||||
# Kept in sync with the Postgres backfill / cleanup-trigger sentinel so
|
||||
# template ownership is predictable across the cutover.
|
||||
SYSTEM_USER_ID = "__system__"
|
||||
|
||||
|
||||
class DatabaseSeeder:
|
||||
def __init__(self, db):
|
||||
self.db = db
|
||||
self.tools_collection = self.db["user_tools"]
|
||||
self.sources_collection = self.db["sources"]
|
||||
self.agents_collection = self.db["agents"]
|
||||
self.prompts_collection = self.db["prompts"]
|
||||
self.system_user_id = "system"
|
||||
"""Postgres-backed seeder.
|
||||
|
||||
The constructor accepts an optional positional argument for back
|
||||
compatibility with legacy callers that used to pass a Mongo ``db``
|
||||
handle. The value is ignored — all persistence goes through the
|
||||
Postgres repositories.
|
||||
"""
|
||||
|
||||
def __init__(self, db=None):
|
||||
self._legacy_db = db # unused; retained for call-site compatibility
|
||||
self.system_user_id = SYSTEM_USER_ID
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def seed_initial_data(self, config_path: str = None, force=False):
|
||||
"""Main entry point for seeding all initial data"""
|
||||
"""Main entry point for seeding all initial data."""
|
||||
if not force and self._is_already_seeded():
|
||||
self.logger.info("Database already seeded. Use force=True to reseed.")
|
||||
return
|
||||
@@ -46,20 +67,18 @@ class DatabaseSeeder:
|
||||
raise
|
||||
|
||||
def _seed_from_config(self, config: Dict):
|
||||
"""Seed all data from configuration"""
|
||||
self.logger.info("🌱 Starting seeding...")
|
||||
"""Seed all data from configuration."""
|
||||
self.logger.info("Starting seeding...")
|
||||
|
||||
if not config.get("agents"):
|
||||
self.logger.warning("No agents found in config")
|
||||
return
|
||||
used_tool_ids = set()
|
||||
|
||||
for agent_config in config["agents"]:
|
||||
try:
|
||||
self.logger.info(f"Processing agent: {agent_config['name']}")
|
||||
|
||||
# 1. Handle Source
|
||||
|
||||
source_result = self._handle_source(agent_config)
|
||||
if source_result is False:
|
||||
self.logger.error(
|
||||
@@ -67,64 +86,100 @@ class DatabaseSeeder:
|
||||
)
|
||||
continue
|
||||
source_id = source_result
|
||||
# 2. Handle Tools
|
||||
|
||||
# 2. Handle Tools
|
||||
tool_ids = self._handle_tools(agent_config)
|
||||
if len(tool_ids) == 0:
|
||||
self.logger.warning(
|
||||
f"No valid tools for agent {agent_config['name']}"
|
||||
)
|
||||
used_tool_ids.update(tool_ids)
|
||||
|
||||
# 3. Handle Prompt
|
||||
|
||||
prompt_id = self._handle_prompt(agent_config)
|
||||
|
||||
# 4. Create Agent
|
||||
# 4. Create or update Agent
|
||||
self._upsert_agent(agent_config, source_id, tool_ids, prompt_id)
|
||||
|
||||
agent_data = {
|
||||
"user": self.system_user_id,
|
||||
"name": agent_config["name"],
|
||||
"description": agent_config["description"],
|
||||
"image": agent_config.get("image", ""),
|
||||
"source": (
|
||||
DBRef("sources", ObjectId(source_id)) if source_id else ""
|
||||
),
|
||||
"tools": [str(tid) for tid in tool_ids],
|
||||
"agent_type": agent_config["agent_type"],
|
||||
"prompt_id": prompt_id or agent_config.get("prompt_id", "default"),
|
||||
"chunks": agent_config.get("chunks", "0"),
|
||||
"retriever": agent_config.get("retriever", ""),
|
||||
"status": "template",
|
||||
"createdAt": datetime.now(timezone.utc),
|
||||
"updatedAt": datetime.now(timezone.utc),
|
||||
}
|
||||
|
||||
existing = self.agents_collection.find_one(
|
||||
{"user": self.system_user_id, "name": agent_config["name"]}
|
||||
)
|
||||
if existing:
|
||||
self.logger.info(f"Updating existing agent: {agent_config['name']}")
|
||||
self.agents_collection.update_one(
|
||||
{"_id": existing["_id"]}, {"$set": agent_data}
|
||||
)
|
||||
agent_id = existing["_id"]
|
||||
else:
|
||||
self.logger.info(f"Creating new agent: {agent_config['name']}")
|
||||
result = self.agents_collection.insert_one(agent_data)
|
||||
agent_id = result.inserted_id
|
||||
self.logger.info(
|
||||
f"Successfully processed agent: {agent_config['name']} (ID: {agent_id})"
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
f"Error processing agent {agent_config['name']}: {str(e)}"
|
||||
)
|
||||
continue
|
||||
self.logger.info("✅ Database seeding completed")
|
||||
self.logger.info("Database seeding completed")
|
||||
|
||||
def _handle_source(self, agent_config: Dict) -> Union[ObjectId, None, bool]:
|
||||
"""Handle source ingestion and return source ID"""
|
||||
@staticmethod
|
||||
def _coerce_uuid_fk(raw) -> Optional[str]:
|
||||
"""Coerce sentinel/blank values to ``None`` for nullable UUID FK columns.
|
||||
|
||||
Mirrors the route-side handling in ``application/api/user/agents/routes.py``:
|
||||
the literal string ``"default"``, empty string, and ``None`` all map
|
||||
to ``None`` so the repository layer skips the column and Postgres
|
||||
keeps the FK NULL (FKs are ``ON DELETE SET NULL``).
|
||||
"""
|
||||
if raw in (None, "", "default"):
|
||||
return None
|
||||
return str(raw)
|
||||
|
||||
def _upsert_agent(
|
||||
self,
|
||||
agent_config: Dict,
|
||||
source_id: Optional[str],
|
||||
tool_ids: List[str],
|
||||
prompt_id: Optional[str],
|
||||
) -> None:
|
||||
"""Create or update a template agent owned by ``__system__``."""
|
||||
name = agent_config["name"]
|
||||
prompt_id_val = self._coerce_uuid_fk(
|
||||
prompt_id if prompt_id is not None else agent_config.get("prompt_id")
|
||||
)
|
||||
folder_id_val = self._coerce_uuid_fk(agent_config.get("folder_id"))
|
||||
workflow_id_val = self._coerce_uuid_fk(agent_config.get("workflow_id"))
|
||||
source_id_val = self._coerce_uuid_fk(source_id)
|
||||
agent_fields = {
|
||||
"description": agent_config["description"],
|
||||
"image": agent_config.get("image", ""),
|
||||
"tools": [str(tid) for tid in tool_ids],
|
||||
"agent_type": agent_config["agent_type"],
|
||||
"prompt_id": prompt_id_val,
|
||||
"chunks": agent_config.get("chunks", "0"),
|
||||
"retriever": agent_config.get("retriever", ""),
|
||||
}
|
||||
if folder_id_val is not None:
|
||||
agent_fields["folder_id"] = folder_id_val
|
||||
if workflow_id_val is not None:
|
||||
agent_fields["workflow_id"] = workflow_id_val
|
||||
if source_id_val is not None:
|
||||
agent_fields["source_id"] = source_id_val
|
||||
|
||||
with db_session() as conn:
|
||||
repo = AgentsRepository(conn)
|
||||
existing = self._find_system_agent_by_name(repo, name)
|
||||
if existing:
|
||||
self.logger.info(f"Updating existing agent: {name}")
|
||||
repo.update(str(existing["id"]), self.system_user_id, agent_fields)
|
||||
self.logger.info(f"Successfully updated agent: {name} (ID: {existing['id']})")
|
||||
else:
|
||||
self.logger.info(f"Creating new agent: {name}")
|
||||
created = repo.create(
|
||||
user_id=self.system_user_id,
|
||||
name=name,
|
||||
status="template",
|
||||
**agent_fields,
|
||||
)
|
||||
self.logger.info(
|
||||
f"Successfully created agent: {name} (ID: {created.get('id')})"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _find_system_agent_by_name(repo: AgentsRepository, name: str) -> Optional[dict]:
|
||||
"""Find a system-owned agent by name among the template rows."""
|
||||
for row in repo.list_for_user(SYSTEM_USER_ID):
|
||||
if row.get("name") == name:
|
||||
return row
|
||||
return None
|
||||
|
||||
def _handle_source(self, agent_config: Dict):
|
||||
"""Handle source ingestion and return a source id (UUID string) or ``None``/``False``."""
|
||||
if not agent_config.get("source"):
|
||||
self.logger.info(
|
||||
"No source provided for agent - will create agent without source"
|
||||
@@ -134,14 +189,15 @@ class DatabaseSeeder:
|
||||
self.logger.info(f"Ingesting source: {source_config['url']}")
|
||||
|
||||
try:
|
||||
existing = self.sources_collection.find_one(
|
||||
{"user": self.system_user_id, "remote_data": source_config["url"]}
|
||||
)
|
||||
with db_readonly() as conn:
|
||||
existing = self._find_system_source_by_remote_url(
|
||||
SourcesRepository(conn), source_config["url"]
|
||||
)
|
||||
if existing:
|
||||
self.logger.info(f"Source already exists: {existing['_id']}")
|
||||
return existing["_id"]
|
||||
# Ingest new source using worker
|
||||
self.logger.info(f"Source already exists: {existing['id']}")
|
||||
return existing["id"]
|
||||
|
||||
# Ingest new source using worker
|
||||
task = ingest_remote.delay(
|
||||
source_data=source_config["url"],
|
||||
job_name=source_config["name"],
|
||||
@@ -164,9 +220,29 @@ class DatabaseSeeder:
|
||||
self.logger.error(f"Failed to ingest source: {str(e)}")
|
||||
return False
|
||||
|
||||
def _handle_tools(self, agent_config: Dict) -> List[ObjectId]:
|
||||
"""Handle tool creation and return list of tool IDs"""
|
||||
tool_ids = []
|
||||
@staticmethod
|
||||
def _find_system_source_by_remote_url(
|
||||
repo: SourcesRepository, url: str
|
||||
) -> Optional[dict]:
|
||||
"""Scan system-owned sources for a row whose remote_data matches ``url``."""
|
||||
# TODO(migration-postgres): push this into SourcesRepository once a
|
||||
# remote_data search helper exists; today we keep the scan here to
|
||||
# stay within this slice's boundaries.
|
||||
try:
|
||||
rows = repo.list_for_user(SYSTEM_USER_ID) # type: ignore[attr-defined]
|
||||
except AttributeError:
|
||||
return None
|
||||
for row in rows:
|
||||
remote = row.get("remote_data")
|
||||
if remote == url:
|
||||
return row
|
||||
if isinstance(remote, dict) and remote.get("url") == url:
|
||||
return row
|
||||
return None
|
||||
|
||||
def _handle_tools(self, agent_config: Dict) -> List[str]:
|
||||
"""Handle tool creation and return list of tool ids (UUID strings)."""
|
||||
tool_ids: List[str] = []
|
||||
if not agent_config.get("tools"):
|
||||
return tool_ids
|
||||
for tool_config in agent_config["tools"]:
|
||||
@@ -175,37 +251,43 @@ class DatabaseSeeder:
|
||||
processed_config = self._process_config(tool_config.get("config", {}))
|
||||
self.logger.info(f"Processing tool: {tool_name}")
|
||||
|
||||
existing = self.tools_collection.find_one(
|
||||
{
|
||||
"user": self.system_user_id,
|
||||
"name": tool_name,
|
||||
"config": processed_config,
|
||||
}
|
||||
)
|
||||
if existing:
|
||||
self.logger.info(f"Tool already exists: {existing['_id']}")
|
||||
tool_ids.append(existing["_id"])
|
||||
continue
|
||||
tool_data = {
|
||||
"user": self.system_user_id,
|
||||
"name": tool_name,
|
||||
"displayName": tool_config.get("display_name", tool_name),
|
||||
"description": tool_config.get("description", ""),
|
||||
"actions": tool_manager.tools[tool_name].get_actions_metadata(),
|
||||
"config": processed_config,
|
||||
"status": True,
|
||||
}
|
||||
|
||||
result = self.tools_collection.insert_one(tool_data)
|
||||
tool_ids.append(result.inserted_id)
|
||||
self.logger.info(f"Created new tool: {result.inserted_id}")
|
||||
with db_session() as conn:
|
||||
repo = UserToolsRepository(conn)
|
||||
existing = self._find_system_tool(
|
||||
repo, tool_name, processed_config
|
||||
)
|
||||
if existing:
|
||||
self.logger.info(f"Tool already exists: {existing['id']}")
|
||||
tool_ids.append(existing["id"])
|
||||
continue
|
||||
created = repo.create(
|
||||
user_id=self.system_user_id,
|
||||
name=tool_name,
|
||||
display_name=tool_config.get("display_name", tool_name),
|
||||
description=tool_config.get("description", ""),
|
||||
actions=tool_manager.tools[tool_name].get_actions_metadata(),
|
||||
config=processed_config,
|
||||
status=True,
|
||||
)
|
||||
tool_ids.append(created["id"])
|
||||
self.logger.info(f"Created new tool: {created['id']}")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to process tool {tool_name}: {str(e)}")
|
||||
continue
|
||||
return tool_ids
|
||||
|
||||
@staticmethod
|
||||
def _find_system_tool(
|
||||
repo: UserToolsRepository, name: str, config: dict
|
||||
) -> Optional[dict]:
|
||||
"""Locate a system-owned tool by (name, config) among existing rows."""
|
||||
existing = repo.find_by_user_and_name(SYSTEM_USER_ID, name)
|
||||
if existing and existing.get("config") == config:
|
||||
return existing
|
||||
return None
|
||||
|
||||
def _handle_prompt(self, agent_config: Dict) -> Optional[str]:
|
||||
"""Handle prompt creation and return prompt ID"""
|
||||
"""Handle prompt creation and return prompt id (UUID string)."""
|
||||
if not agent_config.get("prompt"):
|
||||
return None
|
||||
|
||||
@@ -222,34 +304,20 @@ class DatabaseSeeder:
|
||||
self.logger.info(f"Processing prompt: {prompt_name}")
|
||||
|
||||
try:
|
||||
existing = self.prompts_collection.find_one(
|
||||
{
|
||||
"user": self.system_user_id,
|
||||
"name": prompt_name,
|
||||
"content": prompt_content,
|
||||
}
|
||||
)
|
||||
if existing:
|
||||
self.logger.info(f"Prompt already exists: {existing['_id']}")
|
||||
return str(existing["_id"])
|
||||
|
||||
prompt_data = {
|
||||
"name": prompt_name,
|
||||
"content": prompt_content,
|
||||
"user": self.system_user_id,
|
||||
}
|
||||
|
||||
result = self.prompts_collection.insert_one(prompt_data)
|
||||
prompt_id = str(result.inserted_id)
|
||||
self.logger.info(f"Created new prompt: {prompt_id}")
|
||||
return prompt_id
|
||||
|
||||
with db_session() as conn:
|
||||
repo = PromptsRepository(conn)
|
||||
row = repo.find_or_create(
|
||||
self.system_user_id, prompt_name, prompt_content
|
||||
)
|
||||
prompt_id = str(row["id"])
|
||||
self.logger.info(f"Prompt ready: {prompt_id}")
|
||||
return prompt_id
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to process prompt {prompt_name}: {str(e)}")
|
||||
return None
|
||||
|
||||
def _process_config(self, config: Dict) -> Dict:
|
||||
"""Process config values to replace environment variables"""
|
||||
"""Process config values to replace environment variables."""
|
||||
processed = {}
|
||||
for key, value in config.items():
|
||||
if (
|
||||
@@ -264,14 +332,18 @@ class DatabaseSeeder:
|
||||
return processed
|
||||
|
||||
def _is_already_seeded(self) -> bool:
|
||||
"""Check if premade agents already exist"""
|
||||
return self.agents_collection.count_documents({"user": self.system_user_id}) > 0
|
||||
"""Check if premade (system-owned) agents already exist in Postgres."""
|
||||
with db_readonly() as conn:
|
||||
repo = AgentsRepository(conn)
|
||||
return len(repo.list_for_user(SYSTEM_USER_ID)) > 0
|
||||
|
||||
@classmethod
|
||||
def initialize_from_env(cls, worker=None):
|
||||
"""Factory method to create seeder from environment"""
|
||||
mongo_uri = os.getenv("MONGO_URI", "mongodb://localhost:27017")
|
||||
db_name = os.getenv("MONGO_DB_NAME", "docsgpt")
|
||||
client = MongoClient(mongo_uri)
|
||||
db = client[db_name]
|
||||
return cls(db)
|
||||
"""Factory method to create seeder from environment.
|
||||
|
||||
Retained for back compatibility with existing call sites. The
|
||||
Postgres connection is resolved lazily via the repository layer
|
||||
(``application.storage.db.engine``), so no explicit wiring is
|
||||
required here.
|
||||
"""
|
||||
return cls()
|
||||
|
||||
@@ -7,10 +7,32 @@ cutover is complete, a follow-up phase may migrate repo return types to
|
||||
Pydantic DTOs (tracked in the migration plan as a post-migration item).
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Any, Mapping
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
_UUID_RE = re.compile(
|
||||
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def looks_like_uuid(value: Any) -> bool:
|
||||
"""Return True if ``value`` is a canonical UUID (string or ``UUID`` instance).
|
||||
|
||||
Used by ``get_any`` accessors to pick the UUID lookup path vs. the
|
||||
``legacy_mongo_id`` fallback during the Mongo→PG cutover window.
|
||||
Accepting ``uuid.UUID`` directly matters for callers that receive an
|
||||
id straight from a PG column (SQLAlchemy maps ``UUID`` columns to the
|
||||
Python ``UUID`` type) — without this, the call falls through to the
|
||||
legacy-text lookup and crashes on ``operator does not exist: text = uuid``.
|
||||
"""
|
||||
if isinstance(value, UUID):
|
||||
return True
|
||||
return isinstance(value, str) and bool(_UUID_RE.match(value))
|
||||
|
||||
|
||||
def row_to_dict(row: Any) -> dict:
|
||||
"""Convert a SQLAlchemy ``Row`` to a plain dict with Mongo-compatible ids.
|
||||
|
||||
|
||||
320
application/storage/db/bootstrap.py
Normal file
320
application/storage/db/bootstrap.py
Normal file
@@ -0,0 +1,320 @@
|
||||
"""Self-bootstrapping database setup for the DocsGPT user-data Postgres DB.
|
||||
|
||||
On app startup the Flask factory (and Celery worker init) can call
|
||||
:func:`ensure_database_ready` to:
|
||||
|
||||
1. Create the target database if it's missing (dev-friendly; requires the
|
||||
configured role to have ``CREATEDB`` privilege).
|
||||
2. Apply every pending Alembic migration up to ``head``.
|
||||
|
||||
Both steps are gated by settings that default ON for dev convenience and
|
||||
can be turned off in prod (``AUTO_CREATE_DB`` / ``AUTO_MIGRATE``) where
|
||||
schema is managed out-of-band by a deploy pipeline.
|
||||
|
||||
All heavy imports (alembic, psycopg, sqlalchemy.exc sub-symbols) are
|
||||
deferred to inside the function so merely importing this module has no
|
||||
side effects and is cheap for test collection.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def ensure_database_ready(
|
||||
uri: Optional[str],
|
||||
*,
|
||||
create_db: bool,
|
||||
migrate: bool,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
"""Make sure the target Postgres DB exists and is migrated to ``head``.
|
||||
|
||||
This is idempotent and safe to call once per process. Each step is
|
||||
independently gated so prod deployments that manage schema externally
|
||||
can disable the migrate step while still allowing the process to boot
|
||||
against an already-provisioned database.
|
||||
|
||||
Args:
|
||||
uri: SQLAlchemy URI for the user-data Postgres database. If
|
||||
``None`` or empty, the function logs and returns — the app
|
||||
supports running without a configured URI for certain dev
|
||||
flows that don't touch user data.
|
||||
create_db: If ``True``, auto-create the database when it's
|
||||
missing. Requires the configured role to have ``CREATEDB``.
|
||||
migrate: If ``True``, run ``alembic upgrade head`` after the
|
||||
database is reachable.
|
||||
logger: Optional logger to use. Defaults to this module's logger.
|
||||
|
||||
Raises:
|
||||
Exception: Any failure in an explicitly-enabled step is re-raised
|
||||
so the app fails fast rather than booting into a broken state.
|
||||
Missing-role / auth errors surface cleanly without a
|
||||
mis-directed auto-create attempt.
|
||||
"""
|
||||
log = logger or logging.getLogger(__name__)
|
||||
|
||||
if not uri:
|
||||
log.info(
|
||||
"ensure_database_ready: POSTGRES_URI is not set; "
|
||||
"skipping database bootstrap."
|
||||
)
|
||||
return
|
||||
|
||||
if create_db:
|
||||
_ensure_database_exists(uri, log)
|
||||
|
||||
if migrate:
|
||||
_run_migrations(log)
|
||||
|
||||
|
||||
def _ensure_database_exists(uri: str, log: logging.Logger) -> None:
|
||||
"""Create the target database if a connection reveals it's missing.
|
||||
|
||||
We probe with a lightweight ``connect().close()``. If Postgres
|
||||
reports ``InvalidCatalogName`` (SQLSTATE ``3D000``), we reconnect to
|
||||
the server's ``postgres`` maintenance DB and issue ``CREATE DATABASE``
|
||||
in AUTOCOMMIT mode (required — CREATE DATABASE can't run in a
|
||||
transaction). Any other connection failure (bad host, auth failure,
|
||||
missing role) is re-raised untouched so the operator sees the true
|
||||
cause instead of a mis-directed auto-create attempt.
|
||||
"""
|
||||
# Lazy imports keep module import side-effect free.
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.engine import make_url
|
||||
from sqlalchemy.exc import OperationalError
|
||||
|
||||
url = make_url(uri)
|
||||
target_db = url.database
|
||||
if not target_db:
|
||||
raise RuntimeError(
|
||||
f"POSTGRES_URI is missing a database name: {uri!r}. "
|
||||
"Expected something like "
|
||||
"'postgresql+psycopg://user:pass@host:5432/docsgpt'."
|
||||
)
|
||||
|
||||
probe_engine = create_engine(uri, pool_pre_ping=False)
|
||||
try:
|
||||
try:
|
||||
conn = probe_engine.connect()
|
||||
except OperationalError as exc:
|
||||
if _is_missing_database(exc):
|
||||
log.info(
|
||||
"ensure_database_ready: database %r is missing; "
|
||||
"creating it...",
|
||||
target_db,
|
||||
)
|
||||
_create_database(url, target_db, log)
|
||||
log.info("ensure_database_ready: database %r ready.", target_db)
|
||||
return
|
||||
# Not a missing-DB error — surface it as-is. This is the path
|
||||
# for bad host/auth/role-missing, and auto-creating would be
|
||||
# actively wrong there.
|
||||
log.error(
|
||||
"ensure_database_ready: cannot connect to Postgres for "
|
||||
"database %r: %s",
|
||||
target_db,
|
||||
exc,
|
||||
)
|
||||
raise
|
||||
else:
|
||||
conn.close()
|
||||
log.info("ensure_database_ready: database %r ready.", target_db)
|
||||
finally:
|
||||
probe_engine.dispose()
|
||||
|
||||
|
||||
def _create_database(url, target_db: str, log: logging.Logger) -> None:
|
||||
"""Issue ``CREATE DATABASE`` against the server's ``postgres`` DB.
|
||||
|
||||
Uses AUTOCOMMIT (required by Postgres — ``CREATE DATABASE`` cannot run
|
||||
inside a transaction). The database identifier is quoted via
|
||||
``psycopg.sql.Identifier`` so unusual names (hyphens, reserved words)
|
||||
are handled correctly.
|
||||
|
||||
Args:
|
||||
url: Parsed SQLAlchemy URL for the target DB; we reuse
|
||||
host/port/credentials and swap the database to ``postgres``.
|
||||
target_db: The target database name to create.
|
||||
log: Logger for INFO/ERROR breadcrumbs.
|
||||
"""
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.exc import OperationalError, ProgrammingError
|
||||
|
||||
# psycopg is imported lazily — its error classes are the canonical
|
||||
# cause markers Postgres hands us back.
|
||||
import psycopg
|
||||
from psycopg import sql as pg_sql
|
||||
|
||||
maintenance_url = url.set(database="postgres")
|
||||
maintenance_engine = create_engine(
|
||||
maintenance_url,
|
||||
isolation_level="AUTOCOMMIT",
|
||||
pool_pre_ping=False,
|
||||
)
|
||||
try:
|
||||
with maintenance_engine.connect() as conn:
|
||||
# Use psycopg's Identifier to quote the DB name safely. The
|
||||
# SQL object renders as a literal ``CREATE DATABASE "<name>"``
|
||||
# which SQLAlchemy passes through to psycopg verbatim.
|
||||
stmt = pg_sql.SQL("CREATE DATABASE {}").format(
|
||||
pg_sql.Identifier(target_db)
|
||||
)
|
||||
raw = conn.connection.dbapi_connection # psycopg connection
|
||||
with raw.cursor() as cur:
|
||||
try:
|
||||
cur.execute(stmt)
|
||||
except psycopg.errors.DuplicateDatabase:
|
||||
# Another worker won the race — benign.
|
||||
log.info(
|
||||
"ensure_database_ready: database %r already "
|
||||
"created by a concurrent worker; continuing.",
|
||||
target_db,
|
||||
)
|
||||
except psycopg.errors.InsufficientPrivilege as exc:
|
||||
log.error(
|
||||
"ensure_database_ready: role lacks CREATEDB "
|
||||
"privilege to create %r. Either GRANT CREATEDB "
|
||||
"to the role, create the database manually, or "
|
||||
"set AUTO_CREATE_DB=False and provision it "
|
||||
"out-of-band. See docs/Deploying/Postgres-"
|
||||
"Migration for guidance. Underlying error: %s",
|
||||
target_db,
|
||||
exc,
|
||||
)
|
||||
raise
|
||||
except (OperationalError, ProgrammingError) as exc:
|
||||
log.error(
|
||||
"ensure_database_ready: failed to create database %r: %s. "
|
||||
"See docs/Deploying/Postgres-Migration for manual setup.",
|
||||
target_db,
|
||||
exc,
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
maintenance_engine.dispose()
|
||||
|
||||
|
||||
def _is_missing_database(exc: Exception) -> bool:
|
||||
"""Return True if ``exc`` indicates the target database doesn't exist.
|
||||
|
||||
We check three signals in the cause chain:
|
||||
|
||||
1. ``psycopg.errors.InvalidCatalogName`` — the canonical class for
|
||||
SQLSTATE ``3D000`` when raised during a query.
|
||||
2. ``pgcode`` / ``diag.sqlstate`` equal to ``3D000`` — defensive, for
|
||||
driver versions that surface the code on a generic class.
|
||||
3. The canonical server message phrasing ``database "..." does not
|
||||
exist`` — **required** for connection-time failures, because
|
||||
psycopg 3's ``OperationalError`` raised by ``connect()`` does NOT
|
||||
populate ``sqlstate`` (the connection never completed the protocol
|
||||
handshake, so the attributes stay ``None``). The server's error
|
||||
message itself is stable across Postgres versions, so this is a
|
||||
reliable fallback for the only case that matters: DB missing at
|
||||
boot.
|
||||
"""
|
||||
try:
|
||||
import psycopg
|
||||
|
||||
invalid_catalog = psycopg.errors.InvalidCatalogName
|
||||
except Exception: # noqa: BLE001 — defensive; never break on import
|
||||
invalid_catalog = None
|
||||
|
||||
seen: set[int] = set()
|
||||
cursor: Optional[BaseException] = exc
|
||||
while cursor is not None and id(cursor) not in seen:
|
||||
seen.add(id(cursor))
|
||||
if invalid_catalog is not None and isinstance(cursor, invalid_catalog):
|
||||
return True
|
||||
pgcode = getattr(cursor, "pgcode", None) or getattr(
|
||||
getattr(cursor, "diag", None), "sqlstate", None
|
||||
)
|
||||
if pgcode == "3D000":
|
||||
return True
|
||||
msg = str(cursor)
|
||||
if 'database "' in msg and "does not exist" in msg:
|
||||
return True
|
||||
cursor = cursor.__cause__ or cursor.__context__
|
||||
return False
|
||||
|
||||
|
||||
def _run_migrations(log: logging.Logger) -> None:
|
||||
"""Run ``alembic upgrade head`` against ``POSTGRES_URI``.
|
||||
|
||||
Alembic serializes concurrent workers via its ``alembic_version``
|
||||
table, so no extra application-level locking is needed. Failures are
|
||||
logged and re-raised so the app fails fast.
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
# Lazy imports — alembic pulls in a fair amount of code.
|
||||
from alembic import command
|
||||
from alembic.config import Config
|
||||
from alembic.runtime.migration import MigrationContext
|
||||
from alembic.script import ScriptDirectory
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
# Mirror the discovery path used by scripts/db/init_postgres.py so
|
||||
# both entry points resolve the same alembic.ini regardless of cwd.
|
||||
alembic_ini = Path(__file__).resolve().parents[2] / "alembic.ini"
|
||||
if not alembic_ini.exists():
|
||||
raise RuntimeError(f"alembic.ini not found at {alembic_ini}")
|
||||
|
||||
cfg = Config(str(alembic_ini))
|
||||
cfg.set_main_option("script_location", str(alembic_ini.parent / "alembic"))
|
||||
|
||||
# Cheap pre-check: if we're already at head, say so explicitly.
|
||||
try:
|
||||
script = ScriptDirectory.from_config(cfg)
|
||||
head_rev = script.get_current_head()
|
||||
url = cfg.get_main_option("sqlalchemy.url")
|
||||
# env.py populates sqlalchemy.url from settings.POSTGRES_URI when
|
||||
# it's imported, but our Config instance hasn't loaded env.py
|
||||
# yet. Fall back to reading settings directly for the precheck.
|
||||
if not url:
|
||||
from application.core.settings import settings as _settings
|
||||
|
||||
url = _settings.POSTGRES_URI
|
||||
current_rev: Optional[str] = None
|
||||
if url:
|
||||
precheck_engine = create_engine(url, pool_pre_ping=False)
|
||||
try:
|
||||
with precheck_engine.connect() as conn:
|
||||
ctx = MigrationContext.configure(conn)
|
||||
current_rev = ctx.get_current_revision()
|
||||
finally:
|
||||
precheck_engine.dispose()
|
||||
if current_rev is not None and current_rev == head_rev:
|
||||
log.info(
|
||||
"ensure_database_ready: migrations already at head (%s); "
|
||||
"nothing to do.",
|
||||
head_rev,
|
||||
)
|
||||
return
|
||||
log.info(
|
||||
"ensure_database_ready: applying Alembic migrations "
|
||||
"(current=%s, target=%s)...",
|
||||
current_rev,
|
||||
head_rev,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001 — precheck is best-effort
|
||||
# If the precheck itself fails we still want to try the upgrade;
|
||||
# alembic will give a more actionable error if something's off.
|
||||
log.info(
|
||||
"ensure_database_ready: revision precheck failed (%s); "
|
||||
"proceeding with upgrade anyway.",
|
||||
exc,
|
||||
)
|
||||
|
||||
try:
|
||||
command.upgrade(cfg, "head")
|
||||
except Exception as exc: # noqa: BLE001 — surface everything
|
||||
log.error(
|
||||
"ensure_database_ready: alembic upgrade failed: %s. "
|
||||
"Check migration logs and DB connectivity; the app will not "
|
||||
"boot until this is resolved (or AUTO_MIGRATE is disabled).",
|
||||
exc,
|
||||
)
|
||||
raise
|
||||
log.info("ensure_database_ready: migrations applied.")
|
||||
@@ -1,67 +0,0 @@
|
||||
"""Best-effort Postgres dual-write helper used during the MongoDB→Postgres
|
||||
migration.
|
||||
|
||||
The helper:
|
||||
|
||||
* Returns immediately if ``settings.USE_POSTGRES`` is off, so default-off
|
||||
call sites add literally zero work.
|
||||
* Opens a transactional connection from the user-data SQLAlchemy engine.
|
||||
* Instantiates the caller's repository class on that connection.
|
||||
* Runs the caller's operation.
|
||||
* Swallows and logs any exception. **Mongo remains the source of truth
|
||||
during the dual-write window** — a Postgres-side failure must never
|
||||
break a user-facing request. Drift that builds up from swallowed
|
||||
failures is caught separately by re-running the backfill script.
|
||||
|
||||
Call sites look like::
|
||||
|
||||
users_collection.update_one(..., {"$addToSet": {...}}) # Mongo write, unchanged
|
||||
dual_write(UsersRepository, lambda r: r.add_pinned(uid, aid)) # Postgres mirror
|
||||
|
||||
A single parameterised helper rather than one function per collection
|
||||
means a new collection just needs its repository class — no new helper
|
||||
function, no new feature flag. The whole helper is deleted at Phase 5
|
||||
when the migration is complete.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Callable, TypeVar
|
||||
|
||||
from application.core.settings import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_Repo = TypeVar("_Repo")
|
||||
|
||||
|
||||
def dual_write(repo_cls: type[_Repo], fn: Callable[[_Repo], None]) -> None:
|
||||
"""Mirror a Mongo write into Postgres via ``repo_cls``, best-effort.
|
||||
|
||||
No-op when ``settings.USE_POSTGRES`` is false. Any exception
|
||||
(connection pool exhaustion, migration drift, SQL error) is logged
|
||||
and swallowed so the caller's primary Mongo write remains the source
|
||||
of truth.
|
||||
|
||||
Args:
|
||||
repo_cls: The repository class to instantiate (e.g. ``UsersRepository``).
|
||||
fn: A callable that takes the instantiated repository and performs
|
||||
the desired write.
|
||||
"""
|
||||
if not settings.USE_POSTGRES:
|
||||
return
|
||||
|
||||
try:
|
||||
# Lazy import so modules that import dual_write don't pay the
|
||||
# SQLAlchemy import cost when the flag is off.
|
||||
from application.storage.db.engine import get_engine
|
||||
|
||||
with get_engine().begin() as conn:
|
||||
fn(repo_cls(conn))
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Postgres dual-write failed for %s — Mongo write already committed",
|
||||
repo_cls.__name__,
|
||||
exc_info=True,
|
||||
)
|
||||
@@ -40,9 +40,21 @@ def _resolve_uri() -> str:
|
||||
return settings.POSTGRES_URI
|
||||
|
||||
|
||||
#: Per-statement wall-clock cap applied to every connection handed out by
|
||||
#: the engine. 30s is generous for interactive hot paths (reads under a few
|
||||
#: hundred ms are normal) but still catches a runaway query before it
|
||||
#: stacks up on PgBouncer or holds locks indefinitely. Override by
|
||||
#: rebuilding the engine with a different ``connect_args`` in tests.
|
||||
STATEMENT_TIMEOUT_MS = 30_000
|
||||
|
||||
|
||||
def get_engine() -> Engine:
|
||||
"""Return the process-wide SQLAlchemy Engine, creating it if needed.
|
||||
|
||||
The engine applies a server-side ``statement_timeout`` to every
|
||||
connection it hands out, so both :func:`db_session` and
|
||||
:func:`db_readonly` inherit the same guardrail.
|
||||
|
||||
Returns:
|
||||
A SQLAlchemy ``Engine`` configured with a pooled connection to
|
||||
Postgres via psycopg3.
|
||||
@@ -56,6 +68,12 @@ def get_engine() -> Engine:
|
||||
pool_pre_ping=True, # survive PgBouncer / idle-disconnect recycles
|
||||
pool_recycle=1800,
|
||||
future=True,
|
||||
connect_args={
|
||||
# ``-c`` passes a GUC to the backend at connect time. This
|
||||
# covers *all* sessions — interactive, Celery, seeder — so
|
||||
# no route-handler can opt out by accident.
|
||||
"options": f"-c statement_timeout={STATEMENT_TIMEOUT_MS}",
|
||||
},
|
||||
)
|
||||
return _engine
|
||||
|
||||
|
||||
@@ -71,9 +71,14 @@ user_tools_table = Table(
|
||||
Column("name", Text, nullable=False),
|
||||
Column("custom_name", Text),
|
||||
Column("display_name", Text),
|
||||
Column("description", Text),
|
||||
Column("config", JSONB, nullable=False, server_default="{}"),
|
||||
Column("config_requirements", JSONB, nullable=False, server_default="{}"),
|
||||
Column("actions", JSONB, nullable=False, server_default="[]"),
|
||||
Column("status", Boolean, nullable=False, server_default="true"),
|
||||
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("legacy_mongo_id", Text),
|
||||
)
|
||||
|
||||
token_usage_table = Table(
|
||||
@@ -122,8 +127,10 @@ agent_folders_table = Table(
|
||||
Column("user_id", Text, nullable=False),
|
||||
Column("name", Text, nullable=False),
|
||||
Column("description", Text),
|
||||
Column("parent_id", UUID(as_uuid=True), ForeignKey("agent_folders.id", ondelete="SET NULL")),
|
||||
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("legacy_mongo_id", Text),
|
||||
)
|
||||
|
||||
sources_table = Table(
|
||||
@@ -132,10 +139,21 @@ sources_table = Table(
|
||||
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||
Column("user_id", Text, nullable=False),
|
||||
Column("name", Text, nullable=False),
|
||||
Column("language", Text),
|
||||
Column("date", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("model", Text),
|
||||
Column("type", Text),
|
||||
Column("metadata", JSONB, nullable=False, server_default="{}"),
|
||||
Column("retriever", Text),
|
||||
Column("sync_frequency", Text),
|
||||
Column("tokens", Text),
|
||||
Column("file_path", Text),
|
||||
Column("remote_data", JSONB),
|
||||
Column("directory_structure", JSONB),
|
||||
Column("file_name_map", JSONB),
|
||||
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("legacy_mongo_id", Text),
|
||||
)
|
||||
|
||||
agents_table = Table(
|
||||
@@ -148,6 +166,7 @@ agents_table = Table(
|
||||
Column("agent_type", Text),
|
||||
Column("status", Text, nullable=False),
|
||||
Column("key", CITEXT, unique=True),
|
||||
Column("image", Text),
|
||||
Column("source_id", UUID(as_uuid=True), ForeignKey("sources.id", ondelete="SET NULL")),
|
||||
Column("extra_source_ids", ARRAY(UUID(as_uuid=True)), nullable=False, server_default="{}"),
|
||||
Column("chunks", Integer),
|
||||
@@ -158,11 +177,15 @@ agents_table = Table(
|
||||
Column("models", JSONB),
|
||||
Column("default_model_id", Text),
|
||||
Column("folder_id", UUID(as_uuid=True), ForeignKey("agent_folders.id", ondelete="SET NULL")),
|
||||
Column("workflow_id", UUID(as_uuid=True), ForeignKey("workflows.id", ondelete="SET NULL")),
|
||||
Column("limited_token_mode", Boolean, nullable=False, server_default="false"),
|
||||
Column("token_limit", Integer),
|
||||
Column("limited_request_mode", Boolean, nullable=False, server_default="false"),
|
||||
Column("request_limit", Integer),
|
||||
Column("allow_system_prompt_override", Boolean, nullable=False, server_default="false"),
|
||||
Column("shared", Boolean, nullable=False, server_default="false"),
|
||||
Column("shared_token", CITEXT, unique=True),
|
||||
Column("shared_metadata", JSONB),
|
||||
Column("incoming_webhook_token", CITEXT, unique=True),
|
||||
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
@@ -179,6 +202,11 @@ attachments_table = Table(
|
||||
Column("upload_path", Text, nullable=False),
|
||||
Column("mime_type", Text),
|
||||
Column("size", BigInteger),
|
||||
Column("content", Text),
|
||||
Column("token_count", Integer),
|
||||
Column("openai_file_id", Text),
|
||||
Column("google_file_uri", Text),
|
||||
Column("metadata", JSONB),
|
||||
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("legacy_mongo_id", Text),
|
||||
)
|
||||
@@ -191,6 +219,7 @@ memories_table = Table(
|
||||
Column("tool_id", UUID(as_uuid=True), ForeignKey("user_tools.id", ondelete="CASCADE")),
|
||||
Column("path", Text, nullable=False),
|
||||
Column("content", Text, nullable=False),
|
||||
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
UniqueConstraint("user_id", "tool_id", "path", name="memories_user_tool_path_uidx"),
|
||||
)
|
||||
@@ -201,10 +230,12 @@ todos_table = Table(
|
||||
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||
Column("user_id", Text, nullable=False),
|
||||
Column("tool_id", UUID(as_uuid=True), ForeignKey("user_tools.id", ondelete="CASCADE")),
|
||||
Column("todo_id", Integer),
|
||||
Column("title", Text, nullable=False),
|
||||
Column("completed", Boolean, nullable=False, server_default="false"),
|
||||
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("legacy_mongo_id", Text),
|
||||
)
|
||||
|
||||
notes_table = Table(
|
||||
@@ -226,10 +257,15 @@ connector_sessions_table = Table(
|
||||
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||
Column("user_id", Text, nullable=False),
|
||||
Column("provider", Text, nullable=False),
|
||||
Column("session_data", JSONB, nullable=False),
|
||||
Column("server_url", Text),
|
||||
Column("session_token", Text, unique=True),
|
||||
Column("user_email", Text),
|
||||
Column("status", Text),
|
||||
Column("token_info", JSONB),
|
||||
Column("session_data", JSONB, nullable=False, server_default="{}"),
|
||||
Column("expires_at", DateTime(timezone=True)),
|
||||
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
UniqueConstraint("user_id", "provider", name="connector_sessions_user_provider_uidx"),
|
||||
Column("legacy_mongo_id", Text),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,28 +1,60 @@
|
||||
"""Repository for the ``agent_folders`` table."""
|
||||
"""Repository for the ``agent_folders`` table.
|
||||
|
||||
Folders are self-referential via ``parent_id`` to model nested folder
|
||||
hierarchies — a folder can sit inside another folder, and on delete the
|
||||
DB sets each child's ``parent_id`` to NULL (no cascade) so children
|
||||
survive their parent's removal but flatten to the top level. The legacy
|
||||
Mongo route used ``$unset: {parent_id: ""}`` against children before
|
||||
deleting the parent; that pre-step is no longer needed because the FK
|
||||
``ON DELETE SET NULL`` action does it automatically.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
from sqlalchemy import Connection, func, text
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
from application.storage.db.models import agent_folders_table
|
||||
|
||||
|
||||
_ALLOWED_UPDATE_COLUMNS = {"name", "description", "parent_id"}
|
||||
|
||||
|
||||
class AgentFoldersRepository:
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def create(self, user_id: str, name: str, *, description: Optional[str] = None) -> dict:
|
||||
def create(
|
||||
self,
|
||||
user_id: str,
|
||||
name: str,
|
||||
*,
|
||||
description: Optional[str] = None,
|
||||
parent_id: Optional[str] = None,
|
||||
legacy_mongo_id: Optional[str] = None,
|
||||
) -> dict:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO agent_folders (user_id, name, description)
|
||||
VALUES (:user_id, :name, :description)
|
||||
INSERT INTO agent_folders (
|
||||
user_id, name, description, parent_id, legacy_mongo_id
|
||||
)
|
||||
VALUES (
|
||||
:user_id, :name, :description,
|
||||
CAST(:parent_id AS uuid), :legacy_mongo_id
|
||||
)
|
||||
RETURNING *
|
||||
"""
|
||||
),
|
||||
{"user_id": user_id, "name": name, "description": description},
|
||||
{
|
||||
"user_id": user_id,
|
||||
"name": name,
|
||||
"description": description,
|
||||
"parent_id": str(parent_id) if parent_id else None,
|
||||
"legacy_mongo_id": legacy_mongo_id,
|
||||
},
|
||||
)
|
||||
return row_to_dict(result.fetchone())
|
||||
|
||||
@@ -34,6 +66,19 @@ class AgentFoldersRepository:
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def get_by_legacy_id(
|
||||
self, legacy_mongo_id: str, user_id: Optional[str] = None
|
||||
) -> Optional[dict]:
|
||||
legacy_mongo_id = str(legacy_mongo_id) if legacy_mongo_id is not None else None
|
||||
sql = "SELECT * FROM agent_folders WHERE legacy_mongo_id = :legacy_id"
|
||||
params: dict[str, str] = {"legacy_id": legacy_mongo_id}
|
||||
if user_id is not None:
|
||||
sql += " AND user_id = :user_id"
|
||||
params["user_id"] = user_id
|
||||
result = self._conn.execute(text(sql), params)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def list_for_user(self, user_id: str) -> list[dict]:
|
||||
result = self._conn.execute(
|
||||
text("SELECT * FROM agent_folders WHERE user_id = :user_id ORDER BY created_at"),
|
||||
@@ -41,46 +86,53 @@ class AgentFoldersRepository:
|
||||
)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def update(self, folder_id: str, user_id: str, fields: dict) -> bool:
|
||||
allowed = {"name", "description"}
|
||||
filtered = {k: v for k, v in fields.items() if k in allowed}
|
||||
def list_children(self, parent_id: str, user_id: str) -> list[dict]:
|
||||
"""List immediate children of ``parent_id`` for nested-folder UIs."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM agent_folders "
|
||||
"WHERE parent_id = CAST(:parent_id AS uuid) AND user_id = :user_id "
|
||||
"ORDER BY created_at"
|
||||
),
|
||||
{"parent_id": parent_id, "user_id": user_id},
|
||||
)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def update(self, folder_id: str, user_id: str, fields: dict[str, Any]) -> bool:
|
||||
"""Partial update.
|
||||
|
||||
The route validates that ``parent_id != folder_id`` (no self-parenting)
|
||||
before calling here; this layer does not re-check.
|
||||
"""
|
||||
filtered = {k: v for k, v in fields.items() if k in _ALLOWED_UPDATE_COLUMNS}
|
||||
if not filtered:
|
||||
return False
|
||||
params: dict = {"id": folder_id, "user_id": user_id}
|
||||
if "name" in filtered and "description" in filtered:
|
||||
params["name"] = filtered["name"]
|
||||
params["description"] = filtered["description"]
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"UPDATE agent_folders "
|
||||
"SET name = :name, description = :description, updated_at = now() "
|
||||
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||
),
|
||||
params,
|
||||
)
|
||||
elif "name" in filtered:
|
||||
params["name"] = filtered["name"]
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"UPDATE agent_folders "
|
||||
"SET name = :name, updated_at = now() "
|
||||
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||
),
|
||||
params,
|
||||
)
|
||||
else:
|
||||
params["description"] = filtered["description"]
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"UPDATE agent_folders "
|
||||
"SET description = :description, updated_at = now() "
|
||||
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||
),
|
||||
params,
|
||||
)
|
||||
|
||||
values: dict = {}
|
||||
for col, val in filtered.items():
|
||||
if col == "parent_id":
|
||||
values[col] = str(val) if val else None
|
||||
else:
|
||||
values[col] = val
|
||||
values["updated_at"] = func.now()
|
||||
|
||||
t = agent_folders_table
|
||||
stmt = (
|
||||
t.update()
|
||||
.where(t.c.id == folder_id)
|
||||
.where(t.c.user_id == user_id)
|
||||
.values(**values)
|
||||
)
|
||||
result = self._conn.execute(stmt)
|
||||
return result.rowcount > 0
|
||||
|
||||
def delete(self, folder_id: str, user_id: str) -> bool:
|
||||
"""Delete a folder.
|
||||
|
||||
The schema's ``ON DELETE SET NULL`` on the self-FK takes care of
|
||||
un-parenting any child folders, and the agents table's
|
||||
``folder_id`` FK does the same for agents in the folder.
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text("DELETE FROM agent_folders WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
|
||||
{"id": folder_id, "user_id": user_id},
|
||||
|
||||
@@ -17,7 +17,7 @@ from typing import Optional
|
||||
from sqlalchemy import Connection, func, text
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
from application.storage.db.base_repository import looks_like_uuid, row_to_dict
|
||||
from application.storage.db.models import agents_table
|
||||
|
||||
|
||||
@@ -38,16 +38,20 @@ class AgentsRepository:
|
||||
_ALLOWED = {
|
||||
"description", "agent_type", "key", "retriever",
|
||||
"default_model_id", "incoming_webhook_token",
|
||||
"source_id", "prompt_id", "folder_id",
|
||||
"source_id", "prompt_id", "folder_id", "workflow_id",
|
||||
"extra_source_ids", "image",
|
||||
"chunks", "token_limit", "request_limit",
|
||||
"limited_token_mode", "limited_request_mode", "shared",
|
||||
"limited_token_mode", "limited_request_mode",
|
||||
"allow_system_prompt_override",
|
||||
"shared", "shared_token", "shared_metadata",
|
||||
"tools", "json_schema", "models", "legacy_mongo_id",
|
||||
"created_at", "updated_at", "last_used_at",
|
||||
}
|
||||
|
||||
for col, val in kwargs.items():
|
||||
if col not in _ALLOWED or val is None:
|
||||
continue
|
||||
if col in ("tools", "json_schema", "models"):
|
||||
if col in ("tools", "json_schema", "models", "shared_metadata"):
|
||||
# JSONB columns: pass the Python object directly. SQLAlchemy
|
||||
# Core's JSONB type processor json.dumps it once during
|
||||
# bind; pre-serialising would double-encode and the value
|
||||
@@ -55,10 +59,16 @@ class AgentsRepository:
|
||||
values[col] = val
|
||||
elif col in ("chunks", "token_limit", "request_limit"):
|
||||
values[col] = int(val)
|
||||
elif col in ("limited_token_mode", "limited_request_mode", "shared"):
|
||||
elif col in (
|
||||
"limited_token_mode", "limited_request_mode",
|
||||
"shared", "allow_system_prompt_override",
|
||||
):
|
||||
values[col] = bool(val)
|
||||
elif col in ("source_id", "prompt_id", "folder_id"):
|
||||
elif col in ("source_id", "prompt_id", "folder_id", "workflow_id"):
|
||||
values[col] = str(val)
|
||||
elif col == "extra_source_ids":
|
||||
# ARRAY(UUID) — pass list of strings; psycopg adapts it.
|
||||
values[col] = [str(x) for x in val] if val else []
|
||||
else:
|
||||
values[col] = self._normalize_unique_text(col, val)
|
||||
|
||||
@@ -74,8 +84,23 @@ class AgentsRepository:
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def get_any(self, agent_id: str, user_id: str) -> Optional[dict]:
|
||||
"""Resolve an agent by either PG UUID or legacy Mongo ObjectId string.
|
||||
|
||||
Cutover helper: URLs / bookmarks / old client state may still hold
|
||||
Mongo ObjectId-strings. Try the UUID path first (the post-cutover
|
||||
shape) and fall back to ``legacy_mongo_id`` — both are scoped by
|
||||
``user_id`` so cross-user access is impossible.
|
||||
"""
|
||||
if looks_like_uuid(agent_id):
|
||||
row = self.get(agent_id, user_id)
|
||||
if row is not None:
|
||||
return row
|
||||
return self.get_by_legacy_id(agent_id, user_id)
|
||||
|
||||
def get_by_legacy_id(self, legacy_mongo_id: str, user_id: str | None = None) -> Optional[dict]:
|
||||
"""Fetch an agent by the original Mongo ObjectId string."""
|
||||
legacy_mongo_id = str(legacy_mongo_id) if legacy_mongo_id is not None else None
|
||||
sql = "SELECT * FROM agents WHERE legacy_mongo_id = :legacy_id"
|
||||
params: dict[str, str] = {"legacy_id": legacy_mongo_id}
|
||||
if user_id is not None:
|
||||
@@ -93,6 +118,23 @@ class AgentsRepository:
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def find_by_shared_token(self, token: str) -> Optional[dict]:
|
||||
"""Resolve a publicly-shared agent by its rotating share token.
|
||||
|
||||
Only returns rows with ``shared = true`` so revoking a share
|
||||
(setting ``shared = false``) immediately stops token access even
|
||||
if the token value itself is still in the row.
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM agents "
|
||||
"WHERE shared_token = :token AND shared = true"
|
||||
),
|
||||
{"token": token},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def find_by_webhook_token(self, token: str) -> Optional[dict]:
|
||||
result = self._conn.execute(
|
||||
text("SELECT * FROM agents WHERE incoming_webhook_token = :token"),
|
||||
@@ -118,8 +160,12 @@ class AgentsRepository:
|
||||
allowed = {
|
||||
"name", "description", "agent_type", "status", "key", "source_id",
|
||||
"chunks", "retriever", "prompt_id", "tools", "json_schema", "models",
|
||||
"default_model_id", "folder_id", "limited_token_mode", "token_limit",
|
||||
"limited_request_mode", "request_limit", "shared",
|
||||
"default_model_id", "folder_id", "workflow_id",
|
||||
"extra_source_ids", "image",
|
||||
"limited_token_mode", "token_limit",
|
||||
"limited_request_mode", "request_limit",
|
||||
"allow_system_prompt_override",
|
||||
"shared", "shared_token", "shared_metadata",
|
||||
"incoming_webhook_token", "last_used_at",
|
||||
}
|
||||
filtered = {k: v for k, v in fields.items() if k in allowed}
|
||||
@@ -128,12 +174,19 @@ class AgentsRepository:
|
||||
|
||||
values: dict = {}
|
||||
for col, val in filtered.items():
|
||||
if col in ("tools", "json_schema", "models"):
|
||||
if col in ("tools", "json_schema", "models", "shared_metadata"):
|
||||
# See note in create(): JSONB columns receive Python
|
||||
# objects, the type processor handles serialisation.
|
||||
values[col] = val
|
||||
elif col in ("source_id", "prompt_id", "folder_id"):
|
||||
elif col in ("source_id", "prompt_id", "folder_id", "workflow_id"):
|
||||
values[col] = str(val) if val else None
|
||||
elif col == "extra_source_ids":
|
||||
values[col] = [str(x) for x in val] if val else []
|
||||
elif col in (
|
||||
"limited_token_mode", "limited_request_mode",
|
||||
"shared", "allow_system_prompt_override",
|
||||
):
|
||||
values[col] = bool(val)
|
||||
else:
|
||||
values[col] = self._normalize_unique_text(col, val)
|
||||
values["updated_at"] = func.now()
|
||||
@@ -150,6 +203,7 @@ class AgentsRepository:
|
||||
|
||||
def update_by_legacy_id(self, legacy_mongo_id: str, user_id: str, fields: dict) -> bool:
|
||||
"""Update an agent addressed by the Mongo ObjectId string."""
|
||||
legacy_mongo_id = str(legacy_mongo_id) if legacy_mongo_id is not None else None
|
||||
agent = self.get_by_legacy_id(legacy_mongo_id, user_id)
|
||||
if agent is None:
|
||||
return False
|
||||
@@ -164,6 +218,7 @@ class AgentsRepository:
|
||||
|
||||
def delete_by_legacy_id(self, legacy_mongo_id: str, user_id: str) -> bool:
|
||||
"""Delete an agent addressed by the Mongo ObjectId string."""
|
||||
legacy_mongo_id = str(legacy_mongo_id) if legacy_mongo_id is not None else None
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"DELETE FROM agents "
|
||||
|
||||
@@ -2,27 +2,53 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
from application.storage.db.base_repository import looks_like_uuid, row_to_dict
|
||||
|
||||
|
||||
_UPDATABLE_SCALARS = {
|
||||
"filename", "upload_path", "mime_type", "size",
|
||||
"content", "token_count", "openai_file_id", "google_file_uri",
|
||||
}
|
||||
_UPDATABLE_JSONB = {"metadata"}
|
||||
|
||||
|
||||
class AttachmentsRepository:
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def create(self, user_id: str, filename: str, upload_path: str, *,
|
||||
mime_type: Optional[str] = None, size: Optional[int] = None,
|
||||
legacy_mongo_id: Optional[str] = None) -> dict:
|
||||
def create(
|
||||
self,
|
||||
user_id: str,
|
||||
filename: str,
|
||||
upload_path: str,
|
||||
*,
|
||||
mime_type: Optional[str] = None,
|
||||
size: Optional[int] = None,
|
||||
content: Optional[str] = None,
|
||||
token_count: Optional[int] = None,
|
||||
openai_file_id: Optional[str] = None,
|
||||
google_file_uri: Optional[str] = None,
|
||||
metadata: Any = None,
|
||||
legacy_mongo_id: Optional[str] = None,
|
||||
) -> dict:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO attachments
|
||||
(user_id, filename, upload_path, mime_type, size, legacy_mongo_id)
|
||||
VALUES
|
||||
(:user_id, :filename, :upload_path, :mime_type, :size, :legacy_mongo_id)
|
||||
INSERT INTO attachments (
|
||||
user_id, filename, upload_path, mime_type, size,
|
||||
content, token_count, openai_file_id, google_file_uri,
|
||||
metadata, legacy_mongo_id
|
||||
)
|
||||
VALUES (
|
||||
:user_id, :filename, :upload_path, :mime_type, :size,
|
||||
:content, :token_count, :openai_file_id, :google_file_uri,
|
||||
CAST(:metadata AS jsonb), :legacy_mongo_id
|
||||
)
|
||||
RETURNING *
|
||||
"""
|
||||
),
|
||||
@@ -32,6 +58,11 @@ class AttachmentsRepository:
|
||||
"upload_path": upload_path,
|
||||
"mime_type": mime_type,
|
||||
"size": size,
|
||||
"content": content,
|
||||
"token_count": token_count,
|
||||
"openai_file_id": openai_file_id,
|
||||
"google_file_uri": google_file_uri,
|
||||
"metadata": json.dumps(metadata) if metadata is not None else None,
|
||||
"legacy_mongo_id": legacy_mongo_id,
|
||||
},
|
||||
)
|
||||
@@ -47,8 +78,76 @@ class AttachmentsRepository:
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def get_any(self, attachment_id: str, user_id: str) -> Optional[dict]:
|
||||
"""Resolve an attachment by either PG UUID or legacy Mongo ObjectId string."""
|
||||
if looks_like_uuid(attachment_id):
|
||||
row = self.get(attachment_id, user_id)
|
||||
if row is not None:
|
||||
return row
|
||||
return self.get_by_legacy_id(attachment_id, user_id)
|
||||
|
||||
def resolve_ids(self, ids: list[str]) -> dict[str, str]:
|
||||
"""Batch-resolve a list of attachment ids (PG UUID *or* Mongo
|
||||
ObjectId or post-cutover route-minted UUID stored only in
|
||||
``legacy_mongo_id``) to their canonical PG ``attachments.id``.
|
||||
|
||||
Returns a ``{input_id: pg_uuid}`` map. Inputs that don't match
|
||||
any row are simply absent from the map (caller decides whether
|
||||
to drop or keep). Single round-trip via ``= ANY(:ids)`` to
|
||||
avoid N+1.
|
||||
|
||||
Resolution prefers ``legacy_mongo_id`` matches first, since
|
||||
the post-cutover ``/store_attachment`` route mints a UUID that
|
||||
is UUID-shaped but only ever lives in ``legacy_mongo_id``
|
||||
(the row's own ``id`` is a fresh PG-generated UUID). A
|
||||
UUID-shaped input that is *also* a real ``attachments.id``
|
||||
falls back to the direct PK match.
|
||||
"""
|
||||
if not ids:
|
||||
return {}
|
||||
# Deduplicate while preserving order for stable output mapping.
|
||||
unique_ids: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for raw in ids:
|
||||
if raw is None:
|
||||
continue
|
||||
s = str(raw)
|
||||
if s in seen:
|
||||
continue
|
||||
seen.add(s)
|
||||
unique_ids.append(s)
|
||||
if not unique_ids:
|
||||
return {}
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT id::text AS id, legacy_mongo_id "
|
||||
"FROM attachments "
|
||||
"WHERE legacy_mongo_id = ANY(:ids) "
|
||||
"OR id::text = ANY(:ids)"
|
||||
),
|
||||
{"ids": unique_ids},
|
||||
)
|
||||
rows = result.fetchall()
|
||||
# Build two indexes so we can apply the legacy-first preference.
|
||||
by_legacy: dict[str, str] = {}
|
||||
by_pk: dict[str, str] = {}
|
||||
for row in rows:
|
||||
pg_id = str(row[0])
|
||||
legacy = row[1]
|
||||
by_pk[pg_id] = pg_id
|
||||
if legacy is not None:
|
||||
by_legacy[str(legacy)] = pg_id
|
||||
out: dict[str, str] = {}
|
||||
for s in unique_ids:
|
||||
if s in by_legacy:
|
||||
out[s] = by_legacy[s]
|
||||
elif s in by_pk:
|
||||
out[s] = by_pk[s]
|
||||
return out
|
||||
|
||||
def get_by_legacy_id(self, legacy_mongo_id: str, user_id: str | None = None) -> Optional[dict]:
|
||||
"""Fetch an attachment by the original Mongo ObjectId string."""
|
||||
legacy_mongo_id = str(legacy_mongo_id) if legacy_mongo_id is not None else None
|
||||
sql = "SELECT * FROM attachments WHERE legacy_mongo_id = :legacy_id"
|
||||
params: dict[str, str] = {"legacy_id": legacy_mongo_id}
|
||||
if user_id is not None:
|
||||
@@ -64,3 +163,86 @@ class AttachmentsRepository:
|
||||
{"user_id": user_id},
|
||||
)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def update(self, attachment_id: str, user_id: str, fields: dict) -> bool:
|
||||
"""Partial update. Used by the LLM providers to cache their
|
||||
uploaded file IDs (``openai_file_id`` / ``google_file_uri``) so we
|
||||
don't re-upload the same blob every call.
|
||||
"""
|
||||
filtered = {
|
||||
k: v for k, v in fields.items()
|
||||
if k in _UPDATABLE_SCALARS | _UPDATABLE_JSONB
|
||||
}
|
||||
if not filtered:
|
||||
return False
|
||||
set_clauses: list[str] = []
|
||||
params: dict = {"id": attachment_id, "user_id": user_id}
|
||||
for col, val in filtered.items():
|
||||
if col in _UPDATABLE_JSONB:
|
||||
set_clauses.append(f"{col} = CAST(:{col} AS jsonb)")
|
||||
params[col] = json.dumps(val) if val is not None else None
|
||||
else:
|
||||
set_clauses.append(f"{col} = :{col}")
|
||||
params[col] = val
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
f"UPDATE attachments SET {', '.join(set_clauses)} "
|
||||
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||
),
|
||||
params,
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def update_any(self, attachment_id: str, user_id: str, fields: dict) -> bool:
|
||||
"""Partial update addressed by either PG UUID or legacy Mongo ObjectId.
|
||||
|
||||
Cutover helper used by the LLM provider file-ID caching hot path:
|
||||
the attachment dict in hand may carry a UUID (post-cutover shape)
|
||||
or an ObjectId-string ``_id`` (legacy). Try the UUID path first
|
||||
when the id looks like a UUID; otherwise fall back to the
|
||||
``legacy_mongo_id`` update. Both branches are user-scoped: the
|
||||
caller must pass the authenticated ``user_id`` so cross-tenant
|
||||
writes are prevented even when the fallback legacy path fires.
|
||||
"""
|
||||
if looks_like_uuid(attachment_id):
|
||||
if self.update(attachment_id, user_id, fields):
|
||||
return True
|
||||
return self.update_by_legacy_id(attachment_id, user_id, fields)
|
||||
|
||||
def update_by_legacy_id(
|
||||
self, legacy_mongo_id: str, user_id: str, fields: dict
|
||||
) -> bool:
|
||||
"""Like ``update`` but addressed by the Mongo ObjectId string.
|
||||
|
||||
Used by the LLM file-ID caching path which, at dual-write time,
|
||||
only has the Mongo ``_id`` in hand (the PG UUID hasn't been
|
||||
looked up yet). Scoped by ``user_id`` so a caller that happens to
|
||||
pass an id matching another user's ``legacy_mongo_id`` cannot
|
||||
mutate the wrong row (IDOR).
|
||||
"""
|
||||
if user_id is None:
|
||||
return False
|
||||
legacy_mongo_id = str(legacy_mongo_id) if legacy_mongo_id is not None else None
|
||||
filtered = {
|
||||
k: v for k, v in fields.items()
|
||||
if k in _UPDATABLE_SCALARS | _UPDATABLE_JSONB
|
||||
}
|
||||
if not filtered:
|
||||
return False
|
||||
set_clauses: list[str] = []
|
||||
params: dict = {"legacy_id": legacy_mongo_id, "user_id": user_id}
|
||||
for col, val in filtered.items():
|
||||
if col in _UPDATABLE_JSONB:
|
||||
set_clauses.append(f"{col} = CAST(:{col} AS jsonb)")
|
||||
params[col] = json.dumps(val) if val is not None else None
|
||||
else:
|
||||
set_clauses.append(f"{col} = :{col}")
|
||||
params[col] = val
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
f"UPDATE attachments SET {', '.join(set_clauses)} "
|
||||
"WHERE legacy_mongo_id = :legacy_id AND user_id = :user_id"
|
||||
),
|
||||
params,
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
@@ -1,55 +1,168 @@
|
||||
"""Repository for the ``connector_sessions`` table.
|
||||
|
||||
Covers operations across connector routes and tools:
|
||||
- upsert session data
|
||||
- find session by user + provider
|
||||
- find session by token
|
||||
- delete session
|
||||
Shape notes:
|
||||
|
||||
* OAuth connectors (Google Drive, SharePoint, Confluence) write one row
|
||||
per ``(user_id, provider)`` with ``server_url = NULL``. The primary
|
||||
lookup key post-callback is ``session_token`` (see
|
||||
``complete_oauth`` style routes), so the table has a standalone
|
||||
unique constraint on ``session_token``.
|
||||
* MCP sessions key off ``server_url`` instead — a single user may have
|
||||
multiple MCP servers, one row each. The composite unique index
|
||||
``(user_id, COALESCE(server_url, ''), provider)`` makes both patterns
|
||||
coexist without collision.
|
||||
* ``session_data`` remains a catch-all JSONB for driver-specific state
|
||||
(tokens that don't fit anywhere else, per-provider scratch data).
|
||||
Promoted columns (``session_token``, ``user_email``, ``status``,
|
||||
``token_info``) are the ones route/auth code queries by.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
|
||||
|
||||
_UPDATABLE_SCALARS = {
|
||||
"server_url", "session_token", "user_email", "status", "expires_at",
|
||||
}
|
||||
_UPDATABLE_JSONB = {"session_data", "token_info"}
|
||||
|
||||
|
||||
def _jsonb(value: Any) -> Any:
|
||||
if value is None:
|
||||
return None
|
||||
return json.dumps(value, default=str)
|
||||
|
||||
|
||||
class ConnectorSessionsRepository:
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def upsert(self, user_id: str, provider: str, session_data: dict) -> dict:
|
||||
def upsert(
|
||||
self,
|
||||
user_id: str,
|
||||
provider: str,
|
||||
session_data: Optional[dict] = None,
|
||||
*,
|
||||
server_url: Optional[str] = None,
|
||||
session_token: Optional[str] = None,
|
||||
user_email: Optional[str] = None,
|
||||
status: Optional[str] = None,
|
||||
token_info: Optional[dict] = None,
|
||||
expires_at: Any = None,
|
||||
legacy_mongo_id: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Insert or update a connector session row.
|
||||
|
||||
Conflict key is ``(user_id, COALESCE(server_url, ''), provider)``
|
||||
so MCP rows (per-server) and OAuth rows (per-provider) both get
|
||||
idempotent upsert semantics.
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO connector_sessions (user_id, provider, session_data)
|
||||
VALUES (:user_id, :provider, CAST(:session_data AS jsonb))
|
||||
ON CONFLICT (user_id, provider)
|
||||
DO UPDATE SET session_data = EXCLUDED.session_data
|
||||
INSERT INTO connector_sessions (
|
||||
user_id, provider, server_url, session_token, user_email,
|
||||
status, token_info, session_data, expires_at, legacy_mongo_id
|
||||
)
|
||||
VALUES (
|
||||
:user_id, :provider, :server_url, :session_token, :user_email,
|
||||
:status, CAST(:token_info AS jsonb),
|
||||
CAST(:session_data AS jsonb), :expires_at, :legacy_mongo_id
|
||||
)
|
||||
ON CONFLICT (user_id, COALESCE(server_url, ''), provider)
|
||||
DO UPDATE SET
|
||||
session_token = COALESCE(EXCLUDED.session_token, connector_sessions.session_token),
|
||||
user_email = COALESCE(EXCLUDED.user_email, connector_sessions.user_email),
|
||||
status = COALESCE(EXCLUDED.status, connector_sessions.status),
|
||||
token_info = COALESCE(EXCLUDED.token_info, connector_sessions.token_info),
|
||||
session_data = EXCLUDED.session_data,
|
||||
expires_at = COALESCE(EXCLUDED.expires_at, connector_sessions.expires_at)
|
||||
RETURNING *
|
||||
"""
|
||||
),
|
||||
{
|
||||
"user_id": user_id,
|
||||
"provider": provider,
|
||||
"session_data": json.dumps(session_data),
|
||||
"server_url": server_url,
|
||||
"session_token": session_token,
|
||||
"user_email": user_email,
|
||||
"status": status,
|
||||
"token_info": _jsonb(token_info),
|
||||
"session_data": _jsonb(session_data or {}),
|
||||
"expires_at": expires_at,
|
||||
"legacy_mongo_id": legacy_mongo_id,
|
||||
},
|
||||
)
|
||||
return row_to_dict(result.fetchone())
|
||||
|
||||
def get_by_user_provider(self, user_id: str, provider: str) -> Optional[dict]:
|
||||
def get_by_user_provider(
|
||||
self, user_id: str, provider: str, *, server_url: Optional[str] = None,
|
||||
) -> Optional[dict]:
|
||||
"""Legacy (user_id, provider) lookup, optionally scoped by server_url.
|
||||
|
||||
Kept for OAuth providers that only have one row per user — they
|
||||
pass ``server_url=None`` and get the single OAuth row.
|
||||
"""
|
||||
sql = (
|
||||
"SELECT * FROM connector_sessions "
|
||||
"WHERE user_id = :user_id AND provider = :provider"
|
||||
)
|
||||
params: dict[str, Any] = {"user_id": user_id, "provider": provider}
|
||||
if server_url is not None:
|
||||
sql += " AND server_url = :server_url"
|
||||
params["server_url"] = server_url
|
||||
result = self._conn.execute(text(sql), params)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def get_by_session_token(self, session_token: str) -> Optional[dict]:
|
||||
"""Post-OAuth-callback lookup.
|
||||
|
||||
Every OAuth flow (Google Drive, SharePoint, Confluence) redirects
|
||||
back with the ``session_token`` as the only handle; the callback
|
||||
route resolves it to the full session row.
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text("SELECT * FROM connector_sessions WHERE session_token = :token"),
|
||||
{"token": session_token},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def get_by_user_and_server_url(
|
||||
self, user_id: str, server_url: str,
|
||||
) -> Optional[dict]:
|
||||
"""MCP-tool lookup: resolve a session by the MCP server URL."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM connector_sessions WHERE user_id = :user_id AND provider = :provider"
|
||||
"SELECT * FROM connector_sessions "
|
||||
"WHERE user_id = :user_id AND server_url = :server_url "
|
||||
"LIMIT 1"
|
||||
),
|
||||
{"user_id": user_id, "provider": provider},
|
||||
{"user_id": user_id, "server_url": server_url},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def get_by_legacy_id(
|
||||
self, legacy_mongo_id: str, user_id: Optional[str] = None,
|
||||
) -> Optional[dict]:
|
||||
legacy_mongo_id = str(legacy_mongo_id) if legacy_mongo_id is not None else None
|
||||
sql = "SELECT * FROM connector_sessions WHERE legacy_mongo_id = :legacy_id"
|
||||
params: dict[str, str] = {"legacy_id": legacy_mongo_id}
|
||||
if user_id is not None:
|
||||
sql += " AND user_id = :user_id"
|
||||
params["user_id"] = user_id
|
||||
result = self._conn.execute(text(sql), params)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def list_for_user(self, user_id: str) -> list[dict]:
|
||||
result = self._conn.execute(
|
||||
text("SELECT * FROM connector_sessions WHERE user_id = :user_id"),
|
||||
@@ -57,9 +170,147 @@ class ConnectorSessionsRepository:
|
||||
)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def delete(self, user_id: str, provider: str) -> bool:
|
||||
def update(self, session_id: str, fields: dict) -> bool:
|
||||
"""Partial update by PG UUID."""
|
||||
filtered = {
|
||||
k: v for k, v in fields.items()
|
||||
if k in _UPDATABLE_SCALARS | _UPDATABLE_JSONB
|
||||
}
|
||||
if not filtered:
|
||||
return False
|
||||
set_clauses: list[str] = []
|
||||
params: dict = {"id": session_id}
|
||||
for col, val in filtered.items():
|
||||
if col in _UPDATABLE_JSONB:
|
||||
set_clauses.append(f"{col} = CAST(:{col} AS jsonb)")
|
||||
params[col] = _jsonb(val)
|
||||
else:
|
||||
set_clauses.append(f"{col} = :{col}")
|
||||
params[col] = val
|
||||
result = self._conn.execute(
|
||||
text("DELETE FROM connector_sessions WHERE user_id = :user_id AND provider = :provider"),
|
||||
{"user_id": user_id, "provider": provider},
|
||||
text(
|
||||
f"UPDATE connector_sessions SET {', '.join(set_clauses)} "
|
||||
"WHERE id = CAST(:id AS uuid)"
|
||||
),
|
||||
params,
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def update_by_legacy_id(self, legacy_mongo_id: str, fields: dict) -> bool:
|
||||
legacy_mongo_id = str(legacy_mongo_id) if legacy_mongo_id is not None else None
|
||||
filtered = {
|
||||
k: v for k, v in fields.items()
|
||||
if k in _UPDATABLE_SCALARS | _UPDATABLE_JSONB
|
||||
}
|
||||
if not filtered:
|
||||
return False
|
||||
set_clauses: list[str] = []
|
||||
params: dict = {"legacy_id": legacy_mongo_id}
|
||||
for col, val in filtered.items():
|
||||
if col in _UPDATABLE_JSONB:
|
||||
set_clauses.append(f"{col} = CAST(:{col} AS jsonb)")
|
||||
params[col] = _jsonb(val)
|
||||
else:
|
||||
set_clauses.append(f"{col} = :{col}")
|
||||
params[col] = val
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
f"UPDATE connector_sessions SET {', '.join(set_clauses)} "
|
||||
"WHERE legacy_mongo_id = :legacy_id"
|
||||
),
|
||||
params,
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def merge_session_data(
|
||||
self,
|
||||
user_id: str,
|
||||
provider: str,
|
||||
server_url: Optional[str],
|
||||
patch: dict,
|
||||
) -> dict:
|
||||
"""Upsert by shallow-merging ``patch`` into ``session_data``.
|
||||
|
||||
Writes ``server_url`` to the scalar column so downstream
|
||||
``get_by_user_and_server_url`` lookups can find the row. If
|
||||
``patch`` still carries a ``"server_url"`` key (legacy callers)
|
||||
it is stripped before merging so the scalar column stays the
|
||||
single source of truth and we don't duplicate it inside the
|
||||
JSONB blob.
|
||||
|
||||
Args:
|
||||
user_id: Owner of the session.
|
||||
provider: Provider tag (e.g. ``"mcp:<base_url>"`` for MCP).
|
||||
server_url: Endpoint to pin the row to. ``None`` is valid
|
||||
for single-row-per-user OAuth providers.
|
||||
patch: Shallow-merge payload for ``session_data``. Keys
|
||||
mapped to ``None`` are *dropped* from the stored doc
|
||||
(used by the redirect-URI-mismatch clear path).
|
||||
|
||||
Returns:
|
||||
The upserted row as a dict.
|
||||
|
||||
Notes:
|
||||
The conflict target matches the table's composite unique
|
||||
constraint ``(user_id, COALESCE(server_url, ''), provider)``
|
||||
so MCP's per-URL rows and OAuth's single-row-per-user rows
|
||||
both upsert idempotently.
|
||||
"""
|
||||
# Defensively strip ``server_url`` from ``patch`` — the scalar
|
||||
# column is authoritative now. Callers still pass it for
|
||||
# backwards compatibility during the transition.
|
||||
patch = {k: v for k, v in patch.items() if k != "server_url"}
|
||||
set_entries = {k: v for k, v in patch.items() if v is not None}
|
||||
drop_keys = [k for k, v in patch.items() if v is None]
|
||||
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO connector_sessions (
|
||||
user_id, provider, server_url, session_data
|
||||
)
|
||||
VALUES (
|
||||
:user_id, :provider, :server_url,
|
||||
CAST(:patch AS jsonb)
|
||||
)
|
||||
ON CONFLICT (user_id, COALESCE(server_url, ''), provider)
|
||||
DO UPDATE SET
|
||||
server_url = COALESCE(EXCLUDED.server_url, connector_sessions.server_url),
|
||||
session_data =
|
||||
(connector_sessions.session_data || EXCLUDED.session_data)
|
||||
- CAST(:drop_keys AS text[])
|
||||
RETURNING *
|
||||
"""
|
||||
),
|
||||
{
|
||||
"user_id": user_id,
|
||||
"provider": provider,
|
||||
"server_url": server_url,
|
||||
"patch": json.dumps(set_entries),
|
||||
"drop_keys": "{" + ",".join(f'"{k}"' for k in drop_keys) + "}",
|
||||
},
|
||||
)
|
||||
return row_to_dict(result.fetchone())
|
||||
|
||||
def delete(
|
||||
self, user_id: str, provider: str, *, server_url: Optional[str] = None,
|
||||
) -> bool:
|
||||
sql = (
|
||||
"DELETE FROM connector_sessions "
|
||||
"WHERE user_id = :user_id AND provider = :provider"
|
||||
)
|
||||
params: dict[str, Any] = {"user_id": user_id, "provider": provider}
|
||||
if server_url is not None:
|
||||
sql += " AND server_url = :server_url"
|
||||
params["server_url"] = server_url
|
||||
result = self._conn.execute(text(sql), params)
|
||||
return result.rowcount > 0
|
||||
|
||||
def delete_by_session_token(self, session_token: str) -> bool:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"DELETE FROM connector_sessions WHERE session_token = :token"
|
||||
),
|
||||
{"token": session_token},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
@@ -20,7 +20,7 @@ from typing import Optional
|
||||
from sqlalchemy import Connection, text
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
from application.storage.db.base_repository import looks_like_uuid, row_to_dict
|
||||
from application.storage.db.models import conversations_table, conversation_messages_table
|
||||
|
||||
|
||||
@@ -38,6 +38,86 @@ class ConversationsRepository:
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Reference translation helpers
|
||||
# ------------------------------------------------------------------
|
||||
#
|
||||
# During the Mongo→Postgres dual-write window, callers routinely
|
||||
# hand us Mongo ObjectId strings (24-char hex) for fields that are
|
||||
# UUID FKs in Postgres (``agent_id``, ``attachments`` entries, ...).
|
||||
# Casting those straight to ``uuid`` raises and the outer dual-write
|
||||
# shim swallows the exception, so the write silently drops. These
|
||||
# helpers translate via the ``legacy_mongo_id`` columns we added
|
||||
# precisely for this purpose.
|
||||
|
||||
def _resolve_agent_ref(self, agent_id_raw: str | None) -> str | None:
|
||||
"""Translate ``agent_id_raw`` to a Postgres UUID string.
|
||||
|
||||
- ``None``/empty → ``None`` (no agent).
|
||||
- Already-UUID-shaped → returned as-is.
|
||||
- Otherwise treated as a Mongo ObjectId and looked up via
|
||||
``agents.legacy_mongo_id``. Returns ``None`` if no PG row
|
||||
exists yet (e.g. the agent was created before Phase 1
|
||||
backfill).
|
||||
"""
|
||||
if not agent_id_raw:
|
||||
return None
|
||||
value = str(agent_id_raw)
|
||||
if looks_like_uuid(value):
|
||||
return value
|
||||
result = self._conn.execute(
|
||||
text("SELECT id FROM agents WHERE legacy_mongo_id = :lid LIMIT 1"),
|
||||
{"lid": value},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return str(row[0]) if row is not None else None
|
||||
|
||||
def _resolve_attachment_refs(
|
||||
self, ids: list[str] | None,
|
||||
) -> list[str]:
|
||||
"""Translate a list of attachment ids to canonical PG
|
||||
``attachments.id`` UUIDs.
|
||||
|
||||
Inputs may be:
|
||||
|
||||
- A Mongo ObjectId string (24-hex), legacy dual-write era —
|
||||
must be looked up via ``attachments.legacy_mongo_id``.
|
||||
- A UUID string that is a real ``attachments.id`` PK.
|
||||
- A UUID string that is *only* present as
|
||||
``attachments.legacy_mongo_id`` — this is the post-cutover
|
||||
shape: ``/store_attachment`` mints a UUID, hands it to the
|
||||
worker, and the worker stashes it in ``legacy_mongo_id``
|
||||
while the row gets a freshly-generated PK. Trusting the
|
||||
input UUID as a PK here orphans the array entry: the column
|
||||
is ``uuid[]`` (no FK), so PG accepts the bad value and all
|
||||
downstream reads via ``AttachmentsRepository.get_any`` miss.
|
||||
|
||||
Resolution therefore tries ``legacy_mongo_id`` first for every
|
||||
id (UUID-shaped or not), then falls back to the direct PK
|
||||
match. Unknown ids are dropped — they'd have failed the
|
||||
``uuid[]`` cast otherwise and the whole row would have vanished
|
||||
via dual-write's exception swallow.
|
||||
"""
|
||||
if not ids:
|
||||
return []
|
||||
# Defer to AttachmentsRepository for the batched lookup so the
|
||||
# legacy-first semantics live in one place.
|
||||
from application.storage.db.repositories.attachments import (
|
||||
AttachmentsRepository,
|
||||
)
|
||||
|
||||
clean: list[str] = [str(raw) for raw in ids if raw is not None]
|
||||
if not clean:
|
||||
return []
|
||||
repo = AttachmentsRepository(self._conn)
|
||||
mapping = repo.resolve_ids(clean)
|
||||
out: list[str] = []
|
||||
for value in clean:
|
||||
mapped = mapping.get(value)
|
||||
if mapped is not None:
|
||||
out.append(mapped)
|
||||
return out
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Conversation CRUD
|
||||
# ------------------------------------------------------------------
|
||||
@@ -65,8 +145,11 @@ class ConversationsRepository:
|
||||
"user_id": user_id,
|
||||
"name": name,
|
||||
}
|
||||
if agent_id:
|
||||
values["agent_id"] = agent_id
|
||||
# ``agent_id`` may arrive as a Mongo ObjectId during the dual-write
|
||||
# window; resolve to a UUID (or drop silently if not yet backfilled).
|
||||
resolved_agent_id = self._resolve_agent_ref(agent_id)
|
||||
if resolved_agent_id:
|
||||
values["agent_id"] = resolved_agent_id
|
||||
if api_key:
|
||||
values["api_key"] = api_key
|
||||
if is_shared_usage:
|
||||
@@ -90,6 +173,7 @@ class ConversationsRepository:
|
||||
provided, the lookup is scoped to rows owned by that user so
|
||||
callers can't accidentally resolve another user's conversation.
|
||||
"""
|
||||
legacy_mongo_id = str(legacy_mongo_id) if legacy_mongo_id is not None else None
|
||||
if user_id is not None:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
@@ -121,6 +205,17 @@ class ConversationsRepository:
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def get_any(self, conversation_id: str, user_id: str) -> Optional[dict]:
|
||||
"""Resolve a conversation by either PG UUID or legacy Mongo ObjectId string.
|
||||
|
||||
Returns a conversation the user owns or has shared access to.
|
||||
"""
|
||||
if looks_like_uuid(conversation_id):
|
||||
row = self.get(conversation_id, user_id)
|
||||
if row is not None:
|
||||
return row
|
||||
return self.get_by_legacy_id(conversation_id, user_id)
|
||||
|
||||
def get_owned(self, conversation_id: str, user_id: str) -> Optional[dict]:
|
||||
"""Fetch a conversation owned by the user (no shared access)."""
|
||||
result = self._conn.execute(
|
||||
@@ -150,6 +245,13 @@ class ConversationsRepository:
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def rename(self, conversation_id: str, user_id: str, name: str) -> bool:
|
||||
# Shape-gate so a non-UUID id (legacy Mongo ObjectId still floating
|
||||
# around in client-side state during the cutover) never reaches the
|
||||
# ``CAST(:id AS uuid)`` — that cast raises on the server and poisons
|
||||
# the enclosing transaction, making every subsequent query on the
|
||||
# same connection fail.
|
||||
if not looks_like_uuid(conversation_id):
|
||||
return False
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"UPDATE conversations SET name = :name, updated_at = now() "
|
||||
@@ -159,7 +261,66 @@ class ConversationsRepository:
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def add_shared_user(self, conversation_id: str, user_to_add: str) -> bool:
|
||||
"""Idempotently append ``user_to_add`` to ``shared_with``.
|
||||
|
||||
Accepts either a PG UUID or a legacy Mongo ObjectId as the
|
||||
conversation id. Mirrors Mongo ``$addToSet`` semantics via the
|
||||
``NOT (:user = ANY(shared_with))`` guard.
|
||||
"""
|
||||
if not user_to_add:
|
||||
return False
|
||||
if looks_like_uuid(conversation_id):
|
||||
sql = (
|
||||
"UPDATE conversations "
|
||||
"SET shared_with = array_append(shared_with, :user), "
|
||||
" updated_at = now() "
|
||||
"WHERE id = CAST(:id AS uuid) "
|
||||
"AND NOT (:user = ANY(shared_with))"
|
||||
)
|
||||
else:
|
||||
sql = (
|
||||
"UPDATE conversations "
|
||||
"SET shared_with = array_append(shared_with, :user), "
|
||||
" updated_at = now() "
|
||||
"WHERE legacy_mongo_id = :id "
|
||||
"AND NOT (:user = ANY(shared_with))"
|
||||
)
|
||||
result = self._conn.execute(
|
||||
text(sql), {"id": conversation_id, "user": user_to_add},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def remove_shared_user(self, conversation_id: str, user_to_remove: str) -> bool:
|
||||
"""Remove ``user_to_remove`` from ``shared_with``. Mirror of Mongo ``$pull``."""
|
||||
if not user_to_remove:
|
||||
return False
|
||||
if looks_like_uuid(conversation_id):
|
||||
sql = (
|
||||
"UPDATE conversations "
|
||||
"SET shared_with = array_remove(shared_with, :user), "
|
||||
" updated_at = now() "
|
||||
"WHERE id = CAST(:id AS uuid) "
|
||||
"AND :user = ANY(shared_with)"
|
||||
)
|
||||
else:
|
||||
sql = (
|
||||
"UPDATE conversations "
|
||||
"SET shared_with = array_remove(shared_with, :user), "
|
||||
" updated_at = now() "
|
||||
"WHERE legacy_mongo_id = :id "
|
||||
"AND :user = ANY(shared_with)"
|
||||
)
|
||||
result = self._conn.execute(
|
||||
text(sql), {"id": conversation_id, "user": user_to_remove},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def set_shared_token(self, conversation_id: str, user_id: str, token: str) -> bool:
|
||||
# Shape-gate: see ``rename`` — prevents transaction poisoning when
|
||||
# a non-UUID id reaches this code path.
|
||||
if not looks_like_uuid(conversation_id):
|
||||
return False
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"UPDATE conversations SET shared_token = :token, updated_at = now() "
|
||||
@@ -179,6 +340,10 @@ class ConversationsRepository:
|
||||
``$set`` + ``$push $slice``). This method is retained for callers
|
||||
that already compute the full merged blob client-side.
|
||||
"""
|
||||
# Shape-gate: see ``rename`` — prevents transaction poisoning when
|
||||
# a non-UUID id reaches this code path.
|
||||
if not looks_like_uuid(conversation_id):
|
||||
return False
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"UPDATE conversations "
|
||||
@@ -205,6 +370,11 @@ class ConversationsRepository:
|
||||
the surrounding object when the row has no ``compression_metadata``
|
||||
yet.
|
||||
"""
|
||||
# Shape-gate: the streaming pipeline may pass through a legacy id
|
||||
# that ``get_by_legacy_id`` couldn't resolve; in that case the id
|
||||
# remains a non-UUID string and the CAST would poison the txn.
|
||||
if not looks_like_uuid(conversation_id):
|
||||
return False
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
@@ -245,6 +415,9 @@ class ConversationsRepository:
|
||||
on ``compression_metadata.compression_points``. Preserves the
|
||||
other top-level keys in ``compression_metadata``.
|
||||
"""
|
||||
# Shape-gate: see ``set_compression_flags``.
|
||||
if not looks_like_uuid(conversation_id):
|
||||
return False
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
@@ -286,6 +459,10 @@ class ConversationsRepository:
|
||||
return result.rowcount > 0
|
||||
|
||||
def delete(self, conversation_id: str, user_id: str) -> bool:
|
||||
# Shape-gate: see ``rename`` — prevents transaction poisoning when
|
||||
# a non-UUID id reaches this code path.
|
||||
if not looks_like_uuid(conversation_id):
|
||||
return False
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"DELETE FROM conversations "
|
||||
@@ -318,6 +495,11 @@ class ConversationsRepository:
|
||||
return [_message_row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def get_message_at(self, conversation_id: str, position: int) -> Optional[dict]:
|
||||
# Shape-gate: see ``rename``. Callers today always pass a resolved
|
||||
# UUID (via ``get_any`` first), but the guard costs nothing and
|
||||
# keeps future callers safe from txn-poisoning.
|
||||
if not looks_like_uuid(conversation_id):
|
||||
return None
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM conversation_messages "
|
||||
@@ -371,7 +553,13 @@ class ConversationsRepository:
|
||||
|
||||
attachments = message.get("attachments")
|
||||
if attachments:
|
||||
values["attachments"] = [str(a) for a in attachments]
|
||||
# Attachment ids may arrive as Mongo ObjectIds during the
|
||||
# dual-write window — resolve each to a PG UUID or drop it.
|
||||
resolved = self._resolve_attachment_refs(
|
||||
[str(a) for a in attachments],
|
||||
)
|
||||
if resolved:
|
||||
values["attachments"] = resolved
|
||||
|
||||
stmt = (
|
||||
pg_insert(conversation_messages_table)
|
||||
@@ -399,6 +587,11 @@ class ConversationsRepository:
|
||||
allowed = {
|
||||
"prompt", "response", "thought", "sources", "tool_calls",
|
||||
"attachments", "model_id", "metadata", "timestamp",
|
||||
# Feedback can be re-set in rare continuation flows; without
|
||||
# it in the whitelist an upstream re-append that happens to
|
||||
# carry feedback would silently lose it. Mirrors
|
||||
# ``set_feedback`` — column is JSONB.
|
||||
"feedback", "feedback_timestamp",
|
||||
}
|
||||
filtered = {k: v for k, v in fields.items() if k in allowed}
|
||||
if not filtered:
|
||||
@@ -411,12 +604,21 @@ class ConversationsRepository:
|
||||
params: dict = {"conv_id": conversation_id, "pos": position}
|
||||
for key, val in filtered.items():
|
||||
col = api_to_col.get(key, key)
|
||||
if key in ("sources", "tool_calls", "metadata"):
|
||||
if key in ("sources", "tool_calls", "metadata", "feedback"):
|
||||
set_parts.append(f"{col} = CAST(:{col} AS jsonb)")
|
||||
params[col] = json.dumps(val) if not isinstance(val, str) else val
|
||||
if val is None:
|
||||
params[col] = None
|
||||
else:
|
||||
params[col] = (
|
||||
json.dumps(val) if not isinstance(val, str) else val
|
||||
)
|
||||
elif key == "attachments":
|
||||
# Attachment ids may be Mongo ObjectIds during the
|
||||
# dual-write window; translate via attachments.legacy_mongo_id.
|
||||
set_parts.append(f"{col} = CAST(:{col} AS uuid[])")
|
||||
params[col] = [str(a) for a in val] if val else []
|
||||
params[col] = self._resolve_attachment_refs(
|
||||
[str(a) for a in val] if val else [],
|
||||
)
|
||||
else:
|
||||
set_parts.append(f"{col} = :{col}")
|
||||
params[col] = val
|
||||
@@ -436,6 +638,10 @@ class ConversationsRepository:
|
||||
Mirrors Mongo's ``$push`` + ``$slice`` that trims queries after an
|
||||
index-based update.
|
||||
"""
|
||||
# Shape-gate: see ``rename`` — prevents transaction poisoning when
|
||||
# a non-UUID id reaches this code path.
|
||||
if not looks_like_uuid(conversation_id):
|
||||
return 0
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"DELETE FROM conversation_messages "
|
||||
@@ -454,6 +660,10 @@ class ConversationsRepository:
|
||||
``feedback`` is a JSONB value, e.g. ``{"text": "thumbs_up",
|
||||
"timestamp": "..."}`` or ``None`` to unset.
|
||||
"""
|
||||
# Shape-gate: see ``rename`` — prevents transaction poisoning when
|
||||
# a non-UUID id reaches this code path.
|
||||
if not looks_like_uuid(conversation_id):
|
||||
return False
|
||||
fb_json = json.dumps(feedback) if feedback is not None else None
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
|
||||
@@ -12,7 +12,7 @@ from typing import Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
from application.storage.db.base_repository import looks_like_uuid, row_to_dict
|
||||
|
||||
|
||||
class NotesRepository:
|
||||
@@ -60,3 +60,29 @@ class NotesRepository:
|
||||
{"user_id": user_id, "tool_id": tool_id},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def get_by_legacy_id(self, legacy_mongo_id: str) -> Optional[dict]:
|
||||
legacy_mongo_id = str(legacy_mongo_id) if legacy_mongo_id is not None else None
|
||||
result = self._conn.execute(
|
||||
text("SELECT * FROM notes WHERE legacy_mongo_id = :legacy"),
|
||||
{"legacy": legacy_mongo_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def get_any(self, identifier: str, user_id: str) -> Optional[dict]:
|
||||
"""Resolve a note by PG UUID or legacy Mongo ObjectId.
|
||||
|
||||
Picks the lookup path from the id shape so non-UUID input never
|
||||
reaches ``CAST(:id AS uuid)`` — that cast raises on the server
|
||||
and poisons the enclosing transaction, making any subsequent
|
||||
query on the same connection fail.
|
||||
"""
|
||||
if looks_like_uuid(identifier):
|
||||
doc = self.get(identifier, user_id)
|
||||
if doc is not None:
|
||||
return doc
|
||||
legacy = self.get_by_legacy_id(identifier)
|
||||
if legacy and legacy.get("user_id") == user_id:
|
||||
return legacy
|
||||
return None
|
||||
|
||||
@@ -18,7 +18,7 @@ from typing import Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
from application.storage.db.base_repository import looks_like_uuid, row_to_dict
|
||||
|
||||
|
||||
class PromptsRepository:
|
||||
@@ -61,6 +61,7 @@ class PromptsRepository:
|
||||
|
||||
def get_by_legacy_id(self, legacy_mongo_id: str, user_id: str | None = None) -> Optional[dict]:
|
||||
"""Fetch a prompt by the original Mongo ObjectId string."""
|
||||
legacy_mongo_id = str(legacy_mongo_id) if legacy_mongo_id is not None else None
|
||||
sql = "SELECT * FROM prompts WHERE legacy_mongo_id = :legacy_id"
|
||||
params: dict[str, str] = {"legacy_id": legacy_mongo_id}
|
||||
if user_id is not None:
|
||||
@@ -70,6 +71,20 @@ class PromptsRepository:
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def get_any(self, identifier: str, user_id: str) -> Optional[dict]:
|
||||
"""Resolve a prompt by PG UUID or legacy Mongo ObjectId.
|
||||
|
||||
Picks the lookup path from the id shape so non-UUID input never
|
||||
reaches ``CAST(:id AS uuid)`` — that cast raises on the server
|
||||
and poisons the enclosing transaction, making any subsequent
|
||||
query on the same connection fail.
|
||||
"""
|
||||
if looks_like_uuid(identifier):
|
||||
doc = self.get(identifier, user_id)
|
||||
if doc is not None:
|
||||
return doc
|
||||
return self.get_by_legacy_id(identifier, user_id)
|
||||
|
||||
def get_for_rendering(self, prompt_id: str) -> Optional[dict]:
|
||||
"""Fetch prompt content by ID without user scoping.
|
||||
|
||||
@@ -110,6 +125,7 @@ class PromptsRepository:
|
||||
content: str,
|
||||
) -> bool:
|
||||
"""Update a prompt addressed by the Mongo ObjectId string."""
|
||||
legacy_mongo_id = str(legacy_mongo_id) if legacy_mongo_id is not None else None
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
@@ -135,6 +151,7 @@ class PromptsRepository:
|
||||
|
||||
def delete_by_legacy_id(self, legacy_mongo_id: str, user_id: str) -> bool:
|
||||
"""Delete a prompt addressed by the Mongo ObjectId string."""
|
||||
legacy_mongo_id = str(legacy_mongo_id) if legacy_mongo_id is not None else None
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"DELETE FROM prompts "
|
||||
|
||||
@@ -16,7 +16,7 @@ from typing import Optional
|
||||
from sqlalchemy import Connection, text
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
from application.storage.db.base_repository import looks_like_uuid, row_to_dict
|
||||
from application.storage.db.models import shared_conversations_table
|
||||
|
||||
|
||||
@@ -131,6 +131,14 @@ class SharedConversationsRepository:
|
||||
return row_to_dict(result.fetchone())
|
||||
|
||||
def find_by_uuid(self, share_uuid: str) -> Optional[dict]:
|
||||
# Shape-gate: the public ``/api/shared_conversation/<identifier>``
|
||||
# endpoint threads the URL path segment straight here. A non-UUID
|
||||
# (e.g. a legacy Mongo ObjectId still embedded in an old link or
|
||||
# an outright garbage path) must resolve to ``None`` rather than
|
||||
# raise — the CAST would otherwise poison the txn and mask the
|
||||
# real "not found" response behind a generic 400.
|
||||
if not looks_like_uuid(share_uuid):
|
||||
return None
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM shared_conversations "
|
||||
|
||||
@@ -3,33 +3,137 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import Connection, func, text
|
||||
from sqlalchemy import Connection, func, select, text
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
from application.storage.db.base_repository import looks_like_uuid, row_to_dict
|
||||
from application.storage.db.models import sources_table
|
||||
|
||||
|
||||
_SCALAR_COLUMNS = {
|
||||
"name", "type", "retriever", "sync_frequency", "tokens", "file_path",
|
||||
"language", "model", "date",
|
||||
}
|
||||
_JSONB_COLUMNS = {"metadata", "remote_data", "directory_structure", "file_name_map"}
|
||||
_ALLOWED_COLUMNS = _SCALAR_COLUMNS | _JSONB_COLUMNS
|
||||
|
||||
# Whitelist for sort columns exposed via ``list_for_user``. Anything not in
|
||||
# this set falls back to ``date`` so user-supplied sort params can't be
|
||||
# interpolated into SQL unchecked.
|
||||
_SORTABLE_COLUMNS = {"date", "name", "tokens", "type", "created_at", "updated_at"}
|
||||
|
||||
|
||||
def _escape_like(pattern: str) -> str:
|
||||
"""Escape wildcards so a user-supplied substring is matched literally.
|
||||
|
||||
We use ``LIKE ESCAPE '\\'`` on the query side so backslash, percent, and
|
||||
underscore in the input don't accidentally turn into regex-like wildcards.
|
||||
"""
|
||||
return (
|
||||
pattern
|
||||
.replace("\\", "\\\\")
|
||||
.replace("%", "\\%")
|
||||
.replace("_", "\\_")
|
||||
)
|
||||
|
||||
|
||||
def _coerce_jsonb(value: Any) -> Any:
|
||||
"""Normalize incoming JSONB values for the Core ``Table.update()`` path.
|
||||
|
||||
``remote_data`` in particular arrives as either a dict or a JSON string
|
||||
(the legacy Mongo docs stored both shapes). Strings are parsed so the
|
||||
stored representation is always structured JSONB; dicts/lists pass
|
||||
through untouched for the SQLAlchemy JSONB type processor.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, (dict, list)):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
stripped = value.strip()
|
||||
if not stripped:
|
||||
return None
|
||||
try:
|
||||
return json.loads(stripped)
|
||||
except json.JSONDecodeError:
|
||||
return {"raw": value}
|
||||
return value
|
||||
|
||||
|
||||
class SourcesRepository:
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def create(self, name: str, *, user_id: str,
|
||||
type: Optional[str] = None, metadata: Optional[dict] = None) -> dict:
|
||||
def create(
|
||||
self,
|
||||
name: str,
|
||||
*,
|
||||
source_id: Optional[str] = None,
|
||||
user_id: str,
|
||||
type: Optional[str] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
retriever: Optional[str] = None,
|
||||
sync_frequency: Optional[str] = None,
|
||||
tokens: Optional[str] = None,
|
||||
file_path: Optional[str] = None,
|
||||
remote_data: Any = None,
|
||||
directory_structure: Any = None,
|
||||
file_name_map: Any = None,
|
||||
language: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
date: Any = None,
|
||||
legacy_mongo_id: Optional[str] = None,
|
||||
) -> dict:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO sources (user_id, name, type, metadata)
|
||||
VALUES (:user_id, :name, :type, CAST(:metadata AS jsonb))
|
||||
INSERT INTO sources (
|
||||
id, user_id, name, type, metadata,
|
||||
retriever, sync_frequency, tokens, file_path,
|
||||
remote_data, directory_structure, file_name_map,
|
||||
language, model, date, legacy_mongo_id
|
||||
)
|
||||
VALUES (
|
||||
COALESCE(CAST(:source_id AS uuid), gen_random_uuid()),
|
||||
:user_id, :name, :type, CAST(:metadata AS jsonb),
|
||||
:retriever, :sync_frequency, :tokens, :file_path,
|
||||
CAST(:remote_data AS jsonb),
|
||||
CAST(:directory_structure AS jsonb),
|
||||
CAST(:file_name_map AS jsonb),
|
||||
:language, :model,
|
||||
COALESCE(:date, now()),
|
||||
:legacy_mongo_id
|
||||
)
|
||||
RETURNING *
|
||||
"""
|
||||
),
|
||||
{
|
||||
"source_id": source_id,
|
||||
"user_id": user_id,
|
||||
"name": name,
|
||||
"type": type,
|
||||
"metadata": json.dumps(metadata or {}),
|
||||
"retriever": retriever,
|
||||
"sync_frequency": sync_frequency,
|
||||
"tokens": tokens,
|
||||
"file_path": file_path,
|
||||
"remote_data": (
|
||||
None if remote_data is None
|
||||
else json.dumps(_coerce_jsonb(remote_data))
|
||||
),
|
||||
"directory_structure": (
|
||||
None if directory_structure is None
|
||||
else json.dumps(_coerce_jsonb(directory_structure))
|
||||
),
|
||||
"file_name_map": (
|
||||
None if file_name_map is None
|
||||
else json.dumps(_coerce_jsonb(file_name_map))
|
||||
),
|
||||
"language": language,
|
||||
"model": model,
|
||||
"date": date,
|
||||
"legacy_mongo_id": legacy_mongo_id,
|
||||
},
|
||||
)
|
||||
return row_to_dict(result.fetchone())
|
||||
@@ -42,25 +146,124 @@ class SourcesRepository:
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def list_for_user(self, user_id: str) -> list[dict]:
|
||||
result = self._conn.execute(
|
||||
text("SELECT * FROM sources WHERE user_id = :user_id ORDER BY created_at DESC"),
|
||||
{"user_id": user_id},
|
||||
)
|
||||
def get_any(self, source_id: str, user_id: str) -> Optional[dict]:
|
||||
"""Resolve a source by either PG UUID or legacy Mongo ObjectId string.
|
||||
|
||||
Cutover helper: URLs / bookmarks may still hold Mongo ObjectIds.
|
||||
Tries the UUID path first, then falls back to ``legacy_mongo_id``.
|
||||
Both paths are scoped by ``user_id``.
|
||||
"""
|
||||
if looks_like_uuid(source_id):
|
||||
row = self.get(source_id, user_id)
|
||||
if row is not None:
|
||||
return row
|
||||
return self.get_by_legacy_id(source_id, user_id)
|
||||
|
||||
def list_for_user(
|
||||
self,
|
||||
user_id: str,
|
||||
*,
|
||||
limit: Optional[int] = None,
|
||||
offset: int = 0,
|
||||
search_term: Optional[str] = None,
|
||||
sort_field: str = "created_at",
|
||||
sort_order: str = "desc",
|
||||
) -> list[dict]:
|
||||
"""Return sources owned by ``user_id``, paginated and optionally filtered.
|
||||
|
||||
All pagination, filtering, and sorting are pushed into SQL so large
|
||||
accounts don't materialize their full source list in Python for every
|
||||
page. See ``PaginatedSources`` in the sources routes for the matching
|
||||
call site.
|
||||
|
||||
Args:
|
||||
user_id: Scope rows to this owner.
|
||||
limit: Page size. ``None`` returns every matching row (legacy
|
||||
full-list path used by ``CombinedJson``).
|
||||
offset: Rows to skip before collecting ``limit`` results.
|
||||
search_term: Case-insensitive substring filter on ``name``.
|
||||
``%`` and ``_`` in the input are escaped so they match
|
||||
literally rather than as LIKE wildcards.
|
||||
sort_field: Column to sort by. Unknown values fall back to
|
||||
``date``. Resolved against ``sources_table.c`` so the
|
||||
column identity is bound by SQLAlchemy — user input never
|
||||
reaches the emitted SQL as a string.
|
||||
sort_order: ``"asc"`` or ``"desc"``; anything else is treated
|
||||
as ``"desc"``.
|
||||
|
||||
Returns:
|
||||
A list of source rows as plain dicts (via ``row_to_dict``).
|
||||
"""
|
||||
column_name = sort_field if sort_field in _SORTABLE_COLUMNS else "date"
|
||||
sort_column = sources_table.c[column_name]
|
||||
ascending = sort_order.lower() == "asc"
|
||||
|
||||
stmt = select(sources_table).where(sources_table.c.user_id == user_id)
|
||||
if search_term:
|
||||
stmt = stmt.where(
|
||||
sources_table.c.name.ilike(
|
||||
f"%{_escape_like(search_term)}%",
|
||||
escape="\\",
|
||||
)
|
||||
)
|
||||
|
||||
# ``id`` is appended as a stable tiebreaker so paginated windows
|
||||
# are deterministic across equal sort keys.
|
||||
id_column = sources_table.c.id
|
||||
if ascending:
|
||||
stmt = stmt.order_by(sort_column.asc(), id_column.asc())
|
||||
else:
|
||||
stmt = stmt.order_by(sort_column.desc(), id_column.desc())
|
||||
|
||||
if limit is not None:
|
||||
stmt = stmt.limit(limit).offset(offset)
|
||||
|
||||
result = self._conn.execute(stmt)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def count_for_user(
|
||||
self,
|
||||
user_id: str,
|
||||
*,
|
||||
search_term: Optional[str] = None,
|
||||
) -> int:
|
||||
"""Return the count of rows that ``list_for_user`` would produce.
|
||||
|
||||
The filter mirrors ``list_for_user`` exactly so ``total`` and the
|
||||
paginated window stay consistent page-to-page.
|
||||
|
||||
Args:
|
||||
user_id: Scope rows to this owner.
|
||||
search_term: Same substring filter semantics as
|
||||
``list_for_user``; ``None``/empty disables the filter.
|
||||
|
||||
Returns:
|
||||
The total number of matching rows.
|
||||
"""
|
||||
stmt = (
|
||||
select(func.count())
|
||||
.select_from(sources_table)
|
||||
.where(sources_table.c.user_id == user_id)
|
||||
)
|
||||
if search_term:
|
||||
stmt = stmt.where(
|
||||
sources_table.c.name.ilike(
|
||||
f"%{_escape_like(search_term)}%",
|
||||
escape="\\",
|
||||
)
|
||||
)
|
||||
result = self._conn.execute(stmt)
|
||||
row = result.fetchone()
|
||||
return int(row[0]) if row is not None else 0
|
||||
|
||||
def update(self, source_id: str, user_id: str, fields: dict) -> None:
|
||||
allowed = {"name", "type", "metadata"}
|
||||
filtered = {k: v for k, v in fields.items() if k in allowed}
|
||||
filtered = {k: v for k, v in fields.items() if k in _ALLOWED_COLUMNS}
|
||||
if not filtered:
|
||||
return
|
||||
|
||||
# Pass Python objects directly for JSONB columns when using
|
||||
# SQLAlchemy Core .update() — the JSONB type processor json.dumps
|
||||
# them itself; pre-serialising here would double-encode and the
|
||||
# value would round-trip as a JSON string instead of the original
|
||||
# dict.
|
||||
values: dict = dict(filtered)
|
||||
values: dict = {}
|
||||
for col, val in filtered.items():
|
||||
values[col] = _coerce_jsonb(val) if col in _JSONB_COLUMNS else val
|
||||
values["updated_at"] = func.now()
|
||||
|
||||
t = sources_table
|
||||
@@ -72,6 +275,47 @@ class SourcesRepository:
|
||||
)
|
||||
self._conn.execute(stmt)
|
||||
|
||||
def get_by_legacy_id(
|
||||
self, legacy_mongo_id: str, user_id: Optional[str] = None,
|
||||
) -> Optional[dict]:
|
||||
legacy_mongo_id = str(legacy_mongo_id) if legacy_mongo_id is not None else None
|
||||
sql = "SELECT * FROM sources WHERE legacy_mongo_id = :legacy_id"
|
||||
params: dict[str, str] = {"legacy_id": legacy_mongo_id}
|
||||
if user_id is not None:
|
||||
sql += " AND user_id = :user_id"
|
||||
params["user_id"] = user_id
|
||||
result = self._conn.execute(text(sql), params)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def update_by_legacy_id(
|
||||
self, legacy_mongo_id: str, user_id: str, fields: dict,
|
||||
) -> bool:
|
||||
"""Update a source addressed by the Mongo ObjectId string.
|
||||
|
||||
Used by dual_write call sites that hold the Mongo ``_id`` but
|
||||
haven't resolved the PG UUID yet. Returns ``True`` if a row was
|
||||
updated (i.e. the legacy id was found).
|
||||
"""
|
||||
legacy_mongo_id = str(legacy_mongo_id) if legacy_mongo_id is not None else None
|
||||
row = self.get_by_legacy_id(legacy_mongo_id, user_id)
|
||||
if row is None:
|
||||
return False
|
||||
self.update(str(row["id"]), user_id, fields)
|
||||
return True
|
||||
|
||||
def delete_by_legacy_id(self, legacy_mongo_id: str, user_id: str) -> bool:
|
||||
"""Delete by Mongo ObjectId. Used by dual_write in DeleteOldIndexes."""
|
||||
legacy_mongo_id = str(legacy_mongo_id) if legacy_mongo_id is not None else None
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"DELETE FROM sources "
|
||||
"WHERE legacy_mongo_id = :legacy_id AND user_id = :user_id"
|
||||
),
|
||||
{"legacy_id": legacy_mongo_id, "user_id": user_id},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def delete(self, source_id: str, user_id: str) -> bool:
|
||||
result = self._conn.execute(
|
||||
text("DELETE FROM sources WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
|
||||
|
||||
@@ -1,41 +1,78 @@
|
||||
"""Repository for the ``todos`` table.
|
||||
|
||||
Covers the operations in ``application/agents/tools/todo_list.py``.
|
||||
Note: the Mongo schema uses ``todo_id`` (sequential int) and ``status`` (text),
|
||||
while the Postgres schema uses ``completed`` (boolean) and the UUID ``id`` as PK.
|
||||
The repository bridges both shapes.
|
||||
|
||||
The Mongo schema uses ``todo_id`` (a per-tool monotonic integer that the
|
||||
LLM uses as its handle) and ``status`` ("open"/"completed"). The Postgres
|
||||
schema mirrors that with a dedicated ``todo_id INTEGER`` column (unique
|
||||
per ``tool_id`` via a partial index) for the LLM-facing handle, while the
|
||||
primary key remains a UUID, and ``status`` is collapsed to a ``completed``
|
||||
boolean. ``legacy_mongo_id`` lets the backfill stay idempotent across
|
||||
reruns.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
from application.storage.db.base_repository import looks_like_uuid, row_to_dict
|
||||
|
||||
|
||||
class TodosRepository:
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def create(self, user_id: str, tool_id: str, title: str) -> dict:
|
||||
def create(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_id: str,
|
||||
title: str,
|
||||
*,
|
||||
todo_id: Optional[int] = None,
|
||||
legacy_mongo_id: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Insert a todo row.
|
||||
|
||||
Allocates the per-tool monotonic ``todo_id`` inside the same
|
||||
transaction when the caller does not supply one. The allocation
|
||||
is ``COALESCE(MAX(todo_id), 0) + 1`` scoped to ``tool_id``; the
|
||||
partial unique index ``todos_tool_todo_id_uidx`` enforces
|
||||
correctness if two callers race.
|
||||
"""
|
||||
if todo_id is None:
|
||||
todo_id = self._conn.execute(
|
||||
text(
|
||||
"SELECT COALESCE(MAX(todo_id), 0) + 1 FROM todos "
|
||||
"WHERE tool_id = CAST(:tool_id AS uuid)"
|
||||
),
|
||||
{"tool_id": tool_id},
|
||||
).scalar_one()
|
||||
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO todos (user_id, tool_id, title)
|
||||
VALUES (:user_id, CAST(:tool_id AS uuid), :title)
|
||||
INSERT INTO todos (user_id, tool_id, todo_id, title, legacy_mongo_id)
|
||||
VALUES (:user_id, CAST(:tool_id AS uuid), :todo_id, :title, :legacy_mongo_id)
|
||||
RETURNING *
|
||||
"""
|
||||
),
|
||||
{"user_id": user_id, "tool_id": tool_id, "title": title},
|
||||
{
|
||||
"user_id": user_id,
|
||||
"tool_id": tool_id,
|
||||
"todo_id": todo_id,
|
||||
"title": title,
|
||||
"legacy_mongo_id": legacy_mongo_id,
|
||||
},
|
||||
)
|
||||
return row_to_dict(result.fetchone())
|
||||
|
||||
def get(self, todo_id: str, user_id: str) -> Optional[dict]:
|
||||
def get(self, todo_uuid: str, user_id: str) -> Optional[dict]:
|
||||
"""Look up a todo by its UUID primary key."""
|
||||
result = self._conn.execute(
|
||||
text("SELECT * FROM todos WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
|
||||
{"id": todo_id, "user_id": user_id},
|
||||
{"id": todo_uuid, "user_id": user_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
@@ -50,29 +87,172 @@ class TodosRepository:
|
||||
)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def update_title(self, todo_id: str, user_id: str, title: str) -> bool:
|
||||
def list_for_tool(self, user_id: str, tool_id: str) -> list[dict]:
|
||||
"""Return all todos for a (user, tool) ordered by ``todo_id``."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM todos WHERE user_id = :user_id "
|
||||
"AND tool_id = CAST(:tool_id AS uuid) "
|
||||
"ORDER BY todo_id NULLS LAST, created_at"
|
||||
),
|
||||
{"user_id": user_id, "tool_id": tool_id},
|
||||
)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def get_by_tool_and_todo_id(
|
||||
self, user_id: str, tool_id: str, todo_id: int
|
||||
) -> Optional[dict]:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM todos WHERE user_id = :user_id "
|
||||
"AND tool_id = CAST(:tool_id AS uuid) AND todo_id = :todo_id"
|
||||
),
|
||||
{"user_id": user_id, "tool_id": tool_id, "todo_id": todo_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def update_title(self, todo_uuid: str, user_id: str, title: str) -> bool:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"UPDATE todos SET title = :title, updated_at = now() "
|
||||
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||
),
|
||||
{"id": todo_id, "user_id": user_id, "title": title},
|
||||
{"id": todo_uuid, "user_id": user_id, "title": title},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def set_completed(self, todo_id: str, user_id: str, completed: bool = True) -> bool:
|
||||
def update_title_by_tool_and_todo_id(
|
||||
self, user_id: str, tool_id: str, todo_id: int, title: str
|
||||
) -> bool:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"UPDATE todos SET title = :title, updated_at = now() "
|
||||
"WHERE user_id = :user_id AND tool_id = CAST(:tool_id AS uuid) "
|
||||
"AND todo_id = :todo_id"
|
||||
),
|
||||
{
|
||||
"user_id": user_id,
|
||||
"tool_id": tool_id,
|
||||
"todo_id": todo_id,
|
||||
"title": title,
|
||||
},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def set_completed(
|
||||
self,
|
||||
user_id_or_uuid: str,
|
||||
tool_id_or_user_id: str,
|
||||
todo_id_or_completed: Any,
|
||||
completed: Optional[bool] = None,
|
||||
) -> bool:
|
||||
"""Mark a todo's ``completed`` flag.
|
||||
|
||||
Two call shapes are supported during the migration window:
|
||||
|
||||
* Legacy UUID form (kept for existing tests):
|
||||
``set_completed(todo_uuid, user_id, completed: bool)``.
|
||||
* Per-tool integer-handle form (used by the tool's dual-write):
|
||||
``set_completed(user_id, tool_id, todo_id: int, completed: bool)``.
|
||||
"""
|
||||
if completed is None:
|
||||
# Legacy three-arg form: (todo_uuid, user_id, completed)
|
||||
todo_uuid = user_id_or_uuid
|
||||
user_id = tool_id_or_user_id
|
||||
completed_value = bool(todo_id_or_completed)
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"UPDATE todos SET completed = :completed, updated_at = now() "
|
||||
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||
),
|
||||
{"id": todo_uuid, "user_id": user_id, "completed": completed_value},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
# New form: (user_id, tool_id, todo_id, completed)
|
||||
user_id = user_id_or_uuid
|
||||
tool_id = tool_id_or_user_id
|
||||
todo_id = int(todo_id_or_completed)
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"UPDATE todos SET completed = :completed, updated_at = now() "
|
||||
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||
"WHERE user_id = :user_id AND tool_id = CAST(:tool_id AS uuid) "
|
||||
"AND todo_id = :todo_id"
|
||||
),
|
||||
{"id": todo_id, "user_id": user_id, "completed": completed},
|
||||
{
|
||||
"user_id": user_id,
|
||||
"tool_id": tool_id,
|
||||
"todo_id": todo_id,
|
||||
"completed": bool(completed),
|
||||
},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def delete(self, todo_id: str, user_id: str) -> bool:
|
||||
def delete(self, todo_uuid: str, user_id: str) -> bool:
|
||||
result = self._conn.execute(
|
||||
text("DELETE FROM todos WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
|
||||
{"id": todo_id, "user_id": user_id},
|
||||
{"id": todo_uuid, "user_id": user_id},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def delete_by_tool_and_todo_id(
|
||||
self, user_id: str, tool_id: str, todo_id: int
|
||||
) -> bool:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"DELETE FROM todos WHERE user_id = :user_id "
|
||||
"AND tool_id = CAST(:tool_id AS uuid) AND todo_id = :todo_id"
|
||||
),
|
||||
{"user_id": user_id, "tool_id": tool_id, "todo_id": todo_id},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def get_by_legacy_id(self, legacy_mongo_id: str) -> Optional[dict]:
|
||||
legacy_mongo_id = str(legacy_mongo_id) if legacy_mongo_id is not None else None
|
||||
result = self._conn.execute(
|
||||
text("SELECT * FROM todos WHERE legacy_mongo_id = :legacy"),
|
||||
{"legacy": legacy_mongo_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def get_any(self, identifier: str, user_id: str) -> Optional[dict]:
|
||||
"""Resolve a todo by PG UUID or legacy Mongo ObjectId.
|
||||
|
||||
Picks the lookup path from the id shape so non-UUID input never
|
||||
reaches ``CAST(:id AS uuid)`` — that cast raises on the server
|
||||
and poisons the enclosing transaction, making any subsequent
|
||||
query on the same connection fail.
|
||||
"""
|
||||
if looks_like_uuid(identifier):
|
||||
doc = self.get(identifier, user_id)
|
||||
if doc is not None:
|
||||
return doc
|
||||
legacy = self.get_by_legacy_id(identifier)
|
||||
if legacy and legacy.get("user_id") == user_id:
|
||||
return legacy
|
||||
return None
|
||||
|
||||
def update_by_legacy_id(
|
||||
self,
|
||||
legacy_mongo_id: str,
|
||||
*,
|
||||
title: Optional[str] = None,
|
||||
completed: Optional[bool] = None,
|
||||
) -> bool:
|
||||
legacy_mongo_id = str(legacy_mongo_id) if legacy_mongo_id is not None else None
|
||||
sets = []
|
||||
params: dict[str, Any] = {"legacy": legacy_mongo_id}
|
||||
if title is not None:
|
||||
sets.append("title = :title")
|
||||
params["title"] = title
|
||||
if completed is not None:
|
||||
sets.append("completed = :completed")
|
||||
params["completed"] = bool(completed)
|
||||
if not sets:
|
||||
return False
|
||||
sets.append("updated_at = now()")
|
||||
sql = "UPDATE todos SET " + ", ".join(sets) + " WHERE legacy_mongo_id = :legacy"
|
||||
result = self._conn.execute(text(sql), params)
|
||||
return result.rowcount > 0
|
||||
|
||||
@@ -33,6 +33,24 @@ class TokenUsageRepository:
|
||||
generated_tokens: int = 0,
|
||||
timestamp: Optional[datetime] = None,
|
||||
) -> None:
|
||||
# Attribution guard: the ``token_usage_attribution_chk`` CHECK
|
||||
# constraint requires at least one of ``user_id`` / ``api_key``
|
||||
# to be non-null. Raise here for a clear error rather than
|
||||
# relying on the DB to reject the row.
|
||||
if not user_id and not api_key:
|
||||
raise ValueError("token_usage insert requires user_id or api_key")
|
||||
|
||||
# ``agent_id`` is a UUID column. Legacy callers occasionally pass
|
||||
# a Mongo ObjectId string (24 hex chars) — those would make
|
||||
# psycopg raise at CAST time. Coerce anything that isn't shaped
|
||||
# like a UUID (36 chars with hyphens) to NULL so a stray legacy
|
||||
# id never breaks token accounting.
|
||||
agent_id_uuid: Optional[str] = None
|
||||
if agent_id:
|
||||
s = str(agent_id)
|
||||
if len(s) == 36 and "-" in s:
|
||||
agent_id_uuid = s
|
||||
|
||||
self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
@@ -48,7 +66,7 @@ class TokenUsageRepository:
|
||||
{
|
||||
"user_id": user_id,
|
||||
"api_key": api_key,
|
||||
"agent_id": agent_id,
|
||||
"agent_id": agent_id_uuid,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"generated_tokens": generated_tokens,
|
||||
"timestamp": timestamp,
|
||||
@@ -79,6 +97,74 @@ class TokenUsageRepository:
|
||||
)
|
||||
return result.scalar()
|
||||
|
||||
def bucketed_totals(
|
||||
self,
|
||||
*,
|
||||
bucket_unit: str,
|
||||
user_id: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
timestamp_gte: Optional[datetime] = None,
|
||||
timestamp_lt: Optional[datetime] = None,
|
||||
) -> list[dict]:
|
||||
"""Sum ``prompt_tokens`` / ``generated_tokens`` bucketed by time.
|
||||
|
||||
Replacement for the legacy Mongo ``$dateToString`` aggregation
|
||||
used by the analytics dashboard. The ``bucket`` format string
|
||||
mirrors Mongo's output so the route layer doesn't reshape:
|
||||
``"YYYY-MM-DD HH:MM:00"`` (minute), ``"YYYY-MM-DD HH:00"``
|
||||
(hour), ``"YYYY-MM-DD"`` (day). Rows are ordered by bucket ASC.
|
||||
"""
|
||||
formats = {
|
||||
"minute": "YYYY-MM-DD HH24:MI:00",
|
||||
"hour": "YYYY-MM-DD HH24:00",
|
||||
"day": "YYYY-MM-DD",
|
||||
}
|
||||
if bucket_unit not in formats:
|
||||
raise ValueError(f"unsupported bucket_unit: {bucket_unit!r}")
|
||||
fmt = formats[bucket_unit]
|
||||
|
||||
clauses: list[str] = []
|
||||
params: dict = {"fmt": fmt}
|
||||
if user_id is not None:
|
||||
clauses.append("user_id = :user_id")
|
||||
params["user_id"] = user_id
|
||||
if api_key is not None:
|
||||
clauses.append("api_key = :api_key")
|
||||
params["api_key"] = api_key
|
||||
if agent_id is not None:
|
||||
clauses.append("agent_id = CAST(:agent_id AS uuid)")
|
||||
params["agent_id"] = agent_id
|
||||
if timestamp_gte is not None:
|
||||
clauses.append("timestamp >= :timestamp_gte")
|
||||
params["timestamp_gte"] = timestamp_gte
|
||||
if timestamp_lt is not None:
|
||||
clauses.append("timestamp < :timestamp_lt")
|
||||
params["timestamp_lt"] = timestamp_lt
|
||||
where = ("WHERE " + " AND ".join(clauses)) if clauses else ""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
f"""
|
||||
SELECT to_char(timestamp AT TIME ZONE 'UTC', :fmt) AS bucket,
|
||||
COALESCE(SUM(prompt_tokens), 0) AS prompt_tokens,
|
||||
COALESCE(SUM(generated_tokens), 0) AS generated_tokens
|
||||
FROM token_usage
|
||||
{where}
|
||||
GROUP BY bucket
|
||||
ORDER BY bucket ASC
|
||||
"""
|
||||
),
|
||||
params,
|
||||
)
|
||||
return [
|
||||
{
|
||||
"bucket": row._mapping["bucket"],
|
||||
"prompt_tokens": int(row._mapping["prompt_tokens"]),
|
||||
"generated_tokens": int(row._mapping["generated_tokens"]),
|
||||
}
|
||||
for row in result.fetchall()
|
||||
]
|
||||
|
||||
def count_in_range(
|
||||
self,
|
||||
*,
|
||||
|
||||
@@ -82,3 +82,34 @@ class UserLogsRepository:
|
||||
rows = [row_to_dict(r) for r in result.fetchall()]
|
||||
has_more = len(rows) > page_size
|
||||
return rows[:page_size], has_more
|
||||
|
||||
def find_by_api_key(
|
||||
self,
|
||||
api_key: str,
|
||||
*,
|
||||
timestamp_gte: Optional[datetime] = None,
|
||||
timestamp_lt: Optional[datetime] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> list[dict]:
|
||||
"""Return user_logs rows whose ``data->>'api_key'`` matches ``api_key``.
|
||||
|
||||
Replacement for the legacy Mongo filter by top-level ``api_key``;
|
||||
on the PG side the per-request payload lives in ``data`` JSONB,
|
||||
so the filter reaches in via ``data->>'api_key'``. Rows are
|
||||
ordered by ``timestamp DESC`` to match the Mongo sort.
|
||||
"""
|
||||
clauses = ["data->>'api_key' = :api_key"]
|
||||
params: dict = {"api_key": api_key}
|
||||
if timestamp_gte is not None:
|
||||
clauses.append("timestamp >= :timestamp_gte")
|
||||
params["timestamp_gte"] = timestamp_gte
|
||||
if timestamp_lt is not None:
|
||||
clauses.append("timestamp < :timestamp_lt")
|
||||
params["timestamp_lt"] = timestamp_lt
|
||||
where = " AND ".join(clauses)
|
||||
sql = f"SELECT * FROM user_logs WHERE {where} ORDER BY timestamp DESC"
|
||||
if limit is not None:
|
||||
sql += " LIMIT :limit"
|
||||
params["limit"] = limit
|
||||
result = self._conn.execute(text(sql), params)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
@@ -15,11 +15,29 @@ Covers every operation the legacy Mongo code performs on
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
from application.storage.db.base_repository import looks_like_uuid, row_to_dict
|
||||
|
||||
|
||||
_JSONB_COLUMNS = {"config", "config_requirements", "actions"}
|
||||
_SCALAR_COLUMNS = {"name", "custom_name", "display_name", "description", "status"}
|
||||
_ALLOWED_COLUMNS = _SCALAR_COLUMNS | _JSONB_COLUMNS
|
||||
|
||||
|
||||
def _encode_jsonb(value: Any) -> Any:
|
||||
"""Serialize a Python value for a JSONB bind parameter.
|
||||
|
||||
Accepts ``None``, already-encoded strings, or Python dict/list. Returns a
|
||||
JSON string suitable for ``CAST(:x AS jsonb)``.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return json.dumps(value)
|
||||
|
||||
|
||||
class UserToolsRepository:
|
||||
@@ -28,9 +46,21 @@ class UserToolsRepository:
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def create(self, user_id: str, name: str, *, config: Optional[dict] = None,
|
||||
custom_name: Optional[str] = None, display_name: Optional[str] = None,
|
||||
extra: Optional[dict] = None) -> dict:
|
||||
def create(
|
||||
self,
|
||||
user_id: str,
|
||||
name: str,
|
||||
*,
|
||||
config: Optional[dict] = None,
|
||||
custom_name: Optional[str] = None,
|
||||
display_name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
config_requirements: Optional[dict] = None,
|
||||
actions: Optional[list] = None,
|
||||
status: bool = True,
|
||||
extra: Optional[dict] = None,
|
||||
legacy_mongo_id: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Insert a new tool row. ``extra`` is merged into the config JSONB."""
|
||||
cfg = config or {}
|
||||
if extra:
|
||||
@@ -38,8 +68,17 @@ class UserToolsRepository:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO user_tools (user_id, name, custom_name, display_name, config)
|
||||
VALUES (:user_id, :name, :custom_name, :display_name, CAST(:config AS jsonb))
|
||||
INSERT INTO user_tools (
|
||||
user_id, name, custom_name, display_name, description,
|
||||
config, config_requirements, actions, status, legacy_mongo_id
|
||||
)
|
||||
VALUES (
|
||||
:user_id, :name, :custom_name, :display_name, :description,
|
||||
CAST(:config AS jsonb),
|
||||
CAST(:config_requirements AS jsonb),
|
||||
CAST(:actions AS jsonb),
|
||||
:status, :legacy_mongo_id
|
||||
)
|
||||
RETURNING *
|
||||
"""
|
||||
),
|
||||
@@ -48,7 +87,12 @@ class UserToolsRepository:
|
||||
"name": name,
|
||||
"custom_name": custom_name,
|
||||
"display_name": display_name,
|
||||
"description": description,
|
||||
"config": json.dumps(cfg),
|
||||
"config_requirements": _encode_jsonb(config_requirements or {}),
|
||||
"actions": _encode_jsonb(actions or []),
|
||||
"status": status,
|
||||
"legacy_mongo_id": legacy_mongo_id,
|
||||
},
|
||||
)
|
||||
return row_to_dict(result.fetchone())
|
||||
@@ -61,6 +105,32 @@ class UserToolsRepository:
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def get_by_legacy_id(
|
||||
self, legacy_mongo_id: str, user_id: Optional[str] = None,
|
||||
) -> Optional[dict]:
|
||||
"""Fetch a user_tool by the original Mongo ObjectId string."""
|
||||
legacy_mongo_id = str(legacy_mongo_id) if legacy_mongo_id is not None else None
|
||||
sql = "SELECT * FROM user_tools WHERE legacy_mongo_id = :legacy_id"
|
||||
params: dict = {"legacy_id": legacy_mongo_id}
|
||||
if user_id is not None:
|
||||
sql += " AND user_id = :user_id"
|
||||
params["user_id"] = user_id
|
||||
result = self._conn.execute(text(sql), params)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def get_any(self, tool_id: str, user_id: str) -> Optional[dict]:
|
||||
"""Resolve a user_tool by PG UUID or legacy Mongo ObjectId string.
|
||||
|
||||
Cutover helper: route handlers may receive either shape from
|
||||
older clients. Always returns a row scoped to ``user_id``.
|
||||
"""
|
||||
if looks_like_uuid(tool_id):
|
||||
row = self.get(tool_id, user_id)
|
||||
if row is not None:
|
||||
return row
|
||||
return self.get_by_legacy_id(tool_id, user_id)
|
||||
|
||||
def list_for_user(self, user_id: str) -> list[dict]:
|
||||
result = self._conn.execute(
|
||||
text("SELECT * FROM user_tools WHERE user_id = :user_id ORDER BY created_at"),
|
||||
@@ -68,43 +138,71 @@ class UserToolsRepository:
|
||||
)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def update(self, tool_id: str, user_id: str, fields: dict) -> None:
|
||||
def list_active_for_user(self, user_id: str) -> list[dict]:
|
||||
"""Return only tools with ``status = true`` — matches the legacy
|
||||
``find({"user": user, "status": True})`` used by the answer pipeline."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM user_tools "
|
||||
"WHERE user_id = :user_id AND status = true "
|
||||
"ORDER BY created_at"
|
||||
),
|
||||
{"user_id": user_id},
|
||||
)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def find_by_user_and_name(self, user_id: str, name: str) -> Optional[dict]:
|
||||
"""Used by the MCP save flow to decide between insert and update."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM user_tools "
|
||||
"WHERE user_id = :user_id AND name = :name "
|
||||
"LIMIT 1"
|
||||
),
|
||||
{"user_id": user_id, "name": name},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def update(self, tool_id: str, user_id: str, fields: dict) -> bool:
|
||||
"""Update arbitrary fields on a tool row.
|
||||
|
||||
``fields`` maps column names to new values. Only ``name``,
|
||||
``custom_name``, ``display_name``, and ``config`` are allowed.
|
||||
``fields`` maps column names to new values. Only columns in
|
||||
``_ALLOWED_COLUMNS`` are honored; unknown keys are silently dropped
|
||||
to keep route handlers from accidentally leaking DB shape. JSONB
|
||||
values are accepted as dicts/lists and serialized here.
|
||||
|
||||
Returns ``True`` if the row was updated, ``False`` if the id/user
|
||||
didn't match anything.
|
||||
"""
|
||||
allowed = {"name", "custom_name", "display_name", "config"}
|
||||
filtered = {k: v for k, v in fields.items() if k in allowed}
|
||||
filtered = {k: v for k, v in fields.items() if k in _ALLOWED_COLUMNS}
|
||||
if not filtered:
|
||||
return
|
||||
params: dict = {
|
||||
"id": tool_id,
|
||||
"user_id": user_id,
|
||||
"name": filtered.get("name"),
|
||||
"custom_name": filtered.get("custom_name"),
|
||||
"display_name": filtered.get("display_name"),
|
||||
"config": (
|
||||
json.dumps(filtered["config"])
|
||||
if "config" in filtered and isinstance(filtered["config"], dict)
|
||||
else filtered.get("config")
|
||||
),
|
||||
}
|
||||
self._conn.execute(
|
||||
return False
|
||||
|
||||
set_clauses: list[str] = []
|
||||
params: dict = {"id": tool_id, "user_id": user_id}
|
||||
for col, val in filtered.items():
|
||||
if col not in _ALLOWED_COLUMNS:
|
||||
raise ValueError(f"disallowed column: {col!r}")
|
||||
if col in _JSONB_COLUMNS:
|
||||
set_clauses.append(f"{col} = CAST(:{col} AS jsonb)")
|
||||
params[col] = _encode_jsonb(val)
|
||||
else:
|
||||
set_clauses.append(f"{col} = :{col}")
|
||||
params[col] = val
|
||||
set_clauses.append("updated_at = now()")
|
||||
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
f"""
|
||||
UPDATE user_tools
|
||||
SET
|
||||
name = COALESCE(:name, name),
|
||||
custom_name = COALESCE(:custom_name, custom_name),
|
||||
display_name = COALESCE(:display_name, display_name),
|
||||
config = COALESCE(CAST(:config AS jsonb), config),
|
||||
updated_at = now()
|
||||
SET {", ".join(set_clauses)}
|
||||
WHERE id = CAST(:id AS uuid) AND user_id = :user_id
|
||||
"""
|
||||
),
|
||||
params,
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def delete(self, tool_id: str, user_id: str) -> bool:
|
||||
result = self._conn.execute(
|
||||
|
||||
@@ -82,7 +82,20 @@ class WorkflowEdgesRepository:
|
||||
"config": e.get("config", {}),
|
||||
})
|
||||
|
||||
stmt = pg_insert(workflow_edges_table).values(rows).returning(workflow_edges_table)
|
||||
# See ``WorkflowNodesRepository.bulk_create`` for the race
|
||||
# rationale — same pattern for edges, keyed on the unique
|
||||
# index ``workflow_edges_wf_ver_eid_uidx``.
|
||||
stmt = pg_insert(workflow_edges_table).values(rows)
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=["workflow_id", "graph_version", "edge_id"],
|
||||
set_={
|
||||
"from_node_id": stmt.excluded.from_node_id,
|
||||
"to_node_id": stmt.excluded.to_node_id,
|
||||
"source_handle": stmt.excluded.source_handle,
|
||||
"target_handle": stmt.excluded.target_handle,
|
||||
"config": stmt.excluded.config,
|
||||
},
|
||||
).returning(workflow_edges_table)
|
||||
result = self._conn.execute(stmt)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
|
||||
@@ -82,7 +82,25 @@ class WorkflowNodesRepository:
|
||||
"legacy_mongo_id": n.get("legacy_mongo_id"),
|
||||
})
|
||||
|
||||
stmt = pg_insert(workflow_nodes_table).values(rows).returning(workflow_nodes_table)
|
||||
# Two concurrent ``PUT /workflows/{id}/graph`` calls at the same
|
||||
# ``next_graph_version`` would race here. Without ON CONFLICT the
|
||||
# unique index ``workflow_nodes_wf_ver_nid_uidx`` would reject the
|
||||
# loser outright; we prefer last-writer-wins semantics so that
|
||||
# "overwrite the whole graph at version N" is idempotent and
|
||||
# resilient. The ON CONFLICT target matches the existing unique
|
||||
# index on (workflow_id, graph_version, node_id).
|
||||
stmt = pg_insert(workflow_nodes_table).values(rows)
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=["workflow_id", "graph_version", "node_id"],
|
||||
set_={
|
||||
"node_type": stmt.excluded.node_type,
|
||||
"title": stmt.excluded.title,
|
||||
"description": stmt.excluded.description,
|
||||
"position": stmt.excluded.position,
|
||||
"config": stmt.excluded.config,
|
||||
"legacy_mongo_id": stmt.excluded.legacy_mongo_id,
|
||||
},
|
||||
).returning(workflow_nodes_table)
|
||||
result = self._conn.execute(stmt)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
@@ -117,6 +135,7 @@ class WorkflowNodesRepository:
|
||||
|
||||
def get_by_legacy_id(self, legacy_mongo_id: str) -> Optional[dict]:
|
||||
"""Find a node by the original Mongo ObjectId string."""
|
||||
legacy_mongo_id = str(legacy_mongo_id) if legacy_mongo_id is not None else None
|
||||
result = self._conn.execute(
|
||||
text("SELECT * FROM workflow_nodes WHERE legacy_mongo_id = :legacy_id"),
|
||||
{"legacy_id": legacy_mongo_id},
|
||||
|
||||
@@ -64,6 +64,7 @@ class WorkflowRunsRepository:
|
||||
|
||||
def get_by_legacy_id(self, legacy_mongo_id: str) -> Optional[dict]:
|
||||
"""Fetch a workflow run by the original Mongo ObjectId string."""
|
||||
legacy_mongo_id = str(legacy_mongo_id) if legacy_mongo_id is not None else None
|
||||
res = self._conn.execute(
|
||||
text("SELECT * FROM workflow_runs WHERE legacy_mongo_id = :legacy_id"),
|
||||
{"legacy_id": legacy_mongo_id},
|
||||
|
||||
@@ -63,6 +63,7 @@ class WorkflowsRepository:
|
||||
self, legacy_mongo_id: str, user_id: str | None = None,
|
||||
) -> Optional[dict]:
|
||||
"""Fetch a workflow by its original Mongo ObjectId string."""
|
||||
legacy_mongo_id = str(legacy_mongo_id) if legacy_mongo_id is not None else None
|
||||
sql = "SELECT * FROM workflows WHERE legacy_mongo_id = :legacy_id"
|
||||
params: dict[str, str] = {"legacy_id": legacy_mongo_id}
|
||||
if user_id is not None:
|
||||
@@ -123,3 +124,19 @@ class WorkflowsRepository:
|
||||
{"id": workflow_id, "user_id": user_id},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def delete_by_legacy_id(self, legacy_mongo_id: str, user_id: str) -> bool:
|
||||
"""Delete a workflow addressed by the Mongo ObjectId string.
|
||||
|
||||
The ``workflow_nodes`` and ``workflow_edges`` rows are removed
|
||||
automatically via ``ON DELETE CASCADE``.
|
||||
"""
|
||||
legacy_mongo_id = str(legacy_mongo_id) if legacy_mongo_id is not None else None
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"DELETE FROM workflows "
|
||||
"WHERE legacy_mongo_id = :legacy_id AND user_id = :user_id"
|
||||
),
|
||||
{"legacy_id": legacy_mongo_id, "user_id": user_id},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
67
application/storage/db/session.py
Normal file
67
application/storage/db/session.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""Per-request connection helpers for route handlers.
|
||||
|
||||
Every route-handler that talks to Postgres opens a short-lived, explicit
|
||||
transaction via the context managers in this module. The pattern is::
|
||||
|
||||
from application.storage.db.session import db_session
|
||||
|
||||
with db_session() as conn:
|
||||
repo = PromptsRepository(conn)
|
||||
prompt = repo.get(prompt_id, user_id)
|
||||
|
||||
Why explicit, not ``flask.g``: the lifecycle stays local to each handler,
|
||||
which mirrors how the repository test fixtures already work and keeps
|
||||
error handling obvious. Celery tasks and the seeder use the same helper
|
||||
so there's one pattern to learn.
|
||||
|
||||
Two flavors:
|
||||
|
||||
* ``db_session()`` — opens a transaction (``engine.begin()``). Commits on
|
||||
clean exit, rolls back on exception. Use for any handler that may
|
||||
write.
|
||||
* ``db_readonly()`` — opens a plain connection (``engine.connect()``) for
|
||||
read-only paths. Avoids the commit round-trip on pure reads.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Iterator
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
from application.storage.db.engine import get_engine
|
||||
|
||||
|
||||
@contextmanager
|
||||
def db_session() -> Iterator[Connection]:
|
||||
"""Transactional connection. Commits on success, rolls back on error."""
|
||||
with get_engine().begin() as conn:
|
||||
yield conn
|
||||
|
||||
|
||||
@contextmanager
|
||||
def db_readonly() -> Iterator[Connection]:
|
||||
"""Read-only connection for handlers that never write.
|
||||
|
||||
The connection is placed into a Postgres ``READ ONLY`` transaction
|
||||
before any caller statement runs, so an accidental ``INSERT`` /
|
||||
``UPDATE`` / ``DELETE`` from inside the block raises
|
||||
``InternalError: cannot execute ... in a read-only transaction``
|
||||
instead of silently mutating data.
|
||||
|
||||
The transaction itself is rolled back on exit — a read-only
|
||||
transaction has nothing meaningful to commit, and rolling back avoids
|
||||
leaving the connection in an open-transaction state when it returns
|
||||
to the pool.
|
||||
"""
|
||||
with get_engine().connect() as conn:
|
||||
trans = conn.begin()
|
||||
try:
|
||||
# Must be the first statement in the txn; psycopg3 + SA both
|
||||
# honor this and Postgres rejects writes for the rest of the
|
||||
# transaction's lifetime.
|
||||
conn.execute(text("SET TRANSACTION READ ONLY"))
|
||||
yield conn
|
||||
finally:
|
||||
trans.rollback()
|
||||
@@ -2,16 +2,12 @@ import sys
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.repositories.token_usage import TokenUsageRepository
|
||||
from application.storage.db.session import db_session
|
||||
from application.utils import num_tokens_from_object_or_list, num_tokens_from_string
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
usage_collection = db["token_usage"]
|
||||
|
||||
|
||||
def _serialize_for_token_count(value):
|
||||
"""Normalize payloads into token-countable primitives."""
|
||||
@@ -99,30 +95,18 @@ def update_token_usage(decoded_token, user_api_key, token_usage, agent_id=None):
|
||||
)
|
||||
return
|
||||
|
||||
usage_data = {
|
||||
"user_id": user_id,
|
||||
"api_key": user_api_key,
|
||||
"prompt_tokens": token_usage["prompt_tokens"],
|
||||
"generated_tokens": token_usage["generated_tokens"],
|
||||
"timestamp": datetime.now(),
|
||||
}
|
||||
if normalized_agent_id:
|
||||
usage_data["agent_id"] = normalized_agent_id
|
||||
usage_collection.insert_one(usage_data)
|
||||
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.repositories.token_usage import TokenUsageRepository
|
||||
|
||||
dual_write(
|
||||
TokenUsageRepository,
|
||||
lambda repo, d=usage_data: repo.insert(
|
||||
user_id=d.get("user_id"),
|
||||
api_key=d.get("api_key"),
|
||||
agent_id=d.get("agent_id"),
|
||||
prompt_tokens=d["prompt_tokens"],
|
||||
generated_tokens=d["generated_tokens"],
|
||||
),
|
||||
)
|
||||
try:
|
||||
with db_session() as conn:
|
||||
TokenUsageRepository(conn).insert(
|
||||
user_id=user_id,
|
||||
api_key=user_api_key,
|
||||
agent_id=normalized_agent_id,
|
||||
prompt_tokens=token_usage["prompt_tokens"],
|
||||
generated_tokens=token_usage["generated_tokens"],
|
||||
timestamp=datetime.now(),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to record token usage: {e}", exc_info=True)
|
||||
|
||||
|
||||
def gen_token_usage(func):
|
||||
|
||||
@@ -9,18 +9,16 @@ import tempfile
|
||||
from typing import Any, Dict
|
||||
import zipfile
|
||||
|
||||
import uuid
|
||||
from collections import Counter
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
from bson.dbref import DBRef
|
||||
from bson.objectid import ObjectId
|
||||
|
||||
from application.agents.agent_creator import AgentCreator
|
||||
from application.api.answer.services.stream_processor import get_prompt
|
||||
|
||||
from application.cache import get_redis_instance
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.parser.chunking import Chunker
|
||||
from application.parser.connectors.connector_creator import ConnectorCreator
|
||||
@@ -31,12 +29,13 @@ from application.parser.remote.remote_creator import RemoteCreator
|
||||
from application.parser.schema.base import Document
|
||||
from application.retriever.retriever_creator import RetrieverCreator
|
||||
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.attachments import AttachmentsRepository
|
||||
from application.storage.db.repositories.sources import SourcesRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.storage.storage_creator import StorageCreator
|
||||
from application.utils import count_tokens_docs, num_tokens_from_string
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
sources_collection = db["sources"]
|
||||
from application.utils import count_tokens_docs, num_tokens_from_string, safe_filename
|
||||
|
||||
# Constants
|
||||
|
||||
@@ -313,23 +312,35 @@ def run_agent_logic(agent_config, input_data):
|
||||
)
|
||||
from application.utils import calculate_doc_token_budget
|
||||
|
||||
source = agent_config.get("source")
|
||||
retriever = agent_config.get("retriever", "classic")
|
||||
if isinstance(source, DBRef):
|
||||
source_doc = db.dereference(source)
|
||||
source = str(source_doc["_id"])
|
||||
retriever = source_doc.get("retriever", agent_config.get("retriever"))
|
||||
else:
|
||||
source = {}
|
||||
source = {"active_docs": source}
|
||||
chunks = int(agent_config.get("chunks", 2))
|
||||
# agent_config is a PG row dict: ``source_id`` is a UUID, and the
|
||||
# retriever/chunks live on the source row. Resolve source row for
|
||||
# its retriever/chunks if the agent points at one.
|
||||
source_id = agent_config.get("source_id") or agent_config.get("source")
|
||||
source_active = {}
|
||||
if source_id:
|
||||
with db_readonly() as conn:
|
||||
src_row = SourcesRepository(conn).get(
|
||||
str(source_id),
|
||||
agent_config.get("user_id") or agent_config.get("user"),
|
||||
)
|
||||
if src_row:
|
||||
source_active = str(src_row["id"])
|
||||
retriever = src_row.get("retriever", retriever)
|
||||
source = {"active_docs": source_active}
|
||||
chunks = int(agent_config.get("chunks", 2) or 2)
|
||||
prompt_id = agent_config.get("prompt_id", "default")
|
||||
user_api_key = agent_config["key"]
|
||||
agent_id = str(agent_config.get("_id")) if agent_config.get("_id") else None
|
||||
agent_id = (
|
||||
str(agent_config.get("id"))
|
||||
if agent_config.get("id")
|
||||
else (str(agent_config.get("_id")) if agent_config.get("_id") else None)
|
||||
)
|
||||
agent_type = agent_config.get("agent_type", "classic")
|
||||
decoded_token = {"sub": agent_config.get("user")}
|
||||
owner = agent_config.get("user_id") or agent_config.get("user")
|
||||
decoded_token = {"sub": owner}
|
||||
json_schema = agent_config.get("json_schema")
|
||||
prompt = get_prompt(prompt_id, db["prompts"])
|
||||
prompt = get_prompt(prompt_id)
|
||||
|
||||
# Determine model_id: check agent's default_model_id, fallback to system default
|
||||
agent_default_model = agent_config.get("default_model_id", "")
|
||||
@@ -545,7 +556,7 @@ def ingest_worker(
|
||||
|
||||
docs = [Document.to_langchain_format(raw_doc) for raw_doc in raw_docs]
|
||||
|
||||
id = ObjectId()
|
||||
id = uuid.uuid4()
|
||||
|
||||
vector_store_path = os.path.join(temp_dir, "vector_store")
|
||||
os.makedirs(vector_store_path, exist_ok=True)
|
||||
@@ -609,9 +620,11 @@ def reingest_source_worker(self, source_id, user):
|
||||
meta={"current": 10, "status": "Initializing re-ingestion scan"},
|
||||
)
|
||||
|
||||
source = sources_collection.find_one({"_id": ObjectId(source_id), "user": user})
|
||||
with db_readonly() as conn:
|
||||
source = SourcesRepository(conn).get_any(source_id, user)
|
||||
if not source:
|
||||
raise ValueError(f"Source {source_id} not found or access denied")
|
||||
source_id = str(source["id"])
|
||||
|
||||
storage = StorageCreator.get_storage()
|
||||
source_file_path = source.get("file_path", "")
|
||||
@@ -860,16 +873,16 @@ def reingest_source_worker(self, source_id, user):
|
||||
directory_structure, file_name_map
|
||||
)
|
||||
|
||||
sources_collection.update_one(
|
||||
{"_id": ObjectId(source_id)},
|
||||
{
|
||||
"$set": {
|
||||
now = datetime.datetime.now()
|
||||
with db_session() as conn:
|
||||
SourcesRepository(conn).update(
|
||||
source_id, user,
|
||||
{
|
||||
"directory_structure": directory_structure,
|
||||
"date": datetime.datetime.now(),
|
||||
"date": now,
|
||||
"tokens": total_tokens,
|
||||
}
|
||||
},
|
||||
)
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error updating directory_structure in DB: {e}", exc_info=True
|
||||
@@ -912,9 +925,9 @@ def remote_worker(
|
||||
operation_mode="upload",
|
||||
doc_id=None,
|
||||
):
|
||||
full_path = os.path.join(directory, user, name_job)
|
||||
if not os.path.exists(full_path):
|
||||
os.makedirs(full_path)
|
||||
safe_user = safe_filename(user)
|
||||
full_path = os.path.join(directory, safe_user, uuid.uuid4().hex)
|
||||
os.makedirs(full_path, exist_ok=True)
|
||||
self.update_state(state="PROGRESS", meta={"current": 1})
|
||||
try:
|
||||
logging.info("Initializing remote loader with type: %s", loader)
|
||||
@@ -1003,13 +1016,13 @@ def remote_worker(
|
||||
)
|
||||
|
||||
if operation_mode == "upload":
|
||||
id = ObjectId()
|
||||
id = uuid.uuid4()
|
||||
embed_and_store_documents(docs, full_path, id, self)
|
||||
elif operation_mode == "sync":
|
||||
if not doc_id or not ObjectId.is_valid(doc_id):
|
||||
if not doc_id:
|
||||
logging.error("Invalid doc_id provided for sync operation: %s", doc_id)
|
||||
raise ValueError("doc_id must be provided for sync operation.")
|
||||
id = ObjectId(doc_id)
|
||||
id = str(doc_id)
|
||||
embed_and_store_documents(docs, full_path, id, self)
|
||||
self.update_state(state="PROGRESS", meta={"current": 100})
|
||||
|
||||
@@ -1030,7 +1043,19 @@ def remote_worker(
|
||||
}
|
||||
|
||||
if operation_mode == "sync":
|
||||
file_data["last_sync"] = datetime.datetime.now()
|
||||
last_sync_now = datetime.datetime.now()
|
||||
file_data["last_sync"] = last_sync_now
|
||||
|
||||
try:
|
||||
with db_session() as conn:
|
||||
repo = SourcesRepository(conn)
|
||||
src = repo.get_any(str(id), user)
|
||||
if src is not None:
|
||||
repo.update(str(src["id"]), user, {"date": last_sync_now})
|
||||
except Exception as upd_err:
|
||||
logging.warning(
|
||||
f"Failed to update last_sync for source {id}: {upd_err}"
|
||||
)
|
||||
upload_index(full_path, file_data)
|
||||
except Exception as e:
|
||||
logging.error("Error in remote_worker task: %s", str(e), exc_info=True)
|
||||
@@ -1079,23 +1104,34 @@ def sync(
|
||||
|
||||
|
||||
def sync_worker(self, frequency):
|
||||
from sqlalchemy import text as sql_text
|
||||
|
||||
sync_counts = Counter()
|
||||
sources = sources_collection.find()
|
||||
for doc in sources:
|
||||
if doc.get("sync_frequency") == frequency:
|
||||
name = doc.get("name")
|
||||
user = doc.get("user")
|
||||
source_type = doc.get("type")
|
||||
source_data = doc.get("remote_data")
|
||||
retriever = doc.get("retriever")
|
||||
doc_id = str(doc.get("_id"))
|
||||
resp = sync(
|
||||
self, source_data, name, user, source_type, frequency, retriever, doc_id
|
||||
)
|
||||
sync_counts["total_sync_count"] += 1
|
||||
sync_counts[
|
||||
"sync_success" if resp["status"] == "success" else "sync_failure"
|
||||
] += 1
|
||||
with db_readonly() as conn:
|
||||
result = conn.execute(
|
||||
sql_text(
|
||||
"SELECT id, name, user_id, type, remote_data, retriever "
|
||||
"FROM sources WHERE sync_frequency = :freq"
|
||||
),
|
||||
{"freq": frequency},
|
||||
)
|
||||
rows = result.fetchall()
|
||||
|
||||
for row in rows:
|
||||
doc = dict(row._mapping)
|
||||
name = doc.get("name")
|
||||
user = doc.get("user_id")
|
||||
source_type = doc.get("type")
|
||||
source_data = doc.get("remote_data")
|
||||
retriever = doc.get("retriever")
|
||||
doc_id = str(doc.get("id"))
|
||||
resp = sync(
|
||||
self, source_data, name, user, source_type, frequency, retriever, doc_id
|
||||
)
|
||||
sync_counts["total_sync_count"] += 1
|
||||
sync_counts[
|
||||
"sync_success" if resp["status"] == "success" else "sync_failure"
|
||||
] += 1
|
||||
return {
|
||||
key: sync_counts[key]
|
||||
for key in ["total_sync_count", "sync_success", "sync_failure"]
|
||||
@@ -1107,10 +1143,6 @@ def attachment_worker(self, file_info, user):
|
||||
Process and store a single attachment without vectorization.
|
||||
"""
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
attachments_collection = db["attachments"]
|
||||
|
||||
filename = file_info["filename"]
|
||||
attachment_id = file_info["attachment_id"]
|
||||
relative_path = file_info["path"]
|
||||
@@ -1158,30 +1190,20 @@ def attachment_worker(self, file_info, user):
|
||||
|
||||
mime_type = mimetypes.guess_type(filename)[0] or "application/octet-stream"
|
||||
|
||||
doc_id = ObjectId(attachment_id)
|
||||
attachments_collection.insert_one(
|
||||
{
|
||||
"_id": doc_id,
|
||||
"user": user,
|
||||
"path": relative_path,
|
||||
"filename": filename,
|
||||
"content": content,
|
||||
"token_count": token_count,
|
||||
"mime_type": mime_type,
|
||||
"date": datetime.datetime.now(),
|
||||
"metadata": metadata,
|
||||
}
|
||||
)
|
||||
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.repositories.attachments import AttachmentsRepository
|
||||
|
||||
dual_write(
|
||||
AttachmentsRepository,
|
||||
lambda repo, u=user, fn=filename, p=relative_path, mt=mime_type, mid=attachment_id: repo.create(
|
||||
u, fn, p, mime_type=mt, legacy_mongo_id=mid,
|
||||
),
|
||||
)
|
||||
# The upload route produces a UUID-shaped ``attachment_id`` (stored
|
||||
# in the storage path) but the PG ``attachments.id`` is generated
|
||||
# by the DB. Keep ``attachment_id`` as the caller-visible handle
|
||||
# used for the storage path, and stash it in ``legacy_mongo_id``
|
||||
# so the attachment row is resolvable via that handle too.
|
||||
with db_session() as conn:
|
||||
AttachmentsRepository(conn).create(
|
||||
user, filename, relative_path,
|
||||
mime_type=mime_type,
|
||||
content=content,
|
||||
token_count=token_count,
|
||||
metadata=metadata,
|
||||
legacy_mongo_id=str(attachment_id),
|
||||
)
|
||||
|
||||
logging.info(
|
||||
f"Stored attachment with ID: {attachment_id}", extra={"user": user}
|
||||
@@ -1218,14 +1240,25 @@ def agent_webhook_worker(self, agent_id, payload):
|
||||
Returns:
|
||||
dict: Information about the processed webhook.
|
||||
"""
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo["docsgpt"]
|
||||
agents_collection = db["agents"]
|
||||
|
||||
self.update_state(state="PROGRESS", meta={"current": 1})
|
||||
try:
|
||||
agent_oid = ObjectId(agent_id)
|
||||
agent_config = agents_collection.find_one({"_id": agent_oid})
|
||||
with db_readonly() as conn:
|
||||
repo = AgentsRepository(conn)
|
||||
agent_config = None
|
||||
if looks_like_uuid(str(agent_id)):
|
||||
# Access without user scoping — webhooks authenticate via
|
||||
# the incoming token, not a user context.
|
||||
from sqlalchemy import text as sql_text
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
result = conn.execute(
|
||||
sql_text("SELECT * FROM agents WHERE id = CAST(:id AS uuid)"),
|
||||
{"id": str(agent_id)},
|
||||
)
|
||||
row = result.fetchone()
|
||||
if row is not None:
|
||||
agent_config = row_to_dict(row)
|
||||
if agent_config is None:
|
||||
agent_config = repo.get_by_legacy_id(str(agent_id))
|
||||
if not agent_config:
|
||||
raise ValueError(f"Agent with ID {agent_id} not found.")
|
||||
input_data = json.dumps(payload)
|
||||
@@ -1371,14 +1404,14 @@ def ingest_connector(
|
||||
docs = [Document.to_langchain_format(raw_doc) for raw_doc in raw_docs]
|
||||
|
||||
if operation_mode == "upload":
|
||||
id = ObjectId()
|
||||
id = uuid.uuid4()
|
||||
elif operation_mode == "sync":
|
||||
if not doc_id or not ObjectId.is_valid(doc_id):
|
||||
if not doc_id:
|
||||
logging.error(
|
||||
"Invalid doc_id provided for sync operation: %s", doc_id
|
||||
)
|
||||
raise ValueError("doc_id must be provided for sync operation.")
|
||||
id = ObjectId(doc_id)
|
||||
id = str(doc_id)
|
||||
else:
|
||||
raise ValueError(f"Invalid operation_mode: {operation_mode}")
|
||||
|
||||
@@ -1412,6 +1445,21 @@ def ingest_connector(
|
||||
else:
|
||||
file_data["last_sync"] = datetime.datetime.now()
|
||||
|
||||
if operation_mode == "sync":
|
||||
try:
|
||||
with db_session() as conn:
|
||||
repo = SourcesRepository(conn)
|
||||
src = repo.get_any(str(id), user)
|
||||
if src is not None:
|
||||
repo.update(
|
||||
str(src["id"]), user,
|
||||
{"date": file_data["last_sync"]},
|
||||
)
|
||||
except Exception as upd_err:
|
||||
logging.warning(
|
||||
f"Failed to update last_sync for source {id}: {upd_err}"
|
||||
)
|
||||
|
||||
upload_index(vector_store_path, file_data)
|
||||
|
||||
# Ensure we mark the task as complete
|
||||
|
||||
@@ -17,7 +17,7 @@ services:
|
||||
# Override URLs to use docker service names
|
||||
- CELERY_BROKER_URL=redis://redis:6379/0
|
||||
- CELERY_RESULT_BACKEND=redis://redis:6379/1
|
||||
- MONGO_URI=mongodb://mongo:27017/docsgpt
|
||||
- POSTGRES_URI=postgresql://docsgpt:docsgpt@postgres:5432/docsgpt
|
||||
ports:
|
||||
- "7091:7091"
|
||||
volumes:
|
||||
@@ -25,8 +25,10 @@ services:
|
||||
- ../application/inputs:/app/application/inputs
|
||||
- ../application/vectors:/app/application/vectors
|
||||
depends_on:
|
||||
- redis
|
||||
- mongo
|
||||
redis:
|
||||
condition: service_started
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
|
||||
worker:
|
||||
build: ../application
|
||||
@@ -37,25 +39,34 @@ services:
|
||||
# Override URLs to use docker service names
|
||||
- CELERY_BROKER_URL=redis://redis:6379/0
|
||||
- CELERY_RESULT_BACKEND=redis://redis:6379/1
|
||||
- MONGO_URI=mongodb://mongo:27017/docsgpt
|
||||
- API_URL=http://backend:7091
|
||||
- POSTGRES_URI=postgresql://docsgpt:docsgpt@postgres:5432/docsgpt
|
||||
depends_on:
|
||||
- redis
|
||||
- mongo
|
||||
redis:
|
||||
condition: service_started
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
|
||||
redis:
|
||||
image: redis:6-alpine
|
||||
ports:
|
||||
- 6379:6379
|
||||
|
||||
mongo:
|
||||
image: mongo:6
|
||||
postgres:
|
||||
image: postgres:16-alpine
|
||||
environment:
|
||||
- POSTGRES_USER=docsgpt
|
||||
- POSTGRES_PASSWORD=docsgpt
|
||||
- POSTGRES_DB=docsgpt
|
||||
ports:
|
||||
- 27017:27017
|
||||
- "5432:5432"
|
||||
volumes:
|
||||
- mongodb_data_container:/data/db
|
||||
|
||||
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U docsgpt -d docsgpt"]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 10
|
||||
|
||||
volumes:
|
||||
mongodb_data_container:
|
||||
postgres_data:
|
||||
|
||||
@@ -6,14 +6,21 @@ services:
|
||||
ports:
|
||||
- 6379:6379
|
||||
|
||||
mongo:
|
||||
image: mongo:6
|
||||
postgres:
|
||||
image: postgres:16-alpine
|
||||
environment:
|
||||
- POSTGRES_USER=docsgpt
|
||||
- POSTGRES_PASSWORD=docsgpt
|
||||
- POSTGRES_DB=docsgpt
|
||||
ports:
|
||||
- 27017:27017
|
||||
- "5432:5432"
|
||||
volumes:
|
||||
- mongodb_data_container:/data/db
|
||||
|
||||
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U docsgpt -d docsgpt"]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 10
|
||||
|
||||
volumes:
|
||||
mongodb_data_container:
|
||||
postgres_data:
|
||||
|
||||
@@ -21,8 +21,8 @@ services:
|
||||
environment:
|
||||
- CELERY_BROKER_URL=redis://redis:6379/0
|
||||
- CELERY_RESULT_BACKEND=redis://redis:6379/1
|
||||
- MONGO_URI=mongodb://mongo:27017/docsgpt
|
||||
- CACHE_REDIS_URL=redis://redis:6379/2
|
||||
- POSTGRES_URI=postgresql://docsgpt:docsgpt@postgres:5432/docsgpt
|
||||
ports:
|
||||
- "7091:7091"
|
||||
volumes:
|
||||
@@ -30,8 +30,10 @@ services:
|
||||
- ../application/inputs:/app/inputs
|
||||
- ../application/vectors:/app/vectors
|
||||
depends_on:
|
||||
- redis
|
||||
- mongo
|
||||
redis:
|
||||
condition: service_started
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
|
||||
|
||||
worker:
|
||||
@@ -43,28 +45,39 @@ services:
|
||||
environment:
|
||||
- CELERY_BROKER_URL=redis://redis:6379/0
|
||||
- CELERY_RESULT_BACKEND=redis://redis:6379/1
|
||||
- MONGO_URI=mongodb://mongo:27017/docsgpt
|
||||
- API_URL=http://backend:7091
|
||||
- CACHE_REDIS_URL=redis://redis:6379/2
|
||||
- POSTGRES_URI=postgresql://docsgpt:docsgpt@postgres:5432/docsgpt
|
||||
volumes:
|
||||
- ../application/indexes:/app/indexes
|
||||
- ../application/inputs:/app/inputs
|
||||
- ../application/vectors:/app/vectors
|
||||
depends_on:
|
||||
- redis
|
||||
- mongo
|
||||
redis:
|
||||
condition: service_started
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
|
||||
redis:
|
||||
image: redis:6-alpine
|
||||
ports:
|
||||
- 6379:6379
|
||||
|
||||
mongo:
|
||||
image: mongo:6
|
||||
postgres:
|
||||
image: postgres:16-alpine
|
||||
environment:
|
||||
- POSTGRES_USER=docsgpt
|
||||
- POSTGRES_PASSWORD=docsgpt
|
||||
- POSTGRES_DB=docsgpt
|
||||
ports:
|
||||
- 27017:27017
|
||||
- "5432:5432"
|
||||
volumes:
|
||||
- mongodb_data_container:/data/db
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U docsgpt -d docsgpt"]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 10
|
||||
|
||||
volumes:
|
||||
mongodb_data_container:
|
||||
postgres_data:
|
||||
|
||||
@@ -15,12 +15,21 @@ services:
|
||||
ports:
|
||||
- 6379:6379
|
||||
|
||||
mongo:
|
||||
image: mongo:6
|
||||
postgres:
|
||||
image: postgres:16-alpine
|
||||
environment:
|
||||
- POSTGRES_USER=docsgpt
|
||||
- POSTGRES_PASSWORD=docsgpt
|
||||
- POSTGRES_DB=docsgpt
|
||||
ports:
|
||||
- 27017:27017
|
||||
- "5432:5432"
|
||||
volumes:
|
||||
- mongodb_data_container:/data/db
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U docsgpt -d docsgpt"]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 10
|
||||
|
||||
volumes:
|
||||
mongodb_data_container:
|
||||
postgres_data:
|
||||
|
||||
@@ -22,8 +22,8 @@ services:
|
||||
# Override URLs to use docker service names
|
||||
- CELERY_BROKER_URL=redis://redis:6379/0
|
||||
- CELERY_RESULT_BACKEND=redis://redis:6379/1
|
||||
- MONGO_URI=mongodb://mongo:27017/docsgpt
|
||||
- CACHE_REDIS_URL=redis://redis:6379/2
|
||||
- POSTGRES_URI=postgresql://docsgpt:docsgpt@postgres:5432/docsgpt
|
||||
ports:
|
||||
- "7091:7091"
|
||||
volumes:
|
||||
@@ -31,8 +31,10 @@ services:
|
||||
- ../application/inputs:/app/inputs
|
||||
- ../application/vectors:/app/vectors
|
||||
depends_on:
|
||||
- redis
|
||||
- mongo
|
||||
redis:
|
||||
condition: service_started
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
|
||||
worker:
|
||||
user: root
|
||||
@@ -44,28 +46,40 @@ services:
|
||||
# Override URLs to use docker service names
|
||||
- CELERY_BROKER_URL=redis://redis:6379/0
|
||||
- CELERY_RESULT_BACKEND=redis://redis:6379/1
|
||||
- MONGO_URI=mongodb://mongo:27017/docsgpt
|
||||
- API_URL=http://backend:7091
|
||||
- CACHE_REDIS_URL=redis://redis:6379/2
|
||||
- POSTGRES_URI=postgresql://docsgpt:docsgpt@postgres:5432/docsgpt
|
||||
volumes:
|
||||
- ../application/indexes:/app/indexes
|
||||
- ../application/inputs:/app/inputs
|
||||
- ../application/vectors:/app/vectors
|
||||
depends_on:
|
||||
- redis
|
||||
- mongo
|
||||
redis:
|
||||
condition: service_started
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
|
||||
redis:
|
||||
image: redis:6-alpine
|
||||
ports:
|
||||
- 6379:6379
|
||||
|
||||
mongo:
|
||||
image: mongo:6
|
||||
postgres:
|
||||
image: postgres:16-alpine
|
||||
environment:
|
||||
- POSTGRES_USER=docsgpt
|
||||
- POSTGRES_PASSWORD=docsgpt
|
||||
- POSTGRES_DB=docsgpt
|
||||
ports:
|
||||
- 27017:27017
|
||||
- "5432:5432"
|
||||
volumes:
|
||||
- mongodb_data_container:/data/db
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U docsgpt -d docsgpt"]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 10
|
||||
|
||||
volumes:
|
||||
mongodb_data_container:
|
||||
postgres_data:
|
||||
|
||||
|
||||
@@ -12,6 +12,19 @@ spec:
|
||||
labels:
|
||||
app: docsgpt-api
|
||||
spec:
|
||||
initContainers:
|
||||
# Block pod start until Postgres accepts connections. The `postgres-init`
|
||||
# Job is responsible for running alembic migrations; this container only
|
||||
# waits for the server to be reachable.
|
||||
- name: wait-for-postgres
|
||||
image: postgres:16-alpine
|
||||
command:
|
||||
- sh
|
||||
- -c
|
||||
- |
|
||||
until pg_isready -h postgres -p 5432 -U docsgpt -d docsgpt; do
|
||||
echo "Waiting for postgres..."; sleep 2;
|
||||
done
|
||||
containers:
|
||||
- name: docsgpt-api
|
||||
image: arc53/docsgpt
|
||||
@@ -32,6 +45,18 @@ spec:
|
||||
value: "application/app.py"
|
||||
- name: DEPLOYMENT_TYPE
|
||||
value: "cloud"
|
||||
- name: POSTGRES_URI
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: docsgpt-secrets
|
||||
key: POSTGRES_URI
|
||||
# Disable in-app auto-bootstrap. The `postgres-init` Job under
|
||||
# deployment/k8s/jobs/ owns schema creation and Alembic migrations,
|
||||
# so application pods must not race with it on rollout.
|
||||
- name: AUTO_MIGRATE
|
||||
value: "false"
|
||||
- name: AUTO_CREATE_DB
|
||||
value: "false"
|
||||
---
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
@@ -47,6 +72,16 @@ spec:
|
||||
labels:
|
||||
app: docsgpt-worker
|
||||
spec:
|
||||
initContainers:
|
||||
- name: wait-for-postgres
|
||||
image: postgres:16-alpine
|
||||
command:
|
||||
- sh
|
||||
- -c
|
||||
- |
|
||||
until pg_isready -h postgres -p 5432 -U docsgpt -d docsgpt; do
|
||||
echo "Waiting for postgres..."; sleep 2;
|
||||
done
|
||||
containers:
|
||||
- name: docsgpt-worker
|
||||
image: arc53/docsgpt
|
||||
@@ -64,6 +99,18 @@ spec:
|
||||
env:
|
||||
- name: API_URL
|
||||
value: "http://<your-api-endpoint>"
|
||||
- name: POSTGRES_URI
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: docsgpt-secrets
|
||||
key: POSTGRES_URI
|
||||
# Disable in-app auto-bootstrap. The `postgres-init` Job under
|
||||
# deployment/k8s/jobs/ owns schema creation and Alembic migrations,
|
||||
# so application pods must not race with it on rollout.
|
||||
- name: AUTO_MIGRATE
|
||||
value: "false"
|
||||
- name: AUTO_CREATE_DB
|
||||
value: "false"
|
||||
---
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
@@ -95,4 +142,4 @@ spec:
|
||||
- name: VITE_API_HOST
|
||||
value: "http://<your-api-endpoint>"
|
||||
- name: VITE_API_STREAMING
|
||||
value: "true"
|
||||
value: "true"
|
||||
|
||||
79
deployment/k8s/deployments/postgres-deploy.yaml
Normal file
79
deployment/k8s/deployments/postgres-deploy.yaml
Normal file
@@ -0,0 +1,79 @@
|
||||
apiVersion: v1
|
||||
kind: PersistentVolumeClaim
|
||||
metadata:
|
||||
name: postgres-pvc
|
||||
spec:
|
||||
accessModes:
|
||||
- ReadWriteOnce
|
||||
resources:
|
||||
requests:
|
||||
storage: 5Gi # Adjust size as needed
|
||||
|
||||
---
|
||||
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: postgres
|
||||
spec:
|
||||
replicas: 1
|
||||
selector:
|
||||
matchLabels:
|
||||
app: postgres
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: postgres
|
||||
spec:
|
||||
containers:
|
||||
- name: postgres
|
||||
image: postgres:16-alpine
|
||||
ports:
|
||||
- containerPort: 5432
|
||||
env:
|
||||
- name: POSTGRES_USER
|
||||
value: "docsgpt"
|
||||
- name: POSTGRES_PASSWORD
|
||||
value: "docsgpt"
|
||||
- name: POSTGRES_DB
|
||||
value: "docsgpt"
|
||||
- name: PGDATA
|
||||
value: "/var/lib/postgresql/data/pgdata"
|
||||
resources:
|
||||
limits:
|
||||
memory: "1Gi"
|
||||
cpu: "1"
|
||||
requests:
|
||||
memory: "256Mi"
|
||||
cpu: "100m"
|
||||
volumeMounts:
|
||||
- name: postgres-data
|
||||
mountPath: /var/lib/postgresql/data
|
||||
readinessProbe:
|
||||
exec:
|
||||
command:
|
||||
- pg_isready
|
||||
- -U
|
||||
- docsgpt
|
||||
- -d
|
||||
- docsgpt
|
||||
initialDelaySeconds: 5
|
||||
periodSeconds: 5
|
||||
timeoutSeconds: 3
|
||||
failureThreshold: 6
|
||||
livenessProbe:
|
||||
exec:
|
||||
command:
|
||||
- pg_isready
|
||||
- -U
|
||||
- docsgpt
|
||||
- -d
|
||||
- docsgpt
|
||||
initialDelaySeconds: 30
|
||||
periodSeconds: 15
|
||||
timeoutSeconds: 5
|
||||
failureThreshold: 3
|
||||
volumes:
|
||||
- name: postgres-data
|
||||
persistentVolumeClaim:
|
||||
claimName: postgres-pvc
|
||||
@@ -3,6 +3,15 @@ kind: Secret
|
||||
metadata:
|
||||
name: docsgpt-secrets
|
||||
type: Opaque
|
||||
# Notes:
|
||||
# - POSTGRES_URI below decodes to:
|
||||
# postgresql://docsgpt:docsgpt@postgres:5432/docsgpt
|
||||
# This matches the default Postgres Deployment/Service in this kustomization
|
||||
# and mirrors the compose-level default (see deployment/docker-compose.yaml).
|
||||
# - If you still need the MongoDB-backed vector store (VECTOR_STORE=mongodb),
|
||||
# manually add a MONGO_URI key below (base64-encoded) and apply the
|
||||
# opt-in manifests under deployment/k8s/optional-mongo/. Example:
|
||||
# MONGO_URI: <base64 of mongodb://mongodb-service:27017/docsgpt?retryWrites=true&w=majority>
|
||||
data:
|
||||
LLM_PROVIDER: ZG9jc2dwdA==
|
||||
INTERNAL_KEY: aW50ZXJuYWw=
|
||||
@@ -10,6 +19,7 @@ data:
|
||||
CELERY_RESULT_BACKEND: cmVkaXM6Ly9yZWRpcy1zZXJ2aWNlOjYzNzkvMA==
|
||||
QDRANT_URL: cmVkaXM6Ly9yZWRpcy1zZXJ2aWNlOjYzNzkvMA==
|
||||
QDRANT_PORT: NjM3OQ==
|
||||
MONGO_URI: bW9uZ29kYjovL21vbmdvZGItc2VydmljZToyNzAxNy9kb2NzZ3B0P3JldHJ5V3JpdGVzPXRydWUmdz1tYWpvcml0eQ==
|
||||
mongo-user: bW9uZ28tdXNlcg==
|
||||
mongo-password: bW9uZ28tcGFzc3dvcmQ=
|
||||
# postgresql://docsgpt:docsgpt@postgres:5432/docsgpt
|
||||
POSTGRES_URI: cG9zdGdyZXNxbDovL2RvY3NncHQ6ZG9jc2dwdEBwb3N0Z3Jlczo1NDMyL2RvY3NncHQ=
|
||||
postgres-user: ZG9jc2dwdA==
|
||||
postgres-password: ZG9jc2dwdA==
|
||||
|
||||
45
deployment/k8s/jobs/postgres-init-job.yaml
Normal file
45
deployment/k8s/jobs/postgres-init-job.yaml
Normal file
@@ -0,0 +1,45 @@
|
||||
# One-shot migrator: runs `python scripts/db/init_postgres.py` (alembic
|
||||
# upgrade head), then exits. The backend and worker Deployments rely on
|
||||
# an init container that blocks until Postgres is reachable, but this Job
|
||||
# ensures the schema is migrated before application pods start serving.
|
||||
#
|
||||
# Re-apply after upgrades to run new migrations. Safe to re-run: idempotent.
|
||||
apiVersion: batch/v1
|
||||
kind: Job
|
||||
metadata:
|
||||
name: postgres-init
|
||||
spec:
|
||||
backoffLimit: 6
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: postgres-init
|
||||
spec:
|
||||
restartPolicy: OnFailure
|
||||
initContainers:
|
||||
- name: wait-for-postgres
|
||||
image: postgres:16-alpine
|
||||
command:
|
||||
- sh
|
||||
- -c
|
||||
- |
|
||||
until pg_isready -h postgres -p 5432 -U docsgpt -d docsgpt; do
|
||||
echo "Waiting for postgres..."; sleep 2;
|
||||
done
|
||||
containers:
|
||||
- name: postgres-init
|
||||
image: arc53/docsgpt
|
||||
command: ["python", "scripts/db/init_postgres.py"]
|
||||
envFrom:
|
||||
- secretRef:
|
||||
name: docsgpt-secrets
|
||||
env:
|
||||
- name: FLASK_APP
|
||||
value: "application/app.py"
|
||||
resources:
|
||||
limits:
|
||||
memory: "1Gi"
|
||||
cpu: "1"
|
||||
requests:
|
||||
memory: "256Mi"
|
||||
cpu: "100m"
|
||||
20
deployment/k8s/kustomization.yaml
Normal file
20
deployment/k8s/kustomization.yaml
Normal file
@@ -0,0 +1,20 @@
|
||||
# Default DocsGPT Kubernetes stack.
|
||||
#
|
||||
# Uses Postgres as the user-data store (mirrors
|
||||
# deployment/docker-compose.yaml). MongoDB manifests are opt-in and live
|
||||
# under deployment/k8s/optional-mongo/ — they are NOT included here and
|
||||
# will not be applied by `kubectl apply -k deployment/k8s/`.
|
||||
apiVersion: kustomize.config.k8s.io/v1beta1
|
||||
kind: Kustomization
|
||||
|
||||
resources:
|
||||
- docsgpt-secrets.yaml
|
||||
- deployments/postgres-deploy.yaml
|
||||
- deployments/qdrant-deploy.yaml
|
||||
- deployments/redis-deploy.yaml
|
||||
- deployments/docsgpt-deploy.yaml
|
||||
- services/postgres-service.yaml
|
||||
- services/qdrant-service.yaml
|
||||
- services/redis-service.yaml
|
||||
- services/docsgpt-service.yaml
|
||||
- jobs/postgres-init-job.yaml
|
||||
23
deployment/k8s/optional-mongo/README.md
Normal file
23
deployment/k8s/optional-mongo/README.md
Normal file
@@ -0,0 +1,23 @@
|
||||
# Optional: MongoDB manifests
|
||||
|
||||
These manifests are **opt-in**. The default DocsGPT install uses Postgres
|
||||
for user data (see `deployment/k8s/deployments/postgres-deploy.yaml`).
|
||||
|
||||
Apply the manifests in this directory only if you run DocsGPT with the
|
||||
MongoDB-backed vector store (`VECTOR_STORE=mongodb`) and need an
|
||||
in-cluster MongoDB, or if you are intentionally running on the legacy
|
||||
MongoDB user-data store during the Postgres migration window.
|
||||
|
||||
Mirrors `deployment/optional/` for compose — not applied by the default
|
||||
`kubectl apply -k deployment/k8s/`.
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
kubectl apply -f deployment/k8s/optional-mongo/deployments/mongo-deploy.yaml
|
||||
kubectl apply -f deployment/k8s/optional-mongo/services/mongo-service.yaml
|
||||
```
|
||||
|
||||
Then extend `docsgpt-secrets.yaml` with a base64-encoded `MONGO_URI`
|
||||
pointing at `mongodb://mongodb-service:27017/docsgpt?retryWrites=true&w=majority`
|
||||
(or your Atlas/external URI) before re-applying the secret.
|
||||
12
deployment/k8s/services/postgres-service.yaml
Normal file
12
deployment/k8s/services/postgres-service.yaml
Normal file
@@ -0,0 +1,12 @@
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: postgres
|
||||
spec:
|
||||
selector:
|
||||
app: postgres
|
||||
ports:
|
||||
- protocol: TCP
|
||||
port: 5432
|
||||
targetPort: 5432
|
||||
type: ClusterIP
|
||||
@@ -115,10 +115,10 @@ Once an agent is created, you can:
|
||||
|
||||
## Seeding Premade Agents from YAML
|
||||
|
||||
You can bootstrap a fresh DocsGPT deployment with a curated set of agents by seeding them directly into MongoDB.
|
||||
You can bootstrap a fresh DocsGPT deployment with a curated set of agents by seeding them directly into the user-data store (Postgres).
|
||||
|
||||
1. **Customize the configuration** – edit `application/seed/config/premade_agents.yaml` (or copy from `application/seed/config/agents_template.yaml`) to describe the agents you want to provision. Each entry lets you define prompts, tools, and optional data sources.
|
||||
2. **Ensure dependencies are running** – MongoDB must be reachable using the credentials in `.env`, and a Celery worker should be available if any agent sources need to be ingested via `ingest_remote`.
|
||||
2. **Ensure dependencies are running** – Postgres must be reachable using `POSTGRES_URI` from `.env` (schema applied via `python scripts/db/init_postgres.py`), and a Celery worker should be available if any agent sources need to be ingested via `ingest_remote`.
|
||||
3. **Execute the seeder** – run `python -m application.seed.commands init`. Add `--force` when you need to reseed an existing environment.
|
||||
|
||||
The seeder keeps templates under the `system` user so they appear in the UI for anyone to clone or customize. Environment variable placeholders such as `${MY_TOKEN}` inside tool configs are resolved during the seeding process.
|
||||
|
||||
@@ -3,17 +3,19 @@ title: Setting Up a Development Environment
|
||||
description: Guide to setting up a development environment for DocsGPT, including backend and frontend setup.
|
||||
---
|
||||
|
||||
import { Callout } from 'nextra/components'
|
||||
|
||||
# Setting Up a Development Environment
|
||||
|
||||
This guide will walk you through setting up a development environment for DocsGPT. This setup allows you to modify and test the application's backend and frontend components.
|
||||
|
||||
## 1. Spin Up MongoDB and Redis
|
||||
## 1. Spin Up Postgres and Redis
|
||||
|
||||
For development purposes, you can quickly start MongoDB and Redis containers, which are the primary database and caching systems used by DocsGPT. We provide a dedicated Docker Compose file, `docker-compose-dev.yaml`, located in the `deployment` directory, that includes only these essential services.
|
||||
For development purposes, you can quickly start Postgres and Redis containers. Postgres is the user-data store for DocsGPT (conversations, agents, prompts, sources, attachments, workflows, logs, and token usage), and Redis is used as the cache and Celery broker. We provide a dedicated Docker Compose file, `docker-compose-dev.yaml`, located in the `deployment` directory, that includes only these essential services. The backend applies the Alembic schema automatically on first boot (`AUTO_MIGRATE=true` / `AUTO_CREATE_DB=true` ship enabled), so no separate migration step is required. You can still run `python scripts/db/init_postgres.py` explicitly if you prefer.
|
||||
|
||||
You can find the `docker-compose-dev.yaml` file [here](https://github.com/arc53/DocsGPT/blob/main/deployment/docker-compose-dev.yaml).
|
||||
|
||||
**Steps to start MongoDB and Redis:**
|
||||
**Steps to start Postgres and Redis:**
|
||||
|
||||
1. Navigate to the root directory of your DocsGPT repository in your terminal.
|
||||
|
||||
@@ -24,7 +26,16 @@ You can find the `docker-compose-dev.yaml` file [here](https://github.com/arc53/
|
||||
docker compose -f deployment/docker-compose-dev.yaml up -d
|
||||
```
|
||||
|
||||
These commands will start MongoDB and Redis in detached mode, running in the background.
|
||||
These commands will start Postgres and Redis in detached mode, running in the background. When the Flask backend boots against the fresh Postgres instance, it will automatically create the database (if missing) and apply the current Alembic schema.
|
||||
|
||||
<Callout type="info" emoji="ℹ️">
|
||||
MongoDB is no longer required for a default DocsGPT install. If you
|
||||
specifically want to use MongoDB Atlas as the vector store
|
||||
(`VECTOR_STORE=mongodb`), start it on the side via
|
||||
`deployment/docker-compose.mongo.yaml`. For migrating an existing
|
||||
Mongo-based install to Postgres, see
|
||||
[PostgreSQL for User Data](/Deploying/Postgres-Migration).
|
||||
</Callout>
|
||||
|
||||
## 2. Run the Backend
|
||||
|
||||
|
||||
@@ -3,6 +3,8 @@ title: DocsGPT Settings
|
||||
description: Configure your DocsGPT application by understanding the basic settings.
|
||||
---
|
||||
|
||||
import { Callout } from 'nextra/components'
|
||||
|
||||
# DocsGPT Settings
|
||||
|
||||
DocsGPT is highly configurable, allowing you to tailor it to your specific needs and preferences. You can control various aspects of the application, from choosing the Large Language Model (LLM) provider to selecting embedding models and vector stores.
|
||||
@@ -239,6 +241,47 @@ SAGEMAKER_REGION=us-east-1
|
||||
|
||||
Your IAM user needs these permissions on the bucket: `s3:PutObject`, `s3:GetObject`, `s3:DeleteObject`, `s3:ListBucket`, `s3:HeadObject`.
|
||||
|
||||
## User-Data Storage (Postgres)
|
||||
|
||||
DocsGPT stores user data — conversations, agents, prompts, sources, attachments, workflows, logs, and token usage — in **PostgreSQL**. The backend connects via a single setting:
|
||||
|
||||
| Setting | Description | Default |
|
||||
| --- | --- | --- |
|
||||
| `POSTGRES_URI` | SQLAlchemy-compatible Postgres URI. Any standard `postgresql://` form works — DocsGPT normalizes it internally to the `psycopg` v3 dialect. | — |
|
||||
| `AUTO_CREATE_DB` | On startup, connect to the server's `postgres` maintenance DB and issue `CREATE DATABASE` if the target is missing. Requires `CREATEDB` or superuser. No-op when the database already exists. Disable in production. | `true` |
|
||||
| `AUTO_MIGRATE` | On startup, run `alembic upgrade head` against the target database. Idempotent and serialized across workers via `alembic_version`. Disable in production in favor of an explicit migration step. | `true` |
|
||||
|
||||
Example:
|
||||
|
||||
```env
|
||||
POSTGRES_URI=postgresql://docsgpt:docsgpt@localhost:5432/docsgpt
|
||||
# Append ?sslmode=require for managed providers that enforce SSL.
|
||||
```
|
||||
|
||||
With the defaults, the app applies the schema automatically on first
|
||||
boot. To run it explicitly instead (e.g., in CI/CD or a k8s `Job`):
|
||||
|
||||
```bash
|
||||
python scripts/db/init_postgres.py
|
||||
```
|
||||
|
||||
The default Docker Compose file bundles a `postgres` service, and the
|
||||
app auto-bootstraps the database on boot, so containerized deployments
|
||||
need no manual migration step. See
|
||||
[PostgreSQL for User Data](/Deploying/Postgres-Migration#production-hardening)
|
||||
for the recommended production flow (both flags `false`, migrations
|
||||
gated by CI/CD).
|
||||
|
||||
<Callout type="info" emoji="ℹ️">
|
||||
`MONGO_URI` is **opt-in**. It is only consulted when you select the
|
||||
MongoDB Atlas vector-store backend (`VECTOR_STORE=mongodb`) or when
|
||||
running the one-shot `scripts/db/backfill.py` migration from a legacy
|
||||
Mongo-based install. Installing the optional Mongo client libraries
|
||||
requires `pip install 'pymongo>=4.6'`. See
|
||||
[PostgreSQL for User Data](/Deploying/Postgres-Migration) for the
|
||||
migration path.
|
||||
</Callout>
|
||||
|
||||
## Exploring More Settings
|
||||
|
||||
These are just the basic settings to get you started. The `settings.py` file contains many more advanced options that you can explore to further customize DocsGPT, such as:
|
||||
|
||||
@@ -1,114 +1,151 @@
|
||||
---
|
||||
title: PostgreSQL for User Data
|
||||
description: Set up PostgreSQL as the user-data store for DocsGPT and migrate from MongoDB at your own pace.
|
||||
description: PostgreSQL is the user-data store for DocsGPT. Covers auto-bootstrap, production hardening, and the one-shot migration from legacy MongoDB deployments.
|
||||
---
|
||||
|
||||
import { Callout } from 'nextra/components'
|
||||
|
||||
# PostgreSQL for User Data
|
||||
|
||||
DocsGPT is progressively moving user data (conversations, agents, prompts,
|
||||
preferences, etc.) from MongoDB to PostgreSQL, one collection at a time.
|
||||
Each collection is guarded by a feature flag so you can opt in and roll
|
||||
back instantly. MongoDB stays the source of truth until you cut over
|
||||
reads; vector stores (`VECTOR_STORE=pgvector`, `faiss`, `qdrant`, `mongodb`, …)
|
||||
are unaffected.
|
||||
DocsGPT stores conversations, agents, prompts, sources, attachments,
|
||||
workflows, logs, and token usage in **PostgreSQL**. MongoDB is no longer
|
||||
required.
|
||||
|
||||
<Callout type="info" emoji="ℹ️">
|
||||
Which collections are available today is in the [Status](#status)
|
||||
table below. That table is the only part of this page that changes
|
||||
release to release.
|
||||
Vector stores are independent — `VECTOR_STORE` can still be `pgvector`,
|
||||
`faiss`, `qdrant`, `milvus`, `elasticsearch`, or `mongodb`.
|
||||
</Callout>
|
||||
|
||||
## Setup
|
||||
## Quickstart
|
||||
|
||||
1. **Run Postgres 13+.** Native install, Docker, or managed (Neon, RDS,
|
||||
Supabase, Cloud SQL…) — all work. You'll need the `pgcrypto` and
|
||||
`citext` extensions, both standard contrib modules available
|
||||
everywhere.
|
||||
Three common paths. Each assumes Postgres 13+ and the default env vars
|
||||
`AUTO_MIGRATE=true` / `AUTO_CREATE_DB=true` (both ship enabled).
|
||||
|
||||
2. **Create a database and role** (skip if your managed provider gave
|
||||
you these):
|
||||
### Docker Compose
|
||||
|
||||
```sql
|
||||
CREATE ROLE docsgpt LOGIN PASSWORD 'docsgpt';
|
||||
CREATE DATABASE docsgpt OWNER docsgpt;
|
||||
```
|
||||
The bundled compose file ships a `postgres` service. App boot handles the
|
||||
rest — no sidecar, no init job.
|
||||
|
||||
3. **Set `POSTGRES_URI` in `.env`.** Any standard Postgres URI works —
|
||||
DocsGPT normalizes it internally.
|
||||
```bash
|
||||
cd deployment && docker compose up
|
||||
```
|
||||
|
||||
```bash
|
||||
POSTGRES_URI=postgresql://docsgpt:docsgpt@localhost:5432/docsgpt
|
||||
# Append ?sslmode=require for managed providers that enforce SSL.
|
||||
```
|
||||
### Managed Postgres (Neon, RDS, Supabase, Cloud SQL)
|
||||
|
||||
4. **Apply the schema** (idempotent — safe to re-run):
|
||||
Point `POSTGRES_URI` at the provider-given URI. The app applies the
|
||||
schema on first boot.
|
||||
|
||||
```bash
|
||||
python scripts/db/init_postgres.py
|
||||
```
|
||||
```bash
|
||||
export POSTGRES_URI="postgresql://user:pass@host/docsgpt?sslmode=require"
|
||||
flask --app application/app.py run --host=0.0.0.0 --port=7091
|
||||
```
|
||||
|
||||
## Migrating data
|
||||
### Bare-metal Postgres
|
||||
|
||||
Two global flags, no per-collection knobs — every collection marked ✅
|
||||
in the [Status](#status) table is handled automatically.
|
||||
Run Postgres locally and point `POSTGRES_URI` at the default superuser.
|
||||
First boot creates both the database and the schema.
|
||||
|
||||
1. **Enable dual-write.** Writes go to both Mongo and Postgres; Mongo
|
||||
remains source of truth. Set the flag in `.env` and restart:
|
||||
```bash
|
||||
export POSTGRES_URI="postgresql://postgres@localhost/docsgpt"
|
||||
flask --app application/app.py run --host=0.0.0.0 --port=7091
|
||||
```
|
||||
|
||||
```bash
|
||||
USE_POSTGRES=true
|
||||
```
|
||||
Prefer a dedicated non-superuser role? Create it once as superuser — the
|
||||
app never creates roles.
|
||||
|
||||
2. **Backfill existing data.** Idempotent — re-run any time to re-sync
|
||||
drifted rows. Without arguments, backfills every registered table;
|
||||
pass `--tables` to limit.
|
||||
```sql
|
||||
CREATE ROLE docsgpt LOGIN PASSWORD 'docsgpt' CREATEDB;
|
||||
-- Then: POSTGRES_URI=postgresql://docsgpt:docsgpt@localhost/docsgpt
|
||||
```
|
||||
|
||||
```bash
|
||||
python scripts/db/backfill.py --dry-run # preview everything
|
||||
python scripts/db/backfill.py # real run, everything
|
||||
python scripts/db/backfill.py --tables users # only specific tables
|
||||
```
|
||||
## How auto-bootstrap works
|
||||
|
||||
3. **Cut over reads** once you trust the Postgres state:
|
||||
Two env vars control startup behavior. Both default to `true` in the
|
||||
app and are idempotent.
|
||||
|
||||
```bash
|
||||
READ_POSTGRES=true
|
||||
```
|
||||
| Setting | Effect | Requires |
|
||||
| --- | --- | --- |
|
||||
| `AUTO_CREATE_DB` | If the target database is missing, connects to the server's `postgres` maintenance DB and issues `CREATE DATABASE`. | `CREATEDB` privilege (or superuser) |
|
||||
| `AUTO_MIGRATE` | Runs `alembic upgrade head` against the target database. | Table-owner or superuser on the target DB |
|
||||
|
||||
Rollback is instant: unset `READ_POSTGRES` and restart. Dual-write
|
||||
keeps Postgres up to date so you can flip back and forth.
|
||||
Concurrent workers serialize through `alembic_version`, so rolling
|
||||
restarts are safe. If the role lacks the required privilege, startup
|
||||
fails fast with a clear error rather than silently skipping.
|
||||
|
||||
<Callout type="info" emoji="ℹ️">
|
||||
Convenient in dev. In production, disable both and run migrations as
|
||||
an explicit step — see [Production hardening](#production-hardening).
|
||||
</Callout>
|
||||
|
||||
## Production hardening
|
||||
|
||||
Set both flags to `false` in prod and run migrations as a gated,
|
||||
auditable step before rolling out the app.
|
||||
|
||||
```env
|
||||
AUTO_MIGRATE=false
|
||||
AUTO_CREATE_DB=false
|
||||
```
|
||||
|
||||
Run migrations from your CI/CD pipeline, a Kubernetes `Job`, or an
|
||||
init-container ahead of the app rollout:
|
||||
|
||||
```bash
|
||||
python scripts/db/init_postgres.py
|
||||
# equivalently:
|
||||
alembic -c application/alembic.ini upgrade head
|
||||
```
|
||||
|
||||
The reasoning: the app's runtime role shouldn't carry DDL privileges,
|
||||
migrations should gate each rollout, and an explicit step is
|
||||
auditable — implicit first-boot bootstrap is fine for dev but muddies
|
||||
prod deploys.
|
||||
|
||||
<Callout type="warning" emoji="⚠️">
|
||||
Don't decommission MongoDB until every collection you use is fully
|
||||
cut over. During the migration window, Mongo is still required.
|
||||
Migrations are not reversible by the app. Always back up production
|
||||
Postgres before running `alembic upgrade head` on a new release.
|
||||
</Callout>
|
||||
|
||||
## Status
|
||||
## Migrating from MongoDB
|
||||
|
||||
_Last updated: 2026-04-10_
|
||||
One-shot, offline, app stopped. The app itself will create the
|
||||
Postgres schema when it boots — you only need to run the data copy.
|
||||
|
||||
| Collection | Status |
|
||||
|---|---|
|
||||
| `users` | ✅ Phase 1 |
|
||||
| `prompts`, `user_tools`, `feedback`, `stack_logs`, `user_logs`, `token_usage` | ⏳ Phase 1 |
|
||||
| `agents`, `sources`, `attachments`, `memories`, `todos`, `notes`, `connector_sessions`, `agent_folders` | ⏳ Phase 2 |
|
||||
| `conversations`, `pending_tool_state`, `workflows` | ⏳ Phase 3 |
|
||||
```bash
|
||||
pip install -r application/requirements.txt
|
||||
pip install 'pymongo>=4.6'
|
||||
|
||||
Schemas for **every** row above already exist after `init_postgres.py`
|
||||
runs. What's landing progressively is the application-level dual-write
|
||||
wiring and the backfill logic for each collection. Once a collection
|
||||
is ✅, enabling `USE_POSTGRES=true` and running `python scripts/db/backfill.py`
|
||||
picks it up automatically — no per-collection config change.
|
||||
export POSTGRES_URI="postgresql://docsgpt:docsgpt@localhost:5432/docsgpt"
|
||||
export MONGO_URI="mongodb://user:pass@host:27017/docsgpt"
|
||||
|
||||
python scripts/db/backfill.py --dry-run # preview
|
||||
python scripts/db/backfill.py # real run
|
||||
# or: python scripts/db/backfill.py --tables users,agents
|
||||
```
|
||||
|
||||
Then unset `MONGO_URI` and start the backend — nothing consults Mongo
|
||||
in the default path anymore. The backfill is idempotent (per-table
|
||||
`ON CONFLICT` upserts, event-log tables deduped via `mongo_id`), so
|
||||
re-running is safe and re-syncs any drifted rows. Keep Mongo online
|
||||
until you've verified Postgres is complete; decommission afterwards
|
||||
unless you still use it as a vector store.
|
||||
|
||||
<Callout type="warning" emoji="⚠️">
|
||||
No dual-write window and no runtime flag — on the current version,
|
||||
Postgres is the only user-data store the backend reads or writes.
|
||||
</Callout>
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
- **`relation "..." does not exist`** — run `python scripts/db/init_postgres.py`.
|
||||
- **`FATAL: role "docsgpt" does not exist`** — run the `CREATE ROLE` /
|
||||
`CREATE DATABASE` statements from step 2 as a Postgres superuser.
|
||||
- **`relation "..." does not exist`** — schema not applied. Either let
|
||||
the app bootstrap it (`AUTO_MIGRATE=true`) or run
|
||||
`python scripts/db/init_postgres.py`.
|
||||
- **`permission denied to create database`** — the role lacks
|
||||
`CREATEDB`. As superuser: `ALTER ROLE <name> CREATEDB;`. Or create
|
||||
the database manually and set `AUTO_CREATE_DB=false`.
|
||||
- **`role "docsgpt" does not exist`** — roles are never auto-created.
|
||||
As superuser: `CREATE ROLE docsgpt LOGIN PASSWORD '...';`.
|
||||
- **SSL errors on a managed provider** — append `?sslmode=require` to
|
||||
`POSTGRES_URI`.
|
||||
- **Dual-write warnings in the logs** — expected to be non-fatal. Mongo
|
||||
is source of truth, so the user-facing request succeeds. Re-run the
|
||||
backfill to re-sync whichever rows drifted.
|
||||
- **`ModuleNotFoundError: pymongo`** — `pip install 'pymongo>=4.6'`
|
||||
(only needed for the one-shot Mongo backfill).
|
||||
|
||||
@@ -15,7 +15,7 @@ This diagram provides a bird's-eye view of the DocsGPT architecture, illustratin
|
||||
flowchart LR
|
||||
User["User"] --> Frontend["Frontend (React/Vite)"]
|
||||
Frontend --> Backend["Backend API (Flask)"]
|
||||
Backend --> LLM["LLM Integration Layer"] & VectorStore["Vector Stores"] & TaskQueue["Task Queue (Celery)"] & Databases["Databases (MongoDB, Redis)"]
|
||||
Backend --> LLM["LLM Integration Layer"] & VectorStore["Vector Stores"] & TaskQueue["Task Queue (Celery)"] & Databases["Databases (Postgres, Redis)"]
|
||||
LLM -- Cloud APIs / Local Engines --> InferenceEngine["Inference Engine"]
|
||||
VectorStore -- Document Embeddings --> Indexes[("Indexes")]
|
||||
TaskQueue -- Asynchronous Tasks --> DocumentIngestion["Document Ingestion"]
|
||||
@@ -86,10 +86,11 @@ flowchart LR
|
||||
* Improves application responsiveness by offloading heavy tasks.
|
||||
* Enhances scalability and reliability through distributed task processing.
|
||||
|
||||
### 7. Databases (MongoDB, Redis)
|
||||
### 7. Databases (Postgres, Redis)
|
||||
|
||||
* **Technology:** MongoDB and Redis.
|
||||
* **Responsibility:** Databases are used for persistent data storage and caching. MongoDB stores structured data such as conversations, documents, user settings, and API keys. Redis is used as a cache, as well as a message broker for Celery.
|
||||
* **Technology:** PostgreSQL and Redis.
|
||||
* **Responsibility:** Databases are used for persistent data storage and caching. PostgreSQL stores structured user data such as conversations, agents, prompts, sources, attachments, workflows, logs, token usage, user settings, and API keys. Redis is used as a cache and as the message broker/result backend for Celery.
|
||||
* **Note:** MongoDB is no longer used for user data. It remains an **optional** backend for the vector store (`VECTOR_STORE=mongodb`, i.e. Mongo Atlas Vector Search) and as the source database for the one-shot `scripts/db/backfill.py` migration from legacy installs.
|
||||
|
||||
## Request Flow Diagram
|
||||
|
||||
@@ -135,7 +136,7 @@ graph LR
|
||||
RedisPod[Redis Pod]
|
||||
end
|
||||
subgraph Node 3
|
||||
MongoDBPod[MongoDB Pod]
|
||||
PostgresPod[Postgres Pod]
|
||||
VectorStorePod[Vector Store Pod]
|
||||
end
|
||||
end
|
||||
@@ -145,12 +146,12 @@ graph LR
|
||||
docsgpt-api-service --> BackendAPIPod
|
||||
BackendAPIPod --> CeleryWorkerPod
|
||||
BackendAPIPod --> RedisPod
|
||||
BackendAPIPod --> MongoDBPod
|
||||
BackendAPIPod --> PostgresPod
|
||||
BackendAPIPod --> VectorStorePod
|
||||
CeleryWorkerPod --> RedisPod
|
||||
BackendAPIPod --> InferenceEngine[(Inference Engine)]
|
||||
VectorStorePod --> Indexes[(Indexes)]
|
||||
MongoDBPod --> Data[(Data)]
|
||||
PostgresPod --> Data[(Data)]
|
||||
RedisPod --> Cache[(Cache)]
|
||||
end
|
||||
User[User] --> LoadBalancer
|
||||
|
||||
@@ -31,7 +31,7 @@ DOCLING_OCR_ATTACHMENTS_ENABLED=false
|
||||
### Attachment flow (Chat-only file context)
|
||||
|
||||
1. Files are uploaded through `/api/store_attachment`.
|
||||
2. Celery task `attachment_worker` parses and stores the attachment in MongoDB (`attachments` collection).
|
||||
2. Celery task `attachment_worker` parses and stores the attachment in Postgres (`attachments` table).
|
||||
3. OCR in this path is controlled by `DOCLING_OCR_ATTACHMENTS_ENABLED`.
|
||||
4. Attachments are not vectorized and are not added to the source index.
|
||||
5. During answer generation, selected attachment IDs are loaded and passed directly to the LLM pipeline.
|
||||
|
||||
@@ -7,4 +7,3 @@ idna==3.7
|
||||
python-dotenv==1.0.1
|
||||
sniffio==1.3.1
|
||||
slack-bolt==1.21.0
|
||||
bson==0.5.10
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user