mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
Compare commits
31 Commits
fix/eslint
...
dependabot
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f7e167d2ec | ||
|
|
9e58eb02b3 | ||
|
|
3f7de867cc | ||
|
|
fbf7cf874b | ||
|
|
ba7278b80f | ||
|
|
9d649de6f9 | ||
|
|
7929afbf58 | ||
|
|
ceaf942e70 | ||
|
|
f355601a44 | ||
|
|
4ff99a1e86 | ||
|
|
129084ba92 | ||
|
|
2288df1293 | ||
|
|
d9dfac55e7 | ||
|
|
404cf4b7c7 | ||
|
|
f1c1fc123b | ||
|
|
9f19c7ee4c | ||
|
|
155e74eca1 | ||
|
|
ea2dc4dbcb | ||
|
|
616edc97de | ||
|
|
b017e99c79 | ||
|
|
f698e9d3e1 | ||
|
|
d366502850 | ||
|
|
3d6757c170 | ||
|
|
cb8302add8 | ||
|
|
9d266e9fad | ||
|
|
ae94c9d31e | ||
|
|
83ab232dcd | ||
|
|
eea85772a3 | ||
|
|
0fe7e223cc | ||
|
|
3789d2eb03 | ||
|
|
d54469532e |
6
.github/dependabot.yml
vendored
6
.github/dependabot.yml
vendored
@@ -13,7 +13,11 @@ updates:
|
||||
directory: "/frontend" # Location of package manifests
|
||||
schedule:
|
||||
interval: "daily"
|
||||
- package-ecosystem: "npm"
|
||||
directory: "/extensions/react-widget"
|
||||
schedule:
|
||||
interval: "daily"
|
||||
- package-ecosystem: "github-actions"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "daily"
|
||||
interval: "daily"
|
||||
@@ -147,5 +147,5 @@ Here's a step-by-step guide on how to contribute to DocsGPT:
|
||||
Thank you for considering contributing to DocsGPT! 🙏
|
||||
|
||||
## Questions/collaboration
|
||||
Feel free to join our [Discord](https://discord.gg/n5BX8dh8rU). We're very friendly and welcoming to new contributors, so don't hesitate to reach out.
|
||||
Feel free to join our [Discord](https://discord.gg/vN7YFfdMpj). We're very friendly and welcoming to new contributors, so don't hesitate to reach out.
|
||||
# Thank you so much for considering to contributing DocsGPT!🙏
|
||||
|
||||
@@ -32,7 +32,7 @@ Non-Code Contributions:
|
||||
- Before contributing check existing [issues](https://github.com/arc53/DocsGPT/issues) or [create](https://github.com/arc53/DocsGPT/issues/new/choose) an issue and wait to get assigned.
|
||||
- Once you are finished with your contribution, please fill in this [form](https://forms.gle/Npaba4n9Epfyx56S8).
|
||||
- Refer to the [Documentation](https://docs.docsgpt.cloud/).
|
||||
- Feel free to join our [Discord](https://discord.gg/n5BX8dh8rU) server. We're here to help newcomers, so don't hesitate to jump in! Join us [here](https://discord.gg/n5BX8dh8rU).
|
||||
- Feel free to join our [Discord](https://discord.gg/vN7YFfdMpj) server. We're here to help newcomers, so don't hesitate to jump in! Join us [here](https://discord.gg/vN7YFfdMpj).
|
||||
|
||||
Thank you very much for considering contributing to DocsGPT during Hacktoberfest! 🙏 Your contributions (not just simple typos) could earn you a stylish new t-shirt.
|
||||
|
||||
|
||||
@@ -16,10 +16,10 @@
|
||||
<a href="https://github.com/arc53/DocsGPT"></a>
|
||||
<a href="https://github.com/arc53/DocsGPT/blob/main/LICENSE"></a>
|
||||
<a href="https://www.bestpractices.dev/projects/9907"><img src="https://www.bestpractices.dev/projects/9907/badge"></a>
|
||||
<a href="https://discord.gg/n5BX8dh8rU"></a>
|
||||
<a href="https://discord.gg/vN7YFfdMpj"></a>
|
||||
<a href="https://x.com/docsgptai"></a>
|
||||
|
||||
<a href="https://docs.docsgpt.cloud/quickstart">⚡️ Quickstart</a> • <a href="https://app.docsgpt.cloud/">☁️ Cloud Version</a> • <a href="https://discord.gg/n5BX8dh8rU">💬 Discord</a>
|
||||
<a href="https://docs.docsgpt.cloud/quickstart">⚡️ Quickstart</a> • <a href="https://app.docsgpt.cloud/">☁️ Cloud Version</a> • <a href="https://discord.gg/vN7YFfdMpj">💬 Discord</a>
|
||||
<br>
|
||||
<a href="https://docs.docsgpt.cloud/">📖 Documentation</a> • <a href="https://github.com/arc53/DocsGPT/blob/main/CONTRIBUTING.md">👫 Contribute</a> • <a href="https://blog.docsgpt.cloud/">🗞 Blog</a>
|
||||
<br>
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
from application.agents.classic_agent import ClassicAgent
|
||||
from application.agents.react_agent import ReActAgent
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentCreator:
|
||||
@@ -13,4 +16,5 @@ class AgentCreator:
|
||||
agent_class = cls.agents.get(type.lower())
|
||||
if not agent_class:
|
||||
raise ValueError(f"No agent class found for type {type}")
|
||||
|
||||
return agent_class(*args, **kwargs)
|
||||
|
||||
@@ -21,7 +21,7 @@ class BaseAgent(ABC):
|
||||
self,
|
||||
endpoint: str,
|
||||
llm_name: str,
|
||||
gpt_model: str,
|
||||
model_id: str,
|
||||
api_key: str,
|
||||
user_api_key: Optional[str] = None,
|
||||
prompt: str = "",
|
||||
@@ -37,7 +37,7 @@ class BaseAgent(ABC):
|
||||
):
|
||||
self.endpoint = endpoint
|
||||
self.llm_name = llm_name
|
||||
self.gpt_model = gpt_model
|
||||
self.model_id = model_id
|
||||
self.api_key = api_key
|
||||
self.user_api_key = user_api_key
|
||||
self.prompt = prompt
|
||||
@@ -52,6 +52,7 @@ class BaseAgent(ABC):
|
||||
api_key=api_key,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
model_id=model_id,
|
||||
)
|
||||
self.retrieved_docs = retrieved_docs or []
|
||||
self.llm_handler = LLMHandlerCreator.create_handler(
|
||||
@@ -316,7 +317,7 @@ class BaseAgent(ABC):
|
||||
return messages
|
||||
|
||||
def _llm_gen(self, messages: List[Dict], log_context: Optional[LogContext] = None):
|
||||
gen_kwargs = {"model": self.gpt_model, "messages": messages}
|
||||
gen_kwargs = {"model": self.model_id, "messages": messages}
|
||||
|
||||
if (
|
||||
hasattr(self.llm, "_supports_tools")
|
||||
|
||||
@@ -86,7 +86,7 @@ class ReActAgent(BaseAgent):
|
||||
messages = [{"role": "user", "content": plan_prompt}]
|
||||
|
||||
plan_stream = self.llm.gen_stream(
|
||||
model=self.gpt_model,
|
||||
model=self.model_id,
|
||||
messages=messages,
|
||||
tools=self.tools if self.tools else None,
|
||||
)
|
||||
@@ -151,7 +151,7 @@ class ReActAgent(BaseAgent):
|
||||
messages = [{"role": "user", "content": final_prompt}]
|
||||
|
||||
final_stream = self.llm.gen_stream(
|
||||
model=self.gpt_model, messages=messages, tools=None
|
||||
model=self.model_id, messages=messages, tools=None
|
||||
)
|
||||
|
||||
if log_context:
|
||||
|
||||
@@ -54,6 +54,10 @@ class AnswerResource(Resource, BaseAnswerResource):
|
||||
default=True,
|
||||
description="Whether to save the conversation",
|
||||
),
|
||||
"model_id": fields.String(
|
||||
required=False,
|
||||
description="Model ID to use for this request",
|
||||
),
|
||||
"passthrough": fields.Raw(
|
||||
required=False,
|
||||
description="Dynamic parameters to inject into prompt template",
|
||||
@@ -97,6 +101,7 @@ class AnswerResource(Resource, BaseAnswerResource):
|
||||
isNoneDoc=data.get("isNoneDoc"),
|
||||
index=None,
|
||||
should_save_conversation=data.get("save_conversation", True),
|
||||
model_id=processor.model_id,
|
||||
)
|
||||
stream_result = self.process_response_stream(stream)
|
||||
|
||||
|
||||
@@ -7,11 +7,16 @@ from flask import jsonify, make_response, Response
|
||||
from flask_restx import Namespace
|
||||
|
||||
from application.api.answer.services.conversation_service import ConversationService
|
||||
from application.core.model_utils import (
|
||||
get_api_key_for_provider,
|
||||
get_default_model_id,
|
||||
get_provider_from_model_id,
|
||||
)
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.utils import check_required_fields, get_gpt_model
|
||||
from application.utils import check_required_fields
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -27,7 +32,7 @@ class BaseAnswerResource:
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
self.db = db
|
||||
self.user_logs_collection = db["user_logs"]
|
||||
self.gpt_model = get_gpt_model()
|
||||
self.default_model_id = get_default_model_id()
|
||||
self.conversation_service = ConversationService()
|
||||
|
||||
def validate_request(
|
||||
@@ -54,7 +59,6 @@ class BaseAnswerResource:
|
||||
api_key = agent_config.get("user_api_key")
|
||||
if not api_key:
|
||||
return None
|
||||
|
||||
agents_collection = self.db["agents"]
|
||||
agent = agents_collection.find_one({"key": api_key})
|
||||
|
||||
@@ -62,7 +66,6 @@ class BaseAnswerResource:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid API key."}), 401
|
||||
)
|
||||
|
||||
limited_token_mode_raw = agent.get("limited_token_mode", False)
|
||||
limited_request_mode_raw = agent.get("limited_request_mode", False)
|
||||
|
||||
@@ -110,15 +113,12 @@ class BaseAnswerResource:
|
||||
daily_token_usage = token_result[0]["total_tokens"] if token_result else 0
|
||||
else:
|
||||
daily_token_usage = 0
|
||||
|
||||
if limited_request_mode:
|
||||
daily_request_usage = token_usage_collection.count_documents(match_query)
|
||||
else:
|
||||
daily_request_usage = 0
|
||||
|
||||
if not limited_token_mode and not limited_request_mode:
|
||||
return None
|
||||
|
||||
token_exceeded = (
|
||||
limited_token_mode and token_limit > 0 and daily_token_usage >= token_limit
|
||||
)
|
||||
@@ -138,7 +138,6 @@ class BaseAnswerResource:
|
||||
),
|
||||
429,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def complete_stream(
|
||||
@@ -155,6 +154,7 @@ class BaseAnswerResource:
|
||||
agent_id: Optional[str] = None,
|
||||
is_shared_usage: bool = False,
|
||||
shared_token: Optional[str] = None,
|
||||
model_id: Optional[str] = None,
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Generator function that streams the complete conversation response.
|
||||
@@ -173,6 +173,7 @@ class BaseAnswerResource:
|
||||
agent_id: ID of agent used
|
||||
is_shared_usage: Flag for shared agent usage
|
||||
shared_token: Token for shared agent
|
||||
model_id: Model ID used for the request
|
||||
retrieved_docs: Pre-fetched documents for sources (optional)
|
||||
|
||||
Yields:
|
||||
@@ -220,7 +221,6 @@ class BaseAnswerResource:
|
||||
elif "type" in line:
|
||||
data = json.dumps(line)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
if is_structured and structured_chunks:
|
||||
structured_data = {
|
||||
"type": "structured_answer",
|
||||
@@ -230,15 +230,22 @@ class BaseAnswerResource:
|
||||
}
|
||||
data = json.dumps(structured_data)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
if isNoneDoc:
|
||||
for doc in source_log_docs:
|
||||
doc["source"] = "None"
|
||||
provider = (
|
||||
get_provider_from_model_id(model_id)
|
||||
if model_id
|
||||
else settings.LLM_PROVIDER
|
||||
)
|
||||
system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER)
|
||||
|
||||
llm = LLMCreator.create_llm(
|
||||
settings.LLM_PROVIDER,
|
||||
api_key=settings.API_KEY,
|
||||
provider or settings.LLM_PROVIDER,
|
||||
api_key=system_api_key,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
model_id=model_id,
|
||||
)
|
||||
|
||||
if should_save_conversation:
|
||||
@@ -250,7 +257,7 @@ class BaseAnswerResource:
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
self.gpt_model,
|
||||
model_id or self.default_model_id,
|
||||
decoded_token,
|
||||
index=index,
|
||||
api_key=user_api_key,
|
||||
@@ -280,12 +287,11 @@ class BaseAnswerResource:
|
||||
log_data["structured_output"] = True
|
||||
if schema_info:
|
||||
log_data["schema"] = schema_info
|
||||
|
||||
# Clean up text fields to be no longer than 10000 characters
|
||||
|
||||
for key, value in log_data.items():
|
||||
if isinstance(value, str) and len(value) > 10000:
|
||||
log_data[key] = value[:10000]
|
||||
|
||||
self.user_logs_collection.insert_one(log_data)
|
||||
|
||||
data = json.dumps({"type": "end"})
|
||||
@@ -293,6 +299,7 @@ class BaseAnswerResource:
|
||||
except GeneratorExit:
|
||||
logger.info(f"Stream aborted by client for question: {question[:50]}... ")
|
||||
# Save partial response
|
||||
|
||||
if should_save_conversation and response_full:
|
||||
try:
|
||||
if isNoneDoc:
|
||||
@@ -312,7 +319,7 @@ class BaseAnswerResource:
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
self.gpt_model,
|
||||
model_id or self.default_model_id,
|
||||
decoded_token,
|
||||
index=index,
|
||||
api_key=user_api_key,
|
||||
@@ -369,7 +376,7 @@ class BaseAnswerResource:
|
||||
thought = event["thought"]
|
||||
elif event["type"] == "error":
|
||||
logger.error(f"Error from stream: {event['error']}")
|
||||
return None, None, None, None, event["error"]
|
||||
return None, None, None, None, event["error"], None
|
||||
elif event["type"] == "end":
|
||||
stream_ended = True
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
@@ -377,8 +384,7 @@ class BaseAnswerResource:
|
||||
continue
|
||||
if not stream_ended:
|
||||
logger.error("Stream ended unexpectedly without an 'end' event.")
|
||||
return None, None, None, None, "Stream ended unexpectedly"
|
||||
|
||||
return None, None, None, None, "Stream ended unexpectedly", None
|
||||
result = (
|
||||
conversation_id,
|
||||
response_full,
|
||||
@@ -390,7 +396,6 @@ class BaseAnswerResource:
|
||||
|
||||
if is_structured:
|
||||
result = result + ({"structured": True, "schema": schema_info},)
|
||||
|
||||
return result
|
||||
|
||||
def error_stream_generate(self, err_response):
|
||||
|
||||
@@ -57,6 +57,10 @@ class StreamResource(Resource, BaseAnswerResource):
|
||||
default=True,
|
||||
description="Whether to save the conversation",
|
||||
),
|
||||
"model_id": fields.String(
|
||||
required=False,
|
||||
description="Model ID to use for this request",
|
||||
),
|
||||
"attachments": fields.List(
|
||||
fields.String, required=False, description="List of attachment IDs"
|
||||
),
|
||||
@@ -101,6 +105,7 @@ class StreamResource(Resource, BaseAnswerResource):
|
||||
agent_id=data.get("agent_id"),
|
||||
is_shared_usage=processor.is_shared_usage,
|
||||
shared_token=processor.shared_token,
|
||||
model_id=processor.model_id,
|
||||
),
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
|
||||
@@ -52,7 +52,7 @@ class ConversationService:
|
||||
sources: List[Dict[str, Any]],
|
||||
tool_calls: List[Dict[str, Any]],
|
||||
llm: Any,
|
||||
gpt_model: str,
|
||||
model_id: str,
|
||||
decoded_token: Dict[str, Any],
|
||||
index: Optional[int] = None,
|
||||
api_key: Optional[str] = None,
|
||||
@@ -66,7 +66,7 @@ class ConversationService:
|
||||
if not user_id:
|
||||
raise ValueError("User ID not found in token")
|
||||
current_time = datetime.now(timezone.utc)
|
||||
|
||||
|
||||
# clean up in sources array such that we save max 1k characters for text part
|
||||
for source in sources:
|
||||
if "text" in source and isinstance(source["text"], str):
|
||||
@@ -90,6 +90,7 @@ class ConversationService:
|
||||
f"queries.{index}.tool_calls": tool_calls,
|
||||
f"queries.{index}.timestamp": current_time,
|
||||
f"queries.{index}.attachments": attachment_ids,
|
||||
f"queries.{index}.model_id": model_id,
|
||||
}
|
||||
},
|
||||
)
|
||||
@@ -120,6 +121,7 @@ class ConversationService:
|
||||
"tool_calls": tool_calls,
|
||||
"timestamp": current_time,
|
||||
"attachments": attachment_ids,
|
||||
"model_id": model_id,
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -146,7 +148,7 @@ class ConversationService:
|
||||
]
|
||||
|
||||
completion = llm.gen(
|
||||
model=gpt_model, messages=messages_summary, max_tokens=30
|
||||
model=model_id, messages=messages_summary, max_tokens=30
|
||||
)
|
||||
|
||||
conversation_data = {
|
||||
@@ -162,6 +164,7 @@ class ConversationService:
|
||||
"tool_calls": tool_calls,
|
||||
"timestamp": current_time,
|
||||
"attachments": attachment_ids,
|
||||
"model_id": model_id,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
@@ -12,12 +12,17 @@ from bson.objectid import ObjectId
|
||||
from application.agents.agent_creator import AgentCreator
|
||||
from application.api.answer.services.conversation_service import ConversationService
|
||||
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||||
from application.core.model_utils import (
|
||||
get_api_key_for_provider,
|
||||
get_default_model_id,
|
||||
get_provider_from_model_id,
|
||||
validate_model_id,
|
||||
)
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.retriever.retriever_creator import RetrieverCreator
|
||||
from application.utils import (
|
||||
calculate_doc_token_budget,
|
||||
get_gpt_model,
|
||||
limit_chat_history,
|
||||
)
|
||||
|
||||
@@ -83,7 +88,7 @@ class StreamProcessor:
|
||||
self.retriever_config = {}
|
||||
self.is_shared_usage = False
|
||||
self.shared_token = None
|
||||
self.gpt_model = get_gpt_model()
|
||||
self.model_id: Optional[str] = None
|
||||
self.conversation_service = ConversationService()
|
||||
self.prompt_renderer = PromptRenderer()
|
||||
self._prompt_content: Optional[str] = None
|
||||
@@ -91,6 +96,7 @@ class StreamProcessor:
|
||||
|
||||
def initialize(self):
|
||||
"""Initialize all required components for processing"""
|
||||
self._validate_and_set_model()
|
||||
self._configure_agent()
|
||||
self._configure_source()
|
||||
self._configure_retriever()
|
||||
@@ -112,7 +118,7 @@ class StreamProcessor:
|
||||
]
|
||||
else:
|
||||
self.history = limit_chat_history(
|
||||
json.loads(self.data.get("history", "[]")), gpt_model=self.gpt_model
|
||||
json.loads(self.data.get("history", "[]")), model_id=self.model_id
|
||||
)
|
||||
|
||||
def _process_attachments(self):
|
||||
@@ -143,6 +149,25 @@ class StreamProcessor:
|
||||
)
|
||||
return attachments
|
||||
|
||||
def _validate_and_set_model(self):
|
||||
"""Validate and set model_id from request"""
|
||||
from application.core.model_settings import ModelRegistry
|
||||
|
||||
requested_model = self.data.get("model_id")
|
||||
|
||||
if requested_model:
|
||||
if not validate_model_id(requested_model):
|
||||
registry = ModelRegistry.get_instance()
|
||||
available_models = [m.id for m in registry.get_enabled_models()]
|
||||
raise ValueError(
|
||||
f"Invalid model_id '{requested_model}'. "
|
||||
f"Available models: {', '.join(available_models[:5])}"
|
||||
+ (f" and {len(available_models) - 5} more" if len(available_models) > 5 else "")
|
||||
)
|
||||
self.model_id = requested_model
|
||||
else:
|
||||
self.model_id = get_default_model_id()
|
||||
|
||||
def _get_agent_key(self, agent_id: Optional[str], user_id: Optional[str]) -> tuple:
|
||||
"""Get API key for agent with access control"""
|
||||
if not agent_id:
|
||||
@@ -322,7 +347,7 @@ class StreamProcessor:
|
||||
def _configure_retriever(self):
|
||||
history_token_limit = int(self.data.get("token_limit", 2000))
|
||||
doc_token_limit = calculate_doc_token_budget(
|
||||
gpt_model=self.gpt_model, history_token_limit=history_token_limit
|
||||
model_id=self.model_id, history_token_limit=history_token_limit
|
||||
)
|
||||
|
||||
self.retriever_config = {
|
||||
@@ -344,7 +369,7 @@ class StreamProcessor:
|
||||
prompt=get_prompt(self.agent_config["prompt_id"], self.prompts_collection),
|
||||
chunks=self.retriever_config["chunks"],
|
||||
doc_token_limit=self.retriever_config.get("doc_token_limit", 50000),
|
||||
gpt_model=self.gpt_model,
|
||||
model_id=self.model_id,
|
||||
user_api_key=self.agent_config["user_api_key"],
|
||||
decoded_token=self.decoded_token,
|
||||
)
|
||||
@@ -626,12 +651,19 @@ class StreamProcessor:
|
||||
tools_data=tools_data,
|
||||
)
|
||||
|
||||
provider = (
|
||||
get_provider_from_model_id(self.model_id)
|
||||
if self.model_id
|
||||
else settings.LLM_PROVIDER
|
||||
)
|
||||
system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER)
|
||||
|
||||
return AgentCreator.create_agent(
|
||||
self.agent_config["agent_type"],
|
||||
endpoint="stream",
|
||||
llm_name=settings.LLM_PROVIDER,
|
||||
gpt_model=self.gpt_model,
|
||||
api_key=settings.API_KEY,
|
||||
llm_name=provider or settings.LLM_PROVIDER,
|
||||
model_id=self.model_id,
|
||||
api_key=system_api_key,
|
||||
user_api_key=self.agent_config["user_api_key"],
|
||||
prompt=rendered_prompt,
|
||||
chat_history=self.history,
|
||||
|
||||
@@ -95,6 +95,8 @@ class GetAgent(Resource):
|
||||
"shared": agent.get("shared_publicly", False),
|
||||
"shared_metadata": agent.get("shared_metadata", {}),
|
||||
"shared_token": agent.get("shared_token", ""),
|
||||
"models": agent.get("models", []),
|
||||
"default_model_id": agent.get("default_model_id", ""),
|
||||
}
|
||||
return make_response(jsonify(data), 200)
|
||||
except Exception as e:
|
||||
@@ -172,6 +174,8 @@ class GetAgents(Resource):
|
||||
"shared": agent.get("shared_publicly", False),
|
||||
"shared_metadata": agent.get("shared_metadata", {}),
|
||||
"shared_token": agent.get("shared_token", ""),
|
||||
"models": agent.get("models", []),
|
||||
"default_model_id": agent.get("default_model_id", ""),
|
||||
}
|
||||
for agent in agents
|
||||
if "source" in agent or "retriever" in agent
|
||||
@@ -230,6 +234,14 @@ class CreateAgent(Resource):
|
||||
required=False,
|
||||
description="Request limit for the agent in limited mode",
|
||||
),
|
||||
"models": fields.List(
|
||||
fields.String,
|
||||
required=False,
|
||||
description="List of available model IDs for this agent",
|
||||
),
|
||||
"default_model_id": fields.String(
|
||||
required=False, description="Default model ID for this agent"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -258,6 +270,11 @@ class CreateAgent(Resource):
|
||||
data["json_schema"] = json.loads(data["json_schema"])
|
||||
except json.JSONDecodeError:
|
||||
data["json_schema"] = None
|
||||
if "models" in data:
|
||||
try:
|
||||
data["models"] = json.loads(data["models"])
|
||||
except json.JSONDecodeError:
|
||||
data["models"] = []
|
||||
print(f"Received data: {data}")
|
||||
|
||||
# Validate JSON schema if provided
|
||||
@@ -399,6 +416,8 @@ class CreateAgent(Resource):
|
||||
"updatedAt": datetime.datetime.now(datetime.timezone.utc),
|
||||
"lastUsedAt": None,
|
||||
"key": key,
|
||||
"models": data.get("models", []),
|
||||
"default_model_id": data.get("default_model_id", ""),
|
||||
}
|
||||
if new_agent["chunks"] == "":
|
||||
new_agent["chunks"] = "2"
|
||||
@@ -464,6 +483,14 @@ class UpdateAgent(Resource):
|
||||
required=False,
|
||||
description="Request limit for the agent in limited mode",
|
||||
),
|
||||
"models": fields.List(
|
||||
fields.String,
|
||||
required=False,
|
||||
description="List of available model IDs for this agent",
|
||||
),
|
||||
"default_model_id": fields.String(
|
||||
required=False, description="Default model ID for this agent"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -487,7 +514,7 @@ class UpdateAgent(Resource):
|
||||
data = request.get_json()
|
||||
else:
|
||||
data = request.form.to_dict()
|
||||
json_fields = ["tools", "sources", "json_schema"]
|
||||
json_fields = ["tools", "sources", "json_schema", "models"]
|
||||
for field in json_fields:
|
||||
if field in data and data[field]:
|
||||
try:
|
||||
@@ -555,6 +582,8 @@ class UpdateAgent(Resource):
|
||||
"token_limit",
|
||||
"limited_request_mode",
|
||||
"request_limit",
|
||||
"models",
|
||||
"default_model_id",
|
||||
]
|
||||
|
||||
for field in allowed_fields:
|
||||
|
||||
@@ -25,7 +25,7 @@ class StoreAttachment(Resource):
|
||||
api.model(
|
||||
"AttachmentModel",
|
||||
{
|
||||
"file": fields.Raw(required=True, description="File to upload"),
|
||||
"file": fields.Raw(required=True, description="File(s) to upload"),
|
||||
"api_key": fields.String(
|
||||
required=False, description="API key (optional)"
|
||||
),
|
||||
@@ -33,18 +33,24 @@ class StoreAttachment(Resource):
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Stores a single attachment without vectorization or training. Supports user or API key authentication."
|
||||
description="Stores one or multiple attachments without vectorization or training. Supports user or API key authentication."
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = getattr(request, "decoded_token", None)
|
||||
api_key = request.form.get("api_key") or request.args.get("api_key")
|
||||
file = request.files.get("file")
|
||||
|
||||
if not file or file.filename == "":
|
||||
|
||||
files = request.files.getlist("file")
|
||||
if not files:
|
||||
single_file = request.files.get("file")
|
||||
if single_file:
|
||||
files = [single_file]
|
||||
|
||||
if not files or all(f.filename == "" for f in files):
|
||||
return make_response(
|
||||
jsonify({"status": "error", "message": "Missing file"}),
|
||||
jsonify({"status": "error", "message": "Missing file(s)"}),
|
||||
400,
|
||||
)
|
||||
|
||||
user = None
|
||||
if decoded_token:
|
||||
user = safe_filename(decoded_token.get("sub"))
|
||||
@@ -59,32 +65,74 @@ class StoreAttachment(Resource):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Authentication required"}), 401
|
||||
)
|
||||
|
||||
try:
|
||||
attachment_id = ObjectId()
|
||||
original_filename = safe_filename(os.path.basename(file.filename))
|
||||
relative_path = f"{settings.UPLOAD_FOLDER}/{user}/attachments/{str(attachment_id)}/{original_filename}"
|
||||
tasks = []
|
||||
errors = []
|
||||
original_file_count = len(files)
|
||||
|
||||
for idx, file in enumerate(files):
|
||||
try:
|
||||
attachment_id = ObjectId()
|
||||
original_filename = safe_filename(os.path.basename(file.filename))
|
||||
relative_path = f"{settings.UPLOAD_FOLDER}/{user}/attachments/{str(attachment_id)}/{original_filename}"
|
||||
|
||||
metadata = storage.save_file(file, relative_path)
|
||||
|
||||
file_info = {
|
||||
"filename": original_filename,
|
||||
"attachment_id": str(attachment_id),
|
||||
"path": relative_path,
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
task = store_attachment.delay(file_info, user)
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"task_id": task.id,
|
||||
"message": "File uploaded successfully. Processing started.",
|
||||
metadata = storage.save_file(file, relative_path)
|
||||
file_info = {
|
||||
"filename": original_filename,
|
||||
"attachment_id": str(attachment_id),
|
||||
"path": relative_path,
|
||||
"metadata": metadata,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
|
||||
task = store_attachment.delay(file_info, user)
|
||||
tasks.append({
|
||||
"task_id": task.id,
|
||||
"filename": original_filename,
|
||||
"attachment_id": str(attachment_id),
|
||||
})
|
||||
except Exception as file_err:
|
||||
current_app.logger.error(f"Error processing file {idx} ({file.filename}): {file_err}", exc_info=True)
|
||||
errors.append({
|
||||
"filename": file.filename,
|
||||
"error": str(file_err)
|
||||
})
|
||||
|
||||
if not tasks:
|
||||
error_msg = "No valid files to upload"
|
||||
if errors:
|
||||
error_msg += f". Errors: {errors}"
|
||||
return make_response(
|
||||
jsonify({"status": "error", "message": error_msg, "errors": errors}),
|
||||
400,
|
||||
)
|
||||
|
||||
if original_file_count == 1 and len(tasks) == 1:
|
||||
current_app.logger.info("Returning single task_id response")
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"task_id": tasks[0]["task_id"],
|
||||
"message": "File uploaded successfully. Processing started.",
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
else:
|
||||
response_data = {
|
||||
"success": True,
|
||||
"tasks": tasks,
|
||||
"message": f"{len(tasks)} file(s) uploaded successfully. Processing started.",
|
||||
}
|
||||
if errors:
|
||||
response_data["errors"] = errors
|
||||
response_data["message"] += f" {len(errors)} file(s) failed."
|
||||
|
||||
return make_response(
|
||||
jsonify(response_data),
|
||||
200,
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error storing attachment: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False, "error": str(err)}), 400)
|
||||
@@ -130,15 +178,11 @@ class TextToSpeech(Resource):
|
||||
@api.expect(tts_model)
|
||||
@api.doc(description="Synthesize audio speech from text")
|
||||
def post(self):
|
||||
from application.utils import clean_text_for_tts
|
||||
|
||||
data = request.get_json()
|
||||
text = data["text"]
|
||||
cleaned_text = clean_text_for_tts(text)
|
||||
|
||||
try:
|
||||
tts_instance = TTSCreator.create_tts(settings.TTS_PROVIDER)
|
||||
audio_base64, detected_language = tts_instance.text_to_speech(cleaned_text)
|
||||
audio_base64, detected_language = tts_instance.text_to_speech(text)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
|
||||
3
application/api/user/models/__init__.py
Normal file
3
application/api/user/models/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .routes import models_ns
|
||||
|
||||
__all__ = ["models_ns"]
|
||||
25
application/api/user/models/routes.py
Normal file
25
application/api/user/models/routes.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from flask import current_app, jsonify, make_response
|
||||
from flask_restx import Namespace, Resource
|
||||
|
||||
from application.core.model_settings import ModelRegistry
|
||||
|
||||
models_ns = Namespace("models", description="Available models", path="/api")
|
||||
|
||||
|
||||
@models_ns.route("/models")
|
||||
class ModelsListResource(Resource):
|
||||
def get(self):
|
||||
"""Get list of available models with their capabilities."""
|
||||
try:
|
||||
registry = ModelRegistry.get_instance()
|
||||
models = registry.get_enabled_models()
|
||||
|
||||
response = {
|
||||
"models": [model.to_dict() for model in models],
|
||||
"default_model_id": registry.default_model_id,
|
||||
"count": len(models),
|
||||
}
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error fetching models: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 500)
|
||||
return make_response(jsonify(response), 200)
|
||||
@@ -10,6 +10,7 @@ from .agents import agents_ns, agents_sharing_ns, agents_webhooks_ns
|
||||
from .analytics import analytics_ns
|
||||
from .attachments import attachments_ns
|
||||
from .conversations import conversations_ns
|
||||
from .models import models_ns
|
||||
from .prompts import prompts_ns
|
||||
from .sharing import sharing_ns
|
||||
from .sources import sources_chunks_ns, sources_ns, sources_upload_ns
|
||||
@@ -27,6 +28,9 @@ api.add_namespace(attachments_ns)
|
||||
# Conversations
|
||||
api.add_namespace(conversations_ns)
|
||||
|
||||
# Models
|
||||
api.add_namespace(models_ns)
|
||||
|
||||
# Agents (main, sharing, webhooks)
|
||||
api.add_namespace(agents_ns)
|
||||
api.add_namespace(agents_sharing_ns)
|
||||
|
||||
@@ -13,7 +13,6 @@ from application.api.user.base import (
|
||||
agents_collection,
|
||||
attachments_collection,
|
||||
conversations_collection,
|
||||
db,
|
||||
shared_conversations_collections,
|
||||
)
|
||||
from application.utils import check_required_fields
|
||||
@@ -97,9 +96,7 @@ class ShareConversation(Resource):
|
||||
api_uuid = pre_existing_api_document["key"]
|
||||
pre_existing = shared_conversations_collections.find_one(
|
||||
{
|
||||
"conversation_id": DBRef(
|
||||
"conversations", ObjectId(conversation_id)
|
||||
),
|
||||
"conversation_id": ObjectId(conversation_id),
|
||||
"isPromptable": is_promptable,
|
||||
"first_n_queries": current_n_queries,
|
||||
"user": user,
|
||||
@@ -120,10 +117,7 @@ class ShareConversation(Resource):
|
||||
shared_conversations_collections.insert_one(
|
||||
{
|
||||
"uuid": explicit_binary,
|
||||
"conversation_id": {
|
||||
"$ref": "conversations",
|
||||
"$id": ObjectId(conversation_id),
|
||||
},
|
||||
"conversation_id": ObjectId(conversation_id),
|
||||
"isPromptable": is_promptable,
|
||||
"first_n_queries": current_n_queries,
|
||||
"user": user,
|
||||
@@ -154,10 +148,7 @@ class ShareConversation(Resource):
|
||||
shared_conversations_collections.insert_one(
|
||||
{
|
||||
"uuid": explicit_binary,
|
||||
"conversation_id": {
|
||||
"$ref": "conversations",
|
||||
"$id": ObjectId(conversation_id),
|
||||
},
|
||||
"conversation_id": ObjectId(conversation_id),
|
||||
"isPromptable": is_promptable,
|
||||
"first_n_queries": current_n_queries,
|
||||
"user": user,
|
||||
@@ -175,9 +166,7 @@ class ShareConversation(Resource):
|
||||
)
|
||||
pre_existing = shared_conversations_collections.find_one(
|
||||
{
|
||||
"conversation_id": DBRef(
|
||||
"conversations", ObjectId(conversation_id)
|
||||
),
|
||||
"conversation_id": ObjectId(conversation_id),
|
||||
"isPromptable": is_promptable,
|
||||
"first_n_queries": current_n_queries,
|
||||
"user": user,
|
||||
@@ -197,10 +186,7 @@ class ShareConversation(Resource):
|
||||
shared_conversations_collections.insert_one(
|
||||
{
|
||||
"uuid": explicit_binary,
|
||||
"conversation_id": {
|
||||
"$ref": "conversations",
|
||||
"$id": ObjectId(conversation_id),
|
||||
},
|
||||
"conversation_id": ObjectId(conversation_id),
|
||||
"isPromptable": is_promptable,
|
||||
"first_n_queries": current_n_queries,
|
||||
"user": user,
|
||||
@@ -233,10 +219,12 @@ class GetPubliclySharedConversations(Resource):
|
||||
if (
|
||||
shared
|
||||
and "conversation_id" in shared
|
||||
and isinstance(shared["conversation_id"], DBRef)
|
||||
):
|
||||
conversation_ref = shared["conversation_id"]
|
||||
conversation = db.dereference(conversation_ref)
|
||||
# conversation_id is now stored as an ObjectId, not a DBRef
|
||||
conversation_id = shared["conversation_id"]
|
||||
conversation = conversations_collection.find_one(
|
||||
{"_id": conversation_id}
|
||||
)
|
||||
if conversation is None:
|
||||
return make_response(
|
||||
jsonify(
|
||||
|
||||
223
application/core/model_configs.py
Normal file
223
application/core/model_configs.py
Normal file
@@ -0,0 +1,223 @@
|
||||
"""
|
||||
Model configurations for all supported LLM providers.
|
||||
"""
|
||||
|
||||
from application.core.model_settings import (
|
||||
AvailableModel,
|
||||
ModelCapabilities,
|
||||
ModelProvider,
|
||||
)
|
||||
|
||||
OPENAI_ATTACHMENTS = [
|
||||
"application/pdf",
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/jpg",
|
||||
"image/webp",
|
||||
"image/gif",
|
||||
]
|
||||
|
||||
GOOGLE_ATTACHMENTS = [
|
||||
"application/pdf",
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/jpg",
|
||||
"image/webp",
|
||||
"image/gif",
|
||||
]
|
||||
|
||||
|
||||
OPENAI_MODELS = [
|
||||
AvailableModel(
|
||||
id="gpt-4o",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="GPT-4 Omni",
|
||||
description="Latest and most capable model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=OPENAI_ATTACHMENTS,
|
||||
context_window=128000,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="gpt-4o-mini",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="GPT-4 Omni Mini",
|
||||
description="Fast and efficient",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=OPENAI_ATTACHMENTS,
|
||||
context_window=128000,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="gpt-4-turbo",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="GPT-4 Turbo",
|
||||
description="Fast GPT-4 with 128k context",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=OPENAI_ATTACHMENTS,
|
||||
context_window=128000,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="gpt-4",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="GPT-4",
|
||||
description="Most capable model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=OPENAI_ATTACHMENTS,
|
||||
context_window=8192,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="gpt-3.5-turbo",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="GPT-3.5 Turbo",
|
||||
description="Fast and cost-effective",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
context_window=4096,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
ANTHROPIC_MODELS = [
|
||||
AvailableModel(
|
||||
id="claude-3-5-sonnet-20241022",
|
||||
provider=ModelProvider.ANTHROPIC,
|
||||
display_name="Claude 3.5 Sonnet (Latest)",
|
||||
description="Latest Claude 3.5 Sonnet with enhanced capabilities",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
context_window=200000,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="claude-3-5-sonnet",
|
||||
provider=ModelProvider.ANTHROPIC,
|
||||
display_name="Claude 3.5 Sonnet",
|
||||
description="Balanced performance and capability",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
context_window=200000,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="claude-3-opus",
|
||||
provider=ModelProvider.ANTHROPIC,
|
||||
display_name="Claude 3 Opus",
|
||||
description="Most capable Claude model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
context_window=200000,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="claude-3-haiku",
|
||||
provider=ModelProvider.ANTHROPIC,
|
||||
display_name="Claude 3 Haiku",
|
||||
description="Fastest Claude model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
context_window=200000,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
GOOGLE_MODELS = [
|
||||
AvailableModel(
|
||||
id="gemini-flash-latest",
|
||||
provider=ModelProvider.GOOGLE,
|
||||
display_name="Gemini Flash (Latest)",
|
||||
description="Latest experimental Gemini model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=GOOGLE_ATTACHMENTS,
|
||||
context_window=int(1e6),
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="gemini-flash-lite-latest",
|
||||
provider=ModelProvider.GOOGLE,
|
||||
display_name="Gemini Flash Lite (Latest)",
|
||||
description="Fast with huge context window",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=GOOGLE_ATTACHMENTS,
|
||||
context_window=int(1e6),
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="gemini-2.5-pro",
|
||||
provider=ModelProvider.GOOGLE,
|
||||
display_name="Gemini 2.5 Pro",
|
||||
description="Most capable Gemini model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=GOOGLE_ATTACHMENTS,
|
||||
context_window=2000000,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
GROQ_MODELS = [
|
||||
AvailableModel(
|
||||
id="llama-3.3-70b-versatile",
|
||||
provider=ModelProvider.GROQ,
|
||||
display_name="Llama 3.3 70B",
|
||||
description="Latest Llama model with high-speed inference",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
context_window=128000,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="llama-3.1-8b-instant",
|
||||
provider=ModelProvider.GROQ,
|
||||
display_name="Llama 3.1 8B",
|
||||
description="Ultra-fast inference",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
context_window=128000,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="mixtral-8x7b-32768",
|
||||
provider=ModelProvider.GROQ,
|
||||
display_name="Mixtral 8x7B",
|
||||
description="High-speed inference with tools",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
context_window=32768,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
AZURE_OPENAI_MODELS = [
|
||||
AvailableModel(
|
||||
id="azure-gpt-4",
|
||||
provider=ModelProvider.AZURE_OPENAI,
|
||||
display_name="Azure OpenAI GPT-4",
|
||||
description="Azure-hosted GPT model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=OPENAI_ATTACHMENTS,
|
||||
context_window=8192,
|
||||
),
|
||||
),
|
||||
]
|
||||
236
application/core/model_settings.py
Normal file
236
application/core/model_settings.py
Normal file
@@ -0,0 +1,236 @@
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelProvider(str, Enum):
|
||||
OPENAI = "openai"
|
||||
AZURE_OPENAI = "azure_openai"
|
||||
ANTHROPIC = "anthropic"
|
||||
GROQ = "groq"
|
||||
GOOGLE = "google"
|
||||
HUGGINGFACE = "huggingface"
|
||||
LLAMA_CPP = "llama.cpp"
|
||||
DOCSGPT = "docsgpt"
|
||||
PREMAI = "premai"
|
||||
SAGEMAKER = "sagemaker"
|
||||
NOVITA = "novita"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelCapabilities:
|
||||
supports_tools: bool = False
|
||||
supports_structured_output: bool = False
|
||||
supports_streaming: bool = True
|
||||
supported_attachment_types: List[str] = field(default_factory=list)
|
||||
context_window: int = 128000
|
||||
input_cost_per_token: Optional[float] = None
|
||||
output_cost_per_token: Optional[float] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AvailableModel:
|
||||
id: str
|
||||
provider: ModelProvider
|
||||
display_name: str
|
||||
description: str = ""
|
||||
capabilities: ModelCapabilities = field(default_factory=ModelCapabilities)
|
||||
enabled: bool = True
|
||||
base_url: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
result = {
|
||||
"id": self.id,
|
||||
"provider": self.provider.value,
|
||||
"display_name": self.display_name,
|
||||
"description": self.description,
|
||||
"supported_attachment_types": self.capabilities.supported_attachment_types,
|
||||
"supports_tools": self.capabilities.supports_tools,
|
||||
"supports_structured_output": self.capabilities.supports_structured_output,
|
||||
"supports_streaming": self.capabilities.supports_streaming,
|
||||
"context_window": self.capabilities.context_window,
|
||||
"enabled": self.enabled,
|
||||
}
|
||||
if self.base_url:
|
||||
result["base_url"] = self.base_url
|
||||
return result
|
||||
|
||||
|
||||
class ModelRegistry:
|
||||
_instance = None
|
||||
_initialized = False
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not ModelRegistry._initialized:
|
||||
self.models: Dict[str, AvailableModel] = {}
|
||||
self.default_model_id: Optional[str] = None
|
||||
self._load_models()
|
||||
ModelRegistry._initialized = True
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "ModelRegistry":
|
||||
return cls()
|
||||
|
||||
def _load_models(self):
|
||||
from application.core.settings import settings
|
||||
|
||||
self.models.clear()
|
||||
|
||||
self._add_docsgpt_models(settings)
|
||||
if settings.OPENAI_API_KEY or (
|
||||
settings.LLM_PROVIDER == "openai" and settings.API_KEY
|
||||
):
|
||||
self._add_openai_models(settings)
|
||||
if settings.OPENAI_API_BASE or (
|
||||
settings.LLM_PROVIDER == "azure_openai" and settings.API_KEY
|
||||
):
|
||||
self._add_azure_openai_models(settings)
|
||||
if settings.ANTHROPIC_API_KEY or (
|
||||
settings.LLM_PROVIDER == "anthropic" and settings.API_KEY
|
||||
):
|
||||
self._add_anthropic_models(settings)
|
||||
if settings.GOOGLE_API_KEY or (
|
||||
settings.LLM_PROVIDER == "google" and settings.API_KEY
|
||||
):
|
||||
self._add_google_models(settings)
|
||||
if settings.GROQ_API_KEY or (
|
||||
settings.LLM_PROVIDER == "groq" and settings.API_KEY
|
||||
):
|
||||
self._add_groq_models(settings)
|
||||
if settings.HUGGINGFACE_API_KEY or (
|
||||
settings.LLM_PROVIDER == "huggingface" and settings.API_KEY
|
||||
):
|
||||
self._add_huggingface_models(settings)
|
||||
# Default model selection
|
||||
|
||||
if settings.LLM_NAME and settings.LLM_NAME in self.models:
|
||||
self.default_model_id = settings.LLM_NAME
|
||||
elif settings.LLM_PROVIDER and settings.API_KEY:
|
||||
for model_id, model in self.models.items():
|
||||
if model.provider.value == settings.LLM_PROVIDER:
|
||||
self.default_model_id = model_id
|
||||
break
|
||||
else:
|
||||
self.default_model_id = next(iter(self.models.keys()))
|
||||
logger.info(
|
||||
f"ModelRegistry loaded {len(self.models)} models, default: {self.default_model_id}"
|
||||
)
|
||||
|
||||
def _add_openai_models(self, settings):
|
||||
from application.core.model_configs import OPENAI_MODELS
|
||||
|
||||
if settings.OPENAI_API_KEY:
|
||||
for model in OPENAI_MODELS:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
if settings.LLM_PROVIDER == "openai" and settings.LLM_NAME:
|
||||
for model in OPENAI_MODELS:
|
||||
if model.id == settings.LLM_NAME:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
for model in OPENAI_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_azure_openai_models(self, settings):
|
||||
from application.core.model_configs import AZURE_OPENAI_MODELS
|
||||
|
||||
if settings.LLM_PROVIDER == "azure_openai" and settings.LLM_NAME:
|
||||
for model in AZURE_OPENAI_MODELS:
|
||||
if model.id == settings.LLM_NAME:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
for model in AZURE_OPENAI_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_anthropic_models(self, settings):
|
||||
from application.core.model_configs import ANTHROPIC_MODELS
|
||||
|
||||
if settings.ANTHROPIC_API_KEY:
|
||||
for model in ANTHROPIC_MODELS:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
if settings.LLM_PROVIDER == "anthropic" and settings.LLM_NAME:
|
||||
for model in ANTHROPIC_MODELS:
|
||||
if model.id == settings.LLM_NAME:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
for model in ANTHROPIC_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_google_models(self, settings):
|
||||
from application.core.model_configs import GOOGLE_MODELS
|
||||
|
||||
if settings.GOOGLE_API_KEY:
|
||||
for model in GOOGLE_MODELS:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
if settings.LLM_PROVIDER == "google" and settings.LLM_NAME:
|
||||
for model in GOOGLE_MODELS:
|
||||
if model.id == settings.LLM_NAME:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
for model in GOOGLE_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_groq_models(self, settings):
|
||||
from application.core.model_configs import GROQ_MODELS
|
||||
|
||||
if settings.GROQ_API_KEY:
|
||||
for model in GROQ_MODELS:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
if settings.LLM_PROVIDER == "groq" and settings.LLM_NAME:
|
||||
for model in GROQ_MODELS:
|
||||
if model.id == settings.LLM_NAME:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
for model in GROQ_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_docsgpt_models(self, settings):
|
||||
model_id = "docsgpt-local"
|
||||
model = AvailableModel(
|
||||
id=model_id,
|
||||
provider=ModelProvider.DOCSGPT,
|
||||
display_name="DocsGPT Model",
|
||||
description="Local model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=False,
|
||||
supported_attachment_types=[],
|
||||
),
|
||||
)
|
||||
self.models[model_id] = model
|
||||
|
||||
def _add_huggingface_models(self, settings):
|
||||
model_id = "huggingface-local"
|
||||
model = AvailableModel(
|
||||
id=model_id,
|
||||
provider=ModelProvider.HUGGINGFACE,
|
||||
display_name="Hugging Face Model",
|
||||
description="Local Hugging Face model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=False,
|
||||
supported_attachment_types=[],
|
||||
),
|
||||
)
|
||||
self.models[model_id] = model
|
||||
|
||||
def get_model(self, model_id: str) -> Optional[AvailableModel]:
|
||||
return self.models.get(model_id)
|
||||
|
||||
def get_all_models(self) -> List[AvailableModel]:
|
||||
return list(self.models.values())
|
||||
|
||||
def get_enabled_models(self) -> List[AvailableModel]:
|
||||
return [m for m in self.models.values() if m.enabled]
|
||||
|
||||
def model_exists(self, model_id: str) -> bool:
|
||||
return model_id in self.models
|
||||
91
application/core/model_utils.py
Normal file
91
application/core/model_utils.py
Normal file
@@ -0,0 +1,91 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from application.core.model_settings import ModelRegistry
|
||||
|
||||
|
||||
def get_api_key_for_provider(provider: str) -> Optional[str]:
|
||||
"""Get the appropriate API key for a provider"""
|
||||
from application.core.settings import settings
|
||||
|
||||
provider_key_map = {
|
||||
"openai": settings.OPENAI_API_KEY,
|
||||
"anthropic": settings.ANTHROPIC_API_KEY,
|
||||
"google": settings.GOOGLE_API_KEY,
|
||||
"groq": settings.GROQ_API_KEY,
|
||||
"huggingface": settings.HUGGINGFACE_API_KEY,
|
||||
"azure_openai": settings.API_KEY,
|
||||
"docsgpt": None,
|
||||
"llama.cpp": None,
|
||||
}
|
||||
|
||||
provider_key = provider_key_map.get(provider)
|
||||
if provider_key:
|
||||
return provider_key
|
||||
return settings.API_KEY
|
||||
|
||||
|
||||
def get_all_available_models() -> Dict[str, Dict[str, Any]]:
|
||||
"""Get all available models with metadata for API response"""
|
||||
registry = ModelRegistry.get_instance()
|
||||
return {model.id: model.to_dict() for model in registry.get_enabled_models()}
|
||||
|
||||
|
||||
def validate_model_id(model_id: str) -> bool:
|
||||
"""Check if a model ID exists in registry"""
|
||||
registry = ModelRegistry.get_instance()
|
||||
return registry.model_exists(model_id)
|
||||
|
||||
|
||||
def get_model_capabilities(model_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get capabilities for a specific model"""
|
||||
registry = ModelRegistry.get_instance()
|
||||
model = registry.get_model(model_id)
|
||||
if model:
|
||||
return {
|
||||
"supported_attachment_types": model.capabilities.supported_attachment_types,
|
||||
"supports_tools": model.capabilities.supports_tools,
|
||||
"supports_structured_output": model.capabilities.supports_structured_output,
|
||||
"context_window": model.capabilities.context_window,
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
def get_default_model_id() -> str:
|
||||
"""Get the system default model ID"""
|
||||
registry = ModelRegistry.get_instance()
|
||||
return registry.default_model_id
|
||||
|
||||
|
||||
def get_provider_from_model_id(model_id: str) -> Optional[str]:
|
||||
"""Get the provider name for a given model_id"""
|
||||
registry = ModelRegistry.get_instance()
|
||||
model = registry.get_model(model_id)
|
||||
if model:
|
||||
return model.provider.value
|
||||
return None
|
||||
|
||||
|
||||
def get_token_limit(model_id: str) -> int:
|
||||
"""
|
||||
Get context window (token limit) for a model.
|
||||
Returns model's context_window or default 128000 if model not found.
|
||||
"""
|
||||
from application.core.settings import settings
|
||||
|
||||
registry = ModelRegistry.get_instance()
|
||||
model = registry.get_model(model_id)
|
||||
if model:
|
||||
return model.capabilities.context_window
|
||||
return settings.DEFAULT_LLM_TOKEN_LIMIT
|
||||
|
||||
|
||||
def get_base_url_for_model(model_id: str) -> Optional[str]:
|
||||
"""
|
||||
Get the custom base_url for a specific model if configured.
|
||||
Returns None if no custom base_url is set.
|
||||
"""
|
||||
registry = ModelRegistry.get_instance()
|
||||
model = registry.get_model(model_id)
|
||||
if model:
|
||||
return model.base_url
|
||||
return None
|
||||
@@ -22,15 +22,7 @@ class Settings(BaseSettings):
|
||||
MONGO_DB_NAME: str = "docsgpt"
|
||||
LLM_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf")
|
||||
DEFAULT_MAX_HISTORY: int = 150
|
||||
LLM_TOKEN_LIMITS: dict = {
|
||||
"gpt-4o": 128000,
|
||||
"gpt-4o-mini": 128000,
|
||||
"gpt-4": 8192,
|
||||
"gpt-3.5-turbo": 4096,
|
||||
"claude-2": int(1e5),
|
||||
"gemini-2.5-flash": int(1e6),
|
||||
}
|
||||
DEFAULT_LLM_TOKEN_LIMIT: int = 128000
|
||||
DEFAULT_LLM_TOKEN_LIMIT: int = 128000 # Fallback when model not found in registry
|
||||
RESERVED_TOKENS: dict = {
|
||||
"system_prompt": 500,
|
||||
"current_query": 500,
|
||||
@@ -64,14 +56,22 @@ class Settings(BaseSettings):
|
||||
)
|
||||
|
||||
# GitHub source
|
||||
GITHUB_ACCESS_TOKEN: Optional[str] = None # PAT token with read repo access
|
||||
GITHUB_ACCESS_TOKEN: Optional[str] = None # PAT token with read repo access
|
||||
|
||||
# LLM Cache
|
||||
CACHE_REDIS_URL: str = "redis://localhost:6379/2"
|
||||
|
||||
API_URL: str = "http://localhost:7091" # backend url for celery worker
|
||||
|
||||
API_KEY: Optional[str] = None # LLM api key
|
||||
API_KEY: Optional[str] = None # LLM api key (used by LLM_PROVIDER)
|
||||
|
||||
# Provider-specific API keys (for multi-model support)
|
||||
OPENAI_API_KEY: Optional[str] = None
|
||||
ANTHROPIC_API_KEY: Optional[str] = None
|
||||
GOOGLE_API_KEY: Optional[str] = None
|
||||
GROQ_API_KEY: Optional[str] = None
|
||||
HUGGINGFACE_API_KEY: Optional[str] = None
|
||||
|
||||
EMBEDDINGS_KEY: Optional[str] = (
|
||||
None # api key for embeddings (if using openai, just copy API_KEY)
|
||||
)
|
||||
@@ -138,11 +138,12 @@ class Settings(BaseSettings):
|
||||
# Encryption settings
|
||||
ENCRYPTION_SECRET_KEY: str = "default-docsgpt-encryption-key"
|
||||
|
||||
TTS_PROVIDER: str = "google_tts" # google_tts or elevenlabs
|
||||
TTS_PROVIDER: str = "google_tts" # google_tts or elevenlabs
|
||||
ELEVENLABS_API_KEY: Optional[str] = None
|
||||
|
||||
# Tool pre-fetch settings
|
||||
ENABLE_TOOL_PREFETCH: bool = True
|
||||
|
||||
|
||||
path = Path(__file__).parent.parent.absolute()
|
||||
settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8")
|
||||
|
||||
@@ -1,30 +1,41 @@
|
||||
from application.llm.base import BaseLLM
|
||||
from anthropic import AI_PROMPT, Anthropic, HUMAN_PROMPT
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.llm.base import BaseLLM
|
||||
|
||||
|
||||
class AnthropicLLM(BaseLLM):
|
||||
|
||||
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
|
||||
from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT
|
||||
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
self.api_key = (
|
||||
api_key or settings.ANTHROPIC_API_KEY
|
||||
) # If not provided, use a default from settings
|
||||
self.api_key = api_key or settings.ANTHROPIC_API_KEY or settings.API_KEY
|
||||
self.user_api_key = user_api_key
|
||||
self.anthropic = Anthropic(api_key=self.api_key)
|
||||
|
||||
# Use custom base_url if provided
|
||||
if base_url:
|
||||
self.anthropic = Anthropic(api_key=self.api_key, base_url=base_url)
|
||||
else:
|
||||
self.anthropic = Anthropic(api_key=self.api_key)
|
||||
|
||||
self.HUMAN_PROMPT = HUMAN_PROMPT
|
||||
self.AI_PROMPT = AI_PROMPT
|
||||
|
||||
def _raw_gen(
|
||||
self, baseself, model, messages, stream=False, tools=None, max_tokens=300, **kwargs
|
||||
self,
|
||||
baseself,
|
||||
model,
|
||||
messages,
|
||||
stream=False,
|
||||
tools=None,
|
||||
max_tokens=300,
|
||||
**kwargs,
|
||||
):
|
||||
context = messages[0]["content"]
|
||||
user_question = messages[-1]["content"]
|
||||
prompt = f"### Context \n {context} \n ### Question \n {user_question}"
|
||||
if stream:
|
||||
return self.gen_stream(model, prompt, stream, max_tokens, **kwargs)
|
||||
|
||||
completion = self.anthropic.completions.create(
|
||||
model=model,
|
||||
max_tokens_to_sample=max_tokens,
|
||||
@@ -34,7 +45,14 @@ class AnthropicLLM(BaseLLM):
|
||||
return completion.completion
|
||||
|
||||
def _raw_gen_stream(
|
||||
self, baseself, model, messages, stream=True, tools=None, max_tokens=300, **kwargs
|
||||
self,
|
||||
baseself,
|
||||
model,
|
||||
messages,
|
||||
stream=True,
|
||||
tools=None,
|
||||
max_tokens=300,
|
||||
**kwargs,
|
||||
):
|
||||
context = messages[0]["content"]
|
||||
user_question = messages[-1]["content"]
|
||||
@@ -50,5 +68,5 @@ class AnthropicLLM(BaseLLM):
|
||||
for completion in stream_response:
|
||||
yield completion.completion
|
||||
finally:
|
||||
if hasattr(stream_response, 'close'):
|
||||
if hasattr(stream_response, "close"):
|
||||
stream_response.close()
|
||||
|
||||
@@ -13,30 +13,32 @@ class BaseLLM(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
decoded_token=None,
|
||||
model_id=None,
|
||||
base_url=None,
|
||||
):
|
||||
self.decoded_token = decoded_token
|
||||
self.model_id = model_id
|
||||
self.base_url = base_url
|
||||
self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
|
||||
self.fallback_provider = settings.FALLBACK_LLM_PROVIDER
|
||||
self.fallback_model_name = settings.FALLBACK_LLM_NAME
|
||||
self.fallback_llm_api_key = settings.FALLBACK_LLM_API_KEY
|
||||
self._fallback_llm = None
|
||||
self._fallback_sequence_index = 0
|
||||
|
||||
@property
|
||||
def fallback_llm(self):
|
||||
"""Lazy-loaded fallback LLM instance."""
|
||||
if (
|
||||
self._fallback_llm is None
|
||||
and self.fallback_provider
|
||||
and self.fallback_model_name
|
||||
):
|
||||
"""Lazy-loaded fallback LLM from FALLBACK_* settings."""
|
||||
if self._fallback_llm is None and settings.FALLBACK_LLM_PROVIDER:
|
||||
try:
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
|
||||
self._fallback_llm = LLMCreator.create_llm(
|
||||
self.fallback_provider,
|
||||
self.fallback_llm_api_key,
|
||||
None,
|
||||
self.decoded_token,
|
||||
settings.FALLBACK_LLM_PROVIDER,
|
||||
api_key=settings.FALLBACK_LLM_API_KEY or settings.API_KEY,
|
||||
user_api_key=None,
|
||||
decoded_token=self.decoded_token,
|
||||
model_id=settings.FALLBACK_LLM_NAME,
|
||||
)
|
||||
logger.info(
|
||||
f"Fallback LLM initialized: {settings.FALLBACK_LLM_PROVIDER}/{settings.FALLBACK_LLM_NAME}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
@@ -54,7 +56,7 @@ class BaseLLM(ABC):
|
||||
self, method_name: str, decorators: list, *args, **kwargs
|
||||
):
|
||||
"""
|
||||
Unified method execution with fallback support.
|
||||
Execute method with fallback support.
|
||||
|
||||
Args:
|
||||
method_name: Name of the raw method ('_raw_gen' or '_raw_gen_stream')
|
||||
@@ -73,10 +75,10 @@ class BaseLLM(ABC):
|
||||
return decorated_method()
|
||||
except Exception as e:
|
||||
if not self.fallback_llm:
|
||||
logger.error(f"Primary LLM failed and no fallback available: {str(e)}")
|
||||
logger.error(f"Primary LLM failed and no fallback configured: {str(e)}")
|
||||
raise
|
||||
logger.warning(
|
||||
f"Falling back to {self.fallback_provider}/{self.fallback_model_name}. Error: {str(e)}"
|
||||
f"Primary LLM failed. Falling back to {settings.FALLBACK_LLM_PROVIDER}/{settings.FALLBACK_LLM_NAME}. Error: {str(e)}"
|
||||
)
|
||||
|
||||
fallback_method = getattr(
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import json
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.llm.base import BaseLLM
|
||||
|
||||
@@ -7,12 +9,11 @@ from application.llm.base import BaseLLM
|
||||
class DocsGPTAPILLM(BaseLLM):
|
||||
|
||||
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
|
||||
from openai import OpenAI
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
self.client = OpenAI(api_key="sk-docsgpt-public", base_url="https://oai.arc53.com")
|
||||
self.api_key = "sk-docsgpt-public"
|
||||
self.client = OpenAI(api_key=self.api_key, base_url="https://oai.arc53.com")
|
||||
self.user_api_key = user_api_key
|
||||
self.api_key = api_key
|
||||
|
||||
def _clean_messages_openai(self, messages):
|
||||
cleaned_messages = []
|
||||
@@ -22,7 +23,6 @@ class DocsGPTAPILLM(BaseLLM):
|
||||
|
||||
if role == "model":
|
||||
role = "assistant"
|
||||
|
||||
if role and content is not None:
|
||||
if isinstance(content, str):
|
||||
cleaned_messages.append({"role": role, "content": content})
|
||||
@@ -69,7 +69,6 @@ class DocsGPTAPILLM(BaseLLM):
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unexpected content type: {type(content)}")
|
||||
|
||||
return cleaned_messages
|
||||
|
||||
def _raw_gen(
|
||||
@@ -121,7 +120,6 @@ class DocsGPTAPILLM(BaseLLM):
|
||||
response = self.client.chat.completions.create(
|
||||
model="docsgpt", messages=messages, stream=stream, **kwargs
|
||||
)
|
||||
|
||||
try:
|
||||
for line in response:
|
||||
if (
|
||||
@@ -133,8 +131,8 @@ class DocsGPTAPILLM(BaseLLM):
|
||||
elif len(line.choices) > 0:
|
||||
yield line.choices[0]
|
||||
finally:
|
||||
if hasattr(response, 'close'):
|
||||
if hasattr(response, "close"):
|
||||
response.close()
|
||||
|
||||
def _supports_tools(self):
|
||||
return True
|
||||
return True
|
||||
|
||||
@@ -13,8 +13,9 @@ from application.storage.storage_creator import StorageCreator
|
||||
class GoogleLLM(BaseLLM):
|
||||
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.api_key = api_key
|
||||
self.api_key = api_key or settings.GOOGLE_API_KEY or settings.API_KEY
|
||||
self.user_api_key = user_api_key
|
||||
|
||||
self.client = genai.Client(api_key=self.api_key)
|
||||
self.storage = StorageCreator.get_storage()
|
||||
|
||||
@@ -47,21 +48,19 @@ class GoogleLLM(BaseLLM):
|
||||
"""
|
||||
if not attachments:
|
||||
return messages
|
||||
|
||||
prepared_messages = messages.copy()
|
||||
|
||||
# Find the user message to attach files to the last one
|
||||
|
||||
user_message_index = None
|
||||
for i in range(len(prepared_messages) - 1, -1, -1):
|
||||
if prepared_messages[i].get("role") == "user":
|
||||
user_message_index = i
|
||||
break
|
||||
|
||||
if user_message_index is None:
|
||||
user_message = {"role": "user", "content": []}
|
||||
prepared_messages.append(user_message)
|
||||
user_message_index = len(prepared_messages) - 1
|
||||
|
||||
if isinstance(prepared_messages[user_message_index].get("content"), str):
|
||||
text_content = prepared_messages[user_message_index]["content"]
|
||||
prepared_messages[user_message_index]["content"] = [
|
||||
@@ -69,7 +68,6 @@ class GoogleLLM(BaseLLM):
|
||||
]
|
||||
elif not isinstance(prepared_messages[user_message_index].get("content"), list):
|
||||
prepared_messages[user_message_index]["content"] = []
|
||||
|
||||
files = []
|
||||
for attachment in attachments:
|
||||
mime_type = attachment.get("mime_type")
|
||||
@@ -92,11 +90,9 @@ class GoogleLLM(BaseLLM):
|
||||
"text": f"[File could not be processed: {attachment.get('path', 'unknown')}]",
|
||||
}
|
||||
)
|
||||
|
||||
if files:
|
||||
logging.info(f"GoogleLLM: Adding {len(files)} files to message")
|
||||
prepared_messages[user_message_index]["content"].append({"files": files})
|
||||
|
||||
return prepared_messages
|
||||
|
||||
def _upload_file_to_google(self, attachment):
|
||||
@@ -111,14 +107,11 @@ class GoogleLLM(BaseLLM):
|
||||
"""
|
||||
if "google_file_uri" in attachment:
|
||||
return attachment["google_file_uri"]
|
||||
|
||||
file_path = attachment.get("path")
|
||||
if not file_path:
|
||||
raise ValueError("No file path provided in attachment")
|
||||
|
||||
if not self.storage.file_exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
try:
|
||||
file_uri = self.storage.process_file(
|
||||
file_path,
|
||||
@@ -136,7 +129,6 @@ class GoogleLLM(BaseLLM):
|
||||
attachments_collection.update_one(
|
||||
{"_id": attachment["_id"]}, {"$set": {"google_file_uri": file_uri}}
|
||||
)
|
||||
|
||||
return file_uri
|
||||
except Exception as e:
|
||||
logging.error(f"Error uploading file to Google AI: {e}", exc_info=True)
|
||||
@@ -153,7 +145,6 @@ class GoogleLLM(BaseLLM):
|
||||
role = "model"
|
||||
elif role == "tool":
|
||||
role = "model"
|
||||
|
||||
parts = []
|
||||
if role and content is not None:
|
||||
if isinstance(content, str):
|
||||
@@ -164,6 +155,7 @@ class GoogleLLM(BaseLLM):
|
||||
parts.append(types.Part.from_text(text=item["text"]))
|
||||
elif "function_call" in item:
|
||||
# Remove null values from args to avoid API errors
|
||||
|
||||
cleaned_args = self._remove_null_values(
|
||||
item["function_call"]["args"]
|
||||
)
|
||||
@@ -194,10 +186,8 @@ class GoogleLLM(BaseLLM):
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unexpected content type: {type(content)}")
|
||||
|
||||
if parts:
|
||||
cleaned_messages.append(types.Content(role=role, parts=parts))
|
||||
|
||||
return cleaned_messages
|
||||
|
||||
def _clean_schema(self, schema_obj):
|
||||
@@ -233,8 +223,8 @@ class GoogleLLM(BaseLLM):
|
||||
cleaned[key] = [self._clean_schema(item) for item in value]
|
||||
else:
|
||||
cleaned[key] = value
|
||||
|
||||
# Validate that required properties actually exist in properties
|
||||
|
||||
if "required" in cleaned and "properties" in cleaned:
|
||||
valid_required = []
|
||||
properties_keys = set(cleaned["properties"].keys())
|
||||
@@ -247,7 +237,6 @@ class GoogleLLM(BaseLLM):
|
||||
cleaned.pop("required", None)
|
||||
elif "required" in cleaned and "properties" not in cleaned:
|
||||
cleaned.pop("required", None)
|
||||
|
||||
return cleaned
|
||||
|
||||
def _clean_tools_format(self, tools_list):
|
||||
@@ -263,7 +252,6 @@ class GoogleLLM(BaseLLM):
|
||||
cleaned_properties = {}
|
||||
for k, v in properties.items():
|
||||
cleaned_properties[k] = self._clean_schema(v)
|
||||
|
||||
genai_function = dict(
|
||||
name=function["name"],
|
||||
description=function["description"],
|
||||
@@ -282,10 +270,8 @@ class GoogleLLM(BaseLLM):
|
||||
name=function["name"],
|
||||
description=function["description"],
|
||||
)
|
||||
|
||||
genai_tool = types.Tool(function_declarations=[genai_function])
|
||||
genai_tools.append(genai_tool)
|
||||
|
||||
return genai_tools
|
||||
|
||||
def _raw_gen(
|
||||
@@ -307,16 +293,14 @@ class GoogleLLM(BaseLLM):
|
||||
if messages[0].role == "system":
|
||||
config.system_instruction = messages[0].parts[0].text
|
||||
messages = messages[1:]
|
||||
|
||||
if tools:
|
||||
cleaned_tools = self._clean_tools_format(tools)
|
||||
config.tools = cleaned_tools
|
||||
|
||||
# Add response schema for structured output if provided
|
||||
|
||||
if response_schema:
|
||||
config.response_schema = response_schema
|
||||
config.response_mime_type = "application/json"
|
||||
|
||||
response = client.models.generate_content(
|
||||
model=model,
|
||||
contents=messages,
|
||||
@@ -347,17 +331,16 @@ class GoogleLLM(BaseLLM):
|
||||
if messages[0].role == "system":
|
||||
config.system_instruction = messages[0].parts[0].text
|
||||
messages = messages[1:]
|
||||
|
||||
if tools:
|
||||
cleaned_tools = self._clean_tools_format(tools)
|
||||
config.tools = cleaned_tools
|
||||
|
||||
# Add response schema for structured output if provided
|
||||
|
||||
if response_schema:
|
||||
config.response_schema = response_schema
|
||||
config.response_mime_type = "application/json"
|
||||
|
||||
# Check if we have both tools and file attachments
|
||||
|
||||
has_attachments = False
|
||||
for message in messages:
|
||||
for part in message.parts:
|
||||
@@ -366,7 +349,6 @@ class GoogleLLM(BaseLLM):
|
||||
break
|
||||
if has_attachments:
|
||||
break
|
||||
|
||||
logging.info(
|
||||
f"GoogleLLM: Starting stream generation. Model: {model}, Messages: {json.dumps(messages, default=str)}, Has attachments: {has_attachments}"
|
||||
)
|
||||
@@ -405,7 +387,6 @@ class GoogleLLM(BaseLLM):
|
||||
"""Convert JSON schema to Google AI structured output format."""
|
||||
if not json_schema:
|
||||
return None
|
||||
|
||||
type_map = {
|
||||
"object": "OBJECT",
|
||||
"array": "ARRAY",
|
||||
@@ -418,12 +399,10 @@ class GoogleLLM(BaseLLM):
|
||||
def convert(schema):
|
||||
if not isinstance(schema, dict):
|
||||
return schema
|
||||
|
||||
result = {}
|
||||
schema_type = schema.get("type")
|
||||
if schema_type:
|
||||
result["type"] = type_map.get(schema_type.lower(), schema_type.upper())
|
||||
|
||||
for key in [
|
||||
"description",
|
||||
"nullable",
|
||||
@@ -435,7 +414,6 @@ class GoogleLLM(BaseLLM):
|
||||
]:
|
||||
if key in schema:
|
||||
result[key] = schema[key]
|
||||
|
||||
if "format" in schema:
|
||||
format_value = schema["format"]
|
||||
if schema_type == "string":
|
||||
@@ -445,21 +423,17 @@ class GoogleLLM(BaseLLM):
|
||||
result["format"] = format_value
|
||||
else:
|
||||
result["format"] = format_value
|
||||
|
||||
if "properties" in schema:
|
||||
result["properties"] = {
|
||||
k: convert(v) for k, v in schema["properties"].items()
|
||||
}
|
||||
if "propertyOrdering" not in result and result.get("type") == "OBJECT":
|
||||
result["propertyOrdering"] = list(result["properties"].keys())
|
||||
|
||||
if "items" in schema:
|
||||
result["items"] = convert(schema["items"])
|
||||
|
||||
for field in ["anyOf", "oneOf", "allOf"]:
|
||||
if field in schema:
|
||||
result[field] = [convert(s) for s in schema[field]]
|
||||
|
||||
return result
|
||||
|
||||
try:
|
||||
|
||||
@@ -1,13 +1,18 @@
|
||||
from application.llm.base import BaseLLM
|
||||
from openai import OpenAI
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.llm.base import BaseLLM
|
||||
|
||||
|
||||
class GroqLLM(BaseLLM):
|
||||
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
self.client = OpenAI(api_key=api_key, base_url="https://api.groq.com/openai/v1")
|
||||
self.api_key = api_key
|
||||
self.api_key = api_key or settings.GROQ_API_KEY or settings.API_KEY
|
||||
self.user_api_key = user_api_key
|
||||
self.client = OpenAI(
|
||||
api_key=self.api_key, base_url="https://api.groq.com/openai/v1"
|
||||
)
|
||||
|
||||
def _raw_gen(self, baseself, model, messages, stream=False, tools=None, **kwargs):
|
||||
if tools:
|
||||
|
||||
@@ -282,7 +282,7 @@ class LLMHandler(ABC):
|
||||
messages = e.value
|
||||
break
|
||||
response = agent.llm.gen(
|
||||
model=agent.gpt_model, messages=messages, tools=agent.tools
|
||||
model=agent.model_id, messages=messages, tools=agent.tools
|
||||
)
|
||||
parsed = self.parse_response(response)
|
||||
self.llm_calls.append(build_stack_data(agent.llm))
|
||||
@@ -337,7 +337,7 @@ class LLMHandler(ABC):
|
||||
tool_calls = {}
|
||||
|
||||
response = agent.llm.gen_stream(
|
||||
model=agent.gpt_model, messages=messages, tools=agent.tools
|
||||
model=agent.model_id, messages=messages, tools=agent.tools
|
||||
)
|
||||
self.llm_calls.append(build_stack_data(agent.llm))
|
||||
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
from application.llm.groq import GroqLLM
|
||||
from application.llm.openai import OpenAILLM, AzureOpenAILLM
|
||||
from application.llm.sagemaker import SagemakerAPILLM
|
||||
from application.llm.huggingface import HuggingFaceLLM
|
||||
from application.llm.llama_cpp import LlamaCpp
|
||||
import logging
|
||||
|
||||
from application.llm.anthropic import AnthropicLLM
|
||||
from application.llm.docsgpt_provider import DocsGPTAPILLM
|
||||
from application.llm.premai import PremAILLM
|
||||
from application.llm.google_ai import GoogleLLM
|
||||
from application.llm.groq import GroqLLM
|
||||
from application.llm.huggingface import HuggingFaceLLM
|
||||
from application.llm.llama_cpp import LlamaCpp
|
||||
from application.llm.novita import NovitaLLM
|
||||
from application.llm.openai import AzureOpenAILLM, OpenAILLM
|
||||
from application.llm.premai import PremAILLM
|
||||
from application.llm.sagemaker import SagemakerAPILLM
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMCreator:
|
||||
@@ -26,10 +30,26 @@ class LLMCreator:
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create_llm(cls, type, api_key, user_api_key, decoded_token, *args, **kwargs):
|
||||
def create_llm(
|
||||
cls, type, api_key, user_api_key, decoded_token, model_id=None, *args, **kwargs
|
||||
):
|
||||
from application.core.model_utils import get_base_url_for_model
|
||||
|
||||
llm_class = cls.llms.get(type.lower())
|
||||
if not llm_class:
|
||||
raise ValueError(f"No LLM class found for type {type}")
|
||||
|
||||
# Extract base_url from model configuration if model_id is provided
|
||||
base_url = None
|
||||
if model_id:
|
||||
base_url = get_base_url_for_model(model_id)
|
||||
|
||||
return llm_class(
|
||||
api_key, user_api_key, decoded_token=decoded_token, *args, **kwargs
|
||||
api_key,
|
||||
user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
model_id=model_id,
|
||||
base_url=base_url,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -2,6 +2,8 @@ import base64
|
||||
import json
|
||||
import logging
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.llm.base import BaseLLM
|
||||
from application.storage.storage_creator import StorageCreator
|
||||
@@ -9,20 +11,25 @@ from application.storage.storage_creator import StorageCreator
|
||||
|
||||
class OpenAILLM(BaseLLM):
|
||||
|
||||
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
|
||||
from openai import OpenAI
|
||||
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
if (
|
||||
self.api_key = api_key or settings.OPENAI_API_KEY or settings.API_KEY
|
||||
self.user_api_key = user_api_key
|
||||
|
||||
# Priority: 1) Parameter base_url, 2) Settings OPENAI_BASE_URL, 3) Default
|
||||
effective_base_url = None
|
||||
if base_url and isinstance(base_url, str) and base_url.strip():
|
||||
effective_base_url = base_url
|
||||
elif (
|
||||
isinstance(settings.OPENAI_BASE_URL, str)
|
||||
and settings.OPENAI_BASE_URL.strip()
|
||||
):
|
||||
self.client = OpenAI(api_key=api_key, base_url=settings.OPENAI_BASE_URL)
|
||||
effective_base_url = settings.OPENAI_BASE_URL
|
||||
else:
|
||||
DEFAULT_OPENAI_API_BASE = "https://api.openai.com/v1"
|
||||
self.client = OpenAI(api_key=api_key, base_url=DEFAULT_OPENAI_API_BASE)
|
||||
self.api_key = api_key
|
||||
self.user_api_key = user_api_key
|
||||
effective_base_url = "https://api.openai.com/v1"
|
||||
|
||||
self.client = OpenAI(api_key=self.api_key, base_url=effective_base_url)
|
||||
self.storage = StorageCreator.get_storage()
|
||||
|
||||
def _clean_messages_openai(self, messages):
|
||||
@@ -33,7 +40,6 @@ class OpenAILLM(BaseLLM):
|
||||
|
||||
if role == "model":
|
||||
role = "assistant"
|
||||
|
||||
if role and content is not None:
|
||||
if isinstance(content, str):
|
||||
cleaned_messages.append({"role": role, "content": content})
|
||||
@@ -107,7 +113,6 @@ class OpenAILLM(BaseLLM):
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unexpected content type: {type(content)}")
|
||||
|
||||
return cleaned_messages
|
||||
|
||||
def _raw_gen(
|
||||
@@ -132,10 +137,8 @@ class OpenAILLM(BaseLLM):
|
||||
|
||||
if tools:
|
||||
request_params["tools"] = tools
|
||||
|
||||
if response_format:
|
||||
request_params["response_format"] = response_format
|
||||
|
||||
response = self.client.chat.completions.create(**request_params)
|
||||
|
||||
if tools:
|
||||
@@ -165,10 +168,8 @@ class OpenAILLM(BaseLLM):
|
||||
|
||||
if tools:
|
||||
request_params["tools"] = tools
|
||||
|
||||
if response_format:
|
||||
request_params["response_format"] = response_format
|
||||
|
||||
response = self.client.chat.completions.create(**request_params)
|
||||
|
||||
try:
|
||||
@@ -194,7 +195,6 @@ class OpenAILLM(BaseLLM):
|
||||
def prepare_structured_output_format(self, json_schema):
|
||||
if not json_schema:
|
||||
return None
|
||||
|
||||
try:
|
||||
|
||||
def add_additional_properties_false(schema_obj):
|
||||
@@ -204,11 +204,11 @@ class OpenAILLM(BaseLLM):
|
||||
if schema_copy.get("type") == "object":
|
||||
schema_copy["additionalProperties"] = False
|
||||
# Ensure 'required' includes all properties for OpenAI strict mode
|
||||
|
||||
if "properties" in schema_copy:
|
||||
schema_copy["required"] = list(
|
||||
schema_copy["properties"].keys()
|
||||
)
|
||||
|
||||
for key, value in schema_copy.items():
|
||||
if key == "properties" and isinstance(value, dict):
|
||||
schema_copy[key] = {
|
||||
@@ -224,7 +224,6 @@ class OpenAILLM(BaseLLM):
|
||||
add_additional_properties_false(sub_schema)
|
||||
for sub_schema in value
|
||||
]
|
||||
|
||||
return schema_copy
|
||||
return schema_obj
|
||||
|
||||
@@ -243,7 +242,6 @@ class OpenAILLM(BaseLLM):
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error preparing structured output format: {e}")
|
||||
return None
|
||||
@@ -277,21 +275,19 @@ class OpenAILLM(BaseLLM):
|
||||
"""
|
||||
if not attachments:
|
||||
return messages
|
||||
|
||||
prepared_messages = messages.copy()
|
||||
|
||||
# Find the user message to attach file_id to the last one
|
||||
|
||||
user_message_index = None
|
||||
for i in range(len(prepared_messages) - 1, -1, -1):
|
||||
if prepared_messages[i].get("role") == "user":
|
||||
user_message_index = i
|
||||
break
|
||||
|
||||
if user_message_index is None:
|
||||
user_message = {"role": "user", "content": []}
|
||||
prepared_messages.append(user_message)
|
||||
user_message_index = len(prepared_messages) - 1
|
||||
|
||||
if isinstance(prepared_messages[user_message_index].get("content"), str):
|
||||
text_content = prepared_messages[user_message_index]["content"]
|
||||
prepared_messages[user_message_index]["content"] = [
|
||||
@@ -299,7 +295,6 @@ class OpenAILLM(BaseLLM):
|
||||
]
|
||||
elif not isinstance(prepared_messages[user_message_index].get("content"), list):
|
||||
prepared_messages[user_message_index]["content"] = []
|
||||
|
||||
for attachment in attachments:
|
||||
mime_type = attachment.get("mime_type")
|
||||
|
||||
@@ -326,6 +321,7 @@ class OpenAILLM(BaseLLM):
|
||||
}
|
||||
)
|
||||
# Handle PDFs using the file API
|
||||
|
||||
elif mime_type == "application/pdf":
|
||||
try:
|
||||
file_id = self._upload_file_to_openai(attachment)
|
||||
@@ -341,7 +337,6 @@ class OpenAILLM(BaseLLM):
|
||||
"text": f"File content:\n\n{attachment['content']}",
|
||||
}
|
||||
)
|
||||
|
||||
return prepared_messages
|
||||
|
||||
def _get_base64_image(self, attachment):
|
||||
@@ -357,7 +352,6 @@ class OpenAILLM(BaseLLM):
|
||||
file_path = attachment.get("path")
|
||||
if not file_path:
|
||||
raise ValueError("No file path provided in attachment")
|
||||
|
||||
try:
|
||||
with self.storage.get_file(file_path) as image_file:
|
||||
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||
@@ -381,12 +375,10 @@ class OpenAILLM(BaseLLM):
|
||||
|
||||
if "openai_file_id" in attachment:
|
||||
return attachment["openai_file_id"]
|
||||
|
||||
file_path = attachment.get("path")
|
||||
|
||||
if not self.storage.file_exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
try:
|
||||
file_id = self.storage.process_file(
|
||||
file_path,
|
||||
@@ -404,7 +396,6 @@ class OpenAILLM(BaseLLM):
|
||||
attachments_collection.update_one(
|
||||
{"_id": attachment["_id"]}, {"$set": {"openai_file_id": file_id}}
|
||||
)
|
||||
|
||||
return file_id
|
||||
except Exception as e:
|
||||
logging.error(f"Error uploading file to OpenAI: {e}", exc_info=True)
|
||||
|
||||
@@ -16,7 +16,7 @@ class ClassicRAG(BaseRetriever):
|
||||
prompt="",
|
||||
chunks=2,
|
||||
doc_token_limit=50000,
|
||||
gpt_model="docsgpt",
|
||||
model_id="docsgpt-local",
|
||||
user_api_key=None,
|
||||
llm_name=settings.LLM_PROVIDER,
|
||||
api_key=settings.API_KEY,
|
||||
@@ -40,7 +40,7 @@ class ClassicRAG(BaseRetriever):
|
||||
f"ClassicRAG initialized with chunks={self.chunks}, user_api_key={user_identifier}, "
|
||||
f"sources={'active_docs' in source and source['active_docs'] is not None}"
|
||||
)
|
||||
self.gpt_model = gpt_model
|
||||
self.model_id = model_id
|
||||
self.doc_token_limit = doc_token_limit
|
||||
self.user_api_key = user_api_key
|
||||
self.llm_name = llm_name
|
||||
@@ -100,7 +100,7 @@ class ClassicRAG(BaseRetriever):
|
||||
]
|
||||
|
||||
try:
|
||||
rephrased_query = self.llm.gen(model=self.gpt_model, messages=messages)
|
||||
rephrased_query = self.llm.gen(model=self.model_id, messages=messages)
|
||||
print(f"Rephrased query: {rephrased_query}")
|
||||
return rephrased_query if rephrased_query else self.original_question
|
||||
except Exception as e:
|
||||
|
||||
@@ -7,6 +7,8 @@ import tiktoken
|
||||
from flask import jsonify, make_response
|
||||
from werkzeug.utils import secure_filename
|
||||
|
||||
from application.core.model_utils import get_token_limit
|
||||
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
@@ -75,11 +77,9 @@ def count_tokens_docs(docs):
|
||||
|
||||
|
||||
def calculate_doc_token_budget(
|
||||
gpt_model: str = "gpt-4o", history_token_limit: int = 2000
|
||||
model_id: str = "gpt-4o", history_token_limit: int = 2000
|
||||
) -> int:
|
||||
total_context = settings.LLM_TOKEN_LIMITS.get(
|
||||
gpt_model, settings.DEFAULT_LLM_TOKEN_LIMIT
|
||||
)
|
||||
total_context = get_token_limit(model_id)
|
||||
reserved = sum(settings.RESERVED_TOKENS.values())
|
||||
doc_budget = total_context - history_token_limit - reserved
|
||||
return max(doc_budget, 1000)
|
||||
@@ -144,16 +144,13 @@ def get_hash(data):
|
||||
return hashlib.md5(data.encode(), usedforsecurity=False).hexdigest()
|
||||
|
||||
|
||||
def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"):
|
||||
def limit_chat_history(history, max_token_limit=None, model_id="docsgpt-local"):
|
||||
"""Limit chat history to fit within token limit."""
|
||||
from application.core.settings import settings
|
||||
|
||||
model_token_limit = get_token_limit(model_id)
|
||||
max_token_limit = (
|
||||
max_token_limit
|
||||
if max_token_limit
|
||||
and max_token_limit
|
||||
< settings.LLM_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_LLM_TOKEN_LIMIT)
|
||||
else settings.LLM_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_LLM_TOKEN_LIMIT)
|
||||
if max_token_limit and max_token_limit < model_token_limit
|
||||
else model_token_limit
|
||||
)
|
||||
|
||||
if not history:
|
||||
@@ -205,37 +202,44 @@ def clean_text_for_tts(text: str) -> str:
|
||||
clean text for Text-to-Speech processing.
|
||||
"""
|
||||
# Handle code blocks and links
|
||||
text = re.sub(r'```mermaid[\s\S]*?```', ' flowchart, ', text) ## ```mermaid...```
|
||||
text = re.sub(r'```[\s\S]*?```', ' code block, ', text) ## ```code```
|
||||
text = re.sub(r'\[([^\]]+)\]\([^\)]+\)', r'\1', text) ## [text](url)
|
||||
text = re.sub(r'!\[([^\]]*)\]\([^\)]+\)', '', text) ## 
|
||||
|
||||
text = re.sub(r"```mermaid[\s\S]*?```", " flowchart, ", text) ## ```mermaid...```
|
||||
text = re.sub(r"```[\s\S]*?```", " code block, ", text) ## ```code```
|
||||
text = re.sub(r"\[([^\]]+)\]\([^\)]+\)", r"\1", text) ## [text](url)
|
||||
text = re.sub(r"!\[([^\]]*)\]\([^\)]+\)", "", text) ## 
|
||||
|
||||
# Remove markdown formatting
|
||||
text = re.sub(r'`([^`]+)`', r'\1', text) ## `code`
|
||||
text = re.sub(r'\{([^}]*)\}', r' \1 ', text) ## {text}
|
||||
text = re.sub(r'[{}]', ' ', text) ## unmatched {}
|
||||
text = re.sub(r'\[([^\]]+)\]', r' \1 ', text) ## [text]
|
||||
text = re.sub(r'[\[\]]', ' ', text) ## unmatched []
|
||||
text = re.sub(r'(\*\*|__)(.*?)\1', r'\2', text) ## **bold** __bold__
|
||||
text = re.sub(r'(\*|_)(.*?)\1', r'\2', text) ## *italic* _italic_
|
||||
text = re.sub(r'^#{1,6}\s+', '', text, flags=re.MULTILINE) ## # headers
|
||||
text = re.sub(r'^>\s+', '', text, flags=re.MULTILINE) ## > blockquotes
|
||||
text = re.sub(r'^[\s]*[-\*\+]\s+', '', text, flags=re.MULTILINE) ## - * + lists
|
||||
text = re.sub(r'^[\s]*\d+\.\s+', '', text, flags=re.MULTILINE) ## 1. numbered lists
|
||||
text = re.sub(r'^[\*\-_]{3,}\s*$', '', text, flags=re.MULTILINE) ## --- *** ___ rules
|
||||
text = re.sub(r'<[^>]*>', '', text) ## <html> tags
|
||||
|
||||
#Remove non-ASCII (emojis, special Unicode)
|
||||
text = re.sub(r'[^\x20-\x7E\n\r\t]', '', text)
|
||||
text = re.sub(r"`([^`]+)`", r"\1", text) ## `code`
|
||||
text = re.sub(r"\{([^}]*)\}", r" \1 ", text) ## {text}
|
||||
text = re.sub(r"[{}]", " ", text) ## unmatched {}
|
||||
text = re.sub(r"\[([^\]]+)\]", r" \1 ", text) ## [text]
|
||||
text = re.sub(r"[\[\]]", " ", text) ## unmatched []
|
||||
text = re.sub(r"(\*\*|__)(.*?)\1", r"\2", text) ## **bold** __bold__
|
||||
text = re.sub(r"(\*|_)(.*?)\1", r"\2", text) ## *italic* _italic_
|
||||
text = re.sub(r"^#{1,6}\s+", "", text, flags=re.MULTILINE) ## # headers
|
||||
text = re.sub(r"^>\s+", "", text, flags=re.MULTILINE) ## > blockquotes
|
||||
text = re.sub(r"^[\s]*[-\*\+]\s+", "", text, flags=re.MULTILINE) ## - * + lists
|
||||
text = re.sub(r"^[\s]*\d+\.\s+", "", text, flags=re.MULTILINE) ## 1. numbered lists
|
||||
text = re.sub(
|
||||
r"^[\*\-_]{3,}\s*$", "", text, flags=re.MULTILINE
|
||||
) ## --- *** ___ rules
|
||||
text = re.sub(r"<[^>]*>", "", text) ## <html> tags
|
||||
|
||||
#Replace special sequences
|
||||
text = re.sub(r'-->', ', ', text) ## -->
|
||||
text = re.sub(r'<--', ', ', text) ## <--
|
||||
text = re.sub(r'=>', ', ', text) ## =>
|
||||
text = re.sub(r'::', ' ', text) ## ::
|
||||
# Remove non-ASCII (emojis, special Unicode)
|
||||
|
||||
#Normalize whitespace
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
text = re.sub(r"[^\x20-\x7E\n\r\t]", "", text)
|
||||
|
||||
# Replace special sequences
|
||||
|
||||
text = re.sub(r"-->", ", ", text) ## -->
|
||||
text = re.sub(r"<--", ", ", text) ## <--
|
||||
text = re.sub(r"=>", ", ", text) ## =>
|
||||
text = re.sub(r"::", " ", text) ## ::
|
||||
|
||||
# Normalize whitespace
|
||||
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
text = text.strip()
|
||||
|
||||
return text
|
||||
|
||||
@@ -165,7 +165,7 @@ def run_agent_logic(agent_config, input_data):
|
||||
agent_type,
|
||||
endpoint="webhook",
|
||||
llm_name=settings.LLM_PROVIDER,
|
||||
gpt_model=settings.LLM_NAME,
|
||||
model_id=settings.LLM_NAME,
|
||||
api_key=settings.API_KEY,
|
||||
user_api_key=user_api_key,
|
||||
prompt=prompt,
|
||||
@@ -180,7 +180,7 @@ def run_agent_logic(agent_config, input_data):
|
||||
prompt=prompt,
|
||||
chunks=chunks,
|
||||
token_limit=settings.DEFAULT_MAX_HISTORY,
|
||||
gpt_model=settings.LLM_NAME,
|
||||
model_id=settings.LLM_NAME,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
5997
docs/package-lock.json
generated
5997
docs/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -8,9 +8,9 @@
|
||||
"dependencies": {
|
||||
"@vercel/analytics": "^1.1.1",
|
||||
"docsgpt-react": "^0.5.1",
|
||||
"next": "^15.3.3",
|
||||
"nextra": "^2.13.2",
|
||||
"nextra-theme-docs": "^2.13.2",
|
||||
"next": "^15.5.6",
|
||||
"nextra": "^4.6.0",
|
||||
"nextra-theme-docs": "^4.6.0",
|
||||
"react": "^18.2.0",
|
||||
"react-dom": "^18.2.0"
|
||||
}
|
||||
|
||||
@@ -57,7 +57,7 @@ The easiest way to launch DocsGPT is using the provided `setup.sh` script. This
|
||||
|
||||
* **4) Connect Cloud API Provider:** This option lets you connect DocsGPT to a commercial Cloud API provider such as OpenAI, Google (Vertex AI/Gemini), Anthropic (Claude), Groq, HuggingFace Inference API, or Azure OpenAI. You will need an API key from your chosen provider. Select this if you prefer to use a powerful cloud-based LLM.
|
||||
|
||||
* **5) Modify DocsGPT's source code and rebuild the Docker images locally. Instead of pulling prebuilt images from Docker Hub or using the hosted/public API, you build the entire backend and frontend from source, customizing how DocsGPT works internally, or run it in an environment without internet access.
|
||||
* **5) Modify DocsGPT's source code and rebuild the Docker images locally.** Instead of pulling prebuilt images from Docker Hub or using the hosted/public API, you build the entire backend and frontend from source, customizing how DocsGPT works internally, or run it in an environment without internet access.
|
||||
|
||||
After selecting an option and providing any required information (like API keys or model names), the script will configure your `.env` file and start DocsGPT using Docker Compose.
|
||||
|
||||
@@ -119,4 +119,4 @@ If you prefer a more manual approach, you can follow our [Docker Deployment docu
|
||||
|
||||
For more advanced customization of DocsGPT settings, such as configuring vector stores, embedding models, and other parameters, please refer to the [DocsGPT Settings documentation](/Deploying/DocsGPT-Settings). This guide explains how to modify the `.env` file or `settings.py` for deeper configuration.
|
||||
|
||||
Enjoy using DocsGPT!
|
||||
Enjoy using DocsGPT!
|
||||
|
||||
@@ -3,4 +3,4 @@ VITE_BASE_URL=http://localhost:5173
|
||||
VITE_API_HOST=http://127.0.0.1:7091
|
||||
VITE_API_STREAMING=true
|
||||
VITE_NOTIFICATION_TEXT="What's new in 0.14.0 — Changelog"
|
||||
VITE_NOTIFICATION_LINK="#"
|
||||
VITE_NOTIFICATION_LINK="https://blog.docsgpt.cloud/docsgpt-0-14-agents-automate-integrate-and-innovate/"
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
node_modules/
|
||||
dist/
|
||||
prettier.config.cjs
|
||||
.eslintrc.cjs
|
||||
env.d.ts
|
||||
public/
|
||||
assets/
|
||||
vite-env.d.ts
|
||||
.prettierignore
|
||||
package-lock.json
|
||||
package.json
|
||||
postcss.config.cjs
|
||||
prettier.config.cjs
|
||||
tailwind.config.cjs
|
||||
tsconfig.json
|
||||
tsconfig.node.json
|
||||
vite.config.ts
|
||||
@@ -1,45 +0,0 @@
|
||||
module.exports = {
|
||||
env: {
|
||||
browser: true,
|
||||
es2021: true,
|
||||
node: true,
|
||||
},
|
||||
extends: [
|
||||
'eslint:recommended',
|
||||
'plugin:@typescript-eslint/recommended',
|
||||
'plugin:react/recommended',
|
||||
'plugin:prettier/recommended',
|
||||
],
|
||||
overrides: [],
|
||||
parser: '@typescript-eslint/parser',
|
||||
parserOptions: {
|
||||
ecmaVersion: 'latest',
|
||||
sourceType: 'module',
|
||||
},
|
||||
plugins: ['react', 'unused-imports'],
|
||||
rules: {
|
||||
'react/prop-types': 'off',
|
||||
'unused-imports/no-unused-imports': 'error',
|
||||
'react/react-in-jsx-scope': 'off',
|
||||
'prettier/prettier': [
|
||||
'error',
|
||||
{
|
||||
endOfLine: 'auto',
|
||||
},
|
||||
],
|
||||
},
|
||||
settings: {
|
||||
'import/parsers': {
|
||||
'@typescript-eslint/parser': ['.ts', '.tsx'],
|
||||
},
|
||||
react: {
|
||||
version: 'detect',
|
||||
},
|
||||
'import/resolver': {
|
||||
node: {
|
||||
paths: ['src'],
|
||||
extensions: ['.js', '.jsx', '.ts', '.tsx'],
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
78
frontend/eslint.config.js
Normal file
78
frontend/eslint.config.js
Normal file
@@ -0,0 +1,78 @@
|
||||
import js from '@eslint/js'
|
||||
import tsParser from '@typescript-eslint/parser'
|
||||
import tsPlugin from '@typescript-eslint/eslint-plugin'
|
||||
import react from 'eslint-plugin-react'
|
||||
import unusedImports from 'eslint-plugin-unused-imports'
|
||||
import prettier from 'eslint-plugin-prettier'
|
||||
import globals from 'globals'
|
||||
|
||||
export default [
|
||||
{
|
||||
ignores: [
|
||||
'node_modules/',
|
||||
'dist/',
|
||||
'prettier.config.cjs',
|
||||
'.eslintrc.cjs',
|
||||
'env.d.ts',
|
||||
'public/',
|
||||
'assets/',
|
||||
'vite-env.d.ts',
|
||||
'.prettierignore',
|
||||
'package-lock.json',
|
||||
'package.json',
|
||||
'postcss.config.cjs',
|
||||
'tailwind.config.cjs',
|
||||
'tsconfig.json',
|
||||
'tsconfig.node.json',
|
||||
'vite.config.ts',
|
||||
],
|
||||
},
|
||||
{
|
||||
files: ['**/*.{js,jsx,ts,tsx}'],
|
||||
languageOptions: {
|
||||
ecmaVersion: 'latest',
|
||||
sourceType: 'module',
|
||||
parser: tsParser,
|
||||
parserOptions: {
|
||||
ecmaFeatures: {
|
||||
jsx: true,
|
||||
},
|
||||
},
|
||||
globals: {
|
||||
...globals.browser,
|
||||
...globals.es2021,
|
||||
...globals.node,
|
||||
},
|
||||
},
|
||||
plugins: {
|
||||
'@typescript-eslint': tsPlugin,
|
||||
react,
|
||||
'unused-imports': unusedImports,
|
||||
prettier,
|
||||
},
|
||||
rules: {
|
||||
...js.configs.recommended.rules,
|
||||
...tsPlugin.configs.recommended.rules,
|
||||
...react.configs.recommended.rules,
|
||||
...prettier.configs.recommended.rules,
|
||||
'react/prop-types': 'off',
|
||||
'unused-imports/no-unused-imports': 'error',
|
||||
'react/react-in-jsx-scope': 'off',
|
||||
'no-undef': 'off',
|
||||
'@typescript-eslint/no-explicit-any': 'warn',
|
||||
'@typescript-eslint/no-unused-vars': 'warn',
|
||||
'@typescript-eslint/no-unused-expressions': 'warn',
|
||||
'prettier/prettier': [
|
||||
'error',
|
||||
{
|
||||
endOfLine: 'auto',
|
||||
},
|
||||
],
|
||||
},
|
||||
settings: {
|
||||
react: {
|
||||
version: 'detect',
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
1480
frontend/package-lock.json
generated
1480
frontend/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -19,7 +19,7 @@
|
||||
]
|
||||
},
|
||||
"dependencies": {
|
||||
"@reduxjs/toolkit": "^2.8.2",
|
||||
"@reduxjs/toolkit": "^2.10.1",
|
||||
"chart.js": "^4.4.4",
|
||||
"clsx": "^2.1.1",
|
||||
"copy-to-clipboard": "^3.3.3",
|
||||
@@ -33,7 +33,7 @@
|
||||
"react-dom": "^19.1.1",
|
||||
"react-dropzone": "^14.3.8",
|
||||
"react-google-drive-picker": "^1.2.2",
|
||||
"react-i18next": "^15.4.0",
|
||||
"react-i18next": "^16.2.4",
|
||||
"react-markdown": "^9.0.1",
|
||||
"react-redux": "^9.2.0",
|
||||
"react-router-dom": "^7.6.1",
|
||||
@@ -46,30 +46,28 @@
|
||||
"devDependencies": {
|
||||
"@tailwindcss/postcss": "^4.1.10",
|
||||
"@types/lodash": "^4.17.20",
|
||||
"@types/mermaid": "^9.1.0",
|
||||
"@types/react": "^19.1.8",
|
||||
"@types/react-dom": "^19.1.7",
|
||||
"@types/react-syntax-highlighter": "^15.5.13",
|
||||
"@typescript-eslint/eslint-plugin": "^6.21.0",
|
||||
"@typescript-eslint/parser": "^6.21.0",
|
||||
"@typescript-eslint/eslint-plugin": "^8.46.3",
|
||||
"@typescript-eslint/parser": "^8.46.3",
|
||||
"@vitejs/plugin-react": "^4.3.4",
|
||||
"eslint": "^8.57.1",
|
||||
"eslint": "^9.39.1",
|
||||
"eslint-config-prettier": "^10.1.5",
|
||||
"eslint-config-standard-with-typescript": "^43.0.1",
|
||||
"eslint-plugin-import": "^2.31.0",
|
||||
"eslint-plugin-n": "^16.6.2",
|
||||
"eslint-plugin-n": "^17.23.1",
|
||||
"eslint-plugin-prettier": "^5.5.4",
|
||||
"eslint-plugin-promise": "^6.6.0",
|
||||
"eslint-plugin-react": "^7.37.5",
|
||||
"eslint-plugin-unused-imports": "^4.1.4",
|
||||
"husky": "^8.0.0",
|
||||
"husky": "^9.1.7",
|
||||
"lint-staged": "^15.3.0",
|
||||
"postcss": "^8.4.49",
|
||||
"prettier": "^3.5.3",
|
||||
"prettier-plugin-tailwindcss": "^0.7.1",
|
||||
"tailwindcss": "^4.1.11",
|
||||
"typescript": "^5.8.3",
|
||||
"vite": "^6.3.5",
|
||||
"vite": "^7.2.0",
|
||||
"vite-plugin-svgr": "^4.3.0"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import DocsGPT3 from './assets/cute_docsgpt3.svg';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import DocsGPT3 from './assets/cute_docsgpt3.svg';
|
||||
import DropdownModel from './components/DropdownModel';
|
||||
|
||||
export default function Hero({
|
||||
handleQuestion,
|
||||
}: {
|
||||
@@ -26,6 +28,10 @@ export default function Hero({
|
||||
<span className="text-4xl font-semibold">DocsGPT</span>
|
||||
<img className="mb-1 inline w-14" src={DocsGPT3} alt="docsgpt" />
|
||||
</div>
|
||||
{/* Model Selector */}
|
||||
<div className="relative w-72">
|
||||
<DropdownModel />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Demo Buttons Section */}
|
||||
@@ -38,7 +44,7 @@ export default function Hero({
|
||||
<button
|
||||
key={key}
|
||||
onClick={() => handleQuestion({ question: demo.query })}
|
||||
className={`border-dark-gray text-just-black hover:bg-cultured dark:border-dim-gray dark:text-chinese-white dark:hover:bg-charleston-green w-full rounded-[66px] border bg-transparent px-6 py-[14px] text-left transition-colors ${key >= 2 ? 'hidden md:block' : ''} // Show only 2 buttons on mobile`}
|
||||
className={`border-dark-gray text-just-black hover:bg-cultured dark:border-dim-gray dark:text-chinese-white dark:hover:bg-charleston-green w-full rounded-[66px] border bg-transparent px-6 py-[14px] text-left transition-colors ${key >= 2 ? 'hidden md:block' : ''}`}
|
||||
>
|
||||
<p className="text-black-1000 dark:text-bright-gray mb-2 font-semibold">
|
||||
{demo.header}
|
||||
|
||||
@@ -567,7 +567,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
<div className="flex items-center gap-1 pr-4">
|
||||
<NavLink
|
||||
target="_blank"
|
||||
to={'https://discord.gg/WHJdfbQDR4'}
|
||||
to={'https://discord.gg/vN7YFfdMpj'}
|
||||
className={
|
||||
'rounded-full hover:bg-gray-100 dark:hover:bg-[#28292E]'
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import { useTranslation } from 'react-i18next';
|
||||
import { useDispatch, useSelector } from 'react-redux';
|
||||
import { useNavigate, useParams } from 'react-router-dom';
|
||||
|
||||
import modelService from '../api/services/modelService';
|
||||
import userService from '../api/services/userService';
|
||||
import ArrowLeft from '../assets/arrow-left.svg';
|
||||
import SourceIcon from '../assets/source.svg';
|
||||
@@ -26,6 +27,7 @@ import { UserToolType } from '../settings/types';
|
||||
import AgentPreview from './AgentPreview';
|
||||
import { Agent, ToolSummary } from './types';
|
||||
|
||||
import type { Model } from '../models/types';
|
||||
const embeddingsName =
|
||||
import.meta.env.VITE_EMBEDDINGS_NAME ||
|
||||
'huggingface_sentence-transformers/all-mpnet-base-v2';
|
||||
@@ -59,18 +61,25 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
token_limit: undefined,
|
||||
limited_request_mode: false,
|
||||
request_limit: undefined,
|
||||
models: [],
|
||||
default_model_id: '',
|
||||
});
|
||||
const [imageFile, setImageFile] = useState<File | null>(null);
|
||||
const [prompts, setPrompts] = useState<
|
||||
{ name: string; id: string; type: string }[]
|
||||
>([]);
|
||||
const [userTools, setUserTools] = useState<OptionType[]>([]);
|
||||
const [availableModels, setAvailableModels] = useState<Model[]>([]);
|
||||
const [isSourcePopupOpen, setIsSourcePopupOpen] = useState(false);
|
||||
const [isToolsPopupOpen, setIsToolsPopupOpen] = useState(false);
|
||||
const [isModelsPopupOpen, setIsModelsPopupOpen] = useState(false);
|
||||
const [selectedSourceIds, setSelectedSourceIds] = useState<
|
||||
Set<string | number>
|
||||
>(new Set());
|
||||
const [selectedTools, setSelectedTools] = useState<ToolSummary[]>([]);
|
||||
const [selectedModelIds, setSelectedModelIds] = useState<Set<string>>(
|
||||
new Set(),
|
||||
);
|
||||
const [deleteConfirmation, setDeleteConfirmation] =
|
||||
useState<ActiveState>('INACTIVE');
|
||||
const [agentDetails, setAgentDetails] = useState<ActiveState>('INACTIVE');
|
||||
@@ -86,6 +95,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
const initialAgentRef = useRef<Agent | null>(null);
|
||||
const sourceAnchorButtonRef = useRef<HTMLButtonElement>(null);
|
||||
const toolAnchorButtonRef = useRef<HTMLButtonElement>(null);
|
||||
const modelAnchorButtonRef = useRef<HTMLButtonElement>(null);
|
||||
|
||||
const modeConfig = {
|
||||
new: {
|
||||
@@ -224,6 +234,13 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
formData.append('json_schema', JSON.stringify(agent.json_schema));
|
||||
}
|
||||
|
||||
if (agent.models && agent.models.length > 0) {
|
||||
formData.append('models', JSON.stringify(agent.models));
|
||||
}
|
||||
if (agent.default_model_id) {
|
||||
formData.append('default_model_id', agent.default_model_id);
|
||||
}
|
||||
|
||||
try {
|
||||
setDraftLoading(true);
|
||||
const response =
|
||||
@@ -320,6 +337,13 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
formData.append('request_limit', '0');
|
||||
}
|
||||
|
||||
if (agent.models && agent.models.length > 0) {
|
||||
formData.append('models', JSON.stringify(agent.models));
|
||||
}
|
||||
if (agent.default_model_id) {
|
||||
formData.append('default_model_id', agent.default_model_id);
|
||||
}
|
||||
|
||||
try {
|
||||
setPublishLoading(true);
|
||||
const response =
|
||||
@@ -388,8 +412,16 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
const data = await response.json();
|
||||
setPrompts(data);
|
||||
};
|
||||
const getModels = async () => {
|
||||
const response = await modelService.getModels(null);
|
||||
if (!response.ok) throw new Error('Failed to fetch models');
|
||||
const data = await response.json();
|
||||
const transformed = modelService.transformModels(data.models || []);
|
||||
setAvailableModels(transformed);
|
||||
};
|
||||
getTools();
|
||||
getPrompts();
|
||||
getModels();
|
||||
}, [token]);
|
||||
|
||||
// Auto-select default source if none selected
|
||||
@@ -462,6 +494,34 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
}
|
||||
}, [agentId, mode, token]);
|
||||
|
||||
useEffect(() => {
|
||||
if (agent.models && agent.models.length > 0 && availableModels.length > 0) {
|
||||
const agentModelIds = new Set(agent.models);
|
||||
if (agentModelIds.size > 0 && selectedModelIds.size === 0) {
|
||||
setSelectedModelIds(agentModelIds);
|
||||
}
|
||||
}
|
||||
}, [agent.models, availableModels.length]);
|
||||
|
||||
useEffect(() => {
|
||||
const modelsArray = Array.from(selectedModelIds);
|
||||
if (modelsArray.length > 0) {
|
||||
setAgent((prev) => ({
|
||||
...prev,
|
||||
models: modelsArray,
|
||||
default_model_id: modelsArray.includes(prev.default_model_id || '')
|
||||
? prev.default_model_id
|
||||
: modelsArray[0],
|
||||
}));
|
||||
} else {
|
||||
setAgent((prev) => ({
|
||||
...prev,
|
||||
models: [],
|
||||
default_model_id: '',
|
||||
}));
|
||||
}
|
||||
}, [selectedModelIds]);
|
||||
|
||||
useEffect(() => {
|
||||
const selectedSources = Array.from(selectedSourceIds)
|
||||
.map((id) =>
|
||||
@@ -882,6 +942,82 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="dark:bg-raisin-black rounded-[30px] bg-white px-6 py-3 dark:text-[#E0E0E0]">
|
||||
<h2 className="text-lg font-semibold">
|
||||
{t('agents.form.sections.models')}
|
||||
</h2>
|
||||
<div className="mt-3 flex flex-col gap-3">
|
||||
<button
|
||||
ref={modelAnchorButtonRef}
|
||||
onClick={() => setIsModelsPopupOpen(!isModelsPopupOpen)}
|
||||
className={`border-silver dark:bg-raisin-black w-full truncate rounded-3xl border bg-white px-5 py-3 text-left text-sm dark:border-[#7E7E7E] ${
|
||||
selectedModelIds.size > 0
|
||||
? 'text-jet dark:text-bright-gray'
|
||||
: 'dark:text-silver text-gray-400'
|
||||
}`}
|
||||
>
|
||||
{selectedModelIds.size > 0
|
||||
? availableModels
|
||||
.filter((m) => selectedModelIds.has(m.id))
|
||||
.map((m) => m.display_name)
|
||||
.join(', ')
|
||||
: t('agents.form.placeholders.selectModels')}
|
||||
</button>
|
||||
<MultiSelectPopup
|
||||
isOpen={isModelsPopupOpen}
|
||||
onClose={() => setIsModelsPopupOpen(false)}
|
||||
anchorRef={modelAnchorButtonRef}
|
||||
options={availableModels.map((model) => ({
|
||||
id: model.id,
|
||||
label: model.display_name,
|
||||
}))}
|
||||
selectedIds={selectedModelIds}
|
||||
onSelectionChange={(newSelectedIds: Set<string | number>) =>
|
||||
setSelectedModelIds(
|
||||
new Set(Array.from(newSelectedIds).map(String)),
|
||||
)
|
||||
}
|
||||
title={t('agents.form.modelsPopup.title')}
|
||||
searchPlaceholder={t(
|
||||
'agents.form.modelsPopup.searchPlaceholder',
|
||||
)}
|
||||
noOptionsMessage={t('agents.form.modelsPopup.noOptionsMessage')}
|
||||
/>
|
||||
{selectedModelIds.size > 0 && (
|
||||
<div>
|
||||
<label className="mb-2 block text-sm font-medium">
|
||||
{t('agents.form.labels.defaultModel')}
|
||||
</label>
|
||||
<Dropdown
|
||||
options={availableModels
|
||||
.filter((m) => selectedModelIds.has(m.id))
|
||||
.map((m) => ({
|
||||
label: m.display_name,
|
||||
value: m.id,
|
||||
}))}
|
||||
selectedValue={
|
||||
availableModels.find(
|
||||
(m) => m.id === agent.default_model_id,
|
||||
)?.display_name || null
|
||||
}
|
||||
onSelect={(option: { label: string; value: string }) =>
|
||||
setAgent({ ...agent, default_model_id: option.value })
|
||||
}
|
||||
size="w-full"
|
||||
rounded="3xl"
|
||||
border="border"
|
||||
buttonClassName="bg-white dark:bg-[#222327] border-silver dark:border-[#7E7E7E]"
|
||||
optionsClassName="bg-white dark:bg-[#383838] border-silver dark:border-[#7E7E7E]"
|
||||
placeholder={t(
|
||||
'agents.form.placeholders.selectDefaultModel',
|
||||
)}
|
||||
placeholderClassName="text-gray-400 dark:text-silver"
|
||||
contentSize="text-sm"
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
<div className="dark:bg-raisin-black rounded-[30px] bg-white px-6 py-3 dark:text-[#E0E0E0]">
|
||||
<button
|
||||
onClick={() =>
|
||||
|
||||
@@ -52,6 +52,10 @@ export const fetchPreviewAnswer = createAsyncThunk<
|
||||
}
|
||||
|
||||
if (state.preference) {
|
||||
const modelId =
|
||||
state.preference.selectedAgent?.default_model_id ||
|
||||
state.preference.selectedModel?.id;
|
||||
|
||||
if (API_STREAMING) {
|
||||
await handleFetchAnswerSteaming(
|
||||
question,
|
||||
@@ -120,22 +124,23 @@ export const fetchPreviewAnswer = createAsyncThunk<
|
||||
indx,
|
||||
state.preference.selectedAgent?.id,
|
||||
attachmentIds,
|
||||
false, // Don't save preview conversations
|
||||
false,
|
||||
modelId,
|
||||
);
|
||||
} else {
|
||||
// Non-streaming implementation
|
||||
const answer = await handleFetchAnswer(
|
||||
question,
|
||||
signal,
|
||||
state.preference.token,
|
||||
state.preference.selectedDocs,
|
||||
null, // No conversation ID for previews
|
||||
null,
|
||||
state.preference.prompt.id,
|
||||
state.preference.chunks,
|
||||
state.preference.token_limit,
|
||||
state.preference.selectedAgent?.id,
|
||||
attachmentIds,
|
||||
false, // Don't save preview conversations
|
||||
false,
|
||||
modelId,
|
||||
);
|
||||
|
||||
if (answer) {
|
||||
|
||||
@@ -32,4 +32,6 @@ export type Agent = {
|
||||
token_limit?: number;
|
||||
limited_request_mode?: boolean;
|
||||
request_limit?: number;
|
||||
models?: string[];
|
||||
default_model_id?: string;
|
||||
};
|
||||
|
||||
@@ -2,6 +2,7 @@ const endpoints = {
|
||||
USER: {
|
||||
CONFIG: '/api/config',
|
||||
NEW_TOKEN: '/api/generate_token',
|
||||
MODELS: '/api/models',
|
||||
DOCS: '/api/sources',
|
||||
DOCS_PAGINATED: '/api/sources/paginated',
|
||||
API_KEYS: '/api/get_api_keys',
|
||||
|
||||
25
frontend/src/api/services/modelService.ts
Normal file
25
frontend/src/api/services/modelService.ts
Normal file
@@ -0,0 +1,25 @@
|
||||
import apiClient from '../client';
|
||||
import endpoints from '../endpoints';
|
||||
|
||||
import type { AvailableModel, Model } from '../../models/types';
|
||||
|
||||
const modelService = {
|
||||
getModels: (token: string | null): Promise<Response> =>
|
||||
apiClient.get(endpoints.USER.MODELS, token, {}),
|
||||
|
||||
transformModels: (models: AvailableModel[]): Model[] =>
|
||||
models.map((model) => ({
|
||||
id: model.id,
|
||||
value: model.id,
|
||||
provider: model.provider,
|
||||
display_name: model.display_name,
|
||||
description: model.description,
|
||||
context_window: model.context_window,
|
||||
supported_attachment_types: model.supported_attachment_types,
|
||||
supports_tools: model.supports_tools,
|
||||
supports_structured_output: model.supports_structured_output,
|
||||
supports_streaming: model.supports_streaming,
|
||||
})),
|
||||
};
|
||||
|
||||
export default modelService;
|
||||
3
frontend/src/assets/rounded-tick.svg
Normal file
3
frontend/src/assets/rounded-tick.svg
Normal file
@@ -0,0 +1,3 @@
|
||||
<svg width="20" height="21" viewBox="0 0 20 21" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M10 0.75C4.62391 0.75 0.25 5.12391 0.25 10.5C0.25 15.8761 4.62391 20.25 10 20.25C15.3761 20.25 19.75 15.8761 19.75 10.5C19.75 5.12391 15.3761 0.75 10 0.75ZM15.0742 7.23234L8.77422 14.7323C8.70511 14.8147 8.61912 14.8812 8.52207 14.9273C8.42502 14.9735 8.31918 14.9983 8.21172 15H8.19906C8.09394 15 7.99 14.9778 7.89398 14.935C7.79797 14.8922 7.71202 14.8297 7.64172 14.7516L4.94172 11.7516C4.87315 11.6788 4.81981 11.5931 4.78483 11.4995C4.74986 11.4059 4.73395 11.3062 4.73805 11.2063C4.74215 11.1064 4.76617 11.0084 4.8087 10.9179C4.85124 10.8275 4.91142 10.7464 4.98572 10.6796C5.06002 10.6127 5.14694 10.5614 5.24136 10.5286C5.33579 10.4958 5.43581 10.4822 5.53556 10.4886C5.63531 10.495 5.73277 10.5213 5.82222 10.5659C5.91166 10.6106 5.99128 10.6726 6.05641 10.7484L8.17938 13.1072L13.9258 6.26766C14.0547 6.11863 14.237 6.02631 14.4335 6.01066C14.6299 5.99501 14.8246 6.05728 14.9754 6.18402C15.1263 6.31075 15.2212 6.49176 15.2397 6.68793C15.2582 6.8841 15.1988 7.07966 15.0742 7.23234Z" fill="#B5B5B5"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.1 KiB |
138
frontend/src/components/DropdownModel.tsx
Normal file
138
frontend/src/components/DropdownModel.tsx
Normal file
@@ -0,0 +1,138 @@
|
||||
import React, { useEffect } from 'react';
|
||||
import { useDispatch, useSelector } from 'react-redux';
|
||||
|
||||
import modelService from '../api/services/modelService';
|
||||
import Arrow2 from '../assets/dropdown-arrow.svg';
|
||||
import RoundedTick from '../assets/rounded-tick.svg';
|
||||
import {
|
||||
selectAvailableModels,
|
||||
selectSelectedModel,
|
||||
setAvailableModels,
|
||||
setModelsLoading,
|
||||
setSelectedModel,
|
||||
} from '../preferences/preferenceSlice';
|
||||
|
||||
import type { Model } from '../models/types';
|
||||
|
||||
export default function DropdownModel() {
|
||||
const dispatch = useDispatch();
|
||||
const selectedModel = useSelector(selectSelectedModel);
|
||||
const availableModels = useSelector(selectAvailableModels);
|
||||
const dropdownRef = React.useRef<HTMLDivElement>(null);
|
||||
const [isOpen, setIsOpen] = React.useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
const loadModels = async () => {
|
||||
if ((availableModels?.length ?? 0) > 0) {
|
||||
return;
|
||||
}
|
||||
dispatch(setModelsLoading(true));
|
||||
try {
|
||||
const response = await modelService.getModels(null);
|
||||
if (!response.ok) {
|
||||
throw new Error(`API error: ${response.status}`);
|
||||
}
|
||||
const data = await response.json();
|
||||
const models = data.models || [];
|
||||
const transformed = modelService.transformModels(models);
|
||||
|
||||
dispatch(setAvailableModels(transformed));
|
||||
if (!selectedModel && transformed.length > 0) {
|
||||
const defaultModel =
|
||||
transformed.find((m) => m.id === data.default_model_id) ||
|
||||
transformed[0];
|
||||
dispatch(setSelectedModel(defaultModel));
|
||||
} else if (selectedModel && transformed.length > 0) {
|
||||
const isValid = transformed.find((m) => m.id === selectedModel.id);
|
||||
if (!isValid) {
|
||||
const defaultModel =
|
||||
transformed.find((m) => m.id === data.default_model_id) ||
|
||||
transformed[0];
|
||||
dispatch(setSelectedModel(defaultModel));
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to load models:', error);
|
||||
} finally {
|
||||
dispatch(setModelsLoading(false));
|
||||
}
|
||||
};
|
||||
|
||||
loadModels();
|
||||
}, [availableModels?.length, dispatch, selectedModel]);
|
||||
|
||||
const handleClickOutside = (event: MouseEvent) => {
|
||||
if (
|
||||
dropdownRef.current &&
|
||||
!dropdownRef.current.contains(event.target as Node)
|
||||
) {
|
||||
setIsOpen(false);
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
document.addEventListener('mousedown', handleClickOutside);
|
||||
return () => {
|
||||
document.removeEventListener('mousedown', handleClickOutside);
|
||||
};
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<div ref={dropdownRef}>
|
||||
<div
|
||||
className={`bg-gray-1000 dark:bg-dark-charcoal mx-auto flex w-full cursor-pointer justify-between p-1 dark:text-white ${isOpen ? 'rounded-t-3xl' : 'rounded-3xl'}`}
|
||||
onClick={() => setIsOpen(!isOpen)}
|
||||
>
|
||||
{selectedModel?.display_name ? (
|
||||
<p className="mx-4 my-3 truncate overflow-hidden whitespace-nowrap">
|
||||
{selectedModel.display_name}
|
||||
</p>
|
||||
) : (
|
||||
<p className="mx-4 my-3 truncate overflow-hidden whitespace-nowrap">
|
||||
Select Model
|
||||
</p>
|
||||
)}
|
||||
<img
|
||||
src={Arrow2}
|
||||
alt="arrow"
|
||||
className={`${
|
||||
isOpen ? 'rotate-360' : 'rotate-270'
|
||||
} mr-3 w-3 transition-all select-none`}
|
||||
/>
|
||||
</div>
|
||||
{isOpen && (
|
||||
<div className="no-scrollbar dark:bg-dark-charcoal absolute right-0 left-0 z-20 -mt-1 max-h-52 w-full overflow-y-auto rounded-b-3xl bg-white shadow-md">
|
||||
{availableModels && (availableModels?.length ?? 0) > 0 ? (
|
||||
availableModels.map((model: Model) => (
|
||||
<div
|
||||
key={model.id}
|
||||
onClick={() => {
|
||||
dispatch(setSelectedModel(model));
|
||||
setIsOpen(false);
|
||||
}}
|
||||
className={`border-gray-3000/75 dark:border-purple-taupe/50 hover:bg-gray-3000/75 dark:hover:bg-purple-taupe flex h-10 w-full cursor-pointer items-center justify-between border-t`}
|
||||
>
|
||||
<div className="flex w-full items-center justify-between">
|
||||
<p className="overflow-hidden py-3 pr-2 pl-5 overflow-ellipsis whitespace-nowrap">
|
||||
{model.display_name}
|
||||
</p>
|
||||
{model.id === selectedModel?.id ? (
|
||||
<img
|
||||
src={RoundedTick}
|
||||
alt="selected"
|
||||
className="mr-3.5 h-4 w-4"
|
||||
/>
|
||||
) : null}
|
||||
</div>
|
||||
</div>
|
||||
))
|
||||
) : (
|
||||
<div className="h-10 w-full border-x-2 border-b-2">
|
||||
<p className="ml-5 py-3 text-gray-500">No models available</p>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -19,8 +19,8 @@ import {
|
||||
removeAttachment,
|
||||
selectAttachments,
|
||||
updateAttachment,
|
||||
reorderAttachments,
|
||||
} from '../upload/uploadSlice';
|
||||
import { reorderAttachments } from '../upload/uploadSlice';
|
||||
|
||||
import { ActiveState } from '../models/misc';
|
||||
import {
|
||||
@@ -77,7 +77,7 @@ export default function MessageInput({
|
||||
(browserOS === 'mac' && event.metaKey && event.key === 'k')
|
||||
) {
|
||||
event.preventDefault();
|
||||
setIsSourcesPopupOpen(!isSourcesPopupOpen);
|
||||
setIsSourcesPopupOpen((s) => !s);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -89,8 +89,198 @@ export default function MessageInput({
|
||||
|
||||
const uploadFiles = useCallback(
|
||||
(files: File[]) => {
|
||||
if (!files || files.length === 0) return;
|
||||
|
||||
const apiHost = import.meta.env.VITE_API_HOST;
|
||||
|
||||
if (files.length > 1) {
|
||||
const formData = new FormData();
|
||||
const indexToUiId: Record<number, string> = {};
|
||||
|
||||
files.forEach((file, i) => {
|
||||
formData.append('file', file);
|
||||
const uiId = crypto.randomUUID();
|
||||
indexToUiId[i] = uiId;
|
||||
dispatch(
|
||||
addAttachment({
|
||||
id: uiId,
|
||||
fileName: file.name,
|
||||
progress: 0,
|
||||
status: 'uploading' as const,
|
||||
taskId: '',
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
const xhr = new XMLHttpRequest();
|
||||
|
||||
xhr.upload.addEventListener('progress', (event) => {
|
||||
if (event.lengthComputable) {
|
||||
const progress = Math.round((event.loaded / event.total) * 100);
|
||||
Object.values(indexToUiId).forEach((uiId) =>
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: uiId,
|
||||
updates: { progress },
|
||||
}),
|
||||
),
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
xhr.onload = () => {
|
||||
const status = xhr.status;
|
||||
if (status === 200) {
|
||||
try {
|
||||
const response = JSON.parse(xhr.responseText);
|
||||
|
||||
if (Array.isArray(response?.tasks)) {
|
||||
const tasks = response.tasks as Array<{
|
||||
task_id?: string;
|
||||
filename?: string;
|
||||
attachment_id?: string;
|
||||
path?: string;
|
||||
}>;
|
||||
|
||||
tasks.forEach((t, idx) => {
|
||||
const uiId = indexToUiId[idx];
|
||||
if (!uiId) return;
|
||||
if (t?.task_id) {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: uiId,
|
||||
updates: {
|
||||
taskId: t.task_id,
|
||||
status: 'processing',
|
||||
progress: 10,
|
||||
},
|
||||
}),
|
||||
);
|
||||
} else {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: uiId,
|
||||
updates: { status: 'failed' },
|
||||
}),
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
if (tasks.length < files.length) {
|
||||
for (let i = tasks.length; i < files.length; i++) {
|
||||
const uiId = indexToUiId[i];
|
||||
if (uiId) {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: uiId,
|
||||
updates: { status: 'failed' },
|
||||
}),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (response?.task_id) {
|
||||
if (files.length === 1) {
|
||||
const uiId = indexToUiId[0];
|
||||
if (uiId) {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: uiId,
|
||||
updates: {
|
||||
taskId: response.task_id,
|
||||
status: 'processing',
|
||||
progress: 10,
|
||||
},
|
||||
}),
|
||||
);
|
||||
}
|
||||
} else {
|
||||
console.warn(
|
||||
'Server returned a single task_id for multiple files. Update backend to return tasks[].',
|
||||
);
|
||||
const firstUi = indexToUiId[0];
|
||||
if (firstUi) {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: firstUi,
|
||||
updates: {
|
||||
taskId: response.task_id,
|
||||
status: 'processing',
|
||||
progress: 10,
|
||||
},
|
||||
}),
|
||||
);
|
||||
}
|
||||
for (let i = 1; i < files.length; i++) {
|
||||
const uiId = indexToUiId[i];
|
||||
if (uiId) {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: uiId,
|
||||
updates: { status: 'failed' },
|
||||
}),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
console.error('Unexpected upload response shape', response);
|
||||
Object.values(indexToUiId).forEach((id) =>
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id,
|
||||
updates: { status: 'failed' },
|
||||
}),
|
||||
),
|
||||
);
|
||||
}
|
||||
} catch (err) {
|
||||
console.error(
|
||||
'Failed to parse upload response',
|
||||
err,
|
||||
xhr.responseText,
|
||||
);
|
||||
Object.values(indexToUiId).forEach((id) =>
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id,
|
||||
updates: { status: 'failed' },
|
||||
}),
|
||||
),
|
||||
);
|
||||
}
|
||||
} else {
|
||||
console.error('Upload failed', status, xhr.responseText);
|
||||
Object.values(indexToUiId).forEach((id) =>
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id,
|
||||
updates: { status: 'failed' },
|
||||
}),
|
||||
),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
xhr.onerror = () => {
|
||||
console.error('Upload network error');
|
||||
Object.values(indexToUiId).forEach((id) =>
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id,
|
||||
updates: { status: 'failed' },
|
||||
}),
|
||||
),
|
||||
);
|
||||
};
|
||||
|
||||
xhr.open('POST', `${apiHost}${endpoints.USER.STORE_ATTACHMENT}`);
|
||||
if (token) xhr.setRequestHeader('Authorization', `Bearer ${token}`);
|
||||
xhr.send(formData);
|
||||
return;
|
||||
}
|
||||
|
||||
// Single-file path: upload each file individually (original repo behavior)
|
||||
files.forEach((file) => {
|
||||
const formData = new FormData();
|
||||
formData.append('file', file);
|
||||
@@ -121,16 +311,54 @@ export default function MessageInput({
|
||||
|
||||
xhr.onload = () => {
|
||||
if (xhr.status === 200) {
|
||||
const response = JSON.parse(xhr.responseText);
|
||||
if (response.task_id) {
|
||||
try {
|
||||
const response = JSON.parse(xhr.responseText);
|
||||
if (response.task_id) {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: uniqueId,
|
||||
updates: {
|
||||
taskId: response.task_id,
|
||||
status: 'processing',
|
||||
progress: 10,
|
||||
},
|
||||
}),
|
||||
);
|
||||
} else {
|
||||
// If backend returned tasks[] for single-file, handle gracefully:
|
||||
if (
|
||||
Array.isArray(response?.tasks) &&
|
||||
response.tasks[0]?.task_id
|
||||
) {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: uniqueId,
|
||||
updates: {
|
||||
taskId: response.tasks[0].task_id,
|
||||
status: 'processing',
|
||||
progress: 10,
|
||||
},
|
||||
}),
|
||||
);
|
||||
} else {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: uniqueId,
|
||||
updates: { status: 'failed' },
|
||||
}),
|
||||
);
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
console.error(
|
||||
'Failed to parse upload response',
|
||||
err,
|
||||
xhr.responseText,
|
||||
);
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: uniqueId,
|
||||
updates: {
|
||||
taskId: response.task_id,
|
||||
status: 'processing',
|
||||
progress: 10,
|
||||
},
|
||||
updates: { status: 'failed' },
|
||||
}),
|
||||
);
|
||||
}
|
||||
@@ -154,7 +382,7 @@ export default function MessageInput({
|
||||
};
|
||||
|
||||
xhr.open('POST', `${apiHost}${endpoints.USER.STORE_ATTACHMENT}`);
|
||||
xhr.setRequestHeader('Authorization', `Bearer ${token}`);
|
||||
if (token) xhr.setRequestHeader('Authorization', `Bearer ${token}`);
|
||||
xhr.send(formData);
|
||||
});
|
||||
},
|
||||
@@ -163,15 +391,13 @@ export default function MessageInput({
|
||||
|
||||
const handleFileAttachment = (e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
if (!e.target.files || e.target.files.length === 0) return;
|
||||
|
||||
const files = Array.from(e.target.files);
|
||||
uploadFiles(files);
|
||||
|
||||
// clear input so same file can be selected again
|
||||
e.target.value = '';
|
||||
};
|
||||
|
||||
// Drag and drop handler
|
||||
// Drag & drop via react-dropzone
|
||||
const onDrop = useCallback(
|
||||
(acceptedFiles: File[]) => {
|
||||
uploadFiles(acceptedFiles);
|
||||
@@ -321,11 +547,8 @@ export default function MessageInput({
|
||||
handleAbort();
|
||||
};
|
||||
|
||||
// Drag state for reordering
|
||||
const [draggingId, setDraggingId] = useState<string | null>(null);
|
||||
|
||||
// no preview object URLs to revoke (preview removed per reviewer request)
|
||||
|
||||
const findIndexById = (id: string) =>
|
||||
attachments.findIndex((a) => a.id === id);
|
||||
|
||||
@@ -359,7 +582,9 @@ export default function MessageInput({
|
||||
|
||||
return (
|
||||
<div {...getRootProps()} className="flex w-full flex-col">
|
||||
{/* react-dropzone input (for drag/drop) */}
|
||||
<input {...getInputProps()} />
|
||||
|
||||
<div className="border-dark-gray bg-lotion dark:border-grey relative flex w-full flex-col rounded-[23px] border dark:bg-transparent">
|
||||
<div className="flex flex-wrap gap-1.5 px-2 py-2 sm:gap-2 sm:px-3">
|
||||
{attachments.map((attachment) => {
|
||||
@@ -374,7 +599,11 @@ export default function MessageInput({
|
||||
attachment.status !== 'completed'
|
||||
? 'opacity-70'
|
||||
: 'opacity-100'
|
||||
} ${draggingId === attachment.id ? 'ring-dashed opacity-60 ring-2 ring-purple-200' : ''}`}
|
||||
} ${
|
||||
draggingId === attachment.id
|
||||
? 'ring-dashed opacity-60 ring-2 ring-purple-200'
|
||||
: ''
|
||||
}`}
|
||||
title={attachment.fileName}
|
||||
>
|
||||
<div className="bg-purple-30 mr-2 flex h-8 w-8 items-center justify-center rounded-md p-1">
|
||||
|
||||
@@ -15,6 +15,7 @@ export function handleFetchAnswer(
|
||||
agentId?: string,
|
||||
attachments?: string[],
|
||||
save_conversation = true,
|
||||
modelId?: string,
|
||||
): Promise<
|
||||
| {
|
||||
result: any;
|
||||
@@ -47,6 +48,10 @@ export function handleFetchAnswer(
|
||||
save_conversation: save_conversation,
|
||||
};
|
||||
|
||||
if (modelId) {
|
||||
payload.model_id = modelId;
|
||||
}
|
||||
|
||||
// Add attachments to payload if they exist
|
||||
if (attachments && attachments.length > 0) {
|
||||
payload.attachments = attachments;
|
||||
@@ -101,6 +106,7 @@ export function handleFetchAnswerSteaming(
|
||||
agentId?: string,
|
||||
attachments?: string[],
|
||||
save_conversation = true,
|
||||
modelId?: string,
|
||||
): Promise<Answer> {
|
||||
const payload: RetrievalPayload = {
|
||||
question: question,
|
||||
@@ -114,6 +120,10 @@ export function handleFetchAnswerSteaming(
|
||||
save_conversation: save_conversation,
|
||||
};
|
||||
|
||||
if (modelId) {
|
||||
payload.model_id = modelId;
|
||||
}
|
||||
|
||||
// Add attachments to payload if they exist
|
||||
if (attachments && attachments.length > 0) {
|
||||
payload.attachments = attachments;
|
||||
|
||||
@@ -65,4 +65,5 @@ export interface RetrievalPayload {
|
||||
agent_id?: string;
|
||||
attachments?: string[];
|
||||
save_conversation?: boolean;
|
||||
model_id?: string;
|
||||
}
|
||||
|
||||
@@ -49,6 +49,9 @@ export const fetchAnswer = createAsyncThunk<
|
||||
}
|
||||
|
||||
const currentConversationId = state.conversation.conversationId;
|
||||
const modelId =
|
||||
state.preference.selectedAgent?.default_model_id ||
|
||||
state.preference.selectedModel?.id;
|
||||
|
||||
if (state.preference) {
|
||||
if (API_STREAMING) {
|
||||
@@ -156,7 +159,8 @@ export const fetchAnswer = createAsyncThunk<
|
||||
indx,
|
||||
state.preference.selectedAgent?.id,
|
||||
attachmentIds,
|
||||
true, // Always save conversation
|
||||
true,
|
||||
modelId,
|
||||
);
|
||||
} else {
|
||||
const answer = await handleFetchAnswer(
|
||||
@@ -170,7 +174,8 @@ export const fetchAnswer = createAsyncThunk<
|
||||
state.preference.token_limit,
|
||||
state.preference.selectedAgent?.id,
|
||||
attachmentIds,
|
||||
true, // Always save conversation
|
||||
true,
|
||||
modelId,
|
||||
);
|
||||
if (answer) {
|
||||
let sourcesPrepped = [];
|
||||
|
||||
@@ -530,6 +530,7 @@
|
||||
"prompt": "Prompt",
|
||||
"tools": "Tools",
|
||||
"agentType": "Agent type",
|
||||
"models": "Models",
|
||||
"advanced": "Advanced",
|
||||
"preview": "Preview"
|
||||
},
|
||||
@@ -540,6 +541,8 @@
|
||||
"chunksPerQuery": "Chunks per query",
|
||||
"selectType": "Select type",
|
||||
"selectTools": "Select tools",
|
||||
"selectModels": "Select models for this agent",
|
||||
"selectDefaultModel": "Select default model",
|
||||
"enterTokenLimit": "Enter token limit",
|
||||
"enterRequestLimit": "Enter request limit"
|
||||
},
|
||||
@@ -553,6 +556,11 @@
|
||||
"searchPlaceholder": "Search tools...",
|
||||
"noOptionsMessage": "No tools available"
|
||||
},
|
||||
"modelsPopup": {
|
||||
"title": "Select Models",
|
||||
"searchPlaceholder": "Search models...",
|
||||
"noOptionsMessage": "No models available"
|
||||
},
|
||||
"upload": {
|
||||
"clickToUpload": "Click to upload",
|
||||
"dragAndDrop": " or drag and drop"
|
||||
@@ -561,6 +569,9 @@
|
||||
"classic": "Classic",
|
||||
"react": "ReAct"
|
||||
},
|
||||
"labels": {
|
||||
"defaultModel": "Default Model"
|
||||
},
|
||||
"advanced": {
|
||||
"jsonSchema": "JSON response schema",
|
||||
"jsonSchemaDescription": "Define a JSON schema to enforce structured output format",
|
||||
|
||||
@@ -530,6 +530,7 @@
|
||||
"prompt": "Prompt",
|
||||
"tools": "Herramientas",
|
||||
"agentType": "Tipo de agente",
|
||||
"models": "Modelos",
|
||||
"advanced": "Avanzado",
|
||||
"preview": "Vista previa"
|
||||
},
|
||||
@@ -540,6 +541,8 @@
|
||||
"chunksPerQuery": "Fragmentos por consulta",
|
||||
"selectType": "Seleccionar tipo",
|
||||
"selectTools": "Seleccionar herramientas",
|
||||
"selectModels": "Seleccionar modelos para este agente",
|
||||
"selectDefaultModel": "Seleccionar modelo predeterminado",
|
||||
"enterTokenLimit": "Ingresar límite de tokens",
|
||||
"enterRequestLimit": "Ingresar límite de solicitudes"
|
||||
},
|
||||
@@ -553,6 +556,11 @@
|
||||
"searchPlaceholder": "Buscar herramientas...",
|
||||
"noOptionsMessage": "No hay herramientas disponibles"
|
||||
},
|
||||
"modelsPopup": {
|
||||
"title": "Seleccionar Modelos",
|
||||
"searchPlaceholder": "Buscar modelos...",
|
||||
"noOptionsMessage": "No hay modelos disponibles"
|
||||
},
|
||||
"upload": {
|
||||
"clickToUpload": "Haz clic para subir",
|
||||
"dragAndDrop": " o arrastra y suelta"
|
||||
@@ -561,6 +569,9 @@
|
||||
"classic": "Clásico",
|
||||
"react": "ReAct"
|
||||
},
|
||||
"labels": {
|
||||
"defaultModel": "Modelo Predeterminado"
|
||||
},
|
||||
"advanced": {
|
||||
"jsonSchema": "Esquema de respuesta JSON",
|
||||
"jsonSchemaDescription": "Define un esquema JSON para aplicar formato de salida estructurado",
|
||||
|
||||
@@ -530,6 +530,7 @@
|
||||
"prompt": "プロンプト",
|
||||
"tools": "ツール",
|
||||
"agentType": "エージェントタイプ",
|
||||
"models": "モデル",
|
||||
"advanced": "詳細設定",
|
||||
"preview": "プレビュー"
|
||||
},
|
||||
@@ -540,6 +541,8 @@
|
||||
"chunksPerQuery": "クエリごとのチャンク数",
|
||||
"selectType": "タイプを選択",
|
||||
"selectTools": "ツールを選択",
|
||||
"selectModels": "このエージェントのモデルを選択",
|
||||
"selectDefaultModel": "デフォルトモデルを選択",
|
||||
"enterTokenLimit": "トークン制限を入力",
|
||||
"enterRequestLimit": "リクエスト制限を入力"
|
||||
},
|
||||
@@ -553,6 +556,11 @@
|
||||
"searchPlaceholder": "ツールを検索...",
|
||||
"noOptionsMessage": "利用可能なツールがありません"
|
||||
},
|
||||
"modelsPopup": {
|
||||
"title": "モデルを選択",
|
||||
"searchPlaceholder": "モデルを検索...",
|
||||
"noOptionsMessage": "利用可能なモデルがありません"
|
||||
},
|
||||
"upload": {
|
||||
"clickToUpload": "クリックしてアップロード",
|
||||
"dragAndDrop": " またはドラッグ&ドロップ"
|
||||
@@ -561,6 +569,9 @@
|
||||
"classic": "クラシック",
|
||||
"react": "ReAct"
|
||||
},
|
||||
"labels": {
|
||||
"defaultModel": "デフォルトモデル"
|
||||
},
|
||||
"advanced": {
|
||||
"jsonSchema": "JSON応答スキーマ",
|
||||
"jsonSchemaDescription": "構造化された出力形式を適用するためのJSONスキーマを定義します",
|
||||
|
||||
@@ -530,6 +530,7 @@
|
||||
"prompt": "Промпт",
|
||||
"tools": "Инструменты",
|
||||
"agentType": "Тип агента",
|
||||
"models": "Модели",
|
||||
"advanced": "Расширенные",
|
||||
"preview": "Предпросмотр"
|
||||
},
|
||||
@@ -540,6 +541,8 @@
|
||||
"chunksPerQuery": "Фрагментов на запрос",
|
||||
"selectType": "Выберите тип",
|
||||
"selectTools": "Выберите инструменты",
|
||||
"selectModels": "Выберите модели для этого агента",
|
||||
"selectDefaultModel": "Выберите модель по умолчанию",
|
||||
"enterTokenLimit": "Введите лимит токенов",
|
||||
"enterRequestLimit": "Введите лимит запросов"
|
||||
},
|
||||
@@ -553,6 +556,11 @@
|
||||
"searchPlaceholder": "Поиск инструментов...",
|
||||
"noOptionsMessage": "Нет доступных инструментов"
|
||||
},
|
||||
"modelsPopup": {
|
||||
"title": "Выберите Модели",
|
||||
"searchPlaceholder": "Поиск моделей...",
|
||||
"noOptionsMessage": "Нет доступных моделей"
|
||||
},
|
||||
"upload": {
|
||||
"clickToUpload": "Нажмите для загрузки",
|
||||
"dragAndDrop": " или перетащите"
|
||||
@@ -561,6 +569,9 @@
|
||||
"classic": "Классический",
|
||||
"react": "ReAct"
|
||||
},
|
||||
"labels": {
|
||||
"defaultModel": "Модель по умолчанию"
|
||||
},
|
||||
"advanced": {
|
||||
"jsonSchema": "Схема ответа JSON",
|
||||
"jsonSchemaDescription": "Определите схему JSON для применения структурированного формата вывода",
|
||||
|
||||
@@ -530,6 +530,7 @@
|
||||
"prompt": "提示詞",
|
||||
"tools": "工具",
|
||||
"agentType": "代理類型",
|
||||
"models": "模型",
|
||||
"advanced": "進階",
|
||||
"preview": "預覽"
|
||||
},
|
||||
@@ -540,6 +541,8 @@
|
||||
"chunksPerQuery": "每次查詢的區塊數",
|
||||
"selectType": "選擇類型",
|
||||
"selectTools": "選擇工具",
|
||||
"selectModels": "為此代理選擇模型",
|
||||
"selectDefaultModel": "選擇預設模型",
|
||||
"enterTokenLimit": "輸入權杖限制",
|
||||
"enterRequestLimit": "輸入請求限制"
|
||||
},
|
||||
@@ -553,6 +556,11 @@
|
||||
"searchPlaceholder": "搜尋工具...",
|
||||
"noOptionsMessage": "沒有可用的工具"
|
||||
},
|
||||
"modelsPopup": {
|
||||
"title": "選擇模型",
|
||||
"searchPlaceholder": "搜尋模型...",
|
||||
"noOptionsMessage": "沒有可用的模型"
|
||||
},
|
||||
"upload": {
|
||||
"clickToUpload": "點擊上傳",
|
||||
"dragAndDrop": " 或拖放"
|
||||
@@ -561,6 +569,9 @@
|
||||
"classic": "經典",
|
||||
"react": "ReAct"
|
||||
},
|
||||
"labels": {
|
||||
"defaultModel": "預設模型"
|
||||
},
|
||||
"advanced": {
|
||||
"jsonSchema": "JSON回應架構",
|
||||
"jsonSchemaDescription": "定義JSON架構以強制執行結構化輸出格式",
|
||||
|
||||
@@ -530,6 +530,7 @@
|
||||
"prompt": "提示词",
|
||||
"tools": "工具",
|
||||
"agentType": "代理类型",
|
||||
"models": "模型",
|
||||
"advanced": "高级",
|
||||
"preview": "预览"
|
||||
},
|
||||
@@ -540,6 +541,8 @@
|
||||
"chunksPerQuery": "每次查询的块数",
|
||||
"selectType": "选择类型",
|
||||
"selectTools": "选择工具",
|
||||
"selectModels": "为此代理选择模型",
|
||||
"selectDefaultModel": "选择默认模型",
|
||||
"enterTokenLimit": "输入令牌限制",
|
||||
"enterRequestLimit": "输入请求限制"
|
||||
},
|
||||
@@ -553,6 +556,11 @@
|
||||
"searchPlaceholder": "搜索工具...",
|
||||
"noOptionsMessage": "没有可用的工具"
|
||||
},
|
||||
"modelsPopup": {
|
||||
"title": "选择模型",
|
||||
"searchPlaceholder": "搜索模型...",
|
||||
"noOptionsMessage": "没有可用的模型"
|
||||
},
|
||||
"upload": {
|
||||
"clickToUpload": "点击上传",
|
||||
"dragAndDrop": " 或拖放"
|
||||
@@ -561,6 +569,9 @@
|
||||
"classic": "经典",
|
||||
"react": "ReAct"
|
||||
},
|
||||
"labels": {
|
||||
"defaultModel": "默认模型"
|
||||
},
|
||||
"advanced": {
|
||||
"jsonSchema": "JSON响应架构",
|
||||
"jsonSchemaDescription": "定义JSON架构以强制执行结构化输出格式",
|
||||
|
||||
25
frontend/src/models/types.ts
Normal file
25
frontend/src/models/types.ts
Normal file
@@ -0,0 +1,25 @@
|
||||
export interface AvailableModel {
|
||||
id: string;
|
||||
provider: string;
|
||||
display_name: string;
|
||||
description?: string;
|
||||
context_window: number;
|
||||
supported_attachment_types: string[];
|
||||
supports_tools: boolean;
|
||||
supports_structured_output: boolean;
|
||||
supports_streaming: boolean;
|
||||
enabled: boolean;
|
||||
}
|
||||
|
||||
export interface Model {
|
||||
id: string;
|
||||
value: string;
|
||||
provider: string;
|
||||
display_name: string;
|
||||
description?: string;
|
||||
context_window: number;
|
||||
supported_attachment_types: string[];
|
||||
supports_tools: boolean;
|
||||
supports_structured_output: boolean;
|
||||
supports_streaming: boolean;
|
||||
}
|
||||
@@ -9,11 +9,12 @@ import { Agent } from '../agents/types';
|
||||
import { ActiveState, Doc } from '../models/misc';
|
||||
import { RootState } from '../store';
|
||||
import {
|
||||
getLocalRecentDocs,
|
||||
setLocalApiKey,
|
||||
setLocalRecentDocs,
|
||||
getLocalRecentDocs,
|
||||
} from './preferenceApi';
|
||||
|
||||
import type { Model } from '../models/types';
|
||||
export interface Preference {
|
||||
apiKey: string;
|
||||
prompt: { name: string; id: string; type: string };
|
||||
@@ -32,6 +33,9 @@ export interface Preference {
|
||||
agents: Agent[] | null;
|
||||
sharedAgents: Agent[] | null;
|
||||
selectedAgent: Agent | null;
|
||||
selectedModel: Model | null;
|
||||
availableModels: Model[];
|
||||
modelsLoading: boolean;
|
||||
}
|
||||
|
||||
const initialState: Preference = {
|
||||
@@ -61,6 +65,9 @@ const initialState: Preference = {
|
||||
agents: null,
|
||||
sharedAgents: null,
|
||||
selectedAgent: null,
|
||||
selectedModel: null,
|
||||
availableModels: [],
|
||||
modelsLoading: false,
|
||||
};
|
||||
|
||||
export const prefSlice = createSlice({
|
||||
@@ -109,6 +116,15 @@ export const prefSlice = createSlice({
|
||||
setSelectedAgent: (state, action) => {
|
||||
state.selectedAgent = action.payload;
|
||||
},
|
||||
setSelectedModel: (state, action: PayloadAction<Model | null>) => {
|
||||
state.selectedModel = action.payload;
|
||||
},
|
||||
setAvailableModels: (state, action: PayloadAction<Model[]>) => {
|
||||
state.availableModels = action.payload;
|
||||
},
|
||||
setModelsLoading: (state, action: PayloadAction<boolean>) => {
|
||||
state.modelsLoading = action.payload;
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
@@ -127,6 +143,9 @@ export const {
|
||||
setAgents,
|
||||
setSharedAgents,
|
||||
setSelectedAgent,
|
||||
setSelectedModel,
|
||||
setAvailableModels,
|
||||
setModelsLoading,
|
||||
} = prefSlice.actions;
|
||||
export default prefSlice.reducer;
|
||||
|
||||
@@ -198,6 +217,19 @@ prefListenerMiddleware.startListening({
|
||||
},
|
||||
});
|
||||
|
||||
prefListenerMiddleware.startListening({
|
||||
matcher: isAnyOf(setSelectedModel),
|
||||
effect: (action, listenerApi) => {
|
||||
const model = (listenerApi.getState() as RootState).preference
|
||||
.selectedModel;
|
||||
if (model) {
|
||||
localStorage.setItem('DocsGPTSelectedModel', JSON.stringify(model));
|
||||
} else {
|
||||
localStorage.removeItem('DocsGPTSelectedModel');
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
export const selectApiKey = (state: RootState) => state.preference.apiKey;
|
||||
export const selectApiKeyStatus = (state: RootState) =>
|
||||
!!state.preference.apiKey;
|
||||
@@ -227,3 +259,9 @@ export const selectSharedAgents = (state: RootState) =>
|
||||
state.preference.sharedAgents;
|
||||
export const selectSelectedAgent = (state: RootState) =>
|
||||
state.preference.selectedAgent;
|
||||
export const selectSelectedModel = (state: RootState) =>
|
||||
state.preference.selectedModel;
|
||||
export const selectAvailableModels = (state: RootState) =>
|
||||
state.preference.availableModels;
|
||||
export const selectModelsLoading = (state: RootState) =>
|
||||
state.preference.modelsLoading;
|
||||
|
||||
@@ -15,6 +15,7 @@ const prompt = localStorage.getItem('DocsGPTPrompt');
|
||||
const chunks = localStorage.getItem('DocsGPTChunks');
|
||||
const token_limit = localStorage.getItem('DocsGPTTokenLimit');
|
||||
const doc = localStorage.getItem('DocsGPTRecentDocs');
|
||||
const selectedModel = localStorage.getItem('DocsGPTSelectedModel');
|
||||
|
||||
const preloadedState: { preference: Preference } = {
|
||||
preference: {
|
||||
@@ -47,6 +48,9 @@ const preloadedState: { preference: Preference } = {
|
||||
agents: null,
|
||||
sharedAgents: null,
|
||||
selectedAgent: null,
|
||||
selectedModel: selectedModel ? JSON.parse(selectedModel) : null,
|
||||
availableModels: [],
|
||||
modelsLoading: false,
|
||||
},
|
||||
};
|
||||
const store = configureStore({
|
||||
|
||||
114
scripts/migrate_conversation_id_dbref_to_objectid.py
Normal file
114
scripts/migrate_conversation_id_dbref_to_objectid.py
Normal file
@@ -0,0 +1,114 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Migration script to convert conversation_id from DBRef to ObjectId in shared_conversations collection.
|
||||
"""
|
||||
|
||||
import pymongo
|
||||
import logging
|
||||
from tqdm import tqdm
|
||||
from bson.dbref import DBRef
|
||||
from bson.objectid import ObjectId
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger()
|
||||
|
||||
# Configuration
|
||||
MONGO_URI = "mongodb://localhost:27017/"
|
||||
DB_NAME = "docsgpt"
|
||||
|
||||
def backup_collection(collection, backup_collection_name):
|
||||
"""Backup collection before migration."""
|
||||
logger.info(f"Backing up collection {collection.name} to {backup_collection_name}")
|
||||
collection.aggregate([{"$out": backup_collection_name}])
|
||||
logger.info("Backup completed")
|
||||
|
||||
def migrate_conversation_id_dbref_to_objectid():
|
||||
"""Migrate conversation_id from DBRef to ObjectId."""
|
||||
client = pymongo.MongoClient(MONGO_URI)
|
||||
db = client[DB_NAME]
|
||||
shared_conversations_collection = db["shared_conversations"]
|
||||
|
||||
try:
|
||||
# Backup collection before migration
|
||||
backup_collection(shared_conversations_collection, "shared_conversations_backup")
|
||||
|
||||
# Find all documents and filter for DBRef conversation_id in Python
|
||||
all_documents = list(shared_conversations_collection.find({}))
|
||||
documents_with_dbref = []
|
||||
|
||||
for doc in all_documents:
|
||||
conversation_id_field = doc.get("conversation_id")
|
||||
if isinstance(conversation_id_field, DBRef):
|
||||
documents_with_dbref.append(doc)
|
||||
|
||||
if not documents_with_dbref:
|
||||
logger.info("No documents with DBRef conversation_id found. Migration not needed.")
|
||||
return
|
||||
|
||||
logger.info(f"Found {len(documents_with_dbref)} documents with DBRef conversation_id")
|
||||
|
||||
# Process each document
|
||||
migrated_count = 0
|
||||
error_count = 0
|
||||
|
||||
for doc in tqdm(documents_with_dbref, desc="Migrating conversation_id"):
|
||||
try:
|
||||
conversation_id_field = doc.get("conversation_id")
|
||||
|
||||
# Extract the ObjectId from the DBRef
|
||||
dbref_id = conversation_id_field.id
|
||||
|
||||
if dbref_id and ObjectId.is_valid(dbref_id):
|
||||
# Update the document to use direct ObjectId
|
||||
result = shared_conversations_collection.update_one(
|
||||
{"_id": doc["_id"]},
|
||||
{"$set": {"conversation_id": dbref_id}}
|
||||
)
|
||||
|
||||
if result.modified_count > 0:
|
||||
migrated_count += 1
|
||||
logger.debug(f"Successfully migrated document {doc['_id']}")
|
||||
else:
|
||||
error_count += 1
|
||||
logger.warning(f"Failed to update document {doc['_id']}")
|
||||
else:
|
||||
error_count += 1
|
||||
logger.warning(f"Invalid ObjectId in DBRef for document {doc['_id']}: {dbref_id}")
|
||||
|
||||
except Exception as e:
|
||||
error_count += 1
|
||||
logger.error(f"Error migrating document {doc['_id']}: {e}")
|
||||
|
||||
# Final verification
|
||||
all_docs_after = list(shared_conversations_collection.find({}))
|
||||
remaining_dbref = 0
|
||||
for doc in all_docs_after:
|
||||
if isinstance(doc.get("conversation_id"), DBRef):
|
||||
remaining_dbref += 1
|
||||
|
||||
logger.info("Migration completed:")
|
||||
logger.info(f" - Total documents processed: {len(documents_with_dbref)}")
|
||||
logger.info(f" - Successfully migrated: {migrated_count}")
|
||||
logger.info(f" - Errors encountered: {error_count}")
|
||||
logger.info(f" - Remaining DBRef documents: {remaining_dbref}")
|
||||
|
||||
if remaining_dbref == 0:
|
||||
logger.info("✅ Migration successful: All DBRef conversation_id fields have been converted to ObjectId")
|
||||
else:
|
||||
logger.warning(f"⚠️ Migration incomplete: {remaining_dbref} DBRef documents still exist")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Migration failed: {e}")
|
||||
raise
|
||||
finally:
|
||||
client.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
logger.info("Starting conversation_id DBRef to ObjectId migration...")
|
||||
migrate_conversation_id_dbref_to_objectid()
|
||||
logger.info("Migration completed successfully!")
|
||||
except Exception as e:
|
||||
logger.error(f"Migration failed due to error: {e}")
|
||||
logger.warning("Please verify database state or restore from backups if necessary.")
|
||||
@@ -12,7 +12,7 @@ class TestAgentCreator:
|
||||
assert isinstance(agent, ClassicAgent)
|
||||
assert agent.endpoint == agent_base_params["endpoint"]
|
||||
assert agent.llm_name == agent_base_params["llm_name"]
|
||||
assert agent.gpt_model == agent_base_params["gpt_model"]
|
||||
assert agent.model_id == agent_base_params["model_id"]
|
||||
|
||||
def test_create_react_agent(self, agent_base_params):
|
||||
agent = AgentCreator.create_agent("react", **agent_base_params)
|
||||
|
||||
@@ -15,7 +15,7 @@ class TestBaseAgentInitialization:
|
||||
|
||||
assert agent.endpoint == agent_base_params["endpoint"]
|
||||
assert agent.llm_name == agent_base_params["llm_name"]
|
||||
assert agent.gpt_model == agent_base_params["gpt_model"]
|
||||
assert agent.model_id == agent_base_params["model_id"]
|
||||
assert agent.api_key == agent_base_params["api_key"]
|
||||
assert agent.prompt == agent_base_params["prompt"]
|
||||
assert agent.user == agent_base_params["decoded_token"]["sub"]
|
||||
@@ -480,7 +480,7 @@ class TestBaseAgentLLMGeneration:
|
||||
|
||||
mock_llm.gen_stream.assert_called_once()
|
||||
call_args = mock_llm.gen_stream.call_args[1]
|
||||
assert call_args["model"] == agent.gpt_model
|
||||
assert call_args["model"] == agent.model_id
|
||||
assert call_args["messages"] == messages
|
||||
|
||||
def test_llm_gen_with_tools(
|
||||
|
||||
@@ -23,7 +23,7 @@ class TestReActAgent:
|
||||
|
||||
assert agent.endpoint == agent_base_params["endpoint"]
|
||||
assert agent.llm_name == agent_base_params["llm_name"]
|
||||
assert agent.gpt_model == agent_base_params["gpt_model"]
|
||||
assert agent.model_id == agent_base_params["model_id"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
|
||||
@@ -274,8 +274,8 @@ class TestGPTModelRetrieval:
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
|
||||
assert hasattr(resource, "gpt_model")
|
||||
assert resource.gpt_model is not None
|
||||
assert hasattr(resource, "default_model_id")
|
||||
assert resource.default_model_id is not None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@@ -412,7 +412,7 @@ class TestCompleteStreamMethod:
|
||||
resource.complete_stream(
|
||||
question="Test?",
|
||||
agent=mock_agent,
|
||||
conversation_id=None,
|
||||
conversation_id=None,
|
||||
user_api_key=None,
|
||||
decoded_token=decoded_token,
|
||||
should_save_conversation=True,
|
||||
@@ -500,9 +500,10 @@ class TestProcessResponseStream:
|
||||
|
||||
result = resource.process_response_stream(iter(stream))
|
||||
|
||||
assert len(result) == 5
|
||||
assert len(result) == 6
|
||||
assert result[0] is None
|
||||
assert result[4] == "Test error"
|
||||
assert result[5] is None
|
||||
|
||||
def test_handles_malformed_stream_data(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
@@ -108,7 +108,7 @@ class TestConversationServiceSave:
|
||||
sources=[],
|
||||
tool_calls=[],
|
||||
llm=mock_llm,
|
||||
gpt_model="gpt-4",
|
||||
model_id="gpt-4",
|
||||
decoded_token={}, # No 'sub' key
|
||||
)
|
||||
|
||||
@@ -136,7 +136,7 @@ class TestConversationServiceSave:
|
||||
sources=sources,
|
||||
tool_calls=[],
|
||||
llm=mock_llm,
|
||||
gpt_model="gpt-4",
|
||||
model_id="gpt-4",
|
||||
decoded_token={"sub": "user_123"},
|
||||
)
|
||||
|
||||
@@ -167,7 +167,7 @@ class TestConversationServiceSave:
|
||||
sources=[],
|
||||
tool_calls=[],
|
||||
llm=mock_llm,
|
||||
gpt_model="gpt-4",
|
||||
model_id="gpt-4",
|
||||
decoded_token={"sub": "user_123"},
|
||||
)
|
||||
|
||||
@@ -208,7 +208,7 @@ class TestConversationServiceSave:
|
||||
sources=[],
|
||||
tool_calls=[],
|
||||
llm=mock_llm,
|
||||
gpt_model="gpt-4",
|
||||
model_id="gpt-4",
|
||||
decoded_token={"sub": "user_123"},
|
||||
)
|
||||
|
||||
@@ -237,6 +237,6 @@ class TestConversationServiceSave:
|
||||
sources=[],
|
||||
tool_calls=[],
|
||||
llm=mock_llm,
|
||||
gpt_model="gpt-4",
|
||||
model_id="gpt-4",
|
||||
decoded_token={"sub": "hacker_456"},
|
||||
)
|
||||
|
||||
@@ -150,7 +150,7 @@ def agent_base_params(decoded_token):
|
||||
return {
|
||||
"endpoint": "https://api.example.com",
|
||||
"llm_name": "openai",
|
||||
"gpt_model": "gpt-4",
|
||||
"model_id": "gpt-4",
|
||||
"api_key": "test_api_key",
|
||||
"user_api_key": None,
|
||||
"prompt": "You are a helpful assistant.",
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
import sys
|
||||
import types
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class _FakeCompletion:
|
||||
def __init__(self, text):
|
||||
self.completion = text
|
||||
|
||||
|
||||
class _FakeCompletions:
|
||||
def __init__(self):
|
||||
self.last_kwargs = None
|
||||
@@ -17,6 +20,7 @@ class _FakeCompletions:
|
||||
return self._stream
|
||||
return _FakeCompletion("final")
|
||||
|
||||
|
||||
class _FakeAnthropic:
|
||||
def __init__(self, api_key=None):
|
||||
self.api_key = api_key
|
||||
@@ -29,9 +33,19 @@ def patch_anthropic(monkeypatch):
|
||||
fake.Anthropic = _FakeAnthropic
|
||||
fake.HUMAN_PROMPT = "<HUMAN>"
|
||||
fake.AI_PROMPT = "<AI>"
|
||||
|
||||
modules_to_remove = [key for key in sys.modules if key.startswith("anthropic")]
|
||||
for key in modules_to_remove:
|
||||
sys.modules.pop(key, None)
|
||||
sys.modules["anthropic"] = fake
|
||||
|
||||
if "application.llm.anthropic" in sys.modules:
|
||||
del sys.modules["application.llm.anthropic"]
|
||||
yield
|
||||
|
||||
sys.modules.pop("anthropic", None)
|
||||
if "application.llm.anthropic" in sys.modules:
|
||||
del sys.modules["application.llm.anthropic"]
|
||||
|
||||
|
||||
def test_anthropic_raw_gen_builds_prompt_and_returns_completion():
|
||||
@@ -42,7 +56,9 @@ def test_anthropic_raw_gen_builds_prompt_and_returns_completion():
|
||||
{"content": "ctx"},
|
||||
{"content": "q"},
|
||||
]
|
||||
out = llm._raw_gen(llm, model="claude-2", messages=msgs, stream=False, max_tokens=55)
|
||||
out = llm._raw_gen(
|
||||
llm, model="claude-2", messages=msgs, stream=False, max_tokens=55
|
||||
)
|
||||
assert out == "final"
|
||||
last = llm.anthropic.completions.last_kwargs
|
||||
assert last["model"] == "claude-2"
|
||||
@@ -59,7 +75,8 @@ def test_anthropic_raw_gen_stream_yields_chunks():
|
||||
{"content": "ctx"},
|
||||
{"content": "q"},
|
||||
]
|
||||
gen = llm._raw_gen_stream(llm, model="claude", messages=msgs, stream=True, max_tokens=10)
|
||||
gen = llm._raw_gen_stream(
|
||||
llm, model="claude", messages=msgs, stream=True, max_tokens=10
|
||||
)
|
||||
chunks = list(gen)
|
||||
assert chunks == ["s1", "s2"]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user