Compare commits

..

1 Commits

Author SHA1 Message Date
Pavel
5a891647bf parser functions change
token_func proposed change to chunking. open_ai_func proposed change to embedding_pipeline. Late chunking first  implementation requires further testing.
2024-11-20 21:40:57 +04:00
507 changed files with 15436 additions and 55166 deletions

View File

@@ -1,15 +0,0 @@
FROM python:3.12-bookworm
# Install Node.js 20.x
RUN curl -fsSL https://deb.nodesource.com/setup_20.x | bash - \
&& apt-get install -y nodejs \
&& rm -rf /var/lib/apt/lists/*
# Install global npm packages
RUN npm install -g husky vite
# Create and activate Python virtual environment
RUN python -m venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
WORKDIR /workspace

View File

@@ -1,49 +0,0 @@
# Welcome to DocsGPT Devcontainer
Welcome to the DocsGPT development environment! This guide will help you get started quickly.
## Starting Services
To run DocsGPT, you need to start three main services: Flask (backend), Celery (task queue), and Vite (frontend). Here are the commands to start each service within the devcontainer:
### Vite (Frontend)
```bash
cd frontend
npm run dev -- --host
```
### Flask (Backend)
```bash
flask --app application/app.py run --host=0.0.0.0 --port=7091
```
### Celery (Task Queue)
```bash
celery -A application.app.celery worker -l INFO
```
## Github Codespaces Instructions
### 1. Make Ports Public:
Go to the "Ports" panel in Codespaces (usually located at the bottom of the VS Code window).
For both port 5173 and 7091, right-click on the port and select "Make Public".
![CleanShot 2025-02-12 at 09 46 14@2x](https://github.com/user-attachments/assets/00a34b16-a7ef-47af-9648-87a7e3008475)
### 2. Update VITE_API_HOST:
After making port 7091 public, copy the public URL provided by Codespaces for port 7091.
Open the file frontend/.env.development.
Find the line VITE_API_HOST=http://localhost:7091.
Replace http://localhost:7091 with the public URL you copied from Codespaces.
![CleanShot 2025-02-12 at 09 46 56@2x](https://github.com/user-attachments/assets/c472242f-1079-4cd8-bc0b-2d78db22b94c)

View File

@@ -1,24 +0,0 @@
{
"name": "DocsGPT Dev Container",
"dockerComposeFile": ["docker-compose-dev.yaml", "docker-compose.override.yaml"],
"service": "dev",
"workspaceFolder": "/workspace",
"postCreateCommand": ".devcontainer/post-create-command.sh",
"forwardPorts": [7091, 5173, 6379, 27017],
"customizations": {
"vscode": {
"extensions": [
"ms-python.python",
"ms-toolsai.jupyter",
"esbenp.prettier-vscode",
"dbaeumer.vscode-eslint"
]
},
"codespaces": {
"openFiles": [
".devcontainer/devc-welcome.md",
"CONTRIBUTING.md"
]
}
}
}

View File

@@ -1,40 +0,0 @@
version: '3.8'
services:
dev:
build:
context: .
dockerfile: Dockerfile
volumes:
- ../:/workspace:cached
command: sleep infinity
depends_on:
redis:
condition: service_healthy
mongo:
condition: service_healthy
environment:
- CELERY_BROKER_URL=redis://redis:6379/0
- CELERY_RESULT_BACKEND=redis://redis:6379/1
- MONGO_URI=mongodb://mongo:27017/docsgpt
- CACHE_REDIS_URL=redis://redis:6379/2
networks:
- default
redis:
healthcheck:
test: ["CMD", "redis-cli", "ping"]
interval: 5s
timeout: 30s
retries: 5
mongo:
healthcheck:
test: ["CMD", "mongosh", "--eval", "db.adminCommand('ping')"]
interval: 5s
timeout: 30s
retries: 5
networks:
default:
name: docsgpt-dev-network

View File

@@ -1,32 +0,0 @@
#!/bin/bash
set -e # Exit immediately if a command exits with a non-zero status
if [ ! -f frontend/.env.development ]; then
cp -n .env-template frontend/.env.development || true # Assuming .env-template is in the root
fi
# Determine VITE_API_HOST based on environment
if [ -n "$CODESPACES" ]; then
# Running in Codespaces
CODESPACE_NAME=$(echo "$CODESPACES" | cut -d'-' -f1) # Extract codespace name
PUBLIC_API_HOST="https://${CODESPACE_NAME}-7091.${GITHUB_CODESPACES_PORT_FORWARDING_DOMAIN}"
echo "Setting VITE_API_HOST for Codespaces: $PUBLIC_API_HOST in frontend/.env.development"
sed -i "s|VITE_API_HOST=.*|VITE_API_HOST=$PUBLIC_API_HOST|" frontend/.env.development
else
# Not running in Codespaces (local devcontainer)
DEFAULT_API_HOST="http://localhost:7091"
echo "Setting VITE_API_HOST for local dev: $DEFAULT_API_HOST in frontend/.env.development"
sed -i "s|VITE_API_HOST=.*|VITE_API_HOST=$DEFAULT_API_HOST|" frontend/.env.development
fi
mkdir -p model
if [ ! -d model/all-mpnet-base-v2 ]; then
wget -q https://d3dg1063dc54p9.cloudfront.net/models/embeddings/mpnet-base-v2.zip -O model/mpnet-base-v2.zip
unzip -q model/mpnet-base-v2.zip -d model
rm model/mpnet-base-v2.zip
fi
pip install -r application/requirements.txt
cd frontend
npm install --include=dev

2
.gitattributes vendored
View File

@@ -1,2 +0,0 @@
# Auto detect text files and perform LF normalization
* text=auto

3
.github/FUNDING.yml vendored
View File

@@ -1,3 +0,0 @@
# These are supported funding model platforms
github: arc53

View File

@@ -8,12 +8,12 @@ updates:
- package-ecosystem: "pip" # See documentation for possible values
directory: "/application" # Location of package manifests
schedule:
interval: "daily"
interval: "weekly"
- package-ecosystem: "npm" # See documentation for possible values
directory: "/frontend" # Location of package manifests
schedule:
interval: "daily"
interval: "weekly"
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "daily"
interval: "weekly"

View File

@@ -1,40 +0,0 @@
name: Bandit Security Scan
on:
push:
branches:
- main
pull_request:
types: [opened, synchronize, reopened]
jobs:
bandit_scan:
if: ${{ github.repository == 'arc53/DocsGPT' }}
runs-on: ubuntu-latest
permissions:
security-events: write
actions: read
contents: read
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.12'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install bandit # Bandit is needed for this action
if [ -f application/requirements.txt ]; then pip install -r application/requirements.txt; fi
- name: Run Bandit scan
uses: PyCQA/bandit-action@v1
with:
severity: medium
confidence: medium
targets: application/
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -5,33 +5,20 @@ on:
types: [published]
jobs:
build:
deploy:
if: github.repository == 'arc53/DocsGPT'
strategy:
matrix:
include:
- platform: linux/amd64
runner: ubuntu-latest
suffix: amd64
- platform: linux/arm64
runner: ubuntu-24.04-arm
suffix: arm64
runs-on: ${{ matrix.runner }}
runs-on: ubuntu-latest
permissions:
contents: read
packages: write
steps:
- uses: actions/checkout@v4
- name: Set up QEMU # Only needed for emulation, not for native arm64 builds
if: matrix.platform == 'linux/arm64'
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
with:
driver: docker-container
install: true
- name: Login to DockerHub
uses: docker/login-action@v3
@@ -46,67 +33,15 @@ jobs:
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Build and push platform-specific images
- name: Build and push Docker images to docker.io and ghcr.io
uses: docker/build-push-action@v6
with:
file: './application/Dockerfile'
platforms: ${{ matrix.platform }}
platforms: linux/amd64
context: ./application
push: true
tags: |
${{ secrets.DOCKER_USERNAME }}/docsgpt:${{ github.event.release.tag_name }}-${{ matrix.suffix }}
ghcr.io/${{ github.repository_owner }}/docsgpt:${{ github.event.release.tag_name }}-${{ matrix.suffix }}
provenance: false
sbom: false
${{ secrets.DOCKER_USERNAME }}/docsgpt:${{ github.event.release.tag_name }},${{ secrets.DOCKER_USERNAME }}/docsgpt:latest
ghcr.io/${{ github.repository_owner }}/docsgpt:${{ github.event.release.tag_name }},ghcr.io/${{ github.repository_owner }}/docsgpt:latest
cache-from: type=registry,ref=${{ secrets.DOCKER_USERNAME }}/docsgpt:latest
cache-to: type=inline
manifest:
if: github.repository == 'arc53/DocsGPT'
needs: build
runs-on: ubuntu-latest
permissions:
packages: write
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
with:
driver: docker-container
install: true
- name: Login to DockerHub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_PASSWORD }}
- name: Login to ghcr.io
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Create and push manifest for DockerHub
run: |
set -e
docker manifest create ${{ secrets.DOCKER_USERNAME }}/docsgpt:${{ github.event.release.tag_name }} \
--amend ${{ secrets.DOCKER_USERNAME }}/docsgpt:${{ github.event.release.tag_name }}-amd64 \
--amend ${{ secrets.DOCKER_USERNAME }}/docsgpt:${{ github.event.release.tag_name }}-arm64
docker manifest push ${{ secrets.DOCKER_USERNAME }}/docsgpt:${{ github.event.release.tag_name }}
docker manifest create ${{ secrets.DOCKER_USERNAME }}/docsgpt:latest \
--amend ${{ secrets.DOCKER_USERNAME }}/docsgpt:${{ github.event.release.tag_name }}-amd64 \
--amend ${{ secrets.DOCKER_USERNAME }}/docsgpt:${{ github.event.release.tag_name }}-arm64
docker manifest push ${{ secrets.DOCKER_USERNAME }}/docsgpt:latest
- name: Create and push manifest for ghcr.io
run: |
set -e
docker manifest create ghcr.io/${{ github.repository_owner }}/docsgpt:${{ github.event.release.tag_name }} \
--amend ghcr.io/${{ github.repository_owner }}/docsgpt:${{ github.event.release.tag_name }}-amd64 \
--amend ghcr.io/${{ github.repository_owner }}/docsgpt:${{ github.event.release.tag_name }}-arm64
docker manifest push ghcr.io/${{ github.repository_owner }}/docsgpt:${{ github.event.release.tag_name }}
docker manifest create ghcr.io/${{ github.repository_owner }}/docsgpt:latest \
--amend ghcr.io/${{ github.repository_owner }}/docsgpt:${{ github.event.release.tag_name }}-amd64 \
--amend ghcr.io/${{ github.repository_owner }}/docsgpt:${{ github.event.release.tag_name }}-arm64
docker manifest push ghcr.io/${{ github.repository_owner }}/docsgpt:latest

View File

@@ -5,33 +5,20 @@ on:
types: [published]
jobs:
build:
deploy:
if: github.repository == 'arc53/DocsGPT'
strategy:
matrix:
include:
- platform: linux/amd64
runner: ubuntu-latest
suffix: amd64
- platform: linux/arm64
runner: ubuntu-24.04-arm
suffix: arm64
runs-on: ${{ matrix.runner }}
runs-on: ubuntu-latest
permissions:
contents: read
packages: write
steps:
- uses: actions/checkout@v4
- name: Set up QEMU # Only needed for emulation, not for native arm64 builds
if: matrix.platform == 'linux/arm64'
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
with:
driver: docker-container
install: true
- name: Login to DockerHub
uses: docker/login-action@v3
@@ -46,67 +33,16 @@ jobs:
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Build and push platform-specific images
# Runs a single command using the runners shell
- name: Build and push Docker images to docker.io and ghcr.io
uses: docker/build-push-action@v6
with:
file: './frontend/Dockerfile'
platforms: ${{ matrix.platform }}
platforms: linux/amd64, linux/arm64
context: ./frontend
push: true
tags: |
${{ secrets.DOCKER_USERNAME }}/docsgpt-fe:${{ github.event.release.tag_name }}-${{ matrix.suffix }}
ghcr.io/${{ github.repository_owner }}/docsgpt-fe:${{ github.event.release.tag_name }}-${{ matrix.suffix }}
provenance: false
sbom: false
${{ secrets.DOCKER_USERNAME }}/docsgpt-fe:${{ github.event.release.tag_name }},${{ secrets.DOCKER_USERNAME }}/docsgpt-fe:latest
ghcr.io/${{ github.repository_owner }}/docsgpt-fe:${{ github.event.release.tag_name }},ghcr.io/${{ github.repository_owner }}/docsgpt-fe:latest
cache-from: type=registry,ref=${{ secrets.DOCKER_USERNAME }}/docsgpt-fe:latest
cache-to: type=inline
manifest:
if: github.repository == 'arc53/DocsGPT'
needs: build
runs-on: ubuntu-latest
permissions:
packages: write
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
with:
driver: docker-container
install: true
- name: Login to DockerHub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_PASSWORD }}
- name: Login to ghcr.io
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Create and push manifest for DockerHub
run: |
set -e
docker manifest create ${{ secrets.DOCKER_USERNAME }}/docsgpt-fe:${{ github.event.release.tag_name }} \
--amend ${{ secrets.DOCKER_USERNAME }}/docsgpt-fe:${{ github.event.release.tag_name }}-amd64 \
--amend ${{ secrets.DOCKER_USERNAME }}/docsgpt-fe:${{ github.event.release.tag_name }}-arm64
docker manifest push ${{ secrets.DOCKER_USERNAME }}/docsgpt-fe:${{ github.event.release.tag_name }}
docker manifest create ${{ secrets.DOCKER_USERNAME }}/docsgpt-fe:latest \
--amend ${{ secrets.DOCKER_USERNAME }}/docsgpt-fe:${{ github.event.release.tag_name }}-amd64 \
--amend ${{ secrets.DOCKER_USERNAME }}/docsgpt-fe:${{ github.event.release.tag_name }}-arm64
docker manifest push ${{ secrets.DOCKER_USERNAME }}/docsgpt-fe:latest
- name: Create and push manifest for ghcr.io
run: |
set -e
docker manifest create ghcr.io/${{ github.repository_owner }}/docsgpt-fe:${{ github.event.release.tag_name }} \
--amend ghcr.io/${{ github.repository_owner }}/docsgpt-fe:${{ github.event.release.tag_name }}-amd64 \
--amend ghcr.io/${{ github.repository_owner }}/docsgpt-fe:${{ github.event.release.tag_name }}-arm64
docker manifest push ghcr.io/${{ github.repository_owner }}/docsgpt-fe:${{ github.event.release.tag_name }}
docker manifest create ghcr.io/${{ github.repository_owner }}/docsgpt-fe:latest \
--amend ghcr.io/${{ github.repository_owner }}/docsgpt-fe:${{ github.event.release.tag_name }}-amd64 \
--amend ghcr.io/${{ github.repository_owner }}/docsgpt-fe:${{ github.event.release.tag_name }}-arm64
docker manifest push ghcr.io/${{ github.repository_owner }}/docsgpt-fe:latest

View File

@@ -1,4 +1,4 @@
name: Build and push multi-arch DocsGPT Docker image
name: Build and push DocsGPT Docker image for development
on:
workflow_dispatch:
@@ -7,36 +7,27 @@ on:
- main
jobs:
build:
deploy:
if: github.repository == 'arc53/DocsGPT'
strategy:
matrix:
include:
- platform: linux/amd64
runner: ubuntu-latest
suffix: amd64
- platform: linux/arm64
runner: ubuntu-24.04-arm
suffix: arm64
runs-on: ${{ matrix.runner }}
runs-on: ubuntu-latest
permissions:
contents: read
packages: write
steps:
- uses: actions/checkout@v4
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
with:
driver: docker-container
install: true
- name: Login to DockerHub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_PASSWORD }}
- name: Login to ghcr.io
uses: docker/login-action@v3
with:
@@ -44,57 +35,15 @@ jobs:
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Build and push platform-specific images
- name: Build and push Docker images to docker.io and ghcr.io
uses: docker/build-push-action@v6
with:
file: './application/Dockerfile'
platforms: ${{ matrix.platform }}
platforms: linux/amd64
context: ./application
push: true
tags: |
${{ secrets.DOCKER_USERNAME }}/docsgpt:develop-${{ matrix.suffix }}
ghcr.io/${{ github.repository_owner }}/docsgpt:develop-${{ matrix.suffix }}
provenance: false
sbom: false
${{ secrets.DOCKER_USERNAME }}/docsgpt:develop
ghcr.io/${{ github.repository_owner }}/docsgpt:develop
cache-from: type=registry,ref=${{ secrets.DOCKER_USERNAME }}/docsgpt:develop
cache-to: type=inline
manifest:
if: github.repository == 'arc53/DocsGPT'
needs: build
runs-on: ubuntu-latest
permissions:
packages: write
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
with:
driver: docker-container
install: true
- name: Login to DockerHub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_PASSWORD }}
- name: Login to ghcr.io
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Create and push manifest for DockerHub
run: |
docker manifest create ${{ secrets.DOCKER_USERNAME }}/docsgpt:develop \
--amend ${{ secrets.DOCKER_USERNAME }}/docsgpt:develop-amd64 \
--amend ${{ secrets.DOCKER_USERNAME }}/docsgpt:develop-arm64
docker manifest push ${{ secrets.DOCKER_USERNAME }}/docsgpt:develop
- name: Create and push manifest for ghcr.io
run: |
docker manifest create ghcr.io/${{ github.repository_owner }}/docsgpt:develop \
--amend ghcr.io/${{ github.repository_owner }}/docsgpt:develop-amd64 \
--amend ghcr.io/${{ github.repository_owner }}/docsgpt:develop-arm64
docker manifest push ghcr.io/${{ github.repository_owner }}/docsgpt:develop

View File

@@ -7,33 +7,20 @@ on:
- main
jobs:
build:
deploy:
if: github.repository == 'arc53/DocsGPT'
strategy:
matrix:
include:
- platform: linux/amd64
runner: ubuntu-latest
suffix: amd64
- platform: linux/arm64
runner: ubuntu-24.04-arm
suffix: arm64
runs-on: ${{ matrix.runner }}
runs-on: ubuntu-latest
permissions:
contents: read
packages: write
steps:
- uses: actions/checkout@v4
- name: Set up QEMU # Only needed for emulation, not for native arm64 builds
if: matrix.platform == 'linux/arm64'
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
with:
driver: docker-container
install: true
- name: Login to DockerHub
uses: docker/login-action@v3
@@ -48,57 +35,15 @@ jobs:
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Build and push platform-specific images
- name: Build and push Docker images to docker.io and ghcr.io
uses: docker/build-push-action@v6
with:
file: './frontend/Dockerfile'
platforms: ${{ matrix.platform }}
platforms: linux/amd64
context: ./frontend
push: true
tags: |
${{ secrets.DOCKER_USERNAME }}/docsgpt-fe:develop-${{ matrix.suffix }}
ghcr.io/${{ github.repository_owner }}/docsgpt-fe:develop-${{ matrix.suffix }}
provenance: false
sbom: false
${{ secrets.DOCKER_USERNAME }}/docsgpt-fe:develop
ghcr.io/${{ github.repository_owner }}/docsgpt-fe:develop
cache-from: type=registry,ref=${{ secrets.DOCKER_USERNAME }}/docsgpt-fe:develop
cache-to: type=inline
manifest:
if: github.repository == 'arc53/DocsGPT'
needs: build
runs-on: ubuntu-latest
permissions:
packages: write
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
with:
driver: docker-container
install: true
- name: Login to DockerHub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_PASSWORD }}
- name: Login to ghcr.io
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Create and push manifest for DockerHub
run: |
docker manifest create ${{ secrets.DOCKER_USERNAME }}/docsgpt-fe:develop \
--amend ${{ secrets.DOCKER_USERNAME }}/docsgpt-fe:develop-amd64 \
--amend ${{ secrets.DOCKER_USERNAME }}/docsgpt-fe:develop-arm64
docker manifest push ${{ secrets.DOCKER_USERNAME }}/docsgpt-fe:develop
- name: Create and push manifest for ghcr.io
run: |
docker manifest create ghcr.io/${{ github.repository_owner }}/docsgpt-fe:develop \
--amend ghcr.io/${{ github.repository_owner }}/docsgpt-fe:develop-amd64 \
--amend ghcr.io/${{ github.repository_owner }}/docsgpt-fe:develop-arm64
docker manifest push ghcr.io/${{ github.repository_owner }}/docsgpt-fe:develop

View File

@@ -6,7 +6,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.12"]
python-version: ["3.11"]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
@@ -23,8 +23,8 @@ jobs:
run: |
python -m pytest --cov=application --cov-report=xml
- name: Upload coverage reports to Codecov
if: github.event_name == 'pull_request' && matrix.python-version == '3.12'
uses: codecov/codecov-action@v5
if: github.event_name == 'pull_request' && matrix.python-version == '3.11'
uses: codecov/codecov-action@v4
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}

1
.gitignore vendored
View File

@@ -113,7 +113,6 @@ venv.bak/
# Spyder project settings
.spyderproject
.spyproject
.jwt_secret_key
# Rope project settings
.ropeproject

71
.vscode/launch.json vendored
View File

@@ -2,70 +2,15 @@
"version": "0.2.0",
"configurations": [
{
"name": "Frontend Debug (npm)",
"type": "node-terminal",
"name": "Docker Debug Frontend",
"request": "launch",
"command": "npm run dev",
"cwd": "${workspaceFolder}/frontend"
},
{
"name": "Flask Debugger",
"type": "debugpy",
"request": "launch",
"module": "flask",
"env": {
"FLASK_APP": "application/app.py",
"PYTHONPATH": "${workspaceFolder}",
"FLASK_ENV": "development",
"FLASK_DEBUG": "1",
"FLASK_RUN_PORT": "7091",
"FLASK_RUN_HOST": "0.0.0.0"
},
"args": [
"run",
"--no-debugger"
],
"cwd": "${workspaceFolder}",
},
{
"name": "Celery Debugger",
"type": "debugpy",
"request": "launch",
"module": "celery",
"env": {
"PYTHONPATH": "${workspaceFolder}",
},
"args": [
"-A",
"application.app.celery",
"worker",
"-l",
"INFO",
"--pool=solo"
],
"cwd": "${workspaceFolder}"
},
{
"name": "Dev Containers (Mongo + Redis)",
"type": "node-terminal",
"request": "launch",
"command": "docker compose -f deployment/docker-compose-dev.yaml up --build",
"cwd": "${workspaceFolder}"
}
],
"compounds": [
{
"name": "DocsGPT: Full Stack",
"configurations": [
"Frontend Debug (npm)",
"Flask Debugger",
"Celery Debugger"
],
"presentation": {
"group": "DocsGPT",
"order": 1
}
"type": "chrome",
"preLaunchTask": "docker-compose: debug:frontend",
"url": "http://127.0.0.1:5173",
"webRoot": "${workspaceFolder}/frontend",
"skipFiles": [
"<node_internals>/**"
]
}
]
}

21
.vscode/tasks.json vendored Normal file
View File

@@ -0,0 +1,21 @@
{
"version": "2.0.0",
"tasks": [
{
"type": "docker-compose",
"label": "docker-compose: debug:frontend",
"dockerCompose": {
"up": {
"detached": true,
"services": [
"frontend"
],
"build": true
},
"files": [
"${workspaceFolder}/docker-compose.yaml"
]
}
}
]
}

View File

@@ -27,7 +27,6 @@ Before creating issues, please check out how the latest version of our app looks
### 👨‍💻 If you're interested in contributing code, here are some important things to know:
For instructions on setting up a development environment, please refer to our [Development Deployment Guide](https://docs.docsgpt.cloud/Deploying/Development-Environment).
Tech Stack Overview:
@@ -35,40 +34,19 @@ Tech Stack Overview:
- 🖥 Backend: Developed in Python 🐍
### 🌐 Frontend Contributions (⚛️ React, Vite)
### 🌐 If you are looking to contribute to frontend (⚛React, Vite):
* The updated Figma design can be found [here](https://www.figma.com/file/OXLtrl1EAy885to6S69554/DocsGPT?node-id=0%3A1&t=hjWVuxRg9yi5YkJ9-1). Please try to follow the guidelines.
* **Coding Style:** We follow a strict coding style enforced by ESLint and Prettier. Please ensure your code adheres to the configuration provided in our repository's `fronetend/.eslintrc.js` file. We recommend configuring your editor with ESLint and Prettier to help with this.
* **Component Structure:** Strive for small, reusable components. Favor functional components and hooks over class components where possible.
* **State Management** If you need to add stores, please use Redux.
- The current frontend is being migrated from [`/application`](https://github.com/arc53/DocsGPT/tree/main/application) to [`/frontend`](https://github.com/arc53/DocsGPT/tree/main/frontend) with a new design, so please contribute to the new one.
- Check out this [milestone](https://github.com/arc53/DocsGPT/milestone/1) and its issues.
- The updated Figma design can be found [here](https://www.figma.com/file/OXLtrl1EAy885to6S69554/DocsGPT?node-id=0%3A1&t=hjWVuxRg9yi5YkJ9-1).
### 🖥 Backend Contributions (🐍 Python)
Please try to follow the guidelines.
- Review our issues and contribute to [`/application`](https://github.com/arc53/DocsGPT/tree/main/application)
### 🖥 If you are looking to contribute to Backend (🐍 Python):
- Review our issues and contribute to [`/application`](https://github.com/arc53/DocsGPT/tree/main/application) or [`/scripts`](https://github.com/arc53/DocsGPT/tree/main/scripts) (please disregard old [`ingest_rst.py`](https://github.com/arc53/DocsGPT/blob/main/scripts/old/ingest_rst.py) [`ingest_rst_sphinx.py`](https://github.com/arc53/DocsGPT/blob/main/scripts/old/ingest_rst_sphinx.py) files; these will be deprecated soon).
- All new code should be covered with unit tests ([pytest](https://github.com/pytest-dev/pytest)). Please find tests under [`/tests`](https://github.com/arc53/DocsGPT/tree/main/tests) folder.
- Before submitting your Pull Request, ensure it can be queried after ingesting some test data.
- **Coding Style:** We adhere to the [PEP 8](https://www.python.org/dev/peps/pep-0008/) style guide for Python code. We use `ruff` as our linter and code formatter. Please ensure your code is formatted correctly and passes `ruff` checks before submitting.
- **Type Hinting:** Please use type hints for all function arguments and return values. This improves code readability and helps catch errors early. Example:
```python
def my_function(name: str, count: int) -> list[str]:
...
```
- **Docstrings:** All functions and classes should have docstrings explaining their purpose, parameters, and return values. We prefer the [Google style docstrings](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html). Example:
```python
def my_function(name: str, count: int) -> list[str]:
"""Does something with a name and a count.
Args:
name: The name to use.
count: The number of times to do it.
Returns:
A list of strings.
"""
...
```
### Testing

View File

@@ -1,13 +1,15 @@
# **🎉 Join the Hacktoberfest with DocsGPT and win a Free T-shirt for a meaningful PR! 🎉**
# **🎉 Join the Hacktoberfest with DocsGPT and win a Free T-shirt and other prizes! 🎉**
Welcome, contributors! We're excited to announce that DocsGPT is participating in Hacktoberfest. Get involved by submitting meaningful pull requests.
All Meaningful contributors with accepted PRs that were created for issues with the `hacktoberfest` label (set by our maintainer team: dartpain, siiddhantt, pabik, ManishMadan2882) will receive a cool T-shirt! 🤩.
All contributors with accepted PRs will receive a cool Holopin! 🤩 (Watch out for a reply in your PR to collect it).
Fill in [this form](https://forms.gle/Npaba4n9Epfyx56S8
) after your PR was merged please
### 🏆 Top 50 contributors will receive a special T-shirt
If you are in doubt don't hesitate to ping us on discord, ping me - Alex (dartpain).
### 🏆 [LLM Document analysis by LexEU competition](https://github.com/arc53/DocsGPT/blob/main/lexeu-competition.md):
A separate competition is available for those who submit new retrieval / workflow method that will analyze a Document using EU laws.
With 200$, 100$, 50$ prize for 1st, 2nd and 3rd place respectively.
You can find more information [here](https://github.com/arc53/DocsGPT/blob/main/lexeu-competition.md)
## 📜 Here's How to Contribute:
```text
@@ -21,18 +23,19 @@ https://github.com/arc53/DocsGPT-cli
Non-Code Contributions:
📚 Wiki: Improve our documentation, create a guide.
📚 Wiki: Improve our documentation, create a guide or change existing documentation.
🖥️ Design: Improve the UI/UX or design a new feature.
📝 Blogging or Content Creation: Write articles or create videos to showcase DocsGPT or highlight your contributions!
```
### 📝 Guidelines for Pull Requests:
- Familiarize yourself with the current contributions and our [Roadmap](https://github.com/orgs/arc53/projects/2).
- Before contributing check existing [issues](https://github.com/arc53/DocsGPT/issues) or [create](https://github.com/arc53/DocsGPT/issues/new/choose) an issue and wait to get assigned.
- Once you are finished with your contribution, please fill in this [form](https://forms.gle/Npaba4n9Epfyx56S8).
- Before contributing we highly advise that you check existing [issues](https://github.com/arc53/DocsGPT/issues) or [create](https://github.com/arc53/DocsGPT/issues/new/choose) an issue and wait to get assigned.
- Once you are finished with your contribution, please fill in this [form](https://airtable.com/appikMaJwdHhC1SDP/pagoblCJ9W29wf6Hf/form).
- Refer to the [Documentation](https://docs.docsgpt.cloud/).
- Feel free to join our [Discord](https://discord.gg/n5BX8dh8rU) server. We're here to help newcomers, so don't hesitate to jump in! Join us [here](https://discord.gg/n5BX8dh8rU).
Thank you very much for considering contributing to DocsGPT during Hacktoberfest! 🙏 Your contributions (not just simple typos) could earn you a stylish new t-shirt.
Thank you very much for considering contributing to DocsGPT during Hacktoberfest! 🙏 Your contributions (not just simple typos) could earn you a stylish new t-shirt and other prizes as a token of our appreciation. 🎁 Join us, and let's code together! 🚀
We will publish a t-shirt design later into the October.

234
README.md
View File

@@ -3,11 +3,13 @@
</h1>
<p align="center">
<strong>Private AI for agents, assistants and enterprise search</strong>
<strong>Open-Source Documentation Assistant</strong>
</p>
<p align="left">
<strong><a href="https://www.docsgpt.cloud/">DocsGPT</a></strong> is an open-source AI platform for building intelligent agents and assistants. Features Agent Builder, deep research tools, document analysis (PDF, Office, web content), Multi-model support (choose your provider or run locally), and rich API connectivity for agents with actionable tools and integrations. Deploy anywhere with complete privacy control.
<strong><a href="https://www.docsgpt.cloud/">DocsGPT</a></strong> is a cutting-edge open-source solution that streamlines the process of finding information in the project documentation. With its integration of the powerful <strong>GPT</strong> models, developers can easily ask questions about a project and receive accurate answers.
Say goodbye to time-consuming manual searches, and let <strong><a href="https://www.docsgpt.cloud/">DocsGPT</a></strong> help you quickly find the information you need. Try it out and see how it revolutionizes your project documentation experience. Contribute to its development and be a part of the future of AI-powered assistance.
</p>
<div align="center">
@@ -15,142 +17,174 @@
<a href="https://github.com/arc53/DocsGPT">![link to main GitHub showing Stars number](https://img.shields.io/github/stars/arc53/docsgpt?style=social)</a>
<a href="https://github.com/arc53/DocsGPT">![link to main GitHub showing Forks number](https://img.shields.io/github/forks/arc53/docsgpt?style=social)</a>
<a href="https://github.com/arc53/DocsGPT/blob/main/LICENSE">![link to license file](https://img.shields.io/github/license/arc53/docsgpt)</a>
<a href="https://www.bestpractices.dev/projects/9907"><img src="https://www.bestpractices.dev/projects/9907/badge"></a>
<a href="https://discord.gg/n5BX8dh8rU">![link to discord](https://img.shields.io/discord/1070046503302877216)</a>
<a href="https://x.com/docsgptai">![X (formerly Twitter) URL](https://img.shields.io/twitter/follow/docsgptai)</a>
<a href="https://docs.docsgpt.cloud/quickstart">⚡️ Quickstart</a><a href="https://app.docsgpt.cloud/">☁️ Cloud Version</a><a href="https://discord.gg/n5BX8dh8rU">💬 Discord</a>
<br>
<a href="https://docs.docsgpt.cloud/">📖 Documentation</a><a href="https://github.com/arc53/DocsGPT/blob/main/CONTRIBUTING.md">👫 Contribute</a><a href="https://blog.docsgpt.cloud/">🗞 Blog</a>
<br>
<a href="https://twitter.com/docsgptai">![X (formerly Twitter) URL](https://img.shields.io/twitter/follow/docsgptai)</a>
</div>
<div align="center">
<br>
🎃 <a href="https://github.com/arc53/DocsGPT/blob/main/HACKTOBERFEST.md"> Hacktoberfest Prizes, Rules & Q&A </a> 🎃
<br>
<br>
</div>
<div align="center">
<br>
<img src="https://d3dg1063dc54p9.cloudfront.net/videos/demov7.gif" alt="video-example-of-docs-gpt" width="800" height="450">
</div>
<h3 align="left">
<strong>Key Features:</strong>
</h3>
<ul align="left">
<li><strong>🗂️ Wide Format Support:</strong> Reads PDF, DOCX, CSV, XLSX, EPUB, MD, RST, HTML, MDX, JSON, PPTX, and images.</li>
<li><strong>🌐 Web & Data Integration:</strong> Ingests from URLs, sitemaps, Reddit, GitHub and web crawlers.</li>
<li><strong>✅ Reliable Answers:</strong> Get accurate, hallucination-free responses with source citations viewable in a clean UI.</li>
<li><strong>🔑 Streamlined API Keys:</strong> Generate keys linked to your settings, documents, and models, simplifying chatbot and integration setup.</li>
<li><strong>🔗 Actionable Tooling:</strong> Connect to APIs, tools, and other services to enable LLM actions.</li>
<li><strong>🧩 Pre-built Integrations:</strong> Use readily available HTML/React chat widgets, search tools, Discord/Telegram bots, and more.</li>
<li><strong>🔌 Flexible Deployment:</strong> Works with major LLMs (OpenAI, Google, Anthropic) and local models (Ollama, llama_cpp).</li>
<li><strong>🏢 Secure & Scalable:</strong> Run privately and securely with Kubernetes support, designed for enterprise-grade reliability.</li>
</ul>
## Roadmap
- [x] Full GoogleAI compatibility (Jan 2025)
- [x] Add tools (Jan 2025)
- [x] Manually updating chunks in the app UI (Feb 2025)
- [x] Devcontainer for easy development (Feb 2025)
- [x] ReACT agent (March 2025)
- [x] Chatbots menu re-design to handle tools, agent types, and more (April 2025)
- [x] New input box in the conversation menu (April 2025)
- [x] Add triggerable actions / tools (webhook) (April 2025)
- [x] Agent optimisations (May 2025)
- [x] Filesystem sources update (July 2025)
- [x] Json Responses (August 2025)
- [x] MCP support (August 2025)
- [x] Google Drive integration (September 2025)
- [x] Add OAuth 2.0 authentication for MCP (September 2025)
- [ ] Sharepoint integration (October 2025)
- [ ] Deep Agents (October 2025)
- [ ] Agent scheduling
You can find our full roadmap [here](https://github.com/orgs/arc53/projects/2). Please don't hesitate to contribute or create issues, it helps us improve DocsGPT!
### Production Support / Help for Companies:
We're eager to provide personalized assistance when deploying your DocsGPT to a live environment.
[Get a Demo :wave:](https://www.docsgpt.cloud/contact)
[Book a Meeting :wave:](https://cal.com/arc53/docsgpt-demo-b2b)
[Send Email :email:](mailto:support@docsgpt.cloud?subject=DocsGPT%20support%2Fsolutions)
[Send Email :email:](mailto:contact@arc53.com?subject=DocsGPT%20support%2Fsolutions)
## Join the Lighthouse Program 🌟
Calling all developers and GenAI innovators! The **DocsGPT Lighthouse Program** connects technical leaders actively deploying or extending DocsGPT in real-world scenarios. Collaborate directly with our team to shape the roadmap, access priority support, and build enterprise-ready solutions with exclusive community insights.
<img src="https://github.com/user-attachments/assets/9a1f21de-7a15-4e42-9424-70d22ba5a913" alt="video-example-of-docs-gpt" width="1000" height="500">
[Learn More & Apply →](https://docs.google.com/forms/d/1KAADiJinUJ8EMQyfTXUIGyFbqINNClNR3jBNWq7DgTE)
## Roadmap
You can find our roadmap [here](https://github.com/orgs/arc53/projects/2). Please don't hesitate to contribute or create issues, it helps us improve DocsGPT!
## Our Open-Source Models Optimized for DocsGPT:
| Name | Base Model | Requirements (or similar) |
| --------------------------------------------------------------------- | ----------- | ------------------------- |
| [Docsgpt-7b-mistral](https://huggingface.co/Arc53/docsgpt-7b-mistral) | Mistral-7b | 1xA10G gpu |
| [Docsgpt-14b](https://huggingface.co/Arc53/docsgpt-14b) | llama-2-14b | 2xA10 gpu's |
| [Docsgpt-40b-falcon](https://huggingface.co/Arc53/docsgpt-40b-falcon) | falcon-40b | 8xA10G gpu's |
If you don't have enough resources to run it, you can use bitsnbytes to quantize.
## End to End AI Framework for Information Retrieval
![Architecture chart](https://github.com/user-attachments/assets/fc6a7841-ddfc-45e6-b5a0-d05fe648cbe2)
## Useful Links
- :mag: :fire: [Cloud Version](https://app.docsgpt.cloud/)
- :speech_balloon: :tada: [Join our Discord](https://discord.gg/n5BX8dh8rU)
- :books: :sunglasses: [Guides](https://docs.docsgpt.cloud/)
- :couple: [Interested in contributing?](https://github.com/arc53/DocsGPT/blob/main/CONTRIBUTING.md)
- :file_folder: :rocket: [How to use any other documentation](https://docs.docsgpt.cloud/Guides/How-to-train-on-other-documentation)
- :house: :closed_lock_with_key: [How to host it locally (so all data will stay on-premises)](https://docs.docsgpt.cloud/Guides/How-to-use-different-LLM)
## Project Structure
- Application - Flask app (main application).
- Extensions - Chrome extension.
- Scripts - Script that creates similarity search index for other libraries.
- Frontend - Frontend uses <a href="https://vitejs.dev/">Vite</a> and <a href="https://react.dev/">React</a>.
## QuickStart
> [!Note]
> Make sure you have [Docker](https://docs.docker.com/engine/install/) installed
A more detailed [Quickstart](https://docs.docsgpt.cloud/quickstart) is available in our documentation
On Mac OS or Linux, write:
1. **Clone the repository:**
`./setup.sh`
```bash
git clone https://github.com/arc53/DocsGPT.git
cd DocsGPT
It will install all the dependencies and allow you to download the local model, use OpenAI or use our LLM API.
Otherwise, refer to this Guide for Windows:
1. Download and open this repository with `git clone https://github.com/arc53/DocsGPT.git`
2. Create a `.env` file in your root directory and set the env variables and `VITE_API_STREAMING` to true or false, depending on whether you want streaming answers or not.
It should look like this inside:
```
LLM_NAME=[docsgpt or openai or others]
VITE_API_STREAMING=true
API_KEY=[if LLM_NAME is openai]
```
**For macOS and Linux:**
See optional environment variables in the [/.env-template](https://github.com/arc53/DocsGPT/blob/main/.env-template) and [/application/.env_sample](https://github.com/arc53/DocsGPT/blob/main/application/.env_sample) files.
2. **Run the setup script:**
3. Run [./run-with-docker-compose.sh](https://github.com/arc53/DocsGPT/blob/main/run-with-docker-compose.sh).
4. Navigate to http://localhost:5173/.
```bash
./setup.sh
```
To stop, just run `Ctrl + C`.
**For Windows:**
## Development Environments
2. **Run the PowerShell setup script:**
### Spin up Mongo and Redis
```powershell
PowerShell -ExecutionPolicy Bypass -File .\setup.ps1
```
For development, only two containers are used from [docker-compose.yaml](https://github.com/arc53/DocsGPT/blob/main/docker-compose.yaml) (by deleting all services except for Redis and Mongo).
See file [docker-compose-dev.yaml](./docker-compose-dev.yaml).
Either script will guide you through setting up DocsGPT. Four options available: using the public API, running locally, connecting to a local inference engine, or using a cloud API provider. Scripts will automatically configure your `.env` file and handle necessary downloads and installations based on your chosen option.
Run
**Navigate to http://localhost:5173/**
To stop DocsGPT, open a terminal in the `DocsGPT` directory and run:
```bash
docker compose -f deployment/docker-compose.yaml down
```
docker compose -f docker-compose-dev.yaml build
docker compose -f docker-compose-dev.yaml up -d
```
(or use the specific `docker compose down` command shown after running the setup script).
### Run the Backend
> [!Note]
> For development environment setup instructions, please refer to the [Development Environment Guide](https://docs.docsgpt.cloud/Deploying/Development-Environment).
> Make sure you have Python 3.10 or 3.11 installed.
1. Export required environment variables or prepare a `.env` file in the project folder:
- Copy [.env-template](https://github.com/arc53/DocsGPT/blob/main/application/.env-template) and create `.env`.
(check out [`application/core/settings.py`](application/core/settings.py) if you want to see more config options.)
2. (optional) Create a Python virtual environment:
You can follow the [Python official documentation](https://docs.python.org/3/tutorial/venv.html) for virtual environments.
a) On Mac OS and Linux
```commandline
python -m venv venv
. venv/bin/activate
```
b) On Windows
```commandline
python -m venv venv
venv/Scripts/activate
```
3. Download embedding model and save it in the `model/` folder:
You can use the script below, or download it manually from [here](https://d3dg1063dc54p9.cloudfront.net/models/embeddings/mpnet-base-v2.zip), unzip it and save it in the `model/` folder.
```commandline
wget https://d3dg1063dc54p9.cloudfront.net/models/embeddings/mpnet-base-v2.zip
unzip mpnet-base-v2.zip -d model
rm mpnet-base-v2.zip
```
4. Install dependencies for the backend:
```commandline
pip install -r application/requirements.txt
```
5. Run the app using `flask --app application/app.py run --host=0.0.0.0 --port=7091`.
6. Start worker with `celery -A application.app.celery worker -l INFO`.
### Start Frontend
> [!Note]
> Make sure you have Node version 16 or higher.
1. Navigate to the [/frontend](https://github.com/arc53/DocsGPT/tree/main/frontend) folder.
2. Install the required packages `husky` and `vite` (ignore if already installed).
```commandline
npm install husky -g
npm install vite -g
```
3. Install dependencies by running `npm install --include=dev`.
4. Run the app using `npm run dev`.
## Contributing
Please refer to the [CONTRIBUTING.md](CONTRIBUTING.md) file for information about how to get involved. We welcome issues, questions, and pull requests.
## Architecture
![Architecture chart](https://github.com/user-attachments/assets/fc6a7841-ddfc-45e6-b5a0-d05fe648cbe2)
## Project Structure
- Application - Flask app (main application).
- Extensions - Extensions, like react widget or discord bot.
- Frontend - Frontend uses <a href="https://vitejs.dev/">Vite</a> and <a href="https://react.dev/">React</a>.
- Scripts - Miscellaneous scripts.
## Code Of Conduct
We as members, contributors, and leaders, pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. Please refer to the [CODE_OF_CONDUCT.md](CODE_OF_CONDUCT.md) file for more information about contributing.

View File

@@ -6,20 +6,21 @@ ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update && \
apt-get install -y software-properties-common && \
add-apt-repository ppa:deadsnakes/ppa && \
# Install necessary packages and Python
apt-get update && \
apt-get install -y --no-install-recommends gcc wget unzip libc6-dev python3.12 python3.12-venv && \
apt-get install -y --no-install-recommends gcc wget unzip libc6-dev python3.11 python3.11-distutils python3.11-venv && \
rm -rf /var/lib/apt/lists/*
# Verify Python installation and setup symlink
RUN if [ -f /usr/bin/python3.12 ]; then \
ln -s /usr/bin/python3.12 /usr/bin/python; \
RUN if [ -f /usr/bin/python3.11 ]; then \
ln -s /usr/bin/python3.11 /usr/bin/python; \
else \
echo "Python 3.12 not found"; exit 1; \
echo "Python 3.11 not found"; exit 1; \
fi
# Download and unzip the model
RUN wget https://d3dg1063dc54p9.cloudfront.net/models/embeddings/mpnet-base-v2.zip && \
unzip mpnet-base-v2.zip -d models && \
unzip mpnet-base-v2.zip -d model && \
rm mpnet-base-v2.zip
# Install Rust
@@ -32,7 +33,7 @@ RUN apt-get remove --purge -y wget unzip && apt-get autoremove -y && rm -rf /var
COPY requirements.txt .
# Setup Python virtual environment
RUN python3.12 -m venv /venv
RUN python3.11 -m venv /venv
# Activate virtual environment and install Python packages
ENV PATH="/venv/bin:$PATH"
@@ -48,8 +49,9 @@ FROM ubuntu:24.04 as final
RUN apt-get update && \
apt-get install -y software-properties-common && \
add-apt-repository ppa:deadsnakes/ppa && \
apt-get update && apt-get install -y --no-install-recommends python3.12 && \
ln -s /usr/bin/python3.12 /usr/bin/python && \
# Install Python
apt-get update && apt-get install -y --no-install-recommends python3.11 && \
ln -s /usr/bin/python3.11 /usr/bin/python && \
rm -rf /var/lib/apt/lists/*
# Set working directory
@@ -61,8 +63,7 @@ RUN groupadd -r appuser && \
# Copy the virtual environment and model from the builder stage
COPY --from=builder /venv /venv
COPY --from=builder /models /app/models
COPY --from=builder /model /app/model
# Copy your application code
COPY . /app/application
@@ -84,4 +85,4 @@ EXPOSE 7091
USER appuser
# Start Gunicorn
CMD ["gunicorn", "-w", "1", "--timeout", "120", "--bind", "0.0.0.0:7091", "--preload", "application.wsgi:app"]
CMD ["gunicorn", "-w", "2", "--timeout", "120", "--bind", "0.0.0.0:7091", "application.wsgi:app"]

View File

@@ -1,16 +0,0 @@
from application.agents.classic_agent import ClassicAgent
from application.agents.react_agent import ReActAgent
class AgentCreator:
agents = {
"classic": ClassicAgent,
"react": ReActAgent,
}
@classmethod
def create_agent(cls, type, *args, **kwargs):
agent_class = cls.agents.get(type.lower())
if not agent_class:
raise ValueError(f"No agent class found for type {type}")
return agent_class(*args, **kwargs)

View File

@@ -1,415 +0,0 @@
import logging
import uuid
from abc import ABC, abstractmethod
from typing import Dict, Generator, List, Optional
from bson.objectid import ObjectId
from application.agents.tools.tool_action_parser import ToolActionParser
from application.agents.tools.tool_manager import ToolManager
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.llm.handlers.handler_creator import LLMHandlerCreator
from application.llm.llm_creator import LLMCreator
from application.logging import build_stack_data, log_activity, LogContext
from application.retriever.base import BaseRetriever
logger = logging.getLogger(__name__)
class BaseAgent(ABC):
def __init__(
self,
endpoint: str,
llm_name: str,
gpt_model: str,
api_key: str,
user_api_key: Optional[str] = None,
prompt: str = "",
chat_history: Optional[List[Dict]] = None,
decoded_token: Optional[Dict] = None,
attachments: Optional[List[Dict]] = None,
json_schema: Optional[Dict] = None,
):
self.endpoint = endpoint
self.llm_name = llm_name
self.gpt_model = gpt_model
self.api_key = api_key
self.user_api_key = user_api_key
self.prompt = prompt
self.decoded_token = decoded_token or {}
self.user: str = decoded_token.get("sub")
self.tool_config: Dict = {}
self.tools: List[Dict] = []
self.tool_calls: List[Dict] = []
self.chat_history: List[Dict] = chat_history if chat_history is not None else []
self.llm = LLMCreator.create_llm(
llm_name,
api_key=api_key,
user_api_key=user_api_key,
decoded_token=decoded_token,
)
self.llm_handler = LLMHandlerCreator.create_handler(
llm_name if llm_name else "default"
)
self.attachments = attachments or []
self.json_schema = json_schema
@log_activity()
def gen(
self, query: str, retriever: BaseRetriever, log_context: LogContext = None
) -> Generator[Dict, None, None]:
yield from self._gen_inner(query, retriever, log_context)
@abstractmethod
def _gen_inner(
self, query: str, retriever: BaseRetriever, log_context: LogContext
) -> Generator[Dict, None, None]:
pass
def _get_tools(self, api_key: str = None) -> Dict[str, Dict]:
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
agents_collection = db["agents"]
tools_collection = db["user_tools"]
agent_data = agents_collection.find_one({"key": api_key or self.user_api_key})
tool_ids = agent_data.get("tools", []) if agent_data else []
tools = (
tools_collection.find(
{"_id": {"$in": [ObjectId(tool_id) for tool_id in tool_ids]}}
)
if tool_ids
else []
)
tools = list(tools)
tools_by_id = {str(tool["_id"]): tool for tool in tools} if tools else {}
return tools_by_id
def _get_user_tools(self, user="local"):
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
user_tools_collection = db["user_tools"]
user_tools = user_tools_collection.find({"user": user, "status": True})
user_tools = list(user_tools)
return {str(i): tool for i, tool in enumerate(user_tools)}
def _build_tool_parameters(self, action):
params = {"type": "object", "properties": {}, "required": []}
for param_type in ["query_params", "headers", "body", "parameters"]:
if param_type in action and action[param_type].get("properties"):
for k, v in action[param_type]["properties"].items():
if v.get("filled_by_llm", True):
params["properties"][k] = {
key: value
for key, value in v.items()
if key != "filled_by_llm" and key != "value"
}
params["required"].append(k)
return params
def _prepare_tools(self, tools_dict):
self.tools = [
{
"type": "function",
"function": {
"name": f"{action['name']}_{tool_id}",
"description": action["description"],
"parameters": self._build_tool_parameters(action),
},
}
for tool_id, tool in tools_dict.items()
if (
(tool["name"] == "api_tool" and "actions" in tool.get("config", {}))
or (tool["name"] != "api_tool" and "actions" in tool)
)
for action in (
tool["config"]["actions"].values()
if tool["name"] == "api_tool"
else tool["actions"]
)
if action.get("active", True)
]
def _execute_tool_action(self, tools_dict, call):
parser = ToolActionParser(self.llm.__class__.__name__)
tool_id, action_name, call_args = parser.parse_args(call)
call_id = getattr(call, "id", None) or str(uuid.uuid4())
# Check if parsing failed
if tool_id is None or action_name is None:
error_message = f"Error: Failed to parse LLM tool call. Tool name: {getattr(call, 'name', 'unknown')}"
logger.error(error_message)
tool_call_data = {
"tool_name": "unknown",
"call_id": call_id,
"action_name": getattr(call, "name", "unknown"),
"arguments": call_args or {},
"result": f"Failed to parse tool call. Invalid tool name format: {getattr(call, 'name', 'unknown')}",
}
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
self.tool_calls.append(tool_call_data)
return "Failed to parse tool call.", call_id
# Check if tool_id exists in available tools
if tool_id not in tools_dict:
error_message = f"Error: Tool ID '{tool_id}' extracted from LLM call not found in available tools_dict. Available IDs: {list(tools_dict.keys())}"
logger.error(error_message)
# Return error result
tool_call_data = {
"tool_name": "unknown",
"call_id": call_id,
"action_name": f"{action_name}_{tool_id}",
"arguments": call_args,
"result": f"Tool with ID {tool_id} not found. Available tools: {list(tools_dict.keys())}",
}
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
self.tool_calls.append(tool_call_data)
return f"Tool with ID {tool_id} not found.", call_id
tool_call_data = {
"tool_name": tools_dict[tool_id]["name"],
"call_id": call_id,
"action_name": f"{action_name}_{tool_id}",
"arguments": call_args,
}
yield {"type": "tool_call", "data": {**tool_call_data, "status": "pending"}}
tool_data = tools_dict[tool_id]
action_data = (
tool_data["config"]["actions"][action_name]
if tool_data["name"] == "api_tool"
else next(
action
for action in tool_data["actions"]
if action["name"] == action_name
)
)
query_params, headers, body, parameters = {}, {}, {}, {}
param_types = {
"query_params": query_params,
"headers": headers,
"body": body,
"parameters": parameters,
}
for param_type, target_dict in param_types.items():
if param_type in action_data and action_data[param_type].get("properties"):
for param, details in action_data[param_type]["properties"].items():
if param not in call_args and "value" in details:
target_dict[param] = details["value"]
for param, value in call_args.items():
for param_type, target_dict in param_types.items():
if param_type in action_data and param in action_data[param_type].get(
"properties", {}
):
target_dict[param] = value
tm = ToolManager(config={})
# Prepare tool_config and add tool_id for memory tools
if tool_data["name"] == "api_tool":
tool_config = {
"url": tool_data["config"]["actions"][action_name]["url"],
"method": tool_data["config"]["actions"][action_name]["method"],
"headers": headers,
"query_params": query_params,
}
else:
tool_config = tool_data["config"].copy() if tool_data["config"] else {}
# Add tool_id from MongoDB _id for tools that need instance isolation (like memory tool)
# Use MongoDB _id if available, otherwise fall back to enumerated tool_id
tool_config["tool_id"] = str(tool_data.get("_id", tool_id))
tool = tm.load_tool(
tool_data["name"],
tool_config=tool_config,
user_id=self.user, # Pass user ID for MCP tools credential decryption
)
if tool_data["name"] == "api_tool":
print(
f"Executing api: {action_name} with query_params: {query_params}, headers: {headers}, body: {body}"
)
result = tool.execute_action(action_name, **body)
else:
print(f"Executing tool: {action_name} with args: {call_args}")
result = tool.execute_action(action_name, **parameters)
tool_call_data["result"] = (
f"{str(result)[:50]}..." if len(str(result)) > 50 else result
)
yield {"type": "tool_call", "data": {**tool_call_data, "status": "completed"}}
self.tool_calls.append(tool_call_data)
return result, call_id
def _get_truncated_tool_calls(self):
return [
{
**tool_call,
"result": (
f"{str(tool_call['result'])[:50]}..."
if len(str(tool_call["result"])) > 50
else tool_call["result"]
),
"status": "completed",
}
for tool_call in self.tool_calls
]
def _build_messages(
self,
system_prompt: str,
query: str,
retrieved_data: List[Dict],
) -> List[Dict]:
docs_with_filenames = []
for doc in retrieved_data:
filename = doc.get("filename") or doc.get("title") or doc.get("source")
if filename:
chunk_header = str(filename)
docs_with_filenames.append(f"{chunk_header}\n{doc['text']}")
else:
docs_with_filenames.append(doc["text"])
docs_together = "\n\n".join(docs_with_filenames)
p_chat_combine = system_prompt.replace("{summaries}", docs_together)
messages_combine = [{"role": "system", "content": p_chat_combine}]
for i in self.chat_history:
if "prompt" in i and "response" in i:
messages_combine.append({"role": "user", "content": i["prompt"]})
messages_combine.append({"role": "assistant", "content": i["response"]})
if "tool_calls" in i:
for tool_call in i["tool_calls"]:
call_id = tool_call.get("call_id") or str(uuid.uuid4())
function_call_dict = {
"function_call": {
"name": tool_call.get("action_name"),
"args": tool_call.get("arguments"),
"call_id": call_id,
}
}
function_response_dict = {
"function_response": {
"name": tool_call.get("action_name"),
"response": {"result": tool_call.get("result")},
"call_id": call_id,
}
}
messages_combine.append(
{"role": "assistant", "content": [function_call_dict]}
)
messages_combine.append(
{"role": "tool", "content": [function_response_dict]}
)
messages_combine.append({"role": "user", "content": query})
return messages_combine
def _retriever_search(
self,
retriever: BaseRetriever,
query: str,
log_context: Optional[LogContext] = None,
) -> List[Dict]:
retrieved_data = retriever.search(query)
if log_context:
data = build_stack_data(retriever, exclude_attributes=["llm"])
log_context.stacks.append({"component": "retriever", "data": data})
return retrieved_data
def _llm_gen(self, messages: List[Dict], log_context: Optional[LogContext] = None):
gen_kwargs = {"model": self.gpt_model, "messages": messages}
if (
hasattr(self.llm, "_supports_tools")
and self.llm._supports_tools
and self.tools
):
gen_kwargs["tools"] = self.tools
if (
self.json_schema
and hasattr(self.llm, "_supports_structured_output")
and self.llm._supports_structured_output()
):
structured_format = self.llm.prepare_structured_output_format(
self.json_schema
)
if structured_format:
if self.llm_name == "openai":
gen_kwargs["response_format"] = structured_format
elif self.llm_name == "google":
gen_kwargs["response_schema"] = structured_format
resp = self.llm.gen_stream(**gen_kwargs)
if log_context:
data = build_stack_data(self.llm, exclude_attributes=["client"])
log_context.stacks.append({"component": "llm", "data": data})
return resp
def _llm_handler(
self,
resp,
tools_dict: Dict,
messages: List[Dict],
log_context: Optional[LogContext] = None,
attachments: Optional[List[Dict]] = None,
):
resp = self.llm_handler.process_message_flow(
self, resp, tools_dict, messages, attachments, True
)
if log_context:
data = build_stack_data(self.llm_handler, exclude_attributes=["tool_calls"])
log_context.stacks.append({"component": "llm_handler", "data": data})
return resp
def _handle_response(self, response, tools_dict, messages, log_context):
is_structured_output = (
self.json_schema is not None
and hasattr(self.llm, "_supports_structured_output")
and self.llm._supports_structured_output()
)
if isinstance(response, str):
answer_data = {"answer": response}
if is_structured_output:
answer_data["structured"] = True
answer_data["schema"] = self.json_schema
yield answer_data
return
if hasattr(response, "message") and getattr(response.message, "content", None):
answer_data = {"answer": response.message.content}
if is_structured_output:
answer_data["structured"] = True
answer_data["schema"] = self.json_schema
yield answer_data
return
processed_response_gen = self._llm_handler(
response, tools_dict, messages, log_context, self.attachments
)
for event in processed_response_gen:
if isinstance(event, str):
answer_data = {"answer": event}
if is_structured_output:
answer_data["structured"] = True
answer_data["schema"] = self.json_schema
yield answer_data
elif hasattr(event, "message") and getattr(event.message, "content", None):
answer_data = {"answer": event.message.content}
if is_structured_output:
answer_data["structured"] = True
answer_data["schema"] = self.json_schema
yield answer_data
elif isinstance(event, dict) and "type" in event:
yield event

View File

@@ -1,53 +0,0 @@
from typing import Dict, Generator
from application.agents.base import BaseAgent
from application.logging import LogContext
from application.retriever.base import BaseRetriever
import logging
logger = logging.getLogger(__name__)
class ClassicAgent(BaseAgent):
"""A simplified agent with clear execution flow.
Usage:
1. Processes a query through retrieval
2. Sets up available tools
3. Generates responses using LLM
4. Handles tool interactions if needed
5. Returns standardized outputs
Easy to extend by overriding specific steps.
"""
def _gen_inner(
self, query: str, retriever: BaseRetriever, log_context: LogContext
) -> Generator[Dict, None, None]:
# Step 1: Retrieve relevant data
retrieved_data = self._retriever_search(retriever, query, log_context)
# Step 2: Prepare tools
tools_dict = (
self._get_user_tools(self.user)
if not self.user_api_key
else self._get_tools(self.user_api_key)
)
self._prepare_tools(tools_dict)
# Step 3: Build and process messages
messages = self._build_messages(self.prompt, query, retrieved_data)
llm_response = self._llm_gen(messages, log_context)
# Step 4: Handle the response
yield from self._handle_response(
llm_response, tools_dict, messages, log_context
)
# Step 5: Return metadata
yield {"sources": retrieved_data}
yield {"tool_calls": self._get_truncated_tool_calls()}
# Log tool calls for debugging
log_context.stacks.append(
{"component": "agent", "data": {"tool_calls": self.tool_calls.copy()}}
)

View File

@@ -1,229 +0,0 @@
import os
from typing import Dict, Generator, List, Any
import logging
from application.agents.base import BaseAgent
from application.logging import build_stack_data, LogContext
from application.retriever.base import BaseRetriever
logger = logging.getLogger(__name__)
current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
with open(
os.path.join(current_dir, "application/prompts", "react_planning_prompt.txt"), "r"
) as f:
planning_prompt_template = f.read()
with open(
os.path.join(current_dir, "application/prompts", "react_final_prompt.txt"),
"r",
) as f:
final_prompt_template = f.read()
MAX_ITERATIONS_REASONING = 10
class ReActAgent(BaseAgent):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.plan: str = ""
self.observations: List[str] = []
def _extract_content_from_llm_response(self, resp: Any) -> str:
"""
Helper to extract string content from various LLM response types.
Handles strings, message objects (OpenAI-like), and streams.
Adapt stream handling for your specific LLM client if not OpenAI.
"""
collected_content = []
if isinstance(resp, str):
collected_content.append(resp)
elif ( # OpenAI non-streaming or Anthropic non-streaming (older SDK style)
hasattr(resp, "message")
and hasattr(resp.message, "content")
and resp.message.content is not None
):
collected_content.append(resp.message.content)
elif ( # OpenAI non-streaming (Pydantic model), Anthropic new SDK non-streaming
hasattr(resp, "choices") and resp.choices and
hasattr(resp.choices[0], "message") and
hasattr(resp.choices[0].message, "content") and
resp.choices[0].message.content is not None
):
collected_content.append(resp.choices[0].message.content) # OpenAI
elif ( # Anthropic new SDK non-streaming content block
hasattr(resp, "content") and isinstance(resp.content, list) and resp.content and
hasattr(resp.content[0], "text")
):
collected_content.append(resp.content[0].text) # Anthropic
else:
# Assume resp is a stream if not a recognized object
try:
for chunk in resp: # This will fail if resp is not iterable (e.g. a non-streaming response object)
content_piece = ""
# OpenAI-like stream
if hasattr(chunk, 'choices') and len(chunk.choices) > 0 and \
hasattr(chunk.choices[0], 'delta') and \
hasattr(chunk.choices[0].delta, 'content') and \
chunk.choices[0].delta.content is not None:
content_piece = chunk.choices[0].delta.content
# Anthropic-like stream (ContentBlockDelta)
elif hasattr(chunk, 'type') and chunk.type == 'content_block_delta' and \
hasattr(chunk, 'delta') and hasattr(chunk.delta, 'text'):
content_piece = chunk.delta.text
elif isinstance(chunk, str): # Simplest case: stream of strings
content_piece = chunk
if content_piece:
collected_content.append(content_piece)
except TypeError: # If resp is not iterable (e.g. a final response object that wasn't caught above)
logger.debug(f"Response type {type(resp)} could not be iterated as a stream. It might be a non-streaming object not handled by specific checks.")
except Exception as e:
logger.error(f"Error processing potential stream chunk: {e}, chunk was: {getattr(chunk, '__dict__', chunk)}")
return "".join(collected_content)
def _gen_inner(
self, query: str, retriever: BaseRetriever, log_context: LogContext
) -> Generator[Dict, None, None]:
# Reset state for this generation call
self.plan = ""
self.observations = []
retrieved_data = self._retriever_search(retriever, query, log_context)
if self.user_api_key:
tools_dict = self._get_tools(self.user_api_key)
else:
tools_dict = self._get_user_tools(self.user)
self._prepare_tools(tools_dict)
docs_together = "\n".join([doc["text"] for doc in retrieved_data])
iterating_reasoning = 0
while iterating_reasoning < MAX_ITERATIONS_REASONING:
iterating_reasoning += 1
# 1. Create Plan
logger.info("ReActAgent: Creating plan...")
plan_stream = self._create_plan(query, docs_together, log_context)
current_plan_parts = []
yield {"thought": f"Reasoning... (iteration {iterating_reasoning})\n\n"}
for line_chunk in plan_stream:
current_plan_parts.append(line_chunk)
yield {"thought": line_chunk}
self.plan = "".join(current_plan_parts)
if self.plan:
self.observations.append(f"Plan: {self.plan} Iteration: {iterating_reasoning}")
max_obs_len = 20000
obs_str = "\n".join(self.observations)
if len(obs_str) > max_obs_len:
obs_str = obs_str[:max_obs_len] + "\n...[observations truncated]"
execution_prompt_str = (
(self.prompt or "")
+ f"\n\nFollow this plan:\n{self.plan}"
+ f"\n\nObservations:\n{obs_str}"
+ f"\n\nIf there is enough data to complete user query '{query}', Respond with 'SATISFIED' only. Otherwise, continue. Dont Menstion 'SATISFIED' in your response if you are not ready. "
)
messages = self._build_messages(execution_prompt_str, query, retrieved_data)
resp_from_llm_gen = self._llm_gen(messages, log_context)
initial_llm_thought_content = self._extract_content_from_llm_response(resp_from_llm_gen)
if initial_llm_thought_content:
self.observations.append(f"Initial thought/response: {initial_llm_thought_content}")
else:
logger.info("ReActAgent: Initial LLM response (before handler) had no textual content (might be only tool calls).")
resp_after_handler = self._llm_handler(resp_from_llm_gen, tools_dict, messages, log_context)
for tool_call_info in self.tool_calls: # Iterate over self.tool_calls populated by _llm_handler
observation_string = (
f"Executed Action: Tool '{tool_call_info.get('tool_name', 'N/A')}' "
f"with arguments '{tool_call_info.get('arguments', '{}')}'. Result: '{str(tool_call_info.get('result', ''))[:200]}...'"
)
self.observations.append(observation_string)
content_after_handler = self._extract_content_from_llm_response(resp_after_handler)
if content_after_handler:
self.observations.append(f"Response after tool execution: {content_after_handler}")
else:
logger.info("ReActAgent: LLM response after handler had no textual content.")
if log_context:
log_context.stacks.append(
{"component": "agent_tool_calls", "data": {"tool_calls": self.tool_calls.copy()}}
)
yield {"sources": retrieved_data}
display_tool_calls = []
for tc in self.tool_calls:
cleaned_tc = tc.copy()
if len(str(cleaned_tc.get("result", ""))) > 50:
cleaned_tc["result"] = str(cleaned_tc["result"])[:50] + "..."
display_tool_calls.append(cleaned_tc)
if display_tool_calls:
yield {"tool_calls": display_tool_calls}
if "SATISFIED" in content_after_handler:
logger.info("ReActAgent: LLM satisfied with the plan and data. Stopping reasoning.")
break
# 3. Create Final Answer based on all observations
final_answer_stream = self._create_final_answer(query, self.observations, log_context)
for answer_chunk in final_answer_stream:
yield {"answer": answer_chunk}
logger.info("ReActAgent: Finished generating final answer.")
def _create_plan(
self, query: str, docs_data: str, log_context: LogContext = None
) -> Generator[str, None, None]:
plan_prompt_filled = planning_prompt_template.replace("{query}", query)
if "{summaries}" in plan_prompt_filled:
summaries = docs_data if docs_data else "No documents retrieved."
plan_prompt_filled = plan_prompt_filled.replace("{summaries}", summaries)
plan_prompt_filled = plan_prompt_filled.replace("{prompt}", self.prompt or "")
plan_prompt_filled = plan_prompt_filled.replace("{observations}", "\n".join(self.observations))
messages = [{"role": "user", "content": plan_prompt_filled}]
plan_stream_from_llm = self.llm.gen_stream(
model=self.gpt_model, messages=messages, tools=getattr(self, 'tools', None) # Use self.tools
)
if log_context:
data = build_stack_data(self.llm)
log_context.stacks.append({"component": "planning_llm", "data": data})
for chunk in plan_stream_from_llm:
content_piece = self._extract_content_from_llm_response(chunk)
if content_piece:
yield content_piece
def _create_final_answer(
self, query: str, observations: List[str], log_context: LogContext = None
) -> Generator[str, None, None]:
observation_string = "\n".join(observations)
max_obs_len = 10000
if len(observation_string) > max_obs_len:
observation_string = observation_string[:max_obs_len] + "\n...[observations truncated]"
logger.warning("ReActAgent: Truncated observations for final answer prompt due to length.")
final_answer_prompt_filled = final_prompt_template.format(
query=query, observations=observation_string
)
messages = [{"role": "user", "content": final_answer_prompt_filled}]
# Final answer should synthesize, not call tools.
final_answer_stream_from_llm = self.llm.gen_stream(
model=self.gpt_model, messages=messages, tools=None
)
if log_context:
data = build_stack_data(self.llm)
log_context.stacks.append({"component": "final_answer_llm", "data": data})
for chunk in final_answer_stream_from_llm:
content_piece = self._extract_content_from_llm_response(chunk)
if content_piece:
yield content_piece

View File

@@ -1,72 +0,0 @@
import json
import requests
from application.agents.tools.base import Tool
class APITool(Tool):
"""
API Tool
A flexible tool for performing various API actions (e.g., sending messages, retrieving data) via custom user-specified APIs
"""
def __init__(self, config):
self.config = config
self.url = config.get("url", "")
self.method = config.get("method", "GET")
self.headers = config.get("headers", {"Content-Type": "application/json"})
self.query_params = config.get("query_params", {})
def execute_action(self, action_name, **kwargs):
return self._make_api_call(
self.url, self.method, self.headers, self.query_params, kwargs
)
def _make_api_call(self, url, method, headers, query_params, body):
if query_params:
url = f"{url}?{requests.compat.urlencode(query_params)}"
# if isinstance(body, dict):
# body = json.dumps(body)
try:
print(f"Making API call: {method} {url} with body: {body}")
if body == "{}":
body = None
response = requests.request(method, url, headers=headers, data=body)
response.raise_for_status()
content_type = response.headers.get(
"Content-Type", "application/json"
).lower()
if "application/json" in content_type:
try:
data = response.json()
except json.JSONDecodeError as e:
print(f"Error decoding JSON: {e}. Raw response: {response.text}")
return {
"status_code": response.status_code,
"message": f"API call returned invalid JSON. Error: {e}",
"data": response.text,
}
elif "text/" in content_type or "application/xml" in content_type:
data = response.text
elif not response.content:
data = None
else:
print(f"Unsupported content type: {content_type}")
data = response.content
return {
"status_code": response.status_code,
"data": data,
"message": "API call successful.",
}
except requests.exceptions.RequestException as e:
return {
"status_code": response.status_code if response else None,
"message": f"API call failed: {str(e)}",
}
def get_actions_metadata(self):
return []
def get_config_requirements(self):
return {}

View File

@@ -1,21 +0,0 @@
from abc import ABC, abstractmethod
class Tool(ABC):
@abstractmethod
def execute_action(self, action_name: str, **kwargs):
pass
@abstractmethod
def get_actions_metadata(self):
"""
Returns a list of JSON objects describing the actions supported by the tool.
"""
pass
@abstractmethod
def get_config_requirements(self):
"""
Returns a dictionary describing the configuration requirements for the tool.
"""
pass

View File

@@ -1,182 +0,0 @@
import requests
from application.agents.tools.base import Tool
class BraveSearchTool(Tool):
"""
Brave Search
A tool for performing web and image searches using the Brave Search API.
Requires an API key for authentication.
"""
def __init__(self, config):
self.config = config
self.token = config.get("token", "")
self.base_url = "https://api.search.brave.com/res/v1"
def execute_action(self, action_name, **kwargs):
actions = {
"brave_web_search": self._web_search,
"brave_image_search": self._image_search,
}
if action_name in actions:
return actions[action_name](**kwargs)
else:
raise ValueError(f"Unknown action: {action_name}")
def _web_search(
self,
query,
country="ALL",
search_lang="en",
count=10,
offset=0,
safesearch="off",
freshness=None,
result_filter=None,
extra_snippets=False,
summary=False,
):
"""
Performs a web search using the Brave Search API.
"""
print(f"Performing Brave web search for: {query}")
url = f"{self.base_url}/web/search"
params = {
"q": query,
"country": country,
"search_lang": search_lang,
"count": min(count, 20),
"offset": min(offset, 9),
"safesearch": safesearch,
}
if freshness:
params["freshness"] = freshness
if result_filter:
params["result_filter"] = result_filter
if extra_snippets:
params["extra_snippets"] = 1
if summary:
params["summary"] = 1
headers = {
"Accept": "application/json",
"Accept-Encoding": "gzip",
"X-Subscription-Token": self.token,
}
response = requests.get(url, params=params, headers=headers)
if response.status_code == 200:
return {
"status_code": response.status_code,
"results": response.json(),
"message": "Search completed successfully.",
}
else:
return {
"status_code": response.status_code,
"message": f"Search failed with status code: {response.status_code}.",
}
def _image_search(
self,
query,
country="ALL",
search_lang="en",
count=5,
safesearch="off",
spellcheck=False,
):
"""
Performs an image search using the Brave Search API.
"""
print(f"Performing Brave image search for: {query}")
url = f"{self.base_url}/images/search"
params = {
"q": query,
"country": country,
"search_lang": search_lang,
"count": min(count, 100), # API max is 100
"safesearch": safesearch,
"spellcheck": 1 if spellcheck else 0,
}
headers = {
"Accept": "application/json",
"Accept-Encoding": "gzip",
"X-Subscription-Token": self.token,
}
response = requests.get(url, params=params, headers=headers)
if response.status_code == 200:
return {
"status_code": response.status_code,
"results": response.json(),
"message": "Image search completed successfully.",
}
else:
return {
"status_code": response.status_code,
"message": f"Image search failed with status code: {response.status_code}.",
}
def get_actions_metadata(self):
return [
{
"name": "brave_web_search",
"description": "Perform a web search using Brave Search",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The search query (max 400 characters, 50 words)",
},
"search_lang": {
"type": "string",
"description": "The search language preference (default: en)",
},
"freshness": {
"type": "string",
"description": "Time filter for results (pd: last 24h, pw: last week, pm: last month, py: last year)",
},
},
"required": ["query"],
"additionalProperties": False,
},
},
{
"name": "brave_image_search",
"description": "Perform an image search using Brave Search",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The search query (max 400 characters, 50 words)",
},
"count": {
"type": "integer",
"description": "Number of results to return (max 100, default: 5)",
},
},
"required": ["query"],
"additionalProperties": False,
},
},
]
def get_config_requirements(self):
return {
"token": {
"type": "string",
"description": "Brave Search API key for authentication",
},
}

View File

@@ -1,76 +0,0 @@
import requests
from application.agents.tools.base import Tool
class CryptoPriceTool(Tool):
"""
CryptoPrice
A tool for retrieving cryptocurrency prices using the CryptoCompare public API
"""
def __init__(self, config):
self.config = config
def execute_action(self, action_name, **kwargs):
actions = {"cryptoprice_get": self._get_price}
if action_name in actions:
return actions[action_name](**kwargs)
else:
raise ValueError(f"Unknown action: {action_name}")
def _get_price(self, symbol, currency):
"""
Fetches the current price of a given cryptocurrency symbol in the specified currency.
Example:
symbol = "BTC"
currency = "USD"
returns price in USD.
"""
url = f"https://min-api.cryptocompare.com/data/price?fsym={symbol.upper()}&tsyms={currency.upper()}"
response = requests.get(url)
if response.status_code == 200:
data = response.json()
if currency.upper() in data:
return {
"status_code": response.status_code,
"price": data[currency.upper()],
"message": f"Price of {symbol.upper()} in {currency.upper()} retrieved successfully.",
}
else:
return {
"status_code": response.status_code,
"message": f"Couldn't find price for {symbol.upper()} in {currency.upper()}.",
}
else:
return {
"status_code": response.status_code,
"message": "Failed to retrieve price.",
}
def get_actions_metadata(self):
return [
{
"name": "cryptoprice_get",
"description": "Retrieve the price of a specified cryptocurrency in a given currency",
"parameters": {
"type": "object",
"properties": {
"symbol": {
"type": "string",
"description": "The cryptocurrency symbol (e.g. BTC)",
},
"currency": {
"type": "string",
"description": "The currency in which you want the price (e.g. USD)",
},
},
"required": ["symbol", "currency"],
"additionalProperties": False,
},
}
]
def get_config_requirements(self):
# No specific configuration needed for this tool as it just queries a public endpoint
return {}

View File

@@ -1,114 +0,0 @@
from application.agents.tools.base import Tool
from duckduckgo_search import DDGS
class DuckDuckGoSearchTool(Tool):
"""
DuckDuckGo Search
A tool for performing web and image searches using DuckDuckGo.
"""
def __init__(self, config):
self.config = config
def execute_action(self, action_name, **kwargs):
actions = {
"ddg_web_search": self._web_search,
"ddg_image_search": self._image_search,
}
if action_name in actions:
return actions[action_name](**kwargs)
else:
raise ValueError(f"Unknown action: {action_name}")
def _web_search(
self,
query,
max_results=5,
):
print(f"Performing DuckDuckGo web search for: {query}")
try:
results = DDGS().text(
query,
max_results=max_results,
)
return {
"status_code": 200,
"results": results,
"message": "Web search completed successfully.",
}
except Exception as e:
return {
"status_code": 500,
"message": f"Web search failed: {str(e)}",
}
def _image_search(
self,
query,
max_results=5,
):
print(f"Performing DuckDuckGo image search for: {query}")
try:
results = DDGS().images(
keywords=query,
max_results=max_results,
)
return {
"status_code": 200,
"results": results,
"message": "Image search completed successfully.",
}
except Exception as e:
return {
"status_code": 500,
"message": f"Image search failed: {str(e)}",
}
def get_actions_metadata(self):
return [
{
"name": "ddg_web_search",
"description": "Perform a web search using DuckDuckGo.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search query",
},
"max_results": {
"type": "integer",
"description": "Number of results to return (default: 5)",
},
},
"required": ["query"],
},
},
{
"name": "ddg_image_search",
"description": "Perform an image search using DuckDuckGo.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search query",
},
"max_results": {
"type": "integer",
"description": "Number of results to return (default: 5, max: 50)",
},
},
"required": ["query"],
},
},
]
def get_config_requirements(self):
return {}

View File

@@ -1,861 +0,0 @@
import asyncio
import base64
import json
import logging
import time
from typing import Any, Dict, List, Optional
from urllib.parse import parse_qs, urlparse
from application.agents.tools.base import Tool
from application.api.user.tasks import mcp_oauth_status_task, mcp_oauth_task
from application.cache import get_redis_instance
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.security.encryption import decrypt_credentials
from fastmcp import Client
from fastmcp.client.auth import BearerAuth
from fastmcp.client.transports import (
SSETransport,
StdioTransport,
StreamableHttpTransport,
)
from mcp.client.auth import OAuthClientProvider, TokenStorage
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
from pydantic import AnyHttpUrl, ValidationError
from redis import Redis
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
_mcp_clients_cache = {}
class MCPTool(Tool):
"""
MCP Tool
Connect to remote Model Context Protocol (MCP) servers to access dynamic tools and resources.
"""
def __init__(self, config: Dict[str, Any], user_id: Optional[str] = None):
"""
Initialize the MCP Tool with configuration.
Args:
config: Dictionary containing MCP server configuration:
- server_url: URL of the remote MCP server
- transport_type: Transport type (auto, sse, http, stdio)
- auth_type: Type of authentication (bearer, oauth, api_key, basic, none)
- encrypted_credentials: Encrypted credentials (if available)
- timeout: Request timeout in seconds (default: 30)
- headers: Custom headers for requests
- command: Command for STDIO transport
- args: Arguments for STDIO transport
- oauth_scopes: OAuth scopes for oauth auth type
- oauth_client_name: OAuth client name for oauth auth type
user_id: User ID for decrypting credentials (required if encrypted_credentials exist)
"""
self.config = config
self.user_id = user_id
self.server_url = config.get("server_url", "")
self.transport_type = config.get("transport_type", "auto")
self.auth_type = config.get("auth_type", "none")
self.timeout = config.get("timeout", 30)
self.custom_headers = config.get("headers", {})
self.auth_credentials = {}
if config.get("encrypted_credentials") and user_id:
self.auth_credentials = decrypt_credentials(
config["encrypted_credentials"], user_id
)
else:
self.auth_credentials = config.get("auth_credentials", {})
self.oauth_scopes = config.get("oauth_scopes", [])
self.oauth_task_id = config.get("oauth_task_id", None)
self.oauth_client_name = config.get("oauth_client_name", "DocsGPT-MCP")
self.redirect_uri = f"{settings.API_URL}/api/mcp_server/callback"
self.available_tools = []
self._cache_key = self._generate_cache_key()
self._client = None
# Only validate and setup if server_url is provided and not OAuth
if self.server_url and self.auth_type != "oauth":
self._setup_client()
def _generate_cache_key(self) -> str:
"""Generate a unique cache key for this MCP server configuration."""
auth_key = ""
if self.auth_type == "oauth":
scopes_str = ",".join(self.oauth_scopes) if self.oauth_scopes else "none"
auth_key = f"oauth:{self.oauth_client_name}:{scopes_str}"
elif self.auth_type in ["bearer"]:
token = self.auth_credentials.get(
"bearer_token", ""
) or self.auth_credentials.get("access_token", "")
auth_key = f"bearer:{token[:10]}..." if token else "bearer:none"
elif self.auth_type == "api_key":
api_key = self.auth_credentials.get("api_key", "")
auth_key = f"apikey:{api_key[:10]}..." if api_key else "apikey:none"
elif self.auth_type == "basic":
username = self.auth_credentials.get("username", "")
auth_key = f"basic:{username}"
else:
auth_key = "none"
return f"{self.server_url}#{self.transport_type}#{auth_key}"
def _setup_client(self):
"""Setup FastMCP client with proper transport and authentication."""
global _mcp_clients_cache
if self._cache_key in _mcp_clients_cache:
cached_data = _mcp_clients_cache[self._cache_key]
if time.time() - cached_data["created_at"] < 1800:
self._client = cached_data["client"]
return
else:
del _mcp_clients_cache[self._cache_key]
transport = self._create_transport()
auth = None
if self.auth_type == "oauth":
redis_client = get_redis_instance()
auth = DocsGPTOAuth(
mcp_url=self.server_url,
scopes=self.oauth_scopes,
redis_client=redis_client,
redirect_uri=self.redirect_uri,
task_id=self.oauth_task_id,
db=db,
user_id=self.user_id,
)
elif self.auth_type == "bearer":
token = self.auth_credentials.get(
"bearer_token", ""
) or self.auth_credentials.get("access_token", "")
if token:
auth = BearerAuth(token)
self._client = Client(transport, auth=auth)
_mcp_clients_cache[self._cache_key] = {
"client": self._client,
"created_at": time.time(),
}
def _create_transport(self):
"""Create appropriate transport based on configuration."""
headers = {"Content-Type": "application/json", "User-Agent": "DocsGPT-MCP/1.0"}
headers.update(self.custom_headers)
if self.auth_type == "api_key":
api_key = self.auth_credentials.get("api_key", "")
header_name = self.auth_credentials.get("api_key_header", "X-API-Key")
if api_key:
headers[header_name] = api_key
elif self.auth_type == "basic":
username = self.auth_credentials.get("username", "")
password = self.auth_credentials.get("password", "")
if username and password:
credentials = base64.b64encode(
f"{username}:{password}".encode()
).decode()
headers["Authorization"] = f"Basic {credentials}"
if self.transport_type == "auto":
if "sse" in self.server_url.lower() or self.server_url.endswith("/sse"):
transport_type = "sse"
else:
transport_type = "http"
else:
transport_type = self.transport_type
if transport_type == "sse":
headers.update({"Accept": "text/event-stream", "Cache-Control": "no-cache"})
return SSETransport(url=self.server_url, headers=headers)
elif transport_type == "http":
return StreamableHttpTransport(url=self.server_url, headers=headers)
elif transport_type == "stdio":
command = self.config.get("command", "python")
args = self.config.get("args", [])
env = self.auth_credentials if self.auth_credentials else None
return StdioTransport(command=command, args=args, env=env)
else:
return StreamableHttpTransport(url=self.server_url, headers=headers)
def _format_tools(self, tools_response) -> List[Dict]:
"""Format tools response to match expected format."""
if hasattr(tools_response, "tools"):
tools = tools_response.tools
elif isinstance(tools_response, list):
tools = tools_response
else:
tools = []
tools_dict = []
for tool in tools:
if hasattr(tool, "name"):
tool_dict = {
"name": tool.name,
"description": tool.description,
}
if hasattr(tool, "inputSchema"):
tool_dict["inputSchema"] = tool.inputSchema
tools_dict.append(tool_dict)
elif isinstance(tool, dict):
tools_dict.append(tool)
else:
if hasattr(tool, "model_dump"):
tools_dict.append(tool.model_dump())
else:
tools_dict.append({"name": str(tool), "description": ""})
return tools_dict
async def _execute_with_client(self, operation: str, *args, **kwargs):
"""Execute operation with FastMCP client."""
if not self._client:
raise Exception("FastMCP client not initialized")
async with self._client:
if operation == "ping":
return await self._client.ping()
elif operation == "list_tools":
tools_response = await self._client.list_tools()
self.available_tools = self._format_tools(tools_response)
return self.available_tools
elif operation == "call_tool":
tool_name = args[0]
tool_args = kwargs
return await self._client.call_tool(tool_name, tool_args)
elif operation == "list_resources":
return await self._client.list_resources()
elif operation == "list_prompts":
return await self._client.list_prompts()
else:
raise Exception(f"Unknown operation: {operation}")
def _run_async_operation(self, operation: str, *args, **kwargs):
"""Run async operation in sync context."""
try:
try:
loop = asyncio.get_running_loop()
import concurrent.futures
def run_in_thread():
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
try:
return new_loop.run_until_complete(
self._execute_with_client(operation, *args, **kwargs)
)
finally:
new_loop.close()
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(run_in_thread)
return future.result(timeout=self.timeout)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(
self._execute_with_client(operation, *args, **kwargs)
)
finally:
loop.close()
except Exception as e:
print(f"Error occurred while running async operation: {e}")
raise
def discover_tools(self) -> List[Dict]:
"""
Discover available tools from the MCP server using FastMCP.
Returns:
List of tool definitions from the server
"""
if not self.server_url:
return []
if not self._client:
self._setup_client()
try:
tools = self._run_async_operation("list_tools")
self.available_tools = tools
return self.available_tools
except Exception as e:
raise Exception(f"Failed to discover tools from MCP server: {str(e)}")
def execute_action(self, action_name: str, **kwargs) -> Any:
"""
Execute an action on the remote MCP server using FastMCP.
Args:
action_name: Name of the action to execute
**kwargs: Parameters for the action
Returns:
Result from the MCP server
"""
if not self.server_url:
raise Exception("No MCP server configured")
if not self._client:
self._setup_client()
cleaned_kwargs = {}
for key, value in kwargs.items():
if value == "" or value is None:
continue
cleaned_kwargs[key] = value
try:
result = self._run_async_operation(
"call_tool", action_name, **cleaned_kwargs
)
return self._format_result(result)
except Exception as e:
raise Exception(f"Failed to execute action '{action_name}': {str(e)}")
def _format_result(self, result) -> Dict:
"""Format FastMCP result to match expected format."""
if hasattr(result, "content"):
content_list = []
for content_item in result.content:
if hasattr(content_item, "text"):
content_list.append({"type": "text", "text": content_item.text})
elif hasattr(content_item, "data"):
content_list.append({"type": "data", "data": content_item.data})
else:
content_list.append(
{"type": "unknown", "content": str(content_item)}
)
return {
"content": content_list,
"isError": getattr(result, "isError", False),
}
else:
return result
def test_connection(self) -> Dict:
"""
Test the connection to the MCP server and validate functionality.
Returns:
Dictionary with connection test results including tool count
"""
if not self.server_url:
return {
"success": False,
"message": "No MCP server URL configured",
"tools_count": 0,
"transport_type": self.transport_type,
"auth_type": self.auth_type,
"error_type": "ConfigurationError",
}
if not self._client:
self._setup_client()
try:
if self.auth_type == "oauth":
return self._test_oauth_connection()
else:
return self._test_regular_connection()
except Exception as e:
return {
"success": False,
"message": f"Connection failed: {str(e)}",
"tools_count": 0,
"transport_type": self.transport_type,
"auth_type": self.auth_type,
"error_type": type(e).__name__,
}
def _test_regular_connection(self) -> Dict:
"""Test connection for non-OAuth auth types."""
try:
self._run_async_operation("ping")
ping_success = True
except Exception:
ping_success = False
tools = self.discover_tools()
message = f"Successfully connected to MCP server. Found {len(tools)} tools."
if not ping_success:
message += " (Ping not supported, but tool discovery worked)"
return {
"success": True,
"message": message,
"tools_count": len(tools),
"transport_type": self.transport_type,
"auth_type": self.auth_type,
"ping_supported": ping_success,
"tools": [tool.get("name", "unknown") for tool in tools],
}
def _test_oauth_connection(self) -> Dict:
"""Test connection for OAuth auth type with proper async handling."""
try:
task = mcp_oauth_task.delay(config=self.config, user=self.user_id)
if not task:
raise Exception("Failed to start OAuth authentication")
return {
"success": True,
"requires_oauth": True,
"task_id": task.id,
"status": "pending",
"message": "OAuth flow started",
}
except Exception as e:
return {
"success": False,
"message": f"OAuth connection failed: {str(e)}",
"tools_count": 0,
"transport_type": self.transport_type,
"auth_type": self.auth_type,
"error_type": type(e).__name__,
}
def get_actions_metadata(self) -> List[Dict]:
"""
Get metadata for all available actions.
Returns:
List of action metadata dictionaries
"""
actions = []
for tool in self.available_tools:
input_schema = (
tool.get("inputSchema")
or tool.get("input_schema")
or tool.get("schema")
or tool.get("parameters")
)
parameters_schema = {
"type": "object",
"properties": {},
"required": [],
}
if input_schema:
if isinstance(input_schema, dict):
if "properties" in input_schema:
parameters_schema = {
"type": input_schema.get("type", "object"),
"properties": input_schema.get("properties", {}),
"required": input_schema.get("required", []),
}
for key in ["additionalProperties", "description"]:
if key in input_schema:
parameters_schema[key] = input_schema[key]
else:
parameters_schema["properties"] = input_schema
action = {
"name": tool.get("name", ""),
"description": tool.get("description", ""),
"parameters": parameters_schema,
}
actions.append(action)
return actions
def get_config_requirements(self) -> Dict:
"""Get configuration requirements for the MCP tool."""
return {
"server_url": {
"type": "string",
"description": "URL of the remote MCP server (e.g., https://api.example.com/mcp or https://docs.mcp.cloudflare.com/sse)",
"required": True,
},
"transport_type": {
"type": "string",
"description": "Transport type for connection",
"enum": ["auto", "sse", "http", "stdio"],
"default": "auto",
"required": False,
"help": {
"auto": "Automatically detect best transport",
"sse": "Server-Sent Events (for real-time streaming)",
"http": "HTTP streaming (recommended for production)",
"stdio": "Standard I/O (for local servers)",
},
},
"auth_type": {
"type": "string",
"description": "Authentication type",
"enum": ["none", "bearer", "oauth", "api_key", "basic"],
"default": "none",
"required": True,
"help": {
"none": "No authentication",
"bearer": "Bearer token authentication",
"oauth": "OAuth 2.1 authentication (with frontend integration)",
"api_key": "API key authentication",
"basic": "Basic authentication",
},
},
"auth_credentials": {
"type": "object",
"description": "Authentication credentials (varies by auth_type)",
"required": False,
"properties": {
"bearer_token": {
"type": "string",
"description": "Bearer token for bearer auth",
},
"access_token": {
"type": "string",
"description": "Access token for OAuth (if pre-obtained)",
},
"api_key": {
"type": "string",
"description": "API key for api_key auth",
},
"api_key_header": {
"type": "string",
"description": "Header name for API key (default: X-API-Key)",
},
"username": {
"type": "string",
"description": "Username for basic auth",
},
"password": {
"type": "string",
"description": "Password for basic auth",
},
},
},
"oauth_scopes": {
"type": "array",
"description": "OAuth scopes to request (for oauth auth_type)",
"items": {"type": "string"},
"required": False,
"default": [],
},
"oauth_client_name": {
"type": "string",
"description": "Client name for OAuth registration (for oauth auth_type)",
"default": "DocsGPT-MCP",
"required": False,
},
"headers": {
"type": "object",
"description": "Custom headers to send with requests",
"required": False,
},
"timeout": {
"type": "integer",
"description": "Request timeout in seconds",
"default": 30,
"minimum": 1,
"maximum": 300,
"required": False,
},
"command": {
"type": "string",
"description": "Command to run for STDIO transport (e.g., 'python')",
"required": False,
},
"args": {
"type": "array",
"description": "Arguments for STDIO command",
"items": {"type": "string"},
"required": False,
},
}
class DocsGPTOAuth(OAuthClientProvider):
"""
Custom OAuth handler for DocsGPT that uses frontend redirect instead of browser.
"""
def __init__(
self,
mcp_url: str,
redirect_uri: str,
redis_client: Redis | None = None,
redis_prefix: str = "mcp_oauth:",
task_id: str = None,
scopes: str | list[str] | None = None,
client_name: str = "DocsGPT-MCP",
user_id=None,
db=None,
additional_client_metadata: dict[str, Any] | None = None,
):
"""
Initialize custom OAuth client provider for DocsGPT.
Args:
mcp_url: Full URL to the MCP endpoint
redirect_uri: Custom redirect URI for DocsGPT frontend
redis_client: Redis client for storing auth state
redis_prefix: Prefix for Redis keys
task_id: Task ID for tracking auth status
scopes: OAuth scopes to request
client_name: Name for this client during registration
user_id: User ID for token storage
db: Database instance for token storage
additional_client_metadata: Extra fields for OAuthClientMetadata
"""
self.redirect_uri = redirect_uri
self.redis_client = redis_client
self.redis_prefix = redis_prefix
self.task_id = task_id
self.user_id = user_id
self.db = db
parsed_url = urlparse(mcp_url)
self.server_base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
if isinstance(scopes, list):
scopes = " ".join(scopes)
client_metadata = OAuthClientMetadata(
client_name=client_name,
redirect_uris=[AnyHttpUrl(redirect_uri)],
grant_types=["authorization_code", "refresh_token"],
response_types=["code"],
scope=scopes,
**(additional_client_metadata or {}),
)
storage = DBTokenStorage(
server_url=self.server_base_url, user_id=self.user_id, db_client=self.db
)
super().__init__(
server_url=self.server_base_url,
client_metadata=client_metadata,
storage=storage,
redirect_handler=self.redirect_handler,
callback_handler=self.callback_handler,
)
self.auth_url = None
self.extracted_state = None
def _process_auth_url(self, authorization_url: str) -> tuple[str, str]:
"""Process authorization URL to extract state"""
try:
parsed_url = urlparse(authorization_url)
query_params = parse_qs(parsed_url.query)
state_params = query_params.get("state", [])
if state_params:
state = state_params[0]
else:
raise ValueError("No state in auth URL")
return authorization_url, state
except Exception as e:
raise Exception(f"Failed to process auth URL: {e}")
async def redirect_handler(self, authorization_url: str) -> None:
"""Store auth URL and state in Redis for frontend to use."""
auth_url, state = self._process_auth_url(authorization_url)
logging.info(
"[DocsGPTOAuth] Processed auth_url: %s, state: %s", auth_url, state
)
self.auth_url = auth_url
self.extracted_state = state
if self.redis_client and self.extracted_state:
key = f"{self.redis_prefix}auth_url:{self.extracted_state}"
self.redis_client.setex(key, 600, auth_url)
logging.info("[DocsGPTOAuth] Stored auth_url in Redis: %s", key)
if self.task_id:
status_key = f"mcp_oauth_status:{self.task_id}"
status_data = {
"status": "requires_redirect",
"message": "OAuth authorization required",
"authorization_url": self.auth_url,
"state": self.extracted_state,
"requires_oauth": True,
"task_id": self.task_id,
}
self.redis_client.setex(status_key, 600, json.dumps(status_data))
async def callback_handler(self) -> tuple[str, str | None]:
"""Wait for auth code from Redis using the state value."""
if not self.redis_client or not self.extracted_state:
raise Exception("Redis client or state not configured for OAuth")
poll_interval = 1
max_wait_time = 300
code_key = f"{self.redis_prefix}code:{self.extracted_state}"
if self.task_id:
status_key = f"mcp_oauth_status:{self.task_id}"
status_data = {
"status": "awaiting_callback",
"message": "Waiting for OAuth callback...",
"authorization_url": self.auth_url,
"state": self.extracted_state,
"requires_oauth": True,
"task_id": self.task_id,
}
self.redis_client.setex(status_key, 600, json.dumps(status_data))
start_time = time.time()
while time.time() - start_time < max_wait_time:
code_data = self.redis_client.get(code_key)
if code_data:
code = code_data.decode()
returned_state = self.extracted_state
self.redis_client.delete(code_key)
self.redis_client.delete(
f"{self.redis_prefix}auth_url:{self.extracted_state}"
)
self.redis_client.delete(
f"{self.redis_prefix}state:{self.extracted_state}"
)
if self.task_id:
status_data = {
"status": "callback_received",
"message": "OAuth callback received, completing authentication...",
"task_id": self.task_id,
}
self.redis_client.setex(status_key, 600, json.dumps(status_data))
return code, returned_state
error_key = f"{self.redis_prefix}error:{self.extracted_state}"
error_data = self.redis_client.get(error_key)
if error_data:
error_msg = error_data.decode()
self.redis_client.delete(error_key)
self.redis_client.delete(
f"{self.redis_prefix}auth_url:{self.extracted_state}"
)
self.redis_client.delete(
f"{self.redis_prefix}state:{self.extracted_state}"
)
raise Exception(f"OAuth error: {error_msg}")
await asyncio.sleep(poll_interval)
self.redis_client.delete(f"{self.redis_prefix}auth_url:{self.extracted_state}")
self.redis_client.delete(f"{self.redis_prefix}state:{self.extracted_state}")
raise Exception("OAuth callback timeout: no code received within 5 minutes")
class DBTokenStorage(TokenStorage):
def __init__(self, server_url: str, user_id: str, db_client):
self.server_url = server_url
self.user_id = user_id
self.db_client = db_client
self.collection = db_client["connector_sessions"]
@staticmethod
def get_base_url(url: str) -> str:
parsed = urlparse(url)
return f"{parsed.scheme}://{parsed.netloc}"
def get_db_key(self) -> dict:
return {
"server_url": self.get_base_url(self.server_url),
"user_id": self.user_id,
}
async def get_tokens(self) -> OAuthToken | None:
doc = await asyncio.to_thread(self.collection.find_one, self.get_db_key())
if not doc or "tokens" not in doc:
return None
try:
tokens = OAuthToken.model_validate(doc["tokens"])
return tokens
except ValidationError as e:
logging.error(f"Could not load tokens: {e}")
return None
async def set_tokens(self, tokens: OAuthToken) -> None:
await asyncio.to_thread(
self.collection.update_one,
self.get_db_key(),
{"$set": {"tokens": tokens.model_dump()}},
True,
)
logging.info(f"Saved tokens for {self.get_base_url(self.server_url)}")
async def get_client_info(self) -> OAuthClientInformationFull | None:
doc = await asyncio.to_thread(self.collection.find_one, self.get_db_key())
if not doc or "client_info" not in doc:
return None
try:
client_info = OAuthClientInformationFull.model_validate(doc["client_info"])
tokens = await self.get_tokens()
if tokens is None:
logging.debug(
"No tokens found, clearing client info to force fresh registration."
)
await asyncio.to_thread(
self.collection.update_one,
self.get_db_key(),
{"$unset": {"client_info": ""}},
)
return None
return client_info
except ValidationError as e:
logging.error(f"Could not load client info: {e}")
return None
def _serialize_client_info(self, info: dict) -> dict:
if "redirect_uris" in info and isinstance(info["redirect_uris"], list):
info["redirect_uris"] = [str(u) for u in info["redirect_uris"]]
return info
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
serialized_info = self._serialize_client_info(client_info.model_dump())
await asyncio.to_thread(
self.collection.update_one,
self.get_db_key(),
{"$set": {"client_info": serialized_info}},
True,
)
logging.info(f"Saved client info for {self.get_base_url(self.server_url)}")
async def clear(self) -> None:
await asyncio.to_thread(self.collection.delete_one, self.get_db_key())
logging.info(f"Cleared OAuth cache for {self.get_base_url(self.server_url)}")
@classmethod
async def clear_all(cls, db_client) -> None:
collection = db_client["connector_sessions"]
await asyncio.to_thread(collection.delete_many, {})
logging.info("Cleared all OAuth client cache data.")
class MCPOAuthManager:
"""Manager for handling MCP OAuth callbacks."""
def __init__(self, redis_client: Redis | None, redis_prefix: str = "mcp_oauth:"):
self.redis_client = redis_client
self.redis_prefix = redis_prefix
def handle_oauth_callback(
self, state: str, code: str, error: Optional[str] = None
) -> bool:
"""
Handle OAuth callback from provider.
Args:
state: The state parameter from OAuth callback
code: The authorization code from OAuth callback
error: Error message if OAuth failed
Returns:
True if successful, False otherwise
"""
try:
if not self.redis_client or not state:
raise Exception("Redis client or state not provided")
if error:
error_key = f"{self.redis_prefix}error:{state}"
self.redis_client.setex(error_key, 300, error)
raise Exception(f"OAuth error received: {error}")
code_key = f"{self.redis_prefix}code:{state}"
self.redis_client.setex(code_key, 300, code)
state_key = f"{self.redis_prefix}state:{state}"
self.redis_client.setex(state_key, 300, "completed")
return True
except Exception as e:
logging.error(f"Error handling OAuth callback: {e}")
return False
def get_oauth_status(self, task_id: str) -> Dict[str, Any]:
"""Get current status of OAuth flow using provided task_id."""
if not task_id:
return {"status": "not_started", "message": "OAuth flow not started"}
return mcp_oauth_status_task(task_id)

View File

@@ -1,546 +0,0 @@
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional
import re
import uuid
from .base import Tool
from application.core.mongo_db import MongoDB
from application.core.settings import settings
class MemoryTool(Tool):
"""Memory
Stores and retrieves information across conversations through a memory file directory.
"""
def __init__(self, tool_config: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None) -> None:
"""Initialize the tool.
Args:
tool_config: Optional tool configuration. Should include:
- tool_id: Unique identifier for this memory tool instance (from user_tools._id)
This ensures each user's tool configuration has isolated memories
user_id: The authenticated user's id (should come from decoded_token["sub"]).
"""
self.user_id: Optional[str] = user_id
# Get tool_id from configuration (passed from user_tools._id in production)
# In production, tool_id is the MongoDB ObjectId string from user_tools collection
if tool_config and "tool_id" in tool_config:
self.tool_id = tool_config["tool_id"]
elif user_id:
# Fallback for backward compatibility or testing
self.tool_id = f"default_{user_id}"
else:
# Last resort fallback (shouldn't happen in normal use)
self.tool_id = str(uuid.uuid4())
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
self.collection = db["memories"]
# -----------------------------
# Action implementations
# -----------------------------
def execute_action(self, action_name: str, **kwargs: Any) -> str:
"""Execute an action by name.
Args:
action_name: One of view, create, str_replace, insert, delete, rename.
**kwargs: Parameters for the action.
Returns:
A human-readable string result.
"""
if not self.user_id:
return "Error: MemoryTool requires a valid user_id."
if action_name == "view":
return self._view(
kwargs.get("path", "/"),
kwargs.get("view_range")
)
if action_name == "create":
return self._create(
kwargs.get("path", ""),
kwargs.get("file_text", "")
)
if action_name == "str_replace":
return self._str_replace(
kwargs.get("path", ""),
kwargs.get("old_str", ""),
kwargs.get("new_str", "")
)
if action_name == "insert":
return self._insert(
kwargs.get("path", ""),
kwargs.get("insert_line", 1),
kwargs.get("insert_text", "")
)
if action_name == "delete":
return self._delete(kwargs.get("path", ""))
if action_name == "rename":
return self._rename(
kwargs.get("old_path", ""),
kwargs.get("new_path", "")
)
return f"Unknown action: {action_name}"
def get_actions_metadata(self) -> List[Dict[str, Any]]:
"""Return JSON metadata describing supported actions for tool schemas."""
return [
{
"name": "view",
"description": "Shows directory contents or file contents with optional line ranges.",
"parameters": {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Path to file or directory (e.g., /notes.txt or /project/ or /)."
},
"view_range": {
"type": "array",
"items": {"type": "integer"},
"description": "Optional [start_line, end_line] to view specific lines (1-indexed)."
}
},
"required": ["path"]
},
},
{
"name": "create",
"description": "Create or overwrite a file.",
"parameters": {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "File path to create (e.g., /notes.txt or /project/task.txt)."
},
"file_text": {
"type": "string",
"description": "Content to write to the file."
}
},
"required": ["path", "file_text"]
},
},
{
"name": "str_replace",
"description": "Replace text in a file.",
"parameters": {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "File path (e.g., /notes.txt)."
},
"old_str": {
"type": "string",
"description": "String to find."
},
"new_str": {
"type": "string",
"description": "String to replace with."
}
},
"required": ["path", "old_str", "new_str"]
},
},
{
"name": "insert",
"description": "Insert text at a specific line in a file.",
"parameters": {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "File path (e.g., /notes.txt)."
},
"insert_line": {
"type": "integer",
"description": "Line number to insert at (1-indexed)."
},
"insert_text": {
"type": "string",
"description": "Text to insert."
}
},
"required": ["path", "insert_line", "insert_text"]
},
},
{
"name": "delete",
"description": "Delete a file or directory.",
"parameters": {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Path to delete (e.g., /notes.txt or /project/)."
}
},
"required": ["path"]
},
},
{
"name": "rename",
"description": "Rename or move a file/directory.",
"parameters": {
"type": "object",
"properties": {
"old_path": {
"type": "string",
"description": "Current path (e.g., /old.txt)."
},
"new_path": {
"type": "string",
"description": "New path (e.g., /new.txt)."
}
},
"required": ["old_path", "new_path"]
},
},
]
def get_config_requirements(self) -> Dict[str, Any]:
"""Return configuration requirements."""
return {}
# -----------------------------
# Path validation
# -----------------------------
def _validate_path(self, path: str) -> Optional[str]:
"""Validate and normalize path.
Args:
path: User-provided path.
Returns:
Normalized path or None if invalid.
"""
if not path:
return None
# Remove any leading/trailing whitespace
path = path.strip()
# Preserve whether path ends with / (indicates directory)
is_directory = path.endswith("/")
# Ensure path starts with / for consistency
if not path.startswith("/"):
path = "/" + path
# Check for directory traversal patterns
if ".." in path or path.count("//") > 0:
return None
# Normalize the path
try:
# Convert to Path object and resolve to canonical form
normalized = str(Path(path).as_posix())
# Ensure it still starts with /
if not normalized.startswith("/"):
return None
# Preserve trailing slash for directories
if is_directory and not normalized.endswith("/") and normalized != "/":
normalized = normalized + "/"
return normalized
except Exception:
return None
# -----------------------------
# Internal helpers
# -----------------------------
def _view(self, path: str, view_range: Optional[List[int]] = None) -> str:
"""View directory contents or file contents."""
validated_path = self._validate_path(path)
if not validated_path:
return "Error: Invalid path."
# Check if viewing directory (ends with / or is root)
if validated_path == "/" or validated_path.endswith("/"):
return self._view_directory(validated_path)
# Otherwise view file
return self._view_file(validated_path, view_range)
def _view_directory(self, path: str) -> str:
"""List files in a directory."""
# Ensure path ends with / for proper prefix matching
search_path = path if path.endswith("/") else path + "/"
# Find all files that start with this directory path
query = {
"user_id": self.user_id,
"tool_id": self.tool_id,
"path": {"$regex": f"^{re.escape(search_path)}"}
}
docs = list(self.collection.find(query, {"path": 1}))
if not docs:
return f"Directory: {path}\n(empty)"
# Extract filenames relative to the directory
files = []
for doc in docs:
file_path = doc["path"]
# Remove the directory prefix
if file_path.startswith(search_path):
relative = file_path[len(search_path):]
if relative:
files.append(relative)
files.sort()
file_list = "\n".join(f"- {f}" for f in files)
return f"Directory: {path}\n{file_list}"
def _view_file(self, path: str, view_range: Optional[List[int]] = None) -> str:
"""View file contents with optional line range."""
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": path})
if not doc or not doc.get("content"):
return f"Error: File not found: {path}"
content = str(doc["content"])
# Apply view_range if specified
if view_range and len(view_range) == 2:
lines = content.split("\n")
start, end = view_range
# Convert to 0-indexed
start_idx = max(0, start - 1)
end_idx = min(len(lines), end)
if start_idx >= len(lines):
return f"Error: Line range out of bounds. File has {len(lines)} lines."
selected_lines = lines[start_idx:end_idx]
# Add line numbers (enumerate with 1-based start)
numbered_lines = [f"{i}: {line}" for i, line in enumerate(selected_lines, start=start)]
return "\n".join(numbered_lines)
return content
def _create(self, path: str, file_text: str) -> str:
"""Create or overwrite a file."""
validated_path = self._validate_path(path)
if not validated_path:
return "Error: Invalid path."
if validated_path == "/" or validated_path.endswith("/"):
return "Error: Cannot create a file at directory path."
self.collection.update_one(
{"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path},
{
"$set": {
"content": file_text,
"updated_at": datetime.now()
}
},
upsert=True
)
return f"File created: {validated_path}"
def _str_replace(self, path: str, old_str: str, new_str: str) -> str:
"""Replace text in a file."""
validated_path = self._validate_path(path)
if not validated_path:
return "Error: Invalid path."
if not old_str:
return "Error: old_str is required."
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path})
if not doc or not doc.get("content"):
return f"Error: File not found: {validated_path}"
current_content = str(doc["content"])
# Check if old_str exists (case-insensitive)
if old_str.lower() not in current_content.lower():
return f"Error: String '{old_str}' not found in file."
# Replace the string (case-insensitive)
import re as regex_module
updated_content = regex_module.sub(regex_module.escape(old_str), new_str, current_content, flags=regex_module.IGNORECASE)
self.collection.update_one(
{"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path},
{
"$set": {
"content": updated_content,
"updated_at": datetime.now()
}
}
)
return f"File updated: {validated_path}"
def _insert(self, path: str, insert_line: int, insert_text: str) -> str:
"""Insert text at a specific line."""
validated_path = self._validate_path(path)
if not validated_path:
return "Error: Invalid path."
if not insert_text:
return "Error: insert_text is required."
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path})
if not doc or not doc.get("content"):
return f"Error: File not found: {validated_path}"
current_content = str(doc["content"])
lines = current_content.split("\n")
# Convert to 0-indexed
index = insert_line - 1
if index < 0 or index > len(lines):
return f"Error: Invalid line number. File has {len(lines)} lines."
lines.insert(index, insert_text)
updated_content = "\n".join(lines)
self.collection.update_one(
{"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path},
{
"$set": {
"content": updated_content,
"updated_at": datetime.now()
}
}
)
return f"Text inserted at line {insert_line} in {validated_path}"
def _delete(self, path: str) -> str:
"""Delete a file or directory."""
validated_path = self._validate_path(path)
if not validated_path:
return "Error: Invalid path."
if validated_path == "/":
# Delete all files for this user and tool
result = self.collection.delete_many({"user_id": self.user_id, "tool_id": self.tool_id})
return f"Deleted {result.deleted_count} file(s) from memory."
# Check if it's a directory (ends with /)
if validated_path.endswith("/"):
# Delete all files in directory
result = self.collection.delete_many({
"user_id": self.user_id,
"tool_id": self.tool_id,
"path": {"$regex": f"^{re.escape(validated_path)}"}
})
return f"Deleted directory and {result.deleted_count} file(s)."
# Try to delete as directory first (without trailing slash)
# Check if any files start with this path + /
search_path = validated_path + "/"
directory_result = self.collection.delete_many({
"user_id": self.user_id,
"tool_id": self.tool_id,
"path": {"$regex": f"^{re.escape(search_path)}"}
})
if directory_result.deleted_count > 0:
return f"Deleted directory and {directory_result.deleted_count} file(s)."
# Delete single file
result = self.collection.delete_one({
"user_id": self.user_id,
"tool_id": self.tool_id,
"path": validated_path
})
if result.deleted_count:
return f"Deleted: {validated_path}"
return f"Error: File not found: {validated_path}"
def _rename(self, old_path: str, new_path: str) -> str:
"""Rename or move a file/directory."""
validated_old = self._validate_path(old_path)
validated_new = self._validate_path(new_path)
if not validated_old or not validated_new:
return "Error: Invalid path."
if validated_old == "/" or validated_new == "/":
return "Error: Cannot rename root directory."
# Check if renaming a directory
if validated_old.endswith("/"):
# Ensure validated_new also ends with / for proper path replacement
if not validated_new.endswith("/"):
validated_new = validated_new + "/"
# Find all files in the old directory
docs = list(self.collection.find({
"user_id": self.user_id,
"tool_id": self.tool_id,
"path": {"$regex": f"^{re.escape(validated_old)}"}
}))
if not docs:
return f"Error: Directory not found: {validated_old}"
# Update paths for all files
for doc in docs:
old_file_path = doc["path"]
new_file_path = old_file_path.replace(validated_old, validated_new, 1)
self.collection.update_one(
{"_id": doc["_id"]},
{"$set": {"path": new_file_path, "updated_at": datetime.now()}}
)
return f"Renamed directory: {validated_old} -> {validated_new} ({len(docs)} files)"
# Rename single file
doc = self.collection.find_one({
"user_id": self.user_id,
"tool_id": self.tool_id,
"path": validated_old
})
if not doc:
return f"Error: File not found: {validated_old}"
# Check if new path already exists
existing = self.collection.find_one({
"user_id": self.user_id,
"tool_id": self.tool_id,
"path": validated_new
})
if existing:
return f"Error: File already exists at {validated_new}"
# Delete the old document and create a new one with the new path
self.collection.delete_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_old})
self.collection.insert_one({
"user_id": self.user_id,
"tool_id": self.tool_id,
"path": validated_new,
"content": doc.get("content", ""),
"updated_at": datetime.now()
})
return f"Renamed: {validated_old} -> {validated_new}"

View File

@@ -1,199 +0,0 @@
from datetime import datetime
from typing import Any, Dict, List, Optional
import uuid
from .base import Tool
from application.core.mongo_db import MongoDB
from application.core.settings import settings
class NotesTool(Tool):
"""Notepad
Single note. Supports viewing, overwriting, string replacement.
"""
def __init__(self, tool_config: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None) -> None:
"""Initialize the tool.
Args:
tool_config: Optional tool configuration. Should include:
- tool_id: Unique identifier for this notes tool instance (from user_tools._id)
This ensures each user's tool configuration has isolated notes
user_id: The authenticated user's id (should come from decoded_token["sub"]).
"""
self.user_id: Optional[str] = user_id
# Get tool_id from configuration (passed from user_tools._id in production)
# In production, tool_id is the MongoDB ObjectId string from user_tools collection
if tool_config and "tool_id" in tool_config:
self.tool_id = tool_config["tool_id"]
elif user_id:
# Fallback for backward compatibility or testing
self.tool_id = f"default_{user_id}"
else:
# Last resort fallback (shouldn't happen in normal use)
self.tool_id = str(uuid.uuid4())
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
self.collection = db["notes"]
# -----------------------------
# Action implementations
# -----------------------------
def execute_action(self, action_name: str, **kwargs: Any) -> str:
"""Execute an action by name.
Args:
action_name: One of view, overwrite, str_replace, insert, delete.
**kwargs: Parameters for the action.
Returns:
A human-readable string result.
"""
if not self.user_id:
return "Error: NotesTool requires a valid user_id."
if action_name == "view":
return self._get_note()
if action_name == "overwrite":
return self._overwrite_note(kwargs.get("text", ""))
if action_name == "str_replace":
return self._str_replace(kwargs.get("old_str", ""), kwargs.get("new_str", ""))
if action_name == "insert":
return self._insert(kwargs.get("line_number", 1), kwargs.get("text", ""))
if action_name == "delete":
return self._delete_note()
return f"Unknown action: {action_name}"
def get_actions_metadata(self) -> List[Dict[str, Any]]:
"""Return JSON metadata describing supported actions for tool schemas."""
return [
{
"name": "view",
"description": "Retrieve the user's note.",
"parameters": {"type": "object", "properties": {}},
},
{
"name": "overwrite",
"description": "Replace the entire note content (creates if doesn't exist).",
"parameters": {
"type": "object",
"properties": {
"text": {"type": "string", "description": "New note content."}
},
"required": ["text"],
},
},
{
"name": "str_replace",
"description": "Replace occurrences of old_str with new_str in the note.",
"parameters": {
"type": "object",
"properties": {
"old_str": {"type": "string", "description": "String to find."},
"new_str": {"type": "string", "description": "String to replace with."}
},
"required": ["old_str", "new_str"],
},
},
{
"name": "insert",
"description": "Insert text at the specified line number (1-indexed).",
"parameters": {
"type": "object",
"properties": {
"line_number": {"type": "integer", "description": "Line number to insert at (1-indexed)."},
"text": {"type": "string", "description": "Text to insert."}
},
"required": ["line_number", "text"],
},
},
{
"name": "delete",
"description": "Delete the user's note.",
"parameters": {"type": "object", "properties": {}},
},
]
def get_config_requirements(self) -> Dict[str, Any]:
"""Return configuration requirements (none for now)."""
return {}
# -----------------------------
# Internal helpers (single-note)
# -----------------------------
def _get_note(self) -> str:
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
if not doc or not doc.get("note"):
return "No note found."
return str(doc["note"])
def _overwrite_note(self, content: str) -> str:
content = (content or "").strip()
if not content:
return "Note content required."
self.collection.update_one(
{"user_id": self.user_id, "tool_id": self.tool_id},
{"$set": {"note": content, "updated_at": datetime.utcnow()}},
upsert=True, # ✅ create if missing
)
return "Note saved."
def _str_replace(self, old_str: str, new_str: str) -> str:
if not old_str:
return "old_str is required."
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
if not doc or not doc.get("note"):
return "No note found."
current_note = str(doc["note"])
# Case-insensitive search
if old_str.lower() not in current_note.lower():
return f"String '{old_str}' not found in note."
# Case-insensitive replacement
import re
updated_note = re.sub(re.escape(old_str), new_str, current_note, flags=re.IGNORECASE)
self.collection.update_one(
{"user_id": self.user_id, "tool_id": self.tool_id},
{"$set": {"note": updated_note, "updated_at": datetime.utcnow()}},
)
return "Note updated."
def _insert(self, line_number: int, text: str) -> str:
if not text:
return "Text is required."
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
if not doc or not doc.get("note"):
return "No note found."
current_note = str(doc["note"])
lines = current_note.split("\n")
# Convert to 0-indexed and validate
index = line_number - 1
if index < 0 or index > len(lines):
return f"Invalid line number. Note has {len(lines)} lines."
lines.insert(index, text)
updated_note = "\n".join(lines)
self.collection.update_one(
{"user_id": self.user_id, "tool_id": self.tool_id},
{"$set": {"note": updated_note, "updated_at": datetime.utcnow()}},
)
return "Text inserted."
def _delete_note(self) -> str:
res = self.collection.delete_one({"user_id": self.user_id, "tool_id": self.tool_id})
return "Note deleted." if res.deleted_count else "No note found to delete."

View File

@@ -1,127 +0,0 @@
import requests
from application.agents.tools.base import Tool
class NtfyTool(Tool):
"""
Ntfy Tool
A tool for sending notifications to ntfy topics on a specified server.
"""
def __init__(self, config):
"""
Initialize the NtfyTool with configuration.
Args:
config (dict): Configuration dictionary containing the access token.
"""
self.config = config
self.token = config.get("token", "")
def execute_action(self, action_name, **kwargs):
"""
Execute the specified action with given parameters.
Args:
action_name (str): Name of the action to execute.
**kwargs: Parameters for the action, including server_url.
Returns:
dict: Result of the action with status code and message.
Raises:
ValueError: If the action name is unknown.
"""
actions = {
"ntfy_send_message": self._send_message,
}
if action_name in actions:
return actions[action_name](**kwargs)
else:
raise ValueError(f"Unknown action: {action_name}")
def _send_message(self, server_url, message, topic, title=None, priority=None):
"""
Send a message to an ntfy topic on the specified server.
Args:
server_url (str): Base URL of the ntfy server (e.g., https://ntfy.sh).
message (str): The message text to send.
topic (str): The topic to send the message to.
title (str, optional): Title of the notification.
priority (int, optional): Priority of the notification (1-5).
Returns:
dict: Response with status code and a confirmation message.
Raises:
ValueError: If priority is not an integer between 1 and 5.
"""
url = f"{server_url.rstrip('/')}/{topic}"
headers = {}
if title:
headers["X-Title"] = title
if priority:
try:
priority = int(priority)
except (ValueError, TypeError):
raise ValueError("Priority must be convertible to an integer")
if priority < 1 or priority > 5:
raise ValueError("Priority must be an integer between 1 and 5")
headers["X-Priority"] = str(priority)
if self.token:
headers["Authorization"] = f"Basic {self.token}"
data = message.encode("utf-8")
response = requests.post(url, headers=headers, data=data)
return {"status_code": response.status_code, "message": "Message sent"}
def get_actions_metadata(self):
"""
Provide metadata about available actions.
Returns:
list: List of dictionaries describing each action.
"""
return [
{
"name": "ntfy_send_message",
"description": "Send a notification to an ntfy topic",
"parameters": {
"type": "object",
"properties": {
"server_url": {
"type": "string",
"description": "Base URL of the ntfy server",
},
"message": {
"type": "string",
"description": "Text to send in the notification",
},
"topic": {
"type": "string",
"description": "Topic to send the notification to",
},
"title": {
"type": "string",
"description": "Title of the notification (optional)",
},
"priority": {
"type": "integer",
"description": "Priority of the notification (1-5, optional)",
},
},
"required": ["server_url", "message", "topic"],
"additionalProperties": False,
},
},
]
def get_config_requirements(self):
"""
Specify the configuration requirements.
Returns:
dict: Dictionary describing required config parameters.
"""
return {
"token": {"type": "string", "description": "Access token for authentication"},
}

View File

@@ -1,163 +0,0 @@
import psycopg2
from application.agents.tools.base import Tool
class PostgresTool(Tool):
"""
PostgreSQL Database Tool
A tool for connecting to a PostgreSQL database using a connection string,
executing SQL queries, and retrieving schema information.
"""
def __init__(self, config):
self.config = config
self.connection_string = config.get("token", "")
def execute_action(self, action_name, **kwargs):
actions = {
"postgres_execute_sql": self._execute_sql,
"postgres_get_schema": self._get_schema,
}
if action_name in actions:
return actions[action_name](**kwargs)
else:
raise ValueError(f"Unknown action: {action_name}")
def _execute_sql(self, sql_query):
"""
Executes an SQL query against the PostgreSQL database using a connection string.
"""
conn = None # Initialize conn to None for error handling
try:
conn = psycopg2.connect(self.connection_string)
cur = conn.cursor()
cur.execute(sql_query)
conn.commit()
if sql_query.strip().lower().startswith("select"):
column_names = [desc[0] for desc in cur.description] if cur.description else []
results = []
rows = cur.fetchall()
for row in rows:
results.append(dict(zip(column_names, row)))
response_data = {"data": results, "column_names": column_names}
else:
row_count = cur.rowcount
response_data = {"message": f"Query executed successfully, {row_count} rows affected."}
cur.close()
return {
"status_code": 200,
"message": "SQL query executed successfully.",
"response_data": response_data,
}
except psycopg2.Error as e:
error_message = f"Database error: {e}"
print(f"Database error: {e}")
return {
"status_code": 500,
"message": "Failed to execute SQL query.",
"error": error_message,
}
finally:
if conn: # Ensure connection is closed even if errors occur
conn.close()
def _get_schema(self, db_name):
"""
Retrieves the schema of the PostgreSQL database using a connection string.
"""
conn = None # Initialize conn to None for error handling
try:
conn = psycopg2.connect(self.connection_string)
cur = conn.cursor()
cur.execute("""
SELECT
table_name,
column_name,
data_type,
column_default,
is_nullable
FROM
information_schema.columns
WHERE
table_schema = 'public'
ORDER BY
table_name,
ordinal_position;
""")
schema_data = {}
for row in cur.fetchall():
table_name, column_name, data_type, column_default, is_nullable = row
if table_name not in schema_data:
schema_data[table_name] = []
schema_data[table_name].append({
"column_name": column_name,
"data_type": data_type,
"column_default": column_default,
"is_nullable": is_nullable
})
cur.close()
return {
"status_code": 200,
"message": "Database schema retrieved successfully.",
"schema": schema_data,
}
except psycopg2.Error as e:
error_message = f"Database error: {e}"
print(f"Database error: {e}")
return {
"status_code": 500,
"message": "Failed to retrieve database schema.",
"error": error_message,
}
finally:
if conn: # Ensure connection is closed even if errors occur
conn.close()
def get_actions_metadata(self):
return [
{
"name": "postgres_execute_sql",
"description": "Execute an SQL query against the PostgreSQL database and return the results. Use this tool to interact with the database, e.g., retrieve specific data or perform updates. Only SELECT queries will return data, other queries will return execution status.",
"parameters": {
"type": "object",
"properties": {
"sql_query": {
"type": "string",
"description": "The SQL query to execute.",
},
},
"required": ["sql_query"],
"additionalProperties": False,
},
},
{
"name": "postgres_get_schema",
"description": "Retrieve the schema of the PostgreSQL database, including tables and their columns. Use this to understand the database structure before executing queries. db_name is 'default' if not provided.",
"parameters": {
"type": "object",
"properties": {
"db_name": {
"type": "string",
"description": "The name of the database to retrieve the schema for.",
},
},
"required": ["db_name"],
"additionalProperties": False,
},
},
]
def get_config_requirements(self):
return {
"token": {
"type": "string",
"description": "PostgreSQL database connection string (e.g., 'postgresql://user:password@host:port/dbname')",
},
}

View File

@@ -1,83 +0,0 @@
import requests
from markdownify import markdownify
from application.agents.tools.base import Tool
from urllib.parse import urlparse
class ReadWebpageTool(Tool):
"""
Read Webpage (browser)
A tool to fetch the HTML content of a URL and convert it to Markdown.
"""
def __init__(self, config=None):
"""
Initializes the tool.
:param config: Optional configuration dictionary. Not used by this tool.
"""
self.config = config
def execute_action(self, action_name: str, **kwargs) -> str:
"""
Executes the specified action. For this tool, the only action is 'read_webpage'.
:param action_name: The name of the action to execute. Should be 'read_webpage'.
:param kwargs: Keyword arguments, must include 'url'.
:return: The Markdown content of the webpage or an error message.
"""
if action_name != "read_webpage":
return f"Error: Unknown action '{action_name}'. This tool only supports 'read_webpage'."
url = kwargs.get("url")
if not url:
return "Error: URL parameter is missing."
# Ensure the URL has a scheme (if not, default to http)
parsed_url = urlparse(url)
if not parsed_url.scheme:
url = "http://" + url
try:
response = requests.get(url, timeout=10, headers={'User-Agent': 'DocsGPT-Agent/1.0'})
response.raise_for_status() # Raise an exception for HTTP errors (4xx or 5xx)
html_content = response.text
#soup = BeautifulSoup(html_content, 'html.parser')
markdown_content = markdownify(html_content, heading_style="ATX", newline_style="BACKSLASH")
return markdown_content
except requests.exceptions.RequestException as e:
return f"Error fetching URL {url}: {e}"
except Exception as e:
return f"Error processing URL {url}: {e}"
def get_actions_metadata(self):
"""
Returns metadata for the actions supported by this tool.
"""
return [
{
"name": "read_webpage",
"description": "Fetches the HTML content of a given URL and returns it as clean Markdown text. Input must be a valid URL.",
"parameters": {
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "The fully qualified URL of the webpage to read (e.g., 'https://www.example.com').",
}
},
"required": ["url"],
"additionalProperties": False,
},
}
]
def get_config_requirements(self):
"""
Returns a dictionary describing the configuration requirements for the tool.
This tool does not require any specific configuration.
"""
return {}

View File

@@ -1,86 +0,0 @@
import requests
from application.agents.tools.base import Tool
class TelegramTool(Tool):
"""
Telegram Bot
A flexible Telegram tool for performing various actions (e.g., sending messages, images).
Requires a bot token and chat ID for configuration
"""
def __init__(self, config):
self.config = config
self.token = config.get("token", "")
def execute_action(self, action_name, **kwargs):
actions = {
"telegram_send_message": self._send_message,
"telegram_send_image": self._send_image,
}
if action_name in actions:
return actions[action_name](**kwargs)
else:
raise ValueError(f"Unknown action: {action_name}")
def _send_message(self, text, chat_id):
print(f"Sending message: {text}")
url = f"https://api.telegram.org/bot{self.token}/sendMessage"
payload = {"chat_id": chat_id, "text": text}
response = requests.post(url, data=payload)
return {"status_code": response.status_code, "message": "Message sent"}
def _send_image(self, image_url, chat_id):
print(f"Sending image: {image_url}")
url = f"https://api.telegram.org/bot{self.token}/sendPhoto"
payload = {"chat_id": chat_id, "photo": image_url}
response = requests.post(url, data=payload)
return {"status_code": response.status_code, "message": "Image sent"}
def get_actions_metadata(self):
return [
{
"name": "telegram_send_message",
"description": "Send a notification to Telegram chat",
"parameters": {
"type": "object",
"properties": {
"text": {
"type": "string",
"description": "Text to send in the notification",
},
"chat_id": {
"type": "string",
"description": "Chat ID to send the notification to",
},
},
"required": ["text"],
"additionalProperties": False,
},
},
{
"name": "telegram_send_image",
"description": "Send an image to the Telegram chat",
"parameters": {
"type": "object",
"properties": {
"image_url": {
"type": "string",
"description": "URL of the image to send",
},
"chat_id": {
"type": "string",
"description": "Chat ID to send the image to",
},
},
"required": ["image_url"],
"additionalProperties": False,
},
},
]
def get_config_requirements(self):
return {
"token": {"type": "string", "description": "Bot token for authentication"},
}

View File

@@ -1,61 +0,0 @@
import json
import logging
logger = logging.getLogger(__name__)
class ToolActionParser:
def __init__(self, llm_type):
self.llm_type = llm_type
self.parsers = {
"OpenAILLM": self._parse_openai_llm,
"GoogleLLM": self._parse_google_llm,
}
def parse_args(self, call):
parser = self.parsers.get(self.llm_type, self._parse_openai_llm)
return parser(call)
def _parse_openai_llm(self, call):
try:
call_args = json.loads(call.arguments)
tool_parts = call.name.split("_")
# If the tool name doesn't contain an underscore, it's likely a hallucinated tool
if len(tool_parts) < 2:
logger.warning(f"Invalid tool name format: {call.name}. Expected format: action_name_tool_id")
return None, None, None
tool_id = tool_parts[-1]
action_name = "_".join(tool_parts[:-1])
# Validate that tool_id looks like a numerical ID
if not tool_id.isdigit():
logger.warning(f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call.")
except (AttributeError, TypeError) as e:
logger.error(f"Error parsing OpenAI LLM call: {e}")
return None, None, None
return tool_id, action_name, call_args
def _parse_google_llm(self, call):
try:
call_args = call.arguments
tool_parts = call.name.split("_")
# If the tool name doesn't contain an underscore, it's likely a hallucinated tool
if len(tool_parts) < 2:
logger.warning(f"Invalid tool name format: {call.name}. Expected format: action_name_tool_id")
return None, None, None
tool_id = tool_parts[-1]
action_name = "_".join(tool_parts[:-1])
# Validate that tool_id looks like a numerical ID
if not tool_id.isdigit():
logger.warning(f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call.")
except (AttributeError, TypeError) as e:
logger.error(f"Error parsing Google LLM call: {e}")
return None, None, None
return tool_id, action_name, call_args

View File

@@ -1,49 +0,0 @@
import importlib
import inspect
import os
import pkgutil
from application.agents.tools.base import Tool
class ToolManager:
def __init__(self, config):
self.config = config
self.tools = {}
self.load_tools()
def load_tools(self):
tools_dir = os.path.join(os.path.dirname(__file__))
for finder, name, ispkg in pkgutil.iter_modules([tools_dir]):
if name == "base" or name.startswith("__"):
continue
module = importlib.import_module(f"application.agents.tools.{name}")
for member_name, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, Tool) and obj is not Tool:
tool_config = self.config.get(name, {})
self.tools[name] = obj(tool_config)
def load_tool(self, tool_name, tool_config, user_id=None):
self.config[tool_name] = tool_config
module = importlib.import_module(f"application.agents.tools.{tool_name}")
for member_name, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, Tool) and obj is not Tool:
if tool_name in {"mcp_tool", "notes", "memory"} and user_id:
return obj(tool_config, user_id)
else:
return obj(tool_config)
def execute_action(self, tool_name, action_name, user_id=None, **kwargs):
if tool_name not in self.tools:
raise ValueError(f"Tool '{tool_name}' not loaded")
if tool_name in {"mcp_tool", "memory"} and user_id:
tool_config = self.config.get(tool_name, {})
tool = self.load_tool(tool_name, tool_config, user_id)
return tool.execute_action(action_name, **kwargs)
return self.tools[tool_name].execute_action(action_name, **kwargs)
def get_all_actions_metadata(self):
metadata = []
for tool in self.tools.values():
metadata.extend(tool.get_actions_metadata())
return metadata

View File

@@ -1,7 +0,0 @@
from flask_restx import Api
api = Api(
version="1.0",
title="DocsGPT API",
description="API for DocsGPT",
)

View File

@@ -1,19 +0,0 @@
from flask import Blueprint
from application.api import api
from application.api.answer.routes.answer import AnswerResource
from application.api.answer.routes.base import answer_ns
from application.api.answer.routes.stream import StreamResource
answer = Blueprint("answer", __name__)
api.add_namespace(answer_ns)
def init_answer_routes():
api.add_resource(StreamResource, "/stream")
api.add_resource(AnswerResource, "/api/answer")
init_answer_routes()

View File

@@ -0,0 +1,619 @@
import asyncio
import datetime
import json
import logging
import os
import sys
import traceback
from bson.dbref import DBRef
from bson.objectid import ObjectId
from flask import Blueprint, current_app, make_response, request, Response
from flask_restx import fields, Namespace, Resource
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.error import bad_request
from application.extensions import api
from application.llm.llm_creator import LLMCreator
from application.retriever.retriever_creator import RetrieverCreator
from application.utils import check_required_fields
logger = logging.getLogger(__name__)
mongo = MongoDB.get_client()
db = mongo["docsgpt"]
conversations_collection = db["conversations"]
sources_collection = db["sources"]
prompts_collection = db["prompts"]
api_key_collection = db["api_keys"]
user_logs_collection = db["user_logs"]
answer = Blueprint("answer", __name__)
answer_ns = Namespace("answer", description="Answer related operations", path="/")
api.add_namespace(answer_ns)
gpt_model = ""
# to have some kind of default behaviour
if settings.LLM_NAME == "openai":
gpt_model = "gpt-3.5-turbo"
elif settings.LLM_NAME == "anthropic":
gpt_model = "claude-2"
elif settings.LLM_NAME == "groq":
gpt_model = "llama3-8b-8192"
if settings.MODEL_NAME: # in case there is particular model name configured
gpt_model = settings.MODEL_NAME
# load the prompts
current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
with open(os.path.join(current_dir, "prompts", "chat_combine_default.txt"), "r") as f:
chat_combine_template = f.read()
with open(os.path.join(current_dir, "prompts", "chat_reduce_prompt.txt"), "r") as f:
chat_reduce_template = f.read()
with open(os.path.join(current_dir, "prompts", "chat_combine_creative.txt"), "r") as f:
chat_combine_creative = f.read()
with open(os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r") as f:
chat_combine_strict = f.read()
api_key_set = settings.API_KEY is not None
embeddings_key_set = settings.EMBEDDINGS_KEY is not None
async def async_generate(chain, question, chat_history):
result = await chain.arun({"question": question, "chat_history": chat_history})
return result
def run_async_chain(chain, question, chat_history):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = {}
try:
answer = loop.run_until_complete(async_generate(chain, question, chat_history))
finally:
loop.close()
result["answer"] = answer
return result
def get_data_from_api_key(api_key):
data = api_key_collection.find_one({"key": api_key})
# # Raise custom exception if the API key is not found
if data is None:
raise Exception("Invalid API Key, please generate new key", 401)
if "retriever" not in data:
data["retriever"] = None
if "source" in data and isinstance(data["source"], DBRef):
source_doc = db.dereference(data["source"])
data["source"] = str(source_doc["_id"])
if "retriever" in source_doc:
data["retriever"] = source_doc["retriever"]
else:
data["source"] = {}
return data
def get_retriever(source_id: str):
doc = sources_collection.find_one({"_id": ObjectId(source_id)})
if doc is None:
raise Exception("Source document does not exist", 404)
retriever_name = None if "retriever" not in doc else doc["retriever"]
return retriever_name
def is_azure_configured():
return (
settings.OPENAI_API_BASE
and settings.OPENAI_API_VERSION
and settings.AZURE_DEPLOYMENT_NAME
)
def save_conversation(conversation_id, question, response, source_log_docs, llm):
if conversation_id is not None and conversation_id != "None":
conversations_collection.update_one(
{"_id": ObjectId(conversation_id)},
{
"$push": {
"queries": {
"prompt": question,
"response": response,
"sources": source_log_docs,
}
}
},
)
else:
# create new conversation
# generate summary
messages_summary = [
{
"role": "assistant",
"content": "Summarise following conversation in no more than 3 "
"words, respond ONLY with the summary, use the same "
"language as the system",
},
{
"role": "user",
"content": "Summarise following conversation in no more than 3 words, "
"respond ONLY with the summary, use the same language as the "
"system \n\nUser: "
+ question
+ "\n\n"
+ "AI: "
+ response,
},
]
completion = llm.gen(model=gpt_model, messages=messages_summary, max_tokens=30)
conversation_id = conversations_collection.insert_one(
{
"user": "local",
"date": datetime.datetime.utcnow(),
"name": completion,
"queries": [
{
"prompt": question,
"response": response,
"sources": source_log_docs,
}
],
}
).inserted_id
return conversation_id
def get_prompt(prompt_id):
if prompt_id == "default":
prompt = chat_combine_template
elif prompt_id == "creative":
prompt = chat_combine_creative
elif prompt_id == "strict":
prompt = chat_combine_strict
else:
prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"]
return prompt
def complete_stream(
question, retriever, conversation_id, user_api_key, isNoneDoc=False
):
try:
response_full = ""
source_log_docs = []
answer = retriever.gen()
sources = retriever.search()
for source in sources:
if "text" in source:
source["text"] = source["text"][:100].strip() + "..."
if len(sources) > 0:
data = json.dumps({"type": "source", "source": sources})
yield f"data: {data}\n\n"
for line in answer:
if "answer" in line:
response_full += str(line["answer"])
data = json.dumps(line)
yield f"data: {data}\n\n"
elif "source" in line:
source_log_docs.append(line["source"])
if isNoneDoc:
for doc in source_log_docs:
doc["source"] = "None"
llm = LLMCreator.create_llm(
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key
)
if user_api_key is None:
conversation_id = save_conversation(
conversation_id, question, response_full, source_log_docs, llm
)
# send data.type = "end" to indicate that the stream has ended as json
data = json.dumps({"type": "id", "id": str(conversation_id)})
yield f"data: {data}\n\n"
retriever_params = retriever.get_params()
user_logs_collection.insert_one(
{
"action": "stream_answer",
"level": "info",
"user": "local",
"api_key": user_api_key,
"question": question,
"response": response_full,
"sources": source_log_docs,
"retriever_params": retriever_params,
"timestamp": datetime.datetime.now(datetime.timezone.utc),
}
)
data = json.dumps({"type": "end"})
yield f"data: {data}\n\n"
except Exception as e:
print("\033[91merr", str(e), file=sys.stderr)
traceback.print_exc()
data = json.dumps(
{
"type": "error",
"error": "Please try again later. We apologize for any inconvenience.",
"error_exception": str(e),
}
)
yield f"data: {data}\n\n"
return
@answer_ns.route("/stream")
class Stream(Resource):
stream_model = api.model(
"StreamModel",
{
"question": fields.String(
required=True, description="Question to be asked"
),
"history": fields.List(
fields.String, required=False, description="Chat history"
),
"conversation_id": fields.String(
required=False, description="Conversation ID"
),
"prompt_id": fields.String(
required=False, default="default", description="Prompt ID"
),
"chunks": fields.Integer(
required=False, default=2, description="Number of chunks"
),
"token_limit": fields.Integer(required=False, description="Token limit"),
"retriever": fields.String(required=False, description="Retriever type"),
"api_key": fields.String(required=False, description="API key"),
"active_docs": fields.String(
required=False, description="Active documents"
),
"isNoneDoc": fields.Boolean(
required=False, description="Flag indicating if no document is used"
),
},
)
@api.expect(stream_model)
@api.doc(description="Stream a response based on the question and retriever")
def post(self):
data = request.get_json()
required_fields = ["question"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
question = data["question"]
history = data.get("history", [])
history = json.loads(history)
conversation_id = data.get("conversation_id")
prompt_id = data.get("prompt_id", "default")
chunks = int(data.get("chunks", 2))
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
retriever_name = data.get("retriever", "classic")
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
chunks = int(data_key.get("chunks", 2))
prompt_id = data_key.get("prompt_id", "default")
source = {"active_docs": data_key.get("source")}
retriever_name = data_key.get("retriever", retriever_name)
user_api_key = data["api_key"]
elif "active_docs" in data:
source = {"active_docs": data["active_docs"]}
retriever_name = get_retriever(data["active_docs"]) or retriever_name
user_api_key = None
else:
source = {}
user_api_key = None
current_app.logger.info(
f"/stream - request_data: {data}, source: {source}",
extra={"data": json.dumps({"request_data": data, "source": source})},
)
prompt = get_prompt(prompt_id)
if "isNoneDoc" in data and data["isNoneDoc"] is True:
chunks = 0
retriever = RetrieverCreator.create_retriever(
retriever_name,
question=question,
source=source,
chat_history=history,
prompt=prompt,
chunks=chunks,
token_limit=token_limit,
gpt_model=gpt_model,
user_api_key=user_api_key,
)
return Response(
complete_stream(
question=question,
retriever=retriever,
conversation_id=conversation_id,
user_api_key=user_api_key,
isNoneDoc=data.get("isNoneDoc"),
),
mimetype="text/event-stream",
)
except ValueError:
message = "Malformed request body"
print("\033[91merr", str(message), file=sys.stderr)
return Response(
error_stream_generate(message),
status=400,
mimetype="text/event-stream",
)
except Exception as e:
current_app.logger.error(
f"/stream - error: {str(e)} - traceback: {traceback.format_exc()}",
extra={"error": str(e), "traceback": traceback.format_exc()},
)
message = e.args[0]
status_code = 400
# Custom exceptions with two arguments, index 1 as status code
if len(e.args) >= 2:
status_code = e.args[1]
return Response(
error_stream_generate(message),
status=status_code,
mimetype="text/event-stream",
)
def error_stream_generate(err_response):
data = json.dumps({"type": "error", "error": err_response})
yield f"data: {data}\n\n"
@answer_ns.route("/api/answer")
class Answer(Resource):
answer_model = api.model(
"AnswerModel",
{
"question": fields.String(
required=True, description="The question to answer"
),
"history": fields.List(
fields.String, required=False, description="Conversation history"
),
"conversation_id": fields.String(
required=False, description="Conversation ID"
),
"prompt_id": fields.String(
required=False, default="default", description="Prompt ID"
),
"chunks": fields.Integer(
required=False, default=2, description="Number of chunks"
),
"token_limit": fields.Integer(required=False, description="Token limit"),
"retriever": fields.String(required=False, description="Retriever type"),
"api_key": fields.String(required=False, description="API key"),
"active_docs": fields.String(
required=False, description="Active documents"
),
"isNoneDoc": fields.Boolean(
required=False, description="Flag indicating if no document is used"
),
},
)
@api.expect(answer_model)
@api.doc(description="Provide an answer based on the question and retriever")
def post(self):
data = request.get_json()
required_fields = ["question"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
question = data["question"]
history = data.get("history", [])
conversation_id = data.get("conversation_id")
prompt_id = data.get("prompt_id", "default")
chunks = int(data.get("chunks", 2))
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
retriever_name = data.get("retriever", "classic")
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
chunks = int(data_key.get("chunks", 2))
prompt_id = data_key.get("prompt_id", "default")
source = {"active_docs": data_key.get("source")}
retriever_name = data_key.get("retriever", retriever_name)
user_api_key = data["api_key"]
elif "active_docs" in data:
source = {"active_docs": data["active_docs"]}
retriever_name = get_retriever(data["active_docs"]) or retriever_name
user_api_key = None
else:
source = {}
user_api_key = None
prompt = get_prompt(prompt_id)
current_app.logger.info(
f"/api/answer - request_data: {data}, source: {source}",
extra={"data": json.dumps({"request_data": data, "source": source})},
)
retriever = RetrieverCreator.create_retriever(
retriever_name,
question=question,
source=source,
chat_history=history,
prompt=prompt,
chunks=chunks,
token_limit=token_limit,
gpt_model=gpt_model,
user_api_key=user_api_key,
)
source_log_docs = []
response_full = ""
for line in retriever.gen():
if "source" in line:
source_log_docs.append(line["source"])
elif "answer" in line:
response_full += line["answer"]
if data.get("isNoneDoc"):
for doc in source_log_docs:
doc["source"] = "None"
llm = LLMCreator.create_llm(
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key
)
result = {"answer": response_full, "sources": source_log_docs}
result["conversation_id"] = str(
save_conversation(
conversation_id, question, response_full, source_log_docs, llm
)
)
retriever_params = retriever.get_params()
user_logs_collection.insert_one(
{
"action": "api_answer",
"level": "info",
"user": "local",
"api_key": user_api_key,
"question": question,
"response": response_full,
"sources": source_log_docs,
"retriever_params": retriever_params,
"timestamp": datetime.datetime.now(datetime.timezone.utc),
}
)
except Exception as e:
current_app.logger.error(
f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}",
extra={"error": str(e), "traceback": traceback.format_exc()},
)
return bad_request(500, str(e))
return make_response(result, 200)
@answer_ns.route("/api/search")
class Search(Resource):
search_model = api.model(
"SearchModel",
{
"question": fields.String(
required=True, description="The question to search"
),
"chunks": fields.Integer(
required=False, default=2, description="Number of chunks"
),
"api_key": fields.String(
required=False, description="API key for authentication"
),
"active_docs": fields.String(
required=False, description="Active documents for retrieval"
),
"retriever": fields.String(required=False, description="Retriever type"),
"token_limit": fields.Integer(
required=False, description="Limit for tokens"
),
"isNoneDoc": fields.Boolean(
required=False, description="Flag indicating if no document is used"
),
},
)
@api.expect(search_model)
@api.doc(
description="Search for relevant documents based on the question and retriever"
)
def post(self):
data = request.get_json()
required_fields = ["question"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
question = data["question"]
chunks = int(data.get("chunks", 2))
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
retriever_name = data.get("retriever", "classic")
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
chunks = int(data_key.get("chunks", 2))
source = {"active_docs": data_key.get("source")}
user_api_key = data["api_key"]
elif "active_docs" in data:
source = {"active_docs": data["active_docs"]}
user_api_key = None
else:
source = {}
user_api_key = None
current_app.logger.info(
f"/api/answer - request_data: {data}, source: {source}",
extra={"data": json.dumps({"request_data": data, "source": source})},
)
retriever = RetrieverCreator.create_retriever(
retriever_name,
question=question,
source=source,
chat_history=[],
prompt="default",
chunks=chunks,
token_limit=token_limit,
gpt_model=gpt_model,
user_api_key=user_api_key,
)
docs = retriever.search()
retriever_params = retriever.get_params()
user_logs_collection.insert_one(
{
"action": "api_search",
"level": "info",
"user": "local",
"api_key": user_api_key,
"question": question,
"sources": docs,
"retriever_params": retriever_params,
"timestamp": datetime.datetime.now(datetime.timezone.utc),
}
)
if data.get("isNoneDoc"):
for doc in docs:
doc["source"] = "None"
except Exception as e:
current_app.logger.error(
f"/api/search - error: {str(e)} - traceback: {traceback.format_exc()}",
extra={"error": str(e), "traceback": traceback.format_exc()},
)
return bad_request(500, str(e))
return make_response(docs, 200)

View File

@@ -1,122 +0,0 @@
import logging
import traceback
from flask import make_response, request
from flask_restx import fields, Resource
from application.api import api
from application.api.answer.routes.base import answer_ns, BaseAnswerResource
from application.api.answer.services.stream_processor import StreamProcessor
logger = logging.getLogger(__name__)
@answer_ns.route("/api/answer")
class AnswerResource(Resource, BaseAnswerResource):
def __init__(self, *args, **kwargs):
Resource.__init__(self, *args, **kwargs)
BaseAnswerResource.__init__(self)
answer_model = answer_ns.model(
"AnswerModel",
{
"question": fields.String(
required=True, description="Question to be asked"
),
"history": fields.List(
fields.String,
required=False,
description="Conversation history (only for new conversations)",
),
"conversation_id": fields.String(
required=False,
description="Existing conversation ID (loads history)",
),
"prompt_id": fields.String(
required=False, default="default", description="Prompt ID"
),
"chunks": fields.Integer(
required=False, default=2, description="Number of chunks"
),
"token_limit": fields.Integer(required=False, description="Token limit"),
"retriever": fields.String(required=False, description="Retriever type"),
"api_key": fields.String(required=False, description="API key"),
"active_docs": fields.String(
required=False, description="Active documents"
),
"isNoneDoc": fields.Boolean(
required=False, description="Flag indicating if no document is used"
),
"save_conversation": fields.Boolean(
required=False,
default=True,
description="Whether to save the conversation",
),
},
)
@api.expect(answer_model)
@api.doc(description="Provide a response based on the question and retriever")
def post(self):
data = request.get_json()
if error := self.validate_request(data):
return error
decoded_token = getattr(request, "decoded_token", None)
processor = StreamProcessor(data, decoded_token)
try:
processor.initialize()
if not processor.decoded_token:
return make_response({"error": "Unauthorized"}, 401)
agent = processor.create_agent()
retriever = processor.create_retriever()
stream = self.complete_stream(
question=data["question"],
agent=agent,
retriever=retriever,
conversation_id=processor.conversation_id,
user_api_key=processor.agent_config.get("user_api_key"),
decoded_token=processor.decoded_token,
isNoneDoc=data.get("isNoneDoc"),
index=None,
should_save_conversation=data.get("save_conversation", True),
)
stream_result = self.process_response_stream(stream)
if len(stream_result) == 7:
(
conversation_id,
response,
sources,
tool_calls,
thought,
error,
structured_info,
) = stream_result
else:
conversation_id, response, sources, tool_calls, thought, error = (
stream_result
)
structured_info = None
if error:
return make_response({"error": error}, 400)
result = {
"conversation_id": conversation_id,
"answer": response,
"sources": sources,
"tool_calls": tool_calls,
"thought": thought,
}
if structured_info:
result.update(structured_info)
except Exception as e:
logger.error(
f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}",
extra={"error": str(e), "traceback": traceback.format_exc()},
)
return make_response({"error": str(e)}, 500)
return make_response(result, 200)

View File

@@ -1,265 +0,0 @@
import datetime
import json
import logging
from typing import Any, Dict, Generator, List, Optional
from flask import Response
from flask_restx import Namespace
from application.api.answer.services.conversation_service import ConversationService
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
from application.utils import check_required_fields, get_gpt_model
logger = logging.getLogger(__name__)
answer_ns = Namespace("answer", description="Answer related operations", path="/")
class BaseAnswerResource:
"""Shared base class for answer endpoints"""
def __init__(self):
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
self.user_logs_collection = db["user_logs"]
self.gpt_model = get_gpt_model()
self.conversation_service = ConversationService()
def validate_request(
self, data: Dict[str, Any], require_conversation_id: bool = False
) -> Optional[Response]:
"""Common request validation"""
required_fields = ["question"]
if require_conversation_id:
required_fields.append("conversation_id")
if missing_fields := check_required_fields(data, required_fields):
return missing_fields
return None
def complete_stream(
self,
question: str,
agent: Any,
retriever: Any,
conversation_id: Optional[str],
user_api_key: Optional[str],
decoded_token: Dict[str, Any],
isNoneDoc: bool = False,
index: Optional[int] = None,
should_save_conversation: bool = True,
attachment_ids: Optional[List[str]] = None,
agent_id: Optional[str] = None,
is_shared_usage: bool = False,
shared_token: Optional[str] = None,
) -> Generator[str, None, None]:
"""
Generator function that streams the complete conversation response.
Args:
question: The user's question
agent: The agent instance
retriever: The retriever instance
conversation_id: Existing conversation ID
user_api_key: User's API key if any
decoded_token: Decoded JWT token
isNoneDoc: Flag for document-less responses
index: Index of message to update
should_save_conversation: Whether to persist the conversation
attachment_ids: List of attachment IDs
agent_id: ID of agent used
is_shared_usage: Flag for shared agent usage
shared_token: Token for shared agent
Yields:
Server-sent event strings
"""
try:
response_full, thought, source_log_docs, tool_calls = "", "", [], []
is_structured = False
schema_info = None
structured_chunks = []
for line in agent.gen(query=question, retriever=retriever):
if "answer" in line:
response_full += str(line["answer"])
if line.get("structured"):
is_structured = True
schema_info = line.get("schema")
structured_chunks.append(line["answer"])
else:
data = json.dumps({"type": "answer", "answer": line["answer"]})
yield f"data: {data}\n\n"
elif "sources" in line:
truncated_sources = []
source_log_docs = line["sources"]
for source in line["sources"]:
truncated_source = source.copy()
if "text" in truncated_source:
truncated_source["text"] = (
truncated_source["text"][:100].strip() + "..."
)
truncated_sources.append(truncated_source)
if truncated_sources:
data = json.dumps(
{"type": "source", "source": truncated_sources}
)
yield f"data: {data}\n\n"
elif "tool_calls" in line:
tool_calls = line["tool_calls"]
data = json.dumps({"type": "tool_calls", "tool_calls": tool_calls})
yield f"data: {data}\n\n"
elif "thought" in line:
thought += line["thought"]
data = json.dumps({"type": "thought", "thought": line["thought"]})
yield f"data: {data}\n\n"
elif "type" in line:
data = json.dumps(line)
yield f"data: {data}\n\n"
if is_structured and structured_chunks:
structured_data = {
"type": "structured_answer",
"answer": response_full,
"structured": True,
"schema": schema_info,
}
data = json.dumps(structured_data)
yield f"data: {data}\n\n"
if isNoneDoc:
for doc in source_log_docs:
doc["source"] = "None"
llm = LLMCreator.create_llm(
settings.LLM_PROVIDER,
api_key=settings.API_KEY,
user_api_key=user_api_key,
decoded_token=decoded_token,
)
if should_save_conversation:
conversation_id = self.conversation_service.save_conversation(
conversation_id,
question,
response_full,
thought,
source_log_docs,
tool_calls,
llm,
self.gpt_model,
decoded_token,
index=index,
api_key=user_api_key,
agent_id=agent_id,
is_shared_usage=is_shared_usage,
shared_token=shared_token,
attachment_ids=attachment_ids,
)
else:
conversation_id = None
id_data = {"type": "id", "id": str(conversation_id)}
data = json.dumps(id_data)
yield f"data: {data}\n\n"
retriever_params = retriever.get_params()
log_data = {
"action": "stream_answer",
"level": "info",
"user": decoded_token.get("sub"),
"api_key": user_api_key,
"question": question,
"response": response_full,
"sources": source_log_docs,
"retriever_params": retriever_params,
"attachments": attachment_ids,
"timestamp": datetime.datetime.now(datetime.timezone.utc),
}
if is_structured:
log_data["structured_output"] = True
if schema_info:
log_data["schema"] = schema_info
# clean up text fields to be no longer than 10000 characters
for key, value in log_data.items():
if isinstance(value, str) and len(value) > 10000:
log_data[key] = value[:10000]
self.user_logs_collection.insert_one(log_data)
# End of stream
data = json.dumps({"type": "end"})
yield f"data: {data}\n\n"
except Exception as e:
logger.error(f"Error in stream: {str(e)}", exc_info=True)
data = json.dumps(
{
"type": "error",
"error": "Please try again later. We apologize for any inconvenience.",
}
)
yield f"data: {data}\n\n"
return
def process_response_stream(self, stream):
"""Process the stream response for non-streaming endpoint"""
conversation_id = ""
response_full = ""
source_log_docs = []
tool_calls = []
thought = ""
stream_ended = False
is_structured = False
schema_info = None
for line in stream:
try:
event_data = line.replace("data: ", "").strip()
event = json.loads(event_data)
if event["type"] == "id":
conversation_id = event["id"]
elif event["type"] == "answer":
response_full += event["answer"]
elif event["type"] == "structured_answer":
response_full = event["answer"]
is_structured = True
schema_info = event.get("schema")
elif event["type"] == "source":
source_log_docs = event["source"]
elif event["type"] == "tool_calls":
tool_calls = event["tool_calls"]
elif event["type"] == "thought":
thought = event["thought"]
elif event["type"] == "error":
logger.error(f"Error from stream: {event['error']}")
return None, None, None, None, event["error"]
elif event["type"] == "end":
stream_ended = True
except (json.JSONDecodeError, KeyError) as e:
logger.warning(f"Error parsing stream event: {e}, line: {line}")
continue
if not stream_ended:
logger.error("Stream ended unexpectedly without an 'end' event.")
return None, None, None, None, "Stream ended unexpectedly"
result = (
conversation_id,
response_full,
source_log_docs,
tool_calls,
thought,
None,
)
if is_structured:
result = result + ({"structured": True, "schema": schema_info},)
return result
def error_stream_generate(self, err_response):
data = json.dumps({"type": "error", "error": err_response})
yield f"data: {data}\n\n"

View File

@@ -1,117 +0,0 @@
import logging
import traceback
from flask import request, Response
from flask_restx import fields, Resource
from application.api import api
from application.api.answer.routes.base import answer_ns, BaseAnswerResource
from application.api.answer.services.stream_processor import StreamProcessor
logger = logging.getLogger(__name__)
@answer_ns.route("/stream")
class StreamResource(Resource, BaseAnswerResource):
def __init__(self, *args, **kwargs):
Resource.__init__(self, *args, **kwargs)
BaseAnswerResource.__init__(self)
stream_model = answer_ns.model(
"StreamModel",
{
"question": fields.String(
required=True, description="Question to be asked"
),
"history": fields.List(
fields.String,
required=False,
description="Conversation history (only for new conversations)",
),
"conversation_id": fields.String(
required=False,
description="Existing conversation ID (loads history)",
),
"prompt_id": fields.String(
required=False, default="default", description="Prompt ID"
),
"chunks": fields.Integer(
required=False, default=2, description="Number of chunks"
),
"token_limit": fields.Integer(required=False, description="Token limit"),
"retriever": fields.String(required=False, description="Retriever type"),
"api_key": fields.String(required=False, description="API key"),
"active_docs": fields.String(
required=False, description="Active documents"
),
"isNoneDoc": fields.Boolean(
required=False, description="Flag indicating if no document is used"
),
"index": fields.Integer(
required=False, description="Index of the query to update"
),
"save_conversation": fields.Boolean(
required=False,
default=True,
description="Whether to save the conversation",
),
"attachments": fields.List(
fields.String, required=False, description="List of attachment IDs"
),
},
)
@api.expect(stream_model)
@api.doc(description="Stream a response based on the question and retriever")
def post(self):
data = request.get_json()
if error := self.validate_request(data, "index" in data):
return error
decoded_token = getattr(request, "decoded_token", None)
processor = StreamProcessor(data, decoded_token)
try:
processor.initialize()
agent = processor.create_agent()
retriever = processor.create_retriever()
return Response(
self.complete_stream(
question=data["question"],
agent=agent,
retriever=retriever,
conversation_id=processor.conversation_id,
user_api_key=processor.agent_config.get("user_api_key"),
decoded_token=processor.decoded_token,
isNoneDoc=data.get("isNoneDoc"),
index=data.get("index"),
should_save_conversation=data.get("save_conversation", True),
attachment_ids=data.get("attachments", []),
agent_id=data.get("agent_id"),
is_shared_usage=processor.is_shared_usage,
shared_token=processor.shared_token,
),
mimetype="text/event-stream",
)
except ValueError as e:
message = "Malformed request body"
logger.error(
f"/stream - error: {message} - specific error: {str(e)} - traceback: {traceback.format_exc()}",
extra={"error": str(e), "traceback": traceback.format_exc()},
)
return Response(
self.error_stream_generate(message),
status=400,
mimetype="text/event-stream",
)
except Exception as e:
logger.error(
f"/stream - error: {str(e)} - traceback: {traceback.format_exc()}",
extra={"error": str(e), "traceback": traceback.format_exc()},
)
return Response(
self.error_stream_generate("Unknown error occurred"),
status=400,
mimetype="text/event-stream",
)

View File

@@ -1,180 +0,0 @@
import logging
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from bson import ObjectId
logger = logging.getLogger(__name__)
class ConversationService:
def __init__(self):
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
self.conversations_collection = db["conversations"]
self.agents_collection = db["agents"]
def get_conversation(
self, conversation_id: str, user_id: str
) -> Optional[Dict[str, Any]]:
"""Retrieve a conversation with proper access control"""
if not conversation_id or not user_id:
return None
try:
conversation = self.conversations_collection.find_one(
{
"_id": ObjectId(conversation_id),
"$or": [{"user": user_id}, {"shared_with": user_id}],
}
)
if not conversation:
logger.warning(
f"Conversation not found or unauthorized - ID: {conversation_id}, User: {user_id}"
)
return None
conversation["_id"] = str(conversation["_id"])
return conversation
except Exception as e:
logger.error(f"Error fetching conversation: {str(e)}", exc_info=True)
return None
def save_conversation(
self,
conversation_id: Optional[str],
question: str,
response: str,
thought: str,
sources: List[Dict[str, Any]],
tool_calls: List[Dict[str, Any]],
llm: Any,
gpt_model: str,
decoded_token: Dict[str, Any],
index: Optional[int] = None,
api_key: Optional[str] = None,
agent_id: Optional[str] = None,
is_shared_usage: bool = False,
shared_token: Optional[str] = None,
attachment_ids: Optional[List[str]] = None,
) -> str:
"""Save or update a conversation in the database"""
user_id = decoded_token.get("sub")
if not user_id:
raise ValueError("User ID not found in token")
current_time = datetime.now(timezone.utc)
# clean up in sources array such that we save max 1k characters for text part
for source in sources:
if "text" in source and isinstance(source["text"], str):
source["text"] = source["text"][:1000]
if conversation_id is not None and index is not None:
# Update existing conversation with new query
result = self.conversations_collection.update_one(
{
"_id": ObjectId(conversation_id),
"user": user_id,
f"queries.{index}": {"$exists": True},
},
{
"$set": {
f"queries.{index}.prompt": question,
f"queries.{index}.response": response,
f"queries.{index}.thought": thought,
f"queries.{index}.sources": sources,
f"queries.{index}.tool_calls": tool_calls,
f"queries.{index}.timestamp": current_time,
f"queries.{index}.attachments": attachment_ids,
}
},
)
if result.matched_count == 0:
raise ValueError("Conversation not found or unauthorized")
self.conversations_collection.update_one(
{
"_id": ObjectId(conversation_id),
"user": user_id,
f"queries.{index}": {"$exists": True},
},
{"$push": {"queries": {"$each": [], "$slice": index + 1}}},
)
return conversation_id
elif conversation_id:
# Append new message to existing conversation
result = self.conversations_collection.update_one(
{"_id": ObjectId(conversation_id), "user": user_id},
{
"$push": {
"queries": {
"prompt": question,
"response": response,
"thought": thought,
"sources": sources,
"tool_calls": tool_calls,
"timestamp": current_time,
"attachments": attachment_ids,
}
}
},
)
if result.matched_count == 0:
raise ValueError("Conversation not found or unauthorized")
return conversation_id
else:
# Create new conversation
messages_summary = [
{
"role": "assistant",
"content": "Summarise following conversation in no more than 3 "
"words, respond ONLY with the summary, use the same "
"language as the user query",
},
{
"role": "user",
"content": "Summarise following conversation in no more than 3 words, "
"respond ONLY with the summary, use the same language as the "
"user query \n\nUser: " + question + "\n\n" + "AI: " + response,
},
]
completion = llm.gen(
model=gpt_model, messages=messages_summary, max_tokens=30
)
conversation_data = {
"user": user_id,
"date": current_time,
"name": completion,
"queries": [
{
"prompt": question,
"response": response,
"thought": thought,
"sources": sources,
"tool_calls": tool_calls,
"timestamp": current_time,
"attachments": attachment_ids,
}
],
}
if api_key:
if agent_id:
conversation_data["agent_id"] = agent_id
if is_shared_usage:
conversation_data["is_shared_usage"] = is_shared_usage
conversation_data["shared_token"] = shared_token
agent = self.agents_collection.find_one({"key": api_key})
if agent:
conversation_data["api_key"] = agent["key"]
result = self.conversations_collection.insert_one(conversation_data)
return str(result.inserted_id)

View File

@@ -1,353 +0,0 @@
import datetime
import json
import logging
import os
from pathlib import Path
from typing import Any, Dict, Optional
from bson.dbref import DBRef
from bson.objectid import ObjectId
from application.agents.agent_creator import AgentCreator
from application.api.answer.services.conversation_service import ConversationService
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.retriever.retriever_creator import RetrieverCreator
from application.utils import get_gpt_model, limit_chat_history
logger = logging.getLogger(__name__)
def get_prompt(prompt_id: str, prompts_collection=None) -> str:
"""
Get a prompt by preset name or MongoDB ID
"""
current_dir = Path(__file__).resolve().parents[3]
prompts_dir = current_dir / "prompts"
preset_mapping = {
"default": "chat_combine_default.txt",
"creative": "chat_combine_creative.txt",
"strict": "chat_combine_strict.txt",
"reduce": "chat_reduce_prompt.txt",
}
if prompt_id in preset_mapping:
file_path = os.path.join(prompts_dir, preset_mapping[prompt_id])
try:
with open(file_path, "r") as f:
return f.read()
except FileNotFoundError:
raise FileNotFoundError(f"Prompt file not found: {file_path}")
try:
if prompts_collection is None:
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
prompts_collection = db["prompts"]
prompt_doc = prompts_collection.find_one({"_id": ObjectId(prompt_id)})
if not prompt_doc:
raise ValueError(f"Prompt with ID {prompt_id} not found")
return prompt_doc["content"]
except Exception as e:
raise ValueError(f"Invalid prompt ID: {prompt_id}") from e
class StreamProcessor:
def __init__(
self, request_data: Dict[str, Any], decoded_token: Optional[Dict[str, Any]]
):
mongo = MongoDB.get_client()
self.db = mongo[settings.MONGO_DB_NAME]
self.agents_collection = self.db["agents"]
self.attachments_collection = self.db["attachments"]
self.prompts_collection = self.db["prompts"]
self.data = request_data
self.decoded_token = decoded_token
self.initial_user_id = (
self.decoded_token.get("sub") if self.decoded_token is not None else None
)
self.conversation_id = self.data.get("conversation_id")
self.source = {}
self.all_sources = []
self.attachments = []
self.history = []
self.agent_config = {}
self.retriever_config = {}
self.is_shared_usage = False
self.shared_token = None
self.gpt_model = get_gpt_model()
self.conversation_service = ConversationService()
def initialize(self):
"""Initialize all required components for processing"""
self._configure_agent()
self._configure_source()
self._configure_retriever()
self._configure_agent()
self._load_conversation_history()
self._process_attachments()
def _load_conversation_history(self):
"""Load conversation history either from DB or request"""
if self.conversation_id and self.initial_user_id:
conversation = self.conversation_service.get_conversation(
self.conversation_id, self.initial_user_id
)
if not conversation:
raise ValueError("Conversation not found or unauthorized")
self.history = [
{"prompt": query["prompt"], "response": query["response"]}
for query in conversation.get("queries", [])
]
else:
self.history = limit_chat_history(
json.loads(self.data.get("history", "[]")), gpt_model=self.gpt_model
)
def _process_attachments(self):
"""Process any attachments in the request"""
attachment_ids = self.data.get("attachments", [])
self.attachments = self._get_attachments_content(
attachment_ids, self.initial_user_id
)
def _get_attachments_content(self, attachment_ids, user_id):
"""
Retrieve content from attachment documents based on their IDs.
"""
if not attachment_ids:
return []
attachments = []
for attachment_id in attachment_ids:
try:
attachment_doc = self.attachments_collection.find_one(
{"_id": ObjectId(attachment_id), "user": user_id}
)
if attachment_doc:
attachments.append(attachment_doc)
except Exception as e:
logger.error(
f"Error retrieving attachment {attachment_id}: {e}", exc_info=True
)
return attachments
def _get_agent_key(self, agent_id: Optional[str], user_id: Optional[str]) -> tuple:
"""Get API key for agent with access control"""
if not agent_id:
return None, False, None
try:
agent = self.agents_collection.find_one({"_id": ObjectId(agent_id)})
if agent is None:
raise Exception("Agent not found")
is_owner = agent.get("user") == user_id
is_shared_with_user = agent.get(
"shared_publicly", False
) or user_id in agent.get("shared_with", [])
if not (is_owner or is_shared_with_user):
raise Exception("Unauthorized access to the agent")
if is_owner:
self.agents_collection.update_one(
{"_id": ObjectId(agent_id)},
{
"$set": {
"lastUsedAt": datetime.datetime.now(datetime.timezone.utc)
}
},
)
return str(agent["key"]), not is_owner, agent.get("shared_token")
except Exception as e:
logger.error(f"Error in get_agent_key: {str(e)}", exc_info=True)
raise
def _get_data_from_api_key(self, api_key: str) -> Dict[str, Any]:
data = self.agents_collection.find_one({"key": api_key})
if not data:
raise Exception("Invalid API Key, please generate a new key", 401)
source = data.get("source")
if isinstance(source, DBRef):
source_doc = self.db.dereference(source)
if source_doc:
data["source"] = str(source_doc["_id"])
data["retriever"] = source_doc.get("retriever", data.get("retriever"))
data["chunks"] = source_doc.get("chunks", data.get("chunks"))
else:
data["source"] = None
elif source == "default":
data["source"] = "default"
else:
data["source"] = None
# Handle multiple sources
sources = data.get("sources", [])
if sources and isinstance(sources, list):
sources_list = []
for i, source_ref in enumerate(sources):
if source_ref == "default":
processed_source = {
"id": "default",
"retriever": "classic",
"chunks": data.get("chunks", "2"),
}
sources_list.append(processed_source)
elif isinstance(source_ref, DBRef):
source_doc = self.db.dereference(source_ref)
if source_doc:
processed_source = {
"id": str(source_doc["_id"]),
"retriever": source_doc.get("retriever", "classic"),
"chunks": source_doc.get("chunks", data.get("chunks", "2")),
}
sources_list.append(processed_source)
data["sources"] = sources_list
else:
data["sources"] = []
return data
def _configure_source(self):
"""Configure the source based on agent data"""
api_key = self.data.get("api_key") or self.agent_key
if api_key:
agent_data = self._get_data_from_api_key(api_key)
if agent_data.get("sources") and len(agent_data["sources"]) > 0:
source_ids = [
source["id"] for source in agent_data["sources"] if source.get("id")
]
if source_ids:
self.source = {"active_docs": source_ids}
else:
self.source = {}
self.all_sources = agent_data["sources"]
elif agent_data.get("source"):
self.source = {"active_docs": agent_data["source"]}
self.all_sources = [
{
"id": agent_data["source"],
"retriever": agent_data.get("retriever", "classic"),
}
]
else:
self.source = {}
self.all_sources = []
return
if "active_docs" in self.data:
self.source = {"active_docs": self.data["active_docs"]}
return
self.source = {}
self.all_sources = []
def _configure_agent(self):
"""Configure the agent based on request data"""
agent_id = self.data.get("agent_id")
self.agent_key, self.is_shared_usage, self.shared_token = self._get_agent_key(
agent_id, self.initial_user_id
)
api_key = self.data.get("api_key")
if api_key:
data_key = self._get_data_from_api_key(api_key)
self.agent_config.update(
{
"prompt_id": data_key.get("prompt_id", "default"),
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
"user_api_key": api_key,
"json_schema": data_key.get("json_schema"),
}
)
self.initial_user_id = data_key.get("user")
self.decoded_token = {"sub": data_key.get("user")}
if data_key.get("source"):
self.source = {"active_docs": data_key["source"]}
if data_key.get("retriever"):
self.retriever_config["retriever_name"] = data_key["retriever"]
if data_key.get("chunks") is not None:
try:
self.retriever_config["chunks"] = int(data_key["chunks"])
except (ValueError, TypeError):
logger.warning(
f"Invalid chunks value: {data_key['chunks']}, using default value 2"
)
self.retriever_config["chunks"] = 2
elif self.agent_key:
data_key = self._get_data_from_api_key(self.agent_key)
self.agent_config.update(
{
"prompt_id": data_key.get("prompt_id", "default"),
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
"user_api_key": self.agent_key,
"json_schema": data_key.get("json_schema"),
}
)
self.decoded_token = (
self.decoded_token
if self.is_shared_usage
else {"sub": data_key.get("user")}
)
if data_key.get("source"):
self.source = {"active_docs": data_key["source"]}
if data_key.get("retriever"):
self.retriever_config["retriever_name"] = data_key["retriever"]
if data_key.get("chunks") is not None:
try:
self.retriever_config["chunks"] = int(data_key["chunks"])
except (ValueError, TypeError):
logger.warning(
f"Invalid chunks value: {data_key['chunks']}, using default value 2"
)
self.retriever_config["chunks"] = 2
else:
self.agent_config.update(
{
"prompt_id": self.data.get("prompt_id", "default"),
"agent_type": settings.AGENT_NAME,
"user_api_key": None,
"json_schema": None,
}
)
def _configure_retriever(self):
"""Configure the retriever based on request data"""
self.retriever_config = {
"retriever_name": self.data.get("retriever", "classic"),
"chunks": int(self.data.get("chunks", 2)),
"token_limit": self.data.get("token_limit", settings.DEFAULT_MAX_HISTORY),
}
api_key = self.data.get("api_key") or self.agent_key
if not api_key and "isNoneDoc" in self.data and self.data["isNoneDoc"]:
self.retriever_config["chunks"] = 0
def create_agent(self):
"""Create and return the configured agent"""
return AgentCreator.create_agent(
self.agent_config["agent_type"],
endpoint="stream",
llm_name=settings.LLM_PROVIDER,
gpt_model=self.gpt_model,
api_key=settings.API_KEY,
user_api_key=self.agent_config["user_api_key"],
prompt=get_prompt(self.agent_config["prompt_id"], self.prompts_collection),
chat_history=self.history,
decoded_token=self.decoded_token,
attachments=self.attachments,
json_schema=self.agent_config.get("json_schema"),
)
def create_retriever(self):
"""Create and return the configured retriever"""
return RetrieverCreator.create_retriever(
self.retriever_config["retriever_name"],
source=self.source,
chat_history=self.history,
prompt=get_prompt(self.agent_config["prompt_id"], self.prompts_collection),
chunks=self.retriever_config["chunks"],
token_limit=self.retriever_config["token_limit"],
gpt_model=self.gpt_model,
user_api_key=self.agent_config["user_api_key"],
decoded_token=self.decoded_token,
)

View File

@@ -1,695 +0,0 @@
import base64
import datetime
import json
import uuid
from bson.objectid import ObjectId
from flask import (
Blueprint,
current_app,
jsonify,
make_response,
request
)
from flask_restx import fields, Namespace, Resource
from application.api.user.tasks import (
ingest_connector_task,
)
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.api import api
from application.utils import (
check_required_fields
)
from application.parser.connectors.connector_creator import ConnectorCreator
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
sources_collection = db["sources"]
sessions_collection = db["connector_sessions"]
connector = Blueprint("connector", __name__)
connectors_ns = Namespace("connectors", description="Connector operations", path="/")
api.add_namespace(connectors_ns)
@connectors_ns.route("/api/connectors/upload")
class UploadConnector(Resource):
@api.expect(
api.model(
"ConnectorUploadModel",
{
"user": fields.String(required=True, description="User ID"),
"source": fields.String(
required=True, description="Source type (google_drive, github, etc.)"
),
"name": fields.String(required=True, description="Job name"),
"data": fields.String(required=True, description="Configuration data"),
"repo_url": fields.String(description="GitHub repository URL"),
},
)
)
@api.doc(
description="Uploads connector source for vectorization",
)
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
data = request.form
required_fields = ["user", "source", "name", "data"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
config = json.loads(data["data"])
source_data = None
sync_frequency = config.get("sync_frequency", "never")
if data["source"] == "github":
source_data = config.get("repo_url")
elif data["source"] in ["crawler", "url"]:
source_data = config.get("url")
elif data["source"] == "reddit":
source_data = config
elif data["source"] in ConnectorCreator.get_supported_connectors():
session_token = config.get("session_token")
if not session_token:
return make_response(jsonify({
"success": False,
"error": f"Missing session_token in {data['source']} configuration"
}), 400)
file_ids = config.get("file_ids", [])
if isinstance(file_ids, str):
file_ids = [id.strip() for id in file_ids.split(',') if id.strip()]
elif not isinstance(file_ids, list):
file_ids = []
folder_ids = config.get("folder_ids", [])
if isinstance(folder_ids, str):
folder_ids = [id.strip() for id in folder_ids.split(',') if id.strip()]
elif not isinstance(folder_ids, list):
folder_ids = []
config["file_ids"] = file_ids
config["folder_ids"] = folder_ids
task = ingest_connector_task.delay(
job_name=data["name"],
user=decoded_token.get("sub"),
source_type=data["source"],
session_token=session_token,
file_ids=file_ids,
folder_ids=folder_ids,
recursive=config.get("recursive", False),
retriever=config.get("retriever", "classic"),
sync_frequency=sync_frequency
)
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
task = ingest_connector_task.delay(
source_data=source_data,
job_name=data["name"],
user=decoded_token.get("sub"),
loader=data["source"],
sync_frequency=sync_frequency
)
except Exception as err:
current_app.logger.error(
f"Error uploading connector source: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
@connectors_ns.route("/api/connectors/task_status")
class ConnectorTaskStatus(Resource):
task_status_model = api.model(
"ConnectorTaskStatusModel",
{"task_id": fields.String(required=True, description="Task ID")},
)
@api.expect(task_status_model)
@api.doc(description="Get connector task status")
def get(self):
task_id = request.args.get("task_id")
if not task_id:
return make_response(
jsonify({"success": False, "message": "Task ID is required"}), 400
)
try:
from application.celery_init import celery
task = celery.AsyncResult(task_id)
task_meta = task.info
print(f"Task status: {task.status}")
if not isinstance(
task_meta, (dict, list, str, int, float, bool, type(None))
):
task_meta = str(task_meta)
except Exception as err:
current_app.logger.error(f"Error getting task status: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"status": task.status, "result": task_meta}), 200)
@connectors_ns.route("/api/connectors/sources")
class ConnectorSources(Resource):
@api.doc(description="Get connector sources")
def get(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
try:
sources = sources_collection.find({"user": user, "type": "connector:file"}).sort("date", -1)
connector_sources = []
for source in sources:
connector_sources.append({
"id": str(source["_id"]),
"name": source.get("name"),
"date": source.get("date"),
"type": source.get("type"),
"source": source.get("source"),
"tokens": source.get("tokens", ""),
"retriever": source.get("retriever", "classic"),
"syncFrequency": source.get("sync_frequency", ""),
})
except Exception as err:
current_app.logger.error(f"Error retrieving connector sources: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify(connector_sources), 200)
@connectors_ns.route("/api/connectors/delete")
class DeleteConnectorSource(Resource):
@api.doc(
description="Delete a connector source",
params={"source_id": "The source ID to delete"},
)
def delete(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
source_id = request.args.get("source_id")
if not source_id:
return make_response(
jsonify({"success": False, "message": "source_id is required"}), 400
)
try:
result = sources_collection.delete_one(
{"_id": ObjectId(source_id), "user": decoded_token.get("sub")}
)
if result.deleted_count == 0:
return make_response(
jsonify({"success": False, "message": "Source not found"}), 404
)
except Exception as err:
current_app.logger.error(
f"Error deleting connector source: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True}), 200)
@connectors_ns.route("/api/connectors/auth")
class ConnectorAuth(Resource):
@api.doc(description="Get connector OAuth authorization URL", params={"provider": "Connector provider (e.g., google_drive)"})
def get(self):
try:
provider = request.args.get('provider') or request.args.get('source')
if not provider:
return make_response(jsonify({"success": False, "error": "Missing provider"}), 400)
if not ConnectorCreator.is_supported(provider):
return make_response(jsonify({"success": False, "error": f"Unsupported provider: {provider}"}), 400)
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
user_id = decoded_token.get('sub')
now = datetime.datetime.now(datetime.timezone.utc)
result = sessions_collection.insert_one({
"provider": provider,
"user": user_id,
"status": "pending",
"created_at": now
})
state_dict = {
"provider": provider,
"object_id": str(result.inserted_id)
}
state = base64.urlsafe_b64encode(json.dumps(state_dict).encode()).decode()
auth = ConnectorCreator.create_auth(provider)
authorization_url = auth.get_authorization_url(state=state)
return make_response(jsonify({
"success": True,
"authorization_url": authorization_url,
"state": state
}), 200)
except Exception as e:
current_app.logger.error(f"Error generating connector auth URL: {e}")
return make_response(jsonify({"success": False, "error": str(e)}), 500)
@connectors_ns.route("/api/connectors/callback")
class ConnectorsCallback(Resource):
@api.doc(description="Handle OAuth callback for external connectors")
def get(self):
"""Handle OAuth callback for external connectors"""
try:
from application.parser.connectors.connector_creator import ConnectorCreator
from flask import request, redirect
authorization_code = request.args.get('code')
state = request.args.get('state')
error = request.args.get('error')
state_dict = json.loads(base64.urlsafe_b64decode(state.encode()).decode())
provider = state_dict["provider"]
state_object_id = state_dict["object_id"]
if error:
if error == "access_denied":
return redirect(f"/api/connectors/callback-status?status=cancelled&message=Authentication+was+cancelled.+You+can+try+again+if+you'd+like+to+connect+your+account.&provider={provider}")
else:
current_app.logger.warning(f"OAuth error in callback: {error}")
return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}")
if not authorization_code:
return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}")
try:
auth = ConnectorCreator.create_auth(provider)
token_info = auth.exchange_code_for_tokens(authorization_code)
session_token = str(uuid.uuid4())
try:
credentials = auth.create_credentials_from_token_info(token_info)
service = auth.build_drive_service(credentials)
user_info = service.about().get(fields="user").execute()
user_email = user_info.get('user', {}).get('emailAddress', 'Connected User')
except Exception as e:
current_app.logger.warning(f"Could not get user info: {e}")
user_email = 'Connected User'
sanitized_token_info = {
"access_token": token_info.get("access_token"),
"refresh_token": token_info.get("refresh_token"),
"token_uri": token_info.get("token_uri"),
"expiry": token_info.get("expiry")
}
sessions_collection.find_one_and_update(
{"_id": ObjectId(state_object_id), "provider": provider},
{
"$set": {
"session_token": session_token,
"token_info": sanitized_token_info,
"user_email": user_email,
"status": "authorized"
}
}
)
# Redirect to success page with session token and user email
return redirect(f"/api/connectors/callback-status?status=success&message=Authentication+successful&provider={provider}&session_token={session_token}&user_email={user_email}")
except Exception as e:
current_app.logger.error(f"Error exchanging code for tokens: {str(e)}", exc_info=True)
return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}")
except Exception as e:
current_app.logger.error(f"Error handling connector callback: {e}")
return redirect("/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.")
@connectors_ns.route("/api/connectors/refresh")
class ConnectorRefresh(Resource):
@api.expect(api.model("ConnectorRefreshModel", {"provider": fields.String(required=True), "refresh_token": fields.String(required=True)}))
@api.doc(description="Refresh connector access token")
def post(self):
try:
data = request.get_json()
provider = data.get('provider')
refresh_token = data.get('refresh_token')
if not provider or not refresh_token:
return make_response(jsonify({"success": False, "error": "provider and refresh_token are required"}), 400)
auth = ConnectorCreator.create_auth(provider)
token_info = auth.refresh_access_token(refresh_token)
return make_response(jsonify({"success": True, "token_info": token_info}), 200)
except Exception as e:
current_app.logger.error(f"Error refreshing token for connector: {e}")
return make_response(jsonify({"success": False, "error": str(e)}), 500)
@connectors_ns.route("/api/connectors/files")
class ConnectorFiles(Resource):
@api.expect(api.model("ConnectorFilesModel", {
"provider": fields.String(required=True),
"session_token": fields.String(required=True),
"folder_id": fields.String(required=False),
"limit": fields.Integer(required=False),
"page_token": fields.String(required=False),
"search_query": fields.String(required=False)
}))
@api.doc(description="List files from a connector provider (supports pagination and search)")
def post(self):
try:
data = request.get_json()
provider = data.get('provider')
session_token = data.get('session_token')
folder_id = data.get('folder_id')
limit = data.get('limit', 10)
page_token = data.get('page_token')
search_query = data.get('search_query')
if not provider or not session_token:
return make_response(jsonify({"success": False, "error": "provider and session_token are required"}), 400)
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
user = decoded_token.get('sub')
session = sessions_collection.find_one({"session_token": session_token, "user": user})
if not session:
return make_response(jsonify({"success": False, "error": "Invalid or unauthorized session"}), 401)
loader = ConnectorCreator.create_connector(provider, session_token)
input_config = {
'limit': limit,
'list_only': True,
'session_token': session_token,
'folder_id': folder_id,
'page_token': page_token
}
if search_query:
input_config['search_query'] = search_query
documents = loader.load_data(input_config)
files = []
for doc in documents[:limit]:
metadata = doc.extra_info
modified_time = metadata.get('modified_time')
if modified_time:
date_part = modified_time.split('T')[0]
time_part = modified_time.split('T')[1].split('.')[0].split('Z')[0]
formatted_time = f"{date_part} {time_part}"
else:
formatted_time = None
files.append({
'id': doc.doc_id,
'name': metadata.get('file_name', 'Unknown File'),
'type': metadata.get('mime_type', 'unknown'),
'size': metadata.get('size', None),
'modifiedTime': formatted_time,
'isFolder': metadata.get('is_folder', False)
})
next_token = getattr(loader, 'next_page_token', None)
has_more = bool(next_token)
return make_response(jsonify({
"success": True,
"files": files,
"total": len(files),
"next_page_token": next_token,
"has_more": has_more
}), 200)
except Exception as e:
current_app.logger.error(f"Error loading connector files: {e}")
return make_response(jsonify({"success": False, "error": f"Failed to load files: {str(e)}"}), 500)
@connectors_ns.route("/api/connectors/validate-session")
class ConnectorValidateSession(Resource):
@api.expect(api.model("ConnectorValidateSessionModel", {"provider": fields.String(required=True), "session_token": fields.String(required=True)}))
@api.doc(description="Validate connector session token and return user info and access token")
def post(self):
try:
data = request.get_json()
provider = data.get('provider')
session_token = data.get('session_token')
if not provider or not session_token:
return make_response(jsonify({"success": False, "error": "provider and session_token are required"}), 400)
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
user = decoded_token.get('sub')
session = sessions_collection.find_one({"session_token": session_token, "user": user})
if not session or "token_info" not in session:
return make_response(jsonify({"success": False, "error": "Invalid or expired session"}), 401)
token_info = session["token_info"]
auth = ConnectorCreator.create_auth(provider)
is_expired = auth.is_token_expired(token_info)
if is_expired and token_info.get('refresh_token'):
try:
refreshed_token_info = auth.refresh_access_token(token_info.get('refresh_token'))
sanitized_token_info = {
"access_token": refreshed_token_info.get("access_token"),
"refresh_token": refreshed_token_info.get("refresh_token"),
"token_uri": refreshed_token_info.get("token_uri"),
"expiry": refreshed_token_info.get("expiry")
}
sessions_collection.update_one(
{"session_token": session_token},
{"$set": {"token_info": sanitized_token_info}}
)
token_info = sanitized_token_info
is_expired = False
except Exception as refresh_error:
current_app.logger.error(f"Failed to refresh token: {refresh_error}")
if is_expired:
return make_response(jsonify({
"success": False,
"expired": True,
"error": "Session token has expired. Please reconnect."
}), 401)
return make_response(jsonify({
"success": True,
"expired": False,
"user_email": session.get('user_email', 'Connected User'),
"access_token": token_info.get('access_token')
}), 200)
except Exception as e:
current_app.logger.error(f"Error validating connector session: {e}")
return make_response(jsonify({"success": False, "error": str(e)}), 500)
@connectors_ns.route("/api/connectors/disconnect")
class ConnectorDisconnect(Resource):
@api.expect(api.model("ConnectorDisconnectModel", {"provider": fields.String(required=True), "session_token": fields.String(required=False)}))
@api.doc(description="Disconnect a connector session")
def post(self):
try:
data = request.get_json()
provider = data.get('provider')
session_token = data.get('session_token')
if not provider:
return make_response(jsonify({"success": False, "error": "provider is required"}), 400)
if session_token:
sessions_collection.delete_one({"session_token": session_token})
return make_response(jsonify({"success": True}), 200)
except Exception as e:
current_app.logger.error(f"Error disconnecting connector session: {e}")
return make_response(jsonify({"success": False, "error": str(e)}), 500)
@connectors_ns.route("/api/connectors/sync")
class ConnectorSync(Resource):
@api.expect(
api.model(
"ConnectorSyncModel",
{
"source_id": fields.String(required=True, description="Source ID to sync"),
"session_token": fields.String(required=True, description="Authentication token")
},
)
)
@api.doc(description="Sync connector source to check for modifications")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
try:
data = request.get_json()
source_id = data.get('source_id')
session_token = data.get('session_token')
if not all([source_id, session_token]):
return make_response(
jsonify({
"success": False,
"error": "source_id and session_token are required"
}),
400
)
source = sources_collection.find_one({"_id": ObjectId(source_id)})
if not source:
return make_response(
jsonify({
"success": False,
"error": "Source not found"
}),
404
)
if source.get('user') != decoded_token.get('sub'):
return make_response(
jsonify({
"success": False,
"error": "Unauthorized access to source"
}),
403
)
remote_data = {}
try:
if source.get('remote_data'):
remote_data = json.loads(source.get('remote_data'))
except json.JSONDecodeError:
current_app.logger.error(f"Invalid remote_data format for source {source_id}")
remote_data = {}
source_type = remote_data.get('provider')
if not source_type:
return make_response(
jsonify({
"success": False,
"error": "Source provider not found in remote_data"
}),
400
)
# Extract configuration from remote_data
file_ids = remote_data.get('file_ids', [])
folder_ids = remote_data.get('folder_ids', [])
recursive = remote_data.get('recursive', True)
# Start the sync task
task = ingest_connector_task.delay(
job_name=source.get('name'),
user=decoded_token.get('sub'),
source_type=source_type,
session_token=session_token,
file_ids=file_ids,
folder_ids=folder_ids,
recursive=recursive,
retriever=source.get('retriever', 'classic'),
operation_mode="sync",
doc_id=source_id,
sync_frequency=source.get('sync_frequency', 'never')
)
return make_response(
jsonify({
"success": True,
"task_id": task.id
}),
200
)
except Exception as err:
current_app.logger.error(
f"Error syncing connector source: {err}",
exc_info=True
)
return make_response(
jsonify({
"success": False,
"error": str(err)
}),
400
)
@connectors_ns.route("/api/connectors/callback-status")
class ConnectorCallbackStatus(Resource):
@api.doc(description="Return HTML page with connector authentication status")
def get(self):
"""Return HTML page with connector authentication status"""
try:
status = request.args.get('status', 'error')
message = request.args.get('message', '')
provider = request.args.get('provider', 'connector')
session_token = request.args.get('session_token', '')
user_email = request.args.get('user_email', '')
html_content = f"""
<!DOCTYPE html>
<html>
<head>
<title>{provider.replace('_', ' ').title()} Authentication</title>
<style>
body {{ font-family: Arial, sans-serif; text-align: center; padding: 40px; }}
.container {{ max-width: 600px; margin: 0 auto; }}
.success {{ color: #4CAF50; }}
.error {{ color: #F44336; }}
.cancelled {{ color: #FF9800; }}
</style>
<script>
window.onload = function() {{
const status = "{status}";
const sessionToken = "{session_token}";
const userEmail = "{user_email}";
if (status === "success" && window.opener) {{
window.opener.postMessage({{
type: '{provider}_auth_success',
session_token: sessionToken,
user_email: userEmail
}}, '*');
setTimeout(() => window.close(), 3000);
}} else if (status === "cancelled" || status === "error") {{
setTimeout(() => window.close(), 3000);
}}
}};
</script>
</head>
<body>
<div class="container">
<h2>{provider.replace('_', ' ').title()} Authentication</h2>
<div class="{status}">
<p>{message}</p>
{f'<p>Connected as: {user_email}</p>' if status == 'success' else ''}
</div>
<p><small>You can close this window. {f"Your {provider.replace('_', ' ').title()} is now connected and ready to use." if status == 'success' else "Feel free to close this window."}</small></p>
</div>
</body>
</html>
"""
return make_response(html_content, 200, {'Content-Type': 'text/html'})
except Exception as e:
current_app.logger.error(f"Error rendering callback status page: {e}")
return make_response("Authentication error occurred", 500, {'Content-Type': 'text/html'})

View File

@@ -1,18 +1,14 @@
import os
import datetime
import json
from flask import Blueprint, request, send_from_directory
from werkzeug.utils import secure_filename
from bson.objectid import ObjectId
import logging
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.storage.storage_creator import StorageCreator
logger = logging.getLogger(__name__)
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
db = mongo["docsgpt"]
conversations_collection = db["conversations"]
sources_collection = db["sources"]
@@ -38,52 +34,37 @@ def upload_index_files():
"""Upload two files(index.faiss, index.pkl) to the user's folder."""
if "user" not in request.form:
return {"status": "no user"}
user = request.form["user"]
user = secure_filename(request.form["user"])
if "name" not in request.form:
return {"status": "no name"}
job_name = request.form["name"]
tokens = request.form["tokens"]
retriever = request.form["retriever"]
id = request.form["id"]
type = request.form["type"]
job_name = secure_filename(request.form["name"])
tokens = secure_filename(request.form["tokens"])
retriever = secure_filename(request.form["retriever"])
id = secure_filename(request.form["id"])
type = secure_filename(request.form["type"])
remote_data = request.form["remote_data"] if "remote_data" in request.form else None
sync_frequency = request.form["sync_frequency"] if "sync_frequency" in request.form else None
file_path = request.form.get("file_path")
directory_structure = request.form.get("directory_structure")
if directory_structure:
try:
directory_structure = json.loads(directory_structure)
except Exception:
logger.error("Error parsing directory_structure")
directory_structure = {}
else:
directory_structure = {}
sync_frequency = secure_filename(request.form["sync_frequency"]) if "sync_frequency" in request.form else None
storage = StorageCreator.get_storage()
index_base_path = f"indexes/{id}"
save_dir = os.path.join(current_dir, "indexes", str(id))
if settings.VECTOR_STORE == "faiss":
if "file_faiss" not in request.files:
logger.error("No file_faiss part")
print("No file part")
return {"status": "no file"}
file_faiss = request.files["file_faiss"]
if file_faiss.filename == "":
return {"status": "no file name"}
if "file_pkl" not in request.files:
logger.error("No file_pkl part")
print("No file part")
return {"status": "no file"}
file_pkl = request.files["file_pkl"]
if file_pkl.filename == "":
return {"status": "no file name"}
# saves index files
# Save index files to storage
faiss_storage_path = f"{index_base_path}/index.faiss"
pkl_storage_path = f"{index_base_path}/index.pkl"
storage.save_file(file_faiss, faiss_storage_path)
storage.save_file(file_pkl, pkl_storage_path)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
file_faiss.save(os.path.join(save_dir, "index.faiss"))
file_pkl.save(os.path.join(save_dir, "index.pkl"))
existing_entry = sources_collection.find_one({"_id": ObjectId(id)})
if existing_entry:
@@ -101,8 +82,6 @@ def upload_index_files():
"retriever": retriever,
"remote_data": remote_data,
"sync_frequency": sync_frequency,
"file_path": file_path,
"directory_structure": directory_structure,
}
},
)
@@ -120,8 +99,6 @@ def upload_index_files():
"retriever": retriever,
"remote_data": remote_data,
"sync_frequency": sync_frequency,
"file_path": file_path,
"directory_structure": directory_structure,
}
)
return {"status": "ok"}

View File

@@ -1,5 +0,0 @@
"""User API module - provides all user-related API endpoints"""
from .routes import user
__all__ = ["user"]

View File

@@ -1,7 +0,0 @@
"""Agents module."""
from .routes import agents_ns
from .sharing import agents_sharing_ns
from .webhooks import agents_webhooks_ns
__all__ = ["agents_ns", "agents_sharing_ns", "agents_webhooks_ns"]

View File

@@ -1,910 +0,0 @@
"""Agent management routes."""
import datetime
import json
import uuid
from bson.dbref import DBRef
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, Namespace, Resource
from application.api import api
from application.api.user.base import (
agents_collection,
db,
ensure_user_doc,
handle_image_upload,
resolve_tool_details,
storage,
users_collection,
)
from application.utils import (
check_required_fields,
generate_image_url,
validate_required_fields,
)
agents_ns = Namespace("agents", description="Agent management operations", path="/api")
@agents_ns.route("/get_agent")
class GetAgent(Resource):
@api.doc(params={"id": "Agent ID"}, description="Get agent by ID")
def get(self):
if not (decoded_token := request.decoded_token):
return {"success": False}, 401
if not (agent_id := request.args.get("id")):
return {"success": False, "message": "ID required"}, 400
try:
agent = agents_collection.find_one(
{"_id": ObjectId(agent_id), "user": decoded_token["sub"]}
)
if not agent:
return {"status": "Not found"}, 404
data = {
"id": str(agent["_id"]),
"name": agent["name"],
"description": agent.get("description", ""),
"image": (
generate_image_url(agent["image"]) if agent.get("image") else ""
),
"source": (
str(source_doc["_id"])
if isinstance(agent.get("source"), DBRef)
and (source_doc := db.dereference(agent.get("source")))
else ""
),
"sources": [
(
str(db.dereference(source_ref)["_id"])
if isinstance(source_ref, DBRef) and db.dereference(source_ref)
else source_ref
)
for source_ref in agent.get("sources", [])
if (isinstance(source_ref, DBRef) and db.dereference(source_ref))
or source_ref == "default"
],
"chunks": agent["chunks"],
"retriever": agent.get("retriever", ""),
"prompt_id": agent.get("prompt_id", ""),
"tools": agent.get("tools", []),
"tool_details": resolve_tool_details(agent.get("tools", [])),
"agent_type": agent.get("agent_type", ""),
"status": agent.get("status", ""),
"json_schema": agent.get("json_schema"),
"created_at": agent.get("createdAt", ""),
"updated_at": agent.get("updatedAt", ""),
"last_used_at": agent.get("lastUsedAt", ""),
"key": (
f"{agent['key'][:4]}...{agent['key'][-4:]}"
if "key" in agent
else ""
),
"pinned": agent.get("pinned", False),
"shared": agent.get("shared_publicly", False),
"shared_metadata": agent.get("shared_metadata", {}),
"shared_token": agent.get("shared_token", ""),
}
return make_response(jsonify(data), 200)
except Exception as e:
current_app.logger.error(f"Agent fetch error: {e}", exc_info=True)
return {"success": False}, 400
@agents_ns.route("/get_agents")
class GetAgents(Resource):
@api.doc(description="Retrieve agents for the user")
def get(self):
if not (decoded_token := request.decoded_token):
return {"success": False}, 401
user = decoded_token.get("sub")
try:
user_doc = ensure_user_doc(user)
pinned_ids = set(user_doc.get("agent_preferences", {}).get("pinned", []))
agents = agents_collection.find({"user": user})
list_agents = [
{
"id": str(agent["_id"]),
"name": agent["name"],
"description": agent.get("description", ""),
"image": (
generate_image_url(agent["image"]) if agent.get("image") else ""
),
"source": (
str(source_doc["_id"])
if isinstance(agent.get("source"), DBRef)
and (source_doc := db.dereference(agent.get("source")))
else (
agent.get("source", "")
if agent.get("source") == "default"
else ""
)
),
"sources": [
(
source_ref
if source_ref == "default"
else str(db.dereference(source_ref)["_id"])
)
for source_ref in agent.get("sources", [])
if source_ref == "default"
or (
isinstance(source_ref, DBRef) and db.dereference(source_ref)
)
],
"chunks": agent["chunks"],
"retriever": agent.get("retriever", ""),
"prompt_id": agent.get("prompt_id", ""),
"tools": agent.get("tools", []),
"tool_details": resolve_tool_details(agent.get("tools", [])),
"agent_type": agent.get("agent_type", ""),
"status": agent.get("status", ""),
"json_schema": agent.get("json_schema"),
"created_at": agent.get("createdAt", ""),
"updated_at": agent.get("updatedAt", ""),
"last_used_at": agent.get("lastUsedAt", ""),
"key": (
f"{agent['key'][:4]}...{agent['key'][-4:]}"
if "key" in agent
else ""
),
"pinned": str(agent["_id"]) in pinned_ids,
"shared": agent.get("shared_publicly", False),
"shared_metadata": agent.get("shared_metadata", {}),
"shared_token": agent.get("shared_token", ""),
}
for agent in agents
if "source" in agent or "retriever" in agent
]
except Exception as err:
current_app.logger.error(f"Error retrieving agents: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify(list_agents), 200)
@agents_ns.route("/create_agent")
class CreateAgent(Resource):
create_agent_model = api.model(
"CreateAgentModel",
{
"name": fields.String(required=True, description="Name of the agent"),
"description": fields.String(
required=True, description="Description of the agent"
),
"image": fields.Raw(
required=False, description="Image file upload", type="file"
),
"source": fields.String(
required=False, description="Source ID (legacy single source)"
),
"sources": fields.List(
fields.String,
required=False,
description="List of source identifiers for multiple sources",
),
"chunks": fields.Integer(required=True, description="Chunks count"),
"retriever": fields.String(required=True, description="Retriever ID"),
"prompt_id": fields.String(required=True, description="Prompt ID"),
"tools": fields.List(
fields.String, required=False, description="List of tool identifiers"
),
"agent_type": fields.String(required=True, description="Type of the agent"),
"status": fields.String(
required=True, description="Status of the agent (draft or published)"
),
"json_schema": fields.Raw(
required=False,
description="JSON schema for enforcing structured output format",
),
},
)
@api.expect(create_agent_model)
@api.doc(description="Create a new agent")
def post(self):
if not (decoded_token := request.decoded_token):
return {"success": False}, 401
user = decoded_token.get("sub")
if request.content_type == "application/json":
data = request.get_json()
else:
data = request.form.to_dict()
if "tools" in data:
try:
data["tools"] = json.loads(data["tools"])
except json.JSONDecodeError:
data["tools"] = []
if "sources" in data:
try:
data["sources"] = json.loads(data["sources"])
except json.JSONDecodeError:
data["sources"] = []
if "json_schema" in data:
try:
data["json_schema"] = json.loads(data["json_schema"])
except json.JSONDecodeError:
data["json_schema"] = None
print(f"Received data: {data}")
# Validate JSON schema if provided
if data.get("json_schema"):
try:
# Basic validation - ensure it's a valid JSON structure
json_schema = data.get("json_schema")
if not isinstance(json_schema, dict):
return make_response(
jsonify(
{
"success": False,
"message": "JSON schema must be a valid JSON object",
}
),
400,
)
# Validate that it has either a 'schema' property or is itself a schema
if "schema" not in json_schema and "type" not in json_schema:
return make_response(
jsonify(
{
"success": False,
"message": "JSON schema must contain either a 'schema' property or be a valid JSON schema with 'type' property",
}
),
400,
)
except Exception as e:
return make_response(
jsonify(
{"success": False, "message": f"Invalid JSON schema: {str(e)}"}
),
400,
)
if data.get("status") not in ["draft", "published"]:
return make_response(
jsonify(
{
"success": False,
"message": "Status must be either 'draft' or 'published'",
}
),
400,
)
if data.get("status") == "published":
required_fields = [
"name",
"description",
"chunks",
"retriever",
"prompt_id",
"agent_type",
]
# Require either source or sources (but not both)
if not data.get("source") and not data.get("sources"):
return make_response(
jsonify(
{
"success": False,
"message": "Either 'source' or 'sources' field is required for published agents",
}
),
400,
)
validate_fields = ["name", "description", "prompt_id", "agent_type"]
else:
required_fields = ["name"]
validate_fields = []
missing_fields = check_required_fields(data, required_fields)
invalid_fields = validate_required_fields(data, validate_fields)
if missing_fields:
return missing_fields
if invalid_fields:
return invalid_fields
image_url, error = handle_image_upload(request, "", user, storage)
if error:
return make_response(
jsonify({"success": False, "message": "Image upload failed"}), 400
)
try:
key = str(uuid.uuid4()) if data.get("status") == "published" else ""
sources_list = []
if data.get("sources") and len(data.get("sources", [])) > 0:
for source_id in data.get("sources", []):
if source_id == "default":
sources_list.append("default")
elif ObjectId.is_valid(source_id):
sources_list.append(DBRef("sources", ObjectId(source_id)))
source_field = ""
else:
source_value = data.get("source", "")
if source_value == "default":
source_field = "default"
elif ObjectId.is_valid(source_value):
source_field = DBRef("sources", ObjectId(source_value))
else:
source_field = ""
new_agent = {
"user": user,
"name": data.get("name"),
"description": data.get("description", ""),
"image": image_url,
"source": source_field,
"sources": sources_list,
"chunks": data.get("chunks", ""),
"retriever": data.get("retriever", ""),
"prompt_id": data.get("prompt_id", ""),
"tools": data.get("tools", []),
"agent_type": data.get("agent_type", ""),
"status": data.get("status"),
"json_schema": data.get("json_schema"),
"createdAt": datetime.datetime.now(datetime.timezone.utc),
"updatedAt": datetime.datetime.now(datetime.timezone.utc),
"lastUsedAt": None,
"key": key,
}
if new_agent["chunks"] == "":
new_agent["chunks"] = "2"
if (
new_agent["source"] == ""
and new_agent["retriever"] == ""
and not new_agent["sources"]
):
new_agent["retriever"] = "classic"
resp = agents_collection.insert_one(new_agent)
new_id = str(resp.inserted_id)
except Exception as err:
current_app.logger.error(f"Error creating agent: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"id": new_id, "key": key}), 201)
@agents_ns.route("/update_agent/<string:agent_id>")
class UpdateAgent(Resource):
update_agent_model = api.model(
"UpdateAgentModel",
{
"name": fields.String(required=True, description="New name of the agent"),
"description": fields.String(
required=True, description="New description of the agent"
),
"image": fields.String(
required=False, description="New image URL or identifier"
),
"source": fields.String(
required=False, description="Source ID (legacy single source)"
),
"sources": fields.List(
fields.String,
required=False,
description="List of source identifiers for multiple sources",
),
"chunks": fields.Integer(required=True, description="Chunks count"),
"retriever": fields.String(required=True, description="Retriever ID"),
"prompt_id": fields.String(required=True, description="Prompt ID"),
"tools": fields.List(
fields.String, required=False, description="List of tool identifiers"
),
"agent_type": fields.String(required=True, description="Type of the agent"),
"status": fields.String(
required=True, description="Status of the agent (draft or published)"
),
"json_schema": fields.Raw(
required=False,
description="JSON schema for enforcing structured output format",
),
},
)
@api.expect(update_agent_model)
@api.doc(description="Update an existing agent")
def put(self, agent_id):
if not (decoded_token := request.decoded_token):
return make_response(
jsonify({"success": False, "message": "Unauthorized"}), 401
)
user = decoded_token.get("sub")
if not ObjectId.is_valid(agent_id):
return make_response(
jsonify({"success": False, "message": "Invalid agent ID format"}), 400
)
oid = ObjectId(agent_id)
try:
if request.content_type and "application/json" in request.content_type:
data = request.get_json()
else:
data = request.form.to_dict()
json_fields = ["tools", "sources", "json_schema"]
for field in json_fields:
if field in data and data[field]:
try:
data[field] = json.loads(data[field])
except json.JSONDecodeError:
return make_response(
jsonify(
{
"success": False,
"message": f"Invalid JSON format for field: {field}",
}
),
400,
)
except Exception as err:
current_app.logger.error(
f"Error parsing request data: {err}", exc_info=True
)
return make_response(
jsonify({"success": False, "message": "Invalid request data"}), 400
)
try:
existing_agent = agents_collection.find_one({"_id": oid, "user": user})
except Exception as err:
current_app.logger.error(
f"Error finding agent {agent_id}: {err}", exc_info=True
)
return make_response(
jsonify({"success": False, "message": "Database error finding agent"}),
500,
)
if not existing_agent:
return make_response(
jsonify(
{"success": False, "message": "Agent not found or not authorized"}
),
404,
)
image_url, error = handle_image_upload(
request, existing_agent.get("image", ""), user, storage
)
if error:
current_app.logger.error(
f"Image upload error for agent {agent_id}: {error}"
)
return make_response(
jsonify({"success": False, "message": f"Image upload failed: {error}"}),
400,
)
update_fields = {}
allowed_fields = [
"name",
"description",
"image",
"source",
"sources",
"chunks",
"retriever",
"prompt_id",
"tools",
"agent_type",
"status",
"json_schema",
]
for field in allowed_fields:
if field not in data:
continue
if field == "status":
new_status = data.get("status")
if new_status not in ["draft", "published"]:
return make_response(
jsonify(
{
"success": False,
"message": "Invalid status value. Must be 'draft' or 'published'",
}
),
400,
)
update_fields[field] = new_status
elif field == "source":
source_id = data.get("source")
if source_id == "default":
update_fields[field] = "default"
elif source_id and ObjectId.is_valid(source_id):
update_fields[field] = DBRef("sources", ObjectId(source_id))
elif source_id:
return make_response(
jsonify(
{
"success": False,
"message": f"Invalid source ID format: {source_id}",
}
),
400,
)
else:
update_fields[field] = ""
elif field == "sources":
sources_list = data.get("sources", [])
if sources_list and isinstance(sources_list, list):
valid_sources = []
for source_id in sources_list:
if source_id == "default":
valid_sources.append("default")
elif ObjectId.is_valid(source_id):
valid_sources.append(DBRef("sources", ObjectId(source_id)))
else:
return make_response(
jsonify(
{
"success": False,
"message": f"Invalid source ID in list: {source_id}",
}
),
400,
)
update_fields[field] = valid_sources
else:
update_fields[field] = []
elif field == "chunks":
chunks_value = data.get("chunks")
if chunks_value == "" or chunks_value is None:
update_fields[field] = "2"
else:
try:
chunks_int = int(chunks_value)
if chunks_int < 0:
return make_response(
jsonify(
{
"success": False,
"message": "Chunks value must be a non-negative integer",
}
),
400,
)
update_fields[field] = str(chunks_int)
except (ValueError, TypeError):
return make_response(
jsonify(
{
"success": False,
"message": f"Invalid chunks value: {chunks_value}",
}
),
400,
)
elif field == "tools":
tools_list = data.get("tools", [])
if isinstance(tools_list, list):
update_fields[field] = tools_list
else:
return make_response(
jsonify(
{
"success": False,
"message": "Tools must be a list",
}
),
400,
)
elif field == "json_schema":
json_schema = data.get("json_schema")
if json_schema is not None:
if not isinstance(json_schema, dict):
return make_response(
jsonify(
{
"success": False,
"message": "JSON schema must be a valid object",
}
),
400,
)
update_fields[field] = json_schema
else:
update_fields[field] = None
else:
value = data[field]
if field in ["name", "description", "prompt_id", "agent_type"]:
if not value or not str(value).strip():
return make_response(
jsonify(
{
"success": False,
"message": f"Field '{field}' cannot be empty",
}
),
400,
)
update_fields[field] = value
if image_url:
update_fields["image"] = image_url
if not update_fields:
return make_response(
jsonify(
{
"success": False,
"message": "No valid update data provided",
}
),
400,
)
newly_generated_key = None
final_status = update_fields.get("status", existing_agent.get("status"))
if final_status == "published":
required_published_fields = {
"name": "Agent name",
"description": "Agent description",
"chunks": "Chunks count",
"prompt_id": "Prompt",
"agent_type": "Agent type",
}
missing_published_fields = []
for req_field, field_label in required_published_fields.items():
final_value = update_fields.get(
req_field, existing_agent.get(req_field)
)
if not final_value:
missing_published_fields.append(field_label)
source_val = update_fields.get("source", existing_agent.get("source"))
sources_val = update_fields.get(
"sources", existing_agent.get("sources", [])
)
has_valid_source = (
isinstance(source_val, DBRef)
or source_val == "default"
or (isinstance(sources_val, list) and len(sources_val) > 0)
)
if not has_valid_source:
missing_published_fields.append("Source")
if missing_published_fields:
return make_response(
jsonify(
{
"success": False,
"message": f"Cannot publish agent. Missing or invalid required fields: {', '.join(missing_published_fields)}",
}
),
400,
)
if not existing_agent.get("key"):
newly_generated_key = str(uuid.uuid4())
update_fields["key"] = newly_generated_key
update_fields["updatedAt"] = datetime.datetime.now(datetime.timezone.utc)
try:
result = agents_collection.update_one(
{"_id": oid, "user": user}, {"$set": update_fields}
)
if result.matched_count == 0:
return make_response(
jsonify(
{
"success": False,
"message": "Agent not found or update failed",
}
),
404,
)
if result.modified_count == 0 and result.matched_count == 1:
return make_response(
jsonify(
{
"success": True,
"message": "No changes detected",
"id": agent_id,
}
),
200,
)
except Exception as err:
current_app.logger.error(
f"Error updating agent {agent_id}: {err}", exc_info=True
)
return make_response(
jsonify({"success": False, "message": "Database error during update"}),
500,
)
response_data = {
"success": True,
"id": agent_id,
"message": "Agent updated successfully",
}
if newly_generated_key:
response_data["key"] = newly_generated_key
return make_response(jsonify(response_data), 200)
@agents_ns.route("/delete_agent")
class DeleteAgent(Resource):
@api.doc(params={"id": "ID of the agent"}, description="Delete an agent by ID")
def delete(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
agent_id = request.args.get("id")
if not agent_id:
return make_response(
jsonify({"success": False, "message": "ID is required"}), 400
)
try:
deleted_agent = agents_collection.find_one_and_delete(
{"_id": ObjectId(agent_id), "user": user}
)
if not deleted_agent:
return make_response(
jsonify({"success": False, "message": "Agent not found"}), 404
)
deleted_id = str(deleted_agent["_id"])
except Exception as err:
current_app.logger.error(f"Error deleting agent: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"id": deleted_id}), 200)
@agents_ns.route("/pinned_agents")
class PinnedAgents(Resource):
@api.doc(description="Get pinned agents for the user")
def get(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user_id = decoded_token.get("sub")
try:
user_doc = ensure_user_doc(user_id)
pinned_ids = user_doc.get("agent_preferences", {}).get("pinned", [])
if not pinned_ids:
return make_response(jsonify([]), 200)
pinned_object_ids = [ObjectId(agent_id) for agent_id in pinned_ids]
pinned_agents_cursor = agents_collection.find(
{"_id": {"$in": pinned_object_ids}}
)
pinned_agents = list(pinned_agents_cursor)
existing_ids = {str(agent["_id"]) for agent in pinned_agents}
# Clean up any stale pinned IDs
stale_ids = [
agent_id for agent_id in pinned_ids if agent_id not in existing_ids
]
if stale_ids:
users_collection.update_one(
{"user_id": user_id},
{"$pullAll": {"agent_preferences.pinned": stale_ids}},
)
list_pinned_agents = [
{
"id": str(agent["_id"]),
"name": agent.get("name", ""),
"description": agent.get("description", ""),
"image": (
generate_image_url(agent["image"]) if agent.get("image") else ""
),
"source": (
str(db.dereference(agent["source"])["_id"])
if "source" in agent
and agent["source"]
and isinstance(agent["source"], DBRef)
and db.dereference(agent["source"]) is not None
else ""
),
"chunks": agent.get("chunks", ""),
"retriever": agent.get("retriever", ""),
"prompt_id": agent.get("prompt_id", ""),
"tools": agent.get("tools", []),
"tool_details": resolve_tool_details(agent.get("tools", [])),
"agent_type": agent.get("agent_type", ""),
"status": agent.get("status", ""),
"created_at": agent.get("createdAt", ""),
"updated_at": agent.get("updatedAt", ""),
"last_used_at": agent.get("lastUsedAt", ""),
"key": (
f"{agent['key'][:4]}...{agent['key'][-4:]}"
if "key" in agent
else ""
),
"pinned": True,
}
for agent in pinned_agents
if "source" in agent or "retriever" in agent
]
except Exception as err:
current_app.logger.error(f"Error retrieving pinned agents: {err}")
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify(list_pinned_agents), 200)
@agents_ns.route("/pin_agent")
class PinAgent(Resource):
@api.doc(params={"id": "ID of the agent"}, description="Pin or unpin an agent")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user_id = decoded_token.get("sub")
agent_id = request.args.get("id")
if not agent_id:
return make_response(
jsonify({"success": False, "message": "ID is required"}), 400
)
try:
agent = agents_collection.find_one({"_id": ObjectId(agent_id)})
if not agent:
return make_response(
jsonify({"success": False, "message": "Agent not found"}), 404
)
user_doc = ensure_user_doc(user_id)
pinned_list = user_doc.get("agent_preferences", {}).get("pinned", [])
if agent_id in pinned_list:
users_collection.update_one(
{"user_id": user_id},
{"$pull": {"agent_preferences.pinned": agent_id}},
)
action = "unpinned"
else:
users_collection.update_one(
{"user_id": user_id},
{"$addToSet": {"agent_preferences.pinned": agent_id}},
)
action = "pinned"
except Exception as err:
current_app.logger.error(f"Error pinning/unpinning agent: {err}")
return make_response(
jsonify({"success": False, "message": "Server error"}), 500
)
return make_response(jsonify({"success": True, "action": action}), 200)
@agents_ns.route("/remove_shared_agent")
class RemoveSharedAgent(Resource):
@api.doc(
params={"id": "ID of the shared agent"},
description="Remove a shared agent from the current user's shared list",
)
def delete(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user_id = decoded_token.get("sub")
agent_id = request.args.get("id")
if not agent_id:
return make_response(
jsonify({"success": False, "message": "ID is required"}), 400
)
try:
agent = agents_collection.find_one(
{"_id": ObjectId(agent_id), "shared_publicly": True}
)
if not agent:
return make_response(
jsonify({"success": False, "message": "Shared agent not found"}),
404,
)
ensure_user_doc(user_id)
users_collection.update_one(
{"user_id": user_id},
{
"$pull": {
"agent_preferences.shared_with_me": agent_id,
"agent_preferences.pinned": agent_id,
}
},
)
return make_response(jsonify({"success": True, "action": "removed"}), 200)
except Exception as err:
current_app.logger.error(f"Error removing shared agent: {err}")
return make_response(
jsonify({"success": False, "message": "Server error"}), 500
)

View File

@@ -1,254 +0,0 @@
"""Agent management sharing functionality."""
import datetime
import secrets
from bson import DBRef
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, Namespace, Resource
from application.api import api
from application.api.user.base import (
agents_collection,
db,
ensure_user_doc,
resolve_tool_details,
user_tools_collection,
users_collection,
)
from application.utils import generate_image_url
agents_sharing_ns = Namespace(
"agents", description="Agent management operations", path="/api"
)
@agents_sharing_ns.route("/shared_agent")
class SharedAgent(Resource):
@api.doc(
params={
"token": "Shared token of the agent",
},
description="Get a shared agent by token or ID",
)
def get(self):
shared_token = request.args.get("token")
if not shared_token:
return make_response(
jsonify({"success": False, "message": "Token or ID is required"}), 400
)
try:
query = {
"shared_publicly": True,
"shared_token": shared_token,
}
shared_agent = agents_collection.find_one(query)
if not shared_agent:
return make_response(
jsonify({"success": False, "message": "Shared agent not found"}),
404,
)
agent_id = str(shared_agent["_id"])
data = {
"id": agent_id,
"user": shared_agent.get("user", ""),
"name": shared_agent.get("name", ""),
"image": (
generate_image_url(shared_agent["image"])
if shared_agent.get("image")
else ""
),
"description": shared_agent.get("description", ""),
"source": (
str(source_doc["_id"])
if isinstance(shared_agent.get("source"), DBRef)
and (source_doc := db.dereference(shared_agent.get("source")))
else ""
),
"chunks": shared_agent.get("chunks", "0"),
"retriever": shared_agent.get("retriever", "classic"),
"prompt_id": shared_agent.get("prompt_id", "default"),
"tools": shared_agent.get("tools", []),
"tool_details": resolve_tool_details(shared_agent.get("tools", [])),
"agent_type": shared_agent.get("agent_type", ""),
"status": shared_agent.get("status", ""),
"json_schema": shared_agent.get("json_schema"),
"created_at": shared_agent.get("createdAt", ""),
"updated_at": shared_agent.get("updatedAt", ""),
"shared": shared_agent.get("shared_publicly", False),
"shared_token": shared_agent.get("shared_token", ""),
"shared_metadata": shared_agent.get("shared_metadata", {}),
}
if data["tools"]:
enriched_tools = []
for tool in data["tools"]:
tool_data = user_tools_collection.find_one({"_id": ObjectId(tool)})
if tool_data:
enriched_tools.append(tool_data.get("name", ""))
data["tools"] = enriched_tools
decoded_token = getattr(request, "decoded_token", None)
if decoded_token:
user_id = decoded_token.get("sub")
owner_id = shared_agent.get("user")
if user_id != owner_id:
ensure_user_doc(user_id)
users_collection.update_one(
{"user_id": user_id},
{"$addToSet": {"agent_preferences.shared_with_me": agent_id}},
)
return make_response(jsonify(data), 200)
except Exception as err:
current_app.logger.error(f"Error retrieving shared agent: {err}")
return make_response(jsonify({"success": False}), 400)
@agents_sharing_ns.route("/shared_agents")
class SharedAgents(Resource):
@api.doc(description="Get shared agents explicitly shared with the user")
def get(self):
try:
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user_id = decoded_token.get("sub")
user_doc = ensure_user_doc(user_id)
shared_with_ids = user_doc.get("agent_preferences", {}).get(
"shared_with_me", []
)
shared_object_ids = [ObjectId(id) for id in shared_with_ids]
shared_agents_cursor = agents_collection.find(
{"_id": {"$in": shared_object_ids}, "shared_publicly": True}
)
shared_agents = list(shared_agents_cursor)
found_ids_set = {str(agent["_id"]) for agent in shared_agents}
stale_ids = [id for id in shared_with_ids if id not in found_ids_set]
if stale_ids:
users_collection.update_one(
{"user_id": user_id},
{"$pullAll": {"agent_preferences.shared_with_me": stale_ids}},
)
pinned_ids = set(user_doc.get("agent_preferences", {}).get("pinned", []))
list_shared_agents = [
{
"id": str(agent["_id"]),
"name": agent.get("name", ""),
"description": agent.get("description", ""),
"image": (
generate_image_url(agent["image"]) if agent.get("image") else ""
),
"tools": agent.get("tools", []),
"tool_details": resolve_tool_details(agent.get("tools", [])),
"agent_type": agent.get("agent_type", ""),
"status": agent.get("status", ""),
"json_schema": agent.get("json_schema"),
"created_at": agent.get("createdAt", ""),
"updated_at": agent.get("updatedAt", ""),
"pinned": str(agent["_id"]) in pinned_ids,
"shared": agent.get("shared_publicly", False),
"shared_token": agent.get("shared_token", ""),
"shared_metadata": agent.get("shared_metadata", {}),
}
for agent in shared_agents
]
return make_response(jsonify(list_shared_agents), 200)
except Exception as err:
current_app.logger.error(f"Error retrieving shared agents: {err}")
return make_response(jsonify({"success": False}), 400)
@agents_sharing_ns.route("/share_agent")
class ShareAgent(Resource):
@api.expect(
api.model(
"ShareAgentModel",
{
"id": fields.String(required=True, description="ID of the agent"),
"shared": fields.Boolean(
required=True, description="Share or unshare the agent"
),
"username": fields.String(
required=False, description="Name of the user"
),
},
)
)
@api.doc(description="Share or unshare an agent")
def put(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
if not data:
return make_response(
jsonify({"success": False, "message": "Missing JSON body"}), 400
)
agent_id = data.get("id")
shared = data.get("shared")
username = data.get("username", "")
if not agent_id:
return make_response(
jsonify({"success": False, "message": "ID is required"}), 400
)
if shared is None:
return make_response(
jsonify(
{
"success": False,
"message": "Shared parameter is required and must be true or false",
}
),
400,
)
try:
try:
agent_oid = ObjectId(agent_id)
except Exception:
return make_response(
jsonify({"success": False, "message": "Invalid agent ID"}), 400
)
agent = agents_collection.find_one({"_id": agent_oid, "user": user})
if not agent:
return make_response(
jsonify({"success": False, "message": "Agent not found"}), 404
)
if shared:
shared_metadata = {
"shared_by": username,
"shared_at": datetime.datetime.now(datetime.timezone.utc),
}
shared_token = secrets.token_urlsafe(32)
agents_collection.update_one(
{"_id": agent_oid, "user": user},
{
"$set": {
"shared_publicly": shared,
"shared_metadata": shared_metadata,
"shared_token": shared_token,
}
},
)
else:
agents_collection.update_one(
{"_id": agent_oid, "user": user},
{"$set": {"shared_publicly": shared, "shared_token": None}},
{"$unset": {"shared_metadata": ""}},
)
except Exception as err:
current_app.logger.error(f"Error sharing/unsharing agent: {err}")
return make_response(jsonify({"success": False, "error": str(err)}), 400)
shared_token = shared_token if shared else None
return make_response(
jsonify({"success": True, "shared_token": shared_token}), 200
)

View File

@@ -1,119 +0,0 @@
"""Agent management webhook handlers."""
import secrets
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask_restx import Namespace, Resource
from application.api import api
from application.api.user.base import agents_collection, require_agent
from application.api.user.tasks import process_agent_webhook
from application.core.settings import settings
agents_webhooks_ns = Namespace(
"agents", description="Agent management operations", path="/api"
)
@agents_webhooks_ns.route("/agent_webhook")
class AgentWebhook(Resource):
@api.doc(
params={"id": "ID of the agent"},
description="Generate webhook URL for the agent",
)
def get(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
agent_id = request.args.get("id")
if not agent_id:
return make_response(
jsonify({"success": False, "message": "ID is required"}), 400
)
try:
agent = agents_collection.find_one(
{"_id": ObjectId(agent_id), "user": user}
)
if not agent:
return make_response(
jsonify({"success": False, "message": "Agent not found"}), 404
)
webhook_token = agent.get("incoming_webhook_token")
if not webhook_token:
webhook_token = secrets.token_urlsafe(32)
agents_collection.update_one(
{"_id": ObjectId(agent_id), "user": user},
{"$set": {"incoming_webhook_token": webhook_token}},
)
base_url = settings.API_URL.rstrip("/")
full_webhook_url = f"{base_url}/api/webhooks/agents/{webhook_token}"
except Exception as err:
current_app.logger.error(
f"Error generating webhook URL: {err}", exc_info=True
)
return make_response(
jsonify({"success": False, "message": "Error generating webhook URL"}),
400,
)
return make_response(
jsonify({"success": True, "webhook_url": full_webhook_url}), 200
)
@agents_webhooks_ns.route("/webhooks/agents/<string:webhook_token>")
class AgentWebhookListener(Resource):
method_decorators = [require_agent]
def _enqueue_webhook_task(self, agent_id_str, payload, source_method):
if not payload:
current_app.logger.warning(
f"Webhook ({source_method}) received for agent {agent_id_str} with empty payload."
)
current_app.logger.info(
f"Incoming {source_method} webhook for agent {agent_id_str}. Enqueuing task with payload: {payload}"
)
try:
task = process_agent_webhook.delay(
agent_id=agent_id_str,
payload=payload,
)
current_app.logger.info(
f"Task {task.id} enqueued for agent {agent_id_str} ({source_method})."
)
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
except Exception as err:
current_app.logger.error(
f"Error enqueuing webhook task ({source_method}) for agent {agent_id_str}: {err}",
exc_info=True,
)
return make_response(
jsonify({"success": False, "message": "Error processing webhook"}), 500
)
@api.doc(
description="Webhook listener for agent events (POST). Expects JSON payload, which is used to trigger processing.",
)
def post(self, webhook_token, agent, agent_id_str):
payload = request.get_json()
if payload is None:
return make_response(
jsonify(
{
"success": False,
"message": "Invalid or missing JSON data in request body",
}
),
400,
)
return self._enqueue_webhook_task(agent_id_str, payload, source_method="POST")
@api.doc(
description="Webhook listener for agent events (GET). Uses URL query parameters as payload to trigger processing.",
)
def get(self, webhook_token, agent, agent_id_str):
payload = request.args.to_dict(flat=True)
return self._enqueue_webhook_task(agent_id_str, payload, source_method="GET")

View File

@@ -1,5 +0,0 @@
"""Analytics module."""
from .routes import analytics_ns
__all__ = ["analytics_ns"]

View File

@@ -1,540 +0,0 @@
"""Analytics and reporting routes."""
import datetime
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, Namespace, Resource
from application.api import api
from application.api.user.base import (
agents_collection,
conversations_collection,
generate_date_range,
generate_hourly_range,
generate_minute_range,
token_usage_collection,
user_logs_collection,
)
analytics_ns = Namespace(
"analytics", description="Analytics and reporting operations", path="/api"
)
@analytics_ns.route("/get_message_analytics")
class GetMessageAnalytics(Resource):
get_message_analytics_model = api.model(
"GetMessageAnalyticsModel",
{
"api_key_id": fields.String(required=False, description="API Key ID"),
"filter_option": fields.String(
required=False,
description="Filter option for analytics",
default="last_30_days",
enum=[
"last_hour",
"last_24_hour",
"last_7_days",
"last_15_days",
"last_30_days",
],
),
},
)
@api.expect(get_message_analytics_model)
@api.doc(description="Get message analytics based on filter option")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
api_key_id = data.get("api_key_id")
filter_option = data.get("filter_option", "last_30_days")
try:
api_key = (
agents_collection.find_one({"_id": ObjectId(api_key_id), "user": user})[
"key"
]
if api_key_id
else None
)
except Exception as err:
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
end_date = datetime.datetime.now(datetime.timezone.utc)
if filter_option == "last_hour":
start_date = end_date - datetime.timedelta(hours=1)
group_format = "%Y-%m-%d %H:%M:00"
elif filter_option == "last_24_hour":
start_date = end_date - datetime.timedelta(hours=24)
group_format = "%Y-%m-%d %H:00"
else:
if filter_option in ["last_7_days", "last_15_days", "last_30_days"]:
filter_days = (
6
if filter_option == "last_7_days"
else 14 if filter_option == "last_15_days" else 29
)
else:
return make_response(
jsonify({"success": False, "message": "Invalid option"}), 400
)
start_date = end_date - datetime.timedelta(days=filter_days)
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
end_date = end_date.replace(
hour=23, minute=59, second=59, microsecond=999999
)
group_format = "%Y-%m-%d"
try:
match_stage = {
"$match": {
"user": user,
}
}
if api_key:
match_stage["$match"]["api_key"] = api_key
pipeline = [
match_stage,
{"$unwind": "$queries"},
{
"$match": {
"queries.timestamp": {"$gte": start_date, "$lte": end_date}
}
},
{
"$group": {
"_id": {
"$dateToString": {
"format": group_format,
"date": "$queries.timestamp",
}
},
"count": {"$sum": 1},
}
},
{"$sort": {"_id": 1}},
]
message_data = conversations_collection.aggregate(pipeline)
if filter_option == "last_hour":
intervals = generate_minute_range(start_date, end_date)
elif filter_option == "last_24_hour":
intervals = generate_hourly_range(start_date, end_date)
else:
intervals = generate_date_range(start_date, end_date)
daily_messages = {interval: 0 for interval in intervals}
for entry in message_data:
daily_messages[entry["_id"]] = entry["count"]
except Exception as err:
current_app.logger.error(
f"Error getting message analytics: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
return make_response(
jsonify({"success": True, "messages": daily_messages}), 200
)
@analytics_ns.route("/get_token_analytics")
class GetTokenAnalytics(Resource):
get_token_analytics_model = api.model(
"GetTokenAnalyticsModel",
{
"api_key_id": fields.String(required=False, description="API Key ID"),
"filter_option": fields.String(
required=False,
description="Filter option for analytics",
default="last_30_days",
enum=[
"last_hour",
"last_24_hour",
"last_7_days",
"last_15_days",
"last_30_days",
],
),
},
)
@api.expect(get_token_analytics_model)
@api.doc(description="Get token analytics data")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
api_key_id = data.get("api_key_id")
filter_option = data.get("filter_option", "last_30_days")
try:
api_key = (
agents_collection.find_one({"_id": ObjectId(api_key_id), "user": user})[
"key"
]
if api_key_id
else None
)
except Exception as err:
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
end_date = datetime.datetime.now(datetime.timezone.utc)
if filter_option == "last_hour":
start_date = end_date - datetime.timedelta(hours=1)
group_format = "%Y-%m-%d %H:%M:00"
group_stage = {
"$group": {
"_id": {
"minute": {
"$dateToString": {
"format": group_format,
"date": "$timestamp",
}
}
},
"total_tokens": {
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
},
}
}
elif filter_option == "last_24_hour":
start_date = end_date - datetime.timedelta(hours=24)
group_format = "%Y-%m-%d %H:00"
group_stage = {
"$group": {
"_id": {
"hour": {
"$dateToString": {
"format": group_format,
"date": "$timestamp",
}
}
},
"total_tokens": {
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
},
}
}
else:
if filter_option in ["last_7_days", "last_15_days", "last_30_days"]:
filter_days = (
6
if filter_option == "last_7_days"
else (14 if filter_option == "last_15_days" else 29)
)
else:
return make_response(
jsonify({"success": False, "message": "Invalid option"}), 400
)
start_date = end_date - datetime.timedelta(days=filter_days)
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
end_date = end_date.replace(
hour=23, minute=59, second=59, microsecond=999999
)
group_format = "%Y-%m-%d"
group_stage = {
"$group": {
"_id": {
"day": {
"$dateToString": {
"format": group_format,
"date": "$timestamp",
}
}
},
"total_tokens": {
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
},
}
}
try:
match_stage = {
"$match": {
"user_id": user,
"timestamp": {"$gte": start_date, "$lte": end_date},
}
}
if api_key:
match_stage["$match"]["api_key"] = api_key
token_usage_data = token_usage_collection.aggregate(
[
match_stage,
group_stage,
{"$sort": {"_id": 1}},
]
)
if filter_option == "last_hour":
intervals = generate_minute_range(start_date, end_date)
elif filter_option == "last_24_hour":
intervals = generate_hourly_range(start_date, end_date)
else:
intervals = generate_date_range(start_date, end_date)
daily_token_usage = {interval: 0 for interval in intervals}
for entry in token_usage_data:
if filter_option == "last_hour":
daily_token_usage[entry["_id"]["minute"]] = entry["total_tokens"]
elif filter_option == "last_24_hour":
daily_token_usage[entry["_id"]["hour"]] = entry["total_tokens"]
else:
daily_token_usage[entry["_id"]["day"]] = entry["total_tokens"]
except Exception as err:
current_app.logger.error(
f"Error getting token analytics: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
return make_response(
jsonify({"success": True, "token_usage": daily_token_usage}), 200
)
@analytics_ns.route("/get_feedback_analytics")
class GetFeedbackAnalytics(Resource):
get_feedback_analytics_model = api.model(
"GetFeedbackAnalyticsModel",
{
"api_key_id": fields.String(required=False, description="API Key ID"),
"filter_option": fields.String(
required=False,
description="Filter option for analytics",
default="last_30_days",
enum=[
"last_hour",
"last_24_hour",
"last_7_days",
"last_15_days",
"last_30_days",
],
),
},
)
@api.expect(get_feedback_analytics_model)
@api.doc(description="Get feedback analytics data")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
api_key_id = data.get("api_key_id")
filter_option = data.get("filter_option", "last_30_days")
try:
api_key = (
agents_collection.find_one({"_id": ObjectId(api_key_id), "user": user})[
"key"
]
if api_key_id
else None
)
except Exception as err:
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
end_date = datetime.datetime.now(datetime.timezone.utc)
if filter_option == "last_hour":
start_date = end_date - datetime.timedelta(hours=1)
group_format = "%Y-%m-%d %H:%M:00"
date_field = {
"$dateToString": {
"format": group_format,
"date": "$queries.feedback_timestamp",
}
}
elif filter_option == "last_24_hour":
start_date = end_date - datetime.timedelta(hours=24)
group_format = "%Y-%m-%d %H:00"
date_field = {
"$dateToString": {
"format": group_format,
"date": "$queries.feedback_timestamp",
}
}
else:
if filter_option in ["last_7_days", "last_15_days", "last_30_days"]:
filter_days = (
6
if filter_option == "last_7_days"
else (14 if filter_option == "last_15_days" else 29)
)
else:
return make_response(
jsonify({"success": False, "message": "Invalid option"}), 400
)
start_date = end_date - datetime.timedelta(days=filter_days)
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
end_date = end_date.replace(
hour=23, minute=59, second=59, microsecond=999999
)
group_format = "%Y-%m-%d"
date_field = {
"$dateToString": {
"format": group_format,
"date": "$queries.feedback_timestamp",
}
}
try:
match_stage = {
"$match": {
"queries.feedback_timestamp": {
"$gte": start_date,
"$lte": end_date,
},
"queries.feedback": {"$exists": True},
}
}
if api_key:
match_stage["$match"]["api_key"] = api_key
pipeline = [
match_stage,
{"$unwind": "$queries"},
{"$match": {"queries.feedback": {"$exists": True}}},
{
"$group": {
"_id": {"time": date_field, "feedback": "$queries.feedback"},
"count": {"$sum": 1},
}
},
{
"$group": {
"_id": "$_id.time",
"positive": {
"$sum": {
"$cond": [
{"$eq": ["$_id.feedback", "LIKE"]},
"$count",
0,
]
}
},
"negative": {
"$sum": {
"$cond": [
{"$eq": ["$_id.feedback", "DISLIKE"]},
"$count",
0,
]
}
},
}
},
{"$sort": {"_id": 1}},
]
feedback_data = conversations_collection.aggregate(pipeline)
if filter_option == "last_hour":
intervals = generate_minute_range(start_date, end_date)
elif filter_option == "last_24_hour":
intervals = generate_hourly_range(start_date, end_date)
else:
intervals = generate_date_range(start_date, end_date)
daily_feedback = {
interval: {"positive": 0, "negative": 0} for interval in intervals
}
for entry in feedback_data:
daily_feedback[entry["_id"]] = {
"positive": entry["positive"],
"negative": entry["negative"],
}
except Exception as err:
current_app.logger.error(
f"Error getting feedback analytics: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
return make_response(
jsonify({"success": True, "feedback": daily_feedback}), 200
)
@analytics_ns.route("/get_user_logs")
class GetUserLogs(Resource):
get_user_logs_model = api.model(
"GetUserLogsModel",
{
"page": fields.Integer(
required=False,
description="Page number for pagination",
default=1,
),
"api_key_id": fields.String(required=False, description="API Key ID"),
"page_size": fields.Integer(
required=False,
description="Number of logs per page",
default=10,
),
},
)
@api.expect(get_user_logs_model)
@api.doc(description="Get user logs with pagination")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
page = int(data.get("page", 1))
api_key_id = data.get("api_key_id")
page_size = int(data.get("page_size", 10))
skip = (page - 1) * page_size
try:
api_key = (
agents_collection.find_one({"_id": ObjectId(api_key_id)})["key"]
if api_key_id
else None
)
except Exception as err:
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
query = {"user": user}
if api_key:
query = {"api_key": api_key}
items_cursor = (
user_logs_collection.find(query)
.sort("timestamp", -1)
.skip(skip)
.limit(page_size + 1)
)
items = list(items_cursor)
results = [
{
"id": str(item.get("_id")),
"action": item.get("action"),
"level": item.get("level"),
"user": item.get("user"),
"question": item.get("question"),
"sources": item.get("sources"),
"retriever_params": item.get("retriever_params"),
"timestamp": item.get("timestamp"),
}
for item in items[:page_size]
]
has_more = len(items) > page_size
return make_response(
jsonify(
{
"success": True,
"logs": results,
"page": page,
"page_size": page_size,
"has_more": has_more,
}
),
200,
)

View File

@@ -1,5 +0,0 @@
"""Attachments module."""
from .routes import attachments_ns
__all__ = ["attachments_ns"]

View File

@@ -1,150 +0,0 @@
"""File attachments and media routes."""
import os
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, Namespace, Resource
from application.api import api
from application.api.user.base import agents_collection, storage
from application.api.user.tasks import store_attachment
from application.core.settings import settings
from application.tts.google_tts import GoogleTTS
from application.utils import safe_filename
attachments_ns = Namespace(
"attachments", description="File attachments and media operations", path="/api"
)
@attachments_ns.route("/store_attachment")
class StoreAttachment(Resource):
@api.expect(
api.model(
"AttachmentModel",
{
"file": fields.Raw(required=True, description="File to upload"),
"api_key": fields.String(
required=False, description="API key (optional)"
),
},
)
)
@api.doc(
description="Stores a single attachment without vectorization or training. Supports user or API key authentication."
)
def post(self):
decoded_token = getattr(request, "decoded_token", None)
api_key = request.form.get("api_key") or request.args.get("api_key")
file = request.files.get("file")
if not file or file.filename == "":
return make_response(
jsonify({"status": "error", "message": "Missing file"}),
400,
)
user = None
if decoded_token:
user = safe_filename(decoded_token.get("sub"))
elif api_key:
agent = agents_collection.find_one({"key": api_key})
if not agent:
return make_response(
jsonify({"success": False, "message": "Invalid API key"}), 401
)
user = safe_filename(agent.get("user"))
else:
return make_response(
jsonify({"success": False, "message": "Authentication required"}), 401
)
try:
attachment_id = ObjectId()
original_filename = safe_filename(os.path.basename(file.filename))
relative_path = f"{settings.UPLOAD_FOLDER}/{user}/attachments/{str(attachment_id)}/{original_filename}"
metadata = storage.save_file(file, relative_path)
file_info = {
"filename": original_filename,
"attachment_id": str(attachment_id),
"path": relative_path,
"metadata": metadata,
}
task = store_attachment.delay(file_info, user)
return make_response(
jsonify(
{
"success": True,
"task_id": task.id,
"message": "File uploaded successfully. Processing started.",
}
),
200,
)
except Exception as err:
current_app.logger.error(f"Error storing attachment: {err}", exc_info=True)
return make_response(jsonify({"success": False, "error": str(err)}), 400)
@attachments_ns.route("/images/<path:image_path>")
class ServeImage(Resource):
@api.doc(description="Serve an image from storage")
def get(self, image_path):
try:
file_obj = storage.get_file(image_path)
extension = image_path.split(".")[-1].lower()
content_type = f"image/{extension}"
if extension == "jpg":
content_type = "image/jpeg"
response = make_response(file_obj.read())
response.headers.set("Content-Type", content_type)
response.headers.set("Cache-Control", "max-age=86400")
return response
except FileNotFoundError:
return make_response(
jsonify({"success": False, "message": "Image not found"}), 404
)
except Exception as e:
current_app.logger.error(f"Error serving image: {e}")
return make_response(
jsonify({"success": False, "message": "Error retrieving image"}), 500
)
@attachments_ns.route("/tts")
class TextToSpeech(Resource):
tts_model = api.model(
"TextToSpeechModel",
{
"text": fields.String(
required=True, description="Text to be synthesized as audio"
),
},
)
@api.expect(tts_model)
@api.doc(description="Synthesize audio speech from text")
def post(self):
data = request.get_json()
text = data["text"]
try:
tts_instance = GoogleTTS()
audio_base64, detected_language = tts_instance.text_to_speech(text)
return make_response(
jsonify(
{
"success": True,
"audio_base64": audio_base64,
"lang": detected_language,
}
),
200,
)
except Exception as err:
current_app.logger.error(f"Error synthesizing audio: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)

View File

@@ -1,222 +0,0 @@
"""
Shared utilities, database connections, and helper functions for user API routes.
"""
import datetime
import os
import uuid
from functools import wraps
from typing import Optional, Tuple
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, Response
from pymongo import ReturnDocument
from werkzeug.utils import secure_filename
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.storage.storage_creator import StorageCreator
from application.vectorstore.vector_creator import VectorCreator
storage = StorageCreator.get_storage()
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
conversations_collection = db["conversations"]
sources_collection = db["sources"]
prompts_collection = db["prompts"]
feedback_collection = db["feedback"]
agents_collection = db["agents"]
token_usage_collection = db["token_usage"]
shared_conversations_collections = db["shared_conversations"]
users_collection = db["users"]
user_logs_collection = db["user_logs"]
user_tools_collection = db["user_tools"]
attachments_collection = db["attachments"]
try:
agents_collection.create_index(
[("shared", 1)],
name="shared_index",
background=True,
)
users_collection.create_index("user_id", unique=True)
except Exception as e:
print("Error creating indexes:", e)
current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
def generate_minute_range(start_date, end_date):
"""Generate a dictionary with minute-level time ranges."""
return {
(start_date + datetime.timedelta(minutes=i)).strftime("%Y-%m-%d %H:%M:00"): 0
for i in range(int((end_date - start_date).total_seconds() // 60) + 1)
}
def generate_hourly_range(start_date, end_date):
"""Generate a dictionary with hourly time ranges."""
return {
(start_date + datetime.timedelta(hours=i)).strftime("%Y-%m-%d %H:00"): 0
for i in range(int((end_date - start_date).total_seconds() // 3600) + 1)
}
def generate_date_range(start_date, end_date):
"""Generate a dictionary with daily date ranges."""
return {
(start_date + datetime.timedelta(days=i)).strftime("%Y-%m-%d"): 0
for i in range((end_date - start_date).days + 1)
}
def ensure_user_doc(user_id):
"""
Ensure user document exists with proper agent preferences structure.
Args:
user_id: The user ID to ensure
Returns:
The user document
"""
default_prefs = {
"pinned": [],
"shared_with_me": [],
}
user_doc = users_collection.find_one_and_update(
{"user_id": user_id},
{"$setOnInsert": {"agent_preferences": default_prefs}},
upsert=True,
return_document=ReturnDocument.AFTER,
)
prefs = user_doc.get("agent_preferences", {})
updates = {}
if "pinned" not in prefs:
updates["agent_preferences.pinned"] = []
if "shared_with_me" not in prefs:
updates["agent_preferences.shared_with_me"] = []
if updates:
users_collection.update_one({"user_id": user_id}, {"$set": updates})
user_doc = users_collection.find_one({"user_id": user_id})
return user_doc
def resolve_tool_details(tool_ids):
"""
Resolve tool IDs to their details.
Args:
tool_ids: List of tool IDs
Returns:
List of tool details with id, name, and display_name
"""
tools = user_tools_collection.find(
{"_id": {"$in": [ObjectId(tid) for tid in tool_ids]}}
)
return [
{
"id": str(tool["_id"]),
"name": tool.get("name", ""),
"display_name": tool.get("displayName", tool.get("name", "")),
}
for tool in tools
]
def get_vector_store(source_id):
"""
Get the Vector Store for a given source ID.
Args:
source_id (str): source id of the document
Returns:
Vector store instance
"""
store = VectorCreator.create_vectorstore(
settings.VECTOR_STORE,
source_id=source_id,
embeddings_key=os.getenv("EMBEDDINGS_KEY"),
)
return store
def handle_image_upload(
request, existing_url: str, user: str, storage, base_path: str = "attachments/"
) -> Tuple[str, Optional[Response]]:
"""
Handle image file upload from request.
Args:
request: Flask request object
existing_url: Existing image URL (fallback)
user: User ID
storage: Storage instance
base_path: Base path for upload
Returns:
Tuple of (image_url, error_response)
"""
image_url = existing_url
if "image" in request.files:
file = request.files["image"]
if file.filename != "":
filename = secure_filename(file.filename)
upload_path = f"{settings.UPLOAD_FOLDER.rstrip('/')}/{user}/{base_path.rstrip('/')}/{uuid.uuid4()}_{filename}"
try:
storage.save_file(file, upload_path, storage_class="STANDARD")
image_url = upload_path
except Exception as e:
current_app.logger.error(f"Error uploading image: {e}")
return None, make_response(
jsonify({"success": False, "message": "Image upload failed"}),
400,
)
return image_url, None
def require_agent(func):
"""
Decorator to require valid agent webhook token.
Args:
func: Function to decorate
Returns:
Wrapped function
"""
@wraps(func)
def wrapper(*args, **kwargs):
webhook_token = kwargs.get("webhook_token")
if not webhook_token:
return make_response(
jsonify({"success": False, "message": "Webhook token missing"}), 400
)
agent = agents_collection.find_one(
{"incoming_webhook_token": webhook_token}, {"_id": 1}
)
if not agent:
current_app.logger.warning(
f"Webhook attempt with invalid token: {webhook_token}"
)
return make_response(
jsonify({"success": False, "message": "Agent not found"}), 404
)
kwargs["agent"] = agent
kwargs["agent_id_str"] = str(agent["_id"])
return func(*args, **kwargs)
return wrapper

View File

@@ -1,5 +0,0 @@
"""Conversation management module."""
from .routes import conversations_ns
__all__ = ["conversations_ns"]

View File

@@ -1,280 +0,0 @@
"""Conversation management routes."""
import datetime
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, Namespace, Resource
from application.api import api
from application.api.user.base import attachments_collection, conversations_collection
from application.utils import check_required_fields
conversations_ns = Namespace(
"conversations", description="Conversation management operations", path="/api"
)
@conversations_ns.route("/delete_conversation")
class DeleteConversation(Resource):
@api.doc(
description="Deletes a conversation by ID",
params={"id": "The ID of the conversation to delete"},
)
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
conversation_id = request.args.get("id")
if not conversation_id:
return make_response(
jsonify({"success": False, "message": "ID is required"}), 400
)
try:
conversations_collection.delete_one(
{"_id": ObjectId(conversation_id), "user": decoded_token["sub"]}
)
except Exception as err:
current_app.logger.error(
f"Error deleting conversation: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True}), 200)
@conversations_ns.route("/delete_all_conversations")
class DeleteAllConversations(Resource):
@api.doc(
description="Deletes all conversations for a specific user",
)
def get(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user_id = decoded_token.get("sub")
try:
conversations_collection.delete_many({"user": user_id})
except Exception as err:
current_app.logger.error(
f"Error deleting all conversations: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True}), 200)
@conversations_ns.route("/get_conversations")
class GetConversations(Resource):
@api.doc(
description="Retrieve a list of the latest 30 conversations (excluding API key conversations)",
)
def get(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
try:
conversations = (
conversations_collection.find(
{
"$or": [
{"api_key": {"$exists": False}},
{"agent_id": {"$exists": True}},
],
"user": decoded_token.get("sub"),
}
)
.sort("date", -1)
.limit(30)
)
list_conversations = [
{
"id": str(conversation["_id"]),
"name": conversation["name"],
"agent_id": conversation.get("agent_id", None),
"is_shared_usage": conversation.get("is_shared_usage", False),
"shared_token": conversation.get("shared_token", None),
}
for conversation in conversations
]
except Exception as err:
current_app.logger.error(
f"Error retrieving conversations: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify(list_conversations), 200)
@conversations_ns.route("/get_single_conversation")
class GetSingleConversation(Resource):
@api.doc(
description="Retrieve a single conversation by ID",
params={"id": "The conversation ID"},
)
def get(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
conversation_id = request.args.get("id")
if not conversation_id:
return make_response(
jsonify({"success": False, "message": "ID is required"}), 400
)
try:
conversation = conversations_collection.find_one(
{"_id": ObjectId(conversation_id), "user": decoded_token.get("sub")}
)
if not conversation:
return make_response(jsonify({"status": "not found"}), 404)
# Process queries to include attachment names
queries = conversation["queries"]
for query in queries:
if "attachments" in query and query["attachments"]:
attachment_details = []
for attachment_id in query["attachments"]:
try:
attachment = attachments_collection.find_one(
{"_id": ObjectId(attachment_id)}
)
if attachment:
attachment_details.append(
{
"id": str(attachment["_id"]),
"fileName": attachment.get(
"filename", "Unknown file"
),
}
)
except Exception as e:
current_app.logger.error(
f"Error retrieving attachment {attachment_id}: {e}",
exc_info=True,
)
query["attachments"] = attachment_details
except Exception as err:
current_app.logger.error(
f"Error retrieving conversation: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
data = {
"queries": queries,
"agent_id": conversation.get("agent_id"),
"is_shared_usage": conversation.get("is_shared_usage", False),
"shared_token": conversation.get("shared_token", None),
}
return make_response(jsonify(data), 200)
@conversations_ns.route("/update_conversation_name")
class UpdateConversationName(Resource):
@api.expect(
api.model(
"UpdateConversationModel",
{
"id": fields.String(required=True, description="Conversation ID"),
"name": fields.String(
required=True, description="New name of the conversation"
),
},
)
)
@api.doc(
description="Updates the name of a conversation",
)
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
data = request.get_json()
required_fields = ["id", "name"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
conversations_collection.update_one(
{"_id": ObjectId(data["id"]), "user": decoded_token.get("sub")},
{"$set": {"name": data["name"]}},
)
except Exception as err:
current_app.logger.error(
f"Error updating conversation name: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True}), 200)
@conversations_ns.route("/feedback")
class SubmitFeedback(Resource):
@api.expect(
api.model(
"FeedbackModel",
{
"question": fields.String(
required=False, description="The user question"
),
"answer": fields.String(required=False, description="The AI answer"),
"feedback": fields.String(required=True, description="User feedback"),
"question_index": fields.Integer(
required=True,
description="The question number in that particular conversation",
),
"conversation_id": fields.String(
required=True, description="id of the particular conversation"
),
"api_key": fields.String(description="Optional API key"),
},
)
)
@api.doc(
description="Submit feedback for a conversation",
)
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
data = request.get_json()
required_fields = ["feedback", "conversation_id", "question_index"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
if data["feedback"] is None:
# Remove feedback and feedback_timestamp if feedback is null
conversations_collection.update_one(
{
"_id": ObjectId(data["conversation_id"]),
"user": decoded_token.get("sub"),
f"queries.{data['question_index']}": {"$exists": True},
},
{
"$unset": {
f"queries.{data['question_index']}.feedback": "",
f"queries.{data['question_index']}.feedback_timestamp": "",
}
},
)
else:
# Set feedback and feedback_timestamp if feedback has a value
conversations_collection.update_one(
{
"_id": ObjectId(data["conversation_id"]),
"user": decoded_token.get("sub"),
f"queries.{data['question_index']}": {"$exists": True},
},
{
"$set": {
f"queries.{data['question_index']}.feedback": data[
"feedback"
],
f"queries.{data['question_index']}.feedback_timestamp": datetime.datetime.now(
datetime.timezone.utc
),
}
},
)
except Exception as err:
current_app.logger.error(f"Error submitting feedback: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True}), 200)

View File

@@ -1,5 +0,0 @@
"""Prompts module."""
from .routes import prompts_ns
__all__ = ["prompts_ns"]

View File

@@ -1,191 +0,0 @@
"""Prompt management routes."""
import os
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, Namespace, Resource
from application.api import api
from application.api.user.base import current_dir, prompts_collection
from application.utils import check_required_fields
prompts_ns = Namespace(
"prompts", description="Prompt management operations", path="/api"
)
@prompts_ns.route("/create_prompt")
class CreatePrompt(Resource):
create_prompt_model = api.model(
"CreatePromptModel",
{
"content": fields.String(
required=True, description="Content of the prompt"
),
"name": fields.String(required=True, description="Name of the prompt"),
},
)
@api.expect(create_prompt_model)
@api.doc(description="Create a new prompt")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
data = request.get_json()
required_fields = ["content", "name"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
user = decoded_token.get("sub")
try:
resp = prompts_collection.insert_one(
{
"name": data["name"],
"content": data["content"],
"user": user,
}
)
new_id = str(resp.inserted_id)
except Exception as err:
current_app.logger.error(f"Error creating prompt: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"id": new_id}), 200)
@prompts_ns.route("/get_prompts")
class GetPrompts(Resource):
@api.doc(description="Get all prompts for the user")
def get(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
try:
prompts = prompts_collection.find({"user": user})
list_prompts = [
{"id": "default", "name": "default", "type": "public"},
{"id": "creative", "name": "creative", "type": "public"},
{"id": "strict", "name": "strict", "type": "public"},
]
for prompt in prompts:
list_prompts.append(
{
"id": str(prompt["_id"]),
"name": prompt["name"],
"type": "private",
}
)
except Exception as err:
current_app.logger.error(f"Error retrieving prompts: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify(list_prompts), 200)
@prompts_ns.route("/get_single_prompt")
class GetSinglePrompt(Resource):
@api.doc(params={"id": "ID of the prompt"}, description="Get a single prompt by ID")
def get(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
prompt_id = request.args.get("id")
if not prompt_id:
return make_response(
jsonify({"success": False, "message": "ID is required"}), 400
)
try:
if prompt_id == "default":
with open(
os.path.join(current_dir, "prompts", "chat_combine_default.txt"),
"r",
) as f:
chat_combine_template = f.read()
return make_response(jsonify({"content": chat_combine_template}), 200)
elif prompt_id == "creative":
with open(
os.path.join(current_dir, "prompts", "chat_combine_creative.txt"),
"r",
) as f:
chat_reduce_creative = f.read()
return make_response(jsonify({"content": chat_reduce_creative}), 200)
elif prompt_id == "strict":
with open(
os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r"
) as f:
chat_reduce_strict = f.read()
return make_response(jsonify({"content": chat_reduce_strict}), 200)
prompt = prompts_collection.find_one(
{"_id": ObjectId(prompt_id), "user": user}
)
except Exception as err:
current_app.logger.error(f"Error retrieving prompt: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"content": prompt["content"]}), 200)
@prompts_ns.route("/delete_prompt")
class DeletePrompt(Resource):
delete_prompt_model = api.model(
"DeletePromptModel",
{"id": fields.String(required=True, description="Prompt ID to delete")},
)
@api.expect(delete_prompt_model)
@api.doc(description="Delete a prompt by ID")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
required_fields = ["id"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
prompts_collection.delete_one({"_id": ObjectId(data["id"]), "user": user})
except Exception as err:
current_app.logger.error(f"Error deleting prompt: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True}), 200)
@prompts_ns.route("/update_prompt")
class UpdatePrompt(Resource):
update_prompt_model = api.model(
"UpdatePromptModel",
{
"id": fields.String(required=True, description="Prompt ID to update"),
"name": fields.String(required=True, description="New name of the prompt"),
"content": fields.String(
required=True, description="New content of the prompt"
),
},
)
@api.expect(update_prompt_model)
@api.doc(description="Update an existing prompt")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
required_fields = ["id", "name", "content"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
prompts_collection.update_one(
{"_id": ObjectId(data["id"]), "user": user},
{"$set": {"name": data["name"], "content": data["content"]}},
)
except Exception as err:
current_app.logger.error(f"Error updating prompt: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True}), 200)

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +0,0 @@
"""Sharing module."""
from .routes import sharing_ns
__all__ = ["sharing_ns"]

View File

@@ -1,301 +0,0 @@
"""Conversation sharing routes."""
import uuid
from bson.binary import Binary, UuidRepresentation
from bson.dbref import DBRef
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, inputs, Namespace, Resource
from application.api import api
from application.api.user.base import (
agents_collection,
attachments_collection,
conversations_collection,
db,
shared_conversations_collections,
)
from application.utils import check_required_fields
sharing_ns = Namespace(
"sharing", description="Conversation sharing operations", path="/api"
)
@sharing_ns.route("/share")
class ShareConversation(Resource):
share_conversation_model = api.model(
"ShareConversationModel",
{
"conversation_id": fields.String(
required=True, description="Conversation ID"
),
"user": fields.String(description="User ID (optional)"),
"prompt_id": fields.String(description="Prompt ID (optional)"),
"chunks": fields.Integer(description="Chunks count (optional)"),
},
)
@api.expect(share_conversation_model)
@api.doc(description="Share a conversation")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
required_fields = ["conversation_id"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
is_promptable = request.args.get("isPromptable", type=inputs.boolean)
if is_promptable is None:
return make_response(
jsonify({"success": False, "message": "isPromptable is required"}), 400
)
conversation_id = data["conversation_id"]
try:
conversation = conversations_collection.find_one(
{"_id": ObjectId(conversation_id)}
)
if conversation is None:
return make_response(
jsonify(
{
"status": "error",
"message": "Conversation does not exist",
}
),
404,
)
current_n_queries = len(conversation["queries"])
explicit_binary = Binary.from_uuid(
uuid.uuid4(), UuidRepresentation.STANDARD
)
if is_promptable:
prompt_id = data.get("prompt_id", "default")
chunks = data.get("chunks", "2")
name = conversation["name"] + "(shared)"
new_api_key_data = {
"prompt_id": prompt_id,
"chunks": chunks,
"user": user,
}
if "source" in data and ObjectId.is_valid(data["source"]):
new_api_key_data["source"] = DBRef(
"sources", ObjectId(data["source"])
)
if "retriever" in data:
new_api_key_data["retriever"] = data["retriever"]
pre_existing_api_document = agents_collection.find_one(new_api_key_data)
if pre_existing_api_document:
api_uuid = pre_existing_api_document["key"]
pre_existing = shared_conversations_collections.find_one(
{
"conversation_id": DBRef(
"conversations", ObjectId(conversation_id)
),
"isPromptable": is_promptable,
"first_n_queries": current_n_queries,
"user": user,
"api_key": api_uuid,
}
)
if pre_existing is not None:
return make_response(
jsonify(
{
"success": True,
"identifier": str(pre_existing["uuid"].as_uuid()),
}
),
200,
)
else:
shared_conversations_collections.insert_one(
{
"uuid": explicit_binary,
"conversation_id": {
"$ref": "conversations",
"$id": ObjectId(conversation_id),
},
"isPromptable": is_promptable,
"first_n_queries": current_n_queries,
"user": user,
"api_key": api_uuid,
}
)
return make_response(
jsonify(
{
"success": True,
"identifier": str(explicit_binary.as_uuid()),
}
),
201,
)
else:
api_uuid = str(uuid.uuid4())
new_api_key_data["key"] = api_uuid
new_api_key_data["name"] = name
if "source" in data and ObjectId.is_valid(data["source"]):
new_api_key_data["source"] = DBRef(
"sources", ObjectId(data["source"])
)
if "retriever" in data:
new_api_key_data["retriever"] = data["retriever"]
agents_collection.insert_one(new_api_key_data)
shared_conversations_collections.insert_one(
{
"uuid": explicit_binary,
"conversation_id": {
"$ref": "conversations",
"$id": ObjectId(conversation_id),
},
"isPromptable": is_promptable,
"first_n_queries": current_n_queries,
"user": user,
"api_key": api_uuid,
}
)
return make_response(
jsonify(
{
"success": True,
"identifier": str(explicit_binary.as_uuid()),
}
),
201,
)
pre_existing = shared_conversations_collections.find_one(
{
"conversation_id": DBRef(
"conversations", ObjectId(conversation_id)
),
"isPromptable": is_promptable,
"first_n_queries": current_n_queries,
"user": user,
}
)
if pre_existing is not None:
return make_response(
jsonify(
{
"success": True,
"identifier": str(pre_existing["uuid"].as_uuid()),
}
),
200,
)
else:
shared_conversations_collections.insert_one(
{
"uuid": explicit_binary,
"conversation_id": {
"$ref": "conversations",
"$id": ObjectId(conversation_id),
},
"isPromptable": is_promptable,
"first_n_queries": current_n_queries,
"user": user,
}
)
return make_response(
jsonify(
{"success": True, "identifier": str(explicit_binary.as_uuid())}
),
201,
)
except Exception as err:
current_app.logger.error(
f"Error sharing conversation: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
@sharing_ns.route("/shared_conversation/<string:identifier>")
class GetPubliclySharedConversations(Resource):
@api.doc(description="Get publicly shared conversations by identifier")
def get(self, identifier: str):
try:
query_uuid = Binary.from_uuid(
uuid.UUID(identifier), UuidRepresentation.STANDARD
)
shared = shared_conversations_collections.find_one({"uuid": query_uuid})
conversation_queries = []
if (
shared
and "conversation_id" in shared
and isinstance(shared["conversation_id"], DBRef)
):
conversation_ref = shared["conversation_id"]
conversation = db.dereference(conversation_ref)
if conversation is None:
return make_response(
jsonify(
{
"success": False,
"error": "might have broken url or the conversation does not exist",
}
),
404,
)
conversation_queries = conversation["queries"][
: (shared["first_n_queries"])
]
for query in conversation_queries:
if "attachments" in query and query["attachments"]:
attachment_details = []
for attachment_id in query["attachments"]:
try:
attachment = attachments_collection.find_one(
{"_id": ObjectId(attachment_id)}
)
if attachment:
attachment_details.append(
{
"id": str(attachment["_id"]),
"fileName": attachment.get(
"filename", "Unknown file"
),
}
)
except Exception as e:
current_app.logger.error(
f"Error retrieving attachment {attachment_id}: {e}",
exc_info=True,
)
query["attachments"] = attachment_details
else:
return make_response(
jsonify(
{
"success": False,
"error": "might have broken url or the conversation does not exist",
}
),
404,
)
date = conversation["_id"].generation_time.isoformat()
res = {
"success": True,
"queries": conversation_queries,
"title": conversation["name"],
"timestamp": date,
}
if shared["isPromptable"] and "api_key" in shared:
res["api_key"] = shared["api_key"]
return make_response(jsonify(res), 200)
except Exception as err:
current_app.logger.error(
f"Error getting shared conversation: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)

View File

@@ -1,7 +0,0 @@
"""Sources module."""
from .chunks import sources_chunks_ns
from .routes import sources_ns
from .upload import sources_upload_ns
__all__ = ["sources_ns", "sources_chunks_ns", "sources_upload_ns"]

View File

@@ -1,278 +0,0 @@
"""Source document management chunk management."""
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, Namespace, Resource
from application.api import api
from application.api.user.base import get_vector_store, sources_collection
from application.utils import check_required_fields, num_tokens_from_string
sources_chunks_ns = Namespace(
"sources", description="Source document management operations", path="/api"
)
@sources_chunks_ns.route("/get_chunks")
class GetChunks(Resource):
@api.doc(
description="Retrieves chunks from a document, optionally filtered by file path and search term",
params={
"id": "The document ID",
"page": "Page number for pagination",
"per_page": "Number of chunks per page",
"path": "Optional: Filter chunks by relative file path",
"search": "Optional: Search term to filter chunks by title or content",
},
)
def get(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
doc_id = request.args.get("id")
page = int(request.args.get("page", 1))
per_page = int(request.args.get("per_page", 10))
path = request.args.get("path")
search_term = request.args.get("search", "").strip().lower()
if not ObjectId.is_valid(doc_id):
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
if not doc:
return make_response(
jsonify({"error": "Document not found or access denied"}), 404
)
try:
store = get_vector_store(doc_id)
chunks = store.get_chunks()
filtered_chunks = []
for chunk in chunks:
metadata = chunk.get("metadata", {})
# Filter by path if provided
if path:
chunk_source = metadata.get("source", "")
# Check if the chunk's source matches the requested path
if not chunk_source or not chunk_source.endswith(path):
continue
# Filter by search term if provided
if search_term:
text_match = search_term in chunk.get("text", "").lower()
title_match = search_term in metadata.get("title", "").lower()
if not (text_match or title_match):
continue
filtered_chunks.append(chunk)
chunks = filtered_chunks
total_chunks = len(chunks)
start = (page - 1) * per_page
end = start + per_page
paginated_chunks = chunks[start:end]
return make_response(
jsonify(
{
"page": page,
"per_page": per_page,
"total": total_chunks,
"chunks": paginated_chunks,
"path": path if path else None,
"search": search_term if search_term else None,
}
),
200,
)
except Exception as e:
current_app.logger.error(f"Error getting chunks: {e}", exc_info=True)
return make_response(jsonify({"success": False}), 500)
@sources_chunks_ns.route("/add_chunk")
class AddChunk(Resource):
@api.expect(
api.model(
"AddChunkModel",
{
"id": fields.String(required=True, description="Document ID"),
"text": fields.String(required=True, description="Text of the chunk"),
"metadata": fields.Raw(
required=False,
description="Metadata associated with the chunk",
),
},
)
)
@api.doc(
description="Adds a new chunk to the document",
)
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
required_fields = ["id", "text"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
doc_id = data.get("id")
text = data.get("text")
metadata = data.get("metadata", {})
token_count = num_tokens_from_string(text)
metadata["token_count"] = token_count
if not ObjectId.is_valid(doc_id):
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
if not doc:
return make_response(
jsonify({"error": "Document not found or access denied"}), 404
)
try:
store = get_vector_store(doc_id)
chunk_id = store.add_chunk(text, metadata)
return make_response(
jsonify({"message": "Chunk added successfully", "chunk_id": chunk_id}),
201,
)
except Exception as e:
current_app.logger.error(f"Error adding chunk: {e}", exc_info=True)
return make_response(jsonify({"success": False}), 500)
@sources_chunks_ns.route("/delete_chunk")
class DeleteChunk(Resource):
@api.doc(
description="Deletes a specific chunk from the document.",
params={"id": "The document ID", "chunk_id": "The ID of the chunk to delete"},
)
def delete(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
doc_id = request.args.get("id")
chunk_id = request.args.get("chunk_id")
if not ObjectId.is_valid(doc_id):
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
if not doc:
return make_response(
jsonify({"error": "Document not found or access denied"}), 404
)
try:
store = get_vector_store(doc_id)
deleted = store.delete_chunk(chunk_id)
if deleted:
return make_response(
jsonify({"message": "Chunk deleted successfully"}), 200
)
else:
return make_response(
jsonify({"message": "Chunk not found or could not be deleted"}),
404,
)
except Exception as e:
current_app.logger.error(f"Error deleting chunk: {e}", exc_info=True)
return make_response(jsonify({"success": False}), 500)
@sources_chunks_ns.route("/update_chunk")
class UpdateChunk(Resource):
@api.expect(
api.model(
"UpdateChunkModel",
{
"id": fields.String(required=True, description="Document ID"),
"chunk_id": fields.String(
required=True, description="Chunk ID to update"
),
"text": fields.String(
required=False, description="New text of the chunk"
),
"metadata": fields.Raw(
required=False,
description="Updated metadata associated with the chunk",
),
},
)
)
@api.doc(
description="Updates an existing chunk in the document.",
)
def put(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
required_fields = ["id", "chunk_id"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
doc_id = data.get("id")
chunk_id = data.get("chunk_id")
text = data.get("text")
metadata = data.get("metadata")
if text is not None:
token_count = num_tokens_from_string(text)
if metadata is None:
metadata = {}
metadata["token_count"] = token_count
if not ObjectId.is_valid(doc_id):
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
if not doc:
return make_response(
jsonify({"error": "Document not found or access denied"}), 404
)
try:
store = get_vector_store(doc_id)
chunks = store.get_chunks()
existing_chunk = next((c for c in chunks if c["doc_id"] == chunk_id), None)
if not existing_chunk:
return make_response(jsonify({"error": "Chunk not found"}), 404)
new_text = text if text is not None else existing_chunk["text"]
if metadata is not None:
new_metadata = existing_chunk["metadata"].copy()
new_metadata.update(metadata)
else:
new_metadata = existing_chunk["metadata"].copy()
if text is not None:
new_metadata["token_count"] = num_tokens_from_string(new_text)
try:
new_chunk_id = store.add_chunk(new_text, new_metadata)
deleted = store.delete_chunk(chunk_id)
if not deleted:
current_app.logger.warning(
f"Failed to delete old chunk {chunk_id}, but new chunk {new_chunk_id} was created"
)
return make_response(
jsonify(
{
"message": "Chunk updated successfully",
"chunk_id": new_chunk_id,
"original_chunk_id": chunk_id,
}
),
200,
)
except Exception as add_error:
current_app.logger.error(f"Failed to add updated chunk: {add_error}")
return make_response(
jsonify({"error": "Failed to update chunk - addition failed"}), 500
)
except Exception as e:
current_app.logger.error(f"Error updating chunk: {e}", exc_info=True)
return make_response(jsonify({"success": False}), 500)

View File

@@ -1,350 +0,0 @@
"""Source document management routes."""
import json
import math
import os
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, redirect, request
from flask_restx import fields, Namespace, Resource
from werkzeug.utils import secure_filename
from application.api import api
from application.api.user.base import sources_collection
from application.core.settings import settings
from application.storage.storage_creator import StorageCreator
from application.utils import check_required_fields
from application.vectorstore.vector_creator import VectorCreator
sources_ns = Namespace(
"sources", description="Source document management operations", path="/api"
)
@sources_ns.route("/sources")
class CombinedJson(Resource):
@api.doc(description="Provide JSON file with combined available indexes")
def get(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = [
{
"name": "Default",
"date": "default",
"model": settings.EMBEDDINGS_NAME,
"location": "remote",
"tokens": "",
"retriever": "classic",
}
]
try:
for index in sources_collection.find({"user": user}).sort("date", -1):
data.append(
{
"id": str(index["_id"]),
"name": index.get("name"),
"date": index.get("date"),
"model": settings.EMBEDDINGS_NAME,
"location": "local",
"tokens": index.get("tokens", ""),
"retriever": index.get("retriever", "classic"),
"syncFrequency": index.get("sync_frequency", ""),
"is_nested": bool(index.get("directory_structure")),
"type": index.get(
"type", "file"
), # Add type field with default "file"
}
)
except Exception as err:
current_app.logger.error(f"Error retrieving sources: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify(data), 200)
@sources_ns.route("/sources/paginated")
class PaginatedSources(Resource):
@api.doc(description="Get document with pagination, sorting and filtering")
def get(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
sort_field = request.args.get("sort", "date") # Default to 'date'
sort_order = request.args.get("order", "desc") # Default to 'desc'
page = int(request.args.get("page", 1)) # Default to 1
rows_per_page = int(request.args.get("rows", 10)) # Default to 10
# add .strip() to remove leading and trailing whitespaces
search_term = request.args.get(
"search", ""
).strip() # add search for filter documents
# Prepare query for filtering
query = {"user": user}
if search_term:
query["name"] = {
"$regex": search_term,
"$options": "i", # using case-insensitive search
}
total_documents = sources_collection.count_documents(query)
total_pages = max(1, math.ceil(total_documents / rows_per_page))
page = min(
max(1, page), total_pages
) # add this to make sure page inbound is within the range
sort_order = 1 if sort_order == "asc" else -1
skip = (page - 1) * rows_per_page
try:
documents = (
sources_collection.find(query)
.sort(sort_field, sort_order)
.skip(skip)
.limit(rows_per_page)
)
paginated_docs = []
for doc in documents:
doc_data = {
"id": str(doc["_id"]),
"name": doc.get("name", ""),
"date": doc.get("date", ""),
"model": settings.EMBEDDINGS_NAME,
"location": "local",
"tokens": doc.get("tokens", ""),
"retriever": doc.get("retriever", "classic"),
"syncFrequency": doc.get("sync_frequency", ""),
"isNested": bool(doc.get("directory_structure")),
"type": doc.get("type", "file"),
}
paginated_docs.append(doc_data)
response = {
"total": total_documents,
"totalPages": total_pages,
"currentPage": page,
"paginated": paginated_docs,
}
return make_response(jsonify(response), 200)
except Exception as err:
current_app.logger.error(
f"Error retrieving paginated sources: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
@sources_ns.route("/docs_check")
class CheckDocs(Resource):
check_docs_model = api.model(
"CheckDocsModel",
{"docs": fields.String(required=True, description="Document name")},
)
@api.expect(check_docs_model)
@api.doc(description="Check if document exists")
def post(self):
data = request.get_json()
required_fields = ["docs"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
vectorstore = "vectors/" + secure_filename(data["docs"])
if os.path.exists(vectorstore) or data["docs"] == "default":
return {"status": "exists"}, 200
except Exception as err:
current_app.logger.error(f"Error checking document: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"status": "not found"}), 404)
@sources_ns.route("/delete_by_ids")
class DeleteByIds(Resource):
@api.doc(
description="Deletes documents from the vector store by IDs",
params={"path": "Comma-separated list of IDs"},
)
def get(self):
ids = request.args.get("path")
if not ids:
return make_response(
jsonify({"success": False, "message": "Missing required fields"}), 400
)
try:
result = sources_collection.delete_index(ids=ids)
if result:
return make_response(jsonify({"success": True}), 200)
except Exception as err:
current_app.logger.error(f"Error deleting indexes: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": False}), 400)
@sources_ns.route("/delete_old")
class DeleteOldIndexes(Resource):
@api.doc(
description="Deletes old indexes and associated files",
params={"source_id": "The source ID to delete"},
)
def get(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
source_id = request.args.get("source_id")
if not source_id:
return make_response(
jsonify({"success": False, "message": "Missing required fields"}), 400
)
doc = sources_collection.find_one(
{"_id": ObjectId(source_id), "user": decoded_token.get("sub")}
)
if not doc:
return make_response(jsonify({"status": "not found"}), 404)
storage = StorageCreator.get_storage()
try:
# Delete vector index
if settings.VECTOR_STORE == "faiss":
index_path = f"indexes/{str(doc['_id'])}"
if storage.file_exists(f"{index_path}/index.faiss"):
storage.delete_file(f"{index_path}/index.faiss")
if storage.file_exists(f"{index_path}/index.pkl"):
storage.delete_file(f"{index_path}/index.pkl")
else:
vectorstore = VectorCreator.create_vectorstore(
settings.VECTOR_STORE, source_id=str(doc["_id"])
)
vectorstore.delete_index()
if "file_path" in doc and doc["file_path"]:
file_path = doc["file_path"]
if storage.is_directory(file_path):
files = storage.list_files(file_path)
for f in files:
storage.delete_file(f)
else:
storage.delete_file(file_path)
except FileNotFoundError:
pass
except Exception as err:
current_app.logger.error(
f"Error deleting files and indexes: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
sources_collection.delete_one({"_id": ObjectId(source_id)})
return make_response(jsonify({"success": True}), 200)
@sources_ns.route("/combine")
class RedirectToSources(Resource):
@api.doc(
description="Redirects /api/combine to /api/sources for backward compatibility"
)
def get(self):
return redirect("/api/sources", code=301)
@sources_ns.route("/manage_sync")
class ManageSync(Resource):
manage_sync_model = api.model(
"ManageSyncModel",
{
"source_id": fields.String(required=True, description="Source ID"),
"sync_frequency": fields.String(
required=True,
description="Sync frequency (never, daily, weekly, monthly)",
),
},
)
@api.expect(manage_sync_model)
@api.doc(description="Manage sync frequency for sources")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
required_fields = ["source_id", "sync_frequency"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
source_id = data["source_id"]
sync_frequency = data["sync_frequency"]
if sync_frequency not in ["never", "daily", "weekly", "monthly"]:
return make_response(
jsonify({"success": False, "message": "Invalid frequency"}), 400
)
update_data = {"$set": {"sync_frequency": sync_frequency}}
try:
sources_collection.update_one(
{
"_id": ObjectId(source_id),
"user": user,
},
update_data,
)
except Exception as err:
current_app.logger.error(
f"Error updating sync frequency: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True}), 200)
@sources_ns.route("/directory_structure")
class DirectoryStructure(Resource):
@api.doc(
description="Get the directory structure for a document",
params={"id": "The document ID"},
)
def get(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
doc_id = request.args.get("id")
if not doc_id:
return make_response(jsonify({"error": "Document ID is required"}), 400)
if not ObjectId.is_valid(doc_id):
return make_response(jsonify({"error": "Invalid document ID"}), 400)
try:
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
if not doc:
return make_response(
jsonify({"error": "Document not found or access denied"}), 404
)
directory_structure = doc.get("directory_structure", {})
base_path = doc.get("file_path", "")
provider = None
remote_data = doc.get("remote_data")
try:
if isinstance(remote_data, str) and remote_data:
remote_data_obj = json.loads(remote_data)
provider = remote_data_obj.get("provider")
except Exception as e:
current_app.logger.warning(
f"Failed to parse remote_data for doc {doc_id}: {e}"
)
return make_response(
jsonify(
{
"success": True,
"directory_structure": directory_structure,
"base_path": base_path,
"provider": provider,
}
),
200,
)
except Exception as e:
current_app.logger.error(
f"Error retrieving directory structure: {e}", exc_info=True
)
return make_response(jsonify({"success": False, "error": str(e)}), 500)

View File

@@ -1,572 +0,0 @@
"""Source document management upload functionality."""
import json
import os
import tempfile
import zipfile
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, Namespace, Resource
from application.api import api
from application.api.user.base import sources_collection
from application.api.user.tasks import ingest, ingest_connector_task, ingest_remote
from application.core.settings import settings
from application.parser.connectors.connector_creator import ConnectorCreator
from application.storage.storage_creator import StorageCreator
from application.utils import check_required_fields, safe_filename
sources_upload_ns = Namespace(
"sources", description="Source document management operations", path="/api"
)
@sources_upload_ns.route("/upload")
class UploadFile(Resource):
@api.expect(
api.model(
"UploadModel",
{
"user": fields.String(required=True, description="User ID"),
"name": fields.String(required=True, description="Job name"),
"file": fields.Raw(required=True, description="File(s) to upload"),
},
)
)
@api.doc(
description="Uploads a file to be vectorized and indexed",
)
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
data = request.form
files = request.files.getlist("file")
required_fields = ["user", "name"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields or not files or all(file.filename == "" for file in files):
return make_response(
jsonify(
{
"status": "error",
"message": "Missing required fields or files",
}
),
400,
)
user = decoded_token.get("sub")
job_name = request.form["name"]
# Create safe versions for filesystem operations
safe_user = safe_filename(user)
dir_name = safe_filename(job_name)
base_path = f"{settings.UPLOAD_FOLDER}/{safe_user}/{dir_name}"
try:
storage = StorageCreator.get_storage()
for file in files:
original_filename = file.filename
safe_file = safe_filename(original_filename)
with tempfile.TemporaryDirectory() as temp_dir:
temp_file_path = os.path.join(temp_dir, safe_file)
file.save(temp_file_path)
if zipfile.is_zipfile(temp_file_path):
try:
with zipfile.ZipFile(temp_file_path, "r") as zip_ref:
zip_ref.extractall(path=temp_dir)
# Walk through extracted files and upload them
for root, _, files in os.walk(temp_dir):
for extracted_file in files:
if (
os.path.join(root, extracted_file)
== temp_file_path
):
continue
rel_path = os.path.relpath(
os.path.join(root, extracted_file), temp_dir
)
storage_path = f"{base_path}/{rel_path}"
with open(
os.path.join(root, extracted_file), "rb"
) as f:
storage.save_file(f, storage_path)
except Exception as e:
current_app.logger.error(
f"Error extracting zip: {e}", exc_info=True
)
# If zip extraction fails, save the original zip file
file_path = f"{base_path}/{safe_file}"
with open(temp_file_path, "rb") as f:
storage.save_file(f, file_path)
else:
# For non-zip files, save directly
file_path = f"{base_path}/{safe_file}"
with open(temp_file_path, "rb") as f:
storage.save_file(f, file_path)
task = ingest.delay(
settings.UPLOAD_FOLDER,
[
".rst",
".md",
".pdf",
".txt",
".docx",
".csv",
".epub",
".html",
".mdx",
".json",
".xlsx",
".pptx",
".png",
".jpg",
".jpeg",
],
job_name,
user,
file_path=base_path,
filename=dir_name,
)
except Exception as err:
current_app.logger.error(f"Error uploading file: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
@sources_upload_ns.route("/remote")
class UploadRemote(Resource):
@api.expect(
api.model(
"RemoteUploadModel",
{
"user": fields.String(required=True, description="User ID"),
"source": fields.String(
required=True, description="Source of the data"
),
"name": fields.String(required=True, description="Job name"),
"data": fields.String(required=True, description="Data to process"),
"repo_url": fields.String(description="GitHub repository URL"),
},
)
)
@api.doc(
description="Uploads remote source for vectorization",
)
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
data = request.form
required_fields = ["user", "source", "name", "data"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
config = json.loads(data["data"])
source_data = None
if data["source"] == "github":
source_data = config.get("repo_url")
elif data["source"] in ["crawler", "url"]:
source_data = config.get("url")
elif data["source"] == "reddit":
source_data = config
elif data["source"] in ConnectorCreator.get_supported_connectors():
session_token = config.get("session_token")
if not session_token:
return make_response(
jsonify(
{
"success": False,
"error": f"Missing session_token in {data['source']} configuration",
}
),
400,
)
# Process file_ids
file_ids = config.get("file_ids", [])
if isinstance(file_ids, str):
file_ids = [id.strip() for id in file_ids.split(",") if id.strip()]
elif not isinstance(file_ids, list):
file_ids = []
# Process folder_ids
folder_ids = config.get("folder_ids", [])
if isinstance(folder_ids, str):
folder_ids = [
id.strip() for id in folder_ids.split(",") if id.strip()
]
elif not isinstance(folder_ids, list):
folder_ids = []
config["file_ids"] = file_ids
config["folder_ids"] = folder_ids
task = ingest_connector_task.delay(
job_name=data["name"],
user=decoded_token.get("sub"),
source_type=data["source"],
session_token=session_token,
file_ids=file_ids,
folder_ids=folder_ids,
recursive=config.get("recursive", False),
retriever=config.get("retriever", "classic"),
)
return make_response(
jsonify({"success": True, "task_id": task.id}), 200
)
task = ingest_remote.delay(
source_data=source_data,
job_name=data["name"],
user=decoded_token.get("sub"),
loader=data["source"],
)
except Exception as err:
current_app.logger.error(
f"Error uploading remote source: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
@sources_upload_ns.route("/manage_source_files")
class ManageSourceFiles(Resource):
@api.expect(
api.model(
"ManageSourceFilesModel",
{
"source_id": fields.String(
required=True, description="Source ID to modify"
),
"operation": fields.String(
required=True,
description="Operation: 'add', 'remove', or 'remove_directory'",
),
"file_paths": fields.List(
fields.String,
required=False,
description="File paths to remove (for remove operation)",
),
"directory_path": fields.String(
required=False,
description="Directory path to remove (for remove_directory operation)",
),
"file": fields.Raw(
required=False, description="Files to add (for add operation)"
),
"parent_dir": fields.String(
required=False,
description="Parent directory path relative to source root",
),
},
)
)
@api.doc(
description="Add files, remove files, or remove directories from an existing source",
)
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(
jsonify({"success": False, "message": "Unauthorized"}), 401
)
user = decoded_token.get("sub")
source_id = request.form.get("source_id")
operation = request.form.get("operation")
if not source_id or not operation:
return make_response(
jsonify(
{
"success": False,
"message": "source_id and operation are required",
}
),
400,
)
if operation not in ["add", "remove", "remove_directory"]:
return make_response(
jsonify(
{
"success": False,
"message": "operation must be 'add', 'remove', or 'remove_directory'",
}
),
400,
)
try:
ObjectId(source_id)
except Exception:
return make_response(
jsonify({"success": False, "message": "Invalid source ID format"}), 400
)
try:
source = sources_collection.find_one(
{"_id": ObjectId(source_id), "user": user}
)
if not source:
return make_response(
jsonify(
{
"success": False,
"message": "Source not found or access denied",
}
),
404,
)
except Exception as err:
current_app.logger.error(f"Error finding source: {err}", exc_info=True)
return make_response(
jsonify({"success": False, "message": "Database error"}), 500
)
try:
storage = StorageCreator.get_storage()
source_file_path = source.get("file_path", "")
parent_dir = request.form.get("parent_dir", "")
if parent_dir and (parent_dir.startswith("/") or ".." in parent_dir):
return make_response(
jsonify(
{"success": False, "message": "Invalid parent directory path"}
),
400,
)
if operation == "add":
files = request.files.getlist("file")
if not files or all(file.filename == "" for file in files):
return make_response(
jsonify(
{
"success": False,
"message": "No files provided for add operation",
}
),
400,
)
added_files = []
target_dir = source_file_path
if parent_dir:
target_dir = f"{source_file_path}/{parent_dir}"
for file in files:
if file.filename:
safe_filename_str = safe_filename(file.filename)
file_path = f"{target_dir}/{safe_filename_str}"
# Save file to storage
storage.save_file(file, file_path)
added_files.append(safe_filename_str)
# Trigger re-ingestion pipeline
from application.api.user.tasks import reingest_source_task
task = reingest_source_task.delay(source_id=source_id, user=user)
return make_response(
jsonify(
{
"success": True,
"message": f"Added {len(added_files)} files",
"added_files": added_files,
"parent_dir": parent_dir,
"reingest_task_id": task.id,
}
),
200,
)
elif operation == "remove":
file_paths_str = request.form.get("file_paths")
if not file_paths_str:
return make_response(
jsonify(
{
"success": False,
"message": "file_paths required for remove operation",
}
),
400,
)
try:
file_paths = (
json.loads(file_paths_str)
if isinstance(file_paths_str, str)
else file_paths_str
)
except Exception:
return make_response(
jsonify(
{"success": False, "message": "Invalid file_paths format"}
),
400,
)
# Remove files from storage and directory structure
removed_files = []
for file_path in file_paths:
full_path = f"{source_file_path}/{file_path}"
# Remove from storage
if storage.file_exists(full_path):
storage.delete_file(full_path)
removed_files.append(file_path)
# Trigger re-ingestion pipeline
from application.api.user.tasks import reingest_source_task
task = reingest_source_task.delay(source_id=source_id, user=user)
return make_response(
jsonify(
{
"success": True,
"message": f"Removed {len(removed_files)} files",
"removed_files": removed_files,
"reingest_task_id": task.id,
}
),
200,
)
elif operation == "remove_directory":
directory_path = request.form.get("directory_path")
if not directory_path:
return make_response(
jsonify(
{
"success": False,
"message": "directory_path required for remove_directory operation",
}
),
400,
)
# Validate directory path (prevent path traversal)
if directory_path.startswith("/") or ".." in directory_path:
current_app.logger.warning(
f"Invalid directory path attempted for removal. "
f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}"
)
return make_response(
jsonify(
{"success": False, "message": "Invalid directory path"}
),
400,
)
full_directory_path = (
f"{source_file_path}/{directory_path}"
if directory_path
else source_file_path
)
if not storage.is_directory(full_directory_path):
current_app.logger.warning(
f"Directory not found or is not a directory for removal. "
f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}, "
f"Full path: {full_directory_path}"
)
return make_response(
jsonify(
{
"success": False,
"message": "Directory not found or is not a directory",
}
),
404,
)
success = storage.remove_directory(full_directory_path)
if not success:
current_app.logger.error(
f"Failed to remove directory from storage. "
f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}, "
f"Full path: {full_directory_path}"
)
return make_response(
jsonify(
{"success": False, "message": "Failed to remove directory"}
),
500,
)
current_app.logger.info(
f"Successfully removed directory. "
f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}, "
f"Full path: {full_directory_path}"
)
# Trigger re-ingestion pipeline
from application.api.user.tasks import reingest_source_task
task = reingest_source_task.delay(source_id=source_id, user=user)
return make_response(
jsonify(
{
"success": True,
"message": f"Successfully removed directory: {directory_path}",
"removed_directory": directory_path,
"reingest_task_id": task.id,
}
),
200,
)
except Exception as err:
error_context = f"operation={operation}, user={user}, source_id={source_id}"
if operation == "remove_directory":
directory_path = request.form.get("directory_path", "")
error_context += f", directory_path={directory_path}"
elif operation == "remove":
file_paths_str = request.form.get("file_paths", "")
error_context += f", file_paths={file_paths_str}"
elif operation == "add":
parent_dir = request.form.get("parent_dir", "")
error_context += f", parent_dir={parent_dir}"
current_app.logger.error(
f"Error managing source files: {err} ({error_context})", exc_info=True
)
return make_response(
jsonify({"success": False, "message": "Operation failed"}), 500
)
@sources_upload_ns.route("/task_status")
class TaskStatus(Resource):
task_status_model = api.model(
"TaskStatusModel",
{"task_id": fields.String(required=True, description="Task ID")},
)
@api.expect(task_status_model)
@api.doc(description="Get celery job status")
def get(self):
task_id = request.args.get("task_id")
if not task_id:
return make_response(
jsonify({"success": False, "message": "Task ID is required"}), 400
)
try:
from application.celery_init import celery
task = celery.AsyncResult(task_id)
task_meta = task.info
print(f"Task status: {task.status}")
if not isinstance(
task_meta, (dict, list, str, int, float, bool, type(None))
):
task_meta = str(task_meta) # Convert to a string representation
except Exception as err:
current_app.logger.error(f"Error getting task status: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"status": task.status, "result": task_meta}), 200)

View File

@@ -1,20 +1,12 @@
from datetime import timedelta
from application.celery_init import celery
from application.worker import (
agent_webhook_worker,
attachment_worker,
ingest_worker,
mcp_oauth,
mcp_oauth_status,
remote_worker,
sync_worker,
)
from application.worker import ingest_worker, remote_worker, sync_worker
@celery.task(bind=True)
def ingest(self, directory, formats, job_name, user, file_path, filename):
resp = ingest_worker(self, directory, formats, job_name, file_path, filename, user)
def ingest(self, directory, formats, name_job, filename, user):
resp = ingest_worker(self, directory, formats, name_job, filename, user)
return resp
@@ -24,66 +16,12 @@ def ingest_remote(self, source_data, job_name, user, loader):
return resp
@celery.task(bind=True)
def reingest_source_task(self, source_id, user):
from application.worker import reingest_source_worker
resp = reingest_source_worker(self, source_id, user)
return resp
@celery.task(bind=True)
def schedule_syncs(self, frequency):
resp = sync_worker(self, frequency)
return resp
@celery.task(bind=True)
def store_attachment(self, file_info, user):
resp = attachment_worker(self, file_info, user)
return resp
@celery.task(bind=True)
def process_agent_webhook(self, agent_id, payload):
resp = agent_webhook_worker(self, agent_id, payload)
return resp
@celery.task(bind=True)
def ingest_connector_task(
self,
job_name,
user,
source_type,
session_token=None,
file_ids=None,
folder_ids=None,
recursive=True,
retriever="classic",
operation_mode="upload",
doc_id=None,
sync_frequency="never",
):
from application.worker import ingest_connector
resp = ingest_connector(
self,
job_name,
user,
source_type,
session_token=session_token,
file_ids=file_ids,
folder_ids=folder_ids,
recursive=recursive,
retriever=retriever,
operation_mode=operation_mode,
doc_id=doc_id,
sync_frequency=sync_frequency,
)
return resp
@celery.on_after_configure.connect
def setup_periodic_tasks(sender, **kwargs):
sender.add_periodic_task(
@@ -98,15 +36,3 @@ def setup_periodic_tasks(sender, **kwargs):
timedelta(days=30),
schedule_syncs.s("monthly"),
)
@celery.task(bind=True)
def mcp_oauth_task(self, config, user):
resp = mcp_oauth(self, config, user)
return resp
@celery.task(bind=True)
def mcp_oauth_status_task(self, task_id):
resp = mcp_oauth_status(self, task_id)
return resp

View File

@@ -1,6 +0,0 @@
"""Tools module."""
from .mcp import tools_mcp_ns
from .routes import tools_ns
__all__ = ["tools_ns", "tools_mcp_ns"]

View File

@@ -1,333 +0,0 @@
"""Tool management MCP server integration."""
import json
from email.quoprimime import unquote
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, redirect, request
from flask_restx import fields, Namespace, Resource
from application.agents.tools.mcp_tool import MCPOAuthManager, MCPTool
from application.api import api
from application.api.user.base import user_tools_collection
from application.cache import get_redis_instance
from application.security.encryption import encrypt_credentials
from application.utils import check_required_fields
tools_mcp_ns = Namespace("tools", description="Tool management operations", path="/api")
@tools_mcp_ns.route("/mcp_server/test")
class TestMCPServerConfig(Resource):
@api.expect(
api.model(
"MCPServerTestModel",
{
"config": fields.Raw(
required=True, description="MCP server configuration to test"
),
},
)
)
@api.doc(description="Test MCP server connection with provided configuration")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
required_fields = ["config"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
config = data["config"]
auth_credentials = {}
auth_type = config.get("auth_type", "none")
if auth_type == "api_key" and "api_key" in config:
auth_credentials["api_key"] = config["api_key"]
if "api_key_header" in config:
auth_credentials["api_key_header"] = config["api_key_header"]
elif auth_type == "bearer" and "bearer_token" in config:
auth_credentials["bearer_token"] = config["bearer_token"]
elif auth_type == "basic":
if "username" in config:
auth_credentials["username"] = config["username"]
if "password" in config:
auth_credentials["password"] = config["password"]
test_config = config.copy()
test_config["auth_credentials"] = auth_credentials
mcp_tool = MCPTool(config=test_config, user_id=user)
result = mcp_tool.test_connection()
return make_response(jsonify(result), 200)
except Exception as e:
current_app.logger.error(f"Error testing MCP server: {e}", exc_info=True)
return make_response(
jsonify(
{"success": False, "error": f"Connection test failed: {str(e)}"}
),
500,
)
@tools_mcp_ns.route("/mcp_server/save")
class MCPServerSave(Resource):
@api.expect(
api.model(
"MCPServerSaveModel",
{
"id": fields.String(
required=False, description="Tool ID for updates (optional)"
),
"displayName": fields.String(
required=True, description="Display name for the MCP server"
),
"config": fields.Raw(
required=True, description="MCP server configuration"
),
"status": fields.Boolean(
required=False, default=True, description="Tool status"
),
},
)
)
@api.doc(description="Create or update MCP server with automatic tool discovery")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
required_fields = ["displayName", "config"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
config = data["config"]
auth_credentials = {}
auth_type = config.get("auth_type", "none")
if auth_type == "api_key":
if "api_key" in config and config["api_key"]:
auth_credentials["api_key"] = config["api_key"]
if "api_key_header" in config:
auth_credentials["api_key_header"] = config["api_key_header"]
elif auth_type == "bearer":
if "bearer_token" in config and config["bearer_token"]:
auth_credentials["bearer_token"] = config["bearer_token"]
elif auth_type == "basic":
if "username" in config and config["username"]:
auth_credentials["username"] = config["username"]
if "password" in config and config["password"]:
auth_credentials["password"] = config["password"]
mcp_config = config.copy()
mcp_config["auth_credentials"] = auth_credentials
if auth_type == "oauth":
if not config.get("oauth_task_id"):
return make_response(
jsonify(
{
"success": False,
"error": "Connection not authorized. Please complete the OAuth authorization first.",
}
),
400,
)
redis_client = get_redis_instance()
manager = MCPOAuthManager(redis_client)
result = manager.get_oauth_status(config["oauth_task_id"])
if not result.get("status") == "completed":
return make_response(
jsonify(
{
"success": False,
"error": "OAuth failed or not completed. Please try authorizing again.",
}
),
400,
)
actions_metadata = result.get("tools", [])
elif auth_type == "none" or auth_credentials:
mcp_tool = MCPTool(config=mcp_config, user_id=user)
mcp_tool.discover_tools()
actions_metadata = mcp_tool.get_actions_metadata()
else:
raise Exception(
"No valid credentials provided for the selected authentication type"
)
storage_config = config.copy()
if auth_credentials:
encrypted_credentials_string = encrypt_credentials(
auth_credentials, user
)
storage_config["encrypted_credentials"] = encrypted_credentials_string
for field in [
"api_key",
"bearer_token",
"username",
"password",
"api_key_header",
]:
storage_config.pop(field, None)
transformed_actions = []
for action in actions_metadata:
action["active"] = True
if "parameters" in action:
if "properties" in action["parameters"]:
for param_name, param_details in action["parameters"][
"properties"
].items():
param_details["filled_by_llm"] = True
param_details["value"] = ""
transformed_actions.append(action)
tool_data = {
"name": "mcp_tool",
"displayName": data["displayName"],
"customName": data["displayName"],
"description": f"MCP Server: {storage_config.get('server_url', 'Unknown')}",
"config": storage_config,
"actions": transformed_actions,
"status": data.get("status", True),
"user": user,
}
tool_id = data.get("id")
if tool_id:
result = user_tools_collection.update_one(
{"_id": ObjectId(tool_id), "user": user, "name": "mcp_tool"},
{"$set": {k: v for k, v in tool_data.items() if k != "user"}},
)
if result.matched_count == 0:
return make_response(
jsonify(
{
"success": False,
"error": "Tool not found or access denied",
}
),
404,
)
response_data = {
"success": True,
"id": tool_id,
"message": f"MCP server updated successfully! Discovered {len(transformed_actions)} tools.",
"tools_count": len(transformed_actions),
}
else:
result = user_tools_collection.insert_one(tool_data)
tool_id = str(result.inserted_id)
response_data = {
"success": True,
"id": tool_id,
"message": f"MCP server created successfully! Discovered {len(transformed_actions)} tools.",
"tools_count": len(transformed_actions),
}
return make_response(jsonify(response_data), 200)
except Exception as e:
current_app.logger.error(f"Error saving MCP server: {e}", exc_info=True)
return make_response(
jsonify(
{"success": False, "error": f"Failed to save MCP server: {str(e)}"}
),
500,
)
@tools_mcp_ns.route("/mcp_server/callback")
class MCPOAuthCallback(Resource):
@api.expect(
api.model(
"MCPServerCallbackModel",
{
"code": fields.String(required=True, description="Authorization code"),
"state": fields.String(required=True, description="State parameter"),
"error": fields.String(
required=False, description="Error message (if any)"
),
},
)
)
@api.doc(
description="Handle OAuth callback by providing the authorization code and state"
)
def get(self):
code = request.args.get("code")
state = request.args.get("state")
error = request.args.get("error")
if error:
return redirect(
f"/api/connectors/callback-status?status=error&message=OAuth+error:+{error}.+Please+try+again+and+make+sure+to+grant+all+requested+permissions,+including+offline+access.&provider=mcp_tool"
)
if not code or not state:
return redirect(
"/api/connectors/callback-status?status=error&message=Authorization+code+or+state+not+provided.+Please+complete+the+authorization+process+and+make+sure+to+grant+offline+access.&provider=mcp_tool"
)
try:
redis_client = get_redis_instance()
if not redis_client:
return redirect(
"/api/connectors/callback-status?status=error&message=Internal+server+error:+Redis+not+available.&provider=mcp_tool"
)
code = unquote(code)
manager = MCPOAuthManager(redis_client)
success = manager.handle_oauth_callback(state, code, error)
if success:
return redirect(
"/api/connectors/callback-status?status=success&message=Authorization+code+received+successfully.+You+can+close+this+window.&provider=mcp_tool"
)
else:
return redirect(
"/api/connectors/callback-status?status=error&message=OAuth+callback+failed.&provider=mcp_tool"
)
except Exception as e:
current_app.logger.error(
f"Error handling MCP OAuth callback: {str(e)}", exc_info=True
)
return redirect(
f"/api/connectors/callback-status?status=error&message=Internal+server+error:+{str(e)}.&provider=mcp_tool"
)
@tools_mcp_ns.route("/mcp_server/oauth_status/<string:task_id>")
class MCPOAuthStatus(Resource):
def get(self, task_id):
"""
Get current status of OAuth flow.
Frontend should poll this endpoint periodically.
"""
try:
redis_client = get_redis_instance()
status_key = f"mcp_oauth_status:{task_id}"
status_data = redis_client.get(status_key)
if status_data:
status = json.loads(status_data)
return make_response(
jsonify({"success": True, "task_id": task_id, **status})
)
else:
return make_response(
jsonify(
{
"success": False,
"error": "Task not found or expired",
"task_id": task_id,
}
),
404,
)
except Exception as e:
current_app.logger.error(
f"Error getting OAuth status for task {task_id}: {str(e)}"
)
return make_response(
jsonify({"success": False, "error": str(e), "task_id": task_id}), 500
)

View File

@@ -1,415 +0,0 @@
"""Tool management routes."""
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, Namespace, Resource
from application.agents.tools.tool_manager import ToolManager
from application.api import api
from application.api.user.base import user_tools_collection
from application.security.encryption import decrypt_credentials, encrypt_credentials
from application.utils import check_required_fields, validate_function_name
tool_config = {}
tool_manager = ToolManager(config=tool_config)
tools_ns = Namespace("tools", description="Tool management operations", path="/api")
@tools_ns.route("/available_tools")
class AvailableTools(Resource):
@api.doc(description="Get available tools for a user")
def get(self):
try:
tools_metadata = []
for tool_name, tool_instance in tool_manager.tools.items():
doc = tool_instance.__doc__.strip()
lines = doc.split("\n", 1)
name = lines[0].strip()
description = lines[1].strip() if len(lines) > 1 else ""
tools_metadata.append(
{
"name": tool_name,
"displayName": name,
"description": description,
"configRequirements": tool_instance.get_config_requirements(),
}
)
except Exception as err:
current_app.logger.error(
f"Error getting available tools: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True, "data": tools_metadata}), 200)
@tools_ns.route("/get_tools")
class GetTools(Resource):
@api.doc(description="Get tools created by a user")
def get(self):
try:
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
tools = user_tools_collection.find({"user": user})
user_tools = []
for tool in tools:
tool["id"] = str(tool["_id"])
tool.pop("_id")
user_tools.append(tool)
except Exception as err:
current_app.logger.error(f"Error getting user tools: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True, "tools": user_tools}), 200)
@tools_ns.route("/create_tool")
class CreateTool(Resource):
@api.expect(
api.model(
"CreateToolModel",
{
"name": fields.String(required=True, description="Name of the tool"),
"displayName": fields.String(
required=True, description="Display name for the tool"
),
"description": fields.String(
required=True, description="Tool description"
),
"config": fields.Raw(
required=True, description="Configuration of the tool"
),
"customName": fields.String(
required=False, description="Custom name for the tool"
),
"status": fields.Boolean(
required=True, description="Status of the tool"
),
},
)
)
@api.doc(description="Create a new tool")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
required_fields = [
"name",
"displayName",
"description",
"config",
"status",
]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
tool_instance = tool_manager.tools.get(data["name"])
if not tool_instance:
return make_response(
jsonify({"success": False, "message": "Tool not found"}), 404
)
actions_metadata = tool_instance.get_actions_metadata()
transformed_actions = []
for action in actions_metadata:
action["active"] = True
if "parameters" in action:
if "properties" in action["parameters"]:
for param_name, param_details in action["parameters"][
"properties"
].items():
param_details["filled_by_llm"] = True
param_details["value"] = ""
transformed_actions.append(action)
except Exception as err:
current_app.logger.error(
f"Error getting tool actions: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
try:
new_tool = {
"user": user,
"name": data["name"],
"displayName": data["displayName"],
"description": data["description"],
"customName": data.get("customName", ""),
"actions": transformed_actions,
"config": data["config"],
"status": data["status"],
}
resp = user_tools_collection.insert_one(new_tool)
new_id = str(resp.inserted_id)
except Exception as err:
current_app.logger.error(f"Error creating tool: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"id": new_id}), 200)
@tools_ns.route("/update_tool")
class UpdateTool(Resource):
@api.expect(
api.model(
"UpdateToolModel",
{
"id": fields.String(required=True, description="Tool ID"),
"name": fields.String(description="Name of the tool"),
"displayName": fields.String(description="Display name for the tool"),
"customName": fields.String(description="Custom name for the tool"),
"description": fields.String(description="Tool description"),
"config": fields.Raw(description="Configuration of the tool"),
"actions": fields.List(
fields.Raw, description="Actions the tool can perform"
),
"status": fields.Boolean(description="Status of the tool"),
},
)
)
@api.doc(description="Update a tool by ID")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
required_fields = ["id"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
update_data = {}
if "name" in data:
update_data["name"] = data["name"]
if "displayName" in data:
update_data["displayName"] = data["displayName"]
if "customName" in data:
update_data["customName"] = data["customName"]
if "description" in data:
update_data["description"] = data["description"]
if "actions" in data:
update_data["actions"] = data["actions"]
if "config" in data:
if "actions" in data["config"]:
for action_name in list(data["config"]["actions"].keys()):
if not validate_function_name(action_name):
return make_response(
jsonify(
{
"success": False,
"message": f"Invalid function name '{action_name}'. Function names must match pattern '^[a-zA-Z0-9_-]+$'.",
"param": "tools[].function.name",
}
),
400,
)
tool_doc = user_tools_collection.find_one(
{"_id": ObjectId(data["id"]), "user": user}
)
if tool_doc and tool_doc.get("name") == "mcp_tool":
config = data["config"]
existing_config = tool_doc.get("config", {})
storage_config = existing_config.copy()
storage_config.update(config)
existing_credentials = {}
if "encrypted_credentials" in existing_config:
existing_credentials = decrypt_credentials(
existing_config["encrypted_credentials"], user
)
auth_credentials = existing_credentials.copy()
auth_type = storage_config.get("auth_type", "none")
if auth_type == "api_key":
if "api_key" in config and config["api_key"]:
auth_credentials["api_key"] = config["api_key"]
if "api_key_header" in config:
auth_credentials["api_key_header"] = config[
"api_key_header"
]
elif auth_type == "bearer":
if "bearer_token" in config and config["bearer_token"]:
auth_credentials["bearer_token"] = config["bearer_token"]
elif "encrypted_token" in config and config["encrypted_token"]:
auth_credentials["bearer_token"] = config["encrypted_token"]
elif auth_type == "basic":
if "username" in config and config["username"]:
auth_credentials["username"] = config["username"]
if "password" in config and config["password"]:
auth_credentials["password"] = config["password"]
if auth_type != "none" and auth_credentials:
encrypted_credentials_string = encrypt_credentials(
auth_credentials, user
)
storage_config["encrypted_credentials"] = (
encrypted_credentials_string
)
elif auth_type == "none":
storage_config.pop("encrypted_credentials", None)
for field in [
"api_key",
"bearer_token",
"encrypted_token",
"username",
"password",
"api_key_header",
]:
storage_config.pop(field, None)
update_data["config"] = storage_config
else:
update_data["config"] = data["config"]
if "status" in data:
update_data["status"] = data["status"]
user_tools_collection.update_one(
{"_id": ObjectId(data["id"]), "user": user},
{"$set": update_data},
)
except Exception as err:
current_app.logger.error(f"Error updating tool: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True}), 200)
@tools_ns.route("/update_tool_config")
class UpdateToolConfig(Resource):
@api.expect(
api.model(
"UpdateToolConfigModel",
{
"id": fields.String(required=True, description="Tool ID"),
"config": fields.Raw(
required=True, description="Configuration of the tool"
),
},
)
)
@api.doc(description="Update the configuration of a tool")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
required_fields = ["id", "config"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
user_tools_collection.update_one(
{"_id": ObjectId(data["id"]), "user": user},
{"$set": {"config": data["config"]}},
)
except Exception as err:
current_app.logger.error(
f"Error updating tool config: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True}), 200)
@tools_ns.route("/update_tool_actions")
class UpdateToolActions(Resource):
@api.expect(
api.model(
"UpdateToolActionsModel",
{
"id": fields.String(required=True, description="Tool ID"),
"actions": fields.List(
fields.Raw,
required=True,
description="Actions the tool can perform",
),
},
)
)
@api.doc(description="Update the actions of a tool")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
required_fields = ["id", "actions"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
user_tools_collection.update_one(
{"_id": ObjectId(data["id"]), "user": user},
{"$set": {"actions": data["actions"]}},
)
except Exception as err:
current_app.logger.error(
f"Error updating tool actions: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True}), 200)
@tools_ns.route("/update_tool_status")
class UpdateToolStatus(Resource):
@api.expect(
api.model(
"UpdateToolStatusModel",
{
"id": fields.String(required=True, description="Tool ID"),
"status": fields.Boolean(
required=True, description="Status of the tool"
),
},
)
)
@api.doc(description="Update the status of a tool")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
required_fields = ["id", "status"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
user_tools_collection.update_one(
{"_id": ObjectId(data["id"]), "user": user},
{"$set": {"status": data["status"]}},
)
except Exception as err:
current_app.logger.error(
f"Error updating tool status: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True}), 200)
@tools_ns.route("/delete_tool")
class DeleteTool(Resource):
@api.expect(
api.model(
"DeleteToolModel",
{"id": fields.String(required=True, description="Tool ID")},
)
)
@api.doc(description="Delete a tool by ID")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
required_fields = ["id"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
result = user_tools_collection.delete_one(
{"_id": ObjectId(data["id"]), "user": user}
)
if result.deleted_count == 0:
return {"success": False, "message": "Tool not found"}, 404
except Exception as err:
current_app.logger.error(f"Error deleting tool: {err}", exc_info=True)
return {"success": False}, 400
return {"success": True}, 200

View File

@@ -1,37 +1,28 @@
import os
import platform
import uuid
import dotenv
from flask import Flask, jsonify, redirect, request
from jose import jwt
from application.auth import handle_auth
from flask import Flask, redirect, request
from application.api.answer.routes import answer
from application.api.internal.routes import internal
from application.api.user.routes import user
from application.celery_init import celery
from application.core.logging_config import setup_logging
setup_logging()
from application.api import api # noqa: E402
from application.api.answer import answer # noqa: E402
from application.api.internal.routes import internal # noqa: E402
from application.api.user.routes import user # noqa: E402
from application.api.connector.routes import connector # noqa: E402
from application.celery_init import celery # noqa: E402
from application.core.settings import settings # noqa: E402
from application.core.settings import settings
from application.extensions import api
if platform.system() == "Windows":
import pathlib
pathlib.PosixPath = pathlib.WindowsPath
dotenv.load_dotenv()
setup_logging()
app = Flask(__name__)
app.register_blueprint(user)
app.register_blueprint(answer)
app.register_blueprint(internal)
app.register_blueprint(connector)
app.config.update(
UPLOAD_FOLDER="inputs",
CELERY_BROKER_URL=settings.CELERY_BROKER_URL,
@@ -41,24 +32,6 @@ app.config.update(
celery.config_from_object("application.celeryconfig")
api.init_app(app)
if settings.AUTH_TYPE in ("simple_jwt", "session_jwt") and not settings.JWT_SECRET_KEY:
key_file = ".jwt_secret_key"
try:
with open(key_file, "r") as f:
settings.JWT_SECRET_KEY = f.read().strip()
except FileNotFoundError:
new_key = os.urandom(32).hex()
with open(key_file, "w") as f:
f.write(new_key)
settings.JWT_SECRET_KEY = new_key
except Exception as e:
raise RuntimeError(f"Failed to setup JWT_SECRET_KEY: {e}")
SIMPLE_JWT_TOKEN = None
if settings.AUTH_TYPE == "simple_jwt":
payload = {"sub": "local"}
SIMPLE_JWT_TOKEN = jwt.encode(payload, settings.JWT_SECRET_KEY, algorithm="HS256")
print(f"Generated Simple JWT Token: {SIMPLE_JWT_TOKEN}")
@app.route("/")
def home():
@@ -68,46 +41,11 @@ def home():
return "Welcome to DocsGPT Backend!"
@app.route("/api/config")
def get_config():
response = {
"auth_type": settings.AUTH_TYPE,
"requires_auth": settings.AUTH_TYPE in ["simple_jwt", "session_jwt"],
}
return jsonify(response)
@app.route("/api/generate_token")
def generate_token():
if settings.AUTH_TYPE == "session_jwt":
new_user_id = str(uuid.uuid4())
token = jwt.encode(
{"sub": new_user_id}, settings.JWT_SECRET_KEY, algorithm="HS256"
)
return jsonify({"token": token})
return jsonify({"error": "Token generation not allowed in current auth mode"}), 400
@app.before_request
def authenticate_request():
if request.method == "OPTIONS":
return "", 200
decoded_token = handle_auth(request)
if not decoded_token:
request.decoded_token = None
elif "error" in decoded_token:
return jsonify(decoded_token), 401
else:
request.decoded_token = decoded_token
@app.after_request
def after_request(response):
response.headers.add("Access-Control-Allow-Origin", "*")
response.headers.add("Access-Control-Allow-Headers", "Content-Type, Authorization")
response.headers.add(
"Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS"
)
response.headers.add("Access-Control-Allow-Headers", "Content-Type,Authorization")
response.headers.add("Access-Control-Allow-Methods", "GET,PUT,POST,DELETE,OPTIONS")
return response

View File

@@ -1,28 +0,0 @@
from jose import jwt
from application.core.settings import settings
def handle_auth(request, data={}):
if settings.AUTH_TYPE in ["simple_jwt", "session_jwt"]:
jwt_token = request.headers.get("Authorization")
if not jwt_token:
return None
jwt_token = jwt_token.replace("Bearer ", "")
try:
decoded_token = jwt.decode(
jwt_token,
settings.JWT_SECRET_KEY,
algorithms=["HS256"],
options={"verify_exp": False},
)
return decoded_token
except Exception as e:
return {
"message": f"Authentication error: {str(e)}",
"error": "invalid_token",
}
else:
return {"sub": "local"}

View File

@@ -1,117 +1,93 @@
import redis
import time
import json
import logging
import time
from threading import Lock
import redis
from application.core.settings import settings
from application.utils import get_hash
logger = logging.getLogger(__name__)
_redis_instance = None
_redis_creation_failed = False
_instance_lock = Lock()
def get_redis_instance():
global _redis_instance, _redis_creation_failed
if _redis_instance is None and not _redis_creation_failed:
global _redis_instance
if _redis_instance is None:
with _instance_lock:
if _redis_instance is None and not _redis_creation_failed:
if _redis_instance is None:
try:
_redis_instance = redis.Redis.from_url(
settings.CACHE_REDIS_URL, socket_connect_timeout=2
)
except ValueError as e:
logger.error(f"Invalid Redis URL: {e}")
_redis_creation_failed = True # Stop future attempts
_redis_instance = None
_redis_instance = redis.Redis.from_url(settings.CACHE_REDIS_URL, socket_connect_timeout=2)
except redis.ConnectionError as e:
logger.error(f"Redis connection error: {e}")
_redis_instance = None # Keep trying for connection errors
_redis_instance = None
return _redis_instance
def gen_cache_key(messages, model="docgpt", tools=None):
def gen_cache_key(*messages, model="docgpt"):
if not all(isinstance(msg, dict) for msg in messages):
raise ValueError("All messages must be dictionaries.")
messages_str = json.dumps(messages)
tools_str = json.dumps(str(tools)) if tools else ""
combined = f"{model}_{messages_str}_{tools_str}"
messages_str = json.dumps(list(messages), sort_keys=True)
combined = f"{model}_{messages_str}"
cache_key = get_hash(combined)
return cache_key
def gen_cache(func):
def wrapper(self, model, messages, stream, tools=None, *args, **kwargs):
if tools is not None:
return func(self, model, messages, stream, tools, *args, **kwargs)
def wrapper(self, model, messages, *args, **kwargs):
try:
cache_key = gen_cache_key(messages, model, tools)
cache_key = gen_cache_key(*messages)
redis_client = get_redis_instance()
if redis_client:
try:
cached_response = redis_client.get(cache_key)
if cached_response:
return cached_response.decode('utf-8')
except redis.ConnectionError as e:
logger.error(f"Redis connection error: {e}")
result = func(self, model, messages, *args, **kwargs)
if redis_client:
try:
redis_client.set(cache_key, result, ex=1800)
except redis.ConnectionError as e:
logger.error(f"Redis connection error: {e}")
return result
except ValueError as e:
logger.error(f"Cache key generation failed: {e}")
return func(self, model, messages, stream, tools, *args, **kwargs)
redis_client = get_redis_instance()
if redis_client:
try:
cached_response = redis_client.get(cache_key)
if cached_response:
return cached_response.decode("utf-8")
except Exception as e:
logger.error(f"Error getting cached response: {e}", exc_info=True)
result = func(self, model, messages, stream, tools, *args, **kwargs)
if redis_client and isinstance(result, str):
try:
redis_client.set(cache_key, result, ex=1800)
except Exception as e:
logger.error(f"Error setting cache: {e}", exc_info=True)
return result
logger.error(e)
return "Error: No user message found in the conversation to generate a cache key."
return wrapper
def stream_cache(func):
def wrapper(self, model, messages, stream, tools=None, *args, **kwargs):
if tools is not None:
yield from func(self, model, messages, stream, tools, *args, **kwargs)
return
def wrapper(self, model, messages, stream, *args, **kwargs):
cache_key = gen_cache_key(*messages)
logger.info(f"Stream cache key: {cache_key}")
try:
cache_key = gen_cache_key(messages, model, tools)
except ValueError as e:
logger.error(f"Cache key generation failed: {e}")
yield from func(self, model, messages, stream, tools, *args, **kwargs)
return
redis_client = get_redis_instance()
if redis_client:
try:
cached_response = redis_client.get(cache_key)
if cached_response:
logger.info(f"Cache hit for stream key: {cache_key}")
cached_response = json.loads(cached_response.decode("utf-8"))
cached_response = json.loads(cached_response.decode('utf-8'))
for chunk in cached_response:
yield chunk
time.sleep(0.03) # Simulate streaming delay
time.sleep(0.03)
return
except Exception as e:
logger.error(f"Error getting cached stream: {e}", exc_info=True)
except redis.ConnectionError as e:
logger.error(f"Redis connection error: {e}")
result = func(self, model, messages, stream, *args, **kwargs)
stream_cache_data = []
for chunk in func(self, model, messages, stream, tools, *args, **kwargs):
for chunk in result:
stream_cache_data.append(chunk)
yield chunk
stream_cache_data.append(str(chunk))
if redis_client:
try:
redis_client.set(cache_key, json.dumps(stream_cache_data), ex=1800)
logger.info(f"Stream cache saved for key: {cache_key}")
except Exception as e:
logger.error(f"Error setting stream cache: {e}", exc_info=True)
return wrapper
except redis.ConnectionError as e:
logger.error(f"Redis connection error: {e}")
return wrapper

View File

@@ -2,22 +2,14 @@ from celery import Celery
from application.core.settings import settings
from celery.signals import setup_logging
def make_celery(app_name=__name__):
celery = Celery(
app_name,
broker=settings.CELERY_BROKER_URL,
backend=settings.CELERY_RESULT_BACKEND,
)
celery = Celery(app_name, broker=settings.CELERY_BROKER_URL, backend=settings.CELERY_RESULT_BACKEND)
celery.conf.update(settings)
return celery
@setup_logging.connect
def config_loggers(*args, **kwargs):
from application.core.logging_config import setup_logging
setup_logging()
celery = make_celery()

View File

@@ -1,55 +1,25 @@
import os
from pathlib import Path
from typing import Optional
import os
from pydantic_settings import BaseSettings
current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
class Settings(BaseSettings):
AUTH_TYPE: Optional[str] = None # simple_jwt, session_jwt, or None
LLM_PROVIDER: str = "docsgpt"
LLM_NAME: Optional[str] = (
None # if LLM_PROVIDER is openai, LLM_NAME can be gpt-4 or gpt-3.5-turbo
)
LLM_NAME: str = "docsgpt"
MODEL_NAME: Optional[str] = None # if LLM_NAME is openai, MODEL_NAME can be gpt-4 or gpt-3.5-turbo
EMBEDDINGS_NAME: str = "huggingface_sentence-transformers/all-mpnet-base-v2"
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
MONGO_URI: str = "mongodb://localhost:27017/docsgpt"
MONGO_DB_NAME: str = "docsgpt"
LLM_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf")
MODEL_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf")
DEFAULT_MAX_HISTORY: int = 150
LLM_TOKEN_LIMITS: dict = {
"gpt-4o-mini": 128000,
"gpt-3.5-turbo": 4096,
"claude-2": 1e5,
"gemini-2.5-flash": 1e6,
}
MODEL_TOKEN_LIMITS: dict = {"gpt-3.5-turbo": 4096, "claude-2": 1e5}
UPLOAD_FOLDER: str = "inputs"
PARSE_PDF_AS_IMAGE: bool = False
PARSE_IMAGE_REMOTE: bool = False
VECTOR_STORE: str = (
"faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb"
)
RETRIEVERS_ENABLED: list = ["classic_rag"]
AGENT_NAME: str = "classic"
FALLBACK_LLM_PROVIDER: Optional[str] = None # provider for fallback llm
FALLBACK_LLM_NAME: Optional[str] = None # model name for fallback llm
FALLBACK_LLM_API_KEY: Optional[str] = None # api key for fallback llm
# Google Drive integration
GOOGLE_CLIENT_ID: Optional[str] = (
None # Replace with your actual Google OAuth client ID
)
GOOGLE_CLIENT_SECRET: Optional[str] = (
None # Replace with your actual Google OAuth client secret
)
CONNECTOR_REDIRECT_BASE_URI: Optional[str] = (
"http://127.0.0.1:7091/api/connectors/callback" ##add redirect url as it is to your provider's console(gcp)
)
VECTOR_STORE: str = "faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb"
RETRIEVERS_ENABLED: list = ["classic_rag", "duckduck_search"] # also brave_search
# LLM Cache
CACHE_REDIS_URL: str = "redis://localhost:6379/2"
@@ -57,18 +27,12 @@ class Settings(BaseSettings):
API_URL: str = "http://localhost:7091" # backend url for celery worker
API_KEY: Optional[str] = None # LLM api key
EMBEDDINGS_KEY: Optional[str] = (
None # api key for embeddings (if using openai, just copy API_KEY)
)
EMBEDDINGS_KEY: Optional[str] = None # api key for embeddings (if using openai, just copy API_KEY)
OPENAI_API_BASE: Optional[str] = None # azure openai api base url
OPENAI_API_VERSION: Optional[str] = None # azure openai api version
AZURE_DEPLOYMENT_NAME: Optional[str] = None # azure deployment name for answering
AZURE_EMBEDDINGS_DEPLOYMENT_NAME: Optional[str] = (
None # azure deployment name for embeddings
)
OPENAI_BASE_URL: Optional[str] = (
None # openai base url for open ai compatable models
)
AZURE_EMBEDDINGS_DEPLOYMENT_NAME: Optional[str] = None # azure deployment name for embeddings
OPENAI_BASE_URL: Optional[str] = None # openai base url for open ai compatable models
# elasticsearch
ELASTIC_CLOUD_ID: Optional[str] = None # cloud id for elasticsearch
@@ -101,29 +65,18 @@ class Settings(BaseSettings):
QDRANT_PATH: Optional[str] = None
QDRANT_DISTANCE_FUNC: str = "Cosine"
# PGVector vectorstore config
PGVECTOR_CONNECTION_STRING: Optional[str] = None
# Milvus vectorstore config
MILVUS_COLLECTION_NAME: Optional[str] = "docsgpt"
MILVUS_URI: Optional[str] = "./milvus_local.db" # milvus lite version as default
MILVUS_URI: Optional[str] = "./milvus_local.db" # milvus lite version as default
MILVUS_TOKEN: Optional[str] = ""
# LanceDB vectorstore config
LANCEDB_PATH: str = "/tmp/lancedb" # Path where LanceDB stores its local data
LANCEDB_TABLE_NAME: Optional[str] = (
"docsgpts" # Name of the table to use for storing vectors
)
LANCEDB_TABLE_NAME: Optional[str] = "docsgpts" # Name of the table to use for storing vectors
BRAVE_SEARCH_API_KEY: Optional[str] = None
FLASK_DEBUG_MODE: bool = False
STORAGE_TYPE: str = "local" # local or s3
URL_STRATEGY: str = "backend" # backend or s3
JWT_SECRET_KEY: str = ""
# Encryption settings
ENCRYPTION_SECRET_KEY: str = "default-docsgpt-encryption-key"
ELEVENLABS_API_KEY: Optional[str] = None
path = Path(__file__).parent.parent.absolute()
settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8")

View File

@@ -0,0 +1,7 @@
from flask_restx import Api
api = Api(
version="1.0",
title="DocsGPT API",
description="API for DocsGPT",
)

View File

@@ -17,7 +17,7 @@ class AnthropicLLM(BaseLLM):
self.AI_PROMPT = AI_PROMPT
def _raw_gen(
self, baseself, model, messages, stream=False, tools=None, max_tokens=300, **kwargs
self, baseself, model, messages, stream=False, max_tokens=300, **kwargs
):
context = messages[0]["content"]
user_question = messages[-1]["content"]
@@ -34,7 +34,7 @@ class AnthropicLLM(BaseLLM):
return completion.completion
def _raw_gen_stream(
self, baseself, model, messages, stream=True, tools=None, max_tokens=300, **kwargs
self, baseself, model, messages, stream=True, max_tokens=300, **kwargs
):
context = messages[0]["content"]
user_question = messages[-1]["content"]

View File

@@ -1,144 +1,29 @@
import logging
from abc import ABC, abstractmethod
from application.cache import gen_cache, stream_cache
from application.core.settings import settings
from application.usage import gen_token_usage, stream_token_usage
logger = logging.getLogger(__name__)
from application.cache import stream_cache, gen_cache
class BaseLLM(ABC):
def __init__(
self,
decoded_token=None,
):
self.decoded_token = decoded_token
def __init__(self):
self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
self.fallback_provider = settings.FALLBACK_LLM_PROVIDER
self.fallback_model_name = settings.FALLBACK_LLM_NAME
self.fallback_llm_api_key = settings.FALLBACK_LLM_API_KEY
self._fallback_llm = None
@property
def fallback_llm(self):
"""Lazy-loaded fallback LLM instance."""
if (
self._fallback_llm is None
and self.fallback_provider
and self.fallback_model_name
):
try:
from application.llm.llm_creator import LLMCreator
self._fallback_llm = LLMCreator.create_llm(
self.fallback_provider,
self.fallback_llm_api_key,
None,
self.decoded_token,
)
except Exception as e:
logger.error(
f"Failed to initialize fallback LLM: {str(e)}", exc_info=True
)
return self._fallback_llm
def _execute_with_fallback(
self, method_name: str, decorators: list, *args, **kwargs
):
"""
Unified method execution with fallback support.
Args:
method_name: Name of the raw method ('_raw_gen' or '_raw_gen_stream')
decorators: List of decorators to apply
*args: Positional arguments
**kwargs: Keyword arguments
"""
def decorated_method():
method = getattr(self, method_name)
for decorator in decorators:
method = decorator(method)
return method(self, *args, **kwargs)
try:
return decorated_method()
except Exception as e:
if not self.fallback_llm:
logger.error(f"Primary LLM failed and no fallback available: {str(e)}")
raise
logger.warning(
f"Falling back to {self.fallback_provider}/{self.fallback_model_name}. Error: {str(e)}"
)
fallback_method = getattr(
self.fallback_llm, method_name.replace("_raw_", "")
)
return fallback_method(*args, **kwargs)
def gen(self, model, messages, stream=False, tools=None, *args, **kwargs):
decorators = [gen_token_usage, gen_cache]
return self._execute_with_fallback(
"_raw_gen",
decorators,
model=model,
messages=messages,
stream=stream,
tools=tools,
*args,
**kwargs,
)
def gen_stream(self, model, messages, stream=True, tools=None, *args, **kwargs):
decorators = [stream_cache, stream_token_usage]
return self._execute_with_fallback(
"_raw_gen_stream",
decorators,
model=model,
messages=messages,
stream=stream,
tools=tools,
*args,
**kwargs,
)
def _apply_decorator(self, method, decorators, *args, **kwargs):
for decorator in decorators:
method = decorator(method)
return method(self, *args, **kwargs)
@abstractmethod
def _raw_gen(self, model, messages, stream, tools, *args, **kwargs):
def _raw_gen(self, model, messages, stream, *args, **kwargs):
pass
def gen(self, model, messages, stream=False, *args, **kwargs):
decorators = [gen_token_usage, gen_cache]
return self._apply_decorator(self._raw_gen, decorators=decorators, model=model, messages=messages, stream=stream, *args, **kwargs)
@abstractmethod
def _raw_gen_stream(self, model, messages, stream, *args, **kwargs):
pass
def supports_tools(self):
return hasattr(self, "_supports_tools") and callable(
getattr(self, "_supports_tools")
)
def _supports_tools(self):
raise NotImplementedError("Subclass must implement _supports_tools method")
def supports_structured_output(self):
"""Check if the LLM supports structured output/JSON schema enforcement"""
return hasattr(self, "_supports_structured_output") and callable(
getattr(self, "_supports_structured_output")
)
def _supports_structured_output(self):
return False
def prepare_structured_output_format(self, json_schema):
"""Prepare structured output format specific to the LLM provider"""
_ = json_schema
return None
def get_supported_attachment_types(self):
"""
Return a list of MIME types supported by this LLM for file uploads.
Returns:
list: List of supported MIME types
"""
return []
def gen_stream(self, model, messages, stream=True, *args, **kwargs):
decorators = [stream_cache, stream_token_usage]
return self._apply_decorator(self._raw_gen_stream, decorators=decorators, model=model, messages=messages, stream=stream, *args, **kwargs)

View File

@@ -1,131 +1,34 @@
import json
from application.core.settings import settings
from application.llm.base import BaseLLM
import json
import requests
class DocsGPTAPILLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
from openai import OpenAI
super().__init__(*args, **kwargs)
self.client = OpenAI(api_key="sk-docsgpt-public", base_url="https://oai.arc53.com")
self.user_api_key = user_api_key
self.api_key = api_key
self.user_api_key = user_api_key
self.endpoint = "https://llm.arc53.com"
def _clean_messages_openai(self, messages):
cleaned_messages = []
for message in messages:
role = message.get("role")
content = message.get("content")
def _raw_gen(self, baseself, model, messages, stream=False, *args, **kwargs):
response = requests.post(
f"{self.endpoint}/answer", json={"messages": messages, "max_new_tokens": 30}
)
response_clean = response.json()["a"].replace("###", "")
if role == "model":
role = "assistant"
return response_clean
if role and content is not None:
if isinstance(content, str):
cleaned_messages.append({"role": role, "content": content})
elif isinstance(content, list):
for item in content:
if "text" in item:
cleaned_messages.append(
{"role": role, "content": item["text"]}
)
elif "function_call" in item:
tool_call = {
"id": item["function_call"]["call_id"],
"type": "function",
"function": {
"name": item["function_call"]["name"],
"arguments": json.dumps(
item["function_call"]["args"]
),
},
}
cleaned_messages.append(
{
"role": "assistant",
"content": None,
"tool_calls": [tool_call],
}
)
elif "function_response" in item:
cleaned_messages.append(
{
"role": "tool",
"tool_call_id": item["function_response"][
"call_id"
],
"content": json.dumps(
item["function_response"]["response"]["result"]
),
}
)
else:
raise ValueError(
f"Unexpected content dictionary format: {item}"
)
else:
raise ValueError(f"Unexpected content type: {type(content)}")
def _raw_gen_stream(self, baseself, model, messages, stream=True, *args, **kwargs):
response = requests.post(
f"{self.endpoint}/stream",
json={"messages": messages, "max_new_tokens": 256},
stream=True,
)
return cleaned_messages
def _raw_gen(
self,
baseself,
model,
messages,
stream=False,
tools=None,
engine=settings.AZURE_DEPLOYMENT_NAME,
**kwargs,
):
messages = self._clean_messages_openai(messages)
if tools:
response = self.client.chat.completions.create(
model="docsgpt",
messages=messages,
stream=stream,
tools=tools,
**kwargs,
)
return response.choices[0]
else:
response = self.client.chat.completions.create(
model="docsgpt", messages=messages, stream=stream, **kwargs
)
return response.choices[0].message.content
def _raw_gen_stream(
self,
baseself,
model,
messages,
stream=True,
tools=None,
engine=settings.AZURE_DEPLOYMENT_NAME,
**kwargs,
):
messages = self._clean_messages_openai(messages)
if tools:
response = self.client.chat.completions.create(
model="docsgpt",
messages=messages,
stream=stream,
tools=tools,
**kwargs,
)
else:
response = self.client.chat.completions.create(
model="docsgpt", messages=messages, stream=stream, **kwargs
)
for line in response:
if len(line.choices) > 0 and line.choices[0].delta.content is not None and len(line.choices[0].delta.content) > 0:
yield line.choices[0].delta.content
elif len(line.choices) > 0:
yield line.choices[0]
def _supports_tools(self):
return True
for line in response.iter_lines():
if line:
data_str = line.decode("utf-8")
if data_str.startswith("data: "):
data = json.loads(data_str[6:])
yield data["a"]

View File

@@ -1,288 +1,21 @@
import json
import logging
from google import genai
from google.genai import types
from application.core.settings import settings
from application.llm.base import BaseLLM
from application.storage.storage_creator import StorageCreator
class GoogleLLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.api_key = api_key
self.user_api_key = user_api_key
self.client = genai.Client(api_key=self.api_key)
self.storage = StorageCreator.get_storage()
def get_supported_attachment_types(self):
"""
Return a list of MIME types supported by Google Gemini for file uploads.
Returns:
list: List of supported MIME types
"""
return [
"application/pdf",
"image/png",
"image/jpeg",
"image/jpg",
"image/webp",
"image/gif",
]
def prepare_messages_with_attachments(self, messages, attachments=None):
"""
Process attachments using Google AI's file API for more efficient handling.
Args:
messages (list): List of message dictionaries.
attachments (list): List of attachment dictionaries with content and metadata.
Returns:
list: Messages formatted with file references for Google AI API.
"""
if not attachments:
return messages
prepared_messages = messages.copy()
# Find the user message to attach files to the last one
user_message_index = None
for i in range(len(prepared_messages) - 1, -1, -1):
if prepared_messages[i].get("role") == "user":
user_message_index = i
break
if user_message_index is None:
user_message = {"role": "user", "content": []}
prepared_messages.append(user_message)
user_message_index = len(prepared_messages) - 1
if isinstance(prepared_messages[user_message_index].get("content"), str):
text_content = prepared_messages[user_message_index]["content"]
prepared_messages[user_message_index]["content"] = [
{"type": "text", "text": text_content}
]
elif not isinstance(prepared_messages[user_message_index].get("content"), list):
prepared_messages[user_message_index]["content"] = []
files = []
for attachment in attachments:
mime_type = attachment.get("mime_type")
if mime_type in self.get_supported_attachment_types():
try:
file_uri = self._upload_file_to_google(attachment)
logging.info(
f"GoogleLLM: Successfully uploaded file, got URI: {file_uri}"
)
files.append({"file_uri": file_uri, "mime_type": mime_type})
except Exception as e:
logging.error(
f"GoogleLLM: Error uploading file: {e}", exc_info=True
)
if "content" in attachment:
prepared_messages[user_message_index]["content"].append(
{
"type": "text",
"text": f"[File could not be processed: {attachment.get('path', 'unknown')}]",
}
)
if files:
logging.info(f"GoogleLLM: Adding {len(files)} files to message")
prepared_messages[user_message_index]["content"].append({"files": files})
return prepared_messages
def _upload_file_to_google(self, attachment):
"""
Upload a file to Google AI and return the file URI.
Args:
attachment (dict): Attachment dictionary with path and metadata.
Returns:
str: Google AI file URI for the uploaded file.
"""
if "google_file_uri" in attachment:
return attachment["google_file_uri"]
file_path = attachment.get("path")
if not file_path:
raise ValueError("No file path provided in attachment")
if not self.storage.file_exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
try:
file_uri = self.storage.process_file(
file_path,
lambda local_path, **kwargs: self.client.files.upload(
file=local_path
).uri,
)
from application.core.mongo_db import MongoDB
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
attachments_collection = db["attachments"]
if "_id" in attachment:
attachments_collection.update_one(
{"_id": attachment["_id"]}, {"$set": {"google_file_uri": file_uri}}
)
return file_uri
except Exception as e:
logging.error(f"Error uploading file to Google AI: {e}", exc_info=True)
raise
def _clean_messages_google(self, messages):
"""Convert OpenAI format messages to Google AI format."""
cleaned_messages = []
for message in messages:
role = message.get("role")
content = message.get("content")
if role == "assistant":
role = "model"
elif role == "tool":
role = "model"
parts = []
if role and content is not None:
if isinstance(content, str):
parts = [types.Part.from_text(text=content)]
elif isinstance(content, list):
for item in content:
if "text" in item:
parts.append(types.Part.from_text(text=item["text"]))
elif "function_call" in item:
parts.append(
types.Part.from_function_call(
name=item["function_call"]["name"],
args=item["function_call"]["args"],
)
)
elif "function_response" in item:
parts.append(
types.Part.from_function_response(
name=item["function_response"]["name"],
response=item["function_response"]["response"],
)
)
elif "files" in item:
for file_data in item["files"]:
parts.append(
types.Part.from_uri(
file_uri=file_data["file_uri"],
mime_type=file_data["mime_type"],
)
)
else:
raise ValueError(
f"Unexpected content dictionary format:{item}"
)
else:
raise ValueError(f"Unexpected content type: {type(content)}")
if parts:
cleaned_messages.append(types.Content(role=role, parts=parts))
return cleaned_messages
def _clean_schema(self, schema_obj):
"""
Recursively remove unsupported fields from schema objects
and validate required properties.
"""
if not isinstance(schema_obj, dict):
return schema_obj
allowed_fields = {
"type",
"description",
"items",
"properties",
"required",
"enum",
"pattern",
"minimum",
"maximum",
"nullable",
"default",
}
cleaned = {}
for key, value in schema_obj.items():
if key not in allowed_fields:
continue
elif key == "type" and isinstance(value, str):
cleaned[key] = value.upper()
elif isinstance(value, dict):
cleaned[key] = self._clean_schema(value)
elif isinstance(value, list):
cleaned[key] = [self._clean_schema(item) for item in value]
else:
cleaned[key] = value
# Validate that required properties actually exist in properties
if "required" in cleaned and "properties" in cleaned:
valid_required = []
properties_keys = set(cleaned["properties"].keys())
for required_prop in cleaned["required"]:
if required_prop in properties_keys:
valid_required.append(required_prop)
if valid_required:
cleaned["required"] = valid_required
else:
cleaned.pop("required", None)
elif "required" in cleaned and "properties" not in cleaned:
cleaned.pop("required", None)
return cleaned
def _clean_tools_format(self, tools_list):
"""Convert OpenAI format tools to Google AI format."""
genai_tools = []
for tool_data in tools_list:
if tool_data["type"] == "function":
function = tool_data["function"]
parameters = function["parameters"]
properties = parameters.get("properties", {})
if properties:
cleaned_properties = {}
for k, v in properties.items():
cleaned_properties[k] = self._clean_schema(v)
genai_function = dict(
name=function["name"],
description=function["description"],
parameters={
"type": "OBJECT",
"properties": cleaned_properties,
"required": (
parameters["required"]
if "required" in parameters
else []
),
},
)
else:
genai_function = dict(
name=function["name"],
description=function["description"],
)
genai_tool = types.Tool(function_declarations=[genai_function])
genai_tools.append(genai_tool)
return genai_tools
return [
{
"role": "model" if message["role"] == "system" else message["role"],
"parts": [message["content"]],
}
for message in messages[1:]
]
def _raw_gen(
self,
@@ -290,39 +23,13 @@ class GoogleLLM(BaseLLM):
model,
messages,
stream=False,
tools=None,
formatting="openai",
response_schema=None,
**kwargs,
):
"""Generate content using Google AI API without streaming."""
client = genai.Client(api_key=self.api_key)
if formatting == "openai":
messages = self._clean_messages_google(messages)
config = types.GenerateContentConfig()
if messages[0].role == "system":
config.system_instruction = messages[0].parts[0].text
messages = messages[1:]
if tools:
cleaned_tools = self._clean_tools_format(tools)
config.tools = cleaned_tools
# Add response schema for structured output if provided
if response_schema:
config.response_schema = response_schema
config.response_mime_type = "application/json"
response = client.models.generate_content(
model=model,
contents=messages,
config=config,
)
if tools:
return response
else:
return response.text
**kwargs
):
import google.generativeai as genai
genai.configure(api_key=self.api_key)
model = genai.GenerativeModel(model, system_instruction=messages[0]["content"])
response = model.generate_content(self._clean_messages_google(messages))
return response.text
def _raw_gen_stream(
self,
@@ -330,135 +37,12 @@ class GoogleLLM(BaseLLM):
model,
messages,
stream=True,
tools=None,
formatting="openai",
response_schema=None,
**kwargs,
):
"""Generate content using Google AI API with streaming."""
client = genai.Client(api_key=self.api_key)
if formatting == "openai":
messages = self._clean_messages_google(messages)
config = types.GenerateContentConfig()
if messages[0].role == "system":
config.system_instruction = messages[0].parts[0].text
messages = messages[1:]
if tools:
cleaned_tools = self._clean_tools_format(tools)
config.tools = cleaned_tools
# Add response schema for structured output if provided
if response_schema:
config.response_schema = response_schema
config.response_mime_type = "application/json"
# Check if we have both tools and file attachments
has_attachments = False
for message in messages:
for part in message.parts:
if hasattr(part, "file_data") and part.file_data is not None:
has_attachments = True
break
if has_attachments:
break
logging.info(
f"GoogleLLM: Starting stream generation. Model: {model}, Messages: {json.dumps(messages, default=str)}, Has attachments: {has_attachments}"
)
response = client.models.generate_content_stream(
model=model,
contents=messages,
config=config,
)
for chunk in response:
if hasattr(chunk, "candidates") and chunk.candidates:
for candidate in chunk.candidates:
if candidate.content and candidate.content.parts:
for part in candidate.content.parts:
if part.function_call:
yield part
elif part.text:
yield part.text
elif hasattr(chunk, "text"):
yield chunk.text
def _supports_tools(self):
"""Return whether this LLM supports function calling."""
return True
def _supports_structured_output(self):
"""Return whether this LLM supports structured JSON output."""
return True
def prepare_structured_output_format(self, json_schema):
"""Convert JSON schema to Google AI structured output format."""
if not json_schema:
return None
type_map = {
"object": "OBJECT",
"array": "ARRAY",
"string": "STRING",
"integer": "INTEGER",
"number": "NUMBER",
"boolean": "BOOLEAN",
}
def convert(schema):
if not isinstance(schema, dict):
return schema
result = {}
schema_type = schema.get("type")
if schema_type:
result["type"] = type_map.get(schema_type.lower(), schema_type.upper())
for key in [
"description",
"nullable",
"enum",
"minItems",
"maxItems",
"required",
"propertyOrdering",
]:
if key in schema:
result[key] = schema[key]
if "format" in schema:
format_value = schema["format"]
if schema_type == "string":
if format_value == "date":
result["format"] = "date-time"
elif format_value in ["enum", "date-time"]:
result["format"] = format_value
else:
result["format"] = format_value
if "properties" in schema:
result["properties"] = {
k: convert(v) for k, v in schema["properties"].items()
}
if "propertyOrdering" not in result and result.get("type") == "OBJECT":
result["propertyOrdering"] = list(result["properties"].keys())
if "items" in schema:
result["items"] = convert(schema["items"])
for field in ["anyOf", "oneOf", "allOf"]:
if field in schema:
result[field] = [convert(s) for s in schema[field]]
return result
try:
return convert(json_schema)
except Exception as e:
logging.error(
f"Error preparing structured output format for Google: {e}",
exc_info=True,
)
return None
**kwargs
):
import google.generativeai as genai
genai.configure(api_key=self.api_key)
model = genai.GenerativeModel(model, system_instruction=messages[0]["content"])
response = model.generate_content(self._clean_messages_google(messages), stream=True)
for line in response:
if line.text is not None:
yield line.text

View File

@@ -1,32 +1,45 @@
from application.llm.base import BaseLLM
from openai import OpenAI
class GroqLLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
from openai import OpenAI
super().__init__(*args, **kwargs)
self.client = OpenAI(api_key=api_key, base_url="https://api.groq.com/openai/v1")
self.api_key = api_key
self.user_api_key = user_api_key
def _raw_gen(self, baseself, model, messages, stream=False, tools=None, **kwargs):
if tools:
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, tools=tools, **kwargs
)
return response.choices[0]
else:
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)
return response.choices[0].message.content
def _raw_gen_stream(
self, baseself, model, messages, stream=True, tools=None, **kwargs
):
def _raw_gen(
self,
baseself,
model,
messages,
stream=False,
**kwargs
):
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)
return response.choices[0].message.content
def _raw_gen_stream(
self,
baseself,
model,
messages,
stream=True,
**kwargs
):
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)
for line in response:
# import sys
# print(line.choices[0].delta.content, file=sys.stderr)
if line.choices[0].delta.content is not None:
yield line.choices[0].delta.content

View File

@@ -1,351 +0,0 @@
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Generator, List, Optional, Union
from application.logging import build_stack_data
logger = logging.getLogger(__name__)
@dataclass
class ToolCall:
"""Represents a tool/function call from the LLM."""
id: str
name: str
arguments: Union[str, Dict]
index: Optional[int] = None
@classmethod
def from_dict(cls, data: Dict) -> "ToolCall":
"""Create ToolCall from dictionary."""
return cls(
id=data.get("id", ""),
name=data.get("name", ""),
arguments=data.get("arguments", {}),
index=data.get("index"),
)
@dataclass
class LLMResponse:
"""Represents a response from the LLM."""
content: str
tool_calls: List[ToolCall]
finish_reason: str
raw_response: Any
@property
def requires_tool_call(self) -> bool:
"""Check if the response requires tool calls."""
return bool(self.tool_calls) and self.finish_reason == "tool_calls"
class LLMHandler(ABC):
"""Abstract base class for LLM handlers."""
def __init__(self):
self.llm_calls = []
self.tool_calls = []
@abstractmethod
def parse_response(self, response: Any) -> LLMResponse:
"""Parse raw LLM response into standardized format."""
pass
@abstractmethod
def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict:
"""Create a tool result message for the conversation history."""
pass
@abstractmethod
def _iterate_stream(self, response: Any) -> Generator:
"""Iterate through streaming response chunks."""
pass
def process_message_flow(
self,
agent,
initial_response,
tools_dict: Dict,
messages: List[Dict],
attachments: Optional[List] = None,
stream: bool = False,
) -> Union[str, Generator]:
"""
Main orchestration method for processing LLM message flow.
Args:
agent: The agent instance
initial_response: Initial LLM response
tools_dict: Dictionary of available tools
messages: Conversation history
attachments: Optional attachments
stream: Whether to use streaming
Returns:
Final response or generator for streaming
"""
messages = self.prepare_messages(agent, messages, attachments)
if stream:
return self.handle_streaming(agent, initial_response, tools_dict, messages)
else:
return self.handle_non_streaming(
agent, initial_response, tools_dict, messages
)
def prepare_messages(
self, agent, messages: List[Dict], attachments: Optional[List] = None
) -> List[Dict]:
"""
Prepare messages with attachments and provider-specific formatting.
Args:
agent: The agent instance
messages: Original messages
attachments: List of attachments
Returns:
Prepared messages list
"""
if not attachments:
return messages
logger.info(f"Preparing messages with {len(attachments)} attachments")
supported_types = agent.llm.get_supported_attachment_types()
supported_attachments = [
a for a in attachments if a.get("mime_type") in supported_types
]
unsupported_attachments = [
a for a in attachments if a.get("mime_type") not in supported_types
]
# Process supported attachments with the LLM's custom method
if supported_attachments:
logger.info(
f"Processing {len(supported_attachments)} supported attachments"
)
messages = agent.llm.prepare_messages_with_attachments(
messages, supported_attachments
)
# Process unsupported attachments with default method
if unsupported_attachments:
logger.info(
f"Processing {len(unsupported_attachments)} unsupported attachments"
)
messages = self._append_unsupported_attachments(
messages, unsupported_attachments
)
return messages
def _append_unsupported_attachments(
self, messages: List[Dict], attachments: List[Dict]
) -> List[Dict]:
"""
Default method to append unsupported attachment content to system prompt.
Args:
messages: Current messages
attachments: List of unsupported attachments
Returns:
Updated messages list
"""
prepared_messages = messages.copy()
attachment_texts = []
for attachment in attachments:
logger.info(f"Adding attachment {attachment.get('id')} to context")
if "content" in attachment:
attachment_texts.append(
f"Attached file content:\n\n{attachment['content']}"
)
if attachment_texts:
combined_text = "\n\n".join(attachment_texts)
system_msg = next(
(msg for msg in prepared_messages if msg.get("role") == "system"),
{"role": "system", "content": ""},
)
if system_msg not in prepared_messages:
prepared_messages.insert(0, system_msg)
system_msg["content"] += f"\n\n{combined_text}"
return prepared_messages
def handle_tool_calls(
self, agent, tool_calls: List[ToolCall], tools_dict: Dict, messages: List[Dict]
) -> Generator:
"""
Execute tool calls and update conversation history.
Args:
agent: The agent instance
tool_calls: List of tool calls to execute
tools_dict: Available tools dictionary
messages: Current conversation history
Returns:
Updated messages list
"""
updated_messages = messages.copy()
for call in tool_calls:
try:
self.tool_calls.append(call)
tool_executor_gen = agent._execute_tool_action(tools_dict, call)
while True:
try:
yield next(tool_executor_gen)
except StopIteration as e:
tool_response, call_id = e.value
break
updated_messages.append(
{
"role": "assistant",
"content": [
{
"function_call": {
"name": call.name,
"args": call.arguments,
"call_id": call_id,
}
}
],
}
)
updated_messages.append(self.create_tool_message(call, tool_response))
except Exception as e:
logger.error(f"Error executing tool: {str(e)}", exc_info=True)
error_call = ToolCall(
id=call.id, name=call.name, arguments=call.arguments
)
error_response = f"Error executing tool: {str(e)}"
error_message = self.create_tool_message(error_call, error_response)
updated_messages.append(error_message)
call_parts = call.name.split("_")
if len(call_parts) >= 2:
tool_id = call_parts[-1] # Last part is tool ID (e.g., "1")
action_name = "_".join(call_parts[:-1])
tool_name = tools_dict.get(tool_id, {}).get("name", "unknown_tool")
full_action_name = f"{action_name}_{tool_id}"
else:
tool_name = "unknown_tool"
action_name = call.name
full_action_name = call.name
yield {
"type": "tool_call",
"data": {
"tool_name": tool_name,
"call_id": call.id,
"action_name": full_action_name,
"arguments": call.arguments,
"error": error_response,
"status": "error",
},
}
return updated_messages
def handle_non_streaming(
self, agent, response: Any, tools_dict: Dict, messages: List[Dict]
) -> Generator:
"""
Handle non-streaming response flow.
Args:
agent: The agent instance
response: Current LLM response
tools_dict: Available tools dictionary
messages: Conversation history
Returns:
Final response after processing all tool calls
"""
parsed = self.parse_response(response)
self.llm_calls.append(build_stack_data(agent.llm))
while parsed.requires_tool_call:
tool_handler_gen = self.handle_tool_calls(
agent, parsed.tool_calls, tools_dict, messages
)
while True:
try:
yield next(tool_handler_gen)
except StopIteration as e:
messages = e.value
break
response = agent.llm.gen(
model=agent.gpt_model, messages=messages, tools=agent.tools
)
parsed = self.parse_response(response)
self.llm_calls.append(build_stack_data(agent.llm))
return parsed.content
def handle_streaming(
self, agent, response: Any, tools_dict: Dict, messages: List[Dict]
) -> Generator:
"""
Handle streaming response flow.
Args:
agent: The agent instance
response: Current LLM response
tools_dict: Available tools dictionary
messages: Conversation history
Yields:
Streaming response chunks
"""
buffer = ""
tool_calls = {}
for chunk in self._iterate_stream(response):
if isinstance(chunk, str):
yield chunk
continue
parsed = self.parse_response(chunk)
if parsed.tool_calls:
for call in parsed.tool_calls:
if call.index not in tool_calls:
tool_calls[call.index] = call
else:
existing = tool_calls[call.index]
if call.id:
existing.id = call.id
if call.name:
existing.name = call.name
if call.arguments:
existing.arguments += call.arguments
if parsed.finish_reason == "tool_calls":
tool_handler_gen = self.handle_tool_calls(
agent, list(tool_calls.values()), tools_dict, messages
)
while True:
try:
yield next(tool_handler_gen)
except StopIteration as e:
messages = e.value
break
tool_calls = {}
response = agent.llm.gen_stream(
model=agent.gpt_model, messages=messages, tools=agent.tools
)
self.llm_calls.append(build_stack_data(agent.llm))
yield from self.handle_streaming(agent, response, tools_dict, messages)
return
if parsed.content:
buffer += parsed.content
yield buffer
buffer = ""
if parsed.finish_reason == "stop":
return

View File

@@ -1,78 +0,0 @@
import uuid
from typing import Any, Dict, Generator
from application.llm.handlers.base import LLMHandler, LLMResponse, ToolCall
class GoogleLLMHandler(LLMHandler):
"""Handler for Google's GenAI API."""
def parse_response(self, response: Any) -> LLMResponse:
"""Parse Google response into standardized format."""
if isinstance(response, str):
return LLMResponse(
content=response,
tool_calls=[],
finish_reason="stop",
raw_response=response,
)
if hasattr(response, "candidates"):
parts = response.candidates[0].content.parts if response.candidates else []
tool_calls = [
ToolCall(
id=str(uuid.uuid4()),
name=part.function_call.name,
arguments=part.function_call.args,
)
for part in parts
if hasattr(part, "function_call") and part.function_call is not None
]
content = " ".join(
part.text
for part in parts
if hasattr(part, "text") and part.text is not None
)
return LLMResponse(
content=content,
tool_calls=tool_calls,
finish_reason="tool_calls" if tool_calls else "stop",
raw_response=response,
)
else:
tool_calls = []
if hasattr(response, "function_call"):
tool_calls.append(
ToolCall(
id=str(uuid.uuid4()),
name=response.function_call.name,
arguments=response.function_call.args,
)
)
return LLMResponse(
content=response.text if hasattr(response, "text") else "",
tool_calls=tool_calls,
finish_reason="tool_calls" if tool_calls else "stop",
raw_response=response,
)
def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict:
"""Create Google-style tool message."""
return {
"role": "model",
"content": [
{
"function_response": {
"name": tool_call.name,
"response": {"result": result},
}
}
],
}
def _iterate_stream(self, response: Any) -> Generator:
"""Iterate through Google streaming response."""
for chunk in response:
yield chunk

View File

@@ -1,18 +0,0 @@
from application.llm.handlers.base import LLMHandler
from application.llm.handlers.google import GoogleLLMHandler
from application.llm.handlers.openai import OpenAILLMHandler
class LLMHandlerCreator:
handlers = {
"openai": OpenAILLMHandler,
"google": GoogleLLMHandler,
"default": OpenAILLMHandler,
}
@classmethod
def create_handler(cls, llm_type: str, *args, **kwargs) -> LLMHandler:
handler_class = cls.handlers.get(llm_type.lower())
if not handler_class:
handler_class = OpenAILLMHandler
return handler_class(*args, **kwargs)

View File

@@ -1,57 +0,0 @@
from typing import Any, Dict, Generator
from application.llm.handlers.base import LLMHandler, LLMResponse, ToolCall
class OpenAILLMHandler(LLMHandler):
"""Handler for OpenAI API."""
def parse_response(self, response: Any) -> LLMResponse:
"""Parse OpenAI response into standardized format."""
if isinstance(response, str):
return LLMResponse(
content=response,
tool_calls=[],
finish_reason="stop",
raw_response=response,
)
message = getattr(response, "message", None) or getattr(response, "delta", None)
tool_calls = []
if hasattr(message, "tool_calls"):
tool_calls = [
ToolCall(
id=getattr(tc, "id", ""),
name=getattr(tc.function, "name", ""),
arguments=getattr(tc.function, "arguments", ""),
index=getattr(tc, "index", None),
)
for tc in message.tool_calls or []
]
return LLMResponse(
content=getattr(message, "content", ""),
tool_calls=tool_calls,
finish_reason=getattr(response, "finish_reason", ""),
raw_response=response,
)
def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict:
"""Create OpenAI-style tool message."""
return {
"role": "tool",
"content": [
{
"function_response": {
"name": tool_call.name,
"response": {"result": result},
"call_id": tool_call.id,
}
}
],
}
def _iterate_stream(self, response: Any) -> Generator:
"""Iterate through OpenAI streaming response."""
for chunk in response:
yield chunk

View File

@@ -2,7 +2,6 @@ from application.llm.base import BaseLLM
from application.core.settings import settings
import threading
class LlamaSingleton:
_instances = {}
_lock = threading.Lock() # Add a lock for thread synchronization
@@ -30,7 +29,7 @@ class LlamaCpp(BaseLLM):
self,
api_key=None,
user_api_key=None,
llm_name=settings.LLM_PATH,
llm_name=settings.MODEL_PATH,
*args,
**kwargs,
):
@@ -43,18 +42,14 @@ class LlamaCpp(BaseLLM):
context = messages[0]["content"]
user_question = messages[-1]["content"]
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
result = LlamaSingleton.query_model(
self.llama, prompt, max_tokens=150, echo=False
)
result = LlamaSingleton.query_model(self.llama, prompt, max_tokens=150, echo=False)
return result["choices"][0]["text"].split("### Answer \n")[-1]
def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs):
context = messages[0]["content"]
user_question = messages[-1]["content"]
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
result = LlamaSingleton.query_model(
self.llama, prompt, max_tokens=150, echo=False, stream=stream
)
result = LlamaSingleton.query_model(self.llama, prompt, max_tokens=150, echo=False, stream=stream)
for item in result:
for choice in item["choices"]:
yield choice["text"]
yield choice["text"]

View File

@@ -7,7 +7,6 @@ from application.llm.anthropic import AnthropicLLM
from application.llm.docsgpt_provider import DocsGPTAPILLM
from application.llm.premai import PremAILLM
from application.llm.google_ai import GoogleLLM
from application.llm.novita import NovitaLLM
class LLMCreator:
@@ -21,15 +20,12 @@ class LLMCreator:
"docsgpt": DocsGPTAPILLM,
"premai": PremAILLM,
"groq": GroqLLM,
"google": GoogleLLM,
"novita": NovitaLLM,
"google": GoogleLLM
}
@classmethod
def create_llm(cls, type, api_key, user_api_key, decoded_token, *args, **kwargs):
def create_llm(cls, type, api_key, user_api_key, *args, **kwargs):
llm_class = cls.llms.get(type.lower())
if not llm_class:
raise ValueError(f"No LLM class found for type {type}")
return llm_class(
api_key, user_api_key, decoded_token=decoded_token, *args, **kwargs
)
return llm_class(api_key, user_api_key, *args, **kwargs)

View File

@@ -1,32 +0,0 @@
from application.llm.base import BaseLLM
from openai import OpenAI
class NovitaLLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.client = OpenAI(api_key=api_key, base_url="https://api.novita.ai/v3/openai")
self.api_key = api_key
self.user_api_key = user_api_key
def _raw_gen(self, baseself, model, messages, stream=False, tools=None, **kwargs):
if tools:
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, tools=tools, **kwargs
)
return response.choices[0]
else:
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)
return response.choices[0].message.content
def _raw_gen_stream(
self, baseself, model, messages, stream=True, tools=None, **kwargs
):
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)
for line in response:
if line.choices[0].delta.content is not None:
yield line.choices[0].delta.content

View File

@@ -1,10 +1,6 @@
import base64
import json
import logging
from application.core.settings import settings
from application.llm.base import BaseLLM
from application.storage.storage_creator import StorageCreator
from application.core.settings import settings
class OpenAILLM(BaseLLM):
@@ -13,101 +9,15 @@ class OpenAILLM(BaseLLM):
from openai import OpenAI
super().__init__(*args, **kwargs)
if (
isinstance(settings.OPENAI_BASE_URL, str)
and settings.OPENAI_BASE_URL.strip()
):
self.client = OpenAI(api_key=api_key, base_url=settings.OPENAI_BASE_URL)
if settings.OPENAI_BASE_URL:
self.client = OpenAI(
api_key=api_key,
base_url=settings.OPENAI_BASE_URL
)
else:
DEFAULT_OPENAI_API_BASE = "https://api.openai.com/v1"
self.client = OpenAI(api_key=api_key, base_url=DEFAULT_OPENAI_API_BASE)
self.client = OpenAI(api_key=api_key)
self.api_key = api_key
self.user_api_key = user_api_key
self.storage = StorageCreator.get_storage()
def _clean_messages_openai(self, messages):
cleaned_messages = []
for message in messages:
role = message.get("role")
content = message.get("content")
if role == "model":
role = "assistant"
if role and content is not None:
if isinstance(content, str):
cleaned_messages.append({"role": role, "content": content})
elif isinstance(content, list):
for item in content:
if "text" in item:
cleaned_messages.append(
{"role": role, "content": item["text"]}
)
elif "function_call" in item:
tool_call = {
"id": item["function_call"]["call_id"],
"type": "function",
"function": {
"name": item["function_call"]["name"],
"arguments": json.dumps(
item["function_call"]["args"]
),
},
}
cleaned_messages.append(
{
"role": "assistant",
"content": None,
"tool_calls": [tool_call],
}
)
elif "function_response" in item:
cleaned_messages.append(
{
"role": "tool",
"tool_call_id": item["function_response"][
"call_id"
],
"content": json.dumps(
item["function_response"]["response"]["result"]
),
}
)
elif isinstance(item, dict):
content_parts = []
if "text" in item:
content_parts.append(
{"type": "text", "text": item["text"]}
)
elif (
"type" in item
and item["type"] == "text"
and "text" in item
):
content_parts.append(item)
elif (
"type" in item
and item["type"] == "file"
and "file" in item
):
content_parts.append(item)
elif (
"type" in item
and item["type"] == "image_url"
and "image_url" in item
):
content_parts.append(item)
cleaned_messages.append(
{"role": role, "content": content_parts}
)
else:
raise ValueError(
f"Unexpected content dictionary format: {item}"
)
else:
raise ValueError(f"Unexpected content type: {type(content)}")
return cleaned_messages
def _raw_gen(
self,
@@ -115,32 +25,14 @@ class OpenAILLM(BaseLLM):
model,
messages,
stream=False,
tools=None,
engine=settings.AZURE_DEPLOYMENT_NAME,
response_format=None,
**kwargs,
):
messages = self._clean_messages_openai(messages)
**kwargs
):
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)
request_params = {
"model": model,
"messages": messages,
"stream": stream,
**kwargs,
}
if tools:
request_params["tools"] = tools
if response_format:
request_params["response_format"] = response_format
response = self.client.chat.completions.create(**request_params)
if tools:
return response.choices[0]
else:
return response.choices[0].message.content
return response.choices[0].message.content
def _raw_gen_stream(
self,
@@ -148,276 +40,34 @@ class OpenAILLM(BaseLLM):
model,
messages,
stream=True,
tools=None,
engine=settings.AZURE_DEPLOYMENT_NAME,
response_format=None,
**kwargs,
):
messages = self._clean_messages_openai(messages)
request_params = {
"model": model,
"messages": messages,
"stream": stream,
**kwargs,
}
if tools:
request_params["tools"] = tools
if response_format:
request_params["response_format"] = response_format
response = self.client.chat.completions.create(**request_params)
**kwargs
):
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)
for line in response:
if (
len(line.choices) > 0
and line.choices[0].delta.content is not None
and len(line.choices[0].delta.content) > 0
):
# import sys
# print(line.choices[0].delta.content, file=sys.stderr)
if line.choices[0].delta.content is not None:
yield line.choices[0].delta.content
elif len(line.choices) > 0:
yield line.choices[0]
def _supports_tools(self):
return True
def _supports_structured_output(self):
return True
def prepare_structured_output_format(self, json_schema):
if not json_schema:
return None
try:
def add_additional_properties_false(schema_obj):
if isinstance(schema_obj, dict):
schema_copy = schema_obj.copy()
if schema_copy.get("type") == "object":
schema_copy["additionalProperties"] = False
# Ensure 'required' includes all properties for OpenAI strict mode
if "properties" in schema_copy:
schema_copy["required"] = list(
schema_copy["properties"].keys()
)
for key, value in schema_copy.items():
if key == "properties" and isinstance(value, dict):
schema_copy[key] = {
prop_name: add_additional_properties_false(prop_schema)
for prop_name, prop_schema in value.items()
}
elif key == "items" and isinstance(value, dict):
schema_copy[key] = add_additional_properties_false(value)
elif key in ["anyOf", "oneOf", "allOf"] and isinstance(
value, list
):
schema_copy[key] = [
add_additional_properties_false(sub_schema)
for sub_schema in value
]
return schema_copy
return schema_obj
processed_schema = add_additional_properties_false(json_schema)
result = {
"type": "json_schema",
"json_schema": {
"name": processed_schema.get("name", "response"),
"description": processed_schema.get(
"description", "Structured response"
),
"schema": processed_schema,
"strict": True,
},
}
return result
except Exception as e:
logging.error(f"Error preparing structured output format: {e}")
return None
def get_supported_attachment_types(self):
"""
Return a list of MIME types supported by OpenAI for file uploads.
Returns:
list: List of supported MIME types
"""
return [
"application/pdf",
"image/png",
"image/jpeg",
"image/jpg",
"image/webp",
"image/gif",
]
def prepare_messages_with_attachments(self, messages, attachments=None):
"""
Process attachments using OpenAI's file API for more efficient handling.
Args:
messages (list): List of message dictionaries.
attachments (list): List of attachment dictionaries with content and metadata.
Returns:
list: Messages formatted with file references for OpenAI API.
"""
if not attachments:
return messages
prepared_messages = messages.copy()
# Find the user message to attach file_id to the last one
user_message_index = None
for i in range(len(prepared_messages) - 1, -1, -1):
if prepared_messages[i].get("role") == "user":
user_message_index = i
break
if user_message_index is None:
user_message = {"role": "user", "content": []}
prepared_messages.append(user_message)
user_message_index = len(prepared_messages) - 1
if isinstance(prepared_messages[user_message_index].get("content"), str):
text_content = prepared_messages[user_message_index]["content"]
prepared_messages[user_message_index]["content"] = [
{"type": "text", "text": text_content}
]
elif not isinstance(prepared_messages[user_message_index].get("content"), list):
prepared_messages[user_message_index]["content"] = []
for attachment in attachments:
mime_type = attachment.get("mime_type")
if mime_type and mime_type.startswith("image/"):
try:
base64_image = self._get_base64_image(attachment)
prepared_messages[user_message_index]["content"].append(
{
"type": "image_url",
"image_url": {
"url": f"data:{mime_type};base64,{base64_image}"
},
}
)
except Exception as e:
logging.error(
f"Error processing image attachment: {e}", exc_info=True
)
if "content" in attachment:
prepared_messages[user_message_index]["content"].append(
{
"type": "text",
"text": f"[Image could not be processed: {attachment.get('path', 'unknown')}]",
}
)
# Handle PDFs using the file API
elif mime_type == "application/pdf":
try:
file_id = self._upload_file_to_openai(attachment)
prepared_messages[user_message_index]["content"].append(
{"type": "file", "file": {"file_id": file_id}}
)
except Exception as e:
logging.error(f"Error uploading PDF to OpenAI: {e}", exc_info=True)
if "content" in attachment:
prepared_messages[user_message_index]["content"].append(
{
"type": "text",
"text": f"File content:\n\n{attachment['content']}",
}
)
return prepared_messages
def _get_base64_image(self, attachment):
"""
Convert an image file to base64 encoding.
Args:
attachment (dict): Attachment dictionary with path and metadata.
Returns:
str: Base64-encoded image data.
"""
file_path = attachment.get("path")
if not file_path:
raise ValueError("No file path provided in attachment")
try:
with self.storage.get_file(file_path) as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
except FileNotFoundError:
raise FileNotFoundError(f"File not found: {file_path}")
def _upload_file_to_openai(self, attachment):
"""
Upload a file to OpenAI and return the file_id.
Args:
attachment (dict): Attachment dictionary with path and metadata.
Expected keys:
- path: Path to the file
- id: Optional MongoDB ID for caching
Returns:
str: OpenAI file_id for the uploaded file.
"""
import logging
if "openai_file_id" in attachment:
return attachment["openai_file_id"]
file_path = attachment.get("path")
if not self.storage.file_exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
try:
file_id = self.storage.process_file(
file_path,
lambda local_path, **kwargs: self.client.files.create(
file=open(local_path, "rb"), purpose="assistants"
).id,
)
from application.core.mongo_db import MongoDB
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
attachments_collection = db["attachments"]
if "_id" in attachment:
attachments_collection.update_one(
{"_id": attachment["_id"]}, {"$set": {"openai_file_id": file_id}}
)
return file_id
except Exception as e:
logging.error(f"Error uploading file to OpenAI: {e}", exc_info=True)
raise
class AzureOpenAILLM(OpenAILLM):
def __init__(self, api_key, user_api_key, *args, **kwargs):
super().__init__(api_key)
def __init__(
self, openai_api_key, openai_api_base, openai_api_version, deployment_name
):
super().__init__(openai_api_key)
self.api_base = (settings.OPENAI_API_BASE,)
self.api_version = (settings.OPENAI_API_VERSION,)
self.deployment_name = (settings.AZURE_DEPLOYMENT_NAME,)
from openai import AzureOpenAI
self.client = AzureOpenAI(
api_key=api_key,
api_key=openai_api_key,
api_version=settings.OPENAI_API_VERSION,
azure_endpoint=settings.OPENAI_API_BASE,
api_base=settings.OPENAI_API_BASE,
deployment_name=settings.AZURE_DEPLOYMENT_NAME,
)

View File

@@ -76,7 +76,7 @@ class SagemakerAPILLM(BaseLLM):
self.endpoint = settings.SAGEMAKER_ENDPOINT
self.runtime = runtime
def _raw_gen(self, baseself, model, messages, stream=False, tools=None, **kwargs):
def _raw_gen(self, baseself, model, messages, stream=False, **kwargs):
context = messages[0]["content"]
user_question = messages[-1]["content"]
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
@@ -105,7 +105,7 @@ class SagemakerAPILLM(BaseLLM):
print(result[0]["generated_text"], file=sys.stderr)
return result[0]["generated_text"][len(prompt) :]
def _raw_gen_stream(self, baseself, model, messages, stream=True, tools=None, **kwargs):
def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs):
context = messages[0]["content"]
user_question = messages[-1]["content"]
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"

View File

@@ -1,161 +0,0 @@
import datetime
import functools
import inspect
import logging
import uuid
from typing import Any, Callable, Dict, Generator, List
from application.core.mongo_db import MongoDB
from application.core.settings import settings
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
class LogContext:
def __init__(self, endpoint, activity_id, user, api_key, query):
self.endpoint = endpoint
self.activity_id = activity_id
self.user = user
self.api_key = api_key
self.query = query
self.stacks = []
def build_stack_data(
obj: Any,
include_attributes: List[str] = None,
exclude_attributes: List[str] = None,
custom_data: Dict = None,
) -> Dict:
if obj is None:
raise ValueError("The 'obj' parameter cannot be None")
data = {}
if include_attributes is None:
include_attributes = []
for name, value in inspect.getmembers(obj):
if (
not name.startswith("_")
and not inspect.ismethod(value)
and not inspect.isfunction(value)
):
include_attributes.append(name)
for attr_name in include_attributes:
if exclude_attributes and attr_name in exclude_attributes:
continue
try:
attr_value = getattr(obj, attr_name)
if attr_value is not None:
if isinstance(attr_value, (int, float, str, bool)):
data[attr_name] = attr_value
elif isinstance(attr_value, list):
if all(isinstance(item, dict) for item in attr_value):
data[attr_name] = attr_value
elif all(hasattr(item, "__dict__") for item in attr_value):
data[attr_name] = [item.__dict__ for item in attr_value]
else:
data[attr_name] = [str(item) for item in attr_value]
elif isinstance(attr_value, dict):
data[attr_name] = {k: str(v) for k, v in attr_value.items()}
except AttributeError as e:
logging.warning(f"AttributeError while accessing {attr_name}: {e}")
except AttributeError:
pass
if custom_data:
data.update(custom_data)
return data
def log_activity() -> Callable:
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
activity_id = str(uuid.uuid4())
data = build_stack_data(args[0])
endpoint = data.get("endpoint", "")
user = data.get("user", "local")
api_key = data.get("user_api_key", "")
query = kwargs.get("query", getattr(args[0], "query", ""))
context = LogContext(endpoint, activity_id, user, api_key, query)
kwargs["log_context"] = context
logging.info(
f"Starting activity: {endpoint} - {activity_id} - User: {user}"
)
generator = func(*args, **kwargs)
yield from _consume_and_log(generator, context)
return wrapper
return decorator
def _consume_and_log(generator: Generator, context: "LogContext"):
try:
for item in generator:
yield item
except Exception as e:
logging.exception(f"Error in {context.endpoint} - {context.activity_id}: {e}")
context.stacks.append({"component": "error", "data": {"message": str(e)}})
_log_to_mongodb(
endpoint=context.endpoint,
activity_id=context.activity_id,
user=context.user,
api_key=context.api_key,
query=context.query,
stacks=context.stacks,
level="error",
)
raise
finally:
_log_to_mongodb(
endpoint=context.endpoint,
activity_id=context.activity_id,
user=context.user,
api_key=context.api_key,
query=context.query,
stacks=context.stacks,
level="info",
)
def _log_to_mongodb(
endpoint: str,
activity_id: str,
user: str,
api_key: str,
query: str,
stacks: List[Dict],
level: str,
) -> None:
try:
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
user_logs_collection = db["stack_logs"]
log_entry = {
"endpoint": endpoint,
"id": activity_id,
"level": level,
"user": user,
"api_key": api_key,
"query": query,
"stacks": stacks,
"timestamp": datetime.datetime.now(datetime.timezone.utc),
}
# clean up text fields to be no longer than 10000 characters
for key, value in log_entry.items():
if isinstance(value, str) and len(value) > 10000:
log_entry[key] = value[:10000]
user_logs_collection.insert_one(log_entry)
logging.debug(f"Logged activity to MongoDB: {activity_id}")
except Exception as e:
logging.error(f"Failed to log to MongoDB: {e}", exc_info=True)

View File

@@ -1,5 +1,5 @@
import re
from typing import List, Tuple
from typing import List, Tuple, Union
import logging
from application.parser.schema.base import Document
from application.utils import get_encoding
@@ -32,7 +32,16 @@ class Chunker:
header, body = "", text # No header, treat entire text as body
return header, body
def combine_documents(self, doc: Document, next_doc: Document) -> Document:
combined_text = doc.text + " " + next_doc.text
combined_token_count = len(self.encoding.encode(combined_text))
new_doc = Document(
text=combined_text,
doc_id=doc.doc_id,
embedding=doc.embedding,
extra_info={**(doc.extra_info or {}), "token_count": combined_token_count}
)
return new_doc
def split_document(self, doc: Document) -> List[Document]:
split_docs = []
@@ -73,11 +82,26 @@ class Chunker:
processed_docs.append(doc)
i += 1
elif token_count < self.min_tokens:
doc.extra_info = doc.extra_info or {}
doc.extra_info["token_count"] = token_count
processed_docs.append(doc)
i += 1
if i + 1 < len(documents):
next_doc = documents[i + 1]
next_tokens = self.encoding.encode(next_doc.text)
if token_count + len(next_tokens) <= self.max_tokens:
# Combine small documents
combined_doc = self.combine_documents(doc, next_doc)
processed_docs.append(combined_doc)
i += 2
else:
# Keep the small document as is if adding next_doc would exceed max_tokens
doc.extra_info = doc.extra_info or {}
doc.extra_info["token_count"] = token_count
processed_docs.append(doc)
i += 1
else:
# No next document to combine with; add the small document as is
doc.extra_info = doc.extra_info or {}
doc.extra_info["token_count"] = token_count
processed_docs.append(doc)
i += 1
else:
# Split large documents
processed_docs.extend(self.split_document(doc))

View File

@@ -1,18 +0,0 @@
"""
External knowledge base connectors for DocsGPT.
This module contains connectors for external knowledge bases and document storage systems
that require authentication and specialized handling, separate from simple web scrapers.
"""
from .base import BaseConnectorAuth, BaseConnectorLoader
from .connector_creator import ConnectorCreator
from .google_drive import GoogleDriveAuth, GoogleDriveLoader
__all__ = [
'BaseConnectorAuth',
'BaseConnectorLoader',
'ConnectorCreator',
'GoogleDriveAuth',
'GoogleDriveLoader'
]

View File

@@ -1,129 +0,0 @@
"""
Base classes for external knowledge base connectors.
This module provides minimal abstract base classes that define the essential
interface for external knowledge base connectors.
"""
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from application.parser.schema.base import Document
class BaseConnectorAuth(ABC):
"""
Abstract base class for connector authentication.
Defines the minimal interface that all connector authentication
implementations must follow.
"""
@abstractmethod
def get_authorization_url(self, state: Optional[str] = None) -> str:
"""
Generate authorization URL for OAuth flows.
Args:
state: Optional state parameter for CSRF protection
Returns:
Authorization URL
"""
pass
@abstractmethod
def exchange_code_for_tokens(self, authorization_code: str) -> Dict[str, Any]:
"""
Exchange authorization code for access tokens.
Args:
authorization_code: Authorization code from OAuth callback
Returns:
Dictionary containing token information
"""
pass
@abstractmethod
def refresh_access_token(self, refresh_token: str) -> Dict[str, Any]:
"""
Refresh an expired access token.
Args:
refresh_token: Refresh token
Returns:
Dictionary containing refreshed token information
"""
pass
@abstractmethod
def is_token_expired(self, token_info: Dict[str, Any]) -> bool:
"""
Check if a token is expired.
Args:
token_info: Token information dictionary
Returns:
True if token is expired, False otherwise
"""
pass
class BaseConnectorLoader(ABC):
"""
Abstract base class for connector loaders.
Defines the minimal interface that all connector loader
implementations must follow.
"""
@abstractmethod
def __init__(self, session_token: str):
"""
Initialize the connector loader.
Args:
session_token: Authentication session token
"""
pass
@abstractmethod
def load_data(self, inputs: Dict[str, Any]) -> List[Document]:
"""
Load documents from the external knowledge base.
Args:
inputs: Configuration dictionary containing:
- file_ids: Optional list of specific file IDs to load
- folder_ids: Optional list of folder IDs to browse/download
- limit: Maximum number of items to return
- list_only: If True, return metadata without content
- recursive: Whether to recursively process folders
Returns:
List of Document objects
"""
pass
@abstractmethod
def download_to_directory(self, local_dir: str, source_config: Dict[str, Any] = None) -> Dict[str, Any]:
"""
Download files/folders to a local directory.
Args:
local_dir: Local directory path to download files to
source_config: Configuration for what to download
Returns:
Dictionary containing download results:
- files_downloaded: Number of files downloaded
- directory_path: Path where files were downloaded
- empty_result: Whether no files were downloaded
- source_type: Type of connector
- config_used: Configuration that was used
- error: Error message if download failed (optional)
"""
pass

View File

@@ -1,81 +0,0 @@
from application.parser.connectors.google_drive.loader import GoogleDriveLoader
from application.parser.connectors.google_drive.auth import GoogleDriveAuth
class ConnectorCreator:
"""
Factory class for creating external knowledge base connectors and auth providers.
These are different from remote loaders as they typically require
authentication and connect to external document storage systems.
"""
connectors = {
"google_drive": GoogleDriveLoader,
}
auth_providers = {
"google_drive": GoogleDriveAuth,
}
@classmethod
def create_connector(cls, connector_type, *args, **kwargs):
"""
Create a connector instance for the specified type.
Args:
connector_type: Type of connector to create (e.g., 'google_drive')
*args, **kwargs: Arguments to pass to the connector constructor
Returns:
Connector instance
Raises:
ValueError: If connector type is not supported
"""
connector_class = cls.connectors.get(connector_type.lower())
if not connector_class:
raise ValueError(f"No connector class found for type {connector_type}")
return connector_class(*args, **kwargs)
@classmethod
def create_auth(cls, connector_type):
"""
Create an auth provider instance for the specified connector type.
Args:
connector_type: Type of connector auth to create (e.g., 'google_drive')
Returns:
Auth provider instance
Raises:
ValueError: If connector type is not supported for auth
"""
auth_class = cls.auth_providers.get(connector_type.lower())
if not auth_class:
raise ValueError(f"No auth class found for type {connector_type}")
return auth_class()
@classmethod
def get_supported_connectors(cls):
"""
Get list of supported connector types.
Returns:
List of supported connector type strings
"""
return list(cls.connectors.keys())
@classmethod
def is_supported(cls, connector_type):
"""
Check if a connector type is supported.
Args:
connector_type: Type of connector to check
Returns:
True if supported, False otherwise
"""
return connector_type.lower() in cls.connectors

View File

@@ -1,10 +0,0 @@
"""
Google Drive connector for DocsGPT.
This module provides authentication and document loading capabilities for Google Drive.
"""
from .auth import GoogleDriveAuth
from .loader import GoogleDriveLoader
__all__ = ['GoogleDriveAuth', 'GoogleDriveLoader']

Some files were not shown because too many files have changed in this diff Show More