Compare commits

...

280 Commits

Author SHA1 Message Date
Alex
ddd5704c49 fix: mini docstring stuff 2026-04-26 00:56:31 +01:00
Alex
d54e6d8b34 fix: test 2026-04-26 00:36:35 +01:00
Alex
2806825959 feat: simplified model structure 2026-04-26 00:20:37 +01:00
Alex
ef976eeb06 feat: make version check periodic 2026-04-25 14:57:37 +01:00
Alex
9c8ae9d540 feat: redbeat 2026-04-25 14:38:24 +01:00
Alex
7ca33b2b72 feat: OTEL 2026-04-25 13:38:03 +01:00
Manish Madan
d1b9798f62 Merge pull request #2422 from arc53/dependabot/npm_and_yarn/extensions/react-widget/babel/plugin-transform-flow-strip-types-7.27.1
chore(deps): bump @babel/plugin-transform-flow-strip-types from 7.24.6 to 7.27.1 in /extensions/react-widget
2026-04-24 05:13:00 +05:30
dependabot[bot]
ddc3adf3ab chore(deps): bump @babel/plugin-transform-flow-strip-types
Bumps [@babel/plugin-transform-flow-strip-types](https://github.com/babel/babel/tree/HEAD/packages/babel-plugin-transform-flow-strip-types) from 7.24.6 to 7.27.1.
- [Release notes](https://github.com/babel/babel/releases)
- [Changelog](https://github.com/babel/babel/blob/main/CHANGELOG.md)
- [Commits](https://github.com/babel/babel/commits/v7.27.1/packages/babel-plugin-transform-flow-strip-types)

---
updated-dependencies:
- dependency-name: "@babel/plugin-transform-flow-strip-types"
  dependency-version: 7.27.1
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-23 23:41:46 +00:00
Manish Madan
a4991d01ac Merge pull request #2421 from arc53/dependabot/npm_and_yarn/extensions/react-widget/typescript-6.0.3
chore(deps-dev): bump typescript from 5.9.3 to 6.0.3 in /extensions/react-widget
2026-04-24 05:09:58 +05:30
ManishMadan2882
87fd1bd359 chore(deps-dev): bump @typescript-eslint/{eslint-plugin,parser} to ^8.58.0 for TS 6 support 2026-04-24 05:08:17 +05:30
dependabot[bot]
c71e986d34 chore(deps-dev): bump typescript in /extensions/react-widget
Bumps [typescript](https://github.com/microsoft/TypeScript) from 5.9.3 to 6.0.3.
- [Release notes](https://github.com/microsoft/TypeScript/releases)
- [Commits](https://github.com/microsoft/TypeScript/compare/v5.9.3...v6.0.3)

---
updated-dependencies:
- dependency-name: typescript
  dependency-version: 6.0.3
  dependency-type: direct:development
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-23 23:18:29 +00:00
Manish Madan
a2a06c569e Merge pull request #2419 from arc53/dependabot/npm_and_yarn/extensions/react-widget/prettier-3.8.3
chore(deps-dev): bump prettier from 3.8.1 to 3.8.3 in /extensions/react-widget
2026-04-24 04:34:37 +05:30
Alex
c5f00a1d1b fix: remove old extension in external repo's 2026-04-23 23:27:38 +01:00
Alex
2a15bb0102 chore: delete old extensions 2026-04-23 22:55:03 +01:00
Alex
c06888bc86 feat: asgi and search service (#2424)
* feat: asgi and search service

* feat: asgi and mcp tool server

* fix: asgi issues

* fix: mini cors hardening
2026-04-23 12:21:39 +01:00
Alex
d4b1c1fd81 chore: 0.17.0 version 2026-04-21 16:16:11 +01:00
Alex
2de84acf81 fix: mini callout 2026-04-21 16:14:08 +01:00
Alex
2702750861 docs: upgrading guide 2026-04-21 15:04:17 +01:00
Alex
2b5f20d0ec fix: safer version 2026-04-21 14:22:32 +01:00
Alex
619b41dc5b fix: better version fetch 2026-04-21 14:07:26 +01:00
Alex
76d8f49ccb feat: security version check 2026-04-21 09:16:52 +01:00
dependabot[bot]
65460b0c03 chore(deps-dev): bump prettier in /extensions/react-widget
Bumps [prettier](https://github.com/prettier/prettier) from 3.8.1 to 3.8.3.
- [Release notes](https://github.com/prettier/prettier/releases)
- [Changelog](https://github.com/prettier/prettier/blob/main/CHANGELOG.md)
- [Commits](https://github.com/prettier/prettier/compare/3.8.1...3.8.3)

---
updated-dependencies:
- dependency-name: prettier
  dependency-version: 3.8.3
  dependency-type: direct:development
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-20 22:35:52 +00:00
Manish Madan
9fe96fb50f Merge pull request #2397 from arc53/dependabot/npm_and_yarn/extensions/react-widget/multi-0193e73c84
chore(deps): bump react and @types/react in /extensions/react-widget
2026-04-21 01:35:29 +05:30
Alex
08822c3379 feat: lazy pymongo 2026-04-20 15:58:02 +01:00
ManishMadan2882
68ca8ff9ea (chore) update react-dom 2026-04-20 19:30:14 +05:30
dependabot[bot]
81be3cdccc chore(deps): bump react and @types/react in /extensions/react-widget
Bumps [react](https://github.com/facebook/react/tree/HEAD/packages/react) and [@types/react](https://github.com/DefinitelyTyped/DefinitelyTyped/tree/HEAD/types/react). These dependencies needed to be updated together.

Updates `react` from 18.3.1 to 19.2.5
- [Release notes](https://github.com/facebook/react/releases)
- [Changelog](https://github.com/facebook/react/blob/main/CHANGELOG.md)
- [Commits](https://github.com/facebook/react/commits/v19.2.5/packages/react)

Updates `@types/react` from 18.3.3 to 19.2.14
- [Release notes](https://github.com/DefinitelyTyped/DefinitelyTyped/releases)
- [Commits](https://github.com/DefinitelyTyped/DefinitelyTyped/commits/HEAD/types/react)

---
updated-dependencies:
- dependency-name: react
  dependency-version: 19.2.5
  dependency-type: direct:production
  update-type: version-update:semver-major
- dependency-name: "@types/react"
  dependency-version: 19.2.14
  dependency-type: direct:development
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-20 13:20:46 +00:00
Manish Madan
3ceabed8ad Merge pull request #2406 from arc53/dependabot/npm_and_yarn/extensions/react-widget/npm_and_yarn-2a73d1bbcf
chore(deps): bump dompurify from 3.3.3 to 3.4.0 in /extensions/react-widget in the npm_and_yarn group across 1 directory
2026-04-20 14:45:32 +05:30
dependabot[bot]
422a4b139e chore(deps): bump dompurify
Bumps the npm_and_yarn group with 1 update in the /extensions/react-widget directory: [dompurify](https://github.com/cure53/DOMPurify).


Updates `dompurify` from 3.3.3 to 3.4.0
- [Release notes](https://github.com/cure53/DOMPurify/releases)
- [Commits](https://github.com/cure53/DOMPurify/compare/3.3.3...3.4.0)

---
updated-dependencies:
- dependency-name: dompurify
  dependency-version: 3.4.0
  dependency-type: direct:production
  dependency-group: npm_and_yarn
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-20 08:58:53 +00:00
Manish Madan
e85935eed0 Merge pull request #2402 from arc53/dependabot/npm_and_yarn/extensions/react-widget/flow-bin-0.309.0
chore(deps): bump flow-bin from 0.306.0 to 0.309.0 in /extensions/react-widget
2026-04-20 14:27:30 +05:30
dependabot[bot]
6a69b8aca0 chore(deps): bump flow-bin in /extensions/react-widget
Bumps [flow-bin](https://github.com/flowtype/flow-bin) from 0.306.0 to 0.309.0.
- [Release notes](https://github.com/flowtype/flow-bin/releases)
- [Commits](https://github.com/flowtype/flow-bin/commits)

---
updated-dependencies:
- dependency-name: flow-bin
  dependency-version: 0.309.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-20 08:55:07 +00:00
Manish Madan
33c2cc9660 Merge pull request #2400 from arc53/dependabot/npm_and_yarn/extensions/react-widget/styled-components-6.4.0
chore(deps): bump styled-components from 6.3.12 to 6.4.0 in /extensions/react-widget
2026-04-20 13:30:30 +05:30
dependabot[bot]
175d4d5a68 chore(deps): bump styled-components in /extensions/react-widget
Bumps [styled-components](https://github.com/styled-components/styled-components) from 6.3.12 to 6.4.0.
- [Release notes](https://github.com/styled-components/styled-components/releases)
- [Commits](https://github.com/styled-components/styled-components/compare/styled-components@6.3.12...styled-components@6.4.0)

---
updated-dependencies:
- dependency-name: styled-components
  dependency-version: 6.4.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-20 07:59:02 +00:00
Manish Madan
6c3ead1071 Merge pull request #2413 from arc53/dependabot/npm_and_yarn/docs/npm_and_yarn-a087653e68
chore(deps): bump the npm_and_yarn group across 1 directory with 2 updates
2026-04-20 00:53:14 +05:30
dependabot[bot]
d23f88f825 chore(deps): bump the npm_and_yarn group across 1 directory with 2 updates
Bumps the npm_and_yarn group with 2 updates in the /docs directory: [@xmldom/xmldom](https://github.com/xmldom/xmldom) and [lodash-es](https://github.com/lodash/lodash).


Updates `@xmldom/xmldom` from 0.9.8 to 0.9.9
- [Release notes](https://github.com/xmldom/xmldom/releases)
- [Changelog](https://github.com/xmldom/xmldom/blob/master/CHANGELOG.md)
- [Commits](https://github.com/xmldom/xmldom/compare/0.9.8...0.9.9)

Updates `lodash-es` from 4.17.23 to 4.18.1
- [Release notes](https://github.com/lodash/lodash/releases)
- [Commits](https://github.com/lodash/lodash/compare/4.17.23...4.18.1)

---
updated-dependencies:
- dependency-name: "@xmldom/xmldom"
  dependency-version: 0.9.9
  dependency-type: indirect
  dependency-group: npm_and_yarn
- dependency-name: lodash-es
  dependency-version: 4.18.1
  dependency-type: indirect
  dependency-group: npm_and_yarn
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-19 17:55:14 +00:00
Manish Madan
da1df515f7 Merge pull request #2411 from arc53/dependabot/npm_and_yarn/docs/npm_and_yarn-bfabbafa3d
chore(deps): bump the npm_and_yarn group across 3 directories with 10 updates
2026-04-19 23:23:34 +05:30
dependabot[bot]
671a9d75ad chore(deps): bump the npm_and_yarn group across 3 directories with 10 updates
Bumps the npm_and_yarn group with 6 updates in the /docs directory:

| Package | From | To |
| --- | --- | --- |
| [next](https://github.com/vercel/next.js) | `15.5.9` | `15.5.15` |
| [brace-expansion](https://github.com/juliangruber/brace-expansion) | `5.0.2` | `5.0.5` |
| [dompurify](https://github.com/cure53/DOMPurify) | `3.3.1` | `3.4.0` |
| [markdown-it](https://github.com/markdown-it/markdown-it) | `14.1.0` | `14.1.1` |
| [minimatch](https://github.com/isaacs/minimatch) | `10.2.1` | `10.2.5` |
| [yaml](https://github.com/eemeli/yaml) | `1.10.2` | `1.10.3` |

Bumps the npm_and_yarn group with 2 updates in the /extensions/chrome directory: [yaml](https://github.com/eemeli/yaml) and [picomatch](https://github.com/micromatch/picomatch).
Bumps the npm_and_yarn group with 4 updates in the /extensions/web-widget directory: [brace-expansion](https://github.com/juliangruber/brace-expansion), [minimatch](https://github.com/isaacs/minimatch), [yaml](https://github.com/eemeli/yaml) and [picomatch](https://github.com/micromatch/picomatch).


Updates `next` from 15.5.9 to 15.5.15
- [Release notes](https://github.com/vercel/next.js/releases)
- [Changelog](https://github.com/vercel/next.js/blob/canary/release.js)
- [Commits](https://github.com/vercel/next.js/compare/v15.5.9...v15.5.15)

Updates `brace-expansion` from 5.0.2 to 5.0.5
- [Release notes](https://github.com/juliangruber/brace-expansion/releases)
- [Commits](https://github.com/juliangruber/brace-expansion/compare/v5.0.2...v5.0.5)

Updates `dompurify` from 3.3.1 to 3.4.0
- [Release notes](https://github.com/cure53/DOMPurify/releases)
- [Commits](https://github.com/cure53/DOMPurify/compare/3.3.1...3.4.0)

Updates `markdown-it` from 14.1.0 to 14.1.1
- [Changelog](https://github.com/markdown-it/markdown-it/blob/master/CHANGELOG.md)
- [Commits](https://github.com/markdown-it/markdown-it/compare/14.1.0...14.1.1)

Updates `minimatch` from 10.2.1 to 10.2.5
- [Changelog](https://github.com/isaacs/minimatch/blob/main/changelog.md)
- [Commits](https://github.com/isaacs/minimatch/compare/v10.2.1...v10.2.5)

Updates `yaml` from 1.10.2 to 1.10.3
- [Release notes](https://github.com/eemeli/yaml/releases)
- [Commits](https://github.com/eemeli/yaml/compare/v1.10.2...v1.10.3)

Updates `diff` from 5.2.0 to 5.2.2
- [Changelog](https://github.com/kpdecker/jsdiff/blob/master/release-notes.md)
- [Commits](https://github.com/kpdecker/jsdiff/compare/v5.2.0...v5.2.2)

Updates `glob` from 10.3.12 to 10.5.0
- [Changelog](https://github.com/isaacs/node-glob/blob/main/changelog.md)
- [Commits](https://github.com/isaacs/node-glob/compare/v10.3.12...v10.5.0)

Updates `tar` from 6.2.1 to 7.5.11
- [Release notes](https://github.com/isaacs/node-tar/releases)
- [Changelog](https://github.com/isaacs/node-tar/blob/main/CHANGELOG.md)
- [Commits](https://github.com/isaacs/node-tar/compare/v6.2.1...v7.5.11)

Updates `picomatch` from 2.3.1 to 2.3.2
- [Release notes](https://github.com/micromatch/picomatch/releases)
- [Changelog](https://github.com/micromatch/picomatch/blob/master/CHANGELOG.md)
- [Commits](https://github.com/micromatch/picomatch/compare/2.3.1...2.3.2)

Updates `picomatch` from 2.3.1 to 2.3.2
- [Release notes](https://github.com/micromatch/picomatch/releases)
- [Changelog](https://github.com/micromatch/picomatch/blob/master/CHANGELOG.md)
- [Commits](https://github.com/micromatch/picomatch/compare/2.3.1...2.3.2)

Updates `yaml` from 1.10.2 to 1.10.3
- [Release notes](https://github.com/eemeli/yaml/releases)
- [Commits](https://github.com/eemeli/yaml/compare/v1.10.2...v1.10.3)

Updates `brace-expansion` from 2.0.1 to 2.0.2
- [Release notes](https://github.com/juliangruber/brace-expansion/releases)
- [Commits](https://github.com/juliangruber/brace-expansion/compare/v5.0.2...v5.0.5)

Updates `glob` from 10.3.12 to 10.5.0
- [Changelog](https://github.com/isaacs/node-glob/blob/main/changelog.md)
- [Commits](https://github.com/isaacs/node-glob/compare/v10.3.12...v10.5.0)

Updates `minimatch` from 9.0.4 to 9.0.9
- [Changelog](https://github.com/isaacs/minimatch/blob/main/changelog.md)
- [Commits](https://github.com/isaacs/minimatch/compare/v10.2.1...v10.2.5)

Updates `picomatch` from 2.3.1 to 2.3.2
- [Release notes](https://github.com/micromatch/picomatch/releases)
- [Changelog](https://github.com/micromatch/picomatch/blob/master/CHANGELOG.md)
- [Commits](https://github.com/micromatch/picomatch/compare/2.3.1...2.3.2)

Updates `yaml` from 1.10.2 to 1.10.3
- [Release notes](https://github.com/eemeli/yaml/releases)
- [Commits](https://github.com/eemeli/yaml/compare/v1.10.2...v1.10.3)

Updates `yaml` from 1.10.2 to 1.10.3
- [Release notes](https://github.com/eemeli/yaml/releases)
- [Commits](https://github.com/eemeli/yaml/compare/v1.10.2...v1.10.3)

Updates `picomatch` from 2.3.1 to 2.3.2
- [Release notes](https://github.com/micromatch/picomatch/releases)
- [Changelog](https://github.com/micromatch/picomatch/blob/master/CHANGELOG.md)
- [Commits](https://github.com/micromatch/picomatch/compare/2.3.1...2.3.2)

Updates `picomatch` from 2.3.1 to 2.3.2
- [Release notes](https://github.com/micromatch/picomatch/releases)
- [Changelog](https://github.com/micromatch/picomatch/blob/master/CHANGELOG.md)
- [Commits](https://github.com/micromatch/picomatch/compare/2.3.1...2.3.2)

Updates `yaml` from 1.10.2 to 1.10.3
- [Release notes](https://github.com/eemeli/yaml/releases)
- [Commits](https://github.com/eemeli/yaml/compare/v1.10.2...v1.10.3)

Updates `picomatch` from 2.3.1 to 2.3.2
- [Release notes](https://github.com/micromatch/picomatch/releases)
- [Changelog](https://github.com/micromatch/picomatch/blob/master/CHANGELOG.md)
- [Commits](https://github.com/micromatch/picomatch/compare/2.3.1...2.3.2)

Updates `yaml` from 1.10.2 to 1.10.3
- [Release notes](https://github.com/eemeli/yaml/releases)
- [Commits](https://github.com/eemeli/yaml/compare/v1.10.2...v1.10.3)

Updates `brace-expansion` from 1.1.11 to 1.1.14
- [Release notes](https://github.com/juliangruber/brace-expansion/releases)
- [Commits](https://github.com/juliangruber/brace-expansion/compare/v5.0.2...v5.0.5)

Updates `minimatch` from 3.1.2 to 3.1.5
- [Changelog](https://github.com/isaacs/minimatch/blob/main/changelog.md)
- [Commits](https://github.com/isaacs/minimatch/compare/v10.2.1...v10.2.5)

Updates `yaml` from 1.10.2 to 1.10.3
- [Release notes](https://github.com/eemeli/yaml/releases)
- [Commits](https://github.com/eemeli/yaml/compare/v1.10.2...v1.10.3)

Updates `picomatch` from 2.3.1 to 2.3.2
- [Release notes](https://github.com/micromatch/picomatch/releases)
- [Changelog](https://github.com/micromatch/picomatch/blob/master/CHANGELOG.md)
- [Commits](https://github.com/micromatch/picomatch/compare/2.3.1...2.3.2)

Updates `picomatch` from 2.3.1 to 2.3.2
- [Release notes](https://github.com/micromatch/picomatch/releases)
- [Changelog](https://github.com/micromatch/picomatch/blob/master/CHANGELOG.md)
- [Commits](https://github.com/micromatch/picomatch/compare/2.3.1...2.3.2)

Updates `yaml` from 1.10.2 to 1.10.3
- [Release notes](https://github.com/eemeli/yaml/releases)
- [Commits](https://github.com/eemeli/yaml/compare/v1.10.2...v1.10.3)

Updates `brace-expansion` from 1.1.11 to 1.1.14
- [Release notes](https://github.com/juliangruber/brace-expansion/releases)
- [Commits](https://github.com/juliangruber/brace-expansion/compare/v5.0.2...v5.0.5)

Updates `minimatch` from 3.1.2 to 3.1.5
- [Changelog](https://github.com/isaacs/minimatch/blob/main/changelog.md)
- [Commits](https://github.com/isaacs/minimatch/compare/v10.2.1...v10.2.5)

Updates `picomatch` from 2.3.1 to 2.3.2
- [Release notes](https://github.com/micromatch/picomatch/releases)
- [Changelog](https://github.com/micromatch/picomatch/blob/master/CHANGELOG.md)
- [Commits](https://github.com/micromatch/picomatch/compare/2.3.1...2.3.2)

Updates `yaml` from 1.10.2 to 1.10.3
- [Release notes](https://github.com/eemeli/yaml/releases)
- [Commits](https://github.com/eemeli/yaml/compare/v1.10.2...v1.10.3)

---
updated-dependencies:
- dependency-name: next
  dependency-version: 15.5.15
  dependency-type: direct:production
  dependency-group: npm_and_yarn
- dependency-name: brace-expansion
  dependency-version: 5.0.5
  dependency-type: indirect
  dependency-group: npm_and_yarn
- dependency-name: dompurify
  dependency-version: 3.4.0
  dependency-type: indirect
  dependency-group: npm_and_yarn
- dependency-name: markdown-it
  dependency-version: 14.1.1
  dependency-type: indirect
  dependency-group: npm_and_yarn
- dependency-name: minimatch
  dependency-version: 10.2.5
  dependency-type: indirect
  dependency-group: npm_and_yarn
- dependency-name: yaml
  dependency-version: 1.10.3
  dependency-type: indirect
  dependency-group: npm_and_yarn
- dependency-name: diff
  dependency-version: 5.2.2
  dependency-type: indirect
  dependency-group: npm_and_yarn
- dependency-name: glob
  dependency-version: 10.5.0
  dependency-type: indirect
  dependency-group: npm_and_yarn
- dependency-name: tar
  dependency-version: 7.5.11
  dependency-type: indirect
  dependency-group: npm_and_yarn
- dependency-name: picomatch
  dependency-version: 2.3.2
  dependency-type: indirect
  dependency-group: npm_and_yarn
- dependency-name: picomatch
  dependency-version: 2.3.2
  dependency-type: indirect
  dependency-group: npm_and_yarn
- dependency-name: yaml
  dependency-version: 1.10.3
  dependency-type: indirect
  dependency-group: npm_and_yarn
- dependency-name: brace-expansion
  dependency-version: 2.0.2
  dependency-type: indirect
  dependency-group: npm_and_yarn
- dependency-name: glob
  dependency-version: 10.5.0
  dependency-type: indirect
  dependency-group: npm_and_yarn
- dependency-name: minimatch
  dependency-version: 9.0.9
  dependency-type: indirect
  dependency-group: npm_and_yarn
- dependency-name: picomatch
  dependency-version: 2.3.2
  dependency-type: indirect
  dependency-group: npm_and_yarn
- dependency-name: yaml
  dependency-version: 1.10.3
  dependency-type: indirect
  dependency-group: npm_and_yarn
- dependency-name: yaml
  dependency-version: 1.10.3
  dependency-type: indirect
  dependency-group: npm_and_yarn
- dependency-name: picomatch
  dependency-version: 2.3.2
  dependency-type: indirect
  dependency-group: npm_and_yarn
- dependency-name: picomatch
  dependency-version: 2.3.2
  dependency-type: indirect
  dependency-group: npm_and_yarn
- dependency-name: yaml
  dependency-version: 1.10.3
  dependency-type: indirect
  dependency-group: npm_and_yarn
- dependency-name: picomatch
  dependency-version: 2.3.2
  dependency-type: indirect
  dependency-group: npm_and_yarn
- dependency-name: yaml
  dependency-version: 1.10.3
  dependency-type: indirect
  dependency-group: npm_and_yarn
- dependency-name: brace-expansion
  dependency-version: 1.1.14
  dependency-type: indirect
  dependency-group: npm_and_yarn
- dependency-name: minimatch
  dependency-version: 3.1.5
  dependency-type: indirect
  dependency-group: npm_and_yarn
- dependency-name: yaml
  dependency-version: 1.10.3
  dependency-type: indirect
  dependency-group: npm_and_yarn
- dependency-name: picomatch
  dependency-version: 2.3.2
  dependency-type: indirect
  dependency-group: npm_and_yarn
- dependency-name: picomatch
  dependency-version: 2.3.2
  dependency-type: indirect
  dependency-group: npm_and_yarn
- dependency-name: yaml
  dependency-version: 1.10.3
  dependency-type: indirect
  dependency-group: npm_and_yarn
- dependency-name: brace-expansion
  dependency-version: 1.1.14
  dependency-type: indirect
  dependency-group: npm_and_yarn
- dependency-name: minimatch
  dependency-version: 3.1.5
  dependency-type: indirect
  dependency-group: npm_and_yarn
- dependency-name: picomatch
  dependency-version: 2.3.2
  dependency-type: indirect
  dependency-group: npm_and_yarn
- dependency-name: yaml
  dependency-version: 1.10.3
  dependency-type: indirect
  dependency-group: npm_and_yarn
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-19 17:47:08 +00:00
Manish Madan
1c829667ff Merge pull request #2326 from arc53/dependabot/npm_and_yarn/extensions/react-widget/class-variance-authority-0.7.1
chore(deps): bump class-variance-authority from 0.7.0 to 0.7.1 in /extensions/react-widget
2026-04-19 23:13:02 +05:30
Manish Madan
3ab0ebb16d Merge pull request #2398 from arc53/dependabot/pip/application/openai-2.31.0
chore(deps): bump openai from 2.30.0 to 2.31.0 in /application
2026-04-19 23:10:36 +05:30
dependabot[bot]
988c4a5a15 chore(deps): bump openai from 2.30.0 to 2.31.0 in /application
Bumps [openai](https://github.com/openai/openai-python) from 2.30.0 to 2.31.0.
- [Release notes](https://github.com/openai/openai-python/releases)
- [Changelog](https://github.com/openai/openai-python/blob/main/CHANGELOG.md)
- [Commits](https://github.com/openai/openai-python/compare/v2.30.0...v2.31.0)

---
updated-dependencies:
- dependency-name: openai
  dependency-version: 2.31.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-19 17:14:43 +00:00
Manish Madan
01db8b2c41 Merge pull request #2395 from arc53/dependabot/pip/application/google-genai-1.73.1
chore(deps): bump google-genai from 1.69.0 to 1.73.1 in /application
2026-04-19 22:43:14 +05:30
Alex
ef19da9516 fix: pg bouncer compatible 2026-04-19 17:53:22 +01:00
dependabot[bot]
cc1275c3f9 chore(deps): bump google-genai from 1.69.0 to 1.73.1 in /application
Bumps [google-genai](https://github.com/googleapis/python-genai) from 1.69.0 to 1.73.1.
- [Release notes](https://github.com/googleapis/python-genai/releases)
- [Changelog](https://github.com/googleapis/python-genai/blob/main/CHANGELOG.md)
- [Commits](https://github.com/googleapis/python-genai/compare/v1.69.0...v1.73.1)

---
updated-dependencies:
- dependency-name: google-genai
  dependency-version: 1.73.1
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-19 15:30:12 +00:00
Manish Madan
14c2f4890f Merge pull request #2394 from arc53/dependabot/pip/application/fastmcp-3.2.4
chore(deps): bump fastmcp from 3.2.0 to 3.2.4 in /application
2026-04-19 20:59:04 +05:30
dependabot[bot]
b3aec36aa2 chore(deps): bump fastmcp from 3.2.0 to 3.2.4 in /application
Bumps [fastmcp](https://github.com/PrefectHQ/fastmcp) from 3.2.0 to 3.2.4.
- [Release notes](https://github.com/PrefectHQ/fastmcp/releases)
- [Changelog](https://github.com/PrefectHQ/fastmcp/blob/main/docs/changelog.mdx)
- [Commits](https://github.com/PrefectHQ/fastmcp/compare/v3.2.0...v3.2.4)

---
updated-dependencies:
- dependency-name: fastmcp
  dependency-version: 3.2.4
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-19 15:13:45 +00:00
Manish Madan
50f62beaeb Merge pull request #2392 from arc53/dependabot/pip/application/elevenlabs-2.43.0
chore(deps): bump elevenlabs from 2.41.0 to 2.43.0 in /application
2026-04-19 20:42:32 +05:30
dependabot[bot]
423b4c6494 chore(deps): bump elevenlabs from 2.41.0 to 2.43.0 in /application
Bumps [elevenlabs](https://github.com/elevenlabs/elevenlabs-python) from 2.41.0 to 2.43.0.
- [Release notes](https://github.com/elevenlabs/elevenlabs-python/releases)
- [Commits](https://github.com/elevenlabs/elevenlabs-python/compare/v2.41.0...v2.43.0)

---
updated-dependencies:
- dependency-name: elevenlabs
  dependency-version: 2.43.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-19 14:49:32 +00:00
Manish Madan
54f615c59d Merge pull request #2407 from arc53/dependabot/npm_and_yarn/frontend/npm_and_yarn-f0540bac9f
chore(deps): bump the npm_and_yarn group across 1 directory with 4 updates
2026-04-19 20:17:07 +05:30
dependabot[bot]
223b3de66e chore(deps): bump the npm_and_yarn group across 1 directory with 4 updates
Bumps the npm_and_yarn group with 3 updates in the /frontend directory: [lodash](https://github.com/lodash/lodash), [vite](https://github.com/vitejs/vite/tree/HEAD/packages/vite) and [dompurify](https://github.com/cure53/DOMPurify).


Updates `lodash` from 4.17.23 to 4.18.1
- [Release notes](https://github.com/lodash/lodash/releases)
- [Commits](https://github.com/lodash/lodash/compare/4.17.23...4.18.1)

Updates `vite` from 8.0.0 to 8.0.8
- [Release notes](https://github.com/vitejs/vite/releases)
- [Changelog](https://github.com/vitejs/vite/blob/main/packages/vite/CHANGELOG.md)
- [Commits](https://github.com/vitejs/vite/commits/v8.0.8/packages/vite)

Updates `picomatch` from 4.0.3 to 4.0.4
- [Release notes](https://github.com/micromatch/picomatch/releases)
- [Changelog](https://github.com/micromatch/picomatch/blob/master/CHANGELOG.md)
- [Commits](https://github.com/micromatch/picomatch/compare/4.0.3...4.0.4)

Updates `dompurify` from 3.3.3 to 3.4.0
- [Release notes](https://github.com/cure53/DOMPurify/releases)
- [Commits](https://github.com/cure53/DOMPurify/compare/3.3.3...3.4.0)

---
updated-dependencies:
- dependency-name: lodash
  dependency-version: 4.18.1
  dependency-type: direct:production
  dependency-group: npm_and_yarn
- dependency-name: vite
  dependency-version: 8.0.8
  dependency-type: direct:development
  dependency-group: npm_and_yarn
- dependency-name: picomatch
  dependency-version: 4.0.4
  dependency-type: indirect
  dependency-group: npm_and_yarn
- dependency-name: dompurify
  dependency-version: 3.4.0
  dependency-type: indirect
  dependency-group: npm_and_yarn
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-19 14:06:01 +00:00
Manish Madan
4db9622ef5 Merge pull request #2403 from arc53/dependabot/npm_and_yarn/frontend/types/react-19.2.14
chore(deps-dev): bump @types/react from 19.2.2 to 19.2.14 in /frontend
2026-04-19 19:34:04 +05:30
dependabot[bot]
e8d1bbfb68 chore(deps-dev): bump @types/react from 19.2.2 to 19.2.14 in /frontend
Bumps [@types/react](https://github.com/DefinitelyTyped/DefinitelyTyped/tree/HEAD/types/react) from 19.2.2 to 19.2.14.
- [Release notes](https://github.com/DefinitelyTyped/DefinitelyTyped/releases)
- [Commits](https://github.com/DefinitelyTyped/DefinitelyTyped/commits/HEAD/types/react)

---
updated-dependencies:
- dependency-name: "@types/react"
  dependency-version: 19.2.14
  dependency-type: direct:development
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-19 13:54:47 +00:00
Manish Madan
aff1345ae4 Merge pull request #2399 from arc53/dependabot/npm_and_yarn/frontend/react-router-dom-7.14.1
chore(deps): bump react-router-dom from 7.13.1 to 7.14.1 in /frontend
2026-04-19 04:45:57 +05:30
Alex
ee430aff1e fix: tests 2026-04-18 13:28:03 +01:00
Alex
81b6ee5daa Pg 4 (#2390)
* feat: postgres tests

* feat: mongo cutoff

* feat: mongo cutoff

* feat: adjust docs and compose files

* fix: mini code mongo removals

* fix: tests and k8s mongo stuff

* feat: test fixes

* fix: ruff

* fix: vale

* Potential fix for pull request finding 'CodeQL / Clear-text logging of sensitive information'

Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>

* fix: mini suggestions

* vale lint fix 2

* fix: codeql columns thing

* fix: test mongo

* fix: tests coverage

* feat: better tests 4

* feat: more tests

* feat: decent coverage

* fix: ruff fixes

* fix: remove mongo mock

* feat: enhance workflow engine and API routes; add document retrieval and source handling

* feat: e2e tests

* fix: mcp, mongo and more

* fix: mini codeql warning

* fix: agent chunk view

* fix: mini issues

* fix: more pg fixes

* feat: postgres prep on start

* feat: qa tests

* fix: mini improvements

* fix: tests

---------

Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
Co-authored-by: Siddhant Rai <siddhant.rai.5686@gmail.com>
2026-04-18 13:13:57 +01:00
dependabot[bot]
ebb7938d1b chore(deps): bump react-router-dom from 7.13.1 to 7.14.1 in /frontend
Bumps [react-router-dom](https://github.com/remix-run/react-router/tree/HEAD/packages/react-router-dom) from 7.13.1 to 7.14.1.
- [Release notes](https://github.com/remix-run/react-router/releases)
- [Changelog](https://github.com/remix-run/react-router/blob/main/packages/react-router-dom/CHANGELOG.md)
- [Commits](https://github.com/remix-run/react-router/commits/react-router-dom@7.14.1/packages/react-router-dom)

---
updated-dependencies:
- dependency-name: react-router-dom
  dependency-version: 7.14.1
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-17 08:55:15 +00:00
Manish Madan
7f6360b4ff Merge pull request #2396 from arc53/dependabot/npm_and_yarn/frontend/lucide-react-1.8.0
chore(deps): bump lucide-react from 0.562.0 to 1.8.0 in /frontend
2026-04-17 14:23:19 +05:30
dependabot[bot]
c68f18a0ae chore(deps): bump lucide-react from 0.562.0 to 1.8.0 in /frontend
Bumps [lucide-react](https://github.com/lucide-icons/lucide/tree/HEAD/packages/lucide-react) from 0.562.0 to 1.8.0.
- [Release notes](https://github.com/lucide-icons/lucide/releases)
- [Commits](https://github.com/lucide-icons/lucide/commits/1.8.0/packages/lucide-react)

---
updated-dependencies:
- dependency-name: lucide-react
  dependency-version: 1.8.0
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-17 08:50:18 +00:00
Manish Madan
684b29e73c Merge pull request #2393 from arc53/dependabot/npm_and_yarn/frontend/eslint-plugin-prettier-5.5.5
chore(deps-dev): bump eslint-plugin-prettier from 5.5.4 to 5.5.5 in /frontend
2026-04-17 14:18:50 +05:30
Alex
a1efea81d0 patch: sitemap and web loader 2026-04-15 23:39:44 +01:00
dependabot[bot]
9eb34262e0 chore(deps-dev): bump eslint-plugin-prettier in /frontend
Bumps [eslint-plugin-prettier](https://github.com/prettier/eslint-plugin-prettier) from 5.5.4 to 5.5.5.
- [Release notes](https://github.com/prettier/eslint-plugin-prettier/releases)
- [Changelog](https://github.com/prettier/eslint-plugin-prettier/blob/main/CHANGELOG.md)
- [Commits](https://github.com/prettier/eslint-plugin-prettier/compare/v5.5.4...v5.5.5)

---
updated-dependencies:
- dependency-name: eslint-plugin-prettier
  dependency-version: 5.5.5
  dependency-type: direct:development
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-15 20:53:23 +00:00
Alex
951bdb8365 Merge pull request #2391 from arc53/codex/draft-threat-model-for-docsgpt
Add public threat model document (.github/THREAT_MODEL.md)
2026-04-15 18:38:55 +01:00
Alex
c18f85a050 docs: clarify tool-access boundary in prompt injection section 2026-04-15 18:31:42 +01:00
Manish Madan
5ecb174567 Merge pull request #2322 from arc53/dependabot/npm_and_yarn/extensions/react-widget/svgo-4.0.1
chore(deps-dev): bump svgo from 3.3.3 to 4.0.1 in /extensions/react-widget
2026-04-15 17:06:40 +05:30
dependabot[bot]
ed7212d016 chore(deps-dev): bump svgo in /extensions/react-widget
Bumps [svgo](https://github.com/svg/svgo) from 3.3.3 to 4.0.1.
- [Release notes](https://github.com/svg/svgo/releases)
- [Commits](https://github.com/svg/svgo/compare/v3.3.3...v4.0.1)

---
updated-dependencies:
- dependency-name: svgo
  dependency-version: 4.0.1
  dependency-type: direct:development
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-15 10:43:02 +00:00
Manish Madan
f82acdab5d Merge pull request #2384 from arc53/dependabot/npm_and_yarn/frontend/eslint-plugin-n-17.24.0
chore(deps-dev): bump eslint-plugin-n from 17.23.1 to 17.24.0 in /frontend
2026-04-15 16:03:56 +05:30
dependabot[bot]
361aebc34c chore(deps-dev): bump eslint-plugin-n in /frontend
Bumps [eslint-plugin-n](https://github.com/eslint-community/eslint-plugin-n) from 17.23.1 to 17.24.0.
- [Release notes](https://github.com/eslint-community/eslint-plugin-n/releases)
- [Changelog](https://github.com/eslint-community/eslint-plugin-n/blob/master/CHANGELOG.md)
- [Commits](https://github.com/eslint-community/eslint-plugin-n/compare/v17.23.1...v17.24.0)

---
updated-dependencies:
- dependency-name: eslint-plugin-n
  dependency-version: 17.24.0
  dependency-type: direct:development
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-15 10:28:46 +00:00
Manish Madan
bf194c1a0f Merge pull request #2311 from arc53/dependabot/npm_and_yarn/extensions/react-widget/babel/preset-env-7.29.2
chore(deps-dev): bump @babel/preset-env from 7.24.6 to 7.29.2 in /extensions/react-widget
2026-04-15 13:16:41 +05:30
Manish Madan
54c396750b Merge pull request #2385 from arc53/dependabot/npm_and_yarn/frontend/multi-2a6546692b
chore(deps): bump react-dom and @types/react-dom in /frontend
2026-04-15 13:12:49 +05:30
dependabot[bot]
9adebfec69 chore(deps): bump react-dom and @types/react-dom in /frontend
Bumps [react-dom](https://github.com/facebook/react/tree/HEAD/packages/react-dom) and [@types/react-dom](https://github.com/DefinitelyTyped/DefinitelyTyped/tree/HEAD/types/react-dom). These dependencies needed to be updated together.

Updates `react-dom` from 19.2.0 to 19.2.5
- [Release notes](https://github.com/facebook/react/releases)
- [Changelog](https://github.com/facebook/react/blob/main/CHANGELOG.md)
- [Commits](https://github.com/facebook/react/commits/v19.2.5/packages/react-dom)

Updates `@types/react-dom` from 19.2.2 to 19.2.3
- [Release notes](https://github.com/DefinitelyTyped/DefinitelyTyped/releases)
- [Commits](https://github.com/DefinitelyTyped/DefinitelyTyped/commits/HEAD/types/react-dom)

---
updated-dependencies:
- dependency-name: react-dom
  dependency-version: 19.2.5
  dependency-type: direct:production
  update-type: version-update:semver-patch
- dependency-name: "@types/react-dom"
  dependency-version: 19.2.3
  dependency-type: direct:development
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-15 07:41:54 +00:00
Manish Madan
92c321f163 Merge pull request #2386 from arc53/dependabot/npm_and_yarn/frontend/tailwindcss/postcss-4.2.2
chore(deps-dev): bump @tailwindcss/postcss from 4.1.16 to 4.2.2 in /frontend
2026-04-15 13:10:30 +05:30
dependabot[bot]
e3d36b9e52 chore(deps-dev): bump @tailwindcss/postcss in /frontend
Bumps [@tailwindcss/postcss](https://github.com/tailwindlabs/tailwindcss/tree/HEAD/packages/@tailwindcss-postcss) from 4.1.16 to 4.2.2.
- [Release notes](https://github.com/tailwindlabs/tailwindcss/releases)
- [Changelog](https://github.com/tailwindlabs/tailwindcss/blob/main/CHANGELOG.md)
- [Commits](https://github.com/tailwindlabs/tailwindcss/commits/v4.2.2/packages/@tailwindcss-postcss)

---
updated-dependencies:
- dependency-name: "@tailwindcss/postcss"
  dependency-version: 4.2.2
  dependency-type: direct:development
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-15 07:39:25 +00:00
Manish Madan
8950e11208 Merge pull request #2388 from arc53/dependabot/npm_and_yarn/frontend/tailwindcss-4.2.2
chore(deps-dev): bump tailwindcss from 4.2.1 to 4.2.2 in /frontend
2026-04-15 12:39:31 +05:30
dependabot[bot]
5de0132a65 chore(deps-dev): bump tailwindcss from 4.2.1 to 4.2.2 in /frontend
Bumps [tailwindcss](https://github.com/tailwindlabs/tailwindcss/tree/HEAD/packages/tailwindcss) from 4.2.1 to 4.2.2.
- [Release notes](https://github.com/tailwindlabs/tailwindcss/releases)
- [Changelog](https://github.com/tailwindlabs/tailwindcss/blob/main/CHANGELOG.md)
- [Commits](https://github.com/tailwindlabs/tailwindcss/commits/v4.2.2/packages/tailwindcss)

---
updated-dependencies:
- dependency-name: tailwindcss
  dependency-version: 4.2.2
  dependency-type: direct:development
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-15 07:06:26 +00:00
Manish Madan
b92ca91512 Merge pull request #2387 from arc53/dependabot/npm_and_yarn/frontend/mermaid-11.14.0
chore(deps): bump mermaid from 11.13.0 to 11.14.0 in /frontend
2026-04-15 12:34:28 +05:30
dependabot[bot]
8e0b2844a2 chore(deps): bump mermaid from 11.13.0 to 11.14.0 in /frontend
Bumps [mermaid](https://github.com/mermaid-js/mermaid) from 11.13.0 to 11.14.0.
- [Release notes](https://github.com/mermaid-js/mermaid/releases)
- [Commits](https://github.com/mermaid-js/mermaid/compare/mermaid@11.13.0...mermaid@11.14.0)

---
updated-dependencies:
- dependency-name: mermaid
  dependency-version: 11.14.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-15 07:02:28 +00:00
dependabot[bot]
0969db5e30 chore(deps-dev): bump @babel/preset-env in /extensions/react-widget
Bumps [@babel/preset-env](https://github.com/babel/babel/tree/HEAD/packages/babel-preset-env) from 7.24.6 to 7.29.2.
- [Release notes](https://github.com/babel/babel/releases)
- [Changelog](https://github.com/babel/babel/blob/main/CHANGELOG.md)
- [Commits](https://github.com/babel/babel/commits/v7.29.2/packages/babel-preset-env)

---
updated-dependencies:
- dependency-name: "@babel/preset-env"
  dependency-version: 7.29.2
  dependency-type: direct:development
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-15 05:30:32 +00:00
Manish Madan
af335a27e8 Merge pull request #2309 from arc53/dependabot/npm_and_yarn/extensions/react-widget/babel/core-7.29.0
chore(deps-dev): bump @babel/core from 7.24.6 to 7.29.0 in /extensions/react-widget
2026-04-15 10:59:07 +05:30
dependabot[bot]
0ae3139284 chore(deps-dev): bump @babel/core in /extensions/react-widget
Bumps [@babel/core](https://github.com/babel/babel/tree/HEAD/packages/babel-core) from 7.24.6 to 7.29.0.
- [Release notes](https://github.com/babel/babel/releases)
- [Changelog](https://github.com/babel/babel/blob/main/CHANGELOG.md)
- [Commits](https://github.com/babel/babel/commits/v7.29.0/packages/babel-core)

---
updated-dependencies:
- dependency-name: "@babel/core"
  dependency-version: 7.29.0
  dependency-type: direct:development
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-14 23:05:51 +00:00
Manish Madan
7529ca3dd6 Merge pull request #2389 from arc53/dependabot/pip/application/pip-3344959f9f
chore(deps): bump cryptography from 46.0.6 to 46.0.7 in /application in the pip group across 1 directory
2026-04-15 04:32:07 +05:30
dependabot[bot]
1b813320f1 chore(deps): bump cryptography
Bumps the pip group with 1 update in the /application directory: [cryptography](https://github.com/pyca/cryptography).


Updates `cryptography` from 46.0.6 to 46.0.7
- [Changelog](https://github.com/pyca/cryptography/blob/main/CHANGELOG.rst)
- [Commits](https://github.com/pyca/cryptography/compare/46.0.6...46.0.7)

---
updated-dependencies:
- dependency-name: cryptography
  dependency-version: 46.0.7
  dependency-type: direct:production
  dependency-group: pip
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-14 21:46:13 +00:00
Manish Madan
02012e9a0b Merge pull request #2366 from arc53/dependabot/pip/application/tzdata-2026.1
chore(deps): bump tzdata from 2025.3 to 2026.1 in /application
2026-04-15 03:14:25 +05:30
dependabot[bot]
c2f027265a chore(deps): bump tzdata from 2025.3 to 2026.1 in /application
Bumps [tzdata](https://github.com/python/tzdata) from 2025.3 to 2026.1.
- [Release notes](https://github.com/python/tzdata/releases)
- [Changelog](https://github.com/python/tzdata/blob/master/NEWS.md)
- [Commits](https://github.com/python/tzdata/compare/2025.3...2026.1)

---
updated-dependencies:
- dependency-name: tzdata
  dependency-version: '2026.1'
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-14 21:30:10 +00:00
Manish Madan
0ae615c10e Merge pull request #2365 from arc53/dependabot/pip/application/langchain-core-1.2.26
chore(deps): bump langchain-core from 1.2.23 to 1.2.26 in /application
2026-04-15 02:58:36 +05:30
dependabot[bot]
881d0da344 chore(deps): bump langchain-core from 1.2.23 to 1.2.26 in /application
Bumps [langchain-core](https://github.com/langchain-ai/langchain) from 1.2.23 to 1.2.26.
- [Release notes](https://github.com/langchain-ai/langchain/releases)
- [Commits](https://github.com/langchain-ai/langchain/compare/langchain-core==1.2.23...langchain-core==1.2.26)

---
updated-dependencies:
- dependency-name: langchain-core
  dependency-version: 1.2.26
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-14 21:10:32 +00:00
Manish Madan
1376de6bae Merge pull request #2364 from arc53/dependabot/pip/application/langsmith-0.7.26
chore(deps): bump langsmith from 0.7.23 to 0.7.26 in /application
2026-04-15 02:39:17 +05:30
dependabot[bot]
362ebfcc0a chore(deps): bump langsmith from 0.7.23 to 0.7.26 in /application
Bumps [langsmith](https://github.com/langchain-ai/langsmith-sdk) from 0.7.23 to 0.7.26.
- [Release notes](https://github.com/langchain-ai/langsmith-sdk/releases)
- [Commits](https://github.com/langchain-ai/langsmith-sdk/compare/v0.7.23...v0.7.26)

---
updated-dependencies:
- dependency-name: langsmith
  dependency-version: 0.7.26
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-14 20:29:55 +00:00
Manish Madan
bc77eed3d8 Merge pull request #2363 from arc53/dependabot/pip/application/jsonpointer-3.1.1
chore(deps): bump jsonpointer from 3.0.0 to 3.1.1 in /application
2026-04-15 01:58:23 +05:30
dependabot[bot]
1f346588e7 chore(deps): bump jsonpointer from 3.0.0 to 3.1.1 in /application
Bumps [jsonpointer](https://github.com/stefankoegl/python-json-pointer) from 3.0.0 to 3.1.1.
- [Commits](https://github.com/stefankoegl/python-json-pointer/compare/v3.0.0...v3.1.1)

---
updated-dependencies:
- dependency-name: jsonpointer
  dependency-version: 3.1.1
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-14 20:13:38 +00:00
Alex
2fed5c882b Merge pull request #2383 from arc53/codex/add-zizmor-ci-configuration
Add GitHub Actions zizmor security workflow
2026-04-14 18:02:18 +01:00
Alex
aa938d76d7 Add GitHub Actions zizmor security workflow 2026-04-14 17:56:14 +01:00
Manish Madan
2940628aa6 Merge pull request #2319 from arc53/dependabot/npm_and_yarn/frontend/npm_and_yarn-e5a595f223
chore(deps-dev): bump flatted from 3.4.1 to 3.4.2 in /frontend in the npm_and_yarn group across 1 directory
2026-04-14 21:30:54 +05:30
dependabot[bot]
7f23928134 chore(deps-dev): bump flatted
Bumps the npm_and_yarn group with 1 update in the /frontend directory: [flatted](https://github.com/WebReflection/flatted).


Updates `flatted` from 3.4.1 to 3.4.2
- [Commits](https://github.com/WebReflection/flatted/compare/v3.4.1...v3.4.2)

---
updated-dependencies:
- dependency-name: flatted
  dependency-version: 3.4.2
  dependency-type: indirect
  dependency-group: npm_and_yarn
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-14 15:10:18 +00:00
Alex
20e17c84c7 Merge pull request #2379 from arc53/codex/refine-and-review-incident-response-plan
Add INCIDENT_RESPONSE.md and reference it from SECURITY.md
2026-04-14 14:59:04 +01:00
copilot-swe-agent[bot]
389ddf6068 Fix secret references in INCIDENT_RESPONSE.md to match actual DocsGPT config
Agent-Logs-Url: https://github.com/arc53/DocsGPT/sessions/c6bfd68d-4dac-46ec-8404-fe5bfda0e8f3

Co-authored-by: dartpain <15183589+dartpain@users.noreply.github.com>
2026-04-14 10:51:22 +00:00
Alex
1e2443fb90 Update .github/INCIDENT_RESPONSE.md
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-04-14 11:49:13 +01:00
Manish Madan
6387bd1892 Merge pull request #2303 from arc53/dependabot/npm_and_yarn/frontend/typescript-eslint/eslint-plugin-8.57.1
chore(deps-dev): bump @typescript-eslint/eslint-plugin from 8.46.3 to 8.57.1 in /frontend
2026-04-14 14:16:35 +05:30
dependabot[bot]
7d22724d1c chore(deps-dev): bump @typescript-eslint/eslint-plugin in /frontend
Bumps [@typescript-eslint/eslint-plugin](https://github.com/typescript-eslint/typescript-eslint/tree/HEAD/packages/eslint-plugin) from 8.46.3 to 8.57.1.
- [Release notes](https://github.com/typescript-eslint/typescript-eslint/releases)
- [Changelog](https://github.com/typescript-eslint/typescript-eslint/blob/main/packages/eslint-plugin/CHANGELOG.md)
- [Commits](https://github.com/typescript-eslint/typescript-eslint/commits/v8.57.1/packages/eslint-plugin)

---
updated-dependencies:
- dependency-name: "@typescript-eslint/eslint-plugin"
  dependency-version: 8.57.1
  dependency-type: direct:development
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-14 08:39:33 +00:00
Manish Madan
f6f12f6895 Merge pull request #2302 from arc53/dependabot/npm_and_yarn/frontend/prettier-plugin-tailwindcss-0.7.2
chore(deps-dev): bump prettier-plugin-tailwindcss from 0.7.1 to 0.7.2 in /frontend
2026-04-14 14:07:38 +05:30
dependabot[bot]
934127f323 chore(deps-dev): bump prettier-plugin-tailwindcss in /frontend
Bumps [prettier-plugin-tailwindcss](https://github.com/tailwindlabs/prettier-plugin-tailwindcss) from 0.7.1 to 0.7.2.
- [Release notes](https://github.com/tailwindlabs/prettier-plugin-tailwindcss/releases)
- [Changelog](https://github.com/tailwindlabs/prettier-plugin-tailwindcss/blob/main/CHANGELOG.md)
- [Commits](https://github.com/tailwindlabs/prettier-plugin-tailwindcss/compare/v0.7.1...v0.7.2)

---
updated-dependencies:
- dependency-name: prettier-plugin-tailwindcss
  dependency-version: 0.7.2
  dependency-type: direct:development
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-14 08:25:09 +00:00
Manish Madan
1780e3cc91 Merge pull request #2301 from arc53/dependabot/npm_and_yarn/frontend/react-i18next-16.5.8
chore(deps): bump react-i18next from 16.2.4 to 16.5.8 in /frontend
2026-04-14 13:53:12 +05:30
ManishMadan2882
5e7fab2f34 (chore:fe) i18next 2026-04-14 13:50:03 +05:30
Alex
92ae76f95e Merge pull request #2381 from arc53/pg-3
feat: pre depriciation
2026-04-14 08:33:42 +01:00
Alex
18755bdd9b fix: workflow tests 2026-04-14 00:35:57 +01:00
Alex
0f20adcbf4 feat: pre depriciation 2026-04-14 00:19:50 +01:00
Alex
18e2a829c9 docs: apply revised incident response plan wording 2026-04-13 14:11:45 +01:00
dependabot[bot]
cd44501a71 chore(deps): bump react-i18next from 16.2.4 to 16.5.8 in /frontend
Bumps [react-i18next](https://github.com/i18next/react-i18next) from 16.2.4 to 16.5.8.
- [Changelog](https://github.com/i18next/react-i18next/blob/master/CHANGELOG.md)
- [Commits](https://github.com/i18next/react-i18next/compare/v16.2.4...v16.5.8)

---
updated-dependencies:
- dependency-name: react-i18next
  dependency-version: 16.5.8
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-13 12:36:35 +00:00
Manish Madan
f8ebdf3fd4 Merge pull request #2300 from arc53/dependabot/npm_and_yarn/frontend/i18next-browser-languagedetector-8.2.1
chore(deps): bump i18next-browser-languagedetector from 8.2.0 to 8.2.1 in /frontend
2026-04-13 18:03:46 +05:30
dependabot[bot]
7c6fca18ad chore(deps): bump i18next-browser-languagedetector in /frontend
Bumps [i18next-browser-languagedetector](https://github.com/i18next/i18next-browser-languageDetector) from 8.2.0 to 8.2.1.
- [Changelog](https://github.com/i18next/i18next-browser-languageDetector/blob/master/CHANGELOG.md)
- [Commits](https://github.com/i18next/i18next-browser-languageDetector/compare/v8.2.0...v8.2.1)

---
updated-dependencies:
- dependency-name: i18next-browser-languagedetector
  dependency-version: 8.2.1
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-13 12:28:26 +00:00
Alex
5fab798707 Merge pull request #2377 from arc53/pg-2
feat: pg-2
2026-04-12 14:09:52 +01:00
Alex
cb30a24e05 feat: fixes on pg2 2026-04-12 13:51:29 +01:00
Alex
530761d08c feat: pg-2 2026-04-12 13:35:32 +01:00
Alex
73fbc28744 Merge pull request #2376 from arc53/pg-1
Pg 1
2026-04-12 12:44:12 +01:00
Alex
b5b6538762 fix: tests 2026-04-12 12:35:23 +01:00
Alex
a9761061fc fix: mini issues 2026-04-12 12:24:58 +01:00
Alex
9388996a15 fix: ruff 2026-04-12 12:18:31 +01:00
Alex
875868b7e5 fix: comment 2026-04-12 12:16:19 +01:00
Alex
502819ae52 feat: pg migration, more tables 2026-04-12 12:15:59 +01:00
Alex
cada1a44fc Merge pull request #2373 from siiddhantt/feat/confluence-connector
feat: add Confluence connector for data ingestion
2026-04-12 11:33:49 +01:00
Alex
6192767451 fix: sanitize attachment filenames, drop dateutil dep, add connector docs 2026-04-12 11:32:24 +01:00
Alex
5c3e6eca54 Merge pull request #2375 from arc53/pg
feat: init pg migration
2026-04-12 10:44:31 +01:00
Alex
59d9d4ac50 fix: comments in settings 2026-04-12 10:42:46 +01:00
Alex
3931ccccee Merge pull request #2374 from ManishMadan2882/main
UX: Conversation scroll experience
2026-04-12 00:37:11 +01:00
Alex
55717043f6 fix: vale 2026-04-12 00:29:23 +01:00
Alex
ececcb8b17 feat: init pg migration 2026-04-12 00:07:24 +01:00
ManishMadan2882
420e9d3dd5 (feat) conversation: scroll experience 2026-04-10 19:27:07 +05:30
Siddhant Rai
749eed3d0b feat: add Confluence integration with authentication and file loading capabilities
- Enhanced settings.py to include Confluence client ID and secret
- Created ConfluenceAuth class for handling authentication with Confluence
- Implemented ConfluenceLoader class for loading data from Confluence
- Updated connector_creator.py to register Confluence as a connector
- Added confluence.svg asset for UI representation
- Modified ConnectorAuth component to support Confluence connection
- Updated FilePicker component to include Confluence as a file source
- Added localization support for Confluence in multiple languages (de, en, es, jp, ru, zh-TW, zh)
- Enhanced Upload component to handle Confluence file selection
- Updated ingestor types to include Confluence and its configuration
2026-04-10 19:10:35 +05:30
Alex
bd03a513e3 Merge pull request #2372 from arc53/fast-ebook
feat: faster ebook parsing
2026-04-09 18:38:13 +01:00
Alex
fcdb4fb5e8 feat: faster ebook parsing 2026-04-09 18:31:06 +01:00
Alex
e787c896eb upd Security.md 2026-04-08 12:49:20 +01:00
Alex
23aeaff5db Merge pull request #2362 from arc53/v1-mini-improvements
feat: history overwrite
2026-04-06 15:02:32 +01:00
Alex
689dd79597 fix: lang 2026-04-06 14:57:51 +01:00
Alex
0c15af90b1 feat: history overwrite 2026-04-06 14:42:01 +01:00
Alex
cdd6ff6557 chore: bump deps 2026-04-04 12:45:34 +01:00
Alex
72b3d94453 fix: tests 2026-04-03 18:30:46 +01:00
Alex
7e88d09e5d Merge branch 'main' of https://github.com/arc53/DocsGPT 2026-04-03 18:26:37 +01:00
Alex
74a4a237dc fix: bump deps 2026-04-03 18:26:29 +01:00
Alex
c3f01c6619 Merge pull request #2347 from ManishMadan2882/main
Minor frontend updates
2026-04-03 18:17:27 +01:00
Alex
6b408823d4 fix: mini theme color edits 2026-04-03 18:16:07 +01:00
Alex
3fc81ac5d8 fix: clean error 2026-04-03 18:08:38 +01:00
Alex
2652f8a5b0 fix: chatwoot 2026-04-03 18:04:49 +01:00
Alex
d711eefe96 patch: agent usage limits 2026-04-03 18:03:31 +01:00
Alex
79206f3919 fix: harden faiss 2026-04-03 17:57:49 +01:00
Alex
de971d9452 fix: validate mcp url 2026-04-03 17:52:48 +01:00
Alex
1b4d5ca0dd patch: mcp identity 2026-04-03 17:40:22 +01:00
Alex
81989e8258 fix: patch /v1/models 2026-04-03 17:37:09 +01:00
Alex
dc262d1698 patch: error 2026-04-03 17:30:23 +01:00
Alex
69f9c93869 patch: s3 2026-04-03 17:28:09 +01:00
Alex
74bf80b25c patch: sharing convos 2026-04-03 17:20:06 +01:00
Alex
d9a92a7208 feat: improve setup scripts 2026-04-03 17:15:21 +01:00
Alex
02e93d993d patch: available tools 2026-04-03 17:12:36 +01:00
Alex
6b6495f48c patch: key 2026-04-03 17:06:35 +01:00
Alex
249dd9ce37 patch: paths 2026-04-03 16:45:03 +01:00
Alex
9134ab0478 Merge branch 'main' of https://github.com/arc53/DocsGPT 2026-04-03 16:40:50 +01:00
Alex
10ef68c9d0 Revise vulnerability reporting process
Updated vulnerability reporting instructions to use GitHub's private reporting flow.
2026-04-03 16:36:10 +01:00
Alex
7d65cf1c2b chore: bump deps 2026-04-03 16:35:10 +01:00
Alex
13c6cc59c1 Merge pull request #2349 from arc53/messages-format
Messages format
2026-04-03 16:26:57 +01:00
Alex
6381f7dd4e fix: remove bad tests 2026-04-03 16:20:15 +01:00
Alex
e6ac4008fe feat: better tool names for llms 2026-04-03 15:35:50 +01:00
Alex
1af09f114d fix: tool mapping 2026-04-03 13:32:55 +01:00
Alex
be7da983e7 fix: remove internal tools when creating tools and better Approval gate UX 2026-04-03 10:36:48 +01:00
Alex
8b9e595d85 fix: structure improvements of messages 2026-04-01 14:58:44 +01:00
Alex
398f3acc8d fix: clean error 2026-04-01 13:01:02 +01:00
Alex
e04baa7ed8 feat: tests and approval gate 2026-04-01 12:49:32 +01:00
Alex
e5586b6f20 feat: fronted connection to api 2026-04-01 10:55:54 +01:00
Alex
addf57cab7 feat: compatible api 2026-03-31 23:10:09 +01:00
ManishMadan2882
648b3f1d20 (fix) lint/fe 2026-04-01 03:30:44 +05:30
ManishMadan2882
a75a9e23f9 (feat:fe) minor good things 2026-04-01 03:19:03 +05:30
Alex
73256389cf feat: client side tools 2026-03-31 22:20:55 +01:00
Alex
d609efca49 feat: continuation messages 2026-03-31 21:30:24 +01:00
Alex
772860b667 fix: mini fe changes 2026-03-31 11:59:38 +01:00
Alex
ea2fd8b04a chore: remove unused deps 2026-03-31 11:57:01 +01:00
Alex
2c73deac20 deps upgrades 2026-03-31 11:32:55 +01:00
Alex
47f3907e5e Merge pull request #2340 from arc53/coverage-3
chore: more tests
2026-03-31 00:50:46 +01:00
Alex
727495c553 fix: mongo in unit tests 2026-03-31 00:34:49 +01:00
Alex
a3b08a5b44 More tests 2026-03-31 00:07:19 +01:00
Alex
81532ada2a Merge pull request #2318 from siiddhantt/feat/standardize-css
feat: update styles and improve accessibility across frontend
2026-03-30 23:26:45 +01:00
ManishMadan2882
43f71374e5 (chore:fe) lint-fix 2026-03-30 23:26:11 +05:30
Alex
d5c0322e2a chore: more tests 2026-03-30 16:13:08 +01:00
Siddhant Rai
3b66a3176c fix: improve option matching logic in Dropdown component + selected style 2026-03-30 18:36:35 +05:30
Alex
dc6db847ca Merge pull request #2339 from arc53/test-handlers
chore: handlers tests
2026-03-30 13:06:53 +01:00
Alex
ed0063aada chore: handlers tests 2026-03-30 12:53:50 +01:00
Siddhant Rai
9a6a55b6da Merge branch 'main' into feat/standardize-css 2026-03-30 13:14:43 +05:30
Siddhant Rai
12a8368216 fix: merge conflicts 2026-03-30 13:12:24 +05:30
Alex
3f6d6f15ea Merge pull request #2338 from arc53/tests-utils
chore: utils tests
2026-03-29 11:54:59 +01:00
Alex
126fa01b14 chore: utils tests 2026-03-29 11:49:35 +01:00
Alex
e06debad5f chore: connector tests 2026-03-29 10:32:48 +01:00
Alex
6492852f7d fix: lint on seeder test 2026-03-29 09:48:46 +01:00
Alex
00a621f33a tests: retriever and seeder 2026-03-29 09:46:07 +01:00
Alex
e92ffc6fdc Merge pull request #2337 from arc53/tests-api
chore: api and tool tests
2026-03-28 22:55:10 +00:00
Alex
fe185e5b8d chore: api and tool tests 2026-03-28 21:51:47 +00:00
Alex
9f3d9ab860 docs: agent types 2026-03-28 18:50:50 +00:00
Alex
1c0adde380 chore: docs update 2026-03-28 17:04:06 +00:00
Alex
3c56bd0d0b docs: fix conflicts 2026-03-28 16:16:15 +00:00
Alex
86664ebda2 Merge pull request #2320 from Alex-wuhu/novita-integration
feat: complete Novita AI provider integration
2026-03-28 13:22:02 +00:00
Alex
db18b743d1 fix tests 2 2026-03-28 13:10:41 +00:00
Alex
9e85cc9065 fix: test errors 2026-03-28 13:04:21 +00:00
Alex
aaaa6f002d Merge branch 'main' of https://github.com/arc53/DocsGPT 2026-03-28 12:03:27 +00:00
Alex
47dcbcb74b fix: tests and sources on workflow agent 2026-03-28 12:03:16 +00:00
Alex
ddbfd94193 Adjust demo GIF height in README
Update the height of the demo GIF in README.
2026-03-28 11:09:05 +00:00
Alex
8dec60ab8b Update demo GIF in README
Replaced demo video GIF in README with a new example.
2026-03-28 11:08:10 +00:00
Alex
84b2e4bab4 fix: end node multi input 2026-03-28 10:00:01 +00:00
Siddhant Rai
193ca6fd63 fix: lint errors + redundant css 2026-03-28 15:02:16 +05:30
Alex
2afdd7f026 fix: enable tools in workflow agents 2026-03-27 13:43:09 +00:00
Alex
f364475f64 Merge pull request #2335 from arc53/tests-worker
tests: worker coverage
2026-03-27 12:14:56 +00:00
Alex
b254de6ed6 tests: worker coverage 2026-03-27 12:03:55 +00:00
Alex
08dedcaf95 Merge pull request #2334 from arc53/tests-vectors
tests: vectors
2026-03-26 19:11:20 +00:00
Alex
c726eb8ebd fix: ruff 2026-03-26 18:48:53 +00:00
Alex
5f0d39e5f1 tests: vectors 2026-03-26 18:42:59 +00:00
Alex
8c82fc5495 Merge pull request #2333 from arc53/chore/bump-npm-v0.6.3
chore: bump npm libraries to v0.6.3
2026-03-26 14:17:09 +00:00
github-actions[bot]
6d81a15e97 chore: bump npm libraries to v0.6.3 2026-03-26 14:16:32 +00:00
Alex
5478e4234c chore: bump npm again 2026-03-26 14:06:05 +00:00
Alex
4056278fef chore: bump npm 2026-03-26 13:48:05 +00:00
Alex
ee6530fe00 chore: bump npm 2026-03-26 13:42:51 +00:00
Alex
7c1decbcc3 fix: npm publish 2026-03-26 13:39:09 +00:00
Alex
8a3c724b31 Merge pull request #2332 from arc53/fix-fallbacks-onstream
Fix fallbacks onstream
2026-03-26 13:18:32 +00:00
Alex
15d4e9dbf5 fix: strict json in openai base 2026-03-26 13:08:22 +00:00
Siddhant Rai
174dee0fe6 fix: inconsistencies with prev color patterns 2026-03-26 18:32:57 +05:30
Alex
f7bfd38b28 fix: proper fallback handling within agent during stream 2026-03-26 12:52:30 +00:00
Alex
187e5da61e Merge pull request #2328 from ManishMadan2882/main
Chore(React widget): automate publish via actions
2026-03-26 11:15:52 +00:00
Alex
175ed58d2e Merge pull request #2327 from arc53/research-agent
Research agent
2026-03-26 11:00:23 +00:00
Alex
820ee3a843 fix test 2026-03-25 22:53:09 +00:00
Alex
462f2e9494 mini refactors 2026-03-25 22:34:25 +00:00
ManishMadan2882
c4968a641e (chore)react-widget: add linter 2026-03-26 02:06:49 +05:30
Alex
c6ece177cd fix: small issues 2026-03-25 20:04:03 +00:00
ManishMadan2882
a3e6a5622d (chore)react-widget: build, publish act 2026-03-26 01:03:09 +05:30
Alex
e8d11fdfa6 feat: list files internal tool 2026-03-25 19:21:46 +00:00
Alex
72393dc369 feat: improve research 2026-03-25 17:42:24 +00:00
Alex
556b0a1da5 feat: research init 2026-03-25 15:16:18 +00:00
Siddhant Rai
844167ba06 Merge branch 'feat/standardize-css' of https://github.com/siiddhantt/DocsGPT into feat/standardize-css 2026-03-25 19:36:35 +05:30
Siddhant Rai
6fa3acb1ca style: standardize colors across components according to figma 2026-03-25 19:36:32 +05:30
Alex
32c268a21e refactor: simplify agent architecture and remove ReActAgent 2026-03-25 12:47:17 +00:00
Alex
ed34c2b929 fix: tool name in agent builder 2026-03-25 11:29:31 +00:00
Alex
06e827573c fix: jinja 2026-03-25 00:03:42 +00:00
dependabot[bot]
cdb71a54f0 chore(deps): bump class-variance-authority in /extensions/react-widget
Bumps [class-variance-authority](https://github.com/joe-bell/cva) from 0.7.0 to 0.7.1.
- [Release notes](https://github.com/joe-bell/cva/releases)
- [Commits](https://github.com/joe-bell/cva/compare/v0.7.0...v0.7.1)

---
updated-dependencies:
- dependency-name: class-variance-authority
  dependency-version: 0.7.1
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-03-24 20:53:18 +00:00
Manish Madan
74e76d4cda Merge pull request #2323 from arc53/dependabot/npm_and_yarn/extensions/react-widget/styled-components-6.3.12
chore(deps): bump styled-components from 6.1.11 to 6.3.12 in /extensions/react-widget
2026-03-25 01:44:55 +05:30
ManishMadan2882
db5c69ca76 (chore)react widget: build fix 2026-03-25 01:37:09 +05:30
Alex
9fd063266b Mini fixes 2026-03-24 01:40:29 +00:00
dependabot[bot]
05aa9d7cca chore(deps): bump styled-components in /extensions/react-widget
Bumps [styled-components](https://github.com/styled-components/styled-components) from 6.1.11 to 6.3.12.
- [Release notes](https://github.com/styled-components/styled-components/releases)
- [Commits](https://github.com/styled-components/styled-components/compare/v6.1.11...styled-components@6.3.12)

---
updated-dependencies:
- dependency-name: styled-components
  dependency-version: 6.3.12
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-03-23 20:53:52 +00:00
Manish Madan
dcececd118 Merge pull request #2298 from arc53/dependabot/npm_and_yarn/extensions/react-widget/flow-bin-0.305.1
chore(deps): bump flow-bin from 0.305.0 to 0.305.1 in /extensions/react-widget
2026-03-23 19:26:21 +05:30
Alex-wuhu
eaf39bb15b feat: add Novita AI as LLM provider
Add Novita AI (https://novita.ai) as a new LLM provider option.
Novita offers OpenAI-compatible API endpoints with competitive pricing.
2026-03-23 10:52:26 +08:00
dependabot[bot]
6515481624 chore(deps): bump flow-bin in /extensions/react-widget
Bumps [flow-bin](https://github.com/flowtype/flow-bin) from 0.305.0 to 0.305.1.
- [Release notes](https://github.com/flowtype/flow-bin/releases)
- [Commits](https://github.com/flowtype/flow-bin/commits)

---
updated-dependencies:
- dependency-name: flow-bin
  dependency-version: 0.305.1
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-03-22 20:29:46 +00:00
Manish Madan
6a7e3b6d77 Merge pull request #2165 from arc53/dependabot/pip/extensions/slack-bot/pip-75f0befae6
chore(deps): bump h11 from 0.14.0 to 0.16.0 in /extensions/slack-bot in the pip group across 1 directory
2026-03-23 01:48:58 +05:30
dependabot[bot]
02804fecce chore(deps): bump h11
Bumps the pip group with 1 update in the /extensions/slack-bot directory: [h11](https://github.com/python-hyper/h11).


Updates `h11` from 0.14.0 to 0.16.0
- [Commits](https://github.com/python-hyper/h11/compare/v0.14.0...v0.16.0)

---
updated-dependencies:
- dependency-name: h11
  dependency-version: 0.16.0
  dependency-type: direct:production
  dependency-group: pip
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-03-22 19:41:56 +00:00
Siddhant Rai
324a8cd4cf refactor: update styles and improve accessibility across frontend
- Updated text colors to use foreground and muted-foreground for better contrast.
- Replaced hardcoded colors with theme-based classes for consistency.
- Enhanced input fields with icons for improved usability.
- Adjusted button styles for a more cohesive design.
- Refactored search input components to use consistent styling and layout.
- Improved layout and spacing in various components for better user experience.
- Updated tool and source titles and subtitles for clarity.
2026-03-20 17:10:27 +05:30
Alex
ce5cd5561a loading animation 2026-03-19 15:52:27 +00:00
Manish Madan
adeefce9aa (fix)widget: since v6 shouldForwardProp is no longer provided by default (#2312) 2026-03-18 15:32:36 +00:00
Alex
5ab43fd12c fix: lang overflows, sst (#2314) 2026-03-18 14:50:29 +00:00
Alex
5894e47189 Agents.md 2026-03-18 00:34:26 +00:00
Manish Madan
ca61d81f4a Merge pull request #2154 from arc53/dependabot/npm_and_yarn/extensions/react-widget/typescript-5.9.3
chore(deps-dev): bump typescript from 5.4.5 to 5.9.3 in /extensions/react-widget
2026-03-18 01:43:16 +05:30
Alex
b12d0ca7b1 screenshot memo on contributing md 2026-03-17 19:45:37 +00:00
Alex
21996af626 stt init (#2306)
* stt init

* fix: limits

* fix: errors

* fix: error messages
2026-03-17 14:27:48 +00:00
ManishMadan2882
cc3b174e5a Merge branch 'dependabot/npm_and_yarn/extensions/react-widget/typescript-5.9.3' of https://github.com/arc53/docsgpt into dependabot/npm_and_yarn/extensions/react-widget/typescript-5.9.3 2026-03-16 17:08:06 +05:30
dependabot[bot]
faee58fb1e chore(deps-dev): bump typescript in /extensions/react-widget
Bumps [typescript](https://github.com/microsoft/TypeScript) from 5.4.5 to 5.9.3.
- [Release notes](https://github.com/microsoft/TypeScript/releases)
- [Changelog](https://github.com/microsoft/TypeScript/blob/main/azure-pipelines.release-publish.yml)
- [Commits](https://github.com/microsoft/TypeScript/compare/v5.4.5...v5.9.3)

---
updated-dependencies:
- dependency-name: typescript
  dependency-version: 5.9.3
  dependency-type: direct:development
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-03-16 11:37:27 +00:00
Manish Madan
d439e48b39 Merge pull request #2151 from arc53/dependabot/npm_and_yarn/extensions/react-widget/flow-bin-0.290.0
chore(deps): bump flow-bin from 0.229.2 to 0.290.0 in /extensions/react-widget
2026-03-16 17:05:33 +05:30
ManishMadan2882
3f0f155d64 Merge branch 'dependabot/npm_and_yarn/extensions/react-widget/flow-bin-0.290.0' of https://github.com/arc53/docsgpt into dependabot/npm_and_yarn/extensions/react-widget/flow-bin-0.290.0 2026-03-16 17:02:35 +05:30
dependabot[bot]
d82d512319 chore(deps): bump flow-bin in /extensions/react-widget
Bumps [flow-bin](https://github.com/flowtype/flow-bin) from 0.229.2 to 0.290.0.
- [Release notes](https://github.com/flowtype/flow-bin/releases)
- [Commits](https://github.com/flowtype/flow-bin/commits)

---
updated-dependencies:
- dependency-name: flow-bin
  dependency-version: 0.290.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-03-16 11:29:26 +00:00
Manish Madan
76aea1716f Merge pull request #2150 from arc53/dependabot/npm_and_yarn/extensions/react-widget/babel-loader-10.0.0
chore(deps-dev): bump babel-loader from 8.3.0 to 10.0.0 in /extensions/react-widget
2026-03-16 16:57:38 +05:30
ManishMadan2882
586649b73f (chore) react-widget: peer deps update 2026-03-16 16:56:02 +05:30
dependabot[bot]
0349a79cb3 chore(deps-dev): bump babel-loader in /extensions/react-widget
Bumps [babel-loader](https://github.com/babel/babel-loader) from 8.3.0 to 10.0.0.
- [Release notes](https://github.com/babel/babel-loader/releases)
- [Changelog](https://github.com/babel/babel-loader/blob/main/CHANGELOG.md)
- [Commits](https://github.com/babel/babel-loader/compare/v8.3.0...v10.0.0)

---
updated-dependencies:
- dependency-name: babel-loader
  dependency-version: 10.0.0
  dependency-type: direct:development
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-03-15 21:30:12 +00:00
Manish Madan
78a255bdd7 Merge pull request #2148 from arc53/dependabot/npm_and_yarn/extensions/react-widget/radix-ui/react-icons-1.3.2
chore(deps): bump @radix-ui/react-icons from 1.3.0 to 1.3.2 in /extensions/react-widget
2026-03-16 02:58:35 +05:30
ManishMadan2882
5b30e71aa1 (chore)deps fe and react-widget: audit fix 2026-03-16 02:24:02 +05:30
ManishMadan2882
99d84aece9 Merge branch 'dependabot/npm_and_yarn/extensions/react-widget/radix-ui/react-icons-1.3.2' of https://github.com/arc53/docsgpt into dependabot/npm_and_yarn/extensions/react-widget/radix-ui/react-icons-1.3.2 2026-03-16 01:57:14 +05:30
dependabot[bot]
525d8eb66d chore(deps): bump @radix-ui/react-icons in /extensions/react-widget
Bumps @radix-ui/react-icons from 1.3.0 to 1.3.2.

---
updated-dependencies:
- dependency-name: "@radix-ui/react-icons"
  dependency-version: 1.3.2
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-03-15 20:25:18 +00:00
Manish Madan
4c810108e0 Merge pull request #2145 from arc53/dependabot/npm_and_yarn/frontend/lint-staged-16.2.6
chore(deps-dev): bump lint-staged from 15.5.2 to 16.2.6 in /frontend
2026-03-16 01:41:34 +05:30
dependabot[bot]
fc03cdc76a chore(deps-dev): bump lint-staged from 15.5.2 to 16.2.6 in /frontend
Bumps [lint-staged](https://github.com/lint-staged/lint-staged) from 15.5.2 to 16.2.6.
- [Release notes](https://github.com/lint-staged/lint-staged/releases)
- [Changelog](https://github.com/lint-staged/lint-staged/blob/main/CHANGELOG.md)
- [Commits](https://github.com/lint-staged/lint-staged/compare/v15.5.2...v16.2.6)

---
updated-dependencies:
- dependency-name: lint-staged
  dependency-version: 16.2.6
  dependency-type: direct:development
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-03-15 20:09:03 +00:00
Manish Madan
9779a563f3 Merge pull request #2144 from arc53/dependabot/npm_and_yarn/frontend/vitejs/plugin-react-5.1.0
chore(deps-dev): bump @vitejs/plugin-react from 4.7.0 to 5.1.0 in /frontend
2026-03-16 01:36:30 +05:30
ManishMadan2882
6141c3c348 (chore:deps) fe, vite up 2026-03-16 01:34:24 +05:30
dependabot[bot]
c3726ddfc9 chore(deps-dev): bump @vitejs/plugin-react in /frontend
Bumps [@vitejs/plugin-react](https://github.com/vitejs/vite-plugin-react/tree/HEAD/packages/plugin-react) from 4.7.0 to 5.1.0.
- [Release notes](https://github.com/vitejs/vite-plugin-react/releases)
- [Changelog](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react/CHANGELOG.md)
- [Commits](https://github.com/vitejs/vite-plugin-react/commits/plugin-react@5.1.0/packages/plugin-react)

---
updated-dependencies:
- dependency-name: "@vitejs/plugin-react"
  dependency-version: 5.1.0
  dependency-type: direct:development
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-03-15 19:03:36 +00:00
Manish Madan
10eaa8143e Merge pull request #2143 from arc53/dependabot/npm_and_yarn/frontend/tailwindcss-4.1.17
chore(deps-dev): bump tailwindcss from 4.1.16 to 4.1.17 in /frontend
2026-03-16 00:31:39 +05:30
dependabot[bot]
0c4f4e1f0c chore(deps-dev): bump tailwindcss from 4.1.16 to 4.1.17 in /frontend
Bumps [tailwindcss](https://github.com/tailwindlabs/tailwindcss/tree/HEAD/packages/tailwindcss) from 4.1.16 to 4.1.17.
- [Release notes](https://github.com/tailwindlabs/tailwindcss/releases)
- [Changelog](https://github.com/tailwindlabs/tailwindcss/blob/main/CHANGELOG.md)
- [Commits](https://github.com/tailwindlabs/tailwindcss/commits/v4.1.17/packages/tailwindcss)

---
updated-dependencies:
- dependency-name: tailwindcss
  dependency-version: 4.1.17
  dependency-type: direct:development
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-03-15 18:56:49 +00:00
Manish Madan
b225c3cd80 Merge pull request #2149 from arc53/dependabot/npm_and_yarn/frontend/i18next-25.6.1
chore(deps): bump i18next from 25.6.0 to 25.6.1 in /frontend
2026-03-16 00:24:57 +05:30
dependabot[bot]
b558645d6b chore(deps): bump i18next from 25.6.0 to 25.6.1 in /frontend
Bumps [i18next](https://github.com/i18next/i18next) from 25.6.0 to 25.6.1.
- [Release notes](https://github.com/i18next/i18next/releases)
- [Changelog](https://github.com/i18next/i18next/blob/master/CHANGELOG.md)
- [Commits](https://github.com/i18next/i18next/compare/v25.6.0...v25.6.1)

---
updated-dependencies:
- dependency-name: i18next
  dependency-version: 25.6.1
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-03-15 18:48:31 +00:00
Manish Madan
03b0889b15 Merge pull request #2152 from arc53/dependabot/npm_and_yarn/frontend/react-syntax-highlighter-16.1.0
chore(deps): bump react-syntax-highlighter from 15.6.6 to 16.1.0 in /frontend
2026-03-15 19:58:29 +05:30
dependabot[bot]
943fe3651c chore(deps): bump react-syntax-highlighter in /frontend
Bumps [react-syntax-highlighter](https://github.com/react-syntax-highlighter/react-syntax-highlighter) from 15.6.6 to 16.1.0.
- [Release notes](https://github.com/react-syntax-highlighter/react-syntax-highlighter/releases)
- [Changelog](https://github.com/react-syntax-highlighter/react-syntax-highlighter/blob/master/CHANGELOG.MD)
- [Commits](https://github.com/react-syntax-highlighter/react-syntax-highlighter/compare/v15.6.6...v16.1.0)

---
updated-dependencies:
- dependency-name: react-syntax-highlighter
  dependency-version: 16.1.0
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-03-13 22:58:44 +00:00
Siddhant Rai
65e57be4dd feat: dynamic config rendering + mcp tool enhancement (#2286)
* feat: enhance modal functionality and configuration handling

- Updated WrapperModal to improve click outside detection for closing the modal.
- Refactored ToolConfig to utilize ConfigFieldSpec for better configuration management.
- Added validation and dynamic handling of configuration fields in ToolConfig.
- Introduced reconnect functionality for MCP tools in the Tools component.
- Enhanced user experience with improved error handling and loading states.
- Updated types for better type safety and clarity in configuration requirements.

* refactor: reorganize imports and improve conditional formatting

* fix: revert API_URL to use backend service name in docker-compose

* feat: add MCP auth status endpoint and integrate into user service and tools

* feat: implement logging for Brave, Postgres, and Telegram tools; add transport sanitization and credential extraction for MCP

---------

Co-authored-by: Alex <a@tushynski.me>
2026-03-13 15:58:50 +00:00
Siddhant Rai
13ad3b5dce feat: enhance logging and error handling across various tools; update DuckDuckGo dependency (#2282)
Co-authored-by: Alex <a@tushynski.me>
2026-03-12 16:50:29 +00:00
Manish Madan
918bbf0369 Sharepoint (#2283)
* feat: add Microsoft Entra ID integration

- Updated .env-template and settings.py for Microsoft Entra ID configuration.
- Enhanced ConnectorsCallback to support SharePoint authentication.
- Introduced SharePointAuth and SharePointLoader classes.
- Added required dependencies in requirements.txt.

* feat: agent templates and seeding premade agents (#1910)

* feat: agent templates and seeding premade agents

* fix: ensure ObjectId is used for source reference in agent configuration

* fix: improve source handling in DatabaseSeeder and update tool config processing

* feat: add prompt handling in DatabaseSeeder for agent configuration

* Docs premade agents

* link to prescraped docs

* feat: add template agent retrieval and adopt agent functionality

* feat: simplify agent descriptions in premade_agents.yaml  added docs

---------

Co-authored-by: Pavel <pabin@yandex.ru>
Co-authored-by: Alex <a@tushynski.me>

* feat: add GitHub access token support and fix file content fetching logic (#2032)

* feat: add init for Share Point connector module

* chore(deps): bump mermaid from 11.6.0 to 11.12.0 in /frontend

Bumps [mermaid](https://github.com/mermaid-js/mermaid) from 11.6.0 to 11.12.0.
- [Release notes](https://github.com/mermaid-js/mermaid/releases)
- [Commits](https://github.com/mermaid-js/mermaid/compare/mermaid@11.6.0...mermaid@11.12.0)

---
updated-dependencies:
- dependency-name: mermaid
  dependency-version: 11.12.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* Feat: Notification section (#2033)

* Feature/Notification-section

* fix notification ui and add local storage variable to save the state

* add notification component to app.tsx

* refactor: remove MICROSOFT_REDIRECT_URI and update SharePointAuth to use CONNECTOR_REDIRECT_BASE_URI

* feat: Add button to cancel LLM response (#1978)

* feat: Add button to cancel LLM response
- Replace text area with cancel button when loading.
- Add useEffect to change elipsis in cancel button text.
- Add new SVG icon for cancel response.
- Button colors match Figma designs.

* fix: Cancel button UI matches new design
- Delete cancel-response svg.
- Change previous cancel button to match the new Figma design.
- Remove console log in handleCancel function.

* fix: Adjust cancel button rounding

* feat: Update UI for send button
- Add SendArrowIcon component, enables dynamic svg color changes
- Replace original icon
- Update colors and hover effects

* (fix:send-button) minor blink in transition

---------

Co-authored-by: Manish Madan <manishmadan321@gmail.com>

* feat: add SharePoint integration with session validation and UI components

* (feat:oneDrive) file loading for ingestion

* feat(oneDrive): shared user files

* (feat:oneDrive) rm shared file support, as sharedWithMe is degraded

* (feat:sharepoint) shared files for work msa

* (feat:sharepoint) retry on auth failure, decorator

* (fix) tests/ruff

* test: fix sharepoint loader expecting client id

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: Abhishek Malviya <abfeb8@gmail.com>
Co-authored-by: Siddhant Rai <47355538+siiddhantt@users.noreply.github.com>
Co-authored-by: Pavel <pabin@yandex.ru>
Co-authored-by: Alex <a@tushynski.me>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Mariam Saeed <69825646+Mariam-Saeed@users.noreply.github.com>
Co-authored-by: Rahul <rahulgithub96@gmail.com>
2026-03-12 14:46:26 +00:00
Alex
5006271abb fix stream stuff (#2293) 2026-03-11 11:43:27 +00:00
dependabot[bot]
2c2bdd37d5 chore(deps-dev): bump typescript in /extensions/react-widget
Bumps [typescript](https://github.com/microsoft/TypeScript) from 5.4.5 to 5.9.3.
- [Release notes](https://github.com/microsoft/TypeScript/releases)
- [Changelog](https://github.com/microsoft/TypeScript/blob/main/azure-pipelines.release-publish.yml)
- [Commits](https://github.com/microsoft/TypeScript/compare/v5.4.5...v5.9.3)

---
updated-dependencies:
- dependency-name: typescript
  dependency-version: 5.9.3
  dependency-type: direct:development
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-11-07 15:19:11 +00:00
dependabot[bot]
6a00319c2d chore(deps): bump flow-bin in /extensions/react-widget
Bumps [flow-bin](https://github.com/flowtype/flow-bin) from 0.229.2 to 0.290.0.
- [Release notes](https://github.com/flowtype/flow-bin/releases)
- [Commits](https://github.com/flowtype/flow-bin/commits)

---
updated-dependencies:
- dependency-name: flow-bin
  dependency-version: 0.290.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-11-07 15:19:00 +00:00
dependabot[bot]
66870279d3 chore(deps): bump @radix-ui/react-icons in /extensions/react-widget
Bumps @radix-ui/react-icons from 1.3.0 to 1.3.2.

---
updated-dependencies:
- dependency-name: "@radix-ui/react-icons"
  dependency-version: 1.3.2
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-11-07 15:18:48 +00:00
761 changed files with 859788 additions and 38708 deletions

View File

@@ -3,6 +3,14 @@ LLM_NAME=docsgpt
VITE_API_STREAMING=true
INTERNAL_KEY=<internal key for worker-to-backend authentication>
# Provider-specific API keys (optional - use these to enable multiple providers)
# OPENAI_API_KEY=<your-openai-api-key>
# ANTHROPIC_API_KEY=<your-anthropic-api-key>
# GOOGLE_API_KEY=<your-google-api-key>
# GROQ_API_KEY=<your-groq-api-key>
# NOVITA_API_KEY=<your-novita-api-key>
# OPEN_ROUTER_API_KEY=<your-openrouter-api-key>
# Remote Embeddings (Optional - for using a remote embeddings API instead of local SentenceTransformer)
# When set, the app will use the remote API and won't load SentenceTransformer (saves RAM)
EMBEDDINGS_BASE_URL=
@@ -12,4 +20,20 @@ EMBEDDINGS_KEY=
OPENAI_API_BASE=
OPENAI_API_VERSION=
AZURE_DEPLOYMENT_NAME=
AZURE_EMBEDDINGS_DEPLOYMENT_NAME=
AZURE_EMBEDDINGS_DEPLOYMENT_NAME=
#Azure AD Application (client) ID
MICROSOFT_CLIENT_ID=your-azure-ad-client-id
#Azure AD Application client secret
MICROSOFT_CLIENT_SECRET=your-azure-ad-client-secret
#Azure AD Tenant ID (or 'common' for multi-tenant)
MICROSOFT_TENANT_ID=your-azure-ad-tenant-id
#If you are using a Microsoft Entra ID tenant,
#configure the AUTHORITY variable as
#"https://login.microsoftonline.com/TENANT_GUID"
#or "https://login.microsoftonline.com/contoso.onmicrosoft.com".
#Alternatively, use "https://login.microsoftonline.com/common" for multi-tenant app.
MICROSOFT_AUTHORITY=https://{tenantId}.ciamlogin.com/{tenantId}
# POSTGRES_URI=postgresql://docsgpt:docsgpt@localhost:5432/docsgpt

99
.github/INCIDENT_RESPONSE.md vendored Normal file
View File

@@ -0,0 +1,99 @@
# DocsGPT Incident Response Plan (IRP)
This playbook describes how maintainers respond to confirmed or suspected security incidents.
- Vulnerability reporting: [`SECURITY.md`](../SECURITY.md)
- Non-security bugs/features: [`CONTRIBUTING.md`](../CONTRIBUTING.md)
## Severity
| Severity | Definition | Typical examples |
|---|---|---|
| **Critical** | Active exploitation, supply-chain compromise, or confirmed data breach requiring immediate user action. | Compromised release artifact/image; remote execution. |
| **High** | Serious undisclosed vulnerability with no practical workaround, or CVSS >= 7.0. | key leakage; prompt injection enabling cross-tenant access. |
| **Medium** | Material impact but constrained by preconditions/scope, or a practical workaround exists. | Auth-required exploit; dependency CVE with limited reachability. |
| **Low** | Defense-in-depth or narrow availability impact with no confirmed data exposure. | Missing rate limiting; hardening gap without exploit evidence. |
## Response workflow
### 1) Triage (target: initial response within 48 hours)
1. Acknowledge report.
2. Validate on latest release and `main`.
3. Confirm in-scope security issue vs. hardening item (per `SECURITY.md`).
4. Assign severity and open a **draft GitHub Security Advisory (GHSA)** (no public issue).
5. Determine whether root cause is DocsGPT code or upstream dependency/provider.
### 2) Investigation
1. Identify affected components, versions, and deployment scope (self-hosted, cloud, or both).
2. For AI issues, explicitly evaluate prompt injection, document isolation, and output leakage.
3. Request a CVE through GHSA for **Medium+** issues.
### 3) Containment, fix, and disclosure
1. Implement and test fix in private security workflow (GHSA private fork/branch).
2. Merge fix to `main`, cut patched release, and verify published artifacts/images.
3. Patch managed cloud deployment (`app.docsgpt.cloud`) and other deployments as soon as validated.
4. Publish GHSA with CVE (if assigned), affected/fixed versions, CVSS, mitigations, and upgrade guidance.
5. **Critical/High:** coordinate disclosure timing with reporter (goal: <= 90 days) and publish a notice.
6. **Medium/Low:** include in next scheduled release unless risk requires immediate out-of-band patching.
### 4) Post-incident
1. Monitor support channels (GitHub/Discord) for regressions or exploitation reports.
2. Run a short retrospective (root cause, detection, response gaps, prevention work).
3. Track follow-up hardening actions with owners/dates.
4. Update this IRP and related runbooks as needed.
## Scenario playbooks
### Supply-chain compromise
1. Freeze releases and investigate blast radius.
2. Rotate credentials in order: Docker Hub -> GitHub tokens -> LLM provider keys -> DB credentials -> `JWT_SECRET_KEY` -> `ENCRYPTION_SECRET_KEY` -> `INTERNAL_KEY`.
3. Replace compromised artifacts/tags with clean releases and revoke/remove bad tags where possible.
4. Publish advisory with exact affected versions and required user actions.
### Data exposure
1. Determine scope (users, documents, keys, logs, time window).
2. Disable affected path or hotfix immediately for managed cloud.
3. Notify affected users with concrete remediation steps (for example, rotate keys).
4. Continue through standard fix/disclosure workflow.
### Critical regression with security impact
1. Identify introducing change (`git bisect` if needed).
2. Publish workaround within 24 hours (for example, pin to known-good version).
3. Ship patch release with regression test and close incident with public summary.
## AI-specific guidance
Treat confirmed AI-specific abuse as security incidents:
- Prompt injection causing sensitive data exfiltration (from tools that don't belong to the agent) -> **High**
- Cross-tenant retrieval/isolation failure -> **High**
- API key disclosure in output -> **High**
## Secret rotation quick reference
| Secret | Standard rotation action |
|---|---|
| Docker Hub credentials | Revoke/replace in Docker Hub; update CI/CD secrets |
| GitHub tokens/PATs | Revoke/replace in GitHub; update automation secrets |
| LLM provider API keys | Rotate in provider console; update runtime/deploy secrets |
| Database credentials | Rotate in DB platform; redeploy with new secrets |
| `JWT_SECRET_KEY` | Rotate and redeploy (invalidates all active user sessions/tokens) |
| `ENCRYPTION_SECRET_KEY` | Rotate and redeploy (re-encrypt stored data if possible; existing encrypted data may become inaccessible) |
| `INTERNAL_KEY` | Rotate and redeploy (invalidates worker-to-backend authentication) |
## Maintenance
Review this document:
- after every **Critical/High** incident, and
- at least annually.
Changes should be proposed via pull request to `main`.

144
.github/THREAT_MODEL.md vendored Normal file
View File

@@ -0,0 +1,144 @@
# DocsGPT Public Threat Model
**Classification:** Public
**Last updated:** 2026-04-15
**Applies to:** Open-source and self-hosted DocsGPT deployments
## 1) Overview
DocsGPT ingests content (files/URLs/connectors), indexes it, and answers queries via LLM-backed APIs and optional tools.
Core components:
- Backend API (`application/`)
- Workers/ingestion (`application/worker.py` and related modules)
- Datastores (MongoDB/Redis/vector stores)
- Frontend (`frontend/`)
- Optional extensions/integrations (`extensions/`)
## 2) Scope and assumptions
In scope:
- Application-level threats in this repository.
- Local and internet-exposed self-hosted deployments.
Assumptions:
- Internet-facing instances enable auth and use strong secrets.
- Datastores/internal services are not publicly exposed.
Out of scope:
- Cloud hardware/provider compromise.
- Security guarantees of external LLM vendors.
- Full security audits of third-party systems targeted by tools (external DBs/MCP servers/code-exec APIs).
## 3) Security objectives
- Protect document/conversation confidentiality.
- Preserve integrity of prompts, agents, tools, and indexed data.
- Maintain API/worker availability.
- Enforce tenant isolation in authenticated deployments.
## 4) Assets
- Documents, attachments, chunks/embeddings, summaries.
- Conversations, agents, workflows, prompt templates.
- Secrets (JWT secret, `INTERNAL_KEY`, provider/API/OAuth credentials).
- Operational capacity (worker throughput, queue depth, model quota/cost).
## 5) Trust boundaries and untrusted input
Trust boundaries:
- Internet ↔ Frontend
- Frontend ↔ Backend API
- Backend ↔ Workers/internal APIs
- Backend/workers ↔ Datastores
- Backend ↔ External LLM/connectors/remote URLs
Untrusted input includes API payloads, file uploads, remote URLs, OAuth/webhook data, retrieved content, and LLM/tool arguments.
## 6) Main attack surfaces
1. Auth/authz paths and sharing tokens.
2. File upload + parsing pipeline.
3. Remote URL fetching and connectors (SSRF risk).
4. Agent/tool execution from LLM output.
5. Template/workflow rendering.
6. Frontend rendering + token storage.
7. Internal service endpoints (`INTERNAL_KEY`).
8. High-impact integrations (SQL tool, generic API tool, remote MCP tools).
## 7) Key threats and expected mitigations
### A. Auth/authz misconfiguration
- Threat: weak/no auth or leaked tokens leads to broad data access.
- Mitigations: require auth for public deployments, short-lived tokens, rotation/revocation, least-privilege sharing.
### B. Untrusted file ingestion
- Threat: malicious files/archives trigger traversal, parser exploits, or resource exhaustion.
- Mitigations: strict path checks, archive safeguards, file limits, patched parser dependencies.
### C. SSRF/outbound abuse
- Threat: URL loaders/tools access private/internal/metadata endpoints.
- Mitigations: validate URLs + redirects, block private/link-local ranges, apply egress controls/allowlists.
### D. Prompt injection + tool abuse
- Threat: retrieved text manipulates model behavior and causes unsafe tool calls.
- Threat: never rely on the model to "choose correctly" under adversarial input.
- Mitigations: treat retrieved/model output as untrusted, enforce tool policies, only expose tools explicitly assigned by the user/admin to that agent, separate system instructions from retrieved content, audit tool calls.
### E. Dangerous tool capability chaining (SQL/API/MCP)
- Threat: write-capable SQL credentials allow destructive queries.
- Threat: API tool can trigger side effects (infra/payment/webhook/code-exec endpoints).
- Threat: remote MCP tools may expose privileged operations.
- Mitigations: read-only-by-default credentials, destination allowlists, explicit approval for write/exec actions, per-tool policy enforcement + logging.
### F. Frontend/XSS + token theft
- Threat: XSS can steal local tokens and call APIs.
- Mitigations: reduce unsafe rendering paths, strong CSP, scoped short-lived credentials.
### G. Internal endpoint exposure
- Threat: weak/unset `INTERNAL_KEY` enables internal API abuse.
- Mitigations: fail closed, require strong random keys, keep internal APIs private.
### H. DoS and cost abuse
- Threat: request floods, large ingestion jobs, expensive prompts/crawls.
- Mitigations: rate limits, quotas, timeouts, queue backpressure, usage budgets.
## 8) Example attacker stories
- Internet-exposed deployment runs with weak/no auth and receives unauthorized data access/abuse.
- Intranet deployment intentionally using weak/no auth is vulnerable to insider misuse and lateral-movement abuse.
- Crafted archive attempts path traversal during extraction.
- Malicious URL/redirect chain targets internal services.
- Poisoned document causes data exfiltration through tool calls.
- Over-privileged SQL/API/MCP tool performs destructive side effects.
## 9) Severity calibration
- **Critical:** unauthenticated public data access; prompt-injection-driven exfiltration; SSRF to sensitive internal endpoints.
- **High:** cross-tenant leakage, persistent token compromise, over-privileged destructive tools.
- **Medium:** DoS/cost amplification and non-critical information disclosure.
- **Low:** minor hardening gaps with limited impact.
## 10) Baseline controls for public deployments
1. Enforce authentication and secure defaults.
2. Set/rotate strong secrets (`JWT`, `INTERNAL_KEY`, encryption keys).
3. Restrict CORS and front API with a hardened proxy.
4. Add rate limiting/quotas for answer/upload/crawl/token endpoints.
5. Enforce URL+redirect SSRF protections and egress restrictions.
6. Apply upload/archive/parsing hardening.
7. Require least-privilege tool credentials and auditable tool execution.
8. Monitor auth failures, tool anomalies, ingestion spikes, and cost anomalies.
9. Keep dependencies/images patched and scanned.
10. Validate multi-tenant isolation with explicit tests.
## 11) Maintenance
Review this model after major auth, ingestion, connector, tool, or workflow changes.
## References
- [OWASP Top 10 for LLM Applications](https://owasp.org/www-project-top-10-for-large-language-model-applications/)
- [OWASP ASVS](https://owasp.org/www-project-application-security-verification-standard/)
- [STRIDE overview](https://learn.microsoft.com/azure/security/develop/threat-modeling-tool-threats)
- [DocsGPT SECURITY.md](../SECURITY.md)

View File

@@ -1,46 +1,80 @@
Ollama
Qdrant
Milvus
Chatwoot
Nextra
VSCode
npm
LLMs
Agentic
Anthropic's
api
APIs
Groq
SGLang
LMDeploy
OAuth
Vite
LLM
JSONPath
UIs
Atlassian
automations
autoescaping
Autoescaping
backfill
backfills
bool
boolean
brave_web_search
chatbot
Chatwoot
config
configs
uncomment
qdrant
vectorstore
CSVs
dev
diarization
Docling
docsgpt
llm
docstrings
Entra
env
enqueues
EOL
ESLint
feedbacks
Figma
GPUs
Groq
hardcode
hardcoding
Idempotency
JSONPath
kubectl
Lightsail
enqueues
chatbot
VSCode's
Shareability
feedbacks
automations
llama_cpp
llm
LLM
LLMs
LMDeploy
Milvus
Mixtral
namespace
namespaces
needs_auth
Nextra
Novita
npm
OAuth
Ollama
opencode
parsable
passthrough
PDFs
pgvector
Postgres
Premade
Signup
Pydantic
pytest
Qdrant
qdrant
Repo
repo
env
URl
agentic
llama_cpp
parsable
Sanitization
SDKs
boolean
bool
hardcode
EOL
SGLang
Shareability
Signup
Supabase
UIs
uncomment
URl
vectorstore
Vite
VSCode
VSCode's
widget's

114
.github/workflows/npm-publish.yml vendored Normal file
View File

@@ -0,0 +1,114 @@
name: Publish npm libraries
on:
workflow_dispatch:
inputs:
version:
description: >
Version bump type (patch | minor | major) or explicit semver (e.g. 1.2.3).
Applies to both docsgpt and docsgpt-react.
required: true
default: patch
permissions:
contents: write
pull-requests: write
jobs:
publish:
runs-on: ubuntu-latest
environment: npm-release
defaults:
run:
working-directory: extensions/react-widget
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v4
with:
node-version: 20
registry-url: https://registry.npmjs.org
- name: Install dependencies
run: npm ci
# ── docsgpt (HTML embedding bundle) ──────────────────────────────────
# Uses the `build` script (parcel build src/browser.tsx) and keeps
# the `targets` field so Parcel produces browser-optimised bundles.
- name: Set package name → docsgpt
run: jq --arg n "docsgpt" '.name=$n' package.json > _tmp.json && mv _tmp.json package.json
- name: Bump version (docsgpt)
id: version_docsgpt
run: |
VERSION="${{ github.event.inputs.version }}"
NEW_VER=$(npm version "${VERSION:-patch}" --no-git-tag-version)
echo "version=${NEW_VER#v}" >> "$GITHUB_OUTPUT"
- name: Build docsgpt
run: npm run build
- name: Publish docsgpt
run: npm publish --verbose
env:
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
# ── docsgpt-react (React library bundle) ─────────────────────────────
# Uses `build:react` script (parcel build src/index.ts) and strips
# the `targets` field so Parcel treats the output as a plain library
# without browser-specific target resolution, producing a smaller bundle.
- name: Reset package.json from source control
run: git checkout -- package.json
- name: Set package name → docsgpt-react
run: jq --arg n "docsgpt-react" '.name=$n' package.json > _tmp.json && mv _tmp.json package.json
- name: Remove targets field (react library build)
run: jq 'del(.targets)' package.json > _tmp.json && mv _tmp.json package.json
- name: Bump version (docsgpt-react) to match docsgpt
run: npm version "${{ steps.version_docsgpt.outputs.version }}" --no-git-tag-version
- name: Clean dist before react build
run: rm -rf dist
- name: Build docsgpt-react
run: npm run build:react
- name: Publish docsgpt-react
run: npm publish --verbose
env:
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
# ── Commit the bumped version back to the repository ─────────────────
- name: Reset package.json and write final version
run: |
git checkout -- package.json
jq --arg v "${{ steps.version_docsgpt.outputs.version }}" '.version=$v' \
package.json > _tmp.json && mv _tmp.json package.json
npm install --package-lock-only
- name: Commit version bump and create PR
run: |
git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com"
BRANCH="chore/bump-npm-v${{ steps.version_docsgpt.outputs.version }}"
git checkout -b "$BRANCH"
git add package.json package-lock.json
git commit -m "chore: bump npm libraries to v${{ steps.version_docsgpt.outputs.version }}"
git push origin "$BRANCH"
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Create PR
run: |
gh pr create \
--title "chore: bump npm libraries to v${{ steps.version_docsgpt.outputs.version }}" \
--body "Automated version bump after npm publish." \
--base main
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -0,0 +1,34 @@
name: React Widget Build
on:
push:
paths:
- 'extensions/react-widget/**'
pull_request:
paths:
- 'extensions/react-widget/**'
permissions:
contents: read
jobs:
build:
runs-on: ubuntu-latest
defaults:
run:
working-directory: extensions/react-widget
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v4
with:
node-version: 20
cache: npm
cache-dependency-path: extensions/react-widget/package-lock.json
- name: Install dependencies
run: npm ci
- name: Build
run: npm run build

View File

@@ -11,7 +11,6 @@ on:
permissions:
contents: read
pull-requests: write
jobs:
vale:
@@ -20,11 +19,16 @@ jobs:
- name: Checkout code
uses: actions/checkout@v4
- name: Vale linter
uses: errata-ai/vale-action@v2
with:
files: docs
fail_on_error: false
version: 3.0.5
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Install Vale
run: |
curl -fsSL -o vale.tar.gz \
https://github.com/errata-ai/vale/releases/download/v3.0.5/vale_3.0.5_Linux_64-bit.tar.gz
tar -xzf vale.tar.gz
sudo mv vale /usr/local/bin/vale
vale --version
- name: Sync Vale packages
run: vale sync
- name: Run Vale
run: vale --minAlertLevel=error docs

25
.github/workflows/zizmor.yml vendored Normal file
View File

@@ -0,0 +1,25 @@
name: GitHub Actions Security Analysis
on:
push:
branches: ["master"]
pull_request:
branches: ["**"]
permissions: {}
jobs:
zizmor:
runs-on: ubuntu-latest
permissions:
security-events: write # Required for upload-sarif (used by zizmor-action) to upload SARIF files.
steps:
- name: Checkout repository
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Run zizmor 🌈
uses: zizmorcore/zizmor-action@71321a20a9ded102f6e9ce5718a2fcec2c4f70d8 # v0.5.2

12
.gitignore vendored
View File

@@ -2,6 +2,7 @@
__pycache__/
*.py[cod]
*$py.class
results.txt
experiments/
experiments
@@ -107,6 +108,8 @@ celerybeat.pid
# Environments
.env
.venv
# Machine-specific Claude Code guidance (see CLAUDE.md preamble)
CLAUDE.md
env/
venv/
ENV/
@@ -180,5 +183,14 @@ application/vectors/
node_modules/
.vscode/settings.json
.vscode/sftp.json
/models/
model/
# E2E test artifacts
.e2e-tmp/
/tmp/docsgpt-e2e/
tests/e2e/node_modules/
tests/e2e/playwright-report/
tests/e2e/test-results/
tests/e2e/.e2e-last-run.json

View File

@@ -1,5 +1,7 @@
MinAlertLevel = warning
StylesPath = .github/styles
Vocab = DocsGPT
[*.{md,mdx}]
BasedOnStyles = DocsGPT

156
AGENTS.md Normal file
View File

@@ -0,0 +1,156 @@
# AGENTS.md
- Read `CONTRIBUTING.md` before making non-trivial changes.
- For day-to-day development and feature work, follow the development-environment workflow rather than defaulting to `setup.sh` / `setup.ps1`.
- Avoid using the setup scripts during normal feature work unless the user explicitly asks for them. Users configure `.env` usually.
- Try to follow red/green TDD
### Check existing dev prerequisites first
For feature work, do **not** assume the environment needs to be recreated.
- Check whether the user already has a Python virtual environment such as `venv/` or `.venv/`.
- Check whether Postgres is already running and reachable via `POSTGRES_URI` (the canonical user-data store).
- Check whether Redis is already running.
- Reuse what is already working. Do not stop or recreate Postgres, Redis, or the Python environment unless the task is environment setup or troubleshooting.
> MongoDB is **not** required for the default install. It is only needed if
> the user opts into the Mongo vector-store backend (`VECTOR_STORE=mongodb`)
> or is running the one-shot `scripts/db/backfill.py` to migrate existing
> user data from the legacy Mongo-based install. In those cases, `pymongo`
> is available as an optional extra, not a core dependency.
## Normal local development commands
Use these commands once the dev prerequisites above are satisfied.
### Backend
```bash
source .venv/bin/activate # macOS/Linux
uv pip install -r application/requirements.txt # or: pip install -r application/requirements.txt
```
Run the Flask API (if needed):
```bash
flask --app application/app.py run --host=0.0.0.0 --port=7091
```
That's the fast inner-loop option — quick startup, the Werkzeug interactive
debugger still works, and it hot-reloads on source changes. It serves the
Flask routes only (`/api/*`, `/stream`, etc.).
If you need to exercise the full ASGI stack — the `/mcp` FastMCP endpoint,
or to match the production runtime exactly — run the ASGI composition under
uvicorn instead:
```bash
uvicorn application.asgi:asgi_app --host 0.0.0.0 --port 7091 --reload
```
Production uses `gunicorn -k uvicorn_worker.UvicornWorker` against the same
`application.asgi:asgi_app` target; see `application/Dockerfile` for the
full flag set.
Run the Celery worker in a separate terminal (if needed):
```bash
celery -A application.app.celery worker -l INFO
```
On macOS, prefer the solo pool for Celery:
```bash
python -m celery -A application.app.celery worker -l INFO --pool=solo
```
### Frontend
Install dependencies only when needed, then run the dev server:
```bash
cd frontend
npm install --include=dev
npm run dev
```
### Docs site
```bash
cd docs
npm install
```
### Python / backend changes validation
```bash
ruff check .
python -m pytest
```
### Frontend changes
```bash
cd frontend && npm run lint
cd frontend && npm run build
```
### Documentation changes
```bash
cd docs && npm run build
```
If Vale is installed locally and you edited prose, also run:
```bash
vale .
```
## Repository map
- `application/`: Flask backend, API routes, agent logic, retrieval, parsing, security, storage, Celery worker, and WSGI entrypoints.
- `tests/`: backend unit/integration tests and test-only Python dependencies.
- `frontend/`: Vite + React + TypeScript application.
- `frontend/src/`: main UI code, including `components`, `conversation`, `hooks`, `locale`, `settings`, `upload`, and Redux store wiring in `store.ts`.
- `docs/`: separate documentation site built with Next.js/Nextra.
- `extensions/`: integrations and widgets — currently the Chatwoot webhook bridge and the React widget (published to npm as `docsgpt`). The Discord bot, Slack bot, and Chrome extension have been moved to their own repos under `arc53/`.
- `deployment/`: Docker Compose variants and Kubernetes manifests.
## Coding rules
### Backend
- Follow PEP 8 and keep Python line length at or under 120 characters.
- Use type hints for function arguments and return values.
- Add Google-style docstrings to new or substantially changed functions and classes.
- Add or update tests under `tests/` for backend behavior changes.
- Keep changes narrow in `api`, `auth`, `security`, `parser`, `retriever`, and `storage` areas.
### Backend Abstractions
- LLM providers implement a common interface in `application/llm/` (add new providers by extending the base class).
- Vector stores are abstracted in `application/vectorstore/`.
- Parsers live in `application/parser/` and handle different document formats in the ingestion stage.
- Agents and tools are in `application/agents/` and `application/agents/tools/`.
- Celery setup/config lives in `application/celery_init.py` and `application/celeryconfig.py`.
- Settings and env vars are managed via Pydantic in `application/core/settings.py`.
### Frontend
- Follow the existing ESLint + Prettier setup.
- Prefer small, reusable functional components and hooks.
- If shared state must be added, use Redux rather than introducing a new global state library.
- Avoid broad UI refactors unless the task explicitly asks for them.
- Do not re-create components if we already have some in the app.
## PR readiness
Before opening a PR:
- run the relevant validation commands above
- confirm backend changes still work end-to-end after ingesting sample data when applicable
- clearly summarize user-visible behavior changes
- mention any config, dependency, or deployment implications
- Ask your user to attach a screenshot or a video to it

View File

@@ -22,6 +22,11 @@ Thank you for choosing to contribute to DocsGPT! We are all very grateful!
- We have a frontend built on React (Vite) and a backend in Python.
> **Required for every PR:** Please attach screenshots or a short screen
> recording that shows the working version of your changes. This makes the
> requirement visible to reviewers and helps them quickly verify what you are
> submitting.
Before creating issues, please check out how the latest version of our app looks and works by launching it via [Quickstart](https://github.com/arc53/DocsGPT#quickstart) the version on our live demo is slightly modified with login. Your issues should relate to the version you can launch via [Quickstart](https://github.com/arc53/DocsGPT#quickstart).
@@ -125,7 +130,7 @@ Here's a step-by-step guide on how to contribute to DocsGPT:
```
9. **Submit a Pull Request (PR):**
- Create a Pull Request from your branch to the main repository. Make sure to include a detailed description of your changes and reference any related issues.
- Create a Pull Request from your branch to the main repository. Make sure to include a detailed description of your changes, reference any related issues, and attach screenshots or a screen recording showing the working version.
10. **Collaborate:**
- Be responsive to comments and feedback on your PR.

View File

@@ -7,7 +7,7 @@
</p>
<p align="left">
<strong><a href="https://www.docsgpt.cloud/">DocsGPT</a></strong> is an open-source AI platform for building intelligent agents and assistants. Features Agent Builder, deep research tools, document analysis (PDF, Office, web content), Multi-model support (choose your provider or run locally), and rich API connectivity for agents with actionable tools and integrations. Deploy anywhere with complete privacy control.
<strong><a href="https://www.docsgpt.cloud/">DocsGPT</a></strong> is an open-source AI platform for building intelligent agents and assistants. Features Agent Builder, deep research tools, document analysis (PDF, Office, web content, and audio), Multi-model support (choose your provider or run locally), and rich API connectivity for agents with actionable tools and integrations. Deploy anywhere with complete privacy control.
</p>
<div align="center">
@@ -29,13 +29,14 @@
<div align="center">
<br>
<img src="https://d3dg1063dc54p9.cloudfront.net/videos/demov7.gif" alt="video-example-of-docs-gpt" width="800" height="450">
<img src="https://d3dg1063dc54p9.cloudfront.net/videos/demo-26.gif" alt="video-example-of-docs-gpt" width="800" height="480">
</div>
<h3 align="left">
<strong>Key Features:</strong>
</h3>
<ul align="left">
<li><strong>🗂️ Wide Format Support:</strong> Reads PDF, DOCX, CSV, XLSX, EPUB, MD, RST, HTML, MDX, JSON, PPTX, and images.</li>
<li><strong>🗂️ Wide Format Support:</strong> Reads PDF, DOCX, CSV, XLSX, EPUB, MD, RST, HTML, MDX, JSON, PPTX, images, and audio files such as MP3, WAV, M4A, OGG, and WebM.</li>
<li><strong>🎙️ Speech Workflows:</strong> Record voice input into chat, transcribe audio on the backend, and ingest meeting recordings or voice notes as searchable knowledge.</li>
<li><strong>🌐 Web & Data Integration:</strong> Ingests from URLs, sitemaps, Reddit, GitHub and web crawlers.</li>
<li><strong>✅ Reliable Answers:</strong> Get accurate, hallucination-free responses with source citations viewable in a clean UI.</li>
<li><strong>🔑 Streamlined API Keys:</strong> Generate keys linked to your settings, documents, and models, simplifying chatbot and integration setup.</li>
@@ -158,4 +159,3 @@ The source code license is [MIT](https://opensource.org/license/mit/), as descri
</a>
</p>

View File

@@ -2,13 +2,21 @@
## Supported Versions
Supported Versions:
Currently, we support security patches by committing changes and bumping the version published on Github.
Security patches target the latest release and the `main` branch. We recommend always running the most recent version.
## Reporting a Vulnerability
Found a vulnerability? Please email us:
Preferred method: use GitHub's private vulnerability reporting flow:
https://github.com/arc53/DocsGPT/security
security@arc53.com
Then click **Report a vulnerability**.
Alternatively, email us at: security@arc53.com
We aim to acknowledge reports within 48 hours.
## Incident Handling
For the public incident response process, see [`INCIDENT_RESPONSE.md`](./.github/INCIDENT_RESPONSE.md). If you believe an active exploit is occurring, include **URGENT** in your report subject line.

View File

@@ -88,5 +88,15 @@ EXPOSE 7091
# Switch to non-root user
USER appuser
# Start Gunicorn
CMD ["gunicorn", "-w", "1", "--timeout", "120", "--bind", "0.0.0.0:7091", "--preload", "application.wsgi:app"]
CMD ["gunicorn", \
"-w", "1", \
"-k", "uvicorn_worker.UvicornWorker", \
"--bind", "0.0.0.0:7091", \
"--timeout", "180", \
"--graceful-timeout", "120", \
"--keep-alive", "5", \
"--worker-tmp-dir", "/dev/shm", \
"--max-requests", "1000", \
"--max-requests-jitter", "100", \
"--config", "application/gunicorn_conf.py", \
"application.asgi:asgi_app"]

View File

@@ -1,7 +1,8 @@
import logging
from application.agents.agentic_agent import AgenticAgent
from application.agents.classic_agent import ClassicAgent
from application.agents.react_agent import ReActAgent
from application.agents.research_agent import ResearchAgent
from application.agents.workflow_agent import WorkflowAgent
logger = logging.getLogger(__name__)
@@ -10,7 +11,9 @@ logger = logging.getLogger(__name__)
class AgentCreator:
agents = {
"classic": ClassicAgent,
"react": ReActAgent,
"react": ClassicAgent, # backwards compat: react falls back to classic
"agentic": AgenticAgent,
"research": ResearchAgent,
"workflow": WorkflowAgent,
}

View File

@@ -0,0 +1,63 @@
import logging
from typing import Dict, Generator, Optional
from application.agents.base import BaseAgent
from application.agents.tools.internal_search import (
INTERNAL_TOOL_ID,
add_internal_search_tool,
)
from application.logging import LogContext
logger = logging.getLogger(__name__)
class AgenticAgent(BaseAgent):
"""Agent where the LLM controls retrieval via tools.
Unlike ClassicAgent which pre-fetches docs into the prompt,
AgenticAgent gives the LLM an internal_search tool so it can
decide when, what, and whether to search.
"""
def __init__(
self,
retriever_config: Optional[Dict] = None,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.retriever_config = retriever_config or {}
def _gen_inner(
self, query: str, log_context: LogContext
) -> Generator[Dict, None, None]:
tools_dict = self.tool_executor.get_tools()
add_internal_search_tool(tools_dict, self.retriever_config)
self._prepare_tools(tools_dict)
# 4. Build messages (prompt has NO pre-fetched docs)
messages = self._build_messages(self.prompt, query)
# 5. Call LLM — the handler manages the tool loop
llm_response = self._llm_gen(messages, log_context)
yield from self._handle_response(
llm_response, tools_dict, messages, log_context
)
# 6. Collect sources from internal search tool results
self._collect_internal_sources()
yield {"sources": self.retrieved_docs}
yield {"tool_calls": self._get_truncated_tool_calls()}
log_context.stacks.append(
{"component": "agent", "data": {"tool_calls": self.tool_calls.copy()}}
)
def _collect_internal_sources(self):
"""Collect retrieved docs from the cached InternalSearchTool instance."""
cache_key = f"internal_search:{INTERNAL_TOOL_ID}:{self.user or ''}"
tool = self.tool_executor._loaded_tools.get(cache_key)
if tool and hasattr(tool, "retrieved_docs") and tool.retrieved_docs:
self.retrieved_docs = tool.retrieved_docs

View File

@@ -1,18 +1,16 @@
import json
import logging
import uuid
from abc import ABC, abstractmethod
from typing import Dict, Generator, List, Optional
from typing import Any, Dict, Generator, List, Optional
from bson.objectid import ObjectId
from application.agents.tools.tool_action_parser import ToolActionParser
from application.agents.tools.tool_manager import ToolManager
from application.agents.tool_executor import ToolExecutor
from application.core.json_schema_utils import (
JsonSchemaValidationError,
normalize_json_schema_payload,
)
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.llm.handlers.base import ToolCall
from application.llm.handlers.handler_creator import LLMHandlerCreator
from application.llm.llm_creator import LLMCreator
from application.logging import build_stack_data, log_activity, LogContext
@@ -40,6 +38,10 @@ class BaseAgent(ABC):
limited_request_mode: Optional[bool] = False,
request_limit: Optional[int] = settings.DEFAULT_AGENT_LIMITS["request_limit"],
compressed_summary: Optional[str] = None,
llm=None,
llm_handler=None,
tool_executor: Optional[ToolExecutor] = None,
backup_models: Optional[List[str]] = None,
):
self.endpoint = endpoint
self.llm_name = llm_name
@@ -50,22 +52,42 @@ class BaseAgent(ABC):
self.prompt = prompt
self.decoded_token = decoded_token or {}
self.user: str = self.decoded_token.get("sub")
self.tool_config: Dict = {}
self.tools: List[Dict] = []
self.tool_calls: List[Dict] = []
self.chat_history: List[Dict] = chat_history if chat_history is not None else []
self.llm = LLMCreator.create_llm(
llm_name,
api_key=api_key,
user_api_key=user_api_key,
decoded_token=decoded_token,
model_id=model_id,
agent_id=agent_id,
)
# Dependency injection for LLM — fall back to creating if not provided
if llm is not None:
self.llm = llm
else:
self.llm = LLMCreator.create_llm(
llm_name,
api_key=api_key,
user_api_key=user_api_key,
decoded_token=decoded_token,
model_id=model_id,
agent_id=agent_id,
backup_models=backup_models,
)
self.retrieved_docs = retrieved_docs or []
self.llm_handler = LLMHandlerCreator.create_handler(
llm_name if llm_name else "default"
)
if llm_handler is not None:
self.llm_handler = llm_handler
else:
self.llm_handler = LLMHandlerCreator.create_handler(
llm_name if llm_name else "default"
)
# Tool executor — injected or created
if tool_executor is not None:
self.tool_executor = tool_executor
else:
self.tool_executor = ToolExecutor(
user_api_key=user_api_key,
user=self.user,
decoded_token=decoded_token,
)
self.attachments = attachments or []
self.json_schema = None
if json_schema is not None:
@@ -93,327 +115,219 @@ class BaseAgent(ABC):
) -> Generator[Dict, None, None]:
pass
def _get_tools(self, api_key: str = None) -> Dict[str, Dict]:
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
agents_collection = db["agents"]
tools_collection = db["user_tools"]
def gen_continuation(
self,
messages: List[Dict],
tools_dict: Dict,
pending_tool_calls: List[Dict],
tool_actions: List[Dict],
) -> Generator[Dict, None, None]:
"""Resume generation after tool actions are resolved.
agent_data = agents_collection.find_one({"key": api_key or self.user_api_key})
tool_ids = agent_data.get("tools", []) if agent_data else []
tools = (
tools_collection.find(
{"_id": {"$in": [ObjectId(tool_id) for tool_id in tool_ids]}}
)
if tool_ids
else []
)
tools = list(tools)
tools_by_id = {str(tool["_id"]): tool for tool in tools} if tools else {}
return tools_by_id
def _get_user_tools(self, user="local"):
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
user_tools_collection = db["user_tools"]
user_tools = user_tools_collection.find({"user": user, "status": True})
user_tools = list(user_tools)
return {str(i): tool for i, tool in enumerate(user_tools)}
def _build_tool_parameters(self, action):
params = {"type": "object", "properties": {}, "required": []}
for param_type in ["query_params", "headers", "body", "parameters"]:
if param_type in action and action[param_type].get("properties"):
for k, v in action[param_type]["properties"].items():
if v.get("filled_by_llm", True):
params["properties"][k] = {
key: value
for key, value in v.items()
if key not in ("filled_by_llm", "value", "required")
}
if v.get("required", False):
params["required"].append(k)
return params
def _prepare_tools(self, tools_dict):
self.tools = [
{
"type": "function",
"function": {
"name": f"{action['name']}_{tool_id}",
"description": action["description"],
"parameters": self._build_tool_parameters(action),
},
}
for tool_id, tool in tools_dict.items()
if (
(tool["name"] == "api_tool" and "actions" in tool.get("config", {}))
or (tool["name"] != "api_tool" and "actions" in tool)
)
for action in (
tool["config"]["actions"].values()
if tool["name"] == "api_tool"
else tool["actions"]
)
if action.get("active", True)
]
def _execute_tool_action(self, tools_dict, call):
parser = ToolActionParser(self.llm.__class__.__name__)
tool_id, action_name, call_args = parser.parse_args(call)
call_id = getattr(call, "id", None) or str(uuid.uuid4())
# Check if parsing failed
if tool_id is None or action_name is None:
error_message = f"Error: Failed to parse LLM tool call. Tool name: {getattr(call, 'name', 'unknown')}"
logger.error(error_message)
tool_call_data = {
"tool_name": "unknown",
"call_id": call_id,
"action_name": getattr(call, "name", "unknown"),
"arguments": call_args or {},
"result": f"Failed to parse tool call. Invalid tool name format: {getattr(call, 'name', 'unknown')}",
}
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
self.tool_calls.append(tool_call_data)
return "Failed to parse tool call.", call_id
# Check if tool_id exists in available tools
if tool_id not in tools_dict:
error_message = f"Error: Tool ID '{tool_id}' extracted from LLM call not found in available tools_dict. Available IDs: {list(tools_dict.keys())}"
logger.error(error_message)
# Return error result
tool_call_data = {
"tool_name": "unknown",
"call_id": call_id,
"action_name": f"{action_name}_{tool_id}",
"arguments": call_args,
"result": f"Tool with ID {tool_id} not found. Available tools: {list(tools_dict.keys())}",
}
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
self.tool_calls.append(tool_call_data)
return f"Tool with ID {tool_id} not found.", call_id
tool_call_data = {
"tool_name": tools_dict[tool_id]["name"],
"call_id": call_id,
"action_name": f"{action_name}_{tool_id}",
"arguments": call_args,
}
yield {"type": "tool_call", "data": {**tool_call_data, "status": "pending"}}
tool_data = tools_dict[tool_id]
action_data = (
tool_data["config"]["actions"][action_name]
if tool_data["name"] == "api_tool"
else next(
action
for action in tool_data["actions"]
if action["name"] == action_name
)
)
query_params, headers, body, parameters = {}, {}, {}, {}
param_types = {
"query_params": query_params,
"headers": headers,
"body": body,
"parameters": parameters,
}
for param_type, target_dict in param_types.items():
if param_type in action_data and action_data[param_type].get("properties"):
for param, details in action_data[param_type]["properties"].items():
if (
param not in call_args
and "value" in details
and details["value"]
):
target_dict[param] = details["value"]
for param, value in call_args.items():
for param_type, target_dict in param_types.items():
if param_type in action_data and param in action_data[param_type].get(
"properties", {}
):
target_dict[param] = value
tm = ToolManager(config={})
# Prepare tool_config and add tool_id for memory tools
if tool_data["name"] == "api_tool":
action_config = tool_data["config"]["actions"][action_name]
tool_config = {
"url": action_config["url"],
"method": action_config["method"],
"headers": headers,
"query_params": query_params,
}
if "body_content_type" in action_config:
tool_config["body_content_type"] = action_config.get(
"body_content_type", "application/json"
)
tool_config["body_encoding_rules"] = action_config.get(
"body_encoding_rules", {}
)
else:
tool_config = tool_data["config"].copy() if tool_data["config"] else {}
# Add tool_id from MongoDB _id for tools that need instance isolation (like memory tool)
# Use MongoDB _id if available, otherwise fall back to enumerated tool_id
tool_config["tool_id"] = str(tool_data.get("_id", tool_id))
if hasattr(self, "conversation_id") and self.conversation_id:
tool_config["conversation_id"] = self.conversation_id
tool = tm.load_tool(
tool_data["name"],
tool_config=tool_config,
user_id=self.user,
)
resolved_arguments = (
{"query_params": query_params, "headers": headers, "body": body}
if tool_data["name"] == "api_tool"
else parameters
)
if tool_data["name"] == "api_tool":
logger.debug(
f"Executing api: {action_name} with query_params: {query_params}, headers: {headers}, body: {body}"
)
result = tool.execute_action(action_name, **body)
else:
logger.debug(f"Executing tool: {action_name} with args: {call_args}")
result = tool.execute_action(action_name, **parameters)
get_artifact_id = (
getattr(tool, "get_artifact_id", None)
if tool_data["name"] != "api_tool"
else None
)
artifact_id = None
if callable(get_artifact_id):
try:
artifact_id = get_artifact_id(action_name, **parameters)
except Exception:
logger.exception(
"Failed to extract artifact_id from tool %s for action %s",
tool_data["name"],
action_name,
)
artifact_id = str(artifact_id).strip() if artifact_id is not None else ""
if artifact_id:
tool_call_data["artifact_id"] = artifact_id
result_full = str(result)
tool_call_data["resolved_arguments"] = resolved_arguments
tool_call_data["result_full"] = result_full
tool_call_data["result"] = (
f"{result_full[:50]}..." if len(result_full) > 50 else result_full
)
stream_tool_call_data = {
key: value
for key, value in tool_call_data.items()
if key not in {"result_full", "resolved_arguments"}
}
yield {"type": "tool_call", "data": {**stream_tool_call_data, "status": "completed"}}
self.tool_calls.append(tool_call_data)
return result, call_id
def _get_truncated_tool_calls(self):
return [
{
"tool_name": tool_call.get("tool_name"),
"call_id": tool_call.get("call_id"),
"action_name": tool_call.get("action_name"),
"arguments": tool_call.get("arguments"),
"artifact_id": tool_call.get("artifact_id"),
"result": (
f"{str(tool_call['result'])[:50]}..."
if len(str(tool_call["result"])) > 50
else tool_call["result"]
),
"status": "completed",
}
for tool_call in self.tool_calls
]
def _calculate_current_context_tokens(self, messages: List[Dict]) -> int:
"""
Calculate total tokens in current context (messages).
Processes the client-provided *tool_actions* (approvals, denials,
or client-side results), appends the resulting messages, then
hands back to the LLM to continue the conversation.
Args:
messages: List of message dicts
Returns:
Total token count
messages: The saved messages array from the pause point.
tools_dict: The saved tools dictionary.
pending_tool_calls: The pending tool call descriptors from the pause.
tool_actions: Client-provided actions resolving the pending calls.
"""
self._prepare_tools(tools_dict)
actions_by_id = {a["call_id"]: a for a in tool_actions}
# Build a single assistant message containing all tool calls so
# the message history matches the format LLM providers expect
# (one assistant message with N tool_calls, followed by N tool results).
tc_objects: List[Dict[str, Any]] = []
for pending in pending_tool_calls:
call_id = pending["call_id"]
args = pending["arguments"]
args_str = (
json.dumps(args) if isinstance(args, dict) else (args or "{}")
)
tc_obj: Dict[str, Any] = {
"id": call_id,
"type": "function",
"function": {
"name": pending["name"],
"arguments": args_str,
},
}
if pending.get("thought_signature"):
tc_obj["thought_signature"] = pending["thought_signature"]
tc_objects.append(tc_obj)
messages.append({
"role": "assistant",
"content": None,
"tool_calls": tc_objects,
})
# Now process each pending call and append tool result messages
for pending in pending_tool_calls:
call_id = pending["call_id"]
args = pending["arguments"]
action = actions_by_id.get(call_id)
if not action:
action = {
"call_id": call_id,
"decision": "denied",
"comment": "No response provided",
}
if action.get("decision") == "approved":
# Execute the tool server-side
tc = ToolCall(
id=call_id,
name=pending["name"],
arguments=(
json.dumps(args) if isinstance(args, dict) else args
),
)
tool_gen = self._execute_tool_action(tools_dict, tc)
tool_response = None
while True:
try:
event = next(tool_gen)
yield event
except StopIteration as e:
tool_response, _ = e.value
break
messages.append(
self.llm_handler.create_tool_message(tc, tool_response)
)
elif action.get("decision") == "denied":
comment = action.get("comment", "")
denial = (
f"Tool execution denied by user. Reason: {comment}"
if comment
else "Tool execution denied by user."
)
tc = ToolCall(
id=call_id, name=pending["name"], arguments=args
)
messages.append(
self.llm_handler.create_tool_message(tc, denial)
)
yield {
"type": "tool_call",
"data": {
"tool_name": pending.get("tool_name", "unknown"),
"call_id": call_id,
"action_name": pending.get("llm_name", pending["name"]),
"arguments": args,
"status": "denied",
},
}
elif "result" in action:
result = action["result"]
result_str = (
json.dumps(result)
if not isinstance(result, str)
else result
)
tc = ToolCall(
id=call_id, name=pending["name"], arguments=args
)
messages.append(
self.llm_handler.create_tool_message(tc, result_str)
)
yield {
"type": "tool_call",
"data": {
"tool_name": pending.get("tool_name", "unknown"),
"call_id": call_id,
"action_name": pending.get("llm_name", pending["name"]),
"arguments": args,
"result": (
result_str[:50] + "..."
if len(result_str) > 50
else result_str
),
"status": "completed",
},
}
# Resume the LLM loop with the updated messages
llm_response = self._llm_gen(messages)
yield from self._handle_response(
llm_response, tools_dict, messages, None
)
yield {"sources": self.retrieved_docs}
yield {"tool_calls": self._get_truncated_tool_calls()}
# ---- Tool delegation (thin wrappers around ToolExecutor) ----
@property
def tool_calls(self) -> List[Dict]:
return self.tool_executor.tool_calls
@tool_calls.setter
def tool_calls(self, value: List[Dict]):
self.tool_executor.tool_calls = value
def _get_tools(self, api_key: str = None) -> Dict[str, Dict]:
return self.tool_executor._get_tools_by_api_key(api_key or self.user_api_key)
def _get_user_tools(self, user="local"):
return self.tool_executor._get_user_tools(user)
def _build_tool_parameters(self, action):
return self.tool_executor._build_tool_parameters(action)
def _prepare_tools(self, tools_dict):
self.tools = self.tool_executor.prepare_tools_for_llm(tools_dict)
def _execute_tool_action(self, tools_dict, call):
return self.tool_executor.execute(
tools_dict, call, self.llm.__class__.__name__
)
def _get_truncated_tool_calls(self):
return self.tool_executor.get_truncated_tool_calls()
# ---- Context / token management ----
def _calculate_current_context_tokens(self, messages: List[Dict]) -> int:
from application.api.answer.services.compression.token_counter import (
TokenCounter,
)
return TokenCounter.count_message_tokens(messages)
def _check_context_limit(self, messages: List[Dict]) -> bool:
"""
Check if we're approaching context limit (80%).
Args:
messages: Current message list
Returns:
True if at or above 80% of context limit
"""
from application.core.model_utils import get_token_limit
from application.core.settings import settings
try:
# Calculate current tokens
current_tokens = self._calculate_current_context_tokens(messages)
self.current_token_count = current_tokens
# Get context limit for model
context_limit = get_token_limit(self.model_id)
# Calculate threshold (80%)
threshold = int(context_limit * settings.COMPRESSION_THRESHOLD_PERCENTAGE)
# Check if we've reached the limit
if current_tokens >= threshold:
logger.warning(
f"Context limit approaching: {current_tokens}/{context_limit} tokens "
f"({(current_tokens/context_limit)*100:.1f}%)"
)
return True
return False
except Exception as e:
logger.error(f"Error checking context limit: {str(e)}", exc_info=True)
return False
def _validate_context_size(self, messages: List[Dict]) -> None:
"""
Pre-flight validation before calling LLM. Logs warnings but never raises errors.
Args:
messages: Messages to be sent to LLM
"""
from application.core.model_utils import get_token_limit
current_tokens = self._calculate_current_context_tokens(messages)
self.current_token_count = current_tokens
context_limit = get_token_limit(self.model_id)
percentage = (current_tokens / context_limit) * 100
# Log based on usage level
if current_tokens >= context_limit:
logger.warning(
f"Context at limit: {current_tokens:,}/{context_limit:,} tokens "
@@ -428,44 +342,32 @@ class BaseAgent(ABC):
)
def _truncate_text_middle(self, text: str, max_tokens: int) -> str:
"""
Truncate text by removing content from the middle, preserving start and end.
Args:
text: Text to truncate
max_tokens: Maximum tokens allowed
Returns:
Truncated text with middle removed if needed
"""
from application.utils import num_tokens_from_string
current_tokens = num_tokens_from_string(text)
if current_tokens <= max_tokens:
return text
# Estimate chars per token (roughly 4 chars per token for English)
chars_per_token = len(text) / current_tokens if current_tokens > 0 else 4
target_chars = int(max_tokens * chars_per_token * 0.95) # 5% safety margin
target_chars = int(max_tokens * chars_per_token * 0.95)
if target_chars <= 0:
return ""
# Split: keep 40% from start, 40% from end, remove middle
start_chars = int(target_chars * 0.4)
end_chars = int(target_chars * 0.4)
truncation_marker = "\n\n[... content truncated to fit context limit ...]\n\n"
truncated = text[:start_chars] + truncation_marker + text[-end_chars:]
logger.info(
f"Truncated text from {current_tokens:,} to ~{max_tokens:,} tokens "
f"(removed middle section)"
)
return truncated
# ---- Message building ----
def _build_messages(
self,
system_prompt: str,
@@ -475,7 +377,6 @@ class BaseAgent(ABC):
from application.core.model_utils import get_token_limit
from application.utils import num_tokens_from_string
# Append compression summary to system prompt if present
if self.compressed_summary:
compression_context = (
"\n\n---\n\n"
@@ -489,23 +390,18 @@ class BaseAgent(ABC):
context_limit = get_token_limit(self.model_id)
system_tokens = num_tokens_from_string(system_prompt)
# Reserve 10% for response/tools
safety_buffer = int(context_limit * 0.1)
available_after_system = context_limit - system_tokens - safety_buffer
# Max tokens for query: 80% of available space (leave room for history)
max_query_tokens = int(available_after_system * 0.8)
query_tokens = num_tokens_from_string(query)
# Truncate query from middle if it exceeds 80% of available context
if query_tokens > max_query_tokens:
query = self._truncate_text_middle(query, max_query_tokens)
query_tokens = num_tokens_from_string(query)
# Calculate remaining budget for chat history
available_for_history = max(available_after_system - query_tokens, 0)
# Truncate chat history to fit within available budget
working_history = self._truncate_history_to_fit(
self.chat_history,
available_for_history,
@@ -520,28 +416,35 @@ class BaseAgent(ABC):
if "tool_calls" in i:
for tool_call in i["tool_calls"]:
call_id = tool_call.get("call_id") or str(uuid.uuid4())
function_call_dict = {
"function_call": {
"name": tool_call.get("action_name"),
"args": tool_call.get("arguments"),
"call_id": call_id,
}
}
function_response_dict = {
"function_response": {
"name": tool_call.get("action_name"),
"response": {"result": tool_call.get("result")},
"call_id": call_id,
}
}
messages.append(
{"role": "assistant", "content": [function_call_dict]}
args = tool_call.get("arguments")
args_str = (
json.dumps(args)
if isinstance(args, dict)
else (args or "{}")
)
messages.append(
{"role": "tool", "content": [function_response_dict]}
messages.append({
"role": "assistant",
"content": None,
"tool_calls": [{
"id": call_id,
"type": "function",
"function": {
"name": tool_call.get("action_name", ""),
"arguments": args_str,
},
}],
})
result = tool_call.get("result")
result_str = (
json.dumps(result)
if not isinstance(result, str)
else (result or "")
)
messages.append({
"role": "tool",
"tool_call_id": call_id,
"content": result_str,
})
messages.append({"role": "user", "content": query})
return messages
@@ -550,16 +453,6 @@ class BaseAgent(ABC):
history: List[Dict],
max_tokens: int,
) -> List[Dict]:
"""
Truncate chat history to fit within token budget, keeping most recent messages.
Args:
history: Full chat history
max_tokens: Maximum tokens allowed for history
Returns:
Truncated history (most recent messages that fit)
"""
from application.utils import num_tokens_from_string
if not history or max_tokens <= 0:
@@ -568,7 +461,6 @@ class BaseAgent(ABC):
truncated = []
current_tokens = 0
# Iterate from newest to oldest
for message in reversed(history):
message_tokens = 0
@@ -588,7 +480,7 @@ class BaseAgent(ABC):
if current_tokens + message_tokens <= max_tokens:
current_tokens += message_tokens
truncated.insert(0, message) # Maintain chronological order
truncated.insert(0, message)
else:
break
@@ -600,13 +492,13 @@ class BaseAgent(ABC):
return truncated
# ---- LLM generation ----
def _llm_gen(self, messages: List[Dict], log_context: Optional[LogContext] = None):
# Pre-flight context validation - fail fast if over limit
self._validate_context_size(messages)
gen_kwargs = {"model": self.model_id, "messages": messages}
if self.attachments:
# Usage accounting only; stripped before provider invocation.
gen_kwargs["_usage_attachments"] = self.attachments
if (

View File

@@ -15,11 +15,7 @@ class ClassicAgent(BaseAgent):
) -> Generator[Dict, None, None]:
"""Core generator function for ClassicAgent execution flow"""
tools_dict = (
self._get_user_tools(self.user)
if not self.user_api_key
else self._get_tools(self.user_api_key)
)
tools_dict = self.tool_executor.get_tools()
self._prepare_tools(tools_dict)
messages = self._build_messages(self.prompt, query)

View File

@@ -1,238 +0,0 @@
import logging
import os
from typing import Any, Dict, Generator, List
from application.agents.base import BaseAgent
from application.logging import build_stack_data, LogContext
logger = logging.getLogger(__name__)
MAX_ITERATIONS_REASONING = 10
current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
with open(
os.path.join(current_dir, "application/prompts", "react_planning_prompt.txt"), "r"
) as f:
PLANNING_PROMPT_TEMPLATE = f.read()
with open(
os.path.join(current_dir, "application/prompts", "react_final_prompt.txt"), "r"
) as f:
FINAL_PROMPT_TEMPLATE = f.read()
class ReActAgent(BaseAgent):
"""
Research and Action (ReAct) Agent - Advanced reasoning agent with iterative planning.
Implements a think-act-observe loop for complex problem-solving:
1. Creates a strategic plan based on the query
2. Executes tools and gathers observations
3. Iteratively refines approach until satisfied
4. Synthesizes final answer from all observations
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.plan: str = ""
self.observations: List[str] = []
def _gen_inner(
self, query: str, log_context: LogContext
) -> Generator[Dict, None, None]:
"""Execute ReAct reasoning loop with planning, action, and observation cycles"""
self._reset_state()
tools_dict = (
self._get_tools(self.user_api_key)
if self.user_api_key
else self._get_user_tools(self.user)
)
self._prepare_tools(tools_dict)
for iteration in range(1, MAX_ITERATIONS_REASONING + 1):
yield {"thought": f"Reasoning... (iteration {iteration})\n\n"}
yield from self._planning_phase(query, log_context)
if not self.plan:
logger.warning(
f"ReActAgent: No plan generated in iteration {iteration}"
)
break
self.observations.append(f"Plan (iteration {iteration}): {self.plan}")
satisfied = yield from self._execution_phase(query, tools_dict, log_context)
if satisfied:
logger.info("ReActAgent: Goal satisfied, stopping reasoning loop")
break
yield from self._synthesis_phase(query, log_context)
def _reset_state(self):
"""Reset agent state for new query"""
self.plan = ""
self.observations = []
def _planning_phase(
self, query: str, log_context: LogContext
) -> Generator[Dict, None, None]:
"""Generate strategic plan for query"""
logger.info("ReActAgent: Creating plan...")
plan_prompt = self._build_planning_prompt(query)
messages = [{"role": "user", "content": plan_prompt}]
plan_stream = self.llm.gen_stream(
model=self.model_id,
messages=messages,
tools=self.tools if self.tools else None,
)
if log_context:
log_context.stacks.append(
{"component": "planning_llm", "data": build_stack_data(self.llm)}
)
plan_parts = []
for chunk in plan_stream:
content = self._extract_content(chunk)
if content:
plan_parts.append(content)
yield {"thought": content}
self.plan = "".join(plan_parts)
def _execution_phase(
self, query: str, tools_dict: Dict, log_context: LogContext
) -> Generator[bool, None, None]:
"""Execute plan with tool calls and observations"""
execution_prompt = self._build_execution_prompt(query)
messages = self._build_messages(execution_prompt, query)
llm_response = self._llm_gen(messages, log_context)
initial_content = self._extract_content(llm_response)
if initial_content:
self.observations.append(f"Initial response: {initial_content}")
processed_response = self._llm_handler(
llm_response, tools_dict, messages, log_context
)
for tool_call in self.tool_calls:
observation = (
f"Executed: {tool_call.get('tool_name', 'Unknown')} "
f"with args {tool_call.get('arguments', {})}. "
f"Result: {str(tool_call.get('result', ''))[:200]}"
)
self.observations.append(observation)
final_content = self._extract_content(processed_response)
if final_content:
self.observations.append(f"Response after tools: {final_content}")
if log_context:
log_context.stacks.append(
{
"component": "agent_tool_calls",
"data": {"tool_calls": self.tool_calls.copy()},
}
)
yield {"sources": self.retrieved_docs}
yield {"tool_calls": self._get_truncated_tool_calls()}
return "SATISFIED" in (final_content or "")
def _synthesis_phase(
self, query: str, log_context: LogContext
) -> Generator[Dict, None, None]:
"""Synthesize final answer from all observations"""
logger.info("ReActAgent: Generating final answer...")
final_prompt = self._build_final_answer_prompt(query)
messages = [{"role": "user", "content": final_prompt}]
final_stream = self.llm.gen_stream(
model=self.model_id, messages=messages, tools=None
)
if log_context:
log_context.stacks.append(
{"component": "final_answer_llm", "data": build_stack_data(self.llm)}
)
for chunk in final_stream:
content = self._extract_content(chunk)
if content:
yield {"answer": content}
def _build_planning_prompt(self, query: str) -> str:
"""Build planning phase prompt"""
prompt = PLANNING_PROMPT_TEMPLATE.replace("{query}", query)
prompt = prompt.replace("{prompt}", self.prompt or "")
prompt = prompt.replace("{summaries}", "")
prompt = prompt.replace("{observations}", "\n".join(self.observations))
return prompt
def _build_execution_prompt(self, query: str) -> str:
"""Build execution phase prompt with plan and observations"""
observations_str = "\n".join(self.observations)
if len(observations_str) > 20000:
observations_str = observations_str[:20000] + "\n...[truncated]"
return (
f"{self.prompt or ''}\n\n"
f"Follow this plan:\n{self.plan}\n\n"
f"Observations:\n{observations_str}\n\n"
f"If sufficient data exists to answer '{query}', respond with 'SATISFIED'. "
f"Otherwise, continue executing the plan."
)
def _build_final_answer_prompt(self, query: str) -> str:
"""Build final synthesis prompt"""
observations_str = "\n".join(self.observations)
if len(observations_str) > 10000:
observations_str = observations_str[:10000] + "\n...[truncated]"
logger.warning("ReActAgent: Observations truncated for final answer")
return FINAL_PROMPT_TEMPLATE.format(query=query, observations=observations_str)
def _extract_content(self, response: Any) -> str:
"""Extract text content from various LLM response formats"""
if not response:
return ""
collected = []
if isinstance(response, str):
return response
if hasattr(response, "message") and hasattr(response.message, "content"):
if response.message.content:
return response.message.content
if hasattr(response, "choices") and response.choices:
if hasattr(response.choices[0], "message"):
content = response.choices[0].message.content
if content:
return content
if hasattr(response, "content") and isinstance(response.content, list):
if response.content and hasattr(response.content[0], "text"):
return response.content[0].text
try:
for chunk in response:
content_piece = ""
if hasattr(chunk, "choices") and chunk.choices:
if hasattr(chunk.choices[0], "delta"):
delta_content = chunk.choices[0].delta.content
if delta_content:
content_piece = delta_content
elif hasattr(chunk, "type") and chunk.type == "content_block_delta":
if hasattr(chunk, "delta") and hasattr(chunk.delta, "text"):
content_piece = chunk.delta.text
elif isinstance(chunk, str):
content_piece = chunk
if content_piece:
collected.append(content_piece)
except (TypeError, AttributeError):
logger.debug(
f"Response not iterable or unexpected format: {type(response)}"
)
except Exception as e:
logger.error(f"Error extracting content: {e}")
return "".join(collected)

View File

@@ -0,0 +1,698 @@
import json
import logging
import os
import time
from typing import Dict, Generator, List, Optional
from application.agents.base import BaseAgent
from application.agents.tool_executor import ToolExecutor
from application.agents.tools.internal_search import (
INTERNAL_TOOL_ID,
add_internal_search_tool,
)
from application.agents.tools.think import THINK_TOOL_ENTRY, THINK_TOOL_ID
from application.logging import LogContext
logger = logging.getLogger(__name__)
# Defaults (can be overridden via constructor)
DEFAULT_MAX_STEPS = 6
DEFAULT_MAX_SUB_ITERATIONS = 5
DEFAULT_TIMEOUT_SECONDS = 300 # 5 minutes
DEFAULT_TOKEN_BUDGET = 100_000
DEFAULT_PARALLEL_WORKERS = 3
# Adaptive depth caps per complexity level
COMPLEXITY_CAPS = {
"simple": 2,
"moderate": 4,
"complex": 6,
}
_PROMPTS_DIR = os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
"prompts",
"research",
)
def _load_prompt(name: str) -> str:
with open(os.path.join(_PROMPTS_DIR, name), "r") as f:
return f.read()
CLARIFICATION_PROMPT = _load_prompt("clarification.txt")
PLANNING_PROMPT = _load_prompt("planning.txt")
STEP_PROMPT = _load_prompt("step.txt")
SYNTHESIS_PROMPT = _load_prompt("synthesis.txt")
# ---------------------------------------------------------------------------
# CitationManager
# ---------------------------------------------------------------------------
class CitationManager:
"""Tracks and deduplicates citations across research steps."""
def __init__(self):
self.citations: Dict[int, Dict] = {}
self._counter = 0
def add(self, doc: Dict) -> int:
"""Register a source, return its citation number. Deduplicates by source."""
source = doc.get("source", "")
title = doc.get("title", "")
for num, existing in self.citations.items():
if existing.get("source") == source and existing.get("title") == title:
return num
self._counter += 1
self.citations[self._counter] = doc
return self._counter
def add_docs(self, docs: List[Dict]) -> str:
"""Register multiple docs, return formatted citation mapping text."""
mapping_lines = []
for doc in docs:
num = self.add(doc)
title = doc.get("title", "Untitled")
mapping_lines.append(f"[{num}] {title}")
return "\n".join(mapping_lines)
def format_references(self) -> str:
"""Generate [N] -> source mapping for report footer."""
if not self.citations:
return "No sources found."
lines = []
for num, doc in sorted(self.citations.items()):
title = doc.get("title", "Untitled")
source = doc.get("source", "Unknown")
filename = doc.get("filename", "")
display = filename or title
lines.append(f"[{num}] {display}{source}")
return "\n".join(lines)
def get_all_docs(self) -> List[Dict]:
return list(self.citations.values())
# ---------------------------------------------------------------------------
# ResearchAgent
# ---------------------------------------------------------------------------
class ResearchAgent(BaseAgent):
"""Multi-step research agent with parallel execution and budget controls.
Orchestrates: Plan -> Research (per step, optionally parallel) -> Synthesize.
"""
def __init__(
self,
retriever_config: Optional[Dict] = None,
max_steps: int = DEFAULT_MAX_STEPS,
max_sub_iterations: int = DEFAULT_MAX_SUB_ITERATIONS,
timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
token_budget: int = DEFAULT_TOKEN_BUDGET,
parallel_workers: int = DEFAULT_PARALLEL_WORKERS,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.retriever_config = retriever_config or {}
self.max_steps = max_steps
self.max_sub_iterations = max_sub_iterations
self.timeout_seconds = timeout_seconds
self.token_budget = token_budget
self.parallel_workers = parallel_workers
self.citations = CitationManager()
self._start_time: float = 0
self._tokens_used: int = 0
self._last_token_snapshot: int = 0
# ------------------------------------------------------------------
# Budget & timeout helpers
# ------------------------------------------------------------------
def _is_timed_out(self) -> bool:
return (time.monotonic() - self._start_time) >= self.timeout_seconds
def _elapsed(self) -> float:
return round(time.monotonic() - self._start_time, 1)
def _track_tokens(self, count: int):
self._tokens_used += count
def _budget_remaining(self) -> int:
return max(self.token_budget - self._tokens_used, 0)
def _is_over_budget(self) -> bool:
return self._tokens_used >= self.token_budget
def _snapshot_llm_tokens(self) -> int:
"""Read current token usage from LLM and return delta since last snapshot."""
current = self.llm.token_usage.get("prompt_tokens", 0) + self.llm.token_usage.get("generated_tokens", 0)
delta = current - self._last_token_snapshot
self._last_token_snapshot = current
return delta
# ------------------------------------------------------------------
# Main orchestration
# ------------------------------------------------------------------
def _gen_inner(
self, query: str, log_context: LogContext
) -> Generator[Dict, None, None]:
self._start_time = time.monotonic()
tools_dict = self._setup_tools()
# Phase 0: Clarification (skip if user is responding to a prior clarification)
if not self._is_follow_up():
clarification = self._clarification_phase(query)
if clarification:
yield {"metadata": {"is_clarification": True}}
yield {"answer": clarification}
yield {"sources": []}
yield {"tool_calls": []}
log_context.stacks.append(
{"component": "agent", "data": {"clarification": True}}
)
return
# Phase 1: Planning (with adaptive depth)
yield {"type": "research_progress", "data": {"status": "planning"}}
plan, complexity = self._planning_phase(query)
if not plan:
logger.warning("ResearchAgent: Planning produced no steps, falling back")
plan = [{"query": query, "rationale": "Direct investigation"}]
complexity = "simple"
yield {
"type": "research_plan",
"data": {"steps": plan, "complexity": complexity},
}
# Phase 2: Research each step (yields progress events in real-time)
intermediate_reports = []
for i, step in enumerate(plan):
step_num = i + 1
step_query = step.get("query", query)
if self._is_timed_out():
logger.warning(
f"ResearchAgent: Timeout at step {step_num}/{len(plan)} "
f"({self._elapsed()}s)"
)
break
if self._is_over_budget():
logger.warning(
f"ResearchAgent: Token budget exhausted at step {step_num}/{len(plan)}"
)
break
yield {
"type": "research_progress",
"data": {
"step": step_num,
"total": len(plan),
"query": step_query,
"status": "researching",
},
}
report = self._research_step(step_query, tools_dict)
intermediate_reports.append({"step": step, "content": report})
yield {
"type": "research_progress",
"data": {
"step": step_num,
"total": len(plan),
"query": step_query,
"status": "complete",
},
}
# Phase 3: Synthesis (streaming)
if self._is_timed_out():
logger.warning(
f"ResearchAgent: Timeout ({self._elapsed()}s) before synthesis, "
f"synthesizing with {len(intermediate_reports)} reports"
)
yield {
"type": "research_progress",
"data": {
"status": "synthesizing",
"elapsed_seconds": self._elapsed(),
"tokens_used": self._tokens_used,
},
}
yield from self._synthesis_phase(
query, plan, intermediate_reports, tools_dict, log_context
)
# Sources and tool calls
self.retrieved_docs = self.citations.get_all_docs()
yield {"sources": self.retrieved_docs}
yield {"tool_calls": self._get_truncated_tool_calls()}
logger.info(
f"ResearchAgent completed: {len(intermediate_reports)}/{len(plan)} steps, "
f"{self._elapsed()}s, ~{self._tokens_used} tokens"
)
log_context.stacks.append(
{"component": "agent", "data": {"tool_calls": self.tool_calls.copy()}}
)
# ------------------------------------------------------------------
# Tool setup
# ------------------------------------------------------------------
def _setup_tools(self) -> Dict:
"""Build tools_dict with user tools + internal search + think."""
tools_dict = self.tool_executor.get_tools()
add_internal_search_tool(tools_dict, self.retriever_config)
think_entry = dict(THINK_TOOL_ENTRY)
think_entry["config"] = {}
tools_dict[THINK_TOOL_ID] = think_entry
self._prepare_tools(tools_dict)
return tools_dict
# ------------------------------------------------------------------
# Phase 0: Clarification
# ------------------------------------------------------------------
def _is_follow_up(self) -> bool:
"""Check if the user is responding to a prior clarification.
Uses the metadata flag stored in the conversation DB — no string matching.
Only skip clarification when the last query was explicitly flagged
as a clarification by this agent.
"""
if not self.chat_history:
return False
last = self.chat_history[-1]
meta = last.get("metadata", {})
return bool(meta.get("is_clarification"))
def _clarification_phase(self, question: str) -> Optional[str]:
"""Ask the LLM whether the question needs clarification.
Returns formatted clarification text if needed, or None to proceed.
Uses response_format to force valid JSON output.
"""
messages = [
{"role": "system", "content": CLARIFICATION_PROMPT},
{"role": "user", "content": question},
]
try:
response = self.llm.gen(
model=self.model_id,
messages=messages,
tools=None,
response_format={"type": "json_object"},
)
text = self._extract_text(response)
self._track_tokens(self._snapshot_llm_tokens())
logger.info(f"ResearchAgent clarification response: {text[:300]}")
data = self._parse_clarification_json(text)
if not data or not data.get("needs_clarification"):
return None
questions = data.get("questions", [])
if not questions:
return None
# Format as a friendly response
lines = [
"Before I begin researching, I'd like to clarify a few things:\n"
]
for i, q in enumerate(questions[:3], 1):
lines.append(f"{i}. {q}")
lines.append(
"\nPlease provide these details and I'll start the research."
)
return "\n".join(lines)
except Exception as e:
logger.error(f"Clarification phase failed: {e}", exc_info=True)
return None # proceed with research on failure
def _parse_clarification_json(self, text: str) -> Optional[Dict]:
"""Parse clarification JSON from LLM response."""
try:
return json.loads(text)
except json.JSONDecodeError:
pass
# Try extracting from code fences
for marker in ["```json", "```"]:
if marker in text:
start = text.index(marker) + len(marker)
end = text.index("```", start) if "```" in text[start:] else len(text)
try:
return json.loads(text[start:end].strip())
except (json.JSONDecodeError, ValueError):
pass
# Try finding JSON object
for i, ch in enumerate(text):
if ch == "{":
for j in range(len(text) - 1, i, -1):
if text[j] == "}":
try:
return json.loads(text[i : j + 1])
except json.JSONDecodeError:
continue
break
return None
# ------------------------------------------------------------------
# Phase 1: Planning (with adaptive depth)
# ------------------------------------------------------------------
def _planning_phase(self, question: str) -> tuple[List[Dict], str]:
"""Decompose the question into research steps via LLM.
Returns (steps, complexity) where complexity is simple/moderate/complex.
"""
messages = [
{"role": "system", "content": PLANNING_PROMPT},
{"role": "user", "content": question},
]
try:
response = self.llm.gen(
model=self.model_id,
messages=messages,
tools=None,
response_format={"type": "json_object"},
)
text = self._extract_text(response)
self._track_tokens(self._snapshot_llm_tokens())
logger.info(f"ResearchAgent planning LLM response: {text[:500]}")
plan_data = self._parse_plan_json(text)
if isinstance(plan_data, dict):
complexity = plan_data.get("complexity", "moderate")
steps = plan_data.get("steps", [])
else:
complexity = "moderate"
steps = plan_data
# Adaptive depth: cap steps based on assessed complexity
cap = COMPLEXITY_CAPS.get(complexity, self.max_steps)
cap = min(cap, self.max_steps)
steps = steps[:cap]
logger.info(
f"ResearchAgent plan: complexity={complexity}, "
f"steps={len(steps)} (cap={cap})"
)
return steps, complexity
except Exception as e:
logger.error(f"Planning phase failed: {e}", exc_info=True)
return (
[{"query": question, "rationale": "Direct investigation (planning failed)"}],
"simple",
)
def _parse_plan_json(self, text: str):
"""Extract JSON plan from LLM response. Returns dict or list."""
# Try direct parse
try:
data = json.loads(text)
if isinstance(data, dict) and "steps" in data:
return data
if isinstance(data, list):
return data
except json.JSONDecodeError:
pass
# Try extracting from markdown code fences
for marker in ["```json", "```"]:
if marker in text:
start = text.index(marker) + len(marker)
end = text.index("```", start) if "```" in text[start:] else len(text)
try:
data = json.loads(text[start:end].strip())
if isinstance(data, dict) and "steps" in data:
return data
if isinstance(data, list):
return data
except (json.JSONDecodeError, ValueError):
pass
# Try finding JSON object in text
for i, ch in enumerate(text):
if ch == "{":
for j in range(len(text) - 1, i, -1):
if text[j] == "}":
try:
data = json.loads(text[i : j + 1])
if isinstance(data, dict) and "steps" in data:
return data
except json.JSONDecodeError:
continue
break
logger.warning(f"Could not parse plan JSON from: {text[:200]}")
return []
# ------------------------------------------------------------------
# Phase 2: Research step (core loop)
# ------------------------------------------------------------------
def _research_step(self, step_query: str, tools_dict: Dict) -> str:
"""Run a focused research loop for one sub-question (sequential path)."""
report = self._research_step_with_executor(
step_query, tools_dict, self.tool_executor
)
self._collect_step_sources()
return report
def _research_step_with_executor(
self, step_query: str, tools_dict: Dict, executor: ToolExecutor
) -> str:
"""Core research loop. Works with any ToolExecutor instance."""
system_prompt = STEP_PROMPT.replace("{step_query}", step_query)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": step_query},
]
last_search_empty = False
for iteration in range(self.max_sub_iterations):
# Check timeout and budget
if self._is_timed_out():
logger.info(
f"Research step '{step_query[:50]}' timed out at iteration {iteration}"
)
break
if self._is_over_budget():
logger.info(
f"Research step '{step_query[:50]}' hit token budget at iteration {iteration}"
)
break
try:
response = self.llm.gen(
model=self.model_id,
messages=messages,
tools=self.tools if self.tools else None,
)
self._track_tokens(self._snapshot_llm_tokens())
except Exception as e:
logger.error(
f"Research step LLM call failed (iteration {iteration}): {e}",
exc_info=True,
)
break
parsed = self.llm_handler.parse_response(response)
if not parsed.requires_tool_call:
return parsed.content or "No findings for this step."
# Execute tool calls
messages, last_search_empty = self._execute_step_tools_with_refinement(
parsed.tool_calls, tools_dict, messages, executor, last_search_empty
)
# Max iterations / timeout / budget — ask for summary
messages.append(
{
"role": "user",
"content": "Please summarize your findings so far based on the information gathered.",
}
)
try:
response = self.llm.gen(
model=self.model_id, messages=messages, tools=None
)
self._track_tokens(self._snapshot_llm_tokens())
text = self._extract_text(response)
return text or "Research step completed."
except Exception:
return "Research step completed."
def _execute_step_tools_with_refinement(
self,
tool_calls,
tools_dict: Dict,
messages: List[Dict],
executor: ToolExecutor,
last_search_empty: bool,
) -> tuple[List[Dict], bool]:
"""Execute tool calls with query refinement on empty results.
Returns (updated_messages, was_last_search_empty).
"""
search_returned_empty = False
for call in tool_calls:
gen = executor.execute(
tools_dict, call, self.llm.__class__.__name__
)
result = None
call_id = None
while True:
try:
event = next(gen)
# Log tool_call status events instead of discarding them
if isinstance(event, dict) and event.get("type") == "tool_call":
logger.debug(
"Tool %s status: %s",
event.get("data", {}).get("action_name", ""),
event.get("data", {}).get("status", ""),
)
except StopIteration as e:
result, call_id = e.value
break
# Detect empty search results for refinement
is_search = "search" in (call.name or "").lower()
result_str = str(result) if result else ""
if is_search and "No documents found" in result_str:
search_returned_empty = True
if last_search_empty:
# Two consecutive empty searches — inject refinement hint
result_str += (
"\n\nHint: Previous search also returned no results. "
"Try a very different query with different keywords, "
"or broaden your search terms."
)
result = result_str
import json as _json
args_str = (
_json.dumps(call.arguments)
if isinstance(call.arguments, dict)
else call.arguments
)
messages.append({
"role": "assistant",
"content": None,
"tool_calls": [{
"id": call_id,
"type": "function",
"function": {"name": call.name, "arguments": args_str},
}],
})
tool_message = self.llm_handler.create_tool_message(call, result)
messages.append(tool_message)
return messages, search_returned_empty
def _collect_step_sources(self):
"""Collect sources from InternalSearchTool and register with CitationManager."""
cache_key = f"internal_search:{INTERNAL_TOOL_ID}:{self.user or ''}"
tool = self.tool_executor._loaded_tools.get(cache_key)
if tool and hasattr(tool, "retrieved_docs"):
for doc in tool.retrieved_docs:
self.citations.add(doc)
# ------------------------------------------------------------------
# Phase 3: Synthesis
# ------------------------------------------------------------------
def _synthesis_phase(
self,
question: str,
plan: List[Dict],
intermediate_reports: List[Dict],
tools_dict: Dict,
log_context: LogContext,
) -> Generator[Dict, None, None]:
"""Compile all findings into a final cited report (streaming)."""
plan_lines = []
for i, step in enumerate(plan, 1):
plan_lines.append(
f"{i}. {step.get('query', 'Unknown')}{step.get('rationale', '')}"
)
plan_summary = "\n".join(plan_lines)
findings_parts = []
for i, report in enumerate(intermediate_reports, 1):
step_query = report["step"].get("query", "Unknown")
content = report["content"]
findings_parts.append(
f"--- Step {i}: {step_query} ---\n{content}"
)
findings = "\n\n".join(findings_parts)
references = self.citations.format_references()
synthesis_prompt = SYNTHESIS_PROMPT.replace("{question}", question)
synthesis_prompt = synthesis_prompt.replace("{plan_summary}", plan_summary)
synthesis_prompt = synthesis_prompt.replace("{findings}", findings)
synthesis_prompt = synthesis_prompt.replace("{references}", references)
messages = [
{"role": "system", "content": synthesis_prompt},
{"role": "user", "content": f"Please write the research report for: {question}"},
]
llm_response = self.llm.gen_stream(
model=self.model_id, messages=messages, tools=None
)
if log_context:
from application.logging import build_stack_data
log_context.stacks.append(
{"component": "synthesis_llm", "data": build_stack_data(self.llm)}
)
yield from self._handle_response(
llm_response, tools_dict, messages, log_context
)
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _extract_text(self, response) -> str:
"""Extract text content from a non-streaming LLM response."""
if isinstance(response, str):
return response
if hasattr(response, "message") and hasattr(response.message, "content"):
return response.message.content or ""
if hasattr(response, "choices") and response.choices:
choice = response.choices[0]
if hasattr(choice, "message") and hasattr(choice.message, "content"):
return choice.message.content or ""
if hasattr(response, "content") and isinstance(response.content, list):
if response.content and hasattr(response.content[0], "text"):
return response.content[0].text or ""
return str(response) if response else ""

View File

@@ -0,0 +1,494 @@
import logging
import uuid
from collections import Counter
from typing import Dict, List, Optional, Tuple
from application.agents.tools.tool_action_parser import ToolActionParser
from application.agents.tools.tool_manager import ToolManager
from application.security.encryption import decrypt_credentials
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.repositories.user_tools import UserToolsRepository
from application.storage.db.session import db_readonly
logger = logging.getLogger(__name__)
class ToolExecutor:
"""Handles tool discovery, preparation, and execution.
Extracted from BaseAgent to separate concerns and enable tool caching.
"""
def __init__(
self,
user_api_key: Optional[str] = None,
user: Optional[str] = None,
decoded_token: Optional[Dict] = None,
):
self.user_api_key = user_api_key
self.user = user
self.decoded_token = decoded_token
self.tool_calls: List[Dict] = []
self._loaded_tools: Dict[str, object] = {}
self.conversation_id: Optional[str] = None
self.client_tools: Optional[List[Dict]] = None
self._name_to_tool: Dict[str, Tuple[str, str]] = {}
self._tool_to_name: Dict[Tuple[str, str], str] = {}
def get_tools(self) -> Dict[str, Dict]:
"""Load tool configs from DB based on user context.
If *client_tools* have been set on this executor, they are
automatically merged into the returned dict.
"""
if self.user_api_key:
tools = self._get_tools_by_api_key(self.user_api_key)
else:
tools = self._get_user_tools(self.user or "local")
if self.client_tools:
self.merge_client_tools(tools, self.client_tools)
return tools
def _get_tools_by_api_key(self, api_key: str) -> Dict[str, Dict]:
# Per-operation session: the answer pipeline spans a long-lived
# generator; wrapping it in a single connection would pin a PG
# conn for the whole stream. Open, fetch, close.
with db_readonly() as conn:
agent_data = AgentsRepository(conn).find_by_key(api_key)
tool_ids = agent_data.get("tools", []) if agent_data else []
if not tool_ids:
return {}
tools_repo = UserToolsRepository(conn)
tools: List[Dict] = []
owner = (agent_data.get("user_id") or agent_data.get("user")) if agent_data else None
for tid in tool_ids:
row = None
if owner:
row = tools_repo.get_any(str(tid), owner)
if row is not None:
tools.append(row)
return {str(tool["id"]): tool for tool in tools} if tools else {}
def _get_user_tools(self, user: str = "local") -> Dict[str, Dict]:
with db_readonly() as conn:
user_tools = UserToolsRepository(conn).list_active_for_user(user)
return {str(i): tool for i, tool in enumerate(user_tools)}
def merge_client_tools(
self, tools_dict: Dict, client_tools: List[Dict]
) -> Dict:
"""Merge client-provided tool definitions into tools_dict.
Client tools use the standard function-calling format::
[{"type": "function", "function": {"name": "get_weather",
"description": "...", "parameters": {...}}}]
They are stored in *tools_dict* with ``client_side: True`` so that
:meth:`check_pause` returns a pause signal instead of trying to
execute them server-side.
Args:
tools_dict: The mutable server tools dict (will be modified in place).
client_tools: List of tool definitions in function-calling format.
Returns:
The updated *tools_dict* (same reference, for convenience).
"""
for i, ct in enumerate(client_tools):
func = ct.get("function", ct) # tolerate bare {"name":..} too
name = func.get("name", f"clienttool{i}")
tool_id = f"ct{i}"
tools_dict[tool_id] = {
"name": name,
"client_side": True,
"actions": [
{
"name": name,
"description": func.get("description", ""),
"active": True,
"parameters": func.get("parameters", {}),
}
],
}
return tools_dict
def prepare_tools_for_llm(self, tools_dict: Dict) -> List[Dict]:
"""Convert tool configs to LLM function schemas.
Action names are kept clean for the LLM:
- Unique action names appear as-is (e.g. ``get_weather``).
- Duplicate action names get numbered suffixes (e.g. ``search_1``,
``search_2``).
A reverse mapping is stored in ``_name_to_tool`` so that tool calls
can be routed back to the correct ``(tool_id, action_name)`` without
brittle string splitting.
"""
# Pass 1: collect entries and count action name occurrences
entries: List[Tuple[str, str, Dict, bool]] = [] # (tool_id, action_name, action, is_client)
name_counts: Counter = Counter()
for tool_id, tool in tools_dict.items():
is_api = tool["name"] == "api_tool"
is_client = tool.get("client_side", False)
if is_api and "actions" not in tool.get("config", {}):
continue
if not is_api and "actions" not in tool:
continue
actions = (
tool["config"]["actions"].values()
if is_api
else tool["actions"]
)
for action in actions:
if not action.get("active", True):
continue
entries.append((tool_id, action["name"], action, is_client))
name_counts[action["name"]] += 1
# Pass 2: assign LLM-visible names and build mappings
self._name_to_tool = {}
self._tool_to_name = {}
collision_counters: Dict[str, int] = {}
all_llm_names: set = set()
result = []
for tool_id, action_name, action, is_client in entries:
if name_counts[action_name] == 1:
llm_name = action_name
else:
counter = collision_counters.get(action_name, 1)
candidate = f"{action_name}_{counter}"
# Skip if candidate collides with a unique action name
while candidate in all_llm_names or (
candidate in name_counts and name_counts[candidate] == 1
):
counter += 1
candidate = f"{action_name}_{counter}"
collision_counters[action_name] = counter + 1
llm_name = candidate
all_llm_names.add(llm_name)
self._name_to_tool[llm_name] = (tool_id, action_name)
self._tool_to_name[(tool_id, action_name)] = llm_name
if is_client:
params = action.get("parameters", {})
else:
params = self._build_tool_parameters(action)
result.append({
"type": "function",
"function": {
"name": llm_name,
"description": action.get("description", ""),
"parameters": params,
},
})
return result
def _build_tool_parameters(self, action: Dict) -> Dict:
params = {"type": "object", "properties": {}, "required": []}
for param_type in ["query_params", "headers", "body", "parameters"]:
if param_type in action and action[param_type].get("properties"):
for k, v in action[param_type]["properties"].items():
if v.get("filled_by_llm", True):
params["properties"][k] = {
key: value
for key, value in v.items()
if key not in ("filled_by_llm", "value", "required")
}
if v.get("required", False):
params["required"].append(k)
return params
def check_pause(
self, tools_dict: Dict, call, llm_class_name: str
) -> Optional[Dict]:
"""Check if a tool call requires pausing for approval or client execution.
Returns a dict describing the pending action if pause is needed, None otherwise.
"""
parser = ToolActionParser(llm_class_name, name_mapping=self._name_to_tool)
tool_id, action_name, call_args = parser.parse_args(call)
call_id = getattr(call, "id", None) or str(uuid.uuid4())
llm_name = getattr(call, "name", "")
if tool_id is None or action_name is None or tool_id not in tools_dict:
return None # Will be handled as error by execute()
tool_data = tools_dict[tool_id]
# Client-side tools
if tool_data.get("client_side"):
return {
"call_id": call_id,
"name": llm_name,
"tool_name": tool_data.get("name", "unknown"),
"tool_id": tool_id,
"action_name": action_name,
"llm_name": llm_name,
"arguments": call_args if isinstance(call_args, dict) else {},
"pause_type": "requires_client_execution",
"thought_signature": getattr(call, "thought_signature", None),
}
# Approval required
if tool_data["name"] == "api_tool":
action_data = tool_data.get("config", {}).get("actions", {}).get(
action_name, {}
)
else:
action_data = next(
(a for a in tool_data.get("actions", []) if a["name"] == action_name),
{},
)
if action_data.get("require_approval"):
return {
"call_id": call_id,
"name": llm_name,
"tool_name": tool_data.get("name", "unknown"),
"tool_id": tool_id,
"action_name": action_name,
"llm_name": llm_name,
"arguments": call_args if isinstance(call_args, dict) else {},
"pause_type": "awaiting_approval",
"thought_signature": getattr(call, "thought_signature", None),
}
return None
def execute(self, tools_dict: Dict, call, llm_class_name: str):
"""Execute a tool call. Yields status events, returns (result, call_id)."""
parser = ToolActionParser(llm_class_name, name_mapping=self._name_to_tool)
tool_id, action_name, call_args = parser.parse_args(call)
llm_name = getattr(call, "name", "unknown")
call_id = getattr(call, "id", None) or str(uuid.uuid4())
if tool_id is None or action_name is None:
error_message = f"Error: Failed to parse LLM tool call. Tool name: {llm_name}"
logger.error(error_message)
tool_call_data = {
"tool_name": "unknown",
"call_id": call_id,
"action_name": llm_name,
"arguments": call_args or {},
"result": f"Failed to parse tool call. Invalid tool name format: {llm_name}",
}
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
self.tool_calls.append(tool_call_data)
return "Failed to parse tool call.", call_id
if tool_id not in tools_dict:
error_message = f"Error: Tool ID '{tool_id}' extracted from LLM call not found in available tools_dict. Available IDs: {list(tools_dict.keys())}"
logger.error(error_message)
tool_call_data = {
"tool_name": "unknown",
"call_id": call_id,
"action_name": llm_name,
"arguments": call_args,
"result": f"Tool with ID {tool_id} not found. Available tools: {list(tools_dict.keys())}",
}
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
self.tool_calls.append(tool_call_data)
return f"Tool with ID {tool_id} not found.", call_id
tool_call_data = {
"tool_name": tools_dict[tool_id]["name"],
"call_id": call_id,
"action_name": llm_name,
"arguments": call_args,
}
yield {"type": "tool_call", "data": {**tool_call_data, "status": "pending"}}
tool_data = tools_dict[tool_id]
action_data = (
tool_data["config"]["actions"][action_name]
if tool_data["name"] == "api_tool"
else next(
action
for action in tool_data["actions"]
if action["name"] == action_name
)
)
query_params, headers, body, parameters = {}, {}, {}, {}
param_types = {
"query_params": query_params,
"headers": headers,
"body": body,
"parameters": parameters,
}
for param_type, target_dict in param_types.items():
if param_type in action_data and action_data[param_type].get("properties"):
for param, details in action_data[param_type]["properties"].items():
if (
param not in call_args
and "value" in details
and details["value"]
):
target_dict[param] = details["value"]
for param, value in call_args.items():
for param_type, target_dict in param_types.items():
if param_type in action_data and param in action_data[param_type].get(
"properties", {}
):
target_dict[param] = value
# Load tool (with caching)
tool = self._get_or_load_tool(
tool_data, tool_id, action_name,
headers=headers, query_params=query_params,
)
if tool is None:
error_message = (
f"Failed to load tool '{tool_data.get('name')}' (tool_id key={tool_id}): "
"missing 'id' on tool row."
)
logger.error(error_message)
tool_call_data["result"] = error_message
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
self.tool_calls.append(tool_call_data)
return error_message, call_id
resolved_arguments = (
{"query_params": query_params, "headers": headers, "body": body}
if tool_data["name"] == "api_tool"
else parameters
)
if tool_data["name"] == "api_tool":
logger.debug(
f"Executing api: {action_name} with query_params: {query_params}, headers: {headers}, body: {body}"
)
result = tool.execute_action(action_name, **body)
else:
logger.debug(f"Executing tool: {action_name} with args: {call_args}")
result = tool.execute_action(action_name, **parameters)
get_artifact_id = (
getattr(tool, "get_artifact_id", None)
if tool_data["name"] != "api_tool"
else None
)
artifact_id = None
if callable(get_artifact_id):
try:
artifact_id = get_artifact_id(action_name, **parameters)
except Exception:
logger.exception(
"Failed to extract artifact_id from tool %s for action %s",
tool_data["name"],
action_name,
)
artifact_id = str(artifact_id).strip() if artifact_id is not None else ""
if artifact_id:
tool_call_data["artifact_id"] = artifact_id
result_full = str(result)
tool_call_data["resolved_arguments"] = resolved_arguments
tool_call_data["result_full"] = result_full
tool_call_data["result"] = (
f"{result_full[:50]}..." if len(result_full) > 50 else result_full
)
stream_tool_call_data = {
key: value
for key, value in tool_call_data.items()
if key not in {"result_full", "resolved_arguments"}
}
yield {"type": "tool_call", "data": {**stream_tool_call_data, "status": "completed"}}
self.tool_calls.append(tool_call_data)
return result, call_id
def _get_or_load_tool(
self, tool_data: Dict, tool_id: str, action_name: str,
headers: Optional[Dict] = None, query_params: Optional[Dict] = None,
):
"""Load a tool, using cache when possible."""
cache_key = f"{tool_data['name']}:{tool_id}:{self.user or ''}"
if cache_key in self._loaded_tools:
return self._loaded_tools[cache_key]
tm = ToolManager(config={})
if tool_data["name"] == "api_tool":
action_config = tool_data["config"]["actions"][action_name]
tool_config = {
"url": action_config["url"],
"method": action_config["method"],
"headers": headers or {},
"query_params": query_params or {},
}
if "body_content_type" in action_config:
tool_config["body_content_type"] = action_config.get(
"body_content_type", "application/json"
)
tool_config["body_encoding_rules"] = action_config.get(
"body_encoding_rules", {}
)
else:
tool_config = tool_data["config"].copy() if tool_data["config"] else {}
if tool_config.get("encrypted_credentials") and self.user:
decrypted = decrypt_credentials(
tool_config["encrypted_credentials"], self.user
)
tool_config.update(decrypted)
tool_config["auth_credentials"] = decrypted
tool_config.pop("encrypted_credentials", None)
row_id = tool_data.get("id")
if not row_id:
logger.error(
"Tool data missing 'id' for tool name=%s (enumerate-key tool_id=%s); "
"skipping load to avoid binding a non-UUID downstream.",
tool_data.get("name"),
tool_id,
)
return None
tool_config["tool_id"] = str(row_id)
if self.conversation_id:
tool_config["conversation_id"] = self.conversation_id
if tool_data["name"] == "mcp_tool":
tool_config["query_mode"] = True
tool = tm.load_tool(
tool_data["name"],
tool_config=tool_config,
user_id=self.user,
)
# Don't cache api_tool since config varies by action
if tool_data["name"] != "api_tool":
self._loaded_tools[cache_key] = tool
return tool
def get_truncated_tool_calls(self) -> List[Dict]:
return [
{
"tool_name": tool_call.get("tool_name"),
"call_id": tool_call.get("call_id"),
"action_name": tool_call.get("action_name"),
"arguments": tool_call.get("arguments"),
"artifact_id": tool_call.get("artifact_id"),
"result": (
f"{str(tool_call['result'])[:50]}..."
if len(str(tool_call["result"])) > 50
else tool_call["result"]
),
"status": "completed",
}
for tool_call in self.tool_calls
]

View File

@@ -2,6 +2,8 @@ from abc import ABC, abstractmethod
class Tool(ABC):
internal: bool = False
@abstractmethod
def execute_action(self, action_name: str, **kwargs):
pass

View File

@@ -1,6 +1,11 @@
import logging
import requests
from application.agents.tools.base import Tool
logger = logging.getLogger(__name__)
class BraveSearchTool(Tool):
"""
@@ -41,7 +46,7 @@ class BraveSearchTool(Tool):
"""
Performs a web search using the Brave Search API.
"""
print(f"Performing Brave web search for: {query}")
logger.debug("Performing Brave web search for: %s", query)
url = f"{self.base_url}/web/search"
@@ -68,7 +73,7 @@ class BraveSearchTool(Tool):
"X-Subscription-Token": self.token,
}
response = requests.get(url, params=params, headers=headers)
response = requests.get(url, params=params, headers=headers, timeout=100)
if response.status_code == 200:
return {
@@ -94,7 +99,7 @@ class BraveSearchTool(Tool):
"""
Performs an image search using the Brave Search API.
"""
print(f"Performing Brave image search for: {query}")
logger.debug("Performing Brave image search for: %s", query)
url = f"{self.base_url}/images/search"
@@ -113,7 +118,7 @@ class BraveSearchTool(Tool):
"X-Subscription-Token": self.token,
}
response = requests.get(url, params=params, headers=headers)
response = requests.get(url, params=params, headers=headers, timeout=100)
if response.status_code == 200:
return {
@@ -177,6 +182,10 @@ class BraveSearchTool(Tool):
return {
"token": {
"type": "string",
"label": "API Key",
"description": "Brave Search API key for authentication",
"required": True,
"secret": True,
"order": 1,
},
}

View File

@@ -28,7 +28,7 @@ class CryptoPriceTool(Tool):
returns price in USD.
"""
url = f"https://min-api.cryptocompare.com/data/price?fsym={symbol.upper()}&tsyms={currency.upper()}"
response = requests.get(url)
response = requests.get(url, timeout=100)
if response.status_code == 200:
data = response.json()
if currency.upper() in data:

View File

@@ -1,5 +1,14 @@
import logging
import time
from typing import Any, Dict, Optional
from application.agents.tools.base import Tool
from duckduckgo_search import DDGS
logger = logging.getLogger(__name__)
MAX_RETRIES = 3
RETRY_DELAY = 2.0
DEFAULT_TIMEOUT = 15
class DuckDuckGoSearchTool(Tool):
@@ -10,71 +19,123 @@ class DuckDuckGoSearchTool(Tool):
def __init__(self, config):
self.config = config
self.timeout = config.get("timeout", DEFAULT_TIMEOUT)
def _get_ddgs_client(self):
from ddgs import DDGS
return DDGS(timeout=self.timeout)
def _execute_with_retry(self, operation, operation_name: str) -> Dict[str, Any]:
last_error = None
for attempt in range(1, MAX_RETRIES + 1):
try:
results = operation()
return {
"status_code": 200,
"results": list(results) if results else [],
"message": f"{operation_name} completed successfully.",
}
except Exception as e:
last_error = e
error_str = str(e).lower()
if "ratelimit" in error_str or "429" in error_str:
if attempt < MAX_RETRIES:
delay = RETRY_DELAY * attempt
logger.warning(
f"{operation_name} rate limited, retrying in {delay}s (attempt {attempt}/{MAX_RETRIES})"
)
time.sleep(delay)
continue
logger.error(f"{operation_name} failed: {e}")
break
return {
"status_code": 500,
"results": [],
"message": f"{operation_name} failed: {str(last_error)}",
}
def execute_action(self, action_name, **kwargs):
actions = {
"ddg_web_search": self._web_search,
"ddg_image_search": self._image_search,
"ddg_news_search": self._news_search,
}
if action_name in actions:
return actions[action_name](**kwargs)
else:
if action_name not in actions:
raise ValueError(f"Unknown action: {action_name}")
return actions[action_name](**kwargs)
def _web_search(
self,
query,
max_results=5,
):
print(f"Performing DuckDuckGo web search for: {query}")
query: str,
max_results: int = 5,
region: str = "wt-wt",
safesearch: str = "moderate",
timelimit: Optional[str] = None,
) -> Dict[str, Any]:
logger.info(f"DuckDuckGo web search: {query}")
try:
results = DDGS().text(
def operation():
client = self._get_ddgs_client()
return client.text(
query,
max_results=max_results,
region=region,
safesearch=safesearch,
timelimit=timelimit,
max_results=min(max_results, 20),
)
return {
"status_code": 200,
"results": results,
"message": "Web search completed successfully.",
}
except Exception as e:
return {
"status_code": 500,
"message": f"Web search failed: {str(e)}",
}
return self._execute_with_retry(operation, "Web search")
def _image_search(
self,
query,
max_results=5,
):
print(f"Performing DuckDuckGo image search for: {query}")
query: str,
max_results: int = 5,
region: str = "wt-wt",
safesearch: str = "moderate",
timelimit: Optional[str] = None,
) -> Dict[str, Any]:
logger.info(f"DuckDuckGo image search: {query}")
try:
results = DDGS().images(
keywords=query,
max_results=max_results,
def operation():
client = self._get_ddgs_client()
return client.images(
query,
region=region,
safesearch=safesearch,
timelimit=timelimit,
max_results=min(max_results, 50),
)
return {
"status_code": 200,
"results": results,
"message": "Image search completed successfully.",
}
except Exception as e:
return {
"status_code": 500,
"message": f"Image search failed: {str(e)}",
}
return self._execute_with_retry(operation, "Image search")
def _news_search(
self,
query: str,
max_results: int = 5,
region: str = "wt-wt",
safesearch: str = "moderate",
timelimit: Optional[str] = None,
) -> Dict[str, Any]:
logger.info(f"DuckDuckGo news search: {query}")
def operation():
client = self._get_ddgs_client()
return client.news(
query,
region=region,
safesearch=safesearch,
timelimit=timelimit,
max_results=min(max_results, 20),
)
return self._execute_with_retry(operation, "News search")
def get_actions_metadata(self):
return [
{
"name": "ddg_web_search",
"description": "Perform a web search using DuckDuckGo.",
"description": "Search the web using DuckDuckGo. Returns titles, URLs, and snippets.",
"parameters": {
"type": "object",
"properties": {
@@ -84,7 +145,15 @@ class DuckDuckGoSearchTool(Tool):
},
"max_results": {
"type": "integer",
"description": "Number of results to return (default: 5)",
"description": "Number of results (default: 5, max: 20)",
},
"region": {
"type": "string",
"description": "Region code (default: wt-wt for worldwide, us-en for US)",
},
"timelimit": {
"type": "string",
"description": "Time filter: d (day), w (week), m (month), y (year)",
},
},
"required": ["query"],
@@ -92,17 +161,43 @@ class DuckDuckGoSearchTool(Tool):
},
{
"name": "ddg_image_search",
"description": "Perform an image search using DuckDuckGo.",
"description": "Search for images using DuckDuckGo. Returns image URLs and metadata.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search query",
"description": "Image search query",
},
"max_results": {
"type": "integer",
"description": "Number of results to return (default: 5, max: 50)",
"description": "Number of results (default: 5, max: 50)",
},
"region": {
"type": "string",
"description": "Region code (default: wt-wt for worldwide)",
},
},
"required": ["query"],
},
},
{
"name": "ddg_news_search",
"description": "Search for news articles using DuckDuckGo. Returns recent news.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "News search query",
},
"max_results": {
"type": "integer",
"description": "Number of results (default: 5, max: 20)",
},
"timelimit": {
"type": "string",
"description": "Time filter: d (day), w (week), m (month)",
},
},
"required": ["query"],

View File

@@ -0,0 +1,458 @@
import json
import logging
from typing import Dict, List, Optional
from application.agents.tools.base import Tool
from application.core.settings import settings
from application.retriever.retriever_creator import RetrieverCreator
logger = logging.getLogger(__name__)
class InternalSearchTool(Tool):
"""Wraps the ClassicRAG retriever as an LLM-callable tool.
Instead of pre-fetching docs into the prompt, the LLM decides
when and what to search. Supports multiple searches per session.
Optional capabilities (enabled when sources have directory_structure):
- path_filter on search: restrict results to a specific file/folder
- list_files action: browse the file/folder structure
"""
internal = True
def __init__(self, config: Dict):
self.config = config
self.retrieved_docs: List[Dict] = []
self._retriever = None
self._directory_structure: Optional[Dict] = None
self._dir_structure_loaded = False
def _get_retriever(self):
if self._retriever is None:
self._retriever = RetrieverCreator.create_retriever(
self.config.get("retriever_name", "classic"),
source=self.config.get("source", {}),
chat_history=[],
prompt="",
chunks=int(self.config.get("chunks", 2)),
doc_token_limit=int(self.config.get("doc_token_limit", 50000)),
model_id=self.config.get("model_id", "docsgpt-local"),
user_api_key=self.config.get("user_api_key"),
agent_id=self.config.get("agent_id"),
llm_name=self.config.get("llm_name", settings.LLM_PROVIDER),
api_key=self.config.get("api_key", settings.API_KEY),
decoded_token=self.config.get("decoded_token"),
)
return self._retriever
def _get_directory_structure(self) -> Optional[Dict]:
"""Load directory structure from Postgres for the configured sources."""
if self._dir_structure_loaded:
return self._directory_structure
self._dir_structure_loaded = True
source = self.config.get("source", {})
active_docs = source.get("active_docs", [])
if not active_docs:
return None
try:
# Per-operation session: this tool runs inside the answer
# generator hot path, so we open a short-lived read
# connection for the batch lookup and release immediately.
from application.storage.db.repositories.sources import (
SourcesRepository,
)
from application.storage.db.session import db_readonly
if isinstance(active_docs, str):
active_docs = [active_docs]
decoded_token = self.config.get("decoded_token") or {}
user_id = decoded_token.get("sub") if decoded_token else None
merged_structure = {}
with db_readonly() as conn:
repo = SourcesRepository(conn)
for doc_id in active_docs:
try:
source_doc = repo.get_any(str(doc_id), user_id) if user_id else None
if not source_doc:
continue
dir_str = source_doc.get("directory_structure")
if dir_str:
if isinstance(dir_str, str):
dir_str = json.loads(dir_str)
source_name = source_doc.get("name", doc_id)
if len(active_docs) > 1:
merged_structure[source_name] = dir_str
else:
merged_structure = dir_str
except Exception as e:
logger.debug(f"Could not load dir structure for {doc_id}: {e}")
self._directory_structure = merged_structure if merged_structure else None
except Exception as e:
logger.debug(f"Failed to load directory structures: {e}")
return self._directory_structure
def execute_action(self, action_name: str, **kwargs):
if action_name == "search":
return self._execute_search(**kwargs)
elif action_name == "list_files":
return self._execute_list_files(**kwargs)
return f"Unknown action: {action_name}"
def _execute_search(self, **kwargs) -> str:
query = kwargs.get("query", "")
path_filter = kwargs.get("path_filter", "")
if not query:
return "Error: 'query' parameter is required."
try:
retriever = self._get_retriever()
docs = retriever.search(query)
except Exception as e:
logger.error(f"Internal search failed: {e}", exc_info=True)
return "Search failed: an internal error occurred."
if not docs:
return "No documents found matching your query."
# Apply path filter if specified
if path_filter:
path_lower = path_filter.lower()
docs = [
d
for d in docs
if path_lower in d.get("source", "").lower()
or path_lower in d.get("filename", "").lower()
or path_lower in d.get("title", "").lower()
]
if not docs:
return f"No documents found matching query '{query}' in path '{path_filter}'."
# Accumulate for source tracking
for doc in docs:
if doc not in self.retrieved_docs:
self.retrieved_docs.append(doc)
# Format results for the LLM
formatted = []
for i, doc in enumerate(docs, 1):
title = doc.get("title", "Untitled")
text = doc.get("text", "")
source = doc.get("source", "Unknown")
filename = doc.get("filename", "")
header = filename or title
formatted.append(f"[{i}] {header} (source: {source})\n{text}")
return "\n\n---\n\n".join(formatted)
def _execute_list_files(self, **kwargs) -> str:
path = kwargs.get("path", "")
dir_structure = self._get_directory_structure()
if not dir_structure:
return "No file structure available for the current sources."
# Navigate to the requested path
current = dir_structure
if path:
for part in path.strip("/").split("/"):
if not part:
continue
if isinstance(current, dict) and part in current:
current = current[part]
else:
return f"Path '{path}' not found in the file structure."
# Format the structure for the LLM
return self._format_structure(current, path or "/")
def _format_structure(self, node: Dict, current_path: str) -> str:
if not isinstance(node, dict):
return f"'{current_path}' is a file, not a directory."
lines = [f"File structure at '{current_path}':\n"]
folders = []
files = []
for name, value in sorted(node.items()):
if isinstance(value, dict):
# Check if it's a file metadata dict or a folder
if "type" in value or "size_bytes" in value or "token_count" in value:
# It's a file with metadata
size = value.get("token_count", "")
ftype = value.get("type", "")
info_parts = []
if ftype:
info_parts.append(ftype)
if size:
info_parts.append(f"{size} tokens")
info = f" ({', '.join(info_parts)})" if info_parts else ""
files.append(f" {name}{info}")
else:
# It's a folder
count = self._count_files(value)
folders.append(f" {name}/ ({count} items)")
else:
files.append(f" {name}")
if folders:
lines.append("Folders:")
lines.extend(folders)
if files:
lines.append("Files:")
lines.extend(files)
if not folders and not files:
lines.append(" (empty)")
return "\n".join(lines)
def _count_files(self, node: Dict) -> int:
count = 0
for value in node.values():
if isinstance(value, dict):
if "type" in value or "size_bytes" in value or "token_count" in value:
count += 1
else:
count += self._count_files(value)
else:
count += 1
return count
def get_actions_metadata(self):
actions = [
{
"name": "search",
"description": (
"Search the user's uploaded documents and knowledge base. "
"Use this to find relevant information before answering questions. "
"You can call this multiple times with different queries."
),
"parameters": {
"properties": {
"query": {
"type": "string",
"description": "The search query. Be specific and focused.",
"filled_by_llm": True,
"required": True,
},
}
},
}
]
# Add path_filter and list_files only if directory structure exists
has_structure = self.config.get("has_directory_structure", False)
if has_structure:
actions[0]["parameters"]["properties"]["path_filter"] = {
"type": "string",
"description": (
"Optional: filter results to a specific file or folder path. "
"Use list_files first to see available paths."
),
"filled_by_llm": True,
"required": False,
}
actions.append(
{
"name": "list_files",
"description": (
"Browse the file and folder structure of the knowledge base. "
"Use this to see what files are available before searching. "
"Optionally provide a path to browse a specific folder."
),
"parameters": {
"properties": {
"path": {
"type": "string",
"description": "Optional: folder path to browse. Leave empty for root.",
"filled_by_llm": True,
"required": False,
}
}
},
}
)
return actions
def get_config_requirements(self):
return {}
# Constants for building synthetic tools_dict entries
INTERNAL_TOOL_ID = "internal"
def build_internal_tool_entry(has_directory_structure: bool = False) -> Dict:
"""Build the tools_dict entry for InternalSearchTool.
Dynamically includes list_files and path_filter based on
whether the sources have directory structure.
"""
search_params = {
"properties": {
"query": {
"type": "string",
"description": "The search query. Be specific and focused.",
"filled_by_llm": True,
"required": True,
}
}
}
actions = [
{
"name": "search",
"description": (
"Search the user's uploaded documents and knowledge base. "
"Use this to find relevant information before answering questions. "
"You can call this multiple times with different queries."
),
"active": True,
"parameters": search_params,
}
]
if has_directory_structure:
search_params["properties"]["path_filter"] = {
"type": "string",
"description": (
"Optional: filter results to a specific file or folder path. "
"Use list_files first to see available paths."
),
"filled_by_llm": True,
"required": False,
}
actions.append(
{
"name": "list_files",
"description": (
"Browse the file and folder structure of the knowledge base. "
"Use this to see what files are available before searching. "
"Optionally provide a path to browse a specific folder."
),
"active": True,
"parameters": {
"properties": {
"path": {
"type": "string",
"description": "Optional: folder path to browse. Leave empty for root.",
"filled_by_llm": True,
"required": False,
}
}
},
}
)
return {"name": "internal_search", "actions": actions}
# Keep backward compat
INTERNAL_TOOL_ENTRY = build_internal_tool_entry(has_directory_structure=False)
def sources_have_directory_structure(source: Dict) -> bool:
"""Check if any of the active sources have a ``directory_structure`` row."""
active_docs = source.get("active_docs", [])
if not active_docs:
return False
try:
# TODO(pg-cutover): SourcesRepository.get_any requires ``user_id``
# scoping, but callers in the agent build path don't always
# thread the decoded token through here. Use a direct
# short-lived SQL lookup instead of the repo until the call
# sites are updated to propagate user context.
from sqlalchemy import text as _text
from application.storage.db.session import db_readonly
if isinstance(active_docs, str):
active_docs = [active_docs]
with db_readonly() as conn:
for doc_id in active_docs:
try:
value = str(doc_id)
if len(value) == 36 and "-" in value:
row = conn.execute(
_text(
"SELECT directory_structure FROM sources "
"WHERE id = CAST(:id AS uuid)"
),
{"id": value},
).fetchone()
else:
row = conn.execute(
_text(
"SELECT directory_structure FROM sources "
"WHERE legacy_mongo_id = :lid"
),
{"lid": value},
).fetchone()
if row is not None and row[0]:
return True
except Exception:
continue
except Exception as e:
logger.debug(f"Could not check directory structure: {e}")
return False
def add_internal_search_tool(tools_dict: Dict, retriever_config: Dict) -> None:
"""Add the internal search tool to tools_dict if sources are configured.
Shared by AgenticAgent and ResearchAgent to avoid duplicate setup logic.
Mutates tools_dict in place.
"""
source = retriever_config.get("source", {})
has_sources = bool(source.get("active_docs"))
if not retriever_config or not has_sources:
return
has_dir = sources_have_directory_structure(source)
internal_entry = build_internal_tool_entry(has_directory_structure=has_dir)
internal_entry["config"] = build_internal_tool_config(
**retriever_config,
has_directory_structure=has_dir,
)
tools_dict[INTERNAL_TOOL_ID] = internal_entry
def build_internal_tool_config(
source: Dict,
retriever_name: str = "classic",
chunks: int = 2,
doc_token_limit: int = 50000,
model_id: str = "docsgpt-local",
user_api_key: Optional[str] = None,
agent_id: Optional[str] = None,
llm_name: str = None,
api_key: str = None,
decoded_token: Optional[Dict] = None,
has_directory_structure: bool = False,
) -> Dict:
"""Build the config dict for InternalSearchTool."""
return {
"source": source,
"retriever_name": retriever_name,
"chunks": chunks,
"doc_token_limit": doc_token_limit,
"model_id": model_id,
"user_api_key": user_api_key,
"agent_id": agent_id,
"llm_name": llm_name or settings.LLM_PROVIDER,
"api_key": api_key or settings.API_KEY,
"decoded_token": decoded_token,
"has_directory_structure": has_directory_structure,
}

View File

@@ -1,20 +1,12 @@
import asyncio
import base64
import concurrent.futures
import json
import logging
import time
from typing import Any, Dict, List, Optional
from urllib.parse import parse_qs, urlparse
from application.agents.tools.base import Tool
from application.api.user.tasks import mcp_oauth_status_task, mcp_oauth_task
from application.cache import get_redis_instance
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.security.encryption import decrypt_credentials
from fastmcp import Client
from fastmcp.client.auth import BearerAuth
from fastmcp.client.transports import (
@@ -24,12 +16,17 @@ from fastmcp.client.transports import (
)
from mcp.client.auth import OAuthClientProvider, TokenStorage
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
from pydantic import AnyHttpUrl, ValidationError
from redis import Redis
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
from application.agents.tools.base import Tool
from application.api.user.tasks import mcp_oauth_status_task, mcp_oauth_task
from application.cache import get_redis_instance
from application.core.settings import settings
from application.core.url_validation import SSRFError, validate_url
from application.security.encryption import decrypt_credentials
logger = logging.getLogger(__name__)
_mcp_clients_cache = {}
@@ -56,11 +53,13 @@ class MCPTool(Tool):
- args: Arguments for STDIO transport
- oauth_scopes: OAuth scopes for oauth auth type
- oauth_client_name: OAuth client name for oauth auth type
- query_mode: If True, use non-interactive OAuth (fail-fast on 401)
user_id: User ID for decrypting credentials (required if encrypted_credentials exist)
"""
self.config = config
self.user_id = user_id
self.server_url = config.get("server_url", "")
raw_url = config.get("server_url", "")
self.server_url = self._validate_server_url(raw_url) if raw_url else ""
self.transport_type = config.get("transport_type", "auto")
self.auth_type = config.get("auth_type", "none")
self.timeout = config.get("timeout", 30)
@@ -76,23 +75,53 @@ class MCPTool(Tool):
self.oauth_scopes = config.get("oauth_scopes", [])
self.oauth_task_id = config.get("oauth_task_id", None)
self.oauth_client_name = config.get("oauth_client_name", "DocsGPT-MCP")
self.redirect_uri = f"{settings.API_URL}/api/mcp_server/callback"
self.redirect_uri = self._resolve_redirect_uri(config.get("redirect_uri"))
self.available_tools = []
self._cache_key = self._generate_cache_key()
self._client = None
# Only validate and setup if server_url is provided and not OAuth
self.query_mode = config.get("query_mode", False)
if self.server_url and self.auth_type != "oauth":
self._setup_client()
@staticmethod
def _validate_server_url(server_url: str) -> str:
"""Validate server_url to prevent SSRF to internal networks.
Raises:
ValueError: If the URL points to a private/internal address.
"""
try:
return validate_url(server_url)
except SSRFError as exc:
raise ValueError(f"Invalid MCP server URL: {exc}") from exc
def _resolve_redirect_uri(self, configured_redirect_uri: Optional[str]) -> str:
if configured_redirect_uri:
return configured_redirect_uri.rstrip("/")
explicit = getattr(settings, "MCP_OAUTH_REDIRECT_URI", None)
if explicit:
return explicit.rstrip("/")
connector_base = getattr(settings, "CONNECTOR_REDIRECT_BASE_URI", None)
if connector_base:
parsed = urlparse(connector_base)
if parsed.scheme and parsed.netloc:
return f"{parsed.scheme}://{parsed.netloc}/api/mcp_server/callback"
return f"{settings.API_URL.rstrip('/')}/api/mcp_server/callback"
def _generate_cache_key(self) -> str:
"""Generate a unique cache key for this MCP server configuration."""
auth_key = ""
if self.auth_type == "oauth":
scopes_str = ",".join(self.oauth_scopes) if self.oauth_scopes else "none"
auth_key = f"oauth:{self.oauth_client_name}:{scopes_str}"
oauth_identity = self.user_id or self.oauth_task_id or "anonymous"
auth_key = (
f"oauth:{oauth_identity}:{self.oauth_client_name}:{scopes_str}:{self.redirect_uri}"
)
elif self.auth_type in ["bearer"]:
token = self.auth_credentials.get(
"bearer_token", ""
@@ -109,11 +138,10 @@ class MCPTool(Tool):
return f"{self.server_url}#{self.transport_type}#{auth_key}"
def _setup_client(self):
"""Setup FastMCP client with proper transport and authentication."""
global _mcp_clients_cache
if self._cache_key in _mcp_clients_cache:
cached_data = _mcp_clients_cache[self._cache_key]
if time.time() - cached_data["created_at"] < 1800:
if time.time() - cached_data["created_at"] < 300:
self._client = cached_data["client"]
return
else:
@@ -123,15 +151,23 @@ class MCPTool(Tool):
if self.auth_type == "oauth":
redis_client = get_redis_instance()
auth = DocsGPTOAuth(
mcp_url=self.server_url,
scopes=self.oauth_scopes,
redis_client=redis_client,
redirect_uri=self.redirect_uri,
task_id=self.oauth_task_id,
db=db,
user_id=self.user_id,
)
if self.query_mode:
auth = NonInteractiveOAuth(
mcp_url=self.server_url,
scopes=self.oauth_scopes,
redis_client=redis_client,
redirect_uri=self.redirect_uri,
user_id=self.user_id,
)
else:
auth = DocsGPTOAuth(
mcp_url=self.server_url,
scopes=self.oauth_scopes,
redis_client=redis_client,
redirect_uri=self.redirect_uri,
task_id=self.oauth_task_id,
user_id=self.user_id,
)
elif self.auth_type == "bearer":
token = self.auth_credentials.get(
"bearer_token", ""
@@ -233,38 +269,53 @@ class MCPTool(Tool):
else:
raise Exception(f"Unknown operation: {operation}")
_ERROR_MAP = [
(concurrent.futures.TimeoutError, lambda op, t, _: f"Timed out after {t}s"),
(ConnectionRefusedError, lambda *_: "Connection refused"),
]
_ERROR_PATTERNS = {
("403", "Forbidden"): "Access denied (403 Forbidden)",
("401", "Unauthorized"): "Authentication failed (401 Unauthorized)",
("ECONNREFUSED",): "Connection refused",
("SSL", "certificate"): "SSL/TLS error",
}
def _run_async_operation(self, operation: str, *args, **kwargs):
"""Run async operation in sync context."""
try:
try:
loop = asyncio.get_running_loop()
import concurrent.futures
def run_in_thread():
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
try:
return new_loop.run_until_complete(
self._execute_with_client(operation, *args, **kwargs)
)
finally:
new_loop.close()
asyncio.get_running_loop()
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(run_in_thread)
future = executor.submit(
self._run_in_new_loop, operation, *args, **kwargs
)
return future.result(timeout=self.timeout)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(
self._execute_with_client(operation, *args, **kwargs)
)
finally:
loop.close()
return self._run_in_new_loop(operation, *args, **kwargs)
except Exception as e:
print(f"Error occurred while running async operation: {e}")
raise
raise self._map_error(operation, e) from e
raise self._map_error(operation, e) from e
def _run_in_new_loop(self, operation, *args, **kwargs):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(
self._execute_with_client(operation, *args, **kwargs)
)
finally:
loop.close()
def _map_error(self, operation: str, exc: Exception) -> Exception:
for exc_type, msg_fn in self._ERROR_MAP:
if isinstance(exc, exc_type):
return Exception(msg_fn(operation, self.timeout, exc))
error_msg = str(exc)
for patterns, friendly in self._ERROR_PATTERNS.items():
if any(p.lower() in error_msg.lower() for p in patterns):
return Exception(friendly)
logger.error("MCP %s failed: %s", operation, exc)
return exc
def discover_tools(self) -> List[Dict]:
"""
@@ -285,16 +336,6 @@ class MCPTool(Tool):
raise Exception(f"Failed to discover tools from MCP server: {str(e)}")
def execute_action(self, action_name: str, **kwargs) -> Any:
"""
Execute an action on the remote MCP server using FastMCP.
Args:
action_name: Name of the action to execute
**kwargs: Parameters for the action
Returns:
Result from the MCP server
"""
if not self.server_url:
raise Exception("No MCP server configured")
if not self._client:
@@ -310,7 +351,37 @@ class MCPTool(Tool):
)
return self._format_result(result)
except Exception as e:
raise Exception(f"Failed to execute action '{action_name}': {str(e)}")
error_msg = str(e)
lower_msg = error_msg.lower()
is_auth_error = (
"401" in error_msg
or "unauthorized" in lower_msg
or "session expired" in lower_msg
or "re-authorize" in lower_msg
)
if is_auth_error:
if self.auth_type == "oauth":
raise Exception(
f"Action '{action_name}' failed: OAuth session expired. "
"Please re-authorize this MCP server in tool settings."
) from e
global _mcp_clients_cache
_mcp_clients_cache.pop(self._cache_key, None)
self._client = None
self._setup_client()
try:
result = self._run_async_operation(
"call_tool", action_name, **cleaned_kwargs
)
return self._format_result(result)
except Exception as retry_e:
raise Exception(
f"Action '{action_name}' failed after re-auth attempt: {retry_e}. "
"Your credentials may have expired — please re-authorize in tool settings."
) from retry_e
raise Exception(
f"Failed to execute action '{action_name}': {error_msg}"
) from e
def _format_result(self, result) -> Dict:
"""Format FastMCP result to match expected format."""
@@ -333,23 +404,35 @@ class MCPTool(Tool):
return result
def test_connection(self) -> Dict:
"""
Test the connection to the MCP server and validate functionality.
Returns:
Dictionary with connection test results including tool count
"""
if not self.server_url:
return {
"success": False,
"message": "No MCP server URL configured",
"message": "No server URL configured",
"tools_count": 0,
}
try:
parsed = urlparse(self.server_url)
if parsed.scheme not in ("http", "https"):
return {
"success": False,
"message": f"Invalid URL scheme '{parsed.scheme}' — use http:// or https://",
"tools_count": 0,
}
except Exception:
return {
"success": False,
"message": "Invalid URL format",
"tools_count": 0,
"transport_type": self.transport_type,
"auth_type": self.auth_type,
"error_type": "ConfigurationError",
}
if not self._client:
self._setup_client()
try:
self._setup_client()
except Exception as e:
return {
"success": False,
"message": f"Client init failed: {str(e)}",
"tools_count": 0,
}
try:
if self.auth_type == "oauth":
return self._test_oauth_connection()
@@ -360,56 +443,94 @@ class MCPTool(Tool):
"success": False,
"message": f"Connection failed: {str(e)}",
"tools_count": 0,
"transport_type": self.transport_type,
"auth_type": self.auth_type,
"error_type": type(e).__name__,
}
def _test_regular_connection(self) -> Dict:
"""Test connection for non-OAuth auth types."""
ping_ok = False
ping_error = None
try:
self._run_async_operation("ping")
ping_success = True
except Exception:
ping_success = False
tools = self.discover_tools()
ping_ok = True
except Exception as e:
ping_error = str(e)
message = f"Successfully connected to MCP server. Found {len(tools)} tools."
if not ping_success:
message += " (Ping not supported, but tool discovery worked)"
return {
"success": True,
"message": message,
"tools_count": len(tools),
"transport_type": self.transport_type,
"auth_type": self.auth_type,
"ping_supported": ping_success,
"tools": [tool.get("name", "unknown") for tool in tools],
}
def _test_oauth_connection(self) -> Dict:
"""Test connection for OAuth auth type with proper async handling."""
try:
task = mcp_oauth_task.delay(config=self.config, user=self.user_id)
if not task:
raise Exception("Failed to start OAuth authentication")
return {
"success": True,
"requires_oauth": True,
"task_id": task.id,
"status": "pending",
"message": "OAuth flow started",
}
tools = self.discover_tools()
except Exception as e:
return {
"success": False,
"message": f"OAuth connection failed: {str(e)}",
"message": f"Connection failed: {ping_error or str(e)}",
"tools_count": 0,
"transport_type": self.transport_type,
"auth_type": self.auth_type,
"error_type": type(e).__name__,
}
if not tools and not ping_ok:
return {
"success": False,
"message": f"Connection failed: {ping_error or 'No tools found'}",
"tools_count": 0,
}
return {
"success": True,
"message": f"Connected — found {len(tools)} tool{'s' if len(tools) != 1 else ''}.",
"tools_count": len(tools),
"tools": [
{
"name": tool.get("name", "unknown"),
"description": tool.get("description", ""),
}
for tool in tools
],
}
def _test_oauth_connection(self) -> Dict:
storage = DBTokenStorage(
server_url=self.server_url, user_id=self.user_id,
)
loop = asyncio.new_event_loop()
try:
tokens = loop.run_until_complete(storage.get_tokens())
finally:
loop.close()
if tokens and tokens.access_token:
self.query_mode = True
_mcp_clients_cache.pop(self._cache_key, None)
self._client = None
self._setup_client()
try:
tools = self.discover_tools()
return {
"success": True,
"message": f"Connected — found {len(tools)} tool{'s' if len(tools) != 1 else ''}.",
"tools_count": len(tools),
"tools": [
{
"name": t.get("name", "unknown"),
"description": t.get("description", ""),
}
for t in tools
],
}
except Exception as e:
logger.warning("OAuth token validation failed: %s", e)
_mcp_clients_cache.pop(self._cache_key, None)
self._client = None
return self._start_oauth_task()
def _start_oauth_task(self) -> Dict:
task_config = self.config.copy()
task_config.pop("query_mode", None)
result = mcp_oauth_task.delay(task_config, self.user_id)
return {
"success": False,
"requires_oauth": True,
"task_id": result.id,
"message": "OAuth authorization required.",
"tools_count": 0,
}
def get_actions_metadata(self) -> List[Dict]:
"""
Get metadata for all available actions.
@@ -455,110 +576,88 @@ class MCPTool(Tool):
return actions
def get_config_requirements(self) -> Dict:
"""Get configuration requirements for the MCP tool."""
transport_enum = ["auto", "sse", "http"]
transport_help = {
"auto": "Automatically detect best transport",
"sse": "Server-Sent Events (for real-time streaming)",
"http": "HTTP streaming (recommended for production)",
}
return {
"server_url": {
"type": "string",
"description": "URL of the remote MCP server (e.g., https://api.example.com/mcp or https://docs.mcp.cloudflare.com/sse)",
"label": "Server URL",
"description": "URL of the remote MCP server",
"required": True,
},
"transport_type": {
"type": "string",
"description": "Transport type for connection",
"enum": transport_enum,
"default": "auto",
"required": False,
"help": {
**transport_help,
},
"secret": False,
"order": 1,
},
"auth_type": {
"type": "string",
"description": "Authentication type",
"label": "Authentication Type",
"description": "Authentication method for the MCP server",
"enum": ["none", "bearer", "oauth", "api_key", "basic"],
"default": "none",
"required": True,
"help": {
"none": "No authentication",
"bearer": "Bearer token authentication",
"oauth": "OAuth 2.1 authentication (with frontend integration)",
"api_key": "API key authentication",
"basic": "Basic authentication",
},
"secret": False,
"order": 2,
},
"auth_credentials": {
"type": "object",
"description": "Authentication credentials (varies by auth_type)",
"api_key": {
"type": "string",
"label": "API Key",
"description": "API key for authentication",
"required": False,
"properties": {
"bearer_token": {
"type": "string",
"description": "Bearer token for bearer auth",
},
"access_token": {
"type": "string",
"description": "Access token for OAuth (if pre-obtained)",
},
"api_key": {
"type": "string",
"description": "API key for api_key auth",
},
"api_key_header": {
"type": "string",
"description": "Header name for API key (default: X-API-Key)",
},
"username": {
"type": "string",
"description": "Username for basic auth",
},
"password": {
"type": "string",
"description": "Password for basic auth",
},
},
"secret": True,
"order": 3,
"depends_on": {"auth_type": "api_key"},
},
"api_key_header": {
"type": "string",
"label": "API Key Header",
"description": "Header name for API key (default: X-API-Key)",
"default": "X-API-Key",
"required": False,
"secret": False,
"order": 4,
"depends_on": {"auth_type": "api_key"},
},
"bearer_token": {
"type": "string",
"label": "Bearer Token",
"description": "Bearer token for authentication",
"required": False,
"secret": True,
"order": 3,
"depends_on": {"auth_type": "bearer"},
},
"username": {
"type": "string",
"label": "Username",
"description": "Username for basic authentication",
"required": False,
"secret": False,
"order": 3,
"depends_on": {"auth_type": "basic"},
},
"password": {
"type": "string",
"label": "Password",
"description": "Password for basic authentication",
"required": False,
"secret": True,
"order": 4,
"depends_on": {"auth_type": "basic"},
},
"oauth_scopes": {
"type": "array",
"description": "OAuth scopes to request (for oauth auth_type)",
"items": {"type": "string"},
"required": False,
"default": [],
},
"oauth_client_name": {
"type": "string",
"description": "Client name for OAuth registration (for oauth auth_type)",
"default": "DocsGPT-MCP",
"required": False,
},
"headers": {
"type": "object",
"description": "Custom headers to send with requests",
"label": "OAuth Scopes",
"description": "Comma-separated OAuth scopes to request",
"required": False,
"secret": False,
"order": 3,
"depends_on": {"auth_type": "oauth"},
},
"timeout": {
"type": "integer",
"description": "Request timeout in seconds",
"type": "number",
"label": "Timeout (seconds)",
"description": "Request timeout in seconds (1-300)",
"default": 30,
"minimum": 1,
"maximum": 300,
"required": False,
},
"command": {
"type": "string",
"description": "Command to run for STDIO transport (e.g., 'python')",
"required": False,
},
"args": {
"type": "array",
"description": "Arguments for STDIO command",
"items": {"type": "string"},
"required": False,
"secret": False,
"order": 10,
},
}
@@ -578,31 +677,14 @@ class DocsGPTOAuth(OAuthClientProvider):
scopes: str | list[str] | None = None,
client_name: str = "DocsGPT-MCP",
user_id=None,
db=None,
additional_client_metadata: dict[str, Any] | None = None,
skip_redirect_validation: bool = False,
):
"""
Initialize custom OAuth client provider for DocsGPT.
Args:
mcp_url: Full URL to the MCP endpoint
redirect_uri: Custom redirect URI for DocsGPT frontend
redis_client: Redis client for storing auth state
redis_prefix: Prefix for Redis keys
task_id: Task ID for tracking auth status
scopes: OAuth scopes to request
client_name: Name for this client during registration
user_id: User ID for token storage
db: Database instance for token storage
additional_client_metadata: Extra fields for OAuthClientMetadata
"""
self.redirect_uri = redirect_uri
self.redis_client = redis_client
self.redis_prefix = redis_prefix
self.task_id = task_id
self.user_id = user_id
self.db = db
parsed_url = urlparse(mcp_url)
self.server_base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
@@ -619,7 +701,9 @@ class DocsGPTOAuth(OAuthClientProvider):
)
storage = DBTokenStorage(
server_url=self.server_base_url, user_id=self.user_id, db_client=self.db
server_url=self.server_base_url,
user_id=self.user_id,
expected_redirect_uri=None if skip_redirect_validation else redirect_uri,
)
super().__init__(
@@ -651,22 +735,20 @@ class DocsGPTOAuth(OAuthClientProvider):
async def redirect_handler(self, authorization_url: str) -> None:
"""Store auth URL and state in Redis for frontend to use."""
auth_url, state = self._process_auth_url(authorization_url)
logging.info(
"[DocsGPTOAuth] Processed auth_url: %s, state: %s", auth_url, state
)
logger.info("Processed auth_url: %s, state: %s", auth_url, state)
self.auth_url = auth_url
self.extracted_state = state
if self.redis_client and self.extracted_state:
key = f"{self.redis_prefix}auth_url:{self.extracted_state}"
self.redis_client.setex(key, 600, auth_url)
logging.info("[DocsGPTOAuth] Stored auth_url in Redis: %s", key)
logger.info("Stored auth_url in Redis: %s", key)
if self.task_id:
status_key = f"mcp_oauth_status:{self.task_id}"
status_data = {
"status": "requires_redirect",
"message": "OAuth authorization required",
"message": "Authorization required",
"authorization_url": self.auth_url,
"state": self.extracted_state,
"requires_oauth": True,
@@ -686,7 +768,7 @@ class DocsGPTOAuth(OAuthClientProvider):
status_key = f"mcp_oauth_status:{self.task_id}"
status_data = {
"status": "awaiting_callback",
"message": "Waiting for OAuth callback...",
"message": "Waiting for authorization...",
"authorization_url": self.auth_url,
"state": self.extracted_state,
"requires_oauth": True,
@@ -711,7 +793,7 @@ class DocsGPTOAuth(OAuthClientProvider):
if self.task_id:
status_data = {
"status": "callback_received",
"message": "OAuth callback received, completing authentication...",
"message": "Completing authentication...",
"task_id": self.task_id,
}
self.redis_client.setex(status_key, 600, json.dumps(status_data))
@@ -731,67 +813,149 @@ class DocsGPTOAuth(OAuthClientProvider):
await asyncio.sleep(poll_interval)
self.redis_client.delete(f"{self.redis_prefix}auth_url:{self.extracted_state}")
self.redis_client.delete(f"{self.redis_prefix}state:{self.extracted_state}")
raise Exception("OAuth callback timeout: no code received within 5 minutes")
raise Exception("OAuth timeout: no code received within 5 minutes")
class NonInteractiveOAuth(DocsGPTOAuth):
"""OAuth provider that fails fast on 401 instead of starting interactive auth.
Used during query execution to prevent the streaming response from blocking
while waiting for user authorization that will never come.
"""
def __init__(self, **kwargs):
kwargs.setdefault("task_id", None)
kwargs["skip_redirect_validation"] = True
super().__init__(**kwargs)
async def redirect_handler(self, authorization_url: str) -> None:
raise Exception(
"OAuth session expired — please re-authorize this MCP server in tool settings."
)
async def callback_handler(self) -> tuple[str, str | None]:
raise Exception(
"OAuth session expired — please re-authorize this MCP server in tool settings."
)
class DBTokenStorage(TokenStorage):
def __init__(self, server_url: str, user_id: str, db_client):
def __init__(
self,
server_url: str,
user_id: str,
expected_redirect_uri: Optional[str] = None,
):
self.server_url = server_url
self.user_id = user_id
self.db_client = db_client
self.collection = db_client["connector_sessions"]
self.expected_redirect_uri = expected_redirect_uri
@staticmethod
def get_base_url(url: str) -> str:
parsed = urlparse(url)
return f"{parsed.scheme}://{parsed.netloc}"
def get_db_key(self) -> dict:
return {
"server_url": self.get_base_url(self.server_url),
"user_id": self.user_id,
}
def _pg_provider(self) -> str:
return f"mcp:{self.get_base_url(self.server_url)}"
def _fetch_session_data(self) -> dict:
"""Read the JSONB ``session_data`` blob for this MCP server row."""
from application.storage.db.repositories.connector_sessions import (
ConnectorSessionsRepository,
)
from application.storage.db.session import db_readonly
base_url = self.get_base_url(self.server_url)
with db_readonly() as conn:
row = ConnectorSessionsRepository(conn).get_by_user_and_server_url(
self.user_id, base_url,
)
if not row:
return {}
data = row.get("session_data") or {}
if isinstance(data, str):
try:
data = json.loads(data)
except ValueError:
return {}
return data if isinstance(data, dict) else {}
async def get_tokens(self) -> OAuthToken | None:
doc = await asyncio.to_thread(self.collection.find_one, self.get_db_key())
if not doc or "tokens" not in doc:
data = await asyncio.to_thread(self._fetch_session_data)
if not data or "tokens" not in data:
return None
try:
tokens = OAuthToken.model_validate(doc["tokens"])
return tokens
return OAuthToken.model_validate(data["tokens"])
except ValidationError as e:
logging.error(f"Could not load tokens: {e}")
logger.error("Could not load tokens: %s", e)
return None
def _merge(self, patch: dict) -> None:
"""Shallow-merge ``patch`` into this row's ``session_data``.
Threads ``server_url`` through to the repository so it lands in
the scalar column — ``get_by_user_and_server_url`` needs that to
resolve the row (``NULL = 'https://...'`` is UNKNOWN in SQL).
"""
from application.storage.db.repositories.connector_sessions import (
ConnectorSessionsRepository,
)
from application.storage.db.session import db_session
base_url = self.get_base_url(self.server_url)
with db_session() as conn:
ConnectorSessionsRepository(conn).merge_session_data(
self.user_id, self._pg_provider(), base_url, patch,
)
def _delete(self) -> None:
from application.storage.db.repositories.connector_sessions import (
ConnectorSessionsRepository,
)
from application.storage.db.session import db_session
with db_session() as conn:
ConnectorSessionsRepository(conn).delete(
self.user_id, self._pg_provider(),
)
async def set_tokens(self, tokens: OAuthToken) -> None:
await asyncio.to_thread(
self.collection.update_one,
self.get_db_key(),
{"$set": {"tokens": tokens.model_dump()}},
True,
)
logging.info(f"Saved tokens for {self.get_base_url(self.server_url)}")
base_url = self.get_base_url(self.server_url)
token_dump = tokens.model_dump()
await asyncio.to_thread(self._merge, {"tokens": token_dump})
logger.info("Saved tokens for %s", base_url)
async def get_client_info(self) -> OAuthClientInformationFull | None:
doc = await asyncio.to_thread(self.collection.find_one, self.get_db_key())
if not doc or "client_info" not in doc:
data = await asyncio.to_thread(self._fetch_session_data)
base_url = self.get_base_url(self.server_url)
if not data or "client_info" not in data:
logger.debug("No client_info in DB for %s", base_url)
return None
try:
client_info = OAuthClientInformationFull.model_validate(doc["client_info"])
tokens = await self.get_tokens()
if tokens is None:
logging.debug(
"No tokens found, clearing client info to force fresh registration."
)
await asyncio.to_thread(
self.collection.update_one,
self.get_db_key(),
{"$unset": {"client_info": ""}},
)
return None
client_info = OAuthClientInformationFull.model_validate(data["client_info"])
if self.expected_redirect_uri:
stored_uris = [
str(uri).rstrip("/") for uri in client_info.redirect_uris
]
expected_uri = self.expected_redirect_uri.rstrip("/")
if expected_uri not in stored_uris:
logger.warning(
"Redirect URI mismatch for %s: expected=%s stored=%s — clearing.",
base_url,
expected_uri,
stored_uris,
)
# Drop ``tokens`` and ``client_info`` from the JSONB
# blob via merge_session_data's ``None``-drops-key
# semantics — preserves the row + any other keys.
await asyncio.to_thread(
self._merge,
{"tokens": None, "client_info": None},
)
return None
return client_info
except ValidationError as e:
logging.error(f"Could not load client info: {e}")
logger.error("Could not load client info: %s", e)
return None
def _serialize_client_info(self, info: dict) -> dict:
@@ -801,23 +965,38 @@ class DBTokenStorage(TokenStorage):
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
serialized_info = self._serialize_client_info(client_info.model_dump())
base_url = self.get_base_url(self.server_url)
await asyncio.to_thread(
self.collection.update_one,
self.get_db_key(),
{"$set": {"client_info": serialized_info}},
True,
self._merge, {"client_info": serialized_info},
)
logging.info(f"Saved client info for {self.get_base_url(self.server_url)}")
logger.info("Saved client info for %s", base_url)
async def clear(self) -> None:
await asyncio.to_thread(self.collection.delete_one, self.get_db_key())
logging.info(f"Cleared OAuth cache for {self.get_base_url(self.server_url)}")
await asyncio.to_thread(self._delete)
logger.info("Cleared OAuth cache for %s", self.get_base_url(self.server_url))
@classmethod
async def clear_all(cls, db_client) -> None:
collection = db_client["connector_sessions"]
await asyncio.to_thread(collection.delete_many, {})
logging.info("Cleared all OAuth client cache data.")
async def clear_all(cls, db_client=None) -> None:
"""Delete every MCP-tagged connector session row.
``db_client`` retained for call-site compatibility but unused —
storage is Postgres-only now.
"""
from sqlalchemy import text
from application.storage.db.session import db_session
def _delete_all() -> None:
with db_session() as conn:
conn.execute(
text(
"DELETE FROM connector_sessions "
"WHERE provider LIKE 'mcp:%'"
)
)
await asyncio.to_thread(_delete_all)
logger.info("Cleared all OAuth client cache data.")
class MCPOAuthManager:
@@ -856,7 +1035,7 @@ class MCPOAuthManager:
return True
except Exception as e:
logging.error(f"Error handling OAuth callback: {e}")
logger.error("Error handling OAuth callback: %s", e)
return False
def get_oauth_status(self, task_id: str) -> Dict[str, Any]:

View File

@@ -1,12 +1,14 @@
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional
import re
import logging
import uuid
from .base import Tool
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.storage.db.repositories.memories import MemoriesRepository
from application.storage.db.session import db_readonly, db_session
logger = logging.getLogger(__name__)
class MemoryTool(Tool):
@@ -27,7 +29,7 @@ class MemoryTool(Tool):
self.user_id: Optional[str] = user_id
# Get tool_id from configuration (passed from user_tools._id in production)
# In production, tool_id is the MongoDB ObjectId string from user_tools collection
# In production, tool_id is the UUID string from user_tools.id.
if tool_config and "tool_id" in tool_config:
self.tool_id = tool_config["tool_id"]
elif user_id:
@@ -37,8 +39,35 @@ class MemoryTool(Tool):
# Last resort fallback (shouldn't happen in normal use)
self.tool_id = str(uuid.uuid4())
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
self.collection = db["memories"]
def _pg_enabled(self) -> bool:
"""Return True if this MemoryTool's tool_id is a real ``user_tools.id``.
The ``memories`` PG table has a UUID foreign key to ``user_tools``.
The sentinel ``default_{uid}`` fallback tool_id is not a UUID and
has no row in ``user_tools``, so any storage operation would fail
the foreign-key check. After the Postgres cutover Postgres is the
only store, so for the sentinel case there is nowhere to read or
write — operations become no-ops and the tool returns an
explanatory error to the caller.
"""
tool_id = getattr(self, "tool_id", None)
if not tool_id or not isinstance(tool_id, str):
return False
if tool_id.startswith("default_"):
logger.debug(
"Skipping Postgres operation for MemoryTool with sentinel tool_id=%s",
tool_id,
)
return False
from application.storage.db.base_repository import looks_like_uuid
if not looks_like_uuid(tool_id):
logger.debug(
"Skipping Postgres operation for MemoryTool with non-UUID tool_id=%s",
tool_id,
)
return False
return True
# -----------------------------
# Action implementations
@@ -56,6 +85,12 @@ class MemoryTool(Tool):
if not self.user_id:
return "Error: MemoryTool requires a valid user_id."
if not self._pg_enabled():
return (
"Error: MemoryTool is not configured with a persistent tool_id; "
"memory storage is unavailable for this session."
)
if action_name == "view":
return self._view(
kwargs.get("path", "/"),
@@ -282,14 +317,10 @@ class MemoryTool(Tool):
# Ensure path ends with / for proper prefix matching
search_path = path if path.endswith("/") else path + "/"
# Find all files that start with this directory path
query = {
"user_id": self.user_id,
"tool_id": self.tool_id,
"path": {"$regex": f"^{re.escape(search_path)}"}
}
docs = list(self.collection.find(query, {"path": 1}))
with db_readonly() as conn:
docs = MemoriesRepository(conn).list_by_prefix(
self.user_id, self.tool_id, search_path
)
if not docs:
return f"Directory: {path}\n(empty)"
@@ -310,7 +341,10 @@ class MemoryTool(Tool):
def _view_file(self, path: str, view_range: Optional[List[int]] = None) -> str:
"""View file contents with optional line range."""
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": path})
with db_readonly() as conn:
doc = MemoriesRepository(conn).get_by_path(
self.user_id, self.tool_id, path
)
if not doc or not doc.get("content"):
return f"Error: File not found: {path}"
@@ -344,16 +378,10 @@ class MemoryTool(Tool):
if validated_path == "/" or validated_path.endswith("/"):
return "Error: Cannot create a file at directory path."
self.collection.update_one(
{"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path},
{
"$set": {
"content": file_text,
"updated_at": datetime.now()
}
},
upsert=True
)
with db_session() as conn:
MemoriesRepository(conn).upsert(
self.user_id, self.tool_id, validated_path, file_text
)
return f"File created: {validated_path}"
@@ -366,30 +394,29 @@ class MemoryTool(Tool):
if not old_str:
return "Error: old_str is required."
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path})
with db_session() as conn:
repo = MemoriesRepository(conn)
doc = repo.get_by_path(self.user_id, self.tool_id, validated_path)
if not doc or not doc.get("content"):
return f"Error: File not found: {validated_path}"
if not doc or not doc.get("content"):
return f"Error: File not found: {validated_path}"
current_content = str(doc["content"])
current_content = str(doc["content"])
# Check if old_str exists (case-insensitive)
if old_str.lower() not in current_content.lower():
return f"Error: String '{old_str}' not found in file."
# Check if old_str exists (case-insensitive)
if old_str.lower() not in current_content.lower():
return f"Error: String '{old_str}' not found in file."
# Replace the string (case-insensitive)
import re as regex_module
updated_content = regex_module.sub(regex_module.escape(old_str), new_str, current_content, flags=regex_module.IGNORECASE)
# Case-insensitive replace
import re as regex_module
updated_content = regex_module.sub(
regex_module.escape(old_str),
new_str,
current_content,
flags=regex_module.IGNORECASE,
)
self.collection.update_one(
{"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path},
{
"$set": {
"content": updated_content,
"updated_at": datetime.now()
}
}
)
repo.upsert(self.user_id, self.tool_id, validated_path, updated_content)
return f"File updated: {validated_path}"
@@ -402,31 +429,25 @@ class MemoryTool(Tool):
if not insert_text:
return "Error: insert_text is required."
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path})
with db_session() as conn:
repo = MemoriesRepository(conn)
doc = repo.get_by_path(self.user_id, self.tool_id, validated_path)
if not doc or not doc.get("content"):
return f"Error: File not found: {validated_path}"
if not doc or not doc.get("content"):
return f"Error: File not found: {validated_path}"
current_content = str(doc["content"])
lines = current_content.split("\n")
current_content = str(doc["content"])
lines = current_content.split("\n")
# Convert to 0-indexed
index = insert_line - 1
if index < 0 or index > len(lines):
return f"Error: Invalid line number. File has {len(lines)} lines."
# Convert to 0-indexed
index = insert_line - 1
if index < 0 or index > len(lines):
return f"Error: Invalid line number. File has {len(lines)} lines."
lines.insert(index, insert_text)
updated_content = "\n".join(lines)
lines.insert(index, insert_text)
updated_content = "\n".join(lines)
self.collection.update_one(
{"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path},
{
"$set": {
"content": updated_content,
"updated_at": datetime.now()
}
}
)
repo.upsert(self.user_id, self.tool_id, validated_path, updated_content)
return f"Text inserted at line {insert_line} in {validated_path}"
@@ -438,39 +459,36 @@ class MemoryTool(Tool):
if validated_path == "/":
# Delete all files for this user and tool
result = self.collection.delete_many({"user_id": self.user_id, "tool_id": self.tool_id})
return f"Deleted {result.deleted_count} file(s) from memory."
with db_session() as conn:
deleted = MemoriesRepository(conn).delete_all(
self.user_id, self.tool_id
)
return f"Deleted {deleted} file(s) from memory."
# Check if it's a directory (ends with /)
if validated_path.endswith("/"):
# Delete all files in directory
result = self.collection.delete_many({
"user_id": self.user_id,
"tool_id": self.tool_id,
"path": {"$regex": f"^{re.escape(validated_path)}"}
})
return f"Deleted directory and {result.deleted_count} file(s)."
with db_session() as conn:
deleted = MemoriesRepository(conn).delete_by_prefix(
self.user_id, self.tool_id, validated_path
)
return f"Deleted directory and {deleted} file(s)."
# Try to delete as directory first (without trailing slash)
# Check if any files start with this path + /
# Try as directory first (without trailing slash)
search_path = validated_path + "/"
directory_result = self.collection.delete_many({
"user_id": self.user_id,
"tool_id": self.tool_id,
"path": {"$regex": f"^{re.escape(search_path)}"}
})
with db_session() as conn:
repo = MemoriesRepository(conn)
directory_deleted = repo.delete_by_prefix(
self.user_id, self.tool_id, search_path
)
if directory_deleted > 0:
return f"Deleted directory and {directory_deleted} file(s)."
if directory_result.deleted_count > 0:
return f"Deleted directory and {directory_result.deleted_count} file(s)."
# Otherwise delete a single file
file_deleted = repo.delete_by_path(
self.user_id, self.tool_id, validated_path
)
# Delete single file
result = self.collection.delete_one({
"user_id": self.user_id,
"tool_id": self.tool_id,
"path": validated_path
})
if result.deleted_count:
if file_deleted:
return f"Deleted: {validated_path}"
return f"Error: File not found: {validated_path}"
@@ -485,62 +503,46 @@ class MemoryTool(Tool):
if validated_old == "/" or validated_new == "/":
return "Error: Cannot rename root directory."
# Check if renaming a directory
# Directory rename: do all path updates inside one transaction so
# the rename is atomic from the caller's perspective.
if validated_old.endswith("/"):
# Ensure validated_new also ends with / for proper path replacement
if not validated_new.endswith("/"):
validated_new = validated_new + "/"
# Find all files in the old directory
docs = list(self.collection.find({
"user_id": self.user_id,
"tool_id": self.tool_id,
"path": {"$regex": f"^{re.escape(validated_old)}"}
}))
if not docs:
return f"Error: Directory not found: {validated_old}"
# Update paths for all files
for doc in docs:
old_file_path = doc["path"]
new_file_path = old_file_path.replace(validated_old, validated_new, 1)
self.collection.update_one(
{"_id": doc["_id"]},
{"$set": {"path": new_file_path, "updated_at": datetime.now()}}
with db_session() as conn:
repo = MemoriesRepository(conn)
docs = repo.list_by_prefix(
self.user_id, self.tool_id, validated_old
)
if not docs:
return f"Error: Directory not found: {validated_old}"
for doc in docs:
old_file_path = doc["path"]
new_file_path = old_file_path.replace(
validated_old, validated_new, 1
)
repo.update_path(
self.user_id, self.tool_id, old_file_path, new_file_path
)
return f"Renamed directory: {validated_old} -> {validated_new} ({len(docs)} files)"
# Rename single file
doc = self.collection.find_one({
"user_id": self.user_id,
"tool_id": self.tool_id,
"path": validated_old
})
# Single-file rename: lookup, collision check, and update in one txn.
with db_session() as conn:
repo = MemoriesRepository(conn)
doc = repo.get_by_path(self.user_id, self.tool_id, validated_old)
if not doc:
return f"Error: File not found: {validated_old}"
if not doc:
return f"Error: File not found: {validated_old}"
existing = repo.get_by_path(self.user_id, self.tool_id, validated_new)
if existing:
return f"Error: File already exists at {validated_new}"
# Check if new path already exists
existing = self.collection.find_one({
"user_id": self.user_id,
"tool_id": self.tool_id,
"path": validated_new
})
if existing:
return f"Error: File already exists at {validated_new}"
# Delete the old document and create a new one with the new path
self.collection.delete_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_old})
self.collection.insert_one({
"user_id": self.user_id,
"tool_id": self.tool_id,
"path": validated_new,
"content": doc.get("content", ""),
"updated_at": datetime.now()
})
repo.update_path(
self.user_id, self.tool_id, validated_old, validated_new
)
return f"Renamed: {validated_old} -> {validated_new}"

View File

@@ -1,10 +1,16 @@
from datetime import datetime
from typing import Any, Dict, List, Optional
import uuid
from .base import Tool
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.storage.db.repositories.notes import NotesRepository
from application.storage.db.session import db_readonly, db_session
# Stable synthetic title used in the Postgres ``notes.title`` column.
# The notes tool stores one note per (user_id, tool_id); there is no
# user-facing title. PG requires ``title`` NOT NULL, so we write a stable
# constant alongside the actual note body in ``content``.
_NOTE_TITLE = "note"
class NotesTool(Tool):
@@ -25,7 +31,6 @@ class NotesTool(Tool):
self.user_id: Optional[str] = user_id
# Get tool_id from configuration (passed from user_tools._id in production)
# In production, tool_id is the MongoDB ObjectId string from user_tools collection
if tool_config and "tool_id" in tool_config:
self.tool_id = tool_config["tool_id"]
elif user_id:
@@ -35,11 +40,25 @@ class NotesTool(Tool):
# Last resort fallback (shouldn't happen in normal use)
self.tool_id = str(uuid.uuid4())
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
self.collection = db["notes"]
self._last_artifact_id: Optional[str] = None
def _pg_enabled(self) -> bool:
"""Return True only when ``tool_id`` is a real ``user_tools.id`` UUID.
``notes.tool_id`` is a UUID FK to ``user_tools``; repo queries
``CAST(:tool_id AS uuid)``. The sentinel ``default_{uid}``
fallback is neither a UUID nor a ``user_tools`` row, so any DB
operation would crash. Mirror MemoryTool's guard and no-op.
"""
tool_id = getattr(self, "tool_id", None)
if not tool_id or not isinstance(tool_id, str):
return False
if tool_id.startswith("default_"):
return False
from application.storage.db.base_repository import looks_like_uuid
return looks_like_uuid(tool_id)
# -----------------------------
# Action implementations
# -----------------------------
@@ -54,7 +73,13 @@ class NotesTool(Tool):
A human-readable string result.
"""
if not self.user_id:
return "Error: NotesTool requires a valid user_id."
return "Error: NotesTool requires a valid user_id."
if not self._pg_enabled():
return (
"Error: NotesTool is not configured with a persistent "
"tool_id; note storage is unavailable for this session."
)
self._last_artifact_id = None
@@ -135,37 +160,45 @@ class NotesTool(Tool):
# -----------------------------
# Internal helpers (single-note)
# -----------------------------
def _fetch_note(self) -> Optional[dict]:
"""Read the note row for this (user, tool) from Postgres."""
with db_readonly() as conn:
return NotesRepository(conn).get_for_user_tool(self.user_id, self.tool_id)
def _get_note(self) -> str:
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
if not doc or not doc.get("note"):
doc = self._fetch_note()
# ``content`` is the PG column; expose as ``note`` to callers via the
# textual return value. Frontends that read the artifact via the
# repo dict get ``content`` (PG-native) plus the artifact id below.
body = (doc or {}).get("content")
if not doc or not body:
return "No note found."
if doc.get("_id") is not None:
self._last_artifact_id = str(doc.get("_id"))
return str(doc["note"])
if doc.get("id") is not None:
self._last_artifact_id = str(doc.get("id"))
return str(body)
def _overwrite_note(self, content: str) -> str:
content = (content or "").strip()
if not content:
return "Note content required."
result = self.collection.find_one_and_update(
{"user_id": self.user_id, "tool_id": self.tool_id},
{"$set": {"note": content, "updated_at": datetime.utcnow()}},
upsert=True,
return_document=True,
)
if result and result.get("_id") is not None:
self._last_artifact_id = str(result.get("_id"))
with db_session() as conn:
row = NotesRepository(conn).upsert(
self.user_id, self.tool_id, _NOTE_TITLE, content
)
if row and row.get("id") is not None:
self._last_artifact_id = str(row.get("id"))
return "Note saved."
def _str_replace(self, old_str: str, new_str: str) -> str:
if not old_str:
return "old_str is required."
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
if not doc or not doc.get("note"):
doc = self._fetch_note()
existing = (doc or {}).get("content")
if not doc or not existing:
return "No note found."
current_note = str(doc["note"])
current_note = str(existing)
# Case-insensitive search
if old_str.lower() not in current_note.lower():
@@ -175,24 +208,24 @@ class NotesTool(Tool):
import re
updated_note = re.sub(re.escape(old_str), new_str, current_note, flags=re.IGNORECASE)
result = self.collection.find_one_and_update(
{"user_id": self.user_id, "tool_id": self.tool_id},
{"$set": {"note": updated_note, "updated_at": datetime.utcnow()}},
return_document=True,
)
if result and result.get("_id") is not None:
self._last_artifact_id = str(result.get("_id"))
with db_session() as conn:
row = NotesRepository(conn).upsert(
self.user_id, self.tool_id, _NOTE_TITLE, updated_note
)
if row and row.get("id") is not None:
self._last_artifact_id = str(row.get("id"))
return "Note updated."
def _insert(self, line_number: int, text: str) -> str:
if not text:
return "Text is required."
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
if not doc or not doc.get("note"):
doc = self._fetch_note()
existing = (doc or {}).get("content")
if not doc or not existing:
return "No note found."
current_note = str(doc["note"])
current_note = str(existing)
lines = current_note.split("\n")
# Convert to 0-indexed and validate
@@ -203,21 +236,23 @@ class NotesTool(Tool):
lines.insert(index, text)
updated_note = "\n".join(lines)
result = self.collection.find_one_and_update(
{"user_id": self.user_id, "tool_id": self.tool_id},
{"$set": {"note": updated_note, "updated_at": datetime.utcnow()}},
return_document=True,
)
if result and result.get("_id") is not None:
self._last_artifact_id = str(result.get("_id"))
with db_session() as conn:
row = NotesRepository(conn).upsert(
self.user_id, self.tool_id, _NOTE_TITLE, updated_note
)
if row and row.get("id") is not None:
self._last_artifact_id = str(row.get("id"))
return "Text inserted."
def _delete_note(self) -> str:
doc = self.collection.find_one_and_delete(
{"user_id": self.user_id, "tool_id": self.tool_id}
)
if not doc:
# Capture the id (for artifact tracking) before deleting.
existing = self._fetch_note()
if not existing:
return "No note found to delete."
if doc.get("_id") is not None:
self._last_artifact_id = str(doc.get("_id"))
with db_session() as conn:
deleted = NotesRepository(conn).delete(self.user_id, self.tool_id)
if not deleted:
return "No note found to delete."
if existing.get("id") is not None:
self._last_artifact_id = str(existing.get("id"))
return "Note deleted."

View File

@@ -71,7 +71,7 @@ class NtfyTool(Tool):
if self.token:
headers["Authorization"] = f"Basic {self.token}"
data = message.encode("utf-8")
response = requests.post(url, headers=headers, data=data)
response = requests.post(url, headers=headers, data=data, timeout=100)
return {"status_code": response.status_code, "message": "Message sent"}
def get_actions_metadata(self):
@@ -116,12 +116,13 @@ class NtfyTool(Tool):
]
def get_config_requirements(self):
"""
Specify the configuration requirements.
Returns:
dict: Dictionary describing required config parameters.
"""
return {
"token": {"type": "string", "description": "Access token for authentication"},
"token": {
"type": "string",
"label": "Access Token",
"description": "Ntfy access token for authentication",
"required": True,
"secret": True,
"order": 1,
},
}

View File

@@ -1,6 +1,12 @@
import psycopg2
import logging
import psycopg
from application.agents.tools.base import Tool
logger = logging.getLogger(__name__)
class PostgresTool(Tool):
"""
PostgreSQL Database Tool
@@ -17,25 +23,25 @@ class PostgresTool(Tool):
"postgres_execute_sql": self._execute_sql,
"postgres_get_schema": self._get_schema,
}
if action_name in actions:
return actions[action_name](**kwargs)
else:
if action_name not in actions:
raise ValueError(f"Unknown action: {action_name}")
return actions[action_name](**kwargs)
def _execute_sql(self, sql_query):
"""
Executes an SQL query against the PostgreSQL database using a connection string.
"""
conn = None # Initialize conn to None for error handling
conn = None
try:
conn = psycopg2.connect(self.connection_string)
conn = psycopg.connect(self.connection_string)
cur = conn.cursor()
cur.execute(sql_query)
conn.commit()
if sql_query.strip().lower().startswith("select"):
column_names = [desc[0] for desc in cur.description] if cur.description else []
column_names = (
[desc[0] for desc in cur.description] if cur.description else []
)
results = []
rows = cur.fetchall()
for row in rows:
@@ -43,7 +49,9 @@ class PostgresTool(Tool):
response_data = {"data": results, "column_names": column_names}
else:
row_count = cur.rowcount
response_data = {"message": f"Query executed successfully, {row_count} rows affected."}
response_data = {
"message": f"Query executed successfully, {row_count} rows affected."
}
cur.close()
return {
@@ -52,28 +60,29 @@ class PostgresTool(Tool):
"response_data": response_data,
}
except psycopg2.Error as e:
except psycopg.Error as e:
error_message = f"Database error: {e}"
print(f"Database error: {e}")
logger.error("PostgreSQL execute_sql error: %s", e)
return {
"status_code": 500,
"message": "Failed to execute SQL query.",
"error": error_message,
}
finally:
if conn: # Ensure connection is closed even if errors occur
if conn:
conn.close()
def _get_schema(self, db_name):
"""
Retrieves the schema of the PostgreSQL database using a connection string.
"""
conn = None # Initialize conn to None for error handling
conn = None
try:
conn = psycopg2.connect(self.connection_string)
conn = psycopg.connect(self.connection_string)
cur = conn.cursor()
cur.execute("""
cur.execute(
"""
SELECT
table_name,
column_name,
@@ -87,19 +96,22 @@ class PostgresTool(Tool):
ORDER BY
table_name,
ordinal_position;
""")
"""
)
schema_data = {}
for row in cur.fetchall():
table_name, column_name, data_type, column_default, is_nullable = row
if table_name not in schema_data:
schema_data[table_name] = []
schema_data[table_name].append({
"column_name": column_name,
"data_type": data_type,
"column_default": column_default,
"is_nullable": is_nullable
})
schema_data[table_name].append(
{
"column_name": column_name,
"data_type": data_type,
"column_default": column_default,
"is_nullable": is_nullable,
}
)
cur.close()
return {
@@ -108,16 +120,16 @@ class PostgresTool(Tool):
"schema": schema_data,
}
except psycopg2.Error as e:
except psycopg.Error as e:
error_message = f"Database error: {e}"
print(f"Database error: {e}")
logger.error("PostgreSQL get_schema error: %s", e)
return {
"status_code": 500,
"message": "Failed to retrieve database schema.",
"error": error_message,
}
finally:
if conn: # Ensure connection is closed even if errors occur
if conn:
conn.close()
def get_actions_metadata(self):
@@ -158,6 +170,10 @@ class PostgresTool(Tool):
return {
"token": {
"type": "string",
"description": "PostgreSQL database connection string (e.g., 'postgresql://user:password@host:port/dbname')",
"label": "Connection String",
"description": "PostgreSQL database connection string",
"required": True,
"secret": True,
"order": 1,
},
}
}

View File

@@ -1,6 +1,11 @@
import logging
import requests
from application.agents.tools.base import Tool
logger = logging.getLogger(__name__)
class TelegramTool(Tool):
"""
@@ -18,24 +23,22 @@ class TelegramTool(Tool):
"telegram_send_message": self._send_message,
"telegram_send_image": self._send_image,
}
if action_name in actions:
return actions[action_name](**kwargs)
else:
if action_name not in actions:
raise ValueError(f"Unknown action: {action_name}")
return actions[action_name](**kwargs)
def _send_message(self, text, chat_id):
print(f"Sending message: {text}")
logger.debug("Sending Telegram message to chat_id=%s", chat_id)
url = f"https://api.telegram.org/bot{self.token}/sendMessage"
payload = {"chat_id": chat_id, "text": text}
response = requests.post(url, data=payload)
response = requests.post(url, data=payload, timeout=100)
return {"status_code": response.status_code, "message": "Message sent"}
def _send_image(self, image_url, chat_id):
print(f"Sending image: {image_url}")
logger.debug("Sending Telegram image to chat_id=%s", chat_id)
url = f"https://api.telegram.org/bot{self.token}/sendPhoto"
payload = {"chat_id": chat_id, "photo": image_url}
response = requests.post(url, data=payload)
response = requests.post(url, data=payload, timeout=100)
return {"status_code": response.status_code, "message": "Image sent"}
def get_actions_metadata(self):
@@ -82,5 +85,12 @@ class TelegramTool(Tool):
def get_config_requirements(self):
return {
"token": {"type": "string", "description": "Bot token for authentication"},
"token": {
"type": "string",
"label": "Bot Token",
"description": "Telegram bot token for authentication",
"required": True,
"secret": True,
"order": 1,
},
}

View File

@@ -0,0 +1,70 @@
from application.agents.tools.base import Tool
THINK_TOOL_ID = "think"
THINK_TOOL_ENTRY = {
"name": "think",
"actions": [
{
"name": "reason",
"description": (
"Use this tool to think through your reasoning step by step "
"before deciding on your next action. Always reason before "
"searching or answering."
),
"active": True,
"parameters": {
"properties": {
"reasoning": {
"type": "string",
"description": "Your step-by-step reasoning and analysis",
"filled_by_llm": True,
"required": True,
}
}
},
}
],
}
class ThinkTool(Tool):
"""Pseudo-tool that captures chain-of-thought reasoning.
Returns a short acknowledgment so the LLM can continue.
The reasoning content is captured in tool_call data for transparency.
"""
internal = True
def __init__(self, config=None):
pass
def execute_action(self, action_name: str, **kwargs):
return "Continue."
def get_actions_metadata(self):
return [
{
"name": "reason",
"description": (
"Use this tool to think through your reasoning step by step "
"before deciding on your next action. Always reason before "
"searching or answering."
),
"parameters": {
"properties": {
"reasoning": {
"type": "string",
"description": "Your step-by-step reasoning and analysis",
"filled_by_llm": True,
"required": True,
}
}
},
}
]
def get_config_requirements(self):
return {}

View File

@@ -1,10 +1,19 @@
from datetime import datetime
from typing import Any, Dict, List, Optional
import uuid
from .base import Tool
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.storage.db.repositories.todos import TodosRepository
from application.storage.db.session import db_readonly, db_session
def _status_from_completed(completed: Any) -> str:
"""Translate the PG ``completed`` boolean to the legacy status string.
The frontend (and prior LLM-facing tool output) expects
``"open"`` / ``"completed"``. Keeping that contract at the tool
boundary insulates callers from the schema change.
"""
return "completed" if bool(completed) else "open"
class TodoListTool(Tool):
@@ -25,7 +34,6 @@ class TodoListTool(Tool):
self.user_id: Optional[str] = user_id
# Get tool_id from configuration (passed from user_tools._id in production)
# In production, tool_id is the MongoDB ObjectId string from user_tools collection
if tool_config and "tool_id" in tool_config:
self.tool_id = tool_config["tool_id"]
elif user_id:
@@ -35,11 +43,27 @@ class TodoListTool(Tool):
# Last resort fallback (shouldn't happen in normal use)
self.tool_id = str(uuid.uuid4())
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
self.collection = db["todos"]
self._last_artifact_id: Optional[str] = None
def _pg_enabled(self) -> bool:
"""Return True only when ``tool_id`` is a real ``user_tools.id`` UUID.
The ``todos`` PG table has a UUID foreign key to ``user_tools`` and
the repo queries ``CAST(:tool_id AS uuid)``. The sentinel
``default_{uid}`` fallback is neither a UUID nor a row in
``user_tools`` — binding it would crash ``invalid input syntax for
type uuid`` and even if it didn't the FK would reject it. Mirror
the MemoryTool guard and no-op in that case.
"""
tool_id = getattr(self, "tool_id", None)
if not tool_id or not isinstance(tool_id, str):
return False
if tool_id.startswith("default_"):
return False
from application.storage.db.base_repository import looks_like_uuid
return looks_like_uuid(tool_id)
# -----------------------------
# Action implementations
# -----------------------------
@@ -56,6 +80,12 @@ class TodoListTool(Tool):
if not self.user_id:
return "Error: TodoListTool requires a valid user_id."
if not self._pg_enabled():
return (
"Error: TodoListTool is not configured with a persistent "
"tool_id; todo storage is unavailable for this session."
)
self._last_artifact_id = None
if action_name == "list":
@@ -191,28 +221,10 @@ class TodoListTool(Tool):
return None
def _get_next_todo_id(self) -> int:
"""Get the next sequential todo_id for this user and tool.
Returns a simple integer (1, 2, 3, ...) scoped to this user/tool.
With 5-10 todos max, scanning is negligible.
"""
query = {"user_id": self.user_id, "tool_id": self.tool_id}
todos = list(self.collection.find(query, {"todo_id": 1}))
# Find the maximum todo_id
max_id = 0
for todo in todos:
todo_id = self._coerce_todo_id(todo.get("todo_id"))
if todo_id is not None:
max_id = max(max_id, todo_id)
return max_id + 1
def _list(self) -> str:
"""List all todos for the user."""
query = {"user_id": self.user_id, "tool_id": self.tool_id}
todos = list(self.collection.find(query))
with db_readonly() as conn:
todos = TodosRepository(conn).list_for_tool(self.user_id, self.tool_id)
if not todos:
return "No todos found."
@@ -221,7 +233,7 @@ class TodoListTool(Tool):
for doc in todos:
todo_id = doc.get("todo_id")
title = doc.get("title", "Untitled")
status = doc.get("status", "open")
status = _status_from_completed(doc.get("completed"))
line = f"[{todo_id}] {title} ({status})"
result_lines.append(line)
@@ -229,27 +241,23 @@ class TodoListTool(Tool):
return "\n".join(result_lines)
def _create(self, title: str) -> str:
"""Create a new todo item."""
"""Create a new todo item.
``TodosRepository.create`` allocates the per-tool monotonic
``todo_id`` inside the same transaction (``COALESCE(MAX(todo_id),0)+1``
scoped to ``tool_id``), so we no longer need a separate read-then-
write step here.
"""
title = (title or "").strip()
if not title:
return "Error: Title is required."
now = datetime.now()
todo_id = self._get_next_todo_id()
with db_session() as conn:
row = TodosRepository(conn).create(self.user_id, self.tool_id, title)
doc = {
"todo_id": todo_id,
"user_id": self.user_id,
"tool_id": self.tool_id,
"title": title,
"status": "open",
"created_at": now,
"updated_at": now,
}
insert_result = self.collection.insert_one(doc)
inserted_id = getattr(insert_result, "inserted_id", None) or doc.get("_id")
if inserted_id is not None:
self._last_artifact_id = str(inserted_id)
todo_id = row.get("todo_id")
if row.get("id") is not None:
self._last_artifact_id = str(row.get("id"))
return f"Todo created with ID {todo_id}: {title}"
def _get(self, todo_id: Optional[Any]) -> str:
@@ -258,21 +266,21 @@ class TodoListTool(Tool):
if parsed_todo_id is None:
return "Error: todo_id must be a positive integer."
query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}
doc = self.collection.find_one(query)
with db_readonly() as conn:
doc = TodosRepository(conn).get_by_tool_and_todo_id(
self.user_id, self.tool_id, parsed_todo_id
)
if not doc:
return f"Error: Todo with ID {parsed_todo_id} not found."
if doc.get("_id") is not None:
self._last_artifact_id = str(doc.get("_id"))
if doc.get("id") is not None:
self._last_artifact_id = str(doc.get("id"))
title = doc.get("title", "Untitled")
status = doc.get("status", "open")
status = _status_from_completed(doc.get("completed"))
result = f"Todo [{parsed_todo_id}]:\nTitle: {title}\nStatus: {status}"
return result
return f"Todo [{parsed_todo_id}]:\nTitle: {title}\nStatus: {status}"
def _update(self, todo_id: Optional[Any], title: str) -> str:
"""Update a todo's title by ID."""
@@ -284,16 +292,19 @@ class TodoListTool(Tool):
if not title:
return "Error: Title is required."
query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}
doc = self.collection.find_one_and_update(
query,
{"$set": {"title": title, "updated_at": datetime.now()}},
)
if not doc:
return f"Error: Todo with ID {parsed_todo_id} not found."
with db_session() as conn:
repo = TodosRepository(conn)
existing = repo.get_by_tool_and_todo_id(
self.user_id, self.tool_id, parsed_todo_id
)
if not existing:
return f"Error: Todo with ID {parsed_todo_id} not found."
repo.update_title_by_tool_and_todo_id(
self.user_id, self.tool_id, parsed_todo_id, title
)
if doc.get("_id") is not None:
self._last_artifact_id = str(doc.get("_id"))
if existing.get("id") is not None:
self._last_artifact_id = str(existing.get("id"))
return f"Todo {parsed_todo_id} updated to: {title}"
@@ -303,16 +314,17 @@ class TodoListTool(Tool):
if parsed_todo_id is None:
return "Error: todo_id must be a positive integer."
query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}
doc = self.collection.find_one_and_update(
query,
{"$set": {"status": "completed", "updated_at": datetime.now()}},
)
if not doc:
return f"Error: Todo with ID {parsed_todo_id} not found."
with db_session() as conn:
repo = TodosRepository(conn)
existing = repo.get_by_tool_and_todo_id(
self.user_id, self.tool_id, parsed_todo_id
)
if not existing:
return f"Error: Todo with ID {parsed_todo_id} not found."
repo.set_completed(self.user_id, self.tool_id, parsed_todo_id, True)
if doc.get("_id") is not None:
self._last_artifact_id = str(doc.get("_id"))
if existing.get("id") is not None:
self._last_artifact_id = str(existing.get("id"))
return f"Todo {parsed_todo_id} marked as completed."
@@ -322,12 +334,18 @@ class TodoListTool(Tool):
if parsed_todo_id is None:
return "Error: todo_id must be a positive integer."
query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}
doc = self.collection.find_one_and_delete(query)
if not doc:
return f"Error: Todo with ID {parsed_todo_id} not found."
with db_session() as conn:
repo = TodosRepository(conn)
existing = repo.get_by_tool_and_todo_id(
self.user_id, self.tool_id, parsed_todo_id
)
if not existing:
return f"Error: Todo with ID {parsed_todo_id} not found."
repo.delete_by_tool_and_todo_id(
self.user_id, self.tool_id, parsed_todo_id
)
if doc.get("_id") is not None:
self._last_artifact_id = str(doc.get("_id"))
if existing.get("id") is not None:
self._last_artifact_id = str(existing.get("id"))
return f"Todo {parsed_todo_id} deleted."

View File

@@ -5,8 +5,9 @@ logger = logging.getLogger(__name__)
class ToolActionParser:
def __init__(self, llm_type):
def __init__(self, llm_type, name_mapping=None):
self.llm_type = llm_type
self.name_mapping = name_mapping
self.parsers = {
"OpenAILLM": self._parse_openai_llm,
"GoogleLLM": self._parse_google_llm,
@@ -16,22 +17,33 @@ class ToolActionParser:
parser = self.parsers.get(self.llm_type, self._parse_openai_llm)
return parser(call)
def _resolve_via_mapping(self, call_name):
"""Look up (tool_id, action_name) from the name mapping if available."""
if self.name_mapping and call_name in self.name_mapping:
return self.name_mapping[call_name]
return None
def _parse_openai_llm(self, call):
try:
call_args = json.loads(call.arguments)
resolved = self._resolve_via_mapping(call.name)
if resolved:
return resolved[0], resolved[1], call_args
# Fallback: legacy split on "_" for backward compatibility
tool_parts = call.name.split("_")
# If the tool name doesn't contain an underscore, it's likely a hallucinated tool
if len(tool_parts) < 2:
logger.warning(
f"Invalid tool name format: {call.name}. Expected format: action_name_tool_id"
f"Invalid tool name format: {call.name}. "
"Could not resolve via mapping or legacy parsing."
)
return None, None, None
tool_id = tool_parts[-1]
action_name = "_".join(tool_parts[:-1])
# Validate that tool_id looks like a numerical ID
if not tool_id.isdigit():
logger.warning(
f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call."
@@ -45,19 +57,24 @@ class ToolActionParser:
def _parse_google_llm(self, call):
try:
call_args = call.arguments
resolved = self._resolve_via_mapping(call.name)
if resolved:
return resolved[0], resolved[1], call_args
# Fallback: legacy split on "_" for backward compatibility
tool_parts = call.name.split("_")
# If the tool name doesn't contain an underscore, it's likely a hallucinated tool
if len(tool_parts) < 2:
logger.warning(
f"Invalid tool name format: {call.name}. Expected format: action_name_tool_id"
f"Invalid tool name format: {call.name}. "
"Could not resolve via mapping or legacy parsing."
)
return None, None, None
tool_id = tool_parts[-1]
action_name = "_".join(tool_parts[:-1])
# Validate that tool_id looks like a numerical ID
if not tool_id.isdigit():
logger.warning(
f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call."

View File

@@ -19,7 +19,7 @@ class ToolManager:
continue
module = importlib.import_module(f"application.agents.tools.{name}")
for member_name, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, Tool) and obj is not Tool:
if issubclass(obj, Tool) and obj is not Tool and not obj.internal:
tool_config = self.config.get(name, {})
self.tools[name] = obj(tool_config)
@@ -36,7 +36,7 @@ class ToolManager:
def execute_action(self, tool_name, action_name, user_id=None, **kwargs):
if tool_name not in self.tools:
raise ValueError(f"Tool '{tool_name}' not loaded")
if tool_name in {"mcp_tool", "memory", "todo_list"} and user_id:
if tool_name in {"mcp_tool", "memory", "todo_list", "notes"} and user_id:
tool_config = self.config.get(tool_name, {})
tool = self.load_tool(tool_name, tool_config, user_id)
return tool.execute_action(action_name, **kwargs)

View File

@@ -12,9 +12,13 @@ from application.agents.workflows.schemas import (
WorkflowRun,
)
from application.agents.workflows.workflow_engine import WorkflowEngine
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.logging import log_activity, LogContext
from application.storage.db.base_repository import looks_like_uuid
from application.storage.db.repositories.workflow_edges import WorkflowEdgesRepository
from application.storage.db.repositories.workflow_nodes import WorkflowNodesRepository
from application.storage.db.repositories.workflow_runs import WorkflowRunsRepository
from application.storage.db.repositories.workflows import WorkflowsRepository
from application.storage.db.session import db_readonly, db_session
logger = logging.getLogger(__name__)
@@ -103,10 +107,8 @@ class WorkflowAgent(BaseAgent):
def _load_from_database(self) -> Optional[WorkflowGraph]:
try:
from bson.objectid import ObjectId
if not self.workflow_id or not ObjectId.is_valid(self.workflow_id):
logger.error(f"Invalid workflow ID: {self.workflow_id}")
if not self.workflow_id:
logger.error("Missing workflow ID for load")
return None
owner_id = self.workflow_owner
if not owner_id and isinstance(self.decoded_token, dict):
@@ -117,61 +119,61 @@ class WorkflowAgent(BaseAgent):
)
return None
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
workflows_coll = db["workflows"]
workflow_nodes_coll = db["workflow_nodes"]
workflow_edges_coll = db["workflow_edges"]
workflow_doc = workflows_coll.find_one(
{"_id": ObjectId(self.workflow_id), "user": owner_id}
)
if not workflow_doc:
logger.error(
f"Workflow {self.workflow_id} not found or inaccessible for user {owner_id}"
)
return None
workflow = Workflow(**workflow_doc)
graph_version = workflow_doc.get("current_graph_version", 1)
try:
graph_version = int(graph_version)
if graph_version <= 0:
with db_readonly() as conn:
wf_repo = WorkflowsRepository(conn)
if looks_like_uuid(self.workflow_id):
workflow_row = wf_repo.get(self.workflow_id, owner_id)
else:
workflow_row = wf_repo.get_by_legacy_id(self.workflow_id, owner_id)
if workflow_row is None:
logger.error(
f"Workflow {self.workflow_id} not found or inaccessible "
f"for user {owner_id}"
)
return None
pg_workflow_id = str(workflow_row["id"])
graph_version = workflow_row.get("current_graph_version", 1)
try:
graph_version = int(graph_version)
if graph_version <= 0:
graph_version = 1
except (ValueError, TypeError):
graph_version = 1
except (ValueError, TypeError):
graph_version = 1
nodes_docs = list(
workflow_nodes_coll.find(
{"workflow_id": self.workflow_id, "graph_version": graph_version}
node_rows = WorkflowNodesRepository(conn).find_by_version(
pg_workflow_id, graph_version,
)
)
if not nodes_docs and graph_version == 1:
nodes_docs = list(
workflow_nodes_coll.find(
{
"workflow_id": self.workflow_id,
"graph_version": {"$exists": False},
}
)
edge_rows = WorkflowEdgesRepository(conn).find_by_version(
pg_workflow_id, graph_version,
)
nodes = [WorkflowNode(**doc) for doc in nodes_docs]
edges_docs = list(
workflow_edges_coll.find(
{"workflow_id": self.workflow_id, "graph_version": graph_version}
)
workflow = Workflow(
name=workflow_row.get("name"),
description=workflow_row.get("description"),
)
if not edges_docs and graph_version == 1:
edges_docs = list(
workflow_edges_coll.find(
{
"workflow_id": self.workflow_id,
"graph_version": {"$exists": False},
}
)
nodes = [
WorkflowNode(
id=n["node_id"],
workflow_id=pg_workflow_id,
type=n["node_type"],
title=n.get("title") or "Node",
description=n.get("description"),
position=n.get("position") or {"x": 0, "y": 0},
config=n.get("config") or {},
)
edges = [WorkflowEdge(**doc) for doc in edges_docs]
for n in node_rows
]
edges = [
WorkflowEdge(
id=e["edge_id"],
workflow_id=pg_workflow_id,
source=e.get("source_id"),
target=e.get("target_id"),
sourceHandle=e.get("source_handle"),
targetHandle=e.get("target_handle"),
)
for e in edge_rows
]
return WorkflowGraph(workflow=workflow, nodes=nodes, edges=edges)
except Exception as e:
@@ -181,13 +183,13 @@ class WorkflowAgent(BaseAgent):
def _save_workflow_run(self, query: str) -> None:
if not self._engine:
return
owner_id = self.workflow_owner
if not owner_id and isinstance(self.decoded_token, dict):
owner_id = self.decoded_token.get("sub")
try:
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
workflow_runs_coll = db["workflow_runs"]
run = WorkflowRun(
workflow_id=self.workflow_id or "unknown",
user=owner_id,
status=self._determine_run_status(),
inputs={"query": query},
outputs=self._serialize_state(self._engine.state),
@@ -196,7 +198,28 @@ class WorkflowAgent(BaseAgent):
completed_at=datetime.now(timezone.utc),
)
workflow_runs_coll.insert_one(run.to_mongo_doc())
if not self.workflow_id or not owner_id:
return
with db_session() as conn:
wf_repo = WorkflowsRepository(conn)
if looks_like_uuid(self.workflow_id):
workflow_row = wf_repo.get(self.workflow_id, owner_id)
else:
workflow_row = wf_repo.get_by_legacy_id(
self.workflow_id, owner_id,
)
if workflow_row is None:
return
WorkflowRunsRepository(conn).create(
str(workflow_row["id"]),
owner_id,
run.status.value,
inputs=run.inputs,
result=run.outputs,
steps=[step.model_dump(mode="json") for step in run.steps],
started_at=run.created_at,
ended_at=run.completed_at,
)
except Exception as e:
logger.error(f"Failed to save workflow run: {e}")

View File

@@ -2,9 +2,10 @@
from typing import Any, Dict, List, Optional, Type
from application.agents.agentic_agent import AgenticAgent
from application.agents.base import BaseAgent
from application.agents.classic_agent import ClassicAgent
from application.agents.react_agent import ReActAgent
from application.agents.research_agent import ResearchAgent
from application.agents.workflows.schemas import AgentType
@@ -36,7 +37,8 @@ class ToolFilterMixin:
return filtered_tools
class WorkflowNodeClassicAgent(ToolFilterMixin, ClassicAgent):
class _WorkflowNodeMixin:
"""Common __init__ for all workflow node agents."""
def __init__(
self,
@@ -57,32 +59,25 @@ class WorkflowNodeClassicAgent(ToolFilterMixin, ClassicAgent):
self._allowed_tool_ids = tool_ids or []
class WorkflowNodeReActAgent(ToolFilterMixin, ReActAgent):
class WorkflowNodeClassicAgent(ToolFilterMixin, _WorkflowNodeMixin, ClassicAgent):
pass
def __init__(
self,
endpoint: str,
llm_name: str,
model_id: str,
api_key: str,
tool_ids: Optional[List[str]] = None,
**kwargs,
):
super().__init__(
endpoint=endpoint,
llm_name=llm_name,
model_id=model_id,
api_key=api_key,
**kwargs,
)
self._allowed_tool_ids = tool_ids or []
class WorkflowNodeAgenticAgent(ToolFilterMixin, _WorkflowNodeMixin, AgenticAgent):
pass
class WorkflowNodeResearchAgent(ToolFilterMixin, _WorkflowNodeMixin, ResearchAgent):
pass
class WorkflowNodeAgentFactory:
_agents: Dict[AgentType, Type[BaseAgent]] = {
AgentType.CLASSIC: WorkflowNodeClassicAgent,
AgentType.REACT: WorkflowNodeReActAgent,
AgentType.REACT: WorkflowNodeClassicAgent, # backwards compat
AgentType.AGENTIC: WorkflowNodeAgenticAgent,
AgentType.RESEARCH: WorkflowNodeResearchAgent,
}
@classmethod

View File

@@ -2,7 +2,6 @@ from datetime import datetime, timezone
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Union
from bson import ObjectId
from pydantic import BaseModel, ConfigDict, Field, field_validator
@@ -18,6 +17,8 @@ class NodeType(str, Enum):
class AgentType(str, Enum):
CLASSIC = "classic"
REACT = "react"
AGENTIC = "agentic"
RESEARCH = "research"
class ExecutionStatus(str, Enum):
@@ -79,24 +80,7 @@ class WorkflowEdgeCreate(BaseModel):
class WorkflowEdge(WorkflowEdgeCreate):
mongo_id: Optional[str] = Field(None, alias="_id")
@field_validator("mongo_id", mode="before")
@classmethod
def convert_objectid(cls, v: Any) -> Optional[str]:
if isinstance(v, ObjectId):
return str(v)
return v
def to_mongo_doc(self) -> Dict[str, Any]:
return {
"id": self.id,
"workflow_id": self.workflow_id,
"source_id": self.source_id,
"target_id": self.target_id,
"source_handle": self.source_handle,
"target_handle": self.target_handle,
}
pass
class WorkflowNodeCreate(BaseModel):
@@ -118,25 +102,7 @@ class WorkflowNodeCreate(BaseModel):
class WorkflowNode(WorkflowNodeCreate):
mongo_id: Optional[str] = Field(None, alias="_id")
@field_validator("mongo_id", mode="before")
@classmethod
def convert_objectid(cls, v: Any) -> Optional[str]:
if isinstance(v, ObjectId):
return str(v)
return v
def to_mongo_doc(self) -> Dict[str, Any]:
return {
"id": self.id,
"workflow_id": self.workflow_id,
"type": self.type.value,
"title": self.title,
"description": self.description,
"position": self.position.model_dump(),
"config": self.config,
}
pass
class WorkflowCreate(BaseModel):
@@ -147,26 +113,10 @@ class WorkflowCreate(BaseModel):
class Workflow(WorkflowCreate):
id: Optional[str] = Field(None, alias="_id")
id: Optional[str] = None
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
@field_validator("id", mode="before")
@classmethod
def convert_objectid(cls, v: Any) -> Optional[str]:
if isinstance(v, ObjectId):
return str(v)
return v
def to_mongo_doc(self) -> Dict[str, Any]:
return {
"name": self.name,
"description": self.description,
"user": self.user,
"created_at": self.created_at,
"updated_at": self.updated_at,
}
class WorkflowGraph(BaseModel):
workflow: Workflow
@@ -207,29 +157,12 @@ class WorkflowRunCreate(BaseModel):
class WorkflowRun(BaseModel):
model_config = ConfigDict(extra="allow")
id: Optional[str] = Field(None, alias="_id")
id: Optional[str] = None
workflow_id: str
user: Optional[str] = None
status: ExecutionStatus = ExecutionStatus.PENDING
inputs: Dict[str, str] = Field(default_factory=dict)
outputs: Dict[str, Any] = Field(default_factory=dict)
steps: List[NodeExecutionLog] = Field(default_factory=list)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
completed_at: Optional[datetime] = None
@field_validator("id", mode="before")
@classmethod
def convert_objectid(cls, v: Any) -> Optional[str]:
if isinstance(v, ObjectId):
return str(v)
return v
def to_mongo_doc(self) -> Dict[str, Any]:
return {
"workflow_id": self.workflow_id,
"status": self.status.value,
"inputs": self.inputs,
"outputs": self.outputs,
"steps": [step.model_dump() for step in self.steps],
"created_at": self.created_at,
"completed_at": self.completed_at,
}

View File

@@ -7,6 +7,7 @@ from application.agents.workflows.cel_evaluator import CelEvaluationError, evalu
from application.agents.workflows.node_agent import WorkflowNodeAgentFactory
from application.agents.workflows.schemas import (
AgentNodeConfig,
AgentType,
ConditionNodeConfig,
ExecutionStatus,
NodeExecutionLog,
@@ -18,6 +19,7 @@ from application.core.json_schema_utils import (
JsonSchemaValidationError,
normalize_json_schema_payload,
)
from application.error import sanitize_api_error
from application.templates.namespaces import NamespaceManager
from application.templates.template_engine import TemplateEngine, TemplateRenderError
@@ -99,6 +101,7 @@ class WorkflowEngine:
log_entry["state_snapshot"] = dict(self.state)
self.execution_log.append(log_entry)
user_friendly_error = sanitize_api_error(e)
yield {
"type": "workflow_step",
"node_id": node.id,
@@ -106,9 +109,9 @@ class WorkflowEngine:
"node_title": node.title,
"status": "failed",
"state_snapshot": dict(self.state),
"error": str(e),
"error": user_friendly_error,
}
yield {"type": "error", "error": str(e)}
yield {"type": "error", "error": user_friendly_error}
break
log_entry["state_snapshot"] = dict(self.state)
self.execution_log.append(log_entry)
@@ -197,6 +200,9 @@ class WorkflowEngine:
node_config = AgentNodeConfig(**node.config.get("config", node.config))
if node_config.sources:
self._retrieve_node_sources(node_config)
if node_config.prompt_template:
formatted_prompt = self._format_template(node_config.prompt_template)
else:
@@ -221,18 +227,32 @@ class WorkflowEngine:
f'Model "{node_model_id}" does not support structured output for node "{node.title}"'
)
node_agent = WorkflowNodeAgentFactory.create(
agent_type=node_config.agent_type,
endpoint=self.agent.endpoint,
llm_name=node_llm_name,
model_id=node_model_id,
api_key=node_api_key,
tool_ids=node_config.tools,
prompt=node_config.system_prompt,
chat_history=self.agent.chat_history,
decoded_token=self.agent.decoded_token,
json_schema=node_json_schema,
)
factory_kwargs = {
"agent_type": node_config.agent_type,
"endpoint": self.agent.endpoint,
"llm_name": node_llm_name,
"model_id": node_model_id,
"api_key": node_api_key,
"tool_ids": node_config.tools,
"prompt": node_config.system_prompt,
"chat_history": self.agent.chat_history,
"decoded_token": self.agent.decoded_token,
"json_schema": node_json_schema,
}
# Agentic/research agents need retriever_config for on-demand search
if node_config.agent_type in (AgentType.AGENTIC, AgentType.RESEARCH):
factory_kwargs["retriever_config"] = {
"source": {"active_docs": node_config.sources} if node_config.sources else {},
"retriever_name": node_config.retriever or "classic",
"chunks": int(node_config.chunks) if node_config.chunks else 2,
"model_id": node_model_id,
"llm_name": node_llm_name,
"api_key": node_api_key,
"decoded_token": self.agent.decoded_token,
}
node_agent = WorkflowNodeAgentFactory.create(**factory_kwargs)
full_response_parts: List[str] = []
structured_response_parts: List[str] = []
@@ -438,6 +458,29 @@ class WorkflowEngine:
docs_together = "\n\n".join(docs_together_parts) if docs_together_parts else None
return docs, docs_together
def _retrieve_node_sources(self, node_config: AgentNodeConfig) -> None:
"""Retrieve documents from the node's sources for template resolution."""
from application.retriever.retriever_creator import RetrieverCreator
query = self.state.get("query", "")
if not query:
return
try:
retriever = RetrieverCreator.create_retriever(
node_config.retriever or "classic",
source={"active_docs": node_config.sources},
chat_history=[],
prompt="",
chunks=int(node_config.chunks) if node_config.chunks else 2,
decoded_token=self.agent.decoded_token,
)
docs = retriever.search(query)
if docs:
self.agent.retrieved_docs = docs
except Exception:
logger.exception("Failed to retrieve docs for workflow node")
def get_execution_summary(self) -> List[NodeExecutionLog]:
return [
NodeExecutionLog(

52
application/alembic.ini Normal file
View File

@@ -0,0 +1,52 @@
# Alembic configuration for the DocsGPT user-data Postgres database.
#
# The SQLAlchemy URL is deliberately NOT set here — env.py reads it from
# ``application.core.settings.settings.POSTGRES_URI`` so the same config
# source serves the running app and migrations. To run from the project
# root::
#
# alembic -c application/alembic.ini upgrade head
[alembic]
script_location = %(here)s/alembic
prepend_sys_path = ..
version_path_separator = os
# sqlalchemy.url is intentionally left blank — env.py supplies it.
sqlalchemy.url =
[post_write_hooks]
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARNING
handlers = console
qualname =
[logger_sqlalchemy]
level = WARNING
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

View File

@@ -0,0 +1,82 @@
"""Alembic environment for the DocsGPT user-data Postgres database.
The URL is pulled from ``application.core.settings`` rather than
``alembic.ini`` so that a single ``POSTGRES_URI`` env var drives both the
running app and ``alembic`` CLI invocations.
"""
import sys
from logging.config import fileConfig
from pathlib import Path
# Make the project root importable regardless of cwd. env.py lives at
# <repo>/application/alembic/env.py, so parents[2] is the repo root.
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
if str(_PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(_PROJECT_ROOT))
from alembic import context # noqa: E402
from sqlalchemy import engine_from_config, pool # noqa: E402
from application.core.settings import settings # noqa: E402
from application.storage.db.models import metadata as target_metadata # noqa: E402
config = context.config
# Populate the runtime URL from settings.
if settings.POSTGRES_URI:
config.set_main_option("sqlalchemy.url", settings.POSTGRES_URI)
if config.config_file_name is not None:
fileConfig(config.config_file_name)
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode (emits SQL without a live DB)."""
url = config.get_main_option("sqlalchemy.url")
if not url:
raise RuntimeError(
"POSTGRES_URI is not configured. Set it in your .env to a "
"psycopg3 URI such as "
"'postgresql+psycopg://user:pass@host:5432/docsgpt'."
)
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
compare_type=True,
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode against a live connection."""
if not config.get_main_option("sqlalchemy.url"):
raise RuntimeError(
"POSTGRES_URI is not configured. Set it in your .env to a "
"psycopg3 URI such as "
"'postgresql+psycopg://user:pass@host:5432/docsgpt'."
)
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
future=True,
)
with connectable.connect() as connection:
context.configure(
connection=connection,
target_metadata=target_metadata,
compare_type=True,
)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

View File

@@ -0,0 +1,26 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
${upgrades if upgrades else "pass"}
def downgrade() -> None:
${downgrades if downgrades else "pass"}

View File

@@ -0,0 +1,927 @@
"""0001 initial schema — consolidated Phase-1..3 baseline.
Revision ID: 0001_initial
Revises:
Create Date: 2026-04-13
"""
from typing import Sequence, Union
from alembic import op
revision: str = "0001_initial"
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ------------------------------------------------------------------
# Extensions
# ------------------------------------------------------------------
op.execute('CREATE EXTENSION IF NOT EXISTS "pgcrypto";')
op.execute('CREATE EXTENSION IF NOT EXISTS "citext";')
# ------------------------------------------------------------------
# Trigger functions
# ------------------------------------------------------------------
op.execute(
"""
CREATE FUNCTION set_updated_at() RETURNS trigger
LANGUAGE plpgsql AS $$
BEGIN
NEW.updated_at = now();
RETURN NEW;
END;
$$;
"""
)
op.execute(
"""
CREATE FUNCTION ensure_user_exists() RETURNS trigger
LANGUAGE plpgsql AS $$
BEGIN
IF NEW.user_id IS NOT NULL THEN
INSERT INTO users (user_id) VALUES (NEW.user_id)
ON CONFLICT (user_id) DO NOTHING;
END IF;
RETURN NEW;
END;
$$;
"""
)
op.execute(
"""
CREATE FUNCTION cleanup_message_attachment_refs() RETURNS trigger
LANGUAGE plpgsql AS $$
BEGIN
UPDATE conversation_messages
SET attachments = array_remove(attachments, OLD.id)
WHERE OLD.id = ANY(attachments);
RETURN OLD;
END;
$$;
"""
)
op.execute(
"""
CREATE FUNCTION cleanup_agent_extra_source_refs() RETURNS trigger
LANGUAGE plpgsql AS $$
BEGIN
UPDATE agents
SET extra_source_ids = array_remove(extra_source_ids, OLD.id)
WHERE OLD.id = ANY(extra_source_ids);
RETURN OLD;
END;
$$;
"""
)
op.execute(
"""
CREATE FUNCTION cleanup_user_agent_prefs() RETURNS trigger
LANGUAGE plpgsql AS $$
DECLARE
agent_id_text text := OLD.id::text;
BEGIN
UPDATE users
SET agent_preferences = jsonb_set(
jsonb_set(
agent_preferences,
'{pinned}',
COALESCE((
SELECT jsonb_agg(e)
FROM jsonb_array_elements(
COALESCE(agent_preferences->'pinned', '[]'::jsonb)
) e
WHERE (e #>> '{}') <> agent_id_text
), '[]'::jsonb)
),
'{shared_with_me}',
COALESCE((
SELECT jsonb_agg(e)
FROM jsonb_array_elements(
COALESCE(agent_preferences->'shared_with_me', '[]'::jsonb)
) e
WHERE (e #>> '{}') <> agent_id_text
), '[]'::jsonb)
)
WHERE agent_preferences->'pinned' @> to_jsonb(agent_id_text)
OR agent_preferences->'shared_with_me' @> to_jsonb(agent_id_text);
RETURN OLD;
END;
$$;
"""
)
op.execute(
"""
CREATE FUNCTION conversation_messages_fill_user_id() RETURNS trigger
LANGUAGE plpgsql AS $$
BEGIN
IF NEW.user_id IS NULL THEN
SELECT user_id INTO NEW.user_id
FROM conversations
WHERE id = NEW.conversation_id;
END IF;
RETURN NEW;
END;
$$;
"""
)
# ------------------------------------------------------------------
# Tables
# ------------------------------------------------------------------
op.execute(
"""
CREATE TABLE users (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL UNIQUE,
agent_preferences JSONB NOT NULL
DEFAULT '{"pinned": [], "shared_with_me": []}'::jsonb,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
"""
)
op.execute(
"""
CREATE TABLE prompts (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
name TEXT NOT NULL,
content TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
legacy_mongo_id TEXT
);
"""
)
op.execute(
"""
CREATE TABLE user_tools (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
name TEXT NOT NULL,
custom_name TEXT,
display_name TEXT,
description TEXT,
config JSONB NOT NULL DEFAULT '{}'::jsonb,
config_requirements JSONB NOT NULL DEFAULT '{}'::jsonb,
actions JSONB NOT NULL DEFAULT '[]'::jsonb,
status BOOLEAN NOT NULL DEFAULT true,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
legacy_mongo_id TEXT
);
"""
)
op.execute(
"""
CREATE TABLE token_usage (
id BIGSERIAL PRIMARY KEY,
user_id TEXT,
api_key TEXT,
agent_id UUID,
prompt_tokens INTEGER NOT NULL DEFAULT 0,
generated_tokens INTEGER NOT NULL DEFAULT 0,
timestamp TIMESTAMPTZ NOT NULL DEFAULT now(),
mongo_id TEXT
);
"""
)
op.execute(
"ALTER TABLE token_usage ADD CONSTRAINT token_usage_attribution_chk "
"CHECK (user_id IS NOT NULL OR api_key IS NOT NULL) NOT VALID;"
)
op.execute(
"""
CREATE TABLE user_logs (
id BIGSERIAL PRIMARY KEY,
user_id TEXT,
endpoint TEXT,
timestamp TIMESTAMPTZ NOT NULL DEFAULT now(),
data JSONB,
mongo_id TEXT
);
"""
)
op.execute(
"""
CREATE TABLE stack_logs (
id BIGSERIAL PRIMARY KEY,
activity_id TEXT NOT NULL,
endpoint TEXT,
level TEXT,
user_id TEXT,
api_key TEXT,
query TEXT,
stacks JSONB NOT NULL DEFAULT '[]'::jsonb,
timestamp TIMESTAMPTZ NOT NULL DEFAULT now(),
mongo_id TEXT
);
"""
)
op.execute(
"""
CREATE TABLE agent_folders (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
name TEXT NOT NULL,
description TEXT,
parent_id UUID REFERENCES agent_folders(id) ON DELETE SET NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
legacy_mongo_id TEXT
);
"""
)
op.execute(
"""
CREATE TABLE sources (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
name TEXT NOT NULL,
language TEXT,
date TIMESTAMPTZ NOT NULL DEFAULT now(),
model TEXT,
type TEXT,
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
retriever TEXT,
sync_frequency TEXT,
tokens TEXT,
file_path TEXT,
remote_data JSONB,
directory_structure JSONB,
file_name_map JSONB,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
legacy_mongo_id TEXT
);
"""
)
op.execute(
"""
CREATE TABLE agents (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
name TEXT NOT NULL,
description TEXT,
agent_type TEXT,
status TEXT NOT NULL,
key CITEXT UNIQUE,
image TEXT,
source_id UUID REFERENCES sources(id) ON DELETE SET NULL,
extra_source_ids UUID[] NOT NULL DEFAULT '{}',
chunks INTEGER,
retriever TEXT,
prompt_id UUID REFERENCES prompts(id) ON DELETE SET NULL,
tools JSONB NOT NULL DEFAULT '[]'::jsonb,
json_schema JSONB,
models JSONB,
default_model_id TEXT,
folder_id UUID REFERENCES agent_folders(id) ON DELETE SET NULL,
workflow_id UUID,
limited_token_mode BOOLEAN NOT NULL DEFAULT false,
token_limit INTEGER,
limited_request_mode BOOLEAN NOT NULL DEFAULT false,
request_limit INTEGER,
allow_system_prompt_override BOOLEAN NOT NULL DEFAULT false,
shared BOOLEAN NOT NULL DEFAULT false,
shared_token CITEXT UNIQUE,
shared_metadata JSONB,
incoming_webhook_token CITEXT UNIQUE,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
last_used_at TIMESTAMPTZ,
legacy_mongo_id TEXT
);
"""
)
op.execute(
"ALTER TABLE token_usage ADD CONSTRAINT token_usage_agent_fk "
"FOREIGN KEY (agent_id) REFERENCES agents(id) ON DELETE SET NULL;"
)
op.execute(
"""
CREATE TABLE attachments (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
filename TEXT NOT NULL,
upload_path TEXT NOT NULL,
mime_type TEXT,
size BIGINT,
content TEXT,
token_count INTEGER,
openai_file_id TEXT,
google_file_uri TEXT,
metadata JSONB,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
legacy_mongo_id TEXT
);
"""
)
op.execute(
"""
CREATE TABLE memories (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
tool_id UUID REFERENCES user_tools(id) ON DELETE CASCADE,
path TEXT NOT NULL,
content TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
"""
)
op.execute(
"""
CREATE TABLE todos (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
tool_id UUID REFERENCES user_tools(id) ON DELETE CASCADE,
todo_id INTEGER,
title TEXT NOT NULL,
completed BOOLEAN NOT NULL DEFAULT false,
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
legacy_mongo_id TEXT
);
"""
)
op.execute(
"""
CREATE TABLE notes (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
tool_id UUID REFERENCES user_tools(id) ON DELETE CASCADE,
title TEXT NOT NULL,
content TEXT NOT NULL,
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
legacy_mongo_id TEXT
);
"""
)
op.execute(
"""
CREATE TABLE connector_sessions (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
provider TEXT NOT NULL,
server_url TEXT,
session_token TEXT UNIQUE,
user_email TEXT,
status TEXT,
token_info JSONB,
session_data JSONB NOT NULL DEFAULT '{}'::jsonb,
expires_at TIMESTAMPTZ,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
legacy_mongo_id TEXT
);
"""
)
op.execute(
"""
CREATE TABLE conversations (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
agent_id UUID REFERENCES agents(id) ON DELETE SET NULL,
name TEXT,
api_key TEXT,
is_shared_usage BOOLEAN NOT NULL DEFAULT false,
shared_token TEXT,
date TIMESTAMPTZ NOT NULL DEFAULT now(),
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
shared_with TEXT[] NOT NULL DEFAULT '{}'::text[],
compression_metadata JSONB,
legacy_mongo_id TEXT,
CONSTRAINT conversations_api_key_nonempty_chk
CHECK (api_key IS NULL OR api_key <> '')
);
"""
)
op.execute(
"""
CREATE TABLE conversation_messages (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
conversation_id UUID NOT NULL REFERENCES conversations(id) ON DELETE CASCADE,
position INTEGER NOT NULL,
prompt TEXT,
response TEXT,
thought TEXT,
sources JSONB NOT NULL DEFAULT '[]'::jsonb,
tool_calls JSONB NOT NULL DEFAULT '[]'::jsonb,
attachments UUID[] NOT NULL DEFAULT '{}'::uuid[],
model_id TEXT,
message_metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
feedback JSONB,
timestamp TIMESTAMPTZ NOT NULL DEFAULT now(),
user_id TEXT NOT NULL,
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
"""
)
op.execute(
"""
CREATE TABLE shared_conversations (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
conversation_id UUID NOT NULL REFERENCES conversations(id) ON DELETE CASCADE,
user_id TEXT NOT NULL,
is_promptable BOOLEAN NOT NULL DEFAULT false,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
uuid UUID NOT NULL,
first_n_queries INTEGER NOT NULL DEFAULT 0,
api_key TEXT,
prompt_id UUID REFERENCES prompts(id) ON DELETE SET NULL,
chunks INTEGER,
CONSTRAINT shared_conversations_api_key_nonempty_chk
CHECK (api_key IS NULL OR api_key <> '')
);
"""
)
op.execute(
"""
CREATE TABLE pending_tool_state (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
conversation_id UUID NOT NULL REFERENCES conversations(id) ON DELETE CASCADE,
user_id TEXT NOT NULL,
messages JSONB NOT NULL,
pending_tool_calls JSONB NOT NULL,
tools_dict JSONB NOT NULL,
tool_schemas JSONB NOT NULL,
agent_config JSONB NOT NULL,
client_tools JSONB,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
expires_at TIMESTAMPTZ NOT NULL
);
"""
)
op.execute(
"""
CREATE TABLE workflows (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
name TEXT NOT NULL,
description TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
current_graph_version INTEGER NOT NULL DEFAULT 1,
legacy_mongo_id TEXT
);
"""
)
# Backfill the agents.workflow_id FK now that workflows exists.
# The column was created without a FK (forward reference to a table
# that hadn't been declared yet); add the constraint here so workflow
# deletion still cascades through to agent unset.
op.execute(
"ALTER TABLE agents ADD CONSTRAINT agents_workflow_fk "
"FOREIGN KEY (workflow_id) REFERENCES workflows(id) ON DELETE SET NULL;"
)
op.execute(
"""
CREATE TABLE workflow_nodes (
id UUID DEFAULT gen_random_uuid() NOT NULL,
workflow_id UUID NOT NULL REFERENCES workflows(id) ON DELETE CASCADE,
graph_version INTEGER NOT NULL,
node_type TEXT NOT NULL,
config JSONB NOT NULL DEFAULT '{}'::jsonb,
node_id TEXT NOT NULL,
title TEXT,
description TEXT,
position JSONB NOT NULL DEFAULT '{"x": 0, "y": 0}'::jsonb,
legacy_mongo_id TEXT,
PRIMARY KEY (id),
CONSTRAINT workflow_nodes_id_wf_ver_key
UNIQUE (id, workflow_id, graph_version)
);
"""
)
op.execute(
"""
CREATE TABLE workflow_edges (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
workflow_id UUID NOT NULL REFERENCES workflows(id) ON DELETE CASCADE,
graph_version INTEGER NOT NULL,
from_node_id UUID NOT NULL,
to_node_id UUID NOT NULL,
config JSONB NOT NULL DEFAULT '{}'::jsonb,
edge_id TEXT NOT NULL,
source_handle TEXT,
target_handle TEXT,
CONSTRAINT workflow_edges_from_node_fk
FOREIGN KEY (from_node_id, workflow_id, graph_version)
REFERENCES workflow_nodes(id, workflow_id, graph_version) ON DELETE CASCADE,
CONSTRAINT workflow_edges_to_node_fk
FOREIGN KEY (to_node_id, workflow_id, graph_version)
REFERENCES workflow_nodes(id, workflow_id, graph_version) ON DELETE CASCADE
);
"""
)
op.execute(
"""
CREATE TABLE workflow_runs (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
workflow_id UUID NOT NULL REFERENCES workflows(id) ON DELETE CASCADE,
user_id TEXT NOT NULL,
status TEXT NOT NULL,
started_at TIMESTAMPTZ NOT NULL DEFAULT now(),
ended_at TIMESTAMPTZ,
result JSONB,
inputs JSONB,
steps JSONB NOT NULL DEFAULT '[]'::jsonb,
legacy_mongo_id TEXT,
CONSTRAINT workflow_runs_status_chk
CHECK (status IN ('pending', 'running', 'completed', 'failed'))
);
"""
)
# ------------------------------------------------------------------
# Indexes
# ------------------------------------------------------------------
op.execute("CREATE INDEX agent_folders_user_idx ON agent_folders (user_id);")
op.execute("CREATE INDEX agents_user_idx ON agents (user_id);")
op.execute("CREATE INDEX agents_shared_idx ON agents (shared) WHERE shared = true;")
op.execute("CREATE INDEX agents_status_idx ON agents (status);")
op.execute("CREATE INDEX agents_source_id_idx ON agents (source_id);")
op.execute("CREATE INDEX agents_prompt_id_idx ON agents (prompt_id);")
op.execute("CREATE INDEX agents_folder_id_idx ON agents (folder_id);")
op.execute(
"CREATE UNIQUE INDEX agents_legacy_mongo_id_uidx "
"ON agents (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
)
op.execute("CREATE INDEX attachments_user_idx ON attachments (user_id);")
op.execute(
"CREATE UNIQUE INDEX attachments_legacy_mongo_id_uidx "
"ON attachments (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
)
op.execute(
# MCP and OAuth connectors share the ``provider`` slot, so the
# dedup key is ``(user_id, server_url, provider)``: MCP rows
# differentiate by server_url (one per MCP server), OAuth rows
# have server_url = NULL and differentiate by provider alone.
# COALESCE lets NULL server_url participate in the constraint.
"CREATE UNIQUE INDEX connector_sessions_user_endpoint_uidx "
"ON connector_sessions (user_id, COALESCE(server_url, ''), provider);"
)
op.execute(
"CREATE INDEX connector_sessions_expiry_idx "
"ON connector_sessions (expires_at) WHERE expires_at IS NOT NULL;"
)
op.execute(
"CREATE INDEX connector_sessions_server_url_idx "
"ON connector_sessions (server_url) WHERE server_url IS NOT NULL;"
)
op.execute(
"CREATE UNIQUE INDEX connector_sessions_legacy_mongo_id_uidx "
"ON connector_sessions (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
)
op.execute(
"CREATE UNIQUE INDEX conversation_messages_conv_pos_uidx "
"ON conversation_messages (conversation_id, position);"
)
op.execute(
"CREATE INDEX conversation_messages_user_ts_idx "
"ON conversation_messages (user_id, timestamp DESC);"
)
op.execute("CREATE INDEX conversations_user_date_idx ON conversations (user_id, date DESC);")
op.execute("CREATE INDEX conversations_agent_idx ON conversations (agent_id);")
op.execute(
"CREATE UNIQUE INDEX conversations_shared_token_uidx "
"ON conversations (shared_token) WHERE shared_token IS NOT NULL;"
)
op.execute(
"CREATE INDEX conversations_api_key_date_idx "
"ON conversations (api_key, date DESC) WHERE api_key IS NOT NULL;"
)
op.execute(
"CREATE UNIQUE INDEX conversations_legacy_mongo_id_uidx "
"ON conversations (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
)
op.execute(
"CREATE UNIQUE INDEX memories_user_tool_path_uidx "
"ON memories (user_id, tool_id, path);"
)
op.execute(
"CREATE UNIQUE INDEX memories_user_path_null_tool_uidx "
"ON memories (user_id, path) WHERE tool_id IS NULL;"
)
op.execute(
"CREATE INDEX memories_path_prefix_idx "
"ON memories (user_id, tool_id, path text_pattern_ops);"
)
op.execute("CREATE INDEX memories_tool_id_idx ON memories (tool_id);")
op.execute("CREATE UNIQUE INDEX notes_user_tool_uidx ON notes (user_id, tool_id);")
op.execute("CREATE INDEX notes_tool_id_idx ON notes (tool_id);")
op.execute(
"CREATE UNIQUE INDEX notes_legacy_mongo_id_uidx "
"ON notes (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
)
op.execute(
"CREATE UNIQUE INDEX pending_tool_state_conv_user_uidx "
"ON pending_tool_state (conversation_id, user_id);"
)
op.execute(
"CREATE INDEX pending_tool_state_expires_idx ON pending_tool_state (expires_at);"
)
op.execute("CREATE INDEX prompts_user_id_idx ON prompts (user_id);")
op.execute(
"CREATE UNIQUE INDEX prompts_legacy_mongo_id_uidx "
"ON prompts (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
)
op.execute("CREATE INDEX shared_conversations_user_idx ON shared_conversations (user_id);")
op.execute("CREATE INDEX shared_conversations_conv_idx ON shared_conversations (conversation_id);")
op.execute(
"CREATE INDEX shared_conversations_prompt_id_idx ON shared_conversations (prompt_id);"
)
op.execute(
"CREATE UNIQUE INDEX shared_conversations_uuid_uidx ON shared_conversations (uuid);"
)
op.execute(
"CREATE UNIQUE INDEX shared_conversations_dedup_uidx "
"ON shared_conversations (conversation_id, user_id, is_promptable, first_n_queries, COALESCE(api_key, ''));"
)
op.execute("CREATE INDEX sources_user_idx ON sources (user_id);")
op.execute(
"CREATE UNIQUE INDEX sources_legacy_mongo_id_uidx "
"ON sources (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
)
op.execute(
"CREATE UNIQUE INDEX user_tools_legacy_mongo_id_uidx "
"ON user_tools (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
)
op.execute(
"CREATE UNIQUE INDEX agent_folders_legacy_mongo_id_uidx "
"ON agent_folders (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
)
op.execute("CREATE INDEX agent_folders_parent_idx ON agent_folders (parent_id);")
op.execute("CREATE INDEX agents_workflow_idx ON agents (workflow_id);")
op.execute('CREATE INDEX stack_logs_timestamp_idx ON stack_logs ("timestamp" DESC);')
op.execute('CREATE INDEX stack_logs_user_ts_idx ON stack_logs (user_id, "timestamp" DESC);')
op.execute('CREATE INDEX stack_logs_level_ts_idx ON stack_logs (level, "timestamp" DESC);')
op.execute("CREATE INDEX stack_logs_activity_idx ON stack_logs (activity_id);")
op.execute(
"CREATE UNIQUE INDEX stack_logs_mongo_id_uidx "
"ON stack_logs (mongo_id) WHERE mongo_id IS NOT NULL;"
)
op.execute("CREATE INDEX todos_user_tool_idx ON todos (user_id, tool_id);")
op.execute("CREATE INDEX todos_tool_id_idx ON todos (tool_id);")
op.execute(
"CREATE UNIQUE INDEX todos_legacy_mongo_id_uidx "
"ON todos (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
)
op.execute(
"CREATE UNIQUE INDEX todos_tool_todo_id_uidx "
"ON todos (tool_id, todo_id) WHERE todo_id IS NOT NULL;"
)
op.execute('CREATE INDEX token_usage_user_ts_idx ON token_usage (user_id, "timestamp" DESC);')
op.execute('CREATE INDEX token_usage_key_ts_idx ON token_usage (api_key, "timestamp" DESC);')
op.execute('CREATE INDEX token_usage_agent_ts_idx ON token_usage (agent_id, "timestamp" DESC);')
op.execute(
"CREATE UNIQUE INDEX token_usage_mongo_id_uidx "
"ON token_usage (mongo_id) WHERE mongo_id IS NOT NULL;"
)
op.execute('CREATE INDEX user_logs_user_ts_idx ON user_logs (user_id, "timestamp" DESC);')
op.execute(
"CREATE UNIQUE INDEX user_logs_mongo_id_uidx "
"ON user_logs (mongo_id) WHERE mongo_id IS NOT NULL;"
)
op.execute("CREATE INDEX user_tools_user_id_idx ON user_tools (user_id);")
op.execute("CREATE INDEX workflow_edges_from_node_idx ON workflow_edges (from_node_id);")
op.execute("CREATE INDEX workflow_edges_to_node_idx ON workflow_edges (to_node_id);")
op.execute(
"CREATE UNIQUE INDEX workflow_edges_wf_ver_eid_uidx "
"ON workflow_edges (workflow_id, graph_version, edge_id);"
)
op.execute(
"CREATE UNIQUE INDEX workflow_nodes_wf_ver_nid_uidx "
"ON workflow_nodes (workflow_id, graph_version, node_id);"
)
op.execute(
"CREATE UNIQUE INDEX workflow_nodes_legacy_mongo_id_uidx "
"ON workflow_nodes (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
)
op.execute("CREATE INDEX workflow_runs_workflow_idx ON workflow_runs (workflow_id);")
op.execute("CREATE INDEX workflow_runs_user_idx ON workflow_runs (user_id);")
op.execute(
"CREATE INDEX workflow_runs_status_started_idx "
"ON workflow_runs (status, started_at DESC);"
)
op.execute(
"CREATE UNIQUE INDEX workflow_runs_legacy_mongo_id_uidx "
"ON workflow_runs (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
)
op.execute("CREATE INDEX workflows_user_idx ON workflows (user_id);")
op.execute(
"CREATE UNIQUE INDEX workflows_legacy_mongo_id_uidx "
"ON workflows (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
)
# ------------------------------------------------------------------
# user_id foreign keys (deferrable so backfills can stage rows)
# ------------------------------------------------------------------
user_fk_tables = (
"agent_folders",
"agents",
"attachments",
"connector_sessions",
"conversation_messages",
"conversations",
"memories",
"notes",
"pending_tool_state",
"prompts",
"shared_conversations",
"sources",
"stack_logs",
"todos",
"token_usage",
"user_logs",
"user_tools",
"workflow_runs",
"workflows",
)
for table in user_fk_tables:
op.execute(
f"ALTER TABLE {table} "
f"ADD CONSTRAINT {table}_user_id_fk "
f"FOREIGN KEY (user_id) REFERENCES users(user_id) "
f"ON DELETE RESTRICT DEFERRABLE INITIALLY IMMEDIATE;"
)
# ------------------------------------------------------------------
# Triggers
# ------------------------------------------------------------------
updated_at_tables = (
"agent_folders",
"agents",
"conversation_messages",
"conversations",
"memories",
"notes",
"prompts",
"sources",
"todos",
"user_tools",
"users",
"workflows",
)
for table in updated_at_tables:
op.execute(
f"CREATE TRIGGER {table}_set_updated_at "
f"BEFORE UPDATE ON {table} "
f"FOR EACH ROW WHEN (OLD.* IS DISTINCT FROM NEW.*) "
f"EXECUTE FUNCTION set_updated_at();"
)
ensure_user_tables = (
"agent_folders",
"agents",
"attachments",
"connector_sessions",
"conversation_messages",
"conversations",
"memories",
"notes",
"pending_tool_state",
"prompts",
"shared_conversations",
"sources",
"stack_logs",
"todos",
"token_usage",
"user_logs",
"user_tools",
"workflow_runs",
"workflows",
)
for table in ensure_user_tables:
op.execute(
f"CREATE TRIGGER {table}_ensure_user "
f"BEFORE INSERT OR UPDATE OF user_id ON {table} "
f"FOR EACH ROW EXECUTE FUNCTION ensure_user_exists();"
)
op.execute(
"CREATE TRIGGER conversation_messages_fill_user "
"BEFORE INSERT ON conversation_messages "
"FOR EACH ROW EXECUTE FUNCTION conversation_messages_fill_user_id();"
)
op.execute(
"CREATE TRIGGER attachments_cleanup_message_refs "
"AFTER DELETE ON attachments "
"FOR EACH ROW EXECUTE FUNCTION cleanup_message_attachment_refs();"
)
op.execute(
"CREATE TRIGGER agents_cleanup_user_prefs "
"AFTER DELETE ON agents "
"FOR EACH ROW EXECUTE FUNCTION cleanup_user_agent_prefs();"
)
op.execute(
"CREATE TRIGGER sources_cleanup_agent_extra_refs "
"AFTER DELETE ON sources "
"FOR EACH ROW EXECUTE FUNCTION cleanup_agent_extra_source_refs();"
)
# ------------------------------------------------------------------
# Seed sentinel __system__ user (system/template sources attribute here)
# ------------------------------------------------------------------
op.execute(
"INSERT INTO users (user_id) VALUES ('__system__') "
"ON CONFLICT (user_id) DO NOTHING;"
)
def downgrade() -> None:
# Nuclear downgrade: drop everything this migration created. The
# ordering drops FK-bearing children before parents; CASCADE would
# also work but explicit ordering is easier to reason about in code
# review.
tables_in_drop_order = (
"workflow_edges",
"workflow_runs",
"workflow_nodes",
"workflows",
"pending_tool_state",
"shared_conversations",
"conversation_messages",
"conversations",
"connector_sessions",
"notes",
"todos",
"memories",
"attachments",
"agents",
"sources",
"agent_folders",
"stack_logs",
"user_logs",
"token_usage",
"user_tools",
"prompts",
"users",
)
for table in tables_in_drop_order:
op.execute(f"DROP TABLE IF EXISTS {table} CASCADE;")
for fn in (
"conversation_messages_fill_user_id",
"cleanup_user_agent_prefs",
"cleanup_agent_extra_source_refs",
"cleanup_message_attachment_refs",
"ensure_user_exists",
"set_updated_at",
):
op.execute(f"DROP FUNCTION IF EXISTS {fn}();")

View File

@@ -0,0 +1,37 @@
"""0002 app_metadata — singleton key/value table for instance-wide state.
Used by the startup version-check client to persist the anonymous
instance UUID and a one-shot "notice shown" flag. Both values are tiny
plain-text strings; this is a deliberate generic-config table rather
than dedicated columns so future one-off settings (telemetry opt-in
timestamps, feature-flag overrides, etc.) don't each need their own
migration.
Revision ID: 0002_app_metadata
Revises: 0001_initial
"""
from typing import Sequence, Union
from alembic import op
revision: str = "0002_app_metadata"
down_revision: Union[str, None] = "0001_initial"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.execute(
"""
CREATE TABLE app_metadata (
key TEXT PRIMARY KEY,
value TEXT NOT NULL
);
"""
)
def downgrade() -> None:
op.execute("DROP TABLE IF EXISTS app_metadata;")

View File

@@ -74,68 +74,76 @@ class AnswerResource(Resource, BaseAnswerResource):
decoded_token = getattr(request, "decoded_token", None)
processor = StreamProcessor(data, decoded_token)
try:
processor.initialize()
if not processor.decoded_token:
return make_response({"error": "Unauthorized"}, 401)
# ---- Continuation mode ----
if data.get("tool_actions"):
(
agent,
messages,
tools_dict,
pending_tool_calls,
tool_actions,
) = processor.resume_from_tool_actions(
data["tool_actions"], data["conversation_id"]
)
if not processor.decoded_token:
return make_response({"error": "Unauthorized"}, 401)
if error := self.check_usage(processor.agent_config):
return error
stream = self.complete_stream(
question="",
agent=agent,
conversation_id=processor.conversation_id,
user_api_key=processor.agent_config.get("user_api_key"),
decoded_token=processor.decoded_token,
agent_id=processor.agent_id,
model_id=processor.model_id,
_continuation={
"messages": messages,
"tools_dict": tools_dict,
"pending_tool_calls": pending_tool_calls,
"tool_actions": tool_actions,
},
)
else:
# ---- Normal mode ----
agent = processor.build_agent(data.get("question", ""))
if not processor.decoded_token:
return make_response({"error": "Unauthorized"}, 401)
docs_together, docs_list = processor.pre_fetch_docs(
data.get("question", "")
)
tools_data = processor.pre_fetch_tools()
if error := self.check_usage(processor.agent_config):
return error
agent = processor.create_agent(
docs_together=docs_together,
docs=docs_list,
tools_data=tools_data,
)
stream = self.complete_stream(
question=data["question"],
agent=agent,
conversation_id=processor.conversation_id,
user_api_key=processor.agent_config.get("user_api_key"),
decoded_token=processor.decoded_token,
isNoneDoc=data.get("isNoneDoc"),
index=None,
should_save_conversation=data.get("save_conversation", True),
agent_id=processor.agent_id,
is_shared_usage=processor.is_shared_usage,
shared_token=processor.shared_token,
model_id=processor.model_id,
)
if error := self.check_usage(processor.agent_config):
return error
stream = self.complete_stream(
question=data["question"],
agent=agent,
conversation_id=processor.conversation_id,
user_api_key=processor.agent_config.get("user_api_key"),
decoded_token=processor.decoded_token,
isNoneDoc=data.get("isNoneDoc"),
index=None,
should_save_conversation=data.get("save_conversation", True),
agent_id=processor.agent_id,
is_shared_usage=processor.is_shared_usage,
shared_token=processor.shared_token,
model_id=processor.model_id,
)
stream_result = self.process_response_stream(stream)
if len(stream_result) == 7:
(
conversation_id,
response,
sources,
tool_calls,
thought,
error,
structured_info,
) = stream_result
else:
conversation_id, response, sources, tool_calls, thought, error = (
stream_result
)
structured_info = None
if stream_result["error"]:
return make_response({"error": stream_result["error"]}, 400)
if error:
return make_response({"error": error}, 400)
result = {
"conversation_id": conversation_id,
"answer": response,
"sources": sources,
"tool_calls": tool_calls,
"thought": thought,
"conversation_id": stream_result["conversation_id"],
"answer": stream_result["answer"],
"sources": stream_result["sources"],
"tool_calls": stream_result["tool_calls"],
"thought": stream_result["thought"],
}
if structured_info:
result.update(structured_info)
extra_info = stream_result.get("extra")
if extra_info:
result.update(extra_info)
except Exception as e:
logger.error(
f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}",

View File

@@ -6,6 +6,7 @@ from typing import Any, Dict, Generator, List, Optional
from flask import jsonify, make_response, Response
from flask_restx import Namespace
from application.api.answer.services.continuation_service import ContinuationService
from application.api.answer.services.conversation_service import ConversationService
from application.core.model_utils import (
get_api_key_for_provider,
@@ -13,9 +14,13 @@ from application.core.model_utils import (
get_provider_from_model_id,
)
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.error import sanitize_api_error
from application.llm.llm_creator import LLMCreator
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.repositories.token_usage import TokenUsageRepository
from application.storage.db.repositories.user_logs import UserLogsRepository
from application.storage.db.session import db_readonly, db_session
from application.utils import check_required_fields
logger = logging.getLogger(__name__)
@@ -28,17 +33,22 @@ class BaseAnswerResource:
"""Shared base class for answer endpoints"""
def __init__(self):
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
self.db = db
self.user_logs_collection = db["user_logs"]
self.default_model_id = get_default_model_id()
self.conversation_service = ConversationService()
def validate_request(
self, data: Dict[str, Any], require_conversation_id: bool = False
) -> Optional[Response]:
"""Common request validation"""
"""Common request validation.
Continuation requests (``tool_actions`` present) require
``conversation_id`` but not ``question``.
"""
if data.get("tool_actions"):
# Continuation mode — question is not required
if missing := check_required_fields(data, ["conversation_id"]):
return missing
return None
required_fields = ["question"]
if require_conversation_id:
required_fields.append("conversation_id")
@@ -80,8 +90,8 @@ class BaseAnswerResource:
api_key = agent_config.get("user_api_key")
if not api_key:
return None
agents_collection = self.db["agents"]
agent = agents_collection.find_one({"key": api_key})
with db_readonly() as conn:
agent = AgentsRepository(conn).find_by_key(api_key)
if not agent:
return make_response(
@@ -102,41 +112,32 @@ class BaseAnswerResource:
)
token_limit = int(
agent.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"])
agent.get("token_limit") or settings.DEFAULT_AGENT_LIMITS["token_limit"]
)
request_limit = int(
agent.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"])
agent.get("request_limit") or settings.DEFAULT_AGENT_LIMITS["request_limit"]
)
token_usage_collection = self.db["token_usage"]
end_date = datetime.datetime.now()
end_date = datetime.datetime.now(datetime.timezone.utc)
start_date = end_date - datetime.timedelta(hours=24)
match_query = {
"timestamp": {"$gte": start_date, "$lte": end_date},
"api_key": api_key,
}
if limited_token_mode:
token_pipeline = [
{"$match": match_query},
{
"$group": {
"_id": None,
"total_tokens": {
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
},
}
},
]
token_result = list(token_usage_collection.aggregate(token_pipeline))
daily_token_usage = token_result[0]["total_tokens"] if token_result else 0
if limited_token_mode or limited_request_mode:
with db_readonly() as conn:
token_repo = TokenUsageRepository(conn)
if limited_token_mode:
daily_token_usage = token_repo.sum_tokens_in_range(
start=start_date, end=end_date, api_key=api_key,
)
else:
daily_token_usage = 0
if limited_request_mode:
daily_request_usage = token_repo.count_in_range(
start=start_date, end=end_date, api_key=api_key,
)
else:
daily_request_usage = 0
else:
daily_token_usage = 0
if limited_request_mode:
daily_request_usage = token_usage_collection.count_documents(match_query)
else:
daily_request_usage = 0
if not limited_token_mode and not limited_request_mode:
return None
@@ -176,6 +177,7 @@ class BaseAnswerResource:
is_shared_usage: bool = False,
shared_token: Optional[str] = None,
model_id: Optional[str] = None,
_continuation: Optional[Dict] = None,
) -> Generator[str, None, None]:
"""
Generator function that streams the complete conversation response.
@@ -205,9 +207,23 @@ class BaseAnswerResource:
is_structured = False
schema_info = None
structured_chunks = []
query_metadata = {}
paused = False
for line in agent.gen(query=question):
if "answer" in line:
if _continuation:
gen_iter = agent.gen_continuation(
messages=_continuation["messages"],
tools_dict=_continuation["tools_dict"],
pending_tool_calls=_continuation["pending_tool_calls"],
tool_actions=_continuation["tool_actions"],
)
else:
gen_iter = agent.gen(query=question)
for line in gen_iter:
if "metadata" in line:
query_metadata.update(line["metadata"])
elif "answer" in line:
response_full += str(line["answer"])
if line.get("structured"):
is_structured = True
@@ -240,8 +256,21 @@ class BaseAnswerResource:
data = json.dumps({"type": "thought", "thought": line["thought"]})
yield f"data: {data}\n\n"
elif "type" in line:
data = json.dumps(line)
yield f"data: {data}\n\n"
if line.get("type") == "tool_calls_pending":
# Save continuation state and end the stream
paused = True
data = json.dumps(line)
yield f"data: {data}\n\n"
elif line.get("type") == "error":
sanitized_error = {
"type": "error",
"error": sanitize_api_error(line.get("error", "An error occurred"))
}
data = json.dumps(sanitized_error)
yield f"data: {data}\n\n"
else:
data = json.dumps(line)
yield f"data: {data}\n\n"
if is_structured and structured_chunks:
structured_data = {
"type": "structured_answer",
@@ -251,6 +280,93 @@ class BaseAnswerResource:
}
data = json.dumps(structured_data)
yield f"data: {data}\n\n"
# ---- Paused: save continuation state and end stream early ----
if paused:
continuation = getattr(agent, "_pending_continuation", None)
if continuation:
# Ensure we have a conversation_id — create a partial
# conversation if this is the first turn.
if not conversation_id and should_save_conversation:
try:
provider = (
get_provider_from_model_id(model_id)
if model_id
else settings.LLM_PROVIDER
)
sys_api_key = get_api_key_for_provider(
provider or settings.LLM_PROVIDER
)
llm = LLMCreator.create_llm(
provider or settings.LLM_PROVIDER,
api_key=sys_api_key,
user_api_key=user_api_key,
decoded_token=decoded_token,
model_id=model_id,
agent_id=agent_id,
)
conversation_id = (
self.conversation_service.save_conversation(
None,
question,
response_full,
thought,
source_log_docs,
tool_calls,
llm,
model_id or self.default_model_id,
decoded_token,
api_key=user_api_key,
agent_id=agent_id,
is_shared_usage=is_shared_usage,
shared_token=shared_token,
)
)
except Exception as e:
logger.error(
f"Failed to create conversation for continuation: {e}",
exc_info=True,
)
if conversation_id:
try:
cont_service = ContinuationService()
cont_service.save_state(
conversation_id=str(conversation_id),
user=decoded_token.get("sub", "local"),
messages=continuation["messages"],
pending_tool_calls=continuation["pending_tool_calls"],
tools_dict=continuation["tools_dict"],
tool_schemas=getattr(agent, "tools", []),
agent_config={
"model_id": model_id or self.default_model_id,
"llm_name": getattr(agent, "llm_name", settings.LLM_PROVIDER),
"api_key": getattr(agent, "api_key", None),
"user_api_key": user_api_key,
"agent_id": agent_id,
"agent_type": agent.__class__.__name__,
"prompt": getattr(agent, "prompt", ""),
"json_schema": getattr(agent, "json_schema", None),
"retriever_config": getattr(agent, "retriever_config", None),
},
client_tools=getattr(
agent.tool_executor, "client_tools", None
),
)
except Exception as e:
logger.error(
f"Failed to save continuation state: {str(e)}",
exc_info=True,
)
id_data = {"type": "id", "id": str(conversation_id)}
data = json.dumps(id_data)
yield f"data: {data}\n\n"
data = json.dumps({"type": "end"})
yield f"data: {data}\n\n"
return
if isNoneDoc:
for doc in source_log_docs:
doc["source"] = "None"
@@ -287,6 +403,7 @@ class BaseAnswerResource:
is_shared_usage=is_shared_usage,
shared_token=shared_token,
attachment_ids=attachment_ids,
metadata=query_metadata if query_metadata else None,
)
# Persist compression metadata/summary if it exists and wasn't saved mid-execution
compression_meta = getattr(agent, "compression_metadata", None)
@@ -340,7 +457,18 @@ class BaseAnswerResource:
for key, value in log_data.items():
if isinstance(value, str) and len(value) > 10000:
log_data[key] = value[:10000]
self.user_logs_collection.insert_one(log_data)
try:
with db_session() as conn:
UserLogsRepository(conn).insert(
user_id=log_data.get("user"),
endpoint="stream_answer",
data=log_data,
)
except Exception as log_err:
logger.error(
f"Failed to persist stream_answer user log: {log_err}",
exc_info=True,
)
data = json.dumps({"type": "end"})
yield f"data: {data}\n\n"
@@ -376,6 +504,7 @@ class BaseAnswerResource:
is_shared_usage=is_shared_usage,
shared_token=shared_token,
attachment_ids=attachment_ids,
metadata=query_metadata if query_metadata else None,
)
compression_meta = getattr(agent, "compression_metadata", None)
compression_saved = getattr(agent, "compression_saved", False)
@@ -412,8 +541,13 @@ class BaseAnswerResource:
yield f"data: {data}\n\n"
return
def process_response_stream(self, stream):
"""Process the stream response for non-streaming endpoint"""
def process_response_stream(self, stream) -> Dict[str, Any]:
"""Process the stream response for non-streaming endpoint.
Returns:
Dict with keys: conversation_id, answer, sources, tool_calls,
thought, error, and optional extra.
"""
conversation_id = ""
response_full = ""
source_log_docs = []
@@ -422,6 +556,7 @@ class BaseAnswerResource:
stream_ended = False
is_structured = False
schema_info = None
pending_tool_calls = None
for line in stream:
try:
@@ -440,11 +575,22 @@ class BaseAnswerResource:
source_log_docs = event["source"]
elif event["type"] == "tool_calls":
tool_calls = event["tool_calls"]
elif event["type"] == "tool_calls_pending":
pending_tool_calls = event.get("data", {}).get(
"pending_tool_calls", []
)
elif event["type"] == "thought":
thought = event["thought"]
elif event["type"] == "error":
logger.error(f"Error from stream: {event['error']}")
return None, None, None, None, event["error"], None
return {
"conversation_id": None,
"answer": None,
"sources": None,
"tool_calls": None,
"thought": None,
"error": event["error"],
}
elif event["type"] == "end":
stream_ended = True
except (json.JSONDecodeError, KeyError) as e:
@@ -452,18 +598,30 @@ class BaseAnswerResource:
continue
if not stream_ended:
logger.error("Stream ended unexpectedly without an 'end' event.")
return None, None, None, None, "Stream ended unexpectedly", None
result = (
conversation_id,
response_full,
source_log_docs,
tool_calls,
thought,
None,
)
return {
"conversation_id": None,
"answer": None,
"sources": None,
"tool_calls": None,
"thought": None,
"error": "Stream ended unexpectedly",
}
result: Dict[str, Any] = {
"conversation_id": conversation_id,
"answer": response_full,
"sources": source_log_docs,
"tool_calls": tool_calls,
"thought": thought,
"error": None,
}
if pending_tool_calls is not None:
result["extra"] = {"pending_tool_calls": pending_tool_calls}
if is_structured:
result = result + ({"structured": True, "schema": schema_info},)
result["extra"] = {"structured": True, "schema": schema_info}
return result
def error_stream_generate(self, err_response):

View File

@@ -1,28 +1,21 @@
import logging
from typing import Any, Dict, List
from flask import make_response, request
from flask_restx import fields, Resource
from bson.dbref import DBRef
from application.api.answer.routes.base import answer_ns
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.vectorstore.vector_creator import VectorCreator
from application.services.search_service import (
InvalidAPIKey,
SearchFailed,
search,
)
logger = logging.getLogger(__name__)
@answer_ns.route("/api/search")
class SearchResource(Resource):
"""Fast search endpoint for retrieving relevant documents"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
mongo = MongoDB.get_client()
self.db = mongo[settings.MONGO_DB_NAME]
self.agents_collection = self.db["agents"]
"""Fast search endpoint for retrieving relevant documents."""
search_model = answer_ns.model(
"SearchModel",
@@ -39,116 +32,10 @@ class SearchResource(Resource):
},
)
def _get_sources_from_api_key(self, api_key: str) -> List[str]:
"""Get source IDs connected to the API key/agent.
"""
agent_data = self.agents_collection.find_one({"key": api_key})
if not agent_data:
return []
source_ids = []
# Handle multiple sources (only if non-empty)
sources = agent_data.get("sources", [])
if sources and isinstance(sources, list) and len(sources) > 0:
for source_ref in sources:
# Skip "default" - it's a placeholder, not an actual vectorstore
if source_ref == "default":
continue
elif isinstance(source_ref, DBRef):
source_doc = self.db.dereference(source_ref)
if source_doc:
source_ids.append(str(source_doc["_id"]))
# Handle single source (legacy) - check if sources was empty or didn't yield results
if not source_ids:
source = agent_data.get("source")
if isinstance(source, DBRef):
source_doc = self.db.dereference(source)
if source_doc:
source_ids.append(str(source_doc["_id"]))
# Skip "default" - it's a placeholder, not an actual vectorstore
elif source and source != "default":
source_ids.append(source)
return source_ids
def _search_vectorstores(
self, query: str, source_ids: List[str], chunks: int
) -> List[Dict[str, Any]]:
"""Search across vectorstores and return results"""
if not source_ids:
return []
results = []
chunks_per_source = max(1, chunks // len(source_ids))
seen_texts = set()
for source_id in source_ids:
if not source_id or not source_id.strip():
continue
try:
docsearch = VectorCreator.create_vectorstore(
settings.VECTOR_STORE, source_id, settings.EMBEDDINGS_KEY
)
docs = docsearch.search(query, k=chunks_per_source * 2)
for doc in docs:
if len(results) >= chunks:
break
if hasattr(doc, "page_content") and hasattr(doc, "metadata"):
page_content = doc.page_content
metadata = doc.metadata
else:
page_content = doc.get("text", doc.get("page_content", ""))
metadata = doc.get("metadata", {})
# Skip duplicates
text_hash = hash(page_content[:200])
if text_hash in seen_texts:
continue
seen_texts.add(text_hash)
title = metadata.get(
"title", metadata.get("post_title", "")
)
if not isinstance(title, str):
title = str(title) if title else ""
# Clean up title
if title:
title = title.split("/")[-1]
else:
# Use filename or first part of content as title
title = metadata.get("filename", page_content[:50] + "...")
source = metadata.get("source", source_id)
results.append({
"text": page_content,
"title": title,
"source": source,
})
if len(results) >= chunks:
break
except Exception as e:
logger.error(
f"Error searching vectorstore {source_id}: {e}",
exc_info=True,
)
continue
return results[:chunks]
@answer_ns.expect(search_model)
@answer_ns.doc(description="Search for relevant documents based on query")
def post(self):
data = request.get_json()
data = request.get_json() or {}
question = data.get("question")
api_key = data.get("api_key")
@@ -156,31 +43,13 @@ class SearchResource(Resource):
if not question:
return make_response({"error": "question is required"}, 400)
if not api_key:
return make_response({"error": "api_key is required"}, 400)
# Validate API key
agent = self.agents_collection.find_one({"key": api_key})
if not agent:
return make_response({"error": "Invalid API key"}, 401)
try:
# Get sources connected to this API key
source_ids = self._get_sources_from_api_key(api_key)
if not source_ids:
return make_response([], 200)
# Perform search
results = self._search_vectorstores(question, source_ids, chunks)
return make_response(results, 200)
except Exception as e:
logger.error(
f"/api/search - error: {str(e)}",
extra={"error": str(e)},
exc_info=True,
)
return make_response(search(api_key, question, chunks), 200)
except InvalidAPIKey:
return make_response({"error": "Invalid API key"}, 401)
except SearchFailed:
logger.exception("/api/search failed")
return make_response({"error": "Search failed"}, 500)

View File

@@ -79,8 +79,48 @@ class StreamResource(Resource, BaseAnswerResource):
return error
decoded_token = getattr(request, "decoded_token", None)
processor = StreamProcessor(data, decoded_token)
try:
processor.initialize()
# ---- Continuation mode ----
if data.get("tool_actions"):
(
agent,
messages,
tools_dict,
pending_tool_calls,
tool_actions,
) = processor.resume_from_tool_actions(
data["tool_actions"], data["conversation_id"]
)
if not processor.decoded_token:
return Response(
self.error_stream_generate("Unauthorized"),
status=401,
mimetype="text/event-stream",
)
if error := self.check_usage(processor.agent_config):
return error
return Response(
self.complete_stream(
question="",
agent=agent,
conversation_id=processor.conversation_id,
user_api_key=processor.agent_config.get("user_api_key"),
decoded_token=processor.decoded_token,
agent_id=processor.agent_id,
model_id=processor.model_id,
_continuation={
"messages": messages,
"tools_dict": tools_dict,
"pending_tool_calls": pending_tool_calls,
"tool_actions": tool_actions,
},
),
mimetype="text/event-stream",
)
# ---- Normal mode ----
agent = processor.build_agent(data["question"])
if not processor.decoded_token:
return Response(
self.error_stream_generate("Unauthorized"),
@@ -88,13 +128,6 @@ class StreamResource(Resource, BaseAnswerResource):
mimetype="text/event-stream",
)
docs_together, docs_list = processor.pre_fetch_docs(data["question"])
tools_data = processor.pre_fetch_tools()
agent = processor.create_agent(
docs_together=docs_together, docs=docs_list, tools_data=tools_data
)
if error := self.check_usage(processor.agent_config):
return error
return Response(

View File

@@ -1,5 +1,6 @@
"""Message reconstruction utilities for compression."""
import json
import logging
import uuid
from typing import Dict, List, Optional
@@ -49,28 +50,35 @@ class MessageBuilder:
if include_tool_calls and "tool_calls" in query:
for tool_call in query["tool_calls"]:
call_id = tool_call.get("call_id") or str(uuid.uuid4())
function_call_dict = {
"function_call": {
"name": tool_call.get("action_name"),
"args": tool_call.get("arguments"),
"call_id": call_id,
}
}
function_response_dict = {
"function_response": {
"name": tool_call.get("action_name"),
"response": {"result": tool_call.get("result")},
"call_id": call_id,
}
}
messages.append(
{"role": "assistant", "content": [function_call_dict]}
args = tool_call.get("arguments")
args_str = (
json.dumps(args)
if isinstance(args, dict)
else (args or "{}")
)
messages.append(
{"role": "tool", "content": [function_response_dict]}
messages.append({
"role": "assistant",
"content": None,
"tool_calls": [{
"id": call_id,
"type": "function",
"function": {
"name": tool_call.get("action_name", ""),
"arguments": args_str,
},
}],
})
result = tool_call.get("result")
result_str = (
json.dumps(result)
if not isinstance(result, str)
else (result or "")
)
messages.append({
"role": "tool",
"tool_call_id": call_id,
"content": result_str,
})
# If no recent queries (everything was compressed), add a continuation user message
if len(recent_queries) == 0 and compressed_summary:
@@ -180,28 +188,35 @@ class MessageBuilder:
if include_tool_calls and "tool_calls" in query:
for tool_call in query["tool_calls"]:
call_id = tool_call.get("call_id") or str(uuid.uuid4())
function_call_dict = {
"function_call": {
"name": tool_call.get("action_name"),
"args": tool_call.get("arguments"),
"call_id": call_id,
}
}
function_response_dict = {
"function_response": {
"name": tool_call.get("action_name"),
"response": {"result": tool_call.get("result")},
"call_id": call_id,
}
}
rebuilt_messages.append(
{"role": "assistant", "content": [function_call_dict]}
args = tool_call.get("arguments")
args_str = (
json.dumps(args)
if isinstance(args, dict)
else (args or "{}")
)
rebuilt_messages.append(
{"role": "tool", "content": [function_response_dict]}
rebuilt_messages.append({
"role": "assistant",
"content": None,
"tool_calls": [{
"id": call_id,
"type": "function",
"function": {
"name": tool_call.get("action_name", ""),
"arguments": args_str,
},
}],
})
result = tool_call.get("result")
result_str = (
json.dumps(result)
if not isinstance(result, str)
else (result or "")
)
rebuilt_messages.append({
"role": "tool",
"tool_call_id": call_id,
"content": result_str,
})
# If no recent queries (everything was compressed), add a continuation user message
if len(recent_queries) == 0 and compressed_summary:

View File

@@ -0,0 +1,157 @@
"""Service for saving and restoring tool-call continuation state.
When a stream pauses (tool needs approval or client-side execution),
the full execution state is persisted to Postgres so the client can
resume later by sending tool_actions.
"""
import logging
from typing import Any, Dict, List, Optional
from uuid import UUID
from application.storage.db.base_repository import looks_like_uuid
from application.storage.db.repositories.conversations import ConversationsRepository
from application.storage.db.repositories.pending_tool_state import (
PendingToolStateRepository,
)
from application.storage.db.session import db_readonly, db_session
logger = logging.getLogger(__name__)
# TTL for pending states — auto-cleaned after this period
PENDING_STATE_TTL_SECONDS = 30 * 60 # 30 minutes
def _make_serializable(obj: Any) -> Any:
"""Recursively coerce non-JSON values into JSON-safe forms.
Handles ``uuid.UUID`` (from PG columns), ``bytes``, and recurses into
dicts/lists. Post-Mongo-cutover the ObjectId branch is gone — none of
our writers produce them anymore.
"""
if isinstance(obj, UUID):
return str(obj)
if isinstance(obj, dict):
return {str(k): _make_serializable(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_make_serializable(v) for v in obj]
if isinstance(obj, bytes):
return obj.decode("utf-8", errors="replace")
return obj
class ContinuationService:
"""Manages pending tool-call state in Postgres."""
def __init__(self):
# No-op constructor retained for call-site compatibility. State
# lives in Postgres now; each operation opens its own short-lived
# session rather than holding a connection on the service.
pass
def save_state(
self,
conversation_id: str,
user: str,
messages: List[Dict],
pending_tool_calls: List[Dict],
tools_dict: Dict,
tool_schemas: List[Dict],
agent_config: Dict,
client_tools: Optional[List[Dict]] = None,
) -> str:
"""Save execution state for later continuation.
``conversation_id`` may be a Postgres UUID or the legacy Mongo
``ObjectId`` string — the latter is resolved via
``conversations.legacy_mongo_id`` to find the matching row.
Args:
conversation_id: The conversation this state belongs to.
user: Owner user ID.
messages: Full messages array at the pause point.
pending_tool_calls: Tool calls awaiting client action.
tools_dict: Serializable tools configuration dict.
tool_schemas: LLM-formatted tool schemas (agent.tools).
agent_config: Config needed to recreate the agent on resume.
client_tools: Client-provided tool schemas for client-side execution.
Returns:
The string ID (conversation_id as provided) of the saved state.
"""
with db_session() as conn:
conv = ConversationsRepository(conn).get_by_legacy_id(conversation_id)
if conv is not None:
pg_conv_id = conv["id"]
elif looks_like_uuid(conversation_id):
pg_conv_id = conversation_id
else:
# Unresolvable legacy ObjectId — downstream ``CAST AS uuid``
# would raise and poison the save. Surface the mismatch so
# the caller can decide (the stream loop in routes/base.py
# already wraps this in try/except).
raise ValueError(
f"Cannot save continuation state: conversation_id "
f"{conversation_id!r} is neither a PG UUID nor a "
f"backfilled legacy Mongo id."
)
PendingToolStateRepository(conn).save_state(
pg_conv_id,
user,
messages=_make_serializable(messages),
pending_tool_calls=_make_serializable(pending_tool_calls),
tools_dict=_make_serializable(tools_dict),
tool_schemas=_make_serializable(tool_schemas),
agent_config=_make_serializable(agent_config),
client_tools=_make_serializable(client_tools) if client_tools else None,
)
logger.info(
f"Saved continuation state for conversation {conversation_id} "
f"with {len(pending_tool_calls)} pending tool call(s)"
)
return conversation_id
def load_state(
self, conversation_id: str, user: str
) -> Optional[Dict[str, Any]]:
"""Load pending continuation state.
Returns:
The state dict, or None if no pending state exists.
"""
with db_readonly() as conn:
conv = ConversationsRepository(conn).get_by_legacy_id(conversation_id)
if conv is not None:
pg_conv_id = conv["id"]
elif looks_like_uuid(conversation_id):
pg_conv_id = conversation_id
else:
# Unresolvable legacy ObjectId → no state can exist for it.
return None
doc = PendingToolStateRepository(conn).load_state(pg_conv_id, user)
if not doc:
return None
return doc
def delete_state(self, conversation_id: str, user: str) -> bool:
"""Delete pending state after successful resumption.
Returns:
True if a row was deleted.
"""
with db_session() as conn:
conv = ConversationsRepository(conn).get_by_legacy_id(conversation_id)
if conv is not None:
pg_conv_id = conv["id"]
elif looks_like_uuid(conversation_id):
pg_conv_id = conversation_id
else:
# Unresolvable legacy ObjectId → nothing to delete.
return False
deleted = PendingToolStateRepository(conn).delete_state(pg_conv_id, user)
if deleted:
logger.info(
f"Deleted continuation state for conversation {conversation_id}"
)
return deleted

View File

@@ -1,44 +1,51 @@
"""Conversation persistence service backed by Postgres.
Handles create / append / update / compression for conversations during
the answer-streaming path. Connections are opened per-operation rather
than held for the duration of a stream.
"""
import logging
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
from application.core.mongo_db import MongoDB
from sqlalchemy import text as sql_text
from application.core.settings import settings
from bson import ObjectId
from application.storage.db.base_repository import looks_like_uuid
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.repositories.conversations import ConversationsRepository
from application.storage.db.session import db_readonly, db_session
logger = logging.getLogger(__name__)
class ConversationService:
def __init__(self):
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
self.conversations_collection = db["conversations"]
self.agents_collection = db["agents"]
def get_conversation(
self, conversation_id: str, user_id: str
) -> Optional[Dict[str, Any]]:
"""Retrieve a conversation with proper access control"""
"""Retrieve a conversation with owner-or-shared access control.
Returns a dict in the legacy Mongo shape — ``queries`` is a list
of message dicts (prompt/response/...) — for compatibility with
the streaming pipeline that consumes this shape.
"""
if not conversation_id or not user_id:
return None
try:
conversation = self.conversations_collection.find_one(
{
"_id": ObjectId(conversation_id),
"$or": [{"user": user_id}, {"shared_with": user_id}],
}
)
if not conversation:
logger.warning(
f"Conversation not found or unauthorized - ID: {conversation_id}, User: {user_id}"
)
return None
conversation["_id"] = str(conversation["_id"])
return conversation
with db_readonly() as conn:
repo = ConversationsRepository(conn)
conv = repo.get_any(conversation_id, user_id)
if conv is None:
logger.warning(
f"Conversation not found or unauthorized - ID: {conversation_id}, User: {user_id}"
)
return None
messages = repo.get_messages(str(conv["id"]))
conv["queries"] = messages
conv["_id"] = str(conv["id"])
return conv
except Exception as e:
logger.error(f"Error fetching conversation: {str(e)}", exc_info=True)
return None
@@ -60,8 +67,13 @@ class ConversationService:
is_shared_usage: bool = False,
shared_token: Optional[str] = None,
attachment_ids: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> str:
"""Save or update a conversation in the database"""
"""Save or update a conversation in Postgres.
Returns the string conversation id (PG UUID as string, or the
caller-provided id if it was already a UUID).
"""
if decoded_token is None:
raise ValueError("Invalid or missing authentication token")
user_id = decoded_token.get("sub")
@@ -69,72 +81,47 @@ class ConversationService:
raise ValueError("User ID not found in token")
current_time = datetime.now(timezone.utc)
# clean up in sources array such that we save max 1k characters for text part
# Trim huge inline source text to a reasonable max before persist.
for source in sources:
if "text" in source and isinstance(source["text"], str):
source["text"] = source["text"][:1000]
message_payload = {
"prompt": question,
"response": response,
"thought": thought,
"sources": sources,
"tool_calls": tool_calls,
"attachments": attachment_ids,
"model_id": model_id,
"timestamp": current_time,
}
if metadata:
message_payload["metadata"] = metadata
if conversation_id is not None and index is not None:
# Update existing conversation with new query
result = self.conversations_collection.update_one(
{
"_id": ObjectId(conversation_id),
"user": user_id,
f"queries.{index}": {"$exists": True},
},
{
"$set": {
f"queries.{index}.prompt": question,
f"queries.{index}.response": response,
f"queries.{index}.thought": thought,
f"queries.{index}.sources": sources,
f"queries.{index}.tool_calls": tool_calls,
f"queries.{index}.timestamp": current_time,
f"queries.{index}.attachments": attachment_ids,
f"queries.{index}.model_id": model_id,
}
},
)
if result.matched_count == 0:
raise ValueError("Conversation not found or unauthorized")
self.conversations_collection.update_one(
{
"_id": ObjectId(conversation_id),
"user": user_id,
f"queries.{index}": {"$exists": True},
},
{"$push": {"queries": {"$each": [], "$slice": index + 1}}},
)
with db_session() as conn:
repo = ConversationsRepository(conn)
conv = repo.get_any(conversation_id, user_id)
if conv is None:
raise ValueError("Conversation not found or unauthorized")
conv_pg_id = str(conv["id"])
repo.update_message_at(conv_pg_id, index, message_payload)
repo.truncate_after(conv_pg_id, index)
return conversation_id
elif conversation_id:
# Append new message to existing conversation
result = self.conversations_collection.update_one(
{"_id": ObjectId(conversation_id), "user": user_id},
{
"$push": {
"queries": {
"prompt": question,
"response": response,
"thought": thought,
"sources": sources,
"tool_calls": tool_calls,
"timestamp": current_time,
"attachments": attachment_ids,
"model_id": model_id,
}
}
},
)
if result.matched_count == 0:
raise ValueError("Conversation not found or unauthorized")
with db_session() as conn:
repo = ConversationsRepository(conn)
conv = repo.get_any(conversation_id, user_id)
if conv is None:
raise ValueError("Conversation not found or unauthorized")
conv_pg_id = str(conv["id"])
# append_message expects 'metadata' key either way; normalise.
append_payload = dict(message_payload)
append_payload.setdefault("metadata", metadata or {})
repo.append_message(conv_pg_id, append_payload)
return conversation_id
else:
# Create new conversation
messages_summary = [
{
"role": "system",
@@ -156,68 +143,64 @@ class ConversationService:
if not completion or not completion.strip():
completion = question[:50] if question else "New Conversation"
conversation_data = {
"user": user_id,
"date": current_time,
"name": completion,
"queries": [
{
"prompt": question,
"response": response,
"thought": thought,
"sources": sources,
"tool_calls": tool_calls,
"timestamp": current_time,
"attachments": attachment_ids,
"model_id": model_id,
}
],
}
resolved_api_key: Optional[str] = None
resolved_agent_id: Optional[str] = None
if api_key:
if agent_id:
conversation_data["agent_id"] = agent_id
if is_shared_usage:
conversation_data["is_shared_usage"] = is_shared_usage
conversation_data["shared_token"] = shared_token
agent = self.agents_collection.find_one({"key": api_key})
with db_readonly() as conn:
agent = AgentsRepository(conn).find_by_key(api_key)
if agent:
conversation_data["api_key"] = agent["key"]
result = self.conversations_collection.insert_one(conversation_data)
return str(result.inserted_id)
resolved_api_key = agent.get("key")
if agent_id:
resolved_agent_id = agent_id
with db_session() as conn:
repo = ConversationsRepository(conn)
conv = repo.create(
user_id,
completion,
agent_id=resolved_agent_id,
api_key=resolved_api_key,
is_shared_usage=bool(resolved_agent_id and is_shared_usage),
shared_token=(
shared_token
if (resolved_agent_id and is_shared_usage)
else None
),
)
conv_pg_id = str(conv["id"])
append_payload = dict(message_payload)
append_payload.setdefault("metadata", metadata or {})
repo.append_message(conv_pg_id, append_payload)
return conv_pg_id
def update_compression_metadata(
self, conversation_id: str, compression_metadata: Dict[str, Any]
) -> None:
"""
Update conversation with compression metadata.
"""Persist compression flags and append a compression point.
Uses $push with $slice to keep only the most recent compression points,
preventing unbounded array growth. Since each compression incorporates
previous compressions, older points become redundant.
Args:
conversation_id: Conversation ID
compression_metadata: Compression point data
Mirrors the Mongo-era ``$set`` + ``$push $slice`` on
``compression_metadata`` but goes through the PG repo API.
"""
try:
self.conversations_collection.update_one(
{"_id": ObjectId(conversation_id)},
{
"$set": {
"compression_metadata.is_compressed": True,
"compression_metadata.last_compression_at": compression_metadata.get(
"timestamp"
),
},
"$push": {
"compression_metadata.compression_points": {
"$each": [compression_metadata],
"$slice": -settings.COMPRESSION_MAX_HISTORY_POINTS,
}
},
},
)
with db_session() as conn:
repo = ConversationsRepository(conn)
# conversation_id here comes from the streaming pipeline
# which has already resolved it; accept either UUID or
# legacy id for safety.
conv = repo.get_by_legacy_id(conversation_id)
conv_pg_id = (
str(conv["id"]) if conv is not None else conversation_id
)
repo.set_compression_flags(
conv_pg_id,
is_compressed=True,
last_compression_at=compression_metadata.get("timestamp"),
)
repo.append_compression_point(
conv_pg_id,
compression_metadata,
max_points=settings.COMPRESSION_MAX_HISTORY_POINTS,
)
logger.info(
f"Updated compression metadata for conversation {conversation_id}"
)
@@ -230,34 +213,34 @@ class ConversationService:
def append_compression_message(
self, conversation_id: str, compression_metadata: Dict[str, Any]
) -> None:
"""
Append a synthetic compression summary entry into the conversation history.
This makes the summary visible in the DB alongside normal queries.
"""
"""Append a synthetic compression summary message to the conversation."""
try:
summary = compression_metadata.get("compressed_summary", "")
if not summary:
return
timestamp = compression_metadata.get("timestamp", datetime.now(timezone.utc))
self.conversations_collection.update_one(
{"_id": ObjectId(conversation_id)},
{
"$push": {
"queries": {
"prompt": "[Context Compression Summary]",
"response": summary,
"thought": "",
"sources": [],
"tool_calls": [],
"timestamp": timestamp,
"attachments": [],
"model_id": compression_metadata.get("model_used"),
}
}
},
timestamp = compression_metadata.get(
"timestamp", datetime.now(timezone.utc)
)
with db_session() as conn:
repo = ConversationsRepository(conn)
conv = repo.get_by_legacy_id(conversation_id)
conv_pg_id = (
str(conv["id"]) if conv is not None else conversation_id
)
repo.append_message(conv_pg_id, {
"prompt": "[Context Compression Summary]",
"response": summary,
"thought": "",
"sources": [],
"tool_calls": [],
"attachments": [],
"model_id": compression_metadata.get("model_used"),
"timestamp": timestamp,
})
logger.info(
f"Appended compression summary to conversation {conversation_id}"
)
logger.info(f"Appended compression summary to conversation {conversation_id}")
except Exception as e:
logger.error(
f"Error appending compression summary: {str(e)}", exc_info=True
@@ -266,20 +249,30 @@ class ConversationService:
def get_compression_metadata(
self, conversation_id: str
) -> Optional[Dict[str, Any]]:
"""
Get compression metadata for a conversation.
Args:
conversation_id: Conversation ID
Returns:
Compression metadata dict or None
"""
"""Fetch the stored compression metadata JSONB blob for a conversation."""
try:
conversation = self.conversations_collection.find_one(
{"_id": ObjectId(conversation_id)}, {"compression_metadata": 1}
)
return conversation.get("compression_metadata") if conversation else None
with db_readonly() as conn:
repo = ConversationsRepository(conn)
conv = repo.get_by_legacy_id(conversation_id)
if conv is None:
# Fallback to UUID lookup without user scoping — the
# caller already holds an authenticated conversation
# id from the streaming path. Gate on id shape so a
# non-UUID (legacy ObjectId that wasn't backfilled)
# doesn't reach CAST — the cast raises and spams the
# logs with a stack trace on every call.
if not looks_like_uuid(conversation_id):
return None
result = conn.execute(
sql_text(
"SELECT compression_metadata FROM conversations "
"WHERE id = CAST(:id AS uuid)"
),
{"id": conversation_id},
)
row = result.fetchone()
return row[0] if row is not None else None
return conv.get("compression_metadata") if conv else None
except Exception as e:
logger.error(
f"Error getting compression metadata: {str(e)}", exc_info=True

View File

@@ -5,10 +5,6 @@ import os
from pathlib import Path
from typing import Any, Dict, Optional, Set
from bson.dbref import DBRef
from bson.objectid import ObjectId
from application.agents.agent_creator import AgentCreator
from application.api.answer.services.compression import CompressionOrchestrator
from application.api.answer.services.compression.token_counter import TokenCounter
@@ -20,8 +16,16 @@ from application.core.model_utils import (
get_provider_from_model_id,
validate_model_id,
)
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from sqlalchemy import text as sql_text
from application.storage.db.base_repository import looks_like_uuid, row_to_dict
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.repositories.attachments import AttachmentsRepository
from application.storage.db.repositories.prompts import PromptsRepository
from application.storage.db.repositories.sources import SourcesRepository
from application.storage.db.repositories.user_tools import UserToolsRepository
from application.storage.db.session import db_readonly, db_session
from application.retriever.retriever_creator import RetrieverCreator
from application.utils import (
calculate_doc_token_budget,
@@ -32,18 +36,41 @@ logger = logging.getLogger(__name__)
def get_prompt(prompt_id: str, prompts_collection=None) -> str:
"""Get a prompt by preset name or Postgres ID (UUID or legacy ObjectId).
The ``prompts_collection`` parameter is retained for backwards
compatibility with call sites that still pass it positionally; it is
ignored post-cutover.
"""
Get a prompt by preset name or MongoDB ID
"""
del prompts_collection # unused — retained for call-site compatibility
# Callers may pass a ``uuid.UUID`` (from a PG ``prompt_id`` column) or a
# plain string ("default"/"creative"/legacy ObjectId). Normalise to str
# so both the preset lookup and the UUID-vs-legacy branching work.
# ``None`` / empty means "use the default prompt" — agents that never
# set a custom prompt land here (PG ``agents.prompt_id`` is NULL).
if prompt_id is None or prompt_id == "":
prompt_id = "default"
elif not isinstance(prompt_id, str):
prompt_id = str(prompt_id)
current_dir = Path(__file__).resolve().parents[3]
prompts_dir = current_dir / "prompts"
preset_mapping = {
CLASSIC_PRESETS = {
"default": "chat_combine_default.txt",
"creative": "chat_combine_creative.txt",
"strict": "chat_combine_strict.txt",
"reduce": "chat_reduce_prompt.txt",
}
AGENTIC_PRESETS = {
"default": "agentic/default.txt",
"creative": "agentic/creative.txt",
"strict": "agentic/strict.txt",
}
preset_mapping = {
**CLASSIC_PRESETS,
**{f"agentic_{k}": v for k, v in AGENTIC_PRESETS.items()},
}
if prompt_id in preset_mapping:
file_path = os.path.join(prompts_dir, preset_mapping[prompt_id])
@@ -53,14 +80,18 @@ def get_prompt(prompt_id: str, prompts_collection=None) -> str:
except FileNotFoundError:
raise FileNotFoundError(f"Prompt file not found: {file_path}")
try:
if prompts_collection is None:
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
prompts_collection = db["prompts"]
prompt_doc = prompts_collection.find_one({"_id": ObjectId(prompt_id)})
with db_readonly() as conn:
repo = PromptsRepository(conn)
prompt_doc = None
if looks_like_uuid(prompt_id):
prompt_doc = repo.get_for_rendering(prompt_id)
if prompt_doc is None:
prompt_doc = repo.get_by_legacy_id(prompt_id)
if not prompt_doc:
raise ValueError(f"Prompt with ID {prompt_id} not found")
return prompt_doc["content"]
except ValueError:
raise
except Exception as e:
raise ValueError(f"Invalid prompt ID: {prompt_id}") from e
@@ -69,12 +100,9 @@ class StreamProcessor:
def __init__(
self, request_data: Dict[str, Any], decoded_token: Optional[Dict[str, Any]]
):
mongo = MongoDB.get_client()
self.db = mongo[settings.MONGO_DB_NAME]
self.agents_collection = self.db["agents"]
self.attachments_collection = self.db["attachments"]
self.prompts_collection = self.db["prompts"]
# Legacy attribute retained as None for any external callers that
# introspect the processor; all DB access uses per-op connections.
self.prompts_collection = None
self.data = request_data
self.decoded_token = decoded_token
self.initial_user_id = (
@@ -91,6 +119,7 @@ class StreamProcessor:
self.is_shared_usage = False
self.shared_token = None
self.agent_id = self.data.get("agent_id")
self.agent_key = None
self.model_id: Optional[str] = None
self.conversation_service = ConversationService()
self.compression_orchestrator = CompressionOrchestrator(
@@ -101,6 +130,7 @@ class StreamProcessor:
self._required_tool_actions: Optional[Dict[str, Set[Optional[str]]]] = None
self.compressed_summary: Optional[str] = None
self.compressed_summary_tokens: int = 0
self._agent_data: Optional[Dict[str, Any]] = None
def initialize(self):
"""Initialize all required components for processing"""
@@ -111,6 +141,29 @@ class StreamProcessor:
self._load_conversation_history()
self._process_attachments()
def build_agent(self, question: str):
"""One call to go from request data to a ready-to-run agent.
Combines initialize(), pre_fetch_docs(), pre_fetch_tools(), and
create_agent() into a single convenience method.
"""
self.initialize()
agent_type = self.agent_config.get("agent_type", "classic")
# Agentic/research agents skip pre-fetch — the LLM searches on-demand via tools
if agent_type in ("agentic", "research"):
tools_data = self.pre_fetch_tools()
return self.create_agent(tools_data=tools_data)
docs_together, docs_list = self.pre_fetch_docs(question)
tools_data = self.pre_fetch_tools()
return self.create_agent(
docs_together=docs_together,
docs=docs_list,
tools_data=tools_data,
)
def _load_conversation_history(self):
"""Load conversation history either from DB or request"""
if self.conversation_id and self.initial_user_id:
@@ -124,9 +177,17 @@ class StreamProcessor:
if settings.ENABLE_CONVERSATION_COMPRESSION:
self._handle_compression(conversation)
else:
# Original behavior - load all history
# Original behavior - load all history (include metadata if present)
self.history = [
{"prompt": query["prompt"], "response": query["response"]}
{
"prompt": query["prompt"],
"response": query["response"],
**(
{"metadata": query["metadata"]}
if "metadata" in query
else {}
),
}
for query in conversation.get("queries", [])
]
else:
@@ -135,14 +196,8 @@ class StreamProcessor:
)
def _handle_compression(self, conversation: Dict[str, Any]):
"""
Handle conversation compression logic using orchestrator.
Args:
conversation: Full conversation document
"""
"""Handle conversation compression logic using orchestrator."""
try:
# Use orchestrator to handle all compression logic
result = self.compression_orchestrator.compress_if_needed(
conversation_id=self.conversation_id,
user_id=self.initial_user_id,
@@ -153,12 +208,15 @@ class StreamProcessor:
if not result.success:
logger.error(f"Compression failed: {result.error}, using full history")
self.history = [
{"prompt": query["prompt"], "response": query["response"]}
{
"prompt": query["prompt"],
"response": query["response"],
**({"metadata": query["metadata"]} if "metadata" in query else {}),
}
for query in conversation.get("queries", [])
]
return
# Set compressed summary if compression was performed
if result.compression_performed and result.compressed_summary:
self.compressed_summary = result.compressed_summary
self.compressed_summary_tokens = TokenCounter.count_message_tokens(
@@ -169,17 +227,27 @@ class StreamProcessor:
f"+ {len(result.recent_queries)} recent messages"
)
# Build history from recent queries
self.history = result.as_history()
# Preserve metadata from recent queries (as_history only has prompt/response)
recent = result.recent_queries if result.recent_queries else conversation.get("queries", [])
for i, entry in enumerate(self.history):
# Match by index from the end of recent queries
offset = len(recent) - len(self.history)
qi = offset + i
if 0 <= qi < len(recent) and "metadata" in recent[qi]:
entry["metadata"] = recent[qi]["metadata"]
except Exception as e:
logger.error(
f"Error handling compression, falling back to standard history: {str(e)}",
exc_info=True,
)
# Fallback to original behavior
self.history = [
{"prompt": query["prompt"], "response": query["response"]}
{
"prompt": query["prompt"],
"response": query["response"],
**({"metadata": query["metadata"]} if "metadata" in query else {}),
}
for query in conversation.get("queries", [])
]
@@ -191,24 +259,24 @@ class StreamProcessor:
)
def _get_attachments_content(self, attachment_ids, user_id):
"""
Retrieve content from attachment documents based on their IDs.
"""
if not attachment_ids:
return []
attachments = []
for attachment_id in attachment_ids:
try:
attachment_doc = self.attachments_collection.find_one(
{"_id": ObjectId(attachment_id), "user": user_id}
)
if attachment_doc:
attachments.append(attachment_doc)
except Exception as e:
logger.error(
f"Error retrieving attachment {attachment_id}: {e}", exc_info=True
)
try:
with db_readonly() as conn:
repo = AttachmentsRepository(conn)
for attachment_id in attachment_ids:
try:
attachment_doc = repo.get_any(str(attachment_id), user_id)
if attachment_doc:
attachments.append(attachment_doc)
except Exception as e:
logger.error(
f"Error retrieving attachment {attachment_id}: {e}",
exc_info=True,
)
except Exception as e:
logger.error(f"Error opening attachments connection: {e}", exc_info=True)
return attachments
def _validate_and_set_model(self):
@@ -232,7 +300,6 @@ class StreamProcessor:
)
self.model_id = requested_model
else:
# Check if agent has a default model configured
agent_default_model = self.agent_config.get("default_model_id", "")
if agent_default_model and validate_model_id(agent_default_model):
self.model_id = agent_default_model
@@ -240,99 +307,127 @@ class StreamProcessor:
self.model_id = get_default_model_id()
def _get_agent_key(self, agent_id: Optional[str], user_id: Optional[str]) -> tuple:
"""Get API key for agent with access control"""
"""Get API key for agent with access control."""
if not agent_id:
return None, False, None
try:
agent = self.agents_collection.find_one({"_id": ObjectId(agent_id)})
with db_readonly() as conn:
# Lookup without user scoping — access control is done
# against ``user_id`` / ``shared_with`` / ``shared`` flags
# right below, matching the legacy Mongo semantics.
repo = AgentsRepository(conn)
agent = None
if looks_like_uuid(str(agent_id)):
result = conn.execute(
sql_text(
"SELECT * FROM agents WHERE id = CAST(:id AS uuid)"
),
{"id": str(agent_id)},
)
row = result.fetchone()
if row is not None:
agent = row_to_dict(row)
if agent is None:
agent = repo.get_by_legacy_id(str(agent_id))
if agent is None:
raise Exception("Agent not found")
is_owner = agent.get("user") == user_id
is_shared_with_user = agent.get(
"shared_publicly", False
) or user_id in agent.get("shared_with", [])
agent_owner = agent.get("user_id")
is_owner = agent_owner == user_id
is_shared_with_user = bool(agent.get("shared", False))
if not (is_owner or is_shared_with_user):
raise Exception("Unauthorized access to the agent")
if is_owner:
self.agents_collection.update_one(
{"_id": ObjectId(agent_id)},
{
"$set": {
"lastUsedAt": datetime.datetime.now(datetime.timezone.utc)
}
},
)
return str(agent["key"]), not is_owner, agent.get("shared_token")
now = datetime.datetime.now(datetime.timezone.utc)
try:
with db_session() as conn:
AgentsRepository(conn).update(
str(agent["id"]), agent_owner,
{"last_used_at": now},
)
except Exception:
logger.warning(
"Failed to update last_used_at for agent",
exc_info=True,
)
return (
str(agent["key"]) if agent.get("key") else None,
not is_owner,
agent.get("shared_token"),
)
except Exception as e:
logger.error(f"Error in get_agent_key: {str(e)}", exc_info=True)
raise
def _get_data_from_api_key(self, api_key: str) -> Dict[str, Any]:
data = self.agents_collection.find_one({"key": api_key})
if not data:
raise Exception("Invalid API Key, please generate a new key", 401)
source = data.get("source")
if isinstance(source, DBRef):
source_doc = self.db.dereference(source)
if source_doc:
data["source"] = str(source_doc["_id"])
data["retriever"] = source_doc.get("retriever", data.get("retriever"))
data["chunks"] = source_doc.get("chunks", data.get("chunks"))
with db_readonly() as conn:
agent = AgentsRepository(conn).find_by_key(api_key)
if not agent:
raise Exception("Invalid API Key, please generate a new key", 401)
sources_repo = SourcesRepository(conn)
# The repo dict uses "user_id" — the streaming path expects
# a "user" key (legacy Mongo shape) for identity propagation.
data: Dict[str, Any] = dict(agent)
data["user"] = agent.get("user_id")
# Resolve the primary source row (if any) for retriever/chunks.
source_id = agent.get("source_id")
if source_id:
source_doc = sources_repo.get(str(source_id), agent.get("user_id"))
if source_doc:
data["source"] = str(source_doc["id"])
data["retriever"] = source_doc.get(
"retriever", data.get("retriever")
)
data["chunks"] = source_doc.get("chunks", data.get("chunks"))
else:
data["source"] = None
else:
data["source"] = None
elif source == "default":
data["source"] = "default"
else:
data["source"] = None
# Handle multiple sources
sources = data.get("sources", [])
if sources and isinstance(sources, list):
sources_list = []
for i, source_ref in enumerate(sources):
if source_ref == "default":
processed_source = {
"id": "default",
"retriever": "classic",
"chunks": data.get("chunks", "2"),
}
sources_list.append(processed_source)
elif isinstance(source_ref, DBRef):
source_doc = self.db.dereference(source_ref)
extra = agent.get("extra_source_ids") or []
if extra:
for sid in extra:
source_doc = sources_repo.get(str(sid), agent.get("user_id"))
if source_doc:
processed_source = {
"id": str(source_doc["_id"]),
"retriever": source_doc.get("retriever", "classic"),
"chunks": source_doc.get("chunks", data.get("chunks", "2")),
}
sources_list.append(processed_source)
data["sources"] = sources_list
else:
data["sources"] = []
# Preserve model configuration from agent
sources_list.append(
{
"id": str(source_doc["id"]),
"retriever": source_doc.get("retriever", "classic"),
"chunks": source_doc.get(
"chunks", data.get("chunks", "2")
),
}
)
data["sources"] = sources_list
data["default_model_id"] = data.get("default_model_id", "")
return data
def _configure_source(self):
"""Configure the source based on agent data"""
api_key = self.data.get("api_key") or self.agent_key
"""Configure the source based on agent data.
if api_key:
agent_data = self._get_data_from_api_key(api_key)
The literal string ``"default"`` is a placeholder meaning "no
ingested source" and is normalized to an empty source so that no
retrieval is attempted.
"""
if self._agent_data:
agent_data = self._agent_data
if agent_data.get("sources") and len(agent_data["sources"]) > 0:
source_ids = [
source["id"] for source in agent_data["sources"] if source.get("id")
source["id"]
for source in agent_data["sources"]
if source.get("id") and source["id"] != "default"
]
if source_ids:
self.source = {"active_docs": source_ids}
else:
self.source = {}
self.all_sources = agent_data["sources"]
elif agent_data.get("source"):
self.all_sources = [
s for s in agent_data["sources"] if s.get("id") != "default"
]
elif agent_data.get("source") and agent_data["source"] != "default":
self.source = {"active_docs": agent_data["source"]}
self.all_sources = [
{
@@ -345,84 +440,106 @@ class StreamProcessor:
self.all_sources = []
return
if "active_docs" in self.data:
self.source = {"active_docs": self.data["active_docs"]}
active_docs = self.data["active_docs"]
if active_docs and active_docs != "default":
self.source = {"active_docs": active_docs}
else:
self.source = {}
return
self.source = {}
self.all_sources = []
def _has_active_docs(self) -> bool:
"""Return True if a real document source is configured for retrieval."""
active_docs = self.source.get("active_docs") if self.source else None
if not active_docs:
return False
if active_docs == "default":
return False
return True
def _resolve_agent_id(self) -> Optional[str]:
"""Resolve agent_id from request, then fall back to conversation context."""
request_agent_id = self.data.get("agent_id")
if request_agent_id:
return str(request_agent_id)
if not self.conversation_id or not self.initial_user_id:
return None
try:
conversation = self.conversation_service.get_conversation(
self.conversation_id, self.initial_user_id
)
except Exception:
return None
if not conversation:
return None
conversation_agent_id = conversation.get("agent_id")
if conversation_agent_id:
return str(conversation_agent_id)
return None
def _configure_agent(self):
"""Configure the agent based on request data"""
agent_id = self.data.get("agent_id")
"""Configure the agent based on request data.
Unified flow: resolve the effective API key, then extract config once.
"""
agent_id = self._resolve_agent_id()
self.agent_key, self.is_shared_usage, self.shared_token = self._get_agent_key(
agent_id, self.initial_user_id
)
self.agent_id = str(agent_id) if agent_id else None
api_key = self.data.get("api_key")
if api_key:
data_key = self._get_data_from_api_key(api_key)
if data_key.get("_id"):
self.agent_id = str(data_key.get("_id"))
# Determine the effective API key (explicit > agent-derived)
effective_key = self.data.get("api_key") or self.agent_key
if effective_key:
self._agent_data = self._get_data_from_api_key(effective_key)
if self._agent_data.get("_id"):
self.agent_id = str(self._agent_data.get("_id"))
self.agent_config.update(
{
"prompt_id": data_key.get("prompt_id", "default"),
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
"user_api_key": api_key,
"json_schema": data_key.get("json_schema"),
"default_model_id": data_key.get("default_model_id", ""),
"prompt_id": self._agent_data.get("prompt_id", "default"),
"agent_type": self._agent_data.get("agent_type", settings.AGENT_NAME),
"user_api_key": effective_key,
"json_schema": self._agent_data.get("json_schema"),
"default_model_id": self._agent_data.get("default_model_id", ""),
"models": self._agent_data.get("models", []),
"allow_system_prompt_override": self._agent_data.get(
"allow_system_prompt_override", False
),
}
)
self.initial_user_id = data_key.get("user")
self.decoded_token = {"sub": data_key.get("user")}
if data_key.get("source"):
self.source = {"active_docs": data_key["source"]}
if data_key.get("workflow"):
self.agent_config["workflow"] = data_key["workflow"]
self.agent_config["workflow_owner"] = data_key.get("user")
if data_key.get("retriever"):
self.retriever_config["retriever_name"] = data_key["retriever"]
if data_key.get("chunks") is not None:
try:
self.retriever_config["chunks"] = int(data_key["chunks"])
except (ValueError, TypeError):
logger.warning(
f"Invalid chunks value: {data_key['chunks']}, using default value 2"
)
self.retriever_config["chunks"] = 2
elif self.agent_key:
data_key = self._get_data_from_api_key(self.agent_key)
if data_key.get("_id"):
self.agent_id = str(data_key.get("_id"))
self.agent_config.update(
{
"prompt_id": data_key.get("prompt_id", "default"),
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
"user_api_key": self.agent_key,
"json_schema": data_key.get("json_schema"),
"default_model_id": data_key.get("default_model_id", ""),
}
# Set identity context
if self.data.get("api_key"):
# External API key: use the key owner's identity
self.initial_user_id = self._agent_data.get("user")
self.decoded_token = {"sub": self._agent_data.get("user")}
elif self.is_shared_usage:
# Shared agent: keep the caller's identity
pass
else:
# Owner using their own agent
self.decoded_token = {"sub": self._agent_data.get("user")}
# PG row exposes the workflow as ``workflow_id`` (UUID column);
# legacy Mongo shape used the key ``workflow``. Accept either so
# API-key-invoked workflow agents bind correctly downstream.
wf_ref = self._agent_data.get("workflow") or self._agent_data.get(
"workflow_id"
)
self.decoded_token = (
self.decoded_token
if self.is_shared_usage
else {"sub": data_key.get("user")}
)
if data_key.get("source"):
self.source = {"active_docs": data_key["source"]}
if data_key.get("workflow"):
self.agent_config["workflow"] = data_key["workflow"]
self.agent_config["workflow_owner"] = data_key.get("user")
if data_key.get("retriever"):
self.retriever_config["retriever_name"] = data_key["retriever"]
if data_key.get("chunks") is not None:
try:
self.retriever_config["chunks"] = int(data_key["chunks"])
except (ValueError, TypeError):
logger.warning(
f"Invalid chunks value: {data_key['chunks']}, using default value 2"
)
self.retriever_config["chunks"] = 2
if wf_ref:
self.agent_config["workflow"] = str(wf_ref)
self.agent_config["workflow_owner"] = self._agent_data.get("user")
else:
# No API key — default/workflow configuration
agent_type = settings.AGENT_NAME
if self.data.get("workflow") and isinstance(
self.data.get("workflow"), dict
@@ -443,14 +560,45 @@ class StreamProcessor:
)
def _configure_retriever(self):
"""Assemble retriever config with precedence: request > agent > default."""
doc_token_limit = calculate_doc_token_budget(model_id=self.model_id)
# Start with defaults
retriever_name = "classic"
chunks = 2
# Layer agent-level config (if present)
if self._agent_data:
if self._agent_data.get("retriever"):
retriever_name = self._agent_data["retriever"]
if self._agent_data.get("chunks") is not None:
try:
chunks = int(self._agent_data["chunks"])
except (ValueError, TypeError):
logger.warning(
f"Invalid agent chunks value: {self._agent_data['chunks']}, "
"using default value 2"
)
# Explicit request values win over agent config
if "retriever" in self.data:
retriever_name = self.data["retriever"]
if "chunks" in self.data:
try:
chunks = int(self.data["chunks"])
except (ValueError, TypeError):
logger.warning(
f"Invalid request chunks value: {self.data['chunks']}, "
"using default value 2"
)
self.retriever_config = {
"retriever_name": self.data.get("retriever", "classic"),
"chunks": int(self.data.get("chunks", 2)),
"retriever_name": retriever_name,
"chunks": chunks,
"doc_token_limit": doc_token_limit,
}
# isNoneDoc without an API key forces no retrieval
api_key = self.data.get("api_key") or self.agent_key
if not api_key and "isNoneDoc" in self.data and self.data["isNoneDoc"]:
self.retriever_config["chunks"] = 0
@@ -471,9 +619,12 @@ class StreamProcessor:
def pre_fetch_docs(self, question: str) -> tuple[Optional[str], Optional[list]]:
"""Pre-fetch documents for template rendering before agent creation"""
if self.data.get("isNoneDoc", False):
if self.data.get("isNoneDoc", False) and not self.agent_id:
logger.info("Pre-fetch skipped: isNoneDoc=True")
return None, None
if not self._has_active_docs():
logger.info("Pre-fetch skipped: no active docs configured")
return None, None
try:
retriever = self.create_retriever()
logger.info(
@@ -505,12 +656,7 @@ class StreamProcessor:
return None, None
def pre_fetch_tools(self) -> Optional[Dict[str, Any]]:
"""Pre-fetch tool data for template rendering before agent creation
Can be controlled via:
1. Global setting: ENABLE_TOOL_PREFETCH in .env
2. Per-request: disable_tool_prefetch in request data
"""
"""Pre-fetch tool data for template rendering before agent creation"""
if not settings.ENABLE_TOOL_PREFETCH:
logger.info(
"Tool pre-fetching disabled globally via ENABLE_TOOL_PREFETCH setting"
@@ -525,12 +671,9 @@ class StreamProcessor:
filtering_enabled = required_tool_actions is not None
try:
user_tools_collection = self.db["user_tools"]
user_id = self.initial_user_id or "local"
user_tools = list(
user_tools_collection.find({"user": user_id, "status": True})
)
with db_readonly() as conn:
user_tools = UserToolsRepository(conn).list_active_for_user(user_id)
if not user_tools:
return None
@@ -722,6 +865,121 @@ class StreamProcessor:
logger.warning(f"Failed to fetch memory tool data: {str(e)}")
return None
def resume_from_tool_actions(
self,
tool_actions: list,
conversation_id: str,
):
"""Resume a paused agent from saved continuation state.
Loads the pending state from MongoDB, recreates the agent with
the saved configuration, and returns an agent ready to call
``gen_continuation()``.
Args:
tool_actions: Client-provided actions (approvals / results).
conversation_id: The conversation being resumed.
Returns:
Tuple of (agent, messages, tools_dict, pending_tool_calls, tool_actions).
"""
from application.api.answer.services.continuation_service import (
ContinuationService,
)
from application.agents.agent_creator import AgentCreator
from application.agents.tool_executor import ToolExecutor
from application.llm.handlers.handler_creator import LLMHandlerCreator
from application.llm.llm_creator import LLMCreator
cont_service = ContinuationService()
state = cont_service.load_state(conversation_id, self.initial_user_id)
if not state:
raise ValueError("No pending tool state found for this conversation")
messages = state["messages"]
pending_tool_calls = state["pending_tool_calls"]
tools_dict = state["tools_dict"]
tool_schemas = state.get("tool_schemas", [])
agent_config = state["agent_config"]
model_id = agent_config.get("model_id")
llm_name = agent_config.get("llm_name", settings.LLM_PROVIDER)
api_key = agent_config.get("api_key")
user_api_key = agent_config.get("user_api_key")
agent_id = agent_config.get("agent_id")
prompt = agent_config.get("prompt", "")
json_schema = agent_config.get("json_schema")
retriever_config = agent_config.get("retriever_config")
# Recreate dependencies
system_api_key = api_key or get_api_key_for_provider(llm_name)
llm = LLMCreator.create_llm(
llm_name,
api_key=system_api_key,
user_api_key=user_api_key,
decoded_token=self.decoded_token,
model_id=model_id,
agent_id=agent_id,
)
llm_handler = LLMHandlerCreator.create_handler(llm_name or "default")
tool_executor = ToolExecutor(
user_api_key=user_api_key,
user=self.initial_user_id,
decoded_token=self.decoded_token,
)
tool_executor.conversation_id = conversation_id
# Restore client tools so they stay available for subsequent LLM calls
saved_client_tools = state.get("client_tools")
if saved_client_tools:
tool_executor.client_tools = saved_client_tools
# Re-merge into tools_dict (they may have been stripped during serialization)
tool_executor.merge_client_tools(tools_dict, saved_client_tools)
agent_type = agent_config.get("agent_type", "ClassicAgent")
# Map class names back to agent creator keys
type_map = {
"ClassicAgent": "classic",
"AgenticAgent": "agentic",
"ResearchAgent": "research",
"WorkflowAgent": "workflow",
}
agent_key = type_map.get(agent_type, "classic")
agent_kwargs = {
"endpoint": "stream",
"llm_name": llm_name,
"model_id": model_id,
"api_key": system_api_key,
"agent_id": agent_id,
"user_api_key": user_api_key,
"prompt": prompt,
"chat_history": [],
"decoded_token": self.decoded_token,
"json_schema": json_schema,
"llm": llm,
"llm_handler": llm_handler,
"tool_executor": tool_executor,
}
if agent_key in ("agentic", "research") and retriever_config:
agent_kwargs["retriever_config"] = retriever_config
agent = AgentCreator.create_agent(agent_key, **agent_kwargs)
agent.conversation_id = conversation_id
agent.initial_user_id = self.initial_user_id
agent.tools = tool_schemas
# Store config for the route layer
self.model_id = model_id
self.agent_id = agent_id
self.agent_config["user_api_key"] = user_api_key
self.conversation_id = conversation_id
# Delete state so it can't be replayed
cont_service.delete_state(conversation_id, self.initial_user_id)
return agent, messages, tools_dict, pending_tool_calls, tool_actions
def create_agent(
self,
docs_together: Optional[str] = None,
@@ -729,22 +987,40 @@ class StreamProcessor:
tools_data: Optional[Dict[str, Any]] = None,
):
"""Create and return the configured agent with rendered prompt"""
agent_type = self.agent_config["agent_type"]
# For agentic agents, swap standard presets for their agentic
# counterparts (which include search tool instructions instead of
# {summaries}). Custom / user-provided prompts pass through as-is.
raw_prompt = self._get_prompt_content()
if raw_prompt is None:
raw_prompt = get_prompt(
self.agent_config["prompt_id"], self.prompts_collection
)
prompt_id = self.agent_config.get("prompt_id", "default")
agentic_presets = {"default", "creative", "strict"}
if agent_type in ("agentic", "research") and prompt_id in agentic_presets:
raw_prompt = get_prompt(
f"agentic_{prompt_id}", self.prompts_collection
)
else:
raw_prompt = get_prompt(prompt_id, self.prompts_collection)
self._prompt_content = raw_prompt
rendered_prompt = self.prompt_renderer.render_prompt(
prompt_content=raw_prompt,
user_id=self.initial_user_id,
request_id=self.data.get("request_id"),
passthrough_data=self.data.get("passthrough"),
docs=docs,
docs_together=docs_together,
tools_data=tools_data,
)
# Allow API callers to override the system prompt when the agent
# has opted in via allow_system_prompt_override.
if (
self.agent_config.get("allow_system_prompt_override", False)
and self.data.get("system_prompt_override")
):
rendered_prompt = self.data["system_prompt_override"]
else:
rendered_prompt = self.prompt_renderer.render_prompt(
prompt_content=raw_prompt,
user_id=self.initial_user_id,
request_id=self.data.get("request_id"),
passthrough_data=self.data.get("passthrough"),
docs=docs,
docs_together=docs_together,
tools_data=tools_data,
)
provider = (
get_provider_from_model_id(self.model_id)
@@ -753,7 +1029,41 @@ class StreamProcessor:
)
system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER)
agent_type = self.agent_config["agent_type"]
# Create LLM and handler (dependency injection)
from application.llm.llm_creator import LLMCreator
from application.llm.handlers.handler_creator import LLMHandlerCreator
from application.agents.tool_executor import ToolExecutor
# Compute backup models: agent's configured models minus the active one.
# PG agents may carry an explicit ``models: NULL`` (not absent), so
# ``.get("models", [])`` isn't enough — coerce None → [].
agent_models = self.agent_config.get("models") or []
backup_models = [m for m in agent_models if m != self.model_id]
llm = LLMCreator.create_llm(
provider or settings.LLM_PROVIDER,
api_key=system_api_key,
user_api_key=self.agent_config["user_api_key"],
decoded_token=self.decoded_token,
model_id=self.model_id,
agent_id=self.agent_id,
backup_models=backup_models,
)
llm_handler = LLMHandlerCreator.create_handler(
provider if provider else "default"
)
user = self.decoded_token.get("sub") if self.decoded_token else None
tool_executor = ToolExecutor(
user_api_key=self.agent_config["user_api_key"],
user=user,
decoded_token=self.decoded_token,
)
tool_executor.conversation_id = self.conversation_id
# Pass client-side tools so they get merged in get_tools()
client_tools = self.data.get("client_tools")
if client_tools:
tool_executor.client_tools = client_tools
# Base agent kwargs
agent_kwargs = {
@@ -770,10 +1080,31 @@ class StreamProcessor:
"attachments": self.attachments,
"json_schema": self.agent_config.get("json_schema"),
"compressed_summary": self.compressed_summary,
"llm": llm,
"llm_handler": llm_handler,
"tool_executor": tool_executor,
}
# Workflow-specific kwargs for workflow agents
if agent_type == "workflow":
# Type-specific kwargs
if agent_type in ("agentic", "research"):
agent_kwargs["retriever_config"] = {
"source": self.source,
"retriever_name": self.retriever_config.get(
"retriever_name", "classic"
),
"chunks": self.retriever_config.get("chunks", 2),
"doc_token_limit": self.retriever_config.get(
"doc_token_limit", 50000
),
"model_id": self.model_id,
"user_api_key": self.agent_config["user_api_key"],
"agent_id": self.agent_id,
"llm_name": provider or settings.LLM_PROVIDER,
"api_key": system_api_key,
"decoded_token": self.decoded_token,
}
elif agent_type == "workflow":
workflow_config = self.agent_config.get("workflow")
if isinstance(workflow_config, str):
agent_kwargs["workflow_id"] = workflow_config

View File

@@ -1,12 +1,10 @@
import base64
import datetime
import html
import json
import uuid
from urllib.parse import urlencode
from bson.objectid import ObjectId
from flask import (
Blueprint,
current_app,
@@ -17,22 +15,18 @@ from flask import (
from flask_restx import fields, Namespace, Resource
from application.api import api
from application.api.user.tasks import (
ingest_connector_task,
)
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.api import api
from application.parser.connectors.connector_creator import ConnectorCreator
from application.storage.db.repositories.connector_sessions import (
ConnectorSessionsRepository,
)
from application.storage.db.repositories.sources import SourcesRepository
from application.storage.db.session import db_readonly, db_session
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
sources_collection = db["sources"]
sessions_collection = db["connector_sessions"]
connector = Blueprint("connector", __name__)
connectors_ns = Namespace("connectors", description="Connector operations", path="/")
api.add_namespace(connectors_ns)
@@ -68,16 +62,14 @@ class ConnectorAuth(Resource):
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
user_id = decoded_token.get('sub')
now = datetime.datetime.now(datetime.timezone.utc)
result = sessions_collection.insert_one({
"provider": provider,
"user": user_id,
"status": "pending",
"created_at": now
})
with db_session() as conn:
session_row = ConnectorSessionsRepository(conn).upsert(
user_id, provider, status="pending",
)
session_pg_id = str(session_row["id"])
state_dict = {
"provider": provider,
"object_id": str(result.inserted_id)
"object_id": session_pg_id,
}
state = base64.urlsafe_b64encode(json.dumps(state_dict).encode()).decode()
@@ -146,32 +138,39 @@ class ConnectorsCallback(Resource):
session_token = str(uuid.uuid4())
try:
credentials = auth.create_credentials_from_token_info(token_info)
service = auth.build_drive_service(credentials)
user_info = service.about().get(fields="user").execute()
user_email = user_info.get('user', {}).get('emailAddress', 'Connected User')
if provider == "google_drive":
credentials = auth.create_credentials_from_token_info(token_info)
service = auth.build_drive_service(credentials)
user_info = service.about().get(fields="user").execute()
user_email = user_info.get('user', {}).get('emailAddress', 'Connected User')
else:
user_email = token_info.get('user_info', {}).get('email', 'Connected User')
except Exception as e:
current_app.logger.warning(f"Could not get user info: {e}")
user_email = 'Connected User'
sanitized_token_info = {
"access_token": token_info.get("access_token"),
"refresh_token": token_info.get("refresh_token"),
"token_uri": token_info.get("token_uri"),
"expiry": token_info.get("expiry")
}
sanitized_token_info = auth.sanitize_token_info(token_info)
sessions_collection.find_one_and_update(
{"_id": ObjectId(state_object_id), "provider": provider},
{
"$set": {
"session_token": session_token,
"token_info": sanitized_token_info,
"user_email": user_email,
"status": "authorized"
}
}
)
# ``object_id`` in the OAuth state is the PG session row
# UUID (new flow) or a legacy Mongo ObjectId (pre-cutover
# issued state). Try UUID update first; fall back to
# legacy id path.
patch = {
"session_token": session_token,
"token_info": sanitized_token_info,
"user_email": user_email,
"status": "authorized",
}
with db_session() as conn:
repo = ConnectorSessionsRepository(conn)
if state_object_id:
value = str(state_object_id)
updated = False
if len(value) == 36 and "-" in value:
updated = repo.update(value, patch)
if not updated:
repo.update_by_legacy_id(value, patch)
# Redirect to success page with session token and user email
return redirect(build_callback_redirect({
@@ -201,12 +200,12 @@ class ConnectorsCallback(Resource):
@connectors_ns.route("/api/connectors/files")
class ConnectorFiles(Resource):
@api.expect(api.model("ConnectorFilesModel", {
"provider": fields.String(required=True),
"session_token": fields.String(required=True),
"folder_id": fields.String(required=False),
"limit": fields.Integer(required=False),
"provider": fields.String(required=True),
"session_token": fields.String(required=True),
"folder_id": fields.String(required=False),
"limit": fields.Integer(required=False),
"page_token": fields.String(required=False),
"search_query": fields.String(required=False)
"search_query": fields.String(required=False),
}))
@api.doc(description="List files from a connector provider (supports pagination and search)")
def post(self):
@@ -214,11 +213,8 @@ class ConnectorFiles(Resource):
data = request.get_json()
provider = data.get('provider')
session_token = data.get('session_token')
folder_id = data.get('folder_id')
limit = data.get('limit', 10)
page_token = data.get('page_token')
search_query = data.get('search_query')
if not provider or not session_token:
return make_response(jsonify({"success": False, "error": "provider and session_token are required"}), 400)
@@ -226,20 +222,20 @@ class ConnectorFiles(Resource):
if not decoded_token:
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
user = decoded_token.get('sub')
session = sessions_collection.find_one({"session_token": session_token, "user": user})
if not session:
with db_readonly() as conn:
session = ConnectorSessionsRepository(conn).get_by_session_token(
session_token,
)
if not session or session.get("user_id") != user:
return make_response(jsonify({"success": False, "error": "Invalid or unauthorized session"}), 401)
loader = ConnectorCreator.create_connector(provider, session_token)
generic_keys = {'provider', 'session_token'}
input_config = {
'limit': limit,
'list_only': True,
'session_token': session_token,
'folder_id': folder_id,
'page_token': page_token
k: v for k, v in data.items() if k not in generic_keys
}
if search_query:
input_config['search_query'] = search_query
input_config['list_only'] = True
documents = loader.load_data(input_config)
@@ -295,8 +291,11 @@ class ConnectorValidateSession(Resource):
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
user = decoded_token.get('sub')
session = sessions_collection.find_one({"session_token": session_token, "user": user})
if not session or "token_info" not in session:
with db_readonly() as conn:
session = ConnectorSessionsRepository(conn).get_by_session_token(
session_token,
)
if not session or session.get("user_id") != user or not session.get("token_info"):
return make_response(jsonify({"success": False, "error": "Invalid or expired session"}), 401)
token_info = session["token_info"]
@@ -306,16 +305,12 @@ class ConnectorValidateSession(Resource):
if is_expired and token_info.get('refresh_token'):
try:
refreshed_token_info = auth.refresh_access_token(token_info.get('refresh_token'))
sanitized_token_info = {
"access_token": refreshed_token_info.get("access_token"),
"refresh_token": refreshed_token_info.get("refresh_token"),
"token_uri": refreshed_token_info.get("token_uri"),
"expiry": refreshed_token_info.get("expiry")
}
sessions_collection.update_one(
{"session_token": session_token},
{"$set": {"token_info": sanitized_token_info}}
)
sanitized_token_info = auth.sanitize_token_info(refreshed_token_info)
with db_session() as conn:
repo = ConnectorSessionsRepository(conn)
row = repo.get_by_session_token(session_token)
if row:
repo.update(str(row["id"]), {"token_info": sanitized_token_info})
token_info = sanitized_token_info
is_expired = False
except Exception as refresh_error:
@@ -328,12 +323,18 @@ class ConnectorValidateSession(Resource):
"error": "Session token has expired. Please reconnect."
}), 401)
return make_response(jsonify({
_base_fields = {"access_token", "refresh_token", "token_uri", "expiry"}
provider_extras = {k: v for k, v in token_info.items() if k not in _base_fields}
response_data = {
"success": True,
"expired": False,
"user_email": session.get('user_email', 'Connected User'),
"access_token": token_info.get('access_token')
}), 200)
"access_token": token_info.get('access_token'),
**provider_extras,
}
return make_response(jsonify(response_data), 200)
except Exception as e:
current_app.logger.error(f"Error validating connector session: {e}", exc_info=True)
return make_response(jsonify({"success": False, "error": "Failed to validate session"}), 500)
@@ -353,8 +354,11 @@ class ConnectorDisconnect(Resource):
if session_token:
sessions_collection.delete_one({"session_token": session_token})
with db_session() as conn:
ConnectorSessionsRepository(conn).delete_by_session_token(
session_token,
)
return make_response(jsonify({"success": True}), 200)
except Exception as e:
current_app.logger.error(f"Error disconnecting connector session: {e}", exc_info=True)
@@ -391,32 +395,28 @@ class ConnectorSync(Resource):
}),
400
)
source = sources_collection.find_one({"_id": ObjectId(source_id)})
user_id = decoded_token.get('sub')
with db_readonly() as conn:
source = SourcesRepository(conn).get_any(source_id, user_id)
if not source:
return make_response(
jsonify({
"success": False,
"error": "Source not found"
}),
}),
404
)
if source.get('user') != decoded_token.get('sub'):
return make_response(
jsonify({
"success": False,
"error": "Unauthorized access to source"
}),
403
)
# ``get_any`` already scopes by ``user_id``; an extra guard
# here would be dead code.
remote_data = {}
try:
if source.get('remote_data'):
remote_data = json.loads(source.get('remote_data'))
except json.JSONDecodeError:
current_app.logger.error(f"Invalid remote_data format for source {source_id}")
remote_data = {}
remote_data = source.get('remote_data') or {}
if isinstance(remote_data, str):
try:
remote_data = json.loads(remote_data)
except json.JSONDecodeError:
current_app.logger.error(f"Invalid remote_data format for source {source_id}")
remote_data = {}
source_type = remote_data.get('provider')
if not source_type:
@@ -444,7 +444,7 @@ class ConnectorSync(Resource):
recursive=recursive,
retriever=source.get('retriever', 'classic'),
operation_mode="sync",
doc_id=source_id,
doc_id=str(source.get('id') or source_id),
sync_frequency=source.get('sync_frequency', 'never')
)

View File

@@ -3,18 +3,16 @@ import datetime
import json
from flask import Blueprint, request, send_from_directory, jsonify
from werkzeug.utils import secure_filename
from bson.objectid import ObjectId
import logging
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.storage.db.base_repository import looks_like_uuid
from application.storage.db.repositories.sources import SourcesRepository
from application.storage.db.session import db_session
from application.storage.storage_creator import StorageCreator
logger = logging.getLogger(__name__)
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
conversations_collection = db["conversations"]
sources_collection = db["sources"]
current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
@@ -26,12 +24,20 @@ internal = Blueprint("internal", __name__)
@internal.before_request
def verify_internal_key():
"""Verify INTERNAL_KEY for all internal endpoint requests."""
if settings.INTERNAL_KEY:
internal_key = request.headers.get("X-Internal-Key")
if not internal_key or internal_key != settings.INTERNAL_KEY:
logger.warning(f"Unauthorized internal API access attempt from {request.remote_addr}")
return jsonify({"error": "Unauthorized", "message": "Invalid or missing internal key"}), 401
"""Verify INTERNAL_KEY for all internal endpoint requests.
Deny by default: if INTERNAL_KEY is not configured, reject all requests.
"""
if not settings.INTERNAL_KEY:
logger.warning(
f"Internal API request rejected from {request.remote_addr}: "
"INTERNAL_KEY is not configured"
)
return jsonify({"error": "Unauthorized", "message": "Internal API is not configured"}), 401
internal_key = request.headers.get("X-Internal-Key")
if not internal_key or internal_key != settings.INTERNAL_KEY:
logger.warning(f"Unauthorized internal API access attempt from {request.remote_addr}")
return jsonify({"error": "Unauthorized", "message": "Invalid or missing internal key"}), 401
@internal.route("/api/download", methods=["get"])
@@ -48,21 +54,21 @@ def upload_index_files():
"""Upload two files(index.faiss, index.pkl) to the user's folder."""
if "user" not in request.form:
return {"status": "no user"}
user = request.form["user"]
user = request.form["user"]
if "name" not in request.form:
return {"status": "no name"}
job_name = request.form["name"]
tokens = request.form["tokens"]
retriever = request.form["retriever"]
id = request.form["id"]
source_id = request.form["id"]
type = request.form["type"]
remote_data = request.form["remote_data"] if "remote_data" in request.form else None
sync_frequency = request.form["sync_frequency"] if "sync_frequency" in request.form else None
file_path = request.form.get("file_path")
directory_structure = request.form.get("directory_structure")
file_name_map = request.form.get("file_name_map")
if directory_structure:
try:
directory_structure = json.loads(directory_structure)
@@ -81,8 +87,8 @@ def upload_index_files():
file_name_map = None
storage = StorageCreator.get_storage()
index_base_path = f"indexes/{id}"
index_base_path = f"indexes/{source_id}"
if settings.VECTOR_STORE == "faiss":
if "file_faiss" not in request.files:
logger.error("No file_faiss part")
@@ -103,46 +109,48 @@ def upload_index_files():
storage.save_file(file_faiss, faiss_storage_path)
storage.save_file(file_pkl, pkl_storage_path)
now = datetime.datetime.now(datetime.timezone.utc)
update_fields = {
"name": job_name,
"type": type,
"language": job_name,
"date": now,
"model": settings.EMBEDDINGS_NAME,
"tokens": tokens,
"retriever": retriever,
"remote_data": remote_data,
"sync_frequency": sync_frequency,
"file_path": file_path,
"directory_structure": directory_structure,
}
if file_name_map is not None:
update_fields["file_name_map"] = file_name_map
existing_entry = sources_collection.find_one({"_id": ObjectId(id)})
if existing_entry:
update_fields = {
"user": user,
"name": job_name,
"language": job_name,
"date": datetime.datetime.now(),
"model": settings.EMBEDDINGS_NAME,
"type": type,
"tokens": tokens,
"retriever": retriever,
"remote_data": remote_data,
"sync_frequency": sync_frequency,
"file_path": file_path,
"directory_structure": directory_structure,
}
if file_name_map is not None:
update_fields["file_name_map"] = file_name_map
sources_collection.update_one(
{"_id": ObjectId(id)},
{"$set": update_fields},
)
else:
insert_doc = {
"_id": ObjectId(id),
"user": user,
"name": job_name,
"language": job_name,
"date": datetime.datetime.now(),
"model": settings.EMBEDDINGS_NAME,
"type": type,
"tokens": tokens,
"retriever": retriever,
"remote_data": remote_data,
"sync_frequency": sync_frequency,
"file_path": file_path,
"directory_structure": directory_structure,
}
if file_name_map is not None:
insert_doc["file_name_map"] = file_name_map
sources_collection.insert_one(insert_doc)
with db_session() as conn:
repo = SourcesRepository(conn)
existing = None
if looks_like_uuid(source_id):
existing = repo.get(source_id, user)
if existing is None:
existing = repo.get_by_legacy_id(source_id, user)
if existing is not None:
repo.update(str(existing["id"]), user, update_fields)
else:
repo.create(
job_name,
source_id=source_id if looks_like_uuid(source_id) else None,
user_id=user,
type=type,
tokens=tokens,
retriever=retriever,
remote_data=remote_data,
sync_frequency=sync_frequency,
file_path=file_path,
directory_structure=directory_structure,
file_name_map=file_name_map,
language=job_name,
model=settings.EMBEDDINGS_NAME,
date=now,
legacy_mongo_id=None if looks_like_uuid(source_id) else str(source_id),
)
return {"status": "ok"}

View File

@@ -3,22 +3,50 @@ Agent folders management routes.
Provides virtual folder organization for agents (Google Drive-like structure).
"""
import datetime
from bson.objectid import ObjectId
from flask import jsonify, make_response, request
from flask import current_app, jsonify, make_response, request
from flask_restx import Namespace, Resource, fields
from sqlalchemy import text as _sql_text
from application.api import api
from application.api.user.base import (
agent_folders_collection,
agents_collection,
)
from application.storage.db.base_repository import looks_like_uuid
from application.storage.db.repositories.agent_folders import AgentFoldersRepository
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.session import db_readonly, db_session
agents_folders_ns = Namespace(
"agents_folders", description="Agent folder management", path="/api/agents/folders"
)
def _resolve_folder_id(repo: AgentFoldersRepository, folder_id: str, user: str):
"""Resolve a folder id that may be either a UUID or legacy Mongo ObjectId."""
if not folder_id:
return None
if looks_like_uuid(folder_id):
row = repo.get(folder_id, user)
if row is not None:
return row
return repo.get_by_legacy_id(folder_id, user)
def _folder_error_response(message: str, err: Exception):
current_app.logger.error(f"{message}: {err}", exc_info=True)
return make_response(jsonify({"success": False, "message": message}), 400)
def _serialize_folder(f: dict) -> dict:
created_at = f.get("created_at")
updated_at = f.get("updated_at")
return {
"id": str(f["id"]),
"name": f.get("name"),
"parent_id": str(f["parent_id"]) if f.get("parent_id") else None,
"created_at": created_at.isoformat() if hasattr(created_at, "isoformat") else created_at,
"updated_at": updated_at.isoformat() if hasattr(updated_at, "isoformat") else updated_at,
}
@agents_folders_ns.route("/")
class AgentFolders(Resource):
@api.doc(description="Get all folders for the user")
@@ -28,20 +56,12 @@ class AgentFolders(Resource):
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
try:
folders = list(agent_folders_collection.find({"user": user}))
result = [
{
"id": str(f["_id"]),
"name": f["name"],
"parent_id": f.get("parent_id"),
"created_at": f.get("created_at", "").isoformat() if f.get("created_at") else None,
"updated_at": f.get("updated_at", "").isoformat() if f.get("updated_at") else None,
}
for f in folders
]
with db_readonly() as conn:
folders = AgentFoldersRepository(conn).list_for_user(user)
result = [_serialize_folder(f) for f in folders]
return make_response(jsonify({"folders": result}), 200)
except Exception as e:
return make_response(jsonify({"success": False, "message": str(e)}), 400)
except Exception as err:
return _folder_error_response("Failed to fetch folders", err)
@api.doc(description="Create a new folder")
@api.expect(
@@ -62,28 +82,38 @@ class AgentFolders(Resource):
if not data or not data.get("name"):
return make_response(jsonify({"success": False, "message": "Folder name is required"}), 400)
parent_id = data.get("parent_id")
if parent_id:
parent = agent_folders_collection.find_one({"_id": ObjectId(parent_id), "user": user})
if not parent:
return make_response(jsonify({"success": False, "message": "Parent folder not found"}), 404)
parent_id_input = data.get("parent_id")
description = data.get("description")
try:
now = datetime.datetime.now(datetime.timezone.utc)
folder = {
"user": user,
"name": data["name"],
"parent_id": parent_id,
"created_at": now,
"updated_at": now,
}
result = agent_folders_collection.insert_one(folder)
with db_session() as conn:
repo = AgentFoldersRepository(conn)
pg_parent_id = None
if parent_id_input:
parent = _resolve_folder_id(repo, parent_id_input, user)
if not parent:
return make_response(
jsonify({"success": False, "message": "Parent folder not found"}),
404,
)
pg_parent_id = str(parent["id"])
folder = repo.create(
user, data["name"],
description=description,
parent_id=pg_parent_id,
)
return make_response(
jsonify({"id": str(result.inserted_id), "name": data["name"], "parent_id": parent_id}),
jsonify(
{
"id": str(folder["id"]),
"name": folder["name"],
"parent_id": pg_parent_id,
}
),
201,
)
except Exception as e:
return make_response(jsonify({"success": False, "message": str(e)}), 400)
except Exception as err:
return _folder_error_response("Failed to create folder", err)
@agents_folders_ns.route("/<string:folder_id>")
@@ -95,30 +125,55 @@ class AgentFolder(Resource):
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
try:
folder = agent_folders_collection.find_one({"_id": ObjectId(folder_id), "user": user})
if not folder:
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
agents = list(agents_collection.find({"user": user, "folder_id": folder_id}))
agents_list = [
{"id": str(a["_id"]), "name": a["name"], "description": a.get("description", "")}
for a in agents
]
subfolders = list(agent_folders_collection.find({"user": user, "parent_id": folder_id}))
subfolders_list = [{"id": str(sf["_id"]), "name": sf["name"]} for sf in subfolders]
with db_readonly() as conn:
folders_repo = AgentFoldersRepository(conn)
folder = _resolve_folder_id(folders_repo, folder_id, user)
if not folder:
return make_response(
jsonify({"success": False, "message": "Folder not found"}),
404,
)
pg_folder_id = str(folder["id"])
agents_rows = conn.execute(
_sql_text(
"SELECT id, name, description FROM agents "
"WHERE user_id = :user_id AND folder_id = CAST(:fid AS uuid) "
"ORDER BY created_at DESC"
),
{"user_id": user, "fid": pg_folder_id},
).fetchall()
agents_list = [
{
"id": str(row._mapping["id"]),
"name": row._mapping["name"],
"description": row._mapping.get("description", "") or "",
}
for row in agents_rows
]
subfolders = folders_repo.list_children(pg_folder_id, user)
subfolders_list = [
{"id": str(sf["id"]), "name": sf["name"]}
for sf in subfolders
]
return make_response(
jsonify({
"id": str(folder["_id"]),
"name": folder["name"],
"parent_id": folder.get("parent_id"),
"agents": agents_list,
"subfolders": subfolders_list,
}),
jsonify(
{
"id": pg_folder_id,
"name": folder["name"],
"parent_id": (
str(folder["parent_id"]) if folder.get("parent_id") else None
),
"agents": agents_list,
"subfolders": subfolders_list,
}
),
200,
)
except Exception as e:
return make_response(jsonify({"success": False, "message": str(e)}), 400)
except Exception as err:
return _folder_error_response("Failed to fetch folder", err)
@api.doc(description="Update a folder")
def put(self, folder_id):
@@ -131,22 +186,60 @@ class AgentFolder(Resource):
return make_response(jsonify({"success": False, "message": "No data provided"}), 400)
try:
update_fields = {"updated_at": datetime.datetime.now(datetime.timezone.utc)}
if "name" in data:
update_fields["name"] = data["name"]
if "parent_id" in data:
if data["parent_id"] == folder_id:
return make_response(jsonify({"success": False, "message": "Cannot set folder as its own parent"}), 400)
update_fields["parent_id"] = data["parent_id"]
with db_session() as conn:
repo = AgentFoldersRepository(conn)
folder = _resolve_folder_id(repo, folder_id, user)
if not folder:
return make_response(
jsonify({"success": False, "message": "Folder not found"}),
404,
)
pg_folder_id = str(folder["id"])
update_fields: dict = {}
if "name" in data:
update_fields["name"] = data["name"]
if "description" in data:
update_fields["description"] = data["description"]
if "parent_id" in data:
parent_input = data.get("parent_id")
if parent_input:
if parent_input == folder_id or parent_input == pg_folder_id:
return make_response(
jsonify(
{
"success": False,
"message": "Cannot set folder as its own parent",
}
),
400,
)
parent = _resolve_folder_id(repo, parent_input, user)
if not parent:
return make_response(
jsonify({"success": False, "message": "Parent folder not found"}),
404,
)
if str(parent["id"]) == pg_folder_id:
return make_response(
jsonify(
{
"success": False,
"message": "Cannot set folder as its own parent",
}
),
400,
)
update_fields["parent_id"] = str(parent["id"])
else:
update_fields["parent_id"] = None
if update_fields:
repo.update(pg_folder_id, user, update_fields)
result = agent_folders_collection.update_one(
{"_id": ObjectId(folder_id), "user": user}, {"$set": update_fields}
)
if result.matched_count == 0:
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
return make_response(jsonify({"success": True}), 200)
except Exception as e:
return make_response(jsonify({"success": False, "message": str(e)}), 400)
except Exception as err:
return _folder_error_response("Failed to update folder", err)
@api.doc(description="Delete a folder")
def delete(self, folder_id):
@@ -155,18 +248,27 @@ class AgentFolder(Resource):
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
try:
agents_collection.update_many(
{"user": user, "folder_id": folder_id}, {"$unset": {"folder_id": ""}}
)
agent_folders_collection.update_many(
{"user": user, "parent_id": folder_id}, {"$unset": {"parent_id": ""}}
)
result = agent_folders_collection.delete_one({"_id": ObjectId(folder_id), "user": user})
if result.deleted_count == 0:
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
with db_session() as conn:
repo = AgentFoldersRepository(conn)
folder = _resolve_folder_id(repo, folder_id, user)
if not folder:
return make_response(
jsonify({"success": False, "message": "Folder not found"}),
404,
)
pg_folder_id = str(folder["id"])
# Clear folder assignments from agents; self-FK
# ``ON DELETE SET NULL`` handles child folders.
AgentsRepository(conn).clear_folder_for_all(pg_folder_id, user)
deleted = repo.delete(pg_folder_id, user)
if not deleted:
return make_response(
jsonify({"success": False, "message": "Folder not found"}),
404,
)
return make_response(jsonify({"success": True}), 200)
except Exception as e:
return make_response(jsonify({"success": False, "message": str(e)}), 400)
except Exception as err:
return _folder_error_response("Failed to delete folder", err)
@agents_folders_ns.route("/move_agent")
@@ -190,29 +292,32 @@ class MoveAgentToFolder(Resource):
if not data or not data.get("agent_id"):
return make_response(jsonify({"success": False, "message": "Agent ID is required"}), 400)
agent_id = data["agent_id"]
folder_id = data.get("folder_id")
agent_id_input = data["agent_id"]
folder_id_input = data.get("folder_id")
try:
agent = agents_collection.find_one({"_id": ObjectId(agent_id), "user": user})
if not agent:
return make_response(jsonify({"success": False, "message": "Agent not found"}), 404)
if folder_id:
folder = agent_folders_collection.find_one({"_id": ObjectId(folder_id), "user": user})
if not folder:
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
agents_collection.update_one(
{"_id": ObjectId(agent_id)}, {"$set": {"folder_id": folder_id}}
)
else:
agents_collection.update_one(
{"_id": ObjectId(agent_id)}, {"$unset": {"folder_id": ""}}
)
with db_session() as conn:
agents_repo = AgentsRepository(conn)
agent = agents_repo.get_any(agent_id_input, user)
if not agent:
return make_response(
jsonify({"success": False, "message": "Agent not found"}),
404,
)
pg_folder_id = None
if folder_id_input:
folders_repo = AgentFoldersRepository(conn)
folder = _resolve_folder_id(folders_repo, folder_id_input, user)
if not folder:
return make_response(
jsonify({"success": False, "message": "Folder not found"}),
404,
)
pg_folder_id = str(folder["id"])
agents_repo.set_folder(str(agent["id"]), user, pg_folder_id)
return make_response(jsonify({"success": True}), 200)
except Exception as e:
return make_response(jsonify({"success": False, "message": str(e)}), 400)
except Exception as err:
return _folder_error_response("Failed to move agent", err)
@agents_folders_ns.route("/bulk_move")
@@ -237,25 +342,25 @@ class BulkMoveAgents(Resource):
return make_response(jsonify({"success": False, "message": "Agent IDs are required"}), 400)
agent_ids = data["agent_ids"]
folder_id = data.get("folder_id")
folder_id_input = data.get("folder_id")
try:
if folder_id:
folder = agent_folders_collection.find_one({"_id": ObjectId(folder_id), "user": user})
if not folder:
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
object_ids = [ObjectId(aid) for aid in agent_ids]
if folder_id:
agents_collection.update_many(
{"_id": {"$in": object_ids}, "user": user},
{"$set": {"folder_id": folder_id}},
)
else:
agents_collection.update_many(
{"_id": {"$in": object_ids}, "user": user},
{"$unset": {"folder_id": ""}},
)
with db_session() as conn:
agents_repo = AgentsRepository(conn)
pg_folder_id = None
if folder_id_input:
folders_repo = AgentFoldersRepository(conn)
folder = _resolve_folder_id(folders_repo, folder_id_input, user)
if not folder:
return make_response(
jsonify({"success": False, "message": "Folder not found"}),
404,
)
pg_folder_id = str(folder["id"])
for agent_id_input in agent_ids:
agent = agents_repo.get_any(agent_id_input, user)
if agent is not None:
agents_repo.set_folder(str(agent["id"]), user, pg_folder_id)
return make_response(jsonify({"success": True}), 200)
except Exception as e:
return make_response(jsonify({"success": False, "message": str(e)}), 400)
except Exception as err:
return _folder_error_response("Failed to move agents", err)

File diff suppressed because it is too large Load Diff

View File

@@ -3,21 +3,17 @@
import datetime
import secrets
from bson import DBRef
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, Namespace, Resource
from sqlalchemy import text as _sql_text
from application.api import api
from application.core.settings import settings
from application.api.user.base import (
agents_collection,
db,
ensure_user_doc,
resolve_tool_details,
user_tools_collection,
users_collection,
)
from application.api.user.base import resolve_tool_details
from application.storage.db.base_repository import looks_like_uuid
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.repositories.users import UsersRepository
from application.storage.db.session import db_readonly, db_session
from application.utils import generate_image_url
agents_sharing_ns = Namespace(
@@ -25,6 +21,38 @@ agents_sharing_ns = Namespace(
)
def _serialize_agent_basic(agent: dict) -> dict:
"""Shape a PG agent row into the API response dict."""
source_id = agent.get("source_id")
return {
"id": str(agent["id"]),
"user": agent.get("user_id", ""),
"name": agent.get("name", ""),
"image": (
generate_image_url(agent["image"]) if agent.get("image") else ""
),
"description": agent.get("description", ""),
"source": str(source_id) if source_id else "",
"chunks": str(agent["chunks"]) if agent.get("chunks") is not None else "0",
"retriever": agent.get("retriever", "classic") or "classic",
"prompt_id": str(agent["prompt_id"]) if agent.get("prompt_id") else "default",
"tools": agent.get("tools", []) or [],
"tool_details": resolve_tool_details(agent.get("tools", []) or []),
"agent_type": agent.get("agent_type", "") or "",
"status": agent.get("status", "") or "",
"json_schema": agent.get("json_schema"),
"limited_token_mode": agent.get("limited_token_mode", False),
"token_limit": agent.get("token_limit") or settings.DEFAULT_AGENT_LIMITS["token_limit"],
"limited_request_mode": agent.get("limited_request_mode", False),
"request_limit": agent.get("request_limit") or settings.DEFAULT_AGENT_LIMITS["request_limit"],
"created_at": agent.get("created_at", ""),
"updated_at": agent.get("updated_at", ""),
"shared": bool(agent.get("shared", False)),
"shared_token": agent.get("shared_token", "") or "",
"shared_metadata": agent.get("shared_metadata", {}) or {},
}
@agents_sharing_ns.route("/shared_agent")
class SharedAgent(Resource):
@api.doc(
@@ -41,70 +69,33 @@ class SharedAgent(Resource):
jsonify({"success": False, "message": "Token or ID is required"}), 400
)
try:
query = {
"shared_publicly": True,
"shared_token": shared_token,
}
shared_agent = agents_collection.find_one(query)
with db_readonly() as conn:
shared_agent = AgentsRepository(conn).find_by_shared_token(
shared_token,
)
if not shared_agent:
return make_response(
jsonify({"success": False, "message": "Shared agent not found"}),
404,
)
agent_id = str(shared_agent["_id"])
data = {
"id": agent_id,
"user": shared_agent.get("user", ""),
"name": shared_agent.get("name", ""),
"image": (
generate_image_url(shared_agent["image"])
if shared_agent.get("image")
else ""
),
"description": shared_agent.get("description", ""),
"source": (
str(source_doc["_id"])
if isinstance(shared_agent.get("source"), DBRef)
and (source_doc := db.dereference(shared_agent.get("source")))
else ""
),
"chunks": shared_agent.get("chunks", "0"),
"retriever": shared_agent.get("retriever", "classic"),
"prompt_id": shared_agent.get("prompt_id", "default"),
"tools": shared_agent.get("tools", []),
"tool_details": resolve_tool_details(shared_agent.get("tools", [])),
"agent_type": shared_agent.get("agent_type", ""),
"status": shared_agent.get("status", ""),
"json_schema": shared_agent.get("json_schema"),
"limited_token_mode": shared_agent.get("limited_token_mode", False),
"token_limit": shared_agent.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"]),
"limited_request_mode": shared_agent.get("limited_request_mode", False),
"request_limit": shared_agent.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"]),
"created_at": shared_agent.get("createdAt", ""),
"updated_at": shared_agent.get("updatedAt", ""),
"shared": shared_agent.get("shared_publicly", False),
"shared_token": shared_agent.get("shared_token", ""),
"shared_metadata": shared_agent.get("shared_metadata", {}),
}
agent_id = str(shared_agent["id"])
data = _serialize_agent_basic(shared_agent)
if data["tools"]:
enriched_tools = []
for tool in data["tools"]:
tool_data = user_tools_collection.find_one({"_id": ObjectId(tool)})
if tool_data:
enriched_tools.append(tool_data.get("name", ""))
for detail in data["tool_details"]:
enriched_tools.append(detail.get("name", ""))
data["tools"] = enriched_tools
decoded_token = getattr(request, "decoded_token", None)
if decoded_token:
user_id = decoded_token.get("sub")
owner_id = shared_agent.get("user")
owner_id = shared_agent.get("user_id")
if user_id != owner_id:
ensure_user_doc(user_id)
users_collection.update_one(
{"user_id": user_id},
{"$addToSet": {"agent_preferences.shared_with_me": agent_id}},
)
with db_session() as conn:
users_repo = UsersRepository(conn)
users_repo.upsert(user_id)
users_repo.add_shared(user_id, agent_id)
return make_response(jsonify(data), 200)
except Exception as err:
current_app.logger.error(f"Error retrieving shared agent: {err}")
@@ -121,52 +112,73 @@ class SharedAgents(Resource):
return make_response(jsonify({"success": False}), 401)
user_id = decoded_token.get("sub")
user_doc = ensure_user_doc(user_id)
shared_with_ids = user_doc.get("agent_preferences", {}).get(
"shared_with_me", []
)
shared_object_ids = [ObjectId(id) for id in shared_with_ids]
shared_agents_cursor = agents_collection.find(
{"_id": {"$in": shared_object_ids}, "shared_publicly": True}
)
shared_agents = list(shared_agents_cursor)
found_ids_set = {str(agent["_id"]) for agent in shared_agents}
stale_ids = [id for id in shared_with_ids if id not in found_ids_set]
if stale_ids:
users_collection.update_one(
{"user_id": user_id},
{"$pullAll": {"agent_preferences.shared_with_me": stale_ids}},
with db_session() as conn:
users_repo = UsersRepository(conn)
user_doc = users_repo.upsert(user_id)
shared_with_ids = (
user_doc.get("agent_preferences", {}).get("shared_with_me", [])
if isinstance(user_doc.get("agent_preferences"), dict)
else []
)
pinned_ids = set(user_doc.get("agent_preferences", {}).get("pinned", []))
# Keep only UUID-shaped ids; ObjectId leftovers are stripped below.
uuid_ids = [sid for sid in shared_with_ids if looks_like_uuid(sid)]
non_uuid_ids = [sid for sid in shared_with_ids if not looks_like_uuid(sid)]
list_shared_agents = [
{
"id": str(agent["_id"]),
"name": agent.get("name", ""),
"description": agent.get("description", ""),
"image": (
generate_image_url(agent["image"]) if agent.get("image") else ""
),
"tools": agent.get("tools", []),
"tool_details": resolve_tool_details(agent.get("tools", [])),
"agent_type": agent.get("agent_type", ""),
"status": agent.get("status", ""),
"json_schema": agent.get("json_schema"),
"limited_token_mode": agent.get("limited_token_mode", False),
"token_limit": agent.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"]),
"limited_request_mode": agent.get("limited_request_mode", False),
"request_limit": agent.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"]),
"created_at": agent.get("createdAt", ""),
"updated_at": agent.get("updatedAt", ""),
"pinned": str(agent["_id"]) in pinned_ids,
"shared": agent.get("shared_publicly", False),
"shared_token": agent.get("shared_token", ""),
"shared_metadata": agent.get("shared_metadata", {}),
}
for agent in shared_agents
]
if uuid_ids:
result = conn.execute(
_sql_text(
"SELECT * FROM agents "
"WHERE id = ANY(CAST(:ids AS uuid[])) "
"AND shared = true"
),
{"ids": uuid_ids},
)
shared_agents = [dict(row._mapping) for row in result.fetchall()]
else:
shared_agents = []
found_ids_set = {str(agent["id"]) for agent in shared_agents}
stale_ids = [sid for sid in uuid_ids if sid not in found_ids_set]
stale_ids.extend(non_uuid_ids)
if stale_ids:
users_repo.remove_shared_bulk(user_id, stale_ids)
pinned_ids = set(
user_doc.get("agent_preferences", {}).get("pinned", [])
if isinstance(user_doc.get("agent_preferences"), dict)
else []
)
list_shared_agents = []
for agent in shared_agents:
agent_id_str = str(agent["id"])
list_shared_agents.append(
{
"id": agent_id_str,
"name": agent.get("name", ""),
"description": agent.get("description", ""),
"image": (
generate_image_url(agent["image"]) if agent.get("image") else ""
),
"tools": agent.get("tools", []) or [],
"tool_details": resolve_tool_details(
agent.get("tools", []) or []
),
"agent_type": agent.get("agent_type", "") or "",
"status": agent.get("status", "") or "",
"json_schema": agent.get("json_schema"),
"limited_token_mode": agent.get("limited_token_mode", False),
"token_limit": agent.get("token_limit") or settings.DEFAULT_AGENT_LIMITS["token_limit"],
"limited_request_mode": agent.get("limited_request_mode", False),
"request_limit": agent.get("request_limit") or settings.DEFAULT_AGENT_LIMITS["request_limit"],
"created_at": agent.get("created_at", ""),
"updated_at": agent.get("updated_at", ""),
"pinned": agent_id_str in pinned_ids,
"shared": bool(agent.get("shared", False)),
"shared_token": agent.get("shared_token", "") or "",
"shared_metadata": agent.get("shared_metadata", {}) or {},
}
)
return make_response(jsonify(list_shared_agents), 200)
except Exception as err:
@@ -220,44 +232,43 @@ class ShareAgent(Resource):
),
400,
)
shared_token = None
try:
try:
agent_oid = ObjectId(agent_id)
except Exception:
return make_response(
jsonify({"success": False, "message": "Invalid agent ID"}), 400
)
agent = agents_collection.find_one({"_id": agent_oid, "user": user})
if not agent:
return make_response(
jsonify({"success": False, "message": "Agent not found"}), 404
)
if shared:
shared_metadata = {
"shared_by": username,
"shared_at": datetime.datetime.now(datetime.timezone.utc),
}
shared_token = secrets.token_urlsafe(32)
agents_collection.update_one(
{"_id": agent_oid, "user": user},
{
"$set": {
"shared_publicly": shared,
"shared_metadata": shared_metadata,
with db_session() as conn:
repo = AgentsRepository(conn)
agent = repo.get_any(agent_id, user)
if not agent:
return make_response(
jsonify({"success": False, "message": "Agent not found"}), 404
)
if shared:
shared_metadata = {
"shared_by": username,
"shared_at": datetime.datetime.now(
datetime.timezone.utc
).isoformat(),
}
shared_token = secrets.token_urlsafe(32)
repo.update(
str(agent["id"]), user,
{
"shared": True,
"shared_token": shared_token,
}
},
)
else:
agents_collection.update_one(
{"_id": agent_oid, "user": user},
{"$set": {"shared_publicly": shared, "shared_token": None}},
{"$unset": {"shared_metadata": ""}},
)
"shared_metadata": shared_metadata,
},
)
else:
repo.update(
str(agent["id"]), user,
{
"shared": False,
"shared_token": None,
"shared_metadata": None,
},
)
except Exception as err:
current_app.logger.error(f"Error sharing/unsharing agent: {err}", exc_info=True)
return make_response(jsonify({"success": False, "error": "Failed to update agent sharing status"}), 400)
shared_token = shared_token if shared else None
return make_response(
jsonify({"success": True, "shared_token": shared_token}), 200
)

View File

@@ -2,14 +2,15 @@
import secrets
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask_restx import Namespace, Resource
from application.api import api
from application.api.user.base import agents_collection, require_agent
from application.api.user.base import require_agent
from application.api.user.tasks import process_agent_webhook
from application.core.settings import settings
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.session import db_readonly, db_session
agents_webhooks_ns = Namespace(
@@ -34,9 +35,8 @@ class AgentWebhook(Resource):
jsonify({"success": False, "message": "ID is required"}), 400
)
try:
agent = agents_collection.find_one(
{"_id": ObjectId(agent_id), "user": user}
)
with db_readonly() as conn:
agent = AgentsRepository(conn).get_any(agent_id, user)
if not agent:
return make_response(
jsonify({"success": False, "message": "Agent not found"}), 404
@@ -44,10 +44,11 @@ class AgentWebhook(Resource):
webhook_token = agent.get("incoming_webhook_token")
if not webhook_token:
webhook_token = secrets.token_urlsafe(32)
agents_collection.update_one(
{"_id": ObjectId(agent_id), "user": user},
{"$set": {"incoming_webhook_token": webhook_token}},
)
with db_session() as conn:
AgentsRepository(conn).update(
str(agent["id"]), user,
{"incoming_webhook_token": webhook_token},
)
base_url = settings.API_URL.rstrip("/")
full_webhook_url = f"{base_url}/api/webhooks/agents/{webhook_token}"
except Exception as err:

View File

@@ -2,26 +2,84 @@
import datetime
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, Namespace, Resource
from sqlalchemy import text as _sql_text
from application.api import api
from application.api.user.base import (
agents_collection,
conversations_collection,
generate_date_range,
generate_hourly_range,
generate_minute_range,
token_usage_collection,
user_logs_collection,
)
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.repositories.token_usage import TokenUsageRepository
from application.storage.db.repositories.user_logs import UserLogsRepository
from application.storage.db.session import db_readonly
analytics_ns = Namespace(
"analytics", description="Analytics and reporting operations", path="/api"
)
_FILTER_BUCKETS = {
"last_hour": ("minute", "%Y-%m-%d %H:%M:00", "YYYY-MM-DD HH24:MI:00"),
"last_24_hour": ("hour", "%Y-%m-%d %H:00", "YYYY-MM-DD HH24:00"),
"last_7_days": ("day", "%Y-%m-%d", "YYYY-MM-DD"),
"last_15_days": ("day", "%Y-%m-%d", "YYYY-MM-DD"),
"last_30_days": ("day", "%Y-%m-%d", "YYYY-MM-DD"),
}
def _range_for_filter(filter_option: str):
"""Return ``(start_date, end_date, bucket_unit, pg_fmt)`` for the filter.
Returns ``None`` on invalid filter.
"""
if filter_option not in _FILTER_BUCKETS:
return None
end_date = datetime.datetime.now(datetime.timezone.utc)
bucket_unit, _py_fmt, pg_fmt = _FILTER_BUCKETS[filter_option]
if filter_option == "last_hour":
start_date = end_date - datetime.timedelta(hours=1)
elif filter_option == "last_24_hour":
start_date = end_date - datetime.timedelta(hours=24)
else:
days = {
"last_7_days": 6,
"last_15_days": 14,
"last_30_days": 29,
}[filter_option]
start_date = end_date - datetime.timedelta(days=days)
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
end_date = end_date.replace(
hour=23, minute=59, second=59, microsecond=999999
)
return start_date, end_date, bucket_unit, pg_fmt
def _intervals_for_filter(filter_option, start_date, end_date):
if filter_option == "last_hour":
return generate_minute_range(start_date, end_date)
if filter_option == "last_24_hour":
return generate_hourly_range(start_date, end_date)
return generate_date_range(start_date, end_date)
def _resolve_api_key(conn, api_key_id, user_id):
"""Look up the ``agents.key`` value for a given agent id.
Scoped by ``user_id`` so an authenticated caller can't probe another
user's agents. Accepts either UUID or legacy Mongo ObjectId shape.
"""
if not api_key_id:
return None
agent = AgentsRepository(conn).get_any(api_key_id, user_id)
return (agent or {}).get("key") if agent else None
@analytics_ns.route("/get_message_analytics")
class GetMessageAnalytics(Resource):
get_message_analytics_model = api.model(
@@ -32,13 +90,7 @@ class GetMessageAnalytics(Resource):
required=False,
description="Filter option for analytics",
default="last_30_days",
enum=[
"last_hour",
"last_24_hour",
"last_7_days",
"last_15_days",
"last_30_days",
],
enum=list(_FILTER_BUCKETS.keys()),
),
},
)
@@ -50,88 +102,54 @@ class GetMessageAnalytics(Resource):
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
data = request.get_json() or {}
api_key_id = data.get("api_key_id")
filter_option = data.get("filter_option", "last_30_days")
window = _range_for_filter(filter_option)
if window is None:
return make_response(
jsonify({"success": False, "message": "Invalid option"}), 400
)
start_date, end_date, _bucket_unit, pg_fmt = window
try:
api_key = (
agents_collection.find_one({"_id": ObjectId(api_key_id), "user": user})[
"key"
with db_readonly() as conn:
api_key = _resolve_api_key(conn, api_key_id, user)
# Count messages per bucket, filtered by the conversation's
# owner (user_id) and optionally the agent api_key. The
# ``user_id`` filter is always applied post-cutover to
# prevent cross-tenant leakage on admin dashboards.
clauses = [
"c.user_id = :user_id",
"m.timestamp >= :start",
"m.timestamp <= :end",
]
if api_key_id
else None
)
except Exception as err:
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
end_date = datetime.datetime.now(datetime.timezone.utc)
if filter_option == "last_hour":
start_date = end_date - datetime.timedelta(hours=1)
group_format = "%Y-%m-%d %H:%M:00"
elif filter_option == "last_24_hour":
start_date = end_date - datetime.timedelta(hours=24)
group_format = "%Y-%m-%d %H:00"
else:
if filter_option in ["last_7_days", "last_15_days", "last_30_days"]:
filter_days = (
6
if filter_option == "last_7_days"
else 14 if filter_option == "last_15_days" else 29
)
else:
return make_response(
jsonify({"success": False, "message": "Invalid option"}), 400
)
start_date = end_date - datetime.timedelta(days=filter_days)
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
end_date = end_date.replace(
hour=23, minute=59, second=59, microsecond=999999
)
group_format = "%Y-%m-%d"
try:
match_stage = {
"$match": {
"user": user,
params: dict = {
"user_id": user,
"start": start_date,
"end": end_date,
"fmt": pg_fmt,
}
}
if api_key:
match_stage["$match"]["api_key"] = api_key
pipeline = [
match_stage,
{"$unwind": "$queries"},
{
"$match": {
"queries.timestamp": {"$gte": start_date, "$lte": end_date}
}
},
{
"$group": {
"_id": {
"$dateToString": {
"format": group_format,
"date": "$queries.timestamp",
}
},
"count": {"$sum": 1},
}
},
{"$sort": {"_id": 1}},
]
if api_key:
clauses.append("c.api_key = :api_key")
params["api_key"] = api_key
where = " AND ".join(clauses)
sql = (
"SELECT to_char(m.timestamp AT TIME ZONE 'UTC', :fmt) AS bucket, "
"COUNT(*) AS count "
"FROM conversation_messages m "
"JOIN conversations c ON c.id = m.conversation_id "
f"WHERE {where} "
"GROUP BY bucket ORDER BY bucket ASC"
)
rows = conn.execute(_sql_text(sql), params).fetchall()
message_data = conversations_collection.aggregate(pipeline)
if filter_option == "last_hour":
intervals = generate_minute_range(start_date, end_date)
elif filter_option == "last_24_hour":
intervals = generate_hourly_range(start_date, end_date)
else:
intervals = generate_date_range(start_date, end_date)
intervals = _intervals_for_filter(filter_option, start_date, end_date)
daily_messages = {interval: 0 for interval in intervals}
for entry in message_data:
daily_messages[entry["_id"]] = entry["count"]
for row in rows:
daily_messages[row._mapping["bucket"]] = int(row._mapping["count"])
except Exception as err:
current_app.logger.error(
f"Error getting message analytics: {err}", exc_info=True
@@ -152,13 +170,7 @@ class GetTokenAnalytics(Resource):
required=False,
description="Filter option for analytics",
default="last_30_days",
enum=[
"last_hour",
"last_24_hour",
"last_7_days",
"last_15_days",
"last_30_days",
],
enum=list(_FILTER_BUCKETS.keys()),
),
},
)
@@ -170,123 +182,36 @@ class GetTokenAnalytics(Resource):
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
data = request.get_json() or {}
api_key_id = data.get("api_key_id")
filter_option = data.get("filter_option", "last_30_days")
try:
api_key = (
agents_collection.find_one({"_id": ObjectId(api_key_id), "user": user})[
"key"
]
if api_key_id
else None
window = _range_for_filter(filter_option)
if window is None:
return make_response(
jsonify({"success": False, "message": "Invalid option"}), 400
)
except Exception as err:
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
end_date = datetime.datetime.now(datetime.timezone.utc)
start_date, end_date, bucket_unit, _pg_fmt = window
if filter_option == "last_hour":
start_date = end_date - datetime.timedelta(hours=1)
group_format = "%Y-%m-%d %H:%M:00"
group_stage = {
"$group": {
"_id": {
"minute": {
"$dateToString": {
"format": group_format,
"date": "$timestamp",
}
}
},
"total_tokens": {
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
},
}
}
elif filter_option == "last_24_hour":
start_date = end_date - datetime.timedelta(hours=24)
group_format = "%Y-%m-%d %H:00"
group_stage = {
"$group": {
"_id": {
"hour": {
"$dateToString": {
"format": group_format,
"date": "$timestamp",
}
}
},
"total_tokens": {
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
},
}
}
else:
if filter_option in ["last_7_days", "last_15_days", "last_30_days"]:
filter_days = (
6
if filter_option == "last_7_days"
else (14 if filter_option == "last_15_days" else 29)
)
else:
return make_response(
jsonify({"success": False, "message": "Invalid option"}), 400
)
start_date = end_date - datetime.timedelta(days=filter_days)
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
end_date = end_date.replace(
hour=23, minute=59, second=59, microsecond=999999
)
group_format = "%Y-%m-%d"
group_stage = {
"$group": {
"_id": {
"day": {
"$dateToString": {
"format": group_format,
"date": "$timestamp",
}
}
},
"total_tokens": {
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
},
}
}
try:
match_stage = {
"$match": {
"user_id": user,
"timestamp": {"$gte": start_date, "$lte": end_date},
}
}
if api_key:
match_stage["$match"]["api_key"] = api_key
token_usage_data = token_usage_collection.aggregate(
[
match_stage,
group_stage,
{"$sort": {"_id": 1}},
]
)
with db_readonly() as conn:
api_key = _resolve_api_key(conn, api_key_id, user)
# ``bucketed_totals`` applies user_id / api_key filters
# directly — no need to reshape a Mongo pipeline.
rows = TokenUsageRepository(conn).bucketed_totals(
bucket_unit=bucket_unit,
user_id=user,
api_key=api_key,
timestamp_gte=start_date,
timestamp_lt=end_date,
)
if filter_option == "last_hour":
intervals = generate_minute_range(start_date, end_date)
elif filter_option == "last_24_hour":
intervals = generate_hourly_range(start_date, end_date)
else:
intervals = generate_date_range(start_date, end_date)
intervals = _intervals_for_filter(filter_option, start_date, end_date)
daily_token_usage = {interval: 0 for interval in intervals}
for entry in token_usage_data:
if filter_option == "last_hour":
daily_token_usage[entry["_id"]["minute"]] = entry["total_tokens"]
elif filter_option == "last_24_hour":
daily_token_usage[entry["_id"]["hour"]] = entry["total_tokens"]
else:
daily_token_usage[entry["_id"]["day"]] = entry["total_tokens"]
for entry in rows:
daily_token_usage[entry["bucket"]] = int(
entry["prompt_tokens"] + entry["generated_tokens"]
)
except Exception as err:
current_app.logger.error(
f"Error getting token analytics: {err}", exc_info=True
@@ -307,13 +232,7 @@ class GetFeedbackAnalytics(Resource):
required=False,
description="Filter option for analytics",
default="last_30_days",
enum=[
"last_hour",
"last_24_hour",
"last_7_days",
"last_15_days",
"last_30_days",
],
enum=list(_FILTER_BUCKETS.keys()),
),
},
)
@@ -325,128 +244,64 @@ class GetFeedbackAnalytics(Resource):
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
data = request.get_json() or {}
api_key_id = data.get("api_key_id")
filter_option = data.get("filter_option", "last_30_days")
window = _range_for_filter(filter_option)
if window is None:
return make_response(
jsonify({"success": False, "message": "Invalid option"}), 400
)
start_date, end_date, _bucket_unit, pg_fmt = window
try:
api_key = (
agents_collection.find_one({"_id": ObjectId(api_key_id), "user": user})[
"key"
with db_readonly() as conn:
api_key = _resolve_api_key(conn, api_key_id, user)
# Feedback lives inside the ``conversation_messages.feedback``
# JSONB as ``{"text": "like"|"dislike", "timestamp": "..."}``.
# There is no scalar ``feedback_timestamp`` column — extract
# the timestamp from the JSONB and cast it to timestamptz for
# the range filter + bucket grouping.
clauses = [
"c.user_id = :user_id",
"m.feedback IS NOT NULL",
"(m.feedback->>'timestamp')::timestamptz >= :start",
"(m.feedback->>'timestamp')::timestamptz <= :end",
]
if api_key_id
else None
)
except Exception as err:
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
end_date = datetime.datetime.now(datetime.timezone.utc)
if filter_option == "last_hour":
start_date = end_date - datetime.timedelta(hours=1)
group_format = "%Y-%m-%d %H:%M:00"
date_field = {
"$dateToString": {
"format": group_format,
"date": "$queries.feedback_timestamp",
params: dict = {
"user_id": user,
"start": start_date,
"end": end_date,
"fmt": pg_fmt,
}
}
elif filter_option == "last_24_hour":
start_date = end_date - datetime.timedelta(hours=24)
group_format = "%Y-%m-%d %H:00"
date_field = {
"$dateToString": {
"format": group_format,
"date": "$queries.feedback_timestamp",
}
}
else:
if filter_option in ["last_7_days", "last_15_days", "last_30_days"]:
filter_days = (
6
if filter_option == "last_7_days"
else (14 if filter_option == "last_15_days" else 29)
if api_key:
clauses.append("c.api_key = :api_key")
params["api_key"] = api_key
where = " AND ".join(clauses)
sql = (
"SELECT to_char("
"(m.feedback->>'timestamp')::timestamptz AT TIME ZONE 'UTC', :fmt"
") AS bucket, "
"SUM(CASE WHEN m.feedback->>'text' = 'like' THEN 1 ELSE 0 END) AS positive, "
"SUM(CASE WHEN m.feedback->>'text' = 'dislike' THEN 1 ELSE 0 END) AS negative "
"FROM conversation_messages m "
"JOIN conversations c ON c.id = m.conversation_id "
f"WHERE {where} "
"GROUP BY bucket ORDER BY bucket ASC"
)
else:
return make_response(
jsonify({"success": False, "message": "Invalid option"}), 400
)
start_date = end_date - datetime.timedelta(days=filter_days)
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
end_date = end_date.replace(
hour=23, minute=59, second=59, microsecond=999999
)
group_format = "%Y-%m-%d"
date_field = {
"$dateToString": {
"format": group_format,
"date": "$queries.feedback_timestamp",
}
}
try:
match_stage = {
"$match": {
"queries.feedback_timestamp": {
"$gte": start_date,
"$lte": end_date,
},
"queries.feedback": {"$exists": True},
}
}
if api_key:
match_stage["$match"]["api_key"] = api_key
pipeline = [
match_stage,
{"$unwind": "$queries"},
{"$match": {"queries.feedback": {"$exists": True}}},
{
"$group": {
"_id": {"time": date_field, "feedback": "$queries.feedback"},
"count": {"$sum": 1},
}
},
{
"$group": {
"_id": "$_id.time",
"positive": {
"$sum": {
"$cond": [
{"$eq": ["$_id.feedback", "LIKE"]},
"$count",
0,
]
}
},
"negative": {
"$sum": {
"$cond": [
{"$eq": ["$_id.feedback", "DISLIKE"]},
"$count",
0,
]
}
},
}
},
{"$sort": {"_id": 1}},
]
rows = conn.execute(_sql_text(sql), params).fetchall()
feedback_data = conversations_collection.aggregate(pipeline)
if filter_option == "last_hour":
intervals = generate_minute_range(start_date, end_date)
elif filter_option == "last_24_hour":
intervals = generate_hourly_range(start_date, end_date)
else:
intervals = generate_date_range(start_date, end_date)
intervals = _intervals_for_filter(filter_option, start_date, end_date)
daily_feedback = {
interval: {"positive": 0, "negative": 0} for interval in intervals
}
for entry in feedback_data:
daily_feedback[entry["_id"]] = {
"positive": entry["positive"],
"negative": entry["negative"],
for row in rows:
bucket = row._mapping["bucket"]
daily_feedback[bucket] = {
"positive": int(row._mapping["positive"] or 0),
"negative": int(row._mapping["negative"] or 0),
}
except Exception as err:
current_app.logger.error(
@@ -484,47 +339,89 @@ class GetUserLogs(Resource):
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
data = request.get_json() or {}
page = int(data.get("page", 1))
api_key_id = data.get("api_key_id")
page_size = int(data.get("page_size", 10))
skip = (page - 1) * page_size
try:
api_key = (
agents_collection.find_one({"_id": ObjectId(api_key_id)})["key"]
if api_key_id
else None
)
with db_readonly() as conn:
api_key = _resolve_api_key(conn, api_key_id, user)
logs_repo = UserLogsRepository(conn)
if api_key:
# ``find_by_api_key`` filters on ``data->>'api_key'``
# — the PG shape of the legacy top-level ``api_key``
# filter. Paginate client-side using offset/limit.
all_rows = logs_repo.find_by_api_key(api_key)
offset = (page - 1) * page_size
window = all_rows[offset: offset + page_size + 1]
items = window
else:
items, has_more_flag = logs_repo.list_paginated(
user_id=user,
page=page,
page_size=page_size,
)
# list_paginated already trims to page_size and
# returns has_more separately.
results = [
{
"id": str(item.get("id") or item.get("_id")),
"action": (item.get("data") or {}).get("action"),
"level": (item.get("data") or {}).get("level"),
"user": item.get("user_id"),
"question": (item.get("data") or {}).get("question"),
"sources": (item.get("data") or {}).get("sources"),
"retriever_params": (item.get("data") or {}).get(
"retriever_params"
),
"timestamp": (
item["timestamp"].isoformat()
if hasattr(item.get("timestamp"), "isoformat")
else item.get("timestamp")
),
}
for item in items
]
return make_response(
jsonify(
{
"success": True,
"logs": results,
"page": page,
"page_size": page_size,
"has_more": has_more_flag,
}
),
200,
)
has_more = len(items) > page_size
items = items[:page_size]
results = [
{
"id": str(item.get("id") or item.get("_id")),
"action": (item.get("data") or {}).get("action"),
"level": (item.get("data") or {}).get("level"),
"user": item.get("user_id"),
"question": (item.get("data") or {}).get("question"),
"sources": (item.get("data") or {}).get("sources"),
"retriever_params": (item.get("data") or {}).get(
"retriever_params"
),
"timestamp": (
item["timestamp"].isoformat()
if hasattr(item.get("timestamp"), "isoformat")
else item.get("timestamp")
),
}
for item in items
]
except Exception as err:
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
current_app.logger.error(
f"Error getting user logs: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
query = {"user": user}
if api_key:
query = {"api_key": api_key}
items_cursor = (
user_logs_collection.find(query)
.sort("timestamp", -1)
.skip(skip)
.limit(page_size + 1)
)
items = list(items_cursor)
results = [
{
"id": str(item.get("_id")),
"action": item.get("action"),
"level": item.get("level"),
"user": item.get("user"),
"question": item.get("question"),
"sources": item.get("sources"),
"retriever_params": item.get("retriever_params"),
"timestamp": item.get("timestamp"),
}
for item in items[:page_size]
]
has_more = len(items) > page_size
return make_response(
jsonify(

View File

@@ -1,15 +1,39 @@
"""File attachments and media routes."""
import os
import tempfile
from pathlib import Path
import uuid
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, Namespace, Resource
from application.api import api
from application.api.user.base import agents_collection, storage
from application.api.user.tasks import store_attachment
from application.cache import get_redis_instance
from application.core.settings import settings
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.session import db_readonly
from application.stt.constants import (
SUPPORTED_AUDIO_EXTENSIONS,
SUPPORTED_AUDIO_MIME_TYPES,
)
from application.stt.upload_limits import (
AudioFileTooLargeError,
build_stt_file_size_limit_message,
enforce_audio_file_size_limit,
is_audio_filename,
)
from application.stt.live_session import (
apply_live_stt_hypothesis,
create_live_stt_session,
delete_live_stt_session,
finalize_live_stt_session,
get_live_stt_transcript_text,
load_live_stt_session,
save_live_stt_session,
)
from application.stt.stt_creator import STTCreator
from application.tts.tts_creator import TTSCreator
from application.utils import safe_filename
@@ -19,6 +43,73 @@ attachments_ns = Namespace(
)
def _resolve_authenticated_user():
decoded_token = getattr(request, "decoded_token", None)
api_key = request.form.get("api_key") or request.args.get("api_key")
if decoded_token:
return safe_filename(decoded_token.get("sub"))
if api_key:
with db_readonly() as conn:
agent = AgentsRepository(conn).find_by_key(api_key)
if not agent:
return make_response(
jsonify({"success": False, "message": "Invalid API key"}), 401
)
return safe_filename(agent.get("user_id"))
return None
def _get_uploaded_file_size(file) -> int:
try:
current_position = file.stream.tell()
file.stream.seek(0, os.SEEK_END)
size_bytes = file.stream.tell()
file.stream.seek(current_position)
return size_bytes
except Exception:
return 0
def _is_supported_audio_mimetype(mimetype: str) -> bool:
if not mimetype:
return True
normalized = mimetype.split(";")[0].strip().lower()
return normalized.startswith("audio/") or normalized in SUPPORTED_AUDIO_MIME_TYPES
def _enforce_uploaded_audio_size_limit(file, filename: str) -> None:
if not is_audio_filename(filename):
return
size_bytes = _get_uploaded_file_size(file)
if size_bytes:
enforce_audio_file_size_limit(size_bytes)
def _get_store_attachment_user_error(exc: Exception) -> str:
if isinstance(exc, AudioFileTooLargeError):
return build_stt_file_size_limit_message()
return "Failed to process file"
def _require_live_stt_redis():
redis_client = get_redis_instance()
if redis_client:
return redis_client
return make_response(
jsonify({"success": False, "message": "Live transcription is unavailable"}),
503,
)
def _parse_bool_form_value(value: str | None) -> bool:
if value is None:
return False
return value.strip().lower() in {"1", "true", "yes", "on"}
@attachments_ns.route("/store_attachment")
class StoreAttachment(Resource):
@api.expect(
@@ -36,8 +127,9 @@ class StoreAttachment(Resource):
description="Stores one or multiple attachments without vectorization or training. Supports user or API key authentication."
)
def post(self):
decoded_token = getattr(request, "decoded_token", None)
api_key = request.form.get("api_key") or request.args.get("api_key")
auth_user = _resolve_authenticated_user()
if hasattr(auth_user, "status_code"):
return auth_user
files = request.files.getlist("file")
if not files:
@@ -51,30 +143,25 @@ class StoreAttachment(Resource):
400,
)
user = None
if decoded_token:
user = safe_filename(decoded_token.get("sub"))
elif api_key:
agent = agents_collection.find_one({"key": api_key})
if not agent:
return make_response(
jsonify({"success": False, "message": "Invalid API key"}), 401
)
user = safe_filename(agent.get("user"))
else:
user = auth_user
if not user:
return make_response(
jsonify({"success": False, "message": "Authentication required"}), 401
)
try:
from application.api.user.tasks import store_attachment
from application.api.user.base import storage
tasks = []
errors = []
original_file_count = len(files)
for idx, file in enumerate(files):
try:
attachment_id = ObjectId()
attachment_id = uuid.uuid4()
original_filename = safe_filename(os.path.basename(file.filename))
_enforce_uploaded_audio_size_limit(file, original_filename)
relative_path = f"{settings.UPLOAD_FOLDER}/{user}/attachments/{str(attachment_id)}/{original_filename}"
metadata = storage.save_file(file, relative_path)
@@ -90,15 +177,31 @@ class StoreAttachment(Resource):
"task_id": task.id,
"filename": original_filename,
"attachment_id": str(attachment_id),
"upload_index": idx,
})
except Exception as file_err:
current_app.logger.error(f"Error processing file {idx} ({file.filename}): {file_err}", exc_info=True)
errors.append({
"upload_index": idx,
"filename": file.filename,
"error": str(file_err)
"error": _get_store_attachment_user_error(file_err),
})
if not tasks:
if errors and all(
error.get("error") == build_stt_file_size_limit_message()
for error in errors
):
return make_response(
jsonify(
{
"success": False,
"message": build_stt_file_size_limit_message(),
"errors": errors,
}
),
413,
)
return make_response(
jsonify({"status": "error", "message": "No valid files to upload"}),
400,
@@ -135,11 +238,389 @@ class StoreAttachment(Resource):
return make_response(jsonify({"success": False, "error": "Failed to store attachment"}), 400)
@attachments_ns.route("/stt")
class SpeechToText(Resource):
@api.expect(
api.model(
"SpeechToTextModel",
{
"file": fields.Raw(required=True, description="Audio file"),
"language": fields.String(
required=False, description="Optional transcription language hint"
),
},
)
)
@api.doc(description="Transcribe an uploaded audio file")
def post(self):
auth_user = _resolve_authenticated_user()
if hasattr(auth_user, "status_code"):
return auth_user
if not auth_user:
return make_response(
jsonify({"success": False, "message": "Authentication required"}),
401,
)
file = request.files.get("file")
if not file or file.filename == "":
return make_response(
jsonify({"success": False, "message": "Missing file"}),
400,
)
filename = safe_filename(os.path.basename(file.filename))
suffix = Path(filename).suffix.lower()
if suffix not in SUPPORTED_AUDIO_EXTENSIONS:
return make_response(
jsonify({"success": False, "message": "Unsupported audio format"}),
400,
)
if not _is_supported_audio_mimetype(file.mimetype or ""):
return make_response(
jsonify({"success": False, "message": "Unsupported audio MIME type"}),
400,
)
try:
_enforce_uploaded_audio_size_limit(file, filename)
except AudioFileTooLargeError:
return make_response(
jsonify(
{
"success": False,
"message": build_stt_file_size_limit_message(),
}
),
413,
)
temp_path = None
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
file.save(temp_file.name)
temp_path = Path(temp_file.name)
stt_instance = STTCreator.create_stt(settings.STT_PROVIDER)
transcript = stt_instance.transcribe(
temp_path,
language=request.form.get("language") or settings.STT_LANGUAGE,
timestamps=settings.STT_ENABLE_TIMESTAMPS,
diarize=settings.STT_ENABLE_DIARIZATION,
)
return make_response(jsonify({"success": True, **transcript}), 200)
except Exception as err:
current_app.logger.error(f"Error transcribing audio: {err}", exc_info=True)
return make_response(
jsonify({"success": False, "message": "Failed to transcribe audio"}),
400,
)
finally:
if temp_path and temp_path.exists():
temp_path.unlink()
@attachments_ns.route("/stt/live/start")
class LiveSpeechToTextStart(Resource):
@api.doc(description="Start a live speech-to-text session")
def post(self):
auth_user = _resolve_authenticated_user()
if hasattr(auth_user, "status_code"):
return auth_user
if not auth_user:
return make_response(
jsonify({"success": False, "message": "Authentication required"}),
401,
)
redis_client = _require_live_stt_redis()
if hasattr(redis_client, "status_code"):
return redis_client
payload = request.get_json(silent=True) or {}
session_state = create_live_stt_session(
user=auth_user,
language=payload.get("language") or settings.STT_LANGUAGE,
)
save_live_stt_session(redis_client, session_state)
return make_response(
jsonify(
{
"success": True,
"session_id": session_state["session_id"],
"language": session_state.get("language"),
"committed_text": "",
"mutable_text": "",
"previous_hypothesis": "",
"latest_hypothesis": "",
"finalized_text": "",
"pending_text": "",
"transcript_text": "",
}
),
200,
)
@attachments_ns.route("/stt/live/chunk")
class LiveSpeechToTextChunk(Resource):
@api.expect(
api.model(
"LiveSpeechToTextChunkModel",
{
"session_id": fields.String(
required=True, description="Live transcription session ID"
),
"chunk_index": fields.Integer(
required=True, description="Sequential chunk index"
),
"is_silence": fields.Boolean(
required=False,
description="Whether the latest capture window was mostly silence",
),
"file": fields.Raw(required=True, description="Audio chunk"),
},
)
)
@api.doc(description="Transcribe a chunk for a live speech-to-text session")
def post(self):
auth_user = _resolve_authenticated_user()
if hasattr(auth_user, "status_code"):
return auth_user
if not auth_user:
return make_response(
jsonify({"success": False, "message": "Authentication required"}),
401,
)
redis_client = _require_live_stt_redis()
if hasattr(redis_client, "status_code"):
return redis_client
session_id = request.form.get("session_id", "").strip()
if not session_id:
return make_response(
jsonify({"success": False, "message": "Missing session_id"}),
400,
)
session_state = load_live_stt_session(redis_client, session_id)
if not session_state:
return make_response(
jsonify(
{
"success": False,
"message": "Live transcription session not found",
}
),
404,
)
if safe_filename(str(session_state.get("user", ""))) != auth_user:
return make_response(
jsonify({"success": False, "message": "Forbidden"}),
403,
)
chunk_index_raw = request.form.get("chunk_index", "").strip()
if chunk_index_raw == "":
return make_response(
jsonify({"success": False, "message": "Missing chunk_index"}),
400,
)
try:
chunk_index = int(chunk_index_raw)
except ValueError:
return make_response(
jsonify({"success": False, "message": "Invalid chunk_index"}),
400,
)
is_silence = _parse_bool_form_value(request.form.get("is_silence"))
file = request.files.get("file")
if not file or file.filename == "":
return make_response(
jsonify({"success": False, "message": "Missing file"}),
400,
)
filename = safe_filename(os.path.basename(file.filename))
suffix = Path(filename).suffix.lower()
if suffix not in SUPPORTED_AUDIO_EXTENSIONS:
return make_response(
jsonify({"success": False, "message": "Unsupported audio format"}),
400,
)
if not _is_supported_audio_mimetype(file.mimetype or ""):
return make_response(
jsonify({"success": False, "message": "Unsupported audio MIME type"}),
400,
)
try:
_enforce_uploaded_audio_size_limit(file, filename)
except AudioFileTooLargeError:
return make_response(
jsonify(
{
"success": False,
"message": build_stt_file_size_limit_message(),
}
),
413,
)
temp_path = None
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
file.save(temp_file.name)
temp_path = Path(temp_file.name)
session_language = session_state.get("language") or settings.STT_LANGUAGE
stt_instance = STTCreator.create_stt(settings.STT_PROVIDER)
transcript = stt_instance.transcribe(
temp_path,
language=session_language,
timestamps=False,
diarize=False,
)
if not session_state.get("language") and transcript.get("language"):
session_state["language"] = transcript["language"]
try:
apply_live_stt_hypothesis(
session_state,
str(transcript.get("text", "")),
chunk_index,
is_silence=is_silence,
)
except ValueError:
current_app.logger.warning(
"Invalid live transcription chunk",
exc_info=True,
)
return make_response(
jsonify(
{
"success": False,
"message": "Invalid live transcription chunk",
}
),
409,
)
save_live_stt_session(redis_client, session_state)
return make_response(
jsonify(
{
"success": True,
"session_id": session_id,
"chunk_index": chunk_index,
"chunk_text": transcript.get("text", ""),
"is_silence": is_silence,
"language": session_state.get("language"),
"committed_text": session_state.get("committed_text", ""),
"mutable_text": session_state.get("mutable_text", ""),
"previous_hypothesis": session_state.get(
"previous_hypothesis", ""
),
"latest_hypothesis": session_state.get(
"latest_hypothesis", ""
),
"finalized_text": session_state.get("committed_text", ""),
"pending_text": session_state.get("mutable_text", ""),
"transcript_text": get_live_stt_transcript_text(session_state),
}
),
200,
)
except Exception as err:
current_app.logger.error(
f"Error transcribing live audio chunk: {err}", exc_info=True
)
return make_response(
jsonify({"success": False, "message": "Failed to transcribe audio"}),
400,
)
finally:
if temp_path and temp_path.exists():
temp_path.unlink()
@attachments_ns.route("/stt/live/finish")
class LiveSpeechToTextFinish(Resource):
@api.doc(description="Finish a live speech-to-text session")
def post(self):
auth_user = _resolve_authenticated_user()
if hasattr(auth_user, "status_code"):
return auth_user
if not auth_user:
return make_response(
jsonify({"success": False, "message": "Authentication required"}),
401,
)
redis_client = _require_live_stt_redis()
if hasattr(redis_client, "status_code"):
return redis_client
payload = request.get_json(silent=True) or {}
session_id = str(payload.get("session_id", "")).strip()
if not session_id:
return make_response(
jsonify({"success": False, "message": "Missing session_id"}),
400,
)
session_state = load_live_stt_session(redis_client, session_id)
if not session_state:
return make_response(
jsonify(
{
"success": False,
"message": "Live transcription session not found",
}
),
404,
)
if safe_filename(str(session_state.get("user", ""))) != auth_user:
return make_response(
jsonify({"success": False, "message": "Forbidden"}),
403,
)
final_text = finalize_live_stt_session(session_state)
delete_live_stt_session(redis_client, session_id)
return make_response(
jsonify(
{
"success": True,
"session_id": session_id,
"language": session_state.get("language"),
"text": final_text,
}
),
200,
)
@attachments_ns.route("/images/<path:image_path>")
class ServeImage(Resource):
@api.doc(description="Serve an image from storage")
def get(self, image_path):
if ".." in image_path or image_path.startswith("/") or "\x00" in image_path:
return make_response(
jsonify({"success": False, "message": "Invalid image path"}), 400
)
try:
from application.api.user.base import storage
file_obj = storage.get_file(image_path)
extension = image_path.split(".")[-1].lower()
content_type = f"image/{extension}"
@@ -154,6 +635,10 @@ class ServeImage(Resource):
return make_response(
jsonify({"success": False, "message": "Image not found"}), 404
)
except ValueError:
return make_response(
jsonify({"success": False, "message": "Invalid image path"}), 400
)
except Exception as e:
current_app.logger.error(f"Error serving image: {e}")
return make_response(

View File

@@ -8,13 +8,15 @@ import uuid
from functools import wraps
from typing import Optional, Tuple
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, Response
from pymongo import ReturnDocument
from werkzeug.utils import secure_filename
from application.core.mongo_db import MongoDB
from sqlalchemy import text as _sql_text
from application.core.settings import settings
from application.storage.db.base_repository import looks_like_uuid, row_to_dict
from application.storage.db.repositories.users import UsersRepository
from application.storage.db.session import db_readonly, db_session
from application.storage.storage_creator import StorageCreator
from application.vectorstore.vector_creator import VectorCreator
@@ -22,56 +24,6 @@ from application.vectorstore.vector_creator import VectorCreator
storage = StorageCreator.get_storage()
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
conversations_collection = db["conversations"]
sources_collection = db["sources"]
prompts_collection = db["prompts"]
feedback_collection = db["feedback"]
agents_collection = db["agents"]
agent_folders_collection = db["agent_folders"]
token_usage_collection = db["token_usage"]
shared_conversations_collections = db["shared_conversations"]
users_collection = db["users"]
user_logs_collection = db["user_logs"]
user_tools_collection = db["user_tools"]
attachments_collection = db["attachments"]
workflow_runs_collection = db["workflow_runs"]
workflows_collection = db["workflows"]
workflow_nodes_collection = db["workflow_nodes"]
workflow_edges_collection = db["workflow_edges"]
try:
agents_collection.create_index(
[("shared", 1)],
name="shared_index",
background=True,
)
users_collection.create_index("user_id", unique=True)
workflows_collection.create_index(
[("user", 1)], name="workflow_user_index", background=True
)
workflow_nodes_collection.create_index(
[("workflow_id", 1)], name="node_workflow_index", background=True
)
workflow_nodes_collection.create_index(
[("workflow_id", 1), ("graph_version", 1)],
name="node_workflow_graph_version_index",
background=True,
)
workflow_edges_collection.create_index(
[("workflow_id", 1)], name="edge_workflow_index", background=True
)
workflow_edges_collection.create_index(
[("workflow_id", 1), ("graph_version", 1)],
name="edge_workflow_graph_version_index",
background=True,
)
except Exception as e:
print("Error creating indexes:", e)
current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
@@ -103,58 +55,95 @@ def generate_date_range(start_date, end_date):
def ensure_user_doc(user_id):
"""
Ensure user document exists with proper agent preferences structure.
Ensure a Postgres ``users`` row exists for ``user_id``.
Returns the row as a dict with the shape legacy callers expect — in
particular ``user_id`` and ``agent_preferences`` (with ``pinned`` and
``shared_with_me`` list keys always present).
Args:
user_id: The user ID to ensure
Returns:
The user document
The user document as a dict.
"""
default_prefs = {
"pinned": [],
"shared_with_me": [],
}
with db_session() as conn:
user_doc = UsersRepository(conn).upsert(user_id)
user_doc = users_collection.find_one_and_update(
{"user_id": user_id},
{"$setOnInsert": {"agent_preferences": default_prefs}},
upsert=True,
return_document=ReturnDocument.AFTER,
)
prefs = user_doc.get("agent_preferences", {})
updates = {}
if "pinned" not in prefs:
updates["agent_preferences.pinned"] = []
if "shared_with_me" not in prefs:
updates["agent_preferences.shared_with_me"] = []
if updates:
users_collection.update_one({"user_id": user_id}, {"$set": updates})
user_doc = users_collection.find_one({"user_id": user_id})
prefs = user_doc.get("agent_preferences") or {}
if not isinstance(prefs, dict):
prefs = {}
prefs.setdefault("pinned", [])
prefs.setdefault("shared_with_me", [])
user_doc["agent_preferences"] = prefs
return user_doc
def resolve_tool_details(tool_ids):
"""
Resolve tool IDs to their details.
Resolve tool IDs to their display details.
Accepts either Postgres UUIDs or legacy Mongo ObjectId strings (mixed
lists are supported — each id is looked up via ``get_any``, which
resolves to whichever column matches). Unknown ids are silently
skipped.
Args:
tool_ids: List of tool IDs
tool_ids: List of tool IDs (UUIDs or legacy Mongo ObjectId strings).
Returns:
List of tool details with id, name, and display_name
List of tool details with ``id``, ``name``, and ``display_name``.
"""
tools = user_tools_collection.find(
{"_id": {"$in": [ObjectId(tid) for tid in tool_ids]}}
)
if not tool_ids:
return []
uuid_ids: list[str] = []
legacy_ids: list[str] = []
for tid in tool_ids:
if not tid:
continue
tid_str = str(tid)
if looks_like_uuid(tid_str):
uuid_ids.append(tid_str)
else:
legacy_ids.append(tid_str)
if not uuid_ids and not legacy_ids:
return []
rows: list[dict] = []
with db_readonly() as conn:
if uuid_ids:
result = conn.execute(
_sql_text(
"SELECT * FROM user_tools "
"WHERE id = ANY(CAST(:ids AS uuid[]))"
),
{"ids": uuid_ids},
)
rows.extend(row_to_dict(r) for r in result.fetchall())
if legacy_ids:
result = conn.execute(
_sql_text(
"SELECT * FROM user_tools "
"WHERE legacy_mongo_id = ANY(:ids)"
),
{"ids": legacy_ids},
)
rows.extend(row_to_dict(r) for r in result.fetchall())
return [
{
"id": str(tool["_id"]),
"name": tool.get("name", ""),
"display_name": tool.get("displayName", tool.get("name", "")),
"id": str(tool.get("id") or tool.get("legacy_mongo_id") or ""),
"name": tool.get("name", "") or "",
"display_name": (
tool.get("custom_name")
or tool.get("display_name")
or tool.get("name", "")
or ""
),
}
for tool in tools
for tool in rows
]
@@ -224,14 +213,15 @@ def require_agent(func):
@wraps(func)
def wrapper(*args, **kwargs):
from application.storage.db.repositories.agents import AgentsRepository
webhook_token = kwargs.get("webhook_token")
if not webhook_token:
return make_response(
jsonify({"success": False, "message": "Webhook token missing"}), 400
)
agent = agents_collection.find_one(
{"incoming_webhook_token": webhook_token}, {"_id": 1}
)
with db_readonly() as conn:
agent = AgentsRepository(conn).find_by_webhook_token(webhook_token)
if not agent:
current_app.logger.warning(
f"Webhook attempt with invalid token: {webhook_token}"
@@ -240,7 +230,7 @@ def require_agent(func):
jsonify({"success": False, "message": "Agent not found"}), 404
)
kwargs["agent"] = agent
kwargs["agent_id_str"] = str(agent["_id"])
kwargs["agent_id_str"] = str(agent["id"])
return func(*args, **kwargs)
return wrapper

View File

@@ -2,12 +2,13 @@
import datetime
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, Namespace, Resource
from application.api import api
from application.api.user.base import attachments_collection, conversations_collection
from application.storage.db.repositories.attachments import AttachmentsRepository
from application.storage.db.repositories.conversations import ConversationsRepository
from application.storage.db.session import db_readonly, db_session
from application.utils import check_required_fields
conversations_ns = Namespace(
@@ -30,10 +31,13 @@ class DeleteConversation(Resource):
return make_response(
jsonify({"success": False, "message": "ID is required"}), 400
)
user_id = decoded_token["sub"]
try:
conversations_collection.delete_one(
{"_id": ObjectId(conversation_id), "user": decoded_token["sub"]}
)
with db_session() as conn:
repo = ConversationsRepository(conn)
conv = repo.get_any(conversation_id, user_id)
if conv is not None:
repo.delete(str(conv["id"]), user_id)
except Exception as err:
current_app.logger.error(
f"Error deleting conversation: {err}", exc_info=True
@@ -53,7 +57,8 @@ class DeleteAllConversations(Resource):
return make_response(jsonify({"success": False}), 401)
user_id = decoded_token.get("sub")
try:
conversations_collection.delete_many({"user": user_id})
with db_session() as conn:
ConversationsRepository(conn).delete_all_for_user(user_id)
except Exception as err:
current_app.logger.error(
f"Error deleting all conversations: {err}", exc_info=True
@@ -71,26 +76,21 @@ class GetConversations(Resource):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user_id = decoded_token.get("sub")
try:
conversations = (
conversations_collection.find(
{
"$or": [
{"api_key": {"$exists": False}},
{"agent_id": {"$exists": True}},
],
"user": decoded_token.get("sub"),
}
with db_readonly() as conn:
conversations = ConversationsRepository(conn).list_for_user(
user_id, limit=30
)
.sort("date", -1)
.limit(30)
)
list_conversations = [
{
"id": str(conversation["_id"]),
"id": str(conversation["id"]),
"name": conversation["name"],
"agent_id": conversation.get("agent_id", None),
"agent_id": (
str(conversation["agent_id"])
if conversation.get("agent_id")
else None
),
"is_shared_usage": conversation.get("is_shared_usage", False),
"shared_token": conversation.get("shared_token", None),
}
@@ -119,38 +119,67 @@ class GetSingleConversation(Resource):
return make_response(
jsonify({"success": False, "message": "ID is required"}), 400
)
user_id = decoded_token.get("sub")
try:
conversation = conversations_collection.find_one(
{"_id": ObjectId(conversation_id), "user": decoded_token.get("sub")}
)
if not conversation:
return make_response(jsonify({"status": "not found"}), 404)
# Process queries to include attachment names
with db_readonly() as conn:
repo = ConversationsRepository(conn)
conversation = repo.get_any(conversation_id, user_id)
if not conversation:
return make_response(jsonify({"status": "not found"}), 404)
conv_pg_id = str(conversation["id"])
messages = repo.get_messages(conv_pg_id)
queries = conversation["queries"]
for query in queries:
if "attachments" in query and query["attachments"]:
attachment_details = []
for attachment_id in query["attachments"]:
try:
attachment = attachments_collection.find_one(
{"_id": ObjectId(attachment_id)}
)
if attachment:
attachment_details.append(
{
"id": str(attachment["_id"]),
"fileName": attachment.get(
"filename", "Unknown file"
),
}
# Resolve attachment details (id, fileName) for each message.
attachments_repo = AttachmentsRepository(conn)
queries = []
for msg in messages:
query = {
"prompt": msg.get("prompt"),
"response": msg.get("response"),
"thought": msg.get("thought"),
"sources": msg.get("sources") or [],
"tool_calls": msg.get("tool_calls") or [],
"timestamp": msg.get("timestamp"),
"model_id": msg.get("model_id"),
}
if msg.get("metadata"):
query["metadata"] = msg["metadata"]
# Feedback on conversation_messages is a JSONB blob with
# shape {"text": <str>, "timestamp": <iso>}. The legacy
# frontend consumed a flat scalar feedback string, so
# unwrap the ``text`` field for compat.
feedback = msg.get("feedback")
if feedback is not None:
if isinstance(feedback, dict):
query["feedback"] = feedback.get("text")
if feedback.get("timestamp"):
query["feedback_timestamp"] = feedback["timestamp"]
else:
query["feedback"] = feedback
attachments = msg.get("attachments") or []
if attachments:
attachment_details = []
for attachment_id in attachments:
try:
att = attachments_repo.get_any(
str(attachment_id), user_id
)
except Exception as e:
current_app.logger.error(
f"Error retrieving attachment {attachment_id}: {e}",
exc_info=True,
)
query["attachments"] = attachment_details
if att:
attachment_details.append(
{
"id": str(att["id"]),
"fileName": att.get(
"filename", "Unknown file"
),
}
)
except Exception as e:
current_app.logger.error(
f"Error retrieving attachment {attachment_id}: {e}",
exc_info=True,
)
query["attachments"] = attachment_details
queries.append(query)
except Exception as err:
current_app.logger.error(
f"Error retrieving conversation: {err}", exc_info=True
@@ -158,7 +187,9 @@ class GetSingleConversation(Resource):
return make_response(jsonify({"success": False}), 400)
data = {
"queries": queries,
"agent_id": conversation.get("agent_id"),
"agent_id": (
str(conversation["agent_id"]) if conversation.get("agent_id") else None
),
"is_shared_usage": conversation.get("is_shared_usage", False),
"shared_token": conversation.get("shared_token", None),
}
@@ -190,11 +221,13 @@ class UpdateConversationName(Resource):
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
user_id = decoded_token.get("sub")
try:
conversations_collection.update_one(
{"_id": ObjectId(data["id"]), "user": decoded_token.get("sub")},
{"$set": {"name": data["name"]}},
)
with db_session() as conn:
repo = ConversationsRepository(conn)
conv = repo.get_any(data["id"], user_id)
if conv is not None:
repo.rename(str(conv["id"]), user_id, data["name"])
except Exception as err:
current_app.logger.error(
f"Error updating conversation name: {err}", exc_info=True
@@ -237,43 +270,33 @@ class SubmitFeedback(Resource):
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
user_id = decoded_token.get("sub")
feedback_value = data["feedback"]
question_index = int(data["question_index"])
# Normalize string feedback to lowercase so analytics queries
# (which match 'like'/'dislike') count rows correctly. Tolerate
# legacy uppercase clients on ingest. Non-string values pass through.
if isinstance(feedback_value, str):
feedback_value = feedback_value.lower()
feedback_payload = (
None
if feedback_value is None
else {
"text": feedback_value,
"timestamp": datetime.datetime.now(
datetime.timezone.utc
).isoformat(),
}
)
try:
if data["feedback"] is None:
# Remove feedback and feedback_timestamp if feedback is null
conversations_collection.update_one(
{
"_id": ObjectId(data["conversation_id"]),
"user": decoded_token.get("sub"),
f"queries.{data['question_index']}": {"$exists": True},
},
{
"$unset": {
f"queries.{data['question_index']}.feedback": "",
f"queries.{data['question_index']}.feedback_timestamp": "",
}
},
)
else:
# Set feedback and feedback_timestamp if feedback has a value
conversations_collection.update_one(
{
"_id": ObjectId(data["conversation_id"]),
"user": decoded_token.get("sub"),
f"queries.{data['question_index']}": {"$exists": True},
},
{
"$set": {
f"queries.{data['question_index']}.feedback": data[
"feedback"
],
f"queries.{data['question_index']}.feedback_timestamp": datetime.datetime.now(
datetime.timezone.utc
),
}
},
)
with db_session() as conn:
repo = ConversationsRepository(conn)
conv = repo.get_any(data["conversation_id"], user_id)
if conv is None:
return make_response(
jsonify({"success": False, "message": "Not found"}), 404
)
repo.set_feedback(str(conv["id"]), question_index, feedback_payload)
except Exception as err:
current_app.logger.error(f"Error submitting feedback: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)

View File

@@ -2,12 +2,13 @@
import os
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, Namespace, Resource
from application.api import api
from application.api.user.base import current_dir, prompts_collection
from application.api.user.base import current_dir
from application.storage.db.repositories.prompts import PromptsRepository
from application.storage.db.session import db_readonly, db_session
from application.utils import check_required_fields
prompts_ns = Namespace(
@@ -40,15 +41,9 @@ class CreatePrompt(Resource):
return missing_fields
user = decoded_token.get("sub")
try:
resp = prompts_collection.insert_one(
{
"name": data["name"],
"content": data["content"],
"user": user,
}
)
new_id = str(resp.inserted_id)
with db_session() as conn:
prompt = PromptsRepository(conn).create(user, data["name"], data["content"])
new_id = str(prompt["id"])
except Exception as err:
current_app.logger.error(f"Error creating prompt: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
@@ -64,17 +59,17 @@ class GetPrompts(Resource):
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
try:
prompts = prompts_collection.find({"user": user})
with db_readonly() as conn:
prompts = PromptsRepository(conn).list_for_user(user)
list_prompts = [
{"id": "default", "name": "default", "type": "public"},
{"id": "creative", "name": "creative", "type": "public"},
{"id": "strict", "name": "strict", "type": "public"},
]
for prompt in prompts:
list_prompts.append(
{
"id": str(prompt["_id"]),
"id": str(prompt["id"]),
"name": prompt["name"],
"type": "private",
}
@@ -119,9 +114,12 @@ class GetSinglePrompt(Resource):
) as f:
chat_reduce_strict = f.read()
return make_response(jsonify({"content": chat_reduce_strict}), 200)
prompt = prompts_collection.find_one(
{"_id": ObjectId(prompt_id), "user": user}
)
with db_readonly() as conn:
prompt = PromptsRepository(conn).get_any(prompt_id, user)
if not prompt:
return make_response(
jsonify({"success": False, "message": "Prompt not found"}), 404
)
except Exception as err:
current_app.logger.error(f"Error retrieving prompt: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
@@ -148,7 +146,15 @@ class DeletePrompt(Resource):
if missing_fields:
return missing_fields
try:
prompts_collection.delete_one({"_id": ObjectId(data["id"]), "user": user})
with db_session() as conn:
repo = PromptsRepository(conn)
prompt = repo.get_any(data["id"], user)
if not prompt:
return make_response(
jsonify({"success": False, "message": "Prompt not found"}),
404,
)
repo.delete(str(prompt["id"]), user)
except Exception as err:
current_app.logger.error(f"Error deleting prompt: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
@@ -181,10 +187,15 @@ class UpdatePrompt(Resource):
if missing_fields:
return missing_fields
try:
prompts_collection.update_one(
{"_id": ObjectId(data["id"]), "user": user},
{"$set": {"name": data["name"], "content": data["content"]}},
)
with db_session() as conn:
repo = PromptsRepository(conn)
prompt = repo.get_any(data["id"], user)
if not prompt:
return make_response(
jsonify({"success": False, "message": "Prompt not found"}),
404,
)
repo.update(str(prompt["id"]), user, data["name"], data["content"])
except Exception as err:
current_app.logger.error(f"Error updating prompt: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)

View File

@@ -2,26 +2,126 @@
import uuid
from bson.binary import Binary, UuidRepresentation
from bson.dbref import DBRef
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, inputs, Namespace, Resource
from sqlalchemy import text as _sql_text
from application.api import api
from application.api.user.base import (
agents_collection,
attachments_collection,
conversations_collection,
shared_conversations_collections,
from application.storage.db.base_repository import looks_like_uuid
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.repositories.attachments import AttachmentsRepository
from application.storage.db.repositories.conversations import ConversationsRepository
from application.storage.db.repositories.shared_conversations import (
SharedConversationsRepository,
)
from application.storage.db.session import db_readonly, db_session
from application.utils import check_required_fields
sharing_ns = Namespace(
"sharing", description="Conversation sharing operations", path="/api"
)
def _resolve_prompt_pg_id(conn, prompt_id_raw, user_id):
"""Translate an incoming prompt id (UUID or legacy Mongo ObjectId) to a PG UUID.
Scoped by ``user_id`` so a caller can't link another user's prompt
into their share record. Returns ``None`` for sentinel values
(``"default"``) or unresolved ids.
"""
if not prompt_id_raw or prompt_id_raw == "default":
return None
value = str(prompt_id_raw)
# Already UUID — trust it but still require ownership. A shape-gate
# (rather than a loose ``len == 36 and '-' in value`` check) keeps
# non-UUID input out of ``CAST(:pid AS uuid)``; the cast would raise
# and poison the readonly transaction otherwise.
if looks_like_uuid(value):
row = conn.execute(
_sql_text(
"SELECT id FROM prompts WHERE id = CAST(:pid AS uuid) "
"AND user_id = :uid"
),
{"pid": value, "uid": user_id},
).fetchone()
return str(row[0]) if row else None
# Legacy Mongo ObjectId fallback.
row = conn.execute(
_sql_text(
"SELECT id FROM prompts WHERE legacy_mongo_id = :pid "
"AND user_id = :uid"
),
{"pid": value, "uid": user_id},
).fetchone()
return str(row[0]) if row else None
def _resolve_source_pg_id(conn, source_raw):
"""Translate a source id (UUID or legacy Mongo ObjectId) to a PG UUID."""
if not source_raw:
return None
value = str(source_raw)
# See ``_resolve_prompt_pg_id`` for the shape-gate rationale.
if looks_like_uuid(value):
row = conn.execute(
_sql_text(
"SELECT id FROM sources WHERE id = CAST(:sid AS uuid)"
),
{"sid": value},
).fetchone()
return str(row[0]) if row else None
row = conn.execute(
_sql_text("SELECT id FROM sources WHERE legacy_mongo_id = :sid"),
{"sid": value},
).fetchone()
return str(row[0]) if row else None
def _find_reusable_share_agent(
conn, user_id, *, prompt_pg_id, chunks, source_pg_id, retriever,
):
"""Find an existing share-as-agent key row matching these parameters.
Mirrors the legacy Mongo ``agents_collection.find_one`` pre-existence
check. Used to reuse an api key across repeated shares of the same
conversation with the same prompt/chunks/source/retriever.
"""
clauses = ["user_id = :uid", "key IS NOT NULL"]
params: dict = {"uid": user_id}
if prompt_pg_id is None:
clauses.append("prompt_id IS NULL")
else:
clauses.append("prompt_id = CAST(:pid AS uuid)")
params["pid"] = prompt_pg_id
if chunks is None:
clauses.append("chunks IS NULL")
else:
clauses.append("chunks = :chunks")
params["chunks"] = int(chunks)
if source_pg_id is None:
clauses.append("source_id IS NULL")
else:
clauses.append("source_id = CAST(:sid AS uuid)")
params["sid"] = source_pg_id
if retriever is None:
clauses.append("retriever IS NULL")
else:
clauses.append("retriever = :retr")
params["retr"] = retriever
sql = (
"SELECT * FROM agents WHERE "
+ " AND ".join(clauses)
+ " LIMIT 1"
)
row = conn.execute(_sql_text(sql), params).fetchone()
if row is None:
return None
mapping = dict(row._mapping)
mapping["id"] = str(mapping["id"]) if mapping.get("id") else None
return mapping
@sharing_ns.route("/share")
class ShareConversation(Resource):
share_conversation_model = api.model(
@@ -56,146 +156,94 @@ class ShareConversation(Resource):
conversation_id = data["conversation_id"]
try:
conversation = conversations_collection.find_one(
{"_id": ObjectId(conversation_id)}
)
if conversation is None:
return make_response(
jsonify(
{
"status": "error",
"message": "Conversation does not exist",
}
),
404,
)
current_n_queries = len(conversation["queries"])
explicit_binary = Binary.from_uuid(
uuid.uuid4(), UuidRepresentation.STANDARD
)
with db_session() as conn:
conv_repo = ConversationsRepository(conn)
shared_repo = SharedConversationsRepository(conn)
agents_repo = AgentsRepository(conn)
if is_promptable:
prompt_id = data.get("prompt_id", "default")
chunks = data.get("chunks", "2")
name = conversation["name"] + "(shared)"
new_api_key_data = {
"prompt_id": prompt_id,
"chunks": chunks,
"user": user,
}
if "source" in data and ObjectId.is_valid(data["source"]):
new_api_key_data["source"] = DBRef(
"sources", ObjectId(data["source"])
)
if "retriever" in data:
new_api_key_data["retriever"] = data["retriever"]
pre_existing_api_document = agents_collection.find_one(new_api_key_data)
if pre_existing_api_document:
api_uuid = pre_existing_api_document["key"]
pre_existing = shared_conversations_collections.find_one(
{
"conversation_id": ObjectId(conversation_id),
"isPromptable": is_promptable,
"first_n_queries": current_n_queries,
"user": user,
"api_key": api_uuid,
}
)
if pre_existing is not None:
return make_response(
jsonify(
{
"success": True,
"identifier": str(pre_existing["uuid"].as_uuid()),
}
),
200,
)
else:
shared_conversations_collections.insert_one(
conversation = conv_repo.get_any(conversation_id, user)
if conversation is None:
return make_response(
jsonify(
{
"uuid": explicit_binary,
"conversation_id": ObjectId(conversation_id),
"isPromptable": is_promptable,
"first_n_queries": current_n_queries,
"user": user,
"api_key": api_uuid,
"status": "error",
"message": "Conversation does not exist",
}
)
return make_response(
jsonify(
{
"success": True,
"identifier": str(explicit_binary.as_uuid()),
}
),
201,
)
else:
api_uuid = str(uuid.uuid4())
new_api_key_data["key"] = api_uuid
new_api_key_data["name"] = name
),
404,
)
conv_pg_id = str(conversation["id"])
current_n_queries = conv_repo.message_count(conv_pg_id)
if "source" in data and ObjectId.is_valid(data["source"]):
new_api_key_data["source"] = DBRef(
"sources", ObjectId(data["source"])
if is_promptable:
prompt_id_raw = data.get("prompt_id", "default")
chunks_raw = data.get("chunks", "2")
try:
chunks_int = int(chunks_raw) if chunks_raw not in (None, "") else None
except (TypeError, ValueError):
chunks_int = None
prompt_pg_id = _resolve_prompt_pg_id(conn, prompt_id_raw, user)
source_pg_id = _resolve_source_pg_id(conn, data.get("source"))
retriever = data.get("retriever")
reusable = _find_reusable_share_agent(
conn, user,
prompt_pg_id=prompt_pg_id,
chunks=chunks_int,
source_pg_id=source_pg_id,
retriever=retriever,
)
if reusable:
api_uuid = reusable.get("key")
else:
api_uuid = str(uuid.uuid4())
name = (conversation.get("name") or "") + "(shared)"
agents_repo.create(
user,
name,
"published",
key=api_uuid,
retriever=retriever,
chunks=chunks_int,
prompt_id=prompt_pg_id,
source_id=source_pg_id,
)
if "retriever" in data:
new_api_key_data["retriever"] = data["retriever"]
agents_collection.insert_one(new_api_key_data)
shared_conversations_collections.insert_one(
{
"uuid": explicit_binary,
"conversation_id": ObjectId(conversation_id),
"isPromptable": is_promptable,
"first_n_queries": current_n_queries,
"user": user,
"api_key": api_uuid,
}
share = shared_repo.get_or_create(
conv_pg_id,
user,
is_promptable=True,
first_n_queries=current_n_queries,
api_key=api_uuid,
prompt_id=prompt_pg_id,
chunks=chunks_int,
)
return make_response(
jsonify(
{
"success": True,
"identifier": str(explicit_binary.as_uuid()),
"identifier": str(share["uuid"]),
}
),
201,
201 if reusable is None else 200,
)
pre_existing = shared_conversations_collections.find_one(
{
"conversation_id": ObjectId(conversation_id),
"isPromptable": is_promptable,
"first_n_queries": current_n_queries,
"user": user,
}
)
if pre_existing is not None:
# Non-promptable share path.
share = shared_repo.get_or_create(
conv_pg_id,
user,
is_promptable=False,
first_n_queries=current_n_queries,
api_key=None,
)
return make_response(
jsonify(
{
"success": True,
"identifier": str(pre_existing["uuid"].as_uuid()),
"identifier": str(share["uuid"]),
}
),
200,
)
else:
shared_conversations_collections.insert_one(
{
"uuid": explicit_binary,
"conversation_id": ObjectId(conversation_id),
"isPromptable": is_promptable,
"first_n_queries": current_n_queries,
"user": user,
}
)
return make_response(
jsonify(
{"success": True, "identifier": str(explicit_binary.as_uuid())}
),
201,
)
except Exception as err:
@@ -210,37 +258,13 @@ class GetPubliclySharedConversations(Resource):
@api.doc(description="Get publicly shared conversations by identifier")
def get(self, identifier: str):
try:
query_uuid = Binary.from_uuid(
uuid.UUID(identifier), UuidRepresentation.STANDARD
)
shared = shared_conversations_collections.find_one({"uuid": query_uuid})
conversation_queries = []
with db_readonly() as conn:
shared_repo = SharedConversationsRepository(conn)
conv_repo = ConversationsRepository(conn)
attach_repo = AttachmentsRepository(conn)
if (
shared
and "conversation_id" in shared
):
# Handle DBRef (legacy), ObjectId, dict, and string formats for conversation_id
conversation_id = shared["conversation_id"]
if isinstance(conversation_id, DBRef):
conversation_id = conversation_id.id
elif isinstance(conversation_id, dict):
# Handle dict representation of DBRef (e.g., {"$ref": "...", "$id": "..."})
if "$id" in conversation_id:
conv_id = conversation_id["$id"]
# $id might be a dict like {"$oid": "..."} or a string
if isinstance(conv_id, dict) and "$oid" in conv_id:
conversation_id = ObjectId(conv_id["$oid"])
else:
conversation_id = ObjectId(conv_id)
elif "_id" in conversation_id:
conversation_id = ObjectId(conversation_id["_id"])
elif isinstance(conversation_id, str):
conversation_id = ObjectId(conversation_id)
conversation = conversations_collection.find_one(
{"_id": conversation_id}
)
if conversation is None:
shared = shared_repo.find_by_uuid(identifier)
if not shared or not shared.get("conversation_id"):
return make_response(
jsonify(
{
@@ -250,22 +274,60 @@ class GetPubliclySharedConversations(Resource):
),
404,
)
conversation_queries = conversation["queries"][
: (shared["first_n_queries"])
]
conv_pg_id = str(shared["conversation_id"])
owner_user = shared.get("user_id")
for query in conversation_queries:
if "attachments" in query and query["attachments"]:
conversation = conv_repo.get_owned(conv_pg_id, owner_user) if owner_user else None
if conversation is None:
# Fall back to any-user lookup in case shared row's
# user_id is missing — still keyed by PG UUID.
row = conn.execute(
_sql_text(
"SELECT * FROM conversations WHERE id = CAST(:id AS uuid)"
),
{"id": conv_pg_id},
).fetchone()
if row is None:
return make_response(
jsonify(
{
"success": False,
"error": "might have broken url or the conversation does not exist",
}
),
404,
)
conversation = dict(row._mapping)
messages = conv_repo.get_messages(conv_pg_id)
first_n = shared.get("first_n_queries") or 0
conversation_queries = []
for msg in messages[:first_n]:
query = {
"prompt": msg.get("prompt"),
"response": msg.get("response"),
"thought": msg.get("thought"),
"sources": msg.get("sources") or [],
"tool_calls": msg.get("tool_calls") or [],
"timestamp": (
msg["timestamp"].isoformat()
if hasattr(msg.get("timestamp"), "isoformat")
else msg.get("timestamp")
),
"feedback": msg.get("feedback"),
}
attachments = msg.get("attachments") or []
if attachments:
attachment_details = []
for attachment_id in query["attachments"]:
for attachment_id in attachments:
try:
attachment = attachments_collection.find_one(
{"_id": ObjectId(attachment_id)}
)
attachment = attach_repo.get_any(
str(attachment_id), owner_user,
) if owner_user else None
if attachment:
attachment_details.append(
{
"id": str(attachment["_id"]),
"id": str(attachment["id"]),
"fileName": attachment.get(
"filename", "Unknown file"
),
@@ -277,26 +339,23 @@ class GetPubliclySharedConversations(Resource):
exc_info=True,
)
query["attachments"] = attachment_details
else:
return make_response(
jsonify(
{
"success": False,
"error": "might have broken url or the conversation does not exist",
}
),
404,
conversation_queries.append(query)
created = conversation.get("created_at") or conversation.get("date")
date_iso = (
created.isoformat()
if hasattr(created, "isoformat")
else (str(created) if created is not None else None)
)
date = conversation["_id"].generation_time.isoformat()
res = {
"success": True,
"queries": conversation_queries,
"title": conversation["name"],
"timestamp": date,
}
if shared["isPromptable"] and "api_key" in shared:
res["api_key"] = shared["api_key"]
return make_response(jsonify(res), 200)
res = {
"success": True,
"queries": conversation_queries,
"title": conversation.get("name"),
"timestamp": date_iso,
}
if shared.get("is_promptable") and shared.get("api_key"):
res["api_key"] = shared["api_key"]
return make_response(jsonify(res), 200)
except Exception as err:
current_app.logger.error(
f"Error getting shared conversation: {err}", exc_info=True

View File

@@ -1,11 +1,12 @@
"""Source document management chunk management."""
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, Namespace, Resource
from application.api import api
from application.api.user.base import get_vector_store, sources_collection
from application.api.user.base import get_vector_store
from application.storage.db.repositories.sources import SourcesRepository
from application.storage.db.session import db_readonly
from application.utils import check_required_fields, num_tokens_from_string
sources_chunks_ns = Namespace(
@@ -13,6 +14,15 @@ sources_chunks_ns = Namespace(
)
def _resolve_source(doc_id: str, user: str):
"""Resolve a source (UUID or legacy ObjectId) for the caller.
Returns the row dict (with PG UUID in ``id``) or ``None`` if missing.
"""
with db_readonly() as conn:
return SourcesRepository(conn).get_any(doc_id, user)
@sources_chunks_ns.route("/get_chunks")
class GetChunks(Resource):
@api.doc(
@@ -36,36 +46,34 @@ class GetChunks(Resource):
path = request.args.get("path")
search_term = request.args.get("search", "").strip().lower()
if not ObjectId.is_valid(doc_id):
if not doc_id:
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
try:
doc = _resolve_source(doc_id, user)
except Exception as e:
current_app.logger.error(f"Error resolving source: {e}", exc_info=True)
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
if not doc:
return make_response(
jsonify({"error": "Document not found or access denied"}), 404
)
resolved_id = str(doc["id"])
try:
store = get_vector_store(doc_id)
store = get_vector_store(resolved_id)
chunks = store.get_chunks()
filtered_chunks = []
for chunk in chunks:
metadata = chunk.get("metadata", {})
# Filter by path if provided
if path:
chunk_source = metadata.get("source", "")
chunk_file_path = metadata.get("file_path", "")
# Check if the chunk matches the requested path
# For file uploads: source ends with path (e.g., "inputs/.../file.pdf" ends with "file.pdf")
# For crawlers: file_path ends with path (e.g., "guides/setup.md" ends with "setup.md")
source_match = chunk_source and chunk_source.endswith(path)
file_path_match = chunk_file_path and chunk_file_path.endswith(path)
if not (source_match or file_path_match):
continue
# Filter by search term if provided
if search_term:
text_match = search_term in chunk.get("text", "").lower()
title_match = search_term in metadata.get("title", "").lower()
@@ -132,15 +140,17 @@ class AddChunk(Resource):
token_count = num_tokens_from_string(text)
metadata["token_count"] = token_count
if not ObjectId.is_valid(doc_id):
try:
doc = _resolve_source(doc_id, user)
except Exception as e:
current_app.logger.error(f"Error resolving source: {e}", exc_info=True)
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
if not doc:
return make_response(
jsonify({"error": "Document not found or access denied"}), 404
)
try:
store = get_vector_store(doc_id)
store = get_vector_store(str(doc["id"]))
chunk_id = store.add_chunk(text, metadata)
return make_response(
jsonify({"message": "Chunk added successfully", "chunk_id": chunk_id}),
@@ -165,15 +175,17 @@ class DeleteChunk(Resource):
doc_id = request.args.get("id")
chunk_id = request.args.get("chunk_id")
if not ObjectId.is_valid(doc_id):
try:
doc = _resolve_source(doc_id, user)
except Exception as e:
current_app.logger.error(f"Error resolving source: {e}", exc_info=True)
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
if not doc:
return make_response(
jsonify({"error": "Document not found or access denied"}), 404
)
try:
store = get_vector_store(doc_id)
store = get_vector_store(str(doc["id"]))
deleted = store.delete_chunk(chunk_id)
if deleted:
return make_response(
@@ -232,15 +244,17 @@ class UpdateChunk(Resource):
if metadata is None:
metadata = {}
metadata["token_count"] = token_count
if not ObjectId.is_valid(doc_id):
try:
doc = _resolve_source(doc_id, user)
except Exception as e:
current_app.logger.error(f"Error resolving source: {e}", exc_info=True)
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
if not doc:
return make_response(
jsonify({"error": "Document not found or access denied"}), 404
)
try:
store = get_vector_store(doc_id)
store = get_vector_store(str(doc["id"]))
chunks = store.get_chunks()
existing_chunk = next((c for c in chunks if c["doc_id"] == chunk_id), None)

View File

@@ -3,14 +3,14 @@
import json
import math
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, redirect, request
from flask_restx import fields, Namespace, Resource
from application.api import api
from application.api.user.base import sources_collection
from application.api.user.tasks import sync_source
from application.core.settings import settings
from application.storage.db.repositories.sources import SourcesRepository
from application.storage.db.session import db_readonly, db_session
from application.storage.storage_creator import StorageCreator
from application.utils import check_required_fields
from application.vectorstore.vector_creator import VectorCreator
@@ -56,11 +56,20 @@ class CombinedJson(Resource):
]
try:
for index in sources_collection.find({"user": user}).sort("date", -1):
with db_readonly() as conn:
indexes = SourcesRepository(conn).list_for_user(user)
# list_for_user sorts by created_at DESC; legacy shape sorted by
# "date" DESC. Both are monotonic on creation so the ordering is
# equivalent for dev; re-sort defensively.
indexes = sorted(
indexes, key=lambda r: r.get("date") or r.get("created_at") or "",
reverse=True,
)
for index in indexes:
provider = _get_provider_from_remote_data(index.get("remote_data"))
data.append(
{
"id": str(index["_id"]),
"id": str(index["id"]),
"name": index.get("name"),
"date": index.get("date"),
"model": settings.EMBEDDINGS_NAME,
@@ -70,9 +79,7 @@ class CombinedJson(Resource):
"syncFrequency": index.get("sync_frequency", ""),
"provider": provider,
"is_nested": bool(index.get("directory_structure")),
"type": index.get(
"type", "file"
), # Add type field with default "file"
"type": index.get("type", "file"),
}
)
except Exception as err:
@@ -89,61 +96,55 @@ class PaginatedSources(Resource):
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
sort_field = request.args.get("sort", "date") # Default to 'date'
sort_order = request.args.get("order", "desc") # Default to 'desc'
page = int(request.args.get("page", 1)) # Default to 1
rows_per_page = int(request.args.get("rows", 10)) # Default to 10
# add .strip() to remove leading and trailing whitespaces
search_term = request.args.get(
"search", ""
).strip() # add search for filter documents
# Prepare query for filtering
query = {"user": user}
if search_term:
query["name"] = {
"$regex": search_term,
"$options": "i", # using case-insensitive search
}
total_documents = sources_collection.count_documents(query)
total_pages = max(1, math.ceil(total_documents / rows_per_page))
page = min(
max(1, page), total_pages
) # add this to make sure page inbound is within the range
sort_order = 1 if sort_order == "asc" else -1
skip = (page - 1) * rows_per_page
sort_field = request.args.get("sort", "date")
sort_order = request.args.get("order", "desc")
page = max(1, int(request.args.get("page", 1)))
rows_per_page = max(1, int(request.args.get("rows", 10)))
search_term = request.args.get("search", "").strip() or None
try:
documents = (
sources_collection.find(query)
.sort(sort_field, sort_order)
.skip(skip)
.limit(rows_per_page)
)
with db_readonly() as conn:
repo = SourcesRepository(conn)
total_documents = repo.count_for_user(
user, search_term=search_term,
)
# Prior in-Python implementation returned ``totalPages = 1``
# for empty result sets (``max(1, ceil(0/rows))``); we
# preserve that contract so the frontend pager stays stable.
total_pages = max(1, math.ceil(total_documents / rows_per_page))
effective_page = min(page, total_pages)
offset = (effective_page - 1) * rows_per_page
window = repo.list_for_user(
user,
limit=rows_per_page,
offset=offset,
search_term=search_term,
sort_field=sort_field,
sort_order=sort_order,
)
paginated_docs = []
for doc in documents:
for doc in window:
provider = _get_provider_from_remote_data(doc.get("remote_data"))
doc_data = {
"id": str(doc["_id"]),
"name": doc.get("name", ""),
"date": doc.get("date", ""),
"model": settings.EMBEDDINGS_NAME,
"location": "local",
"tokens": doc.get("tokens", ""),
"retriever": doc.get("retriever", "classic"),
"syncFrequency": doc.get("sync_frequency", ""),
"provider": provider,
"isNested": bool(doc.get("directory_structure")),
"type": doc.get("type", "file"),
}
paginated_docs.append(doc_data)
paginated_docs.append(
{
"id": str(doc["id"]),
"name": doc.get("name", ""),
"date": doc.get("date", ""),
"model": settings.EMBEDDINGS_NAME,
"location": "local",
"tokens": doc.get("tokens", ""),
"retriever": doc.get("retriever", "classic"),
"syncFrequency": doc.get("sync_frequency", ""),
"provider": provider,
"isNested": bool(doc.get("directory_structure")),
"type": doc.get("type", "file"),
}
)
response = {
"total": total_documents,
"totalPages": total_pages,
"currentPage": page,
"currentPage": effective_page,
"paginated": paginated_docs,
}
return make_response(jsonify(response), 200)
@@ -154,28 +155,6 @@ class PaginatedSources(Resource):
return make_response(jsonify({"success": False}), 400)
@sources_ns.route("/delete_by_ids")
class DeleteByIds(Resource):
@api.doc(
description="Deletes documents from the vector store by IDs",
params={"path": "Comma-separated list of IDs"},
)
def get(self):
ids = request.args.get("path")
if not ids:
return make_response(
jsonify({"success": False, "message": "Missing required fields"}), 400
)
try:
result = sources_collection.delete_index(ids=ids)
if result:
return make_response(jsonify({"success": True}), 200)
except Exception as err:
current_app.logger.error(f"Error deleting indexes: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": False}), 400)
@sources_ns.route("/delete_old")
class DeleteOldIndexes(Resource):
@api.doc(
@@ -186,30 +165,33 @@ class DeleteOldIndexes(Resource):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
source_id = request.args.get("source_id")
if not source_id:
return make_response(
jsonify({"success": False, "message": "Missing required fields"}), 400
)
doc = sources_collection.find_one(
{"_id": ObjectId(source_id), "user": decoded_token.get("sub")}
)
try:
with db_readonly() as conn:
doc = SourcesRepository(conn).get_any(source_id, user)
except Exception as err:
current_app.logger.error(f"Error looking up source: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
if not doc:
return make_response(jsonify({"status": "not found"}), 404)
storage = StorageCreator.get_storage()
resolved_id = str(doc["id"])
try:
# Delete vector index
if settings.VECTOR_STORE == "faiss":
index_path = f"indexes/{str(doc['_id'])}"
index_path = f"indexes/{resolved_id}"
if storage.file_exists(f"{index_path}/index.faiss"):
storage.delete_file(f"{index_path}/index.faiss")
if storage.file_exists(f"{index_path}/index.pkl"):
storage.delete_file(f"{index_path}/index.pkl")
else:
vectorstore = VectorCreator.create_vectorstore(
settings.VECTOR_STORE, source_id=str(doc["_id"])
settings.VECTOR_STORE, source_id=resolved_id
)
vectorstore.delete_index()
if "file_path" in doc and doc["file_path"]:
@@ -227,7 +209,14 @@ class DeleteOldIndexes(Resource):
f"Error deleting files and indexes: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
sources_collection.delete_one({"_id": ObjectId(source_id)})
try:
with db_session() as conn:
SourcesRepository(conn).delete(resolved_id, user)
except Exception as err:
current_app.logger.error(
f"Error deleting source row: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True}), 200)
@@ -272,15 +261,16 @@ class ManageSync(Resource):
return make_response(
jsonify({"success": False, "message": "Invalid frequency"}), 400
)
update_data = {"$set": {"sync_frequency": sync_frequency}}
try:
sources_collection.update_one(
{
"_id": ObjectId(source_id),
"user": user,
},
update_data,
)
with db_session() as conn:
repo = SourcesRepository(conn)
doc = repo.get_any(source_id, user)
if doc is None:
return make_response(
jsonify({"success": False, "message": "Source not found"}),
404,
)
repo.update(str(doc["id"]), user, {"sync_frequency": sync_frequency})
except Exception as err:
current_app.logger.error(
f"Error updating sync frequency: {err}", exc_info=True
@@ -309,19 +299,20 @@ class SyncSource(Resource):
if missing_fields:
return missing_fields
source_id = data["source_id"]
if not ObjectId.is_valid(source_id):
try:
with db_readonly() as conn:
doc = SourcesRepository(conn).get_any(source_id, user)
except Exception as err:
current_app.logger.error(f"Error looking up source: {err}", exc_info=True)
return make_response(
jsonify({"success": False, "message": "Invalid source ID"}), 400
)
doc = sources_collection.find_one(
{"_id": ObjectId(source_id), "user": user}
)
if not doc:
return make_response(
jsonify({"success": False, "message": "Source not found"}), 404
)
source_type = doc.get("type", "")
if source_type.startswith("connector"):
if source_type and source_type.startswith("connector"):
return make_response(
jsonify(
{
@@ -344,7 +335,7 @@ class SyncSource(Resource):
loader=source_type,
sync_frequency=doc.get("sync_frequency", "never"),
retriever=doc.get("retriever", "classic"),
doc_id=source_id,
doc_id=str(doc["id"]),
)
except Exception as err:
current_app.logger.error(
@@ -370,10 +361,9 @@ class DirectoryStructure(Resource):
if not doc_id:
return make_response(jsonify({"error": "Document ID is required"}), 400)
if not ObjectId.is_valid(doc_id):
return make_response(jsonify({"error": "Invalid document ID"}), 400)
try:
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
with db_readonly() as conn:
doc = SourcesRepository(conn).get_any(doc_id, user)
if not doc:
return make_response(
jsonify({"error": "Document not found or access denied"}), 404
@@ -387,6 +377,8 @@ class DirectoryStructure(Resource):
if isinstance(remote_data, str) and remote_data:
remote_data_obj = json.loads(remote_data)
provider = remote_data_obj.get("provider")
elif isinstance(remote_data, dict):
provider = remote_data.get("provider")
except Exception as e:
current_app.logger.warning(
f"Failed to parse remote_data for doc {doc_id}: {e}"
@@ -406,4 +398,7 @@ class DirectoryStructure(Resource):
current_app.logger.error(
f"Error retrieving directory structure: {e}", exc_info=True
)
return make_response(jsonify({"success": False, "error": "Failed to retrieve directory structure"}), 500)
return make_response(
jsonify({"success": False, "error": "Failed to retrieve directory structure"}),
500,
)

View File

@@ -5,16 +5,23 @@ import os
import tempfile
import zipfile
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, Namespace, Resource
from application.api import api
from application.api.user.base import sources_collection
from application.api.user.tasks import ingest, ingest_connector_task, ingest_remote
from application.core.settings import settings
from application.parser.connectors.connector_creator import ConnectorCreator
from application.parser.file.constants import SUPPORTED_SOURCE_EXTENSIONS
from application.storage.db.repositories.sources import SourcesRepository
from application.storage.db.session import db_readonly, db_session
from application.storage.storage_creator import StorageCreator
from application.stt.upload_limits import (
AudioFileTooLargeError,
build_stt_file_size_limit_message,
enforce_audio_file_size_limit,
is_audio_filename,
)
from application.utils import check_required_fields, safe_filename
@@ -23,6 +30,12 @@ sources_upload_ns = Namespace(
)
def _enforce_audio_path_size_limit(file_path: str, filename: str) -> None:
if not is_audio_filename(filename):
return
enforce_audio_file_size_limit(os.path.getsize(file_path))
@sources_upload_ns.route("/upload")
class UploadFile(Resource):
@api.expect(
@@ -78,6 +91,7 @@ class UploadFile(Resource):
with tempfile.TemporaryDirectory() as temp_dir:
temp_file_path = os.path.join(temp_dir, safe_file)
file.save(temp_file_path)
_enforce_audio_path_size_limit(temp_file_path, safe_file)
# Only extract actual .zip files, not Office formats (.docx, .xlsx, .pptx)
# which are technically zip archives but should be processed as-is
@@ -102,6 +116,10 @@ class UploadFile(Resource):
os.path.join(root, extracted_file), temp_dir
)
storage_path = f"{base_path}/{rel_path}"
_enforce_audio_path_size_limit(
os.path.join(root, extracted_file),
extracted_file,
)
with open(
os.path.join(root, extracted_file), "rb"
@@ -124,29 +142,23 @@ class UploadFile(Resource):
storage.save_file(f, file_path)
task = ingest.delay(
settings.UPLOAD_FOLDER,
[
".rst",
".md",
".pdf",
".txt",
".docx",
".csv",
".epub",
".html",
".mdx",
".json",
".xlsx",
".pptx",
".png",
".jpg",
".jpeg",
],
list(SUPPORTED_SOURCE_EXTENSIONS),
job_name,
user,
file_path=base_path,
filename=dir_name,
file_name_map=file_name_map,
)
except AudioFileTooLargeError:
return make_response(
jsonify(
{
"success": False,
"message": build_stt_file_size_limit_message(),
}
),
413,
)
except Exception as err:
current_app.logger.error(f"Error uploading file: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
@@ -317,15 +329,8 @@ class ManageSourceFiles(Resource):
400,
)
try:
ObjectId(source_id)
except Exception:
return make_response(
jsonify({"success": False, "message": "Invalid source ID format"}), 400
)
try:
source = sources_collection.find_one(
{"_id": ObjectId(source_id), "user": user}
)
with db_readonly() as conn:
source = SourcesRepository(conn).get_any(source_id, user)
if not source:
return make_response(
jsonify(
@@ -341,6 +346,7 @@ class ManageSourceFiles(Resource):
return make_response(
jsonify({"success": False, "message": "Database error"}), 500
)
resolved_source_id = str(source["id"])
try:
storage = StorageCreator.get_storage()
source_file_path = source.get("file_path", "")
@@ -399,15 +405,18 @@ class ManageSourceFiles(Resource):
map_updated = True
if map_updated:
sources_collection.update_one(
{"_id": ObjectId(source_id)},
{"$set": {"file_name_map": file_name_map}},
)
with db_session() as conn:
SourcesRepository(conn).update(
resolved_source_id, user,
{"file_name_map": dict(file_name_map)},
)
# Trigger re-ingestion pipeline
from application.api.user.tasks import reingest_source_task
task = reingest_source_task.delay(source_id=source_id, user=user)
task = reingest_source_task.delay(
source_id=resolved_source_id, user=user
)
return make_response(
jsonify(
@@ -451,6 +460,16 @@ class ManageSourceFiles(Resource):
removed_files = []
map_updated = False
for file_path in file_paths:
if ".." in str(file_path) or str(file_path).startswith("/"):
return make_response(
jsonify(
{
"success": False,
"message": "Invalid file path",
}
),
400,
)
full_path = f"{source_file_path}/{file_path}"
# Remove from storage
@@ -463,15 +482,18 @@ class ManageSourceFiles(Resource):
map_updated = True
if map_updated and isinstance(file_name_map, dict):
sources_collection.update_one(
{"_id": ObjectId(source_id)},
{"$set": {"file_name_map": file_name_map}},
)
with db_session() as conn:
SourcesRepository(conn).update(
resolved_source_id, user,
{"file_name_map": dict(file_name_map)},
)
# Trigger re-ingestion pipeline
from application.api.user.tasks import reingest_source_task
task = reingest_source_task.delay(source_id=source_id, user=user)
task = reingest_source_task.delay(
source_id=resolved_source_id, user=user
)
return make_response(
jsonify(
@@ -559,16 +581,19 @@ class ManageSourceFiles(Resource):
if keys_to_remove:
for key in keys_to_remove:
file_name_map.pop(key, None)
sources_collection.update_one(
{"_id": ObjectId(source_id)},
{"$set": {"file_name_map": file_name_map}},
)
with db_session() as conn:
SourcesRepository(conn).update(
resolved_source_id, user,
{"file_name_map": dict(file_name_map)},
)
# Trigger re-ingestion pipeline
from application.api.user.tasks import reingest_source_task
task = reingest_source_task.delay(source_id=source_id, user=user)
task = reingest_source_task.delay(
source_id=resolved_source_id, user=user
)
return make_response(
jsonify(

View File

@@ -134,6 +134,17 @@ def setup_periodic_tasks(sender, **kwargs):
timedelta(days=30),
schedule_syncs.s("monthly"),
)
# Replaces Mongo's TTL index on pending_tool_state.expires_at.
sender.add_periodic_task(
timedelta(seconds=60),
cleanup_pending_tool_state.s(),
name="cleanup-pending-tool-state",
)
sender.add_periodic_task(
timedelta(hours=7),
version_check_task.s(),
name="version-check",
)
@celery.task(bind=True)
@@ -146,3 +157,40 @@ def mcp_oauth_task(self, config, user):
def mcp_oauth_status_task(self, task_id):
resp = mcp_oauth_status(self, task_id)
return resp
@celery.task(bind=True)
def cleanup_pending_tool_state(self):
"""Delete pending_tool_state rows past their TTL.
Replaces Mongo's ``expireAfterSeconds=0`` TTL index — Postgres has
no native TTL, so this task runs every 60 seconds to keep
``pending_tool_state`` bounded. No-ops if ``POSTGRES_URI`` isn't
configured (keeps the task runnable in Mongo-only environments).
"""
from application.core.settings import settings
if not settings.POSTGRES_URI:
return {"deleted": 0, "skipped": "POSTGRES_URI not set"}
from application.storage.db.engine import get_engine
from application.storage.db.repositories.pending_tool_state import (
PendingToolStateRepository,
)
engine = get_engine()
with engine.begin() as conn:
deleted = PendingToolStateRepository(conn).cleanup_expired()
return {"deleted": deleted}
@celery.task(bind=True)
def version_check_task(self):
"""Periodic anonymous version check.
Complements the ``worker_ready`` boot trigger so long-running
deployments (>6h cache TTL) still refresh advisories. ``run_check``
is fail-silent and coordinates across replicas via Redis lock +
cache (see ``application.updates.version_check``).
"""
from application.updates.version_check import run_check
run_check()

View File

@@ -1,21 +1,80 @@
"""Tool management MCP server integration."""
import json
from urllib.parse import unquote, urlencode
from urllib.parse import urlencode, urlparse
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, redirect, request
from flask_restx import fields, Namespace, Resource
from flask_restx import Namespace, Resource, fields
from application.agents.tools.mcp_tool import MCPOAuthManager, MCPTool
from application.api import api
from application.api.user.base import user_tools_collection
from application.api.user.tools.routes import transform_actions
from application.cache import get_redis_instance
from application.security.encryption import encrypt_credentials
from application.core.url_validation import SSRFError, validate_url
from application.security.encryption import decrypt_credentials, encrypt_credentials
from application.storage.db.repositories.connector_sessions import (
ConnectorSessionsRepository,
)
from application.storage.db.repositories.user_tools import UserToolsRepository
from application.storage.db.session import db_readonly, db_session
from application.utils import check_required_fields
tools_mcp_ns = Namespace("tools", description="Tool management operations", path="/api")
_ALLOWED_TRANSPORTS = {"auto", "sse", "http"}
def _sanitize_mcp_transport(config):
"""Normalise and validate the transport_type field.
Strips ``command`` / ``args`` keys that are only valid for local STDIO
transports and returns the cleaned transport type string.
"""
transport_type = (config.get("transport_type") or "auto").lower()
if transport_type not in _ALLOWED_TRANSPORTS:
raise ValueError(f"Unsupported transport_type: {transport_type}")
config.pop("command", None)
config.pop("args", None)
config["transport_type"] = transport_type
return transport_type
def _extract_auth_credentials(config):
"""Build an ``auth_credentials`` dict from the raw MCP config."""
auth_credentials = {}
auth_type = config.get("auth_type", "none")
if auth_type == "api_key":
if config.get("api_key"):
auth_credentials["api_key"] = config["api_key"]
if config.get("api_key_header"):
auth_credentials["api_key_header"] = config["api_key_header"]
elif auth_type == "bearer":
if config.get("bearer_token"):
auth_credentials["bearer_token"] = config["bearer_token"]
elif auth_type == "basic":
if config.get("username"):
auth_credentials["username"] = config["username"]
if config.get("password"):
auth_credentials["password"] = config["password"]
return auth_credentials
def _validate_mcp_server_url(config: dict) -> None:
"""Validate the server_url in an MCP config to prevent SSRF.
Raises:
ValueError: If the URL is missing or points to a blocked address.
"""
server_url = (config.get("server_url") or "").strip()
if not server_url:
raise ValueError("server_url is required")
try:
validate_url(server_url)
except SSRFError as exc:
raise ValueError(f"Invalid server URL: {exc}") from exc
@tools_mcp_ns.route("/mcp_server/test")
class TestMCPServerConfig(Resource):
@@ -43,49 +102,63 @@ class TestMCPServerConfig(Resource):
return missing_fields
try:
config = data["config"]
transport_type = (config.get("transport_type") or "auto").lower()
allowed_transports = {"auto", "sse", "http"}
if transport_type not in allowed_transports:
try:
_sanitize_mcp_transport(config)
except ValueError:
return make_response(
jsonify({"success": False, "error": "Unsupported transport_type"}),
400,
)
config.pop("command", None)
config.pop("args", None)
config["transport_type"] = transport_type
auth_credentials = {}
auth_type = config.get("auth_type", "none")
_validate_mcp_server_url(config)
if auth_type == "api_key" and "api_key" in config:
auth_credentials["api_key"] = config["api_key"]
if "api_key_header" in config:
auth_credentials["api_key_header"] = config["api_key_header"]
elif auth_type == "bearer" and "bearer_token" in config:
auth_credentials["bearer_token"] = config["bearer_token"]
elif auth_type == "basic":
if "username" in config:
auth_credentials["username"] = config["username"]
if "password" in config:
auth_credentials["password"] = config["password"]
auth_credentials = _extract_auth_credentials(config)
test_config = config.copy()
test_config["auth_credentials"] = auth_credentials
mcp_tool = MCPTool(config=test_config, user_id=user)
result = mcp_tool.test_connection()
# Sanitize the response to avoid exposing internal error details
if not result.get("success") and "message" in result:
current_app.logger.error(f"MCP connection test failed: {result.get('message')}")
result["message"] = "Connection test failed"
if result.get("requires_oauth"):
safe_result = {
k: v
for k, v in result.items()
if k in ("success", "requires_oauth", "auth_url")
}
return make_response(jsonify(safe_result), 200)
return make_response(jsonify(result), 200)
if not result.get("success"):
current_app.logger.error(
f"MCP connection test failed: {result.get('message')}"
)
return make_response(
jsonify(
{
"success": False,
"message": "Connection test failed",
"tools_count": 0,
}
),
200,
)
safe_result = {
"success": True,
"message": result.get("message", "Connection successful"),
"tools_count": result.get("tools_count", 0),
"tools": result.get("tools", []),
}
return make_response(jsonify(safe_result), 200)
except ValueError as e:
current_app.logger.warning(f"Invalid MCP server test request: {e}")
return make_response(
jsonify({"success": False, "error": "Invalid MCP server configuration"}),
400,
)
except Exception as e:
current_app.logger.error(f"Error testing MCP server: {e}", exc_info=True)
return make_response(
jsonify(
{"success": False, "error": "Connection test failed"}
),
jsonify({"success": False, "error": "Connection test failed"}),
500,
)
@@ -125,32 +198,18 @@ class MCPServerSave(Resource):
return missing_fields
try:
config = data["config"]
transport_type = (config.get("transport_type") or "auto").lower()
allowed_transports = {"auto", "sse", "http"}
if transport_type not in allowed_transports:
try:
_sanitize_mcp_transport(config)
except ValueError:
return make_response(
jsonify({"success": False, "error": "Unsupported transport_type"}),
400,
)
config.pop("command", None)
config.pop("args", None)
config["transport_type"] = transport_type
auth_credentials = {}
_validate_mcp_server_url(config)
auth_credentials = _extract_auth_credentials(config)
auth_type = config.get("auth_type", "none")
if auth_type == "api_key":
if "api_key" in config and config["api_key"]:
auth_credentials["api_key"] = config["api_key"]
if "api_key_header" in config:
auth_credentials["api_key_header"] = config["api_key_header"]
elif auth_type == "bearer":
if "bearer_token" in config and config["bearer_token"]:
auth_credentials["bearer_token"] = config["bearer_token"]
elif auth_type == "basic":
if "username" in config and config["username"]:
auth_credentials["username"] = config["username"]
if "password" in config and config["password"]:
auth_credentials["password"] = config["password"]
mcp_config = config.copy()
mcp_config["auth_credentials"] = auth_credentials
@@ -188,79 +247,135 @@ class MCPServerSave(Resource):
"No valid credentials provided for the selected authentication type"
)
storage_config = config.copy()
tool_id = data.get("id")
existing_doc = None
existing_encrypted = None
if tool_id:
with db_readonly() as conn:
repo = UserToolsRepository(conn)
existing_doc = repo.get_any(tool_id, user)
if existing_doc and existing_doc.get("name") == "mcp_tool":
existing_encrypted = (existing_doc.get("config") or {}).get(
"encrypted_credentials"
)
else:
existing_doc = None
if auth_credentials:
encrypted_credentials_string = encrypt_credentials(
if existing_encrypted:
existing_secrets = decrypt_credentials(existing_encrypted, user)
existing_secrets.update(auth_credentials)
auth_credentials = existing_secrets
storage_config["encrypted_credentials"] = encrypt_credentials(
auth_credentials, user
)
storage_config["encrypted_credentials"] = encrypted_credentials_string
elif existing_encrypted:
storage_config["encrypted_credentials"] = existing_encrypted
for field in [
"api_key",
"bearer_token",
"username",
"password",
"api_key_header",
"redirect_uri",
]:
storage_config.pop(field, None)
transformed_actions = []
for action in actions_metadata:
action["active"] = True
if "parameters" in action:
if "properties" in action["parameters"]:
for param_name, param_details in action["parameters"][
"properties"
].items():
param_details["filled_by_llm"] = True
param_details["value"] = ""
transformed_actions.append(action)
tool_data = {
"name": "mcp_tool",
"displayName": data["displayName"],
"customName": data["displayName"],
"description": f"MCP Server: {storage_config.get('server_url', 'Unknown')}",
"config": storage_config,
"actions": transformed_actions,
"status": data.get("status", True),
"user": user,
}
transformed_actions = transform_actions(actions_metadata)
tool_id = data.get("id")
if tool_id:
result = user_tools_collection.update_one(
{"_id": ObjectId(tool_id), "user": user, "name": "mcp_tool"},
{"$set": {k: v for k, v in tool_data.items() if k != "user"}},
)
if result.matched_count == 0:
return make_response(
jsonify(
{
"success": False,
"error": "Tool not found or access denied",
}
),
404,
display_name = data["displayName"]
description = f"MCP Server: {storage_config.get('server_url', 'Unknown')}"
status_bool = bool(data.get("status", True))
with db_session() as conn:
repo = UserToolsRepository(conn)
if existing_doc:
repo.update(
str(existing_doc["id"]), user,
{
"display_name": display_name,
"custom_name": display_name,
"description": description,
"config": storage_config,
"actions": transformed_actions,
"status": status_bool,
},
)
response_data = {
"success": True,
"id": tool_id,
"message": f"MCP server updated successfully! Discovered {len(transformed_actions)} tools.",
"tools_count": len(transformed_actions),
}
else:
result = user_tools_collection.insert_one(tool_data)
tool_id = str(result.inserted_id)
response_data = {
"success": True,
"id": tool_id,
"message": f"MCP server created successfully! Discovered {len(transformed_actions)} tools.",
"tools_count": len(transformed_actions),
}
saved_id = str(existing_doc["id"])
response_data = {
"success": True,
"id": saved_id,
"message": f"MCP server updated successfully! Discovered {len(transformed_actions)} tools.",
"tools_count": len(transformed_actions),
}
else:
# Fall back to find_by_user_and_name — the original
# dual-write path also ran an existence check before
# deciding between insert and update.
existing_by_name = repo.find_by_user_and_name(user, "mcp_tool")
if tool_id is None and existing_by_name and (
(existing_by_name.get("config") or {}).get("server_url")
== storage_config.get("server_url")
):
repo.update(
str(existing_by_name["id"]), user,
{
"display_name": display_name,
"custom_name": display_name,
"description": description,
"config": storage_config,
"actions": transformed_actions,
"status": status_bool,
},
)
saved_id = str(existing_by_name["id"])
response_data = {
"success": True,
"id": saved_id,
"message": f"MCP server updated successfully! Discovered {len(transformed_actions)} tools.",
"tools_count": len(transformed_actions),
}
else:
created = repo.create(
user, "mcp_tool",
config=storage_config,
custom_name=display_name,
display_name=display_name,
description=description,
config_requirements={},
actions=transformed_actions,
status=status_bool,
)
saved_id = str(created["id"])
response_data = {
"success": True,
"id": saved_id,
"message": f"MCP server created successfully! Discovered {len(transformed_actions)} tools.",
"tools_count": len(transformed_actions),
}
if tool_id and existing_doc is None:
# Client requested update on a non-existent tool id.
return make_response(
jsonify(
{
"success": False,
"error": "Tool not found or access denied",
}
),
404,
)
return make_response(jsonify(response_data), 200)
except ValueError as e:
current_app.logger.warning(f"Invalid MCP server save request: {e}")
return make_response(
jsonify({"success": False, "error": "Invalid MCP server configuration"}),
400,
)
except Exception as e:
current_app.logger.error(f"Error saving MCP server: {e}", exc_info=True)
return make_response(
jsonify(
{"success": False, "error": "Failed to save MCP server"}
),
jsonify({"success": False, "error": "Failed to save MCP server"}),
500,
)
@@ -291,7 +406,7 @@ class MCPOAuthCallback(Resource):
params = {
"status": "error",
"message": f"OAuth error: {error}. Please try again and make sure to grant all requested permissions, including offline access.",
"provider": "mcp_tool"
"provider": "mcp_tool",
}
return redirect(f"/api/connectors/callback-status?{urlencode(params)}")
if not code or not state:
@@ -304,7 +419,6 @@ class MCPOAuthCallback(Resource):
return redirect(
"/api/connectors/callback-status?status=error&message=Internal+server+error:+Redis+not+available.&provider=mcp_tool"
)
code = unquote(code)
manager = MCPOAuthManager(redis_client)
success = manager.handle_oauth_callback(state, code, error)
if success:
@@ -327,10 +441,6 @@ class MCPOAuthCallback(Resource):
@tools_mcp_ns.route("/mcp_server/oauth_status/<string:task_id>")
class MCPOAuthStatus(Resource):
def get(self, task_id):
"""
Get current status of OAuth flow.
Frontend should poll this endpoint periodically.
"""
try:
redis_client = get_redis_instance()
status_key = f"mcp_oauth_status:{task_id}"
@@ -338,6 +448,14 @@ class MCPOAuthStatus(Resource):
if status_data:
status = json.loads(status_data)
if "tools" in status and isinstance(status["tools"], list):
status["tools"] = [
{
"name": t.get("name", "unknown"),
"description": t.get("description", ""),
}
for t in status["tools"]
]
return make_response(
jsonify({"success": True, "task_id": task_id, **status})
)
@@ -345,17 +463,103 @@ class MCPOAuthStatus(Resource):
return make_response(
jsonify(
{
"success": False,
"error": "Task not found or expired",
"success": True,
"task_id": task_id,
"status": "pending",
"message": "Waiting for OAuth to start...",
}
),
404,
200,
)
except Exception as e:
current_app.logger.error(
f"Error getting OAuth status for task {task_id}: {str(e)}", exc_info=True
f"Error getting OAuth status for task {task_id}: {str(e)}",
exc_info=True,
)
return make_response(
jsonify({"success": False, "error": "Failed to get OAuth status", "task_id": task_id}), 500
jsonify(
{
"success": False,
"error": "Failed to get OAuth status",
"task_id": task_id,
}
),
500,
)
@tools_mcp_ns.route("/mcp_server/auth_status")
class MCPAuthStatus(Resource):
@api.doc(
description="Batch check auth status for all MCP tools. "
"Lightweight DB-only check — no network calls to MCP servers."
)
def get(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
try:
with db_readonly() as conn:
tools_repo = UserToolsRepository(conn)
sessions_repo = ConnectorSessionsRepository(conn)
all_tools = tools_repo.list_for_user(user)
mcp_tools = [t for t in all_tools if t.get("name") == "mcp_tool"]
if not mcp_tools:
return make_response(
jsonify({"success": True, "statuses": {}}), 200
)
oauth_server_urls: dict = {}
statuses: dict = {}
for tool in mcp_tools:
tool_id = str(tool["id"])
config = tool.get("config") or {}
auth_type = config.get("auth_type", "none")
if auth_type == "oauth":
server_url = config.get("server_url", "")
if server_url:
parsed = urlparse(server_url)
base_url = f"{parsed.scheme}://{parsed.netloc}"
oauth_server_urls[tool_id] = base_url
else:
statuses[tool_id] = "needs_auth"
else:
statuses[tool_id] = "configured"
if oauth_server_urls:
# Look up a session per distinct base URL. MCP sessions
# are stored with ``provider = "mcp:<server_url>"``
# and the URL in ``server_url``; reuse the repo's
# per-URL accessor rather than an ad-hoc $in query.
url_has_tokens: dict = {}
for base_url in set(oauth_server_urls.values()):
session = sessions_repo.get_by_user_and_server_url(
user, base_url,
)
tokens = (
(session or {}).get("session_data", {}) or {}
).get("tokens", {}) or {}
# MCP code also stashes tokens into token_info on
# the row; consider either present as "connected".
token_info = (session or {}).get("token_info") or {}
url_has_tokens[base_url] = bool(
tokens.get("access_token")
or token_info.get("access_token")
)
for tool_id, base_url in oauth_server_urls.items():
if url_has_tokens.get(base_url):
statuses[tool_id] = "connected"
else:
statuses[tool_id] = "needs_auth"
return make_response(jsonify({"success": True, "statuses": statuses}), 200)
except Exception as e:
current_app.logger.error(
"Error checking MCP auth status: %s", e, exc_info=True
)
return make_response(
jsonify({"success": False, "error": "Failed to check auth status"}),
500,
)

View File

@@ -1,20 +1,167 @@
"""Tool management routes."""
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, Namespace, Resource
from application.agents.tools.spec_parser import parse_spec
from application.agents.tools.tool_manager import ToolManager
from application.api import api
from application.api.user.base import user_tools_collection
from application.core.url_validation import SSRFError, validate_url
from application.security.encryption import decrypt_credentials, encrypt_credentials
from application.storage.db.repositories.notes import NotesRepository
from application.storage.db.repositories.todos import TodosRepository
from application.storage.db.repositories.user_tools import UserToolsRepository
from application.storage.db.session import db_readonly, db_session
from application.utils import check_required_fields, validate_function_name
tool_config = {}
tool_manager = ToolManager(config=tool_config)
# ---------------------------------------------------------------------------
# Shape translation helpers
# ---------------------------------------------------------------------------
# The frontend speaks camelCase (``displayName`` / ``customName`` /
# ``configRequirements``). The PG ``user_tools`` table stores snake_case
# (``display_name`` / ``custom_name`` / ``config_requirements``). Keep the
# translation localized to this module so repositories stay pure.
_CAMEL_TO_SNAKE = {
"displayName": "display_name",
"customName": "custom_name",
"configRequirements": "config_requirements",
}
_SNAKE_TO_CAMEL = {v: k for k, v in _CAMEL_TO_SNAKE.items()}
def _row_to_api(row: dict) -> dict:
"""Rename DB-native snake_case keys to the camelCase shape the frontend expects."""
out = dict(row)
for snake, camel in _SNAKE_TO_CAMEL.items():
if snake in out:
out[camel] = out.pop(snake)
# ``user_id`` is exposed as ``user`` in the legacy API shape.
if "user_id" in out:
out["user"] = out.pop("user_id")
return out
def _api_to_update_fields(data: dict) -> dict:
"""Rename incoming camelCase update keys to the repo's snake_case columns."""
fields_out: dict = {}
for key, value in data.items():
fields_out[_CAMEL_TO_SNAKE.get(key, key)] = value
return fields_out
def _encrypt_secret_fields(config, config_requirements, user_id):
secret_keys = [
key for key, spec in config_requirements.items()
if spec.get("secret") and key in config and config[key]
]
if not secret_keys:
return config
storage_config = config.copy()
secret_values = {k: config[k] for k in secret_keys}
storage_config["encrypted_credentials"] = encrypt_credentials(secret_values, user_id)
for key in secret_keys:
storage_config.pop(key, None)
return storage_config
def _validate_config(config, config_requirements, has_existing_secrets=False):
errors = {}
for key, spec in config_requirements.items():
depends_on = spec.get("depends_on")
if depends_on:
if not all(config.get(dk) == dv for dk, dv in depends_on.items()):
continue
if spec.get("required") and not config.get(key):
if has_existing_secrets and spec.get("secret"):
continue
errors[key] = f"{spec.get('label', key)} is required"
value = config.get(key)
if value is not None and value != "":
if spec.get("type") == "number":
try:
num = float(value)
if key == "timeout" and (num < 1 or num > 300):
errors[key] = "Timeout must be between 1 and 300"
except (ValueError, TypeError):
errors[key] = f"{spec.get('label', key)} must be a number"
if spec.get("enum") and value not in spec["enum"]:
errors[key] = f"Invalid value for {spec.get('label', key)}"
return errors
def _merge_secrets_on_update(new_config, existing_config, config_requirements, user_id):
"""Merge incoming config with existing encrypted secrets and re-encrypt.
For updates, the client may omit unchanged secret values. This helper
decrypts any previously stored secrets, overlays whatever the client *did*
send, strips plain-text secrets from the stored config, and re-encrypts
the merged result.
Returns the final ``config`` dict ready for persistence.
"""
secret_keys = [
key for key, spec in config_requirements.items()
if spec.get("secret")
]
if not secret_keys:
return new_config
existing_secrets = {}
if "encrypted_credentials" in existing_config:
existing_secrets = decrypt_credentials(
existing_config["encrypted_credentials"], user_id
)
merged_secrets = existing_secrets.copy()
for key in secret_keys:
if key in new_config and new_config[key]:
merged_secrets[key] = new_config[key]
# Start from existing non-secret values, then overlay incoming non-secrets
storage_config = {
k: v for k, v in existing_config.items()
if k not in secret_keys and k != "encrypted_credentials"
}
storage_config.update(
{k: v for k, v in new_config.items() if k not in secret_keys}
)
if merged_secrets:
storage_config["encrypted_credentials"] = encrypt_credentials(
merged_secrets, user_id
)
else:
storage_config.pop("encrypted_credentials", None)
storage_config.pop("has_encrypted_credentials", None)
return storage_config
def transform_actions(actions_metadata):
"""Set default flags on action metadata for storage.
Marks each action as active, sets ``filled_by_llm`` and ``value`` on every
parameter property. Used by both the generic create_tool and MCP save routes.
"""
transformed = []
for action in actions_metadata:
action["active"] = True
if "parameters" in action:
props = action["parameters"].get("properties", {})
for param_details in props.values():
param_details["filled_by_llm"] = True
param_details["value"] = ""
transformed.append(action)
return transformed
tools_ns = Namespace("tools", description="Tool management operations", path="/api")
@@ -22,6 +169,8 @@ tools_ns = Namespace("tools", description="Tool management operations", path="/a
class AvailableTools(Resource):
@api.doc(description="Get available tools for a user")
def get(self):
if not request.decoded_token:
return make_response(jsonify({"success": False}), 401)
try:
tools_metadata = []
for tool_name, tool_instance in tool_manager.tools.items():
@@ -29,12 +178,15 @@ class AvailableTools(Resource):
lines = doc.split("\n", 1)
name = lines[0].strip()
description = lines[1].strip() if len(lines) > 1 else ""
config_req = tool_instance.get_config_requirements()
actions = tool_instance.get_actions_metadata()
tools_metadata.append(
{
"name": tool_name,
"displayName": name,
"description": description,
"configRequirements": tool_instance.get_config_requirements(),
"configRequirements": config_req,
"actions": actions,
}
)
except Exception as err:
@@ -54,12 +206,26 @@ class GetTools(Resource):
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
tools = user_tools_collection.find({"user": user})
with db_readonly() as conn:
rows = UserToolsRepository(conn).list_for_user(user)
user_tools = []
for tool in tools:
tool_copy = {**tool}
tool_copy["id"] = str(tool["_id"])
tool_copy.pop("_id", None)
for row in rows:
tool_copy = _row_to_api(row)
config_req = tool_copy.get("configRequirements", {})
if not config_req:
tool_instance = tool_manager.tools.get(tool_copy.get("name"))
if tool_instance:
config_req = tool_instance.get_config_requirements()
tool_copy["configRequirements"] = config_req
has_secrets = any(
spec.get("secret") for spec in config_req.values()
) if config_req else False
if has_secrets and "encrypted_credentials" in tool_copy.get("config", {}):
tool_copy["config"]["has_encrypted_credentials"] = True
tool_copy["config"].pop("encrypted_credentials", None)
user_tools.append(tool_copy)
except Exception as err:
current_app.logger.error(f"Error getting user tools: {err}", exc_info=True)
@@ -110,41 +276,61 @@ class CreateTool(Resource):
if missing_fields:
return missing_fields
try:
if data["name"] == "mcp_tool":
server_url = (data.get("config", {}).get("server_url") or "").strip()
if server_url:
try:
validate_url(server_url)
except SSRFError:
return make_response(
jsonify({"success": False, "message": "Invalid server URL"}),
400,
)
tool_instance = tool_manager.tools.get(data["name"])
if not tool_instance:
return make_response(
jsonify({"success": False, "message": "Tool not found"}), 404
)
actions_metadata = tool_instance.get_actions_metadata()
transformed_actions = []
for action in actions_metadata:
action["active"] = True
if "parameters" in action:
if "properties" in action["parameters"]:
for param_name, param_details in action["parameters"][
"properties"
].items():
param_details["filled_by_llm"] = True
param_details["value"] = ""
transformed_actions.append(action)
transformed_actions = transform_actions(actions_metadata)
except Exception as err:
current_app.logger.error(
f"Error getting tool actions: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
try:
new_tool = {
"user": user,
"name": data["name"],
"displayName": data["displayName"],
"description": data["description"],
"customName": data.get("customName", ""),
"actions": transformed_actions,
"config": data["config"],
"status": data["status"],
}
resp = user_tools_collection.insert_one(new_tool)
new_id = str(resp.inserted_id)
config_requirements = tool_instance.get_config_requirements()
if config_requirements:
validation_errors = _validate_config(
data["config"], config_requirements
)
if validation_errors:
return make_response(
jsonify(
{
"success": False,
"message": "Validation failed",
"errors": validation_errors,
}
),
400,
)
storage_config = _encrypt_secret_fields(
data["config"], config_requirements, user
)
with db_session() as conn:
created = UserToolsRepository(conn).create(
user,
data["name"],
config=storage_config,
custom_name=data.get("customName", ""),
display_name=data["displayName"],
description=data["description"],
config_requirements=config_requirements,
actions=transformed_actions,
status=bool(data.get("status", True)),
)
new_id = str(created["id"])
except Exception as err:
current_app.logger.error(f"Error creating tool: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
@@ -182,17 +368,10 @@ class UpdateTool(Resource):
if missing_fields:
return missing_fields
try:
update_data = {}
if "name" in data:
update_data["name"] = data["name"]
if "displayName" in data:
update_data["displayName"] = data["displayName"]
if "customName" in data:
update_data["customName"] = data["customName"]
if "description" in data:
update_data["description"] = data["description"]
if "actions" in data:
update_data["actions"] = data["actions"]
update_data: dict = {}
for key in ("name", "displayName", "customName", "description", "actions"):
if key in data:
update_data[key] = data[key]
if "config" in data:
if "actions" in data["config"]:
for action_name in list(data["config"]["actions"].keys()):
@@ -207,66 +386,61 @@ class UpdateTool(Resource):
),
400,
)
tool_doc = user_tools_collection.find_one(
{"_id": ObjectId(data["id"]), "user": user}
)
if tool_doc and tool_doc.get("name") == "mcp_tool":
config = data["config"]
existing_config = tool_doc.get("config", {})
storage_config = existing_config.copy()
with db_session() as conn:
repo = UserToolsRepository(conn)
tool_doc = repo.get_any(data["id"], user)
if not tool_doc:
return make_response(
jsonify({"success": False, "message": "Tool not found"}),
404,
)
tool_name = tool_doc.get("name", data.get("name"))
tool_instance = tool_manager.tools.get(tool_name)
config_requirements = (
tool_instance.get_config_requirements()
if tool_instance
else {}
)
existing_config = tool_doc.get("config", {}) or {}
has_existing_secrets = "encrypted_credentials" in existing_config
storage_config.update(config)
existing_credentials = {}
if "encrypted_credentials" in existing_config:
existing_credentials = decrypt_credentials(
existing_config["encrypted_credentials"], user
if config_requirements:
validation_errors = _validate_config(
data["config"], config_requirements,
has_existing_secrets=has_existing_secrets,
)
auth_credentials = existing_credentials.copy()
auth_type = storage_config.get("auth_type", "none")
if auth_type == "api_key":
if "api_key" in config and config["api_key"]:
auth_credentials["api_key"] = config["api_key"]
if "api_key_header" in config:
auth_credentials["api_key_header"] = config[
"api_key_header"
]
elif auth_type == "bearer":
if "bearer_token" in config and config["bearer_token"]:
auth_credentials["bearer_token"] = config["bearer_token"]
elif "encrypted_token" in config and config["encrypted_token"]:
auth_credentials["bearer_token"] = config["encrypted_token"]
elif auth_type == "basic":
if "username" in config and config["username"]:
auth_credentials["username"] = config["username"]
if "password" in config and config["password"]:
auth_credentials["password"] = config["password"]
if auth_type != "none" and auth_credentials:
encrypted_credentials_string = encrypt_credentials(
auth_credentials, user
if validation_errors:
return make_response(
jsonify({
"success": False,
"message": "Validation failed",
"errors": validation_errors,
}),
400,
)
update_data["config"] = _merge_secrets_on_update(
data["config"], existing_config, config_requirements, user
)
if "status" in data:
update_data["status"] = bool(data["status"])
repo.update(
str(tool_doc["id"]), user, _api_to_update_fields(update_data),
)
else:
if "status" in data:
update_data["status"] = bool(data["status"])
with db_session() as conn:
repo = UserToolsRepository(conn)
tool_doc = repo.get_any(data["id"], user)
if not tool_doc:
return make_response(
jsonify({"success": False, "message": "Tool not found"}),
404,
)
storage_config["encrypted_credentials"] = (
encrypted_credentials_string
)
elif auth_type == "none":
storage_config.pop("encrypted_credentials", None)
for field in [
"api_key",
"bearer_token",
"encrypted_token",
"username",
"password",
"api_key_header",
]:
storage_config.pop(field, None)
update_data["config"] = storage_config
else:
update_data["config"] = data["config"]
if "status" in data:
update_data["status"] = data["status"]
user_tools_collection.update_one(
{"_id": ObjectId(data["id"]), "user": user},
{"$set": update_data},
)
repo.update(
str(tool_doc["id"]), user, _api_to_update_fields(update_data),
)
except Exception as err:
current_app.logger.error(f"Error updating tool: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
@@ -298,10 +472,50 @@ class UpdateToolConfig(Resource):
if missing_fields:
return missing_fields
try:
user_tools_collection.update_one(
{"_id": ObjectId(data["id"]), "user": user},
{"$set": {"config": data["config"]}},
)
with db_session() as conn:
repo = UserToolsRepository(conn)
tool_doc = repo.get_any(data["id"], user)
if not tool_doc:
return make_response(jsonify({"success": False}), 404)
tool_name = tool_doc.get("name")
if tool_name == "mcp_tool":
server_url = (data["config"].get("server_url") or "").strip()
if server_url:
try:
validate_url(server_url)
except SSRFError:
return make_response(
jsonify({"success": False, "message": "Invalid server URL"}),
400,
)
tool_instance = tool_manager.tools.get(tool_name)
config_requirements = (
tool_instance.get_config_requirements() if tool_instance else {}
)
existing_config = tool_doc.get("config", {}) or {}
has_existing_secrets = "encrypted_credentials" in existing_config
if config_requirements:
validation_errors = _validate_config(
data["config"], config_requirements,
has_existing_secrets=has_existing_secrets,
)
if validation_errors:
return make_response(
jsonify({
"success": False,
"message": "Validation failed",
"errors": validation_errors,
}),
400,
)
final_config = _merge_secrets_on_update(
data["config"], existing_config, config_requirements, user
)
repo.update(str(tool_doc["id"]), user, {"config": final_config})
except Exception as err:
current_app.logger.error(
f"Error updating tool config: {err}", exc_info=True
@@ -337,10 +551,17 @@ class UpdateToolActions(Resource):
if missing_fields:
return missing_fields
try:
user_tools_collection.update_one(
{"_id": ObjectId(data["id"]), "user": user},
{"$set": {"actions": data["actions"]}},
)
with db_session() as conn:
repo = UserToolsRepository(conn)
tool_doc = repo.get_any(data["id"], user)
if not tool_doc:
return make_response(
jsonify({"success": False, "message": "Tool not found"}),
404,
)
repo.update(
str(tool_doc["id"]), user, {"actions": data["actions"]},
)
except Exception as err:
current_app.logger.error(
f"Error updating tool actions: {err}", exc_info=True
@@ -374,10 +595,17 @@ class UpdateToolStatus(Resource):
if missing_fields:
return missing_fields
try:
user_tools_collection.update_one(
{"_id": ObjectId(data["id"]), "user": user},
{"$set": {"status": data["status"]}},
)
with db_session() as conn:
repo = UserToolsRepository(conn)
tool_doc = repo.get_any(data["id"], user)
if not tool_doc:
return make_response(
jsonify({"success": False, "message": "Tool not found"}),
404,
)
repo.update(
str(tool_doc["id"]), user, {"status": bool(data["status"])},
)
except Exception as err:
current_app.logger.error(
f"Error updating tool status: {err}", exc_info=True
@@ -406,15 +634,18 @@ class DeleteTool(Resource):
if missing_fields:
return missing_fields
try:
result = user_tools_collection.delete_one(
{"_id": ObjectId(data["id"]), "user": user}
)
if result.deleted_count == 0:
return {"success": False, "message": "Tool not found"}, 404
with db_session() as conn:
repo = UserToolsRepository(conn)
tool_doc = repo.get_any(data["id"], user)
if not tool_doc:
return make_response(
jsonify({"success": False, "message": "Tool not found"}), 404
)
repo.delete(str(tool_doc["id"]), user)
except Exception as err:
current_app.logger.error(f"Error deleting tool: {err}", exc_info=True)
return {"success": False}, 400
return {"success": True}, 200
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True}), 200)
@tools_ns.route("/parse_spec")
@@ -479,71 +710,88 @@ class GetArtifact(Resource):
user_id = decoded_token.get("sub")
try:
obj_id = ObjectId(artifact_id)
except Exception:
return make_response(
jsonify({"success": False, "message": "Invalid artifact ID"}), 400
with db_readonly() as conn:
notes_repo = NotesRepository(conn)
todos_repo = TodosRepository(conn)
# Artifact IDs may be PG UUIDs (post-cutover) or legacy
# Mongo ObjectIds embedded in older conversation history.
# Both repos' ``get_any`` handles the id-shape branching
# internally so a non-UUID input never reaches
# ``CAST(:id AS uuid)`` (which would poison the readonly
# transaction and break the fallback below).
note_doc = notes_repo.get_any(artifact_id, user_id)
if note_doc:
content = note_doc.get("note", "") or note_doc.get("content", "")
line_count = len(content.split("\n")) if content else 0
updated = note_doc.get("updated_at")
artifact = {
"artifact_type": "note",
"data": {
"content": content,
"line_count": line_count,
"updated_at": (
updated.isoformat()
if hasattr(updated, "isoformat")
else updated
),
},
}
return make_response(
jsonify({"success": True, "artifact": artifact}), 200
)
todo_doc = todos_repo.get_any(artifact_id, user_id)
if todo_doc:
tool_id = todo_doc.get("tool_id")
all_todos = todos_repo.list_for_tool(user_id, tool_id) if tool_id else []
items = []
open_count = 0
completed_count = 0
for t in all_todos:
# PG ``todos`` stores a ``completed BOOLEAN`` column;
# the legacy Mongo shape used a ``status`` string.
# Keep the response shape stable by translating here.
status = "completed" if t.get("completed") else "open"
if status == "open":
open_count += 1
else:
completed_count += 1
created = t.get("created_at")
updated = t.get("updated_at")
items.append({
"todo_id": t.get("todo_id"),
"title": t.get("title", ""),
"status": status,
"created_at": (
created.isoformat()
if hasattr(created, "isoformat")
else created
),
"updated_at": (
updated.isoformat()
if hasattr(updated, "isoformat")
else updated
),
})
artifact = {
"artifact_type": "todo_list",
"data": {
"items": items,
"total_count": len(items),
"open_count": open_count,
"completed_count": completed_count,
},
}
return make_response(
jsonify({"success": True, "artifact": artifact}), 200
)
except Exception as err:
current_app.logger.error(
f"Error retrieving artifact: {err}", exc_info=True
)
from application.core.mongo_db import MongoDB
from application.core.settings import settings
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
note_doc = db["notes"].find_one({"_id": obj_id, "user_id": user_id})
if note_doc:
content = note_doc.get("note", "")
line_count = len(content.split("\n")) if content else 0
artifact = {
"artifact_type": "note",
"data": {
"content": content,
"line_count": line_count,
"updated_at": (
note_doc["updated_at"].isoformat()
if note_doc.get("updated_at")
else None
),
},
}
return make_response(jsonify({"success": True, "artifact": artifact}), 200)
todo_doc = db["todos"].find_one({"_id": obj_id, "user_id": user_id})
if todo_doc:
tool_id = todo_doc.get("tool_id")
# Return all todos for the tool
query = {"user_id": user_id, "tool_id": tool_id}
all_todos = list(db["todos"].find(query))
items = []
open_count = 0
completed_count = 0
for t in all_todos:
status = t.get("status", "open")
if status == "open":
open_count += 1
elif status == "completed":
completed_count += 1
items.append({
"todo_id": t.get("todo_id"),
"title": t.get("title", ""),
"status": status,
"created_at": (
t["created_at"].isoformat() if t.get("created_at") else None
),
"updated_at": (
t["updated_at"].isoformat() if t.get("updated_at") else None
),
})
artifact = {
"artifact_type": "todo_list",
"data": {
"items": items,
"total_count": len(items),
"open_count": open_count,
"completed_count": completed_count,
},
}
return make_response(jsonify({"success": True, "artifact": artifact}), 200)
return make_response(jsonify({"success": False}), 400)
return make_response(
jsonify({"success": False, "message": "Artifact not found"}), 404

View File

@@ -1,283 +1,61 @@
"""Centralized utilities for API routes."""
"""Centralized utilities for API routes.
Post-Mongo-cutover slim: the old Mongo-shaped helpers (``validate_object_id``,
``check_resource_ownership``, ``paginated_response``, ``serialize_object_id``,
``safe_db_operation``, ``validate_enum``, ``extract_sort_params``) have been
removed — they carried ``bson`` / ``pymongo`` imports and had zero callers.
"""
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Callable, Optional
from bson.errors import InvalidId
from bson.objectid import ObjectId
from flask import jsonify, make_response, request, Response
from pymongo.collection import Collection
from flask import (
Response,
jsonify,
make_response,
request,
)
def get_user_id() -> Optional[str]:
"""
Extract user ID from decoded JWT token.
Returns:
User ID string or None if not authenticated
"""
"""Extract user ID from decoded JWT token, or None if unauthenticated."""
decoded_token = getattr(request, "decoded_token", None)
return decoded_token.get("sub") if decoded_token else None
def require_auth(func: Callable) -> Callable:
"""
Decorator to require authentication for route handlers.
Usage:
@require_auth
def get(self):
user_id = get_user_id()
...
"""
"""Decorator to require authentication. Returns 401 when absent."""
@wraps(func)
def wrapper(*args, **kwargs):
user_id = get_user_id()
if not user_id:
return error_response("Unauthorized", 401)
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
return func(*args, **kwargs)
return wrapper
def success_response(
data: Optional[Dict[str, Any]] = None, status: int = 200
data=None, message: Optional[str] = None, status: int = 200
) -> Response:
"""
Create a standardized success response.
Args:
data: Optional data dictionary to include in response
status: HTTP status code (default: 200)
Returns:
Flask Response object
Example:
return success_response({"users": [...], "total": 10})
"""
response = {"success": True}
if data:
response.update(data)
return make_response(jsonify(response), status)
"""Shape a successful JSON response."""
body = {"success": True}
if data is not None:
body["data"] = data
if message is not None:
body["message"] = message
return make_response(jsonify(body), status)
def error_response(message: str, status: int = 400, **kwargs) -> Response:
"""
Create a standardized error response.
Args:
message: Error message string
status: HTTP status code (default: 400)
**kwargs: Additional fields to include in response
Returns:
Flask Response object
Example:
return error_response("Resource not found", 404)
return error_response("Invalid input", 400, errors=["field1", "field2"])
"""
response = {"success": False, "message": message}
response.update(kwargs)
return make_response(jsonify(response), status)
"""Shape an error JSON response; any kwargs are merged into the body."""
body = {"success": False, "error": message, **kwargs}
return make_response(jsonify(body), status)
def validate_object_id(
id_string: str, resource_name: str = "Resource"
) -> Tuple[Optional[ObjectId], Optional[Response]]:
"""
Validate and convert string to ObjectId.
Args:
id_string: String to convert
resource_name: Name of resource for error message
Returns:
Tuple of (ObjectId or None, error_response or None)
Example:
obj_id, error = validate_object_id(workflow_id, "Workflow")
if error:
return error
"""
try:
return ObjectId(id_string), None
except (InvalidId, TypeError):
return None, error_response(f"Invalid {resource_name} ID format")
def validate_pagination(
default_limit: int = 20, max_limit: int = 100
) -> Tuple[int, int, Optional[Response]]:
"""
Extract and validate pagination parameters from request.
Args:
default_limit: Default items per page
max_limit: Maximum allowed items per page
Returns:
Tuple of (limit, skip, error_response or None)
Example:
limit, skip, error = validate_pagination()
if error:
return error
"""
try:
limit = min(int(request.args.get("limit", default_limit)), max_limit)
skip = int(request.args.get("skip", 0))
if limit < 1 or skip < 0:
return 0, 0, error_response("Invalid pagination parameters")
return limit, skip, None
except ValueError:
return 0, 0, error_response("Invalid pagination parameters")
def check_resource_ownership(
collection: Collection,
resource_id: ObjectId,
user_id: str,
resource_name: str = "Resource",
) -> Tuple[Optional[Dict], Optional[Response]]:
"""
Check if resource exists and belongs to user.
Args:
collection: MongoDB collection
resource_id: Resource ObjectId
user_id: User ID string
resource_name: Name of resource for error messages
Returns:
Tuple of (resource_dict or None, error_response or None)
Example:
workflow, error = check_resource_ownership(
workflows_collection,
workflow_id,
user_id,
"Workflow"
)
if error:
return error
"""
resource = collection.find_one({"_id": resource_id, "user": user_id})
if not resource:
return None, error_response(f"{resource_name} not found", 404)
return resource, None
def serialize_object_id(
obj: Dict[str, Any], id_field: str = "_id", new_field: str = "id"
) -> Dict[str, Any]:
"""
Convert ObjectId to string in a dictionary.
Args:
obj: Dictionary containing ObjectId
id_field: Field name containing ObjectId
new_field: New field name for string ID
Returns:
Modified dictionary
Example:
user = serialize_object_id(user_doc)
# user["id"] = "507f1f77bcf86cd799439011"
"""
if id_field in obj:
obj[new_field] = str(obj[id_field])
if id_field != new_field:
obj.pop(id_field, None)
return obj
def serialize_list(items: List[Dict], serializer: Callable[[Dict], Dict]) -> List[Dict]:
"""
Apply serializer function to list of items.
Args:
items: List of dictionaries
serializer: Function to apply to each item
Returns:
List of serialized items
Example:
workflows = serialize_list(workflow_docs, serialize_workflow)
"""
return [serializer(item) for item in items]
def paginated_response(
collection: Collection,
query: Dict[str, Any],
serializer: Callable[[Dict], Dict],
limit: int,
skip: int,
sort_field: str = "created_at",
sort_order: int = -1,
response_key: str = "items",
) -> Response:
"""
Create paginated response for collection query.
Args:
collection: MongoDB collection
query: Query dictionary
serializer: Function to serialize each item
limit: Items per page
skip: Number of items to skip
sort_field: Field to sort by
sort_order: Sort order (1=asc, -1=desc)
response_key: Key name for items in response
Returns:
Flask Response with paginated data
Example:
return paginated_response(
workflows_collection,
{"user": user_id},
serialize_workflow,
limit, skip,
response_key="workflows"
)
"""
items = list(
collection.find(query).sort(sort_field, sort_order).skip(skip).limit(limit)
)
total = collection.count_documents(query)
return success_response(
{
response_key: serialize_list(items, serializer),
"total": total,
"limit": limit,
"skip": skip,
}
)
def require_fields(required: List[str]) -> Callable:
"""
Decorator to validate required fields in request JSON.
Args:
required: List of required field names
Returns:
Decorator function
Example:
@require_fields(["name", "description"])
def post(self):
data = request.get_json()
...
"""
def require_fields(required: list) -> Callable:
"""Decorator: return 400 if any listed field is missing/falsy in the JSON body."""
def decorator(func: Callable) -> Callable:
@wraps(func)
@@ -287,92 +65,11 @@ def require_fields(required: List[str]) -> Callable:
return error_response("Request body required")
missing = [field for field in required if not data.get(field)]
if missing:
return error_response(f"Missing required fields: {', '.join(missing)}")
return error_response(
f"Missing required fields: {', '.join(missing)}"
)
return func(*args, **kwargs)
return wrapper
return decorator
def safe_db_operation(
operation: Callable, error_message: str = "Database operation failed"
) -> Tuple[Any, Optional[Response]]:
"""
Safely execute database operation with error handling.
Args:
operation: Function to execute
error_message: Error message if operation fails
Returns:
Tuple of (result or None, error_response or None)
Example:
result, error = safe_db_operation(
lambda: collection.insert_one(doc),
"Failed to create resource"
)
if error:
return error
"""
try:
result = operation()
return result, None
except Exception as e:
return None, error_response(f"{error_message}: {str(e)}")
def validate_enum(
value: Any, allowed: List[Any], field_name: str
) -> Optional[Response]:
"""
Validate that value is in allowed list.
Args:
value: Value to validate
allowed: List of allowed values
field_name: Field name for error message
Returns:
error_response if invalid, None if valid
Example:
error = validate_enum(status, ["draft", "published"], "status")
if error:
return error
"""
if value not in allowed:
allowed_str = ", ".join(f"'{v}'" for v in allowed)
return error_response(f"Invalid {field_name}. Must be one of: {allowed_str}")
return None
def extract_sort_params(
default_field: str = "created_at",
default_order: str = "desc",
allowed_fields: Optional[List[str]] = None,
) -> Tuple[str, int]:
"""
Extract and validate sort parameters from request.
Args:
default_field: Default sort field
default_order: Default sort order ("asc" or "desc")
allowed_fields: List of allowed sort fields (None = no validation)
Returns:
Tuple of (sort_field, sort_order)
Example:
sort_field, sort_order = extract_sort_params(
allowed_fields=["name", "date", "status"]
)
"""
sort_field = request.args.get("sort", default_field)
sort_order_str = request.args.get("order", default_order).lower()
if allowed_fields and sort_field not in allowed_fields:
sort_field = default_field
sort_order = -1 if sort_order_str == "desc" else 1
return sort_field, sort_order

View File

@@ -1,62 +1,142 @@
"""Workflow management routes."""
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional, Set
from flask import current_app, request
from flask_restx import Namespace, Resource
from application.api.user.base import (
workflow_edges_collection,
workflow_nodes_collection,
workflows_collection,
)
from application.storage.db.base_repository import looks_like_uuid
from application.storage.db.repositories.workflow_edges import WorkflowEdgesRepository
from application.storage.db.repositories.workflow_nodes import WorkflowNodesRepository
from application.storage.db.repositories.workflows import WorkflowsRepository
from application.storage.db.session import db_readonly, db_session
from application.core.json_schema_utils import (
JsonSchemaValidationError,
normalize_json_schema_payload,
)
from application.core.model_utils import get_model_capabilities
from application.api.user.utils import (
check_resource_ownership,
error_response,
get_user_id,
require_auth,
require_fields,
safe_db_operation,
success_response,
validate_object_id,
)
workflows_ns = Namespace("workflows", path="/api")
def _workflow_error_response(message: str, err: Exception):
current_app.logger.error(f"{message}: {err}", exc_info=True)
return error_response(message)
def _resolve_workflow(repo: WorkflowsRepository, workflow_id: str, user_id: str):
"""Resolve a workflow by UUID or legacy Mongo id, scoped to user."""
if not workflow_id:
return None
if looks_like_uuid(workflow_id):
row = repo.get(workflow_id, user_id)
if row is not None:
return row
return repo.get_by_legacy_id(workflow_id, user_id)
def _write_graph(
conn,
pg_workflow_id: str,
graph_version: int,
nodes_data: List[Dict],
edges_data: List[Dict],
) -> List[Dict]:
"""Bulk-create nodes + edges for one graph version. Uses ON CONFLICT upsert.
Edges arrive with source/target as user-provided node-id strings. We
insert nodes first, capture their ``node_id → UUID`` map, then
translate edges before insertion. Edges referencing missing nodes are
dropped with a warning.
"""
nodes_repo = WorkflowNodesRepository(conn)
edges_repo = WorkflowEdgesRepository(conn)
if nodes_data:
created_nodes = nodes_repo.bulk_create(
pg_workflow_id, graph_version,
[
{
"node_id": n["id"],
"node_type": n["type"],
"title": n.get("title", ""),
"description": n.get("description", ""),
"position": n.get("position", {"x": 0, "y": 0}),
"config": n.get("data", {}),
}
for n in nodes_data
],
)
node_uuid_by_str = {n["node_id"]: n["id"] for n in created_nodes}
else:
created_nodes = []
node_uuid_by_str = {}
if edges_data:
translated_edges: List[Dict] = []
for e in edges_data:
src = e.get("source")
tgt = e.get("target")
from_uuid = node_uuid_by_str.get(src)
to_uuid = node_uuid_by_str.get(tgt)
if not from_uuid or not to_uuid:
current_app.logger.warning(
"Workflow graph write: dropping edge %s; node refs unresolved "
"(source=%s, target=%s)",
e.get("id"), src, tgt,
)
continue
translated_edges.append({
"edge_id": e["id"],
"from_node_id": from_uuid,
"to_node_id": to_uuid,
"source_handle": e.get("sourceHandle"),
"target_handle": e.get("targetHandle"),
})
if translated_edges:
edges_repo.bulk_create(
pg_workflow_id, graph_version, translated_edges,
)
return created_nodes
def serialize_workflow(w: Dict) -> Dict:
"""Serialize workflow document to API response format."""
"""Serialize workflow row to API response format."""
created_at = w.get("created_at")
updated_at = w.get("updated_at")
return {
"id": str(w["_id"]),
"id": str(w["id"]),
"name": w.get("name"),
"description": w.get("description"),
"created_at": w["created_at"].isoformat() if w.get("created_at") else None,
"updated_at": w["updated_at"].isoformat() if w.get("updated_at") else None,
"created_at": created_at.isoformat() if hasattr(created_at, "isoformat") else created_at,
"updated_at": updated_at.isoformat() if hasattr(updated_at, "isoformat") else updated_at,
}
def serialize_node(n: Dict) -> Dict:
"""Serialize workflow node document to API response format."""
"""Serialize workflow node row to API response format."""
return {
"id": n["id"],
"type": n["type"],
"id": n["node_id"],
"type": n["node_type"],
"title": n.get("title"),
"description": n.get("description"),
"position": n.get("position"),
"data": n.get("config", {}),
"data": n.get("config", {}) or {},
}
def serialize_edge(e: Dict) -> Dict:
"""Serialize workflow edge document to API response format."""
"""Serialize workflow edge row to API response format."""
return {
"id": e["id"],
"id": e["edge_id"],
"source": e.get("source_id"),
"target": e.get("target_id"),
"sourceHandle": e.get("source_handle"),
@@ -65,7 +145,7 @@ def serialize_edge(e: Dict) -> Dict:
def get_workflow_graph_version(workflow: Dict) -> int:
"""Get current graph version with legacy fallback."""
"""Get current graph version with fallback."""
raw_version = workflow.get("current_graph_version", 1)
try:
version = int(raw_version)
@@ -74,22 +154,6 @@ def get_workflow_graph_version(workflow: Dict) -> int:
return 1
def fetch_graph_documents(collection, workflow_id: str, graph_version: int) -> List[Dict]:
"""Fetch graph docs for active version, with fallback for legacy unversioned data."""
docs = list(
collection.find({"workflow_id": workflow_id, "graph_version": graph_version})
)
if docs:
return docs
if graph_version == 1:
return list(
collection.find(
{"workflow_id": workflow_id, "graph_version": {"$exists": False}}
)
)
return docs
def validate_json_schema_payload(
json_schema: Any,
) -> tuple[Optional[Dict[str, Any]], Optional[str]]:
@@ -310,49 +374,6 @@ def _can_reach_end(
return any(_can_reach_end(t, edges, node_map, end_ids, visited) for t in outgoing if t)
def create_workflow_nodes(
workflow_id: str, nodes_data: List[Dict], graph_version: int
) -> None:
"""Insert workflow nodes into database."""
if nodes_data:
workflow_nodes_collection.insert_many(
[
{
"id": n["id"],
"workflow_id": workflow_id,
"graph_version": graph_version,
"type": n["type"],
"title": n.get("title", ""),
"description": n.get("description", ""),
"position": n.get("position", {"x": 0, "y": 0}),
"config": n.get("data", {}),
}
for n in nodes_data
]
)
def create_workflow_edges(
workflow_id: str, edges_data: List[Dict], graph_version: int
) -> None:
"""Insert workflow edges into database."""
if edges_data:
workflow_edges_collection.insert_many(
[
{
"id": e["id"],
"workflow_id": workflow_id,
"graph_version": graph_version,
"source_id": e.get("source"),
"target_id": e.get("target"),
"source_handle": e.get("sourceHandle"),
"target_handle": e.get("targetHandle"),
}
for e in edges_data
]
)
@workflows_ns.route("/workflows")
class WorkflowList(Resource):
@@ -364,6 +385,7 @@ class WorkflowList(Resource):
data = request.get_json()
name = data.get("name", "").strip()
description = data.get("description", "")
nodes_data = data.get("nodes", [])
edges_data = data.get("edges", [])
@@ -374,35 +396,16 @@ class WorkflowList(Resource):
)
nodes_data = normalize_agent_node_json_schemas(nodes_data)
now = datetime.now(timezone.utc)
workflow_doc = {
"name": name,
"description": data.get("description", ""),
"user": user_id,
"created_at": now,
"updated_at": now,
"current_graph_version": 1,
}
result, error = safe_db_operation(
lambda: workflows_collection.insert_one(workflow_doc),
"Failed to create workflow",
)
if error:
return error
workflow_id = str(result.inserted_id)
try:
create_workflow_nodes(workflow_id, nodes_data, 1)
create_workflow_edges(workflow_id, edges_data, 1)
except Exception as e:
workflow_nodes_collection.delete_many({"workflow_id": workflow_id})
workflow_edges_collection.delete_many({"workflow_id": workflow_id})
workflows_collection.delete_one({"_id": result.inserted_id})
return error_response(f"Failed to create workflow structure: {str(e)}")
with db_session() as conn:
repo = WorkflowsRepository(conn)
workflow = repo.create(user_id, name, description=description)
pg_workflow_id = str(workflow["id"])
_write_graph(conn, pg_workflow_id, 1, nodes_data, edges_data)
except Exception as err:
return _workflow_error_response("Failed to create workflow", err)
return success_response({"id": workflow_id}, 201)
return success_response({"id": pg_workflow_id}, 201)
@workflows_ns.route("/workflows/<string:workflow_id>")
@@ -412,23 +415,22 @@ class WorkflowDetail(Resource):
def get(self, workflow_id: str):
"""Get workflow details with nodes and edges."""
user_id = get_user_id()
obj_id, error = validate_object_id(workflow_id, "Workflow")
if error:
return error
workflow, error = check_resource_ownership(
workflows_collection, obj_id, user_id, "Workflow"
)
if error:
return error
graph_version = get_workflow_graph_version(workflow)
nodes = fetch_graph_documents(
workflow_nodes_collection, workflow_id, graph_version
)
edges = fetch_graph_documents(
workflow_edges_collection, workflow_id, graph_version
)
try:
with db_readonly() as conn:
repo = WorkflowsRepository(conn)
workflow = _resolve_workflow(repo, workflow_id, user_id)
if workflow is None:
return error_response("Workflow not found", 404)
pg_workflow_id = str(workflow["id"])
graph_version = get_workflow_graph_version(workflow)
nodes = WorkflowNodesRepository(conn).find_by_version(
pg_workflow_id, graph_version,
)
edges = WorkflowEdgesRepository(conn).find_by_version(
pg_workflow_id, graph_version,
)
except Exception as err:
return _workflow_error_response("Failed to fetch workflow", err)
return success_response(
{
@@ -443,18 +445,9 @@ class WorkflowDetail(Resource):
def put(self, workflow_id: str):
"""Update workflow and replace nodes/edges."""
user_id = get_user_id()
obj_id, error = validate_object_id(workflow_id, "Workflow")
if error:
return error
workflow, error = check_resource_ownership(
workflows_collection, obj_id, user_id, "Workflow"
)
if error:
return error
data = request.get_json()
name = data.get("name", "").strip()
description = data.get("description", "")
nodes_data = data.get("nodes", [])
edges_data = data.get("edges", [])
@@ -465,55 +458,36 @@ class WorkflowDetail(Resource):
)
nodes_data = normalize_agent_node_json_schemas(nodes_data)
current_graph_version = get_workflow_graph_version(workflow)
next_graph_version = current_graph_version + 1
try:
create_workflow_nodes(workflow_id, nodes_data, next_graph_version)
create_workflow_edges(workflow_id, edges_data, next_graph_version)
except Exception as e:
workflow_nodes_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": next_graph_version}
)
workflow_edges_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": next_graph_version}
)
return error_response(f"Failed to update workflow structure: {str(e)}")
with db_session() as conn:
repo = WorkflowsRepository(conn)
workflow = _resolve_workflow(repo, workflow_id, user_id)
if workflow is None:
return error_response("Workflow not found", 404)
pg_workflow_id = str(workflow["id"])
current_graph_version = get_workflow_graph_version(workflow)
next_graph_version = current_graph_version + 1
now = datetime.now(timezone.utc)
_, error = safe_db_operation(
lambda: workflows_collection.update_one(
{"_id": obj_id},
{
"$set": {
_write_graph(
conn, pg_workflow_id, next_graph_version,
nodes_data, edges_data,
)
repo.update(
pg_workflow_id, user_id,
{
"name": name,
"description": data.get("description", ""),
"updated_at": now,
"description": description,
"current_graph_version": next_graph_version,
}
},
),
"Failed to update workflow",
)
if error:
workflow_nodes_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": next_graph_version}
)
workflow_edges_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": next_graph_version}
)
return error
try:
workflow_nodes_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": {"$ne": next_graph_version}}
)
workflow_edges_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": {"$ne": next_graph_version}}
)
except Exception as cleanup_err:
current_app.logger.warning(
f"Failed to clean old workflow graph versions for {workflow_id}: {cleanup_err}"
)
},
)
WorkflowNodesRepository(conn).delete_other_versions(
pg_workflow_id, next_graph_version,
)
WorkflowEdgesRepository(conn).delete_other_versions(
pg_workflow_id, next_graph_version,
)
except Exception as err:
return _workflow_error_response("Failed to update workflow", err)
return success_response()
@@ -521,21 +495,15 @@ class WorkflowDetail(Resource):
def delete(self, workflow_id: str):
"""Delete workflow and its graph."""
user_id = get_user_id()
obj_id, error = validate_object_id(workflow_id, "Workflow")
if error:
return error
workflow, error = check_resource_ownership(
workflows_collection, obj_id, user_id, "Workflow"
)
if error:
return error
try:
workflow_nodes_collection.delete_many({"workflow_id": workflow_id})
workflow_edges_collection.delete_many({"workflow_id": workflow_id})
workflows_collection.delete_one({"_id": workflow["_id"], "user": user_id})
except Exception as e:
return error_response(f"Failed to delete workflow: {str(e)}")
with db_session() as conn:
repo = WorkflowsRepository(conn)
workflow = _resolve_workflow(repo, workflow_id, user_id)
if workflow is None:
return error_response("Workflow not found", 404)
# ON DELETE CASCADE on workflow_nodes/edges cleans children.
repo.delete(str(workflow["id"]), user_id)
except Exception as err:
return _workflow_error_response("Failed to delete workflow", err)
return success_response()

View File

@@ -0,0 +1,3 @@
from application.api.v1.routes import v1_bp
__all__ = ["v1_bp"]

View File

@@ -0,0 +1,331 @@
"""Standard chat completions API routes.
Exposes ``/v1/chat/completions`` and ``/v1/models`` endpoints that
follow the widely-adopted chat completions protocol so external tools
(opencode, continue, etc.) can connect to DocsGPT agents.
"""
import json
import logging
import time
import traceback
from typing import Any, Dict, Generator, Optional
from flask import Blueprint, jsonify, make_response, request, Response
from application.api.answer.routes.base import BaseAnswerResource
from application.api.answer.services.stream_processor import StreamProcessor
from application.api.v1.translator import (
translate_request,
translate_response,
translate_stream_event,
)
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.session import db_readonly
logger = logging.getLogger(__name__)
v1_bp = Blueprint("v1", __name__, url_prefix="/v1")
def _extract_bearer_token() -> Optional[str]:
"""Extract API key from Authorization: Bearer header."""
auth = request.headers.get("Authorization", "")
if auth.startswith("Bearer "):
return auth[7:].strip()
return None
def _lookup_agent(api_key: str) -> Optional[Dict]:
"""Look up the agent document for this API key."""
try:
with db_readonly() as conn:
return AgentsRepository(conn).find_by_key(api_key)
except Exception:
logger.warning("Failed to look up agent for API key", exc_info=True)
return None
def _get_model_name(agent: Optional[Dict], api_key: str) -> str:
"""Return agent name for display as model name."""
if agent:
return agent.get("name", api_key)
return api_key
class _V1AnswerHelper(BaseAnswerResource):
"""Thin wrapper to access complete_stream / process_response_stream."""
pass
@v1_bp.route("/chat/completions", methods=["POST"])
def chat_completions():
"""Handle POST /v1/chat/completions."""
api_key = _extract_bearer_token()
if not api_key:
return make_response(
jsonify({"error": {"message": "Missing Authorization header", "type": "auth_error"}}),
401,
)
data = request.get_json()
if not data or not data.get("messages"):
return make_response(
jsonify({"error": {"message": "messages field is required", "type": "invalid_request"}}),
400,
)
is_stream = data.get("stream", False)
agent_doc = _lookup_agent(api_key)
model_name = _get_model_name(agent_doc, api_key)
try:
internal_data = translate_request(data, api_key)
except Exception as e:
logger.error(f"/v1/chat/completions translate error: {e}", exc_info=True)
return make_response(
jsonify({"error": {"message": "Failed to process request", "type": "invalid_request"}}),
400,
)
# Link decoded_token to the agent's owner so continuation state,
# logs, and tool execution use the correct user identity. The PG
# ``agents`` row exposes the owner via ``user_id`` (``user`` is the
# legacy Mongo field name kept in ``row_to_dict`` only for the
# mapping ``id``/``_id``).
agent_user = (
(agent_doc.get("user_id") or agent_doc.get("user"))
if agent_doc else None
)
decoded_token = {"sub": agent_user or "api_key_user"}
try:
processor = StreamProcessor(internal_data, decoded_token)
if internal_data.get("tool_actions"):
# Continuation mode
conversation_id = internal_data.get("conversation_id")
if not conversation_id:
return make_response(
jsonify({"error": {"message": "conversation_id required for tool continuation", "type": "invalid_request"}}),
400,
)
(
agent,
messages,
tools_dict,
pending_tool_calls,
tool_actions,
) = processor.resume_from_tool_actions(
internal_data["tool_actions"], conversation_id
)
continuation = {
"messages": messages,
"tools_dict": tools_dict,
"pending_tool_calls": pending_tool_calls,
"tool_actions": tool_actions,
}
question = ""
else:
# Normal mode
question = internal_data.get("question", "")
agent = processor.build_agent(question)
continuation = None
if not processor.decoded_token:
return make_response(
jsonify({"error": {"message": "Unauthorized", "type": "auth_error"}}),
401,
)
helper = _V1AnswerHelper()
usage_error = helper.check_usage(processor.agent_config)
if usage_error:
return usage_error
should_save_conversation = bool(internal_data.get("save_conversation", False))
if is_stream:
return Response(
_stream_response(
helper,
question,
agent,
processor,
model_name,
continuation,
should_save_conversation,
),
mimetype="text/event-stream",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no",
},
)
else:
return _non_stream_response(
helper,
question,
agent,
processor,
model_name,
continuation,
should_save_conversation,
)
except ValueError as e:
logger.error(
f"/v1/chat/completions error: {e} - {traceback.format_exc()}",
extra={"error": str(e)},
)
return make_response(
jsonify({"error": {"message": "Failed to process request", "type": "invalid_request"}}),
400,
)
except Exception as e:
logger.error(
f"/v1/chat/completions error: {e} - {traceback.format_exc()}",
extra={"error": str(e)},
)
return make_response(
jsonify({"error": {"message": "Internal server error", "type": "server_error"}}),
500,
)
def _stream_response(
helper: _V1AnswerHelper,
question: str,
agent: Any,
processor: StreamProcessor,
model_name: str,
continuation: Optional[Dict],
should_save_conversation: bool,
) -> Generator[str, None, None]:
"""Generate translated SSE chunks for streaming response."""
completion_id = f"chatcmpl-{int(time.time())}"
internal_stream = helper.complete_stream(
question=question,
agent=agent,
conversation_id=processor.conversation_id,
user_api_key=processor.agent_config.get("user_api_key"),
decoded_token=processor.decoded_token,
agent_id=processor.agent_id,
model_id=processor.model_id,
should_save_conversation=should_save_conversation,
_continuation=continuation,
)
for line in internal_stream:
if not line.strip():
continue
# Parse the internal SSE event
event_str = line.replace("data: ", "").strip()
try:
event_data = json.loads(event_str)
except (json.JSONDecodeError, TypeError):
continue
# Update completion_id when we get the conversation id
if event_data.get("type") == "id":
conv_id = event_data.get("id", "")
if conv_id:
completion_id = f"chatcmpl-{conv_id}"
# Translate to standard format
translated = translate_stream_event(event_data, completion_id, model_name)
for chunk in translated:
yield chunk
def _non_stream_response(
helper: _V1AnswerHelper,
question: str,
agent: Any,
processor: StreamProcessor,
model_name: str,
continuation: Optional[Dict],
should_save_conversation: bool,
) -> Response:
"""Collect full response and return as single JSON."""
stream = helper.complete_stream(
question=question,
agent=agent,
conversation_id=processor.conversation_id,
user_api_key=processor.agent_config.get("user_api_key"),
decoded_token=processor.decoded_token,
agent_id=processor.agent_id,
model_id=processor.model_id,
should_save_conversation=should_save_conversation,
_continuation=continuation,
)
result = helper.process_response_stream(stream)
if result["error"]:
return make_response(
jsonify({"error": {"message": result["error"], "type": "server_error"}}),
500,
)
extra = result.get("extra")
pending = extra.get("pending_tool_calls") if isinstance(extra, dict) else None
response = translate_response(
conversation_id=result["conversation_id"],
answer=result["answer"] or "",
sources=result["sources"],
tool_calls=result["tool_calls"],
thought=result["thought"] or "",
model_name=model_name,
pending_tool_calls=pending,
)
return make_response(jsonify(response), 200)
@v1_bp.route("/models", methods=["GET"])
def list_models():
"""Handle GET /v1/models — return agents as models."""
api_key = _extract_bearer_token()
if not api_key:
return make_response(
jsonify({"error": {"message": "Missing Authorization header", "type": "auth_error"}}),
401,
)
try:
with db_readonly() as conn:
agents_repo = AgentsRepository(conn)
agent = agents_repo.find_by_key(api_key)
if not agent:
return make_response(
jsonify({"error": {"message": "Invalid API key", "type": "auth_error"}}),
401,
)
created = agent.get("created_at") or agent.get("createdAt")
created_ts = (
int(created.timestamp()) if hasattr(created, "timestamp")
else int(time.time())
)
model_id = str(agent.get("id") or agent.get("_id") or "")
model = {
"id": model_id,
"object": "model",
"created": created_ts,
"owned_by": "docsgpt",
"name": agent.get("name", ""),
"description": agent.get("description", ""),
}
return make_response(
jsonify({"object": "list", "data": [model]}),
200,
)
except Exception as e:
logger.error(f"/v1/models error: {e}", exc_info=True)
return make_response(
jsonify({"error": {"message": "Internal server error", "type": "server_error"}}),
500,
)

View File

@@ -0,0 +1,433 @@
"""Translate between standard chat completions format and DocsGPT internals.
This module handles:
- Request translation (chat completions -> DocsGPT internal format)
- Response translation (DocsGPT response -> chat completions format)
- Streaming event translation (DocsGPT SSE -> standard SSE chunks)
"""
import json
import time
from typing import Any, Dict, List, Optional
def _get_client_tool_name(tc: Dict) -> str:
"""Return the original tool name for client-facing responses.
For client-side tools the ``tool_name`` field carries the name the
client originally registered. Fall back to ``action_name`` (which
is now the clean LLM-visible name) or ``name``.
"""
return tc.get("tool_name", tc.get("action_name", tc.get("name", "")))
# ---------------------------------------------------------------------------
# Request translation
# ---------------------------------------------------------------------------
def is_continuation(messages: List[Dict]) -> bool:
"""Check if messages represent a tool-call continuation.
A continuation is detected when the last message(s) have ``role: "tool"``
immediately after an assistant message with ``tool_calls``.
"""
if not messages:
return False
# Walk backwards: if we see tool messages before hitting a non-tool, non-assistant message
# and there's an assistant message with tool_calls, it's a continuation.
i = len(messages) - 1
while i >= 0 and messages[i].get("role") == "tool":
i -= 1
if i < 0:
return False
return (
messages[i].get("role") == "assistant"
and bool(messages[i].get("tool_calls"))
)
def extract_tool_results(messages: List[Dict]) -> List[Dict]:
"""Extract tool results from trailing tool messages for continuation.
Returns a list of ``tool_actions`` dicts with ``call_id`` and ``result``.
"""
results = []
for msg in reversed(messages):
if msg.get("role") != "tool":
break
call_id = msg.get("tool_call_id", "")
content = msg.get("content", "")
if isinstance(content, str):
try:
content = json.loads(content)
except (json.JSONDecodeError, TypeError):
pass
results.append({"call_id": call_id, "result": content})
results.reverse()
return results
def extract_conversation_id(messages: List[Dict]) -> Optional[str]:
"""Try to extract conversation_id from the assistant message before tool results.
The conversation_id may be stored in a custom field on the assistant message
from a previous response cycle.
"""
for msg in reversed(messages):
if msg.get("role") == "assistant":
# Check docsgpt extension
return msg.get("docsgpt", {}).get("conversation_id")
return None
def extract_system_prompt(messages: List[Dict]) -> Optional[str]:
"""Extract the first system message content from the messages array.
Returns None if no system message is present.
"""
for msg in messages:
if msg.get("role") == "system":
return msg.get("content", "")
return None
def convert_history(messages: List[Dict]) -> List[Dict]:
"""Convert chat completions messages array to DocsGPT history format.
DocsGPT history is a list of ``{prompt, response}`` dicts.
Excludes the last user message (that becomes the ``question``).
"""
history = []
i = 0
while i < len(messages):
msg = messages[i]
if msg.get("role") == "system":
i += 1
continue
if msg.get("role") == "user":
# Look ahead for assistant response
if i + 1 < len(messages) and messages[i + 1].get("role") == "assistant":
content = messages[i + 1].get("content") or ""
history.append({
"prompt": msg.get("content", ""),
"response": content,
})
i += 2
continue
# Last user message without response — skip (it's the question)
i += 1
continue
i += 1
return history
def translate_request(
data: Dict[str, Any], api_key: str
) -> Dict[str, Any]:
"""Translate a chat completions request to DocsGPT internal format.
Args:
data: The incoming request body.
api_key: Agent API key from the Authorization header.
Returns:
Dict suitable for passing to ``StreamProcessor``.
"""
messages = data.get("messages", [])
# Check for continuation (tool results after assistant tool_calls)
if is_continuation(messages):
tool_actions = extract_tool_results(messages)
conversation_id = extract_conversation_id(messages)
if not conversation_id:
conversation_id = data.get("conversation_id")
result = {
"conversation_id": conversation_id,
"tool_actions": tool_actions,
"api_key": api_key,
}
# Carry tools forward for next iteration
if data.get("tools"):
result["client_tools"] = data["tools"]
return result
# Normal request — extract question from last user message
question = ""
for msg in reversed(messages):
if msg.get("role") == "user":
question = msg.get("content", "")
break
history = convert_history(messages)
system_prompt_override = extract_system_prompt(messages)
docsgpt = data.get("docsgpt", {})
result = {
"question": question,
"api_key": api_key,
"history": json.dumps(history),
# Conversations are NOT persisted by default on the v1 endpoint.
# Callers opt in via ``docsgpt.save_conversation: true``.
"save_conversation": bool(docsgpt.get("save_conversation", False)),
}
if system_prompt_override is not None:
result["system_prompt_override"] = system_prompt_override
# Client tools
if data.get("tools"):
result["client_tools"] = data["tools"]
# DocsGPT extensions
if docsgpt.get("attachments"):
result["attachments"] = docsgpt["attachments"]
return result
# ---------------------------------------------------------------------------
# Response translation (non-streaming)
# ---------------------------------------------------------------------------
def translate_response(
conversation_id: str,
answer: str,
sources: Optional[List[Dict]],
tool_calls: Optional[List[Dict]],
thought: str,
model_name: str,
pending_tool_calls: Optional[List[Dict]] = None,
) -> Dict[str, Any]:
"""Translate DocsGPT response to chat completions format.
Args:
conversation_id: The DocsGPT conversation ID.
answer: The assistant's text response.
sources: RAG retrieval sources.
tool_calls: Completed tool call results.
thought: Reasoning/thinking tokens.
model_name: Model/agent identifier.
pending_tool_calls: Pending client-side tool calls (if paused).
Returns:
Dict in the standard chat completions response format.
"""
created = int(time.time())
completion_id = f"chatcmpl-{conversation_id}" if conversation_id else f"chatcmpl-{created}"
# Build message
message: Dict[str, Any] = {"role": "assistant"}
if pending_tool_calls:
# Tool calls pending — return them for client execution
message["content"] = None
message["tool_calls"] = [
{
"id": tc.get("call_id", ""),
"type": "function",
"function": {
"name": _get_client_tool_name(tc),
"arguments": (
json.dumps(tc["arguments"])
if isinstance(tc.get("arguments"), dict)
else tc.get("arguments", "{}")
),
},
}
for tc in pending_tool_calls
]
finish_reason = "tool_calls"
else:
message["content"] = answer
if thought:
message["reasoning_content"] = thought
finish_reason = "stop"
result: Dict[str, Any] = {
"id": completion_id,
"object": "chat.completion",
"created": created,
"model": model_name,
"choices": [
{
"index": 0,
"message": message,
"finish_reason": finish_reason,
}
],
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
},
}
# DocsGPT extensions
docsgpt: Dict[str, Any] = {}
if conversation_id:
docsgpt["conversation_id"] = conversation_id
if sources:
docsgpt["sources"] = sources
if tool_calls:
docsgpt["tool_calls"] = tool_calls
if docsgpt:
result["docsgpt"] = docsgpt
return result
# ---------------------------------------------------------------------------
# Streaming event translation
# ---------------------------------------------------------------------------
def _make_chunk(
completion_id: str,
model_name: str,
delta: Dict[str, Any],
finish_reason: Optional[str] = None,
) -> str:
"""Build a single SSE chunk in the standard streaming format."""
chunk = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model_name,
"choices": [
{
"index": 0,
"delta": delta,
"finish_reason": finish_reason,
}
],
}
return f"data: {json.dumps(chunk)}\n\n"
def _make_docsgpt_chunk(data: Dict[str, Any]) -> str:
"""Build a DocsGPT extension SSE chunk."""
return f"data: {json.dumps({'docsgpt': data})}\n\n"
def translate_stream_event(
event_data: Dict[str, Any],
completion_id: str,
model_name: str,
) -> List[str]:
"""Translate a DocsGPT SSE event dict to standard streaming chunks.
May return 0, 1, or 2 chunks per input event. For example, a completed
tool call produces both a docsgpt extension chunk and nothing on the
standard side (since server-side tool calls aren't surfaced in standard
format).
Args:
event_data: Parsed DocsGPT event dict.
completion_id: The completion ID for this response.
model_name: Model/agent identifier.
Returns:
List of SSE-formatted strings to send to the client.
"""
event_type = event_data.get("type")
chunks: List[str] = []
if event_type == "answer":
chunks.append(
_make_chunk(completion_id, model_name, {"content": event_data.get("answer", "")})
)
elif event_type == "thought":
chunks.append(
_make_chunk(
completion_id, model_name,
{"reasoning_content": event_data.get("thought", "")},
)
)
elif event_type == "source":
chunks.append(
_make_docsgpt_chunk({
"type": "source",
"sources": event_data.get("source", []),
})
)
elif event_type == "tool_call":
tc_data = event_data.get("data", {})
status = tc_data.get("status")
if status == "requires_client_execution":
# Standard: stream as tool_calls delta
args = tc_data.get("arguments", {})
args_str = json.dumps(args) if isinstance(args, dict) else str(args)
chunks.append(
_make_chunk(completion_id, model_name, {
"tool_calls": [{
"index": 0,
"id": tc_data.get("call_id", ""),
"type": "function",
"function": {
"name": _get_client_tool_name(tc_data),
"arguments": args_str,
},
}],
})
)
elif status == "awaiting_approval":
# Extension: approval needed
chunks.append(_make_docsgpt_chunk({"type": "tool_call", "data": tc_data}))
elif status in ("completed", "pending", "error", "denied", "skipped"):
# Extension: tool call progress
chunks.append(_make_docsgpt_chunk({"type": "tool_call", "data": tc_data}))
elif event_type == "tool_calls_pending":
# Standard: finish_reason = tool_calls
chunks.append(
_make_chunk(completion_id, model_name, {}, finish_reason="tool_calls")
)
# Also emit as docsgpt extension
chunks.append(
_make_docsgpt_chunk({
"type": "tool_calls_pending",
"pending_tool_calls": event_data.get("data", {}).get("pending_tool_calls", []),
})
)
elif event_type == "end":
chunks.append(
_make_chunk(completion_id, model_name, {}, finish_reason="stop")
)
chunks.append("data: [DONE]\n\n")
elif event_type == "id":
chunks.append(
_make_docsgpt_chunk({
"type": "id",
"conversation_id": event_data.get("id", ""),
})
)
elif event_type == "error":
# Emit as standard error (non-standard but widely supported)
error_data = {
"error": {
"message": event_data.get("error", "An error occurred"),
"type": "server_error",
}
}
chunks.append(f"data: {json.dumps(error_data)}\n\n")
elif event_type == "structured_answer":
chunks.append(
_make_chunk(
completion_id, model_name,
{"content": event_data.get("answer", "")},
)
)
# Skip: tool_calls (redundant), research_plan, research_progress
return chunks

View File

@@ -1,9 +1,10 @@
import logging
import os
import platform
import uuid
import dotenv
from flask import Flask, jsonify, redirect, request
from flask import Flask, Response, jsonify, redirect, request
from jose import jwt
from application.auth import handle_auth
@@ -17,8 +18,14 @@ from application.api.answer import answer # noqa: E402
from application.api.internal.routes import internal # noqa: E402
from application.api.user.routes import user # noqa: E402
from application.api.connector.routes import connector # noqa: E402
from application.api.v1 import v1_bp # noqa: E402
from application.celery_init import celery # noqa: E402
from application.core.settings import settings # noqa: E402
from application.storage.db.bootstrap import ensure_database_ready # noqa: E402
from application.stt.upload_limits import ( # noqa: E402
build_stt_file_size_limit_message,
should_reject_stt_request,
)
if platform.system() == "Windows":
@@ -27,11 +34,23 @@ if platform.system() == "Windows":
pathlib.PosixPath = pathlib.WindowsPath
dotenv.load_dotenv()
# Self-bootstrap the user-data Postgres DB. Runs before any blueprint or
# repository touches the engine, so the first request can't race the
# schema being created. Gated by AUTO_CREATE_DB / AUTO_MIGRATE settings
# (default ON for dev; disable in prod if schema is managed out-of-band).
ensure_database_ready(
settings.POSTGRES_URI,
create_db=settings.AUTO_CREATE_DB,
migrate=settings.AUTO_MIGRATE,
logger=logging.getLogger("application.app"),
)
app = Flask(__name__)
app.register_blueprint(user)
app.register_blueprint(answer)
app.register_blueprint(internal)
app.register_blueprint(connector)
app.register_blueprint(v1_bp)
app.config.update(
UPLOAD_FOLDER="inputs",
CELERY_BROKER_URL=settings.CELERY_BROKER_URL,
@@ -68,6 +87,11 @@ def home():
return "Welcome to DocsGPT Backend!"
@app.route("/api/health")
def health():
return jsonify({"status": "ok"})
@app.route("/api/config")
def get_config():
response = {
@@ -88,10 +112,33 @@ def generate_token():
return jsonify({"error": "Token generation not allowed in current auth mode"}), 400
@app.before_request
def enforce_stt_request_size_limits():
if request.method == "OPTIONS":
return None
if should_reject_stt_request(request.path, request.content_length):
return (
jsonify(
{
"success": False,
"message": build_stt_file_size_limit_message(),
}
),
413,
)
return None
@app.before_request
def authenticate_request():
if request.method == "OPTIONS":
return "", 200
# OpenAI-compatible routes authenticate via opaque agent API keys in the
# Authorization header, which the JWT decoder below would reject. Defer
# auth to the route handlers (see application/api/v1/routes.py).
if request.path.startswith("/v1/"):
request.decoded_token = None
return None
decoded_token = handle_auth(request)
if not decoded_token:
request.decoded_token = None
@@ -102,12 +149,11 @@ def authenticate_request():
@app.after_request
def after_request(response):
response.headers.add("Access-Control-Allow-Origin", "*")
response.headers.add("Access-Control-Allow-Headers", "Content-Type, Authorization")
response.headers.add(
"Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS"
)
def after_request(response: Response) -> Response:
"""Add CORS headers for the pure Flask development entrypoint."""
response.headers["Access-Control-Allow-Origin"] = "*"
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization"
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS"
return response

33
application/asgi.py Normal file
View File

@@ -0,0 +1,33 @@
"""ASGI entrypoint: Flask (WSGI) + FastMCP on the same process."""
from __future__ import annotations
from a2wsgi import WSGIMiddleware
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.cors import CORSMiddleware
from starlette.routing import Mount
from application.app import app as flask_app
from application.mcp_server import mcp
_WSGI_THREADPOOL = 32
mcp_app = mcp.http_app(path="/")
asgi_app = Starlette(
routes=[
Mount("/mcp", app=mcp_app),
Mount("/", app=WSGIMiddleware(flask_app, workers=_WSGI_THREADPOOL)),
],
middleware=[
Middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
allow_headers=["Content-Type", "Authorization", "Mcp-Session-Id"],
expose_headers=["Mcp-Session-Id"],
),
],
lifespan=mcp_app.lifespan,
)

View File

@@ -1,6 +1,8 @@
import threading
from celery import Celery
from application.core.settings import settings
from celery.signals import setup_logging
from celery.signals import setup_logging, worker_process_init, worker_ready
def make_celery(app_name=__name__):
@@ -20,5 +22,44 @@ def config_loggers(*args, **kwargs):
setup_logging()
@worker_process_init.connect
def _dispose_db_engine_on_fork(*args, **kwargs):
"""Dispose the SQLAlchemy engine pool in each forked Celery worker.
SQLAlchemy connection pools are not fork-safe: file descriptors shared
between the parent and a forked worker will corrupt the pool. Disposing
on ``worker_process_init`` gives every worker its own fresh pool on
first use.
Imported lazily so Celery workers that don't touch Postgres (or where
``POSTGRES_URI`` is unset) don't fail at startup.
"""
try:
from application.storage.db.engine import dispose_engine
except Exception:
return
dispose_engine()
@worker_ready.connect
def _run_version_check(*args, **kwargs):
"""Kick off the anonymous version check on worker startup.
Runs in a daemon thread so a slow endpoint or bad DNS never holds
up the worker becoming ready for tasks. The check itself is
fail-silent (see ``application.updates.version_check.run_check``);
this handler's only job is to launch it and get out of the way.
Import is lazy so the symbol resolution never fires at module
import time — consistent with the ``_dispose_db_engine_on_fork``
pattern above.
"""
try:
from application.updates.version_check import run_check
except Exception:
return
threading.Thread(target=run_check, name="version-check", daemon=True).start()
celery = make_celery()
celery.config_from_object("application.celeryconfig")

View File

@@ -9,3 +9,8 @@ accept_content = ['json']
# Autodiscover tasks
imports = ('application.api.user.tasks',)
beat_scheduler = "redbeat.RedBeatScheduler"
redbeat_redis_url = broker_url
redbeat_key_prefix = "redbeat:docsgpt:"
redbeat_lock_timeout = 90

View File

@@ -0,0 +1,89 @@
"""Normalize user-supplied Postgres URIs for different drivers.
DocsGPT has two Postgres connection strings pointing at potentially
different databases:
* ``POSTGRES_URI`` feeds SQLAlchemy, which needs the
``postgresql+psycopg://`` dialect prefix to pick the psycopg v3 driver.
* ``PGVECTOR_CONNECTION_STRING`` feeds ``psycopg.connect()`` directly
(via libpq) in ``application/vectorstore/pgvector.py``. libpq only
understands ``postgres://`` and ``postgresql://`` — the SQLAlchemy
dialect prefix is an invalid URI from its point of view.
The two fields therefore need opposite normalization so operators don't
have to know which driver a given field feeds. Each normalizer also
silently upgrades the legacy ``postgresql+psycopg2://`` prefix since
psycopg2 is no longer in the project.
This module is deliberately separate from ``application/core/settings.py``
so the Settings class stays focused on field declarations, and the
URI-rewriting logic can be unit-tested without triggering ``.env``
file loading from importing Settings.
"""
from __future__ import annotations
def _rewrite_uri_prefixes(v, rewrites):
"""Shared URI prefix rewriter used by both normalizers below.
Strips whitespace, returns ``None`` for empty / ``"none"`` values,
applies the first matching rewrite, and passes unrecognised input
through so downstream consumers (SQLAlchemy, libpq) can produce
their own error messages rather than us silently eating a
misconfiguration.
"""
if v is None:
return None
if not isinstance(v, str):
return v
v = v.strip()
if not v or v.lower() == "none":
return None
for prefix, target in rewrites:
if v.startswith(prefix):
return target + v[len(prefix):]
return v
# POSTGRES_URI feeds SQLAlchemy, which needs a ``postgresql+psycopg://``
# dialect prefix to select the psycopg v3 driver. Normalize the
# operator-friendly forms TOWARD that dialect.
_POSTGRES_URI_REWRITES = (
("postgresql+psycopg2://", "postgresql+psycopg://"),
("postgresql://", "postgresql+psycopg://"),
("postgres://", "postgresql+psycopg://"),
)
# PGVECTOR_CONNECTION_STRING feeds ``psycopg.connect()`` directly in
# application/vectorstore/pgvector.py — NOT SQLAlchemy. libpq only
# understands ``postgres://`` and ``postgresql://``; the SQLAlchemy
# dialect prefix is an invalid URI from libpq's point of view. Strip it
# if the operator accidentally copied their POSTGRES_URI value here.
_PGVECTOR_CONNECTION_STRING_REWRITES = (
("postgresql+psycopg2://", "postgresql://"),
("postgresql+psycopg://", "postgresql://"),
)
def normalize_postgres_uri(v):
"""Normalize a user-supplied POSTGRES_URI to the SQLAlchemy psycopg3 form.
Accepts the forms operators naturally write (``postgres://``,
``postgresql://``) and rewrites them to ``postgresql+psycopg://``.
Unknown schemes pass through unchanged so SQLAlchemy can produce its
own dialect-not-found error.
"""
return _rewrite_uri_prefixes(v, _POSTGRES_URI_REWRITES)
def normalize_pgvector_connection_string(v):
"""Normalize a user-supplied PGVECTOR_CONNECTION_STRING for libpq.
Strips the SQLAlchemy dialect prefix if the operator accidentally
copied their POSTGRES_URI value here — libpq can't parse it.
User-friendly forms (``postgres://``, ``postgresql://``) pass
through unchanged since libpq accepts them natively.
"""
return _rewrite_uri_prefixes(v, _PGVECTOR_CONNECTION_STRING_REWRITES)

View File

@@ -1,11 +1,45 @@
import logging
import os
from logging.config import dictConfig
def setup_logging():
def _otlp_logs_enabled() -> bool:
"""Return True when the user has opted in to OTLP log export.
Gated by the standard OTEL env vars so no project-specific knob is needed:
set ``OTEL_LOGS_EXPORTER=otlp`` (and leave ``OTEL_SDK_DISABLED`` unset or
false) to flip it on. When false, ``setup_logging`` keeps its original
console-only behavior.
"""
exporter = os.getenv("OTEL_LOGS_EXPORTER", "").strip().lower()
disabled = os.getenv("OTEL_SDK_DISABLED", "false").strip().lower() == "true"
return exporter == "otlp" and not disabled
def setup_logging() -> None:
"""Configure the root logger with a stdout console handler.
When OTLP log export is enabled, ``opentelemetry-instrument`` attaches a
``LoggingHandler`` to the root logger before this function runs. The
``dictConfig`` call below replaces ``root.handlers`` with the console
handler, which would silently drop the OTEL handler. To make OTLP log
export work without forcing every contributor to opt in, snapshot the
OTEL handlers up front and re-attach them after ``dictConfig``.
"""
preserved_handlers: list[logging.Handler] = []
if _otlp_logs_enabled():
preserved_handlers = [
h
for h in logging.getLogger().handlers
if h.__class__.__module__.startswith("opentelemetry")
]
dictConfig({
'version': 1,
'formatters': {
'default': {
'format': '[%(asctime)s] %(levelname)s in %(module)s: %(message)s',
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"default": {
"format": "[%(asctime)s] %(levelname)s in %(module)s: %(message)s",
}
},
"handlers": {
@@ -15,8 +49,14 @@ def setup_logging():
"formatter": "default",
}
},
'root': {
'level': 'INFO',
'handlers': ['console'],
"root": {
"level": "INFO",
"handlers": ["console"],
},
})
})
if preserved_handlers:
root = logging.getLogger()
for handler in preserved_handlers:
if handler not in root.handlers:
root.addHandler(handler)

View File

@@ -1,224 +0,0 @@
"""
Model configurations for all supported LLM providers.
"""
from application.core.model_settings import (
AvailableModel,
ModelCapabilities,
ModelProvider,
)
# Base image attachment types supported by most vision-capable LLMs
IMAGE_ATTACHMENTS = [
"image/png",
"image/jpeg",
"image/jpg",
"image/webp",
"image/gif",
]
# PDF excluded: most OpenAI-compatible endpoints don't support native PDF uploads.
# When excluded, PDFs are synthetically processed by converting pages to images.
OPENAI_ATTACHMENTS = IMAGE_ATTACHMENTS
GOOGLE_ATTACHMENTS = ["application/pdf"] + IMAGE_ATTACHMENTS
ANTHROPIC_ATTACHMENTS = IMAGE_ATTACHMENTS
OPENROUTER_ATTACHMENTS = IMAGE_ATTACHMENTS
OPENAI_MODELS = [
AvailableModel(
id="gpt-5.1",
provider=ModelProvider.OPENAI,
display_name="GPT-5.1",
description="Flagship model with enhanced reasoning, coding, and agentic capabilities",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=OPENAI_ATTACHMENTS,
context_window=200000,
),
),
AvailableModel(
id="gpt-5-mini",
provider=ModelProvider.OPENAI,
display_name="GPT-5 Mini",
description="Faster, cost-effective variant of GPT-5.1",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=OPENAI_ATTACHMENTS,
context_window=200000,
),
)
]
ANTHROPIC_MODELS = [
AvailableModel(
id="claude-3-5-sonnet-20241022",
provider=ModelProvider.ANTHROPIC,
display_name="Claude 3.5 Sonnet (Latest)",
description="Latest Claude 3.5 Sonnet with enhanced capabilities",
capabilities=ModelCapabilities(
supports_tools=True,
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
context_window=200000,
),
),
AvailableModel(
id="claude-3-5-sonnet",
provider=ModelProvider.ANTHROPIC,
display_name="Claude 3.5 Sonnet",
description="Balanced performance and capability",
capabilities=ModelCapabilities(
supports_tools=True,
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
context_window=200000,
),
),
AvailableModel(
id="claude-3-opus",
provider=ModelProvider.ANTHROPIC,
display_name="Claude 3 Opus",
description="Most capable Claude model",
capabilities=ModelCapabilities(
supports_tools=True,
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
context_window=200000,
),
),
AvailableModel(
id="claude-3-haiku",
provider=ModelProvider.ANTHROPIC,
display_name="Claude 3 Haiku",
description="Fastest Claude model",
capabilities=ModelCapabilities(
supports_tools=True,
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
context_window=200000,
),
),
]
GOOGLE_MODELS = [
AvailableModel(
id="gemini-flash-latest",
provider=ModelProvider.GOOGLE,
display_name="Gemini Flash (Latest)",
description="Latest experimental Gemini model",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=GOOGLE_ATTACHMENTS,
context_window=int(1e6),
),
),
AvailableModel(
id="gemini-flash-lite-latest",
provider=ModelProvider.GOOGLE,
display_name="Gemini Flash Lite (Latest)",
description="Fast with huge context window",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=GOOGLE_ATTACHMENTS,
context_window=int(1e6),
),
),
AvailableModel(
id="gemini-3-pro-preview",
provider=ModelProvider.GOOGLE,
display_name="Gemini 3 Pro",
description="Most capable Gemini model",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=GOOGLE_ATTACHMENTS,
context_window=2000000,
),
),
]
GROQ_MODELS = [
AvailableModel(
id="llama-3.3-70b-versatile",
provider=ModelProvider.GROQ,
display_name="Llama 3.3 70B",
description="Latest Llama model with high-speed inference",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=128000,
),
),
AvailableModel(
id="openai/gpt-oss-120b",
provider=ModelProvider.GROQ,
display_name="GPT-OSS 120B",
description="Open-source GPT model optimized for speed",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=128000,
),
),
]
OPENROUTER_MODELS = [
AvailableModel(
id="qwen/qwen3-coder:free",
provider=ModelProvider.OPENROUTER,
display_name="Qwen 3 Coder",
description="Latest Qwen model with high-speed inference",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=128000,
supported_attachment_types=OPENROUTER_ATTACHMENTS
),
),
AvailableModel(
id="google/gemma-3-27b-it:free",
provider=ModelProvider.OPENROUTER,
display_name="Gemma 3 27B",
description="Latest Gemma model with high-speed inference",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=128000,
supported_attachment_types=OPENROUTER_ATTACHMENTS
),
),
]
AZURE_OPENAI_MODELS = [
AvailableModel(
id="azure-gpt-4",
provider=ModelProvider.AZURE_OPENAI,
display_name="Azure OpenAI GPT-4",
description="Azure-hosted GPT model",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=OPENAI_ATTACHMENTS,
context_window=8192,
),
),
]
def create_custom_openai_model(model_name: str, base_url: str) -> AvailableModel:
"""Create a custom OpenAI-compatible model (e.g., LM Studio, Ollama)."""
return AvailableModel(
id=model_name,
provider=ModelProvider.OPENAI,
display_name=model_name,
description=f"Custom OpenAI-compatible model at {base_url}",
base_url=base_url,
capabilities=ModelCapabilities(
supports_tools=True,
supported_attachment_types=OPENAI_ATTACHMENTS,
),
)

View File

@@ -0,0 +1,164 @@
"""Layered model registry.
Loads model catalogs from YAML files (built-in + operator-supplied),
groups them by provider name, then for each registered provider plugin
calls ``get_models`` to produce the final per-provider model list.
The ``user_id`` parameter on lookup methods is reserved for the future
end-user BYOM (per-user model records in Postgres). It is currently
ignored — defaulted to ``None`` everywhere — so call sites can be
threaded through without a wide refactor when BYOM lands.
"""
from __future__ import annotations
import logging
from collections import defaultdict
from typing import Dict, List, Optional
from application.core.model_settings import AvailableModel
from application.core.model_yaml import (
BUILTIN_MODELS_DIR,
ProviderCatalog,
load_model_yamls,
)
logger = logging.getLogger(__name__)
class ModelRegistry:
"""Singleton registry of available models."""
_instance: Optional["ModelRegistry"] = None
_initialized: bool = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not ModelRegistry._initialized:
self.models: Dict[str, AvailableModel] = {}
self.default_model_id: Optional[str] = None
self._load_models()
ModelRegistry._initialized = True
@classmethod
def get_instance(cls) -> "ModelRegistry":
return cls()
@classmethod
def reset(cls) -> None:
"""Clear the singleton. Intended for test fixtures."""
cls._instance = None
cls._initialized = False
def _load_models(self) -> None:
from pathlib import Path
from application.core.settings import settings
from application.llm.providers import ALL_PROVIDERS
directories = [BUILTIN_MODELS_DIR]
operator_dir = getattr(settings, "MODELS_CONFIG_DIR", None)
if operator_dir:
op_path = Path(operator_dir)
if not op_path.exists():
logger.warning(
"MODELS_CONFIG_DIR=%s does not exist; no operator "
"model YAMLs will be loaded.",
operator_dir,
)
elif not op_path.is_dir():
logger.warning(
"MODELS_CONFIG_DIR=%s is not a directory; no operator "
"model YAMLs will be loaded.",
operator_dir,
)
else:
directories.append(op_path)
catalogs = load_model_yamls(directories)
# Validate every catalog targets a known plugin before doing any
# registry work, so an unknown provider name in YAML aborts boot
# with a clear error.
plugin_names = {p.name for p in ALL_PROVIDERS}
for c in catalogs:
if c.provider not in plugin_names:
raise ValueError(
f"{c.source_path}: YAML declares unknown provider "
f"{c.provider!r}; no Provider plugin is registered "
f"under that name. Known: {sorted(plugin_names)}"
)
catalogs_by_provider: Dict[str, List[ProviderCatalog]] = defaultdict(list)
for c in catalogs:
catalogs_by_provider[c.provider].append(c)
self.models.clear()
for provider in ALL_PROVIDERS:
if not provider.is_enabled(settings):
continue
for model in provider.get_models(
settings, catalogs_by_provider.get(provider.name, [])
):
self.models[model.id] = model
self.default_model_id = self._resolve_default(settings)
logger.info(
"ModelRegistry loaded %d models, default: %s",
len(self.models),
self.default_model_id,
)
def _resolve_default(self, settings) -> Optional[str]:
if settings.LLM_NAME:
for name in self._parse_model_names(settings.LLM_NAME):
if name in self.models:
return name
if settings.LLM_NAME in self.models:
return settings.LLM_NAME
if settings.LLM_PROVIDER and settings.API_KEY:
for model_id, model in self.models.items():
if model.provider.value == settings.LLM_PROVIDER:
return model_id
if self.models:
return next(iter(self.models.keys()))
return None
@staticmethod
def _parse_model_names(llm_name: str) -> List[str]:
if not llm_name:
return []
return [name.strip() for name in llm_name.split(",") if name.strip()]
# ------------------------------------------------------------------
# Lookup API. ``user_id`` is reserved for the future BYOM and
# is ignored today — but threading it through every call site now
# means BYOM doesn't require a wide refactor when we build it.
# ------------------------------------------------------------------
def get_model(
self, model_id: str, user_id: Optional[str] = None
) -> Optional[AvailableModel]:
return self.models.get(model_id)
def get_all_models(
self, user_id: Optional[str] = None
) -> List[AvailableModel]:
return list(self.models.values())
def get_enabled_models(
self, user_id: Optional[str] = None
) -> List[AvailableModel]:
return [m for m in self.models.values() if m.enabled]
def model_exists(
self, model_id: str, user_id: Optional[str] = None
) -> bool:
return model_id in self.models

View File

@@ -5,9 +5,16 @@ from typing import Dict, List, Optional
logger = logging.getLogger(__name__)
# Re-exported here so existing call sites (and tests) that do
# ``from application.core.model_settings import ModelRegistry`` keep
# working. The implementation lives in ``application/core/model_registry.py``.
# Imported lazily inside ``__getattr__`` to avoid an import cycle with
# ``model_yaml`` → ``model_settings`` (this file).
class ModelProvider(str, Enum):
OPENAI = "openai"
OPENAI_COMPATIBLE = "openai_compatible"
OPENROUTER = "openrouter"
AZURE_OPENAI = "azure_openai"
ANTHROPIC = "anthropic"
@@ -41,11 +48,20 @@ class AvailableModel:
capabilities: ModelCapabilities = field(default_factory=ModelCapabilities)
enabled: bool = True
base_url: Optional[str] = None
# User-facing label distinct from the dispatch ``provider``. Used by
# openai_compatible YAMLs so a Mistral model shows "mistral" in the
# API response while still routing through the OpenAI wire format.
display_provider: Optional[str] = None
# Per-record API key. Operator YAMLs leave this None; populated for
# openai_compatible models (resolved from the YAML's ``api_key_env``)
# and reserved for the future end-user BYOM phase. Never serialized
# into to_dict().
api_key: Optional[str] = field(default=None, repr=False, compare=False)
def to_dict(self) -> Dict:
result = {
"id": self.id,
"provider": self.provider.value,
"provider": self.display_provider or self.provider.value,
"display_name": self.display_name,
"description": self.description,
"supported_attachment_types": self.capabilities.supported_attachment_types,
@@ -60,236 +76,14 @@ class AvailableModel:
return result
class ModelRegistry:
_instance = None
_initialized = False
def __getattr__(name):
"""Lazy re-export of ``ModelRegistry`` from ``model_registry.py``.
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
Done lazily to avoid an import cycle: ``model_registry`` imports
``model_yaml`` which imports the dataclasses from this file.
"""
if name == "ModelRegistry":
from application.core.model_registry import ModelRegistry as _MR
def __init__(self):
if not ModelRegistry._initialized:
self.models: Dict[str, AvailableModel] = {}
self.default_model_id: Optional[str] = None
self._load_models()
ModelRegistry._initialized = True
@classmethod
def get_instance(cls) -> "ModelRegistry":
return cls()
def _load_models(self):
from application.core.settings import settings
self.models.clear()
# Skip DocsGPT model if using custom OpenAI-compatible endpoint
if not settings.OPENAI_BASE_URL:
self._add_docsgpt_models(settings)
if (
settings.OPENAI_API_KEY
or (settings.LLM_PROVIDER == "openai" and settings.API_KEY)
or settings.OPENAI_BASE_URL
):
self._add_openai_models(settings)
if settings.OPENAI_API_BASE or (
settings.LLM_PROVIDER == "azure_openai" and settings.API_KEY
):
self._add_azure_openai_models(settings)
if settings.ANTHROPIC_API_KEY or (
settings.LLM_PROVIDER == "anthropic" and settings.API_KEY
):
self._add_anthropic_models(settings)
if settings.GOOGLE_API_KEY or (
settings.LLM_PROVIDER == "google" and settings.API_KEY
):
self._add_google_models(settings)
if settings.GROQ_API_KEY or (
settings.LLM_PROVIDER == "groq" and settings.API_KEY
):
self._add_groq_models(settings)
if settings.OPEN_ROUTER_API_KEY or (
settings.LLM_PROVIDER == "openrouter" and settings.API_KEY
):
self._add_openrouter_models(settings)
if settings.HUGGINGFACE_API_KEY or (
settings.LLM_PROVIDER == "huggingface" and settings.API_KEY
):
self._add_huggingface_models(settings)
# Default model selection
if settings.LLM_NAME:
# Parse LLM_NAME (may be comma-separated)
model_names = self._parse_model_names(settings.LLM_NAME)
# First model in the list becomes default
for model_name in model_names:
if model_name in self.models:
self.default_model_id = model_name
break
# Backward compat: try exact match if no parsed model found
if not self.default_model_id and settings.LLM_NAME in self.models:
self.default_model_id = settings.LLM_NAME
if not self.default_model_id:
if settings.LLM_PROVIDER and settings.API_KEY:
for model_id, model in self.models.items():
if model.provider.value == settings.LLM_PROVIDER:
self.default_model_id = model_id
break
if not self.default_model_id and self.models:
self.default_model_id = next(iter(self.models.keys()))
logger.info(
f"ModelRegistry loaded {len(self.models)} models, default: {self.default_model_id}"
)
def _add_openai_models(self, settings):
from application.core.model_configs import (
OPENAI_MODELS,
create_custom_openai_model,
)
# Check if using local OpenAI-compatible endpoint (Ollama, LM Studio, etc.)
using_local_endpoint = bool(
settings.OPENAI_BASE_URL and settings.OPENAI_BASE_URL.strip()
)
if using_local_endpoint:
# When OPENAI_BASE_URL is set, ONLY register custom models from LLM_NAME
# Do NOT add standard OpenAI models (gpt-5.1, etc.)
if settings.LLM_NAME:
model_names = self._parse_model_names(settings.LLM_NAME)
for model_name in model_names:
custom_model = create_custom_openai_model(
model_name, settings.OPENAI_BASE_URL
)
self.models[model_name] = custom_model
logger.info(
f"Registered custom OpenAI model: {model_name} at {settings.OPENAI_BASE_URL}"
)
else:
# Standard OpenAI API usage - add standard models if API key is valid
if settings.OPENAI_API_KEY:
for model in OPENAI_MODELS:
self.models[model.id] = model
def _add_azure_openai_models(self, settings):
from application.core.model_configs import AZURE_OPENAI_MODELS
if settings.LLM_PROVIDER == "azure_openai" and settings.LLM_NAME:
for model in AZURE_OPENAI_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in AZURE_OPENAI_MODELS:
self.models[model.id] = model
def _add_anthropic_models(self, settings):
from application.core.model_configs import ANTHROPIC_MODELS
if settings.ANTHROPIC_API_KEY:
for model in ANTHROPIC_MODELS:
self.models[model.id] = model
return
if settings.LLM_PROVIDER == "anthropic" and settings.LLM_NAME:
for model in ANTHROPIC_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in ANTHROPIC_MODELS:
self.models[model.id] = model
def _add_google_models(self, settings):
from application.core.model_configs import GOOGLE_MODELS
if settings.GOOGLE_API_KEY:
for model in GOOGLE_MODELS:
self.models[model.id] = model
return
if settings.LLM_PROVIDER == "google" and settings.LLM_NAME:
for model in GOOGLE_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in GOOGLE_MODELS:
self.models[model.id] = model
def _add_groq_models(self, settings):
from application.core.model_configs import GROQ_MODELS
if settings.GROQ_API_KEY:
for model in GROQ_MODELS:
self.models[model.id] = model
return
if settings.LLM_PROVIDER == "groq" and settings.LLM_NAME:
for model in GROQ_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in GROQ_MODELS:
self.models[model.id] = model
def _add_openrouter_models(self, settings):
from application.core.model_configs import OPENROUTER_MODELS
if settings.OPEN_ROUTER_API_KEY:
for model in OPENROUTER_MODELS:
self.models[model.id] = model
return
if settings.LLM_PROVIDER == "openrouter" and settings.LLM_NAME:
for model in OPENROUTER_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in OPENROUTER_MODELS:
self.models[model.id] = model
def _add_docsgpt_models(self, settings):
model_id = "docsgpt-local"
model = AvailableModel(
id=model_id,
provider=ModelProvider.DOCSGPT,
display_name="DocsGPT Model",
description="Local model",
capabilities=ModelCapabilities(
supports_tools=False,
supported_attachment_types=[],
),
)
self.models[model_id] = model
def _add_huggingface_models(self, settings):
model_id = "huggingface-local"
model = AvailableModel(
id=model_id,
provider=ModelProvider.HUGGINGFACE,
display_name="Hugging Face Model",
description="Local Hugging Face model",
capabilities=ModelCapabilities(
supports_tools=False,
supported_attachment_types=[],
),
)
self.models[model_id] = model
def _parse_model_names(self, llm_name: str) -> List[str]:
"""
Parse LLM_NAME which may contain comma-separated model names.
E.g., 'deepseek-r1:1.5b,gemma:2b' -> ['deepseek-r1:1.5b', 'gemma:2b']
"""
if not llm_name:
return []
return [name.strip() for name in llm_name.split(",") if name.strip()]
def get_model(self, model_id: str) -> Optional[AvailableModel]:
return self.models.get(model_id)
def get_all_models(self) -> List[AvailableModel]:
return list(self.models.values())
def get_enabled_models(self) -> List[AvailableModel]:
return [m for m in self.models.values() if m.enabled]
def model_exists(self, model_id: str) -> bool:
return model_id in self.models
return _MR
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@@ -1,27 +1,22 @@
from typing import Any, Dict, Optional
from application.core.model_settings import ModelRegistry
from application.core.model_registry import ModelRegistry
def get_api_key_for_provider(provider: str) -> Optional[str]:
"""Get the appropriate API key for a provider"""
"""Get the appropriate API key for a provider.
Delegates to the provider plugin's ``get_api_key``. Falls back to the
generic ``settings.API_KEY`` for unknown providers.
"""
from application.core.settings import settings
from application.llm.providers import PROVIDERS_BY_NAME
provider_key_map = {
"openai": settings.OPENAI_API_KEY,
"openrouter": settings.OPEN_ROUTER_API_KEY,
"anthropic": settings.ANTHROPIC_API_KEY,
"google": settings.GOOGLE_API_KEY,
"groq": settings.GROQ_API_KEY,
"huggingface": settings.HUGGINGFACE_API_KEY,
"azure_openai": settings.API_KEY,
"docsgpt": None,
"llama.cpp": None,
}
provider_key = provider_key_map.get(provider)
if provider_key:
return provider_key
plugin = PROVIDERS_BY_NAME.get(provider)
if plugin is not None:
key = plugin.get_api_key(settings)
if key:
return key
return settings.API_KEY
@@ -90,3 +85,21 @@ def get_base_url_for_model(model_id: str) -> Optional[str]:
if model:
return model.base_url
return None
def get_api_key_for_model(model_id: str) -> Optional[str]:
"""
Resolve the API key to use when invoking ``model_id``.
Priority:
1. The model record's own ``api_key`` (reserved for future end-user
BYOM where credentials travel with the record).
2. The provider plugin's settings-based key.
"""
registry = ModelRegistry.get_instance()
model = registry.get_model(model_id)
if model is not None and model.api_key:
return model.api_key
if model is not None:
return get_api_key_for_provider(model.provider.value)
return None

View File

@@ -0,0 +1,325 @@
"""YAML loader for model catalog files under ``application/core/models/``.
Each ``*.yaml`` file declares one provider's static model catalog. Files
are validated with Pydantic at load time; any parse, schema, or alias
error aborts startup with the offending file path in the message.
For most providers, one YAML maps to one catalog. The
``openai_compatible`` provider is special: each YAML file represents a
distinct logical endpoint (Mistral, Together, Ollama, ...) with its own
``api_key_env`` and ``base_url``. The loader returns a flat list so the
registry can distinguish multiple files with the same ``provider:`` value.
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Dict, List, Optional, Sequence
import yaml
from pydantic import BaseModel, ConfigDict, Field, field_validator
from application.core.model_settings import (
AvailableModel,
ModelCapabilities,
ModelProvider,
)
logger = logging.getLogger(__name__)
BUILTIN_MODELS_DIR = Path(__file__).parent / "models"
DEFAULTS_FILENAME = "_defaults.yaml"
class _DefaultsFile(BaseModel):
"""Schema for ``_defaults.yaml``. Currently just attachment aliases."""
model_config = ConfigDict(extra="forbid")
attachment_aliases: Dict[str, List[str]] = Field(default_factory=dict)
class _CapabilityFields(BaseModel):
"""Capability fields shared between provider ``defaults:`` and per-model overrides.
All fields are optional so a per-model override can selectively replace
a single field from the provider-level defaults.
"""
model_config = ConfigDict(extra="forbid")
supports_tools: Optional[bool] = None
supports_structured_output: Optional[bool] = None
supports_streaming: Optional[bool] = None
attachments: Optional[List[str]] = None
context_window: Optional[int] = None
input_cost_per_token: Optional[float] = None
output_cost_per_token: Optional[float] = None
class _ModelEntry(_CapabilityFields):
"""Schema for one model row inside a YAML's ``models:`` list."""
id: str
display_name: Optional[str] = None
description: str = ""
enabled: bool = True
base_url: Optional[str] = None
aliases: List[str] = Field(default_factory=list)
@field_validator("id")
@classmethod
def _id_nonempty(cls, v: str) -> str:
if not v or not v.strip():
raise ValueError("model id must be a non-empty string")
return v
class _ProviderFile(BaseModel):
"""Schema for one ``<provider>.yaml`` catalog file."""
model_config = ConfigDict(extra="forbid")
provider: str
defaults: _CapabilityFields = Field(default_factory=_CapabilityFields)
models: List[_ModelEntry] = Field(default_factory=list)
# openai_compatible metadata. Optional for other providers.
display_provider: Optional[str] = None
api_key_env: Optional[str] = None
base_url: Optional[str] = None
class ProviderCatalog(BaseModel):
"""One YAML file's parsed contents, ready for the registry.
For most providers, multiple catalogs with the same ``provider`` get
merged later by the registry. The ``openai_compatible`` provider is
the exception: each catalog is treated as a distinct endpoint, with
its own ``api_key_env`` and ``base_url``.
"""
provider: str
models: List[AvailableModel]
source_path: Optional[Path] = None
display_provider: Optional[str] = None
api_key_env: Optional[str] = None
base_url: Optional[str] = None
model_config = ConfigDict(arbitrary_types_allowed=True)
class ModelYAMLError(ValueError):
"""Raised when a model YAML fails parsing, schema, or alias validation."""
def _expand_attachments(
attachments: Sequence[str], aliases: Dict[str, List[str]], source: str
) -> List[str]:
"""Resolve attachment shorthands (``image``, ``pdf``) to MIME types.
Raw MIME-typed entries (containing ``/``) pass through unchanged.
Unknown aliases raise ``ModelYAMLError``.
"""
expanded: List[str] = []
seen: set = set()
for entry in attachments:
if "/" in entry:
if entry not in seen:
expanded.append(entry)
seen.add(entry)
continue
if entry not in aliases:
valid = ", ".join(sorted(aliases.keys())) or "<none defined>"
raise ModelYAMLError(
f"{source}: unknown attachment alias '{entry}'. "
f"Valid aliases: {valid}. "
"(Or use a raw MIME type like 'image/png'.)"
)
for mime in aliases[entry]:
if mime not in seen:
expanded.append(mime)
seen.add(mime)
return expanded
def _load_defaults(directory: Path) -> Dict[str, List[str]]:
"""Load ``_defaults.yaml`` from ``directory`` if it exists."""
path = directory / DEFAULTS_FILENAME
if not path.exists():
return {}
try:
raw = yaml.safe_load(path.read_text(encoding="utf-8")) or {}
except yaml.YAMLError as e:
raise ModelYAMLError(f"{path}: invalid YAML: {e}") from e
try:
parsed = _DefaultsFile.model_validate(raw)
except Exception as e:
raise ModelYAMLError(f"{path}: schema error: {e}") from e
return parsed.attachment_aliases
def _resolve_provider_enum(name: str, source: Path) -> ModelProvider:
try:
return ModelProvider(name)
except ValueError as e:
valid = ", ".join(p.value for p in ModelProvider)
raise ModelYAMLError(
f"{source}: unknown provider '{name}'. Valid: {valid}"
) from e
def _build_model(
entry: _ModelEntry,
defaults: _CapabilityFields,
provider: ModelProvider,
aliases: Dict[str, List[str]],
source: Path,
display_provider: Optional[str] = None,
) -> AvailableModel:
"""Merge defaults + per-model overrides into a final ``AvailableModel``."""
def pick(field_name: str, fallback):
v = getattr(entry, field_name)
if v is not None:
return v
d = getattr(defaults, field_name)
if d is not None:
return d
return fallback
raw_attachments = entry.attachments
if raw_attachments is None:
raw_attachments = defaults.attachments
if raw_attachments is None:
raw_attachments = []
expanded = _expand_attachments(
raw_attachments, aliases, f"{source} [model={entry.id}]"
)
caps = ModelCapabilities(
supports_tools=pick("supports_tools", False),
supports_structured_output=pick("supports_structured_output", False),
supports_streaming=pick("supports_streaming", True),
supported_attachment_types=expanded,
context_window=pick("context_window", 128000),
input_cost_per_token=pick("input_cost_per_token", None),
output_cost_per_token=pick("output_cost_per_token", None),
)
return AvailableModel(
id=entry.id,
provider=provider,
display_name=entry.display_name or entry.id,
description=entry.description,
capabilities=caps,
enabled=entry.enabled,
base_url=entry.base_url,
display_provider=display_provider,
)
def _load_one_yaml(
path: Path, aliases: Dict[str, List[str]]
) -> ProviderCatalog:
try:
raw = yaml.safe_load(path.read_text(encoding="utf-8")) or {}
except yaml.YAMLError as e:
raise ModelYAMLError(f"{path}: invalid YAML: {e}") from e
try:
parsed = _ProviderFile.model_validate(raw)
except Exception as e:
raise ModelYAMLError(f"{path}: schema error: {e}") from e
provider_enum = _resolve_provider_enum(parsed.provider, path)
models = [
_build_model(
entry,
parsed.defaults,
provider_enum,
aliases,
path,
display_provider=parsed.display_provider,
)
for entry in parsed.models
]
return ProviderCatalog(
provider=parsed.provider,
models=models,
source_path=path,
display_provider=parsed.display_provider,
api_key_env=parsed.api_key_env,
base_url=parsed.base_url,
)
_BUILTIN_ALIASES_CACHE: Optional[Dict[str, List[str]]] = None
def builtin_attachment_aliases() -> Dict[str, List[str]]:
"""Return the built-in attachment alias map from ``_defaults.yaml``.
Cached after first read so repeat calls are cheap.
"""
global _BUILTIN_ALIASES_CACHE
if _BUILTIN_ALIASES_CACHE is None:
_BUILTIN_ALIASES_CACHE = _load_defaults(BUILTIN_MODELS_DIR)
return _BUILTIN_ALIASES_CACHE
def resolve_attachment_alias(alias: str) -> List[str]:
"""Resolve a single attachment alias (e.g. ``"image"``) to its
canonical MIME-type list. Raises ``ModelYAMLError`` if unknown.
"""
aliases = builtin_attachment_aliases()
if alias not in aliases:
valid = ", ".join(sorted(aliases.keys())) or "<none defined>"
raise ModelYAMLError(
f"Unknown attachment alias '{alias}'. Valid: {valid}"
)
return list(aliases[alias])
def load_model_yamls(directories: Sequence[Path]) -> List[ProviderCatalog]:
"""Load every ``*.yaml`` file (excluding ``_defaults.yaml``) under each
directory in order and return a flat list of catalogs.
Caller is responsible for merging multiple catalogs that target the
same provider plugin. The flat-list shape lets ``openai_compatible``
keep each file separate (one logical endpoint per file).
When the same model ``id`` appears in more than one YAML across the
directory list, a warning is logged. Order in the returned list
preserves load order, so the registry's "later wins" merge gives the
later directory's definition.
"""
catalogs: List[ProviderCatalog] = []
seen_ids: Dict[str, Path] = {}
aliases: Dict[str, List[str]] = {}
for d in directories:
if not d or not d.exists():
continue
aliases.update(_load_defaults(d))
for d in directories:
if not d or not d.exists():
continue
for path in sorted(d.glob("*.yaml")):
if path.name == DEFAULTS_FILENAME:
continue
catalog = _load_one_yaml(path, aliases)
catalogs.append(catalog)
for m in catalog.models:
prior = seen_ids.get(m.id)
if prior is not None and prior != path:
logger.warning(
"Model id %r redefined: %s overrides %s (later wins)",
m.id,
path,
prior,
)
seen_ids[m.id] = path
return catalogs

View File

@@ -0,0 +1,213 @@
# Model catalogs
Each `*.yaml` file in this directory declares one provider's model
catalog. The registry loads every YAML at boot and joins it to the
matching provider plugin under `application/llm/providers/`.
To add or edit models, you almost always only touch a YAML here — no
Python code required.
## Add a model to an existing provider
Open the provider's YAML (e.g. `anthropic.yaml`) and append two lines
under `models:`:
```yaml
models:
- id: claude-3-7-sonnet
display_name: Claude 3.7 Sonnet
```
Capabilities default to the provider's `defaults:` block. Override
per-model only when needed:
```yaml
- id: claude-3-7-sonnet
display_name: Claude 3.7 Sonnet
context_window: 500000
```
Restart the app. The new model appears in `/api/models`.
> The model `id` is what gets stored in agent / workflow records. Once
> users start picking the model, **don't rename it** — agent and
> workflow rows reference it as a free-form string and silently fall
> back to the system default if the id disappears.
## Add an OpenAI-compatible provider (zero Python)
Drop a YAML in this directory (or in your `MODELS_CONFIG_DIR`) that uses
the `openai_compatible` plugin. Set the env var named in `api_key_env`
and you're done — no Python, no settings.py edit, no LLMCreator change:
```yaml
# mistral.yaml
provider: openai_compatible
display_provider: mistral # shown in /api/models response
api_key_env: MISTRAL_API_KEY # env var the plugin reads at boot
base_url: https://api.mistral.ai/v1
defaults:
supports_tools: true
context_window: 128000
models:
- id: mistral-large-latest
display_name: Mistral Large
- id: mistral-small-latest
display_name: Mistral Small
```
`MISTRAL_API_KEY=sk-... ; restart` — Mistral models appear in
`/api/models` with `provider: "mistral"`. They route through the OpenAI
wire format (it's `OpenAILLM` under the hood) but with Mistral's
endpoint and key.
Multiple `openai_compatible` YAMLs coexist: each file is one logical
endpoint with its own `api_key_env` and `base_url`. Drop in
`together.yaml`, `fireworks.yaml`, etc. side by side. If an env var
isn't set, that catalog is silently skipped at boot (logged at INFO) —
no error.
Working example: `examples/mistral.yaml.example`. Files inside
`examples/` aren't loaded by the registry; the glob only picks up
`*.yaml` at the top level.
## Add a provider with its own SDK
For a provider that doesn't speak OpenAI's wire format, add one Python
file to `application/llm/providers/<name>.py`:
```python
from application.llm.providers.base import Provider
from application.llm.my_provider import MyLLM
class MyProvider(Provider):
name = "my_provider"
llm_class = MyLLM
def get_api_key(self, settings):
return settings.MY_PROVIDER_API_KEY
```
Register it in `application/llm/providers/__init__.py` (one line in
`ALL_PROVIDERS`), add `MY_PROVIDER_API_KEY` to `settings.py`, and create
`my_provider.yaml` here with the model catalog.
## Schema reference
```yaml
provider: <string, required> # matches the Provider plugin's `name`
# openai_compatible only — required for that provider, ignored for others
display_provider: <string> # label shown in /api/models response
api_key_env: <string> # name of the env var carrying the key
base_url: <string> # endpoint URL
defaults: # optional, applied to every model below
supports_tools: bool # default false
supports_structured_output: bool # default false
supports_streaming: bool # default true
attachments: [<alias-or-mime>, ...] # default []
context_window: int # default 128000
input_cost_per_token: float # default null
output_cost_per_token: float # default null
models: # required
- id: <string, required> # the value persisted in agent records
display_name: <string> # default: id
description: <string> # default: ""
enabled: bool # default true; false hides from /api/models
base_url: <string> # optional custom endpoint for this model
# All `defaults:` fields above can be overridden here per-model.
```
### Attachment aliases
The `attachments:` list can mix human-readable aliases with raw MIME
types. Aliases are defined in `_defaults.yaml`:
| Alias | Expands to |
|---|---|
| `image` | `image/png`, `image/jpeg`, `image/jpg`, `image/webp`, `image/gif` |
| `pdf` | `application/pdf` |
| `audio` | `audio/mpeg`, `audio/wav`, `audio/ogg` |
Use raw MIME types when you need surgical control:
```yaml
attachments: [image/png, image/webp] # only these two
```
## Operator-supplied YAMLs (`MODELS_CONFIG_DIR`)
Set the `MODELS_CONFIG_DIR` env var (or `.env` entry) to a directory
path. Every `*.yaml` in that directory is loaded **after** the built-in
catalog under `application/core/models/`. Operators use this to:
- Add new `openai_compatible` providers (Mistral, Together, Fireworks,
Ollama, ...) without forking the repo.
- Extend an existing provider's catalog with extra models — append
models under `provider: anthropic` and they show up alongside the
built-ins.
- Override a built-in model's capabilities — declare the same `id`
with different fields (e.g. a higher `context_window`). Later wins;
the override is logged as a `WARNING` so you can audit it.
Things you cannot do via `MODELS_CONFIG_DIR`:
- Add a brand-new non-OpenAI provider — that needs a Python plugin
under `application/llm/providers/` (see "Add a provider with its own
SDK" above). Operator YAMLs may only target a `provider:` value that
already has a registered plugin.
### Example: Docker
Mount your model YAMLs into the container and point the env var at the
mount path:
```yaml
# docker-compose.yml
services:
app:
image: arc53/docsgpt
environment:
MODELS_CONFIG_DIR: /etc/docsgpt/models
MISTRAL_API_KEY: ${MISTRAL_API_KEY}
volumes:
- ./my-models:/etc/docsgpt/models:ro
```
Then `./my-models/mistral.yaml` (the file from
`examples/mistral.yaml.example`) gets picked up at boot.
### Example: Kubernetes
Mount a `ConfigMap` containing your YAMLs at a known path and set
`MODELS_CONFIG_DIR` on the deployment. The same `examples/mistral.yaml.example`
becomes a key in the ConfigMap.
### Misconfiguration
If `MODELS_CONFIG_DIR` is set but the path doesn't exist (or isn't a
directory), the app logs a `WARNING` at boot and continues with just
the built-in catalog. The app does *not* fail to start — operators can
ship config drift without taking down the service — but the warning is
loud enough to surface in any reasonable log aggregator.
## Validation
YAMLs are parsed with Pydantic at boot. The app fails to start with a
clear error message if:
- a top-level key is unknown
- a model is missing `id`
- an attachment alias isn't defined
- the `provider:` value isn't registered as a plugin
This is intentional — silent fallbacks would mean users don't notice
their model picks broke until they hit the API.
## Reserved fields (not yet implemented)
- `aliases:` on a model — old IDs that resolve to this model. Reserved
for future renames; the schema accepts the field but it is not yet
acted on.

View File

@@ -0,0 +1,18 @@
# Global defaults applied across every model YAML in this directory.
# Keep this file sparse — per-provider `defaults:` blocks are clearer
# than a deep global default chain. This file is for things that
# genuinely never vary, like the meaning of "image".
attachment_aliases:
image:
- image/png
- image/jpeg
- image/jpg
- image/webp
- image/gif
pdf:
- application/pdf
audio:
- audio/mpeg
- audio/wav
- audio/ogg

View File

@@ -0,0 +1,23 @@
provider: anthropic
defaults:
supports_tools: true
attachments: [image]
context_window: 200000
models:
- id: claude-opus-4-7
display_name: Claude Opus 4.7
description: Most capable Claude model for complex reasoning and agentic coding
context_window: 1000000
supports_structured_output: true
- id: claude-sonnet-4-6
display_name: Claude Sonnet 4.6
description: Best balance of speed and intelligence with extended thinking
context_window: 1000000
supports_structured_output: true
- id: claude-haiku-4-5
display_name: Claude Haiku 4.5
description: Fastest Claude model with near-frontier intelligence
supports_structured_output: true

View File

@@ -0,0 +1,31 @@
# Azure OpenAI catalog.
#
# IMPORTANT: For Azure OpenAI, the `id` field is the **deployment name**, not
# a model name. Deployment names are arbitrary strings the operator chooses
# in Azure portal (or via ARM/Bicep/Terraform) when they create a deployment
# for a given underlying model + version.
#
# The IDs below are sensible defaults that mirror the underlying OpenAI
# model name (prefixed with `azure-`). Operators almost always need to
# override them via `MODELS_CONFIG_DIR` to match the deployment names that
# actually exist in their Azure resource. The `display_name`, capability
# flags, and `context_window` reflect the underlying OpenAI model.
provider: azure_openai
defaults:
supports_tools: true
supports_structured_output: true
attachments: [image]
context_window: 400000
models:
- id: azure-gpt-5.5
display_name: Azure OpenAI GPT-5.5
description: Azure-hosted flagship frontier model for complex reasoning, coding, and agentic work with a 1M-token context window
context_window: 1050000
- id: azure-gpt-5.4-mini
display_name: Azure OpenAI GPT-5.4 Mini
description: Azure-hosted cost-efficient GPT-5.4-class model for high-volume coding, computer use, and subagent workloads
- id: azure-gpt-5.4-nano
display_name: Azure OpenAI GPT-5.4 Nano
description: Azure-hosted cheapest GPT-5.4-class model, optimized for simple high-volume tasks where speed and cost matter most

View File

@@ -0,0 +1,7 @@
provider: docsgpt
models:
- id: docsgpt-local
display_name: DocsGPT Model
description: Local model
supports_tools: false

View File

@@ -0,0 +1,31 @@
# EXAMPLE — copy this file to ../mistral.yaml (or to your
# MODELS_CONFIG_DIR) and set MISTRAL_API_KEY in your environment.
#
# This is the entire integration. No Python required: the
# `openai_compatible` plugin reads `api_key_env` and `base_url` from
# the file and routes calls through the OpenAI wire format.
#
# Files in this `examples/` directory are NOT loaded by the registry
# (the loader globs *.yaml at the top level only).
provider: openai_compatible
display_provider: mistral # shown in /api/models response
api_key_env: MISTRAL_API_KEY # env var the plugin reads
base_url: https://api.mistral.ai/v1 # OpenAI-compatible endpoint
defaults:
supports_tools: true
context_window: 128000
models:
- id: mistral-large-latest
display_name: Mistral Large
description: Top-tier reasoning model
- id: mistral-small-latest
display_name: Mistral Small
description: Fast, cost-efficient
- id: codestral-latest
display_name: Codestral
description: Code-specialized model

View File

@@ -0,0 +1,17 @@
provider: google
defaults:
supports_tools: true
supports_structured_output: true
attachments: [pdf, image]
context_window: 1048576
models:
- id: gemini-3.1-pro-preview
display_name: Gemini 3.1 Pro
description: Most capable Gemini 3 model with advanced reasoning and agentic coding (preview)
- id: gemini-3-flash-preview
display_name: Gemini 3 Flash
description: Frontier-class performance for low-latency, high-volume tasks (preview)
- id: gemini-3.1-flash-lite-preview
display_name: Gemini 3.1 Flash-Lite
description: Cost-efficient frontier-class multimodal model for high-throughput workloads (preview)

View File

@@ -0,0 +1,16 @@
provider: groq
defaults:
supports_tools: true
context_window: 131072
models:
- id: openai/gpt-oss-120b
display_name: GPT-OSS 120B
description: OpenAI's open-weight 120B flagship served on Groq's LPU hardware; strong general reasoning with strict structured output support
supports_structured_output: true
- id: llama-3.3-70b-versatile
display_name: Llama 3.3 70B Versatile
description: Meta's Llama 3.3 70B for general-purpose chat with parallel tool use
- id: llama-3.1-8b-instant
display_name: Llama 3.1 8B Instant
description: Small, very low-latency Llama model (~560 tok/s) with parallel tool use

View File

@@ -0,0 +1,7 @@
provider: huggingface
models:
- id: huggingface-local
display_name: Hugging Face Model
description: Local Hugging Face model
supports_tools: false

View File

@@ -0,0 +1,21 @@
provider: novita
defaults:
supports_tools: true
supports_structured_output: true
models:
- id: deepseek/deepseek-v4-pro
display_name: DeepSeek V4 Pro
description: 1.6T MoE (49B active) with 1M context, hybrid CSA/HCA attention, top-tier reasoning and agentic coding
context_window: 1048576
- id: moonshotai/kimi-k2.6
display_name: Kimi K2.6
description: 1T-parameter open-weight MoE with native vision/video, multi-step tool calling, and agentic long-horizon execution
attachments: [image]
context_window: 262144
- id: zai-org/glm-5
display_name: GLM-5
description: Z.AI 754B-parameter MoE with strong general reasoning, function calling, and structured output
context_window: 202800

View File

@@ -0,0 +1,18 @@
provider: openai
defaults:
supports_tools: true
supports_structured_output: true
attachments: [image]
context_window: 400000
models:
- id: gpt-5.5
display_name: GPT-5.5
description: Flagship frontier model for complex reasoning, coding, and agentic work with a 1M-token context window
context_window: 1050000
- id: gpt-5.4-mini
display_name: GPT-5.4 Mini
description: Cost-efficient GPT-5.4-class model for high-volume coding, computer use, and subagent workloads
- id: gpt-5.4-nano
display_name: GPT-5.4 Nano
description: Cheapest GPT-5.4-class model, optimized for simple high-volume tasks where speed and cost matter most

View File

@@ -0,0 +1,25 @@
provider: openrouter
defaults:
supports_tools: true
attachments: [image]
context_window: 128000
models:
- id: qwen/qwen3-coder:free
display_name: Qwen3 Coder (free)
description: Free-tier 480B MoE coder model with strong agentic tool use; rate-limited
context_window: 262000
attachments: []
- id: deepseek/deepseek-v3.2
display_name: DeepSeek V3.2
description: Open-weights reasoning model, very low cost (~$0.25 in / $0.38 out per 1M)
context_window: 131072
attachments: []
supports_structured_output: true
- id: anthropic/claude-sonnet-4.6
display_name: Claude Sonnet 4.6 (via OpenRouter)
description: Frontier Sonnet-class model with 1M context, vision, and extended thinking
context_window: 1000000
supports_structured_output: true

Some files were not shown because too many files have changed in this diff Show More