mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 16:43:16 +00:00
Compare commits
36 Commits
fix/eslint
...
fix-tool-n
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
17bc22224c | ||
|
|
899b30da5e | ||
|
|
dc2faf7a7e | ||
|
|
67e0d222d1 | ||
|
|
17698ce774 | ||
|
|
7d1c8c008b | ||
|
|
9e58eb02b3 | ||
|
|
3f7de867cc | ||
|
|
fbf7cf874b | ||
|
|
ba7278b80f | ||
|
|
9d649de6f9 | ||
|
|
7929afbf58 | ||
|
|
ceaf942e70 | ||
|
|
f355601a44 | ||
|
|
4ff99a1e86 | ||
|
|
129084ba92 | ||
|
|
2288df1293 | ||
|
|
d9dfac55e7 | ||
|
|
404cf4b7c7 | ||
|
|
f1c1fc123b | ||
|
|
9f19c7ee4c | ||
|
|
155e74eca1 | ||
|
|
ea2dc4dbcb | ||
|
|
616edc97de | ||
|
|
b017e99c79 | ||
|
|
f698e9d3e1 | ||
|
|
d366502850 | ||
|
|
3d6757c170 | ||
|
|
cb8302add8 | ||
|
|
9d266e9fad | ||
|
|
ae94c9d31e | ||
|
|
83ab232dcd | ||
|
|
eea85772a3 | ||
|
|
0fe7e223cc | ||
|
|
3789d2eb03 | ||
|
|
d54469532e |
6
.github/dependabot.yml
vendored
6
.github/dependabot.yml
vendored
@@ -13,7 +13,11 @@ updates:
|
||||
directory: "/frontend" # Location of package manifests
|
||||
schedule:
|
||||
interval: "daily"
|
||||
- package-ecosystem: "npm"
|
||||
directory: "/extensions/react-widget"
|
||||
schedule:
|
||||
interval: "daily"
|
||||
- package-ecosystem: "github-actions"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "daily"
|
||||
interval: "daily"
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -2,6 +2,7 @@
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
experiments/
|
||||
|
||||
experiments
|
||||
# C extensions
|
||||
|
||||
@@ -147,5 +147,5 @@ Here's a step-by-step guide on how to contribute to DocsGPT:
|
||||
Thank you for considering contributing to DocsGPT! 🙏
|
||||
|
||||
## Questions/collaboration
|
||||
Feel free to join our [Discord](https://discord.gg/n5BX8dh8rU). We're very friendly and welcoming to new contributors, so don't hesitate to reach out.
|
||||
Feel free to join our [Discord](https://discord.gg/vN7YFfdMpj). We're very friendly and welcoming to new contributors, so don't hesitate to reach out.
|
||||
# Thank you so much for considering to contributing DocsGPT!🙏
|
||||
|
||||
@@ -32,7 +32,7 @@ Non-Code Contributions:
|
||||
- Before contributing check existing [issues](https://github.com/arc53/DocsGPT/issues) or [create](https://github.com/arc53/DocsGPT/issues/new/choose) an issue and wait to get assigned.
|
||||
- Once you are finished with your contribution, please fill in this [form](https://forms.gle/Npaba4n9Epfyx56S8).
|
||||
- Refer to the [Documentation](https://docs.docsgpt.cloud/).
|
||||
- Feel free to join our [Discord](https://discord.gg/n5BX8dh8rU) server. We're here to help newcomers, so don't hesitate to jump in! Join us [here](https://discord.gg/n5BX8dh8rU).
|
||||
- Feel free to join our [Discord](https://discord.gg/vN7YFfdMpj) server. We're here to help newcomers, so don't hesitate to jump in! Join us [here](https://discord.gg/vN7YFfdMpj).
|
||||
|
||||
Thank you very much for considering contributing to DocsGPT during Hacktoberfest! 🙏 Your contributions (not just simple typos) could earn you a stylish new t-shirt.
|
||||
|
||||
|
||||
11
README.md
11
README.md
@@ -16,23 +16,16 @@
|
||||
<a href="https://github.com/arc53/DocsGPT"></a>
|
||||
<a href="https://github.com/arc53/DocsGPT/blob/main/LICENSE"></a>
|
||||
<a href="https://www.bestpractices.dev/projects/9907"><img src="https://www.bestpractices.dev/projects/9907/badge"></a>
|
||||
<a href="https://discord.gg/n5BX8dh8rU"></a>
|
||||
<a href="https://discord.gg/vN7YFfdMpj"></a>
|
||||
<a href="https://x.com/docsgptai"></a>
|
||||
|
||||
<a href="https://docs.docsgpt.cloud/quickstart">⚡️ Quickstart</a> • <a href="https://app.docsgpt.cloud/">☁️ Cloud Version</a> • <a href="https://discord.gg/n5BX8dh8rU">💬 Discord</a>
|
||||
<a href="https://docs.docsgpt.cloud/quickstart">⚡️ Quickstart</a> • <a href="https://app.docsgpt.cloud/">☁️ Cloud Version</a> • <a href="https://discord.gg/vN7YFfdMpj">💬 Discord</a>
|
||||
<br>
|
||||
<a href="https://docs.docsgpt.cloud/">📖 Documentation</a> • <a href="https://github.com/arc53/DocsGPT/blob/main/CONTRIBUTING.md">👫 Contribute</a> • <a href="https://blog.docsgpt.cloud/">🗞 Blog</a>
|
||||
<br>
|
||||
|
||||
</div>
|
||||
|
||||
<div align="center">
|
||||
<br>
|
||||
🎃 <a href="https://github.com/arc53/DocsGPT/blob/main/HACKTOBERFEST.md"> Hacktoberfest Prizes, Rules & Q&A </a> 🎃
|
||||
<br>
|
||||
<br>
|
||||
</div>
|
||||
|
||||
|
||||
<div align="center">
|
||||
<br>
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
from application.agents.classic_agent import ClassicAgent
|
||||
from application.agents.react_agent import ReActAgent
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentCreator:
|
||||
@@ -13,4 +16,5 @@ class AgentCreator:
|
||||
agent_class = cls.agents.get(type.lower())
|
||||
if not agent_class:
|
||||
raise ValueError(f"No agent class found for type {type}")
|
||||
|
||||
return agent_class(*args, **kwargs)
|
||||
|
||||
@@ -21,7 +21,7 @@ class BaseAgent(ABC):
|
||||
self,
|
||||
endpoint: str,
|
||||
llm_name: str,
|
||||
gpt_model: str,
|
||||
model_id: str,
|
||||
api_key: str,
|
||||
user_api_key: Optional[str] = None,
|
||||
prompt: str = "",
|
||||
@@ -34,10 +34,11 @@ class BaseAgent(ABC):
|
||||
token_limit: Optional[int] = settings.DEFAULT_AGENT_LIMITS["token_limit"],
|
||||
limited_request_mode: Optional[bool] = False,
|
||||
request_limit: Optional[int] = settings.DEFAULT_AGENT_LIMITS["request_limit"],
|
||||
compressed_summary: Optional[str] = None,
|
||||
):
|
||||
self.endpoint = endpoint
|
||||
self.llm_name = llm_name
|
||||
self.gpt_model = gpt_model
|
||||
self.model_id = model_id
|
||||
self.api_key = api_key
|
||||
self.user_api_key = user_api_key
|
||||
self.prompt = prompt
|
||||
@@ -52,6 +53,7 @@ class BaseAgent(ABC):
|
||||
api_key=api_key,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
model_id=model_id,
|
||||
)
|
||||
self.retrieved_docs = retrieved_docs or []
|
||||
self.llm_handler = LLMHandlerCreator.create_handler(
|
||||
@@ -63,6 +65,9 @@ class BaseAgent(ABC):
|
||||
self.token_limit = token_limit
|
||||
self.limited_request_mode = limited_request_mode
|
||||
self.request_limit = request_limit
|
||||
self.compressed_summary = compressed_summary
|
||||
self.current_token_count = 0
|
||||
self.context_limit_reached = False
|
||||
|
||||
@log_activity()
|
||||
def gen(
|
||||
@@ -275,12 +280,77 @@ class BaseAgent(ABC):
|
||||
for tool_call in self.tool_calls
|
||||
]
|
||||
|
||||
def _calculate_current_context_tokens(self, messages: List[Dict]) -> int:
|
||||
"""
|
||||
Calculate total tokens in current context (messages).
|
||||
|
||||
Args:
|
||||
messages: List of message dicts
|
||||
|
||||
Returns:
|
||||
Total token count
|
||||
"""
|
||||
from application.api.answer.services.compression.token_counter import (
|
||||
TokenCounter,
|
||||
)
|
||||
|
||||
return TokenCounter.count_message_tokens(messages)
|
||||
|
||||
def _check_context_limit(self, messages: List[Dict]) -> bool:
|
||||
"""
|
||||
Check if we're approaching context limit (80%).
|
||||
|
||||
Args:
|
||||
messages: Current message list
|
||||
|
||||
Returns:
|
||||
True if at or above 80% of context limit
|
||||
"""
|
||||
from application.core.model_utils import get_token_limit
|
||||
from application.core.settings import settings
|
||||
|
||||
try:
|
||||
# Calculate current tokens
|
||||
current_tokens = self._calculate_current_context_tokens(messages)
|
||||
self.current_token_count = current_tokens
|
||||
|
||||
# Get context limit for model
|
||||
context_limit = get_token_limit(self.model_id)
|
||||
|
||||
# Calculate threshold (80%)
|
||||
threshold = int(context_limit * settings.COMPRESSION_THRESHOLD_PERCENTAGE)
|
||||
|
||||
# Check if we've reached the limit
|
||||
if current_tokens >= threshold:
|
||||
logger.warning(
|
||||
f"Context limit approaching: {current_tokens}/{context_limit} tokens "
|
||||
f"({(current_tokens/context_limit)*100:.1f}%)"
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking context limit: {str(e)}", exc_info=True)
|
||||
return False
|
||||
|
||||
def _build_messages(
|
||||
self,
|
||||
system_prompt: str,
|
||||
query: str,
|
||||
) -> List[Dict]:
|
||||
"""Build messages using pre-rendered system prompt"""
|
||||
# Append compression summary to system prompt if present
|
||||
if self.compressed_summary:
|
||||
compression_context = (
|
||||
"\n\n---\n\n"
|
||||
"This session is being continued from a previous conversation that "
|
||||
"has been compressed to fit within context limits. "
|
||||
"The conversation is summarized below:\n\n"
|
||||
f"{self.compressed_summary}"
|
||||
)
|
||||
system_prompt = system_prompt + compression_context
|
||||
|
||||
messages = [{"role": "system", "content": system_prompt}]
|
||||
|
||||
for i in self.chat_history:
|
||||
@@ -316,7 +386,7 @@ class BaseAgent(ABC):
|
||||
return messages
|
||||
|
||||
def _llm_gen(self, messages: List[Dict], log_context: Optional[LogContext] = None):
|
||||
gen_kwargs = {"model": self.gpt_model, "messages": messages}
|
||||
gen_kwargs = {"model": self.model_id, "messages": messages}
|
||||
|
||||
if (
|
||||
hasattr(self.llm, "_supports_tools")
|
||||
|
||||
@@ -86,7 +86,7 @@ class ReActAgent(BaseAgent):
|
||||
messages = [{"role": "user", "content": plan_prompt}]
|
||||
|
||||
plan_stream = self.llm.gen_stream(
|
||||
model=self.gpt_model,
|
||||
model=self.model_id,
|
||||
messages=messages,
|
||||
tools=self.tools if self.tools else None,
|
||||
)
|
||||
@@ -151,7 +151,7 @@ class ReActAgent(BaseAgent):
|
||||
messages = [{"role": "user", "content": final_prompt}]
|
||||
|
||||
final_stream = self.llm.gen_stream(
|
||||
model=self.gpt_model, messages=messages, tools=None
|
||||
model=self.model_id, messages=messages, tools=None
|
||||
)
|
||||
|
||||
if log_context:
|
||||
|
||||
@@ -54,6 +54,10 @@ class AnswerResource(Resource, BaseAnswerResource):
|
||||
default=True,
|
||||
description="Whether to save the conversation",
|
||||
),
|
||||
"model_id": fields.String(
|
||||
required=False,
|
||||
description="Model ID to use for this request",
|
||||
),
|
||||
"passthrough": fields.Raw(
|
||||
required=False,
|
||||
description="Dynamic parameters to inject into prompt template",
|
||||
@@ -97,6 +101,7 @@ class AnswerResource(Resource, BaseAnswerResource):
|
||||
isNoneDoc=data.get("isNoneDoc"),
|
||||
index=None,
|
||||
should_save_conversation=data.get("save_conversation", True),
|
||||
model_id=processor.model_id,
|
||||
)
|
||||
stream_result = self.process_response_stream(stream)
|
||||
|
||||
|
||||
@@ -7,11 +7,16 @@ from flask import jsonify, make_response, Response
|
||||
from flask_restx import Namespace
|
||||
|
||||
from application.api.answer.services.conversation_service import ConversationService
|
||||
from application.core.model_utils import (
|
||||
get_api_key_for_provider,
|
||||
get_default_model_id,
|
||||
get_provider_from_model_id,
|
||||
)
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.utils import check_required_fields, get_gpt_model
|
||||
from application.utils import check_required_fields
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -27,7 +32,7 @@ class BaseAnswerResource:
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
self.db = db
|
||||
self.user_logs_collection = db["user_logs"]
|
||||
self.gpt_model = get_gpt_model()
|
||||
self.default_model_id = get_default_model_id()
|
||||
self.conversation_service = ConversationService()
|
||||
|
||||
def validate_request(
|
||||
@@ -54,7 +59,6 @@ class BaseAnswerResource:
|
||||
api_key = agent_config.get("user_api_key")
|
||||
if not api_key:
|
||||
return None
|
||||
|
||||
agents_collection = self.db["agents"]
|
||||
agent = agents_collection.find_one({"key": api_key})
|
||||
|
||||
@@ -62,7 +66,6 @@ class BaseAnswerResource:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid API key."}), 401
|
||||
)
|
||||
|
||||
limited_token_mode_raw = agent.get("limited_token_mode", False)
|
||||
limited_request_mode_raw = agent.get("limited_request_mode", False)
|
||||
|
||||
@@ -110,15 +113,12 @@ class BaseAnswerResource:
|
||||
daily_token_usage = token_result[0]["total_tokens"] if token_result else 0
|
||||
else:
|
||||
daily_token_usage = 0
|
||||
|
||||
if limited_request_mode:
|
||||
daily_request_usage = token_usage_collection.count_documents(match_query)
|
||||
else:
|
||||
daily_request_usage = 0
|
||||
|
||||
if not limited_token_mode and not limited_request_mode:
|
||||
return None
|
||||
|
||||
token_exceeded = (
|
||||
limited_token_mode and token_limit > 0 and daily_token_usage >= token_limit
|
||||
)
|
||||
@@ -138,7 +138,6 @@ class BaseAnswerResource:
|
||||
),
|
||||
429,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def complete_stream(
|
||||
@@ -155,6 +154,7 @@ class BaseAnswerResource:
|
||||
agent_id: Optional[str] = None,
|
||||
is_shared_usage: bool = False,
|
||||
shared_token: Optional[str] = None,
|
||||
model_id: Optional[str] = None,
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Generator function that streams the complete conversation response.
|
||||
@@ -173,6 +173,7 @@ class BaseAnswerResource:
|
||||
agent_id: ID of agent used
|
||||
is_shared_usage: Flag for shared agent usage
|
||||
shared_token: Token for shared agent
|
||||
model_id: Model ID used for the request
|
||||
retrieved_docs: Pre-fetched documents for sources (optional)
|
||||
|
||||
Yields:
|
||||
@@ -220,7 +221,6 @@ class BaseAnswerResource:
|
||||
elif "type" in line:
|
||||
data = json.dumps(line)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
if is_structured and structured_chunks:
|
||||
structured_data = {
|
||||
"type": "structured_answer",
|
||||
@@ -230,15 +230,22 @@ class BaseAnswerResource:
|
||||
}
|
||||
data = json.dumps(structured_data)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
if isNoneDoc:
|
||||
for doc in source_log_docs:
|
||||
doc["source"] = "None"
|
||||
provider = (
|
||||
get_provider_from_model_id(model_id)
|
||||
if model_id
|
||||
else settings.LLM_PROVIDER
|
||||
)
|
||||
system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER)
|
||||
|
||||
llm = LLMCreator.create_llm(
|
||||
settings.LLM_PROVIDER,
|
||||
api_key=settings.API_KEY,
|
||||
provider or settings.LLM_PROVIDER,
|
||||
api_key=system_api_key,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
model_id=model_id,
|
||||
)
|
||||
|
||||
if should_save_conversation:
|
||||
@@ -250,7 +257,7 @@ class BaseAnswerResource:
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
self.gpt_model,
|
||||
model_id or self.default_model_id,
|
||||
decoded_token,
|
||||
index=index,
|
||||
api_key=user_api_key,
|
||||
@@ -259,6 +266,26 @@ class BaseAnswerResource:
|
||||
shared_token=shared_token,
|
||||
attachment_ids=attachment_ids,
|
||||
)
|
||||
# Persist compression metadata/summary if it exists and wasn't saved mid-execution
|
||||
compression_meta = getattr(agent, "compression_metadata", None)
|
||||
compression_saved = getattr(agent, "compression_saved", False)
|
||||
if conversation_id and compression_meta and not compression_saved:
|
||||
try:
|
||||
self.conversation_service.update_compression_metadata(
|
||||
conversation_id, compression_meta
|
||||
)
|
||||
self.conversation_service.append_compression_message(
|
||||
conversation_id, compression_meta
|
||||
)
|
||||
agent.compression_saved = True
|
||||
logger.info(
|
||||
f"Persisted compression metadata for conversation {conversation_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to persist compression metadata: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
else:
|
||||
conversation_id = None
|
||||
id_data = {"type": "id", "id": str(conversation_id)}
|
||||
@@ -280,12 +307,11 @@ class BaseAnswerResource:
|
||||
log_data["structured_output"] = True
|
||||
if schema_info:
|
||||
log_data["schema"] = schema_info
|
||||
|
||||
# Clean up text fields to be no longer than 10000 characters
|
||||
|
||||
for key, value in log_data.items():
|
||||
if isinstance(value, str) and len(value) > 10000:
|
||||
log_data[key] = value[:10000]
|
||||
|
||||
self.user_logs_collection.insert_one(log_data)
|
||||
|
||||
data = json.dumps({"type": "end"})
|
||||
@@ -293,6 +319,7 @@ class BaseAnswerResource:
|
||||
except GeneratorExit:
|
||||
logger.info(f"Stream aborted by client for question: {question[:50]}... ")
|
||||
# Save partial response
|
||||
|
||||
if should_save_conversation and response_full:
|
||||
try:
|
||||
if isNoneDoc:
|
||||
@@ -312,7 +339,7 @@ class BaseAnswerResource:
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
self.gpt_model,
|
||||
model_id or self.default_model_id,
|
||||
decoded_token,
|
||||
index=index,
|
||||
api_key=user_api_key,
|
||||
@@ -321,6 +348,25 @@ class BaseAnswerResource:
|
||||
shared_token=shared_token,
|
||||
attachment_ids=attachment_ids,
|
||||
)
|
||||
compression_meta = getattr(agent, "compression_metadata", None)
|
||||
compression_saved = getattr(agent, "compression_saved", False)
|
||||
if conversation_id and compression_meta and not compression_saved:
|
||||
try:
|
||||
self.conversation_service.update_compression_metadata(
|
||||
conversation_id, compression_meta
|
||||
)
|
||||
self.conversation_service.append_compression_message(
|
||||
conversation_id, compression_meta
|
||||
)
|
||||
agent.compression_saved = True
|
||||
logger.info(
|
||||
f"Persisted compression metadata for conversation {conversation_id} (partial stream)"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to persist compression metadata (partial stream): {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error saving partial response: {str(e)}", exc_info=True
|
||||
@@ -369,7 +415,7 @@ class BaseAnswerResource:
|
||||
thought = event["thought"]
|
||||
elif event["type"] == "error":
|
||||
logger.error(f"Error from stream: {event['error']}")
|
||||
return None, None, None, None, event["error"]
|
||||
return None, None, None, None, event["error"], None
|
||||
elif event["type"] == "end":
|
||||
stream_ended = True
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
@@ -377,8 +423,7 @@ class BaseAnswerResource:
|
||||
continue
|
||||
if not stream_ended:
|
||||
logger.error("Stream ended unexpectedly without an 'end' event.")
|
||||
return None, None, None, None, "Stream ended unexpectedly"
|
||||
|
||||
return None, None, None, None, "Stream ended unexpectedly", None
|
||||
result = (
|
||||
conversation_id,
|
||||
response_full,
|
||||
@@ -390,7 +435,6 @@ class BaseAnswerResource:
|
||||
|
||||
if is_structured:
|
||||
result = result + ({"structured": True, "schema": schema_info},)
|
||||
|
||||
return result
|
||||
|
||||
def error_stream_generate(self, err_response):
|
||||
|
||||
@@ -57,6 +57,10 @@ class StreamResource(Resource, BaseAnswerResource):
|
||||
default=True,
|
||||
description="Whether to save the conversation",
|
||||
),
|
||||
"model_id": fields.String(
|
||||
required=False,
|
||||
description="Model ID to use for this request",
|
||||
),
|
||||
"attachments": fields.List(
|
||||
fields.String, required=False, description="List of attachment IDs"
|
||||
),
|
||||
@@ -101,6 +105,7 @@ class StreamResource(Resource, BaseAnswerResource):
|
||||
agent_id=data.get("agent_id"),
|
||||
is_shared_usage=processor.is_shared_usage,
|
||||
shared_token=processor.shared_token,
|
||||
model_id=processor.model_id,
|
||||
),
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
|
||||
20
application/api/answer/services/compression/__init__.py
Normal file
20
application/api/answer/services/compression/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""
|
||||
Compression module for managing conversation context compression.
|
||||
|
||||
"""
|
||||
|
||||
from application.api.answer.services.compression.orchestrator import (
|
||||
CompressionOrchestrator,
|
||||
)
|
||||
from application.api.answer.services.compression.service import CompressionService
|
||||
from application.api.answer.services.compression.types import (
|
||||
CompressionResult,
|
||||
CompressionMetadata,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CompressionOrchestrator",
|
||||
"CompressionService",
|
||||
"CompressionResult",
|
||||
"CompressionMetadata",
|
||||
]
|
||||
234
application/api/answer/services/compression/message_builder.py
Normal file
234
application/api/answer/services/compression/message_builder.py
Normal file
@@ -0,0 +1,234 @@
|
||||
"""Message reconstruction utilities for compression."""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageBuilder:
|
||||
"""Builds message arrays from compressed context."""
|
||||
|
||||
@staticmethod
|
||||
def build_from_compressed_context(
|
||||
system_prompt: str,
|
||||
compressed_summary: Optional[str],
|
||||
recent_queries: List[Dict],
|
||||
include_tool_calls: bool = False,
|
||||
context_type: str = "pre_request",
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Build messages from compressed context.
|
||||
|
||||
Args:
|
||||
system_prompt: Original system prompt
|
||||
compressed_summary: Compressed summary (if any)
|
||||
recent_queries: Recent uncompressed queries
|
||||
include_tool_calls: Whether to include tool calls from history
|
||||
context_type: Type of context ('pre_request' or 'mid_execution')
|
||||
|
||||
Returns:
|
||||
List of message dicts ready for LLM
|
||||
"""
|
||||
# Append compression summary to system prompt if present
|
||||
if compressed_summary:
|
||||
system_prompt = MessageBuilder._append_compression_context(
|
||||
system_prompt, compressed_summary, context_type
|
||||
)
|
||||
|
||||
messages = [{"role": "system", "content": system_prompt}]
|
||||
|
||||
# Add recent history
|
||||
for query in recent_queries:
|
||||
if "prompt" in query and "response" in query:
|
||||
messages.append({"role": "user", "content": query["prompt"]})
|
||||
messages.append({"role": "assistant", "content": query["response"]})
|
||||
|
||||
# Add tool calls from history if present
|
||||
if include_tool_calls and "tool_calls" in query:
|
||||
for tool_call in query["tool_calls"]:
|
||||
call_id = tool_call.get("call_id") or str(uuid.uuid4())
|
||||
|
||||
function_call_dict = {
|
||||
"function_call": {
|
||||
"name": tool_call.get("action_name"),
|
||||
"args": tool_call.get("arguments"),
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
function_response_dict = {
|
||||
"function_response": {
|
||||
"name": tool_call.get("action_name"),
|
||||
"response": {"result": tool_call.get("result")},
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
|
||||
messages.append(
|
||||
{"role": "assistant", "content": [function_call_dict]}
|
||||
)
|
||||
messages.append(
|
||||
{"role": "tool", "content": [function_response_dict]}
|
||||
)
|
||||
|
||||
# If no recent queries (everything was compressed), add a continuation user message
|
||||
if len(recent_queries) == 0 and compressed_summary:
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": "Please continue with the remaining tasks based on the context above."
|
||||
})
|
||||
logger.info("Added continuation user message to maintain proper turn-taking after full compression")
|
||||
|
||||
return messages
|
||||
|
||||
@staticmethod
|
||||
def _append_compression_context(
|
||||
system_prompt: str, compressed_summary: str, context_type: str = "pre_request"
|
||||
) -> str:
|
||||
"""
|
||||
Append compression context to system prompt.
|
||||
|
||||
Args:
|
||||
system_prompt: Original system prompt
|
||||
compressed_summary: Summary to append
|
||||
context_type: Type of compression context
|
||||
|
||||
Returns:
|
||||
Updated system prompt
|
||||
"""
|
||||
# Remove existing compression context if present
|
||||
if "This session is being continued" in system_prompt or "Context window limit reached" in system_prompt:
|
||||
parts = system_prompt.split("\n\n---\n\n")
|
||||
system_prompt = parts[0]
|
||||
|
||||
# Build appropriate context message based on type
|
||||
if context_type == "mid_execution":
|
||||
context_message = (
|
||||
"\n\n---\n\n"
|
||||
"Context window limit reached during execution. "
|
||||
"Previous conversation has been compressed to fit within limits. "
|
||||
"The conversation is summarized below:\n\n"
|
||||
f"{compressed_summary}"
|
||||
)
|
||||
else: # pre_request
|
||||
context_message = (
|
||||
"\n\n---\n\n"
|
||||
"This session is being continued from a previous conversation that "
|
||||
"has been compressed to fit within context limits. "
|
||||
"The conversation is summarized below:\n\n"
|
||||
f"{compressed_summary}"
|
||||
)
|
||||
|
||||
return system_prompt + context_message
|
||||
|
||||
@staticmethod
|
||||
def rebuild_messages_after_compression(
|
||||
messages: List[Dict],
|
||||
compressed_summary: Optional[str],
|
||||
recent_queries: List[Dict],
|
||||
include_current_execution: bool = False,
|
||||
include_tool_calls: bool = False,
|
||||
) -> Optional[List[Dict]]:
|
||||
"""
|
||||
Rebuild the message list after compression so tool execution can continue.
|
||||
|
||||
Args:
|
||||
messages: Original message list
|
||||
compressed_summary: Compressed summary
|
||||
recent_queries: Recent uncompressed queries
|
||||
include_current_execution: Whether to preserve current execution messages
|
||||
include_tool_calls: Whether to include tool calls from history
|
||||
|
||||
Returns:
|
||||
Rebuilt message list or None if failed
|
||||
"""
|
||||
# Find the system message
|
||||
system_message = next(
|
||||
(msg for msg in messages if msg.get("role") == "system"), None
|
||||
)
|
||||
if not system_message:
|
||||
logger.warning("No system message found in messages list")
|
||||
return None
|
||||
|
||||
# Update system message with compressed summary
|
||||
if compressed_summary:
|
||||
content = system_message.get("content", "")
|
||||
system_message["content"] = MessageBuilder._append_compression_context(
|
||||
content, compressed_summary, "mid_execution"
|
||||
)
|
||||
logger.info(
|
||||
"Appended compression summary to system prompt (truncated): %s",
|
||||
(
|
||||
compressed_summary[:500] + "..."
|
||||
if len(compressed_summary) > 500
|
||||
else compressed_summary
|
||||
),
|
||||
)
|
||||
|
||||
rebuilt_messages = [system_message]
|
||||
|
||||
# Add recent history from compressed context
|
||||
for query in recent_queries:
|
||||
if "prompt" in query and "response" in query:
|
||||
rebuilt_messages.append({"role": "user", "content": query["prompt"]})
|
||||
rebuilt_messages.append(
|
||||
{"role": "assistant", "content": query["response"]}
|
||||
)
|
||||
|
||||
# Add tool calls from history if present
|
||||
if include_tool_calls and "tool_calls" in query:
|
||||
for tool_call in query["tool_calls"]:
|
||||
call_id = tool_call.get("call_id") or str(uuid.uuid4())
|
||||
|
||||
function_call_dict = {
|
||||
"function_call": {
|
||||
"name": tool_call.get("action_name"),
|
||||
"args": tool_call.get("arguments"),
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
function_response_dict = {
|
||||
"function_response": {
|
||||
"name": tool_call.get("action_name"),
|
||||
"response": {"result": tool_call.get("result")},
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
|
||||
rebuilt_messages.append(
|
||||
{"role": "assistant", "content": [function_call_dict]}
|
||||
)
|
||||
rebuilt_messages.append(
|
||||
{"role": "tool", "content": [function_response_dict]}
|
||||
)
|
||||
|
||||
# If no recent queries (everything was compressed), add a continuation user message
|
||||
if len(recent_queries) == 0 and compressed_summary:
|
||||
rebuilt_messages.append({
|
||||
"role": "user",
|
||||
"content": "Please continue with the remaining tasks based on the context above."
|
||||
})
|
||||
logger.info("Added continuation user message to maintain proper turn-taking after full compression")
|
||||
|
||||
if include_current_execution:
|
||||
# Preserve any messages that were added during the current execution cycle
|
||||
recent_msg_count = 1 # system message
|
||||
for query in recent_queries:
|
||||
if "prompt" in query and "response" in query:
|
||||
recent_msg_count += 2
|
||||
if "tool_calls" in query:
|
||||
recent_msg_count += len(query["tool_calls"]) * 2
|
||||
|
||||
if len(messages) > recent_msg_count:
|
||||
current_execution_messages = messages[recent_msg_count:]
|
||||
rebuilt_messages.extend(current_execution_messages)
|
||||
logger.info(
|
||||
f"Preserved {len(current_execution_messages)} messages from current execution cycle"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Messages rebuilt: {len(messages)} → {len(rebuilt_messages)} messages. "
|
||||
f"Ready to continue tool execution."
|
||||
)
|
||||
return rebuilt_messages
|
||||
232
application/api/answer/services/compression/orchestrator.py
Normal file
232
application/api/answer/services/compression/orchestrator.py
Normal file
@@ -0,0 +1,232 @@
|
||||
"""High-level compression orchestration."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from application.api.answer.services.compression.service import CompressionService
|
||||
from application.api.answer.services.compression.threshold_checker import (
|
||||
CompressionThresholdChecker,
|
||||
)
|
||||
from application.api.answer.services.compression.types import CompressionResult
|
||||
from application.api.answer.services.conversation_service import ConversationService
|
||||
from application.core.model_utils import (
|
||||
get_api_key_for_provider,
|
||||
get_provider_from_model_id,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CompressionOrchestrator:
|
||||
"""
|
||||
Facade for compression operations.
|
||||
|
||||
Coordinates between all compression components and provides
|
||||
a simple interface for callers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conversation_service: ConversationService,
|
||||
threshold_checker: Optional[CompressionThresholdChecker] = None,
|
||||
):
|
||||
"""
|
||||
Initialize orchestrator.
|
||||
|
||||
Args:
|
||||
conversation_service: Service for DB operations
|
||||
threshold_checker: Custom threshold checker (optional)
|
||||
"""
|
||||
self.conversation_service = conversation_service
|
||||
self.threshold_checker = threshold_checker or CompressionThresholdChecker()
|
||||
|
||||
def compress_if_needed(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_id: str,
|
||||
model_id: str,
|
||||
decoded_token: Dict[str, Any],
|
||||
current_query_tokens: int = 500,
|
||||
) -> CompressionResult:
|
||||
"""
|
||||
Check if compression is needed and perform it if so.
|
||||
|
||||
This is the main entry point for compression operations.
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation ID
|
||||
user_id: User ID
|
||||
model_id: Model being used for conversation
|
||||
decoded_token: User's decoded JWT token
|
||||
current_query_tokens: Estimated tokens for current query
|
||||
|
||||
Returns:
|
||||
CompressionResult with summary and recent queries
|
||||
"""
|
||||
try:
|
||||
# Load conversation
|
||||
conversation = self.conversation_service.get_conversation(
|
||||
conversation_id, user_id
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
logger.warning(
|
||||
f"Conversation {conversation_id} not found for user {user_id}"
|
||||
)
|
||||
return CompressionResult.failure("Conversation not found")
|
||||
|
||||
# Check if compression is needed
|
||||
if not self.threshold_checker.should_compress(
|
||||
conversation, model_id, current_query_tokens
|
||||
):
|
||||
# No compression needed, return full history
|
||||
queries = conversation.get("queries", [])
|
||||
return CompressionResult.success_no_compression(queries)
|
||||
|
||||
# Perform compression
|
||||
return self._perform_compression(
|
||||
conversation_id, conversation, model_id, decoded_token
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in compress_if_needed: {str(e)}", exc_info=True
|
||||
)
|
||||
return CompressionResult.failure(str(e))
|
||||
|
||||
def _perform_compression(
|
||||
self,
|
||||
conversation_id: str,
|
||||
conversation: Dict[str, Any],
|
||||
model_id: str,
|
||||
decoded_token: Dict[str, Any],
|
||||
) -> CompressionResult:
|
||||
"""
|
||||
Perform the actual compression operation.
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation ID
|
||||
conversation: Conversation document
|
||||
model_id: Model ID for conversation
|
||||
decoded_token: User token
|
||||
|
||||
Returns:
|
||||
CompressionResult
|
||||
"""
|
||||
try:
|
||||
# Determine which model to use for compression
|
||||
compression_model = (
|
||||
settings.COMPRESSION_MODEL_OVERRIDE
|
||||
if settings.COMPRESSION_MODEL_OVERRIDE
|
||||
else model_id
|
||||
)
|
||||
|
||||
# Get provider and API key for compression model
|
||||
provider = get_provider_from_model_id(compression_model)
|
||||
api_key = get_api_key_for_provider(provider)
|
||||
|
||||
# Create compression LLM
|
||||
compression_llm = LLMCreator.create_llm(
|
||||
provider,
|
||||
api_key=api_key,
|
||||
user_api_key=None,
|
||||
decoded_token=decoded_token,
|
||||
model_id=compression_model,
|
||||
)
|
||||
|
||||
# Create compression service with DB update capability
|
||||
compression_service = CompressionService(
|
||||
llm=compression_llm,
|
||||
model_id=compression_model,
|
||||
conversation_service=self.conversation_service,
|
||||
)
|
||||
|
||||
# Compress all queries up to the latest
|
||||
queries_count = len(conversation.get("queries", []))
|
||||
compress_up_to = queries_count - 1
|
||||
|
||||
if compress_up_to < 0:
|
||||
logger.warning("No queries to compress")
|
||||
return CompressionResult.success_no_compression([])
|
||||
|
||||
logger.info(
|
||||
f"Initiating compression for conversation {conversation_id}: "
|
||||
f"compressing all {queries_count} queries (0-{compress_up_to})"
|
||||
)
|
||||
|
||||
# Perform compression and save to DB
|
||||
metadata = compression_service.compress_and_save(
|
||||
conversation_id, conversation, compress_up_to
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Compression successful - ratio: {metadata.compression_ratio:.1f}x, "
|
||||
f"saved {metadata.original_token_count - metadata.compressed_token_count} tokens"
|
||||
)
|
||||
|
||||
# Reload conversation with updated metadata
|
||||
conversation = self.conversation_service.get_conversation(
|
||||
conversation_id, user_id=decoded_token.get("sub")
|
||||
)
|
||||
|
||||
# Get compressed context
|
||||
compressed_summary, recent_queries = (
|
||||
compression_service.get_compressed_context(conversation)
|
||||
)
|
||||
|
||||
return CompressionResult.success_with_compression(
|
||||
compressed_summary, recent_queries, metadata
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error performing compression: {str(e)}", exc_info=True)
|
||||
return CompressionResult.failure(str(e))
|
||||
|
||||
def compress_mid_execution(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_id: str,
|
||||
model_id: str,
|
||||
decoded_token: Dict[str, Any],
|
||||
current_conversation: Optional[Dict[str, Any]] = None,
|
||||
) -> CompressionResult:
|
||||
"""
|
||||
Perform compression during tool execution.
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation ID
|
||||
user_id: User ID
|
||||
model_id: Model ID
|
||||
decoded_token: User token
|
||||
current_conversation: Pre-loaded conversation (optional)
|
||||
|
||||
Returns:
|
||||
CompressionResult
|
||||
"""
|
||||
try:
|
||||
# Load conversation if not provided
|
||||
if current_conversation:
|
||||
conversation = current_conversation
|
||||
else:
|
||||
conversation = self.conversation_service.get_conversation(
|
||||
conversation_id, user_id
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
logger.warning(
|
||||
f"Could not load conversation {conversation_id} for mid-execution compression"
|
||||
)
|
||||
return CompressionResult.failure("Conversation not found")
|
||||
|
||||
# Perform compression
|
||||
return self._perform_compression(
|
||||
conversation_id, conversation, model_id, decoded_token
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in mid-execution compression: {str(e)}", exc_info=True
|
||||
)
|
||||
return CompressionResult.failure(str(e))
|
||||
149
application/api/answer/services/compression/prompt_builder.py
Normal file
149
application/api/answer/services/compression/prompt_builder.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""Compression prompt building logic."""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CompressionPromptBuilder:
|
||||
"""Builds prompts for LLM compression calls."""
|
||||
|
||||
def __init__(self, version: str = "v1.0"):
|
||||
"""
|
||||
Initialize prompt builder.
|
||||
|
||||
Args:
|
||||
version: Prompt template version to use
|
||||
"""
|
||||
self.version = version
|
||||
self.system_prompt = self._load_prompt(version)
|
||||
|
||||
def _load_prompt(self, version: str) -> str:
|
||||
"""
|
||||
Load prompt template from file.
|
||||
|
||||
Args:
|
||||
version: Version string (e.g., 'v1.0')
|
||||
|
||||
Returns:
|
||||
Prompt template content
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If prompt template file doesn't exist
|
||||
"""
|
||||
current_dir = Path(__file__).resolve().parents[4]
|
||||
prompt_path = current_dir / "prompts" / "compression" / f"{version}.txt"
|
||||
|
||||
try:
|
||||
with open(prompt_path, "r") as f:
|
||||
return f.read()
|
||||
except FileNotFoundError:
|
||||
logger.error(f"Compression prompt template not found: {prompt_path}")
|
||||
raise FileNotFoundError(
|
||||
f"Compression prompt template '{version}' not found at {prompt_path}. "
|
||||
f"Please ensure the template file exists."
|
||||
)
|
||||
|
||||
def build_prompt(
|
||||
self,
|
||||
queries: List[Dict[str, Any]],
|
||||
existing_compressions: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Build messages for compression LLM call.
|
||||
|
||||
Args:
|
||||
queries: List of query objects to compress
|
||||
existing_compressions: List of previous compression points
|
||||
|
||||
Returns:
|
||||
List of message dicts for LLM
|
||||
"""
|
||||
# Build conversation text
|
||||
conversation_text = self._format_conversation(queries)
|
||||
|
||||
# Add existing compression context if present
|
||||
existing_compression_context = ""
|
||||
if existing_compressions and len(existing_compressions) > 0:
|
||||
existing_compression_context = (
|
||||
"\n\nIMPORTANT: This conversation has been compressed before. "
|
||||
"Previous compression summaries:\n\n"
|
||||
)
|
||||
for i, comp in enumerate(existing_compressions):
|
||||
existing_compression_context += (
|
||||
f"--- Compression {i + 1} (up to message {comp.get('query_index', 'unknown')}) ---\n"
|
||||
f"{comp.get('compressed_summary', '')}\n\n"
|
||||
)
|
||||
existing_compression_context += (
|
||||
"Your task is to create a NEW summary that incorporates the context from "
|
||||
"previous compressions AND the new messages below. The final summary should "
|
||||
"be comprehensive and include all important information from both previous "
|
||||
"compressions and new messages.\n\n"
|
||||
)
|
||||
|
||||
user_prompt = (
|
||||
f"{existing_compression_context}"
|
||||
f"Here is the conversation to summarize:\n\n"
|
||||
f"{conversation_text}"
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": self.system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
]
|
||||
|
||||
return messages
|
||||
|
||||
def _format_conversation(self, queries: List[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
Format conversation queries into readable text for compression.
|
||||
|
||||
Args:
|
||||
queries: List of query objects
|
||||
|
||||
Returns:
|
||||
Formatted conversation text
|
||||
"""
|
||||
conversation_lines = []
|
||||
|
||||
for i, query in enumerate(queries):
|
||||
conversation_lines.append(f"--- Message {i + 1} ---")
|
||||
conversation_lines.append(f"User: {query.get('prompt', '')}")
|
||||
|
||||
# Add tool calls if present
|
||||
tool_calls = query.get("tool_calls", [])
|
||||
if tool_calls:
|
||||
conversation_lines.append("\nTool Calls:")
|
||||
for tc in tool_calls:
|
||||
tool_name = tc.get("tool_name", "unknown")
|
||||
action_name = tc.get("action_name", "unknown")
|
||||
arguments = tc.get("arguments", {})
|
||||
result = tc.get("result", "")
|
||||
if result is None:
|
||||
result = ""
|
||||
status = tc.get("status", "unknown")
|
||||
|
||||
# Include full tool result for complete compression context
|
||||
conversation_lines.append(
|
||||
f" - {tool_name}.{action_name}({arguments}) "
|
||||
f"[{status}] → {result}"
|
||||
)
|
||||
|
||||
# Add agent thought if present
|
||||
thought = query.get("thought", "")
|
||||
if thought:
|
||||
conversation_lines.append(f"\nAgent Thought: {thought}")
|
||||
|
||||
# Add assistant response
|
||||
conversation_lines.append(f"\nAssistant: {query.get('response', '')}")
|
||||
|
||||
# Add sources if present
|
||||
sources = query.get("sources", [])
|
||||
if sources:
|
||||
conversation_lines.append(f"\nSources Used: {len(sources)} documents")
|
||||
|
||||
conversation_lines.append("") # Empty line between messages
|
||||
|
||||
return "\n".join(conversation_lines)
|
||||
306
application/api/answer/services/compression/service.py
Normal file
306
application/api/answer/services/compression/service.py
Normal file
@@ -0,0 +1,306 @@
|
||||
"""Core compression service with simplified responsibilities."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from application.api.answer.services.compression.prompt_builder import (
|
||||
CompressionPromptBuilder,
|
||||
)
|
||||
from application.api.answer.services.compression.token_counter import TokenCounter
|
||||
from application.api.answer.services.compression.types import (
|
||||
CompressionMetadata,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CompressionService:
|
||||
"""
|
||||
Service for compressing conversation history.
|
||||
|
||||
Handles DB updates.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm,
|
||||
model_id: str,
|
||||
conversation_service=None,
|
||||
prompt_builder: Optional[CompressionPromptBuilder] = None,
|
||||
):
|
||||
"""
|
||||
Initialize compression service.
|
||||
|
||||
Args:
|
||||
llm: LLM instance to use for compression
|
||||
model_id: Model ID for compression
|
||||
conversation_service: Service for DB operations (optional, for DB updates)
|
||||
prompt_builder: Custom prompt builder (optional)
|
||||
"""
|
||||
self.llm = llm
|
||||
self.model_id = model_id
|
||||
self.conversation_service = conversation_service
|
||||
self.prompt_builder = prompt_builder or CompressionPromptBuilder(
|
||||
version=settings.COMPRESSION_PROMPT_VERSION
|
||||
)
|
||||
|
||||
def compress_conversation(
|
||||
self,
|
||||
conversation: Dict[str, Any],
|
||||
compress_up_to_index: int,
|
||||
) -> CompressionMetadata:
|
||||
"""
|
||||
Compress conversation history up to specified index.
|
||||
|
||||
Args:
|
||||
conversation: Full conversation document
|
||||
compress_up_to_index: Last query index to include in compression
|
||||
|
||||
Returns:
|
||||
CompressionMetadata with compression details
|
||||
|
||||
Raises:
|
||||
ValueError: If compress_up_to_index is invalid
|
||||
"""
|
||||
try:
|
||||
queries = conversation.get("queries", [])
|
||||
|
||||
if compress_up_to_index < 0 or compress_up_to_index >= len(queries):
|
||||
raise ValueError(
|
||||
f"Invalid compress_up_to_index: {compress_up_to_index} "
|
||||
f"(conversation has {len(queries)} queries)"
|
||||
)
|
||||
|
||||
# Get queries to compress
|
||||
queries_to_compress = queries[: compress_up_to_index + 1]
|
||||
|
||||
# Check if there are existing compressions
|
||||
existing_compressions = conversation.get("compression_metadata", {}).get(
|
||||
"compression_points", []
|
||||
)
|
||||
|
||||
if existing_compressions:
|
||||
logger.info(
|
||||
f"Found {len(existing_compressions)} previous compression(s) - "
|
||||
f"will incorporate into new summary"
|
||||
)
|
||||
|
||||
# Calculate original token count
|
||||
original_tokens = TokenCounter.count_query_tokens(queries_to_compress)
|
||||
|
||||
# Log tool call stats
|
||||
self._log_tool_call_stats(queries_to_compress)
|
||||
|
||||
# Build compression prompt
|
||||
messages = self.prompt_builder.build_prompt(
|
||||
queries_to_compress, existing_compressions
|
||||
)
|
||||
|
||||
# Call LLM to generate compression
|
||||
logger.info(
|
||||
f"Starting compression: {len(queries_to_compress)} queries "
|
||||
f"(messages 0-{compress_up_to_index}, {original_tokens} tokens) "
|
||||
f"using model {self.model_id}"
|
||||
)
|
||||
|
||||
response = self.llm.gen(
|
||||
model=self.model_id, messages=messages, max_tokens=4000
|
||||
)
|
||||
|
||||
# Extract summary from response
|
||||
compressed_summary = self._extract_summary(response)
|
||||
|
||||
# Calculate compressed token count
|
||||
compressed_tokens = TokenCounter.count_message_tokens(
|
||||
[{"content": compressed_summary}]
|
||||
)
|
||||
|
||||
# Calculate compression ratio
|
||||
compression_ratio = (
|
||||
original_tokens / compressed_tokens if compressed_tokens > 0 else 0
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Compression complete: {original_tokens} → {compressed_tokens} tokens "
|
||||
f"({compression_ratio:.1f}x compression)"
|
||||
)
|
||||
|
||||
# Build compression metadata
|
||||
compression_metadata = CompressionMetadata(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
query_index=compress_up_to_index,
|
||||
compressed_summary=compressed_summary,
|
||||
original_token_count=original_tokens,
|
||||
compressed_token_count=compressed_tokens,
|
||||
compression_ratio=compression_ratio,
|
||||
model_used=self.model_id,
|
||||
compression_prompt_version=self.prompt_builder.version,
|
||||
)
|
||||
|
||||
return compression_metadata
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error compressing conversation: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def compress_and_save(
|
||||
self,
|
||||
conversation_id: str,
|
||||
conversation: Dict[str, Any],
|
||||
compress_up_to_index: int,
|
||||
) -> CompressionMetadata:
|
||||
"""
|
||||
Compress conversation and save to database.
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation ID
|
||||
conversation: Full conversation document
|
||||
compress_up_to_index: Last query index to include
|
||||
|
||||
Returns:
|
||||
CompressionMetadata
|
||||
|
||||
Raises:
|
||||
ValueError: If conversation_service not provided or invalid index
|
||||
"""
|
||||
if not self.conversation_service:
|
||||
raise ValueError(
|
||||
"conversation_service required for compress_and_save operation"
|
||||
)
|
||||
|
||||
# Perform compression
|
||||
metadata = self.compress_conversation(conversation, compress_up_to_index)
|
||||
|
||||
# Save to database
|
||||
self.conversation_service.update_compression_metadata(
|
||||
conversation_id, metadata.to_dict()
|
||||
)
|
||||
|
||||
logger.info(f"Compression metadata saved to database for {conversation_id}")
|
||||
|
||||
return metadata
|
||||
|
||||
def get_compressed_context(
|
||||
self, conversation: Dict[str, Any]
|
||||
) -> tuple[Optional[str], List[Dict[str, Any]]]:
|
||||
"""
|
||||
Get compressed summary + recent uncompressed messages.
|
||||
|
||||
Args:
|
||||
conversation: Full conversation document
|
||||
|
||||
Returns:
|
||||
(compressed_summary, recent_messages)
|
||||
"""
|
||||
try:
|
||||
compression_metadata = conversation.get("compression_metadata", {})
|
||||
|
||||
if not compression_metadata.get("is_compressed"):
|
||||
logger.debug("No compression metadata found - using full history")
|
||||
queries = conversation.get("queries", [])
|
||||
if queries is None:
|
||||
logger.error("Conversation queries is None - returning empty list")
|
||||
return None, []
|
||||
return None, queries
|
||||
|
||||
compression_points = compression_metadata.get("compression_points", [])
|
||||
|
||||
if not compression_points:
|
||||
logger.debug("No compression points found - using full history")
|
||||
queries = conversation.get("queries", [])
|
||||
if queries is None:
|
||||
logger.error("Conversation queries is None - returning empty list")
|
||||
return None, []
|
||||
return None, queries
|
||||
|
||||
# Get the most recent compression point
|
||||
latest_compression = compression_points[-1]
|
||||
compressed_summary = latest_compression.get("compressed_summary")
|
||||
last_compressed_index = latest_compression.get("query_index")
|
||||
compressed_tokens = latest_compression.get("compressed_token_count", 0)
|
||||
original_tokens = latest_compression.get("original_token_count", 0)
|
||||
|
||||
# Get only messages after compression point
|
||||
queries = conversation.get("queries", [])
|
||||
total_queries = len(queries)
|
||||
recent_queries = queries[last_compressed_index + 1 :]
|
||||
|
||||
logger.info(
|
||||
f"Using compressed context: summary ({compressed_tokens} tokens, "
|
||||
f"compressed from {original_tokens}) + {len(recent_queries)} recent messages "
|
||||
f"(messages {last_compressed_index + 1}-{total_queries - 1})"
|
||||
)
|
||||
|
||||
return compressed_summary, recent_queries
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting compressed context: {str(e)}", exc_info=True
|
||||
)
|
||||
queries = conversation.get("queries", [])
|
||||
if queries is None:
|
||||
return None, []
|
||||
return None, queries
|
||||
|
||||
def _extract_summary(self, llm_response: str) -> str:
|
||||
"""
|
||||
Extract clean summary from LLM response.
|
||||
|
||||
Args:
|
||||
llm_response: Raw LLM response
|
||||
|
||||
Returns:
|
||||
Cleaned summary text
|
||||
"""
|
||||
try:
|
||||
# Try to extract content within <summary> tags
|
||||
summary_match = re.search(
|
||||
r"<summary>(.*?)</summary>", llm_response, re.DOTALL
|
||||
)
|
||||
|
||||
if summary_match:
|
||||
summary = summary_match.group(1).strip()
|
||||
else:
|
||||
# If no summary tags, remove analysis tags and use the rest
|
||||
summary = re.sub(
|
||||
r"<analysis>.*?</analysis>", "", llm_response, flags=re.DOTALL
|
||||
).strip()
|
||||
|
||||
return summary
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error extracting summary: {str(e)}, using full response")
|
||||
return llm_response
|
||||
|
||||
def _log_tool_call_stats(self, queries: List[Dict[str, Any]]) -> None:
|
||||
"""Log statistics about tool calls in queries."""
|
||||
total_tool_calls = 0
|
||||
total_tool_result_chars = 0
|
||||
tool_call_breakdown = {}
|
||||
|
||||
for q in queries:
|
||||
for tc in q.get("tool_calls", []):
|
||||
total_tool_calls += 1
|
||||
tool_name = tc.get("tool_name", "unknown")
|
||||
action_name = tc.get("action_name", "unknown")
|
||||
key = f"{tool_name}.{action_name}"
|
||||
tool_call_breakdown[key] = tool_call_breakdown.get(key, 0) + 1
|
||||
|
||||
# Track total tool result size
|
||||
result = tc.get("result", "")
|
||||
if result:
|
||||
total_tool_result_chars += len(str(result))
|
||||
|
||||
if total_tool_calls > 0:
|
||||
tool_breakdown_str = ", ".join(
|
||||
f"{tool}({count})"
|
||||
for tool, count in sorted(tool_call_breakdown.items())
|
||||
)
|
||||
tool_result_kb = total_tool_result_chars / 1024
|
||||
logger.info(
|
||||
f"Tool call breakdown: {tool_breakdown_str} "
|
||||
f"(total result size: {tool_result_kb:.1f} KB, {total_tool_result_chars:,} chars)"
|
||||
)
|
||||
103
application/api/answer/services/compression/threshold_checker.py
Normal file
103
application/api/answer/services/compression/threshold_checker.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""Compression threshold checking logic."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
|
||||
from application.core.model_utils import get_token_limit
|
||||
from application.core.settings import settings
|
||||
from application.api.answer.services.compression.token_counter import TokenCounter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CompressionThresholdChecker:
|
||||
"""Determines if compression is needed based on token thresholds."""
|
||||
|
||||
def __init__(self, threshold_percentage: float = None):
|
||||
"""
|
||||
Initialize threshold checker.
|
||||
|
||||
Args:
|
||||
threshold_percentage: Percentage of context to use as threshold
|
||||
(defaults to settings.COMPRESSION_THRESHOLD_PERCENTAGE)
|
||||
"""
|
||||
self.threshold_percentage = (
|
||||
threshold_percentage or settings.COMPRESSION_THRESHOLD_PERCENTAGE
|
||||
)
|
||||
|
||||
def should_compress(
|
||||
self,
|
||||
conversation: Dict[str, Any],
|
||||
model_id: str,
|
||||
current_query_tokens: int = 500,
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if compression is needed.
|
||||
|
||||
Args:
|
||||
conversation: Full conversation document
|
||||
model_id: Target model for this request
|
||||
current_query_tokens: Estimated tokens for current query
|
||||
|
||||
Returns:
|
||||
True if tokens >= threshold% of context window
|
||||
"""
|
||||
try:
|
||||
# Calculate total tokens in conversation
|
||||
total_tokens = TokenCounter.count_conversation_tokens(conversation)
|
||||
total_tokens += current_query_tokens
|
||||
|
||||
# Get context window limit for model
|
||||
context_limit = get_token_limit(model_id)
|
||||
|
||||
# Calculate threshold
|
||||
threshold = int(context_limit * self.threshold_percentage)
|
||||
|
||||
compression_needed = total_tokens >= threshold
|
||||
percentage_used = (total_tokens / context_limit) * 100
|
||||
|
||||
if compression_needed:
|
||||
logger.warning(
|
||||
f"COMPRESSION TRIGGERED: {total_tokens} tokens / {context_limit} limit "
|
||||
f"({percentage_used:.1f}% used, threshold: {self.threshold_percentage * 100:.0f}%)"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Compression check: {total_tokens}/{context_limit} tokens "
|
||||
f"({percentage_used:.1f}% used, threshold: {self.threshold_percentage * 100:.0f}%) - No compression needed"
|
||||
)
|
||||
|
||||
return compression_needed
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking compression need: {str(e)}", exc_info=True)
|
||||
return False
|
||||
|
||||
def check_message_tokens(self, messages: list, model_id: str) -> bool:
|
||||
"""
|
||||
Check if message list exceeds threshold.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts
|
||||
model_id: Target model
|
||||
|
||||
Returns:
|
||||
True if at or above threshold
|
||||
"""
|
||||
try:
|
||||
current_tokens = TokenCounter.count_message_tokens(messages)
|
||||
context_limit = get_token_limit(model_id)
|
||||
threshold = int(context_limit * self.threshold_percentage)
|
||||
|
||||
if current_tokens >= threshold:
|
||||
logger.warning(
|
||||
f"Message context limit approaching: {current_tokens}/{context_limit} tokens "
|
||||
f"({(current_tokens/context_limit)*100:.1f}%)"
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking message tokens: {str(e)}", exc_info=True)
|
||||
return False
|
||||
103
application/api/answer/services/compression/token_counter.py
Normal file
103
application/api/answer/services/compression/token_counter.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""Token counting utilities for compression."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from application.utils import num_tokens_from_string
|
||||
from application.core.settings import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TokenCounter:
|
||||
"""Centralized token counting for conversations and messages."""
|
||||
|
||||
@staticmethod
|
||||
def count_message_tokens(messages: List[Dict]) -> int:
|
||||
"""
|
||||
Calculate total tokens in a list of messages.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'content' field
|
||||
|
||||
Returns:
|
||||
Total token count
|
||||
"""
|
||||
total_tokens = 0
|
||||
for message in messages:
|
||||
content = message.get("content", "")
|
||||
if isinstance(content, str):
|
||||
total_tokens += num_tokens_from_string(content)
|
||||
elif isinstance(content, list):
|
||||
# Handle structured content (tool calls, etc.)
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
total_tokens += num_tokens_from_string(str(item))
|
||||
return total_tokens
|
||||
|
||||
@staticmethod
|
||||
def count_query_tokens(
|
||||
queries: List[Dict[str, Any]], include_tool_calls: bool = True
|
||||
) -> int:
|
||||
"""
|
||||
Count tokens across multiple query objects.
|
||||
|
||||
Args:
|
||||
queries: List of query objects from conversation
|
||||
include_tool_calls: Whether to count tool call tokens
|
||||
|
||||
Returns:
|
||||
Total token count
|
||||
"""
|
||||
total_tokens = 0
|
||||
|
||||
for query in queries:
|
||||
# Count prompt and response tokens
|
||||
if "prompt" in query:
|
||||
total_tokens += num_tokens_from_string(query["prompt"])
|
||||
if "response" in query:
|
||||
total_tokens += num_tokens_from_string(query["response"])
|
||||
if "thought" in query:
|
||||
total_tokens += num_tokens_from_string(query.get("thought", ""))
|
||||
|
||||
# Count tool call tokens
|
||||
if include_tool_calls and "tool_calls" in query:
|
||||
for tool_call in query["tool_calls"]:
|
||||
tool_call_string = (
|
||||
f"Tool: {tool_call.get('tool_name')} | "
|
||||
f"Action: {tool_call.get('action_name')} | "
|
||||
f"Args: {tool_call.get('arguments')} | "
|
||||
f"Response: {tool_call.get('result')}"
|
||||
)
|
||||
total_tokens += num_tokens_from_string(tool_call_string)
|
||||
|
||||
return total_tokens
|
||||
|
||||
@staticmethod
|
||||
def count_conversation_tokens(
|
||||
conversation: Dict[str, Any], include_system_prompt: bool = False
|
||||
) -> int:
|
||||
"""
|
||||
Calculate total tokens in a conversation.
|
||||
|
||||
Args:
|
||||
conversation: Conversation document
|
||||
include_system_prompt: Whether to include system prompt in count
|
||||
|
||||
Returns:
|
||||
Total token count
|
||||
"""
|
||||
try:
|
||||
queries = conversation.get("queries", [])
|
||||
total_tokens = TokenCounter.count_query_tokens(queries)
|
||||
|
||||
# Add system prompt tokens if requested
|
||||
if include_system_prompt:
|
||||
# Rough estimate for system prompt
|
||||
total_tokens += settings.RESERVED_TOKENS.get("system_prompt", 500)
|
||||
|
||||
return total_tokens
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating conversation tokens: {str(e)}")
|
||||
return 0
|
||||
83
application/api/answer/services/compression/types.py
Normal file
83
application/api/answer/services/compression/types.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""Type definitions for compression module."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompressionMetadata:
|
||||
"""Metadata about a compression operation."""
|
||||
|
||||
timestamp: datetime
|
||||
query_index: int
|
||||
compressed_summary: str
|
||||
original_token_count: int
|
||||
compressed_token_count: int
|
||||
compression_ratio: float
|
||||
model_used: str
|
||||
compression_prompt_version: str
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for DB storage."""
|
||||
return {
|
||||
"timestamp": self.timestamp,
|
||||
"query_index": self.query_index,
|
||||
"compressed_summary": self.compressed_summary,
|
||||
"original_token_count": self.original_token_count,
|
||||
"compressed_token_count": self.compressed_token_count,
|
||||
"compression_ratio": self.compression_ratio,
|
||||
"model_used": self.model_used,
|
||||
"compression_prompt_version": self.compression_prompt_version,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompressionResult:
|
||||
"""Result of a compression operation."""
|
||||
|
||||
success: bool
|
||||
compressed_summary: Optional[str] = None
|
||||
recent_queries: List[Dict[str, Any]] = field(default_factory=list)
|
||||
metadata: Optional[CompressionMetadata] = None
|
||||
error: Optional[str] = None
|
||||
compression_performed: bool = False
|
||||
|
||||
@classmethod
|
||||
def success_with_compression(
|
||||
cls, summary: str, queries: List[Dict], metadata: CompressionMetadata
|
||||
) -> "CompressionResult":
|
||||
"""Create a successful result with compression."""
|
||||
return cls(
|
||||
success=True,
|
||||
compressed_summary=summary,
|
||||
recent_queries=queries,
|
||||
metadata=metadata,
|
||||
compression_performed=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def success_no_compression(cls, queries: List[Dict]) -> "CompressionResult":
|
||||
"""Create a successful result without compression needed."""
|
||||
return cls(
|
||||
success=True,
|
||||
recent_queries=queries,
|
||||
compression_performed=False,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def failure(cls, error: str) -> "CompressionResult":
|
||||
"""Create a failure result."""
|
||||
return cls(success=False, error=error, compression_performed=False)
|
||||
|
||||
def as_history(self) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Convert recent queries to history format.
|
||||
|
||||
Returns:
|
||||
List of prompt/response dicts
|
||||
"""
|
||||
return [
|
||||
{"prompt": q["prompt"], "response": q["response"]}
|
||||
for q in self.recent_queries
|
||||
]
|
||||
@@ -52,7 +52,7 @@ class ConversationService:
|
||||
sources: List[Dict[str, Any]],
|
||||
tool_calls: List[Dict[str, Any]],
|
||||
llm: Any,
|
||||
gpt_model: str,
|
||||
model_id: str,
|
||||
decoded_token: Dict[str, Any],
|
||||
index: Optional[int] = None,
|
||||
api_key: Optional[str] = None,
|
||||
@@ -66,7 +66,7 @@ class ConversationService:
|
||||
if not user_id:
|
||||
raise ValueError("User ID not found in token")
|
||||
current_time = datetime.now(timezone.utc)
|
||||
|
||||
|
||||
# clean up in sources array such that we save max 1k characters for text part
|
||||
for source in sources:
|
||||
if "text" in source and isinstance(source["text"], str):
|
||||
@@ -90,6 +90,7 @@ class ConversationService:
|
||||
f"queries.{index}.tool_calls": tool_calls,
|
||||
f"queries.{index}.timestamp": current_time,
|
||||
f"queries.{index}.attachments": attachment_ids,
|
||||
f"queries.{index}.model_id": model_id,
|
||||
}
|
||||
},
|
||||
)
|
||||
@@ -120,6 +121,7 @@ class ConversationService:
|
||||
"tool_calls": tool_calls,
|
||||
"timestamp": current_time,
|
||||
"attachments": attachment_ids,
|
||||
"model_id": model_id,
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -146,7 +148,7 @@ class ConversationService:
|
||||
]
|
||||
|
||||
completion = llm.gen(
|
||||
model=gpt_model, messages=messages_summary, max_tokens=30
|
||||
model=model_id, messages=messages_summary, max_tokens=30
|
||||
)
|
||||
|
||||
conversation_data = {
|
||||
@@ -162,6 +164,7 @@ class ConversationService:
|
||||
"tool_calls": tool_calls,
|
||||
"timestamp": current_time,
|
||||
"attachments": attachment_ids,
|
||||
"model_id": model_id,
|
||||
}
|
||||
],
|
||||
}
|
||||
@@ -177,3 +180,103 @@ class ConversationService:
|
||||
conversation_data["api_key"] = agent["key"]
|
||||
result = self.conversations_collection.insert_one(conversation_data)
|
||||
return str(result.inserted_id)
|
||||
|
||||
def update_compression_metadata(
|
||||
self, conversation_id: str, compression_metadata: Dict[str, Any]
|
||||
) -> None:
|
||||
"""
|
||||
Update conversation with compression metadata.
|
||||
|
||||
Uses $push with $slice to keep only the most recent compression points,
|
||||
preventing unbounded array growth. Since each compression incorporates
|
||||
previous compressions, older points become redundant.
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation ID
|
||||
compression_metadata: Compression point data
|
||||
"""
|
||||
try:
|
||||
self.conversations_collection.update_one(
|
||||
{"_id": ObjectId(conversation_id)},
|
||||
{
|
||||
"$set": {
|
||||
"compression_metadata.is_compressed": True,
|
||||
"compression_metadata.last_compression_at": compression_metadata.get(
|
||||
"timestamp"
|
||||
),
|
||||
},
|
||||
"$push": {
|
||||
"compression_metadata.compression_points": {
|
||||
"$each": [compression_metadata],
|
||||
"$slice": -settings.COMPRESSION_MAX_HISTORY_POINTS,
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
logger.info(
|
||||
f"Updated compression metadata for conversation {conversation_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error updating compression metadata: {str(e)}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def append_compression_message(
|
||||
self, conversation_id: str, compression_metadata: Dict[str, Any]
|
||||
) -> None:
|
||||
"""
|
||||
Append a synthetic compression summary entry into the conversation history.
|
||||
This makes the summary visible in the DB alongside normal queries.
|
||||
"""
|
||||
try:
|
||||
summary = compression_metadata.get("compressed_summary", "")
|
||||
if not summary:
|
||||
return
|
||||
timestamp = compression_metadata.get("timestamp", datetime.now(timezone.utc))
|
||||
|
||||
self.conversations_collection.update_one(
|
||||
{"_id": ObjectId(conversation_id)},
|
||||
{
|
||||
"$push": {
|
||||
"queries": {
|
||||
"prompt": "[Context Compression Summary]",
|
||||
"response": summary,
|
||||
"thought": "",
|
||||
"sources": [],
|
||||
"tool_calls": [],
|
||||
"timestamp": timestamp,
|
||||
"attachments": [],
|
||||
"model_id": compression_metadata.get("model_used"),
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
logger.info(f"Appended compression summary to conversation {conversation_id}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error appending compression summary: {str(e)}", exc_info=True
|
||||
)
|
||||
|
||||
def get_compression_metadata(
|
||||
self, conversation_id: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get compression metadata for a conversation.
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation ID
|
||||
|
||||
Returns:
|
||||
Compression metadata dict or None
|
||||
"""
|
||||
try:
|
||||
conversation = self.conversations_collection.find_one(
|
||||
{"_id": ObjectId(conversation_id)}, {"compression_metadata": 1}
|
||||
)
|
||||
return conversation.get("compression_metadata") if conversation else None
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting compression metadata: {str(e)}", exc_info=True
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -10,14 +10,21 @@ from bson.dbref import DBRef
|
||||
from bson.objectid import ObjectId
|
||||
|
||||
from application.agents.agent_creator import AgentCreator
|
||||
from application.api.answer.services.compression import CompressionOrchestrator
|
||||
from application.api.answer.services.compression.token_counter import TokenCounter
|
||||
from application.api.answer.services.conversation_service import ConversationService
|
||||
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||||
from application.core.model_utils import (
|
||||
get_api_key_for_provider,
|
||||
get_default_model_id,
|
||||
get_provider_from_model_id,
|
||||
validate_model_id,
|
||||
)
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.retriever.retriever_creator import RetrieverCreator
|
||||
from application.utils import (
|
||||
calculate_doc_token_budget,
|
||||
get_gpt_model,
|
||||
limit_chat_history,
|
||||
)
|
||||
|
||||
@@ -83,18 +90,23 @@ class StreamProcessor:
|
||||
self.retriever_config = {}
|
||||
self.is_shared_usage = False
|
||||
self.shared_token = None
|
||||
self.gpt_model = get_gpt_model()
|
||||
self.model_id: Optional[str] = None
|
||||
self.conversation_service = ConversationService()
|
||||
self.compression_orchestrator = CompressionOrchestrator(
|
||||
self.conversation_service
|
||||
)
|
||||
self.prompt_renderer = PromptRenderer()
|
||||
self._prompt_content: Optional[str] = None
|
||||
self._required_tool_actions: Optional[Dict[str, Set[Optional[str]]]] = None
|
||||
self.compressed_summary: Optional[str] = None
|
||||
self.compressed_summary_tokens: int = 0
|
||||
|
||||
def initialize(self):
|
||||
"""Initialize all required components for processing"""
|
||||
self._configure_agent()
|
||||
self._validate_and_set_model()
|
||||
self._configure_source()
|
||||
self._configure_retriever()
|
||||
self._configure_agent()
|
||||
self._load_conversation_history()
|
||||
self._process_attachments()
|
||||
|
||||
@@ -106,14 +118,71 @@ class StreamProcessor:
|
||||
)
|
||||
if not conversation:
|
||||
raise ValueError("Conversation not found or unauthorized")
|
||||
|
||||
# Check if compression is enabled and needed
|
||||
if settings.ENABLE_CONVERSATION_COMPRESSION:
|
||||
self._handle_compression(conversation)
|
||||
else:
|
||||
# Original behavior - load all history
|
||||
self.history = [
|
||||
{"prompt": query["prompt"], "response": query["response"]}
|
||||
for query in conversation.get("queries", [])
|
||||
]
|
||||
else:
|
||||
self.history = limit_chat_history(
|
||||
json.loads(self.data.get("history", "[]")), model_id=self.model_id
|
||||
)
|
||||
|
||||
def _handle_compression(self, conversation: Dict[str, Any]):
|
||||
"""
|
||||
Handle conversation compression logic using orchestrator.
|
||||
|
||||
Args:
|
||||
conversation: Full conversation document
|
||||
"""
|
||||
try:
|
||||
# Use orchestrator to handle all compression logic
|
||||
result = self.compression_orchestrator.compress_if_needed(
|
||||
conversation_id=self.conversation_id,
|
||||
user_id=self.initial_user_id,
|
||||
model_id=self.model_id,
|
||||
decoded_token=self.decoded_token,
|
||||
)
|
||||
|
||||
if not result.success:
|
||||
logger.error(
|
||||
f"Compression failed: {result.error}, using full history"
|
||||
)
|
||||
self.history = [
|
||||
{"prompt": query["prompt"], "response": query["response"]}
|
||||
for query in conversation.get("queries", [])
|
||||
]
|
||||
return
|
||||
|
||||
# Set compressed summary if compression was performed
|
||||
if result.compression_performed and result.compressed_summary:
|
||||
self.compressed_summary = result.compressed_summary
|
||||
self.compressed_summary_tokens = TokenCounter.count_message_tokens(
|
||||
[{"content": result.compressed_summary}]
|
||||
)
|
||||
logger.info(
|
||||
f"Using compressed summary ({self.compressed_summary_tokens} tokens) "
|
||||
f"+ {len(result.recent_queries)} recent messages"
|
||||
)
|
||||
|
||||
# Build history from recent queries
|
||||
self.history = result.as_history()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error handling compression, falling back to standard history: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Fallback to original behavior
|
||||
self.history = [
|
||||
{"prompt": query["prompt"], "response": query["response"]}
|
||||
for query in conversation.get("queries", [])
|
||||
]
|
||||
else:
|
||||
self.history = limit_chat_history(
|
||||
json.loads(self.data.get("history", "[]")), gpt_model=self.gpt_model
|
||||
)
|
||||
|
||||
def _process_attachments(self):
|
||||
"""Process any attachments in the request"""
|
||||
@@ -143,6 +212,30 @@ class StreamProcessor:
|
||||
)
|
||||
return attachments
|
||||
|
||||
def _validate_and_set_model(self):
|
||||
"""Validate and set model_id from request"""
|
||||
from application.core.model_settings import ModelRegistry
|
||||
|
||||
requested_model = self.data.get("model_id")
|
||||
|
||||
if requested_model:
|
||||
if not validate_model_id(requested_model):
|
||||
registry = ModelRegistry.get_instance()
|
||||
available_models = [m.id for m in registry.get_enabled_models()]
|
||||
raise ValueError(
|
||||
f"Invalid model_id '{requested_model}'. "
|
||||
f"Available models: {', '.join(available_models[:5])}"
|
||||
+ (f" and {len(available_models) - 5} more" if len(available_models) > 5 else "")
|
||||
)
|
||||
self.model_id = requested_model
|
||||
else:
|
||||
# Check if agent has a default model configured
|
||||
agent_default_model = self.agent_config.get("default_model_id", "")
|
||||
if agent_default_model and validate_model_id(agent_default_model):
|
||||
self.model_id = agent_default_model
|
||||
else:
|
||||
self.model_id = get_default_model_id()
|
||||
|
||||
def _get_agent_key(self, agent_id: Optional[str], user_id: Optional[str]) -> tuple:
|
||||
"""Get API key for agent with access control"""
|
||||
if not agent_id:
|
||||
@@ -214,6 +307,10 @@ class StreamProcessor:
|
||||
data["sources"] = sources_list
|
||||
else:
|
||||
data["sources"] = []
|
||||
|
||||
# Preserve model configuration from agent
|
||||
data["default_model_id"] = data.get("default_model_id", "")
|
||||
|
||||
return data
|
||||
|
||||
def _configure_source(self):
|
||||
@@ -266,6 +363,7 @@ class StreamProcessor:
|
||||
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
|
||||
"user_api_key": api_key,
|
||||
"json_schema": data_key.get("json_schema"),
|
||||
"default_model_id": data_key.get("default_model_id", ""),
|
||||
}
|
||||
)
|
||||
self.initial_user_id = data_key.get("user")
|
||||
@@ -290,6 +388,7 @@ class StreamProcessor:
|
||||
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
|
||||
"user_api_key": self.agent_key,
|
||||
"json_schema": data_key.get("json_schema"),
|
||||
"default_model_id": data_key.get("default_model_id", ""),
|
||||
}
|
||||
)
|
||||
self.decoded_token = (
|
||||
@@ -316,13 +415,14 @@ class StreamProcessor:
|
||||
"agent_type": settings.AGENT_NAME,
|
||||
"user_api_key": None,
|
||||
"json_schema": None,
|
||||
"default_model_id": "",
|
||||
}
|
||||
)
|
||||
|
||||
def _configure_retriever(self):
|
||||
history_token_limit = int(self.data.get("token_limit", 2000))
|
||||
doc_token_limit = calculate_doc_token_budget(
|
||||
gpt_model=self.gpt_model, history_token_limit=history_token_limit
|
||||
model_id=self.model_id, history_token_limit=history_token_limit
|
||||
)
|
||||
|
||||
self.retriever_config = {
|
||||
@@ -344,7 +444,7 @@ class StreamProcessor:
|
||||
prompt=get_prompt(self.agent_config["prompt_id"], self.prompts_collection),
|
||||
chunks=self.retriever_config["chunks"],
|
||||
doc_token_limit=self.retriever_config.get("doc_token_limit", 50000),
|
||||
gpt_model=self.gpt_model,
|
||||
model_id=self.model_id,
|
||||
user_api_key=self.agent_config["user_api_key"],
|
||||
decoded_token=self.decoded_token,
|
||||
)
|
||||
@@ -626,12 +726,19 @@ class StreamProcessor:
|
||||
tools_data=tools_data,
|
||||
)
|
||||
|
||||
return AgentCreator.create_agent(
|
||||
provider = (
|
||||
get_provider_from_model_id(self.model_id)
|
||||
if self.model_id
|
||||
else settings.LLM_PROVIDER
|
||||
)
|
||||
system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER)
|
||||
|
||||
agent = AgentCreator.create_agent(
|
||||
self.agent_config["agent_type"],
|
||||
endpoint="stream",
|
||||
llm_name=settings.LLM_PROVIDER,
|
||||
gpt_model=self.gpt_model,
|
||||
api_key=settings.API_KEY,
|
||||
llm_name=provider or settings.LLM_PROVIDER,
|
||||
model_id=self.model_id,
|
||||
api_key=system_api_key,
|
||||
user_api_key=self.agent_config["user_api_key"],
|
||||
prompt=rendered_prompt,
|
||||
chat_history=self.history,
|
||||
@@ -639,4 +746,10 @@ class StreamProcessor:
|
||||
decoded_token=self.decoded_token,
|
||||
attachments=self.attachments,
|
||||
json_schema=self.agent_config.get("json_schema"),
|
||||
compressed_summary=self.compressed_summary,
|
||||
)
|
||||
|
||||
agent.conversation_id = self.conversation_id
|
||||
agent.initial_user_id = self.initial_user_id
|
||||
|
||||
return agent
|
||||
|
||||
@@ -95,6 +95,8 @@ class GetAgent(Resource):
|
||||
"shared": agent.get("shared_publicly", False),
|
||||
"shared_metadata": agent.get("shared_metadata", {}),
|
||||
"shared_token": agent.get("shared_token", ""),
|
||||
"models": agent.get("models", []),
|
||||
"default_model_id": agent.get("default_model_id", ""),
|
||||
}
|
||||
return make_response(jsonify(data), 200)
|
||||
except Exception as e:
|
||||
@@ -172,6 +174,8 @@ class GetAgents(Resource):
|
||||
"shared": agent.get("shared_publicly", False),
|
||||
"shared_metadata": agent.get("shared_metadata", {}),
|
||||
"shared_token": agent.get("shared_token", ""),
|
||||
"models": agent.get("models", []),
|
||||
"default_model_id": agent.get("default_model_id", ""),
|
||||
}
|
||||
for agent in agents
|
||||
if "source" in agent or "retriever" in agent
|
||||
@@ -230,6 +234,14 @@ class CreateAgent(Resource):
|
||||
required=False,
|
||||
description="Request limit for the agent in limited mode",
|
||||
),
|
||||
"models": fields.List(
|
||||
fields.String,
|
||||
required=False,
|
||||
description="List of available model IDs for this agent",
|
||||
),
|
||||
"default_model_id": fields.String(
|
||||
required=False, description="Default model ID for this agent"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -258,6 +270,11 @@ class CreateAgent(Resource):
|
||||
data["json_schema"] = json.loads(data["json_schema"])
|
||||
except json.JSONDecodeError:
|
||||
data["json_schema"] = None
|
||||
if "models" in data:
|
||||
try:
|
||||
data["models"] = json.loads(data["models"])
|
||||
except json.JSONDecodeError:
|
||||
data["models"] = []
|
||||
print(f"Received data: {data}")
|
||||
|
||||
# Validate JSON schema if provided
|
||||
@@ -399,6 +416,8 @@ class CreateAgent(Resource):
|
||||
"updatedAt": datetime.datetime.now(datetime.timezone.utc),
|
||||
"lastUsedAt": None,
|
||||
"key": key,
|
||||
"models": data.get("models", []),
|
||||
"default_model_id": data.get("default_model_id", ""),
|
||||
}
|
||||
if new_agent["chunks"] == "":
|
||||
new_agent["chunks"] = "2"
|
||||
@@ -464,6 +483,14 @@ class UpdateAgent(Resource):
|
||||
required=False,
|
||||
description="Request limit for the agent in limited mode",
|
||||
),
|
||||
"models": fields.List(
|
||||
fields.String,
|
||||
required=False,
|
||||
description="List of available model IDs for this agent",
|
||||
),
|
||||
"default_model_id": fields.String(
|
||||
required=False, description="Default model ID for this agent"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -487,7 +514,7 @@ class UpdateAgent(Resource):
|
||||
data = request.get_json()
|
||||
else:
|
||||
data = request.form.to_dict()
|
||||
json_fields = ["tools", "sources", "json_schema"]
|
||||
json_fields = ["tools", "sources", "json_schema", "models"]
|
||||
for field in json_fields:
|
||||
if field in data and data[field]:
|
||||
try:
|
||||
@@ -555,6 +582,8 @@ class UpdateAgent(Resource):
|
||||
"token_limit",
|
||||
"limited_request_mode",
|
||||
"request_limit",
|
||||
"models",
|
||||
"default_model_id",
|
||||
]
|
||||
|
||||
for field in allowed_fields:
|
||||
|
||||
@@ -25,7 +25,7 @@ class StoreAttachment(Resource):
|
||||
api.model(
|
||||
"AttachmentModel",
|
||||
{
|
||||
"file": fields.Raw(required=True, description="File to upload"),
|
||||
"file": fields.Raw(required=True, description="File(s) to upload"),
|
||||
"api_key": fields.String(
|
||||
required=False, description="API key (optional)"
|
||||
),
|
||||
@@ -33,18 +33,24 @@ class StoreAttachment(Resource):
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Stores a single attachment without vectorization or training. Supports user or API key authentication."
|
||||
description="Stores one or multiple attachments without vectorization or training. Supports user or API key authentication."
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = getattr(request, "decoded_token", None)
|
||||
api_key = request.form.get("api_key") or request.args.get("api_key")
|
||||
file = request.files.get("file")
|
||||
|
||||
if not file or file.filename == "":
|
||||
|
||||
files = request.files.getlist("file")
|
||||
if not files:
|
||||
single_file = request.files.get("file")
|
||||
if single_file:
|
||||
files = [single_file]
|
||||
|
||||
if not files or all(f.filename == "" for f in files):
|
||||
return make_response(
|
||||
jsonify({"status": "error", "message": "Missing file"}),
|
||||
jsonify({"status": "error", "message": "Missing file(s)"}),
|
||||
400,
|
||||
)
|
||||
|
||||
user = None
|
||||
if decoded_token:
|
||||
user = safe_filename(decoded_token.get("sub"))
|
||||
@@ -59,32 +65,74 @@ class StoreAttachment(Resource):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Authentication required"}), 401
|
||||
)
|
||||
|
||||
try:
|
||||
attachment_id = ObjectId()
|
||||
original_filename = safe_filename(os.path.basename(file.filename))
|
||||
relative_path = f"{settings.UPLOAD_FOLDER}/{user}/attachments/{str(attachment_id)}/{original_filename}"
|
||||
tasks = []
|
||||
errors = []
|
||||
original_file_count = len(files)
|
||||
|
||||
for idx, file in enumerate(files):
|
||||
try:
|
||||
attachment_id = ObjectId()
|
||||
original_filename = safe_filename(os.path.basename(file.filename))
|
||||
relative_path = f"{settings.UPLOAD_FOLDER}/{user}/attachments/{str(attachment_id)}/{original_filename}"
|
||||
|
||||
metadata = storage.save_file(file, relative_path)
|
||||
|
||||
file_info = {
|
||||
"filename": original_filename,
|
||||
"attachment_id": str(attachment_id),
|
||||
"path": relative_path,
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
task = store_attachment.delay(file_info, user)
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"task_id": task.id,
|
||||
"message": "File uploaded successfully. Processing started.",
|
||||
metadata = storage.save_file(file, relative_path)
|
||||
file_info = {
|
||||
"filename": original_filename,
|
||||
"attachment_id": str(attachment_id),
|
||||
"path": relative_path,
|
||||
"metadata": metadata,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
|
||||
task = store_attachment.delay(file_info, user)
|
||||
tasks.append({
|
||||
"task_id": task.id,
|
||||
"filename": original_filename,
|
||||
"attachment_id": str(attachment_id),
|
||||
})
|
||||
except Exception as file_err:
|
||||
current_app.logger.error(f"Error processing file {idx} ({file.filename}): {file_err}", exc_info=True)
|
||||
errors.append({
|
||||
"filename": file.filename,
|
||||
"error": str(file_err)
|
||||
})
|
||||
|
||||
if not tasks:
|
||||
error_msg = "No valid files to upload"
|
||||
if errors:
|
||||
error_msg += f". Errors: {errors}"
|
||||
return make_response(
|
||||
jsonify({"status": "error", "message": error_msg, "errors": errors}),
|
||||
400,
|
||||
)
|
||||
|
||||
if original_file_count == 1 and len(tasks) == 1:
|
||||
current_app.logger.info("Returning single task_id response")
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"task_id": tasks[0]["task_id"],
|
||||
"message": "File uploaded successfully. Processing started.",
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
else:
|
||||
response_data = {
|
||||
"success": True,
|
||||
"tasks": tasks,
|
||||
"message": f"{len(tasks)} file(s) uploaded successfully. Processing started.",
|
||||
}
|
||||
if errors:
|
||||
response_data["errors"] = errors
|
||||
response_data["message"] += f" {len(errors)} file(s) failed."
|
||||
|
||||
return make_response(
|
||||
jsonify(response_data),
|
||||
200,
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error storing attachment: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False, "error": str(err)}), 400)
|
||||
@@ -130,15 +178,11 @@ class TextToSpeech(Resource):
|
||||
@api.expect(tts_model)
|
||||
@api.doc(description="Synthesize audio speech from text")
|
||||
def post(self):
|
||||
from application.utils import clean_text_for_tts
|
||||
|
||||
data = request.get_json()
|
||||
text = data["text"]
|
||||
cleaned_text = clean_text_for_tts(text)
|
||||
|
||||
try:
|
||||
tts_instance = TTSCreator.create_tts(settings.TTS_PROVIDER)
|
||||
audio_base64, detected_language = tts_instance.text_to_speech(cleaned_text)
|
||||
audio_base64, detected_language = tts_instance.text_to_speech(text)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
|
||||
3
application/api/user/models/__init__.py
Normal file
3
application/api/user/models/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .routes import models_ns
|
||||
|
||||
__all__ = ["models_ns"]
|
||||
25
application/api/user/models/routes.py
Normal file
25
application/api/user/models/routes.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from flask import current_app, jsonify, make_response
|
||||
from flask_restx import Namespace, Resource
|
||||
|
||||
from application.core.model_settings import ModelRegistry
|
||||
|
||||
models_ns = Namespace("models", description="Available models", path="/api")
|
||||
|
||||
|
||||
@models_ns.route("/models")
|
||||
class ModelsListResource(Resource):
|
||||
def get(self):
|
||||
"""Get list of available models with their capabilities."""
|
||||
try:
|
||||
registry = ModelRegistry.get_instance()
|
||||
models = registry.get_enabled_models()
|
||||
|
||||
response = {
|
||||
"models": [model.to_dict() for model in models],
|
||||
"default_model_id": registry.default_model_id,
|
||||
"count": len(models),
|
||||
}
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error fetching models: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 500)
|
||||
return make_response(jsonify(response), 200)
|
||||
@@ -10,6 +10,7 @@ from .agents import agents_ns, agents_sharing_ns, agents_webhooks_ns
|
||||
from .analytics import analytics_ns
|
||||
from .attachments import attachments_ns
|
||||
from .conversations import conversations_ns
|
||||
from .models import models_ns
|
||||
from .prompts import prompts_ns
|
||||
from .sharing import sharing_ns
|
||||
from .sources import sources_chunks_ns, sources_ns, sources_upload_ns
|
||||
@@ -27,6 +28,9 @@ api.add_namespace(attachments_ns)
|
||||
# Conversations
|
||||
api.add_namespace(conversations_ns)
|
||||
|
||||
# Models
|
||||
api.add_namespace(models_ns)
|
||||
|
||||
# Agents (main, sharing, webhooks)
|
||||
api.add_namespace(agents_ns)
|
||||
api.add_namespace(agents_sharing_ns)
|
||||
|
||||
@@ -13,7 +13,6 @@ from application.api.user.base import (
|
||||
agents_collection,
|
||||
attachments_collection,
|
||||
conversations_collection,
|
||||
db,
|
||||
shared_conversations_collections,
|
||||
)
|
||||
from application.utils import check_required_fields
|
||||
@@ -97,9 +96,7 @@ class ShareConversation(Resource):
|
||||
api_uuid = pre_existing_api_document["key"]
|
||||
pre_existing = shared_conversations_collections.find_one(
|
||||
{
|
||||
"conversation_id": DBRef(
|
||||
"conversations", ObjectId(conversation_id)
|
||||
),
|
||||
"conversation_id": ObjectId(conversation_id),
|
||||
"isPromptable": is_promptable,
|
||||
"first_n_queries": current_n_queries,
|
||||
"user": user,
|
||||
@@ -120,10 +117,7 @@ class ShareConversation(Resource):
|
||||
shared_conversations_collections.insert_one(
|
||||
{
|
||||
"uuid": explicit_binary,
|
||||
"conversation_id": {
|
||||
"$ref": "conversations",
|
||||
"$id": ObjectId(conversation_id),
|
||||
},
|
||||
"conversation_id": ObjectId(conversation_id),
|
||||
"isPromptable": is_promptable,
|
||||
"first_n_queries": current_n_queries,
|
||||
"user": user,
|
||||
@@ -154,10 +148,7 @@ class ShareConversation(Resource):
|
||||
shared_conversations_collections.insert_one(
|
||||
{
|
||||
"uuid": explicit_binary,
|
||||
"conversation_id": {
|
||||
"$ref": "conversations",
|
||||
"$id": ObjectId(conversation_id),
|
||||
},
|
||||
"conversation_id": ObjectId(conversation_id),
|
||||
"isPromptable": is_promptable,
|
||||
"first_n_queries": current_n_queries,
|
||||
"user": user,
|
||||
@@ -175,9 +166,7 @@ class ShareConversation(Resource):
|
||||
)
|
||||
pre_existing = shared_conversations_collections.find_one(
|
||||
{
|
||||
"conversation_id": DBRef(
|
||||
"conversations", ObjectId(conversation_id)
|
||||
),
|
||||
"conversation_id": ObjectId(conversation_id),
|
||||
"isPromptable": is_promptable,
|
||||
"first_n_queries": current_n_queries,
|
||||
"user": user,
|
||||
@@ -197,10 +186,7 @@ class ShareConversation(Resource):
|
||||
shared_conversations_collections.insert_one(
|
||||
{
|
||||
"uuid": explicit_binary,
|
||||
"conversation_id": {
|
||||
"$ref": "conversations",
|
||||
"$id": ObjectId(conversation_id),
|
||||
},
|
||||
"conversation_id": ObjectId(conversation_id),
|
||||
"isPromptable": is_promptable,
|
||||
"first_n_queries": current_n_queries,
|
||||
"user": user,
|
||||
@@ -233,10 +219,12 @@ class GetPubliclySharedConversations(Resource):
|
||||
if (
|
||||
shared
|
||||
and "conversation_id" in shared
|
||||
and isinstance(shared["conversation_id"], DBRef)
|
||||
):
|
||||
conversation_ref = shared["conversation_id"]
|
||||
conversation = db.dereference(conversation_ref)
|
||||
# conversation_id is now stored as an ObjectId, not a DBRef
|
||||
conversation_id = shared["conversation_id"]
|
||||
conversation = conversations_collection.find_one(
|
||||
{"_id": conversation_id}
|
||||
)
|
||||
if conversation is None:
|
||||
return make_response(
|
||||
jsonify(
|
||||
|
||||
189
application/core/model_configs.py
Normal file
189
application/core/model_configs.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""
|
||||
Model configurations for all supported LLM providers.
|
||||
"""
|
||||
|
||||
from application.core.model_settings import (
|
||||
AvailableModel,
|
||||
ModelCapabilities,
|
||||
ModelProvider,
|
||||
)
|
||||
|
||||
OPENAI_ATTACHMENTS = [
|
||||
"application/pdf",
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/jpg",
|
||||
"image/webp",
|
||||
"image/gif",
|
||||
]
|
||||
|
||||
GOOGLE_ATTACHMENTS = [
|
||||
"application/pdf",
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/jpg",
|
||||
"image/webp",
|
||||
"image/gif",
|
||||
]
|
||||
|
||||
|
||||
OPENAI_MODELS = [
|
||||
AvailableModel(
|
||||
id="gpt-5.1",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="GPT-5.1",
|
||||
description="Flagship model with enhanced reasoning, coding, and agentic capabilities",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=OPENAI_ATTACHMENTS,
|
||||
context_window=200000,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="gpt-5-mini",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="GPT-5 Mini",
|
||||
description="Faster, cost-effective variant of GPT-5.1",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=OPENAI_ATTACHMENTS,
|
||||
context_window=200000,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
ANTHROPIC_MODELS = [
|
||||
AvailableModel(
|
||||
id="claude-3-5-sonnet-20241022",
|
||||
provider=ModelProvider.ANTHROPIC,
|
||||
display_name="Claude 3.5 Sonnet (Latest)",
|
||||
description="Latest Claude 3.5 Sonnet with enhanced capabilities",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
context_window=200000,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="claude-3-5-sonnet",
|
||||
provider=ModelProvider.ANTHROPIC,
|
||||
display_name="Claude 3.5 Sonnet",
|
||||
description="Balanced performance and capability",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
context_window=200000,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="claude-3-opus",
|
||||
provider=ModelProvider.ANTHROPIC,
|
||||
display_name="Claude 3 Opus",
|
||||
description="Most capable Claude model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
context_window=200000,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="claude-3-haiku",
|
||||
provider=ModelProvider.ANTHROPIC,
|
||||
display_name="Claude 3 Haiku",
|
||||
description="Fastest Claude model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
context_window=200000,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
GOOGLE_MODELS = [
|
||||
AvailableModel(
|
||||
id="gemini-flash-latest",
|
||||
provider=ModelProvider.GOOGLE,
|
||||
display_name="Gemini Flash (Latest)",
|
||||
description="Latest experimental Gemini model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=GOOGLE_ATTACHMENTS,
|
||||
context_window=int(1e6),
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="gemini-flash-lite-latest",
|
||||
provider=ModelProvider.GOOGLE,
|
||||
display_name="Gemini Flash Lite (Latest)",
|
||||
description="Fast with huge context window",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=GOOGLE_ATTACHMENTS,
|
||||
context_window=int(1e6),
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="gemini-3-pro-preview",
|
||||
provider=ModelProvider.GOOGLE,
|
||||
display_name="Gemini 3 Pro",
|
||||
description="Most capable Gemini model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=GOOGLE_ATTACHMENTS,
|
||||
context_window=2000000,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
GROQ_MODELS = [
|
||||
AvailableModel(
|
||||
id="llama-3.3-70b-versatile",
|
||||
provider=ModelProvider.GROQ,
|
||||
display_name="Llama 3.3 70B",
|
||||
description="Latest Llama model with high-speed inference",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
context_window=128000,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="llama-3.1-8b-instant",
|
||||
provider=ModelProvider.GROQ,
|
||||
display_name="Llama 3.1 8B",
|
||||
description="Ultra-fast inference",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
context_window=128000,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="mixtral-8x7b-32768",
|
||||
provider=ModelProvider.GROQ,
|
||||
display_name="Mixtral 8x7B",
|
||||
description="High-speed inference with tools",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
context_window=32768,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
AZURE_OPENAI_MODELS = [
|
||||
AvailableModel(
|
||||
id="azure-gpt-4",
|
||||
provider=ModelProvider.AZURE_OPENAI,
|
||||
display_name="Azure OpenAI GPT-4",
|
||||
description="Azure-hosted GPT model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=OPENAI_ATTACHMENTS,
|
||||
context_window=8192,
|
||||
),
|
||||
),
|
||||
]
|
||||
236
application/core/model_settings.py
Normal file
236
application/core/model_settings.py
Normal file
@@ -0,0 +1,236 @@
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelProvider(str, Enum):
|
||||
OPENAI = "openai"
|
||||
AZURE_OPENAI = "azure_openai"
|
||||
ANTHROPIC = "anthropic"
|
||||
GROQ = "groq"
|
||||
GOOGLE = "google"
|
||||
HUGGINGFACE = "huggingface"
|
||||
LLAMA_CPP = "llama.cpp"
|
||||
DOCSGPT = "docsgpt"
|
||||
PREMAI = "premai"
|
||||
SAGEMAKER = "sagemaker"
|
||||
NOVITA = "novita"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelCapabilities:
|
||||
supports_tools: bool = False
|
||||
supports_structured_output: bool = False
|
||||
supports_streaming: bool = True
|
||||
supported_attachment_types: List[str] = field(default_factory=list)
|
||||
context_window: int = 128000
|
||||
input_cost_per_token: Optional[float] = None
|
||||
output_cost_per_token: Optional[float] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AvailableModel:
|
||||
id: str
|
||||
provider: ModelProvider
|
||||
display_name: str
|
||||
description: str = ""
|
||||
capabilities: ModelCapabilities = field(default_factory=ModelCapabilities)
|
||||
enabled: bool = True
|
||||
base_url: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
result = {
|
||||
"id": self.id,
|
||||
"provider": self.provider.value,
|
||||
"display_name": self.display_name,
|
||||
"description": self.description,
|
||||
"supported_attachment_types": self.capabilities.supported_attachment_types,
|
||||
"supports_tools": self.capabilities.supports_tools,
|
||||
"supports_structured_output": self.capabilities.supports_structured_output,
|
||||
"supports_streaming": self.capabilities.supports_streaming,
|
||||
"context_window": self.capabilities.context_window,
|
||||
"enabled": self.enabled,
|
||||
}
|
||||
if self.base_url:
|
||||
result["base_url"] = self.base_url
|
||||
return result
|
||||
|
||||
|
||||
class ModelRegistry:
|
||||
_instance = None
|
||||
_initialized = False
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not ModelRegistry._initialized:
|
||||
self.models: Dict[str, AvailableModel] = {}
|
||||
self.default_model_id: Optional[str] = None
|
||||
self._load_models()
|
||||
ModelRegistry._initialized = True
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "ModelRegistry":
|
||||
return cls()
|
||||
|
||||
def _load_models(self):
|
||||
from application.core.settings import settings
|
||||
|
||||
self.models.clear()
|
||||
|
||||
self._add_docsgpt_models(settings)
|
||||
if settings.OPENAI_API_KEY or (
|
||||
settings.LLM_PROVIDER == "openai" and settings.API_KEY
|
||||
):
|
||||
self._add_openai_models(settings)
|
||||
if settings.OPENAI_API_BASE or (
|
||||
settings.LLM_PROVIDER == "azure_openai" and settings.API_KEY
|
||||
):
|
||||
self._add_azure_openai_models(settings)
|
||||
if settings.ANTHROPIC_API_KEY or (
|
||||
settings.LLM_PROVIDER == "anthropic" and settings.API_KEY
|
||||
):
|
||||
self._add_anthropic_models(settings)
|
||||
if settings.GOOGLE_API_KEY or (
|
||||
settings.LLM_PROVIDER == "google" and settings.API_KEY
|
||||
):
|
||||
self._add_google_models(settings)
|
||||
if settings.GROQ_API_KEY or (
|
||||
settings.LLM_PROVIDER == "groq" and settings.API_KEY
|
||||
):
|
||||
self._add_groq_models(settings)
|
||||
if settings.HUGGINGFACE_API_KEY or (
|
||||
settings.LLM_PROVIDER == "huggingface" and settings.API_KEY
|
||||
):
|
||||
self._add_huggingface_models(settings)
|
||||
# Default model selection
|
||||
|
||||
if settings.LLM_NAME and settings.LLM_NAME in self.models:
|
||||
self.default_model_id = settings.LLM_NAME
|
||||
elif settings.LLM_PROVIDER and settings.API_KEY:
|
||||
for model_id, model in self.models.items():
|
||||
if model.provider.value == settings.LLM_PROVIDER:
|
||||
self.default_model_id = model_id
|
||||
break
|
||||
else:
|
||||
self.default_model_id = next(iter(self.models.keys()))
|
||||
logger.info(
|
||||
f"ModelRegistry loaded {len(self.models)} models, default: {self.default_model_id}"
|
||||
)
|
||||
|
||||
def _add_openai_models(self, settings):
|
||||
from application.core.model_configs import OPENAI_MODELS
|
||||
|
||||
if settings.OPENAI_API_KEY:
|
||||
for model in OPENAI_MODELS:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
if settings.LLM_PROVIDER == "openai" and settings.LLM_NAME:
|
||||
for model in OPENAI_MODELS:
|
||||
if model.id == settings.LLM_NAME:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
for model in OPENAI_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_azure_openai_models(self, settings):
|
||||
from application.core.model_configs import AZURE_OPENAI_MODELS
|
||||
|
||||
if settings.LLM_PROVIDER == "azure_openai" and settings.LLM_NAME:
|
||||
for model in AZURE_OPENAI_MODELS:
|
||||
if model.id == settings.LLM_NAME:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
for model in AZURE_OPENAI_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_anthropic_models(self, settings):
|
||||
from application.core.model_configs import ANTHROPIC_MODELS
|
||||
|
||||
if settings.ANTHROPIC_API_KEY:
|
||||
for model in ANTHROPIC_MODELS:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
if settings.LLM_PROVIDER == "anthropic" and settings.LLM_NAME:
|
||||
for model in ANTHROPIC_MODELS:
|
||||
if model.id == settings.LLM_NAME:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
for model in ANTHROPIC_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_google_models(self, settings):
|
||||
from application.core.model_configs import GOOGLE_MODELS
|
||||
|
||||
if settings.GOOGLE_API_KEY:
|
||||
for model in GOOGLE_MODELS:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
if settings.LLM_PROVIDER == "google" and settings.LLM_NAME:
|
||||
for model in GOOGLE_MODELS:
|
||||
if model.id == settings.LLM_NAME:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
for model in GOOGLE_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_groq_models(self, settings):
|
||||
from application.core.model_configs import GROQ_MODELS
|
||||
|
||||
if settings.GROQ_API_KEY:
|
||||
for model in GROQ_MODELS:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
if settings.LLM_PROVIDER == "groq" and settings.LLM_NAME:
|
||||
for model in GROQ_MODELS:
|
||||
if model.id == settings.LLM_NAME:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
for model in GROQ_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_docsgpt_models(self, settings):
|
||||
model_id = "docsgpt-local"
|
||||
model = AvailableModel(
|
||||
id=model_id,
|
||||
provider=ModelProvider.DOCSGPT,
|
||||
display_name="DocsGPT Model",
|
||||
description="Local model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=False,
|
||||
supported_attachment_types=[],
|
||||
),
|
||||
)
|
||||
self.models[model_id] = model
|
||||
|
||||
def _add_huggingface_models(self, settings):
|
||||
model_id = "huggingface-local"
|
||||
model = AvailableModel(
|
||||
id=model_id,
|
||||
provider=ModelProvider.HUGGINGFACE,
|
||||
display_name="Hugging Face Model",
|
||||
description="Local Hugging Face model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=False,
|
||||
supported_attachment_types=[],
|
||||
),
|
||||
)
|
||||
self.models[model_id] = model
|
||||
|
||||
def get_model(self, model_id: str) -> Optional[AvailableModel]:
|
||||
return self.models.get(model_id)
|
||||
|
||||
def get_all_models(self) -> List[AvailableModel]:
|
||||
return list(self.models.values())
|
||||
|
||||
def get_enabled_models(self) -> List[AvailableModel]:
|
||||
return [m for m in self.models.values() if m.enabled]
|
||||
|
||||
def model_exists(self, model_id: str) -> bool:
|
||||
return model_id in self.models
|
||||
91
application/core/model_utils.py
Normal file
91
application/core/model_utils.py
Normal file
@@ -0,0 +1,91 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from application.core.model_settings import ModelRegistry
|
||||
|
||||
|
||||
def get_api_key_for_provider(provider: str) -> Optional[str]:
|
||||
"""Get the appropriate API key for a provider"""
|
||||
from application.core.settings import settings
|
||||
|
||||
provider_key_map = {
|
||||
"openai": settings.OPENAI_API_KEY,
|
||||
"anthropic": settings.ANTHROPIC_API_KEY,
|
||||
"google": settings.GOOGLE_API_KEY,
|
||||
"groq": settings.GROQ_API_KEY,
|
||||
"huggingface": settings.HUGGINGFACE_API_KEY,
|
||||
"azure_openai": settings.API_KEY,
|
||||
"docsgpt": None,
|
||||
"llama.cpp": None,
|
||||
}
|
||||
|
||||
provider_key = provider_key_map.get(provider)
|
||||
if provider_key:
|
||||
return provider_key
|
||||
return settings.API_KEY
|
||||
|
||||
|
||||
def get_all_available_models() -> Dict[str, Dict[str, Any]]:
|
||||
"""Get all available models with metadata for API response"""
|
||||
registry = ModelRegistry.get_instance()
|
||||
return {model.id: model.to_dict() for model in registry.get_enabled_models()}
|
||||
|
||||
|
||||
def validate_model_id(model_id: str) -> bool:
|
||||
"""Check if a model ID exists in registry"""
|
||||
registry = ModelRegistry.get_instance()
|
||||
return registry.model_exists(model_id)
|
||||
|
||||
|
||||
def get_model_capabilities(model_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get capabilities for a specific model"""
|
||||
registry = ModelRegistry.get_instance()
|
||||
model = registry.get_model(model_id)
|
||||
if model:
|
||||
return {
|
||||
"supported_attachment_types": model.capabilities.supported_attachment_types,
|
||||
"supports_tools": model.capabilities.supports_tools,
|
||||
"supports_structured_output": model.capabilities.supports_structured_output,
|
||||
"context_window": model.capabilities.context_window,
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
def get_default_model_id() -> str:
|
||||
"""Get the system default model ID"""
|
||||
registry = ModelRegistry.get_instance()
|
||||
return registry.default_model_id
|
||||
|
||||
|
||||
def get_provider_from_model_id(model_id: str) -> Optional[str]:
|
||||
"""Get the provider name for a given model_id"""
|
||||
registry = ModelRegistry.get_instance()
|
||||
model = registry.get_model(model_id)
|
||||
if model:
|
||||
return model.provider.value
|
||||
return None
|
||||
|
||||
|
||||
def get_token_limit(model_id: str) -> int:
|
||||
"""
|
||||
Get context window (token limit) for a model.
|
||||
Returns model's context_window or default 128000 if model not found.
|
||||
"""
|
||||
from application.core.settings import settings
|
||||
|
||||
registry = ModelRegistry.get_instance()
|
||||
model = registry.get_model(model_id)
|
||||
if model:
|
||||
return model.capabilities.context_window
|
||||
return settings.DEFAULT_LLM_TOKEN_LIMIT
|
||||
|
||||
|
||||
def get_base_url_for_model(model_id: str) -> Optional[str]:
|
||||
"""
|
||||
Get the custom base_url for a specific model if configured.
|
||||
Returns None if no custom base_url is set.
|
||||
"""
|
||||
registry = ModelRegistry.get_instance()
|
||||
model = registry.get_model(model_id)
|
||||
if model:
|
||||
return model.base_url
|
||||
return None
|
||||
@@ -22,15 +22,7 @@ class Settings(BaseSettings):
|
||||
MONGO_DB_NAME: str = "docsgpt"
|
||||
LLM_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf")
|
||||
DEFAULT_MAX_HISTORY: int = 150
|
||||
LLM_TOKEN_LIMITS: dict = {
|
||||
"gpt-4o": 128000,
|
||||
"gpt-4o-mini": 128000,
|
||||
"gpt-4": 8192,
|
||||
"gpt-3.5-turbo": 4096,
|
||||
"claude-2": int(1e5),
|
||||
"gemini-2.5-flash": int(1e6),
|
||||
}
|
||||
DEFAULT_LLM_TOKEN_LIMIT: int = 128000
|
||||
DEFAULT_LLM_TOKEN_LIMIT: int = 128000 # Fallback when model not found in registry
|
||||
RESERVED_TOKENS: dict = {
|
||||
"system_prompt": 500,
|
||||
"current_query": 500,
|
||||
@@ -64,14 +56,22 @@ class Settings(BaseSettings):
|
||||
)
|
||||
|
||||
# GitHub source
|
||||
GITHUB_ACCESS_TOKEN: Optional[str] = None # PAT token with read repo access
|
||||
GITHUB_ACCESS_TOKEN: Optional[str] = None # PAT token with read repo access
|
||||
|
||||
# LLM Cache
|
||||
CACHE_REDIS_URL: str = "redis://localhost:6379/2"
|
||||
|
||||
API_URL: str = "http://localhost:7091" # backend url for celery worker
|
||||
|
||||
API_KEY: Optional[str] = None # LLM api key
|
||||
API_KEY: Optional[str] = None # LLM api key (used by LLM_PROVIDER)
|
||||
|
||||
# Provider-specific API keys (for multi-model support)
|
||||
OPENAI_API_KEY: Optional[str] = None
|
||||
ANTHROPIC_API_KEY: Optional[str] = None
|
||||
GOOGLE_API_KEY: Optional[str] = None
|
||||
GROQ_API_KEY: Optional[str] = None
|
||||
HUGGINGFACE_API_KEY: Optional[str] = None
|
||||
|
||||
EMBEDDINGS_KEY: Optional[str] = (
|
||||
None # api key for embeddings (if using openai, just copy API_KEY)
|
||||
)
|
||||
@@ -138,11 +138,19 @@ class Settings(BaseSettings):
|
||||
# Encryption settings
|
||||
ENCRYPTION_SECRET_KEY: str = "default-docsgpt-encryption-key"
|
||||
|
||||
TTS_PROVIDER: str = "google_tts" # google_tts or elevenlabs
|
||||
TTS_PROVIDER: str = "google_tts" # google_tts or elevenlabs
|
||||
ELEVENLABS_API_KEY: Optional[str] = None
|
||||
|
||||
# Tool pre-fetch settings
|
||||
ENABLE_TOOL_PREFETCH: bool = True
|
||||
|
||||
# Conversation Compression Settings
|
||||
ENABLE_CONVERSATION_COMPRESSION: bool = True
|
||||
COMPRESSION_THRESHOLD_PERCENTAGE: float = 0.8 # Trigger at 80% of context
|
||||
COMPRESSION_MODEL_OVERRIDE: Optional[str] = None # Use different model for compression
|
||||
COMPRESSION_PROMPT_VERSION: str = "v1.0" # Track prompt iterations
|
||||
COMPRESSION_MAX_HISTORY_POINTS: int = 3 # Keep only last N compression points to prevent DB bloat
|
||||
|
||||
|
||||
path = Path(__file__).parent.parent.absolute()
|
||||
settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8")
|
||||
|
||||
@@ -1,30 +1,41 @@
|
||||
from application.llm.base import BaseLLM
|
||||
from anthropic import AI_PROMPT, Anthropic, HUMAN_PROMPT
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.llm.base import BaseLLM
|
||||
|
||||
|
||||
class AnthropicLLM(BaseLLM):
|
||||
|
||||
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
|
||||
from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT
|
||||
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
self.api_key = (
|
||||
api_key or settings.ANTHROPIC_API_KEY
|
||||
) # If not provided, use a default from settings
|
||||
self.api_key = api_key or settings.ANTHROPIC_API_KEY or settings.API_KEY
|
||||
self.user_api_key = user_api_key
|
||||
self.anthropic = Anthropic(api_key=self.api_key)
|
||||
|
||||
# Use custom base_url if provided
|
||||
if base_url:
|
||||
self.anthropic = Anthropic(api_key=self.api_key, base_url=base_url)
|
||||
else:
|
||||
self.anthropic = Anthropic(api_key=self.api_key)
|
||||
|
||||
self.HUMAN_PROMPT = HUMAN_PROMPT
|
||||
self.AI_PROMPT = AI_PROMPT
|
||||
|
||||
def _raw_gen(
|
||||
self, baseself, model, messages, stream=False, tools=None, max_tokens=300, **kwargs
|
||||
self,
|
||||
baseself,
|
||||
model,
|
||||
messages,
|
||||
stream=False,
|
||||
tools=None,
|
||||
max_tokens=300,
|
||||
**kwargs,
|
||||
):
|
||||
context = messages[0]["content"]
|
||||
user_question = messages[-1]["content"]
|
||||
prompt = f"### Context \n {context} \n ### Question \n {user_question}"
|
||||
if stream:
|
||||
return self.gen_stream(model, prompt, stream, max_tokens, **kwargs)
|
||||
|
||||
completion = self.anthropic.completions.create(
|
||||
model=model,
|
||||
max_tokens_to_sample=max_tokens,
|
||||
@@ -34,7 +45,14 @@ class AnthropicLLM(BaseLLM):
|
||||
return completion.completion
|
||||
|
||||
def _raw_gen_stream(
|
||||
self, baseself, model, messages, stream=True, tools=None, max_tokens=300, **kwargs
|
||||
self,
|
||||
baseself,
|
||||
model,
|
||||
messages,
|
||||
stream=True,
|
||||
tools=None,
|
||||
max_tokens=300,
|
||||
**kwargs,
|
||||
):
|
||||
context = messages[0]["content"]
|
||||
user_question = messages[-1]["content"]
|
||||
@@ -50,5 +68,5 @@ class AnthropicLLM(BaseLLM):
|
||||
for completion in stream_response:
|
||||
yield completion.completion
|
||||
finally:
|
||||
if hasattr(stream_response, 'close'):
|
||||
if hasattr(stream_response, "close"):
|
||||
stream_response.close()
|
||||
|
||||
@@ -13,30 +13,32 @@ class BaseLLM(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
decoded_token=None,
|
||||
model_id=None,
|
||||
base_url=None,
|
||||
):
|
||||
self.decoded_token = decoded_token
|
||||
self.model_id = model_id
|
||||
self.base_url = base_url
|
||||
self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
|
||||
self.fallback_provider = settings.FALLBACK_LLM_PROVIDER
|
||||
self.fallback_model_name = settings.FALLBACK_LLM_NAME
|
||||
self.fallback_llm_api_key = settings.FALLBACK_LLM_API_KEY
|
||||
self._fallback_llm = None
|
||||
self._fallback_sequence_index = 0
|
||||
|
||||
@property
|
||||
def fallback_llm(self):
|
||||
"""Lazy-loaded fallback LLM instance."""
|
||||
if (
|
||||
self._fallback_llm is None
|
||||
and self.fallback_provider
|
||||
and self.fallback_model_name
|
||||
):
|
||||
"""Lazy-loaded fallback LLM from FALLBACK_* settings."""
|
||||
if self._fallback_llm is None and settings.FALLBACK_LLM_PROVIDER:
|
||||
try:
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
|
||||
self._fallback_llm = LLMCreator.create_llm(
|
||||
self.fallback_provider,
|
||||
self.fallback_llm_api_key,
|
||||
None,
|
||||
self.decoded_token,
|
||||
settings.FALLBACK_LLM_PROVIDER,
|
||||
api_key=settings.FALLBACK_LLM_API_KEY or settings.API_KEY,
|
||||
user_api_key=None,
|
||||
decoded_token=self.decoded_token,
|
||||
model_id=settings.FALLBACK_LLM_NAME,
|
||||
)
|
||||
logger.info(
|
||||
f"Fallback LLM initialized: {settings.FALLBACK_LLM_PROVIDER}/{settings.FALLBACK_LLM_NAME}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
@@ -54,7 +56,7 @@ class BaseLLM(ABC):
|
||||
self, method_name: str, decorators: list, *args, **kwargs
|
||||
):
|
||||
"""
|
||||
Unified method execution with fallback support.
|
||||
Execute method with fallback support.
|
||||
|
||||
Args:
|
||||
method_name: Name of the raw method ('_raw_gen' or '_raw_gen_stream')
|
||||
@@ -73,10 +75,10 @@ class BaseLLM(ABC):
|
||||
return decorated_method()
|
||||
except Exception as e:
|
||||
if not self.fallback_llm:
|
||||
logger.error(f"Primary LLM failed and no fallback available: {str(e)}")
|
||||
logger.error(f"Primary LLM failed and no fallback configured: {str(e)}")
|
||||
raise
|
||||
logger.warning(
|
||||
f"Falling back to {self.fallback_provider}/{self.fallback_model_name}. Error: {str(e)}"
|
||||
f"Primary LLM failed. Falling back to {settings.FALLBACK_LLM_PROVIDER}/{settings.FALLBACK_LLM_NAME}. Error: {str(e)}"
|
||||
)
|
||||
|
||||
fallback_method = getattr(
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import json
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.llm.base import BaseLLM
|
||||
|
||||
@@ -7,12 +9,11 @@ from application.llm.base import BaseLLM
|
||||
class DocsGPTAPILLM(BaseLLM):
|
||||
|
||||
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
|
||||
from openai import OpenAI
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
self.client = OpenAI(api_key="sk-docsgpt-public", base_url="https://oai.arc53.com")
|
||||
self.api_key = "sk-docsgpt-public"
|
||||
self.client = OpenAI(api_key=self.api_key, base_url="https://oai.arc53.com")
|
||||
self.user_api_key = user_api_key
|
||||
self.api_key = api_key
|
||||
|
||||
def _clean_messages_openai(self, messages):
|
||||
cleaned_messages = []
|
||||
@@ -22,7 +23,6 @@ class DocsGPTAPILLM(BaseLLM):
|
||||
|
||||
if role == "model":
|
||||
role = "assistant"
|
||||
|
||||
if role and content is not None:
|
||||
if isinstance(content, str):
|
||||
cleaned_messages.append({"role": role, "content": content})
|
||||
@@ -69,7 +69,6 @@ class DocsGPTAPILLM(BaseLLM):
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unexpected content type: {type(content)}")
|
||||
|
||||
return cleaned_messages
|
||||
|
||||
def _raw_gen(
|
||||
@@ -121,7 +120,6 @@ class DocsGPTAPILLM(BaseLLM):
|
||||
response = self.client.chat.completions.create(
|
||||
model="docsgpt", messages=messages, stream=stream, **kwargs
|
||||
)
|
||||
|
||||
try:
|
||||
for line in response:
|
||||
if (
|
||||
@@ -133,8 +131,8 @@ class DocsGPTAPILLM(BaseLLM):
|
||||
elif len(line.choices) > 0:
|
||||
yield line.choices[0]
|
||||
finally:
|
||||
if hasattr(response, 'close'):
|
||||
if hasattr(response, "close"):
|
||||
response.close()
|
||||
|
||||
def _supports_tools(self):
|
||||
return True
|
||||
return True
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from google import genai
|
||||
@@ -11,10 +10,13 @@ from application.storage.storage_creator import StorageCreator
|
||||
|
||||
|
||||
class GoogleLLM(BaseLLM):
|
||||
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
|
||||
def __init__(
|
||||
self, api_key=None, user_api_key=None, decoded_token=None, *args, **kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.api_key = api_key
|
||||
self.api_key = api_key or settings.GOOGLE_API_KEY or settings.API_KEY
|
||||
self.user_api_key = user_api_key
|
||||
|
||||
self.client = genai.Client(api_key=self.api_key)
|
||||
self.storage = StorageCreator.get_storage()
|
||||
|
||||
@@ -32,6 +34,12 @@ class GoogleLLM(BaseLLM):
|
||||
"image/jpg",
|
||||
"image/webp",
|
||||
"image/gif",
|
||||
"application/pdf",
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/jpg",
|
||||
"image/webp",
|
||||
"image/gif",
|
||||
]
|
||||
|
||||
def prepare_messages_with_attachments(self, messages, attachments=None):
|
||||
@@ -47,21 +55,19 @@ class GoogleLLM(BaseLLM):
|
||||
"""
|
||||
if not attachments:
|
||||
return messages
|
||||
|
||||
prepared_messages = messages.copy()
|
||||
|
||||
# Find the user message to attach files to the last one
|
||||
|
||||
user_message_index = None
|
||||
for i in range(len(prepared_messages) - 1, -1, -1):
|
||||
if prepared_messages[i].get("role") == "user":
|
||||
user_message_index = i
|
||||
break
|
||||
|
||||
if user_message_index is None:
|
||||
user_message = {"role": "user", "content": []}
|
||||
prepared_messages.append(user_message)
|
||||
user_message_index = len(prepared_messages) - 1
|
||||
|
||||
if isinstance(prepared_messages[user_message_index].get("content"), str):
|
||||
text_content = prepared_messages[user_message_index]["content"]
|
||||
prepared_messages[user_message_index]["content"] = [
|
||||
@@ -69,7 +75,6 @@ class GoogleLLM(BaseLLM):
|
||||
]
|
||||
elif not isinstance(prepared_messages[user_message_index].get("content"), list):
|
||||
prepared_messages[user_message_index]["content"] = []
|
||||
|
||||
files = []
|
||||
for attachment in attachments:
|
||||
mime_type = attachment.get("mime_type")
|
||||
@@ -92,11 +97,9 @@ class GoogleLLM(BaseLLM):
|
||||
"text": f"[File could not be processed: {attachment.get('path', 'unknown')}]",
|
||||
}
|
||||
)
|
||||
|
||||
if files:
|
||||
logging.info(f"GoogleLLM: Adding {len(files)} files to message")
|
||||
prepared_messages[user_message_index]["content"].append({"files": files})
|
||||
|
||||
return prepared_messages
|
||||
|
||||
def _upload_file_to_google(self, attachment):
|
||||
@@ -111,14 +114,11 @@ class GoogleLLM(BaseLLM):
|
||||
"""
|
||||
if "google_file_uri" in attachment:
|
||||
return attachment["google_file_uri"]
|
||||
|
||||
file_path = attachment.get("path")
|
||||
if not file_path:
|
||||
raise ValueError("No file path provided in attachment")
|
||||
|
||||
if not self.storage.file_exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
try:
|
||||
file_uri = self.storage.process_file(
|
||||
file_path,
|
||||
@@ -136,24 +136,48 @@ class GoogleLLM(BaseLLM):
|
||||
attachments_collection.update_one(
|
||||
{"_id": attachment["_id"]}, {"$set": {"google_file_uri": file_uri}}
|
||||
)
|
||||
|
||||
return file_uri
|
||||
except Exception as e:
|
||||
logging.error(f"Error uploading file to Google AI: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def _clean_messages_google(self, messages):
|
||||
"""Convert OpenAI format messages to Google AI format."""
|
||||
"""
|
||||
Convert OpenAI format messages to Google AI format and collect system prompts.
|
||||
|
||||
Returns:
|
||||
tuple[list[types.Content], Optional[str]]: cleaned messages and optional
|
||||
combined system instruction.
|
||||
"""
|
||||
cleaned_messages = []
|
||||
system_instructions = []
|
||||
|
||||
def _extract_system_text(content):
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict) and "text" in item and item["text"] is not None:
|
||||
parts.append(item["text"])
|
||||
return "\n".join(parts)
|
||||
return ""
|
||||
|
||||
for message in messages:
|
||||
role = message.get("role")
|
||||
content = message.get("content")
|
||||
|
||||
# Gemini only accepts user/model in the contents list.
|
||||
if role == "system":
|
||||
sys_text = _extract_system_text(content)
|
||||
if sys_text:
|
||||
system_instructions.append(sys_text)
|
||||
continue
|
||||
|
||||
if role == "assistant":
|
||||
role = "model"
|
||||
elif role == "tool":
|
||||
role = "model"
|
||||
|
||||
parts = []
|
||||
if role and content is not None:
|
||||
if isinstance(content, str):
|
||||
@@ -164,15 +188,31 @@ class GoogleLLM(BaseLLM):
|
||||
parts.append(types.Part.from_text(text=item["text"]))
|
||||
elif "function_call" in item:
|
||||
# Remove null values from args to avoid API errors
|
||||
|
||||
cleaned_args = self._remove_null_values(
|
||||
item["function_call"]["args"]
|
||||
)
|
||||
parts.append(
|
||||
types.Part.from_function_call(
|
||||
name=item["function_call"]["name"],
|
||||
args=cleaned_args,
|
||||
# Create function call part with thought_signature if present
|
||||
# For Gemini 3 models, we need to include thought_signature
|
||||
if "thought_signature" in item:
|
||||
# Use Part constructor with functionCall and thoughtSignature
|
||||
parts.append(
|
||||
types.Part(
|
||||
functionCall=types.FunctionCall(
|
||||
name=item["function_call"]["name"],
|
||||
args=cleaned_args,
|
||||
),
|
||||
thoughtSignature=item["thought_signature"],
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Use helper method when no thought_signature
|
||||
parts.append(
|
||||
types.Part.from_function_call(
|
||||
name=item["function_call"]["name"],
|
||||
args=cleaned_args,
|
||||
)
|
||||
)
|
||||
)
|
||||
elif "function_response" in item:
|
||||
parts.append(
|
||||
types.Part.from_function_response(
|
||||
@@ -194,11 +234,10 @@ class GoogleLLM(BaseLLM):
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unexpected content type: {type(content)}")
|
||||
|
||||
if parts:
|
||||
cleaned_messages.append(types.Content(role=role, parts=parts))
|
||||
|
||||
return cleaned_messages
|
||||
system_instruction = "\n\n".join(system_instructions) if system_instructions else None
|
||||
return cleaned_messages, system_instruction
|
||||
|
||||
def _clean_schema(self, schema_obj):
|
||||
"""
|
||||
@@ -233,8 +272,8 @@ class GoogleLLM(BaseLLM):
|
||||
cleaned[key] = [self._clean_schema(item) for item in value]
|
||||
else:
|
||||
cleaned[key] = value
|
||||
|
||||
# Validate that required properties actually exist in properties
|
||||
|
||||
if "required" in cleaned and "properties" in cleaned:
|
||||
valid_required = []
|
||||
properties_keys = set(cleaned["properties"].keys())
|
||||
@@ -247,7 +286,6 @@ class GoogleLLM(BaseLLM):
|
||||
cleaned.pop("required", None)
|
||||
elif "required" in cleaned and "properties" not in cleaned:
|
||||
cleaned.pop("required", None)
|
||||
|
||||
return cleaned
|
||||
|
||||
def _clean_tools_format(self, tools_list):
|
||||
@@ -263,7 +301,6 @@ class GoogleLLM(BaseLLM):
|
||||
cleaned_properties = {}
|
||||
for k, v in properties.items():
|
||||
cleaned_properties[k] = self._clean_schema(v)
|
||||
|
||||
genai_function = dict(
|
||||
name=function["name"],
|
||||
description=function["description"],
|
||||
@@ -282,12 +319,65 @@ class GoogleLLM(BaseLLM):
|
||||
name=function["name"],
|
||||
description=function["description"],
|
||||
)
|
||||
|
||||
genai_tool = types.Tool(function_declarations=[genai_function])
|
||||
genai_tools.append(genai_tool)
|
||||
|
||||
return genai_tools
|
||||
|
||||
def _extract_preview_from_message(self, message):
|
||||
"""Get a short, human-readable preview from the last message."""
|
||||
try:
|
||||
if hasattr(message, "parts"):
|
||||
for part in reversed(message.parts):
|
||||
if getattr(part, "text", None):
|
||||
return part.text
|
||||
function_call = getattr(part, "function_call", None)
|
||||
if function_call:
|
||||
name = getattr(function_call, "name", "") or "function_call"
|
||||
return f"function_call:{name}"
|
||||
function_response = getattr(part, "function_response", None)
|
||||
if function_response:
|
||||
name = getattr(function_response, "name", "") or "function_response"
|
||||
return f"function_response:{name}"
|
||||
if isinstance(message, dict):
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
for item in reversed(content):
|
||||
if isinstance(item, str):
|
||||
return item
|
||||
if isinstance(item, dict):
|
||||
if item.get("text"):
|
||||
return item["text"]
|
||||
if item.get("function_call"):
|
||||
fn = item["function_call"]
|
||||
if isinstance(fn, dict):
|
||||
name = fn.get("name") or "function_call"
|
||||
return f"function_call:{name}"
|
||||
return "function_call"
|
||||
if item.get("function_response"):
|
||||
resp = item["function_response"]
|
||||
if isinstance(resp, dict):
|
||||
name = resp.get("name") or "function_response"
|
||||
return f"function_response:{name}"
|
||||
return "function_response"
|
||||
if "text" in message and isinstance(message["text"], str):
|
||||
return message["text"]
|
||||
except Exception:
|
||||
pass
|
||||
return str(message)
|
||||
|
||||
def _summarize_messages_for_log(self, messages, preview_chars=20):
|
||||
"""Return a compact summary for logging to avoid huge payloads."""
|
||||
message_count = len(messages) if messages else 0
|
||||
last_preview = ""
|
||||
if messages:
|
||||
last_preview = self._extract_preview_from_message(messages[-1]) or ""
|
||||
last_preview = str(last_preview).replace("\n", " ")
|
||||
if len(last_preview) > preview_chars:
|
||||
last_preview = f"{last_preview[:preview_chars]}..."
|
||||
return f"count={message_count}, last='{last_preview}'"
|
||||
|
||||
def _raw_gen(
|
||||
self,
|
||||
baseself,
|
||||
@@ -301,22 +391,20 @@ class GoogleLLM(BaseLLM):
|
||||
):
|
||||
"""Generate content using Google AI API without streaming."""
|
||||
client = genai.Client(api_key=self.api_key)
|
||||
system_instruction = None
|
||||
if formatting == "openai":
|
||||
messages = self._clean_messages_google(messages)
|
||||
messages, system_instruction = self._clean_messages_google(messages)
|
||||
config = types.GenerateContentConfig()
|
||||
if messages[0].role == "system":
|
||||
config.system_instruction = messages[0].parts[0].text
|
||||
messages = messages[1:]
|
||||
|
||||
if system_instruction:
|
||||
config.system_instruction = system_instruction
|
||||
if tools:
|
||||
cleaned_tools = self._clean_tools_format(tools)
|
||||
config.tools = cleaned_tools
|
||||
|
||||
# Add response schema for structured output if provided
|
||||
|
||||
if response_schema:
|
||||
config.response_schema = response_schema
|
||||
config.response_mime_type = "application/json"
|
||||
|
||||
response = client.models.generate_content(
|
||||
model=model,
|
||||
contents=messages,
|
||||
@@ -341,23 +429,22 @@ class GoogleLLM(BaseLLM):
|
||||
):
|
||||
"""Generate content using Google AI API with streaming."""
|
||||
client = genai.Client(api_key=self.api_key)
|
||||
system_instruction = None
|
||||
if formatting == "openai":
|
||||
messages = self._clean_messages_google(messages)
|
||||
messages, system_instruction = self._clean_messages_google(messages)
|
||||
config = types.GenerateContentConfig()
|
||||
if messages[0].role == "system":
|
||||
config.system_instruction = messages[0].parts[0].text
|
||||
messages = messages[1:]
|
||||
|
||||
if system_instruction:
|
||||
config.system_instruction = system_instruction
|
||||
if tools:
|
||||
cleaned_tools = self._clean_tools_format(tools)
|
||||
config.tools = cleaned_tools
|
||||
|
||||
# Add response schema for structured output if provided
|
||||
|
||||
if response_schema:
|
||||
config.response_schema = response_schema
|
||||
config.response_mime_type = "application/json"
|
||||
|
||||
# Check if we have both tools and file attachments
|
||||
|
||||
has_attachments = False
|
||||
for message in messages:
|
||||
for part in message.parts:
|
||||
@@ -366,9 +453,12 @@ class GoogleLLM(BaseLLM):
|
||||
break
|
||||
if has_attachments:
|
||||
break
|
||||
|
||||
messages_summary = self._summarize_messages_for_log(messages)
|
||||
logging.info(
|
||||
f"GoogleLLM: Starting stream generation. Model: {model}, Messages: {json.dumps(messages, default=str)}, Has attachments: {has_attachments}"
|
||||
"GoogleLLM: Starting stream generation. Model: %s, Messages: %s, Has attachments: %s",
|
||||
model,
|
||||
messages_summary,
|
||||
has_attachments,
|
||||
)
|
||||
|
||||
response = client.models.generate_content_stream(
|
||||
@@ -405,7 +495,6 @@ class GoogleLLM(BaseLLM):
|
||||
"""Convert JSON schema to Google AI structured output format."""
|
||||
if not json_schema:
|
||||
return None
|
||||
|
||||
type_map = {
|
||||
"object": "OBJECT",
|
||||
"array": "ARRAY",
|
||||
@@ -418,12 +507,10 @@ class GoogleLLM(BaseLLM):
|
||||
def convert(schema):
|
||||
if not isinstance(schema, dict):
|
||||
return schema
|
||||
|
||||
result = {}
|
||||
schema_type = schema.get("type")
|
||||
if schema_type:
|
||||
result["type"] = type_map.get(schema_type.lower(), schema_type.upper())
|
||||
|
||||
for key in [
|
||||
"description",
|
||||
"nullable",
|
||||
@@ -435,7 +522,6 @@ class GoogleLLM(BaseLLM):
|
||||
]:
|
||||
if key in schema:
|
||||
result[key] = schema[key]
|
||||
|
||||
if "format" in schema:
|
||||
format_value = schema["format"]
|
||||
if schema_type == "string":
|
||||
@@ -445,21 +531,17 @@ class GoogleLLM(BaseLLM):
|
||||
result["format"] = format_value
|
||||
else:
|
||||
result["format"] = format_value
|
||||
|
||||
if "properties" in schema:
|
||||
result["properties"] = {
|
||||
k: convert(v) for k, v in schema["properties"].items()
|
||||
}
|
||||
if "propertyOrdering" not in result and result.get("type") == "OBJECT":
|
||||
result["propertyOrdering"] = list(result["properties"].keys())
|
||||
|
||||
if "items" in schema:
|
||||
result["items"] = convert(schema["items"])
|
||||
|
||||
for field in ["anyOf", "oneOf", "allOf"]:
|
||||
if field in schema:
|
||||
result[field] = [convert(s) for s in schema[field]]
|
||||
|
||||
return result
|
||||
|
||||
try:
|
||||
|
||||
@@ -1,13 +1,18 @@
|
||||
from application.llm.base import BaseLLM
|
||||
from openai import OpenAI
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.llm.base import BaseLLM
|
||||
|
||||
|
||||
class GroqLLM(BaseLLM):
|
||||
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
self.client = OpenAI(api_key=api_key, base_url="https://api.groq.com/openai/v1")
|
||||
self.api_key = api_key
|
||||
self.api_key = api_key or settings.GROQ_API_KEY or settings.API_KEY
|
||||
self.user_api_key = user_api_key
|
||||
self.client = OpenAI(
|
||||
api_key=self.api_key, base_url="https://api.groq.com/openai/v1"
|
||||
)
|
||||
|
||||
def _raw_gen(self, baseself, model, messages, stream=False, tools=None, **kwargs):
|
||||
if tools:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Generator, List, Optional, Union
|
||||
@@ -16,6 +17,7 @@ class ToolCall:
|
||||
name: str
|
||||
arguments: Union[str, Dict]
|
||||
index: Optional[int] = None
|
||||
thought_signature: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict) -> "ToolCall":
|
||||
@@ -178,6 +180,406 @@ class LLMHandler(ABC):
|
||||
system_msg["content"] += f"\n\n{combined_text}"
|
||||
return prepared_messages
|
||||
|
||||
def _prune_messages_minimal(self, messages: List[Dict]) -> Optional[List[Dict]]:
|
||||
"""
|
||||
Build a minimal context: system prompt + latest user message only.
|
||||
Drops all tool/function messages to shrink context aggressively.
|
||||
"""
|
||||
system_message = next((m for m in messages if m.get("role") == "system"), None)
|
||||
if not system_message:
|
||||
logger.warning("Cannot prune messages minimally: missing system message.")
|
||||
return None
|
||||
last_non_system = None
|
||||
for m in reversed(messages):
|
||||
if m.get("role") == "user":
|
||||
last_non_system = m
|
||||
break
|
||||
if not last_non_system and m.get("role") not in ("system", None):
|
||||
last_non_system = m
|
||||
if not last_non_system:
|
||||
logger.warning("Cannot prune messages minimally: missing user/assistant messages.")
|
||||
return None
|
||||
logger.info("Pruning context to system + latest user/assistant message to proceed.")
|
||||
return [system_message, last_non_system]
|
||||
|
||||
def _extract_text_from_content(self, content: Any) -> str:
|
||||
"""
|
||||
Convert message content (str or list of parts) to plain text for compression.
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts_text = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
if "text" in item and item["text"] is not None:
|
||||
parts_text.append(str(item["text"]))
|
||||
elif "function_call" in item or "function_response" in item:
|
||||
# Keep serialized function calls/responses so the compressor sees actions
|
||||
parts_text.append(str(item))
|
||||
elif "files" in item:
|
||||
parts_text.append(str(item))
|
||||
return "\n".join(parts_text)
|
||||
return ""
|
||||
|
||||
def _build_conversation_from_messages(self, messages: List[Dict]) -> Optional[Dict]:
|
||||
"""
|
||||
Build a conversation-like dict from current messages so we can compress
|
||||
even when the conversation isn't persisted yet. Includes tool calls/results.
|
||||
"""
|
||||
queries = []
|
||||
current_prompt = None
|
||||
current_tool_calls = {}
|
||||
|
||||
def _commit_query(response_text: str):
|
||||
nonlocal current_prompt, current_tool_calls
|
||||
if current_prompt is None and not response_text:
|
||||
return
|
||||
tool_calls_list = list(current_tool_calls.values())
|
||||
queries.append(
|
||||
{
|
||||
"prompt": current_prompt or "",
|
||||
"response": response_text,
|
||||
"tool_calls": tool_calls_list,
|
||||
}
|
||||
)
|
||||
current_prompt = None
|
||||
current_tool_calls = {}
|
||||
|
||||
for message in messages:
|
||||
role = message.get("role")
|
||||
content = message.get("content")
|
||||
|
||||
if role == "user":
|
||||
current_prompt = self._extract_text_from_content(content)
|
||||
|
||||
elif role in {"assistant", "model"}:
|
||||
# If this assistant turn contains tool calls, collect them; otherwise commit a response.
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if "function_call" in item:
|
||||
fc = item["function_call"]
|
||||
call_id = fc.get("call_id") or str(uuid.uuid4())
|
||||
current_tool_calls[call_id] = {
|
||||
"tool_name": "unknown_tool",
|
||||
"action_name": fc.get("name"),
|
||||
"arguments": fc.get("args"),
|
||||
"result": None,
|
||||
"status": "called",
|
||||
"call_id": call_id,
|
||||
}
|
||||
elif "function_response" in item:
|
||||
fr = item["function_response"]
|
||||
call_id = fr.get("call_id") or str(uuid.uuid4())
|
||||
current_tool_calls[call_id] = {
|
||||
"tool_name": "unknown_tool",
|
||||
"action_name": fr.get("name"),
|
||||
"arguments": None,
|
||||
"result": fr.get("response", {}).get("result"),
|
||||
"status": "completed",
|
||||
"call_id": call_id,
|
||||
}
|
||||
# No direct assistant text here; continue to next message
|
||||
continue
|
||||
|
||||
response_text = self._extract_text_from_content(content)
|
||||
_commit_query(response_text)
|
||||
|
||||
elif role == "tool":
|
||||
# Attach tool outputs to the latest pending tool call if possible
|
||||
tool_text = self._extract_text_from_content(content)
|
||||
# Attempt to parse function_response style
|
||||
call_id = None
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if "function_response" in item and item["function_response"].get("call_id"):
|
||||
call_id = item["function_response"]["call_id"]
|
||||
break
|
||||
if call_id and call_id in current_tool_calls:
|
||||
current_tool_calls[call_id]["result"] = tool_text
|
||||
current_tool_calls[call_id]["status"] = "completed"
|
||||
elif queries:
|
||||
queries[-1].setdefault("tool_calls", []).append(
|
||||
{
|
||||
"tool_name": "unknown_tool",
|
||||
"action_name": "unknown_action",
|
||||
"arguments": {},
|
||||
"result": tool_text,
|
||||
"status": "completed",
|
||||
}
|
||||
)
|
||||
|
||||
# If there's an unfinished prompt with tool_calls but no response yet, commit it
|
||||
if current_prompt is not None or current_tool_calls:
|
||||
_commit_query(response_text="")
|
||||
|
||||
if not queries:
|
||||
return None
|
||||
|
||||
return {
|
||||
"queries": queries,
|
||||
"compression_metadata": {
|
||||
"is_compressed": False,
|
||||
"compression_points": [],
|
||||
},
|
||||
}
|
||||
|
||||
def _rebuild_messages_after_compression(
|
||||
self,
|
||||
messages: List[Dict],
|
||||
compressed_summary: Optional[str],
|
||||
recent_queries: List[Dict],
|
||||
include_current_execution: bool = False,
|
||||
include_tool_calls: bool = False,
|
||||
) -> Optional[List[Dict]]:
|
||||
"""
|
||||
Rebuild the message list after compression so tool execution can continue.
|
||||
|
||||
Delegates to MessageBuilder for the actual reconstruction.
|
||||
"""
|
||||
from application.api.answer.services.compression.message_builder import (
|
||||
MessageBuilder,
|
||||
)
|
||||
|
||||
return MessageBuilder.rebuild_messages_after_compression(
|
||||
messages=messages,
|
||||
compressed_summary=compressed_summary,
|
||||
recent_queries=recent_queries,
|
||||
include_current_execution=include_current_execution,
|
||||
include_tool_calls=include_tool_calls,
|
||||
)
|
||||
|
||||
def _perform_mid_execution_compression(
|
||||
self, agent, messages: List[Dict]
|
||||
) -> tuple[bool, Optional[List[Dict]]]:
|
||||
"""
|
||||
Perform compression during tool execution and rebuild messages.
|
||||
|
||||
Uses the new orchestrator for simplified compression.
|
||||
|
||||
Args:
|
||||
agent: The agent instance
|
||||
messages: Current conversation messages
|
||||
|
||||
Returns:
|
||||
(success: bool, rebuilt_messages: Optional[List[Dict]])
|
||||
"""
|
||||
try:
|
||||
from application.api.answer.services.compression import (
|
||||
CompressionOrchestrator,
|
||||
)
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
|
||||
conversation_service = ConversationService()
|
||||
orchestrator = CompressionOrchestrator(conversation_service)
|
||||
|
||||
# Get conversation from database (may be None for new sessions)
|
||||
conversation = conversation_service.get_conversation(
|
||||
agent.conversation_id, agent.initial_user_id
|
||||
)
|
||||
|
||||
if conversation:
|
||||
# Merge current in-flight messages (including tool calls)
|
||||
conversation_from_msgs = self._build_conversation_from_messages(messages)
|
||||
if conversation_from_msgs:
|
||||
conversation = conversation_from_msgs
|
||||
else:
|
||||
logger.warning(
|
||||
"Could not load conversation for compression; attempting in-memory compression"
|
||||
)
|
||||
return self._perform_in_memory_compression(agent, messages)
|
||||
|
||||
# Use orchestrator to perform compression
|
||||
result = orchestrator.compress_mid_execution(
|
||||
conversation_id=agent.conversation_id,
|
||||
user_id=agent.initial_user_id,
|
||||
model_id=agent.model_id,
|
||||
decoded_token=getattr(agent, "decoded_token", {}),
|
||||
current_conversation=conversation,
|
||||
)
|
||||
|
||||
if not result.success:
|
||||
logger.warning(f"Mid-execution compression failed: {result.error}")
|
||||
# Try minimal pruning as fallback
|
||||
pruned = self._prune_messages_minimal(messages)
|
||||
if pruned:
|
||||
agent.context_limit_reached = False
|
||||
agent.current_token_count = 0
|
||||
return True, pruned
|
||||
return False, None
|
||||
|
||||
if not result.compression_performed:
|
||||
logger.warning("Compression not performed")
|
||||
return False, None
|
||||
|
||||
# Check if compression actually reduced tokens
|
||||
if result.metadata:
|
||||
if result.metadata.compressed_token_count >= result.metadata.original_token_count:
|
||||
logger.warning(
|
||||
"Compression did not reduce token count; falling back to minimal pruning"
|
||||
)
|
||||
pruned = self._prune_messages_minimal(messages)
|
||||
if pruned:
|
||||
agent.context_limit_reached = False
|
||||
agent.current_token_count = 0
|
||||
return True, pruned
|
||||
return False, None
|
||||
|
||||
logger.info(
|
||||
f"Mid-execution compression successful - ratio: {result.metadata.compression_ratio:.1f}x, "
|
||||
f"saved {result.metadata.original_token_count - result.metadata.compressed_token_count} tokens"
|
||||
)
|
||||
|
||||
# Also store the compression summary as a visible message
|
||||
if result.metadata:
|
||||
conversation_service.append_compression_message(
|
||||
agent.conversation_id, result.metadata.to_dict()
|
||||
)
|
||||
|
||||
# Update agent's compressed summary for downstream persistence
|
||||
agent.compressed_summary = result.compressed_summary
|
||||
agent.compression_metadata = result.metadata.to_dict() if result.metadata else None
|
||||
agent.compression_saved = False
|
||||
|
||||
# Reset the context limit flag so tools can continue
|
||||
agent.context_limit_reached = False
|
||||
agent.current_token_count = 0
|
||||
|
||||
# Rebuild messages
|
||||
rebuilt_messages = self._rebuild_messages_after_compression(
|
||||
messages,
|
||||
result.compressed_summary,
|
||||
result.recent_queries,
|
||||
include_current_execution=False,
|
||||
include_tool_calls=False,
|
||||
)
|
||||
|
||||
if rebuilt_messages is None:
|
||||
return False, None
|
||||
|
||||
return True, rebuilt_messages
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error performing mid-execution compression: {str(e)}", exc_info=True
|
||||
)
|
||||
return False, None
|
||||
|
||||
def _perform_in_memory_compression(
|
||||
self, agent, messages: List[Dict]
|
||||
) -> tuple[bool, Optional[List[Dict]]]:
|
||||
"""
|
||||
Fallback compression path when the conversation is not yet persisted.
|
||||
|
||||
Uses CompressionService directly without DB persistence.
|
||||
"""
|
||||
try:
|
||||
from application.api.answer.services.compression.service import (
|
||||
CompressionService,
|
||||
)
|
||||
from application.core.model_utils import (
|
||||
get_api_key_for_provider,
|
||||
get_provider_from_model_id,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
|
||||
conversation = self._build_conversation_from_messages(messages)
|
||||
if not conversation:
|
||||
logger.warning(
|
||||
"Cannot perform in-memory compression: no user/assistant turns found"
|
||||
)
|
||||
return False, None
|
||||
|
||||
compression_model = (
|
||||
settings.COMPRESSION_MODEL_OVERRIDE
|
||||
if settings.COMPRESSION_MODEL_OVERRIDE
|
||||
else agent.model_id
|
||||
)
|
||||
provider = get_provider_from_model_id(compression_model)
|
||||
api_key = get_api_key_for_provider(provider)
|
||||
compression_llm = LLMCreator.create_llm(
|
||||
provider,
|
||||
api_key,
|
||||
getattr(agent, "user_api_key", None),
|
||||
getattr(agent, "decoded_token", None),
|
||||
model_id=compression_model,
|
||||
)
|
||||
|
||||
# Create service without DB persistence capability
|
||||
compression_service = CompressionService(
|
||||
llm=compression_llm,
|
||||
model_id=compression_model,
|
||||
conversation_service=None, # No DB updates for in-memory
|
||||
)
|
||||
|
||||
queries_count = len(conversation.get("queries", []))
|
||||
compress_up_to = queries_count - 1
|
||||
|
||||
if compress_up_to < 0 or queries_count == 0:
|
||||
logger.warning("Not enough queries to compress in-memory context")
|
||||
return False, None
|
||||
|
||||
metadata = compression_service.compress_conversation(
|
||||
conversation,
|
||||
compress_up_to_index=compress_up_to,
|
||||
)
|
||||
|
||||
# If compression doesn't reduce tokens, fall back to minimal pruning
|
||||
if (
|
||||
metadata.compressed_token_count
|
||||
>= metadata.original_token_count
|
||||
):
|
||||
logger.warning(
|
||||
"In-memory compression did not reduce token count; falling back to minimal pruning"
|
||||
)
|
||||
pruned = self._prune_messages_minimal(messages)
|
||||
if pruned:
|
||||
agent.context_limit_reached = False
|
||||
agent.current_token_count = 0
|
||||
return True, pruned
|
||||
return False, None
|
||||
|
||||
# Attach metadata to synthetic conversation
|
||||
conversation["compression_metadata"] = {
|
||||
"is_compressed": True,
|
||||
"compression_points": [metadata.to_dict()],
|
||||
}
|
||||
|
||||
compressed_summary, recent_queries = (
|
||||
compression_service.get_compressed_context(conversation)
|
||||
)
|
||||
|
||||
agent.compressed_summary = compressed_summary
|
||||
agent.compression_metadata = metadata.to_dict()
|
||||
agent.compression_saved = False
|
||||
agent.context_limit_reached = False
|
||||
agent.current_token_count = 0
|
||||
|
||||
rebuilt_messages = self._rebuild_messages_after_compression(
|
||||
messages,
|
||||
compressed_summary,
|
||||
recent_queries,
|
||||
include_current_execution=False,
|
||||
include_tool_calls=False,
|
||||
)
|
||||
if rebuilt_messages is None:
|
||||
return False, None
|
||||
|
||||
logger.info(
|
||||
f"In-memory compression successful - ratio: {metadata.compression_ratio:.1f}x, "
|
||||
f"saved {metadata.original_token_count - metadata.compressed_token_count} tokens"
|
||||
)
|
||||
return True, rebuilt_messages
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error performing in-memory compression: {str(e)}", exc_info=True
|
||||
)
|
||||
return False, None
|
||||
|
||||
def handle_tool_calls(
|
||||
self, agent, tool_calls: List[ToolCall], tools_dict: Dict, messages: List[Dict]
|
||||
) -> Generator:
|
||||
@@ -195,7 +597,110 @@ class LLMHandler(ABC):
|
||||
"""
|
||||
updated_messages = messages.copy()
|
||||
|
||||
for call in tool_calls:
|
||||
for i, call in enumerate(tool_calls):
|
||||
# Check context limit before executing tool call
|
||||
if hasattr(agent, '_check_context_limit') and agent._check_context_limit(updated_messages):
|
||||
# Context limit reached - attempt mid-execution compression
|
||||
compression_attempted = False
|
||||
compression_successful = False
|
||||
|
||||
try:
|
||||
from application.core.settings import settings
|
||||
compression_enabled = settings.ENABLE_CONVERSATION_COMPRESSION
|
||||
except Exception:
|
||||
compression_enabled = False
|
||||
|
||||
if compression_enabled:
|
||||
compression_attempted = True
|
||||
try:
|
||||
logger.info(
|
||||
f"Context limit reached with {len(tool_calls) - i} remaining tool calls. "
|
||||
f"Attempting mid-execution compression..."
|
||||
)
|
||||
|
||||
# Trigger mid-execution compression (DB-backed if available, otherwise in-memory)
|
||||
compression_successful, rebuilt_messages = self._perform_mid_execution_compression(
|
||||
agent, updated_messages
|
||||
)
|
||||
|
||||
if compression_successful and rebuilt_messages is not None:
|
||||
# Update the messages list with rebuilt compressed version
|
||||
updated_messages = rebuilt_messages
|
||||
|
||||
# Yield compression success message
|
||||
yield {
|
||||
"type": "info",
|
||||
"data": {
|
||||
"message": "Context window limit reached. Compressed conversation history to continue processing."
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Mid-execution compression successful. Continuing with {len(tool_calls) - i} remaining tool calls."
|
||||
)
|
||||
# Proceed to execute the current tool call with the reduced context
|
||||
else:
|
||||
logger.warning("Mid-execution compression attempted but failed. Skipping remaining tools.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during mid-execution compression: {str(e)}", exc_info=True)
|
||||
compression_attempted = True
|
||||
compression_successful = False
|
||||
|
||||
# If compression wasn't attempted or failed, skip remaining tools
|
||||
if not compression_successful:
|
||||
if i == 0:
|
||||
# Special case: limit reached before executing any tools
|
||||
# This can happen when previous tool responses pushed context over limit
|
||||
if compression_attempted:
|
||||
logger.warning(
|
||||
f"Context limit reached before executing any tools. "
|
||||
f"Compression attempted but failed. "
|
||||
f"Skipping all {len(tool_calls)} pending tool call(s). "
|
||||
f"This typically occurs when previous tool responses contained large amounts of data."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Context limit reached before executing any tools. "
|
||||
f"Skipping all {len(tool_calls)} pending tool call(s). "
|
||||
f"This typically occurs when previous tool responses contained large amounts of data. "
|
||||
f"Consider enabling compression or using a model with larger context window."
|
||||
)
|
||||
else:
|
||||
# Normal case: executed some tools, now stopping
|
||||
tool_word = "tool call" if i == 1 else "tool calls"
|
||||
remaining = len(tool_calls) - i
|
||||
remaining_word = "tool call" if remaining == 1 else "tool calls"
|
||||
if compression_attempted:
|
||||
logger.warning(
|
||||
f"Context limit reached after executing {i} {tool_word}. "
|
||||
f"Compression attempted but failed. "
|
||||
f"Skipping remaining {remaining} {remaining_word}."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Context limit reached after executing {i} {tool_word}. "
|
||||
f"Skipping remaining {remaining} {remaining_word}. "
|
||||
f"Consider enabling compression or using a model with larger context window."
|
||||
)
|
||||
|
||||
# Mark remaining tools as skipped
|
||||
for remaining_call in tool_calls[i:]:
|
||||
skip_message = {
|
||||
"type": "tool_call",
|
||||
"data": {
|
||||
"tool_name": "system",
|
||||
"call_id": remaining_call.id,
|
||||
"action_name": remaining_call.name,
|
||||
"arguments": {},
|
||||
"result": "Skipped: Context limit reached. Too many tool calls in conversation.",
|
||||
"status": "skipped"
|
||||
}
|
||||
}
|
||||
yield skip_message
|
||||
|
||||
# Set flag on agent
|
||||
agent.context_limit_reached = True
|
||||
break
|
||||
try:
|
||||
self.tool_calls.append(call)
|
||||
tool_executor_gen = agent._execute_tool_action(tools_dict, call)
|
||||
@@ -205,21 +710,26 @@ class LLMHandler(ABC):
|
||||
except StopIteration as e:
|
||||
tool_response, call_id = e.value
|
||||
break
|
||||
|
||||
function_call_content = {
|
||||
"function_call": {
|
||||
"name": call.name,
|
||||
"args": call.arguments,
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
# Include thought_signature for Google Gemini 3 models
|
||||
# It should be at the same level as function_call, not inside it
|
||||
if call.thought_signature:
|
||||
function_call_content["thought_signature"] = call.thought_signature
|
||||
updated_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"function_call": {
|
||||
"name": call.name,
|
||||
"args": call.arguments,
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
],
|
||||
"content": [function_call_content],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
updated_messages.append(self.create_tool_message(call, tool_response))
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing tool: {str(e)}", exc_info=True)
|
||||
@@ -282,7 +792,7 @@ class LLMHandler(ABC):
|
||||
messages = e.value
|
||||
break
|
||||
response = agent.llm.gen(
|
||||
model=agent.gpt_model, messages=messages, tools=agent.tools
|
||||
model=agent.model_id, messages=messages, tools=agent.tools
|
||||
)
|
||||
parsed = self.parse_response(response)
|
||||
self.llm_calls.append(build_stack_data(agent.llm))
|
||||
@@ -324,6 +834,9 @@ class LLMHandler(ABC):
|
||||
existing.name = call.name
|
||||
if call.arguments:
|
||||
existing.arguments += call.arguments
|
||||
# Preserve thought_signature for Google Gemini 3 models
|
||||
if call.thought_signature:
|
||||
existing.thought_signature = call.thought_signature
|
||||
if parsed.finish_reason == "tool_calls":
|
||||
tool_handler_gen = self.handle_tool_calls(
|
||||
agent, list(tool_calls.values()), tools_dict, messages
|
||||
@@ -336,8 +849,21 @@ class LLMHandler(ABC):
|
||||
break
|
||||
tool_calls = {}
|
||||
|
||||
# Check if context limit was reached during tool execution
|
||||
if hasattr(agent, 'context_limit_reached') and agent.context_limit_reached:
|
||||
# Add system message warning about context limit
|
||||
messages.append({
|
||||
"role": "system",
|
||||
"content": (
|
||||
"WARNING: Context window limit has been reached. "
|
||||
"Please provide a final response to the user without making additional tool calls. "
|
||||
"Summarize the work completed so far."
|
||||
)
|
||||
})
|
||||
logger.info("Context limit reached - instructing agent to wrap up")
|
||||
|
||||
response = agent.llm.gen_stream(
|
||||
model=agent.gpt_model, messages=messages, tools=agent.tools
|
||||
model=agent.model_id, messages=messages, tools=agent.tools if not agent.context_limit_reached else None
|
||||
)
|
||||
self.llm_calls.append(build_stack_data(agent.llm))
|
||||
|
||||
|
||||
@@ -19,15 +19,20 @@ class GoogleLLMHandler(LLMHandler):
|
||||
)
|
||||
if hasattr(response, "candidates"):
|
||||
parts = response.candidates[0].content.parts if response.candidates else []
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
id=str(uuid.uuid4()),
|
||||
name=part.function_call.name,
|
||||
arguments=part.function_call.args,
|
||||
)
|
||||
for part in parts
|
||||
if hasattr(part, "function_call") and part.function_call is not None
|
||||
]
|
||||
tool_calls = []
|
||||
for idx, part in enumerate(parts):
|
||||
if hasattr(part, "function_call") and part.function_call is not None:
|
||||
has_sig = hasattr(part, "thought_signature") and part.thought_signature is not None
|
||||
thought_sig = part.thought_signature if has_sig else None
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=str(uuid.uuid4()),
|
||||
name=part.function_call.name,
|
||||
arguments=part.function_call.args,
|
||||
index=idx,
|
||||
thought_signature=thought_sig,
|
||||
)
|
||||
)
|
||||
|
||||
content = " ".join(
|
||||
part.text
|
||||
@@ -41,13 +46,17 @@ class GoogleLLMHandler(LLMHandler):
|
||||
raw_response=response,
|
||||
)
|
||||
else:
|
||||
# This branch handles individual Part objects from streaming responses
|
||||
tool_calls = []
|
||||
if hasattr(response, "function_call"):
|
||||
if hasattr(response, "function_call") and response.function_call is not None:
|
||||
has_sig = hasattr(response, "thought_signature") and response.thought_signature is not None
|
||||
thought_sig = response.thought_signature if has_sig else None
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=str(uuid.uuid4()),
|
||||
name=response.function_call.name,
|
||||
arguments=response.function_call.args,
|
||||
thought_signature=thought_sig,
|
||||
)
|
||||
)
|
||||
return LLMResponse(
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
from application.llm.groq import GroqLLM
|
||||
from application.llm.openai import OpenAILLM, AzureOpenAILLM
|
||||
from application.llm.sagemaker import SagemakerAPILLM
|
||||
from application.llm.huggingface import HuggingFaceLLM
|
||||
from application.llm.llama_cpp import LlamaCpp
|
||||
import logging
|
||||
|
||||
from application.llm.anthropic import AnthropicLLM
|
||||
from application.llm.docsgpt_provider import DocsGPTAPILLM
|
||||
from application.llm.premai import PremAILLM
|
||||
from application.llm.google_ai import GoogleLLM
|
||||
from application.llm.groq import GroqLLM
|
||||
from application.llm.huggingface import HuggingFaceLLM
|
||||
from application.llm.llama_cpp import LlamaCpp
|
||||
from application.llm.novita import NovitaLLM
|
||||
from application.llm.openai import AzureOpenAILLM, OpenAILLM
|
||||
from application.llm.premai import PremAILLM
|
||||
from application.llm.sagemaker import SagemakerAPILLM
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMCreator:
|
||||
@@ -26,10 +30,26 @@ class LLMCreator:
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create_llm(cls, type, api_key, user_api_key, decoded_token, *args, **kwargs):
|
||||
def create_llm(
|
||||
cls, type, api_key, user_api_key, decoded_token, model_id=None, *args, **kwargs
|
||||
):
|
||||
from application.core.model_utils import get_base_url_for_model
|
||||
|
||||
llm_class = cls.llms.get(type.lower())
|
||||
if not llm_class:
|
||||
raise ValueError(f"No LLM class found for type {type}")
|
||||
|
||||
# Extract base_url from model configuration if model_id is provided
|
||||
base_url = None
|
||||
if model_id:
|
||||
base_url = get_base_url_for_model(model_id)
|
||||
|
||||
return llm_class(
|
||||
api_key, user_api_key, decoded_token=decoded_token, *args, **kwargs
|
||||
api_key,
|
||||
user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
model_id=model_id,
|
||||
base_url=base_url,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -2,6 +2,8 @@ import base64
|
||||
import json
|
||||
import logging
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.llm.base import BaseLLM
|
||||
from application.storage.storage_creator import StorageCreator
|
||||
@@ -9,20 +11,25 @@ from application.storage.storage_creator import StorageCreator
|
||||
|
||||
class OpenAILLM(BaseLLM):
|
||||
|
||||
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
|
||||
from openai import OpenAI
|
||||
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
if (
|
||||
self.api_key = api_key or settings.OPENAI_API_KEY or settings.API_KEY
|
||||
self.user_api_key = user_api_key
|
||||
|
||||
# Priority: 1) Parameter base_url, 2) Settings OPENAI_BASE_URL, 3) Default
|
||||
effective_base_url = None
|
||||
if base_url and isinstance(base_url, str) and base_url.strip():
|
||||
effective_base_url = base_url
|
||||
elif (
|
||||
isinstance(settings.OPENAI_BASE_URL, str)
|
||||
and settings.OPENAI_BASE_URL.strip()
|
||||
):
|
||||
self.client = OpenAI(api_key=api_key, base_url=settings.OPENAI_BASE_URL)
|
||||
effective_base_url = settings.OPENAI_BASE_URL
|
||||
else:
|
||||
DEFAULT_OPENAI_API_BASE = "https://api.openai.com/v1"
|
||||
self.client = OpenAI(api_key=api_key, base_url=DEFAULT_OPENAI_API_BASE)
|
||||
self.api_key = api_key
|
||||
self.user_api_key = user_api_key
|
||||
effective_base_url = "https://api.openai.com/v1"
|
||||
|
||||
self.client = OpenAI(api_key=self.api_key, base_url=effective_base_url)
|
||||
self.storage = StorageCreator.get_storage()
|
||||
|
||||
def _clean_messages_openai(self, messages):
|
||||
@@ -33,7 +40,6 @@ class OpenAILLM(BaseLLM):
|
||||
|
||||
if role == "model":
|
||||
role = "assistant"
|
||||
|
||||
if role and content is not None:
|
||||
if isinstance(content, str):
|
||||
cleaned_messages.append({"role": role, "content": content})
|
||||
@@ -107,7 +113,6 @@ class OpenAILLM(BaseLLM):
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unexpected content type: {type(content)}")
|
||||
|
||||
return cleaned_messages
|
||||
|
||||
def _raw_gen(
|
||||
@@ -123,6 +128,10 @@ class OpenAILLM(BaseLLM):
|
||||
):
|
||||
messages = self._clean_messages_openai(messages)
|
||||
|
||||
# Convert max_tokens to max_completion_tokens for newer models
|
||||
if "max_tokens" in kwargs:
|
||||
kwargs["max_completion_tokens"] = kwargs.pop("max_tokens")
|
||||
|
||||
request_params = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
@@ -132,10 +141,8 @@ class OpenAILLM(BaseLLM):
|
||||
|
||||
if tools:
|
||||
request_params["tools"] = tools
|
||||
|
||||
if response_format:
|
||||
request_params["response_format"] = response_format
|
||||
|
||||
response = self.client.chat.completions.create(**request_params)
|
||||
|
||||
if tools:
|
||||
@@ -156,6 +163,10 @@ class OpenAILLM(BaseLLM):
|
||||
):
|
||||
messages = self._clean_messages_openai(messages)
|
||||
|
||||
# Convert max_tokens to max_completion_tokens for newer models
|
||||
if "max_tokens" in kwargs:
|
||||
kwargs["max_completion_tokens"] = kwargs.pop("max_tokens")
|
||||
|
||||
request_params = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
@@ -165,10 +176,8 @@ class OpenAILLM(BaseLLM):
|
||||
|
||||
if tools:
|
||||
request_params["tools"] = tools
|
||||
|
||||
if response_format:
|
||||
request_params["response_format"] = response_format
|
||||
|
||||
response = self.client.chat.completions.create(**request_params)
|
||||
|
||||
try:
|
||||
@@ -194,7 +203,6 @@ class OpenAILLM(BaseLLM):
|
||||
def prepare_structured_output_format(self, json_schema):
|
||||
if not json_schema:
|
||||
return None
|
||||
|
||||
try:
|
||||
|
||||
def add_additional_properties_false(schema_obj):
|
||||
@@ -204,11 +212,11 @@ class OpenAILLM(BaseLLM):
|
||||
if schema_copy.get("type") == "object":
|
||||
schema_copy["additionalProperties"] = False
|
||||
# Ensure 'required' includes all properties for OpenAI strict mode
|
||||
|
||||
if "properties" in schema_copy:
|
||||
schema_copy["required"] = list(
|
||||
schema_copy["properties"].keys()
|
||||
)
|
||||
|
||||
for key, value in schema_copy.items():
|
||||
if key == "properties" and isinstance(value, dict):
|
||||
schema_copy[key] = {
|
||||
@@ -224,7 +232,6 @@ class OpenAILLM(BaseLLM):
|
||||
add_additional_properties_false(sub_schema)
|
||||
for sub_schema in value
|
||||
]
|
||||
|
||||
return schema_copy
|
||||
return schema_obj
|
||||
|
||||
@@ -243,7 +250,6 @@ class OpenAILLM(BaseLLM):
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error preparing structured output format: {e}")
|
||||
return None
|
||||
@@ -277,21 +283,19 @@ class OpenAILLM(BaseLLM):
|
||||
"""
|
||||
if not attachments:
|
||||
return messages
|
||||
|
||||
prepared_messages = messages.copy()
|
||||
|
||||
# Find the user message to attach file_id to the last one
|
||||
|
||||
user_message_index = None
|
||||
for i in range(len(prepared_messages) - 1, -1, -1):
|
||||
if prepared_messages[i].get("role") == "user":
|
||||
user_message_index = i
|
||||
break
|
||||
|
||||
if user_message_index is None:
|
||||
user_message = {"role": "user", "content": []}
|
||||
prepared_messages.append(user_message)
|
||||
user_message_index = len(prepared_messages) - 1
|
||||
|
||||
if isinstance(prepared_messages[user_message_index].get("content"), str):
|
||||
text_content = prepared_messages[user_message_index]["content"]
|
||||
prepared_messages[user_message_index]["content"] = [
|
||||
@@ -299,7 +303,6 @@ class OpenAILLM(BaseLLM):
|
||||
]
|
||||
elif not isinstance(prepared_messages[user_message_index].get("content"), list):
|
||||
prepared_messages[user_message_index]["content"] = []
|
||||
|
||||
for attachment in attachments:
|
||||
mime_type = attachment.get("mime_type")
|
||||
|
||||
@@ -326,6 +329,7 @@ class OpenAILLM(BaseLLM):
|
||||
}
|
||||
)
|
||||
# Handle PDFs using the file API
|
||||
|
||||
elif mime_type == "application/pdf":
|
||||
try:
|
||||
file_id = self._upload_file_to_openai(attachment)
|
||||
@@ -341,7 +345,6 @@ class OpenAILLM(BaseLLM):
|
||||
"text": f"File content:\n\n{attachment['content']}",
|
||||
}
|
||||
)
|
||||
|
||||
return prepared_messages
|
||||
|
||||
def _get_base64_image(self, attachment):
|
||||
@@ -357,7 +360,6 @@ class OpenAILLM(BaseLLM):
|
||||
file_path = attachment.get("path")
|
||||
if not file_path:
|
||||
raise ValueError("No file path provided in attachment")
|
||||
|
||||
try:
|
||||
with self.storage.get_file(file_path) as image_file:
|
||||
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||
@@ -381,12 +383,10 @@ class OpenAILLM(BaseLLM):
|
||||
|
||||
if "openai_file_id" in attachment:
|
||||
return attachment["openai_file_id"]
|
||||
|
||||
file_path = attachment.get("path")
|
||||
|
||||
if not self.storage.file_exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
try:
|
||||
file_id = self.storage.process_file(
|
||||
file_path,
|
||||
@@ -404,7 +404,6 @@ class OpenAILLM(BaseLLM):
|
||||
attachments_collection.update_one(
|
||||
{"_id": attachment["_id"]}, {"$set": {"openai_file_id": file_id}}
|
||||
)
|
||||
|
||||
return file_id
|
||||
except Exception as e:
|
||||
logging.error(f"Error uploading file to OpenAI: {e}", exc_info=True)
|
||||
|
||||
35
application/prompts/compression/v1.0.txt
Normal file
35
application/prompts/compression/v1.0.txt
Normal file
@@ -0,0 +1,35 @@
|
||||
Your task is to create a detailed summary of the conversation so far, paying close attention to the user's explicit requests and your previous actions.
|
||||
|
||||
This summary should be thorough in capturing technical details, code patterns, and architectural decisions that would be essential for continuing work without losing context.
|
||||
|
||||
Before providing your final summary, wrap your analysis in <analysis> tags to organize your thoughts and ensure you've covered all necessary points. In your analysis process:
|
||||
|
||||
1. Chronologically analyze each message, tool call and section of the conversation. For each section thoroughly identify:
|
||||
- The user's explicit requests and intents
|
||||
- Your approach to addressing the user's requests
|
||||
- Key decisions, concepts and patterns
|
||||
- Specific details like if applicable:
|
||||
- file names
|
||||
- full code snippets
|
||||
- function signatures
|
||||
- file edits
|
||||
- Errors that you ran into and how you fixed them
|
||||
- Pay special attention to specific user feedback that you received, especially if the user told you to do something differently.
|
||||
|
||||
2. Double-check for accuracy and completeness, addressing each required element thoroughly.
|
||||
|
||||
Your summary should include the following sections:
|
||||
|
||||
1. Primary Request and Intent: Capture all of the user's explicit requests and intents in detail
|
||||
2. Key Concepts: List all important concepts discussed.
|
||||
3. Files and Code Sections: Enumerate specific files and code sections examined, modified, or created. Pay special attention to the most recent messages and include full code snippets where applicable and include a summary of why this file read or edit is important.
|
||||
4. Errors and fixes: List all errors that you ran into, and how you fixed them. Pay special attention to specific user feedback that you received, especially if the user told you to do something differently.
|
||||
5. Problem Solving: Document problems solved and any ongoing troubleshooting efforts.
|
||||
6. All user messages: List ALL user messages that are not tool results. These are critical for understanding the users' feedback and changing intent.
|
||||
7. Tool Calls: List ALL tool calls made, including their inputs relevant parts of the outputs.
|
||||
8. Pending Tasks: Outline any pending tasks that you have explicitly been asked to work on.
|
||||
9. Current Work: Describe in detail precisely what was being worked on immediately before this summary request, paying special attention to the most recent messages from both user and assistant. Include file names and code snippets where applicable.
|
||||
10. Optional Next Step: List the next step that you will take that is related to the most recent work you were doing. IMPORTANT: ensure that this step is DIRECTLY in line with the user's most recent explicit requests, and the task you were working on immediately before this summary request. If your last task was concluded, then only list next steps if they are explicitly in line with the users request. Do not start on tangential requests or really old requests that were already completed without confirming with the user first.
|
||||
If there is a next step, include direct quotes from the most recent conversation showing exactly what task you were working on and where you left off. This should be verbatim to ensure there's no drift in task interpretation.
|
||||
|
||||
Please provide your summary based on the conversation and tools used so far, following this structure and ensuring precision and thoroughness in your response.
|
||||
@@ -15,7 +15,7 @@ Flask==3.1.1
|
||||
faiss-cpu==1.9.0.post1
|
||||
fastmcp==2.11.0
|
||||
flask-restx==1.3.0
|
||||
google-genai==1.3.0
|
||||
google-genai==1.49.0
|
||||
google-api-python-client==2.179.0
|
||||
google-auth-httplib2==0.2.0
|
||||
google-auth-oauthlib==1.2.2
|
||||
|
||||
@@ -16,7 +16,7 @@ class ClassicRAG(BaseRetriever):
|
||||
prompt="",
|
||||
chunks=2,
|
||||
doc_token_limit=50000,
|
||||
gpt_model="docsgpt",
|
||||
model_id="docsgpt-local",
|
||||
user_api_key=None,
|
||||
llm_name=settings.LLM_PROVIDER,
|
||||
api_key=settings.API_KEY,
|
||||
@@ -40,7 +40,7 @@ class ClassicRAG(BaseRetriever):
|
||||
f"ClassicRAG initialized with chunks={self.chunks}, user_api_key={user_identifier}, "
|
||||
f"sources={'active_docs' in source and source['active_docs'] is not None}"
|
||||
)
|
||||
self.gpt_model = gpt_model
|
||||
self.model_id = model_id
|
||||
self.doc_token_limit = doc_token_limit
|
||||
self.user_api_key = user_api_key
|
||||
self.llm_name = llm_name
|
||||
@@ -100,7 +100,7 @@ class ClassicRAG(BaseRetriever):
|
||||
]
|
||||
|
||||
try:
|
||||
rephrased_query = self.llm.gen(model=self.gpt_model, messages=messages)
|
||||
rephrased_query = self.llm.gen(model=self.model_id, messages=messages)
|
||||
print(f"Rephrased query: {rephrased_query}")
|
||||
return rephrased_query if rephrased_query else self.original_question
|
||||
except Exception as e:
|
||||
|
||||
@@ -7,6 +7,8 @@ import tiktoken
|
||||
from flask import jsonify, make_response
|
||||
from werkzeug.utils import secure_filename
|
||||
|
||||
from application.core.model_utils import get_token_limit
|
||||
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
@@ -75,11 +77,9 @@ def count_tokens_docs(docs):
|
||||
|
||||
|
||||
def calculate_doc_token_budget(
|
||||
gpt_model: str = "gpt-4o", history_token_limit: int = 2000
|
||||
model_id: str = "gpt-4o", history_token_limit: int = 2000
|
||||
) -> int:
|
||||
total_context = settings.LLM_TOKEN_LIMITS.get(
|
||||
gpt_model, settings.DEFAULT_LLM_TOKEN_LIMIT
|
||||
)
|
||||
total_context = get_token_limit(model_id)
|
||||
reserved = sum(settings.RESERVED_TOKENS.values())
|
||||
doc_budget = total_context - history_token_limit - reserved
|
||||
return max(doc_budget, 1000)
|
||||
@@ -144,16 +144,13 @@ def get_hash(data):
|
||||
return hashlib.md5(data.encode(), usedforsecurity=False).hexdigest()
|
||||
|
||||
|
||||
def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"):
|
||||
def limit_chat_history(history, max_token_limit=None, model_id="docsgpt-local"):
|
||||
"""Limit chat history to fit within token limit."""
|
||||
from application.core.settings import settings
|
||||
|
||||
model_token_limit = get_token_limit(model_id)
|
||||
max_token_limit = (
|
||||
max_token_limit
|
||||
if max_token_limit
|
||||
and max_token_limit
|
||||
< settings.LLM_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_LLM_TOKEN_LIMIT)
|
||||
else settings.LLM_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_LLM_TOKEN_LIMIT)
|
||||
if max_token_limit and max_token_limit < model_token_limit
|
||||
else model_token_limit
|
||||
)
|
||||
|
||||
if not history:
|
||||
@@ -200,42 +197,67 @@ def generate_image_url(image_path):
|
||||
return f"{base_url}/api/images/{image_path}"
|
||||
|
||||
|
||||
def calculate_compression_threshold(
|
||||
model_id: str, threshold_percentage: float = 0.8
|
||||
) -> int:
|
||||
"""
|
||||
Calculate token threshold for triggering compression.
|
||||
|
||||
Args:
|
||||
model_id: Model identifier
|
||||
threshold_percentage: Percentage of context window (default 80%)
|
||||
|
||||
Returns:
|
||||
Token count threshold
|
||||
"""
|
||||
total_context = get_token_limit(model_id)
|
||||
threshold = int(total_context * threshold_percentage)
|
||||
return threshold
|
||||
|
||||
|
||||
def clean_text_for_tts(text: str) -> str:
|
||||
"""
|
||||
clean text for Text-to-Speech processing.
|
||||
"""
|
||||
# Handle code blocks and links
|
||||
text = re.sub(r'```mermaid[\s\S]*?```', ' flowchart, ', text) ## ```mermaid...```
|
||||
text = re.sub(r'```[\s\S]*?```', ' code block, ', text) ## ```code```
|
||||
text = re.sub(r'\[([^\]]+)\]\([^\)]+\)', r'\1', text) ## [text](url)
|
||||
text = re.sub(r'!\[([^\]]*)\]\([^\)]+\)', '', text) ## 
|
||||
|
||||
text = re.sub(r"```mermaid[\s\S]*?```", " flowchart, ", text) ## ```mermaid...```
|
||||
text = re.sub(r"```[\s\S]*?```", " code block, ", text) ## ```code```
|
||||
text = re.sub(r"\[([^\]]+)\]\([^\)]+\)", r"\1", text) ## [text](url)
|
||||
text = re.sub(r"!\[([^\]]*)\]\([^\)]+\)", "", text) ## 
|
||||
|
||||
# Remove markdown formatting
|
||||
text = re.sub(r'`([^`]+)`', r'\1', text) ## `code`
|
||||
text = re.sub(r'\{([^}]*)\}', r' \1 ', text) ## {text}
|
||||
text = re.sub(r'[{}]', ' ', text) ## unmatched {}
|
||||
text = re.sub(r'\[([^\]]+)\]', r' \1 ', text) ## [text]
|
||||
text = re.sub(r'[\[\]]', ' ', text) ## unmatched []
|
||||
text = re.sub(r'(\*\*|__)(.*?)\1', r'\2', text) ## **bold** __bold__
|
||||
text = re.sub(r'(\*|_)(.*?)\1', r'\2', text) ## *italic* _italic_
|
||||
text = re.sub(r'^#{1,6}\s+', '', text, flags=re.MULTILINE) ## # headers
|
||||
text = re.sub(r'^>\s+', '', text, flags=re.MULTILINE) ## > blockquotes
|
||||
text = re.sub(r'^[\s]*[-\*\+]\s+', '', text, flags=re.MULTILINE) ## - * + lists
|
||||
text = re.sub(r'^[\s]*\d+\.\s+', '', text, flags=re.MULTILINE) ## 1. numbered lists
|
||||
text = re.sub(r'^[\*\-_]{3,}\s*$', '', text, flags=re.MULTILINE) ## --- *** ___ rules
|
||||
text = re.sub(r'<[^>]*>', '', text) ## <html> tags
|
||||
|
||||
#Remove non-ASCII (emojis, special Unicode)
|
||||
text = re.sub(r'[^\x20-\x7E\n\r\t]', '', text)
|
||||
text = re.sub(r"`([^`]+)`", r"\1", text) ## `code`
|
||||
text = re.sub(r"\{([^}]*)\}", r" \1 ", text) ## {text}
|
||||
text = re.sub(r"[{}]", " ", text) ## unmatched {}
|
||||
text = re.sub(r"\[([^\]]+)\]", r" \1 ", text) ## [text]
|
||||
text = re.sub(r"[\[\]]", " ", text) ## unmatched []
|
||||
text = re.sub(r"(\*\*|__)(.*?)\1", r"\2", text) ## **bold** __bold__
|
||||
text = re.sub(r"(\*|_)(.*?)\1", r"\2", text) ## *italic* _italic_
|
||||
text = re.sub(r"^#{1,6}\s+", "", text, flags=re.MULTILINE) ## # headers
|
||||
text = re.sub(r"^>\s+", "", text, flags=re.MULTILINE) ## > blockquotes
|
||||
text = re.sub(r"^[\s]*[-\*\+]\s+", "", text, flags=re.MULTILINE) ## - * + lists
|
||||
text = re.sub(r"^[\s]*\d+\.\s+", "", text, flags=re.MULTILINE) ## 1. numbered lists
|
||||
text = re.sub(
|
||||
r"^[\*\-_]{3,}\s*$", "", text, flags=re.MULTILINE
|
||||
) ## --- *** ___ rules
|
||||
text = re.sub(r"<[^>]*>", "", text) ## <html> tags
|
||||
|
||||
#Replace special sequences
|
||||
text = re.sub(r'-->', ', ', text) ## -->
|
||||
text = re.sub(r'<--', ', ', text) ## <--
|
||||
text = re.sub(r'=>', ', ', text) ## =>
|
||||
text = re.sub(r'::', ' ', text) ## ::
|
||||
# Remove non-ASCII (emojis, special Unicode)
|
||||
|
||||
#Normalize whitespace
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
text = re.sub(r"[^\x20-\x7E\n\r\t]", "", text)
|
||||
|
||||
# Replace special sequences
|
||||
|
||||
text = re.sub(r"-->", ", ", text) ## -->
|
||||
text = re.sub(r"<--", ", ", text) ## <--
|
||||
text = re.sub(r"=>", ", ", text) ## =>
|
||||
text = re.sub(r"::", " ", text) ## ::
|
||||
|
||||
# Normalize whitespace
|
||||
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
text = text.strip()
|
||||
|
||||
return text
|
||||
|
||||
@@ -146,6 +146,14 @@ def upload_index(full_path, file_data):
|
||||
|
||||
def run_agent_logic(agent_config, input_data):
|
||||
try:
|
||||
from application.core.model_utils import (
|
||||
get_api_key_for_provider,
|
||||
get_default_model_id,
|
||||
get_provider_from_model_id,
|
||||
validate_model_id,
|
||||
)
|
||||
from application.utils import calculate_doc_token_budget
|
||||
|
||||
source = agent_config.get("source")
|
||||
retriever = agent_config.get("retriever", "classic")
|
||||
if isinstance(source, DBRef):
|
||||
@@ -160,31 +168,62 @@ def run_agent_logic(agent_config, input_data):
|
||||
user_api_key = agent_config["key"]
|
||||
agent_type = agent_config.get("agent_type", "classic")
|
||||
decoded_token = {"sub": agent_config.get("user")}
|
||||
json_schema = agent_config.get("json_schema")
|
||||
prompt = get_prompt(prompt_id, db["prompts"])
|
||||
agent = AgentCreator.create_agent(
|
||||
agent_type,
|
||||
endpoint="webhook",
|
||||
llm_name=settings.LLM_PROVIDER,
|
||||
gpt_model=settings.LLM_NAME,
|
||||
api_key=settings.API_KEY,
|
||||
user_api_key=user_api_key,
|
||||
prompt=prompt,
|
||||
chat_history=[],
|
||||
decoded_token=decoded_token,
|
||||
attachments=[],
|
||||
|
||||
# Determine model_id: check agent's default_model_id, fallback to system default
|
||||
agent_default_model = agent_config.get("default_model_id", "")
|
||||
if agent_default_model and validate_model_id(agent_default_model):
|
||||
model_id = agent_default_model
|
||||
else:
|
||||
model_id = get_default_model_id()
|
||||
|
||||
# Get provider and API key for the selected model
|
||||
provider = get_provider_from_model_id(model_id) if model_id else settings.LLM_PROVIDER
|
||||
system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER)
|
||||
|
||||
# Calculate proper doc_token_limit based on model's context window
|
||||
history_token_limit = 2000 # Default for webhooks
|
||||
doc_token_limit = calculate_doc_token_budget(
|
||||
model_id=model_id, history_token_limit=history_token_limit
|
||||
)
|
||||
|
||||
retriever = RetrieverCreator.create_retriever(
|
||||
retriever,
|
||||
source=source,
|
||||
chat_history=[],
|
||||
prompt=prompt,
|
||||
chunks=chunks,
|
||||
token_limit=settings.DEFAULT_MAX_HISTORY,
|
||||
gpt_model=settings.LLM_NAME,
|
||||
doc_token_limit=doc_token_limit,
|
||||
model_id=model_id,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
answer = agent.gen(query=input_data, retriever=retriever)
|
||||
|
||||
# Pre-fetch documents using the retriever
|
||||
retrieved_docs = []
|
||||
try:
|
||||
docs = retriever.search(input_data)
|
||||
if docs:
|
||||
retrieved_docs = docs
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to retrieve documents: {e}")
|
||||
|
||||
agent = AgentCreator.create_agent(
|
||||
agent_type,
|
||||
endpoint="webhook",
|
||||
llm_name=provider or settings.LLM_PROVIDER,
|
||||
model_id=model_id,
|
||||
api_key=system_api_key,
|
||||
user_api_key=user_api_key,
|
||||
prompt=prompt,
|
||||
chat_history=[],
|
||||
retrieved_docs=retrieved_docs,
|
||||
decoded_token=decoded_token,
|
||||
attachments=[],
|
||||
json_schema=json_schema,
|
||||
)
|
||||
answer = agent.gen(query=input_data)
|
||||
response_full = ""
|
||||
thought = ""
|
||||
source_log_docs = []
|
||||
|
||||
@@ -57,7 +57,7 @@ The easiest way to launch DocsGPT is using the provided `setup.sh` script. This
|
||||
|
||||
* **4) Connect Cloud API Provider:** This option lets you connect DocsGPT to a commercial Cloud API provider such as OpenAI, Google (Vertex AI/Gemini), Anthropic (Claude), Groq, HuggingFace Inference API, or Azure OpenAI. You will need an API key from your chosen provider. Select this if you prefer to use a powerful cloud-based LLM.
|
||||
|
||||
* **5) Modify DocsGPT's source code and rebuild the Docker images locally. Instead of pulling prebuilt images from Docker Hub or using the hosted/public API, you build the entire backend and frontend from source, customizing how DocsGPT works internally, or run it in an environment without internet access.
|
||||
* **5) Modify DocsGPT's source code and rebuild the Docker images locally.** Instead of pulling prebuilt images from Docker Hub or using the hosted/public API, you build the entire backend and frontend from source, customizing how DocsGPT works internally, or run it in an environment without internet access.
|
||||
|
||||
After selecting an option and providing any required information (like API keys or model names), the script will configure your `.env` file and start DocsGPT using Docker Compose.
|
||||
|
||||
@@ -119,4 +119,4 @@ If you prefer a more manual approach, you can follow our [Docker Deployment docu
|
||||
|
||||
For more advanced customization of DocsGPT settings, such as configuring vector stores, embedding models, and other parameters, please refer to the [DocsGPT Settings documentation](/Deploying/DocsGPT-Settings). This guide explains how to modify the `.env` file or `settings.py` for deeper configuration.
|
||||
|
||||
Enjoy using DocsGPT!
|
||||
Enjoy using DocsGPT!
|
||||
|
||||
@@ -3,4 +3,4 @@ VITE_BASE_URL=http://localhost:5173
|
||||
VITE_API_HOST=http://127.0.0.1:7091
|
||||
VITE_API_STREAMING=true
|
||||
VITE_NOTIFICATION_TEXT="What's new in 0.14.0 — Changelog"
|
||||
VITE_NOTIFICATION_LINK="#"
|
||||
VITE_NOTIFICATION_LINK="https://blog.docsgpt.cloud/docsgpt-0-14-agents-automate-integrate-and-innovate/"
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
node_modules/
|
||||
dist/
|
||||
prettier.config.cjs
|
||||
.eslintrc.cjs
|
||||
env.d.ts
|
||||
public/
|
||||
assets/
|
||||
vite-env.d.ts
|
||||
.prettierignore
|
||||
package-lock.json
|
||||
package.json
|
||||
postcss.config.cjs
|
||||
prettier.config.cjs
|
||||
tailwind.config.cjs
|
||||
tsconfig.json
|
||||
tsconfig.node.json
|
||||
vite.config.ts
|
||||
@@ -1,45 +0,0 @@
|
||||
module.exports = {
|
||||
env: {
|
||||
browser: true,
|
||||
es2021: true,
|
||||
node: true,
|
||||
},
|
||||
extends: [
|
||||
'eslint:recommended',
|
||||
'plugin:@typescript-eslint/recommended',
|
||||
'plugin:react/recommended',
|
||||
'plugin:prettier/recommended',
|
||||
],
|
||||
overrides: [],
|
||||
parser: '@typescript-eslint/parser',
|
||||
parserOptions: {
|
||||
ecmaVersion: 'latest',
|
||||
sourceType: 'module',
|
||||
},
|
||||
plugins: ['react', 'unused-imports'],
|
||||
rules: {
|
||||
'react/prop-types': 'off',
|
||||
'unused-imports/no-unused-imports': 'error',
|
||||
'react/react-in-jsx-scope': 'off',
|
||||
'prettier/prettier': [
|
||||
'error',
|
||||
{
|
||||
endOfLine: 'auto',
|
||||
},
|
||||
],
|
||||
},
|
||||
settings: {
|
||||
'import/parsers': {
|
||||
'@typescript-eslint/parser': ['.ts', '.tsx'],
|
||||
},
|
||||
react: {
|
||||
version: 'detect',
|
||||
},
|
||||
'import/resolver': {
|
||||
node: {
|
||||
paths: ['src'],
|
||||
extensions: ['.js', '.jsx', '.ts', '.tsx'],
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
78
frontend/eslint.config.js
Normal file
78
frontend/eslint.config.js
Normal file
@@ -0,0 +1,78 @@
|
||||
import js from '@eslint/js'
|
||||
import tsParser from '@typescript-eslint/parser'
|
||||
import tsPlugin from '@typescript-eslint/eslint-plugin'
|
||||
import react from 'eslint-plugin-react'
|
||||
import unusedImports from 'eslint-plugin-unused-imports'
|
||||
import prettier from 'eslint-plugin-prettier'
|
||||
import globals from 'globals'
|
||||
|
||||
export default [
|
||||
{
|
||||
ignores: [
|
||||
'node_modules/',
|
||||
'dist/',
|
||||
'prettier.config.cjs',
|
||||
'.eslintrc.cjs',
|
||||
'env.d.ts',
|
||||
'public/',
|
||||
'assets/',
|
||||
'vite-env.d.ts',
|
||||
'.prettierignore',
|
||||
'package-lock.json',
|
||||
'package.json',
|
||||
'postcss.config.cjs',
|
||||
'tailwind.config.cjs',
|
||||
'tsconfig.json',
|
||||
'tsconfig.node.json',
|
||||
'vite.config.ts',
|
||||
],
|
||||
},
|
||||
{
|
||||
files: ['**/*.{js,jsx,ts,tsx}'],
|
||||
languageOptions: {
|
||||
ecmaVersion: 'latest',
|
||||
sourceType: 'module',
|
||||
parser: tsParser,
|
||||
parserOptions: {
|
||||
ecmaFeatures: {
|
||||
jsx: true,
|
||||
},
|
||||
},
|
||||
globals: {
|
||||
...globals.browser,
|
||||
...globals.es2021,
|
||||
...globals.node,
|
||||
},
|
||||
},
|
||||
plugins: {
|
||||
'@typescript-eslint': tsPlugin,
|
||||
react,
|
||||
'unused-imports': unusedImports,
|
||||
prettier,
|
||||
},
|
||||
rules: {
|
||||
...js.configs.recommended.rules,
|
||||
...tsPlugin.configs.recommended.rules,
|
||||
...react.configs.recommended.rules,
|
||||
...prettier.configs.recommended.rules,
|
||||
'react/prop-types': 'off',
|
||||
'unused-imports/no-unused-imports': 'error',
|
||||
'react/react-in-jsx-scope': 'off',
|
||||
'no-undef': 'off',
|
||||
'@typescript-eslint/no-explicit-any': 'warn',
|
||||
'@typescript-eslint/no-unused-vars': 'warn',
|
||||
'@typescript-eslint/no-unused-expressions': 'warn',
|
||||
'prettier/prettier': [
|
||||
'error',
|
||||
{
|
||||
endOfLine: 'auto',
|
||||
},
|
||||
],
|
||||
},
|
||||
settings: {
|
||||
react: {
|
||||
version: 'detect',
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
1480
frontend/package-lock.json
generated
1480
frontend/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -19,7 +19,7 @@
|
||||
]
|
||||
},
|
||||
"dependencies": {
|
||||
"@reduxjs/toolkit": "^2.8.2",
|
||||
"@reduxjs/toolkit": "^2.10.1",
|
||||
"chart.js": "^4.4.4",
|
||||
"clsx": "^2.1.1",
|
||||
"copy-to-clipboard": "^3.3.3",
|
||||
@@ -33,7 +33,7 @@
|
||||
"react-dom": "^19.1.1",
|
||||
"react-dropzone": "^14.3.8",
|
||||
"react-google-drive-picker": "^1.2.2",
|
||||
"react-i18next": "^15.4.0",
|
||||
"react-i18next": "^16.2.4",
|
||||
"react-markdown": "^9.0.1",
|
||||
"react-redux": "^9.2.0",
|
||||
"react-router-dom": "^7.6.1",
|
||||
@@ -46,30 +46,28 @@
|
||||
"devDependencies": {
|
||||
"@tailwindcss/postcss": "^4.1.10",
|
||||
"@types/lodash": "^4.17.20",
|
||||
"@types/mermaid": "^9.1.0",
|
||||
"@types/react": "^19.1.8",
|
||||
"@types/react-dom": "^19.1.7",
|
||||
"@types/react-syntax-highlighter": "^15.5.13",
|
||||
"@typescript-eslint/eslint-plugin": "^6.21.0",
|
||||
"@typescript-eslint/parser": "^6.21.0",
|
||||
"@typescript-eslint/eslint-plugin": "^8.46.3",
|
||||
"@typescript-eslint/parser": "^8.46.3",
|
||||
"@vitejs/plugin-react": "^4.3.4",
|
||||
"eslint": "^8.57.1",
|
||||
"eslint": "^9.39.1",
|
||||
"eslint-config-prettier": "^10.1.5",
|
||||
"eslint-config-standard-with-typescript": "^43.0.1",
|
||||
"eslint-plugin-import": "^2.31.0",
|
||||
"eslint-plugin-n": "^16.6.2",
|
||||
"eslint-plugin-n": "^17.23.1",
|
||||
"eslint-plugin-prettier": "^5.5.4",
|
||||
"eslint-plugin-promise": "^6.6.0",
|
||||
"eslint-plugin-react": "^7.37.5",
|
||||
"eslint-plugin-unused-imports": "^4.1.4",
|
||||
"husky": "^8.0.0",
|
||||
"husky": "^9.1.7",
|
||||
"lint-staged": "^15.3.0",
|
||||
"postcss": "^8.4.49",
|
||||
"prettier": "^3.5.3",
|
||||
"prettier-plugin-tailwindcss": "^0.7.1",
|
||||
"tailwindcss": "^4.1.11",
|
||||
"typescript": "^5.8.3",
|
||||
"vite": "^6.3.5",
|
||||
"vite": "^7.2.0",
|
||||
"vite-plugin-svgr": "^4.3.0"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import DocsGPT3 from './assets/cute_docsgpt3.svg';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import DocsGPT3 from './assets/cute_docsgpt3.svg';
|
||||
import DropdownModel from './components/DropdownModel';
|
||||
|
||||
export default function Hero({
|
||||
handleQuestion,
|
||||
}: {
|
||||
@@ -26,6 +28,10 @@ export default function Hero({
|
||||
<span className="text-4xl font-semibold">DocsGPT</span>
|
||||
<img className="mb-1 inline w-14" src={DocsGPT3} alt="docsgpt" />
|
||||
</div>
|
||||
{/* Model Selector */}
|
||||
<div className="relative w-72">
|
||||
<DropdownModel />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Demo Buttons Section */}
|
||||
@@ -38,7 +44,7 @@ export default function Hero({
|
||||
<button
|
||||
key={key}
|
||||
onClick={() => handleQuestion({ question: demo.query })}
|
||||
className={`border-dark-gray text-just-black hover:bg-cultured dark:border-dim-gray dark:text-chinese-white dark:hover:bg-charleston-green w-full rounded-[66px] border bg-transparent px-6 py-[14px] text-left transition-colors ${key >= 2 ? 'hidden md:block' : ''} // Show only 2 buttons on mobile`}
|
||||
className={`border-dark-gray text-just-black hover:bg-cultured dark:border-dim-gray dark:text-chinese-white dark:hover:bg-charleston-green w-full rounded-[66px] border bg-transparent px-6 py-[14px] text-left transition-colors ${key >= 2 ? 'hidden md:block' : ''}`}
|
||||
>
|
||||
<p className="text-black-1000 dark:text-bright-gray mb-2 font-semibold">
|
||||
{demo.header}
|
||||
|
||||
@@ -567,7 +567,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
<div className="flex items-center gap-1 pr-4">
|
||||
<NavLink
|
||||
target="_blank"
|
||||
to={'https://discord.gg/WHJdfbQDR4'}
|
||||
to={'https://discord.gg/vN7YFfdMpj'}
|
||||
className={
|
||||
'rounded-full hover:bg-gray-100 dark:hover:bg-[#28292E]'
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import { useTranslation } from 'react-i18next';
|
||||
import { useDispatch, useSelector } from 'react-redux';
|
||||
import { useNavigate, useParams } from 'react-router-dom';
|
||||
|
||||
import modelService from '../api/services/modelService';
|
||||
import userService from '../api/services/userService';
|
||||
import ArrowLeft from '../assets/arrow-left.svg';
|
||||
import SourceIcon from '../assets/source.svg';
|
||||
@@ -26,6 +27,7 @@ import { UserToolType } from '../settings/types';
|
||||
import AgentPreview from './AgentPreview';
|
||||
import { Agent, ToolSummary } from './types';
|
||||
|
||||
import type { Model } from '../models/types';
|
||||
const embeddingsName =
|
||||
import.meta.env.VITE_EMBEDDINGS_NAME ||
|
||||
'huggingface_sentence-transformers/all-mpnet-base-v2';
|
||||
@@ -59,18 +61,25 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
token_limit: undefined,
|
||||
limited_request_mode: false,
|
||||
request_limit: undefined,
|
||||
models: [],
|
||||
default_model_id: '',
|
||||
});
|
||||
const [imageFile, setImageFile] = useState<File | null>(null);
|
||||
const [prompts, setPrompts] = useState<
|
||||
{ name: string; id: string; type: string }[]
|
||||
>([]);
|
||||
const [userTools, setUserTools] = useState<OptionType[]>([]);
|
||||
const [availableModels, setAvailableModels] = useState<Model[]>([]);
|
||||
const [isSourcePopupOpen, setIsSourcePopupOpen] = useState(false);
|
||||
const [isToolsPopupOpen, setIsToolsPopupOpen] = useState(false);
|
||||
const [isModelsPopupOpen, setIsModelsPopupOpen] = useState(false);
|
||||
const [selectedSourceIds, setSelectedSourceIds] = useState<
|
||||
Set<string | number>
|
||||
>(new Set());
|
||||
const [selectedTools, setSelectedTools] = useState<ToolSummary[]>([]);
|
||||
const [selectedModelIds, setSelectedModelIds] = useState<Set<string>>(
|
||||
new Set(),
|
||||
);
|
||||
const [deleteConfirmation, setDeleteConfirmation] =
|
||||
useState<ActiveState>('INACTIVE');
|
||||
const [agentDetails, setAgentDetails] = useState<ActiveState>('INACTIVE');
|
||||
@@ -86,6 +95,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
const initialAgentRef = useRef<Agent | null>(null);
|
||||
const sourceAnchorButtonRef = useRef<HTMLButtonElement>(null);
|
||||
const toolAnchorButtonRef = useRef<HTMLButtonElement>(null);
|
||||
const modelAnchorButtonRef = useRef<HTMLButtonElement>(null);
|
||||
|
||||
const modeConfig = {
|
||||
new: {
|
||||
@@ -224,6 +234,13 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
formData.append('json_schema', JSON.stringify(agent.json_schema));
|
||||
}
|
||||
|
||||
if (agent.models && agent.models.length > 0) {
|
||||
formData.append('models', JSON.stringify(agent.models));
|
||||
}
|
||||
if (agent.default_model_id) {
|
||||
formData.append('default_model_id', agent.default_model_id);
|
||||
}
|
||||
|
||||
try {
|
||||
setDraftLoading(true);
|
||||
const response =
|
||||
@@ -320,6 +337,13 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
formData.append('request_limit', '0');
|
||||
}
|
||||
|
||||
if (agent.models && agent.models.length > 0) {
|
||||
formData.append('models', JSON.stringify(agent.models));
|
||||
}
|
||||
if (agent.default_model_id) {
|
||||
formData.append('default_model_id', agent.default_model_id);
|
||||
}
|
||||
|
||||
try {
|
||||
setPublishLoading(true);
|
||||
const response =
|
||||
@@ -388,8 +412,16 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
const data = await response.json();
|
||||
setPrompts(data);
|
||||
};
|
||||
const getModels = async () => {
|
||||
const response = await modelService.getModels(null);
|
||||
if (!response.ok) throw new Error('Failed to fetch models');
|
||||
const data = await response.json();
|
||||
const transformed = modelService.transformModels(data.models || []);
|
||||
setAvailableModels(transformed);
|
||||
};
|
||||
getTools();
|
||||
getPrompts();
|
||||
getModels();
|
||||
}, [token]);
|
||||
|
||||
// Auto-select default source if none selected
|
||||
@@ -462,6 +494,34 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
}
|
||||
}, [agentId, mode, token]);
|
||||
|
||||
useEffect(() => {
|
||||
if (agent.models && agent.models.length > 0 && availableModels.length > 0) {
|
||||
const agentModelIds = new Set(agent.models);
|
||||
if (agentModelIds.size > 0 && selectedModelIds.size === 0) {
|
||||
setSelectedModelIds(agentModelIds);
|
||||
}
|
||||
}
|
||||
}, [agent.models, availableModels.length]);
|
||||
|
||||
useEffect(() => {
|
||||
const modelsArray = Array.from(selectedModelIds);
|
||||
if (modelsArray.length > 0) {
|
||||
setAgent((prev) => ({
|
||||
...prev,
|
||||
models: modelsArray,
|
||||
default_model_id: modelsArray.includes(prev.default_model_id || '')
|
||||
? prev.default_model_id
|
||||
: modelsArray[0],
|
||||
}));
|
||||
} else {
|
||||
setAgent((prev) => ({
|
||||
...prev,
|
||||
models: [],
|
||||
default_model_id: '',
|
||||
}));
|
||||
}
|
||||
}, [selectedModelIds]);
|
||||
|
||||
useEffect(() => {
|
||||
const selectedSources = Array.from(selectedSourceIds)
|
||||
.map((id) =>
|
||||
@@ -882,6 +942,82 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="dark:bg-raisin-black rounded-[30px] bg-white px-6 py-3 dark:text-[#E0E0E0]">
|
||||
<h2 className="text-lg font-semibold">
|
||||
{t('agents.form.sections.models')}
|
||||
</h2>
|
||||
<div className="mt-3 flex flex-col gap-3">
|
||||
<button
|
||||
ref={modelAnchorButtonRef}
|
||||
onClick={() => setIsModelsPopupOpen(!isModelsPopupOpen)}
|
||||
className={`border-silver dark:bg-raisin-black w-full truncate rounded-3xl border bg-white px-5 py-3 text-left text-sm dark:border-[#7E7E7E] ${
|
||||
selectedModelIds.size > 0
|
||||
? 'text-jet dark:text-bright-gray'
|
||||
: 'dark:text-silver text-gray-400'
|
||||
}`}
|
||||
>
|
||||
{selectedModelIds.size > 0
|
||||
? availableModels
|
||||
.filter((m) => selectedModelIds.has(m.id))
|
||||
.map((m) => m.display_name)
|
||||
.join(', ')
|
||||
: t('agents.form.placeholders.selectModels')}
|
||||
</button>
|
||||
<MultiSelectPopup
|
||||
isOpen={isModelsPopupOpen}
|
||||
onClose={() => setIsModelsPopupOpen(false)}
|
||||
anchorRef={modelAnchorButtonRef}
|
||||
options={availableModels.map((model) => ({
|
||||
id: model.id,
|
||||
label: model.display_name,
|
||||
}))}
|
||||
selectedIds={selectedModelIds}
|
||||
onSelectionChange={(newSelectedIds: Set<string | number>) =>
|
||||
setSelectedModelIds(
|
||||
new Set(Array.from(newSelectedIds).map(String)),
|
||||
)
|
||||
}
|
||||
title={t('agents.form.modelsPopup.title')}
|
||||
searchPlaceholder={t(
|
||||
'agents.form.modelsPopup.searchPlaceholder',
|
||||
)}
|
||||
noOptionsMessage={t('agents.form.modelsPopup.noOptionsMessage')}
|
||||
/>
|
||||
{selectedModelIds.size > 0 && (
|
||||
<div>
|
||||
<label className="mb-2 block text-sm font-medium">
|
||||
{t('agents.form.labels.defaultModel')}
|
||||
</label>
|
||||
<Dropdown
|
||||
options={availableModels
|
||||
.filter((m) => selectedModelIds.has(m.id))
|
||||
.map((m) => ({
|
||||
label: m.display_name,
|
||||
value: m.id,
|
||||
}))}
|
||||
selectedValue={
|
||||
availableModels.find(
|
||||
(m) => m.id === agent.default_model_id,
|
||||
)?.display_name || null
|
||||
}
|
||||
onSelect={(option: { label: string; value: string }) =>
|
||||
setAgent({ ...agent, default_model_id: option.value })
|
||||
}
|
||||
size="w-full"
|
||||
rounded="3xl"
|
||||
border="border"
|
||||
buttonClassName="bg-white dark:bg-[#222327] border-silver dark:border-[#7E7E7E]"
|
||||
optionsClassName="bg-white dark:bg-[#383838] border-silver dark:border-[#7E7E7E]"
|
||||
placeholder={t(
|
||||
'agents.form.placeholders.selectDefaultModel',
|
||||
)}
|
||||
placeholderClassName="text-gray-400 dark:text-silver"
|
||||
contentSize="text-sm"
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
<div className="dark:bg-raisin-black rounded-[30px] bg-white px-6 py-3 dark:text-[#E0E0E0]">
|
||||
<button
|
||||
onClick={() =>
|
||||
|
||||
@@ -52,6 +52,10 @@ export const fetchPreviewAnswer = createAsyncThunk<
|
||||
}
|
||||
|
||||
if (state.preference) {
|
||||
const modelId =
|
||||
state.preference.selectedAgent?.default_model_id ||
|
||||
state.preference.selectedModel?.id;
|
||||
|
||||
if (API_STREAMING) {
|
||||
await handleFetchAnswerSteaming(
|
||||
question,
|
||||
@@ -120,22 +124,23 @@ export const fetchPreviewAnswer = createAsyncThunk<
|
||||
indx,
|
||||
state.preference.selectedAgent?.id,
|
||||
attachmentIds,
|
||||
false, // Don't save preview conversations
|
||||
false,
|
||||
modelId,
|
||||
);
|
||||
} else {
|
||||
// Non-streaming implementation
|
||||
const answer = await handleFetchAnswer(
|
||||
question,
|
||||
signal,
|
||||
state.preference.token,
|
||||
state.preference.selectedDocs,
|
||||
null, // No conversation ID for previews
|
||||
null,
|
||||
state.preference.prompt.id,
|
||||
state.preference.chunks,
|
||||
state.preference.token_limit,
|
||||
state.preference.selectedAgent?.id,
|
||||
attachmentIds,
|
||||
false, // Don't save preview conversations
|
||||
false,
|
||||
modelId,
|
||||
);
|
||||
|
||||
if (answer) {
|
||||
|
||||
@@ -32,4 +32,6 @@ export type Agent = {
|
||||
token_limit?: number;
|
||||
limited_request_mode?: boolean;
|
||||
request_limit?: number;
|
||||
models?: string[];
|
||||
default_model_id?: string;
|
||||
};
|
||||
|
||||
@@ -2,6 +2,7 @@ const endpoints = {
|
||||
USER: {
|
||||
CONFIG: '/api/config',
|
||||
NEW_TOKEN: '/api/generate_token',
|
||||
MODELS: '/api/models',
|
||||
DOCS: '/api/sources',
|
||||
DOCS_PAGINATED: '/api/sources/paginated',
|
||||
API_KEYS: '/api/get_api_keys',
|
||||
|
||||
25
frontend/src/api/services/modelService.ts
Normal file
25
frontend/src/api/services/modelService.ts
Normal file
@@ -0,0 +1,25 @@
|
||||
import apiClient from '../client';
|
||||
import endpoints from '../endpoints';
|
||||
|
||||
import type { AvailableModel, Model } from '../../models/types';
|
||||
|
||||
const modelService = {
|
||||
getModels: (token: string | null): Promise<Response> =>
|
||||
apiClient.get(endpoints.USER.MODELS, token, {}),
|
||||
|
||||
transformModels: (models: AvailableModel[]): Model[] =>
|
||||
models.map((model) => ({
|
||||
id: model.id,
|
||||
value: model.id,
|
||||
provider: model.provider,
|
||||
display_name: model.display_name,
|
||||
description: model.description,
|
||||
context_window: model.context_window,
|
||||
supported_attachment_types: model.supported_attachment_types,
|
||||
supports_tools: model.supports_tools,
|
||||
supports_structured_output: model.supports_structured_output,
|
||||
supports_streaming: model.supports_streaming,
|
||||
})),
|
||||
};
|
||||
|
||||
export default modelService;
|
||||
3
frontend/src/assets/rounded-tick.svg
Normal file
3
frontend/src/assets/rounded-tick.svg
Normal file
@@ -0,0 +1,3 @@
|
||||
<svg width="20" height="21" viewBox="0 0 20 21" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M10 0.75C4.62391 0.75 0.25 5.12391 0.25 10.5C0.25 15.8761 4.62391 20.25 10 20.25C15.3761 20.25 19.75 15.8761 19.75 10.5C19.75 5.12391 15.3761 0.75 10 0.75ZM15.0742 7.23234L8.77422 14.7323C8.70511 14.8147 8.61912 14.8812 8.52207 14.9273C8.42502 14.9735 8.31918 14.9983 8.21172 15H8.19906C8.09394 15 7.99 14.9778 7.89398 14.935C7.79797 14.8922 7.71202 14.8297 7.64172 14.7516L4.94172 11.7516C4.87315 11.6788 4.81981 11.5931 4.78483 11.4995C4.74986 11.4059 4.73395 11.3062 4.73805 11.2063C4.74215 11.1064 4.76617 11.0084 4.8087 10.9179C4.85124 10.8275 4.91142 10.7464 4.98572 10.6796C5.06002 10.6127 5.14694 10.5614 5.24136 10.5286C5.33579 10.4958 5.43581 10.4822 5.53556 10.4886C5.63531 10.495 5.73277 10.5213 5.82222 10.5659C5.91166 10.6106 5.99128 10.6726 6.05641 10.7484L8.17938 13.1072L13.9258 6.26766C14.0547 6.11863 14.237 6.02631 14.4335 6.01066C14.6299 5.99501 14.8246 6.05728 14.9754 6.18402C15.1263 6.31075 15.2212 6.49176 15.2397 6.68793C15.2582 6.8841 15.1988 7.07966 15.0742 7.23234Z" fill="#B5B5B5"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.1 KiB |
138
frontend/src/components/DropdownModel.tsx
Normal file
138
frontend/src/components/DropdownModel.tsx
Normal file
@@ -0,0 +1,138 @@
|
||||
import React, { useEffect } from 'react';
|
||||
import { useDispatch, useSelector } from 'react-redux';
|
||||
|
||||
import modelService from '../api/services/modelService';
|
||||
import Arrow2 from '../assets/dropdown-arrow.svg';
|
||||
import RoundedTick from '../assets/rounded-tick.svg';
|
||||
import {
|
||||
selectAvailableModels,
|
||||
selectSelectedModel,
|
||||
setAvailableModels,
|
||||
setModelsLoading,
|
||||
setSelectedModel,
|
||||
} from '../preferences/preferenceSlice';
|
||||
|
||||
import type { Model } from '../models/types';
|
||||
|
||||
export default function DropdownModel() {
|
||||
const dispatch = useDispatch();
|
||||
const selectedModel = useSelector(selectSelectedModel);
|
||||
const availableModels = useSelector(selectAvailableModels);
|
||||
const dropdownRef = React.useRef<HTMLDivElement>(null);
|
||||
const [isOpen, setIsOpen] = React.useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
const loadModels = async () => {
|
||||
if ((availableModels?.length ?? 0) > 0) {
|
||||
return;
|
||||
}
|
||||
dispatch(setModelsLoading(true));
|
||||
try {
|
||||
const response = await modelService.getModels(null);
|
||||
if (!response.ok) {
|
||||
throw new Error(`API error: ${response.status}`);
|
||||
}
|
||||
const data = await response.json();
|
||||
const models = data.models || [];
|
||||
const transformed = modelService.transformModels(models);
|
||||
|
||||
dispatch(setAvailableModels(transformed));
|
||||
if (!selectedModel && transformed.length > 0) {
|
||||
const defaultModel =
|
||||
transformed.find((m) => m.id === data.default_model_id) ||
|
||||
transformed[0];
|
||||
dispatch(setSelectedModel(defaultModel));
|
||||
} else if (selectedModel && transformed.length > 0) {
|
||||
const isValid = transformed.find((m) => m.id === selectedModel.id);
|
||||
if (!isValid) {
|
||||
const defaultModel =
|
||||
transformed.find((m) => m.id === data.default_model_id) ||
|
||||
transformed[0];
|
||||
dispatch(setSelectedModel(defaultModel));
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to load models:', error);
|
||||
} finally {
|
||||
dispatch(setModelsLoading(false));
|
||||
}
|
||||
};
|
||||
|
||||
loadModels();
|
||||
}, [availableModels?.length, dispatch, selectedModel]);
|
||||
|
||||
const handleClickOutside = (event: MouseEvent) => {
|
||||
if (
|
||||
dropdownRef.current &&
|
||||
!dropdownRef.current.contains(event.target as Node)
|
||||
) {
|
||||
setIsOpen(false);
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
document.addEventListener('mousedown', handleClickOutside);
|
||||
return () => {
|
||||
document.removeEventListener('mousedown', handleClickOutside);
|
||||
};
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<div ref={dropdownRef}>
|
||||
<div
|
||||
className={`bg-gray-1000 dark:bg-dark-charcoal mx-auto flex w-full cursor-pointer justify-between p-1 dark:text-white ${isOpen ? 'rounded-t-3xl' : 'rounded-3xl'}`}
|
||||
onClick={() => setIsOpen(!isOpen)}
|
||||
>
|
||||
{selectedModel?.display_name ? (
|
||||
<p className="mx-4 my-3 truncate overflow-hidden whitespace-nowrap">
|
||||
{selectedModel.display_name}
|
||||
</p>
|
||||
) : (
|
||||
<p className="mx-4 my-3 truncate overflow-hidden whitespace-nowrap">
|
||||
Select Model
|
||||
</p>
|
||||
)}
|
||||
<img
|
||||
src={Arrow2}
|
||||
alt="arrow"
|
||||
className={`${
|
||||
isOpen ? 'rotate-360' : 'rotate-270'
|
||||
} mr-3 w-3 transition-all select-none`}
|
||||
/>
|
||||
</div>
|
||||
{isOpen && (
|
||||
<div className="no-scrollbar dark:bg-dark-charcoal absolute right-0 left-0 z-20 -mt-1 max-h-52 w-full overflow-y-auto rounded-b-3xl bg-white shadow-md">
|
||||
{availableModels && (availableModels?.length ?? 0) > 0 ? (
|
||||
availableModels.map((model: Model) => (
|
||||
<div
|
||||
key={model.id}
|
||||
onClick={() => {
|
||||
dispatch(setSelectedModel(model));
|
||||
setIsOpen(false);
|
||||
}}
|
||||
className={`border-gray-3000/75 dark:border-purple-taupe/50 hover:bg-gray-3000/75 dark:hover:bg-purple-taupe flex h-10 w-full cursor-pointer items-center justify-between border-t`}
|
||||
>
|
||||
<div className="flex w-full items-center justify-between">
|
||||
<p className="overflow-hidden py-3 pr-2 pl-5 overflow-ellipsis whitespace-nowrap">
|
||||
{model.display_name}
|
||||
</p>
|
||||
{model.id === selectedModel?.id ? (
|
||||
<img
|
||||
src={RoundedTick}
|
||||
alt="selected"
|
||||
className="mr-3.5 h-4 w-4"
|
||||
/>
|
||||
) : null}
|
||||
</div>
|
||||
</div>
|
||||
))
|
||||
) : (
|
||||
<div className="h-10 w-full border-x-2 border-b-2">
|
||||
<p className="ml-5 py-3 text-gray-500">No models available</p>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -19,8 +19,8 @@ import {
|
||||
removeAttachment,
|
||||
selectAttachments,
|
||||
updateAttachment,
|
||||
reorderAttachments,
|
||||
} from '../upload/uploadSlice';
|
||||
import { reorderAttachments } from '../upload/uploadSlice';
|
||||
|
||||
import { ActiveState } from '../models/misc';
|
||||
import {
|
||||
@@ -77,7 +77,7 @@ export default function MessageInput({
|
||||
(browserOS === 'mac' && event.metaKey && event.key === 'k')
|
||||
) {
|
||||
event.preventDefault();
|
||||
setIsSourcesPopupOpen(!isSourcesPopupOpen);
|
||||
setIsSourcesPopupOpen((s) => !s);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -89,8 +89,198 @@ export default function MessageInput({
|
||||
|
||||
const uploadFiles = useCallback(
|
||||
(files: File[]) => {
|
||||
if (!files || files.length === 0) return;
|
||||
|
||||
const apiHost = import.meta.env.VITE_API_HOST;
|
||||
|
||||
if (files.length > 1) {
|
||||
const formData = new FormData();
|
||||
const indexToUiId: Record<number, string> = {};
|
||||
|
||||
files.forEach((file, i) => {
|
||||
formData.append('file', file);
|
||||
const uiId = crypto.randomUUID();
|
||||
indexToUiId[i] = uiId;
|
||||
dispatch(
|
||||
addAttachment({
|
||||
id: uiId,
|
||||
fileName: file.name,
|
||||
progress: 0,
|
||||
status: 'uploading' as const,
|
||||
taskId: '',
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
const xhr = new XMLHttpRequest();
|
||||
|
||||
xhr.upload.addEventListener('progress', (event) => {
|
||||
if (event.lengthComputable) {
|
||||
const progress = Math.round((event.loaded / event.total) * 100);
|
||||
Object.values(indexToUiId).forEach((uiId) =>
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: uiId,
|
||||
updates: { progress },
|
||||
}),
|
||||
),
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
xhr.onload = () => {
|
||||
const status = xhr.status;
|
||||
if (status === 200) {
|
||||
try {
|
||||
const response = JSON.parse(xhr.responseText);
|
||||
|
||||
if (Array.isArray(response?.tasks)) {
|
||||
const tasks = response.tasks as Array<{
|
||||
task_id?: string;
|
||||
filename?: string;
|
||||
attachment_id?: string;
|
||||
path?: string;
|
||||
}>;
|
||||
|
||||
tasks.forEach((t, idx) => {
|
||||
const uiId = indexToUiId[idx];
|
||||
if (!uiId) return;
|
||||
if (t?.task_id) {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: uiId,
|
||||
updates: {
|
||||
taskId: t.task_id,
|
||||
status: 'processing',
|
||||
progress: 10,
|
||||
},
|
||||
}),
|
||||
);
|
||||
} else {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: uiId,
|
||||
updates: { status: 'failed' },
|
||||
}),
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
if (tasks.length < files.length) {
|
||||
for (let i = tasks.length; i < files.length; i++) {
|
||||
const uiId = indexToUiId[i];
|
||||
if (uiId) {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: uiId,
|
||||
updates: { status: 'failed' },
|
||||
}),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (response?.task_id) {
|
||||
if (files.length === 1) {
|
||||
const uiId = indexToUiId[0];
|
||||
if (uiId) {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: uiId,
|
||||
updates: {
|
||||
taskId: response.task_id,
|
||||
status: 'processing',
|
||||
progress: 10,
|
||||
},
|
||||
}),
|
||||
);
|
||||
}
|
||||
} else {
|
||||
console.warn(
|
||||
'Server returned a single task_id for multiple files. Update backend to return tasks[].',
|
||||
);
|
||||
const firstUi = indexToUiId[0];
|
||||
if (firstUi) {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: firstUi,
|
||||
updates: {
|
||||
taskId: response.task_id,
|
||||
status: 'processing',
|
||||
progress: 10,
|
||||
},
|
||||
}),
|
||||
);
|
||||
}
|
||||
for (let i = 1; i < files.length; i++) {
|
||||
const uiId = indexToUiId[i];
|
||||
if (uiId) {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: uiId,
|
||||
updates: { status: 'failed' },
|
||||
}),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
console.error('Unexpected upload response shape', response);
|
||||
Object.values(indexToUiId).forEach((id) =>
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id,
|
||||
updates: { status: 'failed' },
|
||||
}),
|
||||
),
|
||||
);
|
||||
}
|
||||
} catch (err) {
|
||||
console.error(
|
||||
'Failed to parse upload response',
|
||||
err,
|
||||
xhr.responseText,
|
||||
);
|
||||
Object.values(indexToUiId).forEach((id) =>
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id,
|
||||
updates: { status: 'failed' },
|
||||
}),
|
||||
),
|
||||
);
|
||||
}
|
||||
} else {
|
||||
console.error('Upload failed', status, xhr.responseText);
|
||||
Object.values(indexToUiId).forEach((id) =>
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id,
|
||||
updates: { status: 'failed' },
|
||||
}),
|
||||
),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
xhr.onerror = () => {
|
||||
console.error('Upload network error');
|
||||
Object.values(indexToUiId).forEach((id) =>
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id,
|
||||
updates: { status: 'failed' },
|
||||
}),
|
||||
),
|
||||
);
|
||||
};
|
||||
|
||||
xhr.open('POST', `${apiHost}${endpoints.USER.STORE_ATTACHMENT}`);
|
||||
if (token) xhr.setRequestHeader('Authorization', `Bearer ${token}`);
|
||||
xhr.send(formData);
|
||||
return;
|
||||
}
|
||||
|
||||
// Single-file path: upload each file individually (original repo behavior)
|
||||
files.forEach((file) => {
|
||||
const formData = new FormData();
|
||||
formData.append('file', file);
|
||||
@@ -121,16 +311,54 @@ export default function MessageInput({
|
||||
|
||||
xhr.onload = () => {
|
||||
if (xhr.status === 200) {
|
||||
const response = JSON.parse(xhr.responseText);
|
||||
if (response.task_id) {
|
||||
try {
|
||||
const response = JSON.parse(xhr.responseText);
|
||||
if (response.task_id) {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: uniqueId,
|
||||
updates: {
|
||||
taskId: response.task_id,
|
||||
status: 'processing',
|
||||
progress: 10,
|
||||
},
|
||||
}),
|
||||
);
|
||||
} else {
|
||||
// If backend returned tasks[] for single-file, handle gracefully:
|
||||
if (
|
||||
Array.isArray(response?.tasks) &&
|
||||
response.tasks[0]?.task_id
|
||||
) {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: uniqueId,
|
||||
updates: {
|
||||
taskId: response.tasks[0].task_id,
|
||||
status: 'processing',
|
||||
progress: 10,
|
||||
},
|
||||
}),
|
||||
);
|
||||
} else {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: uniqueId,
|
||||
updates: { status: 'failed' },
|
||||
}),
|
||||
);
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
console.error(
|
||||
'Failed to parse upload response',
|
||||
err,
|
||||
xhr.responseText,
|
||||
);
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: uniqueId,
|
||||
updates: {
|
||||
taskId: response.task_id,
|
||||
status: 'processing',
|
||||
progress: 10,
|
||||
},
|
||||
updates: { status: 'failed' },
|
||||
}),
|
||||
);
|
||||
}
|
||||
@@ -154,7 +382,7 @@ export default function MessageInput({
|
||||
};
|
||||
|
||||
xhr.open('POST', `${apiHost}${endpoints.USER.STORE_ATTACHMENT}`);
|
||||
xhr.setRequestHeader('Authorization', `Bearer ${token}`);
|
||||
if (token) xhr.setRequestHeader('Authorization', `Bearer ${token}`);
|
||||
xhr.send(formData);
|
||||
});
|
||||
},
|
||||
@@ -163,15 +391,13 @@ export default function MessageInput({
|
||||
|
||||
const handleFileAttachment = (e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
if (!e.target.files || e.target.files.length === 0) return;
|
||||
|
||||
const files = Array.from(e.target.files);
|
||||
uploadFiles(files);
|
||||
|
||||
// clear input so same file can be selected again
|
||||
e.target.value = '';
|
||||
};
|
||||
|
||||
// Drag and drop handler
|
||||
// Drag & drop via react-dropzone
|
||||
const onDrop = useCallback(
|
||||
(acceptedFiles: File[]) => {
|
||||
uploadFiles(acceptedFiles);
|
||||
@@ -321,11 +547,8 @@ export default function MessageInput({
|
||||
handleAbort();
|
||||
};
|
||||
|
||||
// Drag state for reordering
|
||||
const [draggingId, setDraggingId] = useState<string | null>(null);
|
||||
|
||||
// no preview object URLs to revoke (preview removed per reviewer request)
|
||||
|
||||
const findIndexById = (id: string) =>
|
||||
attachments.findIndex((a) => a.id === id);
|
||||
|
||||
@@ -359,7 +582,9 @@ export default function MessageInput({
|
||||
|
||||
return (
|
||||
<div {...getRootProps()} className="flex w-full flex-col">
|
||||
{/* react-dropzone input (for drag/drop) */}
|
||||
<input {...getInputProps()} />
|
||||
|
||||
<div className="border-dark-gray bg-lotion dark:border-grey relative flex w-full flex-col rounded-[23px] border dark:bg-transparent">
|
||||
<div className="flex flex-wrap gap-1.5 px-2 py-2 sm:gap-2 sm:px-3">
|
||||
{attachments.map((attachment) => {
|
||||
@@ -374,7 +599,11 @@ export default function MessageInput({
|
||||
attachment.status !== 'completed'
|
||||
? 'opacity-70'
|
||||
: 'opacity-100'
|
||||
} ${draggingId === attachment.id ? 'ring-dashed opacity-60 ring-2 ring-purple-200' : ''}`}
|
||||
} ${
|
||||
draggingId === attachment.id
|
||||
? 'ring-dashed opacity-60 ring-2 ring-purple-200'
|
||||
: ''
|
||||
}`}
|
||||
title={attachment.fileName}
|
||||
>
|
||||
<div className="bg-purple-30 mr-2 flex h-8 w-8 items-center justify-center rounded-md p-1">
|
||||
|
||||
@@ -15,6 +15,7 @@ export function handleFetchAnswer(
|
||||
agentId?: string,
|
||||
attachments?: string[],
|
||||
save_conversation = true,
|
||||
modelId?: string,
|
||||
): Promise<
|
||||
| {
|
||||
result: any;
|
||||
@@ -47,6 +48,10 @@ export function handleFetchAnswer(
|
||||
save_conversation: save_conversation,
|
||||
};
|
||||
|
||||
if (modelId) {
|
||||
payload.model_id = modelId;
|
||||
}
|
||||
|
||||
// Add attachments to payload if they exist
|
||||
if (attachments && attachments.length > 0) {
|
||||
payload.attachments = attachments;
|
||||
@@ -101,6 +106,7 @@ export function handleFetchAnswerSteaming(
|
||||
agentId?: string,
|
||||
attachments?: string[],
|
||||
save_conversation = true,
|
||||
modelId?: string,
|
||||
): Promise<Answer> {
|
||||
const payload: RetrievalPayload = {
|
||||
question: question,
|
||||
@@ -114,6 +120,10 @@ export function handleFetchAnswerSteaming(
|
||||
save_conversation: save_conversation,
|
||||
};
|
||||
|
||||
if (modelId) {
|
||||
payload.model_id = modelId;
|
||||
}
|
||||
|
||||
// Add attachments to payload if they exist
|
||||
if (attachments && attachments.length > 0) {
|
||||
payload.attachments = attachments;
|
||||
|
||||
@@ -65,4 +65,5 @@ export interface RetrievalPayload {
|
||||
agent_id?: string;
|
||||
attachments?: string[];
|
||||
save_conversation?: boolean;
|
||||
model_id?: string;
|
||||
}
|
||||
|
||||
@@ -49,6 +49,9 @@ export const fetchAnswer = createAsyncThunk<
|
||||
}
|
||||
|
||||
const currentConversationId = state.conversation.conversationId;
|
||||
const modelId =
|
||||
state.preference.selectedAgent?.default_model_id ||
|
||||
state.preference.selectedModel?.id;
|
||||
|
||||
if (state.preference) {
|
||||
if (API_STREAMING) {
|
||||
@@ -156,7 +159,8 @@ export const fetchAnswer = createAsyncThunk<
|
||||
indx,
|
||||
state.preference.selectedAgent?.id,
|
||||
attachmentIds,
|
||||
true, // Always save conversation
|
||||
true,
|
||||
modelId,
|
||||
);
|
||||
} else {
|
||||
const answer = await handleFetchAnswer(
|
||||
@@ -170,7 +174,8 @@ export const fetchAnswer = createAsyncThunk<
|
||||
state.preference.token_limit,
|
||||
state.preference.selectedAgent?.id,
|
||||
attachmentIds,
|
||||
true, // Always save conversation
|
||||
true,
|
||||
modelId,
|
||||
);
|
||||
if (answer) {
|
||||
let sourcesPrepped = [];
|
||||
|
||||
646
frontend/src/locale/de.json
Normal file
646
frontend/src/locale/de.json
Normal file
@@ -0,0 +1,646 @@
|
||||
{
|
||||
"language": "Deutsch",
|
||||
"chat": "Chat",
|
||||
"chats": "Chats",
|
||||
"newChat": "Neuer Chat",
|
||||
"inputPlaceholder": "Wie kann DocsGPT dir helfen?",
|
||||
"tagline": "DocsGPT verwendet GenAI, bitte überprüfe kritische Informationen anhand der Quellen.",
|
||||
"sourceDocs": "Quelle",
|
||||
"none": "Keine",
|
||||
"cancel": "Abbrechen",
|
||||
"help": "Hilfe",
|
||||
"emailUs": "E-Mail senden",
|
||||
"documentation": "Dokumentation",
|
||||
"manageAgents": "Agenten verwalten",
|
||||
"demo": [
|
||||
{
|
||||
"header": "Über DocsGPT lernen",
|
||||
"query": "Was ist DocsGPT?"
|
||||
},
|
||||
{
|
||||
"header": "Dokumentation zusammenfassen",
|
||||
"query": "Fasse den aktuellen Kontext zusammen"
|
||||
},
|
||||
{
|
||||
"header": "Code schreiben",
|
||||
"query": "Schreibe Code für eine API-Anfrage an /api/answer"
|
||||
},
|
||||
{
|
||||
"header": "Lernunterstützung",
|
||||
"query": "Schreibe mögliche Fragen zum Kontext"
|
||||
}
|
||||
],
|
||||
"settings": {
|
||||
"label": "Einstellungen",
|
||||
"general": {
|
||||
"label": "Allgemein",
|
||||
"selectTheme": "Design auswählen",
|
||||
"light": "Hell",
|
||||
"dark": "Dunkel",
|
||||
"selectLanguage": "Sprache auswählen",
|
||||
"chunks": "Chunks pro Anfrage",
|
||||
"prompt": "Aktiver Prompt",
|
||||
"deleteAllLabel": "Alle Konversationen löschen",
|
||||
"deleteAllBtn": "Alle löschen",
|
||||
"addNew": "Neu hinzufügen",
|
||||
"convHistory": "Konversationsverlauf",
|
||||
"none": "Keine",
|
||||
"low": "Niedrig",
|
||||
"medium": "Mittel",
|
||||
"high": "Hoch",
|
||||
"unlimited": "Unbegrenzt",
|
||||
"default": "Standard",
|
||||
"add": "Hinzufügen"
|
||||
},
|
||||
"sources": {
|
||||
"title": "Hier kannst du alle verfügbaren Quelldateien verwalten, die dir zur Verfügung stehen und die du hochgeladen hast.",
|
||||
"label": "Quellen",
|
||||
"name": "Quellenname",
|
||||
"date": "Vektor-Datum",
|
||||
"type": "Typ",
|
||||
"tokenUsage": "Token-Verbrauch",
|
||||
"noData": "Keine vorhandenen Quellen",
|
||||
"searchPlaceholder": "Suchen...",
|
||||
"addNew": "Neu hinzufügen",
|
||||
"addSource": "Quelle hinzufügen",
|
||||
"addChunk": "Chunk hinzufügen",
|
||||
"preLoaded": "Vorgeladen",
|
||||
"private": "Privat",
|
||||
"sync": "Synchronisieren",
|
||||
"syncing": "Synchronisiere...",
|
||||
"syncConfirmation": "Bist du sicher, dass du \"{{sourceName}}\" synchronisieren möchtest? Dies aktualisiert den Inhalt mit deinem Cloud-Speicher und kann Änderungen an einzelnen Chunks überschreiben.",
|
||||
"syncFrequency": {
|
||||
"never": "Nie",
|
||||
"daily": "Täglich",
|
||||
"weekly": "Wöchentlich",
|
||||
"monthly": "Monatlich"
|
||||
},
|
||||
"actions": "Aktionen",
|
||||
"view": "Anzeigen",
|
||||
"deleteWarning": "Bist du sicher, dass du \"{{name}}\" löschen möchtest?",
|
||||
"confirmDelete": "Bist du sicher, dass du diese Datei löschen möchtest? Diese Aktion kann nicht rückgängig gemacht werden.",
|
||||
"backToAll": "Zurück zu allen Quellen",
|
||||
"chunks": "Chunks",
|
||||
"noChunks": "Keine Chunks gefunden",
|
||||
"noChunksAlt": "Keine Chunks gefunden",
|
||||
"goToSources": "Zu den Quellen",
|
||||
"uploadNew": "Neu hochladen",
|
||||
"searchFiles": "Dateien suchen...",
|
||||
"noResults": "Keine Ergebnisse gefunden",
|
||||
"fileName": "Name",
|
||||
"tokens": "Tokens",
|
||||
"size": "Größe",
|
||||
"fileAlt": "Datei",
|
||||
"folderAlt": "Ordner",
|
||||
"parentFolderAlt": "Übergeordneter Ordner",
|
||||
"menuAlt": "Menü",
|
||||
"tokensUnit": "Tokens",
|
||||
"editAlt": "Bearbeiten",
|
||||
"uploading": "Wird hochgeladen…",
|
||||
"deleting": "Wird gelöscht…",
|
||||
"queued": "In Warteschlange: {{count}}",
|
||||
"addFile": "Datei hinzufügen",
|
||||
"uploadingFilesTitle": "Dateien werden hochgeladen...",
|
||||
"deletingTitle": "Wird gelöscht...",
|
||||
"deleteDirectoryWarning": "Bist du sicher, dass du das Verzeichnis \"{{name}}\" und seinen gesamten Inhalt löschen möchtest? Diese Aktion kann nicht rückgängig gemacht werden.",
|
||||
"searchAlt": "Suchen"
|
||||
},
|
||||
"apiKeys": {
|
||||
"label": "Chatbots",
|
||||
"name": "Name",
|
||||
"key": "API-Schlüssel",
|
||||
"sourceDoc": "Quelldokument",
|
||||
"createNew": "Neu erstellen",
|
||||
"noData": "Keine vorhandenen Chatbots",
|
||||
"deleteConfirmation": "Bist du sicher, dass du den API-Schlüssel '{{name}}' löschen möchtest?",
|
||||
"description": "Hier kannst du deine Chatbots erstellen und verwalten. Chatbots können als Widgets auf Websites eingebunden oder in deinen Anwendungen verwendet werden."
|
||||
},
|
||||
"analytics": {
|
||||
"label": "Analytik",
|
||||
"filterByChatbot": "Nach Chatbot filtern",
|
||||
"selectChatbot": "Chatbot auswählen",
|
||||
"filterOptions": {
|
||||
"hour": "Stunde",
|
||||
"last24Hours": "24 Stunden",
|
||||
"last7Days": "7 Tage",
|
||||
"last15Days": "15 Tage",
|
||||
"last30Days": "30 Tage"
|
||||
},
|
||||
"messages": "Nachrichten",
|
||||
"tokenUsage": "Token-Verbrauch",
|
||||
"userFeedback": "Benutzer-Feedback",
|
||||
"filterPlaceholder": "Filter",
|
||||
"none": "Keine",
|
||||
"positiveFeedback": "Positives Feedback",
|
||||
"negativeFeedback": "Negatives Feedback"
|
||||
},
|
||||
"logs": {
|
||||
"label": "Protokolle",
|
||||
"filterByChatbot": "Nach Chatbot filtern",
|
||||
"selectChatbot": "Chatbot auswählen",
|
||||
"none": "Keine",
|
||||
"tableHeader": "API-generierte / Chatbot-Konversationen"
|
||||
},
|
||||
"tools": {
|
||||
"label": "Werkzeuge",
|
||||
"searchPlaceholder": "Werkzeuge suchen...",
|
||||
"addTool": "Werkzeug hinzufügen",
|
||||
"noToolsFound": "Keine Werkzeuge gefunden",
|
||||
"selectToolSetup": "Wähle ein Werkzeug zur Einrichtung",
|
||||
"settingsIconAlt": "Einstellungssymbol",
|
||||
"configureToolAria": "{{toolName}} konfigurieren",
|
||||
"toggleToolAria": "{{toolName}} umschalten",
|
||||
"manageTools": "Zu den Werkzeugen",
|
||||
"edit": "Bearbeiten",
|
||||
"delete": "Löschen",
|
||||
"deleteWarning": "Bist du sicher, dass du das Werkzeug \"{{toolName}}\" löschen möchtest?",
|
||||
"unsavedChanges": "Du hast ungespeicherte Änderungen, die verloren gehen, wenn du ohne Speichern verlässt.",
|
||||
"leaveWithoutSaving": "Ohne Speichern verlassen",
|
||||
"saveAndLeave": "Speichern und verlassen",
|
||||
"customName": "Benutzerdefinierter Name",
|
||||
"customNamePlaceholder": "Gib einen benutzerdefinierten Namen ein (optional)",
|
||||
"authentication": "Authentifizierung",
|
||||
"actions": "Aktionen",
|
||||
"addAction": "Aktion hinzufügen",
|
||||
"noActionsFound": "Keine Aktionen gefunden",
|
||||
"url": "URL",
|
||||
"urlPlaceholder": "URL eingeben",
|
||||
"method": "Methode",
|
||||
"description": "Beschreibung",
|
||||
"descriptionPlaceholder": "Beschreibung eingeben",
|
||||
"headers": "Header",
|
||||
"queryParameters": "Abfrageparameter",
|
||||
"body": "Body",
|
||||
"deleteActionWarning": "Bist du sicher, dass du die Aktion \"{{name}}\" löschen möchtest?",
|
||||
"backToAllTools": "Zurück zu allen Werkzeugen",
|
||||
"save": "Speichern",
|
||||
"fieldName": "Feldname",
|
||||
"fieldType": "Feldtyp",
|
||||
"filledByLLM": "Vom LLM ausgefüllt",
|
||||
"fieldDescription": "Feldbeschreibung",
|
||||
"value": "Wert",
|
||||
"addProperty": "Eigenschaft hinzufügen",
|
||||
"propertyName": "Neuer Eigenschaftsschlüssel",
|
||||
"add": "Hinzufügen",
|
||||
"cancel": "Abbrechen",
|
||||
"addNew": "Neu hinzufügen",
|
||||
"name": "Name",
|
||||
"type": "Typ",
|
||||
"mcp": {
|
||||
"addServer": "MCP-Server hinzufügen",
|
||||
"editServer": "Server bearbeiten",
|
||||
"serverName": "Servername",
|
||||
"serverUrl": "Server-URL",
|
||||
"headerName": "Header-Name",
|
||||
"timeout": "Timeout (Sekunden)",
|
||||
"testConnection": "Verbindung testen",
|
||||
"testing": "Teste...",
|
||||
"saving": "Speichere...",
|
||||
"save": "Speichern",
|
||||
"cancel": "Abbrechen",
|
||||
"noAuth": "Keine Authentifizierung",
|
||||
"oauthInProgress": "Warte auf OAuth-Abschluss...",
|
||||
"oauthCompleted": "OAuth erfolgreich abgeschlossen",
|
||||
"authType": "Authentifizierungstyp",
|
||||
"defaultServerName": "Mein MCP-Server",
|
||||
"authTypes": {
|
||||
"none": "Keine Authentifizierung",
|
||||
"apiKey": "API-Schlüssel",
|
||||
"bearer": "Bearer-Token",
|
||||
"oauth": "OAuth",
|
||||
"basic": "Basis-Authentifizierung"
|
||||
},
|
||||
"placeholders": {
|
||||
"serverUrl": "https://api.beispiel.com",
|
||||
"apiKey": "Dein geheimer API-Schlüssel",
|
||||
"bearerToken": "Dein geheimes Token",
|
||||
"username": "Dein Benutzername",
|
||||
"password": "Dein Passwort",
|
||||
"oauthScopes": "OAuth-Bereiche (kommagetrennt)"
|
||||
},
|
||||
"errors": {
|
||||
"nameRequired": "Servername ist erforderlich",
|
||||
"urlRequired": "Server-URL ist erforderlich",
|
||||
"invalidUrl": "Bitte gib eine gültige URL ein",
|
||||
"apiKeyRequired": "API-Schlüssel ist erforderlich",
|
||||
"tokenRequired": "Bearer-Token ist erforderlich",
|
||||
"usernameRequired": "Benutzername ist erforderlich",
|
||||
"passwordRequired": "Passwort ist erforderlich",
|
||||
"testFailed": "Verbindungstest fehlgeschlagen",
|
||||
"saveFailed": "MCP-Server konnte nicht gespeichert werden",
|
||||
"oauthFailed": "OAuth-Prozess fehlgeschlagen oder abgebrochen",
|
||||
"oauthTimeout": "OAuth-Prozess abgelaufen, bitte erneut versuchen",
|
||||
"timeoutRange": "Timeout muss zwischen 1 und 300 Sekunden liegen"
|
||||
}
|
||||
}
|
||||
},
|
||||
"scrollTabsLeft": "Tabs nach links scrollen",
|
||||
"tabsAriaLabel": "Einstellungs-Tabs",
|
||||
"scrollTabsRight": "Tabs nach rechts scrollen"
|
||||
},
|
||||
"modals": {
|
||||
"uploadDoc": {
|
||||
"label": "Neues Dokument hochladen",
|
||||
"select": "Wähle, wie du dein Dokument zu DocsGPT hochladen möchtest",
|
||||
"selectSource": "Wähle die Art, wie du deine Quelle hinzufügen möchtest",
|
||||
"selectedFiles": "Ausgewählte Dateien",
|
||||
"noFilesSelected": "Keine Dateien ausgewählt",
|
||||
"file": "Vom Gerät hochladen",
|
||||
"back": "Zurück",
|
||||
"wait": "Bitte warten ...",
|
||||
"remote": "Von Website sammeln",
|
||||
"start": "Chat starten",
|
||||
"name": "Name",
|
||||
"choose": "Dateien auswählen",
|
||||
"info": "Bitte lade .pdf, .txt, .rst, .csv, .xlsx, .docx, .md, .html, .epub, .json, .pptx, .zip hoch (max. 25 MB)",
|
||||
"uploadedFiles": "Hochgeladene Dateien",
|
||||
"cancel": "Abbrechen",
|
||||
"train": "Trainieren",
|
||||
"link": "Link",
|
||||
"urlLink": "URL-Link",
|
||||
"repoUrl": "Repository-URL",
|
||||
"reddit": {
|
||||
"id": "Client-ID",
|
||||
"secret": "Client-Secret",
|
||||
"agent": "User-Agent",
|
||||
"searchQueries": "Suchanfragen",
|
||||
"numberOfPosts": "Anzahl der Beiträge",
|
||||
"addQuery": "Anfrage hinzufügen"
|
||||
},
|
||||
"drag": {
|
||||
"title": "Anhänge hier ablegen",
|
||||
"description": "Loslassen, um deine Anhänge hochzuladen"
|
||||
},
|
||||
"progress": {
|
||||
"upload": "Upload läuft",
|
||||
"training": "Upload läuft",
|
||||
"completed": "Upload abgeschlossen",
|
||||
"failed": "Upload fehlgeschlagen",
|
||||
"wait": "Dies kann einige Minuten dauern",
|
||||
"preparing": "Upload wird vorbereitet",
|
||||
"tokenLimit": "Token-Limit überschritten, bitte lade ein kleineres Dokument hoch",
|
||||
"expandDetails": "Upload-Details erweitern",
|
||||
"collapseDetails": "Upload-Details einklappen",
|
||||
"dismiss": "Upload-Benachrichtigung schließen",
|
||||
"uploadProgress": "Upload-Fortschritt {{progress}}%",
|
||||
"clear": "Löschen"
|
||||
},
|
||||
"showAdvanced": "Erweiterte Optionen anzeigen",
|
||||
"hideAdvanced": "Erweiterte Optionen ausblenden",
|
||||
"ingestors": {
|
||||
"local_file": {
|
||||
"label": "Datei hochladen",
|
||||
"heading": "Neues Dokument hochladen"
|
||||
},
|
||||
"crawler": {
|
||||
"label": "Crawler",
|
||||
"heading": "Inhalt mit Web-Crawler hinzufügen"
|
||||
},
|
||||
"url": {
|
||||
"label": "Link",
|
||||
"heading": "Inhalt von URL hinzufügen"
|
||||
},
|
||||
"github": {
|
||||
"label": "GitHub",
|
||||
"heading": "Inhalt von GitHub hinzufügen"
|
||||
},
|
||||
"reddit": {
|
||||
"label": "Reddit",
|
||||
"heading": "Inhalt von Reddit hinzufügen"
|
||||
},
|
||||
"google_drive": {
|
||||
"label": "Google Drive",
|
||||
"heading": "Von Google Drive hochladen"
|
||||
}
|
||||
},
|
||||
"connectors": {
|
||||
"auth": {
|
||||
"connectedUser": "Verbundener Benutzer",
|
||||
"authFailed": "Authentifizierung fehlgeschlagen",
|
||||
"authUrlFailed": "Autorisierungs-URL konnte nicht abgerufen werden",
|
||||
"popupBlocked": "Authentifizierungsfenster konnte nicht geöffnet werden. Bitte erlaube Popups.",
|
||||
"authCancelled": "Authentifizierung wurde abgebrochen",
|
||||
"connectedAs": "Verbunden als {{email}}",
|
||||
"disconnect": "Trennen"
|
||||
},
|
||||
"googleDrive": {
|
||||
"connect": "Mit Google Drive verbinden",
|
||||
"sessionExpired": "Sitzung abgelaufen. Bitte verbinde dich erneut mit Google Drive.",
|
||||
"sessionExpiredGeneric": "Sitzung abgelaufen. Bitte verbinde dein Konto erneut.",
|
||||
"validateFailed": "Sitzung konnte nicht validiert werden. Bitte verbinde dich erneut.",
|
||||
"noSession": "Keine gültige Sitzung gefunden. Bitte verbinde dich erneut mit Google Drive.",
|
||||
"noAccessToken": "Kein Zugriffstoken verfügbar. Bitte verbinde dich erneut mit Google Drive.",
|
||||
"pickerFailed": "Dateiauswahl konnte nicht geöffnet werden. Bitte versuche es erneut.",
|
||||
"selectedFiles": "Ausgewählte Dateien",
|
||||
"selectFiles": "Dateien auswählen",
|
||||
"loading": "Laden...",
|
||||
"noFilesSelected": "Keine Dateien oder Ordner ausgewählt",
|
||||
"folders": "Ordner",
|
||||
"files": "Dateien",
|
||||
"remove": "Entfernen",
|
||||
"folderAlt": "Ordner",
|
||||
"fileAlt": "Datei"
|
||||
}
|
||||
}
|
||||
},
|
||||
"createAPIKey": {
|
||||
"label": "Neuen API-Schlüssel erstellen",
|
||||
"apiKeyName": "API-Schlüssel-Name",
|
||||
"chunks": "Chunks pro Anfrage",
|
||||
"prompt": "Aktiven Prompt auswählen",
|
||||
"sourceDoc": "Quelldokument",
|
||||
"create": "Erstellen"
|
||||
},
|
||||
"saveKey": {
|
||||
"note": "Bitte speichere deinen Schlüssel",
|
||||
"disclaimer": "Dies ist das einzige Mal, dass dein Schlüssel angezeigt wird.",
|
||||
"copy": "Kopieren",
|
||||
"copied": "Kopiert",
|
||||
"confirm": "Ich habe den Schlüssel gespeichert",
|
||||
"apiKeyLabel": "API-Schlüssel"
|
||||
},
|
||||
"deleteConv": {
|
||||
"confirm": "Bist du sicher, dass du alle Konversationen löschen möchtest?",
|
||||
"delete": "Löschen"
|
||||
},
|
||||
"shareConv": {
|
||||
"label": "Öffentliche Seite zum Teilen erstellen",
|
||||
"note": "Quelldokument, persönliche Informationen und weitere Konversationen bleiben privat",
|
||||
"create": "Erstellen",
|
||||
"option": "Benutzern weitere Eingaben erlauben"
|
||||
},
|
||||
"configTool": {
|
||||
"title": "Werkzeug-Konfiguration",
|
||||
"type": "Typ",
|
||||
"apiKeyLabel": "API-Schlüssel / OAuth",
|
||||
"apiKeyPlaceholder": "API-Schlüssel / OAuth eingeben",
|
||||
"addButton": "Werkzeug hinzufügen",
|
||||
"closeButton": "Schließen",
|
||||
"customNamePlaceholder": "Benutzerdefinierten Namen eingeben (optional)"
|
||||
},
|
||||
"prompts": {
|
||||
"addPrompt": "Prompt hinzufügen",
|
||||
"addDescription": "Füge deinen benutzerdefinierten Prompt hinzu und speichere ihn in DocsGPT",
|
||||
"editPrompt": "Prompt bearbeiten",
|
||||
"editDescription": "Bearbeite deinen benutzerdefinierten Prompt und speichere ihn in DocsGPT",
|
||||
"promptName": "Prompt-Name",
|
||||
"promptText": "Prompt-Text",
|
||||
"save": "Speichern",
|
||||
"cancel": "Abbrechen",
|
||||
"nameExists": "Name existiert bereits",
|
||||
"deleteConfirmation": "Bist du sicher, dass du den Prompt '{{name}}' löschen möchtest?",
|
||||
"placeholderText": "Gib hier deinen Prompt-Text ein...",
|
||||
"addExamplePlaceholder": "Bitte fasse diesen Text zusammen:",
|
||||
"variablesLabel": "Variablen",
|
||||
"variablesSubtext": "Klicken zum Einfügen in den Prompt",
|
||||
"variablesDescription": "Klicken zum Einfügen in den Prompt",
|
||||
"systemVariables": "Klicken zum Einfügen in den Prompt",
|
||||
"toolVariables": "Werkzeug-Variablen",
|
||||
"systemVariablesDropdownLabel": "System-Variablen",
|
||||
"systemVariableOptions": {
|
||||
"sourceContent": "Quelleninhalte",
|
||||
"sourceSummaries": "Alias für Inhalte (abwärtskompatibel)",
|
||||
"sourceDocuments": "Dokumentenobjekte-Liste",
|
||||
"sourceCount": "Anzahl der abgerufenen Dokumente",
|
||||
"systemDate": "Aktuelles Datum (JJJJ-MM-TT)",
|
||||
"systemTime": "Aktuelle Uhrzeit (HH:MM:SS)",
|
||||
"systemTimestamp": "ISO 8601 Zeitstempel",
|
||||
"systemRequestId": "Eindeutige Anfrage-ID",
|
||||
"systemUserId": "Aktuelle Benutzer-ID"
|
||||
},
|
||||
"learnAboutPrompts": "Mehr über Prompts erfahren →",
|
||||
"publicPromptEditDisabled": "Öffentliche Prompts können nicht bearbeitet werden",
|
||||
"promptTypePublic": "öffentlich",
|
||||
"promptTypePrivate": "privat"
|
||||
},
|
||||
"chunk": {
|
||||
"add": "Chunk hinzufügen",
|
||||
"edit": "Bearbeiten",
|
||||
"title": "Titel",
|
||||
"enterTitle": "Titel eingeben",
|
||||
"bodyText": "Textkörper",
|
||||
"promptText": "Prompt-Text",
|
||||
"save": "Speichern",
|
||||
"close": "Schließen",
|
||||
"cancel": "Abbrechen",
|
||||
"delete": "Löschen",
|
||||
"deleteConfirmation": "Bist du sicher, dass du diesen Chunk löschen möchtest?"
|
||||
},
|
||||
"addAction": {
|
||||
"title": "Neue Aktion",
|
||||
"actionNamePlaceholder": "Aktionsname",
|
||||
"invalidFormat": "Ungültiges Funktionsnamenformat. Verwende nur Buchstaben, Zahlen, Unterstriche und Bindestriche.",
|
||||
"formatHelp": "Verwende nur Buchstaben, Zahlen, Unterstriche und Bindestriche (z.B. `get_data`, `send_report`, etc.)",
|
||||
"addButton": "Hinzufügen"
|
||||
},
|
||||
"agentDetails": {
|
||||
"title": "Zugangsdaten",
|
||||
"publicLink": "Öffentlicher Link",
|
||||
"apiKey": "API-Schlüssel",
|
||||
"webhookUrl": "Webhook-URL",
|
||||
"generate": "Generieren",
|
||||
"test": "Testen",
|
||||
"learnMore": "Mehr erfahren"
|
||||
}
|
||||
},
|
||||
"sharedConv": {
|
||||
"subtitle": "Erstellt mit",
|
||||
"button": "Mit DocsGPT starten",
|
||||
"meta": "DocsGPT verwendet GenAI, bitte überprüfe kritische Informationen anhand der Quellen."
|
||||
},
|
||||
"convTile": {
|
||||
"share": "Teilen",
|
||||
"delete": "Löschen",
|
||||
"rename": "Umbenennen",
|
||||
"deleteWarning": "Bist du sicher, dass du diese Konversation löschen möchtest?"
|
||||
},
|
||||
"pagination": {
|
||||
"rowsPerPage": "Zeilen pro Seite",
|
||||
"pageOf": "Seite {{currentPage}} von {{totalPages}}",
|
||||
"firstPage": "Erste Seite",
|
||||
"previousPage": "Vorherige Seite",
|
||||
"nextPage": "Nächste Seite",
|
||||
"lastPage": "Letzte Seite"
|
||||
},
|
||||
"conversation": {
|
||||
"copy": "Kopieren",
|
||||
"copied": "Kopiert",
|
||||
"speak": "Vorlesen",
|
||||
"answer": "Antwort",
|
||||
"edit": {
|
||||
"update": "Aktualisieren",
|
||||
"cancel": "Abbrechen",
|
||||
"placeholder": "Aktualisierte Anfrage eingeben..."
|
||||
},
|
||||
"sources": {
|
||||
"title": "Quellen",
|
||||
"text": "Wähle deine Quellen",
|
||||
"link": "Quellen-Link",
|
||||
"view_more": "{{count}} weitere Quellen",
|
||||
"noSourcesAvailable": "Keine Quellen verfügbar"
|
||||
},
|
||||
"attachments": {
|
||||
"attach": "Anhängen",
|
||||
"remove": "Anhang entfernen"
|
||||
},
|
||||
"retry": "Erneut versuchen",
|
||||
"reasoning": "Begründung"
|
||||
},
|
||||
"agents": {
|
||||
"title": "Agenten",
|
||||
"description": "Entdecke und erstelle benutzerdefinierte Versionen von DocsGPT, die Anweisungen, zusätzliches Wissen und beliebige Kombinationen von Fähigkeiten kombinieren",
|
||||
"newAgent": "Neuer Agent",
|
||||
"backToAll": "Zurück zu allen Agenten",
|
||||
"sections": {
|
||||
"template": {
|
||||
"title": "Von DocsGPT",
|
||||
"description": "Von DocsGPT bereitgestellte Agenten",
|
||||
"emptyState": "Keine Vorlagen-Agenten gefunden."
|
||||
},
|
||||
"user": {
|
||||
"title": "Von mir",
|
||||
"description": "Von dir erstellte oder veröffentlichte Agenten",
|
||||
"emptyState": "Du hast noch keine Agenten erstellt."
|
||||
},
|
||||
"shared": {
|
||||
"title": "Mit mir geteilt",
|
||||
"description": "Über einen öffentlichen Link importierte Agenten",
|
||||
"emptyState": "Keine geteilten Agenten gefunden."
|
||||
}
|
||||
},
|
||||
"form": {
|
||||
"headings": {
|
||||
"new": "Neuer Agent",
|
||||
"edit": "Agent bearbeiten",
|
||||
"draft": "Neuer Agent (Entwurf)"
|
||||
},
|
||||
"buttons": {
|
||||
"publish": "Veröffentlichen",
|
||||
"save": "Speichern",
|
||||
"saveDraft": "Entwurf speichern",
|
||||
"cancel": "Abbrechen",
|
||||
"delete": "Löschen",
|
||||
"logs": "Protokolle",
|
||||
"accessDetails": "Zugangsdaten",
|
||||
"add": "Hinzufügen"
|
||||
},
|
||||
"sections": {
|
||||
"meta": "Meta",
|
||||
"source": "Quelle",
|
||||
"prompt": "Prompt",
|
||||
"tools": "Werkzeuge",
|
||||
"agentType": "Agententyp",
|
||||
"models": "Modelle",
|
||||
"advanced": "Erweitert",
|
||||
"preview": "Vorschau"
|
||||
},
|
||||
"placeholders": {
|
||||
"agentName": "Agentenname",
|
||||
"describeAgent": "Beschreibe deinen Agenten",
|
||||
"selectSources": "Quellen auswählen",
|
||||
"chunksPerQuery": "Chunks pro Anfrage",
|
||||
"selectType": "Typ auswählen",
|
||||
"selectTools": "Werkzeuge auswählen",
|
||||
"selectModels": "Modelle für diesen Agenten auswählen",
|
||||
"selectDefaultModel": "Standardmodell auswählen",
|
||||
"enterTokenLimit": "Token-Limit eingeben",
|
||||
"enterRequestLimit": "Anfrage-Limit eingeben"
|
||||
},
|
||||
"sourcePopup": {
|
||||
"title": "Quellen auswählen",
|
||||
"searchPlaceholder": "Quellen suchen...",
|
||||
"noOptionsMessage": "Keine Quellen verfügbar"
|
||||
},
|
||||
"toolsPopup": {
|
||||
"title": "Werkzeuge auswählen",
|
||||
"searchPlaceholder": "Werkzeuge suchen...",
|
||||
"noOptionsMessage": "Keine Werkzeuge verfügbar"
|
||||
},
|
||||
"modelsPopup": {
|
||||
"title": "Modelle auswählen",
|
||||
"searchPlaceholder": "Modelle suchen...",
|
||||
"noOptionsMessage": "Keine Modelle verfügbar"
|
||||
},
|
||||
"upload": {
|
||||
"clickToUpload": "Klicken zum Hochladen",
|
||||
"dragAndDrop": " oder per Drag & Drop"
|
||||
},
|
||||
"agentTypes": {
|
||||
"classic": "Klassisch",
|
||||
"react": "ReAct"
|
||||
},
|
||||
"labels": {
|
||||
"defaultModel": "Standardmodell"
|
||||
},
|
||||
"advanced": {
|
||||
"jsonSchema": "JSON-Antwortschema",
|
||||
"jsonSchemaDescription": "Definiere ein JSON-Schema, um ein strukturiertes Ausgabeformat zu erzwingen",
|
||||
"validJson": "Gültiges JSON",
|
||||
"invalidJson": "Ungültiges JSON - zur Aktivierung des Speicherns beheben",
|
||||
"tokenLimiting": "Token-Limitierung",
|
||||
"tokenLimitingDescription": "Begrenze die täglich von diesem Agenten verwendbaren Tokens",
|
||||
"requestLimiting": "Anfrage-Limitierung",
|
||||
"requestLimitingDescription": "Begrenze die täglich an diesen Agenten gestellten Anfragen"
|
||||
},
|
||||
"preview": {
|
||||
"publishedPreview": "Veröffentlichte Agenten können hier in der Vorschau angezeigt werden"
|
||||
},
|
||||
"externalKb": "Externe KB"
|
||||
},
|
||||
"logs": {
|
||||
"title": "Agenten-Protokolle",
|
||||
"lastUsedAt": "Zuletzt verwendet am",
|
||||
"noUsageHistory": "Kein Nutzungsverlauf",
|
||||
"tableHeader": "Agenten-Endpunkt-Protokolle"
|
||||
},
|
||||
"shared": {
|
||||
"notFound": "Kein Agent gefunden. Bitte stelle sicher, dass der Agent geteilt ist."
|
||||
},
|
||||
"preview": {
|
||||
"testMessage": "Teste deinen Agenten hier. Veröffentlichte Agenten können in Konversationen verwendet werden."
|
||||
},
|
||||
"deleteConfirmation": "Bist du sicher, dass du diesen Agenten löschen möchtest?"
|
||||
},
|
||||
"components": {
|
||||
"fileUpload": {
|
||||
"clickToUpload": "Klicken zum Hochladen oder per Drag & Drop",
|
||||
"dropFiles": "Dateien hier ablegen",
|
||||
"fileTypes": "PNG, JPG, JPEG bis zu",
|
||||
"sizeLimitUnit": "MB",
|
||||
"fileSizeError": "Datei überschreitet {{size}}MB-Limit"
|
||||
}
|
||||
},
|
||||
"pageNotFound": {
|
||||
"title": "404",
|
||||
"message": "Die gesuchte Seite existiert nicht.",
|
||||
"goHome": "Zur Startseite"
|
||||
},
|
||||
"filePicker": {
|
||||
"searchPlaceholder": "Dateien und Ordner suchen...",
|
||||
"itemsSelected": "{{count}} ausgewählt",
|
||||
"name": "Name",
|
||||
"lastModified": "Zuletzt geändert",
|
||||
"size": "Größe"
|
||||
},
|
||||
"actionButtons": {
|
||||
"openNewChat": "Neuen Chat öffnen",
|
||||
"share": "Teilen"
|
||||
},
|
||||
"mermaid": {
|
||||
"downloadOptions": "Download-Optionen",
|
||||
"viewCode": "Code anzeigen",
|
||||
"decreaseZoom": "Verkleinern",
|
||||
"resetZoom": "Zoom zurücksetzen",
|
||||
"increaseZoom": "Vergrößern"
|
||||
},
|
||||
"navigation": {
|
||||
"agents": "Agenten"
|
||||
},
|
||||
"notification": {
|
||||
"ariaLabel": "Benachrichtigung",
|
||||
"closeAriaLabel": "Benachrichtigung schließen"
|
||||
},
|
||||
"prompts": {
|
||||
"textAriaLabel": "Prompt-Text"
|
||||
}
|
||||
}
|
||||
@@ -530,6 +530,7 @@
|
||||
"prompt": "Prompt",
|
||||
"tools": "Tools",
|
||||
"agentType": "Agent type",
|
||||
"models": "Models",
|
||||
"advanced": "Advanced",
|
||||
"preview": "Preview"
|
||||
},
|
||||
@@ -540,6 +541,8 @@
|
||||
"chunksPerQuery": "Chunks per query",
|
||||
"selectType": "Select type",
|
||||
"selectTools": "Select tools",
|
||||
"selectModels": "Select models for this agent",
|
||||
"selectDefaultModel": "Select default model",
|
||||
"enterTokenLimit": "Enter token limit",
|
||||
"enterRequestLimit": "Enter request limit"
|
||||
},
|
||||
@@ -553,6 +556,11 @@
|
||||
"searchPlaceholder": "Search tools...",
|
||||
"noOptionsMessage": "No tools available"
|
||||
},
|
||||
"modelsPopup": {
|
||||
"title": "Select Models",
|
||||
"searchPlaceholder": "Search models...",
|
||||
"noOptionsMessage": "No models available"
|
||||
},
|
||||
"upload": {
|
||||
"clickToUpload": "Click to upload",
|
||||
"dragAndDrop": " or drag and drop"
|
||||
@@ -561,6 +569,9 @@
|
||||
"classic": "Classic",
|
||||
"react": "ReAct"
|
||||
},
|
||||
"labels": {
|
||||
"defaultModel": "Default Model"
|
||||
},
|
||||
"advanced": {
|
||||
"jsonSchema": "JSON response schema",
|
||||
"jsonSchemaDescription": "Define a JSON schema to enforce structured output format",
|
||||
|
||||
@@ -530,6 +530,7 @@
|
||||
"prompt": "Prompt",
|
||||
"tools": "Herramientas",
|
||||
"agentType": "Tipo de agente",
|
||||
"models": "Modelos",
|
||||
"advanced": "Avanzado",
|
||||
"preview": "Vista previa"
|
||||
},
|
||||
@@ -540,6 +541,8 @@
|
||||
"chunksPerQuery": "Fragmentos por consulta",
|
||||
"selectType": "Seleccionar tipo",
|
||||
"selectTools": "Seleccionar herramientas",
|
||||
"selectModels": "Seleccionar modelos para este agente",
|
||||
"selectDefaultModel": "Seleccionar modelo predeterminado",
|
||||
"enterTokenLimit": "Ingresar límite de tokens",
|
||||
"enterRequestLimit": "Ingresar límite de solicitudes"
|
||||
},
|
||||
@@ -553,6 +556,11 @@
|
||||
"searchPlaceholder": "Buscar herramientas...",
|
||||
"noOptionsMessage": "No hay herramientas disponibles"
|
||||
},
|
||||
"modelsPopup": {
|
||||
"title": "Seleccionar Modelos",
|
||||
"searchPlaceholder": "Buscar modelos...",
|
||||
"noOptionsMessage": "No hay modelos disponibles"
|
||||
},
|
||||
"upload": {
|
||||
"clickToUpload": "Haz clic para subir",
|
||||
"dragAndDrop": " o arrastra y suelta"
|
||||
@@ -561,6 +569,9 @@
|
||||
"classic": "Clásico",
|
||||
"react": "ReAct"
|
||||
},
|
||||
"labels": {
|
||||
"defaultModel": "Modelo Predeterminado"
|
||||
},
|
||||
"advanced": {
|
||||
"jsonSchema": "Esquema de respuesta JSON",
|
||||
"jsonSchemaDescription": "Define un esquema JSON para aplicar formato de salida estructurado",
|
||||
|
||||
@@ -8,6 +8,7 @@ import jp from './jp.json'; //Japanese
|
||||
import zh from './zh.json'; //Mandarin
|
||||
import zhTW from './zh-TW.json'; //Traditional Chinese
|
||||
import ru from './ru.json'; //Russian
|
||||
import de from './de.json'; //German
|
||||
|
||||
i18n
|
||||
.use(LanguageDetector)
|
||||
@@ -32,6 +33,9 @@ i18n
|
||||
ru: {
|
||||
translation: ru,
|
||||
},
|
||||
de: {
|
||||
translation: de,
|
||||
},
|
||||
},
|
||||
fallbackLng: 'en',
|
||||
detection: {
|
||||
|
||||
@@ -530,6 +530,7 @@
|
||||
"prompt": "プロンプト",
|
||||
"tools": "ツール",
|
||||
"agentType": "エージェントタイプ",
|
||||
"models": "モデル",
|
||||
"advanced": "詳細設定",
|
||||
"preview": "プレビュー"
|
||||
},
|
||||
@@ -540,6 +541,8 @@
|
||||
"chunksPerQuery": "クエリごとのチャンク数",
|
||||
"selectType": "タイプを選択",
|
||||
"selectTools": "ツールを選択",
|
||||
"selectModels": "このエージェントのモデルを選択",
|
||||
"selectDefaultModel": "デフォルトモデルを選択",
|
||||
"enterTokenLimit": "トークン制限を入力",
|
||||
"enterRequestLimit": "リクエスト制限を入力"
|
||||
},
|
||||
@@ -553,6 +556,11 @@
|
||||
"searchPlaceholder": "ツールを検索...",
|
||||
"noOptionsMessage": "利用可能なツールがありません"
|
||||
},
|
||||
"modelsPopup": {
|
||||
"title": "モデルを選択",
|
||||
"searchPlaceholder": "モデルを検索...",
|
||||
"noOptionsMessage": "利用可能なモデルがありません"
|
||||
},
|
||||
"upload": {
|
||||
"clickToUpload": "クリックしてアップロード",
|
||||
"dragAndDrop": " またはドラッグ&ドロップ"
|
||||
@@ -561,6 +569,9 @@
|
||||
"classic": "クラシック",
|
||||
"react": "ReAct"
|
||||
},
|
||||
"labels": {
|
||||
"defaultModel": "デフォルトモデル"
|
||||
},
|
||||
"advanced": {
|
||||
"jsonSchema": "JSON応答スキーマ",
|
||||
"jsonSchemaDescription": "構造化された出力形式を適用するためのJSONスキーマを定義します",
|
||||
|
||||
@@ -530,6 +530,7 @@
|
||||
"prompt": "Промпт",
|
||||
"tools": "Инструменты",
|
||||
"agentType": "Тип агента",
|
||||
"models": "Модели",
|
||||
"advanced": "Расширенные",
|
||||
"preview": "Предпросмотр"
|
||||
},
|
||||
@@ -540,6 +541,8 @@
|
||||
"chunksPerQuery": "Фрагментов на запрос",
|
||||
"selectType": "Выберите тип",
|
||||
"selectTools": "Выберите инструменты",
|
||||
"selectModels": "Выберите модели для этого агента",
|
||||
"selectDefaultModel": "Выберите модель по умолчанию",
|
||||
"enterTokenLimit": "Введите лимит токенов",
|
||||
"enterRequestLimit": "Введите лимит запросов"
|
||||
},
|
||||
@@ -553,6 +556,11 @@
|
||||
"searchPlaceholder": "Поиск инструментов...",
|
||||
"noOptionsMessage": "Нет доступных инструментов"
|
||||
},
|
||||
"modelsPopup": {
|
||||
"title": "Выберите Модели",
|
||||
"searchPlaceholder": "Поиск моделей...",
|
||||
"noOptionsMessage": "Нет доступных моделей"
|
||||
},
|
||||
"upload": {
|
||||
"clickToUpload": "Нажмите для загрузки",
|
||||
"dragAndDrop": " или перетащите"
|
||||
@@ -561,6 +569,9 @@
|
||||
"classic": "Классический",
|
||||
"react": "ReAct"
|
||||
},
|
||||
"labels": {
|
||||
"defaultModel": "Модель по умолчанию"
|
||||
},
|
||||
"advanced": {
|
||||
"jsonSchema": "Схема ответа JSON",
|
||||
"jsonSchemaDescription": "Определите схему JSON для применения структурированного формата вывода",
|
||||
|
||||
@@ -530,6 +530,7 @@
|
||||
"prompt": "提示詞",
|
||||
"tools": "工具",
|
||||
"agentType": "代理類型",
|
||||
"models": "模型",
|
||||
"advanced": "進階",
|
||||
"preview": "預覽"
|
||||
},
|
||||
@@ -540,6 +541,8 @@
|
||||
"chunksPerQuery": "每次查詢的區塊數",
|
||||
"selectType": "選擇類型",
|
||||
"selectTools": "選擇工具",
|
||||
"selectModels": "為此代理選擇模型",
|
||||
"selectDefaultModel": "選擇預設模型",
|
||||
"enterTokenLimit": "輸入權杖限制",
|
||||
"enterRequestLimit": "輸入請求限制"
|
||||
},
|
||||
@@ -553,6 +556,11 @@
|
||||
"searchPlaceholder": "搜尋工具...",
|
||||
"noOptionsMessage": "沒有可用的工具"
|
||||
},
|
||||
"modelsPopup": {
|
||||
"title": "選擇模型",
|
||||
"searchPlaceholder": "搜尋模型...",
|
||||
"noOptionsMessage": "沒有可用的模型"
|
||||
},
|
||||
"upload": {
|
||||
"clickToUpload": "點擊上傳",
|
||||
"dragAndDrop": " 或拖放"
|
||||
@@ -561,6 +569,9 @@
|
||||
"classic": "經典",
|
||||
"react": "ReAct"
|
||||
},
|
||||
"labels": {
|
||||
"defaultModel": "預設模型"
|
||||
},
|
||||
"advanced": {
|
||||
"jsonSchema": "JSON回應架構",
|
||||
"jsonSchemaDescription": "定義JSON架構以強制執行結構化輸出格式",
|
||||
|
||||
@@ -530,6 +530,7 @@
|
||||
"prompt": "提示词",
|
||||
"tools": "工具",
|
||||
"agentType": "代理类型",
|
||||
"models": "模型",
|
||||
"advanced": "高级",
|
||||
"preview": "预览"
|
||||
},
|
||||
@@ -540,6 +541,8 @@
|
||||
"chunksPerQuery": "每次查询的块数",
|
||||
"selectType": "选择类型",
|
||||
"selectTools": "选择工具",
|
||||
"selectModels": "为此代理选择模型",
|
||||
"selectDefaultModel": "选择默认模型",
|
||||
"enterTokenLimit": "输入令牌限制",
|
||||
"enterRequestLimit": "输入请求限制"
|
||||
},
|
||||
@@ -553,6 +556,11 @@
|
||||
"searchPlaceholder": "搜索工具...",
|
||||
"noOptionsMessage": "没有可用的工具"
|
||||
},
|
||||
"modelsPopup": {
|
||||
"title": "选择模型",
|
||||
"searchPlaceholder": "搜索模型...",
|
||||
"noOptionsMessage": "没有可用的模型"
|
||||
},
|
||||
"upload": {
|
||||
"clickToUpload": "点击上传",
|
||||
"dragAndDrop": " 或拖放"
|
||||
@@ -561,6 +569,9 @@
|
||||
"classic": "经典",
|
||||
"react": "ReAct"
|
||||
},
|
||||
"labels": {
|
||||
"defaultModel": "默认模型"
|
||||
},
|
||||
"advanced": {
|
||||
"jsonSchema": "JSON响应架构",
|
||||
"jsonSchemaDescription": "定义JSON架构以强制执行结构化输出格式",
|
||||
|
||||
25
frontend/src/models/types.ts
Normal file
25
frontend/src/models/types.ts
Normal file
@@ -0,0 +1,25 @@
|
||||
export interface AvailableModel {
|
||||
id: string;
|
||||
provider: string;
|
||||
display_name: string;
|
||||
description?: string;
|
||||
context_window: number;
|
||||
supported_attachment_types: string[];
|
||||
supports_tools: boolean;
|
||||
supports_structured_output: boolean;
|
||||
supports_streaming: boolean;
|
||||
enabled: boolean;
|
||||
}
|
||||
|
||||
export interface Model {
|
||||
id: string;
|
||||
value: string;
|
||||
provider: string;
|
||||
display_name: string;
|
||||
description?: string;
|
||||
context_window: number;
|
||||
supported_attachment_types: string[];
|
||||
supports_tools: boolean;
|
||||
supports_structured_output: boolean;
|
||||
supports_streaming: boolean;
|
||||
}
|
||||
@@ -191,7 +191,7 @@ const useToolVariables = () => {
|
||||
}
|
||||
filteredActions.push({
|
||||
label: `${action.name} (${tool.displayName || tool.name})`,
|
||||
value: `tools.${toolIdentifier}.${action.name}`,
|
||||
value: `tools['${toolIdentifier}'].${action.name}`,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,11 +9,12 @@ import { Agent } from '../agents/types';
|
||||
import { ActiveState, Doc } from '../models/misc';
|
||||
import { RootState } from '../store';
|
||||
import {
|
||||
getLocalRecentDocs,
|
||||
setLocalApiKey,
|
||||
setLocalRecentDocs,
|
||||
getLocalRecentDocs,
|
||||
} from './preferenceApi';
|
||||
|
||||
import type { Model } from '../models/types';
|
||||
export interface Preference {
|
||||
apiKey: string;
|
||||
prompt: { name: string; id: string; type: string };
|
||||
@@ -32,6 +33,9 @@ export interface Preference {
|
||||
agents: Agent[] | null;
|
||||
sharedAgents: Agent[] | null;
|
||||
selectedAgent: Agent | null;
|
||||
selectedModel: Model | null;
|
||||
availableModels: Model[];
|
||||
modelsLoading: boolean;
|
||||
}
|
||||
|
||||
const initialState: Preference = {
|
||||
@@ -61,6 +65,9 @@ const initialState: Preference = {
|
||||
agents: null,
|
||||
sharedAgents: null,
|
||||
selectedAgent: null,
|
||||
selectedModel: null,
|
||||
availableModels: [],
|
||||
modelsLoading: false,
|
||||
};
|
||||
|
||||
export const prefSlice = createSlice({
|
||||
@@ -109,6 +116,15 @@ export const prefSlice = createSlice({
|
||||
setSelectedAgent: (state, action) => {
|
||||
state.selectedAgent = action.payload;
|
||||
},
|
||||
setSelectedModel: (state, action: PayloadAction<Model | null>) => {
|
||||
state.selectedModel = action.payload;
|
||||
},
|
||||
setAvailableModels: (state, action: PayloadAction<Model[]>) => {
|
||||
state.availableModels = action.payload;
|
||||
},
|
||||
setModelsLoading: (state, action: PayloadAction<boolean>) => {
|
||||
state.modelsLoading = action.payload;
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
@@ -127,6 +143,9 @@ export const {
|
||||
setAgents,
|
||||
setSharedAgents,
|
||||
setSelectedAgent,
|
||||
setSelectedModel,
|
||||
setAvailableModels,
|
||||
setModelsLoading,
|
||||
} = prefSlice.actions;
|
||||
export default prefSlice.reducer;
|
||||
|
||||
@@ -198,6 +217,19 @@ prefListenerMiddleware.startListening({
|
||||
},
|
||||
});
|
||||
|
||||
prefListenerMiddleware.startListening({
|
||||
matcher: isAnyOf(setSelectedModel),
|
||||
effect: (action, listenerApi) => {
|
||||
const model = (listenerApi.getState() as RootState).preference
|
||||
.selectedModel;
|
||||
if (model) {
|
||||
localStorage.setItem('DocsGPTSelectedModel', JSON.stringify(model));
|
||||
} else {
|
||||
localStorage.removeItem('DocsGPTSelectedModel');
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
export const selectApiKey = (state: RootState) => state.preference.apiKey;
|
||||
export const selectApiKeyStatus = (state: RootState) =>
|
||||
!!state.preference.apiKey;
|
||||
@@ -227,3 +259,9 @@ export const selectSharedAgents = (state: RootState) =>
|
||||
state.preference.sharedAgents;
|
||||
export const selectSelectedAgent = (state: RootState) =>
|
||||
state.preference.selectedAgent;
|
||||
export const selectSelectedModel = (state: RootState) =>
|
||||
state.preference.selectedModel;
|
||||
export const selectAvailableModels = (state: RootState) =>
|
||||
state.preference.availableModels;
|
||||
export const selectModelsLoading = (state: RootState) =>
|
||||
state.preference.modelsLoading;
|
||||
|
||||
@@ -30,6 +30,7 @@ export default function General() {
|
||||
|
||||
const languageOptions = [
|
||||
{ label: 'English', value: 'en' },
|
||||
{ label: 'Deutsch', value: 'de' },
|
||||
{ label: 'Español', value: 'es' },
|
||||
{ label: '日本語', value: 'jp' },
|
||||
{ label: '普通话', value: 'zh' },
|
||||
|
||||
@@ -15,6 +15,7 @@ const prompt = localStorage.getItem('DocsGPTPrompt');
|
||||
const chunks = localStorage.getItem('DocsGPTChunks');
|
||||
const token_limit = localStorage.getItem('DocsGPTTokenLimit');
|
||||
const doc = localStorage.getItem('DocsGPTRecentDocs');
|
||||
const selectedModel = localStorage.getItem('DocsGPTSelectedModel');
|
||||
|
||||
const preloadedState: { preference: Preference } = {
|
||||
preference: {
|
||||
@@ -47,6 +48,9 @@ const preloadedState: { preference: Preference } = {
|
||||
agents: null,
|
||||
sharedAgents: null,
|
||||
selectedAgent: null,
|
||||
selectedModel: selectedModel ? JSON.parse(selectedModel) : null,
|
||||
availableModels: [],
|
||||
modelsLoading: false,
|
||||
},
|
||||
};
|
||||
const store = configureStore({
|
||||
|
||||
114
scripts/migrate_conversation_id_dbref_to_objectid.py
Normal file
114
scripts/migrate_conversation_id_dbref_to_objectid.py
Normal file
@@ -0,0 +1,114 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Migration script to convert conversation_id from DBRef to ObjectId in shared_conversations collection.
|
||||
"""
|
||||
|
||||
import pymongo
|
||||
import logging
|
||||
from tqdm import tqdm
|
||||
from bson.dbref import DBRef
|
||||
from bson.objectid import ObjectId
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger()
|
||||
|
||||
# Configuration
|
||||
MONGO_URI = "mongodb://localhost:27017/"
|
||||
DB_NAME = "docsgpt"
|
||||
|
||||
def backup_collection(collection, backup_collection_name):
|
||||
"""Backup collection before migration."""
|
||||
logger.info(f"Backing up collection {collection.name} to {backup_collection_name}")
|
||||
collection.aggregate([{"$out": backup_collection_name}])
|
||||
logger.info("Backup completed")
|
||||
|
||||
def migrate_conversation_id_dbref_to_objectid():
|
||||
"""Migrate conversation_id from DBRef to ObjectId."""
|
||||
client = pymongo.MongoClient(MONGO_URI)
|
||||
db = client[DB_NAME]
|
||||
shared_conversations_collection = db["shared_conversations"]
|
||||
|
||||
try:
|
||||
# Backup collection before migration
|
||||
backup_collection(shared_conversations_collection, "shared_conversations_backup")
|
||||
|
||||
# Find all documents and filter for DBRef conversation_id in Python
|
||||
all_documents = list(shared_conversations_collection.find({}))
|
||||
documents_with_dbref = []
|
||||
|
||||
for doc in all_documents:
|
||||
conversation_id_field = doc.get("conversation_id")
|
||||
if isinstance(conversation_id_field, DBRef):
|
||||
documents_with_dbref.append(doc)
|
||||
|
||||
if not documents_with_dbref:
|
||||
logger.info("No documents with DBRef conversation_id found. Migration not needed.")
|
||||
return
|
||||
|
||||
logger.info(f"Found {len(documents_with_dbref)} documents with DBRef conversation_id")
|
||||
|
||||
# Process each document
|
||||
migrated_count = 0
|
||||
error_count = 0
|
||||
|
||||
for doc in tqdm(documents_with_dbref, desc="Migrating conversation_id"):
|
||||
try:
|
||||
conversation_id_field = doc.get("conversation_id")
|
||||
|
||||
# Extract the ObjectId from the DBRef
|
||||
dbref_id = conversation_id_field.id
|
||||
|
||||
if dbref_id and ObjectId.is_valid(dbref_id):
|
||||
# Update the document to use direct ObjectId
|
||||
result = shared_conversations_collection.update_one(
|
||||
{"_id": doc["_id"]},
|
||||
{"$set": {"conversation_id": dbref_id}}
|
||||
)
|
||||
|
||||
if result.modified_count > 0:
|
||||
migrated_count += 1
|
||||
logger.debug(f"Successfully migrated document {doc['_id']}")
|
||||
else:
|
||||
error_count += 1
|
||||
logger.warning(f"Failed to update document {doc['_id']}")
|
||||
else:
|
||||
error_count += 1
|
||||
logger.warning(f"Invalid ObjectId in DBRef for document {doc['_id']}: {dbref_id}")
|
||||
|
||||
except Exception as e:
|
||||
error_count += 1
|
||||
logger.error(f"Error migrating document {doc['_id']}: {e}")
|
||||
|
||||
# Final verification
|
||||
all_docs_after = list(shared_conversations_collection.find({}))
|
||||
remaining_dbref = 0
|
||||
for doc in all_docs_after:
|
||||
if isinstance(doc.get("conversation_id"), DBRef):
|
||||
remaining_dbref += 1
|
||||
|
||||
logger.info("Migration completed:")
|
||||
logger.info(f" - Total documents processed: {len(documents_with_dbref)}")
|
||||
logger.info(f" - Successfully migrated: {migrated_count}")
|
||||
logger.info(f" - Errors encountered: {error_count}")
|
||||
logger.info(f" - Remaining DBRef documents: {remaining_dbref}")
|
||||
|
||||
if remaining_dbref == 0:
|
||||
logger.info("✅ Migration successful: All DBRef conversation_id fields have been converted to ObjectId")
|
||||
else:
|
||||
logger.warning(f"⚠️ Migration incomplete: {remaining_dbref} DBRef documents still exist")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Migration failed: {e}")
|
||||
raise
|
||||
finally:
|
||||
client.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
logger.info("Starting conversation_id DBRef to ObjectId migration...")
|
||||
migrate_conversation_id_dbref_to_objectid()
|
||||
logger.info("Migration completed successfully!")
|
||||
except Exception as e:
|
||||
logger.error(f"Migration failed due to error: {e}")
|
||||
logger.warning("Please verify database state or restore from backups if necessary.")
|
||||
@@ -12,7 +12,7 @@ class TestAgentCreator:
|
||||
assert isinstance(agent, ClassicAgent)
|
||||
assert agent.endpoint == agent_base_params["endpoint"]
|
||||
assert agent.llm_name == agent_base_params["llm_name"]
|
||||
assert agent.gpt_model == agent_base_params["gpt_model"]
|
||||
assert agent.model_id == agent_base_params["model_id"]
|
||||
|
||||
def test_create_react_agent(self, agent_base_params):
|
||||
agent = AgentCreator.create_agent("react", **agent_base_params)
|
||||
|
||||
@@ -15,7 +15,7 @@ class TestBaseAgentInitialization:
|
||||
|
||||
assert agent.endpoint == agent_base_params["endpoint"]
|
||||
assert agent.llm_name == agent_base_params["llm_name"]
|
||||
assert agent.gpt_model == agent_base_params["gpt_model"]
|
||||
assert agent.model_id == agent_base_params["model_id"]
|
||||
assert agent.api_key == agent_base_params["api_key"]
|
||||
assert agent.prompt == agent_base_params["prompt"]
|
||||
assert agent.user == agent_base_params["decoded_token"]["sub"]
|
||||
@@ -480,7 +480,7 @@ class TestBaseAgentLLMGeneration:
|
||||
|
||||
mock_llm.gen_stream.assert_called_once()
|
||||
call_args = mock_llm.gen_stream.call_args[1]
|
||||
assert call_args["model"] == agent.gpt_model
|
||||
assert call_args["model"] == agent.model_id
|
||||
assert call_args["messages"] == messages
|
||||
|
||||
def test_llm_gen_with_tools(
|
||||
|
||||
@@ -23,7 +23,7 @@ class TestReActAgent:
|
||||
|
||||
assert agent.endpoint == agent_base_params["endpoint"]
|
||||
assert agent.llm_name == agent_base_params["llm_name"]
|
||||
assert agent.gpt_model == agent_base_params["gpt_model"]
|
||||
assert agent.model_id == agent_base_params["model_id"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
|
||||
@@ -274,8 +274,8 @@ class TestGPTModelRetrieval:
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
|
||||
assert hasattr(resource, "gpt_model")
|
||||
assert resource.gpt_model is not None
|
||||
assert hasattr(resource, "default_model_id")
|
||||
assert resource.default_model_id is not None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@@ -412,7 +412,7 @@ class TestCompleteStreamMethod:
|
||||
resource.complete_stream(
|
||||
question="Test?",
|
||||
agent=mock_agent,
|
||||
conversation_id=None,
|
||||
conversation_id=None,
|
||||
user_api_key=None,
|
||||
decoded_token=decoded_token,
|
||||
should_save_conversation=True,
|
||||
@@ -500,9 +500,10 @@ class TestProcessResponseStream:
|
||||
|
||||
result = resource.process_response_stream(iter(stream))
|
||||
|
||||
assert len(result) == 5
|
||||
assert len(result) == 6
|
||||
assert result[0] is None
|
||||
assert result[4] == "Test error"
|
||||
assert result[5] is None
|
||||
|
||||
def test_handles_malformed_stream_data(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
@@ -108,7 +108,7 @@ class TestConversationServiceSave:
|
||||
sources=[],
|
||||
tool_calls=[],
|
||||
llm=mock_llm,
|
||||
gpt_model="gpt-4",
|
||||
model_id="gpt-4",
|
||||
decoded_token={}, # No 'sub' key
|
||||
)
|
||||
|
||||
@@ -136,7 +136,7 @@ class TestConversationServiceSave:
|
||||
sources=sources,
|
||||
tool_calls=[],
|
||||
llm=mock_llm,
|
||||
gpt_model="gpt-4",
|
||||
model_id="gpt-4",
|
||||
decoded_token={"sub": "user_123"},
|
||||
)
|
||||
|
||||
@@ -167,7 +167,7 @@ class TestConversationServiceSave:
|
||||
sources=[],
|
||||
tool_calls=[],
|
||||
llm=mock_llm,
|
||||
gpt_model="gpt-4",
|
||||
model_id="gpt-4",
|
||||
decoded_token={"sub": "user_123"},
|
||||
)
|
||||
|
||||
@@ -208,7 +208,7 @@ class TestConversationServiceSave:
|
||||
sources=[],
|
||||
tool_calls=[],
|
||||
llm=mock_llm,
|
||||
gpt_model="gpt-4",
|
||||
model_id="gpt-4",
|
||||
decoded_token={"sub": "user_123"},
|
||||
)
|
||||
|
||||
@@ -237,6 +237,6 @@ class TestConversationServiceSave:
|
||||
sources=[],
|
||||
tool_calls=[],
|
||||
llm=mock_llm,
|
||||
gpt_model="gpt-4",
|
||||
model_id="gpt-4",
|
||||
decoded_token={"sub": "hacker_456"},
|
||||
)
|
||||
|
||||
@@ -150,7 +150,7 @@ def agent_base_params(decoded_token):
|
||||
return {
|
||||
"endpoint": "https://api.example.com",
|
||||
"llm_name": "openai",
|
||||
"gpt_model": "gpt-4",
|
||||
"model_id": "gpt-4",
|
||||
"api_key": "test_api_key",
|
||||
"user_api_key": None,
|
||||
"prompt": "You are a helpful assistant.",
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
import sys
|
||||
import types
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class _FakeCompletion:
|
||||
def __init__(self, text):
|
||||
self.completion = text
|
||||
|
||||
|
||||
class _FakeCompletions:
|
||||
def __init__(self):
|
||||
self.last_kwargs = None
|
||||
@@ -17,6 +20,7 @@ class _FakeCompletions:
|
||||
return self._stream
|
||||
return _FakeCompletion("final")
|
||||
|
||||
|
||||
class _FakeAnthropic:
|
||||
def __init__(self, api_key=None):
|
||||
self.api_key = api_key
|
||||
@@ -29,9 +33,19 @@ def patch_anthropic(monkeypatch):
|
||||
fake.Anthropic = _FakeAnthropic
|
||||
fake.HUMAN_PROMPT = "<HUMAN>"
|
||||
fake.AI_PROMPT = "<AI>"
|
||||
|
||||
modules_to_remove = [key for key in sys.modules if key.startswith("anthropic")]
|
||||
for key in modules_to_remove:
|
||||
sys.modules.pop(key, None)
|
||||
sys.modules["anthropic"] = fake
|
||||
|
||||
if "application.llm.anthropic" in sys.modules:
|
||||
del sys.modules["application.llm.anthropic"]
|
||||
yield
|
||||
|
||||
sys.modules.pop("anthropic", None)
|
||||
if "application.llm.anthropic" in sys.modules:
|
||||
del sys.modules["application.llm.anthropic"]
|
||||
|
||||
|
||||
def test_anthropic_raw_gen_builds_prompt_and_returns_completion():
|
||||
@@ -42,7 +56,9 @@ def test_anthropic_raw_gen_builds_prompt_and_returns_completion():
|
||||
{"content": "ctx"},
|
||||
{"content": "q"},
|
||||
]
|
||||
out = llm._raw_gen(llm, model="claude-2", messages=msgs, stream=False, max_tokens=55)
|
||||
out = llm._raw_gen(
|
||||
llm, model="claude-2", messages=msgs, stream=False, max_tokens=55
|
||||
)
|
||||
assert out == "final"
|
||||
last = llm.anthropic.completions.last_kwargs
|
||||
assert last["model"] == "claude-2"
|
||||
@@ -59,7 +75,8 @@ def test_anthropic_raw_gen_stream_yields_chunks():
|
||||
{"content": "ctx"},
|
||||
{"content": "q"},
|
||||
]
|
||||
gen = llm._raw_gen_stream(llm, model="claude", messages=msgs, stream=True, max_tokens=10)
|
||||
gen = llm._raw_gen_stream(
|
||||
llm, model="claude", messages=msgs, stream=True, max_tokens=10
|
||||
)
|
||||
chunks = list(gen)
|
||||
assert chunks == ["s1", "s2"]
|
||||
|
||||
|
||||
@@ -91,7 +91,7 @@ def test_clean_messages_google_basic():
|
||||
{"function_call": {"name": "fn", "args": {"a": 1}}},
|
||||
]},
|
||||
]
|
||||
cleaned = llm._clean_messages_google(msgs)
|
||||
cleaned, system_instruction = llm._clean_messages_google(msgs)
|
||||
|
||||
assert all(hasattr(c, "role") and hasattr(c, "parts") for c in cleaned)
|
||||
assert any(c.role == "model" for c in cleaned)
|
||||
|
||||
325
tests/test_agent_token_tracking.py
Normal file
325
tests/test_agent_token_tracking.py
Normal file
@@ -0,0 +1,325 @@
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from application.agents.base import BaseAgent
|
||||
from application.llm.handlers.base import LLMHandler, ToolCall
|
||||
|
||||
|
||||
class MockAgent(BaseAgent):
|
||||
"""Mock agent for testing"""
|
||||
|
||||
def _gen_inner(self, query, log_context=None):
|
||||
yield {"answer": "test"}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent():
|
||||
"""Create a mock agent for testing"""
|
||||
agent = MockAgent(
|
||||
endpoint="test",
|
||||
llm_name="openai",
|
||||
model_id="gpt-4o",
|
||||
api_key="test-key",
|
||||
)
|
||||
agent.llm = Mock()
|
||||
return agent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_handler():
|
||||
"""Create a mock LLM handler"""
|
||||
handler = Mock(spec=LLMHandler)
|
||||
handler.tool_calls = []
|
||||
return handler
|
||||
|
||||
|
||||
class TestAgentTokenTracking:
|
||||
"""Test suite for agent token tracking during execution"""
|
||||
|
||||
def test_calculate_current_context_tokens(self, mock_agent):
|
||||
"""Test token calculation for current context"""
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing well, thank you!"},
|
||||
]
|
||||
|
||||
tokens = mock_agent._calculate_current_context_tokens(messages)
|
||||
|
||||
# Should count tokens from all messages
|
||||
assert tokens > 0
|
||||
# Rough estimate: ~20-40 tokens for this conversation
|
||||
assert 15 < tokens < 60
|
||||
|
||||
def test_calculate_tokens_with_tool_calls(self, mock_agent):
|
||||
"""Test token calculation includes tool call content"""
|
||||
messages = [
|
||||
{"role": "system", "content": "Test"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"function_call": {
|
||||
"name": "search_tool",
|
||||
"args": {"query": "test"},
|
||||
"call_id": "123",
|
||||
}
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"content": [
|
||||
{
|
||||
"function_response": {
|
||||
"name": "search_tool",
|
||||
"response": {"result": "Found 10 results"},
|
||||
"call_id": "123",
|
||||
}
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
tokens = mock_agent._calculate_current_context_tokens(messages)
|
||||
|
||||
# Should include tool call tokens
|
||||
assert tokens > 0
|
||||
|
||||
@patch("application.core.model_utils.get_token_limit")
|
||||
@patch("application.core.settings.settings")
|
||||
def test_check_context_limit_below_threshold(
|
||||
self, mock_settings, mock_get_token_limit, mock_agent
|
||||
):
|
||||
"""Test context limit check when below threshold"""
|
||||
mock_get_token_limit.return_value = 128000
|
||||
mock_settings.COMPRESSION_THRESHOLD_PERCENTAGE = 0.8
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "Short message"},
|
||||
{"role": "user", "content": "Hello"},
|
||||
]
|
||||
|
||||
# Should return False for small conversation
|
||||
result = mock_agent._check_context_limit(messages)
|
||||
assert result is False
|
||||
|
||||
# Should track current token count
|
||||
assert mock_agent.current_token_count > 0
|
||||
assert mock_agent.current_token_count < 128000 * 0.8
|
||||
|
||||
@patch("application.core.model_utils.get_token_limit")
|
||||
@patch("application.core.settings.settings")
|
||||
def test_check_context_limit_above_threshold(
|
||||
self, mock_settings, mock_get_token_limit, mock_agent
|
||||
):
|
||||
"""Test context limit check when above threshold"""
|
||||
mock_get_token_limit.return_value = 100 # Very small limit for testing
|
||||
mock_settings.COMPRESSION_THRESHOLD_PERCENTAGE = 0.8
|
||||
|
||||
# Create messages that will exceed 80 tokens (80% of 100)
|
||||
messages = [
|
||||
{"role": "system", "content": "a " * 50}, # ~50 tokens
|
||||
{"role": "user", "content": "b " * 50}, # ~50 tokens
|
||||
]
|
||||
|
||||
# Should return True when exceeding threshold
|
||||
result = mock_agent._check_context_limit(messages)
|
||||
assert result is True
|
||||
|
||||
@patch("application.agents.base.logger")
|
||||
def test_check_context_limit_error_handling(self, mock_logger, mock_agent):
|
||||
"""Test error handling in context limit check"""
|
||||
# Force an error by making get_token_limit fail
|
||||
with patch(
|
||||
"application.core.model_utils.get_token_limit", side_effect=Exception("Test error")
|
||||
):
|
||||
messages = [{"role": "user", "content": "test"}]
|
||||
|
||||
result = mock_agent._check_context_limit(messages)
|
||||
|
||||
# Should return False on error (safe default)
|
||||
assert result is False
|
||||
# Should log the error
|
||||
assert mock_logger.error.called
|
||||
|
||||
def test_context_limit_flag_initialization(self, mock_agent):
|
||||
"""Test that context limit flag is initialized"""
|
||||
assert hasattr(mock_agent, "context_limit_reached")
|
||||
assert mock_agent.context_limit_reached is False
|
||||
|
||||
assert hasattr(mock_agent, "current_token_count")
|
||||
assert mock_agent.current_token_count == 0
|
||||
|
||||
|
||||
class TestLLMHandlerTokenTracking:
|
||||
"""Test suite for LLM handler token tracking"""
|
||||
|
||||
@patch("application.llm.handlers.base.logger")
|
||||
def test_handle_tool_calls_stops_at_limit(self, mock_logger):
|
||||
"""Test that tool execution stops when context limit is reached"""
|
||||
from application.llm.handlers.base import LLMHandler
|
||||
|
||||
# Create a concrete handler for testing
|
||||
class TestHandler(LLMHandler):
|
||||
def parse_response(self, response):
|
||||
pass
|
||||
|
||||
def create_tool_message(self, tool_call, result):
|
||||
return {"role": "tool", "content": str(result)}
|
||||
|
||||
def _iterate_stream(self, response):
|
||||
yield ""
|
||||
|
||||
handler = TestHandler()
|
||||
|
||||
# Create mock agent that hits limit on second tool
|
||||
mock_agent = Mock()
|
||||
mock_agent.context_limit_reached = False
|
||||
|
||||
call_count = [0]
|
||||
|
||||
def check_limit_side_effect(messages):
|
||||
call_count[0] += 1
|
||||
# Return True on second call (second tool)
|
||||
return call_count[0] >= 2
|
||||
|
||||
mock_agent._check_context_limit = Mock(side_effect=check_limit_side_effect)
|
||||
mock_agent._execute_tool_action = Mock(
|
||||
return_value=iter([{"type": "tool_call", "data": {}}])
|
||||
)
|
||||
|
||||
# Create multiple tool calls
|
||||
tool_calls = [
|
||||
ToolCall(id="1", name="tool1", arguments={}),
|
||||
ToolCall(id="2", name="tool2", arguments={}),
|
||||
ToolCall(id="3", name="tool3", arguments={}),
|
||||
]
|
||||
|
||||
messages = []
|
||||
tools_dict = {}
|
||||
|
||||
# Execute tool calls
|
||||
results = list(handler.handle_tool_calls(mock_agent, tool_calls, tools_dict, messages))
|
||||
|
||||
# First tool should execute
|
||||
assert mock_agent._execute_tool_action.call_count == 1
|
||||
|
||||
# Should have yielded skip messages for tools 2 and 3
|
||||
skip_messages = [r for r in results if r.get("type") == "tool_call" and r.get("data", {}).get("status") == "skipped"]
|
||||
assert len(skip_messages) == 2
|
||||
|
||||
# Should have set the flag
|
||||
assert mock_agent.context_limit_reached is True
|
||||
|
||||
# Should have logged warning
|
||||
assert mock_logger.warning.called
|
||||
|
||||
def test_handle_tool_calls_all_execute_when_no_limit(self):
|
||||
"""Test that all tools execute when under limit"""
|
||||
from application.llm.handlers.base import LLMHandler
|
||||
|
||||
class TestHandler(LLMHandler):
|
||||
def parse_response(self, response):
|
||||
pass
|
||||
|
||||
def create_tool_message(self, tool_call, result):
|
||||
return {"role": "tool", "content": str(result)}
|
||||
|
||||
def _iterate_stream(self, response):
|
||||
yield ""
|
||||
|
||||
handler = TestHandler()
|
||||
|
||||
# Create mock agent that never hits limit
|
||||
mock_agent = Mock()
|
||||
mock_agent.context_limit_reached = False
|
||||
mock_agent._check_context_limit = Mock(return_value=False)
|
||||
mock_agent._execute_tool_action = Mock(
|
||||
return_value=iter([{"type": "tool_call", "data": {}}])
|
||||
)
|
||||
|
||||
tool_calls = [
|
||||
ToolCall(id="1", name="tool1", arguments={}),
|
||||
ToolCall(id="2", name="tool2", arguments={}),
|
||||
ToolCall(id="3", name="tool3", arguments={}),
|
||||
]
|
||||
|
||||
messages = []
|
||||
tools_dict = {}
|
||||
|
||||
# Execute tool calls
|
||||
list(handler.handle_tool_calls(mock_agent, tool_calls, tools_dict, messages))
|
||||
|
||||
# All 3 tools should execute
|
||||
assert mock_agent._execute_tool_action.call_count == 3
|
||||
|
||||
# Should not have set the flag
|
||||
assert mock_agent.context_limit_reached is False
|
||||
|
||||
@patch("application.llm.handlers.base.logger")
|
||||
def test_handle_streaming_adds_warning_message(self, mock_logger):
|
||||
"""Test that streaming handler adds warning when limit reached"""
|
||||
from application.llm.handlers.base import LLMHandler, LLMResponse, ToolCall
|
||||
|
||||
class TestHandler(LLMHandler):
|
||||
def parse_response(self, response):
|
||||
if isinstance(response, dict) and response.get("type") == "tool_call":
|
||||
return LLMResponse(
|
||||
content="",
|
||||
tool_calls=[ToolCall(id="1", name="test", arguments={}, index=0)],
|
||||
finish_reason="tool_calls",
|
||||
raw_response=None,
|
||||
)
|
||||
else:
|
||||
return LLMResponse(
|
||||
content="Done",
|
||||
tool_calls=[],
|
||||
finish_reason="stop",
|
||||
raw_response=None,
|
||||
)
|
||||
|
||||
def create_tool_message(self, tool_call, result):
|
||||
return {"role": "tool", "content": str(result)}
|
||||
|
||||
def _iterate_stream(self, response):
|
||||
if response == "first":
|
||||
yield {"type": "tool_call"} # Object to be parsed, not string
|
||||
else:
|
||||
yield {"type": "stop"} # Object to be parsed, not string
|
||||
|
||||
handler = TestHandler()
|
||||
|
||||
# Create mock agent with limit reached
|
||||
mock_agent = Mock()
|
||||
mock_agent.context_limit_reached = True
|
||||
mock_agent.model_id = "gpt-4o"
|
||||
mock_agent.tools = []
|
||||
mock_agent.llm = Mock()
|
||||
mock_agent.llm.gen_stream = Mock(return_value="second")
|
||||
|
||||
def tool_handler_gen(*args):
|
||||
yield {"type": "tool", "data": {}}
|
||||
return []
|
||||
|
||||
# Mock handle_tool_calls to return messages and set flag
|
||||
with patch.object(
|
||||
handler, "handle_tool_calls", return_value=tool_handler_gen()
|
||||
):
|
||||
messages = []
|
||||
tools_dict = {}
|
||||
|
||||
# Execute streaming
|
||||
list(handler.handle_streaming(mock_agent, "first", tools_dict, messages))
|
||||
|
||||
# Should have called gen_stream with tools=None (disabled)
|
||||
mock_agent.llm.gen_stream.assert_called()
|
||||
call_kwargs = mock_agent.llm.gen_stream.call_args.kwargs
|
||||
assert call_kwargs.get("tools") is None
|
||||
|
||||
# Should have logged the warning
|
||||
assert mock_logger.info.called
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
1082
tests/test_compression_service.py
Normal file
1082
tests/test_compression_service.py
Normal file
File diff suppressed because it is too large
Load Diff
1287
tests/test_integration.py
Executable file
1287
tests/test_integration.py
Executable file
File diff suppressed because it is too large
Load Diff
106
tests/test_model_validation.py
Normal file
106
tests/test_model_validation.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""
|
||||
Tests for model validation and base_url functionality
|
||||
"""
|
||||
import pytest
|
||||
from application.core.model_settings import (
|
||||
AvailableModel,
|
||||
ModelCapabilities,
|
||||
ModelProvider,
|
||||
ModelRegistry,
|
||||
)
|
||||
from application.core.model_utils import (
|
||||
get_base_url_for_model,
|
||||
validate_model_id,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_model_with_base_url():
|
||||
"""Test that AvailableModel can store and retrieve base_url"""
|
||||
model = AvailableModel(
|
||||
id="test-model",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="Test Model",
|
||||
description="Test model with custom base URL",
|
||||
base_url="https://custom-endpoint.com/v1",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
context_window=8192,
|
||||
),
|
||||
)
|
||||
|
||||
assert model.base_url == "https://custom-endpoint.com/v1"
|
||||
assert model.id == "test-model"
|
||||
assert model.provider == ModelProvider.OPENAI
|
||||
|
||||
# Test to_dict includes base_url
|
||||
model_dict = model.to_dict()
|
||||
assert "base_url" in model_dict
|
||||
assert model_dict["base_url"] == "https://custom-endpoint.com/v1"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_model_without_base_url():
|
||||
"""Test that models without base_url still work"""
|
||||
model = AvailableModel(
|
||||
id="test-model-no-url",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="Test Model",
|
||||
description="Test model without base URL",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
context_window=8192,
|
||||
),
|
||||
)
|
||||
|
||||
assert model.base_url is None
|
||||
|
||||
# Test to_dict doesn't include base_url when None
|
||||
model_dict = model.to_dict()
|
||||
assert "base_url" not in model_dict
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_validate_model_id():
|
||||
"""Test model_id validation"""
|
||||
# Get the registry instance to check what models are available
|
||||
ModelRegistry.get_instance()
|
||||
|
||||
# Test with a model that should exist (docsgpt-local is always added)
|
||||
assert validate_model_id("docsgpt-local") is True
|
||||
|
||||
# Test with invalid model_id
|
||||
assert validate_model_id("invalid-model-xyz-123") is False
|
||||
|
||||
# Test with None
|
||||
assert validate_model_id(None) is False
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_base_url_for_model():
|
||||
"""Test retrieving base_url for a model"""
|
||||
# Test with a model that doesn't have base_url
|
||||
result = get_base_url_for_model("docsgpt-local")
|
||||
assert result is None # docsgpt-local doesn't have custom base_url
|
||||
|
||||
# Test with invalid model
|
||||
result = get_base_url_for_model("invalid-model")
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_model_validation_error_message():
|
||||
"""Test that validation provides helpful error messages"""
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
|
||||
# Create processor with invalid model_id
|
||||
data = {"model_id": "invalid-model-xyz"}
|
||||
processor = StreamProcessor(data, None)
|
||||
|
||||
# Should raise ValueError with helpful message
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
processor._validate_and_set_model()
|
||||
|
||||
error_msg = str(exc_info.value)
|
||||
assert "Invalid model_id 'invalid-model-xyz'" in error_msg
|
||||
assert "Available models:" in error_msg
|
||||
314
tests/test_token_management.py
Normal file
314
tests/test_token_management.py
Normal file
@@ -0,0 +1,314 @@
|
||||
"""
|
||||
Tests for token management and compression features.
|
||||
|
||||
NOTE: These tests are for future planned features that are not yet implemented.
|
||||
They are skipped until the following modules are created:
|
||||
- application.compression (DocumentCompressor, HistoryCompressor, etc.)
|
||||
- application.core.token_budget (TokenBudgetManager)
|
||||
"""
|
||||
# ruff: noqa: F821
|
||||
import pytest
|
||||
|
||||
pytest.skip(
|
||||
"Token management features not yet implemented - planned for future release",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
|
||||
class TestTokenBudgetManager:
|
||||
"""Test TokenBudgetManager functionality"""
|
||||
|
||||
def test_calculate_budget(self):
|
||||
"""Test budget calculation"""
|
||||
manager = TokenBudgetManager(model_id="gpt-4o")
|
||||
budget = manager.calculate_budget()
|
||||
|
||||
assert budget.total_budget > 0
|
||||
assert budget.system_prompt > 0
|
||||
assert budget.chat_history > 0
|
||||
assert budget.retrieved_docs > 0
|
||||
|
||||
def test_measure_usage(self):
|
||||
"""Test token usage measurement"""
|
||||
manager = TokenBudgetManager(model_id="gpt-4o")
|
||||
|
||||
usage = manager.measure_usage(
|
||||
system_prompt="You are a helpful assistant.",
|
||||
current_query="What is Python?",
|
||||
chat_history=[
|
||||
{"prompt": "Hello", "response": "Hi there!"},
|
||||
{"prompt": "How are you?", "response": "I'm doing well, thanks!"},
|
||||
],
|
||||
)
|
||||
|
||||
assert usage.total > 0
|
||||
assert usage.system_prompt > 0
|
||||
assert usage.current_query > 0
|
||||
assert usage.chat_history > 0
|
||||
|
||||
def test_compression_recommendation(self):
|
||||
"""Test compression recommendation generation"""
|
||||
manager = TokenBudgetManager(model_id="gpt-4o")
|
||||
|
||||
# Create scenario with excessive history
|
||||
large_history = [
|
||||
{"prompt": f"Question {i}" * 100, "response": f"Answer {i}" * 100}
|
||||
for i in range(100)
|
||||
]
|
||||
|
||||
budget, usage, recommendation = manager.check_and_recommend(
|
||||
system_prompt="You are a helpful assistant.",
|
||||
current_query="What is Python?",
|
||||
chat_history=large_history,
|
||||
)
|
||||
|
||||
# Should recommend compression
|
||||
assert recommendation.needs_compression()
|
||||
assert recommendation.compress_history
|
||||
|
||||
|
||||
class TestHistoryCompressor:
|
||||
"""Test HistoryCompressor functionality"""
|
||||
|
||||
def test_sliding_window_compression(self):
|
||||
"""Test sliding window compression strategy"""
|
||||
compressor = HistoryCompressor()
|
||||
|
||||
history = [
|
||||
{"prompt": f"Question {i}", "response": f"Answer {i}"} for i in range(20)
|
||||
]
|
||||
|
||||
compressed, metadata = compressor.compress(
|
||||
history, target_tokens=500, strategy="sliding_window"
|
||||
)
|
||||
|
||||
assert len(compressed) < len(history)
|
||||
assert metadata["original_messages"] == 20
|
||||
assert metadata["compressed_messages"] < 20
|
||||
assert metadata["strategy"] == "sliding_window"
|
||||
|
||||
def test_preserve_tool_calls(self):
|
||||
"""Test that tool calls are preserved during compression"""
|
||||
compressor = HistoryCompressor()
|
||||
|
||||
history = [
|
||||
{"prompt": "Question 1", "response": "Answer 1"},
|
||||
{
|
||||
"prompt": "Use a tool",
|
||||
"response": "Tool used",
|
||||
"tool_calls": [{"tool_name": "search", "result": "Found something"}],
|
||||
},
|
||||
{"prompt": "Question 3", "response": "Answer 3"},
|
||||
]
|
||||
|
||||
compressed, metadata = compressor.compress(
|
||||
history, target_tokens=200, strategy="sliding_window", preserve_tool_calls=True
|
||||
)
|
||||
|
||||
# Tool call message should be preserved
|
||||
has_tool_calls = any("tool_calls" in msg for msg in compressed)
|
||||
assert has_tool_calls
|
||||
|
||||
|
||||
class TestDocumentCompressor:
|
||||
"""Test DocumentCompressor functionality"""
|
||||
|
||||
def test_rerank_compression(self):
|
||||
"""Test re-ranking compression strategy"""
|
||||
compressor = DocumentCompressor()
|
||||
|
||||
docs = [
|
||||
{"text": f"Document {i} with some content here" * 20, "title": f"Doc {i}"}
|
||||
for i in range(10)
|
||||
]
|
||||
|
||||
compressed, metadata = compressor.compress(
|
||||
docs, target_tokens=500, query="Document 5", strategy="rerank"
|
||||
)
|
||||
|
||||
assert len(compressed) < len(docs)
|
||||
assert metadata["original_docs"] == 10
|
||||
assert metadata["strategy"] == "rerank"
|
||||
|
||||
def test_excerpt_extraction(self):
|
||||
"""Test excerpt extraction strategy"""
|
||||
compressor = DocumentCompressor()
|
||||
|
||||
docs = [
|
||||
{
|
||||
"text": "This is a long document. " * 100
|
||||
+ "Python is great. "
|
||||
+ "More text here. " * 100,
|
||||
"title": "Python Guide",
|
||||
}
|
||||
]
|
||||
|
||||
compressed, metadata = compressor.compress(
|
||||
docs, target_tokens=300, query="Python", strategy="excerpt"
|
||||
)
|
||||
|
||||
assert metadata["excerpts_created"] > 0
|
||||
# Excerpt should contain the query term
|
||||
assert "python" in compressed[0]["text"].lower()
|
||||
|
||||
|
||||
class TestToolResultCompressor:
|
||||
"""Test ToolResultCompressor functionality"""
|
||||
|
||||
def test_truncate_large_results(self):
|
||||
"""Test truncation of large tool results"""
|
||||
compressor = ToolResultCompressor()
|
||||
|
||||
tool_results = [
|
||||
{
|
||||
"tool_name": "search",
|
||||
"result": "Very long result " * 1000,
|
||||
"arguments": {},
|
||||
}
|
||||
]
|
||||
|
||||
compressed, metadata = compressor.compress(
|
||||
tool_results, target_tokens=100, strategy="truncate"
|
||||
)
|
||||
|
||||
assert metadata["results_truncated"] > 0
|
||||
# Result should be shorter
|
||||
compressed_result_len = len(str(compressed[0]["result"]))
|
||||
original_result_len = len(tool_results[0]["result"])
|
||||
assert compressed_result_len < original_result_len
|
||||
|
||||
def test_extract_json_fields(self):
|
||||
"""Test extraction of key fields from JSON results"""
|
||||
compressor = ToolResultCompressor()
|
||||
|
||||
tool_results = [
|
||||
{
|
||||
"tool_name": "api_call",
|
||||
"result": {
|
||||
"data": {"important": "value"},
|
||||
"metadata": {"verbose": "information" * 100},
|
||||
"debug": {"lots": "of data" * 100},
|
||||
},
|
||||
"arguments": {},
|
||||
}
|
||||
]
|
||||
|
||||
compressed, metadata = compressor.compress(
|
||||
tool_results, target_tokens=100, strategy="extract"
|
||||
)
|
||||
|
||||
# Should keep important fields, discard verbose ones
|
||||
assert "data" in compressed[0]["result"]
|
||||
|
||||
|
||||
class TestPromptOptimizer:
|
||||
"""Test PromptOptimizer functionality"""
|
||||
|
||||
def test_compress_tool_descriptions(self):
|
||||
"""Test compression of tool descriptions"""
|
||||
optimizer = PromptOptimizer()
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": f"tool_{i}",
|
||||
"description": "This is a very long description " * 50,
|
||||
"parameters": {},
|
||||
},
|
||||
}
|
||||
for i in range(10)
|
||||
]
|
||||
|
||||
optimized, metadata = optimizer.optimize_tools(
|
||||
tools, target_tokens=500, strategy="compress"
|
||||
)
|
||||
|
||||
assert metadata["optimized_tokens"] < metadata["original_tokens"]
|
||||
assert metadata["descriptions_compressed"] > 0
|
||||
|
||||
def test_lazy_load_tools(self):
|
||||
"""Test lazy loading of tools based on query"""
|
||||
optimizer = PromptOptimizer()
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_tool",
|
||||
"description": "Search for information",
|
||||
"parameters": {},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "calculate_tool",
|
||||
"description": "Perform calculations",
|
||||
"parameters": {},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "other_tool",
|
||||
"description": "Do something else",
|
||||
"parameters": {},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
optimized, metadata = optimizer.optimize_tools(
|
||||
tools, target_tokens=200, query="I want to search for something", strategy="lazy_load"
|
||||
)
|
||||
|
||||
# Should prefer search tool
|
||||
assert len(optimized) < len(tools)
|
||||
tool_names = [t["function"]["name"] for t in optimized]
|
||||
# Search tool should be included due to query relevance
|
||||
assert any("search" in name for name in tool_names)
|
||||
|
||||
|
||||
def test_integration_compression_workflow():
|
||||
"""Test complete compression workflow"""
|
||||
# Simulate a scenario with large inputs
|
||||
manager = TokenBudgetManager(model_id="gpt-4o")
|
||||
history_compressor = HistoryCompressor()
|
||||
doc_compressor = DocumentCompressor()
|
||||
|
||||
# Large chat history
|
||||
history = [
|
||||
{"prompt": f"Question {i}" * 50, "response": f"Answer {i}" * 50}
|
||||
for i in range(50)
|
||||
]
|
||||
|
||||
# Large documents
|
||||
docs = [
|
||||
{"text": f"Document {i} content" * 100, "title": f"Doc {i}"} for i in range(20)
|
||||
]
|
||||
|
||||
# Check budget
|
||||
budget, usage, recommendation = manager.check_and_recommend(
|
||||
system_prompt="You are a helpful assistant.",
|
||||
current_query="What is Python?",
|
||||
chat_history=history,
|
||||
retrieved_docs=docs,
|
||||
)
|
||||
|
||||
# Should need compression
|
||||
assert recommendation.needs_compression()
|
||||
|
||||
# Apply compression
|
||||
if recommendation.compress_history:
|
||||
compressed_history, hist_meta = history_compressor.compress(
|
||||
history, recommendation.target_history_tokens or budget.chat_history
|
||||
)
|
||||
assert len(compressed_history) < len(history)
|
||||
|
||||
if recommendation.compress_docs:
|
||||
compressed_docs, doc_meta = doc_compressor.compress(
|
||||
docs,
|
||||
recommendation.target_docs_tokens or budget.retrieved_docs,
|
||||
query="Python",
|
||||
)
|
||||
assert len(compressed_docs) < len(docs)
|
||||
Reference in New Issue
Block a user