mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
Compare commits
1 Commits
compress-c
...
dependabot
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
37c672b891 |
6
.github/dependabot.yml
vendored
6
.github/dependabot.yml
vendored
@@ -13,11 +13,7 @@ updates:
|
||||
directory: "/frontend" # Location of package manifests
|
||||
schedule:
|
||||
interval: "daily"
|
||||
- package-ecosystem: "npm"
|
||||
directory: "/extensions/react-widget"
|
||||
schedule:
|
||||
interval: "daily"
|
||||
- package-ecosystem: "github-actions"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "daily"
|
||||
interval: "daily"
|
||||
|
||||
11
.github/styles/DocsGPT/Spelling.yml
vendored
11
.github/styles/DocsGPT/Spelling.yml
vendored
@@ -1,11 +0,0 @@
|
||||
extends: spelling
|
||||
level: warning
|
||||
message: "Did you really mean '%s'?"
|
||||
ignore:
|
||||
- "**/node_modules/**"
|
||||
- "**/dist/**"
|
||||
- "**/build/**"
|
||||
- "**/coverage/**"
|
||||
- "**/public/**"
|
||||
- "**/static/**"
|
||||
vocab: DocsGPT
|
||||
@@ -1,46 +0,0 @@
|
||||
Ollama
|
||||
Qdrant
|
||||
Milvus
|
||||
Chatwoot
|
||||
Nextra
|
||||
VSCode
|
||||
npm
|
||||
LLMs
|
||||
APIs
|
||||
Groq
|
||||
SGLang
|
||||
LMDeploy
|
||||
OAuth
|
||||
Vite
|
||||
LLM
|
||||
JSONPath
|
||||
UIs
|
||||
configs
|
||||
uncomment
|
||||
qdrant
|
||||
vectorstore
|
||||
docsgpt
|
||||
llm
|
||||
GPUs
|
||||
kubectl
|
||||
Lightsail
|
||||
enqueues
|
||||
chatbot
|
||||
VSCode's
|
||||
Shareability
|
||||
feedbacks
|
||||
automations
|
||||
Premade
|
||||
Signup
|
||||
Repo
|
||||
repo
|
||||
env
|
||||
URl
|
||||
agentic
|
||||
llama_cpp
|
||||
parsable
|
||||
SDKs
|
||||
boolean
|
||||
bool
|
||||
hardcode
|
||||
EOL
|
||||
2
.github/workflows/bandit.yaml
vendored
2
.github/workflows/bandit.yaml
vendored
@@ -21,7 +21,7 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.12'
|
||||
- name: Install dependencies
|
||||
|
||||
8
.github/workflows/pytest.yml
vendored
8
.github/workflows/pytest.yml
vendored
@@ -10,21 +10,21 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install pytest pytest-cov
|
||||
cd application
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
|
||||
cd ../tests
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
|
||||
- name: Test with pytest and generate coverage report
|
||||
run: |
|
||||
python -m pytest --cov=application --cov-report=xml --cov-report=term-missing
|
||||
python -m pytest --cov=application --cov-report=xml
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: github.event_name == 'pull_request' && matrix.python-version == '3.12'
|
||||
uses: codecov/codecov-action@v5
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
|
||||
|
||||
26
.github/workflows/vale.yml
vendored
26
.github/workflows/vale.yml
vendored
@@ -1,26 +0,0 @@
|
||||
name: Vale Documentation Linter
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- 'docs/**/*.md'
|
||||
- 'docs/**/*.mdx'
|
||||
- '**/*.md'
|
||||
- '.vale.ini'
|
||||
- '.github/styles/**'
|
||||
|
||||
jobs:
|
||||
vale:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Vale linter
|
||||
uses: errata-ai/vale-action@v2
|
||||
with:
|
||||
files: docs
|
||||
fail_on_error: false
|
||||
version: 3.0.5
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -2,9 +2,7 @@
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
experiments/
|
||||
|
||||
experiments
|
||||
# C extensions
|
||||
*.so
|
||||
*.next
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
MinAlertLevel = warning
|
||||
StylesPath = .github/styles
|
||||
|
||||
[*.{md,mdx}]
|
||||
BasedOnStyles = DocsGPT
|
||||
33
.vscode/launch.json
vendored
33
.vscode/launch.json
vendored
@@ -2,11 +2,15 @@
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Frontend Debug (npm)",
|
||||
"type": "node-terminal",
|
||||
"name": "Docker Debug Frontend",
|
||||
"request": "launch",
|
||||
"command": "npm run dev",
|
||||
"cwd": "${workspaceFolder}/frontend"
|
||||
"type": "chrome",
|
||||
"preLaunchTask": "docker-compose: debug:frontend",
|
||||
"url": "http://127.0.0.1:5173",
|
||||
"webRoot": "${workspaceFolder}/frontend",
|
||||
"skipFiles": [
|
||||
"<node_internals>/**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Flask Debugger",
|
||||
@@ -45,27 +49,6 @@
|
||||
"--pool=solo"
|
||||
],
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
{
|
||||
"name": "Dev Containers (Mongo + Redis)",
|
||||
"type": "node-terminal",
|
||||
"request": "launch",
|
||||
"command": "docker compose -f deployment/docker-compose-dev.yaml up --build",
|
||||
"cwd": "${workspaceFolder}"
|
||||
}
|
||||
],
|
||||
"compounds": [
|
||||
{
|
||||
"name": "DocsGPT: Full Stack",
|
||||
"configurations": [
|
||||
"Frontend Debug (npm)",
|
||||
"Flask Debugger",
|
||||
"Celery Debugger"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "DocsGPT",
|
||||
"order": 1
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
21
.vscode/tasks.json
vendored
Normal file
21
.vscode/tasks.json
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
{
|
||||
"version": "2.0.0",
|
||||
"tasks": [
|
||||
{
|
||||
"type": "docker-compose",
|
||||
"label": "docker-compose: debug:frontend",
|
||||
"dockerCompose": {
|
||||
"up": {
|
||||
"detached": true,
|
||||
"services": [
|
||||
"frontend"
|
||||
],
|
||||
"build": true
|
||||
},
|
||||
"files": [
|
||||
"${workspaceFolder}/docker-compose.yaml"
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -147,5 +147,5 @@ Here's a step-by-step guide on how to contribute to DocsGPT:
|
||||
Thank you for considering contributing to DocsGPT! 🙏
|
||||
|
||||
## Questions/collaboration
|
||||
Feel free to join our [Discord](https://discord.gg/vN7YFfdMpj). We're very friendly and welcoming to new contributors, so don't hesitate to reach out.
|
||||
Feel free to join our [Discord](https://discord.gg/n5BX8dh8rU). We're very friendly and welcoming to new contributors, so don't hesitate to reach out.
|
||||
# Thank you so much for considering to contributing DocsGPT!🙏
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
# **🎉 Join the Hacktoberfest with DocsGPT and win a Free T-shirt for a meaningful PR! 🎉**
|
||||
|
||||
Welcome, contributors! We're excited to announce that DocsGPT is participating in Hacktoberfest. Get involved by submitting meaningful pull requests.
|
||||
|
||||
All Meaningful contributors with accepted PRs that were created for issues with the `hacktoberfest` label (set by our maintainer team: dartpain, siiddhantt, pabik, ManishMadan2882) will receive a cool T-shirt! 🤩.
|
||||
<img width="1331" height="678" alt="hacktoberfest-mocks-preview" src="https://github.com/user-attachments/assets/633f6377-38db-48f5-b519-a8b3855a9eb4" />
|
||||
|
||||
Fill in [this form](https://forms.gle/Npaba4n9Epfyx56S8
|
||||
) after your PR was merged please
|
||||
|
||||
If you are in doubt don't hesitate to ping us on discord, ping me - Alex (dartpain).
|
||||
|
||||
## 📜 Here's How to Contribute:
|
||||
```text
|
||||
🛠️ Code: This is the golden ticket! Make meaningful contributions through PRs.
|
||||
|
||||
🧩 API extension: Build an app utilising DocsGPT API. We prefer submissions that showcase original ideas and turn the API into an AI agent.
|
||||
They can be a completely separate repos.
|
||||
For example:
|
||||
https://github.com/arc53/tg-bot-docsgpt-extenstion or
|
||||
https://github.com/arc53/DocsGPT-cli
|
||||
|
||||
Non-Code Contributions:
|
||||
|
||||
📚 Wiki: Improve our documentation, create a guide.
|
||||
|
||||
🖥️ Design: Improve the UI/UX or design a new feature.
|
||||
```
|
||||
|
||||
### 📝 Guidelines for Pull Requests:
|
||||
- Familiarize yourself with the current contributions and our [Roadmap](https://github.com/orgs/arc53/projects/2).
|
||||
- Before contributing check existing [issues](https://github.com/arc53/DocsGPT/issues) or [create](https://github.com/arc53/DocsGPT/issues/new/choose) an issue and wait to get assigned.
|
||||
- Once you are finished with your contribution, please fill in this [form](https://forms.gle/Npaba4n9Epfyx56S8).
|
||||
- Refer to the [Documentation](https://docs.docsgpt.cloud/).
|
||||
- Feel free to join our [Discord](https://discord.gg/vN7YFfdMpj) server. We're here to help newcomers, so don't hesitate to jump in! Join us [here](https://discord.gg/vN7YFfdMpj).
|
||||
|
||||
Thank you very much for considering contributing to DocsGPT during Hacktoberfest! 🙏 Your contributions (not just simple typos) could earn you a stylish new t-shirt.
|
||||
|
||||
We will publish a t-shirt design later into the October.
|
||||
26
README.md
26
README.md
@@ -16,26 +16,16 @@
|
||||
<a href="https://github.com/arc53/DocsGPT"></a>
|
||||
<a href="https://github.com/arc53/DocsGPT/blob/main/LICENSE"></a>
|
||||
<a href="https://www.bestpractices.dev/projects/9907"><img src="https://www.bestpractices.dev/projects/9907/badge"></a>
|
||||
<a href="https://discord.gg/vN7YFfdMpj"></a>
|
||||
<a href="https://x.com/docsgptai"></a>
|
||||
<a href="https://discord.gg/n5BX8dh8rU"></a>
|
||||
<a href="https://twitter.com/docsgptai"></a>
|
||||
|
||||
<a href="https://docs.docsgpt.cloud/quickstart">⚡️ Quickstart</a> • <a href="https://app.docsgpt.cloud/">☁️ Cloud Version</a> • <a href="https://discord.gg/vN7YFfdMpj">💬 Discord</a>
|
||||
<a href="https://docs.docsgpt.cloud/quickstart">⚡️ Quickstart</a> • <a href="https://app.docsgpt.cloud/">☁️ Cloud Version</a> • <a href="https://discord.gg/n5BX8dh8rU">💬 Discord</a>
|
||||
<br>
|
||||
<a href="https://docs.docsgpt.cloud/">📖 Documentation</a> • <a href="https://github.com/arc53/DocsGPT/blob/main/CONTRIBUTING.md">👫 Contribute</a> • <a href="https://blog.docsgpt.cloud/">🗞 Blog</a>
|
||||
<br>
|
||||
|
||||
</div>
|
||||
|
||||
<div align="center">
|
||||
<br>
|
||||
🎃 <a href="https://github.com/arc53/DocsGPT/blob/main/HACKTOBERFEST.md"> Hacktoberfest Prizes, Rules & Q&A </a> 🎃
|
||||
<br>
|
||||
<br>
|
||||
</div>
|
||||
|
||||
|
||||
<div align="center">
|
||||
<br>
|
||||
<img src="https://d3dg1063dc54p9.cloudfront.net/videos/demov7.gif" alt="video-example-of-docs-gpt" width="800" height="450">
|
||||
</div>
|
||||
<h3 align="left">
|
||||
@@ -65,11 +55,9 @@
|
||||
- [x] Agent optimisations (May 2025)
|
||||
- [x] Filesystem sources update (July 2025)
|
||||
- [x] Json Responses (August 2025)
|
||||
- [x] MCP support (August 2025)
|
||||
- [x] Google Drive integration (September 2025)
|
||||
- [x] Add OAuth 2.0 authentication for MCP (September 2025)
|
||||
- [ ] SharePoint integration (October 2025)
|
||||
- [ ] Deep Agents (October 2025)
|
||||
- [ ] Sharepoint integration (August 2025)
|
||||
- [ ] MCP support (August 2025)
|
||||
- [ ] Add OAuth 2.0 authentication for tools and sources (August 2025)
|
||||
- [ ] Agent scheduling
|
||||
|
||||
You can find our full roadmap [here](https://github.com/orgs/arc53/projects/2). Please don't hesitate to contribute or create issues, it helps us improve DocsGPT!
|
||||
@@ -118,7 +106,7 @@ A more detailed [Quickstart](https://docs.docsgpt.cloud/quickstart) is available
|
||||
PowerShell -ExecutionPolicy Bypass -File .\setup.ps1
|
||||
```
|
||||
|
||||
Either script will guide you through setting up DocsGPT. Five options available: using the public API, running locally, connecting to a local inference engine, using a cloud API provider, or build the docker image locally. Scripts will automatically configure your `.env` file and handle necessary downloads and installations based on your chosen option.
|
||||
Either script will guide you through setting up DocsGPT. Four options available: using the public API, running locally, connecting to a local inference engine, or using a cloud API provider. Scripts will automatically configure your `.env` file and handle necessary downloads and installations based on your chosen option.
|
||||
|
||||
**Navigate to http://localhost:5173/**
|
||||
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
from application.agents.classic_agent import ClassicAgent
|
||||
from application.agents.react_agent import ReActAgent
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentCreator:
|
||||
@@ -16,5 +13,4 @@ class AgentCreator:
|
||||
agent_class = cls.agents.get(type.lower())
|
||||
if not agent_class:
|
||||
raise ValueError(f"No agent class found for type {type}")
|
||||
|
||||
return agent_class(*args, **kwargs)
|
||||
|
||||
@@ -12,6 +12,7 @@ from application.core.settings import settings
|
||||
from application.llm.handlers.handler_creator import LLMHandlerCreator
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.logging import build_stack_data, log_activity, LogContext
|
||||
from application.retriever.base import BaseRetriever
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -21,29 +22,23 @@ class BaseAgent(ABC):
|
||||
self,
|
||||
endpoint: str,
|
||||
llm_name: str,
|
||||
model_id: str,
|
||||
gpt_model: str,
|
||||
api_key: str,
|
||||
user_api_key: Optional[str] = None,
|
||||
prompt: str = "",
|
||||
chat_history: Optional[List[Dict]] = None,
|
||||
retrieved_docs: Optional[List[Dict]] = None,
|
||||
decoded_token: Optional[Dict] = None,
|
||||
attachments: Optional[List[Dict]] = None,
|
||||
json_schema: Optional[Dict] = None,
|
||||
limited_token_mode: Optional[bool] = False,
|
||||
token_limit: Optional[int] = settings.DEFAULT_AGENT_LIMITS["token_limit"],
|
||||
limited_request_mode: Optional[bool] = False,
|
||||
request_limit: Optional[int] = settings.DEFAULT_AGENT_LIMITS["request_limit"],
|
||||
compressed_summary: Optional[str] = None,
|
||||
):
|
||||
self.endpoint = endpoint
|
||||
self.llm_name = llm_name
|
||||
self.model_id = model_id
|
||||
self.gpt_model = gpt_model
|
||||
self.api_key = api_key
|
||||
self.user_api_key = user_api_key
|
||||
self.prompt = prompt
|
||||
self.decoded_token = decoded_token or {}
|
||||
self.user: str = self.decoded_token.get("sub")
|
||||
self.user: str = decoded_token.get("sub")
|
||||
self.tool_config: Dict = {}
|
||||
self.tools: List[Dict] = []
|
||||
self.tool_calls: List[Dict] = []
|
||||
@@ -53,31 +48,22 @@ class BaseAgent(ABC):
|
||||
api_key=api_key,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
model_id=model_id,
|
||||
)
|
||||
self.retrieved_docs = retrieved_docs or []
|
||||
self.llm_handler = LLMHandlerCreator.create_handler(
|
||||
llm_name if llm_name else "default"
|
||||
)
|
||||
self.attachments = attachments or []
|
||||
self.json_schema = json_schema
|
||||
self.limited_token_mode = limited_token_mode
|
||||
self.token_limit = token_limit
|
||||
self.limited_request_mode = limited_request_mode
|
||||
self.request_limit = request_limit
|
||||
self.compressed_summary = compressed_summary
|
||||
self.current_token_count = 0
|
||||
self.context_limit_reached = False
|
||||
|
||||
@log_activity()
|
||||
def gen(
|
||||
self, query: str, log_context: LogContext = None
|
||||
self, query: str, retriever: BaseRetriever, log_context: LogContext = None
|
||||
) -> Generator[Dict, None, None]:
|
||||
yield from self._gen_inner(query, log_context)
|
||||
yield from self._gen_inner(query, retriever, log_context)
|
||||
|
||||
@abstractmethod
|
||||
def _gen_inner(
|
||||
self, query: str, log_context: LogContext
|
||||
self, query: str, retriever: BaseRetriever, log_context: LogContext
|
||||
) -> Generator[Dict, None, None]:
|
||||
pass
|
||||
|
||||
@@ -154,31 +140,29 @@ class BaseAgent(ABC):
|
||||
tool_id, action_name, call_args = parser.parse_args(call)
|
||||
|
||||
call_id = getattr(call, "id", None) or str(uuid.uuid4())
|
||||
|
||||
|
||||
# Check if parsing failed
|
||||
|
||||
if tool_id is None or action_name is None:
|
||||
error_message = f"Error: Failed to parse LLM tool call. Tool name: {getattr(call, 'name', 'unknown')}"
|
||||
logger.error(error_message)
|
||||
|
||||
|
||||
tool_call_data = {
|
||||
"tool_name": "unknown",
|
||||
"call_id": call_id,
|
||||
"action_name": getattr(call, "name", "unknown"),
|
||||
"action_name": getattr(call, 'name', 'unknown'),
|
||||
"arguments": call_args or {},
|
||||
"result": f"Failed to parse tool call. Invalid tool name format: {getattr(call, 'name', 'unknown')}",
|
||||
}
|
||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
|
||||
self.tool_calls.append(tool_call_data)
|
||||
return "Failed to parse tool call.", call_id
|
||||
|
||||
# Check if tool_id exists in available tools
|
||||
|
||||
if tool_id not in tools_dict:
|
||||
error_message = f"Error: Tool ID '{tool_id}' extracted from LLM call not found in available tools_dict. Available IDs: {list(tools_dict.keys())}"
|
||||
logger.error(error_message)
|
||||
|
||||
|
||||
# Return error result
|
||||
|
||||
tool_call_data = {
|
||||
"tool_name": "unknown",
|
||||
"call_id": call_id,
|
||||
@@ -189,6 +173,7 @@ class BaseAgent(ABC):
|
||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
|
||||
self.tool_calls.append(tool_call_data)
|
||||
return f"Tool with ID {tool_id} not found.", call_id
|
||||
|
||||
tool_call_data = {
|
||||
"tool_name": tools_dict[tool_id]["name"],
|
||||
"call_id": call_id,
|
||||
@@ -228,26 +213,18 @@ class BaseAgent(ABC):
|
||||
):
|
||||
target_dict[param] = value
|
||||
tm = ToolManager(config={})
|
||||
|
||||
# Prepare tool_config and add tool_id for memory tools
|
||||
|
||||
if tool_data["name"] == "api_tool":
|
||||
tool_config = {
|
||||
"url": tool_data["config"]["actions"][action_name]["url"],
|
||||
"method": tool_data["config"]["actions"][action_name]["method"],
|
||||
"headers": headers,
|
||||
"query_params": query_params,
|
||||
}
|
||||
else:
|
||||
tool_config = tool_data["config"].copy() if tool_data["config"] else {}
|
||||
# Add tool_id from MongoDB _id for tools that need instance isolation (like memory tool)
|
||||
# Use MongoDB _id if available, otherwise fall back to enumerated tool_id
|
||||
|
||||
tool_config["tool_id"] = str(tool_data.get("_id", tool_id))
|
||||
tool = tm.load_tool(
|
||||
tool_data["name"],
|
||||
tool_config=tool_config,
|
||||
user_id=self.user, # Pass user ID for MCP tools credential decryption
|
||||
tool_config=(
|
||||
{
|
||||
"url": tool_data["config"]["actions"][action_name]["url"],
|
||||
"method": tool_data["config"]["actions"][action_name]["method"],
|
||||
"headers": headers,
|
||||
"query_params": query_params,
|
||||
}
|
||||
if tool_data["name"] == "api_tool"
|
||||
else tool_data["config"]
|
||||
),
|
||||
)
|
||||
if tool_data["name"] == "api_tool":
|
||||
print(
|
||||
@@ -280,83 +257,20 @@ class BaseAgent(ABC):
|
||||
for tool_call in self.tool_calls
|
||||
]
|
||||
|
||||
def _calculate_current_context_tokens(self, messages: List[Dict]) -> int:
|
||||
"""
|
||||
Calculate total tokens in current context (messages).
|
||||
|
||||
Args:
|
||||
messages: List of message dicts
|
||||
|
||||
Returns:
|
||||
Total token count
|
||||
"""
|
||||
from application.api.answer.services.compression.token_counter import (
|
||||
TokenCounter,
|
||||
)
|
||||
|
||||
return TokenCounter.count_message_tokens(messages)
|
||||
|
||||
def _check_context_limit(self, messages: List[Dict]) -> bool:
|
||||
"""
|
||||
Check if we're approaching context limit (80%).
|
||||
|
||||
Args:
|
||||
messages: Current message list
|
||||
|
||||
Returns:
|
||||
True if at or above 80% of context limit
|
||||
"""
|
||||
from application.core.model_utils import get_token_limit
|
||||
from application.core.settings import settings
|
||||
|
||||
try:
|
||||
# Calculate current tokens
|
||||
current_tokens = self._calculate_current_context_tokens(messages)
|
||||
self.current_token_count = current_tokens
|
||||
|
||||
# Get context limit for model
|
||||
context_limit = get_token_limit(self.model_id)
|
||||
|
||||
# Calculate threshold (80%)
|
||||
threshold = int(context_limit * settings.COMPRESSION_THRESHOLD_PERCENTAGE)
|
||||
|
||||
# Check if we've reached the limit
|
||||
if current_tokens >= threshold:
|
||||
logger.warning(
|
||||
f"Context limit approaching: {current_tokens}/{context_limit} tokens "
|
||||
f"({(current_tokens/context_limit)*100:.1f}%)"
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking context limit: {str(e)}", exc_info=True)
|
||||
return False
|
||||
|
||||
def _build_messages(
|
||||
self,
|
||||
system_prompt: str,
|
||||
query: str,
|
||||
retrieved_data: List[Dict],
|
||||
) -> List[Dict]:
|
||||
"""Build messages using pre-rendered system prompt"""
|
||||
# Append compression summary to system prompt if present
|
||||
if self.compressed_summary:
|
||||
compression_context = (
|
||||
"\n\n---\n\n"
|
||||
"This session is being continued from a previous conversation that "
|
||||
"has been compressed to fit within context limits. "
|
||||
"The conversation is summarized below:\n\n"
|
||||
f"{self.compressed_summary}"
|
||||
)
|
||||
system_prompt = system_prompt + compression_context
|
||||
|
||||
messages = [{"role": "system", "content": system_prompt}]
|
||||
docs_together = "\n".join([doc["text"] for doc in retrieved_data])
|
||||
p_chat_combine = system_prompt.replace("{summaries}", docs_together)
|
||||
messages_combine = [{"role": "system", "content": p_chat_combine}]
|
||||
|
||||
for i in self.chat_history:
|
||||
if "prompt" in i and "response" in i:
|
||||
messages.append({"role": "user", "content": i["prompt"]})
|
||||
messages.append({"role": "assistant", "content": i["response"]})
|
||||
messages_combine.append({"role": "user", "content": i["prompt"]})
|
||||
messages_combine.append({"role": "assistant", "content": i["response"]})
|
||||
if "tool_calls" in i:
|
||||
for tool_call in i["tool_calls"]:
|
||||
call_id = tool_call.get("call_id") or str(uuid.uuid4())
|
||||
@@ -376,17 +290,29 @@ class BaseAgent(ABC):
|
||||
}
|
||||
}
|
||||
|
||||
messages.append(
|
||||
messages_combine.append(
|
||||
{"role": "assistant", "content": [function_call_dict]}
|
||||
)
|
||||
messages.append(
|
||||
messages_combine.append(
|
||||
{"role": "tool", "content": [function_response_dict]}
|
||||
)
|
||||
messages.append({"role": "user", "content": query})
|
||||
return messages
|
||||
messages_combine.append({"role": "user", "content": query})
|
||||
return messages_combine
|
||||
|
||||
def _retriever_search(
|
||||
self,
|
||||
retriever: BaseRetriever,
|
||||
query: str,
|
||||
log_context: Optional[LogContext] = None,
|
||||
) -> List[Dict]:
|
||||
retrieved_data = retriever.search(query)
|
||||
if log_context:
|
||||
data = build_stack_data(retriever, exclude_attributes=["llm"])
|
||||
log_context.stacks.append({"component": "retriever", "data": data})
|
||||
return retrieved_data
|
||||
|
||||
def _llm_gen(self, messages: List[Dict], log_context: Optional[LogContext] = None):
|
||||
gen_kwargs = {"model": self.model_id, "messages": messages}
|
||||
gen_kwargs = {"model": self.gpt_model, "messages": messages}
|
||||
|
||||
if (
|
||||
hasattr(self.llm, "_supports_tools")
|
||||
@@ -394,6 +320,7 @@ class BaseAgent(ABC):
|
||||
and self.tools
|
||||
):
|
||||
gen_kwargs["tools"] = self.tools
|
||||
|
||||
if (
|
||||
self.json_schema
|
||||
and hasattr(self.llm, "_supports_structured_output")
|
||||
@@ -407,6 +334,7 @@ class BaseAgent(ABC):
|
||||
gen_kwargs["response_format"] = structured_format
|
||||
elif self.llm_name == "google":
|
||||
gen_kwargs["response_schema"] = structured_format
|
||||
|
||||
resp = self.llm.gen_stream(**gen_kwargs)
|
||||
|
||||
if log_context:
|
||||
|
||||
@@ -1,20 +1,32 @@
|
||||
import logging
|
||||
from typing import Dict, Generator
|
||||
|
||||
from application.agents.base import BaseAgent
|
||||
from application.logging import LogContext
|
||||
from application.retriever.base import BaseRetriever
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ClassicAgent(BaseAgent):
|
||||
"""A simplified agent with clear execution flow"""
|
||||
"""A simplified agent with clear execution flow.
|
||||
|
||||
Usage:
|
||||
1. Processes a query through retrieval
|
||||
2. Sets up available tools
|
||||
3. Generates responses using LLM
|
||||
4. Handles tool interactions if needed
|
||||
5. Returns standardized outputs
|
||||
|
||||
Easy to extend by overriding specific steps.
|
||||
"""
|
||||
|
||||
def _gen_inner(
|
||||
self, query: str, log_context: LogContext
|
||||
self, query: str, retriever: BaseRetriever, log_context: LogContext
|
||||
) -> Generator[Dict, None, None]:
|
||||
"""Core generator function for ClassicAgent execution flow"""
|
||||
# Step 1: Retrieve relevant data
|
||||
retrieved_data = self._retriever_search(retriever, query, log_context)
|
||||
|
||||
# Step 2: Prepare tools
|
||||
tools_dict = (
|
||||
self._get_user_tools(self.user)
|
||||
if not self.user_api_key
|
||||
@@ -22,16 +34,20 @@ class ClassicAgent(BaseAgent):
|
||||
)
|
||||
self._prepare_tools(tools_dict)
|
||||
|
||||
messages = self._build_messages(self.prompt, query)
|
||||
# Step 3: Build and process messages
|
||||
messages = self._build_messages(self.prompt, query, retrieved_data)
|
||||
llm_response = self._llm_gen(messages, log_context)
|
||||
|
||||
# Step 4: Handle the response
|
||||
yield from self._handle_response(
|
||||
llm_response, tools_dict, messages, log_context
|
||||
)
|
||||
|
||||
yield {"sources": self.retrieved_docs}
|
||||
# Step 5: Return metadata
|
||||
yield {"sources": retrieved_data}
|
||||
yield {"tool_calls": self._get_truncated_tool_calls()}
|
||||
|
||||
# Log tool calls for debugging
|
||||
log_context.stacks.append(
|
||||
{"component": "agent", "data": {"tool_calls": self.tool_calls.copy()}}
|
||||
)
|
||||
|
||||
@@ -1,238 +1,229 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, Generator, List
|
||||
from typing import Dict, Generator, List, Any
|
||||
import logging
|
||||
|
||||
from application.agents.base import BaseAgent
|
||||
from application.logging import build_stack_data, LogContext
|
||||
from application.retriever.base import BaseRetriever
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_ITERATIONS_REASONING = 10
|
||||
|
||||
current_dir = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
)
|
||||
with open(
|
||||
os.path.join(current_dir, "application/prompts", "react_planning_prompt.txt"), "r"
|
||||
) as f:
|
||||
PLANNING_PROMPT_TEMPLATE = f.read()
|
||||
planning_prompt_template = f.read()
|
||||
with open(
|
||||
os.path.join(current_dir, "application/prompts", "react_final_prompt.txt"), "r"
|
||||
os.path.join(current_dir, "application/prompts", "react_final_prompt.txt"),
|
||||
"r",
|
||||
) as f:
|
||||
FINAL_PROMPT_TEMPLATE = f.read()
|
||||
|
||||
final_prompt_template = f.read()
|
||||
|
||||
MAX_ITERATIONS_REASONING = 10
|
||||
|
||||
class ReActAgent(BaseAgent):
|
||||
"""
|
||||
Research and Action (ReAct) Agent - Advanced reasoning agent with iterative planning.
|
||||
|
||||
Implements a think-act-observe loop for complex problem-solving:
|
||||
1. Creates a strategic plan based on the query
|
||||
2. Executes tools and gathers observations
|
||||
3. Iteratively refines approach until satisfied
|
||||
4. Synthesizes final answer from all observations
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.plan: str = ""
|
||||
self.observations: List[str] = []
|
||||
|
||||
def _extract_content_from_llm_response(self, resp: Any) -> str:
|
||||
"""
|
||||
Helper to extract string content from various LLM response types.
|
||||
Handles strings, message objects (OpenAI-like), and streams.
|
||||
Adapt stream handling for your specific LLM client if not OpenAI.
|
||||
"""
|
||||
collected_content = []
|
||||
if isinstance(resp, str):
|
||||
collected_content.append(resp)
|
||||
elif ( # OpenAI non-streaming or Anthropic non-streaming (older SDK style)
|
||||
hasattr(resp, "message")
|
||||
and hasattr(resp.message, "content")
|
||||
and resp.message.content is not None
|
||||
):
|
||||
collected_content.append(resp.message.content)
|
||||
elif ( # OpenAI non-streaming (Pydantic model), Anthropic new SDK non-streaming
|
||||
hasattr(resp, "choices") and resp.choices and
|
||||
hasattr(resp.choices[0], "message") and
|
||||
hasattr(resp.choices[0].message, "content") and
|
||||
resp.choices[0].message.content is not None
|
||||
):
|
||||
collected_content.append(resp.choices[0].message.content) # OpenAI
|
||||
elif ( # Anthropic new SDK non-streaming content block
|
||||
hasattr(resp, "content") and isinstance(resp.content, list) and resp.content and
|
||||
hasattr(resp.content[0], "text")
|
||||
):
|
||||
collected_content.append(resp.content[0].text) # Anthropic
|
||||
else:
|
||||
# Assume resp is a stream if not a recognized object
|
||||
try:
|
||||
for chunk in resp: # This will fail if resp is not iterable (e.g. a non-streaming response object)
|
||||
content_piece = ""
|
||||
# OpenAI-like stream
|
||||
if hasattr(chunk, 'choices') and len(chunk.choices) > 0 and \
|
||||
hasattr(chunk.choices[0], 'delta') and \
|
||||
hasattr(chunk.choices[0].delta, 'content') and \
|
||||
chunk.choices[0].delta.content is not None:
|
||||
content_piece = chunk.choices[0].delta.content
|
||||
# Anthropic-like stream (ContentBlockDelta)
|
||||
elif hasattr(chunk, 'type') and chunk.type == 'content_block_delta' and \
|
||||
hasattr(chunk, 'delta') and hasattr(chunk.delta, 'text'):
|
||||
content_piece = chunk.delta.text
|
||||
elif isinstance(chunk, str): # Simplest case: stream of strings
|
||||
content_piece = chunk
|
||||
|
||||
if content_piece:
|
||||
collected_content.append(content_piece)
|
||||
except TypeError: # If resp is not iterable (e.g. a final response object that wasn't caught above)
|
||||
logger.debug(f"Response type {type(resp)} could not be iterated as a stream. It might be a non-streaming object not handled by specific checks.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing potential stream chunk: {e}, chunk was: {getattr(chunk, '__dict__', chunk)}")
|
||||
|
||||
|
||||
return "".join(collected_content)
|
||||
|
||||
def _gen_inner(
|
||||
self, query: str, log_context: LogContext
|
||||
self, query: str, retriever: BaseRetriever, log_context: LogContext
|
||||
) -> Generator[Dict, None, None]:
|
||||
"""Execute ReAct reasoning loop with planning, action, and observation cycles"""
|
||||
|
||||
self._reset_state()
|
||||
|
||||
tools_dict = (
|
||||
self._get_tools(self.user_api_key)
|
||||
if self.user_api_key
|
||||
else self._get_user_tools(self.user)
|
||||
)
|
||||
self._prepare_tools(tools_dict)
|
||||
|
||||
for iteration in range(1, MAX_ITERATIONS_REASONING + 1):
|
||||
yield {"thought": f"Reasoning... (iteration {iteration})\n\n"}
|
||||
|
||||
yield from self._planning_phase(query, log_context)
|
||||
|
||||
if not self.plan:
|
||||
logger.warning(
|
||||
f"ReActAgent: No plan generated in iteration {iteration}"
|
||||
)
|
||||
break
|
||||
self.observations.append(f"Plan (iteration {iteration}): {self.plan}")
|
||||
|
||||
satisfied = yield from self._execution_phase(query, tools_dict, log_context)
|
||||
|
||||
if satisfied:
|
||||
logger.info("ReActAgent: Goal satisfied, stopping reasoning loop")
|
||||
break
|
||||
yield from self._synthesis_phase(query, log_context)
|
||||
|
||||
def _reset_state(self):
|
||||
"""Reset agent state for new query"""
|
||||
# Reset state for this generation call
|
||||
self.plan = ""
|
||||
self.observations = []
|
||||
retrieved_data = self._retriever_search(retriever, query, log_context)
|
||||
|
||||
def _planning_phase(
|
||||
self, query: str, log_context: LogContext
|
||||
) -> Generator[Dict, None, None]:
|
||||
"""Generate strategic plan for query"""
|
||||
logger.info("ReActAgent: Creating plan...")
|
||||
if self.user_api_key:
|
||||
tools_dict = self._get_tools(self.user_api_key)
|
||||
else:
|
||||
tools_dict = self._get_user_tools(self.user)
|
||||
self._prepare_tools(tools_dict)
|
||||
|
||||
plan_prompt = self._build_planning_prompt(query)
|
||||
messages = [{"role": "user", "content": plan_prompt}]
|
||||
docs_together = "\n".join([doc["text"] for doc in retrieved_data])
|
||||
iterating_reasoning = 0
|
||||
while iterating_reasoning < MAX_ITERATIONS_REASONING:
|
||||
iterating_reasoning += 1
|
||||
# 1. Create Plan
|
||||
logger.info("ReActAgent: Creating plan...")
|
||||
plan_stream = self._create_plan(query, docs_together, log_context)
|
||||
current_plan_parts = []
|
||||
yield {"thought": f"Reasoning... (iteration {iterating_reasoning})\n\n"}
|
||||
for line_chunk in plan_stream:
|
||||
current_plan_parts.append(line_chunk)
|
||||
yield {"thought": line_chunk}
|
||||
self.plan = "".join(current_plan_parts)
|
||||
if self.plan:
|
||||
self.observations.append(f"Plan: {self.plan} Iteration: {iterating_reasoning}")
|
||||
|
||||
plan_stream = self.llm.gen_stream(
|
||||
model=self.model_id,
|
||||
messages=messages,
|
||||
tools=self.tools if self.tools else None,
|
||||
|
||||
max_obs_len = 20000
|
||||
obs_str = "\n".join(self.observations)
|
||||
if len(obs_str) > max_obs_len:
|
||||
obs_str = obs_str[:max_obs_len] + "\n...[observations truncated]"
|
||||
execution_prompt_str = (
|
||||
(self.prompt or "")
|
||||
+ f"\n\nFollow this plan:\n{self.plan}"
|
||||
+ f"\n\nObservations:\n{obs_str}"
|
||||
+ f"\n\nIf there is enough data to complete user query '{query}', Respond with 'SATISFIED' only. Otherwise, continue. Dont Menstion 'SATISFIED' in your response if you are not ready. "
|
||||
)
|
||||
|
||||
messages = self._build_messages(execution_prompt_str, query, retrieved_data)
|
||||
|
||||
resp_from_llm_gen = self._llm_gen(messages, log_context)
|
||||
|
||||
initial_llm_thought_content = self._extract_content_from_llm_response(resp_from_llm_gen)
|
||||
if initial_llm_thought_content:
|
||||
self.observations.append(f"Initial thought/response: {initial_llm_thought_content}")
|
||||
else:
|
||||
logger.info("ReActAgent: Initial LLM response (before handler) had no textual content (might be only tool calls).")
|
||||
resp_after_handler = self._llm_handler(resp_from_llm_gen, tools_dict, messages, log_context)
|
||||
|
||||
for tool_call_info in self.tool_calls: # Iterate over self.tool_calls populated by _llm_handler
|
||||
observation_string = (
|
||||
f"Executed Action: Tool '{tool_call_info.get('tool_name', 'N/A')}' "
|
||||
f"with arguments '{tool_call_info.get('arguments', '{}')}'. Result: '{str(tool_call_info.get('result', ''))[:200]}...'"
|
||||
)
|
||||
self.observations.append(observation_string)
|
||||
|
||||
content_after_handler = self._extract_content_from_llm_response(resp_after_handler)
|
||||
if content_after_handler:
|
||||
self.observations.append(f"Response after tool execution: {content_after_handler}")
|
||||
else:
|
||||
logger.info("ReActAgent: LLM response after handler had no textual content.")
|
||||
|
||||
if log_context:
|
||||
log_context.stacks.append(
|
||||
{"component": "agent_tool_calls", "data": {"tool_calls": self.tool_calls.copy()}}
|
||||
)
|
||||
|
||||
yield {"sources": retrieved_data}
|
||||
|
||||
display_tool_calls = []
|
||||
for tc in self.tool_calls:
|
||||
cleaned_tc = tc.copy()
|
||||
if len(str(cleaned_tc.get("result", ""))) > 50:
|
||||
cleaned_tc["result"] = str(cleaned_tc["result"])[:50] + "..."
|
||||
display_tool_calls.append(cleaned_tc)
|
||||
if display_tool_calls:
|
||||
yield {"tool_calls": display_tool_calls}
|
||||
|
||||
if "SATISFIED" in content_after_handler:
|
||||
logger.info("ReActAgent: LLM satisfied with the plan and data. Stopping reasoning.")
|
||||
break
|
||||
|
||||
# 3. Create Final Answer based on all observations
|
||||
final_answer_stream = self._create_final_answer(query, self.observations, log_context)
|
||||
for answer_chunk in final_answer_stream:
|
||||
yield {"answer": answer_chunk}
|
||||
logger.info("ReActAgent: Finished generating final answer.")
|
||||
|
||||
def _create_plan(
|
||||
self, query: str, docs_data: str, log_context: LogContext = None
|
||||
) -> Generator[str, None, None]:
|
||||
plan_prompt_filled = planning_prompt_template.replace("{query}", query)
|
||||
if "{summaries}" in plan_prompt_filled:
|
||||
summaries = docs_data if docs_data else "No documents retrieved."
|
||||
plan_prompt_filled = plan_prompt_filled.replace("{summaries}", summaries)
|
||||
plan_prompt_filled = plan_prompt_filled.replace("{prompt}", self.prompt or "")
|
||||
plan_prompt_filled = plan_prompt_filled.replace("{observations}", "\n".join(self.observations))
|
||||
|
||||
messages = [{"role": "user", "content": plan_prompt_filled}]
|
||||
|
||||
plan_stream_from_llm = self.llm.gen_stream(
|
||||
model=self.gpt_model, messages=messages, tools=getattr(self, 'tools', None) # Use self.tools
|
||||
)
|
||||
|
||||
if log_context:
|
||||
log_context.stacks.append(
|
||||
{"component": "planning_llm", "data": build_stack_data(self.llm)}
|
||||
)
|
||||
plan_parts = []
|
||||
for chunk in plan_stream:
|
||||
content = self._extract_content(chunk)
|
||||
if content:
|
||||
plan_parts.append(content)
|
||||
yield {"thought": content}
|
||||
self.plan = "".join(plan_parts)
|
||||
data = build_stack_data(self.llm)
|
||||
log_context.stacks.append({"component": "planning_llm", "data": data})
|
||||
|
||||
def _execution_phase(
|
||||
self, query: str, tools_dict: Dict, log_context: LogContext
|
||||
) -> Generator[bool, None, None]:
|
||||
"""Execute plan with tool calls and observations"""
|
||||
execution_prompt = self._build_execution_prompt(query)
|
||||
messages = self._build_messages(execution_prompt, query)
|
||||
for chunk in plan_stream_from_llm:
|
||||
content_piece = self._extract_content_from_llm_response(chunk)
|
||||
if content_piece:
|
||||
yield content_piece
|
||||
|
||||
llm_response = self._llm_gen(messages, log_context)
|
||||
initial_content = self._extract_content(llm_response)
|
||||
def _create_final_answer(
|
||||
self, query: str, observations: List[str], log_context: LogContext = None
|
||||
) -> Generator[str, None, None]:
|
||||
observation_string = "\n".join(observations)
|
||||
max_obs_len = 10000
|
||||
if len(observation_string) > max_obs_len:
|
||||
observation_string = observation_string[:max_obs_len] + "\n...[observations truncated]"
|
||||
logger.warning("ReActAgent: Truncated observations for final answer prompt due to length.")
|
||||
|
||||
if initial_content:
|
||||
self.observations.append(f"Initial response: {initial_content}")
|
||||
processed_response = self._llm_handler(
|
||||
llm_response, tools_dict, messages, log_context
|
||||
final_answer_prompt_filled = final_prompt_template.format(
|
||||
query=query, observations=observation_string
|
||||
)
|
||||
|
||||
for tool_call in self.tool_calls:
|
||||
observation = (
|
||||
f"Executed: {tool_call.get('tool_name', 'Unknown')} "
|
||||
f"with args {tool_call.get('arguments', {})}. "
|
||||
f"Result: {str(tool_call.get('result', ''))[:200]}"
|
||||
)
|
||||
self.observations.append(observation)
|
||||
final_content = self._extract_content(processed_response)
|
||||
if final_content:
|
||||
self.observations.append(f"Response after tools: {final_content}")
|
||||
messages = [{"role": "user", "content": final_answer_prompt_filled}]
|
||||
|
||||
# Final answer should synthesize, not call tools.
|
||||
final_answer_stream_from_llm = self.llm.gen_stream(
|
||||
model=self.gpt_model, messages=messages, tools=None
|
||||
)
|
||||
if log_context:
|
||||
log_context.stacks.append(
|
||||
{
|
||||
"component": "agent_tool_calls",
|
||||
"data": {"tool_calls": self.tool_calls.copy()},
|
||||
}
|
||||
)
|
||||
yield {"sources": self.retrieved_docs}
|
||||
yield {"tool_calls": self._get_truncated_tool_calls()}
|
||||
data = build_stack_data(self.llm)
|
||||
log_context.stacks.append({"component": "final_answer_llm", "data": data})
|
||||
|
||||
return "SATISFIED" in (final_content or "")
|
||||
|
||||
def _synthesis_phase(
|
||||
self, query: str, log_context: LogContext
|
||||
) -> Generator[Dict, None, None]:
|
||||
"""Synthesize final answer from all observations"""
|
||||
logger.info("ReActAgent: Generating final answer...")
|
||||
|
||||
final_prompt = self._build_final_answer_prompt(query)
|
||||
messages = [{"role": "user", "content": final_prompt}]
|
||||
|
||||
final_stream = self.llm.gen_stream(
|
||||
model=self.model_id, messages=messages, tools=None
|
||||
)
|
||||
|
||||
if log_context:
|
||||
log_context.stacks.append(
|
||||
{"component": "final_answer_llm", "data": build_stack_data(self.llm)}
|
||||
)
|
||||
for chunk in final_stream:
|
||||
content = self._extract_content(chunk)
|
||||
if content:
|
||||
yield {"answer": content}
|
||||
|
||||
def _build_planning_prompt(self, query: str) -> str:
|
||||
"""Build planning phase prompt"""
|
||||
prompt = PLANNING_PROMPT_TEMPLATE.replace("{query}", query)
|
||||
prompt = prompt.replace("{prompt}", self.prompt or "")
|
||||
prompt = prompt.replace("{summaries}", "")
|
||||
prompt = prompt.replace("{observations}", "\n".join(self.observations))
|
||||
return prompt
|
||||
|
||||
def _build_execution_prompt(self, query: str) -> str:
|
||||
"""Build execution phase prompt with plan and observations"""
|
||||
observations_str = "\n".join(self.observations)
|
||||
|
||||
if len(observations_str) > 20000:
|
||||
observations_str = observations_str[:20000] + "\n...[truncated]"
|
||||
return (
|
||||
f"{self.prompt or ''}\n\n"
|
||||
f"Follow this plan:\n{self.plan}\n\n"
|
||||
f"Observations:\n{observations_str}\n\n"
|
||||
f"If sufficient data exists to answer '{query}', respond with 'SATISFIED'. "
|
||||
f"Otherwise, continue executing the plan."
|
||||
)
|
||||
|
||||
def _build_final_answer_prompt(self, query: str) -> str:
|
||||
"""Build final synthesis prompt"""
|
||||
observations_str = "\n".join(self.observations)
|
||||
|
||||
if len(observations_str) > 10000:
|
||||
observations_str = observations_str[:10000] + "\n...[truncated]"
|
||||
logger.warning("ReActAgent: Observations truncated for final answer")
|
||||
return FINAL_PROMPT_TEMPLATE.format(query=query, observations=observations_str)
|
||||
|
||||
def _extract_content(self, response: Any) -> str:
|
||||
"""Extract text content from various LLM response formats"""
|
||||
if not response:
|
||||
return ""
|
||||
collected = []
|
||||
|
||||
if isinstance(response, str):
|
||||
return response
|
||||
if hasattr(response, "message") and hasattr(response.message, "content"):
|
||||
if response.message.content:
|
||||
return response.message.content
|
||||
if hasattr(response, "choices") and response.choices:
|
||||
if hasattr(response.choices[0], "message"):
|
||||
content = response.choices[0].message.content
|
||||
if content:
|
||||
return content
|
||||
if hasattr(response, "content") and isinstance(response.content, list):
|
||||
if response.content and hasattr(response.content[0], "text"):
|
||||
return response.content[0].text
|
||||
try:
|
||||
for chunk in response:
|
||||
content_piece = ""
|
||||
|
||||
if hasattr(chunk, "choices") and chunk.choices:
|
||||
if hasattr(chunk.choices[0], "delta"):
|
||||
delta_content = chunk.choices[0].delta.content
|
||||
if delta_content:
|
||||
content_piece = delta_content
|
||||
elif hasattr(chunk, "type") and chunk.type == "content_block_delta":
|
||||
if hasattr(chunk, "delta") and hasattr(chunk.delta, "text"):
|
||||
content_piece = chunk.delta.text
|
||||
elif isinstance(chunk, str):
|
||||
content_piece = chunk
|
||||
if content_piece:
|
||||
collected.append(content_piece)
|
||||
except (TypeError, AttributeError):
|
||||
logger.debug(
|
||||
f"Response not iterable or unexpected format: {type(response)}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting content: {e}")
|
||||
return "".join(collected)
|
||||
for chunk in final_answer_stream_from_llm:
|
||||
content_piece = self._extract_content_from_llm_response(chunk)
|
||||
if content_piece:
|
||||
yield content_piece
|
||||
@@ -1,861 +0,0 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from application.agents.tools.base import Tool
|
||||
from application.api.user.tasks import mcp_oauth_status_task, mcp_oauth_task
|
||||
from application.cache import get_redis_instance
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
|
||||
from application.core.settings import settings
|
||||
|
||||
from application.security.encryption import decrypt_credentials
|
||||
from fastmcp import Client
|
||||
from fastmcp.client.auth import BearerAuth
|
||||
from fastmcp.client.transports import (
|
||||
SSETransport,
|
||||
StdioTransport,
|
||||
StreamableHttpTransport,
|
||||
)
|
||||
from mcp.client.auth import OAuthClientProvider, TokenStorage
|
||||
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
|
||||
|
||||
from pydantic import AnyHttpUrl, ValidationError
|
||||
from redis import Redis
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
|
||||
_mcp_clients_cache = {}
|
||||
|
||||
|
||||
class MCPTool(Tool):
|
||||
"""
|
||||
MCP Tool
|
||||
Connect to remote Model Context Protocol (MCP) servers to access dynamic tools and resources.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any], user_id: Optional[str] = None):
|
||||
"""
|
||||
Initialize the MCP Tool with configuration.
|
||||
|
||||
Args:
|
||||
config: Dictionary containing MCP server configuration:
|
||||
- server_url: URL of the remote MCP server
|
||||
- transport_type: Transport type (auto, sse, http, stdio)
|
||||
- auth_type: Type of authentication (bearer, oauth, api_key, basic, none)
|
||||
- encrypted_credentials: Encrypted credentials (if available)
|
||||
- timeout: Request timeout in seconds (default: 30)
|
||||
- headers: Custom headers for requests
|
||||
- command: Command for STDIO transport
|
||||
- args: Arguments for STDIO transport
|
||||
- oauth_scopes: OAuth scopes for oauth auth type
|
||||
- oauth_client_name: OAuth client name for oauth auth type
|
||||
user_id: User ID for decrypting credentials (required if encrypted_credentials exist)
|
||||
"""
|
||||
self.config = config
|
||||
self.user_id = user_id
|
||||
self.server_url = config.get("server_url", "")
|
||||
self.transport_type = config.get("transport_type", "auto")
|
||||
self.auth_type = config.get("auth_type", "none")
|
||||
self.timeout = config.get("timeout", 30)
|
||||
self.custom_headers = config.get("headers", {})
|
||||
|
||||
self.auth_credentials = {}
|
||||
if config.get("encrypted_credentials") and user_id:
|
||||
self.auth_credentials = decrypt_credentials(
|
||||
config["encrypted_credentials"], user_id
|
||||
)
|
||||
else:
|
||||
self.auth_credentials = config.get("auth_credentials", {})
|
||||
self.oauth_scopes = config.get("oauth_scopes", [])
|
||||
self.oauth_task_id = config.get("oauth_task_id", None)
|
||||
self.oauth_client_name = config.get("oauth_client_name", "DocsGPT-MCP")
|
||||
self.redirect_uri = f"{settings.API_URL}/api/mcp_server/callback"
|
||||
|
||||
self.available_tools = []
|
||||
self._cache_key = self._generate_cache_key()
|
||||
self._client = None
|
||||
|
||||
# Only validate and setup if server_url is provided and not OAuth
|
||||
|
||||
if self.server_url and self.auth_type != "oauth":
|
||||
self._setup_client()
|
||||
|
||||
def _generate_cache_key(self) -> str:
|
||||
"""Generate a unique cache key for this MCP server configuration."""
|
||||
auth_key = ""
|
||||
if self.auth_type == "oauth":
|
||||
scopes_str = ",".join(self.oauth_scopes) if self.oauth_scopes else "none"
|
||||
auth_key = f"oauth:{self.oauth_client_name}:{scopes_str}"
|
||||
elif self.auth_type in ["bearer"]:
|
||||
token = self.auth_credentials.get(
|
||||
"bearer_token", ""
|
||||
) or self.auth_credentials.get("access_token", "")
|
||||
auth_key = f"bearer:{token[:10]}..." if token else "bearer:none"
|
||||
elif self.auth_type == "api_key":
|
||||
api_key = self.auth_credentials.get("api_key", "")
|
||||
auth_key = f"apikey:{api_key[:10]}..." if api_key else "apikey:none"
|
||||
elif self.auth_type == "basic":
|
||||
username = self.auth_credentials.get("username", "")
|
||||
auth_key = f"basic:{username}"
|
||||
else:
|
||||
auth_key = "none"
|
||||
return f"{self.server_url}#{self.transport_type}#{auth_key}"
|
||||
|
||||
def _setup_client(self):
|
||||
"""Setup FastMCP client with proper transport and authentication."""
|
||||
global _mcp_clients_cache
|
||||
if self._cache_key in _mcp_clients_cache:
|
||||
cached_data = _mcp_clients_cache[self._cache_key]
|
||||
if time.time() - cached_data["created_at"] < 1800:
|
||||
self._client = cached_data["client"]
|
||||
return
|
||||
else:
|
||||
del _mcp_clients_cache[self._cache_key]
|
||||
transport = self._create_transport()
|
||||
auth = None
|
||||
|
||||
if self.auth_type == "oauth":
|
||||
redis_client = get_redis_instance()
|
||||
auth = DocsGPTOAuth(
|
||||
mcp_url=self.server_url,
|
||||
scopes=self.oauth_scopes,
|
||||
redis_client=redis_client,
|
||||
redirect_uri=self.redirect_uri,
|
||||
task_id=self.oauth_task_id,
|
||||
db=db,
|
||||
user_id=self.user_id,
|
||||
)
|
||||
elif self.auth_type == "bearer":
|
||||
token = self.auth_credentials.get(
|
||||
"bearer_token", ""
|
||||
) or self.auth_credentials.get("access_token", "")
|
||||
if token:
|
||||
auth = BearerAuth(token)
|
||||
self._client = Client(transport, auth=auth)
|
||||
_mcp_clients_cache[self._cache_key] = {
|
||||
"client": self._client,
|
||||
"created_at": time.time(),
|
||||
}
|
||||
|
||||
def _create_transport(self):
|
||||
"""Create appropriate transport based on configuration."""
|
||||
headers = {"Content-Type": "application/json", "User-Agent": "DocsGPT-MCP/1.0"}
|
||||
headers.update(self.custom_headers)
|
||||
|
||||
if self.auth_type == "api_key":
|
||||
api_key = self.auth_credentials.get("api_key", "")
|
||||
header_name = self.auth_credentials.get("api_key_header", "X-API-Key")
|
||||
if api_key:
|
||||
headers[header_name] = api_key
|
||||
elif self.auth_type == "basic":
|
||||
username = self.auth_credentials.get("username", "")
|
||||
password = self.auth_credentials.get("password", "")
|
||||
if username and password:
|
||||
credentials = base64.b64encode(
|
||||
f"{username}:{password}".encode()
|
||||
).decode()
|
||||
headers["Authorization"] = f"Basic {credentials}"
|
||||
if self.transport_type == "auto":
|
||||
if "sse" in self.server_url.lower() or self.server_url.endswith("/sse"):
|
||||
transport_type = "sse"
|
||||
else:
|
||||
transport_type = "http"
|
||||
else:
|
||||
transport_type = self.transport_type
|
||||
if transport_type == "sse":
|
||||
headers.update({"Accept": "text/event-stream", "Cache-Control": "no-cache"})
|
||||
return SSETransport(url=self.server_url, headers=headers)
|
||||
elif transport_type == "http":
|
||||
return StreamableHttpTransport(url=self.server_url, headers=headers)
|
||||
elif transport_type == "stdio":
|
||||
command = self.config.get("command", "python")
|
||||
args = self.config.get("args", [])
|
||||
env = self.auth_credentials if self.auth_credentials else None
|
||||
return StdioTransport(command=command, args=args, env=env)
|
||||
else:
|
||||
return StreamableHttpTransport(url=self.server_url, headers=headers)
|
||||
|
||||
def _format_tools(self, tools_response) -> List[Dict]:
|
||||
"""Format tools response to match expected format."""
|
||||
if hasattr(tools_response, "tools"):
|
||||
tools = tools_response.tools
|
||||
elif isinstance(tools_response, list):
|
||||
tools = tools_response
|
||||
else:
|
||||
tools = []
|
||||
tools_dict = []
|
||||
for tool in tools:
|
||||
if hasattr(tool, "name"):
|
||||
tool_dict = {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
}
|
||||
if hasattr(tool, "inputSchema"):
|
||||
tool_dict["inputSchema"] = tool.inputSchema
|
||||
tools_dict.append(tool_dict)
|
||||
elif isinstance(tool, dict):
|
||||
tools_dict.append(tool)
|
||||
else:
|
||||
if hasattr(tool, "model_dump"):
|
||||
tools_dict.append(tool.model_dump())
|
||||
else:
|
||||
tools_dict.append({"name": str(tool), "description": ""})
|
||||
return tools_dict
|
||||
|
||||
async def _execute_with_client(self, operation: str, *args, **kwargs):
|
||||
"""Execute operation with FastMCP client."""
|
||||
if not self._client:
|
||||
raise Exception("FastMCP client not initialized")
|
||||
async with self._client:
|
||||
if operation == "ping":
|
||||
return await self._client.ping()
|
||||
elif operation == "list_tools":
|
||||
tools_response = await self._client.list_tools()
|
||||
self.available_tools = self._format_tools(tools_response)
|
||||
return self.available_tools
|
||||
elif operation == "call_tool":
|
||||
tool_name = args[0]
|
||||
tool_args = kwargs
|
||||
return await self._client.call_tool(tool_name, tool_args)
|
||||
elif operation == "list_resources":
|
||||
return await self._client.list_resources()
|
||||
elif operation == "list_prompts":
|
||||
return await self._client.list_prompts()
|
||||
else:
|
||||
raise Exception(f"Unknown operation: {operation}")
|
||||
|
||||
def _run_async_operation(self, operation: str, *args, **kwargs):
|
||||
"""Run async operation in sync context."""
|
||||
try:
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
import concurrent.futures
|
||||
|
||||
def run_in_thread():
|
||||
new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(new_loop)
|
||||
try:
|
||||
return new_loop.run_until_complete(
|
||||
self._execute_with_client(operation, *args, **kwargs)
|
||||
)
|
||||
finally:
|
||||
new_loop.close()
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(run_in_thread)
|
||||
return future.result(timeout=self.timeout)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
return loop.run_until_complete(
|
||||
self._execute_with_client(operation, *args, **kwargs)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
except Exception as e:
|
||||
print(f"Error occurred while running async operation: {e}")
|
||||
raise
|
||||
|
||||
def discover_tools(self) -> List[Dict]:
|
||||
"""
|
||||
Discover available tools from the MCP server using FastMCP.
|
||||
|
||||
Returns:
|
||||
List of tool definitions from the server
|
||||
"""
|
||||
if not self.server_url:
|
||||
return []
|
||||
if not self._client:
|
||||
self._setup_client()
|
||||
try:
|
||||
tools = self._run_async_operation("list_tools")
|
||||
self.available_tools = tools
|
||||
return self.available_tools
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to discover tools from MCP server: {str(e)}")
|
||||
|
||||
def execute_action(self, action_name: str, **kwargs) -> Any:
|
||||
"""
|
||||
Execute an action on the remote MCP server using FastMCP.
|
||||
|
||||
Args:
|
||||
action_name: Name of the action to execute
|
||||
**kwargs: Parameters for the action
|
||||
|
||||
Returns:
|
||||
Result from the MCP server
|
||||
"""
|
||||
if not self.server_url:
|
||||
raise Exception("No MCP server configured")
|
||||
if not self._client:
|
||||
self._setup_client()
|
||||
cleaned_kwargs = {}
|
||||
for key, value in kwargs.items():
|
||||
if value == "" or value is None:
|
||||
continue
|
||||
cleaned_kwargs[key] = value
|
||||
try:
|
||||
result = self._run_async_operation(
|
||||
"call_tool", action_name, **cleaned_kwargs
|
||||
)
|
||||
return self._format_result(result)
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to execute action '{action_name}': {str(e)}")
|
||||
|
||||
def _format_result(self, result) -> Dict:
|
||||
"""Format FastMCP result to match expected format."""
|
||||
if hasattr(result, "content"):
|
||||
content_list = []
|
||||
for content_item in result.content:
|
||||
if hasattr(content_item, "text"):
|
||||
content_list.append({"type": "text", "text": content_item.text})
|
||||
elif hasattr(content_item, "data"):
|
||||
content_list.append({"type": "data", "data": content_item.data})
|
||||
else:
|
||||
content_list.append(
|
||||
{"type": "unknown", "content": str(content_item)}
|
||||
)
|
||||
return {
|
||||
"content": content_list,
|
||||
"isError": getattr(result, "isError", False),
|
||||
}
|
||||
else:
|
||||
return result
|
||||
|
||||
def test_connection(self) -> Dict:
|
||||
"""
|
||||
Test the connection to the MCP server and validate functionality.
|
||||
|
||||
Returns:
|
||||
Dictionary with connection test results including tool count
|
||||
"""
|
||||
if not self.server_url:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "No MCP server URL configured",
|
||||
"tools_count": 0,
|
||||
"transport_type": self.transport_type,
|
||||
"auth_type": self.auth_type,
|
||||
"error_type": "ConfigurationError",
|
||||
}
|
||||
if not self._client:
|
||||
self._setup_client()
|
||||
try:
|
||||
if self.auth_type == "oauth":
|
||||
return self._test_oauth_connection()
|
||||
else:
|
||||
return self._test_regular_connection()
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"Connection failed: {str(e)}",
|
||||
"tools_count": 0,
|
||||
"transport_type": self.transport_type,
|
||||
"auth_type": self.auth_type,
|
||||
"error_type": type(e).__name__,
|
||||
}
|
||||
|
||||
def _test_regular_connection(self) -> Dict:
|
||||
"""Test connection for non-OAuth auth types."""
|
||||
try:
|
||||
self._run_async_operation("ping")
|
||||
ping_success = True
|
||||
except Exception:
|
||||
ping_success = False
|
||||
tools = self.discover_tools()
|
||||
|
||||
message = f"Successfully connected to MCP server. Found {len(tools)} tools."
|
||||
if not ping_success:
|
||||
message += " (Ping not supported, but tool discovery worked)"
|
||||
return {
|
||||
"success": True,
|
||||
"message": message,
|
||||
"tools_count": len(tools),
|
||||
"transport_type": self.transport_type,
|
||||
"auth_type": self.auth_type,
|
||||
"ping_supported": ping_success,
|
||||
"tools": [tool.get("name", "unknown") for tool in tools],
|
||||
}
|
||||
|
||||
def _test_oauth_connection(self) -> Dict:
|
||||
"""Test connection for OAuth auth type with proper async handling."""
|
||||
try:
|
||||
task = mcp_oauth_task.delay(config=self.config, user=self.user_id)
|
||||
if not task:
|
||||
raise Exception("Failed to start OAuth authentication")
|
||||
return {
|
||||
"success": True,
|
||||
"requires_oauth": True,
|
||||
"task_id": task.id,
|
||||
"status": "pending",
|
||||
"message": "OAuth flow started",
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"OAuth connection failed: {str(e)}",
|
||||
"tools_count": 0,
|
||||
"transport_type": self.transport_type,
|
||||
"auth_type": self.auth_type,
|
||||
"error_type": type(e).__name__,
|
||||
}
|
||||
|
||||
def get_actions_metadata(self) -> List[Dict]:
|
||||
"""
|
||||
Get metadata for all available actions.
|
||||
|
||||
Returns:
|
||||
List of action metadata dictionaries
|
||||
"""
|
||||
actions = []
|
||||
for tool in self.available_tools:
|
||||
input_schema = (
|
||||
tool.get("inputSchema")
|
||||
or tool.get("input_schema")
|
||||
or tool.get("schema")
|
||||
or tool.get("parameters")
|
||||
)
|
||||
|
||||
parameters_schema = {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
if input_schema:
|
||||
if isinstance(input_schema, dict):
|
||||
if "properties" in input_schema:
|
||||
parameters_schema = {
|
||||
"type": input_schema.get("type", "object"),
|
||||
"properties": input_schema.get("properties", {}),
|
||||
"required": input_schema.get("required", []),
|
||||
}
|
||||
|
||||
for key in ["additionalProperties", "description"]:
|
||||
if key in input_schema:
|
||||
parameters_schema[key] = input_schema[key]
|
||||
else:
|
||||
parameters_schema["properties"] = input_schema
|
||||
action = {
|
||||
"name": tool.get("name", ""),
|
||||
"description": tool.get("description", ""),
|
||||
"parameters": parameters_schema,
|
||||
}
|
||||
actions.append(action)
|
||||
return actions
|
||||
|
||||
def get_config_requirements(self) -> Dict:
|
||||
"""Get configuration requirements for the MCP tool."""
|
||||
return {
|
||||
"server_url": {
|
||||
"type": "string",
|
||||
"description": "URL of the remote MCP server (e.g., https://api.example.com/mcp or https://docs.mcp.cloudflare.com/sse)",
|
||||
"required": True,
|
||||
},
|
||||
"transport_type": {
|
||||
"type": "string",
|
||||
"description": "Transport type for connection",
|
||||
"enum": ["auto", "sse", "http", "stdio"],
|
||||
"default": "auto",
|
||||
"required": False,
|
||||
"help": {
|
||||
"auto": "Automatically detect best transport",
|
||||
"sse": "Server-Sent Events (for real-time streaming)",
|
||||
"http": "HTTP streaming (recommended for production)",
|
||||
"stdio": "Standard I/O (for local servers)",
|
||||
},
|
||||
},
|
||||
"auth_type": {
|
||||
"type": "string",
|
||||
"description": "Authentication type",
|
||||
"enum": ["none", "bearer", "oauth", "api_key", "basic"],
|
||||
"default": "none",
|
||||
"required": True,
|
||||
"help": {
|
||||
"none": "No authentication",
|
||||
"bearer": "Bearer token authentication",
|
||||
"oauth": "OAuth 2.1 authentication (with frontend integration)",
|
||||
"api_key": "API key authentication",
|
||||
"basic": "Basic authentication",
|
||||
},
|
||||
},
|
||||
"auth_credentials": {
|
||||
"type": "object",
|
||||
"description": "Authentication credentials (varies by auth_type)",
|
||||
"required": False,
|
||||
"properties": {
|
||||
"bearer_token": {
|
||||
"type": "string",
|
||||
"description": "Bearer token for bearer auth",
|
||||
},
|
||||
"access_token": {
|
||||
"type": "string",
|
||||
"description": "Access token for OAuth (if pre-obtained)",
|
||||
},
|
||||
"api_key": {
|
||||
"type": "string",
|
||||
"description": "API key for api_key auth",
|
||||
},
|
||||
"api_key_header": {
|
||||
"type": "string",
|
||||
"description": "Header name for API key (default: X-API-Key)",
|
||||
},
|
||||
"username": {
|
||||
"type": "string",
|
||||
"description": "Username for basic auth",
|
||||
},
|
||||
"password": {
|
||||
"type": "string",
|
||||
"description": "Password for basic auth",
|
||||
},
|
||||
},
|
||||
},
|
||||
"oauth_scopes": {
|
||||
"type": "array",
|
||||
"description": "OAuth scopes to request (for oauth auth_type)",
|
||||
"items": {"type": "string"},
|
||||
"required": False,
|
||||
"default": [],
|
||||
},
|
||||
"oauth_client_name": {
|
||||
"type": "string",
|
||||
"description": "Client name for OAuth registration (for oauth auth_type)",
|
||||
"default": "DocsGPT-MCP",
|
||||
"required": False,
|
||||
},
|
||||
"headers": {
|
||||
"type": "object",
|
||||
"description": "Custom headers to send with requests",
|
||||
"required": False,
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Request timeout in seconds",
|
||||
"default": 30,
|
||||
"minimum": 1,
|
||||
"maximum": 300,
|
||||
"required": False,
|
||||
},
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "Command to run for STDIO transport (e.g., 'python')",
|
||||
"required": False,
|
||||
},
|
||||
"args": {
|
||||
"type": "array",
|
||||
"description": "Arguments for STDIO command",
|
||||
"items": {"type": "string"},
|
||||
"required": False,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class DocsGPTOAuth(OAuthClientProvider):
|
||||
"""
|
||||
Custom OAuth handler for DocsGPT that uses frontend redirect instead of browser.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mcp_url: str,
|
||||
redirect_uri: str,
|
||||
redis_client: Redis | None = None,
|
||||
redis_prefix: str = "mcp_oauth:",
|
||||
task_id: str = None,
|
||||
scopes: str | list[str] | None = None,
|
||||
client_name: str = "DocsGPT-MCP",
|
||||
user_id=None,
|
||||
db=None,
|
||||
additional_client_metadata: dict[str, Any] | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize custom OAuth client provider for DocsGPT.
|
||||
|
||||
Args:
|
||||
mcp_url: Full URL to the MCP endpoint
|
||||
redirect_uri: Custom redirect URI for DocsGPT frontend
|
||||
redis_client: Redis client for storing auth state
|
||||
redis_prefix: Prefix for Redis keys
|
||||
task_id: Task ID for tracking auth status
|
||||
scopes: OAuth scopes to request
|
||||
client_name: Name for this client during registration
|
||||
user_id: User ID for token storage
|
||||
db: Database instance for token storage
|
||||
additional_client_metadata: Extra fields for OAuthClientMetadata
|
||||
"""
|
||||
|
||||
self.redirect_uri = redirect_uri
|
||||
self.redis_client = redis_client
|
||||
self.redis_prefix = redis_prefix
|
||||
self.task_id = task_id
|
||||
self.user_id = user_id
|
||||
self.db = db
|
||||
|
||||
parsed_url = urlparse(mcp_url)
|
||||
self.server_base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
||||
|
||||
if isinstance(scopes, list):
|
||||
scopes = " ".join(scopes)
|
||||
client_metadata = OAuthClientMetadata(
|
||||
client_name=client_name,
|
||||
redirect_uris=[AnyHttpUrl(redirect_uri)],
|
||||
grant_types=["authorization_code", "refresh_token"],
|
||||
response_types=["code"],
|
||||
scope=scopes,
|
||||
**(additional_client_metadata or {}),
|
||||
)
|
||||
|
||||
storage = DBTokenStorage(
|
||||
server_url=self.server_base_url, user_id=self.user_id, db_client=self.db
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
server_url=self.server_base_url,
|
||||
client_metadata=client_metadata,
|
||||
storage=storage,
|
||||
redirect_handler=self.redirect_handler,
|
||||
callback_handler=self.callback_handler,
|
||||
)
|
||||
|
||||
self.auth_url = None
|
||||
self.extracted_state = None
|
||||
|
||||
def _process_auth_url(self, authorization_url: str) -> tuple[str, str]:
|
||||
"""Process authorization URL to extract state"""
|
||||
try:
|
||||
parsed_url = urlparse(authorization_url)
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
|
||||
state_params = query_params.get("state", [])
|
||||
if state_params:
|
||||
state = state_params[0]
|
||||
else:
|
||||
raise ValueError("No state in auth URL")
|
||||
return authorization_url, state
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to process auth URL: {e}")
|
||||
|
||||
async def redirect_handler(self, authorization_url: str) -> None:
|
||||
"""Store auth URL and state in Redis for frontend to use."""
|
||||
auth_url, state = self._process_auth_url(authorization_url)
|
||||
logging.info(
|
||||
"[DocsGPTOAuth] Processed auth_url: %s, state: %s", auth_url, state
|
||||
)
|
||||
self.auth_url = auth_url
|
||||
self.extracted_state = state
|
||||
|
||||
if self.redis_client and self.extracted_state:
|
||||
key = f"{self.redis_prefix}auth_url:{self.extracted_state}"
|
||||
self.redis_client.setex(key, 600, auth_url)
|
||||
logging.info("[DocsGPTOAuth] Stored auth_url in Redis: %s", key)
|
||||
|
||||
if self.task_id:
|
||||
status_key = f"mcp_oauth_status:{self.task_id}"
|
||||
status_data = {
|
||||
"status": "requires_redirect",
|
||||
"message": "OAuth authorization required",
|
||||
"authorization_url": self.auth_url,
|
||||
"state": self.extracted_state,
|
||||
"requires_oauth": True,
|
||||
"task_id": self.task_id,
|
||||
}
|
||||
self.redis_client.setex(status_key, 600, json.dumps(status_data))
|
||||
|
||||
async def callback_handler(self) -> tuple[str, str | None]:
|
||||
"""Wait for auth code from Redis using the state value."""
|
||||
if not self.redis_client or not self.extracted_state:
|
||||
raise Exception("Redis client or state not configured for OAuth")
|
||||
poll_interval = 1
|
||||
max_wait_time = 300
|
||||
code_key = f"{self.redis_prefix}code:{self.extracted_state}"
|
||||
|
||||
if self.task_id:
|
||||
status_key = f"mcp_oauth_status:{self.task_id}"
|
||||
status_data = {
|
||||
"status": "awaiting_callback",
|
||||
"message": "Waiting for OAuth callback...",
|
||||
"authorization_url": self.auth_url,
|
||||
"state": self.extracted_state,
|
||||
"requires_oauth": True,
|
||||
"task_id": self.task_id,
|
||||
}
|
||||
self.redis_client.setex(status_key, 600, json.dumps(status_data))
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < max_wait_time:
|
||||
code_data = self.redis_client.get(code_key)
|
||||
if code_data:
|
||||
code = code_data.decode()
|
||||
returned_state = self.extracted_state
|
||||
|
||||
self.redis_client.delete(code_key)
|
||||
self.redis_client.delete(
|
||||
f"{self.redis_prefix}auth_url:{self.extracted_state}"
|
||||
)
|
||||
self.redis_client.delete(
|
||||
f"{self.redis_prefix}state:{self.extracted_state}"
|
||||
)
|
||||
|
||||
if self.task_id:
|
||||
status_data = {
|
||||
"status": "callback_received",
|
||||
"message": "OAuth callback received, completing authentication...",
|
||||
"task_id": self.task_id,
|
||||
}
|
||||
self.redis_client.setex(status_key, 600, json.dumps(status_data))
|
||||
return code, returned_state
|
||||
error_key = f"{self.redis_prefix}error:{self.extracted_state}"
|
||||
error_data = self.redis_client.get(error_key)
|
||||
if error_data:
|
||||
error_msg = error_data.decode()
|
||||
self.redis_client.delete(error_key)
|
||||
self.redis_client.delete(
|
||||
f"{self.redis_prefix}auth_url:{self.extracted_state}"
|
||||
)
|
||||
self.redis_client.delete(
|
||||
f"{self.redis_prefix}state:{self.extracted_state}"
|
||||
)
|
||||
raise Exception(f"OAuth error: {error_msg}")
|
||||
await asyncio.sleep(poll_interval)
|
||||
self.redis_client.delete(f"{self.redis_prefix}auth_url:{self.extracted_state}")
|
||||
self.redis_client.delete(f"{self.redis_prefix}state:{self.extracted_state}")
|
||||
raise Exception("OAuth callback timeout: no code received within 5 minutes")
|
||||
|
||||
|
||||
class DBTokenStorage(TokenStorage):
|
||||
def __init__(self, server_url: str, user_id: str, db_client):
|
||||
self.server_url = server_url
|
||||
self.user_id = user_id
|
||||
self.db_client = db_client
|
||||
self.collection = db_client["connector_sessions"]
|
||||
|
||||
@staticmethod
|
||||
def get_base_url(url: str) -> str:
|
||||
parsed = urlparse(url)
|
||||
return f"{parsed.scheme}://{parsed.netloc}"
|
||||
|
||||
def get_db_key(self) -> dict:
|
||||
return {
|
||||
"server_url": self.get_base_url(self.server_url),
|
||||
"user_id": self.user_id,
|
||||
}
|
||||
|
||||
async def get_tokens(self) -> OAuthToken | None:
|
||||
doc = await asyncio.to_thread(self.collection.find_one, self.get_db_key())
|
||||
if not doc or "tokens" not in doc:
|
||||
return None
|
||||
try:
|
||||
tokens = OAuthToken.model_validate(doc["tokens"])
|
||||
return tokens
|
||||
except ValidationError as e:
|
||||
logging.error(f"Could not load tokens: {e}")
|
||||
return None
|
||||
|
||||
async def set_tokens(self, tokens: OAuthToken) -> None:
|
||||
await asyncio.to_thread(
|
||||
self.collection.update_one,
|
||||
self.get_db_key(),
|
||||
{"$set": {"tokens": tokens.model_dump()}},
|
||||
True,
|
||||
)
|
||||
logging.info(f"Saved tokens for {self.get_base_url(self.server_url)}")
|
||||
|
||||
async def get_client_info(self) -> OAuthClientInformationFull | None:
|
||||
doc = await asyncio.to_thread(self.collection.find_one, self.get_db_key())
|
||||
if not doc or "client_info" not in doc:
|
||||
return None
|
||||
try:
|
||||
client_info = OAuthClientInformationFull.model_validate(doc["client_info"])
|
||||
tokens = await self.get_tokens()
|
||||
if tokens is None:
|
||||
logging.debug(
|
||||
"No tokens found, clearing client info to force fresh registration."
|
||||
)
|
||||
await asyncio.to_thread(
|
||||
self.collection.update_one,
|
||||
self.get_db_key(),
|
||||
{"$unset": {"client_info": ""}},
|
||||
)
|
||||
return None
|
||||
return client_info
|
||||
except ValidationError as e:
|
||||
logging.error(f"Could not load client info: {e}")
|
||||
return None
|
||||
|
||||
def _serialize_client_info(self, info: dict) -> dict:
|
||||
if "redirect_uris" in info and isinstance(info["redirect_uris"], list):
|
||||
info["redirect_uris"] = [str(u) for u in info["redirect_uris"]]
|
||||
return info
|
||||
|
||||
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
|
||||
serialized_info = self._serialize_client_info(client_info.model_dump())
|
||||
await asyncio.to_thread(
|
||||
self.collection.update_one,
|
||||
self.get_db_key(),
|
||||
{"$set": {"client_info": serialized_info}},
|
||||
True,
|
||||
)
|
||||
logging.info(f"Saved client info for {self.get_base_url(self.server_url)}")
|
||||
|
||||
async def clear(self) -> None:
|
||||
await asyncio.to_thread(self.collection.delete_one, self.get_db_key())
|
||||
logging.info(f"Cleared OAuth cache for {self.get_base_url(self.server_url)}")
|
||||
|
||||
@classmethod
|
||||
async def clear_all(cls, db_client) -> None:
|
||||
collection = db_client["connector_sessions"]
|
||||
await asyncio.to_thread(collection.delete_many, {})
|
||||
logging.info("Cleared all OAuth client cache data.")
|
||||
|
||||
|
||||
class MCPOAuthManager:
|
||||
"""Manager for handling MCP OAuth callbacks."""
|
||||
|
||||
def __init__(self, redis_client: Redis | None, redis_prefix: str = "mcp_oauth:"):
|
||||
self.redis_client = redis_client
|
||||
self.redis_prefix = redis_prefix
|
||||
|
||||
def handle_oauth_callback(
|
||||
self, state: str, code: str, error: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Handle OAuth callback from provider.
|
||||
|
||||
Args:
|
||||
state: The state parameter from OAuth callback
|
||||
code: The authorization code from OAuth callback
|
||||
error: Error message if OAuth failed
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
if not self.redis_client or not state:
|
||||
raise Exception("Redis client or state not provided")
|
||||
if error:
|
||||
error_key = f"{self.redis_prefix}error:{state}"
|
||||
self.redis_client.setex(error_key, 300, error)
|
||||
raise Exception(f"OAuth error received: {error}")
|
||||
code_key = f"{self.redis_prefix}code:{state}"
|
||||
self.redis_client.setex(code_key, 300, code)
|
||||
|
||||
state_key = f"{self.redis_prefix}state:{state}"
|
||||
self.redis_client.setex(state_key, 300, "completed")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.error(f"Error handling OAuth callback: {e}")
|
||||
return False
|
||||
|
||||
def get_oauth_status(self, task_id: str) -> Dict[str, Any]:
|
||||
"""Get current status of OAuth flow using provided task_id."""
|
||||
if not task_id:
|
||||
return {"status": "not_started", "message": "OAuth flow not started"}
|
||||
return mcp_oauth_status_task(task_id)
|
||||
@@ -1,546 +0,0 @@
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
import re
|
||||
import uuid
|
||||
|
||||
from .base import Tool
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
class MemoryTool(Tool):
|
||||
"""Memory
|
||||
|
||||
Stores and retrieves information across conversations through a memory file directory.
|
||||
"""
|
||||
|
||||
def __init__(self, tool_config: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None) -> None:
|
||||
"""Initialize the tool.
|
||||
|
||||
Args:
|
||||
tool_config: Optional tool configuration. Should include:
|
||||
- tool_id: Unique identifier for this memory tool instance (from user_tools._id)
|
||||
This ensures each user's tool configuration has isolated memories
|
||||
user_id: The authenticated user's id (should come from decoded_token["sub"]).
|
||||
"""
|
||||
self.user_id: Optional[str] = user_id
|
||||
|
||||
# Get tool_id from configuration (passed from user_tools._id in production)
|
||||
# In production, tool_id is the MongoDB ObjectId string from user_tools collection
|
||||
if tool_config and "tool_id" in tool_config:
|
||||
self.tool_id = tool_config["tool_id"]
|
||||
elif user_id:
|
||||
# Fallback for backward compatibility or testing
|
||||
self.tool_id = f"default_{user_id}"
|
||||
else:
|
||||
# Last resort fallback (shouldn't happen in normal use)
|
||||
self.tool_id = str(uuid.uuid4())
|
||||
|
||||
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
|
||||
self.collection = db["memories"]
|
||||
|
||||
# -----------------------------
|
||||
# Action implementations
|
||||
# -----------------------------
|
||||
def execute_action(self, action_name: str, **kwargs: Any) -> str:
|
||||
"""Execute an action by name.
|
||||
|
||||
Args:
|
||||
action_name: One of view, create, str_replace, insert, delete, rename.
|
||||
**kwargs: Parameters for the action.
|
||||
|
||||
Returns:
|
||||
A human-readable string result.
|
||||
"""
|
||||
if not self.user_id:
|
||||
return "Error: MemoryTool requires a valid user_id."
|
||||
|
||||
if action_name == "view":
|
||||
return self._view(
|
||||
kwargs.get("path", "/"),
|
||||
kwargs.get("view_range")
|
||||
)
|
||||
|
||||
if action_name == "create":
|
||||
return self._create(
|
||||
kwargs.get("path", ""),
|
||||
kwargs.get("file_text", "")
|
||||
)
|
||||
|
||||
if action_name == "str_replace":
|
||||
return self._str_replace(
|
||||
kwargs.get("path", ""),
|
||||
kwargs.get("old_str", ""),
|
||||
kwargs.get("new_str", "")
|
||||
)
|
||||
|
||||
if action_name == "insert":
|
||||
return self._insert(
|
||||
kwargs.get("path", ""),
|
||||
kwargs.get("insert_line", 1),
|
||||
kwargs.get("insert_text", "")
|
||||
)
|
||||
|
||||
if action_name == "delete":
|
||||
return self._delete(kwargs.get("path", ""))
|
||||
|
||||
if action_name == "rename":
|
||||
return self._rename(
|
||||
kwargs.get("old_path", ""),
|
||||
kwargs.get("new_path", "")
|
||||
)
|
||||
|
||||
return f"Unknown action: {action_name}"
|
||||
|
||||
def get_actions_metadata(self) -> List[Dict[str, Any]]:
|
||||
"""Return JSON metadata describing supported actions for tool schemas."""
|
||||
return [
|
||||
{
|
||||
"name": "view",
|
||||
"description": "Shows directory contents or file contents with optional line ranges.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to file or directory (e.g., /notes.txt or /project/ or /)."
|
||||
},
|
||||
"view_range": {
|
||||
"type": "array",
|
||||
"items": {"type": "integer"},
|
||||
"description": "Optional [start_line, end_line] to view specific lines (1-indexed)."
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "create",
|
||||
"description": "Create or overwrite a file.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "File path to create (e.g., /notes.txt or /project/task.txt)."
|
||||
},
|
||||
"file_text": {
|
||||
"type": "string",
|
||||
"description": "Content to write to the file."
|
||||
}
|
||||
},
|
||||
"required": ["path", "file_text"]
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "str_replace",
|
||||
"description": "Replace text in a file.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "File path (e.g., /notes.txt)."
|
||||
},
|
||||
"old_str": {
|
||||
"type": "string",
|
||||
"description": "String to find."
|
||||
},
|
||||
"new_str": {
|
||||
"type": "string",
|
||||
"description": "String to replace with."
|
||||
}
|
||||
},
|
||||
"required": ["path", "old_str", "new_str"]
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "insert",
|
||||
"description": "Insert text at a specific line in a file.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "File path (e.g., /notes.txt)."
|
||||
},
|
||||
"insert_line": {
|
||||
"type": "integer",
|
||||
"description": "Line number to insert at (1-indexed)."
|
||||
},
|
||||
"insert_text": {
|
||||
"type": "string",
|
||||
"description": "Text to insert."
|
||||
}
|
||||
},
|
||||
"required": ["path", "insert_line", "insert_text"]
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "delete",
|
||||
"description": "Delete a file or directory.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to delete (e.g., /notes.txt or /project/)."
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "rename",
|
||||
"description": "Rename or move a file/directory.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"old_path": {
|
||||
"type": "string",
|
||||
"description": "Current path (e.g., /old.txt)."
|
||||
},
|
||||
"new_path": {
|
||||
"type": "string",
|
||||
"description": "New path (e.g., /new.txt)."
|
||||
}
|
||||
},
|
||||
"required": ["old_path", "new_path"]
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
def get_config_requirements(self) -> Dict[str, Any]:
|
||||
"""Return configuration requirements."""
|
||||
return {}
|
||||
|
||||
# -----------------------------
|
||||
# Path validation
|
||||
# -----------------------------
|
||||
def _validate_path(self, path: str) -> Optional[str]:
|
||||
"""Validate and normalize path.
|
||||
|
||||
Args:
|
||||
path: User-provided path.
|
||||
|
||||
Returns:
|
||||
Normalized path or None if invalid.
|
||||
"""
|
||||
if not path:
|
||||
return None
|
||||
|
||||
# Remove any leading/trailing whitespace
|
||||
path = path.strip()
|
||||
|
||||
# Preserve whether path ends with / (indicates directory)
|
||||
is_directory = path.endswith("/")
|
||||
|
||||
# Ensure path starts with / for consistency
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
|
||||
# Check for directory traversal patterns
|
||||
if ".." in path or path.count("//") > 0:
|
||||
return None
|
||||
|
||||
# Normalize the path
|
||||
try:
|
||||
# Convert to Path object and resolve to canonical form
|
||||
normalized = str(Path(path).as_posix())
|
||||
|
||||
# Ensure it still starts with /
|
||||
if not normalized.startswith("/"):
|
||||
return None
|
||||
|
||||
# Preserve trailing slash for directories
|
||||
if is_directory and not normalized.endswith("/") and normalized != "/":
|
||||
normalized = normalized + "/"
|
||||
|
||||
return normalized
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
# -----------------------------
|
||||
# Internal helpers
|
||||
# -----------------------------
|
||||
def _view(self, path: str, view_range: Optional[List[int]] = None) -> str:
|
||||
"""View directory contents or file contents."""
|
||||
validated_path = self._validate_path(path)
|
||||
if not validated_path:
|
||||
return "Error: Invalid path."
|
||||
|
||||
# Check if viewing directory (ends with / or is root)
|
||||
if validated_path == "/" or validated_path.endswith("/"):
|
||||
return self._view_directory(validated_path)
|
||||
|
||||
# Otherwise view file
|
||||
return self._view_file(validated_path, view_range)
|
||||
|
||||
def _view_directory(self, path: str) -> str:
|
||||
"""List files in a directory."""
|
||||
# Ensure path ends with / for proper prefix matching
|
||||
search_path = path if path.endswith("/") else path + "/"
|
||||
|
||||
# Find all files that start with this directory path
|
||||
query = {
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": {"$regex": f"^{re.escape(search_path)}"}
|
||||
}
|
||||
|
||||
docs = list(self.collection.find(query, {"path": 1}))
|
||||
|
||||
if not docs:
|
||||
return f"Directory: {path}\n(empty)"
|
||||
|
||||
# Extract filenames relative to the directory
|
||||
files = []
|
||||
for doc in docs:
|
||||
file_path = doc["path"]
|
||||
# Remove the directory prefix
|
||||
if file_path.startswith(search_path):
|
||||
relative = file_path[len(search_path):]
|
||||
if relative:
|
||||
files.append(relative)
|
||||
|
||||
files.sort()
|
||||
file_list = "\n".join(f"- {f}" for f in files)
|
||||
return f"Directory: {path}\n{file_list}"
|
||||
|
||||
def _view_file(self, path: str, view_range: Optional[List[int]] = None) -> str:
|
||||
"""View file contents with optional line range."""
|
||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": path})
|
||||
|
||||
if not doc or not doc.get("content"):
|
||||
return f"Error: File not found: {path}"
|
||||
|
||||
content = str(doc["content"])
|
||||
|
||||
# Apply view_range if specified
|
||||
if view_range and len(view_range) == 2:
|
||||
lines = content.split("\n")
|
||||
start, end = view_range
|
||||
# Convert to 0-indexed
|
||||
start_idx = max(0, start - 1)
|
||||
end_idx = min(len(lines), end)
|
||||
|
||||
if start_idx >= len(lines):
|
||||
return f"Error: Line range out of bounds. File has {len(lines)} lines."
|
||||
|
||||
selected_lines = lines[start_idx:end_idx]
|
||||
# Add line numbers (enumerate with 1-based start)
|
||||
numbered_lines = [f"{i}: {line}" for i, line in enumerate(selected_lines, start=start)]
|
||||
return "\n".join(numbered_lines)
|
||||
|
||||
return content
|
||||
|
||||
def _create(self, path: str, file_text: str) -> str:
|
||||
"""Create or overwrite a file."""
|
||||
validated_path = self._validate_path(path)
|
||||
if not validated_path:
|
||||
return "Error: Invalid path."
|
||||
|
||||
if validated_path == "/" or validated_path.endswith("/"):
|
||||
return "Error: Cannot create a file at directory path."
|
||||
|
||||
self.collection.update_one(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path},
|
||||
{
|
||||
"$set": {
|
||||
"content": file_text,
|
||||
"updated_at": datetime.now()
|
||||
}
|
||||
},
|
||||
upsert=True
|
||||
)
|
||||
|
||||
return f"File created: {validated_path}"
|
||||
|
||||
def _str_replace(self, path: str, old_str: str, new_str: str) -> str:
|
||||
"""Replace text in a file."""
|
||||
validated_path = self._validate_path(path)
|
||||
if not validated_path:
|
||||
return "Error: Invalid path."
|
||||
|
||||
if not old_str:
|
||||
return "Error: old_str is required."
|
||||
|
||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path})
|
||||
|
||||
if not doc or not doc.get("content"):
|
||||
return f"Error: File not found: {validated_path}"
|
||||
|
||||
current_content = str(doc["content"])
|
||||
|
||||
# Check if old_str exists (case-insensitive)
|
||||
if old_str.lower() not in current_content.lower():
|
||||
return f"Error: String '{old_str}' not found in file."
|
||||
|
||||
# Replace the string (case-insensitive)
|
||||
import re as regex_module
|
||||
updated_content = regex_module.sub(regex_module.escape(old_str), new_str, current_content, flags=regex_module.IGNORECASE)
|
||||
|
||||
self.collection.update_one(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path},
|
||||
{
|
||||
"$set": {
|
||||
"content": updated_content,
|
||||
"updated_at": datetime.now()
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
return f"File updated: {validated_path}"
|
||||
|
||||
def _insert(self, path: str, insert_line: int, insert_text: str) -> str:
|
||||
"""Insert text at a specific line."""
|
||||
validated_path = self._validate_path(path)
|
||||
if not validated_path:
|
||||
return "Error: Invalid path."
|
||||
|
||||
if not insert_text:
|
||||
return "Error: insert_text is required."
|
||||
|
||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path})
|
||||
|
||||
if not doc or not doc.get("content"):
|
||||
return f"Error: File not found: {validated_path}"
|
||||
|
||||
current_content = str(doc["content"])
|
||||
lines = current_content.split("\n")
|
||||
|
||||
# Convert to 0-indexed
|
||||
index = insert_line - 1
|
||||
if index < 0 or index > len(lines):
|
||||
return f"Error: Invalid line number. File has {len(lines)} lines."
|
||||
|
||||
lines.insert(index, insert_text)
|
||||
updated_content = "\n".join(lines)
|
||||
|
||||
self.collection.update_one(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path},
|
||||
{
|
||||
"$set": {
|
||||
"content": updated_content,
|
||||
"updated_at": datetime.now()
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
return f"Text inserted at line {insert_line} in {validated_path}"
|
||||
|
||||
def _delete(self, path: str) -> str:
|
||||
"""Delete a file or directory."""
|
||||
validated_path = self._validate_path(path)
|
||||
if not validated_path:
|
||||
return "Error: Invalid path."
|
||||
|
||||
if validated_path == "/":
|
||||
# Delete all files for this user and tool
|
||||
result = self.collection.delete_many({"user_id": self.user_id, "tool_id": self.tool_id})
|
||||
return f"Deleted {result.deleted_count} file(s) from memory."
|
||||
|
||||
# Check if it's a directory (ends with /)
|
||||
if validated_path.endswith("/"):
|
||||
# Delete all files in directory
|
||||
result = self.collection.delete_many({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": {"$regex": f"^{re.escape(validated_path)}"}
|
||||
})
|
||||
return f"Deleted directory and {result.deleted_count} file(s)."
|
||||
|
||||
# Try to delete as directory first (without trailing slash)
|
||||
# Check if any files start with this path + /
|
||||
search_path = validated_path + "/"
|
||||
directory_result = self.collection.delete_many({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": {"$regex": f"^{re.escape(search_path)}"}
|
||||
})
|
||||
|
||||
if directory_result.deleted_count > 0:
|
||||
return f"Deleted directory and {directory_result.deleted_count} file(s)."
|
||||
|
||||
# Delete single file
|
||||
result = self.collection.delete_one({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": validated_path
|
||||
})
|
||||
|
||||
if result.deleted_count:
|
||||
return f"Deleted: {validated_path}"
|
||||
return f"Error: File not found: {validated_path}"
|
||||
|
||||
def _rename(self, old_path: str, new_path: str) -> str:
|
||||
"""Rename or move a file/directory."""
|
||||
validated_old = self._validate_path(old_path)
|
||||
validated_new = self._validate_path(new_path)
|
||||
|
||||
if not validated_old or not validated_new:
|
||||
return "Error: Invalid path."
|
||||
|
||||
if validated_old == "/" or validated_new == "/":
|
||||
return "Error: Cannot rename root directory."
|
||||
|
||||
# Check if renaming a directory
|
||||
if validated_old.endswith("/"):
|
||||
# Ensure validated_new also ends with / for proper path replacement
|
||||
if not validated_new.endswith("/"):
|
||||
validated_new = validated_new + "/"
|
||||
|
||||
# Find all files in the old directory
|
||||
docs = list(self.collection.find({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": {"$regex": f"^{re.escape(validated_old)}"}
|
||||
}))
|
||||
|
||||
if not docs:
|
||||
return f"Error: Directory not found: {validated_old}"
|
||||
|
||||
# Update paths for all files
|
||||
for doc in docs:
|
||||
old_file_path = doc["path"]
|
||||
new_file_path = old_file_path.replace(validated_old, validated_new, 1)
|
||||
|
||||
self.collection.update_one(
|
||||
{"_id": doc["_id"]},
|
||||
{"$set": {"path": new_file_path, "updated_at": datetime.now()}}
|
||||
)
|
||||
|
||||
return f"Renamed directory: {validated_old} -> {validated_new} ({len(docs)} files)"
|
||||
|
||||
# Rename single file
|
||||
doc = self.collection.find_one({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": validated_old
|
||||
})
|
||||
|
||||
if not doc:
|
||||
return f"Error: File not found: {validated_old}"
|
||||
|
||||
# Check if new path already exists
|
||||
existing = self.collection.find_one({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": validated_new
|
||||
})
|
||||
|
||||
if existing:
|
||||
return f"Error: File already exists at {validated_new}"
|
||||
|
||||
# Delete the old document and create a new one with the new path
|
||||
self.collection.delete_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_old})
|
||||
self.collection.insert_one({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": validated_new,
|
||||
"content": doc.get("content", ""),
|
||||
"updated_at": datetime.now()
|
||||
})
|
||||
|
||||
return f"Renamed: {validated_old} -> {validated_new}"
|
||||
@@ -1,199 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
import uuid
|
||||
|
||||
from .base import Tool
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
class NotesTool(Tool):
|
||||
"""Notepad
|
||||
|
||||
Single note. Supports viewing, overwriting, string replacement.
|
||||
"""
|
||||
|
||||
def __init__(self, tool_config: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None) -> None:
|
||||
"""Initialize the tool.
|
||||
|
||||
Args:
|
||||
tool_config: Optional tool configuration. Should include:
|
||||
- tool_id: Unique identifier for this notes tool instance (from user_tools._id)
|
||||
This ensures each user's tool configuration has isolated notes
|
||||
user_id: The authenticated user's id (should come from decoded_token["sub"]).
|
||||
"""
|
||||
self.user_id: Optional[str] = user_id
|
||||
|
||||
# Get tool_id from configuration (passed from user_tools._id in production)
|
||||
# In production, tool_id is the MongoDB ObjectId string from user_tools collection
|
||||
if tool_config and "tool_id" in tool_config:
|
||||
self.tool_id = tool_config["tool_id"]
|
||||
elif user_id:
|
||||
# Fallback for backward compatibility or testing
|
||||
self.tool_id = f"default_{user_id}"
|
||||
else:
|
||||
# Last resort fallback (shouldn't happen in normal use)
|
||||
self.tool_id = str(uuid.uuid4())
|
||||
|
||||
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
|
||||
self.collection = db["notes"]
|
||||
|
||||
# -----------------------------
|
||||
# Action implementations
|
||||
# -----------------------------
|
||||
def execute_action(self, action_name: str, **kwargs: Any) -> str:
|
||||
"""Execute an action by name.
|
||||
|
||||
Args:
|
||||
action_name: One of view, overwrite, str_replace, insert, delete.
|
||||
**kwargs: Parameters for the action.
|
||||
|
||||
Returns:
|
||||
A human-readable string result.
|
||||
"""
|
||||
if not self.user_id:
|
||||
return "Error: NotesTool requires a valid user_id."
|
||||
|
||||
if action_name == "view":
|
||||
return self._get_note()
|
||||
|
||||
if action_name == "overwrite":
|
||||
return self._overwrite_note(kwargs.get("text", ""))
|
||||
|
||||
if action_name == "str_replace":
|
||||
return self._str_replace(kwargs.get("old_str", ""), kwargs.get("new_str", ""))
|
||||
|
||||
if action_name == "insert":
|
||||
return self._insert(kwargs.get("line_number", 1), kwargs.get("text", ""))
|
||||
|
||||
if action_name == "delete":
|
||||
return self._delete_note()
|
||||
|
||||
return f"Unknown action: {action_name}"
|
||||
|
||||
def get_actions_metadata(self) -> List[Dict[str, Any]]:
|
||||
"""Return JSON metadata describing supported actions for tool schemas."""
|
||||
return [
|
||||
{
|
||||
"name": "view",
|
||||
"description": "Retrieve the user's note.",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
{
|
||||
"name": "overwrite",
|
||||
"description": "Replace the entire note content (creates if doesn't exist).",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {"type": "string", "description": "New note content."}
|
||||
},
|
||||
"required": ["text"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "str_replace",
|
||||
"description": "Replace occurrences of old_str with new_str in the note.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"old_str": {"type": "string", "description": "String to find."},
|
||||
"new_str": {"type": "string", "description": "String to replace with."}
|
||||
},
|
||||
"required": ["old_str", "new_str"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "insert",
|
||||
"description": "Insert text at the specified line number (1-indexed).",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"line_number": {"type": "integer", "description": "Line number to insert at (1-indexed)."},
|
||||
"text": {"type": "string", "description": "Text to insert."}
|
||||
},
|
||||
"required": ["line_number", "text"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "delete",
|
||||
"description": "Delete the user's note.",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
]
|
||||
|
||||
def get_config_requirements(self) -> Dict[str, Any]:
|
||||
"""Return configuration requirements (none for now)."""
|
||||
return {}
|
||||
|
||||
# -----------------------------
|
||||
# Internal helpers (single-note)
|
||||
# -----------------------------
|
||||
def _get_note(self) -> str:
|
||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
|
||||
if not doc or not doc.get("note"):
|
||||
return "No note found."
|
||||
return str(doc["note"])
|
||||
|
||||
def _overwrite_note(self, content: str) -> str:
|
||||
content = (content or "").strip()
|
||||
if not content:
|
||||
return "Note content required."
|
||||
self.collection.update_one(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id},
|
||||
{"$set": {"note": content, "updated_at": datetime.utcnow()}},
|
||||
upsert=True, # ✅ create if missing
|
||||
)
|
||||
return "Note saved."
|
||||
|
||||
def _str_replace(self, old_str: str, new_str: str) -> str:
|
||||
if not old_str:
|
||||
return "old_str is required."
|
||||
|
||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
|
||||
if not doc or not doc.get("note"):
|
||||
return "No note found."
|
||||
|
||||
current_note = str(doc["note"])
|
||||
|
||||
# Case-insensitive search
|
||||
if old_str.lower() not in current_note.lower():
|
||||
return f"String '{old_str}' not found in note."
|
||||
|
||||
# Case-insensitive replacement
|
||||
import re
|
||||
updated_note = re.sub(re.escape(old_str), new_str, current_note, flags=re.IGNORECASE)
|
||||
|
||||
self.collection.update_one(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id},
|
||||
{"$set": {"note": updated_note, "updated_at": datetime.utcnow()}},
|
||||
)
|
||||
return "Note updated."
|
||||
|
||||
def _insert(self, line_number: int, text: str) -> str:
|
||||
if not text:
|
||||
return "Text is required."
|
||||
|
||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
|
||||
if not doc or not doc.get("note"):
|
||||
return "No note found."
|
||||
|
||||
current_note = str(doc["note"])
|
||||
lines = current_note.split("\n")
|
||||
|
||||
# Convert to 0-indexed and validate
|
||||
index = line_number - 1
|
||||
if index < 0 or index > len(lines):
|
||||
return f"Invalid line number. Note has {len(lines)} lines."
|
||||
|
||||
lines.insert(index, text)
|
||||
updated_note = "\n".join(lines)
|
||||
|
||||
self.collection.update_one(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id},
|
||||
{"$set": {"note": updated_note, "updated_at": datetime.utcnow()}},
|
||||
)
|
||||
return "Text inserted."
|
||||
|
||||
def _delete_note(self) -> str:
|
||||
res = self.collection.delete_one({"user_id": self.user_id, "tool_id": self.tool_id})
|
||||
return "Note deleted." if res.deleted_count else "No note found to delete."
|
||||
@@ -1,321 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
import uuid
|
||||
|
||||
from .base import Tool
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
class TodoListTool(Tool):
|
||||
"""Todo List
|
||||
|
||||
Manages todo items for users. Supports creating, viewing, updating, and deleting todos.
|
||||
"""
|
||||
|
||||
def __init__(self, tool_config: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None) -> None:
|
||||
"""Initialize the tool.
|
||||
|
||||
Args:
|
||||
tool_config: Optional tool configuration. Should include:
|
||||
- tool_id: Unique identifier for this todo list tool instance (from user_tools._id)
|
||||
This ensures each user's tool configuration has isolated todos
|
||||
user_id: The authenticated user's id (should come from decoded_token["sub"]).
|
||||
"""
|
||||
self.user_id: Optional[str] = user_id
|
||||
|
||||
# Get tool_id from configuration (passed from user_tools._id in production)
|
||||
# In production, tool_id is the MongoDB ObjectId string from user_tools collection
|
||||
if tool_config and "tool_id" in tool_config:
|
||||
self.tool_id = tool_config["tool_id"]
|
||||
elif user_id:
|
||||
# Fallback for backward compatibility or testing
|
||||
self.tool_id = f"default_{user_id}"
|
||||
else:
|
||||
# Last resort fallback (shouldn't happen in normal use)
|
||||
self.tool_id = str(uuid.uuid4())
|
||||
|
||||
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
|
||||
self.collection = db["todos"]
|
||||
|
||||
# -----------------------------
|
||||
# Action implementations
|
||||
# -----------------------------
|
||||
def execute_action(self, action_name: str, **kwargs: Any) -> str:
|
||||
"""Execute an action by name.
|
||||
|
||||
Args:
|
||||
action_name: One of list, create, get, update, complete, delete.
|
||||
**kwargs: Parameters for the action.
|
||||
|
||||
Returns:
|
||||
A human-readable string result.
|
||||
"""
|
||||
if not self.user_id:
|
||||
return "Error: TodoListTool requires a valid user_id."
|
||||
|
||||
if action_name == "list":
|
||||
return self._list()
|
||||
|
||||
if action_name == "create":
|
||||
return self._create(kwargs.get("title", ""))
|
||||
|
||||
if action_name == "get":
|
||||
return self._get(kwargs.get("todo_id"))
|
||||
|
||||
if action_name == "update":
|
||||
return self._update(
|
||||
kwargs.get("todo_id"),
|
||||
kwargs.get("title", "")
|
||||
)
|
||||
|
||||
if action_name == "complete":
|
||||
return self._complete(kwargs.get("todo_id"))
|
||||
|
||||
if action_name == "delete":
|
||||
return self._delete(kwargs.get("todo_id"))
|
||||
|
||||
return f"Unknown action: {action_name}"
|
||||
|
||||
def get_actions_metadata(self) -> List[Dict[str, Any]]:
|
||||
"""Return JSON metadata describing supported actions for tool schemas."""
|
||||
return [
|
||||
{
|
||||
"name": "list",
|
||||
"description": "List all todos for the user.",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
{
|
||||
"name": "create",
|
||||
"description": "Create a new todo item.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {
|
||||
"type": "string",
|
||||
"description": "Title of the todo item."
|
||||
}
|
||||
},
|
||||
"required": ["title"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "get",
|
||||
"description": "Get a specific todo by ID.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"todo_id": {
|
||||
"type": "integer",
|
||||
"description": "The ID of the todo to retrieve."
|
||||
}
|
||||
},
|
||||
"required": ["todo_id"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "update",
|
||||
"description": "Update a todo's title by ID.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"todo_id": {
|
||||
"type": "integer",
|
||||
"description": "The ID of the todo to update."
|
||||
},
|
||||
"title": {
|
||||
"type": "string",
|
||||
"description": "The new title for the todo."
|
||||
}
|
||||
},
|
||||
"required": ["todo_id", "title"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "complete",
|
||||
"description": "Mark a todo as completed.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"todo_id": {
|
||||
"type": "integer",
|
||||
"description": "The ID of the todo to mark as completed."
|
||||
}
|
||||
},
|
||||
"required": ["todo_id"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "delete",
|
||||
"description": "Delete a specific todo by ID.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"todo_id": {
|
||||
"type": "integer",
|
||||
"description": "The ID of the todo to delete."
|
||||
}
|
||||
},
|
||||
"required": ["todo_id"],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
def get_config_requirements(self) -> Dict[str, Any]:
|
||||
"""Return configuration requirements."""
|
||||
return {}
|
||||
|
||||
# -----------------------------
|
||||
# Internal helpers
|
||||
# -----------------------------
|
||||
def _coerce_todo_id(self, value: Optional[Any]) -> Optional[int]:
|
||||
"""Convert todo identifiers to sequential integers."""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if isinstance(value, int):
|
||||
return value if value > 0 else None
|
||||
|
||||
if isinstance(value, str):
|
||||
stripped = value.strip()
|
||||
if stripped.isdigit():
|
||||
numeric_value = int(stripped)
|
||||
return numeric_value if numeric_value > 0 else None
|
||||
|
||||
return None
|
||||
|
||||
def _get_next_todo_id(self) -> int:
|
||||
"""Get the next sequential todo_id for this user and tool.
|
||||
|
||||
Returns a simple integer (1, 2, 3, ...) scoped to this user/tool.
|
||||
With 5-10 todos max, scanning is negligible.
|
||||
"""
|
||||
# Find all todos for this user/tool and get their IDs
|
||||
todos = list(self.collection.find(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id},
|
||||
{"todo_id": 1}
|
||||
))
|
||||
|
||||
# Find the maximum todo_id
|
||||
max_id = 0
|
||||
for todo in todos:
|
||||
todo_id = self._coerce_todo_id(todo.get("todo_id"))
|
||||
if todo_id is not None:
|
||||
max_id = max(max_id, todo_id)
|
||||
|
||||
return max_id + 1
|
||||
|
||||
def _list(self) -> str:
|
||||
"""List all todos for the user."""
|
||||
cursor = self.collection.find({"user_id": self.user_id, "tool_id": self.tool_id})
|
||||
todos = list(cursor)
|
||||
|
||||
if not todos:
|
||||
return "No todos found."
|
||||
|
||||
result_lines = ["Todos:"]
|
||||
for doc in todos:
|
||||
todo_id = doc.get("todo_id")
|
||||
title = doc.get("title", "Untitled")
|
||||
status = doc.get("status", "open")
|
||||
|
||||
line = f"[{todo_id}] {title} ({status})"
|
||||
result_lines.append(line)
|
||||
|
||||
return "\n".join(result_lines)
|
||||
|
||||
def _create(self, title: str) -> str:
|
||||
"""Create a new todo item."""
|
||||
title = (title or "").strip()
|
||||
if not title:
|
||||
return "Error: Title is required."
|
||||
|
||||
now = datetime.now()
|
||||
todo_id = self._get_next_todo_id()
|
||||
|
||||
doc = {
|
||||
"todo_id": todo_id,
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"title": title,
|
||||
"status": "open",
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
}
|
||||
self.collection.insert_one(doc)
|
||||
return f"Todo created with ID {todo_id}: {title}"
|
||||
|
||||
def _get(self, todo_id: Optional[Any]) -> str:
|
||||
"""Get a specific todo by ID."""
|
||||
parsed_todo_id = self._coerce_todo_id(todo_id)
|
||||
if parsed_todo_id is None:
|
||||
return "Error: todo_id must be a positive integer."
|
||||
|
||||
doc = self.collection.find_one({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"todo_id": parsed_todo_id
|
||||
})
|
||||
|
||||
if not doc:
|
||||
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||
|
||||
title = doc.get("title", "Untitled")
|
||||
status = doc.get("status", "open")
|
||||
|
||||
result = f"Todo [{parsed_todo_id}]:\nTitle: {title}\nStatus: {status}"
|
||||
|
||||
return result
|
||||
|
||||
def _update(self, todo_id: Optional[Any], title: str) -> str:
|
||||
"""Update a todo's title by ID."""
|
||||
parsed_todo_id = self._coerce_todo_id(todo_id)
|
||||
if parsed_todo_id is None:
|
||||
return "Error: todo_id must be a positive integer."
|
||||
|
||||
title = (title or "").strip()
|
||||
if not title:
|
||||
return "Error: Title is required."
|
||||
|
||||
result = self.collection.update_one(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id},
|
||||
{"$set": {"title": title, "updated_at": datetime.now()}}
|
||||
)
|
||||
|
||||
if result.matched_count == 0:
|
||||
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||
|
||||
return f"Todo {parsed_todo_id} updated to: {title}"
|
||||
|
||||
def _complete(self, todo_id: Optional[Any]) -> str:
|
||||
"""Mark a todo as completed."""
|
||||
parsed_todo_id = self._coerce_todo_id(todo_id)
|
||||
if parsed_todo_id is None:
|
||||
return "Error: todo_id must be a positive integer."
|
||||
|
||||
result = self.collection.update_one(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id},
|
||||
{"$set": {"status": "completed", "updated_at": datetime.now()}}
|
||||
)
|
||||
|
||||
if result.matched_count == 0:
|
||||
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||
|
||||
return f"Todo {parsed_todo_id} marked as completed."
|
||||
|
||||
def _delete(self, todo_id: Optional[Any]) -> str:
|
||||
"""Delete a specific todo by ID."""
|
||||
parsed_todo_id = self._coerce_todo_id(todo_id)
|
||||
if parsed_todo_id is None:
|
||||
return "Error: todo_id must be a positive integer."
|
||||
|
||||
result = self.collection.delete_one({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"todo_id": parsed_todo_id
|
||||
})
|
||||
|
||||
if result.deleted_count == 0:
|
||||
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||
|
||||
return f"Todo {parsed_todo_id} deleted."
|
||||
@@ -20,24 +20,20 @@ class ToolActionParser:
|
||||
try:
|
||||
call_args = json.loads(call.arguments)
|
||||
tool_parts = call.name.split("_")
|
||||
|
||||
|
||||
# If the tool name doesn't contain an underscore, it's likely a hallucinated tool
|
||||
if len(tool_parts) < 2:
|
||||
logger.warning(
|
||||
f"Invalid tool name format: {call.name}. Expected format: action_name_tool_id"
|
||||
)
|
||||
logger.warning(f"Invalid tool name format: {call.name}. Expected format: action_name_tool_id")
|
||||
return None, None, None
|
||||
|
||||
|
||||
tool_id = tool_parts[-1]
|
||||
action_name = "_".join(tool_parts[:-1])
|
||||
|
||||
|
||||
# Validate that tool_id looks like a numerical ID
|
||||
if not tool_id.isdigit():
|
||||
logger.warning(
|
||||
f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call."
|
||||
)
|
||||
|
||||
except (AttributeError, TypeError, json.JSONDecodeError) as e:
|
||||
logger.warning(f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call.")
|
||||
|
||||
except (AttributeError, TypeError) as e:
|
||||
logger.error(f"Error parsing OpenAI LLM call: {e}")
|
||||
return None, None, None
|
||||
return tool_id, action_name, call_args
|
||||
@@ -46,23 +42,19 @@ class ToolActionParser:
|
||||
try:
|
||||
call_args = call.arguments
|
||||
tool_parts = call.name.split("_")
|
||||
|
||||
|
||||
# If the tool name doesn't contain an underscore, it's likely a hallucinated tool
|
||||
if len(tool_parts) < 2:
|
||||
logger.warning(
|
||||
f"Invalid tool name format: {call.name}. Expected format: action_name_tool_id"
|
||||
)
|
||||
logger.warning(f"Invalid tool name format: {call.name}. Expected format: action_name_tool_id")
|
||||
return None, None, None
|
||||
|
||||
|
||||
tool_id = tool_parts[-1]
|
||||
action_name = "_".join(tool_parts[:-1])
|
||||
|
||||
|
||||
# Validate that tool_id looks like a numerical ID
|
||||
if not tool_id.isdigit():
|
||||
logger.warning(
|
||||
f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call."
|
||||
)
|
||||
|
||||
logger.warning(f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call.")
|
||||
|
||||
except (AttributeError, TypeError) as e:
|
||||
logger.error(f"Error parsing Google LLM call: {e}")
|
||||
return None, None, None
|
||||
|
||||
@@ -23,23 +23,16 @@ class ToolManager:
|
||||
tool_config = self.config.get(name, {})
|
||||
self.tools[name] = obj(tool_config)
|
||||
|
||||
def load_tool(self, tool_name, tool_config, user_id=None):
|
||||
def load_tool(self, tool_name, tool_config):
|
||||
self.config[tool_name] = tool_config
|
||||
module = importlib.import_module(f"application.agents.tools.{tool_name}")
|
||||
for member_name, obj in inspect.getmembers(module, inspect.isclass):
|
||||
if issubclass(obj, Tool) and obj is not Tool:
|
||||
if tool_name in {"mcp_tool", "notes", "memory", "todo_list"} and user_id:
|
||||
return obj(tool_config, user_id)
|
||||
else:
|
||||
return obj(tool_config)
|
||||
return obj(tool_config)
|
||||
|
||||
def execute_action(self, tool_name, action_name, user_id=None, **kwargs):
|
||||
def execute_action(self, tool_name, action_name, **kwargs):
|
||||
if tool_name not in self.tools:
|
||||
raise ValueError(f"Tool '{tool_name}' not loaded")
|
||||
if tool_name in {"mcp_tool", "memory", "todo_list"} and user_id:
|
||||
tool_config = self.config.get(tool_name, {})
|
||||
tool = self.load_tool(tool_name, tool_config, user_id)
|
||||
return tool.execute_action(action_name, **kwargs)
|
||||
return self.tools[tool_name].execute_action(action_name, **kwargs)
|
||||
|
||||
def get_all_actions_metadata(self):
|
||||
|
||||
@@ -54,14 +54,6 @@ class AnswerResource(Resource, BaseAnswerResource):
|
||||
default=True,
|
||||
description="Whether to save the conversation",
|
||||
),
|
||||
"model_id": fields.String(
|
||||
required=False,
|
||||
description="Model ID to use for this request",
|
||||
),
|
||||
"passthrough": fields.Raw(
|
||||
required=False,
|
||||
description="Dynamic parameters to inject into prompt template",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -77,31 +69,19 @@ class AnswerResource(Resource, BaseAnswerResource):
|
||||
processor.initialize()
|
||||
if not processor.decoded_token:
|
||||
return make_response({"error": "Unauthorized"}, 401)
|
||||
|
||||
docs_together, docs_list = processor.pre_fetch_docs(
|
||||
data.get("question", "")
|
||||
)
|
||||
tools_data = processor.pre_fetch_tools()
|
||||
|
||||
agent = processor.create_agent(
|
||||
docs_together=docs_together,
|
||||
docs=docs_list,
|
||||
tools_data=tools_data,
|
||||
)
|
||||
|
||||
if error := self.check_usage(processor.agent_config):
|
||||
return error
|
||||
agent = processor.create_agent()
|
||||
retriever = processor.create_retriever()
|
||||
|
||||
stream = self.complete_stream(
|
||||
question=data["question"],
|
||||
agent=agent,
|
||||
retriever=retriever,
|
||||
conversation_id=processor.conversation_id,
|
||||
user_api_key=processor.agent_config.get("user_api_key"),
|
||||
decoded_token=processor.decoded_token,
|
||||
isNoneDoc=data.get("isNoneDoc"),
|
||||
index=None,
|
||||
should_save_conversation=data.get("save_conversation", True),
|
||||
model_id=processor.model_id,
|
||||
)
|
||||
stream_result = self.process_response_stream(stream)
|
||||
|
||||
|
||||
@@ -3,20 +3,15 @@ import json
|
||||
import logging
|
||||
from typing import Any, Dict, Generator, List, Optional
|
||||
|
||||
from flask import jsonify, make_response, Response
|
||||
from flask import Response
|
||||
from flask_restx import Namespace
|
||||
|
||||
from application.api.answer.services.conversation_service import ConversationService
|
||||
from application.core.model_utils import (
|
||||
get_api_key_for_provider,
|
||||
get_default_model_id,
|
||||
get_provider_from_model_id,
|
||||
)
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.utils import check_required_fields
|
||||
from application.utils import check_required_fields, get_gpt_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -30,9 +25,8 @@ class BaseAnswerResource:
|
||||
def __init__(self):
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
self.db = db
|
||||
self.user_logs_collection = db["user_logs"]
|
||||
self.default_model_id = get_default_model_id()
|
||||
self.gpt_model = get_gpt_model()
|
||||
self.conversation_service = ConversationService()
|
||||
|
||||
def validate_request(
|
||||
@@ -46,104 +40,11 @@ class BaseAnswerResource:
|
||||
return missing_fields
|
||||
return None
|
||||
|
||||
def check_usage(self, agent_config: Dict) -> Optional[Response]:
|
||||
"""Check if there is a usage limit and if it is exceeded
|
||||
|
||||
Args:
|
||||
agent_config: The config dict of agent instance
|
||||
|
||||
Returns:
|
||||
None or Response if either of limits exceeded.
|
||||
|
||||
"""
|
||||
api_key = agent_config.get("user_api_key")
|
||||
if not api_key:
|
||||
return None
|
||||
agents_collection = self.db["agents"]
|
||||
agent = agents_collection.find_one({"key": api_key})
|
||||
|
||||
if not agent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid API key."}), 401
|
||||
)
|
||||
limited_token_mode_raw = agent.get("limited_token_mode", False)
|
||||
limited_request_mode_raw = agent.get("limited_request_mode", False)
|
||||
|
||||
limited_token_mode = (
|
||||
limited_token_mode_raw
|
||||
if isinstance(limited_token_mode_raw, bool)
|
||||
else limited_token_mode_raw == "True"
|
||||
)
|
||||
limited_request_mode = (
|
||||
limited_request_mode_raw
|
||||
if isinstance(limited_request_mode_raw, bool)
|
||||
else limited_request_mode_raw == "True"
|
||||
)
|
||||
|
||||
token_limit = int(
|
||||
agent.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"])
|
||||
)
|
||||
request_limit = int(
|
||||
agent.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"])
|
||||
)
|
||||
|
||||
token_usage_collection = self.db["token_usage"]
|
||||
|
||||
end_date = datetime.datetime.now()
|
||||
start_date = end_date - datetime.timedelta(hours=24)
|
||||
|
||||
match_query = {
|
||||
"timestamp": {"$gte": start_date, "$lte": end_date},
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
||||
if limited_token_mode:
|
||||
token_pipeline = [
|
||||
{"$match": match_query},
|
||||
{
|
||||
"$group": {
|
||||
"_id": None,
|
||||
"total_tokens": {
|
||||
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
|
||||
},
|
||||
}
|
||||
},
|
||||
]
|
||||
token_result = list(token_usage_collection.aggregate(token_pipeline))
|
||||
daily_token_usage = token_result[0]["total_tokens"] if token_result else 0
|
||||
else:
|
||||
daily_token_usage = 0
|
||||
if limited_request_mode:
|
||||
daily_request_usage = token_usage_collection.count_documents(match_query)
|
||||
else:
|
||||
daily_request_usage = 0
|
||||
if not limited_token_mode and not limited_request_mode:
|
||||
return None
|
||||
token_exceeded = (
|
||||
limited_token_mode and token_limit > 0 and daily_token_usage >= token_limit
|
||||
)
|
||||
request_exceeded = (
|
||||
limited_request_mode
|
||||
and request_limit > 0
|
||||
and daily_request_usage >= request_limit
|
||||
)
|
||||
|
||||
if token_exceeded or request_exceeded:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Exceeding usage limit, please try again later.",
|
||||
}
|
||||
),
|
||||
429,
|
||||
)
|
||||
return None
|
||||
|
||||
def complete_stream(
|
||||
self,
|
||||
question: str,
|
||||
agent: Any,
|
||||
retriever: Any,
|
||||
conversation_id: Optional[str],
|
||||
user_api_key: Optional[str],
|
||||
decoded_token: Dict[str, Any],
|
||||
@@ -154,7 +55,6 @@ class BaseAnswerResource:
|
||||
agent_id: Optional[str] = None,
|
||||
is_shared_usage: bool = False,
|
||||
shared_token: Optional[str] = None,
|
||||
model_id: Optional[str] = None,
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Generator function that streams the complete conversation response.
|
||||
@@ -173,8 +73,6 @@ class BaseAnswerResource:
|
||||
agent_id: ID of agent used
|
||||
is_shared_usage: Flag for shared agent usage
|
||||
shared_token: Token for shared agent
|
||||
model_id: Model ID used for the request
|
||||
retrieved_docs: Pre-fetched documents for sources (optional)
|
||||
|
||||
Yields:
|
||||
Server-sent event strings
|
||||
@@ -185,7 +83,7 @@ class BaseAnswerResource:
|
||||
schema_info = None
|
||||
structured_chunks = []
|
||||
|
||||
for line in agent.gen(query=question):
|
||||
for line in agent.gen(query=question, retriever=retriever):
|
||||
if "answer" in line:
|
||||
response_full += str(line["answer"])
|
||||
if line.get("structured"):
|
||||
@@ -212,8 +110,6 @@ class BaseAnswerResource:
|
||||
yield f"data: {data}\n\n"
|
||||
elif "tool_calls" in line:
|
||||
tool_calls = line["tool_calls"]
|
||||
data = json.dumps({"type": "tool_calls", "tool_calls": tool_calls})
|
||||
yield f"data: {data}\n\n"
|
||||
elif "thought" in line:
|
||||
thought += line["thought"]
|
||||
data = json.dumps({"type": "thought", "thought": line["thought"]})
|
||||
@@ -221,6 +117,7 @@ class BaseAnswerResource:
|
||||
elif "type" in line:
|
||||
data = json.dumps(line)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
if is_structured and structured_chunks:
|
||||
structured_data = {
|
||||
"type": "structured_answer",
|
||||
@@ -230,22 +127,15 @@ class BaseAnswerResource:
|
||||
}
|
||||
data = json.dumps(structured_data)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
if isNoneDoc:
|
||||
for doc in source_log_docs:
|
||||
doc["source"] = "None"
|
||||
provider = (
|
||||
get_provider_from_model_id(model_id)
|
||||
if model_id
|
||||
else settings.LLM_PROVIDER
|
||||
)
|
||||
system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER)
|
||||
|
||||
llm = LLMCreator.create_llm(
|
||||
provider or settings.LLM_PROVIDER,
|
||||
api_key=system_api_key,
|
||||
settings.LLM_PROVIDER,
|
||||
api_key=settings.API_KEY,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
model_id=model_id,
|
||||
)
|
||||
|
||||
if should_save_conversation:
|
||||
@@ -257,7 +147,7 @@ class BaseAnswerResource:
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
model_id or self.default_model_id,
|
||||
self.gpt_model,
|
||||
decoded_token,
|
||||
index=index,
|
||||
api_key=user_api_key,
|
||||
@@ -266,32 +156,13 @@ class BaseAnswerResource:
|
||||
shared_token=shared_token,
|
||||
attachment_ids=attachment_ids,
|
||||
)
|
||||
# Persist compression metadata/summary if it exists and wasn't saved mid-execution
|
||||
compression_meta = getattr(agent, "compression_metadata", None)
|
||||
compression_saved = getattr(agent, "compression_saved", False)
|
||||
if conversation_id and compression_meta and not compression_saved:
|
||||
try:
|
||||
self.conversation_service.update_compression_metadata(
|
||||
conversation_id, compression_meta
|
||||
)
|
||||
self.conversation_service.append_compression_message(
|
||||
conversation_id, compression_meta
|
||||
)
|
||||
agent.compression_saved = True
|
||||
logger.info(
|
||||
f"Persisted compression metadata for conversation {conversation_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to persist compression metadata: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
else:
|
||||
conversation_id = None
|
||||
id_data = {"type": "id", "id": str(conversation_id)}
|
||||
data = json.dumps(id_data)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
retriever_params = retriever.get_params()
|
||||
log_data = {
|
||||
"action": "stream_answer",
|
||||
"level": "info",
|
||||
@@ -300,6 +171,7 @@ class BaseAnswerResource:
|
||||
"question": question,
|
||||
"response": response_full,
|
||||
"sources": source_log_docs,
|
||||
"retriever_params": retriever_params,
|
||||
"attachments": attachment_ids,
|
||||
"timestamp": datetime.datetime.now(datetime.timezone.utc),
|
||||
}
|
||||
@@ -307,71 +179,18 @@ class BaseAnswerResource:
|
||||
log_data["structured_output"] = True
|
||||
if schema_info:
|
||||
log_data["schema"] = schema_info
|
||||
# Clean up text fields to be no longer than 10000 characters
|
||||
|
||||
|
||||
# clean up text fields to be no longer than 10000 characters
|
||||
for key, value in log_data.items():
|
||||
if isinstance(value, str) and len(value) > 10000:
|
||||
log_data[key] = value[:10000]
|
||||
|
||||
self.user_logs_collection.insert_one(log_data)
|
||||
|
||||
# End of stream
|
||||
|
||||
data = json.dumps({"type": "end"})
|
||||
yield f"data: {data}\n\n"
|
||||
except GeneratorExit:
|
||||
logger.info(f"Stream aborted by client for question: {question[:50]}... ")
|
||||
# Save partial response
|
||||
|
||||
if should_save_conversation and response_full:
|
||||
try:
|
||||
if isNoneDoc:
|
||||
for doc in source_log_docs:
|
||||
doc["source"] = "None"
|
||||
llm = LLMCreator.create_llm(
|
||||
settings.LLM_PROVIDER,
|
||||
api_key=settings.API_KEY,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
self.conversation_service.save_conversation(
|
||||
conversation_id,
|
||||
question,
|
||||
response_full,
|
||||
thought,
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
model_id or self.default_model_id,
|
||||
decoded_token,
|
||||
index=index,
|
||||
api_key=user_api_key,
|
||||
agent_id=agent_id,
|
||||
is_shared_usage=is_shared_usage,
|
||||
shared_token=shared_token,
|
||||
attachment_ids=attachment_ids,
|
||||
)
|
||||
compression_meta = getattr(agent, "compression_metadata", None)
|
||||
compression_saved = getattr(agent, "compression_saved", False)
|
||||
if conversation_id and compression_meta and not compression_saved:
|
||||
try:
|
||||
self.conversation_service.update_compression_metadata(
|
||||
conversation_id, compression_meta
|
||||
)
|
||||
self.conversation_service.append_compression_message(
|
||||
conversation_id, compression_meta
|
||||
)
|
||||
agent.compression_saved = True
|
||||
logger.info(
|
||||
f"Persisted compression metadata for conversation {conversation_id} (partial stream)"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to persist compression metadata (partial stream): {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error saving partial response: {str(e)}", exc_info=True
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in stream: {str(e)}", exc_info=True)
|
||||
data = json.dumps(
|
||||
@@ -415,7 +234,7 @@ class BaseAnswerResource:
|
||||
thought = event["thought"]
|
||||
elif event["type"] == "error":
|
||||
logger.error(f"Error from stream: {event['error']}")
|
||||
return None, None, None, None, event["error"], None
|
||||
return None, None, None, None, event["error"]
|
||||
elif event["type"] == "end":
|
||||
stream_ended = True
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
@@ -423,7 +242,8 @@ class BaseAnswerResource:
|
||||
continue
|
||||
if not stream_ended:
|
||||
logger.error("Stream ended unexpectedly without an 'end' event.")
|
||||
return None, None, None, None, "Stream ended unexpectedly", None
|
||||
return None, None, None, None, "Stream ended unexpectedly"
|
||||
|
||||
result = (
|
||||
conversation_id,
|
||||
response_full,
|
||||
@@ -435,6 +255,7 @@ class BaseAnswerResource:
|
||||
|
||||
if is_structured:
|
||||
result = result + ({"structured": True, "schema": schema_info},)
|
||||
|
||||
return result
|
||||
|
||||
def error_stream_generate(self, err_response):
|
||||
|
||||
@@ -57,17 +57,9 @@ class StreamResource(Resource, BaseAnswerResource):
|
||||
default=True,
|
||||
description="Whether to save the conversation",
|
||||
),
|
||||
"model_id": fields.String(
|
||||
required=False,
|
||||
description="Model ID to use for this request",
|
||||
),
|
||||
"attachments": fields.List(
|
||||
fields.String, required=False, description="List of attachment IDs"
|
||||
),
|
||||
"passthrough": fields.Raw(
|
||||
required=False,
|
||||
description="Dynamic parameters to inject into prompt template",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -81,20 +73,14 @@ class StreamResource(Resource, BaseAnswerResource):
|
||||
processor = StreamProcessor(data, decoded_token)
|
||||
try:
|
||||
processor.initialize()
|
||||
agent = processor.create_agent()
|
||||
retriever = processor.create_retriever()
|
||||
|
||||
docs_together, docs_list = processor.pre_fetch_docs(data["question"])
|
||||
tools_data = processor.pre_fetch_tools()
|
||||
|
||||
agent = processor.create_agent(
|
||||
docs_together=docs_together, docs=docs_list, tools_data=tools_data
|
||||
)
|
||||
|
||||
if error := self.check_usage(processor.agent_config):
|
||||
return error
|
||||
return Response(
|
||||
self.complete_stream(
|
||||
question=data["question"],
|
||||
agent=agent,
|
||||
retriever=retriever,
|
||||
conversation_id=processor.conversation_id,
|
||||
user_api_key=processor.agent_config.get("user_api_key"),
|
||||
decoded_token=processor.decoded_token,
|
||||
@@ -105,7 +91,6 @@ class StreamResource(Resource, BaseAnswerResource):
|
||||
agent_id=data.get("agent_id"),
|
||||
is_shared_usage=processor.is_shared_usage,
|
||||
shared_token=processor.shared_token,
|
||||
model_id=processor.model_id,
|
||||
),
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
"""
|
||||
Compression module for managing conversation context compression.
|
||||
|
||||
"""
|
||||
|
||||
from application.api.answer.services.compression.orchestrator import (
|
||||
CompressionOrchestrator,
|
||||
)
|
||||
from application.api.answer.services.compression.service import CompressionService
|
||||
from application.api.answer.services.compression.types import (
|
||||
CompressionResult,
|
||||
CompressionMetadata,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CompressionOrchestrator",
|
||||
"CompressionService",
|
||||
"CompressionResult",
|
||||
"CompressionMetadata",
|
||||
]
|
||||
@@ -1,234 +0,0 @@
|
||||
"""Message reconstruction utilities for compression."""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageBuilder:
|
||||
"""Builds message arrays from compressed context."""
|
||||
|
||||
@staticmethod
|
||||
def build_from_compressed_context(
|
||||
system_prompt: str,
|
||||
compressed_summary: Optional[str],
|
||||
recent_queries: List[Dict],
|
||||
include_tool_calls: bool = False,
|
||||
context_type: str = "pre_request",
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Build messages from compressed context.
|
||||
|
||||
Args:
|
||||
system_prompt: Original system prompt
|
||||
compressed_summary: Compressed summary (if any)
|
||||
recent_queries: Recent uncompressed queries
|
||||
include_tool_calls: Whether to include tool calls from history
|
||||
context_type: Type of context ('pre_request' or 'mid_execution')
|
||||
|
||||
Returns:
|
||||
List of message dicts ready for LLM
|
||||
"""
|
||||
# Append compression summary to system prompt if present
|
||||
if compressed_summary:
|
||||
system_prompt = MessageBuilder._append_compression_context(
|
||||
system_prompt, compressed_summary, context_type
|
||||
)
|
||||
|
||||
messages = [{"role": "system", "content": system_prompt}]
|
||||
|
||||
# Add recent history
|
||||
for query in recent_queries:
|
||||
if "prompt" in query and "response" in query:
|
||||
messages.append({"role": "user", "content": query["prompt"]})
|
||||
messages.append({"role": "assistant", "content": query["response"]})
|
||||
|
||||
# Add tool calls from history if present
|
||||
if include_tool_calls and "tool_calls" in query:
|
||||
for tool_call in query["tool_calls"]:
|
||||
call_id = tool_call.get("call_id") or str(uuid.uuid4())
|
||||
|
||||
function_call_dict = {
|
||||
"function_call": {
|
||||
"name": tool_call.get("action_name"),
|
||||
"args": tool_call.get("arguments"),
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
function_response_dict = {
|
||||
"function_response": {
|
||||
"name": tool_call.get("action_name"),
|
||||
"response": {"result": tool_call.get("result")},
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
|
||||
messages.append(
|
||||
{"role": "assistant", "content": [function_call_dict]}
|
||||
)
|
||||
messages.append(
|
||||
{"role": "tool", "content": [function_response_dict]}
|
||||
)
|
||||
|
||||
# If no recent queries (everything was compressed), add a continuation user message
|
||||
if len(recent_queries) == 0 and compressed_summary:
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": "Please continue with the remaining tasks based on the context above."
|
||||
})
|
||||
logger.info("Added continuation user message to maintain proper turn-taking after full compression")
|
||||
|
||||
return messages
|
||||
|
||||
@staticmethod
|
||||
def _append_compression_context(
|
||||
system_prompt: str, compressed_summary: str, context_type: str = "pre_request"
|
||||
) -> str:
|
||||
"""
|
||||
Append compression context to system prompt.
|
||||
|
||||
Args:
|
||||
system_prompt: Original system prompt
|
||||
compressed_summary: Summary to append
|
||||
context_type: Type of compression context
|
||||
|
||||
Returns:
|
||||
Updated system prompt
|
||||
"""
|
||||
# Remove existing compression context if present
|
||||
if "This session is being continued" in system_prompt or "Context window limit reached" in system_prompt:
|
||||
parts = system_prompt.split("\n\n---\n\n")
|
||||
system_prompt = parts[0]
|
||||
|
||||
# Build appropriate context message based on type
|
||||
if context_type == "mid_execution":
|
||||
context_message = (
|
||||
"\n\n---\n\n"
|
||||
"Context window limit reached during execution. "
|
||||
"Previous conversation has been compressed to fit within limits. "
|
||||
"The conversation is summarized below:\n\n"
|
||||
f"{compressed_summary}"
|
||||
)
|
||||
else: # pre_request
|
||||
context_message = (
|
||||
"\n\n---\n\n"
|
||||
"This session is being continued from a previous conversation that "
|
||||
"has been compressed to fit within context limits. "
|
||||
"The conversation is summarized below:\n\n"
|
||||
f"{compressed_summary}"
|
||||
)
|
||||
|
||||
return system_prompt + context_message
|
||||
|
||||
@staticmethod
|
||||
def rebuild_messages_after_compression(
|
||||
messages: List[Dict],
|
||||
compressed_summary: Optional[str],
|
||||
recent_queries: List[Dict],
|
||||
include_current_execution: bool = False,
|
||||
include_tool_calls: bool = False,
|
||||
) -> Optional[List[Dict]]:
|
||||
"""
|
||||
Rebuild the message list after compression so tool execution can continue.
|
||||
|
||||
Args:
|
||||
messages: Original message list
|
||||
compressed_summary: Compressed summary
|
||||
recent_queries: Recent uncompressed queries
|
||||
include_current_execution: Whether to preserve current execution messages
|
||||
include_tool_calls: Whether to include tool calls from history
|
||||
|
||||
Returns:
|
||||
Rebuilt message list or None if failed
|
||||
"""
|
||||
# Find the system message
|
||||
system_message = next(
|
||||
(msg for msg in messages if msg.get("role") == "system"), None
|
||||
)
|
||||
if not system_message:
|
||||
logger.warning("No system message found in messages list")
|
||||
return None
|
||||
|
||||
# Update system message with compressed summary
|
||||
if compressed_summary:
|
||||
content = system_message.get("content", "")
|
||||
system_message["content"] = MessageBuilder._append_compression_context(
|
||||
content, compressed_summary, "mid_execution"
|
||||
)
|
||||
logger.info(
|
||||
"Appended compression summary to system prompt (truncated): %s",
|
||||
(
|
||||
compressed_summary[:500] + "..."
|
||||
if len(compressed_summary) > 500
|
||||
else compressed_summary
|
||||
),
|
||||
)
|
||||
|
||||
rebuilt_messages = [system_message]
|
||||
|
||||
# Add recent history from compressed context
|
||||
for query in recent_queries:
|
||||
if "prompt" in query and "response" in query:
|
||||
rebuilt_messages.append({"role": "user", "content": query["prompt"]})
|
||||
rebuilt_messages.append(
|
||||
{"role": "assistant", "content": query["response"]}
|
||||
)
|
||||
|
||||
# Add tool calls from history if present
|
||||
if include_tool_calls and "tool_calls" in query:
|
||||
for tool_call in query["tool_calls"]:
|
||||
call_id = tool_call.get("call_id") or str(uuid.uuid4())
|
||||
|
||||
function_call_dict = {
|
||||
"function_call": {
|
||||
"name": tool_call.get("action_name"),
|
||||
"args": tool_call.get("arguments"),
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
function_response_dict = {
|
||||
"function_response": {
|
||||
"name": tool_call.get("action_name"),
|
||||
"response": {"result": tool_call.get("result")},
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
|
||||
rebuilt_messages.append(
|
||||
{"role": "assistant", "content": [function_call_dict]}
|
||||
)
|
||||
rebuilt_messages.append(
|
||||
{"role": "tool", "content": [function_response_dict]}
|
||||
)
|
||||
|
||||
# If no recent queries (everything was compressed), add a continuation user message
|
||||
if len(recent_queries) == 0 and compressed_summary:
|
||||
rebuilt_messages.append({
|
||||
"role": "user",
|
||||
"content": "Please continue with the remaining tasks based on the context above."
|
||||
})
|
||||
logger.info("Added continuation user message to maintain proper turn-taking after full compression")
|
||||
|
||||
if include_current_execution:
|
||||
# Preserve any messages that were added during the current execution cycle
|
||||
recent_msg_count = 1 # system message
|
||||
for query in recent_queries:
|
||||
if "prompt" in query and "response" in query:
|
||||
recent_msg_count += 2
|
||||
if "tool_calls" in query:
|
||||
recent_msg_count += len(query["tool_calls"]) * 2
|
||||
|
||||
if len(messages) > recent_msg_count:
|
||||
current_execution_messages = messages[recent_msg_count:]
|
||||
rebuilt_messages.extend(current_execution_messages)
|
||||
logger.info(
|
||||
f"Preserved {len(current_execution_messages)} messages from current execution cycle"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Messages rebuilt: {len(messages)} → {len(rebuilt_messages)} messages. "
|
||||
f"Ready to continue tool execution."
|
||||
)
|
||||
return rebuilt_messages
|
||||
@@ -1,232 +0,0 @@
|
||||
"""High-level compression orchestration."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from application.api.answer.services.compression.service import CompressionService
|
||||
from application.api.answer.services.compression.threshold_checker import (
|
||||
CompressionThresholdChecker,
|
||||
)
|
||||
from application.api.answer.services.compression.types import CompressionResult
|
||||
from application.api.answer.services.conversation_service import ConversationService
|
||||
from application.core.model_utils import (
|
||||
get_api_key_for_provider,
|
||||
get_provider_from_model_id,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CompressionOrchestrator:
|
||||
"""
|
||||
Facade for compression operations.
|
||||
|
||||
Coordinates between all compression components and provides
|
||||
a simple interface for callers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conversation_service: ConversationService,
|
||||
threshold_checker: Optional[CompressionThresholdChecker] = None,
|
||||
):
|
||||
"""
|
||||
Initialize orchestrator.
|
||||
|
||||
Args:
|
||||
conversation_service: Service for DB operations
|
||||
threshold_checker: Custom threshold checker (optional)
|
||||
"""
|
||||
self.conversation_service = conversation_service
|
||||
self.threshold_checker = threshold_checker or CompressionThresholdChecker()
|
||||
|
||||
def compress_if_needed(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_id: str,
|
||||
model_id: str,
|
||||
decoded_token: Dict[str, Any],
|
||||
current_query_tokens: int = 500,
|
||||
) -> CompressionResult:
|
||||
"""
|
||||
Check if compression is needed and perform it if so.
|
||||
|
||||
This is the main entry point for compression operations.
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation ID
|
||||
user_id: User ID
|
||||
model_id: Model being used for conversation
|
||||
decoded_token: User's decoded JWT token
|
||||
current_query_tokens: Estimated tokens for current query
|
||||
|
||||
Returns:
|
||||
CompressionResult with summary and recent queries
|
||||
"""
|
||||
try:
|
||||
# Load conversation
|
||||
conversation = self.conversation_service.get_conversation(
|
||||
conversation_id, user_id
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
logger.warning(
|
||||
f"Conversation {conversation_id} not found for user {user_id}"
|
||||
)
|
||||
return CompressionResult.failure("Conversation not found")
|
||||
|
||||
# Check if compression is needed
|
||||
if not self.threshold_checker.should_compress(
|
||||
conversation, model_id, current_query_tokens
|
||||
):
|
||||
# No compression needed, return full history
|
||||
queries = conversation.get("queries", [])
|
||||
return CompressionResult.success_no_compression(queries)
|
||||
|
||||
# Perform compression
|
||||
return self._perform_compression(
|
||||
conversation_id, conversation, model_id, decoded_token
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in compress_if_needed: {str(e)}", exc_info=True
|
||||
)
|
||||
return CompressionResult.failure(str(e))
|
||||
|
||||
def _perform_compression(
|
||||
self,
|
||||
conversation_id: str,
|
||||
conversation: Dict[str, Any],
|
||||
model_id: str,
|
||||
decoded_token: Dict[str, Any],
|
||||
) -> CompressionResult:
|
||||
"""
|
||||
Perform the actual compression operation.
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation ID
|
||||
conversation: Conversation document
|
||||
model_id: Model ID for conversation
|
||||
decoded_token: User token
|
||||
|
||||
Returns:
|
||||
CompressionResult
|
||||
"""
|
||||
try:
|
||||
# Determine which model to use for compression
|
||||
compression_model = (
|
||||
settings.COMPRESSION_MODEL_OVERRIDE
|
||||
if settings.COMPRESSION_MODEL_OVERRIDE
|
||||
else model_id
|
||||
)
|
||||
|
||||
# Get provider and API key for compression model
|
||||
provider = get_provider_from_model_id(compression_model)
|
||||
api_key = get_api_key_for_provider(provider)
|
||||
|
||||
# Create compression LLM
|
||||
compression_llm = LLMCreator.create_llm(
|
||||
provider,
|
||||
api_key=api_key,
|
||||
user_api_key=None,
|
||||
decoded_token=decoded_token,
|
||||
model_id=compression_model,
|
||||
)
|
||||
|
||||
# Create compression service with DB update capability
|
||||
compression_service = CompressionService(
|
||||
llm=compression_llm,
|
||||
model_id=compression_model,
|
||||
conversation_service=self.conversation_service,
|
||||
)
|
||||
|
||||
# Compress all queries up to the latest
|
||||
queries_count = len(conversation.get("queries", []))
|
||||
compress_up_to = queries_count - 1
|
||||
|
||||
if compress_up_to < 0:
|
||||
logger.warning("No queries to compress")
|
||||
return CompressionResult.success_no_compression([])
|
||||
|
||||
logger.info(
|
||||
f"Initiating compression for conversation {conversation_id}: "
|
||||
f"compressing all {queries_count} queries (0-{compress_up_to})"
|
||||
)
|
||||
|
||||
# Perform compression and save to DB
|
||||
metadata = compression_service.compress_and_save(
|
||||
conversation_id, conversation, compress_up_to
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Compression successful - ratio: {metadata.compression_ratio:.1f}x, "
|
||||
f"saved {metadata.original_token_count - metadata.compressed_token_count} tokens"
|
||||
)
|
||||
|
||||
# Reload conversation with updated metadata
|
||||
conversation = self.conversation_service.get_conversation(
|
||||
conversation_id, user_id=decoded_token.get("sub")
|
||||
)
|
||||
|
||||
# Get compressed context
|
||||
compressed_summary, recent_queries = (
|
||||
compression_service.get_compressed_context(conversation)
|
||||
)
|
||||
|
||||
return CompressionResult.success_with_compression(
|
||||
compressed_summary, recent_queries, metadata
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error performing compression: {str(e)}", exc_info=True)
|
||||
return CompressionResult.failure(str(e))
|
||||
|
||||
def compress_mid_execution(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_id: str,
|
||||
model_id: str,
|
||||
decoded_token: Dict[str, Any],
|
||||
current_conversation: Optional[Dict[str, Any]] = None,
|
||||
) -> CompressionResult:
|
||||
"""
|
||||
Perform compression during tool execution.
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation ID
|
||||
user_id: User ID
|
||||
model_id: Model ID
|
||||
decoded_token: User token
|
||||
current_conversation: Pre-loaded conversation (optional)
|
||||
|
||||
Returns:
|
||||
CompressionResult
|
||||
"""
|
||||
try:
|
||||
# Load conversation if not provided
|
||||
if current_conversation:
|
||||
conversation = current_conversation
|
||||
else:
|
||||
conversation = self.conversation_service.get_conversation(
|
||||
conversation_id, user_id
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
logger.warning(
|
||||
f"Could not load conversation {conversation_id} for mid-execution compression"
|
||||
)
|
||||
return CompressionResult.failure("Conversation not found")
|
||||
|
||||
# Perform compression
|
||||
return self._perform_compression(
|
||||
conversation_id, conversation, model_id, decoded_token
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in mid-execution compression: {str(e)}", exc_info=True
|
||||
)
|
||||
return CompressionResult.failure(str(e))
|
||||
@@ -1,149 +0,0 @@
|
||||
"""Compression prompt building logic."""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CompressionPromptBuilder:
|
||||
"""Builds prompts for LLM compression calls."""
|
||||
|
||||
def __init__(self, version: str = "v1.0"):
|
||||
"""
|
||||
Initialize prompt builder.
|
||||
|
||||
Args:
|
||||
version: Prompt template version to use
|
||||
"""
|
||||
self.version = version
|
||||
self.system_prompt = self._load_prompt(version)
|
||||
|
||||
def _load_prompt(self, version: str) -> str:
|
||||
"""
|
||||
Load prompt template from file.
|
||||
|
||||
Args:
|
||||
version: Version string (e.g., 'v1.0')
|
||||
|
||||
Returns:
|
||||
Prompt template content
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If prompt template file doesn't exist
|
||||
"""
|
||||
current_dir = Path(__file__).resolve().parents[4]
|
||||
prompt_path = current_dir / "prompts" / "compression" / f"{version}.txt"
|
||||
|
||||
try:
|
||||
with open(prompt_path, "r") as f:
|
||||
return f.read()
|
||||
except FileNotFoundError:
|
||||
logger.error(f"Compression prompt template not found: {prompt_path}")
|
||||
raise FileNotFoundError(
|
||||
f"Compression prompt template '{version}' not found at {prompt_path}. "
|
||||
f"Please ensure the template file exists."
|
||||
)
|
||||
|
||||
def build_prompt(
|
||||
self,
|
||||
queries: List[Dict[str, Any]],
|
||||
existing_compressions: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Build messages for compression LLM call.
|
||||
|
||||
Args:
|
||||
queries: List of query objects to compress
|
||||
existing_compressions: List of previous compression points
|
||||
|
||||
Returns:
|
||||
List of message dicts for LLM
|
||||
"""
|
||||
# Build conversation text
|
||||
conversation_text = self._format_conversation(queries)
|
||||
|
||||
# Add existing compression context if present
|
||||
existing_compression_context = ""
|
||||
if existing_compressions and len(existing_compressions) > 0:
|
||||
existing_compression_context = (
|
||||
"\n\nIMPORTANT: This conversation has been compressed before. "
|
||||
"Previous compression summaries:\n\n"
|
||||
)
|
||||
for i, comp in enumerate(existing_compressions):
|
||||
existing_compression_context += (
|
||||
f"--- Compression {i + 1} (up to message {comp.get('query_index', 'unknown')}) ---\n"
|
||||
f"{comp.get('compressed_summary', '')}\n\n"
|
||||
)
|
||||
existing_compression_context += (
|
||||
"Your task is to create a NEW summary that incorporates the context from "
|
||||
"previous compressions AND the new messages below. The final summary should "
|
||||
"be comprehensive and include all important information from both previous "
|
||||
"compressions and new messages.\n\n"
|
||||
)
|
||||
|
||||
user_prompt = (
|
||||
f"{existing_compression_context}"
|
||||
f"Here is the conversation to summarize:\n\n"
|
||||
f"{conversation_text}"
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": self.system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
]
|
||||
|
||||
return messages
|
||||
|
||||
def _format_conversation(self, queries: List[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
Format conversation queries into readable text for compression.
|
||||
|
||||
Args:
|
||||
queries: List of query objects
|
||||
|
||||
Returns:
|
||||
Formatted conversation text
|
||||
"""
|
||||
conversation_lines = []
|
||||
|
||||
for i, query in enumerate(queries):
|
||||
conversation_lines.append(f"--- Message {i + 1} ---")
|
||||
conversation_lines.append(f"User: {query.get('prompt', '')}")
|
||||
|
||||
# Add tool calls if present
|
||||
tool_calls = query.get("tool_calls", [])
|
||||
if tool_calls:
|
||||
conversation_lines.append("\nTool Calls:")
|
||||
for tc in tool_calls:
|
||||
tool_name = tc.get("tool_name", "unknown")
|
||||
action_name = tc.get("action_name", "unknown")
|
||||
arguments = tc.get("arguments", {})
|
||||
result = tc.get("result", "")
|
||||
if result is None:
|
||||
result = ""
|
||||
status = tc.get("status", "unknown")
|
||||
|
||||
# Include full tool result for complete compression context
|
||||
conversation_lines.append(
|
||||
f" - {tool_name}.{action_name}({arguments}) "
|
||||
f"[{status}] → {result}"
|
||||
)
|
||||
|
||||
# Add agent thought if present
|
||||
thought = query.get("thought", "")
|
||||
if thought:
|
||||
conversation_lines.append(f"\nAgent Thought: {thought}")
|
||||
|
||||
# Add assistant response
|
||||
conversation_lines.append(f"\nAssistant: {query.get('response', '')}")
|
||||
|
||||
# Add sources if present
|
||||
sources = query.get("sources", [])
|
||||
if sources:
|
||||
conversation_lines.append(f"\nSources Used: {len(sources)} documents")
|
||||
|
||||
conversation_lines.append("") # Empty line between messages
|
||||
|
||||
return "\n".join(conversation_lines)
|
||||
@@ -1,306 +0,0 @@
|
||||
"""Core compression service with simplified responsibilities."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from application.api.answer.services.compression.prompt_builder import (
|
||||
CompressionPromptBuilder,
|
||||
)
|
||||
from application.api.answer.services.compression.token_counter import TokenCounter
|
||||
from application.api.answer.services.compression.types import (
|
||||
CompressionMetadata,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CompressionService:
|
||||
"""
|
||||
Service for compressing conversation history.
|
||||
|
||||
Handles DB updates.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm,
|
||||
model_id: str,
|
||||
conversation_service=None,
|
||||
prompt_builder: Optional[CompressionPromptBuilder] = None,
|
||||
):
|
||||
"""
|
||||
Initialize compression service.
|
||||
|
||||
Args:
|
||||
llm: LLM instance to use for compression
|
||||
model_id: Model ID for compression
|
||||
conversation_service: Service for DB operations (optional, for DB updates)
|
||||
prompt_builder: Custom prompt builder (optional)
|
||||
"""
|
||||
self.llm = llm
|
||||
self.model_id = model_id
|
||||
self.conversation_service = conversation_service
|
||||
self.prompt_builder = prompt_builder or CompressionPromptBuilder(
|
||||
version=settings.COMPRESSION_PROMPT_VERSION
|
||||
)
|
||||
|
||||
def compress_conversation(
|
||||
self,
|
||||
conversation: Dict[str, Any],
|
||||
compress_up_to_index: int,
|
||||
) -> CompressionMetadata:
|
||||
"""
|
||||
Compress conversation history up to specified index.
|
||||
|
||||
Args:
|
||||
conversation: Full conversation document
|
||||
compress_up_to_index: Last query index to include in compression
|
||||
|
||||
Returns:
|
||||
CompressionMetadata with compression details
|
||||
|
||||
Raises:
|
||||
ValueError: If compress_up_to_index is invalid
|
||||
"""
|
||||
try:
|
||||
queries = conversation.get("queries", [])
|
||||
|
||||
if compress_up_to_index < 0 or compress_up_to_index >= len(queries):
|
||||
raise ValueError(
|
||||
f"Invalid compress_up_to_index: {compress_up_to_index} "
|
||||
f"(conversation has {len(queries)} queries)"
|
||||
)
|
||||
|
||||
# Get queries to compress
|
||||
queries_to_compress = queries[: compress_up_to_index + 1]
|
||||
|
||||
# Check if there are existing compressions
|
||||
existing_compressions = conversation.get("compression_metadata", {}).get(
|
||||
"compression_points", []
|
||||
)
|
||||
|
||||
if existing_compressions:
|
||||
logger.info(
|
||||
f"Found {len(existing_compressions)} previous compression(s) - "
|
||||
f"will incorporate into new summary"
|
||||
)
|
||||
|
||||
# Calculate original token count
|
||||
original_tokens = TokenCounter.count_query_tokens(queries_to_compress)
|
||||
|
||||
# Log tool call stats
|
||||
self._log_tool_call_stats(queries_to_compress)
|
||||
|
||||
# Build compression prompt
|
||||
messages = self.prompt_builder.build_prompt(
|
||||
queries_to_compress, existing_compressions
|
||||
)
|
||||
|
||||
# Call LLM to generate compression
|
||||
logger.info(
|
||||
f"Starting compression: {len(queries_to_compress)} queries "
|
||||
f"(messages 0-{compress_up_to_index}, {original_tokens} tokens) "
|
||||
f"using model {self.model_id}"
|
||||
)
|
||||
|
||||
response = self.llm.gen(
|
||||
model=self.model_id, messages=messages, max_tokens=4000
|
||||
)
|
||||
|
||||
# Extract summary from response
|
||||
compressed_summary = self._extract_summary(response)
|
||||
|
||||
# Calculate compressed token count
|
||||
compressed_tokens = TokenCounter.count_message_tokens(
|
||||
[{"content": compressed_summary}]
|
||||
)
|
||||
|
||||
# Calculate compression ratio
|
||||
compression_ratio = (
|
||||
original_tokens / compressed_tokens if compressed_tokens > 0 else 0
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Compression complete: {original_tokens} → {compressed_tokens} tokens "
|
||||
f"({compression_ratio:.1f}x compression)"
|
||||
)
|
||||
|
||||
# Build compression metadata
|
||||
compression_metadata = CompressionMetadata(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
query_index=compress_up_to_index,
|
||||
compressed_summary=compressed_summary,
|
||||
original_token_count=original_tokens,
|
||||
compressed_token_count=compressed_tokens,
|
||||
compression_ratio=compression_ratio,
|
||||
model_used=self.model_id,
|
||||
compression_prompt_version=self.prompt_builder.version,
|
||||
)
|
||||
|
||||
return compression_metadata
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error compressing conversation: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def compress_and_save(
|
||||
self,
|
||||
conversation_id: str,
|
||||
conversation: Dict[str, Any],
|
||||
compress_up_to_index: int,
|
||||
) -> CompressionMetadata:
|
||||
"""
|
||||
Compress conversation and save to database.
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation ID
|
||||
conversation: Full conversation document
|
||||
compress_up_to_index: Last query index to include
|
||||
|
||||
Returns:
|
||||
CompressionMetadata
|
||||
|
||||
Raises:
|
||||
ValueError: If conversation_service not provided or invalid index
|
||||
"""
|
||||
if not self.conversation_service:
|
||||
raise ValueError(
|
||||
"conversation_service required for compress_and_save operation"
|
||||
)
|
||||
|
||||
# Perform compression
|
||||
metadata = self.compress_conversation(conversation, compress_up_to_index)
|
||||
|
||||
# Save to database
|
||||
self.conversation_service.update_compression_metadata(
|
||||
conversation_id, metadata.to_dict()
|
||||
)
|
||||
|
||||
logger.info(f"Compression metadata saved to database for {conversation_id}")
|
||||
|
||||
return metadata
|
||||
|
||||
def get_compressed_context(
|
||||
self, conversation: Dict[str, Any]
|
||||
) -> tuple[Optional[str], List[Dict[str, Any]]]:
|
||||
"""
|
||||
Get compressed summary + recent uncompressed messages.
|
||||
|
||||
Args:
|
||||
conversation: Full conversation document
|
||||
|
||||
Returns:
|
||||
(compressed_summary, recent_messages)
|
||||
"""
|
||||
try:
|
||||
compression_metadata = conversation.get("compression_metadata", {})
|
||||
|
||||
if not compression_metadata.get("is_compressed"):
|
||||
logger.debug("No compression metadata found - using full history")
|
||||
queries = conversation.get("queries", [])
|
||||
if queries is None:
|
||||
logger.error("Conversation queries is None - returning empty list")
|
||||
return None, []
|
||||
return None, queries
|
||||
|
||||
compression_points = compression_metadata.get("compression_points", [])
|
||||
|
||||
if not compression_points:
|
||||
logger.debug("No compression points found - using full history")
|
||||
queries = conversation.get("queries", [])
|
||||
if queries is None:
|
||||
logger.error("Conversation queries is None - returning empty list")
|
||||
return None, []
|
||||
return None, queries
|
||||
|
||||
# Get the most recent compression point
|
||||
latest_compression = compression_points[-1]
|
||||
compressed_summary = latest_compression.get("compressed_summary")
|
||||
last_compressed_index = latest_compression.get("query_index")
|
||||
compressed_tokens = latest_compression.get("compressed_token_count", 0)
|
||||
original_tokens = latest_compression.get("original_token_count", 0)
|
||||
|
||||
# Get only messages after compression point
|
||||
queries = conversation.get("queries", [])
|
||||
total_queries = len(queries)
|
||||
recent_queries = queries[last_compressed_index + 1 :]
|
||||
|
||||
logger.info(
|
||||
f"Using compressed context: summary ({compressed_tokens} tokens, "
|
||||
f"compressed from {original_tokens}) + {len(recent_queries)} recent messages "
|
||||
f"(messages {last_compressed_index + 1}-{total_queries - 1})"
|
||||
)
|
||||
|
||||
return compressed_summary, recent_queries
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting compressed context: {str(e)}", exc_info=True
|
||||
)
|
||||
queries = conversation.get("queries", [])
|
||||
if queries is None:
|
||||
return None, []
|
||||
return None, queries
|
||||
|
||||
def _extract_summary(self, llm_response: str) -> str:
|
||||
"""
|
||||
Extract clean summary from LLM response.
|
||||
|
||||
Args:
|
||||
llm_response: Raw LLM response
|
||||
|
||||
Returns:
|
||||
Cleaned summary text
|
||||
"""
|
||||
try:
|
||||
# Try to extract content within <summary> tags
|
||||
summary_match = re.search(
|
||||
r"<summary>(.*?)</summary>", llm_response, re.DOTALL
|
||||
)
|
||||
|
||||
if summary_match:
|
||||
summary = summary_match.group(1).strip()
|
||||
else:
|
||||
# If no summary tags, remove analysis tags and use the rest
|
||||
summary = re.sub(
|
||||
r"<analysis>.*?</analysis>", "", llm_response, flags=re.DOTALL
|
||||
).strip()
|
||||
|
||||
return summary
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error extracting summary: {str(e)}, using full response")
|
||||
return llm_response
|
||||
|
||||
def _log_tool_call_stats(self, queries: List[Dict[str, Any]]) -> None:
|
||||
"""Log statistics about tool calls in queries."""
|
||||
total_tool_calls = 0
|
||||
total_tool_result_chars = 0
|
||||
tool_call_breakdown = {}
|
||||
|
||||
for q in queries:
|
||||
for tc in q.get("tool_calls", []):
|
||||
total_tool_calls += 1
|
||||
tool_name = tc.get("tool_name", "unknown")
|
||||
action_name = tc.get("action_name", "unknown")
|
||||
key = f"{tool_name}.{action_name}"
|
||||
tool_call_breakdown[key] = tool_call_breakdown.get(key, 0) + 1
|
||||
|
||||
# Track total tool result size
|
||||
result = tc.get("result", "")
|
||||
if result:
|
||||
total_tool_result_chars += len(str(result))
|
||||
|
||||
if total_tool_calls > 0:
|
||||
tool_breakdown_str = ", ".join(
|
||||
f"{tool}({count})"
|
||||
for tool, count in sorted(tool_call_breakdown.items())
|
||||
)
|
||||
tool_result_kb = total_tool_result_chars / 1024
|
||||
logger.info(
|
||||
f"Tool call breakdown: {tool_breakdown_str} "
|
||||
f"(total result size: {tool_result_kb:.1f} KB, {total_tool_result_chars:,} chars)"
|
||||
)
|
||||
@@ -1,103 +0,0 @@
|
||||
"""Compression threshold checking logic."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
|
||||
from application.core.model_utils import get_token_limit
|
||||
from application.core.settings import settings
|
||||
from application.api.answer.services.compression.token_counter import TokenCounter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CompressionThresholdChecker:
|
||||
"""Determines if compression is needed based on token thresholds."""
|
||||
|
||||
def __init__(self, threshold_percentage: float = None):
|
||||
"""
|
||||
Initialize threshold checker.
|
||||
|
||||
Args:
|
||||
threshold_percentage: Percentage of context to use as threshold
|
||||
(defaults to settings.COMPRESSION_THRESHOLD_PERCENTAGE)
|
||||
"""
|
||||
self.threshold_percentage = (
|
||||
threshold_percentage or settings.COMPRESSION_THRESHOLD_PERCENTAGE
|
||||
)
|
||||
|
||||
def should_compress(
|
||||
self,
|
||||
conversation: Dict[str, Any],
|
||||
model_id: str,
|
||||
current_query_tokens: int = 500,
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if compression is needed.
|
||||
|
||||
Args:
|
||||
conversation: Full conversation document
|
||||
model_id: Target model for this request
|
||||
current_query_tokens: Estimated tokens for current query
|
||||
|
||||
Returns:
|
||||
True if tokens >= threshold% of context window
|
||||
"""
|
||||
try:
|
||||
# Calculate total tokens in conversation
|
||||
total_tokens = TokenCounter.count_conversation_tokens(conversation)
|
||||
total_tokens += current_query_tokens
|
||||
|
||||
# Get context window limit for model
|
||||
context_limit = get_token_limit(model_id)
|
||||
|
||||
# Calculate threshold
|
||||
threshold = int(context_limit * self.threshold_percentage)
|
||||
|
||||
compression_needed = total_tokens >= threshold
|
||||
percentage_used = (total_tokens / context_limit) * 100
|
||||
|
||||
if compression_needed:
|
||||
logger.warning(
|
||||
f"COMPRESSION TRIGGERED: {total_tokens} tokens / {context_limit} limit "
|
||||
f"({percentage_used:.1f}% used, threshold: {self.threshold_percentage * 100:.0f}%)"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Compression check: {total_tokens}/{context_limit} tokens "
|
||||
f"({percentage_used:.1f}% used, threshold: {self.threshold_percentage * 100:.0f}%) - No compression needed"
|
||||
)
|
||||
|
||||
return compression_needed
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking compression need: {str(e)}", exc_info=True)
|
||||
return False
|
||||
|
||||
def check_message_tokens(self, messages: list, model_id: str) -> bool:
|
||||
"""
|
||||
Check if message list exceeds threshold.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts
|
||||
model_id: Target model
|
||||
|
||||
Returns:
|
||||
True if at or above threshold
|
||||
"""
|
||||
try:
|
||||
current_tokens = TokenCounter.count_message_tokens(messages)
|
||||
context_limit = get_token_limit(model_id)
|
||||
threshold = int(context_limit * self.threshold_percentage)
|
||||
|
||||
if current_tokens >= threshold:
|
||||
logger.warning(
|
||||
f"Message context limit approaching: {current_tokens}/{context_limit} tokens "
|
||||
f"({(current_tokens/context_limit)*100:.1f}%)"
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking message tokens: {str(e)}", exc_info=True)
|
||||
return False
|
||||
@@ -1,103 +0,0 @@
|
||||
"""Token counting utilities for compression."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from application.utils import num_tokens_from_string
|
||||
from application.core.settings import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TokenCounter:
|
||||
"""Centralized token counting for conversations and messages."""
|
||||
|
||||
@staticmethod
|
||||
def count_message_tokens(messages: List[Dict]) -> int:
|
||||
"""
|
||||
Calculate total tokens in a list of messages.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'content' field
|
||||
|
||||
Returns:
|
||||
Total token count
|
||||
"""
|
||||
total_tokens = 0
|
||||
for message in messages:
|
||||
content = message.get("content", "")
|
||||
if isinstance(content, str):
|
||||
total_tokens += num_tokens_from_string(content)
|
||||
elif isinstance(content, list):
|
||||
# Handle structured content (tool calls, etc.)
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
total_tokens += num_tokens_from_string(str(item))
|
||||
return total_tokens
|
||||
|
||||
@staticmethod
|
||||
def count_query_tokens(
|
||||
queries: List[Dict[str, Any]], include_tool_calls: bool = True
|
||||
) -> int:
|
||||
"""
|
||||
Count tokens across multiple query objects.
|
||||
|
||||
Args:
|
||||
queries: List of query objects from conversation
|
||||
include_tool_calls: Whether to count tool call tokens
|
||||
|
||||
Returns:
|
||||
Total token count
|
||||
"""
|
||||
total_tokens = 0
|
||||
|
||||
for query in queries:
|
||||
# Count prompt and response tokens
|
||||
if "prompt" in query:
|
||||
total_tokens += num_tokens_from_string(query["prompt"])
|
||||
if "response" in query:
|
||||
total_tokens += num_tokens_from_string(query["response"])
|
||||
if "thought" in query:
|
||||
total_tokens += num_tokens_from_string(query.get("thought", ""))
|
||||
|
||||
# Count tool call tokens
|
||||
if include_tool_calls and "tool_calls" in query:
|
||||
for tool_call in query["tool_calls"]:
|
||||
tool_call_string = (
|
||||
f"Tool: {tool_call.get('tool_name')} | "
|
||||
f"Action: {tool_call.get('action_name')} | "
|
||||
f"Args: {tool_call.get('arguments')} | "
|
||||
f"Response: {tool_call.get('result')}"
|
||||
)
|
||||
total_tokens += num_tokens_from_string(tool_call_string)
|
||||
|
||||
return total_tokens
|
||||
|
||||
@staticmethod
|
||||
def count_conversation_tokens(
|
||||
conversation: Dict[str, Any], include_system_prompt: bool = False
|
||||
) -> int:
|
||||
"""
|
||||
Calculate total tokens in a conversation.
|
||||
|
||||
Args:
|
||||
conversation: Conversation document
|
||||
include_system_prompt: Whether to include system prompt in count
|
||||
|
||||
Returns:
|
||||
Total token count
|
||||
"""
|
||||
try:
|
||||
queries = conversation.get("queries", [])
|
||||
total_tokens = TokenCounter.count_query_tokens(queries)
|
||||
|
||||
# Add system prompt tokens if requested
|
||||
if include_system_prompt:
|
||||
# Rough estimate for system prompt
|
||||
total_tokens += settings.RESERVED_TOKENS.get("system_prompt", 500)
|
||||
|
||||
return total_tokens
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating conversation tokens: {str(e)}")
|
||||
return 0
|
||||
@@ -1,83 +0,0 @@
|
||||
"""Type definitions for compression module."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompressionMetadata:
|
||||
"""Metadata about a compression operation."""
|
||||
|
||||
timestamp: datetime
|
||||
query_index: int
|
||||
compressed_summary: str
|
||||
original_token_count: int
|
||||
compressed_token_count: int
|
||||
compression_ratio: float
|
||||
model_used: str
|
||||
compression_prompt_version: str
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for DB storage."""
|
||||
return {
|
||||
"timestamp": self.timestamp,
|
||||
"query_index": self.query_index,
|
||||
"compressed_summary": self.compressed_summary,
|
||||
"original_token_count": self.original_token_count,
|
||||
"compressed_token_count": self.compressed_token_count,
|
||||
"compression_ratio": self.compression_ratio,
|
||||
"model_used": self.model_used,
|
||||
"compression_prompt_version": self.compression_prompt_version,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompressionResult:
|
||||
"""Result of a compression operation."""
|
||||
|
||||
success: bool
|
||||
compressed_summary: Optional[str] = None
|
||||
recent_queries: List[Dict[str, Any]] = field(default_factory=list)
|
||||
metadata: Optional[CompressionMetadata] = None
|
||||
error: Optional[str] = None
|
||||
compression_performed: bool = False
|
||||
|
||||
@classmethod
|
||||
def success_with_compression(
|
||||
cls, summary: str, queries: List[Dict], metadata: CompressionMetadata
|
||||
) -> "CompressionResult":
|
||||
"""Create a successful result with compression."""
|
||||
return cls(
|
||||
success=True,
|
||||
compressed_summary=summary,
|
||||
recent_queries=queries,
|
||||
metadata=metadata,
|
||||
compression_performed=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def success_no_compression(cls, queries: List[Dict]) -> "CompressionResult":
|
||||
"""Create a successful result without compression needed."""
|
||||
return cls(
|
||||
success=True,
|
||||
recent_queries=queries,
|
||||
compression_performed=False,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def failure(cls, error: str) -> "CompressionResult":
|
||||
"""Create a failure result."""
|
||||
return cls(success=False, error=error, compression_performed=False)
|
||||
|
||||
def as_history(self) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Convert recent queries to history format.
|
||||
|
||||
Returns:
|
||||
List of prompt/response dicts
|
||||
"""
|
||||
return [
|
||||
{"prompt": q["prompt"], "response": q["response"]}
|
||||
for q in self.recent_queries
|
||||
]
|
||||
@@ -52,7 +52,7 @@ class ConversationService:
|
||||
sources: List[Dict[str, Any]],
|
||||
tool_calls: List[Dict[str, Any]],
|
||||
llm: Any,
|
||||
model_id: str,
|
||||
gpt_model: str,
|
||||
decoded_token: Dict[str, Any],
|
||||
index: Optional[int] = None,
|
||||
api_key: Optional[str] = None,
|
||||
@@ -66,7 +66,7 @@ class ConversationService:
|
||||
if not user_id:
|
||||
raise ValueError("User ID not found in token")
|
||||
current_time = datetime.now(timezone.utc)
|
||||
|
||||
|
||||
# clean up in sources array such that we save max 1k characters for text part
|
||||
for source in sources:
|
||||
if "text" in source and isinstance(source["text"], str):
|
||||
@@ -90,7 +90,6 @@ class ConversationService:
|
||||
f"queries.{index}.tool_calls": tool_calls,
|
||||
f"queries.{index}.timestamp": current_time,
|
||||
f"queries.{index}.attachments": attachment_ids,
|
||||
f"queries.{index}.model_id": model_id,
|
||||
}
|
||||
},
|
||||
)
|
||||
@@ -121,7 +120,6 @@ class ConversationService:
|
||||
"tool_calls": tool_calls,
|
||||
"timestamp": current_time,
|
||||
"attachments": attachment_ids,
|
||||
"model_id": model_id,
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -135,9 +133,10 @@ class ConversationService:
|
||||
|
||||
messages_summary = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant that creates concise conversation titles. "
|
||||
"Summarize conversations in 3 words or less using the same language as the user.",
|
||||
"role": "assistant",
|
||||
"content": "Summarise following conversation in no more than 3 "
|
||||
"words, respond ONLY with the summary, use the same "
|
||||
"language as the user query",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
@@ -148,7 +147,7 @@ class ConversationService:
|
||||
]
|
||||
|
||||
completion = llm.gen(
|
||||
model=model_id, messages=messages_summary, max_tokens=30
|
||||
model=gpt_model, messages=messages_summary, max_tokens=30
|
||||
)
|
||||
|
||||
conversation_data = {
|
||||
@@ -164,7 +163,6 @@ class ConversationService:
|
||||
"tool_calls": tool_calls,
|
||||
"timestamp": current_time,
|
||||
"attachments": attachment_ids,
|
||||
"model_id": model_id,
|
||||
}
|
||||
],
|
||||
}
|
||||
@@ -180,103 +178,3 @@ class ConversationService:
|
||||
conversation_data["api_key"] = agent["key"]
|
||||
result = self.conversations_collection.insert_one(conversation_data)
|
||||
return str(result.inserted_id)
|
||||
|
||||
def update_compression_metadata(
|
||||
self, conversation_id: str, compression_metadata: Dict[str, Any]
|
||||
) -> None:
|
||||
"""
|
||||
Update conversation with compression metadata.
|
||||
|
||||
Uses $push with $slice to keep only the most recent compression points,
|
||||
preventing unbounded array growth. Since each compression incorporates
|
||||
previous compressions, older points become redundant.
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation ID
|
||||
compression_metadata: Compression point data
|
||||
"""
|
||||
try:
|
||||
self.conversations_collection.update_one(
|
||||
{"_id": ObjectId(conversation_id)},
|
||||
{
|
||||
"$set": {
|
||||
"compression_metadata.is_compressed": True,
|
||||
"compression_metadata.last_compression_at": compression_metadata.get(
|
||||
"timestamp"
|
||||
),
|
||||
},
|
||||
"$push": {
|
||||
"compression_metadata.compression_points": {
|
||||
"$each": [compression_metadata],
|
||||
"$slice": -settings.COMPRESSION_MAX_HISTORY_POINTS,
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
logger.info(
|
||||
f"Updated compression metadata for conversation {conversation_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error updating compression metadata: {str(e)}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def append_compression_message(
|
||||
self, conversation_id: str, compression_metadata: Dict[str, Any]
|
||||
) -> None:
|
||||
"""
|
||||
Append a synthetic compression summary entry into the conversation history.
|
||||
This makes the summary visible in the DB alongside normal queries.
|
||||
"""
|
||||
try:
|
||||
summary = compression_metadata.get("compressed_summary", "")
|
||||
if not summary:
|
||||
return
|
||||
timestamp = compression_metadata.get("timestamp", datetime.now(timezone.utc))
|
||||
|
||||
self.conversations_collection.update_one(
|
||||
{"_id": ObjectId(conversation_id)},
|
||||
{
|
||||
"$push": {
|
||||
"queries": {
|
||||
"prompt": "[Context Compression Summary]",
|
||||
"response": summary,
|
||||
"thought": "",
|
||||
"sources": [],
|
||||
"tool_calls": [],
|
||||
"timestamp": timestamp,
|
||||
"attachments": [],
|
||||
"model_id": compression_metadata.get("model_used"),
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
logger.info(f"Appended compression summary to conversation {conversation_id}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error appending compression summary: {str(e)}", exc_info=True
|
||||
)
|
||||
|
||||
def get_compression_metadata(
|
||||
self, conversation_id: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get compression metadata for a conversation.
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation ID
|
||||
|
||||
Returns:
|
||||
Compression metadata dict or None
|
||||
"""
|
||||
try:
|
||||
conversation = self.conversations_collection.find_one(
|
||||
{"_id": ObjectId(conversation_id)}, {"compression_metadata": 1}
|
||||
)
|
||||
return conversation.get("compression_metadata") if conversation else None
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting compression metadata: {str(e)}", exc_info=True
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -1,97 +0,0 @@
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from application.templates.namespaces import NamespaceManager
|
||||
|
||||
from application.templates.template_engine import TemplateEngine, TemplateRenderError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PromptRenderer:
|
||||
"""Service for rendering prompts with dynamic context using namespaces"""
|
||||
|
||||
def __init__(self):
|
||||
self.template_engine = TemplateEngine()
|
||||
self.namespace_manager = NamespaceManager()
|
||||
|
||||
def render_prompt(
|
||||
self,
|
||||
prompt_content: str,
|
||||
user_id: Optional[str] = None,
|
||||
request_id: Optional[str] = None,
|
||||
passthrough_data: Optional[Dict[str, Any]] = None,
|
||||
docs: Optional[list] = None,
|
||||
docs_together: Optional[str] = None,
|
||||
tools_data: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
Render prompt with full context from all namespaces.
|
||||
|
||||
Args:
|
||||
prompt_content: Raw prompt template string
|
||||
user_id: Current user identifier
|
||||
request_id: Unique request identifier
|
||||
passthrough_data: Parameters from web request
|
||||
docs: RAG retrieved documents
|
||||
docs_together: Concatenated document content
|
||||
tools_data: Pre-fetched tool results organized by tool name
|
||||
**kwargs: Additional parameters for namespace builders
|
||||
|
||||
Returns:
|
||||
Rendered prompt string with all variables substituted
|
||||
|
||||
Raises:
|
||||
TemplateRenderError: If template rendering fails
|
||||
"""
|
||||
if not prompt_content:
|
||||
return ""
|
||||
|
||||
uses_template = self._uses_template_syntax(prompt_content)
|
||||
|
||||
if not uses_template:
|
||||
return self._apply_legacy_substitutions(prompt_content, docs_together)
|
||||
|
||||
try:
|
||||
context = self.namespace_manager.build_context(
|
||||
user_id=user_id,
|
||||
request_id=request_id,
|
||||
passthrough_data=passthrough_data,
|
||||
docs=docs,
|
||||
docs_together=docs_together,
|
||||
tools_data=tools_data,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return self.template_engine.render(prompt_content, context)
|
||||
except TemplateRenderError:
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"Prompt rendering failed: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
raise TemplateRenderError(error_msg) from e
|
||||
|
||||
def _uses_template_syntax(self, prompt_content: str) -> bool:
|
||||
"""Check if prompt uses Jinja2 template syntax"""
|
||||
return "{{" in prompt_content and "}}" in prompt_content
|
||||
|
||||
def _apply_legacy_substitutions(
|
||||
self, prompt_content: str, docs_together: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Apply backward-compatible substitutions for old prompt format.
|
||||
|
||||
Handles legacy {summaries} and {query} placeholders during transition period.
|
||||
"""
|
||||
if docs_together:
|
||||
prompt_content = prompt_content.replace("{summaries}", docs_together)
|
||||
return prompt_content
|
||||
|
||||
def validate_template(self, prompt_content: str) -> bool:
|
||||
"""Validate prompt template syntax"""
|
||||
return self.template_engine.validate_template(prompt_content)
|
||||
|
||||
def extract_variables(self, prompt_content: str) -> set[str]:
|
||||
"""Extract all variable names from prompt template"""
|
||||
return self.template_engine.extract_variables(prompt_content)
|
||||
@@ -3,30 +3,18 @@ import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Set
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from bson.dbref import DBRef
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
|
||||
from application.agents.agent_creator import AgentCreator
|
||||
from application.api.answer.services.compression import CompressionOrchestrator
|
||||
from application.api.answer.services.compression.token_counter import TokenCounter
|
||||
from application.api.answer.services.conversation_service import ConversationService
|
||||
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||||
from application.core.model_utils import (
|
||||
get_api_key_for_provider,
|
||||
get_default_model_id,
|
||||
get_provider_from_model_id,
|
||||
validate_model_id,
|
||||
)
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.retriever.retriever_creator import RetrieverCreator
|
||||
from application.utils import (
|
||||
calculate_doc_token_budget,
|
||||
limit_chat_history,
|
||||
)
|
||||
from application.utils import get_gpt_model, limit_chat_history
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -81,31 +69,22 @@ class StreamProcessor:
|
||||
self.decoded_token.get("sub") if self.decoded_token is not None else None
|
||||
)
|
||||
self.conversation_id = self.data.get("conversation_id")
|
||||
self.source = {}
|
||||
self.all_sources = []
|
||||
self.source = (
|
||||
{"active_docs": self.data["active_docs"]}
|
||||
if "active_docs" in self.data
|
||||
else {}
|
||||
)
|
||||
self.attachments = []
|
||||
self.history = []
|
||||
self.retrieved_docs = []
|
||||
self.agent_config = {}
|
||||
self.retriever_config = {}
|
||||
self.is_shared_usage = False
|
||||
self.shared_token = None
|
||||
self.model_id: Optional[str] = None
|
||||
self.gpt_model = get_gpt_model()
|
||||
self.conversation_service = ConversationService()
|
||||
self.compression_orchestrator = CompressionOrchestrator(
|
||||
self.conversation_service
|
||||
)
|
||||
self.prompt_renderer = PromptRenderer()
|
||||
self._prompt_content: Optional[str] = None
|
||||
self._required_tool_actions: Optional[Dict[str, Set[Optional[str]]]] = None
|
||||
self.compressed_summary: Optional[str] = None
|
||||
self.compressed_summary_tokens: int = 0
|
||||
|
||||
def initialize(self):
|
||||
"""Initialize all required components for processing"""
|
||||
self._validate_and_set_model()
|
||||
self._configure_agent()
|
||||
self._configure_source()
|
||||
self._configure_retriever()
|
||||
self._configure_agent()
|
||||
self._load_conversation_history()
|
||||
@@ -119,71 +98,14 @@ class StreamProcessor:
|
||||
)
|
||||
if not conversation:
|
||||
raise ValueError("Conversation not found or unauthorized")
|
||||
|
||||
# Check if compression is enabled and needed
|
||||
if settings.ENABLE_CONVERSATION_COMPRESSION:
|
||||
self._handle_compression(conversation)
|
||||
else:
|
||||
# Original behavior - load all history
|
||||
self.history = [
|
||||
{"prompt": query["prompt"], "response": query["response"]}
|
||||
for query in conversation.get("queries", [])
|
||||
]
|
||||
else:
|
||||
self.history = limit_chat_history(
|
||||
json.loads(self.data.get("history", "[]")), model_id=self.model_id
|
||||
)
|
||||
|
||||
def _handle_compression(self, conversation: Dict[str, Any]):
|
||||
"""
|
||||
Handle conversation compression logic using orchestrator.
|
||||
|
||||
Args:
|
||||
conversation: Full conversation document
|
||||
"""
|
||||
try:
|
||||
# Use orchestrator to handle all compression logic
|
||||
result = self.compression_orchestrator.compress_if_needed(
|
||||
conversation_id=self.conversation_id,
|
||||
user_id=self.initial_user_id,
|
||||
model_id=self.model_id,
|
||||
decoded_token=self.decoded_token,
|
||||
)
|
||||
|
||||
if not result.success:
|
||||
logger.error(
|
||||
f"Compression failed: {result.error}, using full history"
|
||||
)
|
||||
self.history = [
|
||||
{"prompt": query["prompt"], "response": query["response"]}
|
||||
for query in conversation.get("queries", [])
|
||||
]
|
||||
return
|
||||
|
||||
# Set compressed summary if compression was performed
|
||||
if result.compression_performed and result.compressed_summary:
|
||||
self.compressed_summary = result.compressed_summary
|
||||
self.compressed_summary_tokens = TokenCounter.count_message_tokens(
|
||||
[{"content": result.compressed_summary}]
|
||||
)
|
||||
logger.info(
|
||||
f"Using compressed summary ({self.compressed_summary_tokens} tokens) "
|
||||
f"+ {len(result.recent_queries)} recent messages"
|
||||
)
|
||||
|
||||
# Build history from recent queries
|
||||
self.history = result.as_history()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error handling compression, falling back to standard history: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Fallback to original behavior
|
||||
self.history = [
|
||||
{"prompt": query["prompt"], "response": query["response"]}
|
||||
for query in conversation.get("queries", [])
|
||||
]
|
||||
else:
|
||||
self.history = limit_chat_history(
|
||||
json.loads(self.data.get("history", "[]")), gpt_model=self.gpt_model
|
||||
)
|
||||
|
||||
def _process_attachments(self):
|
||||
"""Process any attachments in the request"""
|
||||
@@ -213,25 +135,6 @@ class StreamProcessor:
|
||||
)
|
||||
return attachments
|
||||
|
||||
def _validate_and_set_model(self):
|
||||
"""Validate and set model_id from request"""
|
||||
from application.core.model_settings import ModelRegistry
|
||||
|
||||
requested_model = self.data.get("model_id")
|
||||
|
||||
if requested_model:
|
||||
if not validate_model_id(requested_model):
|
||||
registry = ModelRegistry.get_instance()
|
||||
available_models = [m.id for m in registry.get_enabled_models()]
|
||||
raise ValueError(
|
||||
f"Invalid model_id '{requested_model}'. "
|
||||
f"Available models: {', '.join(available_models[:5])}"
|
||||
+ (f" and {len(available_models) - 5} more" if len(available_models) > 5 else "")
|
||||
)
|
||||
self.model_id = requested_model
|
||||
else:
|
||||
self.model_id = get_default_model_id()
|
||||
|
||||
def _get_agent_key(self, agent_id: Optional[str], user_id: Optional[str]) -> tuple:
|
||||
"""Get API key for agent with access control"""
|
||||
if not agent_id:
|
||||
@@ -268,77 +171,13 @@ class StreamProcessor:
|
||||
source = data.get("source")
|
||||
if isinstance(source, DBRef):
|
||||
source_doc = self.db.dereference(source)
|
||||
if source_doc:
|
||||
data["source"] = str(source_doc["_id"])
|
||||
data["retriever"] = source_doc.get("retriever", data.get("retriever"))
|
||||
data["chunks"] = source_doc.get("chunks", data.get("chunks"))
|
||||
else:
|
||||
data["source"] = None
|
||||
elif source == "default":
|
||||
data["source"] = "default"
|
||||
data["source"] = str(source_doc["_id"])
|
||||
data["retriever"] = source_doc.get("retriever", data.get("retriever"))
|
||||
data["chunks"] = source_doc.get("chunks", data.get("chunks"))
|
||||
else:
|
||||
data["source"] = None
|
||||
# Handle multiple sources
|
||||
|
||||
sources = data.get("sources", [])
|
||||
if sources and isinstance(sources, list):
|
||||
sources_list = []
|
||||
for i, source_ref in enumerate(sources):
|
||||
if source_ref == "default":
|
||||
processed_source = {
|
||||
"id": "default",
|
||||
"retriever": "classic",
|
||||
"chunks": data.get("chunks", "2"),
|
||||
}
|
||||
sources_list.append(processed_source)
|
||||
elif isinstance(source_ref, DBRef):
|
||||
source_doc = self.db.dereference(source_ref)
|
||||
if source_doc:
|
||||
processed_source = {
|
||||
"id": str(source_doc["_id"]),
|
||||
"retriever": source_doc.get("retriever", "classic"),
|
||||
"chunks": source_doc.get("chunks", data.get("chunks", "2")),
|
||||
}
|
||||
sources_list.append(processed_source)
|
||||
data["sources"] = sources_list
|
||||
else:
|
||||
data["sources"] = []
|
||||
return data
|
||||
|
||||
def _configure_source(self):
|
||||
"""Configure the source based on agent data"""
|
||||
api_key = self.data.get("api_key") or self.agent_key
|
||||
|
||||
if api_key:
|
||||
agent_data = self._get_data_from_api_key(api_key)
|
||||
|
||||
if agent_data.get("sources") and len(agent_data["sources"]) > 0:
|
||||
source_ids = [
|
||||
source["id"] for source in agent_data["sources"] if source.get("id")
|
||||
]
|
||||
if source_ids:
|
||||
self.source = {"active_docs": source_ids}
|
||||
else:
|
||||
self.source = {}
|
||||
self.all_sources = agent_data["sources"]
|
||||
elif agent_data.get("source"):
|
||||
self.source = {"active_docs": agent_data["source"]}
|
||||
self.all_sources = [
|
||||
{
|
||||
"id": agent_data["source"],
|
||||
"retriever": agent_data.get("retriever", "classic"),
|
||||
}
|
||||
]
|
||||
else:
|
||||
self.source = {}
|
||||
self.all_sources = []
|
||||
return
|
||||
if "active_docs" in self.data:
|
||||
self.source = {"active_docs": self.data["active_docs"]}
|
||||
return
|
||||
self.source = {}
|
||||
self.all_sources = []
|
||||
|
||||
def _configure_agent(self):
|
||||
"""Configure the agent based on request data"""
|
||||
agent_id = self.data.get("agent_id")
|
||||
@@ -364,13 +203,7 @@ class StreamProcessor:
|
||||
if data_key.get("retriever"):
|
||||
self.retriever_config["retriever_name"] = data_key["retriever"]
|
||||
if data_key.get("chunks") is not None:
|
||||
try:
|
||||
self.retriever_config["chunks"] = int(data_key["chunks"])
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
f"Invalid chunks value: {data_key['chunks']}, using default value 2"
|
||||
)
|
||||
self.retriever_config["chunks"] = 2
|
||||
self.retriever_config["chunks"] = data_key["chunks"]
|
||||
elif self.agent_key:
|
||||
data_key = self._get_data_from_api_key(self.agent_key)
|
||||
self.agent_config.update(
|
||||
@@ -391,13 +224,7 @@ class StreamProcessor:
|
||||
if data_key.get("retriever"):
|
||||
self.retriever_config["retriever_name"] = data_key["retriever"]
|
||||
if data_key.get("chunks") is not None:
|
||||
try:
|
||||
self.retriever_config["chunks"] = int(data_key["chunks"])
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
f"Invalid chunks value: {data_key['chunks']}, using default value 2"
|
||||
)
|
||||
self.retriever_config["chunks"] = 2
|
||||
self.retriever_config["chunks"] = data_key["chunks"]
|
||||
else:
|
||||
self.agent_config.update(
|
||||
{
|
||||
@@ -409,336 +236,42 @@ class StreamProcessor:
|
||||
)
|
||||
|
||||
def _configure_retriever(self):
|
||||
history_token_limit = int(self.data.get("token_limit", 2000))
|
||||
doc_token_limit = calculate_doc_token_budget(
|
||||
model_id=self.model_id, history_token_limit=history_token_limit
|
||||
)
|
||||
|
||||
"""Configure the retriever based on request data"""
|
||||
self.retriever_config = {
|
||||
"retriever_name": self.data.get("retriever", "classic"),
|
||||
"chunks": int(self.data.get("chunks", 2)),
|
||||
"doc_token_limit": doc_token_limit,
|
||||
"history_token_limit": history_token_limit,
|
||||
"token_limit": self.data.get("token_limit", settings.DEFAULT_MAX_HISTORY),
|
||||
}
|
||||
|
||||
api_key = self.data.get("api_key") or self.agent_key
|
||||
if not api_key and "isNoneDoc" in self.data and self.data["isNoneDoc"]:
|
||||
if "isNoneDoc" in self.data and self.data["isNoneDoc"]:
|
||||
self.retriever_config["chunks"] = 0
|
||||
|
||||
def create_agent(self):
|
||||
"""Create and return the configured agent"""
|
||||
return AgentCreator.create_agent(
|
||||
self.agent_config["agent_type"],
|
||||
endpoint="stream",
|
||||
llm_name=settings.LLM_PROVIDER,
|
||||
gpt_model=self.gpt_model,
|
||||
api_key=settings.API_KEY,
|
||||
user_api_key=self.agent_config["user_api_key"],
|
||||
prompt=get_prompt(self.agent_config["prompt_id"], self.prompts_collection),
|
||||
chat_history=self.history,
|
||||
decoded_token=self.decoded_token,
|
||||
attachments=self.attachments,
|
||||
json_schema=self.agent_config.get("json_schema"),
|
||||
)
|
||||
|
||||
def create_retriever(self):
|
||||
"""Create and return the configured retriever"""
|
||||
return RetrieverCreator.create_retriever(
|
||||
self.retriever_config["retriever_name"],
|
||||
source=self.source,
|
||||
chat_history=self.history,
|
||||
prompt=get_prompt(self.agent_config["prompt_id"], self.prompts_collection),
|
||||
chunks=self.retriever_config["chunks"],
|
||||
doc_token_limit=self.retriever_config.get("doc_token_limit", 50000),
|
||||
model_id=self.model_id,
|
||||
token_limit=self.retriever_config["token_limit"],
|
||||
gpt_model=self.gpt_model,
|
||||
user_api_key=self.agent_config["user_api_key"],
|
||||
decoded_token=self.decoded_token,
|
||||
)
|
||||
|
||||
def pre_fetch_docs(self, question: str) -> tuple[Optional[str], Optional[list]]:
|
||||
"""Pre-fetch documents for template rendering before agent creation"""
|
||||
if self.data.get("isNoneDoc", False):
|
||||
logger.info("Pre-fetch skipped: isNoneDoc=True")
|
||||
return None, None
|
||||
try:
|
||||
retriever = self.create_retriever()
|
||||
logger.info(
|
||||
f"Pre-fetching docs with chunks={retriever.chunks}, doc_token_limit={retriever.doc_token_limit}"
|
||||
)
|
||||
docs = retriever.search(question)
|
||||
logger.info(f"Pre-fetch retrieved {len(docs) if docs else 0} documents")
|
||||
|
||||
if not docs:
|
||||
logger.info("Pre-fetch: No documents returned from search")
|
||||
return None, None
|
||||
self.retrieved_docs = docs
|
||||
|
||||
docs_with_filenames = []
|
||||
for doc in docs:
|
||||
filename = doc.get("filename") or doc.get("title") or doc.get("source")
|
||||
if filename:
|
||||
chunk_header = str(filename)
|
||||
docs_with_filenames.append(f"{chunk_header}\n{doc['text']}")
|
||||
else:
|
||||
docs_with_filenames.append(doc["text"])
|
||||
docs_together = "\n\n".join(docs_with_filenames)
|
||||
|
||||
logger.info(f"Pre-fetch docs_together size: {len(docs_together)} chars")
|
||||
|
||||
return docs_together, docs
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to pre-fetch docs: {str(e)}", exc_info=True)
|
||||
return None, None
|
||||
|
||||
def pre_fetch_tools(self) -> Optional[Dict[str, Any]]:
|
||||
"""Pre-fetch tool data for template rendering before agent creation
|
||||
|
||||
Can be controlled via:
|
||||
1. Global setting: ENABLE_TOOL_PREFETCH in .env
|
||||
2. Per-request: disable_tool_prefetch in request data
|
||||
"""
|
||||
if not settings.ENABLE_TOOL_PREFETCH:
|
||||
logger.info(
|
||||
"Tool pre-fetching disabled globally via ENABLE_TOOL_PREFETCH setting"
|
||||
)
|
||||
return None
|
||||
|
||||
if self.data.get("disable_tool_prefetch", False):
|
||||
logger.info("Tool pre-fetching disabled for this request")
|
||||
return None
|
||||
|
||||
required_tool_actions = self._get_required_tool_actions()
|
||||
filtering_enabled = required_tool_actions is not None
|
||||
|
||||
try:
|
||||
user_tools_collection = self.db["user_tools"]
|
||||
user_id = self.initial_user_id or "local"
|
||||
|
||||
user_tools = list(
|
||||
user_tools_collection.find({"user": user_id, "status": True})
|
||||
)
|
||||
|
||||
if not user_tools:
|
||||
return None
|
||||
|
||||
tools_data = {}
|
||||
|
||||
for tool_doc in user_tools:
|
||||
tool_name = tool_doc.get("name")
|
||||
tool_id = str(tool_doc.get("_id"))
|
||||
|
||||
if filtering_enabled:
|
||||
required_actions_by_name = required_tool_actions.get(
|
||||
tool_name, set()
|
||||
)
|
||||
required_actions_by_id = required_tool_actions.get(tool_id, set())
|
||||
|
||||
required_actions = required_actions_by_name | required_actions_by_id
|
||||
|
||||
if not required_actions:
|
||||
continue
|
||||
else:
|
||||
required_actions = None
|
||||
|
||||
tool_data = self._fetch_tool_data(tool_doc, required_actions)
|
||||
if tool_data:
|
||||
tools_data[tool_name] = tool_data
|
||||
tools_data[tool_id] = tool_data
|
||||
|
||||
return tools_data if tools_data else None
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to pre-fetch tools: {type(e).__name__}")
|
||||
return None
|
||||
|
||||
def _fetch_tool_data(
|
||||
self,
|
||||
tool_doc: Dict[str, Any],
|
||||
required_actions: Optional[Set[Optional[str]]],
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Fetch and execute tool actions with saved parameters"""
|
||||
try:
|
||||
from application.agents.tools.tool_manager import ToolManager
|
||||
|
||||
tool_name = tool_doc.get("name")
|
||||
tool_config = tool_doc.get("config", {}).copy()
|
||||
tool_config["tool_id"] = str(tool_doc["_id"])
|
||||
|
||||
tool_manager = ToolManager(config={tool_name: tool_config})
|
||||
user_id = self.initial_user_id or "local"
|
||||
tool = tool_manager.load_tool(tool_name, tool_config, user_id=user_id)
|
||||
|
||||
if not tool:
|
||||
logger.debug(f"Tool '{tool_name}' failed to load")
|
||||
return None
|
||||
|
||||
tool_actions = tool.get_actions_metadata()
|
||||
if not tool_actions:
|
||||
logger.debug(f"Tool '{tool_name}' has no actions")
|
||||
return None
|
||||
|
||||
saved_actions = tool_doc.get("actions", [])
|
||||
|
||||
include_all_actions = required_actions is None or (
|
||||
required_actions and None in required_actions
|
||||
)
|
||||
allowed_actions: Set[str] = (
|
||||
{action for action in required_actions if isinstance(action, str)}
|
||||
if required_actions
|
||||
else set()
|
||||
)
|
||||
|
||||
action_results = {}
|
||||
for action_meta in tool_actions:
|
||||
action_name = action_meta.get("name")
|
||||
if action_name is None:
|
||||
continue
|
||||
if (
|
||||
not include_all_actions
|
||||
and allowed_actions
|
||||
and action_name not in allowed_actions
|
||||
):
|
||||
continue
|
||||
|
||||
try:
|
||||
saved_action = None
|
||||
for sa in saved_actions:
|
||||
if sa.get("name") == action_name:
|
||||
saved_action = sa
|
||||
break
|
||||
|
||||
action_params = action_meta.get("parameters", {})
|
||||
properties = action_params.get("properties", {})
|
||||
|
||||
kwargs = {}
|
||||
for param_name, param_spec in properties.items():
|
||||
if saved_action:
|
||||
saved_props = saved_action.get("parameters", {}).get(
|
||||
"properties", {}
|
||||
)
|
||||
if param_name in saved_props:
|
||||
param_value = saved_props[param_name].get("value")
|
||||
if param_value is not None:
|
||||
kwargs[param_name] = param_value
|
||||
continue
|
||||
|
||||
if param_name in tool_config:
|
||||
kwargs[param_name] = tool_config[param_name]
|
||||
elif "default" in param_spec:
|
||||
kwargs[param_name] = param_spec["default"]
|
||||
|
||||
result = tool.execute_action(action_name, **kwargs)
|
||||
action_results[action_name] = result
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
f"Action '{action_name}' execution failed: {type(e).__name__}"
|
||||
)
|
||||
continue
|
||||
|
||||
return action_results if action_results else None
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Tool pre-fetch failed for '{tool_name}': {type(e).__name__}")
|
||||
return None
|
||||
|
||||
def _get_prompt_content(self) -> Optional[str]:
|
||||
"""Retrieve and cache the raw prompt content for the current agent configuration."""
|
||||
if self._prompt_content is not None:
|
||||
return self._prompt_content
|
||||
prompt_id = (
|
||||
self.agent_config.get("prompt_id")
|
||||
if isinstance(self.agent_config, dict)
|
||||
else None
|
||||
)
|
||||
if not prompt_id:
|
||||
return None
|
||||
try:
|
||||
self._prompt_content = get_prompt(prompt_id, self.prompts_collection)
|
||||
except ValueError as e:
|
||||
logger.debug(f"Invalid prompt ID '{prompt_id}': {str(e)}")
|
||||
self._prompt_content = None
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to fetch prompt '{prompt_id}': {type(e).__name__}")
|
||||
self._prompt_content = None
|
||||
return self._prompt_content
|
||||
|
||||
def _get_required_tool_actions(self) -> Optional[Dict[str, Set[Optional[str]]]]:
|
||||
"""Determine which tool actions are referenced in the prompt template"""
|
||||
if self._required_tool_actions is not None:
|
||||
return self._required_tool_actions
|
||||
|
||||
prompt_content = self._get_prompt_content()
|
||||
if prompt_content is None:
|
||||
return None
|
||||
|
||||
if "{{" not in prompt_content or "}}" not in prompt_content:
|
||||
self._required_tool_actions = {}
|
||||
return self._required_tool_actions
|
||||
|
||||
try:
|
||||
from application.templates.template_engine import TemplateEngine
|
||||
|
||||
template_engine = TemplateEngine()
|
||||
usages = template_engine.extract_tool_usages(prompt_content)
|
||||
self._required_tool_actions = usages
|
||||
return self._required_tool_actions
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to extract tool usages: {type(e).__name__}")
|
||||
self._required_tool_actions = {}
|
||||
return self._required_tool_actions
|
||||
|
||||
def _fetch_memory_tool_data(
|
||||
self, tool_doc: Dict[str, Any]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Fetch memory tool data for pre-injection into prompt"""
|
||||
try:
|
||||
tool_config = tool_doc.get("config", {}).copy()
|
||||
tool_config["tool_id"] = str(tool_doc["_id"])
|
||||
|
||||
from application.agents.tools.memory import MemoryTool
|
||||
|
||||
memory_tool = MemoryTool(tool_config, self.initial_user_id)
|
||||
|
||||
root_view = memory_tool.execute_action("view", path="/")
|
||||
|
||||
if "Error:" in root_view or not root_view.strip():
|
||||
return None
|
||||
|
||||
return {"root": root_view, "available": True}
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch memory tool data: {str(e)}")
|
||||
return None
|
||||
|
||||
def create_agent(
|
||||
self,
|
||||
docs_together: Optional[str] = None,
|
||||
docs: Optional[list] = None,
|
||||
tools_data: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""Create and return the configured agent with rendered prompt"""
|
||||
raw_prompt = self._get_prompt_content()
|
||||
if raw_prompt is None:
|
||||
raw_prompt = get_prompt(
|
||||
self.agent_config["prompt_id"], self.prompts_collection
|
||||
)
|
||||
self._prompt_content = raw_prompt
|
||||
|
||||
rendered_prompt = self.prompt_renderer.render_prompt(
|
||||
prompt_content=raw_prompt,
|
||||
user_id=self.initial_user_id,
|
||||
request_id=self.data.get("request_id"),
|
||||
passthrough_data=self.data.get("passthrough"),
|
||||
docs=docs,
|
||||
docs_together=docs_together,
|
||||
tools_data=tools_data,
|
||||
)
|
||||
|
||||
provider = (
|
||||
get_provider_from_model_id(self.model_id)
|
||||
if self.model_id
|
||||
else settings.LLM_PROVIDER
|
||||
)
|
||||
system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER)
|
||||
|
||||
agent = AgentCreator.create_agent(
|
||||
self.agent_config["agent_type"],
|
||||
endpoint="stream",
|
||||
llm_name=provider or settings.LLM_PROVIDER,
|
||||
model_id=self.model_id,
|
||||
api_key=system_api_key,
|
||||
user_api_key=self.agent_config["user_api_key"],
|
||||
prompt=rendered_prompt,
|
||||
chat_history=self.history,
|
||||
retrieved_docs=self.retrieved_docs,
|
||||
decoded_token=self.decoded_token,
|
||||
attachments=self.attachments,
|
||||
json_schema=self.agent_config.get("json_schema"),
|
||||
compressed_summary=self.compressed_summary,
|
||||
)
|
||||
|
||||
agent.conversation_id = self.conversation_id
|
||||
agent.initial_user_id = self.initial_user_id
|
||||
|
||||
return agent
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import base64
|
||||
import datetime
|
||||
import json
|
||||
import uuid
|
||||
import logging
|
||||
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
@@ -15,6 +14,8 @@ from flask import (
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
|
||||
|
||||
|
||||
from application.api.user.tasks import (
|
||||
ingest_connector_task,
|
||||
)
|
||||
@@ -23,9 +24,15 @@ from application.core.settings import settings
|
||||
from application.api import api
|
||||
|
||||
|
||||
from application.utils import (
|
||||
check_required_fields
|
||||
)
|
||||
|
||||
|
||||
from application.parser.connectors.connector_creator import ConnectorCreator
|
||||
|
||||
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
sources_collection = db["sources"]
|
||||
@@ -37,6 +44,185 @@ api.add_namespace(connectors_ns)
|
||||
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/upload")
|
||||
class UploadConnector(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"ConnectorUploadModel",
|
||||
{
|
||||
"user": fields.String(required=True, description="User ID"),
|
||||
"source": fields.String(
|
||||
required=True, description="Source type (google_drive, github, etc.)"
|
||||
),
|
||||
"name": fields.String(required=True, description="Job name"),
|
||||
"data": fields.String(required=True, description="Configuration data"),
|
||||
"repo_url": fields.String(description="GitHub repository URL"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Uploads connector source for vectorization",
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
data = request.form
|
||||
required_fields = ["user", "source", "name", "data"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
config = json.loads(data["data"])
|
||||
source_data = None
|
||||
sync_frequency = config.get("sync_frequency", "never")
|
||||
|
||||
if data["source"] == "github":
|
||||
source_data = config.get("repo_url")
|
||||
elif data["source"] in ["crawler", "url"]:
|
||||
source_data = config.get("url")
|
||||
elif data["source"] == "reddit":
|
||||
source_data = config
|
||||
elif data["source"] in ConnectorCreator.get_supported_connectors():
|
||||
session_token = config.get("session_token")
|
||||
if not session_token:
|
||||
return make_response(jsonify({
|
||||
"success": False,
|
||||
"error": f"Missing session_token in {data['source']} configuration"
|
||||
}), 400)
|
||||
|
||||
file_ids = config.get("file_ids", [])
|
||||
if isinstance(file_ids, str):
|
||||
file_ids = [id.strip() for id in file_ids.split(',') if id.strip()]
|
||||
elif not isinstance(file_ids, list):
|
||||
file_ids = []
|
||||
|
||||
folder_ids = config.get("folder_ids", [])
|
||||
if isinstance(folder_ids, str):
|
||||
folder_ids = [id.strip() for id in folder_ids.split(',') if id.strip()]
|
||||
elif not isinstance(folder_ids, list):
|
||||
folder_ids = []
|
||||
|
||||
config["file_ids"] = file_ids
|
||||
config["folder_ids"] = folder_ids
|
||||
|
||||
task = ingest_connector_task.delay(
|
||||
job_name=data["name"],
|
||||
user=decoded_token.get("sub"),
|
||||
source_type=data["source"],
|
||||
session_token=session_token,
|
||||
file_ids=file_ids,
|
||||
folder_ids=folder_ids,
|
||||
recursive=config.get("recursive", False),
|
||||
retriever=config.get("retriever", "classic"),
|
||||
sync_frequency=sync_frequency
|
||||
)
|
||||
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
|
||||
task = ingest_connector_task.delay(
|
||||
source_data=source_data,
|
||||
job_name=data["name"],
|
||||
user=decoded_token.get("sub"),
|
||||
loader=data["source"],
|
||||
sync_frequency=sync_frequency
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error uploading connector source: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/task_status")
|
||||
class ConnectorTaskStatus(Resource):
|
||||
task_status_model = api.model(
|
||||
"ConnectorTaskStatusModel",
|
||||
{"task_id": fields.String(required=True, description="Task ID")},
|
||||
)
|
||||
|
||||
@api.expect(task_status_model)
|
||||
@api.doc(description="Get connector task status")
|
||||
def get(self):
|
||||
task_id = request.args.get("task_id")
|
||||
if not task_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Task ID is required"}), 400
|
||||
)
|
||||
try:
|
||||
from application.celery_init import celery
|
||||
|
||||
task = celery.AsyncResult(task_id)
|
||||
task_meta = task.info
|
||||
print(f"Task status: {task.status}")
|
||||
if not isinstance(
|
||||
task_meta, (dict, list, str, int, float, bool, type(None))
|
||||
):
|
||||
task_meta = str(task_meta)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error getting task status: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"status": task.status, "result": task_meta}), 200)
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/sources")
|
||||
class ConnectorSources(Resource):
|
||||
@api.doc(description="Get connector sources")
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
sources = sources_collection.find({"user": user, "type": "connector"}).sort("date", -1)
|
||||
connector_sources = []
|
||||
for source in sources:
|
||||
connector_sources.append({
|
||||
"id": str(source["_id"]),
|
||||
"name": source.get("name"),
|
||||
"date": source.get("date"),
|
||||
"type": source.get("type"),
|
||||
"source": source.get("source"),
|
||||
"tokens": source.get("tokens", ""),
|
||||
"retriever": source.get("retriever", "classic"),
|
||||
"syncFrequency": source.get("sync_frequency", ""),
|
||||
})
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error retrieving connector sources: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify(connector_sources), 200)
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/delete")
|
||||
class DeleteConnectorSource(Resource):
|
||||
@api.doc(
|
||||
description="Delete a connector source",
|
||||
params={"source_id": "The source ID to delete"},
|
||||
)
|
||||
def delete(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
source_id = request.args.get("source_id")
|
||||
if not source_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "source_id is required"}), 400
|
||||
)
|
||||
try:
|
||||
result = sources_collection.delete_one(
|
||||
{"_id": ObjectId(source_id), "user": decoded_token.get("sub")}
|
||||
)
|
||||
if result.deleted_count == 0:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Source not found"}), 404
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error deleting connector source: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/auth")
|
||||
class ConnectorAuth(Resource):
|
||||
@api.doc(description="Get connector OAuth authorization URL", params={"provider": "Connector provider (e.g., google_drive)"})
|
||||
@@ -49,24 +235,8 @@ class ConnectorAuth(Resource):
|
||||
if not ConnectorCreator.is_supported(provider):
|
||||
return make_response(jsonify({"success": False, "error": f"Unsupported provider: {provider}"}), 400)
|
||||
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
|
||||
user_id = decoded_token.get('sub')
|
||||
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
result = sessions_collection.insert_one({
|
||||
"provider": provider,
|
||||
"user": user_id,
|
||||
"status": "pending",
|
||||
"created_at": now
|
||||
})
|
||||
state_dict = {
|
||||
"provider": provider,
|
||||
"object_id": str(result.inserted_id)
|
||||
}
|
||||
state = base64.urlsafe_b64encode(json.dumps(state_dict).encode()).decode()
|
||||
|
||||
import uuid
|
||||
state = str(uuid.uuid4())
|
||||
auth = ConnectorCreator.create_auth(provider)
|
||||
authorization_url = auth.get_authorization_url(state=state)
|
||||
return make_response(jsonify({
|
||||
@@ -87,30 +257,25 @@ class ConnectorsCallback(Resource):
|
||||
try:
|
||||
from application.parser.connectors.connector_creator import ConnectorCreator
|
||||
from flask import request, redirect
|
||||
import uuid
|
||||
|
||||
provider = request.args.get('provider', 'google_drive')
|
||||
authorization_code = request.args.get('code')
|
||||
state = request.args.get('state')
|
||||
_ = request.args.get('state')
|
||||
error = request.args.get('error')
|
||||
|
||||
state_dict = json.loads(base64.urlsafe_b64decode(state.encode()).decode())
|
||||
provider = state_dict["provider"]
|
||||
state_object_id = state_dict["object_id"]
|
||||
|
||||
if error:
|
||||
if error == "access_denied":
|
||||
return redirect(f"/api/connectors/callback-status?status=cancelled&message=Authentication+was+cancelled.+You+can+try+again+if+you'd+like+to+connect+your+account.&provider={provider}")
|
||||
else:
|
||||
current_app.logger.warning(f"OAuth error in callback: {error}")
|
||||
return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}")
|
||||
return redirect(f"/api/connectors/callback-status?status=error&message=OAuth+error:+{error}.+Please+try+again+and+make+sure+to+grant+all+requested+permissions,+including+offline+access.&provider={provider}")
|
||||
|
||||
if not authorization_code:
|
||||
return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}")
|
||||
return redirect(f"/api/connectors/callback-status?status=error&message=Authorization+code+not+provided.+Please+complete+the+authorization+process+and+make+sure+to+grant+offline+access.&provider={provider}")
|
||||
|
||||
try:
|
||||
auth = ConnectorCreator.create_auth(provider)
|
||||
token_info = auth.exchange_code_for_tokens(authorization_code)
|
||||
|
||||
session_token = str(uuid.uuid4())
|
||||
|
||||
|
||||
try:
|
||||
credentials = auth.create_credentials_from_token_info(token_info)
|
||||
@@ -125,44 +290,57 @@ class ConnectorsCallback(Resource):
|
||||
"access_token": token_info.get("access_token"),
|
||||
"refresh_token": token_info.get("refresh_token"),
|
||||
"token_uri": token_info.get("token_uri"),
|
||||
"expiry": token_info.get("expiry")
|
||||
"expiry": token_info.get("expiry"),
|
||||
"scopes": token_info.get("scopes")
|
||||
}
|
||||
|
||||
sessions_collection.find_one_and_update(
|
||||
{"_id": ObjectId(state_object_id), "provider": provider},
|
||||
{
|
||||
"$set": {
|
||||
"session_token": session_token,
|
||||
"token_info": sanitized_token_info,
|
||||
"user_email": user_email,
|
||||
"status": "authorized"
|
||||
}
|
||||
}
|
||||
)
|
||||
user_id = request.decoded_token.get("sub") if getattr(request, "decoded_token", None) else None
|
||||
sessions_collection.insert_one({
|
||||
"session_token": session_token,
|
||||
"user": user_id,
|
||||
"token_info": sanitized_token_info,
|
||||
"created_at": datetime.datetime.now(datetime.timezone.utc),
|
||||
"user_email": user_email,
|
||||
"provider": provider
|
||||
})
|
||||
|
||||
# Redirect to success page with session token and user email
|
||||
return redirect(f"/api/connectors/callback-status?status=success&message=Authentication+successful&provider={provider}&session_token={session_token}&user_email={user_email}")
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error exchanging code for tokens: {str(e)}", exc_info=True)
|
||||
return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}")
|
||||
return redirect(f"/api/connectors/callback-status?status=error&message=Failed+to+exchange+authorization+code+for+tokens:+{str(e)}&provider={provider}")
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error handling connector callback: {e}")
|
||||
return redirect("/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.")
|
||||
return redirect(f"/api/connectors/callback-status?status=error&message=Failed+to+complete+connector+authentication:+{str(e)}.+Please+try+again+and+make+sure+to+grant+all+requested+permissions,+including+offline+access.")
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/refresh")
|
||||
class ConnectorRefresh(Resource):
|
||||
@api.expect(api.model("ConnectorRefreshModel", {"provider": fields.String(required=True), "refresh_token": fields.String(required=True)}))
|
||||
@api.doc(description="Refresh connector access token")
|
||||
def post(self):
|
||||
try:
|
||||
data = request.get_json()
|
||||
provider = data.get('provider')
|
||||
refresh_token = data.get('refresh_token')
|
||||
|
||||
if not provider or not refresh_token:
|
||||
return make_response(jsonify({"success": False, "error": "provider and refresh_token are required"}), 400)
|
||||
|
||||
auth = ConnectorCreator.create_auth(provider)
|
||||
token_info = auth.refresh_access_token(refresh_token)
|
||||
return make_response(jsonify({"success": True, "token_info": token_info}), 200)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error refreshing token for connector: {e}")
|
||||
return make_response(jsonify({"success": False, "error": str(e)}), 500)
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/files")
|
||||
class ConnectorFiles(Resource):
|
||||
@api.expect(api.model("ConnectorFilesModel", {
|
||||
"provider": fields.String(required=True),
|
||||
"session_token": fields.String(required=True),
|
||||
"folder_id": fields.String(required=False),
|
||||
"limit": fields.Integer(required=False),
|
||||
"page_token": fields.String(required=False),
|
||||
"search_query": fields.String(required=False)
|
||||
}))
|
||||
@api.doc(description="List files from a connector provider (supports pagination and search)")
|
||||
@api.expect(api.model("ConnectorFilesModel", {"provider": fields.String(required=True), "session_token": fields.String(required=True), "folder_id": fields.String(required=False), "limit": fields.Integer(required=False), "page_token": fields.String(required=False)}))
|
||||
@api.doc(description="List files from a connector provider (supports pagination)")
|
||||
def post(self):
|
||||
try:
|
||||
data = request.get_json()
|
||||
@@ -171,11 +349,10 @@ class ConnectorFiles(Resource):
|
||||
folder_id = data.get('folder_id')
|
||||
limit = data.get('limit', 10)
|
||||
page_token = data.get('page_token')
|
||||
search_query = data.get('search_query')
|
||||
|
||||
if not provider or not session_token:
|
||||
return make_response(jsonify({"success": False, "error": "provider and session_token are required"}), 400)
|
||||
|
||||
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
|
||||
@@ -185,17 +362,13 @@ class ConnectorFiles(Resource):
|
||||
return make_response(jsonify({"success": False, "error": "Invalid or unauthorized session"}), 401)
|
||||
|
||||
loader = ConnectorCreator.create_connector(provider, session_token)
|
||||
input_config = {
|
||||
documents = loader.load_data({
|
||||
'limit': limit,
|
||||
'list_only': True,
|
||||
'session_token': session_token,
|
||||
'folder_id': folder_id,
|
||||
'page_token': page_token
|
||||
}
|
||||
if search_query:
|
||||
input_config['search_query'] = search_query
|
||||
|
||||
documents = loader.load_data(input_config)
|
||||
})
|
||||
|
||||
files = []
|
||||
for doc in documents[:limit]:
|
||||
@@ -213,20 +386,13 @@ class ConnectorFiles(Resource):
|
||||
'name': metadata.get('file_name', 'Unknown File'),
|
||||
'type': metadata.get('mime_type', 'unknown'),
|
||||
'size': metadata.get('size', None),
|
||||
'modifiedTime': formatted_time,
|
||||
'isFolder': metadata.get('is_folder', False)
|
||||
'modifiedTime': formatted_time
|
||||
})
|
||||
|
||||
next_token = getattr(loader, 'next_page_token', None)
|
||||
has_more = bool(next_token)
|
||||
|
||||
return make_response(jsonify({
|
||||
"success": True,
|
||||
"files": files,
|
||||
"total": len(files),
|
||||
"next_page_token": next_token,
|
||||
"has_more": has_more
|
||||
}), 200)
|
||||
return make_response(jsonify({"success": True, "files": files, "total": len(files), "next_page_token": next_token, "has_more": has_more}), 200)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error loading connector files: {e}")
|
||||
return make_response(jsonify({"success": False, "error": f"Failed to load files: {str(e)}"}), 500)
|
||||
@@ -235,7 +401,7 @@ class ConnectorFiles(Resource):
|
||||
@connectors_ns.route("/api/connectors/validate-session")
|
||||
class ConnectorValidateSession(Resource):
|
||||
@api.expect(api.model("ConnectorValidateSessionModel", {"provider": fields.String(required=True), "session_token": fields.String(required=True)}))
|
||||
@api.doc(description="Validate connector session token and return user info and access token")
|
||||
@api.doc(description="Validate connector session token and return user info")
|
||||
def post(self):
|
||||
try:
|
||||
data = request.get_json()
|
||||
@@ -244,6 +410,7 @@ class ConnectorValidateSession(Resource):
|
||||
if not provider or not session_token:
|
||||
return make_response(jsonify({"success": False, "error": "provider and session_token are required"}), 400)
|
||||
|
||||
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
|
||||
@@ -257,36 +424,10 @@ class ConnectorValidateSession(Resource):
|
||||
auth = ConnectorCreator.create_auth(provider)
|
||||
is_expired = auth.is_token_expired(token_info)
|
||||
|
||||
if is_expired and token_info.get('refresh_token'):
|
||||
try:
|
||||
refreshed_token_info = auth.refresh_access_token(token_info.get('refresh_token'))
|
||||
sanitized_token_info = {
|
||||
"access_token": refreshed_token_info.get("access_token"),
|
||||
"refresh_token": refreshed_token_info.get("refresh_token"),
|
||||
"token_uri": refreshed_token_info.get("token_uri"),
|
||||
"expiry": refreshed_token_info.get("expiry")
|
||||
}
|
||||
sessions_collection.update_one(
|
||||
{"session_token": session_token},
|
||||
{"$set": {"token_info": sanitized_token_info}}
|
||||
)
|
||||
token_info = sanitized_token_info
|
||||
is_expired = False
|
||||
except Exception as refresh_error:
|
||||
current_app.logger.error(f"Failed to refresh token: {refresh_error}")
|
||||
|
||||
if is_expired:
|
||||
return make_response(jsonify({
|
||||
"success": False,
|
||||
"expired": True,
|
||||
"error": "Session token has expired. Please reconnect."
|
||||
}), 401)
|
||||
|
||||
return make_response(jsonify({
|
||||
"success": True,
|
||||
"expired": False,
|
||||
"user_email": session.get('user_email', 'Connected User'),
|
||||
"access_token": token_info.get('access_token')
|
||||
"expired": is_expired,
|
||||
"user_email": session.get('user_email', 'Connected User')
|
||||
}), 200)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error validating connector session: {e}")
|
||||
@@ -446,23 +587,20 @@ class ConnectorCallbackStatus(Resource):
|
||||
.container {{ max-width: 600px; margin: 0 auto; }}
|
||||
.success {{ color: #4CAF50; }}
|
||||
.error {{ color: #F44336; }}
|
||||
.cancelled {{ color: #FF9800; }}
|
||||
</style>
|
||||
<script>
|
||||
window.onload = function() {{
|
||||
const status = "{status}";
|
||||
const sessionToken = "{session_token}";
|
||||
const userEmail = "{user_email}";
|
||||
|
||||
|
||||
if (status === "success" && window.opener) {{
|
||||
window.opener.postMessage({{
|
||||
type: '{provider}_auth_success',
|
||||
session_token: sessionToken,
|
||||
user_email: userEmail
|
||||
}}, '*');
|
||||
|
||||
setTimeout(() => window.close(), 3000);
|
||||
}} else if (status === "cancelled" || status === "error") {{
|
||||
|
||||
setTimeout(() => window.close(), 3000);
|
||||
}}
|
||||
}};
|
||||
@@ -475,7 +613,7 @@ class ConnectorCallbackStatus(Resource):
|
||||
<p>{message}</p>
|
||||
{f'<p>Connected as: {user_email}</p>' if status == 'success' else ''}
|
||||
</div>
|
||||
<p><small>You can close this window. {f"Your {provider.replace('_', ' ').title()} is now connected and ready to use." if status == 'success' else "Feel free to close this window."}</small></p>
|
||||
<p><small>You can close this window. {f"Your {provider.replace('_', ' ').title()} is now connected and ready to use." if status == 'success' else ''}</small></p>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
"""User API module - provides all user-related API endpoints"""
|
||||
|
||||
from .routes import user
|
||||
|
||||
__all__ = ["user"]
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
"""Agents module."""
|
||||
|
||||
from .routes import agents_ns
|
||||
from .sharing import agents_sharing_ns
|
||||
from .webhooks import agents_webhooks_ns
|
||||
|
||||
__all__ = ["agents_ns", "agents_sharing_ns", "agents_webhooks_ns"]
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,263 +0,0 @@
|
||||
"""Agent management sharing functionality."""
|
||||
|
||||
import datetime
|
||||
import secrets
|
||||
|
||||
from bson import DBRef
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.core.settings import settings
|
||||
from application.api.user.base import (
|
||||
agents_collection,
|
||||
db,
|
||||
ensure_user_doc,
|
||||
resolve_tool_details,
|
||||
user_tools_collection,
|
||||
users_collection,
|
||||
)
|
||||
from application.utils import generate_image_url
|
||||
|
||||
agents_sharing_ns = Namespace(
|
||||
"agents", description="Agent management operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
@agents_sharing_ns.route("/shared_agent")
|
||||
class SharedAgent(Resource):
|
||||
@api.doc(
|
||||
params={
|
||||
"token": "Shared token of the agent",
|
||||
},
|
||||
description="Get a shared agent by token or ID",
|
||||
)
|
||||
def get(self):
|
||||
shared_token = request.args.get("token")
|
||||
|
||||
if not shared_token:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Token or ID is required"}), 400
|
||||
)
|
||||
try:
|
||||
query = {
|
||||
"shared_publicly": True,
|
||||
"shared_token": shared_token,
|
||||
}
|
||||
shared_agent = agents_collection.find_one(query)
|
||||
if not shared_agent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Shared agent not found"}),
|
||||
404,
|
||||
)
|
||||
agent_id = str(shared_agent["_id"])
|
||||
data = {
|
||||
"id": agent_id,
|
||||
"user": shared_agent.get("user", ""),
|
||||
"name": shared_agent.get("name", ""),
|
||||
"image": (
|
||||
generate_image_url(shared_agent["image"])
|
||||
if shared_agent.get("image")
|
||||
else ""
|
||||
),
|
||||
"description": shared_agent.get("description", ""),
|
||||
"source": (
|
||||
str(source_doc["_id"])
|
||||
if isinstance(shared_agent.get("source"), DBRef)
|
||||
and (source_doc := db.dereference(shared_agent.get("source")))
|
||||
else ""
|
||||
),
|
||||
"chunks": shared_agent.get("chunks", "0"),
|
||||
"retriever": shared_agent.get("retriever", "classic"),
|
||||
"prompt_id": shared_agent.get("prompt_id", "default"),
|
||||
"tools": shared_agent.get("tools", []),
|
||||
"tool_details": resolve_tool_details(shared_agent.get("tools", [])),
|
||||
"agent_type": shared_agent.get("agent_type", ""),
|
||||
"status": shared_agent.get("status", ""),
|
||||
"json_schema": shared_agent.get("json_schema"),
|
||||
"limited_token_mode": shared_agent.get("limited_token_mode", False),
|
||||
"token_limit": shared_agent.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"]),
|
||||
"limited_request_mode": shared_agent.get("limited_request_mode", False),
|
||||
"request_limit": shared_agent.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"]),
|
||||
"created_at": shared_agent.get("createdAt", ""),
|
||||
"updated_at": shared_agent.get("updatedAt", ""),
|
||||
"shared": shared_agent.get("shared_publicly", False),
|
||||
"shared_token": shared_agent.get("shared_token", ""),
|
||||
"shared_metadata": shared_agent.get("shared_metadata", {}),
|
||||
}
|
||||
|
||||
if data["tools"]:
|
||||
enriched_tools = []
|
||||
for tool in data["tools"]:
|
||||
tool_data = user_tools_collection.find_one({"_id": ObjectId(tool)})
|
||||
if tool_data:
|
||||
enriched_tools.append(tool_data.get("name", ""))
|
||||
data["tools"] = enriched_tools
|
||||
decoded_token = getattr(request, "decoded_token", None)
|
||||
if decoded_token:
|
||||
user_id = decoded_token.get("sub")
|
||||
owner_id = shared_agent.get("user")
|
||||
|
||||
if user_id != owner_id:
|
||||
ensure_user_doc(user_id)
|
||||
users_collection.update_one(
|
||||
{"user_id": user_id},
|
||||
{"$addToSet": {"agent_preferences.shared_with_me": agent_id}},
|
||||
)
|
||||
return make_response(jsonify(data), 200)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error retrieving shared agent: {err}")
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
|
||||
@agents_sharing_ns.route("/shared_agents")
|
||||
class SharedAgents(Resource):
|
||||
@api.doc(description="Get shared agents explicitly shared with the user")
|
||||
def get(self):
|
||||
try:
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user_id = decoded_token.get("sub")
|
||||
|
||||
user_doc = ensure_user_doc(user_id)
|
||||
shared_with_ids = user_doc.get("agent_preferences", {}).get(
|
||||
"shared_with_me", []
|
||||
)
|
||||
shared_object_ids = [ObjectId(id) for id in shared_with_ids]
|
||||
|
||||
shared_agents_cursor = agents_collection.find(
|
||||
{"_id": {"$in": shared_object_ids}, "shared_publicly": True}
|
||||
)
|
||||
shared_agents = list(shared_agents_cursor)
|
||||
|
||||
found_ids_set = {str(agent["_id"]) for agent in shared_agents}
|
||||
stale_ids = [id for id in shared_with_ids if id not in found_ids_set]
|
||||
if stale_ids:
|
||||
users_collection.update_one(
|
||||
{"user_id": user_id},
|
||||
{"$pullAll": {"agent_preferences.shared_with_me": stale_ids}},
|
||||
)
|
||||
pinned_ids = set(user_doc.get("agent_preferences", {}).get("pinned", []))
|
||||
|
||||
list_shared_agents = [
|
||||
{
|
||||
"id": str(agent["_id"]),
|
||||
"name": agent.get("name", ""),
|
||||
"description": agent.get("description", ""),
|
||||
"image": (
|
||||
generate_image_url(agent["image"]) if agent.get("image") else ""
|
||||
),
|
||||
"tools": agent.get("tools", []),
|
||||
"tool_details": resolve_tool_details(agent.get("tools", [])),
|
||||
"agent_type": agent.get("agent_type", ""),
|
||||
"status": agent.get("status", ""),
|
||||
"json_schema": agent.get("json_schema"),
|
||||
"limited_token_mode": agent.get("limited_token_mode", False),
|
||||
"token_limit": agent.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"]),
|
||||
"limited_request_mode": agent.get("limited_request_mode", False),
|
||||
"request_limit": agent.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"]),
|
||||
"created_at": agent.get("createdAt", ""),
|
||||
"updated_at": agent.get("updatedAt", ""),
|
||||
"pinned": str(agent["_id"]) in pinned_ids,
|
||||
"shared": agent.get("shared_publicly", False),
|
||||
"shared_token": agent.get("shared_token", ""),
|
||||
"shared_metadata": agent.get("shared_metadata", {}),
|
||||
}
|
||||
for agent in shared_agents
|
||||
]
|
||||
|
||||
return make_response(jsonify(list_shared_agents), 200)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error retrieving shared agents: {err}")
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
|
||||
@agents_sharing_ns.route("/share_agent")
|
||||
class ShareAgent(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"ShareAgentModel",
|
||||
{
|
||||
"id": fields.String(required=True, description="ID of the agent"),
|
||||
"shared": fields.Boolean(
|
||||
required=True, description="Share or unshare the agent"
|
||||
),
|
||||
"username": fields.String(
|
||||
required=False, description="Name of the user"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Share or unshare an agent")
|
||||
def put(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
|
||||
data = request.get_json()
|
||||
if not data:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Missing JSON body"}), 400
|
||||
)
|
||||
agent_id = data.get("id")
|
||||
shared = data.get("shared")
|
||||
username = data.get("username", "")
|
||||
|
||||
if not agent_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "ID is required"}), 400
|
||||
)
|
||||
if shared is None:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Shared parameter is required and must be true or false",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
try:
|
||||
try:
|
||||
agent_oid = ObjectId(agent_id)
|
||||
except Exception:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid agent ID"}), 400
|
||||
)
|
||||
agent = agents_collection.find_one({"_id": agent_oid, "user": user})
|
||||
if not agent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Agent not found"}), 404
|
||||
)
|
||||
if shared:
|
||||
shared_metadata = {
|
||||
"shared_by": username,
|
||||
"shared_at": datetime.datetime.now(datetime.timezone.utc),
|
||||
}
|
||||
shared_token = secrets.token_urlsafe(32)
|
||||
agents_collection.update_one(
|
||||
{"_id": agent_oid, "user": user},
|
||||
{
|
||||
"$set": {
|
||||
"shared_publicly": shared,
|
||||
"shared_metadata": shared_metadata,
|
||||
"shared_token": shared_token,
|
||||
}
|
||||
},
|
||||
)
|
||||
else:
|
||||
agents_collection.update_one(
|
||||
{"_id": agent_oid, "user": user},
|
||||
{"$set": {"shared_publicly": shared, "shared_token": None}},
|
||||
{"$unset": {"shared_metadata": ""}},
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error sharing/unsharing agent: {err}")
|
||||
return make_response(jsonify({"success": False, "error": str(err)}), 400)
|
||||
shared_token = shared_token if shared else None
|
||||
return make_response(
|
||||
jsonify({"success": True, "shared_token": shared_token}), 200
|
||||
)
|
||||
@@ -1,119 +0,0 @@
|
||||
"""Agent management webhook handlers."""
|
||||
|
||||
import secrets
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import agents_collection, require_agent
|
||||
from application.api.user.tasks import process_agent_webhook
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
agents_webhooks_ns = Namespace(
|
||||
"agents", description="Agent management operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
@agents_webhooks_ns.route("/agent_webhook")
|
||||
class AgentWebhook(Resource):
|
||||
@api.doc(
|
||||
params={"id": "ID of the agent"},
|
||||
description="Generate webhook URL for the agent",
|
||||
)
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
agent_id = request.args.get("id")
|
||||
if not agent_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "ID is required"}), 400
|
||||
)
|
||||
try:
|
||||
agent = agents_collection.find_one(
|
||||
{"_id": ObjectId(agent_id), "user": user}
|
||||
)
|
||||
if not agent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Agent not found"}), 404
|
||||
)
|
||||
webhook_token = agent.get("incoming_webhook_token")
|
||||
if not webhook_token:
|
||||
webhook_token = secrets.token_urlsafe(32)
|
||||
agents_collection.update_one(
|
||||
{"_id": ObjectId(agent_id), "user": user},
|
||||
{"$set": {"incoming_webhook_token": webhook_token}},
|
||||
)
|
||||
base_url = settings.API_URL.rstrip("/")
|
||||
full_webhook_url = f"{base_url}/api/webhooks/agents/{webhook_token}"
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error generating webhook URL: {err}", exc_info=True
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Error generating webhook URL"}),
|
||||
400,
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": True, "webhook_url": full_webhook_url}), 200
|
||||
)
|
||||
|
||||
|
||||
@agents_webhooks_ns.route("/webhooks/agents/<string:webhook_token>")
|
||||
class AgentWebhookListener(Resource):
|
||||
method_decorators = [require_agent]
|
||||
|
||||
def _enqueue_webhook_task(self, agent_id_str, payload, source_method):
|
||||
if not payload:
|
||||
current_app.logger.warning(
|
||||
f"Webhook ({source_method}) received for agent {agent_id_str} with empty payload."
|
||||
)
|
||||
current_app.logger.info(
|
||||
f"Incoming {source_method} webhook for agent {agent_id_str}. Enqueuing task with payload: {payload}"
|
||||
)
|
||||
|
||||
try:
|
||||
task = process_agent_webhook.delay(
|
||||
agent_id=agent_id_str,
|
||||
payload=payload,
|
||||
)
|
||||
current_app.logger.info(
|
||||
f"Task {task.id} enqueued for agent {agent_id_str} ({source_method})."
|
||||
)
|
||||
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error enqueuing webhook task ({source_method}) for agent {agent_id_str}: {err}",
|
||||
exc_info=True,
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Error processing webhook"}), 500
|
||||
)
|
||||
|
||||
@api.doc(
|
||||
description="Webhook listener for agent events (POST). Expects JSON payload, which is used to trigger processing.",
|
||||
)
|
||||
def post(self, webhook_token, agent, agent_id_str):
|
||||
payload = request.get_json()
|
||||
if payload is None:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Invalid or missing JSON data in request body",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
return self._enqueue_webhook_task(agent_id_str, payload, source_method="POST")
|
||||
|
||||
@api.doc(
|
||||
description="Webhook listener for agent events (GET). Uses URL query parameters as payload to trigger processing.",
|
||||
)
|
||||
def get(self, webhook_token, agent, agent_id_str):
|
||||
payload = request.args.to_dict(flat=True)
|
||||
return self._enqueue_webhook_task(agent_id_str, payload, source_method="GET")
|
||||
@@ -1,5 +0,0 @@
|
||||
"""Analytics module."""
|
||||
|
||||
from .routes import analytics_ns
|
||||
|
||||
__all__ = ["analytics_ns"]
|
||||
@@ -1,540 +0,0 @@
|
||||
"""Analytics and reporting routes."""
|
||||
|
||||
import datetime
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import (
|
||||
agents_collection,
|
||||
conversations_collection,
|
||||
generate_date_range,
|
||||
generate_hourly_range,
|
||||
generate_minute_range,
|
||||
token_usage_collection,
|
||||
user_logs_collection,
|
||||
)
|
||||
|
||||
analytics_ns = Namespace(
|
||||
"analytics", description="Analytics and reporting operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
@analytics_ns.route("/get_message_analytics")
|
||||
class GetMessageAnalytics(Resource):
|
||||
get_message_analytics_model = api.model(
|
||||
"GetMessageAnalyticsModel",
|
||||
{
|
||||
"api_key_id": fields.String(required=False, description="API Key ID"),
|
||||
"filter_option": fields.String(
|
||||
required=False,
|
||||
description="Filter option for analytics",
|
||||
default="last_30_days",
|
||||
enum=[
|
||||
"last_hour",
|
||||
"last_24_hour",
|
||||
"last_7_days",
|
||||
"last_15_days",
|
||||
"last_30_days",
|
||||
],
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(get_message_analytics_model)
|
||||
@api.doc(description="Get message analytics based on filter option")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
api_key_id = data.get("api_key_id")
|
||||
filter_option = data.get("filter_option", "last_30_days")
|
||||
|
||||
try:
|
||||
api_key = (
|
||||
agents_collection.find_one({"_id": ObjectId(api_key_id), "user": user})[
|
||||
"key"
|
||||
]
|
||||
if api_key_id
|
||||
else None
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
end_date = datetime.datetime.now(datetime.timezone.utc)
|
||||
|
||||
if filter_option == "last_hour":
|
||||
start_date = end_date - datetime.timedelta(hours=1)
|
||||
group_format = "%Y-%m-%d %H:%M:00"
|
||||
elif filter_option == "last_24_hour":
|
||||
start_date = end_date - datetime.timedelta(hours=24)
|
||||
group_format = "%Y-%m-%d %H:00"
|
||||
else:
|
||||
if filter_option in ["last_7_days", "last_15_days", "last_30_days"]:
|
||||
filter_days = (
|
||||
6
|
||||
if filter_option == "last_7_days"
|
||||
else 14 if filter_option == "last_15_days" else 29
|
||||
)
|
||||
else:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid option"}), 400
|
||||
)
|
||||
start_date = end_date - datetime.timedelta(days=filter_days)
|
||||
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
end_date = end_date.replace(
|
||||
hour=23, minute=59, second=59, microsecond=999999
|
||||
)
|
||||
group_format = "%Y-%m-%d"
|
||||
try:
|
||||
match_stage = {
|
||||
"$match": {
|
||||
"user": user,
|
||||
}
|
||||
}
|
||||
if api_key:
|
||||
match_stage["$match"]["api_key"] = api_key
|
||||
pipeline = [
|
||||
match_stage,
|
||||
{"$unwind": "$queries"},
|
||||
{
|
||||
"$match": {
|
||||
"queries.timestamp": {"$gte": start_date, "$lte": end_date}
|
||||
}
|
||||
},
|
||||
{
|
||||
"$group": {
|
||||
"_id": {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$queries.timestamp",
|
||||
}
|
||||
},
|
||||
"count": {"$sum": 1},
|
||||
}
|
||||
},
|
||||
{"$sort": {"_id": 1}},
|
||||
]
|
||||
|
||||
message_data = conversations_collection.aggregate(pipeline)
|
||||
|
||||
if filter_option == "last_hour":
|
||||
intervals = generate_minute_range(start_date, end_date)
|
||||
elif filter_option == "last_24_hour":
|
||||
intervals = generate_hourly_range(start_date, end_date)
|
||||
else:
|
||||
intervals = generate_date_range(start_date, end_date)
|
||||
daily_messages = {interval: 0 for interval in intervals}
|
||||
|
||||
for entry in message_data:
|
||||
daily_messages[entry["_id"]] = entry["count"]
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error getting message analytics: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(
|
||||
jsonify({"success": True, "messages": daily_messages}), 200
|
||||
)
|
||||
|
||||
|
||||
@analytics_ns.route("/get_token_analytics")
|
||||
class GetTokenAnalytics(Resource):
|
||||
get_token_analytics_model = api.model(
|
||||
"GetTokenAnalyticsModel",
|
||||
{
|
||||
"api_key_id": fields.String(required=False, description="API Key ID"),
|
||||
"filter_option": fields.String(
|
||||
required=False,
|
||||
description="Filter option for analytics",
|
||||
default="last_30_days",
|
||||
enum=[
|
||||
"last_hour",
|
||||
"last_24_hour",
|
||||
"last_7_days",
|
||||
"last_15_days",
|
||||
"last_30_days",
|
||||
],
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(get_token_analytics_model)
|
||||
@api.doc(description="Get token analytics data")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
api_key_id = data.get("api_key_id")
|
||||
filter_option = data.get("filter_option", "last_30_days")
|
||||
|
||||
try:
|
||||
api_key = (
|
||||
agents_collection.find_one({"_id": ObjectId(api_key_id), "user": user})[
|
||||
"key"
|
||||
]
|
||||
if api_key_id
|
||||
else None
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
end_date = datetime.datetime.now(datetime.timezone.utc)
|
||||
|
||||
if filter_option == "last_hour":
|
||||
start_date = end_date - datetime.timedelta(hours=1)
|
||||
group_format = "%Y-%m-%d %H:%M:00"
|
||||
group_stage = {
|
||||
"$group": {
|
||||
"_id": {
|
||||
"minute": {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$timestamp",
|
||||
}
|
||||
}
|
||||
},
|
||||
"total_tokens": {
|
||||
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
|
||||
},
|
||||
}
|
||||
}
|
||||
elif filter_option == "last_24_hour":
|
||||
start_date = end_date - datetime.timedelta(hours=24)
|
||||
group_format = "%Y-%m-%d %H:00"
|
||||
group_stage = {
|
||||
"$group": {
|
||||
"_id": {
|
||||
"hour": {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$timestamp",
|
||||
}
|
||||
}
|
||||
},
|
||||
"total_tokens": {
|
||||
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
|
||||
},
|
||||
}
|
||||
}
|
||||
else:
|
||||
if filter_option in ["last_7_days", "last_15_days", "last_30_days"]:
|
||||
filter_days = (
|
||||
6
|
||||
if filter_option == "last_7_days"
|
||||
else (14 if filter_option == "last_15_days" else 29)
|
||||
)
|
||||
else:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid option"}), 400
|
||||
)
|
||||
start_date = end_date - datetime.timedelta(days=filter_days)
|
||||
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
end_date = end_date.replace(
|
||||
hour=23, minute=59, second=59, microsecond=999999
|
||||
)
|
||||
group_format = "%Y-%m-%d"
|
||||
group_stage = {
|
||||
"$group": {
|
||||
"_id": {
|
||||
"day": {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$timestamp",
|
||||
}
|
||||
}
|
||||
},
|
||||
"total_tokens": {
|
||||
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
|
||||
},
|
||||
}
|
||||
}
|
||||
try:
|
||||
match_stage = {
|
||||
"$match": {
|
||||
"user_id": user,
|
||||
"timestamp": {"$gte": start_date, "$lte": end_date},
|
||||
}
|
||||
}
|
||||
if api_key:
|
||||
match_stage["$match"]["api_key"] = api_key
|
||||
token_usage_data = token_usage_collection.aggregate(
|
||||
[
|
||||
match_stage,
|
||||
group_stage,
|
||||
{"$sort": {"_id": 1}},
|
||||
]
|
||||
)
|
||||
|
||||
if filter_option == "last_hour":
|
||||
intervals = generate_minute_range(start_date, end_date)
|
||||
elif filter_option == "last_24_hour":
|
||||
intervals = generate_hourly_range(start_date, end_date)
|
||||
else:
|
||||
intervals = generate_date_range(start_date, end_date)
|
||||
daily_token_usage = {interval: 0 for interval in intervals}
|
||||
|
||||
for entry in token_usage_data:
|
||||
if filter_option == "last_hour":
|
||||
daily_token_usage[entry["_id"]["minute"]] = entry["total_tokens"]
|
||||
elif filter_option == "last_24_hour":
|
||||
daily_token_usage[entry["_id"]["hour"]] = entry["total_tokens"]
|
||||
else:
|
||||
daily_token_usage[entry["_id"]["day"]] = entry["total_tokens"]
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error getting token analytics: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(
|
||||
jsonify({"success": True, "token_usage": daily_token_usage}), 200
|
||||
)
|
||||
|
||||
|
||||
@analytics_ns.route("/get_feedback_analytics")
|
||||
class GetFeedbackAnalytics(Resource):
|
||||
get_feedback_analytics_model = api.model(
|
||||
"GetFeedbackAnalyticsModel",
|
||||
{
|
||||
"api_key_id": fields.String(required=False, description="API Key ID"),
|
||||
"filter_option": fields.String(
|
||||
required=False,
|
||||
description="Filter option for analytics",
|
||||
default="last_30_days",
|
||||
enum=[
|
||||
"last_hour",
|
||||
"last_24_hour",
|
||||
"last_7_days",
|
||||
"last_15_days",
|
||||
"last_30_days",
|
||||
],
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(get_feedback_analytics_model)
|
||||
@api.doc(description="Get feedback analytics data")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
api_key_id = data.get("api_key_id")
|
||||
filter_option = data.get("filter_option", "last_30_days")
|
||||
|
||||
try:
|
||||
api_key = (
|
||||
agents_collection.find_one({"_id": ObjectId(api_key_id), "user": user})[
|
||||
"key"
|
||||
]
|
||||
if api_key_id
|
||||
else None
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
end_date = datetime.datetime.now(datetime.timezone.utc)
|
||||
|
||||
if filter_option == "last_hour":
|
||||
start_date = end_date - datetime.timedelta(hours=1)
|
||||
group_format = "%Y-%m-%d %H:%M:00"
|
||||
date_field = {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$queries.feedback_timestamp",
|
||||
}
|
||||
}
|
||||
elif filter_option == "last_24_hour":
|
||||
start_date = end_date - datetime.timedelta(hours=24)
|
||||
group_format = "%Y-%m-%d %H:00"
|
||||
date_field = {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$queries.feedback_timestamp",
|
||||
}
|
||||
}
|
||||
else:
|
||||
if filter_option in ["last_7_days", "last_15_days", "last_30_days"]:
|
||||
filter_days = (
|
||||
6
|
||||
if filter_option == "last_7_days"
|
||||
else (14 if filter_option == "last_15_days" else 29)
|
||||
)
|
||||
else:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid option"}), 400
|
||||
)
|
||||
start_date = end_date - datetime.timedelta(days=filter_days)
|
||||
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
end_date = end_date.replace(
|
||||
hour=23, minute=59, second=59, microsecond=999999
|
||||
)
|
||||
group_format = "%Y-%m-%d"
|
||||
date_field = {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$queries.feedback_timestamp",
|
||||
}
|
||||
}
|
||||
try:
|
||||
match_stage = {
|
||||
"$match": {
|
||||
"queries.feedback_timestamp": {
|
||||
"$gte": start_date,
|
||||
"$lte": end_date,
|
||||
},
|
||||
"queries.feedback": {"$exists": True},
|
||||
}
|
||||
}
|
||||
if api_key:
|
||||
match_stage["$match"]["api_key"] = api_key
|
||||
pipeline = [
|
||||
match_stage,
|
||||
{"$unwind": "$queries"},
|
||||
{"$match": {"queries.feedback": {"$exists": True}}},
|
||||
{
|
||||
"$group": {
|
||||
"_id": {"time": date_field, "feedback": "$queries.feedback"},
|
||||
"count": {"$sum": 1},
|
||||
}
|
||||
},
|
||||
{
|
||||
"$group": {
|
||||
"_id": "$_id.time",
|
||||
"positive": {
|
||||
"$sum": {
|
||||
"$cond": [
|
||||
{"$eq": ["$_id.feedback", "LIKE"]},
|
||||
"$count",
|
||||
0,
|
||||
]
|
||||
}
|
||||
},
|
||||
"negative": {
|
||||
"$sum": {
|
||||
"$cond": [
|
||||
{"$eq": ["$_id.feedback", "DISLIKE"]},
|
||||
"$count",
|
||||
0,
|
||||
]
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
{"$sort": {"_id": 1}},
|
||||
]
|
||||
|
||||
feedback_data = conversations_collection.aggregate(pipeline)
|
||||
|
||||
if filter_option == "last_hour":
|
||||
intervals = generate_minute_range(start_date, end_date)
|
||||
elif filter_option == "last_24_hour":
|
||||
intervals = generate_hourly_range(start_date, end_date)
|
||||
else:
|
||||
intervals = generate_date_range(start_date, end_date)
|
||||
daily_feedback = {
|
||||
interval: {"positive": 0, "negative": 0} for interval in intervals
|
||||
}
|
||||
|
||||
for entry in feedback_data:
|
||||
daily_feedback[entry["_id"]] = {
|
||||
"positive": entry["positive"],
|
||||
"negative": entry["negative"],
|
||||
}
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error getting feedback analytics: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(
|
||||
jsonify({"success": True, "feedback": daily_feedback}), 200
|
||||
)
|
||||
|
||||
|
||||
@analytics_ns.route("/get_user_logs")
|
||||
class GetUserLogs(Resource):
|
||||
get_user_logs_model = api.model(
|
||||
"GetUserLogsModel",
|
||||
{
|
||||
"page": fields.Integer(
|
||||
required=False,
|
||||
description="Page number for pagination",
|
||||
default=1,
|
||||
),
|
||||
"api_key_id": fields.String(required=False, description="API Key ID"),
|
||||
"page_size": fields.Integer(
|
||||
required=False,
|
||||
description="Number of logs per page",
|
||||
default=10,
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(get_user_logs_model)
|
||||
@api.doc(description="Get user logs with pagination")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
page = int(data.get("page", 1))
|
||||
api_key_id = data.get("api_key_id")
|
||||
page_size = int(data.get("page_size", 10))
|
||||
skip = (page - 1) * page_size
|
||||
|
||||
try:
|
||||
api_key = (
|
||||
agents_collection.find_one({"_id": ObjectId(api_key_id)})["key"]
|
||||
if api_key_id
|
||||
else None
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
query = {"user": user}
|
||||
if api_key:
|
||||
query = {"api_key": api_key}
|
||||
items_cursor = (
|
||||
user_logs_collection.find(query)
|
||||
.sort("timestamp", -1)
|
||||
.skip(skip)
|
||||
.limit(page_size + 1)
|
||||
)
|
||||
items = list(items_cursor)
|
||||
|
||||
results = [
|
||||
{
|
||||
"id": str(item.get("_id")),
|
||||
"action": item.get("action"),
|
||||
"level": item.get("level"),
|
||||
"user": item.get("user"),
|
||||
"question": item.get("question"),
|
||||
"sources": item.get("sources"),
|
||||
"retriever_params": item.get("retriever_params"),
|
||||
"timestamp": item.get("timestamp"),
|
||||
}
|
||||
for item in items[:page_size]
|
||||
]
|
||||
|
||||
has_more = len(items) > page_size
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"logs": results,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"has_more": has_more,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
@@ -1,5 +0,0 @@
|
||||
"""Attachments module."""
|
||||
|
||||
from .routes import attachments_ns
|
||||
|
||||
__all__ = ["attachments_ns"]
|
||||
@@ -1,198 +0,0 @@
|
||||
"""File attachments and media routes."""
|
||||
|
||||
import os
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import agents_collection, storage
|
||||
from application.api.user.tasks import store_attachment
|
||||
from application.core.settings import settings
|
||||
from application.tts.tts_creator import TTSCreator
|
||||
from application.utils import safe_filename
|
||||
|
||||
|
||||
attachments_ns = Namespace(
|
||||
"attachments", description="File attachments and media operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
@attachments_ns.route("/store_attachment")
|
||||
class StoreAttachment(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"AttachmentModel",
|
||||
{
|
||||
"file": fields.Raw(required=True, description="File(s) to upload"),
|
||||
"api_key": fields.String(
|
||||
required=False, description="API key (optional)"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Stores one or multiple attachments without vectorization or training. Supports user or API key authentication."
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = getattr(request, "decoded_token", None)
|
||||
api_key = request.form.get("api_key") or request.args.get("api_key")
|
||||
|
||||
files = request.files.getlist("file")
|
||||
if not files:
|
||||
single_file = request.files.get("file")
|
||||
if single_file:
|
||||
files = [single_file]
|
||||
|
||||
if not files or all(f.filename == "" for f in files):
|
||||
return make_response(
|
||||
jsonify({"status": "error", "message": "Missing file(s)"}),
|
||||
400,
|
||||
)
|
||||
|
||||
user = None
|
||||
if decoded_token:
|
||||
user = safe_filename(decoded_token.get("sub"))
|
||||
elif api_key:
|
||||
agent = agents_collection.find_one({"key": api_key})
|
||||
if not agent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid API key"}), 401
|
||||
)
|
||||
user = safe_filename(agent.get("user"))
|
||||
else:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Authentication required"}), 401
|
||||
)
|
||||
|
||||
try:
|
||||
tasks = []
|
||||
errors = []
|
||||
original_file_count = len(files)
|
||||
|
||||
for idx, file in enumerate(files):
|
||||
try:
|
||||
attachment_id = ObjectId()
|
||||
original_filename = safe_filename(os.path.basename(file.filename))
|
||||
relative_path = f"{settings.UPLOAD_FOLDER}/{user}/attachments/{str(attachment_id)}/{original_filename}"
|
||||
|
||||
metadata = storage.save_file(file, relative_path)
|
||||
file_info = {
|
||||
"filename": original_filename,
|
||||
"attachment_id": str(attachment_id),
|
||||
"path": relative_path,
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
task = store_attachment.delay(file_info, user)
|
||||
tasks.append({
|
||||
"task_id": task.id,
|
||||
"filename": original_filename,
|
||||
"attachment_id": str(attachment_id),
|
||||
})
|
||||
except Exception as file_err:
|
||||
current_app.logger.error(f"Error processing file {idx} ({file.filename}): {file_err}", exc_info=True)
|
||||
errors.append({
|
||||
"filename": file.filename,
|
||||
"error": str(file_err)
|
||||
})
|
||||
|
||||
if not tasks:
|
||||
error_msg = "No valid files to upload"
|
||||
if errors:
|
||||
error_msg += f". Errors: {errors}"
|
||||
return make_response(
|
||||
jsonify({"status": "error", "message": error_msg, "errors": errors}),
|
||||
400,
|
||||
)
|
||||
|
||||
if original_file_count == 1 and len(tasks) == 1:
|
||||
current_app.logger.info("Returning single task_id response")
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"task_id": tasks[0]["task_id"],
|
||||
"message": "File uploaded successfully. Processing started.",
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
else:
|
||||
response_data = {
|
||||
"success": True,
|
||||
"tasks": tasks,
|
||||
"message": f"{len(tasks)} file(s) uploaded successfully. Processing started.",
|
||||
}
|
||||
if errors:
|
||||
response_data["errors"] = errors
|
||||
response_data["message"] += f" {len(errors)} file(s) failed."
|
||||
|
||||
return make_response(
|
||||
jsonify(response_data),
|
||||
200,
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error storing attachment: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False, "error": str(err)}), 400)
|
||||
|
||||
|
||||
@attachments_ns.route("/images/<path:image_path>")
|
||||
class ServeImage(Resource):
|
||||
@api.doc(description="Serve an image from storage")
|
||||
def get(self, image_path):
|
||||
try:
|
||||
file_obj = storage.get_file(image_path)
|
||||
extension = image_path.split(".")[-1].lower()
|
||||
content_type = f"image/{extension}"
|
||||
if extension == "jpg":
|
||||
content_type = "image/jpeg"
|
||||
response = make_response(file_obj.read())
|
||||
response.headers.set("Content-Type", content_type)
|
||||
response.headers.set("Cache-Control", "max-age=86400")
|
||||
|
||||
return response
|
||||
except FileNotFoundError:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Image not found"}), 404
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error serving image: {e}")
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Error retrieving image"}), 500
|
||||
)
|
||||
|
||||
|
||||
@attachments_ns.route("/tts")
|
||||
class TextToSpeech(Resource):
|
||||
tts_model = api.model(
|
||||
"TextToSpeechModel",
|
||||
{
|
||||
"text": fields.String(
|
||||
required=True, description="Text to be synthesized as audio"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(tts_model)
|
||||
@api.doc(description="Synthesize audio speech from text")
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
text = data["text"]
|
||||
try:
|
||||
tts_instance = TTSCreator.create_tts(settings.TTS_PROVIDER)
|
||||
audio_base64, detected_language = tts_instance.text_to_speech(text)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"audio_base64": audio_base64,
|
||||
"lang": detected_language,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error synthesizing audio: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
@@ -1,222 +0,0 @@
|
||||
"""
|
||||
Shared utilities, database connections, and helper functions for user API routes.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import os
|
||||
import uuid
|
||||
from functools import wraps
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, Response
|
||||
from pymongo import ReturnDocument
|
||||
from werkzeug.utils import secure_filename
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.storage.storage_creator import StorageCreator
|
||||
from application.vectorstore.vector_creator import VectorCreator
|
||||
|
||||
|
||||
storage = StorageCreator.get_storage()
|
||||
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
|
||||
|
||||
conversations_collection = db["conversations"]
|
||||
sources_collection = db["sources"]
|
||||
prompts_collection = db["prompts"]
|
||||
feedback_collection = db["feedback"]
|
||||
agents_collection = db["agents"]
|
||||
token_usage_collection = db["token_usage"]
|
||||
shared_conversations_collections = db["shared_conversations"]
|
||||
users_collection = db["users"]
|
||||
user_logs_collection = db["user_logs"]
|
||||
user_tools_collection = db["user_tools"]
|
||||
attachments_collection = db["attachments"]
|
||||
|
||||
|
||||
try:
|
||||
agents_collection.create_index(
|
||||
[("shared", 1)],
|
||||
name="shared_index",
|
||||
background=True,
|
||||
)
|
||||
users_collection.create_index("user_id", unique=True)
|
||||
except Exception as e:
|
||||
print("Error creating indexes:", e)
|
||||
current_dir = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
)
|
||||
|
||||
|
||||
def generate_minute_range(start_date, end_date):
|
||||
"""Generate a dictionary with minute-level time ranges."""
|
||||
return {
|
||||
(start_date + datetime.timedelta(minutes=i)).strftime("%Y-%m-%d %H:%M:00"): 0
|
||||
for i in range(int((end_date - start_date).total_seconds() // 60) + 1)
|
||||
}
|
||||
|
||||
|
||||
def generate_hourly_range(start_date, end_date):
|
||||
"""Generate a dictionary with hourly time ranges."""
|
||||
return {
|
||||
(start_date + datetime.timedelta(hours=i)).strftime("%Y-%m-%d %H:00"): 0
|
||||
for i in range(int((end_date - start_date).total_seconds() // 3600) + 1)
|
||||
}
|
||||
|
||||
|
||||
def generate_date_range(start_date, end_date):
|
||||
"""Generate a dictionary with daily date ranges."""
|
||||
return {
|
||||
(start_date + datetime.timedelta(days=i)).strftime("%Y-%m-%d"): 0
|
||||
for i in range((end_date - start_date).days + 1)
|
||||
}
|
||||
|
||||
|
||||
def ensure_user_doc(user_id):
|
||||
"""
|
||||
Ensure user document exists with proper agent preferences structure.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to ensure
|
||||
|
||||
Returns:
|
||||
The user document
|
||||
"""
|
||||
default_prefs = {
|
||||
"pinned": [],
|
||||
"shared_with_me": [],
|
||||
}
|
||||
|
||||
user_doc = users_collection.find_one_and_update(
|
||||
{"user_id": user_id},
|
||||
{"$setOnInsert": {"agent_preferences": default_prefs}},
|
||||
upsert=True,
|
||||
return_document=ReturnDocument.AFTER,
|
||||
)
|
||||
|
||||
prefs = user_doc.get("agent_preferences", {})
|
||||
updates = {}
|
||||
if "pinned" not in prefs:
|
||||
updates["agent_preferences.pinned"] = []
|
||||
if "shared_with_me" not in prefs:
|
||||
updates["agent_preferences.shared_with_me"] = []
|
||||
if updates:
|
||||
users_collection.update_one({"user_id": user_id}, {"$set": updates})
|
||||
user_doc = users_collection.find_one({"user_id": user_id})
|
||||
return user_doc
|
||||
|
||||
|
||||
def resolve_tool_details(tool_ids):
|
||||
"""
|
||||
Resolve tool IDs to their details.
|
||||
|
||||
Args:
|
||||
tool_ids: List of tool IDs
|
||||
|
||||
Returns:
|
||||
List of tool details with id, name, and display_name
|
||||
"""
|
||||
tools = user_tools_collection.find(
|
||||
{"_id": {"$in": [ObjectId(tid) for tid in tool_ids]}}
|
||||
)
|
||||
return [
|
||||
{
|
||||
"id": str(tool["_id"]),
|
||||
"name": tool.get("name", ""),
|
||||
"display_name": tool.get("displayName", tool.get("name", "")),
|
||||
}
|
||||
for tool in tools
|
||||
]
|
||||
|
||||
|
||||
def get_vector_store(source_id):
|
||||
"""
|
||||
Get the Vector Store for a given source ID.
|
||||
|
||||
Args:
|
||||
source_id (str): source id of the document
|
||||
|
||||
Returns:
|
||||
Vector store instance
|
||||
"""
|
||||
store = VectorCreator.create_vectorstore(
|
||||
settings.VECTOR_STORE,
|
||||
source_id=source_id,
|
||||
embeddings_key=os.getenv("EMBEDDINGS_KEY"),
|
||||
)
|
||||
return store
|
||||
|
||||
|
||||
def handle_image_upload(
|
||||
request, existing_url: str, user: str, storage, base_path: str = "attachments/"
|
||||
) -> Tuple[str, Optional[Response]]:
|
||||
"""
|
||||
Handle image file upload from request.
|
||||
|
||||
Args:
|
||||
request: Flask request object
|
||||
existing_url: Existing image URL (fallback)
|
||||
user: User ID
|
||||
storage: Storage instance
|
||||
base_path: Base path for upload
|
||||
|
||||
Returns:
|
||||
Tuple of (image_url, error_response)
|
||||
"""
|
||||
image_url = existing_url
|
||||
|
||||
if "image" in request.files:
|
||||
file = request.files["image"]
|
||||
if file.filename != "":
|
||||
filename = secure_filename(file.filename)
|
||||
upload_path = f"{settings.UPLOAD_FOLDER.rstrip('/')}/{user}/{base_path.rstrip('/')}/{uuid.uuid4()}_{filename}"
|
||||
try:
|
||||
storage.save_file(file, upload_path, storage_class="STANDARD")
|
||||
image_url = upload_path
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error uploading image: {e}")
|
||||
return None, make_response(
|
||||
jsonify({"success": False, "message": "Image upload failed"}),
|
||||
400,
|
||||
)
|
||||
return image_url, None
|
||||
|
||||
|
||||
def require_agent(func):
|
||||
"""
|
||||
Decorator to require valid agent webhook token.
|
||||
|
||||
Args:
|
||||
func: Function to decorate
|
||||
|
||||
Returns:
|
||||
Wrapped function
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
webhook_token = kwargs.get("webhook_token")
|
||||
if not webhook_token:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Webhook token missing"}), 400
|
||||
)
|
||||
agent = agents_collection.find_one(
|
||||
{"incoming_webhook_token": webhook_token}, {"_id": 1}
|
||||
)
|
||||
if not agent:
|
||||
current_app.logger.warning(
|
||||
f"Webhook attempt with invalid token: {webhook_token}"
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Agent not found"}), 404
|
||||
)
|
||||
kwargs["agent"] = agent
|
||||
kwargs["agent_id_str"] = str(agent["_id"])
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
@@ -1,5 +0,0 @@
|
||||
"""Conversation management module."""
|
||||
|
||||
from .routes import conversations_ns
|
||||
|
||||
__all__ = ["conversations_ns"]
|
||||
@@ -1,280 +0,0 @@
|
||||
"""Conversation management routes."""
|
||||
|
||||
import datetime
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import attachments_collection, conversations_collection
|
||||
from application.utils import check_required_fields
|
||||
|
||||
conversations_ns = Namespace(
|
||||
"conversations", description="Conversation management operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
@conversations_ns.route("/delete_conversation")
|
||||
class DeleteConversation(Resource):
|
||||
@api.doc(
|
||||
description="Deletes a conversation by ID",
|
||||
params={"id": "The ID of the conversation to delete"},
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
conversation_id = request.args.get("id")
|
||||
if not conversation_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "ID is required"}), 400
|
||||
)
|
||||
try:
|
||||
conversations_collection.delete_one(
|
||||
{"_id": ObjectId(conversation_id), "user": decoded_token["sub"]}
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error deleting conversation: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@conversations_ns.route("/delete_all_conversations")
|
||||
class DeleteAllConversations(Resource):
|
||||
@api.doc(
|
||||
description="Deletes all conversations for a specific user",
|
||||
)
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user_id = decoded_token.get("sub")
|
||||
try:
|
||||
conversations_collection.delete_many({"user": user_id})
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error deleting all conversations: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@conversations_ns.route("/get_conversations")
|
||||
class GetConversations(Resource):
|
||||
@api.doc(
|
||||
description="Retrieve a list of the latest 30 conversations (excluding API key conversations)",
|
||||
)
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
try:
|
||||
conversations = (
|
||||
conversations_collection.find(
|
||||
{
|
||||
"$or": [
|
||||
{"api_key": {"$exists": False}},
|
||||
{"agent_id": {"$exists": True}},
|
||||
],
|
||||
"user": decoded_token.get("sub"),
|
||||
}
|
||||
)
|
||||
.sort("date", -1)
|
||||
.limit(30)
|
||||
)
|
||||
|
||||
list_conversations = [
|
||||
{
|
||||
"id": str(conversation["_id"]),
|
||||
"name": conversation["name"],
|
||||
"agent_id": conversation.get("agent_id", None),
|
||||
"is_shared_usage": conversation.get("is_shared_usage", False),
|
||||
"shared_token": conversation.get("shared_token", None),
|
||||
}
|
||||
for conversation in conversations
|
||||
]
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error retrieving conversations: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify(list_conversations), 200)
|
||||
|
||||
|
||||
@conversations_ns.route("/get_single_conversation")
|
||||
class GetSingleConversation(Resource):
|
||||
@api.doc(
|
||||
description="Retrieve a single conversation by ID",
|
||||
params={"id": "The conversation ID"},
|
||||
)
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
conversation_id = request.args.get("id")
|
||||
if not conversation_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "ID is required"}), 400
|
||||
)
|
||||
try:
|
||||
conversation = conversations_collection.find_one(
|
||||
{"_id": ObjectId(conversation_id), "user": decoded_token.get("sub")}
|
||||
)
|
||||
if not conversation:
|
||||
return make_response(jsonify({"status": "not found"}), 404)
|
||||
# Process queries to include attachment names
|
||||
|
||||
queries = conversation["queries"]
|
||||
for query in queries:
|
||||
if "attachments" in query and query["attachments"]:
|
||||
attachment_details = []
|
||||
for attachment_id in query["attachments"]:
|
||||
try:
|
||||
attachment = attachments_collection.find_one(
|
||||
{"_id": ObjectId(attachment_id)}
|
||||
)
|
||||
if attachment:
|
||||
attachment_details.append(
|
||||
{
|
||||
"id": str(attachment["_id"]),
|
||||
"fileName": attachment.get(
|
||||
"filename", "Unknown file"
|
||||
),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error retrieving attachment {attachment_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
query["attachments"] = attachment_details
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error retrieving conversation: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
data = {
|
||||
"queries": queries,
|
||||
"agent_id": conversation.get("agent_id"),
|
||||
"is_shared_usage": conversation.get("is_shared_usage", False),
|
||||
"shared_token": conversation.get("shared_token", None),
|
||||
}
|
||||
return make_response(jsonify(data), 200)
|
||||
|
||||
|
||||
@conversations_ns.route("/update_conversation_name")
|
||||
class UpdateConversationName(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"UpdateConversationModel",
|
||||
{
|
||||
"id": fields.String(required=True, description="Conversation ID"),
|
||||
"name": fields.String(
|
||||
required=True, description="New name of the conversation"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Updates the name of a conversation",
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
data = request.get_json()
|
||||
required_fields = ["id", "name"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
conversations_collection.update_one(
|
||||
{"_id": ObjectId(data["id"]), "user": decoded_token.get("sub")},
|
||||
{"$set": {"name": data["name"]}},
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error updating conversation name: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@conversations_ns.route("/feedback")
|
||||
class SubmitFeedback(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"FeedbackModel",
|
||||
{
|
||||
"question": fields.String(
|
||||
required=False, description="The user question"
|
||||
),
|
||||
"answer": fields.String(required=False, description="The AI answer"),
|
||||
"feedback": fields.String(required=True, description="User feedback"),
|
||||
"question_index": fields.Integer(
|
||||
required=True,
|
||||
description="The question number in that particular conversation",
|
||||
),
|
||||
"conversation_id": fields.String(
|
||||
required=True, description="id of the particular conversation"
|
||||
),
|
||||
"api_key": fields.String(description="Optional API key"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Submit feedback for a conversation",
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
data = request.get_json()
|
||||
required_fields = ["feedback", "conversation_id", "question_index"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
if data["feedback"] is None:
|
||||
# Remove feedback and feedback_timestamp if feedback is null
|
||||
|
||||
conversations_collection.update_one(
|
||||
{
|
||||
"_id": ObjectId(data["conversation_id"]),
|
||||
"user": decoded_token.get("sub"),
|
||||
f"queries.{data['question_index']}": {"$exists": True},
|
||||
},
|
||||
{
|
||||
"$unset": {
|
||||
f"queries.{data['question_index']}.feedback": "",
|
||||
f"queries.{data['question_index']}.feedback_timestamp": "",
|
||||
}
|
||||
},
|
||||
)
|
||||
else:
|
||||
# Set feedback and feedback_timestamp if feedback has a value
|
||||
|
||||
conversations_collection.update_one(
|
||||
{
|
||||
"_id": ObjectId(data["conversation_id"]),
|
||||
"user": decoded_token.get("sub"),
|
||||
f"queries.{data['question_index']}": {"$exists": True},
|
||||
},
|
||||
{
|
||||
"$set": {
|
||||
f"queries.{data['question_index']}.feedback": data[
|
||||
"feedback"
|
||||
],
|
||||
f"queries.{data['question_index']}.feedback_timestamp": datetime.datetime.now(
|
||||
datetime.timezone.utc
|
||||
),
|
||||
}
|
||||
},
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error submitting feedback: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
@@ -1,3 +0,0 @@
|
||||
from .routes import models_ns
|
||||
|
||||
__all__ = ["models_ns"]
|
||||
@@ -1,25 +0,0 @@
|
||||
from flask import current_app, jsonify, make_response
|
||||
from flask_restx import Namespace, Resource
|
||||
|
||||
from application.core.model_settings import ModelRegistry
|
||||
|
||||
models_ns = Namespace("models", description="Available models", path="/api")
|
||||
|
||||
|
||||
@models_ns.route("/models")
|
||||
class ModelsListResource(Resource):
|
||||
def get(self):
|
||||
"""Get list of available models with their capabilities."""
|
||||
try:
|
||||
registry = ModelRegistry.get_instance()
|
||||
models = registry.get_enabled_models()
|
||||
|
||||
response = {
|
||||
"models": [model.to_dict() for model in models],
|
||||
"default_model_id": registry.default_model_id,
|
||||
"count": len(models),
|
||||
}
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error fetching models: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 500)
|
||||
return make_response(jsonify(response), 200)
|
||||
@@ -1,5 +0,0 @@
|
||||
"""Prompts module."""
|
||||
|
||||
from .routes import prompts_ns
|
||||
|
||||
__all__ = ["prompts_ns"]
|
||||
@@ -1,191 +0,0 @@
|
||||
"""Prompt management routes."""
|
||||
|
||||
import os
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import current_dir, prompts_collection
|
||||
from application.utils import check_required_fields
|
||||
|
||||
prompts_ns = Namespace(
|
||||
"prompts", description="Prompt management operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
@prompts_ns.route("/create_prompt")
|
||||
class CreatePrompt(Resource):
|
||||
create_prompt_model = api.model(
|
||||
"CreatePromptModel",
|
||||
{
|
||||
"content": fields.String(
|
||||
required=True, description="Content of the prompt"
|
||||
),
|
||||
"name": fields.String(required=True, description="Name of the prompt"),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(create_prompt_model)
|
||||
@api.doc(description="Create a new prompt")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
data = request.get_json()
|
||||
required_fields = ["content", "name"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
|
||||
resp = prompts_collection.insert_one(
|
||||
{
|
||||
"name": data["name"],
|
||||
"content": data["content"],
|
||||
"user": user,
|
||||
}
|
||||
)
|
||||
new_id = str(resp.inserted_id)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error creating prompt: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"id": new_id}), 200)
|
||||
|
||||
|
||||
@prompts_ns.route("/get_prompts")
|
||||
class GetPrompts(Resource):
|
||||
@api.doc(description="Get all prompts for the user")
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
prompts = prompts_collection.find({"user": user})
|
||||
list_prompts = [
|
||||
{"id": "default", "name": "default", "type": "public"},
|
||||
{"id": "creative", "name": "creative", "type": "public"},
|
||||
{"id": "strict", "name": "strict", "type": "public"},
|
||||
]
|
||||
|
||||
for prompt in prompts:
|
||||
list_prompts.append(
|
||||
{
|
||||
"id": str(prompt["_id"]),
|
||||
"name": prompt["name"],
|
||||
"type": "private",
|
||||
}
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error retrieving prompts: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify(list_prompts), 200)
|
||||
|
||||
|
||||
@prompts_ns.route("/get_single_prompt")
|
||||
class GetSinglePrompt(Resource):
|
||||
@api.doc(params={"id": "ID of the prompt"}, description="Get a single prompt by ID")
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
prompt_id = request.args.get("id")
|
||||
if not prompt_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "ID is required"}), 400
|
||||
)
|
||||
try:
|
||||
if prompt_id == "default":
|
||||
with open(
|
||||
os.path.join(current_dir, "prompts", "chat_combine_default.txt"),
|
||||
"r",
|
||||
) as f:
|
||||
chat_combine_template = f.read()
|
||||
return make_response(jsonify({"content": chat_combine_template}), 200)
|
||||
elif prompt_id == "creative":
|
||||
with open(
|
||||
os.path.join(current_dir, "prompts", "chat_combine_creative.txt"),
|
||||
"r",
|
||||
) as f:
|
||||
chat_reduce_creative = f.read()
|
||||
return make_response(jsonify({"content": chat_reduce_creative}), 200)
|
||||
elif prompt_id == "strict":
|
||||
with open(
|
||||
os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r"
|
||||
) as f:
|
||||
chat_reduce_strict = f.read()
|
||||
return make_response(jsonify({"content": chat_reduce_strict}), 200)
|
||||
prompt = prompts_collection.find_one(
|
||||
{"_id": ObjectId(prompt_id), "user": user}
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error retrieving prompt: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"content": prompt["content"]}), 200)
|
||||
|
||||
|
||||
@prompts_ns.route("/delete_prompt")
|
||||
class DeletePrompt(Resource):
|
||||
delete_prompt_model = api.model(
|
||||
"DeletePromptModel",
|
||||
{"id": fields.String(required=True, description="Prompt ID to delete")},
|
||||
)
|
||||
|
||||
@api.expect(delete_prompt_model)
|
||||
@api.doc(description="Delete a prompt by ID")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["id"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
prompts_collection.delete_one({"_id": ObjectId(data["id"]), "user": user})
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error deleting prompt: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@prompts_ns.route("/update_prompt")
|
||||
class UpdatePrompt(Resource):
|
||||
update_prompt_model = api.model(
|
||||
"UpdatePromptModel",
|
||||
{
|
||||
"id": fields.String(required=True, description="Prompt ID to update"),
|
||||
"name": fields.String(required=True, description="New name of the prompt"),
|
||||
"content": fields.String(
|
||||
required=True, description="New content of the prompt"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(update_prompt_model)
|
||||
@api.doc(description="Update an existing prompt")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["id", "name", "content"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
prompts_collection.update_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user},
|
||||
{"$set": {"name": data["name"], "content": data["content"]}},
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error updating prompt: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,5 +0,0 @@
|
||||
"""Sharing module."""
|
||||
|
||||
from .routes import sharing_ns
|
||||
|
||||
__all__ = ["sharing_ns"]
|
||||
@@ -1,289 +0,0 @@
|
||||
"""Conversation sharing routes."""
|
||||
|
||||
import uuid
|
||||
|
||||
from bson.binary import Binary, UuidRepresentation
|
||||
from bson.dbref import DBRef
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, inputs, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import (
|
||||
agents_collection,
|
||||
attachments_collection,
|
||||
conversations_collection,
|
||||
shared_conversations_collections,
|
||||
)
|
||||
from application.utils import check_required_fields
|
||||
|
||||
sharing_ns = Namespace(
|
||||
"sharing", description="Conversation sharing operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
@sharing_ns.route("/share")
|
||||
class ShareConversation(Resource):
|
||||
share_conversation_model = api.model(
|
||||
"ShareConversationModel",
|
||||
{
|
||||
"conversation_id": fields.String(
|
||||
required=True, description="Conversation ID"
|
||||
),
|
||||
"user": fields.String(description="User ID (optional)"),
|
||||
"prompt_id": fields.String(description="Prompt ID (optional)"),
|
||||
"chunks": fields.Integer(description="Chunks count (optional)"),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(share_conversation_model)
|
||||
@api.doc(description="Share a conversation")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["conversation_id"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
is_promptable = request.args.get("isPromptable", type=inputs.boolean)
|
||||
if is_promptable is None:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "isPromptable is required"}), 400
|
||||
)
|
||||
conversation_id = data["conversation_id"]
|
||||
|
||||
try:
|
||||
conversation = conversations_collection.find_one(
|
||||
{"_id": ObjectId(conversation_id)}
|
||||
)
|
||||
if conversation is None:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"status": "error",
|
||||
"message": "Conversation does not exist",
|
||||
}
|
||||
),
|
||||
404,
|
||||
)
|
||||
current_n_queries = len(conversation["queries"])
|
||||
explicit_binary = Binary.from_uuid(
|
||||
uuid.uuid4(), UuidRepresentation.STANDARD
|
||||
)
|
||||
|
||||
if is_promptable:
|
||||
prompt_id = data.get("prompt_id", "default")
|
||||
chunks = data.get("chunks", "2")
|
||||
|
||||
name = conversation["name"] + "(shared)"
|
||||
new_api_key_data = {
|
||||
"prompt_id": prompt_id,
|
||||
"chunks": chunks,
|
||||
"user": user,
|
||||
}
|
||||
|
||||
if "source" in data and ObjectId.is_valid(data["source"]):
|
||||
new_api_key_data["source"] = DBRef(
|
||||
"sources", ObjectId(data["source"])
|
||||
)
|
||||
if "retriever" in data:
|
||||
new_api_key_data["retriever"] = data["retriever"]
|
||||
pre_existing_api_document = agents_collection.find_one(new_api_key_data)
|
||||
if pre_existing_api_document:
|
||||
api_uuid = pre_existing_api_document["key"]
|
||||
pre_existing = shared_conversations_collections.find_one(
|
||||
{
|
||||
"conversation_id": ObjectId(conversation_id),
|
||||
"isPromptable": is_promptable,
|
||||
"first_n_queries": current_n_queries,
|
||||
"user": user,
|
||||
"api_key": api_uuid,
|
||||
}
|
||||
)
|
||||
if pre_existing is not None:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"identifier": str(pre_existing["uuid"].as_uuid()),
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
else:
|
||||
shared_conversations_collections.insert_one(
|
||||
{
|
||||
"uuid": explicit_binary,
|
||||
"conversation_id": ObjectId(conversation_id),
|
||||
"isPromptable": is_promptable,
|
||||
"first_n_queries": current_n_queries,
|
||||
"user": user,
|
||||
"api_key": api_uuid,
|
||||
}
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"identifier": str(explicit_binary.as_uuid()),
|
||||
}
|
||||
),
|
||||
201,
|
||||
)
|
||||
else:
|
||||
api_uuid = str(uuid.uuid4())
|
||||
new_api_key_data["key"] = api_uuid
|
||||
new_api_key_data["name"] = name
|
||||
|
||||
if "source" in data and ObjectId.is_valid(data["source"]):
|
||||
new_api_key_data["source"] = DBRef(
|
||||
"sources", ObjectId(data["source"])
|
||||
)
|
||||
if "retriever" in data:
|
||||
new_api_key_data["retriever"] = data["retriever"]
|
||||
agents_collection.insert_one(new_api_key_data)
|
||||
shared_conversations_collections.insert_one(
|
||||
{
|
||||
"uuid": explicit_binary,
|
||||
"conversation_id": ObjectId(conversation_id),
|
||||
"isPromptable": is_promptable,
|
||||
"first_n_queries": current_n_queries,
|
||||
"user": user,
|
||||
"api_key": api_uuid,
|
||||
}
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"identifier": str(explicit_binary.as_uuid()),
|
||||
}
|
||||
),
|
||||
201,
|
||||
)
|
||||
pre_existing = shared_conversations_collections.find_one(
|
||||
{
|
||||
"conversation_id": ObjectId(conversation_id),
|
||||
"isPromptable": is_promptable,
|
||||
"first_n_queries": current_n_queries,
|
||||
"user": user,
|
||||
}
|
||||
)
|
||||
if pre_existing is not None:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"identifier": str(pre_existing["uuid"].as_uuid()),
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
else:
|
||||
shared_conversations_collections.insert_one(
|
||||
{
|
||||
"uuid": explicit_binary,
|
||||
"conversation_id": ObjectId(conversation_id),
|
||||
"isPromptable": is_promptable,
|
||||
"first_n_queries": current_n_queries,
|
||||
"user": user,
|
||||
}
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{"success": True, "identifier": str(explicit_binary.as_uuid())}
|
||||
),
|
||||
201,
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error sharing conversation: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
|
||||
@sharing_ns.route("/shared_conversation/<string:identifier>")
|
||||
class GetPubliclySharedConversations(Resource):
|
||||
@api.doc(description="Get publicly shared conversations by identifier")
|
||||
def get(self, identifier: str):
|
||||
try:
|
||||
query_uuid = Binary.from_uuid(
|
||||
uuid.UUID(identifier), UuidRepresentation.STANDARD
|
||||
)
|
||||
shared = shared_conversations_collections.find_one({"uuid": query_uuid})
|
||||
conversation_queries = []
|
||||
|
||||
if (
|
||||
shared
|
||||
and "conversation_id" in shared
|
||||
):
|
||||
# conversation_id is now stored as an ObjectId, not a DBRef
|
||||
conversation_id = shared["conversation_id"]
|
||||
conversation = conversations_collection.find_one(
|
||||
{"_id": conversation_id}
|
||||
)
|
||||
if conversation is None:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": "might have broken url or the conversation does not exist",
|
||||
}
|
||||
),
|
||||
404,
|
||||
)
|
||||
conversation_queries = conversation["queries"][
|
||||
: (shared["first_n_queries"])
|
||||
]
|
||||
|
||||
for query in conversation_queries:
|
||||
if "attachments" in query and query["attachments"]:
|
||||
attachment_details = []
|
||||
for attachment_id in query["attachments"]:
|
||||
try:
|
||||
attachment = attachments_collection.find_one(
|
||||
{"_id": ObjectId(attachment_id)}
|
||||
)
|
||||
if attachment:
|
||||
attachment_details.append(
|
||||
{
|
||||
"id": str(attachment["_id"]),
|
||||
"fileName": attachment.get(
|
||||
"filename", "Unknown file"
|
||||
),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error retrieving attachment {attachment_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
query["attachments"] = attachment_details
|
||||
else:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": "might have broken url or the conversation does not exist",
|
||||
}
|
||||
),
|
||||
404,
|
||||
)
|
||||
date = conversation["_id"].generation_time.isoformat()
|
||||
res = {
|
||||
"success": True,
|
||||
"queries": conversation_queries,
|
||||
"title": conversation["name"],
|
||||
"timestamp": date,
|
||||
}
|
||||
if shared["isPromptable"] and "api_key" in shared:
|
||||
res["api_key"] = shared["api_key"]
|
||||
return make_response(jsonify(res), 200)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error getting shared conversation: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
@@ -1,7 +0,0 @@
|
||||
"""Sources module."""
|
||||
|
||||
from .chunks import sources_chunks_ns
|
||||
from .routes import sources_ns
|
||||
from .upload import sources_upload_ns
|
||||
|
||||
__all__ = ["sources_ns", "sources_chunks_ns", "sources_upload_ns"]
|
||||
@@ -1,278 +0,0 @@
|
||||
"""Source document management chunk management."""
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import get_vector_store, sources_collection
|
||||
from application.utils import check_required_fields, num_tokens_from_string
|
||||
|
||||
sources_chunks_ns = Namespace(
|
||||
"sources", description="Source document management operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
@sources_chunks_ns.route("/get_chunks")
|
||||
class GetChunks(Resource):
|
||||
@api.doc(
|
||||
description="Retrieves chunks from a document, optionally filtered by file path and search term",
|
||||
params={
|
||||
"id": "The document ID",
|
||||
"page": "Page number for pagination",
|
||||
"per_page": "Number of chunks per page",
|
||||
"path": "Optional: Filter chunks by relative file path",
|
||||
"search": "Optional: Search term to filter chunks by title or content",
|
||||
},
|
||||
)
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
doc_id = request.args.get("id")
|
||||
page = int(request.args.get("page", 1))
|
||||
per_page = int(request.args.get("per_page", 10))
|
||||
path = request.args.get("path")
|
||||
search_term = request.args.get("search", "").strip().lower()
|
||||
|
||||
if not ObjectId.is_valid(doc_id):
|
||||
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
||||
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
||||
if not doc:
|
||||
return make_response(
|
||||
jsonify({"error": "Document not found or access denied"}), 404
|
||||
)
|
||||
try:
|
||||
store = get_vector_store(doc_id)
|
||||
chunks = store.get_chunks()
|
||||
|
||||
filtered_chunks = []
|
||||
for chunk in chunks:
|
||||
metadata = chunk.get("metadata", {})
|
||||
|
||||
# Filter by path if provided
|
||||
|
||||
if path:
|
||||
chunk_source = metadata.get("source", "")
|
||||
# Check if the chunk's source matches the requested path
|
||||
|
||||
if not chunk_source or not chunk_source.endswith(path):
|
||||
continue
|
||||
# Filter by search term if provided
|
||||
|
||||
if search_term:
|
||||
text_match = search_term in chunk.get("text", "").lower()
|
||||
title_match = search_term in metadata.get("title", "").lower()
|
||||
|
||||
if not (text_match or title_match):
|
||||
continue
|
||||
filtered_chunks.append(chunk)
|
||||
chunks = filtered_chunks
|
||||
|
||||
total_chunks = len(chunks)
|
||||
start = (page - 1) * per_page
|
||||
end = start + per_page
|
||||
paginated_chunks = chunks[start:end]
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
"total": total_chunks,
|
||||
"chunks": paginated_chunks,
|
||||
"path": path if path else None,
|
||||
"search": search_term if search_term else None,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error getting chunks: {e}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 500)
|
||||
|
||||
|
||||
@sources_chunks_ns.route("/add_chunk")
|
||||
class AddChunk(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"AddChunkModel",
|
||||
{
|
||||
"id": fields.String(required=True, description="Document ID"),
|
||||
"text": fields.String(required=True, description="Text of the chunk"),
|
||||
"metadata": fields.Raw(
|
||||
required=False,
|
||||
description="Metadata associated with the chunk",
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Adds a new chunk to the document",
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["id", "text"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
doc_id = data.get("id")
|
||||
text = data.get("text")
|
||||
metadata = data.get("metadata", {})
|
||||
token_count = num_tokens_from_string(text)
|
||||
metadata["token_count"] = token_count
|
||||
|
||||
if not ObjectId.is_valid(doc_id):
|
||||
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
||||
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
||||
if not doc:
|
||||
return make_response(
|
||||
jsonify({"error": "Document not found or access denied"}), 404
|
||||
)
|
||||
try:
|
||||
store = get_vector_store(doc_id)
|
||||
chunk_id = store.add_chunk(text, metadata)
|
||||
return make_response(
|
||||
jsonify({"message": "Chunk added successfully", "chunk_id": chunk_id}),
|
||||
201,
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error adding chunk: {e}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 500)
|
||||
|
||||
|
||||
@sources_chunks_ns.route("/delete_chunk")
|
||||
class DeleteChunk(Resource):
|
||||
@api.doc(
|
||||
description="Deletes a specific chunk from the document.",
|
||||
params={"id": "The document ID", "chunk_id": "The ID of the chunk to delete"},
|
||||
)
|
||||
def delete(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
doc_id = request.args.get("id")
|
||||
chunk_id = request.args.get("chunk_id")
|
||||
|
||||
if not ObjectId.is_valid(doc_id):
|
||||
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
||||
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
||||
if not doc:
|
||||
return make_response(
|
||||
jsonify({"error": "Document not found or access denied"}), 404
|
||||
)
|
||||
try:
|
||||
store = get_vector_store(doc_id)
|
||||
deleted = store.delete_chunk(chunk_id)
|
||||
if deleted:
|
||||
return make_response(
|
||||
jsonify({"message": "Chunk deleted successfully"}), 200
|
||||
)
|
||||
else:
|
||||
return make_response(
|
||||
jsonify({"message": "Chunk not found or could not be deleted"}),
|
||||
404,
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error deleting chunk: {e}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 500)
|
||||
|
||||
|
||||
@sources_chunks_ns.route("/update_chunk")
|
||||
class UpdateChunk(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"UpdateChunkModel",
|
||||
{
|
||||
"id": fields.String(required=True, description="Document ID"),
|
||||
"chunk_id": fields.String(
|
||||
required=True, description="Chunk ID to update"
|
||||
),
|
||||
"text": fields.String(
|
||||
required=False, description="New text of the chunk"
|
||||
),
|
||||
"metadata": fields.Raw(
|
||||
required=False,
|
||||
description="Updated metadata associated with the chunk",
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Updates an existing chunk in the document.",
|
||||
)
|
||||
def put(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["id", "chunk_id"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
doc_id = data.get("id")
|
||||
chunk_id = data.get("chunk_id")
|
||||
text = data.get("text")
|
||||
metadata = data.get("metadata")
|
||||
|
||||
if text is not None:
|
||||
token_count = num_tokens_from_string(text)
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
metadata["token_count"] = token_count
|
||||
if not ObjectId.is_valid(doc_id):
|
||||
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
||||
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
||||
if not doc:
|
||||
return make_response(
|
||||
jsonify({"error": "Document not found or access denied"}), 404
|
||||
)
|
||||
try:
|
||||
store = get_vector_store(doc_id)
|
||||
|
||||
chunks = store.get_chunks()
|
||||
existing_chunk = next((c for c in chunks if c["doc_id"] == chunk_id), None)
|
||||
if not existing_chunk:
|
||||
return make_response(jsonify({"error": "Chunk not found"}), 404)
|
||||
new_text = text if text is not None else existing_chunk["text"]
|
||||
|
||||
if metadata is not None:
|
||||
new_metadata = existing_chunk["metadata"].copy()
|
||||
new_metadata.update(metadata)
|
||||
else:
|
||||
new_metadata = existing_chunk["metadata"].copy()
|
||||
if text is not None:
|
||||
new_metadata["token_count"] = num_tokens_from_string(new_text)
|
||||
try:
|
||||
new_chunk_id = store.add_chunk(new_text, new_metadata)
|
||||
|
||||
deleted = store.delete_chunk(chunk_id)
|
||||
if not deleted:
|
||||
current_app.logger.warning(
|
||||
f"Failed to delete old chunk {chunk_id}, but new chunk {new_chunk_id} was created"
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"message": "Chunk updated successfully",
|
||||
"chunk_id": new_chunk_id,
|
||||
"original_chunk_id": chunk_id,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
except Exception as add_error:
|
||||
current_app.logger.error(f"Failed to add updated chunk: {add_error}")
|
||||
return make_response(
|
||||
jsonify({"error": "Failed to update chunk - addition failed"}), 500
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error updating chunk: {e}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 500)
|
||||
@@ -1,323 +0,0 @@
|
||||
"""Source document management routes."""
|
||||
|
||||
import json
|
||||
import math
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, redirect, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import sources_collection
|
||||
from application.core.settings import settings
|
||||
from application.storage.storage_creator import StorageCreator
|
||||
from application.utils import check_required_fields
|
||||
from application.vectorstore.vector_creator import VectorCreator
|
||||
|
||||
|
||||
sources_ns = Namespace(
|
||||
"sources", description="Source document management operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
@sources_ns.route("/sources")
|
||||
class CombinedJson(Resource):
|
||||
@api.doc(description="Provide JSON file with combined available indexes")
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = [
|
||||
{
|
||||
"name": "Default",
|
||||
"date": "default",
|
||||
"model": settings.EMBEDDINGS_NAME,
|
||||
"location": "remote",
|
||||
"tokens": "",
|
||||
"retriever": "classic",
|
||||
}
|
||||
]
|
||||
|
||||
try:
|
||||
for index in sources_collection.find({"user": user}).sort("date", -1):
|
||||
data.append(
|
||||
{
|
||||
"id": str(index["_id"]),
|
||||
"name": index.get("name"),
|
||||
"date": index.get("date"),
|
||||
"model": settings.EMBEDDINGS_NAME,
|
||||
"location": "local",
|
||||
"tokens": index.get("tokens", ""),
|
||||
"retriever": index.get("retriever", "classic"),
|
||||
"syncFrequency": index.get("sync_frequency", ""),
|
||||
"is_nested": bool(index.get("directory_structure")),
|
||||
"type": index.get(
|
||||
"type", "file"
|
||||
), # Add type field with default "file"
|
||||
}
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error retrieving sources: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify(data), 200)
|
||||
|
||||
|
||||
@sources_ns.route("/sources/paginated")
|
||||
class PaginatedSources(Resource):
|
||||
@api.doc(description="Get document with pagination, sorting and filtering")
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
sort_field = request.args.get("sort", "date") # Default to 'date'
|
||||
sort_order = request.args.get("order", "desc") # Default to 'desc'
|
||||
page = int(request.args.get("page", 1)) # Default to 1
|
||||
rows_per_page = int(request.args.get("rows", 10)) # Default to 10
|
||||
# add .strip() to remove leading and trailing whitespaces
|
||||
|
||||
search_term = request.args.get(
|
||||
"search", ""
|
||||
).strip() # add search for filter documents
|
||||
|
||||
# Prepare query for filtering
|
||||
|
||||
query = {"user": user}
|
||||
if search_term:
|
||||
query["name"] = {
|
||||
"$regex": search_term,
|
||||
"$options": "i", # using case-insensitive search
|
||||
}
|
||||
total_documents = sources_collection.count_documents(query)
|
||||
total_pages = max(1, math.ceil(total_documents / rows_per_page))
|
||||
page = min(
|
||||
max(1, page), total_pages
|
||||
) # add this to make sure page inbound is within the range
|
||||
sort_order = 1 if sort_order == "asc" else -1
|
||||
skip = (page - 1) * rows_per_page
|
||||
|
||||
try:
|
||||
documents = (
|
||||
sources_collection.find(query)
|
||||
.sort(sort_field, sort_order)
|
||||
.skip(skip)
|
||||
.limit(rows_per_page)
|
||||
)
|
||||
|
||||
paginated_docs = []
|
||||
for doc in documents:
|
||||
doc_data = {
|
||||
"id": str(doc["_id"]),
|
||||
"name": doc.get("name", ""),
|
||||
"date": doc.get("date", ""),
|
||||
"model": settings.EMBEDDINGS_NAME,
|
||||
"location": "local",
|
||||
"tokens": doc.get("tokens", ""),
|
||||
"retriever": doc.get("retriever", "classic"),
|
||||
"syncFrequency": doc.get("sync_frequency", ""),
|
||||
"isNested": bool(doc.get("directory_structure")),
|
||||
"type": doc.get("type", "file"),
|
||||
}
|
||||
paginated_docs.append(doc_data)
|
||||
response = {
|
||||
"total": total_documents,
|
||||
"totalPages": total_pages,
|
||||
"currentPage": page,
|
||||
"paginated": paginated_docs,
|
||||
}
|
||||
return make_response(jsonify(response), 200)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error retrieving paginated sources: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
|
||||
@sources_ns.route("/delete_by_ids")
|
||||
class DeleteByIds(Resource):
|
||||
@api.doc(
|
||||
description="Deletes documents from the vector store by IDs",
|
||||
params={"path": "Comma-separated list of IDs"},
|
||||
)
|
||||
def get(self):
|
||||
ids = request.args.get("path")
|
||||
if not ids:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Missing required fields"}), 400
|
||||
)
|
||||
try:
|
||||
result = sources_collection.delete_index(ids=ids)
|
||||
if result:
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error deleting indexes: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
|
||||
@sources_ns.route("/delete_old")
|
||||
class DeleteOldIndexes(Resource):
|
||||
@api.doc(
|
||||
description="Deletes old indexes and associated files",
|
||||
params={"source_id": "The source ID to delete"},
|
||||
)
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
source_id = request.args.get("source_id")
|
||||
if not source_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Missing required fields"}), 400
|
||||
)
|
||||
doc = sources_collection.find_one(
|
||||
{"_id": ObjectId(source_id), "user": decoded_token.get("sub")}
|
||||
)
|
||||
if not doc:
|
||||
return make_response(jsonify({"status": "not found"}), 404)
|
||||
storage = StorageCreator.get_storage()
|
||||
|
||||
try:
|
||||
# Delete vector index
|
||||
|
||||
if settings.VECTOR_STORE == "faiss":
|
||||
index_path = f"indexes/{str(doc['_id'])}"
|
||||
if storage.file_exists(f"{index_path}/index.faiss"):
|
||||
storage.delete_file(f"{index_path}/index.faiss")
|
||||
if storage.file_exists(f"{index_path}/index.pkl"):
|
||||
storage.delete_file(f"{index_path}/index.pkl")
|
||||
else:
|
||||
vectorstore = VectorCreator.create_vectorstore(
|
||||
settings.VECTOR_STORE, source_id=str(doc["_id"])
|
||||
)
|
||||
vectorstore.delete_index()
|
||||
if "file_path" in doc and doc["file_path"]:
|
||||
file_path = doc["file_path"]
|
||||
if storage.is_directory(file_path):
|
||||
files = storage.list_files(file_path)
|
||||
for f in files:
|
||||
storage.delete_file(f)
|
||||
else:
|
||||
storage.delete_file(file_path)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error deleting files and indexes: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
sources_collection.delete_one({"_id": ObjectId(source_id)})
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@sources_ns.route("/combine")
|
||||
class RedirectToSources(Resource):
|
||||
@api.doc(
|
||||
description="Redirects /api/combine to /api/sources for backward compatibility"
|
||||
)
|
||||
def get(self):
|
||||
return redirect("/api/sources", code=301)
|
||||
|
||||
|
||||
@sources_ns.route("/manage_sync")
|
||||
class ManageSync(Resource):
|
||||
manage_sync_model = api.model(
|
||||
"ManageSyncModel",
|
||||
{
|
||||
"source_id": fields.String(required=True, description="Source ID"),
|
||||
"sync_frequency": fields.String(
|
||||
required=True,
|
||||
description="Sync frequency (never, daily, weekly, monthly)",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(manage_sync_model)
|
||||
@api.doc(description="Manage sync frequency for sources")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["source_id", "sync_frequency"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
source_id = data["source_id"]
|
||||
sync_frequency = data["sync_frequency"]
|
||||
|
||||
if sync_frequency not in ["never", "daily", "weekly", "monthly"]:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid frequency"}), 400
|
||||
)
|
||||
update_data = {"$set": {"sync_frequency": sync_frequency}}
|
||||
try:
|
||||
sources_collection.update_one(
|
||||
{
|
||||
"_id": ObjectId(source_id),
|
||||
"user": user,
|
||||
},
|
||||
update_data,
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error updating sync frequency: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@sources_ns.route("/directory_structure")
|
||||
class DirectoryStructure(Resource):
|
||||
@api.doc(
|
||||
description="Get the directory structure for a document",
|
||||
params={"id": "The document ID"},
|
||||
)
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
doc_id = request.args.get("id")
|
||||
|
||||
if not doc_id:
|
||||
return make_response(jsonify({"error": "Document ID is required"}), 400)
|
||||
if not ObjectId.is_valid(doc_id):
|
||||
return make_response(jsonify({"error": "Invalid document ID"}), 400)
|
||||
try:
|
||||
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
||||
if not doc:
|
||||
return make_response(
|
||||
jsonify({"error": "Document not found or access denied"}), 404
|
||||
)
|
||||
directory_structure = doc.get("directory_structure", {})
|
||||
base_path = doc.get("file_path", "")
|
||||
|
||||
provider = None
|
||||
remote_data = doc.get("remote_data")
|
||||
try:
|
||||
if isinstance(remote_data, str) and remote_data:
|
||||
remote_data_obj = json.loads(remote_data)
|
||||
provider = remote_data_obj.get("provider")
|
||||
except Exception as e:
|
||||
current_app.logger.warning(
|
||||
f"Failed to parse remote_data for doc {doc_id}: {e}"
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"directory_structure": directory_structure,
|
||||
"base_path": base_path,
|
||||
"provider": provider,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error retrieving directory structure: {e}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False, "error": str(e)}), 500)
|
||||
@@ -1,583 +0,0 @@
|
||||
"""Source document management upload functionality."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import zipfile
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import sources_collection
|
||||
from application.api.user.tasks import ingest, ingest_connector_task, ingest_remote
|
||||
from application.core.settings import settings
|
||||
from application.parser.connectors.connector_creator import ConnectorCreator
|
||||
from application.storage.storage_creator import StorageCreator
|
||||
from application.utils import check_required_fields, safe_filename
|
||||
|
||||
|
||||
sources_upload_ns = Namespace(
|
||||
"sources", description="Source document management operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
@sources_upload_ns.route("/upload")
|
||||
class UploadFile(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"UploadModel",
|
||||
{
|
||||
"user": fields.String(required=True, description="User ID"),
|
||||
"name": fields.String(required=True, description="Job name"),
|
||||
"file": fields.Raw(required=True, description="File(s) to upload"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Uploads a file to be vectorized and indexed",
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
data = request.form
|
||||
files = request.files.getlist("file")
|
||||
required_fields = ["user", "name"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields or not files or all(file.filename == "" for file in files):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"status": "error",
|
||||
"message": "Missing required fields or files",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
user = decoded_token.get("sub")
|
||||
job_name = request.form["name"]
|
||||
|
||||
# Create safe versions for filesystem operations
|
||||
|
||||
safe_user = safe_filename(user)
|
||||
dir_name = safe_filename(job_name)
|
||||
base_path = f"{settings.UPLOAD_FOLDER}/{safe_user}/{dir_name}"
|
||||
|
||||
try:
|
||||
storage = StorageCreator.get_storage()
|
||||
|
||||
for file in files:
|
||||
original_filename = file.filename
|
||||
safe_file = safe_filename(original_filename)
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_file_path = os.path.join(temp_dir, safe_file)
|
||||
file.save(temp_file_path)
|
||||
|
||||
if zipfile.is_zipfile(temp_file_path):
|
||||
try:
|
||||
with zipfile.ZipFile(temp_file_path, "r") as zip_ref:
|
||||
zip_ref.extractall(path=temp_dir)
|
||||
|
||||
# Walk through extracted files and upload them
|
||||
|
||||
for root, _, files in os.walk(temp_dir):
|
||||
for extracted_file in files:
|
||||
if (
|
||||
os.path.join(root, extracted_file)
|
||||
== temp_file_path
|
||||
):
|
||||
continue
|
||||
rel_path = os.path.relpath(
|
||||
os.path.join(root, extracted_file), temp_dir
|
||||
)
|
||||
storage_path = f"{base_path}/{rel_path}"
|
||||
|
||||
with open(
|
||||
os.path.join(root, extracted_file), "rb"
|
||||
) as f:
|
||||
storage.save_file(f, storage_path)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error extracting zip: {e}", exc_info=True
|
||||
)
|
||||
# If zip extraction fails, save the original zip file
|
||||
|
||||
file_path = f"{base_path}/{safe_file}"
|
||||
with open(temp_file_path, "rb") as f:
|
||||
storage.save_file(f, file_path)
|
||||
else:
|
||||
# For non-zip files, save directly
|
||||
|
||||
file_path = f"{base_path}/{safe_file}"
|
||||
with open(temp_file_path, "rb") as f:
|
||||
storage.save_file(f, file_path)
|
||||
task = ingest.delay(
|
||||
settings.UPLOAD_FOLDER,
|
||||
[
|
||||
".rst",
|
||||
".md",
|
||||
".pdf",
|
||||
".txt",
|
||||
".docx",
|
||||
".csv",
|
||||
".epub",
|
||||
".html",
|
||||
".mdx",
|
||||
".json",
|
||||
".xlsx",
|
||||
".pptx",
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
],
|
||||
job_name,
|
||||
user,
|
||||
file_path=base_path,
|
||||
filename=dir_name,
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error uploading file: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
|
||||
|
||||
|
||||
@sources_upload_ns.route("/remote")
|
||||
class UploadRemote(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"RemoteUploadModel",
|
||||
{
|
||||
"user": fields.String(required=True, description="User ID"),
|
||||
"source": fields.String(
|
||||
required=True, description="Source of the data"
|
||||
),
|
||||
"name": fields.String(required=True, description="Job name"),
|
||||
"data": fields.String(required=True, description="Data to process"),
|
||||
"repo_url": fields.String(description="GitHub repository URL"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Uploads remote source for vectorization",
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
data = request.form
|
||||
required_fields = ["user", "source", "name", "data"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
config = json.loads(data["data"])
|
||||
source_data = None
|
||||
|
||||
if data["source"] == "github":
|
||||
source_data = config.get("repo_url")
|
||||
elif data["source"] in ["crawler", "url"]:
|
||||
source_data = config.get("url")
|
||||
elif data["source"] == "reddit":
|
||||
source_data = config
|
||||
elif data["source"] in ConnectorCreator.get_supported_connectors():
|
||||
session_token = config.get("session_token")
|
||||
if not session_token:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": f"Missing session_token in {data['source']} configuration",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
# Process file_ids
|
||||
|
||||
file_ids = config.get("file_ids", [])
|
||||
if isinstance(file_ids, str):
|
||||
file_ids = [id.strip() for id in file_ids.split(",") if id.strip()]
|
||||
elif not isinstance(file_ids, list):
|
||||
file_ids = []
|
||||
# Process folder_ids
|
||||
|
||||
folder_ids = config.get("folder_ids", [])
|
||||
if isinstance(folder_ids, str):
|
||||
folder_ids = [
|
||||
id.strip() for id in folder_ids.split(",") if id.strip()
|
||||
]
|
||||
elif not isinstance(folder_ids, list):
|
||||
folder_ids = []
|
||||
config["file_ids"] = file_ids
|
||||
config["folder_ids"] = folder_ids
|
||||
|
||||
task = ingest_connector_task.delay(
|
||||
job_name=data["name"],
|
||||
user=decoded_token.get("sub"),
|
||||
source_type=data["source"],
|
||||
session_token=session_token,
|
||||
file_ids=file_ids,
|
||||
folder_ids=folder_ids,
|
||||
recursive=config.get("recursive", False),
|
||||
retriever=config.get("retriever", "classic"),
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": True, "task_id": task.id}), 200
|
||||
)
|
||||
task = ingest_remote.delay(
|
||||
source_data=source_data,
|
||||
job_name=data["name"],
|
||||
user=decoded_token.get("sub"),
|
||||
loader=data["source"],
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error uploading remote source: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
|
||||
|
||||
|
||||
@sources_upload_ns.route("/manage_source_files")
|
||||
class ManageSourceFiles(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"ManageSourceFilesModel",
|
||||
{
|
||||
"source_id": fields.String(
|
||||
required=True, description="Source ID to modify"
|
||||
),
|
||||
"operation": fields.String(
|
||||
required=True,
|
||||
description="Operation: 'add', 'remove', or 'remove_directory'",
|
||||
),
|
||||
"file_paths": fields.List(
|
||||
fields.String,
|
||||
required=False,
|
||||
description="File paths to remove (for remove operation)",
|
||||
),
|
||||
"directory_path": fields.String(
|
||||
required=False,
|
||||
description="Directory path to remove (for remove_directory operation)",
|
||||
),
|
||||
"file": fields.Raw(
|
||||
required=False, description="Files to add (for add operation)"
|
||||
),
|
||||
"parent_dir": fields.String(
|
||||
required=False,
|
||||
description="Parent directory path relative to source root",
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Add files, remove files, or remove directories from an existing source",
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Unauthorized"}), 401
|
||||
)
|
||||
user = decoded_token.get("sub")
|
||||
source_id = request.form.get("source_id")
|
||||
operation = request.form.get("operation")
|
||||
|
||||
if not source_id or not operation:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "source_id and operation are required",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
if operation not in ["add", "remove", "remove_directory"]:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "operation must be 'add', 'remove', or 'remove_directory'",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
try:
|
||||
ObjectId(source_id)
|
||||
except Exception:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid source ID format"}), 400
|
||||
)
|
||||
try:
|
||||
source = sources_collection.find_one(
|
||||
{"_id": ObjectId(source_id), "user": user}
|
||||
)
|
||||
if not source:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Source not found or access denied",
|
||||
}
|
||||
),
|
||||
404,
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error finding source: {err}", exc_info=True)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Database error"}), 500
|
||||
)
|
||||
try:
|
||||
storage = StorageCreator.get_storage()
|
||||
source_file_path = source.get("file_path", "")
|
||||
parent_dir = request.form.get("parent_dir", "")
|
||||
|
||||
if parent_dir and (parent_dir.startswith("/") or ".." in parent_dir):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{"success": False, "message": "Invalid parent directory path"}
|
||||
),
|
||||
400,
|
||||
)
|
||||
if operation == "add":
|
||||
files = request.files.getlist("file")
|
||||
if not files or all(file.filename == "" for file in files):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "No files provided for add operation",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
added_files = []
|
||||
|
||||
target_dir = source_file_path
|
||||
if parent_dir:
|
||||
target_dir = f"{source_file_path}/{parent_dir}"
|
||||
for file in files:
|
||||
if file.filename:
|
||||
safe_filename_str = safe_filename(file.filename)
|
||||
file_path = f"{target_dir}/{safe_filename_str}"
|
||||
|
||||
# Save file to storage
|
||||
|
||||
storage.save_file(file, file_path)
|
||||
added_files.append(safe_filename_str)
|
||||
# Trigger re-ingestion pipeline
|
||||
|
||||
from application.api.user.tasks import reingest_source_task
|
||||
|
||||
task = reingest_source_task.delay(source_id=source_id, user=user)
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"Added {len(added_files)} files",
|
||||
"added_files": added_files,
|
||||
"parent_dir": parent_dir,
|
||||
"reingest_task_id": task.id,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
elif operation == "remove":
|
||||
file_paths_str = request.form.get("file_paths")
|
||||
if not file_paths_str:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "file_paths required for remove operation",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
try:
|
||||
file_paths = (
|
||||
json.loads(file_paths_str)
|
||||
if isinstance(file_paths_str, str)
|
||||
else file_paths_str
|
||||
)
|
||||
except Exception:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{"success": False, "message": "Invalid file_paths format"}
|
||||
),
|
||||
400,
|
||||
)
|
||||
# Remove files from storage and directory structure
|
||||
|
||||
removed_files = []
|
||||
for file_path in file_paths:
|
||||
full_path = f"{source_file_path}/{file_path}"
|
||||
|
||||
# Remove from storage
|
||||
|
||||
if storage.file_exists(full_path):
|
||||
storage.delete_file(full_path)
|
||||
removed_files.append(file_path)
|
||||
# Trigger re-ingestion pipeline
|
||||
|
||||
from application.api.user.tasks import reingest_source_task
|
||||
|
||||
task = reingest_source_task.delay(source_id=source_id, user=user)
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"Removed {len(removed_files)} files",
|
||||
"removed_files": removed_files,
|
||||
"reingest_task_id": task.id,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
elif operation == "remove_directory":
|
||||
directory_path = request.form.get("directory_path")
|
||||
if not directory_path:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "directory_path required for remove_directory operation",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
# Validate directory path (prevent path traversal)
|
||||
|
||||
if directory_path.startswith("/") or ".." in directory_path:
|
||||
current_app.logger.warning(
|
||||
f"Invalid directory path attempted for removal. "
|
||||
f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}"
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{"success": False, "message": "Invalid directory path"}
|
||||
),
|
||||
400,
|
||||
)
|
||||
full_directory_path = (
|
||||
f"{source_file_path}/{directory_path}"
|
||||
if directory_path
|
||||
else source_file_path
|
||||
)
|
||||
|
||||
if not storage.is_directory(full_directory_path):
|
||||
current_app.logger.warning(
|
||||
f"Directory not found or is not a directory for removal. "
|
||||
f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}, "
|
||||
f"Full path: {full_directory_path}"
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Directory not found or is not a directory",
|
||||
}
|
||||
),
|
||||
404,
|
||||
)
|
||||
success = storage.remove_directory(full_directory_path)
|
||||
|
||||
if not success:
|
||||
current_app.logger.error(
|
||||
f"Failed to remove directory from storage. "
|
||||
f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}, "
|
||||
f"Full path: {full_directory_path}"
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{"success": False, "message": "Failed to remove directory"}
|
||||
),
|
||||
500,
|
||||
)
|
||||
current_app.logger.info(
|
||||
f"Successfully removed directory. "
|
||||
f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}, "
|
||||
f"Full path: {full_directory_path}"
|
||||
)
|
||||
|
||||
# Trigger re-ingestion pipeline
|
||||
|
||||
from application.api.user.tasks import reingest_source_task
|
||||
|
||||
task = reingest_source_task.delay(source_id=source_id, user=user)
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"Successfully removed directory: {directory_path}",
|
||||
"removed_directory": directory_path,
|
||||
"reingest_task_id": task.id,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
except Exception as err:
|
||||
error_context = f"operation={operation}, user={user}, source_id={source_id}"
|
||||
if operation == "remove_directory":
|
||||
directory_path = request.form.get("directory_path", "")
|
||||
error_context += f", directory_path={directory_path}"
|
||||
elif operation == "remove":
|
||||
file_paths_str = request.form.get("file_paths", "")
|
||||
error_context += f", file_paths={file_paths_str}"
|
||||
elif operation == "add":
|
||||
parent_dir = request.form.get("parent_dir", "")
|
||||
error_context += f", parent_dir={parent_dir}"
|
||||
current_app.logger.error(
|
||||
f"Error managing source files: {err} ({error_context})", exc_info=True
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Operation failed"}), 500
|
||||
)
|
||||
|
||||
|
||||
@sources_upload_ns.route("/task_status")
|
||||
class TaskStatus(Resource):
|
||||
task_status_model = api.model(
|
||||
"TaskStatusModel",
|
||||
{"task_id": fields.String(required=True, description="Task ID")},
|
||||
)
|
||||
|
||||
@api.expect(task_status_model)
|
||||
@api.doc(description="Get celery job status")
|
||||
def get(self):
|
||||
task_id = request.args.get("task_id")
|
||||
if not task_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Task ID is required"}), 400
|
||||
)
|
||||
try:
|
||||
from application.celery_init import celery
|
||||
|
||||
task = celery.AsyncResult(task_id)
|
||||
task_meta = task.info
|
||||
print(f"Task status: {task.status}")
|
||||
|
||||
if task.status == "PENDING":
|
||||
inspect = celery.control.inspect()
|
||||
active_workers = inspect.ping()
|
||||
if not active_workers:
|
||||
raise ConnectionError("Service unavailable")
|
||||
|
||||
if not isinstance(
|
||||
task_meta, (dict, list, str, int, float, bool, type(None))
|
||||
):
|
||||
task_meta = str(task_meta) # Convert to a string representation
|
||||
except ConnectionError as err:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": str(err)}), 503
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error getting task status: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"status": task.status, "result": task_meta}), 200)
|
||||
@@ -5,8 +5,6 @@ from application.worker import (
|
||||
agent_webhook_worker,
|
||||
attachment_worker,
|
||||
ingest_worker,
|
||||
mcp_oauth,
|
||||
mcp_oauth_status,
|
||||
remote_worker,
|
||||
sync_worker,
|
||||
)
|
||||
@@ -27,7 +25,6 @@ def ingest_remote(self, source_data, job_name, user, loader):
|
||||
@celery.task(bind=True)
|
||||
def reingest_source_task(self, source_id, user):
|
||||
from application.worker import reingest_source_worker
|
||||
|
||||
resp = reingest_source_worker(self, source_id, user)
|
||||
return resp
|
||||
|
||||
@@ -63,10 +60,9 @@ def ingest_connector_task(
|
||||
retriever="classic",
|
||||
operation_mode="upload",
|
||||
doc_id=None,
|
||||
sync_frequency="never",
|
||||
sync_frequency="never"
|
||||
):
|
||||
from application.worker import ingest_connector
|
||||
|
||||
resp = ingest_connector(
|
||||
self,
|
||||
job_name,
|
||||
@@ -79,7 +75,7 @@ def ingest_connector_task(
|
||||
retriever=retriever,
|
||||
operation_mode=operation_mode,
|
||||
doc_id=doc_id,
|
||||
sync_frequency=sync_frequency,
|
||||
sync_frequency=sync_frequency
|
||||
)
|
||||
return resp
|
||||
|
||||
@@ -98,15 +94,3 @@ def setup_periodic_tasks(sender, **kwargs):
|
||||
timedelta(days=30),
|
||||
schedule_syncs.s("monthly"),
|
||||
)
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
def mcp_oauth_task(self, config, user):
|
||||
resp = mcp_oauth(self, config, user)
|
||||
return resp
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
def mcp_oauth_status_task(self, task_id):
|
||||
resp = mcp_oauth_status(self, task_id)
|
||||
return resp
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
"""Tools module."""
|
||||
|
||||
from .mcp import tools_mcp_ns
|
||||
from .routes import tools_ns
|
||||
|
||||
__all__ = ["tools_ns", "tools_mcp_ns"]
|
||||
@@ -1,333 +0,0 @@
|
||||
"""Tool management MCP server integration."""
|
||||
|
||||
import json
|
||||
from email.quoprimime import unquote
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, redirect, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.agents.tools.mcp_tool import MCPOAuthManager, MCPTool
|
||||
from application.api import api
|
||||
from application.api.user.base import user_tools_collection
|
||||
from application.cache import get_redis_instance
|
||||
from application.security.encryption import encrypt_credentials
|
||||
from application.utils import check_required_fields
|
||||
|
||||
tools_mcp_ns = Namespace("tools", description="Tool management operations", path="/api")
|
||||
|
||||
|
||||
@tools_mcp_ns.route("/mcp_server/test")
|
||||
class TestMCPServerConfig(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"MCPServerTestModel",
|
||||
{
|
||||
"config": fields.Raw(
|
||||
required=True, description="MCP server configuration to test"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Test MCP server connection with provided configuration")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
|
||||
required_fields = ["config"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
config = data["config"]
|
||||
|
||||
auth_credentials = {}
|
||||
auth_type = config.get("auth_type", "none")
|
||||
|
||||
if auth_type == "api_key" and "api_key" in config:
|
||||
auth_credentials["api_key"] = config["api_key"]
|
||||
if "api_key_header" in config:
|
||||
auth_credentials["api_key_header"] = config["api_key_header"]
|
||||
elif auth_type == "bearer" and "bearer_token" in config:
|
||||
auth_credentials["bearer_token"] = config["bearer_token"]
|
||||
elif auth_type == "basic":
|
||||
if "username" in config:
|
||||
auth_credentials["username"] = config["username"]
|
||||
if "password" in config:
|
||||
auth_credentials["password"] = config["password"]
|
||||
test_config = config.copy()
|
||||
test_config["auth_credentials"] = auth_credentials
|
||||
|
||||
mcp_tool = MCPTool(config=test_config, user_id=user)
|
||||
result = mcp_tool.test_connection()
|
||||
|
||||
return make_response(jsonify(result), 200)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error testing MCP server: {e}", exc_info=True)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{"success": False, "error": f"Connection test failed: {str(e)}"}
|
||||
),
|
||||
500,
|
||||
)
|
||||
|
||||
|
||||
@tools_mcp_ns.route("/mcp_server/save")
|
||||
class MCPServerSave(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"MCPServerSaveModel",
|
||||
{
|
||||
"id": fields.String(
|
||||
required=False, description="Tool ID for updates (optional)"
|
||||
),
|
||||
"displayName": fields.String(
|
||||
required=True, description="Display name for the MCP server"
|
||||
),
|
||||
"config": fields.Raw(
|
||||
required=True, description="MCP server configuration"
|
||||
),
|
||||
"status": fields.Boolean(
|
||||
required=False, default=True, description="Tool status"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Create or update MCP server with automatic tool discovery")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
|
||||
required_fields = ["displayName", "config"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
config = data["config"]
|
||||
|
||||
auth_credentials = {}
|
||||
auth_type = config.get("auth_type", "none")
|
||||
if auth_type == "api_key":
|
||||
if "api_key" in config and config["api_key"]:
|
||||
auth_credentials["api_key"] = config["api_key"]
|
||||
if "api_key_header" in config:
|
||||
auth_credentials["api_key_header"] = config["api_key_header"]
|
||||
elif auth_type == "bearer":
|
||||
if "bearer_token" in config and config["bearer_token"]:
|
||||
auth_credentials["bearer_token"] = config["bearer_token"]
|
||||
elif auth_type == "basic":
|
||||
if "username" in config and config["username"]:
|
||||
auth_credentials["username"] = config["username"]
|
||||
if "password" in config and config["password"]:
|
||||
auth_credentials["password"] = config["password"]
|
||||
mcp_config = config.copy()
|
||||
mcp_config["auth_credentials"] = auth_credentials
|
||||
|
||||
if auth_type == "oauth":
|
||||
if not config.get("oauth_task_id"):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": "Connection not authorized. Please complete the OAuth authorization first.",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
redis_client = get_redis_instance()
|
||||
manager = MCPOAuthManager(redis_client)
|
||||
result = manager.get_oauth_status(config["oauth_task_id"])
|
||||
if not result.get("status") == "completed":
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": "OAuth failed or not completed. Please try authorizing again.",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
actions_metadata = result.get("tools", [])
|
||||
elif auth_type == "none" or auth_credentials:
|
||||
mcp_tool = MCPTool(config=mcp_config, user_id=user)
|
||||
mcp_tool.discover_tools()
|
||||
actions_metadata = mcp_tool.get_actions_metadata()
|
||||
else:
|
||||
raise Exception(
|
||||
"No valid credentials provided for the selected authentication type"
|
||||
)
|
||||
storage_config = config.copy()
|
||||
if auth_credentials:
|
||||
encrypted_credentials_string = encrypt_credentials(
|
||||
auth_credentials, user
|
||||
)
|
||||
storage_config["encrypted_credentials"] = encrypted_credentials_string
|
||||
for field in [
|
||||
"api_key",
|
||||
"bearer_token",
|
||||
"username",
|
||||
"password",
|
||||
"api_key_header",
|
||||
]:
|
||||
storage_config.pop(field, None)
|
||||
transformed_actions = []
|
||||
for action in actions_metadata:
|
||||
action["active"] = True
|
||||
if "parameters" in action:
|
||||
if "properties" in action["parameters"]:
|
||||
for param_name, param_details in action["parameters"][
|
||||
"properties"
|
||||
].items():
|
||||
param_details["filled_by_llm"] = True
|
||||
param_details["value"] = ""
|
||||
transformed_actions.append(action)
|
||||
tool_data = {
|
||||
"name": "mcp_tool",
|
||||
"displayName": data["displayName"],
|
||||
"customName": data["displayName"],
|
||||
"description": f"MCP Server: {storage_config.get('server_url', 'Unknown')}",
|
||||
"config": storage_config,
|
||||
"actions": transformed_actions,
|
||||
"status": data.get("status", True),
|
||||
"user": user,
|
||||
}
|
||||
|
||||
tool_id = data.get("id")
|
||||
if tool_id:
|
||||
result = user_tools_collection.update_one(
|
||||
{"_id": ObjectId(tool_id), "user": user, "name": "mcp_tool"},
|
||||
{"$set": {k: v for k, v in tool_data.items() if k != "user"}},
|
||||
)
|
||||
if result.matched_count == 0:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": "Tool not found or access denied",
|
||||
}
|
||||
),
|
||||
404,
|
||||
)
|
||||
response_data = {
|
||||
"success": True,
|
||||
"id": tool_id,
|
||||
"message": f"MCP server updated successfully! Discovered {len(transformed_actions)} tools.",
|
||||
"tools_count": len(transformed_actions),
|
||||
}
|
||||
else:
|
||||
result = user_tools_collection.insert_one(tool_data)
|
||||
tool_id = str(result.inserted_id)
|
||||
response_data = {
|
||||
"success": True,
|
||||
"id": tool_id,
|
||||
"message": f"MCP server created successfully! Discovered {len(transformed_actions)} tools.",
|
||||
"tools_count": len(transformed_actions),
|
||||
}
|
||||
return make_response(jsonify(response_data), 200)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error saving MCP server: {e}", exc_info=True)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{"success": False, "error": f"Failed to save MCP server: {str(e)}"}
|
||||
),
|
||||
500,
|
||||
)
|
||||
|
||||
|
||||
@tools_mcp_ns.route("/mcp_server/callback")
|
||||
class MCPOAuthCallback(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"MCPServerCallbackModel",
|
||||
{
|
||||
"code": fields.String(required=True, description="Authorization code"),
|
||||
"state": fields.String(required=True, description="State parameter"),
|
||||
"error": fields.String(
|
||||
required=False, description="Error message (if any)"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Handle OAuth callback by providing the authorization code and state"
|
||||
)
|
||||
def get(self):
|
||||
code = request.args.get("code")
|
||||
state = request.args.get("state")
|
||||
error = request.args.get("error")
|
||||
|
||||
if error:
|
||||
return redirect(
|
||||
f"/api/connectors/callback-status?status=error&message=OAuth+error:+{error}.+Please+try+again+and+make+sure+to+grant+all+requested+permissions,+including+offline+access.&provider=mcp_tool"
|
||||
)
|
||||
if not code or not state:
|
||||
return redirect(
|
||||
"/api/connectors/callback-status?status=error&message=Authorization+code+or+state+not+provided.+Please+complete+the+authorization+process+and+make+sure+to+grant+offline+access.&provider=mcp_tool"
|
||||
)
|
||||
try:
|
||||
redis_client = get_redis_instance()
|
||||
if not redis_client:
|
||||
return redirect(
|
||||
"/api/connectors/callback-status?status=error&message=Internal+server+error:+Redis+not+available.&provider=mcp_tool"
|
||||
)
|
||||
code = unquote(code)
|
||||
manager = MCPOAuthManager(redis_client)
|
||||
success = manager.handle_oauth_callback(state, code, error)
|
||||
if success:
|
||||
return redirect(
|
||||
"/api/connectors/callback-status?status=success&message=Authorization+code+received+successfully.+You+can+close+this+window.&provider=mcp_tool"
|
||||
)
|
||||
else:
|
||||
return redirect(
|
||||
"/api/connectors/callback-status?status=error&message=OAuth+callback+failed.&provider=mcp_tool"
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error handling MCP OAuth callback: {str(e)}", exc_info=True
|
||||
)
|
||||
return redirect(
|
||||
f"/api/connectors/callback-status?status=error&message=Internal+server+error:+{str(e)}.&provider=mcp_tool"
|
||||
)
|
||||
|
||||
|
||||
@tools_mcp_ns.route("/mcp_server/oauth_status/<string:task_id>")
|
||||
class MCPOAuthStatus(Resource):
|
||||
def get(self, task_id):
|
||||
"""
|
||||
Get current status of OAuth flow.
|
||||
Frontend should poll this endpoint periodically.
|
||||
"""
|
||||
try:
|
||||
redis_client = get_redis_instance()
|
||||
status_key = f"mcp_oauth_status:{task_id}"
|
||||
status_data = redis_client.get(status_key)
|
||||
|
||||
if status_data:
|
||||
status = json.loads(status_data)
|
||||
return make_response(
|
||||
jsonify({"success": True, "task_id": task_id, **status})
|
||||
)
|
||||
else:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": "Task not found or expired",
|
||||
"task_id": task_id,
|
||||
}
|
||||
),
|
||||
404,
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error getting OAuth status for task {task_id}: {str(e)}"
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "error": str(e), "task_id": task_id}), 500
|
||||
)
|
||||
@@ -1,416 +0,0 @@
|
||||
"""Tool management routes."""
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.agents.tools.tool_manager import ToolManager
|
||||
from application.api import api
|
||||
from application.api.user.base import user_tools_collection
|
||||
from application.security.encryption import decrypt_credentials, encrypt_credentials
|
||||
from application.utils import check_required_fields, validate_function_name
|
||||
|
||||
tool_config = {}
|
||||
tool_manager = ToolManager(config=tool_config)
|
||||
|
||||
|
||||
tools_ns = Namespace("tools", description="Tool management operations", path="/api")
|
||||
|
||||
|
||||
@tools_ns.route("/available_tools")
|
||||
class AvailableTools(Resource):
|
||||
@api.doc(description="Get available tools for a user")
|
||||
def get(self):
|
||||
try:
|
||||
tools_metadata = []
|
||||
for tool_name, tool_instance in tool_manager.tools.items():
|
||||
doc = tool_instance.__doc__.strip()
|
||||
lines = doc.split("\n", 1)
|
||||
name = lines[0].strip()
|
||||
description = lines[1].strip() if len(lines) > 1 else ""
|
||||
tools_metadata.append(
|
||||
{
|
||||
"name": tool_name,
|
||||
"displayName": name,
|
||||
"description": description,
|
||||
"configRequirements": tool_instance.get_config_requirements(),
|
||||
}
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error getting available tools: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True, "data": tools_metadata}), 200)
|
||||
|
||||
|
||||
@tools_ns.route("/get_tools")
|
||||
class GetTools(Resource):
|
||||
@api.doc(description="Get tools created by a user")
|
||||
def get(self):
|
||||
try:
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
tools = user_tools_collection.find({"user": user})
|
||||
user_tools = []
|
||||
for tool in tools:
|
||||
tool_copy = {**tool}
|
||||
tool_copy["id"] = str(tool["_id"])
|
||||
tool_copy.pop("_id", None)
|
||||
user_tools.append(tool_copy)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error getting user tools: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True, "tools": user_tools}), 200)
|
||||
|
||||
|
||||
@tools_ns.route("/create_tool")
|
||||
class CreateTool(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"CreateToolModel",
|
||||
{
|
||||
"name": fields.String(required=True, description="Name of the tool"),
|
||||
"displayName": fields.String(
|
||||
required=True, description="Display name for the tool"
|
||||
),
|
||||
"description": fields.String(
|
||||
required=True, description="Tool description"
|
||||
),
|
||||
"config": fields.Raw(
|
||||
required=True, description="Configuration of the tool"
|
||||
),
|
||||
"customName": fields.String(
|
||||
required=False, description="Custom name for the tool"
|
||||
),
|
||||
"status": fields.Boolean(
|
||||
required=True, description="Status of the tool"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Create a new tool")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = [
|
||||
"name",
|
||||
"displayName",
|
||||
"description",
|
||||
"config",
|
||||
"status",
|
||||
]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
tool_instance = tool_manager.tools.get(data["name"])
|
||||
if not tool_instance:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Tool not found"}), 404
|
||||
)
|
||||
actions_metadata = tool_instance.get_actions_metadata()
|
||||
transformed_actions = []
|
||||
for action in actions_metadata:
|
||||
action["active"] = True
|
||||
if "parameters" in action:
|
||||
if "properties" in action["parameters"]:
|
||||
for param_name, param_details in action["parameters"][
|
||||
"properties"
|
||||
].items():
|
||||
param_details["filled_by_llm"] = True
|
||||
param_details["value"] = ""
|
||||
transformed_actions.append(action)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error getting tool actions: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
try:
|
||||
new_tool = {
|
||||
"user": user,
|
||||
"name": data["name"],
|
||||
"displayName": data["displayName"],
|
||||
"description": data["description"],
|
||||
"customName": data.get("customName", ""),
|
||||
"actions": transformed_actions,
|
||||
"config": data["config"],
|
||||
"status": data["status"],
|
||||
}
|
||||
resp = user_tools_collection.insert_one(new_tool)
|
||||
new_id = str(resp.inserted_id)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error creating tool: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"id": new_id}), 200)
|
||||
|
||||
|
||||
@tools_ns.route("/update_tool")
|
||||
class UpdateTool(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"UpdateToolModel",
|
||||
{
|
||||
"id": fields.String(required=True, description="Tool ID"),
|
||||
"name": fields.String(description="Name of the tool"),
|
||||
"displayName": fields.String(description="Display name for the tool"),
|
||||
"customName": fields.String(description="Custom name for the tool"),
|
||||
"description": fields.String(description="Tool description"),
|
||||
"config": fields.Raw(description="Configuration of the tool"),
|
||||
"actions": fields.List(
|
||||
fields.Raw, description="Actions the tool can perform"
|
||||
),
|
||||
"status": fields.Boolean(description="Status of the tool"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Update a tool by ID")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["id"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
update_data = {}
|
||||
if "name" in data:
|
||||
update_data["name"] = data["name"]
|
||||
if "displayName" in data:
|
||||
update_data["displayName"] = data["displayName"]
|
||||
if "customName" in data:
|
||||
update_data["customName"] = data["customName"]
|
||||
if "description" in data:
|
||||
update_data["description"] = data["description"]
|
||||
if "actions" in data:
|
||||
update_data["actions"] = data["actions"]
|
||||
if "config" in data:
|
||||
if "actions" in data["config"]:
|
||||
for action_name in list(data["config"]["actions"].keys()):
|
||||
if not validate_function_name(action_name):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"Invalid function name '{action_name}'. Function names must match pattern '^[a-zA-Z0-9_-]+$'.",
|
||||
"param": "tools[].function.name",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
tool_doc = user_tools_collection.find_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user}
|
||||
)
|
||||
if tool_doc and tool_doc.get("name") == "mcp_tool":
|
||||
config = data["config"]
|
||||
existing_config = tool_doc.get("config", {})
|
||||
storage_config = existing_config.copy()
|
||||
|
||||
storage_config.update(config)
|
||||
existing_credentials = {}
|
||||
if "encrypted_credentials" in existing_config:
|
||||
existing_credentials = decrypt_credentials(
|
||||
existing_config["encrypted_credentials"], user
|
||||
)
|
||||
auth_credentials = existing_credentials.copy()
|
||||
auth_type = storage_config.get("auth_type", "none")
|
||||
if auth_type == "api_key":
|
||||
if "api_key" in config and config["api_key"]:
|
||||
auth_credentials["api_key"] = config["api_key"]
|
||||
if "api_key_header" in config:
|
||||
auth_credentials["api_key_header"] = config[
|
||||
"api_key_header"
|
||||
]
|
||||
elif auth_type == "bearer":
|
||||
if "bearer_token" in config and config["bearer_token"]:
|
||||
auth_credentials["bearer_token"] = config["bearer_token"]
|
||||
elif "encrypted_token" in config and config["encrypted_token"]:
|
||||
auth_credentials["bearer_token"] = config["encrypted_token"]
|
||||
elif auth_type == "basic":
|
||||
if "username" in config and config["username"]:
|
||||
auth_credentials["username"] = config["username"]
|
||||
if "password" in config and config["password"]:
|
||||
auth_credentials["password"] = config["password"]
|
||||
if auth_type != "none" and auth_credentials:
|
||||
encrypted_credentials_string = encrypt_credentials(
|
||||
auth_credentials, user
|
||||
)
|
||||
storage_config["encrypted_credentials"] = (
|
||||
encrypted_credentials_string
|
||||
)
|
||||
elif auth_type == "none":
|
||||
storage_config.pop("encrypted_credentials", None)
|
||||
for field in [
|
||||
"api_key",
|
||||
"bearer_token",
|
||||
"encrypted_token",
|
||||
"username",
|
||||
"password",
|
||||
"api_key_header",
|
||||
]:
|
||||
storage_config.pop(field, None)
|
||||
update_data["config"] = storage_config
|
||||
else:
|
||||
update_data["config"] = data["config"]
|
||||
if "status" in data:
|
||||
update_data["status"] = data["status"]
|
||||
user_tools_collection.update_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user},
|
||||
{"$set": update_data},
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error updating tool: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@tools_ns.route("/update_tool_config")
|
||||
class UpdateToolConfig(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"UpdateToolConfigModel",
|
||||
{
|
||||
"id": fields.String(required=True, description="Tool ID"),
|
||||
"config": fields.Raw(
|
||||
required=True, description="Configuration of the tool"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Update the configuration of a tool")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["id", "config"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
user_tools_collection.update_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user},
|
||||
{"$set": {"config": data["config"]}},
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error updating tool config: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@tools_ns.route("/update_tool_actions")
|
||||
class UpdateToolActions(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"UpdateToolActionsModel",
|
||||
{
|
||||
"id": fields.String(required=True, description="Tool ID"),
|
||||
"actions": fields.List(
|
||||
fields.Raw,
|
||||
required=True,
|
||||
description="Actions the tool can perform",
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Update the actions of a tool")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["id", "actions"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
user_tools_collection.update_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user},
|
||||
{"$set": {"actions": data["actions"]}},
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error updating tool actions: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@tools_ns.route("/update_tool_status")
|
||||
class UpdateToolStatus(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"UpdateToolStatusModel",
|
||||
{
|
||||
"id": fields.String(required=True, description="Tool ID"),
|
||||
"status": fields.Boolean(
|
||||
required=True, description="Status of the tool"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Update the status of a tool")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["id", "status"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
user_tools_collection.update_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user},
|
||||
{"$set": {"status": data["status"]}},
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error updating tool status: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@tools_ns.route("/delete_tool")
|
||||
class DeleteTool(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"DeleteToolModel",
|
||||
{"id": fields.String(required=True, description="Tool ID")},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Delete a tool by ID")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["id"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
result = user_tools_collection.delete_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user}
|
||||
)
|
||||
if result.deleted_count == 0:
|
||||
return {"success": False, "message": "Tool not found"}, 404
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error deleting tool: {err}", exc_info=True)
|
||||
return {"success": False}, 400
|
||||
return {"success": True}, 200
|
||||
@@ -1,189 +0,0 @@
|
||||
"""
|
||||
Model configurations for all supported LLM providers.
|
||||
"""
|
||||
|
||||
from application.core.model_settings import (
|
||||
AvailableModel,
|
||||
ModelCapabilities,
|
||||
ModelProvider,
|
||||
)
|
||||
|
||||
OPENAI_ATTACHMENTS = [
|
||||
"application/pdf",
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/jpg",
|
||||
"image/webp",
|
||||
"image/gif",
|
||||
]
|
||||
|
||||
GOOGLE_ATTACHMENTS = [
|
||||
"application/pdf",
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/jpg",
|
||||
"image/webp",
|
||||
"image/gif",
|
||||
]
|
||||
|
||||
|
||||
OPENAI_MODELS = [
|
||||
AvailableModel(
|
||||
id="gpt-5.1",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="GPT-5.1",
|
||||
description="Flagship model with enhanced reasoning, coding, and agentic capabilities",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=OPENAI_ATTACHMENTS,
|
||||
context_window=400000,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="gpt-5-mini",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="GPT-5 Mini",
|
||||
description="Faster, cost-effective variant of GPT-5.1",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=OPENAI_ATTACHMENTS,
|
||||
context_window=400000,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
ANTHROPIC_MODELS = [
|
||||
AvailableModel(
|
||||
id="claude-3-5-sonnet-20241022",
|
||||
provider=ModelProvider.ANTHROPIC,
|
||||
display_name="Claude 3.5 Sonnet (Latest)",
|
||||
description="Latest Claude 3.5 Sonnet with enhanced capabilities",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
context_window=200000,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="claude-3-5-sonnet",
|
||||
provider=ModelProvider.ANTHROPIC,
|
||||
display_name="Claude 3.5 Sonnet",
|
||||
description="Balanced performance and capability",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
context_window=200000,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="claude-3-opus",
|
||||
provider=ModelProvider.ANTHROPIC,
|
||||
display_name="Claude 3 Opus",
|
||||
description="Most capable Claude model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
context_window=200000,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="claude-3-haiku",
|
||||
provider=ModelProvider.ANTHROPIC,
|
||||
display_name="Claude 3 Haiku",
|
||||
description="Fastest Claude model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
context_window=200000,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
GOOGLE_MODELS = [
|
||||
AvailableModel(
|
||||
id="gemini-flash-latest",
|
||||
provider=ModelProvider.GOOGLE,
|
||||
display_name="Gemini Flash (Latest)",
|
||||
description="Latest experimental Gemini model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=GOOGLE_ATTACHMENTS,
|
||||
context_window=int(1e6),
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="gemini-flash-lite-latest",
|
||||
provider=ModelProvider.GOOGLE,
|
||||
display_name="Gemini Flash Lite (Latest)",
|
||||
description="Fast with huge context window",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=GOOGLE_ATTACHMENTS,
|
||||
context_window=int(1e6),
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="gemini-3-pro-preview",
|
||||
provider=ModelProvider.GOOGLE,
|
||||
display_name="Gemini 3 Pro",
|
||||
description="Most capable Gemini model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=GOOGLE_ATTACHMENTS,
|
||||
context_window=20000, # Set low for testing compression
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
GROQ_MODELS = [
|
||||
AvailableModel(
|
||||
id="llama-3.3-70b-versatile",
|
||||
provider=ModelProvider.GROQ,
|
||||
display_name="Llama 3.3 70B",
|
||||
description="Latest Llama model with high-speed inference",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
context_window=128000,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="llama-3.1-8b-instant",
|
||||
provider=ModelProvider.GROQ,
|
||||
display_name="Llama 3.1 8B",
|
||||
description="Ultra-fast inference",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
context_window=128000,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="mixtral-8x7b-32768",
|
||||
provider=ModelProvider.GROQ,
|
||||
display_name="Mixtral 8x7B",
|
||||
description="High-speed inference with tools",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
context_window=32768,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
AZURE_OPENAI_MODELS = [
|
||||
AvailableModel(
|
||||
id="azure-gpt-4",
|
||||
provider=ModelProvider.AZURE_OPENAI,
|
||||
display_name="Azure OpenAI GPT-4",
|
||||
description="Azure-hosted GPT model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=OPENAI_ATTACHMENTS,
|
||||
context_window=8192,
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -1,236 +0,0 @@
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelProvider(str, Enum):
|
||||
OPENAI = "openai"
|
||||
AZURE_OPENAI = "azure_openai"
|
||||
ANTHROPIC = "anthropic"
|
||||
GROQ = "groq"
|
||||
GOOGLE = "google"
|
||||
HUGGINGFACE = "huggingface"
|
||||
LLAMA_CPP = "llama.cpp"
|
||||
DOCSGPT = "docsgpt"
|
||||
PREMAI = "premai"
|
||||
SAGEMAKER = "sagemaker"
|
||||
NOVITA = "novita"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelCapabilities:
|
||||
supports_tools: bool = False
|
||||
supports_structured_output: bool = False
|
||||
supports_streaming: bool = True
|
||||
supported_attachment_types: List[str] = field(default_factory=list)
|
||||
context_window: int = 128000
|
||||
input_cost_per_token: Optional[float] = None
|
||||
output_cost_per_token: Optional[float] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AvailableModel:
|
||||
id: str
|
||||
provider: ModelProvider
|
||||
display_name: str
|
||||
description: str = ""
|
||||
capabilities: ModelCapabilities = field(default_factory=ModelCapabilities)
|
||||
enabled: bool = True
|
||||
base_url: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
result = {
|
||||
"id": self.id,
|
||||
"provider": self.provider.value,
|
||||
"display_name": self.display_name,
|
||||
"description": self.description,
|
||||
"supported_attachment_types": self.capabilities.supported_attachment_types,
|
||||
"supports_tools": self.capabilities.supports_tools,
|
||||
"supports_structured_output": self.capabilities.supports_structured_output,
|
||||
"supports_streaming": self.capabilities.supports_streaming,
|
||||
"context_window": self.capabilities.context_window,
|
||||
"enabled": self.enabled,
|
||||
}
|
||||
if self.base_url:
|
||||
result["base_url"] = self.base_url
|
||||
return result
|
||||
|
||||
|
||||
class ModelRegistry:
|
||||
_instance = None
|
||||
_initialized = False
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not ModelRegistry._initialized:
|
||||
self.models: Dict[str, AvailableModel] = {}
|
||||
self.default_model_id: Optional[str] = None
|
||||
self._load_models()
|
||||
ModelRegistry._initialized = True
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "ModelRegistry":
|
||||
return cls()
|
||||
|
||||
def _load_models(self):
|
||||
from application.core.settings import settings
|
||||
|
||||
self.models.clear()
|
||||
|
||||
self._add_docsgpt_models(settings)
|
||||
if settings.OPENAI_API_KEY or (
|
||||
settings.LLM_PROVIDER == "openai" and settings.API_KEY
|
||||
):
|
||||
self._add_openai_models(settings)
|
||||
if settings.OPENAI_API_BASE or (
|
||||
settings.LLM_PROVIDER == "azure_openai" and settings.API_KEY
|
||||
):
|
||||
self._add_azure_openai_models(settings)
|
||||
if settings.ANTHROPIC_API_KEY or (
|
||||
settings.LLM_PROVIDER == "anthropic" and settings.API_KEY
|
||||
):
|
||||
self._add_anthropic_models(settings)
|
||||
if settings.GOOGLE_API_KEY or (
|
||||
settings.LLM_PROVIDER == "google" and settings.API_KEY
|
||||
):
|
||||
self._add_google_models(settings)
|
||||
if settings.GROQ_API_KEY or (
|
||||
settings.LLM_PROVIDER == "groq" and settings.API_KEY
|
||||
):
|
||||
self._add_groq_models(settings)
|
||||
if settings.HUGGINGFACE_API_KEY or (
|
||||
settings.LLM_PROVIDER == "huggingface" and settings.API_KEY
|
||||
):
|
||||
self._add_huggingface_models(settings)
|
||||
# Default model selection
|
||||
|
||||
if settings.LLM_NAME and settings.LLM_NAME in self.models:
|
||||
self.default_model_id = settings.LLM_NAME
|
||||
elif settings.LLM_PROVIDER and settings.API_KEY:
|
||||
for model_id, model in self.models.items():
|
||||
if model.provider.value == settings.LLM_PROVIDER:
|
||||
self.default_model_id = model_id
|
||||
break
|
||||
else:
|
||||
self.default_model_id = next(iter(self.models.keys()))
|
||||
logger.info(
|
||||
f"ModelRegistry loaded {len(self.models)} models, default: {self.default_model_id}"
|
||||
)
|
||||
|
||||
def _add_openai_models(self, settings):
|
||||
from application.core.model_configs import OPENAI_MODELS
|
||||
|
||||
if settings.OPENAI_API_KEY:
|
||||
for model in OPENAI_MODELS:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
if settings.LLM_PROVIDER == "openai" and settings.LLM_NAME:
|
||||
for model in OPENAI_MODELS:
|
||||
if model.id == settings.LLM_NAME:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
for model in OPENAI_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_azure_openai_models(self, settings):
|
||||
from application.core.model_configs import AZURE_OPENAI_MODELS
|
||||
|
||||
if settings.LLM_PROVIDER == "azure_openai" and settings.LLM_NAME:
|
||||
for model in AZURE_OPENAI_MODELS:
|
||||
if model.id == settings.LLM_NAME:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
for model in AZURE_OPENAI_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_anthropic_models(self, settings):
|
||||
from application.core.model_configs import ANTHROPIC_MODELS
|
||||
|
||||
if settings.ANTHROPIC_API_KEY:
|
||||
for model in ANTHROPIC_MODELS:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
if settings.LLM_PROVIDER == "anthropic" and settings.LLM_NAME:
|
||||
for model in ANTHROPIC_MODELS:
|
||||
if model.id == settings.LLM_NAME:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
for model in ANTHROPIC_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_google_models(self, settings):
|
||||
from application.core.model_configs import GOOGLE_MODELS
|
||||
|
||||
if settings.GOOGLE_API_KEY:
|
||||
for model in GOOGLE_MODELS:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
if settings.LLM_PROVIDER == "google" and settings.LLM_NAME:
|
||||
for model in GOOGLE_MODELS:
|
||||
if model.id == settings.LLM_NAME:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
for model in GOOGLE_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_groq_models(self, settings):
|
||||
from application.core.model_configs import GROQ_MODELS
|
||||
|
||||
if settings.GROQ_API_KEY:
|
||||
for model in GROQ_MODELS:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
if settings.LLM_PROVIDER == "groq" and settings.LLM_NAME:
|
||||
for model in GROQ_MODELS:
|
||||
if model.id == settings.LLM_NAME:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
for model in GROQ_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_docsgpt_models(self, settings):
|
||||
model_id = "docsgpt-local"
|
||||
model = AvailableModel(
|
||||
id=model_id,
|
||||
provider=ModelProvider.DOCSGPT,
|
||||
display_name="DocsGPT Model",
|
||||
description="Local model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=False,
|
||||
supported_attachment_types=[],
|
||||
),
|
||||
)
|
||||
self.models[model_id] = model
|
||||
|
||||
def _add_huggingface_models(self, settings):
|
||||
model_id = "huggingface-local"
|
||||
model = AvailableModel(
|
||||
id=model_id,
|
||||
provider=ModelProvider.HUGGINGFACE,
|
||||
display_name="Hugging Face Model",
|
||||
description="Local Hugging Face model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=False,
|
||||
supported_attachment_types=[],
|
||||
),
|
||||
)
|
||||
self.models[model_id] = model
|
||||
|
||||
def get_model(self, model_id: str) -> Optional[AvailableModel]:
|
||||
return self.models.get(model_id)
|
||||
|
||||
def get_all_models(self) -> List[AvailableModel]:
|
||||
return list(self.models.values())
|
||||
|
||||
def get_enabled_models(self) -> List[AvailableModel]:
|
||||
return [m for m in self.models.values() if m.enabled]
|
||||
|
||||
def model_exists(self, model_id: str) -> bool:
|
||||
return model_id in self.models
|
||||
@@ -1,91 +0,0 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from application.core.model_settings import ModelRegistry
|
||||
|
||||
|
||||
def get_api_key_for_provider(provider: str) -> Optional[str]:
|
||||
"""Get the appropriate API key for a provider"""
|
||||
from application.core.settings import settings
|
||||
|
||||
provider_key_map = {
|
||||
"openai": settings.OPENAI_API_KEY,
|
||||
"anthropic": settings.ANTHROPIC_API_KEY,
|
||||
"google": settings.GOOGLE_API_KEY,
|
||||
"groq": settings.GROQ_API_KEY,
|
||||
"huggingface": settings.HUGGINGFACE_API_KEY,
|
||||
"azure_openai": settings.API_KEY,
|
||||
"docsgpt": None,
|
||||
"llama.cpp": None,
|
||||
}
|
||||
|
||||
provider_key = provider_key_map.get(provider)
|
||||
if provider_key:
|
||||
return provider_key
|
||||
return settings.API_KEY
|
||||
|
||||
|
||||
def get_all_available_models() -> Dict[str, Dict[str, Any]]:
|
||||
"""Get all available models with metadata for API response"""
|
||||
registry = ModelRegistry.get_instance()
|
||||
return {model.id: model.to_dict() for model in registry.get_enabled_models()}
|
||||
|
||||
|
||||
def validate_model_id(model_id: str) -> bool:
|
||||
"""Check if a model ID exists in registry"""
|
||||
registry = ModelRegistry.get_instance()
|
||||
return registry.model_exists(model_id)
|
||||
|
||||
|
||||
def get_model_capabilities(model_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get capabilities for a specific model"""
|
||||
registry = ModelRegistry.get_instance()
|
||||
model = registry.get_model(model_id)
|
||||
if model:
|
||||
return {
|
||||
"supported_attachment_types": model.capabilities.supported_attachment_types,
|
||||
"supports_tools": model.capabilities.supports_tools,
|
||||
"supports_structured_output": model.capabilities.supports_structured_output,
|
||||
"context_window": model.capabilities.context_window,
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
def get_default_model_id() -> str:
|
||||
"""Get the system default model ID"""
|
||||
registry = ModelRegistry.get_instance()
|
||||
return registry.default_model_id
|
||||
|
||||
|
||||
def get_provider_from_model_id(model_id: str) -> Optional[str]:
|
||||
"""Get the provider name for a given model_id"""
|
||||
registry = ModelRegistry.get_instance()
|
||||
model = registry.get_model(model_id)
|
||||
if model:
|
||||
return model.provider.value
|
||||
return None
|
||||
|
||||
|
||||
def get_token_limit(model_id: str) -> int:
|
||||
"""
|
||||
Get context window (token limit) for a model.
|
||||
Returns model's context_window or default 128000 if model not found.
|
||||
"""
|
||||
from application.core.settings import settings
|
||||
|
||||
registry = ModelRegistry.get_instance()
|
||||
model = registry.get_model(model_id)
|
||||
if model:
|
||||
return model.capabilities.context_window
|
||||
return settings.DEFAULT_LLM_TOKEN_LIMIT
|
||||
|
||||
|
||||
def get_base_url_for_model(model_id: str) -> Optional[str]:
|
||||
"""
|
||||
Get the custom base_url for a specific model if configured.
|
||||
Returns None if no custom base_url is set.
|
||||
"""
|
||||
registry = ModelRegistry.get_instance()
|
||||
model = registry.get_model(model_id)
|
||||
if model:
|
||||
return model.base_url
|
||||
return None
|
||||
@@ -22,15 +22,11 @@ class Settings(BaseSettings):
|
||||
MONGO_DB_NAME: str = "docsgpt"
|
||||
LLM_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf")
|
||||
DEFAULT_MAX_HISTORY: int = 150
|
||||
DEFAULT_LLM_TOKEN_LIMIT: int = 128000 # Fallback when model not found in registry
|
||||
RESERVED_TOKENS: dict = {
|
||||
"system_prompt": 500,
|
||||
"current_query": 500,
|
||||
"safety_buffer": 1000,
|
||||
}
|
||||
DEFAULT_AGENT_LIMITS: dict = {
|
||||
"token_limit": 50000,
|
||||
"request_limit": 500,
|
||||
LLM_TOKEN_LIMITS: dict = {
|
||||
"gpt-4o-mini": 128000,
|
||||
"gpt-3.5-turbo": 4096,
|
||||
"claude-2": 1e5,
|
||||
"gemini-2.0-flash-exp": 1e6,
|
||||
}
|
||||
UPLOAD_FOLDER: str = "inputs"
|
||||
PARSE_PDF_AS_IMAGE: bool = False
|
||||
@@ -45,33 +41,18 @@ class Settings(BaseSettings):
|
||||
FALLBACK_LLM_API_KEY: Optional[str] = None # api key for fallback llm
|
||||
|
||||
# Google Drive integration
|
||||
GOOGLE_CLIENT_ID: Optional[str] = (
|
||||
None # Replace with your actual Google OAuth client ID
|
||||
)
|
||||
GOOGLE_CLIENT_SECRET: Optional[str] = (
|
||||
None # Replace with your actual Google OAuth client secret
|
||||
)
|
||||
CONNECTOR_REDIRECT_BASE_URI: Optional[str] = (
|
||||
"http://127.0.0.1:7091/api/connectors/callback" ##add redirect url as it is to your provider's console(gcp)
|
||||
)
|
||||
GOOGLE_CLIENT_ID: Optional[str] = None # Replace with your actual Google OAuth client ID
|
||||
GOOGLE_CLIENT_SECRET: Optional[str] = None# Replace with your actual Google OAuth client secret
|
||||
CONNECTOR_REDIRECT_BASE_URI: Optional[str] = "http://127.0.0.1:7091/api/connectors/callback"
|
||||
##append ?provider={provider_name} in your Provider console like http://127.0.0.1:7091/api/connectors/callback?provider=google_drive
|
||||
|
||||
# GitHub source
|
||||
GITHUB_ACCESS_TOKEN: Optional[str] = None # PAT token with read repo access
|
||||
|
||||
# LLM Cache
|
||||
CACHE_REDIS_URL: str = "redis://localhost:6379/2"
|
||||
|
||||
API_URL: str = "http://localhost:7091" # backend url for celery worker
|
||||
|
||||
API_KEY: Optional[str] = None # LLM api key (used by LLM_PROVIDER)
|
||||
|
||||
# Provider-specific API keys (for multi-model support)
|
||||
OPENAI_API_KEY: Optional[str] = None
|
||||
ANTHROPIC_API_KEY: Optional[str] = None
|
||||
GOOGLE_API_KEY: Optional[str] = None
|
||||
GROQ_API_KEY: Optional[str] = None
|
||||
HUGGINGFACE_API_KEY: Optional[str] = None
|
||||
|
||||
API_KEY: Optional[str] = None # LLM api key
|
||||
EMBEDDINGS_KEY: Optional[str] = (
|
||||
None # api key for embeddings (if using openai, just copy API_KEY)
|
||||
)
|
||||
@@ -115,7 +96,7 @@ class Settings(BaseSettings):
|
||||
QDRANT_HOST: Optional[str] = None
|
||||
QDRANT_PATH: Optional[str] = None
|
||||
QDRANT_DISTANCE_FUNC: str = "Cosine"
|
||||
|
||||
|
||||
# PGVector vectorstore config
|
||||
PGVECTOR_CONNECTION_STRING: Optional[str] = None
|
||||
# Milvus vectorstore config
|
||||
@@ -135,22 +116,6 @@ class Settings(BaseSettings):
|
||||
|
||||
JWT_SECRET_KEY: str = ""
|
||||
|
||||
# Encryption settings
|
||||
ENCRYPTION_SECRET_KEY: str = "default-docsgpt-encryption-key"
|
||||
|
||||
TTS_PROVIDER: str = "google_tts" # google_tts or elevenlabs
|
||||
ELEVENLABS_API_KEY: Optional[str] = None
|
||||
|
||||
# Tool pre-fetch settings
|
||||
ENABLE_TOOL_PREFETCH: bool = True
|
||||
|
||||
# Conversation Compression Settings
|
||||
ENABLE_CONVERSATION_COMPRESSION: bool = True
|
||||
COMPRESSION_THRESHOLD_PERCENTAGE: float = 0.8 # Trigger at 80% of context
|
||||
COMPRESSION_MODEL_OVERRIDE: Optional[str] = None # Use different model for compression
|
||||
COMPRESSION_PROMPT_VERSION: str = "v1.0" # Track prompt iterations
|
||||
COMPRESSION_MAX_HISTORY_POINTS: int = 3 # Keep only last N compression points to prevent DB bloat
|
||||
|
||||
|
||||
path = Path(__file__).parent.parent.absolute()
|
||||
settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8")
|
||||
|
||||
@@ -1,41 +1,30 @@
|
||||
from anthropic import AI_PROMPT, Anthropic, HUMAN_PROMPT
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.llm.base import BaseLLM
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
class AnthropicLLM(BaseLLM):
|
||||
|
||||
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
|
||||
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
|
||||
from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
self.api_key = api_key or settings.ANTHROPIC_API_KEY or settings.API_KEY
|
||||
self.api_key = (
|
||||
api_key or settings.ANTHROPIC_API_KEY
|
||||
) # If not provided, use a default from settings
|
||||
self.user_api_key = user_api_key
|
||||
|
||||
# Use custom base_url if provided
|
||||
if base_url:
|
||||
self.anthropic = Anthropic(api_key=self.api_key, base_url=base_url)
|
||||
else:
|
||||
self.anthropic = Anthropic(api_key=self.api_key)
|
||||
|
||||
self.anthropic = Anthropic(api_key=self.api_key)
|
||||
self.HUMAN_PROMPT = HUMAN_PROMPT
|
||||
self.AI_PROMPT = AI_PROMPT
|
||||
|
||||
def _raw_gen(
|
||||
self,
|
||||
baseself,
|
||||
model,
|
||||
messages,
|
||||
stream=False,
|
||||
tools=None,
|
||||
max_tokens=300,
|
||||
**kwargs,
|
||||
self, baseself, model, messages, stream=False, tools=None, max_tokens=300, **kwargs
|
||||
):
|
||||
context = messages[0]["content"]
|
||||
user_question = messages[-1]["content"]
|
||||
prompt = f"### Context \n {context} \n ### Question \n {user_question}"
|
||||
if stream:
|
||||
return self.gen_stream(model, prompt, stream, max_tokens, **kwargs)
|
||||
|
||||
completion = self.anthropic.completions.create(
|
||||
model=model,
|
||||
max_tokens_to_sample=max_tokens,
|
||||
@@ -45,14 +34,7 @@ class AnthropicLLM(BaseLLM):
|
||||
return completion.completion
|
||||
|
||||
def _raw_gen_stream(
|
||||
self,
|
||||
baseself,
|
||||
model,
|
||||
messages,
|
||||
stream=True,
|
||||
tools=None,
|
||||
max_tokens=300,
|
||||
**kwargs,
|
||||
self, baseself, model, messages, stream=True, tools=None, max_tokens=300, **kwargs
|
||||
):
|
||||
context = messages[0]["content"]
|
||||
user_question = messages[-1]["content"]
|
||||
@@ -64,9 +46,5 @@ class AnthropicLLM(BaseLLM):
|
||||
stream=True,
|
||||
)
|
||||
|
||||
try:
|
||||
for completion in stream_response:
|
||||
yield completion.completion
|
||||
finally:
|
||||
if hasattr(stream_response, "close"):
|
||||
stream_response.close()
|
||||
for completion in stream_response:
|
||||
yield completion.completion
|
||||
|
||||
@@ -13,32 +13,30 @@ class BaseLLM(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
decoded_token=None,
|
||||
model_id=None,
|
||||
base_url=None,
|
||||
):
|
||||
self.decoded_token = decoded_token
|
||||
self.model_id = model_id
|
||||
self.base_url = base_url
|
||||
self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
|
||||
self.fallback_provider = settings.FALLBACK_LLM_PROVIDER
|
||||
self.fallback_model_name = settings.FALLBACK_LLM_NAME
|
||||
self.fallback_llm_api_key = settings.FALLBACK_LLM_API_KEY
|
||||
self._fallback_llm = None
|
||||
self._fallback_sequence_index = 0
|
||||
|
||||
@property
|
||||
def fallback_llm(self):
|
||||
"""Lazy-loaded fallback LLM from FALLBACK_* settings."""
|
||||
if self._fallback_llm is None and settings.FALLBACK_LLM_PROVIDER:
|
||||
"""Lazy-loaded fallback LLM instance."""
|
||||
if (
|
||||
self._fallback_llm is None
|
||||
and self.fallback_provider
|
||||
and self.fallback_model_name
|
||||
):
|
||||
try:
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
|
||||
self._fallback_llm = LLMCreator.create_llm(
|
||||
settings.FALLBACK_LLM_PROVIDER,
|
||||
api_key=settings.FALLBACK_LLM_API_KEY or settings.API_KEY,
|
||||
user_api_key=None,
|
||||
decoded_token=self.decoded_token,
|
||||
model_id=settings.FALLBACK_LLM_NAME,
|
||||
)
|
||||
logger.info(
|
||||
f"Fallback LLM initialized: {settings.FALLBACK_LLM_PROVIDER}/{settings.FALLBACK_LLM_NAME}"
|
||||
self.fallback_provider,
|
||||
self.fallback_llm_api_key,
|
||||
None,
|
||||
self.decoded_token,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
@@ -46,17 +44,11 @@ class BaseLLM(ABC):
|
||||
)
|
||||
return self._fallback_llm
|
||||
|
||||
@staticmethod
|
||||
def _remove_null_values(args_dict):
|
||||
if not isinstance(args_dict, dict):
|
||||
return args_dict
|
||||
return {k: v for k, v in args_dict.items() if v is not None}
|
||||
|
||||
def _execute_with_fallback(
|
||||
self, method_name: str, decorators: list, *args, **kwargs
|
||||
):
|
||||
"""
|
||||
Execute method with fallback support.
|
||||
Unified method execution with fallback support.
|
||||
|
||||
Args:
|
||||
method_name: Name of the raw method ('_raw_gen' or '_raw_gen_stream')
|
||||
@@ -75,10 +67,10 @@ class BaseLLM(ABC):
|
||||
return decorated_method()
|
||||
except Exception as e:
|
||||
if not self.fallback_llm:
|
||||
logger.error(f"Primary LLM failed and no fallback configured: {str(e)}")
|
||||
logger.error(f"Primary LLM failed and no fallback available: {str(e)}")
|
||||
raise
|
||||
logger.warning(
|
||||
f"Primary LLM failed. Falling back to {settings.FALLBACK_LLM_PROVIDER}/{settings.FALLBACK_LLM_NAME}. Error: {str(e)}"
|
||||
f"Falling back to {self.fallback_provider}/{self.fallback_model_name}. Error: {str(e)}"
|
||||
)
|
||||
|
||||
fallback_method = getattr(
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import json
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.llm.base import BaseLLM
|
||||
|
||||
@@ -9,11 +7,12 @@ from application.llm.base import BaseLLM
|
||||
class DocsGPTAPILLM(BaseLLM):
|
||||
|
||||
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
|
||||
from openai import OpenAI
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
self.api_key = "sk-docsgpt-public"
|
||||
self.client = OpenAI(api_key=self.api_key, base_url="https://oai.arc53.com")
|
||||
self.client = OpenAI(api_key="sk-docsgpt-public", base_url="https://oai.arc53.com")
|
||||
self.user_api_key = user_api_key
|
||||
self.api_key = api_key
|
||||
|
||||
def _clean_messages_openai(self, messages):
|
||||
cleaned_messages = []
|
||||
@@ -23,6 +22,7 @@ class DocsGPTAPILLM(BaseLLM):
|
||||
|
||||
if role == "model":
|
||||
role = "assistant"
|
||||
|
||||
if role and content is not None:
|
||||
if isinstance(content, str):
|
||||
cleaned_messages.append({"role": role, "content": content})
|
||||
@@ -33,15 +33,14 @@ class DocsGPTAPILLM(BaseLLM):
|
||||
{"role": role, "content": item["text"]}
|
||||
)
|
||||
elif "function_call" in item:
|
||||
cleaned_args = self._remove_null_values(
|
||||
item["function_call"]["args"]
|
||||
)
|
||||
tool_call = {
|
||||
"id": item["function_call"]["call_id"],
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": item["function_call"]["name"],
|
||||
"arguments": json.dumps(cleaned_args),
|
||||
"arguments": json.dumps(
|
||||
item["function_call"]["args"]
|
||||
),
|
||||
},
|
||||
}
|
||||
cleaned_messages.append(
|
||||
@@ -69,6 +68,7 @@ class DocsGPTAPILLM(BaseLLM):
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unexpected content type: {type(content)}")
|
||||
|
||||
return cleaned_messages
|
||||
|
||||
def _raw_gen(
|
||||
@@ -120,19 +120,12 @@ class DocsGPTAPILLM(BaseLLM):
|
||||
response = self.client.chat.completions.create(
|
||||
model="docsgpt", messages=messages, stream=stream, **kwargs
|
||||
)
|
||||
try:
|
||||
for line in response:
|
||||
if (
|
||||
len(line.choices) > 0
|
||||
and line.choices[0].delta.content is not None
|
||||
and len(line.choices[0].delta.content) > 0
|
||||
):
|
||||
yield line.choices[0].delta.content
|
||||
elif len(line.choices) > 0:
|
||||
yield line.choices[0]
|
||||
finally:
|
||||
if hasattr(response, "close"):
|
||||
response.close()
|
||||
|
||||
for line in response:
|
||||
if len(line.choices) > 0 and line.choices[0].delta.content is not None and len(line.choices[0].delta.content) > 0:
|
||||
yield line.choices[0].delta.content
|
||||
elif len(line.choices) > 0:
|
||||
yield line.choices[0]
|
||||
|
||||
def _supports_tools(self):
|
||||
return True
|
||||
return True
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from google import genai
|
||||
@@ -10,13 +11,10 @@ from application.storage.storage_creator import StorageCreator
|
||||
|
||||
|
||||
class GoogleLLM(BaseLLM):
|
||||
def __init__(
|
||||
self, api_key=None, user_api_key=None, decoded_token=None, *args, **kwargs
|
||||
):
|
||||
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.api_key = api_key or settings.GOOGLE_API_KEY or settings.API_KEY
|
||||
self.api_key = api_key
|
||||
self.user_api_key = user_api_key
|
||||
|
||||
self.client = genai.Client(api_key=self.api_key)
|
||||
self.storage = StorageCreator.get_storage()
|
||||
|
||||
@@ -34,12 +32,6 @@ class GoogleLLM(BaseLLM):
|
||||
"image/jpg",
|
||||
"image/webp",
|
||||
"image/gif",
|
||||
"application/pdf",
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/jpg",
|
||||
"image/webp",
|
||||
"image/gif",
|
||||
]
|
||||
|
||||
def prepare_messages_with_attachments(self, messages, attachments=None):
|
||||
@@ -55,19 +47,21 @@ class GoogleLLM(BaseLLM):
|
||||
"""
|
||||
if not attachments:
|
||||
return messages
|
||||
|
||||
prepared_messages = messages.copy()
|
||||
|
||||
# Find the user message to attach files to the last one
|
||||
|
||||
user_message_index = None
|
||||
for i in range(len(prepared_messages) - 1, -1, -1):
|
||||
if prepared_messages[i].get("role") == "user":
|
||||
user_message_index = i
|
||||
break
|
||||
|
||||
if user_message_index is None:
|
||||
user_message = {"role": "user", "content": []}
|
||||
prepared_messages.append(user_message)
|
||||
user_message_index = len(prepared_messages) - 1
|
||||
|
||||
if isinstance(prepared_messages[user_message_index].get("content"), str):
|
||||
text_content = prepared_messages[user_message_index]["content"]
|
||||
prepared_messages[user_message_index]["content"] = [
|
||||
@@ -75,6 +69,7 @@ class GoogleLLM(BaseLLM):
|
||||
]
|
||||
elif not isinstance(prepared_messages[user_message_index].get("content"), list):
|
||||
prepared_messages[user_message_index]["content"] = []
|
||||
|
||||
files = []
|
||||
for attachment in attachments:
|
||||
mime_type = attachment.get("mime_type")
|
||||
@@ -97,9 +92,11 @@ class GoogleLLM(BaseLLM):
|
||||
"text": f"[File could not be processed: {attachment.get('path', 'unknown')}]",
|
||||
}
|
||||
)
|
||||
|
||||
if files:
|
||||
logging.info(f"GoogleLLM: Adding {len(files)} files to message")
|
||||
prepared_messages[user_message_index]["content"].append({"files": files})
|
||||
|
||||
return prepared_messages
|
||||
|
||||
def _upload_file_to_google(self, attachment):
|
||||
@@ -114,11 +111,14 @@ class GoogleLLM(BaseLLM):
|
||||
"""
|
||||
if "google_file_uri" in attachment:
|
||||
return attachment["google_file_uri"]
|
||||
|
||||
file_path = attachment.get("path")
|
||||
if not file_path:
|
||||
raise ValueError("No file path provided in attachment")
|
||||
|
||||
if not self.storage.file_exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
try:
|
||||
file_uri = self.storage.process_file(
|
||||
file_path,
|
||||
@@ -136,48 +136,21 @@ class GoogleLLM(BaseLLM):
|
||||
attachments_collection.update_one(
|
||||
{"_id": attachment["_id"]}, {"$set": {"google_file_uri": file_uri}}
|
||||
)
|
||||
|
||||
return file_uri
|
||||
except Exception as e:
|
||||
logging.error(f"Error uploading file to Google AI: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def _clean_messages_google(self, messages):
|
||||
"""
|
||||
Convert OpenAI format messages to Google AI format and collect system prompts.
|
||||
|
||||
Returns:
|
||||
tuple[list[types.Content], Optional[str]]: cleaned messages and optional
|
||||
combined system instruction.
|
||||
"""
|
||||
cleaned_messages = []
|
||||
system_instructions = []
|
||||
|
||||
def _extract_system_text(content):
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict) and "text" in item and item["text"] is not None:
|
||||
parts.append(item["text"])
|
||||
return "\n".join(parts)
|
||||
return ""
|
||||
|
||||
for message in messages:
|
||||
role = message.get("role")
|
||||
content = message.get("content")
|
||||
|
||||
# Gemini only accepts user/model in the contents list.
|
||||
if role == "system":
|
||||
sys_text = _extract_system_text(content)
|
||||
if sys_text:
|
||||
system_instructions.append(sys_text)
|
||||
continue
|
||||
|
||||
if role == "assistant":
|
||||
role = "model"
|
||||
elif role == "tool":
|
||||
role = "model"
|
||||
|
||||
parts = []
|
||||
if role and content is not None:
|
||||
if isinstance(content, str):
|
||||
@@ -187,32 +160,12 @@ class GoogleLLM(BaseLLM):
|
||||
if "text" in item:
|
||||
parts.append(types.Part.from_text(text=item["text"]))
|
||||
elif "function_call" in item:
|
||||
# Remove null values from args to avoid API errors
|
||||
|
||||
cleaned_args = self._remove_null_values(
|
||||
item["function_call"]["args"]
|
||||
parts.append(
|
||||
types.Part.from_function_call(
|
||||
name=item["function_call"]["name"],
|
||||
args=item["function_call"]["args"],
|
||||
)
|
||||
)
|
||||
# Create function call part with thought_signature if present
|
||||
# For Gemini 3 models, we need to include thought_signature
|
||||
if "thought_signature" in item:
|
||||
# Use Part constructor with functionCall and thoughtSignature
|
||||
parts.append(
|
||||
types.Part(
|
||||
functionCall=types.FunctionCall(
|
||||
name=item["function_call"]["name"],
|
||||
args=cleaned_args,
|
||||
),
|
||||
thoughtSignature=item["thought_signature"],
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Use helper method when no thought_signature
|
||||
parts.append(
|
||||
types.Part.from_function_call(
|
||||
name=item["function_call"]["name"],
|
||||
args=cleaned_args,
|
||||
)
|
||||
)
|
||||
elif "function_response" in item:
|
||||
parts.append(
|
||||
types.Part.from_function_response(
|
||||
@@ -234,62 +187,12 @@ class GoogleLLM(BaseLLM):
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unexpected content type: {type(content)}")
|
||||
if parts:
|
||||
cleaned_messages.append(types.Content(role=role, parts=parts))
|
||||
system_instruction = "\n\n".join(system_instructions) if system_instructions else None
|
||||
return cleaned_messages, system_instruction
|
||||
|
||||
def _clean_schema(self, schema_obj):
|
||||
"""
|
||||
Recursively remove unsupported fields from schema objects
|
||||
and validate required properties.
|
||||
"""
|
||||
if not isinstance(schema_obj, dict):
|
||||
return schema_obj
|
||||
allowed_fields = {
|
||||
"type",
|
||||
"description",
|
||||
"items",
|
||||
"properties",
|
||||
"required",
|
||||
"enum",
|
||||
"pattern",
|
||||
"minimum",
|
||||
"maximum",
|
||||
"nullable",
|
||||
"default",
|
||||
}
|
||||
cleaned_messages.append(types.Content(role=role, parts=parts))
|
||||
|
||||
cleaned = {}
|
||||
for key, value in schema_obj.items():
|
||||
if key not in allowed_fields:
|
||||
continue
|
||||
elif key == "type" and isinstance(value, str):
|
||||
cleaned[key] = value.upper()
|
||||
elif isinstance(value, dict):
|
||||
cleaned[key] = self._clean_schema(value)
|
||||
elif isinstance(value, list):
|
||||
cleaned[key] = [self._clean_schema(item) for item in value]
|
||||
else:
|
||||
cleaned[key] = value
|
||||
# Validate that required properties actually exist in properties
|
||||
|
||||
if "required" in cleaned and "properties" in cleaned:
|
||||
valid_required = []
|
||||
properties_keys = set(cleaned["properties"].keys())
|
||||
for required_prop in cleaned["required"]:
|
||||
if required_prop in properties_keys:
|
||||
valid_required.append(required_prop)
|
||||
if valid_required:
|
||||
cleaned["required"] = valid_required
|
||||
else:
|
||||
cleaned.pop("required", None)
|
||||
elif "required" in cleaned and "properties" not in cleaned:
|
||||
cleaned.pop("required", None)
|
||||
return cleaned
|
||||
return cleaned_messages
|
||||
|
||||
def _clean_tools_format(self, tools_list):
|
||||
"""Convert OpenAI format tools to Google AI format."""
|
||||
genai_tools = []
|
||||
for tool_data in tools_list:
|
||||
if tool_data["type"] == "function":
|
||||
@@ -298,15 +201,18 @@ class GoogleLLM(BaseLLM):
|
||||
properties = parameters.get("properties", {})
|
||||
|
||||
if properties:
|
||||
cleaned_properties = {}
|
||||
for k, v in properties.items():
|
||||
cleaned_properties[k] = self._clean_schema(v)
|
||||
genai_function = dict(
|
||||
name=function["name"],
|
||||
description=function["description"],
|
||||
parameters={
|
||||
"type": "OBJECT",
|
||||
"properties": cleaned_properties,
|
||||
"properties": {
|
||||
k: {
|
||||
**v,
|
||||
"type": v["type"].upper() if v["type"] else None,
|
||||
}
|
||||
for k, v in properties.items()
|
||||
},
|
||||
"required": (
|
||||
parameters["required"]
|
||||
if "required" in parameters
|
||||
@@ -319,65 +225,12 @@ class GoogleLLM(BaseLLM):
|
||||
name=function["name"],
|
||||
description=function["description"],
|
||||
)
|
||||
|
||||
genai_tool = types.Tool(function_declarations=[genai_function])
|
||||
genai_tools.append(genai_tool)
|
||||
|
||||
return genai_tools
|
||||
|
||||
def _extract_preview_from_message(self, message):
|
||||
"""Get a short, human-readable preview from the last message."""
|
||||
try:
|
||||
if hasattr(message, "parts"):
|
||||
for part in reversed(message.parts):
|
||||
if getattr(part, "text", None):
|
||||
return part.text
|
||||
function_call = getattr(part, "function_call", None)
|
||||
if function_call:
|
||||
name = getattr(function_call, "name", "") or "function_call"
|
||||
return f"function_call:{name}"
|
||||
function_response = getattr(part, "function_response", None)
|
||||
if function_response:
|
||||
name = getattr(function_response, "name", "") or "function_response"
|
||||
return f"function_response:{name}"
|
||||
if isinstance(message, dict):
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
for item in reversed(content):
|
||||
if isinstance(item, str):
|
||||
return item
|
||||
if isinstance(item, dict):
|
||||
if item.get("text"):
|
||||
return item["text"]
|
||||
if item.get("function_call"):
|
||||
fn = item["function_call"]
|
||||
if isinstance(fn, dict):
|
||||
name = fn.get("name") or "function_call"
|
||||
return f"function_call:{name}"
|
||||
return "function_call"
|
||||
if item.get("function_response"):
|
||||
resp = item["function_response"]
|
||||
if isinstance(resp, dict):
|
||||
name = resp.get("name") or "function_response"
|
||||
return f"function_response:{name}"
|
||||
return "function_response"
|
||||
if "text" in message and isinstance(message["text"], str):
|
||||
return message["text"]
|
||||
except Exception:
|
||||
pass
|
||||
return str(message)
|
||||
|
||||
def _summarize_messages_for_log(self, messages, preview_chars=20):
|
||||
"""Return a compact summary for logging to avoid huge payloads."""
|
||||
message_count = len(messages) if messages else 0
|
||||
last_preview = ""
|
||||
if messages:
|
||||
last_preview = self._extract_preview_from_message(messages[-1]) or ""
|
||||
last_preview = str(last_preview).replace("\n", " ")
|
||||
if len(last_preview) > preview_chars:
|
||||
last_preview = f"{last_preview[:preview_chars]}..."
|
||||
return f"count={message_count}, last='{last_preview}'"
|
||||
|
||||
def _raw_gen(
|
||||
self,
|
||||
baseself,
|
||||
@@ -389,22 +242,23 @@ class GoogleLLM(BaseLLM):
|
||||
response_schema=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Generate content using Google AI API without streaming."""
|
||||
client = genai.Client(api_key=self.api_key)
|
||||
system_instruction = None
|
||||
if formatting == "openai":
|
||||
messages, system_instruction = self._clean_messages_google(messages)
|
||||
messages = self._clean_messages_google(messages)
|
||||
config = types.GenerateContentConfig()
|
||||
if system_instruction:
|
||||
config.system_instruction = system_instruction
|
||||
if messages[0].role == "system":
|
||||
config.system_instruction = messages[0].parts[0].text
|
||||
messages = messages[1:]
|
||||
|
||||
if tools:
|
||||
cleaned_tools = self._clean_tools_format(tools)
|
||||
config.tools = cleaned_tools
|
||||
# Add response schema for structured output if provided
|
||||
|
||||
# Add response schema for structured output if provided
|
||||
if response_schema:
|
||||
config.response_schema = response_schema
|
||||
config.response_mime_type = "application/json"
|
||||
|
||||
response = client.models.generate_content(
|
||||
model=model,
|
||||
contents=messages,
|
||||
@@ -427,24 +281,24 @@ class GoogleLLM(BaseLLM):
|
||||
response_schema=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Generate content using Google AI API with streaming."""
|
||||
client = genai.Client(api_key=self.api_key)
|
||||
system_instruction = None
|
||||
if formatting == "openai":
|
||||
messages, system_instruction = self._clean_messages_google(messages)
|
||||
messages = self._clean_messages_google(messages)
|
||||
config = types.GenerateContentConfig()
|
||||
if system_instruction:
|
||||
config.system_instruction = system_instruction
|
||||
if messages[0].role == "system":
|
||||
config.system_instruction = messages[0].parts[0].text
|
||||
messages = messages[1:]
|
||||
|
||||
if tools:
|
||||
cleaned_tools = self._clean_tools_format(tools)
|
||||
config.tools = cleaned_tools
|
||||
# Add response schema for structured output if provided
|
||||
|
||||
# Add response schema for structured output if provided
|
||||
if response_schema:
|
||||
config.response_schema = response_schema
|
||||
config.response_mime_type = "application/json"
|
||||
# Check if we have both tools and file attachments
|
||||
|
||||
# Check if we have both tools and file attachments
|
||||
has_attachments = False
|
||||
for message in messages:
|
||||
for part in message.parts:
|
||||
@@ -453,12 +307,9 @@ class GoogleLLM(BaseLLM):
|
||||
break
|
||||
if has_attachments:
|
||||
break
|
||||
messages_summary = self._summarize_messages_for_log(messages)
|
||||
|
||||
logging.info(
|
||||
"GoogleLLM: Starting stream generation. Model: %s, Messages: %s, Has attachments: %s",
|
||||
model,
|
||||
messages_summary,
|
||||
has_attachments,
|
||||
f"GoogleLLM: Starting stream generation. Model: {model}, Messages: {json.dumps(messages, default=str)}, Has attachments: {has_attachments}"
|
||||
)
|
||||
|
||||
response = client.models.generate_content_stream(
|
||||
@@ -467,34 +318,28 @@ class GoogleLLM(BaseLLM):
|
||||
config=config,
|
||||
)
|
||||
|
||||
try:
|
||||
for chunk in response:
|
||||
if hasattr(chunk, "candidates") and chunk.candidates:
|
||||
for candidate in chunk.candidates:
|
||||
if candidate.content and candidate.content.parts:
|
||||
for part in candidate.content.parts:
|
||||
if part.function_call:
|
||||
yield part
|
||||
elif part.text:
|
||||
yield part.text
|
||||
elif hasattr(chunk, "text"):
|
||||
yield chunk.text
|
||||
finally:
|
||||
if hasattr(response, "close"):
|
||||
response.close()
|
||||
for chunk in response:
|
||||
if hasattr(chunk, "candidates") and chunk.candidates:
|
||||
for candidate in chunk.candidates:
|
||||
if candidate.content and candidate.content.parts:
|
||||
for part in candidate.content.parts:
|
||||
if part.function_call:
|
||||
yield part
|
||||
elif part.text:
|
||||
yield part.text
|
||||
elif hasattr(chunk, "text"):
|
||||
yield chunk.text
|
||||
|
||||
def _supports_tools(self):
|
||||
"""Return whether this LLM supports function calling."""
|
||||
return True
|
||||
|
||||
def _supports_structured_output(self):
|
||||
"""Return whether this LLM supports structured JSON output."""
|
||||
return True
|
||||
|
||||
def prepare_structured_output_format(self, json_schema):
|
||||
"""Convert JSON schema to Google AI structured output format."""
|
||||
if not json_schema:
|
||||
return None
|
||||
|
||||
type_map = {
|
||||
"object": "OBJECT",
|
||||
"array": "ARRAY",
|
||||
@@ -507,10 +352,12 @@ class GoogleLLM(BaseLLM):
|
||||
def convert(schema):
|
||||
if not isinstance(schema, dict):
|
||||
return schema
|
||||
|
||||
result = {}
|
||||
schema_type = schema.get("type")
|
||||
if schema_type:
|
||||
result["type"] = type_map.get(schema_type.lower(), schema_type.upper())
|
||||
|
||||
for key in [
|
||||
"description",
|
||||
"nullable",
|
||||
@@ -522,6 +369,7 @@ class GoogleLLM(BaseLLM):
|
||||
]:
|
||||
if key in schema:
|
||||
result[key] = schema[key]
|
||||
|
||||
if "format" in schema:
|
||||
format_value = schema["format"]
|
||||
if schema_type == "string":
|
||||
@@ -531,17 +379,21 @@ class GoogleLLM(BaseLLM):
|
||||
result["format"] = format_value
|
||||
else:
|
||||
result["format"] = format_value
|
||||
|
||||
if "properties" in schema:
|
||||
result["properties"] = {
|
||||
k: convert(v) for k, v in schema["properties"].items()
|
||||
}
|
||||
if "propertyOrdering" not in result and result.get("type") == "OBJECT":
|
||||
result["propertyOrdering"] = list(result["properties"].keys())
|
||||
|
||||
if "items" in schema:
|
||||
result["items"] = convert(schema["items"])
|
||||
|
||||
for field in ["anyOf", "oneOf", "allOf"]:
|
||||
if field in schema:
|
||||
result[field] = [convert(s) for s in schema[field]]
|
||||
|
||||
return result
|
||||
|
||||
try:
|
||||
|
||||
@@ -1,18 +1,13 @@
|
||||
from openai import OpenAI
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.llm.base import BaseLLM
|
||||
from openai import OpenAI
|
||||
|
||||
|
||||
class GroqLLM(BaseLLM):
|
||||
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
self.api_key = api_key or settings.GROQ_API_KEY or settings.API_KEY
|
||||
self.client = OpenAI(api_key=api_key, base_url="https://api.groq.com/openai/v1")
|
||||
self.api_key = api_key
|
||||
self.user_api_key = user_api_key
|
||||
self.client = OpenAI(
|
||||
api_key=self.api_key, base_url="https://api.groq.com/openai/v1"
|
||||
)
|
||||
|
||||
def _raw_gen(self, baseself, model, messages, stream=False, tools=None, **kwargs):
|
||||
if tools:
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Generator, List, Optional, Union
|
||||
@@ -17,7 +16,6 @@ class ToolCall:
|
||||
name: str
|
||||
arguments: Union[str, Dict]
|
||||
index: Optional[int] = None
|
||||
thought_signature: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict) -> "ToolCall":
|
||||
@@ -180,406 +178,6 @@ class LLMHandler(ABC):
|
||||
system_msg["content"] += f"\n\n{combined_text}"
|
||||
return prepared_messages
|
||||
|
||||
def _prune_messages_minimal(self, messages: List[Dict]) -> Optional[List[Dict]]:
|
||||
"""
|
||||
Build a minimal context: system prompt + latest user message only.
|
||||
Drops all tool/function messages to shrink context aggressively.
|
||||
"""
|
||||
system_message = next((m for m in messages if m.get("role") == "system"), None)
|
||||
if not system_message:
|
||||
logger.warning("Cannot prune messages minimally: missing system message.")
|
||||
return None
|
||||
last_non_system = None
|
||||
for m in reversed(messages):
|
||||
if m.get("role") == "user":
|
||||
last_non_system = m
|
||||
break
|
||||
if not last_non_system and m.get("role") not in ("system", None):
|
||||
last_non_system = m
|
||||
if not last_non_system:
|
||||
logger.warning("Cannot prune messages minimally: missing user/assistant messages.")
|
||||
return None
|
||||
logger.info("Pruning context to system + latest user/assistant message to proceed.")
|
||||
return [system_message, last_non_system]
|
||||
|
||||
def _extract_text_from_content(self, content: Any) -> str:
|
||||
"""
|
||||
Convert message content (str or list of parts) to plain text for compression.
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts_text = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
if "text" in item and item["text"] is not None:
|
||||
parts_text.append(str(item["text"]))
|
||||
elif "function_call" in item or "function_response" in item:
|
||||
# Keep serialized function calls/responses so the compressor sees actions
|
||||
parts_text.append(str(item))
|
||||
elif "files" in item:
|
||||
parts_text.append(str(item))
|
||||
return "\n".join(parts_text)
|
||||
return ""
|
||||
|
||||
def _build_conversation_from_messages(self, messages: List[Dict]) -> Optional[Dict]:
|
||||
"""
|
||||
Build a conversation-like dict from current messages so we can compress
|
||||
even when the conversation isn't persisted yet. Includes tool calls/results.
|
||||
"""
|
||||
queries = []
|
||||
current_prompt = None
|
||||
current_tool_calls = {}
|
||||
|
||||
def _commit_query(response_text: str):
|
||||
nonlocal current_prompt, current_tool_calls
|
||||
if current_prompt is None and not response_text:
|
||||
return
|
||||
tool_calls_list = list(current_tool_calls.values())
|
||||
queries.append(
|
||||
{
|
||||
"prompt": current_prompt or "",
|
||||
"response": response_text,
|
||||
"tool_calls": tool_calls_list,
|
||||
}
|
||||
)
|
||||
current_prompt = None
|
||||
current_tool_calls = {}
|
||||
|
||||
for message in messages:
|
||||
role = message.get("role")
|
||||
content = message.get("content")
|
||||
|
||||
if role == "user":
|
||||
current_prompt = self._extract_text_from_content(content)
|
||||
|
||||
elif role in {"assistant", "model"}:
|
||||
# If this assistant turn contains tool calls, collect them; otherwise commit a response.
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if "function_call" in item:
|
||||
fc = item["function_call"]
|
||||
call_id = fc.get("call_id") or str(uuid.uuid4())
|
||||
current_tool_calls[call_id] = {
|
||||
"tool_name": "unknown_tool",
|
||||
"action_name": fc.get("name"),
|
||||
"arguments": fc.get("args"),
|
||||
"result": None,
|
||||
"status": "called",
|
||||
"call_id": call_id,
|
||||
}
|
||||
elif "function_response" in item:
|
||||
fr = item["function_response"]
|
||||
call_id = fr.get("call_id") or str(uuid.uuid4())
|
||||
current_tool_calls[call_id] = {
|
||||
"tool_name": "unknown_tool",
|
||||
"action_name": fr.get("name"),
|
||||
"arguments": None,
|
||||
"result": fr.get("response", {}).get("result"),
|
||||
"status": "completed",
|
||||
"call_id": call_id,
|
||||
}
|
||||
# No direct assistant text here; continue to next message
|
||||
continue
|
||||
|
||||
response_text = self._extract_text_from_content(content)
|
||||
_commit_query(response_text)
|
||||
|
||||
elif role == "tool":
|
||||
# Attach tool outputs to the latest pending tool call if possible
|
||||
tool_text = self._extract_text_from_content(content)
|
||||
# Attempt to parse function_response style
|
||||
call_id = None
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if "function_response" in item and item["function_response"].get("call_id"):
|
||||
call_id = item["function_response"]["call_id"]
|
||||
break
|
||||
if call_id and call_id in current_tool_calls:
|
||||
current_tool_calls[call_id]["result"] = tool_text
|
||||
current_tool_calls[call_id]["status"] = "completed"
|
||||
elif queries:
|
||||
queries[-1].setdefault("tool_calls", []).append(
|
||||
{
|
||||
"tool_name": "unknown_tool",
|
||||
"action_name": "unknown_action",
|
||||
"arguments": {},
|
||||
"result": tool_text,
|
||||
"status": "completed",
|
||||
}
|
||||
)
|
||||
|
||||
# If there's an unfinished prompt with tool_calls but no response yet, commit it
|
||||
if current_prompt is not None or current_tool_calls:
|
||||
_commit_query(response_text="")
|
||||
|
||||
if not queries:
|
||||
return None
|
||||
|
||||
return {
|
||||
"queries": queries,
|
||||
"compression_metadata": {
|
||||
"is_compressed": False,
|
||||
"compression_points": [],
|
||||
},
|
||||
}
|
||||
|
||||
def _rebuild_messages_after_compression(
|
||||
self,
|
||||
messages: List[Dict],
|
||||
compressed_summary: Optional[str],
|
||||
recent_queries: List[Dict],
|
||||
include_current_execution: bool = False,
|
||||
include_tool_calls: bool = False,
|
||||
) -> Optional[List[Dict]]:
|
||||
"""
|
||||
Rebuild the message list after compression so tool execution can continue.
|
||||
|
||||
Delegates to MessageBuilder for the actual reconstruction.
|
||||
"""
|
||||
from application.api.answer.services.compression.message_builder import (
|
||||
MessageBuilder,
|
||||
)
|
||||
|
||||
return MessageBuilder.rebuild_messages_after_compression(
|
||||
messages=messages,
|
||||
compressed_summary=compressed_summary,
|
||||
recent_queries=recent_queries,
|
||||
include_current_execution=include_current_execution,
|
||||
include_tool_calls=include_tool_calls,
|
||||
)
|
||||
|
||||
def _perform_mid_execution_compression(
|
||||
self, agent, messages: List[Dict]
|
||||
) -> tuple[bool, Optional[List[Dict]]]:
|
||||
"""
|
||||
Perform compression during tool execution and rebuild messages.
|
||||
|
||||
Uses the new orchestrator for simplified compression.
|
||||
|
||||
Args:
|
||||
agent: The agent instance
|
||||
messages: Current conversation messages
|
||||
|
||||
Returns:
|
||||
(success: bool, rebuilt_messages: Optional[List[Dict]])
|
||||
"""
|
||||
try:
|
||||
from application.api.answer.services.compression import (
|
||||
CompressionOrchestrator,
|
||||
)
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
|
||||
conversation_service = ConversationService()
|
||||
orchestrator = CompressionOrchestrator(conversation_service)
|
||||
|
||||
# Get conversation from database (may be None for new sessions)
|
||||
conversation = conversation_service.get_conversation(
|
||||
agent.conversation_id, agent.initial_user_id
|
||||
)
|
||||
|
||||
if conversation:
|
||||
# Merge current in-flight messages (including tool calls)
|
||||
conversation_from_msgs = self._build_conversation_from_messages(messages)
|
||||
if conversation_from_msgs:
|
||||
conversation = conversation_from_msgs
|
||||
else:
|
||||
logger.warning(
|
||||
"Could not load conversation for compression; attempting in-memory compression"
|
||||
)
|
||||
return self._perform_in_memory_compression(agent, messages)
|
||||
|
||||
# Use orchestrator to perform compression
|
||||
result = orchestrator.compress_mid_execution(
|
||||
conversation_id=agent.conversation_id,
|
||||
user_id=agent.initial_user_id,
|
||||
model_id=agent.model_id,
|
||||
decoded_token=getattr(agent, "decoded_token", {}),
|
||||
current_conversation=conversation,
|
||||
)
|
||||
|
||||
if not result.success:
|
||||
logger.warning(f"Mid-execution compression failed: {result.error}")
|
||||
# Try minimal pruning as fallback
|
||||
pruned = self._prune_messages_minimal(messages)
|
||||
if pruned:
|
||||
agent.context_limit_reached = False
|
||||
agent.current_token_count = 0
|
||||
return True, pruned
|
||||
return False, None
|
||||
|
||||
if not result.compression_performed:
|
||||
logger.warning("Compression not performed")
|
||||
return False, None
|
||||
|
||||
# Check if compression actually reduced tokens
|
||||
if result.metadata:
|
||||
if result.metadata.compressed_token_count >= result.metadata.original_token_count:
|
||||
logger.warning(
|
||||
"Compression did not reduce token count; falling back to minimal pruning"
|
||||
)
|
||||
pruned = self._prune_messages_minimal(messages)
|
||||
if pruned:
|
||||
agent.context_limit_reached = False
|
||||
agent.current_token_count = 0
|
||||
return True, pruned
|
||||
return False, None
|
||||
|
||||
logger.info(
|
||||
f"Mid-execution compression successful - ratio: {result.metadata.compression_ratio:.1f}x, "
|
||||
f"saved {result.metadata.original_token_count - result.metadata.compressed_token_count} tokens"
|
||||
)
|
||||
|
||||
# Also store the compression summary as a visible message
|
||||
if result.metadata:
|
||||
conversation_service.append_compression_message(
|
||||
agent.conversation_id, result.metadata.to_dict()
|
||||
)
|
||||
|
||||
# Update agent's compressed summary for downstream persistence
|
||||
agent.compressed_summary = result.compressed_summary
|
||||
agent.compression_metadata = result.metadata.to_dict() if result.metadata else None
|
||||
agent.compression_saved = False
|
||||
|
||||
# Reset the context limit flag so tools can continue
|
||||
agent.context_limit_reached = False
|
||||
agent.current_token_count = 0
|
||||
|
||||
# Rebuild messages
|
||||
rebuilt_messages = self._rebuild_messages_after_compression(
|
||||
messages,
|
||||
result.compressed_summary,
|
||||
result.recent_queries,
|
||||
include_current_execution=False,
|
||||
include_tool_calls=False,
|
||||
)
|
||||
|
||||
if rebuilt_messages is None:
|
||||
return False, None
|
||||
|
||||
return True, rebuilt_messages
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error performing mid-execution compression: {str(e)}", exc_info=True
|
||||
)
|
||||
return False, None
|
||||
|
||||
def _perform_in_memory_compression(
|
||||
self, agent, messages: List[Dict]
|
||||
) -> tuple[bool, Optional[List[Dict]]]:
|
||||
"""
|
||||
Fallback compression path when the conversation is not yet persisted.
|
||||
|
||||
Uses CompressionService directly without DB persistence.
|
||||
"""
|
||||
try:
|
||||
from application.api.answer.services.compression.service import (
|
||||
CompressionService,
|
||||
)
|
||||
from application.core.model_utils import (
|
||||
get_api_key_for_provider,
|
||||
get_provider_from_model_id,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
|
||||
conversation = self._build_conversation_from_messages(messages)
|
||||
if not conversation:
|
||||
logger.warning(
|
||||
"Cannot perform in-memory compression: no user/assistant turns found"
|
||||
)
|
||||
return False, None
|
||||
|
||||
compression_model = (
|
||||
settings.COMPRESSION_MODEL_OVERRIDE
|
||||
if settings.COMPRESSION_MODEL_OVERRIDE
|
||||
else agent.model_id
|
||||
)
|
||||
provider = get_provider_from_model_id(compression_model)
|
||||
api_key = get_api_key_for_provider(provider)
|
||||
compression_llm = LLMCreator.create_llm(
|
||||
provider,
|
||||
api_key,
|
||||
getattr(agent, "user_api_key", None),
|
||||
getattr(agent, "decoded_token", None),
|
||||
model_id=compression_model,
|
||||
)
|
||||
|
||||
# Create service without DB persistence capability
|
||||
compression_service = CompressionService(
|
||||
llm=compression_llm,
|
||||
model_id=compression_model,
|
||||
conversation_service=None, # No DB updates for in-memory
|
||||
)
|
||||
|
||||
queries_count = len(conversation.get("queries", []))
|
||||
compress_up_to = queries_count - 1
|
||||
|
||||
if compress_up_to < 0 or queries_count == 0:
|
||||
logger.warning("Not enough queries to compress in-memory context")
|
||||
return False, None
|
||||
|
||||
metadata = compression_service.compress_conversation(
|
||||
conversation,
|
||||
compress_up_to_index=compress_up_to,
|
||||
)
|
||||
|
||||
# If compression doesn't reduce tokens, fall back to minimal pruning
|
||||
if (
|
||||
metadata.compressed_token_count
|
||||
>= metadata.original_token_count
|
||||
):
|
||||
logger.warning(
|
||||
"In-memory compression did not reduce token count; falling back to minimal pruning"
|
||||
)
|
||||
pruned = self._prune_messages_minimal(messages)
|
||||
if pruned:
|
||||
agent.context_limit_reached = False
|
||||
agent.current_token_count = 0
|
||||
return True, pruned
|
||||
return False, None
|
||||
|
||||
# Attach metadata to synthetic conversation
|
||||
conversation["compression_metadata"] = {
|
||||
"is_compressed": True,
|
||||
"compression_points": [metadata.to_dict()],
|
||||
}
|
||||
|
||||
compressed_summary, recent_queries = (
|
||||
compression_service.get_compressed_context(conversation)
|
||||
)
|
||||
|
||||
agent.compressed_summary = compressed_summary
|
||||
agent.compression_metadata = metadata.to_dict()
|
||||
agent.compression_saved = False
|
||||
agent.context_limit_reached = False
|
||||
agent.current_token_count = 0
|
||||
|
||||
rebuilt_messages = self._rebuild_messages_after_compression(
|
||||
messages,
|
||||
compressed_summary,
|
||||
recent_queries,
|
||||
include_current_execution=False,
|
||||
include_tool_calls=False,
|
||||
)
|
||||
if rebuilt_messages is None:
|
||||
return False, None
|
||||
|
||||
logger.info(
|
||||
f"In-memory compression successful - ratio: {metadata.compression_ratio:.1f}x, "
|
||||
f"saved {metadata.original_token_count - metadata.compressed_token_count} tokens"
|
||||
)
|
||||
return True, rebuilt_messages
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error performing in-memory compression: {str(e)}", exc_info=True
|
||||
)
|
||||
return False, None
|
||||
|
||||
def handle_tool_calls(
|
||||
self, agent, tool_calls: List[ToolCall], tools_dict: Dict, messages: List[Dict]
|
||||
) -> Generator:
|
||||
@@ -597,110 +195,7 @@ class LLMHandler(ABC):
|
||||
"""
|
||||
updated_messages = messages.copy()
|
||||
|
||||
for i, call in enumerate(tool_calls):
|
||||
# Check context limit before executing tool call
|
||||
if hasattr(agent, '_check_context_limit') and agent._check_context_limit(updated_messages):
|
||||
# Context limit reached - attempt mid-execution compression
|
||||
compression_attempted = False
|
||||
compression_successful = False
|
||||
|
||||
try:
|
||||
from application.core.settings import settings
|
||||
compression_enabled = settings.ENABLE_CONVERSATION_COMPRESSION
|
||||
except Exception:
|
||||
compression_enabled = False
|
||||
|
||||
if compression_enabled:
|
||||
compression_attempted = True
|
||||
try:
|
||||
logger.info(
|
||||
f"Context limit reached with {len(tool_calls) - i} remaining tool calls. "
|
||||
f"Attempting mid-execution compression..."
|
||||
)
|
||||
|
||||
# Trigger mid-execution compression (DB-backed if available, otherwise in-memory)
|
||||
compression_successful, rebuilt_messages = self._perform_mid_execution_compression(
|
||||
agent, updated_messages
|
||||
)
|
||||
|
||||
if compression_successful and rebuilt_messages is not None:
|
||||
# Update the messages list with rebuilt compressed version
|
||||
updated_messages = rebuilt_messages
|
||||
|
||||
# Yield compression success message
|
||||
yield {
|
||||
"type": "info",
|
||||
"data": {
|
||||
"message": "Context window limit reached. Compressed conversation history to continue processing."
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Mid-execution compression successful. Continuing with {len(tool_calls) - i} remaining tool calls."
|
||||
)
|
||||
# Proceed to execute the current tool call with the reduced context
|
||||
else:
|
||||
logger.warning("Mid-execution compression attempted but failed. Skipping remaining tools.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during mid-execution compression: {str(e)}", exc_info=True)
|
||||
compression_attempted = True
|
||||
compression_successful = False
|
||||
|
||||
# If compression wasn't attempted or failed, skip remaining tools
|
||||
if not compression_successful:
|
||||
if i == 0:
|
||||
# Special case: limit reached before executing any tools
|
||||
# This can happen when previous tool responses pushed context over limit
|
||||
if compression_attempted:
|
||||
logger.warning(
|
||||
f"Context limit reached before executing any tools. "
|
||||
f"Compression attempted but failed. "
|
||||
f"Skipping all {len(tool_calls)} pending tool call(s). "
|
||||
f"This typically occurs when previous tool responses contained large amounts of data."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Context limit reached before executing any tools. "
|
||||
f"Skipping all {len(tool_calls)} pending tool call(s). "
|
||||
f"This typically occurs when previous tool responses contained large amounts of data. "
|
||||
f"Consider enabling compression or using a model with larger context window."
|
||||
)
|
||||
else:
|
||||
# Normal case: executed some tools, now stopping
|
||||
tool_word = "tool call" if i == 1 else "tool calls"
|
||||
remaining = len(tool_calls) - i
|
||||
remaining_word = "tool call" if remaining == 1 else "tool calls"
|
||||
if compression_attempted:
|
||||
logger.warning(
|
||||
f"Context limit reached after executing {i} {tool_word}. "
|
||||
f"Compression attempted but failed. "
|
||||
f"Skipping remaining {remaining} {remaining_word}."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Context limit reached after executing {i} {tool_word}. "
|
||||
f"Skipping remaining {remaining} {remaining_word}. "
|
||||
f"Consider enabling compression or using a model with larger context window."
|
||||
)
|
||||
|
||||
# Mark remaining tools as skipped
|
||||
for remaining_call in tool_calls[i:]:
|
||||
skip_message = {
|
||||
"type": "tool_call",
|
||||
"data": {
|
||||
"tool_name": "system",
|
||||
"call_id": remaining_call.id,
|
||||
"action_name": remaining_call.name,
|
||||
"arguments": {},
|
||||
"result": "Skipped: Context limit reached. Too many tool calls in conversation.",
|
||||
"status": "skipped"
|
||||
}
|
||||
}
|
||||
yield skip_message
|
||||
|
||||
# Set flag on agent
|
||||
agent.context_limit_reached = True
|
||||
break
|
||||
for call in tool_calls:
|
||||
try:
|
||||
self.tool_calls.append(call)
|
||||
tool_executor_gen = agent._execute_tool_action(tools_dict, call)
|
||||
@@ -710,57 +205,34 @@ class LLMHandler(ABC):
|
||||
except StopIteration as e:
|
||||
tool_response, call_id = e.value
|
||||
break
|
||||
|
||||
function_call_content = {
|
||||
"function_call": {
|
||||
"name": call.name,
|
||||
"args": call.arguments,
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
# Include thought_signature for Google Gemini 3 models
|
||||
# It should be at the same level as function_call, not inside it
|
||||
if call.thought_signature:
|
||||
function_call_content["thought_signature"] = call.thought_signature
|
||||
|
||||
updated_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [function_call_content],
|
||||
"content": [
|
||||
{
|
||||
"function_call": {
|
||||
"name": call.name,
|
||||
"args": call.arguments,
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
updated_messages.append(self.create_tool_message(call, tool_response))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing tool: {str(e)}", exc_info=True)
|
||||
error_call = ToolCall(
|
||||
id=call.id, name=call.name, arguments=call.arguments
|
||||
updated_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": f"Error executing tool: {str(e)}",
|
||||
"tool_call_id": call.id,
|
||||
}
|
||||
)
|
||||
error_response = f"Error executing tool: {str(e)}"
|
||||
error_message = self.create_tool_message(error_call, error_response)
|
||||
updated_messages.append(error_message)
|
||||
|
||||
call_parts = call.name.split("_")
|
||||
if len(call_parts) >= 2:
|
||||
tool_id = call_parts[-1] # Last part is tool ID (e.g., "1")
|
||||
action_name = "_".join(call_parts[:-1])
|
||||
tool_name = tools_dict.get(tool_id, {}).get("name", "unknown_tool")
|
||||
full_action_name = f"{action_name}_{tool_id}"
|
||||
else:
|
||||
tool_name = "unknown_tool"
|
||||
action_name = call.name
|
||||
full_action_name = call.name
|
||||
yield {
|
||||
"type": "tool_call",
|
||||
"data": {
|
||||
"tool_name": tool_name,
|
||||
"call_id": call.id,
|
||||
"action_name": full_action_name,
|
||||
"arguments": call.arguments,
|
||||
"error": error_response,
|
||||
"status": "error",
|
||||
},
|
||||
}
|
||||
return updated_messages
|
||||
|
||||
def handle_non_streaming(
|
||||
@@ -791,11 +263,13 @@ class LLMHandler(ABC):
|
||||
except StopIteration as e:
|
||||
messages = e.value
|
||||
break
|
||||
|
||||
response = agent.llm.gen(
|
||||
model=agent.model_id, messages=messages, tools=agent.tools
|
||||
model=agent.gpt_model, messages=messages, tools=agent.tools
|
||||
)
|
||||
parsed = self.parse_response(response)
|
||||
self.llm_calls.append(build_stack_data(agent.llm))
|
||||
|
||||
return parsed.content
|
||||
|
||||
def handle_streaming(
|
||||
@@ -834,9 +308,6 @@ class LLMHandler(ABC):
|
||||
existing.name = call.name
|
||||
if call.arguments:
|
||||
existing.arguments += call.arguments
|
||||
# Preserve thought_signature for Google Gemini 3 models
|
||||
if call.thought_signature:
|
||||
existing.thought_signature = call.thought_signature
|
||||
if parsed.finish_reason == "tool_calls":
|
||||
tool_handler_gen = self.handle_tool_calls(
|
||||
agent, list(tool_calls.values()), tools_dict, messages
|
||||
@@ -849,21 +320,8 @@ class LLMHandler(ABC):
|
||||
break
|
||||
tool_calls = {}
|
||||
|
||||
# Check if context limit was reached during tool execution
|
||||
if hasattr(agent, 'context_limit_reached') and agent.context_limit_reached:
|
||||
# Add system message warning about context limit
|
||||
messages.append({
|
||||
"role": "system",
|
||||
"content": (
|
||||
"WARNING: Context window limit has been reached. "
|
||||
"Please provide a final response to the user without making additional tool calls. "
|
||||
"Summarize the work completed so far."
|
||||
)
|
||||
})
|
||||
logger.info("Context limit reached - instructing agent to wrap up")
|
||||
|
||||
response = agent.llm.gen_stream(
|
||||
model=agent.model_id, messages=messages, tools=agent.tools if not agent.context_limit_reached else None
|
||||
model=agent.gpt_model, messages=messages, tools=agent.tools
|
||||
)
|
||||
self.llm_calls.append(build_stack_data(agent.llm))
|
||||
|
||||
|
||||
@@ -17,22 +17,18 @@ class GoogleLLMHandler(LLMHandler):
|
||||
finish_reason="stop",
|
||||
raw_response=response,
|
||||
)
|
||||
|
||||
if hasattr(response, "candidates"):
|
||||
parts = response.candidates[0].content.parts if response.candidates else []
|
||||
tool_calls = []
|
||||
for idx, part in enumerate(parts):
|
||||
if hasattr(part, "function_call") and part.function_call is not None:
|
||||
has_sig = hasattr(part, "thought_signature") and part.thought_signature is not None
|
||||
thought_sig = part.thought_signature if has_sig else None
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=str(uuid.uuid4()),
|
||||
name=part.function_call.name,
|
||||
arguments=part.function_call.args,
|
||||
index=idx,
|
||||
thought_signature=thought_sig,
|
||||
)
|
||||
)
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
id=str(uuid.uuid4()),
|
||||
name=part.function_call.name,
|
||||
arguments=part.function_call.args,
|
||||
)
|
||||
for part in parts
|
||||
if hasattr(part, "function_call") and part.function_call is not None
|
||||
]
|
||||
|
||||
content = " ".join(
|
||||
part.text
|
||||
@@ -45,18 +41,15 @@ class GoogleLLMHandler(LLMHandler):
|
||||
finish_reason="tool_calls" if tool_calls else "stop",
|
||||
raw_response=response,
|
||||
)
|
||||
|
||||
else:
|
||||
# This branch handles individual Part objects from streaming responses
|
||||
tool_calls = []
|
||||
if hasattr(response, "function_call") and response.function_call is not None:
|
||||
has_sig = hasattr(response, "thought_signature") and response.thought_signature is not None
|
||||
thought_sig = response.thought_signature if has_sig else None
|
||||
if hasattr(response, "function_call"):
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=str(uuid.uuid4()),
|
||||
name=response.function_call.name,
|
||||
arguments=response.function_call.args,
|
||||
thought_signature=thought_sig,
|
||||
)
|
||||
)
|
||||
return LLMResponse(
|
||||
@@ -68,16 +61,14 @@ class GoogleLLMHandler(LLMHandler):
|
||||
|
||||
def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict:
|
||||
"""Create Google-style tool message."""
|
||||
from google.genai import types
|
||||
|
||||
return {
|
||||
"role": "model",
|
||||
"role": "tool",
|
||||
"content": [
|
||||
{
|
||||
"function_response": {
|
||||
"name": tool_call.name,
|
||||
"response": {"result": result},
|
||||
}
|
||||
}
|
||||
types.Part.from_function_response(
|
||||
name=tool_call.name, response={"result": result}
|
||||
).to_json_dict()
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@@ -1,17 +1,13 @@
|
||||
import logging
|
||||
|
||||
from application.llm.anthropic import AnthropicLLM
|
||||
from application.llm.docsgpt_provider import DocsGPTAPILLM
|
||||
from application.llm.google_ai import GoogleLLM
|
||||
from application.llm.groq import GroqLLM
|
||||
from application.llm.openai import OpenAILLM, AzureOpenAILLM
|
||||
from application.llm.sagemaker import SagemakerAPILLM
|
||||
from application.llm.huggingface import HuggingFaceLLM
|
||||
from application.llm.llama_cpp import LlamaCpp
|
||||
from application.llm.novita import NovitaLLM
|
||||
from application.llm.openai import AzureOpenAILLM, OpenAILLM
|
||||
from application.llm.anthropic import AnthropicLLM
|
||||
from application.llm.docsgpt_provider import DocsGPTAPILLM
|
||||
from application.llm.premai import PremAILLM
|
||||
from application.llm.sagemaker import SagemakerAPILLM
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from application.llm.google_ai import GoogleLLM
|
||||
from application.llm.novita import NovitaLLM
|
||||
|
||||
|
||||
class LLMCreator:
|
||||
@@ -30,26 +26,10 @@ class LLMCreator:
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create_llm(
|
||||
cls, type, api_key, user_api_key, decoded_token, model_id=None, *args, **kwargs
|
||||
):
|
||||
from application.core.model_utils import get_base_url_for_model
|
||||
|
||||
def create_llm(cls, type, api_key, user_api_key, decoded_token, *args, **kwargs):
|
||||
llm_class = cls.llms.get(type.lower())
|
||||
if not llm_class:
|
||||
raise ValueError(f"No LLM class found for type {type}")
|
||||
|
||||
# Extract base_url from model configuration if model_id is provided
|
||||
base_url = None
|
||||
if model_id:
|
||||
base_url = get_base_url_for_model(model_id)
|
||||
|
||||
return llm_class(
|
||||
api_key,
|
||||
user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
model_id=model_id,
|
||||
base_url=base_url,
|
||||
*args,
|
||||
**kwargs,
|
||||
api_key, user_api_key, decoded_token=decoded_token, *args, **kwargs
|
||||
)
|
||||
|
||||
@@ -2,8 +2,6 @@ import base64
|
||||
import json
|
||||
import logging
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.llm.base import BaseLLM
|
||||
from application.storage.storage_creator import StorageCreator
|
||||
@@ -11,25 +9,20 @@ from application.storage.storage_creator import StorageCreator
|
||||
|
||||
class OpenAILLM(BaseLLM):
|
||||
|
||||
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
|
||||
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
|
||||
from openai import OpenAI
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
self.api_key = api_key or settings.OPENAI_API_KEY or settings.API_KEY
|
||||
self.user_api_key = user_api_key
|
||||
|
||||
# Priority: 1) Parameter base_url, 2) Settings OPENAI_BASE_URL, 3) Default
|
||||
effective_base_url = None
|
||||
if base_url and isinstance(base_url, str) and base_url.strip():
|
||||
effective_base_url = base_url
|
||||
elif (
|
||||
if (
|
||||
isinstance(settings.OPENAI_BASE_URL, str)
|
||||
and settings.OPENAI_BASE_URL.strip()
|
||||
):
|
||||
effective_base_url = settings.OPENAI_BASE_URL
|
||||
self.client = OpenAI(api_key=api_key, base_url=settings.OPENAI_BASE_URL)
|
||||
else:
|
||||
effective_base_url = "https://api.openai.com/v1"
|
||||
|
||||
self.client = OpenAI(api_key=self.api_key, base_url=effective_base_url)
|
||||
DEFAULT_OPENAI_API_BASE = "https://api.openai.com/v1"
|
||||
self.client = OpenAI(api_key=api_key, base_url=DEFAULT_OPENAI_API_BASE)
|
||||
self.api_key = api_key
|
||||
self.user_api_key = user_api_key
|
||||
self.storage = StorageCreator.get_storage()
|
||||
|
||||
def _clean_messages_openai(self, messages):
|
||||
@@ -40,6 +33,7 @@ class OpenAILLM(BaseLLM):
|
||||
|
||||
if role == "model":
|
||||
role = "assistant"
|
||||
|
||||
if role and content is not None:
|
||||
if isinstance(content, str):
|
||||
cleaned_messages.append({"role": role, "content": content})
|
||||
@@ -50,15 +44,14 @@ class OpenAILLM(BaseLLM):
|
||||
{"role": role, "content": item["text"]}
|
||||
)
|
||||
elif "function_call" in item:
|
||||
cleaned_args = self._remove_null_values(
|
||||
item["function_call"]["args"]
|
||||
)
|
||||
tool_call = {
|
||||
"id": item["function_call"]["call_id"],
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": item["function_call"]["name"],
|
||||
"arguments": json.dumps(cleaned_args),
|
||||
"arguments": json.dumps(
|
||||
item["function_call"]["args"]
|
||||
),
|
||||
},
|
||||
}
|
||||
cleaned_messages.append(
|
||||
@@ -113,6 +106,7 @@ class OpenAILLM(BaseLLM):
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unexpected content type: {type(content)}")
|
||||
|
||||
return cleaned_messages
|
||||
|
||||
def _raw_gen(
|
||||
@@ -128,10 +122,6 @@ class OpenAILLM(BaseLLM):
|
||||
):
|
||||
messages = self._clean_messages_openai(messages)
|
||||
|
||||
# Convert max_tokens to max_completion_tokens for newer models
|
||||
if "max_tokens" in kwargs:
|
||||
kwargs["max_completion_tokens"] = kwargs.pop("max_tokens")
|
||||
|
||||
request_params = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
@@ -141,8 +131,10 @@ class OpenAILLM(BaseLLM):
|
||||
|
||||
if tools:
|
||||
request_params["tools"] = tools
|
||||
|
||||
if response_format:
|
||||
request_params["response_format"] = response_format
|
||||
|
||||
response = self.client.chat.completions.create(**request_params)
|
||||
|
||||
if tools:
|
||||
@@ -163,10 +155,6 @@ class OpenAILLM(BaseLLM):
|
||||
):
|
||||
messages = self._clean_messages_openai(messages)
|
||||
|
||||
# Convert max_tokens to max_completion_tokens for newer models
|
||||
if "max_tokens" in kwargs:
|
||||
kwargs["max_completion_tokens"] = kwargs.pop("max_tokens")
|
||||
|
||||
request_params = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
@@ -176,23 +164,21 @@ class OpenAILLM(BaseLLM):
|
||||
|
||||
if tools:
|
||||
request_params["tools"] = tools
|
||||
|
||||
if response_format:
|
||||
request_params["response_format"] = response_format
|
||||
|
||||
response = self.client.chat.completions.create(**request_params)
|
||||
|
||||
try:
|
||||
for line in response:
|
||||
if (
|
||||
len(line.choices) > 0
|
||||
and line.choices[0].delta.content is not None
|
||||
and len(line.choices[0].delta.content) > 0
|
||||
):
|
||||
yield line.choices[0].delta.content
|
||||
elif len(line.choices) > 0:
|
||||
yield line.choices[0]
|
||||
finally:
|
||||
if hasattr(response, "close"):
|
||||
response.close()
|
||||
for line in response:
|
||||
if (
|
||||
len(line.choices) > 0
|
||||
and line.choices[0].delta.content is not None
|
||||
and len(line.choices[0].delta.content) > 0
|
||||
):
|
||||
yield line.choices[0].delta.content
|
||||
elif len(line.choices) > 0:
|
||||
yield line.choices[0]
|
||||
|
||||
def _supports_tools(self):
|
||||
return True
|
||||
@@ -203,6 +189,7 @@ class OpenAILLM(BaseLLM):
|
||||
def prepare_structured_output_format(self, json_schema):
|
||||
if not json_schema:
|
||||
return None
|
||||
|
||||
try:
|
||||
|
||||
def add_additional_properties_false(schema_obj):
|
||||
@@ -212,11 +199,11 @@ class OpenAILLM(BaseLLM):
|
||||
if schema_copy.get("type") == "object":
|
||||
schema_copy["additionalProperties"] = False
|
||||
# Ensure 'required' includes all properties for OpenAI strict mode
|
||||
|
||||
if "properties" in schema_copy:
|
||||
schema_copy["required"] = list(
|
||||
schema_copy["properties"].keys()
|
||||
)
|
||||
|
||||
for key, value in schema_copy.items():
|
||||
if key == "properties" and isinstance(value, dict):
|
||||
schema_copy[key] = {
|
||||
@@ -232,6 +219,7 @@ class OpenAILLM(BaseLLM):
|
||||
add_additional_properties_false(sub_schema)
|
||||
for sub_schema in value
|
||||
]
|
||||
|
||||
return schema_copy
|
||||
return schema_obj
|
||||
|
||||
@@ -250,6 +238,7 @@ class OpenAILLM(BaseLLM):
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error preparing structured output format: {e}")
|
||||
return None
|
||||
@@ -283,19 +272,21 @@ class OpenAILLM(BaseLLM):
|
||||
"""
|
||||
if not attachments:
|
||||
return messages
|
||||
|
||||
prepared_messages = messages.copy()
|
||||
|
||||
# Find the user message to attach file_id to the last one
|
||||
|
||||
user_message_index = None
|
||||
for i in range(len(prepared_messages) - 1, -1, -1):
|
||||
if prepared_messages[i].get("role") == "user":
|
||||
user_message_index = i
|
||||
break
|
||||
|
||||
if user_message_index is None:
|
||||
user_message = {"role": "user", "content": []}
|
||||
prepared_messages.append(user_message)
|
||||
user_message_index = len(prepared_messages) - 1
|
||||
|
||||
if isinstance(prepared_messages[user_message_index].get("content"), str):
|
||||
text_content = prepared_messages[user_message_index]["content"]
|
||||
prepared_messages[user_message_index]["content"] = [
|
||||
@@ -303,6 +294,7 @@ class OpenAILLM(BaseLLM):
|
||||
]
|
||||
elif not isinstance(prepared_messages[user_message_index].get("content"), list):
|
||||
prepared_messages[user_message_index]["content"] = []
|
||||
|
||||
for attachment in attachments:
|
||||
mime_type = attachment.get("mime_type")
|
||||
|
||||
@@ -329,7 +321,6 @@ class OpenAILLM(BaseLLM):
|
||||
}
|
||||
)
|
||||
# Handle PDFs using the file API
|
||||
|
||||
elif mime_type == "application/pdf":
|
||||
try:
|
||||
file_id = self._upload_file_to_openai(attachment)
|
||||
@@ -345,6 +336,7 @@ class OpenAILLM(BaseLLM):
|
||||
"text": f"File content:\n\n{attachment['content']}",
|
||||
}
|
||||
)
|
||||
|
||||
return prepared_messages
|
||||
|
||||
def _get_base64_image(self, attachment):
|
||||
@@ -360,6 +352,7 @@ class OpenAILLM(BaseLLM):
|
||||
file_path = attachment.get("path")
|
||||
if not file_path:
|
||||
raise ValueError("No file path provided in attachment")
|
||||
|
||||
try:
|
||||
with self.storage.get_file(file_path) as image_file:
|
||||
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||
@@ -383,10 +376,12 @@ class OpenAILLM(BaseLLM):
|
||||
|
||||
if "openai_file_id" in attachment:
|
||||
return attachment["openai_file_id"]
|
||||
|
||||
file_path = attachment.get("path")
|
||||
|
||||
if not self.storage.file_exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
try:
|
||||
file_id = self.storage.process_file(
|
||||
file_path,
|
||||
@@ -404,6 +399,7 @@ class OpenAILLM(BaseLLM):
|
||||
attachments_collection.update_one(
|
||||
{"_id": attachment["_id"]}, {"$set": {"openai_file_id": file_id}}
|
||||
)
|
||||
|
||||
return file_id
|
||||
except Exception as e:
|
||||
logging.error(f"Error uploading file to OpenAI: {e}", exc_info=True)
|
||||
|
||||
@@ -17,13 +17,14 @@ class GoogleDriveAuth(BaseConnectorAuth):
|
||||
"""
|
||||
|
||||
SCOPES = [
|
||||
'https://www.googleapis.com/auth/drive.file'
|
||||
'https://www.googleapis.com/auth/drive.readonly',
|
||||
'https://www.googleapis.com/auth/drive.metadata.readonly'
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
self.client_id = settings.GOOGLE_CLIENT_ID
|
||||
self.client_secret = settings.GOOGLE_CLIENT_SECRET
|
||||
self.redirect_uri = f"{settings.CONNECTOR_REDIRECT_BASE_URI}"
|
||||
self.redirect_uri = f"{settings.CONNECTOR_REDIRECT_BASE_URI}?provider=google_drive"
|
||||
|
||||
if not self.client_id or not self.client_secret:
|
||||
raise ValueError("Google OAuth credentials not configured. Please set GOOGLE_CLIENT_ID and GOOGLE_CLIENT_SECRET in settings.")
|
||||
@@ -49,7 +50,7 @@ class GoogleDriveAuth(BaseConnectorAuth):
|
||||
authorization_url, _ = flow.authorization_url(
|
||||
access_type='offline',
|
||||
prompt='consent',
|
||||
include_granted_scopes='false',
|
||||
include_granted_scopes='true',
|
||||
state=state
|
||||
)
|
||||
|
||||
|
||||
@@ -32,10 +32,6 @@ class GoogleDriveLoader(BaseConnectorLoader):
|
||||
'text/plain': '.txt',
|
||||
'text/csv': '.csv',
|
||||
'text/html': '.html',
|
||||
'text/markdown': '.md',
|
||||
'text/x-rst': '.rst',
|
||||
'application/json': '.json',
|
||||
'application/epub+zip': '.epub',
|
||||
'application/rtf': '.rtf',
|
||||
'image/jpeg': '.jpg',
|
||||
'image/jpg': '.jpg',
|
||||
@@ -124,7 +120,6 @@ class GoogleDriveLoader(BaseConnectorLoader):
|
||||
list_only = inputs.get('list_only', False)
|
||||
load_content = not list_only
|
||||
page_token = inputs.get('page_token')
|
||||
search_query = inputs.get('search_query')
|
||||
self.next_page_token = None
|
||||
|
||||
if file_ids:
|
||||
@@ -133,18 +128,12 @@ class GoogleDriveLoader(BaseConnectorLoader):
|
||||
try:
|
||||
doc = self._load_file_by_id(file_id, load_content=load_content)
|
||||
if doc:
|
||||
if not search_query or (
|
||||
search_query.lower() in doc.extra_info.get('file_name', '').lower()
|
||||
):
|
||||
documents.append(doc)
|
||||
documents.append(doc)
|
||||
elif hasattr(self, '_credential_refreshed') and self._credential_refreshed:
|
||||
self._credential_refreshed = False
|
||||
logging.info(f"Retrying load of file {file_id} after credential refresh")
|
||||
doc = self._load_file_by_id(file_id, load_content=load_content)
|
||||
if doc and (
|
||||
not search_query or
|
||||
search_query.lower() in doc.extra_info.get('file_name', '').lower()
|
||||
):
|
||||
if doc:
|
||||
documents.append(doc)
|
||||
except Exception as e:
|
||||
logging.error(f"Error loading file {file_id}: {e}")
|
||||
@@ -152,13 +141,7 @@ class GoogleDriveLoader(BaseConnectorLoader):
|
||||
else:
|
||||
# Browsing mode: list immediate children of provided folder or root
|
||||
parent_id = folder_id if folder_id else 'root'
|
||||
documents = self._list_items_in_parent(
|
||||
parent_id,
|
||||
limit=limit,
|
||||
load_content=load_content,
|
||||
page_token=page_token,
|
||||
search_query=search_query
|
||||
)
|
||||
documents = self._list_items_in_parent(parent_id, limit=limit, load_content=load_content, page_token=page_token)
|
||||
|
||||
logging.info(f"Loaded {len(documents)} documents from Google Drive")
|
||||
return documents
|
||||
@@ -201,18 +184,13 @@ class GoogleDriveLoader(BaseConnectorLoader):
|
||||
return None
|
||||
|
||||
|
||||
def _list_items_in_parent(self, parent_id: str, limit: int = 100, load_content: bool = False, page_token: Optional[str] = None, search_query: Optional[str] = None) -> List[Document]:
|
||||
def _list_items_in_parent(self, parent_id: str, limit: int = 100, load_content: bool = False, page_token: Optional[str] = None) -> List[Document]:
|
||||
self._ensure_service()
|
||||
|
||||
documents: List[Document] = []
|
||||
|
||||
try:
|
||||
query = f"'{parent_id}' in parents and trashed=false"
|
||||
|
||||
if search_query:
|
||||
safe_search = search_query.replace("'", "\\'")
|
||||
query += f" and name contains '{safe_search}'"
|
||||
|
||||
next_token_out: Optional[str] = None
|
||||
|
||||
while True:
|
||||
@@ -227,8 +205,7 @@ class GoogleDriveLoader(BaseConnectorLoader):
|
||||
q=query,
|
||||
fields='nextPageToken,files(id,name,mimeType,size,createdTime,modifiedTime,parents)',
|
||||
pageToken=page_token,
|
||||
pageSize=page_size,
|
||||
orderBy='name'
|
||||
pageSize=page_size
|
||||
).execute()
|
||||
|
||||
items = results.get('files', [])
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import List, Any
|
||||
from retry import retry
|
||||
from tqdm import tqdm
|
||||
from application.core.settings import settings
|
||||
@@ -23,16 +22,13 @@ def sanitize_content(content: str) -> str:
|
||||
|
||||
|
||||
@retry(tries=10, delay=60)
|
||||
def add_text_to_store_with_retry(store: Any, doc: Any, source_id: str) -> None:
|
||||
"""Add a document's text and metadata to the vector store with retry logic.
|
||||
|
||||
def add_text_to_store_with_retry(store, doc, source_id):
|
||||
"""
|
||||
Add a document's text and metadata to the vector store with retry logic.
|
||||
Args:
|
||||
store: The vector store object.
|
||||
doc: The document to be added.
|
||||
source_id: Unique identifier for the source.
|
||||
|
||||
Raises:
|
||||
Exception: If document addition fails after all retry attempts.
|
||||
"""
|
||||
try:
|
||||
# Sanitize content to remove NUL characters that cause ingestion failures
|
||||
@@ -45,21 +41,18 @@ def add_text_to_store_with_retry(store: Any, doc: Any, source_id: str) -> None:
|
||||
raise
|
||||
|
||||
|
||||
def embed_and_store_documents(docs: List[Any], folder_name: str, source_id: str, task_status: Any) -> None:
|
||||
"""Embeds documents and stores them in a vector store.
|
||||
def embed_and_store_documents(docs, folder_name, source_id, task_status):
|
||||
"""
|
||||
Embeds documents and stores them in a vector store.
|
||||
|
||||
Args:
|
||||
docs: List of documents to be embedded and stored.
|
||||
folder_name: Directory to save the vector store.
|
||||
source_id: Unique identifier for the source.
|
||||
docs (list): List of documents to be embedded and stored.
|
||||
folder_name (str): Directory to save the vector store.
|
||||
source_id (str): Unique identifier for the source.
|
||||
task_status: Task state manager for progress updates.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Raises:
|
||||
OSError: If unable to create folder or save vector store.
|
||||
Exception: If vector store creation or document embedding fails.
|
||||
"""
|
||||
# Ensure the folder exists
|
||||
if not os.path.exists(folder_name):
|
||||
@@ -102,21 +95,10 @@ def embed_and_store_documents(docs: List[Any], folder_name: str, source_id: str,
|
||||
except Exception as e:
|
||||
logging.error(f"Error embedding document {idx}: {e}", exc_info=True)
|
||||
logging.info(f"Saving progress at document {idx} out of {total_docs}")
|
||||
try:
|
||||
store.save_local(folder_name)
|
||||
logging.info("Progress saved successfully")
|
||||
except Exception as save_error:
|
||||
logging.error(f"CRITICAL: Failed to save progress: {save_error}", exc_info=True)
|
||||
# Continue without breaking to attempt final save
|
||||
store.save_local(folder_name)
|
||||
break
|
||||
|
||||
# Save the vector store
|
||||
if settings.VECTOR_STORE == "faiss":
|
||||
try:
|
||||
store.save_local(folder_name)
|
||||
logging.info("Vector store saved successfully.")
|
||||
except Exception as e:
|
||||
logging.error(f"CRITICAL: Failed to save final vector store: {e}", exc_info=True)
|
||||
raise OSError(f"Unable to save vector store to {folder_name}: {e}") from e
|
||||
else:
|
||||
logging.info("Vector store saved successfully.")
|
||||
store.save_local(folder_name)
|
||||
logging.info("Vector store saved successfully.")
|
||||
|
||||
@@ -1,135 +1,44 @@
|
||||
import base64
|
||||
import requests
|
||||
import time
|
||||
from typing import List, Optional
|
||||
from typing import List
|
||||
from application.parser.remote.base import BaseRemote
|
||||
from application.parser.schema.base import Document
|
||||
from langchain_core.documents import Document
|
||||
import mimetypes
|
||||
from application.core.settings import settings
|
||||
|
||||
class GitHubLoader(BaseRemote):
|
||||
def __init__(self):
|
||||
self.access_token = settings.GITHUB_ACCESS_TOKEN
|
||||
self.access_token = None
|
||||
self.headers = {
|
||||
"Authorization": f"token {self.access_token}",
|
||||
"Accept": "application/vnd.github.v3+json"
|
||||
} if self.access_token else {
|
||||
"Accept": "application/vnd.github.v3+json"
|
||||
}
|
||||
"Authorization": f"token {self.access_token}"
|
||||
} if self.access_token else {}
|
||||
return
|
||||
|
||||
def is_text_file(self, file_path: str) -> bool:
|
||||
"""Determine if a file is a text file based on extension."""
|
||||
# Common text file extensions
|
||||
text_extensions = {
|
||||
'.txt', '.md', '.markdown', '.rst', '.json', '.xml', '.yaml', '.yml',
|
||||
'.py', '.js', '.ts', '.jsx', '.tsx', '.java', '.c', '.cpp', '.h', '.hpp',
|
||||
'.cs', '.go', '.rs', '.rb', '.php', '.swift', '.kt', '.scala',
|
||||
'.html', '.css', '.scss', '.sass', '.less',
|
||||
'.sh', '.bash', '.zsh', '.fish',
|
||||
'.sql', '.r', '.m', '.mat',
|
||||
'.ini', '.cfg', '.conf', '.config', '.env',
|
||||
'.gitignore', '.dockerignore', '.editorconfig',
|
||||
'.log', '.csv', '.tsv'
|
||||
}
|
||||
|
||||
# Get file extension
|
||||
file_lower = file_path.lower()
|
||||
for ext in text_extensions:
|
||||
if file_lower.endswith(ext):
|
||||
return True
|
||||
|
||||
# Also check MIME type
|
||||
mime_type, _ = mimetypes.guess_type(file_path)
|
||||
if mime_type and (mime_type.startswith("text") or mime_type in ["application/json", "application/xml"]):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def fetch_file_content(self, repo_url: str, file_path: str) -> Optional[str]:
|
||||
"""Fetch file content. Returns None if file should be skipped (binary files or empty files)."""
|
||||
def fetch_file_content(self, repo_url: str, file_path: str) -> str:
|
||||
url = f"https://api.github.com/repos/{repo_url}/contents/{file_path}"
|
||||
response = self._make_request(url)
|
||||
response = requests.get(url, headers=self.headers)
|
||||
|
||||
content = response.json()
|
||||
if response.status_code == 200:
|
||||
content = response.json()
|
||||
mime_type, _ = mimetypes.guess_type(file_path) # Guess the MIME type based on the file extension
|
||||
|
||||
if content.get("encoding") == "base64":
|
||||
if self.is_text_file(file_path): # Handle only text files
|
||||
try:
|
||||
decoded_content = base64.b64decode(content["content"]).decode("utf-8").strip()
|
||||
# Skip empty files
|
||||
if not decoded_content:
|
||||
return None
|
||||
return decoded_content
|
||||
except Exception:
|
||||
# If decoding fails, it's probably a binary file
|
||||
return None
|
||||
if content.get("encoding") == "base64":
|
||||
if mime_type and mime_type.startswith("text"): # Handle only text files
|
||||
try:
|
||||
decoded_content = base64.b64decode(content["content"]).decode("utf-8")
|
||||
return f"Filename: {file_path}\n\n{decoded_content}"
|
||||
except Exception as e:
|
||||
raise e
|
||||
else:
|
||||
return f"Filename: {file_path} is a binary file and was skipped."
|
||||
else:
|
||||
# Skip binary files by returning None
|
||||
return None
|
||||
return f"Filename: {file_path}\n\n{content['content']}"
|
||||
else:
|
||||
file_content = content['content'].strip()
|
||||
# Skip empty files
|
||||
if not file_content:
|
||||
return None
|
||||
return file_content
|
||||
|
||||
def _make_request(self, url: str, max_retries: int = 3) -> requests.Response:
|
||||
"""Make a request with retry logic for rate limiting"""
|
||||
for attempt in range(max_retries):
|
||||
response = requests.get(url, headers=self.headers)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response
|
||||
elif response.status_code == 403:
|
||||
# Check if it's a rate limit issue
|
||||
try:
|
||||
error_data = response.json()
|
||||
error_msg = error_data.get("message", "")
|
||||
|
||||
# Check rate limit headers
|
||||
remaining = response.headers.get("X-RateLimit-Remaining", "unknown")
|
||||
reset_time = response.headers.get("X-RateLimit-Reset", "unknown")
|
||||
|
||||
print(f"GitHub API 403 Error: {error_msg}")
|
||||
print(f"Rate limit remaining: {remaining}, Reset time: {reset_time}")
|
||||
|
||||
if "rate limit" in error_msg.lower():
|
||||
if attempt < max_retries - 1:
|
||||
wait_time = 2 ** attempt # Exponential backoff
|
||||
print(f"Rate limit hit, waiting {wait_time} seconds before retry...")
|
||||
time.sleep(wait_time)
|
||||
continue
|
||||
|
||||
# Provide helpful error message
|
||||
if remaining == "0":
|
||||
raise Exception(f"GitHub API rate limit exceeded. Please set GITHUB_ACCESS_TOKEN environment variable. Reset time: {reset_time}")
|
||||
else:
|
||||
raise Exception(f"GitHub API error: {error_msg}. This may require authentication - set GITHUB_ACCESS_TOKEN environment variable.")
|
||||
except Exception as e:
|
||||
if isinstance(e, Exception) and "GitHub API" in str(e):
|
||||
raise
|
||||
# If we can't parse the response, raise the original error
|
||||
response.raise_for_status()
|
||||
else:
|
||||
response.raise_for_status()
|
||||
|
||||
return response
|
||||
response.raise_for_status()
|
||||
|
||||
def fetch_repo_files(self, repo_url: str, path: str = "") -> List[str]:
|
||||
url = f"https://api.github.com/repos/{repo_url}/contents/{path}"
|
||||
response = self._make_request(url)
|
||||
|
||||
response = requests.get(url, headers={**self.headers, "Accept": "application/vnd.github.v3.raw"})
|
||||
contents = response.json()
|
||||
|
||||
# Handle error responses from GitHub API
|
||||
if isinstance(contents, dict) and "message" in contents:
|
||||
raise Exception(f"GitHub API error: {contents.get('message')}")
|
||||
|
||||
# Ensure contents is a list
|
||||
if not isinstance(contents, list):
|
||||
raise TypeError(f"Expected list from GitHub API, got {type(contents).__name__}: {contents}")
|
||||
|
||||
files = []
|
||||
for item in contents:
|
||||
if item["type"] == "file":
|
||||
@@ -144,15 +53,6 @@ class GitHubLoader(BaseRemote):
|
||||
documents = []
|
||||
for file_path in files:
|
||||
content = self.fetch_file_content(repo_name, file_path)
|
||||
# Skip binary files (content is None)
|
||||
if content is None:
|
||||
continue
|
||||
documents.append(Document(
|
||||
text=content,
|
||||
doc_id=file_path,
|
||||
extra_info={
|
||||
"title": file_path,
|
||||
"source": f"https://github.com/{repo_name}/blob/main/{file_path}"
|
||||
}
|
||||
))
|
||||
documents.append(Document(page_content=content, metadata={"title": file_path,
|
||||
"source": f"https://github.com/{repo_name}/blob/main/{file_path}"}))
|
||||
return documents
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
Your task is to create a detailed summary of the conversation so far, paying close attention to the user's explicit requests and your previous actions.
|
||||
|
||||
This summary should be thorough in capturing technical details, code patterns, and architectural decisions that would be essential for continuing work without losing context.
|
||||
|
||||
Before providing your final summary, wrap your analysis in <analysis> tags to organize your thoughts and ensure you've covered all necessary points. In your analysis process:
|
||||
|
||||
1. Chronologically analyze each message, tool call and section of the conversation. For each section thoroughly identify:
|
||||
- The user's explicit requests and intents
|
||||
- Your approach to addressing the user's requests
|
||||
- Key decisions, concepts and patterns
|
||||
- Specific details like if applicable:
|
||||
- file names
|
||||
- full code snippets
|
||||
- function signatures
|
||||
- file edits
|
||||
- Errors that you ran into and how you fixed them
|
||||
- Pay special attention to specific user feedback that you received, especially if the user told you to do something differently.
|
||||
|
||||
2. Double-check for accuracy and completeness, addressing each required element thoroughly.
|
||||
|
||||
Your summary should include the following sections:
|
||||
|
||||
1. Primary Request and Intent: Capture all of the user's explicit requests and intents in detail
|
||||
2. Key Concepts: List all important concepts discussed.
|
||||
3. Files and Code Sections: Enumerate specific files and code sections examined, modified, or created. Pay special attention to the most recent messages and include full code snippets where applicable and include a summary of why this file read or edit is important.
|
||||
4. Errors and fixes: List all errors that you ran into, and how you fixed them. Pay special attention to specific user feedback that you received, especially if the user told you to do something differently.
|
||||
5. Problem Solving: Document problems solved and any ongoing troubleshooting efforts.
|
||||
6. All user messages: List ALL user messages that are not tool results. These are critical for understanding the users' feedback and changing intent.
|
||||
7. Tool Calls: List ALL tool calls made, including their inputs relevant parts of the outputs.
|
||||
8. Pending Tasks: Outline any pending tasks that you have explicitly been asked to work on.
|
||||
9. Current Work: Describe in detail precisely what was being worked on immediately before this summary request, paying special attention to the most recent messages from both user and assistant. Include file names and code snippets where applicable.
|
||||
10. Optional Next Step: List the next step that you will take that is related to the most recent work you were doing. IMPORTANT: ensure that this step is DIRECTLY in line with the user's most recent explicit requests, and the task you were working on immediately before this summary request. If your last task was concluded, then only list next steps if they are explicitly in line with the users request. Do not start on tangential requests or really old requests that were already completed without confirming with the user first.
|
||||
If there is a next step, include direct quotes from the most recent conversation showing exactly what task you were working on and where you left off. This should be verbatim to ensure there's no drift in task interpretation.
|
||||
|
||||
Please provide your summary based on the conversation and tools used so far, following this structure and ensuring precision and thoroughness in your response.
|
||||
@@ -2,7 +2,6 @@ anthropic==0.49.0
|
||||
boto3==1.38.18
|
||||
beautifulsoup4==4.13.4
|
||||
celery==5.4.0
|
||||
cryptography==42.0.8
|
||||
dataclasses-json==0.6.7
|
||||
docx2txt==0.8
|
||||
duckduckgo-search==7.5.2
|
||||
@@ -10,12 +9,10 @@ ebooklib==0.18
|
||||
escodegen==1.0.11
|
||||
esprima==4.0.1
|
||||
esutils==1.0.1
|
||||
elevenlabs==2.17.0
|
||||
Flask==3.1.1
|
||||
faiss-cpu==1.9.0.post1
|
||||
fastmcp==2.11.0
|
||||
flask-restx==1.3.0
|
||||
google-genai==1.49.0
|
||||
google-genai==1.3.0
|
||||
google-api-python-client==2.179.0
|
||||
google-auth-httplib2==0.2.0
|
||||
google-auth-oauthlib==1.2.2
|
||||
@@ -58,13 +55,13 @@ prompt-toolkit==3.0.51
|
||||
protobuf==5.29.3
|
||||
psycopg2-binary==2.9.10
|
||||
py==1.11.0
|
||||
pydantic
|
||||
pydantic-core
|
||||
pydantic-settings
|
||||
pydantic==2.10.6
|
||||
pydantic-core==2.27.2
|
||||
pydantic-settings==2.7.1
|
||||
pymongo==4.11.3
|
||||
pypdf==5.5.0
|
||||
python-dateutil==2.9.0.post0
|
||||
python-dotenv
|
||||
python-dotenv==1.0.1
|
||||
python-jose==3.4.0
|
||||
python-pptx==1.0.2
|
||||
redis==5.2.1
|
||||
@@ -84,7 +81,7 @@ tzdata==2024.2
|
||||
urllib3==2.3.0
|
||||
vine==5.1.0
|
||||
wcwidth==0.2.13
|
||||
werkzeug>=3.1.0,<3.1.2
|
||||
werkzeug==3.1.3
|
||||
yarl==1.20.0
|
||||
markdownify==1.1.0
|
||||
tldextract==5.1.3
|
||||
|
||||
@@ -5,6 +5,14 @@ class BaseRetriever(ABC):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def gen(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_params(self):
|
||||
pass
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.retriever.base import BaseRetriever
|
||||
from application.utils import num_tokens_from_string
|
||||
|
||||
from application.vectorstore.vector_creator import VectorCreator
|
||||
|
||||
|
||||
@@ -15,33 +13,28 @@ class ClassicRAG(BaseRetriever):
|
||||
chat_history=None,
|
||||
prompt="",
|
||||
chunks=2,
|
||||
doc_token_limit=50000,
|
||||
model_id="docsgpt-local",
|
||||
token_limit=150,
|
||||
gpt_model="docsgpt",
|
||||
user_api_key=None,
|
||||
llm_name=settings.LLM_PROVIDER,
|
||||
api_key=settings.API_KEY,
|
||||
decoded_token=None,
|
||||
):
|
||||
self.original_question = source.get("question", "")
|
||||
self.original_question = ""
|
||||
self.chat_history = chat_history if chat_history is not None else []
|
||||
self.prompt = prompt
|
||||
if isinstance(chunks, str):
|
||||
try:
|
||||
self.chunks = int(chunks)
|
||||
except ValueError:
|
||||
logging.warning(
|
||||
f"Invalid chunks value '{chunks}', using default value 2"
|
||||
)
|
||||
self.chunks = 2
|
||||
else:
|
||||
self.chunks = chunks
|
||||
user_identifier = user_api_key if user_api_key else "default"
|
||||
logging.info(
|
||||
f"ClassicRAG initialized with chunks={self.chunks}, user_api_key={user_identifier}, "
|
||||
f"sources={'active_docs' in source and source['active_docs'] is not None}"
|
||||
self.chunks = chunks
|
||||
self.gpt_model = gpt_model
|
||||
self.token_limit = (
|
||||
token_limit
|
||||
if token_limit
|
||||
< settings.LLM_TOKEN_LIMITS.get(
|
||||
self.gpt_model, settings.DEFAULT_MAX_HISTORY
|
||||
)
|
||||
else settings.LLM_TOKEN_LIMITS.get(
|
||||
self.gpt_model, settings.DEFAULT_MAX_HISTORY
|
||||
)
|
||||
)
|
||||
self.model_id = model_id
|
||||
self.doc_token_limit = doc_token_limit
|
||||
self.user_api_key = user_api_key
|
||||
self.llm_name = llm_name
|
||||
self.api_key = api_key
|
||||
@@ -51,48 +44,26 @@ class ClassicRAG(BaseRetriever):
|
||||
user_api_key=self.user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
if "active_docs" in source and source["active_docs"] is not None:
|
||||
if isinstance(source["active_docs"], list):
|
||||
self.vectorstores = source["active_docs"]
|
||||
else:
|
||||
self.vectorstores = [source["active_docs"]]
|
||||
else:
|
||||
self.vectorstores = []
|
||||
self.vectorstore = source["active_docs"] if "active_docs" in source else None
|
||||
self.question = self._rephrase_query()
|
||||
self.decoded_token = decoded_token
|
||||
self._validate_vectorstore_config()
|
||||
|
||||
def _validate_vectorstore_config(self):
|
||||
"""Validate vectorstore IDs and remove any empty/invalid entries"""
|
||||
if not self.vectorstores:
|
||||
logging.warning("No vectorstores configured for retrieval")
|
||||
return
|
||||
invalid_ids = [
|
||||
vs_id for vs_id in self.vectorstores if not vs_id or not vs_id.strip()
|
||||
]
|
||||
if invalid_ids:
|
||||
logging.warning(f"Found invalid vectorstore IDs: {invalid_ids}")
|
||||
self.vectorstores = [
|
||||
vs_id for vs_id in self.vectorstores if vs_id and vs_id.strip()
|
||||
]
|
||||
|
||||
def _rephrase_query(self):
|
||||
"""Rephrase user query with chat history context for better retrieval"""
|
||||
if (
|
||||
not self.original_question
|
||||
or not self.chat_history
|
||||
or self.chat_history == []
|
||||
or self.chunks == 0
|
||||
or not self.vectorstores
|
||||
or self.vectorstore is None
|
||||
):
|
||||
return self.original_question
|
||||
prompt = (
|
||||
"Given the following conversation history:\n"
|
||||
f"{self.chat_history}\n\n"
|
||||
"Rephrase the following user question to be a standalone search query "
|
||||
"that captures all relevant context from the conversation:\n"
|
||||
)
|
||||
|
||||
prompt = f"""Given the following conversation history:
|
||||
{self.chat_history}
|
||||
|
||||
Rephrase the following user question to be a standalone search query
|
||||
that captures all relevant context from the conversation:
|
||||
"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": prompt},
|
||||
@@ -100,7 +71,7 @@ class ClassicRAG(BaseRetriever):
|
||||
]
|
||||
|
||||
try:
|
||||
rephrased_query = self.llm.gen(model=self.model_id, messages=messages)
|
||||
rephrased_query = self.llm.gen(model=self.gpt_model, messages=messages)
|
||||
print(f"Rephrased query: {rephrased_query}")
|
||||
return rephrased_query if rephrased_query else self.original_question
|
||||
except Exception as e:
|
||||
@@ -108,93 +79,46 @@ class ClassicRAG(BaseRetriever):
|
||||
return self.original_question
|
||||
|
||||
def _get_data(self):
|
||||
if self.chunks == 0 or not self.vectorstores:
|
||||
logging.info(
|
||||
f"ClassicRAG._get_data: Skipping retrieval - chunks={self.chunks}, "
|
||||
f"vectorstores_count={len(self.vectorstores) if self.vectorstores else 0}"
|
||||
if self.chunks == 0 or self.vectorstore is None:
|
||||
docs = []
|
||||
else:
|
||||
docsearch = VectorCreator.create_vectorstore(
|
||||
settings.VECTOR_STORE, self.vectorstore, settings.EMBEDDINGS_KEY
|
||||
)
|
||||
return []
|
||||
docs_temp = docsearch.search(self.question, k=self.chunks)
|
||||
docs = [
|
||||
{
|
||||
"title": i.metadata.get(
|
||||
"title", i.metadata.get("post_title", i.page_content)
|
||||
).split("/")[-1],
|
||||
"text": i.page_content,
|
||||
"source": (
|
||||
i.metadata.get("source")
|
||||
if i.metadata.get("source")
|
||||
else "local"
|
||||
),
|
||||
}
|
||||
for i in docs_temp
|
||||
]
|
||||
|
||||
all_docs = []
|
||||
chunks_per_source = max(1, self.chunks // len(self.vectorstores))
|
||||
token_budget = max(int(self.doc_token_limit * 0.9), 100)
|
||||
cumulative_tokens = 0
|
||||
return docs
|
||||
|
||||
for vectorstore_id in self.vectorstores:
|
||||
if vectorstore_id:
|
||||
try:
|
||||
docsearch = VectorCreator.create_vectorstore(
|
||||
settings.VECTOR_STORE, vectorstore_id, settings.EMBEDDINGS_KEY
|
||||
)
|
||||
docs_temp = docsearch.search(
|
||||
self.question, k=max(chunks_per_source * 2, 20)
|
||||
)
|
||||
|
||||
for doc in docs_temp:
|
||||
if cumulative_tokens >= token_budget:
|
||||
break
|
||||
|
||||
if hasattr(doc, "page_content") and hasattr(doc, "metadata"):
|
||||
page_content = doc.page_content
|
||||
metadata = doc.metadata
|
||||
else:
|
||||
page_content = doc.get("text", doc.get("page_content", ""))
|
||||
metadata = doc.get("metadata", {})
|
||||
|
||||
title = metadata.get(
|
||||
"title", metadata.get("post_title", page_content)
|
||||
)
|
||||
if not isinstance(title, str):
|
||||
title = str(title)
|
||||
title = title.split("/")[-1]
|
||||
|
||||
filename = (
|
||||
metadata.get("filename")
|
||||
or metadata.get("file_name")
|
||||
or metadata.get("source")
|
||||
)
|
||||
if isinstance(filename, str):
|
||||
filename = os.path.basename(filename) or filename
|
||||
else:
|
||||
filename = title
|
||||
if not filename:
|
||||
filename = title
|
||||
source_path = metadata.get("source") or vectorstore_id
|
||||
|
||||
doc_text_with_header = f"{filename}\n{page_content}"
|
||||
doc_tokens = num_tokens_from_string(doc_text_with_header)
|
||||
|
||||
if cumulative_tokens + doc_tokens < token_budget:
|
||||
all_docs.append(
|
||||
{
|
||||
"title": title,
|
||||
"text": page_content,
|
||||
"source": source_path,
|
||||
"filename": filename,
|
||||
}
|
||||
)
|
||||
cumulative_tokens += doc_tokens
|
||||
|
||||
if cumulative_tokens >= token_budget:
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error searching vectorstore {vectorstore_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
continue
|
||||
|
||||
logging.info(
|
||||
f"ClassicRAG._get_data: Retrieval complete - retrieved {len(all_docs)} documents "
|
||||
f"(requested chunks={self.chunks}, chunks_per_source={chunks_per_source}, "
|
||||
f"cumulative_tokens={cumulative_tokens}/{token_budget})"
|
||||
)
|
||||
return all_docs
|
||||
def gen():
|
||||
pass
|
||||
|
||||
def search(self, query: str = ""):
|
||||
"""Search for documents using optional query override"""
|
||||
if query:
|
||||
self.original_question = query
|
||||
self.question = self._rephrase_query()
|
||||
return self._get_data()
|
||||
|
||||
def get_params(self):
|
||||
return {
|
||||
"question": self.original_question,
|
||||
"rephrased_question": self.question,
|
||||
"source": self.vectorstore,
|
||||
"chunks": self.chunks,
|
||||
"token_limit": self.token_limit,
|
||||
"gpt_model": self.gpt_model,
|
||||
"user_api_key": self.user_api_key,
|
||||
}
|
||||
|
||||
@@ -1,85 +0,0 @@
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.ciphers import algorithms, Cipher, modes
|
||||
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
def _derive_key(user_id: str, salt: bytes) -> bytes:
|
||||
app_secret = settings.ENCRYPTION_SECRET_KEY
|
||||
|
||||
password = f"{app_secret}#{user_id}".encode()
|
||||
|
||||
kdf = PBKDF2HMAC(
|
||||
algorithm=hashes.SHA256(),
|
||||
length=32,
|
||||
salt=salt,
|
||||
iterations=100000,
|
||||
backend=default_backend(),
|
||||
)
|
||||
|
||||
return kdf.derive(password)
|
||||
|
||||
|
||||
def encrypt_credentials(credentials: dict, user_id: str) -> str:
|
||||
if not credentials:
|
||||
return ""
|
||||
try:
|
||||
salt = os.urandom(16)
|
||||
iv = os.urandom(16)
|
||||
key = _derive_key(user_id, salt)
|
||||
|
||||
json_str = json.dumps(credentials)
|
||||
|
||||
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
|
||||
encryptor = cipher.encryptor()
|
||||
|
||||
padded_data = _pad_data(json_str.encode())
|
||||
encrypted_data = encryptor.update(padded_data) + encryptor.finalize()
|
||||
|
||||
result = salt + iv + encrypted_data
|
||||
return base64.b64encode(result).decode()
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to encrypt credentials: {e}")
|
||||
return ""
|
||||
|
||||
|
||||
def decrypt_credentials(encrypted_data: str, user_id: str) -> dict:
|
||||
if not encrypted_data:
|
||||
return {}
|
||||
try:
|
||||
data = base64.b64decode(encrypted_data.encode())
|
||||
|
||||
salt = data[:16]
|
||||
iv = data[16:32]
|
||||
encrypted_content = data[32:]
|
||||
|
||||
key = _derive_key(user_id, salt)
|
||||
|
||||
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
|
||||
decryptor = cipher.decryptor()
|
||||
|
||||
decrypted_padded = decryptor.update(encrypted_content) + decryptor.finalize()
|
||||
decrypted_data = _unpad_data(decrypted_padded)
|
||||
|
||||
return json.loads(decrypted_data.decode())
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to decrypt credentials: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
def _pad_data(data: bytes) -> bytes:
|
||||
block_size = 16
|
||||
padding_len = block_size - (len(data) % block_size)
|
||||
padding = bytes([padding_len]) * padding_len
|
||||
return data + padding
|
||||
|
||||
|
||||
def _unpad_data(data: bytes) -> bytes:
|
||||
padding_len = data[-1]
|
||||
return data[:-padding_len]
|
||||
@@ -1,26 +0,0 @@
|
||||
import click
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.seed.seeder import DatabaseSeeder
|
||||
|
||||
|
||||
@click.group()
|
||||
def seed():
|
||||
"""Database seeding commands"""
|
||||
pass
|
||||
|
||||
|
||||
@seed.command()
|
||||
@click.option("--force", is_flag=True, help="Force reseeding even if data exists")
|
||||
def init(force):
|
||||
"""Initialize database with seed data"""
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
|
||||
seeder = DatabaseSeeder(db)
|
||||
seeder.seed_initial_data(force=force)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
seed()
|
||||
@@ -1,36 +0,0 @@
|
||||
# Configuration for Premade Agents
|
||||
# This file contains template agents that will be seeded into the database
|
||||
|
||||
agents:
|
||||
# Basic Agent Template
|
||||
- name: "Agent Name" # Required: Unique name for the agent
|
||||
description: "What this agent does" # Required: Brief description of the agent's purpose
|
||||
image: "URL_TO_IMAGE" # Optional: URL to agent's avatar/image
|
||||
agent_type: "classic" # Required: Type of agent (e.g., classic, react, etc.)
|
||||
prompt_id: "default" # Optional: Reference to prompt template
|
||||
prompt: # Optional: Define new prompt
|
||||
name: "New Prompt"
|
||||
content: "You are new agent with cool new prompt."
|
||||
chunks: "0" # Optional: Chunking strategy for documents
|
||||
retriever: "" # Optional: Retriever type for document search
|
||||
|
||||
# Source Configuration (where the agent gets its knowledge)
|
||||
source: # Optional: Select a source to link with agent
|
||||
name: "Source Display Name" # Human-readable name for the source
|
||||
url: "https://example.com/data-source" # URL or path to knowledge source
|
||||
loader: "url" # Type of loader (url, pdf, txt, etc.)
|
||||
|
||||
# Tools Configuration (what capabilities the agent has)
|
||||
tools: # Optional: Remove if agent doesn't need tools
|
||||
- name: "tool_name" # Must match a supported tool name
|
||||
display_name: "Tool Display Name" # Optional: Human-readable name for the tool
|
||||
config:
|
||||
# Tool-specific configuration
|
||||
# Example for DuckDuckGo:
|
||||
# token: "${DDG_API_KEY}" # ${} denotes environment variable
|
||||
|
||||
# Add more tools as needed
|
||||
# - name: "another_tool"
|
||||
# config:
|
||||
# param1: "value1"
|
||||
# param2: "${ENV_VAR}"
|
||||
@@ -1,94 +0,0 @@
|
||||
# Configuration for Premade Agents
|
||||
|
||||
agents:
|
||||
- name: "Assistant"
|
||||
description: "Your general-purpose AI assistant. Ready to help with a wide range of tasks."
|
||||
image: "https://d3dg1063dc54p9.cloudfront.net/imgs/agents/agent-logo.svg"
|
||||
agent_type: "classic"
|
||||
prompt_id: "default"
|
||||
chunks: "0"
|
||||
retriever: ""
|
||||
|
||||
# Tools Configuration
|
||||
tools:
|
||||
- name: "tool_name"
|
||||
display_name: "read_webpage"
|
||||
config:
|
||||
|
||||
- name: "Researcher"
|
||||
description: "A specialized research agent that performs deep dives into subjects."
|
||||
image: "https://d3dg1063dc54p9.cloudfront.net/imgs/agents/agent-researcher.svg"
|
||||
agent_type: "react"
|
||||
prompt:
|
||||
name: "Researcher-Agent"
|
||||
content: |
|
||||
You are a specialized AI research assistant, DocsGPT. Your primary function is to conduct in-depth research on a given subject or question. You are methodical, thorough, and analytical. You should perform multiple iterations of thinking to gather and synthesize information before providing a final, comprehensive answer.
|
||||
|
||||
You have access to the 'Read Webpage' tool. Use this tool to explore sources, gather data, and deepen your understanding. Be proactive in using the tool to fill in knowledge gaps and validate information.
|
||||
|
||||
Users can Upload documents for your context as attachments or sources via UI using the Conversation input box.
|
||||
If appropriate, your answers can include code examples, formatted as follows:
|
||||
```(language)
|
||||
(code)
|
||||
```
|
||||
Users are also able to see charts and diagrams if you use them with valid mermaid syntax in your responses. Try to respond with mermaid charts if visualization helps with users queries. You effectively utilize chat history, ensuring relevant and tailored responses. Try to use additional provided context if it's available, otherwise use your knowledge and tool capabilities.
|
||||
----------------
|
||||
Possible additional context from uploaded sources:
|
||||
{summaries}
|
||||
|
||||
chunks: "0"
|
||||
retriever: ""
|
||||
|
||||
# Tools Configuration
|
||||
tools:
|
||||
- name: "tool_name"
|
||||
display_name: "read_webpage"
|
||||
config:
|
||||
|
||||
- name: "Search Widget"
|
||||
description: "A powerful search widget agent. Ask it anything about DocsGPT"
|
||||
image: "https://d3dg1063dc54p9.cloudfront.net/imgs/agents/agent-search.svg"
|
||||
agent_type: "classic"
|
||||
prompt:
|
||||
name: "Search-Agent"
|
||||
content: |
|
||||
You are a website search assistant, DocsGPT. Your sole purpose is to help users find information within the provided context of the DocsGPT documentation. Act as a specialized search engine.
|
||||
|
||||
Your answers must be based *only* on the provided context. Do not use any external knowledge. If the answer is not in the context, inform the user that you could not find the information within the documentation.
|
||||
|
||||
Keep your responses concise and directly related to the user's query, pointing them to the most relevant information.
|
||||
----------------
|
||||
Possible additional context from uploaded sources:
|
||||
{summaries}
|
||||
|
||||
chunks: "8"
|
||||
retriever: ""
|
||||
|
||||
source:
|
||||
name: "DocsGPT-Docs"
|
||||
url: "https://d3dg1063dc54p9.cloudfront.net/agent-source/docsgpt-documentation.md" # URL to DocsGPT documentation
|
||||
loader: "url"
|
||||
|
||||
- name: "Support Widget"
|
||||
description: "A friendly support widget agent to help you with any questions."
|
||||
image: "https://d3dg1063dc54p9.cloudfront.net/imgs/agents/agent-support.svg"
|
||||
agent_type: "classic"
|
||||
prompt:
|
||||
name: "Support-Agent"
|
||||
content: |
|
||||
You are a helpful AI support widget agent, DocsGPT. Your goal is to assist users by answering their questions about our website, product and its features. Provide friendly, clear, and direct support.
|
||||
|
||||
Your knowledge is strictly limited to the provided context from the DocsGPT documentation. You must not answer questions outside of this scope. If a user asks something you cannot answer from the context, politely state that you can only help with questions about this website.
|
||||
|
||||
Effectively utilize chat history to understand the user's issue fully. Guide users to the information they need in a helpful and conversational manner.
|
||||
----------------
|
||||
Possible additional context from uploaded sources:
|
||||
{summaries}
|
||||
|
||||
chunks: "8"
|
||||
retriever: ""
|
||||
|
||||
source:
|
||||
name: "DocsGPT-Docs"
|
||||
url: "https://d3dg1063dc54p9.cloudfront.net/agent-source/docsgpt-documentation.md" # URL to DocsGPT documentation
|
||||
loader: "url"
|
||||
@@ -1,277 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import yaml
|
||||
from bson import ObjectId
|
||||
from bson.dbref import DBRef
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pymongo import MongoClient
|
||||
|
||||
from application.agents.tools.tool_manager import ToolManager
|
||||
from application.api.user.tasks import ingest_remote
|
||||
|
||||
load_dotenv()
|
||||
tool_config = {}
|
||||
tool_manager = ToolManager(config=tool_config)
|
||||
|
||||
|
||||
class DatabaseSeeder:
|
||||
def __init__(self, db):
|
||||
self.db = db
|
||||
self.tools_collection = self.db["user_tools"]
|
||||
self.sources_collection = self.db["sources"]
|
||||
self.agents_collection = self.db["agents"]
|
||||
self.prompts_collection = self.db["prompts"]
|
||||
self.system_user_id = "system"
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def seed_initial_data(self, config_path: str = None, force=False):
|
||||
"""Main entry point for seeding all initial data"""
|
||||
if not force and self._is_already_seeded():
|
||||
self.logger.info("Database already seeded. Use force=True to reseed.")
|
||||
return
|
||||
config_path = config_path or os.path.join(
|
||||
os.path.dirname(__file__), "config", "premade_agents.yaml"
|
||||
)
|
||||
|
||||
try:
|
||||
with open(config_path, "r") as f:
|
||||
config = yaml.safe_load(f)
|
||||
self._seed_from_config(config)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to load seeding config: {str(e)}")
|
||||
raise
|
||||
|
||||
def _seed_from_config(self, config: Dict):
|
||||
"""Seed all data from configuration"""
|
||||
self.logger.info("🌱 Starting seeding...")
|
||||
|
||||
if not config.get("agents"):
|
||||
self.logger.warning("No agents found in config")
|
||||
return
|
||||
used_tool_ids = set()
|
||||
|
||||
for agent_config in config["agents"]:
|
||||
try:
|
||||
self.logger.info(f"Processing agent: {agent_config['name']}")
|
||||
|
||||
# 1. Handle Source
|
||||
|
||||
source_result = self._handle_source(agent_config)
|
||||
if source_result is False:
|
||||
self.logger.error(
|
||||
f"Skipping agent {agent_config['name']} due to source ingestion failure"
|
||||
)
|
||||
continue
|
||||
source_id = source_result
|
||||
# 2. Handle Tools
|
||||
|
||||
tool_ids = self._handle_tools(agent_config)
|
||||
if len(tool_ids) == 0:
|
||||
self.logger.warning(
|
||||
f"No valid tools for agent {agent_config['name']}"
|
||||
)
|
||||
used_tool_ids.update(tool_ids)
|
||||
|
||||
# 3. Handle Prompt
|
||||
|
||||
prompt_id = self._handle_prompt(agent_config)
|
||||
|
||||
# 4. Create Agent
|
||||
|
||||
agent_data = {
|
||||
"user": self.system_user_id,
|
||||
"name": agent_config["name"],
|
||||
"description": agent_config["description"],
|
||||
"image": agent_config.get("image", ""),
|
||||
"source": (
|
||||
DBRef("sources", ObjectId(source_id)) if source_id else ""
|
||||
),
|
||||
"tools": [str(tid) for tid in tool_ids],
|
||||
"agent_type": agent_config["agent_type"],
|
||||
"prompt_id": prompt_id or agent_config.get("prompt_id", "default"),
|
||||
"chunks": agent_config.get("chunks", "0"),
|
||||
"retriever": agent_config.get("retriever", ""),
|
||||
"status": "template",
|
||||
"createdAt": datetime.now(timezone.utc),
|
||||
"updatedAt": datetime.now(timezone.utc),
|
||||
}
|
||||
|
||||
existing = self.agents_collection.find_one(
|
||||
{"user": self.system_user_id, "name": agent_config["name"]}
|
||||
)
|
||||
if existing:
|
||||
self.logger.info(f"Updating existing agent: {agent_config['name']}")
|
||||
self.agents_collection.update_one(
|
||||
{"_id": existing["_id"]}, {"$set": agent_data}
|
||||
)
|
||||
agent_id = existing["_id"]
|
||||
else:
|
||||
self.logger.info(f"Creating new agent: {agent_config['name']}")
|
||||
result = self.agents_collection.insert_one(agent_data)
|
||||
agent_id = result.inserted_id
|
||||
self.logger.info(
|
||||
f"Successfully processed agent: {agent_config['name']} (ID: {agent_id})"
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
f"Error processing agent {agent_config['name']}: {str(e)}"
|
||||
)
|
||||
continue
|
||||
self.logger.info("✅ Database seeding completed")
|
||||
|
||||
def _handle_source(self, agent_config: Dict) -> Union[ObjectId, None, bool]:
|
||||
"""Handle source ingestion and return source ID"""
|
||||
if not agent_config.get("source"):
|
||||
self.logger.info(
|
||||
"No source provided for agent - will create agent without source"
|
||||
)
|
||||
return None
|
||||
source_config = agent_config["source"]
|
||||
self.logger.info(f"Ingesting source: {source_config['url']}")
|
||||
|
||||
try:
|
||||
existing = self.sources_collection.find_one(
|
||||
{"user": self.system_user_id, "remote_data": source_config["url"]}
|
||||
)
|
||||
if existing:
|
||||
self.logger.info(f"Source already exists: {existing['_id']}")
|
||||
return existing["_id"]
|
||||
# Ingest new source using worker
|
||||
|
||||
task = ingest_remote.delay(
|
||||
source_data=source_config["url"],
|
||||
job_name=source_config["name"],
|
||||
user=self.system_user_id,
|
||||
loader=source_config.get("loader", "url"),
|
||||
)
|
||||
|
||||
result = task.get(timeout=300)
|
||||
|
||||
if not task.successful():
|
||||
raise Exception(f"Source ingestion failed: {result}")
|
||||
source_id = None
|
||||
if isinstance(result, dict) and "id" in result:
|
||||
source_id = result["id"]
|
||||
else:
|
||||
raise Exception(f"Source ingestion result missing 'id': {result}")
|
||||
self.logger.info(f"Source ingested successfully: {source_id}")
|
||||
return source_id
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to ingest source: {str(e)}")
|
||||
return False
|
||||
|
||||
def _handle_tools(self, agent_config: Dict) -> List[ObjectId]:
|
||||
"""Handle tool creation and return list of tool IDs"""
|
||||
tool_ids = []
|
||||
if not agent_config.get("tools"):
|
||||
return tool_ids
|
||||
for tool_config in agent_config["tools"]:
|
||||
try:
|
||||
tool_name = tool_config["name"]
|
||||
processed_config = self._process_config(tool_config.get("config", {}))
|
||||
self.logger.info(f"Processing tool: {tool_name}")
|
||||
|
||||
existing = self.tools_collection.find_one(
|
||||
{
|
||||
"user": self.system_user_id,
|
||||
"name": tool_name,
|
||||
"config": processed_config,
|
||||
}
|
||||
)
|
||||
if existing:
|
||||
self.logger.info(f"Tool already exists: {existing['_id']}")
|
||||
tool_ids.append(existing["_id"])
|
||||
continue
|
||||
tool_data = {
|
||||
"user": self.system_user_id,
|
||||
"name": tool_name,
|
||||
"displayName": tool_config.get("display_name", tool_name),
|
||||
"description": tool_config.get("description", ""),
|
||||
"actions": tool_manager.tools[tool_name].get_actions_metadata(),
|
||||
"config": processed_config,
|
||||
"status": True,
|
||||
}
|
||||
|
||||
result = self.tools_collection.insert_one(tool_data)
|
||||
tool_ids.append(result.inserted_id)
|
||||
self.logger.info(f"Created new tool: {result.inserted_id}")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to process tool {tool_name}: {str(e)}")
|
||||
continue
|
||||
return tool_ids
|
||||
|
||||
def _handle_prompt(self, agent_config: Dict) -> Optional[str]:
|
||||
"""Handle prompt creation and return prompt ID"""
|
||||
if not agent_config.get("prompt"):
|
||||
return None
|
||||
|
||||
prompt_config = agent_config["prompt"]
|
||||
prompt_name = prompt_config.get("name", f"{agent_config['name']} Prompt")
|
||||
prompt_content = prompt_config.get("content", "")
|
||||
|
||||
if not prompt_content:
|
||||
self.logger.warning(
|
||||
f"No prompt content provided for agent {agent_config['name']}"
|
||||
)
|
||||
return None
|
||||
|
||||
self.logger.info(f"Processing prompt: {prompt_name}")
|
||||
|
||||
try:
|
||||
existing = self.prompts_collection.find_one(
|
||||
{
|
||||
"user": self.system_user_id,
|
||||
"name": prompt_name,
|
||||
"content": prompt_content,
|
||||
}
|
||||
)
|
||||
if existing:
|
||||
self.logger.info(f"Prompt already exists: {existing['_id']}")
|
||||
return str(existing["_id"])
|
||||
|
||||
prompt_data = {
|
||||
"name": prompt_name,
|
||||
"content": prompt_content,
|
||||
"user": self.system_user_id,
|
||||
}
|
||||
|
||||
result = self.prompts_collection.insert_one(prompt_data)
|
||||
prompt_id = str(result.inserted_id)
|
||||
self.logger.info(f"Created new prompt: {prompt_id}")
|
||||
return prompt_id
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to process prompt {prompt_name}: {str(e)}")
|
||||
return None
|
||||
|
||||
def _process_config(self, config: Dict) -> Dict:
|
||||
"""Process config values to replace environment variables"""
|
||||
processed = {}
|
||||
for key, value in config.items():
|
||||
if (
|
||||
isinstance(value, str)
|
||||
and value.startswith("${")
|
||||
and value.endswith("}")
|
||||
):
|
||||
env_var = value[2:-1]
|
||||
processed[key] = os.getenv(env_var, "")
|
||||
else:
|
||||
processed[key] = value
|
||||
return processed
|
||||
|
||||
def _is_already_seeded(self) -> bool:
|
||||
"""Check if premade agents already exist"""
|
||||
return self.agents_collection.count_documents({"user": self.system_user_id}) > 0
|
||||
|
||||
@classmethod
|
||||
def initialize_from_env(cls, worker=None):
|
||||
"""Factory method to create seeder from environment"""
|
||||
mongo_uri = os.getenv("MONGO_URI", "mongodb://localhost:27017")
|
||||
db_name = os.getenv("MONGO_DB_NAME", "docsgpt")
|
||||
client = MongoClient(mongo_uri)
|
||||
db = client[db_name]
|
||||
return cls(db)
|
||||
@@ -26,7 +26,7 @@ class LocalStorage(BaseStorage):
|
||||
return path
|
||||
return os.path.join(self.base_dir, path)
|
||||
|
||||
def save_file(self, file_data: BinaryIO, path: str, **kwargs) -> dict:
|
||||
def save_file(self, file_data: BinaryIO, path: str) -> dict:
|
||||
"""Save a file to local storage."""
|
||||
full_path = self._get_full_path(path)
|
||||
|
||||
|
||||
@@ -1,190 +0,0 @@
|
||||
import logging
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NamespaceBuilder(ABC):
|
||||
"""Base class for building template context namespaces"""
|
||||
|
||||
@abstractmethod
|
||||
def build(self, **kwargs) -> Dict[str, Any]:
|
||||
"""Build namespace context dictionary"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def namespace_name(self) -> str:
|
||||
"""Name of this namespace for template access"""
|
||||
pass
|
||||
|
||||
|
||||
class SystemNamespace(NamespaceBuilder):
|
||||
"""System metadata namespace: {{ system.* }}"""
|
||||
|
||||
@property
|
||||
def namespace_name(self) -> str:
|
||||
return "system"
|
||||
|
||||
def build(
|
||||
self, request_id: Optional[str] = None, user_id: Optional[str] = None, **kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Build system context with metadata.
|
||||
|
||||
Args:
|
||||
request_id: Unique request identifier
|
||||
user_id: Current user identifier
|
||||
|
||||
Returns:
|
||||
Dictionary with system variables
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
return {
|
||||
"date": now.strftime("%Y-%m-%d"),
|
||||
"time": now.strftime("%H:%M:%S"),
|
||||
"timestamp": now.isoformat(),
|
||||
"request_id": request_id or str(uuid.uuid4()),
|
||||
"user_id": user_id,
|
||||
}
|
||||
|
||||
|
||||
class PassthroughNamespace(NamespaceBuilder):
|
||||
"""Request parameters namespace: {{ passthrough.* }}"""
|
||||
|
||||
@property
|
||||
def namespace_name(self) -> str:
|
||||
return "passthrough"
|
||||
|
||||
def build(
|
||||
self, passthrough_data: Optional[Dict[str, Any]] = None, **kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Build passthrough context from request parameters.
|
||||
|
||||
Args:
|
||||
passthrough_data: Dictionary of parameters from web request
|
||||
|
||||
Returns:
|
||||
Dictionary with passthrough variables
|
||||
"""
|
||||
if not passthrough_data:
|
||||
return {}
|
||||
safe_data = {}
|
||||
for key, value in passthrough_data.items():
|
||||
if isinstance(value, (str, int, float, bool, type(None))):
|
||||
safe_data[key] = value
|
||||
else:
|
||||
logger.warning(
|
||||
f"Skipping non-serializable passthrough value for key '{key}': {type(value)}"
|
||||
)
|
||||
return safe_data
|
||||
|
||||
|
||||
class SourceNamespace(NamespaceBuilder):
|
||||
"""RAG source documents namespace: {{ source.* }}"""
|
||||
|
||||
@property
|
||||
def namespace_name(self) -> str:
|
||||
return "source"
|
||||
|
||||
def build(
|
||||
self, docs: Optional[list] = None, docs_together: Optional[str] = None, **kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Build source context from RAG retrieval results.
|
||||
|
||||
Args:
|
||||
docs: List of retrieved documents
|
||||
docs_together: Concatenated document content (for backward compatibility)
|
||||
|
||||
Returns:
|
||||
Dictionary with source variables
|
||||
"""
|
||||
context = {}
|
||||
|
||||
if docs:
|
||||
context["documents"] = docs
|
||||
context["count"] = len(docs)
|
||||
if docs_together:
|
||||
context["docs_together"] = docs_together # Add docs_together for custom templates
|
||||
context["content"] = docs_together
|
||||
context["summaries"] = docs_together
|
||||
return context
|
||||
|
||||
|
||||
class ToolsNamespace(NamespaceBuilder):
|
||||
"""Pre-executed tools namespace: {{ tools.* }}"""
|
||||
|
||||
@property
|
||||
def namespace_name(self) -> str:
|
||||
return "tools"
|
||||
|
||||
def build(
|
||||
self, tools_data: Optional[Dict[str, Any]] = None, **kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Build tools context with pre-executed tool results.
|
||||
|
||||
Args:
|
||||
tools_data: Dictionary of pre-fetched tool results organized by tool name
|
||||
e.g., {"memory": {"notes": "content", "tasks": "list"}}
|
||||
|
||||
Returns:
|
||||
Dictionary with tool results organized by tool name
|
||||
"""
|
||||
if not tools_data:
|
||||
return {}
|
||||
|
||||
safe_data = {}
|
||||
for tool_name, tool_result in tools_data.items():
|
||||
if isinstance(tool_result, (str, dict, list, int, float, bool, type(None))):
|
||||
safe_data[tool_name] = tool_result
|
||||
else:
|
||||
logger.warning(
|
||||
f"Skipping non-serializable tool result for '{tool_name}': {type(tool_result)}"
|
||||
)
|
||||
return safe_data
|
||||
|
||||
|
||||
class NamespaceManager:
|
||||
"""Manages all namespace builders and context assembly"""
|
||||
|
||||
def __init__(self):
|
||||
self._builders = {
|
||||
"system": SystemNamespace(),
|
||||
"passthrough": PassthroughNamespace(),
|
||||
"source": SourceNamespace(),
|
||||
"tools": ToolsNamespace(),
|
||||
}
|
||||
|
||||
def build_context(self, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
Build complete template context from all namespaces.
|
||||
|
||||
Args:
|
||||
**kwargs: Parameters to pass to namespace builders
|
||||
|
||||
Returns:
|
||||
Complete context dictionary for template rendering
|
||||
"""
|
||||
context = {}
|
||||
|
||||
for namespace_name, builder in self._builders.items():
|
||||
try:
|
||||
namespace_context = builder.build(**kwargs)
|
||||
# Always include namespace, even if empty, to prevent undefined errors
|
||||
context[namespace_name] = namespace_context if namespace_context else {}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to build {namespace_name} namespace: {str(e)}")
|
||||
# Include empty namespace on error to prevent template failures
|
||||
context[namespace_name] = {}
|
||||
return context
|
||||
|
||||
def get_builder(self, namespace_name: str) -> Optional[NamespaceBuilder]:
|
||||
"""Get specific namespace builder"""
|
||||
return self._builders.get(namespace_name)
|
||||
@@ -1,161 +0,0 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from jinja2 import (
|
||||
ChainableUndefined,
|
||||
Environment,
|
||||
nodes,
|
||||
select_autoescape,
|
||||
TemplateSyntaxError,
|
||||
)
|
||||
from jinja2.exceptions import UndefinedError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TemplateRenderError(Exception):
|
||||
"""Raised when template rendering fails"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TemplateEngine:
|
||||
"""Jinja2-based template engine for dynamic prompt rendering"""
|
||||
|
||||
def __init__(self):
|
||||
self._env = Environment(
|
||||
undefined=ChainableUndefined,
|
||||
trim_blocks=True,
|
||||
lstrip_blocks=True,
|
||||
autoescape=select_autoescape(default_for_string=True, default=True),
|
||||
)
|
||||
|
||||
def render(self, template_content: str, context: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Render template with provided context.
|
||||
|
||||
Args:
|
||||
template_content: Raw template string with Jinja2 syntax
|
||||
context: Dictionary of variables to inject into template
|
||||
|
||||
Returns:
|
||||
Rendered template string
|
||||
|
||||
Raises:
|
||||
TemplateRenderError: If template syntax is invalid or variables undefined
|
||||
"""
|
||||
if not template_content:
|
||||
return ""
|
||||
try:
|
||||
template = self._env.from_string(template_content)
|
||||
return template.render(**context)
|
||||
except TemplateSyntaxError as e:
|
||||
error_msg = f"Template syntax error at line {e.lineno}: {e.message}"
|
||||
logger.error(error_msg)
|
||||
raise TemplateRenderError(error_msg) from e
|
||||
except UndefinedError as e:
|
||||
error_msg = f"Undefined variable in template: {e.message}"
|
||||
logger.error(error_msg)
|
||||
raise TemplateRenderError(error_msg) from e
|
||||
except Exception as e:
|
||||
error_msg = f"Template rendering failed: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
raise TemplateRenderError(error_msg) from e
|
||||
|
||||
def validate_template(self, template_content: str) -> bool:
|
||||
"""
|
||||
Validate template syntax without rendering.
|
||||
|
||||
Args:
|
||||
template_content: Template string to validate
|
||||
|
||||
Returns:
|
||||
True if template is syntactically valid
|
||||
"""
|
||||
if not template_content:
|
||||
return True
|
||||
try:
|
||||
self._env.from_string(template_content)
|
||||
return True
|
||||
except TemplateSyntaxError as e:
|
||||
logger.debug(f"Template syntax invalid at line {e.lineno}: {e.message}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.debug(f"Template validation error: {type(e).__name__}: {str(e)}")
|
||||
return False
|
||||
|
||||
def extract_variables(self, template_content: str) -> Set[str]:
|
||||
"""
|
||||
Extract all variable names from template.
|
||||
|
||||
Args:
|
||||
template_content: Template string to analyze
|
||||
|
||||
Returns:
|
||||
Set of variable names found in template
|
||||
"""
|
||||
if not template_content:
|
||||
return set()
|
||||
try:
|
||||
ast = self._env.parse(template_content)
|
||||
return set(self._env.get_template_module(ast).make_module().keys())
|
||||
except TemplateSyntaxError as e:
|
||||
logger.debug(f"Cannot extract variables - syntax error at line {e.lineno}")
|
||||
return set()
|
||||
except Exception as e:
|
||||
logger.debug(f"Cannot extract variables: {type(e).__name__}")
|
||||
return set()
|
||||
|
||||
def extract_tool_usages(
|
||||
self, template_content: str
|
||||
) -> Dict[str, Set[Optional[str]]]:
|
||||
"""Extract tool and action references from a template"""
|
||||
if not template_content:
|
||||
return {}
|
||||
try:
|
||||
ast = self._env.parse(template_content)
|
||||
except TemplateSyntaxError as e:
|
||||
logger.debug(f"extract_tool_usages - syntax error at line {e.lineno}")
|
||||
return {}
|
||||
except Exception as e:
|
||||
logger.debug(f"extract_tool_usages - parse error: {type(e).__name__}")
|
||||
return {}
|
||||
|
||||
usages: Dict[str, Set[Optional[str]]] = {}
|
||||
|
||||
def record(path: List[str]) -> None:
|
||||
if not path:
|
||||
return
|
||||
tool_name = path[0]
|
||||
action_name = path[1] if len(path) > 1 else None
|
||||
if not tool_name:
|
||||
return
|
||||
tool_entry = usages.setdefault(tool_name, set())
|
||||
tool_entry.add(action_name)
|
||||
|
||||
for node in ast.find_all(nodes.Getattr):
|
||||
path = []
|
||||
current = node
|
||||
while isinstance(current, nodes.Getattr):
|
||||
path.append(current.attr)
|
||||
current = current.node
|
||||
if isinstance(current, nodes.Name) and current.name == "tools":
|
||||
path.reverse()
|
||||
record(path)
|
||||
|
||||
for node in ast.find_all(nodes.Getitem):
|
||||
path = []
|
||||
current = node
|
||||
while isinstance(current, nodes.Getitem):
|
||||
key = current.arg
|
||||
if isinstance(key, nodes.Const) and isinstance(key.value, str):
|
||||
path.append(key.value)
|
||||
else:
|
||||
path = []
|
||||
break
|
||||
current = current.node
|
||||
if path and isinstance(current, nodes.Name) and current.name == "tools":
|
||||
path.reverse()
|
||||
record(path)
|
||||
|
||||
return usages
|
||||
@@ -1,31 +1,84 @@
|
||||
from io import BytesIO
|
||||
import asyncio
|
||||
import websockets
|
||||
import json
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from application.tts.base import BaseTTS
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
class ElevenlabsTTS(BaseTTS):
|
||||
def __init__(self):
|
||||
from elevenlabs.client import ElevenLabs
|
||||
|
||||
self.client = ElevenLabs(
|
||||
api_key=settings.ELEVENLABS_API_KEY,
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
self.api_key = 'ELEVENLABS_API_KEY'# here you should put your api key
|
||||
self.model = "eleven_flash_v2_5"
|
||||
self.voice = "VOICE_ID" # this is the hash code for the voice not the name!
|
||||
self.write_audio = 1
|
||||
|
||||
def text_to_speech(self, text):
|
||||
lang = "en"
|
||||
audio = self.client.text_to_speech.convert(
|
||||
voice_id="nPczCjzI2devNBz1zQrb",
|
||||
model_id="eleven_multilingual_v2",
|
||||
text=text,
|
||||
output_format="mp3_44100_128"
|
||||
)
|
||||
audio_data = BytesIO()
|
||||
for chunk in audio:
|
||||
audio_data.write(chunk)
|
||||
audio_bytes = audio_data.getvalue()
|
||||
asyncio.run(self._text_to_speech_websocket(text))
|
||||
|
||||
# Encode to base64
|
||||
audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")
|
||||
return audio_base64, lang
|
||||
async def _text_to_speech_websocket(self, text):
|
||||
uri = f"wss://api.elevenlabs.io/v1/text-to-speech/{self.voice}/stream-input?model_id={self.model}"
|
||||
websocket = await websockets.connect(uri)
|
||||
payload = {
|
||||
"text": " ",
|
||||
"voice_settings": {
|
||||
"stability": 0.5,
|
||||
"similarity_boost": 0.8,
|
||||
},
|
||||
"xi_api_key": self.api_key,
|
||||
}
|
||||
|
||||
await websocket.send(json.dumps(payload))
|
||||
|
||||
async def listen():
|
||||
while 1:
|
||||
try:
|
||||
msg = await websocket.recv()
|
||||
data = json.loads(msg)
|
||||
|
||||
if data.get("audio"):
|
||||
print("audio received")
|
||||
yield base64.b64decode(data["audio"])
|
||||
elif data.get("isFinal"):
|
||||
break
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
print("websocket closed")
|
||||
break
|
||||
listen_task = asyncio.create_task(self.stream(listen()))
|
||||
|
||||
await websocket.send(json.dumps({"text": text}))
|
||||
# this is to signal the end of the text, either use this or flush
|
||||
await websocket.send(json.dumps({"text": ""}))
|
||||
|
||||
await listen_task
|
||||
|
||||
async def stream(self, audio_stream):
|
||||
if self.write_audio:
|
||||
audio_bytes = BytesIO()
|
||||
async for chunk in audio_stream:
|
||||
if chunk:
|
||||
audio_bytes.write(chunk)
|
||||
with open("output_audio.mp3", "wb") as f:
|
||||
f.write(audio_bytes.getvalue())
|
||||
|
||||
else:
|
||||
async for chunk in audio_stream:
|
||||
pass # depends on the streamer!
|
||||
|
||||
|
||||
def test_elevenlabs_websocket():
|
||||
"""
|
||||
Tests the ElevenlabsTTS text_to_speech method with a sample prompt.
|
||||
Prints out the base64-encoded result and writes it to 'output_audio.mp3'.
|
||||
"""
|
||||
# Instantiate your TTS class
|
||||
tts = ElevenlabsTTS()
|
||||
|
||||
# Call the method with some sample text
|
||||
tts.text_to_speech("Hello from ElevenLabs WebSocket!")
|
||||
|
||||
print("Saved audio to output_audio.mp3.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_elevenlabs_websocket()
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
from application.tts.google_tts import GoogleTTS
|
||||
from application.tts.elevenlabs import ElevenlabsTTS
|
||||
from application.tts.base import BaseTTS
|
||||
|
||||
|
||||
|
||||
class TTSCreator:
|
||||
tts_providers = {
|
||||
"google_tts": GoogleTTS,
|
||||
"elevenlabs": ElevenlabsTTS,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create_tts(cls, tts_type, *args, **kwargs)-> BaseTTS:
|
||||
tts_class = cls.tts_providers.get(tts_type.lower())
|
||||
if not tts_class:
|
||||
raise ValueError(f"No tts class found for type {tts_type}")
|
||||
return tts_class(*args, **kwargs)
|
||||
@@ -7,8 +7,6 @@ import tiktoken
|
||||
from flask import jsonify, make_response
|
||||
from werkzeug.utils import secure_filename
|
||||
|
||||
from application.core.model_utils import get_token_limit
|
||||
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
@@ -23,7 +21,7 @@ def get_encoding():
|
||||
|
||||
|
||||
def get_gpt_model() -> str:
|
||||
"""Get GPT model based on provider"""
|
||||
"""Get the appropriate GPT model based on provider"""
|
||||
model_map = {
|
||||
"openai": "gpt-4o-mini",
|
||||
"anthropic": "claude-2",
|
||||
@@ -34,7 +32,16 @@ def get_gpt_model() -> str:
|
||||
|
||||
|
||||
def safe_filename(filename):
|
||||
"""Create safe filename, preserving extension. Handles non-Latin characters."""
|
||||
"""
|
||||
Creates a safe filename that preserves the original extension.
|
||||
Uses secure_filename, but ensures a proper filename is returned even with non-Latin characters.
|
||||
|
||||
Args:
|
||||
filename (str): The original filename
|
||||
|
||||
Returns:
|
||||
str: A safe filename that can be used for storage
|
||||
"""
|
||||
if not filename:
|
||||
return str(uuid.uuid4())
|
||||
_, extension = os.path.splitext(filename)
|
||||
@@ -76,23 +83,8 @@ def count_tokens_docs(docs):
|
||||
return tokens
|
||||
|
||||
|
||||
def calculate_doc_token_budget(
|
||||
model_id: str = "gpt-4o", history_token_limit: int = 2000
|
||||
) -> int:
|
||||
total_context = get_token_limit(model_id)
|
||||
reserved = sum(settings.RESERVED_TOKENS.values())
|
||||
doc_budget = total_context - history_token_limit - reserved
|
||||
return max(doc_budget, 1000)
|
||||
|
||||
|
||||
def get_missing_fields(data, required_fields):
|
||||
"""Check for missing required fields. Returns list of missing field names."""
|
||||
return [field for field in required_fields if field not in data]
|
||||
|
||||
|
||||
def check_required_fields(data, required_fields):
|
||||
"""Validate required fields. Returns Flask 400 response if validation fails, None otherwise."""
|
||||
missing_fields = get_missing_fields(data, required_fields)
|
||||
missing_fields = [field for field in required_fields if field not in data]
|
||||
if missing_fields:
|
||||
return make_response(
|
||||
jsonify(
|
||||
@@ -106,8 +98,7 @@ def check_required_fields(data, required_fields):
|
||||
return None
|
||||
|
||||
|
||||
def get_field_validation_errors(data, required_fields):
|
||||
"""Check for missing and empty fields. Returns dict with 'missing_fields' and 'empty_fields', or None."""
|
||||
def validate_required_fields(data, required_fields):
|
||||
missing_fields = []
|
||||
empty_fields = []
|
||||
|
||||
@@ -116,24 +107,12 @@ def get_field_validation_errors(data, required_fields):
|
||||
missing_fields.append(field)
|
||||
elif not data[field]:
|
||||
empty_fields.append(field)
|
||||
if missing_fields or empty_fields:
|
||||
return {"missing_fields": missing_fields, "empty_fields": empty_fields}
|
||||
return None
|
||||
|
||||
|
||||
def validate_required_fields(data, required_fields):
|
||||
"""Validate required fields (must exist and be non-empty). Returns Flask 400 response if validation fails, None otherwise."""
|
||||
errors_dict = get_field_validation_errors(data, required_fields)
|
||||
if errors_dict:
|
||||
errors = []
|
||||
if errors_dict["missing_fields"]:
|
||||
errors.append(
|
||||
f"Missing required fields: {', '.join(errors_dict['missing_fields'])}"
|
||||
)
|
||||
if errors_dict["empty_fields"]:
|
||||
errors.append(
|
||||
f"Empty values in required fields: {', '.join(errors_dict['empty_fields'])}"
|
||||
)
|
||||
errors = []
|
||||
if missing_fields:
|
||||
errors.append(f"Missing required fields: {', '.join(missing_fields)}")
|
||||
if empty_fields:
|
||||
errors.append(f"Empty values in required fields: {', '.join(empty_fields)}")
|
||||
if errors:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": " | ".join(errors)}), 400
|
||||
)
|
||||
@@ -144,13 +123,19 @@ def get_hash(data):
|
||||
return hashlib.md5(data.encode(), usedforsecurity=False).hexdigest()
|
||||
|
||||
|
||||
def limit_chat_history(history, max_token_limit=None, model_id="docsgpt-local"):
|
||||
"""Limit chat history to fit within token limit."""
|
||||
model_token_limit = get_token_limit(model_id)
|
||||
def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"):
|
||||
"""
|
||||
Limits chat history based on token count.
|
||||
Returns a list of messages that fit within the token limit.
|
||||
"""
|
||||
from application.core.settings import settings
|
||||
|
||||
max_token_limit = (
|
||||
max_token_limit
|
||||
if max_token_limit and max_token_limit < model_token_limit
|
||||
else model_token_limit
|
||||
if max_token_limit
|
||||
and max_token_limit
|
||||
< settings.LLM_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_MAX_HISTORY)
|
||||
else settings.LLM_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_MAX_HISTORY)
|
||||
)
|
||||
|
||||
if not history:
|
||||
@@ -176,17 +161,13 @@ def limit_chat_history(history, max_token_limit=None, model_id="docsgpt-local"):
|
||||
|
||||
|
||||
def validate_function_name(function_name):
|
||||
"""Validate function name matches allowed pattern (alphanumeric, underscore, hyphen)."""
|
||||
"""Validates if a function name matches the allowed pattern."""
|
||||
if not re.match(r"^[a-zA-Z0-9_-]+$", function_name):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def generate_image_url(image_path):
|
||||
if isinstance(image_path, str) and (
|
||||
image_path.startswith("http://") or image_path.startswith("https://")
|
||||
):
|
||||
return image_path
|
||||
strategy = getattr(settings, "URL_STRATEGY", "backend")
|
||||
if strategy == "s3":
|
||||
bucket_name = getattr(settings, "S3_BUCKET_NAME", "docsgpt-test-bucket")
|
||||
@@ -195,69 +176,3 @@ def generate_image_url(image_path):
|
||||
else:
|
||||
base_url = getattr(settings, "API_URL", "http://localhost:7091")
|
||||
return f"{base_url}/api/images/{image_path}"
|
||||
|
||||
|
||||
def calculate_compression_threshold(
|
||||
model_id: str, threshold_percentage: float = 0.8
|
||||
) -> int:
|
||||
"""
|
||||
Calculate token threshold for triggering compression.
|
||||
|
||||
Args:
|
||||
model_id: Model identifier
|
||||
threshold_percentage: Percentage of context window (default 80%)
|
||||
|
||||
Returns:
|
||||
Token count threshold
|
||||
"""
|
||||
total_context = get_token_limit(model_id)
|
||||
threshold = int(total_context * threshold_percentage)
|
||||
return threshold
|
||||
|
||||
|
||||
def clean_text_for_tts(text: str) -> str:
|
||||
"""
|
||||
clean text for Text-to-Speech processing.
|
||||
"""
|
||||
# Handle code blocks and links
|
||||
|
||||
text = re.sub(r"```mermaid[\s\S]*?```", " flowchart, ", text) ## ```mermaid...```
|
||||
text = re.sub(r"```[\s\S]*?```", " code block, ", text) ## ```code```
|
||||
text = re.sub(r"\[([^\]]+)\]\([^\)]+\)", r"\1", text) ## [text](url)
|
||||
text = re.sub(r"!\[([^\]]*)\]\([^\)]+\)", "", text) ## 
|
||||
|
||||
# Remove markdown formatting
|
||||
|
||||
text = re.sub(r"`([^`]+)`", r"\1", text) ## `code`
|
||||
text = re.sub(r"\{([^}]*)\}", r" \1 ", text) ## {text}
|
||||
text = re.sub(r"[{}]", " ", text) ## unmatched {}
|
||||
text = re.sub(r"\[([^\]]+)\]", r" \1 ", text) ## [text]
|
||||
text = re.sub(r"[\[\]]", " ", text) ## unmatched []
|
||||
text = re.sub(r"(\*\*|__)(.*?)\1", r"\2", text) ## **bold** __bold__
|
||||
text = re.sub(r"(\*|_)(.*?)\1", r"\2", text) ## *italic* _italic_
|
||||
text = re.sub(r"^#{1,6}\s+", "", text, flags=re.MULTILINE) ## # headers
|
||||
text = re.sub(r"^>\s+", "", text, flags=re.MULTILINE) ## > blockquotes
|
||||
text = re.sub(r"^[\s]*[-\*\+]\s+", "", text, flags=re.MULTILINE) ## - * + lists
|
||||
text = re.sub(r"^[\s]*\d+\.\s+", "", text, flags=re.MULTILINE) ## 1. numbered lists
|
||||
text = re.sub(
|
||||
r"^[\*\-_]{3,}\s*$", "", text, flags=re.MULTILINE
|
||||
) ## --- *** ___ rules
|
||||
text = re.sub(r"<[^>]*>", "", text) ## <html> tags
|
||||
|
||||
# Remove non-ASCII (emojis, special Unicode)
|
||||
|
||||
text = re.sub(r"[^\x20-\x7E\n\r\t]", "", text)
|
||||
|
||||
# Replace special sequences
|
||||
|
||||
text = re.sub(r"-->", ", ", text) ## -->
|
||||
text = re.sub(r"<--", ", ", text) ## <--
|
||||
text = re.sub(r"=>", ", ", text) ## =>
|
||||
text = re.sub(r"::", " ", text) ## ::
|
||||
|
||||
# Normalize whitespace
|
||||
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
text = text.strip()
|
||||
|
||||
return text
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user