mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-07 06:30:03 +00:00
Compare commits
33 Commits
chore/bump
...
coverage-3
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
727495c553 | ||
|
|
a3b08a5b44 | ||
|
|
d5c0322e2a | ||
|
|
dc6db847ca | ||
|
|
ed0063aada | ||
|
|
3f6d6f15ea | ||
|
|
126fa01b14 | ||
|
|
e06debad5f | ||
|
|
6492852f7d | ||
|
|
00a621f33a | ||
|
|
e92ffc6fdc | ||
|
|
fe185e5b8d | ||
|
|
9f3d9ab860 | ||
|
|
1c0adde380 | ||
|
|
3c56bd0d0b | ||
|
|
86664ebda2 | ||
|
|
db18b743d1 | ||
|
|
9e85cc9065 | ||
|
|
aaaa6f002d | ||
|
|
47dcbcb74b | ||
|
|
ddbfd94193 | ||
|
|
8dec60ab8b | ||
|
|
84b2e4bab4 | ||
|
|
2afdd7f026 | ||
|
|
f364475f64 | ||
|
|
b254de6ed6 | ||
|
|
08dedcaf95 | ||
|
|
c726eb8ebd | ||
|
|
5f0d39e5f1 | ||
|
|
8c82fc5495 | ||
|
|
6d81a15e97 | ||
|
|
5478e4234c | ||
|
|
eaf39bb15b |
@@ -3,6 +3,14 @@ LLM_NAME=docsgpt
|
||||
VITE_API_STREAMING=true
|
||||
INTERNAL_KEY=<internal key for worker-to-backend authentication>
|
||||
|
||||
# Provider-specific API keys (optional - use these to enable multiple providers)
|
||||
# OPENAI_API_KEY=<your-openai-api-key>
|
||||
# ANTHROPIC_API_KEY=<your-anthropic-api-key>
|
||||
# GOOGLE_API_KEY=<your-google-api-key>
|
||||
# GROQ_API_KEY=<your-groq-api-key>
|
||||
# NOVITA_API_KEY=<your-novita-api-key>
|
||||
# OPEN_ROUTER_API_KEY=<your-openrouter-api-key>
|
||||
|
||||
# Remote Embeddings (Optional - for using a remote embeddings API instead of local SentenceTransformer)
|
||||
# When set, the app will use the remote API and won't load SentenceTransformer (saves RAM)
|
||||
EMBEDDINGS_BASE_URL=
|
||||
|
||||
@@ -29,7 +29,7 @@
|
||||
|
||||
<div align="center">
|
||||
<br>
|
||||
<img src="https://d3dg1063dc54p9.cloudfront.net/videos/demov7.gif" alt="video-example-of-docs-gpt" width="800" height="450">
|
||||
<img src="https://d3dg1063dc54p9.cloudfront.net/videos/demo-26.gif" alt="video-example-of-docs-gpt" width="800" height="480">
|
||||
</div>
|
||||
<h3 align="left">
|
||||
<strong>Key Features:</strong>
|
||||
|
||||
@@ -185,7 +185,10 @@ class ToolExecutor:
|
||||
target_dict[param] = value
|
||||
|
||||
# Load tool (with caching)
|
||||
tool = self._get_or_load_tool(tool_data, tool_id, action_name)
|
||||
tool = self._get_or_load_tool(
|
||||
tool_data, tool_id, action_name,
|
||||
headers=headers, query_params=query_params,
|
||||
)
|
||||
|
||||
resolved_arguments = (
|
||||
{"query_params": query_params, "headers": headers, "body": body}
|
||||
@@ -238,7 +241,10 @@ class ToolExecutor:
|
||||
|
||||
return result, call_id
|
||||
|
||||
def _get_or_load_tool(self, tool_data: Dict, tool_id: str, action_name: str):
|
||||
def _get_or_load_tool(
|
||||
self, tool_data: Dict, tool_id: str, action_name: str,
|
||||
headers: Optional[Dict] = None, query_params: Optional[Dict] = None,
|
||||
):
|
||||
"""Load a tool, using cache when possible."""
|
||||
cache_key = f"{tool_data['name']}:{tool_id}:{self.user or ''}"
|
||||
if cache_key in self._loaded_tools:
|
||||
@@ -251,8 +257,8 @@ class ToolExecutor:
|
||||
tool_config = {
|
||||
"url": action_config["url"],
|
||||
"method": action_config["method"],
|
||||
"headers": {},
|
||||
"query_params": {},
|
||||
"headers": headers or {},
|
||||
"query_params": query_params or {},
|
||||
}
|
||||
if "body_content_type" in action_config:
|
||||
tool_config["body_content_type"] = action_config.get(
|
||||
|
||||
@@ -27,6 +27,8 @@ ANTHROPIC_ATTACHMENTS = IMAGE_ATTACHMENTS
|
||||
|
||||
OPENROUTER_ATTACHMENTS = IMAGE_ATTACHMENTS
|
||||
|
||||
NOVITA_ATTACHMENTS = IMAGE_ATTACHMENTS
|
||||
|
||||
|
||||
OPENAI_MODELS = [
|
||||
AvailableModel(
|
||||
@@ -193,6 +195,46 @@ OPENROUTER_MODELS = [
|
||||
),
|
||||
]
|
||||
|
||||
NOVITA_MODELS = [
|
||||
AvailableModel(
|
||||
id="moonshotai/kimi-k2.5",
|
||||
provider=ModelProvider.NOVITA,
|
||||
display_name="Kimi K2.5",
|
||||
description="MoE model with function calling, structured output, reasoning, and vision",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=NOVITA_ATTACHMENTS,
|
||||
context_window=262144,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="zai-org/glm-5",
|
||||
provider=ModelProvider.NOVITA,
|
||||
display_name="GLM-5",
|
||||
description="MoE model with function calling, structured output, and reasoning",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=[],
|
||||
context_window=202800,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="minimax/minimax-m2.5",
|
||||
provider=ModelProvider.NOVITA,
|
||||
display_name="MiniMax M2.5",
|
||||
description="MoE model with function calling, structured output, and reasoning",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=[],
|
||||
context_window=204800,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
AZURE_OPENAI_MODELS = [
|
||||
AvailableModel(
|
||||
id="azure-gpt-4",
|
||||
|
||||
@@ -114,6 +114,10 @@ class ModelRegistry:
|
||||
settings.LLM_PROVIDER == "openrouter" and settings.API_KEY
|
||||
):
|
||||
self._add_openrouter_models(settings)
|
||||
if settings.NOVITA_API_KEY or (
|
||||
settings.LLM_PROVIDER == "novita" and settings.API_KEY
|
||||
):
|
||||
self._add_novita_models(settings)
|
||||
if settings.HUGGINGFACE_API_KEY or (
|
||||
settings.LLM_PROVIDER == "huggingface" and settings.API_KEY
|
||||
):
|
||||
@@ -245,6 +249,21 @@ class ModelRegistry:
|
||||
for model in OPENROUTER_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_novita_models(self, settings):
|
||||
from application.core.model_configs import NOVITA_MODELS
|
||||
|
||||
if settings.NOVITA_API_KEY:
|
||||
for model in NOVITA_MODELS:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
if settings.LLM_PROVIDER == "novita" and settings.LLM_NAME:
|
||||
for model in NOVITA_MODELS:
|
||||
if model.id == settings.LLM_NAME:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
for model in NOVITA_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_docsgpt_models(self, settings):
|
||||
model_id = "docsgpt-local"
|
||||
model = AvailableModel(
|
||||
|
||||
@@ -10,6 +10,7 @@ def get_api_key_for_provider(provider: str) -> Optional[str]:
|
||||
provider_key_map = {
|
||||
"openai": settings.OPENAI_API_KEY,
|
||||
"openrouter": settings.OPEN_ROUTER_API_KEY,
|
||||
"novita": settings.NOVITA_API_KEY,
|
||||
"anthropic": settings.ANTHROPIC_API_KEY,
|
||||
"google": settings.GOOGLE_API_KEY,
|
||||
"groq": settings.GROQ_API_KEY,
|
||||
|
||||
@@ -5,9 +5,7 @@ from typing import Optional
|
||||
from pydantic import field_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
current_dir = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
)
|
||||
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
@@ -15,15 +13,11 @@ class Settings(BaseSettings):
|
||||
|
||||
AUTH_TYPE: Optional[str] = None # simple_jwt, session_jwt, or None
|
||||
LLM_PROVIDER: str = "docsgpt"
|
||||
LLM_NAME: Optional[str] = (
|
||||
None # if LLM_PROVIDER is openai, LLM_NAME can be gpt-4 or gpt-3.5-turbo
|
||||
)
|
||||
LLM_NAME: Optional[str] = None # if LLM_PROVIDER is openai, LLM_NAME can be gpt-4 or gpt-3.5-turbo
|
||||
EMBEDDINGS_NAME: str = "huggingface_sentence-transformers/all-mpnet-base-v2"
|
||||
EMBEDDINGS_BASE_URL: Optional[str] = None # Remote embeddings API URL (OpenAI-compatible)
|
||||
EMBEDDINGS_KEY: Optional[str] = (
|
||||
None # api key for embeddings (if using openai, just copy API_KEY)
|
||||
)
|
||||
|
||||
EMBEDDINGS_KEY: Optional[str] = None # api key for embeddings (if using openai, just copy API_KEY)
|
||||
|
||||
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
|
||||
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
|
||||
MONGO_URI: str = "mongodb://localhost:27017/docsgpt"
|
||||
@@ -45,9 +39,7 @@ class Settings(BaseSettings):
|
||||
PARSE_IMAGE_REMOTE: bool = False
|
||||
DOCLING_OCR_ENABLED: bool = False # Enable OCR for docling parsers (PDF, images)
|
||||
DOCLING_OCR_ATTACHMENTS_ENABLED: bool = False # Enable OCR for docling when parsing attachments
|
||||
VECTOR_STORE: str = (
|
||||
"faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb" or "pgvector"
|
||||
)
|
||||
VECTOR_STORE: str = "faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb" or "pgvector"
|
||||
RETRIEVERS_ENABLED: list = ["classic_rag"]
|
||||
AGENT_NAME: str = "classic"
|
||||
FALLBACK_LLM_PROVIDER: Optional[str] = None # provider for fallback llm
|
||||
@@ -55,12 +47,8 @@ class Settings(BaseSettings):
|
||||
FALLBACK_LLM_API_KEY: Optional[str] = None # api key for fallback llm
|
||||
|
||||
# Google Drive integration
|
||||
GOOGLE_CLIENT_ID: Optional[str] = (
|
||||
None # Replace with your actual Google OAuth client ID
|
||||
)
|
||||
GOOGLE_CLIENT_SECRET: Optional[str] = (
|
||||
None # Replace with your actual Google OAuth client secret
|
||||
)
|
||||
GOOGLE_CLIENT_ID: Optional[str] = None # Replace with your actual Google OAuth client ID
|
||||
GOOGLE_CLIENT_SECRET: Optional[str] = None # Replace with your actual Google OAuth client secret
|
||||
CONNECTOR_REDIRECT_BASE_URI: Optional[str] = (
|
||||
"http://127.0.0.1:7091/api/connectors/callback" ##add redirect url as it is to your provider's console(gcp)
|
||||
)
|
||||
@@ -72,7 +60,7 @@ class Settings(BaseSettings):
|
||||
MICROSOFT_AUTHORITY: Optional[str] = None # e.g., "https://login.microsoftonline.com/{tenant_id}"
|
||||
|
||||
# GitHub source
|
||||
GITHUB_ACCESS_TOKEN: Optional[str] = None # PAT token with read repo access
|
||||
GITHUB_ACCESS_TOKEN: Optional[str] = None # PAT token with read repo access
|
||||
|
||||
# LLM Cache
|
||||
CACHE_REDIS_URL: str = "redis://localhost:6379/2"
|
||||
@@ -90,16 +78,13 @@ class Settings(BaseSettings):
|
||||
GROQ_API_KEY: Optional[str] = None
|
||||
HUGGINGFACE_API_KEY: Optional[str] = None
|
||||
OPEN_ROUTER_API_KEY: Optional[str] = None
|
||||
NOVITA_API_KEY: Optional[str] = None
|
||||
|
||||
OPENAI_API_BASE: Optional[str] = None # azure openai api base url
|
||||
OPENAI_API_VERSION: Optional[str] = None # azure openai api version
|
||||
AZURE_DEPLOYMENT_NAME: Optional[str] = None # azure deployment name for answering
|
||||
AZURE_EMBEDDINGS_DEPLOYMENT_NAME: Optional[str] = (
|
||||
None # azure deployment name for embeddings
|
||||
)
|
||||
OPENAI_BASE_URL: Optional[str] = (
|
||||
None # openai base url for open ai compatable models
|
||||
)
|
||||
AZURE_EMBEDDINGS_DEPLOYMENT_NAME: Optional[str] = None # azure deployment name for embeddings
|
||||
OPENAI_BASE_URL: Optional[str] = None # openai base url for open ai compatable models
|
||||
|
||||
# elasticsearch
|
||||
ELASTIC_CLOUD_ID: Optional[str] = None # cloud id for elasticsearch
|
||||
@@ -141,9 +126,7 @@ class Settings(BaseSettings):
|
||||
|
||||
# LanceDB vectorstore config
|
||||
LANCEDB_PATH: str = "./data/lancedb" # Path where LanceDB stores its local data
|
||||
LANCEDB_TABLE_NAME: Optional[str] = (
|
||||
"docsgpts" # Name of the table to use for storing vectors
|
||||
)
|
||||
LANCEDB_TABLE_NAME: Optional[str] = "docsgpts" # Name of the table to use for storing vectors
|
||||
|
||||
FLASK_DEBUG_MODE: bool = False
|
||||
STORAGE_TYPE: str = "local" # local or s3
|
||||
@@ -180,6 +163,7 @@ class Settings(BaseSettings):
|
||||
"GOOGLE_API_KEY",
|
||||
"GROQ_API_KEY",
|
||||
"HUGGINGFACE_API_KEY",
|
||||
"NOVITA_API_KEY",
|
||||
"EMBEDDINGS_KEY",
|
||||
"FALLBACK_LLM_API_KEY",
|
||||
"QDRANT_API_KEY",
|
||||
|
||||
@@ -7,6 +7,7 @@ class LLMHandlerCreator:
|
||||
handlers = {
|
||||
"openai": OpenAILLMHandler,
|
||||
"google": GoogleLLMHandler,
|
||||
"novita": OpenAILLMHandler, # Novita uses OpenAI-compatible API
|
||||
"default": OpenAILLMHandler,
|
||||
}
|
||||
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from application.core.settings import settings
|
||||
from application.llm.openai import OpenAILLM
|
||||
|
||||
NOVITA_BASE_URL = "https://api.novita.ai/v3/openai"
|
||||
NOVITA_BASE_URL = "https://api.novita.ai/openai"
|
||||
|
||||
|
||||
class NovitaLLM(OpenAILLM):
|
||||
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
|
||||
super().__init__(
|
||||
api_key=api_key or settings.API_KEY,
|
||||
api_key=api_key or settings.NOVITA_API_KEY or settings.API_KEY,
|
||||
user_api_key=user_api_key,
|
||||
base_url=base_url or NOVITA_BASE_URL,
|
||||
*args,
|
||||
|
||||
@@ -44,36 +44,40 @@ The main set of instructions or system [prompt](/Guides/Customising-prompts) tha
|
||||
|
||||
## Understanding Agent Types
|
||||
|
||||
DocsGPT allows for different "types" of agents, each with a distinct way of processing information and generating responses. The code for these agent types can be found in the `application/agents/` directory.
|
||||
DocsGPT supports several agent types, each with a distinct way of processing information. The code for these can be found in the `application/agents/` directory.
|
||||
|
||||
### 1. Classic Agent (`classic_agent.py`)
|
||||
### 1. Classic Agent
|
||||
|
||||
**How it works:** The Classic Agent follows a traditional Retrieval Augmented Generation (RAG) approach.
|
||||
1. **Retrieve:** When a query is made, it first searches the selected Source documents for relevant information.
|
||||
2. **Augment:** This retrieved data is then added to the context, along with the main Prompt and the user's query.
|
||||
3. **Generate:** The LLM generates a response based on this augmented context. It can also utilize any configured tools if the LLM decides they are necessary.
|
||||
The Classic Agent follows a traditional Retrieval Augmented Generation (RAG) approach: it retrieves relevant document chunks, augments the prompt context with them, and generates a response. It can also use configured tools if the LLM decides they are necessary.
|
||||
|
||||
**Best for:**
|
||||
* Direct question-answering over a specific set of documents.
|
||||
* Tasks where the primary goal is to extract and synthesize information from the provided sources.
|
||||
* Simpler tool integrations where the decision to use a tool is straightforward.
|
||||
**Best for:** Direct question-answering over a specific set of documents and straightforward tool use.
|
||||
|
||||
### 2. ReAct Agent (`react_agent.py`)
|
||||
### 2. Agentic Agent
|
||||
|
||||
**How it works:** The ReAct Agent employs a more sophisticated "Reason and Act" framework. This involves a multi-step process:
|
||||
1. **Plan (Thought):** Based on the query, its prompt, and available tools/sources, the LLM first generates a plan or a sequence of thoughts on how to approach the problem. You might see this output as a "thought" process during generation.
|
||||
2. **Act:** The agent then executes actions based on this plan. This might involve querying its sources, using a tool, or performing internal reasoning.
|
||||
3. **Observe:** It gathers observations from the results of its actions (e.g., data from a tool, snippets from documents).
|
||||
4. **Repeat (if necessary):** Steps 2 and 3 can be repeated as the agent refines its approach or gathers more information.
|
||||
5. **Conclude:** Finally, it generates the final answer based on the initial query and all accumulated observations.
|
||||
Unlike Classic which pre-fetches documents into the prompt, the Agentic Agent gives the LLM an `internal_search` tool so it can decide **when, what, and whether** to search. This means the LLM controls its own retrieval — it can search multiple times, refine queries, or skip retrieval entirely if the question doesn't need it.
|
||||
|
||||
**Best for:**
|
||||
* More complex tasks that require multi-step reasoning or problem-solving.
|
||||
* Scenarios where the agent needs to dynamically decide which tools to use and in what order, based on intermediate results.
|
||||
* Interactive tasks where the agent needs to "think" through a problem.
|
||||
**Best for:** Tasks where the agent needs to dynamically decide how to gather information, use multiple tools in sequence, or combine retrieval with external tool calls.
|
||||
|
||||
### 3. Research Agent
|
||||
|
||||
A multi-phase agent designed for in-depth research tasks:
|
||||
1. **Clarification** — Determines if the question needs clarification before proceeding.
|
||||
2. **Planning** — Decomposes the question into research steps with adaptive depth based on complexity.
|
||||
3. **Research** — Executes each step, calling tools and refining queries as needed.
|
||||
4. **Synthesis** — Compiles findings into a final cited report.
|
||||
|
||||
Includes budget controls for max steps, timeout, and token limits to keep research bounded.
|
||||
|
||||
**Best for:** Complex questions that require multi-step investigation, gathering information from multiple sources, and producing structured reports with citations.
|
||||
|
||||
### 4. Workflow Agent
|
||||
|
||||
Executes predefined workflows composed of connected nodes (AI Agent, Set State, Condition). See the [Workflow Nodes](/Agents/nodes) page for details on building workflows.
|
||||
|
||||
**Best for:** Structured, multi-step processes with branching logic and shared state between steps.
|
||||
|
||||
<Callout type="info">
|
||||
Developers looking to introduce new agent architectures can explore the `application/agents/` directory. `classic_agent.py` and `react_agent.py` serve as excellent starting points, demonstrating how to inherit from `BaseAgent` and structure agent logic.
|
||||
The legacy "ReAct" agent type is still accepted for backwards compatibility but maps to the Classic Agent internally. New agents should use Classic, Agentic, or Research instead.
|
||||
</Callout>
|
||||
|
||||
## Navigating and Managing Agents in DocsGPT
|
||||
|
||||
@@ -70,9 +70,9 @@ Inside the DocsGPT folder create a `.env` file and copy the contents of `.env_sa
|
||||
Make sure your `.env` file looks like this:
|
||||
|
||||
```
|
||||
OPENAI_API_KEY=(Your OpenAI API key)
|
||||
API_KEY=<Your LLM API key>
|
||||
LLM_NAME=docsgpt
|
||||
VITE_API_STREAMING=true
|
||||
SELF_HOSTED_MODEL=false
|
||||
```
|
||||
|
||||
To save the file, press CTRL+X, then Y, and then ENTER.
|
||||
|
||||
@@ -104,7 +104,7 @@ DocsGPT can transcribe audio in two places:
|
||||
- Voice input in the chat.
|
||||
- Audio file ingestion. Uploaded `.wav`, `.mp3`, `.m4a`, `.ogg`, and `.webm` files are transcribed first and then passed through the normal parser, chunking, embedding, and indexing pipeline.
|
||||
|
||||
For an end-to-end walkthrough, see the [Speech and Audio Guide](/Guides/speech-and-audio).
|
||||
The settings below control speech-to-text behaviour for both voice input and audio file ingestion.
|
||||
|
||||
| Setting | Purpose | Typical values |
|
||||
| --- | --- | --- |
|
||||
@@ -214,6 +214,31 @@ If you have configured `AUTH_TYPE=simple_jwt`, the DocsGPT frontend will prompt
|
||||
}}
|
||||
/>
|
||||
|
||||
## S3 Storage Backend
|
||||
|
||||
By default DocsGPT stores files locally. Set `STORAGE_TYPE=s3` to use Amazon S3 instead.
|
||||
|
||||
| Setting | Description | Default |
|
||||
| --- | --- | --- |
|
||||
| `STORAGE_TYPE` | `local` or `s3` | `local` |
|
||||
| `S3_BUCKET_NAME` | S3 bucket name | `docsgpt-test-bucket` |
|
||||
| `SAGEMAKER_ACCESS_KEY` | AWS access key ID | — |
|
||||
| `SAGEMAKER_SECRET_KEY` | AWS secret access key | — |
|
||||
| `SAGEMAKER_REGION` | AWS region | — |
|
||||
| `URL_STRATEGY` | `backend` (proxy through API) or `s3` (direct S3 URLs) | `backend` |
|
||||
|
||||
The S3 credentials use `SAGEMAKER_*` variable names because they are shared with the SageMaker integration.
|
||||
|
||||
```env
|
||||
STORAGE_TYPE=s3
|
||||
S3_BUCKET_NAME=your-bucket-name
|
||||
SAGEMAKER_ACCESS_KEY=your-aws-access-key-id
|
||||
SAGEMAKER_SECRET_KEY=your-aws-secret-access-key
|
||||
SAGEMAKER_REGION=us-east-1
|
||||
```
|
||||
|
||||
Your IAM user needs these permissions on the bucket: `s3:PutObject`, `s3:GetObject`, `s3:DeleteObject`, `s3:ListBucket`, `s3:HeadObject`.
|
||||
|
||||
## Exploring More Settings
|
||||
|
||||
These are just the basic settings to get you started. The `settings.py` file contains many more advanced options that you can explore to further customize DocsGPT, such as:
|
||||
|
||||
@@ -86,13 +86,9 @@ Make sure your `.env` file looks like this:
|
||||
|
||||
|
||||
```
|
||||
|
||||
OPENAI_API_KEY=(Your OpenAI API key)
|
||||
|
||||
API_KEY=<Your LLM API key>
|
||||
LLM_NAME=docsgpt
|
||||
VITE_API_STREAMING=true
|
||||
|
||||
SELF_HOSTED_MODEL=false
|
||||
|
||||
```
|
||||
|
||||
|
||||
|
||||
@@ -11,18 +11,18 @@ DocsGPT API keys are essential for developers and users who wish to integrate th
|
||||
|
||||
After uploading your document, you can obtain an API key either through the graphical user interface or via an API call:
|
||||
|
||||
- **Graphical User Interface:** Navigate to the Settings section of the DocsGPT web app, find the API Keys option, and press 'Create New' to generate your key.
|
||||
- **API Call:** Alternatively, you can use the `/api/create_api_key` endpoint to create a new API key. For detailed instructions, visit [DocsGPT API Documentation](https://gptcloud.arc53.com/).
|
||||
- **Graphical User Interface:** Navigate to the Settings section of the DocsGPT web app, find the Agents option, and press 'Create New' to generate a new agent (which includes an API key).
|
||||
- **API Call:** Alternatively, you can use the `/api/create_agent` endpoint to create a new agent. An API key is automatically generated for each agent. For detailed instructions, visit [DocsGPT API Documentation](https://gptcloud.arc53.com/).
|
||||
|
||||
## Understanding Key Variables
|
||||
|
||||
Upon creating your API key, you will encounter several key variables. Each serves a specific purpose:
|
||||
Upon creating your agent, you will encounter several key variables. Each serves a specific purpose:
|
||||
|
||||
- **Name:** Assign a name to your API key for easy identification.
|
||||
- **Source:** Indicates the source document(s) linked to your API key, which DocsGPT will use to generate responses.
|
||||
- **ID:** A unique identifier for your API key. You can view this by making a call to `/api/get_api_keys`.
|
||||
- **Key:** The API key itself, which will be used in your application to authenticate API requests.
|
||||
- **Name:** Assign a name to your agent for easy identification.
|
||||
- **Source:** Indicates the source document(s) linked to your agent, which DocsGPT will use to generate responses.
|
||||
- **ID:** A unique identifier for your agent. You can view this by making a call to `/api/get_agents`.
|
||||
- **Key:** The API key for the agent, which will be used in your application to authenticate API requests.
|
||||
|
||||
With your API key ready, you can now integrate DocsGPT into your application, such as the DocsGPT Widget or any other software, via `/api/answer` or `/stream` endpoints. The source document is preset with the API key, allowing you to bypass fields like `selectDocs` and `active_docs` during implementation.
|
||||
With your API key ready, you can now integrate DocsGPT into your application, such as the DocsGPT Widget or any other software, via `/api/answer` or `/stream` endpoints. The source document is preset with the agent, allowing you to bypass fields like `selectDocs` and `active_docs` during implementation.
|
||||
|
||||
Congratulations on taking the first step towards enhancing your applications with DocsGPT!
|
||||
|
||||
@@ -64,7 +64,7 @@ flowchart LR
|
||||
* **Technology:** Supports multiple vector databases.
|
||||
* **Responsibility:** Vector Stores are used to store and retrieve vector embeddings of document chunks. This enables semantic search and retrieval of relevant document snippets in response to user queries.
|
||||
* **Key Features:**
|
||||
* Supports vector databases including FAISS, Elasticsearch, Qdrant, Milvus, and LanceDB.
|
||||
* Supports vector databases including FAISS, Elasticsearch, Qdrant, Milvus, MongoDB Atlas Vector Search, and pgvector.
|
||||
* Provides storage and indexing of high-dimensional vector embeddings.
|
||||
* Enables editing and updating of vector indexes including specific chunks.
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ Training on other documentation sources can greatly enhance the versatility and
|
||||
Make sure you have the document on which you want to train on ready with you on the device which you are using .You can also use links to the documentation to train on.
|
||||
|
||||
<Callout type="warning" emoji="⚠️">
|
||||
Note: The document should be either of the given file formats .pdf, .txt, .rst, .docx, .md, .zip and limited to 25mb.You can also train using the link of the documentation.
|
||||
Note: Supported file formats include .pdf, .txt, .rst, .docx, .md, .mdx, .csv, .epub, .html, .json, .xlsx, .pptx, .png, .jpg, .jpeg, and audio files (.wav, .mp3, .m4a, .ogg, .webm). You can also train using the link of the documentation.
|
||||
|
||||
</Callout>
|
||||
|
||||
|
||||
@@ -35,8 +35,34 @@ Choose the LLM of your choice.
|
||||
For open source version please edit `LLM_PROVIDER`, `LLM_NAME` and others in the .env file. Refer to [⚙️ App Configuration](/Deploying/DocsGPT-Settings) for more information.
|
||||
### Step 2
|
||||
Visit [☁️ Cloud Providers](/Models/cloud-providers) for the updated list of online models. Make sure you have the right API_KEY and correct LLM_PROVIDER.
|
||||
For self-hosted please visit [🖥️ Local Inference](/Models/local-inference).
|
||||
For self-hosted please visit [🖥️ Local Inference](/Models/local-inference).
|
||||
</Steps>
|
||||
|
||||
## Fallback LLM
|
||||
|
||||
DocsGPT can automatically switch to a fallback LLM when the primary model fails, including mid-stream. This works with both streaming and non-streaming requests.
|
||||
|
||||
**Fallback order:**
|
||||
1. Per-agent backup models (other models configured on the same agent)
|
||||
2. Global fallback (`FALLBACK_LLM_*` env vars below)
|
||||
3. Error returned if all fail
|
||||
|
||||
| Setting | Description | Default |
|
||||
| --- | --- | --- |
|
||||
| `FALLBACK_LLM_PROVIDER` | Provider name (e.g., `openai`, `anthropic`, `google`) | — |
|
||||
| `FALLBACK_LLM_NAME` | Model name (e.g., `gpt-4o`, `claude-sonnet-4-20250514`) | — |
|
||||
| `FALLBACK_LLM_API_KEY` | API key for the fallback provider | Falls back to `API_KEY` |
|
||||
|
||||
All three (`FALLBACK_LLM_PROVIDER`, `FALLBACK_LLM_NAME`, and an API key) must resolve for the global fallback to activate.
|
||||
|
||||
```env
|
||||
FALLBACK_LLM_PROVIDER=anthropic
|
||||
FALLBACK_LLM_NAME=claude-sonnet-4-20250514
|
||||
FALLBACK_LLM_API_KEY=sk-ant-your-anthropic-key
|
||||
```
|
||||
|
||||
<Callout type="info">
|
||||
For maximum resilience, use a fallback provider from a different cloud than your primary. Each agent can also have multiple models configured — the other models are tried first before the global fallback.
|
||||
</Callout>
|
||||
|
||||
|
||||
|
||||
@@ -2,5 +2,13 @@ export default {
|
||||
"google-drive-connector": {
|
||||
"title": "🔗 Google Drive",
|
||||
"href": "/Guides/Integrations/google-drive-connector"
|
||||
},
|
||||
"sharepoint-connector": {
|
||||
"title": "🔗 SharePoint / OneDrive",
|
||||
"href": "/Guides/Integrations/sharepoint-connector"
|
||||
},
|
||||
"mcp-tool-integration": {
|
||||
"title": "🔗 MCP Tools",
|
||||
"href": "/Guides/Integrations/mcp-tool-integration"
|
||||
}
|
||||
}
|
||||
|
||||
66
docs/content/Guides/Integrations/mcp-tool-integration.mdx
Normal file
66
docs/content/Guides/Integrations/mcp-tool-integration.mdx
Normal file
@@ -0,0 +1,66 @@
|
||||
---
|
||||
title: MCP Tool Integration
|
||||
description: Connect external tools to DocsGPT agents using the Model Context Protocol (MCP) standard.
|
||||
---
|
||||
|
||||
import { Callout } from 'nextra/components'
|
||||
import { Steps } from 'nextra/components'
|
||||
|
||||
# MCP Tool Integration
|
||||
|
||||
The [Model Context Protocol (MCP)](https://modelcontextprotocol.io/) integration lets you connect external tool servers to DocsGPT. Your agents can then discover and call tools provided by those servers during conversations — for example, querying a CRM, running code, or accessing a database.
|
||||
|
||||
## Setup
|
||||
|
||||
<Steps>
|
||||
|
||||
### Step 1: Configure Environment Variables (Optional)
|
||||
|
||||
Only needed if your MCP servers use OAuth authentication:
|
||||
|
||||
```env
|
||||
MCP_OAUTH_REDIRECT_URI=https://yourdomain.com/api/mcp_server/callback
|
||||
```
|
||||
|
||||
If not set, falls back to `API_URL/api/mcp_server/callback`.
|
||||
|
||||
### Step 2: Add an MCP Server
|
||||
|
||||
Go to **Settings** > **Tools** > **Add Tool** > **MCP Server**. Enter the server URL, select an auth type, and click **Test Connection** to verify, then **Save**.
|
||||
|
||||
### Step 3: Enable for Your Agent
|
||||
|
||||
In your agent configuration, enable the MCP tools you want the agent to use.
|
||||
|
||||
</Steps>
|
||||
|
||||
## Authentication Types
|
||||
|
||||
| Auth Type | Config Fields |
|
||||
|-----------|---------------|
|
||||
| **None** | — |
|
||||
| **Bearer** | `bearer_token` |
|
||||
| **API Key** | `api_key`, `api_key_header` (default: `X-API-Key`) |
|
||||
| **Basic** | `username`, `password` |
|
||||
| **OAuth** | `oauth_scopes` (optional) |
|
||||
|
||||
<Callout type="warning">
|
||||
For OAuth in production, `MCP_OAUTH_REDIRECT_URI` must be a publicly accessible URL pointing to your DocsGPT backend.
|
||||
</Callout>
|
||||
|
||||
## API Endpoints
|
||||
|
||||
| Endpoint | Method | Description |
|
||||
|----------|--------|-------------|
|
||||
| `/api/mcp_server/test` | POST | Test a connection without saving |
|
||||
| `/api/mcp_server/save` | POST | Save or update a server configuration |
|
||||
| `/api/mcp_server/callback` | GET | OAuth callback handler |
|
||||
| `/api/mcp_server/oauth_status/<task_id>` | GET | Poll OAuth flow status |
|
||||
| `/api/mcp_server/auth_status` | GET | Batch check auth status for all MCP tools |
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
- **Connection refused** — Verify the URL and that the server is reachable from your backend.
|
||||
- **403 Forbidden** — Check credentials and permissions.
|
||||
- **Timed out** — Default is 30s; increase timeout in tool config (max 300s).
|
||||
- **OAuth "needs_auth" persists** — Verify `MCP_OAUTH_REDIRECT_URI` is correct and Redis is running.
|
||||
63
docs/content/Guides/Integrations/sharepoint-connector.mdx
Normal file
63
docs/content/Guides/Integrations/sharepoint-connector.mdx
Normal file
@@ -0,0 +1,63 @@
|
||||
---
|
||||
title: SharePoint / OneDrive Connector
|
||||
description: Connect your Microsoft SharePoint or OneDrive as an external knowledge base to upload and process files directly.
|
||||
---
|
||||
|
||||
import { Callout } from 'nextra/components'
|
||||
import { Steps } from 'nextra/components'
|
||||
|
||||
# SharePoint / OneDrive Connector
|
||||
|
||||
Connect your SharePoint or OneDrive account to upload and process files directly as an external knowledge base. Supports Office files, PDFs, text files, CSVs, images, and more. Authentication is handled via Microsoft Entra ID (Azure AD) with automatic token refresh.
|
||||
|
||||
## Setup
|
||||
|
||||
<Steps>
|
||||
|
||||
### Step 1: Create an App Registration in Azure
|
||||
|
||||
1. Go to the [Azure Portal](https://portal.azure.com/) > **Microsoft Entra ID** > **App registrations** > **New registration**
|
||||
2. Set **Redirect URI** (Web) to:
|
||||
- Local: `http://localhost:7091/api/connectors/callback?provider=share_point`
|
||||
- Production: `https://yourdomain.com/api/connectors/callback?provider=share_point`
|
||||
|
||||
### Step 2: Configure API Permissions
|
||||
|
||||
In your App Registration, go to **API permissions** > **Add a permission** > **Microsoft Graph** > **Delegated permissions** and add: `Files.Read`, `Files.Read.All`, `Sites.Read.All`. Grant admin consent if possible.
|
||||
|
||||
### Step 3: Create a Client Secret
|
||||
|
||||
Go to **Certificates & secrets** > **New client secret**. Copy the secret value immediately (it won't be shown again).
|
||||
|
||||
### Step 4: Configure Environment Variables
|
||||
|
||||
Add to your `.env` file:
|
||||
|
||||
```env
|
||||
MICROSOFT_CLIENT_ID=your-azure-ad-client-id
|
||||
MICROSOFT_CLIENT_SECRET=your-azure-ad-client-secret
|
||||
MICROSOFT_TENANT_ID=your-azure-ad-tenant-id
|
||||
```
|
||||
|
||||
| Variable | Description | Required | Default |
|
||||
|----------|-------------|----------|---------|
|
||||
| `MICROSOFT_CLIENT_ID` | Application (client) ID from App Registration overview | Yes | — |
|
||||
| `MICROSOFT_CLIENT_SECRET` | Client secret value | Yes | — |
|
||||
| `MICROSOFT_TENANT_ID` | Directory (tenant) ID | No | `common` |
|
||||
| `MICROSOFT_AUTHORITY` | Login endpoint override | No | Auto-constructed |
|
||||
|
||||
<Callout type="warning">
|
||||
`MICROSOFT_TENANT_ID=common` (the default) allows any Microsoft account to authenticate. Set this to your specific tenant ID in production.
|
||||
</Callout>
|
||||
|
||||
### Step 5: Restart and Use
|
||||
|
||||
Restart your application, then go to the upload section in DocsGPT and select **SharePoint / OneDrive** as the source. You'll be redirected to Microsoft to sign in, then can browse and select files to process.
|
||||
|
||||
</Steps>
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
- **Option not appearing** — Verify `MICROSOFT_CLIENT_ID` and `MICROSOFT_CLIENT_SECRET` are set, then restart.
|
||||
- **Authentication failed** — Check that the redirect URI matches exactly, including `?provider=share_point`.
|
||||
- **Permission denied** — Ensure admin consent is granted and the user has access to the target files.
|
||||
@@ -7,20 +7,10 @@ description:
|
||||
|
||||
If your AI uses external knowledge and is not explicit enough, it is ok, because we try to make DocsGPT friendly.
|
||||
|
||||
But if you want to adjust it, here is a simple way:-
|
||||
|
||||
- Got to `application/prompts/chat_combine_prompt.txt`
|
||||
|
||||
- And change it to
|
||||
But if you want to adjust it, prompts are now managed through the UI and API using a template-based system. See the [Customising Prompts](/Guides/Customising-prompts) guide for details.
|
||||
|
||||
To make the AI stricter about staying on-topic, edit your active prompt template (via **Sidebar → Settings → Active Prompt**) to include instructions like:
|
||||
|
||||
```
|
||||
|
||||
You are a DocsGPT, friendly and helpful AI assistant by Arc53 that provides help with documents. You give thorough answers with code examples, if possible.
|
||||
Write an answer for the question below based on the provided context.
|
||||
If the context provides insufficient information, reply "I cannot answer".
|
||||
You have access to chat history and can use it to help answer the question.
|
||||
----------------
|
||||
{summaries}
|
||||
|
||||
```
|
||||
|
||||
@@ -29,7 +29,7 @@ export default {
|
||||
"title": "OCR",
|
||||
"href": "/Guides/ocr"
|
||||
},
|
||||
"Integrations": {
|
||||
"Integrations": {
|
||||
"title": "🔗 Integrations"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -70,7 +70,7 @@ The easiest way to launch DocsGPT is using the provided `setup.sh` script. This
|
||||
To stop DocsGPT, simply open a new terminal in the `DocsGPT` directory and run:
|
||||
|
||||
```bash
|
||||
docker compose -f deployment/docker-compose.yaml down
|
||||
docker compose -f deployment/docker-compose-hub.yaml down
|
||||
```
|
||||
(or the specific `docker compose` command shown at the end of the `setup.sh` execution, which may include optional compose files depending on your choices).
|
||||
|
||||
|
||||
4
extensions/react-widget/package-lock.json
generated
4
extensions/react-widget/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "docsgpt",
|
||||
"version": "0.5.1",
|
||||
"version": "0.6.3",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "docsgpt",
|
||||
"version": "0.5.1",
|
||||
"version": "0.6.3",
|
||||
"license": "Apache-2.0",
|
||||
"dependencies": {
|
||||
"@babel/plugin-transform-flow-strip-types": "^7.23.3",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "docsgpt",
|
||||
"version": "0.6.1",
|
||||
"version": "0.6.3",
|
||||
"private": false,
|
||||
"description": "DocsGPT 🦖 is an innovative open-source tool designed to simplify the retrieval of information from project documentation using advanced GPT models 🤖.",
|
||||
"source": "./src/index.html",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
aiohttp>=3,<4
|
||||
certifi==2024.7.4
|
||||
h11==0.16.0
|
||||
h11==0.14.0
|
||||
httpcore==1.0.5
|
||||
httpx==0.27.0
|
||||
idna==3.7
|
||||
|
||||
@@ -40,9 +40,12 @@ import {
|
||||
} from '@/components/ui/select';
|
||||
import { Sheet, SheetContent } from '@/components/ui/sheet';
|
||||
|
||||
import { useSelector } from 'react-redux';
|
||||
|
||||
import modelService from '../api/services/modelService';
|
||||
import userService from '../api/services/userService';
|
||||
import ArrowLeft from '../assets/arrow-left.svg';
|
||||
import { selectToken } from '../preferences/preferenceSlice';
|
||||
import { WorkflowNode } from './types/workflow';
|
||||
import {
|
||||
AgentNode,
|
||||
@@ -77,6 +80,7 @@ interface UserTool {
|
||||
|
||||
function WorkflowBuilderInner() {
|
||||
const navigate = useNavigate();
|
||||
const token = useSelector(selectToken);
|
||||
const { agentId } = useParams<{ agentId?: string }>();
|
||||
const [searchParams] = useSearchParams();
|
||||
const folderId = searchParams.get('folder_id');
|
||||
@@ -304,7 +308,7 @@ function WorkflowBuilderInner() {
|
||||
setAvailableModels(modelService.transformModels(modelsData.models));
|
||||
}
|
||||
|
||||
const toolsResponse = await userService.getUserTools(null);
|
||||
const toolsResponse = await userService.getUserTools(token);
|
||||
if (toolsResponse.ok) {
|
||||
const toolsData = await toolsResponse.json();
|
||||
setAvailableTools(toolsData.tools);
|
||||
|
||||
@@ -54,7 +54,10 @@ import { FileUpload } from '../../components/FileUpload';
|
||||
import AgentDetailsModal from '../../modals/AgentDetailsModal';
|
||||
import ConfirmationModal from '../../modals/ConfirmationModal';
|
||||
import { ActiveState } from '../../models/misc';
|
||||
import { selectToken } from '../../preferences/preferenceSlice';
|
||||
import {
|
||||
selectSourceDocs,
|
||||
selectToken,
|
||||
} from '../../preferences/preferenceSlice';
|
||||
import { getToolDisplayName } from '../../utils/toolUtils';
|
||||
import { Agent } from '../types';
|
||||
import { ConditionCase, WorkflowNode } from '../types/workflow';
|
||||
@@ -300,6 +303,7 @@ function createWorkflowPayload(
|
||||
function WorkflowBuilderInner() {
|
||||
const navigate = useNavigate();
|
||||
const token = useSelector(selectToken);
|
||||
const sourceDocs = useSelector(selectSourceDocs);
|
||||
const { agentId } = useParams<{ agentId?: string }>();
|
||||
const [searchParams] = useSearchParams();
|
||||
const folderId = searchParams.get('folder_id');
|
||||
@@ -341,6 +345,14 @@ function WorkflowBuilderInner() {
|
||||
const [availableModels, setAvailableModels] = useState<Model[]>([]);
|
||||
const [defaultAgentModelId, setDefaultAgentModelId] = useState('');
|
||||
const [availableTools, setAvailableTools] = useState<UserTool[]>([]);
|
||||
const sourceOptions = useMemo(
|
||||
() =>
|
||||
(sourceDocs ?? []).map((doc) => ({
|
||||
value: doc.id ?? 'default',
|
||||
label: doc.name,
|
||||
})),
|
||||
[sourceDocs],
|
||||
);
|
||||
const [agentJsonSchemaDrafts, setAgentJsonSchemaDrafts] = useState<
|
||||
Record<string, string>
|
||||
>({});
|
||||
@@ -387,31 +399,39 @@ function WorkflowBuilderInner() {
|
||||
[],
|
||||
);
|
||||
|
||||
const onConnect = useCallback((params: Connection) => {
|
||||
setEdges((eds) => {
|
||||
const exists = eds.some(
|
||||
(e) =>
|
||||
e.source === params.source &&
|
||||
e.sourceHandle === params.sourceHandle &&
|
||||
e.target === params.target &&
|
||||
e.targetHandle === params.targetHandle,
|
||||
);
|
||||
if (exists) return eds;
|
||||
|
||||
const filtered = eds.filter(
|
||||
(e) =>
|
||||
!(
|
||||
const onConnect = useCallback(
|
||||
(params: Connection) => {
|
||||
setEdges((eds) => {
|
||||
const exists = eds.some(
|
||||
(e) =>
|
||||
e.source === params.source &&
|
||||
e.sourceHandle === (params.sourceHandle ?? null)
|
||||
) &&
|
||||
!(
|
||||
e.sourceHandle === params.sourceHandle &&
|
||||
e.target === params.target &&
|
||||
e.targetHandle === (params.targetHandle ?? null)
|
||||
),
|
||||
);
|
||||
return addEdge(params, filtered);
|
||||
});
|
||||
}, []);
|
||||
e.targetHandle === params.targetHandle,
|
||||
);
|
||||
if (exists) return eds;
|
||||
|
||||
const targetNode = nodes.find((n) => n.id === params.target);
|
||||
const isEndNode = targetNode?.type === 'end';
|
||||
|
||||
const filtered = eds.filter(
|
||||
(e) =>
|
||||
!(
|
||||
e.source === params.source &&
|
||||
e.sourceHandle === (params.sourceHandle ?? null)
|
||||
) &&
|
||||
// End nodes accept multiple incoming edges
|
||||
(isEndNode ||
|
||||
!(
|
||||
e.target === params.target &&
|
||||
e.targetHandle === (params.targetHandle ?? null)
|
||||
)),
|
||||
);
|
||||
return addEdge(params, filtered);
|
||||
});
|
||||
},
|
||||
[nodes],
|
||||
);
|
||||
|
||||
const onEdgeClick = useCallback((_event: React.MouseEvent, edge: Edge) => {
|
||||
setEdges((eds) => eds.filter((e) => e.id !== edge.id));
|
||||
@@ -701,7 +721,7 @@ function WorkflowBuilderInner() {
|
||||
setDefaultAgentModelId(preferredDefaultModel);
|
||||
}
|
||||
|
||||
const toolsResponse = await userService.getUserTools(null);
|
||||
const toolsResponse = await userService.getUserTools(token);
|
||||
if (toolsResponse.ok) {
|
||||
const toolsData = await toolsResponse.json();
|
||||
setAvailableTools(toolsData.tools);
|
||||
@@ -1271,8 +1291,8 @@ function WorkflowBuilderInner() {
|
||||
|
||||
const handlePrimaryAction = useCallback(() => {
|
||||
if (isPrimaryActionDisabled) return;
|
||||
void persistWorkflow(!canManageAgent);
|
||||
}, [isPrimaryActionDisabled, persistWorkflow, canManageAgent]);
|
||||
void persistWorkflow(false);
|
||||
}, [isPrimaryActionDisabled, persistWorkflow]);
|
||||
|
||||
const agentForDetails = useMemo<Agent>(
|
||||
() => ({
|
||||
@@ -1910,6 +1930,28 @@ function WorkflowBuilderInner() {
|
||||
emptyText="No tools available"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label className="mb-2 block text-sm font-medium text-gray-700 dark:text-gray-300">
|
||||
Sources
|
||||
</label>
|
||||
<MultiSelect
|
||||
options={sourceOptions}
|
||||
selected={
|
||||
selectedNode.data.config?.sources || []
|
||||
}
|
||||
onChange={(newSources) =>
|
||||
handleUpdateNodeData({
|
||||
config: {
|
||||
...(selectedNode.data.config || {}),
|
||||
sources: newSources,
|
||||
},
|
||||
})
|
||||
}
|
||||
placeholder="Select sources..."
|
||||
searchPlaceholder="Search sources..."
|
||||
emptyText="No sources available"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label className="mb-2 block text-sm font-medium text-gray-700 dark:text-gray-300">
|
||||
Structured Output (JSON Schema)
|
||||
|
||||
@@ -70,12 +70,12 @@ export function MultiSelect({
|
||||
role="combobox"
|
||||
aria-expanded={open}
|
||||
className={cn(
|
||||
'w-full justify-between border-[#E5E5E5] bg-white hover:bg-gray-50 dark:border-[#3A3A3A] dark:bg-[#2C2C2C] dark:hover:bg-[#383838]',
|
||||
'h-auto min-h-[2.5rem] w-full justify-between border-[#E5E5E5] bg-white py-1.5 hover:bg-gray-50 dark:border-[#3A3A3A] dark:bg-[#2C2C2C] dark:hover:bg-[#383838]',
|
||||
!selected.length && 'text-gray-500 dark:text-gray-400',
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<div className="flex flex-wrap gap-1">
|
||||
<div className="flex min-w-0 flex-wrap gap-1">
|
||||
{selected.length === 0 ? (
|
||||
placeholder
|
||||
) : (
|
||||
@@ -85,9 +85,9 @@ export function MultiSelect({
|
||||
return (
|
||||
<span
|
||||
key={option?.value || label}
|
||||
className="dark:bg-purple-30/30 bg-violets-are-blue/20 inline-flex items-center gap-1 rounded-md px-2 py-0.5 text-xs font-medium text-purple-700 dark:text-purple-300"
|
||||
className="dark:bg-purple-30/30 bg-violets-are-blue/20 inline-flex max-w-[calc(100%-1rem)] items-center gap-1 rounded-md px-2 py-0.5 text-xs font-medium text-purple-700 dark:text-purple-300"
|
||||
>
|
||||
{label}
|
||||
<span className="truncate">{label}</span>
|
||||
<span
|
||||
role="button"
|
||||
tabIndex={0}
|
||||
|
||||
@@ -3,7 +3,7 @@ testpaths = tests
|
||||
python_files = test_*.py
|
||||
python_classes = Test*
|
||||
python_functions = test_*
|
||||
addopts =
|
||||
addopts =
|
||||
-v
|
||||
--strict-markers
|
||||
--tb=short
|
||||
@@ -11,6 +11,7 @@ addopts =
|
||||
--cov-report=html
|
||||
--cov-report=term-missing
|
||||
--cov-report=xml
|
||||
--ignore=tests/integration
|
||||
markers =
|
||||
unit: Unit tests
|
||||
integration: Integration tests
|
||||
|
||||
@@ -977,8 +977,8 @@ function Connect-CloudAPIProvider {
|
||||
}
|
||||
"7" { # Novita
|
||||
$script:provider_name = "Novita"
|
||||
$script:llm_name = "novita"
|
||||
$script:model_name = "deepseek/deepseek-r1"
|
||||
$script:llm_provider = "novita"
|
||||
$script:model_name = "moonshotai/kimi-k2.5"
|
||||
Get-APIKey
|
||||
break
|
||||
}
|
||||
|
||||
2
setup.sh
2
setup.sh
@@ -704,7 +704,7 @@ connect_cloud_api_provider() {
|
||||
7) # Novita
|
||||
provider_name="Novita"
|
||||
llm_provider="novita"
|
||||
model_name="deepseek/deepseek-r1"
|
||||
model_name="moonshotai/kimi-k2.5"
|
||||
get_api_key
|
||||
break ;;
|
||||
b|B) clear; return 1 ;; # Clear screen and Back to Main Menu
|
||||
|
||||
202
tests/agents/test_api_body_serializer.py
Normal file
202
tests/agents/test_api_body_serializer.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""Tests for application/agents/tools/api_body_serializer.py"""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from application.agents.tools.api_body_serializer import (
|
||||
ContentType,
|
||||
RequestBodySerializer,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestContentTypeEnum:
|
||||
def test_json_value(self):
|
||||
assert ContentType.JSON == "application/json"
|
||||
|
||||
def test_form_urlencoded_value(self):
|
||||
assert ContentType.FORM_URLENCODED == "application/x-www-form-urlencoded"
|
||||
|
||||
def test_multipart_value(self):
|
||||
assert ContentType.MULTIPART_FORM_DATA == "multipart/form-data"
|
||||
|
||||
def test_text_plain_value(self):
|
||||
assert ContentType.TEXT_PLAIN == "text/plain"
|
||||
|
||||
def test_xml_value(self):
|
||||
assert ContentType.XML == "application/xml"
|
||||
|
||||
def test_octet_stream_value(self):
|
||||
assert ContentType.OCTET_STREAM == "application/octet-stream"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSerializeJson:
|
||||
def test_basic_json(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"key": "value"}, ContentType.JSON
|
||||
)
|
||||
assert json.loads(body) == {"key": "value"}
|
||||
assert headers["Content-Type"] == "application/json"
|
||||
|
||||
def test_nested_json(self):
|
||||
data = {"user": {"name": "Alice", "age": 30}}
|
||||
body, headers = RequestBodySerializer.serialize(data, ContentType.JSON)
|
||||
assert json.loads(body) == data
|
||||
|
||||
def test_empty_body_returns_none(self):
|
||||
body, headers = RequestBodySerializer.serialize({}, ContentType.JSON)
|
||||
assert body is None
|
||||
assert headers == {}
|
||||
|
||||
def test_none_body(self):
|
||||
body, headers = RequestBodySerializer.serialize(None, ContentType.JSON)
|
||||
assert body is None
|
||||
|
||||
def test_unknown_content_type_falls_back_to_json(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"k": "v"}, "application/vnd.custom+json"
|
||||
)
|
||||
assert json.loads(body) == {"k": "v"}
|
||||
|
||||
def test_content_type_with_charset_suffix(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"k": "v"}, "application/json; charset=utf-8"
|
||||
)
|
||||
assert json.loads(body) == {"k": "v"}
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSerializeFormUrlencoded:
|
||||
def test_basic_form(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"name": "Alice", "age": "30"}, ContentType.FORM_URLENCODED
|
||||
)
|
||||
assert "name=Alice" in body
|
||||
assert "age=30" in body
|
||||
assert headers["Content-Type"] == "application/x-www-form-urlencoded"
|
||||
|
||||
def test_none_values_skipped(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"name": "Alice", "skip": None}, ContentType.FORM_URLENCODED
|
||||
)
|
||||
assert "name=Alice" in body
|
||||
assert "skip" not in body
|
||||
|
||||
def test_list_explode_true(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"tags": ["a", "b"]},
|
||||
ContentType.FORM_URLENCODED,
|
||||
encoding_rules={"tags": {"style": "form", "explode": True}},
|
||||
)
|
||||
assert "tags=a" in body
|
||||
assert "tags=b" in body
|
||||
|
||||
def test_list_explode_false(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"tags": ["a", "b"]},
|
||||
ContentType.FORM_URLENCODED,
|
||||
encoding_rules={"tags": {"style": "form", "explode": False}},
|
||||
)
|
||||
# Value is percent-encoded by _serialize_form_value then urlencoded again
|
||||
assert "tags=" in body
|
||||
assert "a" in body and "b" in body
|
||||
|
||||
def test_dict_value_json_content_type(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"metadata": {"key": "val"}},
|
||||
ContentType.FORM_URLENCODED,
|
||||
encoding_rules={"metadata": {"contentType": "application/json"}},
|
||||
)
|
||||
assert "metadata" in body
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSerializeTextPlain:
|
||||
def test_single_value(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"message": "hello"}, ContentType.TEXT_PLAIN
|
||||
)
|
||||
assert body == "hello"
|
||||
assert headers["Content-Type"] == "text/plain"
|
||||
|
||||
def test_multiple_values(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"name": "Alice", "age": 30}, ContentType.TEXT_PLAIN
|
||||
)
|
||||
assert "name: Alice" in body
|
||||
assert "age: 30" in body
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSerializeXml:
|
||||
def test_basic_xml(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"name": "Alice"}, ContentType.XML
|
||||
)
|
||||
assert '<?xml version="1.0"' in body
|
||||
assert "<name>Alice</name>" in body
|
||||
assert headers["Content-Type"] == "application/xml"
|
||||
|
||||
def test_nested_xml(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"user": {"name": "Alice"}}, ContentType.XML
|
||||
)
|
||||
assert "<user>" in body
|
||||
assert "<name>Alice</name>" in body
|
||||
|
||||
def test_xml_escapes_special_chars(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"data": "<script>alert('xss')</script>"}, ContentType.XML
|
||||
)
|
||||
assert "<script>" in body
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSerializeOctetStream:
|
||||
def test_dict_body(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"key": "val"}, ContentType.OCTET_STREAM
|
||||
)
|
||||
assert isinstance(body, bytes)
|
||||
assert headers["Content-Type"] == "application/octet-stream"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSerializeMultipartFormData:
|
||||
def test_basic_multipart(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"field": "value"}, ContentType.MULTIPART_FORM_DATA
|
||||
)
|
||||
assert isinstance(body, bytes)
|
||||
assert "multipart/form-data" in headers["Content-Type"]
|
||||
assert "boundary=" in headers["Content-Type"]
|
||||
|
||||
def test_none_values_skipped(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"field": "value", "empty": None}, ContentType.MULTIPART_FORM_DATA
|
||||
)
|
||||
body_str = body.decode("utf-8", errors="replace")
|
||||
assert "field" in body_str
|
||||
assert "empty" not in body_str
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestHelpers:
|
||||
def test_percent_encode(self):
|
||||
assert RequestBodySerializer._percent_encode("hello world") == "hello%20world"
|
||||
assert RequestBodySerializer._percent_encode("a/b") == "a%2Fb"
|
||||
assert RequestBodySerializer._percent_encode("safe", safe_chars="/") == "safe"
|
||||
|
||||
def test_escape_xml(self):
|
||||
assert "&" in RequestBodySerializer._escape_xml("&")
|
||||
assert "<" in RequestBodySerializer._escape_xml("<")
|
||||
assert ">" in RequestBodySerializer._escape_xml(">")
|
||||
assert """ in RequestBodySerializer._escape_xml('"')
|
||||
assert "'" in RequestBodySerializer._escape_xml("'")
|
||||
|
||||
def test_dict_to_xml_list(self):
|
||||
xml = RequestBodySerializer._dict_to_xml({"items": [1, 2, 3]})
|
||||
assert "<item>1</item>" in xml
|
||||
assert "<item>2</item>" in xml
|
||||
282
tests/agents/test_api_tool.py
Normal file
282
tests/agents/test_api_tool.py
Normal file
@@ -0,0 +1,282 @@
|
||||
"""Tests for application/agents/tools/api_tool.py"""
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from application.agents.tools.api_tool import APITool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool():
|
||||
return APITool(
|
||||
config={
|
||||
"url": "https://api.example.com/data",
|
||||
"method": "GET",
|
||||
"headers": {"Accept": "application/json"},
|
||||
"query_params": {},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def post_tool():
|
||||
return APITool(
|
||||
config={
|
||||
"url": "https://api.example.com/items",
|
||||
"method": "POST",
|
||||
"headers": {},
|
||||
"query_params": {},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestAPIToolInit:
|
||||
def test_default_values(self):
|
||||
tool = APITool(config={})
|
||||
assert tool.url == ""
|
||||
assert tool.method == "GET"
|
||||
assert tool.headers == {}
|
||||
assert tool.query_params == {}
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestMakeApiCall:
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_successful_get(self, mock_get, mock_validate, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {"result": "ok"}
|
||||
mock_resp.content = b'{"result":"ok"}'
|
||||
mock_get.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("any_action")
|
||||
|
||||
assert result["status_code"] == 200
|
||||
assert result["data"] == {"result": "ok"}
|
||||
assert result["message"] == "API call successful."
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.post")
|
||||
def test_successful_post(self, mock_post, mock_validate, post_tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 201
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {"id": 1}
|
||||
mock_resp.content = b'{"id":1}'
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
result = post_tool.execute_action("create", name="test")
|
||||
|
||||
assert result["status_code"] == 201
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
def test_ssrf_blocked(self, mock_validate, tool):
|
||||
from application.core.url_validation import SSRFError
|
||||
|
||||
mock_validate.side_effect = SSRFError("blocked")
|
||||
|
||||
result = tool.execute_action("any")
|
||||
|
||||
assert result["status_code"] is None
|
||||
assert "URL validation error" in result["message"]
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_timeout_error(self, mock_get, mock_validate, tool):
|
||||
mock_get.side_effect = requests.exceptions.Timeout()
|
||||
|
||||
result = tool.execute_action("any")
|
||||
|
||||
assert result["status_code"] is None
|
||||
assert "timeout" in result["message"].lower()
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_connection_error(self, mock_get, mock_validate, tool):
|
||||
mock_get.side_effect = requests.exceptions.ConnectionError("refused")
|
||||
|
||||
result = tool.execute_action("any")
|
||||
|
||||
assert result["status_code"] is None
|
||||
assert "Connection error" in result["message"]
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_http_error(self, mock_get, mock_validate, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 404
|
||||
mock_resp.text = "Not Found"
|
||||
mock_resp.json.side_effect = json.JSONDecodeError("", "", 0)
|
||||
mock_resp.raise_for_status.side_effect = requests.exceptions.HTTPError(
|
||||
response=mock_resp
|
||||
)
|
||||
mock_get.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("any")
|
||||
|
||||
assert result["status_code"] == 404
|
||||
assert "HTTP Error" in result["message"]
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
def test_unsupported_method(self, mock_validate):
|
||||
tool = APITool(
|
||||
config={"url": "https://example.com", "method": "CUSTOM"}
|
||||
)
|
||||
result = tool.execute_action("any")
|
||||
assert result["status_code"] is None
|
||||
assert "Unsupported" in result["message"]
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.put")
|
||||
def test_put_method(self, mock_put, mock_validate):
|
||||
tool = APITool(config={"url": "https://example.com/item/1", "method": "PUT"})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {}
|
||||
mock_resp.content = b'{}'
|
||||
mock_put.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("update", name="new")
|
||||
assert result["status_code"] == 200
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.delete")
|
||||
def test_delete_method(self, mock_delete, mock_validate):
|
||||
tool = APITool(config={"url": "https://example.com/item/1", "method": "DELETE"})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 204
|
||||
mock_resp.headers = {"Content-Type": "text/plain"}
|
||||
mock_resp.content = b''
|
||||
mock_delete.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("delete")
|
||||
assert result["status_code"] == 204
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.patch")
|
||||
def test_patch_method(self, mock_patch, mock_validate):
|
||||
tool = APITool(config={"url": "https://example.com/item/1", "method": "PATCH"})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {"patched": True}
|
||||
mock_resp.content = b'{"patched":true}'
|
||||
mock_patch.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("patch", field="val")
|
||||
assert result["status_code"] == 200
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.head")
|
||||
def test_head_method(self, mock_head, mock_validate):
|
||||
tool = APITool(config={"url": "https://example.com", "method": "HEAD"})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "text/html"}
|
||||
mock_resp.content = b''
|
||||
mock_head.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("check")
|
||||
assert result["status_code"] == 200
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.options")
|
||||
def test_options_method(self, mock_options, mock_validate):
|
||||
tool = APITool(config={"url": "https://example.com", "method": "OPTIONS"})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "text/plain"}
|
||||
mock_resp.content = b''
|
||||
mock_options.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("options")
|
||||
assert result["status_code"] == 200
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestPathParamSubstitution:
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_path_params_substituted(self, mock_get, mock_validate):
|
||||
tool = APITool(
|
||||
config={
|
||||
"url": "https://api.example.com/users/{user_id}/posts/{post_id}",
|
||||
"method": "GET",
|
||||
"query_params": {"user_id": "42", "post_id": "7"},
|
||||
}
|
||||
)
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = []
|
||||
mock_resp.content = b'[]'
|
||||
mock_get.return_value = mock_resp
|
||||
|
||||
tool.execute_action("get")
|
||||
|
||||
called_url = mock_get.call_args[0][0]
|
||||
assert "/users/42/posts/7" in called_url
|
||||
assert "{user_id}" not in called_url
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestParseResponse:
|
||||
def test_json_response(self, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {"key": "val"}
|
||||
mock_resp.content = b'{"key":"val"}'
|
||||
|
||||
result = tool._parse_response(mock_resp)
|
||||
assert result == {"key": "val"}
|
||||
|
||||
def test_text_response(self, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.headers = {"Content-Type": "text/plain"}
|
||||
mock_resp.text = "plain text"
|
||||
mock_resp.content = b"plain text"
|
||||
|
||||
result = tool._parse_response(mock_resp)
|
||||
assert result == "plain text"
|
||||
|
||||
def test_xml_response(self, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.headers = {"Content-Type": "application/xml"}
|
||||
mock_resp.text = "<root><item>1</item></root>"
|
||||
mock_resp.content = b"<root><item>1</item></root>"
|
||||
|
||||
result = tool._parse_response(mock_resp)
|
||||
assert "<root>" in result
|
||||
|
||||
def test_empty_content(self, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.content = b""
|
||||
|
||||
result = tool._parse_response(mock_resp)
|
||||
assert result is None
|
||||
|
||||
def test_html_response(self, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.headers = {"Content-Type": "text/html"}
|
||||
mock_resp.text = "<html><body>Hi</body></html>"
|
||||
mock_resp.content = b"<html><body>Hi</body></html>"
|
||||
|
||||
result = tool._parse_response(mock_resp)
|
||||
assert "<html>" in result
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestAPIToolMetadata:
|
||||
def test_actions_metadata_empty(self, tool):
|
||||
assert tool.get_actions_metadata() == []
|
||||
|
||||
def test_config_requirements_empty(self, tool):
|
||||
assert tool.get_config_requirements() == {}
|
||||
@@ -1,4 +1,4 @@
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from application.agents.classic_agent import ClassicAgent
|
||||
@@ -56,6 +56,73 @@ class TestBaseAgentInitialization:
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
assert agent.user == "user123"
|
||||
|
||||
def test_dependency_injection_llm(self, agent_base_params, mock_llm_handler_creator):
|
||||
"""When llm is provided, LLMCreator.create_llm is NOT called."""
|
||||
injected_llm = Mock()
|
||||
agent_base_params["llm"] = injected_llm
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
assert agent.llm is injected_llm
|
||||
|
||||
def test_dependency_injection_llm_handler(self, agent_base_params, mock_llm_creator):
|
||||
"""When llm_handler is provided, LLMHandlerCreator is NOT called."""
|
||||
injected_handler = Mock()
|
||||
agent_base_params["llm_handler"] = injected_handler
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
assert agent.llm_handler is injected_handler
|
||||
|
||||
def test_dependency_injection_tool_executor(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
"""When tool_executor is provided, a new one is NOT created."""
|
||||
injected_executor = Mock()
|
||||
injected_executor.tool_calls = []
|
||||
agent_base_params["tool_executor"] = injected_executor
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
assert agent.tool_executor is injected_executor
|
||||
|
||||
def test_json_schema_normalized(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent_base_params["json_schema"] = {"type": "object"}
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
assert agent.json_schema == {"type": "object"}
|
||||
|
||||
def test_json_schema_wrapped(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent_base_params["json_schema"] = {"schema": {"type": "string"}}
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
assert agent.json_schema == {"type": "string"}
|
||||
|
||||
def test_json_schema_invalid_ignored(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent_base_params["json_schema"] = {"bad": "no type"}
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
assert agent.json_schema is None
|
||||
|
||||
def test_retrieved_docs_defaults_to_empty(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
assert agent.retrieved_docs == []
|
||||
|
||||
def test_attachments_defaults_to_empty(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent_base_params["attachments"] = None
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
assert agent.attachments == []
|
||||
|
||||
def test_limited_token_mode_defaults(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
assert agent.limited_token_mode is False
|
||||
assert agent.limited_request_mode is False
|
||||
assert agent.current_token_count == 0
|
||||
assert agent.context_limit_reached is False
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestBaseAgentBuildMessages:
|
||||
@@ -602,3 +669,656 @@ class TestBaseAgentHandleResponse:
|
||||
assert len(results) == 2
|
||||
assert results[0]["type"] == "tool_call"
|
||||
assert results[1]["answer"] == "Final answer"
|
||||
|
||||
def test_handle_response_dict_event_passthrough(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
log_context,
|
||||
):
|
||||
"""Dict events with 'type' key pass through without wrapping."""
|
||||
|
||||
def mock_process(*args):
|
||||
yield {"type": "info", "data": {"message": "processing"}}
|
||||
|
||||
mock_llm_handler.process_message_flow = Mock(side_effect=mock_process)
|
||||
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
response = Mock()
|
||||
response.message = None
|
||||
|
||||
results = list(agent._handle_response(response, {}, [], log_context))
|
||||
assert results == [{"type": "info", "data": {"message": "processing"}}]
|
||||
|
||||
def test_handle_response_message_object_from_handler(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
log_context,
|
||||
):
|
||||
"""Response objects with .message.content from handler are unwrapped."""
|
||||
event = Mock()
|
||||
event.message = Mock()
|
||||
event.message.content = "from handler"
|
||||
|
||||
def mock_process(*args):
|
||||
yield event
|
||||
|
||||
mock_llm_handler.process_message_flow = Mock(side_effect=mock_process)
|
||||
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
response = Mock()
|
||||
response.message = None
|
||||
|
||||
results = list(agent._handle_response(response, {}, [], log_context))
|
||||
assert results[0]["answer"] == "from handler"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# gen() — the @log_activity decorated entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestBaseAgentGen:
|
||||
|
||||
def test_gen_delegates_to_gen_inner(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
# ClassicAgent._gen_inner is abstract — we patch it
|
||||
with patch.object(agent, "_gen_inner") as mock_inner:
|
||||
mock_inner.return_value = iter([{"answer": "ok"}])
|
||||
results = list(agent.gen("hello"))
|
||||
|
||||
assert any(r.get("answer") == "ok" for r in results)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# tool_calls property
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestBaseAgentToolCallsProperty:
|
||||
|
||||
def test_getter(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
agent.tool_executor.tool_calls = ["a", "b"]
|
||||
assert agent.tool_calls == ["a", "b"]
|
||||
|
||||
def test_setter(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
agent.tool_calls = ["x"]
|
||||
assert agent.tool_executor.tool_calls == ["x"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _calculate_current_context_tokens
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCalculateContextTokens:
|
||||
|
||||
def test_delegates_to_token_counter(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
messages = [{"role": "user", "content": "hello"}]
|
||||
|
||||
with patch(
|
||||
"application.api.answer.services.compression.token_counter.TokenCounter"
|
||||
) as MockTC:
|
||||
MockTC.count_message_tokens.return_value = 42
|
||||
result = agent._calculate_current_context_tokens(messages)
|
||||
assert result == 42
|
||||
MockTC.count_message_tokens.assert_called_once_with(messages)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _check_context_limit
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCheckContextLimit:
|
||||
|
||||
def _make_agent(self, agent_base_params, mock_llm_creator, mock_llm_handler_creator):
|
||||
return ClassicAgent(**agent_base_params)
|
||||
|
||||
def test_below_threshold_returns_false(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = self._make_agent(
|
||||
agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
)
|
||||
messages = [{"role": "user", "content": "hi"}]
|
||||
|
||||
with patch.object(agent, "_calculate_current_context_tokens", return_value=100):
|
||||
with patch(
|
||||
"application.core.model_utils.get_token_limit", return_value=10000
|
||||
):
|
||||
result = agent._check_context_limit(messages)
|
||||
assert result is False
|
||||
assert agent.current_token_count == 100
|
||||
|
||||
def test_at_threshold_returns_true(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = self._make_agent(
|
||||
agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
)
|
||||
messages = [{"role": "user", "content": "hi"}]
|
||||
|
||||
# threshold = 10000 * 0.8 = 8000; tokens = 8001 → True
|
||||
with patch.object(agent, "_calculate_current_context_tokens", return_value=8001):
|
||||
with patch(
|
||||
"application.core.model_utils.get_token_limit", return_value=10000
|
||||
):
|
||||
result = agent._check_context_limit(messages)
|
||||
assert result is True
|
||||
|
||||
def test_error_returns_false(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = self._make_agent(
|
||||
agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
)
|
||||
with patch.object(
|
||||
agent,
|
||||
"_calculate_current_context_tokens",
|
||||
side_effect=RuntimeError("boom"),
|
||||
):
|
||||
result = agent._check_context_limit([])
|
||||
assert result is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _validate_context_size
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestValidateContextSize:
|
||||
|
||||
def test_at_limit_logs_warning(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
with patch.object(agent, "_calculate_current_context_tokens", return_value=10000):
|
||||
with patch(
|
||||
"application.core.model_utils.get_token_limit", return_value=10000
|
||||
):
|
||||
# Should not raise
|
||||
agent._validate_context_size([{"role": "user", "content": "x"}])
|
||||
assert agent.current_token_count == 10000
|
||||
|
||||
def test_below_threshold_no_warning(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
with patch.object(agent, "_calculate_current_context_tokens", return_value=100):
|
||||
with patch(
|
||||
"application.core.model_utils.get_token_limit", return_value=10000
|
||||
):
|
||||
agent._validate_context_size([])
|
||||
assert agent.current_token_count == 100
|
||||
|
||||
def test_approaching_threshold(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
# 8500 / 10000 = 85% → above 80% threshold but below 100%
|
||||
with patch.object(agent, "_calculate_current_context_tokens", return_value=8500):
|
||||
with patch(
|
||||
"application.core.model_utils.get_token_limit", return_value=10000
|
||||
):
|
||||
agent._validate_context_size([])
|
||||
assert agent.current_token_count == 8500
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _truncate_text_middle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestTruncateTextMiddle:
|
||||
|
||||
def test_short_text_unchanged(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
with patch("application.utils.num_tokens_from_string", return_value=5):
|
||||
result = agent._truncate_text_middle("short", max_tokens=100)
|
||||
assert result == "short"
|
||||
|
||||
def test_long_text_truncated(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
long_text = "A" * 1000
|
||||
|
||||
def fake_tokens(text):
|
||||
return len(text) // 4
|
||||
|
||||
with patch("application.utils.num_tokens_from_string", side_effect=fake_tokens):
|
||||
result = agent._truncate_text_middle(long_text, max_tokens=50)
|
||||
assert "[... content truncated to fit context limit ...]" in result
|
||||
assert len(result) < len(long_text)
|
||||
|
||||
def test_zero_target_returns_empty(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
with patch("application.utils.num_tokens_from_string", return_value=100):
|
||||
result = agent._truncate_text_middle("some text", max_tokens=0)
|
||||
assert result == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _truncate_history_to_fit
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestTruncateHistoryToFit:
|
||||
|
||||
def _make_agent(self, agent_base_params, mock_llm_creator, mock_llm_handler_creator):
|
||||
return ClassicAgent(**agent_base_params)
|
||||
|
||||
def test_empty_history(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = self._make_agent(
|
||||
agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
)
|
||||
assert agent._truncate_history_to_fit([], 100) == []
|
||||
|
||||
def test_zero_budget(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = self._make_agent(
|
||||
agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
)
|
||||
history = [{"prompt": "a", "response": "b"}]
|
||||
assert agent._truncate_history_to_fit(history, 0) == []
|
||||
|
||||
def test_fits_all(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = self._make_agent(
|
||||
agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
)
|
||||
history = [
|
||||
{"prompt": "q1", "response": "a1"},
|
||||
{"prompt": "q2", "response": "a2"},
|
||||
]
|
||||
with patch("application.utils.num_tokens_from_string", return_value=5):
|
||||
result = agent._truncate_history_to_fit(history, 10000)
|
||||
assert len(result) == 2
|
||||
|
||||
def test_partial_fit_keeps_most_recent(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = self._make_agent(
|
||||
agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
)
|
||||
history = [
|
||||
{"prompt": "old", "response": "old_ans"},
|
||||
{"prompt": "new", "response": "new_ans"},
|
||||
]
|
||||
# Each message = 10 tokens (prompt + response), budget = 15 → only 1 fits
|
||||
with patch("application.utils.num_tokens_from_string", return_value=5):
|
||||
result = agent._truncate_history_to_fit(history, 15)
|
||||
assert len(result) == 1
|
||||
assert result[0]["prompt"] == "new"
|
||||
|
||||
def test_history_with_tool_calls(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = self._make_agent(
|
||||
agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
)
|
||||
history = [
|
||||
{
|
||||
"prompt": "q",
|
||||
"response": "a",
|
||||
"tool_calls": [
|
||||
{
|
||||
"tool_name": "t",
|
||||
"action_name": "act",
|
||||
"arguments": "{}",
|
||||
"result": "ok",
|
||||
}
|
||||
],
|
||||
}
|
||||
]
|
||||
with patch("application.utils.num_tokens_from_string", return_value=3):
|
||||
result = agent._truncate_history_to_fit(history, 100)
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _build_messages — compressed_summary and query truncation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestBuildMessagesAdvanced:
|
||||
|
||||
def test_compressed_summary_appended(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent_base_params["compressed_summary"] = "Previous conversation summary"
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
with patch(
|
||||
"application.core.model_utils.get_token_limit", return_value=100000
|
||||
), patch("application.utils.num_tokens_from_string", return_value=10):
|
||||
messages = agent._build_messages("System prompt", "query")
|
||||
|
||||
system_content = messages[0]["content"]
|
||||
assert "Previous conversation summary" in system_content
|
||||
|
||||
def test_query_truncated_when_too_large(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
call_count = {"n": 0}
|
||||
|
||||
def fake_tokens(text):
|
||||
call_count["n"] += 1
|
||||
return len(text)
|
||||
|
||||
with patch(
|
||||
"application.core.model_utils.get_token_limit", return_value=200
|
||||
), patch("application.utils.num_tokens_from_string", side_effect=fake_tokens):
|
||||
with patch.object(agent, "_truncate_text_middle", return_value="truncated"):
|
||||
with patch.object(agent, "_truncate_history_to_fit", return_value=[]):
|
||||
messages = agent._build_messages("sys", "A" * 500)
|
||||
|
||||
# The method should have been called for truncation
|
||||
assert messages[-1]["role"] == "user"
|
||||
|
||||
def test_build_messages_with_tool_call_missing_call_id(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
"""Tool calls without call_id get a generated UUID."""
|
||||
history = [
|
||||
{
|
||||
"tool_calls": [
|
||||
{
|
||||
"action_name": "search",
|
||||
"arguments": "{}",
|
||||
"result": "found",
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
agent_base_params["chat_history"] = history
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
with patch(
|
||||
"application.core.model_utils.get_token_limit", return_value=100000
|
||||
), patch("application.utils.num_tokens_from_string", return_value=5):
|
||||
messages = agent._build_messages("sys", "q")
|
||||
|
||||
tool_msgs = [m for m in messages if m["role"] == "tool"]
|
||||
assert len(tool_msgs) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _llm_gen — edge cases
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestLLMGenAdvanced:
|
||||
|
||||
def test_llm_gen_with_attachments(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_llm,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
):
|
||||
agent_base_params["attachments"] = [{"id": "att1", "mime_type": "image/png"}]
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
messages = [{"role": "user", "content": "test"}]
|
||||
agent._llm_gen(messages)
|
||||
|
||||
call_kwargs = mock_llm.gen_stream.call_args[1]
|
||||
assert "_usage_attachments" in call_kwargs
|
||||
|
||||
def test_llm_gen_without_log_context(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_llm,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
messages = [{"role": "user", "content": "test"}]
|
||||
|
||||
# Should not raise even without log_context
|
||||
agent._llm_gen(messages, log_context=None)
|
||||
mock_llm.gen_stream.assert_called_once()
|
||||
|
||||
def test_llm_gen_google_structured_output(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_llm,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
log_context,
|
||||
):
|
||||
mock_llm._supports_structured_output = Mock(return_value=True)
|
||||
mock_llm.prepare_structured_output_format = Mock(
|
||||
return_value={"schema": "test"}
|
||||
)
|
||||
|
||||
agent_base_params["json_schema"] = {"type": "object"}
|
||||
agent_base_params["llm_name"] = "google"
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
messages = [{"role": "user", "content": "test"}]
|
||||
agent._llm_gen(messages, log_context)
|
||||
|
||||
call_kwargs = mock_llm.gen_stream.call_args[1]
|
||||
assert "response_schema" in call_kwargs
|
||||
|
||||
def test_llm_gen_no_tools_when_unsupported(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_llm,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
):
|
||||
mock_llm._supports_tools = False
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
agent.tools = [{"type": "function", "function": {"name": "test"}}]
|
||||
|
||||
messages = [{"role": "user", "content": "test"}]
|
||||
agent._llm_gen(messages)
|
||||
|
||||
call_kwargs = mock_llm.gen_stream.call_args[1]
|
||||
assert "tools" not in call_kwargs
|
||||
|
||||
def test_llm_gen_no_structured_output_when_unsupported(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_llm,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
):
|
||||
mock_llm._supports_structured_output = Mock(return_value=False)
|
||||
agent_base_params["json_schema"] = {"type": "object"}
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
messages = [{"role": "user", "content": "test"}]
|
||||
agent._llm_gen(messages)
|
||||
|
||||
call_kwargs = mock_llm.gen_stream.call_args[1]
|
||||
assert "response_format" not in call_kwargs
|
||||
assert "response_schema" not in call_kwargs
|
||||
|
||||
def test_llm_gen_no_format_when_prepare_returns_none(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_llm,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
):
|
||||
mock_llm._supports_structured_output = Mock(return_value=True)
|
||||
mock_llm.prepare_structured_output_format = Mock(return_value=None)
|
||||
|
||||
agent_base_params["json_schema"] = {"type": "object"}
|
||||
agent_base_params["llm_name"] = "openai"
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
messages = [{"role": "user", "content": "test"}]
|
||||
agent._llm_gen(messages)
|
||||
|
||||
call_kwargs = mock_llm.gen_stream.call_args[1]
|
||||
assert "response_format" not in call_kwargs
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _llm_handler
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestLLMHandlerMethod:
|
||||
|
||||
def test_delegates_to_handler(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
log_context,
|
||||
):
|
||||
mock_llm_handler.process_message_flow = Mock(return_value="result")
|
||||
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
resp = Mock()
|
||||
result = agent._llm_handler(resp, {}, [], log_context)
|
||||
|
||||
mock_llm_handler.process_message_flow.assert_called_once()
|
||||
assert result == "result"
|
||||
assert len(log_context.stacks) == 1
|
||||
assert log_context.stacks[0]["component"] == "llm_handler"
|
||||
|
||||
def test_without_log_context(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
):
|
||||
mock_llm_handler.process_message_flow = Mock(return_value="r")
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
result = agent._llm_handler(Mock(), {}, [], log_context=None)
|
||||
assert result == "r"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _handle_response — structured output on all code paths
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestHandleResponseStructuredAllPaths:
|
||||
|
||||
def test_message_response_with_structured_output(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_llm,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
log_context,
|
||||
):
|
||||
"""Structured output on the message.content early-return path."""
|
||||
mock_llm._supports_structured_output = Mock(return_value=True)
|
||||
agent_base_params["json_schema"] = {"type": "object"}
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
response = Mock()
|
||||
response.message = Mock()
|
||||
response.message.content = "structured msg"
|
||||
|
||||
results = list(agent._handle_response(response, {}, [], log_context))
|
||||
assert results[0]["structured"] is True
|
||||
assert results[0]["schema"] == {"type": "object"}
|
||||
assert results[0]["answer"] == "structured msg"
|
||||
|
||||
def test_handler_string_event_with_structured_output(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
log_context,
|
||||
):
|
||||
"""Structured output on string events from the handler."""
|
||||
mock_llm._supports_structured_output = Mock(return_value=True)
|
||||
agent_base_params["json_schema"] = {"type": "array"}
|
||||
|
||||
def mock_process(*args):
|
||||
yield "handler string"
|
||||
|
||||
mock_llm_handler.process_message_flow = Mock(side_effect=mock_process)
|
||||
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
response = Mock()
|
||||
response.message = None
|
||||
|
||||
results = list(agent._handle_response(response, {}, [], log_context))
|
||||
assert results[0]["structured"] is True
|
||||
assert results[0]["schema"] == {"type": "array"}
|
||||
|
||||
def test_handler_message_event_with_structured_output(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
log_context,
|
||||
):
|
||||
"""Structured output on message-object events from the handler."""
|
||||
mock_llm._supports_structured_output = Mock(return_value=True)
|
||||
agent_base_params["json_schema"] = {"type": "number"}
|
||||
|
||||
event = Mock()
|
||||
event.message = Mock()
|
||||
event.message.content = "from handler msg"
|
||||
|
||||
def mock_process(*args):
|
||||
yield event
|
||||
|
||||
mock_llm_handler.process_message_flow = Mock(side_effect=mock_process)
|
||||
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
response = Mock()
|
||||
response.message = None
|
||||
|
||||
results = list(agent._handle_response(response, {}, [], log_context))
|
||||
assert results[0]["structured"] is True
|
||||
assert results[0]["schema"] == {"type": "number"}
|
||||
assert results[0]["answer"] == "from handler msg"
|
||||
|
||||
132
tests/agents/test_brave_tool.py
Normal file
132
tests/agents/test_brave_tool.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""Tests for application/agents/tools/brave.py"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from application.agents.tools.brave import BraveSearchTool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool():
|
||||
return BraveSearchTool(config={"token": "test_api_key"})
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestBraveExecuteAction:
|
||||
def test_unknown_action_raises(self, tool):
|
||||
with pytest.raises(ValueError, match="Unknown action"):
|
||||
tool.execute_action("invalid")
|
||||
|
||||
@patch("application.agents.tools.brave.requests.get")
|
||||
def test_web_search_success(self, mock_get, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = {"web": {"results": [{"title": "Result"}]}}
|
||||
mock_get.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("brave_web_search", query="python")
|
||||
|
||||
assert result["status_code"] == 200
|
||||
assert "results" in result
|
||||
assert "successfully" in result["message"]
|
||||
|
||||
call_kwargs = mock_get.call_args
|
||||
assert call_kwargs[1]["headers"]["X-Subscription-Token"] == "test_api_key"
|
||||
|
||||
@patch("application.agents.tools.brave.requests.get")
|
||||
def test_web_search_failure(self, mock_get, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 429
|
||||
mock_get.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("brave_web_search", query="test")
|
||||
|
||||
assert result["status_code"] == 429
|
||||
assert "failed" in result["message"].lower()
|
||||
|
||||
@patch("application.agents.tools.brave.requests.get")
|
||||
def test_image_search_success(self, mock_get, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = {"results": [{"url": "https://img.com/1.jpg"}]}
|
||||
mock_get.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("brave_image_search", query="cats")
|
||||
|
||||
assert result["status_code"] == 200
|
||||
assert "results" in result
|
||||
|
||||
@patch("application.agents.tools.brave.requests.get")
|
||||
def test_image_search_failure(self, mock_get, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 500
|
||||
mock_get.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("brave_image_search", query="cats")
|
||||
|
||||
assert result["status_code"] == 500
|
||||
|
||||
@patch("application.agents.tools.brave.requests.get")
|
||||
def test_count_capped_at_20(self, mock_get, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = {}
|
||||
mock_get.return_value = mock_resp
|
||||
|
||||
tool.execute_action("brave_web_search", query="test", count=100)
|
||||
|
||||
params = mock_get.call_args[1]["params"]
|
||||
assert params["count"] == 20
|
||||
|
||||
@patch("application.agents.tools.brave.requests.get")
|
||||
def test_image_count_capped_at_100(self, mock_get, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = {}
|
||||
mock_get.return_value = mock_resp
|
||||
|
||||
tool.execute_action("brave_image_search", query="test", count=500)
|
||||
|
||||
params = mock_get.call_args[1]["params"]
|
||||
assert params["count"] == 100
|
||||
|
||||
@patch("application.agents.tools.brave.requests.get")
|
||||
def test_freshness_param(self, mock_get, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = {}
|
||||
mock_get.return_value = mock_resp
|
||||
|
||||
tool.execute_action("brave_web_search", query="news", freshness="pd")
|
||||
|
||||
params = mock_get.call_args[1]["params"]
|
||||
assert params["freshness"] == "pd"
|
||||
|
||||
@patch("application.agents.tools.brave.requests.get")
|
||||
def test_offset_capped(self, mock_get, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = {}
|
||||
mock_get.return_value = mock_resp
|
||||
|
||||
tool.execute_action("brave_web_search", query="test", offset=100)
|
||||
|
||||
params = mock_get.call_args[1]["params"]
|
||||
assert params["offset"] == 9
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestBraveMetadata:
|
||||
def test_actions_metadata(self, tool):
|
||||
meta = tool.get_actions_metadata()
|
||||
assert len(meta) == 2
|
||||
names = {a["name"] for a in meta}
|
||||
assert "brave_web_search" in names
|
||||
assert "brave_image_search" in names
|
||||
|
||||
def test_config_requirements(self, tool):
|
||||
reqs = tool.get_config_requirements()
|
||||
assert "token" in reqs
|
||||
assert reqs["token"]["secret"] is True
|
||||
assert reqs["token"]["required"] is True
|
||||
169
tests/agents/test_cel_evaluator.py
Normal file
169
tests/agents/test_cel_evaluator.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""Tests for application/agents/workflows/cel_evaluator.py"""
|
||||
|
||||
import pytest
|
||||
|
||||
from application.agents.workflows.cel_evaluator import (
|
||||
CelEvaluationError,
|
||||
_convert_value,
|
||||
build_activation,
|
||||
cel_to_python,
|
||||
evaluate_cel,
|
||||
)
|
||||
import celpy.celtypes
|
||||
|
||||
|
||||
class TestConvertValue:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_bool_true(self):
|
||||
result = _convert_value(True)
|
||||
assert isinstance(result, celpy.celtypes.BoolType)
|
||||
assert bool(result) is True
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_bool_false(self):
|
||||
result = _convert_value(False)
|
||||
assert isinstance(result, celpy.celtypes.BoolType)
|
||||
assert bool(result) is False
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_int(self):
|
||||
result = _convert_value(42)
|
||||
assert isinstance(result, celpy.celtypes.IntType)
|
||||
assert int(result) == 42
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_float(self):
|
||||
result = _convert_value(3.14)
|
||||
assert isinstance(result, celpy.celtypes.DoubleType)
|
||||
assert float(result) == pytest.approx(3.14)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_string(self):
|
||||
result = _convert_value("hello")
|
||||
assert isinstance(result, celpy.celtypes.StringType)
|
||||
assert str(result) == "hello"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_list(self):
|
||||
result = _convert_value([1, "two", 3.0])
|
||||
assert isinstance(result, celpy.celtypes.ListType)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_dict(self):
|
||||
result = _convert_value({"key": "value"})
|
||||
assert isinstance(result, celpy.celtypes.MapType)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_none(self):
|
||||
result = _convert_value(None)
|
||||
assert isinstance(result, celpy.celtypes.BoolType)
|
||||
assert bool(result) is False
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_other_type_converts_to_string(self):
|
||||
result = _convert_value(object())
|
||||
assert isinstance(result, celpy.celtypes.StringType)
|
||||
|
||||
|
||||
class TestBuildActivation:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_converts_dict_values(self):
|
||||
state = {"name": "Alice", "age": 30, "active": True}
|
||||
result = build_activation(state)
|
||||
assert "name" in result
|
||||
assert "age" in result
|
||||
assert "active" in result
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_empty_state(self):
|
||||
assert build_activation({}) == {}
|
||||
|
||||
|
||||
class TestEvaluateCel:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_simple_comparison(self):
|
||||
assert evaluate_cel("x > 5", {"x": 10}) is True
|
||||
assert evaluate_cel("x > 5", {"x": 3}) is False
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_string_comparison(self):
|
||||
assert evaluate_cel('name == "Alice"', {"name": "Alice"}) is True
|
||||
assert evaluate_cel('name == "Alice"', {"name": "Bob"}) is False
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_arithmetic(self):
|
||||
assert evaluate_cel("x + y", {"x": 3, "y": 4}) == 7
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_boolean_logic(self):
|
||||
assert evaluate_cel("a && b", {"a": True, "b": True}) is True
|
||||
assert evaluate_cel("a && b", {"a": True, "b": False}) is False
|
||||
assert evaluate_cel("a || b", {"a": False, "b": True}) is True
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_empty_expression_raises(self):
|
||||
with pytest.raises(CelEvaluationError, match="Empty expression"):
|
||||
evaluate_cel("", {})
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_whitespace_expression_raises(self):
|
||||
with pytest.raises(CelEvaluationError, match="Empty expression"):
|
||||
evaluate_cel(" ", {})
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invalid_expression_raises(self):
|
||||
with pytest.raises(CelEvaluationError):
|
||||
evaluate_cel("invalid!!!", {})
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_missing_variable_raises(self):
|
||||
with pytest.raises(CelEvaluationError):
|
||||
evaluate_cel("undefined_var > 5", {})
|
||||
|
||||
|
||||
class TestCelToPython:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_bool(self):
|
||||
result = cel_to_python(celpy.celtypes.BoolType(True))
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_int(self):
|
||||
result = cel_to_python(celpy.celtypes.IntType(42))
|
||||
assert result == 42
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_double(self):
|
||||
result = cel_to_python(celpy.celtypes.DoubleType(3.14))
|
||||
assert result == pytest.approx(3.14)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_string(self):
|
||||
result = cel_to_python(celpy.celtypes.StringType("hello"))
|
||||
assert result == "hello"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_list(self):
|
||||
cel_list = celpy.celtypes.ListType([
|
||||
celpy.celtypes.IntType(1),
|
||||
celpy.celtypes.IntType(2),
|
||||
])
|
||||
result = cel_to_python(cel_list)
|
||||
assert result == [1, 2]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_map(self):
|
||||
cel_map = celpy.celtypes.MapType({
|
||||
celpy.celtypes.StringType("key"): celpy.celtypes.StringType("value"),
|
||||
})
|
||||
result = cel_to_python(cel_map)
|
||||
assert result == {"key": "value"}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_unknown_type_passthrough(self):
|
||||
result = cel_to_python("raw_value")
|
||||
assert result == "raw_value"
|
||||
85
tests/agents/test_cryptoprice_tool.py
Normal file
85
tests/agents/test_cryptoprice_tool.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""Tests for application/agents/tools/cryptoprice.py"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from application.agents.tools.cryptoprice import CryptoPriceTool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool():
|
||||
return CryptoPriceTool(config={})
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCryptoPriceExecuteAction:
|
||||
def test_unknown_action_raises(self, tool):
|
||||
with pytest.raises(ValueError, match="Unknown action"):
|
||||
tool.execute_action("invalid_action")
|
||||
|
||||
@patch("application.agents.tools.cryptoprice.requests.get")
|
||||
def test_successful_price_fetch(self, mock_get, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = {"USD": 65000}
|
||||
mock_get.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("cryptoprice_get", symbol="BTC", currency="USD")
|
||||
|
||||
assert result["status_code"] == 200
|
||||
assert result["price"] == 65000
|
||||
assert "successfully" in result["message"]
|
||||
|
||||
@patch("application.agents.tools.cryptoprice.requests.get")
|
||||
def test_currency_not_found(self, mock_get, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = {"EUR": 60000}
|
||||
mock_get.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("cryptoprice_get", symbol="BTC", currency="USD")
|
||||
|
||||
assert result["status_code"] == 200
|
||||
assert "Couldn't find" in result["message"]
|
||||
assert "price" not in result
|
||||
|
||||
@patch("application.agents.tools.cryptoprice.requests.get")
|
||||
def test_api_failure(self, mock_get, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 500
|
||||
mock_get.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("cryptoprice_get", symbol="BTC", currency="USD")
|
||||
|
||||
assert result["status_code"] == 500
|
||||
assert "Failed" in result["message"]
|
||||
|
||||
@patch("application.agents.tools.cryptoprice.requests.get")
|
||||
def test_symbol_case_insensitive(self, mock_get, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = {"USD": 100}
|
||||
mock_get.return_value = mock_resp
|
||||
|
||||
tool.execute_action("cryptoprice_get", symbol="btc", currency="usd")
|
||||
|
||||
called_url = mock_get.call_args[0][0]
|
||||
assert "fsym=BTC" in called_url
|
||||
assert "tsyms=USD" in called_url
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCryptoPriceMetadata:
|
||||
def test_actions_metadata(self, tool):
|
||||
meta = tool.get_actions_metadata()
|
||||
assert len(meta) == 1
|
||||
assert meta[0]["name"] == "cryptoprice_get"
|
||||
params = meta[0]["parameters"]
|
||||
assert "symbol" in params["properties"]
|
||||
assert "currency" in params["properties"]
|
||||
assert "symbol" in params["required"]
|
||||
assert "currency" in params["required"]
|
||||
|
||||
def test_config_requirements(self, tool):
|
||||
assert tool.get_config_requirements() == {}
|
||||
145
tests/agents/test_duckduckgo_tool.py
Normal file
145
tests/agents/test_duckduckgo_tool.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""Tests for application/agents/tools/duckduckgo.py"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from application.agents.tools.duckduckgo import DuckDuckGoSearchTool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool():
|
||||
return DuckDuckGoSearchTool(config={})
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestDuckDuckGoExecuteAction:
|
||||
def test_unknown_action_raises(self, tool):
|
||||
with pytest.raises(ValueError, match="Unknown action"):
|
||||
tool.execute_action("invalid")
|
||||
|
||||
@patch.object(DuckDuckGoSearchTool, "_get_ddgs_client")
|
||||
def test_web_search_success(self, mock_client_factory, tool):
|
||||
mock_client = MagicMock()
|
||||
mock_client.text.return_value = [
|
||||
{"title": "Result 1", "href": "https://example.com", "body": "snippet"}
|
||||
]
|
||||
mock_client_factory.return_value = mock_client
|
||||
|
||||
result = tool.execute_action("ddg_web_search", query="python")
|
||||
|
||||
assert result["status_code"] == 200
|
||||
assert len(result["results"]) == 1
|
||||
assert "successfully" in result["message"]
|
||||
|
||||
@patch.object(DuckDuckGoSearchTool, "_get_ddgs_client")
|
||||
def test_image_search_success(self, mock_client_factory, tool):
|
||||
mock_client = MagicMock()
|
||||
mock_client.images.return_value = [{"image": "https://img.com/1.jpg"}]
|
||||
mock_client_factory.return_value = mock_client
|
||||
|
||||
result = tool.execute_action("ddg_image_search", query="cats")
|
||||
|
||||
assert result["status_code"] == 200
|
||||
assert len(result["results"]) == 1
|
||||
|
||||
@patch.object(DuckDuckGoSearchTool, "_get_ddgs_client")
|
||||
def test_news_search_success(self, mock_client_factory, tool):
|
||||
mock_client = MagicMock()
|
||||
mock_client.news.return_value = [{"title": "News"}]
|
||||
mock_client_factory.return_value = mock_client
|
||||
|
||||
result = tool.execute_action("ddg_news_search", query="tech")
|
||||
|
||||
assert result["status_code"] == 200
|
||||
assert len(result["results"]) == 1
|
||||
|
||||
@patch.object(DuckDuckGoSearchTool, "_get_ddgs_client")
|
||||
def test_search_error_returns_500(self, mock_client_factory, tool):
|
||||
mock_client = MagicMock()
|
||||
mock_client.text.side_effect = Exception("Network error")
|
||||
mock_client_factory.return_value = mock_client
|
||||
|
||||
result = tool.execute_action("ddg_web_search", query="test")
|
||||
|
||||
assert result["status_code"] == 500
|
||||
assert "failed" in result["message"].lower()
|
||||
assert result["results"] == []
|
||||
|
||||
@patch.object(DuckDuckGoSearchTool, "_get_ddgs_client")
|
||||
def test_max_results_capped_at_20(self, mock_client_factory, tool):
|
||||
mock_client = MagicMock()
|
||||
mock_client.text.return_value = []
|
||||
mock_client_factory.return_value = mock_client
|
||||
|
||||
tool.execute_action("ddg_web_search", query="test", max_results=100)
|
||||
|
||||
call_kwargs = mock_client.text.call_args[1]
|
||||
assert call_kwargs["max_results"] == 20
|
||||
|
||||
@patch.object(DuckDuckGoSearchTool, "_get_ddgs_client")
|
||||
def test_image_max_results_capped_at_50(self, mock_client_factory, tool):
|
||||
mock_client = MagicMock()
|
||||
mock_client.images.return_value = []
|
||||
mock_client_factory.return_value = mock_client
|
||||
|
||||
tool.execute_action("ddg_image_search", query="test", max_results=200)
|
||||
|
||||
call_kwargs = mock_client.images.call_args[1]
|
||||
assert call_kwargs["max_results"] == 50
|
||||
|
||||
@patch("application.agents.tools.duckduckgo.time.sleep")
|
||||
@patch.object(DuckDuckGoSearchTool, "_get_ddgs_client")
|
||||
def test_rate_limit_retries(self, mock_client_factory, mock_sleep, tool):
|
||||
mock_client = MagicMock()
|
||||
mock_client.text.side_effect = [
|
||||
Exception("RateLimit exceeded"),
|
||||
[{"title": "Result"}],
|
||||
]
|
||||
mock_client_factory.return_value = mock_client
|
||||
|
||||
result = tool.execute_action("ddg_web_search", query="test")
|
||||
|
||||
assert result["status_code"] == 200
|
||||
assert len(result["results"]) == 1
|
||||
mock_sleep.assert_called_once()
|
||||
|
||||
@patch.object(DuckDuckGoSearchTool, "_get_ddgs_client")
|
||||
def test_empty_results(self, mock_client_factory, tool):
|
||||
mock_client = MagicMock()
|
||||
mock_client.text.return_value = []
|
||||
mock_client_factory.return_value = mock_client
|
||||
|
||||
result = tool.execute_action("ddg_web_search", query="obscure query")
|
||||
|
||||
assert result["status_code"] == 200
|
||||
assert result["results"] == []
|
||||
|
||||
@patch.object(DuckDuckGoSearchTool, "_get_ddgs_client")
|
||||
def test_none_results(self, mock_client_factory, tool):
|
||||
mock_client = MagicMock()
|
||||
mock_client.text.return_value = None
|
||||
mock_client_factory.return_value = mock_client
|
||||
|
||||
result = tool.execute_action("ddg_web_search", query="test")
|
||||
|
||||
assert result["status_code"] == 200
|
||||
assert result["results"] == []
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestDuckDuckGoMetadata:
|
||||
def test_actions_metadata(self, tool):
|
||||
meta = tool.get_actions_metadata()
|
||||
assert len(meta) == 3
|
||||
names = {a["name"] for a in meta}
|
||||
assert "ddg_web_search" in names
|
||||
assert "ddg_image_search" in names
|
||||
assert "ddg_news_search" in names
|
||||
|
||||
def test_config_requirements(self, tool):
|
||||
assert tool.get_config_requirements() == {}
|
||||
|
||||
def test_custom_timeout(self):
|
||||
tool = DuckDuckGoSearchTool(config={"timeout": 30})
|
||||
assert tool.timeout == 30
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Tests for InternalSearchTool and its helper functions."""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from application.agents.tools.internal_search import (
|
||||
@@ -248,3 +248,452 @@ class TestBuildHelpers:
|
||||
tools_dict = {}
|
||||
add_internal_search_tool(tools_dict, {})
|
||||
assert INTERNAL_TOOL_ID not in tools_dict
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestInternalSearchToolGetRetriever:
|
||||
"""Cover line 32: _get_retriever creates retriever lazily."""
|
||||
|
||||
def test_get_retriever_creates_retriever(self):
|
||||
tool = InternalSearchTool({
|
||||
"source": {},
|
||||
"retriever_name": "classic",
|
||||
"chunks": 2,
|
||||
})
|
||||
assert tool._retriever is None
|
||||
|
||||
mock_retriever = Mock()
|
||||
with patch(
|
||||
"application.agents.tools.internal_search.RetrieverCreator"
|
||||
) as mock_rc:
|
||||
mock_rc.create_retriever.return_value = mock_retriever
|
||||
result = tool._get_retriever()
|
||||
|
||||
assert result is mock_retriever
|
||||
assert tool._retriever is mock_retriever
|
||||
|
||||
def test_get_retriever_cached(self):
|
||||
"""Cover line 32: second call returns cached retriever."""
|
||||
tool = InternalSearchTool({"source": {}, "retriever_name": "classic"})
|
||||
mock_retriever = Mock()
|
||||
tool._retriever = mock_retriever
|
||||
|
||||
result = tool._get_retriever()
|
||||
assert result is mock_retriever
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGetDirectoryStructure:
|
||||
"""Cover lines 61: _get_directory_structure loads from MongoDB."""
|
||||
|
||||
def test_no_active_docs_returns_none(self):
|
||||
"""Cover line 56-57: no active_docs returns None."""
|
||||
tool = InternalSearchTool({"source": {}})
|
||||
result = tool._get_directory_structure()
|
||||
assert result is None
|
||||
assert tool._dir_structure_loaded is True
|
||||
|
||||
def test_loads_structure_from_mongo(self):
|
||||
"""Cover line 61+: loads directory structure from MongoDB."""
|
||||
from bson.objectid import ObjectId
|
||||
|
||||
doc_id = str(ObjectId())
|
||||
tool = InternalSearchTool({
|
||||
"source": {"active_docs": [doc_id]},
|
||||
})
|
||||
|
||||
mock_source_doc = {
|
||||
"_id": ObjectId(doc_id),
|
||||
"name": "test_source",
|
||||
"directory_structure": {"root": {"file.txt": {"type": "text"}}},
|
||||
}
|
||||
|
||||
with patch(
|
||||
"application.core.mongo_db.MongoDB"
|
||||
) as mock_mongo:
|
||||
mock_db = Mock()
|
||||
mock_collection = Mock()
|
||||
mock_collection.find_one.return_value = mock_source_doc
|
||||
mock_db.__getitem__ = Mock(return_value=mock_collection)
|
||||
mock_mongo.get_client.return_value = Mock(
|
||||
__getitem__=Mock(return_value=mock_db)
|
||||
)
|
||||
|
||||
result = tool._get_directory_structure()
|
||||
|
||||
assert result == {"root": {"file.txt": {"type": "text"}}}
|
||||
|
||||
def test_loads_string_structure_from_mongo(self):
|
||||
"""Cover line 80-81: directory_structure stored as JSON string."""
|
||||
from bson.objectid import ObjectId
|
||||
|
||||
doc_id = str(ObjectId())
|
||||
tool = InternalSearchTool({
|
||||
"source": {"active_docs": [doc_id]},
|
||||
})
|
||||
|
||||
mock_source_doc = {
|
||||
"_id": ObjectId(doc_id),
|
||||
"name": "test_source",
|
||||
"directory_structure": '{"root": {"file.txt": {}}}',
|
||||
}
|
||||
|
||||
with patch(
|
||||
"application.core.mongo_db.MongoDB"
|
||||
) as mock_mongo:
|
||||
mock_db = Mock()
|
||||
mock_collection = Mock()
|
||||
mock_collection.find_one.return_value = mock_source_doc
|
||||
mock_db.__getitem__ = Mock(return_value=mock_collection)
|
||||
mock_mongo.get_client.return_value = Mock(
|
||||
__getitem__=Mock(return_value=mock_db)
|
||||
)
|
||||
|
||||
result = tool._get_directory_structure()
|
||||
|
||||
assert result == {"root": {"file.txt": {}}}
|
||||
|
||||
def test_multiple_active_docs_merged(self):
|
||||
"""Cover line 83-84: multiple docs merge under source names."""
|
||||
from bson.objectid import ObjectId
|
||||
|
||||
doc_id1 = str(ObjectId())
|
||||
doc_id2 = str(ObjectId())
|
||||
tool = InternalSearchTool({
|
||||
"source": {"active_docs": [doc_id1, doc_id2]},
|
||||
})
|
||||
|
||||
docs = {
|
||||
doc_id1: {
|
||||
"_id": ObjectId(doc_id1),
|
||||
"name": "source1",
|
||||
"directory_structure": {"file1.txt": {}},
|
||||
},
|
||||
doc_id2: {
|
||||
"_id": ObjectId(doc_id2),
|
||||
"name": "source2",
|
||||
"directory_structure": {"file2.txt": {}},
|
||||
},
|
||||
}
|
||||
|
||||
with patch(
|
||||
"application.core.mongo_db.MongoDB"
|
||||
) as mock_mongo:
|
||||
mock_db = Mock()
|
||||
mock_collection = Mock()
|
||||
mock_collection.find_one.side_effect = lambda q: docs.get(
|
||||
str(q["_id"])
|
||||
)
|
||||
mock_db.__getitem__ = Mock(return_value=mock_collection)
|
||||
mock_mongo.get_client.return_value = Mock(
|
||||
__getitem__=Mock(return_value=mock_db)
|
||||
)
|
||||
|
||||
result = tool._get_directory_structure()
|
||||
|
||||
assert "source1" in result
|
||||
assert "source2" in result
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestFormatStructureAdditional:
|
||||
"""Cover lines 186, 193, 200, 221: format structure branches."""
|
||||
|
||||
def test_format_structure_non_dict_node(self):
|
||||
"""Cover line 173: non-dict node returns file message."""
|
||||
tool = InternalSearchTool({"source": {}})
|
||||
result = tool._format_structure("a string node", "/path")
|
||||
assert "is a file" in result
|
||||
|
||||
def test_format_structure_file_with_type_metadata(self):
|
||||
"""Cover lines 186-193: file with type and token_count metadata."""
|
||||
tool = InternalSearchTool({"source": {}})
|
||||
node = {
|
||||
"readme.md": {"type": "markdown", "token_count": 500},
|
||||
"data.json": {"size_bytes": 1024},
|
||||
}
|
||||
result = tool._format_structure(node, "/root")
|
||||
assert "readme.md" in result
|
||||
assert "500 tokens" in result
|
||||
|
||||
def test_format_structure_empty_directory(self):
|
||||
"""Cover lines 206-208: empty directory."""
|
||||
tool = InternalSearchTool({"source": {}})
|
||||
result = tool._format_structure({}, "/empty")
|
||||
assert "(empty)" in result
|
||||
|
||||
def test_format_structure_plain_file_entry(self):
|
||||
"""Cover line 198: plain file entry (non-dict value)."""
|
||||
tool = InternalSearchTool({"source": {}})
|
||||
node = {"file.txt": "some_value"}
|
||||
result = tool._format_structure(node, "/root")
|
||||
assert "file.txt" in result
|
||||
|
||||
def test_count_files_nested(self):
|
||||
"""Cover line 221: _count_files counts nested files."""
|
||||
tool = InternalSearchTool({"source": {}})
|
||||
node = {
|
||||
"sub": {"file1.txt": {"type": "text"}},
|
||||
"file2.txt": "plain",
|
||||
}
|
||||
count = tool._count_files(node)
|
||||
assert count == 2
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSourcesHaveDirectoryStructure:
|
||||
"""Cover line 240, 254, 298: sources_have_directory_structure helper."""
|
||||
|
||||
def test_no_active_docs_returns_false(self):
|
||||
from application.agents.tools.internal_search import (
|
||||
sources_have_directory_structure,
|
||||
)
|
||||
|
||||
assert sources_have_directory_structure({}) is False
|
||||
assert sources_have_directory_structure({"active_docs": []}) is False
|
||||
|
||||
def test_with_structure_returns_true(self):
|
||||
from bson.objectid import ObjectId
|
||||
from application.agents.tools.internal_search import (
|
||||
sources_have_directory_structure,
|
||||
)
|
||||
|
||||
doc_id = str(ObjectId())
|
||||
mock_source_doc = {
|
||||
"_id": ObjectId(doc_id),
|
||||
"directory_structure": {"root": {}},
|
||||
}
|
||||
|
||||
with patch(
|
||||
"application.core.mongo_db.MongoDB"
|
||||
) as mock_mongo:
|
||||
mock_db = Mock()
|
||||
mock_collection = Mock()
|
||||
mock_collection.find_one.return_value = mock_source_doc
|
||||
mock_db.__getitem__ = Mock(return_value=mock_collection)
|
||||
mock_mongo.get_client.return_value = Mock(
|
||||
__getitem__=Mock(return_value=mock_db)
|
||||
)
|
||||
|
||||
result = sources_have_directory_structure({"active_docs": [doc_id]})
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_string_active_docs_converted_to_list(self):
|
||||
"""Cover line 298: active_docs as string is converted to list."""
|
||||
from bson.objectid import ObjectId
|
||||
from application.agents.tools.internal_search import (
|
||||
sources_have_directory_structure,
|
||||
)
|
||||
|
||||
doc_id = str(ObjectId())
|
||||
mock_source_doc = {
|
||||
"_id": ObjectId(doc_id),
|
||||
"directory_structure": {"root": {}},
|
||||
}
|
||||
|
||||
with patch(
|
||||
"application.core.mongo_db.MongoDB"
|
||||
) as mock_mongo:
|
||||
mock_db = Mock()
|
||||
mock_collection = Mock()
|
||||
mock_collection.find_one.return_value = mock_source_doc
|
||||
mock_db.__getitem__ = Mock(return_value=mock_collection)
|
||||
mock_mongo.get_client.return_value = Mock(
|
||||
__getitem__=Mock(return_value=mock_db)
|
||||
)
|
||||
|
||||
result = sources_have_directory_structure({"active_docs": doc_id})
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_get_config_requirements(self):
|
||||
"""Cover line 280: get_config_requirements."""
|
||||
tool = InternalSearchTool({"source": {}})
|
||||
assert tool.get_config_requirements() == {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Coverage — additional uncovered lines: 77, 135, 186, 200, 221, 240, 254, 298
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestInternalSearchToolAdditionalCoverage:
|
||||
|
||||
def test_get_directory_structure_returns_cached(self):
|
||||
"""Cover line 77: source_doc not found in DB returns None."""
|
||||
tool = InternalSearchTool({"source": {"active_docs": ["nonexistent"]}})
|
||||
tool._dir_structure_loaded = True
|
||||
tool._directory_structure = {"cached": True}
|
||||
result = tool._get_directory_structure()
|
||||
assert result == {"cached": True}
|
||||
|
||||
def test_execute_search_appends_to_retrieved_docs(self):
|
||||
"""Cover line 135: doc appended to retrieved_docs."""
|
||||
tool = InternalSearchTool({"source": {}})
|
||||
mock_retriever = Mock()
|
||||
mock_retriever.search.return_value = [
|
||||
{"title": "Doc1", "text": "content", "source": "src"},
|
||||
]
|
||||
tool._retriever = mock_retriever
|
||||
tool._execute_search(query="test")
|
||||
assert len(tool.retrieved_docs) == 1
|
||||
|
||||
def test_format_structure_file_metadata(self):
|
||||
"""Cover line 186: file with metadata (type, token_count)."""
|
||||
tool = InternalSearchTool({"source": {}})
|
||||
node = {
|
||||
"readme.md": {"type": "markdown", "token_count": 100},
|
||||
"subfolder": {"nested_file.py": {}},
|
||||
}
|
||||
result = tool._format_structure(node, "/")
|
||||
assert "readme.md" in result
|
||||
assert "markdown" in result
|
||||
assert "100 tokens" in result
|
||||
|
||||
def test_format_structure_folders_and_files(self):
|
||||
"""Cover line 200: folders and files sections in output."""
|
||||
tool = InternalSearchTool({"source": {}})
|
||||
node = {
|
||||
"src": {"main.py": {}},
|
||||
"README.md": "file",
|
||||
}
|
||||
result = tool._format_structure(node, "/")
|
||||
assert "Folders:" in result
|
||||
assert "Files:" in result
|
||||
|
||||
def test_count_files_recursive(self):
|
||||
"""Cover line 221: _count_files counts nested files."""
|
||||
tool = InternalSearchTool({"source": {}})
|
||||
node = {
|
||||
"a.py": "file",
|
||||
"subdir": {
|
||||
"b.py": {"type": "python", "token_count": 50},
|
||||
},
|
||||
}
|
||||
count = tool._count_files(node)
|
||||
assert count == 2
|
||||
|
||||
def test_get_actions_metadata_with_directory_structure(self):
|
||||
"""Cover line 240+: actions include path_filter and list_files."""
|
||||
tool = InternalSearchTool({"source": {}, "has_directory_structure": True})
|
||||
actions = tool.get_actions_metadata()
|
||||
action_names = [a["name"] for a in actions]
|
||||
assert "search" in action_names
|
||||
assert "list_files" in action_names
|
||||
# Check path_filter is in search params
|
||||
search_action = next(a for a in actions if a["name"] == "search")
|
||||
assert "path_filter" in search_action["parameters"]["properties"]
|
||||
|
||||
def test_get_actions_metadata_without_directory_structure(self):
|
||||
"""Cover line 254: actions without directory structure."""
|
||||
tool = InternalSearchTool({"source": {}, "has_directory_structure": False})
|
||||
actions = tool.get_actions_metadata()
|
||||
action_names = [a["name"] for a in actions]
|
||||
assert "search" in action_names
|
||||
assert "list_files" not in action_names
|
||||
|
||||
def test_build_internal_tool_entry_with_directory_structure(self):
|
||||
"""Cover line 298: build_internal_tool_entry with has_directory_structure."""
|
||||
entry = build_internal_tool_entry(has_directory_structure=True)
|
||||
action_names = [a["name"] for a in entry["actions"]]
|
||||
assert "list_files" in action_names
|
||||
search_action = next(a for a in entry["actions"] if a["name"] == "search")
|
||||
assert "path_filter" in search_action["parameters"]["properties"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Additional coverage for internal_search.py
|
||||
# Lines: 101 (unknown action), 108 (empty query), 114-115 (search exception),
|
||||
# 117-118 (no docs), 130-131 (path filter no match),
|
||||
# 154-155 (no dir structure), 165-166 (path not found)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestInternalSearchUnknownAction:
|
||||
"""Cover line 101: unknown action returns error string."""
|
||||
|
||||
def test_unknown_action(self):
|
||||
tool = InternalSearchTool({"source": {}})
|
||||
result = tool.execute_action("unknown_action")
|
||||
assert "Unknown action" in result
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestInternalSearchEmptyQuery:
|
||||
"""Cover line 108: empty query returns error."""
|
||||
|
||||
def test_empty_query(self):
|
||||
tool = InternalSearchTool({"source": {}})
|
||||
result = tool.execute_action("search", query="")
|
||||
assert "required" in result.lower()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestInternalSearchException:
|
||||
"""Cover lines 114-115: search exception returns error."""
|
||||
|
||||
def test_search_raises(self):
|
||||
tool = InternalSearchTool({"source": {}})
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.search.side_effect = RuntimeError("DB down")
|
||||
tool._get_retriever = MagicMock(return_value=mock_retriever)
|
||||
result = tool.execute_action("search", query="hello")
|
||||
assert "internal error" in result.lower()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestInternalSearchNoDocs:
|
||||
"""Cover lines 117-118: no docs found."""
|
||||
|
||||
def test_no_docs(self):
|
||||
tool = InternalSearchTool({"source": {}})
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.search.return_value = []
|
||||
tool._get_retriever = MagicMock(return_value=mock_retriever)
|
||||
result = tool.execute_action("search", query="hello")
|
||||
assert "No documents found" in result
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestInternalSearchPathFilterNoMatch:
|
||||
"""Cover lines 130-131: path filter with no matching docs."""
|
||||
|
||||
def test_path_filter_no_match(self):
|
||||
tool = InternalSearchTool({"source": {}})
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.search.return_value = [
|
||||
{"source": "other.txt", "text": "data", "title": "Other"}
|
||||
]
|
||||
tool._get_retriever = MagicMock(return_value=mock_retriever)
|
||||
result = tool.execute_action(
|
||||
"search", query="hello", path_filter="nonexistent"
|
||||
)
|
||||
assert "No documents found" in result
|
||||
assert "nonexistent" in result
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestInternalSearchListFilesNoDirStructure:
|
||||
"""Cover lines 154-155: no directory structure."""
|
||||
|
||||
def test_no_dir_structure(self):
|
||||
tool = InternalSearchTool({"source": {}})
|
||||
tool._get_directory_structure = MagicMock(return_value=None)
|
||||
result = tool.execute_action("list_files")
|
||||
assert "No file structure" in result
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestInternalSearchListFilesPathNotFound:
|
||||
"""Cover lines 165-166: path not found."""
|
||||
|
||||
def test_path_not_found(self):
|
||||
tool = InternalSearchTool({"source": {}})
|
||||
tool._get_directory_structure = MagicMock(
|
||||
return_value={"folder": {"file.txt": {}}}
|
||||
)
|
||||
result = tool.execute_action("list_files", path="missing_dir")
|
||||
assert "not found" in result.lower()
|
||||
|
||||
519
tests/agents/test_mcp_tool.py
Normal file
519
tests/agents/test_mcp_tool.py
Normal file
@@ -0,0 +1,519 @@
|
||||
"""Tests for application/agents/tools/mcp_tool.py"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# mcp_tool has a circular import at module level (mcp_tool -> tasks -> user -> mcp.py -> mcp_tool).
|
||||
# We must patch the dependencies BEFORE the module is first imported.
|
||||
@pytest.fixture(autouse=True)
|
||||
def _patch_mcp_globals(monkeypatch):
|
||||
"""Patch module-level MongoDB and cache to avoid real connections."""
|
||||
import sys
|
||||
|
||||
# If the module is already loaded, just patch attributes directly
|
||||
if "application.agents.tools.mcp_tool" in sys.modules:
|
||||
mcp_mod = sys.modules["application.agents.tools.mcp_tool"]
|
||||
else:
|
||||
# Break the circular import by pre-populating the tasks import
|
||||
# with a mock before mcp_tool tries to import it
|
||||
mock_tasks = MagicMock()
|
||||
monkeypatch.setitem(sys.modules, "application.api.user.tasks", mock_tasks)
|
||||
import application.agents.tools.mcp_tool as mcp_mod
|
||||
|
||||
mock_mongo = MagicMock()
|
||||
mock_db = MagicMock()
|
||||
mock_db.__getitem__ = MagicMock(return_value=MagicMock())
|
||||
monkeypatch.setattr(mcp_mod, "mongo", mock_mongo)
|
||||
monkeypatch.setattr(mcp_mod, "db", mock_db)
|
||||
monkeypatch.setattr(mcp_mod, "_mcp_clients_cache", {})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_config():
|
||||
return {
|
||||
"server_url": "https://mcp.example.com/api",
|
||||
"transport_type": "http",
|
||||
"auth_type": "none",
|
||||
"timeout": 10,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def bearer_config():
|
||||
return {
|
||||
"server_url": "https://mcp.example.com/api",
|
||||
"transport_type": "http",
|
||||
"auth_type": "bearer",
|
||||
"auth_credentials": {"bearer_token": "tok_123"},
|
||||
"timeout": 10,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestMCPToolInit:
|
||||
def test_basic_init(self, mcp_config):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
with patch.object(MCPTool, "_setup_client"):
|
||||
tool = MCPTool(mcp_config)
|
||||
|
||||
assert tool.server_url == "https://mcp.example.com/api"
|
||||
assert tool.transport_type == "http"
|
||||
assert tool.auth_type == "none"
|
||||
assert tool.timeout == 10
|
||||
|
||||
def test_bearer_auth_credentials(self, bearer_config):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
with patch.object(MCPTool, "_setup_client"):
|
||||
tool = MCPTool(bearer_config)
|
||||
assert tool.auth_credentials["bearer_token"] == "tok_123"
|
||||
|
||||
def test_no_server_url_skips_setup(self):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
with patch.object(MCPTool, "_setup_client") as mock_setup:
|
||||
MCPTool({"server_url": "", "auth_type": "none"})
|
||||
mock_setup.assert_not_called()
|
||||
|
||||
def test_oauth_skips_setup(self):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
with patch.object(MCPTool, "_setup_client") as mock_setup:
|
||||
MCPTool(
|
||||
{
|
||||
"server_url": "https://mcp.example.com",
|
||||
"auth_type": "oauth",
|
||||
}
|
||||
)
|
||||
mock_setup.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGenerateCacheKey:
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_none_auth(self, mock_setup, mcp_config):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(mcp_config)
|
||||
assert "none" in tool._cache_key
|
||||
assert "mcp.example.com" in tool._cache_key
|
||||
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_bearer_auth(self, mock_setup, bearer_config):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(bearer_config)
|
||||
assert "bearer:" in tool._cache_key
|
||||
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_api_key_auth(self, mock_setup):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(
|
||||
{
|
||||
"server_url": "https://mcp.example.com",
|
||||
"auth_type": "api_key",
|
||||
"auth_credentials": {"api_key": "sk-test12345678"},
|
||||
}
|
||||
)
|
||||
assert "apikey:" in tool._cache_key
|
||||
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_basic_auth(self, mock_setup):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(
|
||||
{
|
||||
"server_url": "https://mcp.example.com",
|
||||
"auth_type": "basic",
|
||||
"auth_credentials": {"username": "user1", "password": "pass"},
|
||||
}
|
||||
)
|
||||
assert "basic:user1" in tool._cache_key
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCreateTransport:
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_http_transport(self, mock_setup, mcp_config):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(mcp_config)
|
||||
transport = tool._create_transport()
|
||||
assert transport is not None
|
||||
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_sse_transport(self, mock_setup):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(
|
||||
{
|
||||
"server_url": "https://mcp.example.com/sse",
|
||||
"transport_type": "sse",
|
||||
"auth_type": "none",
|
||||
}
|
||||
)
|
||||
transport = tool._create_transport()
|
||||
assert transport is not None
|
||||
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_auto_detects_sse(self, mock_setup):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(
|
||||
{
|
||||
"server_url": "https://mcp.example.com/sse",
|
||||
"transport_type": "auto",
|
||||
"auth_type": "none",
|
||||
}
|
||||
)
|
||||
transport = tool._create_transport()
|
||||
assert transport is not None
|
||||
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_stdio_transport_disabled(self, mock_setup):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(
|
||||
{
|
||||
"server_url": "https://mcp.example.com",
|
||||
"transport_type": "stdio",
|
||||
"auth_type": "none",
|
||||
}
|
||||
)
|
||||
with pytest.raises(ValueError, match="STDIO transport is disabled"):
|
||||
tool._create_transport()
|
||||
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_api_key_header_injected(self, mock_setup):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(
|
||||
{
|
||||
"server_url": "https://mcp.example.com",
|
||||
"transport_type": "http",
|
||||
"auth_type": "api_key",
|
||||
"auth_credentials": {
|
||||
"api_key": "sk-test",
|
||||
"api_key_header": "X-Custom-Key",
|
||||
},
|
||||
}
|
||||
)
|
||||
# _create_transport will be called; verify it doesn't raise
|
||||
transport = tool._create_transport()
|
||||
assert transport is not None
|
||||
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_basic_auth_header_injected(self, mock_setup):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(
|
||||
{
|
||||
"server_url": "https://mcp.example.com",
|
||||
"transport_type": "http",
|
||||
"auth_type": "basic",
|
||||
"auth_credentials": {"username": "user", "password": "pass"},
|
||||
}
|
||||
)
|
||||
transport = tool._create_transport()
|
||||
assert transport is not None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestFormatTools:
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_format_list_of_dicts(self, mock_setup, mcp_config):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(mcp_config)
|
||||
result = tool._format_tools([{"name": "tool1", "description": "desc"}])
|
||||
assert len(result) == 1
|
||||
assert result[0]["name"] == "tool1"
|
||||
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_format_tools_with_name_attribute(self, mock_setup, mcp_config):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(mcp_config)
|
||||
mock_tool = MagicMock()
|
||||
mock_tool.name = "my_tool"
|
||||
mock_tool.description = "A tool"
|
||||
mock_tool.inputSchema = {"type": "object", "properties": {}}
|
||||
|
||||
result = tool._format_tools([mock_tool])
|
||||
assert len(result) == 1
|
||||
assert result[0]["name"] == "my_tool"
|
||||
assert result[0]["inputSchema"] == {"type": "object", "properties": {}}
|
||||
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_format_tools_response_object(self, mock_setup, mcp_config):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(mcp_config)
|
||||
resp = MagicMock()
|
||||
resp.tools = [{"name": "t1", "description": "d1"}]
|
||||
|
||||
result = tool._format_tools(resp)
|
||||
assert len(result) == 1
|
||||
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_format_empty(self, mock_setup, mcp_config):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(mcp_config)
|
||||
assert tool._format_tools([]) == []
|
||||
assert tool._format_tools("unexpected") == []
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestFormatResult:
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_format_result_with_content(self, mock_setup, mcp_config):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(mcp_config)
|
||||
mock_result = MagicMock()
|
||||
text_item = MagicMock()
|
||||
text_item.text = "Hello"
|
||||
del text_item.data
|
||||
mock_result.content = [text_item]
|
||||
mock_result.isError = False
|
||||
|
||||
result = tool._format_result(mock_result)
|
||||
assert result["content"][0]["type"] == "text"
|
||||
assert result["content"][0]["text"] == "Hello"
|
||||
assert result["isError"] is False
|
||||
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_format_raw_result(self, mock_setup, mcp_config):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(mcp_config)
|
||||
raw = {"key": "value"}
|
||||
assert tool._format_result(raw) == raw
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestExecuteAction:
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_no_server_raises(self, mock_setup):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool({"server_url": "", "auth_type": "none"})
|
||||
with pytest.raises(Exception, match="No MCP server configured"):
|
||||
tool.execute_action("test_action")
|
||||
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._run_async_operation")
|
||||
def test_successful_execute(self, mock_run, mock_setup, mcp_config):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(mcp_config)
|
||||
tool._client = MagicMock()
|
||||
mock_run.return_value = {"key": "value"}
|
||||
|
||||
result = tool.execute_action("test_action", param1="val1")
|
||||
|
||||
mock_run.assert_called_once_with("call_tool", "test_action", param1="val1")
|
||||
assert result == {"key": "value"}
|
||||
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._run_async_operation")
|
||||
def test_empty_kwargs_cleaned(self, mock_run, mock_setup, mcp_config):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(mcp_config)
|
||||
tool._client = MagicMock()
|
||||
mock_run.return_value = {}
|
||||
|
||||
tool.execute_action("test", param1="", param2=None, param3="real")
|
||||
|
||||
call_kwargs = mock_run.call_args[1]
|
||||
assert "param1" not in call_kwargs
|
||||
assert "param2" not in call_kwargs
|
||||
assert call_kwargs["param3"] == "real"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestTestConnection:
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_no_server_url(self, mock_setup):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool({"server_url": "", "auth_type": "none"})
|
||||
result = tool.test_connection()
|
||||
assert result["success"] is False
|
||||
assert "No server URL" in result["message"]
|
||||
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_invalid_scheme(self, mock_setup):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(
|
||||
{"server_url": "ftp://bad.com", "auth_type": "none"}
|
||||
)
|
||||
result = tool.test_connection()
|
||||
assert result["success"] is False
|
||||
assert "Invalid URL scheme" in result["message"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestMapError:
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_timeout_error(self, mock_setup, mcp_config):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
import concurrent.futures
|
||||
|
||||
tool = MCPTool(mcp_config)
|
||||
err = tool._map_error("test", concurrent.futures.TimeoutError())
|
||||
assert "Timed out" in str(err)
|
||||
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_connection_refused(self, mock_setup, mcp_config):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(mcp_config)
|
||||
err = tool._map_error("test", ConnectionRefusedError())
|
||||
assert "Connection refused" in str(err)
|
||||
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_403_forbidden(self, mock_setup, mcp_config):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(mcp_config)
|
||||
err = tool._map_error("test", Exception("403 Forbidden"))
|
||||
assert "Access denied" in str(err)
|
||||
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_ssl_error(self, mock_setup, mcp_config):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(mcp_config)
|
||||
err = tool._map_error("test", Exception("SSL certificate verify failed"))
|
||||
assert "SSL" in str(err)
|
||||
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_unknown_error_passthrough(self, mock_setup, mcp_config):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(mcp_config)
|
||||
original = RuntimeError("something weird")
|
||||
err = tool._map_error("test", original)
|
||||
assert err is original
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGetActionsMetadata:
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_empty_tools(self, mock_setup, mcp_config):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(mcp_config)
|
||||
tool.available_tools = []
|
||||
assert tool.get_actions_metadata() == []
|
||||
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_tools_with_input_schema(self, mock_setup, mcp_config):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(mcp_config)
|
||||
tool.available_tools = [
|
||||
{
|
||||
"name": "search",
|
||||
"description": "Search things",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {"query": {"type": "string"}},
|
||||
"required": ["query"],
|
||||
},
|
||||
}
|
||||
]
|
||||
meta = tool.get_actions_metadata()
|
||||
assert len(meta) == 1
|
||||
assert meta[0]["name"] == "search"
|
||||
assert "query" in meta[0]["parameters"]["properties"]
|
||||
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_tools_without_schema(self, mock_setup, mcp_config):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(mcp_config)
|
||||
tool.available_tools = [{"name": "ping", "description": "Ping"}]
|
||||
meta = tool.get_actions_metadata()
|
||||
assert meta[0]["parameters"]["properties"] == {}
|
||||
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_config_requirements(self, mock_setup, mcp_config):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(mcp_config)
|
||||
reqs = tool.get_config_requirements()
|
||||
assert "server_url" in reqs
|
||||
assert "auth_type" in reqs
|
||||
assert reqs["server_url"]["required"] is True
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestMCPOAuthManager:
|
||||
def test_handle_callback_success(self):
|
||||
from application.agents.tools.mcp_tool import MCPOAuthManager
|
||||
|
||||
mock_redis = MagicMock()
|
||||
manager = MCPOAuthManager(mock_redis)
|
||||
|
||||
result = manager.handle_oauth_callback(state="abc123", code="auth_code")
|
||||
|
||||
assert result is True
|
||||
mock_redis.setex.assert_called()
|
||||
|
||||
def test_handle_callback_no_redis(self):
|
||||
from application.agents.tools.mcp_tool import MCPOAuthManager
|
||||
|
||||
manager = MCPOAuthManager(None)
|
||||
result = manager.handle_oauth_callback(state="abc", code="code")
|
||||
assert result is False
|
||||
|
||||
def test_handle_callback_error(self):
|
||||
from application.agents.tools.mcp_tool import MCPOAuthManager
|
||||
|
||||
mock_redis = MagicMock()
|
||||
manager = MCPOAuthManager(mock_redis)
|
||||
|
||||
result = manager.handle_oauth_callback(
|
||||
state="abc", code="", error="access_denied"
|
||||
)
|
||||
assert result is False
|
||||
|
||||
def test_get_oauth_status_no_task(self):
|
||||
from application.agents.tools.mcp_tool import MCPOAuthManager
|
||||
|
||||
manager = MCPOAuthManager(MagicMock())
|
||||
result = manager.get_oauth_status("")
|
||||
assert result["status"] == "not_started"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestDBTokenStorage:
|
||||
def test_get_base_url(self):
|
||||
from application.agents.tools.mcp_tool import DBTokenStorage
|
||||
|
||||
assert (
|
||||
DBTokenStorage.get_base_url("https://mcp.example.com/api/v1")
|
||||
== "https://mcp.example.com"
|
||||
)
|
||||
|
||||
def test_get_db_key(self):
|
||||
from application.agents.tools.mcp_tool import DBTokenStorage
|
||||
|
||||
mock_db = MagicMock()
|
||||
storage = DBTokenStorage(
|
||||
server_url="https://mcp.example.com/api",
|
||||
user_id="user1",
|
||||
db_client=mock_db,
|
||||
)
|
||||
key = storage.get_db_key()
|
||||
assert key["server_url"] == "https://mcp.example.com"
|
||||
assert key["user_id"] == "user1"
|
||||
140
tests/agents/test_node_agent.py
Normal file
140
tests/agents/test_node_agent.py
Normal file
@@ -0,0 +1,140 @@
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestToolFilterMixin:
|
||||
|
||||
def test_get_user_tools_filters_by_allowed_ids(self):
|
||||
from application.agents.workflows.node_agent import ToolFilterMixin
|
||||
|
||||
class FakeBase:
|
||||
def _get_user_tools(self, user="local"):
|
||||
return {
|
||||
"t1": {"_id": "id1", "name": "tool1"},
|
||||
"t2": {"_id": "id2", "name": "tool2"},
|
||||
"t3": {"_id": "id3", "name": "tool3"},
|
||||
}
|
||||
|
||||
class TestClass(ToolFilterMixin, FakeBase):
|
||||
pass
|
||||
|
||||
obj = TestClass()
|
||||
obj._allowed_tool_ids = ["id1", "id3"]
|
||||
result = obj._get_user_tools("user1")
|
||||
assert "t1" in result
|
||||
assert "t3" in result
|
||||
assert "t2" not in result
|
||||
|
||||
def test_get_user_tools_returns_empty_when_no_allowed(self):
|
||||
from application.agents.workflows.node_agent import ToolFilterMixin
|
||||
|
||||
class FakeBase:
|
||||
def _get_user_tools(self, user="local"):
|
||||
return {"t1": {"_id": "id1"}}
|
||||
|
||||
class TestClass(ToolFilterMixin, FakeBase):
|
||||
pass
|
||||
|
||||
obj = TestClass()
|
||||
obj._allowed_tool_ids = []
|
||||
result = obj._get_user_tools()
|
||||
assert result == {}
|
||||
|
||||
def test_get_tools_filters_by_allowed_ids(self):
|
||||
from application.agents.workflows.node_agent import ToolFilterMixin
|
||||
|
||||
class FakeBase:
|
||||
def _get_tools(self, api_key=None):
|
||||
return {
|
||||
"t1": {"_id": "id1"},
|
||||
"t2": {"_id": "id2"},
|
||||
}
|
||||
|
||||
class TestClass(ToolFilterMixin, FakeBase):
|
||||
pass
|
||||
|
||||
obj = TestClass()
|
||||
obj._allowed_tool_ids = ["id2"]
|
||||
result = obj._get_tools("key")
|
||||
assert "t2" in result
|
||||
assert "t1" not in result
|
||||
|
||||
def test_get_tools_returns_empty_when_no_allowed(self):
|
||||
from application.agents.workflows.node_agent import ToolFilterMixin
|
||||
|
||||
class FakeBase:
|
||||
def _get_tools(self, api_key=None):
|
||||
return {"t1": {"_id": "id1"}}
|
||||
|
||||
class TestClass(ToolFilterMixin, FakeBase):
|
||||
pass
|
||||
|
||||
obj = TestClass()
|
||||
obj._allowed_tool_ids = []
|
||||
result = obj._get_tools()
|
||||
assert result == {}
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestWorkflowNodeAgentFactory:
|
||||
|
||||
def test_raises_on_unsupported_type(self):
|
||||
from application.agents.workflows.node_agent import WorkflowNodeAgentFactory
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported agent type"):
|
||||
WorkflowNodeAgentFactory.create(
|
||||
agent_type="nonexistent",
|
||||
endpoint="http://example.com",
|
||||
llm_name="openai",
|
||||
model_id="gpt-4",
|
||||
api_key="key",
|
||||
)
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Coverage gap tests (lines 52-59: _WorkflowNodeMixin.__init__)
|
||||
# =====================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestWorkflowNodeMixinInit:
|
||||
|
||||
def test_mixin_init_sets_allowed_tool_ids(self):
|
||||
"""Cover lines 52-59: _WorkflowNodeMixin.__init__ stores tool_ids."""
|
||||
from application.agents.workflows.node_agent import _WorkflowNodeMixin
|
||||
|
||||
class FakeBase:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
class TestMixin(_WorkflowNodeMixin, FakeBase):
|
||||
pass
|
||||
|
||||
obj = TestMixin(
|
||||
endpoint="http://example.com",
|
||||
llm_name="openai",
|
||||
model_id="gpt-4",
|
||||
api_key="key",
|
||||
tool_ids=["tool1", "tool2"],
|
||||
)
|
||||
assert obj._allowed_tool_ids == ["tool1", "tool2"]
|
||||
|
||||
def test_mixin_init_defaults_empty_tool_ids(self):
|
||||
"""Cover: _WorkflowNodeMixin defaults to empty list."""
|
||||
from application.agents.workflows.node_agent import _WorkflowNodeMixin
|
||||
|
||||
class FakeBase:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
class TestMixin(_WorkflowNodeMixin, FakeBase):
|
||||
pass
|
||||
|
||||
obj = TestMixin(
|
||||
endpoint="http://example.com",
|
||||
llm_name="openai",
|
||||
model_id="gpt-4",
|
||||
api_key="key",
|
||||
)
|
||||
assert obj._allowed_tool_ids == []
|
||||
146
tests/agents/test_ntfy_tool.py
Normal file
146
tests/agents/test_ntfy_tool.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""Tests for application/agents/tools/ntfy.py"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from application.agents.tools.ntfy import NtfyTool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool():
|
||||
return NtfyTool(config={"token": "test_token"})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool_no_token():
|
||||
return NtfyTool(config={})
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestNtfyExecuteAction:
|
||||
def test_unknown_action_raises(self, tool):
|
||||
with pytest.raises(ValueError, match="Unknown action"):
|
||||
tool.execute_action("bad_action")
|
||||
|
||||
@patch("application.agents.tools.ntfy.requests.post")
|
||||
def test_send_message_basic(self, mock_post, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action(
|
||||
"ntfy_send_message",
|
||||
server_url="https://ntfy.sh",
|
||||
message="Hello",
|
||||
topic="test",
|
||||
)
|
||||
|
||||
assert result["status_code"] == 200
|
||||
assert result["message"] == "Message sent"
|
||||
mock_post.assert_called_once()
|
||||
call_args = mock_post.call_args
|
||||
assert call_args[0][0] == "https://ntfy.sh/test"
|
||||
assert call_args[1]["data"] == b"Hello"
|
||||
|
||||
@patch("application.agents.tools.ntfy.requests.post")
|
||||
def test_send_with_title_and_priority(self, mock_post, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
tool.execute_action(
|
||||
"ntfy_send_message",
|
||||
server_url="https://ntfy.sh",
|
||||
message="Alert",
|
||||
topic="urgent",
|
||||
title="Warning",
|
||||
priority=5,
|
||||
)
|
||||
|
||||
headers = mock_post.call_args[1]["headers"]
|
||||
assert headers["X-Title"] == "Warning"
|
||||
assert headers["X-Priority"] == "5"
|
||||
|
||||
@patch("application.agents.tools.ntfy.requests.post")
|
||||
def test_auth_header_with_token(self, mock_post, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
tool.execute_action(
|
||||
"ntfy_send_message",
|
||||
server_url="https://ntfy.sh",
|
||||
message="Hi",
|
||||
topic="t",
|
||||
)
|
||||
|
||||
headers = mock_post.call_args[1]["headers"]
|
||||
assert headers["Authorization"] == "Basic test_token"
|
||||
|
||||
@patch("application.agents.tools.ntfy.requests.post")
|
||||
def test_no_auth_without_token(self, mock_post, tool_no_token):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
tool_no_token.execute_action(
|
||||
"ntfy_send_message",
|
||||
server_url="https://ntfy.sh",
|
||||
message="Hi",
|
||||
topic="t",
|
||||
)
|
||||
|
||||
headers = mock_post.call_args[1]["headers"]
|
||||
assert "Authorization" not in headers
|
||||
|
||||
def test_invalid_priority_raises(self, tool):
|
||||
with pytest.raises(ValueError, match="between 1 and 5"):
|
||||
tool.execute_action(
|
||||
"ntfy_send_message",
|
||||
server_url="https://ntfy.sh",
|
||||
message="Hi",
|
||||
topic="t",
|
||||
priority=10,
|
||||
)
|
||||
|
||||
def test_non_numeric_priority_raises(self, tool):
|
||||
with pytest.raises(ValueError, match="convertible to an integer"):
|
||||
tool.execute_action(
|
||||
"ntfy_send_message",
|
||||
server_url="https://ntfy.sh",
|
||||
message="Hi",
|
||||
topic="t",
|
||||
priority="abc",
|
||||
)
|
||||
|
||||
@patch("application.agents.tools.ntfy.requests.post")
|
||||
def test_trailing_slash_stripped(self, mock_post, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
tool.execute_action(
|
||||
"ntfy_send_message",
|
||||
server_url="https://ntfy.sh/",
|
||||
message="Hi",
|
||||
topic="test",
|
||||
)
|
||||
|
||||
assert mock_post.call_args[0][0] == "https://ntfy.sh/test"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestNtfyMetadata:
|
||||
def test_actions_metadata(self, tool):
|
||||
meta = tool.get_actions_metadata()
|
||||
assert len(meta) == 1
|
||||
assert meta[0]["name"] == "ntfy_send_message"
|
||||
assert "server_url" in meta[0]["parameters"]["properties"]
|
||||
assert "message" in meta[0]["parameters"]["properties"]
|
||||
assert "topic" in meta[0]["parameters"]["properties"]
|
||||
|
||||
def test_config_requirements(self, tool):
|
||||
reqs = tool.get_config_requirements()
|
||||
assert "token" in reqs
|
||||
assert reqs["token"]["secret"] is True
|
||||
146
tests/agents/test_postgres_tool.py
Normal file
146
tests/agents/test_postgres_tool.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""Tests for application/agents/tools/postgres.py"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from application.agents.tools.postgres import PostgresTool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool():
|
||||
return PostgresTool(config={"token": "postgresql://user:pass@localhost/testdb"})
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestPostgresExecuteAction:
|
||||
def test_unknown_action_raises(self, tool):
|
||||
with pytest.raises(ValueError, match="Unknown action"):
|
||||
tool.execute_action("invalid")
|
||||
|
||||
@patch("application.agents.tools.postgres.psycopg2.connect")
|
||||
def test_select_query(self, mock_connect, tool):
|
||||
mock_conn = MagicMock()
|
||||
mock_cur = MagicMock()
|
||||
mock_cur.description = [("id",), ("name",)]
|
||||
mock_cur.fetchall.return_value = [(1, "Alice"), (2, "Bob")]
|
||||
mock_conn.cursor.return_value = mock_cur
|
||||
mock_connect.return_value = mock_conn
|
||||
|
||||
result = tool.execute_action(
|
||||
"postgres_execute_sql", sql_query="SELECT id, name FROM users"
|
||||
)
|
||||
|
||||
assert result["status_code"] == 200
|
||||
assert result["response_data"]["column_names"] == ["id", "name"]
|
||||
assert len(result["response_data"]["data"]) == 2
|
||||
assert result["response_data"]["data"][0] == {"id": 1, "name": "Alice"}
|
||||
mock_conn.close.assert_called_once()
|
||||
|
||||
@patch("application.agents.tools.postgres.psycopg2.connect")
|
||||
def test_insert_query(self, mock_connect, tool):
|
||||
mock_conn = MagicMock()
|
||||
mock_cur = MagicMock()
|
||||
mock_cur.rowcount = 1
|
||||
mock_conn.cursor.return_value = mock_cur
|
||||
mock_connect.return_value = mock_conn
|
||||
|
||||
result = tool.execute_action(
|
||||
"postgres_execute_sql",
|
||||
sql_query="INSERT INTO users (name) VALUES ('Alice')",
|
||||
)
|
||||
|
||||
assert result["status_code"] == 200
|
||||
assert "1 rows affected" in result["response_data"]["message"]
|
||||
mock_conn.commit.assert_called_once()
|
||||
mock_conn.close.assert_called_once()
|
||||
|
||||
@patch("application.agents.tools.postgres.psycopg2.connect")
|
||||
def test_db_error(self, mock_connect, tool):
|
||||
import psycopg2
|
||||
|
||||
mock_connect.side_effect = psycopg2.Error("connection refused")
|
||||
|
||||
result = tool.execute_action(
|
||||
"postgres_execute_sql", sql_query="SELECT 1"
|
||||
)
|
||||
|
||||
assert result["status_code"] == 500
|
||||
assert "Database error" in result["error"]
|
||||
|
||||
@patch("application.agents.tools.postgres.psycopg2.connect")
|
||||
def test_get_schema(self, mock_connect, tool):
|
||||
mock_conn = MagicMock()
|
||||
mock_cur = MagicMock()
|
||||
mock_cur.fetchall.return_value = [
|
||||
("users", "id", "integer", "nextval(...)", "NO"),
|
||||
("users", "name", "varchar", None, "YES"),
|
||||
("posts", "id", "integer", "nextval(...)", "NO"),
|
||||
]
|
||||
mock_conn.cursor.return_value = mock_cur
|
||||
mock_connect.return_value = mock_conn
|
||||
|
||||
result = tool.execute_action("postgres_get_schema", db_name="testdb")
|
||||
|
||||
assert result["status_code"] == 200
|
||||
assert "users" in result["schema"]
|
||||
assert "posts" in result["schema"]
|
||||
assert len(result["schema"]["users"]) == 2
|
||||
assert result["schema"]["users"][0]["column_name"] == "id"
|
||||
mock_conn.close.assert_called_once()
|
||||
|
||||
@patch("application.agents.tools.postgres.psycopg2.connect")
|
||||
def test_get_schema_db_error(self, mock_connect, tool):
|
||||
import psycopg2
|
||||
|
||||
mock_connect.side_effect = psycopg2.Error("auth failed")
|
||||
|
||||
result = tool.execute_action("postgres_get_schema", db_name="testdb")
|
||||
|
||||
assert result["status_code"] == 500
|
||||
assert "Database error" in result["error"]
|
||||
|
||||
@patch("application.agents.tools.postgres.psycopg2.connect")
|
||||
def test_connection_closed_on_error(self, mock_connect, tool):
|
||||
import psycopg2
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cur = MagicMock()
|
||||
mock_cur.execute.side_effect = psycopg2.Error("syntax error")
|
||||
mock_conn.cursor.return_value = mock_cur
|
||||
mock_connect.return_value = mock_conn
|
||||
|
||||
tool.execute_action("postgres_execute_sql", sql_query="BAD SQL")
|
||||
|
||||
mock_conn.close.assert_called_once()
|
||||
|
||||
@patch("application.agents.tools.postgres.psycopg2.connect")
|
||||
def test_select_with_no_description(self, mock_connect, tool):
|
||||
mock_conn = MagicMock()
|
||||
mock_cur = MagicMock()
|
||||
mock_cur.description = None
|
||||
mock_cur.fetchall.return_value = []
|
||||
mock_conn.cursor.return_value = mock_cur
|
||||
mock_connect.return_value = mock_conn
|
||||
|
||||
result = tool.execute_action(
|
||||
"postgres_execute_sql", sql_query="SELECT 1 WHERE false"
|
||||
)
|
||||
|
||||
assert result["status_code"] == 200
|
||||
assert result["response_data"]["column_names"] == []
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestPostgresMetadata:
|
||||
def test_actions_metadata(self, tool):
|
||||
meta = tool.get_actions_metadata()
|
||||
assert len(meta) == 2
|
||||
names = {a["name"] for a in meta}
|
||||
assert "postgres_execute_sql" in names
|
||||
assert "postgres_get_schema" in names
|
||||
|
||||
def test_config_requirements(self, tool):
|
||||
reqs = tool.get_config_requirements()
|
||||
assert "token" in reqs
|
||||
assert reqs["token"]["secret"] is True
|
||||
86
tests/agents/test_read_webpage_tool.py
Normal file
86
tests/agents/test_read_webpage_tool.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""Tests for application/agents/tools/read_webpage.py"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from application.agents.tools.read_webpage import ReadWebpageTool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool():
|
||||
return ReadWebpageTool()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestReadWebpageExecuteAction:
|
||||
def test_unknown_action(self, tool):
|
||||
result = tool.execute_action("unknown_action")
|
||||
assert "Error" in result
|
||||
assert "Unknown action" in result
|
||||
|
||||
def test_missing_url(self, tool):
|
||||
result = tool.execute_action("read_webpage")
|
||||
assert "Error" in result
|
||||
assert "URL parameter is missing" in result
|
||||
|
||||
@patch("application.agents.tools.read_webpage.validate_url")
|
||||
@patch("application.agents.tools.read_webpage.requests.get")
|
||||
def test_successful_fetch(self, mock_get, mock_validate, tool):
|
||||
mock_validate.return_value = "https://example.com"
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.text = "<html><body><h1>Title</h1><p>Content</p></body></html>"
|
||||
mock_get.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("read_webpage", url="https://example.com")
|
||||
|
||||
assert "Title" in result
|
||||
assert "Content" in result
|
||||
|
||||
@patch("application.agents.tools.read_webpage.validate_url")
|
||||
@patch("application.agents.tools.read_webpage.requests.get")
|
||||
def test_request_error(self, mock_get, mock_validate, tool):
|
||||
mock_validate.return_value = "https://example.com"
|
||||
mock_get.side_effect = requests.exceptions.ConnectionError("refused")
|
||||
|
||||
result = tool.execute_action("read_webpage", url="https://example.com")
|
||||
|
||||
assert "Error fetching URL" in result
|
||||
|
||||
@patch("application.agents.tools.read_webpage.validate_url")
|
||||
def test_ssrf_blocked(self, mock_validate, tool):
|
||||
from application.core.url_validation import SSRFError
|
||||
|
||||
mock_validate.side_effect = SSRFError("blocked")
|
||||
|
||||
result = tool.execute_action("read_webpage", url="http://169.254.169.254/")
|
||||
|
||||
assert "Error" in result
|
||||
assert "validation failed" in result
|
||||
|
||||
@patch("application.agents.tools.read_webpage.validate_url")
|
||||
@patch("application.agents.tools.read_webpage.requests.get")
|
||||
def test_http_error(self, mock_get, mock_validate, tool):
|
||||
mock_validate.return_value = "https://example.com/404"
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.raise_for_status.side_effect = requests.exceptions.HTTPError("404")
|
||||
mock_get.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("read_webpage", url="https://example.com/404")
|
||||
|
||||
assert "Error fetching URL" in result
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestReadWebpageMetadata:
|
||||
def test_actions_metadata(self, tool):
|
||||
meta = tool.get_actions_metadata()
|
||||
assert len(meta) == 1
|
||||
assert meta[0]["name"] == "read_webpage"
|
||||
assert "url" in meta[0]["parameters"]["properties"]
|
||||
assert "url" in meta[0]["parameters"]["required"]
|
||||
|
||||
def test_config_requirements(self, tool):
|
||||
assert tool.get_config_requirements() == {}
|
||||
File diff suppressed because it is too large
Load Diff
692
tests/agents/test_spec_parser.py
Normal file
692
tests/agents/test_spec_parser.py
Normal file
@@ -0,0 +1,692 @@
|
||||
"""Tests for application/agents/tools/spec_parser.py"""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from application.agents.tools.spec_parser import (
|
||||
_extract_metadata,
|
||||
_generate_action_name,
|
||||
_get_base_url,
|
||||
_load_spec,
|
||||
_param_to_property,
|
||||
_resolve_ref,
|
||||
_validate_spec,
|
||||
parse_spec,
|
||||
)
|
||||
|
||||
|
||||
MINIMAL_OPENAPI = json.dumps(
|
||||
{
|
||||
"openapi": "3.0.0",
|
||||
"info": {"title": "Test API", "version": "1.0.0"},
|
||||
"servers": [{"url": "https://api.example.com"}],
|
||||
"paths": {
|
||||
"/users": {
|
||||
"get": {
|
||||
"operationId": "listUsers",
|
||||
"summary": "List all users",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "limit",
|
||||
"in": "query",
|
||||
"schema": {"type": "integer"},
|
||||
}
|
||||
],
|
||||
"responses": {"200": {"description": "OK"}},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
MINIMAL_SWAGGER = json.dumps(
|
||||
{
|
||||
"swagger": "2.0",
|
||||
"info": {"title": "Swagger API", "version": "2.0.0"},
|
||||
"host": "api.example.com",
|
||||
"basePath": "/v1",
|
||||
"schemes": ["https"],
|
||||
"paths": {
|
||||
"/pets": {
|
||||
"get": {
|
||||
"operationId": "listPets",
|
||||
"summary": "List pets",
|
||||
"parameters": [],
|
||||
"responses": {"200": {"description": "OK"}},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestLoadSpec:
|
||||
def test_load_json(self):
|
||||
spec = _load_spec('{"openapi": "3.0.0"}')
|
||||
assert spec["openapi"] == "3.0.0"
|
||||
|
||||
def test_load_yaml(self):
|
||||
spec = _load_spec("openapi: '3.0.0'\ninfo:\n title: Test")
|
||||
assert spec["openapi"] == "3.0.0"
|
||||
|
||||
def test_empty_raises(self):
|
||||
with pytest.raises(ValueError, match="Empty"):
|
||||
_load_spec("")
|
||||
|
||||
def test_whitespace_only_raises(self):
|
||||
with pytest.raises(ValueError, match="Empty"):
|
||||
_load_spec(" \n ")
|
||||
|
||||
def test_invalid_json_raises(self):
|
||||
with pytest.raises(ValueError, match="Invalid"):
|
||||
_load_spec("{invalid json}")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestValidateSpec:
|
||||
def test_valid_openapi(self):
|
||||
spec = json.loads(MINIMAL_OPENAPI)
|
||||
_validate_spec(spec) # should not raise
|
||||
|
||||
def test_valid_swagger(self):
|
||||
spec = json.loads(MINIMAL_SWAGGER)
|
||||
_validate_spec(spec)
|
||||
|
||||
def test_missing_version_raises(self):
|
||||
with pytest.raises(ValueError, match="Unsupported"):
|
||||
_validate_spec({"paths": {"/a": {}}})
|
||||
|
||||
def test_no_paths_raises(self):
|
||||
with pytest.raises(ValueError, match="No API paths"):
|
||||
_validate_spec({"openapi": "3.0.0"})
|
||||
|
||||
def test_empty_paths_raises(self):
|
||||
with pytest.raises(ValueError, match="No API paths"):
|
||||
_validate_spec({"openapi": "3.0.0", "paths": {}})
|
||||
|
||||
def test_non_dict_raises(self):
|
||||
with pytest.raises(ValueError, match="valid object"):
|
||||
_validate_spec("not a dict")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestExtractMetadata:
|
||||
def test_openapi_metadata(self):
|
||||
spec = json.loads(MINIMAL_OPENAPI)
|
||||
meta = _extract_metadata(spec, is_swagger=False)
|
||||
assert meta["title"] == "Test API"
|
||||
assert meta["version"] == "1.0.0"
|
||||
assert meta["base_url"] == "https://api.example.com"
|
||||
|
||||
def test_swagger_metadata(self):
|
||||
spec = json.loads(MINIMAL_SWAGGER)
|
||||
meta = _extract_metadata(spec, is_swagger=True)
|
||||
assert meta["title"] == "Swagger API"
|
||||
assert meta["base_url"] == "https://api.example.com/v1"
|
||||
|
||||
def test_missing_info(self):
|
||||
spec = {"openapi": "3.0.0", "paths": {"/a": {}}}
|
||||
meta = _extract_metadata(spec, is_swagger=False)
|
||||
assert meta["title"] == "Untitled API"
|
||||
|
||||
def test_description_truncated(self):
|
||||
spec = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {"title": "T", "description": "x" * 1000},
|
||||
"paths": {"/a": {}},
|
||||
}
|
||||
meta = _extract_metadata(spec, is_swagger=False)
|
||||
assert len(meta["description"]) <= 500
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGetBaseUrl:
|
||||
def test_openapi_servers(self):
|
||||
spec = {"servers": [{"url": "https://api.example.com/v2/"}]}
|
||||
assert _get_base_url(spec, is_swagger=False) == "https://api.example.com/v2"
|
||||
|
||||
def test_openapi_no_servers(self):
|
||||
assert _get_base_url({}, is_swagger=False) == ""
|
||||
|
||||
def test_swagger_with_host(self):
|
||||
spec = {"host": "api.test.com", "basePath": "/v1", "schemes": ["https"]}
|
||||
assert _get_base_url(spec, is_swagger=True) == "https://api.test.com/v1"
|
||||
|
||||
def test_swagger_no_host(self):
|
||||
assert _get_base_url({}, is_swagger=True) == ""
|
||||
|
||||
def test_swagger_default_scheme(self):
|
||||
spec = {"host": "api.test.com", "basePath": ""}
|
||||
assert _get_base_url(spec, is_swagger=True) == "https://api.test.com"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGenerateActionName:
|
||||
def test_uses_operation_id(self):
|
||||
assert _generate_action_name({"operationId": "getUser"}, "get", "/users") == "getUser"
|
||||
|
||||
def test_fallback_to_method_path(self):
|
||||
name = _generate_action_name({}, "get", "/users/{id}")
|
||||
assert name.startswith("get_")
|
||||
assert "users" in name
|
||||
|
||||
def test_truncates_long_names(self):
|
||||
name = _generate_action_name({}, "get", "/" + "a" * 200)
|
||||
assert len(name) <= 64
|
||||
|
||||
def test_sanitizes_special_chars(self):
|
||||
name = _generate_action_name({"operationId": "get.user@v2"}, "get", "/")
|
||||
assert "." not in name
|
||||
assert "@" not in name
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestParamToProperty:
|
||||
def test_string_param(self):
|
||||
param = {"name": "q", "in": "query", "schema": {"type": "string"}}
|
||||
prop = _param_to_property(param)
|
||||
assert prop["type"] == "string"
|
||||
assert prop["required"] is False
|
||||
|
||||
def test_integer_param(self):
|
||||
param = {
|
||||
"name": "limit",
|
||||
"in": "query",
|
||||
"required": True,
|
||||
"schema": {"type": "integer"},
|
||||
}
|
||||
prop = _param_to_property(param)
|
||||
assert prop["type"] == "integer"
|
||||
assert prop["required"] is True
|
||||
assert prop["filled_by_llm"] is True
|
||||
|
||||
def test_number_maps_to_integer(self):
|
||||
param = {"name": "score", "in": "query", "schema": {"type": "number"}}
|
||||
prop = _param_to_property(param)
|
||||
assert prop["type"] == "integer"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestResolveRef:
|
||||
def test_no_ref(self):
|
||||
obj = {"type": "string"}
|
||||
assert _resolve_ref(obj, {}, {}) == obj
|
||||
|
||||
def test_components_ref(self):
|
||||
components = {"schemas": {"User": {"type": "object", "properties": {"name": {"type": "string"}}}}}
|
||||
obj = {"$ref": "#/components/schemas/User"}
|
||||
result = _resolve_ref(obj, components, {})
|
||||
assert result["type"] == "object"
|
||||
|
||||
def test_definitions_ref(self):
|
||||
definitions = {"Pet": {"type": "object"}}
|
||||
obj = {"$ref": "#/definitions/Pet"}
|
||||
result = _resolve_ref(obj, {}, definitions)
|
||||
assert result["type"] == "object"
|
||||
|
||||
def test_unsupported_ref(self):
|
||||
obj = {"$ref": "#/external/something"}
|
||||
assert _resolve_ref(obj, {}, {}) is None
|
||||
|
||||
def test_non_dict_returns_none(self):
|
||||
assert _resolve_ref("string", {}, {}) is None
|
||||
assert _resolve_ref(42, {}, {}) is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestParseSpec:
|
||||
def test_openapi_full_parse(self):
|
||||
metadata, actions = parse_spec(MINIMAL_OPENAPI)
|
||||
assert metadata["title"] == "Test API"
|
||||
assert len(actions) == 1
|
||||
assert actions[0]["name"] == "listUsers"
|
||||
assert actions[0]["method"] == "GET"
|
||||
assert actions[0]["url"] == "https://api.example.com/users"
|
||||
assert "limit" in actions[0]["query_params"]["properties"]
|
||||
|
||||
def test_swagger_full_parse(self):
|
||||
metadata, actions = parse_spec(MINIMAL_SWAGGER)
|
||||
assert metadata["title"] == "Swagger API"
|
||||
assert len(actions) == 1
|
||||
assert actions[0]["name"] == "listPets"
|
||||
assert actions[0]["method"] == "GET"
|
||||
|
||||
def test_multiple_methods(self):
|
||||
spec = json.dumps(
|
||||
{
|
||||
"openapi": "3.0.0",
|
||||
"info": {"title": "T", "version": "1"},
|
||||
"paths": {
|
||||
"/items": {
|
||||
"get": {"operationId": "listItems", "responses": {}},
|
||||
"post": {
|
||||
"operationId": "createItem",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"}
|
||||
},
|
||||
"required": ["name"],
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
metadata, actions = parse_spec(spec)
|
||||
assert len(actions) == 2
|
||||
names = {a["name"] for a in actions}
|
||||
assert "listItems" in names
|
||||
assert "createItem" in names
|
||||
|
||||
create = next(a for a in actions if a["name"] == "createItem")
|
||||
assert "name" in create["body"]["properties"]
|
||||
|
||||
def test_header_params(self):
|
||||
spec = json.dumps(
|
||||
{
|
||||
"openapi": "3.0.0",
|
||||
"info": {"title": "T", "version": "1"},
|
||||
"paths": {
|
||||
"/data": {
|
||||
"get": {
|
||||
"operationId": "getData",
|
||||
"parameters": [
|
||||
{"name": "X-API-Key", "in": "header", "schema": {"type": "string"}}
|
||||
],
|
||||
"responses": {},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
_, actions = parse_spec(spec)
|
||||
assert "X-API-Key" in actions[0]["headers"]["properties"]
|
||||
|
||||
def test_invalid_spec_raises(self):
|
||||
with pytest.raises(ValueError):
|
||||
parse_spec("")
|
||||
|
||||
def test_yaml_spec(self):
|
||||
yaml_spec = """
|
||||
openapi: "3.0.0"
|
||||
info:
|
||||
title: YAML API
|
||||
version: "1.0"
|
||||
paths:
|
||||
/health:
|
||||
get:
|
||||
operationId: healthCheck
|
||||
responses:
|
||||
"200":
|
||||
description: OK
|
||||
"""
|
||||
metadata, actions = parse_spec(yaml_spec)
|
||||
assert metadata["title"] == "YAML API"
|
||||
assert actions[0]["name"] == "healthCheck"
|
||||
|
||||
def test_non_dict_path_item_skipped(self):
|
||||
"""Cover line 117: non-dict path item is skipped."""
|
||||
spec = json.dumps(
|
||||
{
|
||||
"openapi": "3.0.0",
|
||||
"info": {"title": "T", "version": "1"},
|
||||
"paths": {
|
||||
"/valid": {
|
||||
"get": {
|
||||
"operationId": "validOp",
|
||||
"responses": {"200": {"description": "OK"}},
|
||||
}
|
||||
},
|
||||
"/invalid": "not_a_dict",
|
||||
},
|
||||
}
|
||||
)
|
||||
_, actions = parse_spec(spec)
|
||||
assert len(actions) == 1
|
||||
assert actions[0]["name"] == "validOp"
|
||||
|
||||
def test_non_dict_operation_skipped(self):
|
||||
"""Cover line 122: non-dict operation for a method is skipped."""
|
||||
spec = json.dumps(
|
||||
{
|
||||
"openapi": "3.0.0",
|
||||
"info": {"title": "T", "version": "1"},
|
||||
"paths": {
|
||||
"/items": {
|
||||
"get": "not_a_dict",
|
||||
"post": {
|
||||
"operationId": "createItem",
|
||||
"responses": {},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
_, actions = parse_spec(spec)
|
||||
assert len(actions) == 1
|
||||
assert actions[0]["name"] == "createItem"
|
||||
|
||||
def test_operation_parse_failure_logged(self):
|
||||
"""Cover lines 137: exception parsing operation is caught."""
|
||||
spec = json.dumps(
|
||||
{
|
||||
"openapi": "3.0.0",
|
||||
"info": {"title": "T", "version": "1"},
|
||||
"paths": {
|
||||
"/items": {
|
||||
"get": {
|
||||
"operationId": "getItems",
|
||||
"responses": {},
|
||||
},
|
||||
"post": {
|
||||
"operationId": "createItem",
|
||||
"requestBody": {
|
||||
"$ref": "#/components/schemas/Missing"
|
||||
},
|
||||
"responses": {},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
_, actions = parse_spec(spec)
|
||||
# At least the GET should succeed
|
||||
assert any(a["name"] == "getItems" for a in actions)
|
||||
|
||||
def test_path_level_params_merged(self):
|
||||
"""Cover lines 129-130, 148, 159: path-level parameters merged."""
|
||||
spec = json.dumps(
|
||||
{
|
||||
"openapi": "3.0.0",
|
||||
"info": {"title": "T", "version": "1"},
|
||||
"paths": {
|
||||
"/items/{id}": {
|
||||
"parameters": [
|
||||
{
|
||||
"name": "id",
|
||||
"in": "path",
|
||||
"required": True,
|
||||
"schema": {"type": "string"},
|
||||
}
|
||||
],
|
||||
"get": {
|
||||
"operationId": "getItem",
|
||||
"responses": {},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
_, actions = parse_spec(spec)
|
||||
assert "id" in actions[0]["query_params"]["properties"]
|
||||
|
||||
def test_swagger_body_param_extraction(self):
|
||||
"""Cover lines 145, 148, 152-153: Swagger 2.0 body parameter extraction."""
|
||||
spec = json.dumps(
|
||||
{
|
||||
"swagger": "2.0",
|
||||
"info": {"title": "T", "version": "1"},
|
||||
"host": "api.test.com",
|
||||
"basePath": "/v1",
|
||||
"schemes": ["https"],
|
||||
"paths": {
|
||||
"/items": {
|
||||
"post": {
|
||||
"operationId": "createItem",
|
||||
"consumes": ["application/json"],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "body",
|
||||
"in": "body",
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "Item name",
|
||||
}
|
||||
},
|
||||
"required": ["name"],
|
||||
},
|
||||
}
|
||||
],
|
||||
"responses": {"201": {"description": "Created"}},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
_, actions = parse_spec(spec)
|
||||
assert len(actions) == 1
|
||||
assert "name" in actions[0]["body"]["properties"]
|
||||
assert actions[0]["body_content_type"] == "application/json"
|
||||
|
||||
def test_traverse_path_key_error(self):
|
||||
"""Cover lines 173-176: _traverse_path returns None on KeyError."""
|
||||
from application.agents.tools.spec_parser import _traverse_path
|
||||
|
||||
result = _traverse_path({"a": {"b": 1}}, ["a", "c"])
|
||||
assert result is None
|
||||
|
||||
def test_traverse_path_non_dict_result(self):
|
||||
"""Cover line 175-176: _traverse_path returns None for non-dict result."""
|
||||
from application.agents.tools.spec_parser import _traverse_path
|
||||
|
||||
result = _traverse_path({"a": "string_value"}, ["a"])
|
||||
assert result is None
|
||||
|
||||
def test_openapi_request_body_form_urlencoded(self):
|
||||
"""Cover lines 152-153: OpenAPI 3.x request body with form-urlencoded."""
|
||||
spec = json.dumps(
|
||||
{
|
||||
"openapi": "3.0.0",
|
||||
"info": {"title": "T", "version": "1"},
|
||||
"paths": {
|
||||
"/login": {
|
||||
"post": {
|
||||
"operationId": "login",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/x-www-form-urlencoded": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"username": {"type": "string"},
|
||||
"password": {"type": "string"},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
_, actions = parse_spec(spec)
|
||||
assert actions[0]["body_content_type"] == "application/x-www-form-urlencoded"
|
||||
assert "username" in actions[0]["body"]["properties"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Coverage — additional uncovered lines: 205, 209, 213, 216-217, 222, 228
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSpecParserAdditionalCoverage:
|
||||
|
||||
def test_categorize_params_query_and_header(self):
|
||||
"""Cover lines 205, 209, 213: parameters categorized into query and header."""
|
||||
from application.agents.tools.spec_parser import _categorize_parameters
|
||||
|
||||
parameters = [
|
||||
{"name": "q", "in": "query", "required": True, "description": "Query param"},
|
||||
{"name": "X-Auth", "in": "header", "required": False, "description": "Auth header"},
|
||||
{"name": "id", "in": "path", "required": True, "description": "Path param"},
|
||||
]
|
||||
query_params, headers = _categorize_parameters(parameters, {}, {})
|
||||
assert "q" in query_params
|
||||
assert "X-Auth" in headers
|
||||
assert "id" in query_params # path params go to query_params
|
||||
|
||||
def test_categorize_params_skips_no_name(self):
|
||||
"""Cover line 205: parameters without name are skipped."""
|
||||
from application.agents.tools.spec_parser import _categorize_parameters
|
||||
|
||||
parameters = [
|
||||
{"in": "query"}, # no name
|
||||
]
|
||||
query_params, headers = _categorize_parameters(parameters, {}, {})
|
||||
assert len(query_params) == 0
|
||||
assert len(headers) == 0
|
||||
|
||||
def test_param_to_property_integer_type(self):
|
||||
"""Cover lines 216-217, 222, 228: _param_to_property with integer type."""
|
||||
from application.agents.tools.spec_parser import _param_to_property
|
||||
|
||||
param = {
|
||||
"name": "count",
|
||||
"schema": {"type": "integer"},
|
||||
"description": "Count of items",
|
||||
"required": True,
|
||||
}
|
||||
prop = _param_to_property(param)
|
||||
assert prop["type"] == "integer"
|
||||
assert prop["required"] is True
|
||||
assert prop["filled_by_llm"] is True
|
||||
|
||||
def test_param_to_property_number_type(self):
|
||||
"""Cover line 222: number type mapped to integer."""
|
||||
from application.agents.tools.spec_parser import _param_to_property
|
||||
|
||||
param = {
|
||||
"schema": {"type": "number"},
|
||||
"description": "A number",
|
||||
"required": False,
|
||||
}
|
||||
prop = _param_to_property(param)
|
||||
assert prop["type"] == "integer"
|
||||
|
||||
def test_param_to_property_string_default(self):
|
||||
"""Cover line 222: unknown type defaults to string."""
|
||||
from application.agents.tools.spec_parser import _param_to_property
|
||||
|
||||
param = {"description": "Desc", "required": False}
|
||||
prop = _param_to_property(param)
|
||||
assert prop["type"] == "string"
|
||||
|
||||
def test_param_to_property_description_truncated(self):
|
||||
"""Cover line 228: description truncated to 200 chars."""
|
||||
from application.agents.tools.spec_parser import _param_to_property
|
||||
|
||||
param = {"description": "x" * 300, "required": False}
|
||||
prop = _param_to_property(param)
|
||||
assert len(prop["description"]) == 200
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Additional coverage for spec_parser.py
|
||||
# Lines: 57-59 (YAML error), 116-117 (non-dict path_item), 136-140
|
||||
# (action parse exception), 156 (full_url with no base_url),
|
||||
# 184-190 (generate_action_name from path), 99 (swagger base URL)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
from application.agents.tools.spec_parser import _extract_actions # noqa: E402
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestLoadSpecYAMLError:
|
||||
"""Cover lines 58-59: YAML parse error."""
|
||||
|
||||
def test_invalid_yaml_raises(self):
|
||||
with pytest.raises(ValueError, match="Invalid YAML"):
|
||||
_load_spec(" \ttabs: [invalid: yaml: {{")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestExtractActionsNonDictPathItem:
|
||||
"""Cover lines 116-117: non-dict path_item is skipped."""
|
||||
|
||||
def test_non_dict_path_skipped(self):
|
||||
spec = {
|
||||
"openapi": "3.0.0",
|
||||
"paths": {
|
||||
"/valid": {"get": {"operationId": "getValid"}},
|
||||
"/invalid": "not-a-dict",
|
||||
},
|
||||
}
|
||||
actions = _extract_actions(spec, False)
|
||||
assert len(actions) == 1
|
||||
assert actions[0]["name"] == "getValid"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestExtractActionsParseException:
|
||||
"""Cover lines 136-140: exception in _build_action is caught."""
|
||||
|
||||
def test_bad_operation_skipped(self):
|
||||
spec = {
|
||||
"openapi": "3.0.0",
|
||||
"paths": {
|
||||
"/test": {
|
||||
"get": {
|
||||
"operationId": "good",
|
||||
},
|
||||
"post": {
|
||||
"operationId": "bad",
|
||||
"parameters": [{"$ref": "#/invalid/ref"}],
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
# Should not raise, bad operation is skipped
|
||||
actions = _extract_actions(spec, False)
|
||||
assert len(actions) >= 1
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGenerateActionNameFromPath:
|
||||
"""Cover lines 184-190: operationId missing, generate from path."""
|
||||
|
||||
def test_name_from_path(self):
|
||||
name = _generate_action_name({}, "get", "/users/{id}/posts")
|
||||
assert name.startswith("get_")
|
||||
assert "users" in name
|
||||
assert "{" not in name
|
||||
|
||||
def test_name_truncated(self):
|
||||
long_path = "/a" * 100
|
||||
name = _generate_action_name({}, "post", long_path)
|
||||
assert len(name) <= 64
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGetBaseUrlSwagger:
|
||||
"""Cover line 99: swagger base URL with host and scheme."""
|
||||
|
||||
def test_swagger_base_url(self):
|
||||
spec = {
|
||||
"swagger": "2.0",
|
||||
"host": "api.example.com",
|
||||
"basePath": "/v2",
|
||||
"schemes": ["https"],
|
||||
}
|
||||
url = _get_base_url(spec, True)
|
||||
assert url == "https://api.example.com/v2"
|
||||
|
||||
def test_swagger_no_host(self):
|
||||
spec = {"swagger": "2.0"}
|
||||
url = _get_base_url(spec, True)
|
||||
assert url == ""
|
||||
79
tests/agents/test_telegram_tool.py
Normal file
79
tests/agents/test_telegram_tool.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""Tests for application/agents/tools/telegram.py"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from application.agents.tools.telegram import TelegramTool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool():
|
||||
return TelegramTool(config={"token": "bot123:ABC"})
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestTelegramExecuteAction:
|
||||
def test_unknown_action_raises(self, tool):
|
||||
with pytest.raises(ValueError, match="Unknown action"):
|
||||
tool.execute_action("invalid")
|
||||
|
||||
@patch("application.agents.tools.telegram.requests.post")
|
||||
def test_send_message(self, mock_post, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action(
|
||||
"telegram_send_message", text="Hello", chat_id="12345"
|
||||
)
|
||||
|
||||
assert result["status_code"] == 200
|
||||
assert result["message"] == "Message sent"
|
||||
call_args = mock_post.call_args
|
||||
assert "bot123:ABC/sendMessage" in call_args[0][0]
|
||||
assert call_args[1]["data"]["text"] == "Hello"
|
||||
assert call_args[1]["data"]["chat_id"] == "12345"
|
||||
|
||||
@patch("application.agents.tools.telegram.requests.post")
|
||||
def test_send_image(self, mock_post, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action(
|
||||
"telegram_send_image", image_url="https://img.com/cat.jpg", chat_id="12345"
|
||||
)
|
||||
|
||||
assert result["status_code"] == 200
|
||||
assert result["message"] == "Image sent"
|
||||
call_args = mock_post.call_args
|
||||
assert "bot123:ABC/sendPhoto" in call_args[0][0]
|
||||
assert call_args[1]["data"]["photo"] == "https://img.com/cat.jpg"
|
||||
|
||||
@patch("application.agents.tools.telegram.requests.post")
|
||||
def test_api_error_status(self, mock_post, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 403
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action(
|
||||
"telegram_send_message", text="Hi", chat_id="999"
|
||||
)
|
||||
|
||||
assert result["status_code"] == 403
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestTelegramMetadata:
|
||||
def test_actions_metadata(self, tool):
|
||||
meta = tool.get_actions_metadata()
|
||||
assert len(meta) == 2
|
||||
names = {a["name"] for a in meta}
|
||||
assert "telegram_send_message" in names
|
||||
assert "telegram_send_image" in names
|
||||
|
||||
def test_config_requirements(self, tool):
|
||||
reqs = tool.get_config_requirements()
|
||||
assert "token" in reqs
|
||||
assert reqs["token"]["secret"] is True
|
||||
@@ -277,3 +277,279 @@ class TestToolExecutorExecute:
|
||||
|
||||
# load_tool called only once due to cache
|
||||
assert mock_tool_manager.load_tool.call_count == 1
|
||||
|
||||
def test_execute_api_tool(self, mock_tool_manager, monkeypatch):
|
||||
"""Cover lines 199-202, 256-267: api_tool execution path."""
|
||||
executor = ToolExecutor(user="test_user")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"application.agents.tool_executor.ToolActionParser",
|
||||
lambda _cls: Mock(
|
||||
parse_args=Mock(return_value=("t1", "get_users", {"body_param": "val"}))
|
||||
),
|
||||
)
|
||||
|
||||
tools_dict = {
|
||||
"t1": {
|
||||
"name": "api_tool",
|
||||
"config": {
|
||||
"actions": {
|
||||
"get_users": {
|
||||
"name": "get_users",
|
||||
"description": "Get users",
|
||||
"url": "https://api.example.com/users",
|
||||
"method": "GET",
|
||||
"query_params": {"properties": {}},
|
||||
"headers": {"properties": {}},
|
||||
"body": {"properties": {}},
|
||||
"active": True,
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
call = self._make_call(name="get_users_t1", call_id="c2")
|
||||
gen = executor.execute(tools_dict, call, "MockLLM")
|
||||
|
||||
events = []
|
||||
result = None
|
||||
while True:
|
||||
try:
|
||||
events.append(next(gen))
|
||||
except StopIteration as e:
|
||||
result = e.value
|
||||
break
|
||||
|
||||
assert result is not None
|
||||
statuses = [e["data"]["status"] for e in events]
|
||||
assert "pending" in statuses
|
||||
|
||||
def test_execute_with_prefilled_param_values(self, mock_tool_manager, monkeypatch):
|
||||
"""Cover line 179: params not in call_args use default value."""
|
||||
executor = ToolExecutor(user="test_user")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"application.agents.tool_executor.ToolActionParser",
|
||||
lambda _cls: Mock(
|
||||
parse_args=Mock(return_value=("t1", "act", {}))
|
||||
),
|
||||
)
|
||||
|
||||
tools_dict = {
|
||||
"t1": {
|
||||
"name": "test_tool",
|
||||
"config": {"key": "val"},
|
||||
"actions": [
|
||||
{
|
||||
"name": "act",
|
||||
"description": "Test",
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"hidden_param": {
|
||||
"type": "string",
|
||||
"value": "default_val",
|
||||
"filled_by_llm": False,
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
call = self._make_call(name="act_t1")
|
||||
gen = executor.execute(tools_dict, call, "MockLLM")
|
||||
|
||||
while True:
|
||||
try:
|
||||
next(gen)
|
||||
except StopIteration as e:
|
||||
result = e.value
|
||||
break
|
||||
|
||||
assert result[0] == "Tool result"
|
||||
|
||||
def test_execute_tool_with_artifact_id(self, mock_tool_manager, monkeypatch):
|
||||
"""Cover lines 217-218: tool with get_artifact_id."""
|
||||
executor = ToolExecutor(user="test_user")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"application.agents.tool_executor.ToolActionParser",
|
||||
lambda _cls: Mock(
|
||||
parse_args=Mock(return_value=("t1", "act", {"q": "v"}))
|
||||
),
|
||||
)
|
||||
|
||||
mock_tool = mock_tool_manager.load_tool.return_value
|
||||
mock_tool.get_artifact_id = Mock(return_value="artifact-123")
|
||||
|
||||
tools_dict = {
|
||||
"t1": {
|
||||
"name": "test_tool",
|
||||
"config": {"key": "val"},
|
||||
"actions": [
|
||||
{
|
||||
"name": "act",
|
||||
"description": "Test",
|
||||
"parameters": {"properties": {}},
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
call = self._make_call(name="act_t1")
|
||||
gen = executor.execute(tools_dict, call, "MockLLM")
|
||||
|
||||
events = []
|
||||
while True:
|
||||
try:
|
||||
events.append(next(gen))
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
completed_events = [
|
||||
e for e in events if e["data"].get("status") == "completed"
|
||||
]
|
||||
assert any(
|
||||
"artifact_id" in e.get("data", {}) for e in completed_events
|
||||
)
|
||||
|
||||
def test_get_or_load_tool_encrypted_credentials(self, monkeypatch):
|
||||
"""Cover lines 273-278: encrypted credentials path."""
|
||||
executor = ToolExecutor(user="test_user")
|
||||
|
||||
mock_tm = Mock()
|
||||
mock_tool = Mock()
|
||||
mock_tm.load_tool.return_value = mock_tool
|
||||
monkeypatch.setattr(
|
||||
"application.agents.tool_executor.ToolManager", lambda config: mock_tm
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.agents.tool_executor.decrypt_credentials",
|
||||
lambda creds, user: {"api_key": "decrypted_key"},
|
||||
)
|
||||
|
||||
tool_data = {
|
||||
"name": "custom_tool",
|
||||
"config": {"encrypted_credentials": "encrypted_blob"},
|
||||
}
|
||||
|
||||
result = executor._get_or_load_tool(tool_data, "t1", "act")
|
||||
assert result is mock_tool
|
||||
call_kwargs = mock_tm.load_tool.call_args
|
||||
tool_config = call_kwargs[1]["tool_config"] if "tool_config" in call_kwargs[1] else call_kwargs[0][1]
|
||||
assert "api_key" in tool_config.get("auth_credentials", tool_config)
|
||||
|
||||
def test_get_or_load_tool_mcp_tool(self, monkeypatch):
|
||||
"""Cover lines 281-283: mcp_tool path sets query_mode."""
|
||||
executor = ToolExecutor(user="test_user")
|
||||
executor.conversation_id = "conv-123"
|
||||
|
||||
mock_tm = Mock()
|
||||
mock_tool = Mock()
|
||||
mock_tm.load_tool.return_value = mock_tool
|
||||
monkeypatch.setattr(
|
||||
"application.agents.tool_executor.ToolManager", lambda config: mock_tm
|
||||
)
|
||||
|
||||
tool_data = {
|
||||
"name": "mcp_tool",
|
||||
"config": {},
|
||||
}
|
||||
|
||||
result = executor._get_or_load_tool(tool_data, "t1", "act")
|
||||
assert result is mock_tool
|
||||
call_kwargs = mock_tm.load_tool.call_args
|
||||
tool_config = call_kwargs[1].get("tool_config", call_kwargs[0][1] if len(call_kwargs[0]) > 1 else {})
|
||||
assert tool_config.get("query_mode") is True
|
||||
assert tool_config.get("conversation_id") == "conv-123"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Coverage — additional uncovered lines: 217-218, 256-267
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestToolExecutorAdditionalCoverage:
|
||||
|
||||
def test_get_artifact_id_exception_handled(self, monkeypatch):
|
||||
"""Cover lines 217-218: get_artifact_id raises exception."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
executor = ToolExecutor(user="user1")
|
||||
|
||||
mock_tool = Mock()
|
||||
mock_tool.execute_action.return_value = "result"
|
||||
mock_tool.get_artifact_id.side_effect = RuntimeError("artifact error")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"application.agents.tool_executor.ToolManager",
|
||||
lambda config: Mock(load_tool=Mock(return_value=mock_tool)),
|
||||
)
|
||||
|
||||
tools_dict = {
|
||||
"t1": {
|
||||
"name": "custom_tool",
|
||||
"config": {"key": "val"},
|
||||
"actions": [
|
||||
{
|
||||
"name": "action1",
|
||||
"active": True,
|
||||
"parameters": {"properties": {}},
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
# Create a fake call object matching what ToolActionParser expects
|
||||
call = SimpleNamespace(
|
||||
id="c1",
|
||||
function=SimpleNamespace(
|
||||
name="action1_t1",
|
||||
arguments="{}",
|
||||
),
|
||||
)
|
||||
events = list(executor.execute(tools_dict, call, "OpenAILLM"))
|
||||
# Should complete without raising; artifact_id error is logged but not raised
|
||||
assert any(
|
||||
isinstance(e, dict) and e.get("type") == "tool_call"
|
||||
for e in events
|
||||
)
|
||||
|
||||
def test_get_or_load_api_tool_with_body_content_type(self, monkeypatch):
|
||||
"""Cover lines 256-267: api_tool with body_content_type."""
|
||||
executor = ToolExecutor(user="user1")
|
||||
|
||||
mock_tm = Mock()
|
||||
mock_tool = Mock()
|
||||
mock_tm.load_tool.return_value = mock_tool
|
||||
monkeypatch.setattr(
|
||||
"application.agents.tool_executor.ToolManager", lambda config: mock_tm
|
||||
)
|
||||
|
||||
tool_data = {
|
||||
"name": "api_tool",
|
||||
"config": {
|
||||
"actions": {
|
||||
"create": {
|
||||
"url": "https://api.example.com/items",
|
||||
"method": "POST",
|
||||
"body_content_type": "application/json",
|
||||
"body_encoding_rules": {"encode_as": "json"},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
result = executor._get_or_load_tool(
|
||||
tool_data, "t1", "create",
|
||||
headers={"Authorization": "Bearer tok"},
|
||||
query_params={"page": "1"},
|
||||
)
|
||||
assert result is mock_tool
|
||||
# Verify config was built with body_content_type
|
||||
call_args = mock_tm.load_tool.call_args
|
||||
tool_config = call_args[1].get("tool_config", call_args[0][1] if len(call_args[0]) > 1 else {})
|
||||
assert tool_config.get("body_content_type") == "application/json"
|
||||
assert tool_config.get("body_encoding_rules") == {"encode_as": "json"}
|
||||
|
||||
433
tests/agents/test_workflow_agent_coverage.py
Normal file
433
tests/agents/test_workflow_agent_coverage.py
Normal file
@@ -0,0 +1,433 @@
|
||||
"""Tests for WorkflowAgent - covering _parse_embedded_workflow, _load_from_database,
|
||||
_save_workflow_run, _determine_run_status, _serialize_state, and gen flow."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from application.agents.workflows.schemas import (
|
||||
ExecutionStatus,
|
||||
WorkflowGraph,
|
||||
)
|
||||
|
||||
|
||||
def _make_agent(**overrides):
|
||||
"""Create a WorkflowAgent with mocked base class dependencies."""
|
||||
defaults = {
|
||||
"endpoint": "https://api.example.com",
|
||||
"llm_name": "openai",
|
||||
"model_id": "gpt-4",
|
||||
"api_key": "test_key",
|
||||
"user_api_key": None,
|
||||
"prompt": "You are helpful.",
|
||||
"chat_history": [],
|
||||
"decoded_token": {"sub": "user1"},
|
||||
"attachments": [],
|
||||
"json_schema": None,
|
||||
}
|
||||
defaults.update(overrides)
|
||||
|
||||
with patch("application.agents.workflow_agent.log_activity", lambda **kw: lambda f: f):
|
||||
from application.agents.workflow_agent import WorkflowAgent
|
||||
agent = WorkflowAgent(**defaults)
|
||||
return agent
|
||||
|
||||
|
||||
class TestWorkflowAgentInit:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_sets_attributes(self):
|
||||
agent = _make_agent(workflow_id="wf1", workflow_owner="owner1")
|
||||
assert agent.workflow_id == "wf1"
|
||||
assert agent.workflow_owner == "owner1"
|
||||
assert agent._engine is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_embedded_workflow(self):
|
||||
wf_data = {"nodes": [], "edges": [], "name": "Test"}
|
||||
agent = _make_agent(workflow=wf_data)
|
||||
assert agent._workflow_data == wf_data
|
||||
|
||||
|
||||
class TestParseEmbeddedWorkflow:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_parses_valid_workflow(self):
|
||||
wf_data = {
|
||||
"name": "Test Workflow",
|
||||
"description": "A test",
|
||||
"nodes": [
|
||||
{"id": "n1", "type": "start", "title": "Start", "data": {}, "position": {"x": 0, "y": 0}},
|
||||
{"id": "n2", "type": "end", "title": "End", "data": {}, "position": {"x": 100, "y": 0}},
|
||||
],
|
||||
"edges": [
|
||||
{"id": "e1", "source": "n1", "target": "n2", "sourceHandle": "out", "targetHandle": "in"},
|
||||
],
|
||||
}
|
||||
agent = _make_agent(workflow=wf_data, workflow_id="wf1")
|
||||
graph = agent._parse_embedded_workflow()
|
||||
assert graph is not None
|
||||
assert len(graph.nodes) == 2
|
||||
assert len(graph.edges) == 1
|
||||
assert graph.workflow.name == "Test Workflow"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_edge_source_id_alias(self):
|
||||
wf_data = {
|
||||
"nodes": [{"id": "n1", "type": "start", "data": {}}],
|
||||
"edges": [{"id": "e1", "source_id": "n1", "target_id": "n2", "source_handle": "out", "target_handle": "in"}],
|
||||
}
|
||||
agent = _make_agent(workflow=wf_data)
|
||||
graph = agent._parse_embedded_workflow()
|
||||
assert graph is not None
|
||||
assert graph.edges[0].source_id == "n1"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invalid_data_returns_none(self):
|
||||
agent = _make_agent(workflow={"nodes": [{"bad": "data"}], "edges": []})
|
||||
graph = agent._parse_embedded_workflow()
|
||||
assert graph is None
|
||||
|
||||
|
||||
class TestLoadWorkflowGraph:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_uses_embedded_when_available(self):
|
||||
agent = _make_agent(workflow={"nodes": [], "edges": [], "name": "E"})
|
||||
agent._parse_embedded_workflow = MagicMock(return_value="parsed_graph")
|
||||
result = agent._load_workflow_graph()
|
||||
assert result == "parsed_graph"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_uses_database_when_workflow_id(self):
|
||||
agent = _make_agent(workflow_id="wf1")
|
||||
agent._load_from_database = MagicMock(return_value="db_graph")
|
||||
result = agent._load_workflow_graph()
|
||||
assert result == "db_graph"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_returns_none_when_nothing(self):
|
||||
agent = _make_agent()
|
||||
result = agent._load_workflow_graph()
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestLoadFromDatabase:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invalid_workflow_id_returns_none(self):
|
||||
agent = _make_agent(workflow_id="invalid!")
|
||||
result = agent._load_from_database()
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_owner_returns_none(self):
|
||||
agent = _make_agent(workflow_id="507f1f77bcf86cd799439011", decoded_token={})
|
||||
agent.workflow_owner = None
|
||||
result = agent._load_from_database()
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_uses_decoded_token_sub(self):
|
||||
agent = _make_agent(
|
||||
workflow_id="507f1f77bcf86cd799439011",
|
||||
decoded_token={"sub": "user1"},
|
||||
)
|
||||
agent.workflow_owner = None
|
||||
|
||||
mock_collection = MagicMock()
|
||||
mock_collection.find_one.return_value = None
|
||||
mock_db = MagicMock()
|
||||
mock_db.__getitem__ = MagicMock(return_value=mock_collection)
|
||||
|
||||
with patch("application.agents.workflow_agent.MongoDB") as MockMongo, \
|
||||
patch("application.agents.workflow_agent.settings") as mock_settings:
|
||||
mock_settings.MONGO_DB_NAME = "test_db"
|
||||
MockMongo.get_client.return_value = {"test_db": mock_db}
|
||||
result = agent._load_from_database()
|
||||
assert result is None # workflow_doc not found
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_successful_load(self):
|
||||
agent = _make_agent(
|
||||
workflow_id="507f1f77bcf86cd799439011",
|
||||
workflow_owner="owner1",
|
||||
)
|
||||
|
||||
mock_wf_coll = MagicMock()
|
||||
mock_wf_coll.find_one.return_value = {
|
||||
"_id": "507f1f77bcf86cd799439011",
|
||||
"name": "Test WF",
|
||||
"user": "owner1",
|
||||
"current_graph_version": 1,
|
||||
}
|
||||
|
||||
mock_nodes_coll = MagicMock()
|
||||
mock_nodes_coll.find.return_value = [
|
||||
{"id": "n1", "workflow_id": "507f1f77bcf86cd799439011", "type": "start",
|
||||
"title": "Start", "position": {"x": 0, "y": 0}, "config": {}},
|
||||
]
|
||||
|
||||
mock_edges_coll = MagicMock()
|
||||
mock_edges_coll.find.return_value = []
|
||||
|
||||
def getitem(name):
|
||||
return {"workflows": mock_wf_coll, "workflow_nodes": mock_nodes_coll, "workflow_edges": mock_edges_coll}[name]
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.__getitem__ = MagicMock(side_effect=getitem)
|
||||
|
||||
with patch("application.agents.workflow_agent.MongoDB") as MockMongo, \
|
||||
patch("application.agents.workflow_agent.settings") as mock_settings:
|
||||
mock_settings.MONGO_DB_NAME = "test_db"
|
||||
MockMongo.get_client.return_value = {"test_db": mock_db}
|
||||
result = agent._load_from_database()
|
||||
|
||||
assert result is not None
|
||||
assert len(result.nodes) == 1
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invalid_graph_version(self):
|
||||
agent = _make_agent(
|
||||
workflow_id="507f1f77bcf86cd799439011",
|
||||
workflow_owner="owner1",
|
||||
)
|
||||
|
||||
mock_wf_coll = MagicMock()
|
||||
mock_wf_coll.find_one.return_value = {
|
||||
"_id": "507f1f77bcf86cd799439011",
|
||||
"name": "WF",
|
||||
"user": "owner1",
|
||||
"current_graph_version": "bad",
|
||||
}
|
||||
|
||||
mock_nodes_coll = MagicMock()
|
||||
mock_nodes_coll.find.return_value = []
|
||||
mock_edges_coll = MagicMock()
|
||||
mock_edges_coll.find.return_value = []
|
||||
|
||||
def getitem(name):
|
||||
return {"workflows": mock_wf_coll, "workflow_nodes": mock_nodes_coll, "workflow_edges": mock_edges_coll}[name]
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.__getitem__ = MagicMock(side_effect=getitem)
|
||||
|
||||
with patch("application.agents.workflow_agent.MongoDB") as MockMongo, \
|
||||
patch("application.agents.workflow_agent.settings") as mock_settings:
|
||||
mock_settings.MONGO_DB_NAME = "test_db"
|
||||
MockMongo.get_client.return_value = {"test_db": mock_db}
|
||||
result = agent._load_from_database()
|
||||
assert result is not None # Defaults to version 1
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_fallback_nodes_without_version(self):
|
||||
"""When graph_version=1 finds no nodes, falls back to nodes without version field."""
|
||||
agent = _make_agent(
|
||||
workflow_id="507f1f77bcf86cd799439011",
|
||||
workflow_owner="owner1",
|
||||
)
|
||||
|
||||
mock_wf_coll = MagicMock()
|
||||
mock_wf_coll.find_one.return_value = {
|
||||
"_id": "507f1f77bcf86cd799439011",
|
||||
"name": "WF",
|
||||
"user": "owner1",
|
||||
"current_graph_version": 1,
|
||||
}
|
||||
|
||||
call_count = [0]
|
||||
def nodes_find(query):
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
return [] # No versioned nodes
|
||||
return [{"id": "n1", "workflow_id": "wf", "type": "start",
|
||||
"title": "S", "position": {"x": 0, "y": 0}, "config": {}}]
|
||||
|
||||
mock_nodes_coll = MagicMock()
|
||||
mock_nodes_coll.find.side_effect = nodes_find
|
||||
|
||||
edge_call = [0]
|
||||
def edges_find(query):
|
||||
edge_call[0] += 1
|
||||
if edge_call[0] == 1:
|
||||
return []
|
||||
return []
|
||||
|
||||
mock_edges_coll = MagicMock()
|
||||
mock_edges_coll.find.side_effect = edges_find
|
||||
|
||||
def getitem(name):
|
||||
return {"workflows": mock_wf_coll, "workflow_nodes": mock_nodes_coll, "workflow_edges": mock_edges_coll}[name]
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.__getitem__ = MagicMock(side_effect=getitem)
|
||||
|
||||
with patch("application.agents.workflow_agent.MongoDB") as MockMongo, \
|
||||
patch("application.agents.workflow_agent.settings") as mock_settings:
|
||||
mock_settings.MONGO_DB_NAME = "test_db"
|
||||
MockMongo.get_client.return_value = {"test_db": mock_db}
|
||||
result = agent._load_from_database()
|
||||
assert result is not None
|
||||
assert len(result.nodes) == 1
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_exception_returns_none(self):
|
||||
agent = _make_agent(
|
||||
workflow_id="507f1f77bcf86cd799439011",
|
||||
workflow_owner="owner1",
|
||||
)
|
||||
with patch("application.agents.workflow_agent.MongoDB") as MockMongo:
|
||||
MockMongo.get_client.side_effect = Exception("db error")
|
||||
result = agent._load_from_database()
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGenInner:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_graph_yields_error(self):
|
||||
agent = _make_agent()
|
||||
agent._load_workflow_graph = MagicMock(return_value=None)
|
||||
events = list(agent._gen_inner("query", None))
|
||||
assert any(e.get("type") == "error" for e in events)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_successful_execution(self):
|
||||
agent = _make_agent(workflow_id="wf1")
|
||||
mock_graph = MagicMock(spec=WorkflowGraph)
|
||||
agent._load_workflow_graph = MagicMock(return_value=mock_graph)
|
||||
agent._save_workflow_run = MagicMock()
|
||||
|
||||
mock_engine = MagicMock()
|
||||
mock_engine.execute.return_value = iter([{"answer": "result"}])
|
||||
|
||||
with patch("application.agents.workflow_agent.WorkflowEngine", return_value=mock_engine):
|
||||
events = list(agent._gen_inner("query", None))
|
||||
assert len(events) == 1
|
||||
agent._save_workflow_run.assert_called_once_with("query")
|
||||
|
||||
|
||||
class TestSaveWorkflowRun:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_engine_returns_early(self):
|
||||
agent = _make_agent()
|
||||
agent._engine = None
|
||||
agent._save_workflow_run("query") # Should not raise
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_saves_to_mongo(self):
|
||||
agent = _make_agent(workflow_id="wf1")
|
||||
mock_engine = MagicMock()
|
||||
mock_engine.state = {"query": "test"}
|
||||
mock_engine.execution_log = []
|
||||
mock_engine.get_execution_summary.return_value = []
|
||||
agent._engine = mock_engine
|
||||
|
||||
mock_collection = MagicMock()
|
||||
mock_db = MagicMock()
|
||||
mock_db.__getitem__ = MagicMock(return_value=mock_collection)
|
||||
|
||||
with patch("application.agents.workflow_agent.MongoDB") as MockMongo, \
|
||||
patch("application.agents.workflow_agent.settings") as mock_settings:
|
||||
mock_settings.MONGO_DB_NAME = "test_db"
|
||||
MockMongo.get_client.return_value = {"test_db": mock_db}
|
||||
agent._save_workflow_run("query")
|
||||
|
||||
mock_collection.insert_one.assert_called_once()
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_exception_does_not_propagate(self):
|
||||
agent = _make_agent(workflow_id="wf1")
|
||||
mock_engine = MagicMock()
|
||||
mock_engine.state = {}
|
||||
mock_engine.execution_log = []
|
||||
mock_engine.get_execution_summary.return_value = []
|
||||
agent._engine = mock_engine
|
||||
|
||||
with patch("application.agents.workflow_agent.MongoDB") as MockMongo:
|
||||
MockMongo.get_client.side_effect = Exception("db fail")
|
||||
agent._save_workflow_run("query") # Should not raise
|
||||
|
||||
|
||||
class TestDetermineRunStatus:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_engine_returns_completed(self):
|
||||
agent = _make_agent()
|
||||
agent._engine = None
|
||||
assert agent._determine_run_status() == ExecutionStatus.COMPLETED
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_empty_log_returns_completed(self):
|
||||
agent = _make_agent()
|
||||
agent._engine = MagicMock()
|
||||
agent._engine.execution_log = []
|
||||
assert agent._determine_run_status() == ExecutionStatus.COMPLETED
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_failed_log_returns_failed(self):
|
||||
agent = _make_agent()
|
||||
agent._engine = MagicMock()
|
||||
agent._engine.execution_log = [
|
||||
{"status": "completed"},
|
||||
{"status": "failed"},
|
||||
]
|
||||
assert agent._determine_run_status() == ExecutionStatus.FAILED
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_all_completed_returns_completed(self):
|
||||
agent = _make_agent()
|
||||
agent._engine = MagicMock()
|
||||
agent._engine.execution_log = [
|
||||
{"status": "completed"},
|
||||
{"status": "completed"},
|
||||
]
|
||||
assert agent._determine_run_status() == ExecutionStatus.COMPLETED
|
||||
|
||||
|
||||
class TestSerializeState:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_serializes_primitives(self):
|
||||
agent = _make_agent()
|
||||
state = {"str": "hello", "int": 42, "float": 3.14, "bool": True, "none": None}
|
||||
result = agent._serialize_state(state)
|
||||
assert result == state
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_serializes_nested_dict(self):
|
||||
agent = _make_agent()
|
||||
state = {"nested": {"key": "value"}}
|
||||
result = agent._serialize_state(state)
|
||||
assert result["nested"]["key"] == "value"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_serializes_list(self):
|
||||
agent = _make_agent()
|
||||
state = {"items": [1, 2, "three"]}
|
||||
result = agent._serialize_state(state)
|
||||
assert result["items"] == [1, 2, "three"]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_serializes_tuple(self):
|
||||
agent = _make_agent()
|
||||
state = {"tup": (1, 2)}
|
||||
result = agent._serialize_state(state)
|
||||
assert result["tup"] == [1, 2]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_serializes_datetime(self):
|
||||
agent = _make_agent()
|
||||
dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
|
||||
state = {"time": dt}
|
||||
result = agent._serialize_state(state)
|
||||
assert "2025-01-01" in result["time"]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_serializes_unknown_to_str(self):
|
||||
agent = _make_agent()
|
||||
state = {"obj": object()}
|
||||
result = agent._serialize_state(state)
|
||||
assert isinstance(result["obj"], str)
|
||||
@@ -330,3 +330,180 @@ def test_execute_agent_node_raises_when_schema_set_and_response_not_json(monkeyp
|
||||
match="Structured output was expected but response was not valid JSON",
|
||||
):
|
||||
list(engine._execute_agent_node(node))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Coverage — additional uncovered lines: 204, 213-215, 223, 283-284, 289,
|
||||
# 355, 375
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestWorkflowEngineAdditionalCoverage:
|
||||
|
||||
def test_agent_node_prompt_template_empty_uses_query(self, monkeypatch):
|
||||
"""Cover line 204: prompt_template is empty, uses state query."""
|
||||
engine = create_engine()
|
||||
engine.state["query"] = "What is the answer?"
|
||||
node = create_agent_node(node_id="n1")
|
||||
node.config["prompt_template"] = ""
|
||||
|
||||
node_events = [{"answer": "42"}]
|
||||
monkeypatch.setattr(
|
||||
WorkflowNodeAgentFactory,
|
||||
"create",
|
||||
staticmethod(lambda **kwargs: StubNodeAgent(node_events)),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_api_key_for_provider",
|
||||
lambda _: None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_provider_from_model_id",
|
||||
lambda _: None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_model_capabilities",
|
||||
lambda _: None,
|
||||
)
|
||||
|
||||
list(engine._execute_agent_node(node))
|
||||
assert engine.state["node_n1_output"] == "42"
|
||||
|
||||
def test_agent_node_model_config_override(self, monkeypatch):
|
||||
"""Cover lines 213-215: node_config with model_id and llm_name."""
|
||||
engine = create_engine()
|
||||
engine.state["query"] = "test"
|
||||
node = create_agent_node(node_id="n2")
|
||||
node.config["model_id"] = "gpt-4o"
|
||||
node.config["llm_name"] = "openai"
|
||||
|
||||
node_events = [{"answer": "result"}]
|
||||
monkeypatch.setattr(
|
||||
WorkflowNodeAgentFactory,
|
||||
"create",
|
||||
staticmethod(lambda **kwargs: StubNodeAgent(node_events)),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_api_key_for_provider",
|
||||
lambda _: "key",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_provider_from_model_id",
|
||||
lambda _: "openai",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_model_capabilities",
|
||||
lambda _: None,
|
||||
)
|
||||
|
||||
list(engine._execute_agent_node(node))
|
||||
assert engine.state["node_n2_output"] == "result"
|
||||
|
||||
def test_agent_node_unsupported_structured_output_raises(self, monkeypatch):
|
||||
"""Cover line 223: model does not support structured output raises."""
|
||||
engine = create_engine()
|
||||
engine.state["query"] = "test"
|
||||
node = create_agent_node(
|
||||
node_id="n3",
|
||||
json_schema={"type": "object", "properties": {"a": {"type": "string"}}},
|
||||
)
|
||||
node.config["model_id"] = "model-no-struct"
|
||||
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_api_key_for_provider",
|
||||
lambda _: "key",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_provider_from_model_id",
|
||||
lambda _: "openai",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_model_capabilities",
|
||||
lambda _: {"supports_structured_output": False},
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="does not support structured output"):
|
||||
list(engine._execute_agent_node(node))
|
||||
|
||||
def test_structured_output_with_structured_response(self, monkeypatch):
|
||||
"""Cover lines 283-284: structured response parsed and validated."""
|
||||
engine = create_engine()
|
||||
engine.state["query"] = "test"
|
||||
node = create_agent_node(
|
||||
node_id="n4",
|
||||
output_variable="result",
|
||||
json_schema={"type": "object", "properties": {"key": {"type": "string"}}},
|
||||
)
|
||||
|
||||
node_events = [
|
||||
{"answer": '{"key": "val"}', "structured": True},
|
||||
]
|
||||
monkeypatch.setattr(
|
||||
WorkflowNodeAgentFactory,
|
||||
"create",
|
||||
staticmethod(lambda **kwargs: StubNodeAgent(node_events)),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_api_key_for_provider",
|
||||
lambda _: None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_provider_from_model_id",
|
||||
lambda _: None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_model_capabilities",
|
||||
lambda _: {"supports_structured_output": True},
|
||||
)
|
||||
|
||||
list(engine._execute_agent_node(node))
|
||||
assert engine.state["result"] == {"key": "val"}
|
||||
|
||||
def test_json_schema_no_structured_flag_parses_response(self, monkeypatch):
|
||||
"""Cover line 289: json_schema set but no structured flag; non-JSON response raises."""
|
||||
engine = create_engine()
|
||||
engine.state["query"] = "test"
|
||||
node = create_agent_node(
|
||||
node_id="n5",
|
||||
json_schema={"type": "object", "properties": {"x": {"type": "string"}}},
|
||||
)
|
||||
|
||||
node_events = [{"answer": "not valid json"}]
|
||||
monkeypatch.setattr(
|
||||
WorkflowNodeAgentFactory,
|
||||
"create",
|
||||
staticmethod(lambda **kwargs: StubNodeAgent(node_events)),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_api_key_for_provider",
|
||||
lambda _: None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_provider_from_model_id",
|
||||
lambda _: None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_model_capabilities",
|
||||
lambda _: {"supports_structured_output": True},
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Structured output was expected but response was not valid JSON",
|
||||
):
|
||||
list(engine._execute_agent_node(node))
|
||||
|
||||
def test_parse_structured_output_empty_string(self):
|
||||
"""Cover line 355: _parse_structured_output with empty string."""
|
||||
engine = create_engine()
|
||||
success, result = engine._parse_structured_output("")
|
||||
assert success is False
|
||||
assert result is None
|
||||
|
||||
def test_normalize_node_json_schema_invalid(self):
|
||||
"""Cover line 375: _normalize_node_json_schema with invalid schema raises."""
|
||||
engine = create_engine()
|
||||
# A non-dict schema triggers JsonSchemaValidationError
|
||||
with pytest.raises(ValueError, match="Invalid JSON schema"):
|
||||
engine._normalize_node_json_schema("not_a_dict", "TestNode")
|
||||
|
||||
833
tests/agents/test_workflow_engine_coverage.py
Normal file
833
tests/agents/test_workflow_engine_coverage.py
Normal file
@@ -0,0 +1,833 @@
|
||||
"""Tests covering gaps in WorkflowEngine: execute loop, state/condition/end nodes,
|
||||
template context, source data, structured output parsing, get_execution_summary."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from application.agents.workflows.schemas import (
|
||||
ExecutionStatus,
|
||||
NodeType,
|
||||
WorkflowEdge,
|
||||
WorkflowGraph,
|
||||
WorkflowNode,
|
||||
Workflow,
|
||||
)
|
||||
from application.agents.workflows.workflow_engine import WorkflowEngine
|
||||
|
||||
|
||||
def _make_graph(nodes, edges):
|
||||
wf = Workflow(name="Test", description="test workflow")
|
||||
return WorkflowGraph(workflow=wf, nodes=nodes, edges=edges)
|
||||
|
||||
|
||||
def _make_node(id, type, title="Node", config=None, position=None):
|
||||
return WorkflowNode(
|
||||
id=id,
|
||||
workflow_id="wf1",
|
||||
type=type,
|
||||
title=title,
|
||||
position=position or {"x": 0, "y": 0},
|
||||
config=config or {},
|
||||
)
|
||||
|
||||
|
||||
def _make_edge(id, source, target, source_handle=None, target_handle=None):
|
||||
return WorkflowEdge(
|
||||
id=id,
|
||||
workflow_id="wf1",
|
||||
source=source,
|
||||
target=target,
|
||||
sourceHandle=source_handle,
|
||||
targetHandle=target_handle,
|
||||
)
|
||||
|
||||
|
||||
def _make_agent():
|
||||
agent = MagicMock()
|
||||
agent.chat_history = []
|
||||
agent.endpoint = "https://api.example.com"
|
||||
agent.llm_name = "openai"
|
||||
agent.model_id = "gpt-4"
|
||||
agent.api_key = "key"
|
||||
agent.decoded_token = {"sub": "user1"}
|
||||
agent.retrieved_docs = None
|
||||
return agent
|
||||
|
||||
|
||||
class TestExecuteLoop:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_start_node_yields_error(self):
|
||||
graph = _make_graph([], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
events = list(engine.execute({}, "query"))
|
||||
assert any(e.get("type") == "error" and "start node" in e.get("error", "") for e in events)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_start_to_end(self):
|
||||
nodes = [
|
||||
_make_node("n1", NodeType.START, "Start"),
|
||||
_make_node("n2", NodeType.END, "End", config={"config": {}}),
|
||||
]
|
||||
edges = [_make_edge("e1", "n1", "n2")]
|
||||
graph = _make_graph(nodes, edges)
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
events = list(engine.execute({}, "hello"))
|
||||
step_events = [e for e in events if e.get("type") == "workflow_step"]
|
||||
assert len(step_events) >= 2 # At least start + end
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_node_not_found_yields_error(self):
|
||||
nodes = [_make_node("n1", NodeType.START)]
|
||||
edges = [_make_edge("e1", "n1", "nonexistent")]
|
||||
graph = _make_graph(nodes, edges)
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
events = list(engine.execute({}, "q"))
|
||||
assert any("not found" in e.get("error", "") for e in events)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_node_execution_error_yields_error(self):
|
||||
nodes = [
|
||||
_make_node("n1", NodeType.START),
|
||||
_make_node("n2", NodeType.STATE, "State", config={"config": {"operations": [{"expression": "bad!!!", "target_variable": "x"}]}}),
|
||||
]
|
||||
edges = [_make_edge("e1", "n1", "n2")]
|
||||
graph = _make_graph(nodes, edges)
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
events = list(engine.execute({}, "q"))
|
||||
failed_events = [e for e in events if e.get("status") == "failed"]
|
||||
assert len(failed_events) >= 1
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_max_steps_limit(self):
|
||||
# Create a cycle: start -> state -> state (loop)
|
||||
nodes = [
|
||||
_make_node("n1", NodeType.START),
|
||||
_make_node("n2", NodeType.NOTE, "Note"),
|
||||
]
|
||||
edges = [_make_edge("e1", "n1", "n2"), _make_edge("e2", "n2", "n2")]
|
||||
graph = _make_graph(nodes, edges)
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
engine.MAX_EXECUTION_STEPS = 5
|
||||
events = list(engine.execute({}, "q"))
|
||||
# Should not run forever
|
||||
assert len(events) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_branch_ends_without_end_node(self):
|
||||
nodes = [
|
||||
_make_node("n1", NodeType.START),
|
||||
_make_node("n2", NodeType.NOTE, "Note"),
|
||||
]
|
||||
edges = [_make_edge("e1", "n1", "n2")] # n2 has no outgoing edges
|
||||
graph = _make_graph(nodes, edges)
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
events = list(engine.execute({}, "q"))
|
||||
assert len(events) > 0
|
||||
|
||||
|
||||
class TestInitializeState:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_sets_query_and_history(self):
|
||||
graph = _make_graph([], [])
|
||||
agent = _make_agent()
|
||||
agent.chat_history = [{"prompt": "hi", "response": "hey"}]
|
||||
engine = WorkflowEngine(graph, agent)
|
||||
engine._initialize_state({"custom": "value"}, "test query")
|
||||
assert engine.state["query"] == "test query"
|
||||
assert "custom" in engine.state
|
||||
assert engine.state["chat_history"] is not None
|
||||
|
||||
|
||||
class TestGetNextNodeId:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_edges_returns_none(self):
|
||||
nodes = [_make_node("n1", NodeType.START)]
|
||||
graph = _make_graph(nodes, [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
assert engine._get_next_node_id("n1") is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_returns_first_edge_target(self):
|
||||
nodes = [_make_node("n1", NodeType.START), _make_node("n2", NodeType.END)]
|
||||
edges = [_make_edge("e1", "n1", "n2")]
|
||||
graph = _make_graph(nodes, edges)
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
assert engine._get_next_node_id("n1") == "n2"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_condition_uses_matched_handle(self):
|
||||
nodes = [
|
||||
_make_node("n1", NodeType.CONDITION),
|
||||
_make_node("n2", NodeType.END, "Yes End"),
|
||||
_make_node("n3", NodeType.END, "No End"),
|
||||
]
|
||||
edges = [
|
||||
_make_edge("e1", "n1", "n2", source_handle="yes"),
|
||||
_make_edge("e2", "n1", "n3", source_handle="no"),
|
||||
]
|
||||
graph = _make_graph(nodes, edges)
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
engine._condition_result = "no"
|
||||
assert engine._get_next_node_id("n1") == "n3"
|
||||
assert engine._condition_result is None # Cleared after use
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_condition_no_matching_handle_returns_none(self):
|
||||
nodes = [_make_node("n1", NodeType.CONDITION)]
|
||||
edges = [_make_edge("e1", "n1", "n2", source_handle="yes")]
|
||||
graph = _make_graph(nodes, edges)
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
engine._condition_result = "nonexistent"
|
||||
assert engine._get_next_node_id("n1") is None
|
||||
|
||||
|
||||
class TestExecuteStateNode:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_evaluates_operations(self):
|
||||
node = _make_node("n1", NodeType.STATE, config={
|
||||
"config": {
|
||||
"operations": [
|
||||
{"expression": "x + 1", "target_variable": "result"},
|
||||
]
|
||||
}
|
||||
})
|
||||
graph = _make_graph([node], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
engine.state = {"x": 5}
|
||||
list(engine._execute_state_node(node))
|
||||
assert engine.state["result"] == 6
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_skips_empty_expression(self):
|
||||
node = _make_node("n1", NodeType.STATE, config={
|
||||
"config": {
|
||||
"operations": [
|
||||
{"expression": "", "target_variable": "result"},
|
||||
]
|
||||
}
|
||||
})
|
||||
graph = _make_graph([node], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
engine.state = {}
|
||||
list(engine._execute_state_node(node))
|
||||
assert "result" not in engine.state
|
||||
|
||||
|
||||
class TestExecuteConditionNode:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_matches_first_true_case(self):
|
||||
node = _make_node("n1", NodeType.CONDITION, config={
|
||||
"config": {
|
||||
"cases": [
|
||||
{"expression": "x > 10", "source_handle": "high"},
|
||||
{"expression": "x > 5", "source_handle": "medium"},
|
||||
]
|
||||
}
|
||||
})
|
||||
graph = _make_graph([node], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
engine.state = {"x": 7}
|
||||
list(engine._execute_condition_node(node))
|
||||
assert engine._condition_result == "medium"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_falls_through_to_else(self):
|
||||
node = _make_node("n1", NodeType.CONDITION, config={
|
||||
"config": {
|
||||
"cases": [
|
||||
{"expression": "x > 100", "source_handle": "high"},
|
||||
]
|
||||
}
|
||||
})
|
||||
graph = _make_graph([node], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
engine.state = {"x": 1}
|
||||
list(engine._execute_condition_node(node))
|
||||
assert engine._condition_result == "else"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_skips_empty_expression(self):
|
||||
node = _make_node("n1", NodeType.CONDITION, config={
|
||||
"config": {
|
||||
"cases": [
|
||||
{"expression": " ", "source_handle": "a"},
|
||||
{"expression": "true", "source_handle": "b"},
|
||||
]
|
||||
}
|
||||
})
|
||||
graph = _make_graph([node], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
engine.state = {}
|
||||
list(engine._execute_condition_node(node))
|
||||
assert engine._condition_result == "b"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_cel_error_continues(self):
|
||||
node = _make_node("n1", NodeType.CONDITION, config={
|
||||
"config": {
|
||||
"cases": [
|
||||
{"expression": "bad!!!", "source_handle": "a"},
|
||||
{"expression": "true", "source_handle": "b"},
|
||||
]
|
||||
}
|
||||
})
|
||||
graph = _make_graph([node], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
engine.state = {}
|
||||
list(engine._execute_condition_node(node))
|
||||
assert engine._condition_result == "b"
|
||||
|
||||
|
||||
class TestExecuteEndNode:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_with_output_template(self):
|
||||
node = _make_node("n1", NodeType.END, config={
|
||||
"config": {"output_template": "Result: {{ query }}"}
|
||||
})
|
||||
graph = _make_graph([node], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
engine.state = {"query": "hello"}
|
||||
engine._format_template = MagicMock(return_value="Result: hello")
|
||||
events = list(engine._execute_end_node(node))
|
||||
assert len(events) == 1
|
||||
assert events[0]["answer"] == "Result: hello"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_without_output_template(self):
|
||||
node = _make_node("n1", NodeType.END, config={"config": {}})
|
||||
graph = _make_graph([node], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
events = list(engine._execute_end_node(node))
|
||||
assert len(events) == 0
|
||||
|
||||
|
||||
class TestParseStructuredOutput:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_valid_json(self):
|
||||
graph = _make_graph([], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
success, data = engine._parse_structured_output('{"key": "value"}')
|
||||
assert success is True
|
||||
assert data == {"key": "value"}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invalid_json(self):
|
||||
graph = _make_graph([], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
success, data = engine._parse_structured_output("not json")
|
||||
assert success is False
|
||||
assert data is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_empty_string(self):
|
||||
graph = _make_graph([], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
success, data = engine._parse_structured_output("")
|
||||
assert success is False
|
||||
assert data is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_whitespace_only(self):
|
||||
graph = _make_graph([], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
success, data = engine._parse_structured_output(" ")
|
||||
assert success is False
|
||||
assert data is None
|
||||
|
||||
|
||||
class TestNormalizeNodeJsonSchema:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_none_returns_none(self):
|
||||
graph = _make_graph([], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
assert engine._normalize_node_json_schema(None, "Node") is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_valid_schema(self):
|
||||
graph = _make_graph([], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
schema = {"type": "object", "properties": {"name": {"type": "string"}}}
|
||||
result = engine._normalize_node_json_schema(schema, "Node")
|
||||
assert result is not None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invalid_schema_raises(self):
|
||||
graph = _make_graph([], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
with patch("application.agents.workflows.workflow_engine.normalize_json_schema_payload") as mock_norm:
|
||||
from application.core.json_schema_utils import JsonSchemaValidationError
|
||||
mock_norm.side_effect = JsonSchemaValidationError("bad schema")
|
||||
with pytest.raises(ValueError, match="Invalid JSON schema"):
|
||||
engine._normalize_node_json_schema({"bad": True}, "TestNode")
|
||||
|
||||
|
||||
class TestValidateStructuredOutput:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_valid_output_passes(self):
|
||||
graph = _make_graph([], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
schema = {"type": "object", "properties": {"name": {"type": "string"}}}
|
||||
engine._validate_structured_output(schema, {"name": "Alice"}) # Should not raise
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invalid_output_raises(self):
|
||||
graph = _make_graph([], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
schema = {"type": "object", "properties": {"name": {"type": "string"}}, "required": ["name"]}
|
||||
with pytest.raises(ValueError, match="did not match schema"):
|
||||
engine._validate_structured_output(schema, {})
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_jsonschema_module(self):
|
||||
graph = _make_graph([], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
with patch("application.agents.workflows.workflow_engine.jsonschema", None):
|
||||
engine._validate_structured_output({"type": "object"}, {}) # Should not raise
|
||||
|
||||
|
||||
class TestFormatTemplate:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_renders_template(self):
|
||||
graph = _make_graph([], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
engine.state = {"query": "hello"}
|
||||
engine._build_template_context = MagicMock(return_value={"query": "hello"})
|
||||
engine._template_engine = MagicMock()
|
||||
engine._template_engine.render.return_value = "hello world"
|
||||
result = engine._format_template("{{ query }} world")
|
||||
assert result == "hello world"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_render_error_returns_raw(self):
|
||||
from application.templates.template_engine import TemplateRenderError
|
||||
graph = _make_graph([], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
engine._build_template_context = MagicMock(return_value={})
|
||||
engine._template_engine = MagicMock()
|
||||
engine._template_engine.render.side_effect = TemplateRenderError("fail")
|
||||
result = engine._format_template("{{ bad }}")
|
||||
assert result == "{{ bad }}"
|
||||
|
||||
|
||||
class TestBuildTemplateContext:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_includes_state_variables(self):
|
||||
graph = _make_graph([], [])
|
||||
agent = _make_agent()
|
||||
agent.retrieved_docs = None
|
||||
engine = WorkflowEngine(graph, agent)
|
||||
engine.state = {"query": "hello", "custom_var": "value"}
|
||||
context = engine._build_template_context()
|
||||
assert context["agent"]["query"] == "hello"
|
||||
assert "custom_var" in context
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_reserved_namespace_gets_prefixed(self):
|
||||
graph = _make_graph([], [])
|
||||
agent = _make_agent()
|
||||
agent.retrieved_docs = None
|
||||
engine = WorkflowEngine(graph, agent)
|
||||
engine.state = {"source": "my_source_val"}
|
||||
context = engine._build_template_context()
|
||||
assert context.get("agent_source") == "my_source_val"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_passthrough_data(self):
|
||||
graph = _make_graph([], [])
|
||||
agent = _make_agent()
|
||||
agent.retrieved_docs = None
|
||||
engine = WorkflowEngine(graph, agent)
|
||||
engine.state = {"passthrough": {"key": "val"}}
|
||||
context = engine._build_template_context()
|
||||
assert "passthrough" in context or "agent_passthrough" in context
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_tools_data(self):
|
||||
graph = _make_graph([], [])
|
||||
agent = _make_agent()
|
||||
agent.retrieved_docs = None
|
||||
engine = WorkflowEngine(graph, agent)
|
||||
engine.state = {"tools": {"tool1": "result"}}
|
||||
context = engine._build_template_context()
|
||||
assert "agent" in context
|
||||
|
||||
|
||||
class TestGetSourceTemplateData:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_docs_returns_none(self):
|
||||
graph = _make_graph([], [])
|
||||
agent = _make_agent()
|
||||
agent.retrieved_docs = None
|
||||
engine = WorkflowEngine(graph, agent)
|
||||
docs, together = engine._get_source_template_data()
|
||||
assert docs is None
|
||||
assert together is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_empty_docs_returns_none(self):
|
||||
graph = _make_graph([], [])
|
||||
agent = _make_agent()
|
||||
agent.retrieved_docs = []
|
||||
engine = WorkflowEngine(graph, agent)
|
||||
docs, together = engine._get_source_template_data()
|
||||
assert docs is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_docs_with_filename(self):
|
||||
graph = _make_graph([], [])
|
||||
agent = _make_agent()
|
||||
agent.retrieved_docs = [{"text": "content", "filename": "doc.txt"}]
|
||||
engine = WorkflowEngine(graph, agent)
|
||||
docs, together = engine._get_source_template_data()
|
||||
assert docs is not None
|
||||
assert "doc.txt" in together
|
||||
assert "content" in together
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_docs_without_filename(self):
|
||||
graph = _make_graph([], [])
|
||||
agent = _make_agent()
|
||||
agent.retrieved_docs = [{"text": "content only"}]
|
||||
engine = WorkflowEngine(graph, agent)
|
||||
docs, together = engine._get_source_template_data()
|
||||
assert together == "content only"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_skips_non_dict_docs(self):
|
||||
graph = _make_graph([], [])
|
||||
agent = _make_agent()
|
||||
agent.retrieved_docs = ["not a dict", {"text": "ok"}]
|
||||
engine = WorkflowEngine(graph, agent)
|
||||
docs, together = engine._get_source_template_data()
|
||||
assert together == "ok"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_skips_non_string_text(self):
|
||||
graph = _make_graph([], [])
|
||||
agent = _make_agent()
|
||||
agent.retrieved_docs = [{"text": 123}]
|
||||
engine = WorkflowEngine(graph, agent)
|
||||
docs, together = engine._get_source_template_data()
|
||||
assert together is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_doc_with_title_fallback(self):
|
||||
graph = _make_graph([], [])
|
||||
agent = _make_agent()
|
||||
agent.retrieved_docs = [{"text": "content", "title": "doc_title"}]
|
||||
engine = WorkflowEngine(graph, agent)
|
||||
docs, together = engine._get_source_template_data()
|
||||
assert "doc_title" in together
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_doc_with_source_fallback(self):
|
||||
graph = _make_graph([], [])
|
||||
agent = _make_agent()
|
||||
agent.retrieved_docs = [{"text": "content", "source": "src"}]
|
||||
engine = WorkflowEngine(graph, agent)
|
||||
docs, together = engine._get_source_template_data()
|
||||
assert "src" in together
|
||||
|
||||
|
||||
class TestGetExecutionSummary:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_returns_log_entries(self):
|
||||
graph = _make_graph([], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
now = datetime.now(timezone.utc)
|
||||
engine.execution_log = [
|
||||
{
|
||||
"node_id": "n1",
|
||||
"node_type": "start",
|
||||
"status": "completed",
|
||||
"started_at": now,
|
||||
"completed_at": now,
|
||||
"error": None,
|
||||
"state_snapshot": {},
|
||||
}
|
||||
]
|
||||
summary = engine.get_execution_summary()
|
||||
assert len(summary) == 1
|
||||
assert summary[0].node_id == "n1"
|
||||
assert summary[0].status == ExecutionStatus.COMPLETED
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_empty_log(self):
|
||||
graph = _make_graph([], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
assert engine.get_execution_summary() == []
|
||||
|
||||
|
||||
class TestAgentNodeExecution:
|
||||
"""Cover lines 204, 213-215, 223, 232-233, 283-284, 289, 355, 375."""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_agent_node_without_prompt_template(self):
|
||||
"""Cover line 204/206: agent node without prompt_template uses query."""
|
||||
node = _make_node("n1", NodeType.AGENT, "Agent", config={
|
||||
"config": {
|
||||
"agent_type": "classic",
|
||||
"stream_to_user": False,
|
||||
}
|
||||
})
|
||||
graph = _make_graph([node], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
engine.state = {"query": "test question"}
|
||||
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.gen.return_value = [{"answer": "response"}]
|
||||
|
||||
with patch(
|
||||
"application.agents.workflows.workflow_engine.WorkflowNodeAgentFactory"
|
||||
) as mock_factory, \
|
||||
patch(
|
||||
"application.core.model_utils.get_provider_from_model_id",
|
||||
return_value="openai",
|
||||
), \
|
||||
patch(
|
||||
"application.core.model_utils.get_api_key_for_provider",
|
||||
return_value="key",
|
||||
), \
|
||||
patch(
|
||||
"application.core.model_utils.get_model_capabilities",
|
||||
return_value=None,
|
||||
):
|
||||
mock_factory.create.return_value = mock_agent
|
||||
list(engine._execute_agent_node(node))
|
||||
|
||||
output_key = f"node_{node.id}_output"
|
||||
assert output_key in engine.state
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_agent_node_with_structured_output(self):
|
||||
"""Cover lines 283-284, 289: structured output parsing."""
|
||||
node = _make_node("n1", NodeType.AGENT, "Agent", config={
|
||||
"config": {
|
||||
"agent_type": "classic",
|
||||
"stream_to_user": False,
|
||||
"json_schema": {
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}},
|
||||
},
|
||||
}
|
||||
})
|
||||
graph = _make_graph([node], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
engine.state = {"query": "test"}
|
||||
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.gen.return_value = [
|
||||
{"answer": '{"name": "Alice"}', "structured": True}
|
||||
]
|
||||
|
||||
with patch(
|
||||
"application.agents.workflows.workflow_engine.WorkflowNodeAgentFactory"
|
||||
) as mock_factory, \
|
||||
patch(
|
||||
"application.core.model_utils.get_provider_from_model_id",
|
||||
return_value="openai",
|
||||
), \
|
||||
patch(
|
||||
"application.core.model_utils.get_api_key_for_provider",
|
||||
return_value="key",
|
||||
), \
|
||||
patch(
|
||||
"application.core.model_utils.get_model_capabilities",
|
||||
return_value={"supports_structured_output": True},
|
||||
):
|
||||
mock_factory.create.return_value = mock_agent
|
||||
list(engine._execute_agent_node(node))
|
||||
|
||||
output_key = f"node_{node.id}_output"
|
||||
assert engine.state[output_key] == {"name": "Alice"}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_agent_node_model_no_structured_support_raises(self):
|
||||
"""Cover lines 223: model without structured output raises ValueError."""
|
||||
node = _make_node("n1", NodeType.AGENT, "Agent", config={
|
||||
"config": {
|
||||
"agent_type": "classic",
|
||||
"json_schema": {
|
||||
"type": "object",
|
||||
"properties": {"x": {"type": "string"}},
|
||||
},
|
||||
"model_id": "test-model",
|
||||
}
|
||||
})
|
||||
graph = _make_graph([node], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
engine.state = {"query": "test"}
|
||||
|
||||
with patch(
|
||||
"application.core.model_utils.get_provider_from_model_id",
|
||||
return_value="openai",
|
||||
), \
|
||||
patch(
|
||||
"application.core.model_utils.get_api_key_for_provider",
|
||||
return_value="key",
|
||||
), \
|
||||
patch(
|
||||
"application.core.model_utils.get_model_capabilities",
|
||||
return_value={"supports_structured_output": False},
|
||||
):
|
||||
with pytest.raises(ValueError, match="does not support structured output"):
|
||||
list(engine._execute_agent_node(node))
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_agent_node_output_variable(self):
|
||||
"""Cover line 300: output_variable stores result."""
|
||||
node = _make_node("n1", NodeType.AGENT, "Agent", config={
|
||||
"config": {
|
||||
"agent_type": "classic",
|
||||
"stream_to_user": False,
|
||||
"output_variable": "my_result",
|
||||
}
|
||||
})
|
||||
graph = _make_graph([node], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
engine.state = {"query": "test"}
|
||||
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.gen.return_value = [{"answer": "output text"}]
|
||||
|
||||
with patch(
|
||||
"application.agents.workflows.workflow_engine.WorkflowNodeAgentFactory"
|
||||
) as mock_factory, \
|
||||
patch(
|
||||
"application.core.model_utils.get_provider_from_model_id",
|
||||
return_value="openai",
|
||||
), \
|
||||
patch(
|
||||
"application.core.model_utils.get_api_key_for_provider",
|
||||
return_value="key",
|
||||
), \
|
||||
patch(
|
||||
"application.core.model_utils.get_model_capabilities",
|
||||
return_value=None,
|
||||
):
|
||||
mock_factory.create.return_value = mock_agent
|
||||
list(engine._execute_agent_node(node))
|
||||
|
||||
assert engine.state["my_result"] == "output text"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_validate_structured_output_schema_error(self):
|
||||
"""Cover line 375/382-383: invalid schema raises ValueError."""
|
||||
graph = _make_graph([], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
import jsonschema as js
|
||||
|
||||
with patch(
|
||||
"application.agents.workflows.workflow_engine.normalize_json_schema_payload",
|
||||
return_value={"type": "invalid_schema_type"},
|
||||
), \
|
||||
patch(
|
||||
"application.agents.workflows.workflow_engine.jsonschema"
|
||||
) as mock_js:
|
||||
mock_js.validate.side_effect = js.exceptions.SchemaError("bad schema")
|
||||
mock_js.exceptions = js.exceptions
|
||||
with pytest.raises(ValueError, match="Invalid JSON schema"):
|
||||
engine._validate_structured_output(
|
||||
{"type": "object"}, {"name": "test"}
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_parse_structured_output_invalid_json(self):
|
||||
"""Cover lines 349-352: invalid JSON returns False."""
|
||||
graph = _make_graph([], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
success, data = engine._parse_structured_output("not json {")
|
||||
assert success is False
|
||||
assert data is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Additional coverage for workflow_engine.py
|
||||
# Lines: 96-114 (exception in node execution), 122-130 (branch/max steps)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestWorkflowNodeExecutionException:
|
||||
"""Cover lines 96-114: exception during _execute_node yields error events."""
|
||||
|
||||
def test_node_raises_exception_yields_error(self):
|
||||
"""Force _execute_node to raise, covering lines 96-114."""
|
||||
nodes = [
|
||||
_make_node("n1", NodeType.START),
|
||||
_make_node("n2", NodeType.AGENT, "Agent"),
|
||||
]
|
||||
edges = [_make_edge("e1", "n1", "n2")]
|
||||
graph = _make_graph(nodes, edges)
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
|
||||
# Patch _execute_node to raise on agent node
|
||||
original_execute = engine._execute_node
|
||||
|
||||
def patched_execute(node):
|
||||
if node.type == NodeType.AGENT:
|
||||
raise RuntimeError("Agent exploded")
|
||||
yield from original_execute(node)
|
||||
|
||||
engine._execute_node = patched_execute
|
||||
events = list(engine.execute({}, "test query"))
|
||||
|
||||
error_events = [e for e in events if e.get("type") == "error"]
|
||||
assert len(error_events) >= 1
|
||||
failed_steps = [e for e in events if e.get("status") == "failed"]
|
||||
assert len(failed_steps) >= 1
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestWorkflowMaxStepsReached:
|
||||
"""Cover lines 127-130: max steps limit warning."""
|
||||
|
||||
def test_max_steps_exactly_reached(self):
|
||||
nodes = [
|
||||
_make_node("n1", NodeType.START),
|
||||
_make_node("n2", NodeType.NOTE, "Note"),
|
||||
]
|
||||
edges = [_make_edge("e1", "n1", "n2"), _make_edge("e2", "n2", "n2")]
|
||||
graph = _make_graph(nodes, edges)
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
engine.MAX_EXECUTION_STEPS = 3
|
||||
events = list(engine.execute({}, "q"))
|
||||
# The while loop runs 3 times then exits, steps >= MAX
|
||||
assert len(events) >= 3
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestWorkflowBranchEndsNonEndNode:
|
||||
"""Cover lines 122-125: branch ends at non-end node without outgoing edges."""
|
||||
|
||||
def test_branch_ends_at_state_node(self):
|
||||
nodes = [
|
||||
_make_node("n1", NodeType.START),
|
||||
_make_node(
|
||||
"n2",
|
||||
NodeType.STATE,
|
||||
"State",
|
||||
config={"config": {"operations": []}},
|
||||
),
|
||||
]
|
||||
edges = [_make_edge("e1", "n1", "n2")] # n2 has no outgoing
|
||||
graph = _make_graph(nodes, edges)
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
events = list(engine.execute({}, "q"))
|
||||
# Should complete without crash, branch ended warning logged
|
||||
assert len(events) > 0
|
||||
556
tests/agents/test_workflow_schemas.py
Normal file
556
tests/agents/test_workflow_schemas.py
Normal file
@@ -0,0 +1,556 @@
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
from bson import ObjectId
|
||||
from pydantic import ValidationError
|
||||
|
||||
from application.agents.workflows.schemas import (
|
||||
AgentNodeConfig,
|
||||
AgentType,
|
||||
ConditionCase,
|
||||
ConditionNodeConfig,
|
||||
ExecutionStatus,
|
||||
NodeExecutionLog,
|
||||
NodeType,
|
||||
Position,
|
||||
StateOperation,
|
||||
Workflow,
|
||||
WorkflowCreate,
|
||||
WorkflowEdge,
|
||||
WorkflowEdgeCreate,
|
||||
WorkflowGraph,
|
||||
WorkflowNode,
|
||||
WorkflowNodeCreate,
|
||||
WorkflowRun,
|
||||
WorkflowRunCreate,
|
||||
)
|
||||
|
||||
|
||||
# ── Enum tests ───────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestNodeType:
|
||||
@pytest.mark.unit
|
||||
def test_values(self):
|
||||
assert NodeType.START == "start"
|
||||
assert NodeType.END == "end"
|
||||
assert NodeType.AGENT == "agent"
|
||||
assert NodeType.NOTE == "note"
|
||||
assert NodeType.STATE == "state"
|
||||
assert NodeType.CONDITION == "condition"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_all_members(self):
|
||||
assert set(NodeType) == {
|
||||
NodeType.START,
|
||||
NodeType.END,
|
||||
NodeType.AGENT,
|
||||
NodeType.NOTE,
|
||||
NodeType.STATE,
|
||||
NodeType.CONDITION,
|
||||
}
|
||||
|
||||
|
||||
class TestAgentType:
|
||||
@pytest.mark.unit
|
||||
def test_values(self):
|
||||
assert AgentType.CLASSIC == "classic"
|
||||
assert AgentType.REACT == "react"
|
||||
assert AgentType.AGENTIC == "agentic"
|
||||
assert AgentType.RESEARCH == "research"
|
||||
|
||||
|
||||
class TestExecutionStatus:
|
||||
@pytest.mark.unit
|
||||
def test_values(self):
|
||||
assert ExecutionStatus.PENDING == "pending"
|
||||
assert ExecutionStatus.RUNNING == "running"
|
||||
assert ExecutionStatus.COMPLETED == "completed"
|
||||
assert ExecutionStatus.FAILED == "failed"
|
||||
|
||||
|
||||
# ── Position ─────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestPosition:
|
||||
@pytest.mark.unit
|
||||
def test_defaults(self):
|
||||
p = Position()
|
||||
assert p.x == 0.0
|
||||
assert p.y == 0.0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_custom_values(self):
|
||||
p = Position(x=10.5, y=-3.2)
|
||||
assert p.x == 10.5
|
||||
assert p.y == -3.2
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_extra_fields_forbidden(self):
|
||||
with pytest.raises(ValidationError):
|
||||
Position(x=0, y=0, z=1)
|
||||
|
||||
|
||||
# ── AgentNodeConfig ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestAgentNodeConfig:
|
||||
@pytest.mark.unit
|
||||
def test_defaults(self):
|
||||
c = AgentNodeConfig()
|
||||
assert c.agent_type == AgentType.CLASSIC
|
||||
assert c.llm_name is None
|
||||
assert c.system_prompt == "You are a helpful assistant."
|
||||
assert c.prompt_template == ""
|
||||
assert c.output_variable is None
|
||||
assert c.stream_to_user is True
|
||||
assert c.tools == []
|
||||
assert c.sources == []
|
||||
assert c.chunks == "2"
|
||||
assert c.retriever == ""
|
||||
assert c.model_id is None
|
||||
assert c.json_schema is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_custom_values(self):
|
||||
c = AgentNodeConfig(
|
||||
agent_type=AgentType.REACT,
|
||||
llm_name="gpt-4",
|
||||
tools=["search"],
|
||||
sources=["src1"],
|
||||
chunks="5",
|
||||
model_id="m1",
|
||||
json_schema={"type": "object"},
|
||||
)
|
||||
assert c.agent_type == AgentType.REACT
|
||||
assert c.llm_name == "gpt-4"
|
||||
assert c.tools == ["search"]
|
||||
assert c.sources == ["src1"]
|
||||
assert c.chunks == "5"
|
||||
assert c.model_id == "m1"
|
||||
assert c.json_schema == {"type": "object"}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_extra_fields_allowed(self):
|
||||
c = AgentNodeConfig(custom_field="value")
|
||||
assert c.custom_field == "value"
|
||||
|
||||
|
||||
# ── ConditionCase / ConditionNodeConfig ──────────────────────────────────────
|
||||
|
||||
|
||||
class TestConditionCase:
|
||||
@pytest.mark.unit
|
||||
def test_alias(self):
|
||||
c = ConditionCase(expression="x > 1", sourceHandle="handle-1")
|
||||
assert c.source_handle == "handle-1"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_defaults(self):
|
||||
c = ConditionCase(sourceHandle="h")
|
||||
assert c.name is None
|
||||
assert c.expression == ""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_extra_forbidden(self):
|
||||
with pytest.raises(ValidationError):
|
||||
ConditionCase(sourceHandle="h", extra="nope")
|
||||
|
||||
|
||||
class TestConditionNodeConfig:
|
||||
@pytest.mark.unit
|
||||
def test_defaults(self):
|
||||
c = ConditionNodeConfig()
|
||||
assert c.mode == "simple"
|
||||
assert c.cases == []
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_with_cases(self):
|
||||
c = ConditionNodeConfig(
|
||||
mode="advanced",
|
||||
cases=[{"expression": "x > 1", "sourceHandle": "h1"}],
|
||||
)
|
||||
assert c.mode == "advanced"
|
||||
assert len(c.cases) == 1
|
||||
assert c.cases[0].source_handle == "h1"
|
||||
|
||||
|
||||
# ── StateOperation ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestStateOperation:
|
||||
@pytest.mark.unit
|
||||
def test_defaults(self):
|
||||
s = StateOperation()
|
||||
assert s.expression == ""
|
||||
assert s.target_variable == ""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_extra_forbidden(self):
|
||||
with pytest.raises(ValidationError):
|
||||
StateOperation(expression="a", target_variable="b", extra="no")
|
||||
|
||||
|
||||
# ── WorkflowEdgeCreate / WorkflowEdge ───────────────────────────────────────
|
||||
|
||||
|
||||
class TestWorkflowEdgeCreate:
|
||||
@pytest.mark.unit
|
||||
def test_aliases(self):
|
||||
e = WorkflowEdgeCreate(
|
||||
id="e1",
|
||||
workflow_id="w1",
|
||||
source="n1",
|
||||
target="n2",
|
||||
sourceHandle="sh",
|
||||
targetHandle="th",
|
||||
)
|
||||
assert e.source_id == "n1"
|
||||
assert e.target_id == "n2"
|
||||
assert e.source_handle == "sh"
|
||||
assert e.target_handle == "th"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_optional_handles_default_none(self):
|
||||
e = WorkflowEdgeCreate(id="e1", workflow_id="w1", source="n1", target="n2")
|
||||
assert e.source_handle is None
|
||||
assert e.target_handle is None
|
||||
|
||||
|
||||
class TestWorkflowEdge:
|
||||
@pytest.mark.unit
|
||||
def test_objectid_conversion(self):
|
||||
oid = ObjectId()
|
||||
e = WorkflowEdge(
|
||||
_id=oid,
|
||||
id="e1",
|
||||
workflow_id="w1",
|
||||
source="n1",
|
||||
target="n2",
|
||||
)
|
||||
assert e.mongo_id == str(oid)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_string_id_passthrough(self):
|
||||
e = WorkflowEdge(
|
||||
_id="string-id",
|
||||
id="e1",
|
||||
workflow_id="w1",
|
||||
source="n1",
|
||||
target="n2",
|
||||
)
|
||||
assert e.mongo_id == "string-id"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_none_id(self):
|
||||
e = WorkflowEdge(id="e1", workflow_id="w1", source="n1", target="n2")
|
||||
assert e.mongo_id is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_mongo_doc(self):
|
||||
e = WorkflowEdge(
|
||||
id="e1",
|
||||
workflow_id="w1",
|
||||
source="n1",
|
||||
target="n2",
|
||||
sourceHandle="sh",
|
||||
targetHandle="th",
|
||||
)
|
||||
doc = e.to_mongo_doc()
|
||||
assert doc == {
|
||||
"id": "e1",
|
||||
"workflow_id": "w1",
|
||||
"source_id": "n1",
|
||||
"target_id": "n2",
|
||||
"source_handle": "sh",
|
||||
"target_handle": "th",
|
||||
}
|
||||
|
||||
|
||||
# ── WorkflowNodeCreate / WorkflowNode ───────────────────────────────────────
|
||||
|
||||
|
||||
class TestWorkflowNodeCreate:
|
||||
@pytest.mark.unit
|
||||
def test_defaults(self):
|
||||
n = WorkflowNodeCreate(id="n1", workflow_id="w1", type=NodeType.AGENT)
|
||||
assert n.title == "Node"
|
||||
assert n.description is None
|
||||
assert n.position.x == 0.0
|
||||
assert n.position.y == 0.0
|
||||
assert n.config == {}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_position_from_dict(self):
|
||||
n = WorkflowNodeCreate(
|
||||
id="n1",
|
||||
workflow_id="w1",
|
||||
type=NodeType.START,
|
||||
position={"x": 100, "y": 200},
|
||||
)
|
||||
assert isinstance(n.position, Position)
|
||||
assert n.position.x == 100
|
||||
assert n.position.y == 200
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_position_from_position_object(self):
|
||||
pos = Position(x=5, y=10)
|
||||
n = WorkflowNodeCreate(
|
||||
id="n1", workflow_id="w1", type=NodeType.END, position=pos
|
||||
)
|
||||
assert n.position is pos
|
||||
|
||||
|
||||
class TestWorkflowNode:
|
||||
@pytest.mark.unit
|
||||
def test_objectid_conversion(self):
|
||||
oid = ObjectId()
|
||||
n = WorkflowNode(
|
||||
_id=oid, id="n1", workflow_id="w1", type=NodeType.AGENT
|
||||
)
|
||||
assert n.mongo_id == str(oid)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_mongo_doc(self):
|
||||
n = WorkflowNode(
|
||||
id="n1",
|
||||
workflow_id="w1",
|
||||
type=NodeType.AGENT,
|
||||
title="My Node",
|
||||
description="desc",
|
||||
position={"x": 10, "y": 20},
|
||||
config={"key": "val"},
|
||||
)
|
||||
doc = n.to_mongo_doc()
|
||||
assert doc == {
|
||||
"id": "n1",
|
||||
"workflow_id": "w1",
|
||||
"type": "agent",
|
||||
"title": "My Node",
|
||||
"description": "desc",
|
||||
"position": {"x": 10.0, "y": 20.0},
|
||||
"config": {"key": "val"},
|
||||
}
|
||||
|
||||
|
||||
# ── WorkflowCreate / Workflow ───────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestWorkflowCreate:
|
||||
@pytest.mark.unit
|
||||
def test_defaults(self):
|
||||
w = WorkflowCreate()
|
||||
assert w.name == "New Workflow"
|
||||
assert w.description is None
|
||||
assert w.user is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_custom_values(self):
|
||||
w = WorkflowCreate(name="Test", description="d", user="u1")
|
||||
assert w.name == "Test"
|
||||
assert w.description == "d"
|
||||
assert w.user == "u1"
|
||||
|
||||
|
||||
class TestWorkflow:
|
||||
@pytest.mark.unit
|
||||
def test_objectid_conversion(self):
|
||||
oid = ObjectId()
|
||||
w = Workflow(_id=oid)
|
||||
assert w.id == str(oid)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_string_id(self):
|
||||
w = Workflow(_id="abc")
|
||||
assert w.id == "abc"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_none_id(self):
|
||||
w = Workflow()
|
||||
assert w.id is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_datetime_defaults(self):
|
||||
before = datetime.now(timezone.utc)
|
||||
w = Workflow()
|
||||
after = datetime.now(timezone.utc)
|
||||
assert before <= w.created_at <= after
|
||||
assert before <= w.updated_at <= after
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_mongo_doc(self):
|
||||
w = Workflow(name="W", description="d", user="u1")
|
||||
doc = w.to_mongo_doc()
|
||||
assert doc["name"] == "W"
|
||||
assert doc["description"] == "d"
|
||||
assert doc["user"] == "u1"
|
||||
assert "created_at" in doc
|
||||
assert "updated_at" in doc
|
||||
|
||||
|
||||
# ── WorkflowGraph ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestWorkflowGraph:
|
||||
@pytest.fixture
|
||||
def graph(self):
|
||||
workflow = Workflow(name="test")
|
||||
nodes = [
|
||||
WorkflowNode(id="start", workflow_id="w1", type=NodeType.START),
|
||||
WorkflowNode(id="agent1", workflow_id="w1", type=NodeType.AGENT),
|
||||
WorkflowNode(id="end", workflow_id="w1", type=NodeType.END),
|
||||
]
|
||||
edges = [
|
||||
WorkflowEdge(
|
||||
id="e1", workflow_id="w1", source="start", target="agent1"
|
||||
),
|
||||
WorkflowEdge(
|
||||
id="e2", workflow_id="w1", source="agent1", target="end"
|
||||
),
|
||||
]
|
||||
return WorkflowGraph(workflow=workflow, nodes=nodes, edges=edges)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_node_by_id_found(self, graph):
|
||||
node = graph.get_node_by_id("agent1")
|
||||
assert node is not None
|
||||
assert node.id == "agent1"
|
||||
assert node.type == NodeType.AGENT
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_node_by_id_not_found(self, graph):
|
||||
assert graph.get_node_by_id("nonexistent") is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_start_node(self, graph):
|
||||
start = graph.get_start_node()
|
||||
assert start is not None
|
||||
assert start.id == "start"
|
||||
assert start.type == NodeType.START
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_start_node_missing(self):
|
||||
g = WorkflowGraph(
|
||||
workflow=Workflow(),
|
||||
nodes=[
|
||||
WorkflowNode(id="a", workflow_id="w", type=NodeType.AGENT),
|
||||
],
|
||||
)
|
||||
assert g.get_start_node() is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_outgoing_edges(self, graph):
|
||||
edges = graph.get_outgoing_edges("start")
|
||||
assert len(edges) == 1
|
||||
assert edges[0].target_id == "agent1"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_outgoing_edges_none(self, graph):
|
||||
edges = graph.get_outgoing_edges("end")
|
||||
assert edges == []
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_empty_graph(self):
|
||||
g = WorkflowGraph(workflow=Workflow())
|
||||
assert g.nodes == []
|
||||
assert g.edges == []
|
||||
assert g.get_start_node() is None
|
||||
|
||||
|
||||
# ── NodeExecutionLog ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestNodeExecutionLog:
|
||||
@pytest.mark.unit
|
||||
def test_required_fields(self):
|
||||
now = datetime.now(timezone.utc)
|
||||
log = NodeExecutionLog(
|
||||
node_id="n1",
|
||||
node_type="agent",
|
||||
status=ExecutionStatus.RUNNING,
|
||||
started_at=now,
|
||||
)
|
||||
assert log.node_id == "n1"
|
||||
assert log.completed_at is None
|
||||
assert log.error is None
|
||||
assert log.state_snapshot == {}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_full_log(self):
|
||||
started = datetime.now(timezone.utc)
|
||||
completed = datetime.now(timezone.utc)
|
||||
log = NodeExecutionLog(
|
||||
node_id="n1",
|
||||
node_type="agent",
|
||||
status=ExecutionStatus.COMPLETED,
|
||||
started_at=started,
|
||||
completed_at=completed,
|
||||
error=None,
|
||||
state_snapshot={"key": "value"},
|
||||
)
|
||||
assert log.completed_at == completed
|
||||
assert log.state_snapshot == {"key": "value"}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_extra_forbidden(self):
|
||||
with pytest.raises(ValidationError):
|
||||
NodeExecutionLog(
|
||||
node_id="n",
|
||||
node_type="agent",
|
||||
status=ExecutionStatus.PENDING,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
extra="no",
|
||||
)
|
||||
|
||||
|
||||
# ── WorkflowRunCreate / WorkflowRun ─────────────────────────────────────────
|
||||
|
||||
|
||||
class TestWorkflowRunCreate:
|
||||
@pytest.mark.unit
|
||||
def test_defaults(self):
|
||||
r = WorkflowRunCreate(workflow_id="w1")
|
||||
assert r.workflow_id == "w1"
|
||||
assert r.inputs == {}
|
||||
|
||||
|
||||
class TestWorkflowRun:
|
||||
@pytest.mark.unit
|
||||
def test_defaults(self):
|
||||
r = WorkflowRun(workflow_id="w1")
|
||||
assert r.status == ExecutionStatus.PENDING
|
||||
assert r.inputs == {}
|
||||
assert r.outputs == {}
|
||||
assert r.steps == []
|
||||
assert r.completed_at is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_objectid_conversion(self):
|
||||
oid = ObjectId()
|
||||
r = WorkflowRun(_id=oid, workflow_id="w1")
|
||||
assert r.id == str(oid)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_mongo_doc(self):
|
||||
now = datetime.now(timezone.utc)
|
||||
log = NodeExecutionLog(
|
||||
node_id="n1",
|
||||
node_type="agent",
|
||||
status=ExecutionStatus.COMPLETED,
|
||||
started_at=now,
|
||||
)
|
||||
r = WorkflowRun(
|
||||
workflow_id="w1",
|
||||
status=ExecutionStatus.RUNNING,
|
||||
inputs={"q": "hello"},
|
||||
outputs={"a": "world"},
|
||||
steps=[log],
|
||||
)
|
||||
doc = r.to_mongo_doc()
|
||||
assert doc["workflow_id"] == "w1"
|
||||
assert doc["status"] == "running"
|
||||
assert doc["inputs"] == {"q": "hello"}
|
||||
assert doc["outputs"] == {"a": "world"}
|
||||
assert len(doc["steps"]) == 1
|
||||
assert doc["steps"][0]["node_id"] == "n1"
|
||||
assert doc["completed_at"] is None
|
||||
0
tests/agents/tools/__init__.py
Normal file
0
tests/agents/tools/__init__.py
Normal file
620
tests/agents/tools/test_api_body_serializer.py
Normal file
620
tests/agents/tools/test_api_body_serializer.py
Normal file
@@ -0,0 +1,620 @@
|
||||
"""Comprehensive tests for application/agents/tools/api_body_serializer.py
|
||||
|
||||
Covers: ContentType enum, RequestBodySerializer (JSON, form-urlencoded,
|
||||
multipart, text/plain, XML, octet-stream, unknown types), encoding rules,
|
||||
helper methods (_percent_encode, _escape_xml, _dict_to_xml).
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from application.agents.tools.api_body_serializer import (
|
||||
ContentType,
|
||||
RequestBodySerializer,
|
||||
)
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# ContentType Enum
|
||||
# =====================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestContentTypeEnum:
|
||||
|
||||
def test_json_value(self):
|
||||
assert ContentType.JSON == "application/json"
|
||||
|
||||
def test_form_urlencoded_value(self):
|
||||
assert ContentType.FORM_URLENCODED == "application/x-www-form-urlencoded"
|
||||
|
||||
def test_multipart_value(self):
|
||||
assert ContentType.MULTIPART_FORM_DATA == "multipart/form-data"
|
||||
|
||||
def test_text_plain_value(self):
|
||||
assert ContentType.TEXT_PLAIN == "text/plain"
|
||||
|
||||
def test_xml_value(self):
|
||||
assert ContentType.XML == "application/xml"
|
||||
|
||||
def test_octet_stream_value(self):
|
||||
assert ContentType.OCTET_STREAM == "application/octet-stream"
|
||||
|
||||
def test_str_enum(self):
|
||||
assert isinstance(ContentType.JSON, str)
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# JSON Serialization
|
||||
# =====================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSerializeJson:
|
||||
|
||||
def test_basic_json(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"key": "value"}, ContentType.JSON
|
||||
)
|
||||
assert json.loads(body) == {"key": "value"}
|
||||
assert headers["Content-Type"] == "application/json"
|
||||
|
||||
def test_nested_json(self):
|
||||
data = {"user": {"name": "Alice", "age": 30}}
|
||||
body, headers = RequestBodySerializer.serialize(data, ContentType.JSON)
|
||||
assert json.loads(body) == data
|
||||
|
||||
def test_empty_body_returns_none(self):
|
||||
body, headers = RequestBodySerializer.serialize({}, ContentType.JSON)
|
||||
assert body is None
|
||||
assert headers == {}
|
||||
|
||||
def test_none_body(self):
|
||||
body, headers = RequestBodySerializer.serialize(None, ContentType.JSON)
|
||||
assert body is None
|
||||
|
||||
def test_unknown_content_type_falls_back_to_json(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"k": "v"}, "application/vnd.custom+json"
|
||||
)
|
||||
assert json.loads(body) == {"k": "v"}
|
||||
|
||||
def test_content_type_with_charset_suffix(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"k": "v"}, "application/json; charset=utf-8"
|
||||
)
|
||||
assert json.loads(body) == {"k": "v"}
|
||||
|
||||
def test_compact_json_format(self):
|
||||
body, _ = RequestBodySerializer.serialize(
|
||||
{"a": 1, "b": 2}, ContentType.JSON
|
||||
)
|
||||
# Should use compact separators
|
||||
assert " " not in body
|
||||
|
||||
def test_unicode_json(self):
|
||||
body, _ = RequestBodySerializer.serialize(
|
||||
{"name": "Heisenberg"}, ContentType.JSON
|
||||
)
|
||||
parsed = json.loads(body)
|
||||
assert parsed["name"] == "Heisenberg"
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Form URL-Encoded Serialization
|
||||
# =====================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSerializeFormUrlencoded:
|
||||
|
||||
def test_basic_form(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"name": "Alice", "age": "30"}, ContentType.FORM_URLENCODED
|
||||
)
|
||||
assert "name=Alice" in body
|
||||
assert "age=30" in body
|
||||
assert headers["Content-Type"] == "application/x-www-form-urlencoded"
|
||||
|
||||
def test_none_values_skipped(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"name": "Alice", "skip": None}, ContentType.FORM_URLENCODED
|
||||
)
|
||||
assert "name=Alice" in body
|
||||
assert "skip" not in body
|
||||
|
||||
def test_list_explode_true(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"tags": ["a", "b"]},
|
||||
ContentType.FORM_URLENCODED,
|
||||
encoding_rules={"tags": {"style": "form", "explode": True}},
|
||||
)
|
||||
assert "tags=a" in body
|
||||
assert "tags=b" in body
|
||||
|
||||
def test_list_explode_false(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"tags": ["a", "b"]},
|
||||
ContentType.FORM_URLENCODED,
|
||||
encoding_rules={"tags": {"style": "form", "explode": False}},
|
||||
)
|
||||
assert "tags=" in body
|
||||
assert "a" in body and "b" in body
|
||||
|
||||
def test_dict_value_json_content_type(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"metadata": {"key": "val"}},
|
||||
ContentType.FORM_URLENCODED,
|
||||
encoding_rules={"metadata": {"contentType": "application/json"}},
|
||||
)
|
||||
assert "metadata" in body
|
||||
|
||||
def test_dict_value_xml_content_type(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"data": {"name": "test"}},
|
||||
ContentType.FORM_URLENCODED,
|
||||
encoding_rules={"data": {"contentType": "application/xml"}},
|
||||
)
|
||||
assert "data" in body
|
||||
|
||||
def test_dict_value_deep_object_explode(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"filter": {"status": "active", "type": "doc"}},
|
||||
ContentType.FORM_URLENCODED,
|
||||
encoding_rules={
|
||||
"filter": {"style": "deepObject", "explode": True}
|
||||
},
|
||||
)
|
||||
assert "filter" in body
|
||||
|
||||
def test_dict_value_non_exploded(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"obj": {"a": "1", "b": "2"}},
|
||||
ContentType.FORM_URLENCODED,
|
||||
encoding_rules={"obj": {"style": "form", "explode": False}},
|
||||
)
|
||||
assert "obj" in body
|
||||
|
||||
def test_default_explode_for_form_style(self):
|
||||
"""Default explode should be True when style is 'form'."""
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"items": ["x", "y"]},
|
||||
ContentType.FORM_URLENCODED,
|
||||
encoding_rules={"items": {"style": "form"}},
|
||||
)
|
||||
# explode defaults to True for form style => separate params
|
||||
assert "items=x" in body
|
||||
assert "items=y" in body
|
||||
|
||||
def test_special_characters_encoded(self):
|
||||
body, _ = RequestBodySerializer.serialize(
|
||||
{"q": "hello world&more"}, ContentType.FORM_URLENCODED
|
||||
)
|
||||
assert "hello" in body
|
||||
assert "q=" in body
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Text Plain Serialization
|
||||
# =====================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSerializeTextPlain:
|
||||
|
||||
def test_single_value(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"message": "hello"}, ContentType.TEXT_PLAIN
|
||||
)
|
||||
assert body == "hello"
|
||||
assert headers["Content-Type"] == "text/plain"
|
||||
|
||||
def test_multiple_values(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"name": "Alice", "age": 30}, ContentType.TEXT_PLAIN
|
||||
)
|
||||
assert "name: Alice" in body
|
||||
assert "age: 30" in body
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# XML Serialization
|
||||
# =====================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSerializeXml:
|
||||
|
||||
def test_basic_xml(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"name": "Alice"}, ContentType.XML
|
||||
)
|
||||
assert '<?xml version="1.0"' in body
|
||||
assert "<name>Alice</name>" in body
|
||||
assert headers["Content-Type"] == "application/xml"
|
||||
|
||||
def test_nested_xml(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"user": {"name": "Alice"}}, ContentType.XML
|
||||
)
|
||||
assert "<user>" in body
|
||||
assert "<name>Alice</name>" in body
|
||||
|
||||
def test_xml_escapes_special_chars(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"data": "<script>alert('xss')</script>"}, ContentType.XML
|
||||
)
|
||||
assert "<script>" in body
|
||||
|
||||
def test_xml_with_list(self):
|
||||
body, _ = RequestBodySerializer.serialize(
|
||||
{"items": [1, 2, 3]}, ContentType.XML
|
||||
)
|
||||
assert "<item>1</item>" in body
|
||||
assert "<item>2</item>" in body
|
||||
assert "<item>3</item>" in body
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Octet Stream Serialization
|
||||
# =====================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSerializeOctetStream:
|
||||
|
||||
def test_dict_body(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"key": "val"}, ContentType.OCTET_STREAM
|
||||
)
|
||||
assert isinstance(body, bytes)
|
||||
assert headers["Content-Type"] == "application/octet-stream"
|
||||
|
||||
def test_bytes_body(self):
|
||||
body, headers = RequestBodySerializer._serialize_octet_stream(b"\x00\x01")
|
||||
assert body == b"\x00\x01"
|
||||
assert headers["Content-Type"] == "application/octet-stream"
|
||||
|
||||
def test_string_body(self):
|
||||
body, headers = RequestBodySerializer._serialize_octet_stream("hello")
|
||||
assert body == b"hello"
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Multipart Form Data Serialization
|
||||
# =====================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSerializeMultipartFormData:
|
||||
|
||||
def test_basic_multipart(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"field": "value"}, ContentType.MULTIPART_FORM_DATA
|
||||
)
|
||||
assert isinstance(body, bytes)
|
||||
assert "multipart/form-data" in headers["Content-Type"]
|
||||
assert "boundary=" in headers["Content-Type"]
|
||||
|
||||
def test_none_values_skipped(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"field": "value", "empty": None}, ContentType.MULTIPART_FORM_DATA
|
||||
)
|
||||
body_str = body.decode("utf-8", errors="replace")
|
||||
assert "field" in body_str
|
||||
assert "empty" not in body_str
|
||||
|
||||
def test_multipart_with_bytes(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"file": b"\x00\x01\x02"}, ContentType.MULTIPART_FORM_DATA
|
||||
)
|
||||
assert isinstance(body, bytes)
|
||||
|
||||
def test_multipart_with_dict_json(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"meta": {"key": "val"}},
|
||||
ContentType.MULTIPART_FORM_DATA,
|
||||
encoding_rules={"meta": {"contentType": "application/json"}},
|
||||
)
|
||||
body_str = body.decode("utf-8", errors="replace")
|
||||
assert "meta" in body_str
|
||||
assert "application/json" in body_str
|
||||
|
||||
def test_multipart_with_dict_xml(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"data": {"name": "test"}},
|
||||
ContentType.MULTIPART_FORM_DATA,
|
||||
encoding_rules={"data": {"contentType": "application/xml"}},
|
||||
)
|
||||
body_str = body.decode("utf-8", errors="replace")
|
||||
assert "data" in body_str
|
||||
|
||||
def test_multipart_octet_stream_bytes(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"bin": b"\xff\xfe"},
|
||||
ContentType.MULTIPART_FORM_DATA,
|
||||
encoding_rules={"bin": {"contentType": "application/octet-stream"}},
|
||||
)
|
||||
body_str = body.decode("utf-8", errors="replace")
|
||||
assert "bin" in body_str
|
||||
assert "Content-Transfer-Encoding: base64" in body_str
|
||||
|
||||
def test_multipart_string_with_json_content_type(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"json_str": '{"a": 1}'},
|
||||
ContentType.MULTIPART_FORM_DATA,
|
||||
encoding_rules={"json_str": {"contentType": "application/json"}},
|
||||
)
|
||||
body_str = body.decode("utf-8", errors="replace")
|
||||
assert "json_str" in body_str
|
||||
|
||||
def test_multipart_string_with_non_text_content_type(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"custom": "data"},
|
||||
ContentType.MULTIPART_FORM_DATA,
|
||||
encoding_rules={"custom": {"contentType": "application/custom"}},
|
||||
)
|
||||
body_str = body.decode("utf-8", errors="replace")
|
||||
assert "custom" in body_str
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Helper Methods
|
||||
# =====================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestHelpers:
|
||||
|
||||
def test_percent_encode_space(self):
|
||||
assert RequestBodySerializer._percent_encode("hello world") == "hello%20world"
|
||||
|
||||
def test_percent_encode_slash(self):
|
||||
assert RequestBodySerializer._percent_encode("a/b") == "a%2Fb"
|
||||
|
||||
def test_percent_encode_safe_chars(self):
|
||||
assert RequestBodySerializer._percent_encode("a/b", safe_chars="/") == "a/b"
|
||||
|
||||
def test_escape_xml_ampersand(self):
|
||||
assert "&" in RequestBodySerializer._escape_xml("&")
|
||||
|
||||
def test_escape_xml_lt(self):
|
||||
assert "<" in RequestBodySerializer._escape_xml("<")
|
||||
|
||||
def test_escape_xml_gt(self):
|
||||
assert ">" in RequestBodySerializer._escape_xml(">")
|
||||
|
||||
def test_escape_xml_quote(self):
|
||||
assert """ in RequestBodySerializer._escape_xml('"')
|
||||
|
||||
def test_escape_xml_apos(self):
|
||||
assert "'" in RequestBodySerializer._escape_xml("'")
|
||||
|
||||
def test_dict_to_xml_list(self):
|
||||
xml = RequestBodySerializer._dict_to_xml({"items": [1, 2, 3]})
|
||||
assert "<item>1</item>" in xml
|
||||
assert "<item>2</item>" in xml
|
||||
|
||||
def test_dict_to_xml_custom_root(self):
|
||||
xml = RequestBodySerializer._dict_to_xml({"key": "val"}, root_name="data")
|
||||
assert "<data>" in xml
|
||||
assert "<key>val</key>" in xml
|
||||
|
||||
def test_dict_to_xml_deeply_nested(self):
|
||||
xml = RequestBodySerializer._dict_to_xml({"a": {"b": {"c": "deep"}}})
|
||||
assert "<c>deep</c>" in xml
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Error Handling
|
||||
# =====================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSerializationErrors:
|
||||
|
||||
def test_serialize_raises_on_internal_error(self):
|
||||
"""Test that serialization errors are wrapped in ValueError."""
|
||||
# Patch _serialize_json to raise
|
||||
with pytest.raises(ValueError, match="Failed to serialize"):
|
||||
RequestBodySerializer.serialize(
|
||||
{"key": object()}, # object() is not JSON-serializable
|
||||
ContentType.JSON,
|
||||
)
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Coverage gap tests (lines 145, 155, 159, 162, 166, 271)
|
||||
# =====================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSerializeFormValueGaps:
|
||||
|
||||
def test_dict_explode_without_deep_object(self):
|
||||
"""Cover line 145: dict with explode=True but style != deepObject."""
|
||||
result = RequestBodySerializer._serialize_form_value(
|
||||
value={"a": "1", "b": "2"},
|
||||
style="form",
|
||||
explode=True,
|
||||
content_type="application/x-www-form-urlencoded",
|
||||
key="data",
|
||||
)
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 2
|
||||
|
||||
def test_list_explode_true(self):
|
||||
"""Cover line 155: list with explode=True."""
|
||||
result = RequestBodySerializer._serialize_form_value(
|
||||
value=["x", "y", "z"],
|
||||
style="form",
|
||||
explode=True,
|
||||
content_type="application/x-www-form-urlencoded",
|
||||
key="items",
|
||||
)
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 3
|
||||
|
||||
def test_list_explode_false(self):
|
||||
"""Cover line 159: list with explode=False."""
|
||||
result = RequestBodySerializer._serialize_form_value(
|
||||
value=["x", "y"],
|
||||
style="form",
|
||||
explode=False,
|
||||
content_type="application/x-www-form-urlencoded",
|
||||
key="items",
|
||||
)
|
||||
assert isinstance(result, str)
|
||||
# comma-joined and percent-encoded
|
||||
assert "x" in result
|
||||
assert "y" in result
|
||||
|
||||
def test_primitive_value(self):
|
||||
"""Cover line 162: primitive string value."""
|
||||
result = RequestBodySerializer._serialize_form_value(
|
||||
value="hello world",
|
||||
style="form",
|
||||
explode=False,
|
||||
content_type="application/x-www-form-urlencoded",
|
||||
key="name",
|
||||
)
|
||||
assert isinstance(result, str)
|
||||
assert "hello" in result
|
||||
|
||||
def test_dict_no_explode(self):
|
||||
"""Cover line 166 area: dict with explode=False returns comma-joined."""
|
||||
result = RequestBodySerializer._serialize_form_value(
|
||||
value={"k1": "v1", "k2": "v2"},
|
||||
style="form",
|
||||
explode=False,
|
||||
content_type="application/x-www-form-urlencoded",
|
||||
key="data",
|
||||
)
|
||||
assert isinstance(result, str)
|
||||
assert "k1" in result
|
||||
|
||||
def test_octet_stream_string_input(self):
|
||||
"""Cover line 271: _serialize_octet_stream with string input."""
|
||||
body, headers = RequestBodySerializer._serialize_octet_stream("hello bytes")
|
||||
assert body == b"hello bytes"
|
||||
assert headers["Content-Type"] == ContentType.OCTET_STREAM.value
|
||||
|
||||
def test_octet_stream_dict_input(self):
|
||||
"""Cover: _serialize_octet_stream with dict input (fallback to JSON)."""
|
||||
body, headers = RequestBodySerializer._serialize_octet_stream({"key": "val"})
|
||||
assert isinstance(body, bytes)
|
||||
import json
|
||||
|
||||
parsed = json.loads(body.decode("utf-8"))
|
||||
assert parsed == {"key": "val"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Coverage — additional uncovered lines: 226, 229, 271, 275, 279
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestApiBodySerializerMultipartParts:
|
||||
|
||||
def test_multipart_dict_unknown_content_type(self):
|
||||
"""Cover line 226: dict with unknown content type uses str()."""
|
||||
from application.agents.tools.api_body_serializer import (
|
||||
RequestBodySerializer,
|
||||
)
|
||||
|
||||
result = RequestBodySerializer._create_multipart_part(
|
||||
name="field",
|
||||
value={"key": "val"},
|
||||
content_type="text/csv",
|
||||
headers_rule={},
|
||||
)
|
||||
assert "text/csv" in result
|
||||
assert "key" in result
|
||||
|
||||
def test_multipart_string_json_content_type(self):
|
||||
"""Cover line 229: string value with application/json content type."""
|
||||
from application.agents.tools.api_body_serializer import (
|
||||
RequestBodySerializer,
|
||||
)
|
||||
|
||||
result = RequestBodySerializer._create_multipart_part(
|
||||
name="data",
|
||||
value='{"a": 1}',
|
||||
content_type="application/json",
|
||||
headers_rule={},
|
||||
)
|
||||
assert "application/json" in result
|
||||
assert '{"a": 1}' in result
|
||||
|
||||
def test_multipart_string_xml_content_type(self):
|
||||
"""Cover line 229: string value with application/xml content type."""
|
||||
from application.agents.tools.api_body_serializer import (
|
||||
RequestBodySerializer,
|
||||
)
|
||||
|
||||
result = RequestBodySerializer._create_multipart_part(
|
||||
name="data",
|
||||
value="<root/>",
|
||||
content_type="application/xml",
|
||||
headers_rule={},
|
||||
)
|
||||
assert "application/xml" in result
|
||||
assert "<root/>" in result
|
||||
|
||||
def test_multipart_string_unknown_content_type(self):
|
||||
"""Cover line 229: string with unknown content type falls through."""
|
||||
from application.agents.tools.api_body_serializer import (
|
||||
RequestBodySerializer,
|
||||
)
|
||||
|
||||
result = RequestBodySerializer._create_multipart_part(
|
||||
name="data",
|
||||
value="some text",
|
||||
content_type="application/custom",
|
||||
headers_rule={},
|
||||
)
|
||||
assert "application/custom" in result
|
||||
assert "some text" in result
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestApiBodySerializerOctetStreamCoverage:
|
||||
|
||||
def test_octet_stream_bytes_input(self):
|
||||
"""Cover line 271: _serialize_octet_stream with bytes input."""
|
||||
from application.agents.tools.api_body_serializer import (
|
||||
ContentType,
|
||||
RequestBodySerializer,
|
||||
)
|
||||
|
||||
body, headers = RequestBodySerializer._serialize_octet_stream(b"raw bytes")
|
||||
assert body == b"raw bytes"
|
||||
assert headers["Content-Type"] == ContentType.OCTET_STREAM.value
|
||||
|
||||
def test_octet_stream_string_input(self):
|
||||
"""Cover line 275: _serialize_octet_stream with string input."""
|
||||
from application.agents.tools.api_body_serializer import (
|
||||
ContentType,
|
||||
RequestBodySerializer,
|
||||
)
|
||||
|
||||
body, headers = RequestBodySerializer._serialize_octet_stream("text data")
|
||||
assert body == b"text data"
|
||||
assert headers["Content-Type"] == ContentType.OCTET_STREAM.value
|
||||
|
||||
def test_octet_stream_dict_input(self):
|
||||
"""Cover line 279: _serialize_octet_stream with dict input (fallback to JSON)."""
|
||||
import json
|
||||
|
||||
from application.agents.tools.api_body_serializer import (
|
||||
ContentType,
|
||||
RequestBodySerializer,
|
||||
)
|
||||
|
||||
body, headers = RequestBodySerializer._serialize_octet_stream({"k": "v"})
|
||||
assert isinstance(body, bytes)
|
||||
parsed = json.loads(body.decode("utf-8"))
|
||||
assert parsed == {"k": "v"}
|
||||
assert headers["Content-Type"] == ContentType.OCTET_STREAM.value
|
||||
516
tests/agents/tools/test_api_tool.py
Normal file
516
tests/agents/tools/test_api_tool.py
Normal file
@@ -0,0 +1,516 @@
|
||||
"""Comprehensive tests for application/agents/tools/api_tool.py
|
||||
|
||||
Covers: APITool initialization, all HTTP methods, path param substitution,
|
||||
SSRF validation, error handling, response parsing, body serialization.
|
||||
"""
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from application.agents.tools.api_tool import APITool, DEFAULT_TIMEOUT
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def get_tool():
|
||||
return APITool(
|
||||
config={
|
||||
"url": "https://api.example.com/data",
|
||||
"method": "GET",
|
||||
"headers": {"Accept": "application/json"},
|
||||
"query_params": {},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def post_tool():
|
||||
return APITool(
|
||||
config={
|
||||
"url": "https://api.example.com/items",
|
||||
"method": "POST",
|
||||
"headers": {},
|
||||
"query_params": {},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Initialization
|
||||
# =====================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestAPIToolInit:
|
||||
|
||||
def test_default_values(self):
|
||||
tool = APITool(config={})
|
||||
assert tool.url == ""
|
||||
assert tool.method == "GET"
|
||||
assert tool.headers == {}
|
||||
assert tool.query_params == {}
|
||||
assert tool.body_content_type == "application/json"
|
||||
assert tool.body_encoding_rules == {}
|
||||
|
||||
def test_custom_config(self):
|
||||
tool = APITool(config={
|
||||
"url": "https://api.test.com",
|
||||
"method": "POST",
|
||||
"headers": {"X-Key": "val"},
|
||||
"query_params": {"page": "1"},
|
||||
"body_content_type": "application/xml",
|
||||
"body_encoding_rules": {"field": {"style": "form"}},
|
||||
})
|
||||
assert tool.url == "https://api.test.com"
|
||||
assert tool.method == "POST"
|
||||
assert tool.headers == {"X-Key": "val"}
|
||||
assert tool.query_params == {"page": "1"}
|
||||
assert tool.body_content_type == "application/xml"
|
||||
|
||||
def test_default_timeout_constant(self):
|
||||
assert DEFAULT_TIMEOUT == 90
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# HTTP Methods
|
||||
# =====================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestMakeApiCall:
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_successful_get(self, mock_get, mock_validate, get_tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {"result": "ok"}
|
||||
mock_resp.content = b'{"result":"ok"}'
|
||||
mock_get.return_value = mock_resp
|
||||
|
||||
result = get_tool.execute_action("any_action")
|
||||
|
||||
assert result["status_code"] == 200
|
||||
assert result["data"] == {"result": "ok"}
|
||||
assert result["message"] == "API call successful."
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.post")
|
||||
def test_successful_post(self, mock_post, mock_validate, post_tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 201
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {"id": 1}
|
||||
mock_resp.content = b'{"id":1}'
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
result = post_tool.execute_action("create", name="test")
|
||||
assert result["status_code"] == 201
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.put")
|
||||
def test_put_method(self, mock_put, mock_validate):
|
||||
tool = APITool(config={"url": "https://example.com/item/1", "method": "PUT"})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {}
|
||||
mock_resp.content = b'{}'
|
||||
mock_put.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("update", name="new")
|
||||
assert result["status_code"] == 200
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.delete")
|
||||
def test_delete_method(self, mock_delete, mock_validate):
|
||||
tool = APITool(config={"url": "https://example.com/item/1", "method": "DELETE"})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 204
|
||||
mock_resp.headers = {"Content-Type": "text/plain"}
|
||||
mock_resp.content = b''
|
||||
mock_delete.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("delete")
|
||||
assert result["status_code"] == 204
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.patch")
|
||||
def test_patch_method(self, mock_patch, mock_validate):
|
||||
tool = APITool(config={"url": "https://example.com/item/1", "method": "PATCH"})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {"patched": True}
|
||||
mock_resp.content = b'{"patched":true}'
|
||||
mock_patch.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("patch", field="val")
|
||||
assert result["status_code"] == 200
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.head")
|
||||
def test_head_method(self, mock_head, mock_validate):
|
||||
tool = APITool(config={"url": "https://example.com", "method": "HEAD"})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "text/html"}
|
||||
mock_resp.content = b''
|
||||
mock_head.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("check")
|
||||
assert result["status_code"] == 200
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.options")
|
||||
def test_options_method(self, mock_options, mock_validate):
|
||||
tool = APITool(config={"url": "https://example.com", "method": "OPTIONS"})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "text/plain"}
|
||||
mock_resp.content = b''
|
||||
mock_options.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("options")
|
||||
assert result["status_code"] == 200
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
def test_unsupported_method(self, mock_validate):
|
||||
tool = APITool(config={"url": "https://example.com", "method": "CUSTOM"})
|
||||
result = tool.execute_action("any")
|
||||
assert result["status_code"] is None
|
||||
assert "Unsupported" in result["message"]
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# SSRF Validation
|
||||
# =====================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSSRFValidation:
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
def test_ssrf_blocked_initial_url(self, mock_validate, get_tool):
|
||||
from application.core.url_validation import SSRFError
|
||||
|
||||
mock_validate.side_effect = SSRFError("blocked")
|
||||
result = get_tool.execute_action("any")
|
||||
assert result["status_code"] is None
|
||||
assert "URL validation error" in result["message"]
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_ssrf_blocked_after_param_substitution(self, mock_get, mock_validate):
|
||||
from application.core.url_validation import SSRFError
|
||||
|
||||
tool = APITool(config={
|
||||
"url": "https://api.example.com/{host}/data",
|
||||
"method": "GET",
|
||||
"query_params": {"host": "169.254.169.254"},
|
||||
})
|
||||
|
||||
call_count = [0]
|
||||
|
||||
def side_effect(url):
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 2:
|
||||
raise SSRFError("blocked after substitution")
|
||||
|
||||
mock_validate.side_effect = side_effect
|
||||
result = tool.execute_action("any")
|
||||
assert result["status_code"] is None
|
||||
assert "URL validation error" in result["message"]
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Error Handling
|
||||
# =====================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestErrorHandling:
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_timeout_error(self, mock_get, mock_validate, get_tool):
|
||||
mock_get.side_effect = requests.exceptions.Timeout()
|
||||
result = get_tool.execute_action("any")
|
||||
assert result["status_code"] is None
|
||||
assert "timeout" in result["message"].lower()
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_connection_error(self, mock_get, mock_validate, get_tool):
|
||||
mock_get.side_effect = requests.exceptions.ConnectionError("refused")
|
||||
result = get_tool.execute_action("any")
|
||||
assert result["status_code"] is None
|
||||
assert "Connection error" in result["message"]
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_http_error_with_json(self, mock_get, mock_validate, get_tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 422
|
||||
mock_resp.json.return_value = {"error": "invalid_field"}
|
||||
mock_resp.raise_for_status.side_effect = requests.exceptions.HTTPError(
|
||||
response=mock_resp
|
||||
)
|
||||
mock_get.return_value = mock_resp
|
||||
|
||||
result = get_tool.execute_action("any")
|
||||
assert result["status_code"] == 422
|
||||
assert result["data"] == {"error": "invalid_field"}
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_http_error_non_json_body(self, mock_get, mock_validate, get_tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 404
|
||||
mock_resp.text = "Not Found"
|
||||
mock_resp.json.side_effect = json.JSONDecodeError("", "", 0)
|
||||
mock_resp.raise_for_status.side_effect = requests.exceptions.HTTPError(
|
||||
response=mock_resp
|
||||
)
|
||||
mock_get.return_value = mock_resp
|
||||
|
||||
result = get_tool.execute_action("any")
|
||||
assert result["status_code"] == 404
|
||||
assert result["data"] == "Not Found"
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_request_exception(self, mock_get, mock_validate, get_tool):
|
||||
mock_get.side_effect = requests.exceptions.RequestException("something")
|
||||
result = get_tool.execute_action("any")
|
||||
assert "API call failed" in result["message"]
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_unexpected_exception(self, mock_get, mock_validate, get_tool):
|
||||
mock_get.side_effect = RuntimeError("unexpected")
|
||||
result = get_tool.execute_action("any")
|
||||
assert "Unexpected error" in result["message"]
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.post")
|
||||
def test_body_serialization_error(self, mock_post, mock_validate):
|
||||
tool = APITool(config={
|
||||
"url": "https://example.com",
|
||||
"method": "POST",
|
||||
"body_content_type": "application/json",
|
||||
})
|
||||
|
||||
with patch(
|
||||
"application.agents.tools.api_tool.RequestBodySerializer.serialize",
|
||||
side_effect=ValueError("serialize fail"),
|
||||
):
|
||||
result = tool.execute_action("any", key="val")
|
||||
assert "serialization error" in result["message"].lower()
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Path Param Substitution
|
||||
# =====================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestPathParamSubstitution:
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_path_params_substituted(self, mock_get, mock_validate):
|
||||
tool = APITool(config={
|
||||
"url": "https://api.example.com/users/{user_id}/posts/{post_id}",
|
||||
"method": "GET",
|
||||
"query_params": {"user_id": "42", "post_id": "7"},
|
||||
})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = []
|
||||
mock_resp.content = b'[]'
|
||||
mock_get.return_value = mock_resp
|
||||
|
||||
tool.execute_action("get")
|
||||
|
||||
called_url = mock_get.call_args[0][0]
|
||||
assert "/users/42/posts/7" in called_url
|
||||
assert "{user_id}" not in called_url
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_remaining_query_params_appended(self, mock_get, mock_validate):
|
||||
tool = APITool(config={
|
||||
"url": "https://api.example.com/items",
|
||||
"method": "GET",
|
||||
"query_params": {"page": "2", "limit": "10"},
|
||||
})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = []
|
||||
mock_resp.content = b'[]'
|
||||
mock_get.return_value = mock_resp
|
||||
|
||||
tool.execute_action("get")
|
||||
|
||||
called_url = mock_get.call_args[0][0]
|
||||
assert "page=2" in called_url
|
||||
assert "limit=10" in called_url
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_query_params_append_with_existing_query_string(
|
||||
self, mock_get, mock_validate
|
||||
):
|
||||
tool = APITool(config={
|
||||
"url": "https://api.example.com/items?existing=true",
|
||||
"method": "GET",
|
||||
"query_params": {"page": "1"},
|
||||
})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = []
|
||||
mock_resp.content = b'[]'
|
||||
mock_get.return_value = mock_resp
|
||||
|
||||
tool.execute_action("get")
|
||||
|
||||
called_url = mock_get.call_args[0][0]
|
||||
assert "&page=1" in called_url
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.post")
|
||||
def test_empty_body_no_serialization(self, mock_post, mock_validate):
|
||||
tool = APITool(config={"url": "https://example.com", "method": "POST"})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {}
|
||||
mock_resp.content = b'{}'
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("create")
|
||||
assert result["status_code"] == 200
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Parse Response
|
||||
# =====================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestParseResponse:
|
||||
|
||||
def test_json_response(self, get_tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {"key": "val"}
|
||||
mock_resp.content = b'{"key":"val"}'
|
||||
|
||||
result = get_tool._parse_response(mock_resp)
|
||||
assert result == {"key": "val"}
|
||||
|
||||
def test_json_decode_error_falls_back_to_text(self, get_tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.side_effect = json.JSONDecodeError("", "", 0)
|
||||
mock_resp.text = "not valid json"
|
||||
mock_resp.content = b"not valid json"
|
||||
|
||||
result = get_tool._parse_response(mock_resp)
|
||||
assert result == "not valid json"
|
||||
|
||||
def test_text_response(self, get_tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.headers = {"Content-Type": "text/plain"}
|
||||
mock_resp.text = "plain text"
|
||||
mock_resp.content = b"plain text"
|
||||
|
||||
result = get_tool._parse_response(mock_resp)
|
||||
assert result == "plain text"
|
||||
|
||||
def test_xml_response(self, get_tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.headers = {"Content-Type": "application/xml"}
|
||||
mock_resp.text = "<root><item>1</item></root>"
|
||||
mock_resp.content = b"<root><item>1</item></root>"
|
||||
|
||||
result = get_tool._parse_response(mock_resp)
|
||||
assert "<root>" in result
|
||||
|
||||
def test_html_response(self, get_tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.headers = {"Content-Type": "text/html"}
|
||||
mock_resp.text = "<html><body>Hi</body></html>"
|
||||
mock_resp.content = b"<html><body>Hi</body></html>"
|
||||
|
||||
result = get_tool._parse_response(mock_resp)
|
||||
assert "<html>" in result
|
||||
|
||||
def test_empty_content(self, get_tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.content = b""
|
||||
|
||||
result = get_tool._parse_response(mock_resp)
|
||||
assert result is None
|
||||
|
||||
def test_binary_response(self, get_tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.headers = {"Content-Type": "application/octet-stream"}
|
||||
mock_resp.text = "binary_text"
|
||||
mock_resp.content = b"\x00\x01\x02"
|
||||
|
||||
result = get_tool._parse_response(mock_resp)
|
||||
assert result is not None
|
||||
|
||||
def test_text_xml_content_type(self, get_tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.headers = {"Content-Type": "text/xml"}
|
||||
mock_resp.text = "<data/>"
|
||||
mock_resp.content = b"<data/>"
|
||||
|
||||
result = get_tool._parse_response(mock_resp)
|
||||
assert result == "<data/>"
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Metadata
|
||||
# =====================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestAPIToolMetadata:
|
||||
|
||||
def test_actions_metadata_empty(self, get_tool):
|
||||
assert get_tool.get_actions_metadata() == []
|
||||
|
||||
def test_config_requirements_empty(self, get_tool):
|
||||
assert get_tool.get_config_requirements() == {}
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.post")
|
||||
def test_content_type_set_for_post_with_no_headers(
|
||||
self, mock_post, mock_validate
|
||||
):
|
||||
tool = APITool(config={
|
||||
"url": "https://example.com",
|
||||
"method": "POST",
|
||||
"headers": {},
|
||||
})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {}
|
||||
mock_resp.content = b'{}'
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
tool.execute_action("create")
|
||||
call_headers = mock_post.call_args[1]["headers"]
|
||||
assert "Content-Type" in call_headers
|
||||
596
tests/agents/tools/test_internal_search.py
Normal file
596
tests/agents/tools/test_internal_search.py
Normal file
@@ -0,0 +1,596 @@
|
||||
"""Comprehensive tests for application/agents/tools/internal_search.py
|
||||
|
||||
Covers: InternalSearchTool (search, list_files, path_filter, error handling,
|
||||
directory structure loading), build helpers, add_internal_search_tool,
|
||||
sources_have_directory_structure.
|
||||
"""
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from application.agents.tools.internal_search import (
|
||||
INTERNAL_TOOL_ENTRY,
|
||||
INTERNAL_TOOL_ID,
|
||||
InternalSearchTool,
|
||||
add_internal_search_tool,
|
||||
build_internal_tool_config,
|
||||
build_internal_tool_entry,
|
||||
sources_have_directory_structure,
|
||||
)
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# InternalSearchTool - Search
|
||||
# =====================================================================
|
||||
|
||||
|
||||
def _make_tool(**config_overrides):
|
||||
config = {"source": {}, "retriever_name": "classic", "chunks": 2}
|
||||
config.update(config_overrides)
|
||||
return InternalSearchTool(config)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestInternalSearchToolSearch:
|
||||
|
||||
def test_search_no_query_returns_error(self):
|
||||
tool = _make_tool()
|
||||
result = tool.execute_action("search", query="")
|
||||
assert "required" in result.lower()
|
||||
|
||||
def test_search_returns_formatted_docs(self):
|
||||
tool = _make_tool()
|
||||
mock_retriever = Mock()
|
||||
mock_retriever.search.return_value = [
|
||||
{
|
||||
"text": "Hello world",
|
||||
"title": "Doc1",
|
||||
"source": "test",
|
||||
"filename": "doc1.md",
|
||||
},
|
||||
]
|
||||
tool._retriever = mock_retriever
|
||||
|
||||
result = tool.execute_action("search", query="hello")
|
||||
assert "doc1.md" in result
|
||||
assert "Hello world" in result
|
||||
assert len(tool.retrieved_docs) == 1
|
||||
|
||||
def test_search_no_results(self):
|
||||
tool = _make_tool()
|
||||
mock_retriever = Mock()
|
||||
mock_retriever.search.return_value = []
|
||||
tool._retriever = mock_retriever
|
||||
|
||||
result = tool.execute_action("search", query="nonexistent")
|
||||
assert "No documents found" in result
|
||||
|
||||
def test_search_accumulates_docs(self):
|
||||
tool = _make_tool()
|
||||
mock_retriever = Mock()
|
||||
tool._retriever = mock_retriever
|
||||
|
||||
mock_retriever.search.return_value = [
|
||||
{"text": "A", "title": "D1", "source": "s1"},
|
||||
]
|
||||
tool.execute_action("search", query="first")
|
||||
|
||||
mock_retriever.search.return_value = [
|
||||
{"text": "B", "title": "D2", "source": "s2"},
|
||||
]
|
||||
tool.execute_action("search", query="second")
|
||||
|
||||
assert len(tool.retrieved_docs) == 2
|
||||
|
||||
def test_search_deduplicates_docs(self):
|
||||
tool = _make_tool()
|
||||
doc = {"text": "Same", "title": "Same", "source": "same"}
|
||||
mock_retriever = Mock()
|
||||
mock_retriever.search.return_value = [doc]
|
||||
tool._retriever = mock_retriever
|
||||
|
||||
tool.execute_action("search", query="q1")
|
||||
tool.execute_action("search", query="q2")
|
||||
|
||||
assert len(tool.retrieved_docs) == 1
|
||||
|
||||
def test_search_with_path_filter(self):
|
||||
tool = _make_tool()
|
||||
mock_retriever = Mock()
|
||||
mock_retriever.search.return_value = [
|
||||
{"text": "A", "title": "T", "source": "src/main.py", "filename": "main.py"},
|
||||
{"text": "B", "title": "T", "source": "docs/readme.md", "filename": "readme.md"},
|
||||
]
|
||||
tool._retriever = mock_retriever
|
||||
|
||||
result = tool.execute_action("search", query="code", path_filter="src/")
|
||||
assert "main.py" in result
|
||||
assert "readme.md" not in result
|
||||
|
||||
def test_search_path_filter_matches_title(self):
|
||||
tool = _make_tool()
|
||||
mock_retriever = Mock()
|
||||
mock_retriever.search.return_value = [
|
||||
{"text": "A", "title": "src/main.py", "source": "other", "filename": ""},
|
||||
]
|
||||
tool._retriever = mock_retriever
|
||||
|
||||
result = tool.execute_action("search", query="code", path_filter="src/main")
|
||||
assert "src/main.py" in result
|
||||
|
||||
def test_search_path_filter_no_match(self):
|
||||
tool = _make_tool()
|
||||
mock_retriever = Mock()
|
||||
mock_retriever.search.return_value = [
|
||||
{"text": "A", "title": "T", "source": "other/file.txt"},
|
||||
]
|
||||
tool._retriever = mock_retriever
|
||||
|
||||
result = tool.execute_action("search", query="code", path_filter="src/")
|
||||
assert "No documents found" in result
|
||||
|
||||
def test_search_retriever_error(self):
|
||||
tool = _make_tool()
|
||||
mock_retriever = Mock()
|
||||
mock_retriever.search.side_effect = Exception("Connection error")
|
||||
tool._retriever = mock_retriever
|
||||
|
||||
result = tool.execute_action("search", query="test")
|
||||
assert "failed" in result.lower() or "error" in result.lower()
|
||||
|
||||
def test_unknown_action(self):
|
||||
tool = _make_tool()
|
||||
result = tool.execute_action("nonexistent")
|
||||
assert "Unknown action" in result
|
||||
|
||||
def test_search_formats_with_separator(self):
|
||||
tool = _make_tool()
|
||||
mock_retriever = Mock()
|
||||
mock_retriever.search.return_value = [
|
||||
{"text": "A", "title": "D1", "source": "s1", "filename": "f1.md"},
|
||||
{"text": "B", "title": "D2", "source": "s2", "filename": "f2.md"},
|
||||
]
|
||||
tool._retriever = mock_retriever
|
||||
|
||||
result = tool.execute_action("search", query="test")
|
||||
assert "---" in result
|
||||
assert "[1]" in result
|
||||
assert "[2]" in result
|
||||
|
||||
def test_search_uses_title_when_no_filename(self):
|
||||
tool = _make_tool()
|
||||
mock_retriever = Mock()
|
||||
mock_retriever.search.return_value = [
|
||||
{"text": "Content", "title": "My Title", "source": "src", "filename": ""},
|
||||
]
|
||||
tool._retriever = mock_retriever
|
||||
|
||||
result = tool.execute_action("search", query="q")
|
||||
assert "My Title" in result
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# InternalSearchTool - List Files
|
||||
# =====================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestInternalSearchToolListFiles:
|
||||
|
||||
def test_list_files_no_structure(self):
|
||||
tool = InternalSearchTool({"source": {}})
|
||||
tool._dir_structure_loaded = True
|
||||
tool._directory_structure = None
|
||||
|
||||
result = tool.execute_action("list_files")
|
||||
assert "No file structure" in result
|
||||
|
||||
def test_list_files_root(self):
|
||||
tool = InternalSearchTool({"source": {}})
|
||||
tool._dir_structure_loaded = True
|
||||
tool._directory_structure = {
|
||||
"src": {"main.py": {}},
|
||||
"README.md": {"type": "md", "token_count": 100},
|
||||
}
|
||||
|
||||
result = tool.execute_action("list_files")
|
||||
assert "src/" in result
|
||||
assert "README.md" in result
|
||||
|
||||
def test_list_files_nested_path(self):
|
||||
tool = InternalSearchTool({"source": {}})
|
||||
tool._dir_structure_loaded = True
|
||||
tool._directory_structure = {
|
||||
"src": {
|
||||
"utils": {"helper.py": {}},
|
||||
},
|
||||
}
|
||||
|
||||
result = tool.execute_action("list_files", path="src")
|
||||
assert "utils/" in result
|
||||
|
||||
def test_list_files_invalid_path(self):
|
||||
tool = InternalSearchTool({"source": {}})
|
||||
tool._dir_structure_loaded = True
|
||||
tool._directory_structure = {"src": {}}
|
||||
|
||||
result = tool.execute_action("list_files", path="nonexistent")
|
||||
assert "not found" in result
|
||||
|
||||
def test_list_files_empty_directory(self):
|
||||
tool = InternalSearchTool({"source": {}})
|
||||
tool._dir_structure_loaded = True
|
||||
tool._directory_structure = {"empty_dir": {}}
|
||||
|
||||
result = tool.execute_action("list_files", path="empty_dir")
|
||||
assert "(empty)" in result
|
||||
|
||||
def test_list_files_file_with_metadata(self):
|
||||
tool = InternalSearchTool({"source": {}})
|
||||
tool._dir_structure_loaded = True
|
||||
tool._directory_structure = {
|
||||
"data.csv": {
|
||||
"type": "text/csv",
|
||||
"size_bytes": 1024,
|
||||
"token_count": 500,
|
||||
},
|
||||
}
|
||||
|
||||
result = tool.execute_action("list_files")
|
||||
assert "data.csv" in result
|
||||
assert "500 tokens" in result
|
||||
assert "text/csv" in result
|
||||
|
||||
def test_list_files_file_is_not_directory(self):
|
||||
tool = InternalSearchTool({"source": {}})
|
||||
tool._dir_structure_loaded = True
|
||||
tool._directory_structure = {
|
||||
"src": {
|
||||
"main.py": "plain_file_value",
|
||||
},
|
||||
}
|
||||
|
||||
result = tool.execute_action("list_files", path="src/main.py")
|
||||
assert "is a file" in result
|
||||
|
||||
def test_list_files_deep_nested_path_with_slashes(self):
|
||||
tool = InternalSearchTool({"source": {}})
|
||||
tool._dir_structure_loaded = True
|
||||
tool._directory_structure = {
|
||||
"a": {"b": {"c": {"file.txt": {"type": "text"}}}},
|
||||
}
|
||||
|
||||
result = tool.execute_action("list_files", path="a/b/c")
|
||||
assert "file.txt" in result
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Count Files Helper
|
||||
# =====================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCountFiles:
|
||||
|
||||
def test_count_files_nested(self):
|
||||
tool = InternalSearchTool({"source": {}})
|
||||
node = {
|
||||
"file1.txt": {"type": "text"},
|
||||
"dir": {
|
||||
"file2.txt": {"type": "text"},
|
||||
"file3.txt": "plain_value",
|
||||
},
|
||||
}
|
||||
assert tool._count_files(node) == 3
|
||||
|
||||
def test_count_files_empty(self):
|
||||
tool = InternalSearchTool({"source": {}})
|
||||
assert tool._count_files({}) == 0
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Directory Structure Loading
|
||||
# =====================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGetDirectoryStructure:
|
||||
|
||||
def test_loads_from_mongo(self):
|
||||
tool = InternalSearchTool({
|
||||
"source": {"active_docs": ["507f1f77bcf86cd799439011"]},
|
||||
})
|
||||
|
||||
mock_collection = MagicMock()
|
||||
mock_collection.find_one.return_value = {
|
||||
"_id": "507f1f77bcf86cd799439011",
|
||||
"name": "test_source",
|
||||
"directory_structure": {"src": {"main.py": {}}},
|
||||
}
|
||||
|
||||
with patch("application.core.mongo_db.MongoDB") as mock_mongo:
|
||||
mock_db = MagicMock()
|
||||
mock_db.__getitem__ = MagicMock(return_value=mock_collection)
|
||||
mock_client = MagicMock()
|
||||
mock_client.__getitem__ = MagicMock(return_value=mock_db)
|
||||
mock_mongo.get_client.return_value = mock_client
|
||||
|
||||
result = tool._get_directory_structure()
|
||||
assert result is not None
|
||||
assert "src" in result
|
||||
|
||||
def test_returns_none_without_active_docs(self):
|
||||
tool = InternalSearchTool({"source": {}})
|
||||
result = tool._get_directory_structure()
|
||||
assert result is None
|
||||
assert tool._dir_structure_loaded is True
|
||||
|
||||
def test_caches_after_first_load(self):
|
||||
tool = InternalSearchTool({"source": {}})
|
||||
tool._dir_structure_loaded = True
|
||||
tool._directory_structure = {"cached": True}
|
||||
|
||||
result = tool._get_directory_structure()
|
||||
assert result == {"cached": True}
|
||||
|
||||
def test_handles_json_string_structure(self):
|
||||
tool = InternalSearchTool({
|
||||
"source": {"active_docs": ["507f1f77bcf86cd799439011"]},
|
||||
})
|
||||
|
||||
mock_collection = MagicMock()
|
||||
mock_collection.find_one.return_value = {
|
||||
"_id": "507f1f77bcf86cd799439011",
|
||||
"name": "test_source",
|
||||
"directory_structure": json.dumps({"src": {"app.py": {}}}),
|
||||
}
|
||||
|
||||
with patch("application.core.mongo_db.MongoDB") as mock_mongo:
|
||||
mock_db = MagicMock()
|
||||
mock_db.__getitem__ = MagicMock(return_value=mock_collection)
|
||||
mock_client = MagicMock()
|
||||
mock_client.__getitem__ = MagicMock(return_value=mock_db)
|
||||
mock_mongo.get_client.return_value = mock_client
|
||||
|
||||
result = tool._get_directory_structure()
|
||||
assert result is not None
|
||||
assert "src" in result
|
||||
|
||||
def test_handles_string_active_docs(self):
|
||||
tool = InternalSearchTool({
|
||||
"source": {"active_docs": "507f1f77bcf86cd799439011"},
|
||||
})
|
||||
|
||||
mock_collection = MagicMock()
|
||||
mock_collection.find_one.return_value = {
|
||||
"_id": "507f1f77bcf86cd799439011",
|
||||
"directory_structure": {"dir": {}},
|
||||
}
|
||||
|
||||
with patch("application.core.mongo_db.MongoDB") as mock_mongo:
|
||||
mock_db = MagicMock()
|
||||
mock_db.__getitem__ = MagicMock(return_value=mock_collection)
|
||||
mock_client = MagicMock()
|
||||
mock_client.__getitem__ = MagicMock(return_value=mock_db)
|
||||
mock_mongo.get_client.return_value = mock_client
|
||||
|
||||
result = tool._get_directory_structure()
|
||||
assert result is not None
|
||||
|
||||
def test_merges_multiple_sources(self):
|
||||
tool = InternalSearchTool({
|
||||
"source": {
|
||||
"active_docs": [
|
||||
"507f1f77bcf86cd799439011",
|
||||
"507f1f77bcf86cd799439012",
|
||||
],
|
||||
},
|
||||
})
|
||||
|
||||
mock_collection = MagicMock()
|
||||
mock_collection.find_one.side_effect = [
|
||||
{"name": "src1", "directory_structure": {"a": {}}},
|
||||
{"name": "src2", "directory_structure": {"b": {}}},
|
||||
]
|
||||
|
||||
with patch("application.core.mongo_db.MongoDB") as mock_mongo:
|
||||
mock_db = MagicMock()
|
||||
mock_db.__getitem__ = MagicMock(return_value=mock_collection)
|
||||
mock_client = MagicMock()
|
||||
mock_client.__getitem__ = MagicMock(return_value=mock_db)
|
||||
mock_mongo.get_client.return_value = mock_client
|
||||
|
||||
result = tool._get_directory_structure()
|
||||
assert "src1" in result
|
||||
assert "src2" in result
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Metadata
|
||||
# =====================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestInternalSearchToolMetadata:
|
||||
|
||||
def test_actions_without_directory_structure(self):
|
||||
tool = InternalSearchTool({"has_directory_structure": False})
|
||||
meta = tool.get_actions_metadata()
|
||||
|
||||
action_names = [a["name"] for a in meta]
|
||||
assert "search" in action_names
|
||||
assert "list_files" not in action_names
|
||||
|
||||
search = meta[0]
|
||||
assert "path_filter" not in search["parameters"]["properties"]
|
||||
|
||||
def test_actions_with_directory_structure(self):
|
||||
tool = InternalSearchTool({"has_directory_structure": True})
|
||||
meta = tool.get_actions_metadata()
|
||||
|
||||
action_names = [a["name"] for a in meta]
|
||||
assert "search" in action_names
|
||||
assert "list_files" in action_names
|
||||
|
||||
search = next(a for a in meta if a["name"] == "search")
|
||||
assert "path_filter" in search["parameters"]["properties"]
|
||||
|
||||
def test_config_requirements_empty(self):
|
||||
tool = InternalSearchTool({})
|
||||
assert tool.get_config_requirements() == {}
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Build Helpers
|
||||
# =====================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestBuildHelpers:
|
||||
|
||||
def test_build_entry_without_directory_structure(self):
|
||||
entry = build_internal_tool_entry(has_directory_structure=False)
|
||||
assert entry["name"] == "internal_search"
|
||||
action_names = [a["name"] for a in entry["actions"]]
|
||||
assert "search" in action_names
|
||||
assert "list_files" not in action_names
|
||||
assert entry["actions"][0].get("active") is True
|
||||
|
||||
def test_build_entry_with_directory_structure(self):
|
||||
entry = build_internal_tool_entry(has_directory_structure=True)
|
||||
action_names = [a["name"] for a in entry["actions"]]
|
||||
assert "list_files" in action_names
|
||||
# path_filter should be in search params
|
||||
search_action = next(a for a in entry["actions"] if a["name"] == "search")
|
||||
assert "path_filter" in search_action["parameters"]["properties"]
|
||||
|
||||
def test_build_config(self):
|
||||
config = build_internal_tool_config(
|
||||
source={"active_docs": ["abc"]},
|
||||
retriever_name="semantic",
|
||||
chunks=4,
|
||||
)
|
||||
assert config["source"] == {"active_docs": ["abc"]}
|
||||
assert config["retriever_name"] == "semantic"
|
||||
assert config["chunks"] == 4
|
||||
|
||||
def test_build_config_defaults(self):
|
||||
config = build_internal_tool_config(source={"active_docs": ["abc"]})
|
||||
assert config["retriever_name"] == "classic"
|
||||
assert config["chunks"] == 2
|
||||
assert config["doc_token_limit"] == 50000
|
||||
|
||||
def test_internal_tool_id(self):
|
||||
assert INTERNAL_TOOL_ID == "internal"
|
||||
|
||||
def test_internal_tool_entry_constant(self):
|
||||
assert INTERNAL_TOOL_ENTRY["name"] == "internal_search"
|
||||
|
||||
def test_add_internal_search_tool_with_sources(self):
|
||||
tools_dict = {}
|
||||
retriever_config = {
|
||||
"source": {"active_docs": ["abc"]},
|
||||
"retriever_name": "classic",
|
||||
"chunks": 2,
|
||||
"model_id": "gpt-4",
|
||||
"llm_name": "openai",
|
||||
"api_key": "key",
|
||||
}
|
||||
|
||||
with patch(
|
||||
"application.agents.tools.internal_search.sources_have_directory_structure",
|
||||
return_value=False,
|
||||
):
|
||||
add_internal_search_tool(tools_dict, retriever_config)
|
||||
|
||||
assert INTERNAL_TOOL_ID in tools_dict
|
||||
assert tools_dict[INTERNAL_TOOL_ID]["name"] == "internal_search"
|
||||
assert "config" in tools_dict[INTERNAL_TOOL_ID]
|
||||
|
||||
def test_add_internal_search_tool_no_sources(self):
|
||||
tools_dict = {}
|
||||
retriever_config = {"source": {}}
|
||||
|
||||
add_internal_search_tool(tools_dict, retriever_config)
|
||||
assert INTERNAL_TOOL_ID not in tools_dict
|
||||
|
||||
def test_add_internal_search_tool_empty_config(self):
|
||||
tools_dict = {}
|
||||
add_internal_search_tool(tools_dict, {})
|
||||
assert INTERNAL_TOOL_ID not in tools_dict
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# sources_have_directory_structure
|
||||
# =====================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSourcesHaveDirectoryStructure:
|
||||
|
||||
def test_no_active_docs(self):
|
||||
assert sources_have_directory_structure({}) is False
|
||||
|
||||
def test_with_directory_structure(self):
|
||||
mock_collection = MagicMock()
|
||||
mock_collection.find_one.return_value = {
|
||||
"directory_structure": {"src": {}},
|
||||
}
|
||||
|
||||
with patch("application.core.mongo_db.MongoDB") as mock_mongo:
|
||||
mock_db = MagicMock()
|
||||
mock_db.__getitem__ = MagicMock(return_value=mock_collection)
|
||||
mock_client = MagicMock()
|
||||
mock_client.__getitem__ = MagicMock(return_value=mock_db)
|
||||
mock_mongo.get_client.return_value = mock_client
|
||||
|
||||
result = sources_have_directory_structure(
|
||||
{"active_docs": ["507f1f77bcf86cd799439011"]}
|
||||
)
|
||||
assert result is True
|
||||
|
||||
def test_without_directory_structure(self):
|
||||
mock_collection = MagicMock()
|
||||
mock_collection.find_one.return_value = {"directory_structure": None}
|
||||
|
||||
with patch("application.core.mongo_db.MongoDB") as mock_mongo:
|
||||
mock_db = MagicMock()
|
||||
mock_db.__getitem__ = MagicMock(return_value=mock_collection)
|
||||
mock_client = MagicMock()
|
||||
mock_client.__getitem__ = MagicMock(return_value=mock_db)
|
||||
mock_mongo.get_client.return_value = mock_client
|
||||
|
||||
result = sources_have_directory_structure(
|
||||
{"active_docs": ["507f1f77bcf86cd799439011"]}
|
||||
)
|
||||
assert result is False
|
||||
|
||||
def test_handles_exception_gracefully(self):
|
||||
with patch(
|
||||
"application.core.mongo_db.MongoDB.get_client",
|
||||
side_effect=Exception("DB down"),
|
||||
):
|
||||
result = sources_have_directory_structure(
|
||||
{"active_docs": ["507f1f77bcf86cd799439011"]}
|
||||
)
|
||||
assert result is False
|
||||
|
||||
def test_string_active_docs(self):
|
||||
mock_collection = MagicMock()
|
||||
mock_collection.find_one.return_value = {
|
||||
"directory_structure": {"a": {}},
|
||||
}
|
||||
|
||||
with patch("application.core.mongo_db.MongoDB") as mock_mongo:
|
||||
mock_db = MagicMock()
|
||||
mock_db.__getitem__ = MagicMock(return_value=mock_collection)
|
||||
mock_client = MagicMock()
|
||||
mock_client.__getitem__ = MagicMock(return_value=mock_db)
|
||||
mock_mongo.get_client.return_value = mock_client
|
||||
|
||||
result = sources_have_directory_structure(
|
||||
{"active_docs": "507f1f77bcf86cd799439011"}
|
||||
)
|
||||
assert result is True
|
||||
2414
tests/agents/tools/test_mcp_tool.py
Normal file
2414
tests/agents/tools/test_mcp_tool.py
Normal file
File diff suppressed because it is too large
Load Diff
548
tests/agents/tools/test_memory.py
Normal file
548
tests/agents/tools/test_memory.py
Normal file
@@ -0,0 +1,548 @@
|
||||
"""Comprehensive tests for application/agents/tools/memory.py
|
||||
|
||||
Covers: MemoryTool initialization, path validation, all actions
|
||||
(view, create, str_replace, insert, delete, rename), directory operations,
|
||||
error handling, and metadata.
|
||||
"""
|
||||
|
||||
import mongomock
|
||||
import pytest
|
||||
|
||||
|
||||
def _get_settings():
|
||||
from application.core.settings import settings
|
||||
return settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_memory_db(monkeypatch):
|
||||
"""Set up a mongomock-based memory collection."""
|
||||
settings = _get_settings()
|
||||
mock_client = mongomock.MongoClient()
|
||||
mock_db = mock_client[settings.MONGO_DB_NAME]
|
||||
|
||||
def get_mock_client():
|
||||
return {settings.MONGO_DB_NAME: mock_db}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"application.core.mongo_db.MongoDB.get_client", get_mock_client
|
||||
)
|
||||
return mock_db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def memory_tool(mock_memory_db):
|
||||
from application.agents.tools.memory import MemoryTool
|
||||
|
||||
return MemoryTool(
|
||||
tool_config={"tool_id": "test_tool_001"},
|
||||
user_id="test_user",
|
||||
)
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Initialization
|
||||
# =====================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestMemoryToolInit:
|
||||
|
||||
def test_init_with_config(self, mock_memory_db):
|
||||
from application.agents.tools.memory import MemoryTool
|
||||
|
||||
tool = MemoryTool(
|
||||
tool_config={"tool_id": "custom_id"}, user_id="user1"
|
||||
)
|
||||
assert tool.tool_id == "custom_id"
|
||||
assert tool.user_id == "user1"
|
||||
|
||||
def test_init_fallback_to_user_id(self, mock_memory_db):
|
||||
from application.agents.tools.memory import MemoryTool
|
||||
|
||||
tool = MemoryTool(tool_config={}, user_id="user1")
|
||||
assert tool.tool_id == "default_user1"
|
||||
|
||||
def test_init_no_user_no_config(self, mock_memory_db):
|
||||
from application.agents.tools.memory import MemoryTool
|
||||
|
||||
tool = MemoryTool()
|
||||
assert tool.tool_id is not None # UUID fallback
|
||||
assert tool.user_id is None
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Path Validation
|
||||
# =====================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestPathValidation:
|
||||
|
||||
def test_valid_path(self, memory_tool):
|
||||
assert memory_tool._validate_path("/notes.txt") == "/notes.txt"
|
||||
|
||||
def test_adds_leading_slash(self, memory_tool):
|
||||
assert memory_tool._validate_path("notes.txt") == "/notes.txt"
|
||||
|
||||
def test_empty_path_returns_none(self, memory_tool):
|
||||
assert memory_tool._validate_path("") is None
|
||||
|
||||
def test_double_dots_rejected(self, memory_tool):
|
||||
assert memory_tool._validate_path("/../../etc/passwd") is None
|
||||
|
||||
def test_double_slash_rejected(self, memory_tool):
|
||||
assert memory_tool._validate_path("//path") is None
|
||||
|
||||
def test_preserves_trailing_slash(self, memory_tool):
|
||||
result = memory_tool._validate_path("/project/")
|
||||
assert result.endswith("/")
|
||||
|
||||
def test_root_path(self, memory_tool):
|
||||
assert memory_tool._validate_path("/") == "/"
|
||||
|
||||
def test_whitespace_stripped(self, memory_tool):
|
||||
result = memory_tool._validate_path(" /notes.txt ")
|
||||
assert result == "/notes.txt"
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Execute Action - No User
|
||||
# =====================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestNoUser:
|
||||
|
||||
def test_requires_user_id(self, mock_memory_db):
|
||||
from application.agents.tools.memory import MemoryTool
|
||||
|
||||
tool = MemoryTool(tool_config={"tool_id": "t"}, user_id=None)
|
||||
result = tool.execute_action("view", path="/")
|
||||
assert "Error" in result
|
||||
assert "user_id" in result
|
||||
|
||||
def test_unknown_action(self, memory_tool):
|
||||
result = memory_tool.execute_action("fly")
|
||||
assert "Unknown action" in result
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# View Action
|
||||
# =====================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestViewAction:
|
||||
|
||||
def test_view_empty_directory(self, memory_tool):
|
||||
result = memory_tool.execute_action("view", path="/")
|
||||
assert "Directory: /" in result
|
||||
assert "(empty)" in result
|
||||
|
||||
def test_view_directory_with_files(self, memory_tool):
|
||||
memory_tool.execute_action("create", path="/notes.txt", file_text="content")
|
||||
memory_tool.execute_action("create", path="/todo.txt", file_text="tasks")
|
||||
|
||||
result = memory_tool.execute_action("view", path="/")
|
||||
assert "notes.txt" in result
|
||||
assert "todo.txt" in result
|
||||
|
||||
def test_view_file_content(self, memory_tool):
|
||||
memory_tool.execute_action("create", path="/hello.txt", file_text="Hello World")
|
||||
result = memory_tool.execute_action("view", path="/hello.txt")
|
||||
assert "Hello World" in result
|
||||
|
||||
def test_view_nonexistent_file(self, memory_tool):
|
||||
result = memory_tool.execute_action("view", path="/missing.txt")
|
||||
assert "Error" in result
|
||||
assert "not found" in result.lower()
|
||||
|
||||
def test_view_file_with_range(self, memory_tool):
|
||||
memory_tool.execute_action(
|
||||
"create", path="/lines.txt", file_text="line1\nline2\nline3\nline4"
|
||||
)
|
||||
result = memory_tool.execute_action(
|
||||
"view", path="/lines.txt", view_range=[2, 3]
|
||||
)
|
||||
assert "line2" in result
|
||||
assert "line3" in result
|
||||
|
||||
def test_view_file_range_out_of_bounds(self, memory_tool):
|
||||
memory_tool.execute_action("create", path="/short.txt", file_text="only")
|
||||
result = memory_tool.execute_action(
|
||||
"view", path="/short.txt", view_range=[100, 200]
|
||||
)
|
||||
assert "out of bounds" in result.lower()
|
||||
|
||||
def test_view_invalid_path(self, memory_tool):
|
||||
result = memory_tool.execute_action("view", path="")
|
||||
assert "Error" in result
|
||||
|
||||
def test_view_subdirectory(self, memory_tool):
|
||||
memory_tool.execute_action(
|
||||
"create", path="/project/src/main.py", file_text="code"
|
||||
)
|
||||
result = memory_tool.execute_action("view", path="/project/")
|
||||
assert "src/main.py" in result
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Create Action
|
||||
# =====================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCreateAction:
|
||||
|
||||
def test_create_file(self, memory_tool):
|
||||
result = memory_tool.execute_action(
|
||||
"create", path="/test.txt", file_text="content"
|
||||
)
|
||||
assert "File created" in result
|
||||
|
||||
content = memory_tool.execute_action("view", path="/test.txt")
|
||||
assert "content" in content
|
||||
|
||||
def test_overwrite_file(self, memory_tool):
|
||||
memory_tool.execute_action("create", path="/test.txt", file_text="old")
|
||||
memory_tool.execute_action("create", path="/test.txt", file_text="new")
|
||||
|
||||
content = memory_tool.execute_action("view", path="/test.txt")
|
||||
assert "new" in content
|
||||
|
||||
def test_create_at_directory_path(self, memory_tool):
|
||||
result = memory_tool.execute_action("create", path="/dir/", file_text="text")
|
||||
assert "Error" in result
|
||||
assert "directory path" in result.lower()
|
||||
|
||||
def test_create_invalid_path(self, memory_tool):
|
||||
result = memory_tool.execute_action("create", path="", file_text="text")
|
||||
assert "Error" in result
|
||||
|
||||
def test_create_nested_path(self, memory_tool):
|
||||
result = memory_tool.execute_action(
|
||||
"create", path="/a/b/c/file.txt", file_text="deep"
|
||||
)
|
||||
assert "File created" in result
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# String Replace Action
|
||||
# =====================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestStrReplaceAction:
|
||||
|
||||
def test_replace_text(self, memory_tool):
|
||||
memory_tool.execute_action(
|
||||
"create", path="/doc.txt", file_text="Hello World"
|
||||
)
|
||||
result = memory_tool.execute_action(
|
||||
"str_replace", path="/doc.txt", old_str="Hello", new_str="Hi"
|
||||
)
|
||||
assert "File updated" in result
|
||||
|
||||
content = memory_tool.execute_action("view", path="/doc.txt")
|
||||
assert "Hi World" in content
|
||||
|
||||
def test_replace_not_found(self, memory_tool):
|
||||
memory_tool.execute_action("create", path="/doc.txt", file_text="Hello")
|
||||
result = memory_tool.execute_action(
|
||||
"str_replace", path="/doc.txt", old_str="Missing", new_str="X"
|
||||
)
|
||||
assert "not found" in result.lower()
|
||||
|
||||
def test_replace_empty_old_str(self, memory_tool):
|
||||
memory_tool.execute_action("create", path="/doc.txt", file_text="Hello")
|
||||
result = memory_tool.execute_action(
|
||||
"str_replace", path="/doc.txt", old_str="", new_str="X"
|
||||
)
|
||||
assert "Error" in result
|
||||
|
||||
def test_replace_file_not_found(self, memory_tool):
|
||||
result = memory_tool.execute_action(
|
||||
"str_replace", path="/missing.txt", old_str="a", new_str="b"
|
||||
)
|
||||
assert "not found" in result.lower()
|
||||
|
||||
def test_replace_case_insensitive(self, memory_tool):
|
||||
memory_tool.execute_action(
|
||||
"create", path="/doc.txt", file_text="Hello World"
|
||||
)
|
||||
result = memory_tool.execute_action(
|
||||
"str_replace", path="/doc.txt", old_str="hello", new_str="Hi"
|
||||
)
|
||||
assert "File updated" in result
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Insert Action
|
||||
# =====================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestInsertAction:
|
||||
|
||||
def test_insert_text(self, memory_tool):
|
||||
memory_tool.execute_action(
|
||||
"create", path="/doc.txt", file_text="line1\nline2"
|
||||
)
|
||||
result = memory_tool.execute_action(
|
||||
"insert", path="/doc.txt", insert_line=2, insert_text="inserted"
|
||||
)
|
||||
assert "inserted" in result.lower()
|
||||
|
||||
content = memory_tool.execute_action("view", path="/doc.txt")
|
||||
assert "inserted" in content
|
||||
|
||||
def test_insert_empty_text(self, memory_tool):
|
||||
memory_tool.execute_action("create", path="/doc.txt", file_text="line1")
|
||||
result = memory_tool.execute_action(
|
||||
"insert", path="/doc.txt", insert_line=1, insert_text=""
|
||||
)
|
||||
assert "Error" in result
|
||||
|
||||
def test_insert_file_not_found(self, memory_tool):
|
||||
result = memory_tool.execute_action(
|
||||
"insert", path="/missing.txt", insert_line=1, insert_text="text"
|
||||
)
|
||||
assert "not found" in result.lower()
|
||||
|
||||
def test_insert_invalid_line_number(self, memory_tool):
|
||||
memory_tool.execute_action("create", path="/doc.txt", file_text="line1")
|
||||
result = memory_tool.execute_action(
|
||||
"insert", path="/doc.txt", insert_line=-5, insert_text="text"
|
||||
)
|
||||
assert "Error" in result
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Delete Action
|
||||
# =====================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestDeleteAction:
|
||||
|
||||
def test_delete_file(self, memory_tool):
|
||||
memory_tool.execute_action("create", path="/test.txt", file_text="data")
|
||||
result = memory_tool.execute_action("delete", path="/test.txt")
|
||||
assert "Deleted" in result
|
||||
|
||||
content = memory_tool.execute_action("view", path="/test.txt")
|
||||
assert "not found" in content.lower()
|
||||
|
||||
def test_delete_nonexistent_file(self, memory_tool):
|
||||
result = memory_tool.execute_action("delete", path="/missing.txt")
|
||||
assert "not found" in result.lower()
|
||||
|
||||
def test_delete_root_clears_all(self, memory_tool):
|
||||
memory_tool.execute_action("create", path="/a.txt", file_text="a")
|
||||
memory_tool.execute_action("create", path="/b.txt", file_text="b")
|
||||
|
||||
result = memory_tool.execute_action("delete", path="/")
|
||||
assert "Deleted" in result
|
||||
assert "2" in result
|
||||
|
||||
def test_delete_directory(self, memory_tool):
|
||||
memory_tool.execute_action("create", path="/dir/f1.txt", file_text="1")
|
||||
memory_tool.execute_action("create", path="/dir/f2.txt", file_text="2")
|
||||
|
||||
result = memory_tool.execute_action("delete", path="/dir/")
|
||||
assert "Deleted" in result
|
||||
|
||||
def test_delete_directory_without_trailing_slash(self, memory_tool):
|
||||
memory_tool.execute_action("create", path="/dir/f1.txt", file_text="1")
|
||||
|
||||
result = memory_tool.execute_action("delete", path="/dir")
|
||||
assert "Deleted" in result
|
||||
|
||||
def test_delete_invalid_path(self, memory_tool):
|
||||
result = memory_tool.execute_action("delete", path="")
|
||||
assert "Error" in result
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Rename Action
|
||||
# =====================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestRenameAction:
|
||||
|
||||
def test_rename_file(self, memory_tool):
|
||||
memory_tool.execute_action("create", path="/old.txt", file_text="data")
|
||||
result = memory_tool.execute_action(
|
||||
"rename", old_path="/old.txt", new_path="/new.txt"
|
||||
)
|
||||
assert "Renamed" in result
|
||||
|
||||
content = memory_tool.execute_action("view", path="/new.txt")
|
||||
assert "data" in content
|
||||
|
||||
def test_rename_file_not_found(self, memory_tool):
|
||||
result = memory_tool.execute_action(
|
||||
"rename", old_path="/missing.txt", new_path="/new.txt"
|
||||
)
|
||||
assert "not found" in result.lower()
|
||||
|
||||
def test_rename_target_exists(self, memory_tool):
|
||||
memory_tool.execute_action("create", path="/a.txt", file_text="a")
|
||||
memory_tool.execute_action("create", path="/b.txt", file_text="b")
|
||||
|
||||
result = memory_tool.execute_action(
|
||||
"rename", old_path="/a.txt", new_path="/b.txt"
|
||||
)
|
||||
assert "already exists" in result.lower()
|
||||
|
||||
def test_rename_root_rejected(self, memory_tool):
|
||||
result = memory_tool.execute_action(
|
||||
"rename", old_path="/", new_path="/new/"
|
||||
)
|
||||
assert "Cannot rename root" in result
|
||||
|
||||
def test_rename_directory(self, memory_tool):
|
||||
memory_tool.execute_action("create", path="/old/f.txt", file_text="data")
|
||||
result = memory_tool.execute_action(
|
||||
"rename", old_path="/old/", new_path="/new/"
|
||||
)
|
||||
assert "Renamed" in result
|
||||
|
||||
content = memory_tool.execute_action("view", path="/new/f.txt")
|
||||
assert "data" in content
|
||||
|
||||
def test_rename_directory_not_found(self, memory_tool):
|
||||
result = memory_tool.execute_action(
|
||||
"rename", old_path="/missing/", new_path="/new/"
|
||||
)
|
||||
assert "not found" in result.lower()
|
||||
|
||||
def test_rename_invalid_path(self, memory_tool):
|
||||
result = memory_tool.execute_action(
|
||||
"rename", old_path="", new_path="/new.txt"
|
||||
)
|
||||
assert "Error" in result
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Metadata
|
||||
# =====================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestMemoryToolMetadata:
|
||||
|
||||
def test_actions_metadata(self, memory_tool):
|
||||
meta = memory_tool.get_actions_metadata()
|
||||
action_names = [a["name"] for a in meta]
|
||||
assert "view" in action_names
|
||||
assert "create" in action_names
|
||||
assert "str_replace" in action_names
|
||||
assert "insert" in action_names
|
||||
assert "delete" in action_names
|
||||
assert "rename" in action_names
|
||||
assert len(meta) == 6
|
||||
|
||||
def test_config_requirements(self, memory_tool):
|
||||
assert memory_tool.get_config_requirements() == {}
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Coverage gap tests (lines 254, 257, 271, 275, 279)
|
||||
# =====================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestMemoryToolValidatePath:
|
||||
|
||||
def test_validate_path_with_traversal_returns_none(self, memory_tool):
|
||||
"""Cover line 244-245: path with .. returns None."""
|
||||
result = memory_tool._validate_path("/some/../etc/passwd")
|
||||
assert result is None
|
||||
|
||||
def test_validate_path_with_directory_trailing_slash(self, memory_tool):
|
||||
"""Cover line 257-258: trailing slash is preserved."""
|
||||
result = memory_tool._validate_path("/some/dir/")
|
||||
assert result is not None
|
||||
assert result.endswith("/")
|
||||
|
||||
def test_validate_path_empty_returns_none(self, memory_tool):
|
||||
"""Cover: empty path returns None."""
|
||||
result = memory_tool._validate_path("")
|
||||
assert result is None
|
||||
|
||||
def test_validate_path_none_returns_none(self, memory_tool):
|
||||
"""Cover: None path returns None."""
|
||||
result = memory_tool._validate_path(None)
|
||||
assert result is None
|
||||
|
||||
def test_validate_path_relative_gets_prefixed(self, memory_tool):
|
||||
"""Cover line 241: relative path gets / prepended."""
|
||||
result = memory_tool._validate_path("relative/path")
|
||||
assert result == "/relative/path"
|
||||
|
||||
def test_validate_path_double_slash_returns_none(self, memory_tool):
|
||||
"""Cover line 244: double slash returns None."""
|
||||
result = memory_tool._validate_path("//etc/passwd")
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestMemoryToolViewDirectory:
|
||||
|
||||
def test_view_with_directory_path(self, memory_tool):
|
||||
"""Cover line 271-275: _view with directory path delegates to _view_directory."""
|
||||
result = memory_tool._view("/")
|
||||
assert isinstance(result, str)
|
||||
|
||||
def test_view_with_file_path(self, memory_tool):
|
||||
"""Cover line 279: _view with non-directory path delegates to _view_file."""
|
||||
# _view on a non-existent file path still exercises the _view_file path
|
||||
result = memory_tool._view("/nonexistent.txt")
|
||||
assert "Error" in result or "not found" in result.lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Coverage — additional uncovered lines: 254, 257, 271, 275, 279
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestMemoryToolValidatePathCoverage:
|
||||
|
||||
def test_validate_path_traversal_returns_none(self, memory_tool):
|
||||
"""Cover line 244: path with directory traversal returns None."""
|
||||
result = memory_tool._validate_path("/etc/../passwd")
|
||||
assert result is None
|
||||
|
||||
def test_validate_path_directory_appends_slash(self, memory_tool):
|
||||
"""Cover line 257: path ending with / preserves trailing slash."""
|
||||
result = memory_tool._validate_path("/some/dir/")
|
||||
assert result is not None
|
||||
assert result.endswith("/")
|
||||
|
||||
def test_validate_path_root_directory(self, memory_tool):
|
||||
"""Cover line 257: root directory preserved as-is."""
|
||||
result = memory_tool._validate_path("/")
|
||||
assert result == "/"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestMemoryToolViewCoverage:
|
||||
|
||||
def test_view_invalid_path_returns_error(self, memory_tool):
|
||||
"""Cover line 271: _view with invalid path returns error."""
|
||||
result = memory_tool._view("//bad//path")
|
||||
assert "Error" in result
|
||||
|
||||
def test_view_root_directory(self, memory_tool):
|
||||
"""Cover line 275: _view with root directory."""
|
||||
result = memory_tool._view("/")
|
||||
assert isinstance(result, str)
|
||||
|
||||
def test_view_file_path(self, memory_tool):
|
||||
"""Cover line 279: _view with file path delegates to _view_file."""
|
||||
result = memory_tool._view("/some/file.txt")
|
||||
assert isinstance(result, str)
|
||||
235
tests/api/answer/routes/test_answer.py
Normal file
235
tests/api/answer/routes/test_answer.py
Normal file
@@ -0,0 +1,235 @@
|
||||
"""Tests for application/api/answer/routes/answer.py"""
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from bson import ObjectId
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_stream_processor():
|
||||
"""Create a mock StreamProcessor."""
|
||||
with patch(
|
||||
"application.api.answer.routes.answer.StreamProcessor"
|
||||
) as MockProcessor:
|
||||
processor = MagicMock()
|
||||
processor.decoded_token = {"sub": "test_user"}
|
||||
processor.conversation_id = str(ObjectId())
|
||||
processor.agent_config = {}
|
||||
processor.agent_id = str(ObjectId())
|
||||
processor.is_shared_usage = False
|
||||
processor.shared_token = None
|
||||
processor.model_id = "gpt-4"
|
||||
processor.build_agent.return_value = MagicMock()
|
||||
MockProcessor.return_value = processor
|
||||
yield processor
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def answer_client(mock_mongo_db, flask_app):
|
||||
"""Create a test client with the answer route registered."""
|
||||
from flask_restx import Api
|
||||
|
||||
from application.api.answer.routes.answer import answer_ns
|
||||
|
||||
api = Api(flask_app)
|
||||
api.add_namespace(answer_ns)
|
||||
flask_app.config["TESTING"] = True
|
||||
return flask_app.test_client()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestAnswerResourcePost:
|
||||
def test_missing_question_returns_400(self, answer_client, mock_stream_processor):
|
||||
resp = answer_client.post(
|
||||
"/api/answer",
|
||||
data=json.dumps({}),
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_successful_answer(self, answer_client, mock_stream_processor):
|
||||
conv_id = str(ObjectId())
|
||||
with patch.object(
|
||||
mock_stream_processor.build_agent.return_value,
|
||||
"gen",
|
||||
return_value=iter([]),
|
||||
):
|
||||
with patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.validate_request",
|
||||
return_value=None,
|
||||
), patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.check_usage",
|
||||
return_value=None,
|
||||
), patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.complete_stream",
|
||||
return_value=iter(
|
||||
[
|
||||
f'data: {json.dumps({"type": "answer", "answer": "Hello"})}\n\n',
|
||||
f'data: {json.dumps({"type": "id", "id": conv_id})}\n\n',
|
||||
f'data: {json.dumps({"type": "end"})}\n\n',
|
||||
]
|
||||
),
|
||||
), patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.process_response_stream",
|
||||
return_value=(conv_id, "Hello", [], [], "", None),
|
||||
):
|
||||
resp = answer_client.post(
|
||||
"/api/answer",
|
||||
data=json.dumps({"question": "What is Python?"}),
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.get_json()
|
||||
assert data["answer"] == "Hello"
|
||||
assert data["conversation_id"] == conv_id
|
||||
|
||||
def test_unauthorized_returns_401(self, answer_client, mock_stream_processor):
|
||||
mock_stream_processor.decoded_token = None
|
||||
with patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.validate_request",
|
||||
return_value=None,
|
||||
):
|
||||
resp = answer_client.post(
|
||||
"/api/answer",
|
||||
data=json.dumps({"question": "test"}),
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
assert resp.get_json()["error"] == "Unauthorized"
|
||||
|
||||
def test_usage_exceeded_returns_error(self, answer_client, mock_stream_processor):
|
||||
|
||||
with patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.validate_request",
|
||||
return_value=None,
|
||||
), patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.check_usage",
|
||||
) as mock_check:
|
||||
with flask_app_context(answer_client):
|
||||
mock_check.return_value = ({"error": "Usage limit exceeded"}, 429)
|
||||
|
||||
resp = answer_client.post(
|
||||
"/api/answer",
|
||||
data=json.dumps({"question": "test"}),
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 429
|
||||
|
||||
def test_stream_error_returns_400(self, answer_client, mock_stream_processor):
|
||||
with patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.validate_request",
|
||||
return_value=None,
|
||||
), patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.check_usage",
|
||||
return_value=None,
|
||||
), patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.complete_stream",
|
||||
return_value=iter([]),
|
||||
), patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.process_response_stream",
|
||||
return_value=(None, None, None, None, None, "Stream error"),
|
||||
):
|
||||
resp = answer_client.post(
|
||||
"/api/answer",
|
||||
data=json.dumps({"question": "test"}),
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
assert resp.get_json()["error"] == "Stream error"
|
||||
|
||||
def test_exception_returns_500(self, answer_client, mock_stream_processor):
|
||||
with patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.validate_request",
|
||||
return_value=None,
|
||||
), patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.check_usage",
|
||||
return_value=None,
|
||||
), patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.complete_stream",
|
||||
side_effect=RuntimeError("unexpected"),
|
||||
):
|
||||
resp = answer_client.post(
|
||||
"/api/answer",
|
||||
data=json.dumps({"question": "test"}),
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 500
|
||||
assert "error" in resp.get_json()
|
||||
|
||||
def test_structured_info_merged_into_result(
|
||||
self, answer_client, mock_stream_processor
|
||||
):
|
||||
conv_id = str(ObjectId())
|
||||
with patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.validate_request",
|
||||
return_value=None,
|
||||
), patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.check_usage",
|
||||
return_value=None,
|
||||
), patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.complete_stream",
|
||||
return_value=iter([]),
|
||||
), patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.process_response_stream",
|
||||
return_value=(
|
||||
conv_id,
|
||||
'{"key": "val"}',
|
||||
[],
|
||||
[],
|
||||
"",
|
||||
None,
|
||||
{"structured": True, "schema": {"type": "object"}},
|
||||
),
|
||||
):
|
||||
resp = answer_client.post(
|
||||
"/api/answer",
|
||||
data=json.dumps({"question": "test"}),
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.get_json()
|
||||
assert data["structured"] is True
|
||||
assert data["schema"] == {"type": "object"}
|
||||
|
||||
def test_result_contains_all_expected_fields(
|
||||
self, answer_client, mock_stream_processor
|
||||
):
|
||||
conv_id = str(ObjectId())
|
||||
with patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.validate_request",
|
||||
return_value=None,
|
||||
), patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.check_usage",
|
||||
return_value=None,
|
||||
), patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.complete_stream",
|
||||
return_value=iter([]),
|
||||
), patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.process_response_stream",
|
||||
return_value=(
|
||||
conv_id,
|
||||
"answer text",
|
||||
[{"title": "src"}],
|
||||
[{"tool": "t"}],
|
||||
"thinking...",
|
||||
None,
|
||||
),
|
||||
):
|
||||
resp = answer_client.post(
|
||||
"/api/answer",
|
||||
data=json.dumps({"question": "test"}),
|
||||
content_type="application/json",
|
||||
)
|
||||
data = resp.get_json()
|
||||
assert data["conversation_id"] == conv_id
|
||||
assert data["answer"] == "answer text"
|
||||
assert data["sources"] == [{"title": "src"}]
|
||||
assert data["tool_calls"] == [{"tool": "t"}]
|
||||
assert data["thought"] == "thinking..."
|
||||
|
||||
|
||||
def flask_app_context(client):
|
||||
"""Helper to get app context from test client."""
|
||||
return client.application.app_context()
|
||||
195
tests/api/answer/routes/test_stream.py
Normal file
195
tests/api/answer/routes/test_stream.py
Normal file
@@ -0,0 +1,195 @@
|
||||
"""Tests for application/api/answer/routes/stream.py"""
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from bson import ObjectId
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_stream_processor():
|
||||
"""Create a mock StreamProcessor for stream tests."""
|
||||
with patch(
|
||||
"application.api.answer.routes.stream.StreamProcessor"
|
||||
) as MockProcessor:
|
||||
processor = MagicMock()
|
||||
processor.decoded_token = {"sub": "test_user"}
|
||||
processor.conversation_id = str(ObjectId())
|
||||
processor.agent_config = {}
|
||||
processor.agent_id = str(ObjectId())
|
||||
processor.is_shared_usage = False
|
||||
processor.shared_token = None
|
||||
processor.model_id = "gpt-4"
|
||||
processor.build_agent.return_value = MagicMock()
|
||||
MockProcessor.return_value = processor
|
||||
yield processor
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def stream_client(mock_mongo_db, flask_app):
|
||||
"""Create a test client with the stream route registered."""
|
||||
from flask_restx import Api
|
||||
|
||||
from application.api.answer.routes.stream import answer_ns
|
||||
|
||||
api = Api(flask_app)
|
||||
api.add_namespace(answer_ns)
|
||||
flask_app.config["TESTING"] = True
|
||||
return flask_app.test_client()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestStreamResourcePost:
|
||||
def test_missing_question_returns_400(self, stream_client, mock_stream_processor):
|
||||
resp = stream_client.post(
|
||||
"/stream",
|
||||
data=json.dumps({}),
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_successful_stream(self, stream_client, mock_stream_processor):
|
||||
def fake_stream(*args, **kwargs):
|
||||
yield f'data: {json.dumps({"type": "answer", "answer": "Hi"})}\n\n'
|
||||
yield f'data: {json.dumps({"type": "end"})}\n\n'
|
||||
|
||||
with patch(
|
||||
"application.api.answer.routes.stream.StreamResource.validate_request",
|
||||
return_value=None,
|
||||
), patch(
|
||||
"application.api.answer.routes.stream.StreamResource.check_usage",
|
||||
return_value=None,
|
||||
), patch(
|
||||
"application.api.answer.routes.stream.StreamResource.complete_stream",
|
||||
side_effect=fake_stream,
|
||||
):
|
||||
resp = stream_client.post(
|
||||
"/stream",
|
||||
data=json.dumps({"question": "What is Python?"}),
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert "text/event-stream" in resp.content_type
|
||||
data = resp.get_data(as_text=True)
|
||||
assert '"type": "answer"' in data
|
||||
assert '"answer": "Hi"' in data
|
||||
|
||||
def test_unauthorized_returns_401_stream(
|
||||
self, stream_client, mock_stream_processor
|
||||
):
|
||||
mock_stream_processor.decoded_token = None
|
||||
with patch(
|
||||
"application.api.answer.routes.stream.StreamResource.validate_request",
|
||||
return_value=None,
|
||||
):
|
||||
resp = stream_client.post(
|
||||
"/stream",
|
||||
data=json.dumps({"question": "test"}),
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
assert "text/event-stream" in resp.content_type
|
||||
data = resp.get_data(as_text=True)
|
||||
assert "Unauthorized" in data
|
||||
|
||||
def test_usage_exceeded_returns_error(
|
||||
self, stream_client, mock_stream_processor
|
||||
):
|
||||
with patch(
|
||||
"application.api.answer.routes.stream.StreamResource.validate_request",
|
||||
return_value=None,
|
||||
), patch(
|
||||
"application.api.answer.routes.stream.StreamResource.check_usage",
|
||||
) as mock_check:
|
||||
mock_check.return_value = ({"error": "Usage limit exceeded"}, 429)
|
||||
resp = stream_client.post(
|
||||
"/stream",
|
||||
data=json.dumps({"question": "test"}),
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 429
|
||||
|
||||
def test_value_error_returns_400_stream(
|
||||
self, stream_client, mock_stream_processor
|
||||
):
|
||||
mock_stream_processor.build_agent.side_effect = ValueError("bad data")
|
||||
with patch(
|
||||
"application.api.answer.routes.stream.StreamResource.validate_request",
|
||||
return_value=None,
|
||||
):
|
||||
resp = stream_client.post(
|
||||
"/stream",
|
||||
data=json.dumps({"question": "test"}),
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
assert "text/event-stream" in resp.content_type
|
||||
data = resp.get_data(as_text=True)
|
||||
assert "Malformed request body" in data
|
||||
|
||||
def test_general_exception_returns_400_stream(
|
||||
self, stream_client, mock_stream_processor
|
||||
):
|
||||
mock_stream_processor.build_agent.side_effect = RuntimeError("crash")
|
||||
with patch(
|
||||
"application.api.answer.routes.stream.StreamResource.validate_request",
|
||||
return_value=None,
|
||||
):
|
||||
resp = stream_client.post(
|
||||
"/stream",
|
||||
data=json.dumps({"question": "test"}),
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
assert "text/event-stream" in resp.content_type
|
||||
data = resp.get_data(as_text=True)
|
||||
assert "Unknown error occurred" in data
|
||||
|
||||
def test_index_in_data_requires_conversation_id(
|
||||
self, stream_client, mock_stream_processor
|
||||
):
|
||||
"""When 'index' is present, validate_request is called with require_conversation_id=True."""
|
||||
resp = stream_client.post(
|
||||
"/stream",
|
||||
data=json.dumps({"question": "test", "index": 0}),
|
||||
content_type="application/json",
|
||||
)
|
||||
# Should get 400 since conversation_id is missing
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_stream_passes_attachments_and_index(
|
||||
self, stream_client, mock_stream_processor
|
||||
):
|
||||
"""Verify attachments and index params are forwarded to complete_stream."""
|
||||
|
||||
def fake_stream(*args, **kwargs):
|
||||
yield f'data: {json.dumps({"type": "end"})}\n\n'
|
||||
|
||||
conv_id = str(ObjectId())
|
||||
with patch(
|
||||
"application.api.answer.routes.stream.StreamResource.validate_request",
|
||||
return_value=None,
|
||||
), patch(
|
||||
"application.api.answer.routes.stream.StreamResource.check_usage",
|
||||
return_value=None,
|
||||
), patch(
|
||||
"application.api.answer.routes.stream.StreamResource.complete_stream",
|
||||
side_effect=fake_stream,
|
||||
) as mock_complete:
|
||||
resp = stream_client.post(
|
||||
"/stream",
|
||||
data=json.dumps(
|
||||
{
|
||||
"question": "test",
|
||||
"conversation_id": conv_id,
|
||||
"index": 3,
|
||||
"attachments": ["att1", "att2"],
|
||||
}
|
||||
),
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
call_kwargs = mock_complete.call_args
|
||||
assert call_kwargs.kwargs.get("index") == 3
|
||||
assert call_kwargs.kwargs.get("attachment_ids") == ["att1", "att2"]
|
||||
0
tests/api/answer/services/compression/__init__.py
Normal file
0
tests/api/answer/services/compression/__init__.py
Normal file
303
tests/api/answer/services/compression/test_message_builder.py
Normal file
303
tests/api/answer/services/compression/test_message_builder.py
Normal file
@@ -0,0 +1,303 @@
|
||||
"""Tests for application/api/answer/services/compression/message_builder.py"""
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
from application.api.answer.services.compression.message_builder import MessageBuilder
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestBuildFromCompressedContext:
|
||||
def test_no_compression_returns_system_only(self):
|
||||
messages = MessageBuilder.build_from_compressed_context(
|
||||
system_prompt="You are helpful.",
|
||||
compressed_summary=None,
|
||||
recent_queries=[],
|
||||
)
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["role"] == "system"
|
||||
assert messages[0]["content"] == "You are helpful."
|
||||
|
||||
def test_with_recent_queries_no_compression(self):
|
||||
queries = [
|
||||
{"prompt": "Hello", "response": "Hi there!"},
|
||||
{"prompt": "How are you?", "response": "I'm fine."},
|
||||
]
|
||||
messages = MessageBuilder.build_from_compressed_context(
|
||||
system_prompt="System prompt",
|
||||
compressed_summary=None,
|
||||
recent_queries=queries,
|
||||
)
|
||||
# system + 2 * (user + assistant) = 5
|
||||
assert len(messages) == 5
|
||||
assert messages[1] == {"role": "user", "content": "Hello"}
|
||||
assert messages[2] == {"role": "assistant", "content": "Hi there!"}
|
||||
assert messages[3] == {"role": "user", "content": "How are you?"}
|
||||
assert messages[4] == {"role": "assistant", "content": "I'm fine."}
|
||||
|
||||
def test_with_compressed_summary_appended_to_system(self):
|
||||
messages = MessageBuilder.build_from_compressed_context(
|
||||
system_prompt="You are helpful.",
|
||||
compressed_summary="Previous: user asked about Python.",
|
||||
recent_queries=[{"prompt": "More?", "response": "Sure."}],
|
||||
)
|
||||
system_content = messages[0]["content"]
|
||||
assert "This session is being continued" in system_content
|
||||
assert "Previous: user asked about Python." in system_content
|
||||
|
||||
def test_mid_execution_context_type(self):
|
||||
messages = MessageBuilder.build_from_compressed_context(
|
||||
system_prompt="System",
|
||||
compressed_summary="Summary here",
|
||||
recent_queries=[{"prompt": "q", "response": "r"}],
|
||||
context_type="mid_execution",
|
||||
)
|
||||
system_content = messages[0]["content"]
|
||||
assert "Context window limit reached" in system_content
|
||||
|
||||
def test_include_tool_calls(self):
|
||||
queries = [
|
||||
{
|
||||
"prompt": "Search for X",
|
||||
"response": "Found X",
|
||||
"tool_calls": [
|
||||
{
|
||||
"call_id": "call-1",
|
||||
"action_name": "search",
|
||||
"arguments": {"q": "X"},
|
||||
"result": "X found",
|
||||
}
|
||||
],
|
||||
}
|
||||
]
|
||||
messages = MessageBuilder.build_from_compressed_context(
|
||||
system_prompt="System",
|
||||
compressed_summary=None,
|
||||
recent_queries=queries,
|
||||
include_tool_calls=True,
|
||||
)
|
||||
# system + user + assistant + tool_call_assistant + tool_response = 5
|
||||
assert len(messages) == 5
|
||||
assert messages[3]["role"] == "assistant"
|
||||
assert "function_call" in messages[3]["content"][0]
|
||||
assert messages[4]["role"] == "tool"
|
||||
assert "function_response" in messages[4]["content"][0]
|
||||
|
||||
def test_tool_calls_not_included_by_default(self):
|
||||
queries = [
|
||||
{
|
||||
"prompt": "Search",
|
||||
"response": "Found",
|
||||
"tool_calls": [
|
||||
{
|
||||
"call_id": "c1",
|
||||
"action_name": "search",
|
||||
"arguments": {},
|
||||
"result": "ok",
|
||||
}
|
||||
],
|
||||
}
|
||||
]
|
||||
messages = MessageBuilder.build_from_compressed_context(
|
||||
system_prompt="System",
|
||||
compressed_summary=None,
|
||||
recent_queries=queries,
|
||||
include_tool_calls=False,
|
||||
)
|
||||
# system + user + assistant = 3 (no tool messages)
|
||||
assert len(messages) == 3
|
||||
|
||||
def test_tool_call_without_call_id_generates_uuid(self):
|
||||
queries = [
|
||||
{
|
||||
"prompt": "q",
|
||||
"response": "r",
|
||||
"tool_calls": [
|
||||
{
|
||||
"action_name": "act",
|
||||
"arguments": {},
|
||||
"result": "res",
|
||||
}
|
||||
],
|
||||
}
|
||||
]
|
||||
messages = MessageBuilder.build_from_compressed_context(
|
||||
system_prompt="S",
|
||||
compressed_summary=None,
|
||||
recent_queries=queries,
|
||||
include_tool_calls=True,
|
||||
)
|
||||
tool_msg = messages[3]["content"][0]
|
||||
call_id = tool_msg["function_call"]["call_id"]
|
||||
assert call_id is not None
|
||||
assert len(call_id) > 0
|
||||
|
||||
def test_continuation_message_when_no_recent_queries_but_has_summary(self):
|
||||
messages = MessageBuilder.build_from_compressed_context(
|
||||
system_prompt="System",
|
||||
compressed_summary="Everything was compressed",
|
||||
recent_queries=[],
|
||||
)
|
||||
# system + continuation user message = 2
|
||||
assert len(messages) == 2
|
||||
assert messages[1]["role"] == "user"
|
||||
assert "continue" in messages[1]["content"].lower()
|
||||
|
||||
def test_no_continuation_when_no_summary(self):
|
||||
messages = MessageBuilder.build_from_compressed_context(
|
||||
system_prompt="System",
|
||||
compressed_summary=None,
|
||||
recent_queries=[],
|
||||
)
|
||||
assert len(messages) == 1
|
||||
|
||||
def test_queries_without_prompt_or_response_skipped(self):
|
||||
queries = [
|
||||
{"other_field": "value"},
|
||||
{"prompt": "real", "response": "answer"},
|
||||
]
|
||||
messages = MessageBuilder.build_from_compressed_context(
|
||||
system_prompt="S",
|
||||
compressed_summary=None,
|
||||
recent_queries=queries,
|
||||
)
|
||||
# system + 1 valid query (user + assistant) = 3
|
||||
assert len(messages) == 3
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestAppendCompressionContext:
|
||||
def test_pre_request_context(self):
|
||||
result = MessageBuilder._append_compression_context(
|
||||
"Original prompt", "Summary text", "pre_request"
|
||||
)
|
||||
assert "This session is being continued" in result
|
||||
assert "Summary text" in result
|
||||
assert result.startswith("Original prompt")
|
||||
|
||||
def test_mid_execution_context(self):
|
||||
result = MessageBuilder._append_compression_context(
|
||||
"Original prompt", "Summary text", "mid_execution"
|
||||
)
|
||||
assert "Context window limit reached" in result
|
||||
assert "Summary text" in result
|
||||
|
||||
def test_removes_existing_compression_context(self):
|
||||
prompt_with_existing = (
|
||||
"Original prompt\n\n---\n\nThis session is being continued from old"
|
||||
)
|
||||
result = MessageBuilder._append_compression_context(
|
||||
prompt_with_existing, "New summary", "pre_request"
|
||||
)
|
||||
# Should not contain old context twice
|
||||
assert result.count("This session is being continued") == 1
|
||||
assert "New summary" in result
|
||||
|
||||
def test_removes_mid_execution_context(self):
|
||||
prompt_with_existing = (
|
||||
"Original\n\n---\n\nContext window limit reached during execution. Old."
|
||||
)
|
||||
result = MessageBuilder._append_compression_context(
|
||||
prompt_with_existing, "New", "mid_execution"
|
||||
)
|
||||
assert result.count("Context window limit reached") == 1
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestRebuildMessagesAfterCompression:
|
||||
def test_basic_rebuild(self):
|
||||
messages = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "old message"},
|
||||
{"role": "assistant", "content": "old reply"},
|
||||
]
|
||||
recent = [{"prompt": "new q", "response": "new r"}]
|
||||
|
||||
result = MessageBuilder.rebuild_messages_after_compression(
|
||||
messages=messages,
|
||||
compressed_summary="Everything was compressed.",
|
||||
recent_queries=recent,
|
||||
)
|
||||
assert result is not None
|
||||
# system + user + assistant = 3
|
||||
assert len(result) == 3
|
||||
assert "Context window limit reached" in result[0]["content"]
|
||||
assert result[1] == {"role": "user", "content": "new q"}
|
||||
assert result[2] == {"role": "assistant", "content": "new r"}
|
||||
|
||||
def test_returns_none_without_system_message(self):
|
||||
messages = [
|
||||
{"role": "user", "content": "hello"},
|
||||
]
|
||||
result = MessageBuilder.rebuild_messages_after_compression(
|
||||
messages=messages,
|
||||
compressed_summary="summary",
|
||||
recent_queries=[],
|
||||
)
|
||||
assert result is None
|
||||
|
||||
def test_no_summary_keeps_system_unchanged(self):
|
||||
messages = [{"role": "system", "content": "Be helpful."}]
|
||||
result = MessageBuilder.rebuild_messages_after_compression(
|
||||
messages=messages,
|
||||
compressed_summary=None,
|
||||
recent_queries=[],
|
||||
)
|
||||
assert result is not None
|
||||
assert result[0]["content"] == "Be helpful."
|
||||
|
||||
def test_include_tool_calls_in_rebuild(self):
|
||||
messages = [{"role": "system", "content": "S"}]
|
||||
recent = [
|
||||
{
|
||||
"prompt": "q",
|
||||
"response": "r",
|
||||
"tool_calls": [
|
||||
{
|
||||
"call_id": "c1",
|
||||
"action_name": "act",
|
||||
"arguments": {"a": 1},
|
||||
"result": "done",
|
||||
}
|
||||
],
|
||||
}
|
||||
]
|
||||
result = MessageBuilder.rebuild_messages_after_compression(
|
||||
messages=messages,
|
||||
compressed_summary="s",
|
||||
recent_queries=recent,
|
||||
include_tool_calls=True,
|
||||
)
|
||||
# system + user + assistant + tool_call + tool_response = 5
|
||||
assert len(result) == 5
|
||||
|
||||
def test_continuation_added_when_no_recent_queries(self):
|
||||
messages = [{"role": "system", "content": "S"}]
|
||||
result = MessageBuilder.rebuild_messages_after_compression(
|
||||
messages=messages,
|
||||
compressed_summary="All compressed",
|
||||
recent_queries=[],
|
||||
)
|
||||
assert len(result) == 2
|
||||
assert result[1]["role"] == "user"
|
||||
assert "continue" in result[1]["content"].lower()
|
||||
|
||||
def test_include_current_execution_preserves_extra_messages(self):
|
||||
messages = [
|
||||
{"role": "system", "content": "S"},
|
||||
{"role": "user", "content": "q1"},
|
||||
{"role": "assistant", "content": "r1"},
|
||||
{"role": "user", "content": "current execution msg"},
|
||||
]
|
||||
recent = [{"prompt": "q1", "response": "r1"}]
|
||||
|
||||
result = MessageBuilder.rebuild_messages_after_compression(
|
||||
messages=messages,
|
||||
compressed_summary="summary",
|
||||
recent_queries=recent,
|
||||
include_current_execution=True,
|
||||
)
|
||||
assert result is not None
|
||||
# Should include the current execution message
|
||||
contents = [m.get("content") for m in result]
|
||||
assert "current execution msg" in contents
|
||||
447
tests/api/answer/services/compression/test_orchestrator.py
Normal file
447
tests/api/answer/services/compression/test_orchestrator.py
Normal file
@@ -0,0 +1,447 @@
|
||||
"""Tests for application/api/answer/services/compression/orchestrator.py"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from application.api.answer.services.compression.orchestrator import (
|
||||
CompressionOrchestrator,
|
||||
)
|
||||
from application.api.answer.services.compression.types import (
|
||||
CompressionMetadata,
|
||||
CompressionResult,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_conversation_service():
|
||||
svc = MagicMock()
|
||||
return svc
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_threshold_checker():
|
||||
checker = MagicMock()
|
||||
return checker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def orchestrator(mock_conversation_service, mock_threshold_checker):
|
||||
return CompressionOrchestrator(
|
||||
conversation_service=mock_conversation_service,
|
||||
threshold_checker=mock_threshold_checker,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_conversation():
|
||||
return {
|
||||
"queries": [
|
||||
{"prompt": "q0", "response": "r0"},
|
||||
{"prompt": "q1", "response": "r1"},
|
||||
{"prompt": "q2", "response": "r2"},
|
||||
],
|
||||
"compression_metadata": {},
|
||||
"agent_id": "agent-1",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def decoded_token():
|
||||
return {"sub": "user123"}
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCompressIfNeeded:
|
||||
def test_conversation_not_found_returns_failure(
|
||||
self, orchestrator, mock_conversation_service
|
||||
):
|
||||
mock_conversation_service.get_conversation.return_value = None
|
||||
|
||||
result = orchestrator.compress_if_needed(
|
||||
conversation_id="conv1",
|
||||
user_id="user1",
|
||||
model_id="gpt-4",
|
||||
decoded_token={"sub": "user1"},
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "not found" in result.error
|
||||
|
||||
def test_no_compression_needed(
|
||||
self,
|
||||
orchestrator,
|
||||
mock_conversation_service,
|
||||
mock_threshold_checker,
|
||||
sample_conversation,
|
||||
):
|
||||
mock_conversation_service.get_conversation.return_value = sample_conversation
|
||||
mock_threshold_checker.should_compress.return_value = False
|
||||
|
||||
result = orchestrator.compress_if_needed(
|
||||
conversation_id="conv1",
|
||||
user_id="user1",
|
||||
model_id="gpt-4",
|
||||
decoded_token={"sub": "user1"},
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.compression_performed is False
|
||||
assert len(result.recent_queries) == 3
|
||||
|
||||
def test_compression_performed_successfully(
|
||||
self,
|
||||
orchestrator,
|
||||
mock_conversation_service,
|
||||
mock_threshold_checker,
|
||||
sample_conversation,
|
||||
decoded_token,
|
||||
):
|
||||
mock_conversation_service.get_conversation.return_value = sample_conversation
|
||||
mock_threshold_checker.should_compress.return_value = True
|
||||
|
||||
mock_metadata = MagicMock(spec=CompressionMetadata)
|
||||
mock_metadata.compression_ratio = 5.0
|
||||
mock_metadata.original_token_count = 1000
|
||||
mock_metadata.compressed_token_count = 200
|
||||
mock_metadata.to_dict.return_value = {"query_index": 2}
|
||||
|
||||
with patch.object(
|
||||
orchestrator, "_perform_compression"
|
||||
) as mock_perform:
|
||||
mock_perform.return_value = CompressionResult.success_with_compression(
|
||||
"compressed summary",
|
||||
[{"prompt": "q2", "response": "r2"}],
|
||||
mock_metadata,
|
||||
)
|
||||
|
||||
result = orchestrator.compress_if_needed(
|
||||
conversation_id="conv1",
|
||||
user_id="user1",
|
||||
model_id="gpt-4",
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.compression_performed is True
|
||||
assert result.compressed_summary == "compressed summary"
|
||||
mock_perform.assert_called_once()
|
||||
|
||||
def test_exception_returns_failure(
|
||||
self,
|
||||
orchestrator,
|
||||
mock_conversation_service,
|
||||
):
|
||||
mock_conversation_service.get_conversation.side_effect = RuntimeError("DB down")
|
||||
|
||||
result = orchestrator.compress_if_needed(
|
||||
conversation_id="conv1",
|
||||
user_id="user1",
|
||||
model_id="gpt-4",
|
||||
decoded_token={"sub": "user1"},
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "DB down" in result.error
|
||||
|
||||
def test_custom_query_tokens(
|
||||
self,
|
||||
orchestrator,
|
||||
mock_conversation_service,
|
||||
mock_threshold_checker,
|
||||
sample_conversation,
|
||||
):
|
||||
mock_conversation_service.get_conversation.return_value = sample_conversation
|
||||
mock_threshold_checker.should_compress.return_value = False
|
||||
|
||||
orchestrator.compress_if_needed(
|
||||
conversation_id="conv1",
|
||||
user_id="user1",
|
||||
model_id="gpt-4",
|
||||
decoded_token={"sub": "user1"},
|
||||
current_query_tokens=1000,
|
||||
)
|
||||
|
||||
mock_threshold_checker.should_compress.assert_called_once_with(
|
||||
sample_conversation, "gpt-4", 1000
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestPerformCompression:
|
||||
@patch(
|
||||
"application.api.answer.services.compression.orchestrator.get_provider_from_model_id"
|
||||
)
|
||||
@patch(
|
||||
"application.api.answer.services.compression.orchestrator.get_api_key_for_provider"
|
||||
)
|
||||
@patch("application.api.answer.services.compression.orchestrator.LLMCreator")
|
||||
@patch("application.api.answer.services.compression.orchestrator.CompressionService")
|
||||
@patch("application.api.answer.services.compression.orchestrator.settings")
|
||||
def test_successful_compression(
|
||||
self,
|
||||
mock_settings,
|
||||
MockCompressionService,
|
||||
MockLLMCreator,
|
||||
mock_get_api_key,
|
||||
mock_get_provider,
|
||||
mock_conversation_service,
|
||||
mock_threshold_checker,
|
||||
sample_conversation,
|
||||
decoded_token,
|
||||
):
|
||||
mock_settings.COMPRESSION_MODEL_OVERRIDE = None
|
||||
mock_get_provider.return_value = "openai"
|
||||
mock_get_api_key.return_value = "sk-test"
|
||||
MockLLMCreator.create_llm.return_value = MagicMock()
|
||||
|
||||
mock_metadata = MagicMock(spec=CompressionMetadata)
|
||||
mock_metadata.compression_ratio = 5.0
|
||||
mock_metadata.original_token_count = 500
|
||||
mock_metadata.compressed_token_count = 100
|
||||
|
||||
mock_svc_instance = MagicMock()
|
||||
mock_svc_instance.compress_and_save.return_value = mock_metadata
|
||||
mock_svc_instance.get_compressed_context.return_value = (
|
||||
"compressed text",
|
||||
[{"prompt": "q2", "response": "r2"}],
|
||||
)
|
||||
MockCompressionService.return_value = mock_svc_instance
|
||||
|
||||
# After compression, reload conversation
|
||||
mock_conversation_service.get_conversation.return_value = sample_conversation
|
||||
|
||||
orch = CompressionOrchestrator(
|
||||
conversation_service=mock_conversation_service,
|
||||
threshold_checker=mock_threshold_checker,
|
||||
)
|
||||
|
||||
result = orch._perform_compression(
|
||||
"conv1", sample_conversation, "gpt-4", decoded_token
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.compression_performed is True
|
||||
assert result.compressed_summary == "compressed text"
|
||||
mock_svc_instance.compress_and_save.assert_called_once()
|
||||
|
||||
@patch(
|
||||
"application.api.answer.services.compression.orchestrator.get_provider_from_model_id"
|
||||
)
|
||||
@patch(
|
||||
"application.api.answer.services.compression.orchestrator.get_api_key_for_provider"
|
||||
)
|
||||
@patch("application.api.answer.services.compression.orchestrator.LLMCreator")
|
||||
@patch("application.api.answer.services.compression.orchestrator.settings")
|
||||
def test_uses_compression_model_override(
|
||||
self,
|
||||
mock_settings,
|
||||
MockLLMCreator,
|
||||
mock_get_api_key,
|
||||
mock_get_provider,
|
||||
mock_conversation_service,
|
||||
mock_threshold_checker,
|
||||
decoded_token,
|
||||
):
|
||||
mock_settings.COMPRESSION_MODEL_OVERRIDE = "gpt-3.5-turbo"
|
||||
mock_get_provider.return_value = "openai"
|
||||
mock_get_api_key.return_value = "sk-test"
|
||||
MockLLMCreator.create_llm.return_value = MagicMock()
|
||||
|
||||
conversation = {"queries": [{"prompt": "q", "response": "r"}], "agent_id": "a"}
|
||||
|
||||
with patch(
|
||||
"application.api.answer.services.compression.orchestrator.CompressionService"
|
||||
) as MockCS:
|
||||
mock_svc = MagicMock()
|
||||
mock_svc.compress_and_save.return_value = MagicMock(
|
||||
compression_ratio=3.0,
|
||||
original_token_count=300,
|
||||
compressed_token_count=100,
|
||||
)
|
||||
mock_svc.get_compressed_context.return_value = ("s", [])
|
||||
MockCS.return_value = mock_svc
|
||||
|
||||
mock_conversation_service.get_conversation.return_value = conversation
|
||||
|
||||
orch = CompressionOrchestrator(
|
||||
conversation_service=mock_conversation_service,
|
||||
threshold_checker=mock_threshold_checker,
|
||||
)
|
||||
orch._perform_compression("c1", conversation, "gpt-4", decoded_token)
|
||||
|
||||
# Verify the override model was used
|
||||
mock_get_provider.assert_called_with("gpt-3.5-turbo")
|
||||
|
||||
@patch(
|
||||
"application.api.answer.services.compression.orchestrator.get_provider_from_model_id"
|
||||
)
|
||||
@patch(
|
||||
"application.api.answer.services.compression.orchestrator.get_api_key_for_provider"
|
||||
)
|
||||
@patch("application.api.answer.services.compression.orchestrator.LLMCreator")
|
||||
@patch("application.api.answer.services.compression.orchestrator.CompressionService")
|
||||
@patch("application.api.answer.services.compression.orchestrator.settings")
|
||||
def test_no_queries_returns_no_compression(
|
||||
self,
|
||||
mock_settings,
|
||||
MockCompressionService,
|
||||
MockLLMCreator,
|
||||
mock_get_api_key,
|
||||
mock_get_provider,
|
||||
mock_conversation_service,
|
||||
mock_threshold_checker,
|
||||
decoded_token,
|
||||
):
|
||||
mock_settings.COMPRESSION_MODEL_OVERRIDE = None
|
||||
mock_get_provider.return_value = "openai"
|
||||
mock_get_api_key.return_value = "sk-test"
|
||||
MockLLMCreator.create_llm.return_value = MagicMock()
|
||||
|
||||
conversation = {"queries": [], "agent_id": "a"}
|
||||
|
||||
orch = CompressionOrchestrator(
|
||||
conversation_service=mock_conversation_service,
|
||||
threshold_checker=mock_threshold_checker,
|
||||
)
|
||||
result = orch._perform_compression("c1", conversation, "gpt-4", decoded_token)
|
||||
|
||||
assert result.success is True
|
||||
assert result.compression_performed is False
|
||||
|
||||
def test_exception_returns_failure(
|
||||
self,
|
||||
mock_conversation_service,
|
||||
mock_threshold_checker,
|
||||
decoded_token,
|
||||
):
|
||||
conversation = {
|
||||
"queries": [{"prompt": "q", "response": "r"}],
|
||||
"agent_id": "a",
|
||||
}
|
||||
|
||||
with patch(
|
||||
"application.api.answer.services.compression.orchestrator.settings"
|
||||
) as mock_settings, patch(
|
||||
"application.api.answer.services.compression.orchestrator.get_provider_from_model_id",
|
||||
side_effect=RuntimeError("provider error"),
|
||||
):
|
||||
mock_settings.COMPRESSION_MODEL_OVERRIDE = None
|
||||
|
||||
orch = CompressionOrchestrator(
|
||||
conversation_service=mock_conversation_service,
|
||||
threshold_checker=mock_threshold_checker,
|
||||
)
|
||||
result = orch._perform_compression(
|
||||
"c1", conversation, "gpt-4", decoded_token
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "provider error" in result.error
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCompressMidExecution:
|
||||
def test_with_provided_conversation(
|
||||
self,
|
||||
orchestrator,
|
||||
sample_conversation,
|
||||
decoded_token,
|
||||
):
|
||||
with patch.object(
|
||||
orchestrator, "_perform_compression"
|
||||
) as mock_perform:
|
||||
mock_perform.return_value = CompressionResult.success_no_compression([])
|
||||
|
||||
orchestrator.compress_mid_execution(
|
||||
conversation_id="conv1",
|
||||
user_id="user1",
|
||||
model_id="gpt-4",
|
||||
decoded_token=decoded_token,
|
||||
current_conversation=sample_conversation,
|
||||
)
|
||||
|
||||
mock_perform.assert_called_once_with(
|
||||
"conv1", sample_conversation, "gpt-4", decoded_token
|
||||
)
|
||||
|
||||
def test_loads_conversation_when_not_provided(
|
||||
self,
|
||||
orchestrator,
|
||||
mock_conversation_service,
|
||||
sample_conversation,
|
||||
decoded_token,
|
||||
):
|
||||
mock_conversation_service.get_conversation.return_value = sample_conversation
|
||||
|
||||
with patch.object(
|
||||
orchestrator, "_perform_compression"
|
||||
) as mock_perform:
|
||||
mock_perform.return_value = CompressionResult.success_no_compression([])
|
||||
|
||||
orchestrator.compress_mid_execution(
|
||||
conversation_id="conv1",
|
||||
user_id="user1",
|
||||
model_id="gpt-4",
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
mock_conversation_service.get_conversation.assert_called_once_with(
|
||||
"conv1", "user1"
|
||||
)
|
||||
mock_perform.assert_called_once()
|
||||
|
||||
def test_conversation_not_found_returns_failure(
|
||||
self,
|
||||
orchestrator,
|
||||
mock_conversation_service,
|
||||
decoded_token,
|
||||
):
|
||||
mock_conversation_service.get_conversation.return_value = None
|
||||
|
||||
result = orchestrator.compress_mid_execution(
|
||||
conversation_id="conv1",
|
||||
user_id="user1",
|
||||
model_id="gpt-4",
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "not found" in result.error
|
||||
|
||||
def test_exception_returns_failure(
|
||||
self,
|
||||
orchestrator,
|
||||
mock_conversation_service,
|
||||
decoded_token,
|
||||
):
|
||||
mock_conversation_service.get_conversation.side_effect = RuntimeError("fail")
|
||||
|
||||
result = orchestrator.compress_mid_execution(
|
||||
conversation_id="conv1",
|
||||
user_id="user1",
|
||||
model_id="gpt-4",
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "fail" in result.error
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestOrchestratorInit:
|
||||
def test_default_threshold_checker(self, mock_conversation_service):
|
||||
orch = CompressionOrchestrator(
|
||||
conversation_service=mock_conversation_service
|
||||
)
|
||||
assert orch.threshold_checker is not None
|
||||
assert orch.conversation_service is mock_conversation_service
|
||||
|
||||
def test_custom_threshold_checker(
|
||||
self, mock_conversation_service, mock_threshold_checker
|
||||
):
|
||||
orch = CompressionOrchestrator(
|
||||
conversation_service=mock_conversation_service,
|
||||
threshold_checker=mock_threshold_checker,
|
||||
)
|
||||
assert orch.threshold_checker is mock_threshold_checker
|
||||
423
tests/api/answer/services/compression/test_service.py
Normal file
423
tests/api/answer/services/compression/test_service.py
Normal file
@@ -0,0 +1,423 @@
|
||||
"""Tests for application/api/answer/services/compression/service.py"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from application.api.answer.services.compression.service import CompressionService
|
||||
from application.api.answer.services.compression.types import CompressionMetadata
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm():
|
||||
llm = MagicMock()
|
||||
llm.gen.return_value = "<summary>Compressed summary content</summary>"
|
||||
return llm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_conversation_service():
|
||||
svc = MagicMock()
|
||||
svc.update_compression_metadata = MagicMock()
|
||||
return svc
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_conversation():
|
||||
return {
|
||||
"queries": [
|
||||
{"prompt": "What is Python?", "response": "A programming language."},
|
||||
{"prompt": "Tell me more.", "response": "It's versatile and popular."},
|
||||
{
|
||||
"prompt": "What about tools?",
|
||||
"response": "Python has many tools.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"tool_name": "search",
|
||||
"action_name": "web_search",
|
||||
"arguments": {"q": "python tools"},
|
||||
"result": "Found 10 results",
|
||||
"status": "success",
|
||||
}
|
||||
],
|
||||
},
|
||||
],
|
||||
"compression_metadata": {},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCompressionServiceInit:
|
||||
@patch("application.api.answer.services.compression.service.settings")
|
||||
def test_default_prompt_builder(self, mock_settings, mock_llm):
|
||||
mock_settings.COMPRESSION_PROMPT_VERSION = "v1.0"
|
||||
with patch(
|
||||
"application.api.answer.services.compression.service.CompressionPromptBuilder"
|
||||
):
|
||||
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
|
||||
assert svc.llm is mock_llm
|
||||
assert svc.model_id == "gpt-4"
|
||||
|
||||
def test_custom_prompt_builder(self, mock_llm):
|
||||
custom_builder = MagicMock()
|
||||
svc = CompressionService(
|
||||
llm=mock_llm, model_id="gpt-4", prompt_builder=custom_builder
|
||||
)
|
||||
assert svc.prompt_builder is custom_builder
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCompressConversation:
|
||||
def test_successful_compression(self, mock_llm, sample_conversation):
|
||||
mock_builder = MagicMock()
|
||||
mock_builder.build_prompt.return_value = [
|
||||
{"role": "system", "content": "Compress"},
|
||||
{"role": "user", "content": "Conversation..."},
|
||||
]
|
||||
mock_builder.version = "v1.0"
|
||||
|
||||
svc = CompressionService(
|
||||
llm=mock_llm, model_id="gpt-4", prompt_builder=mock_builder
|
||||
)
|
||||
|
||||
with patch(
|
||||
"application.api.answer.services.compression.service.TokenCounter"
|
||||
) as MockTC:
|
||||
MockTC.count_query_tokens.return_value = 1000
|
||||
MockTC.count_message_tokens.return_value = 100
|
||||
|
||||
result = svc.compress_conversation(sample_conversation, 2)
|
||||
|
||||
assert isinstance(result, CompressionMetadata)
|
||||
assert result.query_index == 2
|
||||
assert result.compressed_summary == "Compressed summary content"
|
||||
assert result.original_token_count == 1000
|
||||
assert result.compressed_token_count == 100
|
||||
assert result.compression_ratio == 10.0
|
||||
assert result.model_used == "gpt-4"
|
||||
assert result.compression_prompt_version == "v1.0"
|
||||
|
||||
def test_invalid_index_negative(self, mock_llm, sample_conversation):
|
||||
mock_builder = MagicMock()
|
||||
mock_builder.version = "v1.0"
|
||||
svc = CompressionService(
|
||||
llm=mock_llm, model_id="gpt-4", prompt_builder=mock_builder
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid compress_up_to_index"):
|
||||
svc.compress_conversation(sample_conversation, -1)
|
||||
|
||||
def test_invalid_index_too_large(self, mock_llm, sample_conversation):
|
||||
mock_builder = MagicMock()
|
||||
mock_builder.version = "v1.0"
|
||||
svc = CompressionService(
|
||||
llm=mock_llm, model_id="gpt-4", prompt_builder=mock_builder
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid compress_up_to_index"):
|
||||
svc.compress_conversation(sample_conversation, 10)
|
||||
|
||||
def test_with_existing_compressions(self, mock_llm):
|
||||
conversation = {
|
||||
"queries": [
|
||||
{"prompt": "q1", "response": "r1"},
|
||||
{"prompt": "q2", "response": "r2"},
|
||||
],
|
||||
"compression_metadata": {
|
||||
"compression_points": [
|
||||
{
|
||||
"query_index": 0,
|
||||
"compressed_summary": "Previous summary",
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
mock_builder = MagicMock()
|
||||
mock_builder.build_prompt.return_value = [
|
||||
{"role": "system", "content": "Compress"},
|
||||
{"role": "user", "content": "..."},
|
||||
]
|
||||
mock_builder.version = "v1.0"
|
||||
|
||||
svc = CompressionService(
|
||||
llm=mock_llm, model_id="gpt-4", prompt_builder=mock_builder
|
||||
)
|
||||
|
||||
with patch(
|
||||
"application.api.answer.services.compression.service.TokenCounter"
|
||||
) as MockTC:
|
||||
MockTC.count_query_tokens.return_value = 500
|
||||
MockTC.count_message_tokens.return_value = 50
|
||||
|
||||
result = svc.compress_conversation(conversation, 1)
|
||||
assert isinstance(result, CompressionMetadata)
|
||||
# Verify existing compressions were passed to prompt builder
|
||||
call_args = mock_builder.build_prompt.call_args
|
||||
assert call_args[0][1] == [
|
||||
{"query_index": 0, "compressed_summary": "Previous summary"}
|
||||
]
|
||||
|
||||
def test_zero_compressed_tokens_ratio(self, mock_llm, sample_conversation):
|
||||
mock_builder = MagicMock()
|
||||
mock_builder.build_prompt.return_value = [
|
||||
{"role": "system", "content": "C"},
|
||||
{"role": "user", "content": "..."},
|
||||
]
|
||||
mock_builder.version = "v1.0"
|
||||
|
||||
svc = CompressionService(
|
||||
llm=mock_llm, model_id="gpt-4", prompt_builder=mock_builder
|
||||
)
|
||||
|
||||
with patch(
|
||||
"application.api.answer.services.compression.service.TokenCounter"
|
||||
) as MockTC:
|
||||
MockTC.count_query_tokens.return_value = 1000
|
||||
MockTC.count_message_tokens.return_value = 0
|
||||
|
||||
result = svc.compress_conversation(sample_conversation, 2)
|
||||
assert result.compression_ratio == 0
|
||||
|
||||
def test_llm_error_propagates(self, sample_conversation):
|
||||
llm = MagicMock()
|
||||
llm.gen.side_effect = RuntimeError("LLM error")
|
||||
mock_builder = MagicMock()
|
||||
mock_builder.build_prompt.return_value = [
|
||||
{"role": "system", "content": "C"},
|
||||
{"role": "user", "content": "..."},
|
||||
]
|
||||
mock_builder.version = "v1.0"
|
||||
|
||||
svc = CompressionService(
|
||||
llm=llm, model_id="gpt-4", prompt_builder=mock_builder
|
||||
)
|
||||
|
||||
with patch(
|
||||
"application.api.answer.services.compression.service.TokenCounter"
|
||||
) as MockTC:
|
||||
MockTC.count_query_tokens.return_value = 100
|
||||
with pytest.raises(RuntimeError, match="LLM error"):
|
||||
svc.compress_conversation(sample_conversation, 2)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCompressAndSave:
|
||||
def test_saves_metadata_to_db(
|
||||
self, mock_llm, mock_conversation_service, sample_conversation
|
||||
):
|
||||
mock_builder = MagicMock()
|
||||
mock_builder.build_prompt.return_value = [
|
||||
{"role": "system", "content": "C"},
|
||||
{"role": "user", "content": "..."},
|
||||
]
|
||||
mock_builder.version = "v1.0"
|
||||
|
||||
svc = CompressionService(
|
||||
llm=mock_llm,
|
||||
model_id="gpt-4",
|
||||
conversation_service=mock_conversation_service,
|
||||
prompt_builder=mock_builder,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"application.api.answer.services.compression.service.TokenCounter"
|
||||
) as MockTC:
|
||||
MockTC.count_query_tokens.return_value = 500
|
||||
MockTC.count_message_tokens.return_value = 50
|
||||
|
||||
result = svc.compress_and_save("conv_123", sample_conversation, 2)
|
||||
|
||||
assert isinstance(result, CompressionMetadata)
|
||||
mock_conversation_service.update_compression_metadata.assert_called_once_with(
|
||||
"conv_123", result.to_dict()
|
||||
)
|
||||
|
||||
def test_raises_without_conversation_service(self, mock_llm, sample_conversation):
|
||||
mock_builder = MagicMock()
|
||||
mock_builder.version = "v1.0"
|
||||
svc = CompressionService(
|
||||
llm=mock_llm, model_id="gpt-4", prompt_builder=mock_builder
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="conversation_service required"):
|
||||
svc.compress_and_save("conv_123", sample_conversation, 2)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGetCompressedContext:
|
||||
def test_no_compression_returns_full_history(self, mock_llm):
|
||||
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
|
||||
conversation = {
|
||||
"queries": [{"prompt": "q1", "response": "r1"}],
|
||||
"compression_metadata": {},
|
||||
}
|
||||
|
||||
summary, queries = svc.get_compressed_context(conversation)
|
||||
|
||||
assert summary is None
|
||||
assert queries == [{"prompt": "q1", "response": "r1"}]
|
||||
|
||||
def test_no_compression_points_returns_full_history(self, mock_llm):
|
||||
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
|
||||
conversation = {
|
||||
"queries": [{"prompt": "q1", "response": "r1"}],
|
||||
"compression_metadata": {"is_compressed": True, "compression_points": []},
|
||||
}
|
||||
|
||||
summary, queries = svc.get_compressed_context(conversation)
|
||||
assert summary is None
|
||||
assert len(queries) == 1
|
||||
|
||||
def test_with_compression_returns_summary_and_recent(self, mock_llm):
|
||||
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
|
||||
conversation = {
|
||||
"queries": [
|
||||
{"prompt": "q0", "response": "r0"},
|
||||
{"prompt": "q1", "response": "r1"},
|
||||
{"prompt": "q2", "response": "r2"},
|
||||
],
|
||||
"compression_metadata": {
|
||||
"is_compressed": True,
|
||||
"compression_points": [
|
||||
{
|
||||
"query_index": 1,
|
||||
"compressed_summary": "Summary of q0 and q1",
|
||||
"compressed_token_count": 50,
|
||||
"original_token_count": 500,
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
summary, queries = svc.get_compressed_context(conversation)
|
||||
|
||||
assert summary == "Summary of q0 and q1"
|
||||
assert len(queries) == 1
|
||||
assert queries[0]["prompt"] == "q2"
|
||||
|
||||
def test_none_queries_returns_empty(self, mock_llm):
|
||||
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
|
||||
conversation = {
|
||||
"queries": None,
|
||||
"compression_metadata": {},
|
||||
}
|
||||
|
||||
summary, queries = svc.get_compressed_context(conversation)
|
||||
assert summary is None
|
||||
assert queries == []
|
||||
|
||||
def test_exception_falls_back_to_full_history(self, mock_llm):
|
||||
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
|
||||
conversation = {
|
||||
"queries": [{"prompt": "q", "response": "r"}],
|
||||
"compression_metadata": {
|
||||
"is_compressed": True,
|
||||
"compression_points": "invalid", # This will cause an error
|
||||
},
|
||||
}
|
||||
|
||||
summary, queries = svc.get_compressed_context(conversation)
|
||||
assert summary is None
|
||||
assert queries == [{"prompt": "q", "response": "r"}]
|
||||
|
||||
def test_exception_with_none_queries_returns_empty(self, mock_llm):
|
||||
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
|
||||
# Force exception by making compression_points non-iterable
|
||||
conversation = {
|
||||
"queries": None,
|
||||
"compression_metadata": {
|
||||
"is_compressed": True,
|
||||
"compression_points": "bad",
|
||||
},
|
||||
}
|
||||
|
||||
summary, queries = svc.get_compressed_context(conversation)
|
||||
assert summary is None
|
||||
assert queries == []
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestExtractSummary:
|
||||
def test_extracts_from_summary_tags(self, mock_llm):
|
||||
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
|
||||
response = "<analysis>Some analysis</analysis><summary>The actual summary</summary>"
|
||||
result = svc._extract_summary(response)
|
||||
assert result == "The actual summary"
|
||||
|
||||
def test_removes_analysis_tags_when_no_summary(self, mock_llm):
|
||||
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
|
||||
response = "<analysis>analysis text</analysis>Raw summary text here"
|
||||
result = svc._extract_summary(response)
|
||||
assert result == "Raw summary text here"
|
||||
|
||||
def test_returns_full_response_when_no_tags(self, mock_llm):
|
||||
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
|
||||
response = "Just a plain text response"
|
||||
result = svc._extract_summary(response)
|
||||
assert result == "Just a plain text response"
|
||||
|
||||
def test_multiline_summary(self, mock_llm):
|
||||
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
|
||||
response = "<summary>Line 1\nLine 2\nLine 3</summary>"
|
||||
result = svc._extract_summary(response)
|
||||
assert "Line 1" in result
|
||||
assert "Line 3" in result
|
||||
|
||||
def test_strips_whitespace(self, mock_llm):
|
||||
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
|
||||
response = "<summary> Trimmed </summary>"
|
||||
result = svc._extract_summary(response)
|
||||
assert result == "Trimmed"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestLogToolCallStats:
|
||||
def test_no_tool_calls(self, mock_llm):
|
||||
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
|
||||
queries = [{"prompt": "q", "response": "r"}]
|
||||
# Should not raise
|
||||
svc._log_tool_call_stats(queries)
|
||||
|
||||
def test_with_tool_calls(self, mock_llm):
|
||||
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
|
||||
queries = [
|
||||
{
|
||||
"prompt": "q",
|
||||
"response": "r",
|
||||
"tool_calls": [
|
||||
{
|
||||
"tool_name": "search",
|
||||
"action_name": "web",
|
||||
"result": "result text",
|
||||
},
|
||||
{
|
||||
"tool_name": "search",
|
||||
"action_name": "web",
|
||||
"result": "more text",
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
# Should not raise - just logs
|
||||
svc._log_tool_call_stats(queries)
|
||||
|
||||
def test_empty_queries(self, mock_llm):
|
||||
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
|
||||
svc._log_tool_call_stats([])
|
||||
|
||||
def test_tool_call_with_none_result(self, mock_llm):
|
||||
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
|
||||
queries = [
|
||||
{
|
||||
"prompt": "q",
|
||||
"response": "r",
|
||||
"tool_calls": [
|
||||
{
|
||||
"tool_name": "t",
|
||||
"action_name": "a",
|
||||
"result": None,
|
||||
}
|
||||
],
|
||||
}
|
||||
]
|
||||
svc._log_tool_call_stats(queries)
|
||||
@@ -0,0 +1,46 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCompressionThresholdChecker:
|
||||
|
||||
def _make_checker(self, pct=0.7):
|
||||
from application.api.answer.services.compression.threshold_checker import (
|
||||
CompressionThresholdChecker,
|
||||
)
|
||||
|
||||
return CompressionThresholdChecker(threshold_percentage=pct)
|
||||
|
||||
@patch(
|
||||
"application.api.answer.services.compression.threshold_checker.get_token_limit",
|
||||
return_value=8000,
|
||||
)
|
||||
@patch(
|
||||
"application.api.answer.services.compression.threshold_checker.TokenCounter.count_message_tokens",
|
||||
return_value=6000,
|
||||
)
|
||||
def test_check_message_tokens_above_threshold(self, mock_count, mock_limit):
|
||||
checker = self._make_checker(0.7)
|
||||
assert checker.check_message_tokens([{"role": "user"}], "gpt-4") is True
|
||||
|
||||
@patch(
|
||||
"application.api.answer.services.compression.threshold_checker.get_token_limit",
|
||||
return_value=8000,
|
||||
)
|
||||
@patch(
|
||||
"application.api.answer.services.compression.threshold_checker.TokenCounter.count_message_tokens",
|
||||
return_value=1000,
|
||||
)
|
||||
def test_check_message_tokens_below_threshold(self, mock_count, mock_limit):
|
||||
checker = self._make_checker(0.7)
|
||||
assert checker.check_message_tokens([{"role": "user"}], "gpt-4") is False
|
||||
|
||||
@patch(
|
||||
"application.api.answer.services.compression.threshold_checker.TokenCounter.count_message_tokens",
|
||||
side_effect=Exception("Token error"),
|
||||
)
|
||||
def test_check_message_tokens_exception_returns_false(self, mock_count):
|
||||
checker = self._make_checker(0.7)
|
||||
assert checker.check_message_tokens([], "gpt-4") is False
|
||||
131
tests/api/answer/services/compression/test_types.py
Normal file
131
tests/api/answer/services/compression/test_types.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""Tests for application/api/answer/services/compression/types.py"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from application.api.answer.services.compression.types import (
|
||||
CompressionMetadata,
|
||||
CompressionResult,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCompressionMetadata:
|
||||
def _make_metadata(self, **overrides):
|
||||
defaults = dict(
|
||||
timestamp=datetime(2025, 1, 1, tzinfo=timezone.utc),
|
||||
query_index=5,
|
||||
compressed_summary="Summary of conversation",
|
||||
original_token_count=5000,
|
||||
compressed_token_count=500,
|
||||
compression_ratio=10.0,
|
||||
model_used="gpt-4",
|
||||
compression_prompt_version="v1.0",
|
||||
)
|
||||
defaults.update(overrides)
|
||||
return CompressionMetadata(**defaults)
|
||||
|
||||
def test_to_dict_contains_all_fields(self):
|
||||
meta = self._make_metadata()
|
||||
d = meta.to_dict()
|
||||
|
||||
assert d["timestamp"] == datetime(2025, 1, 1, tzinfo=timezone.utc)
|
||||
assert d["query_index"] == 5
|
||||
assert d["compressed_summary"] == "Summary of conversation"
|
||||
assert d["original_token_count"] == 5000
|
||||
assert d["compressed_token_count"] == 500
|
||||
assert d["compression_ratio"] == 10.0
|
||||
assert d["model_used"] == "gpt-4"
|
||||
assert d["compression_prompt_version"] == "v1.0"
|
||||
|
||||
def test_to_dict_returns_dict_type(self):
|
||||
meta = self._make_metadata()
|
||||
assert isinstance(meta.to_dict(), dict)
|
||||
|
||||
def test_to_dict_field_count(self):
|
||||
meta = self._make_metadata()
|
||||
d = meta.to_dict()
|
||||
assert len(d) == 8
|
||||
|
||||
def test_attributes_accessible(self):
|
||||
meta = self._make_metadata(query_index=10, compression_ratio=5.5)
|
||||
assert meta.query_index == 10
|
||||
assert meta.compression_ratio == 5.5
|
||||
|
||||
def test_zero_compressed_tokens(self):
|
||||
meta = self._make_metadata(compressed_token_count=0, compression_ratio=0)
|
||||
d = meta.to_dict()
|
||||
assert d["compressed_token_count"] == 0
|
||||
assert d["compression_ratio"] == 0
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCompressionResult:
|
||||
def test_success_with_compression(self):
|
||||
meta = CompressionMetadata(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
query_index=3,
|
||||
compressed_summary="summary",
|
||||
original_token_count=1000,
|
||||
compressed_token_count=100,
|
||||
compression_ratio=10.0,
|
||||
model_used="gpt-4",
|
||||
compression_prompt_version="v1.0",
|
||||
)
|
||||
queries = [{"prompt": "q1", "response": "r1"}]
|
||||
result = CompressionResult.success_with_compression("summary", queries, meta)
|
||||
|
||||
assert result.success is True
|
||||
assert result.compressed_summary == "summary"
|
||||
assert result.recent_queries == queries
|
||||
assert result.metadata is meta
|
||||
assert result.compression_performed is True
|
||||
assert result.error is None
|
||||
|
||||
def test_success_no_compression(self):
|
||||
queries = [{"prompt": "q1", "response": "r1"}]
|
||||
result = CompressionResult.success_no_compression(queries)
|
||||
|
||||
assert result.success is True
|
||||
assert result.compressed_summary is None
|
||||
assert result.recent_queries == queries
|
||||
assert result.metadata is None
|
||||
assert result.compression_performed is False
|
||||
assert result.error is None
|
||||
|
||||
def test_failure(self):
|
||||
result = CompressionResult.failure("something went wrong")
|
||||
|
||||
assert result.success is False
|
||||
assert result.error == "something went wrong"
|
||||
assert result.compression_performed is False
|
||||
assert result.compressed_summary is None
|
||||
assert result.recent_queries == []
|
||||
assert result.metadata is None
|
||||
|
||||
def test_as_history_extracts_prompt_response(self):
|
||||
queries = [
|
||||
{"prompt": "Hello", "response": "Hi", "extra": "ignored"},
|
||||
{"prompt": "How?", "response": "Fine"},
|
||||
]
|
||||
result = CompressionResult.success_no_compression(queries)
|
||||
history = result.as_history()
|
||||
|
||||
assert len(history) == 2
|
||||
assert history[0] == {"prompt": "Hello", "response": "Hi"}
|
||||
assert history[1] == {"prompt": "How?", "response": "Fine"}
|
||||
|
||||
def test_as_history_empty_queries(self):
|
||||
result = CompressionResult.success_no_compression([])
|
||||
assert result.as_history() == []
|
||||
|
||||
def test_default_recent_queries_is_empty_list(self):
|
||||
result = CompressionResult(success=True)
|
||||
assert result.recent_queries == []
|
||||
assert result.as_history() == []
|
||||
|
||||
def test_success_no_compression_with_empty_list(self):
|
||||
result = CompressionResult.success_no_compression([])
|
||||
assert result.success is True
|
||||
assert result.recent_queries == []
|
||||
578
tests/api/answer/test_base_routes.py
Normal file
578
tests/api/answer/test_base_routes.py
Normal file
@@ -0,0 +1,578 @@
|
||||
"""Unit tests for application/api/answer/routes/base.py — BaseAnswerResource.
|
||||
|
||||
Additional coverage beyond tests/api/answer/routes/test_base.py:
|
||||
- _prepare_tool_calls_for_logging: truncation, non-dict items
|
||||
- complete_stream: tool_calls, thoughts, structured output, metadata,
|
||||
isNoneDoc, GeneratorExit handling, compression metadata
|
||||
- process_response_stream: structured answer, incomplete stream
|
||||
- error_stream_generate: format
|
||||
- check_usage: string boolean parsing ("True" strings)
|
||||
"""
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from bson import ObjectId
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestPrepareToolCallsForLogging:
|
||||
|
||||
def test_empty_list(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
assert resource._prepare_tool_calls_for_logging([]) == []
|
||||
|
||||
def test_none_returns_empty(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
assert resource._prepare_tool_calls_for_logging(None) == []
|
||||
|
||||
def test_truncates_long_result(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
tool_calls = [{"result": "x" * 20000}]
|
||||
prepared = resource._prepare_tool_calls_for_logging(tool_calls, max_chars=100)
|
||||
assert len(prepared[0]["result"]) == 100
|
||||
|
||||
def test_truncates_result_full(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
tool_calls = [{"result_full": "y" * 20000}]
|
||||
prepared = resource._prepare_tool_calls_for_logging(tool_calls, max_chars=50)
|
||||
assert len(prepared[0]["result_full"]) == 50
|
||||
|
||||
def test_non_dict_items_wrapped(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
tool_calls = ["string_item", 42]
|
||||
prepared = resource._prepare_tool_calls_for_logging(tool_calls)
|
||||
assert prepared[0] == {"result": "string_item"}
|
||||
assert prepared[1] == {"result": "42"}
|
||||
|
||||
def test_preserves_short_results(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
tool_calls = [{"tool_name": "search", "result": "short text"}]
|
||||
prepared = resource._prepare_tool_calls_for_logging(tool_calls)
|
||||
assert prepared[0]["result"] == "short text"
|
||||
assert prepared[0]["tool_name"] == "search"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCompleteStreamToolCalls:
|
||||
|
||||
def test_streams_tool_calls(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.gen.return_value = iter(
|
||||
[
|
||||
{"answer": "Using tool..."},
|
||||
{"tool_calls": [{"name": "search", "result": "found"}]},
|
||||
]
|
||||
)
|
||||
|
||||
stream = list(
|
||||
resource.complete_stream(
|
||||
question="Search for X",
|
||||
agent=mock_agent,
|
||||
conversation_id=None,
|
||||
user_api_key=None,
|
||||
decoded_token={"sub": "u"},
|
||||
should_save_conversation=False,
|
||||
)
|
||||
)
|
||||
tool_chunks = [s for s in stream if '"type": "tool_calls"' in s]
|
||||
assert len(tool_chunks) == 1
|
||||
|
||||
def test_streams_thought_events(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.gen.return_value = iter(
|
||||
[
|
||||
{"thought": "Let me think..."},
|
||||
{"answer": "Here is the answer"},
|
||||
]
|
||||
)
|
||||
|
||||
stream = list(
|
||||
resource.complete_stream(
|
||||
question="Q",
|
||||
agent=mock_agent,
|
||||
conversation_id=None,
|
||||
user_api_key=None,
|
||||
decoded_token={"sub": "u"},
|
||||
should_save_conversation=False,
|
||||
)
|
||||
)
|
||||
thought_chunks = [s for s in stream if '"type": "thought"' in s]
|
||||
assert len(thought_chunks) == 1
|
||||
assert "Let me think" in thought_chunks[0]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCompleteStreamStructuredOutput:
|
||||
|
||||
def test_streams_structured_answer(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.gen.return_value = iter(
|
||||
[
|
||||
{
|
||||
"answer": '{"key": "value"}',
|
||||
"structured": True,
|
||||
"schema": {"type": "object"},
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
stream = list(
|
||||
resource.complete_stream(
|
||||
question="Q",
|
||||
agent=mock_agent,
|
||||
conversation_id=None,
|
||||
user_api_key=None,
|
||||
decoded_token={"sub": "u"},
|
||||
should_save_conversation=False,
|
||||
)
|
||||
)
|
||||
structured_chunks = [
|
||||
s for s in stream if '"type": "structured_answer"' in s
|
||||
]
|
||||
assert len(structured_chunks) == 1
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCompleteStreamMetadata:
|
||||
|
||||
def test_metadata_collected(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.gen.return_value = iter(
|
||||
[
|
||||
{"metadata": {"search_query": "test"}},
|
||||
{"answer": "result"},
|
||||
]
|
||||
)
|
||||
|
||||
stream = list(
|
||||
resource.complete_stream(
|
||||
question="Q",
|
||||
agent=mock_agent,
|
||||
conversation_id=None,
|
||||
user_api_key=None,
|
||||
decoded_token={"sub": "u"},
|
||||
should_save_conversation=False,
|
||||
)
|
||||
)
|
||||
# Should not crash, metadata handled silently
|
||||
answer_chunks = [s for s in stream if '"type": "answer"' in s]
|
||||
assert len(answer_chunks) == 1
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCompleteStreamIsNoneDoc:
|
||||
|
||||
def test_isNoneDoc_sets_source_to_none(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.gen.return_value = iter(
|
||||
[
|
||||
{"answer": "answer"},
|
||||
{"sources": [{"text": "doc", "source": "real_source"}]},
|
||||
]
|
||||
)
|
||||
|
||||
stream = list(
|
||||
resource.complete_stream(
|
||||
question="Q",
|
||||
agent=mock_agent,
|
||||
conversation_id=None,
|
||||
user_api_key=None,
|
||||
decoded_token={"sub": "u"},
|
||||
isNoneDoc=True,
|
||||
should_save_conversation=False,
|
||||
)
|
||||
)
|
||||
# Verify stream completes without error
|
||||
end_chunks = [s for s in stream if '"type": "end"' in s]
|
||||
assert len(end_chunks) == 1
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCompleteStreamErrorType:
|
||||
|
||||
def test_error_type_event_sanitized(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.gen.return_value = iter(
|
||||
[
|
||||
{"type": "error", "error": "API key invalid: sk-xxx"},
|
||||
]
|
||||
)
|
||||
|
||||
stream = list(
|
||||
resource.complete_stream(
|
||||
question="Q",
|
||||
agent=mock_agent,
|
||||
conversation_id=None,
|
||||
user_api_key=None,
|
||||
decoded_token={"sub": "u"},
|
||||
should_save_conversation=False,
|
||||
)
|
||||
)
|
||||
error_chunks = [s for s in stream if '"type": "error"' in s]
|
||||
assert len(error_chunks) == 1
|
||||
|
||||
def test_non_error_type_event_passed_through(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.gen.return_value = iter(
|
||||
[
|
||||
{"type": "custom_event", "data": "value"},
|
||||
]
|
||||
)
|
||||
|
||||
stream = list(
|
||||
resource.complete_stream(
|
||||
question="Q",
|
||||
agent=mock_agent,
|
||||
conversation_id=None,
|
||||
user_api_key=None,
|
||||
decoded_token={"sub": "u"},
|
||||
should_save_conversation=False,
|
||||
)
|
||||
)
|
||||
custom_chunks = [s for s in stream if '"type": "custom_event"' in s]
|
||||
assert len(custom_chunks) == 1
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestProcessResponseStreamExtended:
|
||||
|
||||
def test_handles_structured_answer(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
stream = [
|
||||
f'data: {json.dumps({"type": "structured_answer", "answer": "{}", "structured": True, "schema": None})}\n\n',
|
||||
f'data: {json.dumps({"type": "id", "id": str(ObjectId())})}\n\n',
|
||||
f'data: {json.dumps({"type": "end"})}\n\n',
|
||||
]
|
||||
result = resource.process_response_stream(iter(stream))
|
||||
assert result[1] == "{}"
|
||||
# Structured output adds extra tuple element
|
||||
assert len(result) == 7
|
||||
assert result[6]["structured"] is True
|
||||
|
||||
def test_handles_tool_calls_event(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
stream = [
|
||||
f'data: {json.dumps({"type": "answer", "answer": "result"})}\n\n',
|
||||
f'data: {json.dumps({"type": "tool_calls", "tool_calls": [{"name": "t1"}]})}\n\n',
|
||||
f'data: {json.dumps({"type": "id", "id": "conv1"})}\n\n',
|
||||
f'data: {json.dumps({"type": "end"})}\n\n',
|
||||
]
|
||||
result = resource.process_response_stream(iter(stream))
|
||||
assert result[3] == [{"name": "t1"}]
|
||||
|
||||
def test_incomplete_stream(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
stream = [
|
||||
f'data: {json.dumps({"type": "answer", "answer": "partial"})}\n\n',
|
||||
]
|
||||
result = resource.process_response_stream(iter(stream))
|
||||
assert result[4] == "Stream ended unexpectedly"
|
||||
|
||||
def test_handles_thought_event(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
stream = [
|
||||
f'data: {json.dumps({"type": "thought", "thought": "thinking..."})}\n\n',
|
||||
f'data: {json.dumps({"type": "end"})}\n\n',
|
||||
]
|
||||
result = resource.process_response_stream(iter(stream))
|
||||
assert result[4] == "thinking..."
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCheckUsageStringBooleans:
|
||||
|
||||
def test_string_true_parsed_correctly(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
from application.core.settings import settings
|
||||
|
||||
with flask_app.app_context():
|
||||
agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"]
|
||||
agents_collection.insert_one(
|
||||
{
|
||||
"_id": ObjectId(),
|
||||
"key": "str_bool_key",
|
||||
"limited_token_mode": "True",
|
||||
"token_limit": 1000000,
|
||||
"limited_request_mode": "True",
|
||||
"request_limit": 1000000,
|
||||
}
|
||||
)
|
||||
resource = BaseAnswerResource()
|
||||
result = resource.check_usage({"user_api_key": "str_bool_key"})
|
||||
# Should not exceed limits, so returns None
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCompleteStreamCompressionMetadata:
|
||||
"""Cover lines 307-319 (compression metadata persistence in complete_stream)."""
|
||||
|
||||
def test_compression_metadata_persisted(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.gen.return_value = iter(
|
||||
[
|
||||
{"answer": "compressed answer"},
|
||||
]
|
||||
)
|
||||
mock_agent.compression_metadata = {"ratio": 2.5}
|
||||
mock_agent.compression_saved = False
|
||||
mock_agent.tool_calls = []
|
||||
|
||||
resource.conversation_service = MagicMock()
|
||||
resource.conversation_service.save_conversation.return_value = "conv123"
|
||||
|
||||
stream = list(
|
||||
resource.complete_stream(
|
||||
question="Q",
|
||||
agent=mock_agent,
|
||||
conversation_id=None,
|
||||
user_api_key=None,
|
||||
decoded_token={"sub": "u"},
|
||||
should_save_conversation=True,
|
||||
model_id="gpt-4",
|
||||
)
|
||||
)
|
||||
|
||||
# Verify compression metadata was persisted
|
||||
resource.conversation_service.update_compression_metadata.assert_called_once_with(
|
||||
"conv123", {"ratio": 2.5}
|
||||
)
|
||||
resource.conversation_service.append_compression_message.assert_called_once()
|
||||
assert mock_agent.compression_saved is True
|
||||
end_chunks = [s for s in stream if '"type": "end"' in s]
|
||||
assert len(end_chunks) == 1
|
||||
|
||||
def test_compression_metadata_error_handled(self, mock_mongo_db, flask_app):
|
||||
"""Cover lines 318-322: compression metadata persistence error."""
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.gen.return_value = iter([{"answer": "answer"}])
|
||||
mock_agent.compression_metadata = {"ratio": 2.5}
|
||||
mock_agent.compression_saved = False
|
||||
mock_agent.tool_calls = []
|
||||
|
||||
resource.conversation_service = MagicMock()
|
||||
resource.conversation_service.save_conversation.return_value = "conv123"
|
||||
resource.conversation_service.update_compression_metadata.side_effect = (
|
||||
Exception("db error")
|
||||
)
|
||||
|
||||
stream = list(
|
||||
resource.complete_stream(
|
||||
question="Q",
|
||||
agent=mock_agent,
|
||||
conversation_id=None,
|
||||
user_api_key=None,
|
||||
decoded_token={"sub": "u"},
|
||||
should_save_conversation=True,
|
||||
model_id="gpt-4",
|
||||
)
|
||||
)
|
||||
|
||||
# Stream should still complete despite compression error
|
||||
end_chunks = [s for s in stream if '"type": "end"' in s]
|
||||
assert len(end_chunks) == 1
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCompleteStreamLogTruncation:
|
||||
"""Cover line 354: log data truncation for long values."""
|
||||
|
||||
def test_long_response_truncated_in_log(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
mock_agent = MagicMock()
|
||||
long_answer = "x" * 20000
|
||||
mock_agent.gen.return_value = iter([{"answer": long_answer}])
|
||||
mock_agent.tool_calls = []
|
||||
|
||||
stream = list(
|
||||
resource.complete_stream(
|
||||
question="Q",
|
||||
agent=mock_agent,
|
||||
conversation_id=None,
|
||||
user_api_key=None,
|
||||
decoded_token={"sub": "u"},
|
||||
should_save_conversation=False,
|
||||
)
|
||||
)
|
||||
|
||||
end_chunks = [s for s in stream if '"type": "end"' in s]
|
||||
assert len(end_chunks) == 1
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCompleteStreamGeneratorExit:
|
||||
"""Cover lines 360-416 (GeneratorExit handling in complete_stream)."""
|
||||
|
||||
def test_generator_exit_saves_partial_response(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
mock_agent = MagicMock()
|
||||
|
||||
def gen_with_answers():
|
||||
yield {"answer": "partial"}
|
||||
yield {"answer": " answer"}
|
||||
# Simulating a long stream that gets interrupted
|
||||
yield {"answer": " more"}
|
||||
|
||||
mock_agent.gen.return_value = gen_with_answers()
|
||||
mock_agent.compression_metadata = None
|
||||
mock_agent.compression_saved = False
|
||||
mock_agent.tool_calls = []
|
||||
|
||||
resource.conversation_service = MagicMock()
|
||||
resource.conversation_service.save_conversation.return_value = "conv1"
|
||||
|
||||
gen = resource.complete_stream(
|
||||
question="Q",
|
||||
agent=mock_agent,
|
||||
conversation_id="conv1",
|
||||
user_api_key=None,
|
||||
decoded_token={"sub": "u"},
|
||||
should_save_conversation=True,
|
||||
model_id="gpt-4",
|
||||
)
|
||||
|
||||
# Read first chunk and then close (simulating client disconnect)
|
||||
chunk = next(gen)
|
||||
assert "partial" in chunk
|
||||
gen.close() # This triggers GeneratorExit
|
||||
|
||||
def test_generator_exit_with_compression_metadata(self, mock_mongo_db, flask_app):
|
||||
"""Cover lines 393-411: GeneratorExit with compression metadata."""
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
mock_agent = MagicMock()
|
||||
|
||||
def gen_answers():
|
||||
yield {"answer": "partial answer"}
|
||||
|
||||
mock_agent.gen.return_value = gen_answers()
|
||||
mock_agent.compression_metadata = {"ratio": 3.0}
|
||||
mock_agent.compression_saved = False
|
||||
mock_agent.tool_calls = []
|
||||
|
||||
resource.conversation_service = MagicMock()
|
||||
resource.conversation_service.save_conversation.return_value = "conv1"
|
||||
|
||||
gen = resource.complete_stream(
|
||||
question="Q",
|
||||
agent=mock_agent,
|
||||
conversation_id="conv1",
|
||||
user_api_key=None,
|
||||
decoded_token={"sub": "u"},
|
||||
should_save_conversation=True,
|
||||
model_id="gpt-4",
|
||||
isNoneDoc=True,
|
||||
)
|
||||
|
||||
next(gen)
|
||||
gen.close()
|
||||
|
||||
def test_generator_exit_save_error_handled(self, mock_mongo_db, flask_app):
|
||||
"""Cover lines 412-415: exception during partial save."""
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
mock_agent = MagicMock()
|
||||
|
||||
def gen_answers():
|
||||
yield {"answer": "partial"}
|
||||
|
||||
mock_agent.gen.return_value = gen_answers()
|
||||
mock_agent.compression_metadata = None
|
||||
mock_agent.compression_saved = False
|
||||
mock_agent.tool_calls = []
|
||||
|
||||
resource.conversation_service = MagicMock()
|
||||
resource.conversation_service.save_conversation.side_effect = Exception(
|
||||
"save error"
|
||||
)
|
||||
|
||||
gen = resource.complete_stream(
|
||||
question="Q",
|
||||
agent=mock_agent,
|
||||
conversation_id="conv1",
|
||||
user_api_key=None,
|
||||
decoded_token={"sub": "u"},
|
||||
should_save_conversation=True,
|
||||
model_id="gpt-4",
|
||||
)
|
||||
|
||||
next(gen)
|
||||
gen.close() # Should not crash even with save error
|
||||
478
tests/api/answer/test_conversation_service.py
Normal file
478
tests/api/answer/test_conversation_service.py
Normal file
@@ -0,0 +1,478 @@
|
||||
"""Unit tests for application/api/answer/services/conversation_service.py.
|
||||
|
||||
Additional coverage beyond tests/api/answer/services/test_conversation_service.py:
|
||||
- save_conversation: index-based update, metadata persistence, agent key tracking
|
||||
- update_compression_metadata
|
||||
- append_compression_message
|
||||
- get_compression_metadata
|
||||
- Edge cases: None token, empty summary, shared_with access
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
import pytest
|
||||
from bson import ObjectId
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestConversationServiceGetExtended:
|
||||
|
||||
def test_returns_conversation_for_shared_user(self, mock_mongo_db):
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
|
||||
service = ConversationService()
|
||||
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
|
||||
|
||||
conv_id = ObjectId()
|
||||
collection.insert_one(
|
||||
{
|
||||
"_id": conv_id,
|
||||
"user": "owner_123",
|
||||
"shared_with": ["shared_user"],
|
||||
"name": "Shared Conv",
|
||||
"queries": [],
|
||||
}
|
||||
)
|
||||
|
||||
result = service.get_conversation(str(conv_id), "shared_user")
|
||||
assert result is not None
|
||||
assert result["name"] == "Shared Conv"
|
||||
|
||||
def test_handles_exception_gracefully(self, mock_mongo_db):
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
|
||||
service = ConversationService()
|
||||
# Pass an invalid ObjectId
|
||||
result = service.get_conversation("not-an-objectid", "user_123")
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSaveConversationExtended:
|
||||
|
||||
def test_raises_for_none_token(self, mock_mongo_db):
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
|
||||
service = ConversationService()
|
||||
with pytest.raises(ValueError, match="Invalid or missing authentication"):
|
||||
service.save_conversation(
|
||||
conversation_id=None,
|
||||
question="Q",
|
||||
response="A",
|
||||
thought="",
|
||||
sources=[],
|
||||
tool_calls=[],
|
||||
llm=Mock(),
|
||||
model_id="m",
|
||||
decoded_token=None,
|
||||
)
|
||||
|
||||
def test_update_existing_at_index(self, mock_mongo_db):
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
|
||||
service = ConversationService()
|
||||
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
|
||||
|
||||
conv_id = ObjectId()
|
||||
collection.insert_one(
|
||||
{
|
||||
"_id": conv_id,
|
||||
"user": "user_123",
|
||||
"name": "Conv",
|
||||
"queries": [
|
||||
{
|
||||
"prompt": "Q1",
|
||||
"response": "A1",
|
||||
"thought": "",
|
||||
"sources": [],
|
||||
"tool_calls": [],
|
||||
},
|
||||
{
|
||||
"prompt": "Q2",
|
||||
"response": "A2",
|
||||
"thought": "",
|
||||
"sources": [],
|
||||
"tool_calls": [],
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
result = service.save_conversation(
|
||||
conversation_id=str(conv_id),
|
||||
question="Q1_updated",
|
||||
response="A1_updated",
|
||||
thought="thinking",
|
||||
sources=[],
|
||||
tool_calls=[],
|
||||
llm=Mock(),
|
||||
model_id="gpt-4",
|
||||
decoded_token={"sub": "user_123"},
|
||||
index=0,
|
||||
)
|
||||
assert result == str(conv_id)
|
||||
|
||||
saved = collection.find_one({"_id": conv_id})
|
||||
assert saved["queries"][0]["prompt"] == "Q1_updated"
|
||||
assert saved["queries"][0]["response"] == "A1_updated"
|
||||
|
||||
def test_update_at_index_unauthorized(self, mock_mongo_db):
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
|
||||
service = ConversationService()
|
||||
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
|
||||
|
||||
conv_id = ObjectId()
|
||||
collection.insert_one(
|
||||
{
|
||||
"_id": conv_id,
|
||||
"user": "owner",
|
||||
"queries": [{"prompt": "Q", "response": "A"}],
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="not found or unauthorized"):
|
||||
service.save_conversation(
|
||||
conversation_id=str(conv_id),
|
||||
question="Hack",
|
||||
response="Attempt",
|
||||
thought="",
|
||||
sources=[],
|
||||
tool_calls=[],
|
||||
llm=Mock(),
|
||||
model_id="m",
|
||||
decoded_token={"sub": "hacker"},
|
||||
index=0,
|
||||
)
|
||||
|
||||
def test_saves_metadata(self, mock_mongo_db):
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
|
||||
service = ConversationService()
|
||||
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
|
||||
|
||||
mock_llm = Mock()
|
||||
mock_llm.gen.return_value = "Title"
|
||||
|
||||
conv_id = service.save_conversation(
|
||||
conversation_id=None,
|
||||
question="Q",
|
||||
response="A",
|
||||
thought="",
|
||||
sources=[],
|
||||
tool_calls=[],
|
||||
llm=mock_llm,
|
||||
model_id="m",
|
||||
decoded_token={"sub": "user_123"},
|
||||
metadata={"search_query": "rewritten query"},
|
||||
)
|
||||
|
||||
saved = collection.find_one({"_id": ObjectId(conv_id)})
|
||||
assert saved["queries"][0]["metadata"] == {"search_query": "rewritten query"}
|
||||
|
||||
def test_no_metadata_when_none(self, mock_mongo_db):
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
|
||||
service = ConversationService()
|
||||
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
|
||||
|
||||
mock_llm = Mock()
|
||||
mock_llm.gen.return_value = "Title"
|
||||
|
||||
conv_id = service.save_conversation(
|
||||
conversation_id=None,
|
||||
question="Q",
|
||||
response="A",
|
||||
thought="",
|
||||
sources=[],
|
||||
tool_calls=[],
|
||||
llm=mock_llm,
|
||||
model_id="m",
|
||||
decoded_token={"sub": "user_123"},
|
||||
metadata=None,
|
||||
)
|
||||
|
||||
saved = collection.find_one({"_id": ObjectId(conv_id)})
|
||||
assert "metadata" not in saved["queries"][0]
|
||||
|
||||
def test_saves_with_api_key_and_agent(self, mock_mongo_db):
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
|
||||
service = ConversationService()
|
||||
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
|
||||
agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"]
|
||||
|
||||
agent_id = ObjectId()
|
||||
agents_collection.insert_one(
|
||||
{"_id": agent_id, "key": "agent_key_123", "user": "user_123"}
|
||||
)
|
||||
|
||||
mock_llm = Mock()
|
||||
mock_llm.gen.return_value = "Title"
|
||||
|
||||
conv_id = service.save_conversation(
|
||||
conversation_id=None,
|
||||
question="Q",
|
||||
response="A",
|
||||
thought="",
|
||||
sources=[],
|
||||
tool_calls=[],
|
||||
llm=mock_llm,
|
||||
model_id="m",
|
||||
decoded_token={"sub": "user_123"},
|
||||
api_key="agent_key_123",
|
||||
agent_id=str(agent_id),
|
||||
)
|
||||
|
||||
saved = collection.find_one({"_id": ObjectId(conv_id)})
|
||||
assert saved["api_key"] == "agent_key_123"
|
||||
assert saved["agent_id"] == str(agent_id)
|
||||
|
||||
def test_empty_completion_uses_question_prefix(self, mock_mongo_db):
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
|
||||
service = ConversationService()
|
||||
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
|
||||
|
||||
mock_llm = Mock()
|
||||
mock_llm.gen.return_value = " " # whitespace only
|
||||
|
||||
conv_id = service.save_conversation(
|
||||
conversation_id=None,
|
||||
question="What is the meaning of life in programming?",
|
||||
response="42",
|
||||
thought="",
|
||||
sources=[],
|
||||
tool_calls=[],
|
||||
llm=mock_llm,
|
||||
model_id="m",
|
||||
decoded_token={"sub": "user_123"},
|
||||
)
|
||||
|
||||
saved = collection.find_one({"_id": ObjectId(conv_id)})
|
||||
assert saved["name"] == "What is the meaning of life in programming?"[:50]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestUpdateCompressionMetadata:
|
||||
|
||||
def test_updates_compression_fields(self, mock_mongo_db):
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
|
||||
service = ConversationService()
|
||||
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
|
||||
|
||||
conv_id = ObjectId()
|
||||
collection.insert_one(
|
||||
{"_id": conv_id, "user": "u", "queries": []}
|
||||
)
|
||||
|
||||
meta = {
|
||||
"timestamp": datetime.now(timezone.utc),
|
||||
"compressed_summary": "Summary of conversation",
|
||||
"model_used": "gpt-4",
|
||||
}
|
||||
|
||||
service.update_compression_metadata(str(conv_id), meta)
|
||||
|
||||
saved = collection.find_one({"_id": conv_id})
|
||||
assert saved["compression_metadata"]["is_compressed"] is True
|
||||
assert len(saved["compression_metadata"]["compression_points"]) == 1
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestAppendCompressionMessage:
|
||||
|
||||
def test_appends_summary_query(self, mock_mongo_db):
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
|
||||
service = ConversationService()
|
||||
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
|
||||
|
||||
conv_id = ObjectId()
|
||||
collection.insert_one(
|
||||
{"_id": conv_id, "user": "u", "queries": []}
|
||||
)
|
||||
|
||||
meta = {
|
||||
"compressed_summary": "This is the summary",
|
||||
"timestamp": datetime.now(timezone.utc),
|
||||
"model_used": "gpt-4",
|
||||
}
|
||||
|
||||
service.append_compression_message(str(conv_id), meta)
|
||||
|
||||
saved = collection.find_one({"_id": conv_id})
|
||||
assert len(saved["queries"]) == 1
|
||||
assert saved["queries"][0]["prompt"] == "[Context Compression Summary]"
|
||||
assert saved["queries"][0]["response"] == "This is the summary"
|
||||
|
||||
def test_empty_summary_does_nothing(self, mock_mongo_db):
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
|
||||
service = ConversationService()
|
||||
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
|
||||
|
||||
conv_id = ObjectId()
|
||||
collection.insert_one(
|
||||
{"_id": conv_id, "user": "u", "queries": []}
|
||||
)
|
||||
|
||||
service.append_compression_message(str(conv_id), {"compressed_summary": ""})
|
||||
|
||||
saved = collection.find_one({"_id": conv_id})
|
||||
assert len(saved["queries"]) == 0
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGetCompressionMetadata:
|
||||
|
||||
def test_returns_metadata(self, mock_mongo_db):
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
|
||||
service = ConversationService()
|
||||
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
|
||||
|
||||
conv_id = ObjectId()
|
||||
collection.insert_one(
|
||||
{
|
||||
"_id": conv_id,
|
||||
"user": "u",
|
||||
"compression_metadata": {"is_compressed": True},
|
||||
}
|
||||
)
|
||||
|
||||
result = service.get_compression_metadata(str(conv_id))
|
||||
assert result is not None
|
||||
assert result["is_compressed"] is True
|
||||
|
||||
def test_returns_none_for_no_metadata(self, mock_mongo_db):
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
|
||||
service = ConversationService()
|
||||
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
|
||||
|
||||
conv_id = ObjectId()
|
||||
collection.insert_one({"_id": conv_id, "user": "u"})
|
||||
|
||||
result = service.get_compression_metadata(str(conv_id))
|
||||
assert result is None
|
||||
|
||||
def test_returns_none_for_missing_conversation(self, mock_mongo_db):
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
|
||||
service = ConversationService()
|
||||
result = service.get_compression_metadata(str(ObjectId()))
|
||||
assert result is None
|
||||
|
||||
def test_handles_invalid_id(self, mock_mongo_db):
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
|
||||
service = ConversationService()
|
||||
result = service.get_compression_metadata("invalid-id")
|
||||
assert result is None
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Coverage gap tests (lines 233-237, 258, 261)
|
||||
# =====================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestConversationServiceGaps:
|
||||
|
||||
def test_update_compression_metadata_exception_raises(self, mock_mongo_db):
|
||||
"""Cover lines 233-237: exception during update raises."""
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
|
||||
service = ConversationService()
|
||||
service.conversations_collection = MagicMock()
|
||||
service.conversations_collection.update_one.side_effect = Exception("db error")
|
||||
|
||||
with pytest.raises(Exception, match="db error"):
|
||||
service.update_compression_metadata(
|
||||
str(ObjectId()),
|
||||
{
|
||||
"compressed_summary": "summary",
|
||||
"query_index": 5,
|
||||
"compressed_token_count": 100,
|
||||
"original_token_count": 1000,
|
||||
},
|
||||
)
|
||||
|
||||
def test_append_compression_message_with_summary(self, mock_mongo_db):
|
||||
"""Cover lines 258, 261: appends compression message to conversation."""
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
|
||||
service = ConversationService()
|
||||
service.conversations_collection = MagicMock()
|
||||
|
||||
conv_id = str(ObjectId())
|
||||
metadata = {
|
||||
"compressed_summary": "This is a summary of the conversation.",
|
||||
"timestamp": "2024-01-01T00:00:00",
|
||||
"model_used": "gpt-4",
|
||||
}
|
||||
service.append_compression_message(conv_id, metadata)
|
||||
service.conversations_collection.update_one.assert_called_once()
|
||||
|
||||
def test_append_compression_message_empty_summary_skips(self, mock_mongo_db):
|
||||
"""Cover: empty summary does not insert."""
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
|
||||
service = ConversationService()
|
||||
service.conversations_collection = MagicMock()
|
||||
|
||||
service.append_compression_message(str(ObjectId()), {"compressed_summary": ""})
|
||||
service.conversations_collection.update_one.assert_not_called()
|
||||
3213
tests/api/answer/test_stream_processor.py
Normal file
3213
tests/api/answer/test_stream_processor.py
Normal file
File diff suppressed because it is too large
Load Diff
821
tests/api/test_connector_routes.py
Normal file
821
tests/api/test_connector_routes.py
Normal file
@@ -0,0 +1,821 @@
|
||||
"""Tests for application/api/connector/routes.py"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import mongomock
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
with patch("application.app.handle_auth", return_value={"sub": "test_user"}):
|
||||
from application.app import app as flask_app
|
||||
flask_app.config["TESTING"] = True
|
||||
yield flask_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
return app.test_client()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_sessions(monkeypatch):
|
||||
mock_client = mongomock.MongoClient()
|
||||
mock_db = mock_client["docsgpt"]
|
||||
sessions = mock_db["connector_sessions"]
|
||||
sources = mock_db["sources"]
|
||||
monkeypatch.setattr("application.api.connector.routes.sessions_collection", sessions)
|
||||
monkeypatch.setattr("application.api.connector.routes.sources_collection", sources)
|
||||
return {"sessions": sessions, "sources": sources}
|
||||
|
||||
|
||||
class TestConnectorAuth:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_missing_provider(self, client):
|
||||
resp = client.get("/api/connectors/auth")
|
||||
assert resp.status_code == 400
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_unsupported_provider(self, client):
|
||||
resp = client.get("/api/connectors/auth?provider=dropbox")
|
||||
assert resp.status_code == 400
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_unauthorized(self, client, app):
|
||||
with patch("application.app.handle_auth", return_value=None):
|
||||
resp = client.get("/api/connectors/auth?provider=google_drive")
|
||||
data = json.loads(resp.data)
|
||||
# decoded_token is None -> 401
|
||||
assert resp.status_code == 401 or data.get("error") == "Unauthorized"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_success(self, client, mock_sessions):
|
||||
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
|
||||
MockCC.is_supported.return_value = True
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.get_authorization_url.return_value = "https://oauth.example.com/auth"
|
||||
MockCC.create_auth.return_value = mock_auth
|
||||
|
||||
resp = client.get("/api/connectors/auth?provider=google_drive")
|
||||
assert resp.status_code == 200
|
||||
data = json.loads(resp.data)
|
||||
assert data["success"] is True
|
||||
assert "authorization_url" in data
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_exception_returns_500(self, client, mock_sessions):
|
||||
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
|
||||
MockCC.is_supported.return_value = True
|
||||
MockCC.create_auth.side_effect = Exception("oauth fail")
|
||||
resp = client.get("/api/connectors/auth?provider=google_drive")
|
||||
assert resp.status_code == 500
|
||||
|
||||
|
||||
class TestConnectorFiles:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_missing_params(self, client):
|
||||
resp = client.post("/api/connectors/files", json={"provider": "google_drive"})
|
||||
assert resp.status_code == 400
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invalid_session(self, client, mock_sessions):
|
||||
resp = client.post("/api/connectors/files", json={
|
||||
"provider": "google_drive",
|
||||
"session_token": "bad_token",
|
||||
})
|
||||
assert resp.status_code == 401
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_success(self, client, mock_sessions):
|
||||
mock_sessions["sessions"].insert_one({
|
||||
"session_token": "valid_tok",
|
||||
"user": "test_user",
|
||||
"provider": "google_drive",
|
||||
})
|
||||
|
||||
mock_doc = MagicMock()
|
||||
mock_doc.doc_id = "f1"
|
||||
mock_doc.extra_info = {
|
||||
"file_name": "test.pdf",
|
||||
"mime_type": "application/pdf",
|
||||
"size": 1024,
|
||||
"modified_time": "2025-01-01T12:00:00.000Z",
|
||||
"is_folder": False,
|
||||
}
|
||||
mock_loader = MagicMock()
|
||||
mock_loader.load_data.return_value = [mock_doc]
|
||||
mock_loader.next_page_token = None
|
||||
|
||||
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
|
||||
MockCC.create_connector.return_value = mock_loader
|
||||
resp = client.post("/api/connectors/files", json={
|
||||
"provider": "google_drive",
|
||||
"session_token": "valid_tok",
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
data = json.loads(resp.data)
|
||||
assert data["success"] is True
|
||||
assert len(data["files"]) == 1
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_modified_time(self, client, mock_sessions):
|
||||
mock_sessions["sessions"].insert_one({
|
||||
"session_token": "tok2",
|
||||
"user": "test_user",
|
||||
"provider": "google_drive",
|
||||
})
|
||||
mock_doc = MagicMock()
|
||||
mock_doc.doc_id = "f1"
|
||||
mock_doc.extra_info = {"file_name": "test.pdf", "mime_type": "application/pdf"}
|
||||
mock_loader = MagicMock()
|
||||
mock_loader.load_data.return_value = [mock_doc]
|
||||
mock_loader.next_page_token = None
|
||||
|
||||
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
|
||||
MockCC.create_connector.return_value = mock_loader
|
||||
resp = client.post("/api/connectors/files", json={
|
||||
"provider": "google_drive", "session_token": "tok2",
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
class TestConnectorValidateSession:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_missing_params(self, client):
|
||||
resp = client.post("/api/connectors/validate-session", json={"provider": "google_drive"})
|
||||
assert resp.status_code == 400
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invalid_session(self, client, mock_sessions):
|
||||
resp = client.post("/api/connectors/validate-session", json={
|
||||
"provider": "google_drive", "session_token": "bad",
|
||||
})
|
||||
assert resp.status_code == 401
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_valid_non_expired(self, client, mock_sessions):
|
||||
mock_sessions["sessions"].insert_one({
|
||||
"session_token": "valid",
|
||||
"user": "test_user",
|
||||
"provider": "google_drive",
|
||||
"token_info": {"access_token": "at", "refresh_token": "rt", "expiry": None},
|
||||
"user_email": "user@example.com",
|
||||
})
|
||||
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.is_token_expired.return_value = False
|
||||
MockCC.create_auth.return_value = mock_auth
|
||||
resp = client.post("/api/connectors/validate-session", json={
|
||||
"provider": "google_drive", "session_token": "valid",
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
data = json.loads(resp.data)
|
||||
assert data["success"] is True
|
||||
assert data["expired"] is False
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_expired_with_refresh(self, client, mock_sessions):
|
||||
mock_sessions["sessions"].insert_one({
|
||||
"session_token": "expired_tok",
|
||||
"user": "test_user",
|
||||
"provider": "google_drive",
|
||||
"token_info": {"access_token": "old_at", "refresh_token": "rt", "expiry": 100},
|
||||
})
|
||||
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.is_token_expired.return_value = True
|
||||
mock_auth.refresh_access_token.return_value = {"access_token": "new_at", "refresh_token": "rt"}
|
||||
mock_auth.sanitize_token_info.return_value = {"access_token": "new_at", "refresh_token": "rt"}
|
||||
MockCC.create_auth.return_value = mock_auth
|
||||
resp = client.post("/api/connectors/validate-session", json={
|
||||
"provider": "google_drive", "session_token": "expired_tok",
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_expired_no_refresh(self, client, mock_sessions):
|
||||
mock_sessions["sessions"].insert_one({
|
||||
"session_token": "exp_no_ref",
|
||||
"user": "test_user",
|
||||
"token_info": {"access_token": "at", "expiry": 100},
|
||||
})
|
||||
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.is_token_expired.return_value = True
|
||||
MockCC.create_auth.return_value = mock_auth
|
||||
resp = client.post("/api/connectors/validate-session", json={
|
||||
"provider": "google_drive", "session_token": "exp_no_ref",
|
||||
})
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
class TestConnectorDisconnect:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_missing_provider(self, client):
|
||||
resp = client.post("/api/connectors/disconnect", json={})
|
||||
assert resp.status_code == 400
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_success_with_session(self, client, mock_sessions):
|
||||
mock_sessions["sessions"].insert_one({"session_token": "del_me", "provider": "google_drive"})
|
||||
resp = client.post("/api/connectors/disconnect", json={
|
||||
"provider": "google_drive", "session_token": "del_me",
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
data = json.loads(resp.data)
|
||||
assert data["success"] is True
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_success_without_session(self, client, mock_sessions):
|
||||
resp = client.post("/api/connectors/disconnect", json={"provider": "google_drive"})
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
class TestConnectorSync:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_missing_params(self, client, mock_sessions):
|
||||
resp = client.post("/api/connectors/sync", json={"source_id": "abc"})
|
||||
assert resp.status_code == 400
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_source_not_found(self, client, mock_sessions):
|
||||
from bson.objectid import ObjectId
|
||||
resp = client.post("/api/connectors/sync", json={
|
||||
"source_id": str(ObjectId()), "session_token": "tok",
|
||||
})
|
||||
assert resp.status_code == 404
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_unauthorized_source(self, client, mock_sessions):
|
||||
sid = mock_sessions["sources"].insert_one({"user": "other_user", "name": "src"}).inserted_id
|
||||
resp = client.post("/api/connectors/sync", json={
|
||||
"source_id": str(sid), "session_token": "tok",
|
||||
})
|
||||
assert resp.status_code == 403
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_missing_provider_in_remote_data(self, client, mock_sessions):
|
||||
sid = mock_sessions["sources"].insert_one({
|
||||
"user": "test_user", "name": "src", "remote_data": json.dumps({}),
|
||||
}).inserted_id
|
||||
resp = client.post("/api/connectors/sync", json={
|
||||
"source_id": str(sid), "session_token": "tok",
|
||||
})
|
||||
assert resp.status_code == 400
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_success(self, client, mock_sessions):
|
||||
sid = mock_sessions["sources"].insert_one({
|
||||
"user": "test_user",
|
||||
"name": "src",
|
||||
"remote_data": json.dumps({"provider": "google_drive", "file_ids": ["f1"]}),
|
||||
}).inserted_id
|
||||
mock_task = MagicMock()
|
||||
mock_task.id = "task_123"
|
||||
with patch("application.api.connector.routes.ingest_connector_task") as mock_ingest:
|
||||
mock_ingest.delay.return_value = mock_task
|
||||
resp = client.post("/api/connectors/sync", json={
|
||||
"source_id": str(sid), "session_token": "tok",
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
data = json.loads(resp.data)
|
||||
assert data["task_id"] == "task_123"
|
||||
|
||||
|
||||
class TestConnectorCallbackStatus:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_success_status(self, client):
|
||||
resp = client.get("/api/connectors/callback-status?status=success&message=OK&provider=google_drive&session_token=tok&user_email=u@e.com")
|
||||
assert resp.status_code == 200
|
||||
assert b"success" in resp.data
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_error_status(self, client):
|
||||
resp = client.get("/api/connectors/callback-status?status=error&message=Failed")
|
||||
assert resp.status_code == 200
|
||||
assert b"error" in resp.data
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_cancelled_status(self, client):
|
||||
resp = client.get("/api/connectors/callback-status?status=cancelled&message=Cancelled&provider=google_drive")
|
||||
assert resp.status_code == 200
|
||||
assert b"cancelled" in resp.data
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_unknown_status_defaults_to_error(self, client):
|
||||
resp = client.get("/api/connectors/callback-status?status=badvalue")
|
||||
assert resp.status_code == 200
|
||||
assert b"error" in resp.data
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_html_escaping(self, client):
|
||||
resp = client.get('/api/connectors/callback-status?status=error&message=<script>alert(1)</script>')
|
||||
assert resp.status_code == 200
|
||||
# The raw <script> tag should be escaped (not executable)
|
||||
assert b"<script>alert(1)</script>" not in resp.data
|
||||
|
||||
|
||||
class TestBuildCallbackRedirect:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_builds_url(self):
|
||||
from application.api.connector.routes import build_callback_redirect
|
||||
url = build_callback_redirect({"status": "success", "message": "OK"})
|
||||
assert url.startswith("/api/connectors/callback-status?")
|
||||
assert "status=success" in url
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestConnectorsCallback:
|
||||
"""Tests for the ConnectorsCallback OAuth callback route."""
|
||||
|
||||
def _encode_state(self, state_dict):
|
||||
return base64.urlsafe_b64encode(json.dumps(state_dict).encode()).decode()
|
||||
|
||||
def _patch_connector_creator(self):
|
||||
"""Patch ConnectorCreator at both module-level and local-import locations."""
|
||||
return patch(
|
||||
"application.parser.connectors.connector_creator.ConnectorCreator",
|
||||
)
|
||||
|
||||
def test_callback_invalid_provider_redirects_error(self, client, mock_sessions):
|
||||
state = self._encode_state({"provider": "dropbox", "object_id": "abc123"})
|
||||
with self._patch_connector_creator() as MockCC:
|
||||
MockCC.is_supported.return_value = False
|
||||
resp = client.get(
|
||||
f"/api/connectors/callback?code=auth_code&state={state}"
|
||||
)
|
||||
assert resp.status_code == 302
|
||||
assert "error" in resp.headers.get("Location", "")
|
||||
|
||||
def test_callback_access_denied_redirects_cancelled(self, client, mock_sessions):
|
||||
state = self._encode_state(
|
||||
{"provider": "google_drive", "object_id": "abc123"}
|
||||
)
|
||||
with self._patch_connector_creator() as MockCC:
|
||||
MockCC.is_supported.return_value = True
|
||||
resp = client.get(
|
||||
f"/api/connectors/callback?error=access_denied&state={state}"
|
||||
)
|
||||
assert resp.status_code == 302
|
||||
assert "cancelled" in resp.headers.get("Location", "")
|
||||
|
||||
def test_callback_other_error_redirects_error(self, client, mock_sessions):
|
||||
state = self._encode_state(
|
||||
{"provider": "google_drive", "object_id": "abc123"}
|
||||
)
|
||||
with self._patch_connector_creator() as MockCC:
|
||||
MockCC.is_supported.return_value = True
|
||||
resp = client.get(
|
||||
f"/api/connectors/callback?error=server_error&state={state}"
|
||||
)
|
||||
assert resp.status_code == 302
|
||||
assert "error" in resp.headers.get("Location", "")
|
||||
|
||||
def test_callback_missing_code_redirects_error(self, client, mock_sessions):
|
||||
state = self._encode_state(
|
||||
{"provider": "google_drive", "object_id": "abc123"}
|
||||
)
|
||||
with self._patch_connector_creator() as MockCC:
|
||||
MockCC.is_supported.return_value = True
|
||||
resp = client.get(f"/api/connectors/callback?state={state}")
|
||||
assert resp.status_code == 302
|
||||
assert "error" in resp.headers.get("Location", "")
|
||||
|
||||
def test_callback_success_google_drive(self, client, mock_sessions):
|
||||
oid = mock_sessions["sessions"].insert_one(
|
||||
{
|
||||
"provider": "google_drive",
|
||||
"user": "test_user",
|
||||
"status": "pending",
|
||||
}
|
||||
).inserted_id
|
||||
state = self._encode_state(
|
||||
{"provider": "google_drive", "object_id": str(oid)}
|
||||
)
|
||||
with self._patch_connector_creator() as MockCC:
|
||||
MockCC.is_supported.return_value = True
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.exchange_code_for_tokens.return_value = {
|
||||
"access_token": "at",
|
||||
"refresh_token": "rt",
|
||||
}
|
||||
mock_creds = MagicMock()
|
||||
mock_auth.create_credentials_from_token_info.return_value = mock_creds
|
||||
mock_service = MagicMock()
|
||||
mock_service.about.return_value.get.return_value.execute.return_value = {
|
||||
"user": {"emailAddress": "user@example.com"}
|
||||
}
|
||||
mock_auth.build_drive_service.return_value = mock_service
|
||||
mock_auth.sanitize_token_info.return_value = {
|
||||
"access_token": "at",
|
||||
"refresh_token": "rt",
|
||||
}
|
||||
MockCC.create_auth.return_value = mock_auth
|
||||
|
||||
resp = client.get(
|
||||
f"/api/connectors/callback?code=auth_code&state={state}"
|
||||
)
|
||||
assert resp.status_code == 302
|
||||
assert "success" in resp.headers.get("Location", "")
|
||||
|
||||
def test_callback_success_non_google_provider(self, client, mock_sessions):
|
||||
oid = mock_sessions["sessions"].insert_one(
|
||||
{
|
||||
"provider": "other_provider",
|
||||
"user": "test_user",
|
||||
"status": "pending",
|
||||
}
|
||||
).inserted_id
|
||||
state = self._encode_state(
|
||||
{"provider": "other_provider", "object_id": str(oid)}
|
||||
)
|
||||
with self._patch_connector_creator() as MockCC:
|
||||
MockCC.is_supported.return_value = True
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.exchange_code_for_tokens.return_value = {
|
||||
"access_token": "at",
|
||||
"user_info": {"email": "other@example.com"},
|
||||
}
|
||||
mock_auth.sanitize_token_info.return_value = {"access_token": "at"}
|
||||
MockCC.create_auth.return_value = mock_auth
|
||||
|
||||
resp = client.get(
|
||||
f"/api/connectors/callback?code=auth_code&state={state}"
|
||||
)
|
||||
assert resp.status_code == 302
|
||||
assert "success" in resp.headers.get("Location", "")
|
||||
|
||||
def test_callback_exchange_tokens_fails(self, client, mock_sessions):
|
||||
oid = mock_sessions["sessions"].insert_one(
|
||||
{
|
||||
"provider": "google_drive",
|
||||
"user": "test_user",
|
||||
"status": "pending",
|
||||
}
|
||||
).inserted_id
|
||||
state = self._encode_state(
|
||||
{"provider": "google_drive", "object_id": str(oid)}
|
||||
)
|
||||
with self._patch_connector_creator() as MockCC:
|
||||
MockCC.is_supported.return_value = True
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.exchange_code_for_tokens.side_effect = Exception("token error")
|
||||
MockCC.create_auth.return_value = mock_auth
|
||||
|
||||
resp = client.get(
|
||||
f"/api/connectors/callback?code=auth_code&state={state}"
|
||||
)
|
||||
assert resp.status_code == 302
|
||||
assert "error" in resp.headers.get("Location", "")
|
||||
|
||||
def test_callback_bad_state_returns_error(self, client, mock_sessions):
|
||||
resp = client.get("/api/connectors/callback?code=auth_code&state=badbase64!!!")
|
||||
assert resp.status_code == 302
|
||||
assert "error" in resp.headers.get("Location", "")
|
||||
|
||||
def test_callback_user_info_fails_gracefully(self, client, mock_sessions):
|
||||
oid = mock_sessions["sessions"].insert_one(
|
||||
{
|
||||
"provider": "google_drive",
|
||||
"user": "test_user",
|
||||
"status": "pending",
|
||||
}
|
||||
).inserted_id
|
||||
state = self._encode_state(
|
||||
{"provider": "google_drive", "object_id": str(oid)}
|
||||
)
|
||||
with self._patch_connector_creator() as MockCC:
|
||||
MockCC.is_supported.return_value = True
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.exchange_code_for_tokens.return_value = {
|
||||
"access_token": "at",
|
||||
"refresh_token": "rt",
|
||||
}
|
||||
mock_auth.create_credentials_from_token_info.side_effect = Exception(
|
||||
"cred error"
|
||||
)
|
||||
mock_auth.sanitize_token_info.return_value = {
|
||||
"access_token": "at",
|
||||
}
|
||||
MockCC.create_auth.return_value = mock_auth
|
||||
|
||||
resp = client.get(
|
||||
f"/api/connectors/callback?code=auth_code&state={state}"
|
||||
)
|
||||
assert resp.status_code == 302
|
||||
assert "success" in resp.headers.get("Location", "")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestConnectorFilesAdditional:
|
||||
"""Additional tests for ConnectorFiles."""
|
||||
|
||||
def test_unauthorized_user(self, client, mock_sessions):
|
||||
with patch("application.app.handle_auth", return_value=None):
|
||||
resp = client.post(
|
||||
"/api/connectors/files",
|
||||
json={
|
||||
"provider": "google_drive",
|
||||
"session_token": "tok",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
def test_files_with_pagination(self, client, mock_sessions):
|
||||
mock_sessions["sessions"].insert_one(
|
||||
{
|
||||
"session_token": "pag_tok",
|
||||
"user": "test_user",
|
||||
"provider": "google_drive",
|
||||
}
|
||||
)
|
||||
|
||||
mock_doc = MagicMock()
|
||||
mock_doc.doc_id = "f1"
|
||||
mock_doc.extra_info = {
|
||||
"file_name": "test.pdf",
|
||||
"mime_type": "application/pdf",
|
||||
"size": 1024,
|
||||
"modified_time": "2025-01-01T12:00:00.000Z",
|
||||
"is_folder": False,
|
||||
}
|
||||
mock_loader = MagicMock()
|
||||
mock_loader.load_data.return_value = [mock_doc]
|
||||
mock_loader.next_page_token = "next_token_123"
|
||||
|
||||
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
|
||||
MockCC.create_connector.return_value = mock_loader
|
||||
resp = client.post(
|
||||
"/api/connectors/files",
|
||||
json={
|
||||
"provider": "google_drive",
|
||||
"session_token": "pag_tok",
|
||||
"page_token": "prev_token",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = json.loads(resp.data)
|
||||
assert data["has_more"] is True
|
||||
assert data["next_page_token"] == "next_token_123"
|
||||
|
||||
def test_files_exception_returns_500(self, client, mock_sessions):
|
||||
mock_sessions["sessions"].insert_one(
|
||||
{
|
||||
"session_token": "err_tok",
|
||||
"user": "test_user",
|
||||
"provider": "google_drive",
|
||||
}
|
||||
)
|
||||
|
||||
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
|
||||
MockCC.create_connector.side_effect = Exception("connector error")
|
||||
resp = client.post(
|
||||
"/api/connectors/files",
|
||||
json={
|
||||
"provider": "google_drive",
|
||||
"session_token": "err_tok",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 500
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestConnectorFilesSearchQuery:
|
||||
"""Test ConnectorFiles with search_query parameter."""
|
||||
|
||||
def test_files_with_search_query(self, client, mock_sessions):
|
||||
mock_sessions["sessions"].insert_one(
|
||||
{
|
||||
"session_token": "search_tok",
|
||||
"user": "test_user",
|
||||
"provider": "google_drive",
|
||||
}
|
||||
)
|
||||
|
||||
mock_doc = MagicMock()
|
||||
mock_doc.doc_id = "f1"
|
||||
mock_doc.extra_info = {
|
||||
"file_name": "result.pdf",
|
||||
"mime_type": "application/pdf",
|
||||
"size": 512,
|
||||
"is_folder": False,
|
||||
}
|
||||
mock_loader = MagicMock()
|
||||
mock_loader.load_data.return_value = [mock_doc]
|
||||
mock_loader.next_page_token = None
|
||||
|
||||
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
|
||||
MockCC.create_connector.return_value = mock_loader
|
||||
resp = client.post(
|
||||
"/api/connectors/files",
|
||||
json={
|
||||
"provider": "google_drive",
|
||||
"session_token": "search_tok",
|
||||
"search_query": "test search",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = json.loads(resp.data)
|
||||
assert data["success"] is True
|
||||
# Verify search_query was passed in input_config
|
||||
call_args = mock_loader.load_data.call_args[0][0]
|
||||
assert call_args.get("search_query") == "test search"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Additional coverage tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestConnectorValidateSessionAdditional:
|
||||
"""Cover uncovered branches in ConnectorValidateSession."""
|
||||
|
||||
def test_unauthorized_returns_401(self, client, mock_sessions):
|
||||
"""Line 288: decoded_token is None -> 401."""
|
||||
with patch("application.app.handle_auth", return_value=None):
|
||||
resp = client.post(
|
||||
"/api/connectors/validate-session",
|
||||
json={
|
||||
"provider": "google_drive",
|
||||
"session_token": "tok",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
def test_refresh_token_failure_still_expired(self, client, mock_sessions):
|
||||
"""Lines 299-310: refresh attempt fails, token stays expired."""
|
||||
mock_sessions["sessions"].insert_one({
|
||||
"session_token": "rf_fail_tok",
|
||||
"user": "test_user",
|
||||
"provider": "google_drive",
|
||||
"token_info": {
|
||||
"access_token": "old_at",
|
||||
"refresh_token": "rt",
|
||||
"expiry": 100,
|
||||
},
|
||||
})
|
||||
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.is_token_expired.return_value = True
|
||||
mock_auth.refresh_access_token.side_effect = Exception("refresh failed")
|
||||
MockCC.create_auth.return_value = mock_auth
|
||||
resp = client.post(
|
||||
"/api/connectors/validate-session",
|
||||
json={
|
||||
"provider": "google_drive",
|
||||
"session_token": "rf_fail_tok",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
data = json.loads(resp.data)
|
||||
assert data["expired"] is True
|
||||
|
||||
def test_provider_extras_in_response(self, client, mock_sessions):
|
||||
"""Lines 319-327: provider_extras are included in response."""
|
||||
mock_sessions["sessions"].insert_one({
|
||||
"session_token": "extras_tok",
|
||||
"user": "test_user",
|
||||
"provider": "google_drive",
|
||||
"token_info": {
|
||||
"access_token": "at",
|
||||
"refresh_token": "rt",
|
||||
"token_uri": "uri",
|
||||
"expiry": None,
|
||||
"custom_field": "custom_value",
|
||||
},
|
||||
"user_email": "user@test.com",
|
||||
})
|
||||
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.is_token_expired.return_value = False
|
||||
MockCC.create_auth.return_value = mock_auth
|
||||
resp = client.post(
|
||||
"/api/connectors/validate-session",
|
||||
json={
|
||||
"provider": "google_drive",
|
||||
"session_token": "extras_tok",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = json.loads(resp.data)
|
||||
assert data["success"] is True
|
||||
assert data["custom_field"] == "custom_value"
|
||||
assert data["user_email"] == "user@test.com"
|
||||
|
||||
def test_exception_returns_500(self, client, mock_sessions):
|
||||
"""Lines 331-333: general exception -> 500."""
|
||||
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
|
||||
MockCC.create_auth.side_effect = Exception("total failure")
|
||||
mock_sessions["sessions"].insert_one({
|
||||
"session_token": "err_tok",
|
||||
"user": "test_user",
|
||||
"provider": "google_drive",
|
||||
"token_info": {"access_token": "at"},
|
||||
})
|
||||
resp = client.post(
|
||||
"/api/connectors/validate-session",
|
||||
json={
|
||||
"provider": "google_drive",
|
||||
"session_token": "err_tok",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 500
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestConnectorDisconnectAdditional:
|
||||
"""Cover uncovered branches in ConnectorDisconnect."""
|
||||
|
||||
def test_exception_returns_500(self, client, mock_sessions):
|
||||
"""Lines 353-355: exception in disconnect -> 500."""
|
||||
with patch(
|
||||
"application.api.connector.routes.sessions_collection"
|
||||
) as mock_col:
|
||||
mock_col.delete_one.side_effect = Exception("db down")
|
||||
resp = client.post(
|
||||
"/api/connectors/disconnect",
|
||||
json={
|
||||
"provider": "google_drive",
|
||||
"session_token": "tok",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 500
|
||||
|
||||
def test_unauthorized_still_works(self, client, mock_sessions):
|
||||
"""ConnectorDisconnect doesn't check decoded_token, just data parsing.
|
||||
No auth check branch to cover, but confirm basic flow."""
|
||||
resp = client.post(
|
||||
"/api/connectors/disconnect",
|
||||
json={"provider": "google_drive"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestConnectorSyncAdditional:
|
||||
"""Cover uncovered branches in ConnectorSync."""
|
||||
|
||||
def test_unauthorized_returns_401(self, client, mock_sessions):
|
||||
"""Line 373: decoded_token is None -> 401."""
|
||||
from bson.objectid import ObjectId as ObjId
|
||||
|
||||
with patch("application.app.handle_auth", return_value=None):
|
||||
resp = client.post(
|
||||
"/api/connectors/sync",
|
||||
json={
|
||||
"source_id": str(ObjId()),
|
||||
"session_token": "tok",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
def test_exception_returns_400(self, client, mock_sessions):
|
||||
"""Lines 453-464: general exception returns 400."""
|
||||
sid = mock_sessions["sources"].insert_one({
|
||||
"user": "test_user",
|
||||
"name": "src",
|
||||
"remote_data": json.dumps({
|
||||
"provider": "google_drive",
|
||||
"file_ids": ["f1"],
|
||||
}),
|
||||
}).inserted_id
|
||||
with patch(
|
||||
"application.api.connector.routes.ingest_connector_task"
|
||||
) as mock_ingest:
|
||||
mock_ingest.delay.side_effect = Exception("task error")
|
||||
resp = client.post(
|
||||
"/api/connectors/sync",
|
||||
json={
|
||||
"source_id": str(sid),
|
||||
"session_token": "tok",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_invalid_remote_data_json(self, client, mock_sessions):
|
||||
"""Line 411-413: invalid remote_data JSON."""
|
||||
sid = mock_sessions["sources"].insert_one({
|
||||
"user": "test_user",
|
||||
"name": "src",
|
||||
"remote_data": "not-valid-json{",
|
||||
}).inserted_id
|
||||
resp = client.post(
|
||||
"/api/connectors/sync",
|
||||
json={
|
||||
"source_id": str(sid),
|
||||
"session_token": "tok",
|
||||
},
|
||||
)
|
||||
# remote_data parsing fails, remote_data = {}, no provider -> 400
|
||||
assert resp.status_code == 400
|
||||
579
tests/api/test_internal_routes.py
Normal file
579
tests/api/test_internal_routes.py
Normal file
@@ -0,0 +1,579 @@
|
||||
"""Unit tests for application/api/internal/routes.py.
|
||||
|
||||
Covers:
|
||||
- verify_internal_key: key validation
|
||||
- /api/download: file download
|
||||
- /api/upload_index: index file upload (existing & new entries)
|
||||
"""
|
||||
|
||||
import io
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from bson.objectid import ObjectId
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def internal_app(monkeypatch, mock_mongo_db):
|
||||
"""Create a Flask app with the internal blueprint registered."""
|
||||
from flask import Flask
|
||||
|
||||
# Patch module-level MongoDB references before importing routes
|
||||
from application.core.settings import settings
|
||||
|
||||
db = mock_mongo_db[settings.MONGO_DB_NAME]
|
||||
monkeypatch.setattr(
|
||||
"application.api.internal.routes.conversations_collection",
|
||||
db["conversations"],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.api.internal.routes.sources_collection",
|
||||
db["sources"],
|
||||
)
|
||||
|
||||
from application.api.internal.routes import internal
|
||||
|
||||
app = Flask(__name__)
|
||||
app.register_blueprint(internal)
|
||||
app.config["TESTING"] = True
|
||||
return app, db
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# verify_internal_key
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestVerifyInternalKey:
|
||||
|
||||
def test_no_internal_key_configured_allows_access(
|
||||
self, internal_app, monkeypatch
|
||||
):
|
||||
app, db = internal_app
|
||||
monkeypatch.setattr(
|
||||
"application.api.internal.routes.settings",
|
||||
MagicMock(
|
||||
INTERNAL_KEY=None,
|
||||
UPLOAD_FOLDER="uploads",
|
||||
VECTOR_STORE="faiss",
|
||||
EMBEDDINGS_NAME="test",
|
||||
MONGO_DB_NAME="docsgpt",
|
||||
),
|
||||
)
|
||||
with app.test_client() as client:
|
||||
# download will fail for missing file but should not be 401
|
||||
resp = client.get("/api/download?user=u&name=n&file=f")
|
||||
assert resp.status_code != 401
|
||||
|
||||
def test_missing_key_returns_401(self, internal_app, monkeypatch):
|
||||
app, db = internal_app
|
||||
monkeypatch.setattr(
|
||||
"application.api.internal.routes.settings",
|
||||
MagicMock(
|
||||
INTERNAL_KEY="secret123",
|
||||
UPLOAD_FOLDER="uploads",
|
||||
VECTOR_STORE="faiss",
|
||||
EMBEDDINGS_NAME="test",
|
||||
MONGO_DB_NAME="docsgpt",
|
||||
),
|
||||
)
|
||||
with app.test_client() as client:
|
||||
resp = client.get("/api/download?user=u&name=n&file=f")
|
||||
assert resp.status_code == 401
|
||||
|
||||
def test_wrong_key_returns_401(self, internal_app, monkeypatch):
|
||||
app, db = internal_app
|
||||
monkeypatch.setattr(
|
||||
"application.api.internal.routes.settings",
|
||||
MagicMock(
|
||||
INTERNAL_KEY="secret123",
|
||||
UPLOAD_FOLDER="uploads",
|
||||
VECTOR_STORE="faiss",
|
||||
EMBEDDINGS_NAME="test",
|
||||
MONGO_DB_NAME="docsgpt",
|
||||
),
|
||||
)
|
||||
with app.test_client() as client:
|
||||
resp = client.get(
|
||||
"/api/download?user=u&name=n&file=f",
|
||||
headers={"X-Internal-Key": "wrong"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
def test_correct_key_allows_access(self, internal_app, monkeypatch):
|
||||
app, db = internal_app
|
||||
monkeypatch.setattr(
|
||||
"application.api.internal.routes.settings",
|
||||
MagicMock(
|
||||
INTERNAL_KEY="secret123",
|
||||
UPLOAD_FOLDER="uploads",
|
||||
VECTOR_STORE="faiss",
|
||||
EMBEDDINGS_NAME="test",
|
||||
MONGO_DB_NAME="docsgpt",
|
||||
),
|
||||
)
|
||||
with app.test_client() as client:
|
||||
# Will 404 for missing file, but should pass auth check
|
||||
resp = client.get(
|
||||
"/api/download?user=u&name=n&file=f",
|
||||
headers={"X-Internal-Key": "secret123"},
|
||||
)
|
||||
assert resp.status_code != 401
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# /api/upload_index
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestUploadIndex:
|
||||
|
||||
def _make_settings(self, vector_store="faiss"):
|
||||
return MagicMock(
|
||||
INTERNAL_KEY=None,
|
||||
UPLOAD_FOLDER="uploads",
|
||||
VECTOR_STORE=vector_store,
|
||||
EMBEDDINGS_NAME="test_embeddings",
|
||||
MONGO_DB_NAME="docsgpt",
|
||||
)
|
||||
|
||||
def test_missing_user_returns_no_user(self, internal_app, monkeypatch):
|
||||
app, db = internal_app
|
||||
monkeypatch.setattr(
|
||||
"application.api.internal.routes.settings", self._make_settings()
|
||||
)
|
||||
with app.test_client() as client:
|
||||
resp = client.post("/api/upload_index", data={})
|
||||
assert resp.json["status"] == "no user"
|
||||
|
||||
def test_missing_name_returns_no_name(self, internal_app, monkeypatch):
|
||||
app, db = internal_app
|
||||
monkeypatch.setattr(
|
||||
"application.api.internal.routes.settings", self._make_settings()
|
||||
)
|
||||
with app.test_client() as client:
|
||||
resp = client.post("/api/upload_index", data={"user": "testuser"})
|
||||
assert resp.json["status"] == "no name"
|
||||
|
||||
def test_creates_new_source_entry(self, internal_app, monkeypatch):
|
||||
app, db = internal_app
|
||||
doc_id = str(ObjectId())
|
||||
settings_mock = self._make_settings(vector_store="other")
|
||||
monkeypatch.setattr(
|
||||
"application.api.internal.routes.settings", settings_mock
|
||||
)
|
||||
mock_storage = MagicMock()
|
||||
monkeypatch.setattr(
|
||||
"application.api.internal.routes.StorageCreator",
|
||||
MagicMock(get_storage=MagicMock(return_value=mock_storage)),
|
||||
)
|
||||
|
||||
with app.test_client() as client:
|
||||
resp = client.post(
|
||||
"/api/upload_index",
|
||||
data={
|
||||
"user": "testuser",
|
||||
"name": "testjob",
|
||||
"tokens": "100",
|
||||
"retriever": "classic",
|
||||
"id": doc_id,
|
||||
"type": "local",
|
||||
},
|
||||
)
|
||||
assert resp.json["status"] == "ok"
|
||||
|
||||
entry = db["sources"].find_one({"_id": ObjectId(doc_id)})
|
||||
assert entry is not None
|
||||
assert entry["user"] == "testuser"
|
||||
assert entry["name"] == "testjob"
|
||||
|
||||
def test_updates_existing_source_entry(self, internal_app, monkeypatch):
|
||||
app, db = internal_app
|
||||
doc_id = ObjectId()
|
||||
settings_mock = self._make_settings(vector_store="other")
|
||||
monkeypatch.setattr(
|
||||
"application.api.internal.routes.settings", settings_mock
|
||||
)
|
||||
mock_storage = MagicMock()
|
||||
monkeypatch.setattr(
|
||||
"application.api.internal.routes.StorageCreator",
|
||||
MagicMock(get_storage=MagicMock(return_value=mock_storage)),
|
||||
)
|
||||
|
||||
# Insert existing entry
|
||||
db["sources"].insert_one(
|
||||
{"_id": doc_id, "user": "old_user", "name": "old_name"}
|
||||
)
|
||||
|
||||
with app.test_client() as client:
|
||||
resp = client.post(
|
||||
"/api/upload_index",
|
||||
data={
|
||||
"user": "new_user",
|
||||
"name": "new_name",
|
||||
"tokens": "200",
|
||||
"retriever": "hybrid",
|
||||
"id": str(doc_id),
|
||||
"type": "remote",
|
||||
},
|
||||
)
|
||||
assert resp.json["status"] == "ok"
|
||||
|
||||
entry = db["sources"].find_one({"_id": doc_id})
|
||||
assert entry["user"] == "new_user"
|
||||
assert entry["name"] == "new_name"
|
||||
|
||||
def test_parses_directory_structure_json(self, internal_app, monkeypatch):
|
||||
app, db = internal_app
|
||||
doc_id = str(ObjectId())
|
||||
settings_mock = self._make_settings(vector_store="other")
|
||||
monkeypatch.setattr(
|
||||
"application.api.internal.routes.settings", settings_mock
|
||||
)
|
||||
mock_storage = MagicMock()
|
||||
monkeypatch.setattr(
|
||||
"application.api.internal.routes.StorageCreator",
|
||||
MagicMock(get_storage=MagicMock(return_value=mock_storage)),
|
||||
)
|
||||
|
||||
dir_struct = {"root": {"files": ["a.txt", "b.txt"]}}
|
||||
with app.test_client() as client:
|
||||
resp = client.post(
|
||||
"/api/upload_index",
|
||||
data={
|
||||
"user": "u",
|
||||
"name": "n",
|
||||
"tokens": "0",
|
||||
"retriever": "classic",
|
||||
"id": doc_id,
|
||||
"type": "local",
|
||||
"directory_structure": json.dumps(dir_struct),
|
||||
},
|
||||
)
|
||||
assert resp.json["status"] == "ok"
|
||||
|
||||
entry = db["sources"].find_one({"_id": ObjectId(doc_id)})
|
||||
assert entry["directory_structure"] == dir_struct
|
||||
|
||||
def test_invalid_directory_structure_defaults_empty(
|
||||
self, internal_app, monkeypatch
|
||||
):
|
||||
app, db = internal_app
|
||||
doc_id = str(ObjectId())
|
||||
settings_mock = self._make_settings(vector_store="other")
|
||||
monkeypatch.setattr(
|
||||
"application.api.internal.routes.settings", settings_mock
|
||||
)
|
||||
mock_storage = MagicMock()
|
||||
monkeypatch.setattr(
|
||||
"application.api.internal.routes.StorageCreator",
|
||||
MagicMock(get_storage=MagicMock(return_value=mock_storage)),
|
||||
)
|
||||
|
||||
with app.test_client() as client:
|
||||
resp = client.post(
|
||||
"/api/upload_index",
|
||||
data={
|
||||
"user": "u",
|
||||
"name": "n",
|
||||
"tokens": "0",
|
||||
"retriever": "classic",
|
||||
"id": doc_id,
|
||||
"type": "local",
|
||||
"directory_structure": "not valid json",
|
||||
},
|
||||
)
|
||||
assert resp.json["status"] == "ok"
|
||||
|
||||
entry = db["sources"].find_one({"_id": ObjectId(doc_id)})
|
||||
assert entry["directory_structure"] == {}
|
||||
|
||||
def test_file_name_map_parsed(self, internal_app, monkeypatch):
|
||||
app, db = internal_app
|
||||
doc_id = str(ObjectId())
|
||||
settings_mock = self._make_settings(vector_store="other")
|
||||
monkeypatch.setattr(
|
||||
"application.api.internal.routes.settings", settings_mock
|
||||
)
|
||||
mock_storage = MagicMock()
|
||||
monkeypatch.setattr(
|
||||
"application.api.internal.routes.StorageCreator",
|
||||
MagicMock(get_storage=MagicMock(return_value=mock_storage)),
|
||||
)
|
||||
|
||||
fmap = {"hash1": "file1.txt"}
|
||||
with app.test_client() as client:
|
||||
resp = client.post(
|
||||
"/api/upload_index",
|
||||
data={
|
||||
"user": "u",
|
||||
"name": "n",
|
||||
"tokens": "0",
|
||||
"retriever": "classic",
|
||||
"id": doc_id,
|
||||
"type": "local",
|
||||
"file_name_map": json.dumps(fmap),
|
||||
},
|
||||
)
|
||||
assert resp.json["status"] == "ok"
|
||||
|
||||
entry = db["sources"].find_one({"_id": ObjectId(doc_id)})
|
||||
assert entry["file_name_map"] == fmap
|
||||
|
||||
def test_faiss_missing_files_returns_no_file(
|
||||
self, internal_app, monkeypatch
|
||||
):
|
||||
app, db = internal_app
|
||||
doc_id = str(ObjectId())
|
||||
settings_mock = self._make_settings(vector_store="faiss")
|
||||
monkeypatch.setattr(
|
||||
"application.api.internal.routes.settings", settings_mock
|
||||
)
|
||||
mock_storage = MagicMock()
|
||||
monkeypatch.setattr(
|
||||
"application.api.internal.routes.StorageCreator",
|
||||
MagicMock(get_storage=MagicMock(return_value=mock_storage)),
|
||||
)
|
||||
|
||||
with app.test_client() as client:
|
||||
resp = client.post(
|
||||
"/api/upload_index",
|
||||
data={
|
||||
"user": "u",
|
||||
"name": "n",
|
||||
"tokens": "0",
|
||||
"retriever": "classic",
|
||||
"id": doc_id,
|
||||
"type": "local",
|
||||
},
|
||||
)
|
||||
assert resp.json["status"] == "no file"
|
||||
|
||||
def test_faiss_empty_filename_returns_no_file_name(
|
||||
self, internal_app, monkeypatch
|
||||
):
|
||||
app, db = internal_app
|
||||
doc_id = str(ObjectId())
|
||||
settings_mock = self._make_settings(vector_store="faiss")
|
||||
monkeypatch.setattr(
|
||||
"application.api.internal.routes.settings", settings_mock
|
||||
)
|
||||
mock_storage = MagicMock()
|
||||
monkeypatch.setattr(
|
||||
"application.api.internal.routes.StorageCreator",
|
||||
MagicMock(get_storage=MagicMock(return_value=mock_storage)),
|
||||
)
|
||||
|
||||
with app.test_client() as client:
|
||||
resp = client.post(
|
||||
"/api/upload_index",
|
||||
data={
|
||||
"user": "u",
|
||||
"name": "n",
|
||||
"tokens": "0",
|
||||
"retriever": "classic",
|
||||
"id": doc_id,
|
||||
"type": "local",
|
||||
"file_faiss": (io.BytesIO(b""), ""),
|
||||
},
|
||||
)
|
||||
assert resp.json["status"] == "no file name"
|
||||
|
||||
def test_remote_data_and_sync_frequency(self, internal_app, monkeypatch):
|
||||
app, db = internal_app
|
||||
doc_id = str(ObjectId())
|
||||
settings_mock = self._make_settings(vector_store="other")
|
||||
monkeypatch.setattr(
|
||||
"application.api.internal.routes.settings", settings_mock
|
||||
)
|
||||
mock_storage = MagicMock()
|
||||
monkeypatch.setattr(
|
||||
"application.api.internal.routes.StorageCreator",
|
||||
MagicMock(get_storage=MagicMock(return_value=mock_storage)),
|
||||
)
|
||||
|
||||
with app.test_client() as client:
|
||||
resp = client.post(
|
||||
"/api/upload_index",
|
||||
data={
|
||||
"user": "u",
|
||||
"name": "n",
|
||||
"tokens": "0",
|
||||
"retriever": "classic",
|
||||
"id": doc_id,
|
||||
"type": "remote",
|
||||
"remote_data": '{"url":"http://example.com"}',
|
||||
"sync_frequency": "daily",
|
||||
},
|
||||
)
|
||||
assert resp.json["status"] == "ok"
|
||||
|
||||
entry = db["sources"].find_one({"_id": ObjectId(doc_id)})
|
||||
assert entry["sync_frequency"] == "daily"
|
||||
assert entry["remote_data"] == '{"url":"http://example.com"}'
|
||||
|
||||
def test_faiss_upload_with_valid_files(self, internal_app, monkeypatch):
|
||||
"""Cover lines 93-104: FAISS upload with both faiss and pkl files."""
|
||||
app, db = internal_app
|
||||
doc_id = str(ObjectId())
|
||||
settings_mock = self._make_settings(vector_store="faiss")
|
||||
monkeypatch.setattr(
|
||||
"application.api.internal.routes.settings", settings_mock
|
||||
)
|
||||
mock_storage = MagicMock()
|
||||
monkeypatch.setattr(
|
||||
"application.api.internal.routes.StorageCreator",
|
||||
MagicMock(get_storage=MagicMock(return_value=mock_storage)),
|
||||
)
|
||||
|
||||
with app.test_client() as client:
|
||||
resp = client.post(
|
||||
"/api/upload_index",
|
||||
data={
|
||||
"user": "u",
|
||||
"name": "n",
|
||||
"tokens": "0",
|
||||
"retriever": "classic",
|
||||
"id": doc_id,
|
||||
"type": "local",
|
||||
"file_faiss": (io.BytesIO(b"faiss data"), "index.faiss"),
|
||||
"file_pkl": (io.BytesIO(b"pkl data"), "index.pkl"),
|
||||
},
|
||||
content_type="multipart/form-data",
|
||||
)
|
||||
assert resp.json["status"] == "ok"
|
||||
|
||||
mock_storage.save_file.assert_called()
|
||||
entry = db["sources"].find_one({"_id": ObjectId(doc_id)})
|
||||
assert entry is not None
|
||||
|
||||
def test_faiss_pkl_missing_returns_no_file(self, internal_app, monkeypatch):
|
||||
"""Cover lines 93-95: FAISS upload with faiss file but no pkl file."""
|
||||
app, db = internal_app
|
||||
doc_id = str(ObjectId())
|
||||
settings_mock = self._make_settings(vector_store="faiss")
|
||||
monkeypatch.setattr(
|
||||
"application.api.internal.routes.settings", settings_mock
|
||||
)
|
||||
mock_storage = MagicMock()
|
||||
monkeypatch.setattr(
|
||||
"application.api.internal.routes.StorageCreator",
|
||||
MagicMock(get_storage=MagicMock(return_value=mock_storage)),
|
||||
)
|
||||
|
||||
with app.test_client() as client:
|
||||
resp = client.post(
|
||||
"/api/upload_index",
|
||||
data={
|
||||
"user": "u",
|
||||
"name": "n",
|
||||
"tokens": "0",
|
||||
"retriever": "classic",
|
||||
"id": doc_id,
|
||||
"type": "local",
|
||||
"file_faiss": (io.BytesIO(b"faiss data"), "index.faiss"),
|
||||
},
|
||||
content_type="multipart/form-data",
|
||||
)
|
||||
assert resp.json["status"] == "no file"
|
||||
|
||||
def test_faiss_pkl_empty_name_returns_no_file_name(self, internal_app, monkeypatch):
|
||||
"""Cover lines 97-98: FAISS upload with pkl but empty filename."""
|
||||
app, db = internal_app
|
||||
doc_id = str(ObjectId())
|
||||
settings_mock = self._make_settings(vector_store="faiss")
|
||||
monkeypatch.setattr(
|
||||
"application.api.internal.routes.settings", settings_mock
|
||||
)
|
||||
mock_storage = MagicMock()
|
||||
monkeypatch.setattr(
|
||||
"application.api.internal.routes.StorageCreator",
|
||||
MagicMock(get_storage=MagicMock(return_value=mock_storage)),
|
||||
)
|
||||
|
||||
with app.test_client() as client:
|
||||
resp = client.post(
|
||||
"/api/upload_index",
|
||||
data={
|
||||
"user": "u",
|
||||
"name": "n",
|
||||
"tokens": "0",
|
||||
"retriever": "classic",
|
||||
"id": doc_id,
|
||||
"type": "local",
|
||||
"file_faiss": (io.BytesIO(b"faiss data"), "index.faiss"),
|
||||
"file_pkl": (io.BytesIO(b""), ""),
|
||||
},
|
||||
content_type="multipart/form-data",
|
||||
)
|
||||
assert resp.json["status"] == "no file name"
|
||||
|
||||
def test_update_existing_with_file_name_map(self, internal_app, monkeypatch):
|
||||
"""Cover line 124: update existing entry with file_name_map."""
|
||||
app, db = internal_app
|
||||
doc_id = ObjectId()
|
||||
settings_mock = self._make_settings(vector_store="other")
|
||||
monkeypatch.setattr(
|
||||
"application.api.internal.routes.settings", settings_mock
|
||||
)
|
||||
mock_storage = MagicMock()
|
||||
monkeypatch.setattr(
|
||||
"application.api.internal.routes.StorageCreator",
|
||||
MagicMock(get_storage=MagicMock(return_value=mock_storage)),
|
||||
)
|
||||
|
||||
db["sources"].insert_one({"_id": doc_id, "user": "old_user", "name": "old"})
|
||||
|
||||
fmap = {"hash1": "file1.txt"}
|
||||
with app.test_client() as client:
|
||||
resp = client.post(
|
||||
"/api/upload_index",
|
||||
data={
|
||||
"user": "u",
|
||||
"name": "n",
|
||||
"tokens": "0",
|
||||
"retriever": "classic",
|
||||
"id": str(doc_id),
|
||||
"type": "local",
|
||||
"file_name_map": json.dumps(fmap),
|
||||
},
|
||||
)
|
||||
assert resp.json["status"] == "ok"
|
||||
|
||||
entry = db["sources"].find_one({"_id": doc_id})
|
||||
assert entry["file_name_map"] == fmap
|
||||
|
||||
def test_invalid_file_name_map_defaults_none(self, internal_app, monkeypatch):
|
||||
"""Cover lines 77-79: invalid file_name_map JSON defaults to None."""
|
||||
app, db = internal_app
|
||||
doc_id = str(ObjectId())
|
||||
settings_mock = self._make_settings(vector_store="other")
|
||||
monkeypatch.setattr(
|
||||
"application.api.internal.routes.settings", settings_mock
|
||||
)
|
||||
mock_storage = MagicMock()
|
||||
monkeypatch.setattr(
|
||||
"application.api.internal.routes.StorageCreator",
|
||||
MagicMock(get_storage=MagicMock(return_value=mock_storage)),
|
||||
)
|
||||
|
||||
with app.test_client() as client:
|
||||
resp = client.post(
|
||||
"/api/upload_index",
|
||||
data={
|
||||
"user": "u",
|
||||
"name": "n",
|
||||
"tokens": "0",
|
||||
"retriever": "classic",
|
||||
"id": doc_id,
|
||||
"type": "local",
|
||||
"file_name_map": "not valid json{{{",
|
||||
},
|
||||
)
|
||||
assert resp.json["status"] == "ok"
|
||||
|
||||
entry = db["sources"].find_one({"_id": ObjectId(doc_id)})
|
||||
assert "file_name_map" not in entry
|
||||
0
tests/api/user/__init__.py
Normal file
0
tests/api/user/__init__.py
Normal file
File diff suppressed because it is too large
Load Diff
879
tests/api/user/sources/test_chunks.py
Normal file
879
tests/api/user/sources/test_chunks.py
Normal file
@@ -0,0 +1,879 @@
|
||||
"""Tests for source chunk management routes."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
from bson import ObjectId
|
||||
from flask import Flask
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
app = Flask(__name__)
|
||||
return app
|
||||
|
||||
|
||||
def _status(response):
|
||||
if isinstance(response, tuple):
|
||||
return response[1]
|
||||
return response.status_code
|
||||
|
||||
|
||||
def _json(response):
|
||||
if isinstance(response, tuple):
|
||||
return response[0].json
|
||||
return response.json
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GetChunks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGetChunks:
|
||||
def test_returns_401_unauthenticated(self, app):
|
||||
from application.api.user.sources.chunks import GetChunks
|
||||
|
||||
with app.test_request_context("/api/get_chunks?id=abc"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = None
|
||||
response = GetChunks().get()
|
||||
|
||||
assert _status(response) == 401
|
||||
|
||||
def test_returns_400_for_invalid_doc_id(self, app):
|
||||
from application.api.user.sources.chunks import GetChunks
|
||||
|
||||
with app.test_request_context("/api/get_chunks?id=invalid"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = GetChunks().get()
|
||||
|
||||
assert _status(response) == 400
|
||||
assert "Invalid doc_id" in _json(response)["error"]
|
||||
|
||||
def test_returns_404_when_doc_not_found(self, app):
|
||||
from application.api.user.sources.chunks import GetChunks
|
||||
|
||||
doc_id = str(ObjectId())
|
||||
mock_collection = Mock()
|
||||
mock_collection.find_one.return_value = None
|
||||
|
||||
with patch(
|
||||
"application.api.user.sources.chunks.sources_collection",
|
||||
mock_collection,
|
||||
):
|
||||
with app.test_request_context(f"/api/get_chunks?id={doc_id}"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = GetChunks().get()
|
||||
|
||||
assert _status(response) == 404
|
||||
assert "not found" in _json(response)["error"]
|
||||
|
||||
def test_returns_paginated_chunks(self, app):
|
||||
from application.api.user.sources.chunks import GetChunks
|
||||
|
||||
doc_id = str(ObjectId())
|
||||
mock_collection = Mock()
|
||||
mock_collection.find_one.return_value = {"_id": ObjectId(doc_id), "user": "u1"}
|
||||
|
||||
chunks = [
|
||||
{"text": f"chunk {i}", "metadata": {}, "doc_id": f"c{i}"}
|
||||
for i in range(25)
|
||||
]
|
||||
mock_store = Mock()
|
||||
mock_store.get_chunks.return_value = chunks
|
||||
|
||||
with patch(
|
||||
"application.api.user.sources.chunks.sources_collection",
|
||||
mock_collection,
|
||||
), patch(
|
||||
"application.api.user.sources.chunks.get_vector_store",
|
||||
return_value=mock_store,
|
||||
):
|
||||
with app.test_request_context(
|
||||
f"/api/get_chunks?id={doc_id}&page=2&per_page=10"
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = GetChunks().get()
|
||||
|
||||
assert _status(response) == 200
|
||||
data = _json(response)
|
||||
assert data["total"] == 25
|
||||
assert data["page"] == 2
|
||||
assert data["per_page"] == 10
|
||||
assert len(data["chunks"]) == 10
|
||||
|
||||
def test_filters_chunks_by_path(self, app):
|
||||
from application.api.user.sources.chunks import GetChunks
|
||||
|
||||
doc_id = str(ObjectId())
|
||||
mock_collection = Mock()
|
||||
mock_collection.find_one.return_value = {"_id": ObjectId(doc_id), "user": "u1"}
|
||||
|
||||
chunks = [
|
||||
{"text": "a", "metadata": {"source": "inputs/dir/file.pdf"}, "doc_id": "c1"},
|
||||
{"text": "b", "metadata": {"source": "inputs/other.txt"}, "doc_id": "c2"},
|
||||
{"text": "c", "metadata": {"file_path": "guides/setup.md"}, "doc_id": "c3"},
|
||||
]
|
||||
mock_store = Mock()
|
||||
mock_store.get_chunks.return_value = chunks
|
||||
|
||||
with patch(
|
||||
"application.api.user.sources.chunks.sources_collection",
|
||||
mock_collection,
|
||||
), patch(
|
||||
"application.api.user.sources.chunks.get_vector_store",
|
||||
return_value=mock_store,
|
||||
):
|
||||
with app.test_request_context(
|
||||
f"/api/get_chunks?id={doc_id}&path=file.pdf"
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = GetChunks().get()
|
||||
|
||||
data = _json(response)
|
||||
assert data["total"] == 1
|
||||
assert data["chunks"][0]["text"] == "a"
|
||||
assert data["path"] == "file.pdf"
|
||||
|
||||
def test_filters_chunks_by_file_path_metadata(self, app):
|
||||
from application.api.user.sources.chunks import GetChunks
|
||||
|
||||
doc_id = str(ObjectId())
|
||||
mock_collection = Mock()
|
||||
mock_collection.find_one.return_value = {"_id": ObjectId(doc_id), "user": "u1"}
|
||||
|
||||
chunks = [
|
||||
{"text": "a", "metadata": {"source": "inputs/dir/file.pdf"}, "doc_id": "c1"},
|
||||
{"text": "c", "metadata": {"file_path": "guides/setup.md"}, "doc_id": "c3"},
|
||||
]
|
||||
mock_store = Mock()
|
||||
mock_store.get_chunks.return_value = chunks
|
||||
|
||||
with patch(
|
||||
"application.api.user.sources.chunks.sources_collection",
|
||||
mock_collection,
|
||||
), patch(
|
||||
"application.api.user.sources.chunks.get_vector_store",
|
||||
return_value=mock_store,
|
||||
):
|
||||
with app.test_request_context(
|
||||
f"/api/get_chunks?id={doc_id}&path=setup.md"
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = GetChunks().get()
|
||||
|
||||
data = _json(response)
|
||||
assert data["total"] == 1
|
||||
assert data["chunks"][0]["text"] == "c"
|
||||
|
||||
def test_filters_chunks_by_search_term(self, app):
|
||||
from application.api.user.sources.chunks import GetChunks
|
||||
|
||||
doc_id = str(ObjectId())
|
||||
mock_collection = Mock()
|
||||
mock_collection.find_one.return_value = {"_id": ObjectId(doc_id), "user": "u1"}
|
||||
|
||||
chunks = [
|
||||
{"text": "Python is great", "metadata": {"title": "intro"}, "doc_id": "c1"},
|
||||
{"text": "Java tutorial", "metadata": {"title": "java guide"}, "doc_id": "c2"},
|
||||
{"text": "Hello world", "metadata": {"title": "Python Basics"}, "doc_id": "c3"},
|
||||
]
|
||||
mock_store = Mock()
|
||||
mock_store.get_chunks.return_value = chunks
|
||||
|
||||
with patch(
|
||||
"application.api.user.sources.chunks.sources_collection",
|
||||
mock_collection,
|
||||
), patch(
|
||||
"application.api.user.sources.chunks.get_vector_store",
|
||||
return_value=mock_store,
|
||||
):
|
||||
with app.test_request_context(
|
||||
f"/api/get_chunks?id={doc_id}&search=python"
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = GetChunks().get()
|
||||
|
||||
data = _json(response)
|
||||
assert data["total"] == 2
|
||||
assert data["search"] == "python"
|
||||
|
||||
def test_combines_path_and_search_filters(self, app):
|
||||
from application.api.user.sources.chunks import GetChunks
|
||||
|
||||
doc_id = str(ObjectId())
|
||||
mock_collection = Mock()
|
||||
mock_collection.find_one.return_value = {"_id": ObjectId(doc_id), "user": "u1"}
|
||||
|
||||
chunks = [
|
||||
{"text": "Python intro", "metadata": {"source": "dir/intro.md", "title": ""}, "doc_id": "c1"},
|
||||
{"text": "Python deep", "metadata": {"source": "dir/deep.md", "title": ""}, "doc_id": "c2"},
|
||||
{"text": "Java intro", "metadata": {"source": "dir/intro.md", "title": ""}, "doc_id": "c3"},
|
||||
]
|
||||
mock_store = Mock()
|
||||
mock_store.get_chunks.return_value = chunks
|
||||
|
||||
with patch(
|
||||
"application.api.user.sources.chunks.sources_collection",
|
||||
mock_collection,
|
||||
), patch(
|
||||
"application.api.user.sources.chunks.get_vector_store",
|
||||
return_value=mock_store,
|
||||
):
|
||||
with app.test_request_context(
|
||||
f"/api/get_chunks?id={doc_id}&path=intro.md&search=python"
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = GetChunks().get()
|
||||
|
||||
data = _json(response)
|
||||
assert data["total"] == 1
|
||||
assert data["chunks"][0]["doc_id"] == "c1"
|
||||
|
||||
def test_returns_500_on_store_error(self, app):
|
||||
from application.api.user.sources.chunks import GetChunks
|
||||
|
||||
doc_id = str(ObjectId())
|
||||
mock_collection = Mock()
|
||||
mock_collection.find_one.return_value = {"_id": ObjectId(doc_id), "user": "u1"}
|
||||
mock_store = Mock()
|
||||
mock_store.get_chunks.side_effect = Exception("Store error")
|
||||
|
||||
with patch(
|
||||
"application.api.user.sources.chunks.sources_collection",
|
||||
mock_collection,
|
||||
), patch(
|
||||
"application.api.user.sources.chunks.get_vector_store",
|
||||
return_value=mock_store,
|
||||
):
|
||||
with app.test_request_context(f"/api/get_chunks?id={doc_id}"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = GetChunks().get()
|
||||
|
||||
assert _status(response) == 500
|
||||
|
||||
def test_no_path_or_search_returns_null_fields(self, app):
|
||||
from application.api.user.sources.chunks import GetChunks
|
||||
|
||||
doc_id = str(ObjectId())
|
||||
mock_collection = Mock()
|
||||
mock_collection.find_one.return_value = {"_id": ObjectId(doc_id), "user": "u1"}
|
||||
mock_store = Mock()
|
||||
mock_store.get_chunks.return_value = []
|
||||
|
||||
with patch(
|
||||
"application.api.user.sources.chunks.sources_collection",
|
||||
mock_collection,
|
||||
), patch(
|
||||
"application.api.user.sources.chunks.get_vector_store",
|
||||
return_value=mock_store,
|
||||
):
|
||||
with app.test_request_context(f"/api/get_chunks?id={doc_id}"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = GetChunks().get()
|
||||
|
||||
data = _json(response)
|
||||
assert data["path"] is None
|
||||
assert data["search"] is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AddChunk
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestAddChunk:
|
||||
def test_returns_401_unauthenticated(self, app):
|
||||
from application.api.user.sources.chunks import AddChunk
|
||||
|
||||
with app.test_request_context(
|
||||
"/api/add_chunk", method="POST", json={"id": "abc", "text": "hi"}
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = None
|
||||
response = AddChunk().post()
|
||||
|
||||
assert _status(response) == 401
|
||||
|
||||
def test_returns_400_missing_required_fields(self, app):
|
||||
from application.api.user.sources.chunks import AddChunk
|
||||
|
||||
with app.test_request_context(
|
||||
"/api/add_chunk", method="POST", json={"id": str(ObjectId())}
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = AddChunk().post()
|
||||
|
||||
# check_required_fields returns a tuple (response, status)
|
||||
assert response is not None
|
||||
|
||||
def test_returns_400_for_invalid_doc_id(self, app):
|
||||
from application.api.user.sources.chunks import AddChunk
|
||||
|
||||
with app.test_request_context(
|
||||
"/api/add_chunk", method="POST", json={"id": "bad", "text": "hi"}
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = AddChunk().post()
|
||||
|
||||
assert _status(response) == 400
|
||||
assert "Invalid doc_id" in _json(response)["error"]
|
||||
|
||||
def test_returns_404_when_doc_not_found(self, app):
|
||||
from application.api.user.sources.chunks import AddChunk
|
||||
|
||||
doc_id = str(ObjectId())
|
||||
mock_collection = Mock()
|
||||
mock_collection.find_one.return_value = None
|
||||
|
||||
with patch(
|
||||
"application.api.user.sources.chunks.sources_collection",
|
||||
mock_collection,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/add_chunk", method="POST",
|
||||
json={"id": doc_id, "text": "hello"},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = AddChunk().post()
|
||||
|
||||
assert _status(response) == 404
|
||||
|
||||
def test_adds_chunk_successfully(self, app):
|
||||
from application.api.user.sources.chunks import AddChunk
|
||||
|
||||
doc_id = str(ObjectId())
|
||||
mock_collection = Mock()
|
||||
mock_collection.find_one.return_value = {"_id": ObjectId(doc_id), "user": "u1"}
|
||||
mock_store = Mock()
|
||||
mock_store.add_chunk.return_value = "new-chunk-id"
|
||||
|
||||
with patch(
|
||||
"application.api.user.sources.chunks.sources_collection",
|
||||
mock_collection,
|
||||
), patch(
|
||||
"application.api.user.sources.chunks.get_vector_store",
|
||||
return_value=mock_store,
|
||||
), patch(
|
||||
"application.api.user.sources.chunks.num_tokens_from_string",
|
||||
return_value=5,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/add_chunk", method="POST",
|
||||
json={"id": doc_id, "text": "hello world"},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = AddChunk().post()
|
||||
|
||||
assert _status(response) == 201
|
||||
data = _json(response)
|
||||
assert data["chunk_id"] == "new-chunk-id"
|
||||
assert "successfully" in data["message"]
|
||||
call_args = mock_store.add_chunk.call_args
|
||||
assert call_args[0][0] == "hello world"
|
||||
assert call_args[0][1]["token_count"] == 5
|
||||
|
||||
def test_adds_chunk_with_custom_metadata(self, app):
|
||||
from application.api.user.sources.chunks import AddChunk
|
||||
|
||||
doc_id = str(ObjectId())
|
||||
mock_collection = Mock()
|
||||
mock_collection.find_one.return_value = {"_id": ObjectId(doc_id), "user": "u1"}
|
||||
mock_store = Mock()
|
||||
mock_store.add_chunk.return_value = "cid"
|
||||
|
||||
with patch(
|
||||
"application.api.user.sources.chunks.sources_collection",
|
||||
mock_collection,
|
||||
), patch(
|
||||
"application.api.user.sources.chunks.get_vector_store",
|
||||
return_value=mock_store,
|
||||
), patch(
|
||||
"application.api.user.sources.chunks.num_tokens_from_string",
|
||||
return_value=3,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/add_chunk", method="POST",
|
||||
json={
|
||||
"id": doc_id,
|
||||
"text": "hi",
|
||||
"metadata": {"source": "test.pdf"},
|
||||
},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = AddChunk().post()
|
||||
|
||||
assert _status(response) == 201
|
||||
meta = mock_store.add_chunk.call_args[0][1]
|
||||
assert meta["source"] == "test.pdf"
|
||||
assert meta["token_count"] == 3
|
||||
|
||||
def test_returns_500_on_store_error(self, app):
|
||||
from application.api.user.sources.chunks import AddChunk
|
||||
|
||||
doc_id = str(ObjectId())
|
||||
mock_collection = Mock()
|
||||
mock_collection.find_one.return_value = {"_id": ObjectId(doc_id), "user": "u1"}
|
||||
mock_store = Mock()
|
||||
mock_store.add_chunk.side_effect = Exception("fail")
|
||||
|
||||
with patch(
|
||||
"application.api.user.sources.chunks.sources_collection",
|
||||
mock_collection,
|
||||
), patch(
|
||||
"application.api.user.sources.chunks.get_vector_store",
|
||||
return_value=mock_store,
|
||||
), patch(
|
||||
"application.api.user.sources.chunks.num_tokens_from_string",
|
||||
return_value=1,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/add_chunk", method="POST",
|
||||
json={"id": doc_id, "text": "hello"},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = AddChunk().post()
|
||||
|
||||
assert _status(response) == 500
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DeleteChunk
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestDeleteChunk:
|
||||
def test_returns_401_unauthenticated(self, app):
|
||||
from application.api.user.sources.chunks import DeleteChunk
|
||||
|
||||
with app.test_request_context("/api/delete_chunk?id=abc&chunk_id=xyz"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = None
|
||||
response = DeleteChunk().delete()
|
||||
|
||||
assert _status(response) == 401
|
||||
|
||||
def test_returns_400_for_invalid_doc_id(self, app):
|
||||
from application.api.user.sources.chunks import DeleteChunk
|
||||
|
||||
with app.test_request_context("/api/delete_chunk?id=bad&chunk_id=xyz"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = DeleteChunk().delete()
|
||||
|
||||
assert _status(response) == 400
|
||||
assert "Invalid doc_id" in _json(response)["error"]
|
||||
|
||||
def test_returns_404_when_doc_not_found(self, app):
|
||||
from application.api.user.sources.chunks import DeleteChunk
|
||||
|
||||
doc_id = str(ObjectId())
|
||||
mock_collection = Mock()
|
||||
mock_collection.find_one.return_value = None
|
||||
|
||||
with patch(
|
||||
"application.api.user.sources.chunks.sources_collection",
|
||||
mock_collection,
|
||||
):
|
||||
with app.test_request_context(
|
||||
f"/api/delete_chunk?id={doc_id}&chunk_id=cid"
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = DeleteChunk().delete()
|
||||
|
||||
assert _status(response) == 404
|
||||
|
||||
def test_deletes_chunk_successfully(self, app):
|
||||
from application.api.user.sources.chunks import DeleteChunk
|
||||
|
||||
doc_id = str(ObjectId())
|
||||
mock_collection = Mock()
|
||||
mock_collection.find_one.return_value = {"_id": ObjectId(doc_id), "user": "u1"}
|
||||
mock_store = Mock()
|
||||
mock_store.delete_chunk.return_value = True
|
||||
|
||||
with patch(
|
||||
"application.api.user.sources.chunks.sources_collection",
|
||||
mock_collection,
|
||||
), patch(
|
||||
"application.api.user.sources.chunks.get_vector_store",
|
||||
return_value=mock_store,
|
||||
):
|
||||
with app.test_request_context(
|
||||
f"/api/delete_chunk?id={doc_id}&chunk_id=cid"
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = DeleteChunk().delete()
|
||||
|
||||
assert _status(response) == 200
|
||||
assert "successfully" in _json(response)["message"]
|
||||
mock_store.delete_chunk.assert_called_once_with("cid")
|
||||
|
||||
def test_returns_404_when_chunk_not_deleted(self, app):
|
||||
from application.api.user.sources.chunks import DeleteChunk
|
||||
|
||||
doc_id = str(ObjectId())
|
||||
mock_collection = Mock()
|
||||
mock_collection.find_one.return_value = {"_id": ObjectId(doc_id), "user": "u1"}
|
||||
mock_store = Mock()
|
||||
mock_store.delete_chunk.return_value = False
|
||||
|
||||
with patch(
|
||||
"application.api.user.sources.chunks.sources_collection",
|
||||
mock_collection,
|
||||
), patch(
|
||||
"application.api.user.sources.chunks.get_vector_store",
|
||||
return_value=mock_store,
|
||||
):
|
||||
with app.test_request_context(
|
||||
f"/api/delete_chunk?id={doc_id}&chunk_id=missing"
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = DeleteChunk().delete()
|
||||
|
||||
assert _status(response) == 404
|
||||
assert "not found" in _json(response)["message"]
|
||||
|
||||
def test_returns_500_on_store_error(self, app):
|
||||
from application.api.user.sources.chunks import DeleteChunk
|
||||
|
||||
doc_id = str(ObjectId())
|
||||
mock_collection = Mock()
|
||||
mock_collection.find_one.return_value = {"_id": ObjectId(doc_id), "user": "u1"}
|
||||
mock_store = Mock()
|
||||
mock_store.delete_chunk.side_effect = Exception("boom")
|
||||
|
||||
with patch(
|
||||
"application.api.user.sources.chunks.sources_collection",
|
||||
mock_collection,
|
||||
), patch(
|
||||
"application.api.user.sources.chunks.get_vector_store",
|
||||
return_value=mock_store,
|
||||
):
|
||||
with app.test_request_context(
|
||||
f"/api/delete_chunk?id={doc_id}&chunk_id=cid"
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = DeleteChunk().delete()
|
||||
|
||||
assert _status(response) == 500
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# UpdateChunk
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestUpdateChunk:
|
||||
def test_returns_401_unauthenticated(self, app):
|
||||
from application.api.user.sources.chunks import UpdateChunk
|
||||
|
||||
with app.test_request_context(
|
||||
"/api/update_chunk", method="PUT",
|
||||
json={"id": "abc", "chunk_id": "cid"},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = None
|
||||
response = UpdateChunk().put()
|
||||
|
||||
assert _status(response) == 401
|
||||
|
||||
def test_returns_400_missing_required_fields(self, app):
|
||||
from application.api.user.sources.chunks import UpdateChunk
|
||||
|
||||
with app.test_request_context(
|
||||
"/api/update_chunk", method="PUT", json={"id": str(ObjectId())}
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = UpdateChunk().put()
|
||||
|
||||
assert response is not None
|
||||
|
||||
def test_returns_400_for_invalid_doc_id(self, app):
|
||||
from application.api.user.sources.chunks import UpdateChunk
|
||||
|
||||
with app.test_request_context(
|
||||
"/api/update_chunk", method="PUT",
|
||||
json={"id": "bad", "chunk_id": "cid"},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = UpdateChunk().put()
|
||||
|
||||
assert _status(response) == 400
|
||||
|
||||
def test_returns_404_when_doc_not_found(self, app):
|
||||
from application.api.user.sources.chunks import UpdateChunk
|
||||
|
||||
doc_id = str(ObjectId())
|
||||
mock_collection = Mock()
|
||||
mock_collection.find_one.return_value = None
|
||||
|
||||
with patch(
|
||||
"application.api.user.sources.chunks.sources_collection",
|
||||
mock_collection,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/update_chunk", method="PUT",
|
||||
json={"id": doc_id, "chunk_id": "cid"},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = UpdateChunk().put()
|
||||
|
||||
assert _status(response) == 404
|
||||
|
||||
def test_returns_404_when_chunk_not_found(self, app):
|
||||
from application.api.user.sources.chunks import UpdateChunk
|
||||
|
||||
doc_id = str(ObjectId())
|
||||
mock_collection = Mock()
|
||||
mock_collection.find_one.return_value = {"_id": ObjectId(doc_id), "user": "u1"}
|
||||
mock_store = Mock()
|
||||
mock_store.get_chunks.return_value = [
|
||||
{"doc_id": "other", "text": "x", "metadata": {}},
|
||||
]
|
||||
|
||||
with patch(
|
||||
"application.api.user.sources.chunks.sources_collection",
|
||||
mock_collection,
|
||||
), patch(
|
||||
"application.api.user.sources.chunks.get_vector_store",
|
||||
return_value=mock_store,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/update_chunk", method="PUT",
|
||||
json={"id": doc_id, "chunk_id": "missing"},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = UpdateChunk().put()
|
||||
|
||||
assert _status(response) == 404
|
||||
assert "Chunk not found" in _json(response)["error"]
|
||||
|
||||
def test_updates_chunk_text_successfully(self, app):
|
||||
from application.api.user.sources.chunks import UpdateChunk
|
||||
|
||||
doc_id = str(ObjectId())
|
||||
mock_collection = Mock()
|
||||
mock_collection.find_one.return_value = {"_id": ObjectId(doc_id), "user": "u1"}
|
||||
mock_store = Mock()
|
||||
mock_store.get_chunks.return_value = [
|
||||
{"doc_id": "cid", "text": "old text", "metadata": {"source": "f.pdf"}},
|
||||
]
|
||||
mock_store.add_chunk.return_value = "new-cid"
|
||||
mock_store.delete_chunk.return_value = True
|
||||
|
||||
with patch(
|
||||
"application.api.user.sources.chunks.sources_collection",
|
||||
mock_collection,
|
||||
), patch(
|
||||
"application.api.user.sources.chunks.get_vector_store",
|
||||
return_value=mock_store,
|
||||
), patch(
|
||||
"application.api.user.sources.chunks.num_tokens_from_string",
|
||||
return_value=7,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/update_chunk", method="PUT",
|
||||
json={"id": doc_id, "chunk_id": "cid", "text": "new text"},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = UpdateChunk().put()
|
||||
|
||||
assert _status(response) == 200
|
||||
data = _json(response)
|
||||
assert data["chunk_id"] == "new-cid"
|
||||
assert data["original_chunk_id"] == "cid"
|
||||
# Verify add was called with new text and merged metadata
|
||||
add_call = mock_store.add_chunk.call_args
|
||||
assert add_call[0][0] == "new text"
|
||||
assert add_call[0][1]["source"] == "f.pdf"
|
||||
assert add_call[0][1]["token_count"] == 7
|
||||
mock_store.delete_chunk.assert_called_once_with("cid")
|
||||
|
||||
def test_updates_chunk_metadata_only(self, app):
|
||||
from application.api.user.sources.chunks import UpdateChunk
|
||||
|
||||
doc_id = str(ObjectId())
|
||||
mock_collection = Mock()
|
||||
mock_collection.find_one.return_value = {"_id": ObjectId(doc_id), "user": "u1"}
|
||||
mock_store = Mock()
|
||||
mock_store.get_chunks.return_value = [
|
||||
{"doc_id": "cid", "text": "keep me", "metadata": {"source": "f.pdf"}},
|
||||
]
|
||||
mock_store.add_chunk.return_value = "new-cid"
|
||||
mock_store.delete_chunk.return_value = True
|
||||
|
||||
with patch(
|
||||
"application.api.user.sources.chunks.sources_collection",
|
||||
mock_collection,
|
||||
), patch(
|
||||
"application.api.user.sources.chunks.get_vector_store",
|
||||
return_value=mock_store,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/update_chunk", method="PUT",
|
||||
json={
|
||||
"id": doc_id,
|
||||
"chunk_id": "cid",
|
||||
"metadata": {"title": "new title"},
|
||||
},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = UpdateChunk().put()
|
||||
|
||||
assert _status(response) == 200
|
||||
add_call = mock_store.add_chunk.call_args
|
||||
# text should be preserved
|
||||
assert add_call[0][0] == "keep me"
|
||||
# metadata should be merged
|
||||
assert add_call[0][1]["source"] == "f.pdf"
|
||||
assert add_call[0][1]["title"] == "new title"
|
||||
|
||||
def test_update_warns_when_old_chunk_delete_fails(self, app):
|
||||
from application.api.user.sources.chunks import UpdateChunk
|
||||
|
||||
doc_id = str(ObjectId())
|
||||
mock_collection = Mock()
|
||||
mock_collection.find_one.return_value = {"_id": ObjectId(doc_id), "user": "u1"}
|
||||
mock_store = Mock()
|
||||
mock_store.get_chunks.return_value = [
|
||||
{"doc_id": "cid", "text": "text", "metadata": {}},
|
||||
]
|
||||
mock_store.add_chunk.return_value = "new-cid"
|
||||
mock_store.delete_chunk.return_value = False # delete fails
|
||||
|
||||
with patch(
|
||||
"application.api.user.sources.chunks.sources_collection",
|
||||
mock_collection,
|
||||
), patch(
|
||||
"application.api.user.sources.chunks.get_vector_store",
|
||||
return_value=mock_store,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/update_chunk", method="PUT",
|
||||
json={"id": doc_id, "chunk_id": "cid"},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = UpdateChunk().put()
|
||||
|
||||
# Still returns 200 with a warning logged
|
||||
assert _status(response) == 200
|
||||
|
||||
def test_returns_500_when_add_chunk_fails(self, app):
|
||||
from application.api.user.sources.chunks import UpdateChunk
|
||||
|
||||
doc_id = str(ObjectId())
|
||||
mock_collection = Mock()
|
||||
mock_collection.find_one.return_value = {"_id": ObjectId(doc_id), "user": "u1"}
|
||||
mock_store = Mock()
|
||||
mock_store.get_chunks.return_value = [
|
||||
{"doc_id": "cid", "text": "text", "metadata": {}},
|
||||
]
|
||||
mock_store.add_chunk.side_effect = Exception("add failed")
|
||||
|
||||
with patch(
|
||||
"application.api.user.sources.chunks.sources_collection",
|
||||
mock_collection,
|
||||
), patch(
|
||||
"application.api.user.sources.chunks.get_vector_store",
|
||||
return_value=mock_store,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/update_chunk", method="PUT",
|
||||
json={"id": doc_id, "chunk_id": "cid", "text": "new"},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = UpdateChunk().put()
|
||||
|
||||
assert _status(response) == 500
|
||||
assert "addition failed" in _json(response)["error"]
|
||||
|
||||
def test_returns_500_on_general_store_error(self, app):
|
||||
from application.api.user.sources.chunks import UpdateChunk
|
||||
|
||||
doc_id = str(ObjectId())
|
||||
mock_collection = Mock()
|
||||
mock_collection.find_one.return_value = {"_id": ObjectId(doc_id), "user": "u1"}
|
||||
mock_store = Mock()
|
||||
mock_store.get_chunks.side_effect = Exception("connection lost")
|
||||
|
||||
with patch(
|
||||
"application.api.user.sources.chunks.sources_collection",
|
||||
mock_collection,
|
||||
), patch(
|
||||
"application.api.user.sources.chunks.get_vector_store",
|
||||
return_value=mock_store,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/update_chunk", method="PUT",
|
||||
json={"id": doc_id, "chunk_id": "cid"},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = UpdateChunk().put()
|
||||
|
||||
assert _status(response) == 500
|
||||
965
tests/api/user/sources/test_source_routes.py
Normal file
965
tests/api/user/sources/test_source_routes.py
Normal file
@@ -0,0 +1,965 @@
|
||||
"""Tests for source management routes (CombinedJson, PaginatedSources,
|
||||
DeleteByIds, DeleteOldIndexes, ManageSync, DirectoryStructure).
|
||||
|
||||
Note: SyncSource and _get_provider_from_remote_data are already covered in
|
||||
test_routes.py and are NOT duplicated here.
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
from bson import ObjectId
|
||||
from flask import Flask
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
app = Flask(__name__)
|
||||
return app
|
||||
|
||||
|
||||
def _status(response):
|
||||
if isinstance(response, tuple):
|
||||
return response[1]
|
||||
return response.status_code
|
||||
|
||||
|
||||
def _json(response):
|
||||
if isinstance(response, tuple):
|
||||
return response[0].json
|
||||
return response.json
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CombinedJson (/api/sources)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCombinedJson:
|
||||
|
||||
def test_returns_401_unauthenticated(self, app):
|
||||
from application.api.user.sources.routes import CombinedJson
|
||||
|
||||
with app.test_request_context("/api/sources"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = None
|
||||
response = CombinedJson().get()
|
||||
|
||||
assert _status(response) == 401
|
||||
|
||||
def test_returns_default_source_plus_user_sources(self, app):
|
||||
from application.api.user.sources.routes import CombinedJson
|
||||
|
||||
src_id = ObjectId()
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.sort.return_value = [
|
||||
{
|
||||
"_id": src_id,
|
||||
"name": "My Doc",
|
||||
"date": "2024-01-01",
|
||||
"tokens": "100",
|
||||
"retriever": "classic",
|
||||
"sync_frequency": "daily",
|
||||
"remote_data": json.dumps({"provider": "github"}),
|
||||
"directory_structure": None,
|
||||
"type": "file",
|
||||
}
|
||||
]
|
||||
mock_collection = Mock()
|
||||
mock_collection.find.return_value = mock_cursor
|
||||
|
||||
with patch(
|
||||
"application.api.user.sources.routes.sources_collection",
|
||||
mock_collection,
|
||||
):
|
||||
with app.test_request_context("/api/sources"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = CombinedJson().get()
|
||||
|
||||
assert _status(response) == 200
|
||||
data = _json(response)
|
||||
# First entry is always the Default
|
||||
assert data[0]["name"] == "Default"
|
||||
assert data[0]["date"] == "default"
|
||||
# Second entry is user source
|
||||
assert data[1]["id"] == str(src_id)
|
||||
assert data[1]["name"] == "My Doc"
|
||||
assert data[1]["provider"] == "github"
|
||||
assert data[1]["syncFrequency"] == "daily"
|
||||
assert data[1]["is_nested"] is False
|
||||
|
||||
def test_is_nested_true_when_directory_structure_present(self, app):
|
||||
from application.api.user.sources.routes import CombinedJson
|
||||
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.sort.return_value = [
|
||||
{
|
||||
"_id": ObjectId(),
|
||||
"name": "Nested",
|
||||
"date": "2024-01-01",
|
||||
"directory_structure": {"files": ["a.txt"]},
|
||||
}
|
||||
]
|
||||
mock_collection = Mock()
|
||||
mock_collection.find.return_value = mock_cursor
|
||||
|
||||
with patch(
|
||||
"application.api.user.sources.routes.sources_collection",
|
||||
mock_collection,
|
||||
):
|
||||
with app.test_request_context("/api/sources"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = CombinedJson().get()
|
||||
|
||||
data = _json(response)
|
||||
assert data[1]["is_nested"] is True
|
||||
|
||||
def test_returns_400_on_db_error(self, app):
|
||||
from application.api.user.sources.routes import CombinedJson
|
||||
|
||||
mock_collection = Mock()
|
||||
mock_collection.find.side_effect = Exception("db err")
|
||||
|
||||
with patch(
|
||||
"application.api.user.sources.routes.sources_collection",
|
||||
mock_collection,
|
||||
):
|
||||
with app.test_request_context("/api/sources"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = CombinedJson().get()
|
||||
|
||||
assert _status(response) == 400
|
||||
|
||||
def test_type_defaults_to_file(self, app):
|
||||
from application.api.user.sources.routes import CombinedJson
|
||||
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.sort.return_value = [
|
||||
{"_id": ObjectId(), "name": "X", "date": "d"}
|
||||
]
|
||||
mock_collection = Mock()
|
||||
mock_collection.find.return_value = mock_cursor
|
||||
|
||||
with patch(
|
||||
"application.api.user.sources.routes.sources_collection",
|
||||
mock_collection,
|
||||
):
|
||||
with app.test_request_context("/api/sources"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = CombinedJson().get()
|
||||
|
||||
data = _json(response)
|
||||
assert data[1]["type"] == "file"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PaginatedSources (/api/sources/paginated)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestPaginatedSources:
|
||||
|
||||
def test_returns_401_unauthenticated(self, app):
|
||||
from application.api.user.sources.routes import PaginatedSources
|
||||
|
||||
with app.test_request_context("/api/sources/paginated"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = None
|
||||
response = PaginatedSources().get()
|
||||
|
||||
assert _status(response) == 401
|
||||
|
||||
def test_returns_paginated_results(self, app):
|
||||
from application.api.user.sources.routes import PaginatedSources
|
||||
|
||||
ids = [ObjectId() for _ in range(3)]
|
||||
docs = [
|
||||
{"_id": ids[i], "name": f"Doc{i}", "date": f"2024-0{i + 1}-01"}
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.sort.return_value = mock_cursor
|
||||
mock_cursor.skip.return_value = mock_cursor
|
||||
mock_cursor.limit.return_value = docs
|
||||
|
||||
mock_collection = Mock()
|
||||
mock_collection.find.return_value = mock_cursor
|
||||
mock_collection.count_documents.return_value = 3
|
||||
|
||||
with patch(
|
||||
"application.api.user.sources.routes.sources_collection",
|
||||
mock_collection,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/sources/paginated?page=1&rows=10&sort=date&order=desc"
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = PaginatedSources().get()
|
||||
|
||||
assert _status(response) == 200
|
||||
data = _json(response)
|
||||
assert data["total"] == 3
|
||||
assert data["totalPages"] == 1
|
||||
assert data["currentPage"] == 1
|
||||
assert len(data["paginated"]) == 3
|
||||
|
||||
def test_search_filter_applies_regex(self, app):
|
||||
from application.api.user.sources.routes import PaginatedSources
|
||||
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.sort.return_value = mock_cursor
|
||||
mock_cursor.skip.return_value = mock_cursor
|
||||
mock_cursor.limit.return_value = []
|
||||
|
||||
mock_collection = Mock()
|
||||
mock_collection.find.return_value = mock_cursor
|
||||
mock_collection.count_documents.return_value = 0
|
||||
|
||||
with patch(
|
||||
"application.api.user.sources.routes.sources_collection",
|
||||
mock_collection,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/sources/paginated?search=test%20doc"
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = PaginatedSources().get()
|
||||
|
||||
assert _status(response) == 200
|
||||
# Verify search query was passed
|
||||
query_arg = mock_collection.count_documents.call_args[0][0]
|
||||
assert query_arg["name"]["$regex"] == "test doc"
|
||||
assert query_arg["name"]["$options"] == "i"
|
||||
|
||||
def test_ascending_sort_order(self, app):
|
||||
from application.api.user.sources.routes import PaginatedSources
|
||||
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.sort.return_value = mock_cursor
|
||||
mock_cursor.skip.return_value = mock_cursor
|
||||
mock_cursor.limit.return_value = []
|
||||
|
||||
mock_collection = Mock()
|
||||
mock_collection.find.return_value = mock_cursor
|
||||
mock_collection.count_documents.return_value = 0
|
||||
|
||||
with patch(
|
||||
"application.api.user.sources.routes.sources_collection",
|
||||
mock_collection,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/sources/paginated?order=asc&sort=name"
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = PaginatedSources().get()
|
||||
|
||||
assert _status(response) == 200
|
||||
mock_cursor.sort.assert_called_once_with("name", 1)
|
||||
|
||||
def test_page_clamped_to_valid_range(self, app):
|
||||
from application.api.user.sources.routes import PaginatedSources
|
||||
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.sort.return_value = mock_cursor
|
||||
mock_cursor.skip.return_value = mock_cursor
|
||||
mock_cursor.limit.return_value = []
|
||||
|
||||
mock_collection = Mock()
|
||||
mock_collection.find.return_value = mock_cursor
|
||||
mock_collection.count_documents.return_value = 5 # 1 page with default 10 rows
|
||||
|
||||
with patch(
|
||||
"application.api.user.sources.routes.sources_collection",
|
||||
mock_collection,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/sources/paginated?page=999"
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = PaginatedSources().get()
|
||||
|
||||
data = _json(response)
|
||||
assert data["currentPage"] == 1 # clamped
|
||||
|
||||
def test_returns_400_on_db_error(self, app):
|
||||
from application.api.user.sources.routes import PaginatedSources
|
||||
|
||||
mock_collection = Mock()
|
||||
mock_collection.count_documents.return_value = 0
|
||||
mock_collection.find.side_effect = Exception("db error")
|
||||
|
||||
with patch(
|
||||
"application.api.user.sources.routes.sources_collection",
|
||||
mock_collection,
|
||||
):
|
||||
with app.test_request_context("/api/sources/paginated"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = PaginatedSources().get()
|
||||
|
||||
assert _status(response) == 400
|
||||
|
||||
def test_paginated_includes_provider_and_is_nested(self, app):
|
||||
from application.api.user.sources.routes import PaginatedSources
|
||||
|
||||
doc = {
|
||||
"_id": ObjectId(),
|
||||
"name": "S3 Src",
|
||||
"date": "2024-01-01",
|
||||
"remote_data": {"provider": "s3"},
|
||||
"directory_structure": {"dirs": ["a"]},
|
||||
"type": "s3",
|
||||
}
|
||||
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.sort.return_value = mock_cursor
|
||||
mock_cursor.skip.return_value = mock_cursor
|
||||
mock_cursor.limit.return_value = [doc]
|
||||
|
||||
mock_collection = Mock()
|
||||
mock_collection.find.return_value = mock_cursor
|
||||
mock_collection.count_documents.return_value = 1
|
||||
|
||||
with patch(
|
||||
"application.api.user.sources.routes.sources_collection",
|
||||
mock_collection,
|
||||
):
|
||||
with app.test_request_context("/api/sources/paginated"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = PaginatedSources().get()
|
||||
|
||||
data = _json(response)
|
||||
entry = data["paginated"][0]
|
||||
assert entry["provider"] == "s3"
|
||||
assert entry["isNested"] is True
|
||||
assert entry["type"] == "s3"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DeleteByIds (/api/delete_by_ids)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestDeleteByIds:
|
||||
|
||||
def test_returns_400_when_path_missing(self, app):
|
||||
from application.api.user.sources.routes import DeleteByIds
|
||||
|
||||
with app.test_request_context("/api/delete_by_ids"):
|
||||
response = DeleteByIds().get()
|
||||
|
||||
assert _status(response) == 400
|
||||
assert "Missing" in _json(response)["message"]
|
||||
|
||||
def test_returns_200_on_successful_delete(self, app):
|
||||
from application.api.user.sources.routes import DeleteByIds
|
||||
|
||||
mock_collection = Mock()
|
||||
mock_collection.delete_index.return_value = True
|
||||
|
||||
with patch(
|
||||
"application.api.user.sources.routes.sources_collection",
|
||||
mock_collection,
|
||||
):
|
||||
with app.test_request_context("/api/delete_by_ids?path=id1,id2"):
|
||||
response = DeleteByIds().get()
|
||||
|
||||
assert _status(response) == 200
|
||||
assert _json(response)["success"] is True
|
||||
mock_collection.delete_index.assert_called_once_with(ids="id1,id2")
|
||||
|
||||
def test_returns_400_when_delete_returns_false(self, app):
|
||||
from application.api.user.sources.routes import DeleteByIds
|
||||
|
||||
mock_collection = Mock()
|
||||
mock_collection.delete_index.return_value = False
|
||||
|
||||
with patch(
|
||||
"application.api.user.sources.routes.sources_collection",
|
||||
mock_collection,
|
||||
):
|
||||
with app.test_request_context("/api/delete_by_ids?path=id1"):
|
||||
response = DeleteByIds().get()
|
||||
|
||||
assert _status(response) == 400
|
||||
|
||||
def test_returns_400_on_exception(self, app):
|
||||
from application.api.user.sources.routes import DeleteByIds
|
||||
|
||||
mock_collection = Mock()
|
||||
mock_collection.delete_index.side_effect = Exception("fail")
|
||||
|
||||
with patch(
|
||||
"application.api.user.sources.routes.sources_collection",
|
||||
mock_collection,
|
||||
):
|
||||
with app.test_request_context("/api/delete_by_ids?path=id1"):
|
||||
response = DeleteByIds().get()
|
||||
|
||||
assert _status(response) == 400
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DeleteOldIndexes (/api/delete_old)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestDeleteOldIndexes:
|
||||
|
||||
def test_returns_401_unauthenticated(self, app):
|
||||
from application.api.user.sources.routes import DeleteOldIndexes
|
||||
|
||||
with app.test_request_context("/api/delete_old?source_id=abc"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = None
|
||||
response = DeleteOldIndexes().get()
|
||||
|
||||
assert _status(response) == 401
|
||||
|
||||
def test_returns_400_when_source_id_missing(self, app):
|
||||
from application.api.user.sources.routes import DeleteOldIndexes
|
||||
|
||||
with app.test_request_context("/api/delete_old"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = DeleteOldIndexes().get()
|
||||
|
||||
assert _status(response) == 400
|
||||
|
||||
def test_returns_404_when_doc_not_found(self, app):
|
||||
from application.api.user.sources.routes import DeleteOldIndexes
|
||||
|
||||
source_id = str(ObjectId())
|
||||
mock_collection = Mock()
|
||||
mock_collection.find_one.return_value = None
|
||||
|
||||
with patch(
|
||||
"application.api.user.sources.routes.sources_collection",
|
||||
mock_collection,
|
||||
):
|
||||
with app.test_request_context(f"/api/delete_old?source_id={source_id}"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = DeleteOldIndexes().get()
|
||||
|
||||
assert _status(response) == 404
|
||||
|
||||
def test_deletes_faiss_index_and_file(self, app):
|
||||
from application.api.user.sources.routes import DeleteOldIndexes
|
||||
|
||||
source_id = ObjectId()
|
||||
mock_collection = Mock()
|
||||
mock_collection.find_one.return_value = {
|
||||
"_id": source_id,
|
||||
"user": "u1",
|
||||
"file_path": "uploads/u1/doc.pdf",
|
||||
}
|
||||
mock_storage = Mock()
|
||||
mock_storage.file_exists.return_value = True
|
||||
mock_storage.is_directory.return_value = False
|
||||
|
||||
with patch(
|
||||
"application.api.user.sources.routes.sources_collection",
|
||||
mock_collection,
|
||||
), patch(
|
||||
"application.api.user.sources.routes.StorageCreator.get_storage",
|
||||
return_value=mock_storage,
|
||||
), patch(
|
||||
"application.api.user.sources.routes.settings"
|
||||
) as mock_settings:
|
||||
mock_settings.VECTOR_STORE = "faiss"
|
||||
with app.test_request_context(
|
||||
f"/api/delete_old?source_id={source_id}"
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = DeleteOldIndexes().get()
|
||||
|
||||
assert _status(response) == 200
|
||||
assert _json(response)["success"] is True
|
||||
# Should have checked and deleted faiss files
|
||||
assert mock_storage.delete_file.call_count >= 1
|
||||
mock_collection.delete_one.assert_called_once()
|
||||
|
||||
def test_deletes_non_faiss_vector_index(self, app):
|
||||
from application.api.user.sources.routes import DeleteOldIndexes
|
||||
|
||||
source_id = ObjectId()
|
||||
mock_collection = Mock()
|
||||
mock_collection.find_one.return_value = {
|
||||
"_id": source_id,
|
||||
"user": "u1",
|
||||
}
|
||||
mock_storage = Mock()
|
||||
mock_vectorstore = Mock()
|
||||
|
||||
with patch(
|
||||
"application.api.user.sources.routes.sources_collection",
|
||||
mock_collection,
|
||||
), patch(
|
||||
"application.api.user.sources.routes.StorageCreator.get_storage",
|
||||
return_value=mock_storage,
|
||||
), patch(
|
||||
"application.api.user.sources.routes.VectorCreator.create_vectorstore",
|
||||
return_value=mock_vectorstore,
|
||||
), patch(
|
||||
"application.api.user.sources.routes.settings"
|
||||
) as mock_settings:
|
||||
mock_settings.VECTOR_STORE = "elasticsearch"
|
||||
with app.test_request_context(
|
||||
f"/api/delete_old?source_id={source_id}"
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = DeleteOldIndexes().get()
|
||||
|
||||
assert _status(response) == 200
|
||||
mock_vectorstore.delete_index.assert_called_once()
|
||||
|
||||
def test_deletes_directory_of_files(self, app):
|
||||
from application.api.user.sources.routes import DeleteOldIndexes
|
||||
|
||||
source_id = ObjectId()
|
||||
mock_collection = Mock()
|
||||
mock_collection.find_one.return_value = {
|
||||
"_id": source_id,
|
||||
"user": "u1",
|
||||
"file_path": "uploads/u1/mydir",
|
||||
}
|
||||
mock_storage = Mock()
|
||||
mock_storage.is_directory.return_value = True
|
||||
mock_storage.list_files.return_value = ["uploads/u1/mydir/a.txt", "uploads/u1/mydir/b.txt"]
|
||||
|
||||
with patch(
|
||||
"application.api.user.sources.routes.sources_collection",
|
||||
mock_collection,
|
||||
), patch(
|
||||
"application.api.user.sources.routes.StorageCreator.get_storage",
|
||||
return_value=mock_storage,
|
||||
), patch(
|
||||
"application.api.user.sources.routes.settings"
|
||||
) as mock_settings:
|
||||
mock_settings.VECTOR_STORE = "faiss"
|
||||
mock_storage.file_exists.return_value = False
|
||||
with app.test_request_context(
|
||||
f"/api/delete_old?source_id={source_id}"
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = DeleteOldIndexes().get()
|
||||
|
||||
assert _status(response) == 200
|
||||
# Each file in directory should be deleted
|
||||
assert mock_storage.delete_file.call_count == 2
|
||||
|
||||
def test_handles_file_not_found_gracefully(self, app):
|
||||
from application.api.user.sources.routes import DeleteOldIndexes
|
||||
|
||||
source_id = ObjectId()
|
||||
mock_collection = Mock()
|
||||
mock_collection.find_one.return_value = {
|
||||
"_id": source_id,
|
||||
"user": "u1",
|
||||
"file_path": "uploads/missing.pdf",
|
||||
}
|
||||
mock_storage = Mock()
|
||||
mock_storage.is_directory.side_effect = FileNotFoundError()
|
||||
|
||||
with patch(
|
||||
"application.api.user.sources.routes.sources_collection",
|
||||
mock_collection,
|
||||
), patch(
|
||||
"application.api.user.sources.routes.StorageCreator.get_storage",
|
||||
return_value=mock_storage,
|
||||
), patch(
|
||||
"application.api.user.sources.routes.settings"
|
||||
) as mock_settings:
|
||||
mock_settings.VECTOR_STORE = "faiss"
|
||||
mock_storage.file_exists.return_value = False
|
||||
with app.test_request_context(
|
||||
f"/api/delete_old?source_id={source_id}"
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = DeleteOldIndexes().get()
|
||||
|
||||
assert _status(response) == 200
|
||||
mock_collection.delete_one.assert_called_once()
|
||||
|
||||
def test_returns_400_on_general_error(self, app):
|
||||
from application.api.user.sources.routes import DeleteOldIndexes
|
||||
|
||||
source_id = ObjectId()
|
||||
mock_collection = Mock()
|
||||
mock_collection.find_one.return_value = {
|
||||
"_id": source_id,
|
||||
"user": "u1",
|
||||
}
|
||||
mock_storage = Mock()
|
||||
|
||||
with patch(
|
||||
"application.api.user.sources.routes.sources_collection",
|
||||
mock_collection,
|
||||
), patch(
|
||||
"application.api.user.sources.routes.StorageCreator.get_storage",
|
||||
return_value=mock_storage,
|
||||
), patch(
|
||||
"application.api.user.sources.routes.settings"
|
||||
) as mock_settings:
|
||||
mock_settings.VECTOR_STORE = "faiss"
|
||||
mock_storage.file_exists.side_effect = RuntimeError("disk error")
|
||||
with app.test_request_context(
|
||||
f"/api/delete_old?source_id={source_id}"
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = DeleteOldIndexes().get()
|
||||
|
||||
assert _status(response) == 400
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ManageSync (/api/manage_sync)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestManageSync:
|
||||
|
||||
def test_returns_401_unauthenticated(self, app):
|
||||
from application.api.user.sources.routes import ManageSync
|
||||
|
||||
with app.test_request_context(
|
||||
"/api/manage_sync", method="POST",
|
||||
json={"source_id": "x", "sync_frequency": "daily"},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = None
|
||||
response = ManageSync().post()
|
||||
|
||||
assert _status(response) == 401
|
||||
|
||||
def test_returns_400_missing_fields(self, app):
|
||||
from application.api.user.sources.routes import ManageSync
|
||||
|
||||
with app.test_request_context(
|
||||
"/api/manage_sync", method="POST",
|
||||
json={"source_id": "abc"},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = ManageSync().post()
|
||||
|
||||
assert response is not None
|
||||
|
||||
def test_returns_400_for_invalid_frequency(self, app):
|
||||
from application.api.user.sources.routes import ManageSync
|
||||
|
||||
with app.test_request_context(
|
||||
"/api/manage_sync", method="POST",
|
||||
json={"source_id": str(ObjectId()), "sync_frequency": "hourly"},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = ManageSync().post()
|
||||
|
||||
assert _status(response) == 400
|
||||
assert "Invalid frequency" in _json(response)["message"]
|
||||
|
||||
def test_updates_sync_frequency_successfully(self, app):
|
||||
from application.api.user.sources.routes import ManageSync
|
||||
|
||||
source_id = str(ObjectId())
|
||||
mock_collection = Mock()
|
||||
|
||||
with patch(
|
||||
"application.api.user.sources.routes.sources_collection",
|
||||
mock_collection,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/manage_sync", method="POST",
|
||||
json={"source_id": source_id, "sync_frequency": "weekly"},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = ManageSync().post()
|
||||
|
||||
assert _status(response) == 200
|
||||
assert _json(response)["success"] is True
|
||||
call_args = mock_collection.update_one.call_args
|
||||
assert call_args[0][0]["_id"] == ObjectId(source_id)
|
||||
assert call_args[0][0]["user"] == "u1"
|
||||
assert call_args[0][1]["$set"]["sync_frequency"] == "weekly"
|
||||
|
||||
def test_accepts_all_valid_frequencies(self, app):
|
||||
from application.api.user.sources.routes import ManageSync
|
||||
|
||||
mock_collection = Mock()
|
||||
|
||||
for freq in ["never", "daily", "weekly", "monthly"]:
|
||||
with patch(
|
||||
"application.api.user.sources.routes.sources_collection",
|
||||
mock_collection,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/manage_sync", method="POST",
|
||||
json={"source_id": str(ObjectId()), "sync_frequency": freq},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = ManageSync().post()
|
||||
|
||||
assert _status(response) == 200
|
||||
|
||||
def test_returns_400_on_db_error(self, app):
|
||||
from application.api.user.sources.routes import ManageSync
|
||||
|
||||
mock_collection = Mock()
|
||||
mock_collection.update_one.side_effect = Exception("db err")
|
||||
|
||||
with patch(
|
||||
"application.api.user.sources.routes.sources_collection",
|
||||
mock_collection,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/manage_sync", method="POST",
|
||||
json={"source_id": str(ObjectId()), "sync_frequency": "daily"},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = ManageSync().post()
|
||||
|
||||
assert _status(response) == 400
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RedirectToSources (/api/combine)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestRedirectToSources:
|
||||
|
||||
def test_redirects_to_sources(self, app):
|
||||
from application.api.user.sources.routes import RedirectToSources
|
||||
|
||||
with app.test_request_context("/api/combine"):
|
||||
response = RedirectToSources().get()
|
||||
|
||||
assert response.status_code == 301
|
||||
assert response.location == "/api/sources"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DirectoryStructure (/api/directory_structure)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestDirectoryStructure:
|
||||
|
||||
def test_returns_401_unauthenticated(self, app):
|
||||
from application.api.user.sources.routes import DirectoryStructure
|
||||
|
||||
with app.test_request_context("/api/directory_structure?id=abc"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = None
|
||||
response = DirectoryStructure().get()
|
||||
|
||||
assert _status(response) == 401
|
||||
|
||||
def test_returns_400_when_id_missing(self, app):
|
||||
from application.api.user.sources.routes import DirectoryStructure
|
||||
|
||||
with app.test_request_context("/api/directory_structure"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = DirectoryStructure().get()
|
||||
|
||||
assert _status(response) == 400
|
||||
assert "required" in _json(response)["error"]
|
||||
|
||||
def test_returns_400_for_invalid_doc_id(self, app):
|
||||
from application.api.user.sources.routes import DirectoryStructure
|
||||
|
||||
with app.test_request_context("/api/directory_structure?id=invalid"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = DirectoryStructure().get()
|
||||
|
||||
assert _status(response) == 400
|
||||
assert "Invalid" in _json(response)["error"]
|
||||
|
||||
def test_returns_404_when_doc_not_found(self, app):
|
||||
from application.api.user.sources.routes import DirectoryStructure
|
||||
|
||||
doc_id = str(ObjectId())
|
||||
mock_collection = Mock()
|
||||
mock_collection.find_one.return_value = None
|
||||
|
||||
with patch(
|
||||
"application.api.user.sources.routes.sources_collection",
|
||||
mock_collection,
|
||||
):
|
||||
with app.test_request_context(f"/api/directory_structure?id={doc_id}"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = DirectoryStructure().get()
|
||||
|
||||
assert _status(response) == 404
|
||||
assert "not found" in _json(response)["error"]
|
||||
|
||||
def test_returns_directory_structure(self, app):
|
||||
from application.api.user.sources.routes import DirectoryStructure
|
||||
|
||||
doc_id = ObjectId()
|
||||
dir_struct = {"dirs": ["a", "b"], "files": ["c.txt"]}
|
||||
mock_collection = Mock()
|
||||
mock_collection.find_one.return_value = {
|
||||
"_id": doc_id,
|
||||
"user": "u1",
|
||||
"directory_structure": dir_struct,
|
||||
"file_path": "uploads/u1/mydir",
|
||||
"remote_data": json.dumps({"provider": "github"}),
|
||||
}
|
||||
|
||||
with patch(
|
||||
"application.api.user.sources.routes.sources_collection",
|
||||
mock_collection,
|
||||
):
|
||||
with app.test_request_context(
|
||||
f"/api/directory_structure?id={doc_id}"
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = DirectoryStructure().get()
|
||||
|
||||
assert _status(response) == 200
|
||||
data = _json(response)
|
||||
assert data["success"] is True
|
||||
assert data["directory_structure"] == dir_struct
|
||||
assert data["base_path"] == "uploads/u1/mydir"
|
||||
assert data["provider"] == "github"
|
||||
|
||||
def test_returns_none_provider_when_no_remote_data(self, app):
|
||||
from application.api.user.sources.routes import DirectoryStructure
|
||||
|
||||
doc_id = ObjectId()
|
||||
mock_collection = Mock()
|
||||
mock_collection.find_one.return_value = {
|
||||
"_id": doc_id,
|
||||
"user": "u1",
|
||||
"directory_structure": {},
|
||||
"file_path": "path",
|
||||
}
|
||||
|
||||
with patch(
|
||||
"application.api.user.sources.routes.sources_collection",
|
||||
mock_collection,
|
||||
):
|
||||
with app.test_request_context(
|
||||
f"/api/directory_structure?id={doc_id}"
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = DirectoryStructure().get()
|
||||
|
||||
data = _json(response)
|
||||
assert data["provider"] is None
|
||||
|
||||
def test_handles_invalid_remote_data_json(self, app):
|
||||
from application.api.user.sources.routes import DirectoryStructure
|
||||
|
||||
doc_id = ObjectId()
|
||||
mock_collection = Mock()
|
||||
mock_collection.find_one.return_value = {
|
||||
"_id": doc_id,
|
||||
"user": "u1",
|
||||
"directory_structure": {},
|
||||
"file_path": "path",
|
||||
"remote_data": "not-valid-json{",
|
||||
}
|
||||
|
||||
with patch(
|
||||
"application.api.user.sources.routes.sources_collection",
|
||||
mock_collection,
|
||||
):
|
||||
with app.test_request_context(
|
||||
f"/api/directory_structure?id={doc_id}"
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = DirectoryStructure().get()
|
||||
|
||||
data = _json(response)
|
||||
assert data["success"] is True
|
||||
assert data["provider"] is None
|
||||
|
||||
def test_returns_500_on_general_error(self, app):
|
||||
from application.api.user.sources.routes import DirectoryStructure
|
||||
|
||||
doc_id = str(ObjectId())
|
||||
mock_collection = Mock()
|
||||
mock_collection.find_one.side_effect = Exception("db error")
|
||||
|
||||
with patch(
|
||||
"application.api.user.sources.routes.sources_collection",
|
||||
mock_collection,
|
||||
):
|
||||
with app.test_request_context(
|
||||
f"/api/directory_structure?id={doc_id}"
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u1"}
|
||||
response = DirectoryStructure().get()
|
||||
|
||||
assert _status(response) == 500
|
||||
1723
tests/api/user/sources/test_upload.py
Normal file
1723
tests/api/user/sources/test_upload.py
Normal file
File diff suppressed because it is too large
Load Diff
3620
tests/api/user/test_agents_routes.py
Normal file
3620
tests/api/user/test_agents_routes.py
Normal file
File diff suppressed because it is too large
Load Diff
768
tests/api/user/test_agents_sharing.py
Normal file
768
tests/api/user/test_agents_sharing.py
Normal file
@@ -0,0 +1,768 @@
|
||||
"""Tests for application.api.user.agents.sharing module."""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from bson import DBRef, ObjectId
|
||||
from flask import Flask
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
app = Flask(__name__)
|
||||
return app
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SharedAgent (GET /shared_agent)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSharedAgent:
|
||||
|
||||
def test_returns_400_missing_token(self, app):
|
||||
from application.api.user.agents.sharing import SharedAgent
|
||||
|
||||
with app.test_request_context("/api/shared_agent"):
|
||||
response = SharedAgent().get()
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_returns_404_agent_not_found(self, app):
|
||||
from application.api.user.agents.sharing import SharedAgent
|
||||
|
||||
mock_col = Mock()
|
||||
mock_col.find_one.return_value = None
|
||||
|
||||
with patch(
|
||||
"application.api.user.agents.sharing.agents_collection", mock_col
|
||||
):
|
||||
with app.test_request_context("/api/shared_agent?token=abc123"):
|
||||
response = SharedAgent().get()
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_returns_shared_agent_data(self, app):
|
||||
from application.api.user.agents.sharing import SharedAgent
|
||||
|
||||
agent_id = ObjectId()
|
||||
mock_agents_col = Mock()
|
||||
mock_agents_col.find_one.return_value = {
|
||||
"_id": agent_id,
|
||||
"user": "owner1",
|
||||
"name": "Shared Agent",
|
||||
"description": "A shared agent",
|
||||
"chunks": "5",
|
||||
"retriever": "classic",
|
||||
"prompt_id": "default",
|
||||
"tools": [],
|
||||
"agent_type": "classic",
|
||||
"status": "published",
|
||||
"shared_publicly": True,
|
||||
"shared_token": "abc123",
|
||||
}
|
||||
mock_resolve = Mock(return_value=[])
|
||||
mock_db = Mock()
|
||||
|
||||
with patch(
|
||||
"application.api.user.agents.sharing.agents_collection", mock_agents_col
|
||||
), patch(
|
||||
"application.api.user.agents.sharing.resolve_tool_details", mock_resolve
|
||||
), patch(
|
||||
"application.api.user.agents.sharing.db", mock_db
|
||||
):
|
||||
with app.test_request_context("/api/shared_agent?token=abc123"):
|
||||
from flask import request
|
||||
|
||||
# No decoded_token -> anonymous access
|
||||
request.decoded_token = None
|
||||
response = SharedAgent().get()
|
||||
assert response.status_code == 200
|
||||
data = response.json
|
||||
assert data["id"] == str(agent_id)
|
||||
assert data["name"] == "Shared Agent"
|
||||
assert data["shared"] is True
|
||||
|
||||
def test_adds_to_shared_with_me_for_different_user(self, app):
|
||||
from application.api.user.agents.sharing import SharedAgent
|
||||
|
||||
agent_id = ObjectId()
|
||||
mock_agents_col = Mock()
|
||||
mock_agents_col.find_one.return_value = {
|
||||
"_id": agent_id,
|
||||
"user": "owner1",
|
||||
"name": "Agent",
|
||||
"tools": [],
|
||||
"shared_publicly": True,
|
||||
"shared_token": "abc123",
|
||||
}
|
||||
mock_resolve = Mock(return_value=[])
|
||||
mock_db = Mock()
|
||||
mock_ensure = Mock(return_value={"user_id": "user2"})
|
||||
mock_users_col = Mock()
|
||||
|
||||
with patch(
|
||||
"application.api.user.agents.sharing.agents_collection", mock_agents_col
|
||||
), patch(
|
||||
"application.api.user.agents.sharing.resolve_tool_details", mock_resolve
|
||||
), patch(
|
||||
"application.api.user.agents.sharing.db", mock_db
|
||||
), patch(
|
||||
"application.api.user.agents.sharing.ensure_user_doc", mock_ensure
|
||||
), patch(
|
||||
"application.api.user.agents.sharing.users_collection", mock_users_col
|
||||
):
|
||||
with app.test_request_context("/api/shared_agent?token=abc123"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user2"}
|
||||
response = SharedAgent().get()
|
||||
assert response.status_code == 200
|
||||
mock_ensure.assert_called_once_with("user2")
|
||||
mock_users_col.update_one.assert_called_once()
|
||||
|
||||
def test_does_not_add_to_shared_for_owner(self, app):
|
||||
from application.api.user.agents.sharing import SharedAgent
|
||||
|
||||
agent_id = ObjectId()
|
||||
mock_agents_col = Mock()
|
||||
mock_agents_col.find_one.return_value = {
|
||||
"_id": agent_id,
|
||||
"user": "owner1",
|
||||
"name": "Agent",
|
||||
"tools": [],
|
||||
"shared_publicly": True,
|
||||
"shared_token": "abc123",
|
||||
}
|
||||
mock_resolve = Mock(return_value=[])
|
||||
mock_db = Mock()
|
||||
mock_ensure = Mock()
|
||||
mock_users_col = Mock()
|
||||
|
||||
with patch(
|
||||
"application.api.user.agents.sharing.agents_collection", mock_agents_col
|
||||
), patch(
|
||||
"application.api.user.agents.sharing.resolve_tool_details", mock_resolve
|
||||
), patch(
|
||||
"application.api.user.agents.sharing.db", mock_db
|
||||
), patch(
|
||||
"application.api.user.agents.sharing.ensure_user_doc", mock_ensure
|
||||
), patch(
|
||||
"application.api.user.agents.sharing.users_collection", mock_users_col
|
||||
):
|
||||
with app.test_request_context("/api/shared_agent?token=abc123"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "owner1"}
|
||||
response = SharedAgent().get()
|
||||
assert response.status_code == 200
|
||||
mock_ensure.assert_not_called()
|
||||
mock_users_col.update_one.assert_not_called()
|
||||
|
||||
def test_enriches_tool_names(self, app):
|
||||
from application.api.user.agents.sharing import SharedAgent
|
||||
|
||||
agent_id = ObjectId()
|
||||
tool_id = str(ObjectId())
|
||||
mock_agents_col = Mock()
|
||||
mock_agents_col.find_one.return_value = {
|
||||
"_id": agent_id,
|
||||
"user": "owner1",
|
||||
"name": "Agent",
|
||||
"tools": [tool_id],
|
||||
"shared_publicly": True,
|
||||
"shared_token": "tok",
|
||||
}
|
||||
mock_tools_col = Mock()
|
||||
mock_tools_col.find_one.return_value = {
|
||||
"_id": ObjectId(tool_id),
|
||||
"name": "calculator",
|
||||
}
|
||||
mock_resolve = Mock(return_value=[])
|
||||
mock_db = Mock()
|
||||
|
||||
with patch(
|
||||
"application.api.user.agents.sharing.agents_collection", mock_agents_col
|
||||
), patch(
|
||||
"application.api.user.agents.sharing.user_tools_collection", mock_tools_col
|
||||
), patch(
|
||||
"application.api.user.agents.sharing.resolve_tool_details", mock_resolve
|
||||
), patch(
|
||||
"application.api.user.agents.sharing.db", mock_db
|
||||
):
|
||||
with app.test_request_context("/api/shared_agent?token=tok"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = None
|
||||
response = SharedAgent().get()
|
||||
assert response.status_code == 200
|
||||
assert response.json["tools"] == ["calculator"]
|
||||
|
||||
def test_handles_source_dbref(self, app):
|
||||
from application.api.user.agents.sharing import SharedAgent
|
||||
|
||||
agent_id = ObjectId()
|
||||
source_id = ObjectId()
|
||||
source_ref = DBRef("sources", source_id)
|
||||
mock_agents_col = Mock()
|
||||
mock_agents_col.find_one.return_value = {
|
||||
"_id": agent_id,
|
||||
"user": "owner1",
|
||||
"name": "Agent",
|
||||
"source": source_ref,
|
||||
"tools": [],
|
||||
"shared_publicly": True,
|
||||
"shared_token": "tok",
|
||||
}
|
||||
mock_resolve = Mock(return_value=[])
|
||||
mock_db = Mock()
|
||||
mock_db.dereference.return_value = {"_id": source_id}
|
||||
|
||||
with patch(
|
||||
"application.api.user.agents.sharing.agents_collection", mock_agents_col
|
||||
), patch(
|
||||
"application.api.user.agents.sharing.resolve_tool_details", mock_resolve
|
||||
), patch(
|
||||
"application.api.user.agents.sharing.db", mock_db
|
||||
):
|
||||
with app.test_request_context("/api/shared_agent?token=tok"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = None
|
||||
response = SharedAgent().get()
|
||||
assert response.status_code == 200
|
||||
assert response.json["source"] == str(source_id)
|
||||
|
||||
def test_returns_400_on_exception(self, app):
|
||||
from application.api.user.agents.sharing import SharedAgent
|
||||
|
||||
mock_col = Mock()
|
||||
mock_col.find_one.side_effect = Exception("DB error")
|
||||
|
||||
with patch(
|
||||
"application.api.user.agents.sharing.agents_collection", mock_col
|
||||
):
|
||||
with app.test_request_context("/api/shared_agent?token=tok"):
|
||||
response = SharedAgent().get()
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_tool_enrichment_handles_missing_tool(self, app):
|
||||
from application.api.user.agents.sharing import SharedAgent
|
||||
|
||||
agent_id = ObjectId()
|
||||
tool_id = str(ObjectId())
|
||||
mock_agents_col = Mock()
|
||||
mock_agents_col.find_one.return_value = {
|
||||
"_id": agent_id,
|
||||
"user": "owner1",
|
||||
"name": "Agent",
|
||||
"tools": [tool_id],
|
||||
"shared_publicly": True,
|
||||
"shared_token": "tok",
|
||||
}
|
||||
mock_tools_col = Mock()
|
||||
mock_tools_col.find_one.return_value = None
|
||||
mock_resolve = Mock(return_value=[])
|
||||
mock_db = Mock()
|
||||
|
||||
with patch(
|
||||
"application.api.user.agents.sharing.agents_collection", mock_agents_col
|
||||
), patch(
|
||||
"application.api.user.agents.sharing.user_tools_collection", mock_tools_col
|
||||
), patch(
|
||||
"application.api.user.agents.sharing.resolve_tool_details", mock_resolve
|
||||
), patch(
|
||||
"application.api.user.agents.sharing.db", mock_db
|
||||
):
|
||||
with app.test_request_context("/api/shared_agent?token=tok"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = None
|
||||
response = SharedAgent().get()
|
||||
assert response.status_code == 200
|
||||
# Missing tools are skipped
|
||||
assert response.json["tools"] == []
|
||||
|
||||
def test_image_url_generated_when_present(self, app):
|
||||
from application.api.user.agents.sharing import SharedAgent
|
||||
|
||||
agent_id = ObjectId()
|
||||
mock_agents_col = Mock()
|
||||
mock_agents_col.find_one.return_value = {
|
||||
"_id": agent_id,
|
||||
"user": "owner1",
|
||||
"name": "Agent",
|
||||
"image": "path/to/img.png",
|
||||
"tools": [],
|
||||
"shared_publicly": True,
|
||||
"shared_token": "tok",
|
||||
}
|
||||
mock_resolve = Mock(return_value=[])
|
||||
mock_db = Mock()
|
||||
mock_generate = Mock(return_value="http://example.com/img.png")
|
||||
|
||||
with patch(
|
||||
"application.api.user.agents.sharing.agents_collection", mock_agents_col
|
||||
), patch(
|
||||
"application.api.user.agents.sharing.resolve_tool_details", mock_resolve
|
||||
), patch(
|
||||
"application.api.user.agents.sharing.db", mock_db
|
||||
), patch(
|
||||
"application.api.user.agents.sharing.generate_image_url", mock_generate
|
||||
):
|
||||
with app.test_request_context("/api/shared_agent?token=tok"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = None
|
||||
response = SharedAgent().get()
|
||||
assert response.status_code == 200
|
||||
assert response.json["image"] == "http://example.com/img.png"
|
||||
mock_generate.assert_called_once_with("path/to/img.png")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SharedAgents (GET /shared_agents)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSharedAgents:
|
||||
|
||||
def test_returns_401_unauthenticated(self, app):
|
||||
from application.api.user.agents.sharing import SharedAgents
|
||||
|
||||
with app.test_request_context("/api/shared_agents"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = None
|
||||
response = SharedAgents().get()
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_returns_shared_agents_list(self, app):
|
||||
from application.api.user.agents.sharing import SharedAgents
|
||||
|
||||
agent_id = ObjectId()
|
||||
mock_ensure = Mock(
|
||||
return_value={
|
||||
"user_id": "user1",
|
||||
"agent_preferences": {
|
||||
"shared_with_me": [str(agent_id)],
|
||||
"pinned": [str(agent_id)],
|
||||
},
|
||||
}
|
||||
)
|
||||
mock_agents_col = Mock()
|
||||
mock_agents_col.find.return_value = [
|
||||
{
|
||||
"_id": agent_id,
|
||||
"name": "Shared Agent",
|
||||
"description": "desc",
|
||||
"tools": [],
|
||||
"agent_type": "classic",
|
||||
"status": "published",
|
||||
"shared_publicly": True,
|
||||
"shared_token": "tok123",
|
||||
}
|
||||
]
|
||||
mock_resolve = Mock(return_value=[])
|
||||
mock_users_col = Mock()
|
||||
|
||||
with patch(
|
||||
"application.api.user.agents.sharing.ensure_user_doc", mock_ensure
|
||||
), patch(
|
||||
"application.api.user.agents.sharing.agents_collection", mock_agents_col
|
||||
), patch(
|
||||
"application.api.user.agents.sharing.resolve_tool_details", mock_resolve
|
||||
), patch(
|
||||
"application.api.user.agents.sharing.users_collection", mock_users_col
|
||||
):
|
||||
with app.test_request_context("/api/shared_agents"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = SharedAgents().get()
|
||||
assert response.status_code == 200
|
||||
data = response.json
|
||||
assert len(data) == 1
|
||||
assert data[0]["name"] == "Shared Agent"
|
||||
assert data[0]["pinned"] is True
|
||||
|
||||
def test_removes_stale_shared_ids(self, app):
|
||||
from application.api.user.agents.sharing import SharedAgents
|
||||
|
||||
stale_id = str(ObjectId())
|
||||
mock_ensure = Mock(
|
||||
return_value={
|
||||
"user_id": "user1",
|
||||
"agent_preferences": {
|
||||
"shared_with_me": [stale_id],
|
||||
"pinned": [],
|
||||
},
|
||||
}
|
||||
)
|
||||
mock_agents_col = Mock()
|
||||
mock_agents_col.find.return_value = [] # None found
|
||||
mock_users_col = Mock()
|
||||
|
||||
with patch(
|
||||
"application.api.user.agents.sharing.ensure_user_doc", mock_ensure
|
||||
), patch(
|
||||
"application.api.user.agents.sharing.agents_collection", mock_agents_col
|
||||
), patch(
|
||||
"application.api.user.agents.sharing.users_collection", mock_users_col
|
||||
):
|
||||
with app.test_request_context("/api/shared_agents"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = SharedAgents().get()
|
||||
assert response.status_code == 200
|
||||
mock_users_col.update_one.assert_called_once()
|
||||
call_args = mock_users_col.update_one.call_args
|
||||
assert stale_id in call_args[0][1]["$pullAll"][
|
||||
"agent_preferences.shared_with_me"
|
||||
]
|
||||
|
||||
def test_returns_empty_when_no_shared_ids(self, app):
|
||||
from application.api.user.agents.sharing import SharedAgents
|
||||
|
||||
mock_ensure = Mock(
|
||||
return_value={
|
||||
"user_id": "user1",
|
||||
"agent_preferences": {"shared_with_me": [], "pinned": []},
|
||||
}
|
||||
)
|
||||
mock_agents_col = Mock()
|
||||
mock_agents_col.find.return_value = []
|
||||
|
||||
with patch(
|
||||
"application.api.user.agents.sharing.ensure_user_doc", mock_ensure
|
||||
), patch(
|
||||
"application.api.user.agents.sharing.agents_collection", mock_agents_col
|
||||
):
|
||||
with app.test_request_context("/api/shared_agents"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = SharedAgents().get()
|
||||
assert response.status_code == 200
|
||||
assert response.json == []
|
||||
|
||||
def test_returns_400_on_exception(self, app):
|
||||
from application.api.user.agents.sharing import SharedAgents
|
||||
|
||||
mock_ensure = Mock(side_effect=Exception("DB error"))
|
||||
|
||||
with patch(
|
||||
"application.api.user.agents.sharing.ensure_user_doc", mock_ensure
|
||||
):
|
||||
with app.test_request_context("/api/shared_agents"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = SharedAgents().get()
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_image_url_generated(self, app):
|
||||
from application.api.user.agents.sharing import SharedAgents
|
||||
|
||||
agent_id = ObjectId()
|
||||
mock_ensure = Mock(
|
||||
return_value={
|
||||
"user_id": "user1",
|
||||
"agent_preferences": {
|
||||
"shared_with_me": [str(agent_id)],
|
||||
"pinned": [],
|
||||
},
|
||||
}
|
||||
)
|
||||
mock_agents_col = Mock()
|
||||
mock_agents_col.find.return_value = [
|
||||
{
|
||||
"_id": agent_id,
|
||||
"name": "Agent",
|
||||
"image": "path.png",
|
||||
"tools": [],
|
||||
"shared_publicly": True,
|
||||
}
|
||||
]
|
||||
mock_resolve = Mock(return_value=[])
|
||||
mock_generate = Mock(return_value="http://example.com/path.png")
|
||||
|
||||
with patch(
|
||||
"application.api.user.agents.sharing.ensure_user_doc", mock_ensure
|
||||
), patch(
|
||||
"application.api.user.agents.sharing.agents_collection", mock_agents_col
|
||||
), patch(
|
||||
"application.api.user.agents.sharing.resolve_tool_details", mock_resolve
|
||||
), patch(
|
||||
"application.api.user.agents.sharing.generate_image_url", mock_generate
|
||||
), patch(
|
||||
"application.api.user.agents.sharing.users_collection", Mock()
|
||||
):
|
||||
with app.test_request_context("/api/shared_agents"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = SharedAgents().get()
|
||||
assert response.status_code == 200
|
||||
assert response.json[0]["image"] == "http://example.com/path.png"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ShareAgent (PUT /share_agent)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestShareAgent:
|
||||
|
||||
def test_returns_401_unauthenticated(self, app):
|
||||
from application.api.user.agents.sharing import ShareAgent
|
||||
|
||||
with app.test_request_context(
|
||||
"/api/share_agent",
|
||||
method="PUT",
|
||||
json={"id": "abc", "shared": True},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = None
|
||||
response = ShareAgent().put()
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_returns_400_missing_json_body(self, app):
|
||||
from application.api.user.agents.sharing import ShareAgent
|
||||
|
||||
with app.test_request_context(
|
||||
"/api/share_agent",
|
||||
method="PUT",
|
||||
content_type="application/json",
|
||||
data=b"{}",
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
# Empty JSON object -> no id, no shared -> 400
|
||||
response = ShareAgent().put()
|
||||
assert response.status_code == 400
|
||||
assert response.json["success"] is False
|
||||
|
||||
def test_returns_400_missing_id(self, app):
|
||||
from application.api.user.agents.sharing import ShareAgent
|
||||
|
||||
with app.test_request_context(
|
||||
"/api/share_agent",
|
||||
method="PUT",
|
||||
json={"shared": True},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = ShareAgent().put()
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_returns_400_missing_shared_param(self, app):
|
||||
from application.api.user.agents.sharing import ShareAgent
|
||||
|
||||
with app.test_request_context(
|
||||
"/api/share_agent",
|
||||
method="PUT",
|
||||
json={"id": str(ObjectId())},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = ShareAgent().put()
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_returns_400_invalid_agent_id(self, app):
|
||||
from application.api.user.agents.sharing import ShareAgent
|
||||
|
||||
with app.test_request_context(
|
||||
"/api/share_agent",
|
||||
method="PUT",
|
||||
json={"id": "invalid-oid", "shared": True},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = ShareAgent().put()
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_returns_404_agent_not_found(self, app):
|
||||
from application.api.user.agents.sharing import ShareAgent
|
||||
|
||||
mock_col = Mock()
|
||||
mock_col.find_one.return_value = None
|
||||
agent_id = str(ObjectId())
|
||||
|
||||
with patch(
|
||||
"application.api.user.agents.sharing.agents_collection", mock_col
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/share_agent",
|
||||
method="PUT",
|
||||
json={"id": agent_id, "shared": True},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = ShareAgent().put()
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_shares_agent_success(self, app):
|
||||
from application.api.user.agents.sharing import ShareAgent
|
||||
|
||||
agent_id = ObjectId()
|
||||
mock_col = Mock()
|
||||
mock_col.find_one.return_value = {
|
||||
"_id": agent_id,
|
||||
"user": "user1",
|
||||
}
|
||||
mock_col.update_one.return_value = Mock()
|
||||
|
||||
with patch(
|
||||
"application.api.user.agents.sharing.agents_collection", mock_col
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/share_agent",
|
||||
method="PUT",
|
||||
json={
|
||||
"id": str(agent_id),
|
||||
"shared": True,
|
||||
"username": "TestUser",
|
||||
},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = ShareAgent().put()
|
||||
assert response.status_code == 200
|
||||
data = response.json
|
||||
assert data["success"] is True
|
||||
assert data["shared_token"] is not None
|
||||
mock_col.update_one.assert_called_once()
|
||||
|
||||
def test_unshares_agent_success(self, app):
|
||||
from application.api.user.agents.sharing import ShareAgent
|
||||
|
||||
agent_id = ObjectId()
|
||||
mock_col = Mock()
|
||||
mock_col.find_one.return_value = {
|
||||
"_id": agent_id,
|
||||
"user": "user1",
|
||||
}
|
||||
mock_col.update_one.return_value = Mock()
|
||||
|
||||
with patch(
|
||||
"application.api.user.agents.sharing.agents_collection", mock_col
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/share_agent",
|
||||
method="PUT",
|
||||
json={
|
||||
"id": str(agent_id),
|
||||
"shared": False,
|
||||
},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = ShareAgent().put()
|
||||
assert response.status_code == 200
|
||||
data = response.json
|
||||
assert data["success"] is True
|
||||
assert data["shared_token"] is None
|
||||
|
||||
def test_returns_400_on_db_exception(self, app):
|
||||
from application.api.user.agents.sharing import ShareAgent
|
||||
|
||||
agent_id = ObjectId()
|
||||
mock_col = Mock()
|
||||
mock_col.find_one.return_value = {
|
||||
"_id": agent_id,
|
||||
"user": "user1",
|
||||
}
|
||||
mock_col.update_one.side_effect = Exception("DB error")
|
||||
|
||||
with patch(
|
||||
"application.api.user.agents.sharing.agents_collection", mock_col
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/share_agent",
|
||||
method="PUT",
|
||||
json={
|
||||
"id": str(agent_id),
|
||||
"shared": True,
|
||||
},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = ShareAgent().put()
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_share_with_username(self, app):
|
||||
from application.api.user.agents.sharing import ShareAgent
|
||||
|
||||
agent_id = ObjectId()
|
||||
mock_col = Mock()
|
||||
mock_col.find_one.return_value = {
|
||||
"_id": agent_id,
|
||||
"user": "user1",
|
||||
}
|
||||
mock_col.update_one.return_value = Mock()
|
||||
|
||||
with patch(
|
||||
"application.api.user.agents.sharing.agents_collection", mock_col
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/share_agent",
|
||||
method="PUT",
|
||||
json={
|
||||
"id": str(agent_id),
|
||||
"shared": True,
|
||||
"username": "SharedByUser",
|
||||
},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = ShareAgent().put()
|
||||
assert response.status_code == 200
|
||||
# Verify the update call includes shared_metadata with username
|
||||
update_call = mock_col.update_one.call_args[0][1]["$set"]
|
||||
assert update_call["shared_metadata"]["shared_by"] == "SharedByUser"
|
||||
assert update_call["shared_publicly"] is True
|
||||
assert "shared_token" in update_call
|
||||
|
||||
def test_shared_false_explicitly(self, app):
|
||||
from application.api.user.agents.sharing import ShareAgent
|
||||
|
||||
agent_id = ObjectId()
|
||||
mock_col = Mock()
|
||||
mock_col.find_one.return_value = {
|
||||
"_id": agent_id,
|
||||
"user": "user1",
|
||||
}
|
||||
mock_col.update_one.return_value = Mock()
|
||||
|
||||
with patch(
|
||||
"application.api.user.agents.sharing.agents_collection", mock_col
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/share_agent",
|
||||
method="PUT",
|
||||
json={
|
||||
"id": str(agent_id),
|
||||
"shared": False,
|
||||
},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = ShareAgent().put()
|
||||
assert response.status_code == 200
|
||||
update_call = mock_col.update_one.call_args[0][1]
|
||||
assert update_call["$set"]["shared_publicly"] is False
|
||||
assert update_call["$set"]["shared_token"] is None
|
||||
899
tests/api/user/test_analytics.py
Normal file
899
tests/api/user/test_analytics.py
Normal file
@@ -0,0 +1,899 @@
|
||||
import datetime
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from bson import ObjectId
|
||||
from flask import Flask
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
app = Flask(__name__)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGetMessageAnalytics:
|
||||
|
||||
def test_returns_message_analytics_last_30_days(self, app):
|
||||
from application.api.user.analytics.routes import GetMessageAnalytics
|
||||
|
||||
mock_conversations = Mock()
|
||||
mock_conversations.aggregate.return_value = [
|
||||
{"_id": "2024-06-01", "count": 5},
|
||||
{"_id": "2024-06-02", "count": 3},
|
||||
]
|
||||
mock_agents = Mock()
|
||||
mock_agents.find_one.return_value = None
|
||||
|
||||
with patch(
|
||||
"application.api.user.analytics.routes.conversations_collection",
|
||||
mock_conversations,
|
||||
), patch(
|
||||
"application.api.user.analytics.routes.agents_collection",
|
||||
mock_agents,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/get_message_analytics",
|
||||
method="POST",
|
||||
json={"filter_option": "last_30_days"},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = GetMessageAnalytics().post()
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json["success"] is True
|
||||
assert "messages" in response.json
|
||||
|
||||
def test_returns_401_unauthenticated(self, app):
|
||||
from application.api.user.analytics.routes import GetMessageAnalytics
|
||||
|
||||
with app.test_request_context(
|
||||
"/api/get_message_analytics",
|
||||
method="POST",
|
||||
json={"filter_option": "last_30_days"},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = None
|
||||
response = GetMessageAnalytics().post()
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_returns_400_invalid_filter_option(self, app):
|
||||
from application.api.user.analytics.routes import GetMessageAnalytics
|
||||
|
||||
mock_agents = Mock()
|
||||
mock_agents.find_one.return_value = None
|
||||
|
||||
with patch(
|
||||
"application.api.user.analytics.routes.agents_collection",
|
||||
mock_agents,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/get_message_analytics",
|
||||
method="POST",
|
||||
json={"filter_option": "invalid_option"},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = GetMessageAnalytics().post()
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_filters_by_api_key(self, app):
|
||||
from application.api.user.analytics.routes import GetMessageAnalytics
|
||||
|
||||
agent_id = ObjectId()
|
||||
mock_agents = Mock()
|
||||
mock_agents.find_one.return_value = {
|
||||
"_id": agent_id,
|
||||
"key": "api_key_value",
|
||||
}
|
||||
mock_conversations = Mock()
|
||||
mock_conversations.aggregate.return_value = []
|
||||
|
||||
with patch(
|
||||
"application.api.user.analytics.routes.agents_collection",
|
||||
mock_agents,
|
||||
), patch(
|
||||
"application.api.user.analytics.routes.conversations_collection",
|
||||
mock_conversations,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/get_message_analytics",
|
||||
method="POST",
|
||||
json={
|
||||
"filter_option": "last_7_days",
|
||||
"api_key_id": str(agent_id),
|
||||
},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = GetMessageAnalytics().post()
|
||||
|
||||
assert response.status_code == 200
|
||||
pipeline = mock_conversations.aggregate.call_args[0][0]
|
||||
assert pipeline[0]["$match"].get("api_key") == "api_key_value"
|
||||
|
||||
def test_last_hour_filter(self, app):
|
||||
from application.api.user.analytics.routes import GetMessageAnalytics
|
||||
|
||||
mock_conversations = Mock()
|
||||
mock_conversations.aggregate.return_value = []
|
||||
mock_agents = Mock()
|
||||
mock_agents.find_one.return_value = None
|
||||
|
||||
with patch(
|
||||
"application.api.user.analytics.routes.conversations_collection",
|
||||
mock_conversations,
|
||||
), patch(
|
||||
"application.api.user.analytics.routes.agents_collection",
|
||||
mock_agents,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/get_message_analytics",
|
||||
method="POST",
|
||||
json={"filter_option": "last_hour"},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = GetMessageAnalytics().post()
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_last_24_hour_filter(self, app):
|
||||
from application.api.user.analytics.routes import GetMessageAnalytics
|
||||
|
||||
mock_conversations = Mock()
|
||||
mock_conversations.aggregate.return_value = []
|
||||
mock_agents = Mock()
|
||||
mock_agents.find_one.return_value = None
|
||||
|
||||
with patch(
|
||||
"application.api.user.analytics.routes.conversations_collection",
|
||||
mock_conversations,
|
||||
), patch(
|
||||
"application.api.user.analytics.routes.agents_collection",
|
||||
mock_agents,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/get_message_analytics",
|
||||
method="POST",
|
||||
json={"filter_option": "last_24_hour"},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = GetMessageAnalytics().post()
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGetTokenAnalytics:
|
||||
|
||||
def test_returns_token_analytics(self, app):
|
||||
from application.api.user.analytics.routes import GetTokenAnalytics
|
||||
|
||||
mock_token_usage = Mock()
|
||||
mock_token_usage.aggregate.return_value = [
|
||||
{"_id": {"day": "2024-06-01"}, "total_tokens": 1000}
|
||||
]
|
||||
mock_agents = Mock()
|
||||
mock_agents.find_one.return_value = None
|
||||
|
||||
with patch(
|
||||
"application.api.user.analytics.routes.token_usage_collection",
|
||||
mock_token_usage,
|
||||
), patch(
|
||||
"application.api.user.analytics.routes.agents_collection",
|
||||
mock_agents,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/get_token_analytics",
|
||||
method="POST",
|
||||
json={"filter_option": "last_30_days"},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = GetTokenAnalytics().post()
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json["success"] is True
|
||||
assert "token_usage" in response.json
|
||||
|
||||
def test_returns_400_invalid_filter(self, app):
|
||||
from application.api.user.analytics.routes import GetTokenAnalytics
|
||||
|
||||
mock_agents = Mock()
|
||||
mock_agents.find_one.return_value = None
|
||||
|
||||
with patch(
|
||||
"application.api.user.analytics.routes.agents_collection",
|
||||
mock_agents,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/get_token_analytics",
|
||||
method="POST",
|
||||
json={"filter_option": "invalid"},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = GetTokenAnalytics().post()
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGetFeedbackAnalytics:
|
||||
|
||||
def test_returns_feedback_analytics(self, app):
|
||||
from application.api.user.analytics.routes import GetFeedbackAnalytics
|
||||
|
||||
mock_conversations = Mock()
|
||||
mock_conversations.aggregate.return_value = [
|
||||
{"_id": "2024-06-01", "positive": 10, "negative": 2}
|
||||
]
|
||||
mock_agents = Mock()
|
||||
mock_agents.find_one.return_value = None
|
||||
|
||||
with patch(
|
||||
"application.api.user.analytics.routes.conversations_collection",
|
||||
mock_conversations,
|
||||
), patch(
|
||||
"application.api.user.analytics.routes.agents_collection",
|
||||
mock_agents,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/get_feedback_analytics",
|
||||
method="POST",
|
||||
json={"filter_option": "last_30_days"},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = GetFeedbackAnalytics().post()
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json["success"] is True
|
||||
assert "feedback" in response.json
|
||||
|
||||
def test_returns_400_invalid_filter(self, app):
|
||||
from application.api.user.analytics.routes import GetFeedbackAnalytics
|
||||
|
||||
mock_agents = Mock()
|
||||
mock_agents.find_one.return_value = None
|
||||
|
||||
with patch(
|
||||
"application.api.user.analytics.routes.agents_collection",
|
||||
mock_agents,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/get_feedback_analytics",
|
||||
method="POST",
|
||||
json={"filter_option": "bad"},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = GetFeedbackAnalytics().post()
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGetUserLogs:
|
||||
|
||||
def test_returns_paginated_logs(self, app):
|
||||
from application.api.user.analytics.routes import GetUserLogs
|
||||
|
||||
log_id = ObjectId()
|
||||
mock_cursor = Mock()
|
||||
mock_cursor.sort.return_value.skip.return_value.limit.return_value = [
|
||||
{
|
||||
"_id": log_id,
|
||||
"action": "query",
|
||||
"level": "info",
|
||||
"user": "user1",
|
||||
"question": "test?",
|
||||
"sources": [],
|
||||
"retriever_params": {},
|
||||
"timestamp": datetime.datetime(2024, 6, 1),
|
||||
}
|
||||
]
|
||||
mock_user_logs = Mock()
|
||||
mock_user_logs.find.return_value = mock_cursor
|
||||
mock_agents = Mock()
|
||||
mock_agents.find_one.return_value = None
|
||||
|
||||
with patch(
|
||||
"application.api.user.analytics.routes.user_logs_collection",
|
||||
mock_user_logs,
|
||||
), patch(
|
||||
"application.api.user.analytics.routes.agents_collection",
|
||||
mock_agents,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/get_user_logs",
|
||||
method="POST",
|
||||
json={"page": 1, "page_size": 10},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = GetUserLogs().post()
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json["success"] is True
|
||||
assert response.json["page"] == 1
|
||||
assert len(response.json["logs"]) == 1
|
||||
assert response.json["has_more"] is False
|
||||
|
||||
def test_detects_has_more(self, app):
|
||||
from application.api.user.analytics.routes import GetUserLogs
|
||||
|
||||
items = [
|
||||
{"_id": ObjectId(), "action": f"q{i}", "level": "info"}
|
||||
for i in range(3)
|
||||
]
|
||||
mock_cursor = Mock()
|
||||
mock_cursor.sort.return_value.skip.return_value.limit.return_value = items
|
||||
mock_user_logs = Mock()
|
||||
mock_user_logs.find.return_value = mock_cursor
|
||||
mock_agents = Mock()
|
||||
mock_agents.find_one.return_value = None
|
||||
|
||||
with patch(
|
||||
"application.api.user.analytics.routes.user_logs_collection",
|
||||
mock_user_logs,
|
||||
), patch(
|
||||
"application.api.user.analytics.routes.agents_collection",
|
||||
mock_agents,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/get_user_logs",
|
||||
method="POST",
|
||||
json={"page": 1, "page_size": 2},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = GetUserLogs().post()
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json["has_more"] is True
|
||||
assert len(response.json["logs"]) == 2
|
||||
|
||||
def test_returns_401_unauthenticated(self, app):
|
||||
from application.api.user.analytics.routes import GetUserLogs
|
||||
|
||||
with app.test_request_context(
|
||||
"/api/get_user_logs",
|
||||
method="POST",
|
||||
json={"page": 1},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = None
|
||||
response = GetUserLogs().post()
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGetTokenAnalyticsAdditional:
|
||||
"""Additional tests for GetTokenAnalytics covering missing lines."""
|
||||
|
||||
def test_returns_401_unauthenticated(self, app):
|
||||
from application.api.user.analytics.routes import GetTokenAnalytics
|
||||
|
||||
with app.test_request_context(
|
||||
"/api/get_token_analytics",
|
||||
method="POST",
|
||||
json={"filter_option": "last_30_days"},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = None
|
||||
response = GetTokenAnalytics().post()
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_last_hour_filter(self, app):
|
||||
from application.api.user.analytics.routes import GetTokenAnalytics
|
||||
|
||||
mock_token_usage = Mock()
|
||||
mock_token_usage.aggregate.return_value = [
|
||||
{"_id": {"minute": "2024-06-01 12:00:00"}, "total_tokens": 500}
|
||||
]
|
||||
mock_agents = Mock()
|
||||
mock_agents.find_one.return_value = None
|
||||
|
||||
with patch(
|
||||
"application.api.user.analytics.routes.token_usage_collection",
|
||||
mock_token_usage,
|
||||
), patch(
|
||||
"application.api.user.analytics.routes.agents_collection",
|
||||
mock_agents,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/get_token_analytics",
|
||||
method="POST",
|
||||
json={"filter_option": "last_hour"},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = GetTokenAnalytics().post()
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json["success"] is True
|
||||
assert "token_usage" in response.json
|
||||
|
||||
def test_last_24_hour_filter(self, app):
|
||||
from application.api.user.analytics.routes import GetTokenAnalytics
|
||||
|
||||
mock_token_usage = Mock()
|
||||
mock_token_usage.aggregate.return_value = [
|
||||
{"_id": {"hour": "2024-06-01 12:00"}, "total_tokens": 800}
|
||||
]
|
||||
mock_agents = Mock()
|
||||
mock_agents.find_one.return_value = None
|
||||
|
||||
with patch(
|
||||
"application.api.user.analytics.routes.token_usage_collection",
|
||||
mock_token_usage,
|
||||
), patch(
|
||||
"application.api.user.analytics.routes.agents_collection",
|
||||
mock_agents,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/get_token_analytics",
|
||||
method="POST",
|
||||
json={"filter_option": "last_24_hour"},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = GetTokenAnalytics().post()
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json["success"] is True
|
||||
|
||||
def test_filters_by_api_key(self, app):
|
||||
from application.api.user.analytics.routes import GetTokenAnalytics
|
||||
|
||||
agent_id = ObjectId()
|
||||
mock_agents = Mock()
|
||||
mock_agents.find_one.return_value = {
|
||||
"_id": agent_id,
|
||||
"key": "token_api_key",
|
||||
}
|
||||
mock_token_usage = Mock()
|
||||
mock_token_usage.aggregate.return_value = []
|
||||
|
||||
with patch(
|
||||
"application.api.user.analytics.routes.agents_collection",
|
||||
mock_agents,
|
||||
), patch(
|
||||
"application.api.user.analytics.routes.token_usage_collection",
|
||||
mock_token_usage,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/get_token_analytics",
|
||||
method="POST",
|
||||
json={
|
||||
"filter_option": "last_7_days",
|
||||
"api_key_id": str(agent_id),
|
||||
},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = GetTokenAnalytics().post()
|
||||
|
||||
assert response.status_code == 200
|
||||
pipeline = mock_token_usage.aggregate.call_args[0][0]
|
||||
assert pipeline[0]["$match"].get("api_key") == "token_api_key"
|
||||
|
||||
def test_api_key_error_returns_400(self, app):
|
||||
from application.api.user.analytics.routes import GetTokenAnalytics
|
||||
|
||||
mock_agents = Mock()
|
||||
mock_agents.find_one.side_effect = Exception("db error")
|
||||
|
||||
with patch(
|
||||
"application.api.user.analytics.routes.agents_collection",
|
||||
mock_agents,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/get_token_analytics",
|
||||
method="POST",
|
||||
json={
|
||||
"filter_option": "last_30_days",
|
||||
"api_key_id": str(ObjectId()),
|
||||
},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = GetTokenAnalytics().post()
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_aggregate_error_returns_400(self, app):
|
||||
from application.api.user.analytics.routes import GetTokenAnalytics
|
||||
|
||||
mock_agents = Mock()
|
||||
mock_agents.find_one.return_value = None
|
||||
mock_token_usage = Mock()
|
||||
mock_token_usage.aggregate.side_effect = Exception("aggregate error")
|
||||
|
||||
with patch(
|
||||
"application.api.user.analytics.routes.agents_collection",
|
||||
mock_agents,
|
||||
), patch(
|
||||
"application.api.user.analytics.routes.token_usage_collection",
|
||||
mock_token_usage,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/get_token_analytics",
|
||||
method="POST",
|
||||
json={"filter_option": "last_30_days"},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = GetTokenAnalytics().post()
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_last_15_days_filter(self, app):
|
||||
from application.api.user.analytics.routes import GetTokenAnalytics
|
||||
|
||||
mock_token_usage = Mock()
|
||||
mock_token_usage.aggregate.return_value = []
|
||||
mock_agents = Mock()
|
||||
mock_agents.find_one.return_value = None
|
||||
|
||||
with patch(
|
||||
"application.api.user.analytics.routes.token_usage_collection",
|
||||
mock_token_usage,
|
||||
), patch(
|
||||
"application.api.user.analytics.routes.agents_collection",
|
||||
mock_agents,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/get_token_analytics",
|
||||
method="POST",
|
||||
json={"filter_option": "last_15_days"},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = GetTokenAnalytics().post()
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGetFeedbackAnalyticsAdditional:
|
||||
"""Additional tests for GetFeedbackAnalytics covering missing lines."""
|
||||
|
||||
def test_returns_401_unauthenticated(self, app):
|
||||
from application.api.user.analytics.routes import GetFeedbackAnalytics
|
||||
|
||||
with app.test_request_context(
|
||||
"/api/get_feedback_analytics",
|
||||
method="POST",
|
||||
json={"filter_option": "last_30_days"},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = None
|
||||
response = GetFeedbackAnalytics().post()
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_last_hour_filter(self, app):
|
||||
from application.api.user.analytics.routes import GetFeedbackAnalytics
|
||||
|
||||
mock_conversations = Mock()
|
||||
mock_conversations.aggregate.return_value = []
|
||||
mock_agents = Mock()
|
||||
mock_agents.find_one.return_value = None
|
||||
|
||||
with patch(
|
||||
"application.api.user.analytics.routes.conversations_collection",
|
||||
mock_conversations,
|
||||
), patch(
|
||||
"application.api.user.analytics.routes.agents_collection",
|
||||
mock_agents,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/get_feedback_analytics",
|
||||
method="POST",
|
||||
json={"filter_option": "last_hour"},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = GetFeedbackAnalytics().post()
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_last_24_hour_filter(self, app):
|
||||
from application.api.user.analytics.routes import GetFeedbackAnalytics
|
||||
|
||||
mock_conversations = Mock()
|
||||
mock_conversations.aggregate.return_value = []
|
||||
mock_agents = Mock()
|
||||
mock_agents.find_one.return_value = None
|
||||
|
||||
with patch(
|
||||
"application.api.user.analytics.routes.conversations_collection",
|
||||
mock_conversations,
|
||||
), patch(
|
||||
"application.api.user.analytics.routes.agents_collection",
|
||||
mock_agents,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/get_feedback_analytics",
|
||||
method="POST",
|
||||
json={"filter_option": "last_24_hour"},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = GetFeedbackAnalytics().post()
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_filters_by_api_key(self, app):
|
||||
from application.api.user.analytics.routes import GetFeedbackAnalytics
|
||||
|
||||
agent_id = ObjectId()
|
||||
mock_agents = Mock()
|
||||
mock_agents.find_one.return_value = {
|
||||
"_id": agent_id,
|
||||
"key": "fb_api_key",
|
||||
}
|
||||
mock_conversations = Mock()
|
||||
mock_conversations.aggregate.return_value = []
|
||||
|
||||
with patch(
|
||||
"application.api.user.analytics.routes.agents_collection",
|
||||
mock_agents,
|
||||
), patch(
|
||||
"application.api.user.analytics.routes.conversations_collection",
|
||||
mock_conversations,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/get_feedback_analytics",
|
||||
method="POST",
|
||||
json={
|
||||
"filter_option": "last_7_days",
|
||||
"api_key_id": str(agent_id),
|
||||
},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = GetFeedbackAnalytics().post()
|
||||
|
||||
assert response.status_code == 200
|
||||
pipeline = mock_conversations.aggregate.call_args[0][0]
|
||||
assert pipeline[0]["$match"].get("api_key") == "fb_api_key"
|
||||
|
||||
def test_api_key_error_returns_400(self, app):
|
||||
from application.api.user.analytics.routes import GetFeedbackAnalytics
|
||||
|
||||
mock_agents = Mock()
|
||||
mock_agents.find_one.side_effect = Exception("db error")
|
||||
|
||||
with patch(
|
||||
"application.api.user.analytics.routes.agents_collection",
|
||||
mock_agents,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/get_feedback_analytics",
|
||||
method="POST",
|
||||
json={
|
||||
"filter_option": "last_30_days",
|
||||
"api_key_id": str(ObjectId()),
|
||||
},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = GetFeedbackAnalytics().post()
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_aggregate_error_returns_400(self, app):
|
||||
from application.api.user.analytics.routes import GetFeedbackAnalytics
|
||||
|
||||
mock_agents = Mock()
|
||||
mock_agents.find_one.return_value = None
|
||||
mock_conversations = Mock()
|
||||
mock_conversations.aggregate.side_effect = Exception("aggregate error")
|
||||
|
||||
with patch(
|
||||
"application.api.user.analytics.routes.agents_collection",
|
||||
mock_agents,
|
||||
), patch(
|
||||
"application.api.user.analytics.routes.conversations_collection",
|
||||
mock_conversations,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/get_feedback_analytics",
|
||||
method="POST",
|
||||
json={"filter_option": "last_30_days"},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = GetFeedbackAnalytics().post()
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGetMessageAnalyticsAdditional:
|
||||
"""Additional tests for GetMessageAnalytics covering error paths."""
|
||||
|
||||
def test_api_key_error_returns_400(self, app):
|
||||
from application.api.user.analytics.routes import GetMessageAnalytics
|
||||
|
||||
mock_agents = Mock()
|
||||
mock_agents.find_one.side_effect = Exception("db error")
|
||||
|
||||
with patch(
|
||||
"application.api.user.analytics.routes.agents_collection",
|
||||
mock_agents,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/get_message_analytics",
|
||||
method="POST",
|
||||
json={
|
||||
"filter_option": "last_30_days",
|
||||
"api_key_id": str(ObjectId()),
|
||||
},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = GetMessageAnalytics().post()
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_aggregate_error_returns_400(self, app):
|
||||
from application.api.user.analytics.routes import GetMessageAnalytics
|
||||
|
||||
mock_agents = Mock()
|
||||
mock_agents.find_one.return_value = None
|
||||
mock_conversations = Mock()
|
||||
mock_conversations.aggregate.side_effect = Exception("aggregate error")
|
||||
|
||||
with patch(
|
||||
"application.api.user.analytics.routes.agents_collection",
|
||||
mock_agents,
|
||||
), patch(
|
||||
"application.api.user.analytics.routes.conversations_collection",
|
||||
mock_conversations,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/get_message_analytics",
|
||||
method="POST",
|
||||
json={"filter_option": "last_30_days"},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = GetMessageAnalytics().post()
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_last_15_days_filter(self, app):
|
||||
from application.api.user.analytics.routes import GetMessageAnalytics
|
||||
|
||||
mock_conversations = Mock()
|
||||
mock_conversations.aggregate.return_value = []
|
||||
mock_agents = Mock()
|
||||
mock_agents.find_one.return_value = None
|
||||
|
||||
with patch(
|
||||
"application.api.user.analytics.routes.conversations_collection",
|
||||
mock_conversations,
|
||||
), patch(
|
||||
"application.api.user.analytics.routes.agents_collection",
|
||||
mock_agents,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/get_message_analytics",
|
||||
method="POST",
|
||||
json={"filter_option": "last_15_days"},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = GetMessageAnalytics().post()
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGetUserLogsAdditional:
|
||||
"""Additional tests for GetUserLogs covering api_key filtering and errors."""
|
||||
|
||||
def test_filters_by_api_key(self, app):
|
||||
from application.api.user.analytics.routes import GetUserLogs
|
||||
|
||||
agent_id = ObjectId()
|
||||
mock_agents = Mock()
|
||||
mock_agents.find_one.return_value = {
|
||||
"_id": agent_id,
|
||||
"key": "logs_api_key",
|
||||
}
|
||||
mock_cursor = Mock()
|
||||
mock_cursor.sort.return_value.skip.return_value.limit.return_value = []
|
||||
mock_user_logs = Mock()
|
||||
mock_user_logs.find.return_value = mock_cursor
|
||||
|
||||
with patch(
|
||||
"application.api.user.analytics.routes.user_logs_collection",
|
||||
mock_user_logs,
|
||||
), patch(
|
||||
"application.api.user.analytics.routes.agents_collection",
|
||||
mock_agents,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/get_user_logs",
|
||||
method="POST",
|
||||
json={
|
||||
"page": 1,
|
||||
"page_size": 10,
|
||||
"api_key_id": str(agent_id),
|
||||
},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = GetUserLogs().post()
|
||||
|
||||
assert response.status_code == 200
|
||||
query_arg = mock_user_logs.find.call_args[0][0]
|
||||
assert query_arg == {"api_key": "logs_api_key"}
|
||||
|
||||
def test_api_key_error_returns_400(self, app):
|
||||
from application.api.user.analytics.routes import GetUserLogs
|
||||
|
||||
mock_agents = Mock()
|
||||
mock_agents.find_one.side_effect = Exception("db error")
|
||||
|
||||
with patch(
|
||||
"application.api.user.analytics.routes.agents_collection",
|
||||
mock_agents,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/get_user_logs",
|
||||
method="POST",
|
||||
json={
|
||||
"page": 1,
|
||||
"api_key_id": str(ObjectId()),
|
||||
},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = GetUserLogs().post()
|
||||
|
||||
assert response.status_code == 400
|
||||
360
tests/api/user/test_conversations.py
Normal file
360
tests/api/user/test_conversations.py
Normal file
@@ -0,0 +1,360 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from bson import ObjectId
|
||||
from flask import Flask
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
app = Flask(__name__)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestDeleteConversation:
|
||||
|
||||
def test_deletes_conversation(self, app):
|
||||
from application.api.user.conversations.routes import DeleteConversation
|
||||
|
||||
conv_id = ObjectId()
|
||||
mock_collection = Mock()
|
||||
|
||||
with patch(
|
||||
"application.api.user.conversations.routes.conversations_collection",
|
||||
mock_collection,
|
||||
):
|
||||
with app.test_request_context(f"/api/delete_conversation?id={conv_id}"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = DeleteConversation().post()
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json["success"] is True
|
||||
mock_collection.delete_one.assert_called_once_with(
|
||||
{"_id": conv_id, "user": "user1"}
|
||||
)
|
||||
|
||||
def test_returns_401_unauthenticated(self, app):
|
||||
from application.api.user.conversations.routes import DeleteConversation
|
||||
|
||||
with app.test_request_context("/api/delete_conversation?id=abc"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = None
|
||||
response = DeleteConversation().post()
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_returns_400_missing_id(self, app):
|
||||
from application.api.user.conversations.routes import DeleteConversation
|
||||
|
||||
with app.test_request_context("/api/delete_conversation"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = DeleteConversation().post()
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestDeleteAllConversations:
|
||||
|
||||
def test_deletes_all_for_user(self, app):
|
||||
from application.api.user.conversations.routes import DeleteAllConversations
|
||||
|
||||
mock_collection = Mock()
|
||||
|
||||
with patch(
|
||||
"application.api.user.conversations.routes.conversations_collection",
|
||||
mock_collection,
|
||||
):
|
||||
with app.test_request_context("/api/delete_all_conversations"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = DeleteAllConversations().get()
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_collection.delete_many.assert_called_once_with({"user": "user1"})
|
||||
|
||||
def test_returns_401_unauthenticated(self, app):
|
||||
from application.api.user.conversations.routes import DeleteAllConversations
|
||||
|
||||
with app.test_request_context("/api/delete_all_conversations"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = None
|
||||
response = DeleteAllConversations().get()
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGetConversations:
|
||||
|
||||
def test_returns_conversations(self, app):
|
||||
from application.api.user.conversations.routes import GetConversations
|
||||
|
||||
conv_id = ObjectId()
|
||||
mock_cursor = Mock()
|
||||
mock_cursor.sort.return_value.limit.return_value = [
|
||||
{
|
||||
"_id": conv_id,
|
||||
"name": "Test Chat",
|
||||
"agent_id": "agent1",
|
||||
"is_shared_usage": False,
|
||||
}
|
||||
]
|
||||
mock_collection = Mock()
|
||||
mock_collection.find.return_value = mock_cursor
|
||||
|
||||
with patch(
|
||||
"application.api.user.conversations.routes.conversations_collection",
|
||||
mock_collection,
|
||||
):
|
||||
with app.test_request_context("/api/get_conversations"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = GetConversations().get()
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json
|
||||
assert len(data) == 1
|
||||
assert data[0]["id"] == str(conv_id)
|
||||
assert data[0]["name"] == "Test Chat"
|
||||
|
||||
def test_returns_401_unauthenticated(self, app):
|
||||
from application.api.user.conversations.routes import GetConversations
|
||||
|
||||
with app.test_request_context("/api/get_conversations"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = None
|
||||
response = GetConversations().get()
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGetSingleConversation:
|
||||
|
||||
def test_returns_conversation(self, app):
|
||||
from application.api.user.conversations.routes import GetSingleConversation
|
||||
|
||||
conv_id = ObjectId()
|
||||
mock_conv_collection = Mock()
|
||||
mock_conv_collection.find_one.return_value = {
|
||||
"_id": conv_id,
|
||||
"name": "Chat",
|
||||
"queries": [{"prompt": "hi", "response": "hello"}],
|
||||
"agent_id": "agent1",
|
||||
}
|
||||
|
||||
with patch(
|
||||
"application.api.user.conversations.routes.conversations_collection",
|
||||
mock_conv_collection,
|
||||
):
|
||||
with app.test_request_context(
|
||||
f"/api/get_single_conversation?id={conv_id}"
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = GetSingleConversation().get()
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json["queries"] == [{"prompt": "hi", "response": "hello"}]
|
||||
assert response.json["agent_id"] == "agent1"
|
||||
|
||||
def test_returns_404_not_found(self, app):
|
||||
from application.api.user.conversations.routes import GetSingleConversation
|
||||
|
||||
mock_collection = Mock()
|
||||
mock_collection.find_one.return_value = None
|
||||
|
||||
with patch(
|
||||
"application.api.user.conversations.routes.conversations_collection",
|
||||
mock_collection,
|
||||
):
|
||||
with app.test_request_context(
|
||||
f"/api/get_single_conversation?id={ObjectId()}"
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = GetSingleConversation().get()
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_returns_400_missing_id(self, app):
|
||||
from application.api.user.conversations.routes import GetSingleConversation
|
||||
|
||||
with app.test_request_context("/api/get_single_conversation"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = GetSingleConversation().get()
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_resolves_attachments(self, app):
|
||||
from application.api.user.conversations.routes import GetSingleConversation
|
||||
|
||||
conv_id = ObjectId()
|
||||
att_id = ObjectId()
|
||||
mock_conv_collection = Mock()
|
||||
mock_conv_collection.find_one.return_value = {
|
||||
"_id": conv_id,
|
||||
"name": "Chat",
|
||||
"queries": [
|
||||
{"prompt": "hi", "response": "hello", "attachments": [str(att_id)]}
|
||||
],
|
||||
}
|
||||
mock_att_collection = Mock()
|
||||
mock_att_collection.find_one.return_value = {
|
||||
"_id": att_id,
|
||||
"filename": "doc.pdf",
|
||||
}
|
||||
|
||||
with patch(
|
||||
"application.api.user.conversations.routes.conversations_collection",
|
||||
mock_conv_collection,
|
||||
), patch(
|
||||
"application.api.user.conversations.routes.attachments_collection",
|
||||
mock_att_collection,
|
||||
):
|
||||
with app.test_request_context(
|
||||
f"/api/get_single_conversation?id={conv_id}"
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = GetSingleConversation().get()
|
||||
|
||||
assert response.status_code == 200
|
||||
attachments = response.json["queries"][0]["attachments"]
|
||||
assert len(attachments) == 1
|
||||
assert attachments[0]["fileName"] == "doc.pdf"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestUpdateConversationName:
|
||||
|
||||
def test_updates_name(self, app):
|
||||
from application.api.user.conversations.routes import UpdateConversationName
|
||||
|
||||
conv_id = ObjectId()
|
||||
mock_collection = Mock()
|
||||
|
||||
with patch(
|
||||
"application.api.user.conversations.routes.conversations_collection",
|
||||
mock_collection,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/update_conversation_name",
|
||||
method="POST",
|
||||
json={"id": str(conv_id), "name": "New Name"},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = UpdateConversationName().post()
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json["success"] is True
|
||||
mock_collection.update_one.assert_called_once()
|
||||
|
||||
def test_returns_400_missing_fields(self, app):
|
||||
from application.api.user.conversations.routes import UpdateConversationName
|
||||
|
||||
with app.test_request_context(
|
||||
"/api/update_conversation_name",
|
||||
method="POST",
|
||||
json={"id": str(ObjectId())},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = UpdateConversationName().post()
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSubmitFeedback:
|
||||
|
||||
def test_submits_positive_feedback(self, app):
|
||||
from application.api.user.conversations.routes import SubmitFeedback
|
||||
|
||||
conv_id = ObjectId()
|
||||
mock_collection = Mock()
|
||||
|
||||
with patch(
|
||||
"application.api.user.conversations.routes.conversations_collection",
|
||||
mock_collection,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/feedback",
|
||||
method="POST",
|
||||
json={
|
||||
"feedback": "LIKE",
|
||||
"conversation_id": str(conv_id),
|
||||
"question_index": 0,
|
||||
},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = SubmitFeedback().post()
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json["success"] is True
|
||||
call_args = mock_collection.update_one.call_args
|
||||
assert "$set" in call_args[0][1]
|
||||
|
||||
def test_removes_feedback_when_null(self, app):
|
||||
from application.api.user.conversations.routes import SubmitFeedback
|
||||
|
||||
conv_id = ObjectId()
|
||||
mock_collection = Mock()
|
||||
|
||||
with patch(
|
||||
"application.api.user.conversations.routes.conversations_collection",
|
||||
mock_collection,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/feedback",
|
||||
method="POST",
|
||||
json={
|
||||
"feedback": None,
|
||||
"conversation_id": str(conv_id),
|
||||
"question_index": 0,
|
||||
},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = SubmitFeedback().post()
|
||||
|
||||
assert response.status_code == 200
|
||||
call_args = mock_collection.update_one.call_args
|
||||
assert "$unset" in call_args[0][1]
|
||||
|
||||
def test_returns_400_missing_fields(self, app):
|
||||
from application.api.user.conversations.routes import SubmitFeedback
|
||||
|
||||
with app.test_request_context(
|
||||
"/api/feedback",
|
||||
method="POST",
|
||||
json={"feedback": "LIKE"},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = SubmitFeedback().post()
|
||||
|
||||
assert response.status_code == 400
|
||||
622
tests/api/user/test_folders.py
Normal file
622
tests/api/user/test_folders.py
Normal file
@@ -0,0 +1,622 @@
|
||||
import datetime
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from bson import ObjectId
|
||||
from flask import Flask
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
app = Flask(__name__)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestAgentFoldersGet:
|
||||
|
||||
def test_returns_folders(self, app):
|
||||
from application.api.user.agents.folders import AgentFolders
|
||||
|
||||
now = datetime.datetime(2024, 6, 15, tzinfo=datetime.timezone.utc)
|
||||
folder_id = ObjectId()
|
||||
mock_collection = Mock()
|
||||
mock_collection.find.return_value = [
|
||||
{
|
||||
"_id": folder_id,
|
||||
"name": "My Folder",
|
||||
"parent_id": None,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
}
|
||||
]
|
||||
|
||||
with patch(
|
||||
"application.api.user.agents.folders.agent_folders_collection",
|
||||
mock_collection,
|
||||
):
|
||||
with app.test_request_context("/api/agents/folders/", method="GET"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = AgentFolders().get()
|
||||
|
||||
assert response.status_code == 200
|
||||
folders = response.json["folders"]
|
||||
assert len(folders) == 1
|
||||
assert folders[0]["id"] == str(folder_id)
|
||||
assert folders[0]["name"] == "My Folder"
|
||||
|
||||
def test_returns_401_unauthenticated(self, app):
|
||||
from application.api.user.agents.folders import AgentFolders
|
||||
|
||||
with app.test_request_context("/api/agents/folders/", method="GET"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = None
|
||||
response = AgentFolders().get()
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestAgentFoldersCreate:
|
||||
|
||||
def test_creates_folder(self, app):
|
||||
from application.api.user.agents.folders import AgentFolders
|
||||
|
||||
inserted_id = ObjectId()
|
||||
mock_collection = Mock()
|
||||
mock_collection.insert_one.return_value = Mock(inserted_id=inserted_id)
|
||||
|
||||
with patch(
|
||||
"application.api.user.agents.folders.agent_folders_collection",
|
||||
mock_collection,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/agents/folders/",
|
||||
method="POST",
|
||||
json={"name": "New Folder"},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = AgentFolders().post()
|
||||
|
||||
assert response.status_code == 201
|
||||
assert response.json["id"] == str(inserted_id)
|
||||
assert response.json["name"] == "New Folder"
|
||||
|
||||
def test_returns_400_missing_name(self, app):
|
||||
from application.api.user.agents.folders import AgentFolders
|
||||
|
||||
with app.test_request_context(
|
||||
"/api/agents/folders/",
|
||||
method="POST",
|
||||
json={},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = AgentFolders().post()
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_validates_parent_folder_exists(self, app):
|
||||
from application.api.user.agents.folders import AgentFolders
|
||||
|
||||
mock_collection = Mock()
|
||||
mock_collection.find_one.return_value = None
|
||||
|
||||
with patch(
|
||||
"application.api.user.agents.folders.agent_folders_collection",
|
||||
mock_collection,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/agents/folders/",
|
||||
method="POST",
|
||||
json={"name": "Sub", "parent_id": str(ObjectId())},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = AgentFolders().post()
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestAgentFolderGet:
|
||||
|
||||
def test_returns_folder_with_agents_and_subfolders(self, app):
|
||||
from application.api.user.agents.folders import AgentFolder
|
||||
|
||||
folder_id = ObjectId()
|
||||
agent_id = ObjectId()
|
||||
subfolder_id = ObjectId()
|
||||
mock_folders = Mock()
|
||||
mock_folders.find_one.return_value = {
|
||||
"_id": folder_id,
|
||||
"name": "Folder",
|
||||
"parent_id": None,
|
||||
}
|
||||
mock_folders.find.return_value = [
|
||||
{"_id": subfolder_id, "name": "Subfolder"}
|
||||
]
|
||||
mock_agents = Mock()
|
||||
mock_agents.find.return_value = [
|
||||
{"_id": agent_id, "name": "Agent 1", "description": "Desc"}
|
||||
]
|
||||
|
||||
with patch(
|
||||
"application.api.user.agents.folders.agent_folders_collection",
|
||||
mock_folders,
|
||||
), patch(
|
||||
"application.api.user.agents.folders.agents_collection",
|
||||
mock_agents,
|
||||
):
|
||||
with app.test_request_context(
|
||||
f"/api/agents/folders/{folder_id}", method="GET"
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = AgentFolder().get(str(folder_id))
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json["name"] == "Folder"
|
||||
assert len(response.json["agents"]) == 1
|
||||
assert len(response.json["subfolders"]) == 1
|
||||
|
||||
def test_returns_404_not_found(self, app):
|
||||
from application.api.user.agents.folders import AgentFolder
|
||||
|
||||
mock_collection = Mock()
|
||||
mock_collection.find_one.return_value = None
|
||||
|
||||
with patch(
|
||||
"application.api.user.agents.folders.agent_folders_collection",
|
||||
mock_collection,
|
||||
):
|
||||
with app.test_request_context(
|
||||
f"/api/agents/folders/{ObjectId()}", method="GET"
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = AgentFolder().get(str(ObjectId()))
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestAgentFolderUpdate:
|
||||
|
||||
def test_updates_folder_name(self, app):
|
||||
from application.api.user.agents.folders import AgentFolder
|
||||
|
||||
folder_id = ObjectId()
|
||||
mock_collection = Mock()
|
||||
mock_collection.update_one.return_value = Mock(matched_count=1)
|
||||
|
||||
with patch(
|
||||
"application.api.user.agents.folders.agent_folders_collection",
|
||||
mock_collection,
|
||||
):
|
||||
with app.test_request_context(
|
||||
f"/api/agents/folders/{folder_id}",
|
||||
method="PUT",
|
||||
json={"name": "Renamed"},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = AgentFolder().put(str(folder_id))
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json["success"] is True
|
||||
|
||||
def test_prevents_self_parent(self, app):
|
||||
from application.api.user.agents.folders import AgentFolder
|
||||
|
||||
folder_id = str(ObjectId())
|
||||
|
||||
with app.test_request_context(
|
||||
f"/api/agents/folders/{folder_id}",
|
||||
method="PUT",
|
||||
json={"parent_id": folder_id},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = AgentFolder().put(folder_id)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "own parent" in response.json["message"]
|
||||
|
||||
def test_returns_404_when_not_found(self, app):
|
||||
from application.api.user.agents.folders import AgentFolder
|
||||
|
||||
mock_collection = Mock()
|
||||
mock_collection.update_one.return_value = Mock(matched_count=0)
|
||||
|
||||
with patch(
|
||||
"application.api.user.agents.folders.agent_folders_collection",
|
||||
mock_collection,
|
||||
):
|
||||
with app.test_request_context(
|
||||
f"/api/agents/folders/{ObjectId()}",
|
||||
method="PUT",
|
||||
json={"name": "X"},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = AgentFolder().put(str(ObjectId()))
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestAgentFolderDelete:
|
||||
|
||||
def test_deletes_folder_and_unsets_references(self, app):
|
||||
from application.api.user.agents.folders import AgentFolder
|
||||
|
||||
folder_id = str(ObjectId())
|
||||
mock_folders = Mock()
|
||||
mock_folders.delete_one.return_value = Mock(deleted_count=1)
|
||||
mock_agents = Mock()
|
||||
|
||||
with patch(
|
||||
"application.api.user.agents.folders.agent_folders_collection",
|
||||
mock_folders,
|
||||
), patch(
|
||||
"application.api.user.agents.folders.agents_collection",
|
||||
mock_agents,
|
||||
):
|
||||
with app.test_request_context(
|
||||
f"/api/agents/folders/{folder_id}", method="DELETE"
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = AgentFolder().delete(folder_id)
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_agents.update_many.assert_called_once()
|
||||
mock_folders.update_many.assert_called_once()
|
||||
mock_folders.delete_one.assert_called_once()
|
||||
|
||||
def test_returns_404_not_found(self, app):
|
||||
from application.api.user.agents.folders import AgentFolder
|
||||
|
||||
mock_folders = Mock()
|
||||
mock_folders.delete_one.return_value = Mock(deleted_count=0)
|
||||
mock_agents = Mock()
|
||||
|
||||
with patch(
|
||||
"application.api.user.agents.folders.agent_folders_collection",
|
||||
mock_folders,
|
||||
), patch(
|
||||
"application.api.user.agents.folders.agents_collection",
|
||||
mock_agents,
|
||||
):
|
||||
with app.test_request_context(
|
||||
f"/api/agents/folders/{ObjectId()}", method="DELETE"
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = AgentFolder().delete(str(ObjectId()))
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestMoveAgentToFolder:
|
||||
|
||||
def test_moves_agent_to_folder(self, app):
|
||||
from application.api.user.agents.folders import MoveAgentToFolder
|
||||
|
||||
agent_id = ObjectId()
|
||||
folder_id = ObjectId()
|
||||
mock_agents = Mock()
|
||||
mock_agents.find_one.return_value = {"_id": agent_id, "user": "user1"}
|
||||
mock_folders = Mock()
|
||||
mock_folders.find_one.return_value = {"_id": folder_id}
|
||||
|
||||
with patch(
|
||||
"application.api.user.agents.folders.agents_collection",
|
||||
mock_agents,
|
||||
), patch(
|
||||
"application.api.user.agents.folders.agent_folders_collection",
|
||||
mock_folders,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/agents/folders/move_agent",
|
||||
method="POST",
|
||||
json={
|
||||
"agent_id": str(agent_id),
|
||||
"folder_id": str(folder_id),
|
||||
},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = MoveAgentToFolder().post()
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_agents.update_one.assert_called_once()
|
||||
|
||||
def test_removes_agent_from_folder(self, app):
|
||||
from application.api.user.agents.folders import MoveAgentToFolder
|
||||
|
||||
agent_id = ObjectId()
|
||||
mock_agents = Mock()
|
||||
mock_agents.find_one.return_value = {"_id": agent_id, "user": "user1"}
|
||||
|
||||
with patch(
|
||||
"application.api.user.agents.folders.agents_collection",
|
||||
mock_agents,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/agents/folders/move_agent",
|
||||
method="POST",
|
||||
json={"agent_id": str(agent_id), "folder_id": None},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = MoveAgentToFolder().post()
|
||||
|
||||
assert response.status_code == 200
|
||||
call_args = mock_agents.update_one.call_args
|
||||
assert "$unset" in call_args[0][1]
|
||||
|
||||
def test_returns_404_agent_not_found(self, app):
|
||||
from application.api.user.agents.folders import MoveAgentToFolder
|
||||
|
||||
mock_agents = Mock()
|
||||
mock_agents.find_one.return_value = None
|
||||
|
||||
with patch(
|
||||
"application.api.user.agents.folders.agents_collection",
|
||||
mock_agents,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/agents/folders/move_agent",
|
||||
method="POST",
|
||||
json={"agent_id": str(ObjectId())},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = MoveAgentToFolder().post()
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_returns_400_missing_agent_id(self, app):
|
||||
from application.api.user.agents.folders import MoveAgentToFolder
|
||||
|
||||
with app.test_request_context(
|
||||
"/api/agents/folders/move_agent",
|
||||
method="POST",
|
||||
json={},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = MoveAgentToFolder().post()
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestBulkMoveAgents:
|
||||
|
||||
def test_bulk_moves_to_folder(self, app):
|
||||
from application.api.user.agents.folders import BulkMoveAgents
|
||||
|
||||
folder_id = ObjectId()
|
||||
agent_ids = [str(ObjectId()), str(ObjectId())]
|
||||
mock_agents = Mock()
|
||||
mock_folders = Mock()
|
||||
mock_folders.find_one.return_value = {"_id": folder_id}
|
||||
|
||||
with patch(
|
||||
"application.api.user.agents.folders.agents_collection",
|
||||
mock_agents,
|
||||
), patch(
|
||||
"application.api.user.agents.folders.agent_folders_collection",
|
||||
mock_folders,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/agents/folders/bulk_move",
|
||||
method="POST",
|
||||
json={"agent_ids": agent_ids, "folder_id": str(folder_id)},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = BulkMoveAgents().post()
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_agents.update_many.assert_called_once()
|
||||
|
||||
def test_bulk_removes_from_folders(self, app):
|
||||
from application.api.user.agents.folders import BulkMoveAgents
|
||||
|
||||
agent_ids = [str(ObjectId())]
|
||||
mock_agents = Mock()
|
||||
|
||||
with patch(
|
||||
"application.api.user.agents.folders.agents_collection",
|
||||
mock_agents,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/agents/folders/bulk_move",
|
||||
method="POST",
|
||||
json={"agent_ids": agent_ids},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = BulkMoveAgents().post()
|
||||
|
||||
assert response.status_code == 200
|
||||
call_args = mock_agents.update_many.call_args
|
||||
assert "$unset" in call_args[0][1]
|
||||
|
||||
def test_returns_400_missing_agent_ids(self, app):
|
||||
from application.api.user.agents.folders import BulkMoveAgents
|
||||
|
||||
with app.test_request_context(
|
||||
"/api/agents/folders/bulk_move",
|
||||
method="POST",
|
||||
json={},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = BulkMoveAgents().post()
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_returns_404_folder_not_found(self, app):
|
||||
from application.api.user.agents.folders import BulkMoveAgents
|
||||
|
||||
mock_folders = Mock()
|
||||
mock_folders.find_one.return_value = None
|
||||
|
||||
with patch(
|
||||
"application.api.user.agents.folders.agent_folders_collection",
|
||||
mock_folders,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/agents/folders/bulk_move",
|
||||
method="POST",
|
||||
json={
|
||||
"agent_ids": [str(ObjectId())],
|
||||
"folder_id": str(ObjectId()),
|
||||
},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = BulkMoveAgents().post()
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Coverage gap tests (lines 64, 90-91, 100, 125-126, 132, 136)
|
||||
# =====================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestAgentFoldersGaps:
|
||||
|
||||
def test_create_folder_no_auth(self, app):
|
||||
"""Cover line 64: post returns 401 when no decoded_token."""
|
||||
from application.api.user.agents.folders import AgentFolders
|
||||
|
||||
with app.test_request_context(
|
||||
"/api/agents/folders/",
|
||||
method="POST",
|
||||
json={"name": "Test"},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = None
|
||||
response = AgentFolders().post()
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_create_folder_exception(self, app):
|
||||
"""Cover lines 90-91: exception during insert_one returns 400."""
|
||||
from application.api.user.agents.folders import AgentFolders
|
||||
|
||||
mock_folders = Mock()
|
||||
mock_folders.find_one.return_value = None
|
||||
mock_folders.insert_one.side_effect = Exception("db error")
|
||||
|
||||
with patch(
|
||||
"application.api.user.agents.folders.agent_folders_collection",
|
||||
mock_folders,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/agents/folders/",
|
||||
method="POST",
|
||||
json={"name": "Test"},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = AgentFolders().post()
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_get_folder_no_auth(self, app):
|
||||
"""Cover line 100: get specific folder returns 401 when no auth."""
|
||||
from application.api.user.agents.folders import AgentFolder
|
||||
|
||||
with app.test_request_context(
|
||||
"/api/agents/folders/abc",
|
||||
method="GET",
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = None
|
||||
response = AgentFolder().get("abc")
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_get_folder_exception(self, app):
|
||||
"""Cover lines 125-126: exception during find returns 400."""
|
||||
from application.api.user.agents.folders import AgentFolder
|
||||
|
||||
mock_folders = Mock()
|
||||
mock_folders.find_one.side_effect = Exception("db error")
|
||||
|
||||
with patch(
|
||||
"application.api.user.agents.folders.agent_folders_collection",
|
||||
mock_folders,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/agents/folders/" + str(ObjectId()),
|
||||
method="GET",
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = AgentFolder().get(str(ObjectId()))
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_update_folder_no_auth(self, app):
|
||||
"""Cover line 132: put returns 401 when no decoded_token."""
|
||||
from application.api.user.agents.folders import AgentFolder
|
||||
|
||||
with app.test_request_context(
|
||||
"/api/agents/folders/abc",
|
||||
method="PUT",
|
||||
json={"name": "Updated"},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = None
|
||||
response = AgentFolder().put("abc")
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_update_folder_no_data(self, app):
|
||||
"""Cover line 136: put with no data returns 400."""
|
||||
from application.api.user.agents.folders import AgentFolder
|
||||
|
||||
with app.test_request_context(
|
||||
"/api/agents/folders/abc",
|
||||
method="PUT",
|
||||
content_type="application/json",
|
||||
data="null",
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = AgentFolder().put("abc")
|
||||
assert response.status_code == 400
|
||||
70
tests/api/user/test_models.py
Normal file
70
tests/api/user/test_models.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
app = Flask(__name__)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestModelsListResource:
|
||||
|
||||
def test_returns_models(self, app):
|
||||
from application.api.user.models.routes import ModelsListResource
|
||||
|
||||
mock_model = Mock()
|
||||
mock_model.to_dict.return_value = {
|
||||
"id": "gpt-4",
|
||||
"name": "GPT-4",
|
||||
"provider": "openai",
|
||||
}
|
||||
|
||||
mock_registry = Mock()
|
||||
mock_registry.get_enabled_models.return_value = [mock_model]
|
||||
mock_registry.default_model_id = "gpt-4"
|
||||
|
||||
with patch(
|
||||
"application.api.user.models.routes.ModelRegistry.get_instance",
|
||||
return_value=mock_registry,
|
||||
):
|
||||
with app.test_request_context("/api/models"):
|
||||
response = ModelsListResource().get()
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json["count"] == 1
|
||||
assert response.json["default_model_id"] == "gpt-4"
|
||||
assert response.json["models"][0]["id"] == "gpt-4"
|
||||
|
||||
def test_returns_empty_models(self, app):
|
||||
from application.api.user.models.routes import ModelsListResource
|
||||
|
||||
mock_registry = Mock()
|
||||
mock_registry.get_enabled_models.return_value = []
|
||||
mock_registry.default_model_id = None
|
||||
|
||||
with patch(
|
||||
"application.api.user.models.routes.ModelRegistry.get_instance",
|
||||
return_value=mock_registry,
|
||||
):
|
||||
with app.test_request_context("/api/models"):
|
||||
response = ModelsListResource().get()
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json["count"] == 0
|
||||
assert response.json["models"] == []
|
||||
|
||||
def test_returns_500_on_error(self, app):
|
||||
from application.api.user.models.routes import ModelsListResource
|
||||
|
||||
with patch(
|
||||
"application.api.user.models.routes.ModelRegistry.get_instance",
|
||||
side_effect=Exception("Registry error"),
|
||||
):
|
||||
with app.test_request_context("/api/models"):
|
||||
response = ModelsListResource().get()
|
||||
|
||||
assert response.status_code == 500
|
||||
288
tests/api/user/test_prompts.py
Normal file
288
tests/api/user/test_prompts.py
Normal file
@@ -0,0 +1,288 @@
|
||||
from unittest.mock import Mock, mock_open, patch
|
||||
|
||||
import pytest
|
||||
from bson import ObjectId
|
||||
from flask import Flask
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
app = Flask(__name__)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCreatePrompt:
|
||||
|
||||
def test_creates_prompt(self, app):
|
||||
from application.api.user.prompts.routes import CreatePrompt
|
||||
|
||||
mock_collection = Mock()
|
||||
inserted_id = ObjectId()
|
||||
mock_collection.insert_one.return_value = Mock(inserted_id=inserted_id)
|
||||
|
||||
with patch(
|
||||
"application.api.user.prompts.routes.prompts_collection",
|
||||
mock_collection,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/create_prompt",
|
||||
method="POST",
|
||||
json={"name": "My Prompt", "content": "You are helpful."},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = CreatePrompt().post()
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json["id"] == str(inserted_id)
|
||||
mock_collection.insert_one.assert_called_once()
|
||||
doc = mock_collection.insert_one.call_args[0][0]
|
||||
assert doc["name"] == "My Prompt"
|
||||
assert doc["user"] == "user1"
|
||||
|
||||
def test_returns_401_unauthenticated(self, app):
|
||||
from application.api.user.prompts.routes import CreatePrompt
|
||||
|
||||
with app.test_request_context(
|
||||
"/api/create_prompt",
|
||||
method="POST",
|
||||
json={"name": "P", "content": "C"},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = None
|
||||
response = CreatePrompt().post()
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_returns_400_missing_fields(self, app):
|
||||
from application.api.user.prompts.routes import CreatePrompt
|
||||
|
||||
with app.test_request_context(
|
||||
"/api/create_prompt",
|
||||
method="POST",
|
||||
json={"name": "P"},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = CreatePrompt().post()
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGetPrompts:
|
||||
|
||||
def test_returns_prompts_with_defaults(self, app):
|
||||
from application.api.user.prompts.routes import GetPrompts
|
||||
|
||||
user_prompt_id = ObjectId()
|
||||
mock_collection = Mock()
|
||||
mock_collection.find.return_value = [
|
||||
{"_id": user_prompt_id, "name": "Custom Prompt"}
|
||||
]
|
||||
|
||||
with patch(
|
||||
"application.api.user.prompts.routes.prompts_collection",
|
||||
mock_collection,
|
||||
):
|
||||
with app.test_request_context("/api/get_prompts"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = GetPrompts().get()
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json
|
||||
public_names = [p["name"] for p in data if p["type"] == "public"]
|
||||
assert "default" in public_names
|
||||
assert "creative" in public_names
|
||||
assert "strict" in public_names
|
||||
private = [p for p in data if p["type"] == "private"]
|
||||
assert len(private) == 1
|
||||
assert private[0]["name"] == "Custom Prompt"
|
||||
|
||||
def test_returns_401_unauthenticated(self, app):
|
||||
from application.api.user.prompts.routes import GetPrompts
|
||||
|
||||
with app.test_request_context("/api/get_prompts"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = None
|
||||
response = GetPrompts().get()
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGetSinglePrompt:
|
||||
|
||||
def test_returns_default_prompt(self, app):
|
||||
from application.api.user.prompts.routes import GetSinglePrompt
|
||||
|
||||
with patch("builtins.open", mock_open(read_data="Default prompt content")):
|
||||
with app.test_request_context("/api/get_single_prompt?id=default"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = GetSinglePrompt().get()
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json["content"] == "Default prompt content"
|
||||
|
||||
def test_returns_creative_prompt(self, app):
|
||||
from application.api.user.prompts.routes import GetSinglePrompt
|
||||
|
||||
with patch("builtins.open", mock_open(read_data="Creative content")):
|
||||
with app.test_request_context("/api/get_single_prompt?id=creative"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = GetSinglePrompt().get()
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json["content"] == "Creative content"
|
||||
|
||||
def test_returns_strict_prompt(self, app):
|
||||
from application.api.user.prompts.routes import GetSinglePrompt
|
||||
|
||||
with patch("builtins.open", mock_open(read_data="Strict content")):
|
||||
with app.test_request_context("/api/get_single_prompt?id=strict"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = GetSinglePrompt().get()
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json["content"] == "Strict content"
|
||||
|
||||
def test_returns_custom_prompt(self, app):
|
||||
from application.api.user.prompts.routes import GetSinglePrompt
|
||||
|
||||
prompt_id = ObjectId()
|
||||
mock_collection = Mock()
|
||||
mock_collection.find_one.return_value = {
|
||||
"_id": prompt_id,
|
||||
"content": "Custom content",
|
||||
}
|
||||
|
||||
with patch(
|
||||
"application.api.user.prompts.routes.prompts_collection",
|
||||
mock_collection,
|
||||
):
|
||||
with app.test_request_context(
|
||||
f"/api/get_single_prompt?id={prompt_id}"
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = GetSinglePrompt().get()
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json["content"] == "Custom content"
|
||||
|
||||
def test_returns_400_missing_id(self, app):
|
||||
from application.api.user.prompts.routes import GetSinglePrompt
|
||||
|
||||
with app.test_request_context("/api/get_single_prompt"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = GetSinglePrompt().get()
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestDeletePrompt:
|
||||
|
||||
def test_deletes_prompt(self, app):
|
||||
from application.api.user.prompts.routes import DeletePrompt
|
||||
|
||||
prompt_id = ObjectId()
|
||||
mock_collection = Mock()
|
||||
|
||||
with patch(
|
||||
"application.api.user.prompts.routes.prompts_collection",
|
||||
mock_collection,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/delete_prompt",
|
||||
method="POST",
|
||||
json={"id": str(prompt_id)},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = DeletePrompt().post()
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json["success"] is True
|
||||
mock_collection.delete_one.assert_called_once_with(
|
||||
{"_id": prompt_id, "user": "user1"}
|
||||
)
|
||||
|
||||
def test_returns_400_missing_id(self, app):
|
||||
from application.api.user.prompts.routes import DeletePrompt
|
||||
|
||||
with app.test_request_context(
|
||||
"/api/delete_prompt",
|
||||
method="POST",
|
||||
json={},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = DeletePrompt().post()
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestUpdatePrompt:
|
||||
|
||||
def test_updates_prompt(self, app):
|
||||
from application.api.user.prompts.routes import UpdatePrompt
|
||||
|
||||
prompt_id = ObjectId()
|
||||
mock_collection = Mock()
|
||||
|
||||
with patch(
|
||||
"application.api.user.prompts.routes.prompts_collection",
|
||||
mock_collection,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/update_prompt",
|
||||
method="POST",
|
||||
json={
|
||||
"id": str(prompt_id),
|
||||
"name": "Updated",
|
||||
"content": "New content",
|
||||
},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = UpdatePrompt().post()
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json["success"] is True
|
||||
mock_collection.update_one.assert_called_once()
|
||||
|
||||
def test_returns_400_missing_fields(self, app):
|
||||
from application.api.user.prompts.routes import UpdatePrompt
|
||||
|
||||
with app.test_request_context(
|
||||
"/api/update_prompt",
|
||||
method="POST",
|
||||
json={"id": str(ObjectId()), "name": "Updated"},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = UpdatePrompt().post()
|
||||
|
||||
assert response.status_code == 400
|
||||
803
tests/api/user/test_sharing.py
Normal file
803
tests/api/user/test_sharing.py
Normal file
@@ -0,0 +1,803 @@
|
||||
import uuid
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from bson import ObjectId
|
||||
from bson.binary import Binary, UuidRepresentation
|
||||
from flask import Flask
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
app = Flask(__name__)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestShareConversation:
|
||||
|
||||
def test_shares_non_promptable_conversation(self, app):
|
||||
from application.api.user.sharing.routes import ShareConversation
|
||||
|
||||
conv_id = ObjectId()
|
||||
mock_conversations = Mock()
|
||||
mock_conversations.find_one.return_value = {
|
||||
"_id": conv_id,
|
||||
"name": "Test Chat",
|
||||
"queries": [{"prompt": "hi"}],
|
||||
}
|
||||
mock_shared = Mock()
|
||||
mock_shared.find_one.return_value = None
|
||||
mock_shared.insert_one.return_value = Mock()
|
||||
|
||||
with patch(
|
||||
"application.api.user.sharing.routes.conversations_collection",
|
||||
mock_conversations,
|
||||
), patch(
|
||||
"application.api.user.sharing.routes.shared_conversations_collections",
|
||||
mock_shared,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/share?isPromptable=false",
|
||||
method="POST",
|
||||
json={"conversation_id": str(conv_id)},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = ShareConversation().post()
|
||||
|
||||
assert response.status_code == 201
|
||||
assert response.json["success"] is True
|
||||
assert "identifier" in response.json
|
||||
mock_shared.insert_one.assert_called_once()
|
||||
|
||||
def test_returns_existing_shared_link(self, app):
|
||||
from application.api.user.sharing.routes import ShareConversation
|
||||
|
||||
conv_id = ObjectId()
|
||||
test_uuid = uuid.uuid4()
|
||||
binary_uuid = Binary.from_uuid(test_uuid, UuidRepresentation.STANDARD)
|
||||
|
||||
mock_conversations = Mock()
|
||||
mock_conversations.find_one.return_value = {
|
||||
"_id": conv_id,
|
||||
"name": "Test Chat",
|
||||
"queries": [{"prompt": "hi"}],
|
||||
}
|
||||
mock_shared = Mock()
|
||||
mock_shared.find_one.return_value = {
|
||||
"uuid": binary_uuid,
|
||||
"conversation_id": conv_id,
|
||||
}
|
||||
|
||||
with patch(
|
||||
"application.api.user.sharing.routes.conversations_collection",
|
||||
mock_conversations,
|
||||
), patch(
|
||||
"application.api.user.sharing.routes.shared_conversations_collections",
|
||||
mock_shared,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/share?isPromptable=false",
|
||||
method="POST",
|
||||
json={"conversation_id": str(conv_id)},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = ShareConversation().post()
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json["identifier"] == str(test_uuid)
|
||||
|
||||
def test_returns_401_unauthenticated(self, app):
|
||||
from application.api.user.sharing.routes import ShareConversation
|
||||
|
||||
with app.test_request_context(
|
||||
"/api/share?isPromptable=false",
|
||||
method="POST",
|
||||
json={"conversation_id": str(ObjectId())},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = None
|
||||
response = ShareConversation().post()
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_returns_400_missing_conversation_id(self, app):
|
||||
from application.api.user.sharing.routes import ShareConversation
|
||||
|
||||
with app.test_request_context(
|
||||
"/api/share?isPromptable=false",
|
||||
method="POST",
|
||||
json={},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = ShareConversation().post()
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_returns_400_missing_isPromptable(self, app):
|
||||
from application.api.user.sharing.routes import ShareConversation
|
||||
|
||||
with app.test_request_context(
|
||||
"/api/share",
|
||||
method="POST",
|
||||
json={"conversation_id": str(ObjectId())},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = ShareConversation().post()
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "isPromptable" in response.json["message"]
|
||||
|
||||
def test_returns_404_conversation_not_found(self, app):
|
||||
from application.api.user.sharing.routes import ShareConversation
|
||||
|
||||
mock_conversations = Mock()
|
||||
mock_conversations.find_one.return_value = None
|
||||
|
||||
with patch(
|
||||
"application.api.user.sharing.routes.conversations_collection",
|
||||
mock_conversations,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/share?isPromptable=false",
|
||||
method="POST",
|
||||
json={"conversation_id": str(ObjectId())},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = ShareConversation().post()
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGetPubliclySharedConversations:
|
||||
|
||||
def test_returns_shared_conversation(self, app):
|
||||
from application.api.user.sharing.routes import (
|
||||
GetPubliclySharedConversations,
|
||||
)
|
||||
|
||||
test_uuid = uuid.uuid4()
|
||||
binary_uuid = Binary.from_uuid(test_uuid, UuidRepresentation.STANDARD)
|
||||
conv_id = ObjectId()
|
||||
|
||||
mock_shared = Mock()
|
||||
mock_shared.find_one.return_value = {
|
||||
"uuid": binary_uuid,
|
||||
"conversation_id": conv_id,
|
||||
"first_n_queries": 2,
|
||||
"isPromptable": False,
|
||||
}
|
||||
mock_conversations = Mock()
|
||||
mock_conversations.find_one.return_value = {
|
||||
"_id": conv_id,
|
||||
"name": "Shared Chat",
|
||||
"queries": [
|
||||
{"prompt": "q1", "response": "a1"},
|
||||
{"prompt": "q2", "response": "a2"},
|
||||
{"prompt": "q3", "response": "a3"},
|
||||
],
|
||||
}
|
||||
|
||||
with patch(
|
||||
"application.api.user.sharing.routes.shared_conversations_collections",
|
||||
mock_shared,
|
||||
), patch(
|
||||
"application.api.user.sharing.routes.conversations_collection",
|
||||
mock_conversations,
|
||||
):
|
||||
with app.test_request_context(
|
||||
f"/api/shared_conversation/{test_uuid}"
|
||||
):
|
||||
response = GetPubliclySharedConversations().get(str(test_uuid))
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json["success"] is True
|
||||
assert response.json["title"] == "Shared Chat"
|
||||
assert len(response.json["queries"]) == 2
|
||||
|
||||
def test_returns_404_not_found(self, app):
|
||||
from application.api.user.sharing.routes import (
|
||||
GetPubliclySharedConversations,
|
||||
)
|
||||
|
||||
test_uuid = uuid.uuid4()
|
||||
mock_shared = Mock()
|
||||
mock_shared.find_one.return_value = None
|
||||
|
||||
with patch(
|
||||
"application.api.user.sharing.routes.shared_conversations_collections",
|
||||
mock_shared,
|
||||
):
|
||||
with app.test_request_context(
|
||||
f"/api/shared_conversation/{test_uuid}"
|
||||
):
|
||||
response = GetPubliclySharedConversations().get(str(test_uuid))
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_returns_404_conversation_deleted(self, app):
|
||||
from application.api.user.sharing.routes import (
|
||||
GetPubliclySharedConversations,
|
||||
)
|
||||
|
||||
test_uuid = uuid.uuid4()
|
||||
binary_uuid = Binary.from_uuid(test_uuid, UuidRepresentation.STANDARD)
|
||||
conv_id = ObjectId()
|
||||
|
||||
mock_shared = Mock()
|
||||
mock_shared.find_one.return_value = {
|
||||
"uuid": binary_uuid,
|
||||
"conversation_id": conv_id,
|
||||
"first_n_queries": 1,
|
||||
"isPromptable": False,
|
||||
}
|
||||
mock_conversations = Mock()
|
||||
mock_conversations.find_one.return_value = None
|
||||
|
||||
with patch(
|
||||
"application.api.user.sharing.routes.shared_conversations_collections",
|
||||
mock_shared,
|
||||
), patch(
|
||||
"application.api.user.sharing.routes.conversations_collection",
|
||||
mock_conversations,
|
||||
):
|
||||
with app.test_request_context(
|
||||
f"/api/shared_conversation/{test_uuid}"
|
||||
):
|
||||
response = GetPubliclySharedConversations().get(str(test_uuid))
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_includes_api_key_when_promptable(self, app):
|
||||
from application.api.user.sharing.routes import (
|
||||
GetPubliclySharedConversations,
|
||||
)
|
||||
|
||||
test_uuid = uuid.uuid4()
|
||||
binary_uuid = Binary.from_uuid(test_uuid, UuidRepresentation.STANDARD)
|
||||
conv_id = ObjectId()
|
||||
|
||||
mock_shared = Mock()
|
||||
mock_shared.find_one.return_value = {
|
||||
"uuid": binary_uuid,
|
||||
"conversation_id": conv_id,
|
||||
"first_n_queries": 1,
|
||||
"isPromptable": True,
|
||||
"api_key": "shared_api_key",
|
||||
}
|
||||
mock_conversations = Mock()
|
||||
mock_conversations.find_one.return_value = {
|
||||
"_id": conv_id,
|
||||
"name": "Chat",
|
||||
"queries": [{"prompt": "q1", "response": "a1"}],
|
||||
}
|
||||
|
||||
with patch(
|
||||
"application.api.user.sharing.routes.shared_conversations_collections",
|
||||
mock_shared,
|
||||
), patch(
|
||||
"application.api.user.sharing.routes.conversations_collection",
|
||||
mock_conversations,
|
||||
):
|
||||
with app.test_request_context(
|
||||
f"/api/shared_conversation/{test_uuid}"
|
||||
):
|
||||
response = GetPubliclySharedConversations().get(str(test_uuid))
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json["api_key"] == "shared_api_key"
|
||||
|
||||
def test_handles_dbref_conversation_id(self, app):
|
||||
from bson.dbref import DBRef
|
||||
from application.api.user.sharing.routes import (
|
||||
GetPubliclySharedConversations,
|
||||
)
|
||||
|
||||
test_uuid = uuid.uuid4()
|
||||
binary_uuid = Binary.from_uuid(test_uuid, UuidRepresentation.STANDARD)
|
||||
conv_id = ObjectId()
|
||||
|
||||
mock_shared = Mock()
|
||||
mock_shared.find_one.return_value = {
|
||||
"uuid": binary_uuid,
|
||||
"conversation_id": DBRef("conversations", conv_id),
|
||||
"first_n_queries": 1,
|
||||
"isPromptable": False,
|
||||
}
|
||||
mock_conversations = Mock()
|
||||
mock_conversations.find_one.return_value = {
|
||||
"_id": conv_id,
|
||||
"name": "Chat",
|
||||
"queries": [{"prompt": "q1", "response": "a1"}],
|
||||
}
|
||||
|
||||
with patch(
|
||||
"application.api.user.sharing.routes.shared_conversations_collections",
|
||||
mock_shared,
|
||||
), patch(
|
||||
"application.api.user.sharing.routes.conversations_collection",
|
||||
mock_conversations,
|
||||
):
|
||||
with app.test_request_context(
|
||||
f"/api/shared_conversation/{test_uuid}"
|
||||
):
|
||||
response = GetPubliclySharedConversations().get(str(test_uuid))
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_conversations.find_one.assert_called_once_with({"_id": conv_id})
|
||||
|
||||
def test_handles_dict_oid_conversation_id(self, app):
|
||||
from application.api.user.sharing.routes import (
|
||||
GetPubliclySharedConversations,
|
||||
)
|
||||
|
||||
test_uuid = uuid.uuid4()
|
||||
binary_uuid = Binary.from_uuid(test_uuid, UuidRepresentation.STANDARD)
|
||||
conv_id = ObjectId()
|
||||
|
||||
mock_shared = Mock()
|
||||
mock_shared.find_one.return_value = {
|
||||
"uuid": binary_uuid,
|
||||
"conversation_id": {"$id": {"$oid": str(conv_id)}},
|
||||
"first_n_queries": 1,
|
||||
"isPromptable": False,
|
||||
}
|
||||
mock_conversations = Mock()
|
||||
mock_conversations.find_one.return_value = {
|
||||
"_id": conv_id,
|
||||
"name": "Chat",
|
||||
"queries": [{"prompt": "q1", "response": "a1"}],
|
||||
}
|
||||
|
||||
with patch(
|
||||
"application.api.user.sharing.routes.shared_conversations_collections",
|
||||
mock_shared,
|
||||
), patch(
|
||||
"application.api.user.sharing.routes.conversations_collection",
|
||||
mock_conversations,
|
||||
):
|
||||
with app.test_request_context(
|
||||
f"/api/shared_conversation/{test_uuid}"
|
||||
):
|
||||
response = GetPubliclySharedConversations().get(str(test_uuid))
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_handles_dict_id_string_conversation_id(self, app):
|
||||
from application.api.user.sharing.routes import (
|
||||
GetPubliclySharedConversations,
|
||||
)
|
||||
|
||||
test_uuid = uuid.uuid4()
|
||||
binary_uuid = Binary.from_uuid(test_uuid, UuidRepresentation.STANDARD)
|
||||
conv_id = ObjectId()
|
||||
|
||||
mock_shared = Mock()
|
||||
mock_shared.find_one.return_value = {
|
||||
"uuid": binary_uuid,
|
||||
"conversation_id": {"$id": str(conv_id)},
|
||||
"first_n_queries": 1,
|
||||
"isPromptable": False,
|
||||
}
|
||||
mock_conversations = Mock()
|
||||
mock_conversations.find_one.return_value = {
|
||||
"_id": conv_id,
|
||||
"name": "Chat",
|
||||
"queries": [{"prompt": "q1", "response": "a1"}],
|
||||
}
|
||||
|
||||
with patch(
|
||||
"application.api.user.sharing.routes.shared_conversations_collections",
|
||||
mock_shared,
|
||||
), patch(
|
||||
"application.api.user.sharing.routes.conversations_collection",
|
||||
mock_conversations,
|
||||
):
|
||||
with app.test_request_context(
|
||||
f"/api/shared_conversation/{test_uuid}"
|
||||
):
|
||||
response = GetPubliclySharedConversations().get(str(test_uuid))
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_handles_dict_underscore_id_conversation_id(self, app):
|
||||
from application.api.user.sharing.routes import (
|
||||
GetPubliclySharedConversations,
|
||||
)
|
||||
|
||||
test_uuid = uuid.uuid4()
|
||||
binary_uuid = Binary.from_uuid(test_uuid, UuidRepresentation.STANDARD)
|
||||
conv_id = ObjectId()
|
||||
|
||||
mock_shared = Mock()
|
||||
mock_shared.find_one.return_value = {
|
||||
"uuid": binary_uuid,
|
||||
"conversation_id": {"_id": str(conv_id)},
|
||||
"first_n_queries": 1,
|
||||
"isPromptable": False,
|
||||
}
|
||||
mock_conversations = Mock()
|
||||
mock_conversations.find_one.return_value = {
|
||||
"_id": conv_id,
|
||||
"name": "Chat",
|
||||
"queries": [{"prompt": "q1", "response": "a1"}],
|
||||
}
|
||||
|
||||
with patch(
|
||||
"application.api.user.sharing.routes.shared_conversations_collections",
|
||||
mock_shared,
|
||||
), patch(
|
||||
"application.api.user.sharing.routes.conversations_collection",
|
||||
mock_conversations,
|
||||
):
|
||||
with app.test_request_context(
|
||||
f"/api/shared_conversation/{test_uuid}"
|
||||
):
|
||||
response = GetPubliclySharedConversations().get(str(test_uuid))
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_handles_string_conversation_id(self, app):
|
||||
from application.api.user.sharing.routes import (
|
||||
GetPubliclySharedConversations,
|
||||
)
|
||||
|
||||
test_uuid = uuid.uuid4()
|
||||
binary_uuid = Binary.from_uuid(test_uuid, UuidRepresentation.STANDARD)
|
||||
conv_id = ObjectId()
|
||||
|
||||
mock_shared = Mock()
|
||||
mock_shared.find_one.return_value = {
|
||||
"uuid": binary_uuid,
|
||||
"conversation_id": str(conv_id),
|
||||
"first_n_queries": 1,
|
||||
"isPromptable": False,
|
||||
}
|
||||
mock_conversations = Mock()
|
||||
mock_conversations.find_one.return_value = {
|
||||
"_id": conv_id,
|
||||
"name": "Chat",
|
||||
"queries": [{"prompt": "q1", "response": "a1"}],
|
||||
}
|
||||
|
||||
with patch(
|
||||
"application.api.user.sharing.routes.shared_conversations_collections",
|
||||
mock_shared,
|
||||
), patch(
|
||||
"application.api.user.sharing.routes.conversations_collection",
|
||||
mock_conversations,
|
||||
):
|
||||
with app.test_request_context(
|
||||
f"/api/shared_conversation/{test_uuid}"
|
||||
):
|
||||
response = GetPubliclySharedConversations().get(str(test_uuid))
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_resolves_attachments_in_shared(self, app):
|
||||
from application.api.user.sharing.routes import (
|
||||
GetPubliclySharedConversations,
|
||||
)
|
||||
|
||||
test_uuid = uuid.uuid4()
|
||||
binary_uuid = Binary.from_uuid(test_uuid, UuidRepresentation.STANDARD)
|
||||
conv_id = ObjectId()
|
||||
att_id = ObjectId()
|
||||
|
||||
mock_shared = Mock()
|
||||
mock_shared.find_one.return_value = {
|
||||
"uuid": binary_uuid,
|
||||
"conversation_id": conv_id,
|
||||
"first_n_queries": 1,
|
||||
"isPromptable": False,
|
||||
}
|
||||
mock_conversations = Mock()
|
||||
mock_conversations.find_one.return_value = {
|
||||
"_id": conv_id,
|
||||
"name": "Chat",
|
||||
"queries": [
|
||||
{"prompt": "q1", "response": "a1", "attachments": [str(att_id)]}
|
||||
],
|
||||
}
|
||||
mock_attachments = Mock()
|
||||
mock_attachments.find_one.return_value = {
|
||||
"_id": att_id,
|
||||
"filename": "file.pdf",
|
||||
}
|
||||
|
||||
with patch(
|
||||
"application.api.user.sharing.routes.shared_conversations_collections",
|
||||
mock_shared,
|
||||
), patch(
|
||||
"application.api.user.sharing.routes.conversations_collection",
|
||||
mock_conversations,
|
||||
), patch(
|
||||
"application.api.user.sharing.routes.attachments_collection",
|
||||
mock_attachments,
|
||||
):
|
||||
with app.test_request_context(
|
||||
f"/api/shared_conversation/{test_uuid}"
|
||||
):
|
||||
response = GetPubliclySharedConversations().get(str(test_uuid))
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json["queries"][0]["attachments"][0]["fileName"] == "file.pdf"
|
||||
|
||||
def test_handles_general_exception(self, app):
|
||||
from application.api.user.sharing.routes import (
|
||||
GetPubliclySharedConversations,
|
||||
)
|
||||
|
||||
mock_shared = Mock()
|
||||
mock_shared.find_one.side_effect = Exception("DB error")
|
||||
|
||||
with patch(
|
||||
"application.api.user.sharing.routes.shared_conversations_collections",
|
||||
mock_shared,
|
||||
):
|
||||
with app.test_request_context(
|
||||
f"/api/shared_conversation/{uuid.uuid4()}"
|
||||
):
|
||||
response = GetPubliclySharedConversations().get(str(uuid.uuid4()))
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestShareConversationPromptable:
|
||||
|
||||
def test_promptable_with_existing_api_key_and_existing_share(self, app):
|
||||
from application.api.user.sharing.routes import ShareConversation
|
||||
|
||||
conv_id = ObjectId()
|
||||
test_uuid = uuid.uuid4()
|
||||
binary_uuid = Binary.from_uuid(test_uuid, UuidRepresentation.STANDARD)
|
||||
|
||||
mock_conversations = Mock()
|
||||
mock_conversations.find_one.return_value = {
|
||||
"_id": conv_id,
|
||||
"name": "Test Chat",
|
||||
"queries": [{"prompt": "hi"}],
|
||||
}
|
||||
mock_agents = Mock()
|
||||
mock_agents.find_one.return_value = {"key": "existing_api_uuid"}
|
||||
mock_shared = Mock()
|
||||
mock_shared.find_one.return_value = {"uuid": binary_uuid}
|
||||
|
||||
with patch(
|
||||
"application.api.user.sharing.routes.conversations_collection",
|
||||
mock_conversations,
|
||||
), patch(
|
||||
"application.api.user.sharing.routes.agents_collection",
|
||||
mock_agents,
|
||||
), patch(
|
||||
"application.api.user.sharing.routes.shared_conversations_collections",
|
||||
mock_shared,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/share?isPromptable=true",
|
||||
method="POST",
|
||||
json={
|
||||
"conversation_id": str(conv_id),
|
||||
"prompt_id": "default",
|
||||
"chunks": "3",
|
||||
},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = ShareConversation().post()
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json["identifier"] == str(test_uuid)
|
||||
|
||||
def test_promptable_with_existing_api_key_new_share(self, app):
|
||||
from application.api.user.sharing.routes import ShareConversation
|
||||
|
||||
conv_id = ObjectId()
|
||||
|
||||
mock_conversations = Mock()
|
||||
mock_conversations.find_one.return_value = {
|
||||
"_id": conv_id,
|
||||
"name": "Test Chat",
|
||||
"queries": [{"prompt": "hi"}],
|
||||
}
|
||||
mock_agents = Mock()
|
||||
mock_agents.find_one.return_value = {"key": "existing_api_uuid"}
|
||||
mock_shared = Mock()
|
||||
mock_shared.find_one.return_value = None
|
||||
|
||||
with patch(
|
||||
"application.api.user.sharing.routes.conversations_collection",
|
||||
mock_conversations,
|
||||
), patch(
|
||||
"application.api.user.sharing.routes.agents_collection",
|
||||
mock_agents,
|
||||
), patch(
|
||||
"application.api.user.sharing.routes.shared_conversations_collections",
|
||||
mock_shared,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/share?isPromptable=true",
|
||||
method="POST",
|
||||
json={
|
||||
"conversation_id": str(conv_id),
|
||||
"source": str(ObjectId()),
|
||||
"retriever": "classic",
|
||||
},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = ShareConversation().post()
|
||||
|
||||
assert response.status_code == 201
|
||||
mock_shared.insert_one.assert_called_once()
|
||||
|
||||
def test_promptable_creates_new_api_key(self, app):
|
||||
from application.api.user.sharing.routes import ShareConversation
|
||||
|
||||
conv_id = ObjectId()
|
||||
|
||||
mock_conversations = Mock()
|
||||
mock_conversations.find_one.return_value = {
|
||||
"_id": conv_id,
|
||||
"name": "Test Chat",
|
||||
"queries": [{"prompt": "hi"}],
|
||||
}
|
||||
mock_agents = Mock()
|
||||
mock_agents.find_one.return_value = None
|
||||
mock_shared = Mock()
|
||||
|
||||
with patch(
|
||||
"application.api.user.sharing.routes.conversations_collection",
|
||||
mock_conversations,
|
||||
), patch(
|
||||
"application.api.user.sharing.routes.agents_collection",
|
||||
mock_agents,
|
||||
), patch(
|
||||
"application.api.user.sharing.routes.shared_conversations_collections",
|
||||
mock_shared,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/share?isPromptable=true",
|
||||
method="POST",
|
||||
json={
|
||||
"conversation_id": str(conv_id),
|
||||
"source": str(ObjectId()),
|
||||
"retriever": "classic",
|
||||
},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = ShareConversation().post()
|
||||
|
||||
assert response.status_code == 201
|
||||
mock_agents.insert_one.assert_called_once()
|
||||
mock_shared.insert_one.assert_called_once()
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Coverage gap tests (lines 201-205)
|
||||
# =====================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestShareConversationExceptionGap:
|
||||
def test_share_conversation_exception_returns_400(self):
|
||||
"""Cover lines 201-205: exception during sharing returns 400."""
|
||||
from application.api.user.sharing.routes import ShareConversation
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
mock_conversations = Mock()
|
||||
mock_conversations.find_one.side_effect = Exception("db error")
|
||||
|
||||
with patch(
|
||||
"application.api.user.sharing.routes.conversations_collection",
|
||||
mock_conversations,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/share",
|
||||
method="POST",
|
||||
json={
|
||||
"conversation_id": str(ObjectId()),
|
||||
"source": str(ObjectId()),
|
||||
"retriever": "classic",
|
||||
},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = ShareConversation().post()
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Coverage — additional uncovered lines: 201-205
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestShareConversationErrorPath:
|
||||
|
||||
def test_share_conversation_exception_returns_400(self, app):
|
||||
"""Cover lines 201-205: exception during sharing returns 400."""
|
||||
from application.api.user.sharing.routes import ShareConversation
|
||||
|
||||
mock_conversations = Mock()
|
||||
mock_conversations.find_one.side_effect = Exception("DB error")
|
||||
|
||||
with patch(
|
||||
"application.api.user.sharing.routes.conversations_collection",
|
||||
mock_conversations,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/share",
|
||||
method="POST",
|
||||
json={"conversation_id": str(ObjectId())},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = ShareConversation().post()
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Additional coverage for sharing/routes.py
|
||||
# Lines: 201-205: exception in try block (different entry point)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestShareConversationInsertException:
|
||||
"""Cover lines 201-205: exception during insert_one."""
|
||||
|
||||
def test_insert_one_exception_returns_400(self, app):
|
||||
from application.api.user.sharing.routes import ShareConversation
|
||||
|
||||
mock_conversations = Mock()
|
||||
mock_conversations.find_one.return_value = {
|
||||
"_id": ObjectId(),
|
||||
"user": "user1",
|
||||
"queries": [],
|
||||
}
|
||||
mock_shared = Mock()
|
||||
mock_shared.find_one.return_value = None
|
||||
mock_shared.insert_one.side_effect = Exception("Insert failed")
|
||||
|
||||
with patch(
|
||||
"application.api.user.sharing.routes.conversations_collection",
|
||||
mock_conversations,
|
||||
), patch(
|
||||
"application.api.user.sharing.routes.shared_conversations_collections",
|
||||
mock_shared,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/share",
|
||||
method="POST",
|
||||
json={"conversation_id": str(ObjectId())},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = ShareConversation().post()
|
||||
|
||||
assert response.status_code == 400
|
||||
240
tests/api/user/test_tasks.py
Normal file
240
tests/api/user/test_tasks.py
Normal file
@@ -0,0 +1,240 @@
|
||||
from datetime import timedelta
|
||||
from unittest.mock import ANY, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestIngestTask:
|
||||
@pytest.mark.unit
|
||||
@patch("application.api.user.tasks.ingest_worker")
|
||||
def test_calls_ingest_worker(self, mock_worker):
|
||||
from application.api.user.tasks import ingest
|
||||
|
||||
mock_worker.return_value = {"status": "ok"}
|
||||
|
||||
result = ingest("dir", ["pdf"], "job1", "user1", "/path", "file.pdf")
|
||||
|
||||
mock_worker.assert_called_once_with(
|
||||
ANY, "dir", ["pdf"], "job1", "/path", "file.pdf", "user1",
|
||||
file_name_map=None,
|
||||
)
|
||||
assert result == {"status": "ok"}
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("application.api.user.tasks.ingest_worker")
|
||||
def test_passes_file_name_map(self, mock_worker):
|
||||
from application.api.user.tasks import ingest
|
||||
|
||||
mock_worker.return_value = {"status": "ok"}
|
||||
name_map = {"a.pdf": "b.pdf"}
|
||||
|
||||
ingest("dir", ["pdf"], "job1", "user1", "/path", "file.pdf",
|
||||
file_name_map=name_map)
|
||||
|
||||
mock_worker.assert_called_once_with(
|
||||
ANY, "dir", ["pdf"], "job1", "/path", "file.pdf", "user1",
|
||||
file_name_map=name_map,
|
||||
)
|
||||
|
||||
|
||||
class TestIngestRemoteTask:
|
||||
@pytest.mark.unit
|
||||
@patch("application.api.user.tasks.remote_worker")
|
||||
def test_calls_remote_worker(self, mock_worker):
|
||||
from application.api.user.tasks import ingest_remote
|
||||
|
||||
mock_worker.return_value = {"status": "ok"}
|
||||
|
||||
result = ingest_remote({"url": "http://x"}, "job1", "user1", "web")
|
||||
|
||||
mock_worker.assert_called_once_with(
|
||||
ANY, {"url": "http://x"}, "job1", "user1", "web"
|
||||
)
|
||||
assert result == {"status": "ok"}
|
||||
|
||||
|
||||
class TestReingestSourceTask:
|
||||
@pytest.mark.unit
|
||||
@patch("application.worker.reingest_source_worker")
|
||||
def test_calls_reingest_worker(self, mock_worker):
|
||||
from application.api.user.tasks import reingest_source_task
|
||||
|
||||
mock_worker.return_value = {"status": "ok"}
|
||||
|
||||
result = reingest_source_task("source123", "user1")
|
||||
|
||||
mock_worker.assert_called_once_with(ANY, "source123", "user1")
|
||||
assert result == {"status": "ok"}
|
||||
|
||||
|
||||
class TestScheduleSyncsTask:
|
||||
@pytest.mark.unit
|
||||
@patch("application.api.user.tasks.sync_worker")
|
||||
def test_calls_sync_worker(self, mock_worker):
|
||||
from application.api.user.tasks import schedule_syncs
|
||||
|
||||
mock_worker.return_value = {"status": "ok"}
|
||||
|
||||
result = schedule_syncs("daily")
|
||||
|
||||
mock_worker.assert_called_once_with(ANY, "daily")
|
||||
assert result == {"status": "ok"}
|
||||
|
||||
|
||||
class TestSyncSourceTask:
|
||||
@pytest.mark.unit
|
||||
@patch("application.api.user.tasks.sync")
|
||||
def test_calls_sync(self, mock_sync):
|
||||
from application.api.user.tasks import sync_source
|
||||
|
||||
mock_sync.return_value = {"status": "ok"}
|
||||
|
||||
result = sync_source(
|
||||
{"data": 1}, "job1", "user1", "web", "daily", "classic", "doc1"
|
||||
)
|
||||
|
||||
mock_sync.assert_called_once_with(
|
||||
ANY, {"data": 1}, "job1", "user1", "web", "daily", "classic", "doc1"
|
||||
)
|
||||
assert result == {"status": "ok"}
|
||||
|
||||
|
||||
class TestStoreAttachmentTask:
|
||||
@pytest.mark.unit
|
||||
@patch("application.api.user.tasks.attachment_worker")
|
||||
def test_calls_attachment_worker(self, mock_worker):
|
||||
from application.api.user.tasks import store_attachment
|
||||
|
||||
mock_worker.return_value = {"status": "ok"}
|
||||
|
||||
result = store_attachment({"file": "info"}, "user1")
|
||||
|
||||
mock_worker.assert_called_once_with(ANY, {"file": "info"}, "user1")
|
||||
assert result == {"status": "ok"}
|
||||
|
||||
|
||||
class TestProcessAgentWebhookTask:
|
||||
@pytest.mark.unit
|
||||
@patch("application.api.user.tasks.agent_webhook_worker")
|
||||
def test_calls_agent_webhook_worker(self, mock_worker):
|
||||
from application.api.user.tasks import process_agent_webhook
|
||||
|
||||
mock_worker.return_value = {"status": "ok"}
|
||||
|
||||
result = process_agent_webhook("agent123", {"event": "test"})
|
||||
|
||||
mock_worker.assert_called_once_with(ANY, "agent123", {"event": "test"})
|
||||
assert result == {"status": "ok"}
|
||||
|
||||
|
||||
class TestIngestConnectorTask:
|
||||
@pytest.mark.unit
|
||||
@patch("application.worker.ingest_connector")
|
||||
def test_calls_ingest_connector_defaults(self, mock_worker):
|
||||
from application.api.user.tasks import ingest_connector_task
|
||||
|
||||
mock_worker.return_value = {"status": "ok"}
|
||||
|
||||
result = ingest_connector_task("job1", "user1", "gdrive")
|
||||
|
||||
mock_worker.assert_called_once_with(
|
||||
ANY,
|
||||
"job1",
|
||||
"user1",
|
||||
"gdrive",
|
||||
session_token=None,
|
||||
file_ids=None,
|
||||
folder_ids=None,
|
||||
recursive=True,
|
||||
retriever="classic",
|
||||
operation_mode="upload",
|
||||
doc_id=None,
|
||||
sync_frequency="never",
|
||||
)
|
||||
assert result == {"status": "ok"}
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("application.worker.ingest_connector")
|
||||
def test_calls_ingest_connector_custom(self, mock_worker):
|
||||
from application.api.user.tasks import ingest_connector_task
|
||||
|
||||
mock_worker.return_value = {"status": "ok"}
|
||||
|
||||
result = ingest_connector_task(
|
||||
"job1",
|
||||
"user1",
|
||||
"sharepoint",
|
||||
session_token="tok",
|
||||
file_ids=["f1"],
|
||||
folder_ids=["d1"],
|
||||
recursive=False,
|
||||
retriever="duckdb",
|
||||
operation_mode="sync",
|
||||
doc_id="doc1",
|
||||
sync_frequency="daily",
|
||||
)
|
||||
|
||||
mock_worker.assert_called_once_with(
|
||||
ANY,
|
||||
"job1",
|
||||
"user1",
|
||||
"sharepoint",
|
||||
session_token="tok",
|
||||
file_ids=["f1"],
|
||||
folder_ids=["d1"],
|
||||
recursive=False,
|
||||
retriever="duckdb",
|
||||
operation_mode="sync",
|
||||
doc_id="doc1",
|
||||
sync_frequency="daily",
|
||||
)
|
||||
assert result == {"status": "ok"}
|
||||
|
||||
|
||||
class TestSetupPeriodicTasks:
|
||||
@pytest.mark.unit
|
||||
def test_registers_periodic_tasks(self):
|
||||
from application.api.user.tasks import setup_periodic_tasks
|
||||
|
||||
sender = MagicMock()
|
||||
|
||||
setup_periodic_tasks(sender)
|
||||
|
||||
assert sender.add_periodic_task.call_count == 3
|
||||
|
||||
calls = sender.add_periodic_task.call_args_list
|
||||
|
||||
# daily
|
||||
assert calls[0][0][0] == timedelta(days=1)
|
||||
# weekly
|
||||
assert calls[1][0][0] == timedelta(weeks=1)
|
||||
# monthly
|
||||
assert calls[2][0][0] == timedelta(days=30)
|
||||
|
||||
|
||||
class TestMcpOauthTask:
|
||||
@pytest.mark.unit
|
||||
@patch("application.api.user.tasks.mcp_oauth")
|
||||
def test_calls_mcp_oauth(self, mock_worker):
|
||||
from application.api.user.tasks import mcp_oauth_task
|
||||
|
||||
mock_worker.return_value = {"url": "http://auth"}
|
||||
|
||||
result = mcp_oauth_task({"server": "mcp"}, "user1")
|
||||
|
||||
mock_worker.assert_called_once_with(ANY, {"server": "mcp"}, "user1")
|
||||
assert result == {"url": "http://auth"}
|
||||
|
||||
|
||||
class TestMcpOauthStatusTask:
|
||||
@pytest.mark.unit
|
||||
@patch("application.api.user.tasks.mcp_oauth_status")
|
||||
def test_calls_mcp_oauth_status(self, mock_worker):
|
||||
from application.api.user.tasks import mcp_oauth_status_task
|
||||
|
||||
mock_worker.return_value = {"status": "authorized"}
|
||||
|
||||
result = mcp_oauth_status_task("task123")
|
||||
|
||||
mock_worker.assert_called_once_with(ANY, "task123")
|
||||
assert result == {"status": "authorized"}
|
||||
1308
tests/api/user/test_tools_mcp.py
Normal file
1308
tests/api/user/test_tools_mcp.py
Normal file
File diff suppressed because it is too large
Load Diff
1948
tests/api/user/test_tools_routes.py
Normal file
1948
tests/api/user/test_tools_routes.py
Normal file
File diff suppressed because it is too large
Load Diff
411
tests/api/user/test_utils.py
Normal file
411
tests/api/user/test_utils.py
Normal file
@@ -0,0 +1,411 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from bson import ObjectId
|
||||
from flask import Flask
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
app = Flask(__name__)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGetUserId:
|
||||
|
||||
def test_returns_user_id_from_decoded_token(self, app):
|
||||
from application.api.user.utils import get_user_id
|
||||
|
||||
with app.test_request_context():
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user_123"}
|
||||
assert get_user_id() == "user_123"
|
||||
|
||||
def test_returns_none_when_no_decoded_token(self, app):
|
||||
from application.api.user.utils import get_user_id
|
||||
|
||||
with app.test_request_context():
|
||||
assert get_user_id() is None
|
||||
|
||||
def test_returns_none_when_decoded_token_has_no_sub(self, app):
|
||||
from application.api.user.utils import get_user_id
|
||||
|
||||
with app.test_request_context():
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {}
|
||||
assert get_user_id() is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestRequireAuth:
|
||||
|
||||
def test_allows_authenticated_request(self, app):
|
||||
from application.api.user.utils import require_auth
|
||||
|
||||
@require_auth
|
||||
def protected():
|
||||
return "ok"
|
||||
|
||||
with app.test_request_context():
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user_123"}
|
||||
assert protected() == "ok"
|
||||
|
||||
def test_returns_401_when_unauthenticated(self, app):
|
||||
from application.api.user.utils import require_auth
|
||||
|
||||
@require_auth
|
||||
def protected():
|
||||
return "ok"
|
||||
|
||||
with app.test_request_context():
|
||||
result = protected()
|
||||
assert result.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSuccessResponse:
|
||||
|
||||
def test_default_success_response(self, app):
|
||||
from application.api.user.utils import success_response
|
||||
|
||||
with app.app_context():
|
||||
resp = success_response()
|
||||
assert resp.status_code == 200
|
||||
assert resp.json["success"] is True
|
||||
|
||||
def test_success_response_with_data(self, app):
|
||||
from application.api.user.utils import success_response
|
||||
|
||||
with app.app_context():
|
||||
resp = success_response({"items": [1, 2], "total": 2})
|
||||
assert resp.status_code == 200
|
||||
assert resp.json["success"] is True
|
||||
assert resp.json["items"] == [1, 2]
|
||||
assert resp.json["total"] == 2
|
||||
|
||||
def test_success_response_custom_status(self, app):
|
||||
from application.api.user.utils import success_response
|
||||
|
||||
with app.app_context():
|
||||
resp = success_response({"id": "new"}, 201)
|
||||
assert resp.status_code == 201
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestErrorResponse:
|
||||
|
||||
def test_default_error_response(self, app):
|
||||
from application.api.user.utils import error_response
|
||||
|
||||
with app.app_context():
|
||||
resp = error_response("Something went wrong")
|
||||
assert resp.status_code == 400
|
||||
assert resp.json["success"] is False
|
||||
assert resp.json["message"] == "Something went wrong"
|
||||
|
||||
def test_error_response_custom_status(self, app):
|
||||
from application.api.user.utils import error_response
|
||||
|
||||
with app.app_context():
|
||||
resp = error_response("Not found", 404)
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_error_response_extra_kwargs(self, app):
|
||||
from application.api.user.utils import error_response
|
||||
|
||||
with app.app_context():
|
||||
resp = error_response("Bad", 400, errors=["field1", "field2"])
|
||||
assert resp.json["errors"] == ["field1", "field2"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestValidateObjectId:
|
||||
|
||||
def test_valid_object_id(self, app):
|
||||
from application.api.user.utils import validate_object_id
|
||||
|
||||
with app.app_context():
|
||||
oid = ObjectId()
|
||||
result, error = validate_object_id(str(oid))
|
||||
assert result == oid
|
||||
assert error is None
|
||||
|
||||
def test_invalid_object_id(self, app):
|
||||
from application.api.user.utils import validate_object_id
|
||||
|
||||
with app.app_context():
|
||||
result, error = validate_object_id("not-a-valid-id")
|
||||
assert result is None
|
||||
assert error.status_code == 400
|
||||
assert "Invalid" in error.json["message"]
|
||||
|
||||
def test_custom_resource_name(self, app):
|
||||
from application.api.user.utils import validate_object_id
|
||||
|
||||
with app.app_context():
|
||||
_, error = validate_object_id("bad", "Workflow")
|
||||
assert "Workflow" in error.json["message"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestValidatePagination:
|
||||
|
||||
def test_default_pagination(self, app):
|
||||
from application.api.user.utils import validate_pagination
|
||||
|
||||
with app.test_request_context("/?limit=10&skip=0"):
|
||||
limit, skip, error = validate_pagination()
|
||||
assert limit == 10
|
||||
assert skip == 0
|
||||
assert error is None
|
||||
|
||||
def test_uses_defaults_when_no_params(self, app):
|
||||
from application.api.user.utils import validate_pagination
|
||||
|
||||
with app.test_request_context("/"):
|
||||
limit, skip, error = validate_pagination()
|
||||
assert limit == 20
|
||||
assert skip == 0
|
||||
assert error is None
|
||||
|
||||
def test_enforces_max_limit(self, app):
|
||||
from application.api.user.utils import validate_pagination
|
||||
|
||||
with app.test_request_context("/?limit=500"):
|
||||
limit, _, _ = validate_pagination(max_limit=100)
|
||||
assert limit == 100
|
||||
|
||||
def test_invalid_limit(self, app):
|
||||
from application.api.user.utils import validate_pagination
|
||||
|
||||
with app.test_request_context("/?limit=-1"):
|
||||
_, _, error = validate_pagination()
|
||||
assert error is not None
|
||||
assert error.status_code == 400
|
||||
|
||||
def test_invalid_skip(self, app):
|
||||
from application.api.user.utils import validate_pagination
|
||||
|
||||
with app.test_request_context("/?skip=-1"):
|
||||
_, _, error = validate_pagination()
|
||||
assert error is not None
|
||||
|
||||
def test_non_numeric_values(self, app):
|
||||
from application.api.user.utils import validate_pagination
|
||||
|
||||
with app.test_request_context("/?limit=abc"):
|
||||
_, _, error = validate_pagination()
|
||||
assert error is not None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCheckResourceOwnership:
|
||||
|
||||
def test_returns_resource_when_owned(self, app):
|
||||
from application.api.user.utils import check_resource_ownership
|
||||
|
||||
with app.app_context():
|
||||
collection = Mock()
|
||||
oid = ObjectId()
|
||||
doc = {"_id": oid, "user": "user1", "name": "test"}
|
||||
collection.find_one.return_value = doc
|
||||
|
||||
resource, error = check_resource_ownership(collection, oid, "user1")
|
||||
assert resource == doc
|
||||
assert error is None
|
||||
|
||||
def test_returns_404_when_not_found(self, app):
|
||||
from application.api.user.utils import check_resource_ownership
|
||||
|
||||
with app.app_context():
|
||||
collection = Mock()
|
||||
collection.find_one.return_value = None
|
||||
|
||||
resource, error = check_resource_ownership(
|
||||
collection, ObjectId(), "user1", "Workflow"
|
||||
)
|
||||
assert resource is None
|
||||
assert error.status_code == 404
|
||||
assert "Workflow" in error.json["message"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSerializeObjectId:
|
||||
|
||||
def test_converts_id_to_string(self):
|
||||
from application.api.user.utils import serialize_object_id
|
||||
|
||||
oid = ObjectId()
|
||||
obj = {"_id": oid, "name": "test"}
|
||||
result = serialize_object_id(obj)
|
||||
assert result["id"] == str(oid)
|
||||
assert "_id" not in result
|
||||
|
||||
def test_custom_field_names(self):
|
||||
from application.api.user.utils import serialize_object_id
|
||||
|
||||
oid = ObjectId()
|
||||
obj = {"custom_id": oid}
|
||||
result = serialize_object_id(obj, id_field="custom_id", new_field="uid")
|
||||
assert result["uid"] == str(oid)
|
||||
assert "custom_id" not in result
|
||||
|
||||
def test_no_id_field_present(self):
|
||||
from application.api.user.utils import serialize_object_id
|
||||
|
||||
obj = {"name": "test"}
|
||||
result = serialize_object_id(obj)
|
||||
assert "id" not in result
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSerializeList:
|
||||
|
||||
def test_applies_serializer_to_all_items(self):
|
||||
from application.api.user.utils import serialize_list
|
||||
|
||||
items = [{"_id": ObjectId()}, {"_id": ObjectId()}]
|
||||
|
||||
def serializer(item):
|
||||
return {"id": str(item["_id"])}
|
||||
|
||||
result = serialize_list(items, serializer)
|
||||
assert len(result) == 2
|
||||
assert all("id" in r for r in result)
|
||||
|
||||
def test_empty_list(self):
|
||||
from application.api.user.utils import serialize_list
|
||||
|
||||
assert serialize_list([], lambda x: x) == []
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestRequireFields:
|
||||
|
||||
def test_allows_valid_request(self, app):
|
||||
from application.api.user.utils import require_fields
|
||||
|
||||
@require_fields(["name", "email"])
|
||||
def handler():
|
||||
return "ok"
|
||||
|
||||
with app.test_request_context(
|
||||
"/", method="POST", json={"name": "Alice", "email": "a@b.com"}
|
||||
):
|
||||
assert handler() == "ok"
|
||||
|
||||
def test_rejects_missing_fields(self, app):
|
||||
from application.api.user.utils import require_fields
|
||||
|
||||
@require_fields(["name", "email"])
|
||||
def handler():
|
||||
return "ok"
|
||||
|
||||
with app.test_request_context("/", method="POST", json={"name": "Alice"}):
|
||||
result = handler()
|
||||
assert result.status_code == 400
|
||||
assert "email" in result.json["message"]
|
||||
|
||||
def test_rejects_empty_body(self, app):
|
||||
from application.api.user.utils import require_fields
|
||||
|
||||
@require_fields(["name"])
|
||||
def handler():
|
||||
return "ok"
|
||||
|
||||
with app.test_request_context(
|
||||
"/", method="POST", json={}
|
||||
):
|
||||
result = handler()
|
||||
assert result.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSafeDbOperation:
|
||||
|
||||
def test_returns_result_on_success(self, app):
|
||||
from application.api.user.utils import safe_db_operation
|
||||
|
||||
with app.app_context():
|
||||
result, error = safe_db_operation(lambda: {"inserted": True})
|
||||
assert result == {"inserted": True}
|
||||
assert error is None
|
||||
|
||||
def test_returns_error_on_exception(self, app):
|
||||
from application.api.user.utils import safe_db_operation
|
||||
|
||||
with app.app_context():
|
||||
result, error = safe_db_operation(
|
||||
lambda: (_ for _ in ()).throw(RuntimeError("db error")),
|
||||
"Operation failed",
|
||||
)
|
||||
assert result is None
|
||||
assert error.status_code == 400
|
||||
assert error.json["message"] == "Operation failed"
|
||||
|
||||
def test_hides_exception_details(self, app):
|
||||
from application.api.user.utils import safe_db_operation
|
||||
|
||||
with app.app_context():
|
||||
_, error = safe_db_operation(
|
||||
lambda: (_ for _ in ()).throw(RuntimeError("secret credentials")),
|
||||
"Failed",
|
||||
)
|
||||
assert "credentials" not in error.json["message"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestValidateEnum:
|
||||
|
||||
def test_valid_value(self, app):
|
||||
from application.api.user.utils import validate_enum
|
||||
|
||||
with app.app_context():
|
||||
assert validate_enum("draft", ["draft", "published"], "status") is None
|
||||
|
||||
def test_invalid_value(self, app):
|
||||
from application.api.user.utils import validate_enum
|
||||
|
||||
with app.app_context():
|
||||
error = validate_enum("unknown", ["draft", "published"], "status")
|
||||
assert error.status_code == 400
|
||||
assert "status" in error.json["message"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestExtractSortParams:
|
||||
|
||||
def test_defaults(self, app):
|
||||
from application.api.user.utils import extract_sort_params
|
||||
|
||||
with app.test_request_context("/"):
|
||||
field, order = extract_sort_params()
|
||||
assert field == "created_at"
|
||||
assert order == -1
|
||||
|
||||
def test_custom_params(self, app):
|
||||
from application.api.user.utils import extract_sort_params
|
||||
|
||||
with app.test_request_context("/?sort=name&order=asc"):
|
||||
field, order = extract_sort_params()
|
||||
assert field == "name"
|
||||
assert order == 1
|
||||
|
||||
def test_enforces_allowed_fields(self, app):
|
||||
from application.api.user.utils import extract_sort_params
|
||||
|
||||
with app.test_request_context("/?sort=forbidden_field"):
|
||||
field, _ = extract_sort_params(allowed_fields=["name", "date"])
|
||||
assert field == "created_at"
|
||||
|
||||
def test_desc_order(self, app):
|
||||
from application.api.user.utils import extract_sort_params
|
||||
|
||||
with app.test_request_context("/?order=desc"):
|
||||
_, order = extract_sort_params()
|
||||
assert order == -1
|
||||
225
tests/api/user/test_webhooks.py
Normal file
225
tests/api/user/test_webhooks.py
Normal file
@@ -0,0 +1,225 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from bson import ObjectId
|
||||
from flask import Flask
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
app = Flask(__name__)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestAgentWebhook:
|
||||
|
||||
def test_returns_existing_webhook_url(self, app):
|
||||
from application.api.user.agents.webhooks import AgentWebhook
|
||||
|
||||
agent_id = ObjectId()
|
||||
mock_collection = Mock()
|
||||
mock_collection.find_one.return_value = {
|
||||
"_id": agent_id,
|
||||
"user": "user1",
|
||||
"incoming_webhook_token": "existing_token",
|
||||
}
|
||||
|
||||
with patch(
|
||||
"application.api.user.agents.webhooks.agents_collection",
|
||||
mock_collection,
|
||||
), patch(
|
||||
"application.api.user.agents.webhooks.settings",
|
||||
Mock(API_URL="https://api.example.com"),
|
||||
):
|
||||
with app.test_request_context(
|
||||
f"/api/agent_webhook?id={agent_id}"
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = AgentWebhook().get()
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json["success"] is True
|
||||
assert "existing_token" in response.json["webhook_url"]
|
||||
mock_collection.update_one.assert_not_called()
|
||||
|
||||
def test_generates_new_webhook_token(self, app):
|
||||
from application.api.user.agents.webhooks import AgentWebhook
|
||||
|
||||
agent_id = ObjectId()
|
||||
mock_collection = Mock()
|
||||
mock_collection.find_one.return_value = {
|
||||
"_id": agent_id,
|
||||
"user": "user1",
|
||||
"incoming_webhook_token": None,
|
||||
}
|
||||
|
||||
with patch(
|
||||
"application.api.user.agents.webhooks.agents_collection",
|
||||
mock_collection,
|
||||
), patch(
|
||||
"application.api.user.agents.webhooks.settings",
|
||||
Mock(API_URL="https://api.example.com"),
|
||||
), patch(
|
||||
"application.api.user.agents.webhooks.secrets.token_urlsafe",
|
||||
return_value="new_generated_token",
|
||||
):
|
||||
with app.test_request_context(
|
||||
f"/api/agent_webhook?id={agent_id}"
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = AgentWebhook().get()
|
||||
|
||||
assert response.status_code == 200
|
||||
assert "new_generated_token" in response.json["webhook_url"]
|
||||
mock_collection.update_one.assert_called_once()
|
||||
|
||||
def test_returns_401_unauthenticated(self, app):
|
||||
from application.api.user.agents.webhooks import AgentWebhook
|
||||
|
||||
with app.test_request_context(f"/api/agent_webhook?id={ObjectId()}"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = None
|
||||
response = AgentWebhook().get()
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_returns_400_missing_id(self, app):
|
||||
from application.api.user.agents.webhooks import AgentWebhook
|
||||
|
||||
with app.test_request_context("/api/agent_webhook"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = AgentWebhook().get()
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_returns_404_agent_not_found(self, app):
|
||||
from application.api.user.agents.webhooks import AgentWebhook
|
||||
|
||||
mock_collection = Mock()
|
||||
mock_collection.find_one.return_value = None
|
||||
|
||||
with patch(
|
||||
"application.api.user.agents.webhooks.agents_collection",
|
||||
mock_collection,
|
||||
):
|
||||
with app.test_request_context(
|
||||
f"/api/agent_webhook?id={ObjectId()}"
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = AgentWebhook().get()
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestAgentWebhookListenerPost:
|
||||
|
||||
def test_enqueues_task_on_valid_post(self, app):
|
||||
from application.api.user.agents.webhooks import AgentWebhookListener
|
||||
|
||||
mock_task = Mock()
|
||||
mock_task.id = "task_abc"
|
||||
|
||||
with patch(
|
||||
"application.api.user.agents.webhooks.process_agent_webhook"
|
||||
) as mock_process:
|
||||
mock_process.delay.return_value = mock_task
|
||||
with app.test_request_context(
|
||||
"/api/webhooks/agents/tok",
|
||||
method="POST",
|
||||
json={"event": "new_message"},
|
||||
):
|
||||
listener = AgentWebhookListener()
|
||||
response = listener._enqueue_webhook_task(
|
||||
"agent123", {"event": "new_message"}, "POST"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json["task_id"] == "task_abc"
|
||||
mock_process.delay.assert_called_once_with(
|
||||
agent_id="agent123", payload={"event": "new_message"}
|
||||
)
|
||||
|
||||
def test_returns_400_on_missing_json(self, app):
|
||||
from application.api.user.agents.webhooks import AgentWebhookListener
|
||||
|
||||
with app.test_request_context(
|
||||
"/api/webhooks/agents/tok",
|
||||
method="POST",
|
||||
json=None,
|
||||
content_type="application/json",
|
||||
data="",
|
||||
):
|
||||
from flask import request as flask_request
|
||||
|
||||
# Force get_json to return None (simulating empty/missing body)
|
||||
with patch.object(
|
||||
flask_request, "get_json", return_value=None
|
||||
):
|
||||
listener = AgentWebhookListener()
|
||||
response = listener.post(
|
||||
webhook_token="tok",
|
||||
agent={"_id": ObjectId()},
|
||||
agent_id_str="agent123",
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_handles_enqueue_error(self, app):
|
||||
from application.api.user.agents.webhooks import AgentWebhookListener
|
||||
|
||||
with patch(
|
||||
"application.api.user.agents.webhooks.process_agent_webhook"
|
||||
) as mock_process:
|
||||
mock_process.delay.side_effect = Exception("Queue down")
|
||||
with app.test_request_context(
|
||||
"/api/webhooks/agents/tok",
|
||||
method="POST",
|
||||
json={"event": "test"},
|
||||
):
|
||||
listener = AgentWebhookListener()
|
||||
response = listener._enqueue_webhook_task(
|
||||
"agent123", {"event": "test"}, "POST"
|
||||
)
|
||||
|
||||
assert response.status_code == 500
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestAgentWebhookListenerGet:
|
||||
|
||||
def test_uses_query_params_as_payload(self, app):
|
||||
from application.api.user.agents.webhooks import AgentWebhookListener
|
||||
|
||||
mock_task = Mock()
|
||||
mock_task.id = "task_xyz"
|
||||
|
||||
with patch(
|
||||
"application.api.user.agents.webhooks.process_agent_webhook"
|
||||
) as mock_process:
|
||||
mock_process.delay.return_value = mock_task
|
||||
with app.test_request_context(
|
||||
"/api/webhooks/agents/tok?event=ping&source=test",
|
||||
method="GET",
|
||||
):
|
||||
listener = AgentWebhookListener()
|
||||
response = listener.get(
|
||||
webhook_token="tok",
|
||||
agent={"_id": ObjectId()},
|
||||
agent_id_str="agent456",
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
call_kwargs = mock_process.delay.call_args[1]
|
||||
assert call_kwargs["payload"]["event"] == "ping"
|
||||
assert call_kwargs["payload"]["source"] == "test"
|
||||
1466
tests/api/user/test_workflows.py
Normal file
1466
tests/api/user/test_workflows.py
Normal file
File diff suppressed because it is too large
Load Diff
0
tests/core/__init__.py
Normal file
0
tests/core/__init__.py
Normal file
803
tests/core/test_model_settings.py
Normal file
803
tests/core/test_model_settings.py
Normal file
@@ -0,0 +1,803 @@
|
||||
"""Tests for application/core/model_settings.py"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from application.core.model_settings import (
|
||||
AvailableModel,
|
||||
ModelCapabilities,
|
||||
ModelProvider,
|
||||
ModelRegistry,
|
||||
)
|
||||
|
||||
|
||||
class TestModelProvider:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_all_providers_exist(self):
|
||||
assert ModelProvider.OPENAI == "openai"
|
||||
assert ModelProvider.ANTHROPIC == "anthropic"
|
||||
assert ModelProvider.GOOGLE == "google"
|
||||
assert ModelProvider.GROQ == "groq"
|
||||
assert ModelProvider.DOCSGPT == "docsgpt"
|
||||
assert ModelProvider.HUGGINGFACE == "huggingface"
|
||||
assert ModelProvider.NOVITA == "novita"
|
||||
assert ModelProvider.OPENROUTER == "openrouter"
|
||||
assert ModelProvider.SAGEMAKER == "sagemaker"
|
||||
assert ModelProvider.PREMAI == "premai"
|
||||
assert ModelProvider.LLAMA_CPP == "llama.cpp"
|
||||
assert ModelProvider.AZURE_OPENAI == "azure_openai"
|
||||
|
||||
|
||||
class TestModelCapabilities:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_defaults(self):
|
||||
caps = ModelCapabilities()
|
||||
assert caps.supports_tools is False
|
||||
assert caps.supports_structured_output is False
|
||||
assert caps.supports_streaming is True
|
||||
assert caps.supported_attachment_types == []
|
||||
assert caps.context_window == 128000
|
||||
assert caps.input_cost_per_token is None
|
||||
assert caps.output_cost_per_token is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_custom_values(self):
|
||||
caps = ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
context_window=32000,
|
||||
input_cost_per_token=0.001,
|
||||
)
|
||||
assert caps.supports_tools is True
|
||||
assert caps.context_window == 32000
|
||||
|
||||
|
||||
class TestAvailableModel:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_basic(self):
|
||||
model = AvailableModel(
|
||||
id="gpt-4",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="GPT-4",
|
||||
description="OpenAI GPT-4",
|
||||
)
|
||||
d = model.to_dict()
|
||||
assert d["id"] == "gpt-4"
|
||||
assert d["provider"] == "openai"
|
||||
assert d["display_name"] == "GPT-4"
|
||||
assert d["enabled"] is True
|
||||
assert "base_url" not in d
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_with_base_url(self):
|
||||
model = AvailableModel(
|
||||
id="local-model",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="Local",
|
||||
base_url="http://localhost:11434",
|
||||
)
|
||||
d = model.to_dict()
|
||||
assert d["base_url"] == "http://localhost:11434"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_includes_capabilities(self):
|
||||
caps = ModelCapabilities(supports_tools=True, context_window=64000)
|
||||
model = AvailableModel(
|
||||
id="m1",
|
||||
provider=ModelProvider.ANTHROPIC,
|
||||
display_name="M1",
|
||||
capabilities=caps,
|
||||
)
|
||||
d = model.to_dict()
|
||||
assert d["supports_tools"] is True
|
||||
assert d["context_window"] == 64000
|
||||
|
||||
|
||||
class TestModelRegistry:
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_singleton(self):
|
||||
"""Reset singleton between tests."""
|
||||
ModelRegistry._instance = None
|
||||
ModelRegistry._initialized = False
|
||||
yield
|
||||
ModelRegistry._instance = None
|
||||
ModelRegistry._initialized = False
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_singleton(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
r1 = ModelRegistry()
|
||||
r2 = ModelRegistry()
|
||||
assert r1 is r2
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_instance(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
r = ModelRegistry.get_instance()
|
||||
assert isinstance(r, ModelRegistry)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_model(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
model = AvailableModel(id="test", provider=ModelProvider.OPENAI, display_name="Test")
|
||||
reg.models["test"] = model
|
||||
assert reg.get_model("test") is model
|
||||
assert reg.get_model("nonexistent") is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_all_models(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models["m1"] = AvailableModel(id="m1", provider=ModelProvider.OPENAI, display_name="M1")
|
||||
reg.models["m2"] = AvailableModel(id="m2", provider=ModelProvider.ANTHROPIC, display_name="M2")
|
||||
assert len(reg.get_all_models()) == 2
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_enabled_models(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models["m1"] = AvailableModel(id="m1", provider=ModelProvider.OPENAI, display_name="M1", enabled=True)
|
||||
reg.models["m2"] = AvailableModel(id="m2", provider=ModelProvider.OPENAI, display_name="M2", enabled=False)
|
||||
enabled = reg.get_enabled_models()
|
||||
assert len(enabled) == 1
|
||||
assert enabled[0].id == "m1"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_model_exists(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models["m1"] = AvailableModel(id="m1", provider=ModelProvider.OPENAI, display_name="M1")
|
||||
assert reg.model_exists("m1") is True
|
||||
assert reg.model_exists("m2") is False
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_parse_model_names(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
assert reg._parse_model_names("model1,model2") == ["model1", "model2"]
|
||||
assert reg._parse_model_names("model1 , model2 ") == ["model1", "model2"]
|
||||
assert reg._parse_model_names("single") == ["single"]
|
||||
assert reg._parse_model_names("") == []
|
||||
assert reg._parse_model_names(None) == []
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_docsgpt_models(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
reg._add_docsgpt_models(mock_settings)
|
||||
assert "docsgpt-local" in reg.models
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_huggingface_models(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
reg._add_huggingface_models(mock_settings)
|
||||
assert "huggingface-local" in reg.models
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_load_models_with_openai_key(self):
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.OPENAI_BASE_URL = None
|
||||
mock_settings.OPENAI_API_KEY = "sk-test"
|
||||
mock_settings.OPENAI_API_BASE = None
|
||||
mock_settings.ANTHROPIC_API_KEY = None
|
||||
mock_settings.GOOGLE_API_KEY = None
|
||||
mock_settings.GROQ_API_KEY = None
|
||||
mock_settings.OPEN_ROUTER_API_KEY = None
|
||||
mock_settings.NOVITA_API_KEY = None
|
||||
mock_settings.HUGGINGFACE_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = "openai"
|
||||
mock_settings.LLM_NAME = ""
|
||||
mock_settings.API_KEY = None
|
||||
|
||||
with patch("application.core.settings.settings", mock_settings):
|
||||
reg = ModelRegistry()
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_load_models_custom_openai_base_url(self):
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.OPENAI_BASE_URL = "http://localhost:11434/v1"
|
||||
mock_settings.OPENAI_API_KEY = "sk-test"
|
||||
mock_settings.OPENAI_API_BASE = None
|
||||
mock_settings.ANTHROPIC_API_KEY = None
|
||||
mock_settings.GOOGLE_API_KEY = None
|
||||
mock_settings.GROQ_API_KEY = None
|
||||
mock_settings.OPEN_ROUTER_API_KEY = None
|
||||
mock_settings.NOVITA_API_KEY = None
|
||||
mock_settings.HUGGINGFACE_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = "openai"
|
||||
mock_settings.LLM_NAME = "llama3,gemma"
|
||||
mock_settings.API_KEY = None
|
||||
|
||||
with patch("application.core.settings.settings", mock_settings):
|
||||
reg = ModelRegistry()
|
||||
assert "llama3" in reg.models
|
||||
assert "gemma" in reg.models
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_default_model_selection_from_llm_name(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {"gpt-4": AvailableModel(id="gpt-4", provider=ModelProvider.OPENAI, display_name="GPT-4")}
|
||||
reg.default_model_id = "gpt-4"
|
||||
assert reg.default_model_id == "gpt-4"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_anthropic_models_with_key(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.ANTHROPIC_API_KEY = "sk-ant-test"
|
||||
mock_settings.LLM_PROVIDER = ""
|
||||
mock_settings.LLM_NAME = ""
|
||||
reg._add_anthropic_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_google_models_with_key(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.GOOGLE_API_KEY = "google-test"
|
||||
mock_settings.LLM_PROVIDER = ""
|
||||
mock_settings.LLM_NAME = ""
|
||||
reg._add_google_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_groq_models_with_key(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.GROQ_API_KEY = "groq-test"
|
||||
mock_settings.LLM_PROVIDER = ""
|
||||
mock_settings.LLM_NAME = ""
|
||||
reg._add_groq_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_openrouter_models_with_key(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.OPEN_ROUTER_API_KEY = "or-test"
|
||||
mock_settings.LLM_PROVIDER = ""
|
||||
mock_settings.LLM_NAME = ""
|
||||
reg._add_openrouter_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_novita_models_with_key(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.NOVITA_API_KEY = "novita-test"
|
||||
mock_settings.LLM_PROVIDER = ""
|
||||
mock_settings.LLM_NAME = ""
|
||||
reg._add_novita_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_azure_openai_models_specific(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.LLM_PROVIDER = "azure_openai"
|
||||
mock_settings.LLM_NAME = "nonexistent-model"
|
||||
reg._add_azure_openai_models(mock_settings)
|
||||
# Falls through to adding all azure models
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_anthropic_models_no_key_with_provider(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.ANTHROPIC_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = "anthropic"
|
||||
mock_settings.LLM_NAME = "nonexistent"
|
||||
reg._add_anthropic_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_default_model_fallback_to_first(self):
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.OPENAI_BASE_URL = None
|
||||
mock_settings.OPENAI_API_KEY = None
|
||||
mock_settings.OPENAI_API_BASE = None
|
||||
mock_settings.ANTHROPIC_API_KEY = None
|
||||
mock_settings.GOOGLE_API_KEY = None
|
||||
mock_settings.GROQ_API_KEY = None
|
||||
mock_settings.OPEN_ROUTER_API_KEY = None
|
||||
mock_settings.NOVITA_API_KEY = None
|
||||
mock_settings.HUGGINGFACE_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = ""
|
||||
mock_settings.LLM_NAME = ""
|
||||
mock_settings.API_KEY = None
|
||||
|
||||
with patch("application.core.settings.settings", mock_settings):
|
||||
reg = ModelRegistry()
|
||||
# Should have at least docsgpt-local
|
||||
assert reg.default_model_id is not None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_default_model_from_provider_fallback(self):
|
||||
"""When LLM_NAME is not set but LLM_PROVIDER and API_KEY are,
|
||||
default should be first model of that provider."""
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.OPENAI_BASE_URL = None
|
||||
mock_settings.OPENAI_API_KEY = "sk-test"
|
||||
mock_settings.OPENAI_API_BASE = None
|
||||
mock_settings.ANTHROPIC_API_KEY = None
|
||||
mock_settings.GOOGLE_API_KEY = None
|
||||
mock_settings.GROQ_API_KEY = None
|
||||
mock_settings.OPEN_ROUTER_API_KEY = None
|
||||
mock_settings.NOVITA_API_KEY = None
|
||||
mock_settings.HUGGINGFACE_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = "openai"
|
||||
mock_settings.LLM_NAME = None
|
||||
mock_settings.API_KEY = "sk-test"
|
||||
|
||||
with patch("application.core.settings.settings", mock_settings):
|
||||
reg = ModelRegistry()
|
||||
assert reg.default_model_id is not None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_google_models_no_key_with_provider(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.GOOGLE_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = "google"
|
||||
mock_settings.LLM_NAME = "nonexistent"
|
||||
reg._add_google_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_groq_models_no_key_with_provider(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.GROQ_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = "groq"
|
||||
mock_settings.LLM_NAME = "nonexistent"
|
||||
reg._add_groq_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_openrouter_models_no_key_with_provider(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.OPEN_ROUTER_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = "openrouter"
|
||||
mock_settings.LLM_NAME = "nonexistent"
|
||||
reg._add_openrouter_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_novita_models_no_key_with_provider(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.NOVITA_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = "novita"
|
||||
mock_settings.LLM_NAME = "nonexistent"
|
||||
reg._add_novita_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_disabled_model(self):
|
||||
model = AvailableModel(
|
||||
id="disabled",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="Disabled",
|
||||
enabled=False,
|
||||
)
|
||||
d = model.to_dict()
|
||||
assert d["enabled"] is False
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_with_attachment_types(self):
|
||||
caps = ModelCapabilities(
|
||||
supported_attachment_types=["image/png", "application/pdf"],
|
||||
)
|
||||
model = AvailableModel(
|
||||
id="vision",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="Vision",
|
||||
capabilities=caps,
|
||||
)
|
||||
d = model.to_dict()
|
||||
assert d["supported_attachment_types"] == ["image/png", "application/pdf"]
|
||||
|
||||
# ----------------------------------------------------------------
|
||||
# Coverage for _add_* methods with matching LLM_NAME
|
||||
# Lines: 100, 105, 147, 171, 179, 186, 199-201, 204, 210, 213,
|
||||
# 218, 229, 233, 241, 250
|
||||
# ----------------------------------------------------------------
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_azure_openai_models_with_matching_name(self):
|
||||
"""Cover line 186: azure model matching LLM_NAME returns early."""
|
||||
from application.core.model_configs import AZURE_OPENAI_MODELS
|
||||
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.LLM_PROVIDER = "azure_openai"
|
||||
if AZURE_OPENAI_MODELS:
|
||||
mock_settings.LLM_NAME = AZURE_OPENAI_MODELS[0].id
|
||||
else:
|
||||
mock_settings.LLM_NAME = "nonexistent"
|
||||
reg._add_azure_openai_models(mock_settings)
|
||||
# Should have added at least one model
|
||||
assert len(reg.models) >= 1
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_anthropic_no_key_no_provider_fallthrough(self):
|
||||
"""Cover lines 199-204: no key, provider set but name not found -> add all."""
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.ANTHROPIC_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = "anthropic"
|
||||
mock_settings.LLM_NAME = "nonexistent-model"
|
||||
reg._add_anthropic_models(mock_settings)
|
||||
# Falls through to add all anthropic models
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_google_no_key_matching_name(self):
|
||||
"""Cover lines 213-218: Google fallback with matching name."""
|
||||
from application.core.model_configs import GOOGLE_MODELS
|
||||
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.GOOGLE_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = "google"
|
||||
if GOOGLE_MODELS:
|
||||
mock_settings.LLM_NAME = GOOGLE_MODELS[0].id
|
||||
else:
|
||||
mock_settings.LLM_NAME = "nonexistent"
|
||||
reg._add_google_models(mock_settings)
|
||||
assert len(reg.models) >= 1
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_groq_no_key_matching_name(self):
|
||||
"""Cover lines 229-233: Groq fallback with matching name."""
|
||||
from application.core.model_configs import GROQ_MODELS
|
||||
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.GROQ_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = "groq"
|
||||
if GROQ_MODELS:
|
||||
mock_settings.LLM_NAME = GROQ_MODELS[0].id
|
||||
else:
|
||||
mock_settings.LLM_NAME = "nonexistent"
|
||||
reg._add_groq_models(mock_settings)
|
||||
assert len(reg.models) >= 1
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_openrouter_no_key_matching_name(self):
|
||||
"""Cover lines 241-250: OpenRouter fallback with matching name."""
|
||||
from application.core.model_configs import OPENROUTER_MODELS
|
||||
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.OPEN_ROUTER_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = "openrouter"
|
||||
if OPENROUTER_MODELS:
|
||||
mock_settings.LLM_NAME = OPENROUTER_MODELS[0].id
|
||||
else:
|
||||
mock_settings.LLM_NAME = "nonexistent"
|
||||
reg._add_openrouter_models(mock_settings)
|
||||
assert len(reg.models) >= 1
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_novita_no_key_matching_name(self):
|
||||
"""Cover novita fallback with matching name."""
|
||||
from application.core.model_configs import NOVITA_MODELS
|
||||
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.NOVITA_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = "novita"
|
||||
if NOVITA_MODELS:
|
||||
mock_settings.LLM_NAME = NOVITA_MODELS[0].id
|
||||
else:
|
||||
mock_settings.LLM_NAME = "nonexistent"
|
||||
reg._add_novita_models(mock_settings)
|
||||
assert len(reg.models) >= 1
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_load_models_default_from_llm_name_exact_match(self):
|
||||
"""Cover line 136/147: exact LLM_NAME match for default model."""
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.OPENAI_BASE_URL = None
|
||||
mock_settings.OPENAI_API_KEY = "sk-test"
|
||||
mock_settings.OPENAI_API_BASE = None
|
||||
mock_settings.ANTHROPIC_API_KEY = None
|
||||
mock_settings.GOOGLE_API_KEY = None
|
||||
mock_settings.GROQ_API_KEY = None
|
||||
mock_settings.OPEN_ROUTER_API_KEY = None
|
||||
mock_settings.NOVITA_API_KEY = None
|
||||
mock_settings.HUGGINGFACE_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = "openai"
|
||||
mock_settings.API_KEY = None
|
||||
|
||||
from application.core.model_configs import OPENAI_MODELS
|
||||
|
||||
if OPENAI_MODELS:
|
||||
mock_settings.LLM_NAME = OPENAI_MODELS[0].id
|
||||
else:
|
||||
mock_settings.LLM_NAME = "gpt-4o"
|
||||
|
||||
with patch("application.core.settings.settings", mock_settings):
|
||||
reg = ModelRegistry()
|
||||
assert reg.default_model_id is not None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_openai_models_local_endpoint_no_name(self):
|
||||
"""Cover line 171: local endpoint without LLM_NAME adds nothing."""
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.OPENAI_BASE_URL = "http://localhost:11434/v1"
|
||||
mock_settings.OPENAI_API_KEY = "sk-test"
|
||||
mock_settings.LLM_NAME = None
|
||||
reg._add_openai_models(mock_settings)
|
||||
assert len(reg.models) == 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_openai_standard_no_api_key(self):
|
||||
"""Cover line 179: standard OpenAI without API key adds nothing."""
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.OPENAI_BASE_URL = None
|
||||
mock_settings.OPENAI_API_KEY = None
|
||||
reg._add_openai_models(mock_settings)
|
||||
assert len(reg.models) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Coverage — additional uncovered lines: 100, 105, 147, 171, 179, 186, 250
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestModelRegistryAdditionalCoverage:
|
||||
|
||||
def test_add_azure_openai_models_specific_name(self):
|
||||
"""Cover line 186: azure_openai with specific LLM_NAME match."""
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.LLM_PROVIDER = "azure_openai"
|
||||
mock_settings.LLM_NAME = "gpt-4o"
|
||||
|
||||
# Create a fake model that matches
|
||||
fake_model = MagicMock()
|
||||
fake_model.id = "gpt-4o"
|
||||
with patch(
|
||||
"application.core.model_configs.AZURE_OPENAI_MODELS",
|
||||
[fake_model],
|
||||
):
|
||||
reg._add_azure_openai_models(mock_settings)
|
||||
assert "gpt-4o" in reg.models
|
||||
|
||||
def test_add_anthropic_models_with_api_key(self):
|
||||
"""Cover line 100: anthropic with API key."""
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.ANTHROPIC_API_KEY = "sk-test"
|
||||
mock_settings.LLM_PROVIDER = "anthropic"
|
||||
reg._add_anthropic_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
def test_add_google_models_with_api_key(self):
|
||||
"""Cover line 105: google with API key."""
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.GOOGLE_API_KEY = "test-key"
|
||||
mock_settings.LLM_PROVIDER = "google"
|
||||
reg._add_google_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
def test_default_model_from_provider(self):
|
||||
"""Cover line 147: default model selected from provider."""
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
reg.default_model_id = None
|
||||
|
||||
fake_model = MagicMock()
|
||||
fake_model.provider = MagicMock()
|
||||
fake_model.provider.value = "openai"
|
||||
reg.models["gpt-4o"] = fake_model
|
||||
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.LLM_NAME = None
|
||||
mock_settings.LLM_PROVIDER = "openai"
|
||||
mock_settings.API_KEY = "key"
|
||||
|
||||
# Simulate the default selection logic
|
||||
if not reg.default_model_id:
|
||||
for model_id, model in reg.models.items():
|
||||
if model.provider.value == mock_settings.LLM_PROVIDER:
|
||||
reg.default_model_id = model_id
|
||||
break
|
||||
|
||||
assert reg.default_model_id == "gpt-4o"
|
||||
|
||||
def test_add_openai_local_endpoint_with_llm_name(self):
|
||||
"""Cover line 171: local endpoint registers custom models from LLM_NAME."""
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.OPENAI_BASE_URL = "http://localhost:11434/v1"
|
||||
mock_settings.OPENAI_API_KEY = "sk-test"
|
||||
mock_settings.LLM_NAME = "llama3,phi3"
|
||||
reg._add_openai_models(mock_settings)
|
||||
assert "llama3" in reg.models
|
||||
assert "phi3" in reg.models
|
||||
|
||||
def test_add_openai_standard_with_api_key(self):
|
||||
"""Cover line 179: standard OpenAI with API key adds models."""
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.OPENAI_BASE_URL = None
|
||||
mock_settings.OPENAI_API_KEY = "sk-real-key"
|
||||
reg._add_openai_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
def test_add_openrouter_models(self):
|
||||
"""Cover line 250: openrouter models added."""
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.OPEN_ROUTER_API_KEY = "or-key"
|
||||
mock_settings.LLM_PROVIDER = "openrouter"
|
||||
reg._add_openrouter_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Additional coverage for model_settings.py
|
||||
# Lines: 135-136 (backward compat LLM_NAME), 138-143 (provider fallback),
|
||||
# 145-146 (first model as default)
|
||||
# ---------------------------------------------------------------------------
|
||||
# Imports already at the top of the file; no additional imports needed
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestDefaultModelSelectionBackwardCompat:
|
||||
"""Cover lines 135-136: backward compat exact match on LLM_NAME."""
|
||||
|
||||
def test_llm_name_exact_match_as_default(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
reg.default_model_id = None
|
||||
# Add a model with composite ID
|
||||
model = AvailableModel(
|
||||
id="my-composite-model",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="Composite",
|
||||
description="test",
|
||||
capabilities=ModelCapabilities(),
|
||||
)
|
||||
reg.models["my-composite-model"] = model
|
||||
|
||||
# Simulate _parse_model_names returning something different
|
||||
# so that the first for-loop doesn't match
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.LLM_NAME = "my-composite-model"
|
||||
mock_settings.LLM_PROVIDER = None
|
||||
mock_settings.API_KEY = None
|
||||
|
||||
# Call the logic directly
|
||||
model_names = reg._parse_model_names(mock_settings.LLM_NAME)
|
||||
for mn in model_names:
|
||||
if mn in reg.models:
|
||||
reg.default_model_id = mn
|
||||
break
|
||||
|
||||
assert reg.default_model_id == "my-composite-model"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestDefaultModelSelectionByProvider:
|
||||
"""Cover lines 138-143: default model by provider when LLM_NAME doesn't match."""
|
||||
|
||||
def test_default_by_provider(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
reg.default_model_id = None
|
||||
model = AvailableModel(
|
||||
id="gpt-4",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="GPT-4",
|
||||
description="test",
|
||||
capabilities=ModelCapabilities(),
|
||||
)
|
||||
reg.models["gpt-4"] = model
|
||||
|
||||
# Simulate: LLM_NAME doesn't exist/match, but LLM_PROVIDER + API_KEY set
|
||||
if not reg.default_model_id:
|
||||
for model_id, m in reg.models.items():
|
||||
if m.provider.value == "openai":
|
||||
reg.default_model_id = model_id
|
||||
break
|
||||
|
||||
assert reg.default_model_id == "gpt-4"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestDefaultModelSelectionFirstModel:
|
||||
"""Cover lines 145-146: first model as default when nothing else matches."""
|
||||
|
||||
def test_first_model_as_default(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
reg.default_model_id = None
|
||||
model = AvailableModel(
|
||||
id="fallback-model",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="Fallback",
|
||||
description="test",
|
||||
capabilities=ModelCapabilities(),
|
||||
)
|
||||
reg.models["fallback-model"] = model
|
||||
|
||||
if not reg.default_model_id and reg.models:
|
||||
reg.default_model_id = next(iter(reg.models.keys()))
|
||||
|
||||
assert reg.default_model_id == "fallback-model"
|
||||
382
tests/core/test_model_utils.py
Normal file
382
tests/core/test_model_utils.py
Normal file
@@ -0,0 +1,382 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from application.core.model_settings import (
|
||||
AvailableModel,
|
||||
ModelCapabilities,
|
||||
ModelProvider,
|
||||
ModelRegistry,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_registry():
|
||||
"""Reset ModelRegistry singleton between tests."""
|
||||
ModelRegistry._instance = None
|
||||
ModelRegistry._initialized = False
|
||||
yield
|
||||
ModelRegistry._instance = None
|
||||
ModelRegistry._initialized = False
|
||||
|
||||
|
||||
def _make_model(
|
||||
model_id="test-model",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="Test Model",
|
||||
context_window=128000,
|
||||
supports_tools=True,
|
||||
supports_structured_output=False,
|
||||
supported_attachment_types=None,
|
||||
enabled=True,
|
||||
base_url=None,
|
||||
):
|
||||
return AvailableModel(
|
||||
id=model_id,
|
||||
provider=provider,
|
||||
display_name=display_name,
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=supports_tools,
|
||||
supports_structured_output=supports_structured_output,
|
||||
supported_attachment_types=supported_attachment_types or [],
|
||||
context_window=context_window,
|
||||
),
|
||||
enabled=enabled,
|
||||
base_url=base_url,
|
||||
)
|
||||
|
||||
|
||||
# ── get_api_key_for_provider ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestGetApiKeyForProvider:
|
||||
"""settings is lazily imported inside the function body, so we patch
|
||||
at application.core.settings.settings (the actual module attribute)."""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_openai_key(self):
|
||||
with patch("application.core.settings.settings") as mock_settings:
|
||||
mock_settings.OPENAI_API_KEY = "sk-openai"
|
||||
mock_settings.API_KEY = "sk-fallback"
|
||||
mock_settings.OPEN_ROUTER_API_KEY = None
|
||||
mock_settings.NOVITA_API_KEY = None
|
||||
mock_settings.ANTHROPIC_API_KEY = None
|
||||
mock_settings.GOOGLE_API_KEY = None
|
||||
mock_settings.GROQ_API_KEY = None
|
||||
mock_settings.HUGGINGFACE_API_KEY = None
|
||||
|
||||
from application.core.model_utils import get_api_key_for_provider
|
||||
|
||||
assert get_api_key_for_provider("openai") == "sk-openai"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_anthropic_key(self):
|
||||
with patch("application.core.settings.settings") as mock_settings:
|
||||
mock_settings.ANTHROPIC_API_KEY = "sk-anthropic"
|
||||
mock_settings.API_KEY = "sk-fallback"
|
||||
|
||||
from application.core.model_utils import get_api_key_for_provider
|
||||
|
||||
assert get_api_key_for_provider("anthropic") == "sk-anthropic"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_google_key(self):
|
||||
with patch("application.core.settings.settings") as mock_settings:
|
||||
mock_settings.GOOGLE_API_KEY = "sk-google"
|
||||
mock_settings.API_KEY = "sk-fallback"
|
||||
|
||||
from application.core.model_utils import get_api_key_for_provider
|
||||
|
||||
assert get_api_key_for_provider("google") == "sk-google"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_groq_key(self):
|
||||
with patch("application.core.settings.settings") as mock_settings:
|
||||
mock_settings.GROQ_API_KEY = "sk-groq"
|
||||
mock_settings.API_KEY = "sk-fallback"
|
||||
|
||||
from application.core.model_utils import get_api_key_for_provider
|
||||
|
||||
assert get_api_key_for_provider("groq") == "sk-groq"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_openrouter_key(self):
|
||||
with patch("application.core.settings.settings") as mock_settings:
|
||||
mock_settings.OPEN_ROUTER_API_KEY = "sk-or"
|
||||
mock_settings.API_KEY = "sk-fallback"
|
||||
|
||||
from application.core.model_utils import get_api_key_for_provider
|
||||
|
||||
assert get_api_key_for_provider("openrouter") == "sk-or"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_novita_key(self):
|
||||
with patch("application.core.settings.settings") as mock_settings:
|
||||
mock_settings.NOVITA_API_KEY = "sk-novita"
|
||||
mock_settings.API_KEY = "sk-fallback"
|
||||
|
||||
from application.core.model_utils import get_api_key_for_provider
|
||||
|
||||
assert get_api_key_for_provider("novita") == "sk-novita"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_huggingface_key(self):
|
||||
with patch("application.core.settings.settings") as mock_settings:
|
||||
mock_settings.HUGGINGFACE_API_KEY = "hf-key"
|
||||
mock_settings.API_KEY = "sk-fallback"
|
||||
|
||||
from application.core.model_utils import get_api_key_for_provider
|
||||
|
||||
assert get_api_key_for_provider("huggingface") == "hf-key"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_docsgpt_returns_fallback(self):
|
||||
with patch("application.core.settings.settings") as mock_settings:
|
||||
mock_settings.API_KEY = "sk-fallback"
|
||||
|
||||
from application.core.model_utils import get_api_key_for_provider
|
||||
|
||||
assert get_api_key_for_provider("docsgpt") == "sk-fallback"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_llama_cpp_returns_fallback(self):
|
||||
with patch("application.core.settings.settings") as mock_settings:
|
||||
mock_settings.API_KEY = "sk-fallback"
|
||||
|
||||
from application.core.model_utils import get_api_key_for_provider
|
||||
|
||||
assert get_api_key_for_provider("llama.cpp") == "sk-fallback"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_unknown_provider_returns_fallback(self):
|
||||
with patch("application.core.settings.settings") as mock_settings:
|
||||
mock_settings.API_KEY = "sk-fallback"
|
||||
|
||||
from application.core.model_utils import get_api_key_for_provider
|
||||
|
||||
assert get_api_key_for_provider("unknown_provider") == "sk-fallback"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_azure_openai_key(self):
|
||||
with patch("application.core.settings.settings") as mock_settings:
|
||||
mock_settings.API_KEY = "sk-azure"
|
||||
|
||||
from application.core.model_utils import get_api_key_for_provider
|
||||
|
||||
assert get_api_key_for_provider("azure_openai") == "sk-azure"
|
||||
|
||||
|
||||
# ── get_all_available_models ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestGetAllAvailableModels:
|
||||
@pytest.mark.unit
|
||||
@patch("application.core.model_utils.ModelRegistry.get_instance")
|
||||
def test_returns_enabled_models_as_dict(self, mock_get_instance):
|
||||
model_a = _make_model("model-a", display_name="Model A")
|
||||
model_b = _make_model("model-b", display_name="Model B")
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.get_enabled_models.return_value = [model_a, model_b]
|
||||
mock_get_instance.return_value = mock_registry
|
||||
|
||||
from application.core.model_utils import get_all_available_models
|
||||
|
||||
result = get_all_available_models()
|
||||
|
||||
assert "model-a" in result
|
||||
assert "model-b" in result
|
||||
assert result["model-a"]["display_name"] == "Model A"
|
||||
assert result["model-b"]["display_name"] == "Model B"
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("application.core.model_utils.ModelRegistry.get_instance")
|
||||
def test_empty_registry(self, mock_get_instance):
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.get_enabled_models.return_value = []
|
||||
mock_get_instance.return_value = mock_registry
|
||||
|
||||
from application.core.model_utils import get_all_available_models
|
||||
|
||||
assert get_all_available_models() == {}
|
||||
|
||||
|
||||
# ── validate_model_id ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestValidateModelId:
|
||||
@pytest.mark.unit
|
||||
@patch("application.core.model_utils.ModelRegistry.get_instance")
|
||||
def test_exists(self, mock_get_instance):
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.model_exists.return_value = True
|
||||
mock_get_instance.return_value = mock_registry
|
||||
|
||||
from application.core.model_utils import validate_model_id
|
||||
|
||||
assert validate_model_id("gpt-4") is True
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("application.core.model_utils.ModelRegistry.get_instance")
|
||||
def test_not_exists(self, mock_get_instance):
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.model_exists.return_value = False
|
||||
mock_get_instance.return_value = mock_registry
|
||||
|
||||
from application.core.model_utils import validate_model_id
|
||||
|
||||
assert validate_model_id("nonexistent") is False
|
||||
|
||||
|
||||
# ── get_model_capabilities ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestGetModelCapabilities:
|
||||
@pytest.mark.unit
|
||||
@patch("application.core.model_utils.ModelRegistry.get_instance")
|
||||
def test_model_found(self, mock_get_instance):
|
||||
model = _make_model(
|
||||
"gpt-4",
|
||||
context_window=8192,
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=["image/png"],
|
||||
)
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.get_model.return_value = model
|
||||
mock_get_instance.return_value = mock_registry
|
||||
|
||||
from application.core.model_utils import get_model_capabilities
|
||||
|
||||
caps = get_model_capabilities("gpt-4")
|
||||
|
||||
assert caps is not None
|
||||
assert caps["supported_attachment_types"] == ["image/png"]
|
||||
assert caps["supports_tools"] is True
|
||||
assert caps["supports_structured_output"] is True
|
||||
assert caps["context_window"] == 8192
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("application.core.model_utils.ModelRegistry.get_instance")
|
||||
def test_model_not_found(self, mock_get_instance):
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.get_model.return_value = None
|
||||
mock_get_instance.return_value = mock_registry
|
||||
|
||||
from application.core.model_utils import get_model_capabilities
|
||||
|
||||
assert get_model_capabilities("nonexistent") is None
|
||||
|
||||
|
||||
# ── get_default_model_id ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestGetDefaultModelId:
|
||||
@pytest.mark.unit
|
||||
@patch("application.core.model_utils.ModelRegistry.get_instance")
|
||||
def test_returns_default(self, mock_get_instance):
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.default_model_id = "gpt-4"
|
||||
mock_get_instance.return_value = mock_registry
|
||||
|
||||
from application.core.model_utils import get_default_model_id
|
||||
|
||||
assert get_default_model_id() == "gpt-4"
|
||||
|
||||
|
||||
# ── get_provider_from_model_id ───────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestGetProviderFromModelId:
|
||||
@pytest.mark.unit
|
||||
@patch("application.core.model_utils.ModelRegistry.get_instance")
|
||||
def test_model_found(self, mock_get_instance):
|
||||
model = _make_model("gpt-4", provider=ModelProvider.OPENAI)
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.get_model.return_value = model
|
||||
mock_get_instance.return_value = mock_registry
|
||||
|
||||
from application.core.model_utils import get_provider_from_model_id
|
||||
|
||||
assert get_provider_from_model_id("gpt-4") == "openai"
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("application.core.model_utils.ModelRegistry.get_instance")
|
||||
def test_model_not_found(self, mock_get_instance):
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.get_model.return_value = None
|
||||
mock_get_instance.return_value = mock_registry
|
||||
|
||||
from application.core.model_utils import get_provider_from_model_id
|
||||
|
||||
assert get_provider_from_model_id("nonexistent") is None
|
||||
|
||||
|
||||
# ── get_token_limit ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestGetTokenLimit:
|
||||
@pytest.mark.unit
|
||||
@patch("application.core.model_utils.ModelRegistry.get_instance")
|
||||
def test_model_found(self, mock_get_instance):
|
||||
model = _make_model("gpt-4", context_window=8192)
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.get_model.return_value = model
|
||||
mock_get_instance.return_value = mock_registry
|
||||
|
||||
from application.core.model_utils import get_token_limit
|
||||
|
||||
assert get_token_limit("gpt-4") == 8192
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("application.core.model_utils.ModelRegistry.get_instance")
|
||||
def test_model_not_found_returns_default(self, mock_get_instance):
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.get_model.return_value = None
|
||||
mock_get_instance.return_value = mock_registry
|
||||
|
||||
with patch("application.core.settings.settings") as mock_settings:
|
||||
mock_settings.DEFAULT_LLM_TOKEN_LIMIT = 128000
|
||||
|
||||
from application.core.model_utils import get_token_limit
|
||||
|
||||
assert get_token_limit("nonexistent") == 128000
|
||||
|
||||
|
||||
# ── get_base_url_for_model ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestGetBaseUrlForModel:
|
||||
@pytest.mark.unit
|
||||
@patch("application.core.model_utils.ModelRegistry.get_instance")
|
||||
def test_model_with_base_url(self, mock_get_instance):
|
||||
model = _make_model("custom-model", base_url="http://localhost:8080")
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.get_model.return_value = model
|
||||
mock_get_instance.return_value = mock_registry
|
||||
|
||||
from application.core.model_utils import get_base_url_for_model
|
||||
|
||||
assert get_base_url_for_model("custom-model") == "http://localhost:8080"
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("application.core.model_utils.ModelRegistry.get_instance")
|
||||
def test_model_without_base_url(self, mock_get_instance):
|
||||
model = _make_model("gpt-4", base_url=None)
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.get_model.return_value = model
|
||||
mock_get_instance.return_value = mock_registry
|
||||
|
||||
from application.core.model_utils import get_base_url_for_model
|
||||
|
||||
assert get_base_url_for_model("gpt-4") is None
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("application.core.model_utils.ModelRegistry.get_instance")
|
||||
def test_model_not_found(self, mock_get_instance):
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.get_model.return_value = None
|
||||
mock_get_instance.return_value = mock_registry
|
||||
|
||||
from application.core.model_utils import get_base_url_for_model
|
||||
|
||||
assert get_base_url_for_model("nonexistent") is None
|
||||
@@ -195,3 +195,67 @@ class TestValidateUrlSafe:
|
||||
is_valid, url, error = validate_url_safe("http://192.168.1.1")
|
||||
assert is_valid is False
|
||||
assert "private" in error.lower() or "internal" in error.lower()
|
||||
|
||||
def test_adds_scheme_when_missing(self):
|
||||
with patch("application.core.url_validation.resolve_hostname") as mock_resolve:
|
||||
mock_resolve.return_value = "93.184.216.34"
|
||||
is_valid, url, error = validate_url_safe("example.com")
|
||||
assert is_valid is True
|
||||
assert url == "http://example.com"
|
||||
|
||||
|
||||
class TestIsPrivateIPExtended:
|
||||
"""Additional edge cases for IP classification."""
|
||||
|
||||
def test_multicast_ip(self):
|
||||
assert is_private_ip("224.0.0.1") is True
|
||||
|
||||
def test_unspecified_ip(self):
|
||||
assert is_private_ip("0.0.0.0") is True
|
||||
|
||||
def test_ipv6_loopback(self):
|
||||
assert is_private_ip("::1") is True
|
||||
|
||||
def test_ipv6_private(self):
|
||||
assert is_private_ip("fc00::1") is True
|
||||
|
||||
def test_ipv6_public(self):
|
||||
assert is_private_ip("2607:f8b0:4004:800::200e") is False
|
||||
|
||||
def test_reserved_ip(self):
|
||||
# 240.0.0.0/4 is reserved (future use), Python's ipaddress marks it as such
|
||||
assert is_private_ip("240.0.0.1") is True
|
||||
|
||||
|
||||
class TestValidateUrlExtended:
|
||||
"""Additional URL validation tests."""
|
||||
|
||||
def test_blocks_metadata_hostname(self):
|
||||
with pytest.raises(SSRFError):
|
||||
validate_url("http://metadata")
|
||||
|
||||
def test_allows_localhost_with_flag(self):
|
||||
with patch("application.core.url_validation.resolve_hostname") as mock_resolve:
|
||||
mock_resolve.return_value = "192.168.1.1"
|
||||
result = validate_url(
|
||||
"http://internal.local", allow_localhost=True
|
||||
)
|
||||
assert result == "http://internal.local"
|
||||
|
||||
def test_blocks_aws_ecs_metadata_ip(self):
|
||||
with pytest.raises(SSRFError, match="metadata"):
|
||||
validate_url("http://169.254.170.2")
|
||||
|
||||
def test_blocks_aws_ipv6_metadata(self):
|
||||
with pytest.raises(SSRFError, match="metadata"):
|
||||
validate_url("http://[fd00:ec2::254]")
|
||||
|
||||
def test_blocks_hostname_resolving_to_loopback(self):
|
||||
with patch("application.core.url_validation.resolve_hostname") as mock_resolve:
|
||||
mock_resolve.return_value = "127.0.0.1"
|
||||
with pytest.raises(SSRFError):
|
||||
validate_url("http://sneaky.example.com")
|
||||
|
||||
def test_allows_localhost_ip_with_flag(self):
|
||||
result = validate_url("http://10.0.0.1", allow_localhost=True)
|
||||
assert result == "http://10.0.0.1"
|
||||
|
||||
577
tests/integration/test_workflows.py
Normal file
577
tests/integration/test_workflows.py
Normal file
@@ -0,0 +1,577 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Integration tests for DocsGPT workflow management endpoints.
|
||||
|
||||
Uses Flask test client with real MongoDB (must be running).
|
||||
|
||||
Endpoints tested:
|
||||
- /api/workflows (POST) - Create workflow
|
||||
- /api/workflows/<id> (GET) - Get workflow
|
||||
- /api/workflows/<id> (PUT) - Update workflow
|
||||
- /api/workflows/<id> (DELETE) - Delete workflow
|
||||
|
||||
Run:
|
||||
pytest tests/integration/test_workflows.py -v
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from jose import jwt
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def app():
|
||||
"""Create the real Flask app (connects to real MongoDB)."""
|
||||
from application.app import app as flask_app
|
||||
flask_app.config["TESTING"] = True
|
||||
return flask_app
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client(app):
|
||||
"""Flask test client.
|
||||
|
||||
When AUTH_TYPE is set to simple_jwt/session_jwt a Bearer token is
|
||||
injected; otherwise the backend already returns {"sub": "local"}
|
||||
for every request so no token is needed.
|
||||
"""
|
||||
from application.core.settings import settings
|
||||
|
||||
c = app.test_client()
|
||||
if settings.AUTH_TYPE in ("simple_jwt", "session_jwt"):
|
||||
secret = settings.JWT_SECRET_KEY
|
||||
if not secret:
|
||||
pytest.skip("JWT_SECRET_KEY not configured")
|
||||
payload = {"sub": f"test_workflow_integration_{int(time.time())}"}
|
||||
token = jwt.encode(payload, secret, algorithm="HS256")
|
||||
c.environ_base["HTTP_AUTHORIZATION"] = f"Bearer {token}"
|
||||
return c
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def created_ids():
|
||||
"""Accumulator for workflow IDs to clean up after all tests."""
|
||||
return []
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="module")
|
||||
def cleanup(client, created_ids):
|
||||
"""Delete all test-created workflows after the module finishes."""
|
||||
yield
|
||||
for wf_id in created_ids:
|
||||
try:
|
||||
client.delete(f"/api/workflows/{wf_id}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Payload helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def simple_workflow(suffix=""):
|
||||
"""Start -> End."""
|
||||
return {
|
||||
"name": f"Simple WF {int(time.time())}{suffix}",
|
||||
"description": "integration test",
|
||||
"nodes": [
|
||||
{"id": "start_1", "type": "start", "title": "Start",
|
||||
"position": {"x": 0, "y": 0}, "data": {}},
|
||||
{"id": "end_1", "type": "end", "title": "End",
|
||||
"position": {"x": 400, "y": 0}, "data": {}},
|
||||
],
|
||||
"edges": [
|
||||
{"id": "edge_1", "source": "start_1", "target": "end_1"},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def linear_workflow(suffix=""):
|
||||
"""Start -> Agent -> End."""
|
||||
return {
|
||||
"name": f"Linear WF {int(time.time())}{suffix}",
|
||||
"description": "integration test",
|
||||
"nodes": [
|
||||
{"id": "start_1", "type": "start", "title": "Start",
|
||||
"position": {"x": 0, "y": 0}, "data": {}},
|
||||
{"id": "agent_1", "type": "agent", "title": "Agent",
|
||||
"position": {"x": 200, "y": 0}, "data": {
|
||||
"agent_type": "classic",
|
||||
"system_prompt": "You are helpful.",
|
||||
"prompt_template": "",
|
||||
"stream_to_user": False,
|
||||
}},
|
||||
{"id": "end_1", "type": "end", "title": "End",
|
||||
"position": {"x": 400, "y": 0}, "data": {}},
|
||||
],
|
||||
"edges": [
|
||||
{"id": "edge_1", "source": "start_1", "target": "agent_1"},
|
||||
{"id": "edge_2", "source": "agent_1", "target": "end_1"},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def multi_input_end_workflow(suffix=""):
|
||||
"""Condition branches into two agents, both converging on one end node.
|
||||
|
||||
Graph:
|
||||
start -> condition --(case_1)--> agent_a --\
|
||||
--(else)----> agent_b ---+--> end
|
||||
"""
|
||||
return {
|
||||
"name": f"Multi-Input End {int(time.time())}{suffix}",
|
||||
"description": "end node with multiple inputs",
|
||||
"nodes": [
|
||||
{"id": "start_1", "type": "start", "title": "Start",
|
||||
"position": {"x": 0, "y": 100}, "data": {}},
|
||||
{"id": "cond_1", "type": "condition", "title": "Branch",
|
||||
"position": {"x": 200, "y": 100}, "data": {
|
||||
"mode": "simple",
|
||||
"cases": [
|
||||
{"name": "Case 1", "expression": "true",
|
||||
"sourceHandle": "case_1"},
|
||||
],
|
||||
}},
|
||||
{"id": "agent_a", "type": "agent", "title": "Agent A",
|
||||
"position": {"x": 400, "y": 0}, "data": {
|
||||
"agent_type": "classic",
|
||||
"system_prompt": "Branch A",
|
||||
"prompt_template": "",
|
||||
"stream_to_user": False,
|
||||
}},
|
||||
{"id": "agent_b", "type": "agent", "title": "Agent B",
|
||||
"position": {"x": 400, "y": 200}, "data": {
|
||||
"agent_type": "classic",
|
||||
"system_prompt": "Branch B",
|
||||
"prompt_template": "",
|
||||
"stream_to_user": False,
|
||||
}},
|
||||
{"id": "end_1", "type": "end", "title": "End",
|
||||
"position": {"x": 600, "y": 100}, "data": {}},
|
||||
],
|
||||
"edges": [
|
||||
{"id": "e1", "source": "start_1", "target": "cond_1"},
|
||||
{"id": "e2", "source": "cond_1", "target": "agent_a",
|
||||
"sourceHandle": "case_1"},
|
||||
{"id": "e3", "source": "cond_1", "target": "agent_b",
|
||||
"sourceHandle": "else"},
|
||||
# Both agents feed into the SAME end node
|
||||
{"id": "e4", "source": "agent_a", "target": "end_1"},
|
||||
{"id": "e5", "source": "agent_b", "target": "end_1"},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _extract_id(resp):
|
||||
"""Pull workflow id from create/update response."""
|
||||
body = resp.get_json()
|
||||
data = body.get("data") or body
|
||||
return data.get("id")
|
||||
|
||||
|
||||
def _get_graph(client, wf_id):
|
||||
"""Fetch workflow and return (nodes, edges)."""
|
||||
resp = client.get(f"/api/workflows/{wf_id}")
|
||||
assert resp.status_code == 200, resp.get_data(as_text=True)
|
||||
body = resp.get_json()
|
||||
data = body.get("data") or body
|
||||
return data.get("nodes", []), data.get("edges", [])
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# CRUD tests
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestWorkflowCRUD:
|
||||
|
||||
def test_create_simple_workflow(self, client, created_ids):
|
||||
resp = client.post("/api/workflows", json=simple_workflow())
|
||||
assert resp.status_code in (200, 201), resp.get_data(as_text=True)
|
||||
wf_id = _extract_id(resp)
|
||||
assert wf_id
|
||||
created_ids.append(wf_id)
|
||||
|
||||
def test_create_linear_workflow(self, client, created_ids):
|
||||
resp = client.post("/api/workflows", json=linear_workflow())
|
||||
assert resp.status_code in (200, 201), resp.get_data(as_text=True)
|
||||
wf_id = _extract_id(resp)
|
||||
assert wf_id
|
||||
created_ids.append(wf_id)
|
||||
|
||||
def test_get_workflow_returns_nodes_and_edges(self, client, created_ids):
|
||||
resp = client.post("/api/workflows", json=simple_workflow(" get"))
|
||||
wf_id = _extract_id(resp)
|
||||
created_ids.append(wf_id)
|
||||
|
||||
nodes, edges = _get_graph(client, wf_id)
|
||||
assert len(nodes) == 2
|
||||
assert len(edges) == 1
|
||||
|
||||
def test_update_workflow(self, client, created_ids):
|
||||
resp = client.post("/api/workflows", json=simple_workflow(" upd"))
|
||||
wf_id = _extract_id(resp)
|
||||
created_ids.append(wf_id)
|
||||
|
||||
update_resp = client.put(
|
||||
f"/api/workflows/{wf_id}", json=linear_workflow(" updated")
|
||||
)
|
||||
assert update_resp.status_code == 200, update_resp.get_data(as_text=True)
|
||||
|
||||
nodes, edges = _get_graph(client, wf_id)
|
||||
assert len(nodes) == 3 # start, agent, end
|
||||
assert len(edges) == 2
|
||||
|
||||
def test_delete_workflow(self, client):
|
||||
resp = client.post("/api/workflows", json=simple_workflow(" del"))
|
||||
wf_id = _extract_id(resp)
|
||||
|
||||
del_resp = client.delete(f"/api/workflows/{wf_id}")
|
||||
assert del_resp.status_code == 200
|
||||
|
||||
get_resp = client.get(f"/api/workflows/{wf_id}")
|
||||
assert get_resp.status_code in (400, 404)
|
||||
|
||||
def test_reject_workflow_without_end_node(self, client):
|
||||
payload = {
|
||||
"name": "No End",
|
||||
"nodes": [
|
||||
{"id": "s", "type": "start", "title": "Start",
|
||||
"position": {"x": 0, "y": 0}, "data": {}},
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
resp = client.post("/api/workflows", json=payload)
|
||||
assert resp.status_code == 400, resp.get_data(as_text=True)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Multi-input end node tests
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestMultiInputEndNode:
|
||||
"""Verify that an end node can receive edges from multiple source nodes."""
|
||||
|
||||
def test_create_multi_input_end_workflow_accepted(self, client, created_ids):
|
||||
"""Backend must accept a workflow where two edges target the same end node."""
|
||||
resp = client.post("/api/workflows", json=multi_input_end_workflow())
|
||||
assert resp.status_code in (200, 201), resp.get_data(as_text=True)
|
||||
wf_id = _extract_id(resp)
|
||||
assert wf_id
|
||||
created_ids.append(wf_id)
|
||||
|
||||
def test_multi_input_end_all_edges_persisted(self, client, created_ids):
|
||||
"""After round-trip, both edges into the end node must still be present."""
|
||||
resp = client.post(
|
||||
"/api/workflows", json=multi_input_end_workflow(" persist")
|
||||
)
|
||||
assert resp.status_code in (200, 201), resp.get_data(as_text=True)
|
||||
wf_id = _extract_id(resp)
|
||||
created_ids.append(wf_id)
|
||||
|
||||
nodes, edges = _get_graph(client, wf_id)
|
||||
|
||||
# Locate end node
|
||||
end_ids = {n["id"] for n in nodes if n["type"] == "end"}
|
||||
assert end_ids, "no end node in response"
|
||||
|
||||
# Count edges targeting any end node
|
||||
edges_to_end = [e for e in edges if e["target"] in end_ids]
|
||||
assert len(edges_to_end) >= 2, (
|
||||
f"Expected >=2 edges to end, got {len(edges_to_end)}: {edges_to_end}"
|
||||
)
|
||||
|
||||
def test_multi_input_end_total_edge_count(self, client, created_ids):
|
||||
"""All 5 edges of the multi-input graph must survive persistence."""
|
||||
resp = client.post(
|
||||
"/api/workflows", json=multi_input_end_workflow(" count")
|
||||
)
|
||||
wf_id = _extract_id(resp)
|
||||
created_ids.append(wf_id)
|
||||
|
||||
_, edges = _get_graph(client, wf_id)
|
||||
assert len(edges) == 5, f"Expected 5 edges, got {len(edges)}"
|
||||
|
||||
def test_update_to_multi_input_end_preserves_edges(self, client, created_ids):
|
||||
"""Updating a simple workflow to multi-input end keeps all edges."""
|
||||
# Create simple
|
||||
resp = client.post("/api/workflows", json=simple_workflow(" pre"))
|
||||
wf_id = _extract_id(resp)
|
||||
created_ids.append(wf_id)
|
||||
|
||||
# Update to multi-input end
|
||||
update_resp = client.put(
|
||||
f"/api/workflows/{wf_id}",
|
||||
json=multi_input_end_workflow(" post"),
|
||||
)
|
||||
assert update_resp.status_code == 200, update_resp.get_data(as_text=True)
|
||||
|
||||
nodes, edges = _get_graph(client, wf_id)
|
||||
end_ids = {n["id"] for n in nodes if n["type"] == "end"}
|
||||
edges_to_end = [e for e in edges if e["target"] in end_ids]
|
||||
assert len(edges_to_end) >= 2, (
|
||||
f"Expected >=2 edges to end after update, got {len(edges_to_end)}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Source-aware payload helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def workflow_with_sources(sources, suffix=""):
|
||||
"""Start -> Agent (with sources) -> End."""
|
||||
return {
|
||||
"name": f"Source WF {int(time.time())}{suffix}",
|
||||
"description": "integration test with sources",
|
||||
"nodes": [
|
||||
{"id": "start_1", "type": "start", "title": "Start",
|
||||
"position": {"x": 0, "y": 0}, "data": {}},
|
||||
{"id": "agent_1", "type": "agent", "title": "Agent",
|
||||
"position": {"x": 200, "y": 0}, "data": {
|
||||
"agent_type": "classic",
|
||||
"system_prompt": "You are helpful.",
|
||||
"prompt_template": "",
|
||||
"stream_to_user": False,
|
||||
"sources": sources,
|
||||
"tools": [],
|
||||
}},
|
||||
{"id": "end_1", "type": "end", "title": "End",
|
||||
"position": {"x": 400, "y": 0}, "data": {}},
|
||||
],
|
||||
"edges": [
|
||||
{"id": "edge_1", "source": "start_1", "target": "agent_1"},
|
||||
{"id": "edge_2", "source": "agent_1", "target": "end_1"},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def workflow_multi_agent_sources(suffix=""):
|
||||
"""Start -> Agent A (sources A) -> Agent B (sources B) -> End."""
|
||||
return {
|
||||
"name": f"Multi-Agent Sources {int(time.time())}{suffix}",
|
||||
"description": "two agents with different sources",
|
||||
"nodes": [
|
||||
{"id": "start_1", "type": "start", "title": "Start",
|
||||
"position": {"x": 0, "y": 0}, "data": {}},
|
||||
{"id": "agent_a", "type": "agent", "title": "Agent A",
|
||||
"position": {"x": 200, "y": 0}, "data": {
|
||||
"agent_type": "agentic",
|
||||
"system_prompt": "Agent A prompt",
|
||||
"prompt_template": "",
|
||||
"stream_to_user": False,
|
||||
"sources": ["src_alpha", "src_beta"],
|
||||
"tools": [],
|
||||
}},
|
||||
{"id": "agent_b", "type": "agent", "title": "Agent B",
|
||||
"position": {"x": 400, "y": 0}, "data": {
|
||||
"agent_type": "classic",
|
||||
"system_prompt": "Agent B prompt",
|
||||
"prompt_template": "",
|
||||
"stream_to_user": True,
|
||||
"sources": ["src_gamma"],
|
||||
"tools": [],
|
||||
}},
|
||||
{"id": "end_1", "type": "end", "title": "End",
|
||||
"position": {"x": 600, "y": 0}, "data": {}},
|
||||
],
|
||||
"edges": [
|
||||
{"id": "e1", "source": "start_1", "target": "agent_a"},
|
||||
{"id": "e2", "source": "agent_a", "target": "agent_b"},
|
||||
{"id": "e3", "source": "agent_b", "target": "end_1"},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _find_agent_node(nodes, node_id):
|
||||
"""Find a specific node by id."""
|
||||
return next((n for n in nodes if n["id"] == node_id), None)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Workflow integration tests
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestWorkflowIntegration:
|
||||
"""Verify end-to-end workflow create → get → update → get round-trips."""
|
||||
|
||||
def test_linear_workflow_round_trip(self, client, created_ids):
|
||||
"""Create a linear workflow and verify all nodes/edges survive the round-trip."""
|
||||
payload = linear_workflow(" round-trip")
|
||||
resp = client.post("/api/workflows", json=payload)
|
||||
assert resp.status_code in (200, 201), resp.get_data(as_text=True)
|
||||
wf_id = _extract_id(resp)
|
||||
assert wf_id
|
||||
created_ids.append(wf_id)
|
||||
|
||||
nodes, edges = _get_graph(client, wf_id)
|
||||
assert len(nodes) == 3
|
||||
assert len(edges) == 2
|
||||
|
||||
# Verify node types
|
||||
types = {n["id"]: n["type"] for n in nodes}
|
||||
assert types["start_1"] == "start"
|
||||
assert types["agent_1"] == "agent"
|
||||
assert types["end_1"] == "end"
|
||||
|
||||
def test_agent_config_persisted(self, client, created_ids):
|
||||
"""Agent node config (type, prompts, stream_to_user) round-trips correctly."""
|
||||
payload = linear_workflow(" config")
|
||||
resp = client.post("/api/workflows", json=payload)
|
||||
wf_id = _extract_id(resp)
|
||||
created_ids.append(wf_id)
|
||||
|
||||
nodes, _ = _get_graph(client, wf_id)
|
||||
agent = _find_agent_node(nodes, "agent_1")
|
||||
assert agent is not None
|
||||
assert agent["data"]["agent_type"] == "classic"
|
||||
assert agent["data"]["system_prompt"] == "You are helpful."
|
||||
assert agent["data"]["stream_to_user"] is False
|
||||
|
||||
def test_update_workflow_replaces_graph(self, client, created_ids):
|
||||
"""Updating a workflow fully replaces nodes and edges."""
|
||||
resp = client.post("/api/workflows", json=simple_workflow(" replace"))
|
||||
wf_id = _extract_id(resp)
|
||||
created_ids.append(wf_id)
|
||||
|
||||
nodes, edges = _get_graph(client, wf_id)
|
||||
assert len(nodes) == 2
|
||||
|
||||
# Update to linear
|
||||
update_resp = client.put(
|
||||
f"/api/workflows/{wf_id}", json=linear_workflow(" replaced")
|
||||
)
|
||||
assert update_resp.status_code == 200
|
||||
|
||||
nodes, edges = _get_graph(client, wf_id)
|
||||
assert len(nodes) == 3
|
||||
assert len(edges) == 2
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Source-specific integration tests
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestWorkflowSources:
|
||||
"""Verify that agent node sources are persisted and retrieved correctly."""
|
||||
|
||||
def test_create_workflow_with_single_source(self, client, created_ids):
|
||||
"""A workflow with one source on an agent node persists it."""
|
||||
payload = workflow_with_sources(["default"])
|
||||
resp = client.post("/api/workflows", json=payload)
|
||||
assert resp.status_code in (200, 201), resp.get_data(as_text=True)
|
||||
wf_id = _extract_id(resp)
|
||||
assert wf_id
|
||||
created_ids.append(wf_id)
|
||||
|
||||
nodes, _ = _get_graph(client, wf_id)
|
||||
agent = _find_agent_node(nodes, "agent_1")
|
||||
assert agent is not None, "Agent node not found"
|
||||
assert agent["data"].get("sources") == ["default"], (
|
||||
f"Expected sources=['default'], got {agent['data'].get('sources')}"
|
||||
)
|
||||
|
||||
def test_create_workflow_with_multiple_sources(self, client, created_ids):
|
||||
"""Multiple sources on an agent node are all persisted."""
|
||||
sources = ["src_1", "src_2", "src_3"]
|
||||
payload = workflow_with_sources(sources)
|
||||
resp = client.post("/api/workflows", json=payload)
|
||||
assert resp.status_code in (200, 201), resp.get_data(as_text=True)
|
||||
wf_id = _extract_id(resp)
|
||||
created_ids.append(wf_id)
|
||||
|
||||
nodes, _ = _get_graph(client, wf_id)
|
||||
agent = _find_agent_node(nodes, "agent_1")
|
||||
assert agent is not None
|
||||
assert agent["data"].get("sources") == sources
|
||||
|
||||
def test_create_workflow_with_empty_sources(self, client, created_ids):
|
||||
"""An agent node with empty sources list is accepted and persisted."""
|
||||
payload = workflow_with_sources([])
|
||||
resp = client.post("/api/workflows", json=payload)
|
||||
assert resp.status_code in (200, 201), resp.get_data(as_text=True)
|
||||
wf_id = _extract_id(resp)
|
||||
assert wf_id
|
||||
created_ids.append(wf_id)
|
||||
|
||||
nodes, _ = _get_graph(client, wf_id)
|
||||
agent = _find_agent_node(nodes, "agent_1")
|
||||
assert agent is not None
|
||||
assert agent["data"].get("sources") == []
|
||||
|
||||
def test_update_workflow_sources(self, client, created_ids):
|
||||
"""Updating a workflow replaces agent sources."""
|
||||
# Create with original sources
|
||||
payload = workflow_with_sources(["old_src"])
|
||||
resp = client.post("/api/workflows", json=payload)
|
||||
wf_id = _extract_id(resp)
|
||||
created_ids.append(wf_id)
|
||||
|
||||
# Update with new sources
|
||||
updated_payload = workflow_with_sources(["new_src_1", "new_src_2"], " upd")
|
||||
update_resp = client.put(f"/api/workflows/{wf_id}", json=updated_payload)
|
||||
assert update_resp.status_code == 200, update_resp.get_data(as_text=True)
|
||||
|
||||
nodes, _ = _get_graph(client, wf_id)
|
||||
agent = _find_agent_node(nodes, "agent_1")
|
||||
assert agent is not None
|
||||
assert agent["data"].get("sources") == ["new_src_1", "new_src_2"]
|
||||
|
||||
def test_multi_agent_independent_sources(self, client, created_ids):
|
||||
"""Each agent node keeps its own distinct sources list."""
|
||||
payload = workflow_multi_agent_sources()
|
||||
resp = client.post("/api/workflows", json=payload)
|
||||
assert resp.status_code in (200, 201), resp.get_data(as_text=True)
|
||||
wf_id = _extract_id(resp)
|
||||
created_ids.append(wf_id)
|
||||
|
||||
nodes, _ = _get_graph(client, wf_id)
|
||||
agent_a = _find_agent_node(nodes, "agent_a")
|
||||
agent_b = _find_agent_node(nodes, "agent_b")
|
||||
|
||||
assert agent_a is not None, "Agent A not found"
|
||||
assert agent_b is not None, "Agent B not found"
|
||||
assert agent_a["data"].get("sources") == ["src_alpha", "src_beta"]
|
||||
assert agent_b["data"].get("sources") == ["src_gamma"]
|
||||
|
||||
def test_sources_survive_workflow_update(self, client, created_ids):
|
||||
"""Sources survive when a workflow is updated without changing sources."""
|
||||
payload = workflow_with_sources(["persistent_src"])
|
||||
resp = client.post("/api/workflows", json=payload)
|
||||
wf_id = _extract_id(resp)
|
||||
created_ids.append(wf_id)
|
||||
|
||||
# Update keeping same sources
|
||||
update_resp = client.put(f"/api/workflows/{wf_id}", json=payload)
|
||||
assert update_resp.status_code == 200
|
||||
|
||||
nodes, _ = _get_graph(client, wf_id)
|
||||
agent = _find_agent_node(nodes, "agent_1")
|
||||
assert agent["data"].get("sources") == ["persistent_src"]
|
||||
|
||||
def test_remove_sources_on_update(self, client, created_ids):
|
||||
"""Clearing sources on update results in empty list."""
|
||||
payload = workflow_with_sources(["will_be_removed"])
|
||||
resp = client.post("/api/workflows", json=payload)
|
||||
wf_id = _extract_id(resp)
|
||||
created_ids.append(wf_id)
|
||||
|
||||
# Update with no sources
|
||||
cleared_payload = workflow_with_sources([], " cleared")
|
||||
update_resp = client.put(f"/api/workflows/{wf_id}", json=cleared_payload)
|
||||
assert update_resp.status_code == 200
|
||||
|
||||
nodes, _ = _get_graph(client, wf_id)
|
||||
agent = _find_agent_node(nodes, "agent_1")
|
||||
assert agent["data"].get("sources") == []
|
||||
0
tests/llm/__init__.py
Normal file
0
tests/llm/__init__.py
Normal file
@@ -71,6 +71,7 @@ class TestLLMHandlerCreator:
|
||||
expected_handlers = {
|
||||
"openai": OpenAILLMHandler,
|
||||
"google": GoogleLLMHandler,
|
||||
"novita": OpenAILLMHandler,
|
||||
"default": OpenAILLMHandler,
|
||||
}
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
323
tests/llm/test_anthropic.py
Normal file
323
tests/llm/test_anthropic.py
Normal file
@@ -0,0 +1,323 @@
|
||||
"""Unit tests for application/llm/anthropic.py — AnthropicLLM.
|
||||
|
||||
Extends coverage beyond test_anthropic_llm.py:
|
||||
- Constructor: api_key priority, base_url support
|
||||
- get_supported_attachment_types
|
||||
- prepare_messages_with_attachments: various scenarios
|
||||
- _get_base64_image: error paths
|
||||
- _raw_gen_stream: close called on response
|
||||
"""
|
||||
|
||||
import sys
|
||||
import types
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fake anthropic module
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _FakeCompletion:
|
||||
def __init__(self, text):
|
||||
self.completion = text
|
||||
|
||||
|
||||
class _FakeCompletions:
|
||||
def __init__(self):
|
||||
self.last_kwargs = None
|
||||
self._stream_items = [_FakeCompletion("s1"), _FakeCompletion("s2")]
|
||||
|
||||
def create(self, **kwargs):
|
||||
self.last_kwargs = kwargs
|
||||
if kwargs.get("stream"):
|
||||
return self._stream_items
|
||||
return _FakeCompletion("final")
|
||||
|
||||
|
||||
class _FakeAnthropic:
|
||||
def __init__(self, api_key=None, base_url=None):
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
self.completions = _FakeCompletions()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_anthropic(monkeypatch):
|
||||
fake = types.ModuleType("anthropic")
|
||||
fake.Anthropic = _FakeAnthropic
|
||||
fake.HUMAN_PROMPT = "<HUMAN>"
|
||||
fake.AI_PROMPT = "<AI>"
|
||||
|
||||
modules_to_remove = [key for key in sys.modules if key.startswith("anthropic")]
|
||||
for key in modules_to_remove:
|
||||
sys.modules.pop(key, None)
|
||||
sys.modules["anthropic"] = fake
|
||||
|
||||
if "application.llm.anthropic" in sys.modules:
|
||||
del sys.modules["application.llm.anthropic"]
|
||||
yield
|
||||
sys.modules.pop("anthropic", None)
|
||||
if "application.llm.anthropic" in sys.modules:
|
||||
del sys.modules["application.llm.anthropic"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm():
|
||||
from application.llm.anthropic import AnthropicLLM
|
||||
|
||||
instance = AnthropicLLM(api_key="test-key")
|
||||
instance.storage = types.SimpleNamespace(
|
||||
get_file=lambda path: _ctx_manager(b"img_bytes"),
|
||||
)
|
||||
return instance
|
||||
|
||||
|
||||
def _ctx_manager(data):
|
||||
"""Create a simple context manager returning an object with .read()."""
|
||||
import contextlib
|
||||
|
||||
@contextlib.contextmanager
|
||||
def cm():
|
||||
yield types.SimpleNamespace(read=lambda: data)
|
||||
|
||||
return cm()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constructor
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestAnthropicConstructor:
|
||||
|
||||
def test_api_key_set(self):
|
||||
from application.llm.anthropic import AnthropicLLM
|
||||
|
||||
instance = AnthropicLLM(api_key="custom-key")
|
||||
assert instance.api_key == "custom-key"
|
||||
|
||||
def test_base_url_passed(self):
|
||||
from application.llm.anthropic import AnthropicLLM
|
||||
|
||||
instance = AnthropicLLM(api_key="k", base_url="https://custom.api")
|
||||
assert instance.anthropic.base_url == "https://custom.api"
|
||||
|
||||
def test_no_base_url(self):
|
||||
from application.llm.anthropic import AnthropicLLM
|
||||
|
||||
instance = AnthropicLLM(api_key="k")
|
||||
assert instance.anthropic.base_url is None
|
||||
|
||||
def test_human_and_ai_prompts_set(self):
|
||||
from application.llm.anthropic import AnthropicLLM
|
||||
|
||||
instance = AnthropicLLM(api_key="k")
|
||||
assert instance.HUMAN_PROMPT == "<HUMAN>"
|
||||
assert instance.AI_PROMPT == "<AI>"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _raw_gen
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestRawGen:
|
||||
|
||||
def test_returns_completion(self, llm):
|
||||
msgs = [{"content": "context"}, {"content": "question"}]
|
||||
result = llm._raw_gen(llm, model="claude-2", messages=msgs)
|
||||
assert result == "final"
|
||||
|
||||
def test_prompt_contains_context_and_question(self, llm):
|
||||
msgs = [{"content": "my context"}, {"content": "my question"}]
|
||||
llm._raw_gen(llm, model="claude-2", messages=msgs)
|
||||
prompt = llm.anthropic.completions.last_kwargs["prompt"]
|
||||
assert "my context" in prompt
|
||||
assert "my question" in prompt
|
||||
|
||||
def test_max_tokens_passed(self, llm):
|
||||
msgs = [{"content": "c"}, {"content": "q"}]
|
||||
llm._raw_gen(llm, model="claude-2", messages=msgs, max_tokens=200)
|
||||
assert llm.anthropic.completions.last_kwargs["max_tokens_to_sample"] == 200
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _raw_gen_stream
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestRawGenStream:
|
||||
|
||||
def test_yields_all_completions(self, llm):
|
||||
msgs = [{"content": "c"}, {"content": "q"}]
|
||||
chunks = list(
|
||||
llm._raw_gen_stream(llm, model="claude", messages=msgs, max_tokens=10)
|
||||
)
|
||||
assert chunks == ["s1", "s2"]
|
||||
|
||||
def test_calls_close_on_response(self, llm):
|
||||
closed = {"called": False}
|
||||
original = llm.anthropic.completions._stream_items
|
||||
|
||||
class ClosableList(list):
|
||||
def close(self):
|
||||
closed["called"] = True
|
||||
|
||||
closable = ClosableList(original)
|
||||
llm.anthropic.completions._stream_items = closable
|
||||
llm.anthropic.completions.create = lambda **kw: closable
|
||||
|
||||
msgs = [{"content": "c"}, {"content": "q"}]
|
||||
list(llm._raw_gen_stream(llm, model="claude", messages=msgs))
|
||||
assert closed["called"]
|
||||
|
||||
def test_prompt_format(self, llm):
|
||||
msgs = [{"content": "ctx"}, {"content": "q"}]
|
||||
list(llm._raw_gen_stream(llm, model="claude", messages=msgs))
|
||||
prompt = llm.anthropic.completions.last_kwargs["prompt"]
|
||||
assert prompt.startswith("<HUMAN>")
|
||||
assert prompt.endswith("<AI>")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_supported_attachment_types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGetSupportedAttachmentTypes:
|
||||
|
||||
def test_returns_image_types(self, llm):
|
||||
result = llm.get_supported_attachment_types()
|
||||
assert "image/png" in result
|
||||
assert "image/jpeg" in result
|
||||
assert "image/webp" in result
|
||||
assert "image/gif" in result
|
||||
|
||||
def test_no_pdf_support(self, llm):
|
||||
result = llm.get_supported_attachment_types()
|
||||
assert "application/pdf" not in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# prepare_messages_with_attachments
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestPrepareMessagesWithAttachments:
|
||||
|
||||
def test_no_attachments_returns_same(self, llm):
|
||||
msgs = [{"role": "user", "content": "hi"}]
|
||||
result = llm.prepare_messages_with_attachments(msgs)
|
||||
assert result == msgs
|
||||
|
||||
def test_empty_attachments_returns_same(self, llm):
|
||||
msgs = [{"role": "user", "content": "hi"}]
|
||||
result = llm.prepare_messages_with_attachments(msgs, [])
|
||||
assert result == msgs
|
||||
|
||||
def test_image_with_preconverted_data(self, llm):
|
||||
msgs = [{"role": "user", "content": "look"}]
|
||||
attachments = [{"mime_type": "image/png", "data": "AABBCC"}]
|
||||
result = llm.prepare_messages_with_attachments(msgs, attachments)
|
||||
user_msg = next(m for m in result if m["role"] == "user")
|
||||
img_part = next(
|
||||
p for p in user_msg["content"] if p.get("type") == "image"
|
||||
)
|
||||
assert img_part["source"]["data"] == "AABBCC"
|
||||
assert img_part["source"]["type"] == "base64"
|
||||
assert img_part["source"]["media_type"] == "image/png"
|
||||
|
||||
def test_image_from_storage(self, llm):
|
||||
llm.storage = types.SimpleNamespace(
|
||||
get_file=lambda p: _ctx_manager(b"raw_image_bytes"),
|
||||
)
|
||||
msgs = [{"role": "user", "content": "look"}]
|
||||
attachments = [{"mime_type": "image/jpeg", "path": "/tmp/img.jpg"}]
|
||||
result = llm.prepare_messages_with_attachments(msgs, attachments)
|
||||
user_msg = next(m for m in result if m["role"] == "user")
|
||||
img_part = next(
|
||||
p for p in user_msg["content"] if p.get("type") == "image"
|
||||
)
|
||||
assert img_part["source"]["media_type"] == "image/jpeg"
|
||||
assert len(img_part["source"]["data"]) > 0
|
||||
|
||||
def test_no_user_message_creates_one(self, llm):
|
||||
msgs = [{"role": "system", "content": "sys"}]
|
||||
attachments = [{"mime_type": "image/png", "data": "AAA"}]
|
||||
result = llm.prepare_messages_with_attachments(msgs, attachments)
|
||||
user_msgs = [m for m in result if m["role"] == "user"]
|
||||
assert len(user_msgs) == 1
|
||||
|
||||
def test_image_error_adds_text_fallback(self, llm):
|
||||
def bad_storage(path):
|
||||
raise Exception("storage error")
|
||||
|
||||
llm.storage = types.SimpleNamespace(get_file=bad_storage)
|
||||
msgs = [{"role": "user", "content": "look"}]
|
||||
attachments = [
|
||||
{"mime_type": "image/png", "path": "/bad.png", "content": "fb"},
|
||||
]
|
||||
result = llm.prepare_messages_with_attachments(msgs, attachments)
|
||||
user_msg = next(m for m in result if m["role"] == "user")
|
||||
text_parts = [
|
||||
p for p in user_msg["content"]
|
||||
if p.get("type") == "text" and "could not" in p.get("text", "").lower()
|
||||
]
|
||||
assert len(text_parts) == 1
|
||||
|
||||
def test_non_image_attachment_ignored(self, llm):
|
||||
msgs = [{"role": "user", "content": "look"}]
|
||||
attachments = [{"mime_type": "application/pdf"}]
|
||||
result = llm.prepare_messages_with_attachments(msgs, attachments)
|
||||
user_msg = next(m for m in result if m["role"] == "user")
|
||||
# content becomes list with just original text
|
||||
assert isinstance(user_msg["content"], list)
|
||||
assert len(user_msg["content"]) == 1
|
||||
|
||||
def test_content_not_list_becomes_empty(self, llm):
|
||||
msgs = [{"role": "user", "content": 999}]
|
||||
attachments = [{"mime_type": "image/png", "data": "AAA"}]
|
||||
result = llm.prepare_messages_with_attachments(msgs, attachments)
|
||||
user_msg = next(m for m in result if m["role"] == "user")
|
||||
assert isinstance(user_msg["content"], list)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _get_base64_image
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGetBase64Image:
|
||||
|
||||
def test_raises_for_no_path(self, llm):
|
||||
with pytest.raises(ValueError, match="No file path"):
|
||||
llm._get_base64_image({})
|
||||
|
||||
def test_raises_for_file_not_found(self, llm):
|
||||
import contextlib
|
||||
|
||||
@contextlib.contextmanager
|
||||
def bad_file(path):
|
||||
raise FileNotFoundError("not found")
|
||||
|
||||
llm.storage = types.SimpleNamespace(get_file=bad_file)
|
||||
with pytest.raises(FileNotFoundError):
|
||||
llm._get_base64_image({"path": "/nonexistent"})
|
||||
|
||||
def test_returns_base64_encoded(self, llm):
|
||||
import base64
|
||||
|
||||
llm.storage = types.SimpleNamespace(
|
||||
get_file=lambda p: _ctx_manager(b"test_data"),
|
||||
)
|
||||
result = llm._get_base64_image({"path": "/tmp/img.png"})
|
||||
decoded = base64.b64decode(result)
|
||||
assert decoded == b"test_data"
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user