feat: model registry and capabilities for multi-provider support (#2158)

* feat: Implement model registry and capabilities for multi-provider support

- Added ModelRegistry to manage available models and their capabilities.
- Introduced ModelProvider enum for different LLM providers.
- Created ModelCapabilities dataclass to define model features.
- Implemented methods to load models based on API keys and settings.
- Added utility functions for model management in model_utils.py.
- Updated settings.py to include provider-specific API keys.
- Refactored LLM classes (Anthropic, OpenAI, Google, etc.) to utilize new model registry.
- Enhanced utility functions to handle token limits and model validation.
- Improved code structure and logging for better maintainability.

* feat: Add model selection feature with API integration and UI component

* feat: Add model selection and default model functionality in agent management

* test: Update assertions and formatting in stream processing tests

* refactor(llm): Standardize model identifier to model_id

* fix tests

---------

Co-authored-by: Alex <a@tushynski.me>
This commit is contained in:
Siddhant Rai
2025-11-14 16:43:19 +05:30
committed by GitHub
parent fbf7cf874b
commit 3f7de867cc
54 changed files with 1388 additions and 226 deletions

View File

@@ -54,6 +54,10 @@ class AnswerResource(Resource, BaseAnswerResource):
default=True,
description="Whether to save the conversation",
),
"model_id": fields.String(
required=False,
description="Model ID to use for this request",
),
"passthrough": fields.Raw(
required=False,
description="Dynamic parameters to inject into prompt template",
@@ -97,6 +101,7 @@ class AnswerResource(Resource, BaseAnswerResource):
isNoneDoc=data.get("isNoneDoc"),
index=None,
should_save_conversation=data.get("save_conversation", True),
model_id=processor.model_id,
)
stream_result = self.process_response_stream(stream)

View File

@@ -7,11 +7,16 @@ from flask import jsonify, make_response, Response
from flask_restx import Namespace
from application.api.answer.services.conversation_service import ConversationService
from application.core.model_utils import (
get_api_key_for_provider,
get_default_model_id,
get_provider_from_model_id,
)
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
from application.utils import check_required_fields, get_gpt_model
from application.utils import check_required_fields
logger = logging.getLogger(__name__)
@@ -27,7 +32,7 @@ class BaseAnswerResource:
db = mongo[settings.MONGO_DB_NAME]
self.db = db
self.user_logs_collection = db["user_logs"]
self.gpt_model = get_gpt_model()
self.default_model_id = get_default_model_id()
self.conversation_service = ConversationService()
def validate_request(
@@ -54,7 +59,6 @@ class BaseAnswerResource:
api_key = agent_config.get("user_api_key")
if not api_key:
return None
agents_collection = self.db["agents"]
agent = agents_collection.find_one({"key": api_key})
@@ -62,7 +66,6 @@ class BaseAnswerResource:
return make_response(
jsonify({"success": False, "message": "Invalid API key."}), 401
)
limited_token_mode_raw = agent.get("limited_token_mode", False)
limited_request_mode_raw = agent.get("limited_request_mode", False)
@@ -110,15 +113,12 @@ class BaseAnswerResource:
daily_token_usage = token_result[0]["total_tokens"] if token_result else 0
else:
daily_token_usage = 0
if limited_request_mode:
daily_request_usage = token_usage_collection.count_documents(match_query)
else:
daily_request_usage = 0
if not limited_token_mode and not limited_request_mode:
return None
token_exceeded = (
limited_token_mode and token_limit > 0 and daily_token_usage >= token_limit
)
@@ -138,7 +138,6 @@ class BaseAnswerResource:
),
429,
)
return None
def complete_stream(
@@ -155,6 +154,7 @@ class BaseAnswerResource:
agent_id: Optional[str] = None,
is_shared_usage: bool = False,
shared_token: Optional[str] = None,
model_id: Optional[str] = None,
) -> Generator[str, None, None]:
"""
Generator function that streams the complete conversation response.
@@ -173,6 +173,7 @@ class BaseAnswerResource:
agent_id: ID of agent used
is_shared_usage: Flag for shared agent usage
shared_token: Token for shared agent
model_id: Model ID used for the request
retrieved_docs: Pre-fetched documents for sources (optional)
Yields:
@@ -220,7 +221,6 @@ class BaseAnswerResource:
elif "type" in line:
data = json.dumps(line)
yield f"data: {data}\n\n"
if is_structured and structured_chunks:
structured_data = {
"type": "structured_answer",
@@ -230,15 +230,22 @@ class BaseAnswerResource:
}
data = json.dumps(structured_data)
yield f"data: {data}\n\n"
if isNoneDoc:
for doc in source_log_docs:
doc["source"] = "None"
provider = (
get_provider_from_model_id(model_id)
if model_id
else settings.LLM_PROVIDER
)
system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER)
llm = LLMCreator.create_llm(
settings.LLM_PROVIDER,
api_key=settings.API_KEY,
provider or settings.LLM_PROVIDER,
api_key=system_api_key,
user_api_key=user_api_key,
decoded_token=decoded_token,
model_id=model_id,
)
if should_save_conversation:
@@ -250,7 +257,7 @@ class BaseAnswerResource:
source_log_docs,
tool_calls,
llm,
self.gpt_model,
model_id or self.default_model_id,
decoded_token,
index=index,
api_key=user_api_key,
@@ -280,12 +287,11 @@ class BaseAnswerResource:
log_data["structured_output"] = True
if schema_info:
log_data["schema"] = schema_info
# Clean up text fields to be no longer than 10000 characters
for key, value in log_data.items():
if isinstance(value, str) and len(value) > 10000:
log_data[key] = value[:10000]
self.user_logs_collection.insert_one(log_data)
data = json.dumps({"type": "end"})
@@ -293,6 +299,7 @@ class BaseAnswerResource:
except GeneratorExit:
logger.info(f"Stream aborted by client for question: {question[:50]}... ")
# Save partial response
if should_save_conversation and response_full:
try:
if isNoneDoc:
@@ -312,7 +319,7 @@ class BaseAnswerResource:
source_log_docs,
tool_calls,
llm,
self.gpt_model,
model_id or self.default_model_id,
decoded_token,
index=index,
api_key=user_api_key,
@@ -369,7 +376,7 @@ class BaseAnswerResource:
thought = event["thought"]
elif event["type"] == "error":
logger.error(f"Error from stream: {event['error']}")
return None, None, None, None, event["error"]
return None, None, None, None, event["error"], None
elif event["type"] == "end":
stream_ended = True
except (json.JSONDecodeError, KeyError) as e:
@@ -377,8 +384,7 @@ class BaseAnswerResource:
continue
if not stream_ended:
logger.error("Stream ended unexpectedly without an 'end' event.")
return None, None, None, None, "Stream ended unexpectedly"
return None, None, None, None, "Stream ended unexpectedly", None
result = (
conversation_id,
response_full,
@@ -390,7 +396,6 @@ class BaseAnswerResource:
if is_structured:
result = result + ({"structured": True, "schema": schema_info},)
return result
def error_stream_generate(self, err_response):

View File

@@ -57,6 +57,10 @@ class StreamResource(Resource, BaseAnswerResource):
default=True,
description="Whether to save the conversation",
),
"model_id": fields.String(
required=False,
description="Model ID to use for this request",
),
"attachments": fields.List(
fields.String, required=False, description="List of attachment IDs"
),
@@ -101,6 +105,7 @@ class StreamResource(Resource, BaseAnswerResource):
agent_id=data.get("agent_id"),
is_shared_usage=processor.is_shared_usage,
shared_token=processor.shared_token,
model_id=processor.model_id,
),
mimetype="text/event-stream",
)

View File

@@ -52,7 +52,7 @@ class ConversationService:
sources: List[Dict[str, Any]],
tool_calls: List[Dict[str, Any]],
llm: Any,
gpt_model: str,
model_id: str,
decoded_token: Dict[str, Any],
index: Optional[int] = None,
api_key: Optional[str] = None,
@@ -66,7 +66,7 @@ class ConversationService:
if not user_id:
raise ValueError("User ID not found in token")
current_time = datetime.now(timezone.utc)
# clean up in sources array such that we save max 1k characters for text part
for source in sources:
if "text" in source and isinstance(source["text"], str):
@@ -90,6 +90,7 @@ class ConversationService:
f"queries.{index}.tool_calls": tool_calls,
f"queries.{index}.timestamp": current_time,
f"queries.{index}.attachments": attachment_ids,
f"queries.{index}.model_id": model_id,
}
},
)
@@ -120,6 +121,7 @@ class ConversationService:
"tool_calls": tool_calls,
"timestamp": current_time,
"attachments": attachment_ids,
"model_id": model_id,
}
}
},
@@ -146,7 +148,7 @@ class ConversationService:
]
completion = llm.gen(
model=gpt_model, messages=messages_summary, max_tokens=30
model=model_id, messages=messages_summary, max_tokens=30
)
conversation_data = {
@@ -162,6 +164,7 @@ class ConversationService:
"tool_calls": tool_calls,
"timestamp": current_time,
"attachments": attachment_ids,
"model_id": model_id,
}
],
}

View File

@@ -12,12 +12,17 @@ from bson.objectid import ObjectId
from application.agents.agent_creator import AgentCreator
from application.api.answer.services.conversation_service import ConversationService
from application.api.answer.services.prompt_renderer import PromptRenderer
from application.core.model_utils import (
get_api_key_for_provider,
get_default_model_id,
get_provider_from_model_id,
validate_model_id,
)
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.retriever.retriever_creator import RetrieverCreator
from application.utils import (
calculate_doc_token_budget,
get_gpt_model,
limit_chat_history,
)
@@ -83,7 +88,7 @@ class StreamProcessor:
self.retriever_config = {}
self.is_shared_usage = False
self.shared_token = None
self.gpt_model = get_gpt_model()
self.model_id: Optional[str] = None
self.conversation_service = ConversationService()
self.prompt_renderer = PromptRenderer()
self._prompt_content: Optional[str] = None
@@ -91,6 +96,7 @@ class StreamProcessor:
def initialize(self):
"""Initialize all required components for processing"""
self._validate_and_set_model()
self._configure_agent()
self._configure_source()
self._configure_retriever()
@@ -112,7 +118,7 @@ class StreamProcessor:
]
else:
self.history = limit_chat_history(
json.loads(self.data.get("history", "[]")), gpt_model=self.gpt_model
json.loads(self.data.get("history", "[]")), model_id=self.model_id
)
def _process_attachments(self):
@@ -143,6 +149,25 @@ class StreamProcessor:
)
return attachments
def _validate_and_set_model(self):
"""Validate and set model_id from request"""
from application.core.model_settings import ModelRegistry
requested_model = self.data.get("model_id")
if requested_model:
if not validate_model_id(requested_model):
registry = ModelRegistry.get_instance()
available_models = [m.id for m in registry.get_enabled_models()]
raise ValueError(
f"Invalid model_id '{requested_model}'. "
f"Available models: {', '.join(available_models[:5])}"
+ (f" and {len(available_models) - 5} more" if len(available_models) > 5 else "")
)
self.model_id = requested_model
else:
self.model_id = get_default_model_id()
def _get_agent_key(self, agent_id: Optional[str], user_id: Optional[str]) -> tuple:
"""Get API key for agent with access control"""
if not agent_id:
@@ -322,7 +347,7 @@ class StreamProcessor:
def _configure_retriever(self):
history_token_limit = int(self.data.get("token_limit", 2000))
doc_token_limit = calculate_doc_token_budget(
gpt_model=self.gpt_model, history_token_limit=history_token_limit
model_id=self.model_id, history_token_limit=history_token_limit
)
self.retriever_config = {
@@ -344,7 +369,7 @@ class StreamProcessor:
prompt=get_prompt(self.agent_config["prompt_id"], self.prompts_collection),
chunks=self.retriever_config["chunks"],
doc_token_limit=self.retriever_config.get("doc_token_limit", 50000),
gpt_model=self.gpt_model,
model_id=self.model_id,
user_api_key=self.agent_config["user_api_key"],
decoded_token=self.decoded_token,
)
@@ -626,12 +651,19 @@ class StreamProcessor:
tools_data=tools_data,
)
provider = (
get_provider_from_model_id(self.model_id)
if self.model_id
else settings.LLM_PROVIDER
)
system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER)
return AgentCreator.create_agent(
self.agent_config["agent_type"],
endpoint="stream",
llm_name=settings.LLM_PROVIDER,
gpt_model=self.gpt_model,
api_key=settings.API_KEY,
llm_name=provider or settings.LLM_PROVIDER,
model_id=self.model_id,
api_key=system_api_key,
user_api_key=self.agent_config["user_api_key"],
prompt=rendered_prompt,
chat_history=self.history,

View File

@@ -95,6 +95,8 @@ class GetAgent(Resource):
"shared": agent.get("shared_publicly", False),
"shared_metadata": agent.get("shared_metadata", {}),
"shared_token": agent.get("shared_token", ""),
"models": agent.get("models", []),
"default_model_id": agent.get("default_model_id", ""),
}
return make_response(jsonify(data), 200)
except Exception as e:
@@ -172,6 +174,8 @@ class GetAgents(Resource):
"shared": agent.get("shared_publicly", False),
"shared_metadata": agent.get("shared_metadata", {}),
"shared_token": agent.get("shared_token", ""),
"models": agent.get("models", []),
"default_model_id": agent.get("default_model_id", ""),
}
for agent in agents
if "source" in agent or "retriever" in agent
@@ -230,6 +234,14 @@ class CreateAgent(Resource):
required=False,
description="Request limit for the agent in limited mode",
),
"models": fields.List(
fields.String,
required=False,
description="List of available model IDs for this agent",
),
"default_model_id": fields.String(
required=False, description="Default model ID for this agent"
),
},
)
@@ -258,6 +270,11 @@ class CreateAgent(Resource):
data["json_schema"] = json.loads(data["json_schema"])
except json.JSONDecodeError:
data["json_schema"] = None
if "models" in data:
try:
data["models"] = json.loads(data["models"])
except json.JSONDecodeError:
data["models"] = []
print(f"Received data: {data}")
# Validate JSON schema if provided
@@ -399,6 +416,8 @@ class CreateAgent(Resource):
"updatedAt": datetime.datetime.now(datetime.timezone.utc),
"lastUsedAt": None,
"key": key,
"models": data.get("models", []),
"default_model_id": data.get("default_model_id", ""),
}
if new_agent["chunks"] == "":
new_agent["chunks"] = "2"
@@ -464,6 +483,14 @@ class UpdateAgent(Resource):
required=False,
description="Request limit for the agent in limited mode",
),
"models": fields.List(
fields.String,
required=False,
description="List of available model IDs for this agent",
),
"default_model_id": fields.String(
required=False, description="Default model ID for this agent"
),
},
)
@@ -487,7 +514,7 @@ class UpdateAgent(Resource):
data = request.get_json()
else:
data = request.form.to_dict()
json_fields = ["tools", "sources", "json_schema"]
json_fields = ["tools", "sources", "json_schema", "models"]
for field in json_fields:
if field in data and data[field]:
try:
@@ -555,6 +582,8 @@ class UpdateAgent(Resource):
"token_limit",
"limited_request_mode",
"request_limit",
"models",
"default_model_id",
]
for field in allowed_fields:

View File

@@ -0,0 +1,3 @@
from .routes import models_ns
__all__ = ["models_ns"]

View File

@@ -0,0 +1,25 @@
from flask import current_app, jsonify, make_response
from flask_restx import Namespace, Resource
from application.core.model_settings import ModelRegistry
models_ns = Namespace("models", description="Available models", path="/api")
@models_ns.route("/models")
class ModelsListResource(Resource):
def get(self):
"""Get list of available models with their capabilities."""
try:
registry = ModelRegistry.get_instance()
models = registry.get_enabled_models()
response = {
"models": [model.to_dict() for model in models],
"default_model_id": registry.default_model_id,
"count": len(models),
}
except Exception as err:
current_app.logger.error(f"Error fetching models: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 500)
return make_response(jsonify(response), 200)

View File

@@ -10,6 +10,7 @@ from .agents import agents_ns, agents_sharing_ns, agents_webhooks_ns
from .analytics import analytics_ns
from .attachments import attachments_ns
from .conversations import conversations_ns
from .models import models_ns
from .prompts import prompts_ns
from .sharing import sharing_ns
from .sources import sources_chunks_ns, sources_ns, sources_upload_ns
@@ -27,6 +28,9 @@ api.add_namespace(attachments_ns)
# Conversations
api.add_namespace(conversations_ns)
# Models
api.add_namespace(models_ns)
# Agents (main, sharing, webhooks)
api.add_namespace(agents_ns)
api.add_namespace(agents_sharing_ns)