Feat/small optimisation (#2182)

* optimised ram use + celery

* Remove VITE_EMBEDDINGS_NAME

* fix: timeout on remote embeds
This commit is contained in:
Alex
2025-12-05 18:57:39 +00:00
committed by GitHub
parent e68da34c13
commit 9a937d2686
12 changed files with 243 additions and 143 deletions

View File

@@ -1,6 +1,12 @@
API_KEY=<LLM api key (for example, open ai key)>
LLM_NAME=docsgpt
VITE_API_STREAMING=true
INTERNAL_KEY=<internal key for worker-to-backend authentication>
# Remote Embeddings (Optional - for using a remote embeddings API instead of local SentenceTransformer)
# When set, the app will use the remote API and won't load SentenceTransformer (saves RAM)
EMBEDDINGS_BASE_URL=
EMBEDDINGS_KEY=
#For Azure (you can delete it if you don't use Azure)
OPENAI_API_BASE=

View File

@@ -1,12 +0,0 @@
API_KEY=your_api_key
EMBEDDINGS_KEY=your_api_key
API_URL=http://localhost:7091
INTERNAL_KEY=your_internal_key
FLASK_APP=application/app.py
FLASK_DEBUG=true
#For OPENAI on Azure
OPENAI_API_BASE=
OPENAI_API_VERSION=
AZURE_DEPLOYMENT_NAME=
AZURE_EMBEDDINGS_DEPLOYMENT_NAME=

View File

@@ -21,3 +21,4 @@ def config_loggers(*args, **kwargs):
celery = make_celery()
celery.config_from_object("application.celeryconfig")

View File

@@ -6,3 +6,6 @@ result_backend = os.getenv("CELERY_RESULT_BACKEND")
task_serializer = 'json'
result_serializer = 'json'
accept_content = ['json']
# Autodiscover tasks
imports = ('application.api.user.tasks',)

View File

@@ -2,7 +2,7 @@ import os
from pathlib import Path
from typing import Optional
from pydantic_settings import BaseSettings
from pydantic_settings import BaseSettings, SettingsConfigDict
current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
@@ -10,12 +10,19 @@ current_dir = os.path.dirname(
class Settings(BaseSettings):
model_config = SettingsConfigDict(extra="ignore")
AUTH_TYPE: Optional[str] = None # simple_jwt, session_jwt, or None
LLM_PROVIDER: str = "docsgpt"
LLM_NAME: Optional[str] = (
None # if LLM_PROVIDER is openai, LLM_NAME can be gpt-4 or gpt-3.5-turbo
)
EMBEDDINGS_NAME: str = "huggingface_sentence-transformers/all-mpnet-base-v2"
EMBEDDINGS_BASE_URL: Optional[str] = None # Remote embeddings API URL (OpenAI-compatible)
EMBEDDINGS_KEY: Optional[str] = (
None # api key for embeddings (if using openai, just copy API_KEY)
)
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
MONGO_URI: str = "mongodb://localhost:27017/docsgpt"
@@ -73,9 +80,6 @@ class Settings(BaseSettings):
GROQ_API_KEY: Optional[str] = None
HUGGINGFACE_API_KEY: Optional[str] = None
EMBEDDINGS_KEY: Optional[str] = (
None # api key for embeddings (if using openai, just copy API_KEY)
)
OPENAI_API_BASE: Optional[str] = None # azure openai api base url
OPENAI_API_VERSION: Optional[str] = None # azure openai api version
AZURE_DEPLOYMENT_NAME: Optional[str] = None # azure deployment name for answering
@@ -153,5 +157,6 @@ class Settings(BaseSettings):
COMPRESSION_MAX_HISTORY_POINTS: int = 3 # Keep only last N compression points to prevent DB bloat
path = Path(__file__).parent.parent.absolute()
# Project root is one level above application/
path = Path(__file__).parent.parent.parent.absolute()
settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8")

View File

@@ -2,41 +2,79 @@ import logging
import os
from abc import ABC, abstractmethod
import requests
from langchain_openai import OpenAIEmbeddings
from sentence_transformers import SentenceTransformer
from application.core.settings import settings
class EmbeddingsWrapper:
def __init__(self, model_name, *args, **kwargs):
logging.info(f"Initializing EmbeddingsWrapper with model: {model_name}")
try:
kwargs.setdefault("trust_remote_code", True)
self.model = SentenceTransformer(
model_name,
config_kwargs={"allow_dangerous_deserialization": True},
*args,
**kwargs,
)
if self.model is None or self.model._first_module() is None:
class RemoteEmbeddings:
"""
Wrapper for remote embeddings API (OpenAI-compatible).
Used when EMBEDDINGS_BASE_URL is configured.
"""
def __init__(self, api_url: str, model_name: str, api_key: str = None):
self.api_url = api_url.rstrip("/")
self.model_name = model_name
self.headers = {"Content-Type": "application/json"}
if api_key:
self.headers["Authorization"] = f"Bearer {api_key}"
self.dimension = None
def _embed(self, inputs):
"""Send embedding request to remote API."""
payload = {"inputs": inputs}
if self.model_name:
payload["model"] = self.model_name
response = requests.post(
self.api_url, headers=self.headers, json=payload, timeout=180
)
response.raise_for_status()
result = response.json()
if isinstance(result, list):
if result and isinstance(result[0], list):
return result
elif result and all(isinstance(x, (int, float)) for x in result):
return [result]
elif not result:
return []
else:
raise ValueError(
f"SentenceTransformer model failed to load properly for: {model_name}"
f"Unexpected list content from remote embeddings API: {result}"
)
self.dimension = self.model.get_sentence_embedding_dimension()
logging.info(f"Successfully loaded model with dimension: {self.dimension}")
except Exception as e:
logging.error(
f"Failed to initialize SentenceTransformer with model {model_name}: {str(e)}",
exc_info=True,
elif isinstance(result, dict) and "error" in result:
raise ValueError(f"Remote embeddings API error: {result['error']}")
else:
raise ValueError(
f"Unexpected response format from remote embeddings API: {result}"
)
raise
def embed_query(self, query: str):
return self.model.encode(query).tolist()
"""Embed a single query string."""
embeddings_list = self._embed(query)
if (
isinstance(embeddings_list, list)
and len(embeddings_list) == 1
and isinstance(embeddings_list[0], list)
):
if self.dimension is None:
self.dimension = len(embeddings_list[0])
return embeddings_list[0]
raise ValueError(
f"Unexpected result structure after embedding query: {embeddings_list}"
)
def embed_documents(self, documents: list):
return self.model.encode(documents).tolist()
"""Embed a list of documents."""
if not documents:
return []
embeddings_list = self._embed(documents)
if self.dimension is None and embeddings_list:
self.dimension = len(embeddings_list[0])
return embeddings_list
def __call__(self, text):
if isinstance(text, str):
@@ -47,6 +85,13 @@ class EmbeddingsWrapper:
raise ValueError("Input must be a string or a list of strings")
def _get_embeddings_wrapper():
"""Lazy import of EmbeddingsWrapper to avoid loading SentenceTransformer when using remote embeddings."""
from application.vectorstore.embeddings_local import EmbeddingsWrapper
return EmbeddingsWrapper
class EmbeddingsSingleton:
_instances = {}
@@ -60,8 +105,13 @@ class EmbeddingsSingleton:
@staticmethod
def _create_instance(embeddings_name, *args, **kwargs):
if embeddings_name == "openai_text-embedding-ada-002":
return OpenAIEmbeddings(*args, **kwargs)
# Lazy import EmbeddingsWrapper only when needed (avoids loading SentenceTransformer)
EmbeddingsWrapper = _get_embeddings_wrapper()
embeddings_factory = {
"openai_text-embedding-ada-002": OpenAIEmbeddings,
"huggingface_sentence-transformers/all-mpnet-base-v2": lambda: EmbeddingsWrapper(
"sentence-transformers/all-mpnet-base-v2"
),
@@ -121,6 +171,20 @@ class BaseVectorStore(ABC):
)
def _get_embeddings(self, embeddings_name, embeddings_key=None):
# Check for remote embeddings first
if settings.EMBEDDINGS_BASE_URL:
logging.info(
f"Using remote embeddings API at: {settings.EMBEDDINGS_BASE_URL}"
)
cache_key = f"remote_{settings.EMBEDDINGS_BASE_URL}_{embeddings_name}"
if cache_key not in EmbeddingsSingleton._instances:
EmbeddingsSingleton._instances[cache_key] = RemoteEmbeddings(
api_url=settings.EMBEDDINGS_BASE_URL,
model_name=embeddings_name,
api_key=embeddings_key,
)
return EmbeddingsSingleton._instances[cache_key]
if embeddings_name == "openai_text-embedding-ada-002":
if self.is_azure_configured():
os.environ["OPENAI_API_TYPE"] = "azure"

View File

@@ -0,0 +1,48 @@
"""
Local embeddings using SentenceTransformer.
This module is only imported when EMBEDDINGS_BASE_URL is not set,
to avoid loading SentenceTransformer into memory when using remote embeddings.
"""
import logging
from sentence_transformers import SentenceTransformer
class EmbeddingsWrapper:
def __init__(self, model_name, *args, **kwargs):
logging.info(f"Initializing EmbeddingsWrapper with model: {model_name}")
try:
kwargs.setdefault("trust_remote_code", True)
self.model = SentenceTransformer(
model_name,
config_kwargs={"allow_dangerous_deserialization": True},
*args,
**kwargs,
)
if self.model is None or self.model._first_module() is None:
raise ValueError(
f"SentenceTransformer model failed to load properly for: {model_name}"
)
self.dimension = self.model.get_sentence_embedding_dimension()
logging.info(f"Successfully loaded model with dimension: {self.dimension}")
except Exception as e:
logging.error(
f"Failed to initialize SentenceTransformer with model {model_name}: {str(e)}",
exc_info=True,
)
raise
def embed_query(self, query: str):
return self.model.encode(query).tolist()
def embed_documents(self, documents: list):
return self.model.encode(documents).tolist()
def __call__(self, text):
if isinstance(text, str):
return self.embed_query(text)
elif isinstance(text, list):
return self.embed_documents(text)
else:
raise ValueError("Input must be a string or a list of strings")

View File

@@ -16,9 +16,12 @@ services:
backend:
user: root
build: ../application
env_file:
- ../.env
environment:
- API_KEY=$API_KEY
- EMBEDDINGS_KEY=$API_KEY
- EMBEDDINGS_KEY=$EMBEDDINGS_KEY
- EMBEDDINGS_BASE_URL=$EMBEDDINGS_BASE_URL
- LLM_PROVIDER=$LLM_PROVIDER
- LLM_NAME=$LLM_NAME
- CELERY_BROKER_URL=redis://redis:6379/0
@@ -41,9 +44,12 @@ services:
user: root
build: ../application
command: celery -A application.app.celery worker -l INFO -B
env_file:
- ../.env
environment:
- API_KEY=$API_KEY
- EMBEDDINGS_KEY=$API_KEY
- EMBEDDINGS_KEY=$EMBEDDINGS_KEY
- EMBEDDINGS_BASE_URL=$EMBEDDINGS_BASE_URL
- LLM_PROVIDER=$LLM_PROVIDER
- LLM_NAME=$LLM_NAME
- CELERY_BROKER_URL=redis://redis:6379/0

View File

@@ -1,4 +1,4 @@
FROM node:20.6.1-bullseye-slim
FROM node:22-bullseye-slim
WORKDIR /app

View File

@@ -28,9 +28,6 @@ import AgentPreview from './AgentPreview';
import { Agent, ToolSummary } from './types';
import type { Model } from '../models/types';
const embeddingsName =
import.meta.env.VITE_EMBEDDINGS_NAME ||
'huggingface_sentence-transformers/all-mpnet-base-v2';
export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
const { t } = useTranslation();
@@ -548,22 +545,20 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
} else {
// Single source selected - maintain backward compatibility
const selectedSource = selectedSources[0];
if (selectedSource?.model === embeddingsName) {
if (selectedSource && 'id' in selectedSource) {
setAgent((prev) => ({
...prev,
source: selectedSource?.id || 'default',
sources: [], // Clear sources array for single source
retriever: '',
}));
} else {
setAgent((prev) => ({
...prev,
source: '',
sources: [], // Clear sources array
retriever: selectedSource?.retriever || 'classic',
}));
}
if (selectedSource && 'id' in selectedSource) {
setAgent((prev) => ({
...prev,
source: selectedSource?.id || 'default',
sources: [], // Clear sources array for single source
retriever: '',
}));
} else {
setAgent((prev) => ({
...prev,
source: '',
sources: [], // Clear sources array
retriever: selectedSource?.retriever || 'classic',
}));
}
}
} else {

View File

@@ -40,10 +40,6 @@ export default function SourcesPopup({
showAbove: false,
});
const embeddingsName =
import.meta.env.VITE_EMBEDDINGS_NAME ||
'huggingface_sentence-transformers/all-mpnet-base-v2';
const options = useSelector(selectSourceDocs);
const selectedDocs = useSelector(selectSelectedDocs);
@@ -147,70 +143,65 @@ export default function SourcesPopup({
{options ? (
<>
{filteredOptions?.map((option: any, index: number) => {
if (option.model === embeddingsName) {
const isSelected =
selectedDocs &&
Array.isArray(selectedDocs) &&
selectedDocs.length > 0 &&
selectedDocs.some((doc) =>
option.id
? doc.id === option.id
: doc.date === option.date,
);
return (
<div
key={index}
className="border-opacity-80 dark:border-dim-gray flex cursor-pointer items-center border-b border-[#D9D9D9] p-3 transition-colors hover:bg-gray-100 dark:text-[14px] dark:hover:bg-[#2C2E3C]"
onClick={() => {
if (isSelected) {
const updatedDocs =
selectedDocs && Array.isArray(selectedDocs)
? selectedDocs.filter((doc) =>
option.id
? doc.id !== option.id
: doc.date !== option.date,
)
: [];
dispatch(setSelectedDocs(updatedDocs));
handlePostDocumentSelect(
updatedDocs.length > 0 ? updatedDocs : null,
);
} else {
const updatedDocs =
selectedDocs && Array.isArray(selectedDocs)
? [...selectedDocs, option]
: [option];
dispatch(setSelectedDocs(updatedDocs));
handlePostDocumentSelect(updatedDocs);
}
}}
>
<img
src={SourceIcon}
alt="Source"
width={14}
height={14}
className="mr-3 shrink-0"
/>
<span className="dark:text-bright-gray mr-3 grow overflow-hidden font-medium text-ellipsis whitespace-nowrap text-[#5D5D5D]">
{option.name}
</span>
<div
className={`flex h-4 w-4 shrink-0 items-center justify-center rounded-xs border-2 border-[#C6C6C6] p-[0.5px] dark:border-[#757783]`}
>
{isSelected && (
<img
src={CheckIcon}
alt="Selected"
className="h-3 w-3"
/>
)}
</div>
</div>
const isSelected =
selectedDocs &&
Array.isArray(selectedDocs) &&
selectedDocs.length > 0 &&
selectedDocs.some((doc) =>
option.id ? doc.id === option.id : doc.date === option.date,
);
}
return null;
return (
<div
key={index}
className="border-opacity-80 dark:border-dim-gray flex cursor-pointer items-center border-b border-[#D9D9D9] p-3 transition-colors hover:bg-gray-100 dark:text-[14px] dark:hover:bg-[#2C2E3C]"
onClick={() => {
if (isSelected) {
const updatedDocs =
selectedDocs && Array.isArray(selectedDocs)
? selectedDocs.filter((doc) =>
option.id
? doc.id !== option.id
: doc.date !== option.date,
)
: [];
dispatch(setSelectedDocs(updatedDocs));
handlePostDocumentSelect(
updatedDocs.length > 0 ? updatedDocs : null,
);
} else {
const updatedDocs =
selectedDocs && Array.isArray(selectedDocs)
? [...selectedDocs, option]
: [option];
dispatch(setSelectedDocs(updatedDocs));
handlePostDocumentSelect(updatedDocs);
}
}}
>
<img
src={SourceIcon}
alt="Source"
width={14}
height={14}
className="mr-3 shrink-0"
/>
<span className="dark:text-bright-gray mr-3 grow overflow-hidden font-medium text-ellipsis whitespace-nowrap text-[#5D5D5D]">
{option.name}
</span>
<div
className={`flex h-4 w-4 shrink-0 items-center justify-center rounded-xs border-2 border-[#C6C6C6] p-[0.5px] dark:border-[#757783]`}
>
{isSelected && (
<img
src={CheckIcon}
alt="Selected"
className="h-3 w-3"
/>
)}
</div>
</div>
);
})}
</>
) : (

View File

@@ -16,11 +16,6 @@ import {
} from '../preferences/preferenceSlice';
import WrapperModal from './WrapperModal';
const apiHost = import.meta.env.VITE_API_HOST || 'https://docsapi.arc53.com';
const embeddingsName =
import.meta.env.VITE_EMBEDDINGS_NAME ||
'huggingface_sentence-transformers/all-mpnet-base-v2';
type StatusType = 'loading' | 'idle' | 'fetched' | 'failed';
export const ShareConversationModal = ({
@@ -47,14 +42,12 @@ export const ShareConversationModal = ({
const extractDocPaths = (docs: Doc[]) =>
docs
? docs
.filter((doc) => doc.model === embeddingsName)
.map((doc: Doc) => {
return {
label: doc.name,
value: doc.id ?? 'default',
};
})
? docs.map((doc: Doc) => {
return {
label: doc.name,
value: doc.id ?? 'default',
};
})
: [];
const [sourcePath, setSourcePath] = useState<{