mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-02-16 19:20:53 +00:00
Merge branch 'main' of https://github.com/arc53/DocsGPT
This commit is contained in:
@@ -140,28 +140,28 @@ class BaseAgent(ABC):
|
||||
tool_id, action_name, call_args = parser.parse_args(call)
|
||||
|
||||
call_id = getattr(call, "id", None) or str(uuid.uuid4())
|
||||
|
||||
|
||||
# Check if parsing failed
|
||||
if tool_id is None or action_name is None:
|
||||
error_message = f"Error: Failed to parse LLM tool call. Tool name: {getattr(call, 'name', 'unknown')}"
|
||||
logger.error(error_message)
|
||||
|
||||
|
||||
tool_call_data = {
|
||||
"tool_name": "unknown",
|
||||
"call_id": call_id,
|
||||
"action_name": getattr(call, 'name', 'unknown'),
|
||||
"action_name": getattr(call, "name", "unknown"),
|
||||
"arguments": call_args or {},
|
||||
"result": f"Failed to parse tool call. Invalid tool name format: {getattr(call, 'name', 'unknown')}",
|
||||
}
|
||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
|
||||
self.tool_calls.append(tool_call_data)
|
||||
return "Failed to parse tool call.", call_id
|
||||
|
||||
|
||||
# Check if tool_id exists in available tools
|
||||
if tool_id not in tools_dict:
|
||||
error_message = f"Error: Tool ID '{tool_id}' extracted from LLM call not found in available tools_dict. Available IDs: {list(tools_dict.keys())}"
|
||||
logger.error(error_message)
|
||||
|
||||
|
||||
# Return error result
|
||||
tool_call_data = {
|
||||
"tool_name": "unknown",
|
||||
@@ -173,7 +173,7 @@ class BaseAgent(ABC):
|
||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
|
||||
self.tool_calls.append(tool_call_data)
|
||||
return f"Tool with ID {tool_id} not found.", call_id
|
||||
|
||||
|
||||
tool_call_data = {
|
||||
"tool_name": tools_dict[tool_id]["name"],
|
||||
"call_id": call_id,
|
||||
@@ -225,6 +225,7 @@ class BaseAgent(ABC):
|
||||
if tool_data["name"] == "api_tool"
|
||||
else tool_data["config"]
|
||||
),
|
||||
user_id=self.user, # Pass user ID for MCP tools credential decryption
|
||||
)
|
||||
if tool_data["name"] == "api_tool":
|
||||
print(
|
||||
|
||||
405
application/agents/tools/mcp_tool.py
Normal file
405
application/agents/tools/mcp_tool.py
Normal file
@@ -0,0 +1,405 @@
|
||||
import json
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from application.agents.tools.base import Tool
|
||||
from application.security.encryption import decrypt_credentials
|
||||
|
||||
|
||||
_mcp_session_cache = {}
|
||||
|
||||
|
||||
class MCPTool(Tool):
|
||||
"""
|
||||
MCP Tool
|
||||
Connect to remote Model Context Protocol (MCP) servers to access dynamic tools and resources. Supports various authentication methods and provides secure access to external services through the MCP protocol.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any], user_id: Optional[str] = None):
|
||||
"""
|
||||
Initialize the MCP Tool with configuration.
|
||||
|
||||
Args:
|
||||
config: Dictionary containing MCP server configuration:
|
||||
- server_url: URL of the remote MCP server
|
||||
- auth_type: Type of authentication (api_key, bearer, basic, none)
|
||||
- encrypted_credentials: Encrypted credentials (if available)
|
||||
- timeout: Request timeout in seconds (default: 30)
|
||||
user_id: User ID for decrypting credentials (required if encrypted_credentials exist)
|
||||
"""
|
||||
self.config = config
|
||||
self.server_url = config.get("server_url", "")
|
||||
self.auth_type = config.get("auth_type", "none")
|
||||
self.timeout = config.get("timeout", 30)
|
||||
|
||||
self.auth_credentials = {}
|
||||
if config.get("encrypted_credentials") and user_id:
|
||||
self.auth_credentials = decrypt_credentials(
|
||||
config["encrypted_credentials"], user_id
|
||||
)
|
||||
else:
|
||||
self.auth_credentials = config.get("auth_credentials", {})
|
||||
self.available_tools = []
|
||||
self._session = requests.Session()
|
||||
self._mcp_session_id = None
|
||||
self._setup_authentication()
|
||||
self._cache_key = self._generate_cache_key()
|
||||
|
||||
def _setup_authentication(self):
|
||||
"""Setup authentication for the MCP server connection."""
|
||||
if self.auth_type == "api_key":
|
||||
api_key = self.auth_credentials.get("api_key", "")
|
||||
header_name = self.auth_credentials.get("api_key_header", "X-API-Key")
|
||||
if api_key:
|
||||
self._session.headers.update({header_name: api_key})
|
||||
elif self.auth_type == "bearer":
|
||||
token = self.auth_credentials.get("bearer_token", "")
|
||||
if token:
|
||||
self._session.headers.update({"Authorization": f"Bearer {token}"})
|
||||
elif self.auth_type == "basic":
|
||||
username = self.auth_credentials.get("username", "")
|
||||
password = self.auth_credentials.get("password", "")
|
||||
if username and password:
|
||||
self._session.auth = (username, password)
|
||||
|
||||
def _generate_cache_key(self) -> str:
|
||||
"""Generate a unique cache key for this MCP server configuration."""
|
||||
auth_key = ""
|
||||
if self.auth_type == "bearer":
|
||||
token = self.auth_credentials.get("bearer_token", "")
|
||||
auth_key = f"bearer:{token[:10]}..." if token else "bearer:none"
|
||||
elif self.auth_type == "api_key":
|
||||
api_key = self.auth_credentials.get("api_key", "")
|
||||
auth_key = f"apikey:{api_key[:10]}..." if api_key else "apikey:none"
|
||||
elif self.auth_type == "basic":
|
||||
username = self.auth_credentials.get("username", "")
|
||||
auth_key = f"basic:{username}"
|
||||
else:
|
||||
auth_key = "none"
|
||||
return f"{self.server_url}#{auth_key}"
|
||||
|
||||
def _get_cached_session(self) -> Optional[str]:
|
||||
"""Get cached session ID if available and not expired."""
|
||||
global _mcp_session_cache
|
||||
|
||||
if self._cache_key in _mcp_session_cache:
|
||||
session_data = _mcp_session_cache[self._cache_key]
|
||||
if time.time() - session_data["created_at"] < 1800:
|
||||
return session_data["session_id"]
|
||||
else:
|
||||
del _mcp_session_cache[self._cache_key]
|
||||
return None
|
||||
|
||||
def _cache_session(self, session_id: str):
|
||||
"""Cache the session ID for reuse."""
|
||||
global _mcp_session_cache
|
||||
_mcp_session_cache[self._cache_key] = {
|
||||
"session_id": session_id,
|
||||
"created_at": time.time(),
|
||||
}
|
||||
|
||||
def _initialize_mcp_connection(self) -> Dict:
|
||||
"""
|
||||
Initialize MCP connection with the server, using cached session if available.
|
||||
|
||||
Returns:
|
||||
Server capabilities and information
|
||||
"""
|
||||
cached_session = self._get_cached_session()
|
||||
if cached_session:
|
||||
self._mcp_session_id = cached_session
|
||||
return {"cached": True}
|
||||
try:
|
||||
init_params = {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {"roots": {"listChanged": True}, "sampling": {}},
|
||||
"clientInfo": {"name": "DocsGPT", "version": "1.0.0"},
|
||||
}
|
||||
response = self._make_mcp_request("initialize", init_params)
|
||||
self._make_mcp_request("notifications/initialized")
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
return {"error": str(e), "fallback": True}
|
||||
|
||||
def _ensure_valid_session(self):
|
||||
"""Ensure we have a valid MCP session, reinitializing if needed."""
|
||||
if not self._mcp_session_id:
|
||||
self._initialize_mcp_connection()
|
||||
|
||||
def _make_mcp_request(self, method: str, params: Optional[Dict] = None) -> Dict:
|
||||
"""
|
||||
Make an MCP protocol request to the server with automatic session recovery.
|
||||
|
||||
Args:
|
||||
method: MCP method name (e.g., "tools/list", "tools/call")
|
||||
params: Parameters for the MCP method
|
||||
|
||||
Returns:
|
||||
Response data as dictionary
|
||||
|
||||
Raises:
|
||||
Exception: If request fails after retry
|
||||
"""
|
||||
mcp_message = {"jsonrpc": "2.0", "method": method}
|
||||
|
||||
if not method.startswith("notifications/"):
|
||||
mcp_message["id"] = 1
|
||||
if params:
|
||||
mcp_message["params"] = params
|
||||
return self._execute_mcp_request(mcp_message, method)
|
||||
|
||||
def _execute_mcp_request(
|
||||
self, mcp_message: Dict, method: str, is_retry: bool = False
|
||||
) -> Dict:
|
||||
"""Execute MCP request with optional retry on session failure."""
|
||||
try:
|
||||
final_headers = self._session.headers.copy()
|
||||
final_headers.update(
|
||||
{
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json, text/event-stream",
|
||||
}
|
||||
)
|
||||
|
||||
if self._mcp_session_id:
|
||||
final_headers["Mcp-Session-Id"] = self._mcp_session_id
|
||||
response = self._session.post(
|
||||
self.server_url.rstrip("/"),
|
||||
json=mcp_message,
|
||||
headers=final_headers,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
|
||||
if "mcp-session-id" in response.headers:
|
||||
self._mcp_session_id = response.headers["mcp-session-id"]
|
||||
self._cache_session(self._mcp_session_id)
|
||||
response.raise_for_status()
|
||||
|
||||
if method.startswith("notifications/"):
|
||||
return {}
|
||||
response_text = response.text.strip()
|
||||
if response_text.startswith("event:") and "data:" in response_text:
|
||||
lines = response_text.split("\n")
|
||||
data_line = None
|
||||
for line in lines:
|
||||
if line.startswith("data:"):
|
||||
data_line = line[5:].strip()
|
||||
break
|
||||
if data_line:
|
||||
try:
|
||||
result = json.loads(data_line)
|
||||
except json.JSONDecodeError:
|
||||
raise Exception(f"Invalid JSON in SSE data: {data_line}")
|
||||
else:
|
||||
raise Exception(f"No data found in SSE response: {response_text}")
|
||||
else:
|
||||
try:
|
||||
result = response.json()
|
||||
except json.JSONDecodeError:
|
||||
raise Exception(f"Invalid JSON response: {response.text}")
|
||||
if "error" in result:
|
||||
error_msg = result["error"]
|
||||
if isinstance(error_msg, dict):
|
||||
error_msg = error_msg.get("message", str(error_msg))
|
||||
raise Exception(f"MCP server error: {error_msg}")
|
||||
return result.get("result", result)
|
||||
except requests.exceptions.RequestException as e:
|
||||
if not is_retry and self._should_retry_with_new_session(e):
|
||||
self._invalidate_and_refresh_session()
|
||||
return self._execute_mcp_request(mcp_message, method, is_retry=True)
|
||||
raise Exception(f"MCP server request failed: {str(e)}")
|
||||
|
||||
def _should_retry_with_new_session(self, error: Exception) -> bool:
|
||||
"""Check if error indicates session invalidation and retry is warranted."""
|
||||
error_str = str(error).lower()
|
||||
return (
|
||||
any(
|
||||
indicator in error_str
|
||||
for indicator in [
|
||||
"invalid session",
|
||||
"session expired",
|
||||
"unauthorized",
|
||||
"401",
|
||||
"403",
|
||||
]
|
||||
)
|
||||
and self._mcp_session_id is not None
|
||||
)
|
||||
|
||||
def _invalidate_and_refresh_session(self) -> None:
|
||||
"""Invalidate current session and create a new one."""
|
||||
global _mcp_session_cache
|
||||
if self._cache_key in _mcp_session_cache:
|
||||
del _mcp_session_cache[self._cache_key]
|
||||
self._mcp_session_id = None
|
||||
self._initialize_mcp_connection()
|
||||
|
||||
def discover_tools(self) -> List[Dict]:
|
||||
"""
|
||||
Discover available tools from the MCP server using MCP protocol.
|
||||
|
||||
Returns:
|
||||
List of tool definitions from the server
|
||||
"""
|
||||
try:
|
||||
self._ensure_valid_session()
|
||||
|
||||
response = self._make_mcp_request("tools/list")
|
||||
|
||||
# Handle both formats: response with 'tools' key or response that IS the tools list
|
||||
|
||||
if isinstance(response, dict):
|
||||
if "tools" in response:
|
||||
self.available_tools = response["tools"]
|
||||
elif (
|
||||
"result" in response
|
||||
and isinstance(response["result"], dict)
|
||||
and "tools" in response["result"]
|
||||
):
|
||||
self.available_tools = response["result"]["tools"]
|
||||
else:
|
||||
self.available_tools = [response] if response else []
|
||||
elif isinstance(response, list):
|
||||
self.available_tools = response
|
||||
else:
|
||||
self.available_tools = []
|
||||
return self.available_tools
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to discover tools from MCP server: {str(e)}")
|
||||
|
||||
def execute_action(self, action_name: str, **kwargs) -> Any:
|
||||
"""
|
||||
Execute an action on the remote MCP server using MCP protocol.
|
||||
|
||||
Args:
|
||||
action_name: Name of the action to execute
|
||||
**kwargs: Parameters for the action
|
||||
|
||||
Returns:
|
||||
Result from the MCP server
|
||||
"""
|
||||
self._ensure_valid_session()
|
||||
|
||||
# Skipping empty/None values - letting the server use defaults
|
||||
|
||||
cleaned_kwargs = {}
|
||||
for key, value in kwargs.items():
|
||||
if value == "" or value is None:
|
||||
continue
|
||||
cleaned_kwargs[key] = value
|
||||
call_params = {"name": action_name, "arguments": cleaned_kwargs}
|
||||
try:
|
||||
result = self._make_mcp_request("tools/call", call_params)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to execute action '{action_name}': {str(e)}")
|
||||
|
||||
def get_actions_metadata(self) -> List[Dict]:
|
||||
"""
|
||||
Get metadata for all available actions.
|
||||
|
||||
Returns:
|
||||
List of action metadata dictionaries
|
||||
"""
|
||||
actions = []
|
||||
for tool in self.available_tools:
|
||||
input_schema = (
|
||||
tool.get("inputSchema")
|
||||
or tool.get("input_schema")
|
||||
or tool.get("schema")
|
||||
or tool.get("parameters")
|
||||
)
|
||||
|
||||
parameters_schema = {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
if input_schema:
|
||||
if isinstance(input_schema, dict):
|
||||
if "properties" in input_schema:
|
||||
parameters_schema = {
|
||||
"type": input_schema.get("type", "object"),
|
||||
"properties": input_schema.get("properties", {}),
|
||||
"required": input_schema.get("required", []),
|
||||
}
|
||||
|
||||
for key in ["additionalProperties", "description"]:
|
||||
if key in input_schema:
|
||||
parameters_schema[key] = input_schema[key]
|
||||
else:
|
||||
parameters_schema["properties"] = input_schema
|
||||
action = {
|
||||
"name": tool.get("name", ""),
|
||||
"description": tool.get("description", ""),
|
||||
"parameters": parameters_schema,
|
||||
}
|
||||
actions.append(action)
|
||||
return actions
|
||||
|
||||
def test_connection(self) -> Dict:
|
||||
"""
|
||||
Test the connection to the MCP server and validate functionality.
|
||||
|
||||
Returns:
|
||||
Dictionary with connection test results including tool count
|
||||
"""
|
||||
try:
|
||||
self._mcp_session_id = None
|
||||
|
||||
init_result = self._initialize_mcp_connection()
|
||||
|
||||
tools = self.discover_tools()
|
||||
|
||||
message = f"Successfully connected to MCP server. Found {len(tools)} tools."
|
||||
if init_result.get("cached"):
|
||||
message += " (Using cached session)"
|
||||
elif init_result.get("fallback"):
|
||||
message += " (No formal initialization required)"
|
||||
return {
|
||||
"success": True,
|
||||
"message": message,
|
||||
"tools_count": len(tools),
|
||||
"session_id": self._mcp_session_id,
|
||||
"tools": [tool.get("name", "unknown") for tool in tools[:5]],
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"Connection failed: {str(e)}",
|
||||
"tools_count": 0,
|
||||
"error_type": type(e).__name__,
|
||||
}
|
||||
|
||||
def get_config_requirements(self) -> Dict:
|
||||
return {
|
||||
"server_url": {
|
||||
"type": "string",
|
||||
"description": "URL of the remote MCP server (e.g., https://api.example.com)",
|
||||
"required": True,
|
||||
},
|
||||
"auth_type": {
|
||||
"type": "string",
|
||||
"description": "Authentication type",
|
||||
"enum": ["none", "api_key", "bearer", "basic"],
|
||||
"default": "none",
|
||||
"required": True,
|
||||
},
|
||||
"auth_credentials": {
|
||||
"type": "object",
|
||||
"description": "Authentication credentials (varies by auth_type)",
|
||||
"required": False,
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Request timeout in seconds",
|
||||
"default": 30,
|
||||
"minimum": 1,
|
||||
"maximum": 300,
|
||||
"required": False,
|
||||
},
|
||||
}
|
||||
@@ -23,16 +23,23 @@ class ToolManager:
|
||||
tool_config = self.config.get(name, {})
|
||||
self.tools[name] = obj(tool_config)
|
||||
|
||||
def load_tool(self, tool_name, tool_config):
|
||||
def load_tool(self, tool_name, tool_config, user_id=None):
|
||||
self.config[tool_name] = tool_config
|
||||
module = importlib.import_module(f"application.agents.tools.{tool_name}")
|
||||
for member_name, obj in inspect.getmembers(module, inspect.isclass):
|
||||
if issubclass(obj, Tool) and obj is not Tool:
|
||||
return obj(tool_config)
|
||||
if tool_name == "mcp_tool" and user_id:
|
||||
return obj(tool_config, user_id)
|
||||
else:
|
||||
return obj(tool_config)
|
||||
|
||||
def execute_action(self, tool_name, action_name, **kwargs):
|
||||
def execute_action(self, tool_name, action_name, user_id=None, **kwargs):
|
||||
if tool_name not in self.tools:
|
||||
raise ValueError(f"Tool '{tool_name}' not loaded")
|
||||
if tool_name == "mcp_tool" and user_id:
|
||||
tool_config = self.config.get(tool_name, {})
|
||||
tool = self.load_tool(tool_name, tool_config, user_id)
|
||||
return tool.execute_action(action_name, **kwargs)
|
||||
return self.tools[tool_name].execute_action(action_name, **kwargs)
|
||||
|
||||
def get_all_actions_metadata(self):
|
||||
|
||||
@@ -69,11 +69,8 @@ class StreamProcessor:
|
||||
self.decoded_token.get("sub") if self.decoded_token is not None else None
|
||||
)
|
||||
self.conversation_id = self.data.get("conversation_id")
|
||||
self.source = (
|
||||
{"active_docs": self.data["active_docs"]}
|
||||
if "active_docs" in self.data
|
||||
else {}
|
||||
)
|
||||
self.source = {}
|
||||
self.all_sources = []
|
||||
self.attachments = []
|
||||
self.history = []
|
||||
self.agent_config = {}
|
||||
@@ -85,6 +82,8 @@ class StreamProcessor:
|
||||
|
||||
def initialize(self):
|
||||
"""Initialize all required components for processing"""
|
||||
self._configure_agent()
|
||||
self._configure_source()
|
||||
self._configure_retriever()
|
||||
self._configure_agent()
|
||||
self._load_conversation_history()
|
||||
@@ -171,13 +170,77 @@ class StreamProcessor:
|
||||
source = data.get("source")
|
||||
if isinstance(source, DBRef):
|
||||
source_doc = self.db.dereference(source)
|
||||
data["source"] = str(source_doc["_id"])
|
||||
data["retriever"] = source_doc.get("retriever", data.get("retriever"))
|
||||
data["chunks"] = source_doc.get("chunks", data.get("chunks"))
|
||||
if source_doc:
|
||||
data["source"] = str(source_doc["_id"])
|
||||
data["retriever"] = source_doc.get("retriever", data.get("retriever"))
|
||||
data["chunks"] = source_doc.get("chunks", data.get("chunks"))
|
||||
else:
|
||||
data["source"] = None
|
||||
elif source == "default":
|
||||
data["source"] = "default"
|
||||
else:
|
||||
data["source"] = None
|
||||
# Handle multiple sources
|
||||
|
||||
sources = data.get("sources", [])
|
||||
if sources and isinstance(sources, list):
|
||||
sources_list = []
|
||||
for i, source_ref in enumerate(sources):
|
||||
if source_ref == "default":
|
||||
processed_source = {
|
||||
"id": "default",
|
||||
"retriever": "classic",
|
||||
"chunks": data.get("chunks", "2"),
|
||||
}
|
||||
sources_list.append(processed_source)
|
||||
elif isinstance(source_ref, DBRef):
|
||||
source_doc = self.db.dereference(source_ref)
|
||||
if source_doc:
|
||||
processed_source = {
|
||||
"id": str(source_doc["_id"]),
|
||||
"retriever": source_doc.get("retriever", "classic"),
|
||||
"chunks": source_doc.get("chunks", data.get("chunks", "2")),
|
||||
}
|
||||
sources_list.append(processed_source)
|
||||
data["sources"] = sources_list
|
||||
else:
|
||||
data["sources"] = []
|
||||
return data
|
||||
|
||||
def _configure_source(self):
|
||||
"""Configure the source based on agent data"""
|
||||
api_key = self.data.get("api_key") or self.agent_key
|
||||
|
||||
if api_key:
|
||||
agent_data = self._get_data_from_api_key(api_key)
|
||||
|
||||
if agent_data.get("sources") and len(agent_data["sources"]) > 0:
|
||||
source_ids = [
|
||||
source["id"] for source in agent_data["sources"] if source.get("id")
|
||||
]
|
||||
if source_ids:
|
||||
self.source = {"active_docs": source_ids}
|
||||
else:
|
||||
self.source = {}
|
||||
self.all_sources = agent_data["sources"]
|
||||
elif agent_data.get("source"):
|
||||
self.source = {"active_docs": agent_data["source"]}
|
||||
self.all_sources = [
|
||||
{
|
||||
"id": agent_data["source"],
|
||||
"retriever": agent_data.get("retriever", "classic"),
|
||||
}
|
||||
]
|
||||
else:
|
||||
self.source = {}
|
||||
self.all_sources = []
|
||||
return
|
||||
if "active_docs" in self.data:
|
||||
self.source = {"active_docs": self.data["active_docs"]}
|
||||
return
|
||||
self.source = {}
|
||||
self.all_sources = []
|
||||
|
||||
def _configure_agent(self):
|
||||
"""Configure the agent based on request data"""
|
||||
agent_id = self.data.get("agent_id")
|
||||
@@ -203,7 +266,13 @@ class StreamProcessor:
|
||||
if data_key.get("retriever"):
|
||||
self.retriever_config["retriever_name"] = data_key["retriever"]
|
||||
if data_key.get("chunks") is not None:
|
||||
self.retriever_config["chunks"] = data_key["chunks"]
|
||||
try:
|
||||
self.retriever_config["chunks"] = int(data_key["chunks"])
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
f"Invalid chunks value: {data_key['chunks']}, using default value 2"
|
||||
)
|
||||
self.retriever_config["chunks"] = 2
|
||||
elif self.agent_key:
|
||||
data_key = self._get_data_from_api_key(self.agent_key)
|
||||
self.agent_config.update(
|
||||
@@ -224,7 +293,13 @@ class StreamProcessor:
|
||||
if data_key.get("retriever"):
|
||||
self.retriever_config["retriever_name"] = data_key["retriever"]
|
||||
if data_key.get("chunks") is not None:
|
||||
self.retriever_config["chunks"] = data_key["chunks"]
|
||||
try:
|
||||
self.retriever_config["chunks"] = int(data_key["chunks"])
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
f"Invalid chunks value: {data_key['chunks']}, using default value 2"
|
||||
)
|
||||
self.retriever_config["chunks"] = 2
|
||||
else:
|
||||
self.agent_config.update(
|
||||
{
|
||||
@@ -243,7 +318,8 @@ class StreamProcessor:
|
||||
"token_limit": self.data.get("token_limit", settings.DEFAULT_MAX_HISTORY),
|
||||
}
|
||||
|
||||
if "isNoneDoc" in self.data and self.data["isNoneDoc"]:
|
||||
api_key = self.data.get("api_key") or self.agent_key
|
||||
if not api_key and "isNoneDoc" in self.data and self.data["isNoneDoc"]:
|
||||
self.retriever_config["chunks"] = 0
|
||||
|
||||
def create_agent(self):
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -26,7 +26,7 @@ class Settings(BaseSettings):
|
||||
"gpt-4o-mini": 128000,
|
||||
"gpt-3.5-turbo": 4096,
|
||||
"claude-2": 1e5,
|
||||
"gemini-2.0-flash-exp": 1e6,
|
||||
"gemini-2.5-flash": 1e6,
|
||||
}
|
||||
UPLOAD_FOLDER: str = "inputs"
|
||||
PARSE_PDF_AS_IMAGE: bool = False
|
||||
@@ -96,7 +96,7 @@ class Settings(BaseSettings):
|
||||
QDRANT_HOST: Optional[str] = None
|
||||
QDRANT_PATH: Optional[str] = None
|
||||
QDRANT_DISTANCE_FUNC: str = "Cosine"
|
||||
|
||||
|
||||
# PGVector vectorstore config
|
||||
PGVECTOR_CONNECTION_STRING: Optional[str] = None
|
||||
# Milvus vectorstore config
|
||||
@@ -116,6 +116,9 @@ class Settings(BaseSettings):
|
||||
|
||||
JWT_SECRET_KEY: str = ""
|
||||
|
||||
# Encryption settings
|
||||
ENCRYPTION_SECRET_KEY: str = "default-docsgpt-encryption-key"
|
||||
|
||||
|
||||
path = Path(__file__).parent.parent.absolute()
|
||||
settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8")
|
||||
|
||||
@@ -143,6 +143,7 @@ class GoogleLLM(BaseLLM):
|
||||
raise
|
||||
|
||||
def _clean_messages_google(self, messages):
|
||||
"""Convert OpenAI format messages to Google AI format."""
|
||||
cleaned_messages = []
|
||||
for message in messages:
|
||||
role = message.get("role")
|
||||
@@ -150,6 +151,8 @@ class GoogleLLM(BaseLLM):
|
||||
|
||||
if role == "assistant":
|
||||
role = "model"
|
||||
elif role == "tool":
|
||||
role = "model"
|
||||
|
||||
parts = []
|
||||
if role and content is not None:
|
||||
@@ -188,11 +191,63 @@ class GoogleLLM(BaseLLM):
|
||||
else:
|
||||
raise ValueError(f"Unexpected content type: {type(content)}")
|
||||
|
||||
cleaned_messages.append(types.Content(role=role, parts=parts))
|
||||
if parts:
|
||||
cleaned_messages.append(types.Content(role=role, parts=parts))
|
||||
|
||||
return cleaned_messages
|
||||
|
||||
def _clean_schema(self, schema_obj):
|
||||
"""
|
||||
Recursively remove unsupported fields from schema objects
|
||||
and validate required properties.
|
||||
"""
|
||||
if not isinstance(schema_obj, dict):
|
||||
return schema_obj
|
||||
allowed_fields = {
|
||||
"type",
|
||||
"description",
|
||||
"items",
|
||||
"properties",
|
||||
"required",
|
||||
"enum",
|
||||
"pattern",
|
||||
"minimum",
|
||||
"maximum",
|
||||
"nullable",
|
||||
"default",
|
||||
}
|
||||
|
||||
cleaned = {}
|
||||
for key, value in schema_obj.items():
|
||||
if key not in allowed_fields:
|
||||
continue
|
||||
elif key == "type" and isinstance(value, str):
|
||||
cleaned[key] = value.upper()
|
||||
elif isinstance(value, dict):
|
||||
cleaned[key] = self._clean_schema(value)
|
||||
elif isinstance(value, list):
|
||||
cleaned[key] = [self._clean_schema(item) for item in value]
|
||||
else:
|
||||
cleaned[key] = value
|
||||
|
||||
# Validate that required properties actually exist in properties
|
||||
if "required" in cleaned and "properties" in cleaned:
|
||||
valid_required = []
|
||||
properties_keys = set(cleaned["properties"].keys())
|
||||
for required_prop in cleaned["required"]:
|
||||
if required_prop in properties_keys:
|
||||
valid_required.append(required_prop)
|
||||
if valid_required:
|
||||
cleaned["required"] = valid_required
|
||||
else:
|
||||
cleaned.pop("required", None)
|
||||
elif "required" in cleaned and "properties" not in cleaned:
|
||||
cleaned.pop("required", None)
|
||||
|
||||
return cleaned
|
||||
|
||||
def _clean_tools_format(self, tools_list):
|
||||
"""Convert OpenAI format tools to Google AI format."""
|
||||
genai_tools = []
|
||||
for tool_data in tools_list:
|
||||
if tool_data["type"] == "function":
|
||||
@@ -201,18 +256,16 @@ class GoogleLLM(BaseLLM):
|
||||
properties = parameters.get("properties", {})
|
||||
|
||||
if properties:
|
||||
cleaned_properties = {}
|
||||
for k, v in properties.items():
|
||||
cleaned_properties[k] = self._clean_schema(v)
|
||||
|
||||
genai_function = dict(
|
||||
name=function["name"],
|
||||
description=function["description"],
|
||||
parameters={
|
||||
"type": "OBJECT",
|
||||
"properties": {
|
||||
k: {
|
||||
**v,
|
||||
"type": v["type"].upper() if v["type"] else None,
|
||||
}
|
||||
for k, v in properties.items()
|
||||
},
|
||||
"properties": cleaned_properties,
|
||||
"required": (
|
||||
parameters["required"]
|
||||
if "required" in parameters
|
||||
@@ -242,6 +295,7 @@ class GoogleLLM(BaseLLM):
|
||||
response_schema=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Generate content using Google AI API without streaming."""
|
||||
client = genai.Client(api_key=self.api_key)
|
||||
if formatting == "openai":
|
||||
messages = self._clean_messages_google(messages)
|
||||
@@ -281,6 +335,7 @@ class GoogleLLM(BaseLLM):
|
||||
response_schema=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Generate content using Google AI API with streaming."""
|
||||
client = genai.Client(api_key=self.api_key)
|
||||
if formatting == "openai":
|
||||
messages = self._clean_messages_google(messages)
|
||||
@@ -331,12 +386,15 @@ class GoogleLLM(BaseLLM):
|
||||
yield chunk.text
|
||||
|
||||
def _supports_tools(self):
|
||||
"""Return whether this LLM supports function calling."""
|
||||
return True
|
||||
|
||||
def _supports_structured_output(self):
|
||||
"""Return whether this LLM supports structured JSON output."""
|
||||
return True
|
||||
|
||||
def prepare_structured_output_format(self, json_schema):
|
||||
"""Convert JSON schema to Google AI structured output format."""
|
||||
if not json_schema:
|
||||
return None
|
||||
|
||||
|
||||
@@ -205,7 +205,6 @@ class LLMHandler(ABC):
|
||||
except StopIteration as e:
|
||||
tool_response, call_id = e.value
|
||||
break
|
||||
|
||||
updated_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
@@ -222,17 +221,36 @@ class LLMHandler(ABC):
|
||||
)
|
||||
|
||||
updated_messages.append(self.create_tool_message(call, tool_response))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing tool: {str(e)}", exc_info=True)
|
||||
updated_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": f"Error executing tool: {str(e)}",
|
||||
"tool_call_id": call.id,
|
||||
}
|
||||
error_call = ToolCall(
|
||||
id=call.id, name=call.name, arguments=call.arguments
|
||||
)
|
||||
error_response = f"Error executing tool: {str(e)}"
|
||||
error_message = self.create_tool_message(error_call, error_response)
|
||||
updated_messages.append(error_message)
|
||||
|
||||
call_parts = call.name.split("_")
|
||||
if len(call_parts) >= 2:
|
||||
tool_id = call_parts[-1] # Last part is tool ID (e.g., "1")
|
||||
action_name = "_".join(call_parts[:-1])
|
||||
tool_name = tools_dict.get(tool_id, {}).get("name", "unknown_tool")
|
||||
full_action_name = f"{action_name}_{tool_id}"
|
||||
else:
|
||||
tool_name = "unknown_tool"
|
||||
action_name = call.name
|
||||
full_action_name = call.name
|
||||
yield {
|
||||
"type": "tool_call",
|
||||
"data": {
|
||||
"tool_name": tool_name,
|
||||
"call_id": call.id,
|
||||
"action_name": full_action_name,
|
||||
"arguments": call.arguments,
|
||||
"error": error_response,
|
||||
"status": "error",
|
||||
},
|
||||
}
|
||||
return updated_messages
|
||||
|
||||
def handle_non_streaming(
|
||||
@@ -263,13 +281,11 @@ class LLMHandler(ABC):
|
||||
except StopIteration as e:
|
||||
messages = e.value
|
||||
break
|
||||
|
||||
response = agent.llm.gen(
|
||||
model=agent.gpt_model, messages=messages, tools=agent.tools
|
||||
)
|
||||
parsed = self.parse_response(response)
|
||||
self.llm_calls.append(build_stack_data(agent.llm))
|
||||
|
||||
return parsed.content
|
||||
|
||||
def handle_streaming(
|
||||
|
||||
@@ -17,7 +17,6 @@ class GoogleLLMHandler(LLMHandler):
|
||||
finish_reason="stop",
|
||||
raw_response=response,
|
||||
)
|
||||
|
||||
if hasattr(response, "candidates"):
|
||||
parts = response.candidates[0].content.parts if response.candidates else []
|
||||
tool_calls = [
|
||||
@@ -41,7 +40,6 @@ class GoogleLLMHandler(LLMHandler):
|
||||
finish_reason="tool_calls" if tool_calls else "stop",
|
||||
raw_response=response,
|
||||
)
|
||||
|
||||
else:
|
||||
tool_calls = []
|
||||
if hasattr(response, "function_call"):
|
||||
@@ -61,14 +59,16 @@ class GoogleLLMHandler(LLMHandler):
|
||||
|
||||
def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict:
|
||||
"""Create Google-style tool message."""
|
||||
from google.genai import types
|
||||
|
||||
return {
|
||||
"role": "tool",
|
||||
"role": "model",
|
||||
"content": [
|
||||
types.Part.from_function_response(
|
||||
name=tool_call.name, response={"result": result}
|
||||
).to_json_dict()
|
||||
{
|
||||
"function_response": {
|
||||
"name": tool_call.name,
|
||||
"response": {"result": result},
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ anthropic==0.49.0
|
||||
boto3==1.38.18
|
||||
beautifulsoup4==4.13.4
|
||||
celery==5.4.0
|
||||
cryptography==42.0.8
|
||||
dataclasses-json==0.6.7
|
||||
docx2txt==0.8
|
||||
duckduckgo-search==7.5.2
|
||||
|
||||
@@ -5,10 +5,6 @@ class BaseRetriever(ABC):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def gen(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.retriever.base import BaseRetriever
|
||||
@@ -20,10 +21,20 @@ class ClassicRAG(BaseRetriever):
|
||||
api_key=settings.API_KEY,
|
||||
decoded_token=None,
|
||||
):
|
||||
self.original_question = ""
|
||||
"""Initialize ClassicRAG retriever with vectorstore sources and LLM configuration"""
|
||||
self.original_question = source.get("question", "")
|
||||
self.chat_history = chat_history if chat_history is not None else []
|
||||
self.prompt = prompt
|
||||
self.chunks = chunks
|
||||
if isinstance(chunks, str):
|
||||
try:
|
||||
self.chunks = int(chunks)
|
||||
except ValueError:
|
||||
logging.warning(
|
||||
f"Invalid chunks value '{chunks}', using default value 2"
|
||||
)
|
||||
self.chunks = 2
|
||||
else:
|
||||
self.chunks = chunks
|
||||
self.gpt_model = gpt_model
|
||||
self.token_limit = (
|
||||
token_limit
|
||||
@@ -44,25 +55,52 @@ class ClassicRAG(BaseRetriever):
|
||||
user_api_key=self.user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
self.vectorstore = source["active_docs"] if "active_docs" in source else None
|
||||
|
||||
if "active_docs" in source and source["active_docs"] is not None:
|
||||
if isinstance(source["active_docs"], list):
|
||||
self.vectorstores = source["active_docs"]
|
||||
else:
|
||||
self.vectorstores = [source["active_docs"]]
|
||||
else:
|
||||
self.vectorstores = []
|
||||
self.question = self._rephrase_query()
|
||||
self.decoded_token = decoded_token
|
||||
self._validate_vectorstore_config()
|
||||
|
||||
def _validate_vectorstore_config(self):
|
||||
"""Validate vectorstore IDs and remove any empty/invalid entries"""
|
||||
if not self.vectorstores:
|
||||
logging.warning("No vectorstores configured for retrieval")
|
||||
return
|
||||
invalid_ids = [
|
||||
vs_id for vs_id in self.vectorstores if not vs_id or not vs_id.strip()
|
||||
]
|
||||
if invalid_ids:
|
||||
logging.warning(f"Found invalid vectorstore IDs: {invalid_ids}")
|
||||
self.vectorstores = [
|
||||
vs_id for vs_id in self.vectorstores if vs_id and vs_id.strip()
|
||||
]
|
||||
|
||||
def _rephrase_query(self):
|
||||
"""Rephrase user query with chat history context for better retrieval"""
|
||||
if (
|
||||
not self.original_question
|
||||
or not self.chat_history
|
||||
or self.chat_history == []
|
||||
or self.chunks == 0
|
||||
or self.vectorstore is None
|
||||
or not self.vectorstores
|
||||
):
|
||||
return self.original_question
|
||||
|
||||
prompt = f"""Given the following conversation history:
|
||||
|
||||
{self.chat_history}
|
||||
|
||||
|
||||
|
||||
Rephrase the following user question to be a standalone search query
|
||||
|
||||
that captures all relevant context from the conversation:
|
||||
|
||||
"""
|
||||
|
||||
messages = [
|
||||
@@ -79,44 +117,62 @@ class ClassicRAG(BaseRetriever):
|
||||
return self.original_question
|
||||
|
||||
def _get_data(self):
|
||||
if self.chunks == 0 or self.vectorstore is None:
|
||||
docs = []
|
||||
else:
|
||||
docsearch = VectorCreator.create_vectorstore(
|
||||
settings.VECTOR_STORE, self.vectorstore, settings.EMBEDDINGS_KEY
|
||||
)
|
||||
docs_temp = docsearch.search(self.question, k=self.chunks)
|
||||
docs = [
|
||||
{
|
||||
"title": i.metadata.get(
|
||||
"title", i.metadata.get("post_title", i.page_content)
|
||||
).split("/")[-1],
|
||||
"text": i.page_content,
|
||||
"source": (
|
||||
i.metadata.get("source")
|
||||
if i.metadata.get("source")
|
||||
else "local"
|
||||
),
|
||||
}
|
||||
for i in docs_temp
|
||||
]
|
||||
"""Retrieve relevant documents from configured vectorstores"""
|
||||
if self.chunks == 0 or not self.vectorstores:
|
||||
return []
|
||||
all_docs = []
|
||||
chunks_per_source = max(1, self.chunks // len(self.vectorstores))
|
||||
|
||||
return docs
|
||||
for vectorstore_id in self.vectorstores:
|
||||
if vectorstore_id:
|
||||
try:
|
||||
docsearch = VectorCreator.create_vectorstore(
|
||||
settings.VECTOR_STORE, vectorstore_id, settings.EMBEDDINGS_KEY
|
||||
)
|
||||
docs_temp = docsearch.search(self.question, k=chunks_per_source)
|
||||
|
||||
def gen():
|
||||
pass
|
||||
for doc in docs_temp:
|
||||
if hasattr(doc, "page_content") and hasattr(doc, "metadata"):
|
||||
page_content = doc.page_content
|
||||
metadata = doc.metadata
|
||||
else:
|
||||
page_content = doc.get("text", doc.get("page_content", ""))
|
||||
metadata = doc.get("metadata", {})
|
||||
title = metadata.get(
|
||||
"title", metadata.get("post_title", page_content)
|
||||
)
|
||||
if isinstance(title, str):
|
||||
title = title.split("/")[-1]
|
||||
else:
|
||||
title = str(title).split("/")[-1]
|
||||
all_docs.append(
|
||||
{
|
||||
"title": title,
|
||||
"text": page_content,
|
||||
"source": metadata.get("source") or vectorstore_id,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error searching vectorstore {vectorstore_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
continue
|
||||
return all_docs
|
||||
|
||||
def search(self, query: str = ""):
|
||||
"""Search for documents using optional query override"""
|
||||
if query:
|
||||
self.original_question = query
|
||||
self.question = self._rephrase_query()
|
||||
return self._get_data()
|
||||
|
||||
def get_params(self):
|
||||
"""Return current retriever configuration parameters"""
|
||||
return {
|
||||
"question": self.original_question,
|
||||
"rephrased_question": self.question,
|
||||
"source": self.vectorstore,
|
||||
"sources": self.vectorstores,
|
||||
"chunks": self.chunks,
|
||||
"token_limit": self.token_limit,
|
||||
"gpt_model": self.gpt_model,
|
||||
|
||||
0
application/security/__init__.py
Normal file
0
application/security/__init__.py
Normal file
85
application/security/encryption.py
Normal file
85
application/security/encryption.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.ciphers import algorithms, Cipher, modes
|
||||
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
def _derive_key(user_id: str, salt: bytes) -> bytes:
|
||||
app_secret = settings.ENCRYPTION_SECRET_KEY
|
||||
|
||||
password = f"{app_secret}#{user_id}".encode()
|
||||
|
||||
kdf = PBKDF2HMAC(
|
||||
algorithm=hashes.SHA256(),
|
||||
length=32,
|
||||
salt=salt,
|
||||
iterations=100000,
|
||||
backend=default_backend(),
|
||||
)
|
||||
|
||||
return kdf.derive(password)
|
||||
|
||||
|
||||
def encrypt_credentials(credentials: dict, user_id: str) -> str:
|
||||
if not credentials:
|
||||
return ""
|
||||
try:
|
||||
salt = os.urandom(16)
|
||||
iv = os.urandom(16)
|
||||
key = _derive_key(user_id, salt)
|
||||
|
||||
json_str = json.dumps(credentials)
|
||||
|
||||
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
|
||||
encryptor = cipher.encryptor()
|
||||
|
||||
padded_data = _pad_data(json_str.encode())
|
||||
encrypted_data = encryptor.update(padded_data) + encryptor.finalize()
|
||||
|
||||
result = salt + iv + encrypted_data
|
||||
return base64.b64encode(result).decode()
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to encrypt credentials: {e}")
|
||||
return ""
|
||||
|
||||
|
||||
def decrypt_credentials(encrypted_data: str, user_id: str) -> dict:
|
||||
if not encrypted_data:
|
||||
return {}
|
||||
try:
|
||||
data = base64.b64decode(encrypted_data.encode())
|
||||
|
||||
salt = data[:16]
|
||||
iv = data[16:32]
|
||||
encrypted_content = data[32:]
|
||||
|
||||
key = _derive_key(user_id, salt)
|
||||
|
||||
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
|
||||
decryptor = cipher.decryptor()
|
||||
|
||||
decrypted_padded = decryptor.update(encrypted_content) + decryptor.finalize()
|
||||
decrypted_data = _unpad_data(decrypted_padded)
|
||||
|
||||
return json.loads(decrypted_data.decode())
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to decrypt credentials: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
def _pad_data(data: bytes) -> bytes:
|
||||
block_size = 16
|
||||
padding_len = block_size - (len(data) % block_size)
|
||||
padding = bytes([padding_len]) * padding_len
|
||||
return data + padding
|
||||
|
||||
|
||||
def _unpad_data(data: bytes) -> bytes:
|
||||
padding_len = data[-1]
|
||||
return data[:-padding_len]
|
||||
@@ -1,20 +1,28 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import os
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
class EmbeddingsWrapper:
|
||||
def __init__(self, model_name, *args, **kwargs):
|
||||
self.model = SentenceTransformer(model_name, config_kwargs={'allow_dangerous_deserialization': True}, *args, **kwargs)
|
||||
self.model = SentenceTransformer(
|
||||
model_name,
|
||||
config_kwargs={"allow_dangerous_deserialization": True},
|
||||
*args,
|
||||
**kwargs
|
||||
)
|
||||
self.dimension = self.model.get_sentence_embedding_dimension()
|
||||
|
||||
def embed_query(self, query: str):
|
||||
return self.model.encode(query).tolist()
|
||||
|
||||
|
||||
def embed_documents(self, documents: list):
|
||||
return self.model.encode(documents).tolist()
|
||||
|
||||
|
||||
def __call__(self, text):
|
||||
if isinstance(text, str):
|
||||
return self.embed_query(text)
|
||||
@@ -24,15 +32,14 @@ class EmbeddingsWrapper:
|
||||
raise ValueError("Input must be a string or a list of strings")
|
||||
|
||||
|
||||
|
||||
class EmbeddingsSingleton:
|
||||
_instances = {}
|
||||
|
||||
@staticmethod
|
||||
def get_instance(embeddings_name, *args, **kwargs):
|
||||
if embeddings_name not in EmbeddingsSingleton._instances:
|
||||
EmbeddingsSingleton._instances[embeddings_name] = EmbeddingsSingleton._create_instance(
|
||||
embeddings_name, *args, **kwargs
|
||||
EmbeddingsSingleton._instances[embeddings_name] = (
|
||||
EmbeddingsSingleton._create_instance(embeddings_name, *args, **kwargs)
|
||||
)
|
||||
return EmbeddingsSingleton._instances[embeddings_name]
|
||||
|
||||
@@ -40,9 +47,15 @@ class EmbeddingsSingleton:
|
||||
def _create_instance(embeddings_name, *args, **kwargs):
|
||||
embeddings_factory = {
|
||||
"openai_text-embedding-ada-002": OpenAIEmbeddings,
|
||||
"huggingface_sentence-transformers/all-mpnet-base-v2": lambda: EmbeddingsWrapper("sentence-transformers/all-mpnet-base-v2"),
|
||||
"huggingface_sentence-transformers-all-mpnet-base-v2": lambda: EmbeddingsWrapper("sentence-transformers/all-mpnet-base-v2"),
|
||||
"huggingface_hkunlp/instructor-large": lambda: EmbeddingsWrapper("hkunlp/instructor-large"),
|
||||
"huggingface_sentence-transformers/all-mpnet-base-v2": lambda: EmbeddingsWrapper(
|
||||
"sentence-transformers/all-mpnet-base-v2"
|
||||
),
|
||||
"huggingface_sentence-transformers-all-mpnet-base-v2": lambda: EmbeddingsWrapper(
|
||||
"sentence-transformers/all-mpnet-base-v2"
|
||||
),
|
||||
"huggingface_hkunlp/instructor-large": lambda: EmbeddingsWrapper(
|
||||
"hkunlp/instructor-large"
|
||||
),
|
||||
}
|
||||
|
||||
if embeddings_name in embeddings_factory:
|
||||
@@ -50,34 +63,63 @@ class EmbeddingsSingleton:
|
||||
else:
|
||||
return EmbeddingsWrapper(embeddings_name, *args, **kwargs)
|
||||
|
||||
|
||||
class BaseVectorStore(ABC):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(self, *args, **kwargs):
|
||||
"""Search for similar documents/chunks in the vectorstore"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_texts(self, texts, metadatas=None, *args, **kwargs):
|
||||
"""Add texts with their embeddings to the vectorstore"""
|
||||
pass
|
||||
|
||||
def delete_index(self, *args, **kwargs):
|
||||
"""Delete the entire index/collection"""
|
||||
pass
|
||||
|
||||
def save_local(self, *args, **kwargs):
|
||||
"""Save vectorstore to local storage"""
|
||||
pass
|
||||
|
||||
def get_chunks(self, *args, **kwargs):
|
||||
"""Get all chunks from the vectorstore"""
|
||||
pass
|
||||
|
||||
def add_chunk(self, text, metadata=None, *args, **kwargs):
|
||||
"""Add a single chunk to the vectorstore"""
|
||||
pass
|
||||
|
||||
def delete_chunk(self, chunk_id, *args, **kwargs):
|
||||
"""Delete a specific chunk from the vectorstore"""
|
||||
pass
|
||||
|
||||
def is_azure_configured(self):
|
||||
return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME
|
||||
return (
|
||||
settings.OPENAI_API_BASE
|
||||
and settings.OPENAI_API_VERSION
|
||||
and settings.AZURE_DEPLOYMENT_NAME
|
||||
)
|
||||
|
||||
def _get_embeddings(self, embeddings_name, embeddings_key=None):
|
||||
if embeddings_name == "openai_text-embedding-ada-002":
|
||||
if self.is_azure_configured():
|
||||
os.environ["OPENAI_API_TYPE"] = "azure"
|
||||
embedding_instance = EmbeddingsSingleton.get_instance(
|
||||
embeddings_name,
|
||||
model=settings.AZURE_EMBEDDINGS_DEPLOYMENT_NAME
|
||||
embeddings_name, model=settings.AZURE_EMBEDDINGS_DEPLOYMENT_NAME
|
||||
)
|
||||
else:
|
||||
embedding_instance = EmbeddingsSingleton.get_instance(
|
||||
embeddings_name,
|
||||
openai_api_key=embeddings_key
|
||||
embeddings_name, openai_api_key=embeddings_key
|
||||
)
|
||||
elif embeddings_name == "huggingface_sentence-transformers/all-mpnet-base-v2":
|
||||
if os.path.exists("./models/all-mpnet-base-v2"):
|
||||
embedding_instance = EmbeddingsSingleton.get_instance(
|
||||
embeddings_name = "./models/all-mpnet-base-v2",
|
||||
embeddings_name="./models/all-mpnet-base-v2",
|
||||
)
|
||||
else:
|
||||
embedding_instance = EmbeddingsSingleton.get_instance(
|
||||
@@ -87,4 +129,3 @@ class BaseVectorStore(ABC):
|
||||
embedding_instance = EmbeddingsSingleton.get_instance(embeddings_name)
|
||||
|
||||
return embedding_instance
|
||||
|
||||
|
||||
Reference in New Issue
Block a user