mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 16:43:16 +00:00
feat: model registry and capabilities for multi-provider support (#2158)
* feat: Implement model registry and capabilities for multi-provider support - Added ModelRegistry to manage available models and their capabilities. - Introduced ModelProvider enum for different LLM providers. - Created ModelCapabilities dataclass to define model features. - Implemented methods to load models based on API keys and settings. - Added utility functions for model management in model_utils.py. - Updated settings.py to include provider-specific API keys. - Refactored LLM classes (Anthropic, OpenAI, Google, etc.) to utilize new model registry. - Enhanced utility functions to handle token limits and model validation. - Improved code structure and logging for better maintainability. * feat: Add model selection feature with API integration and UI component * feat: Add model selection and default model functionality in agent management * test: Update assertions and formatting in stream processing tests * refactor(llm): Standardize model identifier to model_id * fix tests --------- Co-authored-by: Alex <a@tushynski.me>
This commit is contained in:
@@ -95,6 +95,8 @@ class GetAgent(Resource):
|
||||
"shared": agent.get("shared_publicly", False),
|
||||
"shared_metadata": agent.get("shared_metadata", {}),
|
||||
"shared_token": agent.get("shared_token", ""),
|
||||
"models": agent.get("models", []),
|
||||
"default_model_id": agent.get("default_model_id", ""),
|
||||
}
|
||||
return make_response(jsonify(data), 200)
|
||||
except Exception as e:
|
||||
@@ -172,6 +174,8 @@ class GetAgents(Resource):
|
||||
"shared": agent.get("shared_publicly", False),
|
||||
"shared_metadata": agent.get("shared_metadata", {}),
|
||||
"shared_token": agent.get("shared_token", ""),
|
||||
"models": agent.get("models", []),
|
||||
"default_model_id": agent.get("default_model_id", ""),
|
||||
}
|
||||
for agent in agents
|
||||
if "source" in agent or "retriever" in agent
|
||||
@@ -230,6 +234,14 @@ class CreateAgent(Resource):
|
||||
required=False,
|
||||
description="Request limit for the agent in limited mode",
|
||||
),
|
||||
"models": fields.List(
|
||||
fields.String,
|
||||
required=False,
|
||||
description="List of available model IDs for this agent",
|
||||
),
|
||||
"default_model_id": fields.String(
|
||||
required=False, description="Default model ID for this agent"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -258,6 +270,11 @@ class CreateAgent(Resource):
|
||||
data["json_schema"] = json.loads(data["json_schema"])
|
||||
except json.JSONDecodeError:
|
||||
data["json_schema"] = None
|
||||
if "models" in data:
|
||||
try:
|
||||
data["models"] = json.loads(data["models"])
|
||||
except json.JSONDecodeError:
|
||||
data["models"] = []
|
||||
print(f"Received data: {data}")
|
||||
|
||||
# Validate JSON schema if provided
|
||||
@@ -399,6 +416,8 @@ class CreateAgent(Resource):
|
||||
"updatedAt": datetime.datetime.now(datetime.timezone.utc),
|
||||
"lastUsedAt": None,
|
||||
"key": key,
|
||||
"models": data.get("models", []),
|
||||
"default_model_id": data.get("default_model_id", ""),
|
||||
}
|
||||
if new_agent["chunks"] == "":
|
||||
new_agent["chunks"] = "2"
|
||||
@@ -464,6 +483,14 @@ class UpdateAgent(Resource):
|
||||
required=False,
|
||||
description="Request limit for the agent in limited mode",
|
||||
),
|
||||
"models": fields.List(
|
||||
fields.String,
|
||||
required=False,
|
||||
description="List of available model IDs for this agent",
|
||||
),
|
||||
"default_model_id": fields.String(
|
||||
required=False, description="Default model ID for this agent"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -487,7 +514,7 @@ class UpdateAgent(Resource):
|
||||
data = request.get_json()
|
||||
else:
|
||||
data = request.form.to_dict()
|
||||
json_fields = ["tools", "sources", "json_schema"]
|
||||
json_fields = ["tools", "sources", "json_schema", "models"]
|
||||
for field in json_fields:
|
||||
if field in data and data[field]:
|
||||
try:
|
||||
@@ -555,6 +582,8 @@ class UpdateAgent(Resource):
|
||||
"token_limit",
|
||||
"limited_request_mode",
|
||||
"request_limit",
|
||||
"models",
|
||||
"default_model_id",
|
||||
]
|
||||
|
||||
for field in allowed_fields:
|
||||
|
||||
Reference in New Issue
Block a user