diff --git a/.env-template b/.env-template index e93f0363..8b53112c 100644 --- a/.env-template +++ b/.env-template @@ -1,6 +1,12 @@ API_KEY= LLM_NAME=docsgpt VITE_API_STREAMING=true +INTERNAL_KEY= + +# Remote Embeddings (Optional - for using a remote embeddings API instead of local SentenceTransformer) +# When set, the app will use the remote API and won't load SentenceTransformer (saves RAM) +EMBEDDINGS_BASE_URL= +EMBEDDINGS_KEY= #For Azure (you can delete it if you don't use Azure) OPENAI_API_BASE= diff --git a/application/.env_sample b/application/.env_sample deleted file mode 100644 index c08b2c1d..00000000 --- a/application/.env_sample +++ /dev/null @@ -1,12 +0,0 @@ -API_KEY=your_api_key -EMBEDDINGS_KEY=your_api_key -API_URL=http://localhost:7091 -INTERNAL_KEY=your_internal_key -FLASK_APP=application/app.py -FLASK_DEBUG=true - -#For OPENAI on Azure -OPENAI_API_BASE= -OPENAI_API_VERSION= -AZURE_DEPLOYMENT_NAME= -AZURE_EMBEDDINGS_DEPLOYMENT_NAME= \ No newline at end of file diff --git a/application/celery_init.py b/application/celery_init.py index 185cc87f..3e9c3c57 100644 --- a/application/celery_init.py +++ b/application/celery_init.py @@ -21,3 +21,4 @@ def config_loggers(*args, **kwargs): celery = make_celery() +celery.config_from_object("application.celeryconfig") diff --git a/application/celeryconfig.py b/application/celeryconfig.py index 712b3bfc..5a33ee19 100644 --- a/application/celeryconfig.py +++ b/application/celeryconfig.py @@ -6,3 +6,6 @@ result_backend = os.getenv("CELERY_RESULT_BACKEND") task_serializer = 'json' result_serializer = 'json' accept_content = ['json'] + +# Autodiscover tasks +imports = ('application.api.user.tasks',) diff --git a/application/core/settings.py b/application/core/settings.py index 12759d7f..f688df9b 100644 --- a/application/core/settings.py +++ b/application/core/settings.py @@ -2,7 +2,7 @@ import os from pathlib import Path from typing import Optional -from pydantic_settings import BaseSettings +from pydantic_settings import BaseSettings, SettingsConfigDict current_dir = os.path.dirname( os.path.dirname(os.path.dirname(os.path.abspath(__file__))) @@ -10,12 +10,19 @@ current_dir = os.path.dirname( class Settings(BaseSettings): + model_config = SettingsConfigDict(extra="ignore") + AUTH_TYPE: Optional[str] = None # simple_jwt, session_jwt, or None LLM_PROVIDER: str = "docsgpt" LLM_NAME: Optional[str] = ( None # if LLM_PROVIDER is openai, LLM_NAME can be gpt-4 or gpt-3.5-turbo ) EMBEDDINGS_NAME: str = "huggingface_sentence-transformers/all-mpnet-base-v2" + EMBEDDINGS_BASE_URL: Optional[str] = None # Remote embeddings API URL (OpenAI-compatible) + EMBEDDINGS_KEY: Optional[str] = ( + None # api key for embeddings (if using openai, just copy API_KEY) + ) + CELERY_BROKER_URL: str = "redis://localhost:6379/0" CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1" MONGO_URI: str = "mongodb://localhost:27017/docsgpt" @@ -73,9 +80,6 @@ class Settings(BaseSettings): GROQ_API_KEY: Optional[str] = None HUGGINGFACE_API_KEY: Optional[str] = None - EMBEDDINGS_KEY: Optional[str] = ( - None # api key for embeddings (if using openai, just copy API_KEY) - ) OPENAI_API_BASE: Optional[str] = None # azure openai api base url OPENAI_API_VERSION: Optional[str] = None # azure openai api version AZURE_DEPLOYMENT_NAME: Optional[str] = None # azure deployment name for answering @@ -153,5 +157,6 @@ class Settings(BaseSettings): COMPRESSION_MAX_HISTORY_POINTS: int = 3 # Keep only last N compression points to prevent DB bloat -path = Path(__file__).parent.parent.absolute() +# Project root is one level above application/ +path = Path(__file__).parent.parent.parent.absolute() settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8") diff --git a/application/vectorstore/base.py b/application/vectorstore/base.py index 84839059..e5c65794 100644 --- a/application/vectorstore/base.py +++ b/application/vectorstore/base.py @@ -2,41 +2,79 @@ import logging import os from abc import ABC, abstractmethod +import requests from langchain_openai import OpenAIEmbeddings -from sentence_transformers import SentenceTransformer from application.core.settings import settings -class EmbeddingsWrapper: - def __init__(self, model_name, *args, **kwargs): - logging.info(f"Initializing EmbeddingsWrapper with model: {model_name}") - try: - kwargs.setdefault("trust_remote_code", True) - self.model = SentenceTransformer( - model_name, - config_kwargs={"allow_dangerous_deserialization": True}, - *args, - **kwargs, - ) - if self.model is None or self.model._first_module() is None: +class RemoteEmbeddings: + """ + Wrapper for remote embeddings API (OpenAI-compatible). + Used when EMBEDDINGS_BASE_URL is configured. + """ + + def __init__(self, api_url: str, model_name: str, api_key: str = None): + self.api_url = api_url.rstrip("/") + self.model_name = model_name + self.headers = {"Content-Type": "application/json"} + if api_key: + self.headers["Authorization"] = f"Bearer {api_key}" + self.dimension = None + + def _embed(self, inputs): + """Send embedding request to remote API.""" + payload = {"inputs": inputs} + if self.model_name: + payload["model"] = self.model_name + + response = requests.post( + self.api_url, headers=self.headers, json=payload, timeout=180 + ) + response.raise_for_status() + result = response.json() + + if isinstance(result, list): + if result and isinstance(result[0], list): + return result + elif result and all(isinstance(x, (int, float)) for x in result): + return [result] + elif not result: + return [] + else: raise ValueError( - f"SentenceTransformer model failed to load properly for: {model_name}" + f"Unexpected list content from remote embeddings API: {result}" ) - self.dimension = self.model.get_sentence_embedding_dimension() - logging.info(f"Successfully loaded model with dimension: {self.dimension}") - except Exception as e: - logging.error( - f"Failed to initialize SentenceTransformer with model {model_name}: {str(e)}", - exc_info=True, + elif isinstance(result, dict) and "error" in result: + raise ValueError(f"Remote embeddings API error: {result['error']}") + else: + raise ValueError( + f"Unexpected response format from remote embeddings API: {result}" ) - raise def embed_query(self, query: str): - return self.model.encode(query).tolist() + """Embed a single query string.""" + embeddings_list = self._embed(query) + if ( + isinstance(embeddings_list, list) + and len(embeddings_list) == 1 + and isinstance(embeddings_list[0], list) + ): + if self.dimension is None: + self.dimension = len(embeddings_list[0]) + return embeddings_list[0] + raise ValueError( + f"Unexpected result structure after embedding query: {embeddings_list}" + ) def embed_documents(self, documents: list): - return self.model.encode(documents).tolist() + """Embed a list of documents.""" + if not documents: + return [] + embeddings_list = self._embed(documents) + if self.dimension is None and embeddings_list: + self.dimension = len(embeddings_list[0]) + return embeddings_list def __call__(self, text): if isinstance(text, str): @@ -47,6 +85,13 @@ class EmbeddingsWrapper: raise ValueError("Input must be a string or a list of strings") +def _get_embeddings_wrapper(): + """Lazy import of EmbeddingsWrapper to avoid loading SentenceTransformer when using remote embeddings.""" + from application.vectorstore.embeddings_local import EmbeddingsWrapper + + return EmbeddingsWrapper + + class EmbeddingsSingleton: _instances = {} @@ -60,8 +105,13 @@ class EmbeddingsSingleton: @staticmethod def _create_instance(embeddings_name, *args, **kwargs): + if embeddings_name == "openai_text-embedding-ada-002": + return OpenAIEmbeddings(*args, **kwargs) + + # Lazy import EmbeddingsWrapper only when needed (avoids loading SentenceTransformer) + EmbeddingsWrapper = _get_embeddings_wrapper() + embeddings_factory = { - "openai_text-embedding-ada-002": OpenAIEmbeddings, "huggingface_sentence-transformers/all-mpnet-base-v2": lambda: EmbeddingsWrapper( "sentence-transformers/all-mpnet-base-v2" ), @@ -121,6 +171,20 @@ class BaseVectorStore(ABC): ) def _get_embeddings(self, embeddings_name, embeddings_key=None): + # Check for remote embeddings first + if settings.EMBEDDINGS_BASE_URL: + logging.info( + f"Using remote embeddings API at: {settings.EMBEDDINGS_BASE_URL}" + ) + cache_key = f"remote_{settings.EMBEDDINGS_BASE_URL}_{embeddings_name}" + if cache_key not in EmbeddingsSingleton._instances: + EmbeddingsSingleton._instances[cache_key] = RemoteEmbeddings( + api_url=settings.EMBEDDINGS_BASE_URL, + model_name=embeddings_name, + api_key=embeddings_key, + ) + return EmbeddingsSingleton._instances[cache_key] + if embeddings_name == "openai_text-embedding-ada-002": if self.is_azure_configured(): os.environ["OPENAI_API_TYPE"] = "azure" diff --git a/application/vectorstore/embeddings_local.py b/application/vectorstore/embeddings_local.py new file mode 100644 index 00000000..fb6e5f4f --- /dev/null +++ b/application/vectorstore/embeddings_local.py @@ -0,0 +1,48 @@ +""" +Local embeddings using SentenceTransformer. +This module is only imported when EMBEDDINGS_BASE_URL is not set, +to avoid loading SentenceTransformer into memory when using remote embeddings. +""" + +import logging + +from sentence_transformers import SentenceTransformer + + +class EmbeddingsWrapper: + def __init__(self, model_name, *args, **kwargs): + logging.info(f"Initializing EmbeddingsWrapper with model: {model_name}") + try: + kwargs.setdefault("trust_remote_code", True) + self.model = SentenceTransformer( + model_name, + config_kwargs={"allow_dangerous_deserialization": True}, + *args, + **kwargs, + ) + if self.model is None or self.model._first_module() is None: + raise ValueError( + f"SentenceTransformer model failed to load properly for: {model_name}" + ) + self.dimension = self.model.get_sentence_embedding_dimension() + logging.info(f"Successfully loaded model with dimension: {self.dimension}") + except Exception as e: + logging.error( + f"Failed to initialize SentenceTransformer with model {model_name}: {str(e)}", + exc_info=True, + ) + raise + + def embed_query(self, query: str): + return self.model.encode(query).tolist() + + def embed_documents(self, documents: list): + return self.model.encode(documents).tolist() + + def __call__(self, text): + if isinstance(text, str): + return self.embed_query(text) + elif isinstance(text, list): + return self.embed_documents(text) + else: + raise ValueError("Input must be a string or a list of strings") diff --git a/deployment/docker-compose.yaml b/deployment/docker-compose.yaml index 7ca110aa..91c22409 100644 --- a/deployment/docker-compose.yaml +++ b/deployment/docker-compose.yaml @@ -16,9 +16,12 @@ services: backend: user: root build: ../application + env_file: + - ../.env environment: - API_KEY=$API_KEY - - EMBEDDINGS_KEY=$API_KEY + - EMBEDDINGS_KEY=$EMBEDDINGS_KEY + - EMBEDDINGS_BASE_URL=$EMBEDDINGS_BASE_URL - LLM_PROVIDER=$LLM_PROVIDER - LLM_NAME=$LLM_NAME - CELERY_BROKER_URL=redis://redis:6379/0 @@ -41,9 +44,12 @@ services: user: root build: ../application command: celery -A application.app.celery worker -l INFO -B + env_file: + - ../.env environment: - API_KEY=$API_KEY - - EMBEDDINGS_KEY=$API_KEY + - EMBEDDINGS_KEY=$EMBEDDINGS_KEY + - EMBEDDINGS_BASE_URL=$EMBEDDINGS_BASE_URL - LLM_PROVIDER=$LLM_PROVIDER - LLM_NAME=$LLM_NAME - CELERY_BROKER_URL=redis://redis:6379/0 diff --git a/frontend/Dockerfile b/frontend/Dockerfile index 5ca21455..19574bf5 100644 --- a/frontend/Dockerfile +++ b/frontend/Dockerfile @@ -1,4 +1,4 @@ -FROM node:20.6.1-bullseye-slim +FROM node:22-bullseye-slim WORKDIR /app diff --git a/frontend/src/agents/NewAgent.tsx b/frontend/src/agents/NewAgent.tsx index 1c6b194b..c9605a26 100644 --- a/frontend/src/agents/NewAgent.tsx +++ b/frontend/src/agents/NewAgent.tsx @@ -28,9 +28,6 @@ import AgentPreview from './AgentPreview'; import { Agent, ToolSummary } from './types'; import type { Model } from '../models/types'; -const embeddingsName = - import.meta.env.VITE_EMBEDDINGS_NAME || - 'huggingface_sentence-transformers/all-mpnet-base-v2'; export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) { const { t } = useTranslation(); @@ -548,22 +545,20 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) { } else { // Single source selected - maintain backward compatibility const selectedSource = selectedSources[0]; - if (selectedSource?.model === embeddingsName) { - if (selectedSource && 'id' in selectedSource) { - setAgent((prev) => ({ - ...prev, - source: selectedSource?.id || 'default', - sources: [], // Clear sources array for single source - retriever: '', - })); - } else { - setAgent((prev) => ({ - ...prev, - source: '', - sources: [], // Clear sources array - retriever: selectedSource?.retriever || 'classic', - })); - } + if (selectedSource && 'id' in selectedSource) { + setAgent((prev) => ({ + ...prev, + source: selectedSource?.id || 'default', + sources: [], // Clear sources array for single source + retriever: '', + })); + } else { + setAgent((prev) => ({ + ...prev, + source: '', + sources: [], // Clear sources array + retriever: selectedSource?.retriever || 'classic', + })); } } } else { diff --git a/frontend/src/components/SourcesPopup.tsx b/frontend/src/components/SourcesPopup.tsx index c1b4d045..8ae399ba 100644 --- a/frontend/src/components/SourcesPopup.tsx +++ b/frontend/src/components/SourcesPopup.tsx @@ -40,10 +40,6 @@ export default function SourcesPopup({ showAbove: false, }); - const embeddingsName = - import.meta.env.VITE_EMBEDDINGS_NAME || - 'huggingface_sentence-transformers/all-mpnet-base-v2'; - const options = useSelector(selectSourceDocs); const selectedDocs = useSelector(selectSelectedDocs); @@ -147,70 +143,65 @@ export default function SourcesPopup({ {options ? ( <> {filteredOptions?.map((option: any, index: number) => { - if (option.model === embeddingsName) { - const isSelected = - selectedDocs && - Array.isArray(selectedDocs) && - selectedDocs.length > 0 && - selectedDocs.some((doc) => - option.id - ? doc.id === option.id - : doc.date === option.date, - ); - - return ( -
{ - if (isSelected) { - const updatedDocs = - selectedDocs && Array.isArray(selectedDocs) - ? selectedDocs.filter((doc) => - option.id - ? doc.id !== option.id - : doc.date !== option.date, - ) - : []; - dispatch(setSelectedDocs(updatedDocs)); - handlePostDocumentSelect( - updatedDocs.length > 0 ? updatedDocs : null, - ); - } else { - const updatedDocs = - selectedDocs && Array.isArray(selectedDocs) - ? [...selectedDocs, option] - : [option]; - dispatch(setSelectedDocs(updatedDocs)); - handlePostDocumentSelect(updatedDocs); - } - }} - > - Source - - {option.name} - -
- {isSelected && ( - Selected - )} -
-
+ const isSelected = + selectedDocs && + Array.isArray(selectedDocs) && + selectedDocs.length > 0 && + selectedDocs.some((doc) => + option.id ? doc.id === option.id : doc.date === option.date, ); - } - return null; + + return ( +
{ + if (isSelected) { + const updatedDocs = + selectedDocs && Array.isArray(selectedDocs) + ? selectedDocs.filter((doc) => + option.id + ? doc.id !== option.id + : doc.date !== option.date, + ) + : []; + dispatch(setSelectedDocs(updatedDocs)); + handlePostDocumentSelect( + updatedDocs.length > 0 ? updatedDocs : null, + ); + } else { + const updatedDocs = + selectedDocs && Array.isArray(selectedDocs) + ? [...selectedDocs, option] + : [option]; + dispatch(setSelectedDocs(updatedDocs)); + handlePostDocumentSelect(updatedDocs); + } + }} + > + Source + + {option.name} + +
+ {isSelected && ( + Selected + )} +
+
+ ); })} ) : ( diff --git a/frontend/src/modals/ShareConversationModal.tsx b/frontend/src/modals/ShareConversationModal.tsx index 1dddef6f..921fd59f 100644 --- a/frontend/src/modals/ShareConversationModal.tsx +++ b/frontend/src/modals/ShareConversationModal.tsx @@ -16,11 +16,6 @@ import { } from '../preferences/preferenceSlice'; import WrapperModal from './WrapperModal'; -const apiHost = import.meta.env.VITE_API_HOST || 'https://docsapi.arc53.com'; -const embeddingsName = - import.meta.env.VITE_EMBEDDINGS_NAME || - 'huggingface_sentence-transformers/all-mpnet-base-v2'; - type StatusType = 'loading' | 'idle' | 'fetched' | 'failed'; export const ShareConversationModal = ({ @@ -47,14 +42,12 @@ export const ShareConversationModal = ({ const extractDocPaths = (docs: Doc[]) => docs - ? docs - .filter((doc) => doc.model === embeddingsName) - .map((doc: Doc) => { - return { - label: doc.name, - value: doc.id ?? 'default', - }; - }) + ? docs.map((doc: Doc) => { + return { + label: doc.name, + value: doc.id ?? 'default', + }; + }) : []; const [sourcePath, setSourcePath] = useState<{