Merge remote-tracking branch 'upstream/main'

This merge includes several important updates:

1. Error Handling: Added try-except blocks for file operations and network requests.
2. Logging Enhancements: Improved logging to capture more detailed information.
3. Code Refactoring: Created download_file and upload_index functions to avoid code repetition.
4. Configuration: Used constants for MIN_TOKENS, MAX_TOKENS, and RECURSION_DEPTH.
This commit is contained in:
DhruvKadam-git
2024-10-01 17:44:30 +05:30
397 changed files with 68676 additions and 0 deletions

138
.github/ISSUE_TEMPLATE/bug_report.yml vendored Normal file
View File

@@ -0,0 +1,138 @@
name: "🐛 Bug Report"
description: "Submit a bug report to help us improve"
title: "🐛 Bug Report: "
labels: ["type: bug"]
body:
- type: markdown
attributes:
value: We value your time and your efforts to submit this bug report is appreciated. 🙏
- type: textarea
id: description
validations:
required: true
attributes:
label: "📜 Description"
description: "A clear and concise description of what the bug is."
placeholder: "It bugs out when ..."
- type: textarea
id: steps-to-reproduce
validations:
required: true
attributes:
label: "👟 Reproduction steps"
description: "How do you trigger this bug? Please walk us through it step by step."
placeholder: "1. Go to '...'
2. Click on '....'
3. Scroll down to '....'
4. See error"
- type: textarea
id: expected-behavior
validations:
required: true
attributes:
label: "👍 Expected behavior"
description: "What did you think should happen?"
placeholder: "It should ..."
- type: textarea
id: actual-behavior
validations:
required: true
attributes:
label: "👎 Actual Behavior with Screenshots"
description: "What did actually happen? Add screenshots, if applicable."
placeholder: "It actually ..."
- type: dropdown
id: operating-system
attributes:
label: "💻 Operating system"
description: "What OS is your app running on?"
options:
- Linux
- MacOS
- Windows
- Something else
validations:
required: true
- type: dropdown
id: browsers
attributes:
label: What browsers are you seeing the problem on?
multiple: true
options:
- Firefox
- Chrome
- Safari
- Microsoft Edge
- Something else
- type: dropdown
id: dev-environment
validations:
required: true
attributes:
label: "🤖 What development environment are you experiencing this bug on?"
options:
- Docker
- Local dev server
- type: textarea
id: env-vars
validations:
required: false
attributes:
label: "🔒 Did you set the correct environment variables in the right path? List the environment variable names (not values please!)"
description: "Please refer to the [Project setup instructions](https://github.com/arc53/DocsGPT#quickstart) if you are unsure."
placeholder: "It actually ..."
- type: textarea
id: additional-context
validations:
required: false
attributes:
label: "📃 Provide any additional context for the Bug."
description: "Add any other context about the problem here."
placeholder: "It actually ..."
- type: textarea
id: logs
validations:
required: false
attributes:
label: 📖 Relevant log output
description: Please copy and paste any relevant log output. This will be automatically formatted into code, so no need for backticks.
render: shell
- type: checkboxes
id: no-duplicate-issues
attributes:
label: "👀 Have you spent some time to check if this bug has been raised before?"
options:
- label: "I checked and didn't find similar issue"
required: true
- type: dropdown
id: willing-to-submit-pr
attributes:
label: 🔗 Are you willing to submit PR?
description: This is absolutely not required, but we are happy to guide you in the contribution process.
options: # Added options key
- "Yes, I am willing to submit a PR!"
- "No"
validations:
required: false
- type: checkboxes
id: terms
attributes:
label: 🧑‍⚖️ Code of Conduct
description: By submitting this issue, you agree to follow our [Code of Conduct](https://github.com/arc53/DocsGPT/blob/main/CODE_OF_CONDUCT.md)
options:
- label: I agree to follow this project's Code of Conduct
required: true

View File

@@ -0,0 +1,54 @@
name: 🚀 Feature
description: "Submit a proposal for a new feature"
title: "🚀 Feature: "
labels: [feature]
body:
- type: markdown
attributes:
value: We value your time and your efforts to submit this bug report is appreciated. 🙏
- type: textarea
id: feature-description
validations:
required: true
attributes:
label: "🔖 Feature description"
description: "A clear and concise description of what the feature is."
placeholder: "You should add ..."
- type: textarea
id: pitch
validations:
required: true
attributes:
label: "🎤 Why is this feature needed ?"
description: "Please explain why this feature should be implemented and how it would be used. Add examples, if applicable."
placeholder: "In my use-case, ..."
- type: textarea
id: solution
validations:
required: true
attributes:
label: "✌️ How do you aim to achieve this?"
description: "A clear and concise description of what you want to happen."
placeholder: "I want this feature to, ..."
- type: textarea
id: alternative
validations:
required: false
attributes:
label: "🔄️ Additional Information"
description: "A clear and concise description of any alternative solutions or additional solutions you've considered."
placeholder: "I tried, ..."
- type: checkboxes
id: no-duplicate-issues
attributes:
label: "👀 Have you spent some time to check if this feature request has been raised before?"
options:
- label: "I checked and didn't find similar issue"
required: true
- type: dropdown
id: willing-to-submit-pr
attributes:
label: Are you willing to submit PR?
description: This is absolutely not required, but we are happy to guide you in the contribution process.
options:
- "Yes I am willing to submit a PR!"

5
.github/PULL_REQUEST_TEMPLATE.md vendored Normal file
View File

@@ -0,0 +1,5 @@
- **What kind of change does this PR introduce?** (Bug fix, feature, docs update, ...)
- **Why was this change needed?** (You can also link to an open issue here)
- **Other information**:

15
.github/dependabot.yml vendored Normal file
View File

@@ -0,0 +1,15 @@
# To get started with Dependabot version updates, you'll need to specify which
# package ecosystems to update and where the package manifests are located.
# Please see the documentation for all configuration options:
# https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file
version: 2
updates:
- package-ecosystem: "pip" # See documentation for possible values
directory: "/application" # Location of package manifests
schedule:
interval: "weekly"
- package-ecosystem: "npm" # See documentation for possible values
directory: "/frontend" # Location of package manifests
schedule:
interval: "weekly"

5
.github/holopin.yml vendored Normal file
View File

@@ -0,0 +1,5 @@
organization: arc53
defaultSticker: clqmdf0ed34290glbvqh0kzxd
stickers:
- id: clqmdf0ed34290glbvqh0kzxd
alias: festive

23
.github/labeler.yml vendored Normal file
View File

@@ -0,0 +1,23 @@
repo:
- '*'
github:
- .github/**/*
application:
- application/**/*
docs:
- docs/**/*
extensions:
- extensions/**/*
frontend:
- frontend/**/*
scripts:
- scripts/**/*
tests:
- tests/**/*

47
.github/workflows/ci.yml vendored Normal file
View File

@@ -0,0 +1,47 @@
name: Build and push DocsGPT Docker image
on:
workflow_dispatch:
push:
branches:
- main
jobs:
deploy:
if: github.repository == 'arc53/DocsGPT'
runs-on: ubuntu-latest
permissions:
contents: read
packages: write
steps:
- uses: actions/checkout@v3
- name: Set up QEMU
uses: docker/setup-qemu-action@v1
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v1
- name: Login to DockerHub
uses: docker/login-action@v2
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_PASSWORD }}
- name: Login to ghcr.io
uses: docker/login-action@v2
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Build and push Docker images to docker.io and ghcr.io
uses: docker/build-push-action@v4
with:
file: './application/Dockerfile'
platforms: linux/amd64
context: ./application
push: true
tags: |
${{ secrets.DOCKER_USERNAME }}/docsgpt:latest
ghcr.io/${{ github.repository_owner }}/docsgpt:latest

48
.github/workflows/cife.yml vendored Normal file
View File

@@ -0,0 +1,48 @@
name: Build and push DocsGPT-FE Docker image
on:
workflow_dispatch:
push:
branches:
- main
jobs:
deploy:
if: github.repository == 'arc53/DocsGPT'
runs-on: ubuntu-latest
permissions:
contents: read
packages: write
steps:
- uses: actions/checkout@v3
- name: Set up QEMU
uses: docker/setup-qemu-action@v1
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v1
- name: Login to DockerHub
uses: docker/login-action@v2
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_PASSWORD }}
- name: Login to ghcr.io
uses: docker/login-action@v2
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
# Runs a single command using the runners shell
- name: Build and push Docker images to docker.io and ghcr.io
uses: docker/build-push-action@v4
with:
file: './frontend/Dockerfile'
platforms: linux/amd64, linux/arm64
context: ./frontend
push: true
tags: |
${{ secrets.DOCKER_USERNAME }}/docsgpt-fe:latest
ghcr.io/${{ github.repository_owner }}/docsgpt-fe:latest

16
.github/workflows/labeler.yml vendored Normal file
View File

@@ -0,0 +1,16 @@
# https://github.com/actions/labeler
name: Pull Request Labeler
on:
- pull_request_target
jobs:
triage:
if: github.repository == 'arc53/DocsGPT'
permissions:
contents: read
pull-requests: write
runs-on: ubuntu-latest
steps:
- uses: actions/labeler@v4
with:
repo-token: "${{ secrets.GITHUB_TOKEN }}"
sync-labels: true

17
.github/workflows/lint.yml vendored Normal file
View File

@@ -0,0 +1,17 @@
name: Python linting
on:
push:
branches:
- '*'
pull_request:
types: [ opened, synchronize ]
jobs:
ruff:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Lint with Ruff
uses: chartboost/ruff-action@v1

30
.github/workflows/pytest.yml vendored Normal file
View File

@@ -0,0 +1,30 @@
name: Run python tests with pytest
on: [push, pull_request]
jobs:
pytest_and_coverage:
name: Run tests and count coverage
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.11"]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pytest pytest-cov
cd application
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Test with pytest and generate coverage report
run: |
python -m pytest --cov=application --cov-report=xml
- name: Upload coverage reports to Codecov
if: github.event_name == 'pull_request' && matrix.python-version == '3.11'
uses: codecov/codecov-action@v3
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}

41
.github/workflows/sync_fork.yaml vendored Normal file
View File

@@ -0,0 +1,41 @@
name: Upstream Sync
permissions:
contents: write
on:
schedule:
- cron: "0 0 * * *" # every hour
workflow_dispatch:
jobs:
sync_latest_from_upstream:
name: Sync latest commits from upstream repo
runs-on: ubuntu-latest
if: ${{ github.event.repository.fork }}
steps:
# Step 1: run a standard checkout action
- name: Checkout target repo
uses: actions/checkout@v3
# Step 2: run the sync action
- name: Sync upstream changes
id: sync
uses: aormsby/Fork-Sync-With-Upstream-action@v3.4
with:
# set your upstream repo and branch
upstream_sync_repo: arc53/DocsGPT
upstream_sync_branch: main
target_sync_branch: main
target_repo_token: ${{ secrets.GITHUB_TOKEN }} # automatically generated, no need to set
# Set test_mode true to run tests instead of the true action!!
test_mode: false
- name: Sync check
if: failure()
run: |
echo "::error::由于权限不足,导致同步失败(这是预期的行为),请前往仓库首页手动执行[Sync fork]。"
echo "::error::Due to insufficient permissions, synchronization failed (as expected). Please go to the repository homepage and manually perform [Sync fork]."
exit 1

176
.gitignore vendored Normal file
View File

@@ -0,0 +1,176 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
*.next
# Distribution / packaging
.Python
build/
develop-eggs/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
**/*.ipynb
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
.flaskenv
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
#pycharm
.idea/
# macOS
.DS_Store
#frontend
# Logs
frontend/logs
frontend/*.log
frontend/npm-debug.log*
frontend/yarn-debug.log*
frontend/yarn-error.log*
frontend/pnpm-debug.log*
frontend/lerna-debug.log*
frontend/node_modules
frontend/dist
frontend/dist-ssr
frontend/*.local
# Editor directories and files
frontend/.vscode/*
frontend/!.vscode/extensions.json
frontend/.idea
frontend/.DS_Store
frontend/*.suo
frontend/*.ntvs*
frontend/*.njsproj
frontend/*.sln
frontend/*.sw?
application/vectors/
**/inputs
**/indexes
**/temp
**/yarn.lock
node_modules/
.vscode/settings.json
/models/
model/

2
.ruff.toml Normal file
View File

@@ -0,0 +1,2 @@
# Allow lines to be as long as 120 characters.
line-length = 120

16
.vscode/launch.json vendored Normal file
View File

@@ -0,0 +1,16 @@
{
"version": "0.2.0",
"configurations": [
{
"name": "Docker Debug Frontend",
"request": "launch",
"type": "chrome",
"preLaunchTask": "docker-compose: debug:frontend",
"url": "http://127.0.0.1:5173",
"webRoot": "${workspaceFolder}/frontend",
"skipFiles": [
"<node_internals>/**"
]
}
]
}

21
.vscode/tasks.json vendored Normal file
View File

@@ -0,0 +1,21 @@
{
"version": "2.0.0",
"tasks": [
{
"type": "docker-compose",
"label": "docker-compose: debug:frontend",
"dockerCompose": {
"up": {
"detached": true,
"services": [
"frontend"
],
"build": true
},
"files": [
"${workspaceFolder}/docker-compose.yaml"
]
}
}
]
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 88 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 21 KiB

124
CODE_OF_CONDUCT.md Normal file
View File

@@ -0,0 +1,124 @@
# Contributor Covenant Code of Conduct
## Our Pledge
We as members, contributors and leaders pledge to make participation in our
community, a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socio-economic status,
nationality, personal appearance, race, religion or sexual identity
and orientation.
We pledge to act and interact in ways that contribute to an open, welcoming,
diverse, inclusive and a healthy community.
## Our Standards
Examples of behavior that contribute to a positive environment for our
community include:
## Demonstrating empathy and kindness towards other people
1. Being respectful and open to differing opinions, viewpoints, and experiences
2. Giving and gracefully accepting constructive feedback
3. Taking accountability and offering apologies to those who have been impacted by our errors,
while also gaining insights from the situation
4. Focusing on what is best not just for us as individuals but for the
community as a whole
Examples of unacceptable behavior include:
1. The use of sexualized language or imagery, and sexual attention or
advances of any kind
2. Trolling, insulting or derogatory comments, and personal or political attacks
3. Public or private harassment
4. Publishing other's private information, such as a physical or email
address, without their explicit permission
5. Other conduct which could reasonably be considered inappropriate in a
professional setting
## Enforcement Responsibilities
Community leaders are responsible for clarifying and enforcing our standards of
acceptable behavior and will take appropriate and fair corrective action in
response to any behavior that they deem inappropriate, threatening, offensive
or harmful.
Community leaders have the right and responsibility to remove, edit, or reject
comments, commits, code, wiki edits, issues, and other contributions that are
not aligned to this Code of Conduct and will communicate reasons for moderation
decisions when appropriate.
## Scope
This Code of Conduct applies within all community spaces and also applies when
an individual is officially representing the community in public spaces.
Examples of representing our community include using an official e-mail address,
posting via an official social media account or acting as an appointed
representative at an online or offline event.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement at
contact@arc53.com.
All complaints will be reviewed and investigated promptly and fairly.
All community leaders are obligated to be respectful towards the privacy and security of the
reporter of any incident.
## Enforcement Guidelines
Community leaders will follow these Community Impact Guidelines in determining
the consequences for any action that they deem in violation of this Code of Conduct:
### 1. Correction
* **Community Impact**: Use of inappropriate language or other behavior deemed
unprofessional or unwelcome in the community space.
* **Consequence**: A private, written warning from community leaders, providing
clarity around the nature of the violation and an explanation of why the
behavior was inappropriate. A public apology may be requested.
### 2. Warning
* **Community Impact**: A violation through a single incident or series
of actions.
* **Consequence**: A warning with consequences for continued behavior. No
interaction with the people involved, including unsolicited interaction with
those enforcing the Code of Conduct, for a specified period of time. This
includes avoiding interactions in community spaces as well as external channels
like social media. Violating these terms may lead to a temporary or
permanent ban.
### 3. Temporary Ban
* **Community Impact**: A serious violation of community standards, including
sustained inappropriate behavior.
* **Consequence**: A temporary ban from any sort of interaction or public
communication with the community for a specified period of time. No public or
private interaction with the people involved, including unsolicited interaction
with those enforcing the Code of Conduct, is allowed during this period.
Violating these terms may lead to a permanent ban.
### 4. Permanent Ban
* **Community Impact**: Demonstrating a pattern of violation of community
standards, including sustained inappropriate behavior,harassment of an
individual or aggression towards or disparagement of classes of individuals.
* **Consequence**: A permanent ban from any sort of public interaction within
the community.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
version 2.0, available at
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
Community Impact Guidelines were inspired by [Mozilla's code of conduct
enforcement ladder](https://github.com/mozilla/diversity).
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see the FAQ at
https://www.contributor-covenant.org/faq. Translations are available at
https://www.contributor-covenant.org/translations.

128
CONTRIBUTING.md Normal file
View File

@@ -0,0 +1,128 @@
# Welcome to DocsGPT Contributing Guidelines
Thank you for choosing to contribute to DocsGPT! We are all very grateful!
# We accept different types of contributions
📣 **Discussions** - Engage in conversations, start new topics, or help answer questions.
🐞 **Issues** - This is where we keep track of tasks. It could be bugs,fixes or suggestions for new features.
🛠️ **Pull requests** - Suggest changes to our repository, either by working on existing issues or adding new features.
📚 **Wiki** - This is where our documentation resides.
## 🐞 Issues and Pull requests
- We value contributions in the form of discussions or suggestions. We recommend taking a look at existing issues and our [roadmap](https://github.com/orgs/arc53/projects/2).
- If you're interested in contributing code, here are some important things to know:
- We have a frontend built on React (Vite) and a backend in Python.
=======
Before creating issues, please check out how the latest version of our app looks and works by launching it via [Quickstart](https://github.com/arc53/DocsGPT#quickstart) the version on our live demo is slightly modified with login. Your issues should relate to the version that you can launch via [Quickstart](https://github.com/arc53/DocsGPT#quickstart).
### 👨‍💻 If you're interested in contributing code, here are some important things to know:
Tech Stack Overview:
- 🌐 Frontend: Built with React (Vite) ⚛️,
- 🖥 Backend: Developed in Python 🐍
### 🌐 If you are looking to contribute to frontend (⚛React, Vite):
- The current frontend is being migrated from [`/application`](https://github.com/arc53/DocsGPT/tree/main/application) to [`/frontend`](https://github.com/arc53/DocsGPT/tree/main/frontend) with a new design, so please contribute to the new one.
- Check out this [milestone](https://github.com/arc53/DocsGPT/milestone/1) and its issues.
- The updated Figma design can be found [here](https://www.figma.com/file/OXLtrl1EAy885to6S69554/DocsGPT?node-id=0%3A1&t=hjWVuxRg9yi5YkJ9-1).
Please try to follow the guidelines.
### 🖥 If you are looking to contribute to Backend (🐍 Python):
- Review our issues and contribute to [`/application`](https://github.com/arc53/DocsGPT/tree/main/application) or [`/scripts`](https://github.com/arc53/DocsGPT/tree/main/scripts) (please disregard old [`ingest_rst.py`](https://github.com/arc53/DocsGPT/blob/main/scripts/old/ingest_rst.py) [`ingest_rst_sphinx.py`](https://github.com/arc53/DocsGPT/blob/main/scripts/old/ingest_rst_sphinx.py) files; they will be deprecated soon).
- All new code should be covered with unit tests ([pytest](https://github.com/pytest-dev/pytest)). Please find tests under [`/tests`](https://github.com/arc53/DocsGPT/tree/main/tests) folder.
- Before submitting your Pull Request, ensure it can be queried after ingesting some test data.
### Testing
To run unit tests from the root of the repository, execute:
```
python -m pytest
```
## Workflow 📈
Here's a step-by-step guide on how to contribute to DocsGPT:
1. **Fork the Repository:**
- Click the "Fork" button at the top-right of this repository to create your fork.
2. **Clone the Forked Repository:**
- Clone the repository using:
``` shell
git clone https://github.com/<your-github-username>/DocsGPT.git
```
3. **Keep your Fork in Sync:**
- Before you make any changes, make sure that your fork is in sync to avoid merge conflicts using:
```shell
git remote add upstream https://github.com/arc53/DocsGPT.git
git pull upstream main
```
4. **Create and Switch to a New Branch:**
- Create a new branch for your contribution using:
```shell
git checkout -b your-branch-name
```
5. **Make Changes:**
- Make the required changes in your branch.
6. **Add Changes to the Staging Area:**
- Add your changes to the staging area using:
```shell
git add .
```
7. **Commit Your Changes:**
- Commit your changes with a descriptive commit message using:
```shell
git commit -m "Your descriptive commit message"
```
8. **Push Your Changes to the Remote Repository:**
- Push your branch with changes to your fork on GitHub using:
```shell
git push origin your-branch-name
```
9. **Submit a Pull Request (PR):**
- Create a Pull Request from your branch to the main repository. Make sure to include a detailed description of your changes and reference any related issues.
10. **Collaborate:**
- Be responsive to comments and feedback on your PR.
- Make necessary updates as suggested.
- Once your PR is approved, it will be merged into the main repository.
11. **Testing:**
- Before submitting a Pull Request, ensure your code passes all unit tests.
- To run unit tests from the root of the repository, execute:
```shell
python -m pytest
```
*Note: You should run the unit test only after making the changes to the backend code.*
12. **Questions and Collaboration:**
- Feel free to join our Discord. We're very friendly and welcoming to new contributors, so don't hesitate to reach out.
Thank you for considering contributing to DocsGPT! 🙏
## Questions/collaboration
Feel free to join our [Discord](https://discord.gg/n5BX8dh8rU). We're very friendly and welcoming to new contributors, so don't hesitate to reach out.
# Thank you so much for considering to contribute DocsGPT!🙏

37
HACKTOBERFEST.md Normal file
View File

@@ -0,0 +1,37 @@
# **🎉 Join the Hacktoberfest with DocsGPT and win a Free T-shirt and other prizes! 🎉**
Welcome, contributors! We're excited to announce that DocsGPT is participating in Hacktoberfest. Get involved by submitting meaningful pull requests.
All contributors with accepted PRs will receive a cool Holopin! 🤩 (Watch out for a reply in your PR to collect it).
### 🏆 Top 50 contributors will recieve a special T-shirt
### 🏆 [LLM Document analysis by LexEU competition](https://github.com/arc53/DocsGPT/blob/main/lexeu-competition.md):
A separate competition is available for those sumbit best new retrieval / workflow method that will analyze a Document using EU laws.
With 200$, 100$, 50$ prize for 1st, 2nd and 3rd place respectively.
You can find more information [here](https://github.com/arc53/DocsGPT/blob/main/lexeu-competition.md)
## 📜 Here's How to Contribute:
```text
🛠️ Code: This is the golden ticket! Make meaningful contributions through PRs.
🧩 API extention: Build an app utilising DocsGPT API. We prefer submissions that showcase original ideas and turn the API into an AI agent.
Non-Code Contributions:
📚 Wiki: Improve our documentation, Create a guide or change existing documentation.
🖥️ Design: Improve the UI/UX or design a new feature.
📝 Blogging or Content Creation: Write articles or create videos to showcase DocsGPT or highlight your contributions!
```
### 📝 Guidelines for Pull Requests:
- Familiarize yourself with the current contributions and our [Roadmap](https://github.com/orgs/arc53/projects/2).
- Before contributing we highly advise that you check existing [issues](https://github.com/arc53/DocsGPT/issues) or [create](https://github.com/arc53/DocsGPT/issues/new/choose) an issue and wait to get assigned.
- Once you are finished with your contribution, please fill in this [form](https://airtable.com/appikMaJwdHhC1SDP/pagoblCJ9W29wf6Hf/form).
- Refer to the [Documentation](https://docs.docsgpt.cloud/).
- Feel free to join our [Discord](https://discord.gg/n5BX8dh8rU) server. We're here to help newcomers, so don't hesitate to jump in! Join us [here](https://discord.gg/n5BX8dh8rU).
Thank you very much for considering contributing to DocsGPT during Hacktoberfest! 🙏 Your contributions (not just simple typo) could earn you a stylish new t-shirt and other prizes as a token of our appreciation. 🎁 Join us, and let's code together! 🚀

21
LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2023 arc53
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

204
README.md Normal file
View File

@@ -0,0 +1,204 @@
<h1 align="center">
DocsGPT 🦖
</h1>
<p align="center">
<strong>Open-Source Documentation Assistant</strong>
</p>
<p align="left">
<strong><a href="https://www.docsgpt.cloud/">DocsGPT</a></strong> is a cutting-edge open-source solution that streamlines the process of finding information in the project documentation. With its integration of the powerful <strong>GPT</strong> models, developers can easily ask questions about a project and receive accurate answers.
Say goodbye to time-consuming manual searches, and let <strong><a href="https://www.docsgpt.cloud/">DocsGPT</a></strong> help you quickly find the information you need. Try it out and see how it revolutionizes your project documentation experience. Contribute to its development and be a part of the future of AI-powered assistance.
</p>
<div align="center">
<a href="https://github.com/arc53/DocsGPT">![link to main GitHub showing Stars number](https://img.shields.io/github/stars/arc53/docsgpt?style=social)</a>
<a href="https://github.com/arc53/DocsGPT">![link to main GitHub showing Forks number](https://img.shields.io/github/forks/arc53/docsgpt?style=social)</a>
<a href="https://github.com/arc53/DocsGPT/blob/main/LICENSE">![link to license file](https://img.shields.io/github/license/arc53/docsgpt)</a>
<a href="https://discord.gg/n5BX8dh8rU">![link to discord](https://img.shields.io/discord/1070046503302877216)</a>
<a href="https://twitter.com/docsgptai">![X (formerly Twitter) URL](https://img.shields.io/twitter/follow/docsgptai)</a>
</div>
### 🎃 [Hacktoberfest Prizes, Rules & Q&A](https://github.com/arc53/DocsGPT/blob/main/HACKTOBERFEST.md) 🎃
### Our [Livestream to Dive into Hacktoberfest! Prizes, Rules & Q&A 🎉](https://www.youtube.com/watch?v=5QQaFFu9BC8) on 3rd of October
### Production Support / Help for Companies:
We're eager to provide personalized assistance when deploying your DocsGPT to a live environment.
- [Book Enterprise / teams Demo :wave:](https://cal.com/arc53/docsgpt-demo-b2b?date=2024-09-27&month=2024-09)
- [Send Email :email:](mailto:contact@arc53.com?subject=DocsGPT%20support%2Fsolutions)
![video-example-of-docs-gpt](https://d3dg1063dc54p9.cloudfront.net/videos/demov3.gif)
## Roadmap
You can find our roadmap [here](https://github.com/orgs/arc53/projects/2). Please don't hesitate to contribute or create issues, it helps us improve DocsGPT!
## Our Open-Source Models Optimized for DocsGPT:
| Name | Base Model | Requirements (or similar) |
| --------------------------------------------------------------------- | ----------- | ------------------------- |
| [Docsgpt-7b-mistral](https://huggingface.co/Arc53/docsgpt-7b-mistral) | Mistral-7b | 1xA10G gpu |
| [Docsgpt-14b](https://huggingface.co/Arc53/docsgpt-14b) | llama-2-14b | 2xA10 gpu's |
| [Docsgpt-40b-falcon](https://huggingface.co/Arc53/docsgpt-40b-falcon) | falcon-40b | 8xA10G gpu's |
If you don't have enough resources to run it, you can use bitsnbytes to quantize.
## End to End AI Framework for Information Retrieval
![Architecture chart](https://github.com/user-attachments/assets/fc6a7841-ddfc-45e6-b5a0-d05fe648cbe2)
## Useful Links
- :mag: :fire: [Cloud Version](https://app.docsgpt.cloud/)
- :speech_balloon: :tada: [Join our Discord](https://discord.gg/n5BX8dh8rU)
- :books: :sunglasses: [Guides](https://docs.docsgpt.cloud/)
- :couple: [Interested in contributing?](https://github.com/arc53/DocsGPT/blob/main/CONTRIBUTING.md)
- :file_folder: :rocket: [How to use any other documentation](https://docs.docsgpt.cloud/Guides/How-to-train-on-other-documentation)
- :house: :closed_lock_with_key: [How to host it locally (so all data will stay on-premises)](https://docs.docsgpt.cloud/Guides/How-to-use-different-LLM)
## Project Structure
- Application - Flask app (main application).
- Extensions - Chrome extension.
- Scripts - Script that creates similarity search index for other libraries.
- Frontend - Frontend uses <a href="https://vitejs.dev/">Vite</a> and <a href="https://react.dev/">React</a>.
## QuickStart
> [!Note]
> Make sure you have [Docker](https://docs.docker.com/engine/install/) installed
On Mac OS or Linux, write:
`./setup.sh`
It will install all the dependencies and allow you to download the local model, use OpenAI or use our LLM API.
Otherwise, refer to this Guide for Windows:
1. Download and open this repository with `git clone https://github.com/arc53/DocsGPT.git`
2. Create a `.env` file in your root directory and set the env variables and `VITE_API_STREAMING` to true or false, depending on whether you want streaming answers or not.
It should look like this inside:
```
LLM_NAME=[docsgpt or openai or others]
VITE_API_STREAMING=true
API_KEY=[if LLM_NAME is openai]
```
See optional environment variables in the [/.env-template](https://github.com/arc53/DocsGPT/blob/main/.env-template) and [/application/.env_sample](https://github.com/arc53/DocsGPT/blob/main/application/.env_sample) files.
3. Run [./run-with-docker-compose.sh](https://github.com/arc53/DocsGPT/blob/main/run-with-docker-compose.sh).
4. Navigate to http://localhost:5173/.
To stop, just run `Ctrl + C`.
## Development Environments
### Spin up Mongo and Redis
For development, only two containers are used from [docker-compose.yaml](https://github.com/arc53/DocsGPT/blob/main/docker-compose.yaml) (by deleting all services except for Redis and Mongo).
See file [docker-compose-dev.yaml](./docker-compose-dev.yaml).
Run
```
docker compose -f docker-compose-dev.yaml build
docker compose -f docker-compose-dev.yaml up -d
```
### Run the Backend
> [!Note]
> Make sure you have Python 3.10 or 3.11 installed.
1. Export required environment variables or prepare a `.env` file in the project folder:
- Copy [.env_sample](https://github.com/arc53/DocsGPT/blob/main/application/.env_sample) and create `.env`.
(check out [`application/core/settings.py`](application/core/settings.py) if you want to see more config options.)
2. (optional) Create a Python virtual environment:
You can follow the [Python official documentation](https://docs.python.org/3/tutorial/venv.html) for virtual environments.
a) On Mac OS and Linux
```commandline
python -m venv venv
. venv/bin/activate
```
b) On Windows
```commandline
python -m venv venv
venv/Scripts/activate
```
3. Download embedding model and save it in the `model/` folder:
You can use the script below, or download it manually from [here](https://d3dg1063dc54p9.cloudfront.net/models/embeddings/mpnet-base-v2.zip), unzip it and save it in the `model/` folder.
```commandline
wget https://d3dg1063dc54p9.cloudfront.net/models/embeddings/mpnet-base-v2.zip
unzip mpnet-base-v2.zip -d model
rm mpnet-base-v2.zip
```
4. Install dependencies for the backend:
```commandline
pip install -r application/requirements.txt
```
5. Run the app using `flask --app application/app.py run --host=0.0.0.0 --port=7091`.
6. Start worker with `celery -A application.app.celery worker -l INFO`.
### Start Frontend
> [!Note]
> Make sure you have Node version 16 or higher.
1. Navigate to the [/frontend](https://github.com/arc53/DocsGPT/tree/main/frontend) folder.
2. Install the required packages `husky` and `vite` (ignore if already installed).
```commandline
npm install husky -g
npm install vite -g
```
3. Install dependencies by running `npm install --include=dev`.
4. Run the app using `npm run dev`.
## Contributing
Please refer to the [CONTRIBUTING.md](CONTRIBUTING.md) file for information about how to get involved. We welcome issues, questions, and pull requests.
## Code Of Conduct
We as members, contributors, and leaders, pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. Please refer to the [CODE_OF_CONDUCT.md](CODE_OF_CONDUCT.md) file for more information about contributing.
## Many Thanks To Our Contributors⚡
<a href="https://github.com/arc53/DocsGPT/graphs/contributors" alt="View Contributors">
<img src="https://contrib.rocks/image?repo=arc53/DocsGPT" alt="Contributors" />
</a>
## License
The source code license is [MIT](https://opensource.org/license/mit/), as described in the [LICENSE](LICENSE) file.
Built with [:bird: :link: LangChain](https://github.com/hwchase17/langchain)

BIN
Readme Logo.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 23 KiB

14
SECURITY.md Normal file
View File

@@ -0,0 +1,14 @@
# Security Policy
## Supported Versions
Supported Versions:
Currently, we support security patches by committing changes and bumping the version published on Github.
## Reporting a Vulnerability
Found a vulnerability? Please email us:
security@arc53.com

11
application/.env_sample Normal file
View File

@@ -0,0 +1,11 @@
API_KEY=your_api_key
EMBEDDINGS_KEY=your_api_key
API_URL=http://localhost:7091
FLASK_APP=application/app.py
FLASK_DEBUG=true
#For OPENAI on Azure
OPENAI_API_BASE=
OPENAI_API_VERSION=
AZURE_DEPLOYMENT_NAME=
AZURE_EMBEDDINGS_DEPLOYMENT_NAME=

88
application/Dockerfile Normal file
View File

@@ -0,0 +1,88 @@
# Builder Stage
FROM ubuntu:24.04 as builder
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update && \
apt-get install -y software-properties-common && \
add-apt-repository ppa:deadsnakes/ppa && \
# Install necessary packages and Python
apt-get update && \
apt-get install -y --no-install-recommends gcc wget unzip libc6-dev python3.11 python3.11-distutils python3.11-venv && \
rm -rf /var/lib/apt/lists/*
# Verify Python installation and setup symlink
RUN if [ -f /usr/bin/python3.11 ]; then \
ln -s /usr/bin/python3.11 /usr/bin/python; \
else \
echo "Python 3.11 not found"; exit 1; \
fi
# Download and unzip the model
RUN wget https://d3dg1063dc54p9.cloudfront.net/models/embeddings/mpnet-base-v2.zip && \
unzip mpnet-base-v2.zip -d model && \
rm mpnet-base-v2.zip
# Install Rust
RUN wget -q -O - https://sh.rustup.rs | sh -s -- -y
# Clean up to reduce container size
RUN apt-get remove --purge -y wget unzip && apt-get autoremove -y && rm -rf /var/lib/apt/lists/*
# Copy requirements.txt
COPY requirements.txt .
# Setup Python virtual environment
RUN python3.11 -m venv /venv
# Activate virtual environment and install Python packages
ENV PATH="/venv/bin:$PATH"
# Install Python packages
RUN pip install --no-cache-dir --upgrade pip && \
pip install --no-cache-dir tiktoken && \
pip install --no-cache-dir -r requirements.txt
# Final Stage
FROM ubuntu:24.04 as final
RUN apt-get update && \
apt-get install -y software-properties-common && \
add-apt-repository ppa:deadsnakes/ppa && \
# Install Python
apt-get update && apt-get install -y --no-install-recommends python3.11 && \
ln -s /usr/bin/python3.11 /usr/bin/python && \
rm -rf /var/lib/apt/lists/*
# Set working directory
WORKDIR /app
# Create a non-root user: `appuser` (Feel free to choose a name)
RUN groupadd -r appuser && \
useradd -r -g appuser -d /app -s /sbin/nologin -c "Docker image user" appuser
# Copy the virtual environment and model from the builder stage
COPY --from=builder /venv /venv
COPY --from=builder /model /app/model
# Copy your application code
COPY . /app/application
# Change the ownership of the /app directory to the appuser
RUN mkdir -p /app/application/inputs/local
RUN chown -R appuser:appuser /app
# Set environment variables
ENV FLASK_APP=app.py \
FLASK_DEBUG=true \
PATH="/venv/bin:$PATH"
# Expose the port the app runs on
EXPOSE 7091
# Switch to non-root user
USER appuser
# Start Gunicorn
CMD ["gunicorn", "-w", "2", "--timeout", "120", "--bind", "0.0.0.0:7091", "application.wsgi:app"]

0
application/__init__.py Normal file
View File

View File

View File

View File

@@ -0,0 +1,618 @@
import asyncio
import datetime
import json
import logging
import os
import sys
import traceback
from bson.dbref import DBRef
from bson.objectid import ObjectId
from flask import Blueprint, current_app, make_response, request, Response
from flask_restx import fields, Namespace, Resource
from pymongo import MongoClient
from application.core.settings import settings
from application.error import bad_request
from application.extensions import api
from application.llm.llm_creator import LLMCreator
from application.retriever.retriever_creator import RetrieverCreator
from application.utils import check_required_fields
logger = logging.getLogger(__name__)
mongo = MongoClient(settings.MONGO_URI)
db = mongo["docsgpt"]
conversations_collection = db["conversations"]
sources_collection = db["sources"]
prompts_collection = db["prompts"]
api_key_collection = db["api_keys"]
user_logs_collection = db["user_logs"]
answer = Blueprint("answer", __name__)
answer_ns = Namespace("answer", description="Answer related operations", path="/")
api.add_namespace(answer_ns)
gpt_model = ""
# to have some kind of default behaviour
if settings.LLM_NAME == "openai":
gpt_model = "gpt-3.5-turbo"
elif settings.LLM_NAME == "anthropic":
gpt_model = "claude-2"
if settings.MODEL_NAME: # in case there is particular model name configured
gpt_model = settings.MODEL_NAME
# load the prompts
current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
with open(os.path.join(current_dir, "prompts", "chat_combine_default.txt"), "r") as f:
chat_combine_template = f.read()
with open(os.path.join(current_dir, "prompts", "chat_reduce_prompt.txt"), "r") as f:
chat_reduce_template = f.read()
with open(os.path.join(current_dir, "prompts", "chat_combine_creative.txt"), "r") as f:
chat_combine_creative = f.read()
with open(os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r") as f:
chat_combine_strict = f.read()
api_key_set = settings.API_KEY is not None
embeddings_key_set = settings.EMBEDDINGS_KEY is not None
async def async_generate(chain, question, chat_history):
result = await chain.arun({"question": question, "chat_history": chat_history})
return result
def run_async_chain(chain, question, chat_history):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = {}
try:
answer = loop.run_until_complete(async_generate(chain, question, chat_history))
finally:
loop.close()
result["answer"] = answer
return result
def get_data_from_api_key(api_key):
data = api_key_collection.find_one({"key": api_key})
# # Raise custom exception if the API key is not found
if data is None:
raise Exception("Invalid API Key, please generate new key", 401)
if "retriever" not in data:
data["retriever"] = None
if "source" in data and isinstance(data["source"], DBRef):
source_doc = db.dereference(data["source"])
data["source"] = str(source_doc["_id"])
if "retriever" in source_doc:
data["retriever"] = source_doc["retriever"]
else:
data["source"] = {}
return data
def get_retriever(source_id: str):
doc = sources_collection.find_one({"_id": ObjectId(source_id)})
if doc is None:
raise Exception("Source document does not exist", 404)
retriever_name = None if "retriever" not in doc else doc["retriever"]
return retriever_name
def is_azure_configured():
return (
settings.OPENAI_API_BASE
and settings.OPENAI_API_VERSION
and settings.AZURE_DEPLOYMENT_NAME
)
def save_conversation(conversation_id, question, response, source_log_docs, llm):
if conversation_id is not None and conversation_id != "None":
conversations_collection.update_one(
{"_id": ObjectId(conversation_id)},
{
"$push": {
"queries": {
"prompt": question,
"response": response,
"sources": source_log_docs,
}
}
},
)
else:
# create new conversation
# generate summary
messages_summary = [
{
"role": "assistant",
"content": "Summarise following conversation in no more than 3 "
"words, respond ONLY with the summary, use the same "
"language as the system \n\nUser: "
+ question
+ "\n\n"
+ "AI: "
+ response,
},
{
"role": "user",
"content": "Summarise following conversation in no more than 3 words, "
"respond ONLY with the summary, use the same language as the "
"system",
},
]
completion = llm.gen(model=gpt_model, messages=messages_summary, max_tokens=30)
conversation_id = conversations_collection.insert_one(
{
"user": "local",
"date": datetime.datetime.utcnow(),
"name": completion,
"queries": [
{
"prompt": question,
"response": response,
"sources": source_log_docs,
}
],
}
).inserted_id
return conversation_id
def get_prompt(prompt_id):
if prompt_id == "default":
prompt = chat_combine_template
elif prompt_id == "creative":
prompt = chat_combine_creative
elif prompt_id == "strict":
prompt = chat_combine_strict
else:
prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"]
return prompt
def complete_stream(
question, retriever, conversation_id, user_api_key, isNoneDoc=False
):
try:
response_full = ""
source_log_docs = []
answer = retriever.gen()
sources = retriever.search()
for source in sources:
if "text" in source:
source["text"] = source["text"][:100].strip() + "..."
if len(sources) > 0:
data = json.dumps({"type": "source", "source": sources})
yield f"data: {data}\n\n"
for line in answer:
if "answer" in line:
response_full += str(line["answer"])
data = json.dumps(line)
yield f"data: {data}\n\n"
elif "source" in line:
source_log_docs.append(line["source"])
if isNoneDoc:
for doc in source_log_docs:
doc["source"] = "None"
llm = LLMCreator.create_llm(
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key
)
if user_api_key is None:
conversation_id = save_conversation(
conversation_id, question, response_full, source_log_docs, llm
)
# send data.type = "end" to indicate that the stream has ended as json
data = json.dumps({"type": "id", "id": str(conversation_id)})
yield f"data: {data}\n\n"
retriever_params = retriever.get_params()
user_logs_collection.insert_one(
{
"action": "stream_answer",
"level": "info",
"user": "local",
"api_key": user_api_key,
"question": question,
"response": response_full,
"sources": source_log_docs,
"retriever_params": retriever_params,
"timestamp": datetime.datetime.now(datetime.timezone.utc),
}
)
data = json.dumps({"type": "end"})
yield f"data: {data}\n\n"
except Exception as e:
print("\033[91merr", str(e), file=sys.stderr)
data = json.dumps(
{
"type": "error",
"error": "Please try again later. We apologize for any inconvenience.",
"error_exception": str(e),
}
)
yield f"data: {data}\n\n"
return
@answer_ns.route("/stream")
class Stream(Resource):
stream_model = api.model(
"StreamModel",
{
"question": fields.String(
required=True, description="Question to be asked"
),
"history": fields.List(
fields.String, required=False, description="Chat history"
),
"conversation_id": fields.String(
required=False, description="Conversation ID"
),
"prompt_id": fields.String(
required=False, default="default", description="Prompt ID"
),
"selectedDocs": fields.String(
required=False, description="Selected documents"
),
"chunks": fields.Integer(
required=False, default=2, description="Number of chunks"
),
"token_limit": fields.Integer(required=False, description="Token limit"),
"retriever": fields.String(required=False, description="Retriever type"),
"api_key": fields.String(required=False, description="API key"),
"active_docs": fields.String(
required=False, description="Active documents"
),
"isNoneDoc": fields.Boolean(
required=False, description="Flag indicating if no document is used"
),
},
)
@api.expect(stream_model)
@api.doc(description="Stream a response based on the question and retriever")
def post(self):
data = request.get_json()
required_fields = ["question"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
question = data["question"]
history = data.get("history", [])
history = json.loads(history)
conversation_id = data.get("conversation_id")
prompt_id = data.get("prompt_id", "default")
if "selectedDocs" in data and data["selectedDocs"] is None:
chunks = 0
else:
chunks = int(data.get("chunks", 2))
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
retriever_name = data.get("retriever", "classic")
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
chunks = int(data_key.get("chunks", 2))
prompt_id = data_key.get("prompt_id", "default")
source = {"active_docs": data_key.get("source")}
retriever_name = data_key.get("retriever", retriever_name)
user_api_key = data["api_key"]
elif "active_docs" in data:
source = {"active_docs": data["active_docs"]}
retriever_name = get_retriever(data["active_docs"]) or retriever_name
user_api_key = None
else:
source = {}
user_api_key = None
current_app.logger.info(
f"/stream - request_data: {data}, source: {source}",
extra={"data": json.dumps({"request_data": data, "source": source})},
)
prompt = get_prompt(prompt_id)
retriever = RetrieverCreator.create_retriever(
retriever_name,
question=question,
source=source,
chat_history=history,
prompt=prompt,
chunks=chunks,
token_limit=token_limit,
gpt_model=gpt_model,
user_api_key=user_api_key,
)
return Response(
complete_stream(
question=question,
retriever=retriever,
conversation_id=conversation_id,
user_api_key=user_api_key,
isNoneDoc=data.get("isNoneDoc"),
),
mimetype="text/event-stream",
)
except ValueError:
message = "Malformed request body"
print("\033[91merr", str(message), file=sys.stderr)
return Response(
error_stream_generate(message),
status=400,
mimetype="text/event-stream",
)
except Exception as e:
current_app.logger.error(
f"/stream - error: {str(e)} - traceback: {traceback.format_exc()}",
extra={"error": str(e), "traceback": traceback.format_exc()},
)
message = e.args[0]
status_code = 400
# Custom exceptions with two arguments, index 1 as status code
if len(e.args) >= 2:
status_code = e.args[1]
return Response(
error_stream_generate(message),
status=status_code,
mimetype="text/event-stream",
)
def error_stream_generate(err_response):
data = json.dumps({"type": "error", "error": err_response})
yield f"data: {data}\n\n"
@answer_ns.route("/api/answer")
class Answer(Resource):
answer_model = api.model(
"AnswerModel",
{
"question": fields.String(
required=True, description="The question to answer"
),
"history": fields.List(
fields.String, required=False, description="Conversation history"
),
"conversation_id": fields.String(
required=False, description="Conversation ID"
),
"prompt_id": fields.String(
required=False, default="default", description="Prompt ID"
),
"chunks": fields.Integer(
required=False, default=2, description="Number of chunks"
),
"token_limit": fields.Integer(required=False, description="Token limit"),
"retriever": fields.String(required=False, description="Retriever type"),
"api_key": fields.String(required=False, description="API key"),
"active_docs": fields.String(
required=False, description="Active documents"
),
"isNoneDoc": fields.Boolean(
required=False, description="Flag indicating if no document is used"
),
},
)
@api.expect(answer_model)
@api.doc(description="Provide an answer based on the question and retriever")
def post(self):
data = request.get_json()
required_fields = ["question"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
question = data["question"]
history = data.get("history", [])
conversation_id = data.get("conversation_id")
prompt_id = data.get("prompt_id", "default")
chunks = int(data.get("chunks", 2))
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
retriever_name = data.get("retriever", "classic")
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
chunks = int(data_key.get("chunks", 2))
prompt_id = data_key.get("prompt_id", "default")
source = {"active_docs": data_key.get("source")}
retriever_name = data_key.get("retriever", retriever_name)
user_api_key = data["api_key"]
elif "active_docs" in data:
source = {"active_docs": data["active_docs"]}
retriever_name = get_retriever(data["active_docs"]) or retriever_name
user_api_key = None
else:
source = {}
user_api_key = None
prompt = get_prompt(prompt_id)
current_app.logger.info(
f"/api/answer - request_data: {data}, source: {source}",
extra={"data": json.dumps({"request_data": data, "source": source})},
)
retriever = RetrieverCreator.create_retriever(
retriever_name,
question=question,
source=source,
chat_history=history,
prompt=prompt,
chunks=chunks,
token_limit=token_limit,
gpt_model=gpt_model,
user_api_key=user_api_key,
)
source_log_docs = []
response_full = ""
for line in retriever.gen():
if "source" in line:
source_log_docs.append(line["source"])
elif "answer" in line:
response_full += line["answer"]
if data.get("isNoneDoc"):
for doc in source_log_docs:
doc["source"] = "None"
llm = LLMCreator.create_llm(
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key
)
result = {"answer": response_full, "sources": source_log_docs}
result["conversation_id"] = str(
save_conversation(
conversation_id, question, response_full, source_log_docs, llm
)
)
retriever_params = retriever.get_params()
user_logs_collection.insert_one(
{
"action": "api_answer",
"level": "info",
"user": "local",
"api_key": user_api_key,
"question": question,
"response": response_full,
"sources": source_log_docs,
"retriever_params": retriever_params,
"timestamp": datetime.datetime.now(datetime.timezone.utc),
}
)
except Exception as e:
current_app.logger.error(
f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}",
extra={"error": str(e), "traceback": traceback.format_exc()},
)
return bad_request(500, str(e))
return make_response(result, 200)
@answer_ns.route("/api/search")
class Search(Resource):
search_model = api.model(
"SearchModel",
{
"question": fields.String(
required=True, description="The question to search"
),
"chunks": fields.Integer(
required=False, default=2, description="Number of chunks"
),
"api_key": fields.String(
required=False, description="API key for authentication"
),
"active_docs": fields.String(
required=False, description="Active documents for retrieval"
),
"retriever": fields.String(required=False, description="Retriever type"),
"token_limit": fields.Integer(
required=False, description="Limit for tokens"
),
"isNoneDoc": fields.Boolean(
required=False, description="Flag indicating if no document is used"
),
},
)
@api.expect(search_model)
@api.doc(
description="Search for relevant documents based on the question and retriever"
)
def post(self):
data = request.get_json()
required_fields = ["question"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
question = data["question"]
chunks = int(data.get("chunks", 2))
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
retriever_name = data.get("retriever", "classic")
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
chunks = int(data_key.get("chunks", 2))
source = {"active_docs": data_key.get("source")}
user_api_key = data["api_key"]
elif "active_docs" in data:
source = {"active_docs": data["active_docs"]}
user_api_key = None
else:
source = {}
user_api_key = None
current_app.logger.info(
f"/api/answer - request_data: {data}, source: {source}",
extra={"data": json.dumps({"request_data": data, "source": source})},
)
retriever = RetrieverCreator.create_retriever(
retriever_name,
question=question,
source=source,
chat_history=[],
prompt="default",
chunks=chunks,
token_limit=token_limit,
gpt_model=gpt_model,
user_api_key=user_api_key,
)
docs = retriever.search()
retriever_params = retriever.get_params()
user_logs_collection.insert_one(
{
"action": "api_search",
"level": "info",
"user": "local",
"api_key": user_api_key,
"question": question,
"sources": docs,
"retriever_params": retriever_params,
"timestamp": datetime.datetime.now(datetime.timezone.utc),
}
)
if data.get("isNoneDoc"):
for doc in docs:
doc["source"] = "None"
except Exception as e:
current_app.logger.error(
f"/api/search - error: {str(e)} - traceback: {traceback.format_exc()}",
extra={"error": str(e), "traceback": traceback.format_exc()},
)
return bad_request(500, str(e))
return make_response(docs, 200)

View File

View File

@@ -0,0 +1,104 @@
import os
import datetime
from flask import Blueprint, request, send_from_directory
from pymongo import MongoClient
from werkzeug.utils import secure_filename
from bson.objectid import ObjectId
from application.core.settings import settings
mongo = MongoClient(settings.MONGO_URI)
db = mongo["docsgpt"]
conversations_collection = db["conversations"]
sources_collection = db["sources"]
current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
internal = Blueprint("internal", __name__)
@internal.route("/api/download", methods=["get"])
def download_file():
user = secure_filename(request.args.get("user"))
job_name = secure_filename(request.args.get("name"))
filename = secure_filename(request.args.get("file"))
save_dir = os.path.join(current_dir, settings.UPLOAD_FOLDER, user, job_name)
return send_from_directory(save_dir, filename, as_attachment=True)
@internal.route("/api/upload_index", methods=["POST"])
def upload_index_files():
"""Upload two files(index.faiss, index.pkl) to the user's folder."""
if "user" not in request.form:
return {"status": "no user"}
user = secure_filename(request.form["user"])
if "name" not in request.form:
return {"status": "no name"}
job_name = secure_filename(request.form["name"])
tokens = secure_filename(request.form["tokens"])
retriever = secure_filename(request.form["retriever"])
id = secure_filename(request.form["id"])
type = secure_filename(request.form["type"])
remote_data = request.form["remote_data"] if "remote_data" in request.form else None
sync_frequency = secure_filename(request.form["sync_frequency"]) if "sync_frequency" in request.form else None
save_dir = os.path.join(current_dir, "indexes", str(id))
if settings.VECTOR_STORE == "faiss":
if "file_faiss" not in request.files:
print("No file part")
return {"status": "no file"}
file_faiss = request.files["file_faiss"]
if file_faiss.filename == "":
return {"status": "no file name"}
if "file_pkl" not in request.files:
print("No file part")
return {"status": "no file"}
file_pkl = request.files["file_pkl"]
if file_pkl.filename == "":
return {"status": "no file name"}
# saves index files
if not os.path.exists(save_dir):
os.makedirs(save_dir)
file_faiss.save(os.path.join(save_dir, "index.faiss"))
file_pkl.save(os.path.join(save_dir, "index.pkl"))
existing_entry = sources_collection.find_one({"_id": ObjectId(id)})
if existing_entry:
sources_collection.update_one(
{"_id": ObjectId(id)},
{
"$set": {
"user": user,
"name": job_name,
"language": job_name,
"date": datetime.datetime.now(),
"model": settings.EMBEDDINGS_NAME,
"type": type,
"tokens": tokens,
"retriever": retriever,
"remote_data": remote_data,
"sync_frequency": sync_frequency,
}
},
)
else:
sources_collection.insert_one(
{
"_id": ObjectId(id),
"user": user,
"name": job_name,
"language": job_name,
"date": datetime.datetime.now(),
"model": settings.EMBEDDINGS_NAME,
"type": type,
"tokens": tokens,
"retriever": retriever,
"remote_data": remote_data,
"sync_frequency": sync_frequency,
}
)
return {"status": "ok"}

View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,38 @@
from datetime import timedelta
from application.celery_init import celery
from application.worker import ingest_worker, remote_worker, sync_worker
@celery.task(bind=True)
def ingest(self, directory, formats, name_job, filename, user):
resp = ingest_worker(self, directory, formats, name_job, filename, user)
return resp
@celery.task(bind=True)
def ingest_remote(self, source_data, job_name, user, loader):
resp = remote_worker(self, source_data, job_name, user, loader)
return resp
@celery.task(bind=True)
def schedule_syncs(self, frequency):
resp = sync_worker(self, frequency)
return resp
@celery.on_after_configure.connect
def setup_periodic_tasks(sender, **kwargs):
sender.add_periodic_task(
timedelta(days=1),
schedule_syncs.s("daily"),
)
sender.add_periodic_task(
timedelta(weeks=1),
schedule_syncs.s("weekly"),
)
sender.add_periodic_task(
timedelta(days=30),
schedule_syncs.s("monthly"),
)

53
application/app.py Normal file
View File

@@ -0,0 +1,53 @@
import platform
import dotenv
from flask import Flask, redirect, request
from application.api.answer.routes import answer
from application.api.internal.routes import internal
from application.api.user.routes import user
from application.celery_init import celery
from application.core.logging_config import setup_logging
from application.core.settings import settings
from application.extensions import api
if platform.system() == "Windows":
import pathlib
pathlib.PosixPath = pathlib.WindowsPath
dotenv.load_dotenv()
setup_logging()
app = Flask(__name__)
app.register_blueprint(user)
app.register_blueprint(answer)
app.register_blueprint(internal)
app.config.update(
UPLOAD_FOLDER="inputs",
CELERY_BROKER_URL=settings.CELERY_BROKER_URL,
CELERY_RESULT_BACKEND=settings.CELERY_RESULT_BACKEND,
MONGO_URI=settings.MONGO_URI,
)
celery.config_from_object("application.celeryconfig")
api.init_app(app)
@app.route("/")
def home():
if request.remote_addr in ("0.0.0.0", "127.0.0.1", "localhost", "172.18.0.1"):
return redirect("http://localhost:5173")
else:
return "Welcome to DocsGPT Backend!"
@app.after_request
def after_request(response):
response.headers.add("Access-Control-Allow-Origin", "*")
response.headers.add("Access-Control-Allow-Headers", "Content-Type,Authorization")
response.headers.add("Access-Control-Allow-Methods", "GET,PUT,POST,DELETE,OPTIONS")
return response
if __name__ == "__main__":
app.run(debug=settings.FLASK_DEBUG_MODE, port=7091)

View File

@@ -0,0 +1,15 @@
from celery import Celery
from application.core.settings import settings
from celery.signals import setup_logging
def make_celery(app_name=__name__):
celery = Celery(app_name, broker=settings.CELERY_BROKER_URL, backend=settings.CELERY_RESULT_BACKEND)
celery.conf.update(settings)
return celery
@setup_logging.connect
def config_loggers(*args, **kwargs):
from application.core.logging_config import setup_logging
setup_logging()
celery = make_celery()

View File

@@ -0,0 +1,8 @@
import os
broker_url = os.getenv("CELERY_BROKER_URL")
result_backend = os.getenv("CELERY_RESULT_BACKEND")
task_serializer = 'json'
result_serializer = 'json'
accept_content = ['json']

View File

View File

@@ -0,0 +1,22 @@
from logging.config import dictConfig
def setup_logging():
dictConfig({
'version': 1,
'formatters': {
'default': {
'format': '[%(asctime)s] %(levelname)s in %(module)s: %(message)s',
}
},
"handlers": {
"console": {
"class": "logging.StreamHandler",
"stream": "ext://sys.stdout",
"formatter": "default",
}
},
'root': {
'level': 'INFO',
'handlers': ['console'],
},
})

View File

@@ -0,0 +1,76 @@
from pathlib import Path
from typing import Optional
import os
from pydantic_settings import BaseSettings
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
class Settings(BaseSettings):
LLM_NAME: str = "docsgpt"
MODEL_NAME: Optional[str] = None # if LLM_NAME is openai, MODEL_NAME can be gpt-4 or gpt-3.5-turbo
EMBEDDINGS_NAME: str = "huggingface_sentence-transformers/all-mpnet-base-v2"
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
MONGO_URI: str = "mongodb://localhost:27017/docsgpt"
MODEL_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf")
DEFAULT_MAX_HISTORY: int = 150
MODEL_TOKEN_LIMITS: dict = {"gpt-3.5-turbo": 4096, "claude-2": 1e5}
UPLOAD_FOLDER: str = "inputs"
VECTOR_STORE: str = "faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus"
RETRIEVERS_ENABLED: list = ["classic_rag", "duckduck_search"] # also brave_search
API_URL: str = "http://localhost:7091" # backend url for celery worker
API_KEY: Optional[str] = None # LLM api key
EMBEDDINGS_KEY: Optional[str] = None # api key for embeddings (if using openai, just copy API_KEY)
OPENAI_API_BASE: Optional[str] = None # azure openai api base url
OPENAI_API_VERSION: Optional[str] = None # azure openai api version
AZURE_DEPLOYMENT_NAME: Optional[str] = None # azure deployment name for answering
AZURE_EMBEDDINGS_DEPLOYMENT_NAME: Optional[str] = None # azure deployment name for embeddings
OPENAI_BASE_URL: Optional[str] = None # openai base url for open ai compatable models
# elasticsearch
ELASTIC_CLOUD_ID: Optional[str] = None # cloud id for elasticsearch
ELASTIC_USERNAME: Optional[str] = None # username for elasticsearch
ELASTIC_PASSWORD: Optional[str] = None # password for elasticsearch
ELASTIC_URL: Optional[str] = None # url for elasticsearch
ELASTIC_INDEX: Optional[str] = "docsgpt" # index name for elasticsearch
# SageMaker config
SAGEMAKER_ENDPOINT: Optional[str] = None # SageMaker endpoint name
SAGEMAKER_REGION: Optional[str] = None # SageMaker region name
SAGEMAKER_ACCESS_KEY: Optional[str] = None # SageMaker access key
SAGEMAKER_SECRET_KEY: Optional[str] = None # SageMaker secret key
# prem ai project id
PREMAI_PROJECT_ID: Optional[str] = None
# Qdrant vectorstore config
QDRANT_COLLECTION_NAME: Optional[str] = "docsgpt"
QDRANT_LOCATION: Optional[str] = None
QDRANT_URL: Optional[str] = None
QDRANT_PORT: Optional[int] = 6333
QDRANT_GRPC_PORT: int = 6334
QDRANT_PREFER_GRPC: bool = False
QDRANT_HTTPS: Optional[bool] = None
QDRANT_API_KEY: Optional[str] = None
QDRANT_PREFIX: Optional[str] = None
QDRANT_TIMEOUT: Optional[float] = None
QDRANT_HOST: Optional[str] = None
QDRANT_PATH: Optional[str] = None
QDRANT_DISTANCE_FUNC: str = "Cosine"
# Milvus vectorstore config
MILVUS_COLLECTION_NAME: Optional[str] = "docsgpt"
MILVUS_URI: Optional[str] = "./milvus_local.db" # milvus lite version as default
MILVUS_TOKEN: Optional[str] = ""
BRAVE_SEARCH_API_KEY: Optional[str] = None
FLASK_DEBUG_MODE: bool = False
path = Path(__file__).parent.parent.absolute()
settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8")

15
application/error.py Normal file
View File

@@ -0,0 +1,15 @@
from flask import jsonify
from werkzeug.http import HTTP_STATUS_CODES
def response_error(code_status, message=None):
payload = {'error': HTTP_STATUS_CODES.get(code_status, "something went wrong")}
if message:
payload['message'] = message
response = jsonify(payload)
response.status_code = code_status
return response
def bad_request(status_code=400, message=''):
return response_error(code_status=status_code, message=message)

View File

@@ -0,0 +1,7 @@
from flask_restx import Api
api = Api(
version="1.0",
title="DocsGPT API",
description="API for DocsGPT",
)

BIN
application/index.faiss Normal file

Binary file not shown.

BIN
application/index.pkl Normal file

Binary file not shown.

View File

View File

@@ -0,0 +1,50 @@
from application.llm.base import BaseLLM
from application.core.settings import settings
class AnthropicLLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT
super().__init__(*args, **kwargs)
self.api_key = (
api_key or settings.ANTHROPIC_API_KEY
) # If not provided, use a default from settings
self.user_api_key = user_api_key
self.anthropic = Anthropic(api_key=self.api_key)
self.HUMAN_PROMPT = HUMAN_PROMPT
self.AI_PROMPT = AI_PROMPT
def _raw_gen(
self, baseself, model, messages, stream=False, max_tokens=300, **kwargs
):
context = messages[0]["content"]
user_question = messages[-1]["content"]
prompt = f"### Context \n {context} \n ### Question \n {user_question}"
if stream:
return self.gen_stream(model, prompt, stream, max_tokens, **kwargs)
completion = self.anthropic.completions.create(
model=model,
max_tokens_to_sample=max_tokens,
stream=stream,
prompt=f"{self.HUMAN_PROMPT} {prompt}{self.AI_PROMPT}",
)
return completion.completion
def _raw_gen_stream(
self, baseself, model, messages, stream=True, max_tokens=300, **kwargs
):
context = messages[0]["content"]
user_question = messages[-1]["content"]
prompt = f"### Context \n {context} \n ### Question \n {user_question}"
stream_response = self.anthropic.completions.create(
model=model,
prompt=f"{self.HUMAN_PROMPT} {prompt}{self.AI_PROMPT}",
max_tokens_to_sample=max_tokens,
stream=True,
)
for completion in stream_response:
yield completion.completion

28
application/llm/base.py Normal file
View File

@@ -0,0 +1,28 @@
from abc import ABC, abstractmethod
from application.usage import gen_token_usage, stream_token_usage
class BaseLLM(ABC):
def __init__(self):
self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
def _apply_decorator(self, method, decorator, *args, **kwargs):
return decorator(method, *args, **kwargs)
@abstractmethod
def _raw_gen(self, model, messages, stream, *args, **kwargs):
pass
def gen(self, model, messages, stream=False, *args, **kwargs):
return self._apply_decorator(self._raw_gen, gen_token_usage)(
self, model=model, messages=messages, stream=stream, *args, **kwargs
)
@abstractmethod
def _raw_gen_stream(self, model, messages, stream, *args, **kwargs):
pass
def gen_stream(self, model, messages, stream=True, *args, **kwargs):
return self._apply_decorator(self._raw_gen_stream, stream_token_usage)(
self, model=model, messages=messages, stream=stream, *args, **kwargs
)

View File

@@ -0,0 +1,44 @@
from application.llm.base import BaseLLM
import json
import requests
class DocsGPTAPILLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.api_key = api_key
self.user_api_key = user_api_key
self.endpoint = "https://llm.docsgpt.co.uk"
def _raw_gen(self, baseself, model, messages, stream=False, *args, **kwargs):
context = messages[0]["content"]
user_question = messages[-1]["content"]
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
response = requests.post(
f"{self.endpoint}/answer", json={"prompt": prompt, "max_new_tokens": 30}
)
response_clean = response.json()["a"].replace("###", "")
return response_clean
def _raw_gen_stream(self, baseself, model, messages, stream=True, *args, **kwargs):
context = messages[0]["content"]
user_question = messages[-1]["content"]
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
# send prompt to endpoint /stream
response = requests.post(
f"{self.endpoint}/stream",
json={"prompt": prompt, "max_new_tokens": 256},
stream=True,
)
for line in response.iter_lines():
if line:
# data = json.loads(line)
data_str = line.decode("utf-8")
if data_str.startswith("data: "):
data = json.loads(data_str[6:])
yield data["a"]

View File

@@ -0,0 +1,68 @@
from application.llm.base import BaseLLM
class HuggingFaceLLM(BaseLLM):
def __init__(
self,
api_key=None,
user_api_key=None,
llm_name="Arc53/DocsGPT-7B",
q=False,
*args,
**kwargs,
):
global hf
from langchain.llms import HuggingFacePipeline
if q:
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
pipeline,
BitsAndBytesConfig,
)
tokenizer = AutoTokenizer.from_pretrained(llm_name)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
model = AutoModelForCausalLM.from_pretrained(
llm_name, quantization_config=bnb_config
)
else:
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
tokenizer = AutoTokenizer.from_pretrained(llm_name)
model = AutoModelForCausalLM.from_pretrained(llm_name)
super().__init__(*args, **kwargs)
self.api_key = api_key
self.user_api_key = user_api_key
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=2000,
device_map="auto",
eos_token_id=tokenizer.eos_token_id,
)
hf = HuggingFacePipeline(pipeline=pipe)
def _raw_gen(self, baseself, model, messages, stream=False, **kwargs):
context = messages[0]["content"]
user_question = messages[-1]["content"]
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
result = hf(prompt)
return result.content
def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs):
raise NotImplementedError("HuggingFaceLLM Streaming is not implemented yet.")

View File

@@ -0,0 +1,55 @@
from application.llm.base import BaseLLM
from application.core.settings import settings
import threading
class LlamaSingleton:
_instances = {}
_lock = threading.Lock() # Add a lock for thread synchronization
@classmethod
def get_instance(cls, llm_name):
if llm_name not in cls._instances:
try:
from llama_cpp import Llama
except ImportError:
raise ImportError(
"Please install llama_cpp using pip install llama-cpp-python"
)
cls._instances[llm_name] = Llama(model_path=llm_name, n_ctx=2048)
return cls._instances[llm_name]
@classmethod
def query_model(cls, llm, prompt, **kwargs):
with cls._lock:
return llm(prompt, **kwargs)
class LlamaCpp(BaseLLM):
def __init__(
self,
api_key=None,
user_api_key=None,
llm_name=settings.MODEL_PATH,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.api_key = api_key
self.user_api_key = user_api_key
self.llama = LlamaSingleton.get_instance(llm_name)
def _raw_gen(self, baseself, model, messages, stream=False, **kwargs):
context = messages[0]["content"]
user_question = messages[-1]["content"]
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
result = LlamaSingleton.query_model(self.llama, prompt, max_tokens=150, echo=False)
return result["choices"][0]["text"].split("### Answer \n")[-1]
def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs):
context = messages[0]["content"]
user_question = messages[-1]["content"]
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
result = LlamaSingleton.query_model(self.llama, prompt, max_tokens=150, echo=False, stream=stream)
for item in result:
for choice in item["choices"]:
yield choice["text"]

View File

@@ -0,0 +1,27 @@
from application.llm.openai import OpenAILLM, AzureOpenAILLM
from application.llm.sagemaker import SagemakerAPILLM
from application.llm.huggingface import HuggingFaceLLM
from application.llm.llama_cpp import LlamaCpp
from application.llm.anthropic import AnthropicLLM
from application.llm.docsgpt_provider import DocsGPTAPILLM
from application.llm.premai import PremAILLM
class LLMCreator:
llms = {
"openai": OpenAILLM,
"azure_openai": AzureOpenAILLM,
"sagemaker": SagemakerAPILLM,
"huggingface": HuggingFaceLLM,
"llama.cpp": LlamaCpp,
"anthropic": AnthropicLLM,
"docsgpt": DocsGPTAPILLM,
"premai": PremAILLM,
}
@classmethod
def create_llm(cls, type, api_key, user_api_key, *args, **kwargs):
llm_class = cls.llms.get(type.lower())
if not llm_class:
raise ValueError(f"No LLM class found for type {type}")
return llm_class(api_key, user_api_key, *args, **kwargs)

73
application/llm/openai.py Normal file
View File

@@ -0,0 +1,73 @@
from application.llm.base import BaseLLM
from application.core.settings import settings
class OpenAILLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
from openai import OpenAI
super().__init__(*args, **kwargs)
if settings.OPENAI_BASE_URL:
self.client = OpenAI(
api_key=api_key,
base_url=settings.OPENAI_BASE_URL
)
else:
self.client = OpenAI(api_key=api_key)
self.api_key = api_key
self.user_api_key = user_api_key
def _raw_gen(
self,
baseself,
model,
messages,
stream=False,
engine=settings.AZURE_DEPLOYMENT_NAME,
**kwargs
):
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)
return response.choices[0].message.content
def _raw_gen_stream(
self,
baseself,
model,
messages,
stream=True,
engine=settings.AZURE_DEPLOYMENT_NAME,
**kwargs
):
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)
for line in response:
# import sys
# print(line.choices[0].delta.content, file=sys.stderr)
if line.choices[0].delta.content is not None:
yield line.choices[0].delta.content
class AzureOpenAILLM(OpenAILLM):
def __init__(
self, openai_api_key, openai_api_base, openai_api_version, deployment_name
):
super().__init__(openai_api_key)
self.api_base = (settings.OPENAI_API_BASE,)
self.api_version = (settings.OPENAI_API_VERSION,)
self.deployment_name = (settings.AZURE_DEPLOYMENT_NAME,)
from openai import AzureOpenAI
self.client = AzureOpenAI(
api_key=openai_api_key,
api_version=settings.OPENAI_API_VERSION,
api_base=settings.OPENAI_API_BASE,
deployment_name=settings.AZURE_DEPLOYMENT_NAME,
)

38
application/llm/premai.py Normal file
View File

@@ -0,0 +1,38 @@
from application.llm.base import BaseLLM
from application.core.settings import settings
class PremAILLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
from premai import Prem
super().__init__(*args, **kwargs)
self.client = Prem(api_key=api_key)
self.api_key = api_key
self.user_api_key = user_api_key
self.project_id = settings.PREMAI_PROJECT_ID
def _raw_gen(self, baseself, model, messages, stream=False, **kwargs):
response = self.client.chat.completions.create(
model=model,
project_id=self.project_id,
messages=messages,
stream=stream,
**kwargs
)
return response.choices[0].message["content"]
def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs):
response = self.client.chat.completions.create(
model=model,
project_id=self.project_id,
messages=messages,
stream=stream,
**kwargs
)
for line in response:
if line.choices[0].delta["content"] is not None:
yield line.choices[0].delta["content"]

View File

@@ -0,0 +1,140 @@
from application.llm.base import BaseLLM
from application.core.settings import settings
import json
import io
class LineIterator:
"""
A helper class for parsing the byte stream input.
The output of the model will be in the following format:
```
b'{"outputs": [" a"]}\n'
b'{"outputs": [" challenging"]}\n'
b'{"outputs": [" problem"]}\n'
...
```
While usually each PayloadPart event from the event stream will contain a byte array
with a full json, this is not guaranteed and some of the json objects may be split across
PayloadPart events. For example:
```
{'PayloadPart': {'Bytes': b'{"outputs": '}}
{'PayloadPart': {'Bytes': b'[" problem"]}\n'}}
```
This class accounts for this by concatenating bytes written via the 'write' function
and then exposing a method which will return lines (ending with a '\n' character) within
the buffer via the 'scan_lines' function. It maintains the position of the last read
position to ensure that previous bytes are not exposed again.
"""
def __init__(self, stream):
self.byte_iterator = iter(stream)
self.buffer = io.BytesIO()
self.read_pos = 0
def __iter__(self):
return self
def __next__(self):
while True:
self.buffer.seek(self.read_pos)
line = self.buffer.readline()
if line and line[-1] == ord("\n"):
self.read_pos += len(line)
return line[:-1]
try:
chunk = next(self.byte_iterator)
except StopIteration:
if self.read_pos < self.buffer.getbuffer().nbytes:
continue
raise
if "PayloadPart" not in chunk:
print("Unknown event type:" + chunk)
continue
self.buffer.seek(0, io.SEEK_END)
self.buffer.write(chunk["PayloadPart"]["Bytes"])
class SagemakerAPILLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
import boto3
runtime = boto3.client(
"runtime.sagemaker",
aws_access_key_id="xxx",
aws_secret_access_key="xxx",
region_name="us-west-2",
)
super().__init__(*args, **kwargs)
self.api_key = api_key
self.user_api_key = user_api_key
self.endpoint = settings.SAGEMAKER_ENDPOINT
self.runtime = runtime
def _raw_gen(self, baseself, model, messages, stream=False, **kwargs):
context = messages[0]["content"]
user_question = messages[-1]["content"]
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
# Construct payload for endpoint
payload = {
"inputs": prompt,
"stream": False,
"parameters": {
"do_sample": True,
"temperature": 0.1,
"max_new_tokens": 30,
"repetition_penalty": 1.03,
"stop": ["</s>", "###"],
},
}
body_bytes = json.dumps(payload).encode("utf-8")
# Invoke the endpoint
response = self.runtime.invoke_endpoint(
EndpointName=self.endpoint, ContentType="application/json", Body=body_bytes
)
result = json.loads(response["Body"].read().decode())
import sys
print(result[0]["generated_text"], file=sys.stderr)
return result[0]["generated_text"][len(prompt) :]
def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs):
context = messages[0]["content"]
user_question = messages[-1]["content"]
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
# Construct payload for endpoint
payload = {
"inputs": prompt,
"stream": True,
"parameters": {
"do_sample": True,
"temperature": 0.1,
"max_new_tokens": 512,
"repetition_penalty": 1.03,
"stop": ["</s>", "###"],
},
}
body_bytes = json.dumps(payload).encode("utf-8")
# Invoke the endpoint
response = self.runtime.invoke_endpoint_with_response_stream(
EndpointName=self.endpoint, ContentType="application/json", Body=body_bytes
)
# result = json.loads(response['Body'].read().decode())
event_stream = response["Body"]
start_json = b"{"
for line in LineIterator(event_stream):
if line != b"" and start_json in line:
# print(line)
data = json.loads(line[line.find(start_json) :].decode("utf-8"))
if data["token"]["text"] not in ["</s>", "###"]:
print(data["token"]["text"], end="")
yield data["token"]["text"]

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,19 @@
"""Base reader class."""
from abc import abstractmethod
from typing import Any, List
from langchain.docstore.document import Document as LCDocument
from application.parser.schema.base import Document
class BaseReader:
"""Utilities for loading data from a directory."""
@abstractmethod
def load_data(self, *args: Any, **load_kwargs: Any) -> List[Document]:
"""Load data from the input directory."""
def load_langchain_documents(self, **load_kwargs: Any) -> List[LCDocument]:
"""Load data in LangChain document format."""
docs = self.load_data(**load_kwargs)
return [d.to_langchain_format() for d in docs]

View File

@@ -0,0 +1,38 @@
"""Base parser and config class."""
from abc import abstractmethod
from pathlib import Path
from typing import Dict, List, Optional, Union
class BaseParser:
"""Base class for all parsers."""
def __init__(self, parser_config: Optional[Dict] = None):
"""Init params."""
self._parser_config = parser_config
def init_parser(self) -> None:
"""Init parser and store it."""
parser_config = self._init_parser()
self._parser_config = parser_config
@property
def parser_config_set(self) -> bool:
"""Check if parser config is set."""
return self._parser_config is not None
@property
def parser_config(self) -> Dict:
"""Check if parser config is set."""
if self._parser_config is None:
raise ValueError("Parser config not set.")
return self._parser_config
@abstractmethod
def _init_parser(self) -> Dict:
"""Initialize the parser with the config."""
@abstractmethod
def parse_file(self, file: Path, errors: str = "ignore") -> Union[str, List[str]]:
"""Parse file."""

View File

@@ -0,0 +1,174 @@
"""Simple reader that reads files of different formats from a directory."""
import logging
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union
from application.parser.file.base import BaseReader
from application.parser.file.base_parser import BaseParser
from application.parser.file.docs_parser import DocxParser, PDFParser
from application.parser.file.epub_parser import EpubParser
from application.parser.file.html_parser import HTMLParser
from application.parser.file.markdown_parser import MarkdownParser
from application.parser.file.rst_parser import RstParser
from application.parser.file.tabular_parser import PandasCSVParser
from application.parser.schema.base import Document
DEFAULT_FILE_EXTRACTOR: Dict[str, BaseParser] = {
".pdf": PDFParser(),
".docx": DocxParser(),
".csv": PandasCSVParser(),
".epub": EpubParser(),
".md": MarkdownParser(),
".rst": RstParser(),
".html": HTMLParser(),
".mdx": MarkdownParser(),
}
class SimpleDirectoryReader(BaseReader):
"""Simple directory reader.
Can read files into separate documents, or concatenates
files into one document text.
Args:
input_dir (str): Path to the directory.
input_files (List): List of file paths to read (Optional; overrides input_dir)
exclude_hidden (bool): Whether to exclude hidden files (dotfiles).
errors (str): how encoding and decoding errors are to be handled,
see https://docs.python.org/3/library/functions.html#open
recursive (bool): Whether to recursively search in subdirectories.
False by default.
required_exts (Optional[List[str]]): List of required extensions.
Default is None.
file_extractor (Optional[Dict[str, BaseParser]]): A mapping of file
extension to a BaseParser class that specifies how to convert that file
to text. See DEFAULT_FILE_EXTRACTOR.
num_files_limit (Optional[int]): Maximum number of files to read.
Default is None.
file_metadata (Optional[Callable[str, Dict]]): A function that takes
in a filename and returns a Dict of metadata for the Document.
Default is None.
"""
def __init__(
self,
input_dir: Optional[str] = None,
input_files: Optional[List] = None,
exclude_hidden: bool = True,
errors: str = "ignore",
recursive: bool = True,
required_exts: Optional[List[str]] = None,
file_extractor: Optional[Dict[str, BaseParser]] = None,
num_files_limit: Optional[int] = None,
file_metadata: Optional[Callable[[str], Dict]] = None,
) -> None:
"""Initialize with parameters."""
super().__init__()
if not input_dir and not input_files:
raise ValueError("Must provide either `input_dir` or `input_files`.")
self.errors = errors
self.recursive = recursive
self.exclude_hidden = exclude_hidden
self.required_exts = required_exts
self.num_files_limit = num_files_limit
if input_files:
self.input_files = []
for path in input_files:
print(path)
input_file = Path(path)
self.input_files.append(input_file)
elif input_dir:
self.input_dir = Path(input_dir)
self.input_files = self._add_files(self.input_dir)
self.file_extractor = file_extractor or DEFAULT_FILE_EXTRACTOR
self.file_metadata = file_metadata
def _add_files(self, input_dir: Path) -> List[Path]:
"""Add files."""
input_files = sorted(input_dir.iterdir())
new_input_files = []
dirs_to_explore = []
for input_file in input_files:
if input_file.is_dir():
if self.recursive:
dirs_to_explore.append(input_file)
elif self.exclude_hidden and input_file.name.startswith("."):
continue
elif (
self.required_exts is not None
and input_file.suffix not in self.required_exts
):
continue
else:
new_input_files.append(input_file)
for dir_to_explore in dirs_to_explore:
sub_input_files = self._add_files(dir_to_explore)
new_input_files.extend(sub_input_files)
if self.num_files_limit is not None and self.num_files_limit > 0:
new_input_files = new_input_files[0: self.num_files_limit]
# print total number of files added
logging.debug(
f"> [SimpleDirectoryReader] Total files added: {len(new_input_files)}"
)
return new_input_files
def load_data(self, concatenate: bool = False) -> List[Document]:
"""Load data from the input directory.
Args:
concatenate (bool): whether to concatenate all files into one document.
If set to True, file metadata is ignored.
False by default.
Returns:
List[Document]: A list of documents.
"""
data: Union[str, List[str]] = ""
data_list: List[str] = []
metadata_list = []
for input_file in self.input_files:
if input_file.suffix in self.file_extractor:
parser = self.file_extractor[input_file.suffix]
if not parser.parser_config_set:
parser.init_parser()
data = parser.parse_file(input_file, errors=self.errors)
else:
# do standard read
with open(input_file, "r", errors=self.errors) as f:
data = f.read()
# Prepare metadata for this file
if self.file_metadata is not None:
file_metadata = self.file_metadata(str(input_file))
else:
# Provide a default empty metadata
file_metadata = {'title': '', 'store': ''}
# TODO: Find a case with no metadata and check if breaks anything
if isinstance(data, List):
# Extend data_list with each item in the data list
data_list.extend([str(d) for d in data])
# For each item in the data list, add the file's metadata to metadata_list
metadata_list.extend([file_metadata for _ in data])
else:
# Add the single piece of data to data_list
data_list.append(str(data))
# Add the file's metadata to metadata_list
metadata_list.append(file_metadata)
if concatenate:
return [Document("\n".join(data_list))]
elif self.file_metadata is not None:
return [Document(d, extra_info=m) for d, m in zip(data_list, metadata_list)]
else:
return [Document(d) for d in data_list]

View File

@@ -0,0 +1,59 @@
"""Docs parser.
Contains parsers for docx, pdf files.
"""
from pathlib import Path
from typing import Dict
from application.parser.file.base_parser import BaseParser
class PDFParser(BaseParser):
"""PDF parser."""
def _init_parser(self) -> Dict:
"""Init parser."""
return {}
def parse_file(self, file: Path, errors: str = "ignore") -> str:
"""Parse file."""
try:
import PyPDF2
except ImportError:
raise ValueError("PyPDF2 is required to read PDF files.")
text_list = []
with open(file, "rb") as fp:
# Create a PDF object
pdf = PyPDF2.PdfReader(fp)
# Get the number of pages in the PDF document
num_pages = len(pdf.pages)
# Iterate over every page
for page in range(num_pages):
# Extract the text from the page
page_text = pdf.pages[page].extract_text()
text_list.append(page_text)
text = "\n".join(text_list)
return text
class DocxParser(BaseParser):
"""Docx parser."""
def _init_parser(self) -> Dict:
"""Init parser."""
return {}
def parse_file(self, file: Path, errors: str = "ignore") -> str:
"""Parse file."""
try:
import docx2txt
except ImportError:
raise ValueError("docx2txt is required to read Microsoft Word files.")
text = docx2txt.process(file)
return text

View File

@@ -0,0 +1,43 @@
"""Epub parser.
Contains parsers for epub files.
"""
from pathlib import Path
from typing import Dict
from application.parser.file.base_parser import BaseParser
class EpubParser(BaseParser):
"""Epub Parser."""
def _init_parser(self) -> Dict:
"""Init parser."""
return {}
def parse_file(self, file: Path, errors: str = "ignore") -> str:
"""Parse file."""
try:
import ebooklib
from ebooklib import epub
except ImportError:
raise ValueError("`EbookLib` is required to read Epub files.")
try:
import html2text
except ImportError:
raise ValueError("`html2text` is required to parse Epub files.")
text_list = []
book = epub.read_epub(file, options={"ignore_ncx": True})
# Iterate through all chapters.
for item in book.get_items():
# Chapters are typically located in epub documents items.
if item.get_type() == ebooklib.ITEM_DOCUMENT:
text_list.append(
html2text.html2text(item.get_content().decode("utf-8"))
)
text = "\n".join(text_list)
return text

View File

@@ -0,0 +1,24 @@
"""HTML parser.
Contains parser for html files.
"""
from pathlib import Path
from typing import Dict, Union
from application.parser.file.base_parser import BaseParser
class HTMLParser(BaseParser):
"""HTML parser."""
def _init_parser(self) -> Dict:
"""Init parser."""
return {}
def parse_file(self, file: Path, errors: str = "ignore") -> Union[str, list[str]]:
from langchain_community.document_loaders import BSHTMLLoader
loader = BSHTMLLoader(file)
data = loader.load()
return data

View File

@@ -0,0 +1,145 @@
"""Markdown parser.
Contains parser for md files.
"""
import re
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union, cast
import tiktoken
from application.parser.file.base_parser import BaseParser
class MarkdownParser(BaseParser):
"""Markdown parser.
Extract text from markdown files.
Returns dictionary with keys as headers and values as the text between headers.
"""
def __init__(
self,
*args: Any,
remove_hyperlinks: bool = True,
remove_images: bool = True,
max_tokens: int = 2048,
# remove_tables: bool = True,
**kwargs: Any,
) -> None:
"""Init params."""
super().__init__(*args, **kwargs)
self._remove_hyperlinks = remove_hyperlinks
self._remove_images = remove_images
self._max_tokens = max_tokens
# self._remove_tables = remove_tables
def tups_chunk_append(self, tups: List[Tuple[Optional[str], str]], current_header: Optional[str],
current_text: str):
"""Append to tups chunk."""
num_tokens = len(tiktoken.get_encoding("cl100k_base").encode(current_text))
if num_tokens > self._max_tokens:
chunks = [current_text[i:i + self._max_tokens] for i in range(0, len(current_text), self._max_tokens)]
for chunk in chunks:
tups.append((current_header, chunk))
else:
tups.append((current_header, current_text))
return tups
def markdown_to_tups(self, markdown_text: str) -> List[Tuple[Optional[str], str]]:
"""Convert a markdown file to a dictionary.
The keys are the headers and the values are the text under each header.
"""
markdown_tups: List[Tuple[Optional[str], str]] = []
lines = markdown_text.split("\n")
current_header = None
current_text = ""
for line in lines:
header_match = re.match(r"^#+\s", line)
if header_match:
if current_header is not None:
if current_text == "" or None:
continue
markdown_tups = self.tups_chunk_append(markdown_tups, current_header, current_text)
current_header = line
current_text = ""
else:
current_text += line + "\n"
markdown_tups = self.tups_chunk_append(markdown_tups, current_header, current_text)
if current_header is not None:
# pass linting, assert keys are defined
markdown_tups = [
(re.sub(r"#", "", cast(str, key)).strip(), re.sub(r"<.*?>", "", value))
for key, value in markdown_tups
]
else:
markdown_tups = [
(key, re.sub("\n", "", value)) for key, value in markdown_tups
]
return markdown_tups
def remove_images(self, content: str) -> str:
"""Get a dictionary of a markdown file from its path."""
pattern = r"!{1}\[\[(.*)\]\]"
content = re.sub(pattern, "", content)
return content
# def remove_tables(self, content: str) -> List[List[str]]:
# """Convert markdown tables to nested lists."""
# table_rows_pattern = r"((\r?\n){2}|^)([^\r\n]*\|[^\r\n]*(\r?\n)?)+(?=(\r?\n){2}|$)"
# table_cells_pattern = r"([^\|\r\n]*)\|"
#
# table_rows = re.findall(table_rows_pattern, content, re.MULTILINE)
# table_lists = []
# for row in table_rows:
# cells = re.findall(table_cells_pattern, row[2])
# cells = [cell.strip() for cell in cells if cell.strip()]
# table_lists.append(cells)
# return str(table_lists)
def remove_hyperlinks(self, content: str) -> str:
"""Get a dictionary of a markdown file from its path."""
pattern = r"\[(.*?)\]\((.*?)\)"
content = re.sub(pattern, r"\1", content)
return content
def _init_parser(self) -> Dict:
"""Initialize the parser with the config."""
return {}
def parse_tups(
self, filepath: Path, errors: str = "ignore"
) -> List[Tuple[Optional[str], str]]:
"""Parse file into tuples."""
with open(filepath, "r") as f:
content = f.read()
if self._remove_hyperlinks:
content = self.remove_hyperlinks(content)
if self._remove_images:
content = self.remove_images(content)
# if self._remove_tables:
# content = self.remove_tables(content)
markdown_tups = self.markdown_to_tups(content)
return markdown_tups
def parse_file(
self, filepath: Path, errors: str = "ignore"
) -> Union[str, List[str]]:
"""Parse file into string."""
tups = self.parse_tups(filepath, errors=errors)
results = []
# TODO: don't include headers right now
for header, value in tups:
if header is None:
results.append(value)
else:
results.append(f"\n\n{header}\n{value}")
return results

View File

@@ -0,0 +1,51 @@
from urllib.parse import urlparse
from openapi_parser import parse
try:
from application.parser.file.base_parser import BaseParser
except ModuleNotFoundError:
from base_parser import BaseParser
class OpenAPI3Parser(BaseParser):
def init_parser(self) -> None:
return super().init_parser()
def get_base_urls(self, urls):
base_urls = []
for i in urls:
parsed_url = urlparse(i)
base_url = parsed_url.scheme + "://" + parsed_url.netloc
if base_url not in base_urls:
base_urls.append(base_url)
return base_urls
def get_info_from_paths(self, path):
info = ""
if path.operations:
for operation in path.operations:
info += (
f"\n{operation.method.value}="
f"{operation.responses[0].description}"
)
return info
def parse_file(self, file_path):
data = parse(file_path)
results = ""
base_urls = self.get_base_urls(link.url for link in data.servers)
base_urls = ",".join([base_url for base_url in base_urls])
results += f"Base URL:{base_urls}\n"
i = 1
for path in data.paths:
info = self.get_info_from_paths(path)
results += (
f"Path{i}: {path.url}\n"
f"description: {path.description}\n"
f"parameters: {path.parameters}\nmethods: {info}\n"
)
i += 1
with open("results.txt", "w") as f:
f.write(results)
return results

View File

@@ -0,0 +1,173 @@
"""reStructuredText parser.
Contains parser for md files.
"""
import re
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
from application.parser.file.base_parser import BaseParser
class RstParser(BaseParser):
"""reStructuredText parser.
Extract text from .rst files.
Returns dictionary with keys as headers and values as the text between headers.
"""
def __init__(
self,
*args: Any,
remove_hyperlinks: bool = True,
remove_images: bool = True,
remove_table_excess: bool = True,
remove_interpreters: bool = True,
remove_directives: bool = True,
remove_whitespaces_excess: bool = True,
# Be careful with remove_characters_excess, might cause data loss
remove_characters_excess: bool = True,
**kwargs: Any,
) -> None:
"""Init params."""
super().__init__(*args, **kwargs)
self._remove_hyperlinks = remove_hyperlinks
self._remove_images = remove_images
self._remove_table_excess = remove_table_excess
self._remove_interpreters = remove_interpreters
self._remove_directives = remove_directives
self._remove_whitespaces_excess = remove_whitespaces_excess
self._remove_characters_excess = remove_characters_excess
def rst_to_tups(self, rst_text: str) -> List[Tuple[Optional[str], str]]:
"""Convert a reStructuredText file to a dictionary.
The keys are the headers and the values are the text under each header.
"""
rst_tups: List[Tuple[Optional[str], str]] = []
lines = rst_text.split("\n")
current_header = None
current_text = ""
for i, line in enumerate(lines):
header_match = re.match(r"^[^\S\n]*[-=]+[^\S\n]*$", line)
if header_match and i > 0 and (
len(lines[i - 1].strip()) == len(header_match.group().strip()) or lines[i - 2] == lines[i - 2]):
if current_header is not None:
if current_text == "" or None:
continue
# removes the next heading from current Document
if current_text.endswith(lines[i - 1] + "\n"):
current_text = current_text[:len(current_text) - len(lines[i - 1] + "\n")]
rst_tups.append((current_header, current_text))
current_header = lines[i - 1]
current_text = ""
else:
current_text += line + "\n"
rst_tups.append((current_header, current_text))
# TODO: Format for rst
#
# if current_header is not None:
# # pass linting, assert keys are defined
# rst_tups = [
# (re.sub(r"#", "", cast(str, key)).strip(), re.sub(r"<.*?>", "", value))
# for key, value in rst_tups
# ]
# else:
# rst_tups = [
# (key, re.sub("\n", "", value)) for key, value in rst_tups
# ]
if current_header is None:
rst_tups = [
(key, re.sub("\n", "", value)) for key, value in rst_tups
]
return rst_tups
def remove_images(self, content: str) -> str:
pattern = r"\.\. image:: (.*)"
content = re.sub(pattern, "", content)
return content
def remove_hyperlinks(self, content: str) -> str:
pattern = r"`(.*?) <(.*?)>`_"
content = re.sub(pattern, r"\1", content)
return content
def remove_directives(self, content: str) -> str:
"""Removes reStructuredText Directives"""
pattern = r"`\.\.([^:]+)::"
content = re.sub(pattern, "", content)
return content
def remove_interpreters(self, content: str) -> str:
"""Removes reStructuredText Interpreted Text Roles"""
pattern = r":(\w+):"
content = re.sub(pattern, "", content)
return content
def remove_table_excess(self, content: str) -> str:
"""Pattern to remove grid table separators"""
pattern = r"^\+[-]+\+[-]+\+$"
content = re.sub(pattern, "", content, flags=re.MULTILINE)
return content
def remove_whitespaces_excess(self, content: List[Tuple[str, Any]]) -> List[Tuple[str, Any]]:
"""Pattern to match 2 or more consecutive whitespaces"""
pattern = r"\s{2,}"
content = [(key, re.sub(pattern, " ", value)) for key, value in content]
return content
def remove_characters_excess(self, content: List[Tuple[str, Any]]) -> List[Tuple[str, Any]]:
"""Pattern to match 2 or more consecutive characters"""
pattern = r"(\S)\1{2,}"
content = [(key, re.sub(pattern, r"\1\1\1", value, flags=re.MULTILINE)) for key, value in content]
return content
def _init_parser(self) -> Dict:
"""Initialize the parser with the config."""
return {}
def parse_tups(
self, filepath: Path, errors: str = "ignore"
) -> List[Tuple[Optional[str], str]]:
"""Parse file into tuples."""
with open(filepath, "r") as f:
content = f.read()
if self._remove_hyperlinks:
content = self.remove_hyperlinks(content)
if self._remove_images:
content = self.remove_images(content)
if self._remove_table_excess:
content = self.remove_table_excess(content)
if self._remove_directives:
content = self.remove_directives(content)
if self._remove_interpreters:
content = self.remove_interpreters(content)
rst_tups = self.rst_to_tups(content)
if self._remove_whitespaces_excess:
rst_tups = self.remove_whitespaces_excess(rst_tups)
if self._remove_characters_excess:
rst_tups = self.remove_characters_excess(rst_tups)
return rst_tups
def parse_file(
self, filepath: Path, errors: str = "ignore"
) -> Union[str, List[str]]:
"""Parse file into string."""
tups = self.parse_tups(filepath, errors=errors)
results = []
# TODO: don't include headers right now
for header, value in tups:
if header is None:
results.append(value)
else:
results.append(f"\n\n{header}\n{value}")
return results

View File

@@ -0,0 +1,115 @@
"""Tabular parser.
Contains parsers for tabular data files.
"""
from pathlib import Path
from typing import Any, Dict, List, Union
from application.parser.file.base_parser import BaseParser
class CSVParser(BaseParser):
"""CSV parser.
Args:
concat_rows (bool): whether to concatenate all rows into one document.
If set to False, a Document will be created for each row.
True by default.
"""
def __init__(self, *args: Any, concat_rows: bool = True, **kwargs: Any) -> None:
"""Init params."""
super().__init__(*args, **kwargs)
self._concat_rows = concat_rows
def _init_parser(self) -> Dict:
"""Init parser."""
return {}
def parse_file(self, file: Path, errors: str = "ignore") -> Union[str, List[str]]:
"""Parse file.
Returns:
Union[str, List[str]]: a string or a List of strings.
"""
try:
import csv
except ImportError:
raise ValueError("csv module is required to read CSV files.")
text_list = []
with open(file, "r") as fp:
csv_reader = csv.reader(fp)
for row in csv_reader:
text_list.append(", ".join(row))
if self._concat_rows:
return "\n".join(text_list)
else:
return text_list
class PandasCSVParser(BaseParser):
r"""Pandas-based CSV parser.
Parses CSVs using the separator detection from Pandas `read_csv`function.
If special parameters are required, use the `pandas_config` dict.
Args:
concat_rows (bool): whether to concatenate all rows into one document.
If set to False, a Document will be created for each row.
True by default.
col_joiner (str): Separator to use for joining cols per row.
Set to ", " by default.
row_joiner (str): Separator to use for joining each row.
Only used when `concat_rows=True`.
Set to "\n" by default.
pandas_config (dict): Options for the `pandas.read_csv` function call.
Refer to https://pandas.pydata.org/docs/reference/api/pandas.read_csv.html
for more information.
Set to empty dict by default, this means pandas will try to figure
out the separators, table head, etc. on its own.
"""
def __init__(
self,
*args: Any,
concat_rows: bool = True,
col_joiner: str = ", ",
row_joiner: str = "\n",
pandas_config: dict = {},
**kwargs: Any
) -> None:
"""Init params."""
super().__init__(*args, **kwargs)
self._concat_rows = concat_rows
self._col_joiner = col_joiner
self._row_joiner = row_joiner
self._pandas_config = pandas_config
def _init_parser(self) -> Dict:
"""Init parser."""
return {}
def parse_file(self, file: Path, errors: str = "ignore") -> Union[str, List[str]]:
"""Parse file."""
try:
import pandas as pd
except ImportError:
raise ValueError("pandas module is required to read CSV files.")
df = pd.read_csv(file, **self._pandas_config)
text_list = df.apply(
lambda row: (self._col_joiner).join(row.astype(str).tolist()), axis=1
).tolist()
if self._concat_rows:
return (self._row_joiner).join(text_list)
else:
return text_list

View File

@@ -0,0 +1,66 @@
import os
import javalang
def find_files(directory):
files_list = []
for root, dirs, files in os.walk(directory):
for file in files:
if file.endswith('.java'):
files_list.append(os.path.join(root, file))
return files_list
def extract_functions(file_path):
with open(file_path, "r") as file:
java_code = file.read()
methods = {}
tree = javalang.parse.parse(java_code)
for _, node in tree.filter(javalang.tree.MethodDeclaration):
method_name = node.name
start_line = node.position.line - 1
end_line = start_line
brace_count = 0
for line in java_code.splitlines()[start_line:]:
end_line += 1
brace_count += line.count("{") - line.count("}")
if brace_count == 0:
break
method_source_code = "\n".join(java_code.splitlines()[start_line:end_line])
methods[method_name] = method_source_code
return methods
def extract_classes(file_path):
with open(file_path, 'r') as file:
source_code = file.read()
classes = {}
tree = javalang.parse.parse(source_code)
for class_decl in tree.types:
class_name = class_decl.name
declarations = []
methods = []
for field_decl in class_decl.fields:
field_name = field_decl.declarators[0].name
field_type = field_decl.type.name
declarations.append(f"{field_type} {field_name}")
for method_decl in class_decl.methods:
methods.append(method_decl.name)
class_string = "Declarations: " + ", ".join(declarations) + "\n Method name: " + ", ".join(methods)
classes[class_name] = class_string
return classes
def extract_functions_and_classes(directory):
files = find_files(directory)
functions_dict = {}
classes_dict = {}
for file in files:
functions = extract_functions(file)
if functions:
functions_dict[file] = functions
classes = extract_classes(file)
if classes:
classes_dict[file] = classes
return functions_dict, classes_dict

View File

@@ -0,0 +1,70 @@
import os
import escodegen
import esprima
def find_files(directory):
files_list = []
for root, dirs, files in os.walk(directory):
for file in files:
if file.endswith('.js'):
files_list.append(os.path.join(root, file))
return files_list
def extract_functions(file_path):
with open(file_path, 'r') as file:
source_code = file.read()
functions = {}
tree = esprima.parseScript(source_code)
for node in tree.body:
if node.type == 'FunctionDeclaration':
func_name = node.id.name if node.id else '<anonymous>'
functions[func_name] = escodegen.generate(node)
elif node.type == 'VariableDeclaration':
for declaration in node.declarations:
if declaration.init and declaration.init.type == 'FunctionExpression':
func_name = declaration.id.name if declaration.id else '<anonymous>'
functions[func_name] = escodegen.generate(declaration.init)
elif node.type == 'ClassDeclaration':
for subnode in node.body.body:
if subnode.type == 'MethodDefinition':
func_name = subnode.key.name
functions[func_name] = escodegen.generate(subnode.value)
elif subnode.type == 'VariableDeclaration':
for declaration in subnode.declarations:
if declaration.init and declaration.init.type == 'FunctionExpression':
func_name = declaration.id.name if declaration.id else '<anonymous>'
functions[func_name] = escodegen.generate(declaration.init)
return functions
def extract_classes(file_path):
with open(file_path, 'r') as file:
source_code = file.read()
classes = {}
tree = esprima.parseScript(source_code)
for node in tree.body:
if node.type == 'ClassDeclaration':
class_name = node.id.name
function_names = []
for subnode in node.body.body:
if subnode.type == 'MethodDefinition':
function_names.append(subnode.key.name)
classes[class_name] = ", ".join(function_names)
return classes
def extract_functions_and_classes(directory):
files = find_files(directory)
functions_dict = {}
classes_dict = {}
for file in files:
functions = extract_functions(file)
if functions:
functions_dict[file] = functions
classes = extract_classes(file)
if classes:
classes_dict[file] = classes
return functions_dict, classes_dict

View File

@@ -0,0 +1,75 @@
import os
from retry import retry
from application.core.settings import settings
from application.vectorstore.vector_creator import VectorCreator
# from langchain_community.embeddings import HuggingFaceEmbeddings
# from langchain_community.embeddings import HuggingFaceInstructEmbeddings
# from langchain_community.embeddings import CohereEmbeddings
@retry(tries=10, delay=60)
def store_add_texts_with_retry(store, i, id):
# add source_id to the metadata
i.metadata["source_id"] = str(id)
store.add_texts([i.page_content], metadatas=[i.metadata])
# store_pine.add_texts([i.page_content], metadatas=[i.metadata])
def call_openai_api(docs, folder_name, id, task_status):
# Function to create a vector store from the documents and save it to disk
if not os.path.exists(f"{folder_name}"):
os.makedirs(f"{folder_name}")
from tqdm import tqdm
c1 = 0
if settings.VECTOR_STORE == "faiss":
docs_init = [docs[0]]
docs.pop(0)
store = VectorCreator.create_vectorstore(
settings.VECTOR_STORE,
docs_init=docs_init,
source_id=f"{folder_name}",
embeddings_key=os.getenv("EMBEDDINGS_KEY"),
)
else:
store = VectorCreator.create_vectorstore(
settings.VECTOR_STORE,
source_id=str(id),
embeddings_key=os.getenv("EMBEDDINGS_KEY"),
)
store.delete_index()
# Uncomment for MPNet embeddings
# model_name = "sentence-transformers/all-mpnet-base-v2"
# hf = HuggingFaceEmbeddings(model_name=model_name)
# store = FAISS.from_documents(docs_test, hf)
s1 = len(docs)
for i in tqdm(
docs,
desc="Embedding 🦖",
unit="docs",
total=len(docs),
bar_format="{l_bar}{bar}| Time Left: {remaining}",
):
try:
task_status.update_state(
state="PROGRESS", meta={"current": int((c1 / s1) * 100)}
)
store_add_texts_with_retry(store, i, id)
except Exception as e:
print(e)
print("Error on ", i)
print("Saving progress")
print(f"stopped at {c1} out of {len(docs)}")
store.save_local(f"{folder_name}")
break
c1 += 1
if settings.VECTOR_STORE == "faiss":
store.save_local(f"{folder_name}")

View File

@@ -0,0 +1,121 @@
import ast
import os
from pathlib import Path
import tiktoken
from langchain.llms import OpenAI
from langchain.prompts import PromptTemplate
def find_files(directory):
files_list = []
for root, dirs, files in os.walk(directory):
for file in files:
if file.endswith('.py'):
files_list.append(os.path.join(root, file))
return files_list
def extract_functions(file_path):
with open(file_path, 'r') as file:
source_code = file.read()
functions = {}
tree = ast.parse(source_code)
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef):
func_name = node.name
func_def = ast.get_source_segment(source_code, node)
functions[func_name] = func_def
return functions
def extract_classes(file_path):
with open(file_path, 'r') as file:
source_code = file.read()
classes = {}
tree = ast.parse(source_code)
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef):
class_name = node.name
function_names = []
for subnode in ast.walk(node):
if isinstance(subnode, ast.FunctionDef):
function_names.append(subnode.name)
classes[class_name] = ", ".join(function_names)
return classes
def extract_functions_and_classes(directory):
files = find_files(directory)
functions_dict = {}
classes_dict = {}
for file in files:
functions = extract_functions(file)
if functions:
functions_dict[file] = functions
classes = extract_classes(file)
if classes:
classes_dict[file] = classes
return functions_dict, classes_dict
def parse_functions(functions_dict, formats, dir):
c1 = len(functions_dict)
for i, (source, functions) in enumerate(functions_dict.items(), start=1):
print(f"Processing file {i}/{c1}")
source_w = source.replace(dir + "/", "").replace("." + formats, ".md")
subfolders = "/".join(source_w.split("/")[:-1])
Path(f"outputs/{subfolders}").mkdir(parents=True, exist_ok=True)
for j, (name, function) in enumerate(functions.items(), start=1):
print(f"Processing function {j}/{len(functions)}")
prompt = PromptTemplate(
input_variables=["code"],
template="Code: \n{code}, \nDocumentation: ",
)
llm = OpenAI(temperature=0)
response = llm(prompt.format(code=function))
mode = "a" if Path(f"outputs/{source_w}").exists() else "w"
with open(f"outputs/{source_w}", mode) as f:
f.write(
f"\n\n# Function name: {name} \n\nFunction: \n```\n{function}\n```, \nDocumentation: \n{response}")
def parse_classes(classes_dict, formats, dir):
c1 = len(classes_dict)
for i, (source, classes) in enumerate(classes_dict.items()):
print(f"Processing file {i + 1}/{c1}")
source_w = source.replace(dir + "/", "").replace("." + formats, ".md")
subfolders = "/".join(source_w.split("/")[:-1])
Path(f"outputs/{subfolders}").mkdir(parents=True, exist_ok=True)
for name, function_names in classes.items():
print(f"Processing Class {i + 1}/{c1}")
prompt = PromptTemplate(
input_variables=["class_name", "functions_names"],
template="Class name: {class_name} \nFunctions: {functions_names}, \nDocumentation: ",
)
llm = OpenAI(temperature=0)
response = llm(prompt.format(class_name=name, functions_names=function_names))
with open(f"outputs/{source_w}", "a" if Path(f"outputs/{source_w}").exists() else "w") as f:
f.write(f"\n\n# Class name: {name} \n\nFunctions: \n{function_names}, \nDocumentation: \n{response}")
def transform_to_docs(functions_dict, classes_dict, formats, dir):
docs_content = ''.join([str(key) + str(value) for key, value in functions_dict.items()])
docs_content += ''.join([str(key) + str(value) for key, value in classes_dict.items()])
num_tokens = len(tiktoken.get_encoding("cl100k_base").encode(docs_content))
total_price = ((num_tokens / 1000) * 0.02)
print(f"Number of Tokens = {num_tokens:,d}")
print(f"Approx Cost = ${total_price:,.2f}")
user_input = input("Price Okay? (Y/N)\n").lower()
if user_input == "y" or user_input == "":
if not Path("outputs").exists():
Path("outputs").mkdir()
parse_functions(functions_dict, formats, dir)
parse_classes(classes_dict, formats, dir)
print("All done!")
else:
print("The API was not called. No money was spent.")

View File

@@ -0,0 +1,19 @@
"""Base reader class."""
from abc import abstractmethod
from typing import Any, List
from langchain.docstore.document import Document as LCDocument
from application.parser.schema.base import Document
class BaseRemote:
"""Utilities for loading data from a directory."""
@abstractmethod
def load_data(self, *args: Any, **load_kwargs: Any) -> List[Document]:
"""Load data from the input directory."""
def load_langchain_documents(self, **load_kwargs: Any) -> List[LCDocument]:
"""Load data in LangChain document format."""
docs = self.load_data(**load_kwargs)
return [d.to_langchain_format() for d in docs]

View File

@@ -0,0 +1,59 @@
import requests
from urllib.parse import urlparse, urljoin
from bs4 import BeautifulSoup
from application.parser.remote.base import BaseRemote
class CrawlerLoader(BaseRemote):
def __init__(self, limit=10):
from langchain_community.document_loaders import WebBaseLoader
self.loader = WebBaseLoader # Initialize the document loader
self.limit = limit # Set the limit for the number of pages to scrape
def load_data(self, inputs):
url = inputs
# Check if the input is a list and if it is, use the first element
if isinstance(url, list) and url:
url = url[0]
# Check if the URL scheme is provided, if not, assume http
if not urlparse(url).scheme:
url = "http://" + url
visited_urls = set() # Keep track of URLs that have been visited
base_url = urlparse(url).scheme + "://" + urlparse(url).hostname # Extract the base URL
urls_to_visit = [url] # List of URLs to be visited, starting with the initial URL
loaded_content = [] # Store the loaded content from each URL
# Continue crawling until there are no more URLs to visit
while urls_to_visit:
current_url = urls_to_visit.pop(0) # Get the next URL to visit
visited_urls.add(current_url) # Mark the URL as visited
# Try to load and process the content from the current URL
try:
response = requests.get(current_url) # Fetch the content of the current URL
response.raise_for_status() # Raise an exception for HTTP errors
loader = self.loader([current_url]) # Initialize the document loader for the current URL
loaded_content.extend(loader.load()) # Load the content and add it to the loaded_content list
except Exception as e:
# Print an error message if loading or processing fails and continue with the next URL
print(f"Error processing URL {current_url}: {e}")
continue
# Parse the HTML content to extract all links
soup = BeautifulSoup(response.text, 'html.parser')
all_links = [
urljoin(current_url, a['href'])
for a in soup.find_all('a', href=True)
if base_url in urljoin(current_url, a['href']) # Ensure links are from the same domain
]
# Add new links to the list of URLs to visit if they haven't been visited yet
urls_to_visit.extend([link for link in all_links if link not in visited_urls])
urls_to_visit = list(set(urls_to_visit)) # Remove duplicate URLs
# Stop crawling if the limit of pages to scrape is reached
if self.limit is not None and len(visited_urls) >= self.limit:
break
return loaded_content # Return the loaded content from all visited URLs

View File

@@ -0,0 +1,26 @@
from application.parser.remote.base import BaseRemote
from langchain_community.document_loaders import RedditPostsLoader
class RedditPostsLoaderRemote(BaseRemote):
def load_data(self, inputs):
data = eval(inputs)
client_id = data.get("client_id")
client_secret = data.get("client_secret")
user_agent = data.get("user_agent")
categories = data.get("categories", ["new", "hot"])
mode = data.get("mode", "subreddit")
search_queries = data.get("search_queries")
number_posts = data.get("number_posts", 10)
self.loader = RedditPostsLoader(
client_id=client_id,
client_secret=client_secret,
user_agent=user_agent,
categories=categories,
mode=mode,
search_queries=search_queries,
number_posts=number_posts,
)
documents = self.loader.load()
print(f"Loaded {len(documents)} documents from Reddit")
return documents

View File

@@ -0,0 +1,20 @@
from application.parser.remote.sitemap_loader import SitemapLoader
from application.parser.remote.crawler_loader import CrawlerLoader
from application.parser.remote.web_loader import WebLoader
from application.parser.remote.reddit_loader import RedditPostsLoaderRemote
class RemoteCreator:
loaders = {
"url": WebLoader,
"sitemap": SitemapLoader,
"crawler": CrawlerLoader,
"reddit": RedditPostsLoaderRemote,
}
@classmethod
def create_loader(cls, type, *args, **kwargs):
loader_class = cls.loaders.get(type.lower())
if not loader_class:
raise ValueError(f"No LLM class found for type {type}")
return loader_class(*args, **kwargs)

View File

@@ -0,0 +1,81 @@
import requests
import re # Import regular expression library
import xml.etree.ElementTree as ET
from application.parser.remote.base import BaseRemote
class SitemapLoader(BaseRemote):
def __init__(self, limit=20):
from langchain_community.document_loaders import WebBaseLoader
self.loader = WebBaseLoader
self.limit = limit # Adding limit to control the number of URLs to process
def load_data(self, inputs):
sitemap_url= inputs
# Check if the input is a list and if it is, use the first element
if isinstance(sitemap_url, list) and sitemap_url:
url = sitemap_url[0]
urls = self._extract_urls(sitemap_url)
if not urls:
print(f"No URLs found in the sitemap: {sitemap_url}")
return []
# Load content of extracted URLs
documents = []
processed_urls = 0 # Counter for processed URLs
for url in urls:
if self.limit is not None and processed_urls >= self.limit:
break # Stop processing if the limit is reached
try:
loader = self.loader([url])
documents.extend(loader.load())
processed_urls += 1 # Increment the counter after processing each URL
except Exception as e:
print(f"Error processing URL {url}: {e}")
continue
return documents
def _extract_urls(self, sitemap_url):
try:
response = requests.get(sitemap_url)
response.raise_for_status() # Raise an exception for HTTP errors
except (requests.exceptions.HTTPError, requests.exceptions.ConnectionError) as e:
print(f"Failed to fetch sitemap: {sitemap_url}. Error: {e}")
return []
# Determine if this is a sitemap or a URL
if self._is_sitemap(response):
# It's a sitemap, so parse it and extract URLs
return self._parse_sitemap(response.content)
else:
# It's not a sitemap, return the URL itself
return [sitemap_url]
def _is_sitemap(self, response):
content_type = response.headers.get('Content-Type', '')
if 'xml' in content_type or response.url.endswith('.xml'):
return True
if '<sitemapindex' in response.text or '<urlset' in response.text:
return True
return False
def _parse_sitemap(self, sitemap_content):
# Remove namespaces
sitemap_content = re.sub(' xmlns="[^"]+"', '', sitemap_content.decode('utf-8'), count=1)
root = ET.fromstring(sitemap_content)
urls = []
for loc in root.findall('.//url/loc'):
urls.append(loc.text)
# Check for nested sitemaps
for sitemap in root.findall('.//sitemap/loc'):
nested_sitemap_url = sitemap.text
urls.extend(self._extract_urls(nested_sitemap_url))
return urls

View File

@@ -0,0 +1,11 @@
from langchain.document_loader import TelegramChatApiLoader
from application.parser.remote.base import BaseRemote
class TelegramChatApiRemote(BaseRemote):
def _init_parser(self, *args, **load_kwargs):
self.loader = TelegramChatApiLoader(**load_kwargs)
return {}
def parse_file(self, *args, **load_kwargs):
return

View File

@@ -0,0 +1,32 @@
from application.parser.remote.base import BaseRemote
from langchain_community.document_loaders import WebBaseLoader
headers = {
"User-Agent": "Mozilla/5.0",
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*"
";q=0.8",
"Accept-Language": "en-US,en;q=0.5",
"Referer": "https://www.google.com/",
"DNT": "1",
"Connection": "keep-alive",
"Upgrade-Insecure-Requests": "1",
}
class WebLoader(BaseRemote):
def __init__(self):
self.loader = WebBaseLoader
def load_data(self, inputs):
urls = inputs
if isinstance(urls, str):
urls = [urls]
documents = []
for url in urls:
try:
loader = self.loader([url], header_template=headers)
documents.extend(loader.load())
except Exception as e:
print(f"Error processing URL {url}: {e}")
continue
return documents

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,34 @@
"""Base schema for readers."""
from dataclasses import dataclass
from langchain.docstore.document import Document as LCDocument
from application.parser.schema.schema import BaseDocument
@dataclass
class Document(BaseDocument):
"""Generic interface for a data document.
This document connects to data sources.
"""
def __post_init__(self) -> None:
"""Post init."""
if self.text is None:
raise ValueError("text field not set.")
@classmethod
def get_type(cls) -> str:
"""Get Document type."""
return "Document"
def to_langchain_format(self) -> LCDocument:
"""Convert struct to LangChain document format."""
metadata = self.extra_info or {}
return LCDocument(page_content=self.text, metadata=metadata)
@classmethod
def from_langchain_format(cls, doc: LCDocument) -> "Document":
"""Convert struct from LangChain document format."""
return cls(text=doc.page_content, extra_info=doc.metadata)

View File

@@ -0,0 +1,64 @@
"""Base schema for data structures."""
from abc import abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from dataclasses_json import DataClassJsonMixin
@dataclass
class BaseDocument(DataClassJsonMixin):
"""Base document.
Generic abstract interfaces that captures both index structs
as well as documents.
"""
# TODO: consolidate fields from Document/IndexStruct into base class
text: Optional[str] = None
doc_id: Optional[str] = None
embedding: Optional[List[float]] = None
# extra fields
extra_info: Optional[Dict[str, Any]] = None
@classmethod
@abstractmethod
def get_type(cls) -> str:
"""Get Document type."""
def get_text(self) -> str:
"""Get text."""
if self.text is None:
raise ValueError("text field not set.")
return self.text
def get_doc_id(self) -> str:
"""Get doc_id."""
if self.doc_id is None:
raise ValueError("doc_id not set.")
return self.doc_id
@property
def is_doc_id_none(self) -> bool:
"""Check if doc_id is None."""
return self.doc_id is None
def get_embedding(self) -> List[float]:
"""Get embedding.
Errors if embedding is None.
"""
if self.embedding is None:
raise ValueError("embedding not set.")
return self.embedding
@property
def extra_info_str(self) -> Optional[str]:
"""Extra info string."""
if self.extra_info is None:
return None
return "\n".join([f"{k}: {str(v)}" for k, v in self.extra_info.items()])

View File

@@ -0,0 +1,79 @@
import re
from math import ceil
from typing import List
import tiktoken
from application.parser.schema.base import Document
def separate_header_and_body(text):
header_pattern = r"^(.*?\n){3}"
match = re.match(header_pattern, text)
header = match.group(0)
body = text[len(header):]
return header, body
def group_documents(documents: List[Document], min_tokens: int, max_tokens: int) -> List[Document]:
docs = []
current_group = None
for doc in documents:
doc_len = len(tiktoken.get_encoding("cl100k_base").encode(doc.text))
# Check if current group is empty or if the document can be added based on token count and matching metadata
if (current_group is None or
(len(tiktoken.get_encoding("cl100k_base").encode(current_group.text)) + doc_len < max_tokens and
doc_len < min_tokens and
current_group.extra_info == doc.extra_info)):
if current_group is None:
current_group = doc # Use the document directly to retain its metadata
else:
current_group.text += " " + doc.text # Append text to the current group
else:
docs.append(current_group)
current_group = doc # Start a new group with the current document
if current_group is not None:
docs.append(current_group)
return docs
def split_documents(documents: List[Document], max_tokens: int) -> List[Document]:
docs = []
for doc in documents:
token_length = len(tiktoken.get_encoding("cl100k_base").encode(doc.text))
if token_length <= max_tokens:
docs.append(doc)
else:
header, body = separate_header_and_body(doc.text)
if len(tiktoken.get_encoding("cl100k_base").encode(header)) > max_tokens:
body = doc.text
header = ""
num_body_parts = ceil(token_length / max_tokens)
part_length = ceil(len(body) / num_body_parts)
body_parts = [body[i:i + part_length] for i in range(0, len(body), part_length)]
for i, body_part in enumerate(body_parts):
new_doc = Document(text=header + body_part.strip(),
doc_id=f"{doc.doc_id}-{i}",
embedding=doc.embedding,
extra_info=doc.extra_info)
docs.append(new_doc)
return docs
def group_split(documents: List[Document], max_tokens: int = 2000, min_tokens: int = 150, token_check: bool = True):
if not token_check:
return documents
print("Grouping small documents")
try:
documents = group_documents(documents=documents, min_tokens=min_tokens, max_tokens=max_tokens)
except Exception:
print("Grouping failed, try running without token_check")
print("Separating large documents")
try:
documents = split_documents(documents=documents, max_tokens=max_tokens)
except Exception:
print("Grouping failed, try running without token_check")
return documents

View File

@@ -0,0 +1,9 @@
You are a DocsGPT, friendly and helpful AI assistant by Arc53 that provides help with documents. You give thorough answers with code examples if possible.
Use the following pieces of context to help answer the users question. If its not relevant to the question, provide friendly responses.
You have access to chat history, and can use it to help answer the question.
When using code examples, use the following format:
```(language)
(code)
```
----------------
{summaries}

View File

@@ -0,0 +1,9 @@
You are a helpful AI assistant, DocsGPT, specializing in document assistance, designed to offer detailed and informative responses.
If appropriate, your answers can include code examples, formatted as follows:
```(language)
(code)
```
You effectively utilize chat history, ensuring relevant and tailored responses.
If a question doesn't align with your context, you provide friendly and helpful replies.
----------------
{summaries}

View File

@@ -0,0 +1,13 @@
You are an AI Assistant, DocsGPT, adept at offering document assistance.
Your expertise lies in providing answer on top of provided context.
You can leverage the chat history if needed.
Answer the question based on the context below.
Keep the answer concise. Respond "Irrelevant context" if not sure about the answer.
If question is not related to the context, respond "Irrelevant context".
When using code examples, use the following format:
```(language)
(code)
```
----------------
Context:
{summaries}

View File

@@ -0,0 +1,3 @@
Use the following pieces of context to help answer the users question. If its not relevant to the question, respond with "-"
----------------
{context}

View File

@@ -0,0 +1,86 @@
anthropic==0.34.2
boto3==1.34.153
beautifulsoup4==4.12.3
celery==5.3.6
dataclasses-json==0.6.7
docx2txt==0.8
duckduckgo-search==6.2.6
ebooklib==0.18
elastic-transport==8.15.0
elasticsearch==8.15.1
escodegen==1.0.11
esprima==4.0.1
esutils==1.0.1
Flask==3.0.3
faiss-cpu==1.8.0.post1
flask-restx==1.3.0
gunicorn==23.0.0
html2text==2024.2.26
javalang==0.13.0
jinja2==3.1.4
jiter==0.5.0
jmespath==1.0.1
joblib==1.4.2
jsonpatch==1.33
jsonpointer==3.0.0
jsonschema==4.23.0
jsonschema-spec==0.2.4
jsonschema-specifications==2023.7.1
kombu==5.4.2
langchain==0.3.0
langchain-community==0.3.0
langchain-core==0.3.2
langchain-openai==0.2.0
langchain-text-splitters==0.3.0
langsmith==0.1.125
lazy-object-proxy==1.10.0
lxml==5.3.0
markupsafe==2.1.5
marshmallow==3.22.0
mpmath==1.3.0
multidict==6.1.0
mypy-extensions==1.0.0
networkx==3.3
numpy==1.26.4
openai==1.46.1
openapi-schema-validator==0.6.2
openapi-spec-validator==0.6.0
openapi3-parser==1.1.18
orjson==3.10.7
packaging==24.1
pandas==2.2.3
pathable==0.4.3
pillow==10.4.0
portalocker==2.10.1
prance==23.6.21.0
primp==0.6.2
prompt-toolkit==3.0.47
protobuf==5.28.2
py==1.11.0
pydantic==2.9.2
pydantic-core==2.23.4
pydantic-settings==2.4.0
pymongo==4.8.0
pypdf2==3.0.1
python-dateutil==2.9.0.post0
python-dotenv==1.0.1
qdrant-client==1.11.0
redis==5.0.1
referencing==0.30.2
regex==2024.9.11
requests==2.32.3
retry==0.9.2
sentence-transformers==3.0.1
tiktoken==0.7.0
tokenizers==0.19.1
torch==2.4.1
tqdm==4.66.5
transformers==4.44.2
typing-extensions==4.12.2
typing-inspect==0.9.0
tzdata==2024.2
urllib3==2.2.3
vine==5.1.0
wcwidth==0.2.13
werkzeug==3.0.4
yarl==1.11.1

View File

View File

@@ -0,0 +1,18 @@
from abc import ABC, abstractmethod
class BaseRetriever(ABC):
def __init__(self):
pass
@abstractmethod
def gen(self, *args, **kwargs):
pass
@abstractmethod
def search(self, *args, **kwargs):
pass
@abstractmethod
def get_params(self):
pass

View File

@@ -0,0 +1,115 @@
import json
from application.retriever.base import BaseRetriever
from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
from application.utils import num_tokens_from_string
from langchain_community.tools import BraveSearch
class BraveRetSearch(BaseRetriever):
def __init__(
self,
question,
source,
chat_history,
prompt,
chunks=2,
token_limit=150,
gpt_model="docsgpt",
user_api_key=None,
):
self.question = question
self.source = source
self.chat_history = chat_history
self.prompt = prompt
self.chunks = chunks
self.gpt_model = gpt_model
self.token_limit = (
token_limit
if token_limit
< settings.MODEL_TOKEN_LIMITS.get(
self.gpt_model, settings.DEFAULT_MAX_HISTORY
)
else settings.MODEL_TOKEN_LIMITS.get(
self.gpt_model, settings.DEFAULT_MAX_HISTORY
)
)
self.user_api_key = user_api_key
def _get_data(self):
if self.chunks == 0:
docs = []
else:
search = BraveSearch.from_api_key(
api_key=settings.BRAVE_SEARCH_API_KEY,
search_kwargs={"count": int(self.chunks)},
)
results = search.run(self.question)
results = json.loads(results)
docs = []
for i in results:
try:
title = i["title"]
link = i["link"]
snippet = i["snippet"]
docs.append({"text": snippet, "title": title, "link": link})
except IndexError:
pass
if settings.LLM_NAME == "llama.cpp":
docs = [docs[0]]
return docs
def gen(self):
docs = self._get_data()
# join all page_content together with a newline
docs_together = "\n".join([doc["text"] for doc in docs])
p_chat_combine = self.prompt.replace("{summaries}", docs_together)
messages_combine = [{"role": "system", "content": p_chat_combine}]
for doc in docs:
yield {"source": doc}
if len(self.chat_history) > 1:
tokens_current_history = 0
# count tokens in history
self.chat_history.reverse()
for i in self.chat_history:
if "prompt" in i and "response" in i:
tokens_batch = num_tokens_from_string(i["prompt"]) + num_tokens_from_string(
i["response"]
)
if tokens_current_history + tokens_batch < self.token_limit:
tokens_current_history += tokens_batch
messages_combine.append(
{"role": "user", "content": i["prompt"]}
)
messages_combine.append(
{"role": "system", "content": i["response"]}
)
messages_combine.append({"role": "user", "content": self.question})
llm = LLMCreator.create_llm(
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key
)
completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
for line in completion:
yield {"answer": str(line)}
def search(self):
return self._get_data()
def get_params(self):
return {
"question": self.question,
"source": self.source,
"chat_history": self.chat_history,
"prompt": self.prompt,
"chunks": self.chunks,
"token_limit": self.token_limit,
"gpt_model": self.gpt_model,
"user_api_key": self.user_api_key
}

View File

@@ -0,0 +1,118 @@
from application.retriever.base import BaseRetriever
from application.core.settings import settings
from application.vectorstore.vector_creator import VectorCreator
from application.llm.llm_creator import LLMCreator
from application.utils import num_tokens_from_string
class ClassicRAG(BaseRetriever):
def __init__(
self,
question,
source,
chat_history,
prompt,
chunks=2,
token_limit=150,
gpt_model="docsgpt",
user_api_key=None,
):
self.question = question
self.vectorstore = source['active_docs'] if 'active_docs' in source else None
self.chat_history = chat_history
self.prompt = prompt
self.chunks = chunks
self.gpt_model = gpt_model
self.token_limit = (
token_limit
if token_limit
< settings.MODEL_TOKEN_LIMITS.get(
self.gpt_model, settings.DEFAULT_MAX_HISTORY
)
else settings.MODEL_TOKEN_LIMITS.get(
self.gpt_model, settings.DEFAULT_MAX_HISTORY
)
)
self.user_api_key = user_api_key
def _get_data(self):
if self.chunks == 0:
docs = []
else:
docsearch = VectorCreator.create_vectorstore(
settings.VECTOR_STORE, self.vectorstore, settings.EMBEDDINGS_KEY
)
docs_temp = docsearch.search(self.question, k=self.chunks)
print(docs_temp)
docs = [
{
"title": i.metadata.get(
"title", i.metadata.get("post_title", i.page_content)
).split("/")[-1],
"text": i.page_content,
"source": (
i.metadata.get("source")
if i.metadata.get("source")
else "local"
),
}
for i in docs_temp
]
if settings.LLM_NAME == "llama.cpp":
docs = [docs[0]]
return docs
def gen(self):
docs = self._get_data()
# join all page_content together with a newline
docs_together = "\n".join([doc["text"] for doc in docs])
p_chat_combine = self.prompt.replace("{summaries}", docs_together)
messages_combine = [{"role": "system", "content": p_chat_combine}]
for doc in docs:
yield {"source": doc}
if len(self.chat_history) > 1:
tokens_current_history = 0
# count tokens in history
self.chat_history.reverse()
for i in self.chat_history:
if "prompt" in i and "response" in i:
tokens_batch = num_tokens_from_string(i["prompt"]) + num_tokens_from_string(
i["response"]
)
if tokens_current_history + tokens_batch < self.token_limit:
tokens_current_history += tokens_batch
messages_combine.append(
{"role": "user", "content": i["prompt"]}
)
messages_combine.append(
{"role": "system", "content": i["response"]}
)
messages_combine.append({"role": "user", "content": self.question})
llm = LLMCreator.create_llm(
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key
)
completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
for line in completion:
yield {"answer": str(line)}
def search(self):
return self._get_data()
def get_params(self):
return {
"question": self.question,
"source": self.vectorstore,
"chat_history": self.chat_history,
"prompt": self.prompt,
"chunks": self.chunks,
"token_limit": self.token_limit,
"gpt_model": self.gpt_model,
"user_api_key": self.user_api_key
}

View File

@@ -0,0 +1,132 @@
from application.retriever.base import BaseRetriever
from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
from application.utils import num_tokens_from_string
from langchain_community.tools import DuckDuckGoSearchResults
from langchain_community.utilities import DuckDuckGoSearchAPIWrapper
class DuckDuckSearch(BaseRetriever):
def __init__(
self,
question,
source,
chat_history,
prompt,
chunks=2,
token_limit=150,
gpt_model="docsgpt",
user_api_key=None,
):
self.question = question
self.source = source
self.chat_history = chat_history
self.prompt = prompt
self.chunks = chunks
self.gpt_model = gpt_model
self.token_limit = (
token_limit
if token_limit
< settings.MODEL_TOKEN_LIMITS.get(
self.gpt_model, settings.DEFAULT_MAX_HISTORY
)
else settings.MODEL_TOKEN_LIMITS.get(
self.gpt_model, settings.DEFAULT_MAX_HISTORY
)
)
self.user_api_key = user_api_key
def _parse_lang_string(self, input_string):
result = []
current_item = ""
inside_brackets = False
for char in input_string:
if char == "[":
inside_brackets = True
elif char == "]":
inside_brackets = False
result.append(current_item)
current_item = ""
elif inside_brackets:
current_item += char
if inside_brackets:
result.append(current_item)
return result
def _get_data(self):
if self.chunks == 0:
docs = []
else:
wrapper = DuckDuckGoSearchAPIWrapper(max_results=self.chunks)
search = DuckDuckGoSearchResults(api_wrapper=wrapper)
results = search.run(self.question)
results = self._parse_lang_string(results)
docs = []
for i in results:
try:
text = i.split("title:")[0]
title = i.split("title:")[1].split("link:")[0]
link = i.split("link:")[1]
docs.append({"text": text, "title": title, "link": link})
except IndexError:
pass
if settings.LLM_NAME == "llama.cpp":
docs = [docs[0]]
return docs
def gen(self):
docs = self._get_data()
# join all page_content together with a newline
docs_together = "\n".join([doc["text"] for doc in docs])
p_chat_combine = self.prompt.replace("{summaries}", docs_together)
messages_combine = [{"role": "system", "content": p_chat_combine}]
for doc in docs:
yield {"source": doc}
if len(self.chat_history) > 1:
tokens_current_history = 0
# count tokens in history
self.chat_history.reverse()
for i in self.chat_history:
if "prompt" in i and "response" in i:
tokens_batch = num_tokens_from_string(i["prompt"]) + num_tokens_from_string(
i["response"]
)
if tokens_current_history + tokens_batch < self.token_limit:
tokens_current_history += tokens_batch
messages_combine.append(
{"role": "user", "content": i["prompt"]}
)
messages_combine.append(
{"role": "system", "content": i["response"]}
)
messages_combine.append({"role": "user", "content": self.question})
llm = LLMCreator.create_llm(
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key
)
completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
for line in completion:
yield {"answer": str(line)}
def search(self):
return self._get_data()
def get_params(self):
return {
"question": self.question,
"source": self.source,
"chat_history": self.chat_history,
"prompt": self.prompt,
"chunks": self.chunks,
"token_limit": self.token_limit,
"gpt_model": self.gpt_model,
"user_api_key": self.user_api_key
}

View File

@@ -0,0 +1,20 @@
from application.retriever.classic_rag import ClassicRAG
from application.retriever.duckduck_search import DuckDuckSearch
from application.retriever.brave_search import BraveRetSearch
class RetrieverCreator:
retrievers = {
'classic': ClassicRAG,
'duckduck_search': DuckDuckSearch,
'brave_search': BraveRetSearch,
'default': ClassicRAG
}
@classmethod
def create_retriever(cls, type, *args, **kwargs):
retiever_class = cls.retrievers.get(type.lower())
if not retiever_class:
raise ValueError(f"No retievers class found for type {type}")
return retiever_class(*args, **kwargs)

49
application/usage.py Normal file
View File

@@ -0,0 +1,49 @@
import sys
from pymongo import MongoClient
from datetime import datetime
from application.core.settings import settings
from application.utils import num_tokens_from_string
mongo = MongoClient(settings.MONGO_URI)
db = mongo["docsgpt"]
usage_collection = db["token_usage"]
def update_token_usage(user_api_key, token_usage):
if "pytest" in sys.modules:
return
usage_data = {
"api_key": user_api_key,
"prompt_tokens": token_usage["prompt_tokens"],
"generated_tokens": token_usage["generated_tokens"],
"timestamp": datetime.now(),
}
usage_collection.insert_one(usage_data)
def gen_token_usage(func):
def wrapper(self, model, messages, stream, **kwargs):
for message in messages:
self.token_usage["prompt_tokens"] += num_tokens_from_string(message["content"])
result = func(self, model, messages, stream, **kwargs)
self.token_usage["generated_tokens"] += num_tokens_from_string(result)
update_token_usage(self.user_api_key, self.token_usage)
return result
return wrapper
def stream_token_usage(func):
def wrapper(self, model, messages, stream, **kwargs):
for message in messages:
self.token_usage["prompt_tokens"] += num_tokens_from_string(message["content"])
batch = []
result = func(self, model, messages, stream, **kwargs)
for r in result:
batch.append(r)
yield r
for line in batch:
self.token_usage["generated_tokens"] += num_tokens_from_string(line)
update_token_usage(self.user_api_key, self.token_usage)
return wrapper

41
application/utils.py Normal file
View File

@@ -0,0 +1,41 @@
import tiktoken
from flask import jsonify, make_response
_encoding = None
def get_encoding():
global _encoding
if _encoding is None:
_encoding = tiktoken.get_encoding("cl100k_base")
return _encoding
def num_tokens_from_string(string: str) -> int:
encoding = get_encoding()
num_tokens = len(encoding.encode(string))
return num_tokens
def count_tokens_docs(docs):
docs_content = ""
for doc in docs:
docs_content += doc.page_content
tokens = num_tokens_from_string(docs_content)
return tokens
def check_required_fields(data, required_fields):
missing_fields = [field for field in required_fields if field not in data]
if missing_fields:
return make_response(
jsonify(
{
"success": False,
"message": f"Missing fields: {', '.join(missing_fields)}",
}
),
400,
)
return None

View File

View File

@@ -0,0 +1,89 @@
from abc import ABC, abstractmethod
import os
from sentence_transformers import SentenceTransformer
from langchain_openai import OpenAIEmbeddings
from application.core.settings import settings
class EmbeddingsWrapper:
def __init__(self, model_name, *args, **kwargs):
self.model = SentenceTransformer(model_name, config_kwargs={'allow_dangerous_deserialization': True}, *args, **kwargs)
self.dimension = self.model.get_sentence_embedding_dimension()
def embed_query(self, query: str):
return self.model.encode(query).tolist()
def embed_documents(self, documents: list):
return self.model.encode(documents).tolist()
def __call__(self, text):
if isinstance(text, str):
return self.embed_query(text)
elif isinstance(text, list):
return self.embed_documents(text)
else:
raise ValueError("Input must be a string or a list of strings")
class EmbeddingsSingleton:
_instances = {}
@staticmethod
def get_instance(embeddings_name, *args, **kwargs):
if embeddings_name not in EmbeddingsSingleton._instances:
EmbeddingsSingleton._instances[embeddings_name] = EmbeddingsSingleton._create_instance(
embeddings_name, *args, **kwargs
)
return EmbeddingsSingleton._instances[embeddings_name]
@staticmethod
def _create_instance(embeddings_name, *args, **kwargs):
embeddings_factory = {
"openai_text-embedding-ada-002": OpenAIEmbeddings,
"huggingface_sentence-transformers/all-mpnet-base-v2": lambda: EmbeddingsWrapper("sentence-transformers/all-mpnet-base-v2"),
"huggingface_sentence-transformers-all-mpnet-base-v2": lambda: EmbeddingsWrapper("sentence-transformers/all-mpnet-base-v2"),
"huggingface_hkunlp/instructor-large": lambda: EmbeddingsWrapper("hkunlp/instructor-large"),
}
if embeddings_name in embeddings_factory:
return embeddings_factory[embeddings_name](*args, **kwargs)
else:
return EmbeddingsWrapper(embeddings_name, *args, **kwargs)
class BaseVectorStore(ABC):
def __init__(self):
pass
@abstractmethod
def search(self, *args, **kwargs):
pass
def is_azure_configured(self):
return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME
def _get_embeddings(self, embeddings_name, embeddings_key=None):
if embeddings_name == "openai_text-embedding-ada-002":
if self.is_azure_configured():
os.environ["OPENAI_API_TYPE"] = "azure"
embedding_instance = EmbeddingsSingleton.get_instance(
embeddings_name,
model=settings.AZURE_EMBEDDINGS_DEPLOYMENT_NAME
)
else:
embedding_instance = EmbeddingsSingleton.get_instance(
embeddings_name,
openai_api_key=embeddings_key
)
elif embeddings_name == "huggingface_sentence-transformers/all-mpnet-base-v2":
if os.path.exists("./model/all-mpnet-base-v2"):
embedding_instance = EmbeddingsSingleton.get_instance(
embeddings_name="./model/all-mpnet-base-v2",
)
else:
embedding_instance = EmbeddingsSingleton.get_instance(
embeddings_name,
)
else:
embedding_instance = EmbeddingsSingleton.get_instance(embeddings_name)
return embedding_instance

View File

@@ -0,0 +1,8 @@
class Document(str):
"""Class for storing a piece of text and associated metadata."""
def __new__(cls, page_content: str, metadata: dict):
instance = super().__new__(cls, page_content)
instance.page_content = page_content
instance.metadata = metadata
return instance

Some files were not shown because too many files have changed in this diff Show More