diff --git a/.env-template b/.env-template index e93f0363..13575fc3 100644 --- a/.env-template +++ b/.env-template @@ -1,9 +1,28 @@ API_KEY= LLM_NAME=docsgpt VITE_API_STREAMING=true +INTERNAL_KEY= + +# Remote Embeddings (Optional - for using a remote embeddings API instead of local SentenceTransformer) +# When set, the app will use the remote API and won't load SentenceTransformer (saves RAM) +EMBEDDINGS_BASE_URL= +EMBEDDINGS_KEY= #For Azure (you can delete it if you don't use Azure) OPENAI_API_BASE= OPENAI_API_VERSION= AZURE_DEPLOYMENT_NAME= -AZURE_EMBEDDINGS_DEPLOYMENT_NAME= \ No newline at end of file +AZURE_EMBEDDINGS_DEPLOYMENT_NAME= + +#Azure AD Application (client) ID +MICROSOFT_CLIENT_ID=your-azure-ad-client-id +#Azure AD Application client secret +MICROSOFT_CLIENT_SECRET=your-azure-ad-client-secret +#Azure AD Tenant ID (or 'common' for multi-tenant) +MICROSOFT_TENANT_ID=your-azure-ad-tenant-id +#If you are using a Microsoft Entra ID tenant, +#configure the AUTHORITY variable as +#"https://login.microsoftonline.com/TENANT_GUID" +#or "https://login.microsoftonline.com/contoso.onmicrosoft.com". +#Alternatively, use "https://login.microsoftonline.com/common" for multi-tenant app. +MICROSOFT_AUTHORITY=https://{tenantId}.ciamlogin.com/{tenantId} diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index a36f529b..ec9fdbdf 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -7,6 +7,9 @@ on: pull_request: types: [ opened, synchronize ] +permissions: + contents: read + jobs: ruff: runs-on: ubuntu-latest diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 8b85366a..d7a66bdc 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -1,5 +1,9 @@ name: Run python tests with pytest on: [push, pull_request] + +permissions: + contents: read + jobs: pytest_and_coverage: name: Run tests and count coverage diff --git a/.github/workflows/vale.yml b/.github/workflows/vale.yml index a0f8167c..48b25b49 100644 --- a/.github/workflows/vale.yml +++ b/.github/workflows/vale.yml @@ -9,6 +9,10 @@ on: - '.vale.ini' - '.github/styles/**' +permissions: + contents: read + pull-requests: write + jobs: vale: runs-on: ubuntu-latest diff --git a/.gitignore b/.gitignore index 91abeca1..9b09303d 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ __pycache__/ *.py[cod] *$py.class +experiments/ experiments # C extensions @@ -70,6 +71,7 @@ instance/ # Sphinx documentation docs/_build/ +docs/public/_pagefind/ # PyBuilder target/ @@ -146,6 +148,10 @@ frontend/yarn-error.log* frontend/pnpm-debug.log* frontend/lerna-debug.log* +# Keep frontend utility helpers tracked (overrides global lib/ ignore) +!frontend/src/lib/ +!frontend/src/lib/** + frontend/node_modules frontend/dist frontend/dist-ssr diff --git a/.ruff.toml b/.ruff.toml index 857f8153..8d9833ff 100644 --- a/.ruff.toml +++ b/.ruff.toml @@ -1,2 +1,6 @@ # Allow lines to be as long as 120 characters. -line-length = 120 \ No newline at end of file +line-length = 120 + +[lint.per-file-ignores] +# Integration tests use sys.path.insert() before imports for standalone execution +"tests/integration/*.py" = ["E402"] \ No newline at end of file diff --git a/README.md b/README.md index a341a273..b93c7b30 100644 --- a/README.md +++ b/README.md @@ -26,13 +26,6 @@ -
-
-🎃 Hacktoberfest Prizes, Rules & Q&A 🎃 -
-
-
-

@@ -53,24 +46,11 @@ ## Roadmap - -- [x] Full GoogleAI compatibility (Jan 2025) -- [x] Add tools (Jan 2025) -- [x] Manually updating chunks in the app UI (Feb 2025) -- [x] Devcontainer for easy development (Feb 2025) -- [x] ReACT agent (March 2025) -- [x] Chatbots menu re-design to handle tools, agent types, and more (April 2025) -- [x] New input box in the conversation menu (April 2025) -- [x] Add triggerable actions / tools (webhook) (April 2025) -- [x] Agent optimisations (May 2025) -- [x] Filesystem sources update (July 2025) -- [x] Json Responses (August 2025) -- [x] MCP support (August 2025) -- [x] Google Drive integration (September 2025) -- [x] Add OAuth 2.0 authentication for MCP (September 2025) -- [ ] SharePoint integration (October 2025) -- [ ] Deep Agents (October 2025) -- [ ] Agent scheduling +- [x] Add OAuth 2.0 authentication for MCP ( September 2025 ) +- [x] Deep Agents ( October 2025 ) +- [x] Prompt Templating ( October 2025 ) +- [x] Full api tooling ( Dec 2025 ) +- [ ] Agent scheduling ( Jan 2026 ) You can find our full roadmap [here](https://github.com/orgs/arc53/projects/2). Please don't hesitate to contribute or create issues, it helps us improve DocsGPT! @@ -165,9 +145,17 @@ We as members, contributors, and leaders, pledge to make participation in our co The source code license is [MIT](https://opensource.org/license/mit/), as described in the [LICENSE](LICENSE) file. -

This project is supported by:

+## This project is supported by: +

+

+ + color + + +

+ diff --git a/application/.env_sample b/application/.env_sample deleted file mode 100644 index 8ab24d2a..00000000 --- a/application/.env_sample +++ /dev/null @@ -1,11 +0,0 @@ -API_KEY=your_api_key -EMBEDDINGS_KEY=your_api_key -API_URL=http://localhost:7091 -FLASK_APP=application/app.py -FLASK_DEBUG=true - -#For OPENAI on Azure -OPENAI_API_BASE= -OPENAI_API_VERSION= -AZURE_DEPLOYMENT_NAME= -AZURE_EMBEDDINGS_DEPLOYMENT_NAME= \ No newline at end of file diff --git a/application/Dockerfile b/application/Dockerfile index e33721a2..48d29e57 100644 --- a/application/Dockerfile +++ b/application/Dockerfile @@ -7,7 +7,7 @@ RUN apt-get update && \ apt-get install -y software-properties-common && \ add-apt-repository ppa:deadsnakes/ppa && \ apt-get update && \ - apt-get install -y --no-install-recommends gcc wget unzip libc6-dev python3.12 python3.12-venv && \ + apt-get install -y --no-install-recommends gcc g++ wget unzip libc6-dev python3.12 python3.12-venv python3.12-dev && \ rm -rf /var/lib/apt/lists/* # Verify Python installation and setup symlink @@ -48,7 +48,12 @@ FROM ubuntu:24.04 as final RUN apt-get update && \ apt-get install -y software-properties-common && \ add-apt-repository ppa:deadsnakes/ppa && \ - apt-get update && apt-get install -y --no-install-recommends python3.12 && \ + apt-get update && apt-get install -y --no-install-recommends \ + python3.12 \ + libgl1 \ + libglib2.0-0 \ + poppler-utils \ + && \ ln -s /usr/bin/python3.12 /usr/bin/python && \ rm -rf /var/lib/apt/lists/* diff --git a/application/agents/agent_creator.py b/application/agents/agent_creator.py index bf37d4ec..e15165bf 100644 --- a/application/agents/agent_creator.py +++ b/application/agents/agent_creator.py @@ -1,11 +1,17 @@ +import logging + from application.agents.classic_agent import ClassicAgent from application.agents.react_agent import ReActAgent +from application.agents.workflow_agent import WorkflowAgent + +logger = logging.getLogger(__name__) class AgentCreator: agents = { "classic": ClassicAgent, "react": ReActAgent, + "workflow": WorkflowAgent, } @classmethod diff --git a/application/agents/base.py b/application/agents/base.py index 27428fc3..ee55a449 100644 --- a/application/agents/base.py +++ b/application/agents/base.py @@ -7,11 +7,16 @@ from bson.objectid import ObjectId from application.agents.tools.tool_action_parser import ToolActionParser from application.agents.tools.tool_manager import ToolManager +from application.core.json_schema_utils import ( + JsonSchemaValidationError, + normalize_json_schema_payload, +) from application.core.mongo_db import MongoDB from application.core.settings import settings from application.llm.handlers.handler_creator import LLMHandlerCreator from application.llm.llm_creator import LLMCreator from application.logging import build_stack_data, log_activity, LogContext +from application.security.encryption import decrypt_credentials logger = logging.getLogger(__name__) @@ -21,8 +26,9 @@ class BaseAgent(ABC): self, endpoint: str, llm_name: str, - gpt_model: str, + model_id: str, api_key: str, + agent_id: Optional[str] = None, user_api_key: Optional[str] = None, prompt: str = "", chat_history: Optional[List[Dict]] = None, @@ -34,11 +40,13 @@ class BaseAgent(ABC): token_limit: Optional[int] = settings.DEFAULT_AGENT_LIMITS["token_limit"], limited_request_mode: Optional[bool] = False, request_limit: Optional[int] = settings.DEFAULT_AGENT_LIMITS["request_limit"], + compressed_summary: Optional[str] = None, ): self.endpoint = endpoint self.llm_name = llm_name - self.gpt_model = gpt_model + self.model_id = model_id self.api_key = api_key + self.agent_id = agent_id self.user_api_key = user_api_key self.prompt = prompt self.decoded_token = decoded_token or {} @@ -52,17 +60,27 @@ class BaseAgent(ABC): api_key=api_key, user_api_key=user_api_key, decoded_token=decoded_token, + model_id=model_id, + agent_id=agent_id, ) self.retrieved_docs = retrieved_docs or [] self.llm_handler = LLMHandlerCreator.create_handler( llm_name if llm_name else "default" ) self.attachments = attachments or [] - self.json_schema = json_schema + self.json_schema = None + if json_schema is not None: + try: + self.json_schema = normalize_json_schema_payload(json_schema) + except JsonSchemaValidationError as exc: + logger.warning("Ignoring invalid JSON schema payload: %s", exc) self.limited_token_mode = limited_token_mode self.token_limit = token_limit self.limited_request_mode = limited_request_mode self.request_limit = request_limit + self.compressed_summary = compressed_summary + self.current_token_count = 0 + self.context_limit_reached = False @log_activity() def gen( @@ -115,10 +133,10 @@ class BaseAgent(ABC): params["properties"][k] = { key: value for key, value in v.items() - if key != "filled_by_llm" and key != "value" + if key not in ("filled_by_llm", "value", "required") } - - params["required"].append(k) + if v.get("required", False): + params["required"].append(k) return params def _prepare_tools(self, tools_dict): @@ -214,7 +232,11 @@ class BaseAgent(ABC): for param_type, target_dict in param_types.items(): if param_type in action_data and action_data[param_type].get("properties"): for param, details in action_data[param_type]["properties"].items(): - if param not in call_args and "value" in details: + if ( + param not in call_args + and "value" in details + and details["value"] + ): target_dict[param] = details["value"] for param, value in call_args.items(): for param_type, target_dict in param_types.items(): @@ -227,36 +249,86 @@ class BaseAgent(ABC): # Prepare tool_config and add tool_id for memory tools if tool_data["name"] == "api_tool": + action_config = tool_data["config"]["actions"][action_name] tool_config = { - "url": tool_data["config"]["actions"][action_name]["url"], - "method": tool_data["config"]["actions"][action_name]["method"], + "url": action_config["url"], + "method": action_config["method"], "headers": headers, "query_params": query_params, } + if "body_content_type" in action_config: + tool_config["body_content_type"] = action_config.get( + "body_content_type", "application/json" + ) + tool_config["body_encoding_rules"] = action_config.get( + "body_encoding_rules", {} + ) else: tool_config = tool_data["config"].copy() if tool_data["config"] else {} - # Add tool_id from MongoDB _id for tools that need instance isolation (like memory tool) - # Use MongoDB _id if available, otherwise fall back to enumerated tool_id - + if tool_config.get("encrypted_credentials") and self.user: + decrypted = decrypt_credentials( + tool_config["encrypted_credentials"], self.user + ) + tool_config.update(decrypted) + tool_config["auth_credentials"] = decrypted + tool_config.pop("encrypted_credentials", None) tool_config["tool_id"] = str(tool_data.get("_id", tool_id)) + if hasattr(self, "conversation_id") and self.conversation_id: + tool_config["conversation_id"] = self.conversation_id + if tool_data["name"] == "mcp_tool": + tool_config["query_mode"] = True tool = tm.load_tool( tool_data["name"], tool_config=tool_config, - user_id=self.user, # Pass user ID for MCP tools credential decryption + user_id=self.user, + ) + resolved_arguments = ( + {"query_params": query_params, "headers": headers, "body": body} + if tool_data["name"] == "api_tool" + else parameters ) if tool_data["name"] == "api_tool": - print( + logger.debug( f"Executing api: {action_name} with query_params: {query_params}, headers: {headers}, body: {body}" ) result = tool.execute_action(action_name, **body) else: - print(f"Executing tool: {action_name} with args: {call_args}") + logger.debug(f"Executing tool: {action_name} with args: {call_args}") result = tool.execute_action(action_name, **parameters) - tool_call_data["result"] = ( - f"{str(result)[:50]}..." if len(str(result)) > 50 else result + + get_artifact_id = ( + getattr(tool, "get_artifact_id", None) + if tool_data["name"] != "api_tool" + else None ) - yield {"type": "tool_call", "data": {**tool_call_data, "status": "completed"}} + artifact_id = None + if callable(get_artifact_id): + try: + artifact_id = get_artifact_id(action_name, **parameters) + except Exception: + logger.exception( + "Failed to extract artifact_id from tool %s for action %s", + tool_data["name"], + action_name, + ) + + artifact_id = str(artifact_id).strip() if artifact_id is not None else "" + if artifact_id: + tool_call_data["artifact_id"] = artifact_id + result_full = str(result) + tool_call_data["resolved_arguments"] = resolved_arguments + tool_call_data["result_full"] = result_full + tool_call_data["result"] = ( + f"{result_full[:50]}..." if len(result_full) > 50 else result_full + ) + + stream_tool_call_data = { + key: value + for key, value in tool_call_data.items() + if key not in {"result_full", "resolved_arguments"} + } + yield {"type": "tool_call", "data": {**stream_tool_call_data, "status": "completed"}} self.tool_calls.append(tool_call_data) return result, call_id @@ -264,7 +336,11 @@ class BaseAgent(ABC): def _get_truncated_tool_calls(self): return [ { - **tool_call, + "tool_name": tool_call.get("tool_name"), + "call_id": tool_call.get("call_id"), + "action_name": tool_call.get("action_name"), + "arguments": tool_call.get("arguments"), + "artifact_id": tool_call.get("artifact_id"), "result": ( f"{str(tool_call['result'])[:50]}..." if len(str(tool_call["result"])) > 50 @@ -275,15 +351,176 @@ class BaseAgent(ABC): for tool_call in self.tool_calls ] + def _calculate_current_context_tokens(self, messages: List[Dict]) -> int: + """ + Calculate total tokens in current context (messages). + + Args: + messages: List of message dicts + + Returns: + Total token count + """ + from application.api.answer.services.compression.token_counter import ( + TokenCounter, + ) + + return TokenCounter.count_message_tokens(messages) + + def _check_context_limit(self, messages: List[Dict]) -> bool: + """ + Check if we're approaching context limit (80%). + + Args: + messages: Current message list + + Returns: + True if at or above 80% of context limit + """ + from application.core.model_utils import get_token_limit + from application.core.settings import settings + + try: + # Calculate current tokens + current_tokens = self._calculate_current_context_tokens(messages) + self.current_token_count = current_tokens + + # Get context limit for model + context_limit = get_token_limit(self.model_id) + + # Calculate threshold (80%) + threshold = int(context_limit * settings.COMPRESSION_THRESHOLD_PERCENTAGE) + + # Check if we've reached the limit + if current_tokens >= threshold: + logger.warning( + f"Context limit approaching: {current_tokens}/{context_limit} tokens " + f"({(current_tokens/context_limit)*100:.1f}%)" + ) + return True + + return False + + except Exception as e: + logger.error(f"Error checking context limit: {str(e)}", exc_info=True) + return False + + def _validate_context_size(self, messages: List[Dict]) -> None: + """ + Pre-flight validation before calling LLM. Logs warnings but never raises errors. + + Args: + messages: Messages to be sent to LLM + """ + from application.core.model_utils import get_token_limit + + current_tokens = self._calculate_current_context_tokens(messages) + self.current_token_count = current_tokens + context_limit = get_token_limit(self.model_id) + + percentage = (current_tokens / context_limit) * 100 + + # Log based on usage level + if current_tokens >= context_limit: + logger.warning( + f"Context at limit: {current_tokens:,}/{context_limit:,} tokens " + f"({percentage:.1f}%). Model: {self.model_id}" + ) + elif current_tokens >= int( + context_limit * settings.COMPRESSION_THRESHOLD_PERCENTAGE + ): + logger.info( + f"Context approaching limit: {current_tokens:,}/{context_limit:,} tokens " + f"({percentage:.1f}%)" + ) + + def _truncate_text_middle(self, text: str, max_tokens: int) -> str: + """ + Truncate text by removing content from the middle, preserving start and end. + + Args: + text: Text to truncate + max_tokens: Maximum tokens allowed + + Returns: + Truncated text with middle removed if needed + """ + from application.utils import num_tokens_from_string + + current_tokens = num_tokens_from_string(text) + if current_tokens <= max_tokens: + return text + + # Estimate chars per token (roughly 4 chars per token for English) + chars_per_token = len(text) / current_tokens if current_tokens > 0 else 4 + target_chars = int(max_tokens * chars_per_token * 0.95) # 5% safety margin + + if target_chars <= 0: + return "" + + # Split: keep 40% from start, 40% from end, remove middle + start_chars = int(target_chars * 0.4) + end_chars = int(target_chars * 0.4) + + truncation_marker = "\n\n[... content truncated to fit context limit ...]\n\n" + + truncated = text[:start_chars] + truncation_marker + text[-end_chars:] + + logger.info( + f"Truncated text from {current_tokens:,} to ~{max_tokens:,} tokens " + f"(removed middle section)" + ) + + return truncated + def _build_messages( self, system_prompt: str, query: str, ) -> List[Dict]: """Build messages using pre-rendered system prompt""" + from application.core.model_utils import get_token_limit + from application.utils import num_tokens_from_string + + # Append compression summary to system prompt if present + if self.compressed_summary: + compression_context = ( + "\n\n---\n\n" + "This session is being continued from a previous conversation that " + "has been compressed to fit within context limits. " + "The conversation is summarized below:\n\n" + f"{self.compressed_summary}" + ) + system_prompt = system_prompt + compression_context + + context_limit = get_token_limit(self.model_id) + system_tokens = num_tokens_from_string(system_prompt) + + # Reserve 10% for response/tools + safety_buffer = int(context_limit * 0.1) + available_after_system = context_limit - system_tokens - safety_buffer + + # Max tokens for query: 80% of available space (leave room for history) + max_query_tokens = int(available_after_system * 0.8) + query_tokens = num_tokens_from_string(query) + + # Truncate query from middle if it exceeds 80% of available context + if query_tokens > max_query_tokens: + query = self._truncate_text_middle(query, max_query_tokens) + query_tokens = num_tokens_from_string(query) + + # Calculate remaining budget for chat history + available_for_history = max(available_after_system - query_tokens, 0) + + # Truncate chat history to fit within available budget + working_history = self._truncate_history_to_fit( + self.chat_history, + available_for_history, + ) + messages = [{"role": "system", "content": system_prompt}] - for i in self.chat_history: + for i in working_history: if "prompt" in i and "response" in i: messages.append({"role": "user", "content": i["prompt"]}) messages.append({"role": "assistant", "content": i["response"]}) @@ -315,8 +552,69 @@ class BaseAgent(ABC): messages.append({"role": "user", "content": query}) return messages + def _truncate_history_to_fit( + self, + history: List[Dict], + max_tokens: int, + ) -> List[Dict]: + """ + Truncate chat history to fit within token budget, keeping most recent messages. + + Args: + history: Full chat history + max_tokens: Maximum tokens allowed for history + + Returns: + Truncated history (most recent messages that fit) + """ + from application.utils import num_tokens_from_string + + if not history or max_tokens <= 0: + return [] + + truncated = [] + current_tokens = 0 + + # Iterate from newest to oldest + for message in reversed(history): + message_tokens = 0 + + if "prompt" in message and "response" in message: + message_tokens += num_tokens_from_string(message["prompt"]) + message_tokens += num_tokens_from_string(message["response"]) + + if "tool_calls" in message: + for tool_call in message["tool_calls"]: + tool_str = ( + f"Tool: {tool_call.get('tool_name')} | " + f"Action: {tool_call.get('action_name')} | " + f"Args: {tool_call.get('arguments')} | " + f"Response: {tool_call.get('result')}" + ) + message_tokens += num_tokens_from_string(tool_str) + + if current_tokens + message_tokens <= max_tokens: + current_tokens += message_tokens + truncated.insert(0, message) # Maintain chronological order + else: + break + + if len(truncated) < len(history): + logger.info( + f"Truncated chat history from {len(history)} to {len(truncated)} messages " + f"to fit within {max_tokens:,} token budget" + ) + + return truncated + def _llm_gen(self, messages: List[Dict], log_context: Optional[LogContext] = None): - gen_kwargs = {"model": self.gpt_model, "messages": messages} + # Pre-flight context validation - fail fast if over limit + self._validate_context_size(messages) + + gen_kwargs = {"model": self.model_id, "messages": messages} + if self.attachments: + # Usage accounting only; stripped before provider invocation. + gen_kwargs["_usage_attachments"] = self.attachments if ( hasattr(self.llm, "_supports_tools") diff --git a/application/agents/react_agent.py b/application/agents/react_agent.py index 49dd29d8..92be75f6 100644 --- a/application/agents/react_agent.py +++ b/application/agents/react_agent.py @@ -86,7 +86,7 @@ class ReActAgent(BaseAgent): messages = [{"role": "user", "content": plan_prompt}] plan_stream = self.llm.gen_stream( - model=self.gpt_model, + model=self.model_id, messages=messages, tools=self.tools if self.tools else None, ) @@ -151,7 +151,7 @@ class ReActAgent(BaseAgent): messages = [{"role": "user", "content": final_prompt}] final_stream = self.llm.gen_stream( - model=self.gpt_model, messages=messages, tools=None + model=self.model_id, messages=messages, tools=None ) if log_context: @@ -235,4 +235,4 @@ class ReActAgent(BaseAgent): ) except Exception as e: logger.error(f"Error extracting content: {e}") - return "".join(collected) + return "".join(collected) \ No newline at end of file diff --git a/application/agents/tools/api_body_serializer.py b/application/agents/tools/api_body_serializer.py new file mode 100644 index 00000000..d23d1fcf --- /dev/null +++ b/application/agents/tools/api_body_serializer.py @@ -0,0 +1,323 @@ +import base64 +import json +import logging +from enum import Enum +from typing import Any, Dict, Optional, Union +from urllib.parse import quote, urlencode + +logger = logging.getLogger(__name__) + + +class ContentType(str, Enum): + """Supported content types for request bodies.""" + + JSON = "application/json" + FORM_URLENCODED = "application/x-www-form-urlencoded" + MULTIPART_FORM_DATA = "multipart/form-data" + TEXT_PLAIN = "text/plain" + XML = "application/xml" + OCTET_STREAM = "application/octet-stream" + + +class RequestBodySerializer: + """Serializes request bodies according to content-type and OpenAPI 3.1 spec.""" + + @staticmethod + def serialize( + body_data: Dict[str, Any], + content_type: str = ContentType.JSON, + encoding_rules: Optional[Dict[str, Dict[str, Any]]] = None, + ) -> tuple[Union[str, bytes], Dict[str, str]]: + """ + Serialize body data to appropriate format. + + Args: + body_data: Dictionary of body parameters + content_type: Content-Type header value + encoding_rules: OpenAPI Encoding Object rules per field + + Returns: + Tuple of (serialized_body, updated_headers_dict) + + Raises: + ValueError: If serialization fails + """ + if not body_data: + return None, {} + + try: + content_type_lower = content_type.lower().split(";")[0].strip() + + if content_type_lower == ContentType.JSON: + return RequestBodySerializer._serialize_json(body_data) + + elif content_type_lower == ContentType.FORM_URLENCODED: + return RequestBodySerializer._serialize_form_urlencoded( + body_data, encoding_rules + ) + + elif content_type_lower == ContentType.MULTIPART_FORM_DATA: + return RequestBodySerializer._serialize_multipart_form_data( + body_data, encoding_rules + ) + + elif content_type_lower == ContentType.TEXT_PLAIN: + return RequestBodySerializer._serialize_text_plain(body_data) + + elif content_type_lower == ContentType.XML: + return RequestBodySerializer._serialize_xml(body_data) + + elif content_type_lower == ContentType.OCTET_STREAM: + return RequestBodySerializer._serialize_octet_stream(body_data) + + else: + logger.warning( + f"Unknown content type: {content_type}, treating as JSON" + ) + return RequestBodySerializer._serialize_json(body_data) + + except Exception as e: + logger.error(f"Error serializing body: {str(e)}", exc_info=True) + raise ValueError(f"Failed to serialize request body: {str(e)}") + + @staticmethod + def _serialize_json(body_data: Dict[str, Any]) -> tuple[str, Dict[str, str]]: + """Serialize body as JSON per OpenAPI spec.""" + try: + serialized = json.dumps( + body_data, separators=(",", ":"), ensure_ascii=False + ) + headers = {"Content-Type": ContentType.JSON.value} + return serialized, headers + except (TypeError, ValueError) as e: + raise ValueError(f"Failed to serialize JSON body: {str(e)}") + + @staticmethod + def _serialize_form_urlencoded( + body_data: Dict[str, Any], + encoding_rules: Optional[Dict[str, Dict[str, Any]]] = None, + ) -> tuple[str, Dict[str, str]]: + """Serialize body as application/x-www-form-urlencoded per RFC1866/RFC3986.""" + encoding_rules = encoding_rules or {} + params = [] + + for key, value in body_data.items(): + if value is None: + continue + + rule = encoding_rules.get(key, {}) + style = rule.get("style", "form") + explode = rule.get("explode", style == "form") + content_type = rule.get("contentType", "text/plain") + + serialized_value = RequestBodySerializer._serialize_form_value( + value, style, explode, content_type, key + ) + + if isinstance(serialized_value, list): + for sv in serialized_value: + params.append((key, sv)) + else: + params.append((key, serialized_value)) + + # Use standard urlencode (replaces space with +) + serialized = urlencode(params, safe="") + headers = {"Content-Type": ContentType.FORM_URLENCODED.value} + return serialized, headers + + @staticmethod + def _serialize_form_value( + value: Any, style: str, explode: bool, content_type: str, key: str + ) -> Union[str, list]: + """Serialize individual form value with encoding rules.""" + if isinstance(value, dict): + if content_type == "application/json": + return json.dumps(value, separators=(",", ":")) + elif content_type == "application/xml": + return RequestBodySerializer._dict_to_xml(value) + else: + if style == "deepObject" and explode: + return [ + f"{RequestBodySerializer._percent_encode(str(v))}" + for v in value.values() + ] + elif explode: + return [ + f"{RequestBodySerializer._percent_encode(str(v))}" + for v in value.values() + ] + else: + pairs = [f"{k},{v}" for k, v in value.items()] + return RequestBodySerializer._percent_encode(",".join(pairs)) + + elif isinstance(value, (list, tuple)): + if explode: + return [ + RequestBodySerializer._percent_encode(str(item)) for item in value + ] + else: + return RequestBodySerializer._percent_encode( + ",".join(str(v) for v in value) + ) + + else: + return RequestBodySerializer._percent_encode(str(value)) + + @staticmethod + def _serialize_multipart_form_data( + body_data: Dict[str, Any], + encoding_rules: Optional[Dict[str, Dict[str, Any]]] = None, + ) -> tuple[bytes, Dict[str, str]]: + """ + Serialize body as multipart/form-data per RFC7578. + + Supports file uploads and encoding rules. + """ + import secrets + + encoding_rules = encoding_rules or {} + boundary = f"----DocsGPT{secrets.token_hex(16)}" + parts = [] + + for key, value in body_data.items(): + if value is None: + continue + + rule = encoding_rules.get(key, {}) + content_type = rule.get("contentType", "text/plain") + headers_rule = rule.get("headers", {}) + + part = RequestBodySerializer._create_multipart_part( + key, value, content_type, headers_rule + ) + parts.append(part) + + body_bytes = f"--{boundary}\r\n".encode("utf-8") + body_bytes += f"--{boundary}\r\n".join(parts).encode("utf-8") + body_bytes += f"\r\n--{boundary}--\r\n".encode("utf-8") + + headers = { + "Content-Type": f"multipart/form-data; boundary={boundary}", + } + return body_bytes, headers + + @staticmethod + def _create_multipart_part( + name: str, value: Any, content_type: str, headers_rule: Dict[str, Any] + ) -> str: + """Create a single multipart/form-data part.""" + headers = [ + f'Content-Disposition: form-data; name="{RequestBodySerializer._percent_encode(name)}"' + ] + + if isinstance(value, bytes): + if content_type == "application/octet-stream": + value_encoded = base64.b64encode(value).decode("utf-8") + else: + value_encoded = value.decode("utf-8", errors="replace") + headers.append(f"Content-Type: {content_type}") + headers.append("Content-Transfer-Encoding: base64") + elif isinstance(value, dict): + if content_type == "application/json": + value_encoded = json.dumps(value, separators=(",", ":")) + elif content_type == "application/xml": + value_encoded = RequestBodySerializer._dict_to_xml(value) + else: + value_encoded = str(value) + headers.append(f"Content-Type: {content_type}") + elif isinstance(value, str) and content_type != "text/plain": + try: + if content_type == "application/json": + json.loads(value) + value_encoded = value + elif content_type == "application/xml": + value_encoded = value + else: + value_encoded = str(value) + except json.JSONDecodeError: + value_encoded = str(value) + headers.append(f"Content-Type: {content_type}") + else: + value_encoded = str(value) + if content_type != "text/plain": + headers.append(f"Content-Type: {content_type}") + + part = "\r\n".join(headers) + "\r\n\r\n" + value_encoded + "\r\n" + return part + + @staticmethod + def _serialize_text_plain(body_data: Dict[str, Any]) -> tuple[str, Dict[str, str]]: + """Serialize body as plain text.""" + if len(body_data) == 1: + value = list(body_data.values())[0] + return str(value), {"Content-Type": ContentType.TEXT_PLAIN.value} + else: + text = "\n".join(f"{k}: {v}" for k, v in body_data.items()) + return text, {"Content-Type": ContentType.TEXT_PLAIN.value} + + @staticmethod + def _serialize_xml(body_data: Dict[str, Any]) -> tuple[str, Dict[str, str]]: + """Serialize body as XML.""" + xml_str = RequestBodySerializer._dict_to_xml(body_data) + return xml_str, {"Content-Type": ContentType.XML.value} + + @staticmethod + def _serialize_octet_stream( + body_data: Dict[str, Any], + ) -> tuple[bytes, Dict[str, str]]: + """Serialize body as binary octet stream.""" + if isinstance(body_data, bytes): + return body_data, {"Content-Type": ContentType.OCTET_STREAM.value} + elif isinstance(body_data, str): + return body_data.encode("utf-8"), { + "Content-Type": ContentType.OCTET_STREAM.value + } + else: + serialized = json.dumps(body_data) + return serialized.encode("utf-8"), { + "Content-Type": ContentType.OCTET_STREAM.value + } + + @staticmethod + def _percent_encode(value: str, safe_chars: str = "") -> str: + """ + Percent-encode per RFC3986. + + Args: + value: String to encode + safe_chars: Additional characters to not encode + """ + return quote(value, safe=safe_chars) + + @staticmethod + def _dict_to_xml(data: Dict[str, Any], root_name: str = "root") -> str: + """ + Convert dict to simple XML format. + """ + + def build_xml(obj: Any, name: str) -> str: + if isinstance(obj, dict): + inner = "".join(build_xml(v, k) for k, v in obj.items()) + return f"<{name}>{inner}" + elif isinstance(obj, (list, tuple)): + items = "".join( + build_xml(item, f"{name[:-1] if name.endswith('s') else name}") + for item in obj + ) + return items + else: + return f"<{name}>{RequestBodySerializer._escape_xml(str(obj))}" + + root = build_xml(data, root_name) + return f'{root}' + + @staticmethod + def _escape_xml(value: str) -> str: + """Escape XML special characters.""" + return ( + value.replace("&", "&") + .replace("<", "<") + .replace(">", ">") + .replace('"', """) + .replace("'", "'") + ) diff --git a/application/agents/tools/api_tool.py b/application/agents/tools/api_tool.py index 063313c4..e010b51b 100644 --- a/application/agents/tools/api_tool.py +++ b/application/agents/tools/api_tool.py @@ -1,72 +1,280 @@ import json +import logging +import re +from typing import Any, Dict, Optional +from urllib.parse import urlencode import requests + +from application.agents.tools.api_body_serializer import ( + ContentType, + RequestBodySerializer, +) from application.agents.tools.base import Tool +from application.core.url_validation import validate_url, SSRFError + +logger = logging.getLogger(__name__) + +DEFAULT_TIMEOUT = 90 # seconds class APITool(Tool): """ API Tool - A flexible tool for performing various API actions (e.g., sending messages, retrieving data) via custom user-specified APIs + A flexible tool for performing various API actions (e.g., sending messages, retrieving data) via custom user-specified APIs. """ def __init__(self, config): self.config = config self.url = config.get("url", "") self.method = config.get("method", "GET") - self.headers = config.get("headers", {"Content-Type": "application/json"}) + self.headers = config.get("headers", {}) self.query_params = config.get("query_params", {}) + self.body_content_type = config.get("body_content_type", ContentType.JSON) + self.body_encoding_rules = config.get("body_encoding_rules", {}) def execute_action(self, action_name, **kwargs): + """Execute an API action with the given arguments.""" return self._make_api_call( - self.url, self.method, self.headers, self.query_params, kwargs + self.url, + self.method, + self.headers, + self.query_params, + kwargs, + self.body_content_type, + self.body_encoding_rules, ) - def _make_api_call(self, url, method, headers, query_params, body): - if query_params: - url = f"{url}?{requests.compat.urlencode(query_params)}" - # if isinstance(body, dict): - # body = json.dumps(body) + def _make_api_call( + self, + url: str, + method: str, + headers: Dict[str, str], + query_params: Dict[str, Any], + body: Dict[str, Any], + content_type: str = ContentType.JSON, + encoding_rules: Optional[Dict[str, Dict[str, Any]]] = None, + ) -> Dict[str, Any]: + """ + Make an API call with proper body serialization and error handling. + + Args: + url: API endpoint URL + method: HTTP method (GET, POST, PUT, DELETE, PATCH, HEAD, OPTIONS) + headers: Request headers dict + query_params: URL query parameters + body: Request body as dict + content_type: Content-Type for serialization + encoding_rules: OpenAPI encoding rules + + Returns: + Dict with status_code, data, and message + """ + request_url = url + request_headers = headers.copy() if headers else {} + response = None + + # Validate URL to prevent SSRF attacks try: - print(f"Making API call: {method} {url} with body: {body}") - if body == "{}": - body = None - response = requests.request(method, url, headers=headers, data=body) - response.raise_for_status() - content_type = response.headers.get( - "Content-Type", "application/json" - ).lower() - if "application/json" in content_type: + validate_url(request_url) + except SSRFError as e: + logger.error(f"URL validation failed: {e}") + return { + "status_code": None, + "message": f"URL validation error: {e}", + "data": None, + } + + try: + path_params_used = set() + if query_params: + for match in re.finditer(r"\{([^}]+)\}", request_url): + param_name = match.group(1) + if param_name in query_params: + request_url = request_url.replace( + f"{{{param_name}}}", str(query_params[param_name]) + ) + path_params_used.add(param_name) + remaining_params = { + k: v for k, v in query_params.items() if k not in path_params_used + } + if remaining_params: + query_string = urlencode(remaining_params) + separator = "&" if "?" in request_url else "?" + request_url = f"{request_url}{separator}{query_string}" + + # Re-validate URL after parameter substitution to prevent SSRF via path params + try: + validate_url(request_url) + except SSRFError as e: + logger.error(f"URL validation failed after parameter substitution: {e}") + return { + "status_code": None, + "message": f"URL validation error: {e}", + "data": None, + } + + # Serialize body based on content type + + if body and body != {}: try: - data = response.json() - except json.JSONDecodeError as e: - print(f"Error decoding JSON: {e}. Raw response: {response.text}") + serialized_body, body_headers = RequestBodySerializer.serialize( + body, content_type, encoding_rules + ) + request_headers.update(body_headers) + except ValueError as e: + logger.error(f"Body serialization failed: {str(e)}") return { - "status_code": response.status_code, - "message": f"API call returned invalid JSON. Error: {e}", - "data": response.text, + "status_code": None, + "message": f"Body serialization error: {str(e)}", + "data": None, } - elif "text/" in content_type or "application/xml" in content_type: - data = response.text - elif not response.content: - data = None else: - print(f"Unsupported content type: {content_type}") - data = response.content + serialized_body = None + if "Content-Type" not in request_headers and method not in [ + "GET", + "HEAD", + "DELETE", + ]: + request_headers["Content-Type"] = ContentType.JSON + logger.debug( + f"API Call: {method} {request_url} | Content-Type: {request_headers.get('Content-Type', 'N/A')}" + ) + + if method.upper() == "GET": + response = requests.get( + request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT + ) + elif method.upper() == "POST": + response = requests.post( + request_url, + data=serialized_body, + headers=request_headers, + timeout=DEFAULT_TIMEOUT, + ) + elif method.upper() == "PUT": + response = requests.put( + request_url, + data=serialized_body, + headers=request_headers, + timeout=DEFAULT_TIMEOUT, + ) + elif method.upper() == "DELETE": + response = requests.delete( + request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT + ) + elif method.upper() == "PATCH": + response = requests.patch( + request_url, + data=serialized_body, + headers=request_headers, + timeout=DEFAULT_TIMEOUT, + ) + elif method.upper() == "HEAD": + response = requests.head( + request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT + ) + elif method.upper() == "OPTIONS": + response = requests.options( + request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT + ) + else: + return { + "status_code": None, + "message": f"Unsupported HTTP method: {method}", + "data": None, + } + response.raise_for_status() + + data = self._parse_response(response) return { "status_code": response.status_code, "data": data, "message": "API call successful.", } + except requests.exceptions.Timeout: + logger.error(f"Request timeout for {request_url}") + return { + "status_code": None, + "message": f"Request timeout ({DEFAULT_TIMEOUT}s exceeded)", + "data": None, + } + except requests.exceptions.ConnectionError as e: + logger.error(f"Connection error: {str(e)}") + return { + "status_code": None, + "message": f"Connection error: {str(e)}", + "data": None, + } + except requests.exceptions.HTTPError as e: + logger.error(f"HTTP error {response.status_code}: {str(e)}") + try: + error_data = response.json() + except (json.JSONDecodeError, ValueError): + error_data = response.text + return { + "status_code": response.status_code, + "message": f"HTTP Error {response.status_code}", + "data": error_data, + } except requests.exceptions.RequestException as e: + logger.error(f"Request failed: {str(e)}") return { "status_code": response.status_code if response else None, "message": f"API call failed: {str(e)}", + "data": None, + } + except Exception as e: + logger.error(f"Unexpected error in API call: {str(e)}", exc_info=True) + return { + "status_code": None, + "message": f"Unexpected error: {str(e)}", + "data": None, } + def _parse_response(self, response: requests.Response) -> Any: + """ + Parse response based on Content-Type header. + + Supports: JSON, XML, plain text, binary data. + """ + content_type = response.headers.get("Content-Type", "").lower() + + if not response.content: + return None + # JSON response + + if "application/json" in content_type: + try: + return response.json() + except json.JSONDecodeError as e: + logger.warning(f"Failed to parse JSON response: {str(e)}") + return response.text + # XML response + + elif "application/xml" in content_type or "text/xml" in content_type: + return response.text + # Plain text response + + elif "text/plain" in content_type or "text/html" in content_type: + return response.text + # Binary/unknown response + + else: + # Try to decode as text first, fall back to base64 + + try: + return response.text + except (UnicodeDecodeError, AttributeError): + import base64 + + return base64.b64encode(response.content).decode("utf-8") + def get_actions_metadata(self): + """Return metadata for available actions (none for API Tool - actions are user-defined).""" return [] def get_config_requirements(self): + """Return configuration requirements for the tool.""" return {} diff --git a/application/agents/tools/brave.py b/application/agents/tools/brave.py index 33843ac0..66b21b10 100644 --- a/application/agents/tools/brave.py +++ b/application/agents/tools/brave.py @@ -1,6 +1,11 @@ +import logging + import requests + from application.agents.tools.base import Tool +logger = logging.getLogger(__name__) + class BraveSearchTool(Tool): """ @@ -41,7 +46,7 @@ class BraveSearchTool(Tool): """ Performs a web search using the Brave Search API. """ - print(f"Performing Brave web search for: {query}") + logger.debug("Performing Brave web search for: %s", query) url = f"{self.base_url}/web/search" @@ -94,7 +99,7 @@ class BraveSearchTool(Tool): """ Performs an image search using the Brave Search API. """ - print(f"Performing Brave image search for: {query}") + logger.debug("Performing Brave image search for: %s", query) url = f"{self.base_url}/images/search" @@ -177,6 +182,10 @@ class BraveSearchTool(Tool): return { "token": { "type": "string", + "label": "API Key", "description": "Brave Search API key for authentication", + "required": True, + "secret": True, + "order": 1, }, } diff --git a/application/agents/tools/duckduckgo.py b/application/agents/tools/duckduckgo.py index 87c1bc7e..19f59069 100644 --- a/application/agents/tools/duckduckgo.py +++ b/application/agents/tools/duckduckgo.py @@ -1,5 +1,14 @@ +import logging +import time +from typing import Any, Dict, Optional + from application.agents.tools.base import Tool -from duckduckgo_search import DDGS + +logger = logging.getLogger(__name__) + +MAX_RETRIES = 3 +RETRY_DELAY = 2.0 +DEFAULT_TIMEOUT = 15 class DuckDuckGoSearchTool(Tool): @@ -10,71 +19,123 @@ class DuckDuckGoSearchTool(Tool): def __init__(self, config): self.config = config + self.timeout = config.get("timeout", DEFAULT_TIMEOUT) + + def _get_ddgs_client(self): + from ddgs import DDGS + + return DDGS(timeout=self.timeout) + + def _execute_with_retry(self, operation, operation_name: str) -> Dict[str, Any]: + last_error = None + for attempt in range(1, MAX_RETRIES + 1): + try: + results = operation() + return { + "status_code": 200, + "results": list(results) if results else [], + "message": f"{operation_name} completed successfully.", + } + except Exception as e: + last_error = e + error_str = str(e).lower() + if "ratelimit" in error_str or "429" in error_str: + if attempt < MAX_RETRIES: + delay = RETRY_DELAY * attempt + logger.warning( + f"{operation_name} rate limited, retrying in {delay}s (attempt {attempt}/{MAX_RETRIES})" + ) + time.sleep(delay) + continue + logger.error(f"{operation_name} failed: {e}") + break + return { + "status_code": 500, + "results": [], + "message": f"{operation_name} failed: {str(last_error)}", + } def execute_action(self, action_name, **kwargs): actions = { "ddg_web_search": self._web_search, "ddg_image_search": self._image_search, + "ddg_news_search": self._news_search, } - - if action_name in actions: - return actions[action_name](**kwargs) - else: + if action_name not in actions: raise ValueError(f"Unknown action: {action_name}") + return actions[action_name](**kwargs) def _web_search( self, - query, - max_results=5, - ): - print(f"Performing DuckDuckGo web search for: {query}") + query: str, + max_results: int = 5, + region: str = "wt-wt", + safesearch: str = "moderate", + timelimit: Optional[str] = None, + ) -> Dict[str, Any]: + logger.info(f"DuckDuckGo web search: {query}") - try: - results = DDGS().text( + def operation(): + client = self._get_ddgs_client() + return client.text( query, - max_results=max_results, + region=region, + safesearch=safesearch, + timelimit=timelimit, + max_results=min(max_results, 20), ) - return { - "status_code": 200, - "results": results, - "message": "Web search completed successfully.", - } - except Exception as e: - return { - "status_code": 500, - "message": f"Web search failed: {str(e)}", - } + return self._execute_with_retry(operation, "Web search") def _image_search( self, - query, - max_results=5, - ): - print(f"Performing DuckDuckGo image search for: {query}") + query: str, + max_results: int = 5, + region: str = "wt-wt", + safesearch: str = "moderate", + timelimit: Optional[str] = None, + ) -> Dict[str, Any]: + logger.info(f"DuckDuckGo image search: {query}") - try: - results = DDGS().images( - keywords=query, - max_results=max_results, + def operation(): + client = self._get_ddgs_client() + return client.images( + query, + region=region, + safesearch=safesearch, + timelimit=timelimit, + max_results=min(max_results, 50), ) - return { - "status_code": 200, - "results": results, - "message": "Image search completed successfully.", - } - except Exception as e: - return { - "status_code": 500, - "message": f"Image search failed: {str(e)}", - } + return self._execute_with_retry(operation, "Image search") + + def _news_search( + self, + query: str, + max_results: int = 5, + region: str = "wt-wt", + safesearch: str = "moderate", + timelimit: Optional[str] = None, + ) -> Dict[str, Any]: + logger.info(f"DuckDuckGo news search: {query}") + + def operation(): + client = self._get_ddgs_client() + return client.news( + query, + region=region, + safesearch=safesearch, + timelimit=timelimit, + max_results=min(max_results, 20), + ) + + return self._execute_with_retry(operation, "News search") def get_actions_metadata(self): return [ { "name": "ddg_web_search", - "description": "Perform a web search using DuckDuckGo.", + "description": "Search the web using DuckDuckGo. Returns titles, URLs, and snippets.", "parameters": { "type": "object", "properties": { @@ -84,7 +145,15 @@ class DuckDuckGoSearchTool(Tool): }, "max_results": { "type": "integer", - "description": "Number of results to return (default: 5)", + "description": "Number of results (default: 5, max: 20)", + }, + "region": { + "type": "string", + "description": "Region code (default: wt-wt for worldwide, us-en for US)", + }, + "timelimit": { + "type": "string", + "description": "Time filter: d (day), w (week), m (month), y (year)", }, }, "required": ["query"], @@ -92,17 +161,43 @@ class DuckDuckGoSearchTool(Tool): }, { "name": "ddg_image_search", - "description": "Perform an image search using DuckDuckGo.", + "description": "Search for images using DuckDuckGo. Returns image URLs and metadata.", "parameters": { "type": "object", "properties": { "query": { "type": "string", - "description": "Search query", + "description": "Image search query", }, "max_results": { "type": "integer", - "description": "Number of results to return (default: 5, max: 50)", + "description": "Number of results (default: 5, max: 50)", + }, + "region": { + "type": "string", + "description": "Region code (default: wt-wt for worldwide)", + }, + }, + "required": ["query"], + }, + }, + { + "name": "ddg_news_search", + "description": "Search for news articles using DuckDuckGo. Returns recent news.", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "News search query", + }, + "max_results": { + "type": "integer", + "description": "Number of results (default: 5, max: 20)", + }, + "timelimit": { + "type": "string", + "description": "Time filter: d (day), w (week), m (month)", }, }, "required": ["query"], diff --git a/application/agents/tools/mcp_tool.py b/application/agents/tools/mcp_tool.py index b21e1363..265688ea 100644 --- a/application/agents/tools/mcp_tool.py +++ b/application/agents/tools/mcp_tool.py @@ -1,20 +1,12 @@ import asyncio import base64 +import concurrent.futures import json import logging import time from typing import Any, Dict, List, Optional from urllib.parse import parse_qs, urlparse -from application.agents.tools.base import Tool -from application.api.user.tasks import mcp_oauth_status_task, mcp_oauth_task -from application.cache import get_redis_instance - -from application.core.mongo_db import MongoDB - -from application.core.settings import settings - -from application.security.encryption import decrypt_credentials from fastmcp import Client from fastmcp.client.auth import BearerAuth from fastmcp.client.transports import ( @@ -24,10 +16,18 @@ from fastmcp.client.transports import ( ) from mcp.client.auth import OAuthClientProvider, TokenStorage from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken - from pydantic import AnyHttpUrl, ValidationError from redis import Redis +from application.agents.tools.base import Tool +from application.api.user.tasks import mcp_oauth_status_task, mcp_oauth_task +from application.cache import get_redis_instance +from application.core.mongo_db import MongoDB +from application.core.settings import settings +from application.security.encryption import decrypt_credentials + +logger = logging.getLogger(__name__) + mongo = MongoDB.get_client() db = mongo[settings.MONGO_DB_NAME] @@ -56,6 +56,7 @@ class MCPTool(Tool): - args: Arguments for STDIO transport - oauth_scopes: OAuth scopes for oauth auth type - oauth_client_name: OAuth client name for oauth auth type + - query_mode: If True, use non-interactive OAuth (fail-fast on 401) user_id: User ID for decrypting credentials (required if encrypted_credentials exist) """ self.config = config @@ -76,23 +77,40 @@ class MCPTool(Tool): self.oauth_scopes = config.get("oauth_scopes", []) self.oauth_task_id = config.get("oauth_task_id", None) self.oauth_client_name = config.get("oauth_client_name", "DocsGPT-MCP") - self.redirect_uri = f"{settings.API_URL}/api/mcp_server/callback" + self.redirect_uri = self._resolve_redirect_uri(config.get("redirect_uri")) self.available_tools = [] self._cache_key = self._generate_cache_key() self._client = None - - # Only validate and setup if server_url is provided and not OAuth + self.query_mode = config.get("query_mode", False) if self.server_url and self.auth_type != "oauth": self._setup_client() + def _resolve_redirect_uri(self, configured_redirect_uri: Optional[str]) -> str: + if configured_redirect_uri: + return configured_redirect_uri.rstrip("/") + + explicit = getattr(settings, "MCP_OAUTH_REDIRECT_URI", None) + if explicit: + return explicit.rstrip("/") + + connector_base = getattr(settings, "CONNECTOR_REDIRECT_BASE_URI", None) + if connector_base: + parsed = urlparse(connector_base) + if parsed.scheme and parsed.netloc: + return f"{parsed.scheme}://{parsed.netloc}/api/mcp_server/callback" + + return f"{settings.API_URL.rstrip('/')}/api/mcp_server/callback" + def _generate_cache_key(self) -> str: """Generate a unique cache key for this MCP server configuration.""" auth_key = "" if self.auth_type == "oauth": scopes_str = ",".join(self.oauth_scopes) if self.oauth_scopes else "none" - auth_key = f"oauth:{self.oauth_client_name}:{scopes_str}" + auth_key = ( + f"oauth:{self.oauth_client_name}:{scopes_str}:{self.redirect_uri}" + ) elif self.auth_type in ["bearer"]: token = self.auth_credentials.get( "bearer_token", "" @@ -109,11 +127,10 @@ class MCPTool(Tool): return f"{self.server_url}#{self.transport_type}#{auth_key}" def _setup_client(self): - """Setup FastMCP client with proper transport and authentication.""" global _mcp_clients_cache if self._cache_key in _mcp_clients_cache: cached_data = _mcp_clients_cache[self._cache_key] - if time.time() - cached_data["created_at"] < 1800: + if time.time() - cached_data["created_at"] < 300: self._client = cached_data["client"] return else: @@ -123,15 +140,25 @@ class MCPTool(Tool): if self.auth_type == "oauth": redis_client = get_redis_instance() - auth = DocsGPTOAuth( - mcp_url=self.server_url, - scopes=self.oauth_scopes, - redis_client=redis_client, - redirect_uri=self.redirect_uri, - task_id=self.oauth_task_id, - db=db, - user_id=self.user_id, - ) + if self.query_mode: + auth = NonInteractiveOAuth( + mcp_url=self.server_url, + scopes=self.oauth_scopes, + redis_client=redis_client, + redirect_uri=self.redirect_uri, + db=db, + user_id=self.user_id, + ) + else: + auth = DocsGPTOAuth( + mcp_url=self.server_url, + scopes=self.oauth_scopes, + redis_client=redis_client, + redirect_uri=self.redirect_uri, + task_id=self.oauth_task_id, + db=db, + user_id=self.user_id, + ) elif self.auth_type == "bearer": token = self.auth_credentials.get( "bearer_token", "" @@ -169,6 +196,8 @@ class MCPTool(Tool): transport_type = "http" else: transport_type = self.transport_type + if transport_type == "stdio": + raise ValueError("STDIO transport is disabled") if transport_type == "sse": headers.update({"Accept": "text/event-stream", "Cache-Control": "no-cache"}) return SSETransport(url=self.server_url, headers=headers) @@ -231,38 +260,53 @@ class MCPTool(Tool): else: raise Exception(f"Unknown operation: {operation}") + _ERROR_MAP = [ + (concurrent.futures.TimeoutError, lambda op, t, _: f"Timed out after {t}s"), + (ConnectionRefusedError, lambda *_: "Connection refused"), + ] + + _ERROR_PATTERNS = { + ("403", "Forbidden"): "Access denied (403 Forbidden)", + ("401", "Unauthorized"): "Authentication failed (401 Unauthorized)", + ("ECONNREFUSED",): "Connection refused", + ("SSL", "certificate"): "SSL/TLS error", + } + def _run_async_operation(self, operation: str, *args, **kwargs): - """Run async operation in sync context.""" try: try: - loop = asyncio.get_running_loop() - import concurrent.futures - - def run_in_thread(): - new_loop = asyncio.new_event_loop() - asyncio.set_event_loop(new_loop) - try: - return new_loop.run_until_complete( - self._execute_with_client(operation, *args, **kwargs) - ) - finally: - new_loop.close() - + asyncio.get_running_loop() with concurrent.futures.ThreadPoolExecutor() as executor: - future = executor.submit(run_in_thread) + future = executor.submit( + self._run_in_new_loop, operation, *args, **kwargs + ) return future.result(timeout=self.timeout) except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - return loop.run_until_complete( - self._execute_with_client(operation, *args, **kwargs) - ) - finally: - loop.close() + return self._run_in_new_loop(operation, *args, **kwargs) except Exception as e: - print(f"Error occurred while running async operation: {e}") - raise + raise self._map_error(operation, e) from e + raise self._map_error(operation, e) from e + + def _run_in_new_loop(self, operation, *args, **kwargs): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + return loop.run_until_complete( + self._execute_with_client(operation, *args, **kwargs) + ) + finally: + loop.close() + + def _map_error(self, operation: str, exc: Exception) -> Exception: + for exc_type, msg_fn in self._ERROR_MAP: + if isinstance(exc, exc_type): + return Exception(msg_fn(operation, self.timeout, exc)) + error_msg = str(exc) + for patterns, friendly in self._ERROR_PATTERNS.items(): + if any(p.lower() in error_msg.lower() for p in patterns): + return Exception(friendly) + logger.error("MCP %s failed: %s", operation, exc) + return exc def discover_tools(self) -> List[Dict]: """ @@ -283,16 +327,6 @@ class MCPTool(Tool): raise Exception(f"Failed to discover tools from MCP server: {str(e)}") def execute_action(self, action_name: str, **kwargs) -> Any: - """ - Execute an action on the remote MCP server using FastMCP. - - Args: - action_name: Name of the action to execute - **kwargs: Parameters for the action - - Returns: - Result from the MCP server - """ if not self.server_url: raise Exception("No MCP server configured") if not self._client: @@ -308,7 +342,37 @@ class MCPTool(Tool): ) return self._format_result(result) except Exception as e: - raise Exception(f"Failed to execute action '{action_name}': {str(e)}") + error_msg = str(e) + lower_msg = error_msg.lower() + is_auth_error = ( + "401" in error_msg + or "unauthorized" in lower_msg + or "session expired" in lower_msg + or "re-authorize" in lower_msg + ) + if is_auth_error: + if self.auth_type == "oauth": + raise Exception( + f"Action '{action_name}' failed: OAuth session expired. " + "Please re-authorize this MCP server in tool settings." + ) from e + global _mcp_clients_cache + _mcp_clients_cache.pop(self._cache_key, None) + self._client = None + self._setup_client() + try: + result = self._run_async_operation( + "call_tool", action_name, **cleaned_kwargs + ) + return self._format_result(result) + except Exception as retry_e: + raise Exception( + f"Action '{action_name}' failed after re-auth attempt: {retry_e}. " + "Your credentials may have expired — please re-authorize in tool settings." + ) from retry_e + raise Exception( + f"Failed to execute action '{action_name}': {error_msg}" + ) from e def _format_result(self, result) -> Dict: """Format FastMCP result to match expected format.""" @@ -331,23 +395,35 @@ class MCPTool(Tool): return result def test_connection(self) -> Dict: - """ - Test the connection to the MCP server and validate functionality. - - Returns: - Dictionary with connection test results including tool count - """ if not self.server_url: return { "success": False, - "message": "No MCP server URL configured", + "message": "No server URL configured", + "tools_count": 0, + } + try: + parsed = urlparse(self.server_url) + if parsed.scheme not in ("http", "https"): + return { + "success": False, + "message": f"Invalid URL scheme '{parsed.scheme}' — use http:// or https://", + "tools_count": 0, + } + except Exception: + return { + "success": False, + "message": "Invalid URL format", "tools_count": 0, - "transport_type": self.transport_type, - "auth_type": self.auth_type, - "error_type": "ConfigurationError", } if not self._client: - self._setup_client() + try: + self._setup_client() + except Exception as e: + return { + "success": False, + "message": f"Client init failed: {str(e)}", + "tools_count": 0, + } try: if self.auth_type == "oauth": return self._test_oauth_connection() @@ -358,56 +434,94 @@ class MCPTool(Tool): "success": False, "message": f"Connection failed: {str(e)}", "tools_count": 0, - "transport_type": self.transport_type, - "auth_type": self.auth_type, - "error_type": type(e).__name__, } def _test_regular_connection(self) -> Dict: - """Test connection for non-OAuth auth types.""" + ping_ok = False + ping_error = None try: self._run_async_operation("ping") - ping_success = True - except Exception: - ping_success = False - tools = self.discover_tools() + ping_ok = True + except Exception as e: + ping_error = str(e) - message = f"Successfully connected to MCP server. Found {len(tools)} tools." - if not ping_success: - message += " (Ping not supported, but tool discovery worked)" - return { - "success": True, - "message": message, - "tools_count": len(tools), - "transport_type": self.transport_type, - "auth_type": self.auth_type, - "ping_supported": ping_success, - "tools": [tool.get("name", "unknown") for tool in tools], - } - - def _test_oauth_connection(self) -> Dict: - """Test connection for OAuth auth type with proper async handling.""" try: - task = mcp_oauth_task.delay(config=self.config, user=self.user_id) - if not task: - raise Exception("Failed to start OAuth authentication") - return { - "success": True, - "requires_oauth": True, - "task_id": task.id, - "status": "pending", - "message": "OAuth flow started", - } + tools = self.discover_tools() except Exception as e: return { "success": False, - "message": f"OAuth connection failed: {str(e)}", + "message": f"Connection failed: {ping_error or str(e)}", "tools_count": 0, - "transport_type": self.transport_type, - "auth_type": self.auth_type, - "error_type": type(e).__name__, } + if not tools and not ping_ok: + return { + "success": False, + "message": f"Connection failed: {ping_error or 'No tools found'}", + "tools_count": 0, + } + + return { + "success": True, + "message": f"Connected — found {len(tools)} tool{'s' if len(tools) != 1 else ''}.", + "tools_count": len(tools), + "tools": [ + { + "name": tool.get("name", "unknown"), + "description": tool.get("description", ""), + } + for tool in tools + ], + } + + def _test_oauth_connection(self) -> Dict: + storage = DBTokenStorage( + server_url=self.server_url, user_id=self.user_id, db_client=db + ) + loop = asyncio.new_event_loop() + try: + tokens = loop.run_until_complete(storage.get_tokens()) + finally: + loop.close() + + if tokens and tokens.access_token: + self.query_mode = True + _mcp_clients_cache.pop(self._cache_key, None) + self._client = None + self._setup_client() + try: + tools = self.discover_tools() + return { + "success": True, + "message": f"Connected — found {len(tools)} tool{'s' if len(tools) != 1 else ''}.", + "tools_count": len(tools), + "tools": [ + { + "name": t.get("name", "unknown"), + "description": t.get("description", ""), + } + for t in tools + ], + } + except Exception as e: + logger.warning("OAuth token validation failed: %s", e) + _mcp_clients_cache.pop(self._cache_key, None) + self._client = None + + return self._start_oauth_task() + + def _start_oauth_task(self) -> Dict: + task_config = self.config.copy() + task_config.pop("query_mode", None) + result = mcp_oauth_task.delay(task_config, self.user_id) + return { + "success": False, + "requires_oauth": True, + "task_id": result.id, + "message": "OAuth authorization required.", + "tools_count": 0, + } + def get_actions_metadata(self) -> List[Dict]: """ Get metadata for all available actions. @@ -453,107 +567,88 @@ class MCPTool(Tool): return actions def get_config_requirements(self) -> Dict: - """Get configuration requirements for the MCP tool.""" return { "server_url": { "type": "string", - "description": "URL of the remote MCP server (e.g., https://api.example.com/mcp or https://docs.mcp.cloudflare.com/sse)", + "label": "Server URL", + "description": "URL of the remote MCP server", "required": True, - }, - "transport_type": { - "type": "string", - "description": "Transport type for connection", - "enum": ["auto", "sse", "http", "stdio"], - "default": "auto", - "required": False, - "help": { - "auto": "Automatically detect best transport", - "sse": "Server-Sent Events (for real-time streaming)", - "http": "HTTP streaming (recommended for production)", - "stdio": "Standard I/O (for local servers)", - }, + "secret": False, + "order": 1, }, "auth_type": { "type": "string", - "description": "Authentication type", + "label": "Authentication Type", + "description": "Authentication method for the MCP server", "enum": ["none", "bearer", "oauth", "api_key", "basic"], "default": "none", "required": True, - "help": { - "none": "No authentication", - "bearer": "Bearer token authentication", - "oauth": "OAuth 2.1 authentication (with frontend integration)", - "api_key": "API key authentication", - "basic": "Basic authentication", - }, + "secret": False, + "order": 2, }, - "auth_credentials": { - "type": "object", - "description": "Authentication credentials (varies by auth_type)", + "api_key": { + "type": "string", + "label": "API Key", + "description": "API key for authentication", "required": False, - "properties": { - "bearer_token": { - "type": "string", - "description": "Bearer token for bearer auth", - }, - "access_token": { - "type": "string", - "description": "Access token for OAuth (if pre-obtained)", - }, - "api_key": { - "type": "string", - "description": "API key for api_key auth", - }, - "api_key_header": { - "type": "string", - "description": "Header name for API key (default: X-API-Key)", - }, - "username": { - "type": "string", - "description": "Username for basic auth", - }, - "password": { - "type": "string", - "description": "Password for basic auth", - }, - }, + "secret": True, + "order": 3, + "depends_on": {"auth_type": "api_key"}, + }, + "api_key_header": { + "type": "string", + "label": "API Key Header", + "description": "Header name for API key (default: X-API-Key)", + "default": "X-API-Key", + "required": False, + "secret": False, + "order": 4, + "depends_on": {"auth_type": "api_key"}, + }, + "bearer_token": { + "type": "string", + "label": "Bearer Token", + "description": "Bearer token for authentication", + "required": False, + "secret": True, + "order": 3, + "depends_on": {"auth_type": "bearer"}, + }, + "username": { + "type": "string", + "label": "Username", + "description": "Username for basic authentication", + "required": False, + "secret": False, + "order": 3, + "depends_on": {"auth_type": "basic"}, + }, + "password": { + "type": "string", + "label": "Password", + "description": "Password for basic authentication", + "required": False, + "secret": True, + "order": 4, + "depends_on": {"auth_type": "basic"}, }, "oauth_scopes": { - "type": "array", - "description": "OAuth scopes to request (for oauth auth_type)", - "items": {"type": "string"}, - "required": False, - "default": [], - }, - "oauth_client_name": { "type": "string", - "description": "Client name for OAuth registration (for oauth auth_type)", - "default": "DocsGPT-MCP", - "required": False, - }, - "headers": { - "type": "object", - "description": "Custom headers to send with requests", + "label": "OAuth Scopes", + "description": "Comma-separated OAuth scopes to request", "required": False, + "secret": False, + "order": 3, + "depends_on": {"auth_type": "oauth"}, }, "timeout": { - "type": "integer", - "description": "Request timeout in seconds", + "type": "number", + "label": "Timeout (seconds)", + "description": "Request timeout in seconds (1-300)", "default": 30, - "minimum": 1, - "maximum": 300, - "required": False, - }, - "command": { - "type": "string", - "description": "Command to run for STDIO transport (e.g., 'python')", - "required": False, - }, - "args": { - "type": "array", - "description": "Arguments for STDIO command", - "items": {"type": "string"}, "required": False, + "secret": False, + "order": 10, }, } @@ -575,23 +670,8 @@ class DocsGPTOAuth(OAuthClientProvider): user_id=None, db=None, additional_client_metadata: dict[str, Any] | None = None, + skip_redirect_validation: bool = False, ): - """ - Initialize custom OAuth client provider for DocsGPT. - - Args: - mcp_url: Full URL to the MCP endpoint - redirect_uri: Custom redirect URI for DocsGPT frontend - redis_client: Redis client for storing auth state - redis_prefix: Prefix for Redis keys - task_id: Task ID for tracking auth status - scopes: OAuth scopes to request - client_name: Name for this client during registration - user_id: User ID for token storage - db: Database instance for token storage - additional_client_metadata: Extra fields for OAuthClientMetadata - """ - self.redirect_uri = redirect_uri self.redis_client = redis_client self.redis_prefix = redis_prefix @@ -614,7 +694,10 @@ class DocsGPTOAuth(OAuthClientProvider): ) storage = DBTokenStorage( - server_url=self.server_base_url, user_id=self.user_id, db_client=self.db + server_url=self.server_base_url, + user_id=self.user_id, + db_client=self.db, + expected_redirect_uri=None if skip_redirect_validation else redirect_uri, ) super().__init__( @@ -646,22 +729,20 @@ class DocsGPTOAuth(OAuthClientProvider): async def redirect_handler(self, authorization_url: str) -> None: """Store auth URL and state in Redis for frontend to use.""" auth_url, state = self._process_auth_url(authorization_url) - logging.info( - "[DocsGPTOAuth] Processed auth_url: %s, state: %s", auth_url, state - ) + logger.info("Processed auth_url: %s, state: %s", auth_url, state) self.auth_url = auth_url self.extracted_state = state if self.redis_client and self.extracted_state: key = f"{self.redis_prefix}auth_url:{self.extracted_state}" self.redis_client.setex(key, 600, auth_url) - logging.info("[DocsGPTOAuth] Stored auth_url in Redis: %s", key) + logger.info("Stored auth_url in Redis: %s", key) if self.task_id: status_key = f"mcp_oauth_status:{self.task_id}" status_data = { "status": "requires_redirect", - "message": "OAuth authorization required", + "message": "Authorization required", "authorization_url": self.auth_url, "state": self.extracted_state, "requires_oauth": True, @@ -681,7 +762,7 @@ class DocsGPTOAuth(OAuthClientProvider): status_key = f"mcp_oauth_status:{self.task_id}" status_data = { "status": "awaiting_callback", - "message": "Waiting for OAuth callback...", + "message": "Waiting for authorization...", "authorization_url": self.auth_url, "state": self.extracted_state, "requires_oauth": True, @@ -706,7 +787,7 @@ class DocsGPTOAuth(OAuthClientProvider): if self.task_id: status_data = { "status": "callback_received", - "message": "OAuth callback received, completing authentication...", + "message": "Completing authentication...", "task_id": self.task_id, } self.redis_client.setex(status_key, 600, json.dumps(status_data)) @@ -726,14 +807,44 @@ class DocsGPTOAuth(OAuthClientProvider): await asyncio.sleep(poll_interval) self.redis_client.delete(f"{self.redis_prefix}auth_url:{self.extracted_state}") self.redis_client.delete(f"{self.redis_prefix}state:{self.extracted_state}") - raise Exception("OAuth callback timeout: no code received within 5 minutes") + raise Exception("OAuth timeout: no code received within 5 minutes") + + +class NonInteractiveOAuth(DocsGPTOAuth): + """OAuth provider that fails fast on 401 instead of starting interactive auth. + + Used during query execution to prevent the streaming response from blocking + while waiting for user authorization that will never come. + """ + + def __init__(self, **kwargs): + kwargs.setdefault("task_id", None) + kwargs["skip_redirect_validation"] = True + super().__init__(**kwargs) + + async def redirect_handler(self, authorization_url: str) -> None: + raise Exception( + "OAuth session expired — please re-authorize this MCP server in tool settings." + ) + + async def callback_handler(self) -> tuple[str, str | None]: + raise Exception( + "OAuth session expired — please re-authorize this MCP server in tool settings." + ) class DBTokenStorage(TokenStorage): - def __init__(self, server_url: str, user_id: str, db_client): + def __init__( + self, + server_url: str, + user_id: str, + db_client, + expected_redirect_uri: Optional[str] = None, + ): self.server_url = server_url self.user_id = user_id self.db_client = db_client + self.expected_redirect_uri = expected_redirect_uri self.collection = db_client["connector_sessions"] @staticmethod @@ -752,10 +863,9 @@ class DBTokenStorage(TokenStorage): if not doc or "tokens" not in doc: return None try: - tokens = OAuthToken.model_validate(doc["tokens"]) - return tokens + return OAuthToken.model_validate(doc["tokens"]) except ValidationError as e: - logging.error(f"Could not load tokens: {e}") + logger.error("Could not load tokens: %s", e) return None async def set_tokens(self, tokens: OAuthToken) -> None: @@ -765,28 +875,38 @@ class DBTokenStorage(TokenStorage): {"$set": {"tokens": tokens.model_dump()}}, True, ) - logging.info(f"Saved tokens for {self.get_base_url(self.server_url)}") + logger.info("Saved tokens for %s", self.get_base_url(self.server_url)) async def get_client_info(self) -> OAuthClientInformationFull | None: doc = await asyncio.to_thread(self.collection.find_one, self.get_db_key()) if not doc or "client_info" not in doc: + logger.debug( + "No client_info in DB for %s", self.get_base_url(self.server_url) + ) return None try: client_info = OAuthClientInformationFull.model_validate(doc["client_info"]) - tokens = await self.get_tokens() - if tokens is None: - logging.debug( - "No tokens found, clearing client info to force fresh registration." - ) - await asyncio.to_thread( - self.collection.update_one, - self.get_db_key(), - {"$unset": {"client_info": ""}}, - ) - return None + if self.expected_redirect_uri: + stored_uris = [ + str(uri).rstrip("/") for uri in client_info.redirect_uris + ] + expected_uri = self.expected_redirect_uri.rstrip("/") + if expected_uri not in stored_uris: + logger.warning( + "Redirect URI mismatch for %s: expected=%s stored=%s — clearing.", + self.get_base_url(self.server_url), + expected_uri, + stored_uris, + ) + await asyncio.to_thread( + self.collection.update_one, + self.get_db_key(), + {"$unset": {"client_info": "", "tokens": ""}}, + ) + return None return client_info except ValidationError as e: - logging.error(f"Could not load client info: {e}") + logger.error("Could not load client info: %s", e) return None def _serialize_client_info(self, info: dict) -> dict: @@ -802,17 +922,17 @@ class DBTokenStorage(TokenStorage): {"$set": {"client_info": serialized_info}}, True, ) - logging.info(f"Saved client info for {self.get_base_url(self.server_url)}") + logger.info("Saved client info for %s", self.get_base_url(self.server_url)) async def clear(self) -> None: await asyncio.to_thread(self.collection.delete_one, self.get_db_key()) - logging.info(f"Cleared OAuth cache for {self.get_base_url(self.server_url)}") + logger.info("Cleared OAuth cache for %s", self.get_base_url(self.server_url)) @classmethod async def clear_all(cls, db_client) -> None: collection = db_client["connector_sessions"] await asyncio.to_thread(collection.delete_many, {}) - logging.info("Cleared all OAuth client cache data.") + logger.info("Cleared all OAuth client cache data.") class MCPOAuthManager: @@ -851,7 +971,7 @@ class MCPOAuthManager: return True except Exception as e: - logging.error(f"Error handling OAuth callback: {e}") + logger.error("Error handling OAuth callback: %s", e) return False def get_oauth_status(self, task_id: str) -> Dict[str, Any]: diff --git a/application/agents/tools/notes.py b/application/agents/tools/notes.py index 3d7ced85..8afdd071 100644 --- a/application/agents/tools/notes.py +++ b/application/agents/tools/notes.py @@ -38,6 +38,8 @@ class NotesTool(Tool): db = MongoDB.get_client()[settings.MONGO_DB_NAME] self.collection = db["notes"] + self._last_artifact_id: Optional[str] = None + # ----------------------------- # Action implementations # ----------------------------- @@ -54,6 +56,8 @@ class NotesTool(Tool): if not self.user_id: return "Error: NotesTool requires a valid user_id." + self._last_artifact_id = None + if action_name == "view": return self._get_note() @@ -125,6 +129,9 @@ class NotesTool(Tool): """Return configuration requirements (none for now).""" return {} + def get_artifact_id(self, action_name: str, **kwargs: Any) -> Optional[str]: + return self._last_artifact_id + # ----------------------------- # Internal helpers (single-note) # ----------------------------- @@ -132,17 +139,22 @@ class NotesTool(Tool): doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id}) if not doc or not doc.get("note"): return "No note found." + if doc.get("_id") is not None: + self._last_artifact_id = str(doc.get("_id")) return str(doc["note"]) def _overwrite_note(self, content: str) -> str: content = (content or "").strip() if not content: return "Note content required." - self.collection.update_one( + result = self.collection.find_one_and_update( {"user_id": self.user_id, "tool_id": self.tool_id}, {"$set": {"note": content, "updated_at": datetime.utcnow()}}, - upsert=True, # ✅ create if missing + upsert=True, + return_document=True, ) + if result and result.get("_id") is not None: + self._last_artifact_id = str(result.get("_id")) return "Note saved." def _str_replace(self, old_str: str, new_str: str) -> str: @@ -163,10 +175,13 @@ class NotesTool(Tool): import re updated_note = re.sub(re.escape(old_str), new_str, current_note, flags=re.IGNORECASE) - self.collection.update_one( + result = self.collection.find_one_and_update( {"user_id": self.user_id, "tool_id": self.tool_id}, {"$set": {"note": updated_note, "updated_at": datetime.utcnow()}}, + return_document=True, ) + if result and result.get("_id") is not None: + self._last_artifact_id = str(result.get("_id")) return "Note updated." def _insert(self, line_number: int, text: str) -> str: @@ -188,12 +203,21 @@ class NotesTool(Tool): lines.insert(index, text) updated_note = "\n".join(lines) - self.collection.update_one( + result = self.collection.find_one_and_update( {"user_id": self.user_id, "tool_id": self.tool_id}, {"$set": {"note": updated_note, "updated_at": datetime.utcnow()}}, + return_document=True, ) + if result and result.get("_id") is not None: + self._last_artifact_id = str(result.get("_id")) return "Text inserted." def _delete_note(self) -> str: - res = self.collection.delete_one({"user_id": self.user_id, "tool_id": self.tool_id}) - return "Note deleted." if res.deleted_count else "No note found to delete." + doc = self.collection.find_one_and_delete( + {"user_id": self.user_id, "tool_id": self.tool_id} + ) + if not doc: + return "No note found to delete." + if doc.get("_id") is not None: + self._last_artifact_id = str(doc.get("_id")) + return "Note deleted." diff --git a/application/agents/tools/ntfy.py b/application/agents/tools/ntfy.py index e968dfc4..9a2d44ca 100644 --- a/application/agents/tools/ntfy.py +++ b/application/agents/tools/ntfy.py @@ -116,12 +116,13 @@ class NtfyTool(Tool): ] def get_config_requirements(self): - """ - Specify the configuration requirements. - - Returns: - dict: Dictionary describing required config parameters. - """ return { - "token": {"type": "string", "description": "Access token for authentication"}, + "token": { + "type": "string", + "label": "Access Token", + "description": "Ntfy access token for authentication", + "required": True, + "secret": True, + "order": 1, + }, } \ No newline at end of file diff --git a/application/agents/tools/postgres.py b/application/agents/tools/postgres.py index 2877ebad..d9d5a2b4 100644 --- a/application/agents/tools/postgres.py +++ b/application/agents/tools/postgres.py @@ -1,6 +1,12 @@ +import logging + import psycopg2 + from application.agents.tools.base import Tool +logger = logging.getLogger(__name__) + + class PostgresTool(Tool): """ PostgreSQL Database Tool @@ -17,17 +23,15 @@ class PostgresTool(Tool): "postgres_execute_sql": self._execute_sql, "postgres_get_schema": self._get_schema, } - - if action_name in actions: - return actions[action_name](**kwargs) - else: + if action_name not in actions: raise ValueError(f"Unknown action: {action_name}") + return actions[action_name](**kwargs) def _execute_sql(self, sql_query): """ Executes an SQL query against the PostgreSQL database using a connection string. """ - conn = None # Initialize conn to None for error handling + conn = None try: conn = psycopg2.connect(self.connection_string) cur = conn.cursor() @@ -35,7 +39,9 @@ class PostgresTool(Tool): conn.commit() if sql_query.strip().lower().startswith("select"): - column_names = [desc[0] for desc in cur.description] if cur.description else [] + column_names = ( + [desc[0] for desc in cur.description] if cur.description else [] + ) results = [] rows = cur.fetchall() for row in rows: @@ -43,7 +49,9 @@ class PostgresTool(Tool): response_data = {"data": results, "column_names": column_names} else: row_count = cur.rowcount - response_data = {"message": f"Query executed successfully, {row_count} rows affected."} + response_data = { + "message": f"Query executed successfully, {row_count} rows affected." + } cur.close() return { @@ -54,26 +62,27 @@ class PostgresTool(Tool): except psycopg2.Error as e: error_message = f"Database error: {e}" - print(f"Database error: {e}") + logger.error("PostgreSQL execute_sql error: %s", e) return { "status_code": 500, "message": "Failed to execute SQL query.", "error": error_message, } finally: - if conn: # Ensure connection is closed even if errors occur + if conn: conn.close() def _get_schema(self, db_name): """ Retrieves the schema of the PostgreSQL database using a connection string. """ - conn = None # Initialize conn to None for error handling + conn = None try: conn = psycopg2.connect(self.connection_string) cur = conn.cursor() - cur.execute(""" + cur.execute( + """ SELECT table_name, column_name, @@ -87,19 +96,22 @@ class PostgresTool(Tool): ORDER BY table_name, ordinal_position; - """) + """ + ) schema_data = {} for row in cur.fetchall(): table_name, column_name, data_type, column_default, is_nullable = row if table_name not in schema_data: schema_data[table_name] = [] - schema_data[table_name].append({ - "column_name": column_name, - "data_type": data_type, - "column_default": column_default, - "is_nullable": is_nullable - }) + schema_data[table_name].append( + { + "column_name": column_name, + "data_type": data_type, + "column_default": column_default, + "is_nullable": is_nullable, + } + ) cur.close() return { @@ -110,14 +122,14 @@ class PostgresTool(Tool): except psycopg2.Error as e: error_message = f"Database error: {e}" - print(f"Database error: {e}") + logger.error("PostgreSQL get_schema error: %s", e) return { "status_code": 500, "message": "Failed to retrieve database schema.", "error": error_message, } finally: - if conn: # Ensure connection is closed even if errors occur + if conn: conn.close() def get_actions_metadata(self): @@ -158,6 +170,10 @@ class PostgresTool(Tool): return { "token": { "type": "string", - "description": "PostgreSQL database connection string (e.g., 'postgresql://user:password@host:port/dbname')", + "label": "Connection String", + "description": "PostgreSQL database connection string", + "required": True, + "secret": True, + "order": 1, }, - } \ No newline at end of file + } diff --git a/application/agents/tools/read_webpage.py b/application/agents/tools/read_webpage.py index e87c79e3..f0321a5a 100644 --- a/application/agents/tools/read_webpage.py +++ b/application/agents/tools/read_webpage.py @@ -1,7 +1,7 @@ import requests from markdownify import markdownify from application.agents.tools.base import Tool -from urllib.parse import urlparse +from application.core.url_validation import validate_url, SSRFError class ReadWebpageTool(Tool): """ @@ -31,11 +31,12 @@ class ReadWebpageTool(Tool): if not url: return "Error: URL parameter is missing." - # Ensure the URL has a scheme (if not, default to http) - parsed_url = urlparse(url) - if not parsed_url.scheme: - url = "http://" + url - + # Validate URL to prevent SSRF attacks + try: + url = validate_url(url) + except SSRFError as e: + return f"Error: URL validation failed - {e}" + try: response = requests.get(url, timeout=10, headers={'User-Agent': 'DocsGPT-Agent/1.0'}) response.raise_for_status() # Raise an exception for HTTP errors (4xx or 5xx) diff --git a/application/agents/tools/spec_parser.py b/application/agents/tools/spec_parser.py new file mode 100644 index 00000000..336f00f8 --- /dev/null +++ b/application/agents/tools/spec_parser.py @@ -0,0 +1,342 @@ +""" +API Specification Parser + +Parses OpenAPI 3.x and Swagger 2.0 specifications and converts them +to API Tool action definitions for use in DocsGPT. +""" + +import json +import logging +import re +from typing import Any, Dict, List, Optional, Tuple + +import yaml + +logger = logging.getLogger(__name__) + +SUPPORTED_METHODS = frozenset( + {"get", "post", "put", "delete", "patch", "head", "options"} +) + + +def parse_spec(spec_content: str) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]: + """ + Parse an API specification and convert operations to action definitions. + + Supports OpenAPI 3.x and Swagger 2.0 formats in JSON or YAML. + + Args: + spec_content: Raw specification content as string + + Returns: + Tuple of (metadata dict, list of action dicts) + + Raises: + ValueError: If the spec is invalid or uses an unsupported format + """ + spec = _load_spec(spec_content) + _validate_spec(spec) + + is_swagger = "swagger" in spec + metadata = _extract_metadata(spec, is_swagger) + actions = _extract_actions(spec, is_swagger) + + return metadata, actions + + +def _load_spec(content: str) -> Dict[str, Any]: + """Parse spec content from JSON or YAML string.""" + content = content.strip() + if not content: + raise ValueError("Empty specification content") + try: + if content.startswith("{"): + return json.loads(content) + return yaml.safe_load(content) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON format: {e.msg}") + except yaml.YAMLError as e: + raise ValueError(f"Invalid YAML format: {e}") + + +def _validate_spec(spec: Dict[str, Any]) -> None: + """Validate spec version and required fields.""" + if not isinstance(spec, dict): + raise ValueError("Specification must be a valid object") + openapi_version = spec.get("openapi", "") + swagger_version = spec.get("swagger", "") + + if not (openapi_version.startswith("3.") or swagger_version == "2.0"): + raise ValueError( + "Unsupported specification version. Expected OpenAPI 3.x or Swagger 2.0" + ) + if "paths" not in spec or not spec["paths"]: + raise ValueError("No API paths defined in the specification") + + +def _extract_metadata(spec: Dict[str, Any], is_swagger: bool) -> Dict[str, Any]: + """Extract API metadata from specification.""" + info = spec.get("info", {}) + base_url = _get_base_url(spec, is_swagger) + + return { + "title": info.get("title", "Untitled API"), + "description": (info.get("description", "") or "")[:500], + "version": info.get("version", ""), + "base_url": base_url, + } + + +def _get_base_url(spec: Dict[str, Any], is_swagger: bool) -> str: + """Extract base URL from spec (handles both OpenAPI 3.x and Swagger 2.0).""" + if is_swagger: + schemes = spec.get("schemes", ["https"]) + host = spec.get("host", "") + base_path = spec.get("basePath", "") + if host: + scheme = schemes[0] if schemes else "https" + return f"{scheme}://{host}{base_path}".rstrip("/") + return "" + servers = spec.get("servers", []) + if servers and isinstance(servers, list) and servers[0].get("url"): + return servers[0]["url"].rstrip("/") + return "" + + +def _extract_actions(spec: Dict[str, Any], is_swagger: bool) -> List[Dict[str, Any]]: + """Extract all API operations as action definitions.""" + actions = [] + paths = spec.get("paths", {}) + base_url = _get_base_url(spec, is_swagger) + + components = spec.get("components", {}) + definitions = spec.get("definitions", {}) + + for path, path_item in paths.items(): + if not isinstance(path_item, dict): + continue + path_params = path_item.get("parameters", []) + + for method in SUPPORTED_METHODS: + operation = path_item.get(method) + if not isinstance(operation, dict): + continue + try: + action = _build_action( + path=path, + method=method, + operation=operation, + path_params=path_params, + base_url=base_url, + components=components, + definitions=definitions, + is_swagger=is_swagger, + ) + actions.append(action) + except Exception as e: + logger.warning( + f"Failed to parse operation {method.upper()} {path}: {e}" + ) + continue + return actions + + +def _build_action( + path: str, + method: str, + operation: Dict[str, Any], + path_params: List[Dict], + base_url: str, + components: Dict[str, Any], + definitions: Dict[str, Any], + is_swagger: bool, +) -> Dict[str, Any]: + """Build a single action from an API operation.""" + action_name = _generate_action_name(operation, method, path) + full_url = f"{base_url}{path}" if base_url else path + + all_params = path_params + operation.get("parameters", []) + query_params, headers = _categorize_parameters(all_params, components, definitions) + + body, body_content_type = _extract_request_body( + operation, components, definitions, is_swagger + ) + + description = operation.get("summary", "") or operation.get("description", "") + + return { + "name": action_name, + "url": full_url, + "method": method.upper(), + "description": (description or "")[:500], + "query_params": {"type": "object", "properties": query_params}, + "headers": {"type": "object", "properties": headers}, + "body": {"type": "object", "properties": body}, + "body_content_type": body_content_type, + "active": True, + } + + +def _generate_action_name(operation: Dict[str, Any], method: str, path: str) -> str: + """Generate a valid action name from operationId or method+path.""" + if operation.get("operationId"): + name = operation["operationId"] + else: + path_slug = re.sub(r"[{}]", "", path) + path_slug = re.sub(r"[^a-zA-Z0-9]", "_", path_slug) + path_slug = re.sub(r"_+", "_", path_slug).strip("_") + name = f"{method}_{path_slug}" + name = re.sub(r"[^a-zA-Z0-9_-]", "_", name) + return name[:64] + + +def _categorize_parameters( + parameters: List[Dict], + components: Dict[str, Any], + definitions: Dict[str, Any], +) -> Tuple[Dict, Dict]: + """Categorize parameters into query params and headers.""" + query_params = {} + headers = {} + + for param in parameters: + resolved = _resolve_ref(param, components, definitions) + if not resolved or "name" not in resolved: + continue + location = resolved.get("in", "query") + prop = _param_to_property(resolved) + + if location in ("query", "path"): + query_params[resolved["name"]] = prop + elif location == "header": + headers[resolved["name"]] = prop + return query_params, headers + + +def _param_to_property(param: Dict) -> Dict[str, Any]: + """Convert an API parameter to an action property definition.""" + schema = param.get("schema", {}) + param_type = schema.get("type", param.get("type", "string")) + + mapped_type = "integer" if param_type in ("integer", "number") else "string" + + return { + "type": mapped_type, + "description": (param.get("description", "") or "")[:200], + "value": "", + "filled_by_llm": param.get("required", False), + "required": param.get("required", False), + } + + +def _extract_request_body( + operation: Dict[str, Any], + components: Dict[str, Any], + definitions: Dict[str, Any], + is_swagger: bool, +) -> Tuple[Dict, str]: + """Extract request body schema and content type.""" + content_types = [ + "application/json", + "application/x-www-form-urlencoded", + "multipart/form-data", + "text/plain", + "application/xml", + ] + + if is_swagger: + consumes = operation.get("consumes", []) + body_param = next( + (p for p in operation.get("parameters", []) if p.get("in") == "body"), None + ) + if not body_param: + return {}, "application/json" + selected_type = consumes[0] if consumes else "application/json" + schema = body_param.get("schema", {}) + else: + request_body = operation.get("requestBody", {}) + if not request_body: + return {}, "application/json" + request_body = _resolve_ref(request_body, components, definitions) + content = request_body.get("content", {}) + + selected_type = "application/json" + schema = {} + + for ct in content_types: + if ct in content: + selected_type = ct + schema = content[ct].get("schema", {}) + break + if not schema and content: + first_type = next(iter(content)) + selected_type = first_type + schema = content[first_type].get("schema", {}) + properties = _schema_to_properties(schema, components, definitions) + return properties, selected_type + + +def _schema_to_properties( + schema: Dict, + components: Dict[str, Any], + definitions: Dict[str, Any], + depth: int = 0, +) -> Dict[str, Any]: + """Convert schema to action body properties (limited depth to prevent recursion).""" + if depth > 3: + return {} + schema = _resolve_ref(schema, components, definitions) + if not schema or not isinstance(schema, dict): + return {} + properties = {} + schema_type = schema.get("type", "object") + + if schema_type == "object": + required_fields = set(schema.get("required", [])) + for prop_name, prop_schema in schema.get("properties", {}).items(): + resolved = _resolve_ref(prop_schema, components, definitions) + if not isinstance(resolved, dict): + continue + prop_type = resolved.get("type", "string") + mapped_type = "integer" if prop_type in ("integer", "number") else "string" + + properties[prop_name] = { + "type": mapped_type, + "description": (resolved.get("description", "") or "")[:200], + "value": "", + "filled_by_llm": prop_name in required_fields, + "required": prop_name in required_fields, + } + return properties + + +def _resolve_ref( + obj: Any, + components: Dict[str, Any], + definitions: Dict[str, Any], +) -> Optional[Dict]: + """Resolve $ref references in the specification.""" + if not isinstance(obj, dict): + return obj if isinstance(obj, dict) else None + if "$ref" not in obj: + return obj + ref_path = obj["$ref"] + + if ref_path.startswith("#/components/"): + parts = ref_path.replace("#/components/", "").split("/") + return _traverse_path(components, parts) + elif ref_path.startswith("#/definitions/"): + parts = ref_path.replace("#/definitions/", "").split("/") + return _traverse_path(definitions, parts) + logger.debug(f"Unsupported ref path: {ref_path}") + return None + + +def _traverse_path(obj: Dict, parts: List[str]) -> Optional[Dict]: + """Traverse a nested dictionary using path parts.""" + try: + for part in parts: + obj = obj[part] + return obj if isinstance(obj, dict) else None + except (KeyError, TypeError): + return None diff --git a/application/agents/tools/telegram.py b/application/agents/tools/telegram.py index 06350ae9..d4381370 100644 --- a/application/agents/tools/telegram.py +++ b/application/agents/tools/telegram.py @@ -1,6 +1,11 @@ +import logging + import requests + from application.agents.tools.base import Tool +logger = logging.getLogger(__name__) + class TelegramTool(Tool): """ @@ -18,21 +23,19 @@ class TelegramTool(Tool): "telegram_send_message": self._send_message, "telegram_send_image": self._send_image, } - - if action_name in actions: - return actions[action_name](**kwargs) - else: + if action_name not in actions: raise ValueError(f"Unknown action: {action_name}") + return actions[action_name](**kwargs) def _send_message(self, text, chat_id): - print(f"Sending message: {text}") + logger.debug("Sending Telegram message to chat_id=%s", chat_id) url = f"https://api.telegram.org/bot{self.token}/sendMessage" payload = {"chat_id": chat_id, "text": text} response = requests.post(url, data=payload) return {"status_code": response.status_code, "message": "Message sent"} def _send_image(self, image_url, chat_id): - print(f"Sending image: {image_url}") + logger.debug("Sending Telegram image to chat_id=%s", chat_id) url = f"https://api.telegram.org/bot{self.token}/sendPhoto" payload = {"chat_id": chat_id, "photo": image_url} response = requests.post(url, data=payload) @@ -82,5 +85,12 @@ class TelegramTool(Tool): def get_config_requirements(self): return { - "token": {"type": "string", "description": "Bot token for authentication"}, + "token": { + "type": "string", + "label": "Bot Token", + "description": "Telegram bot token for authentication", + "required": True, + "secret": True, + "order": 1, + }, } diff --git a/application/agents/tools/todo_list.py b/application/agents/tools/todo_list.py index 87a3e969..b515ad56 100644 --- a/application/agents/tools/todo_list.py +++ b/application/agents/tools/todo_list.py @@ -38,6 +38,8 @@ class TodoListTool(Tool): db = MongoDB.get_client()[settings.MONGO_DB_NAME] self.collection = db["todos"] + self._last_artifact_id: Optional[str] = None + # ----------------------------- # Action implementations # ----------------------------- @@ -54,6 +56,8 @@ class TodoListTool(Tool): if not self.user_id: return "Error: TodoListTool requires a valid user_id." + self._last_artifact_id = None + if action_name == "list": return self._list() @@ -165,6 +169,9 @@ class TodoListTool(Tool): """Return configuration requirements.""" return {} + def get_artifact_id(self, action_name: str, **kwargs: Any) -> Optional[str]: + return self._last_artifact_id + # ----------------------------- # Internal helpers # ----------------------------- @@ -190,11 +197,8 @@ class TodoListTool(Tool): Returns a simple integer (1, 2, 3, ...) scoped to this user/tool. With 5-10 todos max, scanning is negligible. """ - # Find all todos for this user/tool and get their IDs - todos = list(self.collection.find( - {"user_id": self.user_id, "tool_id": self.tool_id}, - {"todo_id": 1} - )) + query = {"user_id": self.user_id, "tool_id": self.tool_id} + todos = list(self.collection.find(query, {"todo_id": 1})) # Find the maximum todo_id max_id = 0 @@ -207,8 +211,8 @@ class TodoListTool(Tool): def _list(self) -> str: """List all todos for the user.""" - cursor = self.collection.find({"user_id": self.user_id, "tool_id": self.tool_id}) - todos = list(cursor) + query = {"user_id": self.user_id, "tool_id": self.tool_id} + todos = list(self.collection.find(query)) if not todos: return "No todos found." @@ -242,7 +246,10 @@ class TodoListTool(Tool): "created_at": now, "updated_at": now, } - self.collection.insert_one(doc) + insert_result = self.collection.insert_one(doc) + inserted_id = getattr(insert_result, "inserted_id", None) or doc.get("_id") + if inserted_id is not None: + self._last_artifact_id = str(inserted_id) return f"Todo created with ID {todo_id}: {title}" def _get(self, todo_id: Optional[Any]) -> str: @@ -251,15 +258,15 @@ class TodoListTool(Tool): if parsed_todo_id is None: return "Error: todo_id must be a positive integer." - doc = self.collection.find_one({ - "user_id": self.user_id, - "tool_id": self.tool_id, - "todo_id": parsed_todo_id - }) + query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id} + doc = self.collection.find_one(query) if not doc: return f"Error: Todo with ID {parsed_todo_id} not found." + if doc.get("_id") is not None: + self._last_artifact_id = str(doc.get("_id")) + title = doc.get("title", "Untitled") status = doc.get("status", "open") @@ -277,14 +284,17 @@ class TodoListTool(Tool): if not title: return "Error: Title is required." - result = self.collection.update_one( - {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}, - {"$set": {"title": title, "updated_at": datetime.now()}} + query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id} + doc = self.collection.find_one_and_update( + query, + {"$set": {"title": title, "updated_at": datetime.now()}}, ) - - if result.matched_count == 0: + if not doc: return f"Error: Todo with ID {parsed_todo_id} not found." + if doc.get("_id") is not None: + self._last_artifact_id = str(doc.get("_id")) + return f"Todo {parsed_todo_id} updated to: {title}" def _complete(self, todo_id: Optional[Any]) -> str: @@ -293,14 +303,17 @@ class TodoListTool(Tool): if parsed_todo_id is None: return "Error: todo_id must be a positive integer." - result = self.collection.update_one( - {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}, - {"$set": {"status": "completed", "updated_at": datetime.now()}} + query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id} + doc = self.collection.find_one_and_update( + query, + {"$set": {"status": "completed", "updated_at": datetime.now()}}, ) - - if result.matched_count == 0: + if not doc: return f"Error: Todo with ID {parsed_todo_id} not found." + if doc.get("_id") is not None: + self._last_artifact_id = str(doc.get("_id")) + return f"Todo {parsed_todo_id} marked as completed." def _delete(self, todo_id: Optional[Any]) -> str: @@ -309,13 +322,12 @@ class TodoListTool(Tool): if parsed_todo_id is None: return "Error: todo_id must be a positive integer." - result = self.collection.delete_one({ - "user_id": self.user_id, - "tool_id": self.tool_id, - "todo_id": parsed_todo_id - }) - - if result.deleted_count == 0: + query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id} + doc = self.collection.find_one_and_delete(query) + if not doc: return f"Error: Todo with ID {parsed_todo_id} not found." + if doc.get("_id") is not None: + self._last_artifact_id = str(doc.get("_id")) + return f"Todo {parsed_todo_id} deleted." diff --git a/application/agents/tools/tool_manager.py b/application/agents/tools/tool_manager.py index 855f1b53..08ef30a4 100644 --- a/application/agents/tools/tool_manager.py +++ b/application/agents/tools/tool_manager.py @@ -36,7 +36,7 @@ class ToolManager: def execute_action(self, tool_name, action_name, user_id=None, **kwargs): if tool_name not in self.tools: raise ValueError(f"Tool '{tool_name}' not loaded") - if tool_name in {"mcp_tool", "memory", "todo_list"} and user_id: + if tool_name in {"mcp_tool", "memory", "todo_list", "notes"} and user_id: tool_config = self.config.get(tool_name, {}) tool = self.load_tool(tool_name, tool_config, user_id) return tool.execute_action(action_name, **kwargs) diff --git a/application/agents/workflow_agent.py b/application/agents/workflow_agent.py new file mode 100644 index 00000000..5c005df5 --- /dev/null +++ b/application/agents/workflow_agent.py @@ -0,0 +1,231 @@ +import logging +from datetime import datetime, timezone +from typing import Any, Dict, Generator, Optional + +from application.agents.base import BaseAgent +from application.agents.workflows.schemas import ( + ExecutionStatus, + Workflow, + WorkflowEdge, + WorkflowGraph, + WorkflowNode, + WorkflowRun, +) +from application.agents.workflows.workflow_engine import WorkflowEngine +from application.core.mongo_db import MongoDB +from application.core.settings import settings +from application.logging import log_activity, LogContext + +logger = logging.getLogger(__name__) + + +class WorkflowAgent(BaseAgent): + """A specialized agent that executes predefined workflows.""" + + def __init__( + self, + *args, + workflow_id: Optional[str] = None, + workflow: Optional[Dict[str, Any]] = None, + workflow_owner: Optional[str] = None, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.workflow_id = workflow_id + self.workflow_owner = workflow_owner + self._workflow_data = workflow + self._engine: Optional[WorkflowEngine] = None + + @log_activity() + def gen( + self, query: str, log_context: LogContext = None + ) -> Generator[Dict[str, str], None, None]: + yield from self._gen_inner(query, log_context) + + def _gen_inner( + self, query: str, log_context: LogContext + ) -> Generator[Dict[str, str], None, None]: + graph = self._load_workflow_graph() + if not graph: + yield {"type": "error", "error": "Failed to load workflow configuration."} + return + self._engine = WorkflowEngine(graph, self) + yield from self._engine.execute({}, query) + self._save_workflow_run(query) + + def _load_workflow_graph(self) -> Optional[WorkflowGraph]: + if self._workflow_data: + return self._parse_embedded_workflow() + if self.workflow_id: + return self._load_from_database() + return None + + def _parse_embedded_workflow(self) -> Optional[WorkflowGraph]: + try: + nodes_data = self._workflow_data.get("nodes", []) + edges_data = self._workflow_data.get("edges", []) + + workflow = Workflow( + name=self._workflow_data.get("name", "Embedded Workflow"), + description=self._workflow_data.get("description"), + ) + + nodes = [] + for n in nodes_data: + node_config = n.get("data", {}) + nodes.append( + WorkflowNode( + id=n["id"], + workflow_id=self.workflow_id or "embedded", + type=n["type"], + title=n.get("title", "Node"), + description=n.get("description"), + position=n.get("position", {"x": 0, "y": 0}), + config=node_config, + ) + ) + edges = [] + for e in edges_data: + edges.append( + WorkflowEdge( + id=e["id"], + workflow_id=self.workflow_id or "embedded", + source=e.get("source") or e.get("source_id"), + target=e.get("target") or e.get("target_id"), + sourceHandle=e.get("sourceHandle") or e.get("source_handle"), + targetHandle=e.get("targetHandle") or e.get("target_handle"), + ) + ) + return WorkflowGraph(workflow=workflow, nodes=nodes, edges=edges) + except Exception as e: + logger.error(f"Invalid embedded workflow: {e}") + return None + + def _load_from_database(self) -> Optional[WorkflowGraph]: + try: + from bson.objectid import ObjectId + + if not self.workflow_id or not ObjectId.is_valid(self.workflow_id): + logger.error(f"Invalid workflow ID: {self.workflow_id}") + return None + owner_id = self.workflow_owner + if not owner_id and isinstance(self.decoded_token, dict): + owner_id = self.decoded_token.get("sub") + if not owner_id: + logger.error( + f"Workflow owner not available for workflow load: {self.workflow_id}" + ) + return None + + mongo = MongoDB.get_client() + db = mongo[settings.MONGO_DB_NAME] + + workflows_coll = db["workflows"] + workflow_nodes_coll = db["workflow_nodes"] + workflow_edges_coll = db["workflow_edges"] + + workflow_doc = workflows_coll.find_one( + {"_id": ObjectId(self.workflow_id), "user": owner_id} + ) + if not workflow_doc: + logger.error( + f"Workflow {self.workflow_id} not found or inaccessible for user {owner_id}" + ) + return None + workflow = Workflow(**workflow_doc) + graph_version = workflow_doc.get("current_graph_version", 1) + try: + graph_version = int(graph_version) + if graph_version <= 0: + graph_version = 1 + except (ValueError, TypeError): + graph_version = 1 + + nodes_docs = list( + workflow_nodes_coll.find( + {"workflow_id": self.workflow_id, "graph_version": graph_version} + ) + ) + if not nodes_docs and graph_version == 1: + nodes_docs = list( + workflow_nodes_coll.find( + { + "workflow_id": self.workflow_id, + "graph_version": {"$exists": False}, + } + ) + ) + nodes = [WorkflowNode(**doc) for doc in nodes_docs] + + edges_docs = list( + workflow_edges_coll.find( + {"workflow_id": self.workflow_id, "graph_version": graph_version} + ) + ) + if not edges_docs and graph_version == 1: + edges_docs = list( + workflow_edges_coll.find( + { + "workflow_id": self.workflow_id, + "graph_version": {"$exists": False}, + } + ) + ) + edges = [WorkflowEdge(**doc) for doc in edges_docs] + + return WorkflowGraph(workflow=workflow, nodes=nodes, edges=edges) + except Exception as e: + logger.error(f"Failed to load workflow from database: {e}") + return None + + def _save_workflow_run(self, query: str) -> None: + if not self._engine: + return + try: + mongo = MongoDB.get_client() + db = mongo[settings.MONGO_DB_NAME] + workflow_runs_coll = db["workflow_runs"] + + run = WorkflowRun( + workflow_id=self.workflow_id or "unknown", + status=self._determine_run_status(), + inputs={"query": query}, + outputs=self._serialize_state(self._engine.state), + steps=self._engine.get_execution_summary(), + created_at=datetime.now(timezone.utc), + completed_at=datetime.now(timezone.utc), + ) + + workflow_runs_coll.insert_one(run.to_mongo_doc()) + except Exception as e: + logger.error(f"Failed to save workflow run: {e}") + + def _determine_run_status(self) -> ExecutionStatus: + if not self._engine or not self._engine.execution_log: + return ExecutionStatus.COMPLETED + for log in self._engine.execution_log: + if log.get("status") == ExecutionStatus.FAILED.value: + return ExecutionStatus.FAILED + return ExecutionStatus.COMPLETED + + def _serialize_state(self, state: Dict[str, Any]) -> Dict[str, Any]: + serialized: Dict[str, Any] = {} + for key, value in state.items(): + serialized[key] = self._serialize_state_value(value) + return serialized + + def _serialize_state_value(self, value: Any) -> Any: + if isinstance(value, dict): + return { + str(dict_key): self._serialize_state_value(dict_value) + for dict_key, dict_value in value.items() + } + if isinstance(value, list): + return [self._serialize_state_value(item) for item in value] + if isinstance(value, tuple): + return [self._serialize_state_value(item) for item in value] + if isinstance(value, datetime): + return value.isoformat() + if isinstance(value, (str, int, float, bool, type(None))): + return value + return str(value) diff --git a/application/agents/workflows/cel_evaluator.py b/application/agents/workflows/cel_evaluator.py new file mode 100644 index 00000000..5f8bab37 --- /dev/null +++ b/application/agents/workflows/cel_evaluator.py @@ -0,0 +1,64 @@ +from typing import Any, Dict + +import celpy +import celpy.celtypes + + +class CelEvaluationError(Exception): + pass + + +def _convert_value(value: Any) -> Any: + if isinstance(value, bool): + return celpy.celtypes.BoolType(value) + if isinstance(value, int): + return celpy.celtypes.IntType(value) + if isinstance(value, float): + return celpy.celtypes.DoubleType(value) + if isinstance(value, str): + return celpy.celtypes.StringType(value) + if isinstance(value, list): + return celpy.celtypes.ListType([_convert_value(item) for item in value]) + if isinstance(value, dict): + return celpy.celtypes.MapType( + {celpy.celtypes.StringType(k): _convert_value(v) for k, v in value.items()} + ) + if value is None: + return celpy.celtypes.BoolType(False) + return celpy.celtypes.StringType(str(value)) + + +def build_activation(state: Dict[str, Any]) -> Dict[str, Any]: + return {k: _convert_value(v) for k, v in state.items()} + + +def evaluate_cel(expression: str, state: Dict[str, Any]) -> Any: + if not expression or not expression.strip(): + raise CelEvaluationError("Empty expression") + try: + env = celpy.Environment() + ast = env.compile(expression) + program = env.program(ast) + activation = build_activation(state) + result = program.evaluate(activation) + except celpy.CELEvalError as exc: + raise CelEvaluationError(f"CEL evaluation error: {exc}") from exc + except Exception as exc: + raise CelEvaluationError(f"CEL error: {exc}") from exc + return cel_to_python(result) + + +def cel_to_python(value: Any) -> Any: + if isinstance(value, celpy.celtypes.BoolType): + return bool(value) + if isinstance(value, celpy.celtypes.IntType): + return int(value) + if isinstance(value, celpy.celtypes.DoubleType): + return float(value) + if isinstance(value, celpy.celtypes.StringType): + return str(value) + if isinstance(value, celpy.celtypes.ListType): + return [cel_to_python(item) for item in value] + if isinstance(value, celpy.celtypes.MapType): + return {str(k): cel_to_python(v) for k, v in value.items()} + return value diff --git a/application/agents/workflows/node_agent.py b/application/agents/workflows/node_agent.py new file mode 100644 index 00000000..d3c9d607 --- /dev/null +++ b/application/agents/workflows/node_agent.py @@ -0,0 +1,109 @@ +"""Workflow Node Agents - defines specialized agents for workflow nodes.""" + +from typing import Any, Dict, List, Optional, Type + +from application.agents.base import BaseAgent +from application.agents.classic_agent import ClassicAgent +from application.agents.react_agent import ReActAgent +from application.agents.workflows.schemas import AgentType + + +class ToolFilterMixin: + """Mixin that filters fetched tools to only those specified in tool_ids.""" + + _allowed_tool_ids: List[str] + + def _get_user_tools(self, user: str = "local") -> Dict[str, Dict[str, Any]]: + all_tools = super()._get_user_tools(user) + if not self._allowed_tool_ids: + return {} + filtered_tools = { + tool_id: tool + for tool_id, tool in all_tools.items() + if str(tool.get("_id", "")) in self._allowed_tool_ids + } + return filtered_tools + + def _get_tools(self, api_key: str = None) -> Dict[str, Dict[str, Any]]: + all_tools = super()._get_tools(api_key) + if not self._allowed_tool_ids: + return {} + filtered_tools = { + tool_id: tool + for tool_id, tool in all_tools.items() + if str(tool.get("_id", "")) in self._allowed_tool_ids + } + return filtered_tools + + +class WorkflowNodeClassicAgent(ToolFilterMixin, ClassicAgent): + + def __init__( + self, + endpoint: str, + llm_name: str, + model_id: str, + api_key: str, + tool_ids: Optional[List[str]] = None, + **kwargs, + ): + super().__init__( + endpoint=endpoint, + llm_name=llm_name, + model_id=model_id, + api_key=api_key, + **kwargs, + ) + self._allowed_tool_ids = tool_ids or [] + + +class WorkflowNodeReActAgent(ToolFilterMixin, ReActAgent): + + def __init__( + self, + endpoint: str, + llm_name: str, + model_id: str, + api_key: str, + tool_ids: Optional[List[str]] = None, + **kwargs, + ): + super().__init__( + endpoint=endpoint, + llm_name=llm_name, + model_id=model_id, + api_key=api_key, + **kwargs, + ) + self._allowed_tool_ids = tool_ids or [] + + +class WorkflowNodeAgentFactory: + + _agents: Dict[AgentType, Type[BaseAgent]] = { + AgentType.CLASSIC: WorkflowNodeClassicAgent, + AgentType.REACT: WorkflowNodeReActAgent, + } + + @classmethod + def create( + cls, + agent_type: AgentType, + endpoint: str, + llm_name: str, + model_id: str, + api_key: str, + tool_ids: Optional[List[str]] = None, + **kwargs, + ) -> BaseAgent: + agent_class = cls._agents.get(agent_type) + if not agent_class: + raise ValueError(f"Unsupported agent type: {agent_type}") + return agent_class( + endpoint=endpoint, + llm_name=llm_name, + model_id=model_id, + api_key=api_key, + tool_ids=tool_ids, + **kwargs, + ) diff --git a/application/agents/workflows/schemas.py b/application/agents/workflows/schemas.py new file mode 100644 index 00000000..5355b88e --- /dev/null +++ b/application/agents/workflows/schemas.py @@ -0,0 +1,235 @@ +from datetime import datetime, timezone +from enum import Enum +from typing import Any, Dict, List, Literal, Optional, Union + +from bson import ObjectId +from pydantic import BaseModel, ConfigDict, Field, field_validator + + +class NodeType(str, Enum): + START = "start" + END = "end" + AGENT = "agent" + NOTE = "note" + STATE = "state" + CONDITION = "condition" + + +class AgentType(str, Enum): + CLASSIC = "classic" + REACT = "react" + + +class ExecutionStatus(str, Enum): + PENDING = "pending" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + + +class Position(BaseModel): + model_config = ConfigDict(extra="forbid") + x: float = 0.0 + y: float = 0.0 + + +class AgentNodeConfig(BaseModel): + model_config = ConfigDict(extra="allow") + agent_type: AgentType = AgentType.CLASSIC + llm_name: Optional[str] = None + system_prompt: str = "You are a helpful assistant." + prompt_template: str = "" + output_variable: Optional[str] = None + stream_to_user: bool = True + tools: List[str] = Field(default_factory=list) + sources: List[str] = Field(default_factory=list) + chunks: str = "2" + retriever: str = "" + model_id: Optional[str] = None + json_schema: Optional[Dict[str, Any]] = None + + +class ConditionCase(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + name: Optional[str] = None + expression: str = "" + source_handle: str = Field(..., alias="sourceHandle") + + +class ConditionNodeConfig(BaseModel): + model_config = ConfigDict(extra="allow") + mode: Literal["simple", "advanced"] = "simple" + cases: List[ConditionCase] = Field(default_factory=list) + + +class StateOperation(BaseModel): + model_config = ConfigDict(extra="forbid") + expression: str = "" + target_variable: str = "" + + +class WorkflowEdgeCreate(BaseModel): + model_config = ConfigDict(populate_by_name=True) + id: str + workflow_id: str + source_id: str = Field(..., alias="source") + target_id: str = Field(..., alias="target") + source_handle: Optional[str] = Field(None, alias="sourceHandle") + target_handle: Optional[str] = Field(None, alias="targetHandle") + + +class WorkflowEdge(WorkflowEdgeCreate): + mongo_id: Optional[str] = Field(None, alias="_id") + + @field_validator("mongo_id", mode="before") + @classmethod + def convert_objectid(cls, v: Any) -> Optional[str]: + if isinstance(v, ObjectId): + return str(v) + return v + + def to_mongo_doc(self) -> Dict[str, Any]: + return { + "id": self.id, + "workflow_id": self.workflow_id, + "source_id": self.source_id, + "target_id": self.target_id, + "source_handle": self.source_handle, + "target_handle": self.target_handle, + } + + +class WorkflowNodeCreate(BaseModel): + model_config = ConfigDict(extra="allow") + id: str + workflow_id: str + type: NodeType + title: str = "Node" + description: Optional[str] = None + position: Position = Field(default_factory=Position) + config: Dict[str, Any] = Field(default_factory=dict) + + @field_validator("position", mode="before") + @classmethod + def parse_position(cls, v: Union[Dict[str, float], Position]) -> Position: + if isinstance(v, dict): + return Position(**v) + return v + + +class WorkflowNode(WorkflowNodeCreate): + mongo_id: Optional[str] = Field(None, alias="_id") + + @field_validator("mongo_id", mode="before") + @classmethod + def convert_objectid(cls, v: Any) -> Optional[str]: + if isinstance(v, ObjectId): + return str(v) + return v + + def to_mongo_doc(self) -> Dict[str, Any]: + return { + "id": self.id, + "workflow_id": self.workflow_id, + "type": self.type.value, + "title": self.title, + "description": self.description, + "position": self.position.model_dump(), + "config": self.config, + } + + +class WorkflowCreate(BaseModel): + model_config = ConfigDict(extra="allow") + name: str = "New Workflow" + description: Optional[str] = None + user: Optional[str] = None + + +class Workflow(WorkflowCreate): + id: Optional[str] = Field(None, alias="_id") + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + @field_validator("id", mode="before") + @classmethod + def convert_objectid(cls, v: Any) -> Optional[str]: + if isinstance(v, ObjectId): + return str(v) + return v + + def to_mongo_doc(self) -> Dict[str, Any]: + return { + "name": self.name, + "description": self.description, + "user": self.user, + "created_at": self.created_at, + "updated_at": self.updated_at, + } + + +class WorkflowGraph(BaseModel): + workflow: Workflow + nodes: List[WorkflowNode] = Field(default_factory=list) + edges: List[WorkflowEdge] = Field(default_factory=list) + + def get_node_by_id(self, node_id: str) -> Optional[WorkflowNode]: + for node in self.nodes: + if node.id == node_id: + return node + return None + + def get_start_node(self) -> Optional[WorkflowNode]: + for node in self.nodes: + if node.type == NodeType.START: + return node + return None + + def get_outgoing_edges(self, node_id: str) -> List[WorkflowEdge]: + return [edge for edge in self.edges if edge.source_id == node_id] + + +class NodeExecutionLog(BaseModel): + model_config = ConfigDict(extra="forbid") + node_id: str + node_type: str + status: ExecutionStatus + started_at: datetime + completed_at: Optional[datetime] = None + error: Optional[str] = None + state_snapshot: Dict[str, Any] = Field(default_factory=dict) + + +class WorkflowRunCreate(BaseModel): + workflow_id: str + inputs: Dict[str, str] = Field(default_factory=dict) + + +class WorkflowRun(BaseModel): + model_config = ConfigDict(extra="allow") + id: Optional[str] = Field(None, alias="_id") + workflow_id: str + status: ExecutionStatus = ExecutionStatus.PENDING + inputs: Dict[str, str] = Field(default_factory=dict) + outputs: Dict[str, Any] = Field(default_factory=dict) + steps: List[NodeExecutionLog] = Field(default_factory=list) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + completed_at: Optional[datetime] = None + + @field_validator("id", mode="before") + @classmethod + def convert_objectid(cls, v: Any) -> Optional[str]: + if isinstance(v, ObjectId): + return str(v) + return v + + def to_mongo_doc(self) -> Dict[str, Any]: + return { + "workflow_id": self.workflow_id, + "status": self.status.value, + "inputs": self.inputs, + "outputs": self.outputs, + "steps": [step.model_dump() for step in self.steps], + "created_at": self.created_at, + "completed_at": self.completed_at, + } diff --git a/application/agents/workflows/workflow_engine.py b/application/agents/workflows/workflow_engine.py new file mode 100644 index 00000000..5444458a --- /dev/null +++ b/application/agents/workflows/workflow_engine.py @@ -0,0 +1,455 @@ +import json +import logging +from datetime import datetime, timezone +from typing import Any, Dict, Generator, List, Optional, TYPE_CHECKING + +from application.agents.workflows.cel_evaluator import CelEvaluationError, evaluate_cel +from application.agents.workflows.node_agent import WorkflowNodeAgentFactory +from application.agents.workflows.schemas import ( + AgentNodeConfig, + ConditionNodeConfig, + ExecutionStatus, + NodeExecutionLog, + NodeType, + WorkflowGraph, + WorkflowNode, +) +from application.core.json_schema_utils import ( + JsonSchemaValidationError, + normalize_json_schema_payload, +) +from application.error import sanitize_api_error +from application.templates.namespaces import NamespaceManager +from application.templates.template_engine import TemplateEngine, TemplateRenderError + +try: + import jsonschema +except ImportError: # pragma: no cover - optional dependency in some deployments. + jsonschema = None + +if TYPE_CHECKING: + from application.agents.base import BaseAgent +logger = logging.getLogger(__name__) + +StateValue = Any +WorkflowState = Dict[str, StateValue] +TEMPLATE_RESERVED_NAMESPACES = {"agent", "system", "source", "tools", "passthrough"} + + +class WorkflowEngine: + MAX_EXECUTION_STEPS = 50 + + def __init__(self, graph: WorkflowGraph, agent: "BaseAgent"): + self.graph = graph + self.agent = agent + self.state: WorkflowState = {} + self.execution_log: List[Dict[str, Any]] = [] + self._condition_result: Optional[str] = None + self._template_engine = TemplateEngine() + self._namespace_manager = NamespaceManager() + + def execute( + self, initial_inputs: WorkflowState, query: str + ) -> Generator[Dict[str, str], None, None]: + self._initialize_state(initial_inputs, query) + + start_node = self.graph.get_start_node() + if not start_node: + yield {"type": "error", "error": "No start node found in workflow."} + return + current_node_id: Optional[str] = start_node.id + steps = 0 + + while current_node_id and steps < self.MAX_EXECUTION_STEPS: + node = self.graph.get_node_by_id(current_node_id) + if not node: + yield {"type": "error", "error": f"Node {current_node_id} not found."} + break + log_entry = self._create_log_entry(node) + + yield { + "type": "workflow_step", + "node_id": node.id, + "node_type": node.type.value, + "node_title": node.title, + "status": "running", + } + + try: + yield from self._execute_node(node) + log_entry["status"] = ExecutionStatus.COMPLETED.value + log_entry["completed_at"] = datetime.now(timezone.utc) + + output_key = f"node_{node.id}_output" + node_output = self.state.get(output_key) + + yield { + "type": "workflow_step", + "node_id": node.id, + "node_type": node.type.value, + "node_title": node.title, + "status": "completed", + "state_snapshot": dict(self.state), + "output": node_output, + } + except Exception as e: + logger.error(f"Error executing node {node.id}: {e}", exc_info=True) + log_entry["status"] = ExecutionStatus.FAILED.value + log_entry["error"] = str(e) + log_entry["completed_at"] = datetime.now(timezone.utc) + log_entry["state_snapshot"] = dict(self.state) + self.execution_log.append(log_entry) + + user_friendly_error = sanitize_api_error(e) + yield { + "type": "workflow_step", + "node_id": node.id, + "node_type": node.type.value, + "node_title": node.title, + "status": "failed", + "state_snapshot": dict(self.state), + "error": user_friendly_error, + } + yield {"type": "error", "error": user_friendly_error} + break + log_entry["state_snapshot"] = dict(self.state) + self.execution_log.append(log_entry) + + if node.type == NodeType.END: + break + current_node_id = self._get_next_node_id(current_node_id) + if current_node_id is None and node.type != NodeType.END: + logger.warning( + f"Branch ended at node '{node.title}' ({node.id}) without reaching an end node" + ) + steps += 1 + if steps >= self.MAX_EXECUTION_STEPS: + logger.warning( + f"Workflow reached max steps limit ({self.MAX_EXECUTION_STEPS})" + ) + + def _initialize_state(self, initial_inputs: WorkflowState, query: str) -> None: + self.state.update(initial_inputs) + self.state["query"] = query + self.state["chat_history"] = str(self.agent.chat_history) + + def _create_log_entry(self, node: WorkflowNode) -> Dict[str, Any]: + return { + "node_id": node.id, + "node_type": node.type.value, + "started_at": datetime.now(timezone.utc), + "completed_at": None, + "status": ExecutionStatus.RUNNING.value, + "error": None, + "state_snapshot": {}, + } + + def _get_next_node_id(self, current_node_id: str) -> Optional[str]: + node = self.graph.get_node_by_id(current_node_id) + edges = self.graph.get_outgoing_edges(current_node_id) + if not edges: + return None + + if node and node.type == NodeType.CONDITION and self._condition_result: + target_handle = self._condition_result + self._condition_result = None + for edge in edges: + if edge.source_handle == target_handle: + return edge.target_id + return None + + return edges[0].target_id + + def _execute_node( + self, node: WorkflowNode + ) -> Generator[Dict[str, str], None, None]: + logger.info(f"Executing node {node.id} ({node.type.value})") + + node_handlers = { + NodeType.START: self._execute_start_node, + NodeType.NOTE: self._execute_note_node, + NodeType.AGENT: self._execute_agent_node, + NodeType.STATE: self._execute_state_node, + NodeType.CONDITION: self._execute_condition_node, + NodeType.END: self._execute_end_node, + } + + handler = node_handlers.get(node.type) + if handler: + yield from handler(node) + + def _execute_start_node( + self, node: WorkflowNode + ) -> Generator[Dict[str, str], None, None]: + yield from () + + def _execute_note_node( + self, node: WorkflowNode + ) -> Generator[Dict[str, str], None, None]: + yield from () + + def _execute_agent_node( + self, node: WorkflowNode + ) -> Generator[Dict[str, str], None, None]: + from application.core.model_utils import ( + get_api_key_for_provider, + get_model_capabilities, + get_provider_from_model_id, + ) + + node_config = AgentNodeConfig(**node.config.get("config", node.config)) + + if node_config.prompt_template: + formatted_prompt = self._format_template(node_config.prompt_template) + else: + formatted_prompt = self.state.get("query", "") + node_json_schema = self._normalize_node_json_schema( + node_config.json_schema, node.title + ) + node_model_id = node_config.model_id or self.agent.model_id + node_llm_name = ( + node_config.llm_name + or get_provider_from_model_id(node_model_id or "") + or self.agent.llm_name + ) + node_api_key = get_api_key_for_provider(node_llm_name) or self.agent.api_key + + if node_json_schema and node_model_id: + model_capabilities = get_model_capabilities(node_model_id) + if model_capabilities and not model_capabilities.get( + "supports_structured_output", False + ): + raise ValueError( + f'Model "{node_model_id}" does not support structured output for node "{node.title}"' + ) + + node_agent = WorkflowNodeAgentFactory.create( + agent_type=node_config.agent_type, + endpoint=self.agent.endpoint, + llm_name=node_llm_name, + model_id=node_model_id, + api_key=node_api_key, + tool_ids=node_config.tools, + prompt=node_config.system_prompt, + chat_history=self.agent.chat_history, + decoded_token=self.agent.decoded_token, + json_schema=node_json_schema, + ) + + full_response_parts: List[str] = [] + structured_response_parts: List[str] = [] + has_structured_response = False + first_chunk = True + for event in node_agent.gen(formatted_prompt): + if "answer" in event: + chunk = str(event["answer"]) + full_response_parts.append(chunk) + if event.get("structured"): + has_structured_response = True + structured_response_parts.append(chunk) + if node_config.stream_to_user: + if first_chunk and hasattr(self, "_has_streamed"): + yield {"answer": "\n\n"} + first_chunk = False + yield event + + if node_config.stream_to_user: + self._has_streamed = True + + full_response = "".join(full_response_parts).strip() + output_value: Any = full_response + if has_structured_response: + structured_response = "".join(structured_response_parts).strip() + response_to_parse = structured_response or full_response + parsed_success, parsed_structured = self._parse_structured_output( + response_to_parse + ) + output_value = parsed_structured if parsed_success else response_to_parse + if node_json_schema: + self._validate_structured_output(node_json_schema, output_value) + elif node_json_schema: + parsed_success, parsed_structured = self._parse_structured_output( + full_response + ) + if not parsed_success: + raise ValueError( + "Structured output was expected but response was not valid JSON" + ) + output_value = parsed_structured + self._validate_structured_output(node_json_schema, output_value) + + default_output_key = f"node_{node.id}_output" + self.state[default_output_key] = output_value + + if node_config.output_variable: + self.state[node_config.output_variable] = output_value + + def _execute_state_node( + self, node: WorkflowNode + ) -> Generator[Dict[str, str], None, None]: + config = node.config.get("config", node.config) + for op in config.get("operations", []): + expression = op.get("expression", "") + target_variable = op.get("target_variable", "") + if expression and target_variable: + self.state[target_variable] = evaluate_cel(expression, self.state) + yield from () + + def _execute_condition_node( + self, node: WorkflowNode + ) -> Generator[Dict[str, str], None, None]: + config = ConditionNodeConfig(**node.config.get("config", node.config)) + matched_handle = None + + for case in config.cases: + if not case.expression.strip(): + continue + try: + if evaluate_cel(case.expression, self.state): + matched_handle = case.source_handle + break + except CelEvaluationError: + continue + + self._condition_result = matched_handle or "else" + yield from () + + def _execute_end_node( + self, node: WorkflowNode + ) -> Generator[Dict[str, str], None, None]: + config = node.config.get("config", node.config) + output_template = str(config.get("output_template", "")) + if output_template: + formatted_output = self._format_template(output_template) + yield {"answer": formatted_output} + + def _parse_structured_output(self, raw_response: str) -> tuple[bool, Optional[Any]]: + normalized_response = raw_response.strip() + if not normalized_response: + return False, None + + try: + return True, json.loads(normalized_response) + except json.JSONDecodeError: + logger.warning( + "Workflow agent returned structured output that was not valid JSON" + ) + return False, None + + def _normalize_node_json_schema( + self, schema: Optional[Dict[str, Any]], node_title: str + ) -> Optional[Dict[str, Any]]: + if schema is None: + return None + try: + return normalize_json_schema_payload(schema) + except JsonSchemaValidationError as exc: + raise ValueError( + f'Invalid JSON schema for node "{node_title}": {exc}' + ) from exc + + def _validate_structured_output(self, schema: Dict[str, Any], output_value: Any) -> None: + if jsonschema is None: + logger.warning( + "jsonschema package is not available, skipping structured output validation" + ) + return + + try: + normalized_schema = normalize_json_schema_payload(schema) + except JsonSchemaValidationError as exc: + raise ValueError(f"Invalid JSON schema: {exc}") from exc + + try: + jsonschema.validate(instance=output_value, schema=normalized_schema) + except jsonschema.exceptions.ValidationError as exc: + raise ValueError(f"Structured output did not match schema: {exc.message}") from exc + except jsonschema.exceptions.SchemaError as exc: + raise ValueError(f"Invalid JSON schema: {exc.message}") from exc + + def _format_template(self, template: str) -> str: + context = self._build_template_context() + try: + return self._template_engine.render(template, context) + except TemplateRenderError as e: + logger.warning( + "Workflow template rendering failed, using raw template: %s", str(e) + ) + return template + + def _build_template_context(self) -> Dict[str, Any]: + docs, docs_together = self._get_source_template_data() + passthrough_data = ( + self.state.get("passthrough") + if isinstance(self.state.get("passthrough"), dict) + else None + ) + tools_data = ( + self.state.get("tools") if isinstance(self.state.get("tools"), dict) else None + ) + + context = self._namespace_manager.build_context( + user_id=getattr(self.agent, "user", None), + request_id=getattr(self.agent, "request_id", None), + passthrough_data=passthrough_data, + docs=docs, + docs_together=docs_together, + tools_data=tools_data, + ) + + agent_context: Dict[str, Any] = {} + for key, value in self.state.items(): + if not isinstance(key, str): + continue + normalized_key = key.strip() + if not normalized_key: + continue + agent_context[normalized_key] = value + + context["agent"] = agent_context + + # Keep legacy top-level variables working while namespaced variables are adopted. + for key, value in agent_context.items(): + if key in TEMPLATE_RESERVED_NAMESPACES: + context[f"agent_{key}"] = value + continue + if key not in context: + context[key] = value + + return context + + def _get_source_template_data(self) -> tuple[Optional[List[Dict[str, Any]]], Optional[str]]: + docs = getattr(self.agent, "retrieved_docs", None) + if not isinstance(docs, list) or len(docs) == 0: + return None, None + + docs_together_parts: List[str] = [] + for doc in docs: + if not isinstance(doc, dict): + continue + text = doc.get("text") + if not isinstance(text, str): + continue + + filename = doc.get("filename") or doc.get("title") or doc.get("source") + if isinstance(filename, str) and filename.strip(): + docs_together_parts.append(f"{filename}\n{text}") + else: + docs_together_parts.append(text) + + docs_together = "\n\n".join(docs_together_parts) if docs_together_parts else None + return docs, docs_together + + def get_execution_summary(self) -> List[NodeExecutionLog]: + return [ + NodeExecutionLog( + node_id=log["node_id"], + node_type=log["node_type"], + status=ExecutionStatus(log["status"]), + started_at=log["started_at"], + completed_at=log.get("completed_at"), + error=log.get("error"), + state_snapshot=log.get("state_snapshot", {}), + ) + for log in self.execution_log + ] diff --git a/application/api/answer/__init__.py b/application/api/answer/__init__.py index 861c922d..a10b9b5f 100644 --- a/application/api/answer/__init__.py +++ b/application/api/answer/__init__.py @@ -3,6 +3,7 @@ from flask import Blueprint from application.api import api from application.api.answer.routes.answer import AnswerResource from application.api.answer.routes.base import answer_ns +from application.api.answer.routes.search import SearchResource from application.api.answer.routes.stream import StreamResource @@ -14,6 +15,7 @@ api.add_namespace(answer_ns) def init_answer_routes(): api.add_resource(StreamResource, "/stream") api.add_resource(AnswerResource, "/api/answer") + api.add_resource(SearchResource, "/api/search") init_answer_routes() diff --git a/application/api/answer/routes/answer.py b/application/api/answer/routes/answer.py index 87d80059..b90ffa15 100644 --- a/application/api/answer/routes/answer.py +++ b/application/api/answer/routes/answer.py @@ -40,9 +40,9 @@ class AnswerResource(Resource, BaseAnswerResource): "chunks": fields.Integer( required=False, default=2, description="Number of chunks" ), - "token_limit": fields.Integer(required=False, description="Token limit"), "retriever": fields.String(required=False, description="Retriever type"), "api_key": fields.String(required=False, description="API key"), + "agent_id": fields.String(required=False, description="Agent ID"), "active_docs": fields.String( required=False, description="Active documents" ), @@ -54,6 +54,10 @@ class AnswerResource(Resource, BaseAnswerResource): default=True, description="Whether to save the conversation", ), + "model_id": fields.String( + required=False, + description="Model ID to use for this request", + ), "passthrough": fields.Raw( required=False, description="Dynamic parameters to inject into prompt template", @@ -97,6 +101,10 @@ class AnswerResource(Resource, BaseAnswerResource): isNoneDoc=data.get("isNoneDoc"), index=None, should_save_conversation=data.get("save_conversation", True), + agent_id=processor.agent_id, + is_shared_usage=processor.is_shared_usage, + shared_token=processor.shared_token, + model_id=processor.model_id, ) stream_result = self.process_response_stream(stream) @@ -133,5 +141,5 @@ class AnswerResource(Resource, BaseAnswerResource): f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}", extra={"error": str(e), "traceback": traceback.format_exc()}, ) - return make_response({"error": str(e)}, 500) + return make_response({"error": "An error occurred processing your request"}, 500) return make_response(result, 200) diff --git a/application/api/answer/routes/base.py b/application/api/answer/routes/base.py index 43e83ed2..2f9c9950 100644 --- a/application/api/answer/routes/base.py +++ b/application/api/answer/routes/base.py @@ -7,11 +7,17 @@ from flask import jsonify, make_response, Response from flask_restx import Namespace from application.api.answer.services.conversation_service import ConversationService +from application.core.model_utils import ( + get_api_key_for_provider, + get_default_model_id, + get_provider_from_model_id, +) from application.core.mongo_db import MongoDB from application.core.settings import settings +from application.error import sanitize_api_error from application.llm.llm_creator import LLMCreator -from application.utils import check_required_fields, get_gpt_model +from application.utils import check_required_fields logger = logging.getLogger(__name__) @@ -27,7 +33,7 @@ class BaseAnswerResource: db = mongo[settings.MONGO_DB_NAME] self.db = db self.user_logs_collection = db["user_logs"] - self.gpt_model = get_gpt_model() + self.default_model_id = get_default_model_id() self.conversation_service = ConversationService() def validate_request( @@ -41,6 +47,27 @@ class BaseAnswerResource: return missing_fields return None + @staticmethod + def _prepare_tool_calls_for_logging( + tool_calls: Optional[List[Dict[str, Any]]], max_chars: int = 10000 + ) -> List[Dict[str, Any]]: + if not tool_calls: + return [] + + prepared = [] + for tool_call in tool_calls: + if not isinstance(tool_call, dict): + prepared.append({"result": str(tool_call)[:max_chars]}) + continue + + item = dict(tool_call) + for key in ("result", "result_full"): + value = item.get(key) + if isinstance(value, str) and len(value) > max_chars: + item[key] = value[:max_chars] + prepared.append(item) + return prepared + def check_usage(self, agent_config: Dict) -> Optional[Response]: """Check if there is a usage limit and if it is exceeded @@ -54,7 +81,6 @@ class BaseAnswerResource: api_key = agent_config.get("user_api_key") if not api_key: return None - agents_collection = self.db["agents"] agent = agents_collection.find_one({"key": api_key}) @@ -62,7 +88,6 @@ class BaseAnswerResource: return make_response( jsonify({"success": False, "message": "Invalid API key."}), 401 ) - limited_token_mode_raw = agent.get("limited_token_mode", False) limited_request_mode_raw = agent.get("limited_request_mode", False) @@ -110,15 +135,12 @@ class BaseAnswerResource: daily_token_usage = token_result[0]["total_tokens"] if token_result else 0 else: daily_token_usage = 0 - if limited_request_mode: daily_request_usage = token_usage_collection.count_documents(match_query) else: daily_request_usage = 0 - if not limited_token_mode and not limited_request_mode: return None - token_exceeded = ( limited_token_mode and token_limit > 0 and daily_token_usage >= token_limit ) @@ -138,7 +160,6 @@ class BaseAnswerResource: ), 429, ) - return None def complete_stream( @@ -155,6 +176,7 @@ class BaseAnswerResource: agent_id: Optional[str] = None, is_shared_usage: bool = False, shared_token: Optional[str] = None, + model_id: Optional[str] = None, ) -> Generator[str, None, None]: """ Generator function that streams the complete conversation response. @@ -173,6 +195,7 @@ class BaseAnswerResource: agent_id: ID of agent used is_shared_usage: Flag for shared agent usage shared_token: Token for shared agent + model_id: Model ID used for the request retrieved_docs: Pre-fetched documents for sources (optional) Yields: @@ -218,9 +241,15 @@ class BaseAnswerResource: data = json.dumps({"type": "thought", "thought": line["thought"]}) yield f"data: {data}\n\n" elif "type" in line: - data = json.dumps(line) + if line.get("type") == "error": + sanitized_error = { + "type": "error", + "error": sanitize_api_error(line.get("error", "An error occurred")) + } + data = json.dumps(sanitized_error) + else: + data = json.dumps(line) yield f"data: {data}\n\n" - if is_structured and structured_chunks: structured_data = { "type": "structured_answer", @@ -230,15 +259,23 @@ class BaseAnswerResource: } data = json.dumps(structured_data) yield f"data: {data}\n\n" - if isNoneDoc: for doc in source_log_docs: doc["source"] = "None" + provider = ( + get_provider_from_model_id(model_id) + if model_id + else settings.LLM_PROVIDER + ) + system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER) + llm = LLMCreator.create_llm( - settings.LLM_PROVIDER, - api_key=settings.API_KEY, + provider or settings.LLM_PROVIDER, + api_key=system_api_key, user_api_key=user_api_key, decoded_token=decoded_token, + model_id=model_id, + agent_id=agent_id, ) if should_save_conversation: @@ -250,7 +287,7 @@ class BaseAnswerResource: source_log_docs, tool_calls, llm, - self.gpt_model, + model_id or self.default_model_id, decoded_token, index=index, api_key=user_api_key, @@ -259,20 +296,46 @@ class BaseAnswerResource: shared_token=shared_token, attachment_ids=attachment_ids, ) + # Persist compression metadata/summary if it exists and wasn't saved mid-execution + compression_meta = getattr(agent, "compression_metadata", None) + compression_saved = getattr(agent, "compression_saved", False) + if conversation_id and compression_meta and not compression_saved: + try: + self.conversation_service.update_compression_metadata( + conversation_id, compression_meta + ) + self.conversation_service.append_compression_message( + conversation_id, compression_meta + ) + agent.compression_saved = True + logger.info( + f"Persisted compression metadata for conversation {conversation_id}" + ) + except Exception as e: + logger.error( + f"Failed to persist compression metadata: {str(e)}", + exc_info=True, + ) else: conversation_id = None id_data = {"type": "id", "id": str(conversation_id)} data = json.dumps(id_data) yield f"data: {data}\n\n" + tool_calls_for_logging = self._prepare_tool_calls_for_logging( + getattr(agent, "tool_calls", tool_calls) or tool_calls + ) + log_data = { "action": "stream_answer", "level": "info", "user": decoded_token.get("sub"), "api_key": user_api_key, + "agent_id": agent_id, "question": question, "response": response_full, "sources": source_log_docs, + "tool_calls": tool_calls_for_logging, "attachments": attachment_ids, "timestamp": datetime.datetime.now(datetime.timezone.utc), } @@ -280,12 +343,11 @@ class BaseAnswerResource: log_data["structured_output"] = True if schema_info: log_data["schema"] = schema_info - # Clean up text fields to be no longer than 10000 characters + for key, value in log_data.items(): if isinstance(value, str) and len(value) > 10000: log_data[key] = value[:10000] - self.user_logs_collection.insert_one(log_data) data = json.dumps({"type": "end"}) @@ -293,6 +355,7 @@ class BaseAnswerResource: except GeneratorExit: logger.info(f"Stream aborted by client for question: {question[:50]}... ") # Save partial response + if should_save_conversation and response_full: try: if isNoneDoc: @@ -303,6 +366,7 @@ class BaseAnswerResource: api_key=settings.API_KEY, user_api_key=user_api_key, decoded_token=decoded_token, + agent_id=agent_id, ) self.conversation_service.save_conversation( conversation_id, @@ -312,7 +376,7 @@ class BaseAnswerResource: source_log_docs, tool_calls, llm, - self.gpt_model, + model_id or self.default_model_id, decoded_token, index=index, api_key=user_api_key, @@ -321,6 +385,25 @@ class BaseAnswerResource: shared_token=shared_token, attachment_ids=attachment_ids, ) + compression_meta = getattr(agent, "compression_metadata", None) + compression_saved = getattr(agent, "compression_saved", False) + if conversation_id and compression_meta and not compression_saved: + try: + self.conversation_service.update_compression_metadata( + conversation_id, compression_meta + ) + self.conversation_service.append_compression_message( + conversation_id, compression_meta + ) + agent.compression_saved = True + logger.info( + f"Persisted compression metadata for conversation {conversation_id} (partial stream)" + ) + except Exception as e: + logger.error( + f"Failed to persist compression metadata (partial stream): {str(e)}", + exc_info=True, + ) except Exception as e: logger.error( f"Error saving partial response: {str(e)}", exc_info=True @@ -369,7 +452,7 @@ class BaseAnswerResource: thought = event["thought"] elif event["type"] == "error": logger.error(f"Error from stream: {event['error']}") - return None, None, None, None, event["error"] + return None, None, None, None, event["error"], None elif event["type"] == "end": stream_ended = True except (json.JSONDecodeError, KeyError) as e: @@ -377,8 +460,7 @@ class BaseAnswerResource: continue if not stream_ended: logger.error("Stream ended unexpectedly without an 'end' event.") - return None, None, None, None, "Stream ended unexpectedly" - + return None, None, None, None, "Stream ended unexpectedly", None result = ( conversation_id, response_full, @@ -390,7 +472,6 @@ class BaseAnswerResource: if is_structured: result = result + ({"structured": True, "schema": schema_info},) - return result def error_stream_generate(self, err_response): diff --git a/application/api/answer/routes/search.py b/application/api/answer/routes/search.py new file mode 100644 index 00000000..16ebdb82 --- /dev/null +++ b/application/api/answer/routes/search.py @@ -0,0 +1,186 @@ +import logging +from typing import Any, Dict, List + +from flask import make_response, request +from flask_restx import fields, Resource + +from bson.dbref import DBRef + +from application.api.answer.routes.base import answer_ns +from application.core.mongo_db import MongoDB +from application.core.settings import settings +from application.vectorstore.vector_creator import VectorCreator + +logger = logging.getLogger(__name__) + + +@answer_ns.route("/api/search") +class SearchResource(Resource): + """Fast search endpoint for retrieving relevant documents""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + mongo = MongoDB.get_client() + self.db = mongo[settings.MONGO_DB_NAME] + self.agents_collection = self.db["agents"] + + search_model = answer_ns.model( + "SearchModel", + { + "question": fields.String( + required=True, description="Search query" + ), + "api_key": fields.String( + required=True, description="API key for authentication" + ), + "chunks": fields.Integer( + required=False, default=5, description="Number of results to return" + ), + }, + ) + + def _get_sources_from_api_key(self, api_key: str) -> List[str]: + """Get source IDs connected to the API key/agent. + + """ + agent_data = self.agents_collection.find_one({"key": api_key}) + if not agent_data: + return [] + + source_ids = [] + + # Handle multiple sources (only if non-empty) + sources = agent_data.get("sources", []) + if sources and isinstance(sources, list) and len(sources) > 0: + for source_ref in sources: + # Skip "default" - it's a placeholder, not an actual vectorstore + if source_ref == "default": + continue + elif isinstance(source_ref, DBRef): + source_doc = self.db.dereference(source_ref) + if source_doc: + source_ids.append(str(source_doc["_id"])) + + # Handle single source (legacy) - check if sources was empty or didn't yield results + if not source_ids: + source = agent_data.get("source") + if isinstance(source, DBRef): + source_doc = self.db.dereference(source) + if source_doc: + source_ids.append(str(source_doc["_id"])) + # Skip "default" - it's a placeholder, not an actual vectorstore + elif source and source != "default": + source_ids.append(source) + + return source_ids + + def _search_vectorstores( + self, query: str, source_ids: List[str], chunks: int + ) -> List[Dict[str, Any]]: + """Search across vectorstores and return results""" + if not source_ids: + return [] + + results = [] + chunks_per_source = max(1, chunks // len(source_ids)) + seen_texts = set() + + for source_id in source_ids: + if not source_id or not source_id.strip(): + continue + + try: + docsearch = VectorCreator.create_vectorstore( + settings.VECTOR_STORE, source_id, settings.EMBEDDINGS_KEY + ) + docs = docsearch.search(query, k=chunks_per_source * 2) + + for doc in docs: + if len(results) >= chunks: + break + + if hasattr(doc, "page_content") and hasattr(doc, "metadata"): + page_content = doc.page_content + metadata = doc.metadata + else: + page_content = doc.get("text", doc.get("page_content", "")) + metadata = doc.get("metadata", {}) + + # Skip duplicates + text_hash = hash(page_content[:200]) + if text_hash in seen_texts: + continue + seen_texts.add(text_hash) + + title = metadata.get( + "title", metadata.get("post_title", "") + ) + if not isinstance(title, str): + title = str(title) if title else "" + + # Clean up title + if title: + title = title.split("/")[-1] + else: + # Use filename or first part of content as title + title = metadata.get("filename", page_content[:50] + "...") + + source = metadata.get("source", source_id) + + results.append({ + "text": page_content, + "title": title, + "source": source, + }) + + if len(results) >= chunks: + break + + except Exception as e: + logger.error( + f"Error searching vectorstore {source_id}: {e}", + exc_info=True, + ) + continue + + return results[:chunks] + + @answer_ns.expect(search_model) + @answer_ns.doc(description="Search for relevant documents based on query") + def post(self): + data = request.get_json() + + question = data.get("question") + api_key = data.get("api_key") + chunks = data.get("chunks", 5) + + if not question: + return make_response({"error": "question is required"}, 400) + + if not api_key: + return make_response({"error": "api_key is required"}, 400) + + # Validate API key + agent = self.agents_collection.find_one({"key": api_key}) + if not agent: + return make_response({"error": "Invalid API key"}, 401) + + try: + # Get sources connected to this API key + source_ids = self._get_sources_from_api_key(api_key) + + if not source_ids: + return make_response([], 200) + + # Perform search + results = self._search_vectorstores(question, source_ids, chunks) + + return make_response(results, 200) + + except Exception as e: + logger.error( + f"/api/search - error: {str(e)}", + extra={"error": str(e)}, + exc_info=True, + ) + return make_response({"error": "Search failed"}, 500) diff --git a/application/api/answer/routes/stream.py b/application/api/answer/routes/stream.py index 92e41c14..d1a71b25 100644 --- a/application/api/answer/routes/stream.py +++ b/application/api/answer/routes/stream.py @@ -40,9 +40,9 @@ class StreamResource(Resource, BaseAnswerResource): "chunks": fields.Integer( required=False, default=2, description="Number of chunks" ), - "token_limit": fields.Integer(required=False, description="Token limit"), "retriever": fields.String(required=False, description="Retriever type"), "api_key": fields.String(required=False, description="API key"), + "agent_id": fields.String(required=False, description="Agent ID"), "active_docs": fields.String( required=False, description="Active documents" ), @@ -57,6 +57,10 @@ class StreamResource(Resource, BaseAnswerResource): default=True, description="Whether to save the conversation", ), + "model_id": fields.String( + required=False, + description="Model ID to use for this request", + ), "attachments": fields.List( fields.String, required=False, description="List of attachment IDs" ), @@ -77,6 +81,12 @@ class StreamResource(Resource, BaseAnswerResource): processor = StreamProcessor(data, decoded_token) try: processor.initialize() + if not processor.decoded_token: + return Response( + self.error_stream_generate("Unauthorized"), + status=401, + mimetype="text/event-stream", + ) docs_together, docs_list = processor.pre_fetch_docs(data["question"]) tools_data = processor.pre_fetch_tools() @@ -98,9 +108,10 @@ class StreamResource(Resource, BaseAnswerResource): index=data.get("index"), should_save_conversation=data.get("save_conversation", True), attachment_ids=data.get("attachments", []), - agent_id=data.get("agent_id"), + agent_id=processor.agent_id, is_shared_usage=processor.is_shared_usage, shared_token=processor.shared_token, + model_id=processor.model_id, ), mimetype="text/event-stream", ) diff --git a/application/api/answer/services/compression/__init__.py b/application/api/answer/services/compression/__init__.py new file mode 100644 index 00000000..4cbdb910 --- /dev/null +++ b/application/api/answer/services/compression/__init__.py @@ -0,0 +1,20 @@ +""" +Compression module for managing conversation context compression. + +""" + +from application.api.answer.services.compression.orchestrator import ( + CompressionOrchestrator, +) +from application.api.answer.services.compression.service import CompressionService +from application.api.answer.services.compression.types import ( + CompressionResult, + CompressionMetadata, +) + +__all__ = [ + "CompressionOrchestrator", + "CompressionService", + "CompressionResult", + "CompressionMetadata", +] diff --git a/application/api/answer/services/compression/message_builder.py b/application/api/answer/services/compression/message_builder.py new file mode 100644 index 00000000..93772fe5 --- /dev/null +++ b/application/api/answer/services/compression/message_builder.py @@ -0,0 +1,234 @@ +"""Message reconstruction utilities for compression.""" + +import logging +import uuid +from typing import Dict, List, Optional + +logger = logging.getLogger(__name__) + + +class MessageBuilder: + """Builds message arrays from compressed context.""" + + @staticmethod + def build_from_compressed_context( + system_prompt: str, + compressed_summary: Optional[str], + recent_queries: List[Dict], + include_tool_calls: bool = False, + context_type: str = "pre_request", + ) -> List[Dict]: + """ + Build messages from compressed context. + + Args: + system_prompt: Original system prompt + compressed_summary: Compressed summary (if any) + recent_queries: Recent uncompressed queries + include_tool_calls: Whether to include tool calls from history + context_type: Type of context ('pre_request' or 'mid_execution') + + Returns: + List of message dicts ready for LLM + """ + # Append compression summary to system prompt if present + if compressed_summary: + system_prompt = MessageBuilder._append_compression_context( + system_prompt, compressed_summary, context_type + ) + + messages = [{"role": "system", "content": system_prompt}] + + # Add recent history + for query in recent_queries: + if "prompt" in query and "response" in query: + messages.append({"role": "user", "content": query["prompt"]}) + messages.append({"role": "assistant", "content": query["response"]}) + + # Add tool calls from history if present + if include_tool_calls and "tool_calls" in query: + for tool_call in query["tool_calls"]: + call_id = tool_call.get("call_id") or str(uuid.uuid4()) + + function_call_dict = { + "function_call": { + "name": tool_call.get("action_name"), + "args": tool_call.get("arguments"), + "call_id": call_id, + } + } + function_response_dict = { + "function_response": { + "name": tool_call.get("action_name"), + "response": {"result": tool_call.get("result")}, + "call_id": call_id, + } + } + + messages.append( + {"role": "assistant", "content": [function_call_dict]} + ) + messages.append( + {"role": "tool", "content": [function_response_dict]} + ) + + # If no recent queries (everything was compressed), add a continuation user message + if len(recent_queries) == 0 and compressed_summary: + messages.append({ + "role": "user", + "content": "Please continue with the remaining tasks based on the context above." + }) + logger.info("Added continuation user message to maintain proper turn-taking after full compression") + + return messages + + @staticmethod + def _append_compression_context( + system_prompt: str, compressed_summary: str, context_type: str = "pre_request" + ) -> str: + """ + Append compression context to system prompt. + + Args: + system_prompt: Original system prompt + compressed_summary: Summary to append + context_type: Type of compression context + + Returns: + Updated system prompt + """ + # Remove existing compression context if present + if "This session is being continued" in system_prompt or "Context window limit reached" in system_prompt: + parts = system_prompt.split("\n\n---\n\n") + system_prompt = parts[0] + + # Build appropriate context message based on type + if context_type == "mid_execution": + context_message = ( + "\n\n---\n\n" + "Context window limit reached during execution. " + "Previous conversation has been compressed to fit within limits. " + "The conversation is summarized below:\n\n" + f"{compressed_summary}" + ) + else: # pre_request + context_message = ( + "\n\n---\n\n" + "This session is being continued from a previous conversation that " + "has been compressed to fit within context limits. " + "The conversation is summarized below:\n\n" + f"{compressed_summary}" + ) + + return system_prompt + context_message + + @staticmethod + def rebuild_messages_after_compression( + messages: List[Dict], + compressed_summary: Optional[str], + recent_queries: List[Dict], + include_current_execution: bool = False, + include_tool_calls: bool = False, + ) -> Optional[List[Dict]]: + """ + Rebuild the message list after compression so tool execution can continue. + + Args: + messages: Original message list + compressed_summary: Compressed summary + recent_queries: Recent uncompressed queries + include_current_execution: Whether to preserve current execution messages + include_tool_calls: Whether to include tool calls from history + + Returns: + Rebuilt message list or None if failed + """ + # Find the system message + system_message = next( + (msg for msg in messages if msg.get("role") == "system"), None + ) + if not system_message: + logger.warning("No system message found in messages list") + return None + + # Update system message with compressed summary + if compressed_summary: + content = system_message.get("content", "") + system_message["content"] = MessageBuilder._append_compression_context( + content, compressed_summary, "mid_execution" + ) + logger.info( + "Appended compression summary to system prompt (truncated): %s", + ( + compressed_summary[:500] + "..." + if len(compressed_summary) > 500 + else compressed_summary + ), + ) + + rebuilt_messages = [system_message] + + # Add recent history from compressed context + for query in recent_queries: + if "prompt" in query and "response" in query: + rebuilt_messages.append({"role": "user", "content": query["prompt"]}) + rebuilt_messages.append( + {"role": "assistant", "content": query["response"]} + ) + + # Add tool calls from history if present + if include_tool_calls and "tool_calls" in query: + for tool_call in query["tool_calls"]: + call_id = tool_call.get("call_id") or str(uuid.uuid4()) + + function_call_dict = { + "function_call": { + "name": tool_call.get("action_name"), + "args": tool_call.get("arguments"), + "call_id": call_id, + } + } + function_response_dict = { + "function_response": { + "name": tool_call.get("action_name"), + "response": {"result": tool_call.get("result")}, + "call_id": call_id, + } + } + + rebuilt_messages.append( + {"role": "assistant", "content": [function_call_dict]} + ) + rebuilt_messages.append( + {"role": "tool", "content": [function_response_dict]} + ) + + # If no recent queries (everything was compressed), add a continuation user message + if len(recent_queries) == 0 and compressed_summary: + rebuilt_messages.append({ + "role": "user", + "content": "Please continue with the remaining tasks based on the context above." + }) + logger.info("Added continuation user message to maintain proper turn-taking after full compression") + + if include_current_execution: + # Preserve any messages that were added during the current execution cycle + recent_msg_count = 1 # system message + for query in recent_queries: + if "prompt" in query and "response" in query: + recent_msg_count += 2 + if "tool_calls" in query: + recent_msg_count += len(query["tool_calls"]) * 2 + + if len(messages) > recent_msg_count: + current_execution_messages = messages[recent_msg_count:] + rebuilt_messages.extend(current_execution_messages) + logger.info( + f"Preserved {len(current_execution_messages)} messages from current execution cycle" + ) + + logger.info( + f"Messages rebuilt: {len(messages)} → {len(rebuilt_messages)} messages. " + f"Ready to continue tool execution." + ) + return rebuilt_messages diff --git a/application/api/answer/services/compression/orchestrator.py b/application/api/answer/services/compression/orchestrator.py new file mode 100644 index 00000000..11a9032c --- /dev/null +++ b/application/api/answer/services/compression/orchestrator.py @@ -0,0 +1,233 @@ +"""High-level compression orchestration.""" + +import logging +from typing import Any, Dict, Optional + +from application.api.answer.services.compression.service import CompressionService +from application.api.answer.services.compression.threshold_checker import ( + CompressionThresholdChecker, +) +from application.api.answer.services.compression.types import CompressionResult +from application.api.answer.services.conversation_service import ConversationService +from application.core.model_utils import ( + get_api_key_for_provider, + get_provider_from_model_id, +) +from application.core.settings import settings +from application.llm.llm_creator import LLMCreator + +logger = logging.getLogger(__name__) + + +class CompressionOrchestrator: + """ + Facade for compression operations. + + Coordinates between all compression components and provides + a simple interface for callers. + """ + + def __init__( + self, + conversation_service: ConversationService, + threshold_checker: Optional[CompressionThresholdChecker] = None, + ): + """ + Initialize orchestrator. + + Args: + conversation_service: Service for DB operations + threshold_checker: Custom threshold checker (optional) + """ + self.conversation_service = conversation_service + self.threshold_checker = threshold_checker or CompressionThresholdChecker() + + def compress_if_needed( + self, + conversation_id: str, + user_id: str, + model_id: str, + decoded_token: Dict[str, Any], + current_query_tokens: int = 500, + ) -> CompressionResult: + """ + Check if compression is needed and perform it if so. + + This is the main entry point for compression operations. + + Args: + conversation_id: Conversation ID + user_id: User ID + model_id: Model being used for conversation + decoded_token: User's decoded JWT token + current_query_tokens: Estimated tokens for current query + + Returns: + CompressionResult with summary and recent queries + """ + try: + # Load conversation + conversation = self.conversation_service.get_conversation( + conversation_id, user_id + ) + + if not conversation: + logger.warning( + f"Conversation {conversation_id} not found for user {user_id}" + ) + return CompressionResult.failure("Conversation not found") + + # Check if compression is needed + if not self.threshold_checker.should_compress( + conversation, model_id, current_query_tokens + ): + # No compression needed, return full history + queries = conversation.get("queries", []) + return CompressionResult.success_no_compression(queries) + + # Perform compression + return self._perform_compression( + conversation_id, conversation, model_id, decoded_token + ) + + except Exception as e: + logger.error( + f"Error in compress_if_needed: {str(e)}", exc_info=True + ) + return CompressionResult.failure(str(e)) + + def _perform_compression( + self, + conversation_id: str, + conversation: Dict[str, Any], + model_id: str, + decoded_token: Dict[str, Any], + ) -> CompressionResult: + """ + Perform the actual compression operation. + + Args: + conversation_id: Conversation ID + conversation: Conversation document + model_id: Model ID for conversation + decoded_token: User token + + Returns: + CompressionResult + """ + try: + # Determine which model to use for compression + compression_model = ( + settings.COMPRESSION_MODEL_OVERRIDE + if settings.COMPRESSION_MODEL_OVERRIDE + else model_id + ) + + # Get provider and API key for compression model + provider = get_provider_from_model_id(compression_model) + api_key = get_api_key_for_provider(provider) + + # Create compression LLM + compression_llm = LLMCreator.create_llm( + provider, + api_key=api_key, + user_api_key=None, + decoded_token=decoded_token, + model_id=compression_model, + agent_id=conversation.get("agent_id"), + ) + + # Create compression service with DB update capability + compression_service = CompressionService( + llm=compression_llm, + model_id=compression_model, + conversation_service=self.conversation_service, + ) + + # Compress all queries up to the latest + queries_count = len(conversation.get("queries", [])) + compress_up_to = queries_count - 1 + + if compress_up_to < 0: + logger.warning("No queries to compress") + return CompressionResult.success_no_compression([]) + + logger.info( + f"Initiating compression for conversation {conversation_id}: " + f"compressing all {queries_count} queries (0-{compress_up_to})" + ) + + # Perform compression and save to DB + metadata = compression_service.compress_and_save( + conversation_id, conversation, compress_up_to + ) + + logger.info( + f"Compression successful - ratio: {metadata.compression_ratio:.1f}x, " + f"saved {metadata.original_token_count - metadata.compressed_token_count} tokens" + ) + + # Reload conversation with updated metadata + conversation = self.conversation_service.get_conversation( + conversation_id, user_id=decoded_token.get("sub") + ) + + # Get compressed context + compressed_summary, recent_queries = ( + compression_service.get_compressed_context(conversation) + ) + + return CompressionResult.success_with_compression( + compressed_summary, recent_queries, metadata + ) + + except Exception as e: + logger.error(f"Error performing compression: {str(e)}", exc_info=True) + return CompressionResult.failure(str(e)) + + def compress_mid_execution( + self, + conversation_id: str, + user_id: str, + model_id: str, + decoded_token: Dict[str, Any], + current_conversation: Optional[Dict[str, Any]] = None, + ) -> CompressionResult: + """ + Perform compression during tool execution. + + Args: + conversation_id: Conversation ID + user_id: User ID + model_id: Model ID + decoded_token: User token + current_conversation: Pre-loaded conversation (optional) + + Returns: + CompressionResult + """ + try: + # Load conversation if not provided + if current_conversation: + conversation = current_conversation + else: + conversation = self.conversation_service.get_conversation( + conversation_id, user_id + ) + + if not conversation: + logger.warning( + f"Could not load conversation {conversation_id} for mid-execution compression" + ) + return CompressionResult.failure("Conversation not found") + + # Perform compression + return self._perform_compression( + conversation_id, conversation, model_id, decoded_token + ) + + except Exception as e: + logger.error( + f"Error in mid-execution compression: {str(e)}", exc_info=True + ) + return CompressionResult.failure(str(e)) diff --git a/application/api/answer/services/compression/prompt_builder.py b/application/api/answer/services/compression/prompt_builder.py new file mode 100644 index 00000000..d5ce3183 --- /dev/null +++ b/application/api/answer/services/compression/prompt_builder.py @@ -0,0 +1,149 @@ +"""Compression prompt building logic.""" + +import logging +from pathlib import Path +from typing import Any, Dict, List, Optional + +logger = logging.getLogger(__name__) + + +class CompressionPromptBuilder: + """Builds prompts for LLM compression calls.""" + + def __init__(self, version: str = "v1.0"): + """ + Initialize prompt builder. + + Args: + version: Prompt template version to use + """ + self.version = version + self.system_prompt = self._load_prompt(version) + + def _load_prompt(self, version: str) -> str: + """ + Load prompt template from file. + + Args: + version: Version string (e.g., 'v1.0') + + Returns: + Prompt template content + + Raises: + FileNotFoundError: If prompt template file doesn't exist + """ + current_dir = Path(__file__).resolve().parents[4] + prompt_path = current_dir / "prompts" / "compression" / f"{version}.txt" + + try: + with open(prompt_path, "r") as f: + return f.read() + except FileNotFoundError: + logger.error(f"Compression prompt template not found: {prompt_path}") + raise FileNotFoundError( + f"Compression prompt template '{version}' not found at {prompt_path}. " + f"Please ensure the template file exists." + ) + + def build_prompt( + self, + queries: List[Dict[str, Any]], + existing_compressions: Optional[List[Dict[str, Any]]] = None, + ) -> List[Dict[str, str]]: + """ + Build messages for compression LLM call. + + Args: + queries: List of query objects to compress + existing_compressions: List of previous compression points + + Returns: + List of message dicts for LLM + """ + # Build conversation text + conversation_text = self._format_conversation(queries) + + # Add existing compression context if present + existing_compression_context = "" + if existing_compressions and len(existing_compressions) > 0: + existing_compression_context = ( + "\n\nIMPORTANT: This conversation has been compressed before. " + "Previous compression summaries:\n\n" + ) + for i, comp in enumerate(existing_compressions): + existing_compression_context += ( + f"--- Compression {i + 1} (up to message {comp.get('query_index', 'unknown')}) ---\n" + f"{comp.get('compressed_summary', '')}\n\n" + ) + existing_compression_context += ( + "Your task is to create a NEW summary that incorporates the context from " + "previous compressions AND the new messages below. The final summary should " + "be comprehensive and include all important information from both previous " + "compressions and new messages.\n\n" + ) + + user_prompt = ( + f"{existing_compression_context}" + f"Here is the conversation to summarize:\n\n" + f"{conversation_text}" + ) + + messages = [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": user_prompt}, + ] + + return messages + + def _format_conversation(self, queries: List[Dict[str, Any]]) -> str: + """ + Format conversation queries into readable text for compression. + + Args: + queries: List of query objects + + Returns: + Formatted conversation text + """ + conversation_lines = [] + + for i, query in enumerate(queries): + conversation_lines.append(f"--- Message {i + 1} ---") + conversation_lines.append(f"User: {query.get('prompt', '')}") + + # Add tool calls if present + tool_calls = query.get("tool_calls", []) + if tool_calls: + conversation_lines.append("\nTool Calls:") + for tc in tool_calls: + tool_name = tc.get("tool_name", "unknown") + action_name = tc.get("action_name", "unknown") + arguments = tc.get("arguments", {}) + result = tc.get("result", "") + if result is None: + result = "" + status = tc.get("status", "unknown") + + # Include full tool result for complete compression context + conversation_lines.append( + f" - {tool_name}.{action_name}({arguments}) " + f"[{status}] → {result}" + ) + + # Add agent thought if present + thought = query.get("thought", "") + if thought: + conversation_lines.append(f"\nAgent Thought: {thought}") + + # Add assistant response + conversation_lines.append(f"\nAssistant: {query.get('response', '')}") + + # Add sources if present + sources = query.get("sources", []) + if sources: + conversation_lines.append(f"\nSources Used: {len(sources)} documents") + + conversation_lines.append("") # Empty line between messages + + return "\n".join(conversation_lines) diff --git a/application/api/answer/services/compression/service.py b/application/api/answer/services/compression/service.py new file mode 100644 index 00000000..ccf6f126 --- /dev/null +++ b/application/api/answer/services/compression/service.py @@ -0,0 +1,306 @@ +"""Core compression service with simplified responsibilities.""" + +import logging +import re +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional + +from application.api.answer.services.compression.prompt_builder import ( + CompressionPromptBuilder, +) +from application.api.answer.services.compression.token_counter import TokenCounter +from application.api.answer.services.compression.types import ( + CompressionMetadata, +) +from application.core.settings import settings + +logger = logging.getLogger(__name__) + + +class CompressionService: + """ + Service for compressing conversation history. + + Handles DB updates. + """ + + def __init__( + self, + llm, + model_id: str, + conversation_service=None, + prompt_builder: Optional[CompressionPromptBuilder] = None, + ): + """ + Initialize compression service. + + Args: + llm: LLM instance to use for compression + model_id: Model ID for compression + conversation_service: Service for DB operations (optional, for DB updates) + prompt_builder: Custom prompt builder (optional) + """ + self.llm = llm + self.model_id = model_id + self.conversation_service = conversation_service + self.prompt_builder = prompt_builder or CompressionPromptBuilder( + version=settings.COMPRESSION_PROMPT_VERSION + ) + + def compress_conversation( + self, + conversation: Dict[str, Any], + compress_up_to_index: int, + ) -> CompressionMetadata: + """ + Compress conversation history up to specified index. + + Args: + conversation: Full conversation document + compress_up_to_index: Last query index to include in compression + + Returns: + CompressionMetadata with compression details + + Raises: + ValueError: If compress_up_to_index is invalid + """ + try: + queries = conversation.get("queries", []) + + if compress_up_to_index < 0 or compress_up_to_index >= len(queries): + raise ValueError( + f"Invalid compress_up_to_index: {compress_up_to_index} " + f"(conversation has {len(queries)} queries)" + ) + + # Get queries to compress + queries_to_compress = queries[: compress_up_to_index + 1] + + # Check if there are existing compressions + existing_compressions = conversation.get("compression_metadata", {}).get( + "compression_points", [] + ) + + if existing_compressions: + logger.info( + f"Found {len(existing_compressions)} previous compression(s) - " + f"will incorporate into new summary" + ) + + # Calculate original token count + original_tokens = TokenCounter.count_query_tokens(queries_to_compress) + + # Log tool call stats + self._log_tool_call_stats(queries_to_compress) + + # Build compression prompt + messages = self.prompt_builder.build_prompt( + queries_to_compress, existing_compressions + ) + + # Call LLM to generate compression + logger.info( + f"Starting compression: {len(queries_to_compress)} queries " + f"(messages 0-{compress_up_to_index}, {original_tokens} tokens) " + f"using model {self.model_id}" + ) + + response = self.llm.gen( + model=self.model_id, messages=messages, max_tokens=4000 + ) + + # Extract summary from response + compressed_summary = self._extract_summary(response) + + # Calculate compressed token count + compressed_tokens = TokenCounter.count_message_tokens( + [{"content": compressed_summary}] + ) + + # Calculate compression ratio + compression_ratio = ( + original_tokens / compressed_tokens if compressed_tokens > 0 else 0 + ) + + logger.info( + f"Compression complete: {original_tokens} → {compressed_tokens} tokens " + f"({compression_ratio:.1f}x compression)" + ) + + # Build compression metadata + compression_metadata = CompressionMetadata( + timestamp=datetime.now(timezone.utc), + query_index=compress_up_to_index, + compressed_summary=compressed_summary, + original_token_count=original_tokens, + compressed_token_count=compressed_tokens, + compression_ratio=compression_ratio, + model_used=self.model_id, + compression_prompt_version=self.prompt_builder.version, + ) + + return compression_metadata + + except Exception as e: + logger.error(f"Error compressing conversation: {str(e)}", exc_info=True) + raise + + def compress_and_save( + self, + conversation_id: str, + conversation: Dict[str, Any], + compress_up_to_index: int, + ) -> CompressionMetadata: + """ + Compress conversation and save to database. + + Args: + conversation_id: Conversation ID + conversation: Full conversation document + compress_up_to_index: Last query index to include + + Returns: + CompressionMetadata + + Raises: + ValueError: If conversation_service not provided or invalid index + """ + if not self.conversation_service: + raise ValueError( + "conversation_service required for compress_and_save operation" + ) + + # Perform compression + metadata = self.compress_conversation(conversation, compress_up_to_index) + + # Save to database + self.conversation_service.update_compression_metadata( + conversation_id, metadata.to_dict() + ) + + logger.info(f"Compression metadata saved to database for {conversation_id}") + + return metadata + + def get_compressed_context( + self, conversation: Dict[str, Any] + ) -> tuple[Optional[str], List[Dict[str, Any]]]: + """ + Get compressed summary + recent uncompressed messages. + + Args: + conversation: Full conversation document + + Returns: + (compressed_summary, recent_messages) + """ + try: + compression_metadata = conversation.get("compression_metadata", {}) + + if not compression_metadata.get("is_compressed"): + logger.debug("No compression metadata found - using full history") + queries = conversation.get("queries", []) + if queries is None: + logger.error("Conversation queries is None - returning empty list") + return None, [] + return None, queries + + compression_points = compression_metadata.get("compression_points", []) + + if not compression_points: + logger.debug("No compression points found - using full history") + queries = conversation.get("queries", []) + if queries is None: + logger.error("Conversation queries is None - returning empty list") + return None, [] + return None, queries + + # Get the most recent compression point + latest_compression = compression_points[-1] + compressed_summary = latest_compression.get("compressed_summary") + last_compressed_index = latest_compression.get("query_index") + compressed_tokens = latest_compression.get("compressed_token_count", 0) + original_tokens = latest_compression.get("original_token_count", 0) + + # Get only messages after compression point + queries = conversation.get("queries", []) + total_queries = len(queries) + recent_queries = queries[last_compressed_index + 1 :] + + logger.info( + f"Using compressed context: summary ({compressed_tokens} tokens, " + f"compressed from {original_tokens}) + {len(recent_queries)} recent messages " + f"(messages {last_compressed_index + 1}-{total_queries - 1})" + ) + + return compressed_summary, recent_queries + + except Exception as e: + logger.error( + f"Error getting compressed context: {str(e)}", exc_info=True + ) + queries = conversation.get("queries", []) + if queries is None: + return None, [] + return None, queries + + def _extract_summary(self, llm_response: str) -> str: + """ + Extract clean summary from LLM response. + + Args: + llm_response: Raw LLM response + + Returns: + Cleaned summary text + """ + try: + # Try to extract content within tags + summary_match = re.search( + r"(.*?)", llm_response, re.DOTALL + ) + + if summary_match: + summary = summary_match.group(1).strip() + else: + # If no summary tags, remove analysis tags and use the rest + summary = re.sub( + r".*?", "", llm_response, flags=re.DOTALL + ).strip() + + return summary + + except Exception as e: + logger.warning(f"Error extracting summary: {str(e)}, using full response") + return llm_response + + def _log_tool_call_stats(self, queries: List[Dict[str, Any]]) -> None: + """Log statistics about tool calls in queries.""" + total_tool_calls = 0 + total_tool_result_chars = 0 + tool_call_breakdown = {} + + for q in queries: + for tc in q.get("tool_calls", []): + total_tool_calls += 1 + tool_name = tc.get("tool_name", "unknown") + action_name = tc.get("action_name", "unknown") + key = f"{tool_name}.{action_name}" + tool_call_breakdown[key] = tool_call_breakdown.get(key, 0) + 1 + + # Track total tool result size + result = tc.get("result", "") + if result: + total_tool_result_chars += len(str(result)) + + if total_tool_calls > 0: + tool_breakdown_str = ", ".join( + f"{tool}({count})" + for tool, count in sorted(tool_call_breakdown.items()) + ) + tool_result_kb = total_tool_result_chars / 1024 + logger.info( + f"Tool call breakdown: {tool_breakdown_str} " + f"(total result size: {tool_result_kb:.1f} KB, {total_tool_result_chars:,} chars)" + ) diff --git a/application/api/answer/services/compression/threshold_checker.py b/application/api/answer/services/compression/threshold_checker.py new file mode 100644 index 00000000..15397018 --- /dev/null +++ b/application/api/answer/services/compression/threshold_checker.py @@ -0,0 +1,103 @@ +"""Compression threshold checking logic.""" + +import logging +from typing import Any, Dict + +from application.core.model_utils import get_token_limit +from application.core.settings import settings +from application.api.answer.services.compression.token_counter import TokenCounter + +logger = logging.getLogger(__name__) + + +class CompressionThresholdChecker: + """Determines if compression is needed based on token thresholds.""" + + def __init__(self, threshold_percentage: float = None): + """ + Initialize threshold checker. + + Args: + threshold_percentage: Percentage of context to use as threshold + (defaults to settings.COMPRESSION_THRESHOLD_PERCENTAGE) + """ + self.threshold_percentage = ( + threshold_percentage or settings.COMPRESSION_THRESHOLD_PERCENTAGE + ) + + def should_compress( + self, + conversation: Dict[str, Any], + model_id: str, + current_query_tokens: int = 500, + ) -> bool: + """ + Determine if compression is needed. + + Args: + conversation: Full conversation document + model_id: Target model for this request + current_query_tokens: Estimated tokens for current query + + Returns: + True if tokens >= threshold% of context window + """ + try: + # Calculate total tokens in conversation + total_tokens = TokenCounter.count_conversation_tokens(conversation) + total_tokens += current_query_tokens + + # Get context window limit for model + context_limit = get_token_limit(model_id) + + # Calculate threshold + threshold = int(context_limit * self.threshold_percentage) + + compression_needed = total_tokens >= threshold + percentage_used = (total_tokens / context_limit) * 100 + + if compression_needed: + logger.warning( + f"COMPRESSION TRIGGERED: {total_tokens} tokens / {context_limit} limit " + f"({percentage_used:.1f}% used, threshold: {self.threshold_percentage * 100:.0f}%)" + ) + else: + logger.info( + f"Compression check: {total_tokens}/{context_limit} tokens " + f"({percentage_used:.1f}% used, threshold: {self.threshold_percentage * 100:.0f}%) - No compression needed" + ) + + return compression_needed + + except Exception as e: + logger.error(f"Error checking compression need: {str(e)}", exc_info=True) + return False + + def check_message_tokens(self, messages: list, model_id: str) -> bool: + """ + Check if message list exceeds threshold. + + Args: + messages: List of message dicts + model_id: Target model + + Returns: + True if at or above threshold + """ + try: + current_tokens = TokenCounter.count_message_tokens(messages) + context_limit = get_token_limit(model_id) + threshold = int(context_limit * self.threshold_percentage) + + if current_tokens >= threshold: + logger.warning( + f"Message context limit approaching: {current_tokens}/{context_limit} tokens " + f"({(current_tokens/context_limit)*100:.1f}%)" + ) + return True + + return False + + except Exception as e: + logger.error(f"Error checking message tokens: {str(e)}", exc_info=True) + return False diff --git a/application/api/answer/services/compression/token_counter.py b/application/api/answer/services/compression/token_counter.py new file mode 100644 index 00000000..ac676cf0 --- /dev/null +++ b/application/api/answer/services/compression/token_counter.py @@ -0,0 +1,103 @@ +"""Token counting utilities for compression.""" + +import logging +from typing import Any, Dict, List + +from application.utils import num_tokens_from_string +from application.core.settings import settings + +logger = logging.getLogger(__name__) + + +class TokenCounter: + """Centralized token counting for conversations and messages.""" + + @staticmethod + def count_message_tokens(messages: List[Dict]) -> int: + """ + Calculate total tokens in a list of messages. + + Args: + messages: List of message dicts with 'content' field + + Returns: + Total token count + """ + total_tokens = 0 + for message in messages: + content = message.get("content", "") + if isinstance(content, str): + total_tokens += num_tokens_from_string(content) + elif isinstance(content, list): + # Handle structured content (tool calls, etc.) + for item in content: + if isinstance(item, dict): + total_tokens += num_tokens_from_string(str(item)) + return total_tokens + + @staticmethod + def count_query_tokens( + queries: List[Dict[str, Any]], include_tool_calls: bool = True + ) -> int: + """ + Count tokens across multiple query objects. + + Args: + queries: List of query objects from conversation + include_tool_calls: Whether to count tool call tokens + + Returns: + Total token count + """ + total_tokens = 0 + + for query in queries: + # Count prompt and response tokens + if "prompt" in query: + total_tokens += num_tokens_from_string(query["prompt"]) + if "response" in query: + total_tokens += num_tokens_from_string(query["response"]) + if "thought" in query: + total_tokens += num_tokens_from_string(query.get("thought", "")) + + # Count tool call tokens + if include_tool_calls and "tool_calls" in query: + for tool_call in query["tool_calls"]: + tool_call_string = ( + f"Tool: {tool_call.get('tool_name')} | " + f"Action: {tool_call.get('action_name')} | " + f"Args: {tool_call.get('arguments')} | " + f"Response: {tool_call.get('result')}" + ) + total_tokens += num_tokens_from_string(tool_call_string) + + return total_tokens + + @staticmethod + def count_conversation_tokens( + conversation: Dict[str, Any], include_system_prompt: bool = False + ) -> int: + """ + Calculate total tokens in a conversation. + + Args: + conversation: Conversation document + include_system_prompt: Whether to include system prompt in count + + Returns: + Total token count + """ + try: + queries = conversation.get("queries", []) + total_tokens = TokenCounter.count_query_tokens(queries) + + # Add system prompt tokens if requested + if include_system_prompt: + # Rough estimate for system prompt + total_tokens += settings.RESERVED_TOKENS.get("system_prompt", 500) + + return total_tokens + + except Exception as e: + logger.error(f"Error calculating conversation tokens: {str(e)}") + return 0 diff --git a/application/api/answer/services/compression/types.py b/application/api/answer/services/compression/types.py new file mode 100644 index 00000000..b71ab9ee --- /dev/null +++ b/application/api/answer/services/compression/types.py @@ -0,0 +1,83 @@ +"""Type definitions for compression module.""" + +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any, Dict, List, Optional + + +@dataclass +class CompressionMetadata: + """Metadata about a compression operation.""" + + timestamp: datetime + query_index: int + compressed_summary: str + original_token_count: int + compressed_token_count: int + compression_ratio: float + model_used: str + compression_prompt_version: str + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for DB storage.""" + return { + "timestamp": self.timestamp, + "query_index": self.query_index, + "compressed_summary": self.compressed_summary, + "original_token_count": self.original_token_count, + "compressed_token_count": self.compressed_token_count, + "compression_ratio": self.compression_ratio, + "model_used": self.model_used, + "compression_prompt_version": self.compression_prompt_version, + } + + +@dataclass +class CompressionResult: + """Result of a compression operation.""" + + success: bool + compressed_summary: Optional[str] = None + recent_queries: List[Dict[str, Any]] = field(default_factory=list) + metadata: Optional[CompressionMetadata] = None + error: Optional[str] = None + compression_performed: bool = False + + @classmethod + def success_with_compression( + cls, summary: str, queries: List[Dict], metadata: CompressionMetadata + ) -> "CompressionResult": + """Create a successful result with compression.""" + return cls( + success=True, + compressed_summary=summary, + recent_queries=queries, + metadata=metadata, + compression_performed=True, + ) + + @classmethod + def success_no_compression(cls, queries: List[Dict]) -> "CompressionResult": + """Create a successful result without compression needed.""" + return cls( + success=True, + recent_queries=queries, + compression_performed=False, + ) + + @classmethod + def failure(cls, error: str) -> "CompressionResult": + """Create a failure result.""" + return cls(success=False, error=error, compression_performed=False) + + def as_history(self) -> List[Dict[str, str]]: + """ + Convert recent queries to history format. + + Returns: + List of prompt/response dicts + """ + return [ + {"prompt": q["prompt"], "response": q["response"]} + for q in self.recent_queries + ] diff --git a/application/api/answer/services/conversation_service.py b/application/api/answer/services/conversation_service.py index eca842d6..09baf3d1 100644 --- a/application/api/answer/services/conversation_service.py +++ b/application/api/answer/services/conversation_service.py @@ -52,7 +52,7 @@ class ConversationService: sources: List[Dict[str, Any]], tool_calls: List[Dict[str, Any]], llm: Any, - gpt_model: str, + model_id: str, decoded_token: Dict[str, Any], index: Optional[int] = None, api_key: Optional[str] = None, @@ -62,11 +62,13 @@ class ConversationService: attachment_ids: Optional[List[str]] = None, ) -> str: """Save or update a conversation in the database""" + if decoded_token is None: + raise ValueError("Invalid or missing authentication token") user_id = decoded_token.get("sub") if not user_id: raise ValueError("User ID not found in token") current_time = datetime.now(timezone.utc) - + # clean up in sources array such that we save max 1k characters for text part for source in sources: if "text" in source and isinstance(source["text"], str): @@ -90,6 +92,7 @@ class ConversationService: f"queries.{index}.tool_calls": tool_calls, f"queries.{index}.timestamp": current_time, f"queries.{index}.attachments": attachment_ids, + f"queries.{index}.model_id": model_id, } }, ) @@ -120,6 +123,7 @@ class ConversationService: "tool_calls": tool_calls, "timestamp": current_time, "attachments": attachment_ids, + "model_id": model_id, } } }, @@ -146,9 +150,12 @@ class ConversationService: ] completion = llm.gen( - model=gpt_model, messages=messages_summary, max_tokens=30 + model=model_id, messages=messages_summary, max_tokens=500 ) + if not completion or not completion.strip(): + completion = question[:50] if question else "New Conversation" + conversation_data = { "user": user_id, "date": current_time, @@ -162,6 +169,7 @@ class ConversationService: "tool_calls": tool_calls, "timestamp": current_time, "attachments": attachment_ids, + "model_id": model_id, } ], } @@ -177,3 +185,103 @@ class ConversationService: conversation_data["api_key"] = agent["key"] result = self.conversations_collection.insert_one(conversation_data) return str(result.inserted_id) + + def update_compression_metadata( + self, conversation_id: str, compression_metadata: Dict[str, Any] + ) -> None: + """ + Update conversation with compression metadata. + + Uses $push with $slice to keep only the most recent compression points, + preventing unbounded array growth. Since each compression incorporates + previous compressions, older points become redundant. + + Args: + conversation_id: Conversation ID + compression_metadata: Compression point data + """ + try: + self.conversations_collection.update_one( + {"_id": ObjectId(conversation_id)}, + { + "$set": { + "compression_metadata.is_compressed": True, + "compression_metadata.last_compression_at": compression_metadata.get( + "timestamp" + ), + }, + "$push": { + "compression_metadata.compression_points": { + "$each": [compression_metadata], + "$slice": -settings.COMPRESSION_MAX_HISTORY_POINTS, + } + }, + }, + ) + logger.info( + f"Updated compression metadata for conversation {conversation_id}" + ) + except Exception as e: + logger.error( + f"Error updating compression metadata: {str(e)}", exc_info=True + ) + raise + + def append_compression_message( + self, conversation_id: str, compression_metadata: Dict[str, Any] + ) -> None: + """ + Append a synthetic compression summary entry into the conversation history. + This makes the summary visible in the DB alongside normal queries. + """ + try: + summary = compression_metadata.get("compressed_summary", "") + if not summary: + return + timestamp = compression_metadata.get("timestamp", datetime.now(timezone.utc)) + + self.conversations_collection.update_one( + {"_id": ObjectId(conversation_id)}, + { + "$push": { + "queries": { + "prompt": "[Context Compression Summary]", + "response": summary, + "thought": "", + "sources": [], + "tool_calls": [], + "timestamp": timestamp, + "attachments": [], + "model_id": compression_metadata.get("model_used"), + } + } + }, + ) + logger.info(f"Appended compression summary to conversation {conversation_id}") + except Exception as e: + logger.error( + f"Error appending compression summary: {str(e)}", exc_info=True + ) + + def get_compression_metadata( + self, conversation_id: str + ) -> Optional[Dict[str, Any]]: + """ + Get compression metadata for a conversation. + + Args: + conversation_id: Conversation ID + + Returns: + Compression metadata dict or None + """ + try: + conversation = self.conversations_collection.find_one( + {"_id": ObjectId(conversation_id)}, {"compression_metadata": 1} + ) + return conversation.get("compression_metadata") if conversation else None + except Exception as e: + logger.error( + f"Error getting compression metadata: {str(e)}", exc_info=True + ) + return None diff --git a/application/api/answer/services/stream_processor.py b/application/api/answer/services/stream_processor.py index bb890937..0cf47094 100644 --- a/application/api/answer/services/stream_processor.py +++ b/application/api/answer/services/stream_processor.py @@ -10,14 +10,21 @@ from bson.dbref import DBRef from bson.objectid import ObjectId from application.agents.agent_creator import AgentCreator +from application.api.answer.services.compression import CompressionOrchestrator +from application.api.answer.services.compression.token_counter import TokenCounter from application.api.answer.services.conversation_service import ConversationService from application.api.answer.services.prompt_renderer import PromptRenderer +from application.core.model_utils import ( + get_api_key_for_provider, + get_default_model_id, + get_provider_from_model_id, + validate_model_id, +) from application.core.mongo_db import MongoDB from application.core.settings import settings from application.retriever.retriever_creator import RetrieverCreator from application.utils import ( calculate_doc_token_budget, - get_gpt_model, limit_chat_history, ) @@ -83,18 +90,24 @@ class StreamProcessor: self.retriever_config = {} self.is_shared_usage = False self.shared_token = None - self.gpt_model = get_gpt_model() + self.agent_id = self.data.get("agent_id") + self.model_id: Optional[str] = None self.conversation_service = ConversationService() + self.compression_orchestrator = CompressionOrchestrator( + self.conversation_service + ) self.prompt_renderer = PromptRenderer() self._prompt_content: Optional[str] = None self._required_tool_actions: Optional[Dict[str, Set[Optional[str]]]] = None + self.compressed_summary: Optional[str] = None + self.compressed_summary_tokens: int = 0 def initialize(self): """Initialize all required components for processing""" self._configure_agent() + self._validate_and_set_model() self._configure_source() self._configure_retriever() - self._configure_agent() self._load_conversation_history() self._process_attachments() @@ -106,14 +119,69 @@ class StreamProcessor: ) if not conversation: raise ValueError("Conversation not found or unauthorized") + + # Check if compression is enabled and needed + if settings.ENABLE_CONVERSATION_COMPRESSION: + self._handle_compression(conversation) + else: + # Original behavior - load all history + self.history = [ + {"prompt": query["prompt"], "response": query["response"]} + for query in conversation.get("queries", []) + ] + else: + self.history = limit_chat_history( + json.loads(self.data.get("history", "[]")), model_id=self.model_id + ) + + def _handle_compression(self, conversation: Dict[str, Any]): + """ + Handle conversation compression logic using orchestrator. + + Args: + conversation: Full conversation document + """ + try: + # Use orchestrator to handle all compression logic + result = self.compression_orchestrator.compress_if_needed( + conversation_id=self.conversation_id, + user_id=self.initial_user_id, + model_id=self.model_id, + decoded_token=self.decoded_token, + ) + + if not result.success: + logger.error(f"Compression failed: {result.error}, using full history") + self.history = [ + {"prompt": query["prompt"], "response": query["response"]} + for query in conversation.get("queries", []) + ] + return + + # Set compressed summary if compression was performed + if result.compression_performed and result.compressed_summary: + self.compressed_summary = result.compressed_summary + self.compressed_summary_tokens = TokenCounter.count_message_tokens( + [{"content": result.compressed_summary}] + ) + logger.info( + f"Using compressed summary ({self.compressed_summary_tokens} tokens) " + f"+ {len(result.recent_queries)} recent messages" + ) + + # Build history from recent queries + self.history = result.as_history() + + except Exception as e: + logger.error( + f"Error handling compression, falling back to standard history: {str(e)}", + exc_info=True, + ) + # Fallback to original behavior self.history = [ {"prompt": query["prompt"], "response": query["response"]} for query in conversation.get("queries", []) ] - else: - self.history = limit_chat_history( - json.loads(self.data.get("history", "[]")), gpt_model=self.gpt_model - ) def _process_attachments(self): """Process any attachments in the request""" @@ -143,6 +211,34 @@ class StreamProcessor: ) return attachments + def _validate_and_set_model(self): + """Validate and set model_id from request""" + from application.core.model_settings import ModelRegistry + + requested_model = self.data.get("model_id") + + if requested_model: + if not validate_model_id(requested_model): + registry = ModelRegistry.get_instance() + available_models = [m.id for m in registry.get_enabled_models()] + raise ValueError( + f"Invalid model_id '{requested_model}'. " + f"Available models: {', '.join(available_models[:5])}" + + ( + f" and {len(available_models) - 5} more" + if len(available_models) > 5 + else "" + ) + ) + self.model_id = requested_model + else: + # Check if agent has a default model configured + agent_default_model = self.agent_config.get("default_model_id", "") + if agent_default_model and validate_model_id(agent_default_model): + self.model_id = agent_default_model + else: + self.model_id = get_default_model_id() + def _get_agent_key(self, agent_id: Optional[str], user_id: Optional[str]) -> tuple: """Get API key for agent with access control""" if not agent_id: @@ -214,6 +310,10 @@ class StreamProcessor: data["sources"] = sources_list else: data["sources"] = [] + + # Preserve model configuration from agent + data["default_model_id"] = data.get("default_model_id", "") + return data def _configure_source(self): @@ -250,28 +350,61 @@ class StreamProcessor: self.source = {} self.all_sources = [] + def _resolve_agent_id(self) -> Optional[str]: + """Resolve agent_id from request, then fall back to conversation context.""" + request_agent_id = self.data.get("agent_id") + if request_agent_id: + return str(request_agent_id) + + if not self.conversation_id or not self.initial_user_id: + return None + + try: + conversation = self.conversation_service.get_conversation( + self.conversation_id, self.initial_user_id + ) + except Exception: + return None + + if not conversation: + return None + + conversation_agent_id = conversation.get("agent_id") + if conversation_agent_id: + return str(conversation_agent_id) + + return None + def _configure_agent(self): """Configure the agent based on request data""" - agent_id = self.data.get("agent_id") + agent_id = self._resolve_agent_id() + self.agent_key, self.is_shared_usage, self.shared_token = self._get_agent_key( agent_id, self.initial_user_id ) + self.agent_id = str(agent_id) if agent_id else None api_key = self.data.get("api_key") if api_key: data_key = self._get_data_from_api_key(api_key) + if data_key.get("_id"): + self.agent_id = str(data_key.get("_id")) self.agent_config.update( { "prompt_id": data_key.get("prompt_id", "default"), "agent_type": data_key.get("agent_type", settings.AGENT_NAME), "user_api_key": api_key, "json_schema": data_key.get("json_schema"), + "default_model_id": data_key.get("default_model_id", ""), } ) self.initial_user_id = data_key.get("user") self.decoded_token = {"sub": data_key.get("user")} if data_key.get("source"): self.source = {"active_docs": data_key["source"]} + if data_key.get("workflow"): + self.agent_config["workflow"] = data_key["workflow"] + self.agent_config["workflow_owner"] = data_key.get("user") if data_key.get("retriever"): self.retriever_config["retriever_name"] = data_key["retriever"] if data_key.get("chunks") is not None: @@ -284,12 +417,15 @@ class StreamProcessor: self.retriever_config["chunks"] = 2 elif self.agent_key: data_key = self._get_data_from_api_key(self.agent_key) + if data_key.get("_id"): + self.agent_id = str(data_key.get("_id")) self.agent_config.update( { "prompt_id": data_key.get("prompt_id", "default"), "agent_type": data_key.get("agent_type", settings.AGENT_NAME), "user_api_key": self.agent_key, "json_schema": data_key.get("json_schema"), + "default_model_id": data_key.get("default_model_id", ""), } ) self.decoded_token = ( @@ -299,6 +435,9 @@ class StreamProcessor: ) if data_key.get("source"): self.source = {"active_docs": data_key["source"]} + if data_key.get("workflow"): + self.agent_config["workflow"] = data_key["workflow"] + self.agent_config["workflow_owner"] = data_key.get("user") if data_key.get("retriever"): self.retriever_config["retriever_name"] = data_key["retriever"] if data_key.get("chunks") is not None: @@ -310,26 +449,32 @@ class StreamProcessor: ) self.retriever_config["chunks"] = 2 else: + agent_type = settings.AGENT_NAME + if self.data.get("workflow") and isinstance( + self.data.get("workflow"), dict + ): + agent_type = "workflow" + self.agent_config["workflow"] = self.data["workflow"] + if isinstance(self.decoded_token, dict): + self.agent_config["workflow_owner"] = self.decoded_token.get("sub") + self.agent_config.update( { "prompt_id": self.data.get("prompt_id", "default"), - "agent_type": settings.AGENT_NAME, + "agent_type": agent_type, "user_api_key": None, "json_schema": None, + "default_model_id": "", } ) def _configure_retriever(self): - history_token_limit = int(self.data.get("token_limit", 2000)) - doc_token_limit = calculate_doc_token_budget( - gpt_model=self.gpt_model, history_token_limit=history_token_limit - ) + doc_token_limit = calculate_doc_token_budget(model_id=self.model_id) self.retriever_config = { "retriever_name": self.data.get("retriever", "classic"), "chunks": int(self.data.get("chunks", 2)), "doc_token_limit": doc_token_limit, - "history_token_limit": history_token_limit, } api_key = self.data.get("api_key") or self.agent_key @@ -344,14 +489,15 @@ class StreamProcessor: prompt=get_prompt(self.agent_config["prompt_id"], self.prompts_collection), chunks=self.retriever_config["chunks"], doc_token_limit=self.retriever_config.get("doc_token_limit", 50000), - gpt_model=self.gpt_model, + model_id=self.model_id, user_api_key=self.agent_config["user_api_key"], + agent_id=self.agent_id, decoded_token=self.decoded_token, ) def pre_fetch_docs(self, question: str) -> tuple[Optional[str], Optional[list]]: """Pre-fetch documents for template rendering before agent creation""" - if self.data.get("isNoneDoc", False): + if self.data.get("isNoneDoc", False) and not self.agent_id: logger.info("Pre-fetch skipped: isNoneDoc=True") return None, None try: @@ -626,17 +772,46 @@ class StreamProcessor: tools_data=tools_data, ) - return AgentCreator.create_agent( - self.agent_config["agent_type"], - endpoint="stream", - llm_name=settings.LLM_PROVIDER, - gpt_model=self.gpt_model, - api_key=settings.API_KEY, - user_api_key=self.agent_config["user_api_key"], - prompt=rendered_prompt, - chat_history=self.history, - retrieved_docs=self.retrieved_docs, - decoded_token=self.decoded_token, - attachments=self.attachments, - json_schema=self.agent_config.get("json_schema"), + provider = ( + get_provider_from_model_id(self.model_id) + if self.model_id + else settings.LLM_PROVIDER ) + system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER) + + agent_type = self.agent_config["agent_type"] + + # Base agent kwargs + agent_kwargs = { + "endpoint": "stream", + "llm_name": provider or settings.LLM_PROVIDER, + "model_id": self.model_id, + "api_key": system_api_key, + "agent_id": self.agent_id, + "user_api_key": self.agent_config["user_api_key"], + "prompt": rendered_prompt, + "chat_history": self.history, + "retrieved_docs": self.retrieved_docs, + "decoded_token": self.decoded_token, + "attachments": self.attachments, + "json_schema": self.agent_config.get("json_schema"), + "compressed_summary": self.compressed_summary, + } + + # Workflow-specific kwargs for workflow agents + if agent_type == "workflow": + workflow_config = self.agent_config.get("workflow") + if isinstance(workflow_config, str): + agent_kwargs["workflow_id"] = workflow_config + elif isinstance(workflow_config, dict): + agent_kwargs["workflow"] = workflow_config + workflow_owner = self.agent_config.get("workflow_owner") + if workflow_owner: + agent_kwargs["workflow_owner"] = workflow_owner + + agent = AgentCreator.create_agent(agent_type, **agent_kwargs) + + agent.conversation_id = self.conversation_id + agent.initial_user_id = self.initial_user_id + + return agent diff --git a/application/api/connector/routes.py b/application/api/connector/routes.py index d8efc1d3..913e5349 100644 --- a/application/api/connector/routes.py +++ b/application/api/connector/routes.py @@ -1,7 +1,9 @@ import base64 import datetime +import html import json import uuid +from urllib.parse import urlencode from bson.objectid import ObjectId @@ -35,6 +37,18 @@ connector = Blueprint("connector", __name__) connectors_ns = Namespace("connectors", description="Connector operations", path="/") api.add_namespace(connectors_ns) +# Fixed callback status path to prevent open redirect +CALLBACK_STATUS_PATH = "/api/connectors/callback-status" + + +def build_callback_redirect(params: dict) -> str: + """Build a safe redirect URL to the callback status page. + + Uses a fixed path and properly URL-encodes all parameters + to prevent URL injection and open redirect vulnerabilities. + """ + return f"{CALLBACK_STATUS_PATH}?{urlencode(params)}" + @connectors_ns.route("/api/connectors/auth") @@ -75,8 +89,8 @@ class ConnectorAuth(Resource): "state": state }), 200) except Exception as e: - current_app.logger.error(f"Error generating connector auth URL: {e}") - return make_response(jsonify({"success": False, "error": str(e)}), 500) + current_app.logger.error(f"Error generating connector auth URL: {e}", exc_info=True) + return make_response(jsonify({"success": False, "error": "Failed to generate authorization URL"}), 500) @connectors_ns.route("/api/connectors/callback") @@ -93,18 +107,37 @@ class ConnectorsCallback(Resource): error = request.args.get('error') state_dict = json.loads(base64.urlsafe_b64decode(state.encode()).decode()) - provider = state_dict["provider"] - state_object_id = state_dict["object_id"] + provider = state_dict.get("provider") + state_object_id = state_dict.get("object_id") + + # Validate provider + if not provider or not isinstance(provider, str) or not ConnectorCreator.is_supported(provider): + return redirect(build_callback_redirect({ + "status": "error", + "message": "Invalid provider" + })) if error: if error == "access_denied": - return redirect(f"/api/connectors/callback-status?status=cancelled&message=Authentication+was+cancelled.+You+can+try+again+if+you'd+like+to+connect+your+account.&provider={provider}") + return redirect(build_callback_redirect({ + "status": "cancelled", + "message": "Authentication was cancelled. You can try again if you'd like to connect your account.", + "provider": provider + })) else: current_app.logger.warning(f"OAuth error in callback: {error}") - return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}") + return redirect(build_callback_redirect({ + "status": "error", + "message": "Authentication failed. Please try again and make sure to grant all requested permissions.", + "provider": provider + })) if not authorization_code: - return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}") + return redirect(build_callback_redirect({ + "status": "error", + "message": "Authentication failed. Please try again and make sure to grant all requested permissions.", + "provider": provider + })) try: auth = ConnectorCreator.create_auth(provider) @@ -113,20 +146,19 @@ class ConnectorsCallback(Resource): session_token = str(uuid.uuid4()) try: - credentials = auth.create_credentials_from_token_info(token_info) - service = auth.build_drive_service(credentials) - user_info = service.about().get(fields="user").execute() - user_email = user_info.get('user', {}).get('emailAddress', 'Connected User') + if provider == "google_drive": + credentials = auth.create_credentials_from_token_info(token_info) + service = auth.build_drive_service(credentials) + user_info = service.about().get(fields="user").execute() + user_email = user_info.get('user', {}).get('emailAddress', 'Connected User') + else: + user_email = token_info.get('user_info', {}).get('email', 'Connected User') + except Exception as e: current_app.logger.warning(f"Could not get user info: {e}") user_email = 'Connected User' - sanitized_token_info = { - "access_token": token_info.get("access_token"), - "refresh_token": token_info.get("refresh_token"), - "token_uri": token_info.get("token_uri"), - "expiry": token_info.get("expiry") - } + sanitized_token_info = auth.sanitize_token_info(token_info) sessions_collection.find_one_and_update( {"_id": ObjectId(state_object_id), "provider": provider}, @@ -141,26 +173,39 @@ class ConnectorsCallback(Resource): ) # Redirect to success page with session token and user email - return redirect(f"/api/connectors/callback-status?status=success&message=Authentication+successful&provider={provider}&session_token={session_token}&user_email={user_email}") + return redirect(build_callback_redirect({ + "status": "success", + "message": "Authentication successful", + "provider": provider, + "session_token": session_token, + "user_email": user_email + })) except Exception as e: current_app.logger.error(f"Error exchanging code for tokens: {str(e)}", exc_info=True) - return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}") + return redirect(build_callback_redirect({ + "status": "error", + "message": "Authentication failed. Please try again and make sure to grant all requested permissions.", + "provider": provider + })) except Exception as e: current_app.logger.error(f"Error handling connector callback: {e}") - return redirect("/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.") + return redirect(build_callback_redirect({ + "status": "error", + "message": "Authentication failed. Please try again and make sure to grant all requested permissions." + })) @connectors_ns.route("/api/connectors/files") class ConnectorFiles(Resource): @api.expect(api.model("ConnectorFilesModel", { - "provider": fields.String(required=True), - "session_token": fields.String(required=True), - "folder_id": fields.String(required=False), - "limit": fields.Integer(required=False), + "provider": fields.String(required=True), + "session_token": fields.String(required=True), + "folder_id": fields.String(required=False), + "limit": fields.Integer(required=False), "page_token": fields.String(required=False), - "search_query": fields.String(required=False) + "search_query": fields.String(required=False), })) @api.doc(description="List files from a connector provider (supports pagination and search)") def post(self): @@ -168,11 +213,8 @@ class ConnectorFiles(Resource): data = request.get_json() provider = data.get('provider') session_token = data.get('session_token') - folder_id = data.get('folder_id') limit = data.get('limit', 10) - page_token = data.get('page_token') - search_query = data.get('search_query') - + if not provider or not session_token: return make_response(jsonify({"success": False, "error": "provider and session_token are required"}), 400) @@ -185,15 +227,12 @@ class ConnectorFiles(Resource): return make_response(jsonify({"success": False, "error": "Invalid or unauthorized session"}), 401) loader = ConnectorCreator.create_connector(provider, session_token) + + generic_keys = {'provider', 'session_token'} input_config = { - 'limit': limit, - 'list_only': True, - 'session_token': session_token, - 'folder_id': folder_id, - 'page_token': page_token + k: v for k, v in data.items() if k not in generic_keys } - if search_query: - input_config['search_query'] = search_query + input_config['list_only'] = True documents = loader.load_data(input_config) @@ -228,8 +267,8 @@ class ConnectorFiles(Resource): "has_more": has_more }), 200) except Exception as e: - current_app.logger.error(f"Error loading connector files: {e}") - return make_response(jsonify({"success": False, "error": f"Failed to load files: {str(e)}"}), 500) + current_app.logger.error(f"Error loading connector files: {e}", exc_info=True) + return make_response(jsonify({"success": False, "error": "Failed to load files"}), 500) @connectors_ns.route("/api/connectors/validate-session") @@ -260,12 +299,7 @@ class ConnectorValidateSession(Resource): if is_expired and token_info.get('refresh_token'): try: refreshed_token_info = auth.refresh_access_token(token_info.get('refresh_token')) - sanitized_token_info = { - "access_token": refreshed_token_info.get("access_token"), - "refresh_token": refreshed_token_info.get("refresh_token"), - "token_uri": refreshed_token_info.get("token_uri"), - "expiry": refreshed_token_info.get("expiry") - } + sanitized_token_info = auth.sanitize_token_info(refreshed_token_info) sessions_collection.update_one( {"session_token": session_token}, {"$set": {"token_info": sanitized_token_info}} @@ -282,15 +316,21 @@ class ConnectorValidateSession(Resource): "error": "Session token has expired. Please reconnect." }), 401) - return make_response(jsonify({ + _base_fields = {"access_token", "refresh_token", "token_uri", "expiry"} + provider_extras = {k: v for k, v in token_info.items() if k not in _base_fields} + + response_data = { "success": True, "expired": False, "user_email": session.get('user_email', 'Connected User'), - "access_token": token_info.get('access_token') - }), 200) + "access_token": token_info.get('access_token'), + **provider_extras, + } + + return make_response(jsonify(response_data), 200) except Exception as e: - current_app.logger.error(f"Error validating connector session: {e}") - return make_response(jsonify({"success": False, "error": str(e)}), 500) + current_app.logger.error(f"Error validating connector session: {e}", exc_info=True) + return make_response(jsonify({"success": False, "error": "Failed to validate session"}), 500) @connectors_ns.route("/api/connectors/disconnect") @@ -311,8 +351,8 @@ class ConnectorDisconnect(Resource): return make_response(jsonify({"success": True}), 200) except Exception as e: - current_app.logger.error(f"Error disconnecting connector session: {e}") - return make_response(jsonify({"success": False, "error": str(e)}), 500) + current_app.logger.error(f"Error disconnecting connector session: {e}", exc_info=True) + return make_response(jsonify({"success": False, "error": "Failed to disconnect session"}), 500) @connectors_ns.route("/api/connectors/sync") @@ -418,8 +458,8 @@ class ConnectorSync(Resource): return make_response( jsonify({ "success": False, - "error": str(err) - }), + "error": "Failed to sync connector source" + }), 400 ) @@ -430,17 +470,32 @@ class ConnectorCallbackStatus(Resource): def get(self): """Return HTML page with connector authentication status""" try: - status = request.args.get('status', 'error') - message = request.args.get('message', '') - provider = request.args.get('provider', 'connector') + # Validate and sanitize status to a known value + status_raw = request.args.get('status', 'error') + status = status_raw if status_raw in ('success', 'error', 'cancelled') else 'error' + + # Escape all user-controlled values for HTML context + message = html.escape(request.args.get('message', '')) + provider_raw = request.args.get('provider', 'connector') + provider = html.escape(provider_raw.replace('_', ' ').title()) session_token = request.args.get('session_token', '') - user_email = request.args.get('user_email', '') - + user_email = html.escape(request.args.get('user_email', '')) + + def safe_js_string(value: str) -> str: + """Safely encode a string for embedding in inline JavaScript.""" + js_encoded = json.dumps(value) + return js_encoded.replace(' - {provider.replace('_', ' ').title()} Authentication + {provider} Authentication