Compare commits

...

76 Commits

Author SHA1 Message Date
dependabot[bot]
4c7a6a78aa chore(deps): bump docker/setup-qemu-action from 3 to 4
Bumps [docker/setup-qemu-action](https://github.com/docker/setup-qemu-action) from 3 to 4.
- [Release notes](https://github.com/docker/setup-qemu-action/releases)
- [Commits](https://github.com/docker/setup-qemu-action/compare/v3...v4)

---
updated-dependencies:
- dependency-name: docker/setup-qemu-action
  dependency-version: '4'
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-03-04 20:54:17 +00:00
Alex
a6625ec5de fix: mini workflow fixes 2026-02-22 11:10:42 +00:00
Alex
1a2104f474 fix: token calc (#2285) 2026-02-20 17:37:47 +00:00
Alex
444abb8283 fix search nextra 2026-02-18 18:03:03 +00:00
Alex
ee86537f21 docs: add llms.txt and enable copy code button in nextra 2026-02-18 17:54:25 +00:00
Alex
17a736a927 docs: migrate to Nextra 4 and Next.js App Router 2026-02-18 17:13:24 +00:00
Alex
6b5779054d Merge branch 'main' of https://github.com/arc53/DocsGPT 2026-02-17 18:46:35 +00:00
Alex
14296632ef build(docs): upgrade nextra to v3 and update config 2026-02-17 18:46:28 +00:00
Siddhant Rai
2a3f0e455a feat: condition node functionality with CEL evaluation in Workflows (#2280)
* feat: add condition node functionality with CEL evaluation

- Introduced ConditionNode to support conditional branching in workflows.
- Implemented CEL evaluation for state updates and condition expressions.
- Updated WorkflowEngine to handle condition nodes and their execution logic.
- Enhanced validation for workflows to ensure condition nodes have at least two outgoing edges and valid expressions.
- Modified frontend components to support new condition node type and its configuration.
- Added necessary types and interfaces for condition cases and state operations.
- Updated requirements to include cel-python for expression evaluation.

* mini-fixes

* feat(workflow): improve UX

---------

Co-authored-by: Alex <a@tushynski.me>
2026-02-17 17:29:48 +00:00
Pavel
8aa44c415b Advanced settings (#2281)
Add additional settings to setup scripts
2026-02-17 11:54:59 +00:00
Manish Madan
0566c41a32 Merge pull request #2279 from yusufazam225/main
Fix XSS vulnerability: replace dangerouslySetInnerHTML with safe React rendering in PromptsModal
2026-02-15 19:27:37 +05:30
Manish Madan
876b04c058 Artifacts-backed persistence for Agent “Self” tools (Notes / Todo) + streaming artifact_id support (#2267)
* (feat:memory) use fs/storage for files

* (feat:todo) artifact_id via sse

* (feat:notes) artifact id return

* (feat:artifact) add get endpoint, store todos with conv id

* (feat: artifacts) fe integration

* feat(artifacts): ui enhancements, notes as mkdwn

* chore(artifacts) updated artifact tests

* (feat:todo_tool) return all todo items

* (feat:tools) use specific tool names in bubble

* feat: add conversationId prop to artifact components in Conversation

* Revert "(feat:memory) use fs/storage for files"

This reverts commit d1ce3bea31.

* (fix:fe) build fail
2026-02-15 00:08:37 +00:00
Yusuf Azam
b49a5934e2 Fix XSS vulnerability: replace dangerouslySetInnerHTML with safe React rendering 2026-02-14 17:18:52 +05:30
Alex
5fb063914e fix: frontend lib 2026-02-12 16:01:49 +00:00
Pavel
b9941e29a9 Scrollbar normalize styles (#2277)
* normalize styles

* Fix agent subhead width

* Dark scrollbar color adjust

* different browser support

---------

Co-authored-by: Alex <a@tushynski.me>
2026-02-12 15:19:42 +00:00
Siddhant Rai
8ef321d784 feat: agent workflow builder (#2264)
* feat: implement WorkflowAgent and GraphExecutor for workflow management and execution

* refactor: workflow schemas and introduce WorkflowEngine

- Updated schemas in `schemas.py` to include new agent types and configurations.
- Created `WorkflowEngine` class in `workflow_engine.py` to manage workflow execution.
- Enhanced `StreamProcessor` to handle workflow-related data.
- Added new routes and utilities for managing workflows in the user API.
- Implemented validation and serialization functions for workflows.
- Established MongoDB collections and indexes for workflows and related entities.

* refactor: improve WorkflowAgent documentation and update type hints in WorkflowEngine

* feat: workflow builder and managing in frontend

- Added new endpoints for workflows in `endpoints.ts`.
- Implemented `getWorkflow`, `createWorkflow`, and `updateWorkflow` methods in `userService.ts`.
- Introduced new UI components for alerts, buttons, commands, dialogs, multi-select, popovers, and selects.
- Enhanced styling in `index.css` with new theme variables and animations.
- Refactored modal components for better layout and styling.
- Configured TypeScript paths and Vite aliases for cleaner imports.

* feat: add workflow preview component and related state management

- Implemented WorkflowPreview component for displaying workflow execution.
- Created WorkflowPreviewSlice for managing workflow preview state, including queries and execution steps.
- Added WorkflowMiniMap for visual representation of workflow nodes and their statuses.
- Integrated conversation handling with the ability to fetch answers and manage query states.
- Introduced reusable Sheet component for UI overlays.
- Updated Redux store to include workflowPreview reducer.

* feat: enhance workflow execution details and state management in WorkflowEngine and WorkflowPreview

* feat: enhance workflow components with improved UI and functionality

- Updated WorkflowPreview to allow text truncation for better display of long names.
- Enhanced BaseNode with connectable handles and improved styling for better visibility.
- Added MobileBlocker component to inform users about desktop requirements for the Workflow Builder.
- Introduced PromptTextArea component for improved variable insertion and search functionality, including upstream variable extraction and context addition.

* feat(workflow): add owner validation and graph version support

* fix: ruff lint

---------

Co-authored-by: Alex <a@tushynski.me>
2026-02-11 14:15:24 +00:00
Alex
8353f9c649 docs: add guide for OCR configuration and usage 2026-02-10 15:54:27 +00:00
Alex
cb6b3aa406 fix: test google ai 2026-02-09 14:37:36 +00:00
Alex
36c7bd9206 Thinking stream (#2276)
* feat: stream thinking tokens

* fix: retry bug

* fix test
2026-02-09 14:27:53 +00:00
Alex
fea94379d7 feat: stream thinking tokens (#2275) 2026-02-09 13:46:27 +00:00
Alex
e602d941ca fix: sources display (#2274)
* fix: sources display

* fix: sources display2
2026-02-05 19:40:35 +00:00
Alex
f41f69a268 docs: add specific Celery startup command for macOS users 2026-02-03 17:33:13 +00:00
Manish Madan
ff72251878 Merge pull request #2268 from IRjSI/frontend-fix
fix(frontend): fix the input styling on renaming chat
2026-02-03 16:57:43 +05:30
Alex
7751fb52dd fix: neon docs mention 2026-02-03 00:23:37 +00:00
Alex
87a44d101d Update link in README for Neon documentation 2026-02-03 00:16:49 +00:00
Alex
80148f25b6 Update image source in README.md 2026-02-03 00:11:53 +00:00
Alex
8e3e4a8b09 fix: stio mcp 2026-02-02 15:46:19 +00:00
–IRjSI
9389b4a1e8 fix(frontend): fix the input styling on renaming chat 2026-01-27 22:33:28 +05:30
Pavel
4245e5bd2e End 2 end tests (#2266)
* All endpoints covered

test_integration.py kept for backwards compatability.
tests/integration/run_all.py proposed as alternative to cover all endpoints.

* Linter fixes
2026-01-22 13:11:24 +02:00
Pavel
e7d2af2405 Setup plus env fixes (#2265)
* fixes setup scripts

fixes to env handling in setup script plus other minor fixes

* Remove var declarations

Declarations such as `LLM_PROVIDER=$LLM_PROVIDER` override .env variables in compose

Similar issue is present in the frontend - need to choose either to switch to separate frontend env or keep as is.

* Manage apikeys in settings

1. More pydantic management of api keys.
2. Clean up of variable declarations from docker compose files, used to block .env imports. Now should be managed ether by settings.py defaults or .env
2026-01-22 12:21:01 +02:00
Alex
4c32a96370 rearrange settings 2026-01-22 00:51:59 +02:00
Alex
f61d112cea feat: process pdfs synthetically im model does not support file natively (#2263)
* feat: process pdfs synthetically im model does not support file natively

* fix: small code optimisations
2026-01-15 02:30:33 +02:00
Alex
2c55c6cd9a fix: tiktoken import in markdown parser 2026-01-12 23:04:20 +00:00
Alex
f1d714b5c1 fix(frontend): replace crypto.randomUUID with custom ID generator 2026-01-12 14:58:37 +00:00
Alex
69d9dc672a chore(settings): disable Docling OCR by default for text parsing 2026-01-12 12:01:54 +00:00
Alex
9192e010e8 build(docker): add g++ and python3.12-dev to system dependencies 2026-01-12 11:59:37 +00:00
Alex
f24cea0877 docs(models): update provider examples and add native llama.cpp info 2026-01-12 11:56:10 +00:00
Alex
a29bfa7489 fix simple routing (#2261) 2026-01-12 13:51:19 +02:00
Manish Madan
2246866a09 Feat: Agents grouped under folders (#2245)
* chore(dependabot): add react-widget npm dependency updates

* refactor(prompts): init on load, mv to pref slice

* (refactor): searchable dropdowns are separate

* (fix/ui) prompts adjust

* feat(changelog): dancing stars

* (fix)conversation: re-blink bubble past stream

* (fix)endless GET sources, esling err

* (feat:Agents) folders metadata

* (feat:agents) create new folder

* (feat:agent-management) ui

* feat:(agent folders) nesting/sub-folders

* feat:(agent folders)- closer the figma, inline folder inputs

* fix(delete behaviour) refetch agents on delete

* (fix:search) folder context missing

* fix(newAgent) preserve folder context

* feat(agent folders) id preserved im query, navigate

* feat(agents) mobile responsive

* feat(search/agents) lookup for nested agents as well

* (fix/modals) close on outside click

---------

Co-authored-by: GH Action - Upstream Sync <action@github.com>
2026-01-08 18:46:40 +02:00
Alex
7b17fde34a bump reqs 2026-01-08 11:53:07 +00:00
Alex
df57053613 feat: improve crawlers and update chunk filtering (#2250) 2026-01-06 00:52:12 +02:00
Alex
5662be12b5 feat(agent): implement context validation and message truncation (#2249) 2026-01-05 19:49:28 +02:00
Ankit Matth
d3e9d66b07 Fixed issue in models name (#2247) 2026-01-05 02:02:54 +02:00
Alex
e0bdbcbe38 Update README.md 2026-01-01 16:36:42 +02:00
Alex
05c835ed02 feat: enable OCR for docling when parsing attachments and update file extractor (#2246) 2025-12-31 02:08:49 +02:00
Alex
9e7f1ad1c0 Add Amazon S3 support and synchronization features (#2244)
* Add Amazon S3 support and synchronization features

* refactor: remove unused variable in load_data test
2025-12-30 20:26:51 +02:00
Alex
f910a82683 feat: add unauthorized response handling in StreamResource and bump deps 2025-12-27 14:23:37 +00:00
Alex
d8b7e86f8d bump dump (#2233) 2025-12-26 17:49:03 +02:00
Alex
aef3e0b4bb chore: update workflow permissions and fix paths in settings (#2227)
* chore: update workflow permissions and fix paths in settings

* dep

* dep upgraes
2025-12-25 14:26:01 +02:00
Alex
b0eee7be24 Patches (#2225)
* feat: implement URL validation to prevent SSRF

* feat: add zip extraction security

* ruff fixes

* fix: standardize error messages across API responses

* fix: improve error handling and standardize error messages across multiple routes

* fix: enhance JavaScript string safety in ConnectorCallbackStatus

* fix: improve OAuth error handling and message formatting in MCPOAuthCallback
2025-12-25 02:57:25 +02:00
Alex
197e94302b Patches (#2219)
* feat: implement URL validation to prevent SSRF

* feat: add zip extraction security

* ruff fixes

* fix: standardize error messages across API responses
2025-12-24 18:35:57 +02:00
Alex
98e949d2fd Patches (#2218)
* feat: implement URL validation to prevent SSRF

* feat: add zip extraction security

* ruff fixes
2025-12-24 17:05:35 +02:00
Alex
83e7a928f1 bump deps 2025-12-23 23:36:15 +00:00
Alex
ccd29b7d4e feat: implement Docling parsers (#2202)
* feat: implement Docling parsers

* fix office

* docling-ocr-fix

* Docling smart ocr

* ruff fix

---------

Co-authored-by: Pavel <pabin@yandex.ru>
2025-12-23 18:33:51 +02:00
Siddhant Rai
5b6cfa6ecc feat: enhance API tool with body serialization and content type handling (#2192)
* feat: enhance API tool with body serialization and content type handling

* feat: enhance ToolConfig with import functionality and user action management

- Added ImportSpecModal to allow importing actions into the tool configuration.
- Implemented search functionality for user actions with expandable action details.
- Introduced method colors for better visual distinction of HTTP methods.
- Updated APIActionType and ParameterGroupType to include optional 'required' field.
- Refactored action rendering to improve usability and maintainability.

* feat: add base URL input to ImportSpecModal for action URL customization

* feat: update TestBaseAgentTools to include 'required' field for parameters

* feat: standardize API call timeout to DEFAULT_TIMEOUT constant

* feat: add import specification functionality and related translations for multiple languages

---------

Co-authored-by: Alex <a@tushynski.me>
2025-12-23 15:37:44 +02:00
Akash Bhadana
f91846ce2d docs: Update VECTOR_STORE comment to include pgvector (#2211)
Co-authored-by: root <root@MD-CG-010>
2025-12-22 18:53:30 +02:00
dependabot[bot]
87e24ab96e chore(deps): bump elevenlabs from 2.26.1 to 2.27.0 in /application (#2203)
Bumps [elevenlabs](https://github.com/elevenlabs/elevenlabs-python) from 2.26.1 to 2.27.0.
- [Release notes](https://github.com/elevenlabs/elevenlabs-python/releases)
- [Commits](https://github.com/elevenlabs/elevenlabs-python/compare/v2.26.1...v2.27.0)

---
updated-dependencies:
- dependency-name: elevenlabs
  dependency-version: 2.27.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-12-22 12:25:45 +02:00
Alex
40c3e5568c fix search (#2210)
* fix search

* fix ruff
2025-12-22 00:51:06 +02:00
Yash
7958d29e13 Fix: Import external-link.svg properly in AgentDetailsModal (#2191) 2025-12-19 19:08:56 +02:00
Rahul Badade
a6fafa6a4d Fix: Autoselect input text box on pageload and conversation reset (#2177) (#2194)
* Fix: Autoselect input text box on pageload and conversation reset

- Added autoFocus to useEffect dependency array in MessageInput
- Added key prop to MessageInput to force remount on conversation reset
- Implemented refocus after message submission
- Removed duplicate input clearing logic in handleKeyDown

Fixes #2177

* fix: optimize input handling

---------

Co-authored-by: Alex <a@tushynski.me>
2025-12-19 18:57:57 +02:00
JustACodeA
3ad38f53fd fix: update Node.js version to 22 for Vite compatibility (#2169)
Updates the frontend Dockerfile from Node 20.6.1 to Node 22 to resolve
compatibility issues with Vite dependencies.

Closes #2157

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude <noreply@anthropic.com>
2025-12-19 18:29:58 +02:00
JustACodeA
d90b1c57e5 feat: add hover animation to conversation context menu button (#2168)
* feat: add hover animation to conversation context menu button

Adds visual feedback when hovering over the three-dot menu button in conversation tiles.
This makes it clear that the submenu is being targeted rather than the parent item.

Changes:
- Added rounded hover background with smooth transition
- Increased clickable area for better UX
- Supports both light and dark themes

Closes #2097

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* Update frontend/src/conversation/ConversationTile.tsx

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: Claude <noreply@anthropic.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-12-19 18:25:29 +02:00
Alex
a69a0e100f fix: update dependencies in requirements.txt (#2201) 2025-12-19 02:17:59 +02:00
Alex
b0d4576a95 fix: improve error handling in agent webhook worker 2025-12-18 13:27:40 +00:00
Alex
2a4ab3aca1 Fix history leftover (#2198)
* fix: history leftover

* fix: unbound result
2025-12-17 16:07:14 +02:00
Alex
e0fd11a86e fix: bump next 2025-12-17 14:06:53 +00:00
Alex
de369f8b5e fix: history leftover (#2197) 2025-12-17 13:07:50 +02:00
Alex
af3e16c4fc fix: count history tokens from chunks, remove old UI setting limit (#2196) 2025-12-17 03:34:17 +02:00
Alex
aacf281222 fix: improve remote embeds (#2193) 2025-12-16 13:59:17 +02:00
AbbasSalloum
6d8f083c6f Adding a feature to paste files you ctrl v (#2183) 2025-12-12 11:55:16 +00:00
Mohamed-Abuali
909bc421c0 Bugfix/docs gpt widget behavior (#2172)
* style(DocsGPTWidget): improve message bubbles and markdown styling

- Adjust max-width for message bubbles to 90% for answers and 80% for questions
- Add overflow-wrap to prevent text overflow in messages
- Update list styling with proper spacing and positioning
- Add responsive font sizing for headings using clamp()
- Implement custom table styling with proper borders and spacing
- Add custom markdown renderer rules for tables

* feat(widget): replace input with textarea for prompt input

Add support for multi-line input and custom scrollbar styling. Implement Enter key submission handling while allowing Shift+Enter for new lines.

* feat(widget): improve textarea auto-resizing and table styling

- Add auto-resizing functionality for prompt textarea with min/max height constraints
- Fix table cell markup (th/td) and improve scrollbar styling
- Add promptRef to manage textarea state and reset after submission

* fix(widget): correct table cell styling and prevent empty submissions

- Fix swapped td/th elements in markdown renderer
- Adjust font weights for table headers and cells
- Add validation to prevent empty message submissions

* (fix) name mkdwn rule as the returned element

---------

Co-authored-by: ManishMadan2882 <manishmadan321@gmail.com>
2025-12-11 01:35:55 +00:00
Alex
d14f04d79c Update requirements.txt 2025-12-11 00:54:58 +02:00
Alex
e0a9f08632 refactor and deps (#2184) 2025-12-10 23:53:59 +02:00
Manish Madan
09e7c1b97f Fixes: re-blink in converstaion, (refactor) prompts and validate LocalStorage prompts (#2181)
* chore(dependabot): add react-widget npm dependency updates

* refactor(prompts): init on load, mv to pref slice

* (refactor): searchable dropdowns are separate

* (fix/ui) prompts adjust

* feat(changelog): dancing stars

* (fix)conversation: re-blink bubble past stream

* (fix)endless GET sources, esling err

---------

Co-authored-by: GH Action - Upstream Sync <action@github.com>
2025-12-10 23:53:40 +02:00
Alex
4adffe762a Update README.md 2025-12-08 16:59:08 +02:00
Alex
9a937d2686 Feat/small optimisation (#2182)
* optimised ram use + celery

* Remove VITE_EMBEDDINGS_NAME

* fix: timeout on remote embeds
2025-12-05 20:57:39 +02:00
289 changed files with 35412 additions and 5912 deletions

View File

@@ -1,6 +1,12 @@
API_KEY=<LLM api key (for example, open ai key)>
LLM_NAME=docsgpt
VITE_API_STREAMING=true
INTERNAL_KEY=<internal key for worker-to-backend authentication>
# Remote Embeddings (Optional - for using a remote embeddings API instead of local SentenceTransformer)
# When set, the app will use the remote API and won't load SentenceTransformer (saves RAM)
EMBEDDINGS_BASE_URL=
EMBEDDINGS_KEY=
#For Azure (you can delete it if you don't use Azure)
OPENAI_API_BASE=

View File

@@ -25,7 +25,7 @@ jobs:
- name: Set up QEMU # Only needed for emulation, not for native arm64 builds
if: matrix.platform == 'linux/arm64'
uses: docker/setup-qemu-action@v3
uses: docker/setup-qemu-action@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3

View File

@@ -25,7 +25,7 @@ jobs:
- name: Set up QEMU # Only needed for emulation, not for native arm64 builds
if: matrix.platform == 'linux/arm64'
uses: docker/setup-qemu-action@v3
uses: docker/setup-qemu-action@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3

View File

@@ -27,7 +27,7 @@ jobs:
- name: Set up QEMU # Only needed for emulation, not for native arm64 builds
if: matrix.platform == 'linux/arm64'
uses: docker/setup-qemu-action@v3
uses: docker/setup-qemu-action@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3

View File

@@ -7,6 +7,9 @@ on:
pull_request:
types: [ opened, synchronize ]
permissions:
contents: read
jobs:
ruff:
runs-on: ubuntu-latest

View File

@@ -1,5 +1,9 @@
name: Run python tests with pytest
on: [push, pull_request]
permissions:
contents: read
jobs:
pytest_and_coverage:
name: Run tests and count coverage

View File

@@ -9,6 +9,10 @@ on:
- '.vale.ini'
- '.github/styles/**'
permissions:
contents: read
pull-requests: write
jobs:
vale:
runs-on: ubuntu-latest

5
.gitignore vendored
View File

@@ -71,6 +71,7 @@ instance/
# Sphinx documentation
docs/_build/
docs/public/_pagefind/
# PyBuilder
target/
@@ -147,6 +148,10 @@ frontend/yarn-error.log*
frontend/pnpm-debug.log*
frontend/lerna-debug.log*
# Keep frontend utility helpers tracked (overrides global lib/ ignore)
!frontend/src/lib/
!frontend/src/lib/**
frontend/node_modules
frontend/dist
frontend/dist-ssr

View File

@@ -1,2 +1,6 @@
# Allow lines to be as long as 120 characters.
line-length = 120
line-length = 120
[lint.per-file-ignores]
# Integration tests use sys.path.insert() before imports for standalone execution
"tests/integration/*.py" = ["E402"]

View File

@@ -46,24 +46,11 @@
</ul>
## Roadmap
- [x] Full GoogleAI compatibility (Jan 2025)
- [x] Add tools (Jan 2025)
- [x] Manually updating chunks in the app UI (Feb 2025)
- [x] Devcontainer for easy development (Feb 2025)
- [x] ReACT agent (March 2025)
- [x] Chatbots menu re-design to handle tools, agent types, and more (April 2025)
- [x] New input box in the conversation menu (April 2025)
- [x] Add triggerable actions / tools (webhook) (April 2025)
- [x] Agent optimisations (May 2025)
- [x] Filesystem sources update (July 2025)
- [x] Json Responses (August 2025)
- [x] MCP support (August 2025)
- [x] Google Drive integration (September 2025)
- [x] Add OAuth 2.0 authentication for MCP (September 2025)
- [ ] SharePoint integration (October 2025)
- [ ] Deep Agents (October 2025)
- [ ] Agent scheduling
- [x] Add OAuth 2.0 authentication for MCP ( September 2025 )
- [x] Deep Agents ( October 2025 )
- [x] Prompt Templating ( October 2025 )
- [x] Full api tooling ( Dec 2025 )
- [ ] Agent scheduling ( Jan 2026 )
You can find our full roadmap [here](https://github.com/orgs/arc53/projects/2). Please don't hesitate to contribute or create issues, it helps us improve DocsGPT!
@@ -158,9 +145,17 @@ We as members, contributors, and leaders, pledge to make participation in our co
The source code license is [MIT](https://opensource.org/license/mit/), as described in the [LICENSE](LICENSE) file.
<p>This project is supported by:</p>
## This project is supported by:
<p>
<a href="https://www.digitalocean.com/?utm_medium=opensource&utm_source=DocsGPT">
<img src="https://opensource.nyc3.cdn.digitaloceanspaces.com/attribution/assets/SVG/DO_Logo_horizontal_blue.svg" width="201px">
</a>
</p>
<p>
<a href="https://get.neon.com/docsgpt">
<img width="201" alt="color" src="https://github.com/user-attachments/assets/7d9813b7-0e6d-403f-b5af-68af066b326f" />
</a>
</p>

View File

@@ -1,12 +0,0 @@
API_KEY=your_api_key
EMBEDDINGS_KEY=your_api_key
API_URL=http://localhost:7091
INTERNAL_KEY=your_internal_key
FLASK_APP=application/app.py
FLASK_DEBUG=true
#For OPENAI on Azure
OPENAI_API_BASE=
OPENAI_API_VERSION=
AZURE_DEPLOYMENT_NAME=
AZURE_EMBEDDINGS_DEPLOYMENT_NAME=

View File

@@ -7,7 +7,7 @@ RUN apt-get update && \
apt-get install -y software-properties-common && \
add-apt-repository ppa:deadsnakes/ppa && \
apt-get update && \
apt-get install -y --no-install-recommends gcc wget unzip libc6-dev python3.12 python3.12-venv && \
apt-get install -y --no-install-recommends gcc g++ wget unzip libc6-dev python3.12 python3.12-venv python3.12-dev && \
rm -rf /var/lib/apt/lists/*
# Verify Python installation and setup symlink
@@ -48,7 +48,12 @@ FROM ubuntu:24.04 as final
RUN apt-get update && \
apt-get install -y software-properties-common && \
add-apt-repository ppa:deadsnakes/ppa && \
apt-get update && apt-get install -y --no-install-recommends python3.12 && \
apt-get update && apt-get install -y --no-install-recommends \
python3.12 \
libgl1 \
libglib2.0-0 \
poppler-utils \
&& \
ln -s /usr/bin/python3.12 /usr/bin/python && \
rm -rf /var/lib/apt/lists/*

View File

@@ -1,6 +1,8 @@
import logging
from application.agents.classic_agent import ClassicAgent
from application.agents.react_agent import ReActAgent
import logging
from application.agents.workflow_agent import WorkflowAgent
logger = logging.getLogger(__name__)
@@ -9,6 +11,7 @@ class AgentCreator:
agents = {
"classic": ClassicAgent,
"react": ReActAgent,
"workflow": WorkflowAgent,
}
@classmethod
@@ -16,5 +19,4 @@ class AgentCreator:
agent_class = cls.agents.get(type.lower())
if not agent_class:
raise ValueError(f"No agent class found for type {type}")
return agent_class(*args, **kwargs)

View File

@@ -7,6 +7,10 @@ from bson.objectid import ObjectId
from application.agents.tools.tool_action_parser import ToolActionParser
from application.agents.tools.tool_manager import ToolManager
from application.core.json_schema_utils import (
JsonSchemaValidationError,
normalize_json_schema_payload,
)
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.llm.handlers.handler_creator import LLMHandlerCreator
@@ -23,6 +27,7 @@ class BaseAgent(ABC):
llm_name: str,
model_id: str,
api_key: str,
agent_id: Optional[str] = None,
user_api_key: Optional[str] = None,
prompt: str = "",
chat_history: Optional[List[Dict]] = None,
@@ -40,6 +45,7 @@ class BaseAgent(ABC):
self.llm_name = llm_name
self.model_id = model_id
self.api_key = api_key
self.agent_id = agent_id
self.user_api_key = user_api_key
self.prompt = prompt
self.decoded_token = decoded_token or {}
@@ -54,13 +60,19 @@ class BaseAgent(ABC):
user_api_key=user_api_key,
decoded_token=decoded_token,
model_id=model_id,
agent_id=agent_id,
)
self.retrieved_docs = retrieved_docs or []
self.llm_handler = LLMHandlerCreator.create_handler(
llm_name if llm_name else "default"
)
self.attachments = attachments or []
self.json_schema = json_schema
self.json_schema = None
if json_schema is not None:
try:
self.json_schema = normalize_json_schema_payload(json_schema)
except JsonSchemaValidationError as exc:
logger.warning("Ignoring invalid JSON schema payload: %s", exc)
self.limited_token_mode = limited_token_mode
self.token_limit = token_limit
self.limited_request_mode = limited_request_mode
@@ -120,10 +132,10 @@ class BaseAgent(ABC):
params["properties"][k] = {
key: value
for key, value in v.items()
if key != "filled_by_llm" and key != "value"
if key not in ("filled_by_llm", "value", "required")
}
params["required"].append(k)
if v.get("required", False):
params["required"].append(k)
return params
def _prepare_tools(self, tools_dict):
@@ -219,7 +231,11 @@ class BaseAgent(ABC):
for param_type, target_dict in param_types.items():
if param_type in action_data and action_data[param_type].get("properties"):
for param, details in action_data[param_type]["properties"].items():
if param not in call_args and "value" in details:
if (
param not in call_args
and "value" in details
and details["value"]
):
target_dict[param] = details["value"]
for param, value in call_args.items():
for param_type, target_dict in param_types.items():
@@ -232,36 +248,80 @@ class BaseAgent(ABC):
# Prepare tool_config and add tool_id for memory tools
if tool_data["name"] == "api_tool":
action_config = tool_data["config"]["actions"][action_name]
tool_config = {
"url": tool_data["config"]["actions"][action_name]["url"],
"method": tool_data["config"]["actions"][action_name]["method"],
"url": action_config["url"],
"method": action_config["method"],
"headers": headers,
"query_params": query_params,
}
if "body_content_type" in action_config:
tool_config["body_content_type"] = action_config.get(
"body_content_type", "application/json"
)
tool_config["body_encoding_rules"] = action_config.get(
"body_encoding_rules", {}
)
else:
tool_config = tool_data["config"].copy() if tool_data["config"] else {}
# Add tool_id from MongoDB _id for tools that need instance isolation (like memory tool)
# Use MongoDB _id if available, otherwise fall back to enumerated tool_id
tool_config["tool_id"] = str(tool_data.get("_id", tool_id))
if hasattr(self, "conversation_id") and self.conversation_id:
tool_config["conversation_id"] = self.conversation_id
tool = tm.load_tool(
tool_data["name"],
tool_config=tool_config,
user_id=self.user, # Pass user ID for MCP tools credential decryption
user_id=self.user,
)
resolved_arguments = (
{"query_params": query_params, "headers": headers, "body": body}
if tool_data["name"] == "api_tool"
else parameters
)
if tool_data["name"] == "api_tool":
print(
logger.debug(
f"Executing api: {action_name} with query_params: {query_params}, headers: {headers}, body: {body}"
)
result = tool.execute_action(action_name, **body)
else:
print(f"Executing tool: {action_name} with args: {call_args}")
logger.debug(f"Executing tool: {action_name} with args: {call_args}")
result = tool.execute_action(action_name, **parameters)
tool_call_data["result"] = (
f"{str(result)[:50]}..." if len(str(result)) > 50 else result
get_artifact_id = (
getattr(tool, "get_artifact_id", None)
if tool_data["name"] != "api_tool"
else None
)
yield {"type": "tool_call", "data": {**tool_call_data, "status": "completed"}}
artifact_id = None
if callable(get_artifact_id):
try:
artifact_id = get_artifact_id(action_name, **parameters)
except Exception:
logger.exception(
"Failed to extract artifact_id from tool %s for action %s",
tool_data["name"],
action_name,
)
artifact_id = str(artifact_id).strip() if artifact_id is not None else ""
if artifact_id:
tool_call_data["artifact_id"] = artifact_id
result_full = str(result)
tool_call_data["resolved_arguments"] = resolved_arguments
tool_call_data["result_full"] = result_full
tool_call_data["result"] = (
f"{result_full[:50]}..." if len(result_full) > 50 else result_full
)
stream_tool_call_data = {
key: value
for key, value in tool_call_data.items()
if key not in {"result_full", "resolved_arguments"}
}
yield {"type": "tool_call", "data": {**stream_tool_call_data, "status": "completed"}}
self.tool_calls.append(tool_call_data)
return result, call_id
@@ -269,7 +329,11 @@ class BaseAgent(ABC):
def _get_truncated_tool_calls(self):
return [
{
**tool_call,
"tool_name": tool_call.get("tool_name"),
"call_id": tool_call.get("call_id"),
"action_name": tool_call.get("action_name"),
"arguments": tool_call.get("arguments"),
"artifact_id": tool_call.get("artifact_id"),
"result": (
f"{str(tool_call['result'])[:50]}..."
if len(str(tool_call["result"])) > 50
@@ -334,12 +398,83 @@ class BaseAgent(ABC):
logger.error(f"Error checking context limit: {str(e)}", exc_info=True)
return False
def _validate_context_size(self, messages: List[Dict]) -> None:
"""
Pre-flight validation before calling LLM. Logs warnings but never raises errors.
Args:
messages: Messages to be sent to LLM
"""
from application.core.model_utils import get_token_limit
current_tokens = self._calculate_current_context_tokens(messages)
self.current_token_count = current_tokens
context_limit = get_token_limit(self.model_id)
percentage = (current_tokens / context_limit) * 100
# Log based on usage level
if current_tokens >= context_limit:
logger.warning(
f"Context at limit: {current_tokens:,}/{context_limit:,} tokens "
f"({percentage:.1f}%). Model: {self.model_id}"
)
elif current_tokens >= int(
context_limit * settings.COMPRESSION_THRESHOLD_PERCENTAGE
):
logger.info(
f"Context approaching limit: {current_tokens:,}/{context_limit:,} tokens "
f"({percentage:.1f}%)"
)
def _truncate_text_middle(self, text: str, max_tokens: int) -> str:
"""
Truncate text by removing content from the middle, preserving start and end.
Args:
text: Text to truncate
max_tokens: Maximum tokens allowed
Returns:
Truncated text with middle removed if needed
"""
from application.utils import num_tokens_from_string
current_tokens = num_tokens_from_string(text)
if current_tokens <= max_tokens:
return text
# Estimate chars per token (roughly 4 chars per token for English)
chars_per_token = len(text) / current_tokens if current_tokens > 0 else 4
target_chars = int(max_tokens * chars_per_token * 0.95) # 5% safety margin
if target_chars <= 0:
return ""
# Split: keep 40% from start, 40% from end, remove middle
start_chars = int(target_chars * 0.4)
end_chars = int(target_chars * 0.4)
truncation_marker = "\n\n[... content truncated to fit context limit ...]\n\n"
truncated = text[:start_chars] + truncation_marker + text[-end_chars:]
logger.info(
f"Truncated text from {current_tokens:,} to ~{max_tokens:,} tokens "
f"(removed middle section)"
)
return truncated
def _build_messages(
self,
system_prompt: str,
query: str,
) -> List[Dict]:
"""Build messages using pre-rendered system prompt"""
from application.core.model_utils import get_token_limit
from application.utils import num_tokens_from_string
# Append compression summary to system prompt if present
if self.compressed_summary:
compression_context = (
@@ -351,9 +486,34 @@ class BaseAgent(ABC):
)
system_prompt = system_prompt + compression_context
context_limit = get_token_limit(self.model_id)
system_tokens = num_tokens_from_string(system_prompt)
# Reserve 10% for response/tools
safety_buffer = int(context_limit * 0.1)
available_after_system = context_limit - system_tokens - safety_buffer
# Max tokens for query: 80% of available space (leave room for history)
max_query_tokens = int(available_after_system * 0.8)
query_tokens = num_tokens_from_string(query)
# Truncate query from middle if it exceeds 80% of available context
if query_tokens > max_query_tokens:
query = self._truncate_text_middle(query, max_query_tokens)
query_tokens = num_tokens_from_string(query)
# Calculate remaining budget for chat history
available_for_history = max(available_after_system - query_tokens, 0)
# Truncate chat history to fit within available budget
working_history = self._truncate_history_to_fit(
self.chat_history,
available_for_history,
)
messages = [{"role": "system", "content": system_prompt}]
for i in self.chat_history:
for i in working_history:
if "prompt" in i and "response" in i:
messages.append({"role": "user", "content": i["prompt"]})
messages.append({"role": "assistant", "content": i["response"]})
@@ -385,8 +545,69 @@ class BaseAgent(ABC):
messages.append({"role": "user", "content": query})
return messages
def _truncate_history_to_fit(
self,
history: List[Dict],
max_tokens: int,
) -> List[Dict]:
"""
Truncate chat history to fit within token budget, keeping most recent messages.
Args:
history: Full chat history
max_tokens: Maximum tokens allowed for history
Returns:
Truncated history (most recent messages that fit)
"""
from application.utils import num_tokens_from_string
if not history or max_tokens <= 0:
return []
truncated = []
current_tokens = 0
# Iterate from newest to oldest
for message in reversed(history):
message_tokens = 0
if "prompt" in message and "response" in message:
message_tokens += num_tokens_from_string(message["prompt"])
message_tokens += num_tokens_from_string(message["response"])
if "tool_calls" in message:
for tool_call in message["tool_calls"]:
tool_str = (
f"Tool: {tool_call.get('tool_name')} | "
f"Action: {tool_call.get('action_name')} | "
f"Args: {tool_call.get('arguments')} | "
f"Response: {tool_call.get('result')}"
)
message_tokens += num_tokens_from_string(tool_str)
if current_tokens + message_tokens <= max_tokens:
current_tokens += message_tokens
truncated.insert(0, message) # Maintain chronological order
else:
break
if len(truncated) < len(history):
logger.info(
f"Truncated chat history from {len(history)} to {len(truncated)} messages "
f"to fit within {max_tokens:,} token budget"
)
return truncated
def _llm_gen(self, messages: List[Dict], log_context: Optional[LogContext] = None):
# Pre-flight context validation - fail fast if over limit
self._validate_context_size(messages)
gen_kwargs = {"model": self.model_id, "messages": messages}
if self.attachments:
# Usage accounting only; stripped before provider invocation.
gen_kwargs["_usage_attachments"] = self.attachments
if (
hasattr(self.llm, "_supports_tools")

View File

@@ -0,0 +1,323 @@
import base64
import json
import logging
from enum import Enum
from typing import Any, Dict, Optional, Union
from urllib.parse import quote, urlencode
logger = logging.getLogger(__name__)
class ContentType(str, Enum):
"""Supported content types for request bodies."""
JSON = "application/json"
FORM_URLENCODED = "application/x-www-form-urlencoded"
MULTIPART_FORM_DATA = "multipart/form-data"
TEXT_PLAIN = "text/plain"
XML = "application/xml"
OCTET_STREAM = "application/octet-stream"
class RequestBodySerializer:
"""Serializes request bodies according to content-type and OpenAPI 3.1 spec."""
@staticmethod
def serialize(
body_data: Dict[str, Any],
content_type: str = ContentType.JSON,
encoding_rules: Optional[Dict[str, Dict[str, Any]]] = None,
) -> tuple[Union[str, bytes], Dict[str, str]]:
"""
Serialize body data to appropriate format.
Args:
body_data: Dictionary of body parameters
content_type: Content-Type header value
encoding_rules: OpenAPI Encoding Object rules per field
Returns:
Tuple of (serialized_body, updated_headers_dict)
Raises:
ValueError: If serialization fails
"""
if not body_data:
return None, {}
try:
content_type_lower = content_type.lower().split(";")[0].strip()
if content_type_lower == ContentType.JSON:
return RequestBodySerializer._serialize_json(body_data)
elif content_type_lower == ContentType.FORM_URLENCODED:
return RequestBodySerializer._serialize_form_urlencoded(
body_data, encoding_rules
)
elif content_type_lower == ContentType.MULTIPART_FORM_DATA:
return RequestBodySerializer._serialize_multipart_form_data(
body_data, encoding_rules
)
elif content_type_lower == ContentType.TEXT_PLAIN:
return RequestBodySerializer._serialize_text_plain(body_data)
elif content_type_lower == ContentType.XML:
return RequestBodySerializer._serialize_xml(body_data)
elif content_type_lower == ContentType.OCTET_STREAM:
return RequestBodySerializer._serialize_octet_stream(body_data)
else:
logger.warning(
f"Unknown content type: {content_type}, treating as JSON"
)
return RequestBodySerializer._serialize_json(body_data)
except Exception as e:
logger.error(f"Error serializing body: {str(e)}", exc_info=True)
raise ValueError(f"Failed to serialize request body: {str(e)}")
@staticmethod
def _serialize_json(body_data: Dict[str, Any]) -> tuple[str, Dict[str, str]]:
"""Serialize body as JSON per OpenAPI spec."""
try:
serialized = json.dumps(
body_data, separators=(",", ":"), ensure_ascii=False
)
headers = {"Content-Type": ContentType.JSON.value}
return serialized, headers
except (TypeError, ValueError) as e:
raise ValueError(f"Failed to serialize JSON body: {str(e)}")
@staticmethod
def _serialize_form_urlencoded(
body_data: Dict[str, Any],
encoding_rules: Optional[Dict[str, Dict[str, Any]]] = None,
) -> tuple[str, Dict[str, str]]:
"""Serialize body as application/x-www-form-urlencoded per RFC1866/RFC3986."""
encoding_rules = encoding_rules or {}
params = []
for key, value in body_data.items():
if value is None:
continue
rule = encoding_rules.get(key, {})
style = rule.get("style", "form")
explode = rule.get("explode", style == "form")
content_type = rule.get("contentType", "text/plain")
serialized_value = RequestBodySerializer._serialize_form_value(
value, style, explode, content_type, key
)
if isinstance(serialized_value, list):
for sv in serialized_value:
params.append((key, sv))
else:
params.append((key, serialized_value))
# Use standard urlencode (replaces space with +)
serialized = urlencode(params, safe="")
headers = {"Content-Type": ContentType.FORM_URLENCODED.value}
return serialized, headers
@staticmethod
def _serialize_form_value(
value: Any, style: str, explode: bool, content_type: str, key: str
) -> Union[str, list]:
"""Serialize individual form value with encoding rules."""
if isinstance(value, dict):
if content_type == "application/json":
return json.dumps(value, separators=(",", ":"))
elif content_type == "application/xml":
return RequestBodySerializer._dict_to_xml(value)
else:
if style == "deepObject" and explode:
return [
f"{RequestBodySerializer._percent_encode(str(v))}"
for v in value.values()
]
elif explode:
return [
f"{RequestBodySerializer._percent_encode(str(v))}"
for v in value.values()
]
else:
pairs = [f"{k},{v}" for k, v in value.items()]
return RequestBodySerializer._percent_encode(",".join(pairs))
elif isinstance(value, (list, tuple)):
if explode:
return [
RequestBodySerializer._percent_encode(str(item)) for item in value
]
else:
return RequestBodySerializer._percent_encode(
",".join(str(v) for v in value)
)
else:
return RequestBodySerializer._percent_encode(str(value))
@staticmethod
def _serialize_multipart_form_data(
body_data: Dict[str, Any],
encoding_rules: Optional[Dict[str, Dict[str, Any]]] = None,
) -> tuple[bytes, Dict[str, str]]:
"""
Serialize body as multipart/form-data per RFC7578.
Supports file uploads and encoding rules.
"""
import secrets
encoding_rules = encoding_rules or {}
boundary = f"----DocsGPT{secrets.token_hex(16)}"
parts = []
for key, value in body_data.items():
if value is None:
continue
rule = encoding_rules.get(key, {})
content_type = rule.get("contentType", "text/plain")
headers_rule = rule.get("headers", {})
part = RequestBodySerializer._create_multipart_part(
key, value, content_type, headers_rule
)
parts.append(part)
body_bytes = f"--{boundary}\r\n".encode("utf-8")
body_bytes += f"--{boundary}\r\n".join(parts).encode("utf-8")
body_bytes += f"\r\n--{boundary}--\r\n".encode("utf-8")
headers = {
"Content-Type": f"multipart/form-data; boundary={boundary}",
}
return body_bytes, headers
@staticmethod
def _create_multipart_part(
name: str, value: Any, content_type: str, headers_rule: Dict[str, Any]
) -> str:
"""Create a single multipart/form-data part."""
headers = [
f'Content-Disposition: form-data; name="{RequestBodySerializer._percent_encode(name)}"'
]
if isinstance(value, bytes):
if content_type == "application/octet-stream":
value_encoded = base64.b64encode(value).decode("utf-8")
else:
value_encoded = value.decode("utf-8", errors="replace")
headers.append(f"Content-Type: {content_type}")
headers.append("Content-Transfer-Encoding: base64")
elif isinstance(value, dict):
if content_type == "application/json":
value_encoded = json.dumps(value, separators=(",", ":"))
elif content_type == "application/xml":
value_encoded = RequestBodySerializer._dict_to_xml(value)
else:
value_encoded = str(value)
headers.append(f"Content-Type: {content_type}")
elif isinstance(value, str) and content_type != "text/plain":
try:
if content_type == "application/json":
json.loads(value)
value_encoded = value
elif content_type == "application/xml":
value_encoded = value
else:
value_encoded = str(value)
except json.JSONDecodeError:
value_encoded = str(value)
headers.append(f"Content-Type: {content_type}")
else:
value_encoded = str(value)
if content_type != "text/plain":
headers.append(f"Content-Type: {content_type}")
part = "\r\n".join(headers) + "\r\n\r\n" + value_encoded + "\r\n"
return part
@staticmethod
def _serialize_text_plain(body_data: Dict[str, Any]) -> tuple[str, Dict[str, str]]:
"""Serialize body as plain text."""
if len(body_data) == 1:
value = list(body_data.values())[0]
return str(value), {"Content-Type": ContentType.TEXT_PLAIN.value}
else:
text = "\n".join(f"{k}: {v}" for k, v in body_data.items())
return text, {"Content-Type": ContentType.TEXT_PLAIN.value}
@staticmethod
def _serialize_xml(body_data: Dict[str, Any]) -> tuple[str, Dict[str, str]]:
"""Serialize body as XML."""
xml_str = RequestBodySerializer._dict_to_xml(body_data)
return xml_str, {"Content-Type": ContentType.XML.value}
@staticmethod
def _serialize_octet_stream(
body_data: Dict[str, Any],
) -> tuple[bytes, Dict[str, str]]:
"""Serialize body as binary octet stream."""
if isinstance(body_data, bytes):
return body_data, {"Content-Type": ContentType.OCTET_STREAM.value}
elif isinstance(body_data, str):
return body_data.encode("utf-8"), {
"Content-Type": ContentType.OCTET_STREAM.value
}
else:
serialized = json.dumps(body_data)
return serialized.encode("utf-8"), {
"Content-Type": ContentType.OCTET_STREAM.value
}
@staticmethod
def _percent_encode(value: str, safe_chars: str = "") -> str:
"""
Percent-encode per RFC3986.
Args:
value: String to encode
safe_chars: Additional characters to not encode
"""
return quote(value, safe=safe_chars)
@staticmethod
def _dict_to_xml(data: Dict[str, Any], root_name: str = "root") -> str:
"""
Convert dict to simple XML format.
"""
def build_xml(obj: Any, name: str) -> str:
if isinstance(obj, dict):
inner = "".join(build_xml(v, k) for k, v in obj.items())
return f"<{name}>{inner}</{name}>"
elif isinstance(obj, (list, tuple)):
items = "".join(
build_xml(item, f"{name[:-1] if name.endswith('s') else name}")
for item in obj
)
return items
else:
return f"<{name}>{RequestBodySerializer._escape_xml(str(obj))}</{name}>"
root = build_xml(data, root_name)
return f'<?xml version="1.0" encoding="UTF-8"?>{root}'
@staticmethod
def _escape_xml(value: str) -> str:
"""Escape XML special characters."""
return (
value.replace("&", "&amp;")
.replace("<", "&lt;")
.replace(">", "&gt;")
.replace('"', "&quot;")
.replace("'", "&apos;")
)

View File

@@ -1,72 +1,280 @@
import json
import logging
import re
from typing import Any, Dict, Optional
from urllib.parse import urlencode
import requests
from application.agents.tools.api_body_serializer import (
ContentType,
RequestBodySerializer,
)
from application.agents.tools.base import Tool
from application.core.url_validation import validate_url, SSRFError
logger = logging.getLogger(__name__)
DEFAULT_TIMEOUT = 90 # seconds
class APITool(Tool):
"""
API Tool
A flexible tool for performing various API actions (e.g., sending messages, retrieving data) via custom user-specified APIs
A flexible tool for performing various API actions (e.g., sending messages, retrieving data) via custom user-specified APIs.
"""
def __init__(self, config):
self.config = config
self.url = config.get("url", "")
self.method = config.get("method", "GET")
self.headers = config.get("headers", {"Content-Type": "application/json"})
self.headers = config.get("headers", {})
self.query_params = config.get("query_params", {})
self.body_content_type = config.get("body_content_type", ContentType.JSON)
self.body_encoding_rules = config.get("body_encoding_rules", {})
def execute_action(self, action_name, **kwargs):
"""Execute an API action with the given arguments."""
return self._make_api_call(
self.url, self.method, self.headers, self.query_params, kwargs
self.url,
self.method,
self.headers,
self.query_params,
kwargs,
self.body_content_type,
self.body_encoding_rules,
)
def _make_api_call(self, url, method, headers, query_params, body):
if query_params:
url = f"{url}?{requests.compat.urlencode(query_params)}"
# if isinstance(body, dict):
# body = json.dumps(body)
def _make_api_call(
self,
url: str,
method: str,
headers: Dict[str, str],
query_params: Dict[str, Any],
body: Dict[str, Any],
content_type: str = ContentType.JSON,
encoding_rules: Optional[Dict[str, Dict[str, Any]]] = None,
) -> Dict[str, Any]:
"""
Make an API call with proper body serialization and error handling.
Args:
url: API endpoint URL
method: HTTP method (GET, POST, PUT, DELETE, PATCH, HEAD, OPTIONS)
headers: Request headers dict
query_params: URL query parameters
body: Request body as dict
content_type: Content-Type for serialization
encoding_rules: OpenAPI encoding rules
Returns:
Dict with status_code, data, and message
"""
request_url = url
request_headers = headers.copy() if headers else {}
response = None
# Validate URL to prevent SSRF attacks
try:
print(f"Making API call: {method} {url} with body: {body}")
if body == "{}":
body = None
response = requests.request(method, url, headers=headers, data=body)
response.raise_for_status()
content_type = response.headers.get(
"Content-Type", "application/json"
).lower()
if "application/json" in content_type:
validate_url(request_url)
except SSRFError as e:
logger.error(f"URL validation failed: {e}")
return {
"status_code": None,
"message": f"URL validation error: {e}",
"data": None,
}
try:
path_params_used = set()
if query_params:
for match in re.finditer(r"\{([^}]+)\}", request_url):
param_name = match.group(1)
if param_name in query_params:
request_url = request_url.replace(
f"{{{param_name}}}", str(query_params[param_name])
)
path_params_used.add(param_name)
remaining_params = {
k: v for k, v in query_params.items() if k not in path_params_used
}
if remaining_params:
query_string = urlencode(remaining_params)
separator = "&" if "?" in request_url else "?"
request_url = f"{request_url}{separator}{query_string}"
# Re-validate URL after parameter substitution to prevent SSRF via path params
try:
validate_url(request_url)
except SSRFError as e:
logger.error(f"URL validation failed after parameter substitution: {e}")
return {
"status_code": None,
"message": f"URL validation error: {e}",
"data": None,
}
# Serialize body based on content type
if body and body != {}:
try:
data = response.json()
except json.JSONDecodeError as e:
print(f"Error decoding JSON: {e}. Raw response: {response.text}")
serialized_body, body_headers = RequestBodySerializer.serialize(
body, content_type, encoding_rules
)
request_headers.update(body_headers)
except ValueError as e:
logger.error(f"Body serialization failed: {str(e)}")
return {
"status_code": response.status_code,
"message": f"API call returned invalid JSON. Error: {e}",
"data": response.text,
"status_code": None,
"message": f"Body serialization error: {str(e)}",
"data": None,
}
elif "text/" in content_type or "application/xml" in content_type:
data = response.text
elif not response.content:
data = None
else:
print(f"Unsupported content type: {content_type}")
data = response.content
serialized_body = None
if "Content-Type" not in request_headers and method not in [
"GET",
"HEAD",
"DELETE",
]:
request_headers["Content-Type"] = ContentType.JSON
logger.debug(
f"API Call: {method} {request_url} | Content-Type: {request_headers.get('Content-Type', 'N/A')}"
)
if method.upper() == "GET":
response = requests.get(
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
)
elif method.upper() == "POST":
response = requests.post(
request_url,
data=serialized_body,
headers=request_headers,
timeout=DEFAULT_TIMEOUT,
)
elif method.upper() == "PUT":
response = requests.put(
request_url,
data=serialized_body,
headers=request_headers,
timeout=DEFAULT_TIMEOUT,
)
elif method.upper() == "DELETE":
response = requests.delete(
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
)
elif method.upper() == "PATCH":
response = requests.patch(
request_url,
data=serialized_body,
headers=request_headers,
timeout=DEFAULT_TIMEOUT,
)
elif method.upper() == "HEAD":
response = requests.head(
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
)
elif method.upper() == "OPTIONS":
response = requests.options(
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
)
else:
return {
"status_code": None,
"message": f"Unsupported HTTP method: {method}",
"data": None,
}
response.raise_for_status()
data = self._parse_response(response)
return {
"status_code": response.status_code,
"data": data,
"message": "API call successful.",
}
except requests.exceptions.Timeout:
logger.error(f"Request timeout for {request_url}")
return {
"status_code": None,
"message": f"Request timeout ({DEFAULT_TIMEOUT}s exceeded)",
"data": None,
}
except requests.exceptions.ConnectionError as e:
logger.error(f"Connection error: {str(e)}")
return {
"status_code": None,
"message": f"Connection error: {str(e)}",
"data": None,
}
except requests.exceptions.HTTPError as e:
logger.error(f"HTTP error {response.status_code}: {str(e)}")
try:
error_data = response.json()
except (json.JSONDecodeError, ValueError):
error_data = response.text
return {
"status_code": response.status_code,
"message": f"HTTP Error {response.status_code}",
"data": error_data,
}
except requests.exceptions.RequestException as e:
logger.error(f"Request failed: {str(e)}")
return {
"status_code": response.status_code if response else None,
"message": f"API call failed: {str(e)}",
"data": None,
}
except Exception as e:
logger.error(f"Unexpected error in API call: {str(e)}", exc_info=True)
return {
"status_code": None,
"message": f"Unexpected error: {str(e)}",
"data": None,
}
def _parse_response(self, response: requests.Response) -> Any:
"""
Parse response based on Content-Type header.
Supports: JSON, XML, plain text, binary data.
"""
content_type = response.headers.get("Content-Type", "").lower()
if not response.content:
return None
# JSON response
if "application/json" in content_type:
try:
return response.json()
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse JSON response: {str(e)}")
return response.text
# XML response
elif "application/xml" in content_type or "text/xml" in content_type:
return response.text
# Plain text response
elif "text/plain" in content_type or "text/html" in content_type:
return response.text
# Binary/unknown response
else:
# Try to decode as text first, fall back to base64
try:
return response.text
except (UnicodeDecodeError, AttributeError):
import base64
return base64.b64encode(response.content).decode("utf-8")
def get_actions_metadata(self):
"""Return metadata for available actions (none for API Tool - actions are user-defined)."""
return []
def get_config_requirements(self):
"""Return configuration requirements for the tool."""
return {}

View File

@@ -169,6 +169,8 @@ class MCPTool(Tool):
transport_type = "http"
else:
transport_type = self.transport_type
if transport_type == "stdio":
raise ValueError("STDIO transport is disabled")
if transport_type == "sse":
headers.update({"Accept": "text/event-stream", "Cache-Control": "no-cache"})
return SSETransport(url=self.server_url, headers=headers)
@@ -454,6 +456,12 @@ class MCPTool(Tool):
def get_config_requirements(self) -> Dict:
"""Get configuration requirements for the MCP tool."""
transport_enum = ["auto", "sse", "http"]
transport_help = {
"auto": "Automatically detect best transport",
"sse": "Server-Sent Events (for real-time streaming)",
"http": "HTTP streaming (recommended for production)",
}
return {
"server_url": {
"type": "string",
@@ -463,14 +471,11 @@ class MCPTool(Tool):
"transport_type": {
"type": "string",
"description": "Transport type for connection",
"enum": ["auto", "sse", "http", "stdio"],
"enum": transport_enum,
"default": "auto",
"required": False,
"help": {
"auto": "Automatically detect best transport",
"sse": "Server-Sent Events (for real-time streaming)",
"http": "HTTP streaming (recommended for production)",
"stdio": "Standard I/O (for local servers)",
**transport_help,
},
},
"auth_type": {

View File

@@ -38,6 +38,8 @@ class NotesTool(Tool):
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
self.collection = db["notes"]
self._last_artifact_id: Optional[str] = None
# -----------------------------
# Action implementations
# -----------------------------
@@ -54,6 +56,8 @@ class NotesTool(Tool):
if not self.user_id:
return "Error: NotesTool requires a valid user_id."
self._last_artifact_id = None
if action_name == "view":
return self._get_note()
@@ -125,6 +129,9 @@ class NotesTool(Tool):
"""Return configuration requirements (none for now)."""
return {}
def get_artifact_id(self, action_name: str, **kwargs: Any) -> Optional[str]:
return self._last_artifact_id
# -----------------------------
# Internal helpers (single-note)
# -----------------------------
@@ -132,17 +139,22 @@ class NotesTool(Tool):
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
if not doc or not doc.get("note"):
return "No note found."
if doc.get("_id") is not None:
self._last_artifact_id = str(doc.get("_id"))
return str(doc["note"])
def _overwrite_note(self, content: str) -> str:
content = (content or "").strip()
if not content:
return "Note content required."
self.collection.update_one(
result = self.collection.find_one_and_update(
{"user_id": self.user_id, "tool_id": self.tool_id},
{"$set": {"note": content, "updated_at": datetime.utcnow()}},
upsert=True, # ✅ create if missing
upsert=True,
return_document=True,
)
if result and result.get("_id") is not None:
self._last_artifact_id = str(result.get("_id"))
return "Note saved."
def _str_replace(self, old_str: str, new_str: str) -> str:
@@ -163,10 +175,13 @@ class NotesTool(Tool):
import re
updated_note = re.sub(re.escape(old_str), new_str, current_note, flags=re.IGNORECASE)
self.collection.update_one(
result = self.collection.find_one_and_update(
{"user_id": self.user_id, "tool_id": self.tool_id},
{"$set": {"note": updated_note, "updated_at": datetime.utcnow()}},
return_document=True,
)
if result and result.get("_id") is not None:
self._last_artifact_id = str(result.get("_id"))
return "Note updated."
def _insert(self, line_number: int, text: str) -> str:
@@ -188,12 +203,21 @@ class NotesTool(Tool):
lines.insert(index, text)
updated_note = "\n".join(lines)
self.collection.update_one(
result = self.collection.find_one_and_update(
{"user_id": self.user_id, "tool_id": self.tool_id},
{"$set": {"note": updated_note, "updated_at": datetime.utcnow()}},
return_document=True,
)
if result and result.get("_id") is not None:
self._last_artifact_id = str(result.get("_id"))
return "Text inserted."
def _delete_note(self) -> str:
res = self.collection.delete_one({"user_id": self.user_id, "tool_id": self.tool_id})
return "Note deleted." if res.deleted_count else "No note found to delete."
doc = self.collection.find_one_and_delete(
{"user_id": self.user_id, "tool_id": self.tool_id}
)
if not doc:
return "No note found to delete."
if doc.get("_id") is not None:
self._last_artifact_id = str(doc.get("_id"))
return "Note deleted."

View File

@@ -1,7 +1,7 @@
import requests
from markdownify import markdownify
from application.agents.tools.base import Tool
from urllib.parse import urlparse
from application.core.url_validation import validate_url, SSRFError
class ReadWebpageTool(Tool):
"""
@@ -31,11 +31,12 @@ class ReadWebpageTool(Tool):
if not url:
return "Error: URL parameter is missing."
# Ensure the URL has a scheme (if not, default to http)
parsed_url = urlparse(url)
if not parsed_url.scheme:
url = "http://" + url
# Validate URL to prevent SSRF attacks
try:
url = validate_url(url)
except SSRFError as e:
return f"Error: URL validation failed - {e}"
try:
response = requests.get(url, timeout=10, headers={'User-Agent': 'DocsGPT-Agent/1.0'})
response.raise_for_status() # Raise an exception for HTTP errors (4xx or 5xx)

View File

@@ -0,0 +1,342 @@
"""
API Specification Parser
Parses OpenAPI 3.x and Swagger 2.0 specifications and converts them
to API Tool action definitions for use in DocsGPT.
"""
import json
import logging
import re
from typing import Any, Dict, List, Optional, Tuple
import yaml
logger = logging.getLogger(__name__)
SUPPORTED_METHODS = frozenset(
{"get", "post", "put", "delete", "patch", "head", "options"}
)
def parse_spec(spec_content: str) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
"""
Parse an API specification and convert operations to action definitions.
Supports OpenAPI 3.x and Swagger 2.0 formats in JSON or YAML.
Args:
spec_content: Raw specification content as string
Returns:
Tuple of (metadata dict, list of action dicts)
Raises:
ValueError: If the spec is invalid or uses an unsupported format
"""
spec = _load_spec(spec_content)
_validate_spec(spec)
is_swagger = "swagger" in spec
metadata = _extract_metadata(spec, is_swagger)
actions = _extract_actions(spec, is_swagger)
return metadata, actions
def _load_spec(content: str) -> Dict[str, Any]:
"""Parse spec content from JSON or YAML string."""
content = content.strip()
if not content:
raise ValueError("Empty specification content")
try:
if content.startswith("{"):
return json.loads(content)
return yaml.safe_load(content)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON format: {e.msg}")
except yaml.YAMLError as e:
raise ValueError(f"Invalid YAML format: {e}")
def _validate_spec(spec: Dict[str, Any]) -> None:
"""Validate spec version and required fields."""
if not isinstance(spec, dict):
raise ValueError("Specification must be a valid object")
openapi_version = spec.get("openapi", "")
swagger_version = spec.get("swagger", "")
if not (openapi_version.startswith("3.") or swagger_version == "2.0"):
raise ValueError(
"Unsupported specification version. Expected OpenAPI 3.x or Swagger 2.0"
)
if "paths" not in spec or not spec["paths"]:
raise ValueError("No API paths defined in the specification")
def _extract_metadata(spec: Dict[str, Any], is_swagger: bool) -> Dict[str, Any]:
"""Extract API metadata from specification."""
info = spec.get("info", {})
base_url = _get_base_url(spec, is_swagger)
return {
"title": info.get("title", "Untitled API"),
"description": (info.get("description", "") or "")[:500],
"version": info.get("version", ""),
"base_url": base_url,
}
def _get_base_url(spec: Dict[str, Any], is_swagger: bool) -> str:
"""Extract base URL from spec (handles both OpenAPI 3.x and Swagger 2.0)."""
if is_swagger:
schemes = spec.get("schemes", ["https"])
host = spec.get("host", "")
base_path = spec.get("basePath", "")
if host:
scheme = schemes[0] if schemes else "https"
return f"{scheme}://{host}{base_path}".rstrip("/")
return ""
servers = spec.get("servers", [])
if servers and isinstance(servers, list) and servers[0].get("url"):
return servers[0]["url"].rstrip("/")
return ""
def _extract_actions(spec: Dict[str, Any], is_swagger: bool) -> List[Dict[str, Any]]:
"""Extract all API operations as action definitions."""
actions = []
paths = spec.get("paths", {})
base_url = _get_base_url(spec, is_swagger)
components = spec.get("components", {})
definitions = spec.get("definitions", {})
for path, path_item in paths.items():
if not isinstance(path_item, dict):
continue
path_params = path_item.get("parameters", [])
for method in SUPPORTED_METHODS:
operation = path_item.get(method)
if not isinstance(operation, dict):
continue
try:
action = _build_action(
path=path,
method=method,
operation=operation,
path_params=path_params,
base_url=base_url,
components=components,
definitions=definitions,
is_swagger=is_swagger,
)
actions.append(action)
except Exception as e:
logger.warning(
f"Failed to parse operation {method.upper()} {path}: {e}"
)
continue
return actions
def _build_action(
path: str,
method: str,
operation: Dict[str, Any],
path_params: List[Dict],
base_url: str,
components: Dict[str, Any],
definitions: Dict[str, Any],
is_swagger: bool,
) -> Dict[str, Any]:
"""Build a single action from an API operation."""
action_name = _generate_action_name(operation, method, path)
full_url = f"{base_url}{path}" if base_url else path
all_params = path_params + operation.get("parameters", [])
query_params, headers = _categorize_parameters(all_params, components, definitions)
body, body_content_type = _extract_request_body(
operation, components, definitions, is_swagger
)
description = operation.get("summary", "") or operation.get("description", "")
return {
"name": action_name,
"url": full_url,
"method": method.upper(),
"description": (description or "")[:500],
"query_params": {"type": "object", "properties": query_params},
"headers": {"type": "object", "properties": headers},
"body": {"type": "object", "properties": body},
"body_content_type": body_content_type,
"active": True,
}
def _generate_action_name(operation: Dict[str, Any], method: str, path: str) -> str:
"""Generate a valid action name from operationId or method+path."""
if operation.get("operationId"):
name = operation["operationId"]
else:
path_slug = re.sub(r"[{}]", "", path)
path_slug = re.sub(r"[^a-zA-Z0-9]", "_", path_slug)
path_slug = re.sub(r"_+", "_", path_slug).strip("_")
name = f"{method}_{path_slug}"
name = re.sub(r"[^a-zA-Z0-9_-]", "_", name)
return name[:64]
def _categorize_parameters(
parameters: List[Dict],
components: Dict[str, Any],
definitions: Dict[str, Any],
) -> Tuple[Dict, Dict]:
"""Categorize parameters into query params and headers."""
query_params = {}
headers = {}
for param in parameters:
resolved = _resolve_ref(param, components, definitions)
if not resolved or "name" not in resolved:
continue
location = resolved.get("in", "query")
prop = _param_to_property(resolved)
if location in ("query", "path"):
query_params[resolved["name"]] = prop
elif location == "header":
headers[resolved["name"]] = prop
return query_params, headers
def _param_to_property(param: Dict) -> Dict[str, Any]:
"""Convert an API parameter to an action property definition."""
schema = param.get("schema", {})
param_type = schema.get("type", param.get("type", "string"))
mapped_type = "integer" if param_type in ("integer", "number") else "string"
return {
"type": mapped_type,
"description": (param.get("description", "") or "")[:200],
"value": "",
"filled_by_llm": param.get("required", False),
"required": param.get("required", False),
}
def _extract_request_body(
operation: Dict[str, Any],
components: Dict[str, Any],
definitions: Dict[str, Any],
is_swagger: bool,
) -> Tuple[Dict, str]:
"""Extract request body schema and content type."""
content_types = [
"application/json",
"application/x-www-form-urlencoded",
"multipart/form-data",
"text/plain",
"application/xml",
]
if is_swagger:
consumes = operation.get("consumes", [])
body_param = next(
(p for p in operation.get("parameters", []) if p.get("in") == "body"), None
)
if not body_param:
return {}, "application/json"
selected_type = consumes[0] if consumes else "application/json"
schema = body_param.get("schema", {})
else:
request_body = operation.get("requestBody", {})
if not request_body:
return {}, "application/json"
request_body = _resolve_ref(request_body, components, definitions)
content = request_body.get("content", {})
selected_type = "application/json"
schema = {}
for ct in content_types:
if ct in content:
selected_type = ct
schema = content[ct].get("schema", {})
break
if not schema and content:
first_type = next(iter(content))
selected_type = first_type
schema = content[first_type].get("schema", {})
properties = _schema_to_properties(schema, components, definitions)
return properties, selected_type
def _schema_to_properties(
schema: Dict,
components: Dict[str, Any],
definitions: Dict[str, Any],
depth: int = 0,
) -> Dict[str, Any]:
"""Convert schema to action body properties (limited depth to prevent recursion)."""
if depth > 3:
return {}
schema = _resolve_ref(schema, components, definitions)
if not schema or not isinstance(schema, dict):
return {}
properties = {}
schema_type = schema.get("type", "object")
if schema_type == "object":
required_fields = set(schema.get("required", []))
for prop_name, prop_schema in schema.get("properties", {}).items():
resolved = _resolve_ref(prop_schema, components, definitions)
if not isinstance(resolved, dict):
continue
prop_type = resolved.get("type", "string")
mapped_type = "integer" if prop_type in ("integer", "number") else "string"
properties[prop_name] = {
"type": mapped_type,
"description": (resolved.get("description", "") or "")[:200],
"value": "",
"filled_by_llm": prop_name in required_fields,
"required": prop_name in required_fields,
}
return properties
def _resolve_ref(
obj: Any,
components: Dict[str, Any],
definitions: Dict[str, Any],
) -> Optional[Dict]:
"""Resolve $ref references in the specification."""
if not isinstance(obj, dict):
return obj if isinstance(obj, dict) else None
if "$ref" not in obj:
return obj
ref_path = obj["$ref"]
if ref_path.startswith("#/components/"):
parts = ref_path.replace("#/components/", "").split("/")
return _traverse_path(components, parts)
elif ref_path.startswith("#/definitions/"):
parts = ref_path.replace("#/definitions/", "").split("/")
return _traverse_path(definitions, parts)
logger.debug(f"Unsupported ref path: {ref_path}")
return None
def _traverse_path(obj: Dict, parts: List[str]) -> Optional[Dict]:
"""Traverse a nested dictionary using path parts."""
try:
for part in parts:
obj = obj[part]
return obj if isinstance(obj, dict) else None
except (KeyError, TypeError):
return None

View File

@@ -38,6 +38,8 @@ class TodoListTool(Tool):
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
self.collection = db["todos"]
self._last_artifact_id: Optional[str] = None
# -----------------------------
# Action implementations
# -----------------------------
@@ -54,6 +56,8 @@ class TodoListTool(Tool):
if not self.user_id:
return "Error: TodoListTool requires a valid user_id."
self._last_artifact_id = None
if action_name == "list":
return self._list()
@@ -165,6 +169,9 @@ class TodoListTool(Tool):
"""Return configuration requirements."""
return {}
def get_artifact_id(self, action_name: str, **kwargs: Any) -> Optional[str]:
return self._last_artifact_id
# -----------------------------
# Internal helpers
# -----------------------------
@@ -190,11 +197,8 @@ class TodoListTool(Tool):
Returns a simple integer (1, 2, 3, ...) scoped to this user/tool.
With 5-10 todos max, scanning is negligible.
"""
# Find all todos for this user/tool and get their IDs
todos = list(self.collection.find(
{"user_id": self.user_id, "tool_id": self.tool_id},
{"todo_id": 1}
))
query = {"user_id": self.user_id, "tool_id": self.tool_id}
todos = list(self.collection.find(query, {"todo_id": 1}))
# Find the maximum todo_id
max_id = 0
@@ -207,8 +211,8 @@ class TodoListTool(Tool):
def _list(self) -> str:
"""List all todos for the user."""
cursor = self.collection.find({"user_id": self.user_id, "tool_id": self.tool_id})
todos = list(cursor)
query = {"user_id": self.user_id, "tool_id": self.tool_id}
todos = list(self.collection.find(query))
if not todos:
return "No todos found."
@@ -242,7 +246,10 @@ class TodoListTool(Tool):
"created_at": now,
"updated_at": now,
}
self.collection.insert_one(doc)
insert_result = self.collection.insert_one(doc)
inserted_id = getattr(insert_result, "inserted_id", None) or doc.get("_id")
if inserted_id is not None:
self._last_artifact_id = str(inserted_id)
return f"Todo created with ID {todo_id}: {title}"
def _get(self, todo_id: Optional[Any]) -> str:
@@ -251,15 +258,15 @@ class TodoListTool(Tool):
if parsed_todo_id is None:
return "Error: todo_id must be a positive integer."
doc = self.collection.find_one({
"user_id": self.user_id,
"tool_id": self.tool_id,
"todo_id": parsed_todo_id
})
query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}
doc = self.collection.find_one(query)
if not doc:
return f"Error: Todo with ID {parsed_todo_id} not found."
if doc.get("_id") is not None:
self._last_artifact_id = str(doc.get("_id"))
title = doc.get("title", "Untitled")
status = doc.get("status", "open")
@@ -277,14 +284,17 @@ class TodoListTool(Tool):
if not title:
return "Error: Title is required."
result = self.collection.update_one(
{"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id},
{"$set": {"title": title, "updated_at": datetime.now()}}
query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}
doc = self.collection.find_one_and_update(
query,
{"$set": {"title": title, "updated_at": datetime.now()}},
)
if result.matched_count == 0:
if not doc:
return f"Error: Todo with ID {parsed_todo_id} not found."
if doc.get("_id") is not None:
self._last_artifact_id = str(doc.get("_id"))
return f"Todo {parsed_todo_id} updated to: {title}"
def _complete(self, todo_id: Optional[Any]) -> str:
@@ -293,14 +303,17 @@ class TodoListTool(Tool):
if parsed_todo_id is None:
return "Error: todo_id must be a positive integer."
result = self.collection.update_one(
{"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id},
{"$set": {"status": "completed", "updated_at": datetime.now()}}
query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}
doc = self.collection.find_one_and_update(
query,
{"$set": {"status": "completed", "updated_at": datetime.now()}},
)
if result.matched_count == 0:
if not doc:
return f"Error: Todo with ID {parsed_todo_id} not found."
if doc.get("_id") is not None:
self._last_artifact_id = str(doc.get("_id"))
return f"Todo {parsed_todo_id} marked as completed."
def _delete(self, todo_id: Optional[Any]) -> str:
@@ -309,13 +322,12 @@ class TodoListTool(Tool):
if parsed_todo_id is None:
return "Error: todo_id must be a positive integer."
result = self.collection.delete_one({
"user_id": self.user_id,
"tool_id": self.tool_id,
"todo_id": parsed_todo_id
})
if result.deleted_count == 0:
query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}
doc = self.collection.find_one_and_delete(query)
if not doc:
return f"Error: Todo with ID {parsed_todo_id} not found."
if doc.get("_id") is not None:
self._last_artifact_id = str(doc.get("_id"))
return f"Todo {parsed_todo_id} deleted."

View File

@@ -0,0 +1,231 @@
import logging
from datetime import datetime, timezone
from typing import Any, Dict, Generator, Optional
from application.agents.base import BaseAgent
from application.agents.workflows.schemas import (
ExecutionStatus,
Workflow,
WorkflowEdge,
WorkflowGraph,
WorkflowNode,
WorkflowRun,
)
from application.agents.workflows.workflow_engine import WorkflowEngine
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.logging import log_activity, LogContext
logger = logging.getLogger(__name__)
class WorkflowAgent(BaseAgent):
"""A specialized agent that executes predefined workflows."""
def __init__(
self,
*args,
workflow_id: Optional[str] = None,
workflow: Optional[Dict[str, Any]] = None,
workflow_owner: Optional[str] = None,
**kwargs,
):
super().__init__(*args, **kwargs)
self.workflow_id = workflow_id
self.workflow_owner = workflow_owner
self._workflow_data = workflow
self._engine: Optional[WorkflowEngine] = None
@log_activity()
def gen(
self, query: str, log_context: LogContext = None
) -> Generator[Dict[str, str], None, None]:
yield from self._gen_inner(query, log_context)
def _gen_inner(
self, query: str, log_context: LogContext
) -> Generator[Dict[str, str], None, None]:
graph = self._load_workflow_graph()
if not graph:
yield {"type": "error", "error": "Failed to load workflow configuration."}
return
self._engine = WorkflowEngine(graph, self)
yield from self._engine.execute({}, query)
self._save_workflow_run(query)
def _load_workflow_graph(self) -> Optional[WorkflowGraph]:
if self._workflow_data:
return self._parse_embedded_workflow()
if self.workflow_id:
return self._load_from_database()
return None
def _parse_embedded_workflow(self) -> Optional[WorkflowGraph]:
try:
nodes_data = self._workflow_data.get("nodes", [])
edges_data = self._workflow_data.get("edges", [])
workflow = Workflow(
name=self._workflow_data.get("name", "Embedded Workflow"),
description=self._workflow_data.get("description"),
)
nodes = []
for n in nodes_data:
node_config = n.get("data", {})
nodes.append(
WorkflowNode(
id=n["id"],
workflow_id=self.workflow_id or "embedded",
type=n["type"],
title=n.get("title", "Node"),
description=n.get("description"),
position=n.get("position", {"x": 0, "y": 0}),
config=node_config,
)
)
edges = []
for e in edges_data:
edges.append(
WorkflowEdge(
id=e["id"],
workflow_id=self.workflow_id or "embedded",
source=e.get("source") or e.get("source_id"),
target=e.get("target") or e.get("target_id"),
sourceHandle=e.get("sourceHandle") or e.get("source_handle"),
targetHandle=e.get("targetHandle") or e.get("target_handle"),
)
)
return WorkflowGraph(workflow=workflow, nodes=nodes, edges=edges)
except Exception as e:
logger.error(f"Invalid embedded workflow: {e}")
return None
def _load_from_database(self) -> Optional[WorkflowGraph]:
try:
from bson.objectid import ObjectId
if not self.workflow_id or not ObjectId.is_valid(self.workflow_id):
logger.error(f"Invalid workflow ID: {self.workflow_id}")
return None
owner_id = self.workflow_owner
if not owner_id and isinstance(self.decoded_token, dict):
owner_id = self.decoded_token.get("sub")
if not owner_id:
logger.error(
f"Workflow owner not available for workflow load: {self.workflow_id}"
)
return None
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
workflows_coll = db["workflows"]
workflow_nodes_coll = db["workflow_nodes"]
workflow_edges_coll = db["workflow_edges"]
workflow_doc = workflows_coll.find_one(
{"_id": ObjectId(self.workflow_id), "user": owner_id}
)
if not workflow_doc:
logger.error(
f"Workflow {self.workflow_id} not found or inaccessible for user {owner_id}"
)
return None
workflow = Workflow(**workflow_doc)
graph_version = workflow_doc.get("current_graph_version", 1)
try:
graph_version = int(graph_version)
if graph_version <= 0:
graph_version = 1
except (ValueError, TypeError):
graph_version = 1
nodes_docs = list(
workflow_nodes_coll.find(
{"workflow_id": self.workflow_id, "graph_version": graph_version}
)
)
if not nodes_docs and graph_version == 1:
nodes_docs = list(
workflow_nodes_coll.find(
{
"workflow_id": self.workflow_id,
"graph_version": {"$exists": False},
}
)
)
nodes = [WorkflowNode(**doc) for doc in nodes_docs]
edges_docs = list(
workflow_edges_coll.find(
{"workflow_id": self.workflow_id, "graph_version": graph_version}
)
)
if not edges_docs and graph_version == 1:
edges_docs = list(
workflow_edges_coll.find(
{
"workflow_id": self.workflow_id,
"graph_version": {"$exists": False},
}
)
)
edges = [WorkflowEdge(**doc) for doc in edges_docs]
return WorkflowGraph(workflow=workflow, nodes=nodes, edges=edges)
except Exception as e:
logger.error(f"Failed to load workflow from database: {e}")
return None
def _save_workflow_run(self, query: str) -> None:
if not self._engine:
return
try:
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
workflow_runs_coll = db["workflow_runs"]
run = WorkflowRun(
workflow_id=self.workflow_id or "unknown",
status=self._determine_run_status(),
inputs={"query": query},
outputs=self._serialize_state(self._engine.state),
steps=self._engine.get_execution_summary(),
created_at=datetime.now(timezone.utc),
completed_at=datetime.now(timezone.utc),
)
workflow_runs_coll.insert_one(run.to_mongo_doc())
except Exception as e:
logger.error(f"Failed to save workflow run: {e}")
def _determine_run_status(self) -> ExecutionStatus:
if not self._engine or not self._engine.execution_log:
return ExecutionStatus.COMPLETED
for log in self._engine.execution_log:
if log.get("status") == ExecutionStatus.FAILED.value:
return ExecutionStatus.FAILED
return ExecutionStatus.COMPLETED
def _serialize_state(self, state: Dict[str, Any]) -> Dict[str, Any]:
serialized: Dict[str, Any] = {}
for key, value in state.items():
serialized[key] = self._serialize_state_value(value)
return serialized
def _serialize_state_value(self, value: Any) -> Any:
if isinstance(value, dict):
return {
str(dict_key): self._serialize_state_value(dict_value)
for dict_key, dict_value in value.items()
}
if isinstance(value, list):
return [self._serialize_state_value(item) for item in value]
if isinstance(value, tuple):
return [self._serialize_state_value(item) for item in value]
if isinstance(value, datetime):
return value.isoformat()
if isinstance(value, (str, int, float, bool, type(None))):
return value
return str(value)

View File

@@ -0,0 +1,64 @@
from typing import Any, Dict
import celpy
import celpy.celtypes
class CelEvaluationError(Exception):
pass
def _convert_value(value: Any) -> Any:
if isinstance(value, bool):
return celpy.celtypes.BoolType(value)
if isinstance(value, int):
return celpy.celtypes.IntType(value)
if isinstance(value, float):
return celpy.celtypes.DoubleType(value)
if isinstance(value, str):
return celpy.celtypes.StringType(value)
if isinstance(value, list):
return celpy.celtypes.ListType([_convert_value(item) for item in value])
if isinstance(value, dict):
return celpy.celtypes.MapType(
{celpy.celtypes.StringType(k): _convert_value(v) for k, v in value.items()}
)
if value is None:
return celpy.celtypes.BoolType(False)
return celpy.celtypes.StringType(str(value))
def build_activation(state: Dict[str, Any]) -> Dict[str, Any]:
return {k: _convert_value(v) for k, v in state.items()}
def evaluate_cel(expression: str, state: Dict[str, Any]) -> Any:
if not expression or not expression.strip():
raise CelEvaluationError("Empty expression")
try:
env = celpy.Environment()
ast = env.compile(expression)
program = env.program(ast)
activation = build_activation(state)
result = program.evaluate(activation)
except celpy.CELEvalError as exc:
raise CelEvaluationError(f"CEL evaluation error: {exc}") from exc
except Exception as exc:
raise CelEvaluationError(f"CEL error: {exc}") from exc
return cel_to_python(result)
def cel_to_python(value: Any) -> Any:
if isinstance(value, celpy.celtypes.BoolType):
return bool(value)
if isinstance(value, celpy.celtypes.IntType):
return int(value)
if isinstance(value, celpy.celtypes.DoubleType):
return float(value)
if isinstance(value, celpy.celtypes.StringType):
return str(value)
if isinstance(value, celpy.celtypes.ListType):
return [cel_to_python(item) for item in value]
if isinstance(value, celpy.celtypes.MapType):
return {str(k): cel_to_python(v) for k, v in value.items()}
return value

View File

@@ -0,0 +1,109 @@
"""Workflow Node Agents - defines specialized agents for workflow nodes."""
from typing import Any, Dict, List, Optional, Type
from application.agents.base import BaseAgent
from application.agents.classic_agent import ClassicAgent
from application.agents.react_agent import ReActAgent
from application.agents.workflows.schemas import AgentType
class ToolFilterMixin:
"""Mixin that filters fetched tools to only those specified in tool_ids."""
_allowed_tool_ids: List[str]
def _get_user_tools(self, user: str = "local") -> Dict[str, Dict[str, Any]]:
all_tools = super()._get_user_tools(user)
if not self._allowed_tool_ids:
return {}
filtered_tools = {
tool_id: tool
for tool_id, tool in all_tools.items()
if str(tool.get("_id", "")) in self._allowed_tool_ids
}
return filtered_tools
def _get_tools(self, api_key: str = None) -> Dict[str, Dict[str, Any]]:
all_tools = super()._get_tools(api_key)
if not self._allowed_tool_ids:
return {}
filtered_tools = {
tool_id: tool
for tool_id, tool in all_tools.items()
if str(tool.get("_id", "")) in self._allowed_tool_ids
}
return filtered_tools
class WorkflowNodeClassicAgent(ToolFilterMixin, ClassicAgent):
def __init__(
self,
endpoint: str,
llm_name: str,
model_id: str,
api_key: str,
tool_ids: Optional[List[str]] = None,
**kwargs,
):
super().__init__(
endpoint=endpoint,
llm_name=llm_name,
model_id=model_id,
api_key=api_key,
**kwargs,
)
self._allowed_tool_ids = tool_ids or []
class WorkflowNodeReActAgent(ToolFilterMixin, ReActAgent):
def __init__(
self,
endpoint: str,
llm_name: str,
model_id: str,
api_key: str,
tool_ids: Optional[List[str]] = None,
**kwargs,
):
super().__init__(
endpoint=endpoint,
llm_name=llm_name,
model_id=model_id,
api_key=api_key,
**kwargs,
)
self._allowed_tool_ids = tool_ids or []
class WorkflowNodeAgentFactory:
_agents: Dict[AgentType, Type[BaseAgent]] = {
AgentType.CLASSIC: WorkflowNodeClassicAgent,
AgentType.REACT: WorkflowNodeReActAgent,
}
@classmethod
def create(
cls,
agent_type: AgentType,
endpoint: str,
llm_name: str,
model_id: str,
api_key: str,
tool_ids: Optional[List[str]] = None,
**kwargs,
) -> BaseAgent:
agent_class = cls._agents.get(agent_type)
if not agent_class:
raise ValueError(f"Unsupported agent type: {agent_type}")
return agent_class(
endpoint=endpoint,
llm_name=llm_name,
model_id=model_id,
api_key=api_key,
tool_ids=tool_ids,
**kwargs,
)

View File

@@ -0,0 +1,235 @@
from datetime import datetime, timezone
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Union
from bson import ObjectId
from pydantic import BaseModel, ConfigDict, Field, field_validator
class NodeType(str, Enum):
START = "start"
END = "end"
AGENT = "agent"
NOTE = "note"
STATE = "state"
CONDITION = "condition"
class AgentType(str, Enum):
CLASSIC = "classic"
REACT = "react"
class ExecutionStatus(str, Enum):
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
class Position(BaseModel):
model_config = ConfigDict(extra="forbid")
x: float = 0.0
y: float = 0.0
class AgentNodeConfig(BaseModel):
model_config = ConfigDict(extra="allow")
agent_type: AgentType = AgentType.CLASSIC
llm_name: Optional[str] = None
system_prompt: str = "You are a helpful assistant."
prompt_template: str = ""
output_variable: Optional[str] = None
stream_to_user: bool = True
tools: List[str] = Field(default_factory=list)
sources: List[str] = Field(default_factory=list)
chunks: str = "2"
retriever: str = ""
model_id: Optional[str] = None
json_schema: Optional[Dict[str, Any]] = None
class ConditionCase(BaseModel):
model_config = ConfigDict(extra="forbid", populate_by_name=True)
name: Optional[str] = None
expression: str = ""
source_handle: str = Field(..., alias="sourceHandle")
class ConditionNodeConfig(BaseModel):
model_config = ConfigDict(extra="allow")
mode: Literal["simple", "advanced"] = "simple"
cases: List[ConditionCase] = Field(default_factory=list)
class StateOperation(BaseModel):
model_config = ConfigDict(extra="forbid")
expression: str = ""
target_variable: str = ""
class WorkflowEdgeCreate(BaseModel):
model_config = ConfigDict(populate_by_name=True)
id: str
workflow_id: str
source_id: str = Field(..., alias="source")
target_id: str = Field(..., alias="target")
source_handle: Optional[str] = Field(None, alias="sourceHandle")
target_handle: Optional[str] = Field(None, alias="targetHandle")
class WorkflowEdge(WorkflowEdgeCreate):
mongo_id: Optional[str] = Field(None, alias="_id")
@field_validator("mongo_id", mode="before")
@classmethod
def convert_objectid(cls, v: Any) -> Optional[str]:
if isinstance(v, ObjectId):
return str(v)
return v
def to_mongo_doc(self) -> Dict[str, Any]:
return {
"id": self.id,
"workflow_id": self.workflow_id,
"source_id": self.source_id,
"target_id": self.target_id,
"source_handle": self.source_handle,
"target_handle": self.target_handle,
}
class WorkflowNodeCreate(BaseModel):
model_config = ConfigDict(extra="allow")
id: str
workflow_id: str
type: NodeType
title: str = "Node"
description: Optional[str] = None
position: Position = Field(default_factory=Position)
config: Dict[str, Any] = Field(default_factory=dict)
@field_validator("position", mode="before")
@classmethod
def parse_position(cls, v: Union[Dict[str, float], Position]) -> Position:
if isinstance(v, dict):
return Position(**v)
return v
class WorkflowNode(WorkflowNodeCreate):
mongo_id: Optional[str] = Field(None, alias="_id")
@field_validator("mongo_id", mode="before")
@classmethod
def convert_objectid(cls, v: Any) -> Optional[str]:
if isinstance(v, ObjectId):
return str(v)
return v
def to_mongo_doc(self) -> Dict[str, Any]:
return {
"id": self.id,
"workflow_id": self.workflow_id,
"type": self.type.value,
"title": self.title,
"description": self.description,
"position": self.position.model_dump(),
"config": self.config,
}
class WorkflowCreate(BaseModel):
model_config = ConfigDict(extra="allow")
name: str = "New Workflow"
description: Optional[str] = None
user: Optional[str] = None
class Workflow(WorkflowCreate):
id: Optional[str] = Field(None, alias="_id")
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
@field_validator("id", mode="before")
@classmethod
def convert_objectid(cls, v: Any) -> Optional[str]:
if isinstance(v, ObjectId):
return str(v)
return v
def to_mongo_doc(self) -> Dict[str, Any]:
return {
"name": self.name,
"description": self.description,
"user": self.user,
"created_at": self.created_at,
"updated_at": self.updated_at,
}
class WorkflowGraph(BaseModel):
workflow: Workflow
nodes: List[WorkflowNode] = Field(default_factory=list)
edges: List[WorkflowEdge] = Field(default_factory=list)
def get_node_by_id(self, node_id: str) -> Optional[WorkflowNode]:
for node in self.nodes:
if node.id == node_id:
return node
return None
def get_start_node(self) -> Optional[WorkflowNode]:
for node in self.nodes:
if node.type == NodeType.START:
return node
return None
def get_outgoing_edges(self, node_id: str) -> List[WorkflowEdge]:
return [edge for edge in self.edges if edge.source_id == node_id]
class NodeExecutionLog(BaseModel):
model_config = ConfigDict(extra="forbid")
node_id: str
node_type: str
status: ExecutionStatus
started_at: datetime
completed_at: Optional[datetime] = None
error: Optional[str] = None
state_snapshot: Dict[str, Any] = Field(default_factory=dict)
class WorkflowRunCreate(BaseModel):
workflow_id: str
inputs: Dict[str, str] = Field(default_factory=dict)
class WorkflowRun(BaseModel):
model_config = ConfigDict(extra="allow")
id: Optional[str] = Field(None, alias="_id")
workflow_id: str
status: ExecutionStatus = ExecutionStatus.PENDING
inputs: Dict[str, str] = Field(default_factory=dict)
outputs: Dict[str, Any] = Field(default_factory=dict)
steps: List[NodeExecutionLog] = Field(default_factory=list)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
completed_at: Optional[datetime] = None
@field_validator("id", mode="before")
@classmethod
def convert_objectid(cls, v: Any) -> Optional[str]:
if isinstance(v, ObjectId):
return str(v)
return v
def to_mongo_doc(self) -> Dict[str, Any]:
return {
"workflow_id": self.workflow_id,
"status": self.status.value,
"inputs": self.inputs,
"outputs": self.outputs,
"steps": [step.model_dump() for step in self.steps],
"created_at": self.created_at,
"completed_at": self.completed_at,
}

View File

@@ -0,0 +1,453 @@
import json
import logging
from datetime import datetime, timezone
from typing import Any, Dict, Generator, List, Optional, TYPE_CHECKING
from application.agents.workflows.cel_evaluator import CelEvaluationError, evaluate_cel
from application.agents.workflows.node_agent import WorkflowNodeAgentFactory
from application.agents.workflows.schemas import (
AgentNodeConfig,
ConditionNodeConfig,
ExecutionStatus,
NodeExecutionLog,
NodeType,
WorkflowGraph,
WorkflowNode,
)
from application.core.json_schema_utils import (
JsonSchemaValidationError,
normalize_json_schema_payload,
)
from application.templates.namespaces import NamespaceManager
from application.templates.template_engine import TemplateEngine, TemplateRenderError
try:
import jsonschema
except ImportError: # pragma: no cover - optional dependency in some deployments.
jsonschema = None
if TYPE_CHECKING:
from application.agents.base import BaseAgent
logger = logging.getLogger(__name__)
StateValue = Any
WorkflowState = Dict[str, StateValue]
TEMPLATE_RESERVED_NAMESPACES = {"agent", "system", "source", "tools", "passthrough"}
class WorkflowEngine:
MAX_EXECUTION_STEPS = 50
def __init__(self, graph: WorkflowGraph, agent: "BaseAgent"):
self.graph = graph
self.agent = agent
self.state: WorkflowState = {}
self.execution_log: List[Dict[str, Any]] = []
self._condition_result: Optional[str] = None
self._template_engine = TemplateEngine()
self._namespace_manager = NamespaceManager()
def execute(
self, initial_inputs: WorkflowState, query: str
) -> Generator[Dict[str, str], None, None]:
self._initialize_state(initial_inputs, query)
start_node = self.graph.get_start_node()
if not start_node:
yield {"type": "error", "error": "No start node found in workflow."}
return
current_node_id: Optional[str] = start_node.id
steps = 0
while current_node_id and steps < self.MAX_EXECUTION_STEPS:
node = self.graph.get_node_by_id(current_node_id)
if not node:
yield {"type": "error", "error": f"Node {current_node_id} not found."}
break
log_entry = self._create_log_entry(node)
yield {
"type": "workflow_step",
"node_id": node.id,
"node_type": node.type.value,
"node_title": node.title,
"status": "running",
}
try:
yield from self._execute_node(node)
log_entry["status"] = ExecutionStatus.COMPLETED.value
log_entry["completed_at"] = datetime.now(timezone.utc)
output_key = f"node_{node.id}_output"
node_output = self.state.get(output_key)
yield {
"type": "workflow_step",
"node_id": node.id,
"node_type": node.type.value,
"node_title": node.title,
"status": "completed",
"state_snapshot": dict(self.state),
"output": node_output,
}
except Exception as e:
logger.error(f"Error executing node {node.id}: {e}", exc_info=True)
log_entry["status"] = ExecutionStatus.FAILED.value
log_entry["error"] = str(e)
log_entry["completed_at"] = datetime.now(timezone.utc)
log_entry["state_snapshot"] = dict(self.state)
self.execution_log.append(log_entry)
yield {
"type": "workflow_step",
"node_id": node.id,
"node_type": node.type.value,
"node_title": node.title,
"status": "failed",
"state_snapshot": dict(self.state),
"error": str(e),
}
yield {"type": "error", "error": str(e)}
break
log_entry["state_snapshot"] = dict(self.state)
self.execution_log.append(log_entry)
if node.type == NodeType.END:
break
current_node_id = self._get_next_node_id(current_node_id)
if current_node_id is None and node.type != NodeType.END:
logger.warning(
f"Branch ended at node '{node.title}' ({node.id}) without reaching an end node"
)
steps += 1
if steps >= self.MAX_EXECUTION_STEPS:
logger.warning(
f"Workflow reached max steps limit ({self.MAX_EXECUTION_STEPS})"
)
def _initialize_state(self, initial_inputs: WorkflowState, query: str) -> None:
self.state.update(initial_inputs)
self.state["query"] = query
self.state["chat_history"] = str(self.agent.chat_history)
def _create_log_entry(self, node: WorkflowNode) -> Dict[str, Any]:
return {
"node_id": node.id,
"node_type": node.type.value,
"started_at": datetime.now(timezone.utc),
"completed_at": None,
"status": ExecutionStatus.RUNNING.value,
"error": None,
"state_snapshot": {},
}
def _get_next_node_id(self, current_node_id: str) -> Optional[str]:
node = self.graph.get_node_by_id(current_node_id)
edges = self.graph.get_outgoing_edges(current_node_id)
if not edges:
return None
if node and node.type == NodeType.CONDITION and self._condition_result:
target_handle = self._condition_result
self._condition_result = None
for edge in edges:
if edge.source_handle == target_handle:
return edge.target_id
return None
return edges[0].target_id
def _execute_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
logger.info(f"Executing node {node.id} ({node.type.value})")
node_handlers = {
NodeType.START: self._execute_start_node,
NodeType.NOTE: self._execute_note_node,
NodeType.AGENT: self._execute_agent_node,
NodeType.STATE: self._execute_state_node,
NodeType.CONDITION: self._execute_condition_node,
NodeType.END: self._execute_end_node,
}
handler = node_handlers.get(node.type)
if handler:
yield from handler(node)
def _execute_start_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
yield from ()
def _execute_note_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
yield from ()
def _execute_agent_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
from application.core.model_utils import (
get_api_key_for_provider,
get_model_capabilities,
get_provider_from_model_id,
)
node_config = AgentNodeConfig(**node.config.get("config", node.config))
if node_config.prompt_template:
formatted_prompt = self._format_template(node_config.prompt_template)
else:
formatted_prompt = self.state.get("query", "")
node_json_schema = self._normalize_node_json_schema(
node_config.json_schema, node.title
)
node_model_id = node_config.model_id or self.agent.model_id
node_llm_name = (
node_config.llm_name
or get_provider_from_model_id(node_model_id or "")
or self.agent.llm_name
)
node_api_key = get_api_key_for_provider(node_llm_name) or self.agent.api_key
if node_json_schema and node_model_id:
model_capabilities = get_model_capabilities(node_model_id)
if model_capabilities and not model_capabilities.get(
"supports_structured_output", False
):
raise ValueError(
f'Model "{node_model_id}" does not support structured output for node "{node.title}"'
)
node_agent = WorkflowNodeAgentFactory.create(
agent_type=node_config.agent_type,
endpoint=self.agent.endpoint,
llm_name=node_llm_name,
model_id=node_model_id,
api_key=node_api_key,
tool_ids=node_config.tools,
prompt=node_config.system_prompt,
chat_history=self.agent.chat_history,
decoded_token=self.agent.decoded_token,
json_schema=node_json_schema,
)
full_response_parts: List[str] = []
structured_response_parts: List[str] = []
has_structured_response = False
first_chunk = True
for event in node_agent.gen(formatted_prompt):
if "answer" in event:
chunk = str(event["answer"])
full_response_parts.append(chunk)
if event.get("structured"):
has_structured_response = True
structured_response_parts.append(chunk)
if node_config.stream_to_user:
if first_chunk and hasattr(self, "_has_streamed"):
yield {"answer": "\n\n"}
first_chunk = False
yield event
if node_config.stream_to_user:
self._has_streamed = True
full_response = "".join(full_response_parts).strip()
output_value: Any = full_response
if has_structured_response:
structured_response = "".join(structured_response_parts).strip()
response_to_parse = structured_response or full_response
parsed_success, parsed_structured = self._parse_structured_output(
response_to_parse
)
output_value = parsed_structured if parsed_success else response_to_parse
if node_json_schema:
self._validate_structured_output(node_json_schema, output_value)
elif node_json_schema:
parsed_success, parsed_structured = self._parse_structured_output(
full_response
)
if not parsed_success:
raise ValueError(
"Structured output was expected but response was not valid JSON"
)
output_value = parsed_structured
self._validate_structured_output(node_json_schema, output_value)
default_output_key = f"node_{node.id}_output"
self.state[default_output_key] = output_value
if node_config.output_variable:
self.state[node_config.output_variable] = output_value
def _execute_state_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
config = node.config.get("config", node.config)
for op in config.get("operations", []):
expression = op.get("expression", "")
target_variable = op.get("target_variable", "")
if expression and target_variable:
self.state[target_variable] = evaluate_cel(expression, self.state)
yield from ()
def _execute_condition_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
config = ConditionNodeConfig(**node.config.get("config", node.config))
matched_handle = None
for case in config.cases:
if not case.expression.strip():
continue
try:
if evaluate_cel(case.expression, self.state):
matched_handle = case.source_handle
break
except CelEvaluationError:
continue
self._condition_result = matched_handle or "else"
yield from ()
def _execute_end_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
config = node.config.get("config", node.config)
output_template = str(config.get("output_template", ""))
if output_template:
formatted_output = self._format_template(output_template)
yield {"answer": formatted_output}
def _parse_structured_output(self, raw_response: str) -> tuple[bool, Optional[Any]]:
normalized_response = raw_response.strip()
if not normalized_response:
return False, None
try:
return True, json.loads(normalized_response)
except json.JSONDecodeError:
logger.warning(
"Workflow agent returned structured output that was not valid JSON"
)
return False, None
def _normalize_node_json_schema(
self, schema: Optional[Dict[str, Any]], node_title: str
) -> Optional[Dict[str, Any]]:
if schema is None:
return None
try:
return normalize_json_schema_payload(schema)
except JsonSchemaValidationError as exc:
raise ValueError(
f'Invalid JSON schema for node "{node_title}": {exc}'
) from exc
def _validate_structured_output(self, schema: Dict[str, Any], output_value: Any) -> None:
if jsonschema is None:
logger.warning(
"jsonschema package is not available, skipping structured output validation"
)
return
try:
normalized_schema = normalize_json_schema_payload(schema)
except JsonSchemaValidationError as exc:
raise ValueError(f"Invalid JSON schema: {exc}") from exc
try:
jsonschema.validate(instance=output_value, schema=normalized_schema)
except jsonschema.exceptions.ValidationError as exc:
raise ValueError(f"Structured output did not match schema: {exc.message}") from exc
except jsonschema.exceptions.SchemaError as exc:
raise ValueError(f"Invalid JSON schema: {exc.message}") from exc
def _format_template(self, template: str) -> str:
context = self._build_template_context()
try:
return self._template_engine.render(template, context)
except TemplateRenderError as e:
logger.warning(
"Workflow template rendering failed, using raw template: %s", str(e)
)
return template
def _build_template_context(self) -> Dict[str, Any]:
docs, docs_together = self._get_source_template_data()
passthrough_data = (
self.state.get("passthrough")
if isinstance(self.state.get("passthrough"), dict)
else None
)
tools_data = (
self.state.get("tools") if isinstance(self.state.get("tools"), dict) else None
)
context = self._namespace_manager.build_context(
user_id=getattr(self.agent, "user", None),
request_id=getattr(self.agent, "request_id", None),
passthrough_data=passthrough_data,
docs=docs,
docs_together=docs_together,
tools_data=tools_data,
)
agent_context: Dict[str, Any] = {}
for key, value in self.state.items():
if not isinstance(key, str):
continue
normalized_key = key.strip()
if not normalized_key:
continue
agent_context[normalized_key] = value
context["agent"] = agent_context
# Keep legacy top-level variables working while namespaced variables are adopted.
for key, value in agent_context.items():
if key in TEMPLATE_RESERVED_NAMESPACES:
context[f"agent_{key}"] = value
continue
if key not in context:
context[key] = value
return context
def _get_source_template_data(self) -> tuple[Optional[List[Dict[str, Any]]], Optional[str]]:
docs = getattr(self.agent, "retrieved_docs", None)
if not isinstance(docs, list) or len(docs) == 0:
return None, None
docs_together_parts: List[str] = []
for doc in docs:
if not isinstance(doc, dict):
continue
text = doc.get("text")
if not isinstance(text, str):
continue
filename = doc.get("filename") or doc.get("title") or doc.get("source")
if isinstance(filename, str) and filename.strip():
docs_together_parts.append(f"{filename}\n{text}")
else:
docs_together_parts.append(text)
docs_together = "\n\n".join(docs_together_parts) if docs_together_parts else None
return docs, docs_together
def get_execution_summary(self) -> List[NodeExecutionLog]:
return [
NodeExecutionLog(
node_id=log["node_id"],
node_type=log["node_type"],
status=ExecutionStatus(log["status"]),
started_at=log["started_at"],
completed_at=log.get("completed_at"),
error=log.get("error"),
state_snapshot=log.get("state_snapshot", {}),
)
for log in self.execution_log
]

View File

@@ -3,6 +3,7 @@ from flask import Blueprint
from application.api import api
from application.api.answer.routes.answer import AnswerResource
from application.api.answer.routes.base import answer_ns
from application.api.answer.routes.search import SearchResource
from application.api.answer.routes.stream import StreamResource
@@ -14,6 +15,7 @@ api.add_namespace(answer_ns)
def init_answer_routes():
api.add_resource(StreamResource, "/stream")
api.add_resource(AnswerResource, "/api/answer")
api.add_resource(SearchResource, "/api/search")
init_answer_routes()

View File

@@ -40,9 +40,9 @@ class AnswerResource(Resource, BaseAnswerResource):
"chunks": fields.Integer(
required=False, default=2, description="Number of chunks"
),
"token_limit": fields.Integer(required=False, description="Token limit"),
"retriever": fields.String(required=False, description="Retriever type"),
"api_key": fields.String(required=False, description="API key"),
"agent_id": fields.String(required=False, description="Agent ID"),
"active_docs": fields.String(
required=False, description="Active documents"
),
@@ -101,6 +101,9 @@ class AnswerResource(Resource, BaseAnswerResource):
isNoneDoc=data.get("isNoneDoc"),
index=None,
should_save_conversation=data.get("save_conversation", True),
agent_id=processor.agent_id,
is_shared_usage=processor.is_shared_usage,
shared_token=processor.shared_token,
model_id=processor.model_id,
)
stream_result = self.process_response_stream(stream)
@@ -138,5 +141,5 @@ class AnswerResource(Resource, BaseAnswerResource):
f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}",
extra={"error": str(e), "traceback": traceback.format_exc()},
)
return make_response({"error": str(e)}, 500)
return make_response({"error": "An error occurred processing your request"}, 500)
return make_response(result, 200)

View File

@@ -46,6 +46,27 @@ class BaseAnswerResource:
return missing_fields
return None
@staticmethod
def _prepare_tool_calls_for_logging(
tool_calls: Optional[List[Dict[str, Any]]], max_chars: int = 10000
) -> List[Dict[str, Any]]:
if not tool_calls:
return []
prepared = []
for tool_call in tool_calls:
if not isinstance(tool_call, dict):
prepared.append({"result": str(tool_call)[:max_chars]})
continue
item = dict(tool_call)
for key in ("result", "result_full"):
value = item.get(key)
if isinstance(value, str) and len(value) > max_chars:
item[key] = value[:max_chars]
prepared.append(item)
return prepared
def check_usage(self, agent_config: Dict) -> Optional[Response]:
"""Check if there is a usage limit and if it is exceeded
@@ -246,6 +267,7 @@ class BaseAnswerResource:
user_api_key=user_api_key,
decoded_token=decoded_token,
model_id=model_id,
agent_id=agent_id,
)
if should_save_conversation:
@@ -292,14 +314,20 @@ class BaseAnswerResource:
data = json.dumps(id_data)
yield f"data: {data}\n\n"
tool_calls_for_logging = self._prepare_tool_calls_for_logging(
getattr(agent, "tool_calls", tool_calls) or tool_calls
)
log_data = {
"action": "stream_answer",
"level": "info",
"user": decoded_token.get("sub"),
"api_key": user_api_key,
"agent_id": agent_id,
"question": question,
"response": response_full,
"sources": source_log_docs,
"tool_calls": tool_calls_for_logging,
"attachments": attachment_ids,
"timestamp": datetime.datetime.now(datetime.timezone.utc),
}
@@ -330,6 +358,7 @@ class BaseAnswerResource:
api_key=settings.API_KEY,
user_api_key=user_api_key,
decoded_token=decoded_token,
agent_id=agent_id,
)
self.conversation_service.save_conversation(
conversation_id,

View File

@@ -0,0 +1,186 @@
import logging
from typing import Any, Dict, List
from flask import make_response, request
from flask_restx import fields, Resource
from bson.dbref import DBRef
from application.api.answer.routes.base import answer_ns
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.vectorstore.vector_creator import VectorCreator
logger = logging.getLogger(__name__)
@answer_ns.route("/api/search")
class SearchResource(Resource):
"""Fast search endpoint for retrieving relevant documents"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
mongo = MongoDB.get_client()
self.db = mongo[settings.MONGO_DB_NAME]
self.agents_collection = self.db["agents"]
search_model = answer_ns.model(
"SearchModel",
{
"question": fields.String(
required=True, description="Search query"
),
"api_key": fields.String(
required=True, description="API key for authentication"
),
"chunks": fields.Integer(
required=False, default=5, description="Number of results to return"
),
},
)
def _get_sources_from_api_key(self, api_key: str) -> List[str]:
"""Get source IDs connected to the API key/agent.
"""
agent_data = self.agents_collection.find_one({"key": api_key})
if not agent_data:
return []
source_ids = []
# Handle multiple sources (only if non-empty)
sources = agent_data.get("sources", [])
if sources and isinstance(sources, list) and len(sources) > 0:
for source_ref in sources:
# Skip "default" - it's a placeholder, not an actual vectorstore
if source_ref == "default":
continue
elif isinstance(source_ref, DBRef):
source_doc = self.db.dereference(source_ref)
if source_doc:
source_ids.append(str(source_doc["_id"]))
# Handle single source (legacy) - check if sources was empty or didn't yield results
if not source_ids:
source = agent_data.get("source")
if isinstance(source, DBRef):
source_doc = self.db.dereference(source)
if source_doc:
source_ids.append(str(source_doc["_id"]))
# Skip "default" - it's a placeholder, not an actual vectorstore
elif source and source != "default":
source_ids.append(source)
return source_ids
def _search_vectorstores(
self, query: str, source_ids: List[str], chunks: int
) -> List[Dict[str, Any]]:
"""Search across vectorstores and return results"""
if not source_ids:
return []
results = []
chunks_per_source = max(1, chunks // len(source_ids))
seen_texts = set()
for source_id in source_ids:
if not source_id or not source_id.strip():
continue
try:
docsearch = VectorCreator.create_vectorstore(
settings.VECTOR_STORE, source_id, settings.EMBEDDINGS_KEY
)
docs = docsearch.search(query, k=chunks_per_source * 2)
for doc in docs:
if len(results) >= chunks:
break
if hasattr(doc, "page_content") and hasattr(doc, "metadata"):
page_content = doc.page_content
metadata = doc.metadata
else:
page_content = doc.get("text", doc.get("page_content", ""))
metadata = doc.get("metadata", {})
# Skip duplicates
text_hash = hash(page_content[:200])
if text_hash in seen_texts:
continue
seen_texts.add(text_hash)
title = metadata.get(
"title", metadata.get("post_title", "")
)
if not isinstance(title, str):
title = str(title) if title else ""
# Clean up title
if title:
title = title.split("/")[-1]
else:
# Use filename or first part of content as title
title = metadata.get("filename", page_content[:50] + "...")
source = metadata.get("source", source_id)
results.append({
"text": page_content,
"title": title,
"source": source,
})
if len(results) >= chunks:
break
except Exception as e:
logger.error(
f"Error searching vectorstore {source_id}: {e}",
exc_info=True,
)
continue
return results[:chunks]
@answer_ns.expect(search_model)
@answer_ns.doc(description="Search for relevant documents based on query")
def post(self):
data = request.get_json()
question = data.get("question")
api_key = data.get("api_key")
chunks = data.get("chunks", 5)
if not question:
return make_response({"error": "question is required"}, 400)
if not api_key:
return make_response({"error": "api_key is required"}, 400)
# Validate API key
agent = self.agents_collection.find_one({"key": api_key})
if not agent:
return make_response({"error": "Invalid API key"}, 401)
try:
# Get sources connected to this API key
source_ids = self._get_sources_from_api_key(api_key)
if not source_ids:
return make_response([], 200)
# Perform search
results = self._search_vectorstores(question, source_ids, chunks)
return make_response(results, 200)
except Exception as e:
logger.error(
f"/api/search - error: {str(e)}",
extra={"error": str(e)},
exc_info=True,
)
return make_response({"error": "Search failed"}, 500)

View File

@@ -40,9 +40,9 @@ class StreamResource(Resource, BaseAnswerResource):
"chunks": fields.Integer(
required=False, default=2, description="Number of chunks"
),
"token_limit": fields.Integer(required=False, description="Token limit"),
"retriever": fields.String(required=False, description="Retriever type"),
"api_key": fields.String(required=False, description="API key"),
"agent_id": fields.String(required=False, description="Agent ID"),
"active_docs": fields.String(
required=False, description="Active documents"
),
@@ -81,6 +81,12 @@ class StreamResource(Resource, BaseAnswerResource):
processor = StreamProcessor(data, decoded_token)
try:
processor.initialize()
if not processor.decoded_token:
return Response(
self.error_stream_generate("Unauthorized"),
status=401,
mimetype="text/event-stream",
)
docs_together, docs_list = processor.pre_fetch_docs(data["question"])
tools_data = processor.pre_fetch_tools()
@@ -102,7 +108,7 @@ class StreamResource(Resource, BaseAnswerResource):
index=data.get("index"),
should_save_conversation=data.get("save_conversation", True),
attachment_ids=data.get("attachments", []),
agent_id=data.get("agent_id"),
agent_id=processor.agent_id,
is_shared_usage=processor.is_shared_usage,
shared_token=processor.shared_token,
model_id=processor.model_id,

View File

@@ -134,6 +134,7 @@ class CompressionOrchestrator:
user_api_key=None,
decoded_token=decoded_token,
model_id=compression_model,
agent_id=conversation.get("agent_id"),
)
# Create compression service with DB update capability

View File

@@ -62,6 +62,8 @@ class ConversationService:
attachment_ids: Optional[List[str]] = None,
) -> str:
"""Save or update a conversation in the database"""
if decoded_token is None:
raise ValueError("Invalid or missing authentication token")
user_id = decoded_token.get("sub")
if not user_id:
raise ValueError("User ID not found in token")
@@ -148,9 +150,12 @@ class ConversationService:
]
completion = llm.gen(
model=model_id, messages=messages_summary, max_tokens=30
model=model_id, messages=messages_summary, max_tokens=500
)
if not completion or not completion.strip():
completion = question[:50] if question else "New Conversation"
conversation_data = {
"user": user_id,
"date": current_time,

View File

@@ -90,6 +90,7 @@ class StreamProcessor:
self.retriever_config = {}
self.is_shared_usage = False
self.shared_token = None
self.agent_id = self.data.get("agent_id")
self.model_id: Optional[str] = None
self.conversation_service = ConversationService()
self.compression_orchestrator = CompressionOrchestrator(
@@ -150,9 +151,7 @@ class StreamProcessor:
)
if not result.success:
logger.error(
f"Compression failed: {result.error}, using full history"
)
logger.error(f"Compression failed: {result.error}, using full history")
self.history = [
{"prompt": query["prompt"], "response": query["response"]}
for query in conversation.get("queries", [])
@@ -225,7 +224,11 @@ class StreamProcessor:
raise ValueError(
f"Invalid model_id '{requested_model}'. "
f"Available models: {', '.join(available_models[:5])}"
+ (f" and {len(available_models) - 5} more" if len(available_models) > 5 else "")
+ (
f" and {len(available_models) - 5} more"
if len(available_models) > 5
else ""
)
)
self.model_id = requested_model
else:
@@ -353,10 +356,13 @@ class StreamProcessor:
self.agent_key, self.is_shared_usage, self.shared_token = self._get_agent_key(
agent_id, self.initial_user_id
)
self.agent_id = str(agent_id) if agent_id else None
api_key = self.data.get("api_key")
if api_key:
data_key = self._get_data_from_api_key(api_key)
if data_key.get("_id"):
self.agent_id = str(data_key.get("_id"))
self.agent_config.update(
{
"prompt_id": data_key.get("prompt_id", "default"),
@@ -370,6 +376,9 @@ class StreamProcessor:
self.decoded_token = {"sub": data_key.get("user")}
if data_key.get("source"):
self.source = {"active_docs": data_key["source"]}
if data_key.get("workflow"):
self.agent_config["workflow"] = data_key["workflow"]
self.agent_config["workflow_owner"] = data_key.get("user")
if data_key.get("retriever"):
self.retriever_config["retriever_name"] = data_key["retriever"]
if data_key.get("chunks") is not None:
@@ -382,6 +391,8 @@ class StreamProcessor:
self.retriever_config["chunks"] = 2
elif self.agent_key:
data_key = self._get_data_from_api_key(self.agent_key)
if data_key.get("_id"):
self.agent_id = str(data_key.get("_id"))
self.agent_config.update(
{
"prompt_id": data_key.get("prompt_id", "default"),
@@ -398,6 +409,9 @@ class StreamProcessor:
)
if data_key.get("source"):
self.source = {"active_docs": data_key["source"]}
if data_key.get("workflow"):
self.agent_config["workflow"] = data_key["workflow"]
self.agent_config["workflow_owner"] = data_key.get("user")
if data_key.get("retriever"):
self.retriever_config["retriever_name"] = data_key["retriever"]
if data_key.get("chunks") is not None:
@@ -409,10 +423,19 @@ class StreamProcessor:
)
self.retriever_config["chunks"] = 2
else:
agent_type = settings.AGENT_NAME
if self.data.get("workflow") and isinstance(
self.data.get("workflow"), dict
):
agent_type = "workflow"
self.agent_config["workflow"] = self.data["workflow"]
if isinstance(self.decoded_token, dict):
self.agent_config["workflow_owner"] = self.decoded_token.get("sub")
self.agent_config.update(
{
"prompt_id": self.data.get("prompt_id", "default"),
"agent_type": settings.AGENT_NAME,
"agent_type": agent_type,
"user_api_key": None,
"json_schema": None,
"default_model_id": "",
@@ -420,16 +443,12 @@ class StreamProcessor:
)
def _configure_retriever(self):
history_token_limit = int(self.data.get("token_limit", 2000))
doc_token_limit = calculate_doc_token_budget(
model_id=self.model_id, history_token_limit=history_token_limit
)
doc_token_limit = calculate_doc_token_budget(model_id=self.model_id)
self.retriever_config = {
"retriever_name": self.data.get("retriever", "classic"),
"chunks": int(self.data.get("chunks", 2)),
"doc_token_limit": doc_token_limit,
"history_token_limit": history_token_limit,
}
api_key = self.data.get("api_key") or self.agent_key
@@ -446,6 +465,7 @@ class StreamProcessor:
doc_token_limit=self.retriever_config.get("doc_token_limit", 50000),
model_id=self.model_id,
user_api_key=self.agent_config["user_api_key"],
agent_id=self.agent_id,
decoded_token=self.decoded_token,
)
@@ -733,21 +753,37 @@ class StreamProcessor:
)
system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER)
agent = AgentCreator.create_agent(
self.agent_config["agent_type"],
endpoint="stream",
llm_name=provider or settings.LLM_PROVIDER,
model_id=self.model_id,
api_key=system_api_key,
user_api_key=self.agent_config["user_api_key"],
prompt=rendered_prompt,
chat_history=self.history,
retrieved_docs=self.retrieved_docs,
decoded_token=self.decoded_token,
attachments=self.attachments,
json_schema=self.agent_config.get("json_schema"),
compressed_summary=self.compressed_summary,
)
agent_type = self.agent_config["agent_type"]
# Base agent kwargs
agent_kwargs = {
"endpoint": "stream",
"llm_name": provider or settings.LLM_PROVIDER,
"model_id": self.model_id,
"api_key": system_api_key,
"agent_id": self.agent_id,
"user_api_key": self.agent_config["user_api_key"],
"prompt": rendered_prompt,
"chat_history": self.history,
"retrieved_docs": self.retrieved_docs,
"decoded_token": self.decoded_token,
"attachments": self.attachments,
"json_schema": self.agent_config.get("json_schema"),
"compressed_summary": self.compressed_summary,
}
# Workflow-specific kwargs for workflow agents
if agent_type == "workflow":
workflow_config = self.agent_config.get("workflow")
if isinstance(workflow_config, str):
agent_kwargs["workflow_id"] = workflow_config
elif isinstance(workflow_config, dict):
agent_kwargs["workflow"] = workflow_config
workflow_owner = self.agent_config.get("workflow_owner")
if workflow_owner:
agent_kwargs["workflow_owner"] = workflow_owner
agent = AgentCreator.create_agent(agent_type, **agent_kwargs)
agent.conversation_id = self.conversation_id
agent.initial_user_id = self.initial_user_id

View File

@@ -1,7 +1,9 @@
import base64
import datetime
import html
import json
import uuid
from urllib.parse import urlencode
from bson.objectid import ObjectId
@@ -35,6 +37,18 @@ connector = Blueprint("connector", __name__)
connectors_ns = Namespace("connectors", description="Connector operations", path="/")
api.add_namespace(connectors_ns)
# Fixed callback status path to prevent open redirect
CALLBACK_STATUS_PATH = "/api/connectors/callback-status"
def build_callback_redirect(params: dict) -> str:
"""Build a safe redirect URL to the callback status page.
Uses a fixed path and properly URL-encodes all parameters
to prevent URL injection and open redirect vulnerabilities.
"""
return f"{CALLBACK_STATUS_PATH}?{urlencode(params)}"
@connectors_ns.route("/api/connectors/auth")
@@ -75,8 +89,8 @@ class ConnectorAuth(Resource):
"state": state
}), 200)
except Exception as e:
current_app.logger.error(f"Error generating connector auth URL: {e}")
return make_response(jsonify({"success": False, "error": str(e)}), 500)
current_app.logger.error(f"Error generating connector auth URL: {e}", exc_info=True)
return make_response(jsonify({"success": False, "error": "Failed to generate authorization URL"}), 500)
@connectors_ns.route("/api/connectors/callback")
@@ -93,18 +107,37 @@ class ConnectorsCallback(Resource):
error = request.args.get('error')
state_dict = json.loads(base64.urlsafe_b64decode(state.encode()).decode())
provider = state_dict["provider"]
state_object_id = state_dict["object_id"]
provider = state_dict.get("provider")
state_object_id = state_dict.get("object_id")
# Validate provider
if not provider or not isinstance(provider, str) or not ConnectorCreator.is_supported(provider):
return redirect(build_callback_redirect({
"status": "error",
"message": "Invalid provider"
}))
if error:
if error == "access_denied":
return redirect(f"/api/connectors/callback-status?status=cancelled&message=Authentication+was+cancelled.+You+can+try+again+if+you'd+like+to+connect+your+account.&provider={provider}")
return redirect(build_callback_redirect({
"status": "cancelled",
"message": "Authentication was cancelled. You can try again if you'd like to connect your account.",
"provider": provider
}))
else:
current_app.logger.warning(f"OAuth error in callback: {error}")
return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}")
return redirect(build_callback_redirect({
"status": "error",
"message": "Authentication failed. Please try again and make sure to grant all requested permissions.",
"provider": provider
}))
if not authorization_code:
return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}")
return redirect(build_callback_redirect({
"status": "error",
"message": "Authentication failed. Please try again and make sure to grant all requested permissions.",
"provider": provider
}))
try:
auth = ConnectorCreator.create_auth(provider)
@@ -141,15 +174,28 @@ class ConnectorsCallback(Resource):
)
# Redirect to success page with session token and user email
return redirect(f"/api/connectors/callback-status?status=success&message=Authentication+successful&provider={provider}&session_token={session_token}&user_email={user_email}")
return redirect(build_callback_redirect({
"status": "success",
"message": "Authentication successful",
"provider": provider,
"session_token": session_token,
"user_email": user_email
}))
except Exception as e:
current_app.logger.error(f"Error exchanging code for tokens: {str(e)}", exc_info=True)
return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}")
return redirect(build_callback_redirect({
"status": "error",
"message": "Authentication failed. Please try again and make sure to grant all requested permissions.",
"provider": provider
}))
except Exception as e:
current_app.logger.error(f"Error handling connector callback: {e}")
return redirect("/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.")
return redirect(build_callback_redirect({
"status": "error",
"message": "Authentication failed. Please try again and make sure to grant all requested permissions."
}))
@connectors_ns.route("/api/connectors/files")
@@ -228,8 +274,8 @@ class ConnectorFiles(Resource):
"has_more": has_more
}), 200)
except Exception as e:
current_app.logger.error(f"Error loading connector files: {e}")
return make_response(jsonify({"success": False, "error": f"Failed to load files: {str(e)}"}), 500)
current_app.logger.error(f"Error loading connector files: {e}", exc_info=True)
return make_response(jsonify({"success": False, "error": "Failed to load files"}), 500)
@connectors_ns.route("/api/connectors/validate-session")
@@ -289,8 +335,8 @@ class ConnectorValidateSession(Resource):
"access_token": token_info.get('access_token')
}), 200)
except Exception as e:
current_app.logger.error(f"Error validating connector session: {e}")
return make_response(jsonify({"success": False, "error": str(e)}), 500)
current_app.logger.error(f"Error validating connector session: {e}", exc_info=True)
return make_response(jsonify({"success": False, "error": "Failed to validate session"}), 500)
@connectors_ns.route("/api/connectors/disconnect")
@@ -311,8 +357,8 @@ class ConnectorDisconnect(Resource):
return make_response(jsonify({"success": True}), 200)
except Exception as e:
current_app.logger.error(f"Error disconnecting connector session: {e}")
return make_response(jsonify({"success": False, "error": str(e)}), 500)
current_app.logger.error(f"Error disconnecting connector session: {e}", exc_info=True)
return make_response(jsonify({"success": False, "error": "Failed to disconnect session"}), 500)
@connectors_ns.route("/api/connectors/sync")
@@ -418,8 +464,8 @@ class ConnectorSync(Resource):
return make_response(
jsonify({
"success": False,
"error": str(err)
}),
"error": "Failed to sync connector source"
}),
400
)
@@ -430,17 +476,32 @@ class ConnectorCallbackStatus(Resource):
def get(self):
"""Return HTML page with connector authentication status"""
try:
status = request.args.get('status', 'error')
message = request.args.get('message', '')
provider = request.args.get('provider', 'connector')
# Validate and sanitize status to a known value
status_raw = request.args.get('status', 'error')
status = status_raw if status_raw in ('success', 'error', 'cancelled') else 'error'
# Escape all user-controlled values for HTML context
message = html.escape(request.args.get('message', ''))
provider_raw = request.args.get('provider', 'connector')
provider = html.escape(provider_raw.replace('_', ' ').title())
session_token = request.args.get('session_token', '')
user_email = request.args.get('user_email', '')
user_email = html.escape(request.args.get('user_email', ''))
def safe_js_string(value: str) -> str:
"""Safely encode a string for embedding in inline JavaScript."""
js_encoded = json.dumps(value)
return js_encoded.replace('</', '<\\/').replace('<!--', '<\\!--')
js_status = safe_js_string(status)
js_session_token = safe_js_string(session_token)
js_user_email = safe_js_string(user_email)
js_provider_type = safe_js_string(provider_raw)
html_content = f"""
<!DOCTYPE html>
<html>
<head>
<title>{provider.replace('_', ' ').title()} Authentication</title>
<title>{provider} Authentication</title>
<style>
body {{ font-family: Arial, sans-serif; text-align: center; padding: 40px; }}
.container {{ max-width: 600px; margin: 0 auto; }}
@@ -450,13 +511,14 @@ class ConnectorCallbackStatus(Resource):
</style>
<script>
window.onload = function() {{
const status = "{status}";
const sessionToken = "{session_token}";
const userEmail = "{user_email}";
const status = {js_status};
const sessionToken = {js_session_token};
const userEmail = {js_user_email};
const providerType = {js_provider_type};
if (status === "success" && window.opener) {{
window.opener.postMessage({{
type: '{provider}_auth_success',
type: providerType + '_auth_success',
session_token: sessionToken,
user_email: userEmail
}}, '*');
@@ -470,17 +532,17 @@ class ConnectorCallbackStatus(Resource):
</head>
<body>
<div class="container">
<h2>{provider.replace('_', ' ').title()} Authentication</h2>
<h2>{provider} Authentication</h2>
<div class="{status}">
<p>{message}</p>
{f'<p>Connected as: {user_email}</p>' if status == 'success' else ''}
</div>
<p><small>You can close this window. {f"Your {provider.replace('_', ' ').title()} is now connected and ready to use." if status == 'success' else "Feel free to close this window."}</small></p>
<p><small>You can close this window. {f"Your {provider} is now connected and ready to use." if status == 'success' else "Feel free to close this window."}</small></p>
</div>
</body>
</html>
"""
return make_response(html_content, 200, {'Content-Type': 'text/html'})
except Exception as e:
current_app.logger.error(f"Error rendering callback status page: {e}")

View File

@@ -61,6 +61,7 @@ def upload_index_files():
file_path = request.form.get("file_path")
directory_structure = request.form.get("directory_structure")
file_name_map = request.form.get("file_name_map")
if directory_structure:
try:
@@ -70,6 +71,14 @@ def upload_index_files():
directory_structure = {}
else:
directory_structure = {}
if file_name_map:
try:
file_name_map = json.loads(file_name_map)
except Exception:
logger.error("Error parsing file_name_map")
file_name_map = None
else:
file_name_map = None
storage = StorageCreator.get_storage()
index_base_path = f"indexes/{id}"
@@ -97,41 +106,43 @@ def upload_index_files():
existing_entry = sources_collection.find_one({"_id": ObjectId(id)})
if existing_entry:
update_fields = {
"user": user,
"name": job_name,
"language": job_name,
"date": datetime.datetime.now(),
"model": settings.EMBEDDINGS_NAME,
"type": type,
"tokens": tokens,
"retriever": retriever,
"remote_data": remote_data,
"sync_frequency": sync_frequency,
"file_path": file_path,
"directory_structure": directory_structure,
}
if file_name_map is not None:
update_fields["file_name_map"] = file_name_map
sources_collection.update_one(
{"_id": ObjectId(id)},
{
"$set": {
"user": user,
"name": job_name,
"language": job_name,
"date": datetime.datetime.now(),
"model": settings.EMBEDDINGS_NAME,
"type": type,
"tokens": tokens,
"retriever": retriever,
"remote_data": remote_data,
"sync_frequency": sync_frequency,
"file_path": file_path,
"directory_structure": directory_structure,
}
},
{"$set": update_fields},
)
else:
sources_collection.insert_one(
{
"_id": ObjectId(id),
"user": user,
"name": job_name,
"language": job_name,
"date": datetime.datetime.now(),
"model": settings.EMBEDDINGS_NAME,
"type": type,
"tokens": tokens,
"retriever": retriever,
"remote_data": remote_data,
"sync_frequency": sync_frequency,
"file_path": file_path,
"directory_structure": directory_structure,
}
)
insert_doc = {
"_id": ObjectId(id),
"user": user,
"name": job_name,
"language": job_name,
"date": datetime.datetime.now(),
"model": settings.EMBEDDINGS_NAME,
"type": type,
"tokens": tokens,
"retriever": retriever,
"remote_data": remote_data,
"sync_frequency": sync_frequency,
"file_path": file_path,
"directory_structure": directory_structure,
}
if file_name_map is not None:
insert_doc["file_name_map"] = file_name_map
sources_collection.insert_one(insert_doc)
return {"status": "ok"}

View File

@@ -3,5 +3,6 @@
from .routes import agents_ns
from .sharing import agents_sharing_ns
from .webhooks import agents_webhooks_ns
from .folders import agents_folders_ns
__all__ = ["agents_ns", "agents_sharing_ns", "agents_webhooks_ns"]
__all__ = ["agents_ns", "agents_sharing_ns", "agents_webhooks_ns", "agents_folders_ns"]

View File

@@ -0,0 +1,261 @@
"""
Agent folders management routes.
Provides virtual folder organization for agents (Google Drive-like structure).
"""
import datetime
from bson.objectid import ObjectId
from flask import jsonify, make_response, request
from flask_restx import Namespace, Resource, fields
from application.api import api
from application.api.user.base import (
agent_folders_collection,
agents_collection,
)
agents_folders_ns = Namespace(
"agents_folders", description="Agent folder management", path="/api/agents/folders"
)
@agents_folders_ns.route("/")
class AgentFolders(Resource):
@api.doc(description="Get all folders for the user")
def get(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
try:
folders = list(agent_folders_collection.find({"user": user}))
result = [
{
"id": str(f["_id"]),
"name": f["name"],
"parent_id": f.get("parent_id"),
"created_at": f.get("created_at", "").isoformat() if f.get("created_at") else None,
"updated_at": f.get("updated_at", "").isoformat() if f.get("updated_at") else None,
}
for f in folders
]
return make_response(jsonify({"folders": result}), 200)
except Exception as e:
return make_response(jsonify({"success": False, "message": str(e)}), 400)
@api.doc(description="Create a new folder")
@api.expect(
api.model(
"CreateFolder",
{
"name": fields.String(required=True, description="Folder name"),
"parent_id": fields.String(required=False, description="Parent folder ID"),
},
)
)
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
if not data or not data.get("name"):
return make_response(jsonify({"success": False, "message": "Folder name is required"}), 400)
parent_id = data.get("parent_id")
if parent_id:
parent = agent_folders_collection.find_one({"_id": ObjectId(parent_id), "user": user})
if not parent:
return make_response(jsonify({"success": False, "message": "Parent folder not found"}), 404)
try:
now = datetime.datetime.now(datetime.timezone.utc)
folder = {
"user": user,
"name": data["name"],
"parent_id": parent_id,
"created_at": now,
"updated_at": now,
}
result = agent_folders_collection.insert_one(folder)
return make_response(
jsonify({"id": str(result.inserted_id), "name": data["name"], "parent_id": parent_id}),
201,
)
except Exception as e:
return make_response(jsonify({"success": False, "message": str(e)}), 400)
@agents_folders_ns.route("/<string:folder_id>")
class AgentFolder(Resource):
@api.doc(description="Get a specific folder with its agents")
def get(self, folder_id):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
try:
folder = agent_folders_collection.find_one({"_id": ObjectId(folder_id), "user": user})
if not folder:
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
agents = list(agents_collection.find({"user": user, "folder_id": folder_id}))
agents_list = [
{"id": str(a["_id"]), "name": a["name"], "description": a.get("description", "")}
for a in agents
]
subfolders = list(agent_folders_collection.find({"user": user, "parent_id": folder_id}))
subfolders_list = [{"id": str(sf["_id"]), "name": sf["name"]} for sf in subfolders]
return make_response(
jsonify({
"id": str(folder["_id"]),
"name": folder["name"],
"parent_id": folder.get("parent_id"),
"agents": agents_list,
"subfolders": subfolders_list,
}),
200,
)
except Exception as e:
return make_response(jsonify({"success": False, "message": str(e)}), 400)
@api.doc(description="Update a folder")
def put(self, folder_id):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
if not data:
return make_response(jsonify({"success": False, "message": "No data provided"}), 400)
try:
update_fields = {"updated_at": datetime.datetime.now(datetime.timezone.utc)}
if "name" in data:
update_fields["name"] = data["name"]
if "parent_id" in data:
if data["parent_id"] == folder_id:
return make_response(jsonify({"success": False, "message": "Cannot set folder as its own parent"}), 400)
update_fields["parent_id"] = data["parent_id"]
result = agent_folders_collection.update_one(
{"_id": ObjectId(folder_id), "user": user}, {"$set": update_fields}
)
if result.matched_count == 0:
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
return make_response(jsonify({"success": True}), 200)
except Exception as e:
return make_response(jsonify({"success": False, "message": str(e)}), 400)
@api.doc(description="Delete a folder")
def delete(self, folder_id):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
try:
agents_collection.update_many(
{"user": user, "folder_id": folder_id}, {"$unset": {"folder_id": ""}}
)
agent_folders_collection.update_many(
{"user": user, "parent_id": folder_id}, {"$unset": {"parent_id": ""}}
)
result = agent_folders_collection.delete_one({"_id": ObjectId(folder_id), "user": user})
if result.deleted_count == 0:
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
return make_response(jsonify({"success": True}), 200)
except Exception as e:
return make_response(jsonify({"success": False, "message": str(e)}), 400)
@agents_folders_ns.route("/move_agent")
class MoveAgentToFolder(Resource):
@api.doc(description="Move an agent to a folder or remove from folder")
@api.expect(
api.model(
"MoveAgent",
{
"agent_id": fields.String(required=True, description="Agent ID to move"),
"folder_id": fields.String(required=False, description="Target folder ID (null to remove from folder)"),
},
)
)
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
if not data or not data.get("agent_id"):
return make_response(jsonify({"success": False, "message": "Agent ID is required"}), 400)
agent_id = data["agent_id"]
folder_id = data.get("folder_id")
try:
agent = agents_collection.find_one({"_id": ObjectId(agent_id), "user": user})
if not agent:
return make_response(jsonify({"success": False, "message": "Agent not found"}), 404)
if folder_id:
folder = agent_folders_collection.find_one({"_id": ObjectId(folder_id), "user": user})
if not folder:
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
agents_collection.update_one(
{"_id": ObjectId(agent_id)}, {"$set": {"folder_id": folder_id}}
)
else:
agents_collection.update_one(
{"_id": ObjectId(agent_id)}, {"$unset": {"folder_id": ""}}
)
return make_response(jsonify({"success": True}), 200)
except Exception as e:
return make_response(jsonify({"success": False, "message": str(e)}), 400)
@agents_folders_ns.route("/bulk_move")
class BulkMoveAgents(Resource):
@api.doc(description="Move multiple agents to a folder")
@api.expect(
api.model(
"BulkMoveAgents",
{
"agent_ids": fields.List(fields.String, required=True, description="List of agent IDs"),
"folder_id": fields.String(required=False, description="Target folder ID"),
},
)
)
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
if not data or not data.get("agent_ids"):
return make_response(jsonify({"success": False, "message": "Agent IDs are required"}), 400)
agent_ids = data["agent_ids"]
folder_id = data.get("folder_id")
try:
if folder_id:
folder = agent_folders_collection.find_one({"_id": ObjectId(folder_id), "user": user})
if not folder:
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
object_ids = [ObjectId(aid) for aid in agent_ids]
if folder_id:
agents_collection.update_many(
{"_id": {"$in": object_ids}, "user": user},
{"$set": {"folder_id": folder_id}},
)
else:
agents_collection.update_many(
{"_id": {"$in": object_ids}, "user": user},
{"$unset": {"folder_id": ""}},
)
return make_response(jsonify({"success": True}), 200)
except Exception as e:
return make_response(jsonify({"success": False, "message": str(e)}), 400)

View File

@@ -11,6 +11,7 @@ from flask_restx import fields, Namespace, Resource
from application.api import api
from application.api.user.base import (
agent_folders_collection,
agents_collection,
db,
ensure_user_doc,
@@ -18,6 +19,13 @@ from application.api.user.base import (
resolve_tool_details,
storage,
users_collection,
workflow_edges_collection,
workflow_nodes_collection,
workflows_collection,
)
from application.core.json_schema_utils import (
JsonSchemaValidationError,
normalize_json_schema_payload,
)
from application.core.settings import settings
from application.utils import (
@@ -30,6 +38,189 @@ from application.utils import (
agents_ns = Namespace("agents", description="Agent management operations", path="/api")
AGENT_TYPE_SCHEMAS = {
"classic": {
"required_published": [
"name",
"description",
"chunks",
"retriever",
"prompt_id",
],
"required_draft": ["name"],
"validate_published": ["name", "description", "prompt_id"],
"validate_draft": [],
"require_source": True,
"fields": [
"user",
"name",
"description",
"agent_type",
"status",
"key",
"image",
"source",
"sources",
"chunks",
"retriever",
"prompt_id",
"tools",
"json_schema",
"models",
"default_model_id",
"folder_id",
"limited_token_mode",
"token_limit",
"limited_request_mode",
"request_limit",
"createdAt",
"updatedAt",
"lastUsedAt",
],
},
"workflow": {
"required_published": ["name", "workflow"],
"required_draft": ["name"],
"validate_published": ["name", "workflow"],
"validate_draft": [],
"fields": [
"user",
"name",
"description",
"agent_type",
"status",
"key",
"workflow",
"folder_id",
"limited_token_mode",
"token_limit",
"limited_request_mode",
"request_limit",
"createdAt",
"updatedAt",
"lastUsedAt",
],
},
}
AGENT_TYPE_SCHEMAS["react"] = AGENT_TYPE_SCHEMAS["classic"]
AGENT_TYPE_SCHEMAS["openai"] = AGENT_TYPE_SCHEMAS["classic"]
def normalize_workflow_reference(workflow_value):
"""Normalize workflow references from form/json payloads."""
if workflow_value is None:
return None
if isinstance(workflow_value, dict):
return (
workflow_value.get("id")
or workflow_value.get("_id")
or workflow_value.get("workflow_id")
)
if isinstance(workflow_value, str):
value = workflow_value.strip()
if not value:
return ""
try:
parsed = json.loads(value)
if isinstance(parsed, str):
return parsed.strip()
if isinstance(parsed, dict):
return (
parsed.get("id") or parsed.get("_id") or parsed.get("workflow_id")
)
except json.JSONDecodeError:
pass
return value
return str(workflow_value)
def validate_workflow_access(workflow_value, user, required=False):
"""Validate workflow reference and ensure ownership."""
workflow_id = normalize_workflow_reference(workflow_value)
if not workflow_id:
if required:
return None, make_response(
jsonify({"success": False, "message": "Workflow is required"}), 400
)
return None, None
if not ObjectId.is_valid(workflow_id):
return None, make_response(
jsonify({"success": False, "message": "Invalid workflow ID format"}), 400
)
workflow = workflows_collection.find_one({"_id": ObjectId(workflow_id), "user": user})
if not workflow:
return None, make_response(
jsonify({"success": False, "message": "Workflow not found"}), 404
)
return workflow_id, None
def build_agent_document(
data, user, key, agent_type, image_url=None, source_field=None, sources_list=None
):
"""Build agent document based on agent type schema."""
if not agent_type or agent_type not in AGENT_TYPE_SCHEMAS:
agent_type = "classic"
schema = AGENT_TYPE_SCHEMAS.get(agent_type, AGENT_TYPE_SCHEMAS["classic"])
allowed_fields = set(schema["fields"])
now = datetime.datetime.now(datetime.timezone.utc)
base_doc = {
"user": user,
"name": data.get("name"),
"description": data.get("description", ""),
"agent_type": agent_type,
"status": data.get("status"),
"key": key,
"createdAt": now,
"updatedAt": now,
"lastUsedAt": None,
}
if agent_type == "workflow":
base_doc["workflow"] = data.get("workflow")
base_doc["folder_id"] = data.get("folder_id")
else:
base_doc.update(
{
"image": image_url or "",
"source": source_field or "",
"sources": sources_list or [],
"chunks": data.get("chunks", ""),
"retriever": data.get("retriever", ""),
"prompt_id": data.get("prompt_id", ""),
"tools": data.get("tools", []),
"json_schema": data.get("json_schema"),
"models": data.get("models", []),
"default_model_id": data.get("default_model_id", ""),
"folder_id": data.get("folder_id"),
}
)
if "limited_token_mode" in allowed_fields:
base_doc["limited_token_mode"] = (
data.get("limited_token_mode") == "True"
if isinstance(data.get("limited_token_mode"), str)
else bool(data.get("limited_token_mode", False))
)
if "token_limit" in allowed_fields:
base_doc["token_limit"] = int(
data.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"])
)
if "limited_request_mode" in allowed_fields:
base_doc["limited_request_mode"] = (
data.get("limited_request_mode") == "True"
if isinstance(data.get("limited_request_mode"), str)
else bool(data.get("limited_request_mode", False))
)
if "request_limit" in allowed_fields:
base_doc["request_limit"] = int(
data.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"])
)
return {k: v for k, v in base_doc.items() if k in allowed_fields}
@agents_ns.route("/get_agent")
class GetAgent(Resource):
@api.doc(params={"id": "Agent ID"}, description="Get agent by ID")
@@ -67,7 +258,7 @@ class GetAgent(Resource):
if (isinstance(source_ref, DBRef) and db.dereference(source_ref))
or source_ref == "default"
],
"chunks": agent["chunks"],
"chunks": agent.get("chunks", "2"),
"retriever": agent.get("retriever", ""),
"prompt_id": agent.get("prompt_id", ""),
"tools": agent.get("tools", []),
@@ -97,6 +288,8 @@ class GetAgent(Resource):
"shared_token": agent.get("shared_token", ""),
"models": agent.get("models", []),
"default_model_id": agent.get("default_model_id", ""),
"folder_id": agent.get("folder_id"),
"workflow": agent.get("workflow"),
}
return make_response(jsonify(data), 200)
except Exception as e:
@@ -146,7 +339,7 @@ class GetAgents(Resource):
isinstance(source_ref, DBRef) and db.dereference(source_ref)
)
],
"chunks": agent["chunks"],
"chunks": agent.get("chunks", "2"),
"retriever": agent.get("retriever", ""),
"prompt_id": agent.get("prompt_id", ""),
"tools": agent.get("tools", []),
@@ -176,9 +369,13 @@ class GetAgents(Resource):
"shared_token": agent.get("shared_token", ""),
"models": agent.get("models", []),
"default_model_id": agent.get("default_model_id", ""),
"folder_id": agent.get("folder_id"),
"workflow": agent.get("workflow"),
}
for agent in agents
if "source" in agent or "retriever" in agent
if "source" in agent
or "retriever" in agent
or agent.get("agent_type") == "workflow"
]
except Exception as err:
current_app.logger.error(f"Error retrieving agents: {err}", exc_info=True)
@@ -206,16 +403,22 @@ class CreateAgent(Resource):
required=False,
description="List of source identifiers for multiple sources",
),
"chunks": fields.Integer(required=True, description="Chunks count"),
"retriever": fields.String(required=True, description="Retriever ID"),
"prompt_id": fields.String(required=True, description="Prompt ID"),
"chunks": fields.Integer(required=False, description="Chunks count"),
"retriever": fields.String(required=False, description="Retriever ID"),
"prompt_id": fields.String(required=False, description="Prompt ID"),
"tools": fields.List(
fields.String, required=False, description="List of tool identifiers"
),
"agent_type": fields.String(required=True, description="Type of the agent"),
"agent_type": fields.String(
required=False,
description="Type of the agent (classic, react, workflow). Defaults to 'classic' for backwards compatibility.",
),
"status": fields.String(
required=True, description="Status of the agent (draft or published)"
),
"workflow": fields.String(
required=False, description="Workflow ID for workflow-type agents"
),
"json_schema": fields.Raw(
required=False,
description="JSON schema for enforcing structured output format",
@@ -242,6 +445,9 @@ class CreateAgent(Resource):
"default_model_id": fields.String(
required=False, description="Default model ID for this agent"
),
"folder_id": fields.String(
required=False, description="Folder ID to organize the agent"
),
},
)
@@ -277,40 +483,15 @@ class CreateAgent(Resource):
data["models"] = []
print(f"Received data: {data}")
# Validate JSON schema if provided
if data.get("json_schema"):
# Validate and normalize JSON schema if provided
if "json_schema" in data:
try:
# Basic validation - ensure it's a valid JSON structure
json_schema = data.get("json_schema")
if not isinstance(json_schema, dict):
return make_response(
jsonify(
{
"success": False,
"message": "JSON schema must be a valid JSON object",
}
),
400,
)
# Validate that it has either a 'schema' property or is itself a schema
if "schema" not in json_schema and "type" not in json_schema:
return make_response(
jsonify(
{
"success": False,
"message": "JSON schema must contain either a 'schema' property or be a valid JSON schema with 'type' property",
}
),
400,
)
except Exception as e:
data["json_schema"] = normalize_json_schema_payload(
data.get("json_schema")
)
except JsonSchemaValidationError as exc:
return make_response(
jsonify(
{"success": False, "message": f"Invalid JSON schema: {str(e)}"}
),
jsonify({"success": False, "message": f"JSON schema {exc}"}),
400,
)
if data.get("status") not in ["draft", "published"]:
@@ -323,18 +504,34 @@ class CreateAgent(Resource):
),
400,
)
if data.get("status") == "published":
required_fields = [
"name",
"description",
"chunks",
"retriever",
"prompt_id",
"agent_type",
]
# Require either source or sources (but not both)
agent_type = data.get("agent_type", "")
# Default to classic schema for empty or unknown agent types
if not data.get("source") and not data.get("sources"):
if not agent_type or agent_type not in AGENT_TYPE_SCHEMAS:
schema = AGENT_TYPE_SCHEMAS["classic"]
# Set agent_type to classic if it was empty
if not agent_type:
agent_type = "classic"
else:
schema = AGENT_TYPE_SCHEMAS[agent_type]
is_published = data.get("status") == "published"
if agent_type == "workflow":
workflow_id, workflow_error = validate_workflow_access(
data.get("workflow"), user, required=is_published
)
if workflow_error:
return workflow_error
data["workflow"] = workflow_id
if data.get("status") == "published":
required_fields = schema["required_published"]
validate_fields = schema["validate_published"]
if (
schema.get("require_source")
and not data.get("source")
and not data.get("sources")
):
return make_response(
jsonify(
{
@@ -344,10 +541,9 @@ class CreateAgent(Resource):
),
400,
)
validate_fields = ["name", "description", "prompt_id", "agent_type"]
else:
required_fields = ["name"]
validate_fields = []
required_fields = schema["required_draft"]
validate_fields = schema["validate_draft"]
missing_fields = check_required_fields(data, required_fields)
invalid_fields = validate_required_fields(data, validate_fields)
if missing_fields:
@@ -359,74 +555,50 @@ class CreateAgent(Resource):
return make_response(
jsonify({"success": False, "message": "Image upload failed"}), 400
)
folder_id = data.get("folder_id")
if folder_id:
if not ObjectId.is_valid(folder_id):
return make_response(
jsonify({"success": False, "message": "Invalid folder ID format"}),
400,
)
folder = agent_folders_collection.find_one(
{"_id": ObjectId(folder_id), "user": user}
)
if not folder:
return make_response(
jsonify({"success": False, "message": "Folder not found"}), 404
)
try:
key = str(uuid.uuid4()) if data.get("status") == "published" else ""
sources_list = []
source_field = ""
if data.get("sources") and len(data.get("sources", [])) > 0:
for source_id in data.get("sources", []):
if source_id == "default":
sources_list.append("default")
elif ObjectId.is_valid(source_id):
sources_list.append(DBRef("sources", ObjectId(source_id)))
source_field = ""
else:
source_value = data.get("source", "")
if source_value == "default":
source_field = "default"
elif ObjectId.is_valid(source_value):
source_field = DBRef("sources", ObjectId(source_value))
else:
source_field = ""
new_agent = {
"user": user,
"name": data.get("name"),
"description": data.get("description", ""),
"image": image_url,
"source": source_field,
"sources": sources_list,
"chunks": data.get("chunks", ""),
"retriever": data.get("retriever", ""),
"prompt_id": data.get("prompt_id", ""),
"tools": data.get("tools", []),
"agent_type": data.get("agent_type", ""),
"status": data.get("status"),
"json_schema": data.get("json_schema"),
"limited_token_mode": (
data.get("limited_token_mode") == "True"
if isinstance(data.get("limited_token_mode"), str)
else bool(data.get("limited_token_mode", False))
),
"token_limit": int(
data.get(
"token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"]
)
),
"limited_request_mode": (
data.get("limited_request_mode") == "True"
if isinstance(data.get("limited_request_mode"), str)
else bool(data.get("limited_request_mode", False))
),
"request_limit": int(
data.get(
"request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"]
)
),
"createdAt": datetime.datetime.now(datetime.timezone.utc),
"updatedAt": datetime.datetime.now(datetime.timezone.utc),
"lastUsedAt": None,
"key": key,
"models": data.get("models", []),
"default_model_id": data.get("default_model_id", ""),
}
if new_agent["chunks"] == "":
new_agent["chunks"] = "2"
if (
new_agent["source"] == ""
and new_agent["retriever"] == ""
and not new_agent["sources"]
):
new_agent["retriever"] = "classic"
new_agent = build_agent_document(
data, user, key, agent_type, image_url, source_field, sources_list
)
if agent_type != "workflow":
if new_agent.get("chunks") == "":
new_agent["chunks"] = "2"
if (
new_agent.get("source") == ""
and new_agent.get("retriever") == ""
and not new_agent.get("sources")
):
new_agent["retriever"] = "classic"
resp = agents_collection.insert_one(new_agent)
new_id = str(resp.inserted_id)
except Exception as err:
@@ -455,16 +627,22 @@ class UpdateAgent(Resource):
required=False,
description="List of source identifiers for multiple sources",
),
"chunks": fields.Integer(required=True, description="Chunks count"),
"retriever": fields.String(required=True, description="Retriever ID"),
"prompt_id": fields.String(required=True, description="Prompt ID"),
"chunks": fields.Integer(required=False, description="Chunks count"),
"retriever": fields.String(required=False, description="Retriever ID"),
"prompt_id": fields.String(required=False, description="Prompt ID"),
"tools": fields.List(
fields.String, required=False, description="List of tool identifiers"
),
"agent_type": fields.String(required=True, description="Type of the agent"),
"agent_type": fields.String(
required=False,
description="Type of the agent (classic, react, workflow). Defaults to 'classic' for backwards compatibility.",
),
"status": fields.String(
required=True, description="Status of the agent (draft or published)"
),
"workflow": fields.String(
required=False, description="Workflow ID for workflow-type agents"
),
"json_schema": fields.Raw(
required=False,
description="JSON schema for enforcing structured output format",
@@ -491,6 +669,9 @@ class UpdateAgent(Resource):
"default_model_id": fields.String(
required=False, description="Default model ID for this agent"
),
"folder_id": fields.String(
required=False, description="Folder ID to organize the agent"
),
},
)
@@ -529,6 +710,8 @@ class UpdateAgent(Resource):
),
400,
)
if data.get("json_schema") == "":
data["json_schema"] = None
except Exception as err:
current_app.logger.error(
f"Error parsing request data: {err}", exc_info=True
@@ -584,6 +767,8 @@ class UpdateAgent(Resource):
"request_limit",
"models",
"default_model_id",
"folder_id",
"workflow",
]
for field in allowed_fields:
@@ -687,17 +872,15 @@ class UpdateAgent(Resource):
elif field == "json_schema":
json_schema = data.get("json_schema")
if json_schema is not None:
if not isinstance(json_schema, dict):
try:
update_fields[field] = normalize_json_schema_payload(
json_schema
)
except JsonSchemaValidationError as exc:
return make_response(
jsonify(
{
"success": False,
"message": "JSON schema must be a valid object",
}
),
jsonify({"success": False, "message": f"JSON schema {exc}"}),
400,
)
update_fields[field] = json_schema
else:
update_fields[field] = None
elif field == "limited_token_mode":
@@ -740,10 +923,10 @@ class UpdateAgent(Resource):
)
elif field == "token_limit":
token_limit = data.get("token_limit")
# Convert to int and store
update_fields[field] = int(token_limit) if token_limit else 0
# Validate consistency with mode
if update_fields[field] > 0 and not data.get("limited_token_mode"):
return make_response(
jsonify(
@@ -768,6 +951,42 @@ class UpdateAgent(Resource):
),
400,
)
elif field == "folder_id":
folder_id = data.get("folder_id")
if folder_id:
if not ObjectId.is_valid(folder_id):
return make_response(
jsonify(
{
"success": False,
"message": "Invalid folder ID format",
}
),
400,
)
folder = agent_folders_collection.find_one(
{"_id": ObjectId(folder_id), "user": user}
)
if not folder:
return make_response(
jsonify({"success": False, "message": "Folder not found"}),
404,
)
update_fields[field] = folder_id
else:
update_fields[field] = None
elif field == "workflow":
workflow_required = (
data.get("status", existing_agent.get("status")) == "published"
and data.get("agent_type", existing_agent.get("agent_type"))
== "workflow"
)
workflow_id, workflow_error = validate_workflow_access(
data.get("workflow"), user, required=workflow_required
)
if workflow_error:
return workflow_error
update_fields[field] = workflow_id
else:
value = data[field]
if field in ["name", "description", "prompt_id", "agent_type"]:
@@ -796,46 +1015,82 @@ class UpdateAgent(Resource):
)
newly_generated_key = None
final_status = update_fields.get("status", existing_agent.get("status"))
agent_type = update_fields.get("agent_type", existing_agent.get("agent_type"))
if final_status == "published":
required_published_fields = {
"name": "Agent name",
"description": "Agent description",
"chunks": "Chunks count",
"prompt_id": "Prompt",
"agent_type": "Agent type",
}
if agent_type == "workflow":
required_published_fields = {
"name": "Agent name",
}
missing_published_fields = []
for req_field, field_label in required_published_fields.items():
final_value = update_fields.get(
req_field, existing_agent.get(req_field)
)
if not final_value:
missing_published_fields.append(field_label)
missing_published_fields = []
for req_field, field_label in required_published_fields.items():
final_value = update_fields.get(
req_field, existing_agent.get(req_field)
workflow_id = update_fields.get("workflow", existing_agent.get("workflow"))
if not workflow_id:
missing_published_fields.append("Workflow")
elif not ObjectId.is_valid(workflow_id):
missing_published_fields.append("Valid workflow")
else:
workflow = workflows_collection.find_one(
{"_id": ObjectId(workflow_id), "user": user}
)
if not workflow:
missing_published_fields.append("Workflow access")
if missing_published_fields:
return make_response(
jsonify(
{
"success": False,
"message": f"Cannot publish workflow agent. Missing required fields: {', '.join(missing_published_fields)}",
}
),
400,
)
else:
required_published_fields = {
"name": "Agent name",
"description": "Agent description",
"chunks": "Chunks count",
"prompt_id": "Prompt",
"agent_type": "Agent type",
}
missing_published_fields = []
for req_field, field_label in required_published_fields.items():
final_value = update_fields.get(
req_field, existing_agent.get(req_field)
)
if not final_value:
missing_published_fields.append(field_label)
source_val = update_fields.get("source", existing_agent.get("source"))
sources_val = update_fields.get(
"sources", existing_agent.get("sources", [])
)
if not final_value:
missing_published_fields.append(field_label)
source_val = update_fields.get("source", existing_agent.get("source"))
sources_val = update_fields.get(
"sources", existing_agent.get("sources", [])
)
has_valid_source = (
isinstance(source_val, DBRef)
or source_val == "default"
or (isinstance(sources_val, list) and len(sources_val) > 0)
)
if not has_valid_source:
missing_published_fields.append("Source")
if missing_published_fields:
return make_response(
jsonify(
{
"success": False,
"message": f"Cannot publish agent. Missing or invalid required fields: {', '.join(missing_published_fields)}",
}
),
400,
has_valid_source = (
isinstance(source_val, DBRef)
or source_val == "default"
or (isinstance(sources_val, list) and len(sources_val) > 0)
)
if not has_valid_source:
missing_published_fields.append("Source")
if missing_published_fields:
return make_response(
jsonify(
{
"success": False,
"message": f"Cannot publish agent. Missing or invalid required fields: {', '.join(missing_published_fields)}",
}
),
400,
)
if not existing_agent.get("key"):
newly_generated_key = str(uuid.uuid4())
update_fields["key"] = newly_generated_key
@@ -907,6 +1162,29 @@ class DeleteAgent(Resource):
jsonify({"success": False, "message": "Agent not found"}), 404
)
deleted_id = str(deleted_agent["_id"])
if deleted_agent.get("agent_type") == "workflow" and deleted_agent.get(
"workflow"
):
workflow_id = normalize_workflow_reference(deleted_agent.get("workflow"))
if workflow_id and ObjectId.is_valid(workflow_id):
workflow_oid = ObjectId(workflow_id)
owned_workflow = workflows_collection.find_one(
{"_id": workflow_oid, "user": user}, {"_id": 1}
)
if owned_workflow:
workflow_nodes_collection.delete_many({"workflow_id": workflow_id})
workflow_edges_collection.delete_many({"workflow_id": workflow_id})
workflows_collection.delete_one({"_id": workflow_oid, "user": user})
else:
current_app.logger.warning(
f"Skipping workflow cleanup for non-owned workflow {workflow_id}"
)
elif workflow_id:
current_app.logger.warning(
f"Skipping workflow cleanup for invalid workflow id {workflow_id}"
)
except Exception as err:
current_app.logger.error(f"Error deleting agent: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
@@ -1015,19 +1293,16 @@ class AdoptAgent(Resource):
def post(self):
if not (decoded_token := request.decoded_token):
return make_response(jsonify({"success": False}), 401)
if not (agent_id := request.args.get("id")):
return make_response(
jsonify({"success": False, "message": "ID required"}), 400
)
try:
agent = agents_collection.find_one(
{"_id": ObjectId(agent_id), "user": "system"}
)
if not agent:
return make_response(jsonify({"status": "Not found"}), 404)
new_agent = agent.copy()
new_agent.pop("_id", None)
new_agent["user"] = decoded_token["sub"]

View File

@@ -255,8 +255,8 @@ class ShareAgent(Resource):
{"$unset": {"shared_metadata": ""}},
)
except Exception as err:
current_app.logger.error(f"Error sharing/unsharing agent: {err}")
return make_response(jsonify({"success": False, "error": str(err)}), 400)
current_app.logger.error(f"Error sharing/unsharing agent: {err}", exc_info=True)
return make_response(jsonify({"success": False, "error": "Failed to update agent sharing status"}), 400)
shared_token = shared_token if shared else None
return make_response(
jsonify({"success": True, "shared_token": shared_token}), 200

View File

@@ -99,11 +99,8 @@ class StoreAttachment(Resource):
})
if not tasks:
error_msg = "No valid files to upload"
if errors:
error_msg += f". Errors: {errors}"
return make_response(
jsonify({"status": "error", "message": error_msg, "errors": errors}),
jsonify({"status": "error", "message": "No valid files to upload"}),
400,
)
@@ -135,7 +132,7 @@ class StoreAttachment(Resource):
)
except Exception as err:
current_app.logger.error(f"Error storing attachment: {err}", exc_info=True)
return make_response(jsonify({"success": False, "error": str(err)}), 400)
return make_response(jsonify({"success": False, "error": "Failed to store attachment"}), 400)
@attachments_ns.route("/images/<path:image_path>")

View File

@@ -31,12 +31,17 @@ sources_collection = db["sources"]
prompts_collection = db["prompts"]
feedback_collection = db["feedback"]
agents_collection = db["agents"]
agent_folders_collection = db["agent_folders"]
token_usage_collection = db["token_usage"]
shared_conversations_collections = db["shared_conversations"]
users_collection = db["users"]
user_logs_collection = db["user_logs"]
user_tools_collection = db["user_tools"]
attachments_collection = db["attachments"]
workflow_runs_collection = db["workflow_runs"]
workflows_collection = db["workflows"]
workflow_nodes_collection = db["workflow_nodes"]
workflow_edges_collection = db["workflow_edges"]
try:
@@ -46,6 +51,25 @@ try:
background=True,
)
users_collection.create_index("user_id", unique=True)
workflows_collection.create_index(
[("user", 1)], name="workflow_user_index", background=True
)
workflow_nodes_collection.create_index(
[("workflow_id", 1)], name="node_workflow_index", background=True
)
workflow_nodes_collection.create_index(
[("workflow_id", 1), ("graph_version", 1)],
name="node_workflow_graph_version_index",
background=True,
)
workflow_edges_collection.create_index(
[("workflow_id", 1)], name="edge_workflow_index", background=True
)
workflow_edges_collection.create_index(
[("workflow_id", 1), ("graph_version", 1)],
name="edge_workflow_graph_version_index",
background=True,
)
except Exception as e:
print("Error creating indexes:", e)
current_dir = os.path.dirname(

View File

@@ -5,8 +5,7 @@ Main user API routes - registers all namespace modules.
from flask import Blueprint
from application.api import api
from .agents import agents_ns, agents_sharing_ns, agents_webhooks_ns
from .agents import agents_ns, agents_sharing_ns, agents_webhooks_ns, agents_folders_ns
from .analytics import analytics_ns
from .attachments import attachments_ns
from .conversations import conversations_ns
@@ -15,6 +14,7 @@ from .prompts import prompts_ns
from .sharing import sharing_ns
from .sources import sources_chunks_ns, sources_ns, sources_upload_ns
from .tools import tools_mcp_ns, tools_ns
from .workflows import workflows_ns
user = Blueprint("user", __name__)
@@ -31,10 +31,11 @@ api.add_namespace(conversations_ns)
# Models
api.add_namespace(models_ns)
# Agents (main, sharing, webhooks)
# Agents (main, sharing, webhooks, folders)
api.add_namespace(agents_ns)
api.add_namespace(agents_sharing_ns)
api.add_namespace(agents_webhooks_ns)
api.add_namespace(agents_folders_ns)
# Prompts
api.add_namespace(prompts_ns)
@@ -50,3 +51,6 @@ api.add_namespace(sources_upload_ns)
# Tools (main, MCP)
api.add_namespace(tools_ns)
api.add_namespace(tools_mcp_ns)
# Workflows
api.add_namespace(workflows_ns)

View File

@@ -220,8 +220,23 @@ class GetPubliclySharedConversations(Resource):
shared
and "conversation_id" in shared
):
# conversation_id is now stored as an ObjectId, not a DBRef
# Handle DBRef (legacy), ObjectId, dict, and string formats for conversation_id
conversation_id = shared["conversation_id"]
if isinstance(conversation_id, DBRef):
conversation_id = conversation_id.id
elif isinstance(conversation_id, dict):
# Handle dict representation of DBRef (e.g., {"$ref": "...", "$id": "..."})
if "$id" in conversation_id:
conv_id = conversation_id["$id"]
# $id might be a dict like {"$oid": "..."} or a string
if isinstance(conv_id, dict) and "$oid" in conv_id:
conversation_id = ObjectId(conv_id["$oid"])
else:
conversation_id = ObjectId(conv_id)
elif "_id" in conversation_id:
conversation_id = ObjectId(conversation_id["_id"])
elif isinstance(conversation_id, str):
conversation_id = ObjectId(conversation_id)
conversation = conversations_collection.find_one(
{"_id": conversation_id}
)

View File

@@ -55,9 +55,14 @@ class GetChunks(Resource):
if path:
chunk_source = metadata.get("source", "")
# Check if the chunk's source matches the requested path
chunk_file_path = metadata.get("file_path", "")
# Check if the chunk matches the requested path
# For file uploads: source ends with path (e.g., "inputs/.../file.pdf" ends with "file.pdf")
# For crawlers: file_path ends with path (e.g., "guides/setup.md" ends with "setup.md")
source_match = chunk_source and chunk_source.endswith(path)
file_path_match = chunk_file_path and chunk_file_path.endswith(path)
if not chunk_source or not chunk_source.endswith(path):
if not (source_match or file_path_match):
continue
# Filter by search term if provided

View File

@@ -9,6 +9,7 @@ from flask_restx import fields, Namespace, Resource
from application.api import api
from application.api.user.base import sources_collection
from application.api.user.tasks import sync_source
from application.core.settings import settings
from application.storage.storage_creator import StorageCreator
from application.utils import check_required_fields
@@ -20,6 +21,21 @@ sources_ns = Namespace(
)
def _get_provider_from_remote_data(remote_data):
if not remote_data:
return None
if isinstance(remote_data, dict):
return remote_data.get("provider")
if isinstance(remote_data, str):
try:
remote_data_obj = json.loads(remote_data)
except Exception:
return None
if isinstance(remote_data_obj, dict):
return remote_data_obj.get("provider")
return None
@sources_ns.route("/sources")
class CombinedJson(Resource):
@api.doc(description="Provide JSON file with combined available indexes")
@@ -41,6 +57,7 @@ class CombinedJson(Resource):
try:
for index in sources_collection.find({"user": user}).sort("date", -1):
provider = _get_provider_from_remote_data(index.get("remote_data"))
data.append(
{
"id": str(index["_id"]),
@@ -51,6 +68,7 @@ class CombinedJson(Resource):
"tokens": index.get("tokens", ""),
"retriever": index.get("retriever", "classic"),
"syncFrequency": index.get("sync_frequency", ""),
"provider": provider,
"is_nested": bool(index.get("directory_structure")),
"type": index.get(
"type", "file"
@@ -107,6 +125,7 @@ class PaginatedSources(Resource):
paginated_docs = []
for doc in documents:
provider = _get_provider_from_remote_data(doc.get("remote_data"))
doc_data = {
"id": str(doc["_id"]),
"name": doc.get("name", ""),
@@ -116,6 +135,7 @@ class PaginatedSources(Resource):
"tokens": doc.get("tokens", ""),
"retriever": doc.get("retriever", "classic"),
"syncFrequency": doc.get("sync_frequency", ""),
"provider": provider,
"isNested": bool(doc.get("directory_structure")),
"type": doc.get("type", "file"),
}
@@ -240,7 +260,7 @@ class ManageSync(Resource):
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
data = request.get_json() or {}
required_fields = ["source_id", "sync_frequency"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
@@ -269,6 +289,72 @@ class ManageSync(Resource):
return make_response(jsonify({"success": True}), 200)
@sources_ns.route("/sync_source")
class SyncSource(Resource):
sync_source_model = api.model(
"SyncSourceModel",
{"source_id": fields.String(required=True, description="Source ID")},
)
@api.expect(sync_source_model)
@api.doc(description="Trigger an immediate sync for a source")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
required_fields = ["source_id"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
source_id = data["source_id"]
if not ObjectId.is_valid(source_id):
return make_response(
jsonify({"success": False, "message": "Invalid source ID"}), 400
)
doc = sources_collection.find_one(
{"_id": ObjectId(source_id), "user": user}
)
if not doc:
return make_response(
jsonify({"success": False, "message": "Source not found"}), 404
)
source_type = doc.get("type", "")
if source_type.startswith("connector"):
return make_response(
jsonify(
{
"success": False,
"message": "Connector sources must be synced via /api/connectors/sync",
}
),
400,
)
source_data = doc.get("remote_data")
if not source_data:
return make_response(
jsonify({"success": False, "message": "Source is not syncable"}), 400
)
try:
task = sync_source.delay(
source_data=source_data,
job_name=doc.get("name", ""),
user=user,
loader=source_type,
sync_frequency=doc.get("sync_frequency", "never"),
retriever=doc.get("retriever", "classic"),
doc_id=source_id,
)
except Exception as err:
current_app.logger.error(
f"Error starting sync for source {source_id}: {err}",
exc_info=True,
)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
@sources_ns.route("/directory_structure")
class DirectoryStructure(Resource):
@api.doc(
@@ -320,4 +406,4 @@ class DirectoryStructure(Resource):
current_app.logger.error(
f"Error retrieving directory structure: {e}", exc_info=True
)
return make_response(jsonify({"success": False, "error": str(e)}), 500)
return make_response(jsonify({"success": False, "error": "Failed to retrieve directory structure"}), 500)

View File

@@ -64,19 +64,27 @@ class UploadFile(Resource):
safe_user = safe_filename(user)
dir_name = safe_filename(job_name)
base_path = f"{settings.UPLOAD_FOLDER}/{safe_user}/{dir_name}"
file_name_map = {}
try:
storage = StorageCreator.get_storage()
for file in files:
original_filename = file.filename
original_filename = os.path.basename(file.filename)
safe_file = safe_filename(original_filename)
if original_filename:
file_name_map[safe_file] = original_filename
with tempfile.TemporaryDirectory() as temp_dir:
temp_file_path = os.path.join(temp_dir, safe_file)
file.save(temp_file_path)
if zipfile.is_zipfile(temp_file_path):
# Only extract actual .zip files, not Office formats (.docx, .xlsx, .pptx)
# which are technically zip archives but should be processed as-is
is_office_format = safe_file.lower().endswith(
(".docx", ".xlsx", ".pptx", ".odt", ".ods", ".odp", ".epub")
)
if zipfile.is_zipfile(temp_file_path) and not is_office_format:
try:
with zipfile.ZipFile(temp_file_path, "r") as zip_ref:
zip_ref.extractall(path=temp_dir)
@@ -137,6 +145,7 @@ class UploadFile(Resource):
user,
file_path=base_path,
filename=dir_name,
file_name_map=file_name_map,
)
except Exception as err:
current_app.logger.error(f"Error uploading file: {err}", exc_info=True)
@@ -182,6 +191,8 @@ class UploadRemote(Resource):
source_data = config.get("url")
elif data["source"] == "reddit":
source_data = config
elif data["source"] == "s3":
source_data = config
elif data["source"] in ConnectorCreator.get_supported_connectors():
session_token = config.get("session_token")
if not session_token:
@@ -334,6 +345,14 @@ class ManageSourceFiles(Resource):
storage = StorageCreator.get_storage()
source_file_path = source.get("file_path", "")
parent_dir = request.form.get("parent_dir", "")
file_name_map = source.get("file_name_map") or {}
if isinstance(file_name_map, str):
try:
file_name_map = json.loads(file_name_map)
except Exception:
file_name_map = {}
if not isinstance(file_name_map, dict):
file_name_map = {}
if parent_dir and (parent_dir.startswith("/") or ".." in parent_dir):
return make_response(
@@ -355,19 +374,35 @@ class ManageSourceFiles(Resource):
400,
)
added_files = []
map_updated = False
target_dir = source_file_path
if parent_dir:
target_dir = f"{source_file_path}/{parent_dir}"
for file in files:
if file.filename:
safe_filename_str = safe_filename(file.filename)
original_filename = os.path.basename(file.filename)
safe_filename_str = safe_filename(original_filename)
file_path = f"{target_dir}/{safe_filename_str}"
# Save file to storage
storage.save_file(file, file_path)
added_files.append(safe_filename_str)
if original_filename:
relative_key = (
f"{parent_dir}/{safe_filename_str}"
if parent_dir
else safe_filename_str
)
file_name_map[relative_key] = original_filename
map_updated = True
if map_updated:
sources_collection.update_one(
{"_id": ObjectId(source_id)},
{"$set": {"file_name_map": file_name_map}},
)
# Trigger re-ingestion pipeline
from application.api.user.tasks import reingest_source_task
@@ -414,6 +449,7 @@ class ManageSourceFiles(Resource):
# Remove files from storage and directory structure
removed_files = []
map_updated = False
for file_path in file_paths:
full_path = f"{source_file_path}/{file_path}"
@@ -422,6 +458,15 @@ class ManageSourceFiles(Resource):
if storage.file_exists(full_path):
storage.delete_file(full_path)
removed_files.append(file_path)
if file_path in file_name_map:
file_name_map.pop(file_path, None)
map_updated = True
if map_updated and isinstance(file_name_map, dict):
sources_collection.update_one(
{"_id": ObjectId(source_id)},
{"$set": {"file_name_map": file_name_map}},
)
# Trigger re-ingestion pipeline
from application.api.user.tasks import reingest_source_task
@@ -504,6 +549,20 @@ class ManageSourceFiles(Resource):
f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}, "
f"Full path: {full_directory_path}"
)
if directory_path and file_name_map:
prefix = f"{directory_path.rstrip('/')}/"
keys_to_remove = [
key
for key in file_name_map.keys()
if key == directory_path or key.startswith(prefix)
]
if keys_to_remove:
for key in keys_to_remove:
file_name_map.pop(key, None)
sources_collection.update_one(
{"_id": ObjectId(source_id)},
{"$set": {"file_name_map": file_name_map}},
)
# Trigger re-ingestion pipeline
@@ -574,8 +633,9 @@ class TaskStatus(Resource):
):
task_meta = str(task_meta) # Convert to a string representation
except ConnectionError as err:
current_app.logger.error(f"Connection error getting task status: {err}")
return make_response(
jsonify({"success": False, "message": str(err)}), 503
jsonify({"success": False, "message": "Service unavailable"}), 503
)
except Exception as err:
current_app.logger.error(f"Error getting task status: {err}", exc_info=True)

View File

@@ -8,13 +8,25 @@ from application.worker import (
mcp_oauth,
mcp_oauth_status,
remote_worker,
sync,
sync_worker,
)
@celery.task(bind=True)
def ingest(self, directory, formats, job_name, user, file_path, filename):
resp = ingest_worker(self, directory, formats, job_name, file_path, filename, user)
def ingest(
self, directory, formats, job_name, user, file_path, filename, file_name_map=None
):
resp = ingest_worker(
self,
directory,
formats,
job_name,
file_path,
filename,
user,
file_name_map=file_name_map,
)
return resp
@@ -38,6 +50,30 @@ def schedule_syncs(self, frequency):
return resp
@celery.task(bind=True)
def sync_source(
self,
source_data,
job_name,
user,
loader,
sync_frequency,
retriever,
doc_id,
):
resp = sync(
self,
source_data,
job_name,
user,
loader,
sync_frequency,
retriever,
doc_id,
)
return resp
@celery.task(bind=True)
def store_attachment(self, file_info, user):
resp = attachment_worker(self, file_info, user)

View File

@@ -1,7 +1,7 @@
"""Tool management MCP server integration."""
import json
from email.quoprimime import unquote
from urllib.parse import unquote, urlencode
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, redirect, request
@@ -43,6 +43,16 @@ class TestMCPServerConfig(Resource):
return missing_fields
try:
config = data["config"]
transport_type = (config.get("transport_type") or "auto").lower()
allowed_transports = {"auto", "sse", "http"}
if transport_type not in allowed_transports:
return make_response(
jsonify({"success": False, "error": "Unsupported transport_type"}),
400,
)
config.pop("command", None)
config.pop("args", None)
config["transport_type"] = transport_type
auth_credentials = {}
auth_type = config.get("auth_type", "none")
@@ -64,12 +74,17 @@ class TestMCPServerConfig(Resource):
mcp_tool = MCPTool(config=test_config, user_id=user)
result = mcp_tool.test_connection()
# Sanitize the response to avoid exposing internal error details
if not result.get("success") and "message" in result:
current_app.logger.error(f"MCP connection test failed: {result.get('message')}")
result["message"] = "Connection test failed"
return make_response(jsonify(result), 200)
except Exception as e:
current_app.logger.error(f"Error testing MCP server: {e}", exc_info=True)
return make_response(
jsonify(
{"success": False, "error": f"Connection test failed: {str(e)}"}
{"success": False, "error": "Connection test failed"}
),
500,
)
@@ -110,6 +125,16 @@ class MCPServerSave(Resource):
return missing_fields
try:
config = data["config"]
transport_type = (config.get("transport_type") or "auto").lower()
allowed_transports = {"auto", "sse", "http"}
if transport_type not in allowed_transports:
return make_response(
jsonify({"success": False, "error": "Unsupported transport_type"}),
400,
)
config.pop("command", None)
config.pop("args", None)
config["transport_type"] = transport_type
auth_credentials = {}
auth_type = config.get("auth_type", "none")
@@ -234,7 +259,7 @@ class MCPServerSave(Resource):
current_app.logger.error(f"Error saving MCP server: {e}", exc_info=True)
return make_response(
jsonify(
{"success": False, "error": f"Failed to save MCP server: {str(e)}"}
{"success": False, "error": "Failed to save MCP server"}
),
500,
)
@@ -263,9 +288,12 @@ class MCPOAuthCallback(Resource):
error = request.args.get("error")
if error:
return redirect(
f"/api/connectors/callback-status?status=error&message=OAuth+error:+{error}.+Please+try+again+and+make+sure+to+grant+all+requested+permissions,+including+offline+access.&provider=mcp_tool"
)
params = {
"status": "error",
"message": f"OAuth error: {error}. Please try again and make sure to grant all requested permissions, including offline access.",
"provider": "mcp_tool"
}
return redirect(f"/api/connectors/callback-status?{urlencode(params)}")
if not code or not state:
return redirect(
"/api/connectors/callback-status?status=error&message=Authorization+code+or+state+not+provided.+Please+complete+the+authorization+process+and+make+sure+to+grant+offline+access.&provider=mcp_tool"
@@ -292,7 +320,7 @@ class MCPOAuthCallback(Resource):
f"Error handling MCP OAuth callback: {str(e)}", exc_info=True
)
return redirect(
f"/api/connectors/callback-status?status=error&message=Internal+server+error:+{str(e)}.&provider=mcp_tool"
"/api/connectors/callback-status?status=error&message=Internal+server+error.&provider=mcp_tool"
)
@@ -326,8 +354,8 @@ class MCPOAuthStatus(Resource):
)
except Exception as e:
current_app.logger.error(
f"Error getting OAuth status for task {task_id}: {str(e)}"
f"Error getting OAuth status for task {task_id}: {str(e)}", exc_info=True
)
return make_response(
jsonify({"success": False, "error": str(e), "task_id": task_id}), 500
jsonify({"success": False, "error": "Failed to get OAuth status", "task_id": task_id}), 500
)

View File

@@ -4,6 +4,7 @@ from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, Namespace, Resource
from application.agents.tools.spec_parser import parse_spec
from application.agents.tools.tool_manager import ToolManager
from application.api import api
from application.api.user.base import user_tools_collection
@@ -414,3 +415,136 @@ class DeleteTool(Resource):
current_app.logger.error(f"Error deleting tool: {err}", exc_info=True)
return {"success": False}, 400
return {"success": True}, 200
@tools_ns.route("/parse_spec")
class ParseSpec(Resource):
@api.doc(
description="Parse an API specification (OpenAPI 3.x or Swagger 2.0) and return actions"
)
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
if "file" in request.files:
file = request.files["file"]
if not file.filename:
return make_response(
jsonify({"success": False, "message": "No file selected"}), 400
)
try:
spec_content = file.read().decode("utf-8")
except UnicodeDecodeError:
return make_response(
jsonify({"success": False, "message": "Invalid file encoding"}), 400
)
elif request.is_json:
data = request.get_json()
spec_content = data.get("spec_content", "")
else:
return make_response(
jsonify({"success": False, "message": "No spec provided"}), 400
)
if not spec_content or not spec_content.strip():
return make_response(
jsonify({"success": False, "message": "Empty spec content"}), 400
)
try:
metadata, actions = parse_spec(spec_content)
return make_response(
jsonify(
{
"success": True,
"metadata": metadata,
"actions": actions,
}
),
200,
)
except ValueError as e:
current_app.logger.error(f"Spec validation error: {e}")
return make_response(jsonify({"success": False, "error": "Invalid specification format"}), 400)
except Exception as err:
current_app.logger.error(f"Error parsing spec: {err}", exc_info=True)
return make_response(jsonify({"success": False, "error": "Failed to parse specification"}), 500)
@tools_ns.route("/artifact/<artifact_id>")
class GetArtifact(Resource):
@api.doc(description="Get artifact data by artifact ID. Returns all todos for the tool when fetching a todo artifact.")
def get(self, artifact_id: str):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user_id = decoded_token.get("sub")
try:
obj_id = ObjectId(artifact_id)
except Exception:
return make_response(
jsonify({"success": False, "message": "Invalid artifact ID"}), 400
)
from application.core.mongo_db import MongoDB
from application.core.settings import settings
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
note_doc = db["notes"].find_one({"_id": obj_id, "user_id": user_id})
if note_doc:
content = note_doc.get("note", "")
line_count = len(content.split("\n")) if content else 0
artifact = {
"artifact_type": "note",
"data": {
"content": content,
"line_count": line_count,
"updated_at": (
note_doc["updated_at"].isoformat()
if note_doc.get("updated_at")
else None
),
},
}
return make_response(jsonify({"success": True, "artifact": artifact}), 200)
todo_doc = db["todos"].find_one({"_id": obj_id, "user_id": user_id})
if todo_doc:
tool_id = todo_doc.get("tool_id")
# Return all todos for the tool
query = {"user_id": user_id, "tool_id": tool_id}
all_todos = list(db["todos"].find(query))
items = []
open_count = 0
completed_count = 0
for t in all_todos:
status = t.get("status", "open")
if status == "open":
open_count += 1
elif status == "completed":
completed_count += 1
items.append({
"todo_id": t.get("todo_id"),
"title": t.get("title", ""),
"status": status,
"created_at": (
t["created_at"].isoformat() if t.get("created_at") else None
),
"updated_at": (
t["updated_at"].isoformat() if t.get("updated_at") else None
),
})
artifact = {
"artifact_type": "todo_list",
"data": {
"items": items,
"total_count": len(items),
"open_count": open_count,
"completed_count": completed_count,
},
}
return make_response(jsonify({"success": True, "artifact": artifact}), 200)
return make_response(
jsonify({"success": False, "message": "Artifact not found"}), 404
)

View File

@@ -0,0 +1,378 @@
"""Centralized utilities for API routes."""
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Tuple
from bson.errors import InvalidId
from bson.objectid import ObjectId
from flask import jsonify, make_response, request, Response
from pymongo.collection import Collection
def get_user_id() -> Optional[str]:
"""
Extract user ID from decoded JWT token.
Returns:
User ID string or None if not authenticated
"""
decoded_token = getattr(request, "decoded_token", None)
return decoded_token.get("sub") if decoded_token else None
def require_auth(func: Callable) -> Callable:
"""
Decorator to require authentication for route handlers.
Usage:
@require_auth
def get(self):
user_id = get_user_id()
...
"""
@wraps(func)
def wrapper(*args, **kwargs):
user_id = get_user_id()
if not user_id:
return error_response("Unauthorized", 401)
return func(*args, **kwargs)
return wrapper
def success_response(
data: Optional[Dict[str, Any]] = None, status: int = 200
) -> Response:
"""
Create a standardized success response.
Args:
data: Optional data dictionary to include in response
status: HTTP status code (default: 200)
Returns:
Flask Response object
Example:
return success_response({"users": [...], "total": 10})
"""
response = {"success": True}
if data:
response.update(data)
return make_response(jsonify(response), status)
def error_response(message: str, status: int = 400, **kwargs) -> Response:
"""
Create a standardized error response.
Args:
message: Error message string
status: HTTP status code (default: 400)
**kwargs: Additional fields to include in response
Returns:
Flask Response object
Example:
return error_response("Resource not found", 404)
return error_response("Invalid input", 400, errors=["field1", "field2"])
"""
response = {"success": False, "message": message}
response.update(kwargs)
return make_response(jsonify(response), status)
def validate_object_id(
id_string: str, resource_name: str = "Resource"
) -> Tuple[Optional[ObjectId], Optional[Response]]:
"""
Validate and convert string to ObjectId.
Args:
id_string: String to convert
resource_name: Name of resource for error message
Returns:
Tuple of (ObjectId or None, error_response or None)
Example:
obj_id, error = validate_object_id(workflow_id, "Workflow")
if error:
return error
"""
try:
return ObjectId(id_string), None
except (InvalidId, TypeError):
return None, error_response(f"Invalid {resource_name} ID format")
def validate_pagination(
default_limit: int = 20, max_limit: int = 100
) -> Tuple[int, int, Optional[Response]]:
"""
Extract and validate pagination parameters from request.
Args:
default_limit: Default items per page
max_limit: Maximum allowed items per page
Returns:
Tuple of (limit, skip, error_response or None)
Example:
limit, skip, error = validate_pagination()
if error:
return error
"""
try:
limit = min(int(request.args.get("limit", default_limit)), max_limit)
skip = int(request.args.get("skip", 0))
if limit < 1 or skip < 0:
return 0, 0, error_response("Invalid pagination parameters")
return limit, skip, None
except ValueError:
return 0, 0, error_response("Invalid pagination parameters")
def check_resource_ownership(
collection: Collection,
resource_id: ObjectId,
user_id: str,
resource_name: str = "Resource",
) -> Tuple[Optional[Dict], Optional[Response]]:
"""
Check if resource exists and belongs to user.
Args:
collection: MongoDB collection
resource_id: Resource ObjectId
user_id: User ID string
resource_name: Name of resource for error messages
Returns:
Tuple of (resource_dict or None, error_response or None)
Example:
workflow, error = check_resource_ownership(
workflows_collection,
workflow_id,
user_id,
"Workflow"
)
if error:
return error
"""
resource = collection.find_one({"_id": resource_id, "user": user_id})
if not resource:
return None, error_response(f"{resource_name} not found", 404)
return resource, None
def serialize_object_id(
obj: Dict[str, Any], id_field: str = "_id", new_field: str = "id"
) -> Dict[str, Any]:
"""
Convert ObjectId to string in a dictionary.
Args:
obj: Dictionary containing ObjectId
id_field: Field name containing ObjectId
new_field: New field name for string ID
Returns:
Modified dictionary
Example:
user = serialize_object_id(user_doc)
# user["id"] = "507f1f77bcf86cd799439011"
"""
if id_field in obj:
obj[new_field] = str(obj[id_field])
if id_field != new_field:
obj.pop(id_field, None)
return obj
def serialize_list(items: List[Dict], serializer: Callable[[Dict], Dict]) -> List[Dict]:
"""
Apply serializer function to list of items.
Args:
items: List of dictionaries
serializer: Function to apply to each item
Returns:
List of serialized items
Example:
workflows = serialize_list(workflow_docs, serialize_workflow)
"""
return [serializer(item) for item in items]
def paginated_response(
collection: Collection,
query: Dict[str, Any],
serializer: Callable[[Dict], Dict],
limit: int,
skip: int,
sort_field: str = "created_at",
sort_order: int = -1,
response_key: str = "items",
) -> Response:
"""
Create paginated response for collection query.
Args:
collection: MongoDB collection
query: Query dictionary
serializer: Function to serialize each item
limit: Items per page
skip: Number of items to skip
sort_field: Field to sort by
sort_order: Sort order (1=asc, -1=desc)
response_key: Key name for items in response
Returns:
Flask Response with paginated data
Example:
return paginated_response(
workflows_collection,
{"user": user_id},
serialize_workflow,
limit, skip,
response_key="workflows"
)
"""
items = list(
collection.find(query).sort(sort_field, sort_order).skip(skip).limit(limit)
)
total = collection.count_documents(query)
return success_response(
{
response_key: serialize_list(items, serializer),
"total": total,
"limit": limit,
"skip": skip,
}
)
def require_fields(required: List[str]) -> Callable:
"""
Decorator to validate required fields in request JSON.
Args:
required: List of required field names
Returns:
Decorator function
Example:
@require_fields(["name", "description"])
def post(self):
data = request.get_json()
...
"""
def decorator(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args, **kwargs):
data = request.get_json()
if not data:
return error_response("Request body required")
missing = [field for field in required if not data.get(field)]
if missing:
return error_response(f"Missing required fields: {', '.join(missing)}")
return func(*args, **kwargs)
return wrapper
return decorator
def safe_db_operation(
operation: Callable, error_message: str = "Database operation failed"
) -> Tuple[Any, Optional[Response]]:
"""
Safely execute database operation with error handling.
Args:
operation: Function to execute
error_message: Error message if operation fails
Returns:
Tuple of (result or None, error_response or None)
Example:
result, error = safe_db_operation(
lambda: collection.insert_one(doc),
"Failed to create resource"
)
if error:
return error
"""
try:
result = operation()
return result, None
except Exception as e:
return None, error_response(f"{error_message}: {str(e)}")
def validate_enum(
value: Any, allowed: List[Any], field_name: str
) -> Optional[Response]:
"""
Validate that value is in allowed list.
Args:
value: Value to validate
allowed: List of allowed values
field_name: Field name for error message
Returns:
error_response if invalid, None if valid
Example:
error = validate_enum(status, ["draft", "published"], "status")
if error:
return error
"""
if value not in allowed:
allowed_str = ", ".join(f"'{v}'" for v in allowed)
return error_response(f"Invalid {field_name}. Must be one of: {allowed_str}")
return None
def extract_sort_params(
default_field: str = "created_at",
default_order: str = "desc",
allowed_fields: Optional[List[str]] = None,
) -> Tuple[str, int]:
"""
Extract and validate sort parameters from request.
Args:
default_field: Default sort field
default_order: Default sort order ("asc" or "desc")
allowed_fields: List of allowed sort fields (None = no validation)
Returns:
Tuple of (sort_field, sort_order)
Example:
sort_field, sort_order = extract_sort_params(
allowed_fields=["name", "date", "status"]
)
"""
sort_field = request.args.get("sort", default_field)
sort_order_str = request.args.get("order", default_order).lower()
if allowed_fields and sort_field not in allowed_fields:
sort_field = default_field
sort_order = -1 if sort_order_str == "desc" else 1
return sort_field, sort_order

View File

@@ -0,0 +1,3 @@
from .routes import workflows_ns
__all__ = ["workflows_ns"]

View File

@@ -0,0 +1,541 @@
"""Workflow management routes."""
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional, Set
from flask import current_app, request
from flask_restx import Namespace, Resource
from application.api.user.base import (
workflow_edges_collection,
workflow_nodes_collection,
workflows_collection,
)
from application.core.json_schema_utils import (
JsonSchemaValidationError,
normalize_json_schema_payload,
)
from application.core.model_utils import get_model_capabilities
from application.api.user.utils import (
check_resource_ownership,
error_response,
get_user_id,
require_auth,
require_fields,
safe_db_operation,
success_response,
validate_object_id,
)
workflows_ns = Namespace("workflows", path="/api")
def serialize_workflow(w: Dict) -> Dict:
"""Serialize workflow document to API response format."""
return {
"id": str(w["_id"]),
"name": w.get("name"),
"description": w.get("description"),
"created_at": w["created_at"].isoformat() if w.get("created_at") else None,
"updated_at": w["updated_at"].isoformat() if w.get("updated_at") else None,
}
def serialize_node(n: Dict) -> Dict:
"""Serialize workflow node document to API response format."""
return {
"id": n["id"],
"type": n["type"],
"title": n.get("title"),
"description": n.get("description"),
"position": n.get("position"),
"data": n.get("config", {}),
}
def serialize_edge(e: Dict) -> Dict:
"""Serialize workflow edge document to API response format."""
return {
"id": e["id"],
"source": e.get("source_id"),
"target": e.get("target_id"),
"sourceHandle": e.get("source_handle"),
"targetHandle": e.get("target_handle"),
}
def get_workflow_graph_version(workflow: Dict) -> int:
"""Get current graph version with legacy fallback."""
raw_version = workflow.get("current_graph_version", 1)
try:
version = int(raw_version)
return version if version > 0 else 1
except (ValueError, TypeError):
return 1
def fetch_graph_documents(collection, workflow_id: str, graph_version: int) -> List[Dict]:
"""Fetch graph docs for active version, with fallback for legacy unversioned data."""
docs = list(
collection.find({"workflow_id": workflow_id, "graph_version": graph_version})
)
if docs:
return docs
if graph_version == 1:
return list(
collection.find(
{"workflow_id": workflow_id, "graph_version": {"$exists": False}}
)
)
return docs
def validate_json_schema_payload(
json_schema: Any,
) -> tuple[Optional[Dict[str, Any]], Optional[str]]:
"""Validate and normalize optional JSON schema payload for structured output."""
if json_schema is None:
return None, None
try:
return normalize_json_schema_payload(json_schema), None
except JsonSchemaValidationError as exc:
return None, str(exc)
def normalize_agent_node_json_schemas(nodes: List[Dict]) -> List[Dict]:
"""Normalize agent-node JSON schema payloads before persistence."""
normalized_nodes: List[Dict] = []
for node in nodes:
if not isinstance(node, dict):
normalized_nodes.append(node)
continue
normalized_node = dict(node)
if normalized_node.get("type") != "agent":
normalized_nodes.append(normalized_node)
continue
raw_config = normalized_node.get("data")
if not isinstance(raw_config, dict) or "json_schema" not in raw_config:
normalized_nodes.append(normalized_node)
continue
normalized_config = dict(raw_config)
try:
normalized_config["json_schema"] = normalize_json_schema_payload(
raw_config.get("json_schema")
)
except JsonSchemaValidationError:
# Validation runs before normalization; keep original on unexpected shape.
normalized_config["json_schema"] = raw_config.get("json_schema")
normalized_node["data"] = normalized_config
normalized_nodes.append(normalized_node)
return normalized_nodes
def validate_workflow_structure(nodes: List[Dict], edges: List[Dict]) -> List[str]:
"""Validate workflow graph structure."""
errors = []
if not nodes:
errors.append("Workflow must have at least one node")
return errors
start_nodes = [n for n in nodes if n.get("type") == "start"]
if len(start_nodes) != 1:
errors.append("Workflow must have exactly one start node")
end_nodes = [n for n in nodes if n.get("type") == "end"]
if not end_nodes:
errors.append("Workflow must have at least one end node")
node_ids = {n.get("id") for n in nodes}
node_map = {n.get("id"): n for n in nodes}
end_ids = {n.get("id") for n in end_nodes}
for edge in edges:
source_id = edge.get("source")
target_id = edge.get("target")
if source_id not in node_ids:
errors.append(f"Edge references non-existent source: {source_id}")
if target_id not in node_ids:
errors.append(f"Edge references non-existent target: {target_id}")
if start_nodes:
start_id = start_nodes[0].get("id")
if not any(e.get("source") == start_id for e in edges):
errors.append("Start node must have at least one outgoing edge")
condition_nodes = [n for n in nodes if n.get("type") == "condition"]
for cnode in condition_nodes:
cnode_id = cnode.get("id")
cnode_title = cnode.get("title", cnode_id)
outgoing = [e for e in edges if e.get("source") == cnode_id]
if len(outgoing) < 2:
errors.append(
f"Condition node '{cnode_title}' must have at least 2 outgoing edges"
)
node_data = cnode.get("data", {}) or {}
cases = node_data.get("cases", [])
if not isinstance(cases, list):
cases = []
if not cases or not any(
isinstance(c, dict) and str(c.get("expression", "")).strip() for c in cases
):
errors.append(
f"Condition node '{cnode_title}' must have at least one case with an expression"
)
case_handles: Set[str] = set()
duplicate_case_handles: Set[str] = set()
for case in cases:
if not isinstance(case, dict):
continue
raw_handle = case.get("sourceHandle", "")
handle = raw_handle.strip() if isinstance(raw_handle, str) else ""
if not handle:
errors.append(
f"Condition node '{cnode_title}' has a case without a branch handle"
)
continue
if handle in case_handles:
duplicate_case_handles.add(handle)
case_handles.add(handle)
for handle in duplicate_case_handles:
errors.append(
f"Condition node '{cnode_title}' has duplicate case handle '{handle}'"
)
outgoing_by_handle: Dict[str, List[Dict]] = {}
for out_edge in outgoing:
raw_handle = out_edge.get("sourceHandle", "")
handle = raw_handle.strip() if isinstance(raw_handle, str) else ""
outgoing_by_handle.setdefault(handle, []).append(out_edge)
for handle, handle_edges in outgoing_by_handle.items():
if not handle:
errors.append(
f"Condition node '{cnode_title}' has an outgoing edge without sourceHandle"
)
continue
if handle != "else" and handle not in case_handles:
errors.append(
f"Condition node '{cnode_title}' has a connection from unknown branch '{handle}'"
)
if len(handle_edges) > 1:
errors.append(
f"Condition node '{cnode_title}' has multiple outgoing edges from branch '{handle}'"
)
if "else" not in outgoing_by_handle:
errors.append(f"Condition node '{cnode_title}' must have an 'else' branch")
for case in cases:
if not isinstance(case, dict):
continue
raw_handle = case.get("sourceHandle", "")
handle = raw_handle.strip() if isinstance(raw_handle, str) else ""
if not handle:
continue
raw_expression = case.get("expression", "")
has_expression = isinstance(raw_expression, str) and bool(
raw_expression.strip()
)
has_outgoing = bool(outgoing_by_handle.get(handle))
if has_expression and not has_outgoing:
errors.append(
f"Condition node '{cnode_title}' case '{handle}' has an expression but no outgoing edge"
)
if not has_expression and has_outgoing:
errors.append(
f"Condition node '{cnode_title}' case '{handle}' has an outgoing edge but no expression"
)
for handle, handle_edges in outgoing_by_handle.items():
if not handle:
continue
for out_edge in handle_edges:
target = out_edge.get("target")
if target and not _can_reach_end(target, edges, node_map, end_ids):
errors.append(
f"Branch '{handle}' of condition '{cnode_title}' "
f"must eventually reach an end node"
)
agent_nodes = [n for n in nodes if n.get("type") == "agent"]
for agent_node in agent_nodes:
agent_title = agent_node.get("title", agent_node.get("id", "unknown"))
raw_config = agent_node.get("data", {}) or {}
if not isinstance(raw_config, dict):
errors.append(f"Agent node '{agent_title}' has invalid configuration")
continue
normalized_schema, schema_error = validate_json_schema_payload(
raw_config.get("json_schema")
)
has_json_schema = normalized_schema is not None
model_id = raw_config.get("model_id")
if has_json_schema and isinstance(model_id, str) and model_id.strip():
capabilities = get_model_capabilities(model_id.strip())
if capabilities and not capabilities.get("supports_structured_output", False):
errors.append(
f"Agent node '{agent_title}' selected model does not support structured output"
)
if schema_error:
errors.append(f"Agent node '{agent_title}' JSON schema {schema_error}")
for node in nodes:
if not node.get("id"):
errors.append("All nodes must have an id")
if not node.get("type"):
errors.append(f"Node {node.get('id', 'unknown')} must have a type")
return errors
def _can_reach_end(
node_id: str, edges: List[Dict], node_map: Dict, end_ids: set, visited: set = None
) -> bool:
if visited is None:
visited = set()
if node_id in end_ids:
return True
if node_id in visited or node_id not in node_map:
return False
visited.add(node_id)
outgoing = [e.get("target") for e in edges if e.get("source") == node_id]
return any(_can_reach_end(t, edges, node_map, end_ids, visited) for t in outgoing if t)
def create_workflow_nodes(
workflow_id: str, nodes_data: List[Dict], graph_version: int
) -> None:
"""Insert workflow nodes into database."""
if nodes_data:
workflow_nodes_collection.insert_many(
[
{
"id": n["id"],
"workflow_id": workflow_id,
"graph_version": graph_version,
"type": n["type"],
"title": n.get("title", ""),
"description": n.get("description", ""),
"position": n.get("position", {"x": 0, "y": 0}),
"config": n.get("data", {}),
}
for n in nodes_data
]
)
def create_workflow_edges(
workflow_id: str, edges_data: List[Dict], graph_version: int
) -> None:
"""Insert workflow edges into database."""
if edges_data:
workflow_edges_collection.insert_many(
[
{
"id": e["id"],
"workflow_id": workflow_id,
"graph_version": graph_version,
"source_id": e.get("source"),
"target_id": e.get("target"),
"source_handle": e.get("sourceHandle"),
"target_handle": e.get("targetHandle"),
}
for e in edges_data
]
)
@workflows_ns.route("/workflows")
class WorkflowList(Resource):
@require_auth
@require_fields(["name"])
def post(self):
"""Create a new workflow with nodes and edges."""
user_id = get_user_id()
data = request.get_json()
name = data.get("name", "").strip()
nodes_data = data.get("nodes", [])
edges_data = data.get("edges", [])
validation_errors = validate_workflow_structure(nodes_data, edges_data)
if validation_errors:
return error_response(
"Workflow validation failed", errors=validation_errors
)
nodes_data = normalize_agent_node_json_schemas(nodes_data)
now = datetime.now(timezone.utc)
workflow_doc = {
"name": name,
"description": data.get("description", ""),
"user": user_id,
"created_at": now,
"updated_at": now,
"current_graph_version": 1,
}
result, error = safe_db_operation(
lambda: workflows_collection.insert_one(workflow_doc),
"Failed to create workflow",
)
if error:
return error
workflow_id = str(result.inserted_id)
try:
create_workflow_nodes(workflow_id, nodes_data, 1)
create_workflow_edges(workflow_id, edges_data, 1)
except Exception as e:
workflow_nodes_collection.delete_many({"workflow_id": workflow_id})
workflow_edges_collection.delete_many({"workflow_id": workflow_id})
workflows_collection.delete_one({"_id": result.inserted_id})
return error_response(f"Failed to create workflow structure: {str(e)}")
return success_response({"id": workflow_id}, 201)
@workflows_ns.route("/workflows/<string:workflow_id>")
class WorkflowDetail(Resource):
@require_auth
def get(self, workflow_id: str):
"""Get workflow details with nodes and edges."""
user_id = get_user_id()
obj_id, error = validate_object_id(workflow_id, "Workflow")
if error:
return error
workflow, error = check_resource_ownership(
workflows_collection, obj_id, user_id, "Workflow"
)
if error:
return error
graph_version = get_workflow_graph_version(workflow)
nodes = fetch_graph_documents(
workflow_nodes_collection, workflow_id, graph_version
)
edges = fetch_graph_documents(
workflow_edges_collection, workflow_id, graph_version
)
return success_response(
{
"workflow": serialize_workflow(workflow),
"nodes": [serialize_node(n) for n in nodes],
"edges": [serialize_edge(e) for e in edges],
}
)
@require_auth
@require_fields(["name"])
def put(self, workflow_id: str):
"""Update workflow and replace nodes/edges."""
user_id = get_user_id()
obj_id, error = validate_object_id(workflow_id, "Workflow")
if error:
return error
workflow, error = check_resource_ownership(
workflows_collection, obj_id, user_id, "Workflow"
)
if error:
return error
data = request.get_json()
name = data.get("name", "").strip()
nodes_data = data.get("nodes", [])
edges_data = data.get("edges", [])
validation_errors = validate_workflow_structure(nodes_data, edges_data)
if validation_errors:
return error_response(
"Workflow validation failed", errors=validation_errors
)
nodes_data = normalize_agent_node_json_schemas(nodes_data)
current_graph_version = get_workflow_graph_version(workflow)
next_graph_version = current_graph_version + 1
try:
create_workflow_nodes(workflow_id, nodes_data, next_graph_version)
create_workflow_edges(workflow_id, edges_data, next_graph_version)
except Exception as e:
workflow_nodes_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": next_graph_version}
)
workflow_edges_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": next_graph_version}
)
return error_response(f"Failed to update workflow structure: {str(e)}")
now = datetime.now(timezone.utc)
_, error = safe_db_operation(
lambda: workflows_collection.update_one(
{"_id": obj_id},
{
"$set": {
"name": name,
"description": data.get("description", ""),
"updated_at": now,
"current_graph_version": next_graph_version,
}
},
),
"Failed to update workflow",
)
if error:
workflow_nodes_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": next_graph_version}
)
workflow_edges_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": next_graph_version}
)
return error
try:
workflow_nodes_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": {"$ne": next_graph_version}}
)
workflow_edges_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": {"$ne": next_graph_version}}
)
except Exception as cleanup_err:
current_app.logger.warning(
f"Failed to clean old workflow graph versions for {workflow_id}: {cleanup_err}"
)
return success_response()
@require_auth
def delete(self, workflow_id: str):
"""Delete workflow and its graph."""
user_id = get_user_id()
obj_id, error = validate_object_id(workflow_id, "Workflow")
if error:
return error
workflow, error = check_resource_ownership(
workflows_collection, obj_id, user_id, "Workflow"
)
if error:
return error
try:
workflow_nodes_collection.delete_many({"workflow_id": workflow_id})
workflow_edges_collection.delete_many({"workflow_id": workflow_id})
workflows_collection.delete_one({"_id": workflow["_id"], "user": user_id})
except Exception as e:
return error_response(f"Failed to delete workflow: {str(e)}")
return success_response()

View File

@@ -19,9 +19,9 @@ def handle_auth(request, data={}):
options={"verify_exp": False},
)
return decoded_token
except Exception as e:
except Exception:
return {
"message": f"Authentication error: {str(e)}",
"message": "Authentication error: invalid token",
"error": "invalid_token",
}
else:

View File

@@ -21,3 +21,4 @@ def config_loggers(*args, **kwargs):
celery = make_celery()
celery.config_from_object("application.celeryconfig")

View File

@@ -6,3 +6,6 @@ result_backend = os.getenv("CELERY_RESULT_BACKEND")
task_serializer = 'json'
result_serializer = 'json'
accept_content = ['json']
# Autodiscover tasks
imports = ('application.api.user.tasks',)

View File

@@ -0,0 +1,34 @@
from typing import Any, Dict, Optional
class JsonSchemaValidationError(ValueError):
"""Raised when a JSON schema payload is invalid."""
def normalize_json_schema_payload(json_schema: Any) -> Optional[Dict[str, Any]]:
"""
Normalize accepted JSON schema payload shapes to a plain schema object.
Accepted inputs:
- None
- A raw schema object with a top-level "type"
- A wrapped payload with a top-level "schema" object
"""
if json_schema is None:
return None
if not isinstance(json_schema, dict):
raise JsonSchemaValidationError("must be a valid JSON object")
wrapped_schema = json_schema.get("schema")
if wrapped_schema is not None:
if not isinstance(wrapped_schema, dict):
raise JsonSchemaValidationError('field "schema" must be a valid JSON object')
return wrapped_schema
if "type" not in json_schema:
raise JsonSchemaValidationError(
'must include either a "type" or "schema" field'
)
return json_schema

View File

@@ -8,8 +8,8 @@ from application.core.model_settings import (
ModelProvider,
)
OPENAI_ATTACHMENTS = [
"application/pdf",
# Base image attachment types supported by most vision-capable LLMs
IMAGE_ATTACHMENTS = [
"image/png",
"image/jpeg",
"image/jpg",
@@ -17,14 +17,15 @@ OPENAI_ATTACHMENTS = [
"image/gif",
]
GOOGLE_ATTACHMENTS = [
"application/pdf",
"image/png",
"image/jpeg",
"image/jpg",
"image/webp",
"image/gif",
]
# PDF excluded: most OpenAI-compatible endpoints don't support native PDF uploads.
# When excluded, PDFs are synthetically processed by converting pages to images.
OPENAI_ATTACHMENTS = IMAGE_ATTACHMENTS
GOOGLE_ATTACHMENTS = ["application/pdf"] + IMAGE_ATTACHMENTS
ANTHROPIC_ATTACHMENTS = IMAGE_ATTACHMENTS
OPENROUTER_ATTACHMENTS = IMAGE_ATTACHMENTS
OPENAI_MODELS = [
@@ -63,6 +64,7 @@ ANTHROPIC_MODELS = [
description="Latest Claude 3.5 Sonnet with enhanced capabilities",
capabilities=ModelCapabilities(
supports_tools=True,
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
context_window=200000,
),
),
@@ -73,6 +75,7 @@ ANTHROPIC_MODELS = [
description="Balanced performance and capability",
capabilities=ModelCapabilities(
supports_tools=True,
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
context_window=200000,
),
),
@@ -83,6 +86,7 @@ ANTHROPIC_MODELS = [
description="Most capable Claude model",
capabilities=ModelCapabilities(
supports_tools=True,
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
context_window=200000,
),
),
@@ -93,6 +97,7 @@ ANTHROPIC_MODELS = [
description="Fastest Claude model",
capabilities=ModelCapabilities(
supports_tools=True,
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
context_window=200000,
),
),
@@ -151,28 +156,43 @@ GROQ_MODELS = [
),
),
AvailableModel(
id="llama-3.1-8b-instant",
id="openai/gpt-oss-120b",
provider=ModelProvider.GROQ,
display_name="Llama 3.1 8B",
description="Ultra-fast inference",
display_name="GPT-OSS 120B",
description="Open-source GPT model optimized for speed",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=128000,
),
),
]
OPENROUTER_MODELS = [
AvailableModel(
id="mixtral-8x7b-32768",
provider=ModelProvider.GROQ,
display_name="Mixtral 8x7B",
description="High-speed inference with tools",
id="qwen/qwen3-coder:free",
provider=ModelProvider.OPENROUTER,
display_name="Qwen 3 Coder",
description="Latest Qwen model with high-speed inference",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=32768,
context_window=128000,
supported_attachment_types=OPENROUTER_ATTACHMENTS
),
),
AvailableModel(
id="google/gemma-3-27b-it:free",
provider=ModelProvider.OPENROUTER,
display_name="Gemma 3 27B",
description="Latest Gemma model with high-speed inference",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=128000,
supported_attachment_types=OPENROUTER_ATTACHMENTS
),
),
]
AZURE_OPENAI_MODELS = [
AvailableModel(
id="azure-gpt-4",
@@ -187,3 +207,18 @@ AZURE_OPENAI_MODELS = [
),
),
]
def create_custom_openai_model(model_name: str, base_url: str) -> AvailableModel:
"""Create a custom OpenAI-compatible model (e.g., LM Studio, Ollama)."""
return AvailableModel(
id=model_name,
provider=ModelProvider.OPENAI,
display_name=model_name,
description=f"Custom OpenAI-compatible model at {base_url}",
base_url=base_url,
capabilities=ModelCapabilities(
supports_tools=True,
supported_attachment_types=OPENAI_ATTACHMENTS,
),
)

View File

@@ -8,6 +8,7 @@ logger = logging.getLogger(__name__)
class ModelProvider(str, Enum):
OPENAI = "openai"
OPENROUTER = "openrouter"
AZURE_OPENAI = "azure_openai"
ANTHROPIC = "anthropic"
GROQ = "groq"
@@ -84,9 +85,13 @@ class ModelRegistry:
self.models.clear()
self._add_docsgpt_models(settings)
if settings.OPENAI_API_KEY or (
settings.LLM_PROVIDER == "openai" and settings.API_KEY
# Skip DocsGPT model if using custom OpenAI-compatible endpoint
if not settings.OPENAI_BASE_URL:
self._add_docsgpt_models(settings)
if (
settings.OPENAI_API_KEY
or (settings.LLM_PROVIDER == "openai" and settings.API_KEY)
or settings.OPENAI_BASE_URL
):
self._add_openai_models(settings)
if settings.OPENAI_API_BASE or (
@@ -105,39 +110,69 @@ class ModelRegistry:
settings.LLM_PROVIDER == "groq" and settings.API_KEY
):
self._add_groq_models(settings)
if settings.OPEN_ROUTER_API_KEY or (
settings.LLM_PROVIDER == "openrouter" and settings.API_KEY
):
self._add_openrouter_models(settings)
if settings.HUGGINGFACE_API_KEY or (
settings.LLM_PROVIDER == "huggingface" and settings.API_KEY
):
self._add_huggingface_models(settings)
# Default model selection
if settings.LLM_NAME and settings.LLM_NAME in self.models:
self.default_model_id = settings.LLM_NAME
elif settings.LLM_PROVIDER and settings.API_KEY:
for model_id, model in self.models.items():
if model.provider.value == settings.LLM_PROVIDER:
self.default_model_id = model_id
if settings.LLM_NAME:
# Parse LLM_NAME (may be comma-separated)
model_names = self._parse_model_names(settings.LLM_NAME)
# First model in the list becomes default
for model_name in model_names:
if model_name in self.models:
self.default_model_id = model_name
break
else:
# Backward compat: try exact match if no parsed model found
if not self.default_model_id and settings.LLM_NAME in self.models:
self.default_model_id = settings.LLM_NAME
if not self.default_model_id:
if settings.LLM_PROVIDER and settings.API_KEY:
for model_id, model in self.models.items():
if model.provider.value == settings.LLM_PROVIDER:
self.default_model_id = model_id
break
if not self.default_model_id and self.models:
self.default_model_id = next(iter(self.models.keys()))
logger.info(
f"ModelRegistry loaded {len(self.models)} models, default: {self.default_model_id}"
)
def _add_openai_models(self, settings):
from application.core.model_configs import OPENAI_MODELS
from application.core.model_configs import (
OPENAI_MODELS,
create_custom_openai_model,
)
if settings.OPENAI_API_KEY:
for model in OPENAI_MODELS:
self.models[model.id] = model
return
if settings.LLM_PROVIDER == "openai" and settings.LLM_NAME:
for model in OPENAI_MODELS:
if model.id == settings.LLM_NAME:
# Check if using local OpenAI-compatible endpoint (Ollama, LM Studio, etc.)
using_local_endpoint = bool(
settings.OPENAI_BASE_URL and settings.OPENAI_BASE_URL.strip()
)
if using_local_endpoint:
# When OPENAI_BASE_URL is set, ONLY register custom models from LLM_NAME
# Do NOT add standard OpenAI models (gpt-5.1, etc.)
if settings.LLM_NAME:
model_names = self._parse_model_names(settings.LLM_NAME)
for model_name in model_names:
custom_model = create_custom_openai_model(
model_name, settings.OPENAI_BASE_URL
)
self.models[model_name] = custom_model
logger.info(
f"Registered custom OpenAI model: {model_name} at {settings.OPENAI_BASE_URL}"
)
else:
# Standard OpenAI API usage - add standard models if API key is valid
if settings.OPENAI_API_KEY:
for model in OPENAI_MODELS:
self.models[model.id] = model
return
for model in OPENAI_MODELS:
self.models[model.id] = model
def _add_azure_openai_models(self, settings):
from application.core.model_configs import AZURE_OPENAI_MODELS
@@ -194,6 +229,21 @@ class ModelRegistry:
return
for model in GROQ_MODELS:
self.models[model.id] = model
def _add_openrouter_models(self, settings):
from application.core.model_configs import OPENROUTER_MODELS
if settings.OPEN_ROUTER_API_KEY:
for model in OPENROUTER_MODELS:
self.models[model.id] = model
return
if settings.LLM_PROVIDER == "openrouter" and settings.LLM_NAME:
for model in OPENROUTER_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in OPENROUTER_MODELS:
self.models[model.id] = model
def _add_docsgpt_models(self, settings):
model_id = "docsgpt-local"
@@ -223,6 +273,15 @@ class ModelRegistry:
)
self.models[model_id] = model
def _parse_model_names(self, llm_name: str) -> List[str]:
"""
Parse LLM_NAME which may contain comma-separated model names.
E.g., 'deepseek-r1:1.5b,gemma:2b' -> ['deepseek-r1:1.5b', 'gemma:2b']
"""
if not llm_name:
return []
return [name.strip() for name in llm_name.split(",") if name.strip()]
def get_model(self, model_id: str) -> Optional[AvailableModel]:
return self.models.get(model_id)

View File

@@ -9,6 +9,7 @@ def get_api_key_for_provider(provider: str) -> Optional[str]:
provider_key_map = {
"openai": settings.OPENAI_API_KEY,
"openrouter": settings.OPEN_ROUTER_API_KEY,
"anthropic": settings.ANTHROPIC_API_KEY,
"google": settings.GOOGLE_API_KEY,
"groq": settings.GROQ_API_KEY,

View File

@@ -2,7 +2,8 @@ import os
from pathlib import Path
from typing import Optional
from pydantic_settings import BaseSettings
from pydantic import field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
@@ -10,12 +11,19 @@ current_dir = os.path.dirname(
class Settings(BaseSettings):
model_config = SettingsConfigDict(extra="ignore")
AUTH_TYPE: Optional[str] = None # simple_jwt, session_jwt, or None
LLM_PROVIDER: str = "docsgpt"
LLM_NAME: Optional[str] = (
None # if LLM_PROVIDER is openai, LLM_NAME can be gpt-4 or gpt-3.5-turbo
)
EMBEDDINGS_NAME: str = "huggingface_sentence-transformers/all-mpnet-base-v2"
EMBEDDINGS_BASE_URL: Optional[str] = None # Remote embeddings API URL (OpenAI-compatible)
EMBEDDINGS_KEY: Optional[str] = (
None # api key for embeddings (if using openai, just copy API_KEY)
)
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
MONGO_URI: str = "mongodb://localhost:27017/docsgpt"
@@ -35,8 +43,10 @@ class Settings(BaseSettings):
UPLOAD_FOLDER: str = "inputs"
PARSE_PDF_AS_IMAGE: bool = False
PARSE_IMAGE_REMOTE: bool = False
DOCLING_OCR_ENABLED: bool = False # Enable OCR for docling parsers (PDF, images)
DOCLING_OCR_ATTACHMENTS_ENABLED: bool = False # Enable OCR for docling when parsing attachments
VECTOR_STORE: str = (
"faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb"
"faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb" or "pgvector"
)
RETRIEVERS_ENABLED: list = ["classic_rag"]
AGENT_NAME: str = "classic"
@@ -72,10 +82,8 @@ class Settings(BaseSettings):
GOOGLE_API_KEY: Optional[str] = None
GROQ_API_KEY: Optional[str] = None
HUGGINGFACE_API_KEY: Optional[str] = None
OPEN_ROUTER_API_KEY: Optional[str] = None
EMBEDDINGS_KEY: Optional[str] = (
None # api key for embeddings (if using openai, just copy API_KEY)
)
OPENAI_API_BASE: Optional[str] = None # azure openai api base url
OPENAI_API_VERSION: Optional[str] = None # azure openai api version
AZURE_DEPLOYMENT_NAME: Optional[str] = None # azure deployment name for answering
@@ -125,7 +133,7 @@ class Settings(BaseSettings):
MILVUS_TOKEN: Optional[str] = ""
# LanceDB vectorstore config
LANCEDB_PATH: str = "/tmp/lancedb" # Path where LanceDB stores its local data
LANCEDB_PATH: str = "./data/lancedb" # Path where LanceDB stores its local data
LANCEDB_TABLE_NAME: Optional[str] = (
"docsgpts" # Name of the table to use for storing vectors
)
@@ -152,6 +160,37 @@ class Settings(BaseSettings):
COMPRESSION_PROMPT_VERSION: str = "v1.0" # Track prompt iterations
COMPRESSION_MAX_HISTORY_POINTS: int = 3 # Keep only last N compression points to prevent DB bloat
@field_validator(
"API_KEY",
"OPENAI_API_KEY",
"ANTHROPIC_API_KEY",
"GOOGLE_API_KEY",
"GROQ_API_KEY",
"HUGGINGFACE_API_KEY",
"EMBEDDINGS_KEY",
"FALLBACK_LLM_API_KEY",
"QDRANT_API_KEY",
"ELEVENLABS_API_KEY",
"INTERNAL_KEY",
mode="before",
)
@classmethod
def normalize_api_key(cls, v: Optional[str]) -> Optional[str]:
"""
Normalize API keys: convert 'None', 'none', empty strings,
and whitespace-only strings to actual None.
Handles Pydantic loading 'None' from .env as string "None".
"""
if v is None:
return None
if not isinstance(v, str):
return v
stripped = v.strip()
if stripped == "" or stripped.lower() == "none":
return None
return stripped
path = Path(__file__).parent.parent.absolute()
# Project root is one level above application/
path = Path(__file__).parent.parent.parent.absolute()
settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8")

View File

@@ -0,0 +1,181 @@
"""
URL validation utilities to prevent SSRF (Server-Side Request Forgery) attacks.
This module provides functions to validate URLs before making HTTP requests,
blocking access to internal networks, cloud metadata services, and other
potentially dangerous endpoints.
"""
import ipaddress
import socket
from urllib.parse import urlparse
from typing import Optional, Set
class SSRFError(Exception):
"""Raised when a URL fails SSRF validation."""
pass
# Blocked hostnames that should never be accessed
BLOCKED_HOSTNAMES: Set[str] = {
"localhost",
"localhost.localdomain",
"metadata.google.internal",
"metadata",
}
# Cloud metadata IP addresses (AWS, GCP, Azure, etc.)
METADATA_IPS: Set[str] = {
"169.254.169.254", # AWS, GCP, Azure metadata
"169.254.170.2", # AWS ECS task metadata
"fd00:ec2::254", # AWS IPv6 metadata
}
# Allowed schemes for external requests
ALLOWED_SCHEMES: Set[str] = {"http", "https"}
def is_private_ip(ip_str: str) -> bool:
"""
Check if an IP address is private, loopback, or link-local.
Args:
ip_str: IP address as a string
Returns:
True if the IP is private/internal, False otherwise
"""
try:
ip = ipaddress.ip_address(ip_str)
return (
ip.is_private or
ip.is_loopback or
ip.is_link_local or
ip.is_reserved or
ip.is_multicast or
ip.is_unspecified
)
except ValueError:
# If we can't parse it as an IP, return False
return False
def is_metadata_ip(ip_str: str) -> bool:
"""
Check if an IP address is a cloud metadata service IP.
Args:
ip_str: IP address as a string
Returns:
True if the IP is a metadata service, False otherwise
"""
return ip_str in METADATA_IPS
def resolve_hostname(hostname: str) -> Optional[str]:
"""
Resolve a hostname to an IP address.
Args:
hostname: The hostname to resolve
Returns:
The resolved IP address, or None if resolution fails
"""
try:
return socket.gethostbyname(hostname)
except socket.gaierror:
return None
def validate_url(url: str, allow_localhost: bool = False) -> str:
"""
Validate a URL to prevent SSRF attacks.
This function checks that:
1. The URL has an allowed scheme (http or https)
2. The hostname is not a blocked hostname
3. The resolved IP is not a private/internal IP
4. The resolved IP is not a cloud metadata service
Args:
url: The URL to validate
allow_localhost: If True, allow localhost connections (for testing only)
Returns:
The validated URL (with scheme added if missing)
Raises:
SSRFError: If the URL fails validation
"""
# Ensure URL has a scheme
if not urlparse(url).scheme:
url = "http://" + url
parsed = urlparse(url)
# Check scheme
if parsed.scheme not in ALLOWED_SCHEMES:
raise SSRFError(f"URL scheme '{parsed.scheme}' is not allowed. Only HTTP(S) is permitted.")
hostname = parsed.hostname
if not hostname:
raise SSRFError("URL must have a valid hostname.")
hostname_lower = hostname.lower()
# Check blocked hostnames
if hostname_lower in BLOCKED_HOSTNAMES and not allow_localhost:
raise SSRFError(f"Access to '{hostname}' is not allowed.")
# Check if hostname is an IP address directly
try:
ip = ipaddress.ip_address(hostname)
ip_str = str(ip)
if is_metadata_ip(ip_str):
raise SSRFError("Access to cloud metadata services is not allowed.")
if is_private_ip(ip_str) and not allow_localhost:
raise SSRFError("Access to private/internal IP addresses is not allowed.")
return url
except ValueError:
# Not an IP address, it's a hostname - resolve it
pass
# Resolve hostname and check the IP
resolved_ip = resolve_hostname(hostname)
if resolved_ip is None:
raise SSRFError(f"Unable to resolve hostname: {hostname}")
if is_metadata_ip(resolved_ip):
raise SSRFError("Access to cloud metadata services is not allowed.")
if is_private_ip(resolved_ip) and not allow_localhost:
raise SSRFError("Access to private/internal networks is not allowed.")
return url
def validate_url_safe(url: str, allow_localhost: bool = False) -> tuple[bool, str, Optional[str]]:
"""
Validate a URL and return a tuple with validation result.
This is a non-throwing version of validate_url for cases where
you want to handle validation failures gracefully.
Args:
url: The URL to validate
allow_localhost: If True, allow localhost connections (for testing only)
Returns:
Tuple of (is_valid, validated_url_or_original, error_message_or_none)
"""
try:
validated = validate_url(url, allow_localhost)
return (True, validated, None)
except SSRFError as e:
return (False, url, str(e))

View File

@@ -1,7 +1,13 @@
import base64
import logging
from anthropic import AI_PROMPT, Anthropic, HUMAN_PROMPT
from application.core.settings import settings
from application.llm.base import BaseLLM
from application.storage.storage_creator import StorageCreator
logger = logging.getLogger(__name__)
class AnthropicLLM(BaseLLM):
@@ -20,6 +26,7 @@ class AnthropicLLM(BaseLLM):
self.HUMAN_PROMPT = HUMAN_PROMPT
self.AI_PROMPT = AI_PROMPT
self.storage = StorageCreator.get_storage()
def _raw_gen(
self,
@@ -70,3 +77,115 @@ class AnthropicLLM(BaseLLM):
finally:
if hasattr(stream_response, "close"):
stream_response.close()
def get_supported_attachment_types(self):
"""
Return a list of MIME types supported by Anthropic Claude for file uploads.
Claude supports images but not PDFs natively.
PDFs are synthetically supported via PDF-to-image conversion in the handler.
Returns:
list: List of supported MIME types
"""
return [
"image/png",
"image/jpeg",
"image/jpg",
"image/webp",
"image/gif",
]
def prepare_messages_with_attachments(self, messages, attachments=None):
"""
Process attachments for Anthropic Claude API.
Formats images using Claude's vision message format.
Args:
messages (list): List of message dictionaries.
attachments (list): List of attachment dictionaries with content and metadata.
Returns:
list: Messages formatted with image content for Claude API.
"""
if not attachments:
return messages
prepared_messages = messages.copy()
# Find the last user message to attach images to
user_message_index = None
for i in range(len(prepared_messages) - 1, -1, -1):
if prepared_messages[i].get("role") == "user":
user_message_index = i
break
if user_message_index is None:
user_message = {"role": "user", "content": []}
prepared_messages.append(user_message)
user_message_index = len(prepared_messages) - 1
# Convert content to list format if it's a string
if isinstance(prepared_messages[user_message_index].get("content"), str):
text_content = prepared_messages[user_message_index]["content"]
prepared_messages[user_message_index]["content"] = [
{"type": "text", "text": text_content}
]
elif not isinstance(prepared_messages[user_message_index].get("content"), list):
prepared_messages[user_message_index]["content"] = []
for attachment in attachments:
mime_type = attachment.get("mime_type")
if mime_type and mime_type.startswith("image/"):
try:
# Check if this is a pre-converted image (from PDF-to-image conversion)
# These have 'data' key with base64 already
if "data" in attachment:
base64_image = attachment["data"]
else:
base64_image = self._get_base64_image(attachment)
# Claude uses a specific format for images
prepared_messages[user_message_index]["content"].append(
{
"type": "image",
"source": {
"type": "base64",
"media_type": mime_type,
"data": base64_image,
},
}
)
except Exception as e:
logger.error(
f"Error processing image attachment: {e}", exc_info=True
)
if "content" in attachment:
prepared_messages[user_message_index]["content"].append(
{
"type": "text",
"text": f"[Image could not be processed: {attachment.get('path', 'unknown')}]",
}
)
return prepared_messages
def _get_base64_image(self, attachment):
"""
Convert an image file to base64 encoding.
Args:
attachment (dict): Attachment dictionary with path and metadata.
Returns:
str: Base64-encoded image data.
"""
file_path = attachment.get("path")
if not file_path:
raise ValueError("No file path provided in attachment")
try:
with self.storage.get_file(file_path) as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
except FileNotFoundError:
raise FileNotFoundError(f"File not found: {file_path}")

View File

@@ -13,10 +13,12 @@ class BaseLLM(ABC):
def __init__(
self,
decoded_token=None,
agent_id=None,
model_id=None,
base_url=None,
):
self.decoded_token = decoded_token
self.agent_id = str(agent_id) if agent_id else None
self.model_id = model_id
self.base_url = base_url
self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
@@ -33,9 +35,10 @@ class BaseLLM(ABC):
self._fallback_llm = LLMCreator.create_llm(
settings.FALLBACK_LLM_PROVIDER,
api_key=settings.FALLBACK_LLM_API_KEY or settings.API_KEY,
user_api_key=None,
user_api_key=getattr(self, "user_api_key", None),
decoded_token=self.decoded_token,
model_id=settings.FALLBACK_LLM_NAME,
agent_id=self.agent_id,
)
logger.info(
f"Fallback LLM initialized: {settings.FALLBACK_LLM_PROVIDER}/{settings.FALLBACK_LLM_NAME}"

View File

@@ -1,75 +1,19 @@
import json
from openai import OpenAI
from application.core.settings import settings
from application.llm.base import BaseLLM
from application.llm.openai import OpenAILLM
DOCSGPT_API_KEY = "sk-docsgpt-public"
DOCSGPT_BASE_URL = "https://oai.arc53.com"
DOCSGPT_MODEL = "docsgpt"
class DocsGPTAPILLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.api_key = "sk-docsgpt-public"
self.client = OpenAI(api_key=self.api_key, base_url="https://oai.arc53.com")
self.user_api_key = user_api_key
def _clean_messages_openai(self, messages):
cleaned_messages = []
for message in messages:
role = message.get("role")
content = message.get("content")
if role == "model":
role = "assistant"
if role and content is not None:
if isinstance(content, str):
cleaned_messages.append({"role": role, "content": content})
elif isinstance(content, list):
for item in content:
if "text" in item:
cleaned_messages.append(
{"role": role, "content": item["text"]}
)
elif "function_call" in item:
cleaned_args = self._remove_null_values(
item["function_call"]["args"]
)
tool_call = {
"id": item["function_call"]["call_id"],
"type": "function",
"function": {
"name": item["function_call"]["name"],
"arguments": json.dumps(cleaned_args),
},
}
cleaned_messages.append(
{
"role": "assistant",
"content": None,
"tool_calls": [tool_call],
}
)
elif "function_response" in item:
cleaned_messages.append(
{
"role": "tool",
"tool_call_id": item["function_response"][
"call_id"
],
"content": json.dumps(
item["function_response"]["response"]["result"]
),
}
)
else:
raise ValueError(
f"Unexpected content dictionary format: {item}"
)
else:
raise ValueError(f"Unexpected content type: {type(content)}")
return cleaned_messages
class DocsGPTAPILLM(OpenAILLM):
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
super().__init__(
api_key=DOCSGPT_API_KEY,
user_api_key=user_api_key,
base_url=DOCSGPT_BASE_URL,
*args,
**kwargs,
)
def _raw_gen(
self,
@@ -79,23 +23,19 @@ class DocsGPTAPILLM(BaseLLM):
stream=False,
tools=None,
engine=settings.AZURE_DEPLOYMENT_NAME,
response_format=None,
**kwargs,
):
messages = self._clean_messages_openai(messages)
if tools:
response = self.client.chat.completions.create(
model="docsgpt",
messages=messages,
stream=stream,
tools=tools,
**kwargs,
)
return response.choices[0]
else:
response = self.client.chat.completions.create(
model="docsgpt", messages=messages, stream=stream, **kwargs
)
return response.choices[0].message.content
return super()._raw_gen(
baseself,
DOCSGPT_MODEL,
messages,
stream=stream,
tools=tools,
engine=engine,
response_format=response_format,
**kwargs,
)
def _raw_gen_stream(
self,
@@ -105,34 +45,16 @@ class DocsGPTAPILLM(BaseLLM):
stream=True,
tools=None,
engine=settings.AZURE_DEPLOYMENT_NAME,
response_format=None,
**kwargs,
):
messages = self._clean_messages_openai(messages)
if tools:
response = self.client.chat.completions.create(
model="docsgpt",
messages=messages,
stream=stream,
tools=tools,
**kwargs,
)
else:
response = self.client.chat.completions.create(
model="docsgpt", messages=messages, stream=stream, **kwargs
)
try:
for line in response:
if (
len(line.choices) > 0
and line.choices[0].delta.content is not None
and len(line.choices[0].delta.content) > 0
):
yield line.choices[0].delta.content
elif len(line.choices) > 0:
yield line.choices[0]
finally:
if hasattr(response, "close"):
response.close()
def _supports_tools(self):
return True
return super()._raw_gen_stream(
baseself,
DOCSGPT_MODEL,
messages,
stream=stream,
tools=tools,
engine=engine,
response_format=response_format,
**kwargs,
)

View File

@@ -13,7 +13,7 @@ class GoogleLLM(BaseLLM):
def __init__(
self, api_key=None, user_api_key=None, decoded_token=None, *args, **kwargs
):
super().__init__(*args, **kwargs)
super().__init__(decoded_token=decoded_token, *args, **kwargs)
self.api_key = api_key or settings.GOOGLE_API_KEY or settings.API_KEY
self.user_api_key = user_api_key
@@ -378,6 +378,22 @@ class GoogleLLM(BaseLLM):
last_preview = f"{last_preview[:preview_chars]}..."
return f"count={message_count}, last='{last_preview}'"
@staticmethod
def _get_text_value(part):
"""Get text from both SDK objects and dict-shaped test doubles."""
if isinstance(part, dict):
value = part.get("text")
return value if isinstance(value, str) else ""
value = getattr(part, "text", None)
return value if isinstance(value, str) else ""
@staticmethod
def _is_thought_part(part):
"""Detect Gemini thinking parts when available."""
if isinstance(part, dict):
return bool(part.get("thought"))
return bool(getattr(part, "thought", False))
def _raw_gen(
self,
baseself,
@@ -438,7 +454,6 @@ class GoogleLLM(BaseLLM):
if tools:
cleaned_tools = self._clean_tools_format(tools)
config.tools = cleaned_tools
# Add response schema for structured output if provided
if response_schema:
config.response_schema = response_schema
@@ -475,10 +490,23 @@ class GoogleLLM(BaseLLM):
for part in candidate.content.parts:
if part.function_call:
yield part
elif part.text:
yield part.text
continue
part_text = self._get_text_value(part)
if not part_text:
continue
if self._is_thought_part(part):
yield {"type": "thought", "thought": part_text}
else:
yield part_text
elif hasattr(chunk, "text"):
yield chunk.text
chunk_text = self._get_text_value(chunk)
if chunk_text:
if self._is_thought_part(chunk):
yield {"type": "thought", "thought": chunk_text}
else:
yield chunk_text
finally:
if hasattr(response, "close"):
response.close()

View File

@@ -1,37 +1,15 @@
from openai import OpenAI
from application.core.settings import settings
from application.llm.base import BaseLLM
from application.llm.openai import OpenAILLM
GROQ_BASE_URL = "https://api.groq.com/openai/v1"
class GroqLLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.api_key = api_key or settings.GROQ_API_KEY or settings.API_KEY
self.user_api_key = user_api_key
self.client = OpenAI(
api_key=self.api_key, base_url="https://api.groq.com/openai/v1"
class GroqLLM(OpenAILLM):
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
super().__init__(
api_key=api_key or settings.GROQ_API_KEY or settings.API_KEY,
user_api_key=user_api_key,
base_url=base_url or GROQ_BASE_URL,
*args,
**kwargs,
)
def _raw_gen(self, baseself, model, messages, stream=False, tools=None, **kwargs):
if tools:
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, tools=tools, **kwargs
)
return response.choices[0]
else:
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)
return response.choices[0].message.content
def _raw_gen_stream(
self, baseself, model, messages, stream=True, tools=None, **kwargs
):
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)
for line in response:
if line.choices[0].delta.content is not None:
yield line.choices[0].delta.content

View File

@@ -105,6 +105,7 @@ class LLMHandler(ABC):
"""
Prepare messages with attachments and provider-specific formatting.
Args:
agent: The agent instance
messages: Original messages
@@ -118,11 +119,40 @@ class LLMHandler(ABC):
logger.info(f"Preparing messages with {len(attachments)} attachments")
supported_types = agent.llm.get_supported_attachment_types()
# Check if provider supports images but not PDF (synthetic PDF support)
supports_images = any(t.startswith("image/") for t in supported_types)
supports_pdf = "application/pdf" in supported_types
# Process attachments, converting PDFs to images if needed
processed_attachments = []
for attachment in attachments:
mime_type = attachment.get("mime_type")
# Synthetic PDF support: convert PDF to images if LLM supports images but not PDF
if mime_type == "application/pdf" and supports_images and not supports_pdf:
logger.info(
f"Converting PDF to images for synthetic PDF support: {attachment.get('path', 'unknown')}"
)
try:
converted_images = self._convert_pdf_to_images(attachment)
processed_attachments.extend(converted_images)
logger.info(
f"Converted PDF to {len(converted_images)} images"
)
except Exception as e:
logger.error(
f"Failed to convert PDF to images, falling back to text: {e}"
)
# Fall back to treating as unsupported (text extraction)
processed_attachments.append(attachment)
else:
processed_attachments.append(attachment)
supported_attachments = [
a for a in attachments if a.get("mime_type") in supported_types
a for a in processed_attachments if a.get("mime_type") in supported_types
]
unsupported_attachments = [
a for a in attachments if a.get("mime_type") not in supported_types
a for a in processed_attachments if a.get("mime_type") not in supported_types
]
# Process supported attachments with the LLM's custom method
@@ -145,6 +175,37 @@ class LLMHandler(ABC):
)
return messages
def _convert_pdf_to_images(self, attachment: Dict) -> List[Dict]:
"""
Convert a PDF attachment to a list of image attachments.
This enables synthetic PDF support for LLMs that support images but not PDFs.
Args:
attachment: PDF attachment dictionary with 'path' and optional 'content'
Returns:
List of image attachment dictionaries with 'data', 'mime_type', and 'page'
"""
from application.utils import convert_pdf_to_images
from application.storage.storage_creator import StorageCreator
file_path = attachment.get("path")
if not file_path:
raise ValueError("No file path provided in PDF attachment")
storage = StorageCreator.get_storage()
# Convert PDF to images
images_data = convert_pdf_to_images(
file_path=file_path,
storage=storage,
max_pages=20,
dpi=150,
)
return images_data
def _append_unsupported_attachments(
self, messages: List[Dict], attachments: List[Dict]
) -> List[Dict]:
@@ -506,6 +567,7 @@ class LLMHandler(ABC):
getattr(agent, "user_api_key", None),
getattr(agent, "decoded_token", None),
model_id=compression_model,
agent_id=getattr(agent, "agent_id", None),
)
# Create service without DB persistence capability
@@ -817,6 +879,9 @@ class LLMHandler(ABC):
tool_calls = {}
for chunk in self._iterate_stream(response):
if isinstance(chunk, dict) and chunk.get("type") == "thought":
yield chunk
continue
if isinstance(chunk, str):
yield chunk
continue
@@ -833,7 +898,10 @@ class LLMHandler(ABC):
if call.name:
existing.name = call.name
if call.arguments:
existing.arguments += call.arguments
if existing.arguments is None:
existing.arguments = call.arguments
else:
existing.arguments += call.arguments
# Preserve thought_signature for Google Gemini 3 models
if call.thought_signature:
existing.thought_signature = call.thought_signature

View File

@@ -1,68 +0,0 @@
from application.llm.base import BaseLLM
class HuggingFaceLLM(BaseLLM):
def __init__(
self,
api_key=None,
user_api_key=None,
llm_name="Arc53/DocsGPT-7B",
q=False,
*args,
**kwargs,
):
global hf
from langchain.llms import HuggingFacePipeline
if q:
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
pipeline,
BitsAndBytesConfig,
)
tokenizer = AutoTokenizer.from_pretrained(llm_name)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
model = AutoModelForCausalLM.from_pretrained(
llm_name, quantization_config=bnb_config
)
else:
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
tokenizer = AutoTokenizer.from_pretrained(llm_name)
model = AutoModelForCausalLM.from_pretrained(llm_name)
super().__init__(*args, **kwargs)
self.api_key = api_key
self.user_api_key = user_api_key
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=2000,
device_map="auto",
eos_token_id=tokenizer.eos_token_id,
)
hf = HuggingFacePipeline(pipeline=pipe)
def _raw_gen(self, baseself, model, messages, stream=False, **kwargs):
context = messages[0]["content"]
user_question = messages[-1]["content"]
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
result = hf(prompt)
return result.content
def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs):
raise NotImplementedError("HuggingFaceLLM Streaming is not implemented yet.")

View File

@@ -4,12 +4,12 @@ from application.llm.anthropic import AnthropicLLM
from application.llm.docsgpt_provider import DocsGPTAPILLM
from application.llm.google_ai import GoogleLLM
from application.llm.groq import GroqLLM
from application.llm.huggingface import HuggingFaceLLM
from application.llm.llama_cpp import LlamaCpp
from application.llm.novita import NovitaLLM
from application.llm.openai import AzureOpenAILLM, OpenAILLM
from application.llm.premai import PremAILLM
from application.llm.sagemaker import SagemakerAPILLM
from application.llm.open_router import OpenRouterLLM
logger = logging.getLogger(__name__)
@@ -19,7 +19,6 @@ class LLMCreator:
"openai": OpenAILLM,
"azure_openai": AzureOpenAILLM,
"sagemaker": SagemakerAPILLM,
"huggingface": HuggingFaceLLM,
"llama.cpp": LlamaCpp,
"anthropic": AnthropicLLM,
"docsgpt": DocsGPTAPILLM,
@@ -27,11 +26,20 @@ class LLMCreator:
"groq": GroqLLM,
"google": GoogleLLM,
"novita": NovitaLLM,
"openrouter": OpenRouterLLM,
}
@classmethod
def create_llm(
cls, type, api_key, user_api_key, decoded_token, model_id=None, *args, **kwargs
cls,
type,
api_key,
user_api_key,
decoded_token,
model_id=None,
agent_id=None,
*args,
**kwargs,
):
from application.core.model_utils import get_base_url_for_model
@@ -49,6 +57,7 @@ class LLMCreator:
user_api_key,
decoded_token=decoded_token,
model_id=model_id,
agent_id=agent_id,
base_url=base_url,
*args,
**kwargs,

View File

@@ -1,32 +1,15 @@
from application.llm.base import BaseLLM
from openai import OpenAI
from application.core.settings import settings
from application.llm.openai import OpenAILLM
NOVITA_BASE_URL = "https://api.novita.ai/v3/openai"
class NovitaLLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.client = OpenAI(api_key=api_key, base_url="https://api.novita.ai/v3/openai")
self.api_key = api_key
self.user_api_key = user_api_key
def _raw_gen(self, baseself, model, messages, stream=False, tools=None, **kwargs):
if tools:
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, tools=tools, **kwargs
)
return response.choices[0]
else:
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)
return response.choices[0].message.content
def _raw_gen_stream(
self, baseself, model, messages, stream=True, tools=None, **kwargs
):
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
class NovitaLLM(OpenAILLM):
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
super().__init__(
api_key=api_key or settings.API_KEY,
user_api_key=user_api_key,
base_url=base_url or NOVITA_BASE_URL,
*args,
**kwargs,
)
for line in response:
if line.choices[0].delta.content is not None:
yield line.choices[0].delta.content

View File

@@ -0,0 +1,15 @@
from application.core.settings import settings
from application.llm.openai import OpenAILLM
OPEN_ROUTER_BASE_URL = "https://openrouter.ai/api/v1"
class OpenRouterLLM(OpenAILLM):
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
super().__init__(
api_key=api_key or settings.OPEN_ROUTER_API_KEY or settings.API_KEY,
user_api_key=user_api_key,
base_url=base_url or OPEN_ROUTER_BASE_URL,
*args,
**kwargs,
)

View File

@@ -9,6 +9,57 @@ from application.llm.base import BaseLLM
from application.storage.storage_creator import StorageCreator
def _truncate_base64_for_logging(messages):
"""
Create a copy of messages with base64 data truncated for readable logging.
Args:
messages: List of message dicts
Returns:
Copy of messages with truncated base64 content
"""
import copy
def truncate_content(content):
if isinstance(content, str):
# Check if it looks like a data URL with base64
if content.startswith("data:") and ";base64," in content:
prefix_end = content.index(";base64,") + len(";base64,")
prefix = content[:prefix_end]
return f"{prefix}[BASE64_DATA_TRUNCATED, length={len(content) - prefix_end}]"
return content
elif isinstance(content, list):
return [truncate_item(item) for item in content]
elif isinstance(content, dict):
return {k: truncate_content(v) for k, v in content.items()}
return content
def truncate_item(item):
if isinstance(item, dict):
result = {}
for k, v in item.items():
if k == "url" and isinstance(v, str) and ";base64," in v:
prefix_end = v.index(";base64,") + len(";base64,")
prefix = v[:prefix_end]
result[k] = f"{prefix}[BASE64_DATA_TRUNCATED, length={len(v) - prefix_end}]"
elif k == "data" and isinstance(v, str) and len(v) > 100:
result[k] = f"[BASE64_DATA_TRUNCATED, length={len(v)}]"
else:
result[k] = truncate_content(v)
return result
return truncate_content(item)
truncated = []
for msg in messages:
msg_copy = copy.copy(msg)
if "content" in msg_copy:
msg_copy["content"] = truncate_content(msg_copy["content"])
truncated.append(msg_copy)
return truncated
class OpenAILLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
@@ -44,12 +95,12 @@ class OpenAILLM(BaseLLM):
if isinstance(content, str):
cleaned_messages.append({"role": role, "content": content})
elif isinstance(content, list):
# Collect all content parts into a single message
content_parts = []
for item in content:
if "text" in item:
cleaned_messages.append(
{"role": role, "content": item["text"]}
)
elif "function_call" in item:
if "function_call" in item:
# Function calls need their own message
cleaned_args = self._remove_null_values(
item["function_call"]["args"]
)
@@ -69,6 +120,7 @@ class OpenAILLM(BaseLLM):
}
)
elif "function_response" in item:
# Function responses need their own message
cleaned_messages.append(
{
"role": "tool",
@@ -81,40 +133,69 @@ class OpenAILLM(BaseLLM):
}
)
elif isinstance(item, dict):
content_parts = []
if "text" in item:
content_parts.append(
{"type": "text", "text": item["text"]}
)
elif (
"type" in item
and item["type"] == "text"
and "text" in item
):
# Collect content parts (text, images, files) into a single message
if "type" in item and item["type"] == "text" and "text" in item:
content_parts.append(item)
elif (
"type" in item
and item["type"] == "file"
and "file" in item
):
elif "type" in item and item["type"] == "file" and "file" in item:
content_parts.append(item)
elif (
"type" in item
and item["type"] == "image_url"
and "image_url" in item
):
elif "type" in item and item["type"] == "image_url" and "image_url" in item:
content_parts.append(item)
cleaned_messages.append(
{"role": role, "content": content_parts}
)
else:
raise ValueError(
f"Unexpected content dictionary format: {item}"
)
elif "text" in item and "type" not in item:
# Legacy format: {"text": "..."} without type
content_parts.append({"type": "text", "text": item["text"]})
# Add the collected content parts as a single message
if content_parts:
cleaned_messages.append({"role": role, "content": content_parts})
else:
raise ValueError(f"Unexpected content type: {type(content)}")
return cleaned_messages
@staticmethod
def _normalize_reasoning_value(value):
"""Normalize reasoning payloads from OpenAI-compatible stream chunks."""
if value is None:
return ""
if isinstance(value, str):
return value
if isinstance(value, list):
return "".join(
OpenAILLM._normalize_reasoning_value(item) for item in value
)
if isinstance(value, dict):
for key in ("text", "content", "value", "reasoning_content", "reasoning"):
normalized = OpenAILLM._normalize_reasoning_value(value.get(key))
if normalized:
return normalized
return ""
for attr in ("text", "content", "value"):
if hasattr(value, attr):
normalized = OpenAILLM._normalize_reasoning_value(getattr(value, attr))
if normalized:
return normalized
return ""
@classmethod
def _extract_reasoning_text(cls, delta):
"""Extract reasoning/thinking tokens from OpenAI-compatible delta chunks."""
if delta is None:
return ""
for key in (
"reasoning_content",
"reasoning",
"thinking",
"thinking_content",
):
value = getattr(delta, key, None)
if value is None and isinstance(delta, dict):
value = delta.get(key)
normalized = cls._normalize_reasoning_value(value)
if normalized:
return normalized
return ""
def _raw_gen(
self,
baseself,
@@ -127,6 +208,7 @@ class OpenAILLM(BaseLLM):
**kwargs,
):
messages = self._clean_messages_openai(messages)
logging.info(f"Cleaned messages: {_truncate_base64_for_logging(messages)}")
# Convert max_tokens to max_completion_tokens for newer models
if "max_tokens" in kwargs:
@@ -144,7 +226,7 @@ class OpenAILLM(BaseLLM):
if response_format:
request_params["response_format"] = response_format
response = self.client.chat.completions.create(**request_params)
logging.info(f"OpenAI response: {response}")
if tools:
return response.choices[0]
else:
@@ -162,6 +244,7 @@ class OpenAILLM(BaseLLM):
**kwargs,
):
messages = self._clean_messages_openai(messages)
logging.info(f"Cleaned messages: {_truncate_base64_for_logging(messages)}")
# Convert max_tokens to max_completion_tokens for newer models
if "max_tokens" in kwargs:
@@ -182,14 +265,27 @@ class OpenAILLM(BaseLLM):
try:
for line in response:
if (
len(line.choices) > 0
and line.choices[0].delta.content is not None
and len(line.choices[0].delta.content) > 0
):
yield line.choices[0].delta.content
elif len(line.choices) > 0:
yield line.choices[0]
logging.debug(f"OpenAI stream line: {line}")
if not getattr(line, "choices", None):
continue
choice = line.choices[0]
delta = getattr(choice, "delta", None)
reasoning_text = self._extract_reasoning_text(delta)
if reasoning_text:
yield {"type": "thought", "thought": reasoning_text}
content = getattr(delta, "content", None)
if isinstance(content, str) and content:
yield content
continue
has_tool_calls = bool(getattr(delta, "tool_calls", None))
finish_reason = getattr(choice, "finish_reason", None)
# Yield non-content chunks only when needed for tool-call handling.
if has_tool_calls or finish_reason == "tool_calls":
yield choice
finally:
if hasattr(response, "close"):
response.close()
@@ -258,17 +354,14 @@ class OpenAILLM(BaseLLM):
"""
Return a list of MIME types supported by OpenAI for file uploads.
This reads from the model config to ensure consistency.
If no model config found, falls back to images only (safest default).
Returns:
list: List of supported MIME types
"""
return [
"application/pdf",
"image/png",
"image/jpeg",
"image/jpg",
"image/webp",
"image/gif",
]
from application.core.model_configs import OPENAI_ATTACHMENTS
return OPENAI_ATTACHMENTS
def prepare_messages_with_attachments(self, messages, attachments=None):
"""
@@ -305,10 +398,16 @@ class OpenAILLM(BaseLLM):
prepared_messages[user_message_index]["content"] = []
for attachment in attachments:
mime_type = attachment.get("mime_type")
logging.info(f"Processing attachment with mime_type: {mime_type}, has_data: {'data' in attachment}, has_path: {'path' in attachment}")
if mime_type and mime_type.startswith("image/"):
try:
base64_image = self._get_base64_image(attachment)
# Check if this is a pre-converted image (from PDF-to-image conversion)
if "data" in attachment:
base64_image = attachment["data"]
else:
base64_image = self._get_base64_image(attachment)
prepared_messages[user_message_index]["content"].append(
{
"type": "image_url",
@@ -317,6 +416,7 @@ class OpenAILLM(BaseLLM):
},
}
)
except Exception as e:
logging.error(
f"Error processing image attachment: {e}", exc_info=True
@@ -331,6 +431,7 @@ class OpenAILLM(BaseLLM):
# Handle PDFs using the file API
elif mime_type == "application/pdf":
logging.info(f"Attempting to upload PDF to OpenAI: {attachment.get('path', 'unknown')}")
try:
file_id = self._upload_file_to_openai(attachment)
prepared_messages[user_message_index]["content"].append(
@@ -345,6 +446,8 @@ class OpenAILLM(BaseLLM):
"text": f"File content:\n\n{attachment['content']}",
}
)
else:
logging.warning(f"Unsupported attachment type in OpenAI provider: {mime_type}")
return prepared_messages
def _get_base64_image(self, attachment):

View File

@@ -65,6 +65,10 @@ def embed_and_store_documents(docs: List[Any], folder_name: str, source_id: str,
if not os.path.exists(folder_name):
os.makedirs(folder_name)
# Validate docs is not empty
if not docs:
raise ValueError("No documents to embed - check file format and extension")
# Initialize vector store
if settings.VECTOR_STORE == "faiss":
docs_init = [docs.pop(0)]

View File

@@ -2,7 +2,7 @@
from abc import abstractmethod
from typing import Any, List
from langchain.docstore.document import Document as LCDocument
from langchain_core.documents import Document as LCDocument
from application.parser.schema.base import Document

View File

@@ -10,29 +10,97 @@ from application.parser.file.epub_parser import EpubParser
from application.parser.file.html_parser import HTMLParser
from application.parser.file.markdown_parser import MarkdownParser
from application.parser.file.rst_parser import RstParser
from application.parser.file.tabular_parser import PandasCSVParser,ExcelParser
from application.parser.file.tabular_parser import PandasCSVParser, ExcelParser
from application.parser.file.json_parser import JSONParser
from application.parser.file.pptx_parser import PPTXParser
from application.parser.file.image_parser import ImageParser
from application.parser.schema.base import Document
from application.utils import num_tokens_from_string
from application.core.settings import settings
DEFAULT_FILE_EXTRACTOR: Dict[str, BaseParser] = {
".pdf": PDFParser(),
".docx": DocxParser(),
".csv": PandasCSVParser(),
".xlsx":ExcelParser(),
".epub": EpubParser(),
".md": MarkdownParser(),
".rst": RstParser(),
".html": HTMLParser(),
".mdx": MarkdownParser(),
".json":JSONParser(),
".pptx":PPTXParser(),
".png": ImageParser(),
".jpg": ImageParser(),
".jpeg": ImageParser(),
}
def get_default_file_extractor(
ocr_enabled: Optional[bool] = None,
) -> Dict[str, BaseParser]:
"""Get the default file extractor.
Uses docling parsers by default for advanced document processing.
Falls back to standard parsers if docling is not installed.
"""
try:
from application.parser.file.docling_parser import (
DoclingPDFParser,
DoclingDocxParser,
DoclingPPTXParser,
DoclingXLSXParser,
DoclingHTMLParser,
DoclingImageParser,
DoclingCSVParser,
DoclingAsciiDocParser,
DoclingVTTParser,
DoclingXMLParser,
)
if ocr_enabled is None:
ocr_enabled = settings.DOCLING_OCR_ENABLED
return {
# Documents
".pdf": DoclingPDFParser(ocr_enabled=ocr_enabled),
".docx": DoclingDocxParser(),
".pptx": DoclingPPTXParser(),
".xlsx": DoclingXLSXParser(),
# Web formats
".html": DoclingHTMLParser(),
".xhtml": DoclingHTMLParser(),
# Data formats
".csv": DoclingCSVParser(),
".json": JSONParser(), # Keep JSON parser (specialized handling)
# Text/markup formats
".md": MarkdownParser(), # Keep markdown parser (specialized handling)
".mdx": MarkdownParser(),
".rst": RstParser(),
".adoc": DoclingAsciiDocParser(),
".asciidoc": DoclingAsciiDocParser(),
# Images (with OCR) - only use Docling when OCR is enabled
".png": DoclingImageParser(ocr_enabled=ocr_enabled) if ocr_enabled else ImageParser(),
".jpg": DoclingImageParser(ocr_enabled=ocr_enabled) if ocr_enabled else ImageParser(),
".jpeg": DoclingImageParser(ocr_enabled=ocr_enabled) if ocr_enabled else ImageParser(),
".tiff": DoclingImageParser(ocr_enabled=ocr_enabled) if ocr_enabled else ImageParser(),
".tif": DoclingImageParser(ocr_enabled=ocr_enabled) if ocr_enabled else ImageParser(),
".bmp": DoclingImageParser(ocr_enabled=ocr_enabled) if ocr_enabled else ImageParser(),
".webp": DoclingImageParser(ocr_enabled=ocr_enabled) if ocr_enabled else ImageParser(),
# Media/subtitles
".vtt": DoclingVTTParser(),
# Specialized XML formats
".xml": DoclingXMLParser(),
# Formats docling doesn't support - use standard parsers
".epub": EpubParser(),
}
except ImportError:
logging.warning(
"docling is not installed. Using standard parsers. "
"For advanced document parsing, install with: pip install docling"
)
# Fallback to standard parsers
return {
".pdf": PDFParser(),
".docx": DocxParser(),
".csv": PandasCSVParser(),
".xlsx": ExcelParser(),
".epub": EpubParser(),
".md": MarkdownParser(),
".rst": RstParser(),
".html": HTMLParser(),
".mdx": MarkdownParser(),
".json": JSONParser(),
".pptx": PPTXParser(),
".png": ImageParser(),
".jpg": ImageParser(),
".jpeg": ImageParser(),
}
# For backwards compatibility
DEFAULT_FILE_EXTRACTOR: Dict[str, BaseParser] = get_default_file_extractor()
class SimpleDirectoryReader(BaseReader):
@@ -83,7 +151,10 @@ class SimpleDirectoryReader(BaseReader):
self.recursive = recursive
self.exclude_hidden = exclude_hidden
self.required_exts = required_exts
# Normalize extensions to lowercase for case-insensitive matching
self.required_exts = (
[ext.lower() for ext in required_exts] if required_exts else None
)
self.num_files_limit = num_files_limit
if input_files:
@@ -112,7 +183,7 @@ class SimpleDirectoryReader(BaseReader):
continue
elif (
self.required_exts is not None
and input_file.suffix not in self.required_exts
and input_file.suffix.lower() not in self.required_exts
):
continue
else:
@@ -149,8 +220,9 @@ class SimpleDirectoryReader(BaseReader):
self.file_token_counts = {}
for input_file in self.input_files:
if input_file.suffix in self.file_extractor:
parser = self.file_extractor[input_file.suffix]
suffix_lower = input_file.suffix.lower()
if suffix_lower in self.file_extractor:
parser = self.file_extractor[suffix_lower]
if not parser.parser_config_set:
parser.init_parser()
data = parser.parse_file(input_file, errors=self.errors)
@@ -232,7 +304,7 @@ class SimpleDirectoryReader(BaseReader):
if subtree:
result[item.name] = subtree
else:
if self.required_exts is not None and item.suffix not in self.required_exts:
if self.required_exts is not None and item.suffix.lower() not in self.required_exts:
continue
full_path = str(item.resolve())
@@ -251,4 +323,4 @@ class SimpleDirectoryReader(BaseReader):
return result
return build_tree(Path(base_path))
return build_tree(Path(base_path))

View File

@@ -0,0 +1,330 @@
"""Docling parser.
Uses docling library for advanced document parsing with layout detection,
table structure recognition, and unified document representation.
Supports: PDF, DOCX, PPTX, XLSX, HTML, XHTML, CSV, Markdown, AsciiDoc,
images (PNG, JPEG, TIFF, BMP, WEBP), WebVTT, and specialized XML formats.
"""
import importlib.util
import logging
from pathlib import Path
from typing import Dict, List, Optional, Union
from application.parser.file.base_parser import BaseParser
logger = logging.getLogger(__name__)
class DoclingParser(BaseParser):
"""Parser using docling for advanced document processing.
Docling provides:
- Advanced PDF layout analysis
- Table structure recognition
- Reading order detection
- OCR for scanned documents (supports RapidOCR)
- Unified DoclingDocument format
- Export to Markdown
Uses hybrid OCR approach by default:
- Text regions: Direct PDF text extraction (fast)
- Bitmap/image regions: OCR only these areas (smart)
"""
def __init__(
self,
ocr_enabled: bool = True,
table_structure: bool = True,
export_format: str = "markdown",
use_rapidocr: bool = True,
ocr_languages: Optional[List[str]] = None,
force_full_page_ocr: bool = False,
):
"""Initialize DoclingParser.
Args:
ocr_enabled: Enable OCR for bitmap/image regions in documents
table_structure: Enable table structure recognition
export_format: Output format ('markdown', 'text', 'html')
use_rapidocr: Use RapidOCR engine (default True, works well in Docker)
ocr_languages: List of OCR languages (default: ['english'])
force_full_page_ocr: Force OCR on entire page (False = smart hybrid OCR)
"""
super().__init__()
self.ocr_enabled = ocr_enabled
self.table_structure = table_structure
self.export_format = export_format
self.use_rapidocr = use_rapidocr
self.ocr_languages = ocr_languages or ["english"]
self.force_full_page_ocr = force_full_page_ocr
self._converter = None
def _create_converter(self):
"""Create a docling converter with hybrid OCR configuration.
Uses smart OCR approach:
- When ocr_enabled=True and force_full_page_ocr=False (default):
Layout model detects text vs bitmap regions, OCR only runs on bitmaps
- When ocr_enabled=True and force_full_page_ocr=True:
OCR runs on entire page (for scanned documents/images)
- When ocr_enabled=False:
No OCR, only native text extraction
Returns:
DocumentConverter instance
"""
from docling.document_converter import (
DocumentConverter,
ImageFormatOption,
InputFormat,
PdfFormatOption,
)
from docling.datamodel.pipeline_options import PdfPipelineOptions
pipeline_options = PdfPipelineOptions(
do_ocr=self.ocr_enabled,
do_table_structure=self.table_structure,
)
if self.ocr_enabled:
ocr_options = self._get_ocr_options()
if ocr_options is not None:
pipeline_options.ocr_options = ocr_options
return DocumentConverter(
format_options={
InputFormat.PDF: PdfFormatOption(
pipeline_options=pipeline_options,
),
InputFormat.IMAGE: ImageFormatOption(
pipeline_options=pipeline_options,
),
}
)
def _init_parser(self) -> Dict:
"""Initialize the docling converter with hybrid OCR."""
logger.info("Initializing DoclingParser...")
logger.info(f" ocr_enabled={self.ocr_enabled}")
logger.info(f" force_full_page_ocr={self.force_full_page_ocr}")
logger.info(f" use_rapidocr={self.use_rapidocr}")
if importlib.util.find_spec("docling.document_converter") is None:
raise ImportError(
"docling is required for DoclingParser. "
"Install it with: pip install docling"
)
# Create converter with hybrid OCR (smart: text direct, bitmaps OCR'd)
self._converter = self._create_converter()
logger.info("DoclingParser initialized successfully")
return {
"ocr_enabled": self.ocr_enabled,
"table_structure": self.table_structure,
"export_format": self.export_format,
"use_rapidocr": self.use_rapidocr,
"ocr_languages": self.ocr_languages,
"force_full_page_ocr": self.force_full_page_ocr,
}
def _get_ocr_options(self):
"""Get OCR options based on configuration.
Returns RapidOcrOptions if use_rapidocr is True and available,
otherwise returns None to use docling defaults.
"""
if not self.use_rapidocr:
return None
try:
from docling.datamodel.pipeline_options import RapidOcrOptions
return RapidOcrOptions(
lang=self.ocr_languages,
force_full_page_ocr=self.force_full_page_ocr,
)
except ImportError as e:
logger.warning(f"Failed to import RapidOcrOptions: {e}")
return None
except Exception as e:
logger.error(f"Error creating RapidOcrOptions: {e}")
return None
def _export_content(self, document) -> str:
"""Export document content in the configured format.
Handles edge case where text is nested under picture elements (e.g., OCR'd
images). If the standard export returns minimal content but document.texts
contains extracted text, falls back to direct text extraction.
"""
if self.export_format == "markdown":
content = document.export_to_markdown()
elif self.export_format == "html":
content = document.export_to_html()
else:
content = document.export_to_text()
# Handle case where text is nested under pictures (common with OCR'd images)
# Standard exports may return just "<!-- image -->" while actual text exists
stripped_content = content.strip()
is_minimal = len(stripped_content) < 50 or stripped_content == "<!-- image -->"
if is_minimal and hasattr(document, "texts") and document.texts:
# Extract text directly from document.texts
extracted_texts = [t.text for t in document.texts if t.text]
if extracted_texts:
logger.info(
f"Standard export minimal ({len(stripped_content)} chars), "
f"extracting {len(extracted_texts)} texts directly"
)
return "\n\n".join(extracted_texts)
return content
def parse_file(self, file: Path, errors: str = "ignore") -> Union[str, List[str]]:
"""Parse file using docling with hybrid OCR.
Uses smart OCR approach where the layout model detects text vs bitmap
regions. Text is extracted directly, bitmaps are OCR'd only when needed.
Args:
file: Path to the file to parse
errors: Error handling mode (ignored, docling handles internally)
Returns:
Parsed document content as markdown string
"""
logger.info(f"parse_file called for: {file}")
if self._converter is None:
self._init_parser()
try:
logger.info(f"Converting file with hybrid OCR: {file}")
result = self._converter.convert(str(file))
content = self._export_content(result.document)
logger.info(f"Parse complete, content length: {len(content)} chars")
return content
except Exception as e:
logger.error(f"Error parsing file with docling: {e}", exc_info=True)
if errors == "ignore":
return f"[Error parsing file with docling: {str(e)}]"
raise
class DoclingPDFParser(DoclingParser):
"""Docling-based PDF parser with advanced features and RapidOCR support.
Uses hybrid OCR approach by default:
- Text regions: Direct PDF text extraction (fast)
- Bitmap/image regions: OCR only these areas (smart)
Set force_full_page_ocr=True only for fully scanned documents.
"""
def __init__(
self,
ocr_enabled: bool = True,
table_structure: bool = True,
use_rapidocr: bool = True,
ocr_languages: Optional[List[str]] = None,
force_full_page_ocr: bool = False,
):
super().__init__(
ocr_enabled=ocr_enabled,
table_structure=table_structure,
export_format="markdown",
use_rapidocr=use_rapidocr,
ocr_languages=ocr_languages,
force_full_page_ocr=force_full_page_ocr,
)
class DoclingDocxParser(DoclingParser):
"""Docling-based DOCX parser."""
def __init__(self):
super().__init__(export_format="markdown")
class DoclingPPTXParser(DoclingParser):
"""Docling-based PPTX parser."""
def __init__(self):
super().__init__(export_format="markdown")
class DoclingXLSXParser(DoclingParser):
"""Docling-based XLSX parser with table structure."""
def __init__(self):
super().__init__(table_structure=True, export_format="markdown")
class DoclingHTMLParser(DoclingParser):
"""Docling-based HTML parser."""
def __init__(self):
super().__init__(export_format="markdown")
class DoclingImageParser(DoclingParser):
"""Docling-based image parser with OCR and RapidOCR support.
For images, force_full_page_ocr=True is used since images are entirely
visual and require full OCR to extract any text.
"""
def __init__(
self,
ocr_enabled: bool = True,
use_rapidocr: bool = True,
ocr_languages: Optional[List[str]] = None,
force_full_page_ocr: bool = True,
):
super().__init__(
ocr_enabled=ocr_enabled,
export_format="markdown",
use_rapidocr=use_rapidocr,
ocr_languages=ocr_languages,
force_full_page_ocr=force_full_page_ocr,
)
class DoclingCSVParser(DoclingParser):
"""Docling-based CSV parser."""
def __init__(self):
super().__init__(table_structure=True, export_format="markdown")
class DoclingMarkdownParser(DoclingParser):
"""Docling-based Markdown parser."""
def __init__(self):
super().__init__(export_format="markdown")
class DoclingAsciiDocParser(DoclingParser):
"""Docling-based AsciiDoc parser."""
def __init__(self):
super().__init__(export_format="markdown")
class DoclingVTTParser(DoclingParser):
"""Docling-based WebVTT (video text tracks) parser."""
def __init__(self):
super().__init__(export_format="markdown")
class DoclingXMLParser(DoclingParser):
"""Docling-based XML parser (USPTO, JATS)."""
def __init__(self):
super().__init__(export_format="markdown")

View File

@@ -7,8 +7,8 @@ import re
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union, cast
import tiktoken
from application.parser.file.base_parser import BaseParser
from application.utils import num_tokens_from_string
class MarkdownParser(BaseParser):
@@ -38,7 +38,7 @@ class MarkdownParser(BaseParser):
def tups_chunk_append(self, tups: List[Tuple[Optional[str], str]], current_header: Optional[str],
current_text: str):
"""Append to tups chunk."""
num_tokens = len(tiktoken.get_encoding("cl100k_base").encode(current_text))
num_tokens = num_tokens_from_string(current_text)
if num_tokens > self._max_tokens:
chunks = [current_text[i:i + self._max_tokens] for i in range(0, len(current_text), self._max_tokens)]
for chunk in chunks:

View File

@@ -2,7 +2,7 @@
from abc import abstractmethod
from typing import Any, List
from langchain.docstore.document import Document as LCDocument
from langchain_core.documents import Document as LCDocument
from application.parser.schema.base import Document

View File

@@ -1,9 +1,11 @@
import logging
import os
import requests
from urllib.parse import urlparse, urljoin
from bs4 import BeautifulSoup
from application.parser.remote.base import BaseRemote
from application.parser.schema.base import Document
from application.core.url_validation import validate_url, SSRFError
from langchain_community.document_loaders import WebBaseLoader
class CrawlerLoader(BaseRemote):
@@ -16,9 +18,12 @@ class CrawlerLoader(BaseRemote):
if isinstance(url, list) and url:
url = url[0]
# Check if the URL scheme is provided, if not, assume http
if not urlparse(url).scheme:
url = "http://" + url
# Validate URL to prevent SSRF attacks
try:
url = validate_url(url)
except SSRFError as e:
logging.error(f"URL validation failed: {e}")
return []
visited_urls = set()
base_url = urlparse(url).scheme + "://" + urlparse(url).hostname
@@ -30,16 +35,26 @@ class CrawlerLoader(BaseRemote):
visited_urls.add(current_url)
try:
response = requests.get(current_url)
# Validate each URL before making requests
try:
validate_url(current_url)
except SSRFError as e:
logging.warning(f"Skipping URL due to validation failure: {current_url} - {e}")
continue
response = requests.get(current_url, timeout=30)
response.raise_for_status()
loader = self.loader([current_url])
docs = loader.load()
# Convert the loaded documents to your Document schema
for doc in docs:
metadata = dict(doc.metadata or {})
source_url = metadata.get("source") or current_url
metadata["file_path"] = self._url_to_virtual_path(source_url)
loaded_content.append(
Document(
doc.page_content,
extra_info=doc.metadata
extra_info=metadata
)
)
except Exception as e:
@@ -63,3 +78,29 @@ class CrawlerLoader(BaseRemote):
break
return loaded_content
def _url_to_virtual_path(self, url):
"""
Convert a URL to a virtual file path ending with .md.
Examples:
https://docs.docsgpt.cloud/ -> index.md
https://docs.docsgpt.cloud/guides/setup -> guides/setup.md
https://docs.docsgpt.cloud/guides/setup/ -> guides/setup.md
https://example.com/page.html -> page.md
"""
parsed = urlparse(url)
path = parsed.path.strip("/")
if not path:
return "index.md"
# Remove common file extensions and add .md
base, ext = os.path.splitext(path)
if ext.lower() in [".html", ".htm", ".php", ".asp", ".aspx", ".jsp"]:
path = base
if not path.endswith(".md"):
path = f"{path}.md"
return path

View File

@@ -2,10 +2,12 @@ import requests
from urllib.parse import urlparse, urljoin
from bs4 import BeautifulSoup
from application.parser.remote.base import BaseRemote
from application.core.url_validation import validate_url, SSRFError
import re
from markdownify import markdownify
from application.parser.schema.base import Document
import tldextract
import os
class CrawlerLoader(BaseRemote):
def __init__(self, limit=10, allow_subdomains=False):
@@ -25,9 +27,12 @@ class CrawlerLoader(BaseRemote):
if isinstance(url, list) and url:
url = url[0]
# Ensure the URL has a scheme (if not, default to http)
if not urlparse(url).scheme:
url = "http://" + url
# Validate URL to prevent SSRF attacks
try:
url = validate_url(url)
except SSRFError as e:
print(f"URL validation failed: {e}")
return []
# Keep track of visited URLs to avoid revisiting the same page
visited_urls = set()
@@ -53,13 +58,21 @@ class CrawlerLoader(BaseRemote):
# Convert the HTML to Markdown for cleaner text formatting
title, language, processed_markdown = self._process_html_to_markdown(html_content, current_url)
if processed_markdown:
# Generate virtual file path from URL for consistent file-like matching
virtual_path = self._url_to_virtual_path(current_url)
# Create a Document for each visited page
documents.append(
Document(
processed_markdown, # content
None, # doc_id
None, # embedding
{"source": current_url, "title": title, "language": language} # extra_info
{
"source": current_url,
"title": title,
"language": language,
"file_path": virtual_path,
}, # extra_info
)
)
@@ -78,9 +91,14 @@ class CrawlerLoader(BaseRemote):
def _fetch_page(self, url):
try:
# Validate URL before fetching to prevent SSRF
validate_url(url)
response = self.session.get(url, timeout=10)
response.raise_for_status()
return response.text
except SSRFError as e:
print(f"URL validation failed for {url}: {e}")
return None
except requests.exceptions.RequestException as e:
print(f"Error fetching URL {url}: {e}")
return None
@@ -136,4 +154,31 @@ class CrawlerLoader(BaseRemote):
# Exact domain match
if link_base == base_domain:
filtered.append(link)
return filtered
return filtered
def _url_to_virtual_path(self, url):
"""
Convert a URL to a virtual file path ending with .md.
Examples:
https://docs.docsgpt.cloud/ -> index.md
https://docs.docsgpt.cloud/guides/setup -> guides/setup.md
https://docs.docsgpt.cloud/guides/setup/ -> guides/setup.md
https://example.com/page.html -> page.md
"""
parsed = urlparse(url)
path = parsed.path.strip("/")
if not path:
return "index.md"
# Remove common file extensions and add .md
base, ext = os.path.splitext(path)
if ext.lower() in [".html", ".htm", ".php", ".asp", ".aspx", ".jsp"]:
path = base
# Ensure path ends with .md
if not path.endswith(".md"):
path = path + ".md"
return path

View File

@@ -3,6 +3,7 @@ from application.parser.remote.crawler_loader import CrawlerLoader
from application.parser.remote.web_loader import WebLoader
from application.parser.remote.reddit_loader import RedditPostsLoaderRemote
from application.parser.remote.github_loader import GitHubLoader
from application.parser.remote.s3_loader import S3Loader
class RemoteCreator:
@@ -22,6 +23,7 @@ class RemoteCreator:
"crawler": CrawlerLoader,
"reddit": RedditPostsLoaderRemote,
"github": GitHubLoader,
"s3": S3Loader,
}
@classmethod

View File

@@ -0,0 +1,427 @@
import json
import logging
import os
import tempfile
import mimetypes
from typing import List, Optional
from application.parser.remote.base import BaseRemote
from application.parser.schema.base import Document
try:
import boto3
from botocore.exceptions import ClientError, NoCredentialsError
except ImportError:
boto3 = None
logger = logging.getLogger(__name__)
class S3Loader(BaseRemote):
"""Load documents from an AWS S3 bucket."""
def __init__(self):
if boto3 is None:
raise ImportError(
"boto3 is required for S3Loader. Install it with: pip install boto3"
)
self.s3_client = None
def _normalize_endpoint_url(self, endpoint_url: str, bucket: str) -> tuple[str, str]:
"""
Normalize endpoint URL for S3-compatible services.
Detects common mistakes like using bucket-prefixed URLs and extracts
the correct endpoint and bucket name.
Args:
endpoint_url: The provided endpoint URL
bucket: The provided bucket name
Returns:
Tuple of (normalized_endpoint_url, bucket_name)
"""
import re
from urllib.parse import urlparse
if not endpoint_url:
return endpoint_url, bucket
parsed = urlparse(endpoint_url)
host = parsed.netloc or parsed.path
# Check for DigitalOcean Spaces bucket-prefixed URL pattern
# e.g., https://mybucket.nyc3.digitaloceanspaces.com
do_match = re.match(r"^([^.]+)\.([a-z0-9]+)\.digitaloceanspaces\.com$", host)
if do_match:
extracted_bucket = do_match.group(1)
region = do_match.group(2)
correct_endpoint = f"https://{region}.digitaloceanspaces.com"
logger.warning(
f"Detected bucket-prefixed DigitalOcean Spaces URL. "
f"Extracted bucket '{extracted_bucket}' from endpoint. "
f"Using endpoint: {correct_endpoint}"
)
# If bucket wasn't provided or differs, use extracted one
if not bucket or bucket != extracted_bucket:
logger.info(f"Using extracted bucket name: '{extracted_bucket}' (was: '{bucket}')")
bucket = extracted_bucket
return correct_endpoint, bucket
# Check for just "digitaloceanspaces.com" without region
if host == "digitaloceanspaces.com":
logger.error(
"Invalid DigitalOcean Spaces endpoint: missing region. "
"Use format: https://<region>.digitaloceanspaces.com (e.g., https://lon1.digitaloceanspaces.com)"
)
return endpoint_url, bucket
def _init_client(
self,
aws_access_key_id: str,
aws_secret_access_key: str,
region_name: str = "us-east-1",
endpoint_url: Optional[str] = None,
bucket: Optional[str] = None,
) -> Optional[str]:
"""
Initialize the S3 client with credentials.
Returns:
The potentially corrected bucket name if endpoint URL was normalized
"""
from botocore.config import Config
client_kwargs = {
"aws_access_key_id": aws_access_key_id,
"aws_secret_access_key": aws_secret_access_key,
"region_name": region_name,
}
logger.info(f"Initializing S3 client with region: {region_name}")
corrected_bucket = bucket
if endpoint_url:
# Normalize the endpoint URL and potentially extract bucket name
normalized_endpoint, corrected_bucket = self._normalize_endpoint_url(endpoint_url, bucket)
logger.info(f"Original endpoint URL: {endpoint_url}")
logger.info(f"Normalized endpoint URL: {normalized_endpoint}")
logger.info(f"Bucket name: '{corrected_bucket}'")
client_kwargs["endpoint_url"] = normalized_endpoint
# Use path-style addressing for S3-compatible services
# (DigitalOcean Spaces, MinIO, etc.)
client_kwargs["config"] = Config(s3={"addressing_style": "path"})
else:
logger.info("Using default AWS S3 endpoint")
self.s3_client = boto3.client("s3", **client_kwargs)
logger.info("S3 client initialized successfully")
return corrected_bucket
def is_text_file(self, file_path: str) -> bool:
"""Determine if a file is a text file based on extension."""
text_extensions = {
".txt",
".md",
".markdown",
".rst",
".json",
".xml",
".yaml",
".yml",
".py",
".js",
".ts",
".jsx",
".tsx",
".java",
".c",
".cpp",
".h",
".hpp",
".cs",
".go",
".rs",
".rb",
".php",
".swift",
".kt",
".scala",
".html",
".css",
".scss",
".sass",
".less",
".sh",
".bash",
".zsh",
".fish",
".sql",
".r",
".m",
".mat",
".ini",
".cfg",
".conf",
".config",
".env",
".gitignore",
".dockerignore",
".editorconfig",
".log",
".csv",
".tsv",
}
file_lower = file_path.lower()
for ext in text_extensions:
if file_lower.endswith(ext):
return True
mime_type, _ = mimetypes.guess_type(file_path)
if mime_type and (
mime_type.startswith("text")
or mime_type in ["application/json", "application/xml"]
):
return True
return False
def is_supported_document(self, file_path: str) -> bool:
"""Check if file is a supported document type for parsing."""
document_extensions = {
".pdf",
".docx",
".doc",
".xlsx",
".xls",
".pptx",
".ppt",
".epub",
".odt",
".rtf",
}
file_lower = file_path.lower()
for ext in document_extensions:
if file_lower.endswith(ext):
return True
return False
def list_objects(self, bucket: str, prefix: str = "") -> List[str]:
"""
List all objects in the bucket with the given prefix.
Args:
bucket: S3 bucket name
prefix: Optional path prefix to filter objects
Returns:
List of object keys
"""
objects = []
paginator = self.s3_client.get_paginator("list_objects_v2")
logger.info(f"Listing objects in bucket: '{bucket}' with prefix: '{prefix}'")
logger.debug(f"S3 client endpoint: {self.s3_client.meta.endpoint_url}")
try:
page_count = 0
for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
page_count += 1
logger.debug(f"Processing page {page_count}, keys in response: {list(page.keys())}")
if "Contents" in page:
for obj in page["Contents"]:
key = obj["Key"]
if not key.endswith("/"):
objects.append(key)
logger.debug(f"Found object: {key}")
else:
logger.info(f"Page {page_count} has no 'Contents' key - bucket may be empty or prefix not found")
logger.info(f"Found {len(objects)} objects in bucket '{bucket}'")
except ClientError as e:
error_code = e.response.get("Error", {}).get("Code", "")
error_message = e.response.get("Error", {}).get("Message", "")
logger.error(f"ClientError listing objects - Code: {error_code}, Message: {error_message}")
logger.error(f"Full error response: {e.response}")
logger.error(f"Bucket: '{bucket}', Prefix: '{prefix}', Endpoint: {self.s3_client.meta.endpoint_url}")
if error_code == "NoSuchBucket":
raise Exception(f"S3 bucket '{bucket}' does not exist")
elif error_code == "AccessDenied":
raise Exception(
f"Access denied to S3 bucket '{bucket}'. Check your credentials and permissions."
)
elif error_code == "NoSuchKey":
# This is unusual for ListObjectsV2 - may indicate endpoint/bucket configuration issue
logger.error(
"NoSuchKey error on ListObjectsV2 - this may indicate the bucket name "
"is incorrect or the endpoint URL format is wrong. "
"For DigitalOcean Spaces, the endpoint should be like: "
"https://<region>.digitaloceanspaces.com and bucket should be just the space name."
)
raise Exception(
f"S3 error: {e}. For S3-compatible services, verify: "
f"1) Endpoint URL format (e.g., https://nyc3.digitaloceanspaces.com), "
f"2) Bucket name is just the space/bucket name without region prefix"
)
else:
raise Exception(f"S3 error: {e}")
except NoCredentialsError:
raise Exception(
"AWS credentials not found. Please provide valid credentials."
)
return objects
def get_object_content(self, bucket: str, key: str) -> Optional[str]:
"""
Get the content of an S3 object as text.
Args:
bucket: S3 bucket name
key: Object key
Returns:
File content as string, or None if file should be skipped
"""
if not self.is_text_file(key) and not self.is_supported_document(key):
return None
try:
response = self.s3_client.get_object(Bucket=bucket, Key=key)
content = response["Body"].read()
if self.is_text_file(key):
try:
decoded_content = content.decode("utf-8").strip()
if not decoded_content:
return None
return decoded_content
except UnicodeDecodeError:
return None
elif self.is_supported_document(key):
return self._process_document(content, key)
except ClientError as e:
error_code = e.response.get("Error", {}).get("Code", "")
if error_code == "NoSuchKey":
return None
elif error_code == "AccessDenied":
print(f"Access denied to object: {key}")
return None
else:
print(f"Error fetching object {key}: {e}")
return None
return None
def _process_document(self, content: bytes, key: str) -> Optional[str]:
"""
Process a document file (PDF, DOCX, etc.) and extract text.
Args:
content: File content as bytes
key: Object key (filename)
Returns:
Extracted text content
"""
ext = os.path.splitext(key)[1].lower()
with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as tmp_file:
tmp_file.write(content)
tmp_path = tmp_file.name
try:
from application.parser.file.bulk import SimpleDirectoryReader
reader = SimpleDirectoryReader(input_files=[tmp_path])
documents = reader.load_data()
if documents:
return "\n\n".join(doc.text for doc in documents if doc.text)
return None
except Exception as e:
print(f"Error processing document {key}: {e}")
return None
finally:
if os.path.exists(tmp_path):
os.unlink(tmp_path)
def load_data(self, inputs) -> List[Document]:
"""
Load documents from an S3 bucket.
Args:
inputs: JSON string or dict containing:
- aws_access_key_id: AWS access key ID
- aws_secret_access_key: AWS secret access key
- bucket: S3 bucket name
- prefix: Optional path prefix to filter objects
- region: AWS region (default: us-east-1)
- endpoint_url: Custom S3 endpoint URL (for MinIO, R2, etc.)
Returns:
List of Document objects
"""
if isinstance(inputs, str):
try:
data = json.loads(inputs)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON input: {e}")
else:
data = inputs
required_fields = ["aws_access_key_id", "aws_secret_access_key", "bucket"]
missing_fields = [field for field in required_fields if not data.get(field)]
if missing_fields:
raise ValueError(f"Missing required fields: {', '.join(missing_fields)}")
aws_access_key_id = data["aws_access_key_id"]
aws_secret_access_key = data["aws_secret_access_key"]
bucket = data["bucket"]
prefix = data.get("prefix", "")
region = data.get("region", "us-east-1")
endpoint_url = data.get("endpoint_url", "")
logger.info(f"Loading data from S3 - Bucket: '{bucket}', Prefix: '{prefix}', Region: '{region}'")
if endpoint_url:
logger.info(f"Custom endpoint URL provided: '{endpoint_url}'")
corrected_bucket = self._init_client(
aws_access_key_id, aws_secret_access_key, region, endpoint_url or None, bucket
)
# Use the corrected bucket name if endpoint URL normalization extracted one
if corrected_bucket and corrected_bucket != bucket:
logger.info(f"Using corrected bucket name: '{corrected_bucket}' (original: '{bucket}')")
bucket = corrected_bucket
objects = self.list_objects(bucket, prefix)
documents = []
for key in objects:
content = self.get_object_content(bucket, key)
if content is None:
continue
documents.append(
Document(
text=content,
doc_id=key,
extra_info={
"title": os.path.basename(key),
"source": f"s3://{bucket}/{key}",
"bucket": bucket,
"key": key,
},
)
)
logger.info(f"Loaded {len(documents)} documents from S3 bucket '{bucket}'")
return documents

View File

@@ -1,8 +1,9 @@
import logging
import requests
import re # Import regular expression library
import xml.etree.ElementTree as ET
import defusedxml.ElementTree as ET
from application.parser.remote.base import BaseRemote
from application.core.url_validation import validate_url, SSRFError
class SitemapLoader(BaseRemote):
def __init__(self, limit=20):
@@ -14,7 +15,14 @@ class SitemapLoader(BaseRemote):
sitemap_url= inputs
# Check if the input is a list and if it is, use the first element
if isinstance(sitemap_url, list) and sitemap_url:
url = sitemap_url[0]
sitemap_url = sitemap_url[0]
# Validate URL to prevent SSRF attacks
try:
sitemap_url = validate_url(sitemap_url)
except SSRFError as e:
logging.error(f"URL validation failed: {e}")
return []
urls = self._extract_urls(sitemap_url)
if not urls:
@@ -40,8 +48,13 @@ class SitemapLoader(BaseRemote):
def _extract_urls(self, sitemap_url):
try:
response = requests.get(sitemap_url)
# Validate URL before fetching to prevent SSRF
validate_url(sitemap_url)
response = requests.get(sitemap_url, timeout=30)
response.raise_for_status() # Raise an exception for HTTP errors
except SSRFError as e:
print(f"URL validation failed for sitemap: {sitemap_url}. Error: {e}")
return []
except (requests.exceptions.HTTPError, requests.exceptions.ConnectionError) as e:
print(f"Failed to fetch sitemap: {sitemap_url}. Error: {e}")
return []

View File

@@ -1,7 +1,7 @@
"""Base schema for readers."""
from dataclasses import dataclass
from langchain.docstore.document import Document as LCDocument
from langchain_core.documents import Document as LCDocument
from application.parser.schema.schema import BaseDocument

View File

@@ -1,91 +1,98 @@
anthropic==0.49.0
boto3==1.38.18
beautifulsoup4==4.13.4
celery==5.4.0
cryptography==42.0.8
anthropic==0.75.0
boto3==1.42.17
beautifulsoup4==4.14.3
cel-python==0.5.0
celery==5.6.0
cryptography==46.0.3
dataclasses-json==0.6.7
docx2txt==0.8
duckduckgo-search==7.5.2
ebooklib==0.18
defusedxml==0.7.1
docling>=2.16.0
rapidocr>=1.4.0
onnxruntime>=1.19.0
docx2txt==0.9
duckduckgo-search==8.1.1
ebooklib==0.20
escodegen==1.0.11
esprima==4.0.1
esutils==1.0.1
elevenlabs==2.17.0
Flask==3.1.1
faiss-cpu==1.9.0.post1
fastmcp==2.11.0
flask-restx==1.3.0
google-genai==1.49.0
google-api-python-client==2.179.0
google-auth-httplib2==0.2.0
google-auth-oauthlib==1.2.2
elevenlabs==2.27.0
Flask==3.1.2
faiss-cpu==1.13.2
fastmcp==2.14.1
flask-restx==1.3.2
google-genai==1.54.0
google-api-python-client==2.187.0
google-auth-httplib2==0.3.0
google-auth-oauthlib==1.2.3
gTTS==2.5.4
gunicorn==23.0.0
html2text==2025.4.15
javalang==0.13.0
jinja2==3.1.6
jiter==0.8.2
jiter==0.12.0
jmespath==1.0.1
joblib==1.4.2
joblib==1.5.3
jsonpatch==1.33
jsonpointer==3.0.0
kombu==5.4.2
langchain==0.3.20
langchain-community==0.3.19
langchain-core==0.3.59
langchain-openai==0.3.16
langchain-text-splitters==0.3.8
langsmith==0.3.42
lazy-object-proxy==1.10.0
lxml==5.3.1
markupsafe==3.0.2
marshmallow==3.26.1
kombu==5.6.1
langchain==1.2.0
langchain-community==0.4.1
langchain-core==1.2.5
langchain-openai==1.1.6
langchain-text-splitters==1.1.0
langsmith==0.5.1
lazy-object-proxy==1.12.0
lxml==6.0.2
markupsafe==3.0.3
marshmallow>=3.18.0,<5.0.0
mpmath==1.3.0
multidict==6.4.3
mypy-extensions==1.0.0
networkx==3.4.2
numpy==2.2.1
openai==1.78.1
openapi3-parser==1.1.21
orjson==3.10.14
multidict==6.7.0
mypy-extensions==1.1.0
networkx==3.6.1
numpy==2.4.0
openai==2.14.0
openapi3-parser==1.1.22
orjson==3.11.5
packaging==24.2
pandas==2.2.3
pandas==2.3.3
openpyxl==3.1.5
pathable==0.4.4
pillow==11.1.0
pdf2image>=1.17.0
pillow
portalocker>=2.7.0,<3.0.0
prance==23.6.21.0
prompt-toolkit==3.0.51
protobuf==5.29.3
psycopg2-binary==2.9.10
prance==25.4.8.0
prompt-toolkit==3.0.52
protobuf==6.33.2
psycopg2-binary==2.9.11
py==1.11.0
pydantic
pydantic-core
pydantic-settings
pymongo==4.11.3
pypdf==5.5.0
pymongo==4.15.5
pypdf==6.5.0
python-dateutil==2.9.0.post0
python-dotenv
python-jose==3.4.0
python-jose==3.5.0
python-pptx==1.0.2
redis==5.2.1
referencing>=0.28.0,<0.31.0
regex==2024.11.6
requests==2.32.3
redis==7.1.0
referencing>=0.28.0,<0.38.0
regex==2025.11.3
requests==2.32.5
retry==0.9.2
sentence-transformers==3.3.1
tiktoken==0.8.0
tokenizers==0.21.0
torch==2.7.0
sentence-transformers==5.2.0
tiktoken==0.12.0
tokenizers==0.22.1
torch==2.9.1
tqdm==4.67.1
transformers==4.51.3
typing-extensions==4.12.2
transformers==4.57.3
typing-extensions==4.15.0
typing-inspect==0.9.0
tzdata==2024.2
urllib3==2.3.0
tzdata==2025.3
urllib3==2.6.3
vine==5.1.0
wcwidth==0.2.13
werkzeug>=3.1.0,<3.1.2
yarl==1.20.0
markdownify==1.1.0
tldextract==5.1.3
websockets==14.1
wcwidth==0.2.14
werkzeug>=3.1.0
yarl==1.22.0
markdownify==1.2.2
tldextract==5.3.0
websockets==15.0.1

View File

@@ -18,6 +18,7 @@ class ClassicRAG(BaseRetriever):
doc_token_limit=50000,
model_id="docsgpt-local",
user_api_key=None,
agent_id=None,
llm_name=settings.LLM_PROVIDER,
api_key=settings.API_KEY,
decoded_token=None,
@@ -35,14 +36,15 @@ class ClassicRAG(BaseRetriever):
self.chunks = 2
else:
self.chunks = chunks
user_identifier = user_api_key if user_api_key else "default"
user_id = decoded_token.get("sub") if decoded_token else "default"
logging.info(
f"ClassicRAG initialized with chunks={self.chunks}, user_api_key={user_identifier}, "
f"ClassicRAG initialized with chunks={self.chunks}, user_id={user_id}, "
f"sources={'active_docs' in source and source['active_docs'] is not None}"
)
self.model_id = model_id
self.doc_token_limit = doc_token_limit
self.user_api_key = user_api_key
self.agent_id = agent_id
self.llm_name = llm_name
self.api_key = api_key
self.llm = LLMCreator.create_llm(
@@ -50,6 +52,7 @@ class ClassicRAG(BaseRetriever):
api_key=self.api_key,
user_api_key=self.user_api_key,
decoded_token=decoded_token,
agent_id=self.agent_id,
)
if "active_docs" in source and source["active_docs"] is not None:

View File

@@ -1,22 +1,104 @@
import sys
import logging
from datetime import datetime
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.utils import num_tokens_from_object_or_list, num_tokens_from_string
logger = logging.getLogger(__name__)
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
usage_collection = db["token_usage"]
def update_token_usage(decoded_token, user_api_key, token_usage):
def _serialize_for_token_count(value):
"""Normalize payloads into token-countable primitives."""
if isinstance(value, str):
# Avoid counting large binary payloads in data URLs as text tokens.
if value.startswith("data:") and ";base64," in value:
return ""
return value
if value is None:
return ""
if isinstance(value, list):
return [_serialize_for_token_count(item) for item in value]
if isinstance(value, dict):
serialized = {}
for key, raw in value.items():
key_lower = str(key).lower()
# Skip raw binary-like fields; keep textual tool-call fields.
if key_lower in {"data", "base64", "image_data"} and isinstance(raw, str):
continue
if key_lower == "url" and isinstance(raw, str) and ";base64," in raw:
continue
serialized[key] = _serialize_for_token_count(raw)
return serialized
if hasattr(value, "model_dump") and callable(getattr(value, "model_dump")):
return _serialize_for_token_count(value.model_dump())
if hasattr(value, "to_dict") and callable(getattr(value, "to_dict")):
return _serialize_for_token_count(value.to_dict())
if hasattr(value, "__dict__"):
return _serialize_for_token_count(vars(value))
return str(value)
def _count_tokens(value):
serialized = _serialize_for_token_count(value)
if isinstance(serialized, str):
return num_tokens_from_string(serialized)
return num_tokens_from_object_or_list(serialized)
def _count_prompt_tokens(messages, tools=None, usage_attachments=None, **kwargs):
prompt_tokens = 0
for message in messages or []:
if not isinstance(message, dict):
prompt_tokens += _count_tokens(message)
continue
prompt_tokens += _count_tokens(message.get("content"))
# Include tool-related message fields for providers that use OpenAI-native format.
prompt_tokens += _count_tokens(message.get("tool_calls"))
prompt_tokens += _count_tokens(message.get("tool_call_id"))
prompt_tokens += _count_tokens(message.get("function_call"))
prompt_tokens += _count_tokens(message.get("function_response"))
# Count tool schema payload passed to the model.
prompt_tokens += _count_tokens(tools)
# Count structured-output/schema payloads when provided.
prompt_tokens += _count_tokens(kwargs.get("response_format"))
prompt_tokens += _count_tokens(kwargs.get("response_schema"))
# Optional usage-only attachment context (not forwarded to provider).
prompt_tokens += _count_tokens(usage_attachments)
return prompt_tokens
def update_token_usage(decoded_token, user_api_key, token_usage, agent_id=None):
if "pytest" in sys.modules:
return
if decoded_token:
user_id = decoded_token["sub"]
else:
user_id = None
user_id = decoded_token.get("sub") if isinstance(decoded_token, dict) else None
normalized_agent_id = str(agent_id) if agent_id else None
if not user_id and not user_api_key and not normalized_agent_id:
logger.warning(
"Skipping token usage insert: missing user_id, api_key, and agent_id"
)
return
usage_data = {
"user_id": user_id,
"api_key": user_api_key,
@@ -24,24 +106,31 @@ def update_token_usage(decoded_token, user_api_key, token_usage):
"generated_tokens": token_usage["generated_tokens"],
"timestamp": datetime.now(),
}
if normalized_agent_id:
usage_data["agent_id"] = normalized_agent_id
usage_collection.insert_one(usage_data)
def gen_token_usage(func):
def wrapper(self, model, messages, stream, tools, **kwargs):
for message in messages:
if message["content"]:
self.token_usage["prompt_tokens"] += num_tokens_from_string(
message["content"]
)
usage_attachments = kwargs.pop("_usage_attachments", None)
call_usage = {"prompt_tokens": 0, "generated_tokens": 0}
call_usage["prompt_tokens"] += _count_prompt_tokens(
messages,
tools=tools,
usage_attachments=usage_attachments,
**kwargs,
)
result = func(self, model, messages, stream, tools, **kwargs)
if isinstance(result, str):
self.token_usage["generated_tokens"] += num_tokens_from_string(result)
else:
self.token_usage["generated_tokens"] += num_tokens_from_object_or_list(
result
)
update_token_usage(self.decoded_token, self.user_api_key, self.token_usage)
call_usage["generated_tokens"] += _count_tokens(result)
self.token_usage["prompt_tokens"] += call_usage["prompt_tokens"]
self.token_usage["generated_tokens"] += call_usage["generated_tokens"]
update_token_usage(
self.decoded_token,
self.user_api_key,
call_usage,
getattr(self, "agent_id", None),
)
return result
return wrapper
@@ -49,17 +138,28 @@ def gen_token_usage(func):
def stream_token_usage(func):
def wrapper(self, model, messages, stream, tools, **kwargs):
for message in messages:
self.token_usage["prompt_tokens"] += num_tokens_from_string(
message["content"]
)
usage_attachments = kwargs.pop("_usage_attachments", None)
call_usage = {"prompt_tokens": 0, "generated_tokens": 0}
call_usage["prompt_tokens"] += _count_prompt_tokens(
messages,
tools=tools,
usage_attachments=usage_attachments,
**kwargs,
)
batch = []
result = func(self, model, messages, stream, tools, **kwargs)
for r in result:
batch.append(r)
yield r
for line in batch:
self.token_usage["generated_tokens"] += num_tokens_from_string(line)
update_token_usage(self.decoded_token, self.user_api_key, self.token_usage)
call_usage["generated_tokens"] += _count_tokens(line)
self.token_usage["prompt_tokens"] += call_usage["prompt_tokens"]
self.token_usage["generated_tokens"] += call_usage["generated_tokens"]
update_token_usage(
self.decoded_token,
self.user_api_key,
call_usage,
getattr(self, "agent_id", None),
)
return wrapper

View File

@@ -1,7 +1,11 @@
import base64
import hashlib
import io
import logging
import os
import re
import uuid
from typing import List
import tiktoken
from flask import jsonify, make_response
@@ -11,6 +15,8 @@ from application.core.model_utils import get_token_limit
from application.core.settings import settings
logger = logging.getLogger(__name__)
_encoding = None
@@ -77,11 +83,11 @@ def count_tokens_docs(docs):
def calculate_doc_token_budget(
model_id: str = "gpt-4o", history_token_limit: int = 2000
model_id: str = "gpt-4o"
) -> int:
total_context = get_token_limit(model_id)
reserved = sum(settings.RESERVED_TOKENS.values())
doc_budget = total_context - history_token_limit - reserved
doc_budget = total_context - reserved
return max(doc_budget, 1000)
@@ -215,6 +221,93 @@ def calculate_compression_threshold(
return threshold
def convert_pdf_to_images(
file_path: str,
storage=None,
max_pages: int = 20,
dpi: int = 150,
image_format: str = "PNG",
) -> List[dict]:
"""
Convert PDF pages to images for LLMs that support images but not PDFs.
This enables "synthetic PDF support" by converting each PDF page to an image
that can be sent to vision-capable LLMs like Claude.
Args:
file_path: Path to the PDF file (can be storage path)
storage: Optional storage instance for retrieving files
max_pages: Maximum number of pages to convert (default 20 to avoid context overflow)
dpi: Resolution for rendering (default 150 for balance of quality/size)
image_format: Output format (PNG recommended for quality)
Returns:
List of dicts with keys:
- 'data': base64-encoded image data
- 'mime_type': MIME type (e.g., 'image/png')
- 'page': Page number (1-indexed)
Raises:
ImportError: If pdf2image is not installed
FileNotFoundError: If file doesn't exist
Exception: If conversion fails
"""
try:
from pdf2image import convert_from_path, convert_from_bytes
except ImportError:
raise ImportError(
"pdf2image is required for PDF-to-image conversion. "
"Install it with: pip install pdf2image\n"
"Also ensure poppler-utils is installed on your system."
)
images_data = []
mime_type = f"image/{image_format.lower()}"
try:
# Get PDF content either from storage or direct file path
if storage and hasattr(storage, "get_file"):
with storage.get_file(file_path) as pdf_file:
pdf_bytes = pdf_file.read()
pil_images = convert_from_bytes(
pdf_bytes,
dpi=dpi,
fmt=image_format.lower(),
first_page=1,
last_page=max_pages,
)
else:
pil_images = convert_from_path(
file_path,
dpi=dpi,
fmt=image_format.lower(),
first_page=1,
last_page=max_pages,
)
for page_num, pil_image in enumerate(pil_images, start=1):
# Convert PIL image to base64
buffer = io.BytesIO()
pil_image.save(buffer, format=image_format)
buffer.seek(0)
base64_data = base64.b64encode(buffer.read()).decode("utf-8")
images_data.append({
"data": base64_data,
"mime_type": mime_type,
"page": page_num,
})
return images_data
except FileNotFoundError:
logger.error(f"PDF file not found: {file_path}")
raise
except Exception as e:
logger.error(f"Error converting PDF to images: {e}", exc_info=True)
raise
def clean_text_for_tts(text: str) -> str:
"""
clean text for Text-to-Speech processing.

View File

@@ -2,41 +2,77 @@ import logging
import os
from abc import ABC, abstractmethod
import requests
from langchain_openai import OpenAIEmbeddings
from sentence_transformers import SentenceTransformer
from application.core.settings import settings
class EmbeddingsWrapper:
def __init__(self, model_name, *args, **kwargs):
logging.info(f"Initializing EmbeddingsWrapper with model: {model_name}")
try:
kwargs.setdefault("trust_remote_code", True)
self.model = SentenceTransformer(
model_name,
config_kwargs={"allow_dangerous_deserialization": True},
*args,
**kwargs,
class RemoteEmbeddings:
"""
Wrapper for remote embeddings API (OpenAI-compatible).
Used when EMBEDDINGS_BASE_URL is configured.
Sends requests to {base_url}/v1/embeddings in OpenAI format.
"""
def __init__(self, api_url: str, model_name: str, api_key: str = None):
self.api_url = api_url.rstrip("/")
self.model_name = model_name
self.headers = {"Content-Type": "application/json"}
if api_key:
self.headers["Authorization"] = f"Bearer {api_key}"
self.dimension = 768
def _embed(self, inputs):
"""Send embedding request to remote API in OpenAI-compatible format."""
payload = {"input": inputs}
if self.model_name:
payload["model"] = self.model_name
url = f"{self.api_url}/v1/embeddings"
response = requests.post(url, headers=self.headers, json=payload, timeout=180)
response.raise_for_status()
result = response.json()
# Handle OpenAI-compatible response format
if isinstance(result, dict):
if "error" in result:
raise ValueError(f"Remote embeddings API error: {result['error']}")
if "data" in result:
# Sort by index to ensure correct order
data = sorted(result["data"], key=lambda x: x.get("index", 0))
return [item["embedding"] for item in data]
raise ValueError(
f"Unexpected response format from remote embeddings API: {result}"
)
if self.model is None or self.model._first_module() is None:
raise ValueError(
f"SentenceTransformer model failed to load properly for: {model_name}"
)
self.dimension = self.model.get_sentence_embedding_dimension()
logging.info(f"Successfully loaded model with dimension: {self.dimension}")
except Exception as e:
logging.error(
f"Failed to initialize SentenceTransformer with model {model_name}: {str(e)}",
exc_info=True,
else:
raise ValueError(
f"Unexpected response format from remote embeddings API: {result}"
)
raise
def embed_query(self, query: str):
return self.model.encode(query).tolist()
"""Embed a single query string."""
embeddings_list = self._embed(query)
if (
isinstance(embeddings_list, list)
and len(embeddings_list) == 1
and isinstance(embeddings_list[0], list)
):
if self.dimension is None:
self.dimension = len(embeddings_list[0])
return embeddings_list[0]
raise ValueError(
f"Unexpected result structure after embedding query: {embeddings_list}"
)
def embed_documents(self, documents: list):
return self.model.encode(documents).tolist()
"""Embed a list of documents."""
if not documents:
return []
embeddings_list = self._embed(documents)
if self.dimension is None and embeddings_list:
self.dimension = len(embeddings_list[0])
return embeddings_list
def __call__(self, text):
if isinstance(text, str):
@@ -47,6 +83,13 @@ class EmbeddingsWrapper:
raise ValueError("Input must be a string or a list of strings")
def _get_embeddings_wrapper():
"""Lazy import of EmbeddingsWrapper to avoid loading SentenceTransformer when using remote embeddings."""
from application.vectorstore.embeddings_local import EmbeddingsWrapper
return EmbeddingsWrapper
class EmbeddingsSingleton:
_instances = {}
@@ -60,8 +103,13 @@ class EmbeddingsSingleton:
@staticmethod
def _create_instance(embeddings_name, *args, **kwargs):
if embeddings_name == "openai_text-embedding-ada-002":
return OpenAIEmbeddings(*args, **kwargs)
# Lazy import EmbeddingsWrapper only when needed (avoids loading SentenceTransformer)
EmbeddingsWrapper = _get_embeddings_wrapper()
embeddings_factory = {
"openai_text-embedding-ada-002": OpenAIEmbeddings,
"huggingface_sentence-transformers/all-mpnet-base-v2": lambda: EmbeddingsWrapper(
"sentence-transformers/all-mpnet-base-v2"
),
@@ -121,6 +169,20 @@ class BaseVectorStore(ABC):
)
def _get_embeddings(self, embeddings_name, embeddings_key=None):
# Check for remote embeddings first
if settings.EMBEDDINGS_BASE_URL:
logging.info(
f"Using remote embeddings API at: {settings.EMBEDDINGS_BASE_URL}"
)
cache_key = f"remote_{settings.EMBEDDINGS_BASE_URL}_{embeddings_name}"
if cache_key not in EmbeddingsSingleton._instances:
EmbeddingsSingleton._instances[cache_key] = RemoteEmbeddings(
api_url=settings.EMBEDDINGS_BASE_URL,
model_name=embeddings_name,
api_key=embeddings_key,
)
return EmbeddingsSingleton._instances[cache_key]
if embeddings_name == "openai_text-embedding-ada-002":
if self.is_azure_configured():
os.environ["OPENAI_API_TYPE"] = "azure"

View File

@@ -0,0 +1,48 @@
"""
Local embeddings using SentenceTransformer.
This module is only imported when EMBEDDINGS_BASE_URL is not set,
to avoid loading SentenceTransformer into memory when using remote embeddings.
"""
import logging
from sentence_transformers import SentenceTransformer
class EmbeddingsWrapper:
def __init__(self, model_name, *args, **kwargs):
logging.info(f"Initializing EmbeddingsWrapper with model: {model_name}")
try:
kwargs.setdefault("trust_remote_code", True)
self.model = SentenceTransformer(
model_name,
config_kwargs={"allow_dangerous_deserialization": True},
*args,
**kwargs,
)
if self.model is None or self.model._first_module() is None:
raise ValueError(
f"SentenceTransformer model failed to load properly for: {model_name}"
)
self.dimension = self.model.get_sentence_embedding_dimension()
logging.info(f"Successfully loaded model with dimension: {self.dimension}")
except Exception as e:
logging.error(
f"Failed to initialize SentenceTransformer with model {model_name}: {str(e)}",
exc_info=True,
)
raise
def embed_query(self, query: str):
return self.model.encode(query).tolist()
def embed_documents(self, documents: list):
return self.model.encode(documents).tolist()
def __call__(self, text):
if isinstance(text, str):
return self.embed_query(text)
elif isinstance(text, list):
return self.embed_documents(text)
else:
raise ValueError("Input must be a string or a list of strings")

View File

@@ -11,6 +11,7 @@ class PGVectorStore(BaseVectorStore):
source_id: str = "",
embeddings_key: str = "embeddings",
table_name: str = "documents",
decoded_token: Optional[str] = None,
vector_column: str = "embedding",
text_column: str = "text",
metadata_column: str = "metadata",
@@ -68,8 +69,7 @@ class PGVectorStore(BaseVectorStore):
# Enable pgvector extension
cursor.execute("CREATE EXTENSION IF NOT EXISTS vector;")
# Get embedding dimension
embedding_dim = getattr(self._embedding, 'dimension', 1536) # Default to OpenAI dimension
embedding_dim = getattr(self._embedding, 'dimension', 768)
# Create table with vector column
create_table_query = f"""
@@ -152,7 +152,7 @@ class PGVectorStore(BaseVectorStore):
"""Add texts with their embeddings to the vector store"""
if not texts:
return []
embeddings = self._embedding.embed_documents(texts)
metadatas = metadatas or [{}] * len(texts)
@@ -239,15 +239,13 @@ class PGVectorStore(BaseVectorStore):
def add_chunk(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> str:
"""Add a single chunk to the vector store"""
metadata = metadata or {}
# Create a copy to avoid modifying the original metadata
final_metadata = metadata.copy()
# Ensure the source_id is in the metadata so the chunk can be found by filters
final_metadata["source_id"] = self._source_id
embeddings = self._embedding.embed_documents([text])
if not embeddings:
raise ValueError("Could not generate embedding for chunk")

View File

@@ -25,7 +25,7 @@ from application.core.settings import settings
from application.parser.chunking import Chunker
from application.parser.connectors.connector_creator import ConnectorCreator
from application.parser.embedding_pipeline import embed_and_store_documents
from application.parser.file.bulk import SimpleDirectoryReader
from application.parser.file.bulk import SimpleDirectoryReader, get_default_file_extractor
from application.parser.remote.remote_creator import RemoteCreator
from application.parser.schema.base import Document
from application.retriever.retriever_creator import RetrieverCreator
@@ -52,6 +52,41 @@ def metadata_from_filename(title):
return {"title": title}
def _normalize_file_name_map(file_name_map):
if not file_name_map:
return {}
if isinstance(file_name_map, str):
try:
file_name_map = json.loads(file_name_map)
except Exception:
return {}
return file_name_map if isinstance(file_name_map, dict) else {}
def _get_display_name(file_name_map, rel_path):
if not file_name_map or not rel_path:
return None
if rel_path in file_name_map:
return file_name_map[rel_path]
base_name = os.path.basename(rel_path)
return file_name_map.get(base_name)
def _apply_display_names_to_structure(structure, file_name_map, prefix=""):
if not isinstance(structure, dict) or not file_name_map:
return structure
for name, node in structure.items():
if isinstance(node, dict) and "type" in node and "size_bytes" in node:
rel_path = f"{prefix}/{name}" if prefix else name
display_name = _get_display_name(file_name_map, rel_path)
if display_name:
node["display_name"] = display_name
elif isinstance(node, dict):
next_prefix = f"{prefix}/{name}" if prefix else name
_apply_display_names_to_structure(node, file_name_map, next_prefix)
return structure
# Define a function to generate a random string of a given length.
@@ -63,10 +98,111 @@ current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
# Zip extraction security limits
MAX_UNCOMPRESSED_SIZE = 500 * 1024 * 1024 # 500 MB max uncompressed size
MAX_FILE_COUNT = 10000 # Maximum number of files to extract
MAX_COMPRESSION_RATIO = 100 # Maximum compression ratio (to detect zip bombs)
class ZipExtractionError(Exception):
"""Raised when zip extraction fails due to security constraints."""
pass
def _is_path_safe(base_path: str, target_path: str) -> bool:
"""
Check if target_path is safely within base_path (prevents zip slip attacks).
Args:
base_path: The base directory where extraction should occur.
target_path: The full path where a file would be extracted.
Returns:
True if the path is safe, False otherwise.
"""
# Resolve to absolute paths and check containment
base_resolved = os.path.realpath(base_path)
target_resolved = os.path.realpath(target_path)
return target_resolved.startswith(base_resolved + os.sep) or target_resolved == base_resolved
def _validate_zip_safety(zip_path: str, extract_to: str) -> None:
"""
Validate a zip file for security issues before extraction.
Checks for:
- Zip bombs (excessive compression ratio or uncompressed size)
- Too many files
- Path traversal attacks (zip slip)
Args:
zip_path: Path to the zip file.
extract_to: Destination directory.
Raises:
ZipExtractionError: If the zip file fails security validation.
"""
try:
with zipfile.ZipFile(zip_path, "r") as zip_ref:
# Get compressed size
compressed_size = os.path.getsize(zip_path)
# Calculate total uncompressed size and file count
total_uncompressed = 0
file_count = 0
for info in zip_ref.infolist():
file_count += 1
# Check file count limit
if file_count > MAX_FILE_COUNT:
raise ZipExtractionError(
f"Zip file contains too many files (>{MAX_FILE_COUNT}). "
"This may be a zip bomb attack."
)
# Accumulate uncompressed size
total_uncompressed += info.file_size
# Check total uncompressed size
if total_uncompressed > MAX_UNCOMPRESSED_SIZE:
raise ZipExtractionError(
f"Zip file uncompressed size exceeds limit "
f"({total_uncompressed / (1024*1024):.1f} MB > "
f"{MAX_UNCOMPRESSED_SIZE / (1024*1024):.1f} MB). "
"This may be a zip bomb attack."
)
# Check for path traversal (zip slip)
target_path = os.path.join(extract_to, info.filename)
if not _is_path_safe(extract_to, target_path):
raise ZipExtractionError(
f"Zip file contains path traversal attempt: {info.filename}"
)
# Check compression ratio (only if compressed size is meaningful)
if compressed_size > 0 and total_uncompressed > 0:
compression_ratio = total_uncompressed / compressed_size
if compression_ratio > MAX_COMPRESSION_RATIO:
raise ZipExtractionError(
f"Zip file has suspicious compression ratio ({compression_ratio:.1f}:1 > "
f"{MAX_COMPRESSION_RATIO}:1). This may be a zip bomb attack."
)
except zipfile.BadZipFile as e:
raise ZipExtractionError(f"Invalid or corrupted zip file: {e}")
def extract_zip_recursive(zip_path, extract_to, current_depth=0, max_depth=5):
"""
Recursively extract zip files with a limit on recursion depth.
Recursively extract zip files with security protections.
Security measures:
- Limits recursion depth to prevent infinite loops
- Validates uncompressed size to prevent zip bombs
- Limits number of files to prevent resource exhaustion
- Checks compression ratio to detect zip bombs
- Validates paths to prevent zip slip attacks
Args:
zip_path (str): Path to the zip file to be extracted.
@@ -77,20 +213,33 @@ def extract_zip_recursive(zip_path, extract_to, current_depth=0, max_depth=5):
if current_depth > max_depth:
logging.warning(f"Reached maximum recursion depth of {max_depth}")
return
try:
# Validate zip file safety before extraction
_validate_zip_safety(zip_path, extract_to)
# Safe to extract
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(extract_to)
os.remove(zip_path) # Remove the zip file after extracting
except ZipExtractionError as e:
logging.error(f"Zip security validation failed for {zip_path}: {e}")
# Remove the potentially malicious zip file
try:
os.remove(zip_path)
except OSError:
pass
return
except Exception as e:
logging.error(f"Error extracting zip file {zip_path}: {e}", exc_info=True)
return
# Check for nested zip files and extract them
# Check for nested zip files and extract them
for root, dirs, files in os.walk(extract_to):
for file in files:
if file.endswith(".zip"):
# If a nested zip file is found, extract it recursively
file_path = os.path.join(root, file)
extract_zip_recursive(file_path, root, current_depth + 1, max_depth)
@@ -173,6 +322,7 @@ def run_agent_logic(agent_config, input_data):
chunks = int(agent_config.get("chunks", 2))
prompt_id = agent_config.get("prompt_id", "default")
user_api_key = agent_config["key"]
agent_id = str(agent_config.get("_id")) if agent_config.get("_id") else None
agent_type = agent_config.get("agent_type", "classic")
decoded_token = {"sub": agent_config.get("user")}
json_schema = agent_config.get("json_schema")
@@ -190,9 +340,8 @@ def run_agent_logic(agent_config, input_data):
system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER)
# Calculate proper doc_token_limit based on model's context window
history_token_limit = 2000 # Default for webhooks
doc_token_limit = calculate_doc_token_budget(
model_id=model_id, history_token_limit=history_token_limit
model_id=model_id
)
retriever = RetrieverCreator.create_retriever(
@@ -204,6 +353,7 @@ def run_agent_logic(agent_config, input_data):
doc_token_limit=doc_token_limit,
model_id=model_id,
user_api_key=user_api_key,
agent_id=agent_id,
decoded_token=decoded_token,
)
@@ -222,6 +372,7 @@ def run_agent_logic(agent_config, input_data):
llm_name=provider or settings.LLM_PROVIDER,
model_id=model_id,
api_key=system_api_key,
agent_id=agent_id,
user_api_key=user_api_key,
prompt=prompt,
chat_history=[],
@@ -262,7 +413,15 @@ def run_agent_logic(agent_config, input_data):
def ingest_worker(
self, directory, formats, job_name, file_path, filename, user, retriever="classic"
self,
directory,
formats,
job_name,
file_path,
filename,
user,
retriever="classic",
file_name_map=None,
):
"""
Ingest and process documents.
@@ -276,6 +435,7 @@ def ingest_worker(
filename (str): Original unsanitized filename provided by the user.
user (str): Identifier for the user initiating the ingestion (original, unsanitized).
retriever (str): Type of retriever to use for processing the documents.
file_name_map (dict|str|None): Optional mapping of safe relative paths to original filenames.
Returns:
dict: Information about the completed ingestion task, including input parameters and a "limited" flag.
@@ -355,6 +515,22 @@ def ingest_worker(
directory_structure = getattr(reader, "directory_structure", {})
logging.info(f"Directory structure from reader: {directory_structure}")
file_name_map = _normalize_file_name_map(file_name_map)
if file_name_map:
for doc in raw_docs:
extra_info = getattr(doc, "extra_info", None)
if not isinstance(extra_info, dict):
continue
rel_path = extra_info.get("source") or extra_info.get("file_path")
display_name = _get_display_name(file_name_map, rel_path)
if display_name:
display_name = str(display_name)
extra_info["filename"] = display_name
extra_info["file_name"] = display_name
extra_info["title"] = display_name
directory_structure = _apply_display_names_to_structure(
directory_structure, file_name_map
)
chunker = Chunker(
chunking_strategy="classic_chunk",
@@ -391,6 +567,8 @@ def ingest_worker(
"file_path": file_path,
"directory_structure": json.dumps(directory_structure),
}
if file_name_map:
file_data["file_name_map"] = json.dumps(file_name_map)
upload_index(vector_store_path, file_data)
except Exception as e:
@@ -434,6 +612,7 @@ def reingest_source_worker(self, source_id, user):
storage = StorageCreator.get_storage()
source_file_path = source.get("file_path", "")
file_name_map = _normalize_file_name_map(source.get("file_name_map"))
self.update_state(
state="PROGRESS", meta={"current": 20, "status": "Scanning current files"}
@@ -668,6 +847,14 @@ def reingest_source_worker(self, source_id, user):
)
except Exception:
pass
display_name = _get_display_name(
file_name_map, meta.get("source")
)
if display_name:
display_name = str(display_name)
meta["filename"] = display_name
meta["file_name"] = display_name
meta["title"] = display_name
vector_store.add_chunk(d.text, metadata=meta)
added += 1
@@ -682,6 +869,9 @@ def reingest_source_worker(self, source_id, user):
# 3) Update source directory structure timestamp
try:
total_tokens = sum(reader.file_token_counts.values())
directory_structure = _apply_display_names_to_structure(
directory_structure, file_name_map
)
sources_collection.update_one(
{"_id": ObjectId(source_id)},
@@ -755,6 +945,76 @@ def remote_worker(
tokens = count_tokens_docs(docs)
logging.info("Total tokens calculated: %d", tokens)
# Build directory structure from loaded documents
# Format matches local file uploads: nested structure with type, size_bytes, token_count
directory_structure = {}
for doc in raw_docs:
# Get the file path from extra_info
# For crawlers: file_path is a virtual path like "guides/setup.md"
# For other remotes: use key or title as fallback
file_path = ""
if doc.extra_info:
file_path = (
doc.extra_info.get("file_path", "")
or doc.extra_info.get("key", "")
or doc.extra_info.get("title", "")
)
if not file_path:
file_path = doc.doc_id or ""
if file_path:
# Calculate token count
token_count = num_tokens_from_string(doc.text) if doc.text else 0
# Estimate size in bytes from text content
size_bytes = len(doc.text.encode("utf-8")) if doc.text else 0
# Guess mime type from extension
file_name = (
file_path.split("/")[-1] if "/" in file_path else file_path
)
ext = os.path.splitext(file_name)[1].lower()
mime_types = {
".txt": "text/plain",
".md": "text/markdown",
".pdf": "application/pdf",
".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
".doc": "application/msword",
".html": "text/html",
".json": "application/json",
".csv": "text/csv",
".xml": "application/xml",
".py": "text/x-python",
".js": "text/javascript",
".ts": "text/typescript",
".jsx": "text/jsx",
".tsx": "text/tsx",
}
file_type = mime_types.get(ext, "application/octet-stream")
# Build nested directory structure from path
# e.g., "guides/setup.md" -> {"guides": {"setup.md": {...}}}
path_parts = file_path.split("/")
current_level = directory_structure
for i, part in enumerate(path_parts):
if i == len(path_parts) - 1:
# Last part is the file
current_level[part] = {
"type": file_type,
"size_bytes": size_bytes,
"token_count": token_count,
}
else:
# Intermediate parts are directories
if part not in current_level:
current_level[part] = {}
current_level = current_level[part]
logging.info(
f"Built directory structure with {len(directory_structure)} files: "
f"{list(directory_structure.keys())}"
)
if operation_mode == "upload":
id = ObjectId()
embed_and_store_documents(docs, full_path, id, self)
@@ -766,6 +1026,10 @@ def remote_worker(
embed_and_store_documents(docs, full_path, id, self)
self.update_state(state="PROGRESS", meta={"current": 100})
# Serialize remote_data as JSON if it's a dict (for S3, Reddit, etc.)
remote_data_serialized = (
json.dumps(source_data) if isinstance(source_data, dict) else source_data
)
file_data = {
"name": name_job,
"user": user,
@@ -773,8 +1037,9 @@ def remote_worker(
"retriever": retriever,
"id": str(id),
"type": loader,
"remote_data": source_data,
"remote_data": remote_data_serialized,
"sync_frequency": sync_frequency,
"directory_structure": json.dumps(directory_structure),
}
if operation_mode == "sync":
@@ -872,10 +1137,16 @@ def attachment_worker(self, file_info, user):
state="PROGRESS", meta={"current": 30, "status": "Processing content"}
)
file_extractor = get_default_file_extractor(
ocr_enabled=settings.DOCLING_OCR_ATTACHMENTS_ENABLED
)
content = storage.process_file(
relative_path,
lambda local_path, **kwargs: SimpleDirectoryReader(
input_files=[local_path], exclude_hidden=True, errors="ignore"
input_files=[local_path],
exclude_hidden=True,
errors="ignore",
file_extractor=file_extractor,
)
.load_data()[0]
.text,
@@ -961,13 +1232,14 @@ def agent_webhook_worker(self, agent_id, payload):
result = run_agent_logic(agent_config, input_data)
except Exception as e:
logging.error(f"Error running agent logic: {e}", exc_info=True)
return {"status": "error", "error": str(e)}
finally:
self.update_state(state="PROGRESS", meta={"current": 100})
return {"status": "error"}
else:
logging.info(
f"Webhook processed for agent {agent_id}", extra={"agent_id": agent_id}
)
return {"status": "success", "result": result}
finally:
self.update_state(state="PROGRESS", meta={"current": 100})
def ingest_connector(

View File

@@ -11,17 +11,13 @@ services:
backend:
build: ../application
env_file:
- ../.env
environment:
- API_KEY=$OPENAI_API_KEY
- EMBEDDINGS_KEY=$OPENAI_API_KEY
# Override URLs to use docker service names
- CELERY_BROKER_URL=redis://redis:6379/0
- CELERY_RESULT_BACKEND=redis://redis:6379/1
- MONGO_URI=mongodb://mongo:27017/docsgpt
- OPENAI_API_KEY=$OPENAI_API_KEY
- OPENAI_API_BASE=$OPENAI_API_BASE
- OPENAI_API_VERSION=$OPENAI_API_VERSION
- AZURE_DEPLOYMENT_NAME=$AZURE_DEPLOYMENT_NAME
- AZURE_EMBEDDINGS_DEPLOYMENT_NAME=$AZURE_EMBEDDINGS_DEPLOYMENT_NAME
ports:
- "7091:7091"
volumes:
@@ -35,18 +31,14 @@ services:
worker:
build: ../application
command: celery -A application.app.celery worker -l INFO
env_file:
- ../.env
environment:
- API_KEY=$OPENAI_API_KEY
- EMBEDDINGS_KEY=$OPENAI_API_KEY
# Override URLs to use docker service names
- CELERY_BROKER_URL=redis://redis:6379/0
- CELERY_RESULT_BACKEND=redis://redis:6379/1
- MONGO_URI=mongodb://mongo:27017/docsgpt
- API_URL=http://backend:7091
- OPENAI_API_KEY=$OPENAI_API_KEY
- OPENAI_API_BASE=$OPENAI_API_BASE
- OPENAI_API_VERSION=$OPENAI_API_VERSION
- AZURE_DEPLOYMENT_NAME=$AZURE_DEPLOYMENT_NAME
- AZURE_EMBEDDINGS_DEPLOYMENT_NAME=$AZURE_EMBEDDINGS_DEPLOYMENT_NAME
depends_on:
- redis
- mongo

View File

@@ -5,8 +5,8 @@ services:
image: arc53/docsgpt-fe:develop
environment:
- VITE_API_HOST=http://localhost:7091
- VITE_API_STREAMING=$VITE_API_STREAMING
- VITE_GOOGLE_CLIENT_ID=$VITE_GOOGLE_CLIENT_ID
- VITE_API_STREAMING=${VITE_API_STREAMING:-true}
- VITE_GOOGLE_CLIENT_ID=${VITE_GOOGLE_CLIENT_ID:-}
ports:
- "5173:5173"
depends_on:
@@ -16,16 +16,13 @@ services:
backend:
user: root
image: arc53/docsgpt:develop
env_file:
- ../.env
environment:
- API_KEY=$API_KEY
- EMBEDDINGS_KEY=$API_KEY
- LLM_PROVIDER=$LLM_PROVIDER
- LLM_NAME=$LLM_NAME
- CELERY_BROKER_URL=redis://redis:6379/0
- CELERY_RESULT_BACKEND=redis://redis:6379/1
- MONGO_URI=mongodb://mongo:27017/docsgpt
- CACHE_REDIS_URL=redis://redis:6379/2
- OPENAI_BASE_URL=$OPENAI_BASE_URL
ports:
- "7091:7091"
volumes:
@@ -41,11 +38,9 @@ services:
user: root
image: arc53/docsgpt:develop
command: celery -A application.app.celery worker -l INFO -B
env_file:
- ../.env
environment:
- API_KEY=$API_KEY
- EMBEDDINGS_KEY=$API_KEY
- LLM_PROVIDER=$LLM_PROVIDER
- LLM_NAME=$LLM_NAME
- CELERY_BROKER_URL=redis://redis:6379/0
- CELERY_RESULT_BACKEND=redis://redis:6379/1
- MONGO_URI=mongodb://mongo:27017/docsgpt

View File

@@ -16,17 +16,14 @@ services:
backend:
user: root
build: ../application
env_file:
- ../.env
environment:
- API_KEY=$API_KEY
- EMBEDDINGS_KEY=$API_KEY
- LLM_PROVIDER=$LLM_PROVIDER
- LLM_NAME=$LLM_NAME
# Override URLs to use docker service names
- CELERY_BROKER_URL=redis://redis:6379/0
- CELERY_RESULT_BACKEND=redis://redis:6379/1
- MONGO_URI=mongodb://mongo:27017/docsgpt
- CACHE_REDIS_URL=redis://redis:6379/2
- OPENAI_BASE_URL=$OPENAI_BASE_URL
- INTERNAL_KEY=$INTERNAL_KEY
ports:
- "7091:7091"
volumes:
@@ -41,17 +38,15 @@ services:
user: root
build: ../application
command: celery -A application.app.celery worker -l INFO -B
env_file:
- ../.env
environment:
- API_KEY=$API_KEY
- EMBEDDINGS_KEY=$API_KEY
- LLM_PROVIDER=$LLM_PROVIDER
- LLM_NAME=$LLM_NAME
# Override URLs to use docker service names
- CELERY_BROKER_URL=redis://redis:6379/0
- CELERY_RESULT_BACKEND=redis://redis:6379/1
- MONGO_URI=mongodb://mongo:27017/docsgpt
- API_URL=http://backend:7091
- CACHE_REDIS_URL=redis://redis:6379/2
- INTERNAL_KEY=$INTERNAL_KEY
volumes:
- ../application/indexes:/app/indexes
- ../application/inputs:/app/inputs

View File

@@ -0,0 +1,25 @@
import { generateStaticParamsFor, importPage } from 'nextra/pages';
import { useMDXComponents } from '../../mdx-components';
export const generateStaticParams = generateStaticParamsFor('mdxPath');
export async function generateMetadata(props) {
const params = await props.params;
const { metadata } = await importPage(params?.mdxPath);
return metadata;
}
const Wrapper = useMDXComponents().wrapper;
export default async function Page(props) {
const params = await props.params;
const result = await importPage(params?.mdxPath);
const { default: MDXContent, metadata, sourceCode, toc } = result;
return (
<Wrapper metadata={metadata} sourceCode={sourceCode} toc={toc}>
<MDXContent {...props} params={params} />
</Wrapper>
);
}

86
docs/app/layout.jsx Normal file
View File

@@ -0,0 +1,86 @@
import Image from 'next/image';
import { Analytics } from '@vercel/analytics/react';
import { Banner, Head } from 'nextra/components';
import { getPageMap } from 'nextra/page-map';
import { Footer, Layout, Navbar } from 'nextra-theme-docs';
import 'nextra-theme-docs/style.css';
import CuteLogo from '../public/cute-docsgpt.png';
import themeConfig from '../theme.config';
const github = 'https://github.com/arc53/DocsGPT';
export const metadata = {
title: {
default: 'DocsGPT Documentation',
template: '%s - DocsGPT Documentation',
},
description:
'Use DocsGPT to chat with your data. DocsGPT is a GPT-powered chatbot that can answer questions about your data.',
};
const navbar = (
<Navbar
logo={
<div style={{ alignItems: 'center', display: 'flex', gap: '8px' }}>
<Image src={CuteLogo} alt="DocsGPT logo" width={28} height={28} />
<span style={{ fontWeight: 'bold', fontSize: 18 }}>DocsGPT Docs</span>
</div>
}
projectLink={github}
chatLink="https://discord.com/invite/n5BX8dh8rU"
/>
);
const footer = (
<Footer>
<span>MIT {new Date().getFullYear()} © </span>
<a href="https://www.docsgpt.cloud/" target="_blank" rel="noreferrer">
DocsGPT
</a>
{' | '}
<a href="https://github.com/arc53/DocsGPT" target="_blank" rel="noreferrer">
GitHub
</a>
{' | '}
<a href="https://blog.docsgpt.cloud/" target="_blank" rel="noreferrer">
Blog
</a>
</Footer>
);
export default async function RootLayout({ children }) {
return (
<html lang="en" dir="ltr" suppressHydrationWarning>
<Head>
<link
rel="apple-touch-icon"
sizes="180x180"
href="/favicons/apple-touch-icon.png"
/>
<link rel="icon" type="image/png" sizes="32x32" href="/favicons/favicon-32x32.png" />
<link rel="icon" type="image/png" sizes="16x16" href="/favicons/favicon-16x16.png" />
<link rel="manifest" href="/favicons/site.webmanifest" />
<meta httpEquiv="Content-Language" content="en" />
</Head>
<body>
<Layout
banner={
<Banner storageKey="docs-launch">
<div className="flex justify-center items-center gap-2">
Welcome to the new DocsGPT docs!
</div>
</Banner>
}
navbar={navbar}
footer={footer}
pageMap={await getPageMap()}
{...themeConfig}
>
{children}
</Layout>
<Analytics />
</body>
</html>
);
}

View File

@@ -1,3 +1,5 @@
'use client';
import Image from 'next/image';
const iconMap = {
@@ -117,4 +119,4 @@ export function DeploymentCards({ items }) {
`}</style>
</>
);
}
}

View File

@@ -1,3 +1,5 @@
'use client';
import Image from 'next/image';
const iconMap = {
@@ -114,4 +116,4 @@ export function ToolCards({ items }) {
`}</style>
</>
);
}
}

Some files were not shown because too many files have changed in this diff Show More