mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-01-20 14:00:55 +00:00
Feat/small optimisation (#2182)
* optimised ram use + celery * Remove VITE_EMBEDDINGS_NAME * fix: timeout on remote embeds
This commit is contained in:
@@ -1,6 +1,12 @@
|
||||
API_KEY=<LLM api key (for example, open ai key)>
|
||||
LLM_NAME=docsgpt
|
||||
VITE_API_STREAMING=true
|
||||
INTERNAL_KEY=<internal key for worker-to-backend authentication>
|
||||
|
||||
# Remote Embeddings (Optional - for using a remote embeddings API instead of local SentenceTransformer)
|
||||
# When set, the app will use the remote API and won't load SentenceTransformer (saves RAM)
|
||||
EMBEDDINGS_BASE_URL=
|
||||
EMBEDDINGS_KEY=
|
||||
|
||||
#For Azure (you can delete it if you don't use Azure)
|
||||
OPENAI_API_BASE=
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
API_KEY=your_api_key
|
||||
EMBEDDINGS_KEY=your_api_key
|
||||
API_URL=http://localhost:7091
|
||||
INTERNAL_KEY=your_internal_key
|
||||
FLASK_APP=application/app.py
|
||||
FLASK_DEBUG=true
|
||||
|
||||
#For OPENAI on Azure
|
||||
OPENAI_API_BASE=
|
||||
OPENAI_API_VERSION=
|
||||
AZURE_DEPLOYMENT_NAME=
|
||||
AZURE_EMBEDDINGS_DEPLOYMENT_NAME=
|
||||
@@ -21,3 +21,4 @@ def config_loggers(*args, **kwargs):
|
||||
|
||||
|
||||
celery = make_celery()
|
||||
celery.config_from_object("application.celeryconfig")
|
||||
|
||||
@@ -6,3 +6,6 @@ result_backend = os.getenv("CELERY_RESULT_BACKEND")
|
||||
task_serializer = 'json'
|
||||
result_serializer = 'json'
|
||||
accept_content = ['json']
|
||||
|
||||
# Autodiscover tasks
|
||||
imports = ('application.api.user.tasks',)
|
||||
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
current_dir = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
@@ -10,12 +10,19 @@ current_dir = os.path.dirname(
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(extra="ignore")
|
||||
|
||||
AUTH_TYPE: Optional[str] = None # simple_jwt, session_jwt, or None
|
||||
LLM_PROVIDER: str = "docsgpt"
|
||||
LLM_NAME: Optional[str] = (
|
||||
None # if LLM_PROVIDER is openai, LLM_NAME can be gpt-4 or gpt-3.5-turbo
|
||||
)
|
||||
EMBEDDINGS_NAME: str = "huggingface_sentence-transformers/all-mpnet-base-v2"
|
||||
EMBEDDINGS_BASE_URL: Optional[str] = None # Remote embeddings API URL (OpenAI-compatible)
|
||||
EMBEDDINGS_KEY: Optional[str] = (
|
||||
None # api key for embeddings (if using openai, just copy API_KEY)
|
||||
)
|
||||
|
||||
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
|
||||
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
|
||||
MONGO_URI: str = "mongodb://localhost:27017/docsgpt"
|
||||
@@ -73,9 +80,6 @@ class Settings(BaseSettings):
|
||||
GROQ_API_KEY: Optional[str] = None
|
||||
HUGGINGFACE_API_KEY: Optional[str] = None
|
||||
|
||||
EMBEDDINGS_KEY: Optional[str] = (
|
||||
None # api key for embeddings (if using openai, just copy API_KEY)
|
||||
)
|
||||
OPENAI_API_BASE: Optional[str] = None # azure openai api base url
|
||||
OPENAI_API_VERSION: Optional[str] = None # azure openai api version
|
||||
AZURE_DEPLOYMENT_NAME: Optional[str] = None # azure deployment name for answering
|
||||
@@ -153,5 +157,6 @@ class Settings(BaseSettings):
|
||||
COMPRESSION_MAX_HISTORY_POINTS: int = 3 # Keep only last N compression points to prevent DB bloat
|
||||
|
||||
|
||||
path = Path(__file__).parent.parent.absolute()
|
||||
# Project root is one level above application/
|
||||
path = Path(__file__).parent.parent.parent.absolute()
|
||||
settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8")
|
||||
|
||||
@@ -2,41 +2,79 @@ import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import requests
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
class EmbeddingsWrapper:
|
||||
def __init__(self, model_name, *args, **kwargs):
|
||||
logging.info(f"Initializing EmbeddingsWrapper with model: {model_name}")
|
||||
try:
|
||||
kwargs.setdefault("trust_remote_code", True)
|
||||
self.model = SentenceTransformer(
|
||||
model_name,
|
||||
config_kwargs={"allow_dangerous_deserialization": True},
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
if self.model is None or self.model._first_module() is None:
|
||||
class RemoteEmbeddings:
|
||||
"""
|
||||
Wrapper for remote embeddings API (OpenAI-compatible).
|
||||
Used when EMBEDDINGS_BASE_URL is configured.
|
||||
"""
|
||||
|
||||
def __init__(self, api_url: str, model_name: str, api_key: str = None):
|
||||
self.api_url = api_url.rstrip("/")
|
||||
self.model_name = model_name
|
||||
self.headers = {"Content-Type": "application/json"}
|
||||
if api_key:
|
||||
self.headers["Authorization"] = f"Bearer {api_key}"
|
||||
self.dimension = None
|
||||
|
||||
def _embed(self, inputs):
|
||||
"""Send embedding request to remote API."""
|
||||
payload = {"inputs": inputs}
|
||||
if self.model_name:
|
||||
payload["model"] = self.model_name
|
||||
|
||||
response = requests.post(
|
||||
self.api_url, headers=self.headers, json=payload, timeout=180
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
if isinstance(result, list):
|
||||
if result and isinstance(result[0], list):
|
||||
return result
|
||||
elif result and all(isinstance(x, (int, float)) for x in result):
|
||||
return [result]
|
||||
elif not result:
|
||||
return []
|
||||
else:
|
||||
raise ValueError(
|
||||
f"SentenceTransformer model failed to load properly for: {model_name}"
|
||||
f"Unexpected list content from remote embeddings API: {result}"
|
||||
)
|
||||
self.dimension = self.model.get_sentence_embedding_dimension()
|
||||
logging.info(f"Successfully loaded model with dimension: {self.dimension}")
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Failed to initialize SentenceTransformer with model {model_name}: {str(e)}",
|
||||
exc_info=True,
|
||||
elif isinstance(result, dict) and "error" in result:
|
||||
raise ValueError(f"Remote embeddings API error: {result['error']}")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected response format from remote embeddings API: {result}"
|
||||
)
|
||||
raise
|
||||
|
||||
def embed_query(self, query: str):
|
||||
return self.model.encode(query).tolist()
|
||||
"""Embed a single query string."""
|
||||
embeddings_list = self._embed(query)
|
||||
if (
|
||||
isinstance(embeddings_list, list)
|
||||
and len(embeddings_list) == 1
|
||||
and isinstance(embeddings_list[0], list)
|
||||
):
|
||||
if self.dimension is None:
|
||||
self.dimension = len(embeddings_list[0])
|
||||
return embeddings_list[0]
|
||||
raise ValueError(
|
||||
f"Unexpected result structure after embedding query: {embeddings_list}"
|
||||
)
|
||||
|
||||
def embed_documents(self, documents: list):
|
||||
return self.model.encode(documents).tolist()
|
||||
"""Embed a list of documents."""
|
||||
if not documents:
|
||||
return []
|
||||
embeddings_list = self._embed(documents)
|
||||
if self.dimension is None and embeddings_list:
|
||||
self.dimension = len(embeddings_list[0])
|
||||
return embeddings_list
|
||||
|
||||
def __call__(self, text):
|
||||
if isinstance(text, str):
|
||||
@@ -47,6 +85,13 @@ class EmbeddingsWrapper:
|
||||
raise ValueError("Input must be a string or a list of strings")
|
||||
|
||||
|
||||
def _get_embeddings_wrapper():
|
||||
"""Lazy import of EmbeddingsWrapper to avoid loading SentenceTransformer when using remote embeddings."""
|
||||
from application.vectorstore.embeddings_local import EmbeddingsWrapper
|
||||
|
||||
return EmbeddingsWrapper
|
||||
|
||||
|
||||
class EmbeddingsSingleton:
|
||||
_instances = {}
|
||||
|
||||
@@ -60,8 +105,13 @@ class EmbeddingsSingleton:
|
||||
|
||||
@staticmethod
|
||||
def _create_instance(embeddings_name, *args, **kwargs):
|
||||
if embeddings_name == "openai_text-embedding-ada-002":
|
||||
return OpenAIEmbeddings(*args, **kwargs)
|
||||
|
||||
# Lazy import EmbeddingsWrapper only when needed (avoids loading SentenceTransformer)
|
||||
EmbeddingsWrapper = _get_embeddings_wrapper()
|
||||
|
||||
embeddings_factory = {
|
||||
"openai_text-embedding-ada-002": OpenAIEmbeddings,
|
||||
"huggingface_sentence-transformers/all-mpnet-base-v2": lambda: EmbeddingsWrapper(
|
||||
"sentence-transformers/all-mpnet-base-v2"
|
||||
),
|
||||
@@ -121,6 +171,20 @@ class BaseVectorStore(ABC):
|
||||
)
|
||||
|
||||
def _get_embeddings(self, embeddings_name, embeddings_key=None):
|
||||
# Check for remote embeddings first
|
||||
if settings.EMBEDDINGS_BASE_URL:
|
||||
logging.info(
|
||||
f"Using remote embeddings API at: {settings.EMBEDDINGS_BASE_URL}"
|
||||
)
|
||||
cache_key = f"remote_{settings.EMBEDDINGS_BASE_URL}_{embeddings_name}"
|
||||
if cache_key not in EmbeddingsSingleton._instances:
|
||||
EmbeddingsSingleton._instances[cache_key] = RemoteEmbeddings(
|
||||
api_url=settings.EMBEDDINGS_BASE_URL,
|
||||
model_name=embeddings_name,
|
||||
api_key=embeddings_key,
|
||||
)
|
||||
return EmbeddingsSingleton._instances[cache_key]
|
||||
|
||||
if embeddings_name == "openai_text-embedding-ada-002":
|
||||
if self.is_azure_configured():
|
||||
os.environ["OPENAI_API_TYPE"] = "azure"
|
||||
|
||||
48
application/vectorstore/embeddings_local.py
Normal file
48
application/vectorstore/embeddings_local.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""
|
||||
Local embeddings using SentenceTransformer.
|
||||
This module is only imported when EMBEDDINGS_BASE_URL is not set,
|
||||
to avoid loading SentenceTransformer into memory when using remote embeddings.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
|
||||
class EmbeddingsWrapper:
|
||||
def __init__(self, model_name, *args, **kwargs):
|
||||
logging.info(f"Initializing EmbeddingsWrapper with model: {model_name}")
|
||||
try:
|
||||
kwargs.setdefault("trust_remote_code", True)
|
||||
self.model = SentenceTransformer(
|
||||
model_name,
|
||||
config_kwargs={"allow_dangerous_deserialization": True},
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
if self.model is None or self.model._first_module() is None:
|
||||
raise ValueError(
|
||||
f"SentenceTransformer model failed to load properly for: {model_name}"
|
||||
)
|
||||
self.dimension = self.model.get_sentence_embedding_dimension()
|
||||
logging.info(f"Successfully loaded model with dimension: {self.dimension}")
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Failed to initialize SentenceTransformer with model {model_name}: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
def embed_query(self, query: str):
|
||||
return self.model.encode(query).tolist()
|
||||
|
||||
def embed_documents(self, documents: list):
|
||||
return self.model.encode(documents).tolist()
|
||||
|
||||
def __call__(self, text):
|
||||
if isinstance(text, str):
|
||||
return self.embed_query(text)
|
||||
elif isinstance(text, list):
|
||||
return self.embed_documents(text)
|
||||
else:
|
||||
raise ValueError("Input must be a string or a list of strings")
|
||||
@@ -16,9 +16,12 @@ services:
|
||||
backend:
|
||||
user: root
|
||||
build: ../application
|
||||
env_file:
|
||||
- ../.env
|
||||
environment:
|
||||
- API_KEY=$API_KEY
|
||||
- EMBEDDINGS_KEY=$API_KEY
|
||||
- EMBEDDINGS_KEY=$EMBEDDINGS_KEY
|
||||
- EMBEDDINGS_BASE_URL=$EMBEDDINGS_BASE_URL
|
||||
- LLM_PROVIDER=$LLM_PROVIDER
|
||||
- LLM_NAME=$LLM_NAME
|
||||
- CELERY_BROKER_URL=redis://redis:6379/0
|
||||
@@ -41,9 +44,12 @@ services:
|
||||
user: root
|
||||
build: ../application
|
||||
command: celery -A application.app.celery worker -l INFO -B
|
||||
env_file:
|
||||
- ../.env
|
||||
environment:
|
||||
- API_KEY=$API_KEY
|
||||
- EMBEDDINGS_KEY=$API_KEY
|
||||
- EMBEDDINGS_KEY=$EMBEDDINGS_KEY
|
||||
- EMBEDDINGS_BASE_URL=$EMBEDDINGS_BASE_URL
|
||||
- LLM_PROVIDER=$LLM_PROVIDER
|
||||
- LLM_NAME=$LLM_NAME
|
||||
- CELERY_BROKER_URL=redis://redis:6379/0
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM node:20.6.1-bullseye-slim
|
||||
FROM node:22-bullseye-slim
|
||||
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -28,9 +28,6 @@ import AgentPreview from './AgentPreview';
|
||||
import { Agent, ToolSummary } from './types';
|
||||
|
||||
import type { Model } from '../models/types';
|
||||
const embeddingsName =
|
||||
import.meta.env.VITE_EMBEDDINGS_NAME ||
|
||||
'huggingface_sentence-transformers/all-mpnet-base-v2';
|
||||
|
||||
export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
const { t } = useTranslation();
|
||||
@@ -548,22 +545,20 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
} else {
|
||||
// Single source selected - maintain backward compatibility
|
||||
const selectedSource = selectedSources[0];
|
||||
if (selectedSource?.model === embeddingsName) {
|
||||
if (selectedSource && 'id' in selectedSource) {
|
||||
setAgent((prev) => ({
|
||||
...prev,
|
||||
source: selectedSource?.id || 'default',
|
||||
sources: [], // Clear sources array for single source
|
||||
retriever: '',
|
||||
}));
|
||||
} else {
|
||||
setAgent((prev) => ({
|
||||
...prev,
|
||||
source: '',
|
||||
sources: [], // Clear sources array
|
||||
retriever: selectedSource?.retriever || 'classic',
|
||||
}));
|
||||
}
|
||||
if (selectedSource && 'id' in selectedSource) {
|
||||
setAgent((prev) => ({
|
||||
...prev,
|
||||
source: selectedSource?.id || 'default',
|
||||
sources: [], // Clear sources array for single source
|
||||
retriever: '',
|
||||
}));
|
||||
} else {
|
||||
setAgent((prev) => ({
|
||||
...prev,
|
||||
source: '',
|
||||
sources: [], // Clear sources array
|
||||
retriever: selectedSource?.retriever || 'classic',
|
||||
}));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
||||
@@ -40,10 +40,6 @@ export default function SourcesPopup({
|
||||
showAbove: false,
|
||||
});
|
||||
|
||||
const embeddingsName =
|
||||
import.meta.env.VITE_EMBEDDINGS_NAME ||
|
||||
'huggingface_sentence-transformers/all-mpnet-base-v2';
|
||||
|
||||
const options = useSelector(selectSourceDocs);
|
||||
const selectedDocs = useSelector(selectSelectedDocs);
|
||||
|
||||
@@ -147,70 +143,65 @@ export default function SourcesPopup({
|
||||
{options ? (
|
||||
<>
|
||||
{filteredOptions?.map((option: any, index: number) => {
|
||||
if (option.model === embeddingsName) {
|
||||
const isSelected =
|
||||
selectedDocs &&
|
||||
Array.isArray(selectedDocs) &&
|
||||
selectedDocs.length > 0 &&
|
||||
selectedDocs.some((doc) =>
|
||||
option.id
|
||||
? doc.id === option.id
|
||||
: doc.date === option.date,
|
||||
);
|
||||
|
||||
return (
|
||||
<div
|
||||
key={index}
|
||||
className="border-opacity-80 dark:border-dim-gray flex cursor-pointer items-center border-b border-[#D9D9D9] p-3 transition-colors hover:bg-gray-100 dark:text-[14px] dark:hover:bg-[#2C2E3C]"
|
||||
onClick={() => {
|
||||
if (isSelected) {
|
||||
const updatedDocs =
|
||||
selectedDocs && Array.isArray(selectedDocs)
|
||||
? selectedDocs.filter((doc) =>
|
||||
option.id
|
||||
? doc.id !== option.id
|
||||
: doc.date !== option.date,
|
||||
)
|
||||
: [];
|
||||
dispatch(setSelectedDocs(updatedDocs));
|
||||
handlePostDocumentSelect(
|
||||
updatedDocs.length > 0 ? updatedDocs : null,
|
||||
);
|
||||
} else {
|
||||
const updatedDocs =
|
||||
selectedDocs && Array.isArray(selectedDocs)
|
||||
? [...selectedDocs, option]
|
||||
: [option];
|
||||
dispatch(setSelectedDocs(updatedDocs));
|
||||
handlePostDocumentSelect(updatedDocs);
|
||||
}
|
||||
}}
|
||||
>
|
||||
<img
|
||||
src={SourceIcon}
|
||||
alt="Source"
|
||||
width={14}
|
||||
height={14}
|
||||
className="mr-3 shrink-0"
|
||||
/>
|
||||
<span className="dark:text-bright-gray mr-3 grow overflow-hidden font-medium text-ellipsis whitespace-nowrap text-[#5D5D5D]">
|
||||
{option.name}
|
||||
</span>
|
||||
<div
|
||||
className={`flex h-4 w-4 shrink-0 items-center justify-center rounded-xs border-2 border-[#C6C6C6] p-[0.5px] dark:border-[#757783]`}
|
||||
>
|
||||
{isSelected && (
|
||||
<img
|
||||
src={CheckIcon}
|
||||
alt="Selected"
|
||||
className="h-3 w-3"
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
const isSelected =
|
||||
selectedDocs &&
|
||||
Array.isArray(selectedDocs) &&
|
||||
selectedDocs.length > 0 &&
|
||||
selectedDocs.some((doc) =>
|
||||
option.id ? doc.id === option.id : doc.date === option.date,
|
||||
);
|
||||
}
|
||||
return null;
|
||||
|
||||
return (
|
||||
<div
|
||||
key={index}
|
||||
className="border-opacity-80 dark:border-dim-gray flex cursor-pointer items-center border-b border-[#D9D9D9] p-3 transition-colors hover:bg-gray-100 dark:text-[14px] dark:hover:bg-[#2C2E3C]"
|
||||
onClick={() => {
|
||||
if (isSelected) {
|
||||
const updatedDocs =
|
||||
selectedDocs && Array.isArray(selectedDocs)
|
||||
? selectedDocs.filter((doc) =>
|
||||
option.id
|
||||
? doc.id !== option.id
|
||||
: doc.date !== option.date,
|
||||
)
|
||||
: [];
|
||||
dispatch(setSelectedDocs(updatedDocs));
|
||||
handlePostDocumentSelect(
|
||||
updatedDocs.length > 0 ? updatedDocs : null,
|
||||
);
|
||||
} else {
|
||||
const updatedDocs =
|
||||
selectedDocs && Array.isArray(selectedDocs)
|
||||
? [...selectedDocs, option]
|
||||
: [option];
|
||||
dispatch(setSelectedDocs(updatedDocs));
|
||||
handlePostDocumentSelect(updatedDocs);
|
||||
}
|
||||
}}
|
||||
>
|
||||
<img
|
||||
src={SourceIcon}
|
||||
alt="Source"
|
||||
width={14}
|
||||
height={14}
|
||||
className="mr-3 shrink-0"
|
||||
/>
|
||||
<span className="dark:text-bright-gray mr-3 grow overflow-hidden font-medium text-ellipsis whitespace-nowrap text-[#5D5D5D]">
|
||||
{option.name}
|
||||
</span>
|
||||
<div
|
||||
className={`flex h-4 w-4 shrink-0 items-center justify-center rounded-xs border-2 border-[#C6C6C6] p-[0.5px] dark:border-[#757783]`}
|
||||
>
|
||||
{isSelected && (
|
||||
<img
|
||||
src={CheckIcon}
|
||||
alt="Selected"
|
||||
className="h-3 w-3"
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</>
|
||||
) : (
|
||||
|
||||
@@ -16,11 +16,6 @@ import {
|
||||
} from '../preferences/preferenceSlice';
|
||||
import WrapperModal from './WrapperModal';
|
||||
|
||||
const apiHost = import.meta.env.VITE_API_HOST || 'https://docsapi.arc53.com';
|
||||
const embeddingsName =
|
||||
import.meta.env.VITE_EMBEDDINGS_NAME ||
|
||||
'huggingface_sentence-transformers/all-mpnet-base-v2';
|
||||
|
||||
type StatusType = 'loading' | 'idle' | 'fetched' | 'failed';
|
||||
|
||||
export const ShareConversationModal = ({
|
||||
@@ -47,14 +42,12 @@ export const ShareConversationModal = ({
|
||||
|
||||
const extractDocPaths = (docs: Doc[]) =>
|
||||
docs
|
||||
? docs
|
||||
.filter((doc) => doc.model === embeddingsName)
|
||||
.map((doc: Doc) => {
|
||||
return {
|
||||
label: doc.name,
|
||||
value: doc.id ?? 'default',
|
||||
};
|
||||
})
|
||||
? docs.map((doc: Doc) => {
|
||||
return {
|
||||
label: doc.name,
|
||||
value: doc.id ?? 'default',
|
||||
};
|
||||
})
|
||||
: [];
|
||||
|
||||
const [sourcePath, setSourcePath] = useState<{
|
||||
|
||||
Reference in New Issue
Block a user