mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
Refactor agent creation and update logic to improve error handling and default values; enhance logging for better traceability
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
@@ -9,13 +10,27 @@ from application.core.settings import settings
|
||||
|
||||
class EmbeddingsWrapper:
|
||||
def __init__(self, model_name, *args, **kwargs):
|
||||
self.model = SentenceTransformer(
|
||||
model_name,
|
||||
config_kwargs={"allow_dangerous_deserialization": True},
|
||||
*args,
|
||||
**kwargs
|
||||
)
|
||||
self.dimension = self.model.get_sentence_embedding_dimension()
|
||||
logging.info(f"Initializing EmbeddingsWrapper with model: {model_name}")
|
||||
try:
|
||||
kwargs.setdefault("trust_remote_code", True)
|
||||
self.model = SentenceTransformer(
|
||||
model_name,
|
||||
config_kwargs={"allow_dangerous_deserialization": True},
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
if self.model is None or self.model._first_module() is None:
|
||||
raise ValueError(
|
||||
f"SentenceTransformer model failed to load properly for: {model_name}"
|
||||
)
|
||||
self.dimension = self.model.get_sentence_embedding_dimension()
|
||||
logging.info(f"Successfully loaded model with dimension: {self.dimension}")
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Failed to initialize SentenceTransformer with model {model_name}: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
def embed_query(self, query: str):
|
||||
return self.model.encode(query).tolist()
|
||||
@@ -117,15 +132,29 @@ class BaseVectorStore(ABC):
|
||||
embeddings_name, openai_api_key=embeddings_key
|
||||
)
|
||||
elif embeddings_name == "huggingface_sentence-transformers/all-mpnet-base-v2":
|
||||
if os.path.exists("./models/all-mpnet-base-v2"):
|
||||
possible_paths = [
|
||||
"/app/models/all-mpnet-base-v2", # Docker absolute path
|
||||
"./models/all-mpnet-base-v2", # Relative path
|
||||
]
|
||||
local_model_path = None
|
||||
for path in possible_paths:
|
||||
if os.path.exists(path):
|
||||
local_model_path = path
|
||||
logging.info(f"Found local model at path: {path}")
|
||||
break
|
||||
else:
|
||||
logging.info(f"Path does not exist: {path}")
|
||||
if local_model_path:
|
||||
embedding_instance = EmbeddingsSingleton.get_instance(
|
||||
embeddings_name="./models/all-mpnet-base-v2",
|
||||
local_model_path,
|
||||
)
|
||||
else:
|
||||
logging.warning(
|
||||
f"Local model not found in any of the paths: {possible_paths}. Falling back to HuggingFace download."
|
||||
)
|
||||
embedding_instance = EmbeddingsSingleton.get_instance(
|
||||
embeddings_name,
|
||||
)
|
||||
else:
|
||||
embedding_instance = EmbeddingsSingleton.get_instance(embeddings_name)
|
||||
|
||||
return embedding_instance
|
||||
|
||||
Reference in New Issue
Block a user