mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-12-01 09:33:14 +00:00
feat: model registry and capabilities for multi-provider support (#2158)
* feat: Implement model registry and capabilities for multi-provider support - Added ModelRegistry to manage available models and their capabilities. - Introduced ModelProvider enum for different LLM providers. - Created ModelCapabilities dataclass to define model features. - Implemented methods to load models based on API keys and settings. - Added utility functions for model management in model_utils.py. - Updated settings.py to include provider-specific API keys. - Refactored LLM classes (Anthropic, OpenAI, Google, etc.) to utilize new model registry. - Enhanced utility functions to handle token limits and model validation. - Improved code structure and logging for better maintainability. * feat: Add model selection feature with API integration and UI component * feat: Add model selection and default model functionality in agent management * test: Update assertions and formatting in stream processing tests * refactor(llm): Standardize model identifier to model_id * fix tests --------- Co-authored-by: Alex <a@tushynski.me>
This commit is contained in:
@@ -54,6 +54,10 @@ class AnswerResource(Resource, BaseAnswerResource):
|
||||
default=True,
|
||||
description="Whether to save the conversation",
|
||||
),
|
||||
"model_id": fields.String(
|
||||
required=False,
|
||||
description="Model ID to use for this request",
|
||||
),
|
||||
"passthrough": fields.Raw(
|
||||
required=False,
|
||||
description="Dynamic parameters to inject into prompt template",
|
||||
@@ -97,6 +101,7 @@ class AnswerResource(Resource, BaseAnswerResource):
|
||||
isNoneDoc=data.get("isNoneDoc"),
|
||||
index=None,
|
||||
should_save_conversation=data.get("save_conversation", True),
|
||||
model_id=processor.model_id,
|
||||
)
|
||||
stream_result = self.process_response_stream(stream)
|
||||
|
||||
|
||||
@@ -7,11 +7,16 @@ from flask import jsonify, make_response, Response
|
||||
from flask_restx import Namespace
|
||||
|
||||
from application.api.answer.services.conversation_service import ConversationService
|
||||
from application.core.model_utils import (
|
||||
get_api_key_for_provider,
|
||||
get_default_model_id,
|
||||
get_provider_from_model_id,
|
||||
)
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.utils import check_required_fields, get_gpt_model
|
||||
from application.utils import check_required_fields
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -27,7 +32,7 @@ class BaseAnswerResource:
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
self.db = db
|
||||
self.user_logs_collection = db["user_logs"]
|
||||
self.gpt_model = get_gpt_model()
|
||||
self.default_model_id = get_default_model_id()
|
||||
self.conversation_service = ConversationService()
|
||||
|
||||
def validate_request(
|
||||
@@ -54,7 +59,6 @@ class BaseAnswerResource:
|
||||
api_key = agent_config.get("user_api_key")
|
||||
if not api_key:
|
||||
return None
|
||||
|
||||
agents_collection = self.db["agents"]
|
||||
agent = agents_collection.find_one({"key": api_key})
|
||||
|
||||
@@ -62,7 +66,6 @@ class BaseAnswerResource:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid API key."}), 401
|
||||
)
|
||||
|
||||
limited_token_mode_raw = agent.get("limited_token_mode", False)
|
||||
limited_request_mode_raw = agent.get("limited_request_mode", False)
|
||||
|
||||
@@ -110,15 +113,12 @@ class BaseAnswerResource:
|
||||
daily_token_usage = token_result[0]["total_tokens"] if token_result else 0
|
||||
else:
|
||||
daily_token_usage = 0
|
||||
|
||||
if limited_request_mode:
|
||||
daily_request_usage = token_usage_collection.count_documents(match_query)
|
||||
else:
|
||||
daily_request_usage = 0
|
||||
|
||||
if not limited_token_mode and not limited_request_mode:
|
||||
return None
|
||||
|
||||
token_exceeded = (
|
||||
limited_token_mode and token_limit > 0 and daily_token_usage >= token_limit
|
||||
)
|
||||
@@ -138,7 +138,6 @@ class BaseAnswerResource:
|
||||
),
|
||||
429,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def complete_stream(
|
||||
@@ -155,6 +154,7 @@ class BaseAnswerResource:
|
||||
agent_id: Optional[str] = None,
|
||||
is_shared_usage: bool = False,
|
||||
shared_token: Optional[str] = None,
|
||||
model_id: Optional[str] = None,
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Generator function that streams the complete conversation response.
|
||||
@@ -173,6 +173,7 @@ class BaseAnswerResource:
|
||||
agent_id: ID of agent used
|
||||
is_shared_usage: Flag for shared agent usage
|
||||
shared_token: Token for shared agent
|
||||
model_id: Model ID used for the request
|
||||
retrieved_docs: Pre-fetched documents for sources (optional)
|
||||
|
||||
Yields:
|
||||
@@ -220,7 +221,6 @@ class BaseAnswerResource:
|
||||
elif "type" in line:
|
||||
data = json.dumps(line)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
if is_structured and structured_chunks:
|
||||
structured_data = {
|
||||
"type": "structured_answer",
|
||||
@@ -230,15 +230,22 @@ class BaseAnswerResource:
|
||||
}
|
||||
data = json.dumps(structured_data)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
if isNoneDoc:
|
||||
for doc in source_log_docs:
|
||||
doc["source"] = "None"
|
||||
provider = (
|
||||
get_provider_from_model_id(model_id)
|
||||
if model_id
|
||||
else settings.LLM_PROVIDER
|
||||
)
|
||||
system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER)
|
||||
|
||||
llm = LLMCreator.create_llm(
|
||||
settings.LLM_PROVIDER,
|
||||
api_key=settings.API_KEY,
|
||||
provider or settings.LLM_PROVIDER,
|
||||
api_key=system_api_key,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
model_id=model_id,
|
||||
)
|
||||
|
||||
if should_save_conversation:
|
||||
@@ -250,7 +257,7 @@ class BaseAnswerResource:
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
self.gpt_model,
|
||||
model_id or self.default_model_id,
|
||||
decoded_token,
|
||||
index=index,
|
||||
api_key=user_api_key,
|
||||
@@ -280,12 +287,11 @@ class BaseAnswerResource:
|
||||
log_data["structured_output"] = True
|
||||
if schema_info:
|
||||
log_data["schema"] = schema_info
|
||||
|
||||
# Clean up text fields to be no longer than 10000 characters
|
||||
|
||||
for key, value in log_data.items():
|
||||
if isinstance(value, str) and len(value) > 10000:
|
||||
log_data[key] = value[:10000]
|
||||
|
||||
self.user_logs_collection.insert_one(log_data)
|
||||
|
||||
data = json.dumps({"type": "end"})
|
||||
@@ -293,6 +299,7 @@ class BaseAnswerResource:
|
||||
except GeneratorExit:
|
||||
logger.info(f"Stream aborted by client for question: {question[:50]}... ")
|
||||
# Save partial response
|
||||
|
||||
if should_save_conversation and response_full:
|
||||
try:
|
||||
if isNoneDoc:
|
||||
@@ -312,7 +319,7 @@ class BaseAnswerResource:
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
self.gpt_model,
|
||||
model_id or self.default_model_id,
|
||||
decoded_token,
|
||||
index=index,
|
||||
api_key=user_api_key,
|
||||
@@ -369,7 +376,7 @@ class BaseAnswerResource:
|
||||
thought = event["thought"]
|
||||
elif event["type"] == "error":
|
||||
logger.error(f"Error from stream: {event['error']}")
|
||||
return None, None, None, None, event["error"]
|
||||
return None, None, None, None, event["error"], None
|
||||
elif event["type"] == "end":
|
||||
stream_ended = True
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
@@ -377,8 +384,7 @@ class BaseAnswerResource:
|
||||
continue
|
||||
if not stream_ended:
|
||||
logger.error("Stream ended unexpectedly without an 'end' event.")
|
||||
return None, None, None, None, "Stream ended unexpectedly"
|
||||
|
||||
return None, None, None, None, "Stream ended unexpectedly", None
|
||||
result = (
|
||||
conversation_id,
|
||||
response_full,
|
||||
@@ -390,7 +396,6 @@ class BaseAnswerResource:
|
||||
|
||||
if is_structured:
|
||||
result = result + ({"structured": True, "schema": schema_info},)
|
||||
|
||||
return result
|
||||
|
||||
def error_stream_generate(self, err_response):
|
||||
|
||||
@@ -57,6 +57,10 @@ class StreamResource(Resource, BaseAnswerResource):
|
||||
default=True,
|
||||
description="Whether to save the conversation",
|
||||
),
|
||||
"model_id": fields.String(
|
||||
required=False,
|
||||
description="Model ID to use for this request",
|
||||
),
|
||||
"attachments": fields.List(
|
||||
fields.String, required=False, description="List of attachment IDs"
|
||||
),
|
||||
@@ -101,6 +105,7 @@ class StreamResource(Resource, BaseAnswerResource):
|
||||
agent_id=data.get("agent_id"),
|
||||
is_shared_usage=processor.is_shared_usage,
|
||||
shared_token=processor.shared_token,
|
||||
model_id=processor.model_id,
|
||||
),
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
|
||||
@@ -52,7 +52,7 @@ class ConversationService:
|
||||
sources: List[Dict[str, Any]],
|
||||
tool_calls: List[Dict[str, Any]],
|
||||
llm: Any,
|
||||
gpt_model: str,
|
||||
model_id: str,
|
||||
decoded_token: Dict[str, Any],
|
||||
index: Optional[int] = None,
|
||||
api_key: Optional[str] = None,
|
||||
@@ -66,7 +66,7 @@ class ConversationService:
|
||||
if not user_id:
|
||||
raise ValueError("User ID not found in token")
|
||||
current_time = datetime.now(timezone.utc)
|
||||
|
||||
|
||||
# clean up in sources array such that we save max 1k characters for text part
|
||||
for source in sources:
|
||||
if "text" in source and isinstance(source["text"], str):
|
||||
@@ -90,6 +90,7 @@ class ConversationService:
|
||||
f"queries.{index}.tool_calls": tool_calls,
|
||||
f"queries.{index}.timestamp": current_time,
|
||||
f"queries.{index}.attachments": attachment_ids,
|
||||
f"queries.{index}.model_id": model_id,
|
||||
}
|
||||
},
|
||||
)
|
||||
@@ -120,6 +121,7 @@ class ConversationService:
|
||||
"tool_calls": tool_calls,
|
||||
"timestamp": current_time,
|
||||
"attachments": attachment_ids,
|
||||
"model_id": model_id,
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -146,7 +148,7 @@ class ConversationService:
|
||||
]
|
||||
|
||||
completion = llm.gen(
|
||||
model=gpt_model, messages=messages_summary, max_tokens=30
|
||||
model=model_id, messages=messages_summary, max_tokens=30
|
||||
)
|
||||
|
||||
conversation_data = {
|
||||
@@ -162,6 +164,7 @@ class ConversationService:
|
||||
"tool_calls": tool_calls,
|
||||
"timestamp": current_time,
|
||||
"attachments": attachment_ids,
|
||||
"model_id": model_id,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
@@ -12,12 +12,17 @@ from bson.objectid import ObjectId
|
||||
from application.agents.agent_creator import AgentCreator
|
||||
from application.api.answer.services.conversation_service import ConversationService
|
||||
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||||
from application.core.model_utils import (
|
||||
get_api_key_for_provider,
|
||||
get_default_model_id,
|
||||
get_provider_from_model_id,
|
||||
validate_model_id,
|
||||
)
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.retriever.retriever_creator import RetrieverCreator
|
||||
from application.utils import (
|
||||
calculate_doc_token_budget,
|
||||
get_gpt_model,
|
||||
limit_chat_history,
|
||||
)
|
||||
|
||||
@@ -83,7 +88,7 @@ class StreamProcessor:
|
||||
self.retriever_config = {}
|
||||
self.is_shared_usage = False
|
||||
self.shared_token = None
|
||||
self.gpt_model = get_gpt_model()
|
||||
self.model_id: Optional[str] = None
|
||||
self.conversation_service = ConversationService()
|
||||
self.prompt_renderer = PromptRenderer()
|
||||
self._prompt_content: Optional[str] = None
|
||||
@@ -91,6 +96,7 @@ class StreamProcessor:
|
||||
|
||||
def initialize(self):
|
||||
"""Initialize all required components for processing"""
|
||||
self._validate_and_set_model()
|
||||
self._configure_agent()
|
||||
self._configure_source()
|
||||
self._configure_retriever()
|
||||
@@ -112,7 +118,7 @@ class StreamProcessor:
|
||||
]
|
||||
else:
|
||||
self.history = limit_chat_history(
|
||||
json.loads(self.data.get("history", "[]")), gpt_model=self.gpt_model
|
||||
json.loads(self.data.get("history", "[]")), model_id=self.model_id
|
||||
)
|
||||
|
||||
def _process_attachments(self):
|
||||
@@ -143,6 +149,25 @@ class StreamProcessor:
|
||||
)
|
||||
return attachments
|
||||
|
||||
def _validate_and_set_model(self):
|
||||
"""Validate and set model_id from request"""
|
||||
from application.core.model_settings import ModelRegistry
|
||||
|
||||
requested_model = self.data.get("model_id")
|
||||
|
||||
if requested_model:
|
||||
if not validate_model_id(requested_model):
|
||||
registry = ModelRegistry.get_instance()
|
||||
available_models = [m.id for m in registry.get_enabled_models()]
|
||||
raise ValueError(
|
||||
f"Invalid model_id '{requested_model}'. "
|
||||
f"Available models: {', '.join(available_models[:5])}"
|
||||
+ (f" and {len(available_models) - 5} more" if len(available_models) > 5 else "")
|
||||
)
|
||||
self.model_id = requested_model
|
||||
else:
|
||||
self.model_id = get_default_model_id()
|
||||
|
||||
def _get_agent_key(self, agent_id: Optional[str], user_id: Optional[str]) -> tuple:
|
||||
"""Get API key for agent with access control"""
|
||||
if not agent_id:
|
||||
@@ -322,7 +347,7 @@ class StreamProcessor:
|
||||
def _configure_retriever(self):
|
||||
history_token_limit = int(self.data.get("token_limit", 2000))
|
||||
doc_token_limit = calculate_doc_token_budget(
|
||||
gpt_model=self.gpt_model, history_token_limit=history_token_limit
|
||||
model_id=self.model_id, history_token_limit=history_token_limit
|
||||
)
|
||||
|
||||
self.retriever_config = {
|
||||
@@ -344,7 +369,7 @@ class StreamProcessor:
|
||||
prompt=get_prompt(self.agent_config["prompt_id"], self.prompts_collection),
|
||||
chunks=self.retriever_config["chunks"],
|
||||
doc_token_limit=self.retriever_config.get("doc_token_limit", 50000),
|
||||
gpt_model=self.gpt_model,
|
||||
model_id=self.model_id,
|
||||
user_api_key=self.agent_config["user_api_key"],
|
||||
decoded_token=self.decoded_token,
|
||||
)
|
||||
@@ -626,12 +651,19 @@ class StreamProcessor:
|
||||
tools_data=tools_data,
|
||||
)
|
||||
|
||||
provider = (
|
||||
get_provider_from_model_id(self.model_id)
|
||||
if self.model_id
|
||||
else settings.LLM_PROVIDER
|
||||
)
|
||||
system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER)
|
||||
|
||||
return AgentCreator.create_agent(
|
||||
self.agent_config["agent_type"],
|
||||
endpoint="stream",
|
||||
llm_name=settings.LLM_PROVIDER,
|
||||
gpt_model=self.gpt_model,
|
||||
api_key=settings.API_KEY,
|
||||
llm_name=provider or settings.LLM_PROVIDER,
|
||||
model_id=self.model_id,
|
||||
api_key=system_api_key,
|
||||
user_api_key=self.agent_config["user_api_key"],
|
||||
prompt=rendered_prompt,
|
||||
chat_history=self.history,
|
||||
|
||||
@@ -95,6 +95,8 @@ class GetAgent(Resource):
|
||||
"shared": agent.get("shared_publicly", False),
|
||||
"shared_metadata": agent.get("shared_metadata", {}),
|
||||
"shared_token": agent.get("shared_token", ""),
|
||||
"models": agent.get("models", []),
|
||||
"default_model_id": agent.get("default_model_id", ""),
|
||||
}
|
||||
return make_response(jsonify(data), 200)
|
||||
except Exception as e:
|
||||
@@ -172,6 +174,8 @@ class GetAgents(Resource):
|
||||
"shared": agent.get("shared_publicly", False),
|
||||
"shared_metadata": agent.get("shared_metadata", {}),
|
||||
"shared_token": agent.get("shared_token", ""),
|
||||
"models": agent.get("models", []),
|
||||
"default_model_id": agent.get("default_model_id", ""),
|
||||
}
|
||||
for agent in agents
|
||||
if "source" in agent or "retriever" in agent
|
||||
@@ -230,6 +234,14 @@ class CreateAgent(Resource):
|
||||
required=False,
|
||||
description="Request limit for the agent in limited mode",
|
||||
),
|
||||
"models": fields.List(
|
||||
fields.String,
|
||||
required=False,
|
||||
description="List of available model IDs for this agent",
|
||||
),
|
||||
"default_model_id": fields.String(
|
||||
required=False, description="Default model ID for this agent"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -258,6 +270,11 @@ class CreateAgent(Resource):
|
||||
data["json_schema"] = json.loads(data["json_schema"])
|
||||
except json.JSONDecodeError:
|
||||
data["json_schema"] = None
|
||||
if "models" in data:
|
||||
try:
|
||||
data["models"] = json.loads(data["models"])
|
||||
except json.JSONDecodeError:
|
||||
data["models"] = []
|
||||
print(f"Received data: {data}")
|
||||
|
||||
# Validate JSON schema if provided
|
||||
@@ -399,6 +416,8 @@ class CreateAgent(Resource):
|
||||
"updatedAt": datetime.datetime.now(datetime.timezone.utc),
|
||||
"lastUsedAt": None,
|
||||
"key": key,
|
||||
"models": data.get("models", []),
|
||||
"default_model_id": data.get("default_model_id", ""),
|
||||
}
|
||||
if new_agent["chunks"] == "":
|
||||
new_agent["chunks"] = "2"
|
||||
@@ -464,6 +483,14 @@ class UpdateAgent(Resource):
|
||||
required=False,
|
||||
description="Request limit for the agent in limited mode",
|
||||
),
|
||||
"models": fields.List(
|
||||
fields.String,
|
||||
required=False,
|
||||
description="List of available model IDs for this agent",
|
||||
),
|
||||
"default_model_id": fields.String(
|
||||
required=False, description="Default model ID for this agent"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -487,7 +514,7 @@ class UpdateAgent(Resource):
|
||||
data = request.get_json()
|
||||
else:
|
||||
data = request.form.to_dict()
|
||||
json_fields = ["tools", "sources", "json_schema"]
|
||||
json_fields = ["tools", "sources", "json_schema", "models"]
|
||||
for field in json_fields:
|
||||
if field in data and data[field]:
|
||||
try:
|
||||
@@ -555,6 +582,8 @@ class UpdateAgent(Resource):
|
||||
"token_limit",
|
||||
"limited_request_mode",
|
||||
"request_limit",
|
||||
"models",
|
||||
"default_model_id",
|
||||
]
|
||||
|
||||
for field in allowed_fields:
|
||||
|
||||
3
application/api/user/models/__init__.py
Normal file
3
application/api/user/models/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .routes import models_ns
|
||||
|
||||
__all__ = ["models_ns"]
|
||||
25
application/api/user/models/routes.py
Normal file
25
application/api/user/models/routes.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from flask import current_app, jsonify, make_response
|
||||
from flask_restx import Namespace, Resource
|
||||
|
||||
from application.core.model_settings import ModelRegistry
|
||||
|
||||
models_ns = Namespace("models", description="Available models", path="/api")
|
||||
|
||||
|
||||
@models_ns.route("/models")
|
||||
class ModelsListResource(Resource):
|
||||
def get(self):
|
||||
"""Get list of available models with their capabilities."""
|
||||
try:
|
||||
registry = ModelRegistry.get_instance()
|
||||
models = registry.get_enabled_models()
|
||||
|
||||
response = {
|
||||
"models": [model.to_dict() for model in models],
|
||||
"default_model_id": registry.default_model_id,
|
||||
"count": len(models),
|
||||
}
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error fetching models: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 500)
|
||||
return make_response(jsonify(response), 200)
|
||||
@@ -10,6 +10,7 @@ from .agents import agents_ns, agents_sharing_ns, agents_webhooks_ns
|
||||
from .analytics import analytics_ns
|
||||
from .attachments import attachments_ns
|
||||
from .conversations import conversations_ns
|
||||
from .models import models_ns
|
||||
from .prompts import prompts_ns
|
||||
from .sharing import sharing_ns
|
||||
from .sources import sources_chunks_ns, sources_ns, sources_upload_ns
|
||||
@@ -27,6 +28,9 @@ api.add_namespace(attachments_ns)
|
||||
# Conversations
|
||||
api.add_namespace(conversations_ns)
|
||||
|
||||
# Models
|
||||
api.add_namespace(models_ns)
|
||||
|
||||
# Agents (main, sharing, webhooks)
|
||||
api.add_namespace(agents_ns)
|
||||
api.add_namespace(agents_sharing_ns)
|
||||
|
||||
Reference in New Issue
Block a user