Compare commits

...

75 Commits

Author SHA1 Message Date
Alex
cb30a24e05 feat: fixes on pg2 2026-04-12 13:51:29 +01:00
Alex
530761d08c feat: pg-2 2026-04-12 13:35:32 +01:00
Alex
73fbc28744 Merge pull request #2376 from arc53/pg-1
Pg 1
2026-04-12 12:44:12 +01:00
Alex
b5b6538762 fix: tests 2026-04-12 12:35:23 +01:00
Alex
a9761061fc fix: mini issues 2026-04-12 12:24:58 +01:00
Alex
9388996a15 fix: ruff 2026-04-12 12:18:31 +01:00
Alex
875868b7e5 fix: comment 2026-04-12 12:16:19 +01:00
Alex
502819ae52 feat: pg migration, more tables 2026-04-12 12:15:59 +01:00
Alex
cada1a44fc Merge pull request #2373 from siiddhantt/feat/confluence-connector
feat: add Confluence connector for data ingestion
2026-04-12 11:33:49 +01:00
Alex
6192767451 fix: sanitize attachment filenames, drop dateutil dep, add connector docs 2026-04-12 11:32:24 +01:00
Alex
5c3e6eca54 Merge pull request #2375 from arc53/pg
feat: init pg migration
2026-04-12 10:44:31 +01:00
Alex
59d9d4ac50 fix: comments in settings 2026-04-12 10:42:46 +01:00
Alex
3931ccccee Merge pull request #2374 from ManishMadan2882/main
UX: Conversation scroll experience
2026-04-12 00:37:11 +01:00
Alex
55717043f6 fix: vale 2026-04-12 00:29:23 +01:00
Alex
ececcb8b17 feat: init pg migration 2026-04-12 00:07:24 +01:00
ManishMadan2882
420e9d3dd5 (feat) conversation: scroll experience 2026-04-10 19:27:07 +05:30
Siddhant Rai
749eed3d0b feat: add Confluence integration with authentication and file loading capabilities
- Enhanced settings.py to include Confluence client ID and secret
- Created ConfluenceAuth class for handling authentication with Confluence
- Implemented ConfluenceLoader class for loading data from Confluence
- Updated connector_creator.py to register Confluence as a connector
- Added confluence.svg asset for UI representation
- Modified ConnectorAuth component to support Confluence connection
- Updated FilePicker component to include Confluence as a file source
- Added localization support for Confluence in multiple languages (de, en, es, jp, ru, zh-TW, zh)
- Enhanced Upload component to handle Confluence file selection
- Updated ingestor types to include Confluence and its configuration
2026-04-10 19:10:35 +05:30
Alex
bd03a513e3 Merge pull request #2372 from arc53/fast-ebook
feat: faster ebook parsing
2026-04-09 18:38:13 +01:00
Alex
fcdb4fb5e8 feat: faster ebook parsing 2026-04-09 18:31:06 +01:00
Alex
e787c896eb upd Security.md 2026-04-08 12:49:20 +01:00
Alex
23aeaff5db Merge pull request #2362 from arc53/v1-mini-improvements
feat: history overwrite
2026-04-06 15:02:32 +01:00
Alex
689dd79597 fix: lang 2026-04-06 14:57:51 +01:00
Alex
0c15af90b1 feat: history overwrite 2026-04-06 14:42:01 +01:00
Alex
cdd6ff6557 chore: bump deps 2026-04-04 12:45:34 +01:00
Alex
72b3d94453 fix: tests 2026-04-03 18:30:46 +01:00
Alex
7e88d09e5d Merge branch 'main' of https://github.com/arc53/DocsGPT 2026-04-03 18:26:37 +01:00
Alex
74a4a237dc fix: bump deps 2026-04-03 18:26:29 +01:00
Alex
c3f01c6619 Merge pull request #2347 from ManishMadan2882/main
Minor frontend updates
2026-04-03 18:17:27 +01:00
Alex
6b408823d4 fix: mini theme color edits 2026-04-03 18:16:07 +01:00
Alex
3fc81ac5d8 fix: clean error 2026-04-03 18:08:38 +01:00
Alex
2652f8a5b0 fix: chatwoot 2026-04-03 18:04:49 +01:00
Alex
d711eefe96 patch: agent usage limits 2026-04-03 18:03:31 +01:00
Alex
79206f3919 fix: harden faiss 2026-04-03 17:57:49 +01:00
Alex
de971d9452 fix: validate mcp url 2026-04-03 17:52:48 +01:00
Alex
1b4d5ca0dd patch: mcp identity 2026-04-03 17:40:22 +01:00
Alex
81989e8258 fix: patch /v1/models 2026-04-03 17:37:09 +01:00
Alex
dc262d1698 patch: error 2026-04-03 17:30:23 +01:00
Alex
69f9c93869 patch: s3 2026-04-03 17:28:09 +01:00
Alex
74bf80b25c patch: sharing convos 2026-04-03 17:20:06 +01:00
Alex
d9a92a7208 feat: improve setup scripts 2026-04-03 17:15:21 +01:00
Alex
02e93d993d patch: available tools 2026-04-03 17:12:36 +01:00
Alex
6b6495f48c patch: key 2026-04-03 17:06:35 +01:00
Alex
249dd9ce37 patch: paths 2026-04-03 16:45:03 +01:00
Alex
9134ab0478 Merge branch 'main' of https://github.com/arc53/DocsGPT 2026-04-03 16:40:50 +01:00
Alex
10ef68c9d0 Revise vulnerability reporting process
Updated vulnerability reporting instructions to use GitHub's private reporting flow.
2026-04-03 16:36:10 +01:00
Alex
7d65cf1c2b chore: bump deps 2026-04-03 16:35:10 +01:00
Alex
13c6cc59c1 Merge pull request #2349 from arc53/messages-format
Messages format
2026-04-03 16:26:57 +01:00
Alex
6381f7dd4e fix: remove bad tests 2026-04-03 16:20:15 +01:00
Alex
e6ac4008fe feat: better tool names for llms 2026-04-03 15:35:50 +01:00
Alex
1af09f114d fix: tool mapping 2026-04-03 13:32:55 +01:00
Alex
be7da983e7 fix: remove internal tools when creating tools and better Approval gate UX 2026-04-03 10:36:48 +01:00
Alex
8b9e595d85 fix: structure improvements of messages 2026-04-01 14:58:44 +01:00
Alex
398f3acc8d fix: clean error 2026-04-01 13:01:02 +01:00
Alex
e04baa7ed8 feat: tests and approval gate 2026-04-01 12:49:32 +01:00
Alex
e5586b6f20 feat: fronted connection to api 2026-04-01 10:55:54 +01:00
Alex
addf57cab7 feat: compatible api 2026-03-31 23:10:09 +01:00
ManishMadan2882
648b3f1d20 (fix) lint/fe 2026-04-01 03:30:44 +05:30
ManishMadan2882
a75a9e23f9 (feat:fe) minor good things 2026-04-01 03:19:03 +05:30
Alex
73256389cf feat: client side tools 2026-03-31 22:20:55 +01:00
Alex
d609efca49 feat: continuation messages 2026-03-31 21:30:24 +01:00
Alex
772860b667 fix: mini fe changes 2026-03-31 11:59:38 +01:00
Alex
ea2fd8b04a chore: remove unused deps 2026-03-31 11:57:01 +01:00
Alex
2c73deac20 deps upgrades 2026-03-31 11:32:55 +01:00
Alex
47f3907e5e Merge pull request #2340 from arc53/coverage-3
chore: more tests
2026-03-31 00:50:46 +01:00
Alex
81532ada2a Merge pull request #2318 from siiddhantt/feat/standardize-css
feat: update styles and improve accessibility across frontend
2026-03-30 23:26:45 +01:00
ManishMadan2882
43f71374e5 (chore:fe) lint-fix 2026-03-30 23:26:11 +05:30
Siddhant Rai
3b66a3176c fix: improve option matching logic in Dropdown component + selected style 2026-03-30 18:36:35 +05:30
Siddhant Rai
9a6a55b6da Merge branch 'main' into feat/standardize-css 2026-03-30 13:14:43 +05:30
Siddhant Rai
12a8368216 fix: merge conflicts 2026-03-30 13:12:24 +05:30
Siddhant Rai
193ca6fd63 fix: lint errors + redundant css 2026-03-28 15:02:16 +05:30
Siddhant Rai
174dee0fe6 fix: inconsistencies with prev color patterns 2026-03-26 18:32:57 +05:30
Siddhant Rai
844167ba06 Merge branch 'feat/standardize-css' of https://github.com/siiddhantt/DocsGPT into feat/standardize-css 2026-03-25 19:36:35 +05:30
Siddhant Rai
6fa3acb1ca style: standardize colors across components according to figma 2026-03-25 19:36:32 +05:30
Alex
9fd063266b Mini fixes 2026-03-24 01:40:29 +00:00
Siddhant Rai
324a8cd4cf refactor: update styles and improve accessibility across frontend
- Updated text colors to use foreground and muted-foreground for better contrast.
- Replaced hardcoded colors with theme-based classes for consistency.
- Enhanced input fields with icons for improved usability.
- Adjusted button styles for a more cohesive design.
- Refactored search input components to use consistent styling and layout.
- Improved layout and spacing in various components for better user experience.
- Updated tool and source titles and subtitles for clarity.
2026-03-20 17:10:27 +05:30
275 changed files with 16389 additions and 2935 deletions

View File

@@ -34,3 +34,9 @@ MICROSOFT_TENANT_ID=your-azure-ad-tenant-id
#or "https://login.microsoftonline.com/contoso.onmicrosoft.com".
#Alternatively, use "https://login.microsoftonline.com/common" for multi-tenant app.
MICROSOFT_AUTHORITY=https://{tenantId}.ciamlogin.com/{tenantId}
# User-data Postgres DB (Phase 0 of the MongoDB→Postgres migration).
# Standard Postgres URI — `postgres://` and `postgresql://` both work.
# Leave unset while the migration is still being rolled out; the app will
# fall back to MongoDB for user data until POSTGRES_URI is configured.
# POSTGRES_URI=postgresql://docsgpt:docsgpt@localhost:5432/docsgpt

View File

@@ -1,46 +1,80 @@
Ollama
Qdrant
Milvus
Chatwoot
Nextra
VSCode
npm
LLMs
Agentic
Anthropic's
api
APIs
Groq
SGLang
LMDeploy
OAuth
Vite
LLM
JSONPath
UIs
Atlassian
automations
autoescaping
Autoescaping
backfill
backfills
bool
boolean
brave_web_search
chatbot
Chatwoot
config
configs
uncomment
qdrant
vectorstore
CSVs
dev
diarization
Docling
docsgpt
llm
docstrings
Entra
env
enqueues
EOL
ESLint
feedbacks
Figma
GPUs
Groq
hardcode
hardcoding
Idempotency
JSONPath
kubectl
Lightsail
enqueues
chatbot
VSCode's
Shareability
feedbacks
automations
llama_cpp
llm
LLM
LLMs
LMDeploy
Milvus
Mixtral
namespace
namespaces
needs_auth
Nextra
Novita
npm
OAuth
Ollama
opencode
parsable
passthrough
PDFs
pgvector
Postgres
Premade
Signup
Pydantic
pytest
Qdrant
qdrant
Repo
repo
env
URl
agentic
llama_cpp
parsable
Sanitization
SDKs
boolean
bool
hardcode
EOL
SGLang
Shareability
Signup
Supabase
UIs
uncomment
URl
vectorstore
Vite
VSCode
VSCode's
widget's

3
.gitignore vendored
View File

@@ -108,6 +108,8 @@ celerybeat.pid
# Environments
.env
.venv
# Machine-specific Claude Code guidance (see CLAUDE.md preamble)
CLAUDE.md
env/
venv/
ENV/
@@ -181,5 +183,6 @@ application/vectors/
node_modules/
.vscode/settings.json
.vscode/sftp.json
/models/
model/

View File

@@ -1,5 +1,7 @@
MinAlertLevel = warning
StylesPath = .github/styles
Vocab = DocsGPT
[*.{md,mdx}]
BasedOnStyles = DocsGPT

View File

@@ -2,13 +2,21 @@
## Supported Versions
Supported Versions:
Currently, we support security patches by committing changes and bumping the version published on Github.
Security patches target the latest release and the `main` branch. We recommend always running the most recent version.
## Reporting a Vulnerability
Found a vulnerability? Please email us:
Preferred method: use GitHub's private vulnerability reporting flow:
https://github.com/arc53/DocsGPT/security
security@arc53.com
Then click **Report a vulnerability**.
Alternatively, email us at: security@arc53.com
We aim to acknowledge reports within 48 hours.
## Incident Handling
Arc53 maintains internal incident response procedures. If you believe an active exploit is occurring, include **URGENT** in your report subject line.

View File

@@ -1,7 +1,8 @@
import json
import logging
import uuid
from abc import ABC, abstractmethod
from typing import Dict, Generator, List, Optional
from typing import Any, Dict, Generator, List, Optional
from application.agents.tool_executor import ToolExecutor
from application.core.json_schema_utils import (
@@ -9,6 +10,7 @@ from application.core.json_schema_utils import (
normalize_json_schema_payload,
)
from application.core.settings import settings
from application.llm.handlers.base import ToolCall
from application.llm.handlers.handler_creator import LLMHandlerCreator
from application.llm.llm_creator import LLMCreator
from application.logging import build_stack_data, log_activity, LogContext
@@ -113,6 +115,153 @@ class BaseAgent(ABC):
) -> Generator[Dict, None, None]:
pass
def gen_continuation(
self,
messages: List[Dict],
tools_dict: Dict,
pending_tool_calls: List[Dict],
tool_actions: List[Dict],
) -> Generator[Dict, None, None]:
"""Resume generation after tool actions are resolved.
Processes the client-provided *tool_actions* (approvals, denials,
or client-side results), appends the resulting messages, then
hands back to the LLM to continue the conversation.
Args:
messages: The saved messages array from the pause point.
tools_dict: The saved tools dictionary.
pending_tool_calls: The pending tool call descriptors from the pause.
tool_actions: Client-provided actions resolving the pending calls.
"""
self._prepare_tools(tools_dict)
actions_by_id = {a["call_id"]: a for a in tool_actions}
# Build a single assistant message containing all tool calls so
# the message history matches the format LLM providers expect
# (one assistant message with N tool_calls, followed by N tool results).
tc_objects: List[Dict[str, Any]] = []
for pending in pending_tool_calls:
call_id = pending["call_id"]
args = pending["arguments"]
args_str = (
json.dumps(args) if isinstance(args, dict) else (args or "{}")
)
tc_obj: Dict[str, Any] = {
"id": call_id,
"type": "function",
"function": {
"name": pending["name"],
"arguments": args_str,
},
}
if pending.get("thought_signature"):
tc_obj["thought_signature"] = pending["thought_signature"]
tc_objects.append(tc_obj)
messages.append({
"role": "assistant",
"content": None,
"tool_calls": tc_objects,
})
# Now process each pending call and append tool result messages
for pending in pending_tool_calls:
call_id = pending["call_id"]
args = pending["arguments"]
action = actions_by_id.get(call_id)
if not action:
action = {
"call_id": call_id,
"decision": "denied",
"comment": "No response provided",
}
if action.get("decision") == "approved":
# Execute the tool server-side
tc = ToolCall(
id=call_id,
name=pending["name"],
arguments=(
json.dumps(args) if isinstance(args, dict) else args
),
)
tool_gen = self._execute_tool_action(tools_dict, tc)
tool_response = None
while True:
try:
event = next(tool_gen)
yield event
except StopIteration as e:
tool_response, _ = e.value
break
messages.append(
self.llm_handler.create_tool_message(tc, tool_response)
)
elif action.get("decision") == "denied":
comment = action.get("comment", "")
denial = (
f"Tool execution denied by user. Reason: {comment}"
if comment
else "Tool execution denied by user."
)
tc = ToolCall(
id=call_id, name=pending["name"], arguments=args
)
messages.append(
self.llm_handler.create_tool_message(tc, denial)
)
yield {
"type": "tool_call",
"data": {
"tool_name": pending.get("tool_name", "unknown"),
"call_id": call_id,
"action_name": pending.get("llm_name", pending["name"]),
"arguments": args,
"status": "denied",
},
}
elif "result" in action:
result = action["result"]
result_str = (
json.dumps(result)
if not isinstance(result, str)
else result
)
tc = ToolCall(
id=call_id, name=pending["name"], arguments=args
)
messages.append(
self.llm_handler.create_tool_message(tc, result_str)
)
yield {
"type": "tool_call",
"data": {
"tool_name": pending.get("tool_name", "unknown"),
"call_id": call_id,
"action_name": pending.get("llm_name", pending["name"]),
"arguments": args,
"result": (
result_str[:50] + "..."
if len(result_str) > 50
else result_str
),
"status": "completed",
},
}
# Resume the LLM loop with the updated messages
llm_response = self._llm_gen(messages)
yield from self._handle_response(
llm_response, tools_dict, messages, None
)
yield {"sources": self.retrieved_docs}
yield {"tool_calls": self._get_truncated_tool_calls()}
# ---- Tool delegation (thin wrappers around ToolExecutor) ----
@property
@@ -267,28 +416,35 @@ class BaseAgent(ABC):
if "tool_calls" in i:
for tool_call in i["tool_calls"]:
call_id = tool_call.get("call_id") or str(uuid.uuid4())
function_call_dict = {
"function_call": {
"name": tool_call.get("action_name"),
"args": tool_call.get("arguments"),
"call_id": call_id,
}
}
function_response_dict = {
"function_response": {
"name": tool_call.get("action_name"),
"response": {"result": tool_call.get("result")},
"call_id": call_id,
}
}
messages.append(
{"role": "assistant", "content": [function_call_dict]}
args = tool_call.get("arguments")
args_str = (
json.dumps(args)
if isinstance(args, dict)
else (args or "{}")
)
messages.append(
{"role": "tool", "content": [function_response_dict]}
messages.append({
"role": "assistant",
"content": None,
"tool_calls": [{
"id": call_id,
"type": "function",
"function": {
"name": tool_call.get("action_name", ""),
"arguments": args_str,
},
}],
})
result = tool_call.get("result")
result_str = (
json.dumps(result)
if not isinstance(result, str)
else (result or "")
)
messages.append({
"role": "tool",
"tool_call_id": call_id,
"content": result_str,
})
messages.append({"role": "user", "content": query})
return messages

View File

@@ -593,16 +593,22 @@ class ResearchAgent(BaseAgent):
)
result = result_str
function_call_content = {
"function_call": {
"name": call.name,
"args": call.arguments,
"call_id": call_id,
}
}
messages.append(
{"role": "assistant", "content": [function_call_content]}
import json as _json
args_str = (
_json.dumps(call.arguments)
if isinstance(call.arguments, dict)
else call.arguments
)
messages.append({
"role": "assistant",
"content": None,
"tool_calls": [{
"id": call_id,
"type": "function",
"function": {"name": call.name, "arguments": args_str},
}],
})
tool_message = self.llm_handler.create_tool_message(call, result)
messages.append(tool_message)

View File

@@ -1,6 +1,7 @@
import logging
import uuid
from typing import Dict, List, Optional
from collections import Counter
from typing import Dict, List, Optional, Tuple
from bson.objectid import ObjectId
@@ -31,12 +32,23 @@ class ToolExecutor:
self.tool_calls: List[Dict] = []
self._loaded_tools: Dict[str, object] = {}
self.conversation_id: Optional[str] = None
self.client_tools: Optional[List[Dict]] = None
self._name_to_tool: Dict[str, Tuple[str, str]] = {}
self._tool_to_name: Dict[Tuple[str, str], str] = {}
def get_tools(self) -> Dict[str, Dict]:
"""Load tool configs from DB based on user context."""
"""Load tool configs from DB based on user context.
If *client_tools* have been set on this executor, they are
automatically merged into the returned dict.
"""
if self.user_api_key:
return self._get_tools_by_api_key(self.user_api_key)
return self._get_user_tools(self.user or "local")
tools = self._get_tools_by_api_key(self.user_api_key)
else:
tools = self._get_user_tools(self.user or "local")
if self.client_tools:
self.merge_client_tools(tools, self.client_tools)
return tools
def _get_tools_by_api_key(self, api_key: str) -> Dict[str, Dict]:
mongo = MongoDB.get_client()
@@ -65,29 +77,123 @@ class ToolExecutor:
user_tools = list(user_tools)
return {str(i): tool for i, tool in enumerate(user_tools)}
def prepare_tools_for_llm(self, tools_dict: Dict) -> List[Dict]:
"""Convert tool configs to LLM function schemas."""
return [
{
"type": "function",
"function": {
"name": f"{action['name']}_{tool_id}",
"description": action["description"],
"parameters": self._build_tool_parameters(action),
},
def merge_client_tools(
self, tools_dict: Dict, client_tools: List[Dict]
) -> Dict:
"""Merge client-provided tool definitions into tools_dict.
Client tools use the standard function-calling format::
[{"type": "function", "function": {"name": "get_weather",
"description": "...", "parameters": {...}}}]
They are stored in *tools_dict* with ``client_side: True`` so that
:meth:`check_pause` returns a pause signal instead of trying to
execute them server-side.
Args:
tools_dict: The mutable server tools dict (will be modified in place).
client_tools: List of tool definitions in function-calling format.
Returns:
The updated *tools_dict* (same reference, for convenience).
"""
for i, ct in enumerate(client_tools):
func = ct.get("function", ct) # tolerate bare {"name":..} too
name = func.get("name", f"clienttool{i}")
tool_id = f"ct{i}"
tools_dict[tool_id] = {
"name": name,
"client_side": True,
"actions": [
{
"name": name,
"description": func.get("description", ""),
"active": True,
"parameters": func.get("parameters", {}),
}
],
}
for tool_id, tool in tools_dict.items()
if (
(tool["name"] == "api_tool" and "actions" in tool.get("config", {}))
or (tool["name"] != "api_tool" and "actions" in tool)
)
for action in (
return tools_dict
def prepare_tools_for_llm(self, tools_dict: Dict) -> List[Dict]:
"""Convert tool configs to LLM function schemas.
Action names are kept clean for the LLM:
- Unique action names appear as-is (e.g. ``get_weather``).
- Duplicate action names get numbered suffixes (e.g. ``search_1``,
``search_2``).
A reverse mapping is stored in ``_name_to_tool`` so that tool calls
can be routed back to the correct ``(tool_id, action_name)`` without
brittle string splitting.
"""
# Pass 1: collect entries and count action name occurrences
entries: List[Tuple[str, str, Dict, bool]] = [] # (tool_id, action_name, action, is_client)
name_counts: Counter = Counter()
for tool_id, tool in tools_dict.items():
is_api = tool["name"] == "api_tool"
is_client = tool.get("client_side", False)
if is_api and "actions" not in tool.get("config", {}):
continue
if not is_api and "actions" not in tool:
continue
actions = (
tool["config"]["actions"].values()
if tool["name"] == "api_tool"
if is_api
else tool["actions"]
)
if action.get("active", True)
]
for action in actions:
if not action.get("active", True):
continue
entries.append((tool_id, action["name"], action, is_client))
name_counts[action["name"]] += 1
# Pass 2: assign LLM-visible names and build mappings
self._name_to_tool = {}
self._tool_to_name = {}
collision_counters: Dict[str, int] = {}
all_llm_names: set = set()
result = []
for tool_id, action_name, action, is_client in entries:
if name_counts[action_name] == 1:
llm_name = action_name
else:
counter = collision_counters.get(action_name, 1)
candidate = f"{action_name}_{counter}"
# Skip if candidate collides with a unique action name
while candidate in all_llm_names or (
candidate in name_counts and name_counts[candidate] == 1
):
counter += 1
candidate = f"{action_name}_{counter}"
collision_counters[action_name] = counter + 1
llm_name = candidate
all_llm_names.add(llm_name)
self._name_to_tool[llm_name] = (tool_id, action_name)
self._tool_to_name[(tool_id, action_name)] = llm_name
if is_client:
params = action.get("parameters", {})
else:
params = self._build_tool_parameters(action)
result.append({
"type": "function",
"function": {
"name": llm_name,
"description": action.get("description", ""),
"parameters": params,
},
})
return result
def _build_tool_parameters(self, action: Dict) -> Dict:
params = {"type": "object", "properties": {}, "required": []}
@@ -104,23 +210,81 @@ class ToolExecutor:
params["required"].append(k)
return params
def check_pause(
self, tools_dict: Dict, call, llm_class_name: str
) -> Optional[Dict]:
"""Check if a tool call requires pausing for approval or client execution.
Returns a dict describing the pending action if pause is needed, None otherwise.
"""
parser = ToolActionParser(llm_class_name, name_mapping=self._name_to_tool)
tool_id, action_name, call_args = parser.parse_args(call)
call_id = getattr(call, "id", None) or str(uuid.uuid4())
llm_name = getattr(call, "name", "")
if tool_id is None or action_name is None or tool_id not in tools_dict:
return None # Will be handled as error by execute()
tool_data = tools_dict[tool_id]
# Client-side tools
if tool_data.get("client_side"):
return {
"call_id": call_id,
"name": llm_name,
"tool_name": tool_data.get("name", "unknown"),
"tool_id": tool_id,
"action_name": action_name,
"llm_name": llm_name,
"arguments": call_args if isinstance(call_args, dict) else {},
"pause_type": "requires_client_execution",
"thought_signature": getattr(call, "thought_signature", None),
}
# Approval required
if tool_data["name"] == "api_tool":
action_data = tool_data.get("config", {}).get("actions", {}).get(
action_name, {}
)
else:
action_data = next(
(a for a in tool_data.get("actions", []) if a["name"] == action_name),
{},
)
if action_data.get("require_approval"):
return {
"call_id": call_id,
"name": llm_name,
"tool_name": tool_data.get("name", "unknown"),
"tool_id": tool_id,
"action_name": action_name,
"llm_name": llm_name,
"arguments": call_args if isinstance(call_args, dict) else {},
"pause_type": "awaiting_approval",
"thought_signature": getattr(call, "thought_signature", None),
}
return None
def execute(self, tools_dict: Dict, call, llm_class_name: str):
"""Execute a tool call. Yields status events, returns (result, call_id)."""
parser = ToolActionParser(llm_class_name)
parser = ToolActionParser(llm_class_name, name_mapping=self._name_to_tool)
tool_id, action_name, call_args = parser.parse_args(call)
llm_name = getattr(call, "name", "unknown")
call_id = getattr(call, "id", None) or str(uuid.uuid4())
if tool_id is None or action_name is None:
error_message = f"Error: Failed to parse LLM tool call. Tool name: {getattr(call, 'name', 'unknown')}"
error_message = f"Error: Failed to parse LLM tool call. Tool name: {llm_name}"
logger.error(error_message)
tool_call_data = {
"tool_name": "unknown",
"call_id": call_id,
"action_name": getattr(call, "name", "unknown"),
"action_name": llm_name,
"arguments": call_args or {},
"result": f"Failed to parse tool call. Invalid tool name format: {getattr(call, 'name', 'unknown')}",
"result": f"Failed to parse tool call. Invalid tool name format: {llm_name}",
}
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
self.tool_calls.append(tool_call_data)
@@ -133,7 +297,7 @@ class ToolExecutor:
tool_call_data = {
"tool_name": "unknown",
"call_id": call_id,
"action_name": f"{action_name}_{tool_id}",
"action_name": llm_name,
"arguments": call_args,
"result": f"Tool with ID {tool_id} not found. Available tools: {list(tools_dict.keys())}",
}
@@ -144,7 +308,7 @@ class ToolExecutor:
tool_call_data = {
"tool_name": tools_dict[tool_id]["name"],
"call_id": call_id,
"action_name": f"{action_name}_{tool_id}",
"action_name": llm_name,
"arguments": call_args,
}
yield {"type": "tool_call", "data": {**tool_call_data, "status": "pending"}}

View File

@@ -2,6 +2,8 @@ from abc import ABC, abstractmethod
class Tool(ABC):
internal: bool = False
@abstractmethod
def execute_action(self, action_name: str, **kwargs):
pass

View File

@@ -73,7 +73,7 @@ class BraveSearchTool(Tool):
"X-Subscription-Token": self.token,
}
response = requests.get(url, params=params, headers=headers)
response = requests.get(url, params=params, headers=headers, timeout=100)
if response.status_code == 200:
return {
@@ -118,7 +118,7 @@ class BraveSearchTool(Tool):
"X-Subscription-Token": self.token,
}
response = requests.get(url, params=params, headers=headers)
response = requests.get(url, params=params, headers=headers, timeout=100)
if response.status_code == 200:
return {

View File

@@ -28,7 +28,7 @@ class CryptoPriceTool(Tool):
returns price in USD.
"""
url = f"https://min-api.cryptocompare.com/data/price?fsym={symbol.upper()}&tsyms={currency.upper()}"
response = requests.get(url)
response = requests.get(url, timeout=100)
if response.status_code == 200:
data = response.json()
if currency.upper() in data:

View File

@@ -20,6 +20,8 @@ class InternalSearchTool(Tool):
- list_files action: browse the file/folder structure
"""
internal = True
def __init__(self, config: Dict):
self.config = config
self.retrieved_docs: List[Dict] = []

View File

@@ -24,6 +24,7 @@ from application.api.user.tasks import mcp_oauth_status_task, mcp_oauth_task
from application.cache import get_redis_instance
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.core.url_validation import SSRFError, validate_url
from application.security.encryption import decrypt_credentials
logger = logging.getLogger(__name__)
@@ -61,7 +62,8 @@ class MCPTool(Tool):
"""
self.config = config
self.user_id = user_id
self.server_url = config.get("server_url", "")
raw_url = config.get("server_url", "")
self.server_url = self._validate_server_url(raw_url) if raw_url else ""
self.transport_type = config.get("transport_type", "auto")
self.auth_type = config.get("auth_type", "none")
self.timeout = config.get("timeout", 30)
@@ -87,6 +89,18 @@ class MCPTool(Tool):
if self.server_url and self.auth_type != "oauth":
self._setup_client()
@staticmethod
def _validate_server_url(server_url: str) -> str:
"""Validate server_url to prevent SSRF to internal networks.
Raises:
ValueError: If the URL points to a private/internal address.
"""
try:
return validate_url(server_url)
except SSRFError as exc:
raise ValueError(f"Invalid MCP server URL: {exc}") from exc
def _resolve_redirect_uri(self, configured_redirect_uri: Optional[str]) -> str:
if configured_redirect_uri:
return configured_redirect_uri.rstrip("/")
@@ -108,8 +122,9 @@ class MCPTool(Tool):
auth_key = ""
if self.auth_type == "oauth":
scopes_str = ",".join(self.oauth_scopes) if self.oauth_scopes else "none"
oauth_identity = self.user_id or self.oauth_task_id or "anonymous"
auth_key = (
f"oauth:{self.oauth_client_name}:{scopes_str}:{self.redirect_uri}"
f"oauth:{oauth_identity}:{self.oauth_client_name}:{scopes_str}:{self.redirect_uri}"
)
elif self.auth_type in ["bearer"]:
token = self.auth_credentials.get(

View File

@@ -71,7 +71,7 @@ class NtfyTool(Tool):
if self.token:
headers["Authorization"] = f"Basic {self.token}"
data = message.encode("utf-8")
response = requests.post(url, headers=headers, data=data)
response = requests.post(url, headers=headers, data=data, timeout=100)
return {"status_code": response.status_code, "message": "Message sent"}
def get_actions_metadata(self):

View File

@@ -1,6 +1,6 @@
import logging
import psycopg2
import psycopg
from application.agents.tools.base import Tool
@@ -33,7 +33,7 @@ class PostgresTool(Tool):
"""
conn = None
try:
conn = psycopg2.connect(self.connection_string)
conn = psycopg.connect(self.connection_string)
cur = conn.cursor()
cur.execute(sql_query)
conn.commit()
@@ -60,7 +60,7 @@ class PostgresTool(Tool):
"response_data": response_data,
}
except psycopg2.Error as e:
except psycopg.Error as e:
error_message = f"Database error: {e}"
logger.error("PostgreSQL execute_sql error: %s", e)
return {
@@ -78,7 +78,7 @@ class PostgresTool(Tool):
"""
conn = None
try:
conn = psycopg2.connect(self.connection_string)
conn = psycopg.connect(self.connection_string)
cur = conn.cursor()
cur.execute(
@@ -120,7 +120,7 @@ class PostgresTool(Tool):
"schema": schema_data,
}
except psycopg2.Error as e:
except psycopg.Error as e:
error_message = f"Database error: {e}"
logger.error("PostgreSQL get_schema error: %s", e)
return {

View File

@@ -31,14 +31,14 @@ class TelegramTool(Tool):
logger.debug("Sending Telegram message to chat_id=%s", chat_id)
url = f"https://api.telegram.org/bot{self.token}/sendMessage"
payload = {"chat_id": chat_id, "text": text}
response = requests.post(url, data=payload)
response = requests.post(url, data=payload, timeout=100)
return {"status_code": response.status_code, "message": "Message sent"}
def _send_image(self, image_url, chat_id):
logger.debug("Sending Telegram image to chat_id=%s", chat_id)
url = f"https://api.telegram.org/bot{self.token}/sendPhoto"
payload = {"chat_id": chat_id, "photo": image_url}
response = requests.post(url, data=payload)
response = requests.post(url, data=payload, timeout=100)
return {"status_code": response.status_code, "message": "Image sent"}
def get_actions_metadata(self):

View File

@@ -36,6 +36,8 @@ class ThinkTool(Tool):
The reasoning content is captured in tool_call data for transparency.
"""
internal = True
def __init__(self, config=None):
pass

View File

@@ -5,8 +5,9 @@ logger = logging.getLogger(__name__)
class ToolActionParser:
def __init__(self, llm_type):
def __init__(self, llm_type, name_mapping=None):
self.llm_type = llm_type
self.name_mapping = name_mapping
self.parsers = {
"OpenAILLM": self._parse_openai_llm,
"GoogleLLM": self._parse_google_llm,
@@ -16,22 +17,33 @@ class ToolActionParser:
parser = self.parsers.get(self.llm_type, self._parse_openai_llm)
return parser(call)
def _resolve_via_mapping(self, call_name):
"""Look up (tool_id, action_name) from the name mapping if available."""
if self.name_mapping and call_name in self.name_mapping:
return self.name_mapping[call_name]
return None
def _parse_openai_llm(self, call):
try:
call_args = json.loads(call.arguments)
resolved = self._resolve_via_mapping(call.name)
if resolved:
return resolved[0], resolved[1], call_args
# Fallback: legacy split on "_" for backward compatibility
tool_parts = call.name.split("_")
# If the tool name doesn't contain an underscore, it's likely a hallucinated tool
if len(tool_parts) < 2:
logger.warning(
f"Invalid tool name format: {call.name}. Expected format: action_name_tool_id"
f"Invalid tool name format: {call.name}. "
"Could not resolve via mapping or legacy parsing."
)
return None, None, None
tool_id = tool_parts[-1]
action_name = "_".join(tool_parts[:-1])
# Validate that tool_id looks like a numerical ID
if not tool_id.isdigit():
logger.warning(
f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call."
@@ -45,19 +57,24 @@ class ToolActionParser:
def _parse_google_llm(self, call):
try:
call_args = call.arguments
resolved = self._resolve_via_mapping(call.name)
if resolved:
return resolved[0], resolved[1], call_args
# Fallback: legacy split on "_" for backward compatibility
tool_parts = call.name.split("_")
# If the tool name doesn't contain an underscore, it's likely a hallucinated tool
if len(tool_parts) < 2:
logger.warning(
f"Invalid tool name format: {call.name}. Expected format: action_name_tool_id"
f"Invalid tool name format: {call.name}. "
"Could not resolve via mapping or legacy parsing."
)
return None, None, None
tool_id = tool_parts[-1]
action_name = "_".join(tool_parts[:-1])
# Validate that tool_id looks like a numerical ID
if not tool_id.isdigit():
logger.warning(
f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call."

View File

@@ -19,7 +19,7 @@ class ToolManager:
continue
module = importlib.import_module(f"application.agents.tools.{name}")
for member_name, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, Tool) and obj is not Tool:
if issubclass(obj, Tool) and obj is not Tool and not obj.internal:
tool_config = self.config.get(name, {})
self.tools[name] = obj(tool_config)

52
application/alembic.ini Normal file
View File

@@ -0,0 +1,52 @@
# Alembic configuration for the DocsGPT user-data Postgres database.
#
# The SQLAlchemy URL is deliberately NOT set here — env.py reads it from
# ``application.core.settings.settings.POSTGRES_URI`` so the same config
# source serves the running app and migrations. To run from the project
# root::
#
# alembic -c application/alembic.ini upgrade head
[alembic]
script_location = %(here)s/alembic
prepend_sys_path = ..
version_path_separator = os
# sqlalchemy.url is intentionally left blank — env.py supplies it.
sqlalchemy.url =
[post_write_hooks]
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARNING
handlers = console
qualname =
[logger_sqlalchemy]
level = WARNING
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

View File

@@ -0,0 +1,82 @@
"""Alembic environment for the DocsGPT user-data Postgres database.
The URL is pulled from ``application.core.settings`` rather than
``alembic.ini`` so that a single ``POSTGRES_URI`` env var drives both the
running app and ``alembic`` CLI invocations.
"""
import sys
from logging.config import fileConfig
from pathlib import Path
# Make the project root importable regardless of cwd. env.py lives at
# <repo>/application/alembic/env.py, so parents[2] is the repo root.
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
if str(_PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(_PROJECT_ROOT))
from alembic import context # noqa: E402
from sqlalchemy import engine_from_config, pool # noqa: E402
from application.core.settings import settings # noqa: E402
from application.storage.db.models import metadata as target_metadata # noqa: E402
config = context.config
# Populate the runtime URL from settings.
if settings.POSTGRES_URI:
config.set_main_option("sqlalchemy.url", settings.POSTGRES_URI)
if config.config_file_name is not None:
fileConfig(config.config_file_name)
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode (emits SQL without a live DB)."""
url = config.get_main_option("sqlalchemy.url")
if not url:
raise RuntimeError(
"POSTGRES_URI is not configured. Set it in your .env to a "
"psycopg3 URI such as "
"'postgresql+psycopg://user:pass@host:5432/docsgpt'."
)
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
compare_type=True,
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode against a live connection."""
if not config.get_main_option("sqlalchemy.url"):
raise RuntimeError(
"POSTGRES_URI is not configured. Set it in your .env to a "
"psycopg3 URI such as "
"'postgresql+psycopg://user:pass@host:5432/docsgpt'."
)
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
future=True,
)
with connectable.connect() as connection:
context.configure(
connection=connection,
target_metadata=target_metadata,
compare_type=True,
)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

View File

@@ -0,0 +1,26 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
${upgrades if upgrades else "pass"}
def downgrade() -> None:
${downgrades if downgrades else "pass"}

View File

@@ -0,0 +1,462 @@
"""0001 initial schema — user-level tables migrated from MongoDB.
Creates every table described in §2.2 of ``migration-postgres.md``: tiers 1,
2, and 3 in one shot. The schema is small enough that splitting the baseline
across multiple revisions would only cost clarity.
Subsequent migrations will add columns / tables incrementally. This file is
hand-written raw DDL rather than Core ``op.create_table`` calls because the
DDL uses several Postgres-specific features (``CITEXT``, partial indexes,
``text_pattern_ops``, JSONB defaults) that are clearer in SQL than in
Alembic's Python API.
Revision ID: 0001_initial
Revises:
Create Date: 2026-04-10
"""
from typing import Sequence, Union
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "0001_initial"
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ------------------------------------------------------------------
# Extensions
# ------------------------------------------------------------------
op.execute('CREATE EXTENSION IF NOT EXISTS "pgcrypto";')
op.execute('CREATE EXTENSION IF NOT EXISTS "citext";')
# ------------------------------------------------------------------
# Tier 1: leaf tables, no FKs into other migrated tables
# ------------------------------------------------------------------
op.execute("""
CREATE TABLE users (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL UNIQUE,
agent_preferences JSONB NOT NULL
DEFAULT '{"pinned": [], "shared_with_me": []}'::jsonb,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
""")
op.execute("CREATE INDEX users_user_id_idx ON users (user_id);")
op.execute("""
CREATE TABLE prompts (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
name TEXT NOT NULL,
content TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
""")
op.execute("CREATE INDEX prompts_user_id_idx ON prompts (user_id);")
op.execute("""
CREATE TABLE user_tools (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
name TEXT NOT NULL,
custom_name TEXT,
display_name TEXT,
config JSONB NOT NULL DEFAULT '{}'::jsonb,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
""")
op.execute("CREATE INDEX user_tools_user_id_idx ON user_tools (user_id);")
op.execute("""
CREATE TABLE token_usage (
id BIGSERIAL PRIMARY KEY,
user_id TEXT,
api_key TEXT,
agent_id UUID, -- FK added later in this migration
prompt_tokens INTEGER NOT NULL DEFAULT 0,
generated_tokens INTEGER NOT NULL DEFAULT 0,
timestamp TIMESTAMPTZ NOT NULL DEFAULT now()
);
""")
op.execute("CREATE INDEX token_usage_user_ts_idx ON token_usage (user_id, timestamp DESC);")
op.execute("CREATE INDEX token_usage_key_ts_idx ON token_usage (api_key, timestamp DESC);")
op.execute("CREATE INDEX token_usage_agent_ts_idx ON token_usage (agent_id, timestamp DESC);")
op.execute("""
CREATE TABLE user_logs (
id BIGSERIAL PRIMARY KEY,
user_id TEXT,
endpoint TEXT,
timestamp TIMESTAMPTZ NOT NULL DEFAULT now(),
data JSONB
);
""")
op.execute("CREATE INDEX user_logs_user_ts_idx ON user_logs (user_id, timestamp DESC);")
op.execute("""
CREATE TABLE feedback (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
conversation_id UUID NOT NULL, -- FK added later in this migration
user_id TEXT NOT NULL,
question_index INTEGER NOT NULL,
feedback_text TEXT,
timestamp TIMESTAMPTZ NOT NULL DEFAULT now()
);
""")
op.execute("CREATE INDEX feedback_conv_idx ON feedback (conversation_id);")
# Append-only debug/error log. The Mongo doc has both `_id` (auto) and an
# `id` field (the activity id). Here the serial PK owns `id`; the
# application-level identifier is renamed to `activity_id`.
op.execute("""
CREATE TABLE stack_logs (
id BIGSERIAL PRIMARY KEY,
activity_id TEXT NOT NULL,
endpoint TEXT,
level TEXT,
user_id TEXT,
api_key TEXT,
query TEXT,
stacks JSONB NOT NULL DEFAULT '[]'::jsonb,
timestamp TIMESTAMPTZ NOT NULL DEFAULT now()
);
""")
op.execute("CREATE INDEX stack_logs_timestamp_idx ON stack_logs (timestamp DESC);")
op.execute("CREATE INDEX stack_logs_user_ts_idx ON stack_logs (user_id, timestamp DESC);")
op.execute("CREATE INDEX stack_logs_level_ts_idx ON stack_logs (level, timestamp DESC);")
op.execute("CREATE INDEX stack_logs_activity_idx ON stack_logs (activity_id);")
# ------------------------------------------------------------------
# Tier 2: FK-bearing tables
# ------------------------------------------------------------------
op.execute("""
CREATE TABLE agent_folders (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
name TEXT NOT NULL,
description TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
""")
op.execute("CREATE INDEX agent_folders_user_idx ON agent_folders (user_id);")
op.execute("""
CREATE TABLE sources (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT, -- NULL for system/template sources
name TEXT NOT NULL,
type TEXT,
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
""")
op.execute("CREATE INDEX sources_user_idx ON sources (user_id);")
op.execute("""
CREATE TABLE agents (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
name TEXT NOT NULL,
description TEXT,
agent_type TEXT,
status TEXT NOT NULL,
key CITEXT UNIQUE,
source_id UUID REFERENCES sources(id) ON DELETE SET NULL,
extra_source_ids UUID[] NOT NULL DEFAULT '{}',
chunks INTEGER,
retriever TEXT,
prompt_id UUID REFERENCES prompts(id) ON DELETE SET NULL,
tools JSONB NOT NULL DEFAULT '[]'::jsonb,
json_schema JSONB,
models JSONB,
default_model_id TEXT,
folder_id UUID REFERENCES agent_folders(id) ON DELETE SET NULL,
limited_token_mode BOOLEAN NOT NULL DEFAULT false,
token_limit INTEGER,
limited_request_mode BOOLEAN NOT NULL DEFAULT false,
request_limit INTEGER,
shared BOOLEAN NOT NULL DEFAULT false,
incoming_webhook_token CITEXT UNIQUE,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
last_used_at TIMESTAMPTZ
);
""")
op.execute("CREATE INDEX agents_user_idx ON agents (user_id);")
op.execute("CREATE INDEX agents_shared_idx ON agents (shared) WHERE shared = true;")
op.execute("CREATE INDEX agents_status_idx ON agents (status);")
# Backfill the token_usage.agent_id FK now that agents exists.
op.execute("""
ALTER TABLE token_usage
ADD CONSTRAINT token_usage_agent_fk
FOREIGN KEY (agent_id) REFERENCES agents(id) ON DELETE SET NULL;
""")
op.execute("""
CREATE TABLE attachments (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
filename TEXT NOT NULL,
upload_path TEXT NOT NULL,
mime_type TEXT,
size BIGINT,
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
""")
op.execute("CREATE INDEX attachments_user_idx ON attachments (user_id);")
op.execute("""
CREATE TABLE memories (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
tool_id UUID REFERENCES user_tools(id) ON DELETE CASCADE,
path TEXT NOT NULL,
content TEXT NOT NULL,
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
""")
op.execute("""
CREATE UNIQUE INDEX memories_user_tool_path_uidx
ON memories (user_id, tool_id, path);
""")
op.execute("""
CREATE INDEX memories_path_prefix_idx
ON memories (user_id, tool_id, path text_pattern_ops);
""")
op.execute("""
CREATE TABLE todos (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
tool_id UUID REFERENCES user_tools(id) ON DELETE CASCADE,
title TEXT NOT NULL,
completed BOOLEAN NOT NULL DEFAULT false,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
""")
op.execute("CREATE INDEX todos_user_tool_idx ON todos (user_id, tool_id);")
op.execute("""
CREATE TABLE notes (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
tool_id UUID REFERENCES user_tools(id) ON DELETE CASCADE,
title TEXT NOT NULL,
content TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
""")
op.execute("CREATE INDEX notes_user_tool_idx ON notes (user_id, tool_id);")
op.execute("""
CREATE TABLE connector_sessions (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
provider TEXT NOT NULL,
session_data JSONB NOT NULL,
expires_at TIMESTAMPTZ,
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
""")
op.execute("""
CREATE INDEX connector_sessions_user_provider_idx
ON connector_sessions (user_id, provider);
""")
op.execute("""
CREATE INDEX connector_sessions_expiry_idx
ON connector_sessions (expires_at) WHERE expires_at IS NOT NULL;
""")
# ------------------------------------------------------------------
# Tier 3: conversations, pending_tool_state, workflows
# ------------------------------------------------------------------
op.execute("""
CREATE TABLE conversations (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
agent_id UUID REFERENCES agents(id) ON DELETE SET NULL,
name TEXT,
api_key TEXT,
is_shared_usage BOOLEAN NOT NULL DEFAULT false,
shared_token TEXT,
date TIMESTAMPTZ NOT NULL DEFAULT now(),
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
""")
op.execute("CREATE INDEX conversations_user_date_idx ON conversations (user_id, date DESC);")
op.execute("CREATE INDEX conversations_agent_idx ON conversations (agent_id);")
op.execute("""
CREATE INDEX conversations_shared_token_idx
ON conversations (shared_token) WHERE shared_token IS NOT NULL;
""")
op.execute("""
CREATE TABLE conversation_messages (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
conversation_id UUID NOT NULL REFERENCES conversations(id) ON DELETE CASCADE,
position INTEGER NOT NULL,
prompt TEXT,
response TEXT,
thought TEXT,
sources JSONB NOT NULL DEFAULT '[]'::jsonb,
tool_calls JSONB NOT NULL DEFAULT '[]'::jsonb,
attachments UUID[] NOT NULL DEFAULT '{}',
model_id TEXT,
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
feedback JSONB,
timestamp TIMESTAMPTZ NOT NULL DEFAULT now()
);
""")
op.execute("""
CREATE UNIQUE INDEX conversation_messages_conv_pos_uidx
ON conversation_messages (conversation_id, position);
""")
# Backfill the feedback.conversation_id FK now that conversations exists.
op.execute("""
ALTER TABLE feedback
ADD CONSTRAINT feedback_conv_fk
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE;
""")
op.execute("""
CREATE TABLE shared_conversations (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
conversation_id UUID NOT NULL REFERENCES conversations(id) ON DELETE CASCADE,
user_id TEXT NOT NULL,
prompt_id UUID REFERENCES prompts(id) ON DELETE SET NULL,
chunks INTEGER,
is_promptable BOOLEAN NOT NULL DEFAULT false,
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
""")
op.execute("CREATE INDEX shared_conversations_user_idx ON shared_conversations (user_id);")
op.execute("CREATE INDEX shared_conversations_conv_idx ON shared_conversations (conversation_id);")
# Paused-tool continuation state. The Mongo version relies on a TTL index;
# Postgres has no native TTL, so a Celery beat task (added in Phase 3)
# deletes rows where expires_at < now() once a minute. The unique
# constraint on (conversation_id, user_id) matches the existing upsert
# semantics.
op.execute("""
CREATE TABLE pending_tool_state (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
conversation_id UUID NOT NULL REFERENCES conversations(id) ON DELETE CASCADE,
user_id TEXT NOT NULL,
messages JSONB NOT NULL,
pending_tool_calls JSONB NOT NULL,
tools_dict JSONB NOT NULL,
tool_schemas JSONB NOT NULL,
agent_config JSONB NOT NULL,
client_tools JSONB,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
expires_at TIMESTAMPTZ NOT NULL
);
""")
op.execute("""
CREATE UNIQUE INDEX pending_tool_state_conv_user_uidx
ON pending_tool_state (conversation_id, user_id);
""")
op.execute("""
CREATE INDEX pending_tool_state_expires_idx
ON pending_tool_state (expires_at);
""")
# Workflows
op.execute("""
CREATE TABLE workflows (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
name TEXT NOT NULL,
description TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
""")
op.execute("CREATE INDEX workflows_user_idx ON workflows (user_id);")
op.execute("""
CREATE TABLE workflow_nodes (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
workflow_id UUID NOT NULL REFERENCES workflows(id) ON DELETE CASCADE,
graph_version INTEGER NOT NULL,
node_type TEXT NOT NULL,
config JSONB NOT NULL DEFAULT '{}'::jsonb
);
""")
op.execute("""
CREATE INDEX workflow_nodes_workflow_version_idx
ON workflow_nodes (workflow_id, graph_version);
""")
op.execute("""
CREATE TABLE workflow_edges (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
workflow_id UUID NOT NULL REFERENCES workflows(id) ON DELETE CASCADE,
graph_version INTEGER NOT NULL,
from_node_id UUID NOT NULL REFERENCES workflow_nodes(id) ON DELETE CASCADE,
to_node_id UUID NOT NULL REFERENCES workflow_nodes(id) ON DELETE CASCADE,
config JSONB NOT NULL DEFAULT '{}'::jsonb
);
""")
op.execute("""
CREATE INDEX workflow_edges_workflow_version_idx
ON workflow_edges (workflow_id, graph_version);
""")
op.execute("""
CREATE TABLE workflow_runs (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
workflow_id UUID NOT NULL REFERENCES workflows(id) ON DELETE CASCADE,
user_id TEXT NOT NULL,
status TEXT NOT NULL,
started_at TIMESTAMPTZ NOT NULL DEFAULT now(),
ended_at TIMESTAMPTZ,
result JSONB
);
""")
op.execute("CREATE INDEX workflow_runs_workflow_idx ON workflow_runs (workflow_id);")
op.execute("CREATE INDEX workflow_runs_user_idx ON workflow_runs (user_id);")
def downgrade() -> None:
# Reverse dependency order. CASCADE would handle FKs anyway, but explicit
# is clearer for anyone reading the migration.
op.execute("DROP TABLE IF EXISTS workflow_runs CASCADE;")
op.execute("DROP TABLE IF EXISTS workflow_edges CASCADE;")
op.execute("DROP TABLE IF EXISTS workflow_nodes CASCADE;")
op.execute("DROP TABLE IF EXISTS workflows CASCADE;")
op.execute("DROP TABLE IF EXISTS pending_tool_state CASCADE;")
op.execute("DROP TABLE IF EXISTS shared_conversations CASCADE;")
op.execute("DROP TABLE IF EXISTS conversation_messages CASCADE;")
op.execute("DROP TABLE IF EXISTS conversations CASCADE;")
op.execute("DROP TABLE IF EXISTS connector_sessions CASCADE;")
op.execute("DROP TABLE IF EXISTS notes CASCADE;")
op.execute("DROP TABLE IF EXISTS todos CASCADE;")
op.execute("DROP TABLE IF EXISTS memories CASCADE;")
op.execute("DROP TABLE IF EXISTS attachments CASCADE;")
op.execute("DROP TABLE IF EXISTS agents CASCADE;")
op.execute("DROP TABLE IF EXISTS sources CASCADE;")
op.execute("DROP TABLE IF EXISTS agent_folders CASCADE;")
op.execute("DROP TABLE IF EXISTS stack_logs CASCADE;")
op.execute("DROP TABLE IF EXISTS feedback CASCADE;")
op.execute("DROP TABLE IF EXISTS user_logs CASCADE;")
op.execute("DROP TABLE IF EXISTS token_usage CASCADE;")
op.execute("DROP TABLE IF EXISTS user_tools CASCADE;")
op.execute("DROP TABLE IF EXISTS prompts CASCADE;")
op.execute("DROP TABLE IF EXISTS users CASCADE;")
# Extensions are intentionally left in place — they may be shared with
# pgvector or other extensions already enabled on the cluster.

View File

@@ -0,0 +1,57 @@
"""0002 add unique constraints for notes and connector_sessions.
The memories table already has ``memories_user_tool_path_uidx`` from the
0001 baseline. Notes and connector_sessions were missing unique constraints
that their repository upsert logic depends on.
Before creating the indexes, duplicate rows are cleaned up — keeping only
the row with the latest ``id`` (UUID, lexicographic max) per group.
Revision ID: 0002_add_unique_constraints
Revises: 0001_initial
Create Date: 2026-04-12
"""
from typing import Sequence, Union
from alembic import op
revision: str = "0002_add_unique_constraints"
down_revision: Union[str, None] = "0001_initial"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Deduplicate notes: keep one row per (user_id, tool_id)
op.execute("""
DELETE FROM notes
WHERE id NOT IN (
SELECT DISTINCT ON (user_id, tool_id) id
FROM notes
ORDER BY user_id, tool_id, created_at DESC
);
""")
op.execute(
"CREATE UNIQUE INDEX IF NOT EXISTS notes_user_tool_uidx "
"ON notes (user_id, tool_id);"
)
# Deduplicate connector_sessions: keep one row per (user_id, provider)
op.execute("""
DELETE FROM connector_sessions
WHERE id NOT IN (
SELECT DISTINCT ON (user_id, provider) id
FROM connector_sessions
ORDER BY user_id, provider, created_at DESC
);
""")
op.execute(
"CREATE UNIQUE INDEX IF NOT EXISTS connector_sessions_user_provider_uidx "
"ON connector_sessions (user_id, provider);"
)
def downgrade() -> None:
op.execute("DROP INDEX IF EXISTS connector_sessions_user_provider_uidx;")
op.execute("DROP INDEX IF EXISTS notes_user_tool_uidx;")

View File

@@ -74,57 +74,76 @@ class AnswerResource(Resource, BaseAnswerResource):
decoded_token = getattr(request, "decoded_token", None)
processor = StreamProcessor(data, decoded_token)
try:
agent = processor.build_agent(data.get("question", ""))
if not processor.decoded_token:
return make_response({"error": "Unauthorized"}, 401)
# ---- Continuation mode ----
if data.get("tool_actions"):
(
agent,
messages,
tools_dict,
pending_tool_calls,
tool_actions,
) = processor.resume_from_tool_actions(
data["tool_actions"], data["conversation_id"]
)
if not processor.decoded_token:
return make_response({"error": "Unauthorized"}, 401)
if error := self.check_usage(processor.agent_config):
return error
stream = self.complete_stream(
question="",
agent=agent,
conversation_id=processor.conversation_id,
user_api_key=processor.agent_config.get("user_api_key"),
decoded_token=processor.decoded_token,
agent_id=processor.agent_id,
model_id=processor.model_id,
_continuation={
"messages": messages,
"tools_dict": tools_dict,
"pending_tool_calls": pending_tool_calls,
"tool_actions": tool_actions,
},
)
else:
# ---- Normal mode ----
agent = processor.build_agent(data.get("question", ""))
if not processor.decoded_token:
return make_response({"error": "Unauthorized"}, 401)
if error := self.check_usage(processor.agent_config):
return error
if error := self.check_usage(processor.agent_config):
return error
stream = self.complete_stream(
question=data["question"],
agent=agent,
conversation_id=processor.conversation_id,
user_api_key=processor.agent_config.get("user_api_key"),
decoded_token=processor.decoded_token,
isNoneDoc=data.get("isNoneDoc"),
index=None,
should_save_conversation=data.get("save_conversation", True),
agent_id=processor.agent_id,
is_shared_usage=processor.is_shared_usage,
shared_token=processor.shared_token,
model_id=processor.model_id,
)
stream = self.complete_stream(
question=data["question"],
agent=agent,
conversation_id=processor.conversation_id,
user_api_key=processor.agent_config.get("user_api_key"),
decoded_token=processor.decoded_token,
isNoneDoc=data.get("isNoneDoc"),
index=None,
should_save_conversation=data.get("save_conversation", True),
agent_id=processor.agent_id,
is_shared_usage=processor.is_shared_usage,
shared_token=processor.shared_token,
model_id=processor.model_id,
)
stream_result = self.process_response_stream(stream)
if len(stream_result) == 7:
(
conversation_id,
response,
sources,
tool_calls,
thought,
error,
structured_info,
) = stream_result
else:
conversation_id, response, sources, tool_calls, thought, error = (
stream_result
)
structured_info = None
if stream_result["error"]:
return make_response({"error": stream_result["error"]}, 400)
if error:
return make_response({"error": error}, 400)
result = {
"conversation_id": conversation_id,
"answer": response,
"sources": sources,
"tool_calls": tool_calls,
"thought": thought,
"conversation_id": stream_result["conversation_id"],
"answer": stream_result["answer"],
"sources": stream_result["sources"],
"tool_calls": stream_result["tool_calls"],
"thought": stream_result["thought"],
}
if structured_info:
result.update(structured_info)
extra_info = stream_result.get("extra")
if extra_info:
result.update(extra_info)
except Exception as e:
logger.error(
f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}",

View File

@@ -6,6 +6,7 @@ from typing import Any, Dict, Generator, List, Optional
from flask import jsonify, make_response, Response
from flask_restx import Namespace
from application.api.answer.services.continuation_service import ContinuationService
from application.api.answer.services.conversation_service import ConversationService
from application.core.model_utils import (
get_api_key_for_provider,
@@ -39,7 +40,16 @@ class BaseAnswerResource:
def validate_request(
self, data: Dict[str, Any], require_conversation_id: bool = False
) -> Optional[Response]:
"""Common request validation"""
"""Common request validation.
Continuation requests (``tool_actions`` present) require
``conversation_id`` but not ``question``.
"""
if data.get("tool_actions"):
# Continuation mode — question is not required
if missing := check_required_fields(data, ["conversation_id"]):
return missing
return None
required_fields = ["question"]
if require_conversation_id:
required_fields.append("conversation_id")
@@ -177,6 +187,7 @@ class BaseAnswerResource:
is_shared_usage: bool = False,
shared_token: Optional[str] = None,
model_id: Optional[str] = None,
_continuation: Optional[Dict] = None,
) -> Generator[str, None, None]:
"""
Generator function that streams the complete conversation response.
@@ -207,8 +218,19 @@ class BaseAnswerResource:
schema_info = None
structured_chunks = []
query_metadata = {}
paused = False
for line in agent.gen(query=question):
if _continuation:
gen_iter = agent.gen_continuation(
messages=_continuation["messages"],
tools_dict=_continuation["tools_dict"],
pending_tool_calls=_continuation["pending_tool_calls"],
tool_actions=_continuation["tool_actions"],
)
else:
gen_iter = agent.gen(query=question)
for line in gen_iter:
if "metadata" in line:
query_metadata.update(line["metadata"])
elif "answer" in line:
@@ -244,15 +266,21 @@ class BaseAnswerResource:
data = json.dumps({"type": "thought", "thought": line["thought"]})
yield f"data: {data}\n\n"
elif "type" in line:
if line.get("type") == "error":
if line.get("type") == "tool_calls_pending":
# Save continuation state and end the stream
paused = True
data = json.dumps(line)
yield f"data: {data}\n\n"
elif line.get("type") == "error":
sanitized_error = {
"type": "error",
"error": sanitize_api_error(line.get("error", "An error occurred"))
}
data = json.dumps(sanitized_error)
yield f"data: {data}\n\n"
else:
data = json.dumps(line)
yield f"data: {data}\n\n"
yield f"data: {data}\n\n"
if is_structured and structured_chunks:
structured_data = {
"type": "structured_answer",
@@ -262,6 +290,93 @@ class BaseAnswerResource:
}
data = json.dumps(structured_data)
yield f"data: {data}\n\n"
# ---- Paused: save continuation state and end stream early ----
if paused:
continuation = getattr(agent, "_pending_continuation", None)
if continuation:
# Ensure we have a conversation_id — create a partial
# conversation if this is the first turn.
if not conversation_id and should_save_conversation:
try:
provider = (
get_provider_from_model_id(model_id)
if model_id
else settings.LLM_PROVIDER
)
sys_api_key = get_api_key_for_provider(
provider or settings.LLM_PROVIDER
)
llm = LLMCreator.create_llm(
provider or settings.LLM_PROVIDER,
api_key=sys_api_key,
user_api_key=user_api_key,
decoded_token=decoded_token,
model_id=model_id,
agent_id=agent_id,
)
conversation_id = (
self.conversation_service.save_conversation(
None,
question,
response_full,
thought,
source_log_docs,
tool_calls,
llm,
model_id or self.default_model_id,
decoded_token,
api_key=user_api_key,
agent_id=agent_id,
is_shared_usage=is_shared_usage,
shared_token=shared_token,
)
)
except Exception as e:
logger.error(
f"Failed to create conversation for continuation: {e}",
exc_info=True,
)
if conversation_id:
try:
cont_service = ContinuationService()
cont_service.save_state(
conversation_id=str(conversation_id),
user=decoded_token.get("sub", "local"),
messages=continuation["messages"],
pending_tool_calls=continuation["pending_tool_calls"],
tools_dict=continuation["tools_dict"],
tool_schemas=getattr(agent, "tools", []),
agent_config={
"model_id": model_id or self.default_model_id,
"llm_name": getattr(agent, "llm_name", settings.LLM_PROVIDER),
"api_key": getattr(agent, "api_key", None),
"user_api_key": user_api_key,
"agent_id": agent_id,
"agent_type": agent.__class__.__name__,
"prompt": getattr(agent, "prompt", ""),
"json_schema": getattr(agent, "json_schema", None),
"retriever_config": getattr(agent, "retriever_config", None),
},
client_tools=getattr(
agent.tool_executor, "client_tools", None
),
)
except Exception as e:
logger.error(
f"Failed to save continuation state: {str(e)}",
exc_info=True,
)
id_data = {"type": "id", "id": str(conversation_id)}
data = json.dumps(id_data)
yield f"data: {data}\n\n"
data = json.dumps({"type": "end"})
yield f"data: {data}\n\n"
return
if isNoneDoc:
for doc in source_log_docs:
doc["source"] = "None"
@@ -354,6 +469,18 @@ class BaseAnswerResource:
log_data[key] = value[:10000]
self.user_logs_collection.insert_one(log_data)
from application.storage.db.dual_write import dual_write
from application.storage.db.repositories.user_logs import UserLogsRepository
dual_write(
UserLogsRepository,
lambda repo, d=log_data: repo.insert(
user_id=d.get("user"),
endpoint="stream_answer",
data=d,
),
)
data = json.dumps({"type": "end"})
yield f"data: {data}\n\n"
except GeneratorExit:
@@ -425,8 +552,13 @@ class BaseAnswerResource:
yield f"data: {data}\n\n"
return
def process_response_stream(self, stream):
"""Process the stream response for non-streaming endpoint"""
def process_response_stream(self, stream) -> Dict[str, Any]:
"""Process the stream response for non-streaming endpoint.
Returns:
Dict with keys: conversation_id, answer, sources, tool_calls,
thought, error, and optional extra.
"""
conversation_id = ""
response_full = ""
source_log_docs = []
@@ -435,6 +567,7 @@ class BaseAnswerResource:
stream_ended = False
is_structured = False
schema_info = None
pending_tool_calls = None
for line in stream:
try:
@@ -453,11 +586,22 @@ class BaseAnswerResource:
source_log_docs = event["source"]
elif event["type"] == "tool_calls":
tool_calls = event["tool_calls"]
elif event["type"] == "tool_calls_pending":
pending_tool_calls = event.get("data", {}).get(
"pending_tool_calls", []
)
elif event["type"] == "thought":
thought = event["thought"]
elif event["type"] == "error":
logger.error(f"Error from stream: {event['error']}")
return None, None, None, None, event["error"], None
return {
"conversation_id": None,
"answer": None,
"sources": None,
"tool_calls": None,
"thought": None,
"error": event["error"],
}
elif event["type"] == "end":
stream_ended = True
except (json.JSONDecodeError, KeyError) as e:
@@ -465,18 +609,30 @@ class BaseAnswerResource:
continue
if not stream_ended:
logger.error("Stream ended unexpectedly without an 'end' event.")
return None, None, None, None, "Stream ended unexpectedly", None
result = (
conversation_id,
response_full,
source_log_docs,
tool_calls,
thought,
None,
)
return {
"conversation_id": None,
"answer": None,
"sources": None,
"tool_calls": None,
"thought": None,
"error": "Stream ended unexpectedly",
}
result: Dict[str, Any] = {
"conversation_id": conversation_id,
"answer": response_full,
"sources": source_log_docs,
"tool_calls": tool_calls,
"thought": thought,
"error": None,
}
if pending_tool_calls is not None:
result["extra"] = {"pending_tool_calls": pending_tool_calls}
if is_structured:
result = result + ({"structured": True, "schema": schema_info},)
result["extra"] = {"structured": True, "schema": schema_info}
return result
def error_stream_generate(self, err_response):

View File

@@ -79,7 +79,47 @@ class StreamResource(Resource, BaseAnswerResource):
return error
decoded_token = getattr(request, "decoded_token", None)
processor = StreamProcessor(data, decoded_token)
try:
# ---- Continuation mode ----
if data.get("tool_actions"):
(
agent,
messages,
tools_dict,
pending_tool_calls,
tool_actions,
) = processor.resume_from_tool_actions(
data["tool_actions"], data["conversation_id"]
)
if not processor.decoded_token:
return Response(
self.error_stream_generate("Unauthorized"),
status=401,
mimetype="text/event-stream",
)
if error := self.check_usage(processor.agent_config):
return error
return Response(
self.complete_stream(
question="",
agent=agent,
conversation_id=processor.conversation_id,
user_api_key=processor.agent_config.get("user_api_key"),
decoded_token=processor.decoded_token,
agent_id=processor.agent_id,
model_id=processor.model_id,
_continuation={
"messages": messages,
"tools_dict": tools_dict,
"pending_tool_calls": pending_tool_calls,
"tool_actions": tool_actions,
},
),
mimetype="text/event-stream",
)
# ---- Normal mode ----
agent = processor.build_agent(data["question"])
if not processor.decoded_token:
return Response(

View File

@@ -1,5 +1,6 @@
"""Message reconstruction utilities for compression."""
import json
import logging
import uuid
from typing import Dict, List, Optional
@@ -49,28 +50,35 @@ class MessageBuilder:
if include_tool_calls and "tool_calls" in query:
for tool_call in query["tool_calls"]:
call_id = tool_call.get("call_id") or str(uuid.uuid4())
function_call_dict = {
"function_call": {
"name": tool_call.get("action_name"),
"args": tool_call.get("arguments"),
"call_id": call_id,
}
}
function_response_dict = {
"function_response": {
"name": tool_call.get("action_name"),
"response": {"result": tool_call.get("result")},
"call_id": call_id,
}
}
messages.append(
{"role": "assistant", "content": [function_call_dict]}
args = tool_call.get("arguments")
args_str = (
json.dumps(args)
if isinstance(args, dict)
else (args or "{}")
)
messages.append(
{"role": "tool", "content": [function_response_dict]}
messages.append({
"role": "assistant",
"content": None,
"tool_calls": [{
"id": call_id,
"type": "function",
"function": {
"name": tool_call.get("action_name", ""),
"arguments": args_str,
},
}],
})
result = tool_call.get("result")
result_str = (
json.dumps(result)
if not isinstance(result, str)
else (result or "")
)
messages.append({
"role": "tool",
"tool_call_id": call_id,
"content": result_str,
})
# If no recent queries (everything was compressed), add a continuation user message
if len(recent_queries) == 0 and compressed_summary:
@@ -180,28 +188,35 @@ class MessageBuilder:
if include_tool_calls and "tool_calls" in query:
for tool_call in query["tool_calls"]:
call_id = tool_call.get("call_id") or str(uuid.uuid4())
function_call_dict = {
"function_call": {
"name": tool_call.get("action_name"),
"args": tool_call.get("arguments"),
"call_id": call_id,
}
}
function_response_dict = {
"function_response": {
"name": tool_call.get("action_name"),
"response": {"result": tool_call.get("result")},
"call_id": call_id,
}
}
rebuilt_messages.append(
{"role": "assistant", "content": [function_call_dict]}
args = tool_call.get("arguments")
args_str = (
json.dumps(args)
if isinstance(args, dict)
else (args or "{}")
)
rebuilt_messages.append(
{"role": "tool", "content": [function_response_dict]}
rebuilt_messages.append({
"role": "assistant",
"content": None,
"tool_calls": [{
"id": call_id,
"type": "function",
"function": {
"name": tool_call.get("action_name", ""),
"arguments": args_str,
},
}],
})
result = tool_call.get("result")
result_str = (
json.dumps(result)
if not isinstance(result, str)
else (result or "")
)
rebuilt_messages.append({
"role": "tool",
"tool_call_id": call_id,
"content": result_str,
})
# If no recent queries (everything was compressed), add a continuation user message
if len(recent_queries) == 0 and compressed_summary:

View File

@@ -0,0 +1,141 @@
"""Service for saving and restoring tool-call continuation state.
When a stream pauses (tool needs approval or client-side execution),
the full execution state is persisted to MongoDB so the client can
resume later by sending tool_actions.
"""
import datetime
import logging
from typing import Any, Dict, List, Optional
from bson import ObjectId
from application.core.mongo_db import MongoDB
from application.core.settings import settings
logger = logging.getLogger(__name__)
# TTL for pending states — auto-cleaned after this period
PENDING_STATE_TTL_SECONDS = 30 * 60 # 30 minutes
def _make_serializable(obj: Any) -> Any:
"""Recursively convert MongoDB ObjectIds and other non-JSON types."""
if isinstance(obj, ObjectId):
return str(obj)
if isinstance(obj, dict):
return {str(k): _make_serializable(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_make_serializable(v) for v in obj]
if isinstance(obj, bytes):
return obj.decode("utf-8", errors="replace")
return obj
class ContinuationService:
"""Manages pending tool-call state in MongoDB."""
def __init__(self):
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
self.collection = db["pending_tool_state"]
self._ensure_indexes()
def _ensure_indexes(self):
try:
self.collection.create_index(
"expires_at", expireAfterSeconds=0
)
self.collection.create_index(
[("conversation_id", 1), ("user", 1)], unique=True
)
except Exception:
# Indexes may already exist or mongomock doesn't support TTL
pass
def save_state(
self,
conversation_id: str,
user: str,
messages: List[Dict],
pending_tool_calls: List[Dict],
tools_dict: Dict,
tool_schemas: List[Dict],
agent_config: Dict,
client_tools: Optional[List[Dict]] = None,
) -> str:
"""Save execution state for later continuation.
Args:
conversation_id: The conversation this state belongs to.
user: Owner user ID.
messages: Full messages array at the pause point.
pending_tool_calls: Tool calls awaiting client action.
tools_dict: Serializable tools configuration dict.
tool_schemas: LLM-formatted tool schemas (agent.tools).
agent_config: Config needed to recreate the agent on resume.
client_tools: Client-provided tool schemas for client-side execution.
Returns:
The string ID of the saved state document.
"""
now = datetime.datetime.now(datetime.timezone.utc)
expires_at = now + datetime.timedelta(seconds=PENDING_STATE_TTL_SECONDS)
doc = {
"conversation_id": conversation_id,
"user": user,
"messages": _make_serializable(messages),
"pending_tool_calls": _make_serializable(pending_tool_calls),
"tools_dict": _make_serializable(tools_dict),
"tool_schemas": _make_serializable(tool_schemas),
"agent_config": _make_serializable(agent_config),
"client_tools": _make_serializable(client_tools) if client_tools else None,
"created_at": now,
"expires_at": expires_at,
}
# Upsert — only one pending state per conversation per user
result = self.collection.replace_one(
{"conversation_id": conversation_id, "user": user},
doc,
upsert=True,
)
state_id = str(result.upserted_id) if result.upserted_id else conversation_id
logger.info(
f"Saved continuation state for conversation {conversation_id} "
f"with {len(pending_tool_calls)} pending tool call(s)"
)
return state_id
def load_state(
self, conversation_id: str, user: str
) -> Optional[Dict[str, Any]]:
"""Load pending continuation state.
Returns:
The state dict, or None if no pending state exists.
"""
doc = self.collection.find_one(
{"conversation_id": conversation_id, "user": user}
)
if not doc:
return None
doc["_id"] = str(doc["_id"])
return doc
def delete_state(self, conversation_id: str, user: str) -> bool:
"""Delete pending state after successful resumption.
Returns:
True if a document was deleted.
"""
result = self.collection.delete_one(
{"conversation_id": conversation_id, "user": user}
)
if result.deleted_count:
logger.info(
f"Deleted continuation state for conversation {conversation_id}"
)
return result.deleted_count > 0

View File

@@ -112,6 +112,7 @@ class StreamProcessor:
self._required_tool_actions: Optional[Dict[str, Set[Optional[str]]]] = None
self.compressed_summary: Optional[str] = None
self.compressed_summary_tokens: int = 0
self._agent_data: Optional[Dict[str, Any]] = None
def initialize(self):
"""Initialize all required components for processing"""
@@ -359,22 +360,29 @@ class StreamProcessor:
return data
def _configure_source(self):
"""Configure the source based on agent data"""
api_key = self.data.get("api_key") or self.agent_key
"""Configure the source based on agent data.
if api_key:
agent_data = self._get_data_from_api_key(api_key)
The literal string ``"default"`` is a placeholder meaning "no
ingested source" and is normalized to an empty source so that no
retrieval is attempted.
"""
if self._agent_data:
agent_data = self._agent_data
if agent_data.get("sources") and len(agent_data["sources"]) > 0:
source_ids = [
source["id"] for source in agent_data["sources"] if source.get("id")
source["id"]
for source in agent_data["sources"]
if source.get("id") and source["id"] != "default"
]
if source_ids:
self.source = {"active_docs": source_ids}
else:
self.source = {}
self.all_sources = agent_data["sources"]
elif agent_data.get("source"):
self.all_sources = [
s for s in agent_data["sources"] if s.get("id") != "default"
]
elif agent_data.get("source") and agent_data["source"] != "default":
self.source = {"active_docs": agent_data["source"]}
self.all_sources = [
{
@@ -387,11 +395,24 @@ class StreamProcessor:
self.all_sources = []
return
if "active_docs" in self.data:
self.source = {"active_docs": self.data["active_docs"]}
active_docs = self.data["active_docs"]
if active_docs and active_docs != "default":
self.source = {"active_docs": active_docs}
else:
self.source = {}
return
self.source = {}
self.all_sources = []
def _has_active_docs(self) -> bool:
"""Return True if a real document source is configured for retrieval."""
active_docs = self.source.get("active_docs") if self.source else None
if not active_docs:
return False
if active_docs == "default":
return False
return True
def _resolve_agent_id(self) -> Optional[str]:
"""Resolve agent_id from request, then fall back to conversation context."""
request_agent_id = self.data.get("agent_id")
@@ -433,48 +454,39 @@ class StreamProcessor:
effective_key = self.data.get("api_key") or self.agent_key
if effective_key:
data_key = self._get_data_from_api_key(effective_key)
if data_key.get("_id"):
self.agent_id = str(data_key.get("_id"))
self._agent_data = self._get_data_from_api_key(effective_key)
if self._agent_data.get("_id"):
self.agent_id = str(self._agent_data.get("_id"))
self.agent_config.update(
{
"prompt_id": data_key.get("prompt_id", "default"),
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
"prompt_id": self._agent_data.get("prompt_id", "default"),
"agent_type": self._agent_data.get("agent_type", settings.AGENT_NAME),
"user_api_key": effective_key,
"json_schema": data_key.get("json_schema"),
"default_model_id": data_key.get("default_model_id", ""),
"models": data_key.get("models", []),
"json_schema": self._agent_data.get("json_schema"),
"default_model_id": self._agent_data.get("default_model_id", ""),
"models": self._agent_data.get("models", []),
"allow_system_prompt_override": self._agent_data.get(
"allow_system_prompt_override", False
),
}
)
# Set identity context
if self.data.get("api_key"):
# External API key: use the key owner's identity
self.initial_user_id = data_key.get("user")
self.decoded_token = {"sub": data_key.get("user")}
self.initial_user_id = self._agent_data.get("user")
self.decoded_token = {"sub": self._agent_data.get("user")}
elif self.is_shared_usage:
# Shared agent: keep the caller's identity
pass
else:
# Owner using their own agent
self.decoded_token = {"sub": data_key.get("user")}
self.decoded_token = {"sub": self._agent_data.get("user")}
if data_key.get("source"):
self.source = {"active_docs": data_key["source"]}
if data_key.get("workflow"):
self.agent_config["workflow"] = data_key["workflow"]
self.agent_config["workflow_owner"] = data_key.get("user")
if data_key.get("retriever"):
self.retriever_config["retriever_name"] = data_key["retriever"]
if data_key.get("chunks") is not None:
try:
self.retriever_config["chunks"] = int(data_key["chunks"])
except (ValueError, TypeError):
logger.warning(
f"Invalid chunks value: {data_key['chunks']}, using default value 2"
)
self.retriever_config["chunks"] = 2
if self._agent_data.get("workflow"):
self.agent_config["workflow"] = self._agent_data["workflow"]
self.agent_config["workflow_owner"] = self._agent_data.get("user")
else:
# No API key — default/workflow configuration
agent_type = settings.AGENT_NAME
@@ -497,14 +509,45 @@ class StreamProcessor:
)
def _configure_retriever(self):
"""Assemble retriever config with precedence: request > agent > default."""
doc_token_limit = calculate_doc_token_budget(model_id=self.model_id)
# Start with defaults
retriever_name = "classic"
chunks = 2
# Layer agent-level config (if present)
if self._agent_data:
if self._agent_data.get("retriever"):
retriever_name = self._agent_data["retriever"]
if self._agent_data.get("chunks") is not None:
try:
chunks = int(self._agent_data["chunks"])
except (ValueError, TypeError):
logger.warning(
f"Invalid agent chunks value: {self._agent_data['chunks']}, "
"using default value 2"
)
# Explicit request values win over agent config
if "retriever" in self.data:
retriever_name = self.data["retriever"]
if "chunks" in self.data:
try:
chunks = int(self.data["chunks"])
except (ValueError, TypeError):
logger.warning(
f"Invalid request chunks value: {self.data['chunks']}, "
"using default value 2"
)
self.retriever_config = {
"retriever_name": self.data.get("retriever", "classic"),
"chunks": int(self.data.get("chunks", 2)),
"retriever_name": retriever_name,
"chunks": chunks,
"doc_token_limit": doc_token_limit,
}
# isNoneDoc without an API key forces no retrieval
api_key = self.data.get("api_key") or self.agent_key
if not api_key and "isNoneDoc" in self.data and self.data["isNoneDoc"]:
self.retriever_config["chunks"] = 0
@@ -528,6 +571,9 @@ class StreamProcessor:
if self.data.get("isNoneDoc", False) and not self.agent_id:
logger.info("Pre-fetch skipped: isNoneDoc=True")
return None, None
if not self._has_active_docs():
logger.info("Pre-fetch skipped: no active docs configured")
return None, None
try:
retriever = self.create_retriever()
logger.info(
@@ -771,6 +817,121 @@ class StreamProcessor:
logger.warning(f"Failed to fetch memory tool data: {str(e)}")
return None
def resume_from_tool_actions(
self,
tool_actions: list,
conversation_id: str,
):
"""Resume a paused agent from saved continuation state.
Loads the pending state from MongoDB, recreates the agent with
the saved configuration, and returns an agent ready to call
``gen_continuation()``.
Args:
tool_actions: Client-provided actions (approvals / results).
conversation_id: The conversation being resumed.
Returns:
Tuple of (agent, messages, tools_dict, pending_tool_calls, tool_actions).
"""
from application.api.answer.services.continuation_service import (
ContinuationService,
)
from application.agents.agent_creator import AgentCreator
from application.agents.tool_executor import ToolExecutor
from application.llm.handlers.handler_creator import LLMHandlerCreator
from application.llm.llm_creator import LLMCreator
cont_service = ContinuationService()
state = cont_service.load_state(conversation_id, self.initial_user_id)
if not state:
raise ValueError("No pending tool state found for this conversation")
messages = state["messages"]
pending_tool_calls = state["pending_tool_calls"]
tools_dict = state["tools_dict"]
tool_schemas = state.get("tool_schemas", [])
agent_config = state["agent_config"]
model_id = agent_config.get("model_id")
llm_name = agent_config.get("llm_name", settings.LLM_PROVIDER)
api_key = agent_config.get("api_key")
user_api_key = agent_config.get("user_api_key")
agent_id = agent_config.get("agent_id")
prompt = agent_config.get("prompt", "")
json_schema = agent_config.get("json_schema")
retriever_config = agent_config.get("retriever_config")
# Recreate dependencies
system_api_key = api_key or get_api_key_for_provider(llm_name)
llm = LLMCreator.create_llm(
llm_name,
api_key=system_api_key,
user_api_key=user_api_key,
decoded_token=self.decoded_token,
model_id=model_id,
agent_id=agent_id,
)
llm_handler = LLMHandlerCreator.create_handler(llm_name or "default")
tool_executor = ToolExecutor(
user_api_key=user_api_key,
user=self.initial_user_id,
decoded_token=self.decoded_token,
)
tool_executor.conversation_id = conversation_id
# Restore client tools so they stay available for subsequent LLM calls
saved_client_tools = state.get("client_tools")
if saved_client_tools:
tool_executor.client_tools = saved_client_tools
# Re-merge into tools_dict (they may have been stripped during serialization)
tool_executor.merge_client_tools(tools_dict, saved_client_tools)
agent_type = agent_config.get("agent_type", "ClassicAgent")
# Map class names back to agent creator keys
type_map = {
"ClassicAgent": "classic",
"AgenticAgent": "agentic",
"ResearchAgent": "research",
"WorkflowAgent": "workflow",
}
agent_key = type_map.get(agent_type, "classic")
agent_kwargs = {
"endpoint": "stream",
"llm_name": llm_name,
"model_id": model_id,
"api_key": system_api_key,
"agent_id": agent_id,
"user_api_key": user_api_key,
"prompt": prompt,
"chat_history": [],
"decoded_token": self.decoded_token,
"json_schema": json_schema,
"llm": llm,
"llm_handler": llm_handler,
"tool_executor": tool_executor,
}
if agent_key in ("agentic", "research") and retriever_config:
agent_kwargs["retriever_config"] = retriever_config
agent = AgentCreator.create_agent(agent_key, **agent_kwargs)
agent.conversation_id = conversation_id
agent.initial_user_id = self.initial_user_id
agent.tools = tool_schemas
# Store config for the route layer
self.model_id = model_id
self.agent_id = agent_id
self.agent_config["user_api_key"] = user_api_key
self.conversation_id = conversation_id
# Delete state so it can't be replayed
cont_service.delete_state(conversation_id, self.initial_user_id)
return agent, messages, tools_dict, pending_tool_calls, tool_actions
def create_agent(
self,
docs_together: Optional[str] = None,
@@ -795,15 +956,23 @@ class StreamProcessor:
raw_prompt = get_prompt(prompt_id, self.prompts_collection)
self._prompt_content = raw_prompt
rendered_prompt = self.prompt_renderer.render_prompt(
prompt_content=raw_prompt,
user_id=self.initial_user_id,
request_id=self.data.get("request_id"),
passthrough_data=self.data.get("passthrough"),
docs=docs,
docs_together=docs_together,
tools_data=tools_data,
)
# Allow API callers to override the system prompt when the agent
# has opted in via allow_system_prompt_override.
if (
self.agent_config.get("allow_system_prompt_override", False)
and self.data.get("system_prompt_override")
):
rendered_prompt = self.data["system_prompt_override"]
else:
rendered_prompt = self.prompt_renderer.render_prompt(
prompt_content=raw_prompt,
user_id=self.initial_user_id,
request_id=self.data.get("request_id"),
passthrough_data=self.data.get("passthrough"),
docs=docs,
docs_together=docs_together,
tools_data=tools_data,
)
provider = (
get_provider_from_model_id(self.model_id)
@@ -841,6 +1010,10 @@ class StreamProcessor:
decoded_token=self.decoded_token,
)
tool_executor.conversation_id = self.conversation_id
# Pass client-side tools so they get merged in get_tools()
client_tools = self.data.get("client_tools")
if client_tools:
tool_executor.client_tools = client_tools
# Base agent kwargs
agent_kwargs = {

View File

@@ -26,12 +26,20 @@ internal = Blueprint("internal", __name__)
@internal.before_request
def verify_internal_key():
"""Verify INTERNAL_KEY for all internal endpoint requests."""
if settings.INTERNAL_KEY:
internal_key = request.headers.get("X-Internal-Key")
if not internal_key or internal_key != settings.INTERNAL_KEY:
logger.warning(f"Unauthorized internal API access attempt from {request.remote_addr}")
return jsonify({"error": "Unauthorized", "message": "Invalid or missing internal key"}), 401
"""Verify INTERNAL_KEY for all internal endpoint requests.
Deny by default: if INTERNAL_KEY is not configured, reject all requests.
"""
if not settings.INTERNAL_KEY:
logger.warning(
f"Internal API request rejected from {request.remote_addr}: "
"INTERNAL_KEY is not configured"
)
return jsonify({"error": "Unauthorized", "message": "Internal API is not configured"}), 401
internal_key = request.headers.get("X-Internal-Key")
if not internal_key or internal_key != settings.INTERNAL_KEY:
logger.warning(f"Unauthorized internal API access attempt from {request.remote_addr}")
return jsonify({"error": "Unauthorized", "message": "Invalid or missing internal key"}), 401
@internal.route("/api/download", methods=["get"])

View File

@@ -13,6 +13,8 @@ from application.api.user.base import (
agent_folders_collection,
agents_collection,
)
from application.storage.db.dual_write import dual_write
from application.storage.db.repositories.agent_folders import AgentFoldersRepository
agents_folders_ns = Namespace(
"agents_folders", description="Agent folder management", path="/api/agents/folders"
@@ -83,6 +85,10 @@ class AgentFolders(Resource):
"updated_at": now,
}
result = agent_folders_collection.insert_one(folder)
dual_write(
AgentFoldersRepository,
lambda repo, u=user, n=data["name"]: repo.create(u, n),
)
return make_response(
jsonify({"id": str(result.inserted_id), "name": data["name"], "parent_id": parent_id}),
201,
@@ -167,6 +173,10 @@ class AgentFolder(Resource):
{"user": user, "parent_id": folder_id}, {"$unset": {"parent_id": ""}}
)
result = agent_folders_collection.delete_one({"_id": ObjectId(folder_id), "user": user})
dual_write(
AgentFoldersRepository,
lambda repo, fid=folder_id, u=user: repo.delete(fid, u),
)
if result.deleted_count == 0:
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
return make_response(jsonify({"success": True}), 200)

View File

@@ -23,6 +23,9 @@ from application.api.user.base import (
workflow_nodes_collection,
workflows_collection,
)
from application.storage.db.dual_write import dual_write
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.repositories.users import UsersRepository
from application.core.json_schema_utils import (
JsonSchemaValidationError,
normalize_json_schema_payload,
@@ -73,6 +76,7 @@ AGENT_TYPE_SCHEMAS = {
"token_limit",
"limited_request_mode",
"request_limit",
"allow_system_prompt_override",
"createdAt",
"updatedAt",
"lastUsedAt",
@@ -96,6 +100,7 @@ AGENT_TYPE_SCHEMAS = {
"token_limit",
"limited_request_mode",
"request_limit",
"allow_system_prompt_override",
"createdAt",
"updatedAt",
"lastUsedAt",
@@ -220,6 +225,12 @@ def build_agent_document(
base_doc["request_limit"] = int(
data.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"])
)
if "allow_system_prompt_override" in allowed_fields:
base_doc["allow_system_prompt_override"] = (
data.get("allow_system_prompt_override") == "True"
if isinstance(data.get("allow_system_prompt_override"), str)
else bool(data.get("allow_system_prompt_override", False))
)
return {k: v for k, v in base_doc.items() if k in allowed_fields}
@@ -292,6 +303,9 @@ class GetAgent(Resource):
"default_model_id": agent.get("default_model_id", ""),
"folder_id": agent.get("folder_id"),
"workflow": agent.get("workflow"),
"allow_system_prompt_override": agent.get(
"allow_system_prompt_override", False
),
}
return make_response(jsonify(data), 200)
except Exception as e:
@@ -373,6 +387,9 @@ class GetAgents(Resource):
"default_model_id": agent.get("default_model_id", ""),
"folder_id": agent.get("folder_id"),
"workflow": agent.get("workflow"),
"allow_system_prompt_override": agent.get(
"allow_system_prompt_override", False
),
}
for agent in agents
if "source" in agent
@@ -450,6 +467,10 @@ class CreateAgent(Resource):
"folder_id": fields.String(
required=False, description="Folder ID to organize the agent"
),
"allow_system_prompt_override": fields.Boolean(
required=False,
description="Allow API callers to override the system prompt via the v1 endpoint",
),
},
)
@@ -491,9 +512,9 @@ class CreateAgent(Resource):
data["json_schema"] = normalize_json_schema_payload(
data.get("json_schema")
)
except JsonSchemaValidationError as exc:
except JsonSchemaValidationError:
return make_response(
jsonify({"success": False, "message": f"JSON schema {exc}"}),
jsonify({"success": False, "message": "Invalid JSON schema"}),
400,
)
if data.get("status") not in ["draft", "published"]:
@@ -603,6 +624,17 @@ class CreateAgent(Resource):
new_agent["retriever"] = "classic"
resp = agents_collection.insert_one(new_agent)
new_id = str(resp.inserted_id)
dual_write(
AgentsRepository,
lambda repo, u=user, a=new_agent: repo.create(
u, a.get("name", ""), a.get("status", "draft"),
key=a.get("key"), description=a.get("description"),
retriever=a.get("retriever"), chunks=a.get("chunks"),
tools=a.get("tools"), models=a.get("models"),
shared=a.get("shared", False),
incoming_webhook_token=a.get("incoming_webhook_token"),
),
)
except Exception as err:
current_app.logger.error(f"Error creating agent: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
@@ -674,6 +706,10 @@ class UpdateAgent(Resource):
"folder_id": fields.String(
required=False, description="Folder ID to organize the agent"
),
"allow_system_prompt_override": fields.Boolean(
required=False,
description="Allow API callers to override the system prompt via the v1 endpoint",
),
},
)
@@ -765,6 +801,7 @@ class UpdateAgent(Resource):
"default_model_id",
"folder_id",
"workflow",
"allow_system_prompt_override",
]
for field in allowed_fields:
@@ -872,9 +909,9 @@ class UpdateAgent(Resource):
update_fields[field] = normalize_json_schema_payload(
json_schema
)
except JsonSchemaValidationError as exc:
except JsonSchemaValidationError:
return make_response(
jsonify({"success": False, "message": f"JSON schema {exc}"}),
jsonify({"success": False, "message": "Invalid JSON schema"}),
400,
)
else:
@@ -983,6 +1020,13 @@ class UpdateAgent(Resource):
if workflow_error:
return workflow_error
update_fields[field] = workflow_id
elif field == "allow_system_prompt_override":
raw_value = data.get("allow_system_prompt_override", False)
update_fields[field] = (
raw_value == "True"
if isinstance(raw_value, str)
else bool(raw_value)
)
else:
value = data[field]
if field in ["name", "description", "prompt_id", "agent_type"]:
@@ -1153,6 +1197,10 @@ class DeleteAgent(Resource):
deleted_agent = agents_collection.find_one_and_delete(
{"_id": ObjectId(agent_id), "user": user}
)
dual_write(
AgentsRepository,
lambda repo, aid=agent_id, u=user: repo.delete(aid, u),
)
if not deleted_agent:
return make_response(
jsonify({"success": False, "message": "Agent not found"}), 404
@@ -1220,6 +1268,9 @@ class PinnedAgents(Resource):
{"user_id": user_id},
{"$pullAll": {"agent_preferences.pinned": stale_ids}},
)
dual_write(UsersRepository,
lambda repo, uid=user_id, ids=stale_ids: repo.remove_pinned_bulk(uid, ids)
)
list_pinned_agents = [
{
"id": str(agent["_id"]),
@@ -1351,12 +1402,18 @@ class PinAgent(Resource):
{"user_id": user_id},
{"$pull": {"agent_preferences.pinned": agent_id}},
)
dual_write(UsersRepository,
lambda repo, uid=user_id, aid=agent_id: repo.remove_pinned(uid, aid)
)
action = "unpinned"
else:
users_collection.update_one(
{"user_id": user_id},
{"$addToSet": {"agent_preferences.pinned": agent_id}},
)
dual_write(UsersRepository,
lambda repo, uid=user_id, aid=agent_id: repo.add_pinned(uid, aid)
)
action = "pinned"
except Exception as err:
current_app.logger.error(f"Error pinning/unpinning agent: {err}")
@@ -1402,6 +1459,9 @@ class RemoveSharedAgent(Resource):
}
},
)
dual_write(UsersRepository,
lambda repo, uid=user_id, aid=agent_id: repo.remove_agent_from_all(uid, aid)
)
return make_response(jsonify({"success": True, "action": "removed"}), 200)
except Exception as err:

View File

@@ -18,6 +18,8 @@ from application.api.user.base import (
user_tools_collection,
users_collection,
)
from application.storage.db.dual_write import dual_write
from application.storage.db.repositories.users import UsersRepository
from application.utils import generate_image_url
agents_sharing_ns = Namespace(
@@ -105,6 +107,9 @@ class SharedAgent(Resource):
{"user_id": user_id},
{"$addToSet": {"agent_preferences.shared_with_me": agent_id}},
)
dual_write(UsersRepository,
lambda repo, uid=user_id, aid=agent_id: repo.add_shared(uid, aid)
)
return make_response(jsonify(data), 200)
except Exception as err:
current_app.logger.error(f"Error retrieving shared agent: {err}")
@@ -139,6 +144,9 @@ class SharedAgents(Resource):
{"user_id": user_id},
{"$pullAll": {"agent_preferences.shared_with_me": stale_ids}},
)
dual_write(UsersRepository,
lambda repo, uid=user_id, ids=stale_ids: repo.remove_shared_bulk(uid, ids)
)
pinned_ids = set(user_doc.get("agent_preferences", {}).get("pinned", []))
list_shared_agents = [

View File

@@ -612,6 +612,10 @@ class LiveSpeechToTextFinish(Resource):
class ServeImage(Resource):
@api.doc(description="Serve an image from storage")
def get(self, image_path):
if ".." in image_path or image_path.startswith("/") or "\x00" in image_path:
return make_response(
jsonify({"success": False, "message": "Invalid image path"}), 400
)
try:
from application.api.user.base import storage
@@ -629,6 +633,10 @@ class ServeImage(Resource):
return make_response(
jsonify({"success": False, "message": "Image not found"}), 404
)
except ValueError:
return make_response(
jsonify({"success": False, "message": "Invalid image path"}), 400
)
except Exception as e:
current_app.logger.error(f"Error serving image: {e}")
return make_response(

View File

@@ -15,6 +15,8 @@ from werkzeug.utils import secure_filename
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.storage.db.dual_write import dual_write
from application.storage.db.repositories.users import UsersRepository
from application.storage.storage_creator import StorageCreator
from application.vectorstore.vector_creator import VectorCreator
@@ -132,6 +134,9 @@ def ensure_user_doc(user_id):
if updates:
users_collection.update_one({"user_id": user_id}, {"$set": updates})
user_doc = users_collection.find_one({"user_id": user_id})
dual_write(UsersRepository, lambda repo: repo.upsert(user_id))
return user_doc

View File

@@ -8,6 +8,8 @@ from flask_restx import fields, Namespace, Resource
from application.api import api
from application.api.user.base import current_dir, prompts_collection
from application.storage.db.dual_write import dual_write
from application.storage.db.repositories.prompts import PromptsRepository
from application.utils import check_required_fields
prompts_ns = Namespace(
@@ -49,6 +51,10 @@ class CreatePrompt(Resource):
}
)
new_id = str(resp.inserted_id)
dual_write(
PromptsRepository,
lambda repo, u=user, n=data["name"], c=data["content"]: repo.create(u, n, c),
)
except Exception as err:
current_app.logger.error(f"Error creating prompt: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
@@ -149,6 +155,10 @@ class DeletePrompt(Resource):
return missing_fields
try:
prompts_collection.delete_one({"_id": ObjectId(data["id"]), "user": user})
dual_write(
PromptsRepository,
lambda repo, pid=data["id"], u=user: repo.delete(pid, u),
)
except Exception as err:
current_app.logger.error(f"Error deleting prompt: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
@@ -185,6 +195,10 @@ class UpdatePrompt(Resource):
{"_id": ObjectId(data["id"]), "user": user},
{"$set": {"name": data["name"], "content": data["content"]}},
)
dual_write(
PromptsRepository,
lambda repo, pid=data["id"], u=user, n=data["name"], c=data["content"]: repo.update(pid, u, n, c),
)
except Exception as err:
current_app.logger.error(f"Error updating prompt: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)

View File

@@ -57,7 +57,7 @@ class ShareConversation(Resource):
try:
conversation = conversations_collection.find_one(
{"_id": ObjectId(conversation_id)}
{"_id": ObjectId(conversation_id), "user": user}
)
if conversation is None:
return make_response(

View File

@@ -463,6 +463,16 @@ class ManageSourceFiles(Resource):
removed_files = []
map_updated = False
for file_path in file_paths:
if ".." in str(file_path) or str(file_path).startswith("/"):
return make_response(
jsonify(
{
"success": False,
"message": "Invalid file path",
}
),
400,
)
full_path = f"{source_file_path}/{file_path}"
# Remove from storage

View File

@@ -14,6 +14,7 @@ from application.api.user.tools.routes import transform_actions
from application.cache import get_redis_instance
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.core.url_validation import SSRFError, validate_url
from application.security.encryption import decrypt_credentials, encrypt_credentials
from application.utils import check_required_fields
@@ -63,6 +64,21 @@ def _extract_auth_credentials(config):
return auth_credentials
def _validate_mcp_server_url(config: dict) -> None:
"""Validate the server_url in an MCP config to prevent SSRF.
Raises:
ValueError: If the URL is missing or points to a blocked address.
"""
server_url = (config.get("server_url") or "").strip()
if not server_url:
raise ValueError("server_url is required")
try:
validate_url(server_url)
except SSRFError as exc:
raise ValueError(f"Invalid server URL: {exc}") from exc
@tools_mcp_ns.route("/mcp_server/test")
class TestMCPServerConfig(Resource):
@api.expect(
@@ -97,6 +113,8 @@ class TestMCPServerConfig(Resource):
400,
)
_validate_mcp_server_url(config)
auth_credentials = _extract_auth_credentials(config)
test_config = config.copy()
test_config["auth_credentials"] = auth_credentials
@@ -105,15 +123,41 @@ class TestMCPServerConfig(Resource):
result = mcp_tool.test_connection()
if result.get("requires_oauth"):
return make_response(jsonify(result), 200)
safe_result = {
k: v
for k, v in result.items()
if k in ("success", "requires_oauth", "auth_url")
}
return make_response(jsonify(safe_result), 200)
if not result.get("success") and "message" in result:
if not result.get("success"):
current_app.logger.error(
f"MCP connection test failed: {result.get('message')}"
)
result["message"] = "Connection test failed"
return make_response(
jsonify(
{
"success": False,
"message": "Connection test failed",
"tools_count": 0,
}
),
200,
)
return make_response(jsonify(result), 200)
safe_result = {
"success": True,
"message": result.get("message", "Connection successful"),
"tools_count": result.get("tools_count", 0),
"tools": result.get("tools", []),
}
return make_response(jsonify(safe_result), 200)
except ValueError as e:
current_app.logger.warning(f"Invalid MCP server test request: {e}")
return make_response(
jsonify({"success": False, "error": "Invalid MCP server configuration"}),
400,
)
except Exception as e:
current_app.logger.error(f"Error testing MCP server: {e}", exc_info=True)
return make_response(
@@ -165,6 +209,8 @@ class MCPServerSave(Resource):
400,
)
_validate_mcp_server_url(config)
auth_credentials = _extract_auth_credentials(config)
auth_type = config.get("auth_type", "none")
mcp_config = config.copy()
@@ -279,6 +325,12 @@ class MCPServerSave(Resource):
"tools_count": len(transformed_actions),
}
return make_response(jsonify(response_data), 200)
except ValueError as e:
current_app.logger.warning(f"Invalid MCP server save request: {e}")
return make_response(
jsonify({"success": False, "error": "Invalid MCP server configuration"}),
400,
)
except Exception as e:
current_app.logger.error(f"Error saving MCP server: {e}", exc_info=True)
return make_response(

View File

@@ -8,6 +8,9 @@ from application.agents.tools.spec_parser import parse_spec
from application.agents.tools.tool_manager import ToolManager
from application.api import api
from application.api.user.base import user_tools_collection
from application.core.url_validation import SSRFError, validate_url
from application.storage.db.dual_write import dual_write
from application.storage.db.repositories.user_tools import UserToolsRepository
from application.security.encryption import decrypt_credentials, encrypt_credentials
from application.utils import check_required_fields, validate_function_name
@@ -130,6 +133,8 @@ tools_ns = Namespace("tools", description="Tool management operations", path="/a
class AvailableTools(Resource):
@api.doc(description="Get available tools for a user")
def get(self):
if not request.decoded_token:
return make_response(jsonify({"success": False}), 401)
try:
tools_metadata = []
for tool_name, tool_instance in tool_manager.tools.items():
@@ -236,6 +241,16 @@ class CreateTool(Resource):
if missing_fields:
return missing_fields
try:
if data["name"] == "mcp_tool":
server_url = (data.get("config", {}).get("server_url") or "").strip()
if server_url:
try:
validate_url(server_url)
except SSRFError:
return make_response(
jsonify({"success": False, "message": "Invalid server URL"}),
400,
)
tool_instance = tool_manager.tools.get(data["name"])
if not tool_instance:
return make_response(
@@ -281,6 +296,13 @@ class CreateTool(Resource):
}
resp = user_tools_collection.insert_one(new_tool)
new_id = str(resp.inserted_id)
dual_write(
UserToolsRepository,
lambda repo, u=user, t=new_tool: repo.create(
u, t["name"], config=t.get("config"),
custom_name=t.get("customName"), display_name=t.get("displayName"),
),
)
except Exception as err:
current_app.logger.error(f"Error creating tool: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
@@ -421,6 +443,16 @@ class UpdateToolConfig(Resource):
return make_response(jsonify({"success": False}), 404)
tool_name = tool_doc.get("name")
if tool_name == "mcp_tool":
server_url = (data["config"].get("server_url") or "").strip()
if server_url:
try:
validate_url(server_url)
except SSRFError:
return make_response(
jsonify({"success": False, "message": "Invalid server URL"}),
400,
)
tool_instance = tool_manager.tools.get(tool_name)
config_requirements = (
tool_instance.get_config_requirements() if tool_instance else {}
@@ -558,6 +590,10 @@ class DeleteTool(Resource):
result = user_tools_collection.delete_one(
{"_id": ObjectId(data["id"]), "user": user}
)
dual_write(
UserToolsRepository,
lambda repo, tid=data["id"], u=user: repo.delete(tid, u),
)
if result.deleted_count == 0:
return make_response(
jsonify({"success": False, "message": "Tool not found"}), 404

View File

@@ -0,0 +1,3 @@
from application.api.v1.routes import v1_bp
__all__ = ["v1_bp"]

View File

@@ -0,0 +1,333 @@
"""Standard chat completions API routes.
Exposes ``/v1/chat/completions`` and ``/v1/models`` endpoints that
follow the widely-adopted chat completions protocol so external tools
(opencode, continue, etc.) can connect to DocsGPT agents.
"""
import json
import logging
import time
import traceback
from typing import Any, Dict, Generator, Optional
from flask import Blueprint, jsonify, make_response, request, Response
from application.api.answer.routes.base import BaseAnswerResource
from application.api.answer.services.stream_processor import StreamProcessor
from application.api.v1.translator import (
translate_request,
translate_response,
translate_stream_event,
)
from application.core.mongo_db import MongoDB
from application.core.settings import settings
logger = logging.getLogger(__name__)
v1_bp = Blueprint("v1", __name__, url_prefix="/v1")
def _extract_bearer_token() -> Optional[str]:
"""Extract API key from Authorization: Bearer header."""
auth = request.headers.get("Authorization", "")
if auth.startswith("Bearer "):
return auth[7:].strip()
return None
def _lookup_agent(api_key: str) -> Optional[Dict]:
"""Look up the agent document for this API key."""
try:
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
return db["agents"].find_one({"key": api_key})
except Exception:
logger.warning("Failed to look up agent for API key", exc_info=True)
return None
def _get_model_name(agent: Optional[Dict], api_key: str) -> str:
"""Return agent name for display as model name."""
if agent:
return agent.get("name", api_key)
return api_key
class _V1AnswerHelper(BaseAnswerResource):
"""Thin wrapper to access complete_stream / process_response_stream."""
pass
@v1_bp.route("/chat/completions", methods=["POST"])
def chat_completions():
"""Handle POST /v1/chat/completions."""
api_key = _extract_bearer_token()
if not api_key:
return make_response(
jsonify({"error": {"message": "Missing Authorization header", "type": "auth_error"}}),
401,
)
data = request.get_json()
if not data or not data.get("messages"):
return make_response(
jsonify({"error": {"message": "messages field is required", "type": "invalid_request"}}),
400,
)
is_stream = data.get("stream", False)
agent_doc = _lookup_agent(api_key)
model_name = _get_model_name(agent_doc, api_key)
try:
internal_data = translate_request(data, api_key)
except Exception as e:
logger.error(f"/v1/chat/completions translate error: {e}", exc_info=True)
return make_response(
jsonify({"error": {"message": "Failed to process request", "type": "invalid_request"}}),
400,
)
# Link decoded_token to the agent's owner so continuation state,
# logs, and tool execution use the correct user identity.
agent_user = agent_doc.get("user") if agent_doc else None
decoded_token = {"sub": agent_user or "api_key_user"}
try:
processor = StreamProcessor(internal_data, decoded_token)
if internal_data.get("tool_actions"):
# Continuation mode
conversation_id = internal_data.get("conversation_id")
if not conversation_id:
return make_response(
jsonify({"error": {"message": "conversation_id required for tool continuation", "type": "invalid_request"}}),
400,
)
(
agent,
messages,
tools_dict,
pending_tool_calls,
tool_actions,
) = processor.resume_from_tool_actions(
internal_data["tool_actions"], conversation_id
)
continuation = {
"messages": messages,
"tools_dict": tools_dict,
"pending_tool_calls": pending_tool_calls,
"tool_actions": tool_actions,
}
question = ""
else:
# Normal mode
question = internal_data.get("question", "")
agent = processor.build_agent(question)
continuation = None
if not processor.decoded_token:
return make_response(
jsonify({"error": {"message": "Unauthorized", "type": "auth_error"}}),
401,
)
helper = _V1AnswerHelper()
usage_error = helper.check_usage(processor.agent_config)
if usage_error:
return usage_error
should_save_conversation = bool(internal_data.get("save_conversation", False))
if is_stream:
return Response(
_stream_response(
helper,
question,
agent,
processor,
model_name,
continuation,
should_save_conversation,
),
mimetype="text/event-stream",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no",
},
)
else:
return _non_stream_response(
helper,
question,
agent,
processor,
model_name,
continuation,
should_save_conversation,
)
except ValueError as e:
logger.error(
f"/v1/chat/completions error: {e} - {traceback.format_exc()}",
extra={"error": str(e)},
)
return make_response(
jsonify({"error": {"message": "Failed to process request", "type": "invalid_request"}}),
400,
)
except Exception as e:
logger.error(
f"/v1/chat/completions error: {e} - {traceback.format_exc()}",
extra={"error": str(e)},
)
return make_response(
jsonify({"error": {"message": "Internal server error", "type": "server_error"}}),
500,
)
def _stream_response(
helper: _V1AnswerHelper,
question: str,
agent: Any,
processor: StreamProcessor,
model_name: str,
continuation: Optional[Dict],
should_save_conversation: bool,
) -> Generator[str, None, None]:
"""Generate translated SSE chunks for streaming response."""
completion_id = f"chatcmpl-{int(time.time())}"
internal_stream = helper.complete_stream(
question=question,
agent=agent,
conversation_id=processor.conversation_id,
user_api_key=processor.agent_config.get("user_api_key"),
decoded_token=processor.decoded_token,
agent_id=processor.agent_id,
model_id=processor.model_id,
should_save_conversation=should_save_conversation,
_continuation=continuation,
)
for line in internal_stream:
if not line.strip():
continue
# Parse the internal SSE event
event_str = line.replace("data: ", "").strip()
try:
event_data = json.loads(event_str)
except (json.JSONDecodeError, TypeError):
continue
# Update completion_id when we get the conversation id
if event_data.get("type") == "id":
conv_id = event_data.get("id", "")
if conv_id:
completion_id = f"chatcmpl-{conv_id}"
# Translate to standard format
translated = translate_stream_event(event_data, completion_id, model_name)
for chunk in translated:
yield chunk
def _non_stream_response(
helper: _V1AnswerHelper,
question: str,
agent: Any,
processor: StreamProcessor,
model_name: str,
continuation: Optional[Dict],
should_save_conversation: bool,
) -> Response:
"""Collect full response and return as single JSON."""
stream = helper.complete_stream(
question=question,
agent=agent,
conversation_id=processor.conversation_id,
user_api_key=processor.agent_config.get("user_api_key"),
decoded_token=processor.decoded_token,
agent_id=processor.agent_id,
model_id=processor.model_id,
should_save_conversation=should_save_conversation,
_continuation=continuation,
)
result = helper.process_response_stream(stream)
if result["error"]:
return make_response(
jsonify({"error": {"message": result["error"], "type": "server_error"}}),
500,
)
extra = result.get("extra")
pending = extra.get("pending_tool_calls") if isinstance(extra, dict) else None
response = translate_response(
conversation_id=result["conversation_id"],
answer=result["answer"] or "",
sources=result["sources"],
tool_calls=result["tool_calls"],
thought=result["thought"] or "",
model_name=model_name,
pending_tool_calls=pending,
)
return make_response(jsonify(response), 200)
@v1_bp.route("/models", methods=["GET"])
def list_models():
"""Handle GET /v1/models — return agents as models."""
api_key = _extract_bearer_token()
if not api_key:
return make_response(
jsonify({"error": {"message": "Missing Authorization header", "type": "auth_error"}}),
401,
)
try:
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
agents_collection = db["agents"]
# Find the agent for this api_key
agent = agents_collection.find_one({"key": api_key})
if not agent:
return make_response(
jsonify({"error": {"message": "Invalid API key", "type": "auth_error"}}),
401,
)
user = agent.get("user")
# Return all agents belonging to this user
user_agents = list(agents_collection.find({"user": user}))
models = []
for ag in user_agents:
created = ag.get("createdAt")
created_ts = int(created.timestamp()) if created else int(time.time())
model_id = str(ag.get("_id") or ag.get("id") or "")
models.append({
"id": model_id,
"object": "model",
"created": created_ts,
"owned_by": "docsgpt",
"name": ag.get("name", ""),
"description": ag.get("description", ""),
})
return make_response(
jsonify({"object": "list", "data": models}),
200,
)
except Exception as e:
logger.error(f"/v1/models error: {e}", exc_info=True)
return make_response(
jsonify({"error": {"message": "Internal server error", "type": "server_error"}}),
500,
)

View File

@@ -0,0 +1,433 @@
"""Translate between standard chat completions format and DocsGPT internals.
This module handles:
- Request translation (chat completions -> DocsGPT internal format)
- Response translation (DocsGPT response -> chat completions format)
- Streaming event translation (DocsGPT SSE -> standard SSE chunks)
"""
import json
import time
from typing import Any, Dict, List, Optional
def _get_client_tool_name(tc: Dict) -> str:
"""Return the original tool name for client-facing responses.
For client-side tools the ``tool_name`` field carries the name the
client originally registered. Fall back to ``action_name`` (which
is now the clean LLM-visible name) or ``name``.
"""
return tc.get("tool_name", tc.get("action_name", tc.get("name", "")))
# ---------------------------------------------------------------------------
# Request translation
# ---------------------------------------------------------------------------
def is_continuation(messages: List[Dict]) -> bool:
"""Check if messages represent a tool-call continuation.
A continuation is detected when the last message(s) have ``role: "tool"``
immediately after an assistant message with ``tool_calls``.
"""
if not messages:
return False
# Walk backwards: if we see tool messages before hitting a non-tool, non-assistant message
# and there's an assistant message with tool_calls, it's a continuation.
i = len(messages) - 1
while i >= 0 and messages[i].get("role") == "tool":
i -= 1
if i < 0:
return False
return (
messages[i].get("role") == "assistant"
and bool(messages[i].get("tool_calls"))
)
def extract_tool_results(messages: List[Dict]) -> List[Dict]:
"""Extract tool results from trailing tool messages for continuation.
Returns a list of ``tool_actions`` dicts with ``call_id`` and ``result``.
"""
results = []
for msg in reversed(messages):
if msg.get("role") != "tool":
break
call_id = msg.get("tool_call_id", "")
content = msg.get("content", "")
if isinstance(content, str):
try:
content = json.loads(content)
except (json.JSONDecodeError, TypeError):
pass
results.append({"call_id": call_id, "result": content})
results.reverse()
return results
def extract_conversation_id(messages: List[Dict]) -> Optional[str]:
"""Try to extract conversation_id from the assistant message before tool results.
The conversation_id may be stored in a custom field on the assistant message
from a previous response cycle.
"""
for msg in reversed(messages):
if msg.get("role") == "assistant":
# Check docsgpt extension
return msg.get("docsgpt", {}).get("conversation_id")
return None
def extract_system_prompt(messages: List[Dict]) -> Optional[str]:
"""Extract the first system message content from the messages array.
Returns None if no system message is present.
"""
for msg in messages:
if msg.get("role") == "system":
return msg.get("content", "")
return None
def convert_history(messages: List[Dict]) -> List[Dict]:
"""Convert chat completions messages array to DocsGPT history format.
DocsGPT history is a list of ``{prompt, response}`` dicts.
Excludes the last user message (that becomes the ``question``).
"""
history = []
i = 0
while i < len(messages):
msg = messages[i]
if msg.get("role") == "system":
i += 1
continue
if msg.get("role") == "user":
# Look ahead for assistant response
if i + 1 < len(messages) and messages[i + 1].get("role") == "assistant":
content = messages[i + 1].get("content") or ""
history.append({
"prompt": msg.get("content", ""),
"response": content,
})
i += 2
continue
# Last user message without response — skip (it's the question)
i += 1
continue
i += 1
return history
def translate_request(
data: Dict[str, Any], api_key: str
) -> Dict[str, Any]:
"""Translate a chat completions request to DocsGPT internal format.
Args:
data: The incoming request body.
api_key: Agent API key from the Authorization header.
Returns:
Dict suitable for passing to ``StreamProcessor``.
"""
messages = data.get("messages", [])
# Check for continuation (tool results after assistant tool_calls)
if is_continuation(messages):
tool_actions = extract_tool_results(messages)
conversation_id = extract_conversation_id(messages)
if not conversation_id:
conversation_id = data.get("conversation_id")
result = {
"conversation_id": conversation_id,
"tool_actions": tool_actions,
"api_key": api_key,
}
# Carry tools forward for next iteration
if data.get("tools"):
result["client_tools"] = data["tools"]
return result
# Normal request — extract question from last user message
question = ""
for msg in reversed(messages):
if msg.get("role") == "user":
question = msg.get("content", "")
break
history = convert_history(messages)
system_prompt_override = extract_system_prompt(messages)
docsgpt = data.get("docsgpt", {})
result = {
"question": question,
"api_key": api_key,
"history": json.dumps(history),
# Conversations are NOT persisted by default on the v1 endpoint.
# Callers opt in via ``docsgpt.save_conversation: true``.
"save_conversation": bool(docsgpt.get("save_conversation", False)),
}
if system_prompt_override is not None:
result["system_prompt_override"] = system_prompt_override
# Client tools
if data.get("tools"):
result["client_tools"] = data["tools"]
# DocsGPT extensions
if docsgpt.get("attachments"):
result["attachments"] = docsgpt["attachments"]
return result
# ---------------------------------------------------------------------------
# Response translation (non-streaming)
# ---------------------------------------------------------------------------
def translate_response(
conversation_id: str,
answer: str,
sources: Optional[List[Dict]],
tool_calls: Optional[List[Dict]],
thought: str,
model_name: str,
pending_tool_calls: Optional[List[Dict]] = None,
) -> Dict[str, Any]:
"""Translate DocsGPT response to chat completions format.
Args:
conversation_id: The DocsGPT conversation ID.
answer: The assistant's text response.
sources: RAG retrieval sources.
tool_calls: Completed tool call results.
thought: Reasoning/thinking tokens.
model_name: Model/agent identifier.
pending_tool_calls: Pending client-side tool calls (if paused).
Returns:
Dict in the standard chat completions response format.
"""
created = int(time.time())
completion_id = f"chatcmpl-{conversation_id}" if conversation_id else f"chatcmpl-{created}"
# Build message
message: Dict[str, Any] = {"role": "assistant"}
if pending_tool_calls:
# Tool calls pending — return them for client execution
message["content"] = None
message["tool_calls"] = [
{
"id": tc.get("call_id", ""),
"type": "function",
"function": {
"name": _get_client_tool_name(tc),
"arguments": (
json.dumps(tc["arguments"])
if isinstance(tc.get("arguments"), dict)
else tc.get("arguments", "{}")
),
},
}
for tc in pending_tool_calls
]
finish_reason = "tool_calls"
else:
message["content"] = answer
if thought:
message["reasoning_content"] = thought
finish_reason = "stop"
result: Dict[str, Any] = {
"id": completion_id,
"object": "chat.completion",
"created": created,
"model": model_name,
"choices": [
{
"index": 0,
"message": message,
"finish_reason": finish_reason,
}
],
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
},
}
# DocsGPT extensions
docsgpt: Dict[str, Any] = {}
if conversation_id:
docsgpt["conversation_id"] = conversation_id
if sources:
docsgpt["sources"] = sources
if tool_calls:
docsgpt["tool_calls"] = tool_calls
if docsgpt:
result["docsgpt"] = docsgpt
return result
# ---------------------------------------------------------------------------
# Streaming event translation
# ---------------------------------------------------------------------------
def _make_chunk(
completion_id: str,
model_name: str,
delta: Dict[str, Any],
finish_reason: Optional[str] = None,
) -> str:
"""Build a single SSE chunk in the standard streaming format."""
chunk = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model_name,
"choices": [
{
"index": 0,
"delta": delta,
"finish_reason": finish_reason,
}
],
}
return f"data: {json.dumps(chunk)}\n\n"
def _make_docsgpt_chunk(data: Dict[str, Any]) -> str:
"""Build a DocsGPT extension SSE chunk."""
return f"data: {json.dumps({'docsgpt': data})}\n\n"
def translate_stream_event(
event_data: Dict[str, Any],
completion_id: str,
model_name: str,
) -> List[str]:
"""Translate a DocsGPT SSE event dict to standard streaming chunks.
May return 0, 1, or 2 chunks per input event. For example, a completed
tool call produces both a docsgpt extension chunk and nothing on the
standard side (since server-side tool calls aren't surfaced in standard
format).
Args:
event_data: Parsed DocsGPT event dict.
completion_id: The completion ID for this response.
model_name: Model/agent identifier.
Returns:
List of SSE-formatted strings to send to the client.
"""
event_type = event_data.get("type")
chunks: List[str] = []
if event_type == "answer":
chunks.append(
_make_chunk(completion_id, model_name, {"content": event_data.get("answer", "")})
)
elif event_type == "thought":
chunks.append(
_make_chunk(
completion_id, model_name,
{"reasoning_content": event_data.get("thought", "")},
)
)
elif event_type == "source":
chunks.append(
_make_docsgpt_chunk({
"type": "source",
"sources": event_data.get("source", []),
})
)
elif event_type == "tool_call":
tc_data = event_data.get("data", {})
status = tc_data.get("status")
if status == "requires_client_execution":
# Standard: stream as tool_calls delta
args = tc_data.get("arguments", {})
args_str = json.dumps(args) if isinstance(args, dict) else str(args)
chunks.append(
_make_chunk(completion_id, model_name, {
"tool_calls": [{
"index": 0,
"id": tc_data.get("call_id", ""),
"type": "function",
"function": {
"name": _get_client_tool_name(tc_data),
"arguments": args_str,
},
}],
})
)
elif status == "awaiting_approval":
# Extension: approval needed
chunks.append(_make_docsgpt_chunk({"type": "tool_call", "data": tc_data}))
elif status in ("completed", "pending", "error", "denied", "skipped"):
# Extension: tool call progress
chunks.append(_make_docsgpt_chunk({"type": "tool_call", "data": tc_data}))
elif event_type == "tool_calls_pending":
# Standard: finish_reason = tool_calls
chunks.append(
_make_chunk(completion_id, model_name, {}, finish_reason="tool_calls")
)
# Also emit as docsgpt extension
chunks.append(
_make_docsgpt_chunk({
"type": "tool_calls_pending",
"pending_tool_calls": event_data.get("data", {}).get("pending_tool_calls", []),
})
)
elif event_type == "end":
chunks.append(
_make_chunk(completion_id, model_name, {}, finish_reason="stop")
)
chunks.append("data: [DONE]\n\n")
elif event_type == "id":
chunks.append(
_make_docsgpt_chunk({
"type": "id",
"conversation_id": event_data.get("id", ""),
})
)
elif event_type == "error":
# Emit as standard error (non-standard but widely supported)
error_data = {
"error": {
"message": event_data.get("error", "An error occurred"),
"type": "server_error",
}
}
chunks.append(f"data: {json.dumps(error_data)}\n\n")
elif event_type == "structured_answer":
chunks.append(
_make_chunk(
completion_id, model_name,
{"content": event_data.get("answer", "")},
)
)
# Skip: tool_calls (redundant), research_plan, research_progress
return chunks

View File

@@ -17,6 +17,7 @@ from application.api.answer import answer # noqa: E402
from application.api.internal.routes import internal # noqa: E402
from application.api.user.routes import user # noqa: E402
from application.api.connector.routes import connector # noqa: E402
from application.api.v1 import v1_bp # noqa: E402
from application.celery_init import celery # noqa: E402
from application.core.settings import settings # noqa: E402
from application.stt.upload_limits import ( # noqa: E402
@@ -36,6 +37,7 @@ app.register_blueprint(user)
app.register_blueprint(answer)
app.register_blueprint(internal)
app.register_blueprint(connector)
app.register_blueprint(v1_bp)
app.config.update(
UPLOAD_FOLDER="inputs",
CELERY_BROKER_URL=settings.CELERY_BROKER_URL,

View File

@@ -1,6 +1,6 @@
from celery import Celery
from application.core.settings import settings
from celery.signals import setup_logging
from celery.signals import setup_logging, worker_process_init
def make_celery(app_name=__name__):
@@ -20,5 +20,24 @@ def config_loggers(*args, **kwargs):
setup_logging()
@worker_process_init.connect
def _dispose_db_engine_on_fork(*args, **kwargs):
"""Dispose the SQLAlchemy engine pool in each forked Celery worker.
SQLAlchemy connection pools are not fork-safe: file descriptors shared
between the parent and a forked worker will corrupt the pool. Disposing
on ``worker_process_init`` gives every worker its own fresh pool on
first use.
Imported lazily so Celery workers that don't touch Postgres (or where
``POSTGRES_URI`` is unset) don't fail at startup.
"""
try:
from application.storage.db.engine import dispose_engine
except Exception:
return
dispose_engine()
celery = make_celery()
celery.config_from_object("application.celeryconfig")

View File

@@ -0,0 +1,89 @@
"""Normalize user-supplied Postgres URIs for different drivers.
DocsGPT has two Postgres connection strings pointing at potentially
different databases:
* ``POSTGRES_URI`` feeds SQLAlchemy, which needs the
``postgresql+psycopg://`` dialect prefix to pick the psycopg v3 driver.
* ``PGVECTOR_CONNECTION_STRING`` feeds ``psycopg.connect()`` directly
(via libpq) in ``application/vectorstore/pgvector.py``. libpq only
understands ``postgres://`` and ``postgresql://`` — the SQLAlchemy
dialect prefix is an invalid URI from its point of view.
The two fields therefore need opposite normalization so operators don't
have to know which driver a given field feeds. Each normalizer also
silently upgrades the legacy ``postgresql+psycopg2://`` prefix since
psycopg2 is no longer in the project.
This module is deliberately separate from ``application/core/settings.py``
so the Settings class stays focused on field declarations, and the
URI-rewriting logic can be unit-tested without triggering ``.env``
file loading from importing Settings.
"""
from __future__ import annotations
def _rewrite_uri_prefixes(v, rewrites):
"""Shared URI prefix rewriter used by both normalizers below.
Strips whitespace, returns ``None`` for empty / ``"none"`` values,
applies the first matching rewrite, and passes unrecognised input
through so downstream consumers (SQLAlchemy, libpq) can produce
their own error messages rather than us silently eating a
misconfiguration.
"""
if v is None:
return None
if not isinstance(v, str):
return v
v = v.strip()
if not v or v.lower() == "none":
return None
for prefix, target in rewrites:
if v.startswith(prefix):
return target + v[len(prefix):]
return v
# POSTGRES_URI feeds SQLAlchemy, which needs a ``postgresql+psycopg://``
# dialect prefix to select the psycopg v3 driver. Normalize the
# operator-friendly forms TOWARD that dialect.
_POSTGRES_URI_REWRITES = (
("postgresql+psycopg2://", "postgresql+psycopg://"),
("postgresql://", "postgresql+psycopg://"),
("postgres://", "postgresql+psycopg://"),
)
# PGVECTOR_CONNECTION_STRING feeds ``psycopg.connect()`` directly in
# application/vectorstore/pgvector.py — NOT SQLAlchemy. libpq only
# understands ``postgres://`` and ``postgresql://``; the SQLAlchemy
# dialect prefix is an invalid URI from libpq's point of view. Strip it
# if the operator accidentally copied their POSTGRES_URI value here.
_PGVECTOR_CONNECTION_STRING_REWRITES = (
("postgresql+psycopg2://", "postgresql://"),
("postgresql+psycopg://", "postgresql://"),
)
def normalize_postgres_uri(v):
"""Normalize a user-supplied POSTGRES_URI to the SQLAlchemy psycopg3 form.
Accepts the forms operators naturally write (``postgres://``,
``postgresql://``) and rewrites them to ``postgresql+psycopg://``.
Unknown schemes pass through unchanged so SQLAlchemy can produce its
own dialect-not-found error.
"""
return _rewrite_uri_prefixes(v, _POSTGRES_URI_REWRITES)
def normalize_pgvector_connection_string(v):
"""Normalize a user-supplied PGVECTOR_CONNECTION_STRING for libpq.
Strips the SQLAlchemy dialect prefix if the operator accidentally
copied their POSTGRES_URI value here — libpq can't parse it.
User-friendly forms (``postgres://``, ``postgresql://``) pass
through unchanged since libpq accepts them natively.
"""
return _rewrite_uri_prefixes(v, _PGVECTOR_CONNECTION_STRING_REWRITES)

View File

@@ -8,6 +8,12 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from application.core.db_uri import ( # noqa: E402
normalize_pgvector_connection_string,
normalize_postgres_uri,
)
class Settings(BaseSettings):
model_config = SettingsConfigDict(extra="ignore")
@@ -22,6 +28,11 @@ class Settings(BaseSettings):
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
MONGO_URI: str = "mongodb://localhost:27017/docsgpt"
MONGO_DB_NAME: str = "docsgpt"
# User-data Postgres DB.
POSTGRES_URI: Optional[str] = None
# MongoDB→Postgres migration: dual-write to Postgres (Mongo stays source of truth)
USE_POSTGRES: bool = False
LLM_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf")
DEFAULT_MAX_HISTORY: int = 150
DEFAULT_LLM_TOKEN_LIMIT: int = 128000 # Fallback when model not found in registry
@@ -59,6 +70,10 @@ class Settings(BaseSettings):
MICROSOFT_TENANT_ID: Optional[str] = "common" # Azure AD Tenant ID (or 'common' for multi-tenant)
MICROSOFT_AUTHORITY: Optional[str] = None # e.g., "https://login.microsoftonline.com/{tenant_id}"
# Confluence Cloud integration
CONFLUENCE_CLIENT_ID: Optional[str] = None
CONFLUENCE_CLIENT_SECRET: Optional[str] = None
# GitHub source
GITHUB_ACCESS_TOKEN: Optional[str] = None # PAT token with read repo access
@@ -117,7 +132,10 @@ class Settings(BaseSettings):
QDRANT_PATH: Optional[str] = None
QDRANT_DISTANCE_FUNC: str = "Cosine"
# PGVector vectorstore config
# PGVector vectorstore config. Write the URI in whichever form you
# prefer — ``postgres://``, ``postgresql://``, or even the SQLAlchemy
# dialect form (``postgresql+psycopg://``) are all accepted and
# normalized internally for ``psycopg.connect()``.
PGVECTOR_CONNECTION_STRING: Optional[str] = None
# Milvus vectorstore config
MILVUS_COLLECTION_NAME: Optional[str] = "docsgpt"
@@ -156,6 +174,16 @@ class Settings(BaseSettings):
COMPRESSION_PROMPT_VERSION: str = "v1.0" # Track prompt iterations
COMPRESSION_MAX_HISTORY_POINTS: int = 3 # Keep only last N compression points to prevent DB bloat
@field_validator("POSTGRES_URI", mode="before")
@classmethod
def _normalize_postgres_uri_validator(cls, v):
return normalize_postgres_uri(v)
@field_validator("PGVECTOR_CONNECTION_STRING", mode="before")
@classmethod
def _normalize_pgvector_connection_string_validator(cls, v):
return normalize_pgvector_connection_string(v)
@field_validator(
"API_KEY",
"OPENAI_API_KEY",

View File

@@ -167,6 +167,8 @@ class GoogleLLM(BaseLLM):
return "\n".join(parts)
return ""
import json as _json
for message in messages:
role = message.get("role")
content = message.get("content")
@@ -180,9 +182,66 @@ class GoogleLLM(BaseLLM):
if role == "assistant":
role = "model"
elif role == "tool":
role = "model"
parts = []
# Standard format: assistant message with tool_calls array
msg_tool_calls = message.get("tool_calls")
if msg_tool_calls and role == "model":
for tc in msg_tool_calls:
func = tc.get("function", {})
args = func.get("arguments", "{}")
if isinstance(args, str):
try:
args = _json.loads(args)
except (_json.JSONDecodeError, TypeError):
args = {}
cleaned_args = self._remove_null_values(args)
thought_sig = tc.get("thought_signature")
if thought_sig:
parts.append(
types.Part(
functionCall=types.FunctionCall(
name=func.get("name", ""),
args=cleaned_args,
),
thoughtSignature=thought_sig,
)
)
else:
parts.append(
types.Part.from_function_call(
name=func.get("name", ""),
args=cleaned_args,
)
)
if parts:
cleaned_messages.append(types.Content(role=role, parts=parts))
continue
# Standard format: tool message with tool_call_id
tool_call_id = message.get("tool_call_id")
if role == "tool" and tool_call_id is not None:
result_content = content
if isinstance(result_content, str):
try:
result_content = _json.loads(result_content)
except (_json.JSONDecodeError, TypeError):
pass
# Google expects function_response name — extract from tool_call_id context
# We use a placeholder name since Google API doesn't require exact match
parts.append(
types.Part.from_function_response(
name="tool_result",
response={"result": result_content},
)
)
cleaned_messages.append(types.Content(role="model", parts=parts))
continue
if role == "tool":
role = "model"
if role and content is not None:
if isinstance(content, str):
parts = [types.Part.from_text(text=content)]
@@ -191,15 +250,11 @@ class GoogleLLM(BaseLLM):
if "text" in item:
parts.append(types.Part.from_text(text=item["text"]))
elif "function_call" in item:
# Remove null values from args to avoid API errors
# Legacy format support
cleaned_args = self._remove_null_values(
item["function_call"]["args"]
)
# Create function call part with thought_signature if present
# For Gemini 3 models, we need to include thought_signature
if "thought_signature" in item:
# Use Part constructor with functionCall and thoughtSignature
parts.append(
types.Part(
functionCall=types.FunctionCall(
@@ -210,7 +265,6 @@ class GoogleLLM(BaseLLM):
)
)
else:
# Use helper method when no thought_signature
parts.append(
types.Part.from_function_call(
name=item["function_call"]["name"],

View File

@@ -1,3 +1,4 @@
import json
import logging
import uuid
from abc import ABC, abstractmethod
@@ -315,10 +316,34 @@ class LLMHandler(ABC):
current_prompt = self._extract_text_from_content(content)
elif role in {"assistant", "model"}:
# If this assistant turn contains tool calls, collect them; otherwise commit a response.
# Standard format: tool_calls array on assistant message
msg_tool_calls = message.get("tool_calls")
if msg_tool_calls:
for tc in msg_tool_calls:
call_id = tc.get("id") or str(uuid.uuid4())
func = tc.get("function", {})
args = func.get("arguments")
if isinstance(args, str):
try:
args = json.loads(args)
except (json.JSONDecodeError, TypeError):
pass
current_tool_calls[call_id] = {
"tool_name": "unknown_tool",
"action_name": func.get("name"),
"arguments": args,
"result": None,
"status": "called",
"call_id": call_id,
}
continue
# Legacy format: function_call/function_response in content list
if isinstance(content, list):
has_fc = False
for item in content:
if "function_call" in item:
has_fc = True
fc = item["function_call"]
call_id = fc.get("call_id") or str(uuid.uuid4())
current_tool_calls[call_id] = {
@@ -329,37 +354,30 @@ class LLMHandler(ABC):
"status": "called",
"call_id": call_id,
}
elif "function_response" in item:
fr = item["function_response"]
call_id = fr.get("call_id") or str(uuid.uuid4())
current_tool_calls[call_id] = {
"tool_name": "unknown_tool",
"action_name": fr.get("name"),
"arguments": None,
"result": fr.get("response", {}).get("result"),
"status": "completed",
"call_id": call_id,
}
# No direct assistant text here; continue to next message
continue
if has_fc:
continue
response_text = self._extract_text_from_content(content)
_commit_query(response_text)
elif role == "tool":
# Attach tool outputs to the latest pending tool call if possible
# Standard format: tool_call_id on tool message
call_id = message.get("tool_call_id")
tool_text = self._extract_text_from_content(content)
# Attempt to parse function_response style
call_id = None
if isinstance(content, list):
for item in content:
if "function_response" in item and item["function_response"].get("call_id"):
call_id = item["function_response"]["call_id"]
break
if call_id and call_id in current_tool_calls:
current_tool_calls[call_id]["result"] = tool_text
current_tool_calls[call_id]["status"] = "completed"
elif queries:
# Legacy: function_response in content list
elif isinstance(content, list):
for item in content:
if "function_response" in item:
legacy_id = item["function_response"].get("call_id")
if legacy_id and legacy_id in current_tool_calls:
current_tool_calls[legacy_id]["result"] = tool_text
current_tool_calls[legacy_id]["status"] = "completed"
break
elif call_id is None and queries:
queries[-1].setdefault("tool_calls", []).append(
{
"tool_name": "unknown_tool",
@@ -648,6 +666,13 @@ class LLMHandler(ABC):
"""
Execute tool calls and update conversation history.
When a tool requires approval or client-side execution, it is
collected as a pending action instead of being executed. The
generator returns ``(updated_messages, pending_actions)`` where
*pending_actions* is ``None`` when every tool was executed
normally, or a list of dicts describing actions the client must
resolve before the LLM loop can continue.
Args:
agent: The agent instance
tool_calls: List of tool calls to execute
@@ -655,9 +680,11 @@ class LLMHandler(ABC):
messages: Current conversation history
Returns:
Updated messages list
Tuple of (updated_messages, pending_actions).
pending_actions is None if all tools executed, otherwise a list.
"""
updated_messages = messages.copy()
pending_actions: List[Dict] = []
for i, call in enumerate(tool_calls):
# Check context limit before executing tool call
@@ -763,6 +790,29 @@ class LLMHandler(ABC):
# Set flag on agent
agent.context_limit_reached = True
break
# ---- Pause check: approval / client-side execution ----
llm_class = agent.llm.__class__.__name__
pause_info = agent.tool_executor.check_pause(
tools_dict, call, llm_class
)
if pause_info:
# Yield pause event so the client knows this tool is waiting
yield {
"type": "tool_call",
"data": {
"tool_name": pause_info["tool_name"],
"call_id": pause_info["call_id"],
"action_name": pause_info.get("llm_name", pause_info["name"]),
"arguments": pause_info["arguments"],
"status": pause_info["pause_type"],
},
}
pending_actions.append(pause_info)
# Do NOT add messages for pending tools here.
# They will be added on resume to keep call/result pairs together.
continue
try:
self.tool_calls.append(call)
tool_executor_gen = agent._execute_tool_action(tools_dict, call)
@@ -772,25 +822,30 @@ class LLMHandler(ABC):
except StopIteration as e:
tool_response, call_id = e.value
break
function_call_content = {
"function_call": {
"name": call.name,
"args": call.arguments,
"call_id": call_id,
}
}
# Include thought_signature for Google Gemini 3 models
# It should be at the same level as function_call, not inside it
if call.thought_signature:
function_call_content["thought_signature"] = call.thought_signature
updated_messages.append(
{
"role": "assistant",
"content": [function_call_content],
}
)
# Standard internal format: assistant message with tool_calls array
args_str = (
json.dumps(call.arguments)
if isinstance(call.arguments, dict)
else call.arguments
)
tool_call_obj = {
"id": call_id,
"type": "function",
"function": {
"name": call.name,
"arguments": args_str,
},
}
# Preserve thought_signature for Google Gemini 3 models
if call.thought_signature:
tool_call_obj["thought_signature"] = call.thought_signature
updated_messages.append({
"role": "assistant",
"content": None,
"tool_calls": [tool_call_obj],
})
updated_messages.append(self.create_tool_message(call, tool_response))
except Exception as e:
@@ -802,16 +857,15 @@ class LLMHandler(ABC):
error_message = self.create_tool_message(error_call, error_response)
updated_messages.append(error_message)
call_parts = call.name.split("_")
if len(call_parts) >= 2:
tool_id = call_parts[-1] # Last part is tool ID (e.g., "1")
action_name = "_".join(call_parts[:-1])
tool_name = tools_dict.get(tool_id, {}).get("name", "unknown_tool")
full_action_name = f"{action_name}_{tool_id}"
mapping = agent.tool_executor._name_to_tool
if call.name in mapping:
resolved_tool_id, _ = mapping[call.name]
tool_name = tools_dict.get(resolved_tool_id, {}).get(
"name", "unknown_tool"
)
else:
tool_name = "unknown_tool"
action_name = call.name
full_action_name = call.name
full_action_name = call.name
yield {
"type": "tool_call",
"data": {
@@ -823,7 +877,7 @@ class LLMHandler(ABC):
"status": "error",
},
}
return updated_messages
return updated_messages, pending_actions if pending_actions else None
def handle_non_streaming(
self, agent, response: Any, tools_dict: Dict, messages: List[Dict]
@@ -851,8 +905,22 @@ class LLMHandler(ABC):
try:
yield next(tool_handler_gen)
except StopIteration as e:
messages = e.value
messages, pending_actions = e.value
break
# If tools need approval or client execution, pause the loop
if pending_actions:
agent._pending_continuation = {
"messages": messages,
"pending_tool_calls": pending_actions,
"tools_dict": tools_dict,
}
yield {
"type": "tool_calls_pending",
"data": {"pending_tool_calls": pending_actions},
}
return ""
response = agent.llm.gen(
model=agent.model_id, messages=messages, tools=agent.tools
)
@@ -913,10 +981,23 @@ class LLMHandler(ABC):
try:
yield next(tool_handler_gen)
except StopIteration as e:
messages = e.value
messages, pending_actions = e.value
break
tool_calls = {}
# If tools need approval or client execution, pause the loop
if pending_actions:
agent._pending_continuation = {
"messages": messages,
"pending_tool_calls": pending_actions,
"tools_dict": tools_dict,
}
yield {
"type": "tool_calls_pending",
"data": {"pending_tool_calls": pending_actions},
}
return
# Check if context limit was reached during tool execution
if hasattr(agent, 'context_limit_reached') and agent.context_limit_reached:
# Add system message warning about context limit

View File

@@ -67,18 +67,18 @@ class GoogleLLMHandler(LLMHandler):
)
def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict:
"""Create Google-style tool message."""
"""Create a tool result message in the standard internal format."""
import json as _json
content = (
_json.dumps(result)
if not isinstance(result, str)
else result
)
return {
"role": "model",
"content": [
{
"function_response": {
"name": tool_call.name,
"response": {"result": result},
}
}
],
"role": "tool",
"tool_call_id": tool_call.id,
"content": content,
}
def _iterate_stream(self, response: Any) -> Generator:

View File

@@ -37,18 +37,18 @@ class OpenAILLMHandler(LLMHandler):
)
def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict:
"""Create OpenAI-style tool message."""
"""Create a tool result message in the standard internal format."""
import json as _json
content = (
_json.dumps(result)
if not isinstance(result, str)
else result
)
return {
"role": "tool",
"content": [
{
"function_response": {
"name": tool_call.name,
"response": {"result": result},
"call_id": tool_call.id,
}
}
],
"tool_call_id": tool_call.id,
"content": content,
}
def _iterate_stream(self, response: Any) -> Generator:

View File

@@ -91,16 +91,52 @@ class OpenAILLM(BaseLLM):
if role == "model":
role = "assistant"
# Standard format: assistant message with tool_calls (passthrough)
tool_calls = message.get("tool_calls")
if tool_calls and role == "assistant":
cleaned_tcs = []
for tc in tool_calls:
func = tc.get("function", {})
args = func.get("arguments", "{}")
if isinstance(args, dict):
args = json.dumps(self._remove_null_values(args))
elif isinstance(args, str):
try:
parsed = json.loads(args)
args = json.dumps(self._remove_null_values(parsed))
except (json.JSONDecodeError, TypeError):
pass
cleaned_tcs.append({
"id": tc.get("id", ""),
"type": "function",
"function": {"name": func.get("name", ""), "arguments": args},
})
cleaned_messages.append({
"role": "assistant",
"content": None,
"tool_calls": cleaned_tcs,
})
continue
# Standard format: tool message with tool_call_id (passthrough)
tool_call_id = message.get("tool_call_id")
if role == "tool" and tool_call_id is not None:
cleaned_messages.append({
"role": "tool",
"tool_call_id": tool_call_id,
"content": content if isinstance(content, str) else json.dumps(content),
})
continue
if role and content is not None:
if isinstance(content, str):
cleaned_messages.append({"role": role, "content": content})
elif isinstance(content, list):
# Collect all content parts into a single message
content_parts = []
for item in content:
# Legacy format support: function_call / function_response
if "function_call" in item:
# Function calls need their own message
args = item["function_call"]["args"]
if isinstance(args, str):
try:
@@ -116,28 +152,20 @@ class OpenAILLM(BaseLLM):
"arguments": json.dumps(cleaned_args),
},
}
cleaned_messages.append(
{
"role": "assistant",
"content": None,
"tool_calls": [tool_call],
}
)
cleaned_messages.append({
"role": "assistant",
"content": None,
"tool_calls": [tool_call],
})
elif "function_response" in item:
# Function responses need their own message
cleaned_messages.append(
{
"role": "tool",
"tool_call_id": item["function_response"][
"call_id"
],
"content": json.dumps(
item["function_response"]["response"]["result"]
),
}
)
cleaned_messages.append({
"role": "tool",
"tool_call_id": item["function_response"]["call_id"],
"content": json.dumps(
item["function_response"]["response"]["result"]
),
})
elif isinstance(item, dict):
# Collect content parts (text, images, files) into a single message
if "type" in item and item["type"] == "text" and "text" in item:
content_parts.append(item)
elif "type" in item and item["type"] == "file" and "file" in item:
@@ -145,10 +173,7 @@ class OpenAILLM(BaseLLM):
elif "type" in item and item["type"] == "image_url" and "image_url" in item:
content_parts.append(item)
elif "text" in item and "type" not in item:
# Legacy format: {"text": "..."} without type
content_parts.append({"type": "text", "text": item["text"]})
# Add the collected content parts as a single message
if content_parts:
cleaned_messages.append({"role": role, "content": content_parts})
else:

View File

@@ -157,5 +157,21 @@ def _log_to_mongodb(
user_logs_collection.insert_one(log_entry)
logging.debug(f"Logged activity to MongoDB: {activity_id}")
from application.storage.db.dual_write import dual_write
from application.storage.db.repositories.stack_logs import StackLogsRepository
dual_write(
StackLogsRepository,
lambda repo, e=log_entry: repo.insert(
activity_id=e["id"],
endpoint=e.get("endpoint"),
level=e.get("level"),
user_id=e.get("user"),
api_key=e.get("api_key"),
query=e.get("query"),
stacks=e.get("stacks"),
),
)
except Exception as e:
logging.error(f"Failed to log to MongoDB: {e}", exc_info=True)

View File

@@ -0,0 +1,4 @@
from .auth import ConfluenceAuth
from .loader import ConfluenceLoader
__all__ = ["ConfluenceAuth", "ConfluenceLoader"]

View File

@@ -0,0 +1,216 @@
import datetime
import logging
from typing import Any, Dict, Optional
from urllib.parse import urlencode
import requests
from application.core.settings import settings
from application.parser.connectors.base import BaseConnectorAuth
logger = logging.getLogger(__name__)
class ConfluenceAuth(BaseConnectorAuth):
SCOPES = [
"read:page:confluence",
"read:space:confluence",
"read:attachment:confluence",
"read:me",
"offline_access",
]
AUTH_URL = "https://auth.atlassian.com/authorize"
TOKEN_URL = "https://auth.atlassian.com/oauth/token"
RESOURCES_URL = "https://api.atlassian.com/oauth/token/accessible-resources"
ME_URL = "https://api.atlassian.com/me"
def __init__(self):
self.client_id = settings.CONFLUENCE_CLIENT_ID
self.client_secret = settings.CONFLUENCE_CLIENT_SECRET
self.redirect_uri = settings.CONNECTOR_REDIRECT_BASE_URI
if not self.client_id or not self.client_secret:
raise ValueError(
"Confluence OAuth credentials not configured. "
"Please set CONFLUENCE_CLIENT_ID and CONFLUENCE_CLIENT_SECRET in settings."
)
def get_authorization_url(self, state: Optional[str] = None) -> str:
params = {
"audience": "api.atlassian.com",
"client_id": self.client_id,
"scope": " ".join(self.SCOPES),
"redirect_uri": self.redirect_uri,
"state": state,
"response_type": "code",
"prompt": "consent",
}
return f"{self.AUTH_URL}?{urlencode(params)}"
def exchange_code_for_tokens(self, authorization_code: str) -> Dict[str, Any]:
if not authorization_code:
raise ValueError("Authorization code is required")
response = requests.post(
self.TOKEN_URL,
json={
"grant_type": "authorization_code",
"client_id": self.client_id,
"client_secret": self.client_secret,
"code": authorization_code,
"redirect_uri": self.redirect_uri,
},
headers={"Content-Type": "application/json"},
timeout=30,
)
response.raise_for_status()
token_data = response.json()
access_token = token_data.get("access_token")
if not access_token:
raise ValueError("OAuth flow did not return an access token")
refresh_token = token_data.get("refresh_token")
if not refresh_token:
raise ValueError("OAuth flow did not return a refresh token")
expires_in = token_data.get("expires_in", 3600)
expiry = (
datetime.datetime.now(datetime.timezone.utc)
+ datetime.timedelta(seconds=expires_in)
).isoformat()
cloud_id = self._fetch_cloud_id(access_token)
user_info = self._fetch_user_info(access_token)
return {
"access_token": access_token,
"refresh_token": refresh_token,
"token_uri": self.TOKEN_URL,
"scopes": self.SCOPES,
"expiry": expiry,
"cloud_id": cloud_id,
"user_info": {
"name": user_info.get("display_name", ""),
"email": user_info.get("email", ""),
},
}
def refresh_access_token(self, refresh_token: str) -> Dict[str, Any]:
if not refresh_token:
raise ValueError("Refresh token is required")
response = requests.post(
self.TOKEN_URL,
json={
"grant_type": "refresh_token",
"client_id": self.client_id,
"client_secret": self.client_secret,
"refresh_token": refresh_token,
},
headers={"Content-Type": "application/json"},
timeout=30,
)
response.raise_for_status()
token_data = response.json()
access_token = token_data.get("access_token")
new_refresh_token = token_data.get("refresh_token", refresh_token)
expires_in = token_data.get("expires_in", 3600)
expiry = (
datetime.datetime.now(datetime.timezone.utc)
+ datetime.timedelta(seconds=expires_in)
).isoformat()
cloud_id = self._fetch_cloud_id(access_token)
return {
"access_token": access_token,
"refresh_token": new_refresh_token,
"token_uri": self.TOKEN_URL,
"scopes": self.SCOPES,
"expiry": expiry,
"cloud_id": cloud_id,
}
def is_token_expired(self, token_info: Dict[str, Any]) -> bool:
if not token_info:
return True
expiry = token_info.get("expiry")
if not expiry:
return bool(token_info.get("access_token"))
try:
expiry_dt = datetime.datetime.fromisoformat(expiry)
now = datetime.datetime.now(datetime.timezone.utc)
return now >= expiry_dt - datetime.timedelta(seconds=60)
except Exception:
return True
def get_token_info_from_session(self, session_token: str) -> Dict[str, Any]:
from application.core.mongo_db import MongoDB
from application.core.settings import settings as app_settings
mongo = MongoDB.get_client()
db = mongo[app_settings.MONGO_DB_NAME]
session = db["connector_sessions"].find_one({"session_token": session_token})
if not session:
raise ValueError(f"Invalid session token: {session_token}")
token_info = session.get("token_info")
if not token_info:
raise ValueError("Session missing token information")
required = ["access_token", "refresh_token", "cloud_id"]
missing = [f for f in required if not token_info.get(f)]
if missing:
raise ValueError(f"Missing required token fields: {missing}")
return token_info
def sanitize_token_info(
self, token_info: Dict[str, Any], **extra_fields
) -> Dict[str, Any]:
return super().sanitize_token_info(
token_info,
cloud_id=token_info.get("cloud_id"),
**extra_fields,
)
def _fetch_cloud_id(self, access_token: str) -> str:
response = requests.get(
self.RESOURCES_URL,
headers={
"Authorization": f"Bearer {access_token}",
"Accept": "application/json",
},
timeout=30,
)
response.raise_for_status()
resources = response.json()
if not resources:
raise ValueError("No accessible Confluence sites found for this account")
return resources[0]["id"]
def _fetch_user_info(self, access_token: str) -> Dict[str, Any]:
try:
response = requests.get(
self.ME_URL,
headers={
"Authorization": f"Bearer {access_token}",
"Accept": "application/json",
},
timeout=30,
)
response.raise_for_status()
return response.json()
except Exception as e:
logger.warning("Could not fetch user info: %s", e)
return {}

View File

@@ -0,0 +1,416 @@
import functools
import logging
import os
from typing import Any, Dict, List, Optional
import requests
from application.parser.connectors.base import BaseConnectorLoader
from application.parser.connectors.confluence.auth import ConfluenceAuth
from application.parser.schema.base import Document
logger = logging.getLogger(__name__)
API_V2 = "https://api.atlassian.com/ex/confluence/{cloud_id}/wiki/api/v2"
DOWNLOAD_BASE = "https://api.atlassian.com/ex/confluence/{cloud_id}/wiki"
SUPPORTED_ATTACHMENT_TYPES = {
"application/pdf": ".pdf",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
"application/vnd.openxmlformats-officedocument.presentationml.presentation": ".pptx",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
"application/msword": ".doc",
"application/vnd.ms-powerpoint": ".ppt",
"application/vnd.ms-excel": ".xls",
"text/plain": ".txt",
"text/csv": ".csv",
"text/html": ".html",
"text/markdown": ".md",
"application/json": ".json",
"application/epub+zip": ".epub",
"image/jpeg": ".jpg",
"image/png": ".png",
}
def _retry_on_auth_failure(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
try:
return func(self, *args, **kwargs)
except requests.exceptions.HTTPError as e:
if e.response is not None and e.response.status_code in (401, 403):
logger.info(
"Auth failure in %s, refreshing token and retrying", func.__name__
)
try:
new_token_info = self.auth.refresh_access_token(self.refresh_token)
self.access_token = new_token_info["access_token"]
self.refresh_token = new_token_info.get(
"refresh_token", self.refresh_token
)
self._persist_refreshed_tokens(new_token_info)
except Exception as refresh_err:
raise ValueError(
f"Authentication failed and could not be refreshed: {refresh_err}"
) from e
return func(self, *args, **kwargs)
raise
return wrapper
class ConfluenceLoader(BaseConnectorLoader):
def __init__(self, session_token: str):
self.auth = ConfluenceAuth()
self.session_token = session_token
token_info = self.auth.get_token_info_from_session(session_token)
self.access_token = token_info["access_token"]
self.refresh_token = token_info["refresh_token"]
self.cloud_id = token_info["cloud_id"]
self.base_url = API_V2.format(cloud_id=self.cloud_id)
self.download_base = DOWNLOAD_BASE.format(cloud_id=self.cloud_id)
self.next_page_token = None
def _headers(self) -> Dict[str, str]:
return {
"Authorization": f"Bearer {self.access_token}",
"Accept": "application/json",
}
def _persist_refreshed_tokens(self, token_info: Dict[str, Any]) -> None:
try:
from application.core.mongo_db import MongoDB
from application.core.settings import settings as app_settings
sanitized = self.auth.sanitize_token_info(token_info)
mongo = MongoDB.get_client()
db = mongo[app_settings.MONGO_DB_NAME]
db["connector_sessions"].update_one(
{"session_token": self.session_token},
{"$set": {"token_info": sanitized}},
)
except Exception as e:
logger.warning("Failed to persist refreshed tokens: %s", e)
@_retry_on_auth_failure
def load_data(self, inputs: Dict[str, Any]) -> List[Document]:
folder_id = inputs.get("folder_id")
file_ids = inputs.get("file_ids", [])
limit = inputs.get("limit", 100)
list_only = inputs.get("list_only", False)
page_token = inputs.get("page_token")
search_query = inputs.get("search_query")
self.next_page_token = None
if file_ids:
return self._load_pages_by_ids(file_ids, list_only, search_query)
if folder_id:
return self._list_pages_in_space(
folder_id, limit, list_only, page_token, search_query
)
return self._list_spaces(limit, page_token, search_query)
@_retry_on_auth_failure
def download_to_directory(self, local_dir: str, source_config: dict = None) -> dict:
config = source_config or getattr(self, "config", {})
file_ids = config.get("file_ids", [])
folder_ids = config.get("folder_ids", [])
files_downloaded = 0
os.makedirs(local_dir, exist_ok=True)
if isinstance(file_ids, str):
file_ids = [file_ids]
if isinstance(folder_ids, str):
folder_ids = [folder_ids]
for page_id in file_ids:
if self._download_page(page_id, local_dir):
files_downloaded += 1
files_downloaded += self._download_page_attachments(page_id, local_dir)
for space_id in folder_ids:
files_downloaded += self._download_space(space_id, local_dir)
return {
"files_downloaded": files_downloaded,
"directory_path": local_dir,
"empty_result": files_downloaded == 0,
"source_type": "confluence",
"config_used": config,
}
def _list_spaces(
self, limit: int, cursor: Optional[str], search_query: Optional[str]
) -> List[Document]:
documents: List[Document] = []
params: Dict[str, Any] = {"limit": min(limit, 250)}
if cursor:
params["cursor"] = cursor
response = requests.get(
f"{self.base_url}/spaces",
headers=self._headers(),
params=params,
timeout=30,
)
response.raise_for_status()
data = response.json()
for space in data.get("results", []):
name = space.get("name", "")
if search_query and search_query.lower() not in name.lower():
continue
documents.append(
Document(
text="",
doc_id=space["id"],
extra_info={
"file_name": name,
"mime_type": "folder",
"size": None,
"created_time": space.get("createdAt"),
"modified_time": None,
"source": "confluence",
"is_folder": True,
"space_key": space.get("key"),
},
)
)
next_link = data.get("_links", {}).get("next")
self.next_page_token = self._extract_cursor(next_link)
return documents
def _list_pages_in_space(
self,
space_id: str,
limit: int,
list_only: bool,
cursor: Optional[str],
search_query: Optional[str],
) -> List[Document]:
documents: List[Document] = []
params: Dict[str, Any] = {"limit": min(limit, 250)}
if cursor:
params["cursor"] = cursor
response = requests.get(
f"{self.base_url}/spaces/{space_id}/pages",
headers=self._headers(),
params=params,
timeout=30,
)
response.raise_for_status()
data = response.json()
for page in data.get("results", []):
title = page.get("title", "")
if search_query and search_query.lower() not in title.lower():
continue
doc = self._page_to_document(
page, load_content=not list_only, space_id=space_id
)
if doc:
documents.append(doc)
next_link = data.get("_links", {}).get("next")
self.next_page_token = self._extract_cursor(next_link)
return documents
def _load_pages_by_ids(
self, page_ids: List[str], list_only: bool, search_query: Optional[str]
) -> List[Document]:
documents: List[Document] = []
for page_id in page_ids:
try:
params: Dict[str, str] = {}
if not list_only:
params["body-format"] = "storage"
response = requests.get(
f"{self.base_url}/pages/{page_id}",
headers=self._headers(),
params=params,
timeout=30,
)
response.raise_for_status()
page = response.json()
title = page.get("title", "")
if search_query and search_query.lower() not in title.lower():
continue
doc = self._page_to_document(page, load_content=not list_only)
if doc:
documents.append(doc)
except Exception as e:
logger.error("Error loading page %s: %s", page_id, e)
return documents
def _page_to_document(
self,
page: Dict[str, Any],
load_content: bool = False,
space_id: Optional[str] = None,
) -> Optional[Document]:
page_id = page.get("id")
title = page.get("title", "Unknown")
version = page.get("version", {})
modified_time = version.get("createdAt") if isinstance(version, dict) else None
created_time = page.get("createdAt")
resolved_space_id = space_id or page.get("spaceId")
text = ""
if load_content:
body = page.get("body", {})
storage = body.get("storage", {}) if isinstance(body, dict) else {}
text = storage.get("value", "") if isinstance(storage, dict) else ""
return Document(
text=text,
doc_id=str(page_id),
extra_info={
"file_name": title,
"mime_type": "text/html",
"size": len(text) if text else None,
"created_time": created_time,
"modified_time": modified_time,
"source": "confluence",
"is_folder": False,
"page_id": str(page_id),
"space_id": resolved_space_id,
"cloud_id": self.cloud_id,
},
)
def _download_page(self, page_id: str, local_dir: str) -> bool:
try:
response = requests.get(
f"{self.base_url}/pages/{page_id}",
headers=self._headers(),
params={"body-format": "storage"},
timeout=30,
)
response.raise_for_status()
page = response.json()
title = page.get("title", page_id)
safe_name = "".join(c if c.isalnum() or c in " -_" else "_" for c in title)
body = page.get("body", {}).get("storage", {}).get("value", "")
file_path = os.path.join(local_dir, f"{safe_name}.html")
with open(file_path, "w", encoding="utf-8") as f:
f.write(body)
return True
except Exception as e:
logger.error("Error downloading page %s: %s", page_id, e)
return False
def _download_page_attachments(self, page_id: str, local_dir: str) -> int:
downloaded = 0
try:
cursor = None
while True:
params: Dict[str, Any] = {"limit": 100}
if cursor:
params["cursor"] = cursor
response = requests.get(
f"{self.base_url}/pages/{page_id}/attachments",
headers=self._headers(),
params=params,
timeout=30,
)
response.raise_for_status()
data = response.json()
for att in data.get("results", []):
media_type = att.get("mediaType", "")
if media_type not in SUPPORTED_ATTACHMENT_TYPES:
continue
download_link = att.get("_links", {}).get("download")
if not download_link:
continue
raw_name = att.get("title", att.get("id", "attachment"))
file_name = "".join(
c if c.isalnum() or c in " -_." else "_"
for c in os.path.basename(raw_name)
) or "attachment"
file_path = os.path.join(local_dir, file_name)
url = f"{self.download_base}{download_link}"
file_resp = requests.get(
url, headers=self._headers(), timeout=60, stream=True
)
file_resp.raise_for_status()
with open(file_path, "wb") as f:
for chunk in file_resp.iter_content(chunk_size=8192):
f.write(chunk)
downloaded += 1
next_link = data.get("_links", {}).get("next")
cursor = self._extract_cursor(next_link)
if not cursor:
break
except Exception as e:
logger.error("Error downloading attachments for page %s: %s", page_id, e)
return downloaded
def _download_space(self, space_id: str, local_dir: str) -> int:
downloaded = 0
cursor = None
while True:
params: Dict[str, Any] = {"limit": 250}
if cursor:
params["cursor"] = cursor
try:
response = requests.get(
f"{self.base_url}/spaces/{space_id}/pages",
headers=self._headers(),
params=params,
timeout=30,
)
response.raise_for_status()
data = response.json()
except Exception as e:
logger.error("Error listing pages in space %s: %s", space_id, e)
break
for page in data.get("results", []):
page_id = page.get("id")
if self._download_page(str(page_id), local_dir):
downloaded += 1
downloaded += self._download_page_attachments(str(page_id), local_dir)
next_link = data.get("_links", {}).get("next")
cursor = self._extract_cursor(next_link)
if not cursor:
break
return downloaded
@staticmethod
def _extract_cursor(next_link: Optional[str]) -> Optional[str]:
if not next_link:
return None
from urllib.parse import parse_qs, urlparse
parsed = urlparse(next_link)
cursors = parse_qs(parsed.query).get("cursor")
return cursors[0] if cursors else None

View File

@@ -1,5 +1,7 @@
from application.parser.connectors.google_drive.loader import GoogleDriveLoader
from application.parser.connectors.confluence.auth import ConfluenceAuth
from application.parser.connectors.confluence.loader import ConfluenceLoader
from application.parser.connectors.google_drive.auth import GoogleDriveAuth
from application.parser.connectors.google_drive.loader import GoogleDriveLoader
from application.parser.connectors.share_point.auth import SharePointAuth
from application.parser.connectors.share_point.loader import SharePointLoader
@@ -13,11 +15,13 @@ class ConnectorCreator:
"""
connectors = {
"confluence": ConfluenceLoader,
"google_drive": GoogleDriveLoader,
"share_point": SharePointLoader,
}
auth_providers = {
"confluence": ConfluenceAuth,
"google_drive": GoogleDriveAuth,
"share_point": SharePointAuth,
}

View File

@@ -205,7 +205,7 @@ class SharePointLoader(BaseConnectorLoader):
try:
url = self._get_item_url(file_id)
params = {'$select': 'id,name,file,createdDateTime,lastModifiedDateTime,size'}
response = requests.get(url, headers=self._get_headers(), params=params)
response = requests.get(url, headers=self._get_headers(), params=params, timeout=100)
response.raise_for_status()
file_metadata = response.json()
@@ -236,9 +236,9 @@ class SharePointLoader(BaseConnectorLoader):
search_url = f"{self.GRAPH_API_BASE}/drives/{drive_id}/root/search(q='{encoded_query}')"
else:
search_url = f"{self.GRAPH_API_BASE}/me/drive/search(q='{encoded_query}')"
response = requests.get(search_url, headers=self._get_headers(), params=params)
response = requests.get(search_url, headers=self._get_headers(), params=params, timeout=100)
else:
response = requests.get(url, headers=self._get_headers(), params=params)
response = requests.get(url, headers=self._get_headers(), params=params, timeout=100)
response.raise_for_status()
@@ -307,7 +307,8 @@ class SharePointLoader(BaseConnectorLoader):
response = requests.get(
f"{self.GRAPH_API_BASE}/me/drive",
headers=self._get_headers(),
params={'$select': 'webUrl'}
params={'$select': 'webUrl'},
timeout=100,
)
response.raise_for_status()
return response.json().get('webUrl')
@@ -352,7 +353,7 @@ class SharePointLoader(BaseConnectorLoader):
headers = self._get_headers()
headers["Content-Type"] = "application/json"
response = requests.post(url, headers=headers, json=body)
response = requests.post(url, headers=headers, json=body, timeout=100)
response.raise_for_status()
results = response.json()
@@ -472,7 +473,7 @@ class SharePointLoader(BaseConnectorLoader):
try:
url = f"{self._get_item_url(file_id)}/content"
response = requests.get(url, headers=self._get_headers())
response = requests.get(url, headers=self._get_headers(), timeout=100)
response.raise_for_status()
try:
@@ -491,7 +492,7 @@ class SharePointLoader(BaseConnectorLoader):
try:
url = self._get_item_url(file_id)
params = {'$select': 'id,name,file'}
response = requests.get(url, headers=self._get_headers(), params=params)
response = requests.get(url, headers=self._get_headers(), params=params, timeout=100)
response.raise_for_status()
metadata = response.json()
@@ -507,7 +508,7 @@ class SharePointLoader(BaseConnectorLoader):
full_path = os.path.join(local_dir, file_name)
download_url = f"{self._get_item_url(file_id)}/content"
download_response = requests.get(download_url, headers=self._get_headers())
download_response = requests.get(download_url, headers=self._get_headers(), timeout=100)
download_response.raise_for_status()
with open(full_path, 'wb') as f:
@@ -527,7 +528,7 @@ class SharePointLoader(BaseConnectorLoader):
params = {'$top': 1000}
while url:
response = requests.get(url, headers=self._get_headers(), params=params)
response = requests.get(url, headers=self._get_headers(), params=params, timeout=100)
response.raise_for_status()
results = response.json()
@@ -609,7 +610,7 @@ class SharePointLoader(BaseConnectorLoader):
try:
url = self._get_item_url(folder_id)
params = {'$select': 'id,name'}
response = requests.get(url, headers=self._get_headers(), params=params)
response = requests.get(url, headers=self._get_headers(), params=params, timeout=100)
response.raise_for_status()
folder_metadata = response.json()

View File

@@ -24,7 +24,7 @@ class PDFParser(BaseParser):
# alternatively you can use local vision capable LLM
with open(file, "rb") as file_loaded:
files = {'file': file_loaded}
response = requests.post(doc2md_service, files=files)
response = requests.post(doc2md_service, files=files, timeout=100)
data = response.json()["markdown"]
return data

View File

@@ -19,25 +19,10 @@ class EpubParser(BaseParser):
def parse_file(self, file: Path, errors: str = "ignore") -> str:
"""Parse file."""
try:
import ebooklib
from ebooklib import epub
from fast_ebook import epub
except ImportError:
raise ValueError("`EbookLib` is required to read Epub files.")
try:
import html2text
except ImportError:
raise ValueError("`html2text` is required to parse Epub files.")
raise ValueError("`fast-ebook` is required to read Epub files.")
text_list = []
book = epub.read_epub(file, options={"ignore_ncx": True})
# Iterate through all chapters.
for item in book.get_items():
# Chapters are typically located in epub documents items.
if item.get_type() == ebooklib.ITEM_DOCUMENT:
text_list.append(
html2text.html2text(item.get_content().decode("utf-8"))
)
text = "\n".join(text_list)
book = epub.read_epub(file)
text = book.to_markdown()
return text

View File

@@ -24,7 +24,7 @@ class ImageParser(BaseParser):
# alternatively you can use local vision capable LLM
with open(file, "rb") as file_loaded:
files = {'file': file_loaded}
response = requests.post(doc2md_service, files=files)
response = requests.post(doc2md_service, files=files, timeout=100)
data = response.json()["markdown"]
else:
data = ""

View File

@@ -77,7 +77,7 @@ class GitHubLoader(BaseRemote):
def _make_request(self, url: str, max_retries: int = 3) -> requests.Response:
"""Make a request with retry logic for rate limiting"""
for attempt in range(max_retries):
response = requests.get(url, headers=self.headers)
response = requests.get(url, headers=self.headers, timeout=100)
if response.status_code == 200:
return response

View File

@@ -1,9 +1,10 @@
anthropic==0.75.0
boto3==1.42.17
alembic>=1.13,<2
anthropic==0.88.0
boto3==1.42.83
beautifulsoup4==4.14.3
cel-python==0.5.0
celery==5.6.0
cryptography==46.0.3
celery==5.6.3
cryptography==46.0.6
dataclasses-json==0.6.7
defusedxml==0.7.1
docling>=2.16.0
@@ -11,89 +12,84 @@ rapidocr>=1.4.0
onnxruntime>=1.19.0
docx2txt==0.9
ddgs>=8.0.0
ebooklib==0.20
escodegen==1.0.11
esprima==4.0.1
esutils==1.0.1
elevenlabs==2.27.0
Flask==3.1.2
fast-ebook
elevenlabs==2.41.0
Flask==3.1.3
faiss-cpu==1.13.2
fastmcp==2.14.1
fastmcp==3.2.0
flask-restx==1.3.2
google-genai==1.54.0
google-api-python-client==2.187.0
google-auth-httplib2==0.3.0
google-auth-oauthlib==1.2.3
google-genai==1.69.0
google-api-python-client==2.193.0
google-auth-httplib2==0.3.1
google-auth-oauthlib==1.3.1
gTTS==2.5.4
gunicorn==23.0.0
html2text==2025.4.15
javalang==0.13.0
gunicorn==25.3.0
jinja2==3.1.6
jiter==0.12.0
jmespath==1.0.1
jiter==0.13.0
jmespath==1.1.0
joblib==1.5.3
jsonpatch==1.33
jsonpointer==3.0.0
kombu==5.6.1
langchain==1.2.0
kombu==5.6.2
langchain==1.2.3
langchain-community==0.4.1
langchain-core==1.2.5
langchain-openai==1.1.6
langchain-text-splitters==1.1.0
langsmith==0.5.1
langchain-core==1.2.23
langchain-openai==1.1.12
langchain-text-splitters==1.1.1
langsmith==0.7.23
lazy-object-proxy==1.12.0
lxml==6.0.2
markupsafe==3.0.3
marshmallow>=3.18.0,<5.0.0
mpmath==1.3.0
multidict==6.7.0
msal==1.34.0
multidict==6.7.1
msal==1.35.1
mypy-extensions==1.1.0
networkx==3.6.1
numpy==2.4.0
openai==2.14.0
numpy==2.4.4
openai==2.30.0
openapi3-parser==1.1.22
orjson==3.11.5
packaging==24.2
pandas==2.3.3
orjson==3.11.7
packaging==26.0
pandas==3.0.2
openpyxl==3.1.5
pathable==0.4.4
pathable==0.5.0
pdf2image>=1.17.0
pillow
portalocker>=2.7.0,<3.0.0
prance==25.4.8.0
portalocker>=2.7.0,<4.0.0
prompt-toolkit==3.0.52
protobuf==6.33.2
psycopg2-binary==2.9.11
protobuf==7.34.1
psycopg[binary,pool]>=3.1,<4
py==1.11.0
pydantic
pydantic-core
pydantic-settings
pymongo==4.15.5
pypdf==6.5.0
pymongo==4.16.0
pypdf==6.9.2
python-dateutil==2.9.0.post0
python-dotenv
python-jose==3.5.0
python-pptx==1.0.2
redis==7.1.0
redis==7.4.0
referencing>=0.28.0,<0.38.0
regex==2025.11.3
requests==2.32.5
regex==2026.4.4
requests==2.33.1
retry==0.9.2
sentence-transformers==5.2.0
sentence-transformers==5.3.0
sqlalchemy>=2.0,<3
tiktoken==0.12.0
tokenizers==0.22.1
torch==2.9.1
tqdm==4.67.1
transformers==4.57.3
tokenizers==0.22.2
torch==2.11.0
tqdm==4.67.3
transformers==5.4.0
typing-extensions==4.15.0
typing-inspect==0.9.0
tzdata==2025.3
urllib3==2.6.3
vine==5.1.0
wcwidth==0.2.14
wcwidth==0.6.0
werkzeug>=3.1.0
yarl==1.22.0
yarl==1.23.0
markdownify==1.2.2
tldextract==5.3.0
websockets==15.0.1
tldextract==5.3.1
websockets==16.0

View File

@@ -0,0 +1,10 @@
"""PostgreSQL storage layer for user-level data.
This package holds the SQLAlchemy Core engine, metadata, repositories, and
migration infrastructure for the user-data Postgres database. It is separate
from ``application/vectorstore/pgvector.py`` — the two may point at the same
cluster or at different clusters depending on operator configuration.
Repository modules are added in later phases
as individual collections are ported.
"""

View File

@@ -0,0 +1,39 @@
"""Common helpers shared by all repositories.
Repositories are thin wrappers around SQLAlchemy Core query construction.
They take a ``Connection`` on call and return plain ``dict`` rows during the
Mongo→Postgres cutover so that call sites don't have to change shape. Once
cutover is complete, a follow-up phase may migrate repo return types to
Pydantic DTOs (tracked in the migration plan as a post-migration item).
"""
from typing import Any, Mapping
from uuid import UUID
def row_to_dict(row: Any) -> dict:
"""Convert a SQLAlchemy ``Row`` to a plain dict with Mongo-compatible ids.
During the migration window, API responses and downstream code still
expect a string ``_id`` field (matching the Mongo shape). This helper
normalizes UUID columns to strings and emits both ``id`` and ``_id`` so
existing serializers keep working unchanged.
Args:
row: A SQLAlchemy ``Row`` object, or ``None``.
Returns:
A plain dict, or an empty dict if ``row`` is ``None``.
"""
if row is None:
return {}
# Row has a ``._mapping`` attribute exposing a MappingProxy view.
mapping: Mapping[str, Any] = row._mapping # type: ignore[attr-defined]
out = dict(mapping)
if "id" in out and out["id"] is not None:
out["id"] = str(out["id"]) if isinstance(out["id"], UUID) else out["id"]
out["_id"] = out["id"]
return out

View File

@@ -0,0 +1,67 @@
"""Best-effort Postgres dual-write helper used during the MongoDB→Postgres
migration.
The helper:
* Returns immediately if ``settings.USE_POSTGRES`` is off, so default-off
call sites add literally zero work.
* Opens a transactional connection from the user-data SQLAlchemy engine.
* Instantiates the caller's repository class on that connection.
* Runs the caller's operation.
* Swallows and logs any exception. **Mongo remains the source of truth
during the dual-write window** — a Postgres-side failure must never
break a user-facing request. Drift that builds up from swallowed
failures is caught separately by re-running the backfill script.
Call sites look like::
users_collection.update_one(..., {"$addToSet": {...}}) # Mongo write, unchanged
dual_write(UsersRepository, lambda r: r.add_pinned(uid, aid)) # Postgres mirror
A single parameterised helper rather than one function per collection
means a new collection just needs its repository class — no new helper
function, no new feature flag. The whole helper is deleted at Phase 5
when the migration is complete.
"""
from __future__ import annotations
import logging
from typing import Callable, TypeVar
from application.core.settings import settings
logger = logging.getLogger(__name__)
_Repo = TypeVar("_Repo")
def dual_write(repo_cls: type[_Repo], fn: Callable[[_Repo], None]) -> None:
"""Mirror a Mongo write into Postgres via ``repo_cls``, best-effort.
No-op when ``settings.USE_POSTGRES`` is false. Any exception
(connection pool exhaustion, migration drift, SQL error) is logged
and swallowed so the caller's primary Mongo write remains the source
of truth.
Args:
repo_cls: The repository class to instantiate (e.g. ``UsersRepository``).
fn: A callable that takes the instantiated repository and performs
the desired write.
"""
if not settings.USE_POSTGRES:
return
try:
# Lazy import so modules that import dual_write don't pay the
# SQLAlchemy import cost when the flag is off.
from application.storage.db.engine import get_engine
with get_engine().begin() as conn:
fn(repo_cls(conn))
except Exception:
logger.warning(
"Postgres dual-write failed for %s — Mongo write already committed",
repo_cls.__name__,
exc_info=True,
)

View File

@@ -0,0 +1,73 @@
"""SQLAlchemy Core engine factory for the user-data Postgres database.
The engine is lazily constructed on first use and cached as a module-level
singleton. Repositories and the Alembic env module both obtain connections
through this factory, so pool tuning lives in one place.
``POSTGRES_URI`` can be written in any of the common Postgres URI forms::
postgres://user:pass@host:5432/docsgpt
postgresql://user:pass@host:5432/docsgpt
Both are accepted and normalized internally to the psycopg3 dialect
(``postgresql+psycopg://``) by ``application.core.settings``. Operators
don't need to know about SQLAlchemy dialect prefixes.
"""
from typing import Optional
from sqlalchemy import Engine, create_engine
from application.core.settings import settings
_engine: Optional[Engine] = None
def _resolve_uri() -> str:
"""Return the Postgres URI for user-data tables.
Raises:
RuntimeError: If ``settings.POSTGRES_URI`` is unset. Callers that
reach this path without a configured URI have a setup bug — the
error message points them at the right setting.
"""
if not settings.POSTGRES_URI:
raise RuntimeError(
"POSTGRES_URI is not configured. Set it in your .env to a "
"psycopg3 URI such as "
"'postgresql+psycopg://user:pass@host:5432/docsgpt'."
)
return settings.POSTGRES_URI
def get_engine() -> Engine:
"""Return the process-wide SQLAlchemy Engine, creating it if needed.
Returns:
A SQLAlchemy ``Engine`` configured with a pooled connection to
Postgres via psycopg3.
"""
global _engine
if _engine is None:
_engine = create_engine(
_resolve_uri(),
pool_size=10,
max_overflow=20,
pool_pre_ping=True, # survive PgBouncer / idle-disconnect recycles
pool_recycle=1800,
future=True,
)
return _engine
def dispose_engine() -> None:
"""Dispose the pooled connections and reset the singleton.
Called from the Celery ``worker_process_init`` signal so each forked
worker gets a fresh pool instead of sharing file descriptors with the
parent process (which corrupts the pool on fork).
"""
global _engine
if _engine is not None:
_engine.dispose()
_engine = None

View File

@@ -0,0 +1,232 @@
"""SQLAlchemy Core metadata for the user-data Postgres database.
Tables are added here one at a time as repositories are built during the
MongoDB→Postgres migration. The baseline schema in the Alembic migration
(``application/alembic/versions/0001_initial.py``) is the source of truth
for DDL; the ``Table`` definitions below must match it column-for-column.
If the two drift, migrations win — update this file to match.
"""
from sqlalchemy import (
BigInteger,
Boolean,
Column,
DateTime,
ForeignKey,
Integer,
MetaData,
UniqueConstraint,
Table,
Text,
func,
)
from sqlalchemy.dialects.postgresql import ARRAY, JSONB, UUID
metadata = MetaData()
# --- Phase 1, Tier 1 --------------------------------------------------------
users_table = Table(
"users",
metadata,
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
Column("user_id", Text, nullable=False, unique=True),
Column(
"agent_preferences",
JSONB,
nullable=False,
server_default='{"pinned": [], "shared_with_me": []}',
),
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
)
prompts_table = Table(
"prompts",
metadata,
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
Column("user_id", Text, nullable=False),
Column("name", Text, nullable=False),
Column("content", Text, nullable=False),
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
)
user_tools_table = Table(
"user_tools",
metadata,
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
Column("user_id", Text, nullable=False),
Column("name", Text, nullable=False),
Column("custom_name", Text),
Column("display_name", Text),
Column("config", JSONB, nullable=False, server_default="{}"),
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
)
token_usage_table = Table(
"token_usage",
metadata,
Column("id", BigInteger, primary_key=True, autoincrement=True),
Column("user_id", Text),
Column("api_key", Text),
Column("agent_id", UUID(as_uuid=True)),
Column("prompt_tokens", Integer, nullable=False, server_default="0"),
Column("generated_tokens", Integer, nullable=False, server_default="0"),
Column("timestamp", DateTime(timezone=True), nullable=False, server_default=func.now()),
)
user_logs_table = Table(
"user_logs",
metadata,
Column("id", BigInteger, primary_key=True, autoincrement=True),
Column("user_id", Text),
Column("endpoint", Text),
Column("timestamp", DateTime(timezone=True), nullable=False, server_default=func.now()),
Column("data", JSONB),
)
feedback_table = Table(
"feedback",
metadata,
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
Column("conversation_id", UUID(as_uuid=True), nullable=False),
Column("user_id", Text, nullable=False),
Column("question_index", Integer, nullable=False),
Column("feedback_text", Text),
Column("timestamp", DateTime(timezone=True), nullable=False, server_default=func.now()),
)
stack_logs_table = Table(
"stack_logs",
metadata,
Column("id", BigInteger, primary_key=True, autoincrement=True),
Column("activity_id", Text, nullable=False),
Column("endpoint", Text),
Column("level", Text),
Column("user_id", Text),
Column("api_key", Text),
Column("query", Text),
Column("stacks", JSONB, nullable=False, server_default="[]"),
Column("timestamp", DateTime(timezone=True), nullable=False, server_default=func.now()),
)
# --- Phase 2, Tier 2 --------------------------------------------------------
agent_folders_table = Table(
"agent_folders",
metadata,
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
Column("user_id", Text, nullable=False),
Column("name", Text, nullable=False),
Column("description", Text),
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
)
sources_table = Table(
"sources",
metadata,
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
Column("user_id", Text),
Column("name", Text, nullable=False),
Column("type", Text),
Column("metadata", JSONB, nullable=False, server_default="{}"),
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
)
agents_table = Table(
"agents",
metadata,
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
Column("user_id", Text, nullable=False),
Column("name", Text, nullable=False),
Column("description", Text),
Column("agent_type", Text),
Column("status", Text, nullable=False),
Column("key", Text, unique=True),
Column("source_id", UUID(as_uuid=True), ForeignKey("sources.id", ondelete="SET NULL")),
Column("extra_source_ids", ARRAY(UUID(as_uuid=True)), nullable=False, server_default="{}"),
Column("chunks", Integer),
Column("retriever", Text),
Column("prompt_id", UUID(as_uuid=True), ForeignKey("prompts.id", ondelete="SET NULL")),
Column("tools", JSONB, nullable=False, server_default="[]"),
Column("json_schema", JSONB),
Column("models", JSONB),
Column("default_model_id", Text),
Column("folder_id", UUID(as_uuid=True), ForeignKey("agent_folders.id", ondelete="SET NULL")),
Column("limited_token_mode", Boolean, nullable=False, server_default="false"),
Column("token_limit", Integer),
Column("limited_request_mode", Boolean, nullable=False, server_default="false"),
Column("request_limit", Integer),
Column("shared", Boolean, nullable=False, server_default="false"),
Column("incoming_webhook_token", Text, unique=True),
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
Column("last_used_at", DateTime(timezone=True)),
)
attachments_table = Table(
"attachments",
metadata,
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
Column("user_id", Text, nullable=False),
Column("filename", Text, nullable=False),
Column("upload_path", Text, nullable=False),
Column("mime_type", Text),
Column("size", BigInteger),
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
)
memories_table = Table(
"memories",
metadata,
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
Column("user_id", Text, nullable=False),
Column("tool_id", UUID(as_uuid=True), ForeignKey("user_tools.id", ondelete="CASCADE")),
Column("path", Text, nullable=False),
Column("content", Text, nullable=False),
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
UniqueConstraint("user_id", "tool_id", "path", name="memories_user_tool_path_uidx"),
)
todos_table = Table(
"todos",
metadata,
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
Column("user_id", Text, nullable=False),
Column("tool_id", UUID(as_uuid=True), ForeignKey("user_tools.id", ondelete="CASCADE")),
Column("title", Text, nullable=False),
Column("completed", Boolean, nullable=False, server_default="false"),
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
)
notes_table = Table(
"notes",
metadata,
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
Column("user_id", Text, nullable=False),
Column("tool_id", UUID(as_uuid=True), ForeignKey("user_tools.id", ondelete="CASCADE")),
Column("title", Text, nullable=False),
Column("content", Text, nullable=False),
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
UniqueConstraint("user_id", "tool_id", name="notes_user_tool_uidx"),
)
connector_sessions_table = Table(
"connector_sessions",
metadata,
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
Column("user_id", Text, nullable=False),
Column("provider", Text, nullable=False),
Column("session_data", JSONB, nullable=False),
Column("expires_at", DateTime(timezone=True)),
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
UniqueConstraint("user_id", "provider", name="connector_sessions_user_provider_uidx"),
)

View File

@@ -0,0 +1,11 @@
"""Repositories for the user-data Postgres database.
Each module in this package exposes exactly one repository class. Repository
methods take a ``Connection`` (either as a constructor argument or as a
method argument) and return plain ``dict`` rows via
``application.storage.db.base_repository.row_to_dict`` during the
MongoDB→Postgres cutover, so call sites don't have to change shape.
Repositories are added one collection at a time, matching the phased
rollout in ``migration-postgres.md``.
"""

View File

@@ -0,0 +1,88 @@
"""Repository for the ``agent_folders`` table."""
from __future__ import annotations
from typing import Optional
from sqlalchemy import Connection, text
from application.storage.db.base_repository import row_to_dict
class AgentFoldersRepository:
def __init__(self, conn: Connection) -> None:
self._conn = conn
def create(self, user_id: str, name: str, *, description: Optional[str] = None) -> dict:
result = self._conn.execute(
text(
"""
INSERT INTO agent_folders (user_id, name, description)
VALUES (:user_id, :name, :description)
RETURNING *
"""
),
{"user_id": user_id, "name": name, "description": description},
)
return row_to_dict(result.fetchone())
def get(self, folder_id: str, user_id: str) -> Optional[dict]:
result = self._conn.execute(
text("SELECT * FROM agent_folders WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
{"id": folder_id, "user_id": user_id},
)
row = result.fetchone()
return row_to_dict(row) if row is not None else None
def list_for_user(self, user_id: str) -> list[dict]:
result = self._conn.execute(
text("SELECT * FROM agent_folders WHERE user_id = :user_id ORDER BY created_at"),
{"user_id": user_id},
)
return [row_to_dict(r) for r in result.fetchall()]
def update(self, folder_id: str, user_id: str, fields: dict) -> bool:
allowed = {"name", "description"}
filtered = {k: v for k, v in fields.items() if k in allowed}
if not filtered:
return False
params: dict = {"id": folder_id, "user_id": user_id}
if "name" in filtered and "description" in filtered:
params["name"] = filtered["name"]
params["description"] = filtered["description"]
result = self._conn.execute(
text(
"UPDATE agent_folders "
"SET name = :name, description = :description, updated_at = now() "
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
),
params,
)
elif "name" in filtered:
params["name"] = filtered["name"]
result = self._conn.execute(
text(
"UPDATE agent_folders "
"SET name = :name, updated_at = now() "
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
),
params,
)
else:
params["description"] = filtered["description"]
result = self._conn.execute(
text(
"UPDATE agent_folders "
"SET description = :description, updated_at = now() "
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
),
params,
)
return result.rowcount > 0
def delete(self, folder_id: str, user_id: str) -> bool:
result = self._conn.execute(
text("DELETE FROM agent_folders WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
{"id": folder_id, "user_id": user_id},
)
return result.rowcount > 0

View File

@@ -0,0 +1,154 @@
"""Repository for the ``agents`` table.
This is the most complex Phase 2 repository. Covers every write operation
the legacy Mongo code performs on ``agents_collection``:
- create, update, delete
- find by key (API key lookup)
- find by webhook token
- list for user, list templates
- folder assignment
"""
from __future__ import annotations
import json
from typing import Optional
from sqlalchemy import Connection, func, text
from sqlalchemy.dialects.postgresql import insert as pg_insert
from application.storage.db.base_repository import row_to_dict
from application.storage.db.models import agents_table
class AgentsRepository:
def __init__(self, conn: Connection) -> None:
self._conn = conn
def create(self, user_id: str, name: str, status: str, **kwargs) -> dict:
values: dict = {"user_id": user_id, "name": name, "status": status}
_ALLOWED = {
"description", "agent_type", "key", "retriever",
"default_model_id", "incoming_webhook_token",
"source_id", "prompt_id", "folder_id",
"chunks", "token_limit", "request_limit",
"limited_token_mode", "limited_request_mode", "shared",
"tools", "json_schema", "models",
}
for col, val in kwargs.items():
if col not in _ALLOWED or val is None:
continue
if col in ("tools", "json_schema", "models"):
values[col] = json.dumps(val)
elif col in ("chunks", "token_limit", "request_limit"):
values[col] = int(val)
elif col in ("limited_token_mode", "limited_request_mode", "shared"):
values[col] = bool(val)
elif col in ("source_id", "prompt_id", "folder_id"):
values[col] = str(val)
else:
values[col] = val
stmt = pg_insert(agents_table).values(**values).returning(agents_table)
result = self._conn.execute(stmt)
return row_to_dict(result.fetchone())
def get(self, agent_id: str, user_id: str) -> Optional[dict]:
result = self._conn.execute(
text("SELECT * FROM agents WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
{"id": agent_id, "user_id": user_id},
)
row = result.fetchone()
return row_to_dict(row) if row is not None else None
def find_by_key(self, key: str) -> Optional[dict]:
result = self._conn.execute(
text("SELECT * FROM agents WHERE key = :key"),
{"key": key},
)
row = result.fetchone()
return row_to_dict(row) if row is not None else None
def find_by_webhook_token(self, token: str) -> Optional[dict]:
result = self._conn.execute(
text("SELECT * FROM agents WHERE incoming_webhook_token = :token"),
{"token": token},
)
row = result.fetchone()
return row_to_dict(row) if row is not None else None
def list_for_user(self, user_id: str) -> list[dict]:
result = self._conn.execute(
text("SELECT * FROM agents WHERE user_id = :user_id ORDER BY created_at DESC"),
{"user_id": user_id},
)
return [row_to_dict(r) for r in result.fetchall()]
def list_templates(self) -> list[dict]:
result = self._conn.execute(
text("SELECT * FROM agents WHERE user_id = 'system' ORDER BY name"),
)
return [row_to_dict(r) for r in result.fetchall()]
def update(self, agent_id: str, user_id: str, fields: dict) -> bool:
allowed = {
"name", "description", "agent_type", "status", "key", "source_id",
"chunks", "retriever", "prompt_id", "tools", "json_schema", "models",
"default_model_id", "folder_id", "limited_token_mode", "token_limit",
"limited_request_mode", "request_limit", "shared",
"incoming_webhook_token", "last_used_at",
}
filtered = {k: v for k, v in fields.items() if k in allowed}
if not filtered:
return False
values: dict = {}
for col, val in filtered.items():
if col in ("tools", "json_schema", "models"):
values[col] = json.dumps(val) if not isinstance(val, str) else val
elif col in ("source_id", "prompt_id", "folder_id"):
values[col] = str(val) if val else None
else:
values[col] = val
values["updated_at"] = func.now()
t = agents_table
stmt = (
t.update()
.where(t.c.id == agent_id)
.where(t.c.user_id == user_id)
.values(**values)
)
result = self._conn.execute(stmt)
return result.rowcount > 0
def delete(self, agent_id: str, user_id: str) -> bool:
result = self._conn.execute(
text("DELETE FROM agents WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
{"id": agent_id, "user_id": user_id},
)
return result.rowcount > 0
def set_folder(self, agent_id: str, user_id: str, folder_id: Optional[str]) -> None:
self._conn.execute(
text(
"""
UPDATE agents SET folder_id = CAST(:folder_id AS uuid), updated_at = now()
WHERE id = CAST(:id AS uuid) AND user_id = :user_id
"""
),
{"id": agent_id, "user_id": user_id, "folder_id": folder_id},
)
def clear_folder_for_all(self, folder_id: str, user_id: str) -> None:
"""Remove folder assignment from all agents in a folder (used on folder delete)."""
self._conn.execute(
text(
"UPDATE agents SET folder_id = NULL, updated_at = now() "
"WHERE folder_id = CAST(:folder_id AS uuid) AND user_id = :user_id"
),
{"folder_id": folder_id, "user_id": user_id},
)

View File

@@ -0,0 +1,51 @@
"""Repository for the ``attachments`` table."""
from __future__ import annotations
from typing import Optional
from sqlalchemy import Connection, text
from application.storage.db.base_repository import row_to_dict
class AttachmentsRepository:
def __init__(self, conn: Connection) -> None:
self._conn = conn
def create(self, user_id: str, filename: str, upload_path: str, *,
mime_type: Optional[str] = None, size: Optional[int] = None) -> dict:
result = self._conn.execute(
text(
"""
INSERT INTO attachments (user_id, filename, upload_path, mime_type, size)
VALUES (:user_id, :filename, :upload_path, :mime_type, :size)
RETURNING *
"""
),
{
"user_id": user_id,
"filename": filename,
"upload_path": upload_path,
"mime_type": mime_type,
"size": size,
},
)
return row_to_dict(result.fetchone())
def get(self, attachment_id: str, user_id: str) -> Optional[dict]:
result = self._conn.execute(
text(
"SELECT * FROM attachments WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
),
{"id": attachment_id, "user_id": user_id},
)
row = result.fetchone()
return row_to_dict(row) if row is not None else None
def list_for_user(self, user_id: str) -> list[dict]:
result = self._conn.execute(
text("SELECT * FROM attachments WHERE user_id = :user_id ORDER BY created_at DESC"),
{"user_id": user_id},
)
return [row_to_dict(r) for r in result.fetchall()]

View File

@@ -0,0 +1,65 @@
"""Repository for the ``connector_sessions`` table.
Covers operations across connector routes and tools:
- upsert session data
- find session by user + provider
- find session by token
- delete session
"""
from __future__ import annotations
import json
from typing import Optional
from sqlalchemy import Connection, text
from application.storage.db.base_repository import row_to_dict
class ConnectorSessionsRepository:
def __init__(self, conn: Connection) -> None:
self._conn = conn
def upsert(self, user_id: str, provider: str, session_data: dict) -> dict:
result = self._conn.execute(
text(
"""
INSERT INTO connector_sessions (user_id, provider, session_data)
VALUES (:user_id, :provider, CAST(:session_data AS jsonb))
ON CONFLICT (user_id, provider)
DO UPDATE SET session_data = EXCLUDED.session_data
RETURNING *
"""
),
{
"user_id": user_id,
"provider": provider,
"session_data": json.dumps(session_data),
},
)
return row_to_dict(result.fetchone())
def get_by_user_provider(self, user_id: str, provider: str) -> Optional[dict]:
result = self._conn.execute(
text(
"SELECT * FROM connector_sessions WHERE user_id = :user_id AND provider = :provider"
),
{"user_id": user_id, "provider": provider},
)
row = result.fetchone()
return row_to_dict(row) if row is not None else None
def list_for_user(self, user_id: str) -> list[dict]:
result = self._conn.execute(
text("SELECT * FROM connector_sessions WHERE user_id = :user_id"),
{"user_id": user_id},
)
return [row_to_dict(r) for r in result.fetchall()]
def delete(self, user_id: str, provider: str) -> bool:
result = self._conn.execute(
text("DELETE FROM connector_sessions WHERE user_id = :user_id AND provider = :provider"),
{"user_id": user_id, "provider": provider},
)
return result.rowcount > 0

View File

@@ -0,0 +1,57 @@
"""Repository for the ``feedback`` table.
The ``feedback_collection`` global is declared in ``base.py`` but currently
has zero direct call sites in the application code (all feedback writes go
through ``conversation_messages.feedback`` JSONB field on the conversations
collection). The table exists for when feedback is denormalized into its own
rows. This repository provides the append-only insert and basic reads
needed for that future.
"""
from __future__ import annotations
from typing import Optional
from sqlalchemy import Connection, text
from application.storage.db.base_repository import row_to_dict
class FeedbackRepository:
"""Postgres-backed replacement for Mongo ``feedback_collection``."""
def __init__(self, conn: Connection) -> None:
self._conn = conn
def create(
self,
conversation_id: str,
user_id: str,
question_index: int,
feedback_text: Optional[str] = None,
) -> dict:
result = self._conn.execute(
text(
"""
INSERT INTO feedback (conversation_id, user_id, question_index, feedback_text)
VALUES (CAST(:conversation_id AS uuid), :user_id, :question_index, :feedback_text)
RETURNING *
"""
),
{
"conversation_id": conversation_id,
"user_id": user_id,
"question_index": question_index,
"feedback_text": feedback_text,
},
)
return row_to_dict(result.fetchone())
def list_for_conversation(self, conversation_id: str) -> list[dict]:
result = self._conn.execute(
text(
"SELECT * FROM feedback WHERE conversation_id = CAST(:cid AS uuid) ORDER BY question_index"
),
{"cid": conversation_id},
)
return [row_to_dict(r) for r in result.fetchall()]

View File

@@ -0,0 +1,97 @@
"""Repository for the ``memories`` table.
Covers the operations in ``application/agents/tools/memory.py``:
- upsert (create/overwrite file)
- find by path (view file)
- find by path prefix (view directory, regex scan)
- delete by path / path prefix
- rename (update path)
"""
from __future__ import annotations
from typing import Optional
from sqlalchemy import Connection, text
from application.storage.db.base_repository import row_to_dict
class MemoriesRepository:
def __init__(self, conn: Connection) -> None:
self._conn = conn
def upsert(self, user_id: str, tool_id: str, path: str, content: str) -> dict:
result = self._conn.execute(
text(
"""
INSERT INTO memories (user_id, tool_id, path, content)
VALUES (:user_id, CAST(:tool_id AS uuid), :path, :content)
ON CONFLICT (user_id, tool_id, path)
DO UPDATE SET content = EXCLUDED.content, updated_at = now()
RETURNING *
"""
),
{"user_id": user_id, "tool_id": tool_id, "path": path, "content": content},
)
return row_to_dict(result.fetchone())
def get_by_path(self, user_id: str, tool_id: str, path: str) -> Optional[dict]:
result = self._conn.execute(
text(
"SELECT * FROM memories WHERE user_id = :user_id "
"AND tool_id = CAST(:tool_id AS uuid) AND path = :path"
),
{"user_id": user_id, "tool_id": tool_id, "path": path},
)
row = result.fetchone()
return row_to_dict(row) if row is not None else None
def list_by_prefix(self, user_id: str, tool_id: str, prefix: str) -> list[dict]:
result = self._conn.execute(
text(
"SELECT * FROM memories WHERE user_id = :user_id "
"AND tool_id = CAST(:tool_id AS uuid) AND path LIKE :prefix"
),
{"user_id": user_id, "tool_id": tool_id, "prefix": prefix + "%"},
)
return [row_to_dict(r) for r in result.fetchall()]
def delete_by_path(self, user_id: str, tool_id: str, path: str) -> int:
result = self._conn.execute(
text(
"DELETE FROM memories WHERE user_id = :user_id "
"AND tool_id = CAST(:tool_id AS uuid) AND path = :path"
),
{"user_id": user_id, "tool_id": tool_id, "path": path},
)
return result.rowcount
def delete_by_prefix(self, user_id: str, tool_id: str, prefix: str) -> int:
result = self._conn.execute(
text(
"DELETE FROM memories WHERE user_id = :user_id "
"AND tool_id = CAST(:tool_id AS uuid) AND path LIKE :prefix"
),
{"user_id": user_id, "tool_id": tool_id, "prefix": prefix + "%"},
)
return result.rowcount
def delete_all(self, user_id: str, tool_id: str) -> int:
result = self._conn.execute(
text(
"DELETE FROM memories WHERE user_id = :user_id AND tool_id = CAST(:tool_id AS uuid)"
),
{"user_id": user_id, "tool_id": tool_id},
)
return result.rowcount
def update_path(self, user_id: str, tool_id: str, old_path: str, new_path: str) -> bool:
result = self._conn.execute(
text(
"UPDATE memories SET path = :new_path, updated_at = now() "
"WHERE user_id = :user_id AND tool_id = CAST(:tool_id AS uuid) AND path = :old_path"
),
{"user_id": user_id, "tool_id": tool_id, "old_path": old_path, "new_path": new_path},
)
return result.rowcount > 0

View File

@@ -0,0 +1,62 @@
"""Repository for the ``notes`` table.
Covers the operations in ``application/agents/tools/notes.py``.
Note: the Mongo schema stores a single ``note`` text field per (user_id, tool_id),
while the Postgres schema has ``title`` + ``content``. During dual-write,
title is set to a default and content holds the note text.
"""
from __future__ import annotations
from typing import Optional
from sqlalchemy import Connection, text
from application.storage.db.base_repository import row_to_dict
class NotesRepository:
def __init__(self, conn: Connection) -> None:
self._conn = conn
def upsert(self, user_id: str, tool_id: str, title: str, content: str) -> dict:
result = self._conn.execute(
text(
"""
INSERT INTO notes (user_id, tool_id, title, content)
VALUES (:user_id, CAST(:tool_id AS uuid), :title, :content)
ON CONFLICT (user_id, tool_id)
DO UPDATE SET content = EXCLUDED.content, title = EXCLUDED.title, updated_at = now()
RETURNING *
"""
),
{"user_id": user_id, "tool_id": tool_id, "title": title, "content": content},
)
return row_to_dict(result.fetchone())
def get_for_user_tool(self, user_id: str, tool_id: str) -> Optional[dict]:
result = self._conn.execute(
text(
"SELECT * FROM notes WHERE user_id = :user_id AND tool_id = CAST(:tool_id AS uuid)"
),
{"user_id": user_id, "tool_id": tool_id},
)
row = result.fetchone()
return row_to_dict(row) if row is not None else None
def get(self, note_id: str, user_id: str) -> Optional[dict]:
result = self._conn.execute(
text("SELECT * FROM notes WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
{"id": note_id, "user_id": user_id},
)
row = result.fetchone()
return row_to_dict(row) if row is not None else None
def delete(self, user_id: str, tool_id: str) -> bool:
result = self._conn.execute(
text(
"DELETE FROM notes WHERE user_id = :user_id AND tool_id = CAST(:tool_id AS uuid)"
),
{"user_id": user_id, "tool_id": tool_id},
)
return result.rowcount > 0

View File

@@ -0,0 +1,103 @@
"""Repository for the ``prompts`` table.
Covers every operation the legacy Mongo code performs on
``prompts_collection``:
1. ``insert_one`` in prompts/routes.py (create)
2. ``find`` by user in prompts/routes.py (list)
3. ``find_one`` by id+user in prompts/routes.py (get single)
4. ``find_one`` by id only in stream_processor.py (get content for rendering)
5. ``update_one`` in prompts/routes.py (update name+content)
6. ``delete_one`` in prompts/routes.py (delete)
7. ``find_one`` + ``insert_one`` in seeder.py (upsert by user+name+content)
"""
from __future__ import annotations
from typing import Optional
from sqlalchemy import Connection, text
from application.storage.db.base_repository import row_to_dict
class PromptsRepository:
"""Postgres-backed replacement for Mongo ``prompts_collection``."""
def __init__(self, conn: Connection) -> None:
self._conn = conn
def create(self, user_id: str, name: str, content: str) -> dict:
result = self._conn.execute(
text(
"""
INSERT INTO prompts (user_id, name, content)
VALUES (:user_id, :name, :content)
RETURNING *
"""
),
{"user_id": user_id, "name": name, "content": content},
)
return row_to_dict(result.fetchone())
def get(self, prompt_id: str, user_id: str) -> Optional[dict]:
result = self._conn.execute(
text("SELECT * FROM prompts WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
{"id": prompt_id, "user_id": user_id},
)
row = result.fetchone()
return row_to_dict(row) if row is not None else None
def get_for_rendering(self, prompt_id: str) -> Optional[dict]:
"""Fetch prompt content by ID without user scoping.
Used only by stream_processor to render a prompt whose owner is
not known at call time. Do NOT use in user-facing routes.
"""
result = self._conn.execute(
text("SELECT * FROM prompts WHERE id = CAST(:id AS uuid)"),
{"id": prompt_id},
)
row = result.fetchone()
return row_to_dict(row) if row is not None else None
def list_for_user(self, user_id: str) -> list[dict]:
result = self._conn.execute(
text("SELECT * FROM prompts WHERE user_id = :user_id ORDER BY created_at"),
{"user_id": user_id},
)
return [row_to_dict(r) for r in result.fetchall()]
def update(self, prompt_id: str, user_id: str, name: str, content: str) -> None:
self._conn.execute(
text(
"""
UPDATE prompts
SET name = :name, content = :content, updated_at = now()
WHERE id = CAST(:id AS uuid) AND user_id = :user_id
"""
),
{"id": prompt_id, "user_id": user_id, "name": name, "content": content},
)
def delete(self, prompt_id: str, user_id: str) -> None:
self._conn.execute(
text("DELETE FROM prompts WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
{"id": prompt_id, "user_id": user_id},
)
def find_or_create(self, user_id: str, name: str, content: str) -> dict:
"""Return existing prompt matching (user, name, content), or create one.
Used by the seeder to avoid duplicating template prompts.
"""
result = self._conn.execute(
text(
"SELECT * FROM prompts WHERE user_id = :user_id AND name = :name AND content = :content"
),
{"user_id": user_id, "name": name, "content": content},
)
row = result.fetchone()
if row is not None:
return row_to_dict(row)
return self.create(user_id, name, content)

View File

@@ -0,0 +1,80 @@
"""Repository for the ``sources`` table."""
from __future__ import annotations
import json
from typing import Optional
from sqlalchemy import Connection, func, text
from application.storage.db.base_repository import row_to_dict
from application.storage.db.models import sources_table
class SourcesRepository:
def __init__(self, conn: Connection) -> None:
self._conn = conn
def create(self, name: str, *, user_id: Optional[str] = None,
type: Optional[str] = None, metadata: Optional[dict] = None) -> dict:
result = self._conn.execute(
text(
"""
INSERT INTO sources (user_id, name, type, metadata)
VALUES (:user_id, :name, :type, CAST(:metadata AS jsonb))
RETURNING *
"""
),
{
"user_id": user_id,
"name": name,
"type": type,
"metadata": json.dumps(metadata or {}),
},
)
return row_to_dict(result.fetchone())
def get(self, source_id: str, user_id: str) -> Optional[dict]:
result = self._conn.execute(
text("SELECT * FROM sources WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
{"id": source_id, "user_id": user_id},
)
row = result.fetchone()
return row_to_dict(row) if row is not None else None
def list_for_user(self, user_id: str) -> list[dict]:
result = self._conn.execute(
text("SELECT * FROM sources WHERE user_id = :user_id ORDER BY created_at DESC"),
{"user_id": user_id},
)
return [row_to_dict(r) for r in result.fetchall()]
def update(self, source_id: str, user_id: str, fields: dict) -> None:
allowed = {"name", "type", "metadata"}
filtered = {k: v for k, v in fields.items() if k in allowed}
if not filtered:
return
values: dict = {}
for col, val in filtered.items():
if col == "metadata":
values[col] = json.dumps(val) if isinstance(val, dict) else val
else:
values[col] = val
values["updated_at"] = func.now()
t = sources_table
stmt = (
t.update()
.where(t.c.id == source_id)
.where(t.c.user_id == user_id)
.values(**values)
)
self._conn.execute(stmt)
def delete(self, source_id: str, user_id: str) -> bool:
result = self._conn.execute(
text("DELETE FROM sources WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
{"id": source_id, "user_id": user_id},
)
return result.rowcount > 0

View File

@@ -0,0 +1,58 @@
"""Repository for the ``stack_logs`` table.
Covers the single operation the legacy Mongo code performs:
1. ``insert_one`` in logging.py ``_log_to_mongodb`` — append-only debug/error
activity log. The Mongo collection is ``stack_logs``; the Mongo variable
inside ``_log_to_mongodb`` is misleadingly named ``user_logs_collection``.
"""
from __future__ import annotations
import json
from datetime import datetime
from typing import Optional
from sqlalchemy import Connection, text
class StackLogsRepository:
"""Postgres-backed replacement for Mongo ``stack_logs`` collection."""
def __init__(self, conn: Connection) -> None:
self._conn = conn
def insert(
self,
*,
activity_id: str,
endpoint: Optional[str] = None,
level: Optional[str] = None,
user_id: Optional[str] = None,
api_key: Optional[str] = None,
query: Optional[str] = None,
stacks: Optional[list] = None,
timestamp: Optional[datetime] = None,
) -> None:
self._conn.execute(
text(
"""
INSERT INTO stack_logs (activity_id, endpoint, level, user_id, api_key, query, stacks, timestamp)
VALUES (
:activity_id, :endpoint, :level, :user_id, :api_key, :query,
CAST(:stacks AS jsonb),
COALESCE(:timestamp, now())
)
"""
),
{
"activity_id": activity_id,
"endpoint": endpoint,
"level": level,
"user_id": user_id,
"api_key": api_key,
"query": query,
"stacks": json.dumps(stacks or []),
"timestamp": timestamp,
},
)

View File

@@ -0,0 +1,78 @@
"""Repository for the ``todos`` table.
Covers the operations in ``application/agents/tools/todo_list.py``.
Note: the Mongo schema uses ``todo_id`` (sequential int) and ``status`` (text),
while the Postgres schema uses ``completed`` (boolean) and the UUID ``id`` as PK.
The repository bridges both shapes.
"""
from __future__ import annotations
from typing import Optional
from sqlalchemy import Connection, text
from application.storage.db.base_repository import row_to_dict
class TodosRepository:
def __init__(self, conn: Connection) -> None:
self._conn = conn
def create(self, user_id: str, tool_id: str, title: str) -> dict:
result = self._conn.execute(
text(
"""
INSERT INTO todos (user_id, tool_id, title)
VALUES (:user_id, CAST(:tool_id AS uuid), :title)
RETURNING *
"""
),
{"user_id": user_id, "tool_id": tool_id, "title": title},
)
return row_to_dict(result.fetchone())
def get(self, todo_id: str, user_id: str) -> Optional[dict]:
result = self._conn.execute(
text("SELECT * FROM todos WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
{"id": todo_id, "user_id": user_id},
)
row = result.fetchone()
return row_to_dict(row) if row is not None else None
def list_for_user_tool(self, user_id: str, tool_id: str) -> list[dict]:
result = self._conn.execute(
text(
"SELECT * FROM todos WHERE user_id = :user_id "
"AND tool_id = CAST(:tool_id AS uuid) ORDER BY created_at"
),
{"user_id": user_id, "tool_id": tool_id},
)
return [row_to_dict(r) for r in result.fetchall()]
def update_title(self, todo_id: str, user_id: str, title: str) -> bool:
result = self._conn.execute(
text(
"UPDATE todos SET title = :title, updated_at = now() "
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
),
{"id": todo_id, "user_id": user_id, "title": title},
)
return result.rowcount > 0
def set_completed(self, todo_id: str, user_id: str, completed: bool = True) -> bool:
result = self._conn.execute(
text(
"UPDATE todos SET completed = :completed, updated_at = now() "
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
),
{"id": todo_id, "user_id": user_id, "completed": completed},
)
return result.rowcount > 0
def delete(self, todo_id: str, user_id: str) -> bool:
result = self._conn.execute(
text("DELETE FROM todos WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
{"id": todo_id, "user_id": user_id},
)
return result.rowcount > 0

View File

@@ -0,0 +1,104 @@
"""Repository for the ``token_usage`` table.
Covers every operation the legacy Mongo code performs on
``token_usage_collection`` / ``usage_collection``:
1. ``insert_one`` in usage.py (record per-call token counts)
2. ``aggregate`` in analytics/routes.py (time-bucketed totals)
3. ``aggregate`` in answer/routes/base.py (24h sum for rate limiting)
4. ``count_documents`` in answer/routes/base.py (24h request count)
"""
from __future__ import annotations
from datetime import datetime
from typing import Optional
from sqlalchemy import Connection, text
class TokenUsageRepository:
"""Postgres-backed replacement for Mongo ``token_usage_collection``."""
def __init__(self, conn: Connection) -> None:
self._conn = conn
def insert(
self,
*,
user_id: Optional[str] = None,
api_key: Optional[str] = None,
agent_id: Optional[str] = None,
prompt_tokens: int = 0,
generated_tokens: int = 0,
timestamp: Optional[datetime] = None,
) -> None:
self._conn.execute(
text(
"""
INSERT INTO token_usage (user_id, api_key, agent_id, prompt_tokens, generated_tokens, timestamp)
VALUES (
:user_id, :api_key,
CAST(:agent_id AS uuid),
:prompt_tokens, :generated_tokens,
COALESCE(:timestamp, now())
)
"""
),
{
"user_id": user_id,
"api_key": api_key,
"agent_id": agent_id,
"prompt_tokens": prompt_tokens,
"generated_tokens": generated_tokens,
"timestamp": timestamp,
},
)
def sum_tokens_in_range(
self,
*,
start: datetime,
end: datetime,
user_id: Optional[str] = None,
api_key: Optional[str] = None,
) -> int:
"""Total (prompt + generated) tokens in the given time range."""
clauses = ["timestamp >= :start", "timestamp <= :end"]
params: dict = {"start": start, "end": end}
if user_id is not None:
clauses.append("user_id = :user_id")
params["user_id"] = user_id
if api_key is not None:
clauses.append("api_key = :api_key")
params["api_key"] = api_key
where = " AND ".join(clauses)
result = self._conn.execute(
text(f"SELECT COALESCE(SUM(prompt_tokens + generated_tokens), 0) FROM token_usage WHERE {where}"),
params,
)
return result.scalar()
def count_in_range(
self,
*,
start: datetime,
end: datetime,
user_id: Optional[str] = None,
api_key: Optional[str] = None,
) -> int:
"""Count of token_usage rows in the given time range (for request limiting)."""
clauses = ["timestamp >= :start", "timestamp <= :end"]
params: dict = {"start": start, "end": end}
if user_id is not None:
clauses.append("user_id = :user_id")
params["user_id"] = user_id
if api_key is not None:
clauses.append("api_key = :api_key")
params["api_key"] = api_key
where = " AND ".join(clauses)
result = self._conn.execute(
text(f"SELECT COUNT(*) FROM token_usage WHERE {where}"),
params,
)
return result.scalar()

View File

@@ -0,0 +1,84 @@
"""Repository for the ``user_logs`` table.
Covers every operation the legacy Mongo code performs on
``user_logs_collection``:
1. ``insert_one`` in logging.py (per-request activity log via
``_log_to_mongodb`` — note: the *Mongo* variable is confusingly named
``user_logs_collection`` but points at the ``user_logs`` Mongo
collection, not ``stack_logs``)
2. ``insert_one`` in answer/routes/base.py (per-stream log entry)
3. ``find`` with sort/skip/limit in analytics/routes.py (paginated log list)
"""
from __future__ import annotations
import json
from datetime import datetime
from typing import Optional
from sqlalchemy import Connection, text
from application.storage.db.base_repository import row_to_dict
class UserLogsRepository:
"""Postgres-backed replacement for Mongo ``user_logs_collection``."""
def __init__(self, conn: Connection) -> None:
self._conn = conn
def insert(
self,
*,
user_id: Optional[str] = None,
endpoint: Optional[str] = None,
data: Optional[dict] = None,
timestamp: Optional[datetime] = None,
) -> None:
self._conn.execute(
text(
"""
INSERT INTO user_logs (user_id, endpoint, data, timestamp)
VALUES (:user_id, :endpoint, CAST(:data AS jsonb), COALESCE(:timestamp, now()))
"""
),
{
"user_id": user_id,
"endpoint": endpoint,
"data": json.dumps(data) if data is not None else None,
"timestamp": timestamp,
},
)
def list_paginated(
self,
*,
user_id: Optional[str] = None,
api_key: Optional[str] = None,
page: int = 1,
page_size: int = 10,
) -> tuple[list[dict], bool]:
"""Return ``(rows, has_more)`` for the requested page.
Mirrors the Mongo ``find(query).sort().skip().limit(page_size+1)``
pattern used in analytics/routes.py.
"""
clauses: list[str] = []
params: dict = {"limit": page_size + 1, "offset": (page - 1) * page_size}
if user_id is not None:
clauses.append("user_id = :user_id")
params["user_id"] = user_id
if api_key is not None:
clauses.append("data->>'api_key' = :api_key")
params["api_key"] = api_key
where = ("WHERE " + " AND ".join(clauses)) if clauses else ""
result = self._conn.execute(
text(
f"SELECT * FROM user_logs {where} ORDER BY timestamp DESC LIMIT :limit OFFSET :offset"
),
params,
)
rows = [row_to_dict(r) for r in result.fetchall()]
has_more = len(rows) > page_size
return rows[:page_size], has_more

View File

@@ -0,0 +1,114 @@
"""Repository for the ``user_tools`` table.
Covers every operation the legacy Mongo code performs on
``user_tools_collection``:
1. ``find`` by user in tools/routes.py and base.py (list all / active)
2. ``find_one`` by id in tools/routes.py and sharing.py (get single)
3. ``insert_one`` in tools/routes.py and mcp.py (create)
4. ``update_one`` in tools/routes.py and mcp.py (update fields)
5. ``delete_one`` in tools/routes.py (delete)
6. ``find`` by user+status in stream_processor.py and tool_executor.py (active tools)
7. ``find_one`` by user+name in mcp.py (upsert check)
"""
from __future__ import annotations
import json
from typing import Optional
from sqlalchemy import Connection, text
from application.storage.db.base_repository import row_to_dict
class UserToolsRepository:
"""Postgres-backed replacement for Mongo ``user_tools_collection``."""
def __init__(self, conn: Connection) -> None:
self._conn = conn
def create(self, user_id: str, name: str, *, config: Optional[dict] = None,
custom_name: Optional[str] = None, display_name: Optional[str] = None,
extra: Optional[dict] = None) -> dict:
"""Insert a new tool row. ``extra`` is merged into the config JSONB."""
cfg = config or {}
if extra:
cfg.update(extra)
result = self._conn.execute(
text(
"""
INSERT INTO user_tools (user_id, name, custom_name, display_name, config)
VALUES (:user_id, :name, :custom_name, :display_name, CAST(:config AS jsonb))
RETURNING *
"""
),
{
"user_id": user_id,
"name": name,
"custom_name": custom_name,
"display_name": display_name,
"config": json.dumps(cfg),
},
)
return row_to_dict(result.fetchone())
def get(self, tool_id: str, user_id: str) -> Optional[dict]:
result = self._conn.execute(
text("SELECT * FROM user_tools WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
{"id": tool_id, "user_id": user_id},
)
row = result.fetchone()
return row_to_dict(row) if row is not None else None
def list_for_user(self, user_id: str) -> list[dict]:
result = self._conn.execute(
text("SELECT * FROM user_tools WHERE user_id = :user_id ORDER BY created_at"),
{"user_id": user_id},
)
return [row_to_dict(r) for r in result.fetchall()]
def update(self, tool_id: str, user_id: str, fields: dict) -> None:
"""Update arbitrary fields on a tool row.
``fields`` maps column names to new values. Only ``name``,
``custom_name``, ``display_name``, and ``config`` are allowed.
"""
allowed = {"name", "custom_name", "display_name", "config"}
filtered = {k: v for k, v in fields.items() if k in allowed}
if not filtered:
return
params: dict = {
"id": tool_id,
"user_id": user_id,
"name": filtered.get("name"),
"custom_name": filtered.get("custom_name"),
"display_name": filtered.get("display_name"),
"config": (
json.dumps(filtered["config"])
if "config" in filtered and isinstance(filtered["config"], dict)
else filtered.get("config")
),
}
self._conn.execute(
text(
"""
UPDATE user_tools
SET
name = COALESCE(:name, name),
custom_name = COALESCE(:custom_name, custom_name),
display_name = COALESCE(:display_name, display_name),
config = COALESCE(CAST(:config AS jsonb), config),
updated_at = now()
WHERE id = CAST(:id AS uuid) AND user_id = :user_id
"""
),
params,
)
def delete(self, tool_id: str, user_id: str) -> bool:
result = self._conn.execute(
text("DELETE FROM user_tools WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
{"id": tool_id, "user_id": user_id},
)
return result.rowcount > 0

View File

@@ -0,0 +1,245 @@
"""Repository for the ``users`` table.
Covers every operation the legacy Mongo code performs on
``users_collection``:
1. ``ensure_user_doc`` in ``application/api/user/base.py`` (upsert + get)
2. Pin/unpin agents in ``application/api/user/agents/routes.py`` (add/remove
on ``agent_preferences.pinned``)
3. Share accept/reject in ``application/api/user/agents/sharing.py`` (add/
bulk-remove on ``agent_preferences.shared_with_me``)
4. Cascade delete of an agent id from both arrays at once
All array mutations are implemented as single atomic UPDATE statements
using JSONB operators (``jsonb_set``, ``jsonb_array_elements``, ``@>``)
so there is no read-modify-write race between concurrent writers on the
same user row.
The repository takes a ``Connection`` and does not manage its own
transactions. Callers are responsible for wrapping writes in
``with engine.begin() as conn:`` (production) or the test fixture's
rollback-per-test connection (tests).
"""
from __future__ import annotations
from typing import Iterable, Optional
from sqlalchemy import Connection, text
from application.storage.db.base_repository import row_to_dict
_DEFAULT_PREFERENCES = '{"pinned": [], "shared_with_me": []}'
class UsersRepository:
"""Postgres-backed replacement for Mongo ``users_collection`` writes/reads."""
def __init__(self, conn: Connection) -> None:
self._conn = conn
# ------------------------------------------------------------------
# Reads
# ------------------------------------------------------------------
def get(self, user_id: str) -> Optional[dict]:
"""Return the user row as a dict, or ``None`` if missing.
Args:
user_id: Auth-provider ``sub`` (opaque string).
"""
result = self._conn.execute(
text("SELECT * FROM users WHERE user_id = :user_id"),
{"user_id": user_id},
)
row = result.fetchone()
return row_to_dict(row) if row is not None else None
# ------------------------------------------------------------------
# Upsert
# ------------------------------------------------------------------
def upsert(self, user_id: str) -> dict:
"""Ensure a row exists for ``user_id`` and return it.
Matches Mongo's ``find_one_and_update(..., $setOnInsert, upsert=True,
return_document=AFTER)`` semantics: if the row exists, preferences
are preserved untouched; if it doesn't, a new row is created with
default preferences.
The ``DO UPDATE SET user_id = EXCLUDED.user_id`` branch is a
deliberate no-op that lets ``RETURNING *`` fire on both the insert
and conflict paths (``DO NOTHING`` would suppress the returning).
"""
result = self._conn.execute(
text(
"""
INSERT INTO users (user_id, agent_preferences)
VALUES (:user_id, CAST(:default_prefs AS jsonb))
ON CONFLICT (user_id) DO UPDATE
SET user_id = EXCLUDED.user_id
RETURNING *
"""
),
{"user_id": user_id, "default_prefs": _DEFAULT_PREFERENCES},
)
return row_to_dict(result.fetchone())
# ------------------------------------------------------------------
# Pinned agents
# ------------------------------------------------------------------
def add_pinned(self, user_id: str, agent_id: str) -> None:
"""Idempotently append ``agent_id`` to ``agent_preferences.pinned``.
Uses ``@>`` containment so a duplicate add is a no-op rather than a
silent double-insert. The whole update is a single atomic statement
so concurrent add_pinned calls on the same user cannot interleave
into a read-modify-write race.
"""
self._append_to_jsonb_array(user_id, "pinned", agent_id)
def remove_pinned(self, user_id: str, agent_id: str) -> None:
"""Remove ``agent_id`` from ``agent_preferences.pinned`` if present."""
self._remove_from_jsonb_array(user_id, "pinned", [agent_id])
def remove_pinned_bulk(self, user_id: str, agent_ids: Iterable[str]) -> None:
"""Remove every id in ``agent_ids`` from ``agent_preferences.pinned``.
No-op if the list is empty. Unknown ids are silently ignored so
callers can pass the full "stale" set without pre-filtering.
"""
ids = list(agent_ids)
if not ids:
return
self._remove_from_jsonb_array(user_id, "pinned", ids)
# ------------------------------------------------------------------
# Shared-with-me agents
# ------------------------------------------------------------------
def add_shared(self, user_id: str, agent_id: str) -> None:
"""Idempotently append ``agent_id`` to ``agent_preferences.shared_with_me``."""
self._append_to_jsonb_array(user_id, "shared_with_me", agent_id)
def remove_shared_bulk(self, user_id: str, agent_ids: Iterable[str]) -> None:
"""Bulk-remove from ``agent_preferences.shared_with_me``. Empty list is a no-op."""
ids = list(agent_ids)
if not ids:
return
self._remove_from_jsonb_array(user_id, "shared_with_me", ids)
# ------------------------------------------------------------------
# Combined removal — called when an agent is hard-deleted
# ------------------------------------------------------------------
def remove_agent_from_all(self, user_id: str, agent_id: str) -> None:
"""Remove ``agent_id`` from BOTH pinned and shared_with_me atomically.
Mirrors the Mongo ``$pull`` that targets both nested array fields
in one ``update_one`` — see ``application/api/user/agents/routes.py``
around the agent-delete path.
"""
self._conn.execute(
text(
"""
UPDATE users
SET
agent_preferences = jsonb_set(
jsonb_set(
agent_preferences,
'{pinned}',
COALESCE(
(
SELECT jsonb_agg(elem)
FROM jsonb_array_elements(
COALESCE(agent_preferences->'pinned', '[]'::jsonb)
) AS elem
WHERE (elem #>> '{}') != :agent_id
),
'[]'::jsonb
)
),
'{shared_with_me}',
COALESCE(
(
SELECT jsonb_agg(elem)
FROM jsonb_array_elements(
COALESCE(agent_preferences->'shared_with_me', '[]'::jsonb)
) AS elem
WHERE (elem #>> '{}') != :agent_id
),
'[]'::jsonb
)
),
updated_at = now()
WHERE user_id = :user_id
"""
),
{"user_id": user_id, "agent_id": agent_id},
)
# ------------------------------------------------------------------
# Private helpers
# ------------------------------------------------------------------
def _append_to_jsonb_array(self, user_id: str, key: str, agent_id: str) -> None:
"""Idempotent append of ``agent_id`` to ``agent_preferences.<key>``.
The ``key`` argument is NOT user input — it's hard-coded by the
calling method (``pinned`` / ``shared_with_me``). It goes into the
SQL literal because ``jsonb_set`` requires a path literal, not a
bind parameter. This is safe as long as callers never pass
untrusted strings for ``key``.
"""
if key not in ("pinned", "shared_with_me"):
raise ValueError(f"unsupported jsonb key: {key!r}")
self._conn.execute(
text(
f"""
UPDATE users
SET
agent_preferences = jsonb_set(
agent_preferences,
'{{{key}}}',
CASE
WHEN agent_preferences->'{key}' @> to_jsonb(CAST(:agent_id AS text))
THEN agent_preferences->'{key}'
ELSE
COALESCE(agent_preferences->'{key}', '[]'::jsonb)
|| to_jsonb(CAST(:agent_id AS text))
END
),
updated_at = now()
WHERE user_id = :user_id
"""
),
{"user_id": user_id, "agent_id": agent_id},
)
def _remove_from_jsonb_array(
self, user_id: str, key: str, agent_ids: list[str]
) -> None:
"""Remove every id in ``agent_ids`` from ``agent_preferences.<key>``."""
if key not in ("pinned", "shared_with_me"):
raise ValueError(f"unsupported jsonb key: {key!r}")
self._conn.execute(
text(
f"""
UPDATE users
SET
agent_preferences = jsonb_set(
agent_preferences,
'{{{key}}}',
COALESCE(
(
SELECT jsonb_agg(elem)
FROM jsonb_array_elements(
COALESCE(agent_preferences->'{key}', '[]'::jsonb)
) AS elem
WHERE NOT ((elem #>> '{{}}') = ANY(:agent_ids))
),
'[]'::jsonb
)
),
updated_at = now()
WHERE user_id = :user_id
"""
),
{"user_id": user_id, "agent_ids": agent_ids},
)

View File

@@ -21,10 +21,19 @@ class LocalStorage(BaseStorage):
)
def _get_full_path(self, path: str) -> str:
"""Get absolute path by combining base_dir and path."""
"""Get absolute path by combining base_dir and path.
Raises:
ValueError: If the resolved path escapes base_dir (path traversal).
"""
if os.path.isabs(path):
return path
return os.path.join(self.base_dir, path)
resolved = os.path.realpath(path)
else:
resolved = os.path.realpath(os.path.join(self.base_dir, path))
base = os.path.realpath(self.base_dir)
if not resolved.startswith(base + os.sep) and resolved != base:
raise ValueError(f"Path traversal detected: {path}")
return resolved
def save_file(self, file_data: BinaryIO, path: str, **kwargs) -> dict:
"""Save a file to local storage."""

View File

@@ -2,6 +2,7 @@
import io
import os
import posixpath
from typing import BinaryIO, Callable, List
import boto3
@@ -14,6 +15,20 @@ from botocore.exceptions import ClientError
class S3Storage(BaseStorage):
"""AWS S3 storage implementation."""
@staticmethod
def _validate_path(path: str) -> str:
"""Validate and normalize an S3 key to prevent path traversal.
Raises:
ValueError: If the path contains traversal sequences or is absolute.
"""
if "\x00" in path:
raise ValueError(f"Null byte in path: {path}")
normalized = posixpath.normpath(path)
if normalized.startswith("/") or normalized.startswith(".."):
raise ValueError(f"Path traversal detected: {path}")
return normalized
def __init__(self, bucket_name=None):
"""
Initialize S3 storage.
@@ -46,6 +61,7 @@ class S3Storage(BaseStorage):
**kwargs,
) -> dict:
"""Save a file to S3 storage."""
path = self._validate_path(path)
self.s3.upload_fileobj(
file_data, self.bucket_name, path, ExtraArgs={"StorageClass": storage_class}
)
@@ -61,6 +77,7 @@ class S3Storage(BaseStorage):
def get_file(self, path: str) -> BinaryIO:
"""Get a file from S3 storage."""
path = self._validate_path(path)
if not self.file_exists(path):
raise FileNotFoundError(f"File not found: {path}")
file_obj = io.BytesIO()
@@ -70,6 +87,7 @@ class S3Storage(BaseStorage):
def delete_file(self, path: str) -> bool:
"""Delete a file from S3 storage."""
path = self._validate_path(path)
try:
self.s3.delete_object(Bucket=self.bucket_name, Key=path)
return True
@@ -78,6 +96,7 @@ class S3Storage(BaseStorage):
def file_exists(self, path: str) -> bool:
"""Check if a file exists in S3 storage."""
path = self._validate_path(path)
try:
self.s3.head_object(Bucket=self.bucket_name, Key=path)
return True
@@ -115,6 +134,7 @@ class S3Storage(BaseStorage):
import logging
import tempfile
path = self._validate_path(path)
if not self.file_exists(path):
raise FileNotFoundError(f"File not found in S3: {path}")
with tempfile.NamedTemporaryFile(

View File

@@ -110,6 +110,20 @@ def update_token_usage(decoded_token, user_api_key, token_usage, agent_id=None):
usage_data["agent_id"] = normalized_agent_id
usage_collection.insert_one(usage_data)
from application.storage.db.dual_write import dual_write
from application.storage.db.repositories.token_usage import TokenUsageRepository
dual_write(
TokenUsageRepository,
lambda repo, d=usage_data: repo.insert(
user_id=d.get("user_id"),
api_key=d.get("api_key"),
agent_id=d.get("agent_id"),
prompt_tokens=d["prompt_tokens"],
generated_tokens=d["generated_tokens"],
),
)
def gen_token_usage(func):
def wrapper(self, model, messages, stream, tools, **kwargs):

View File

@@ -11,11 +11,33 @@ from application.storage.storage_creator import StorageCreator
def get_vectorstore(path: str) -> str:
if path:
vectorstore = f"indexes/{path}"
else:
vectorstore = "indexes"
return vectorstore
"""Build a safe local path for a FAISS index.
Args:
path: Source identifier provided by the caller.
Returns:
The validated vectorstore path rooted under ``indexes``.
Raises:
ValueError: If ``path`` escapes the ``indexes`` directory.
"""
base_dir = "indexes"
if not path:
return base_dir
normalized = str(path).strip()
if "\\" in normalized:
raise ValueError("Invalid source_id path")
candidate = os.path.normpath(os.path.join(base_dir, normalized))
base_abs = os.path.abspath(base_dir)
candidate_abs = os.path.abspath(candidate)
if not candidate_abs.startswith(base_abs + os.sep) and candidate_abs != base_abs:
raise ValueError("Invalid source_id path")
return candidate
class FaissStore(BaseVectorStore):

View File

@@ -27,37 +27,42 @@ class PGVectorStore(BaseVectorStore):
self._metadata_column = metadata_column
self._embedding = self._get_embeddings(settings.EMBEDDINGS_NAME, embeddings_key)
# Use provided connection string or fall back to settings
# Use provided connection string or fall back to settings.
# If PGVECTOR_CONNECTION_STRING is not set but POSTGRES_URI is,
# reuse the same cluster — normalize from SQLAlchemy dialect to libpq form.
self._connection_string = connection_string or getattr(settings, 'PGVECTOR_CONNECTION_STRING', None)
if not self._connection_string and getattr(settings, 'POSTGRES_URI', None):
from application.core.db_uri import normalize_pgvector_connection_string
self._connection_string = normalize_pgvector_connection_string(settings.POSTGRES_URI)
if not self._connection_string:
raise ValueError(
"PostgreSQL connection string is required. "
"Set PGVECTOR_CONNECTION_STRING in settings or pass connection_string parameter."
"Set PGVECTOR_CONNECTION_STRING or POSTGRES_URI in settings, "
"or pass connection_string parameter."
)
try:
import psycopg2
from psycopg2.extras import Json
import pgvector.psycopg2
import psycopg
from pgvector.psycopg import register_vector
except ImportError:
raise ImportError(
"Could not import required packages. "
"Please install with `pip install psycopg2-binary pgvector`."
"Please install with `pip install 'psycopg[binary,pool]' pgvector`."
)
self._psycopg2 = psycopg2
self._Json = Json
self._pgvector = pgvector.psycopg2
self._psycopg = psycopg
self._register_vector = register_vector
self._connection = None
self._ensure_table_exists()
def _get_connection(self):
"""Get or create database connection"""
if self._connection is None or self._connection.closed:
self._connection = self._psycopg2.connect(self._connection_string)
self._connection = self._psycopg.connect(self._connection_string)
# Register pgvector types
self._pgvector.register_vector(self._connection)
self._register_vector(self._connection)
return self._connection
def _ensure_table_exists(self):
@@ -170,7 +175,7 @@ class PGVectorStore(BaseVectorStore):
for text, embedding, metadata in zip(texts, embeddings, metadatas):
cursor.execute(
insert_query,
(text, embedding, self._Json(metadata), self._source_id)
(text, embedding, metadata, self._source_id)
)
inserted_id = cursor.fetchone()[0]
inserted_ids.append(str(inserted_id))
@@ -261,7 +266,7 @@ class PGVectorStore(BaseVectorStore):
cursor.execute(
insert_query,
(text, embeddings[0], self._Json(final_metadata), self._source_id)
(text, embeddings[0], final_metadata, self._source_id)
)
inserted_id = cursor.fetchone()[0]
conn.commit()

View File

@@ -247,7 +247,7 @@ def extract_zip_recursive(zip_path, extract_to, current_depth=0, max_depth=5):
def download_file(url, params, dest_path):
try:
response = requests.get(url, params=params)
response = requests.get(url, params=params, timeout=100)
response.raise_for_status()
with open(dest_path, "wb") as f:
f.write(response.content)
@@ -284,12 +284,14 @@ def upload_index(full_path, file_data):
files=files,
data=file_data,
headers=headers,
timeout=100,
)
else:
response = requests.post(
urljoin(settings.API_URL, "/api/upload_index"),
data=file_data,
headers=headers,
timeout=100,
)
response.raise_for_status()
except (requests.RequestException, FileNotFoundError) as e:
@@ -1171,6 +1173,16 @@ def attachment_worker(self, file_info, user):
}
)
from application.storage.db.dual_write import dual_write
from application.storage.db.repositories.attachments import AttachmentsRepository
dual_write(
AttachmentsRepository,
lambda repo, u=user, fn=filename, p=relative_path, mt=mime_type: repo.create(
u, fn, p, mime_type=mt,
),
)
logging.info(
f"Stored attachment with ID: {attachment_id}", extra={"user": user}
)

View File

@@ -7,6 +7,10 @@ export default {
"title": "🔌 Agent API",
"href": "/Agents/api"
},
"openai-compatible": {
"title": "🔄 OpenAI-Compatible API",
"href": "/Agents/openai-compatible"
},
"webhooks": {
"title": "🪝 Agent Webhooks",
"href": "/Agents/webhooks"

View File

@@ -15,6 +15,10 @@ DocsGPT Agents can be accessed programmatically through API endpoints. This page
When you use an agent `api_key`, DocsGPT loads that agent's configuration automatically (prompt, tools, sources, default model). You usually only need to send `question` and `api_key`.
<Callout type="info">
Looking to connect an existing OpenAI-compatible client (opencode, aider, the OpenAI SDKs, etc.) to a DocsGPT Agent? Use the [OpenAI-Compatible Chat Completions API](/Agents/openai-compatible) — it speaks the standard chat completions protocol so no adapter code is required.
</Callout>
## Base URL
<Callout type="info">

View File

@@ -111,6 +111,7 @@ Once an agent is created, you can:
* Modify any of its configuration settings (name, description, source, prompt, tools, type).
* **Generate a Public Link:** From the edit screen, you can create a shareable public link that allows others to import and use your agent.
* **Get a Webhook URL:** You can also obtain a Webhook URL for the agent. This allows external applications or services to trigger the agent and receive responses programmatically, enabling powerful integrations and automations.
* **Use it via API:** Every agent exposes an API key that can be used with the native [Agent API](/Agents/api) or the [OpenAI-Compatible API](/Agents/openai-compatible) so you can drop DocsGPT Agents into any tool that already speaks the chat completions protocol.
## Seeding Premade Agents from YAML

View File

@@ -0,0 +1,93 @@
---
title: OpenAI-Compatible API
description: Connect any OpenAI-compatible client to DocsGPT Agents via /v1/chat/completions.
---
import { Callout, Tabs } from 'nextra/components';
# OpenAI-Compatible API
DocsGPT exposes `/v1/chat/completions` following the standard chat completions protocol. Point any compatible client — **opencode**, **Aider**, **LibreChat** or the OpenAI SDKs — at your DocsGPT Agent by changing only the base URL and API key.
## Quick Start
<Tabs items={['Python', 'cURL']}>
<Tabs.Tab>
```python
from openai import OpenAI
client = OpenAI(
base_url="http://localhost:7091/v1", # or https://gptcloud.arc53.com/v1
api_key="your_agent_api_key",
)
response = client.chat.completions.create(
model="docsgpt-agent",
messages=[{"role": "user", "content": "Summarize our refund policy"}],
)
print(response.choices[0].message.content)
```
</Tabs.Tab>
<Tabs.Tab>
```bash
curl -X POST http://localhost:7091/v1/chat/completions \
-H "Authorization: Bearer your_agent_api_key" \
-H "Content-Type: application/json" \
-d '{"model":"docsgpt-agent","messages":[{"role":"user","content":"Summarize our refund policy"}]}'
```
</Tabs.Tab>
</Tabs>
The `model` field is accepted but ignored — the agent bound to your API key determines the model. The agent's prompt, sources, tools, and default model are loaded automatically.
## Base URL & Auth
| Environment | Base URL |
| --- | --- |
| Local | `http://localhost:7091/v1` |
| Cloud | `https://gptcloud.arc53.com/v1` |
Authenticate with `Authorization: Bearer <agent_api_key>`.
## Endpoints
| Method | Path | Description |
| --- | --- | --- |
| `POST` | `/v1/chat/completions` | Chat request (streaming or non-streaming) |
| `GET` | `/v1/models` | List agents available to your key |
## Streaming
Set `"stream": true`. You'll receive SSE chunks with `choices[0].delta.content`. DocsGPT-specific events (sources, tool calls) arrive as extra frames with a `docsgpt` key — standard clients ignore them.
```python
stream = client.chat.completions.create(
model="docsgpt-agent",
stream=True,
messages=[{"role": "user", "content": "Explain vector search"}],
)
for chunk in stream:
print(chunk.choices[0].delta.content or "", end="", flush=True)
```
## System Prompt Override
System messages are **dropped by default** — the agent's configured prompt is used. To allow callers to override it, enable **Allow prompt override** in the agent's Advanced settings.
<Callout type="warning">
When an override is active, the agent's prompt template is replaced wholesale — template variables like `{summaries}` are not substituted.
</Callout>
## Conversation Persistence
Conversations are **not persisted by default** (stateless, like most OpenAI clients expect). Opt in per request:
```json
{ "docsgpt": { "save_conversation": true } }
```
The response will include `docsgpt.conversation_id`.
## When to Use Native Endpoints Instead
Use [`/api/answer` or `/stream`](/Agents/api) if you need server-side attachments, `passthrough` template variables, explicit `conversation_id` reuse, or persistence by default.

View File

@@ -0,0 +1,114 @@
---
title: PostgreSQL for User Data
description: Set up PostgreSQL as the user-data store for DocsGPT and migrate from MongoDB at your own pace.
---
import { Callout } from 'nextra/components'
# PostgreSQL for User Data
DocsGPT is progressively moving user data (conversations, agents, prompts,
preferences, etc.) from MongoDB to PostgreSQL, one collection at a time.
Each collection is guarded by a feature flag so you can opt in and roll
back instantly. MongoDB stays the source of truth until you cut over
reads; vector stores (`VECTOR_STORE=pgvector`, `faiss`, `qdrant`, `mongodb`, …)
are unaffected.
<Callout type="info" emoji="">
Which collections are available today is in the [Status](#status)
table below. That table is the only part of this page that changes
release to release.
</Callout>
## Setup
1. **Run Postgres 13+.** Native install, Docker, or managed (Neon, RDS,
Supabase, Cloud SQL…) — all work. You'll need the `pgcrypto` and
`citext` extensions, both standard contrib modules available
everywhere.
2. **Create a database and role** (skip if your managed provider gave
you these):
```sql
CREATE ROLE docsgpt LOGIN PASSWORD 'docsgpt';
CREATE DATABASE docsgpt OWNER docsgpt;
```
3. **Set `POSTGRES_URI` in `.env`.** Any standard Postgres URI works —
DocsGPT normalizes it internally.
```bash
POSTGRES_URI=postgresql://docsgpt:docsgpt@localhost:5432/docsgpt
# Append ?sslmode=require for managed providers that enforce SSL.
```
4. **Apply the schema** (idempotent — safe to re-run):
```bash
python scripts/db/init_postgres.py
```
## Migrating data
Two global flags, no per-collection knobs — every collection marked ✅
in the [Status](#status) table is handled automatically.
1. **Enable dual-write.** Writes go to both Mongo and Postgres; Mongo
remains source of truth. Set the flag in `.env` and restart:
```bash
USE_POSTGRES=true
```
2. **Backfill existing data.** Idempotent — re-run any time to re-sync
drifted rows. Without arguments, backfills every registered table;
pass `--tables` to limit.
```bash
python scripts/db/backfill.py --dry-run # preview everything
python scripts/db/backfill.py # real run, everything
python scripts/db/backfill.py --tables users # only specific tables
```
3. **Cut over reads** once you trust the Postgres state:
```bash
READ_POSTGRES=true
```
Rollback is instant: unset `READ_POSTGRES` and restart. Dual-write
keeps Postgres up to date so you can flip back and forth.
<Callout type="warning" emoji="⚠️">
Don't decommission MongoDB until every collection you use is fully
cut over. During the migration window, Mongo is still required.
</Callout>
## Status
_Last updated: 2026-04-10_
| Collection | Status |
|---|---|
| `users` | ✅ Phase 1 |
| `prompts`, `user_tools`, `feedback`, `stack_logs`, `user_logs`, `token_usage` | ⏳ Phase 1 |
| `agents`, `sources`, `attachments`, `memories`, `todos`, `notes`, `connector_sessions`, `agent_folders` | ⏳ Phase 2 |
| `conversations`, `pending_tool_state`, `workflows` | ⏳ Phase 3 |
Schemas for **every** row above already exist after `init_postgres.py`
runs. What's landing progressively is the application-level dual-write
wiring and the backfill logic for each collection. Once a collection
is ✅, enabling `USE_POSTGRES=true` and running `python scripts/db/backfill.py`
picks it up automatically — no per-collection config change.
## Troubleshooting
- **`relation "..." does not exist`** — run `python scripts/db/init_postgres.py`.
- **`FATAL: role "docsgpt" does not exist`** — run the `CREATE ROLE` /
`CREATE DATABASE` statements from step 2 as a Postgres superuser.
- **SSL errors on a managed provider** — append `?sslmode=require` to
`POSTGRES_URI`.
- **Dual-write warnings in the logs** — expected to be non-fatal. Mongo
is source of truth, so the user-facing request succeeds. Re-run the
backfill to re-sync whichever rows drifted.

View File

@@ -19,6 +19,10 @@ export default {
"title": "☁️ Hosting DocsGPT",
"href": "/Deploying/Hosting-the-app"
},
"Postgres-Migration": {
"title": "🐘 PostgreSQL for User Data",
"href": "/Deploying/Postgres-Migration"
},
"Amazon-Lightsail": {
"title": "Hosting DocsGPT on Amazon Lightsail",
"href": "/Deploying/Amazon-Lightsail",

View File

@@ -54,8 +54,8 @@ flowchart LR
* **Technology:** Supports multiple LLM APIs and local engines.
* **Responsibility:** This layer provides an abstraction for interacting with Large Language Models (LLMs).
* **Key Features:**
* Supports LLMs from OpenAI, Google, Anthropic, Groq, HuggingFace Inference API, Azure OpenAI, also compatable with local models like Ollama, LLaMa.cpp, Text Generation Inference (TGI), SGLang, vLLM, Aphrodite, FriendliAI, and LMDeploy.
* Manages API key handling and request formatting and Tool fromatting.
* Supports LLMs from OpenAI, Google, Anthropic, Groq, HuggingFace Inference API, Azure OpenAI, also compatible with local models like Ollama, LLaMa.cpp, Text Generation Inference (TGI), SGLang, vLLM, Aphrodite, FriendliAI, and LMDeploy.
* Manages API key handling and request formatting and Tool formatting.
* Offers caching mechanisms to improve response times and reduce API usage.
* Handles streaming responses for a more interactive user experience.
@@ -120,7 +120,7 @@ sequenceDiagram
## Deployment Architecture
DocsGPT is designed to be deployed using Docker and Kubernetes, here is a qucik overview of a simple k8s deployment.
DocsGPT is designed to be deployed using Docker and Kubernetes, here is a quick overview of a simple k8s deployment.
```mermaid
graph LR

View File

@@ -7,6 +7,10 @@ export default {
"title": "🔗 SharePoint / OneDrive",
"href": "/Guides/Integrations/sharepoint-connector"
},
"confluence-connector": {
"title": "🔗 Confluence",
"href": "/Guides/Integrations/confluence-connector"
},
"mcp-tool-integration": {
"title": "🔗 MCP Tools",
"href": "/Guides/Integrations/mcp-tool-integration"

View File

@@ -0,0 +1,67 @@
---
title: Confluence Connector
description: Connect your Confluence Cloud workspace as an external knowledge base to upload and process pages directly.
---
import { Callout } from 'nextra/components'
import { Steps } from 'nextra/components'
# Confluence Connector
Connect your Confluence Cloud workspace to upload and process pages directly as an external knowledge base. Supports page content and attachments (PDFs, Office files, text files, images, and more). Authentication is handled via Atlassian OAuth 2.0 with automatic token refresh.
## Setup
<Steps>
### Step 1: Create an OAuth 2.0 App in Atlassian
1. Go to [developer.atlassian.com/console/myapps](https://developer.atlassian.com/console/myapps/) and click **Create** > **OAuth 2.0 integration**
2. Under **Authorization**, add a callback URL:
- Local: `http://localhost:7091/api/connectors/callback?provider=confluence`
- Production: `https://yourdomain.com/api/connectors/callback?provider=confluence`
### Step 2: Configure Permissions
In your app settings, go to **Permissions** and add the **Confluence API**. Enable these scopes:
- `read:page:confluence`
- `read:space:confluence`
- `read:attachment:confluence`
### Step 3: Get Your Credentials
Go to **Settings** in your app to find the **Client ID** and **Secret**. Copy both.
### Step 4: Configure Environment Variables
Add to your backend `.env` file:
```env
CONFLUENCE_CLIENT_ID=your-atlassian-client-id
CONFLUENCE_CLIENT_SECRET=your-atlassian-client-secret
```
Add to your frontend `.env` file:
```env
VITE_CONFLUENCE_CLIENT_ID=your-atlassian-client-id
```
| Variable | Description | Required |
|----------|-------------|----------|
| `CONFLUENCE_CLIENT_ID` | Client ID from your Atlassian OAuth app | Yes |
| `CONFLUENCE_CLIENT_SECRET` | Client secret from your Atlassian OAuth app | Yes |
| `VITE_CONFLUENCE_CLIENT_ID` | Same Client ID, used by the frontend to show the Confluence option | Yes |
### Step 5: Restart and Use
Restart your application, then go to the upload section in DocsGPT and select **Confluence** as the source. You'll be redirected to Atlassian to sign in, then can browse spaces and select pages to process.
</Steps>
## Troubleshooting
- **Option not appearing** — Verify `VITE_CONFLUENCE_CLIENT_ID` is set in the frontend `.env`, then restart.
- **Authentication failed** — Check that the callback URL matches exactly, including `?provider=confluence`.
- **No accessible sites** — Ensure the authenticating user has access to at least one Confluence Cloud site.
- **Permission denied** — Verify that the Confluence API scopes are enabled in your Atlassian app settings.

Some files were not shown because too many files have changed in this diff Show More