feat: model registry and capabilities for multi-provider support (#2158)

* feat: Implement model registry and capabilities for multi-provider support

- Added ModelRegistry to manage available models and their capabilities.
- Introduced ModelProvider enum for different LLM providers.
- Created ModelCapabilities dataclass to define model features.
- Implemented methods to load models based on API keys and settings.
- Added utility functions for model management in model_utils.py.
- Updated settings.py to include provider-specific API keys.
- Refactored LLM classes (Anthropic, OpenAI, Google, etc.) to utilize new model registry.
- Enhanced utility functions to handle token limits and model validation.
- Improved code structure and logging for better maintainability.

* feat: Add model selection feature with API integration and UI component

* feat: Add model selection and default model functionality in agent management

* test: Update assertions and formatting in stream processing tests

* refactor(llm): Standardize model identifier to model_id

* fix tests

---------

Co-authored-by: Alex <a@tushynski.me>
This commit is contained in:
Siddhant Rai
2025-11-14 16:43:19 +05:30
committed by GitHub
parent fbf7cf874b
commit 3f7de867cc
54 changed files with 1388 additions and 226 deletions

View File

@@ -95,6 +95,8 @@ class GetAgent(Resource):
"shared": agent.get("shared_publicly", False),
"shared_metadata": agent.get("shared_metadata", {}),
"shared_token": agent.get("shared_token", ""),
"models": agent.get("models", []),
"default_model_id": agent.get("default_model_id", ""),
}
return make_response(jsonify(data), 200)
except Exception as e:
@@ -172,6 +174,8 @@ class GetAgents(Resource):
"shared": agent.get("shared_publicly", False),
"shared_metadata": agent.get("shared_metadata", {}),
"shared_token": agent.get("shared_token", ""),
"models": agent.get("models", []),
"default_model_id": agent.get("default_model_id", ""),
}
for agent in agents
if "source" in agent or "retriever" in agent
@@ -230,6 +234,14 @@ class CreateAgent(Resource):
required=False,
description="Request limit for the agent in limited mode",
),
"models": fields.List(
fields.String,
required=False,
description="List of available model IDs for this agent",
),
"default_model_id": fields.String(
required=False, description="Default model ID for this agent"
),
},
)
@@ -258,6 +270,11 @@ class CreateAgent(Resource):
data["json_schema"] = json.loads(data["json_schema"])
except json.JSONDecodeError:
data["json_schema"] = None
if "models" in data:
try:
data["models"] = json.loads(data["models"])
except json.JSONDecodeError:
data["models"] = []
print(f"Received data: {data}")
# Validate JSON schema if provided
@@ -399,6 +416,8 @@ class CreateAgent(Resource):
"updatedAt": datetime.datetime.now(datetime.timezone.utc),
"lastUsedAt": None,
"key": key,
"models": data.get("models", []),
"default_model_id": data.get("default_model_id", ""),
}
if new_agent["chunks"] == "":
new_agent["chunks"] = "2"
@@ -464,6 +483,14 @@ class UpdateAgent(Resource):
required=False,
description="Request limit for the agent in limited mode",
),
"models": fields.List(
fields.String,
required=False,
description="List of available model IDs for this agent",
),
"default_model_id": fields.String(
required=False, description="Default model ID for this agent"
),
},
)
@@ -487,7 +514,7 @@ class UpdateAgent(Resource):
data = request.get_json()
else:
data = request.form.to_dict()
json_fields = ["tools", "sources", "json_schema"]
json_fields = ["tools", "sources", "json_schema", "models"]
for field in json_fields:
if field in data and data[field]:
try:
@@ -555,6 +582,8 @@ class UpdateAgent(Resource):
"token_limit",
"limited_request_mode",
"request_limit",
"models",
"default_model_id",
]
for field in allowed_fields:

View File

@@ -0,0 +1,3 @@
from .routes import models_ns
__all__ = ["models_ns"]

View File

@@ -0,0 +1,25 @@
from flask import current_app, jsonify, make_response
from flask_restx import Namespace, Resource
from application.core.model_settings import ModelRegistry
models_ns = Namespace("models", description="Available models", path="/api")
@models_ns.route("/models")
class ModelsListResource(Resource):
def get(self):
"""Get list of available models with their capabilities."""
try:
registry = ModelRegistry.get_instance()
models = registry.get_enabled_models()
response = {
"models": [model.to_dict() for model in models],
"default_model_id": registry.default_model_id,
"count": len(models),
}
except Exception as err:
current_app.logger.error(f"Error fetching models: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 500)
return make_response(jsonify(response), 200)

View File

@@ -10,6 +10,7 @@ from .agents import agents_ns, agents_sharing_ns, agents_webhooks_ns
from .analytics import analytics_ns
from .attachments import attachments_ns
from .conversations import conversations_ns
from .models import models_ns
from .prompts import prompts_ns
from .sharing import sharing_ns
from .sources import sources_chunks_ns, sources_ns, sources_upload_ns
@@ -27,6 +28,9 @@ api.add_namespace(attachments_ns)
# Conversations
api.add_namespace(conversations_ns)
# Models
api.add_namespace(models_ns)
# Agents (main, sharing, webhooks)
api.add_namespace(agents_ns)
api.add_namespace(agents_sharing_ns)