Compare commits

...

33 Commits

Author SHA1 Message Date
Alex
727495c553 fix: mongo in unit tests 2026-03-31 00:34:49 +01:00
Alex
a3b08a5b44 More tests 2026-03-31 00:07:19 +01:00
Alex
d5c0322e2a chore: more tests 2026-03-30 16:13:08 +01:00
Alex
dc6db847ca Merge pull request #2339 from arc53/test-handlers
chore: handlers tests
2026-03-30 13:06:53 +01:00
Alex
ed0063aada chore: handlers tests 2026-03-30 12:53:50 +01:00
Alex
3f6d6f15ea Merge pull request #2338 from arc53/tests-utils
chore: utils tests
2026-03-29 11:54:59 +01:00
Alex
126fa01b14 chore: utils tests 2026-03-29 11:49:35 +01:00
Alex
e06debad5f chore: connector tests 2026-03-29 10:32:48 +01:00
Alex
6492852f7d fix: lint on seeder test 2026-03-29 09:48:46 +01:00
Alex
00a621f33a tests: retriever and seeder 2026-03-29 09:46:07 +01:00
Alex
e92ffc6fdc Merge pull request #2337 from arc53/tests-api
chore: api and tool tests
2026-03-28 22:55:10 +00:00
Alex
fe185e5b8d chore: api and tool tests 2026-03-28 21:51:47 +00:00
Alex
9f3d9ab860 docs: agent types 2026-03-28 18:50:50 +00:00
Alex
1c0adde380 chore: docs update 2026-03-28 17:04:06 +00:00
Alex
3c56bd0d0b docs: fix conflicts 2026-03-28 16:16:15 +00:00
Alex
86664ebda2 Merge pull request #2320 from Alex-wuhu/novita-integration
feat: complete Novita AI provider integration
2026-03-28 13:22:02 +00:00
Alex
db18b743d1 fix tests 2 2026-03-28 13:10:41 +00:00
Alex
9e85cc9065 fix: test errors 2026-03-28 13:04:21 +00:00
Alex
aaaa6f002d Merge branch 'main' of https://github.com/arc53/DocsGPT 2026-03-28 12:03:27 +00:00
Alex
47dcbcb74b fix: tests and sources on workflow agent 2026-03-28 12:03:16 +00:00
Alex
ddbfd94193 Adjust demo GIF height in README
Update the height of the demo GIF in README.
2026-03-28 11:09:05 +00:00
Alex
8dec60ab8b Update demo GIF in README
Replaced demo video GIF in README with a new example.
2026-03-28 11:08:10 +00:00
Alex
84b2e4bab4 fix: end node multi input 2026-03-28 10:00:01 +00:00
Alex
2afdd7f026 fix: enable tools in workflow agents 2026-03-27 13:43:09 +00:00
Alex
f364475f64 Merge pull request #2335 from arc53/tests-worker
tests: worker coverage
2026-03-27 12:14:56 +00:00
Alex
b254de6ed6 tests: worker coverage 2026-03-27 12:03:55 +00:00
Alex
08dedcaf95 Merge pull request #2334 from arc53/tests-vectors
tests: vectors
2026-03-26 19:11:20 +00:00
Alex
c726eb8ebd fix: ruff 2026-03-26 18:48:53 +00:00
Alex
5f0d39e5f1 tests: vectors 2026-03-26 18:42:59 +00:00
Alex
8c82fc5495 Merge pull request #2333 from arc53/chore/bump-npm-v0.6.3
chore: bump npm libraries to v0.6.3
2026-03-26 14:17:09 +00:00
github-actions[bot]
6d81a15e97 chore: bump npm libraries to v0.6.3 2026-03-26 14:16:32 +00:00
Alex
5478e4234c chore: bump npm again 2026-03-26 14:06:05 +00:00
Alex-wuhu
eaf39bb15b feat: add Novita AI as LLM provider
Add Novita AI (https://novita.ai) as a new LLM provider option.
Novita offers OpenAI-compatible API endpoints with competitive pricing.
2026-03-23 10:52:26 +08:00
162 changed files with 66006 additions and 274 deletions

View File

@@ -3,6 +3,14 @@ LLM_NAME=docsgpt
VITE_API_STREAMING=true
INTERNAL_KEY=<internal key for worker-to-backend authentication>
# Provider-specific API keys (optional - use these to enable multiple providers)
# OPENAI_API_KEY=<your-openai-api-key>
# ANTHROPIC_API_KEY=<your-anthropic-api-key>
# GOOGLE_API_KEY=<your-google-api-key>
# GROQ_API_KEY=<your-groq-api-key>
# NOVITA_API_KEY=<your-novita-api-key>
# OPEN_ROUTER_API_KEY=<your-openrouter-api-key>
# Remote Embeddings (Optional - for using a remote embeddings API instead of local SentenceTransformer)
# When set, the app will use the remote API and won't load SentenceTransformer (saves RAM)
EMBEDDINGS_BASE_URL=

View File

@@ -29,7 +29,7 @@
<div align="center">
<br>
<img src="https://d3dg1063dc54p9.cloudfront.net/videos/demov7.gif" alt="video-example-of-docs-gpt" width="800" height="450">
<img src="https://d3dg1063dc54p9.cloudfront.net/videos/demo-26.gif" alt="video-example-of-docs-gpt" width="800" height="480">
</div>
<h3 align="left">
<strong>Key Features:</strong>

View File

@@ -185,7 +185,10 @@ class ToolExecutor:
target_dict[param] = value
# Load tool (with caching)
tool = self._get_or_load_tool(tool_data, tool_id, action_name)
tool = self._get_or_load_tool(
tool_data, tool_id, action_name,
headers=headers, query_params=query_params,
)
resolved_arguments = (
{"query_params": query_params, "headers": headers, "body": body}
@@ -238,7 +241,10 @@ class ToolExecutor:
return result, call_id
def _get_or_load_tool(self, tool_data: Dict, tool_id: str, action_name: str):
def _get_or_load_tool(
self, tool_data: Dict, tool_id: str, action_name: str,
headers: Optional[Dict] = None, query_params: Optional[Dict] = None,
):
"""Load a tool, using cache when possible."""
cache_key = f"{tool_data['name']}:{tool_id}:{self.user or ''}"
if cache_key in self._loaded_tools:
@@ -251,8 +257,8 @@ class ToolExecutor:
tool_config = {
"url": action_config["url"],
"method": action_config["method"],
"headers": {},
"query_params": {},
"headers": headers or {},
"query_params": query_params or {},
}
if "body_content_type" in action_config:
tool_config["body_content_type"] = action_config.get(

View File

@@ -27,6 +27,8 @@ ANTHROPIC_ATTACHMENTS = IMAGE_ATTACHMENTS
OPENROUTER_ATTACHMENTS = IMAGE_ATTACHMENTS
NOVITA_ATTACHMENTS = IMAGE_ATTACHMENTS
OPENAI_MODELS = [
AvailableModel(
@@ -193,6 +195,46 @@ OPENROUTER_MODELS = [
),
]
NOVITA_MODELS = [
AvailableModel(
id="moonshotai/kimi-k2.5",
provider=ModelProvider.NOVITA,
display_name="Kimi K2.5",
description="MoE model with function calling, structured output, reasoning, and vision",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=NOVITA_ATTACHMENTS,
context_window=262144,
),
),
AvailableModel(
id="zai-org/glm-5",
provider=ModelProvider.NOVITA,
display_name="GLM-5",
description="MoE model with function calling, structured output, and reasoning",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=[],
context_window=202800,
),
),
AvailableModel(
id="minimax/minimax-m2.5",
provider=ModelProvider.NOVITA,
display_name="MiniMax M2.5",
description="MoE model with function calling, structured output, and reasoning",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=[],
context_window=204800,
),
),
]
AZURE_OPENAI_MODELS = [
AvailableModel(
id="azure-gpt-4",

View File

@@ -114,6 +114,10 @@ class ModelRegistry:
settings.LLM_PROVIDER == "openrouter" and settings.API_KEY
):
self._add_openrouter_models(settings)
if settings.NOVITA_API_KEY or (
settings.LLM_PROVIDER == "novita" and settings.API_KEY
):
self._add_novita_models(settings)
if settings.HUGGINGFACE_API_KEY or (
settings.LLM_PROVIDER == "huggingface" and settings.API_KEY
):
@@ -245,6 +249,21 @@ class ModelRegistry:
for model in OPENROUTER_MODELS:
self.models[model.id] = model
def _add_novita_models(self, settings):
from application.core.model_configs import NOVITA_MODELS
if settings.NOVITA_API_KEY:
for model in NOVITA_MODELS:
self.models[model.id] = model
return
if settings.LLM_PROVIDER == "novita" and settings.LLM_NAME:
for model in NOVITA_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in NOVITA_MODELS:
self.models[model.id] = model
def _add_docsgpt_models(self, settings):
model_id = "docsgpt-local"
model = AvailableModel(

View File

@@ -10,6 +10,7 @@ def get_api_key_for_provider(provider: str) -> Optional[str]:
provider_key_map = {
"openai": settings.OPENAI_API_KEY,
"openrouter": settings.OPEN_ROUTER_API_KEY,
"novita": settings.NOVITA_API_KEY,
"anthropic": settings.ANTHROPIC_API_KEY,
"google": settings.GOOGLE_API_KEY,
"groq": settings.GROQ_API_KEY,

View File

@@ -5,9 +5,7 @@ from typing import Optional
from pydantic import field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
class Settings(BaseSettings):
@@ -15,15 +13,11 @@ class Settings(BaseSettings):
AUTH_TYPE: Optional[str] = None # simple_jwt, session_jwt, or None
LLM_PROVIDER: str = "docsgpt"
LLM_NAME: Optional[str] = (
None # if LLM_PROVIDER is openai, LLM_NAME can be gpt-4 or gpt-3.5-turbo
)
LLM_NAME: Optional[str] = None # if LLM_PROVIDER is openai, LLM_NAME can be gpt-4 or gpt-3.5-turbo
EMBEDDINGS_NAME: str = "huggingface_sentence-transformers/all-mpnet-base-v2"
EMBEDDINGS_BASE_URL: Optional[str] = None # Remote embeddings API URL (OpenAI-compatible)
EMBEDDINGS_KEY: Optional[str] = (
None # api key for embeddings (if using openai, just copy API_KEY)
)
EMBEDDINGS_KEY: Optional[str] = None # api key for embeddings (if using openai, just copy API_KEY)
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
MONGO_URI: str = "mongodb://localhost:27017/docsgpt"
@@ -45,9 +39,7 @@ class Settings(BaseSettings):
PARSE_IMAGE_REMOTE: bool = False
DOCLING_OCR_ENABLED: bool = False # Enable OCR for docling parsers (PDF, images)
DOCLING_OCR_ATTACHMENTS_ENABLED: bool = False # Enable OCR for docling when parsing attachments
VECTOR_STORE: str = (
"faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb" or "pgvector"
)
VECTOR_STORE: str = "faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb" or "pgvector"
RETRIEVERS_ENABLED: list = ["classic_rag"]
AGENT_NAME: str = "classic"
FALLBACK_LLM_PROVIDER: Optional[str] = None # provider for fallback llm
@@ -55,12 +47,8 @@ class Settings(BaseSettings):
FALLBACK_LLM_API_KEY: Optional[str] = None # api key for fallback llm
# Google Drive integration
GOOGLE_CLIENT_ID: Optional[str] = (
None # Replace with your actual Google OAuth client ID
)
GOOGLE_CLIENT_SECRET: Optional[str] = (
None # Replace with your actual Google OAuth client secret
)
GOOGLE_CLIENT_ID: Optional[str] = None # Replace with your actual Google OAuth client ID
GOOGLE_CLIENT_SECRET: Optional[str] = None # Replace with your actual Google OAuth client secret
CONNECTOR_REDIRECT_BASE_URI: Optional[str] = (
"http://127.0.0.1:7091/api/connectors/callback" ##add redirect url as it is to your provider's console(gcp)
)
@@ -72,7 +60,7 @@ class Settings(BaseSettings):
MICROSOFT_AUTHORITY: Optional[str] = None # e.g., "https://login.microsoftonline.com/{tenant_id}"
# GitHub source
GITHUB_ACCESS_TOKEN: Optional[str] = None # PAT token with read repo access
GITHUB_ACCESS_TOKEN: Optional[str] = None # PAT token with read repo access
# LLM Cache
CACHE_REDIS_URL: str = "redis://localhost:6379/2"
@@ -90,16 +78,13 @@ class Settings(BaseSettings):
GROQ_API_KEY: Optional[str] = None
HUGGINGFACE_API_KEY: Optional[str] = None
OPEN_ROUTER_API_KEY: Optional[str] = None
NOVITA_API_KEY: Optional[str] = None
OPENAI_API_BASE: Optional[str] = None # azure openai api base url
OPENAI_API_VERSION: Optional[str] = None # azure openai api version
AZURE_DEPLOYMENT_NAME: Optional[str] = None # azure deployment name for answering
AZURE_EMBEDDINGS_DEPLOYMENT_NAME: Optional[str] = (
None # azure deployment name for embeddings
)
OPENAI_BASE_URL: Optional[str] = (
None # openai base url for open ai compatable models
)
AZURE_EMBEDDINGS_DEPLOYMENT_NAME: Optional[str] = None # azure deployment name for embeddings
OPENAI_BASE_URL: Optional[str] = None # openai base url for open ai compatable models
# elasticsearch
ELASTIC_CLOUD_ID: Optional[str] = None # cloud id for elasticsearch
@@ -141,9 +126,7 @@ class Settings(BaseSettings):
# LanceDB vectorstore config
LANCEDB_PATH: str = "./data/lancedb" # Path where LanceDB stores its local data
LANCEDB_TABLE_NAME: Optional[str] = (
"docsgpts" # Name of the table to use for storing vectors
)
LANCEDB_TABLE_NAME: Optional[str] = "docsgpts" # Name of the table to use for storing vectors
FLASK_DEBUG_MODE: bool = False
STORAGE_TYPE: str = "local" # local or s3
@@ -180,6 +163,7 @@ class Settings(BaseSettings):
"GOOGLE_API_KEY",
"GROQ_API_KEY",
"HUGGINGFACE_API_KEY",
"NOVITA_API_KEY",
"EMBEDDINGS_KEY",
"FALLBACK_LLM_API_KEY",
"QDRANT_API_KEY",

View File

@@ -7,6 +7,7 @@ class LLMHandlerCreator:
handlers = {
"openai": OpenAILLMHandler,
"google": GoogleLLMHandler,
"novita": OpenAILLMHandler, # Novita uses OpenAI-compatible API
"default": OpenAILLMHandler,
}

View File

@@ -1,13 +1,13 @@
from application.core.settings import settings
from application.llm.openai import OpenAILLM
NOVITA_BASE_URL = "https://api.novita.ai/v3/openai"
NOVITA_BASE_URL = "https://api.novita.ai/openai"
class NovitaLLM(OpenAILLM):
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
super().__init__(
api_key=api_key or settings.API_KEY,
api_key=api_key or settings.NOVITA_API_KEY or settings.API_KEY,
user_api_key=user_api_key,
base_url=base_url or NOVITA_BASE_URL,
*args,

View File

@@ -44,36 +44,40 @@ The main set of instructions or system [prompt](/Guides/Customising-prompts) tha
## Understanding Agent Types
DocsGPT allows for different "types" of agents, each with a distinct way of processing information and generating responses. The code for these agent types can be found in the `application/agents/` directory.
DocsGPT supports several agent types, each with a distinct way of processing information. The code for these can be found in the `application/agents/` directory.
### 1. Classic Agent (`classic_agent.py`)
### 1. Classic Agent
**How it works:** The Classic Agent follows a traditional Retrieval Augmented Generation (RAG) approach.
1. **Retrieve:** When a query is made, it first searches the selected Source documents for relevant information.
2. **Augment:** This retrieved data is then added to the context, along with the main Prompt and the user's query.
3. **Generate:** The LLM generates a response based on this augmented context. It can also utilize any configured tools if the LLM decides they are necessary.
The Classic Agent follows a traditional Retrieval Augmented Generation (RAG) approach: it retrieves relevant document chunks, augments the prompt context with them, and generates a response. It can also use configured tools if the LLM decides they are necessary.
**Best for:**
* Direct question-answering over a specific set of documents.
* Tasks where the primary goal is to extract and synthesize information from the provided sources.
* Simpler tool integrations where the decision to use a tool is straightforward.
**Best for:** Direct question-answering over a specific set of documents and straightforward tool use.
### 2. ReAct Agent (`react_agent.py`)
### 2. Agentic Agent
**How it works:** The ReAct Agent employs a more sophisticated "Reason and Act" framework. This involves a multi-step process:
1. **Plan (Thought):** Based on the query, its prompt, and available tools/sources, the LLM first generates a plan or a sequence of thoughts on how to approach the problem. You might see this output as a "thought" process during generation.
2. **Act:** The agent then executes actions based on this plan. This might involve querying its sources, using a tool, or performing internal reasoning.
3. **Observe:** It gathers observations from the results of its actions (e.g., data from a tool, snippets from documents).
4. **Repeat (if necessary):** Steps 2 and 3 can be repeated as the agent refines its approach or gathers more information.
5. **Conclude:** Finally, it generates the final answer based on the initial query and all accumulated observations.
Unlike Classic which pre-fetches documents into the prompt, the Agentic Agent gives the LLM an `internal_search` tool so it can decide **when, what, and whether** to search. This means the LLM controls its own retrieval — it can search multiple times, refine queries, or skip retrieval entirely if the question doesn't need it.
**Best for:**
* More complex tasks that require multi-step reasoning or problem-solving.
* Scenarios where the agent needs to dynamically decide which tools to use and in what order, based on intermediate results.
* Interactive tasks where the agent needs to "think" through a problem.
**Best for:** Tasks where the agent needs to dynamically decide how to gather information, use multiple tools in sequence, or combine retrieval with external tool calls.
### 3. Research Agent
A multi-phase agent designed for in-depth research tasks:
1. **Clarification** — Determines if the question needs clarification before proceeding.
2. **Planning** — Decomposes the question into research steps with adaptive depth based on complexity.
3. **Research** — Executes each step, calling tools and refining queries as needed.
4. **Synthesis** — Compiles findings into a final cited report.
Includes budget controls for max steps, timeout, and token limits to keep research bounded.
**Best for:** Complex questions that require multi-step investigation, gathering information from multiple sources, and producing structured reports with citations.
### 4. Workflow Agent
Executes predefined workflows composed of connected nodes (AI Agent, Set State, Condition). See the [Workflow Nodes](/Agents/nodes) page for details on building workflows.
**Best for:** Structured, multi-step processes with branching logic and shared state between steps.
<Callout type="info">
Developers looking to introduce new agent architectures can explore the `application/agents/` directory. `classic_agent.py` and `react_agent.py` serve as excellent starting points, demonstrating how to inherit from `BaseAgent` and structure agent logic.
The legacy "ReAct" agent type is still accepted for backwards compatibility but maps to the Classic Agent internally. New agents should use Classic, Agentic, or Research instead.
</Callout>
## Navigating and Managing Agents in DocsGPT

View File

@@ -70,9 +70,9 @@ Inside the DocsGPT folder create a `.env` file and copy the contents of `.env_sa
Make sure your `.env` file looks like this:
```
OPENAI_API_KEY=(Your OpenAI API key)
API_KEY=<Your LLM API key>
LLM_NAME=docsgpt
VITE_API_STREAMING=true
SELF_HOSTED_MODEL=false
```
To save the file, press CTRL+X, then Y, and then ENTER.

View File

@@ -104,7 +104,7 @@ DocsGPT can transcribe audio in two places:
- Voice input in the chat.
- Audio file ingestion. Uploaded `.wav`, `.mp3`, `.m4a`, `.ogg`, and `.webm` files are transcribed first and then passed through the normal parser, chunking, embedding, and indexing pipeline.
For an end-to-end walkthrough, see the [Speech and Audio Guide](/Guides/speech-and-audio).
The settings below control speech-to-text behaviour for both voice input and audio file ingestion.
| Setting | Purpose | Typical values |
| --- | --- | --- |
@@ -214,6 +214,31 @@ If you have configured `AUTH_TYPE=simple_jwt`, the DocsGPT frontend will prompt
}}
/>
## S3 Storage Backend
By default DocsGPT stores files locally. Set `STORAGE_TYPE=s3` to use Amazon S3 instead.
| Setting | Description | Default |
| --- | --- | --- |
| `STORAGE_TYPE` | `local` or `s3` | `local` |
| `S3_BUCKET_NAME` | S3 bucket name | `docsgpt-test-bucket` |
| `SAGEMAKER_ACCESS_KEY` | AWS access key ID | — |
| `SAGEMAKER_SECRET_KEY` | AWS secret access key | — |
| `SAGEMAKER_REGION` | AWS region | — |
| `URL_STRATEGY` | `backend` (proxy through API) or `s3` (direct S3 URLs) | `backend` |
The S3 credentials use `SAGEMAKER_*` variable names because they are shared with the SageMaker integration.
```env
STORAGE_TYPE=s3
S3_BUCKET_NAME=your-bucket-name
SAGEMAKER_ACCESS_KEY=your-aws-access-key-id
SAGEMAKER_SECRET_KEY=your-aws-secret-access-key
SAGEMAKER_REGION=us-east-1
```
Your IAM user needs these permissions on the bucket: `s3:PutObject`, `s3:GetObject`, `s3:DeleteObject`, `s3:ListBucket`, `s3:HeadObject`.
## Exploring More Settings
These are just the basic settings to get you started. The `settings.py` file contains many more advanced options that you can explore to further customize DocsGPT, such as:

View File

@@ -86,13 +86,9 @@ Make sure your `.env` file looks like this:
```
OPENAI_API_KEY=(Your OpenAI API key)
API_KEY=<Your LLM API key>
LLM_NAME=docsgpt
VITE_API_STREAMING=true
SELF_HOSTED_MODEL=false
```

View File

@@ -11,18 +11,18 @@ DocsGPT API keys are essential for developers and users who wish to integrate th
After uploading your document, you can obtain an API key either through the graphical user interface or via an API call:
- **Graphical User Interface:** Navigate to the Settings section of the DocsGPT web app, find the API Keys option, and press 'Create New' to generate your key.
- **API Call:** Alternatively, you can use the `/api/create_api_key` endpoint to create a new API key. For detailed instructions, visit [DocsGPT API Documentation](https://gptcloud.arc53.com/).
- **Graphical User Interface:** Navigate to the Settings section of the DocsGPT web app, find the Agents option, and press 'Create New' to generate a new agent (which includes an API key).
- **API Call:** Alternatively, you can use the `/api/create_agent` endpoint to create a new agent. An API key is automatically generated for each agent. For detailed instructions, visit [DocsGPT API Documentation](https://gptcloud.arc53.com/).
## Understanding Key Variables
Upon creating your API key, you will encounter several key variables. Each serves a specific purpose:
Upon creating your agent, you will encounter several key variables. Each serves a specific purpose:
- **Name:** Assign a name to your API key for easy identification.
- **Source:** Indicates the source document(s) linked to your API key, which DocsGPT will use to generate responses.
- **ID:** A unique identifier for your API key. You can view this by making a call to `/api/get_api_keys`.
- **Key:** The API key itself, which will be used in your application to authenticate API requests.
- **Name:** Assign a name to your agent for easy identification.
- **Source:** Indicates the source document(s) linked to your agent, which DocsGPT will use to generate responses.
- **ID:** A unique identifier for your agent. You can view this by making a call to `/api/get_agents`.
- **Key:** The API key for the agent, which will be used in your application to authenticate API requests.
With your API key ready, you can now integrate DocsGPT into your application, such as the DocsGPT Widget or any other software, via `/api/answer` or `/stream` endpoints. The source document is preset with the API key, allowing you to bypass fields like `selectDocs` and `active_docs` during implementation.
With your API key ready, you can now integrate DocsGPT into your application, such as the DocsGPT Widget or any other software, via `/api/answer` or `/stream` endpoints. The source document is preset with the agent, allowing you to bypass fields like `selectDocs` and `active_docs` during implementation.
Congratulations on taking the first step towards enhancing your applications with DocsGPT!

View File

@@ -64,7 +64,7 @@ flowchart LR
* **Technology:** Supports multiple vector databases.
* **Responsibility:** Vector Stores are used to store and retrieve vector embeddings of document chunks. This enables semantic search and retrieval of relevant document snippets in response to user queries.
* **Key Features:**
* Supports vector databases including FAISS, Elasticsearch, Qdrant, Milvus, and LanceDB.
* Supports vector databases including FAISS, Elasticsearch, Qdrant, Milvus, MongoDB Atlas Vector Search, and pgvector.
* Provides storage and indexing of high-dimensional vector embeddings.
* Enables editing and updating of vector indexes including specific chunks.

View File

@@ -16,7 +16,7 @@ Training on other documentation sources can greatly enhance the versatility and
Make sure you have the document on which you want to train on ready with you on the device which you are using .You can also use links to the documentation to train on.
<Callout type="warning" emoji="⚠️">
Note: The document should be either of the given file formats .pdf, .txt, .rst, .docx, .md, .zip and limited to 25mb.You can also train using the link of the documentation.
Note: Supported file formats include .pdf, .txt, .rst, .docx, .md, .mdx, .csv, .epub, .html, .json, .xlsx, .pptx, .png, .jpg, .jpeg, and audio files (.wav, .mp3, .m4a, .ogg, .webm). You can also train using the link of the documentation.
</Callout>

View File

@@ -35,8 +35,34 @@ Choose the LLM of your choice.
For open source version please edit `LLM_PROVIDER`, `LLM_NAME` and others in the .env file. Refer to [⚙️ App Configuration](/Deploying/DocsGPT-Settings) for more information.
### Step 2
Visit [☁️ Cloud Providers](/Models/cloud-providers) for the updated list of online models. Make sure you have the right API_KEY and correct LLM_PROVIDER.
For self-hosted please visit [🖥️ Local Inference](/Models/local-inference).
For self-hosted please visit [🖥️ Local Inference](/Models/local-inference).
</Steps>
## Fallback LLM
DocsGPT can automatically switch to a fallback LLM when the primary model fails, including mid-stream. This works with both streaming and non-streaming requests.
**Fallback order:**
1. Per-agent backup models (other models configured on the same agent)
2. Global fallback (`FALLBACK_LLM_*` env vars below)
3. Error returned if all fail
| Setting | Description | Default |
| --- | --- | --- |
| `FALLBACK_LLM_PROVIDER` | Provider name (e.g., `openai`, `anthropic`, `google`) | — |
| `FALLBACK_LLM_NAME` | Model name (e.g., `gpt-4o`, `claude-sonnet-4-20250514`) | — |
| `FALLBACK_LLM_API_KEY` | API key for the fallback provider | Falls back to `API_KEY` |
All three (`FALLBACK_LLM_PROVIDER`, `FALLBACK_LLM_NAME`, and an API key) must resolve for the global fallback to activate.
```env
FALLBACK_LLM_PROVIDER=anthropic
FALLBACK_LLM_NAME=claude-sonnet-4-20250514
FALLBACK_LLM_API_KEY=sk-ant-your-anthropic-key
```
<Callout type="info">
For maximum resilience, use a fallback provider from a different cloud than your primary. Each agent can also have multiple models configured — the other models are tried first before the global fallback.
</Callout>

View File

@@ -2,5 +2,13 @@ export default {
"google-drive-connector": {
"title": "🔗 Google Drive",
"href": "/Guides/Integrations/google-drive-connector"
},
"sharepoint-connector": {
"title": "🔗 SharePoint / OneDrive",
"href": "/Guides/Integrations/sharepoint-connector"
},
"mcp-tool-integration": {
"title": "🔗 MCP Tools",
"href": "/Guides/Integrations/mcp-tool-integration"
}
}

View File

@@ -0,0 +1,66 @@
---
title: MCP Tool Integration
description: Connect external tools to DocsGPT agents using the Model Context Protocol (MCP) standard.
---
import { Callout } from 'nextra/components'
import { Steps } from 'nextra/components'
# MCP Tool Integration
The [Model Context Protocol (MCP)](https://modelcontextprotocol.io/) integration lets you connect external tool servers to DocsGPT. Your agents can then discover and call tools provided by those servers during conversations — for example, querying a CRM, running code, or accessing a database.
## Setup
<Steps>
### Step 1: Configure Environment Variables (Optional)
Only needed if your MCP servers use OAuth authentication:
```env
MCP_OAUTH_REDIRECT_URI=https://yourdomain.com/api/mcp_server/callback
```
If not set, falls back to `API_URL/api/mcp_server/callback`.
### Step 2: Add an MCP Server
Go to **Settings** > **Tools** > **Add Tool** > **MCP Server**. Enter the server URL, select an auth type, and click **Test Connection** to verify, then **Save**.
### Step 3: Enable for Your Agent
In your agent configuration, enable the MCP tools you want the agent to use.
</Steps>
## Authentication Types
| Auth Type | Config Fields |
|-----------|---------------|
| **None** | — |
| **Bearer** | `bearer_token` |
| **API Key** | `api_key`, `api_key_header` (default: `X-API-Key`) |
| **Basic** | `username`, `password` |
| **OAuth** | `oauth_scopes` (optional) |
<Callout type="warning">
For OAuth in production, `MCP_OAUTH_REDIRECT_URI` must be a publicly accessible URL pointing to your DocsGPT backend.
</Callout>
## API Endpoints
| Endpoint | Method | Description |
|----------|--------|-------------|
| `/api/mcp_server/test` | POST | Test a connection without saving |
| `/api/mcp_server/save` | POST | Save or update a server configuration |
| `/api/mcp_server/callback` | GET | OAuth callback handler |
| `/api/mcp_server/oauth_status/<task_id>` | GET | Poll OAuth flow status |
| `/api/mcp_server/auth_status` | GET | Batch check auth status for all MCP tools |
## Troubleshooting
- **Connection refused** — Verify the URL and that the server is reachable from your backend.
- **403 Forbidden** — Check credentials and permissions.
- **Timed out** — Default is 30s; increase timeout in tool config (max 300s).
- **OAuth "needs_auth" persists** — Verify `MCP_OAUTH_REDIRECT_URI` is correct and Redis is running.

View File

@@ -0,0 +1,63 @@
---
title: SharePoint / OneDrive Connector
description: Connect your Microsoft SharePoint or OneDrive as an external knowledge base to upload and process files directly.
---
import { Callout } from 'nextra/components'
import { Steps } from 'nextra/components'
# SharePoint / OneDrive Connector
Connect your SharePoint or OneDrive account to upload and process files directly as an external knowledge base. Supports Office files, PDFs, text files, CSVs, images, and more. Authentication is handled via Microsoft Entra ID (Azure AD) with automatic token refresh.
## Setup
<Steps>
### Step 1: Create an App Registration in Azure
1. Go to the [Azure Portal](https://portal.azure.com/) > **Microsoft Entra ID** > **App registrations** > **New registration**
2. Set **Redirect URI** (Web) to:
- Local: `http://localhost:7091/api/connectors/callback?provider=share_point`
- Production: `https://yourdomain.com/api/connectors/callback?provider=share_point`
### Step 2: Configure API Permissions
In your App Registration, go to **API permissions** > **Add a permission** > **Microsoft Graph** > **Delegated permissions** and add: `Files.Read`, `Files.Read.All`, `Sites.Read.All`. Grant admin consent if possible.
### Step 3: Create a Client Secret
Go to **Certificates & secrets** > **New client secret**. Copy the secret value immediately (it won't be shown again).
### Step 4: Configure Environment Variables
Add to your `.env` file:
```env
MICROSOFT_CLIENT_ID=your-azure-ad-client-id
MICROSOFT_CLIENT_SECRET=your-azure-ad-client-secret
MICROSOFT_TENANT_ID=your-azure-ad-tenant-id
```
| Variable | Description | Required | Default |
|----------|-------------|----------|---------|
| `MICROSOFT_CLIENT_ID` | Application (client) ID from App Registration overview | Yes | — |
| `MICROSOFT_CLIENT_SECRET` | Client secret value | Yes | — |
| `MICROSOFT_TENANT_ID` | Directory (tenant) ID | No | `common` |
| `MICROSOFT_AUTHORITY` | Login endpoint override | No | Auto-constructed |
<Callout type="warning">
`MICROSOFT_TENANT_ID=common` (the default) allows any Microsoft account to authenticate. Set this to your specific tenant ID in production.
</Callout>
### Step 5: Restart and Use
Restart your application, then go to the upload section in DocsGPT and select **SharePoint / OneDrive** as the source. You'll be redirected to Microsoft to sign in, then can browse and select files to process.
</Steps>
## Troubleshooting
- **Option not appearing** — Verify `MICROSOFT_CLIENT_ID` and `MICROSOFT_CLIENT_SECRET` are set, then restart.
- **Authentication failed** — Check that the redirect URI matches exactly, including `?provider=share_point`.
- **Permission denied** — Ensure admin consent is granted and the user has access to the target files.

View File

@@ -7,20 +7,10 @@ description:
If your AI uses external knowledge and is not explicit enough, it is ok, because we try to make DocsGPT friendly.
But if you want to adjust it, here is a simple way:-
- Got to `application/prompts/chat_combine_prompt.txt`
- And change it to
But if you want to adjust it, prompts are now managed through the UI and API using a template-based system. See the [Customising Prompts](/Guides/Customising-prompts) guide for details.
To make the AI stricter about staying on-topic, edit your active prompt template (via **Sidebar → Settings → Active Prompt**) to include instructions like:
```
You are a DocsGPT, friendly and helpful AI assistant by Arc53 that provides help with documents. You give thorough answers with code examples, if possible.
Write an answer for the question below based on the provided context.
If the context provides insufficient information, reply "I cannot answer".
You have access to chat history and can use it to help answer the question.
----------------
{summaries}
```

View File

@@ -29,7 +29,7 @@ export default {
"title": "OCR",
"href": "/Guides/ocr"
},
"Integrations": {
"Integrations": {
"title": "🔗 Integrations"
}
}

View File

@@ -70,7 +70,7 @@ The easiest way to launch DocsGPT is using the provided `setup.sh` script. This
To stop DocsGPT, simply open a new terminal in the `DocsGPT` directory and run:
```bash
docker compose -f deployment/docker-compose.yaml down
docker compose -f deployment/docker-compose-hub.yaml down
```
(or the specific `docker compose` command shown at the end of the `setup.sh` execution, which may include optional compose files depending on your choices).

View File

@@ -1,12 +1,12 @@
{
"name": "docsgpt",
"version": "0.5.1",
"version": "0.6.3",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "docsgpt",
"version": "0.5.1",
"version": "0.6.3",
"license": "Apache-2.0",
"dependencies": {
"@babel/plugin-transform-flow-strip-types": "^7.23.3",

View File

@@ -1,6 +1,6 @@
{
"name": "docsgpt",
"version": "0.6.1",
"version": "0.6.3",
"private": false,
"description": "DocsGPT 🦖 is an innovative open-source tool designed to simplify the retrieval of information from project documentation using advanced GPT models 🤖.",
"source": "./src/index.html",

View File

@@ -1,6 +1,6 @@
aiohttp>=3,<4
certifi==2024.7.4
h11==0.16.0
h11==0.14.0
httpcore==1.0.5
httpx==0.27.0
idna==3.7

View File

@@ -40,9 +40,12 @@ import {
} from '@/components/ui/select';
import { Sheet, SheetContent } from '@/components/ui/sheet';
import { useSelector } from 'react-redux';
import modelService from '../api/services/modelService';
import userService from '../api/services/userService';
import ArrowLeft from '../assets/arrow-left.svg';
import { selectToken } from '../preferences/preferenceSlice';
import { WorkflowNode } from './types/workflow';
import {
AgentNode,
@@ -77,6 +80,7 @@ interface UserTool {
function WorkflowBuilderInner() {
const navigate = useNavigate();
const token = useSelector(selectToken);
const { agentId } = useParams<{ agentId?: string }>();
const [searchParams] = useSearchParams();
const folderId = searchParams.get('folder_id');
@@ -304,7 +308,7 @@ function WorkflowBuilderInner() {
setAvailableModels(modelService.transformModels(modelsData.models));
}
const toolsResponse = await userService.getUserTools(null);
const toolsResponse = await userService.getUserTools(token);
if (toolsResponse.ok) {
const toolsData = await toolsResponse.json();
setAvailableTools(toolsData.tools);

View File

@@ -54,7 +54,10 @@ import { FileUpload } from '../../components/FileUpload';
import AgentDetailsModal from '../../modals/AgentDetailsModal';
import ConfirmationModal from '../../modals/ConfirmationModal';
import { ActiveState } from '../../models/misc';
import { selectToken } from '../../preferences/preferenceSlice';
import {
selectSourceDocs,
selectToken,
} from '../../preferences/preferenceSlice';
import { getToolDisplayName } from '../../utils/toolUtils';
import { Agent } from '../types';
import { ConditionCase, WorkflowNode } from '../types/workflow';
@@ -300,6 +303,7 @@ function createWorkflowPayload(
function WorkflowBuilderInner() {
const navigate = useNavigate();
const token = useSelector(selectToken);
const sourceDocs = useSelector(selectSourceDocs);
const { agentId } = useParams<{ agentId?: string }>();
const [searchParams] = useSearchParams();
const folderId = searchParams.get('folder_id');
@@ -341,6 +345,14 @@ function WorkflowBuilderInner() {
const [availableModels, setAvailableModels] = useState<Model[]>([]);
const [defaultAgentModelId, setDefaultAgentModelId] = useState('');
const [availableTools, setAvailableTools] = useState<UserTool[]>([]);
const sourceOptions = useMemo(
() =>
(sourceDocs ?? []).map((doc) => ({
value: doc.id ?? 'default',
label: doc.name,
})),
[sourceDocs],
);
const [agentJsonSchemaDrafts, setAgentJsonSchemaDrafts] = useState<
Record<string, string>
>({});
@@ -387,31 +399,39 @@ function WorkflowBuilderInner() {
[],
);
const onConnect = useCallback((params: Connection) => {
setEdges((eds) => {
const exists = eds.some(
(e) =>
e.source === params.source &&
e.sourceHandle === params.sourceHandle &&
e.target === params.target &&
e.targetHandle === params.targetHandle,
);
if (exists) return eds;
const filtered = eds.filter(
(e) =>
!(
const onConnect = useCallback(
(params: Connection) => {
setEdges((eds) => {
const exists = eds.some(
(e) =>
e.source === params.source &&
e.sourceHandle === (params.sourceHandle ?? null)
) &&
!(
e.sourceHandle === params.sourceHandle &&
e.target === params.target &&
e.targetHandle === (params.targetHandle ?? null)
),
);
return addEdge(params, filtered);
});
}, []);
e.targetHandle === params.targetHandle,
);
if (exists) return eds;
const targetNode = nodes.find((n) => n.id === params.target);
const isEndNode = targetNode?.type === 'end';
const filtered = eds.filter(
(e) =>
!(
e.source === params.source &&
e.sourceHandle === (params.sourceHandle ?? null)
) &&
// End nodes accept multiple incoming edges
(isEndNode ||
!(
e.target === params.target &&
e.targetHandle === (params.targetHandle ?? null)
)),
);
return addEdge(params, filtered);
});
},
[nodes],
);
const onEdgeClick = useCallback((_event: React.MouseEvent, edge: Edge) => {
setEdges((eds) => eds.filter((e) => e.id !== edge.id));
@@ -701,7 +721,7 @@ function WorkflowBuilderInner() {
setDefaultAgentModelId(preferredDefaultModel);
}
const toolsResponse = await userService.getUserTools(null);
const toolsResponse = await userService.getUserTools(token);
if (toolsResponse.ok) {
const toolsData = await toolsResponse.json();
setAvailableTools(toolsData.tools);
@@ -1271,8 +1291,8 @@ function WorkflowBuilderInner() {
const handlePrimaryAction = useCallback(() => {
if (isPrimaryActionDisabled) return;
void persistWorkflow(!canManageAgent);
}, [isPrimaryActionDisabled, persistWorkflow, canManageAgent]);
void persistWorkflow(false);
}, [isPrimaryActionDisabled, persistWorkflow]);
const agentForDetails = useMemo<Agent>(
() => ({
@@ -1910,6 +1930,28 @@ function WorkflowBuilderInner() {
emptyText="No tools available"
/>
</div>
<div>
<label className="mb-2 block text-sm font-medium text-gray-700 dark:text-gray-300">
Sources
</label>
<MultiSelect
options={sourceOptions}
selected={
selectedNode.data.config?.sources || []
}
onChange={(newSources) =>
handleUpdateNodeData({
config: {
...(selectedNode.data.config || {}),
sources: newSources,
},
})
}
placeholder="Select sources..."
searchPlaceholder="Search sources..."
emptyText="No sources available"
/>
</div>
<div>
<label className="mb-2 block text-sm font-medium text-gray-700 dark:text-gray-300">
Structured Output (JSON Schema)

View File

@@ -70,12 +70,12 @@ export function MultiSelect({
role="combobox"
aria-expanded={open}
className={cn(
'w-full justify-between border-[#E5E5E5] bg-white hover:bg-gray-50 dark:border-[#3A3A3A] dark:bg-[#2C2C2C] dark:hover:bg-[#383838]',
'h-auto min-h-[2.5rem] w-full justify-between border-[#E5E5E5] bg-white py-1.5 hover:bg-gray-50 dark:border-[#3A3A3A] dark:bg-[#2C2C2C] dark:hover:bg-[#383838]',
!selected.length && 'text-gray-500 dark:text-gray-400',
className,
)}
>
<div className="flex flex-wrap gap-1">
<div className="flex min-w-0 flex-wrap gap-1">
{selected.length === 0 ? (
placeholder
) : (
@@ -85,9 +85,9 @@ export function MultiSelect({
return (
<span
key={option?.value || label}
className="dark:bg-purple-30/30 bg-violets-are-blue/20 inline-flex items-center gap-1 rounded-md px-2 py-0.5 text-xs font-medium text-purple-700 dark:text-purple-300"
className="dark:bg-purple-30/30 bg-violets-are-blue/20 inline-flex max-w-[calc(100%-1rem)] items-center gap-1 rounded-md px-2 py-0.5 text-xs font-medium text-purple-700 dark:text-purple-300"
>
{label}
<span className="truncate">{label}</span>
<span
role="button"
tabIndex={0}

View File

@@ -3,7 +3,7 @@ testpaths = tests
python_files = test_*.py
python_classes = Test*
python_functions = test_*
addopts =
addopts =
-v
--strict-markers
--tb=short
@@ -11,6 +11,7 @@ addopts =
--cov-report=html
--cov-report=term-missing
--cov-report=xml
--ignore=tests/integration
markers =
unit: Unit tests
integration: Integration tests

View File

@@ -977,8 +977,8 @@ function Connect-CloudAPIProvider {
}
"7" { # Novita
$script:provider_name = "Novita"
$script:llm_name = "novita"
$script:model_name = "deepseek/deepseek-r1"
$script:llm_provider = "novita"
$script:model_name = "moonshotai/kimi-k2.5"
Get-APIKey
break
}

View File

@@ -704,7 +704,7 @@ connect_cloud_api_provider() {
7) # Novita
provider_name="Novita"
llm_provider="novita"
model_name="deepseek/deepseek-r1"
model_name="moonshotai/kimi-k2.5"
get_api_key
break ;;
b|B) clear; return 1 ;; # Clear screen and Back to Main Menu

View File

@@ -0,0 +1,202 @@
"""Tests for application/agents/tools/api_body_serializer.py"""
import json
import pytest
from application.agents.tools.api_body_serializer import (
ContentType,
RequestBodySerializer,
)
@pytest.mark.unit
class TestContentTypeEnum:
def test_json_value(self):
assert ContentType.JSON == "application/json"
def test_form_urlencoded_value(self):
assert ContentType.FORM_URLENCODED == "application/x-www-form-urlencoded"
def test_multipart_value(self):
assert ContentType.MULTIPART_FORM_DATA == "multipart/form-data"
def test_text_plain_value(self):
assert ContentType.TEXT_PLAIN == "text/plain"
def test_xml_value(self):
assert ContentType.XML == "application/xml"
def test_octet_stream_value(self):
assert ContentType.OCTET_STREAM == "application/octet-stream"
@pytest.mark.unit
class TestSerializeJson:
def test_basic_json(self):
body, headers = RequestBodySerializer.serialize(
{"key": "value"}, ContentType.JSON
)
assert json.loads(body) == {"key": "value"}
assert headers["Content-Type"] == "application/json"
def test_nested_json(self):
data = {"user": {"name": "Alice", "age": 30}}
body, headers = RequestBodySerializer.serialize(data, ContentType.JSON)
assert json.loads(body) == data
def test_empty_body_returns_none(self):
body, headers = RequestBodySerializer.serialize({}, ContentType.JSON)
assert body is None
assert headers == {}
def test_none_body(self):
body, headers = RequestBodySerializer.serialize(None, ContentType.JSON)
assert body is None
def test_unknown_content_type_falls_back_to_json(self):
body, headers = RequestBodySerializer.serialize(
{"k": "v"}, "application/vnd.custom+json"
)
assert json.loads(body) == {"k": "v"}
def test_content_type_with_charset_suffix(self):
body, headers = RequestBodySerializer.serialize(
{"k": "v"}, "application/json; charset=utf-8"
)
assert json.loads(body) == {"k": "v"}
@pytest.mark.unit
class TestSerializeFormUrlencoded:
def test_basic_form(self):
body, headers = RequestBodySerializer.serialize(
{"name": "Alice", "age": "30"}, ContentType.FORM_URLENCODED
)
assert "name=Alice" in body
assert "age=30" in body
assert headers["Content-Type"] == "application/x-www-form-urlencoded"
def test_none_values_skipped(self):
body, headers = RequestBodySerializer.serialize(
{"name": "Alice", "skip": None}, ContentType.FORM_URLENCODED
)
assert "name=Alice" in body
assert "skip" not in body
def test_list_explode_true(self):
body, headers = RequestBodySerializer.serialize(
{"tags": ["a", "b"]},
ContentType.FORM_URLENCODED,
encoding_rules={"tags": {"style": "form", "explode": True}},
)
assert "tags=a" in body
assert "tags=b" in body
def test_list_explode_false(self):
body, headers = RequestBodySerializer.serialize(
{"tags": ["a", "b"]},
ContentType.FORM_URLENCODED,
encoding_rules={"tags": {"style": "form", "explode": False}},
)
# Value is percent-encoded by _serialize_form_value then urlencoded again
assert "tags=" in body
assert "a" in body and "b" in body
def test_dict_value_json_content_type(self):
body, headers = RequestBodySerializer.serialize(
{"metadata": {"key": "val"}},
ContentType.FORM_URLENCODED,
encoding_rules={"metadata": {"contentType": "application/json"}},
)
assert "metadata" in body
@pytest.mark.unit
class TestSerializeTextPlain:
def test_single_value(self):
body, headers = RequestBodySerializer.serialize(
{"message": "hello"}, ContentType.TEXT_PLAIN
)
assert body == "hello"
assert headers["Content-Type"] == "text/plain"
def test_multiple_values(self):
body, headers = RequestBodySerializer.serialize(
{"name": "Alice", "age": 30}, ContentType.TEXT_PLAIN
)
assert "name: Alice" in body
assert "age: 30" in body
@pytest.mark.unit
class TestSerializeXml:
def test_basic_xml(self):
body, headers = RequestBodySerializer.serialize(
{"name": "Alice"}, ContentType.XML
)
assert '<?xml version="1.0"' in body
assert "<name>Alice</name>" in body
assert headers["Content-Type"] == "application/xml"
def test_nested_xml(self):
body, headers = RequestBodySerializer.serialize(
{"user": {"name": "Alice"}}, ContentType.XML
)
assert "<user>" in body
assert "<name>Alice</name>" in body
def test_xml_escapes_special_chars(self):
body, headers = RequestBodySerializer.serialize(
{"data": "<script>alert('xss')</script>"}, ContentType.XML
)
assert "&lt;script&gt;" in body
@pytest.mark.unit
class TestSerializeOctetStream:
def test_dict_body(self):
body, headers = RequestBodySerializer.serialize(
{"key": "val"}, ContentType.OCTET_STREAM
)
assert isinstance(body, bytes)
assert headers["Content-Type"] == "application/octet-stream"
@pytest.mark.unit
class TestSerializeMultipartFormData:
def test_basic_multipart(self):
body, headers = RequestBodySerializer.serialize(
{"field": "value"}, ContentType.MULTIPART_FORM_DATA
)
assert isinstance(body, bytes)
assert "multipart/form-data" in headers["Content-Type"]
assert "boundary=" in headers["Content-Type"]
def test_none_values_skipped(self):
body, headers = RequestBodySerializer.serialize(
{"field": "value", "empty": None}, ContentType.MULTIPART_FORM_DATA
)
body_str = body.decode("utf-8", errors="replace")
assert "field" in body_str
assert "empty" not in body_str
@pytest.mark.unit
class TestHelpers:
def test_percent_encode(self):
assert RequestBodySerializer._percent_encode("hello world") == "hello%20world"
assert RequestBodySerializer._percent_encode("a/b") == "a%2Fb"
assert RequestBodySerializer._percent_encode("safe", safe_chars="/") == "safe"
def test_escape_xml(self):
assert "&amp;" in RequestBodySerializer._escape_xml("&")
assert "&lt;" in RequestBodySerializer._escape_xml("<")
assert "&gt;" in RequestBodySerializer._escape_xml(">")
assert "&quot;" in RequestBodySerializer._escape_xml('"')
assert "&apos;" in RequestBodySerializer._escape_xml("'")
def test_dict_to_xml_list(self):
xml = RequestBodySerializer._dict_to_xml({"items": [1, 2, 3]})
assert "<item>1</item>" in xml
assert "<item>2</item>" in xml

View File

@@ -0,0 +1,282 @@
"""Tests for application/agents/tools/api_tool.py"""
import json
from unittest.mock import MagicMock, patch
import pytest
import requests
from application.agents.tools.api_tool import APITool
@pytest.fixture
def tool():
return APITool(
config={
"url": "https://api.example.com/data",
"method": "GET",
"headers": {"Accept": "application/json"},
"query_params": {},
}
)
@pytest.fixture
def post_tool():
return APITool(
config={
"url": "https://api.example.com/items",
"method": "POST",
"headers": {},
"query_params": {},
}
)
@pytest.mark.unit
class TestAPIToolInit:
def test_default_values(self):
tool = APITool(config={})
assert tool.url == ""
assert tool.method == "GET"
assert tool.headers == {}
assert tool.query_params == {}
@pytest.mark.unit
class TestMakeApiCall:
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_successful_get(self, mock_get, mock_validate, tool):
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {"result": "ok"}
mock_resp.content = b'{"result":"ok"}'
mock_get.return_value = mock_resp
result = tool.execute_action("any_action")
assert result["status_code"] == 200
assert result["data"] == {"result": "ok"}
assert result["message"] == "API call successful."
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.post")
def test_successful_post(self, mock_post, mock_validate, post_tool):
mock_resp = MagicMock()
mock_resp.status_code = 201
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {"id": 1}
mock_resp.content = b'{"id":1}'
mock_post.return_value = mock_resp
result = post_tool.execute_action("create", name="test")
assert result["status_code"] == 201
@patch("application.agents.tools.api_tool.validate_url")
def test_ssrf_blocked(self, mock_validate, tool):
from application.core.url_validation import SSRFError
mock_validate.side_effect = SSRFError("blocked")
result = tool.execute_action("any")
assert result["status_code"] is None
assert "URL validation error" in result["message"]
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_timeout_error(self, mock_get, mock_validate, tool):
mock_get.side_effect = requests.exceptions.Timeout()
result = tool.execute_action("any")
assert result["status_code"] is None
assert "timeout" in result["message"].lower()
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_connection_error(self, mock_get, mock_validate, tool):
mock_get.side_effect = requests.exceptions.ConnectionError("refused")
result = tool.execute_action("any")
assert result["status_code"] is None
assert "Connection error" in result["message"]
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_http_error(self, mock_get, mock_validate, tool):
mock_resp = MagicMock()
mock_resp.status_code = 404
mock_resp.text = "Not Found"
mock_resp.json.side_effect = json.JSONDecodeError("", "", 0)
mock_resp.raise_for_status.side_effect = requests.exceptions.HTTPError(
response=mock_resp
)
mock_get.return_value = mock_resp
result = tool.execute_action("any")
assert result["status_code"] == 404
assert "HTTP Error" in result["message"]
@patch("application.agents.tools.api_tool.validate_url")
def test_unsupported_method(self, mock_validate):
tool = APITool(
config={"url": "https://example.com", "method": "CUSTOM"}
)
result = tool.execute_action("any")
assert result["status_code"] is None
assert "Unsupported" in result["message"]
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.put")
def test_put_method(self, mock_put, mock_validate):
tool = APITool(config={"url": "https://example.com/item/1", "method": "PUT"})
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {}
mock_resp.content = b'{}'
mock_put.return_value = mock_resp
result = tool.execute_action("update", name="new")
assert result["status_code"] == 200
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.delete")
def test_delete_method(self, mock_delete, mock_validate):
tool = APITool(config={"url": "https://example.com/item/1", "method": "DELETE"})
mock_resp = MagicMock()
mock_resp.status_code = 204
mock_resp.headers = {"Content-Type": "text/plain"}
mock_resp.content = b''
mock_delete.return_value = mock_resp
result = tool.execute_action("delete")
assert result["status_code"] == 204
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.patch")
def test_patch_method(self, mock_patch, mock_validate):
tool = APITool(config={"url": "https://example.com/item/1", "method": "PATCH"})
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {"patched": True}
mock_resp.content = b'{"patched":true}'
mock_patch.return_value = mock_resp
result = tool.execute_action("patch", field="val")
assert result["status_code"] == 200
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.head")
def test_head_method(self, mock_head, mock_validate):
tool = APITool(config={"url": "https://example.com", "method": "HEAD"})
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "text/html"}
mock_resp.content = b''
mock_head.return_value = mock_resp
result = tool.execute_action("check")
assert result["status_code"] == 200
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.options")
def test_options_method(self, mock_options, mock_validate):
tool = APITool(config={"url": "https://example.com", "method": "OPTIONS"})
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "text/plain"}
mock_resp.content = b''
mock_options.return_value = mock_resp
result = tool.execute_action("options")
assert result["status_code"] == 200
@pytest.mark.unit
class TestPathParamSubstitution:
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_path_params_substituted(self, mock_get, mock_validate):
tool = APITool(
config={
"url": "https://api.example.com/users/{user_id}/posts/{post_id}",
"method": "GET",
"query_params": {"user_id": "42", "post_id": "7"},
}
)
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = []
mock_resp.content = b'[]'
mock_get.return_value = mock_resp
tool.execute_action("get")
called_url = mock_get.call_args[0][0]
assert "/users/42/posts/7" in called_url
assert "{user_id}" not in called_url
@pytest.mark.unit
class TestParseResponse:
def test_json_response(self, tool):
mock_resp = MagicMock()
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {"key": "val"}
mock_resp.content = b'{"key":"val"}'
result = tool._parse_response(mock_resp)
assert result == {"key": "val"}
def test_text_response(self, tool):
mock_resp = MagicMock()
mock_resp.headers = {"Content-Type": "text/plain"}
mock_resp.text = "plain text"
mock_resp.content = b"plain text"
result = tool._parse_response(mock_resp)
assert result == "plain text"
def test_xml_response(self, tool):
mock_resp = MagicMock()
mock_resp.headers = {"Content-Type": "application/xml"}
mock_resp.text = "<root><item>1</item></root>"
mock_resp.content = b"<root><item>1</item></root>"
result = tool._parse_response(mock_resp)
assert "<root>" in result
def test_empty_content(self, tool):
mock_resp = MagicMock()
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.content = b""
result = tool._parse_response(mock_resp)
assert result is None
def test_html_response(self, tool):
mock_resp = MagicMock()
mock_resp.headers = {"Content-Type": "text/html"}
mock_resp.text = "<html><body>Hi</body></html>"
mock_resp.content = b"<html><body>Hi</body></html>"
result = tool._parse_response(mock_resp)
assert "<html>" in result
@pytest.mark.unit
class TestAPIToolMetadata:
def test_actions_metadata_empty(self, tool):
assert tool.get_actions_metadata() == []
def test_config_requirements_empty(self, tool):
assert tool.get_config_requirements() == {}

View File

@@ -1,4 +1,4 @@
from unittest.mock import Mock
from unittest.mock import Mock, patch
import pytest
from application.agents.classic_agent import ClassicAgent
@@ -56,6 +56,73 @@ class TestBaseAgentInitialization:
agent = ClassicAgent(**agent_base_params)
assert agent.user == "user123"
def test_dependency_injection_llm(self, agent_base_params, mock_llm_handler_creator):
"""When llm is provided, LLMCreator.create_llm is NOT called."""
injected_llm = Mock()
agent_base_params["llm"] = injected_llm
agent = ClassicAgent(**agent_base_params)
assert agent.llm is injected_llm
def test_dependency_injection_llm_handler(self, agent_base_params, mock_llm_creator):
"""When llm_handler is provided, LLMHandlerCreator is NOT called."""
injected_handler = Mock()
agent_base_params["llm_handler"] = injected_handler
agent = ClassicAgent(**agent_base_params)
assert agent.llm_handler is injected_handler
def test_dependency_injection_tool_executor(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
"""When tool_executor is provided, a new one is NOT created."""
injected_executor = Mock()
injected_executor.tool_calls = []
agent_base_params["tool_executor"] = injected_executor
agent = ClassicAgent(**agent_base_params)
assert agent.tool_executor is injected_executor
def test_json_schema_normalized(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent_base_params["json_schema"] = {"type": "object"}
agent = ClassicAgent(**agent_base_params)
assert agent.json_schema == {"type": "object"}
def test_json_schema_wrapped(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent_base_params["json_schema"] = {"schema": {"type": "string"}}
agent = ClassicAgent(**agent_base_params)
assert agent.json_schema == {"type": "string"}
def test_json_schema_invalid_ignored(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent_base_params["json_schema"] = {"bad": "no type"}
agent = ClassicAgent(**agent_base_params)
assert agent.json_schema is None
def test_retrieved_docs_defaults_to_empty(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
assert agent.retrieved_docs == []
def test_attachments_defaults_to_empty(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent_base_params["attachments"] = None
agent = ClassicAgent(**agent_base_params)
assert agent.attachments == []
def test_limited_token_mode_defaults(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
assert agent.limited_token_mode is False
assert agent.limited_request_mode is False
assert agent.current_token_count == 0
assert agent.context_limit_reached is False
@pytest.mark.unit
class TestBaseAgentBuildMessages:
@@ -602,3 +669,656 @@ class TestBaseAgentHandleResponse:
assert len(results) == 2
assert results[0]["type"] == "tool_call"
assert results[1]["answer"] == "Final answer"
def test_handle_response_dict_event_passthrough(
self,
agent_base_params,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
log_context,
):
"""Dict events with 'type' key pass through without wrapping."""
def mock_process(*args):
yield {"type": "info", "data": {"message": "processing"}}
mock_llm_handler.process_message_flow = Mock(side_effect=mock_process)
agent = ClassicAgent(**agent_base_params)
response = Mock()
response.message = None
results = list(agent._handle_response(response, {}, [], log_context))
assert results == [{"type": "info", "data": {"message": "processing"}}]
def test_handle_response_message_object_from_handler(
self,
agent_base_params,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
log_context,
):
"""Response objects with .message.content from handler are unwrapped."""
event = Mock()
event.message = Mock()
event.message.content = "from handler"
def mock_process(*args):
yield event
mock_llm_handler.process_message_flow = Mock(side_effect=mock_process)
agent = ClassicAgent(**agent_base_params)
response = Mock()
response.message = None
results = list(agent._handle_response(response, {}, [], log_context))
assert results[0]["answer"] == "from handler"
# ---------------------------------------------------------------------------
# gen() — the @log_activity decorated entry point
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestBaseAgentGen:
def test_gen_delegates_to_gen_inner(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
# ClassicAgent._gen_inner is abstract — we patch it
with patch.object(agent, "_gen_inner") as mock_inner:
mock_inner.return_value = iter([{"answer": "ok"}])
results = list(agent.gen("hello"))
assert any(r.get("answer") == "ok" for r in results)
# ---------------------------------------------------------------------------
# tool_calls property
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestBaseAgentToolCallsProperty:
def test_getter(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
agent.tool_executor.tool_calls = ["a", "b"]
assert agent.tool_calls == ["a", "b"]
def test_setter(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
agent.tool_calls = ["x"]
assert agent.tool_executor.tool_calls == ["x"]
# ---------------------------------------------------------------------------
# _calculate_current_context_tokens
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestCalculateContextTokens:
def test_delegates_to_token_counter(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
messages = [{"role": "user", "content": "hello"}]
with patch(
"application.api.answer.services.compression.token_counter.TokenCounter"
) as MockTC:
MockTC.count_message_tokens.return_value = 42
result = agent._calculate_current_context_tokens(messages)
assert result == 42
MockTC.count_message_tokens.assert_called_once_with(messages)
# ---------------------------------------------------------------------------
# _check_context_limit
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestCheckContextLimit:
def _make_agent(self, agent_base_params, mock_llm_creator, mock_llm_handler_creator):
return ClassicAgent(**agent_base_params)
def test_below_threshold_returns_false(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = self._make_agent(
agent_base_params, mock_llm_creator, mock_llm_handler_creator
)
messages = [{"role": "user", "content": "hi"}]
with patch.object(agent, "_calculate_current_context_tokens", return_value=100):
with patch(
"application.core.model_utils.get_token_limit", return_value=10000
):
result = agent._check_context_limit(messages)
assert result is False
assert agent.current_token_count == 100
def test_at_threshold_returns_true(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = self._make_agent(
agent_base_params, mock_llm_creator, mock_llm_handler_creator
)
messages = [{"role": "user", "content": "hi"}]
# threshold = 10000 * 0.8 = 8000; tokens = 8001 → True
with patch.object(agent, "_calculate_current_context_tokens", return_value=8001):
with patch(
"application.core.model_utils.get_token_limit", return_value=10000
):
result = agent._check_context_limit(messages)
assert result is True
def test_error_returns_false(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = self._make_agent(
agent_base_params, mock_llm_creator, mock_llm_handler_creator
)
with patch.object(
agent,
"_calculate_current_context_tokens",
side_effect=RuntimeError("boom"),
):
result = agent._check_context_limit([])
assert result is False
# ---------------------------------------------------------------------------
# _validate_context_size
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestValidateContextSize:
def test_at_limit_logs_warning(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
with patch.object(agent, "_calculate_current_context_tokens", return_value=10000):
with patch(
"application.core.model_utils.get_token_limit", return_value=10000
):
# Should not raise
agent._validate_context_size([{"role": "user", "content": "x"}])
assert agent.current_token_count == 10000
def test_below_threshold_no_warning(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
with patch.object(agent, "_calculate_current_context_tokens", return_value=100):
with patch(
"application.core.model_utils.get_token_limit", return_value=10000
):
agent._validate_context_size([])
assert agent.current_token_count == 100
def test_approaching_threshold(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
# 8500 / 10000 = 85% → above 80% threshold but below 100%
with patch.object(agent, "_calculate_current_context_tokens", return_value=8500):
with patch(
"application.core.model_utils.get_token_limit", return_value=10000
):
agent._validate_context_size([])
assert agent.current_token_count == 8500
# ---------------------------------------------------------------------------
# _truncate_text_middle
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestTruncateTextMiddle:
def test_short_text_unchanged(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
with patch("application.utils.num_tokens_from_string", return_value=5):
result = agent._truncate_text_middle("short", max_tokens=100)
assert result == "short"
def test_long_text_truncated(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
long_text = "A" * 1000
def fake_tokens(text):
return len(text) // 4
with patch("application.utils.num_tokens_from_string", side_effect=fake_tokens):
result = agent._truncate_text_middle(long_text, max_tokens=50)
assert "[... content truncated to fit context limit ...]" in result
assert len(result) < len(long_text)
def test_zero_target_returns_empty(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
with patch("application.utils.num_tokens_from_string", return_value=100):
result = agent._truncate_text_middle("some text", max_tokens=0)
assert result == ""
# ---------------------------------------------------------------------------
# _truncate_history_to_fit
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestTruncateHistoryToFit:
def _make_agent(self, agent_base_params, mock_llm_creator, mock_llm_handler_creator):
return ClassicAgent(**agent_base_params)
def test_empty_history(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = self._make_agent(
agent_base_params, mock_llm_creator, mock_llm_handler_creator
)
assert agent._truncate_history_to_fit([], 100) == []
def test_zero_budget(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = self._make_agent(
agent_base_params, mock_llm_creator, mock_llm_handler_creator
)
history = [{"prompt": "a", "response": "b"}]
assert agent._truncate_history_to_fit(history, 0) == []
def test_fits_all(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = self._make_agent(
agent_base_params, mock_llm_creator, mock_llm_handler_creator
)
history = [
{"prompt": "q1", "response": "a1"},
{"prompt": "q2", "response": "a2"},
]
with patch("application.utils.num_tokens_from_string", return_value=5):
result = agent._truncate_history_to_fit(history, 10000)
assert len(result) == 2
def test_partial_fit_keeps_most_recent(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = self._make_agent(
agent_base_params, mock_llm_creator, mock_llm_handler_creator
)
history = [
{"prompt": "old", "response": "old_ans"},
{"prompt": "new", "response": "new_ans"},
]
# Each message = 10 tokens (prompt + response), budget = 15 → only 1 fits
with patch("application.utils.num_tokens_from_string", return_value=5):
result = agent._truncate_history_to_fit(history, 15)
assert len(result) == 1
assert result[0]["prompt"] == "new"
def test_history_with_tool_calls(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = self._make_agent(
agent_base_params, mock_llm_creator, mock_llm_handler_creator
)
history = [
{
"prompt": "q",
"response": "a",
"tool_calls": [
{
"tool_name": "t",
"action_name": "act",
"arguments": "{}",
"result": "ok",
}
],
}
]
with patch("application.utils.num_tokens_from_string", return_value=3):
result = agent._truncate_history_to_fit(history, 100)
assert len(result) == 1
# ---------------------------------------------------------------------------
# _build_messages — compressed_summary and query truncation
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestBuildMessagesAdvanced:
def test_compressed_summary_appended(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent_base_params["compressed_summary"] = "Previous conversation summary"
agent = ClassicAgent(**agent_base_params)
with patch(
"application.core.model_utils.get_token_limit", return_value=100000
), patch("application.utils.num_tokens_from_string", return_value=10):
messages = agent._build_messages("System prompt", "query")
system_content = messages[0]["content"]
assert "Previous conversation summary" in system_content
def test_query_truncated_when_too_large(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
call_count = {"n": 0}
def fake_tokens(text):
call_count["n"] += 1
return len(text)
with patch(
"application.core.model_utils.get_token_limit", return_value=200
), patch("application.utils.num_tokens_from_string", side_effect=fake_tokens):
with patch.object(agent, "_truncate_text_middle", return_value="truncated"):
with patch.object(agent, "_truncate_history_to_fit", return_value=[]):
messages = agent._build_messages("sys", "A" * 500)
# The method should have been called for truncation
assert messages[-1]["role"] == "user"
def test_build_messages_with_tool_call_missing_call_id(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
"""Tool calls without call_id get a generated UUID."""
history = [
{
"tool_calls": [
{
"action_name": "search",
"arguments": "{}",
"result": "found",
}
]
}
]
agent_base_params["chat_history"] = history
agent = ClassicAgent(**agent_base_params)
with patch(
"application.core.model_utils.get_token_limit", return_value=100000
), patch("application.utils.num_tokens_from_string", return_value=5):
messages = agent._build_messages("sys", "q")
tool_msgs = [m for m in messages if m["role"] == "tool"]
assert len(tool_msgs) == 1
# ---------------------------------------------------------------------------
# _llm_gen — edge cases
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestLLMGenAdvanced:
def test_llm_gen_with_attachments(
self,
agent_base_params,
mock_llm,
mock_llm_creator,
mock_llm_handler_creator,
):
agent_base_params["attachments"] = [{"id": "att1", "mime_type": "image/png"}]
agent = ClassicAgent(**agent_base_params)
messages = [{"role": "user", "content": "test"}]
agent._llm_gen(messages)
call_kwargs = mock_llm.gen_stream.call_args[1]
assert "_usage_attachments" in call_kwargs
def test_llm_gen_without_log_context(
self,
agent_base_params,
mock_llm,
mock_llm_creator,
mock_llm_handler_creator,
):
agent = ClassicAgent(**agent_base_params)
messages = [{"role": "user", "content": "test"}]
# Should not raise even without log_context
agent._llm_gen(messages, log_context=None)
mock_llm.gen_stream.assert_called_once()
def test_llm_gen_google_structured_output(
self,
agent_base_params,
mock_llm,
mock_llm_creator,
mock_llm_handler_creator,
log_context,
):
mock_llm._supports_structured_output = Mock(return_value=True)
mock_llm.prepare_structured_output_format = Mock(
return_value={"schema": "test"}
)
agent_base_params["json_schema"] = {"type": "object"}
agent_base_params["llm_name"] = "google"
agent = ClassicAgent(**agent_base_params)
messages = [{"role": "user", "content": "test"}]
agent._llm_gen(messages, log_context)
call_kwargs = mock_llm.gen_stream.call_args[1]
assert "response_schema" in call_kwargs
def test_llm_gen_no_tools_when_unsupported(
self,
agent_base_params,
mock_llm,
mock_llm_creator,
mock_llm_handler_creator,
):
mock_llm._supports_tools = False
agent = ClassicAgent(**agent_base_params)
agent.tools = [{"type": "function", "function": {"name": "test"}}]
messages = [{"role": "user", "content": "test"}]
agent._llm_gen(messages)
call_kwargs = mock_llm.gen_stream.call_args[1]
assert "tools" not in call_kwargs
def test_llm_gen_no_structured_output_when_unsupported(
self,
agent_base_params,
mock_llm,
mock_llm_creator,
mock_llm_handler_creator,
):
mock_llm._supports_structured_output = Mock(return_value=False)
agent_base_params["json_schema"] = {"type": "object"}
agent = ClassicAgent(**agent_base_params)
messages = [{"role": "user", "content": "test"}]
agent._llm_gen(messages)
call_kwargs = mock_llm.gen_stream.call_args[1]
assert "response_format" not in call_kwargs
assert "response_schema" not in call_kwargs
def test_llm_gen_no_format_when_prepare_returns_none(
self,
agent_base_params,
mock_llm,
mock_llm_creator,
mock_llm_handler_creator,
):
mock_llm._supports_structured_output = Mock(return_value=True)
mock_llm.prepare_structured_output_format = Mock(return_value=None)
agent_base_params["json_schema"] = {"type": "object"}
agent_base_params["llm_name"] = "openai"
agent = ClassicAgent(**agent_base_params)
messages = [{"role": "user", "content": "test"}]
agent._llm_gen(messages)
call_kwargs = mock_llm.gen_stream.call_args[1]
assert "response_format" not in call_kwargs
# ---------------------------------------------------------------------------
# _llm_handler
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestLLMHandlerMethod:
def test_delegates_to_handler(
self,
agent_base_params,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
log_context,
):
mock_llm_handler.process_message_flow = Mock(return_value="result")
agent = ClassicAgent(**agent_base_params)
resp = Mock()
result = agent._llm_handler(resp, {}, [], log_context)
mock_llm_handler.process_message_flow.assert_called_once()
assert result == "result"
assert len(log_context.stacks) == 1
assert log_context.stacks[0]["component"] == "llm_handler"
def test_without_log_context(
self,
agent_base_params,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
):
mock_llm_handler.process_message_flow = Mock(return_value="r")
agent = ClassicAgent(**agent_base_params)
result = agent._llm_handler(Mock(), {}, [], log_context=None)
assert result == "r"
# ---------------------------------------------------------------------------
# _handle_response — structured output on all code paths
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestHandleResponseStructuredAllPaths:
def test_message_response_with_structured_output(
self,
agent_base_params,
mock_llm,
mock_llm_creator,
mock_llm_handler_creator,
log_context,
):
"""Structured output on the message.content early-return path."""
mock_llm._supports_structured_output = Mock(return_value=True)
agent_base_params["json_schema"] = {"type": "object"}
agent = ClassicAgent(**agent_base_params)
response = Mock()
response.message = Mock()
response.message.content = "structured msg"
results = list(agent._handle_response(response, {}, [], log_context))
assert results[0]["structured"] is True
assert results[0]["schema"] == {"type": "object"}
assert results[0]["answer"] == "structured msg"
def test_handler_string_event_with_structured_output(
self,
agent_base_params,
mock_llm,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
log_context,
):
"""Structured output on string events from the handler."""
mock_llm._supports_structured_output = Mock(return_value=True)
agent_base_params["json_schema"] = {"type": "array"}
def mock_process(*args):
yield "handler string"
mock_llm_handler.process_message_flow = Mock(side_effect=mock_process)
agent = ClassicAgent(**agent_base_params)
response = Mock()
response.message = None
results = list(agent._handle_response(response, {}, [], log_context))
assert results[0]["structured"] is True
assert results[0]["schema"] == {"type": "array"}
def test_handler_message_event_with_structured_output(
self,
agent_base_params,
mock_llm,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
log_context,
):
"""Structured output on message-object events from the handler."""
mock_llm._supports_structured_output = Mock(return_value=True)
agent_base_params["json_schema"] = {"type": "number"}
event = Mock()
event.message = Mock()
event.message.content = "from handler msg"
def mock_process(*args):
yield event
mock_llm_handler.process_message_flow = Mock(side_effect=mock_process)
agent = ClassicAgent(**agent_base_params)
response = Mock()
response.message = None
results = list(agent._handle_response(response, {}, [], log_context))
assert results[0]["structured"] is True
assert results[0]["schema"] == {"type": "number"}
assert results[0]["answer"] == "from handler msg"

View File

@@ -0,0 +1,132 @@
"""Tests for application/agents/tools/brave.py"""
from unittest.mock import MagicMock, patch
import pytest
from application.agents.tools.brave import BraveSearchTool
@pytest.fixture
def tool():
return BraveSearchTool(config={"token": "test_api_key"})
@pytest.mark.unit
class TestBraveExecuteAction:
def test_unknown_action_raises(self, tool):
with pytest.raises(ValueError, match="Unknown action"):
tool.execute_action("invalid")
@patch("application.agents.tools.brave.requests.get")
def test_web_search_success(self, mock_get, tool):
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {"web": {"results": [{"title": "Result"}]}}
mock_get.return_value = mock_resp
result = tool.execute_action("brave_web_search", query="python")
assert result["status_code"] == 200
assert "results" in result
assert "successfully" in result["message"]
call_kwargs = mock_get.call_args
assert call_kwargs[1]["headers"]["X-Subscription-Token"] == "test_api_key"
@patch("application.agents.tools.brave.requests.get")
def test_web_search_failure(self, mock_get, tool):
mock_resp = MagicMock()
mock_resp.status_code = 429
mock_get.return_value = mock_resp
result = tool.execute_action("brave_web_search", query="test")
assert result["status_code"] == 429
assert "failed" in result["message"].lower()
@patch("application.agents.tools.brave.requests.get")
def test_image_search_success(self, mock_get, tool):
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {"results": [{"url": "https://img.com/1.jpg"}]}
mock_get.return_value = mock_resp
result = tool.execute_action("brave_image_search", query="cats")
assert result["status_code"] == 200
assert "results" in result
@patch("application.agents.tools.brave.requests.get")
def test_image_search_failure(self, mock_get, tool):
mock_resp = MagicMock()
mock_resp.status_code = 500
mock_get.return_value = mock_resp
result = tool.execute_action("brave_image_search", query="cats")
assert result["status_code"] == 500
@patch("application.agents.tools.brave.requests.get")
def test_count_capped_at_20(self, mock_get, tool):
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {}
mock_get.return_value = mock_resp
tool.execute_action("brave_web_search", query="test", count=100)
params = mock_get.call_args[1]["params"]
assert params["count"] == 20
@patch("application.agents.tools.brave.requests.get")
def test_image_count_capped_at_100(self, mock_get, tool):
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {}
mock_get.return_value = mock_resp
tool.execute_action("brave_image_search", query="test", count=500)
params = mock_get.call_args[1]["params"]
assert params["count"] == 100
@patch("application.agents.tools.brave.requests.get")
def test_freshness_param(self, mock_get, tool):
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {}
mock_get.return_value = mock_resp
tool.execute_action("brave_web_search", query="news", freshness="pd")
params = mock_get.call_args[1]["params"]
assert params["freshness"] == "pd"
@patch("application.agents.tools.brave.requests.get")
def test_offset_capped(self, mock_get, tool):
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {}
mock_get.return_value = mock_resp
tool.execute_action("brave_web_search", query="test", offset=100)
params = mock_get.call_args[1]["params"]
assert params["offset"] == 9
@pytest.mark.unit
class TestBraveMetadata:
def test_actions_metadata(self, tool):
meta = tool.get_actions_metadata()
assert len(meta) == 2
names = {a["name"] for a in meta}
assert "brave_web_search" in names
assert "brave_image_search" in names
def test_config_requirements(self, tool):
reqs = tool.get_config_requirements()
assert "token" in reqs
assert reqs["token"]["secret"] is True
assert reqs["token"]["required"] is True

View File

@@ -0,0 +1,169 @@
"""Tests for application/agents/workflows/cel_evaluator.py"""
import pytest
from application.agents.workflows.cel_evaluator import (
CelEvaluationError,
_convert_value,
build_activation,
cel_to_python,
evaluate_cel,
)
import celpy.celtypes
class TestConvertValue:
@pytest.mark.unit
def test_bool_true(self):
result = _convert_value(True)
assert isinstance(result, celpy.celtypes.BoolType)
assert bool(result) is True
@pytest.mark.unit
def test_bool_false(self):
result = _convert_value(False)
assert isinstance(result, celpy.celtypes.BoolType)
assert bool(result) is False
@pytest.mark.unit
def test_int(self):
result = _convert_value(42)
assert isinstance(result, celpy.celtypes.IntType)
assert int(result) == 42
@pytest.mark.unit
def test_float(self):
result = _convert_value(3.14)
assert isinstance(result, celpy.celtypes.DoubleType)
assert float(result) == pytest.approx(3.14)
@pytest.mark.unit
def test_string(self):
result = _convert_value("hello")
assert isinstance(result, celpy.celtypes.StringType)
assert str(result) == "hello"
@pytest.mark.unit
def test_list(self):
result = _convert_value([1, "two", 3.0])
assert isinstance(result, celpy.celtypes.ListType)
@pytest.mark.unit
def test_dict(self):
result = _convert_value({"key": "value"})
assert isinstance(result, celpy.celtypes.MapType)
@pytest.mark.unit
def test_none(self):
result = _convert_value(None)
assert isinstance(result, celpy.celtypes.BoolType)
assert bool(result) is False
@pytest.mark.unit
def test_other_type_converts_to_string(self):
result = _convert_value(object())
assert isinstance(result, celpy.celtypes.StringType)
class TestBuildActivation:
@pytest.mark.unit
def test_converts_dict_values(self):
state = {"name": "Alice", "age": 30, "active": True}
result = build_activation(state)
assert "name" in result
assert "age" in result
assert "active" in result
@pytest.mark.unit
def test_empty_state(self):
assert build_activation({}) == {}
class TestEvaluateCel:
@pytest.mark.unit
def test_simple_comparison(self):
assert evaluate_cel("x > 5", {"x": 10}) is True
assert evaluate_cel("x > 5", {"x": 3}) is False
@pytest.mark.unit
def test_string_comparison(self):
assert evaluate_cel('name == "Alice"', {"name": "Alice"}) is True
assert evaluate_cel('name == "Alice"', {"name": "Bob"}) is False
@pytest.mark.unit
def test_arithmetic(self):
assert evaluate_cel("x + y", {"x": 3, "y": 4}) == 7
@pytest.mark.unit
def test_boolean_logic(self):
assert evaluate_cel("a && b", {"a": True, "b": True}) is True
assert evaluate_cel("a && b", {"a": True, "b": False}) is False
assert evaluate_cel("a || b", {"a": False, "b": True}) is True
@pytest.mark.unit
def test_empty_expression_raises(self):
with pytest.raises(CelEvaluationError, match="Empty expression"):
evaluate_cel("", {})
@pytest.mark.unit
def test_whitespace_expression_raises(self):
with pytest.raises(CelEvaluationError, match="Empty expression"):
evaluate_cel(" ", {})
@pytest.mark.unit
def test_invalid_expression_raises(self):
with pytest.raises(CelEvaluationError):
evaluate_cel("invalid!!!", {})
@pytest.mark.unit
def test_missing_variable_raises(self):
with pytest.raises(CelEvaluationError):
evaluate_cel("undefined_var > 5", {})
class TestCelToPython:
@pytest.mark.unit
def test_bool(self):
result = cel_to_python(celpy.celtypes.BoolType(True))
assert result is True
@pytest.mark.unit
def test_int(self):
result = cel_to_python(celpy.celtypes.IntType(42))
assert result == 42
@pytest.mark.unit
def test_double(self):
result = cel_to_python(celpy.celtypes.DoubleType(3.14))
assert result == pytest.approx(3.14)
@pytest.mark.unit
def test_string(self):
result = cel_to_python(celpy.celtypes.StringType("hello"))
assert result == "hello"
@pytest.mark.unit
def test_list(self):
cel_list = celpy.celtypes.ListType([
celpy.celtypes.IntType(1),
celpy.celtypes.IntType(2),
])
result = cel_to_python(cel_list)
assert result == [1, 2]
@pytest.mark.unit
def test_map(self):
cel_map = celpy.celtypes.MapType({
celpy.celtypes.StringType("key"): celpy.celtypes.StringType("value"),
})
result = cel_to_python(cel_map)
assert result == {"key": "value"}
@pytest.mark.unit
def test_unknown_type_passthrough(self):
result = cel_to_python("raw_value")
assert result == "raw_value"

View File

@@ -0,0 +1,85 @@
"""Tests for application/agents/tools/cryptoprice.py"""
from unittest.mock import MagicMock, patch
import pytest
from application.agents.tools.cryptoprice import CryptoPriceTool
@pytest.fixture
def tool():
return CryptoPriceTool(config={})
@pytest.mark.unit
class TestCryptoPriceExecuteAction:
def test_unknown_action_raises(self, tool):
with pytest.raises(ValueError, match="Unknown action"):
tool.execute_action("invalid_action")
@patch("application.agents.tools.cryptoprice.requests.get")
def test_successful_price_fetch(self, mock_get, tool):
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {"USD": 65000}
mock_get.return_value = mock_resp
result = tool.execute_action("cryptoprice_get", symbol="BTC", currency="USD")
assert result["status_code"] == 200
assert result["price"] == 65000
assert "successfully" in result["message"]
@patch("application.agents.tools.cryptoprice.requests.get")
def test_currency_not_found(self, mock_get, tool):
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {"EUR": 60000}
mock_get.return_value = mock_resp
result = tool.execute_action("cryptoprice_get", symbol="BTC", currency="USD")
assert result["status_code"] == 200
assert "Couldn't find" in result["message"]
assert "price" not in result
@patch("application.agents.tools.cryptoprice.requests.get")
def test_api_failure(self, mock_get, tool):
mock_resp = MagicMock()
mock_resp.status_code = 500
mock_get.return_value = mock_resp
result = tool.execute_action("cryptoprice_get", symbol="BTC", currency="USD")
assert result["status_code"] == 500
assert "Failed" in result["message"]
@patch("application.agents.tools.cryptoprice.requests.get")
def test_symbol_case_insensitive(self, mock_get, tool):
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {"USD": 100}
mock_get.return_value = mock_resp
tool.execute_action("cryptoprice_get", symbol="btc", currency="usd")
called_url = mock_get.call_args[0][0]
assert "fsym=BTC" in called_url
assert "tsyms=USD" in called_url
@pytest.mark.unit
class TestCryptoPriceMetadata:
def test_actions_metadata(self, tool):
meta = tool.get_actions_metadata()
assert len(meta) == 1
assert meta[0]["name"] == "cryptoprice_get"
params = meta[0]["parameters"]
assert "symbol" in params["properties"]
assert "currency" in params["properties"]
assert "symbol" in params["required"]
assert "currency" in params["required"]
def test_config_requirements(self, tool):
assert tool.get_config_requirements() == {}

View File

@@ -0,0 +1,145 @@
"""Tests for application/agents/tools/duckduckgo.py"""
from unittest.mock import MagicMock, patch
import pytest
from application.agents.tools.duckduckgo import DuckDuckGoSearchTool
@pytest.fixture
def tool():
return DuckDuckGoSearchTool(config={})
@pytest.mark.unit
class TestDuckDuckGoExecuteAction:
def test_unknown_action_raises(self, tool):
with pytest.raises(ValueError, match="Unknown action"):
tool.execute_action("invalid")
@patch.object(DuckDuckGoSearchTool, "_get_ddgs_client")
def test_web_search_success(self, mock_client_factory, tool):
mock_client = MagicMock()
mock_client.text.return_value = [
{"title": "Result 1", "href": "https://example.com", "body": "snippet"}
]
mock_client_factory.return_value = mock_client
result = tool.execute_action("ddg_web_search", query="python")
assert result["status_code"] == 200
assert len(result["results"]) == 1
assert "successfully" in result["message"]
@patch.object(DuckDuckGoSearchTool, "_get_ddgs_client")
def test_image_search_success(self, mock_client_factory, tool):
mock_client = MagicMock()
mock_client.images.return_value = [{"image": "https://img.com/1.jpg"}]
mock_client_factory.return_value = mock_client
result = tool.execute_action("ddg_image_search", query="cats")
assert result["status_code"] == 200
assert len(result["results"]) == 1
@patch.object(DuckDuckGoSearchTool, "_get_ddgs_client")
def test_news_search_success(self, mock_client_factory, tool):
mock_client = MagicMock()
mock_client.news.return_value = [{"title": "News"}]
mock_client_factory.return_value = mock_client
result = tool.execute_action("ddg_news_search", query="tech")
assert result["status_code"] == 200
assert len(result["results"]) == 1
@patch.object(DuckDuckGoSearchTool, "_get_ddgs_client")
def test_search_error_returns_500(self, mock_client_factory, tool):
mock_client = MagicMock()
mock_client.text.side_effect = Exception("Network error")
mock_client_factory.return_value = mock_client
result = tool.execute_action("ddg_web_search", query="test")
assert result["status_code"] == 500
assert "failed" in result["message"].lower()
assert result["results"] == []
@patch.object(DuckDuckGoSearchTool, "_get_ddgs_client")
def test_max_results_capped_at_20(self, mock_client_factory, tool):
mock_client = MagicMock()
mock_client.text.return_value = []
mock_client_factory.return_value = mock_client
tool.execute_action("ddg_web_search", query="test", max_results=100)
call_kwargs = mock_client.text.call_args[1]
assert call_kwargs["max_results"] == 20
@patch.object(DuckDuckGoSearchTool, "_get_ddgs_client")
def test_image_max_results_capped_at_50(self, mock_client_factory, tool):
mock_client = MagicMock()
mock_client.images.return_value = []
mock_client_factory.return_value = mock_client
tool.execute_action("ddg_image_search", query="test", max_results=200)
call_kwargs = mock_client.images.call_args[1]
assert call_kwargs["max_results"] == 50
@patch("application.agents.tools.duckduckgo.time.sleep")
@patch.object(DuckDuckGoSearchTool, "_get_ddgs_client")
def test_rate_limit_retries(self, mock_client_factory, mock_sleep, tool):
mock_client = MagicMock()
mock_client.text.side_effect = [
Exception("RateLimit exceeded"),
[{"title": "Result"}],
]
mock_client_factory.return_value = mock_client
result = tool.execute_action("ddg_web_search", query="test")
assert result["status_code"] == 200
assert len(result["results"]) == 1
mock_sleep.assert_called_once()
@patch.object(DuckDuckGoSearchTool, "_get_ddgs_client")
def test_empty_results(self, mock_client_factory, tool):
mock_client = MagicMock()
mock_client.text.return_value = []
mock_client_factory.return_value = mock_client
result = tool.execute_action("ddg_web_search", query="obscure query")
assert result["status_code"] == 200
assert result["results"] == []
@patch.object(DuckDuckGoSearchTool, "_get_ddgs_client")
def test_none_results(self, mock_client_factory, tool):
mock_client = MagicMock()
mock_client.text.return_value = None
mock_client_factory.return_value = mock_client
result = tool.execute_action("ddg_web_search", query="test")
assert result["status_code"] == 200
assert result["results"] == []
@pytest.mark.unit
class TestDuckDuckGoMetadata:
def test_actions_metadata(self, tool):
meta = tool.get_actions_metadata()
assert len(meta) == 3
names = {a["name"] for a in meta}
assert "ddg_web_search" in names
assert "ddg_image_search" in names
assert "ddg_news_search" in names
def test_config_requirements(self, tool):
assert tool.get_config_requirements() == {}
def test_custom_timeout(self):
tool = DuckDuckGoSearchTool(config={"timeout": 30})
assert tool.timeout == 30

View File

@@ -1,6 +1,6 @@
"""Tests for InternalSearchTool and its helper functions."""
from unittest.mock import Mock, patch
from unittest.mock import MagicMock, Mock, patch
import pytest
from application.agents.tools.internal_search import (
@@ -248,3 +248,452 @@ class TestBuildHelpers:
tools_dict = {}
add_internal_search_tool(tools_dict, {})
assert INTERNAL_TOOL_ID not in tools_dict
@pytest.mark.unit
class TestInternalSearchToolGetRetriever:
"""Cover line 32: _get_retriever creates retriever lazily."""
def test_get_retriever_creates_retriever(self):
tool = InternalSearchTool({
"source": {},
"retriever_name": "classic",
"chunks": 2,
})
assert tool._retriever is None
mock_retriever = Mock()
with patch(
"application.agents.tools.internal_search.RetrieverCreator"
) as mock_rc:
mock_rc.create_retriever.return_value = mock_retriever
result = tool._get_retriever()
assert result is mock_retriever
assert tool._retriever is mock_retriever
def test_get_retriever_cached(self):
"""Cover line 32: second call returns cached retriever."""
tool = InternalSearchTool({"source": {}, "retriever_name": "classic"})
mock_retriever = Mock()
tool._retriever = mock_retriever
result = tool._get_retriever()
assert result is mock_retriever
@pytest.mark.unit
class TestGetDirectoryStructure:
"""Cover lines 61: _get_directory_structure loads from MongoDB."""
def test_no_active_docs_returns_none(self):
"""Cover line 56-57: no active_docs returns None."""
tool = InternalSearchTool({"source": {}})
result = tool._get_directory_structure()
assert result is None
assert tool._dir_structure_loaded is True
def test_loads_structure_from_mongo(self):
"""Cover line 61+: loads directory structure from MongoDB."""
from bson.objectid import ObjectId
doc_id = str(ObjectId())
tool = InternalSearchTool({
"source": {"active_docs": [doc_id]},
})
mock_source_doc = {
"_id": ObjectId(doc_id),
"name": "test_source",
"directory_structure": {"root": {"file.txt": {"type": "text"}}},
}
with patch(
"application.core.mongo_db.MongoDB"
) as mock_mongo:
mock_db = Mock()
mock_collection = Mock()
mock_collection.find_one.return_value = mock_source_doc
mock_db.__getitem__ = Mock(return_value=mock_collection)
mock_mongo.get_client.return_value = Mock(
__getitem__=Mock(return_value=mock_db)
)
result = tool._get_directory_structure()
assert result == {"root": {"file.txt": {"type": "text"}}}
def test_loads_string_structure_from_mongo(self):
"""Cover line 80-81: directory_structure stored as JSON string."""
from bson.objectid import ObjectId
doc_id = str(ObjectId())
tool = InternalSearchTool({
"source": {"active_docs": [doc_id]},
})
mock_source_doc = {
"_id": ObjectId(doc_id),
"name": "test_source",
"directory_structure": '{"root": {"file.txt": {}}}',
}
with patch(
"application.core.mongo_db.MongoDB"
) as mock_mongo:
mock_db = Mock()
mock_collection = Mock()
mock_collection.find_one.return_value = mock_source_doc
mock_db.__getitem__ = Mock(return_value=mock_collection)
mock_mongo.get_client.return_value = Mock(
__getitem__=Mock(return_value=mock_db)
)
result = tool._get_directory_structure()
assert result == {"root": {"file.txt": {}}}
def test_multiple_active_docs_merged(self):
"""Cover line 83-84: multiple docs merge under source names."""
from bson.objectid import ObjectId
doc_id1 = str(ObjectId())
doc_id2 = str(ObjectId())
tool = InternalSearchTool({
"source": {"active_docs": [doc_id1, doc_id2]},
})
docs = {
doc_id1: {
"_id": ObjectId(doc_id1),
"name": "source1",
"directory_structure": {"file1.txt": {}},
},
doc_id2: {
"_id": ObjectId(doc_id2),
"name": "source2",
"directory_structure": {"file2.txt": {}},
},
}
with patch(
"application.core.mongo_db.MongoDB"
) as mock_mongo:
mock_db = Mock()
mock_collection = Mock()
mock_collection.find_one.side_effect = lambda q: docs.get(
str(q["_id"])
)
mock_db.__getitem__ = Mock(return_value=mock_collection)
mock_mongo.get_client.return_value = Mock(
__getitem__=Mock(return_value=mock_db)
)
result = tool._get_directory_structure()
assert "source1" in result
assert "source2" in result
@pytest.mark.unit
class TestFormatStructureAdditional:
"""Cover lines 186, 193, 200, 221: format structure branches."""
def test_format_structure_non_dict_node(self):
"""Cover line 173: non-dict node returns file message."""
tool = InternalSearchTool({"source": {}})
result = tool._format_structure("a string node", "/path")
assert "is a file" in result
def test_format_structure_file_with_type_metadata(self):
"""Cover lines 186-193: file with type and token_count metadata."""
tool = InternalSearchTool({"source": {}})
node = {
"readme.md": {"type": "markdown", "token_count": 500},
"data.json": {"size_bytes": 1024},
}
result = tool._format_structure(node, "/root")
assert "readme.md" in result
assert "500 tokens" in result
def test_format_structure_empty_directory(self):
"""Cover lines 206-208: empty directory."""
tool = InternalSearchTool({"source": {}})
result = tool._format_structure({}, "/empty")
assert "(empty)" in result
def test_format_structure_plain_file_entry(self):
"""Cover line 198: plain file entry (non-dict value)."""
tool = InternalSearchTool({"source": {}})
node = {"file.txt": "some_value"}
result = tool._format_structure(node, "/root")
assert "file.txt" in result
def test_count_files_nested(self):
"""Cover line 221: _count_files counts nested files."""
tool = InternalSearchTool({"source": {}})
node = {
"sub": {"file1.txt": {"type": "text"}},
"file2.txt": "plain",
}
count = tool._count_files(node)
assert count == 2
@pytest.mark.unit
class TestSourcesHaveDirectoryStructure:
"""Cover line 240, 254, 298: sources_have_directory_structure helper."""
def test_no_active_docs_returns_false(self):
from application.agents.tools.internal_search import (
sources_have_directory_structure,
)
assert sources_have_directory_structure({}) is False
assert sources_have_directory_structure({"active_docs": []}) is False
def test_with_structure_returns_true(self):
from bson.objectid import ObjectId
from application.agents.tools.internal_search import (
sources_have_directory_structure,
)
doc_id = str(ObjectId())
mock_source_doc = {
"_id": ObjectId(doc_id),
"directory_structure": {"root": {}},
}
with patch(
"application.core.mongo_db.MongoDB"
) as mock_mongo:
mock_db = Mock()
mock_collection = Mock()
mock_collection.find_one.return_value = mock_source_doc
mock_db.__getitem__ = Mock(return_value=mock_collection)
mock_mongo.get_client.return_value = Mock(
__getitem__=Mock(return_value=mock_db)
)
result = sources_have_directory_structure({"active_docs": [doc_id]})
assert result is True
def test_string_active_docs_converted_to_list(self):
"""Cover line 298: active_docs as string is converted to list."""
from bson.objectid import ObjectId
from application.agents.tools.internal_search import (
sources_have_directory_structure,
)
doc_id = str(ObjectId())
mock_source_doc = {
"_id": ObjectId(doc_id),
"directory_structure": {"root": {}},
}
with patch(
"application.core.mongo_db.MongoDB"
) as mock_mongo:
mock_db = Mock()
mock_collection = Mock()
mock_collection.find_one.return_value = mock_source_doc
mock_db.__getitem__ = Mock(return_value=mock_collection)
mock_mongo.get_client.return_value = Mock(
__getitem__=Mock(return_value=mock_db)
)
result = sources_have_directory_structure({"active_docs": doc_id})
assert result is True
def test_get_config_requirements(self):
"""Cover line 280: get_config_requirements."""
tool = InternalSearchTool({"source": {}})
assert tool.get_config_requirements() == {}
# ---------------------------------------------------------------------------
# Coverage — additional uncovered lines: 77, 135, 186, 200, 221, 240, 254, 298
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestInternalSearchToolAdditionalCoverage:
def test_get_directory_structure_returns_cached(self):
"""Cover line 77: source_doc not found in DB returns None."""
tool = InternalSearchTool({"source": {"active_docs": ["nonexistent"]}})
tool._dir_structure_loaded = True
tool._directory_structure = {"cached": True}
result = tool._get_directory_structure()
assert result == {"cached": True}
def test_execute_search_appends_to_retrieved_docs(self):
"""Cover line 135: doc appended to retrieved_docs."""
tool = InternalSearchTool({"source": {}})
mock_retriever = Mock()
mock_retriever.search.return_value = [
{"title": "Doc1", "text": "content", "source": "src"},
]
tool._retriever = mock_retriever
tool._execute_search(query="test")
assert len(tool.retrieved_docs) == 1
def test_format_structure_file_metadata(self):
"""Cover line 186: file with metadata (type, token_count)."""
tool = InternalSearchTool({"source": {}})
node = {
"readme.md": {"type": "markdown", "token_count": 100},
"subfolder": {"nested_file.py": {}},
}
result = tool._format_structure(node, "/")
assert "readme.md" in result
assert "markdown" in result
assert "100 tokens" in result
def test_format_structure_folders_and_files(self):
"""Cover line 200: folders and files sections in output."""
tool = InternalSearchTool({"source": {}})
node = {
"src": {"main.py": {}},
"README.md": "file",
}
result = tool._format_structure(node, "/")
assert "Folders:" in result
assert "Files:" in result
def test_count_files_recursive(self):
"""Cover line 221: _count_files counts nested files."""
tool = InternalSearchTool({"source": {}})
node = {
"a.py": "file",
"subdir": {
"b.py": {"type": "python", "token_count": 50},
},
}
count = tool._count_files(node)
assert count == 2
def test_get_actions_metadata_with_directory_structure(self):
"""Cover line 240+: actions include path_filter and list_files."""
tool = InternalSearchTool({"source": {}, "has_directory_structure": True})
actions = tool.get_actions_metadata()
action_names = [a["name"] for a in actions]
assert "search" in action_names
assert "list_files" in action_names
# Check path_filter is in search params
search_action = next(a for a in actions if a["name"] == "search")
assert "path_filter" in search_action["parameters"]["properties"]
def test_get_actions_metadata_without_directory_structure(self):
"""Cover line 254: actions without directory structure."""
tool = InternalSearchTool({"source": {}, "has_directory_structure": False})
actions = tool.get_actions_metadata()
action_names = [a["name"] for a in actions]
assert "search" in action_names
assert "list_files" not in action_names
def test_build_internal_tool_entry_with_directory_structure(self):
"""Cover line 298: build_internal_tool_entry with has_directory_structure."""
entry = build_internal_tool_entry(has_directory_structure=True)
action_names = [a["name"] for a in entry["actions"]]
assert "list_files" in action_names
search_action = next(a for a in entry["actions"] if a["name"] == "search")
assert "path_filter" in search_action["parameters"]["properties"]
# ---------------------------------------------------------------------------
# Additional coverage for internal_search.py
# Lines: 101 (unknown action), 108 (empty query), 114-115 (search exception),
# 117-118 (no docs), 130-131 (path filter no match),
# 154-155 (no dir structure), 165-166 (path not found)
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestInternalSearchUnknownAction:
"""Cover line 101: unknown action returns error string."""
def test_unknown_action(self):
tool = InternalSearchTool({"source": {}})
result = tool.execute_action("unknown_action")
assert "Unknown action" in result
@pytest.mark.unit
class TestInternalSearchEmptyQuery:
"""Cover line 108: empty query returns error."""
def test_empty_query(self):
tool = InternalSearchTool({"source": {}})
result = tool.execute_action("search", query="")
assert "required" in result.lower()
@pytest.mark.unit
class TestInternalSearchException:
"""Cover lines 114-115: search exception returns error."""
def test_search_raises(self):
tool = InternalSearchTool({"source": {}})
mock_retriever = MagicMock()
mock_retriever.search.side_effect = RuntimeError("DB down")
tool._get_retriever = MagicMock(return_value=mock_retriever)
result = tool.execute_action("search", query="hello")
assert "internal error" in result.lower()
@pytest.mark.unit
class TestInternalSearchNoDocs:
"""Cover lines 117-118: no docs found."""
def test_no_docs(self):
tool = InternalSearchTool({"source": {}})
mock_retriever = MagicMock()
mock_retriever.search.return_value = []
tool._get_retriever = MagicMock(return_value=mock_retriever)
result = tool.execute_action("search", query="hello")
assert "No documents found" in result
@pytest.mark.unit
class TestInternalSearchPathFilterNoMatch:
"""Cover lines 130-131: path filter with no matching docs."""
def test_path_filter_no_match(self):
tool = InternalSearchTool({"source": {}})
mock_retriever = MagicMock()
mock_retriever.search.return_value = [
{"source": "other.txt", "text": "data", "title": "Other"}
]
tool._get_retriever = MagicMock(return_value=mock_retriever)
result = tool.execute_action(
"search", query="hello", path_filter="nonexistent"
)
assert "No documents found" in result
assert "nonexistent" in result
@pytest.mark.unit
class TestInternalSearchListFilesNoDirStructure:
"""Cover lines 154-155: no directory structure."""
def test_no_dir_structure(self):
tool = InternalSearchTool({"source": {}})
tool._get_directory_structure = MagicMock(return_value=None)
result = tool.execute_action("list_files")
assert "No file structure" in result
@pytest.mark.unit
class TestInternalSearchListFilesPathNotFound:
"""Cover lines 165-166: path not found."""
def test_path_not_found(self):
tool = InternalSearchTool({"source": {}})
tool._get_directory_structure = MagicMock(
return_value={"folder": {"file.txt": {}}}
)
result = tool.execute_action("list_files", path="missing_dir")
assert "not found" in result.lower()

View File

@@ -0,0 +1,519 @@
"""Tests for application/agents/tools/mcp_tool.py"""
from unittest.mock import MagicMock, patch
import pytest
# mcp_tool has a circular import at module level (mcp_tool -> tasks -> user -> mcp.py -> mcp_tool).
# We must patch the dependencies BEFORE the module is first imported.
@pytest.fixture(autouse=True)
def _patch_mcp_globals(monkeypatch):
"""Patch module-level MongoDB and cache to avoid real connections."""
import sys
# If the module is already loaded, just patch attributes directly
if "application.agents.tools.mcp_tool" in sys.modules:
mcp_mod = sys.modules["application.agents.tools.mcp_tool"]
else:
# Break the circular import by pre-populating the tasks import
# with a mock before mcp_tool tries to import it
mock_tasks = MagicMock()
monkeypatch.setitem(sys.modules, "application.api.user.tasks", mock_tasks)
import application.agents.tools.mcp_tool as mcp_mod
mock_mongo = MagicMock()
mock_db = MagicMock()
mock_db.__getitem__ = MagicMock(return_value=MagicMock())
monkeypatch.setattr(mcp_mod, "mongo", mock_mongo)
monkeypatch.setattr(mcp_mod, "db", mock_db)
monkeypatch.setattr(mcp_mod, "_mcp_clients_cache", {})
@pytest.fixture
def mcp_config():
return {
"server_url": "https://mcp.example.com/api",
"transport_type": "http",
"auth_type": "none",
"timeout": 10,
}
@pytest.fixture
def bearer_config():
return {
"server_url": "https://mcp.example.com/api",
"transport_type": "http",
"auth_type": "bearer",
"auth_credentials": {"bearer_token": "tok_123"},
"timeout": 10,
}
@pytest.mark.unit
class TestMCPToolInit:
def test_basic_init(self, mcp_config):
from application.agents.tools.mcp_tool import MCPTool
with patch.object(MCPTool, "_setup_client"):
tool = MCPTool(mcp_config)
assert tool.server_url == "https://mcp.example.com/api"
assert tool.transport_type == "http"
assert tool.auth_type == "none"
assert tool.timeout == 10
def test_bearer_auth_credentials(self, bearer_config):
from application.agents.tools.mcp_tool import MCPTool
with patch.object(MCPTool, "_setup_client"):
tool = MCPTool(bearer_config)
assert tool.auth_credentials["bearer_token"] == "tok_123"
def test_no_server_url_skips_setup(self):
from application.agents.tools.mcp_tool import MCPTool
with patch.object(MCPTool, "_setup_client") as mock_setup:
MCPTool({"server_url": "", "auth_type": "none"})
mock_setup.assert_not_called()
def test_oauth_skips_setup(self):
from application.agents.tools.mcp_tool import MCPTool
with patch.object(MCPTool, "_setup_client") as mock_setup:
MCPTool(
{
"server_url": "https://mcp.example.com",
"auth_type": "oauth",
}
)
mock_setup.assert_not_called()
@pytest.mark.unit
class TestGenerateCacheKey:
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_none_auth(self, mock_setup, mcp_config):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(mcp_config)
assert "none" in tool._cache_key
assert "mcp.example.com" in tool._cache_key
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_bearer_auth(self, mock_setup, bearer_config):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(bearer_config)
assert "bearer:" in tool._cache_key
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_api_key_auth(self, mock_setup):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(
{
"server_url": "https://mcp.example.com",
"auth_type": "api_key",
"auth_credentials": {"api_key": "sk-test12345678"},
}
)
assert "apikey:" in tool._cache_key
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_basic_auth(self, mock_setup):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(
{
"server_url": "https://mcp.example.com",
"auth_type": "basic",
"auth_credentials": {"username": "user1", "password": "pass"},
}
)
assert "basic:user1" in tool._cache_key
@pytest.mark.unit
class TestCreateTransport:
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_http_transport(self, mock_setup, mcp_config):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(mcp_config)
transport = tool._create_transport()
assert transport is not None
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_sse_transport(self, mock_setup):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(
{
"server_url": "https://mcp.example.com/sse",
"transport_type": "sse",
"auth_type": "none",
}
)
transport = tool._create_transport()
assert transport is not None
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_auto_detects_sse(self, mock_setup):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(
{
"server_url": "https://mcp.example.com/sse",
"transport_type": "auto",
"auth_type": "none",
}
)
transport = tool._create_transport()
assert transport is not None
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_stdio_transport_disabled(self, mock_setup):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(
{
"server_url": "https://mcp.example.com",
"transport_type": "stdio",
"auth_type": "none",
}
)
with pytest.raises(ValueError, match="STDIO transport is disabled"):
tool._create_transport()
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_api_key_header_injected(self, mock_setup):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(
{
"server_url": "https://mcp.example.com",
"transport_type": "http",
"auth_type": "api_key",
"auth_credentials": {
"api_key": "sk-test",
"api_key_header": "X-Custom-Key",
},
}
)
# _create_transport will be called; verify it doesn't raise
transport = tool._create_transport()
assert transport is not None
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_basic_auth_header_injected(self, mock_setup):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(
{
"server_url": "https://mcp.example.com",
"transport_type": "http",
"auth_type": "basic",
"auth_credentials": {"username": "user", "password": "pass"},
}
)
transport = tool._create_transport()
assert transport is not None
@pytest.mark.unit
class TestFormatTools:
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_format_list_of_dicts(self, mock_setup, mcp_config):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(mcp_config)
result = tool._format_tools([{"name": "tool1", "description": "desc"}])
assert len(result) == 1
assert result[0]["name"] == "tool1"
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_format_tools_with_name_attribute(self, mock_setup, mcp_config):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(mcp_config)
mock_tool = MagicMock()
mock_tool.name = "my_tool"
mock_tool.description = "A tool"
mock_tool.inputSchema = {"type": "object", "properties": {}}
result = tool._format_tools([mock_tool])
assert len(result) == 1
assert result[0]["name"] == "my_tool"
assert result[0]["inputSchema"] == {"type": "object", "properties": {}}
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_format_tools_response_object(self, mock_setup, mcp_config):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(mcp_config)
resp = MagicMock()
resp.tools = [{"name": "t1", "description": "d1"}]
result = tool._format_tools(resp)
assert len(result) == 1
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_format_empty(self, mock_setup, mcp_config):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(mcp_config)
assert tool._format_tools([]) == []
assert tool._format_tools("unexpected") == []
@pytest.mark.unit
class TestFormatResult:
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_format_result_with_content(self, mock_setup, mcp_config):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(mcp_config)
mock_result = MagicMock()
text_item = MagicMock()
text_item.text = "Hello"
del text_item.data
mock_result.content = [text_item]
mock_result.isError = False
result = tool._format_result(mock_result)
assert result["content"][0]["type"] == "text"
assert result["content"][0]["text"] == "Hello"
assert result["isError"] is False
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_format_raw_result(self, mock_setup, mcp_config):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(mcp_config)
raw = {"key": "value"}
assert tool._format_result(raw) == raw
@pytest.mark.unit
class TestExecuteAction:
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_no_server_raises(self, mock_setup):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool({"server_url": "", "auth_type": "none"})
with pytest.raises(Exception, match="No MCP server configured"):
tool.execute_action("test_action")
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
@patch("application.agents.tools.mcp_tool.MCPTool._run_async_operation")
def test_successful_execute(self, mock_run, mock_setup, mcp_config):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(mcp_config)
tool._client = MagicMock()
mock_run.return_value = {"key": "value"}
result = tool.execute_action("test_action", param1="val1")
mock_run.assert_called_once_with("call_tool", "test_action", param1="val1")
assert result == {"key": "value"}
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
@patch("application.agents.tools.mcp_tool.MCPTool._run_async_operation")
def test_empty_kwargs_cleaned(self, mock_run, mock_setup, mcp_config):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(mcp_config)
tool._client = MagicMock()
mock_run.return_value = {}
tool.execute_action("test", param1="", param2=None, param3="real")
call_kwargs = mock_run.call_args[1]
assert "param1" not in call_kwargs
assert "param2" not in call_kwargs
assert call_kwargs["param3"] == "real"
@pytest.mark.unit
class TestTestConnection:
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_no_server_url(self, mock_setup):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool({"server_url": "", "auth_type": "none"})
result = tool.test_connection()
assert result["success"] is False
assert "No server URL" in result["message"]
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_invalid_scheme(self, mock_setup):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(
{"server_url": "ftp://bad.com", "auth_type": "none"}
)
result = tool.test_connection()
assert result["success"] is False
assert "Invalid URL scheme" in result["message"]
@pytest.mark.unit
class TestMapError:
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_timeout_error(self, mock_setup, mcp_config):
from application.agents.tools.mcp_tool import MCPTool
import concurrent.futures
tool = MCPTool(mcp_config)
err = tool._map_error("test", concurrent.futures.TimeoutError())
assert "Timed out" in str(err)
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_connection_refused(self, mock_setup, mcp_config):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(mcp_config)
err = tool._map_error("test", ConnectionRefusedError())
assert "Connection refused" in str(err)
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_403_forbidden(self, mock_setup, mcp_config):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(mcp_config)
err = tool._map_error("test", Exception("403 Forbidden"))
assert "Access denied" in str(err)
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_ssl_error(self, mock_setup, mcp_config):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(mcp_config)
err = tool._map_error("test", Exception("SSL certificate verify failed"))
assert "SSL" in str(err)
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_unknown_error_passthrough(self, mock_setup, mcp_config):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(mcp_config)
original = RuntimeError("something weird")
err = tool._map_error("test", original)
assert err is original
@pytest.mark.unit
class TestGetActionsMetadata:
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_empty_tools(self, mock_setup, mcp_config):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(mcp_config)
tool.available_tools = []
assert tool.get_actions_metadata() == []
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_tools_with_input_schema(self, mock_setup, mcp_config):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(mcp_config)
tool.available_tools = [
{
"name": "search",
"description": "Search things",
"inputSchema": {
"type": "object",
"properties": {"query": {"type": "string"}},
"required": ["query"],
},
}
]
meta = tool.get_actions_metadata()
assert len(meta) == 1
assert meta[0]["name"] == "search"
assert "query" in meta[0]["parameters"]["properties"]
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_tools_without_schema(self, mock_setup, mcp_config):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(mcp_config)
tool.available_tools = [{"name": "ping", "description": "Ping"}]
meta = tool.get_actions_metadata()
assert meta[0]["parameters"]["properties"] == {}
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_config_requirements(self, mock_setup, mcp_config):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(mcp_config)
reqs = tool.get_config_requirements()
assert "server_url" in reqs
assert "auth_type" in reqs
assert reqs["server_url"]["required"] is True
@pytest.mark.unit
class TestMCPOAuthManager:
def test_handle_callback_success(self):
from application.agents.tools.mcp_tool import MCPOAuthManager
mock_redis = MagicMock()
manager = MCPOAuthManager(mock_redis)
result = manager.handle_oauth_callback(state="abc123", code="auth_code")
assert result is True
mock_redis.setex.assert_called()
def test_handle_callback_no_redis(self):
from application.agents.tools.mcp_tool import MCPOAuthManager
manager = MCPOAuthManager(None)
result = manager.handle_oauth_callback(state="abc", code="code")
assert result is False
def test_handle_callback_error(self):
from application.agents.tools.mcp_tool import MCPOAuthManager
mock_redis = MagicMock()
manager = MCPOAuthManager(mock_redis)
result = manager.handle_oauth_callback(
state="abc", code="", error="access_denied"
)
assert result is False
def test_get_oauth_status_no_task(self):
from application.agents.tools.mcp_tool import MCPOAuthManager
manager = MCPOAuthManager(MagicMock())
result = manager.get_oauth_status("")
assert result["status"] == "not_started"
@pytest.mark.unit
class TestDBTokenStorage:
def test_get_base_url(self):
from application.agents.tools.mcp_tool import DBTokenStorage
assert (
DBTokenStorage.get_base_url("https://mcp.example.com/api/v1")
== "https://mcp.example.com"
)
def test_get_db_key(self):
from application.agents.tools.mcp_tool import DBTokenStorage
mock_db = MagicMock()
storage = DBTokenStorage(
server_url="https://mcp.example.com/api",
user_id="user1",
db_client=mock_db,
)
key = storage.get_db_key()
assert key["server_url"] == "https://mcp.example.com"
assert key["user_id"] == "user1"

View File

@@ -0,0 +1,140 @@
import pytest
@pytest.mark.unit
class TestToolFilterMixin:
def test_get_user_tools_filters_by_allowed_ids(self):
from application.agents.workflows.node_agent import ToolFilterMixin
class FakeBase:
def _get_user_tools(self, user="local"):
return {
"t1": {"_id": "id1", "name": "tool1"},
"t2": {"_id": "id2", "name": "tool2"},
"t3": {"_id": "id3", "name": "tool3"},
}
class TestClass(ToolFilterMixin, FakeBase):
pass
obj = TestClass()
obj._allowed_tool_ids = ["id1", "id3"]
result = obj._get_user_tools("user1")
assert "t1" in result
assert "t3" in result
assert "t2" not in result
def test_get_user_tools_returns_empty_when_no_allowed(self):
from application.agents.workflows.node_agent import ToolFilterMixin
class FakeBase:
def _get_user_tools(self, user="local"):
return {"t1": {"_id": "id1"}}
class TestClass(ToolFilterMixin, FakeBase):
pass
obj = TestClass()
obj._allowed_tool_ids = []
result = obj._get_user_tools()
assert result == {}
def test_get_tools_filters_by_allowed_ids(self):
from application.agents.workflows.node_agent import ToolFilterMixin
class FakeBase:
def _get_tools(self, api_key=None):
return {
"t1": {"_id": "id1"},
"t2": {"_id": "id2"},
}
class TestClass(ToolFilterMixin, FakeBase):
pass
obj = TestClass()
obj._allowed_tool_ids = ["id2"]
result = obj._get_tools("key")
assert "t2" in result
assert "t1" not in result
def test_get_tools_returns_empty_when_no_allowed(self):
from application.agents.workflows.node_agent import ToolFilterMixin
class FakeBase:
def _get_tools(self, api_key=None):
return {"t1": {"_id": "id1"}}
class TestClass(ToolFilterMixin, FakeBase):
pass
obj = TestClass()
obj._allowed_tool_ids = []
result = obj._get_tools()
assert result == {}
@pytest.mark.unit
class TestWorkflowNodeAgentFactory:
def test_raises_on_unsupported_type(self):
from application.agents.workflows.node_agent import WorkflowNodeAgentFactory
with pytest.raises(ValueError, match="Unsupported agent type"):
WorkflowNodeAgentFactory.create(
agent_type="nonexistent",
endpoint="http://example.com",
llm_name="openai",
model_id="gpt-4",
api_key="key",
)
# =====================================================================
# Coverage gap tests (lines 52-59: _WorkflowNodeMixin.__init__)
# =====================================================================
@pytest.mark.unit
class TestWorkflowNodeMixinInit:
def test_mixin_init_sets_allowed_tool_ids(self):
"""Cover lines 52-59: _WorkflowNodeMixin.__init__ stores tool_ids."""
from application.agents.workflows.node_agent import _WorkflowNodeMixin
class FakeBase:
def __init__(self, *args, **kwargs):
pass
class TestMixin(_WorkflowNodeMixin, FakeBase):
pass
obj = TestMixin(
endpoint="http://example.com",
llm_name="openai",
model_id="gpt-4",
api_key="key",
tool_ids=["tool1", "tool2"],
)
assert obj._allowed_tool_ids == ["tool1", "tool2"]
def test_mixin_init_defaults_empty_tool_ids(self):
"""Cover: _WorkflowNodeMixin defaults to empty list."""
from application.agents.workflows.node_agent import _WorkflowNodeMixin
class FakeBase:
def __init__(self, *args, **kwargs):
pass
class TestMixin(_WorkflowNodeMixin, FakeBase):
pass
obj = TestMixin(
endpoint="http://example.com",
llm_name="openai",
model_id="gpt-4",
api_key="key",
)
assert obj._allowed_tool_ids == []

View File

@@ -0,0 +1,146 @@
"""Tests for application/agents/tools/ntfy.py"""
from unittest.mock import MagicMock, patch
import pytest
from application.agents.tools.ntfy import NtfyTool
@pytest.fixture
def tool():
return NtfyTool(config={"token": "test_token"})
@pytest.fixture
def tool_no_token():
return NtfyTool(config={})
@pytest.mark.unit
class TestNtfyExecuteAction:
def test_unknown_action_raises(self, tool):
with pytest.raises(ValueError, match="Unknown action"):
tool.execute_action("bad_action")
@patch("application.agents.tools.ntfy.requests.post")
def test_send_message_basic(self, mock_post, tool):
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_post.return_value = mock_resp
result = tool.execute_action(
"ntfy_send_message",
server_url="https://ntfy.sh",
message="Hello",
topic="test",
)
assert result["status_code"] == 200
assert result["message"] == "Message sent"
mock_post.assert_called_once()
call_args = mock_post.call_args
assert call_args[0][0] == "https://ntfy.sh/test"
assert call_args[1]["data"] == b"Hello"
@patch("application.agents.tools.ntfy.requests.post")
def test_send_with_title_and_priority(self, mock_post, tool):
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_post.return_value = mock_resp
tool.execute_action(
"ntfy_send_message",
server_url="https://ntfy.sh",
message="Alert",
topic="urgent",
title="Warning",
priority=5,
)
headers = mock_post.call_args[1]["headers"]
assert headers["X-Title"] == "Warning"
assert headers["X-Priority"] == "5"
@patch("application.agents.tools.ntfy.requests.post")
def test_auth_header_with_token(self, mock_post, tool):
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_post.return_value = mock_resp
tool.execute_action(
"ntfy_send_message",
server_url="https://ntfy.sh",
message="Hi",
topic="t",
)
headers = mock_post.call_args[1]["headers"]
assert headers["Authorization"] == "Basic test_token"
@patch("application.agents.tools.ntfy.requests.post")
def test_no_auth_without_token(self, mock_post, tool_no_token):
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_post.return_value = mock_resp
tool_no_token.execute_action(
"ntfy_send_message",
server_url="https://ntfy.sh",
message="Hi",
topic="t",
)
headers = mock_post.call_args[1]["headers"]
assert "Authorization" not in headers
def test_invalid_priority_raises(self, tool):
with pytest.raises(ValueError, match="between 1 and 5"):
tool.execute_action(
"ntfy_send_message",
server_url="https://ntfy.sh",
message="Hi",
topic="t",
priority=10,
)
def test_non_numeric_priority_raises(self, tool):
with pytest.raises(ValueError, match="convertible to an integer"):
tool.execute_action(
"ntfy_send_message",
server_url="https://ntfy.sh",
message="Hi",
topic="t",
priority="abc",
)
@patch("application.agents.tools.ntfy.requests.post")
def test_trailing_slash_stripped(self, mock_post, tool):
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_post.return_value = mock_resp
tool.execute_action(
"ntfy_send_message",
server_url="https://ntfy.sh/",
message="Hi",
topic="test",
)
assert mock_post.call_args[0][0] == "https://ntfy.sh/test"
@pytest.mark.unit
class TestNtfyMetadata:
def test_actions_metadata(self, tool):
meta = tool.get_actions_metadata()
assert len(meta) == 1
assert meta[0]["name"] == "ntfy_send_message"
assert "server_url" in meta[0]["parameters"]["properties"]
assert "message" in meta[0]["parameters"]["properties"]
assert "topic" in meta[0]["parameters"]["properties"]
def test_config_requirements(self, tool):
reqs = tool.get_config_requirements()
assert "token" in reqs
assert reqs["token"]["secret"] is True

View File

@@ -0,0 +1,146 @@
"""Tests for application/agents/tools/postgres.py"""
from unittest.mock import MagicMock, patch
import pytest
from application.agents.tools.postgres import PostgresTool
@pytest.fixture
def tool():
return PostgresTool(config={"token": "postgresql://user:pass@localhost/testdb"})
@pytest.mark.unit
class TestPostgresExecuteAction:
def test_unknown_action_raises(self, tool):
with pytest.raises(ValueError, match="Unknown action"):
tool.execute_action("invalid")
@patch("application.agents.tools.postgres.psycopg2.connect")
def test_select_query(self, mock_connect, tool):
mock_conn = MagicMock()
mock_cur = MagicMock()
mock_cur.description = [("id",), ("name",)]
mock_cur.fetchall.return_value = [(1, "Alice"), (2, "Bob")]
mock_conn.cursor.return_value = mock_cur
mock_connect.return_value = mock_conn
result = tool.execute_action(
"postgres_execute_sql", sql_query="SELECT id, name FROM users"
)
assert result["status_code"] == 200
assert result["response_data"]["column_names"] == ["id", "name"]
assert len(result["response_data"]["data"]) == 2
assert result["response_data"]["data"][0] == {"id": 1, "name": "Alice"}
mock_conn.close.assert_called_once()
@patch("application.agents.tools.postgres.psycopg2.connect")
def test_insert_query(self, mock_connect, tool):
mock_conn = MagicMock()
mock_cur = MagicMock()
mock_cur.rowcount = 1
mock_conn.cursor.return_value = mock_cur
mock_connect.return_value = mock_conn
result = tool.execute_action(
"postgres_execute_sql",
sql_query="INSERT INTO users (name) VALUES ('Alice')",
)
assert result["status_code"] == 200
assert "1 rows affected" in result["response_data"]["message"]
mock_conn.commit.assert_called_once()
mock_conn.close.assert_called_once()
@patch("application.agents.tools.postgres.psycopg2.connect")
def test_db_error(self, mock_connect, tool):
import psycopg2
mock_connect.side_effect = psycopg2.Error("connection refused")
result = tool.execute_action(
"postgres_execute_sql", sql_query="SELECT 1"
)
assert result["status_code"] == 500
assert "Database error" in result["error"]
@patch("application.agents.tools.postgres.psycopg2.connect")
def test_get_schema(self, mock_connect, tool):
mock_conn = MagicMock()
mock_cur = MagicMock()
mock_cur.fetchall.return_value = [
("users", "id", "integer", "nextval(...)", "NO"),
("users", "name", "varchar", None, "YES"),
("posts", "id", "integer", "nextval(...)", "NO"),
]
mock_conn.cursor.return_value = mock_cur
mock_connect.return_value = mock_conn
result = tool.execute_action("postgres_get_schema", db_name="testdb")
assert result["status_code"] == 200
assert "users" in result["schema"]
assert "posts" in result["schema"]
assert len(result["schema"]["users"]) == 2
assert result["schema"]["users"][0]["column_name"] == "id"
mock_conn.close.assert_called_once()
@patch("application.agents.tools.postgres.psycopg2.connect")
def test_get_schema_db_error(self, mock_connect, tool):
import psycopg2
mock_connect.side_effect = psycopg2.Error("auth failed")
result = tool.execute_action("postgres_get_schema", db_name="testdb")
assert result["status_code"] == 500
assert "Database error" in result["error"]
@patch("application.agents.tools.postgres.psycopg2.connect")
def test_connection_closed_on_error(self, mock_connect, tool):
import psycopg2
mock_conn = MagicMock()
mock_cur = MagicMock()
mock_cur.execute.side_effect = psycopg2.Error("syntax error")
mock_conn.cursor.return_value = mock_cur
mock_connect.return_value = mock_conn
tool.execute_action("postgres_execute_sql", sql_query="BAD SQL")
mock_conn.close.assert_called_once()
@patch("application.agents.tools.postgres.psycopg2.connect")
def test_select_with_no_description(self, mock_connect, tool):
mock_conn = MagicMock()
mock_cur = MagicMock()
mock_cur.description = None
mock_cur.fetchall.return_value = []
mock_conn.cursor.return_value = mock_cur
mock_connect.return_value = mock_conn
result = tool.execute_action(
"postgres_execute_sql", sql_query="SELECT 1 WHERE false"
)
assert result["status_code"] == 200
assert result["response_data"]["column_names"] == []
@pytest.mark.unit
class TestPostgresMetadata:
def test_actions_metadata(self, tool):
meta = tool.get_actions_metadata()
assert len(meta) == 2
names = {a["name"] for a in meta}
assert "postgres_execute_sql" in names
assert "postgres_get_schema" in names
def test_config_requirements(self, tool):
reqs = tool.get_config_requirements()
assert "token" in reqs
assert reqs["token"]["secret"] is True

View File

@@ -0,0 +1,86 @@
"""Tests for application/agents/tools/read_webpage.py"""
from unittest.mock import MagicMock, patch
import pytest
import requests
from application.agents.tools.read_webpage import ReadWebpageTool
@pytest.fixture
def tool():
return ReadWebpageTool()
@pytest.mark.unit
class TestReadWebpageExecuteAction:
def test_unknown_action(self, tool):
result = tool.execute_action("unknown_action")
assert "Error" in result
assert "Unknown action" in result
def test_missing_url(self, tool):
result = tool.execute_action("read_webpage")
assert "Error" in result
assert "URL parameter is missing" in result
@patch("application.agents.tools.read_webpage.validate_url")
@patch("application.agents.tools.read_webpage.requests.get")
def test_successful_fetch(self, mock_get, mock_validate, tool):
mock_validate.return_value = "https://example.com"
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.text = "<html><body><h1>Title</h1><p>Content</p></body></html>"
mock_get.return_value = mock_resp
result = tool.execute_action("read_webpage", url="https://example.com")
assert "Title" in result
assert "Content" in result
@patch("application.agents.tools.read_webpage.validate_url")
@patch("application.agents.tools.read_webpage.requests.get")
def test_request_error(self, mock_get, mock_validate, tool):
mock_validate.return_value = "https://example.com"
mock_get.side_effect = requests.exceptions.ConnectionError("refused")
result = tool.execute_action("read_webpage", url="https://example.com")
assert "Error fetching URL" in result
@patch("application.agents.tools.read_webpage.validate_url")
def test_ssrf_blocked(self, mock_validate, tool):
from application.core.url_validation import SSRFError
mock_validate.side_effect = SSRFError("blocked")
result = tool.execute_action("read_webpage", url="http://169.254.169.254/")
assert "Error" in result
assert "validation failed" in result
@patch("application.agents.tools.read_webpage.validate_url")
@patch("application.agents.tools.read_webpage.requests.get")
def test_http_error(self, mock_get, mock_validate, tool):
mock_validate.return_value = "https://example.com/404"
mock_resp = MagicMock()
mock_resp.raise_for_status.side_effect = requests.exceptions.HTTPError("404")
mock_get.return_value = mock_resp
result = tool.execute_action("read_webpage", url="https://example.com/404")
assert "Error fetching URL" in result
@pytest.mark.unit
class TestReadWebpageMetadata:
def test_actions_metadata(self, tool):
meta = tool.get_actions_metadata()
assert len(meta) == 1
assert meta[0]["name"] == "read_webpage"
assert "url" in meta[0]["parameters"]["properties"]
assert "url" in meta[0]["parameters"]["required"]
def test_config_requirements(self, tool):
assert tool.get_config_requirements() == {}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,692 @@
"""Tests for application/agents/tools/spec_parser.py"""
import json
import pytest
from application.agents.tools.spec_parser import (
_extract_metadata,
_generate_action_name,
_get_base_url,
_load_spec,
_param_to_property,
_resolve_ref,
_validate_spec,
parse_spec,
)
MINIMAL_OPENAPI = json.dumps(
{
"openapi": "3.0.0",
"info": {"title": "Test API", "version": "1.0.0"},
"servers": [{"url": "https://api.example.com"}],
"paths": {
"/users": {
"get": {
"operationId": "listUsers",
"summary": "List all users",
"parameters": [
{
"name": "limit",
"in": "query",
"schema": {"type": "integer"},
}
],
"responses": {"200": {"description": "OK"}},
}
}
},
}
)
MINIMAL_SWAGGER = json.dumps(
{
"swagger": "2.0",
"info": {"title": "Swagger API", "version": "2.0.0"},
"host": "api.example.com",
"basePath": "/v1",
"schemes": ["https"],
"paths": {
"/pets": {
"get": {
"operationId": "listPets",
"summary": "List pets",
"parameters": [],
"responses": {"200": {"description": "OK"}},
}
}
},
}
)
@pytest.mark.unit
class TestLoadSpec:
def test_load_json(self):
spec = _load_spec('{"openapi": "3.0.0"}')
assert spec["openapi"] == "3.0.0"
def test_load_yaml(self):
spec = _load_spec("openapi: '3.0.0'\ninfo:\n title: Test")
assert spec["openapi"] == "3.0.0"
def test_empty_raises(self):
with pytest.raises(ValueError, match="Empty"):
_load_spec("")
def test_whitespace_only_raises(self):
with pytest.raises(ValueError, match="Empty"):
_load_spec(" \n ")
def test_invalid_json_raises(self):
with pytest.raises(ValueError, match="Invalid"):
_load_spec("{invalid json}")
@pytest.mark.unit
class TestValidateSpec:
def test_valid_openapi(self):
spec = json.loads(MINIMAL_OPENAPI)
_validate_spec(spec) # should not raise
def test_valid_swagger(self):
spec = json.loads(MINIMAL_SWAGGER)
_validate_spec(spec)
def test_missing_version_raises(self):
with pytest.raises(ValueError, match="Unsupported"):
_validate_spec({"paths": {"/a": {}}})
def test_no_paths_raises(self):
with pytest.raises(ValueError, match="No API paths"):
_validate_spec({"openapi": "3.0.0"})
def test_empty_paths_raises(self):
with pytest.raises(ValueError, match="No API paths"):
_validate_spec({"openapi": "3.0.0", "paths": {}})
def test_non_dict_raises(self):
with pytest.raises(ValueError, match="valid object"):
_validate_spec("not a dict")
@pytest.mark.unit
class TestExtractMetadata:
def test_openapi_metadata(self):
spec = json.loads(MINIMAL_OPENAPI)
meta = _extract_metadata(spec, is_swagger=False)
assert meta["title"] == "Test API"
assert meta["version"] == "1.0.0"
assert meta["base_url"] == "https://api.example.com"
def test_swagger_metadata(self):
spec = json.loads(MINIMAL_SWAGGER)
meta = _extract_metadata(spec, is_swagger=True)
assert meta["title"] == "Swagger API"
assert meta["base_url"] == "https://api.example.com/v1"
def test_missing_info(self):
spec = {"openapi": "3.0.0", "paths": {"/a": {}}}
meta = _extract_metadata(spec, is_swagger=False)
assert meta["title"] == "Untitled API"
def test_description_truncated(self):
spec = {
"openapi": "3.0.0",
"info": {"title": "T", "description": "x" * 1000},
"paths": {"/a": {}},
}
meta = _extract_metadata(spec, is_swagger=False)
assert len(meta["description"]) <= 500
@pytest.mark.unit
class TestGetBaseUrl:
def test_openapi_servers(self):
spec = {"servers": [{"url": "https://api.example.com/v2/"}]}
assert _get_base_url(spec, is_swagger=False) == "https://api.example.com/v2"
def test_openapi_no_servers(self):
assert _get_base_url({}, is_swagger=False) == ""
def test_swagger_with_host(self):
spec = {"host": "api.test.com", "basePath": "/v1", "schemes": ["https"]}
assert _get_base_url(spec, is_swagger=True) == "https://api.test.com/v1"
def test_swagger_no_host(self):
assert _get_base_url({}, is_swagger=True) == ""
def test_swagger_default_scheme(self):
spec = {"host": "api.test.com", "basePath": ""}
assert _get_base_url(spec, is_swagger=True) == "https://api.test.com"
@pytest.mark.unit
class TestGenerateActionName:
def test_uses_operation_id(self):
assert _generate_action_name({"operationId": "getUser"}, "get", "/users") == "getUser"
def test_fallback_to_method_path(self):
name = _generate_action_name({}, "get", "/users/{id}")
assert name.startswith("get_")
assert "users" in name
def test_truncates_long_names(self):
name = _generate_action_name({}, "get", "/" + "a" * 200)
assert len(name) <= 64
def test_sanitizes_special_chars(self):
name = _generate_action_name({"operationId": "get.user@v2"}, "get", "/")
assert "." not in name
assert "@" not in name
@pytest.mark.unit
class TestParamToProperty:
def test_string_param(self):
param = {"name": "q", "in": "query", "schema": {"type": "string"}}
prop = _param_to_property(param)
assert prop["type"] == "string"
assert prop["required"] is False
def test_integer_param(self):
param = {
"name": "limit",
"in": "query",
"required": True,
"schema": {"type": "integer"},
}
prop = _param_to_property(param)
assert prop["type"] == "integer"
assert prop["required"] is True
assert prop["filled_by_llm"] is True
def test_number_maps_to_integer(self):
param = {"name": "score", "in": "query", "schema": {"type": "number"}}
prop = _param_to_property(param)
assert prop["type"] == "integer"
@pytest.mark.unit
class TestResolveRef:
def test_no_ref(self):
obj = {"type": "string"}
assert _resolve_ref(obj, {}, {}) == obj
def test_components_ref(self):
components = {"schemas": {"User": {"type": "object", "properties": {"name": {"type": "string"}}}}}
obj = {"$ref": "#/components/schemas/User"}
result = _resolve_ref(obj, components, {})
assert result["type"] == "object"
def test_definitions_ref(self):
definitions = {"Pet": {"type": "object"}}
obj = {"$ref": "#/definitions/Pet"}
result = _resolve_ref(obj, {}, definitions)
assert result["type"] == "object"
def test_unsupported_ref(self):
obj = {"$ref": "#/external/something"}
assert _resolve_ref(obj, {}, {}) is None
def test_non_dict_returns_none(self):
assert _resolve_ref("string", {}, {}) is None
assert _resolve_ref(42, {}, {}) is None
@pytest.mark.unit
class TestParseSpec:
def test_openapi_full_parse(self):
metadata, actions = parse_spec(MINIMAL_OPENAPI)
assert metadata["title"] == "Test API"
assert len(actions) == 1
assert actions[0]["name"] == "listUsers"
assert actions[0]["method"] == "GET"
assert actions[0]["url"] == "https://api.example.com/users"
assert "limit" in actions[0]["query_params"]["properties"]
def test_swagger_full_parse(self):
metadata, actions = parse_spec(MINIMAL_SWAGGER)
assert metadata["title"] == "Swagger API"
assert len(actions) == 1
assert actions[0]["name"] == "listPets"
assert actions[0]["method"] == "GET"
def test_multiple_methods(self):
spec = json.dumps(
{
"openapi": "3.0.0",
"info": {"title": "T", "version": "1"},
"paths": {
"/items": {
"get": {"operationId": "listItems", "responses": {}},
"post": {
"operationId": "createItem",
"requestBody": {
"content": {
"application/json": {
"schema": {
"type": "object",
"properties": {
"name": {"type": "string"}
},
"required": ["name"],
}
}
}
},
"responses": {},
},
}
},
}
)
metadata, actions = parse_spec(spec)
assert len(actions) == 2
names = {a["name"] for a in actions}
assert "listItems" in names
assert "createItem" in names
create = next(a for a in actions if a["name"] == "createItem")
assert "name" in create["body"]["properties"]
def test_header_params(self):
spec = json.dumps(
{
"openapi": "3.0.0",
"info": {"title": "T", "version": "1"},
"paths": {
"/data": {
"get": {
"operationId": "getData",
"parameters": [
{"name": "X-API-Key", "in": "header", "schema": {"type": "string"}}
],
"responses": {},
}
}
},
}
)
_, actions = parse_spec(spec)
assert "X-API-Key" in actions[0]["headers"]["properties"]
def test_invalid_spec_raises(self):
with pytest.raises(ValueError):
parse_spec("")
def test_yaml_spec(self):
yaml_spec = """
openapi: "3.0.0"
info:
title: YAML API
version: "1.0"
paths:
/health:
get:
operationId: healthCheck
responses:
"200":
description: OK
"""
metadata, actions = parse_spec(yaml_spec)
assert metadata["title"] == "YAML API"
assert actions[0]["name"] == "healthCheck"
def test_non_dict_path_item_skipped(self):
"""Cover line 117: non-dict path item is skipped."""
spec = json.dumps(
{
"openapi": "3.0.0",
"info": {"title": "T", "version": "1"},
"paths": {
"/valid": {
"get": {
"operationId": "validOp",
"responses": {"200": {"description": "OK"}},
}
},
"/invalid": "not_a_dict",
},
}
)
_, actions = parse_spec(spec)
assert len(actions) == 1
assert actions[0]["name"] == "validOp"
def test_non_dict_operation_skipped(self):
"""Cover line 122: non-dict operation for a method is skipped."""
spec = json.dumps(
{
"openapi": "3.0.0",
"info": {"title": "T", "version": "1"},
"paths": {
"/items": {
"get": "not_a_dict",
"post": {
"operationId": "createItem",
"responses": {},
},
}
},
}
)
_, actions = parse_spec(spec)
assert len(actions) == 1
assert actions[0]["name"] == "createItem"
def test_operation_parse_failure_logged(self):
"""Cover lines 137: exception parsing operation is caught."""
spec = json.dumps(
{
"openapi": "3.0.0",
"info": {"title": "T", "version": "1"},
"paths": {
"/items": {
"get": {
"operationId": "getItems",
"responses": {},
},
"post": {
"operationId": "createItem",
"requestBody": {
"$ref": "#/components/schemas/Missing"
},
"responses": {},
},
}
},
}
)
_, actions = parse_spec(spec)
# At least the GET should succeed
assert any(a["name"] == "getItems" for a in actions)
def test_path_level_params_merged(self):
"""Cover lines 129-130, 148, 159: path-level parameters merged."""
spec = json.dumps(
{
"openapi": "3.0.0",
"info": {"title": "T", "version": "1"},
"paths": {
"/items/{id}": {
"parameters": [
{
"name": "id",
"in": "path",
"required": True,
"schema": {"type": "string"},
}
],
"get": {
"operationId": "getItem",
"responses": {},
},
}
},
}
)
_, actions = parse_spec(spec)
assert "id" in actions[0]["query_params"]["properties"]
def test_swagger_body_param_extraction(self):
"""Cover lines 145, 148, 152-153: Swagger 2.0 body parameter extraction."""
spec = json.dumps(
{
"swagger": "2.0",
"info": {"title": "T", "version": "1"},
"host": "api.test.com",
"basePath": "/v1",
"schemes": ["https"],
"paths": {
"/items": {
"post": {
"operationId": "createItem",
"consumes": ["application/json"],
"parameters": [
{
"name": "body",
"in": "body",
"schema": {
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "Item name",
}
},
"required": ["name"],
},
}
],
"responses": {"201": {"description": "Created"}},
}
}
},
}
)
_, actions = parse_spec(spec)
assert len(actions) == 1
assert "name" in actions[0]["body"]["properties"]
assert actions[0]["body_content_type"] == "application/json"
def test_traverse_path_key_error(self):
"""Cover lines 173-176: _traverse_path returns None on KeyError."""
from application.agents.tools.spec_parser import _traverse_path
result = _traverse_path({"a": {"b": 1}}, ["a", "c"])
assert result is None
def test_traverse_path_non_dict_result(self):
"""Cover line 175-176: _traverse_path returns None for non-dict result."""
from application.agents.tools.spec_parser import _traverse_path
result = _traverse_path({"a": "string_value"}, ["a"])
assert result is None
def test_openapi_request_body_form_urlencoded(self):
"""Cover lines 152-153: OpenAPI 3.x request body with form-urlencoded."""
spec = json.dumps(
{
"openapi": "3.0.0",
"info": {"title": "T", "version": "1"},
"paths": {
"/login": {
"post": {
"operationId": "login",
"requestBody": {
"content": {
"application/x-www-form-urlencoded": {
"schema": {
"type": "object",
"properties": {
"username": {"type": "string"},
"password": {"type": "string"},
},
}
}
}
},
"responses": {},
}
}
},
}
)
_, actions = parse_spec(spec)
assert actions[0]["body_content_type"] == "application/x-www-form-urlencoded"
assert "username" in actions[0]["body"]["properties"]
# ---------------------------------------------------------------------------
# Coverage — additional uncovered lines: 205, 209, 213, 216-217, 222, 228
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestSpecParserAdditionalCoverage:
def test_categorize_params_query_and_header(self):
"""Cover lines 205, 209, 213: parameters categorized into query and header."""
from application.agents.tools.spec_parser import _categorize_parameters
parameters = [
{"name": "q", "in": "query", "required": True, "description": "Query param"},
{"name": "X-Auth", "in": "header", "required": False, "description": "Auth header"},
{"name": "id", "in": "path", "required": True, "description": "Path param"},
]
query_params, headers = _categorize_parameters(parameters, {}, {})
assert "q" in query_params
assert "X-Auth" in headers
assert "id" in query_params # path params go to query_params
def test_categorize_params_skips_no_name(self):
"""Cover line 205: parameters without name are skipped."""
from application.agents.tools.spec_parser import _categorize_parameters
parameters = [
{"in": "query"}, # no name
]
query_params, headers = _categorize_parameters(parameters, {}, {})
assert len(query_params) == 0
assert len(headers) == 0
def test_param_to_property_integer_type(self):
"""Cover lines 216-217, 222, 228: _param_to_property with integer type."""
from application.agents.tools.spec_parser import _param_to_property
param = {
"name": "count",
"schema": {"type": "integer"},
"description": "Count of items",
"required": True,
}
prop = _param_to_property(param)
assert prop["type"] == "integer"
assert prop["required"] is True
assert prop["filled_by_llm"] is True
def test_param_to_property_number_type(self):
"""Cover line 222: number type mapped to integer."""
from application.agents.tools.spec_parser import _param_to_property
param = {
"schema": {"type": "number"},
"description": "A number",
"required": False,
}
prop = _param_to_property(param)
assert prop["type"] == "integer"
def test_param_to_property_string_default(self):
"""Cover line 222: unknown type defaults to string."""
from application.agents.tools.spec_parser import _param_to_property
param = {"description": "Desc", "required": False}
prop = _param_to_property(param)
assert prop["type"] == "string"
def test_param_to_property_description_truncated(self):
"""Cover line 228: description truncated to 200 chars."""
from application.agents.tools.spec_parser import _param_to_property
param = {"description": "x" * 300, "required": False}
prop = _param_to_property(param)
assert len(prop["description"]) == 200
# ---------------------------------------------------------------------------
# Additional coverage for spec_parser.py
# Lines: 57-59 (YAML error), 116-117 (non-dict path_item), 136-140
# (action parse exception), 156 (full_url with no base_url),
# 184-190 (generate_action_name from path), 99 (swagger base URL)
# ---------------------------------------------------------------------------
from application.agents.tools.spec_parser import _extract_actions # noqa: E402
@pytest.mark.unit
class TestLoadSpecYAMLError:
"""Cover lines 58-59: YAML parse error."""
def test_invalid_yaml_raises(self):
with pytest.raises(ValueError, match="Invalid YAML"):
_load_spec(" \ttabs: [invalid: yaml: {{")
@pytest.mark.unit
class TestExtractActionsNonDictPathItem:
"""Cover lines 116-117: non-dict path_item is skipped."""
def test_non_dict_path_skipped(self):
spec = {
"openapi": "3.0.0",
"paths": {
"/valid": {"get": {"operationId": "getValid"}},
"/invalid": "not-a-dict",
},
}
actions = _extract_actions(spec, False)
assert len(actions) == 1
assert actions[0]["name"] == "getValid"
@pytest.mark.unit
class TestExtractActionsParseException:
"""Cover lines 136-140: exception in _build_action is caught."""
def test_bad_operation_skipped(self):
spec = {
"openapi": "3.0.0",
"paths": {
"/test": {
"get": {
"operationId": "good",
},
"post": {
"operationId": "bad",
"parameters": [{"$ref": "#/invalid/ref"}],
},
},
},
}
# Should not raise, bad operation is skipped
actions = _extract_actions(spec, False)
assert len(actions) >= 1
@pytest.mark.unit
class TestGenerateActionNameFromPath:
"""Cover lines 184-190: operationId missing, generate from path."""
def test_name_from_path(self):
name = _generate_action_name({}, "get", "/users/{id}/posts")
assert name.startswith("get_")
assert "users" in name
assert "{" not in name
def test_name_truncated(self):
long_path = "/a" * 100
name = _generate_action_name({}, "post", long_path)
assert len(name) <= 64
@pytest.mark.unit
class TestGetBaseUrlSwagger:
"""Cover line 99: swagger base URL with host and scheme."""
def test_swagger_base_url(self):
spec = {
"swagger": "2.0",
"host": "api.example.com",
"basePath": "/v2",
"schemes": ["https"],
}
url = _get_base_url(spec, True)
assert url == "https://api.example.com/v2"
def test_swagger_no_host(self):
spec = {"swagger": "2.0"}
url = _get_base_url(spec, True)
assert url == ""

View File

@@ -0,0 +1,79 @@
"""Tests for application/agents/tools/telegram.py"""
from unittest.mock import MagicMock, patch
import pytest
from application.agents.tools.telegram import TelegramTool
@pytest.fixture
def tool():
return TelegramTool(config={"token": "bot123:ABC"})
@pytest.mark.unit
class TestTelegramExecuteAction:
def test_unknown_action_raises(self, tool):
with pytest.raises(ValueError, match="Unknown action"):
tool.execute_action("invalid")
@patch("application.agents.tools.telegram.requests.post")
def test_send_message(self, mock_post, tool):
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_post.return_value = mock_resp
result = tool.execute_action(
"telegram_send_message", text="Hello", chat_id="12345"
)
assert result["status_code"] == 200
assert result["message"] == "Message sent"
call_args = mock_post.call_args
assert "bot123:ABC/sendMessage" in call_args[0][0]
assert call_args[1]["data"]["text"] == "Hello"
assert call_args[1]["data"]["chat_id"] == "12345"
@patch("application.agents.tools.telegram.requests.post")
def test_send_image(self, mock_post, tool):
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_post.return_value = mock_resp
result = tool.execute_action(
"telegram_send_image", image_url="https://img.com/cat.jpg", chat_id="12345"
)
assert result["status_code"] == 200
assert result["message"] == "Image sent"
call_args = mock_post.call_args
assert "bot123:ABC/sendPhoto" in call_args[0][0]
assert call_args[1]["data"]["photo"] == "https://img.com/cat.jpg"
@patch("application.agents.tools.telegram.requests.post")
def test_api_error_status(self, mock_post, tool):
mock_resp = MagicMock()
mock_resp.status_code = 403
mock_post.return_value = mock_resp
result = tool.execute_action(
"telegram_send_message", text="Hi", chat_id="999"
)
assert result["status_code"] == 403
@pytest.mark.unit
class TestTelegramMetadata:
def test_actions_metadata(self, tool):
meta = tool.get_actions_metadata()
assert len(meta) == 2
names = {a["name"] for a in meta}
assert "telegram_send_message" in names
assert "telegram_send_image" in names
def test_config_requirements(self, tool):
reqs = tool.get_config_requirements()
assert "token" in reqs
assert reqs["token"]["secret"] is True

View File

@@ -277,3 +277,279 @@ class TestToolExecutorExecute:
# load_tool called only once due to cache
assert mock_tool_manager.load_tool.call_count == 1
def test_execute_api_tool(self, mock_tool_manager, monkeypatch):
"""Cover lines 199-202, 256-267: api_tool execution path."""
executor = ToolExecutor(user="test_user")
monkeypatch.setattr(
"application.agents.tool_executor.ToolActionParser",
lambda _cls: Mock(
parse_args=Mock(return_value=("t1", "get_users", {"body_param": "val"}))
),
)
tools_dict = {
"t1": {
"name": "api_tool",
"config": {
"actions": {
"get_users": {
"name": "get_users",
"description": "Get users",
"url": "https://api.example.com/users",
"method": "GET",
"query_params": {"properties": {}},
"headers": {"properties": {}},
"body": {"properties": {}},
"active": True,
}
}
},
}
}
call = self._make_call(name="get_users_t1", call_id="c2")
gen = executor.execute(tools_dict, call, "MockLLM")
events = []
result = None
while True:
try:
events.append(next(gen))
except StopIteration as e:
result = e.value
break
assert result is not None
statuses = [e["data"]["status"] for e in events]
assert "pending" in statuses
def test_execute_with_prefilled_param_values(self, mock_tool_manager, monkeypatch):
"""Cover line 179: params not in call_args use default value."""
executor = ToolExecutor(user="test_user")
monkeypatch.setattr(
"application.agents.tool_executor.ToolActionParser",
lambda _cls: Mock(
parse_args=Mock(return_value=("t1", "act", {}))
),
)
tools_dict = {
"t1": {
"name": "test_tool",
"config": {"key": "val"},
"actions": [
{
"name": "act",
"description": "Test",
"parameters": {
"properties": {
"hidden_param": {
"type": "string",
"value": "default_val",
"filled_by_llm": False,
}
}
},
}
],
}
}
call = self._make_call(name="act_t1")
gen = executor.execute(tools_dict, call, "MockLLM")
while True:
try:
next(gen)
except StopIteration as e:
result = e.value
break
assert result[0] == "Tool result"
def test_execute_tool_with_artifact_id(self, mock_tool_manager, monkeypatch):
"""Cover lines 217-218: tool with get_artifact_id."""
executor = ToolExecutor(user="test_user")
monkeypatch.setattr(
"application.agents.tool_executor.ToolActionParser",
lambda _cls: Mock(
parse_args=Mock(return_value=("t1", "act", {"q": "v"}))
),
)
mock_tool = mock_tool_manager.load_tool.return_value
mock_tool.get_artifact_id = Mock(return_value="artifact-123")
tools_dict = {
"t1": {
"name": "test_tool",
"config": {"key": "val"},
"actions": [
{
"name": "act",
"description": "Test",
"parameters": {"properties": {}},
}
],
}
}
call = self._make_call(name="act_t1")
gen = executor.execute(tools_dict, call, "MockLLM")
events = []
while True:
try:
events.append(next(gen))
except StopIteration:
break
completed_events = [
e for e in events if e["data"].get("status") == "completed"
]
assert any(
"artifact_id" in e.get("data", {}) for e in completed_events
)
def test_get_or_load_tool_encrypted_credentials(self, monkeypatch):
"""Cover lines 273-278: encrypted credentials path."""
executor = ToolExecutor(user="test_user")
mock_tm = Mock()
mock_tool = Mock()
mock_tm.load_tool.return_value = mock_tool
monkeypatch.setattr(
"application.agents.tool_executor.ToolManager", lambda config: mock_tm
)
monkeypatch.setattr(
"application.agents.tool_executor.decrypt_credentials",
lambda creds, user: {"api_key": "decrypted_key"},
)
tool_data = {
"name": "custom_tool",
"config": {"encrypted_credentials": "encrypted_blob"},
}
result = executor._get_or_load_tool(tool_data, "t1", "act")
assert result is mock_tool
call_kwargs = mock_tm.load_tool.call_args
tool_config = call_kwargs[1]["tool_config"] if "tool_config" in call_kwargs[1] else call_kwargs[0][1]
assert "api_key" in tool_config.get("auth_credentials", tool_config)
def test_get_or_load_tool_mcp_tool(self, monkeypatch):
"""Cover lines 281-283: mcp_tool path sets query_mode."""
executor = ToolExecutor(user="test_user")
executor.conversation_id = "conv-123"
mock_tm = Mock()
mock_tool = Mock()
mock_tm.load_tool.return_value = mock_tool
monkeypatch.setattr(
"application.agents.tool_executor.ToolManager", lambda config: mock_tm
)
tool_data = {
"name": "mcp_tool",
"config": {},
}
result = executor._get_or_load_tool(tool_data, "t1", "act")
assert result is mock_tool
call_kwargs = mock_tm.load_tool.call_args
tool_config = call_kwargs[1].get("tool_config", call_kwargs[0][1] if len(call_kwargs[0]) > 1 else {})
assert tool_config.get("query_mode") is True
assert tool_config.get("conversation_id") == "conv-123"
# ---------------------------------------------------------------------------
# Coverage — additional uncovered lines: 217-218, 256-267
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestToolExecutorAdditionalCoverage:
def test_get_artifact_id_exception_handled(self, monkeypatch):
"""Cover lines 217-218: get_artifact_id raises exception."""
from types import SimpleNamespace
executor = ToolExecutor(user="user1")
mock_tool = Mock()
mock_tool.execute_action.return_value = "result"
mock_tool.get_artifact_id.side_effect = RuntimeError("artifact error")
monkeypatch.setattr(
"application.agents.tool_executor.ToolManager",
lambda config: Mock(load_tool=Mock(return_value=mock_tool)),
)
tools_dict = {
"t1": {
"name": "custom_tool",
"config": {"key": "val"},
"actions": [
{
"name": "action1",
"active": True,
"parameters": {"properties": {}},
}
],
}
}
# Create a fake call object matching what ToolActionParser expects
call = SimpleNamespace(
id="c1",
function=SimpleNamespace(
name="action1_t1",
arguments="{}",
),
)
events = list(executor.execute(tools_dict, call, "OpenAILLM"))
# Should complete without raising; artifact_id error is logged but not raised
assert any(
isinstance(e, dict) and e.get("type") == "tool_call"
for e in events
)
def test_get_or_load_api_tool_with_body_content_type(self, monkeypatch):
"""Cover lines 256-267: api_tool with body_content_type."""
executor = ToolExecutor(user="user1")
mock_tm = Mock()
mock_tool = Mock()
mock_tm.load_tool.return_value = mock_tool
monkeypatch.setattr(
"application.agents.tool_executor.ToolManager", lambda config: mock_tm
)
tool_data = {
"name": "api_tool",
"config": {
"actions": {
"create": {
"url": "https://api.example.com/items",
"method": "POST",
"body_content_type": "application/json",
"body_encoding_rules": {"encode_as": "json"},
}
}
},
}
result = executor._get_or_load_tool(
tool_data, "t1", "create",
headers={"Authorization": "Bearer tok"},
query_params={"page": "1"},
)
assert result is mock_tool
# Verify config was built with body_content_type
call_args = mock_tm.load_tool.call_args
tool_config = call_args[1].get("tool_config", call_args[0][1] if len(call_args[0]) > 1 else {})
assert tool_config.get("body_content_type") == "application/json"
assert tool_config.get("body_encoding_rules") == {"encode_as": "json"}

View File

@@ -0,0 +1,433 @@
"""Tests for WorkflowAgent - covering _parse_embedded_workflow, _load_from_database,
_save_workflow_run, _determine_run_status, _serialize_state, and gen flow."""
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
import pytest
from application.agents.workflows.schemas import (
ExecutionStatus,
WorkflowGraph,
)
def _make_agent(**overrides):
"""Create a WorkflowAgent with mocked base class dependencies."""
defaults = {
"endpoint": "https://api.example.com",
"llm_name": "openai",
"model_id": "gpt-4",
"api_key": "test_key",
"user_api_key": None,
"prompt": "You are helpful.",
"chat_history": [],
"decoded_token": {"sub": "user1"},
"attachments": [],
"json_schema": None,
}
defaults.update(overrides)
with patch("application.agents.workflow_agent.log_activity", lambda **kw: lambda f: f):
from application.agents.workflow_agent import WorkflowAgent
agent = WorkflowAgent(**defaults)
return agent
class TestWorkflowAgentInit:
@pytest.mark.unit
def test_sets_attributes(self):
agent = _make_agent(workflow_id="wf1", workflow_owner="owner1")
assert agent.workflow_id == "wf1"
assert agent.workflow_owner == "owner1"
assert agent._engine is None
@pytest.mark.unit
def test_embedded_workflow(self):
wf_data = {"nodes": [], "edges": [], "name": "Test"}
agent = _make_agent(workflow=wf_data)
assert agent._workflow_data == wf_data
class TestParseEmbeddedWorkflow:
@pytest.mark.unit
def test_parses_valid_workflow(self):
wf_data = {
"name": "Test Workflow",
"description": "A test",
"nodes": [
{"id": "n1", "type": "start", "title": "Start", "data": {}, "position": {"x": 0, "y": 0}},
{"id": "n2", "type": "end", "title": "End", "data": {}, "position": {"x": 100, "y": 0}},
],
"edges": [
{"id": "e1", "source": "n1", "target": "n2", "sourceHandle": "out", "targetHandle": "in"},
],
}
agent = _make_agent(workflow=wf_data, workflow_id="wf1")
graph = agent._parse_embedded_workflow()
assert graph is not None
assert len(graph.nodes) == 2
assert len(graph.edges) == 1
assert graph.workflow.name == "Test Workflow"
@pytest.mark.unit
def test_edge_source_id_alias(self):
wf_data = {
"nodes": [{"id": "n1", "type": "start", "data": {}}],
"edges": [{"id": "e1", "source_id": "n1", "target_id": "n2", "source_handle": "out", "target_handle": "in"}],
}
agent = _make_agent(workflow=wf_data)
graph = agent._parse_embedded_workflow()
assert graph is not None
assert graph.edges[0].source_id == "n1"
@pytest.mark.unit
def test_invalid_data_returns_none(self):
agent = _make_agent(workflow={"nodes": [{"bad": "data"}], "edges": []})
graph = agent._parse_embedded_workflow()
assert graph is None
class TestLoadWorkflowGraph:
@pytest.mark.unit
def test_uses_embedded_when_available(self):
agent = _make_agent(workflow={"nodes": [], "edges": [], "name": "E"})
agent._parse_embedded_workflow = MagicMock(return_value="parsed_graph")
result = agent._load_workflow_graph()
assert result == "parsed_graph"
@pytest.mark.unit
def test_uses_database_when_workflow_id(self):
agent = _make_agent(workflow_id="wf1")
agent._load_from_database = MagicMock(return_value="db_graph")
result = agent._load_workflow_graph()
assert result == "db_graph"
@pytest.mark.unit
def test_returns_none_when_nothing(self):
agent = _make_agent()
result = agent._load_workflow_graph()
assert result is None
class TestLoadFromDatabase:
@pytest.mark.unit
def test_invalid_workflow_id_returns_none(self):
agent = _make_agent(workflow_id="invalid!")
result = agent._load_from_database()
assert result is None
@pytest.mark.unit
def test_no_owner_returns_none(self):
agent = _make_agent(workflow_id="507f1f77bcf86cd799439011", decoded_token={})
agent.workflow_owner = None
result = agent._load_from_database()
assert result is None
@pytest.mark.unit
def test_uses_decoded_token_sub(self):
agent = _make_agent(
workflow_id="507f1f77bcf86cd799439011",
decoded_token={"sub": "user1"},
)
agent.workflow_owner = None
mock_collection = MagicMock()
mock_collection.find_one.return_value = None
mock_db = MagicMock()
mock_db.__getitem__ = MagicMock(return_value=mock_collection)
with patch("application.agents.workflow_agent.MongoDB") as MockMongo, \
patch("application.agents.workflow_agent.settings") as mock_settings:
mock_settings.MONGO_DB_NAME = "test_db"
MockMongo.get_client.return_value = {"test_db": mock_db}
result = agent._load_from_database()
assert result is None # workflow_doc not found
@pytest.mark.unit
def test_successful_load(self):
agent = _make_agent(
workflow_id="507f1f77bcf86cd799439011",
workflow_owner="owner1",
)
mock_wf_coll = MagicMock()
mock_wf_coll.find_one.return_value = {
"_id": "507f1f77bcf86cd799439011",
"name": "Test WF",
"user": "owner1",
"current_graph_version": 1,
}
mock_nodes_coll = MagicMock()
mock_nodes_coll.find.return_value = [
{"id": "n1", "workflow_id": "507f1f77bcf86cd799439011", "type": "start",
"title": "Start", "position": {"x": 0, "y": 0}, "config": {}},
]
mock_edges_coll = MagicMock()
mock_edges_coll.find.return_value = []
def getitem(name):
return {"workflows": mock_wf_coll, "workflow_nodes": mock_nodes_coll, "workflow_edges": mock_edges_coll}[name]
mock_db = MagicMock()
mock_db.__getitem__ = MagicMock(side_effect=getitem)
with patch("application.agents.workflow_agent.MongoDB") as MockMongo, \
patch("application.agents.workflow_agent.settings") as mock_settings:
mock_settings.MONGO_DB_NAME = "test_db"
MockMongo.get_client.return_value = {"test_db": mock_db}
result = agent._load_from_database()
assert result is not None
assert len(result.nodes) == 1
@pytest.mark.unit
def test_invalid_graph_version(self):
agent = _make_agent(
workflow_id="507f1f77bcf86cd799439011",
workflow_owner="owner1",
)
mock_wf_coll = MagicMock()
mock_wf_coll.find_one.return_value = {
"_id": "507f1f77bcf86cd799439011",
"name": "WF",
"user": "owner1",
"current_graph_version": "bad",
}
mock_nodes_coll = MagicMock()
mock_nodes_coll.find.return_value = []
mock_edges_coll = MagicMock()
mock_edges_coll.find.return_value = []
def getitem(name):
return {"workflows": mock_wf_coll, "workflow_nodes": mock_nodes_coll, "workflow_edges": mock_edges_coll}[name]
mock_db = MagicMock()
mock_db.__getitem__ = MagicMock(side_effect=getitem)
with patch("application.agents.workflow_agent.MongoDB") as MockMongo, \
patch("application.agents.workflow_agent.settings") as mock_settings:
mock_settings.MONGO_DB_NAME = "test_db"
MockMongo.get_client.return_value = {"test_db": mock_db}
result = agent._load_from_database()
assert result is not None # Defaults to version 1
@pytest.mark.unit
def test_fallback_nodes_without_version(self):
"""When graph_version=1 finds no nodes, falls back to nodes without version field."""
agent = _make_agent(
workflow_id="507f1f77bcf86cd799439011",
workflow_owner="owner1",
)
mock_wf_coll = MagicMock()
mock_wf_coll.find_one.return_value = {
"_id": "507f1f77bcf86cd799439011",
"name": "WF",
"user": "owner1",
"current_graph_version": 1,
}
call_count = [0]
def nodes_find(query):
call_count[0] += 1
if call_count[0] == 1:
return [] # No versioned nodes
return [{"id": "n1", "workflow_id": "wf", "type": "start",
"title": "S", "position": {"x": 0, "y": 0}, "config": {}}]
mock_nodes_coll = MagicMock()
mock_nodes_coll.find.side_effect = nodes_find
edge_call = [0]
def edges_find(query):
edge_call[0] += 1
if edge_call[0] == 1:
return []
return []
mock_edges_coll = MagicMock()
mock_edges_coll.find.side_effect = edges_find
def getitem(name):
return {"workflows": mock_wf_coll, "workflow_nodes": mock_nodes_coll, "workflow_edges": mock_edges_coll}[name]
mock_db = MagicMock()
mock_db.__getitem__ = MagicMock(side_effect=getitem)
with patch("application.agents.workflow_agent.MongoDB") as MockMongo, \
patch("application.agents.workflow_agent.settings") as mock_settings:
mock_settings.MONGO_DB_NAME = "test_db"
MockMongo.get_client.return_value = {"test_db": mock_db}
result = agent._load_from_database()
assert result is not None
assert len(result.nodes) == 1
@pytest.mark.unit
def test_exception_returns_none(self):
agent = _make_agent(
workflow_id="507f1f77bcf86cd799439011",
workflow_owner="owner1",
)
with patch("application.agents.workflow_agent.MongoDB") as MockMongo:
MockMongo.get_client.side_effect = Exception("db error")
result = agent._load_from_database()
assert result is None
class TestGenInner:
@pytest.mark.unit
def test_no_graph_yields_error(self):
agent = _make_agent()
agent._load_workflow_graph = MagicMock(return_value=None)
events = list(agent._gen_inner("query", None))
assert any(e.get("type") == "error" for e in events)
@pytest.mark.unit
def test_successful_execution(self):
agent = _make_agent(workflow_id="wf1")
mock_graph = MagicMock(spec=WorkflowGraph)
agent._load_workflow_graph = MagicMock(return_value=mock_graph)
agent._save_workflow_run = MagicMock()
mock_engine = MagicMock()
mock_engine.execute.return_value = iter([{"answer": "result"}])
with patch("application.agents.workflow_agent.WorkflowEngine", return_value=mock_engine):
events = list(agent._gen_inner("query", None))
assert len(events) == 1
agent._save_workflow_run.assert_called_once_with("query")
class TestSaveWorkflowRun:
@pytest.mark.unit
def test_no_engine_returns_early(self):
agent = _make_agent()
agent._engine = None
agent._save_workflow_run("query") # Should not raise
@pytest.mark.unit
def test_saves_to_mongo(self):
agent = _make_agent(workflow_id="wf1")
mock_engine = MagicMock()
mock_engine.state = {"query": "test"}
mock_engine.execution_log = []
mock_engine.get_execution_summary.return_value = []
agent._engine = mock_engine
mock_collection = MagicMock()
mock_db = MagicMock()
mock_db.__getitem__ = MagicMock(return_value=mock_collection)
with patch("application.agents.workflow_agent.MongoDB") as MockMongo, \
patch("application.agents.workflow_agent.settings") as mock_settings:
mock_settings.MONGO_DB_NAME = "test_db"
MockMongo.get_client.return_value = {"test_db": mock_db}
agent._save_workflow_run("query")
mock_collection.insert_one.assert_called_once()
@pytest.mark.unit
def test_exception_does_not_propagate(self):
agent = _make_agent(workflow_id="wf1")
mock_engine = MagicMock()
mock_engine.state = {}
mock_engine.execution_log = []
mock_engine.get_execution_summary.return_value = []
agent._engine = mock_engine
with patch("application.agents.workflow_agent.MongoDB") as MockMongo:
MockMongo.get_client.side_effect = Exception("db fail")
agent._save_workflow_run("query") # Should not raise
class TestDetermineRunStatus:
@pytest.mark.unit
def test_no_engine_returns_completed(self):
agent = _make_agent()
agent._engine = None
assert agent._determine_run_status() == ExecutionStatus.COMPLETED
@pytest.mark.unit
def test_empty_log_returns_completed(self):
agent = _make_agent()
agent._engine = MagicMock()
agent._engine.execution_log = []
assert agent._determine_run_status() == ExecutionStatus.COMPLETED
@pytest.mark.unit
def test_failed_log_returns_failed(self):
agent = _make_agent()
agent._engine = MagicMock()
agent._engine.execution_log = [
{"status": "completed"},
{"status": "failed"},
]
assert agent._determine_run_status() == ExecutionStatus.FAILED
@pytest.mark.unit
def test_all_completed_returns_completed(self):
agent = _make_agent()
agent._engine = MagicMock()
agent._engine.execution_log = [
{"status": "completed"},
{"status": "completed"},
]
assert agent._determine_run_status() == ExecutionStatus.COMPLETED
class TestSerializeState:
@pytest.mark.unit
def test_serializes_primitives(self):
agent = _make_agent()
state = {"str": "hello", "int": 42, "float": 3.14, "bool": True, "none": None}
result = agent._serialize_state(state)
assert result == state
@pytest.mark.unit
def test_serializes_nested_dict(self):
agent = _make_agent()
state = {"nested": {"key": "value"}}
result = agent._serialize_state(state)
assert result["nested"]["key"] == "value"
@pytest.mark.unit
def test_serializes_list(self):
agent = _make_agent()
state = {"items": [1, 2, "three"]}
result = agent._serialize_state(state)
assert result["items"] == [1, 2, "three"]
@pytest.mark.unit
def test_serializes_tuple(self):
agent = _make_agent()
state = {"tup": (1, 2)}
result = agent._serialize_state(state)
assert result["tup"] == [1, 2]
@pytest.mark.unit
def test_serializes_datetime(self):
agent = _make_agent()
dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
state = {"time": dt}
result = agent._serialize_state(state)
assert "2025-01-01" in result["time"]
@pytest.mark.unit
def test_serializes_unknown_to_str(self):
agent = _make_agent()
state = {"obj": object()}
result = agent._serialize_state(state)
assert isinstance(result["obj"], str)

View File

@@ -330,3 +330,180 @@ def test_execute_agent_node_raises_when_schema_set_and_response_not_json(monkeyp
match="Structured output was expected but response was not valid JSON",
):
list(engine._execute_agent_node(node))
# ---------------------------------------------------------------------------
# Coverage — additional uncovered lines: 204, 213-215, 223, 283-284, 289,
# 355, 375
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestWorkflowEngineAdditionalCoverage:
def test_agent_node_prompt_template_empty_uses_query(self, monkeypatch):
"""Cover line 204: prompt_template is empty, uses state query."""
engine = create_engine()
engine.state["query"] = "What is the answer?"
node = create_agent_node(node_id="n1")
node.config["prompt_template"] = ""
node_events = [{"answer": "42"}]
monkeypatch.setattr(
WorkflowNodeAgentFactory,
"create",
staticmethod(lambda **kwargs: StubNodeAgent(node_events)),
)
monkeypatch.setattr(
"application.core.model_utils.get_api_key_for_provider",
lambda _: None,
)
monkeypatch.setattr(
"application.core.model_utils.get_provider_from_model_id",
lambda _: None,
)
monkeypatch.setattr(
"application.core.model_utils.get_model_capabilities",
lambda _: None,
)
list(engine._execute_agent_node(node))
assert engine.state["node_n1_output"] == "42"
def test_agent_node_model_config_override(self, monkeypatch):
"""Cover lines 213-215: node_config with model_id and llm_name."""
engine = create_engine()
engine.state["query"] = "test"
node = create_agent_node(node_id="n2")
node.config["model_id"] = "gpt-4o"
node.config["llm_name"] = "openai"
node_events = [{"answer": "result"}]
monkeypatch.setattr(
WorkflowNodeAgentFactory,
"create",
staticmethod(lambda **kwargs: StubNodeAgent(node_events)),
)
monkeypatch.setattr(
"application.core.model_utils.get_api_key_for_provider",
lambda _: "key",
)
monkeypatch.setattr(
"application.core.model_utils.get_provider_from_model_id",
lambda _: "openai",
)
monkeypatch.setattr(
"application.core.model_utils.get_model_capabilities",
lambda _: None,
)
list(engine._execute_agent_node(node))
assert engine.state["node_n2_output"] == "result"
def test_agent_node_unsupported_structured_output_raises(self, monkeypatch):
"""Cover line 223: model does not support structured output raises."""
engine = create_engine()
engine.state["query"] = "test"
node = create_agent_node(
node_id="n3",
json_schema={"type": "object", "properties": {"a": {"type": "string"}}},
)
node.config["model_id"] = "model-no-struct"
monkeypatch.setattr(
"application.core.model_utils.get_api_key_for_provider",
lambda _: "key",
)
monkeypatch.setattr(
"application.core.model_utils.get_provider_from_model_id",
lambda _: "openai",
)
monkeypatch.setattr(
"application.core.model_utils.get_model_capabilities",
lambda _: {"supports_structured_output": False},
)
with pytest.raises(ValueError, match="does not support structured output"):
list(engine._execute_agent_node(node))
def test_structured_output_with_structured_response(self, monkeypatch):
"""Cover lines 283-284: structured response parsed and validated."""
engine = create_engine()
engine.state["query"] = "test"
node = create_agent_node(
node_id="n4",
output_variable="result",
json_schema={"type": "object", "properties": {"key": {"type": "string"}}},
)
node_events = [
{"answer": '{"key": "val"}', "structured": True},
]
monkeypatch.setattr(
WorkflowNodeAgentFactory,
"create",
staticmethod(lambda **kwargs: StubNodeAgent(node_events)),
)
monkeypatch.setattr(
"application.core.model_utils.get_api_key_for_provider",
lambda _: None,
)
monkeypatch.setattr(
"application.core.model_utils.get_provider_from_model_id",
lambda _: None,
)
monkeypatch.setattr(
"application.core.model_utils.get_model_capabilities",
lambda _: {"supports_structured_output": True},
)
list(engine._execute_agent_node(node))
assert engine.state["result"] == {"key": "val"}
def test_json_schema_no_structured_flag_parses_response(self, monkeypatch):
"""Cover line 289: json_schema set but no structured flag; non-JSON response raises."""
engine = create_engine()
engine.state["query"] = "test"
node = create_agent_node(
node_id="n5",
json_schema={"type": "object", "properties": {"x": {"type": "string"}}},
)
node_events = [{"answer": "not valid json"}]
monkeypatch.setattr(
WorkflowNodeAgentFactory,
"create",
staticmethod(lambda **kwargs: StubNodeAgent(node_events)),
)
monkeypatch.setattr(
"application.core.model_utils.get_api_key_for_provider",
lambda _: None,
)
monkeypatch.setattr(
"application.core.model_utils.get_provider_from_model_id",
lambda _: None,
)
monkeypatch.setattr(
"application.core.model_utils.get_model_capabilities",
lambda _: {"supports_structured_output": True},
)
with pytest.raises(
ValueError,
match="Structured output was expected but response was not valid JSON",
):
list(engine._execute_agent_node(node))
def test_parse_structured_output_empty_string(self):
"""Cover line 355: _parse_structured_output with empty string."""
engine = create_engine()
success, result = engine._parse_structured_output("")
assert success is False
assert result is None
def test_normalize_node_json_schema_invalid(self):
"""Cover line 375: _normalize_node_json_schema with invalid schema raises."""
engine = create_engine()
# A non-dict schema triggers JsonSchemaValidationError
with pytest.raises(ValueError, match="Invalid JSON schema"):
engine._normalize_node_json_schema("not_a_dict", "TestNode")

View File

@@ -0,0 +1,833 @@
"""Tests covering gaps in WorkflowEngine: execute loop, state/condition/end nodes,
template context, source data, structured output parsing, get_execution_summary."""
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
import pytest
from application.agents.workflows.schemas import (
ExecutionStatus,
NodeType,
WorkflowEdge,
WorkflowGraph,
WorkflowNode,
Workflow,
)
from application.agents.workflows.workflow_engine import WorkflowEngine
def _make_graph(nodes, edges):
wf = Workflow(name="Test", description="test workflow")
return WorkflowGraph(workflow=wf, nodes=nodes, edges=edges)
def _make_node(id, type, title="Node", config=None, position=None):
return WorkflowNode(
id=id,
workflow_id="wf1",
type=type,
title=title,
position=position or {"x": 0, "y": 0},
config=config or {},
)
def _make_edge(id, source, target, source_handle=None, target_handle=None):
return WorkflowEdge(
id=id,
workflow_id="wf1",
source=source,
target=target,
sourceHandle=source_handle,
targetHandle=target_handle,
)
def _make_agent():
agent = MagicMock()
agent.chat_history = []
agent.endpoint = "https://api.example.com"
agent.llm_name = "openai"
agent.model_id = "gpt-4"
agent.api_key = "key"
agent.decoded_token = {"sub": "user1"}
agent.retrieved_docs = None
return agent
class TestExecuteLoop:
@pytest.mark.unit
def test_no_start_node_yields_error(self):
graph = _make_graph([], [])
engine = WorkflowEngine(graph, _make_agent())
events = list(engine.execute({}, "query"))
assert any(e.get("type") == "error" and "start node" in e.get("error", "") for e in events)
@pytest.mark.unit
def test_start_to_end(self):
nodes = [
_make_node("n1", NodeType.START, "Start"),
_make_node("n2", NodeType.END, "End", config={"config": {}}),
]
edges = [_make_edge("e1", "n1", "n2")]
graph = _make_graph(nodes, edges)
engine = WorkflowEngine(graph, _make_agent())
events = list(engine.execute({}, "hello"))
step_events = [e for e in events if e.get("type") == "workflow_step"]
assert len(step_events) >= 2 # At least start + end
@pytest.mark.unit
def test_node_not_found_yields_error(self):
nodes = [_make_node("n1", NodeType.START)]
edges = [_make_edge("e1", "n1", "nonexistent")]
graph = _make_graph(nodes, edges)
engine = WorkflowEngine(graph, _make_agent())
events = list(engine.execute({}, "q"))
assert any("not found" in e.get("error", "") for e in events)
@pytest.mark.unit
def test_node_execution_error_yields_error(self):
nodes = [
_make_node("n1", NodeType.START),
_make_node("n2", NodeType.STATE, "State", config={"config": {"operations": [{"expression": "bad!!!", "target_variable": "x"}]}}),
]
edges = [_make_edge("e1", "n1", "n2")]
graph = _make_graph(nodes, edges)
engine = WorkflowEngine(graph, _make_agent())
events = list(engine.execute({}, "q"))
failed_events = [e for e in events if e.get("status") == "failed"]
assert len(failed_events) >= 1
@pytest.mark.unit
def test_max_steps_limit(self):
# Create a cycle: start -> state -> state (loop)
nodes = [
_make_node("n1", NodeType.START),
_make_node("n2", NodeType.NOTE, "Note"),
]
edges = [_make_edge("e1", "n1", "n2"), _make_edge("e2", "n2", "n2")]
graph = _make_graph(nodes, edges)
engine = WorkflowEngine(graph, _make_agent())
engine.MAX_EXECUTION_STEPS = 5
events = list(engine.execute({}, "q"))
# Should not run forever
assert len(events) > 0
@pytest.mark.unit
def test_branch_ends_without_end_node(self):
nodes = [
_make_node("n1", NodeType.START),
_make_node("n2", NodeType.NOTE, "Note"),
]
edges = [_make_edge("e1", "n1", "n2")] # n2 has no outgoing edges
graph = _make_graph(nodes, edges)
engine = WorkflowEngine(graph, _make_agent())
events = list(engine.execute({}, "q"))
assert len(events) > 0
class TestInitializeState:
@pytest.mark.unit
def test_sets_query_and_history(self):
graph = _make_graph([], [])
agent = _make_agent()
agent.chat_history = [{"prompt": "hi", "response": "hey"}]
engine = WorkflowEngine(graph, agent)
engine._initialize_state({"custom": "value"}, "test query")
assert engine.state["query"] == "test query"
assert "custom" in engine.state
assert engine.state["chat_history"] is not None
class TestGetNextNodeId:
@pytest.mark.unit
def test_no_edges_returns_none(self):
nodes = [_make_node("n1", NodeType.START)]
graph = _make_graph(nodes, [])
engine = WorkflowEngine(graph, _make_agent())
assert engine._get_next_node_id("n1") is None
@pytest.mark.unit
def test_returns_first_edge_target(self):
nodes = [_make_node("n1", NodeType.START), _make_node("n2", NodeType.END)]
edges = [_make_edge("e1", "n1", "n2")]
graph = _make_graph(nodes, edges)
engine = WorkflowEngine(graph, _make_agent())
assert engine._get_next_node_id("n1") == "n2"
@pytest.mark.unit
def test_condition_uses_matched_handle(self):
nodes = [
_make_node("n1", NodeType.CONDITION),
_make_node("n2", NodeType.END, "Yes End"),
_make_node("n3", NodeType.END, "No End"),
]
edges = [
_make_edge("e1", "n1", "n2", source_handle="yes"),
_make_edge("e2", "n1", "n3", source_handle="no"),
]
graph = _make_graph(nodes, edges)
engine = WorkflowEngine(graph, _make_agent())
engine._condition_result = "no"
assert engine._get_next_node_id("n1") == "n3"
assert engine._condition_result is None # Cleared after use
@pytest.mark.unit
def test_condition_no_matching_handle_returns_none(self):
nodes = [_make_node("n1", NodeType.CONDITION)]
edges = [_make_edge("e1", "n1", "n2", source_handle="yes")]
graph = _make_graph(nodes, edges)
engine = WorkflowEngine(graph, _make_agent())
engine._condition_result = "nonexistent"
assert engine._get_next_node_id("n1") is None
class TestExecuteStateNode:
@pytest.mark.unit
def test_evaluates_operations(self):
node = _make_node("n1", NodeType.STATE, config={
"config": {
"operations": [
{"expression": "x + 1", "target_variable": "result"},
]
}
})
graph = _make_graph([node], [])
engine = WorkflowEngine(graph, _make_agent())
engine.state = {"x": 5}
list(engine._execute_state_node(node))
assert engine.state["result"] == 6
@pytest.mark.unit
def test_skips_empty_expression(self):
node = _make_node("n1", NodeType.STATE, config={
"config": {
"operations": [
{"expression": "", "target_variable": "result"},
]
}
})
graph = _make_graph([node], [])
engine = WorkflowEngine(graph, _make_agent())
engine.state = {}
list(engine._execute_state_node(node))
assert "result" not in engine.state
class TestExecuteConditionNode:
@pytest.mark.unit
def test_matches_first_true_case(self):
node = _make_node("n1", NodeType.CONDITION, config={
"config": {
"cases": [
{"expression": "x > 10", "source_handle": "high"},
{"expression": "x > 5", "source_handle": "medium"},
]
}
})
graph = _make_graph([node], [])
engine = WorkflowEngine(graph, _make_agent())
engine.state = {"x": 7}
list(engine._execute_condition_node(node))
assert engine._condition_result == "medium"
@pytest.mark.unit
def test_falls_through_to_else(self):
node = _make_node("n1", NodeType.CONDITION, config={
"config": {
"cases": [
{"expression": "x > 100", "source_handle": "high"},
]
}
})
graph = _make_graph([node], [])
engine = WorkflowEngine(graph, _make_agent())
engine.state = {"x": 1}
list(engine._execute_condition_node(node))
assert engine._condition_result == "else"
@pytest.mark.unit
def test_skips_empty_expression(self):
node = _make_node("n1", NodeType.CONDITION, config={
"config": {
"cases": [
{"expression": " ", "source_handle": "a"},
{"expression": "true", "source_handle": "b"},
]
}
})
graph = _make_graph([node], [])
engine = WorkflowEngine(graph, _make_agent())
engine.state = {}
list(engine._execute_condition_node(node))
assert engine._condition_result == "b"
@pytest.mark.unit
def test_cel_error_continues(self):
node = _make_node("n1", NodeType.CONDITION, config={
"config": {
"cases": [
{"expression": "bad!!!", "source_handle": "a"},
{"expression": "true", "source_handle": "b"},
]
}
})
graph = _make_graph([node], [])
engine = WorkflowEngine(graph, _make_agent())
engine.state = {}
list(engine._execute_condition_node(node))
assert engine._condition_result == "b"
class TestExecuteEndNode:
@pytest.mark.unit
def test_with_output_template(self):
node = _make_node("n1", NodeType.END, config={
"config": {"output_template": "Result: {{ query }}"}
})
graph = _make_graph([node], [])
engine = WorkflowEngine(graph, _make_agent())
engine.state = {"query": "hello"}
engine._format_template = MagicMock(return_value="Result: hello")
events = list(engine._execute_end_node(node))
assert len(events) == 1
assert events[0]["answer"] == "Result: hello"
@pytest.mark.unit
def test_without_output_template(self):
node = _make_node("n1", NodeType.END, config={"config": {}})
graph = _make_graph([node], [])
engine = WorkflowEngine(graph, _make_agent())
events = list(engine._execute_end_node(node))
assert len(events) == 0
class TestParseStructuredOutput:
@pytest.mark.unit
def test_valid_json(self):
graph = _make_graph([], [])
engine = WorkflowEngine(graph, _make_agent())
success, data = engine._parse_structured_output('{"key": "value"}')
assert success is True
assert data == {"key": "value"}
@pytest.mark.unit
def test_invalid_json(self):
graph = _make_graph([], [])
engine = WorkflowEngine(graph, _make_agent())
success, data = engine._parse_structured_output("not json")
assert success is False
assert data is None
@pytest.mark.unit
def test_empty_string(self):
graph = _make_graph([], [])
engine = WorkflowEngine(graph, _make_agent())
success, data = engine._parse_structured_output("")
assert success is False
assert data is None
@pytest.mark.unit
def test_whitespace_only(self):
graph = _make_graph([], [])
engine = WorkflowEngine(graph, _make_agent())
success, data = engine._parse_structured_output(" ")
assert success is False
assert data is None
class TestNormalizeNodeJsonSchema:
@pytest.mark.unit
def test_none_returns_none(self):
graph = _make_graph([], [])
engine = WorkflowEngine(graph, _make_agent())
assert engine._normalize_node_json_schema(None, "Node") is None
@pytest.mark.unit
def test_valid_schema(self):
graph = _make_graph([], [])
engine = WorkflowEngine(graph, _make_agent())
schema = {"type": "object", "properties": {"name": {"type": "string"}}}
result = engine._normalize_node_json_schema(schema, "Node")
assert result is not None
@pytest.mark.unit
def test_invalid_schema_raises(self):
graph = _make_graph([], [])
engine = WorkflowEngine(graph, _make_agent())
with patch("application.agents.workflows.workflow_engine.normalize_json_schema_payload") as mock_norm:
from application.core.json_schema_utils import JsonSchemaValidationError
mock_norm.side_effect = JsonSchemaValidationError("bad schema")
with pytest.raises(ValueError, match="Invalid JSON schema"):
engine._normalize_node_json_schema({"bad": True}, "TestNode")
class TestValidateStructuredOutput:
@pytest.mark.unit
def test_valid_output_passes(self):
graph = _make_graph([], [])
engine = WorkflowEngine(graph, _make_agent())
schema = {"type": "object", "properties": {"name": {"type": "string"}}}
engine._validate_structured_output(schema, {"name": "Alice"}) # Should not raise
@pytest.mark.unit
def test_invalid_output_raises(self):
graph = _make_graph([], [])
engine = WorkflowEngine(graph, _make_agent())
schema = {"type": "object", "properties": {"name": {"type": "string"}}, "required": ["name"]}
with pytest.raises(ValueError, match="did not match schema"):
engine._validate_structured_output(schema, {})
@pytest.mark.unit
def test_no_jsonschema_module(self):
graph = _make_graph([], [])
engine = WorkflowEngine(graph, _make_agent())
with patch("application.agents.workflows.workflow_engine.jsonschema", None):
engine._validate_structured_output({"type": "object"}, {}) # Should not raise
class TestFormatTemplate:
@pytest.mark.unit
def test_renders_template(self):
graph = _make_graph([], [])
engine = WorkflowEngine(graph, _make_agent())
engine.state = {"query": "hello"}
engine._build_template_context = MagicMock(return_value={"query": "hello"})
engine._template_engine = MagicMock()
engine._template_engine.render.return_value = "hello world"
result = engine._format_template("{{ query }} world")
assert result == "hello world"
@pytest.mark.unit
def test_render_error_returns_raw(self):
from application.templates.template_engine import TemplateRenderError
graph = _make_graph([], [])
engine = WorkflowEngine(graph, _make_agent())
engine._build_template_context = MagicMock(return_value={})
engine._template_engine = MagicMock()
engine._template_engine.render.side_effect = TemplateRenderError("fail")
result = engine._format_template("{{ bad }}")
assert result == "{{ bad }}"
class TestBuildTemplateContext:
@pytest.mark.unit
def test_includes_state_variables(self):
graph = _make_graph([], [])
agent = _make_agent()
agent.retrieved_docs = None
engine = WorkflowEngine(graph, agent)
engine.state = {"query": "hello", "custom_var": "value"}
context = engine._build_template_context()
assert context["agent"]["query"] == "hello"
assert "custom_var" in context
@pytest.mark.unit
def test_reserved_namespace_gets_prefixed(self):
graph = _make_graph([], [])
agent = _make_agent()
agent.retrieved_docs = None
engine = WorkflowEngine(graph, agent)
engine.state = {"source": "my_source_val"}
context = engine._build_template_context()
assert context.get("agent_source") == "my_source_val"
@pytest.mark.unit
def test_passthrough_data(self):
graph = _make_graph([], [])
agent = _make_agent()
agent.retrieved_docs = None
engine = WorkflowEngine(graph, agent)
engine.state = {"passthrough": {"key": "val"}}
context = engine._build_template_context()
assert "passthrough" in context or "agent_passthrough" in context
@pytest.mark.unit
def test_tools_data(self):
graph = _make_graph([], [])
agent = _make_agent()
agent.retrieved_docs = None
engine = WorkflowEngine(graph, agent)
engine.state = {"tools": {"tool1": "result"}}
context = engine._build_template_context()
assert "agent" in context
class TestGetSourceTemplateData:
@pytest.mark.unit
def test_no_docs_returns_none(self):
graph = _make_graph([], [])
agent = _make_agent()
agent.retrieved_docs = None
engine = WorkflowEngine(graph, agent)
docs, together = engine._get_source_template_data()
assert docs is None
assert together is None
@pytest.mark.unit
def test_empty_docs_returns_none(self):
graph = _make_graph([], [])
agent = _make_agent()
agent.retrieved_docs = []
engine = WorkflowEngine(graph, agent)
docs, together = engine._get_source_template_data()
assert docs is None
@pytest.mark.unit
def test_docs_with_filename(self):
graph = _make_graph([], [])
agent = _make_agent()
agent.retrieved_docs = [{"text": "content", "filename": "doc.txt"}]
engine = WorkflowEngine(graph, agent)
docs, together = engine._get_source_template_data()
assert docs is not None
assert "doc.txt" in together
assert "content" in together
@pytest.mark.unit
def test_docs_without_filename(self):
graph = _make_graph([], [])
agent = _make_agent()
agent.retrieved_docs = [{"text": "content only"}]
engine = WorkflowEngine(graph, agent)
docs, together = engine._get_source_template_data()
assert together == "content only"
@pytest.mark.unit
def test_skips_non_dict_docs(self):
graph = _make_graph([], [])
agent = _make_agent()
agent.retrieved_docs = ["not a dict", {"text": "ok"}]
engine = WorkflowEngine(graph, agent)
docs, together = engine._get_source_template_data()
assert together == "ok"
@pytest.mark.unit
def test_skips_non_string_text(self):
graph = _make_graph([], [])
agent = _make_agent()
agent.retrieved_docs = [{"text": 123}]
engine = WorkflowEngine(graph, agent)
docs, together = engine._get_source_template_data()
assert together is None
@pytest.mark.unit
def test_doc_with_title_fallback(self):
graph = _make_graph([], [])
agent = _make_agent()
agent.retrieved_docs = [{"text": "content", "title": "doc_title"}]
engine = WorkflowEngine(graph, agent)
docs, together = engine._get_source_template_data()
assert "doc_title" in together
@pytest.mark.unit
def test_doc_with_source_fallback(self):
graph = _make_graph([], [])
agent = _make_agent()
agent.retrieved_docs = [{"text": "content", "source": "src"}]
engine = WorkflowEngine(graph, agent)
docs, together = engine._get_source_template_data()
assert "src" in together
class TestGetExecutionSummary:
@pytest.mark.unit
def test_returns_log_entries(self):
graph = _make_graph([], [])
engine = WorkflowEngine(graph, _make_agent())
now = datetime.now(timezone.utc)
engine.execution_log = [
{
"node_id": "n1",
"node_type": "start",
"status": "completed",
"started_at": now,
"completed_at": now,
"error": None,
"state_snapshot": {},
}
]
summary = engine.get_execution_summary()
assert len(summary) == 1
assert summary[0].node_id == "n1"
assert summary[0].status == ExecutionStatus.COMPLETED
@pytest.mark.unit
def test_empty_log(self):
graph = _make_graph([], [])
engine = WorkflowEngine(graph, _make_agent())
assert engine.get_execution_summary() == []
class TestAgentNodeExecution:
"""Cover lines 204, 213-215, 223, 232-233, 283-284, 289, 355, 375."""
@pytest.mark.unit
def test_agent_node_without_prompt_template(self):
"""Cover line 204/206: agent node without prompt_template uses query."""
node = _make_node("n1", NodeType.AGENT, "Agent", config={
"config": {
"agent_type": "classic",
"stream_to_user": False,
}
})
graph = _make_graph([node], [])
engine = WorkflowEngine(graph, _make_agent())
engine.state = {"query": "test question"}
mock_agent = MagicMock()
mock_agent.gen.return_value = [{"answer": "response"}]
with patch(
"application.agents.workflows.workflow_engine.WorkflowNodeAgentFactory"
) as mock_factory, \
patch(
"application.core.model_utils.get_provider_from_model_id",
return_value="openai",
), \
patch(
"application.core.model_utils.get_api_key_for_provider",
return_value="key",
), \
patch(
"application.core.model_utils.get_model_capabilities",
return_value=None,
):
mock_factory.create.return_value = mock_agent
list(engine._execute_agent_node(node))
output_key = f"node_{node.id}_output"
assert output_key in engine.state
@pytest.mark.unit
def test_agent_node_with_structured_output(self):
"""Cover lines 283-284, 289: structured output parsing."""
node = _make_node("n1", NodeType.AGENT, "Agent", config={
"config": {
"agent_type": "classic",
"stream_to_user": False,
"json_schema": {
"type": "object",
"properties": {"name": {"type": "string"}},
},
}
})
graph = _make_graph([node], [])
engine = WorkflowEngine(graph, _make_agent())
engine.state = {"query": "test"}
mock_agent = MagicMock()
mock_agent.gen.return_value = [
{"answer": '{"name": "Alice"}', "structured": True}
]
with patch(
"application.agents.workflows.workflow_engine.WorkflowNodeAgentFactory"
) as mock_factory, \
patch(
"application.core.model_utils.get_provider_from_model_id",
return_value="openai",
), \
patch(
"application.core.model_utils.get_api_key_for_provider",
return_value="key",
), \
patch(
"application.core.model_utils.get_model_capabilities",
return_value={"supports_structured_output": True},
):
mock_factory.create.return_value = mock_agent
list(engine._execute_agent_node(node))
output_key = f"node_{node.id}_output"
assert engine.state[output_key] == {"name": "Alice"}
@pytest.mark.unit
def test_agent_node_model_no_structured_support_raises(self):
"""Cover lines 223: model without structured output raises ValueError."""
node = _make_node("n1", NodeType.AGENT, "Agent", config={
"config": {
"agent_type": "classic",
"json_schema": {
"type": "object",
"properties": {"x": {"type": "string"}},
},
"model_id": "test-model",
}
})
graph = _make_graph([node], [])
engine = WorkflowEngine(graph, _make_agent())
engine.state = {"query": "test"}
with patch(
"application.core.model_utils.get_provider_from_model_id",
return_value="openai",
), \
patch(
"application.core.model_utils.get_api_key_for_provider",
return_value="key",
), \
patch(
"application.core.model_utils.get_model_capabilities",
return_value={"supports_structured_output": False},
):
with pytest.raises(ValueError, match="does not support structured output"):
list(engine._execute_agent_node(node))
@pytest.mark.unit
def test_agent_node_output_variable(self):
"""Cover line 300: output_variable stores result."""
node = _make_node("n1", NodeType.AGENT, "Agent", config={
"config": {
"agent_type": "classic",
"stream_to_user": False,
"output_variable": "my_result",
}
})
graph = _make_graph([node], [])
engine = WorkflowEngine(graph, _make_agent())
engine.state = {"query": "test"}
mock_agent = MagicMock()
mock_agent.gen.return_value = [{"answer": "output text"}]
with patch(
"application.agents.workflows.workflow_engine.WorkflowNodeAgentFactory"
) as mock_factory, \
patch(
"application.core.model_utils.get_provider_from_model_id",
return_value="openai",
), \
patch(
"application.core.model_utils.get_api_key_for_provider",
return_value="key",
), \
patch(
"application.core.model_utils.get_model_capabilities",
return_value=None,
):
mock_factory.create.return_value = mock_agent
list(engine._execute_agent_node(node))
assert engine.state["my_result"] == "output text"
@pytest.mark.unit
def test_validate_structured_output_schema_error(self):
"""Cover line 375/382-383: invalid schema raises ValueError."""
graph = _make_graph([], [])
engine = WorkflowEngine(graph, _make_agent())
import jsonschema as js
with patch(
"application.agents.workflows.workflow_engine.normalize_json_schema_payload",
return_value={"type": "invalid_schema_type"},
), \
patch(
"application.agents.workflows.workflow_engine.jsonschema"
) as mock_js:
mock_js.validate.side_effect = js.exceptions.SchemaError("bad schema")
mock_js.exceptions = js.exceptions
with pytest.raises(ValueError, match="Invalid JSON schema"):
engine._validate_structured_output(
{"type": "object"}, {"name": "test"}
)
@pytest.mark.unit
def test_parse_structured_output_invalid_json(self):
"""Cover lines 349-352: invalid JSON returns False."""
graph = _make_graph([], [])
engine = WorkflowEngine(graph, _make_agent())
success, data = engine._parse_structured_output("not json {")
assert success is False
assert data is None
# ---------------------------------------------------------------------------
# Additional coverage for workflow_engine.py
# Lines: 96-114 (exception in node execution), 122-130 (branch/max steps)
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestWorkflowNodeExecutionException:
"""Cover lines 96-114: exception during _execute_node yields error events."""
def test_node_raises_exception_yields_error(self):
"""Force _execute_node to raise, covering lines 96-114."""
nodes = [
_make_node("n1", NodeType.START),
_make_node("n2", NodeType.AGENT, "Agent"),
]
edges = [_make_edge("e1", "n1", "n2")]
graph = _make_graph(nodes, edges)
engine = WorkflowEngine(graph, _make_agent())
# Patch _execute_node to raise on agent node
original_execute = engine._execute_node
def patched_execute(node):
if node.type == NodeType.AGENT:
raise RuntimeError("Agent exploded")
yield from original_execute(node)
engine._execute_node = patched_execute
events = list(engine.execute({}, "test query"))
error_events = [e for e in events if e.get("type") == "error"]
assert len(error_events) >= 1
failed_steps = [e for e in events if e.get("status") == "failed"]
assert len(failed_steps) >= 1
@pytest.mark.unit
class TestWorkflowMaxStepsReached:
"""Cover lines 127-130: max steps limit warning."""
def test_max_steps_exactly_reached(self):
nodes = [
_make_node("n1", NodeType.START),
_make_node("n2", NodeType.NOTE, "Note"),
]
edges = [_make_edge("e1", "n1", "n2"), _make_edge("e2", "n2", "n2")]
graph = _make_graph(nodes, edges)
engine = WorkflowEngine(graph, _make_agent())
engine.MAX_EXECUTION_STEPS = 3
events = list(engine.execute({}, "q"))
# The while loop runs 3 times then exits, steps >= MAX
assert len(events) >= 3
@pytest.mark.unit
class TestWorkflowBranchEndsNonEndNode:
"""Cover lines 122-125: branch ends at non-end node without outgoing edges."""
def test_branch_ends_at_state_node(self):
nodes = [
_make_node("n1", NodeType.START),
_make_node(
"n2",
NodeType.STATE,
"State",
config={"config": {"operations": []}},
),
]
edges = [_make_edge("e1", "n1", "n2")] # n2 has no outgoing
graph = _make_graph(nodes, edges)
engine = WorkflowEngine(graph, _make_agent())
events = list(engine.execute({}, "q"))
# Should complete without crash, branch ended warning logged
assert len(events) > 0

View File

@@ -0,0 +1,556 @@
from datetime import datetime, timezone
import pytest
from bson import ObjectId
from pydantic import ValidationError
from application.agents.workflows.schemas import (
AgentNodeConfig,
AgentType,
ConditionCase,
ConditionNodeConfig,
ExecutionStatus,
NodeExecutionLog,
NodeType,
Position,
StateOperation,
Workflow,
WorkflowCreate,
WorkflowEdge,
WorkflowEdgeCreate,
WorkflowGraph,
WorkflowNode,
WorkflowNodeCreate,
WorkflowRun,
WorkflowRunCreate,
)
# ── Enum tests ───────────────────────────────────────────────────────────────
class TestNodeType:
@pytest.mark.unit
def test_values(self):
assert NodeType.START == "start"
assert NodeType.END == "end"
assert NodeType.AGENT == "agent"
assert NodeType.NOTE == "note"
assert NodeType.STATE == "state"
assert NodeType.CONDITION == "condition"
@pytest.mark.unit
def test_all_members(self):
assert set(NodeType) == {
NodeType.START,
NodeType.END,
NodeType.AGENT,
NodeType.NOTE,
NodeType.STATE,
NodeType.CONDITION,
}
class TestAgentType:
@pytest.mark.unit
def test_values(self):
assert AgentType.CLASSIC == "classic"
assert AgentType.REACT == "react"
assert AgentType.AGENTIC == "agentic"
assert AgentType.RESEARCH == "research"
class TestExecutionStatus:
@pytest.mark.unit
def test_values(self):
assert ExecutionStatus.PENDING == "pending"
assert ExecutionStatus.RUNNING == "running"
assert ExecutionStatus.COMPLETED == "completed"
assert ExecutionStatus.FAILED == "failed"
# ── Position ─────────────────────────────────────────────────────────────────
class TestPosition:
@pytest.mark.unit
def test_defaults(self):
p = Position()
assert p.x == 0.0
assert p.y == 0.0
@pytest.mark.unit
def test_custom_values(self):
p = Position(x=10.5, y=-3.2)
assert p.x == 10.5
assert p.y == -3.2
@pytest.mark.unit
def test_extra_fields_forbidden(self):
with pytest.raises(ValidationError):
Position(x=0, y=0, z=1)
# ── AgentNodeConfig ──────────────────────────────────────────────────────────
class TestAgentNodeConfig:
@pytest.mark.unit
def test_defaults(self):
c = AgentNodeConfig()
assert c.agent_type == AgentType.CLASSIC
assert c.llm_name is None
assert c.system_prompt == "You are a helpful assistant."
assert c.prompt_template == ""
assert c.output_variable is None
assert c.stream_to_user is True
assert c.tools == []
assert c.sources == []
assert c.chunks == "2"
assert c.retriever == ""
assert c.model_id is None
assert c.json_schema is None
@pytest.mark.unit
def test_custom_values(self):
c = AgentNodeConfig(
agent_type=AgentType.REACT,
llm_name="gpt-4",
tools=["search"],
sources=["src1"],
chunks="5",
model_id="m1",
json_schema={"type": "object"},
)
assert c.agent_type == AgentType.REACT
assert c.llm_name == "gpt-4"
assert c.tools == ["search"]
assert c.sources == ["src1"]
assert c.chunks == "5"
assert c.model_id == "m1"
assert c.json_schema == {"type": "object"}
@pytest.mark.unit
def test_extra_fields_allowed(self):
c = AgentNodeConfig(custom_field="value")
assert c.custom_field == "value"
# ── ConditionCase / ConditionNodeConfig ──────────────────────────────────────
class TestConditionCase:
@pytest.mark.unit
def test_alias(self):
c = ConditionCase(expression="x > 1", sourceHandle="handle-1")
assert c.source_handle == "handle-1"
@pytest.mark.unit
def test_defaults(self):
c = ConditionCase(sourceHandle="h")
assert c.name is None
assert c.expression == ""
@pytest.mark.unit
def test_extra_forbidden(self):
with pytest.raises(ValidationError):
ConditionCase(sourceHandle="h", extra="nope")
class TestConditionNodeConfig:
@pytest.mark.unit
def test_defaults(self):
c = ConditionNodeConfig()
assert c.mode == "simple"
assert c.cases == []
@pytest.mark.unit
def test_with_cases(self):
c = ConditionNodeConfig(
mode="advanced",
cases=[{"expression": "x > 1", "sourceHandle": "h1"}],
)
assert c.mode == "advanced"
assert len(c.cases) == 1
assert c.cases[0].source_handle == "h1"
# ── StateOperation ───────────────────────────────────────────────────────────
class TestStateOperation:
@pytest.mark.unit
def test_defaults(self):
s = StateOperation()
assert s.expression == ""
assert s.target_variable == ""
@pytest.mark.unit
def test_extra_forbidden(self):
with pytest.raises(ValidationError):
StateOperation(expression="a", target_variable="b", extra="no")
# ── WorkflowEdgeCreate / WorkflowEdge ───────────────────────────────────────
class TestWorkflowEdgeCreate:
@pytest.mark.unit
def test_aliases(self):
e = WorkflowEdgeCreate(
id="e1",
workflow_id="w1",
source="n1",
target="n2",
sourceHandle="sh",
targetHandle="th",
)
assert e.source_id == "n1"
assert e.target_id == "n2"
assert e.source_handle == "sh"
assert e.target_handle == "th"
@pytest.mark.unit
def test_optional_handles_default_none(self):
e = WorkflowEdgeCreate(id="e1", workflow_id="w1", source="n1", target="n2")
assert e.source_handle is None
assert e.target_handle is None
class TestWorkflowEdge:
@pytest.mark.unit
def test_objectid_conversion(self):
oid = ObjectId()
e = WorkflowEdge(
_id=oid,
id="e1",
workflow_id="w1",
source="n1",
target="n2",
)
assert e.mongo_id == str(oid)
@pytest.mark.unit
def test_string_id_passthrough(self):
e = WorkflowEdge(
_id="string-id",
id="e1",
workflow_id="w1",
source="n1",
target="n2",
)
assert e.mongo_id == "string-id"
@pytest.mark.unit
def test_none_id(self):
e = WorkflowEdge(id="e1", workflow_id="w1", source="n1", target="n2")
assert e.mongo_id is None
@pytest.mark.unit
def test_to_mongo_doc(self):
e = WorkflowEdge(
id="e1",
workflow_id="w1",
source="n1",
target="n2",
sourceHandle="sh",
targetHandle="th",
)
doc = e.to_mongo_doc()
assert doc == {
"id": "e1",
"workflow_id": "w1",
"source_id": "n1",
"target_id": "n2",
"source_handle": "sh",
"target_handle": "th",
}
# ── WorkflowNodeCreate / WorkflowNode ───────────────────────────────────────
class TestWorkflowNodeCreate:
@pytest.mark.unit
def test_defaults(self):
n = WorkflowNodeCreate(id="n1", workflow_id="w1", type=NodeType.AGENT)
assert n.title == "Node"
assert n.description is None
assert n.position.x == 0.0
assert n.position.y == 0.0
assert n.config == {}
@pytest.mark.unit
def test_position_from_dict(self):
n = WorkflowNodeCreate(
id="n1",
workflow_id="w1",
type=NodeType.START,
position={"x": 100, "y": 200},
)
assert isinstance(n.position, Position)
assert n.position.x == 100
assert n.position.y == 200
@pytest.mark.unit
def test_position_from_position_object(self):
pos = Position(x=5, y=10)
n = WorkflowNodeCreate(
id="n1", workflow_id="w1", type=NodeType.END, position=pos
)
assert n.position is pos
class TestWorkflowNode:
@pytest.mark.unit
def test_objectid_conversion(self):
oid = ObjectId()
n = WorkflowNode(
_id=oid, id="n1", workflow_id="w1", type=NodeType.AGENT
)
assert n.mongo_id == str(oid)
@pytest.mark.unit
def test_to_mongo_doc(self):
n = WorkflowNode(
id="n1",
workflow_id="w1",
type=NodeType.AGENT,
title="My Node",
description="desc",
position={"x": 10, "y": 20},
config={"key": "val"},
)
doc = n.to_mongo_doc()
assert doc == {
"id": "n1",
"workflow_id": "w1",
"type": "agent",
"title": "My Node",
"description": "desc",
"position": {"x": 10.0, "y": 20.0},
"config": {"key": "val"},
}
# ── WorkflowCreate / Workflow ───────────────────────────────────────────────
class TestWorkflowCreate:
@pytest.mark.unit
def test_defaults(self):
w = WorkflowCreate()
assert w.name == "New Workflow"
assert w.description is None
assert w.user is None
@pytest.mark.unit
def test_custom_values(self):
w = WorkflowCreate(name="Test", description="d", user="u1")
assert w.name == "Test"
assert w.description == "d"
assert w.user == "u1"
class TestWorkflow:
@pytest.mark.unit
def test_objectid_conversion(self):
oid = ObjectId()
w = Workflow(_id=oid)
assert w.id == str(oid)
@pytest.mark.unit
def test_string_id(self):
w = Workflow(_id="abc")
assert w.id == "abc"
@pytest.mark.unit
def test_none_id(self):
w = Workflow()
assert w.id is None
@pytest.mark.unit
def test_datetime_defaults(self):
before = datetime.now(timezone.utc)
w = Workflow()
after = datetime.now(timezone.utc)
assert before <= w.created_at <= after
assert before <= w.updated_at <= after
@pytest.mark.unit
def test_to_mongo_doc(self):
w = Workflow(name="W", description="d", user="u1")
doc = w.to_mongo_doc()
assert doc["name"] == "W"
assert doc["description"] == "d"
assert doc["user"] == "u1"
assert "created_at" in doc
assert "updated_at" in doc
# ── WorkflowGraph ───────────────────────────────────────────────────────────
class TestWorkflowGraph:
@pytest.fixture
def graph(self):
workflow = Workflow(name="test")
nodes = [
WorkflowNode(id="start", workflow_id="w1", type=NodeType.START),
WorkflowNode(id="agent1", workflow_id="w1", type=NodeType.AGENT),
WorkflowNode(id="end", workflow_id="w1", type=NodeType.END),
]
edges = [
WorkflowEdge(
id="e1", workflow_id="w1", source="start", target="agent1"
),
WorkflowEdge(
id="e2", workflow_id="w1", source="agent1", target="end"
),
]
return WorkflowGraph(workflow=workflow, nodes=nodes, edges=edges)
@pytest.mark.unit
def test_get_node_by_id_found(self, graph):
node = graph.get_node_by_id("agent1")
assert node is not None
assert node.id == "agent1"
assert node.type == NodeType.AGENT
@pytest.mark.unit
def test_get_node_by_id_not_found(self, graph):
assert graph.get_node_by_id("nonexistent") is None
@pytest.mark.unit
def test_get_start_node(self, graph):
start = graph.get_start_node()
assert start is not None
assert start.id == "start"
assert start.type == NodeType.START
@pytest.mark.unit
def test_get_start_node_missing(self):
g = WorkflowGraph(
workflow=Workflow(),
nodes=[
WorkflowNode(id="a", workflow_id="w", type=NodeType.AGENT),
],
)
assert g.get_start_node() is None
@pytest.mark.unit
def test_get_outgoing_edges(self, graph):
edges = graph.get_outgoing_edges("start")
assert len(edges) == 1
assert edges[0].target_id == "agent1"
@pytest.mark.unit
def test_get_outgoing_edges_none(self, graph):
edges = graph.get_outgoing_edges("end")
assert edges == []
@pytest.mark.unit
def test_empty_graph(self):
g = WorkflowGraph(workflow=Workflow())
assert g.nodes == []
assert g.edges == []
assert g.get_start_node() is None
# ── NodeExecutionLog ─────────────────────────────────────────────────────────
class TestNodeExecutionLog:
@pytest.mark.unit
def test_required_fields(self):
now = datetime.now(timezone.utc)
log = NodeExecutionLog(
node_id="n1",
node_type="agent",
status=ExecutionStatus.RUNNING,
started_at=now,
)
assert log.node_id == "n1"
assert log.completed_at is None
assert log.error is None
assert log.state_snapshot == {}
@pytest.mark.unit
def test_full_log(self):
started = datetime.now(timezone.utc)
completed = datetime.now(timezone.utc)
log = NodeExecutionLog(
node_id="n1",
node_type="agent",
status=ExecutionStatus.COMPLETED,
started_at=started,
completed_at=completed,
error=None,
state_snapshot={"key": "value"},
)
assert log.completed_at == completed
assert log.state_snapshot == {"key": "value"}
@pytest.mark.unit
def test_extra_forbidden(self):
with pytest.raises(ValidationError):
NodeExecutionLog(
node_id="n",
node_type="agent",
status=ExecutionStatus.PENDING,
started_at=datetime.now(timezone.utc),
extra="no",
)
# ── WorkflowRunCreate / WorkflowRun ─────────────────────────────────────────
class TestWorkflowRunCreate:
@pytest.mark.unit
def test_defaults(self):
r = WorkflowRunCreate(workflow_id="w1")
assert r.workflow_id == "w1"
assert r.inputs == {}
class TestWorkflowRun:
@pytest.mark.unit
def test_defaults(self):
r = WorkflowRun(workflow_id="w1")
assert r.status == ExecutionStatus.PENDING
assert r.inputs == {}
assert r.outputs == {}
assert r.steps == []
assert r.completed_at is None
@pytest.mark.unit
def test_objectid_conversion(self):
oid = ObjectId()
r = WorkflowRun(_id=oid, workflow_id="w1")
assert r.id == str(oid)
@pytest.mark.unit
def test_to_mongo_doc(self):
now = datetime.now(timezone.utc)
log = NodeExecutionLog(
node_id="n1",
node_type="agent",
status=ExecutionStatus.COMPLETED,
started_at=now,
)
r = WorkflowRun(
workflow_id="w1",
status=ExecutionStatus.RUNNING,
inputs={"q": "hello"},
outputs={"a": "world"},
steps=[log],
)
doc = r.to_mongo_doc()
assert doc["workflow_id"] == "w1"
assert doc["status"] == "running"
assert doc["inputs"] == {"q": "hello"}
assert doc["outputs"] == {"a": "world"}
assert len(doc["steps"]) == 1
assert doc["steps"][0]["node_id"] == "n1"
assert doc["completed_at"] is None

View File

View File

@@ -0,0 +1,620 @@
"""Comprehensive tests for application/agents/tools/api_body_serializer.py
Covers: ContentType enum, RequestBodySerializer (JSON, form-urlencoded,
multipart, text/plain, XML, octet-stream, unknown types), encoding rules,
helper methods (_percent_encode, _escape_xml, _dict_to_xml).
"""
import json
import pytest
from application.agents.tools.api_body_serializer import (
ContentType,
RequestBodySerializer,
)
# =====================================================================
# ContentType Enum
# =====================================================================
@pytest.mark.unit
class TestContentTypeEnum:
def test_json_value(self):
assert ContentType.JSON == "application/json"
def test_form_urlencoded_value(self):
assert ContentType.FORM_URLENCODED == "application/x-www-form-urlencoded"
def test_multipart_value(self):
assert ContentType.MULTIPART_FORM_DATA == "multipart/form-data"
def test_text_plain_value(self):
assert ContentType.TEXT_PLAIN == "text/plain"
def test_xml_value(self):
assert ContentType.XML == "application/xml"
def test_octet_stream_value(self):
assert ContentType.OCTET_STREAM == "application/octet-stream"
def test_str_enum(self):
assert isinstance(ContentType.JSON, str)
# =====================================================================
# JSON Serialization
# =====================================================================
@pytest.mark.unit
class TestSerializeJson:
def test_basic_json(self):
body, headers = RequestBodySerializer.serialize(
{"key": "value"}, ContentType.JSON
)
assert json.loads(body) == {"key": "value"}
assert headers["Content-Type"] == "application/json"
def test_nested_json(self):
data = {"user": {"name": "Alice", "age": 30}}
body, headers = RequestBodySerializer.serialize(data, ContentType.JSON)
assert json.loads(body) == data
def test_empty_body_returns_none(self):
body, headers = RequestBodySerializer.serialize({}, ContentType.JSON)
assert body is None
assert headers == {}
def test_none_body(self):
body, headers = RequestBodySerializer.serialize(None, ContentType.JSON)
assert body is None
def test_unknown_content_type_falls_back_to_json(self):
body, headers = RequestBodySerializer.serialize(
{"k": "v"}, "application/vnd.custom+json"
)
assert json.loads(body) == {"k": "v"}
def test_content_type_with_charset_suffix(self):
body, headers = RequestBodySerializer.serialize(
{"k": "v"}, "application/json; charset=utf-8"
)
assert json.loads(body) == {"k": "v"}
def test_compact_json_format(self):
body, _ = RequestBodySerializer.serialize(
{"a": 1, "b": 2}, ContentType.JSON
)
# Should use compact separators
assert " " not in body
def test_unicode_json(self):
body, _ = RequestBodySerializer.serialize(
{"name": "Heisenberg"}, ContentType.JSON
)
parsed = json.loads(body)
assert parsed["name"] == "Heisenberg"
# =====================================================================
# Form URL-Encoded Serialization
# =====================================================================
@pytest.mark.unit
class TestSerializeFormUrlencoded:
def test_basic_form(self):
body, headers = RequestBodySerializer.serialize(
{"name": "Alice", "age": "30"}, ContentType.FORM_URLENCODED
)
assert "name=Alice" in body
assert "age=30" in body
assert headers["Content-Type"] == "application/x-www-form-urlencoded"
def test_none_values_skipped(self):
body, headers = RequestBodySerializer.serialize(
{"name": "Alice", "skip": None}, ContentType.FORM_URLENCODED
)
assert "name=Alice" in body
assert "skip" not in body
def test_list_explode_true(self):
body, headers = RequestBodySerializer.serialize(
{"tags": ["a", "b"]},
ContentType.FORM_URLENCODED,
encoding_rules={"tags": {"style": "form", "explode": True}},
)
assert "tags=a" in body
assert "tags=b" in body
def test_list_explode_false(self):
body, headers = RequestBodySerializer.serialize(
{"tags": ["a", "b"]},
ContentType.FORM_URLENCODED,
encoding_rules={"tags": {"style": "form", "explode": False}},
)
assert "tags=" in body
assert "a" in body and "b" in body
def test_dict_value_json_content_type(self):
body, headers = RequestBodySerializer.serialize(
{"metadata": {"key": "val"}},
ContentType.FORM_URLENCODED,
encoding_rules={"metadata": {"contentType": "application/json"}},
)
assert "metadata" in body
def test_dict_value_xml_content_type(self):
body, headers = RequestBodySerializer.serialize(
{"data": {"name": "test"}},
ContentType.FORM_URLENCODED,
encoding_rules={"data": {"contentType": "application/xml"}},
)
assert "data" in body
def test_dict_value_deep_object_explode(self):
body, headers = RequestBodySerializer.serialize(
{"filter": {"status": "active", "type": "doc"}},
ContentType.FORM_URLENCODED,
encoding_rules={
"filter": {"style": "deepObject", "explode": True}
},
)
assert "filter" in body
def test_dict_value_non_exploded(self):
body, headers = RequestBodySerializer.serialize(
{"obj": {"a": "1", "b": "2"}},
ContentType.FORM_URLENCODED,
encoding_rules={"obj": {"style": "form", "explode": False}},
)
assert "obj" in body
def test_default_explode_for_form_style(self):
"""Default explode should be True when style is 'form'."""
body, headers = RequestBodySerializer.serialize(
{"items": ["x", "y"]},
ContentType.FORM_URLENCODED,
encoding_rules={"items": {"style": "form"}},
)
# explode defaults to True for form style => separate params
assert "items=x" in body
assert "items=y" in body
def test_special_characters_encoded(self):
body, _ = RequestBodySerializer.serialize(
{"q": "hello world&more"}, ContentType.FORM_URLENCODED
)
assert "hello" in body
assert "q=" in body
# =====================================================================
# Text Plain Serialization
# =====================================================================
@pytest.mark.unit
class TestSerializeTextPlain:
def test_single_value(self):
body, headers = RequestBodySerializer.serialize(
{"message": "hello"}, ContentType.TEXT_PLAIN
)
assert body == "hello"
assert headers["Content-Type"] == "text/plain"
def test_multiple_values(self):
body, headers = RequestBodySerializer.serialize(
{"name": "Alice", "age": 30}, ContentType.TEXT_PLAIN
)
assert "name: Alice" in body
assert "age: 30" in body
# =====================================================================
# XML Serialization
# =====================================================================
@pytest.mark.unit
class TestSerializeXml:
def test_basic_xml(self):
body, headers = RequestBodySerializer.serialize(
{"name": "Alice"}, ContentType.XML
)
assert '<?xml version="1.0"' in body
assert "<name>Alice</name>" in body
assert headers["Content-Type"] == "application/xml"
def test_nested_xml(self):
body, headers = RequestBodySerializer.serialize(
{"user": {"name": "Alice"}}, ContentType.XML
)
assert "<user>" in body
assert "<name>Alice</name>" in body
def test_xml_escapes_special_chars(self):
body, headers = RequestBodySerializer.serialize(
{"data": "<script>alert('xss')</script>"}, ContentType.XML
)
assert "&lt;script&gt;" in body
def test_xml_with_list(self):
body, _ = RequestBodySerializer.serialize(
{"items": [1, 2, 3]}, ContentType.XML
)
assert "<item>1</item>" in body
assert "<item>2</item>" in body
assert "<item>3</item>" in body
# =====================================================================
# Octet Stream Serialization
# =====================================================================
@pytest.mark.unit
class TestSerializeOctetStream:
def test_dict_body(self):
body, headers = RequestBodySerializer.serialize(
{"key": "val"}, ContentType.OCTET_STREAM
)
assert isinstance(body, bytes)
assert headers["Content-Type"] == "application/octet-stream"
def test_bytes_body(self):
body, headers = RequestBodySerializer._serialize_octet_stream(b"\x00\x01")
assert body == b"\x00\x01"
assert headers["Content-Type"] == "application/octet-stream"
def test_string_body(self):
body, headers = RequestBodySerializer._serialize_octet_stream("hello")
assert body == b"hello"
# =====================================================================
# Multipart Form Data Serialization
# =====================================================================
@pytest.mark.unit
class TestSerializeMultipartFormData:
def test_basic_multipart(self):
body, headers = RequestBodySerializer.serialize(
{"field": "value"}, ContentType.MULTIPART_FORM_DATA
)
assert isinstance(body, bytes)
assert "multipart/form-data" in headers["Content-Type"]
assert "boundary=" in headers["Content-Type"]
def test_none_values_skipped(self):
body, headers = RequestBodySerializer.serialize(
{"field": "value", "empty": None}, ContentType.MULTIPART_FORM_DATA
)
body_str = body.decode("utf-8", errors="replace")
assert "field" in body_str
assert "empty" not in body_str
def test_multipart_with_bytes(self):
body, headers = RequestBodySerializer.serialize(
{"file": b"\x00\x01\x02"}, ContentType.MULTIPART_FORM_DATA
)
assert isinstance(body, bytes)
def test_multipart_with_dict_json(self):
body, headers = RequestBodySerializer.serialize(
{"meta": {"key": "val"}},
ContentType.MULTIPART_FORM_DATA,
encoding_rules={"meta": {"contentType": "application/json"}},
)
body_str = body.decode("utf-8", errors="replace")
assert "meta" in body_str
assert "application/json" in body_str
def test_multipart_with_dict_xml(self):
body, headers = RequestBodySerializer.serialize(
{"data": {"name": "test"}},
ContentType.MULTIPART_FORM_DATA,
encoding_rules={"data": {"contentType": "application/xml"}},
)
body_str = body.decode("utf-8", errors="replace")
assert "data" in body_str
def test_multipart_octet_stream_bytes(self):
body, headers = RequestBodySerializer.serialize(
{"bin": b"\xff\xfe"},
ContentType.MULTIPART_FORM_DATA,
encoding_rules={"bin": {"contentType": "application/octet-stream"}},
)
body_str = body.decode("utf-8", errors="replace")
assert "bin" in body_str
assert "Content-Transfer-Encoding: base64" in body_str
def test_multipart_string_with_json_content_type(self):
body, headers = RequestBodySerializer.serialize(
{"json_str": '{"a": 1}'},
ContentType.MULTIPART_FORM_DATA,
encoding_rules={"json_str": {"contentType": "application/json"}},
)
body_str = body.decode("utf-8", errors="replace")
assert "json_str" in body_str
def test_multipart_string_with_non_text_content_type(self):
body, headers = RequestBodySerializer.serialize(
{"custom": "data"},
ContentType.MULTIPART_FORM_DATA,
encoding_rules={"custom": {"contentType": "application/custom"}},
)
body_str = body.decode("utf-8", errors="replace")
assert "custom" in body_str
# =====================================================================
# Helper Methods
# =====================================================================
@pytest.mark.unit
class TestHelpers:
def test_percent_encode_space(self):
assert RequestBodySerializer._percent_encode("hello world") == "hello%20world"
def test_percent_encode_slash(self):
assert RequestBodySerializer._percent_encode("a/b") == "a%2Fb"
def test_percent_encode_safe_chars(self):
assert RequestBodySerializer._percent_encode("a/b", safe_chars="/") == "a/b"
def test_escape_xml_ampersand(self):
assert "&amp;" in RequestBodySerializer._escape_xml("&")
def test_escape_xml_lt(self):
assert "&lt;" in RequestBodySerializer._escape_xml("<")
def test_escape_xml_gt(self):
assert "&gt;" in RequestBodySerializer._escape_xml(">")
def test_escape_xml_quote(self):
assert "&quot;" in RequestBodySerializer._escape_xml('"')
def test_escape_xml_apos(self):
assert "&apos;" in RequestBodySerializer._escape_xml("'")
def test_dict_to_xml_list(self):
xml = RequestBodySerializer._dict_to_xml({"items": [1, 2, 3]})
assert "<item>1</item>" in xml
assert "<item>2</item>" in xml
def test_dict_to_xml_custom_root(self):
xml = RequestBodySerializer._dict_to_xml({"key": "val"}, root_name="data")
assert "<data>" in xml
assert "<key>val</key>" in xml
def test_dict_to_xml_deeply_nested(self):
xml = RequestBodySerializer._dict_to_xml({"a": {"b": {"c": "deep"}}})
assert "<c>deep</c>" in xml
# =====================================================================
# Error Handling
# =====================================================================
@pytest.mark.unit
class TestSerializationErrors:
def test_serialize_raises_on_internal_error(self):
"""Test that serialization errors are wrapped in ValueError."""
# Patch _serialize_json to raise
with pytest.raises(ValueError, match="Failed to serialize"):
RequestBodySerializer.serialize(
{"key": object()}, # object() is not JSON-serializable
ContentType.JSON,
)
# =====================================================================
# Coverage gap tests (lines 145, 155, 159, 162, 166, 271)
# =====================================================================
@pytest.mark.unit
class TestSerializeFormValueGaps:
def test_dict_explode_without_deep_object(self):
"""Cover line 145: dict with explode=True but style != deepObject."""
result = RequestBodySerializer._serialize_form_value(
value={"a": "1", "b": "2"},
style="form",
explode=True,
content_type="application/x-www-form-urlencoded",
key="data",
)
assert isinstance(result, list)
assert len(result) == 2
def test_list_explode_true(self):
"""Cover line 155: list with explode=True."""
result = RequestBodySerializer._serialize_form_value(
value=["x", "y", "z"],
style="form",
explode=True,
content_type="application/x-www-form-urlencoded",
key="items",
)
assert isinstance(result, list)
assert len(result) == 3
def test_list_explode_false(self):
"""Cover line 159: list with explode=False."""
result = RequestBodySerializer._serialize_form_value(
value=["x", "y"],
style="form",
explode=False,
content_type="application/x-www-form-urlencoded",
key="items",
)
assert isinstance(result, str)
# comma-joined and percent-encoded
assert "x" in result
assert "y" in result
def test_primitive_value(self):
"""Cover line 162: primitive string value."""
result = RequestBodySerializer._serialize_form_value(
value="hello world",
style="form",
explode=False,
content_type="application/x-www-form-urlencoded",
key="name",
)
assert isinstance(result, str)
assert "hello" in result
def test_dict_no_explode(self):
"""Cover line 166 area: dict with explode=False returns comma-joined."""
result = RequestBodySerializer._serialize_form_value(
value={"k1": "v1", "k2": "v2"},
style="form",
explode=False,
content_type="application/x-www-form-urlencoded",
key="data",
)
assert isinstance(result, str)
assert "k1" in result
def test_octet_stream_string_input(self):
"""Cover line 271: _serialize_octet_stream with string input."""
body, headers = RequestBodySerializer._serialize_octet_stream("hello bytes")
assert body == b"hello bytes"
assert headers["Content-Type"] == ContentType.OCTET_STREAM.value
def test_octet_stream_dict_input(self):
"""Cover: _serialize_octet_stream with dict input (fallback to JSON)."""
body, headers = RequestBodySerializer._serialize_octet_stream({"key": "val"})
assert isinstance(body, bytes)
import json
parsed = json.loads(body.decode("utf-8"))
assert parsed == {"key": "val"}
# ---------------------------------------------------------------------------
# Coverage — additional uncovered lines: 226, 229, 271, 275, 279
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestApiBodySerializerMultipartParts:
def test_multipart_dict_unknown_content_type(self):
"""Cover line 226: dict with unknown content type uses str()."""
from application.agents.tools.api_body_serializer import (
RequestBodySerializer,
)
result = RequestBodySerializer._create_multipart_part(
name="field",
value={"key": "val"},
content_type="text/csv",
headers_rule={},
)
assert "text/csv" in result
assert "key" in result
def test_multipart_string_json_content_type(self):
"""Cover line 229: string value with application/json content type."""
from application.agents.tools.api_body_serializer import (
RequestBodySerializer,
)
result = RequestBodySerializer._create_multipart_part(
name="data",
value='{"a": 1}',
content_type="application/json",
headers_rule={},
)
assert "application/json" in result
assert '{"a": 1}' in result
def test_multipart_string_xml_content_type(self):
"""Cover line 229: string value with application/xml content type."""
from application.agents.tools.api_body_serializer import (
RequestBodySerializer,
)
result = RequestBodySerializer._create_multipart_part(
name="data",
value="<root/>",
content_type="application/xml",
headers_rule={},
)
assert "application/xml" in result
assert "<root/>" in result
def test_multipart_string_unknown_content_type(self):
"""Cover line 229: string with unknown content type falls through."""
from application.agents.tools.api_body_serializer import (
RequestBodySerializer,
)
result = RequestBodySerializer._create_multipart_part(
name="data",
value="some text",
content_type="application/custom",
headers_rule={},
)
assert "application/custom" in result
assert "some text" in result
@pytest.mark.unit
class TestApiBodySerializerOctetStreamCoverage:
def test_octet_stream_bytes_input(self):
"""Cover line 271: _serialize_octet_stream with bytes input."""
from application.agents.tools.api_body_serializer import (
ContentType,
RequestBodySerializer,
)
body, headers = RequestBodySerializer._serialize_octet_stream(b"raw bytes")
assert body == b"raw bytes"
assert headers["Content-Type"] == ContentType.OCTET_STREAM.value
def test_octet_stream_string_input(self):
"""Cover line 275: _serialize_octet_stream with string input."""
from application.agents.tools.api_body_serializer import (
ContentType,
RequestBodySerializer,
)
body, headers = RequestBodySerializer._serialize_octet_stream("text data")
assert body == b"text data"
assert headers["Content-Type"] == ContentType.OCTET_STREAM.value
def test_octet_stream_dict_input(self):
"""Cover line 279: _serialize_octet_stream with dict input (fallback to JSON)."""
import json
from application.agents.tools.api_body_serializer import (
ContentType,
RequestBodySerializer,
)
body, headers = RequestBodySerializer._serialize_octet_stream({"k": "v"})
assert isinstance(body, bytes)
parsed = json.loads(body.decode("utf-8"))
assert parsed == {"k": "v"}
assert headers["Content-Type"] == ContentType.OCTET_STREAM.value

View File

@@ -0,0 +1,516 @@
"""Comprehensive tests for application/agents/tools/api_tool.py
Covers: APITool initialization, all HTTP methods, path param substitution,
SSRF validation, error handling, response parsing, body serialization.
"""
import json
from unittest.mock import MagicMock, patch
import pytest
import requests
from application.agents.tools.api_tool import APITool, DEFAULT_TIMEOUT
@pytest.fixture
def get_tool():
return APITool(
config={
"url": "https://api.example.com/data",
"method": "GET",
"headers": {"Accept": "application/json"},
"query_params": {},
}
)
@pytest.fixture
def post_tool():
return APITool(
config={
"url": "https://api.example.com/items",
"method": "POST",
"headers": {},
"query_params": {},
}
)
# =====================================================================
# Initialization
# =====================================================================
@pytest.mark.unit
class TestAPIToolInit:
def test_default_values(self):
tool = APITool(config={})
assert tool.url == ""
assert tool.method == "GET"
assert tool.headers == {}
assert tool.query_params == {}
assert tool.body_content_type == "application/json"
assert tool.body_encoding_rules == {}
def test_custom_config(self):
tool = APITool(config={
"url": "https://api.test.com",
"method": "POST",
"headers": {"X-Key": "val"},
"query_params": {"page": "1"},
"body_content_type": "application/xml",
"body_encoding_rules": {"field": {"style": "form"}},
})
assert tool.url == "https://api.test.com"
assert tool.method == "POST"
assert tool.headers == {"X-Key": "val"}
assert tool.query_params == {"page": "1"}
assert tool.body_content_type == "application/xml"
def test_default_timeout_constant(self):
assert DEFAULT_TIMEOUT == 90
# =====================================================================
# HTTP Methods
# =====================================================================
@pytest.mark.unit
class TestMakeApiCall:
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_successful_get(self, mock_get, mock_validate, get_tool):
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {"result": "ok"}
mock_resp.content = b'{"result":"ok"}'
mock_get.return_value = mock_resp
result = get_tool.execute_action("any_action")
assert result["status_code"] == 200
assert result["data"] == {"result": "ok"}
assert result["message"] == "API call successful."
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.post")
def test_successful_post(self, mock_post, mock_validate, post_tool):
mock_resp = MagicMock()
mock_resp.status_code = 201
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {"id": 1}
mock_resp.content = b'{"id":1}'
mock_post.return_value = mock_resp
result = post_tool.execute_action("create", name="test")
assert result["status_code"] == 201
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.put")
def test_put_method(self, mock_put, mock_validate):
tool = APITool(config={"url": "https://example.com/item/1", "method": "PUT"})
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {}
mock_resp.content = b'{}'
mock_put.return_value = mock_resp
result = tool.execute_action("update", name="new")
assert result["status_code"] == 200
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.delete")
def test_delete_method(self, mock_delete, mock_validate):
tool = APITool(config={"url": "https://example.com/item/1", "method": "DELETE"})
mock_resp = MagicMock()
mock_resp.status_code = 204
mock_resp.headers = {"Content-Type": "text/plain"}
mock_resp.content = b''
mock_delete.return_value = mock_resp
result = tool.execute_action("delete")
assert result["status_code"] == 204
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.patch")
def test_patch_method(self, mock_patch, mock_validate):
tool = APITool(config={"url": "https://example.com/item/1", "method": "PATCH"})
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {"patched": True}
mock_resp.content = b'{"patched":true}'
mock_patch.return_value = mock_resp
result = tool.execute_action("patch", field="val")
assert result["status_code"] == 200
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.head")
def test_head_method(self, mock_head, mock_validate):
tool = APITool(config={"url": "https://example.com", "method": "HEAD"})
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "text/html"}
mock_resp.content = b''
mock_head.return_value = mock_resp
result = tool.execute_action("check")
assert result["status_code"] == 200
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.options")
def test_options_method(self, mock_options, mock_validate):
tool = APITool(config={"url": "https://example.com", "method": "OPTIONS"})
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "text/plain"}
mock_resp.content = b''
mock_options.return_value = mock_resp
result = tool.execute_action("options")
assert result["status_code"] == 200
@patch("application.agents.tools.api_tool.validate_url")
def test_unsupported_method(self, mock_validate):
tool = APITool(config={"url": "https://example.com", "method": "CUSTOM"})
result = tool.execute_action("any")
assert result["status_code"] is None
assert "Unsupported" in result["message"]
# =====================================================================
# SSRF Validation
# =====================================================================
@pytest.mark.unit
class TestSSRFValidation:
@patch("application.agents.tools.api_tool.validate_url")
def test_ssrf_blocked_initial_url(self, mock_validate, get_tool):
from application.core.url_validation import SSRFError
mock_validate.side_effect = SSRFError("blocked")
result = get_tool.execute_action("any")
assert result["status_code"] is None
assert "URL validation error" in result["message"]
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_ssrf_blocked_after_param_substitution(self, mock_get, mock_validate):
from application.core.url_validation import SSRFError
tool = APITool(config={
"url": "https://api.example.com/{host}/data",
"method": "GET",
"query_params": {"host": "169.254.169.254"},
})
call_count = [0]
def side_effect(url):
call_count[0] += 1
if call_count[0] == 2:
raise SSRFError("blocked after substitution")
mock_validate.side_effect = side_effect
result = tool.execute_action("any")
assert result["status_code"] is None
assert "URL validation error" in result["message"]
# =====================================================================
# Error Handling
# =====================================================================
@pytest.mark.unit
class TestErrorHandling:
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_timeout_error(self, mock_get, mock_validate, get_tool):
mock_get.side_effect = requests.exceptions.Timeout()
result = get_tool.execute_action("any")
assert result["status_code"] is None
assert "timeout" in result["message"].lower()
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_connection_error(self, mock_get, mock_validate, get_tool):
mock_get.side_effect = requests.exceptions.ConnectionError("refused")
result = get_tool.execute_action("any")
assert result["status_code"] is None
assert "Connection error" in result["message"]
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_http_error_with_json(self, mock_get, mock_validate, get_tool):
mock_resp = MagicMock()
mock_resp.status_code = 422
mock_resp.json.return_value = {"error": "invalid_field"}
mock_resp.raise_for_status.side_effect = requests.exceptions.HTTPError(
response=mock_resp
)
mock_get.return_value = mock_resp
result = get_tool.execute_action("any")
assert result["status_code"] == 422
assert result["data"] == {"error": "invalid_field"}
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_http_error_non_json_body(self, mock_get, mock_validate, get_tool):
mock_resp = MagicMock()
mock_resp.status_code = 404
mock_resp.text = "Not Found"
mock_resp.json.side_effect = json.JSONDecodeError("", "", 0)
mock_resp.raise_for_status.side_effect = requests.exceptions.HTTPError(
response=mock_resp
)
mock_get.return_value = mock_resp
result = get_tool.execute_action("any")
assert result["status_code"] == 404
assert result["data"] == "Not Found"
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_request_exception(self, mock_get, mock_validate, get_tool):
mock_get.side_effect = requests.exceptions.RequestException("something")
result = get_tool.execute_action("any")
assert "API call failed" in result["message"]
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_unexpected_exception(self, mock_get, mock_validate, get_tool):
mock_get.side_effect = RuntimeError("unexpected")
result = get_tool.execute_action("any")
assert "Unexpected error" in result["message"]
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.post")
def test_body_serialization_error(self, mock_post, mock_validate):
tool = APITool(config={
"url": "https://example.com",
"method": "POST",
"body_content_type": "application/json",
})
with patch(
"application.agents.tools.api_tool.RequestBodySerializer.serialize",
side_effect=ValueError("serialize fail"),
):
result = tool.execute_action("any", key="val")
assert "serialization error" in result["message"].lower()
# =====================================================================
# Path Param Substitution
# =====================================================================
@pytest.mark.unit
class TestPathParamSubstitution:
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_path_params_substituted(self, mock_get, mock_validate):
tool = APITool(config={
"url": "https://api.example.com/users/{user_id}/posts/{post_id}",
"method": "GET",
"query_params": {"user_id": "42", "post_id": "7"},
})
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = []
mock_resp.content = b'[]'
mock_get.return_value = mock_resp
tool.execute_action("get")
called_url = mock_get.call_args[0][0]
assert "/users/42/posts/7" in called_url
assert "{user_id}" not in called_url
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_remaining_query_params_appended(self, mock_get, mock_validate):
tool = APITool(config={
"url": "https://api.example.com/items",
"method": "GET",
"query_params": {"page": "2", "limit": "10"},
})
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = []
mock_resp.content = b'[]'
mock_get.return_value = mock_resp
tool.execute_action("get")
called_url = mock_get.call_args[0][0]
assert "page=2" in called_url
assert "limit=10" in called_url
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_query_params_append_with_existing_query_string(
self, mock_get, mock_validate
):
tool = APITool(config={
"url": "https://api.example.com/items?existing=true",
"method": "GET",
"query_params": {"page": "1"},
})
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = []
mock_resp.content = b'[]'
mock_get.return_value = mock_resp
tool.execute_action("get")
called_url = mock_get.call_args[0][0]
assert "&page=1" in called_url
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.post")
def test_empty_body_no_serialization(self, mock_post, mock_validate):
tool = APITool(config={"url": "https://example.com", "method": "POST"})
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {}
mock_resp.content = b'{}'
mock_post.return_value = mock_resp
result = tool.execute_action("create")
assert result["status_code"] == 200
# =====================================================================
# Parse Response
# =====================================================================
@pytest.mark.unit
class TestParseResponse:
def test_json_response(self, get_tool):
mock_resp = MagicMock()
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {"key": "val"}
mock_resp.content = b'{"key":"val"}'
result = get_tool._parse_response(mock_resp)
assert result == {"key": "val"}
def test_json_decode_error_falls_back_to_text(self, get_tool):
mock_resp = MagicMock()
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.side_effect = json.JSONDecodeError("", "", 0)
mock_resp.text = "not valid json"
mock_resp.content = b"not valid json"
result = get_tool._parse_response(mock_resp)
assert result == "not valid json"
def test_text_response(self, get_tool):
mock_resp = MagicMock()
mock_resp.headers = {"Content-Type": "text/plain"}
mock_resp.text = "plain text"
mock_resp.content = b"plain text"
result = get_tool._parse_response(mock_resp)
assert result == "plain text"
def test_xml_response(self, get_tool):
mock_resp = MagicMock()
mock_resp.headers = {"Content-Type": "application/xml"}
mock_resp.text = "<root><item>1</item></root>"
mock_resp.content = b"<root><item>1</item></root>"
result = get_tool._parse_response(mock_resp)
assert "<root>" in result
def test_html_response(self, get_tool):
mock_resp = MagicMock()
mock_resp.headers = {"Content-Type": "text/html"}
mock_resp.text = "<html><body>Hi</body></html>"
mock_resp.content = b"<html><body>Hi</body></html>"
result = get_tool._parse_response(mock_resp)
assert "<html>" in result
def test_empty_content(self, get_tool):
mock_resp = MagicMock()
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.content = b""
result = get_tool._parse_response(mock_resp)
assert result is None
def test_binary_response(self, get_tool):
mock_resp = MagicMock()
mock_resp.headers = {"Content-Type": "application/octet-stream"}
mock_resp.text = "binary_text"
mock_resp.content = b"\x00\x01\x02"
result = get_tool._parse_response(mock_resp)
assert result is not None
def test_text_xml_content_type(self, get_tool):
mock_resp = MagicMock()
mock_resp.headers = {"Content-Type": "text/xml"}
mock_resp.text = "<data/>"
mock_resp.content = b"<data/>"
result = get_tool._parse_response(mock_resp)
assert result == "<data/>"
# =====================================================================
# Metadata
# =====================================================================
@pytest.mark.unit
class TestAPIToolMetadata:
def test_actions_metadata_empty(self, get_tool):
assert get_tool.get_actions_metadata() == []
def test_config_requirements_empty(self, get_tool):
assert get_tool.get_config_requirements() == {}
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.post")
def test_content_type_set_for_post_with_no_headers(
self, mock_post, mock_validate
):
tool = APITool(config={
"url": "https://example.com",
"method": "POST",
"headers": {},
})
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {}
mock_resp.content = b'{}'
mock_post.return_value = mock_resp
tool.execute_action("create")
call_headers = mock_post.call_args[1]["headers"]
assert "Content-Type" in call_headers

View File

@@ -0,0 +1,596 @@
"""Comprehensive tests for application/agents/tools/internal_search.py
Covers: InternalSearchTool (search, list_files, path_filter, error handling,
directory structure loading), build helpers, add_internal_search_tool,
sources_have_directory_structure.
"""
import json
from unittest.mock import MagicMock, Mock, patch
import pytest
from application.agents.tools.internal_search import (
INTERNAL_TOOL_ENTRY,
INTERNAL_TOOL_ID,
InternalSearchTool,
add_internal_search_tool,
build_internal_tool_config,
build_internal_tool_entry,
sources_have_directory_structure,
)
# =====================================================================
# InternalSearchTool - Search
# =====================================================================
def _make_tool(**config_overrides):
config = {"source": {}, "retriever_name": "classic", "chunks": 2}
config.update(config_overrides)
return InternalSearchTool(config)
@pytest.mark.unit
class TestInternalSearchToolSearch:
def test_search_no_query_returns_error(self):
tool = _make_tool()
result = tool.execute_action("search", query="")
assert "required" in result.lower()
def test_search_returns_formatted_docs(self):
tool = _make_tool()
mock_retriever = Mock()
mock_retriever.search.return_value = [
{
"text": "Hello world",
"title": "Doc1",
"source": "test",
"filename": "doc1.md",
},
]
tool._retriever = mock_retriever
result = tool.execute_action("search", query="hello")
assert "doc1.md" in result
assert "Hello world" in result
assert len(tool.retrieved_docs) == 1
def test_search_no_results(self):
tool = _make_tool()
mock_retriever = Mock()
mock_retriever.search.return_value = []
tool._retriever = mock_retriever
result = tool.execute_action("search", query="nonexistent")
assert "No documents found" in result
def test_search_accumulates_docs(self):
tool = _make_tool()
mock_retriever = Mock()
tool._retriever = mock_retriever
mock_retriever.search.return_value = [
{"text": "A", "title": "D1", "source": "s1"},
]
tool.execute_action("search", query="first")
mock_retriever.search.return_value = [
{"text": "B", "title": "D2", "source": "s2"},
]
tool.execute_action("search", query="second")
assert len(tool.retrieved_docs) == 2
def test_search_deduplicates_docs(self):
tool = _make_tool()
doc = {"text": "Same", "title": "Same", "source": "same"}
mock_retriever = Mock()
mock_retriever.search.return_value = [doc]
tool._retriever = mock_retriever
tool.execute_action("search", query="q1")
tool.execute_action("search", query="q2")
assert len(tool.retrieved_docs) == 1
def test_search_with_path_filter(self):
tool = _make_tool()
mock_retriever = Mock()
mock_retriever.search.return_value = [
{"text": "A", "title": "T", "source": "src/main.py", "filename": "main.py"},
{"text": "B", "title": "T", "source": "docs/readme.md", "filename": "readme.md"},
]
tool._retriever = mock_retriever
result = tool.execute_action("search", query="code", path_filter="src/")
assert "main.py" in result
assert "readme.md" not in result
def test_search_path_filter_matches_title(self):
tool = _make_tool()
mock_retriever = Mock()
mock_retriever.search.return_value = [
{"text": "A", "title": "src/main.py", "source": "other", "filename": ""},
]
tool._retriever = mock_retriever
result = tool.execute_action("search", query="code", path_filter="src/main")
assert "src/main.py" in result
def test_search_path_filter_no_match(self):
tool = _make_tool()
mock_retriever = Mock()
mock_retriever.search.return_value = [
{"text": "A", "title": "T", "source": "other/file.txt"},
]
tool._retriever = mock_retriever
result = tool.execute_action("search", query="code", path_filter="src/")
assert "No documents found" in result
def test_search_retriever_error(self):
tool = _make_tool()
mock_retriever = Mock()
mock_retriever.search.side_effect = Exception("Connection error")
tool._retriever = mock_retriever
result = tool.execute_action("search", query="test")
assert "failed" in result.lower() or "error" in result.lower()
def test_unknown_action(self):
tool = _make_tool()
result = tool.execute_action("nonexistent")
assert "Unknown action" in result
def test_search_formats_with_separator(self):
tool = _make_tool()
mock_retriever = Mock()
mock_retriever.search.return_value = [
{"text": "A", "title": "D1", "source": "s1", "filename": "f1.md"},
{"text": "B", "title": "D2", "source": "s2", "filename": "f2.md"},
]
tool._retriever = mock_retriever
result = tool.execute_action("search", query="test")
assert "---" in result
assert "[1]" in result
assert "[2]" in result
def test_search_uses_title_when_no_filename(self):
tool = _make_tool()
mock_retriever = Mock()
mock_retriever.search.return_value = [
{"text": "Content", "title": "My Title", "source": "src", "filename": ""},
]
tool._retriever = mock_retriever
result = tool.execute_action("search", query="q")
assert "My Title" in result
# =====================================================================
# InternalSearchTool - List Files
# =====================================================================
@pytest.mark.unit
class TestInternalSearchToolListFiles:
def test_list_files_no_structure(self):
tool = InternalSearchTool({"source": {}})
tool._dir_structure_loaded = True
tool._directory_structure = None
result = tool.execute_action("list_files")
assert "No file structure" in result
def test_list_files_root(self):
tool = InternalSearchTool({"source": {}})
tool._dir_structure_loaded = True
tool._directory_structure = {
"src": {"main.py": {}},
"README.md": {"type": "md", "token_count": 100},
}
result = tool.execute_action("list_files")
assert "src/" in result
assert "README.md" in result
def test_list_files_nested_path(self):
tool = InternalSearchTool({"source": {}})
tool._dir_structure_loaded = True
tool._directory_structure = {
"src": {
"utils": {"helper.py": {}},
},
}
result = tool.execute_action("list_files", path="src")
assert "utils/" in result
def test_list_files_invalid_path(self):
tool = InternalSearchTool({"source": {}})
tool._dir_structure_loaded = True
tool._directory_structure = {"src": {}}
result = tool.execute_action("list_files", path="nonexistent")
assert "not found" in result
def test_list_files_empty_directory(self):
tool = InternalSearchTool({"source": {}})
tool._dir_structure_loaded = True
tool._directory_structure = {"empty_dir": {}}
result = tool.execute_action("list_files", path="empty_dir")
assert "(empty)" in result
def test_list_files_file_with_metadata(self):
tool = InternalSearchTool({"source": {}})
tool._dir_structure_loaded = True
tool._directory_structure = {
"data.csv": {
"type": "text/csv",
"size_bytes": 1024,
"token_count": 500,
},
}
result = tool.execute_action("list_files")
assert "data.csv" in result
assert "500 tokens" in result
assert "text/csv" in result
def test_list_files_file_is_not_directory(self):
tool = InternalSearchTool({"source": {}})
tool._dir_structure_loaded = True
tool._directory_structure = {
"src": {
"main.py": "plain_file_value",
},
}
result = tool.execute_action("list_files", path="src/main.py")
assert "is a file" in result
def test_list_files_deep_nested_path_with_slashes(self):
tool = InternalSearchTool({"source": {}})
tool._dir_structure_loaded = True
tool._directory_structure = {
"a": {"b": {"c": {"file.txt": {"type": "text"}}}},
}
result = tool.execute_action("list_files", path="a/b/c")
assert "file.txt" in result
# =====================================================================
# Count Files Helper
# =====================================================================
@pytest.mark.unit
class TestCountFiles:
def test_count_files_nested(self):
tool = InternalSearchTool({"source": {}})
node = {
"file1.txt": {"type": "text"},
"dir": {
"file2.txt": {"type": "text"},
"file3.txt": "plain_value",
},
}
assert tool._count_files(node) == 3
def test_count_files_empty(self):
tool = InternalSearchTool({"source": {}})
assert tool._count_files({}) == 0
# =====================================================================
# Directory Structure Loading
# =====================================================================
@pytest.mark.unit
class TestGetDirectoryStructure:
def test_loads_from_mongo(self):
tool = InternalSearchTool({
"source": {"active_docs": ["507f1f77bcf86cd799439011"]},
})
mock_collection = MagicMock()
mock_collection.find_one.return_value = {
"_id": "507f1f77bcf86cd799439011",
"name": "test_source",
"directory_structure": {"src": {"main.py": {}}},
}
with patch("application.core.mongo_db.MongoDB") as mock_mongo:
mock_db = MagicMock()
mock_db.__getitem__ = MagicMock(return_value=mock_collection)
mock_client = MagicMock()
mock_client.__getitem__ = MagicMock(return_value=mock_db)
mock_mongo.get_client.return_value = mock_client
result = tool._get_directory_structure()
assert result is not None
assert "src" in result
def test_returns_none_without_active_docs(self):
tool = InternalSearchTool({"source": {}})
result = tool._get_directory_structure()
assert result is None
assert tool._dir_structure_loaded is True
def test_caches_after_first_load(self):
tool = InternalSearchTool({"source": {}})
tool._dir_structure_loaded = True
tool._directory_structure = {"cached": True}
result = tool._get_directory_structure()
assert result == {"cached": True}
def test_handles_json_string_structure(self):
tool = InternalSearchTool({
"source": {"active_docs": ["507f1f77bcf86cd799439011"]},
})
mock_collection = MagicMock()
mock_collection.find_one.return_value = {
"_id": "507f1f77bcf86cd799439011",
"name": "test_source",
"directory_structure": json.dumps({"src": {"app.py": {}}}),
}
with patch("application.core.mongo_db.MongoDB") as mock_mongo:
mock_db = MagicMock()
mock_db.__getitem__ = MagicMock(return_value=mock_collection)
mock_client = MagicMock()
mock_client.__getitem__ = MagicMock(return_value=mock_db)
mock_mongo.get_client.return_value = mock_client
result = tool._get_directory_structure()
assert result is not None
assert "src" in result
def test_handles_string_active_docs(self):
tool = InternalSearchTool({
"source": {"active_docs": "507f1f77bcf86cd799439011"},
})
mock_collection = MagicMock()
mock_collection.find_one.return_value = {
"_id": "507f1f77bcf86cd799439011",
"directory_structure": {"dir": {}},
}
with patch("application.core.mongo_db.MongoDB") as mock_mongo:
mock_db = MagicMock()
mock_db.__getitem__ = MagicMock(return_value=mock_collection)
mock_client = MagicMock()
mock_client.__getitem__ = MagicMock(return_value=mock_db)
mock_mongo.get_client.return_value = mock_client
result = tool._get_directory_structure()
assert result is not None
def test_merges_multiple_sources(self):
tool = InternalSearchTool({
"source": {
"active_docs": [
"507f1f77bcf86cd799439011",
"507f1f77bcf86cd799439012",
],
},
})
mock_collection = MagicMock()
mock_collection.find_one.side_effect = [
{"name": "src1", "directory_structure": {"a": {}}},
{"name": "src2", "directory_structure": {"b": {}}},
]
with patch("application.core.mongo_db.MongoDB") as mock_mongo:
mock_db = MagicMock()
mock_db.__getitem__ = MagicMock(return_value=mock_collection)
mock_client = MagicMock()
mock_client.__getitem__ = MagicMock(return_value=mock_db)
mock_mongo.get_client.return_value = mock_client
result = tool._get_directory_structure()
assert "src1" in result
assert "src2" in result
# =====================================================================
# Metadata
# =====================================================================
@pytest.mark.unit
class TestInternalSearchToolMetadata:
def test_actions_without_directory_structure(self):
tool = InternalSearchTool({"has_directory_structure": False})
meta = tool.get_actions_metadata()
action_names = [a["name"] for a in meta]
assert "search" in action_names
assert "list_files" not in action_names
search = meta[0]
assert "path_filter" not in search["parameters"]["properties"]
def test_actions_with_directory_structure(self):
tool = InternalSearchTool({"has_directory_structure": True})
meta = tool.get_actions_metadata()
action_names = [a["name"] for a in meta]
assert "search" in action_names
assert "list_files" in action_names
search = next(a for a in meta if a["name"] == "search")
assert "path_filter" in search["parameters"]["properties"]
def test_config_requirements_empty(self):
tool = InternalSearchTool({})
assert tool.get_config_requirements() == {}
# =====================================================================
# Build Helpers
# =====================================================================
@pytest.mark.unit
class TestBuildHelpers:
def test_build_entry_without_directory_structure(self):
entry = build_internal_tool_entry(has_directory_structure=False)
assert entry["name"] == "internal_search"
action_names = [a["name"] for a in entry["actions"]]
assert "search" in action_names
assert "list_files" not in action_names
assert entry["actions"][0].get("active") is True
def test_build_entry_with_directory_structure(self):
entry = build_internal_tool_entry(has_directory_structure=True)
action_names = [a["name"] for a in entry["actions"]]
assert "list_files" in action_names
# path_filter should be in search params
search_action = next(a for a in entry["actions"] if a["name"] == "search")
assert "path_filter" in search_action["parameters"]["properties"]
def test_build_config(self):
config = build_internal_tool_config(
source={"active_docs": ["abc"]},
retriever_name="semantic",
chunks=4,
)
assert config["source"] == {"active_docs": ["abc"]}
assert config["retriever_name"] == "semantic"
assert config["chunks"] == 4
def test_build_config_defaults(self):
config = build_internal_tool_config(source={"active_docs": ["abc"]})
assert config["retriever_name"] == "classic"
assert config["chunks"] == 2
assert config["doc_token_limit"] == 50000
def test_internal_tool_id(self):
assert INTERNAL_TOOL_ID == "internal"
def test_internal_tool_entry_constant(self):
assert INTERNAL_TOOL_ENTRY["name"] == "internal_search"
def test_add_internal_search_tool_with_sources(self):
tools_dict = {}
retriever_config = {
"source": {"active_docs": ["abc"]},
"retriever_name": "classic",
"chunks": 2,
"model_id": "gpt-4",
"llm_name": "openai",
"api_key": "key",
}
with patch(
"application.agents.tools.internal_search.sources_have_directory_structure",
return_value=False,
):
add_internal_search_tool(tools_dict, retriever_config)
assert INTERNAL_TOOL_ID in tools_dict
assert tools_dict[INTERNAL_TOOL_ID]["name"] == "internal_search"
assert "config" in tools_dict[INTERNAL_TOOL_ID]
def test_add_internal_search_tool_no_sources(self):
tools_dict = {}
retriever_config = {"source": {}}
add_internal_search_tool(tools_dict, retriever_config)
assert INTERNAL_TOOL_ID not in tools_dict
def test_add_internal_search_tool_empty_config(self):
tools_dict = {}
add_internal_search_tool(tools_dict, {})
assert INTERNAL_TOOL_ID not in tools_dict
# =====================================================================
# sources_have_directory_structure
# =====================================================================
@pytest.mark.unit
class TestSourcesHaveDirectoryStructure:
def test_no_active_docs(self):
assert sources_have_directory_structure({}) is False
def test_with_directory_structure(self):
mock_collection = MagicMock()
mock_collection.find_one.return_value = {
"directory_structure": {"src": {}},
}
with patch("application.core.mongo_db.MongoDB") as mock_mongo:
mock_db = MagicMock()
mock_db.__getitem__ = MagicMock(return_value=mock_collection)
mock_client = MagicMock()
mock_client.__getitem__ = MagicMock(return_value=mock_db)
mock_mongo.get_client.return_value = mock_client
result = sources_have_directory_structure(
{"active_docs": ["507f1f77bcf86cd799439011"]}
)
assert result is True
def test_without_directory_structure(self):
mock_collection = MagicMock()
mock_collection.find_one.return_value = {"directory_structure": None}
with patch("application.core.mongo_db.MongoDB") as mock_mongo:
mock_db = MagicMock()
mock_db.__getitem__ = MagicMock(return_value=mock_collection)
mock_client = MagicMock()
mock_client.__getitem__ = MagicMock(return_value=mock_db)
mock_mongo.get_client.return_value = mock_client
result = sources_have_directory_structure(
{"active_docs": ["507f1f77bcf86cd799439011"]}
)
assert result is False
def test_handles_exception_gracefully(self):
with patch(
"application.core.mongo_db.MongoDB.get_client",
side_effect=Exception("DB down"),
):
result = sources_have_directory_structure(
{"active_docs": ["507f1f77bcf86cd799439011"]}
)
assert result is False
def test_string_active_docs(self):
mock_collection = MagicMock()
mock_collection.find_one.return_value = {
"directory_structure": {"a": {}},
}
with patch("application.core.mongo_db.MongoDB") as mock_mongo:
mock_db = MagicMock()
mock_db.__getitem__ = MagicMock(return_value=mock_collection)
mock_client = MagicMock()
mock_client.__getitem__ = MagicMock(return_value=mock_db)
mock_mongo.get_client.return_value = mock_client
result = sources_have_directory_structure(
{"active_docs": "507f1f77bcf86cd799439011"}
)
assert result is True

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,548 @@
"""Comprehensive tests for application/agents/tools/memory.py
Covers: MemoryTool initialization, path validation, all actions
(view, create, str_replace, insert, delete, rename), directory operations,
error handling, and metadata.
"""
import mongomock
import pytest
def _get_settings():
from application.core.settings import settings
return settings
@pytest.fixture
def mock_memory_db(monkeypatch):
"""Set up a mongomock-based memory collection."""
settings = _get_settings()
mock_client = mongomock.MongoClient()
mock_db = mock_client[settings.MONGO_DB_NAME]
def get_mock_client():
return {settings.MONGO_DB_NAME: mock_db}
monkeypatch.setattr(
"application.core.mongo_db.MongoDB.get_client", get_mock_client
)
return mock_db
@pytest.fixture
def memory_tool(mock_memory_db):
from application.agents.tools.memory import MemoryTool
return MemoryTool(
tool_config={"tool_id": "test_tool_001"},
user_id="test_user",
)
# =====================================================================
# Initialization
# =====================================================================
@pytest.mark.unit
class TestMemoryToolInit:
def test_init_with_config(self, mock_memory_db):
from application.agents.tools.memory import MemoryTool
tool = MemoryTool(
tool_config={"tool_id": "custom_id"}, user_id="user1"
)
assert tool.tool_id == "custom_id"
assert tool.user_id == "user1"
def test_init_fallback_to_user_id(self, mock_memory_db):
from application.agents.tools.memory import MemoryTool
tool = MemoryTool(tool_config={}, user_id="user1")
assert tool.tool_id == "default_user1"
def test_init_no_user_no_config(self, mock_memory_db):
from application.agents.tools.memory import MemoryTool
tool = MemoryTool()
assert tool.tool_id is not None # UUID fallback
assert tool.user_id is None
# =====================================================================
# Path Validation
# =====================================================================
@pytest.mark.unit
class TestPathValidation:
def test_valid_path(self, memory_tool):
assert memory_tool._validate_path("/notes.txt") == "/notes.txt"
def test_adds_leading_slash(self, memory_tool):
assert memory_tool._validate_path("notes.txt") == "/notes.txt"
def test_empty_path_returns_none(self, memory_tool):
assert memory_tool._validate_path("") is None
def test_double_dots_rejected(self, memory_tool):
assert memory_tool._validate_path("/../../etc/passwd") is None
def test_double_slash_rejected(self, memory_tool):
assert memory_tool._validate_path("//path") is None
def test_preserves_trailing_slash(self, memory_tool):
result = memory_tool._validate_path("/project/")
assert result.endswith("/")
def test_root_path(self, memory_tool):
assert memory_tool._validate_path("/") == "/"
def test_whitespace_stripped(self, memory_tool):
result = memory_tool._validate_path(" /notes.txt ")
assert result == "/notes.txt"
# =====================================================================
# Execute Action - No User
# =====================================================================
@pytest.mark.unit
class TestNoUser:
def test_requires_user_id(self, mock_memory_db):
from application.agents.tools.memory import MemoryTool
tool = MemoryTool(tool_config={"tool_id": "t"}, user_id=None)
result = tool.execute_action("view", path="/")
assert "Error" in result
assert "user_id" in result
def test_unknown_action(self, memory_tool):
result = memory_tool.execute_action("fly")
assert "Unknown action" in result
# =====================================================================
# View Action
# =====================================================================
@pytest.mark.unit
class TestViewAction:
def test_view_empty_directory(self, memory_tool):
result = memory_tool.execute_action("view", path="/")
assert "Directory: /" in result
assert "(empty)" in result
def test_view_directory_with_files(self, memory_tool):
memory_tool.execute_action("create", path="/notes.txt", file_text="content")
memory_tool.execute_action("create", path="/todo.txt", file_text="tasks")
result = memory_tool.execute_action("view", path="/")
assert "notes.txt" in result
assert "todo.txt" in result
def test_view_file_content(self, memory_tool):
memory_tool.execute_action("create", path="/hello.txt", file_text="Hello World")
result = memory_tool.execute_action("view", path="/hello.txt")
assert "Hello World" in result
def test_view_nonexistent_file(self, memory_tool):
result = memory_tool.execute_action("view", path="/missing.txt")
assert "Error" in result
assert "not found" in result.lower()
def test_view_file_with_range(self, memory_tool):
memory_tool.execute_action(
"create", path="/lines.txt", file_text="line1\nline2\nline3\nline4"
)
result = memory_tool.execute_action(
"view", path="/lines.txt", view_range=[2, 3]
)
assert "line2" in result
assert "line3" in result
def test_view_file_range_out_of_bounds(self, memory_tool):
memory_tool.execute_action("create", path="/short.txt", file_text="only")
result = memory_tool.execute_action(
"view", path="/short.txt", view_range=[100, 200]
)
assert "out of bounds" in result.lower()
def test_view_invalid_path(self, memory_tool):
result = memory_tool.execute_action("view", path="")
assert "Error" in result
def test_view_subdirectory(self, memory_tool):
memory_tool.execute_action(
"create", path="/project/src/main.py", file_text="code"
)
result = memory_tool.execute_action("view", path="/project/")
assert "src/main.py" in result
# =====================================================================
# Create Action
# =====================================================================
@pytest.mark.unit
class TestCreateAction:
def test_create_file(self, memory_tool):
result = memory_tool.execute_action(
"create", path="/test.txt", file_text="content"
)
assert "File created" in result
content = memory_tool.execute_action("view", path="/test.txt")
assert "content" in content
def test_overwrite_file(self, memory_tool):
memory_tool.execute_action("create", path="/test.txt", file_text="old")
memory_tool.execute_action("create", path="/test.txt", file_text="new")
content = memory_tool.execute_action("view", path="/test.txt")
assert "new" in content
def test_create_at_directory_path(self, memory_tool):
result = memory_tool.execute_action("create", path="/dir/", file_text="text")
assert "Error" in result
assert "directory path" in result.lower()
def test_create_invalid_path(self, memory_tool):
result = memory_tool.execute_action("create", path="", file_text="text")
assert "Error" in result
def test_create_nested_path(self, memory_tool):
result = memory_tool.execute_action(
"create", path="/a/b/c/file.txt", file_text="deep"
)
assert "File created" in result
# =====================================================================
# String Replace Action
# =====================================================================
@pytest.mark.unit
class TestStrReplaceAction:
def test_replace_text(self, memory_tool):
memory_tool.execute_action(
"create", path="/doc.txt", file_text="Hello World"
)
result = memory_tool.execute_action(
"str_replace", path="/doc.txt", old_str="Hello", new_str="Hi"
)
assert "File updated" in result
content = memory_tool.execute_action("view", path="/doc.txt")
assert "Hi World" in content
def test_replace_not_found(self, memory_tool):
memory_tool.execute_action("create", path="/doc.txt", file_text="Hello")
result = memory_tool.execute_action(
"str_replace", path="/doc.txt", old_str="Missing", new_str="X"
)
assert "not found" in result.lower()
def test_replace_empty_old_str(self, memory_tool):
memory_tool.execute_action("create", path="/doc.txt", file_text="Hello")
result = memory_tool.execute_action(
"str_replace", path="/doc.txt", old_str="", new_str="X"
)
assert "Error" in result
def test_replace_file_not_found(self, memory_tool):
result = memory_tool.execute_action(
"str_replace", path="/missing.txt", old_str="a", new_str="b"
)
assert "not found" in result.lower()
def test_replace_case_insensitive(self, memory_tool):
memory_tool.execute_action(
"create", path="/doc.txt", file_text="Hello World"
)
result = memory_tool.execute_action(
"str_replace", path="/doc.txt", old_str="hello", new_str="Hi"
)
assert "File updated" in result
# =====================================================================
# Insert Action
# =====================================================================
@pytest.mark.unit
class TestInsertAction:
def test_insert_text(self, memory_tool):
memory_tool.execute_action(
"create", path="/doc.txt", file_text="line1\nline2"
)
result = memory_tool.execute_action(
"insert", path="/doc.txt", insert_line=2, insert_text="inserted"
)
assert "inserted" in result.lower()
content = memory_tool.execute_action("view", path="/doc.txt")
assert "inserted" in content
def test_insert_empty_text(self, memory_tool):
memory_tool.execute_action("create", path="/doc.txt", file_text="line1")
result = memory_tool.execute_action(
"insert", path="/doc.txt", insert_line=1, insert_text=""
)
assert "Error" in result
def test_insert_file_not_found(self, memory_tool):
result = memory_tool.execute_action(
"insert", path="/missing.txt", insert_line=1, insert_text="text"
)
assert "not found" in result.lower()
def test_insert_invalid_line_number(self, memory_tool):
memory_tool.execute_action("create", path="/doc.txt", file_text="line1")
result = memory_tool.execute_action(
"insert", path="/doc.txt", insert_line=-5, insert_text="text"
)
assert "Error" in result
# =====================================================================
# Delete Action
# =====================================================================
@pytest.mark.unit
class TestDeleteAction:
def test_delete_file(self, memory_tool):
memory_tool.execute_action("create", path="/test.txt", file_text="data")
result = memory_tool.execute_action("delete", path="/test.txt")
assert "Deleted" in result
content = memory_tool.execute_action("view", path="/test.txt")
assert "not found" in content.lower()
def test_delete_nonexistent_file(self, memory_tool):
result = memory_tool.execute_action("delete", path="/missing.txt")
assert "not found" in result.lower()
def test_delete_root_clears_all(self, memory_tool):
memory_tool.execute_action("create", path="/a.txt", file_text="a")
memory_tool.execute_action("create", path="/b.txt", file_text="b")
result = memory_tool.execute_action("delete", path="/")
assert "Deleted" in result
assert "2" in result
def test_delete_directory(self, memory_tool):
memory_tool.execute_action("create", path="/dir/f1.txt", file_text="1")
memory_tool.execute_action("create", path="/dir/f2.txt", file_text="2")
result = memory_tool.execute_action("delete", path="/dir/")
assert "Deleted" in result
def test_delete_directory_without_trailing_slash(self, memory_tool):
memory_tool.execute_action("create", path="/dir/f1.txt", file_text="1")
result = memory_tool.execute_action("delete", path="/dir")
assert "Deleted" in result
def test_delete_invalid_path(self, memory_tool):
result = memory_tool.execute_action("delete", path="")
assert "Error" in result
# =====================================================================
# Rename Action
# =====================================================================
@pytest.mark.unit
class TestRenameAction:
def test_rename_file(self, memory_tool):
memory_tool.execute_action("create", path="/old.txt", file_text="data")
result = memory_tool.execute_action(
"rename", old_path="/old.txt", new_path="/new.txt"
)
assert "Renamed" in result
content = memory_tool.execute_action("view", path="/new.txt")
assert "data" in content
def test_rename_file_not_found(self, memory_tool):
result = memory_tool.execute_action(
"rename", old_path="/missing.txt", new_path="/new.txt"
)
assert "not found" in result.lower()
def test_rename_target_exists(self, memory_tool):
memory_tool.execute_action("create", path="/a.txt", file_text="a")
memory_tool.execute_action("create", path="/b.txt", file_text="b")
result = memory_tool.execute_action(
"rename", old_path="/a.txt", new_path="/b.txt"
)
assert "already exists" in result.lower()
def test_rename_root_rejected(self, memory_tool):
result = memory_tool.execute_action(
"rename", old_path="/", new_path="/new/"
)
assert "Cannot rename root" in result
def test_rename_directory(self, memory_tool):
memory_tool.execute_action("create", path="/old/f.txt", file_text="data")
result = memory_tool.execute_action(
"rename", old_path="/old/", new_path="/new/"
)
assert "Renamed" in result
content = memory_tool.execute_action("view", path="/new/f.txt")
assert "data" in content
def test_rename_directory_not_found(self, memory_tool):
result = memory_tool.execute_action(
"rename", old_path="/missing/", new_path="/new/"
)
assert "not found" in result.lower()
def test_rename_invalid_path(self, memory_tool):
result = memory_tool.execute_action(
"rename", old_path="", new_path="/new.txt"
)
assert "Error" in result
# =====================================================================
# Metadata
# =====================================================================
@pytest.mark.unit
class TestMemoryToolMetadata:
def test_actions_metadata(self, memory_tool):
meta = memory_tool.get_actions_metadata()
action_names = [a["name"] for a in meta]
assert "view" in action_names
assert "create" in action_names
assert "str_replace" in action_names
assert "insert" in action_names
assert "delete" in action_names
assert "rename" in action_names
assert len(meta) == 6
def test_config_requirements(self, memory_tool):
assert memory_tool.get_config_requirements() == {}
# =====================================================================
# Coverage gap tests (lines 254, 257, 271, 275, 279)
# =====================================================================
@pytest.mark.unit
class TestMemoryToolValidatePath:
def test_validate_path_with_traversal_returns_none(self, memory_tool):
"""Cover line 244-245: path with .. returns None."""
result = memory_tool._validate_path("/some/../etc/passwd")
assert result is None
def test_validate_path_with_directory_trailing_slash(self, memory_tool):
"""Cover line 257-258: trailing slash is preserved."""
result = memory_tool._validate_path("/some/dir/")
assert result is not None
assert result.endswith("/")
def test_validate_path_empty_returns_none(self, memory_tool):
"""Cover: empty path returns None."""
result = memory_tool._validate_path("")
assert result is None
def test_validate_path_none_returns_none(self, memory_tool):
"""Cover: None path returns None."""
result = memory_tool._validate_path(None)
assert result is None
def test_validate_path_relative_gets_prefixed(self, memory_tool):
"""Cover line 241: relative path gets / prepended."""
result = memory_tool._validate_path("relative/path")
assert result == "/relative/path"
def test_validate_path_double_slash_returns_none(self, memory_tool):
"""Cover line 244: double slash returns None."""
result = memory_tool._validate_path("//etc/passwd")
assert result is None
@pytest.mark.unit
class TestMemoryToolViewDirectory:
def test_view_with_directory_path(self, memory_tool):
"""Cover line 271-275: _view with directory path delegates to _view_directory."""
result = memory_tool._view("/")
assert isinstance(result, str)
def test_view_with_file_path(self, memory_tool):
"""Cover line 279: _view with non-directory path delegates to _view_file."""
# _view on a non-existent file path still exercises the _view_file path
result = memory_tool._view("/nonexistent.txt")
assert "Error" in result or "not found" in result.lower()
# ---------------------------------------------------------------------------
# Coverage — additional uncovered lines: 254, 257, 271, 275, 279
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestMemoryToolValidatePathCoverage:
def test_validate_path_traversal_returns_none(self, memory_tool):
"""Cover line 244: path with directory traversal returns None."""
result = memory_tool._validate_path("/etc/../passwd")
assert result is None
def test_validate_path_directory_appends_slash(self, memory_tool):
"""Cover line 257: path ending with / preserves trailing slash."""
result = memory_tool._validate_path("/some/dir/")
assert result is not None
assert result.endswith("/")
def test_validate_path_root_directory(self, memory_tool):
"""Cover line 257: root directory preserved as-is."""
result = memory_tool._validate_path("/")
assert result == "/"
@pytest.mark.unit
class TestMemoryToolViewCoverage:
def test_view_invalid_path_returns_error(self, memory_tool):
"""Cover line 271: _view with invalid path returns error."""
result = memory_tool._view("//bad//path")
assert "Error" in result
def test_view_root_directory(self, memory_tool):
"""Cover line 275: _view with root directory."""
result = memory_tool._view("/")
assert isinstance(result, str)
def test_view_file_path(self, memory_tool):
"""Cover line 279: _view with file path delegates to _view_file."""
result = memory_tool._view("/some/file.txt")
assert isinstance(result, str)

View File

@@ -0,0 +1,235 @@
"""Tests for application/api/answer/routes/answer.py"""
import json
from unittest.mock import MagicMock, patch
import pytest
from bson import ObjectId
@pytest.fixture
def mock_stream_processor():
"""Create a mock StreamProcessor."""
with patch(
"application.api.answer.routes.answer.StreamProcessor"
) as MockProcessor:
processor = MagicMock()
processor.decoded_token = {"sub": "test_user"}
processor.conversation_id = str(ObjectId())
processor.agent_config = {}
processor.agent_id = str(ObjectId())
processor.is_shared_usage = False
processor.shared_token = None
processor.model_id = "gpt-4"
processor.build_agent.return_value = MagicMock()
MockProcessor.return_value = processor
yield processor
@pytest.fixture
def answer_client(mock_mongo_db, flask_app):
"""Create a test client with the answer route registered."""
from flask_restx import Api
from application.api.answer.routes.answer import answer_ns
api = Api(flask_app)
api.add_namespace(answer_ns)
flask_app.config["TESTING"] = True
return flask_app.test_client()
@pytest.mark.unit
class TestAnswerResourcePost:
def test_missing_question_returns_400(self, answer_client, mock_stream_processor):
resp = answer_client.post(
"/api/answer",
data=json.dumps({}),
content_type="application/json",
)
assert resp.status_code == 400
def test_successful_answer(self, answer_client, mock_stream_processor):
conv_id = str(ObjectId())
with patch.object(
mock_stream_processor.build_agent.return_value,
"gen",
return_value=iter([]),
):
with patch(
"application.api.answer.routes.answer.AnswerResource.validate_request",
return_value=None,
), patch(
"application.api.answer.routes.answer.AnswerResource.check_usage",
return_value=None,
), patch(
"application.api.answer.routes.answer.AnswerResource.complete_stream",
return_value=iter(
[
f'data: {json.dumps({"type": "answer", "answer": "Hello"})}\n\n',
f'data: {json.dumps({"type": "id", "id": conv_id})}\n\n',
f'data: {json.dumps({"type": "end"})}\n\n',
]
),
), patch(
"application.api.answer.routes.answer.AnswerResource.process_response_stream",
return_value=(conv_id, "Hello", [], [], "", None),
):
resp = answer_client.post(
"/api/answer",
data=json.dumps({"question": "What is Python?"}),
content_type="application/json",
)
assert resp.status_code == 200
data = resp.get_json()
assert data["answer"] == "Hello"
assert data["conversation_id"] == conv_id
def test_unauthorized_returns_401(self, answer_client, mock_stream_processor):
mock_stream_processor.decoded_token = None
with patch(
"application.api.answer.routes.answer.AnswerResource.validate_request",
return_value=None,
):
resp = answer_client.post(
"/api/answer",
data=json.dumps({"question": "test"}),
content_type="application/json",
)
assert resp.status_code == 401
assert resp.get_json()["error"] == "Unauthorized"
def test_usage_exceeded_returns_error(self, answer_client, mock_stream_processor):
with patch(
"application.api.answer.routes.answer.AnswerResource.validate_request",
return_value=None,
), patch(
"application.api.answer.routes.answer.AnswerResource.check_usage",
) as mock_check:
with flask_app_context(answer_client):
mock_check.return_value = ({"error": "Usage limit exceeded"}, 429)
resp = answer_client.post(
"/api/answer",
data=json.dumps({"question": "test"}),
content_type="application/json",
)
assert resp.status_code == 429
def test_stream_error_returns_400(self, answer_client, mock_stream_processor):
with patch(
"application.api.answer.routes.answer.AnswerResource.validate_request",
return_value=None,
), patch(
"application.api.answer.routes.answer.AnswerResource.check_usage",
return_value=None,
), patch(
"application.api.answer.routes.answer.AnswerResource.complete_stream",
return_value=iter([]),
), patch(
"application.api.answer.routes.answer.AnswerResource.process_response_stream",
return_value=(None, None, None, None, None, "Stream error"),
):
resp = answer_client.post(
"/api/answer",
data=json.dumps({"question": "test"}),
content_type="application/json",
)
assert resp.status_code == 400
assert resp.get_json()["error"] == "Stream error"
def test_exception_returns_500(self, answer_client, mock_stream_processor):
with patch(
"application.api.answer.routes.answer.AnswerResource.validate_request",
return_value=None,
), patch(
"application.api.answer.routes.answer.AnswerResource.check_usage",
return_value=None,
), patch(
"application.api.answer.routes.answer.AnswerResource.complete_stream",
side_effect=RuntimeError("unexpected"),
):
resp = answer_client.post(
"/api/answer",
data=json.dumps({"question": "test"}),
content_type="application/json",
)
assert resp.status_code == 500
assert "error" in resp.get_json()
def test_structured_info_merged_into_result(
self, answer_client, mock_stream_processor
):
conv_id = str(ObjectId())
with patch(
"application.api.answer.routes.answer.AnswerResource.validate_request",
return_value=None,
), patch(
"application.api.answer.routes.answer.AnswerResource.check_usage",
return_value=None,
), patch(
"application.api.answer.routes.answer.AnswerResource.complete_stream",
return_value=iter([]),
), patch(
"application.api.answer.routes.answer.AnswerResource.process_response_stream",
return_value=(
conv_id,
'{"key": "val"}',
[],
[],
"",
None,
{"structured": True, "schema": {"type": "object"}},
),
):
resp = answer_client.post(
"/api/answer",
data=json.dumps({"question": "test"}),
content_type="application/json",
)
assert resp.status_code == 200
data = resp.get_json()
assert data["structured"] is True
assert data["schema"] == {"type": "object"}
def test_result_contains_all_expected_fields(
self, answer_client, mock_stream_processor
):
conv_id = str(ObjectId())
with patch(
"application.api.answer.routes.answer.AnswerResource.validate_request",
return_value=None,
), patch(
"application.api.answer.routes.answer.AnswerResource.check_usage",
return_value=None,
), patch(
"application.api.answer.routes.answer.AnswerResource.complete_stream",
return_value=iter([]),
), patch(
"application.api.answer.routes.answer.AnswerResource.process_response_stream",
return_value=(
conv_id,
"answer text",
[{"title": "src"}],
[{"tool": "t"}],
"thinking...",
None,
),
):
resp = answer_client.post(
"/api/answer",
data=json.dumps({"question": "test"}),
content_type="application/json",
)
data = resp.get_json()
assert data["conversation_id"] == conv_id
assert data["answer"] == "answer text"
assert data["sources"] == [{"title": "src"}]
assert data["tool_calls"] == [{"tool": "t"}]
assert data["thought"] == "thinking..."
def flask_app_context(client):
"""Helper to get app context from test client."""
return client.application.app_context()

View File

@@ -0,0 +1,195 @@
"""Tests for application/api/answer/routes/stream.py"""
import json
from unittest.mock import MagicMock, patch
import pytest
from bson import ObjectId
@pytest.fixture
def mock_stream_processor():
"""Create a mock StreamProcessor for stream tests."""
with patch(
"application.api.answer.routes.stream.StreamProcessor"
) as MockProcessor:
processor = MagicMock()
processor.decoded_token = {"sub": "test_user"}
processor.conversation_id = str(ObjectId())
processor.agent_config = {}
processor.agent_id = str(ObjectId())
processor.is_shared_usage = False
processor.shared_token = None
processor.model_id = "gpt-4"
processor.build_agent.return_value = MagicMock()
MockProcessor.return_value = processor
yield processor
@pytest.fixture
def stream_client(mock_mongo_db, flask_app):
"""Create a test client with the stream route registered."""
from flask_restx import Api
from application.api.answer.routes.stream import answer_ns
api = Api(flask_app)
api.add_namespace(answer_ns)
flask_app.config["TESTING"] = True
return flask_app.test_client()
@pytest.mark.unit
class TestStreamResourcePost:
def test_missing_question_returns_400(self, stream_client, mock_stream_processor):
resp = stream_client.post(
"/stream",
data=json.dumps({}),
content_type="application/json",
)
assert resp.status_code == 400
def test_successful_stream(self, stream_client, mock_stream_processor):
def fake_stream(*args, **kwargs):
yield f'data: {json.dumps({"type": "answer", "answer": "Hi"})}\n\n'
yield f'data: {json.dumps({"type": "end"})}\n\n'
with patch(
"application.api.answer.routes.stream.StreamResource.validate_request",
return_value=None,
), patch(
"application.api.answer.routes.stream.StreamResource.check_usage",
return_value=None,
), patch(
"application.api.answer.routes.stream.StreamResource.complete_stream",
side_effect=fake_stream,
):
resp = stream_client.post(
"/stream",
data=json.dumps({"question": "What is Python?"}),
content_type="application/json",
)
assert resp.status_code == 200
assert "text/event-stream" in resp.content_type
data = resp.get_data(as_text=True)
assert '"type": "answer"' in data
assert '"answer": "Hi"' in data
def test_unauthorized_returns_401_stream(
self, stream_client, mock_stream_processor
):
mock_stream_processor.decoded_token = None
with patch(
"application.api.answer.routes.stream.StreamResource.validate_request",
return_value=None,
):
resp = stream_client.post(
"/stream",
data=json.dumps({"question": "test"}),
content_type="application/json",
)
assert resp.status_code == 401
assert "text/event-stream" in resp.content_type
data = resp.get_data(as_text=True)
assert "Unauthorized" in data
def test_usage_exceeded_returns_error(
self, stream_client, mock_stream_processor
):
with patch(
"application.api.answer.routes.stream.StreamResource.validate_request",
return_value=None,
), patch(
"application.api.answer.routes.stream.StreamResource.check_usage",
) as mock_check:
mock_check.return_value = ({"error": "Usage limit exceeded"}, 429)
resp = stream_client.post(
"/stream",
data=json.dumps({"question": "test"}),
content_type="application/json",
)
assert resp.status_code == 429
def test_value_error_returns_400_stream(
self, stream_client, mock_stream_processor
):
mock_stream_processor.build_agent.side_effect = ValueError("bad data")
with patch(
"application.api.answer.routes.stream.StreamResource.validate_request",
return_value=None,
):
resp = stream_client.post(
"/stream",
data=json.dumps({"question": "test"}),
content_type="application/json",
)
assert resp.status_code == 400
assert "text/event-stream" in resp.content_type
data = resp.get_data(as_text=True)
assert "Malformed request body" in data
def test_general_exception_returns_400_stream(
self, stream_client, mock_stream_processor
):
mock_stream_processor.build_agent.side_effect = RuntimeError("crash")
with patch(
"application.api.answer.routes.stream.StreamResource.validate_request",
return_value=None,
):
resp = stream_client.post(
"/stream",
data=json.dumps({"question": "test"}),
content_type="application/json",
)
assert resp.status_code == 400
assert "text/event-stream" in resp.content_type
data = resp.get_data(as_text=True)
assert "Unknown error occurred" in data
def test_index_in_data_requires_conversation_id(
self, stream_client, mock_stream_processor
):
"""When 'index' is present, validate_request is called with require_conversation_id=True."""
resp = stream_client.post(
"/stream",
data=json.dumps({"question": "test", "index": 0}),
content_type="application/json",
)
# Should get 400 since conversation_id is missing
assert resp.status_code == 400
def test_stream_passes_attachments_and_index(
self, stream_client, mock_stream_processor
):
"""Verify attachments and index params are forwarded to complete_stream."""
def fake_stream(*args, **kwargs):
yield f'data: {json.dumps({"type": "end"})}\n\n'
conv_id = str(ObjectId())
with patch(
"application.api.answer.routes.stream.StreamResource.validate_request",
return_value=None,
), patch(
"application.api.answer.routes.stream.StreamResource.check_usage",
return_value=None,
), patch(
"application.api.answer.routes.stream.StreamResource.complete_stream",
side_effect=fake_stream,
) as mock_complete:
resp = stream_client.post(
"/stream",
data=json.dumps(
{
"question": "test",
"conversation_id": conv_id,
"index": 3,
"attachments": ["att1", "att2"],
}
),
content_type="application/json",
)
assert resp.status_code == 200
call_kwargs = mock_complete.call_args
assert call_kwargs.kwargs.get("index") == 3
assert call_kwargs.kwargs.get("attachment_ids") == ["att1", "att2"]

View File

@@ -0,0 +1,303 @@
"""Tests for application/api/answer/services/compression/message_builder.py"""
import pytest
from application.api.answer.services.compression.message_builder import MessageBuilder
@pytest.mark.unit
class TestBuildFromCompressedContext:
def test_no_compression_returns_system_only(self):
messages = MessageBuilder.build_from_compressed_context(
system_prompt="You are helpful.",
compressed_summary=None,
recent_queries=[],
)
assert len(messages) == 1
assert messages[0]["role"] == "system"
assert messages[0]["content"] == "You are helpful."
def test_with_recent_queries_no_compression(self):
queries = [
{"prompt": "Hello", "response": "Hi there!"},
{"prompt": "How are you?", "response": "I'm fine."},
]
messages = MessageBuilder.build_from_compressed_context(
system_prompt="System prompt",
compressed_summary=None,
recent_queries=queries,
)
# system + 2 * (user + assistant) = 5
assert len(messages) == 5
assert messages[1] == {"role": "user", "content": "Hello"}
assert messages[2] == {"role": "assistant", "content": "Hi there!"}
assert messages[3] == {"role": "user", "content": "How are you?"}
assert messages[4] == {"role": "assistant", "content": "I'm fine."}
def test_with_compressed_summary_appended_to_system(self):
messages = MessageBuilder.build_from_compressed_context(
system_prompt="You are helpful.",
compressed_summary="Previous: user asked about Python.",
recent_queries=[{"prompt": "More?", "response": "Sure."}],
)
system_content = messages[0]["content"]
assert "This session is being continued" in system_content
assert "Previous: user asked about Python." in system_content
def test_mid_execution_context_type(self):
messages = MessageBuilder.build_from_compressed_context(
system_prompt="System",
compressed_summary="Summary here",
recent_queries=[{"prompt": "q", "response": "r"}],
context_type="mid_execution",
)
system_content = messages[0]["content"]
assert "Context window limit reached" in system_content
def test_include_tool_calls(self):
queries = [
{
"prompt": "Search for X",
"response": "Found X",
"tool_calls": [
{
"call_id": "call-1",
"action_name": "search",
"arguments": {"q": "X"},
"result": "X found",
}
],
}
]
messages = MessageBuilder.build_from_compressed_context(
system_prompt="System",
compressed_summary=None,
recent_queries=queries,
include_tool_calls=True,
)
# system + user + assistant + tool_call_assistant + tool_response = 5
assert len(messages) == 5
assert messages[3]["role"] == "assistant"
assert "function_call" in messages[3]["content"][0]
assert messages[4]["role"] == "tool"
assert "function_response" in messages[4]["content"][0]
def test_tool_calls_not_included_by_default(self):
queries = [
{
"prompt": "Search",
"response": "Found",
"tool_calls": [
{
"call_id": "c1",
"action_name": "search",
"arguments": {},
"result": "ok",
}
],
}
]
messages = MessageBuilder.build_from_compressed_context(
system_prompt="System",
compressed_summary=None,
recent_queries=queries,
include_tool_calls=False,
)
# system + user + assistant = 3 (no tool messages)
assert len(messages) == 3
def test_tool_call_without_call_id_generates_uuid(self):
queries = [
{
"prompt": "q",
"response": "r",
"tool_calls": [
{
"action_name": "act",
"arguments": {},
"result": "res",
}
],
}
]
messages = MessageBuilder.build_from_compressed_context(
system_prompt="S",
compressed_summary=None,
recent_queries=queries,
include_tool_calls=True,
)
tool_msg = messages[3]["content"][0]
call_id = tool_msg["function_call"]["call_id"]
assert call_id is not None
assert len(call_id) > 0
def test_continuation_message_when_no_recent_queries_but_has_summary(self):
messages = MessageBuilder.build_from_compressed_context(
system_prompt="System",
compressed_summary="Everything was compressed",
recent_queries=[],
)
# system + continuation user message = 2
assert len(messages) == 2
assert messages[1]["role"] == "user"
assert "continue" in messages[1]["content"].lower()
def test_no_continuation_when_no_summary(self):
messages = MessageBuilder.build_from_compressed_context(
system_prompt="System",
compressed_summary=None,
recent_queries=[],
)
assert len(messages) == 1
def test_queries_without_prompt_or_response_skipped(self):
queries = [
{"other_field": "value"},
{"prompt": "real", "response": "answer"},
]
messages = MessageBuilder.build_from_compressed_context(
system_prompt="S",
compressed_summary=None,
recent_queries=queries,
)
# system + 1 valid query (user + assistant) = 3
assert len(messages) == 3
@pytest.mark.unit
class TestAppendCompressionContext:
def test_pre_request_context(self):
result = MessageBuilder._append_compression_context(
"Original prompt", "Summary text", "pre_request"
)
assert "This session is being continued" in result
assert "Summary text" in result
assert result.startswith("Original prompt")
def test_mid_execution_context(self):
result = MessageBuilder._append_compression_context(
"Original prompt", "Summary text", "mid_execution"
)
assert "Context window limit reached" in result
assert "Summary text" in result
def test_removes_existing_compression_context(self):
prompt_with_existing = (
"Original prompt\n\n---\n\nThis session is being continued from old"
)
result = MessageBuilder._append_compression_context(
prompt_with_existing, "New summary", "pre_request"
)
# Should not contain old context twice
assert result.count("This session is being continued") == 1
assert "New summary" in result
def test_removes_mid_execution_context(self):
prompt_with_existing = (
"Original\n\n---\n\nContext window limit reached during execution. Old."
)
result = MessageBuilder._append_compression_context(
prompt_with_existing, "New", "mid_execution"
)
assert result.count("Context window limit reached") == 1
@pytest.mark.unit
class TestRebuildMessagesAfterCompression:
def test_basic_rebuild(self):
messages = [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "old message"},
{"role": "assistant", "content": "old reply"},
]
recent = [{"prompt": "new q", "response": "new r"}]
result = MessageBuilder.rebuild_messages_after_compression(
messages=messages,
compressed_summary="Everything was compressed.",
recent_queries=recent,
)
assert result is not None
# system + user + assistant = 3
assert len(result) == 3
assert "Context window limit reached" in result[0]["content"]
assert result[1] == {"role": "user", "content": "new q"}
assert result[2] == {"role": "assistant", "content": "new r"}
def test_returns_none_without_system_message(self):
messages = [
{"role": "user", "content": "hello"},
]
result = MessageBuilder.rebuild_messages_after_compression(
messages=messages,
compressed_summary="summary",
recent_queries=[],
)
assert result is None
def test_no_summary_keeps_system_unchanged(self):
messages = [{"role": "system", "content": "Be helpful."}]
result = MessageBuilder.rebuild_messages_after_compression(
messages=messages,
compressed_summary=None,
recent_queries=[],
)
assert result is not None
assert result[0]["content"] == "Be helpful."
def test_include_tool_calls_in_rebuild(self):
messages = [{"role": "system", "content": "S"}]
recent = [
{
"prompt": "q",
"response": "r",
"tool_calls": [
{
"call_id": "c1",
"action_name": "act",
"arguments": {"a": 1},
"result": "done",
}
],
}
]
result = MessageBuilder.rebuild_messages_after_compression(
messages=messages,
compressed_summary="s",
recent_queries=recent,
include_tool_calls=True,
)
# system + user + assistant + tool_call + tool_response = 5
assert len(result) == 5
def test_continuation_added_when_no_recent_queries(self):
messages = [{"role": "system", "content": "S"}]
result = MessageBuilder.rebuild_messages_after_compression(
messages=messages,
compressed_summary="All compressed",
recent_queries=[],
)
assert len(result) == 2
assert result[1]["role"] == "user"
assert "continue" in result[1]["content"].lower()
def test_include_current_execution_preserves_extra_messages(self):
messages = [
{"role": "system", "content": "S"},
{"role": "user", "content": "q1"},
{"role": "assistant", "content": "r1"},
{"role": "user", "content": "current execution msg"},
]
recent = [{"prompt": "q1", "response": "r1"}]
result = MessageBuilder.rebuild_messages_after_compression(
messages=messages,
compressed_summary="summary",
recent_queries=recent,
include_current_execution=True,
)
assert result is not None
# Should include the current execution message
contents = [m.get("content") for m in result]
assert "current execution msg" in contents

View File

@@ -0,0 +1,447 @@
"""Tests for application/api/answer/services/compression/orchestrator.py"""
from unittest.mock import MagicMock, patch
import pytest
from application.api.answer.services.compression.orchestrator import (
CompressionOrchestrator,
)
from application.api.answer.services.compression.types import (
CompressionMetadata,
CompressionResult,
)
@pytest.fixture
def mock_conversation_service():
svc = MagicMock()
return svc
@pytest.fixture
def mock_threshold_checker():
checker = MagicMock()
return checker
@pytest.fixture
def orchestrator(mock_conversation_service, mock_threshold_checker):
return CompressionOrchestrator(
conversation_service=mock_conversation_service,
threshold_checker=mock_threshold_checker,
)
@pytest.fixture
def sample_conversation():
return {
"queries": [
{"prompt": "q0", "response": "r0"},
{"prompt": "q1", "response": "r1"},
{"prompt": "q2", "response": "r2"},
],
"compression_metadata": {},
"agent_id": "agent-1",
}
@pytest.fixture
def decoded_token():
return {"sub": "user123"}
@pytest.mark.unit
class TestCompressIfNeeded:
def test_conversation_not_found_returns_failure(
self, orchestrator, mock_conversation_service
):
mock_conversation_service.get_conversation.return_value = None
result = orchestrator.compress_if_needed(
conversation_id="conv1",
user_id="user1",
model_id="gpt-4",
decoded_token={"sub": "user1"},
)
assert result.success is False
assert "not found" in result.error
def test_no_compression_needed(
self,
orchestrator,
mock_conversation_service,
mock_threshold_checker,
sample_conversation,
):
mock_conversation_service.get_conversation.return_value = sample_conversation
mock_threshold_checker.should_compress.return_value = False
result = orchestrator.compress_if_needed(
conversation_id="conv1",
user_id="user1",
model_id="gpt-4",
decoded_token={"sub": "user1"},
)
assert result.success is True
assert result.compression_performed is False
assert len(result.recent_queries) == 3
def test_compression_performed_successfully(
self,
orchestrator,
mock_conversation_service,
mock_threshold_checker,
sample_conversation,
decoded_token,
):
mock_conversation_service.get_conversation.return_value = sample_conversation
mock_threshold_checker.should_compress.return_value = True
mock_metadata = MagicMock(spec=CompressionMetadata)
mock_metadata.compression_ratio = 5.0
mock_metadata.original_token_count = 1000
mock_metadata.compressed_token_count = 200
mock_metadata.to_dict.return_value = {"query_index": 2}
with patch.object(
orchestrator, "_perform_compression"
) as mock_perform:
mock_perform.return_value = CompressionResult.success_with_compression(
"compressed summary",
[{"prompt": "q2", "response": "r2"}],
mock_metadata,
)
result = orchestrator.compress_if_needed(
conversation_id="conv1",
user_id="user1",
model_id="gpt-4",
decoded_token=decoded_token,
)
assert result.success is True
assert result.compression_performed is True
assert result.compressed_summary == "compressed summary"
mock_perform.assert_called_once()
def test_exception_returns_failure(
self,
orchestrator,
mock_conversation_service,
):
mock_conversation_service.get_conversation.side_effect = RuntimeError("DB down")
result = orchestrator.compress_if_needed(
conversation_id="conv1",
user_id="user1",
model_id="gpt-4",
decoded_token={"sub": "user1"},
)
assert result.success is False
assert "DB down" in result.error
def test_custom_query_tokens(
self,
orchestrator,
mock_conversation_service,
mock_threshold_checker,
sample_conversation,
):
mock_conversation_service.get_conversation.return_value = sample_conversation
mock_threshold_checker.should_compress.return_value = False
orchestrator.compress_if_needed(
conversation_id="conv1",
user_id="user1",
model_id="gpt-4",
decoded_token={"sub": "user1"},
current_query_tokens=1000,
)
mock_threshold_checker.should_compress.assert_called_once_with(
sample_conversation, "gpt-4", 1000
)
@pytest.mark.unit
class TestPerformCompression:
@patch(
"application.api.answer.services.compression.orchestrator.get_provider_from_model_id"
)
@patch(
"application.api.answer.services.compression.orchestrator.get_api_key_for_provider"
)
@patch("application.api.answer.services.compression.orchestrator.LLMCreator")
@patch("application.api.answer.services.compression.orchestrator.CompressionService")
@patch("application.api.answer.services.compression.orchestrator.settings")
def test_successful_compression(
self,
mock_settings,
MockCompressionService,
MockLLMCreator,
mock_get_api_key,
mock_get_provider,
mock_conversation_service,
mock_threshold_checker,
sample_conversation,
decoded_token,
):
mock_settings.COMPRESSION_MODEL_OVERRIDE = None
mock_get_provider.return_value = "openai"
mock_get_api_key.return_value = "sk-test"
MockLLMCreator.create_llm.return_value = MagicMock()
mock_metadata = MagicMock(spec=CompressionMetadata)
mock_metadata.compression_ratio = 5.0
mock_metadata.original_token_count = 500
mock_metadata.compressed_token_count = 100
mock_svc_instance = MagicMock()
mock_svc_instance.compress_and_save.return_value = mock_metadata
mock_svc_instance.get_compressed_context.return_value = (
"compressed text",
[{"prompt": "q2", "response": "r2"}],
)
MockCompressionService.return_value = mock_svc_instance
# After compression, reload conversation
mock_conversation_service.get_conversation.return_value = sample_conversation
orch = CompressionOrchestrator(
conversation_service=mock_conversation_service,
threshold_checker=mock_threshold_checker,
)
result = orch._perform_compression(
"conv1", sample_conversation, "gpt-4", decoded_token
)
assert result.success is True
assert result.compression_performed is True
assert result.compressed_summary == "compressed text"
mock_svc_instance.compress_and_save.assert_called_once()
@patch(
"application.api.answer.services.compression.orchestrator.get_provider_from_model_id"
)
@patch(
"application.api.answer.services.compression.orchestrator.get_api_key_for_provider"
)
@patch("application.api.answer.services.compression.orchestrator.LLMCreator")
@patch("application.api.answer.services.compression.orchestrator.settings")
def test_uses_compression_model_override(
self,
mock_settings,
MockLLMCreator,
mock_get_api_key,
mock_get_provider,
mock_conversation_service,
mock_threshold_checker,
decoded_token,
):
mock_settings.COMPRESSION_MODEL_OVERRIDE = "gpt-3.5-turbo"
mock_get_provider.return_value = "openai"
mock_get_api_key.return_value = "sk-test"
MockLLMCreator.create_llm.return_value = MagicMock()
conversation = {"queries": [{"prompt": "q", "response": "r"}], "agent_id": "a"}
with patch(
"application.api.answer.services.compression.orchestrator.CompressionService"
) as MockCS:
mock_svc = MagicMock()
mock_svc.compress_and_save.return_value = MagicMock(
compression_ratio=3.0,
original_token_count=300,
compressed_token_count=100,
)
mock_svc.get_compressed_context.return_value = ("s", [])
MockCS.return_value = mock_svc
mock_conversation_service.get_conversation.return_value = conversation
orch = CompressionOrchestrator(
conversation_service=mock_conversation_service,
threshold_checker=mock_threshold_checker,
)
orch._perform_compression("c1", conversation, "gpt-4", decoded_token)
# Verify the override model was used
mock_get_provider.assert_called_with("gpt-3.5-turbo")
@patch(
"application.api.answer.services.compression.orchestrator.get_provider_from_model_id"
)
@patch(
"application.api.answer.services.compression.orchestrator.get_api_key_for_provider"
)
@patch("application.api.answer.services.compression.orchestrator.LLMCreator")
@patch("application.api.answer.services.compression.orchestrator.CompressionService")
@patch("application.api.answer.services.compression.orchestrator.settings")
def test_no_queries_returns_no_compression(
self,
mock_settings,
MockCompressionService,
MockLLMCreator,
mock_get_api_key,
mock_get_provider,
mock_conversation_service,
mock_threshold_checker,
decoded_token,
):
mock_settings.COMPRESSION_MODEL_OVERRIDE = None
mock_get_provider.return_value = "openai"
mock_get_api_key.return_value = "sk-test"
MockLLMCreator.create_llm.return_value = MagicMock()
conversation = {"queries": [], "agent_id": "a"}
orch = CompressionOrchestrator(
conversation_service=mock_conversation_service,
threshold_checker=mock_threshold_checker,
)
result = orch._perform_compression("c1", conversation, "gpt-4", decoded_token)
assert result.success is True
assert result.compression_performed is False
def test_exception_returns_failure(
self,
mock_conversation_service,
mock_threshold_checker,
decoded_token,
):
conversation = {
"queries": [{"prompt": "q", "response": "r"}],
"agent_id": "a",
}
with patch(
"application.api.answer.services.compression.orchestrator.settings"
) as mock_settings, patch(
"application.api.answer.services.compression.orchestrator.get_provider_from_model_id",
side_effect=RuntimeError("provider error"),
):
mock_settings.COMPRESSION_MODEL_OVERRIDE = None
orch = CompressionOrchestrator(
conversation_service=mock_conversation_service,
threshold_checker=mock_threshold_checker,
)
result = orch._perform_compression(
"c1", conversation, "gpt-4", decoded_token
)
assert result.success is False
assert "provider error" in result.error
@pytest.mark.unit
class TestCompressMidExecution:
def test_with_provided_conversation(
self,
orchestrator,
sample_conversation,
decoded_token,
):
with patch.object(
orchestrator, "_perform_compression"
) as mock_perform:
mock_perform.return_value = CompressionResult.success_no_compression([])
orchestrator.compress_mid_execution(
conversation_id="conv1",
user_id="user1",
model_id="gpt-4",
decoded_token=decoded_token,
current_conversation=sample_conversation,
)
mock_perform.assert_called_once_with(
"conv1", sample_conversation, "gpt-4", decoded_token
)
def test_loads_conversation_when_not_provided(
self,
orchestrator,
mock_conversation_service,
sample_conversation,
decoded_token,
):
mock_conversation_service.get_conversation.return_value = sample_conversation
with patch.object(
orchestrator, "_perform_compression"
) as mock_perform:
mock_perform.return_value = CompressionResult.success_no_compression([])
orchestrator.compress_mid_execution(
conversation_id="conv1",
user_id="user1",
model_id="gpt-4",
decoded_token=decoded_token,
)
mock_conversation_service.get_conversation.assert_called_once_with(
"conv1", "user1"
)
mock_perform.assert_called_once()
def test_conversation_not_found_returns_failure(
self,
orchestrator,
mock_conversation_service,
decoded_token,
):
mock_conversation_service.get_conversation.return_value = None
result = orchestrator.compress_mid_execution(
conversation_id="conv1",
user_id="user1",
model_id="gpt-4",
decoded_token=decoded_token,
)
assert result.success is False
assert "not found" in result.error
def test_exception_returns_failure(
self,
orchestrator,
mock_conversation_service,
decoded_token,
):
mock_conversation_service.get_conversation.side_effect = RuntimeError("fail")
result = orchestrator.compress_mid_execution(
conversation_id="conv1",
user_id="user1",
model_id="gpt-4",
decoded_token=decoded_token,
)
assert result.success is False
assert "fail" in result.error
@pytest.mark.unit
class TestOrchestratorInit:
def test_default_threshold_checker(self, mock_conversation_service):
orch = CompressionOrchestrator(
conversation_service=mock_conversation_service
)
assert orch.threshold_checker is not None
assert orch.conversation_service is mock_conversation_service
def test_custom_threshold_checker(
self, mock_conversation_service, mock_threshold_checker
):
orch = CompressionOrchestrator(
conversation_service=mock_conversation_service,
threshold_checker=mock_threshold_checker,
)
assert orch.threshold_checker is mock_threshold_checker

View File

@@ -0,0 +1,423 @@
"""Tests for application/api/answer/services/compression/service.py"""
from unittest.mock import MagicMock, patch
import pytest
from application.api.answer.services.compression.service import CompressionService
from application.api.answer.services.compression.types import CompressionMetadata
@pytest.fixture
def mock_llm():
llm = MagicMock()
llm.gen.return_value = "<summary>Compressed summary content</summary>"
return llm
@pytest.fixture
def mock_conversation_service():
svc = MagicMock()
svc.update_compression_metadata = MagicMock()
return svc
@pytest.fixture
def sample_conversation():
return {
"queries": [
{"prompt": "What is Python?", "response": "A programming language."},
{"prompt": "Tell me more.", "response": "It's versatile and popular."},
{
"prompt": "What about tools?",
"response": "Python has many tools.",
"tool_calls": [
{
"tool_name": "search",
"action_name": "web_search",
"arguments": {"q": "python tools"},
"result": "Found 10 results",
"status": "success",
}
],
},
],
"compression_metadata": {},
}
@pytest.mark.unit
class TestCompressionServiceInit:
@patch("application.api.answer.services.compression.service.settings")
def test_default_prompt_builder(self, mock_settings, mock_llm):
mock_settings.COMPRESSION_PROMPT_VERSION = "v1.0"
with patch(
"application.api.answer.services.compression.service.CompressionPromptBuilder"
):
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
assert svc.llm is mock_llm
assert svc.model_id == "gpt-4"
def test_custom_prompt_builder(self, mock_llm):
custom_builder = MagicMock()
svc = CompressionService(
llm=mock_llm, model_id="gpt-4", prompt_builder=custom_builder
)
assert svc.prompt_builder is custom_builder
@pytest.mark.unit
class TestCompressConversation:
def test_successful_compression(self, mock_llm, sample_conversation):
mock_builder = MagicMock()
mock_builder.build_prompt.return_value = [
{"role": "system", "content": "Compress"},
{"role": "user", "content": "Conversation..."},
]
mock_builder.version = "v1.0"
svc = CompressionService(
llm=mock_llm, model_id="gpt-4", prompt_builder=mock_builder
)
with patch(
"application.api.answer.services.compression.service.TokenCounter"
) as MockTC:
MockTC.count_query_tokens.return_value = 1000
MockTC.count_message_tokens.return_value = 100
result = svc.compress_conversation(sample_conversation, 2)
assert isinstance(result, CompressionMetadata)
assert result.query_index == 2
assert result.compressed_summary == "Compressed summary content"
assert result.original_token_count == 1000
assert result.compressed_token_count == 100
assert result.compression_ratio == 10.0
assert result.model_used == "gpt-4"
assert result.compression_prompt_version == "v1.0"
def test_invalid_index_negative(self, mock_llm, sample_conversation):
mock_builder = MagicMock()
mock_builder.version = "v1.0"
svc = CompressionService(
llm=mock_llm, model_id="gpt-4", prompt_builder=mock_builder
)
with pytest.raises(ValueError, match="Invalid compress_up_to_index"):
svc.compress_conversation(sample_conversation, -1)
def test_invalid_index_too_large(self, mock_llm, sample_conversation):
mock_builder = MagicMock()
mock_builder.version = "v1.0"
svc = CompressionService(
llm=mock_llm, model_id="gpt-4", prompt_builder=mock_builder
)
with pytest.raises(ValueError, match="Invalid compress_up_to_index"):
svc.compress_conversation(sample_conversation, 10)
def test_with_existing_compressions(self, mock_llm):
conversation = {
"queries": [
{"prompt": "q1", "response": "r1"},
{"prompt": "q2", "response": "r2"},
],
"compression_metadata": {
"compression_points": [
{
"query_index": 0,
"compressed_summary": "Previous summary",
}
]
},
}
mock_builder = MagicMock()
mock_builder.build_prompt.return_value = [
{"role": "system", "content": "Compress"},
{"role": "user", "content": "..."},
]
mock_builder.version = "v1.0"
svc = CompressionService(
llm=mock_llm, model_id="gpt-4", prompt_builder=mock_builder
)
with patch(
"application.api.answer.services.compression.service.TokenCounter"
) as MockTC:
MockTC.count_query_tokens.return_value = 500
MockTC.count_message_tokens.return_value = 50
result = svc.compress_conversation(conversation, 1)
assert isinstance(result, CompressionMetadata)
# Verify existing compressions were passed to prompt builder
call_args = mock_builder.build_prompt.call_args
assert call_args[0][1] == [
{"query_index": 0, "compressed_summary": "Previous summary"}
]
def test_zero_compressed_tokens_ratio(self, mock_llm, sample_conversation):
mock_builder = MagicMock()
mock_builder.build_prompt.return_value = [
{"role": "system", "content": "C"},
{"role": "user", "content": "..."},
]
mock_builder.version = "v1.0"
svc = CompressionService(
llm=mock_llm, model_id="gpt-4", prompt_builder=mock_builder
)
with patch(
"application.api.answer.services.compression.service.TokenCounter"
) as MockTC:
MockTC.count_query_tokens.return_value = 1000
MockTC.count_message_tokens.return_value = 0
result = svc.compress_conversation(sample_conversation, 2)
assert result.compression_ratio == 0
def test_llm_error_propagates(self, sample_conversation):
llm = MagicMock()
llm.gen.side_effect = RuntimeError("LLM error")
mock_builder = MagicMock()
mock_builder.build_prompt.return_value = [
{"role": "system", "content": "C"},
{"role": "user", "content": "..."},
]
mock_builder.version = "v1.0"
svc = CompressionService(
llm=llm, model_id="gpt-4", prompt_builder=mock_builder
)
with patch(
"application.api.answer.services.compression.service.TokenCounter"
) as MockTC:
MockTC.count_query_tokens.return_value = 100
with pytest.raises(RuntimeError, match="LLM error"):
svc.compress_conversation(sample_conversation, 2)
@pytest.mark.unit
class TestCompressAndSave:
def test_saves_metadata_to_db(
self, mock_llm, mock_conversation_service, sample_conversation
):
mock_builder = MagicMock()
mock_builder.build_prompt.return_value = [
{"role": "system", "content": "C"},
{"role": "user", "content": "..."},
]
mock_builder.version = "v1.0"
svc = CompressionService(
llm=mock_llm,
model_id="gpt-4",
conversation_service=mock_conversation_service,
prompt_builder=mock_builder,
)
with patch(
"application.api.answer.services.compression.service.TokenCounter"
) as MockTC:
MockTC.count_query_tokens.return_value = 500
MockTC.count_message_tokens.return_value = 50
result = svc.compress_and_save("conv_123", sample_conversation, 2)
assert isinstance(result, CompressionMetadata)
mock_conversation_service.update_compression_metadata.assert_called_once_with(
"conv_123", result.to_dict()
)
def test_raises_without_conversation_service(self, mock_llm, sample_conversation):
mock_builder = MagicMock()
mock_builder.version = "v1.0"
svc = CompressionService(
llm=mock_llm, model_id="gpt-4", prompt_builder=mock_builder
)
with pytest.raises(ValueError, match="conversation_service required"):
svc.compress_and_save("conv_123", sample_conversation, 2)
@pytest.mark.unit
class TestGetCompressedContext:
def test_no_compression_returns_full_history(self, mock_llm):
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
conversation = {
"queries": [{"prompt": "q1", "response": "r1"}],
"compression_metadata": {},
}
summary, queries = svc.get_compressed_context(conversation)
assert summary is None
assert queries == [{"prompt": "q1", "response": "r1"}]
def test_no_compression_points_returns_full_history(self, mock_llm):
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
conversation = {
"queries": [{"prompt": "q1", "response": "r1"}],
"compression_metadata": {"is_compressed": True, "compression_points": []},
}
summary, queries = svc.get_compressed_context(conversation)
assert summary is None
assert len(queries) == 1
def test_with_compression_returns_summary_and_recent(self, mock_llm):
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
conversation = {
"queries": [
{"prompt": "q0", "response": "r0"},
{"prompt": "q1", "response": "r1"},
{"prompt": "q2", "response": "r2"},
],
"compression_metadata": {
"is_compressed": True,
"compression_points": [
{
"query_index": 1,
"compressed_summary": "Summary of q0 and q1",
"compressed_token_count": 50,
"original_token_count": 500,
}
],
},
}
summary, queries = svc.get_compressed_context(conversation)
assert summary == "Summary of q0 and q1"
assert len(queries) == 1
assert queries[0]["prompt"] == "q2"
def test_none_queries_returns_empty(self, mock_llm):
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
conversation = {
"queries": None,
"compression_metadata": {},
}
summary, queries = svc.get_compressed_context(conversation)
assert summary is None
assert queries == []
def test_exception_falls_back_to_full_history(self, mock_llm):
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
conversation = {
"queries": [{"prompt": "q", "response": "r"}],
"compression_metadata": {
"is_compressed": True,
"compression_points": "invalid", # This will cause an error
},
}
summary, queries = svc.get_compressed_context(conversation)
assert summary is None
assert queries == [{"prompt": "q", "response": "r"}]
def test_exception_with_none_queries_returns_empty(self, mock_llm):
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
# Force exception by making compression_points non-iterable
conversation = {
"queries": None,
"compression_metadata": {
"is_compressed": True,
"compression_points": "bad",
},
}
summary, queries = svc.get_compressed_context(conversation)
assert summary is None
assert queries == []
@pytest.mark.unit
class TestExtractSummary:
def test_extracts_from_summary_tags(self, mock_llm):
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
response = "<analysis>Some analysis</analysis><summary>The actual summary</summary>"
result = svc._extract_summary(response)
assert result == "The actual summary"
def test_removes_analysis_tags_when_no_summary(self, mock_llm):
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
response = "<analysis>analysis text</analysis>Raw summary text here"
result = svc._extract_summary(response)
assert result == "Raw summary text here"
def test_returns_full_response_when_no_tags(self, mock_llm):
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
response = "Just a plain text response"
result = svc._extract_summary(response)
assert result == "Just a plain text response"
def test_multiline_summary(self, mock_llm):
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
response = "<summary>Line 1\nLine 2\nLine 3</summary>"
result = svc._extract_summary(response)
assert "Line 1" in result
assert "Line 3" in result
def test_strips_whitespace(self, mock_llm):
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
response = "<summary> Trimmed </summary>"
result = svc._extract_summary(response)
assert result == "Trimmed"
@pytest.mark.unit
class TestLogToolCallStats:
def test_no_tool_calls(self, mock_llm):
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
queries = [{"prompt": "q", "response": "r"}]
# Should not raise
svc._log_tool_call_stats(queries)
def test_with_tool_calls(self, mock_llm):
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
queries = [
{
"prompt": "q",
"response": "r",
"tool_calls": [
{
"tool_name": "search",
"action_name": "web",
"result": "result text",
},
{
"tool_name": "search",
"action_name": "web",
"result": "more text",
},
],
}
]
# Should not raise - just logs
svc._log_tool_call_stats(queries)
def test_empty_queries(self, mock_llm):
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
svc._log_tool_call_stats([])
def test_tool_call_with_none_result(self, mock_llm):
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
queries = [
{
"prompt": "q",
"response": "r",
"tool_calls": [
{
"tool_name": "t",
"action_name": "a",
"result": None,
}
],
}
]
svc._log_tool_call_stats(queries)

View File

@@ -0,0 +1,46 @@
from unittest.mock import patch
import pytest
@pytest.mark.unit
class TestCompressionThresholdChecker:
def _make_checker(self, pct=0.7):
from application.api.answer.services.compression.threshold_checker import (
CompressionThresholdChecker,
)
return CompressionThresholdChecker(threshold_percentage=pct)
@patch(
"application.api.answer.services.compression.threshold_checker.get_token_limit",
return_value=8000,
)
@patch(
"application.api.answer.services.compression.threshold_checker.TokenCounter.count_message_tokens",
return_value=6000,
)
def test_check_message_tokens_above_threshold(self, mock_count, mock_limit):
checker = self._make_checker(0.7)
assert checker.check_message_tokens([{"role": "user"}], "gpt-4") is True
@patch(
"application.api.answer.services.compression.threshold_checker.get_token_limit",
return_value=8000,
)
@patch(
"application.api.answer.services.compression.threshold_checker.TokenCounter.count_message_tokens",
return_value=1000,
)
def test_check_message_tokens_below_threshold(self, mock_count, mock_limit):
checker = self._make_checker(0.7)
assert checker.check_message_tokens([{"role": "user"}], "gpt-4") is False
@patch(
"application.api.answer.services.compression.threshold_checker.TokenCounter.count_message_tokens",
side_effect=Exception("Token error"),
)
def test_check_message_tokens_exception_returns_false(self, mock_count):
checker = self._make_checker(0.7)
assert checker.check_message_tokens([], "gpt-4") is False

View File

@@ -0,0 +1,131 @@
"""Tests for application/api/answer/services/compression/types.py"""
from datetime import datetime, timezone
import pytest
from application.api.answer.services.compression.types import (
CompressionMetadata,
CompressionResult,
)
@pytest.mark.unit
class TestCompressionMetadata:
def _make_metadata(self, **overrides):
defaults = dict(
timestamp=datetime(2025, 1, 1, tzinfo=timezone.utc),
query_index=5,
compressed_summary="Summary of conversation",
original_token_count=5000,
compressed_token_count=500,
compression_ratio=10.0,
model_used="gpt-4",
compression_prompt_version="v1.0",
)
defaults.update(overrides)
return CompressionMetadata(**defaults)
def test_to_dict_contains_all_fields(self):
meta = self._make_metadata()
d = meta.to_dict()
assert d["timestamp"] == datetime(2025, 1, 1, tzinfo=timezone.utc)
assert d["query_index"] == 5
assert d["compressed_summary"] == "Summary of conversation"
assert d["original_token_count"] == 5000
assert d["compressed_token_count"] == 500
assert d["compression_ratio"] == 10.0
assert d["model_used"] == "gpt-4"
assert d["compression_prompt_version"] == "v1.0"
def test_to_dict_returns_dict_type(self):
meta = self._make_metadata()
assert isinstance(meta.to_dict(), dict)
def test_to_dict_field_count(self):
meta = self._make_metadata()
d = meta.to_dict()
assert len(d) == 8
def test_attributes_accessible(self):
meta = self._make_metadata(query_index=10, compression_ratio=5.5)
assert meta.query_index == 10
assert meta.compression_ratio == 5.5
def test_zero_compressed_tokens(self):
meta = self._make_metadata(compressed_token_count=0, compression_ratio=0)
d = meta.to_dict()
assert d["compressed_token_count"] == 0
assert d["compression_ratio"] == 0
@pytest.mark.unit
class TestCompressionResult:
def test_success_with_compression(self):
meta = CompressionMetadata(
timestamp=datetime.now(timezone.utc),
query_index=3,
compressed_summary="summary",
original_token_count=1000,
compressed_token_count=100,
compression_ratio=10.0,
model_used="gpt-4",
compression_prompt_version="v1.0",
)
queries = [{"prompt": "q1", "response": "r1"}]
result = CompressionResult.success_with_compression("summary", queries, meta)
assert result.success is True
assert result.compressed_summary == "summary"
assert result.recent_queries == queries
assert result.metadata is meta
assert result.compression_performed is True
assert result.error is None
def test_success_no_compression(self):
queries = [{"prompt": "q1", "response": "r1"}]
result = CompressionResult.success_no_compression(queries)
assert result.success is True
assert result.compressed_summary is None
assert result.recent_queries == queries
assert result.metadata is None
assert result.compression_performed is False
assert result.error is None
def test_failure(self):
result = CompressionResult.failure("something went wrong")
assert result.success is False
assert result.error == "something went wrong"
assert result.compression_performed is False
assert result.compressed_summary is None
assert result.recent_queries == []
assert result.metadata is None
def test_as_history_extracts_prompt_response(self):
queries = [
{"prompt": "Hello", "response": "Hi", "extra": "ignored"},
{"prompt": "How?", "response": "Fine"},
]
result = CompressionResult.success_no_compression(queries)
history = result.as_history()
assert len(history) == 2
assert history[0] == {"prompt": "Hello", "response": "Hi"}
assert history[1] == {"prompt": "How?", "response": "Fine"}
def test_as_history_empty_queries(self):
result = CompressionResult.success_no_compression([])
assert result.as_history() == []
def test_default_recent_queries_is_empty_list(self):
result = CompressionResult(success=True)
assert result.recent_queries == []
assert result.as_history() == []
def test_success_no_compression_with_empty_list(self):
result = CompressionResult.success_no_compression([])
assert result.success is True
assert result.recent_queries == []

View File

@@ -0,0 +1,578 @@
"""Unit tests for application/api/answer/routes/base.py — BaseAnswerResource.
Additional coverage beyond tests/api/answer/routes/test_base.py:
- _prepare_tool_calls_for_logging: truncation, non-dict items
- complete_stream: tool_calls, thoughts, structured output, metadata,
isNoneDoc, GeneratorExit handling, compression metadata
- process_response_stream: structured answer, incomplete stream
- error_stream_generate: format
- check_usage: string boolean parsing ("True" strings)
"""
import json
from unittest.mock import MagicMock
import pytest
from bson import ObjectId
@pytest.mark.unit
class TestPrepareToolCallsForLogging:
def test_empty_list(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
assert resource._prepare_tool_calls_for_logging([]) == []
def test_none_returns_empty(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
assert resource._prepare_tool_calls_for_logging(None) == []
def test_truncates_long_result(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
tool_calls = [{"result": "x" * 20000}]
prepared = resource._prepare_tool_calls_for_logging(tool_calls, max_chars=100)
assert len(prepared[0]["result"]) == 100
def test_truncates_result_full(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
tool_calls = [{"result_full": "y" * 20000}]
prepared = resource._prepare_tool_calls_for_logging(tool_calls, max_chars=50)
assert len(prepared[0]["result_full"]) == 50
def test_non_dict_items_wrapped(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
tool_calls = ["string_item", 42]
prepared = resource._prepare_tool_calls_for_logging(tool_calls)
assert prepared[0] == {"result": "string_item"}
assert prepared[1] == {"result": "42"}
def test_preserves_short_results(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
tool_calls = [{"tool_name": "search", "result": "short text"}]
prepared = resource._prepare_tool_calls_for_logging(tool_calls)
assert prepared[0]["result"] == "short text"
assert prepared[0]["tool_name"] == "search"
@pytest.mark.unit
class TestCompleteStreamToolCalls:
def test_streams_tool_calls(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
mock_agent = MagicMock()
mock_agent.gen.return_value = iter(
[
{"answer": "Using tool..."},
{"tool_calls": [{"name": "search", "result": "found"}]},
]
)
stream = list(
resource.complete_stream(
question="Search for X",
agent=mock_agent,
conversation_id=None,
user_api_key=None,
decoded_token={"sub": "u"},
should_save_conversation=False,
)
)
tool_chunks = [s for s in stream if '"type": "tool_calls"' in s]
assert len(tool_chunks) == 1
def test_streams_thought_events(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
mock_agent = MagicMock()
mock_agent.gen.return_value = iter(
[
{"thought": "Let me think..."},
{"answer": "Here is the answer"},
]
)
stream = list(
resource.complete_stream(
question="Q",
agent=mock_agent,
conversation_id=None,
user_api_key=None,
decoded_token={"sub": "u"},
should_save_conversation=False,
)
)
thought_chunks = [s for s in stream if '"type": "thought"' in s]
assert len(thought_chunks) == 1
assert "Let me think" in thought_chunks[0]
@pytest.mark.unit
class TestCompleteStreamStructuredOutput:
def test_streams_structured_answer(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
mock_agent = MagicMock()
mock_agent.gen.return_value = iter(
[
{
"answer": '{"key": "value"}',
"structured": True,
"schema": {"type": "object"},
},
]
)
stream = list(
resource.complete_stream(
question="Q",
agent=mock_agent,
conversation_id=None,
user_api_key=None,
decoded_token={"sub": "u"},
should_save_conversation=False,
)
)
structured_chunks = [
s for s in stream if '"type": "structured_answer"' in s
]
assert len(structured_chunks) == 1
@pytest.mark.unit
class TestCompleteStreamMetadata:
def test_metadata_collected(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
mock_agent = MagicMock()
mock_agent.gen.return_value = iter(
[
{"metadata": {"search_query": "test"}},
{"answer": "result"},
]
)
stream = list(
resource.complete_stream(
question="Q",
agent=mock_agent,
conversation_id=None,
user_api_key=None,
decoded_token={"sub": "u"},
should_save_conversation=False,
)
)
# Should not crash, metadata handled silently
answer_chunks = [s for s in stream if '"type": "answer"' in s]
assert len(answer_chunks) == 1
@pytest.mark.unit
class TestCompleteStreamIsNoneDoc:
def test_isNoneDoc_sets_source_to_none(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
mock_agent = MagicMock()
mock_agent.gen.return_value = iter(
[
{"answer": "answer"},
{"sources": [{"text": "doc", "source": "real_source"}]},
]
)
stream = list(
resource.complete_stream(
question="Q",
agent=mock_agent,
conversation_id=None,
user_api_key=None,
decoded_token={"sub": "u"},
isNoneDoc=True,
should_save_conversation=False,
)
)
# Verify stream completes without error
end_chunks = [s for s in stream if '"type": "end"' in s]
assert len(end_chunks) == 1
@pytest.mark.unit
class TestCompleteStreamErrorType:
def test_error_type_event_sanitized(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
mock_agent = MagicMock()
mock_agent.gen.return_value = iter(
[
{"type": "error", "error": "API key invalid: sk-xxx"},
]
)
stream = list(
resource.complete_stream(
question="Q",
agent=mock_agent,
conversation_id=None,
user_api_key=None,
decoded_token={"sub": "u"},
should_save_conversation=False,
)
)
error_chunks = [s for s in stream if '"type": "error"' in s]
assert len(error_chunks) == 1
def test_non_error_type_event_passed_through(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
mock_agent = MagicMock()
mock_agent.gen.return_value = iter(
[
{"type": "custom_event", "data": "value"},
]
)
stream = list(
resource.complete_stream(
question="Q",
agent=mock_agent,
conversation_id=None,
user_api_key=None,
decoded_token={"sub": "u"},
should_save_conversation=False,
)
)
custom_chunks = [s for s in stream if '"type": "custom_event"' in s]
assert len(custom_chunks) == 1
@pytest.mark.unit
class TestProcessResponseStreamExtended:
def test_handles_structured_answer(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
stream = [
f'data: {json.dumps({"type": "structured_answer", "answer": "{}", "structured": True, "schema": None})}\n\n',
f'data: {json.dumps({"type": "id", "id": str(ObjectId())})}\n\n',
f'data: {json.dumps({"type": "end"})}\n\n',
]
result = resource.process_response_stream(iter(stream))
assert result[1] == "{}"
# Structured output adds extra tuple element
assert len(result) == 7
assert result[6]["structured"] is True
def test_handles_tool_calls_event(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
stream = [
f'data: {json.dumps({"type": "answer", "answer": "result"})}\n\n',
f'data: {json.dumps({"type": "tool_calls", "tool_calls": [{"name": "t1"}]})}\n\n',
f'data: {json.dumps({"type": "id", "id": "conv1"})}\n\n',
f'data: {json.dumps({"type": "end"})}\n\n',
]
result = resource.process_response_stream(iter(stream))
assert result[3] == [{"name": "t1"}]
def test_incomplete_stream(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
stream = [
f'data: {json.dumps({"type": "answer", "answer": "partial"})}\n\n',
]
result = resource.process_response_stream(iter(stream))
assert result[4] == "Stream ended unexpectedly"
def test_handles_thought_event(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
stream = [
f'data: {json.dumps({"type": "thought", "thought": "thinking..."})}\n\n',
f'data: {json.dumps({"type": "end"})}\n\n',
]
result = resource.process_response_stream(iter(stream))
assert result[4] == "thinking..."
@pytest.mark.unit
class TestCheckUsageStringBooleans:
def test_string_true_parsed_correctly(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
from application.core.settings import settings
with flask_app.app_context():
agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"]
agents_collection.insert_one(
{
"_id": ObjectId(),
"key": "str_bool_key",
"limited_token_mode": "True",
"token_limit": 1000000,
"limited_request_mode": "True",
"request_limit": 1000000,
}
)
resource = BaseAnswerResource()
result = resource.check_usage({"user_api_key": "str_bool_key"})
# Should not exceed limits, so returns None
assert result is None
@pytest.mark.unit
class TestCompleteStreamCompressionMetadata:
"""Cover lines 307-319 (compression metadata persistence in complete_stream)."""
def test_compression_metadata_persisted(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
mock_agent = MagicMock()
mock_agent.gen.return_value = iter(
[
{"answer": "compressed answer"},
]
)
mock_agent.compression_metadata = {"ratio": 2.5}
mock_agent.compression_saved = False
mock_agent.tool_calls = []
resource.conversation_service = MagicMock()
resource.conversation_service.save_conversation.return_value = "conv123"
stream = list(
resource.complete_stream(
question="Q",
agent=mock_agent,
conversation_id=None,
user_api_key=None,
decoded_token={"sub": "u"},
should_save_conversation=True,
model_id="gpt-4",
)
)
# Verify compression metadata was persisted
resource.conversation_service.update_compression_metadata.assert_called_once_with(
"conv123", {"ratio": 2.5}
)
resource.conversation_service.append_compression_message.assert_called_once()
assert mock_agent.compression_saved is True
end_chunks = [s for s in stream if '"type": "end"' in s]
assert len(end_chunks) == 1
def test_compression_metadata_error_handled(self, mock_mongo_db, flask_app):
"""Cover lines 318-322: compression metadata persistence error."""
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
mock_agent = MagicMock()
mock_agent.gen.return_value = iter([{"answer": "answer"}])
mock_agent.compression_metadata = {"ratio": 2.5}
mock_agent.compression_saved = False
mock_agent.tool_calls = []
resource.conversation_service = MagicMock()
resource.conversation_service.save_conversation.return_value = "conv123"
resource.conversation_service.update_compression_metadata.side_effect = (
Exception("db error")
)
stream = list(
resource.complete_stream(
question="Q",
agent=mock_agent,
conversation_id=None,
user_api_key=None,
decoded_token={"sub": "u"},
should_save_conversation=True,
model_id="gpt-4",
)
)
# Stream should still complete despite compression error
end_chunks = [s for s in stream if '"type": "end"' in s]
assert len(end_chunks) == 1
@pytest.mark.unit
class TestCompleteStreamLogTruncation:
"""Cover line 354: log data truncation for long values."""
def test_long_response_truncated_in_log(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
mock_agent = MagicMock()
long_answer = "x" * 20000
mock_agent.gen.return_value = iter([{"answer": long_answer}])
mock_agent.tool_calls = []
stream = list(
resource.complete_stream(
question="Q",
agent=mock_agent,
conversation_id=None,
user_api_key=None,
decoded_token={"sub": "u"},
should_save_conversation=False,
)
)
end_chunks = [s for s in stream if '"type": "end"' in s]
assert len(end_chunks) == 1
@pytest.mark.unit
class TestCompleteStreamGeneratorExit:
"""Cover lines 360-416 (GeneratorExit handling in complete_stream)."""
def test_generator_exit_saves_partial_response(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
mock_agent = MagicMock()
def gen_with_answers():
yield {"answer": "partial"}
yield {"answer": " answer"}
# Simulating a long stream that gets interrupted
yield {"answer": " more"}
mock_agent.gen.return_value = gen_with_answers()
mock_agent.compression_metadata = None
mock_agent.compression_saved = False
mock_agent.tool_calls = []
resource.conversation_service = MagicMock()
resource.conversation_service.save_conversation.return_value = "conv1"
gen = resource.complete_stream(
question="Q",
agent=mock_agent,
conversation_id="conv1",
user_api_key=None,
decoded_token={"sub": "u"},
should_save_conversation=True,
model_id="gpt-4",
)
# Read first chunk and then close (simulating client disconnect)
chunk = next(gen)
assert "partial" in chunk
gen.close() # This triggers GeneratorExit
def test_generator_exit_with_compression_metadata(self, mock_mongo_db, flask_app):
"""Cover lines 393-411: GeneratorExit with compression metadata."""
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
mock_agent = MagicMock()
def gen_answers():
yield {"answer": "partial answer"}
mock_agent.gen.return_value = gen_answers()
mock_agent.compression_metadata = {"ratio": 3.0}
mock_agent.compression_saved = False
mock_agent.tool_calls = []
resource.conversation_service = MagicMock()
resource.conversation_service.save_conversation.return_value = "conv1"
gen = resource.complete_stream(
question="Q",
agent=mock_agent,
conversation_id="conv1",
user_api_key=None,
decoded_token={"sub": "u"},
should_save_conversation=True,
model_id="gpt-4",
isNoneDoc=True,
)
next(gen)
gen.close()
def test_generator_exit_save_error_handled(self, mock_mongo_db, flask_app):
"""Cover lines 412-415: exception during partial save."""
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
mock_agent = MagicMock()
def gen_answers():
yield {"answer": "partial"}
mock_agent.gen.return_value = gen_answers()
mock_agent.compression_metadata = None
mock_agent.compression_saved = False
mock_agent.tool_calls = []
resource.conversation_service = MagicMock()
resource.conversation_service.save_conversation.side_effect = Exception(
"save error"
)
gen = resource.complete_stream(
question="Q",
agent=mock_agent,
conversation_id="conv1",
user_api_key=None,
decoded_token={"sub": "u"},
should_save_conversation=True,
model_id="gpt-4",
)
next(gen)
gen.close() # Should not crash even with save error

View File

@@ -0,0 +1,478 @@
"""Unit tests for application/api/answer/services/conversation_service.py.
Additional coverage beyond tests/api/answer/services/test_conversation_service.py:
- save_conversation: index-based update, metadata persistence, agent key tracking
- update_compression_metadata
- append_compression_message
- get_compression_metadata
- Edge cases: None token, empty summary, shared_with access
"""
from datetime import datetime, timezone
from unittest.mock import MagicMock, Mock
import pytest
from bson import ObjectId
@pytest.mark.unit
class TestConversationServiceGetExtended:
def test_returns_conversation_for_shared_user(self, mock_mongo_db):
from application.api.answer.services.conversation_service import (
ConversationService,
)
from application.core.settings import settings
service = ConversationService()
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
conv_id = ObjectId()
collection.insert_one(
{
"_id": conv_id,
"user": "owner_123",
"shared_with": ["shared_user"],
"name": "Shared Conv",
"queries": [],
}
)
result = service.get_conversation(str(conv_id), "shared_user")
assert result is not None
assert result["name"] == "Shared Conv"
def test_handles_exception_gracefully(self, mock_mongo_db):
from application.api.answer.services.conversation_service import (
ConversationService,
)
service = ConversationService()
# Pass an invalid ObjectId
result = service.get_conversation("not-an-objectid", "user_123")
assert result is None
@pytest.mark.unit
class TestSaveConversationExtended:
def test_raises_for_none_token(self, mock_mongo_db):
from application.api.answer.services.conversation_service import (
ConversationService,
)
service = ConversationService()
with pytest.raises(ValueError, match="Invalid or missing authentication"):
service.save_conversation(
conversation_id=None,
question="Q",
response="A",
thought="",
sources=[],
tool_calls=[],
llm=Mock(),
model_id="m",
decoded_token=None,
)
def test_update_existing_at_index(self, mock_mongo_db):
from application.api.answer.services.conversation_service import (
ConversationService,
)
from application.core.settings import settings
service = ConversationService()
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
conv_id = ObjectId()
collection.insert_one(
{
"_id": conv_id,
"user": "user_123",
"name": "Conv",
"queries": [
{
"prompt": "Q1",
"response": "A1",
"thought": "",
"sources": [],
"tool_calls": [],
},
{
"prompt": "Q2",
"response": "A2",
"thought": "",
"sources": [],
"tool_calls": [],
},
],
}
)
result = service.save_conversation(
conversation_id=str(conv_id),
question="Q1_updated",
response="A1_updated",
thought="thinking",
sources=[],
tool_calls=[],
llm=Mock(),
model_id="gpt-4",
decoded_token={"sub": "user_123"},
index=0,
)
assert result == str(conv_id)
saved = collection.find_one({"_id": conv_id})
assert saved["queries"][0]["prompt"] == "Q1_updated"
assert saved["queries"][0]["response"] == "A1_updated"
def test_update_at_index_unauthorized(self, mock_mongo_db):
from application.api.answer.services.conversation_service import (
ConversationService,
)
from application.core.settings import settings
service = ConversationService()
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
conv_id = ObjectId()
collection.insert_one(
{
"_id": conv_id,
"user": "owner",
"queries": [{"prompt": "Q", "response": "A"}],
}
)
with pytest.raises(ValueError, match="not found or unauthorized"):
service.save_conversation(
conversation_id=str(conv_id),
question="Hack",
response="Attempt",
thought="",
sources=[],
tool_calls=[],
llm=Mock(),
model_id="m",
decoded_token={"sub": "hacker"},
index=0,
)
def test_saves_metadata(self, mock_mongo_db):
from application.api.answer.services.conversation_service import (
ConversationService,
)
from application.core.settings import settings
service = ConversationService()
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
mock_llm = Mock()
mock_llm.gen.return_value = "Title"
conv_id = service.save_conversation(
conversation_id=None,
question="Q",
response="A",
thought="",
sources=[],
tool_calls=[],
llm=mock_llm,
model_id="m",
decoded_token={"sub": "user_123"},
metadata={"search_query": "rewritten query"},
)
saved = collection.find_one({"_id": ObjectId(conv_id)})
assert saved["queries"][0]["metadata"] == {"search_query": "rewritten query"}
def test_no_metadata_when_none(self, mock_mongo_db):
from application.api.answer.services.conversation_service import (
ConversationService,
)
from application.core.settings import settings
service = ConversationService()
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
mock_llm = Mock()
mock_llm.gen.return_value = "Title"
conv_id = service.save_conversation(
conversation_id=None,
question="Q",
response="A",
thought="",
sources=[],
tool_calls=[],
llm=mock_llm,
model_id="m",
decoded_token={"sub": "user_123"},
metadata=None,
)
saved = collection.find_one({"_id": ObjectId(conv_id)})
assert "metadata" not in saved["queries"][0]
def test_saves_with_api_key_and_agent(self, mock_mongo_db):
from application.api.answer.services.conversation_service import (
ConversationService,
)
from application.core.settings import settings
service = ConversationService()
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"]
agent_id = ObjectId()
agents_collection.insert_one(
{"_id": agent_id, "key": "agent_key_123", "user": "user_123"}
)
mock_llm = Mock()
mock_llm.gen.return_value = "Title"
conv_id = service.save_conversation(
conversation_id=None,
question="Q",
response="A",
thought="",
sources=[],
tool_calls=[],
llm=mock_llm,
model_id="m",
decoded_token={"sub": "user_123"},
api_key="agent_key_123",
agent_id=str(agent_id),
)
saved = collection.find_one({"_id": ObjectId(conv_id)})
assert saved["api_key"] == "agent_key_123"
assert saved["agent_id"] == str(agent_id)
def test_empty_completion_uses_question_prefix(self, mock_mongo_db):
from application.api.answer.services.conversation_service import (
ConversationService,
)
from application.core.settings import settings
service = ConversationService()
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
mock_llm = Mock()
mock_llm.gen.return_value = " " # whitespace only
conv_id = service.save_conversation(
conversation_id=None,
question="What is the meaning of life in programming?",
response="42",
thought="",
sources=[],
tool_calls=[],
llm=mock_llm,
model_id="m",
decoded_token={"sub": "user_123"},
)
saved = collection.find_one({"_id": ObjectId(conv_id)})
assert saved["name"] == "What is the meaning of life in programming?"[:50]
@pytest.mark.unit
class TestUpdateCompressionMetadata:
def test_updates_compression_fields(self, mock_mongo_db):
from application.api.answer.services.conversation_service import (
ConversationService,
)
from application.core.settings import settings
service = ConversationService()
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
conv_id = ObjectId()
collection.insert_one(
{"_id": conv_id, "user": "u", "queries": []}
)
meta = {
"timestamp": datetime.now(timezone.utc),
"compressed_summary": "Summary of conversation",
"model_used": "gpt-4",
}
service.update_compression_metadata(str(conv_id), meta)
saved = collection.find_one({"_id": conv_id})
assert saved["compression_metadata"]["is_compressed"] is True
assert len(saved["compression_metadata"]["compression_points"]) == 1
@pytest.mark.unit
class TestAppendCompressionMessage:
def test_appends_summary_query(self, mock_mongo_db):
from application.api.answer.services.conversation_service import (
ConversationService,
)
from application.core.settings import settings
service = ConversationService()
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
conv_id = ObjectId()
collection.insert_one(
{"_id": conv_id, "user": "u", "queries": []}
)
meta = {
"compressed_summary": "This is the summary",
"timestamp": datetime.now(timezone.utc),
"model_used": "gpt-4",
}
service.append_compression_message(str(conv_id), meta)
saved = collection.find_one({"_id": conv_id})
assert len(saved["queries"]) == 1
assert saved["queries"][0]["prompt"] == "[Context Compression Summary]"
assert saved["queries"][0]["response"] == "This is the summary"
def test_empty_summary_does_nothing(self, mock_mongo_db):
from application.api.answer.services.conversation_service import (
ConversationService,
)
from application.core.settings import settings
service = ConversationService()
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
conv_id = ObjectId()
collection.insert_one(
{"_id": conv_id, "user": "u", "queries": []}
)
service.append_compression_message(str(conv_id), {"compressed_summary": ""})
saved = collection.find_one({"_id": conv_id})
assert len(saved["queries"]) == 0
@pytest.mark.unit
class TestGetCompressionMetadata:
def test_returns_metadata(self, mock_mongo_db):
from application.api.answer.services.conversation_service import (
ConversationService,
)
from application.core.settings import settings
service = ConversationService()
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
conv_id = ObjectId()
collection.insert_one(
{
"_id": conv_id,
"user": "u",
"compression_metadata": {"is_compressed": True},
}
)
result = service.get_compression_metadata(str(conv_id))
assert result is not None
assert result["is_compressed"] is True
def test_returns_none_for_no_metadata(self, mock_mongo_db):
from application.api.answer.services.conversation_service import (
ConversationService,
)
from application.core.settings import settings
service = ConversationService()
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
conv_id = ObjectId()
collection.insert_one({"_id": conv_id, "user": "u"})
result = service.get_compression_metadata(str(conv_id))
assert result is None
def test_returns_none_for_missing_conversation(self, mock_mongo_db):
from application.api.answer.services.conversation_service import (
ConversationService,
)
service = ConversationService()
result = service.get_compression_metadata(str(ObjectId()))
assert result is None
def test_handles_invalid_id(self, mock_mongo_db):
from application.api.answer.services.conversation_service import (
ConversationService,
)
service = ConversationService()
result = service.get_compression_metadata("invalid-id")
assert result is None
# =====================================================================
# Coverage gap tests (lines 233-237, 258, 261)
# =====================================================================
@pytest.mark.unit
class TestConversationServiceGaps:
def test_update_compression_metadata_exception_raises(self, mock_mongo_db):
"""Cover lines 233-237: exception during update raises."""
from application.api.answer.services.conversation_service import (
ConversationService,
)
service = ConversationService()
service.conversations_collection = MagicMock()
service.conversations_collection.update_one.side_effect = Exception("db error")
with pytest.raises(Exception, match="db error"):
service.update_compression_metadata(
str(ObjectId()),
{
"compressed_summary": "summary",
"query_index": 5,
"compressed_token_count": 100,
"original_token_count": 1000,
},
)
def test_append_compression_message_with_summary(self, mock_mongo_db):
"""Cover lines 258, 261: appends compression message to conversation."""
from application.api.answer.services.conversation_service import (
ConversationService,
)
service = ConversationService()
service.conversations_collection = MagicMock()
conv_id = str(ObjectId())
metadata = {
"compressed_summary": "This is a summary of the conversation.",
"timestamp": "2024-01-01T00:00:00",
"model_used": "gpt-4",
}
service.append_compression_message(conv_id, metadata)
service.conversations_collection.update_one.assert_called_once()
def test_append_compression_message_empty_summary_skips(self, mock_mongo_db):
"""Cover: empty summary does not insert."""
from application.api.answer.services.conversation_service import (
ConversationService,
)
service = ConversationService()
service.conversations_collection = MagicMock()
service.append_compression_message(str(ObjectId()), {"compressed_summary": ""})
service.conversations_collection.update_one.assert_not_called()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,821 @@
"""Tests for application/api/connector/routes.py"""
import base64
import json
from unittest.mock import MagicMock, patch
import mongomock
import pytest
@pytest.fixture
def app():
with patch("application.app.handle_auth", return_value={"sub": "test_user"}):
from application.app import app as flask_app
flask_app.config["TESTING"] = True
yield flask_app
@pytest.fixture
def client(app):
return app.test_client()
@pytest.fixture
def mock_sessions(monkeypatch):
mock_client = mongomock.MongoClient()
mock_db = mock_client["docsgpt"]
sessions = mock_db["connector_sessions"]
sources = mock_db["sources"]
monkeypatch.setattr("application.api.connector.routes.sessions_collection", sessions)
monkeypatch.setattr("application.api.connector.routes.sources_collection", sources)
return {"sessions": sessions, "sources": sources}
class TestConnectorAuth:
@pytest.mark.unit
def test_missing_provider(self, client):
resp = client.get("/api/connectors/auth")
assert resp.status_code == 400
@pytest.mark.unit
def test_unsupported_provider(self, client):
resp = client.get("/api/connectors/auth?provider=dropbox")
assert resp.status_code == 400
@pytest.mark.unit
def test_unauthorized(self, client, app):
with patch("application.app.handle_auth", return_value=None):
resp = client.get("/api/connectors/auth?provider=google_drive")
data = json.loads(resp.data)
# decoded_token is None -> 401
assert resp.status_code == 401 or data.get("error") == "Unauthorized"
@pytest.mark.unit
def test_success(self, client, mock_sessions):
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
MockCC.is_supported.return_value = True
mock_auth = MagicMock()
mock_auth.get_authorization_url.return_value = "https://oauth.example.com/auth"
MockCC.create_auth.return_value = mock_auth
resp = client.get("/api/connectors/auth?provider=google_drive")
assert resp.status_code == 200
data = json.loads(resp.data)
assert data["success"] is True
assert "authorization_url" in data
@pytest.mark.unit
def test_exception_returns_500(self, client, mock_sessions):
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
MockCC.is_supported.return_value = True
MockCC.create_auth.side_effect = Exception("oauth fail")
resp = client.get("/api/connectors/auth?provider=google_drive")
assert resp.status_code == 500
class TestConnectorFiles:
@pytest.mark.unit
def test_missing_params(self, client):
resp = client.post("/api/connectors/files", json={"provider": "google_drive"})
assert resp.status_code == 400
@pytest.mark.unit
def test_invalid_session(self, client, mock_sessions):
resp = client.post("/api/connectors/files", json={
"provider": "google_drive",
"session_token": "bad_token",
})
assert resp.status_code == 401
@pytest.mark.unit
def test_success(self, client, mock_sessions):
mock_sessions["sessions"].insert_one({
"session_token": "valid_tok",
"user": "test_user",
"provider": "google_drive",
})
mock_doc = MagicMock()
mock_doc.doc_id = "f1"
mock_doc.extra_info = {
"file_name": "test.pdf",
"mime_type": "application/pdf",
"size": 1024,
"modified_time": "2025-01-01T12:00:00.000Z",
"is_folder": False,
}
mock_loader = MagicMock()
mock_loader.load_data.return_value = [mock_doc]
mock_loader.next_page_token = None
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
MockCC.create_connector.return_value = mock_loader
resp = client.post("/api/connectors/files", json={
"provider": "google_drive",
"session_token": "valid_tok",
})
assert resp.status_code == 200
data = json.loads(resp.data)
assert data["success"] is True
assert len(data["files"]) == 1
@pytest.mark.unit
def test_no_modified_time(self, client, mock_sessions):
mock_sessions["sessions"].insert_one({
"session_token": "tok2",
"user": "test_user",
"provider": "google_drive",
})
mock_doc = MagicMock()
mock_doc.doc_id = "f1"
mock_doc.extra_info = {"file_name": "test.pdf", "mime_type": "application/pdf"}
mock_loader = MagicMock()
mock_loader.load_data.return_value = [mock_doc]
mock_loader.next_page_token = None
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
MockCC.create_connector.return_value = mock_loader
resp = client.post("/api/connectors/files", json={
"provider": "google_drive", "session_token": "tok2",
})
assert resp.status_code == 200
class TestConnectorValidateSession:
@pytest.mark.unit
def test_missing_params(self, client):
resp = client.post("/api/connectors/validate-session", json={"provider": "google_drive"})
assert resp.status_code == 400
@pytest.mark.unit
def test_invalid_session(self, client, mock_sessions):
resp = client.post("/api/connectors/validate-session", json={
"provider": "google_drive", "session_token": "bad",
})
assert resp.status_code == 401
@pytest.mark.unit
def test_valid_non_expired(self, client, mock_sessions):
mock_sessions["sessions"].insert_one({
"session_token": "valid",
"user": "test_user",
"provider": "google_drive",
"token_info": {"access_token": "at", "refresh_token": "rt", "expiry": None},
"user_email": "user@example.com",
})
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
mock_auth = MagicMock()
mock_auth.is_token_expired.return_value = False
MockCC.create_auth.return_value = mock_auth
resp = client.post("/api/connectors/validate-session", json={
"provider": "google_drive", "session_token": "valid",
})
assert resp.status_code == 200
data = json.loads(resp.data)
assert data["success"] is True
assert data["expired"] is False
@pytest.mark.unit
def test_expired_with_refresh(self, client, mock_sessions):
mock_sessions["sessions"].insert_one({
"session_token": "expired_tok",
"user": "test_user",
"provider": "google_drive",
"token_info": {"access_token": "old_at", "refresh_token": "rt", "expiry": 100},
})
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
mock_auth = MagicMock()
mock_auth.is_token_expired.return_value = True
mock_auth.refresh_access_token.return_value = {"access_token": "new_at", "refresh_token": "rt"}
mock_auth.sanitize_token_info.return_value = {"access_token": "new_at", "refresh_token": "rt"}
MockCC.create_auth.return_value = mock_auth
resp = client.post("/api/connectors/validate-session", json={
"provider": "google_drive", "session_token": "expired_tok",
})
assert resp.status_code == 200
@pytest.mark.unit
def test_expired_no_refresh(self, client, mock_sessions):
mock_sessions["sessions"].insert_one({
"session_token": "exp_no_ref",
"user": "test_user",
"token_info": {"access_token": "at", "expiry": 100},
})
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
mock_auth = MagicMock()
mock_auth.is_token_expired.return_value = True
MockCC.create_auth.return_value = mock_auth
resp = client.post("/api/connectors/validate-session", json={
"provider": "google_drive", "session_token": "exp_no_ref",
})
assert resp.status_code == 401
class TestConnectorDisconnect:
@pytest.mark.unit
def test_missing_provider(self, client):
resp = client.post("/api/connectors/disconnect", json={})
assert resp.status_code == 400
@pytest.mark.unit
def test_success_with_session(self, client, mock_sessions):
mock_sessions["sessions"].insert_one({"session_token": "del_me", "provider": "google_drive"})
resp = client.post("/api/connectors/disconnect", json={
"provider": "google_drive", "session_token": "del_me",
})
assert resp.status_code == 200
data = json.loads(resp.data)
assert data["success"] is True
@pytest.mark.unit
def test_success_without_session(self, client, mock_sessions):
resp = client.post("/api/connectors/disconnect", json={"provider": "google_drive"})
assert resp.status_code == 200
class TestConnectorSync:
@pytest.mark.unit
def test_missing_params(self, client, mock_sessions):
resp = client.post("/api/connectors/sync", json={"source_id": "abc"})
assert resp.status_code == 400
@pytest.mark.unit
def test_source_not_found(self, client, mock_sessions):
from bson.objectid import ObjectId
resp = client.post("/api/connectors/sync", json={
"source_id": str(ObjectId()), "session_token": "tok",
})
assert resp.status_code == 404
@pytest.mark.unit
def test_unauthorized_source(self, client, mock_sessions):
sid = mock_sessions["sources"].insert_one({"user": "other_user", "name": "src"}).inserted_id
resp = client.post("/api/connectors/sync", json={
"source_id": str(sid), "session_token": "tok",
})
assert resp.status_code == 403
@pytest.mark.unit
def test_missing_provider_in_remote_data(self, client, mock_sessions):
sid = mock_sessions["sources"].insert_one({
"user": "test_user", "name": "src", "remote_data": json.dumps({}),
}).inserted_id
resp = client.post("/api/connectors/sync", json={
"source_id": str(sid), "session_token": "tok",
})
assert resp.status_code == 400
@pytest.mark.unit
def test_success(self, client, mock_sessions):
sid = mock_sessions["sources"].insert_one({
"user": "test_user",
"name": "src",
"remote_data": json.dumps({"provider": "google_drive", "file_ids": ["f1"]}),
}).inserted_id
mock_task = MagicMock()
mock_task.id = "task_123"
with patch("application.api.connector.routes.ingest_connector_task") as mock_ingest:
mock_ingest.delay.return_value = mock_task
resp = client.post("/api/connectors/sync", json={
"source_id": str(sid), "session_token": "tok",
})
assert resp.status_code == 200
data = json.loads(resp.data)
assert data["task_id"] == "task_123"
class TestConnectorCallbackStatus:
@pytest.mark.unit
def test_success_status(self, client):
resp = client.get("/api/connectors/callback-status?status=success&message=OK&provider=google_drive&session_token=tok&user_email=u@e.com")
assert resp.status_code == 200
assert b"success" in resp.data
@pytest.mark.unit
def test_error_status(self, client):
resp = client.get("/api/connectors/callback-status?status=error&message=Failed")
assert resp.status_code == 200
assert b"error" in resp.data
@pytest.mark.unit
def test_cancelled_status(self, client):
resp = client.get("/api/connectors/callback-status?status=cancelled&message=Cancelled&provider=google_drive")
assert resp.status_code == 200
assert b"cancelled" in resp.data
@pytest.mark.unit
def test_unknown_status_defaults_to_error(self, client):
resp = client.get("/api/connectors/callback-status?status=badvalue")
assert resp.status_code == 200
assert b"error" in resp.data
@pytest.mark.unit
def test_html_escaping(self, client):
resp = client.get('/api/connectors/callback-status?status=error&message=<script>alert(1)</script>')
assert resp.status_code == 200
# The raw <script> tag should be escaped (not executable)
assert b"<script>alert(1)</script>" not in resp.data
class TestBuildCallbackRedirect:
@pytest.mark.unit
def test_builds_url(self):
from application.api.connector.routes import build_callback_redirect
url = build_callback_redirect({"status": "success", "message": "OK"})
assert url.startswith("/api/connectors/callback-status?")
assert "status=success" in url
@pytest.mark.unit
class TestConnectorsCallback:
"""Tests for the ConnectorsCallback OAuth callback route."""
def _encode_state(self, state_dict):
return base64.urlsafe_b64encode(json.dumps(state_dict).encode()).decode()
def _patch_connector_creator(self):
"""Patch ConnectorCreator at both module-level and local-import locations."""
return patch(
"application.parser.connectors.connector_creator.ConnectorCreator",
)
def test_callback_invalid_provider_redirects_error(self, client, mock_sessions):
state = self._encode_state({"provider": "dropbox", "object_id": "abc123"})
with self._patch_connector_creator() as MockCC:
MockCC.is_supported.return_value = False
resp = client.get(
f"/api/connectors/callback?code=auth_code&state={state}"
)
assert resp.status_code == 302
assert "error" in resp.headers.get("Location", "")
def test_callback_access_denied_redirects_cancelled(self, client, mock_sessions):
state = self._encode_state(
{"provider": "google_drive", "object_id": "abc123"}
)
with self._patch_connector_creator() as MockCC:
MockCC.is_supported.return_value = True
resp = client.get(
f"/api/connectors/callback?error=access_denied&state={state}"
)
assert resp.status_code == 302
assert "cancelled" in resp.headers.get("Location", "")
def test_callback_other_error_redirects_error(self, client, mock_sessions):
state = self._encode_state(
{"provider": "google_drive", "object_id": "abc123"}
)
with self._patch_connector_creator() as MockCC:
MockCC.is_supported.return_value = True
resp = client.get(
f"/api/connectors/callback?error=server_error&state={state}"
)
assert resp.status_code == 302
assert "error" in resp.headers.get("Location", "")
def test_callback_missing_code_redirects_error(self, client, mock_sessions):
state = self._encode_state(
{"provider": "google_drive", "object_id": "abc123"}
)
with self._patch_connector_creator() as MockCC:
MockCC.is_supported.return_value = True
resp = client.get(f"/api/connectors/callback?state={state}")
assert resp.status_code == 302
assert "error" in resp.headers.get("Location", "")
def test_callback_success_google_drive(self, client, mock_sessions):
oid = mock_sessions["sessions"].insert_one(
{
"provider": "google_drive",
"user": "test_user",
"status": "pending",
}
).inserted_id
state = self._encode_state(
{"provider": "google_drive", "object_id": str(oid)}
)
with self._patch_connector_creator() as MockCC:
MockCC.is_supported.return_value = True
mock_auth = MagicMock()
mock_auth.exchange_code_for_tokens.return_value = {
"access_token": "at",
"refresh_token": "rt",
}
mock_creds = MagicMock()
mock_auth.create_credentials_from_token_info.return_value = mock_creds
mock_service = MagicMock()
mock_service.about.return_value.get.return_value.execute.return_value = {
"user": {"emailAddress": "user@example.com"}
}
mock_auth.build_drive_service.return_value = mock_service
mock_auth.sanitize_token_info.return_value = {
"access_token": "at",
"refresh_token": "rt",
}
MockCC.create_auth.return_value = mock_auth
resp = client.get(
f"/api/connectors/callback?code=auth_code&state={state}"
)
assert resp.status_code == 302
assert "success" in resp.headers.get("Location", "")
def test_callback_success_non_google_provider(self, client, mock_sessions):
oid = mock_sessions["sessions"].insert_one(
{
"provider": "other_provider",
"user": "test_user",
"status": "pending",
}
).inserted_id
state = self._encode_state(
{"provider": "other_provider", "object_id": str(oid)}
)
with self._patch_connector_creator() as MockCC:
MockCC.is_supported.return_value = True
mock_auth = MagicMock()
mock_auth.exchange_code_for_tokens.return_value = {
"access_token": "at",
"user_info": {"email": "other@example.com"},
}
mock_auth.sanitize_token_info.return_value = {"access_token": "at"}
MockCC.create_auth.return_value = mock_auth
resp = client.get(
f"/api/connectors/callback?code=auth_code&state={state}"
)
assert resp.status_code == 302
assert "success" in resp.headers.get("Location", "")
def test_callback_exchange_tokens_fails(self, client, mock_sessions):
oid = mock_sessions["sessions"].insert_one(
{
"provider": "google_drive",
"user": "test_user",
"status": "pending",
}
).inserted_id
state = self._encode_state(
{"provider": "google_drive", "object_id": str(oid)}
)
with self._patch_connector_creator() as MockCC:
MockCC.is_supported.return_value = True
mock_auth = MagicMock()
mock_auth.exchange_code_for_tokens.side_effect = Exception("token error")
MockCC.create_auth.return_value = mock_auth
resp = client.get(
f"/api/connectors/callback?code=auth_code&state={state}"
)
assert resp.status_code == 302
assert "error" in resp.headers.get("Location", "")
def test_callback_bad_state_returns_error(self, client, mock_sessions):
resp = client.get("/api/connectors/callback?code=auth_code&state=badbase64!!!")
assert resp.status_code == 302
assert "error" in resp.headers.get("Location", "")
def test_callback_user_info_fails_gracefully(self, client, mock_sessions):
oid = mock_sessions["sessions"].insert_one(
{
"provider": "google_drive",
"user": "test_user",
"status": "pending",
}
).inserted_id
state = self._encode_state(
{"provider": "google_drive", "object_id": str(oid)}
)
with self._patch_connector_creator() as MockCC:
MockCC.is_supported.return_value = True
mock_auth = MagicMock()
mock_auth.exchange_code_for_tokens.return_value = {
"access_token": "at",
"refresh_token": "rt",
}
mock_auth.create_credentials_from_token_info.side_effect = Exception(
"cred error"
)
mock_auth.sanitize_token_info.return_value = {
"access_token": "at",
}
MockCC.create_auth.return_value = mock_auth
resp = client.get(
f"/api/connectors/callback?code=auth_code&state={state}"
)
assert resp.status_code == 302
assert "success" in resp.headers.get("Location", "")
@pytest.mark.unit
class TestConnectorFilesAdditional:
"""Additional tests for ConnectorFiles."""
def test_unauthorized_user(self, client, mock_sessions):
with patch("application.app.handle_auth", return_value=None):
resp = client.post(
"/api/connectors/files",
json={
"provider": "google_drive",
"session_token": "tok",
},
)
assert resp.status_code == 401
def test_files_with_pagination(self, client, mock_sessions):
mock_sessions["sessions"].insert_one(
{
"session_token": "pag_tok",
"user": "test_user",
"provider": "google_drive",
}
)
mock_doc = MagicMock()
mock_doc.doc_id = "f1"
mock_doc.extra_info = {
"file_name": "test.pdf",
"mime_type": "application/pdf",
"size": 1024,
"modified_time": "2025-01-01T12:00:00.000Z",
"is_folder": False,
}
mock_loader = MagicMock()
mock_loader.load_data.return_value = [mock_doc]
mock_loader.next_page_token = "next_token_123"
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
MockCC.create_connector.return_value = mock_loader
resp = client.post(
"/api/connectors/files",
json={
"provider": "google_drive",
"session_token": "pag_tok",
"page_token": "prev_token",
},
)
assert resp.status_code == 200
data = json.loads(resp.data)
assert data["has_more"] is True
assert data["next_page_token"] == "next_token_123"
def test_files_exception_returns_500(self, client, mock_sessions):
mock_sessions["sessions"].insert_one(
{
"session_token": "err_tok",
"user": "test_user",
"provider": "google_drive",
}
)
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
MockCC.create_connector.side_effect = Exception("connector error")
resp = client.post(
"/api/connectors/files",
json={
"provider": "google_drive",
"session_token": "err_tok",
},
)
assert resp.status_code == 500
@pytest.mark.unit
class TestConnectorFilesSearchQuery:
"""Test ConnectorFiles with search_query parameter."""
def test_files_with_search_query(self, client, mock_sessions):
mock_sessions["sessions"].insert_one(
{
"session_token": "search_tok",
"user": "test_user",
"provider": "google_drive",
}
)
mock_doc = MagicMock()
mock_doc.doc_id = "f1"
mock_doc.extra_info = {
"file_name": "result.pdf",
"mime_type": "application/pdf",
"size": 512,
"is_folder": False,
}
mock_loader = MagicMock()
mock_loader.load_data.return_value = [mock_doc]
mock_loader.next_page_token = None
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
MockCC.create_connector.return_value = mock_loader
resp = client.post(
"/api/connectors/files",
json={
"provider": "google_drive",
"session_token": "search_tok",
"search_query": "test search",
},
)
assert resp.status_code == 200
data = json.loads(resp.data)
assert data["success"] is True
# Verify search_query was passed in input_config
call_args = mock_loader.load_data.call_args[0][0]
assert call_args.get("search_query") == "test search"
# ---------------------------------------------------------------------------
# Additional coverage tests
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestConnectorValidateSessionAdditional:
"""Cover uncovered branches in ConnectorValidateSession."""
def test_unauthorized_returns_401(self, client, mock_sessions):
"""Line 288: decoded_token is None -> 401."""
with patch("application.app.handle_auth", return_value=None):
resp = client.post(
"/api/connectors/validate-session",
json={
"provider": "google_drive",
"session_token": "tok",
},
)
assert resp.status_code == 401
def test_refresh_token_failure_still_expired(self, client, mock_sessions):
"""Lines 299-310: refresh attempt fails, token stays expired."""
mock_sessions["sessions"].insert_one({
"session_token": "rf_fail_tok",
"user": "test_user",
"provider": "google_drive",
"token_info": {
"access_token": "old_at",
"refresh_token": "rt",
"expiry": 100,
},
})
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
mock_auth = MagicMock()
mock_auth.is_token_expired.return_value = True
mock_auth.refresh_access_token.side_effect = Exception("refresh failed")
MockCC.create_auth.return_value = mock_auth
resp = client.post(
"/api/connectors/validate-session",
json={
"provider": "google_drive",
"session_token": "rf_fail_tok",
},
)
assert resp.status_code == 401
data = json.loads(resp.data)
assert data["expired"] is True
def test_provider_extras_in_response(self, client, mock_sessions):
"""Lines 319-327: provider_extras are included in response."""
mock_sessions["sessions"].insert_one({
"session_token": "extras_tok",
"user": "test_user",
"provider": "google_drive",
"token_info": {
"access_token": "at",
"refresh_token": "rt",
"token_uri": "uri",
"expiry": None,
"custom_field": "custom_value",
},
"user_email": "user@test.com",
})
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
mock_auth = MagicMock()
mock_auth.is_token_expired.return_value = False
MockCC.create_auth.return_value = mock_auth
resp = client.post(
"/api/connectors/validate-session",
json={
"provider": "google_drive",
"session_token": "extras_tok",
},
)
assert resp.status_code == 200
data = json.loads(resp.data)
assert data["success"] is True
assert data["custom_field"] == "custom_value"
assert data["user_email"] == "user@test.com"
def test_exception_returns_500(self, client, mock_sessions):
"""Lines 331-333: general exception -> 500."""
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
MockCC.create_auth.side_effect = Exception("total failure")
mock_sessions["sessions"].insert_one({
"session_token": "err_tok",
"user": "test_user",
"provider": "google_drive",
"token_info": {"access_token": "at"},
})
resp = client.post(
"/api/connectors/validate-session",
json={
"provider": "google_drive",
"session_token": "err_tok",
},
)
assert resp.status_code == 500
@pytest.mark.unit
class TestConnectorDisconnectAdditional:
"""Cover uncovered branches in ConnectorDisconnect."""
def test_exception_returns_500(self, client, mock_sessions):
"""Lines 353-355: exception in disconnect -> 500."""
with patch(
"application.api.connector.routes.sessions_collection"
) as mock_col:
mock_col.delete_one.side_effect = Exception("db down")
resp = client.post(
"/api/connectors/disconnect",
json={
"provider": "google_drive",
"session_token": "tok",
},
)
assert resp.status_code == 500
def test_unauthorized_still_works(self, client, mock_sessions):
"""ConnectorDisconnect doesn't check decoded_token, just data parsing.
No auth check branch to cover, but confirm basic flow."""
resp = client.post(
"/api/connectors/disconnect",
json={"provider": "google_drive"},
)
assert resp.status_code == 200
@pytest.mark.unit
class TestConnectorSyncAdditional:
"""Cover uncovered branches in ConnectorSync."""
def test_unauthorized_returns_401(self, client, mock_sessions):
"""Line 373: decoded_token is None -> 401."""
from bson.objectid import ObjectId as ObjId
with patch("application.app.handle_auth", return_value=None):
resp = client.post(
"/api/connectors/sync",
json={
"source_id": str(ObjId()),
"session_token": "tok",
},
)
assert resp.status_code == 401
def test_exception_returns_400(self, client, mock_sessions):
"""Lines 453-464: general exception returns 400."""
sid = mock_sessions["sources"].insert_one({
"user": "test_user",
"name": "src",
"remote_data": json.dumps({
"provider": "google_drive",
"file_ids": ["f1"],
}),
}).inserted_id
with patch(
"application.api.connector.routes.ingest_connector_task"
) as mock_ingest:
mock_ingest.delay.side_effect = Exception("task error")
resp = client.post(
"/api/connectors/sync",
json={
"source_id": str(sid),
"session_token": "tok",
},
)
assert resp.status_code == 400
def test_invalid_remote_data_json(self, client, mock_sessions):
"""Line 411-413: invalid remote_data JSON."""
sid = mock_sessions["sources"].insert_one({
"user": "test_user",
"name": "src",
"remote_data": "not-valid-json{",
}).inserted_id
resp = client.post(
"/api/connectors/sync",
json={
"source_id": str(sid),
"session_token": "tok",
},
)
# remote_data parsing fails, remote_data = {}, no provider -> 400
assert resp.status_code == 400

View File

@@ -0,0 +1,579 @@
"""Unit tests for application/api/internal/routes.py.
Covers:
- verify_internal_key: key validation
- /api/download: file download
- /api/upload_index: index file upload (existing & new entries)
"""
import io
import json
from unittest.mock import MagicMock
import pytest
from bson.objectid import ObjectId
@pytest.fixture
def internal_app(monkeypatch, mock_mongo_db):
"""Create a Flask app with the internal blueprint registered."""
from flask import Flask
# Patch module-level MongoDB references before importing routes
from application.core.settings import settings
db = mock_mongo_db[settings.MONGO_DB_NAME]
monkeypatch.setattr(
"application.api.internal.routes.conversations_collection",
db["conversations"],
)
monkeypatch.setattr(
"application.api.internal.routes.sources_collection",
db["sources"],
)
from application.api.internal.routes import internal
app = Flask(__name__)
app.register_blueprint(internal)
app.config["TESTING"] = True
return app, db
# ---------------------------------------------------------------------------
# verify_internal_key
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestVerifyInternalKey:
def test_no_internal_key_configured_allows_access(
self, internal_app, monkeypatch
):
app, db = internal_app
monkeypatch.setattr(
"application.api.internal.routes.settings",
MagicMock(
INTERNAL_KEY=None,
UPLOAD_FOLDER="uploads",
VECTOR_STORE="faiss",
EMBEDDINGS_NAME="test",
MONGO_DB_NAME="docsgpt",
),
)
with app.test_client() as client:
# download will fail for missing file but should not be 401
resp = client.get("/api/download?user=u&name=n&file=f")
assert resp.status_code != 401
def test_missing_key_returns_401(self, internal_app, monkeypatch):
app, db = internal_app
monkeypatch.setattr(
"application.api.internal.routes.settings",
MagicMock(
INTERNAL_KEY="secret123",
UPLOAD_FOLDER="uploads",
VECTOR_STORE="faiss",
EMBEDDINGS_NAME="test",
MONGO_DB_NAME="docsgpt",
),
)
with app.test_client() as client:
resp = client.get("/api/download?user=u&name=n&file=f")
assert resp.status_code == 401
def test_wrong_key_returns_401(self, internal_app, monkeypatch):
app, db = internal_app
monkeypatch.setattr(
"application.api.internal.routes.settings",
MagicMock(
INTERNAL_KEY="secret123",
UPLOAD_FOLDER="uploads",
VECTOR_STORE="faiss",
EMBEDDINGS_NAME="test",
MONGO_DB_NAME="docsgpt",
),
)
with app.test_client() as client:
resp = client.get(
"/api/download?user=u&name=n&file=f",
headers={"X-Internal-Key": "wrong"},
)
assert resp.status_code == 401
def test_correct_key_allows_access(self, internal_app, monkeypatch):
app, db = internal_app
monkeypatch.setattr(
"application.api.internal.routes.settings",
MagicMock(
INTERNAL_KEY="secret123",
UPLOAD_FOLDER="uploads",
VECTOR_STORE="faiss",
EMBEDDINGS_NAME="test",
MONGO_DB_NAME="docsgpt",
),
)
with app.test_client() as client:
# Will 404 for missing file, but should pass auth check
resp = client.get(
"/api/download?user=u&name=n&file=f",
headers={"X-Internal-Key": "secret123"},
)
assert resp.status_code != 401
# ---------------------------------------------------------------------------
# /api/upload_index
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestUploadIndex:
def _make_settings(self, vector_store="faiss"):
return MagicMock(
INTERNAL_KEY=None,
UPLOAD_FOLDER="uploads",
VECTOR_STORE=vector_store,
EMBEDDINGS_NAME="test_embeddings",
MONGO_DB_NAME="docsgpt",
)
def test_missing_user_returns_no_user(self, internal_app, monkeypatch):
app, db = internal_app
monkeypatch.setattr(
"application.api.internal.routes.settings", self._make_settings()
)
with app.test_client() as client:
resp = client.post("/api/upload_index", data={})
assert resp.json["status"] == "no user"
def test_missing_name_returns_no_name(self, internal_app, monkeypatch):
app, db = internal_app
monkeypatch.setattr(
"application.api.internal.routes.settings", self._make_settings()
)
with app.test_client() as client:
resp = client.post("/api/upload_index", data={"user": "testuser"})
assert resp.json["status"] == "no name"
def test_creates_new_source_entry(self, internal_app, monkeypatch):
app, db = internal_app
doc_id = str(ObjectId())
settings_mock = self._make_settings(vector_store="other")
monkeypatch.setattr(
"application.api.internal.routes.settings", settings_mock
)
mock_storage = MagicMock()
monkeypatch.setattr(
"application.api.internal.routes.StorageCreator",
MagicMock(get_storage=MagicMock(return_value=mock_storage)),
)
with app.test_client() as client:
resp = client.post(
"/api/upload_index",
data={
"user": "testuser",
"name": "testjob",
"tokens": "100",
"retriever": "classic",
"id": doc_id,
"type": "local",
},
)
assert resp.json["status"] == "ok"
entry = db["sources"].find_one({"_id": ObjectId(doc_id)})
assert entry is not None
assert entry["user"] == "testuser"
assert entry["name"] == "testjob"
def test_updates_existing_source_entry(self, internal_app, monkeypatch):
app, db = internal_app
doc_id = ObjectId()
settings_mock = self._make_settings(vector_store="other")
monkeypatch.setattr(
"application.api.internal.routes.settings", settings_mock
)
mock_storage = MagicMock()
monkeypatch.setattr(
"application.api.internal.routes.StorageCreator",
MagicMock(get_storage=MagicMock(return_value=mock_storage)),
)
# Insert existing entry
db["sources"].insert_one(
{"_id": doc_id, "user": "old_user", "name": "old_name"}
)
with app.test_client() as client:
resp = client.post(
"/api/upload_index",
data={
"user": "new_user",
"name": "new_name",
"tokens": "200",
"retriever": "hybrid",
"id": str(doc_id),
"type": "remote",
},
)
assert resp.json["status"] == "ok"
entry = db["sources"].find_one({"_id": doc_id})
assert entry["user"] == "new_user"
assert entry["name"] == "new_name"
def test_parses_directory_structure_json(self, internal_app, monkeypatch):
app, db = internal_app
doc_id = str(ObjectId())
settings_mock = self._make_settings(vector_store="other")
monkeypatch.setattr(
"application.api.internal.routes.settings", settings_mock
)
mock_storage = MagicMock()
monkeypatch.setattr(
"application.api.internal.routes.StorageCreator",
MagicMock(get_storage=MagicMock(return_value=mock_storage)),
)
dir_struct = {"root": {"files": ["a.txt", "b.txt"]}}
with app.test_client() as client:
resp = client.post(
"/api/upload_index",
data={
"user": "u",
"name": "n",
"tokens": "0",
"retriever": "classic",
"id": doc_id,
"type": "local",
"directory_structure": json.dumps(dir_struct),
},
)
assert resp.json["status"] == "ok"
entry = db["sources"].find_one({"_id": ObjectId(doc_id)})
assert entry["directory_structure"] == dir_struct
def test_invalid_directory_structure_defaults_empty(
self, internal_app, monkeypatch
):
app, db = internal_app
doc_id = str(ObjectId())
settings_mock = self._make_settings(vector_store="other")
monkeypatch.setattr(
"application.api.internal.routes.settings", settings_mock
)
mock_storage = MagicMock()
monkeypatch.setattr(
"application.api.internal.routes.StorageCreator",
MagicMock(get_storage=MagicMock(return_value=mock_storage)),
)
with app.test_client() as client:
resp = client.post(
"/api/upload_index",
data={
"user": "u",
"name": "n",
"tokens": "0",
"retriever": "classic",
"id": doc_id,
"type": "local",
"directory_structure": "not valid json",
},
)
assert resp.json["status"] == "ok"
entry = db["sources"].find_one({"_id": ObjectId(doc_id)})
assert entry["directory_structure"] == {}
def test_file_name_map_parsed(self, internal_app, monkeypatch):
app, db = internal_app
doc_id = str(ObjectId())
settings_mock = self._make_settings(vector_store="other")
monkeypatch.setattr(
"application.api.internal.routes.settings", settings_mock
)
mock_storage = MagicMock()
monkeypatch.setattr(
"application.api.internal.routes.StorageCreator",
MagicMock(get_storage=MagicMock(return_value=mock_storage)),
)
fmap = {"hash1": "file1.txt"}
with app.test_client() as client:
resp = client.post(
"/api/upload_index",
data={
"user": "u",
"name": "n",
"tokens": "0",
"retriever": "classic",
"id": doc_id,
"type": "local",
"file_name_map": json.dumps(fmap),
},
)
assert resp.json["status"] == "ok"
entry = db["sources"].find_one({"_id": ObjectId(doc_id)})
assert entry["file_name_map"] == fmap
def test_faiss_missing_files_returns_no_file(
self, internal_app, monkeypatch
):
app, db = internal_app
doc_id = str(ObjectId())
settings_mock = self._make_settings(vector_store="faiss")
monkeypatch.setattr(
"application.api.internal.routes.settings", settings_mock
)
mock_storage = MagicMock()
monkeypatch.setattr(
"application.api.internal.routes.StorageCreator",
MagicMock(get_storage=MagicMock(return_value=mock_storage)),
)
with app.test_client() as client:
resp = client.post(
"/api/upload_index",
data={
"user": "u",
"name": "n",
"tokens": "0",
"retriever": "classic",
"id": doc_id,
"type": "local",
},
)
assert resp.json["status"] == "no file"
def test_faiss_empty_filename_returns_no_file_name(
self, internal_app, monkeypatch
):
app, db = internal_app
doc_id = str(ObjectId())
settings_mock = self._make_settings(vector_store="faiss")
monkeypatch.setattr(
"application.api.internal.routes.settings", settings_mock
)
mock_storage = MagicMock()
monkeypatch.setattr(
"application.api.internal.routes.StorageCreator",
MagicMock(get_storage=MagicMock(return_value=mock_storage)),
)
with app.test_client() as client:
resp = client.post(
"/api/upload_index",
data={
"user": "u",
"name": "n",
"tokens": "0",
"retriever": "classic",
"id": doc_id,
"type": "local",
"file_faiss": (io.BytesIO(b""), ""),
},
)
assert resp.json["status"] == "no file name"
def test_remote_data_and_sync_frequency(self, internal_app, monkeypatch):
app, db = internal_app
doc_id = str(ObjectId())
settings_mock = self._make_settings(vector_store="other")
monkeypatch.setattr(
"application.api.internal.routes.settings", settings_mock
)
mock_storage = MagicMock()
monkeypatch.setattr(
"application.api.internal.routes.StorageCreator",
MagicMock(get_storage=MagicMock(return_value=mock_storage)),
)
with app.test_client() as client:
resp = client.post(
"/api/upload_index",
data={
"user": "u",
"name": "n",
"tokens": "0",
"retriever": "classic",
"id": doc_id,
"type": "remote",
"remote_data": '{"url":"http://example.com"}',
"sync_frequency": "daily",
},
)
assert resp.json["status"] == "ok"
entry = db["sources"].find_one({"_id": ObjectId(doc_id)})
assert entry["sync_frequency"] == "daily"
assert entry["remote_data"] == '{"url":"http://example.com"}'
def test_faiss_upload_with_valid_files(self, internal_app, monkeypatch):
"""Cover lines 93-104: FAISS upload with both faiss and pkl files."""
app, db = internal_app
doc_id = str(ObjectId())
settings_mock = self._make_settings(vector_store="faiss")
monkeypatch.setattr(
"application.api.internal.routes.settings", settings_mock
)
mock_storage = MagicMock()
monkeypatch.setattr(
"application.api.internal.routes.StorageCreator",
MagicMock(get_storage=MagicMock(return_value=mock_storage)),
)
with app.test_client() as client:
resp = client.post(
"/api/upload_index",
data={
"user": "u",
"name": "n",
"tokens": "0",
"retriever": "classic",
"id": doc_id,
"type": "local",
"file_faiss": (io.BytesIO(b"faiss data"), "index.faiss"),
"file_pkl": (io.BytesIO(b"pkl data"), "index.pkl"),
},
content_type="multipart/form-data",
)
assert resp.json["status"] == "ok"
mock_storage.save_file.assert_called()
entry = db["sources"].find_one({"_id": ObjectId(doc_id)})
assert entry is not None
def test_faiss_pkl_missing_returns_no_file(self, internal_app, monkeypatch):
"""Cover lines 93-95: FAISS upload with faiss file but no pkl file."""
app, db = internal_app
doc_id = str(ObjectId())
settings_mock = self._make_settings(vector_store="faiss")
monkeypatch.setattr(
"application.api.internal.routes.settings", settings_mock
)
mock_storage = MagicMock()
monkeypatch.setattr(
"application.api.internal.routes.StorageCreator",
MagicMock(get_storage=MagicMock(return_value=mock_storage)),
)
with app.test_client() as client:
resp = client.post(
"/api/upload_index",
data={
"user": "u",
"name": "n",
"tokens": "0",
"retriever": "classic",
"id": doc_id,
"type": "local",
"file_faiss": (io.BytesIO(b"faiss data"), "index.faiss"),
},
content_type="multipart/form-data",
)
assert resp.json["status"] == "no file"
def test_faiss_pkl_empty_name_returns_no_file_name(self, internal_app, monkeypatch):
"""Cover lines 97-98: FAISS upload with pkl but empty filename."""
app, db = internal_app
doc_id = str(ObjectId())
settings_mock = self._make_settings(vector_store="faiss")
monkeypatch.setattr(
"application.api.internal.routes.settings", settings_mock
)
mock_storage = MagicMock()
monkeypatch.setattr(
"application.api.internal.routes.StorageCreator",
MagicMock(get_storage=MagicMock(return_value=mock_storage)),
)
with app.test_client() as client:
resp = client.post(
"/api/upload_index",
data={
"user": "u",
"name": "n",
"tokens": "0",
"retriever": "classic",
"id": doc_id,
"type": "local",
"file_faiss": (io.BytesIO(b"faiss data"), "index.faiss"),
"file_pkl": (io.BytesIO(b""), ""),
},
content_type="multipart/form-data",
)
assert resp.json["status"] == "no file name"
def test_update_existing_with_file_name_map(self, internal_app, monkeypatch):
"""Cover line 124: update existing entry with file_name_map."""
app, db = internal_app
doc_id = ObjectId()
settings_mock = self._make_settings(vector_store="other")
monkeypatch.setattr(
"application.api.internal.routes.settings", settings_mock
)
mock_storage = MagicMock()
monkeypatch.setattr(
"application.api.internal.routes.StorageCreator",
MagicMock(get_storage=MagicMock(return_value=mock_storage)),
)
db["sources"].insert_one({"_id": doc_id, "user": "old_user", "name": "old"})
fmap = {"hash1": "file1.txt"}
with app.test_client() as client:
resp = client.post(
"/api/upload_index",
data={
"user": "u",
"name": "n",
"tokens": "0",
"retriever": "classic",
"id": str(doc_id),
"type": "local",
"file_name_map": json.dumps(fmap),
},
)
assert resp.json["status"] == "ok"
entry = db["sources"].find_one({"_id": doc_id})
assert entry["file_name_map"] == fmap
def test_invalid_file_name_map_defaults_none(self, internal_app, monkeypatch):
"""Cover lines 77-79: invalid file_name_map JSON defaults to None."""
app, db = internal_app
doc_id = str(ObjectId())
settings_mock = self._make_settings(vector_store="other")
monkeypatch.setattr(
"application.api.internal.routes.settings", settings_mock
)
mock_storage = MagicMock()
monkeypatch.setattr(
"application.api.internal.routes.StorageCreator",
MagicMock(get_storage=MagicMock(return_value=mock_storage)),
)
with app.test_client() as client:
resp = client.post(
"/api/upload_index",
data={
"user": "u",
"name": "n",
"tokens": "0",
"retriever": "classic",
"id": doc_id,
"type": "local",
"file_name_map": "not valid json{{{",
},
)
assert resp.json["status"] == "ok"
entry = db["sources"].find_one({"_id": ObjectId(doc_id)})
assert "file_name_map" not in entry

View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,879 @@
"""Tests for source chunk management routes."""
import pytest
from unittest.mock import Mock, patch
from bson import ObjectId
from flask import Flask
@pytest.fixture
def app():
app = Flask(__name__)
return app
def _status(response):
if isinstance(response, tuple):
return response[1]
return response.status_code
def _json(response):
if isinstance(response, tuple):
return response[0].json
return response.json
# ---------------------------------------------------------------------------
# GetChunks
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestGetChunks:
def test_returns_401_unauthenticated(self, app):
from application.api.user.sources.chunks import GetChunks
with app.test_request_context("/api/get_chunks?id=abc"):
from flask import request
request.decoded_token = None
response = GetChunks().get()
assert _status(response) == 401
def test_returns_400_for_invalid_doc_id(self, app):
from application.api.user.sources.chunks import GetChunks
with app.test_request_context("/api/get_chunks?id=invalid"):
from flask import request
request.decoded_token = {"sub": "user1"}
response = GetChunks().get()
assert _status(response) == 400
assert "Invalid doc_id" in _json(response)["error"]
def test_returns_404_when_doc_not_found(self, app):
from application.api.user.sources.chunks import GetChunks
doc_id = str(ObjectId())
mock_collection = Mock()
mock_collection.find_one.return_value = None
with patch(
"application.api.user.sources.chunks.sources_collection",
mock_collection,
):
with app.test_request_context(f"/api/get_chunks?id={doc_id}"):
from flask import request
request.decoded_token = {"sub": "user1"}
response = GetChunks().get()
assert _status(response) == 404
assert "not found" in _json(response)["error"]
def test_returns_paginated_chunks(self, app):
from application.api.user.sources.chunks import GetChunks
doc_id = str(ObjectId())
mock_collection = Mock()
mock_collection.find_one.return_value = {"_id": ObjectId(doc_id), "user": "u1"}
chunks = [
{"text": f"chunk {i}", "metadata": {}, "doc_id": f"c{i}"}
for i in range(25)
]
mock_store = Mock()
mock_store.get_chunks.return_value = chunks
with patch(
"application.api.user.sources.chunks.sources_collection",
mock_collection,
), patch(
"application.api.user.sources.chunks.get_vector_store",
return_value=mock_store,
):
with app.test_request_context(
f"/api/get_chunks?id={doc_id}&page=2&per_page=10"
):
from flask import request
request.decoded_token = {"sub": "u1"}
response = GetChunks().get()
assert _status(response) == 200
data = _json(response)
assert data["total"] == 25
assert data["page"] == 2
assert data["per_page"] == 10
assert len(data["chunks"]) == 10
def test_filters_chunks_by_path(self, app):
from application.api.user.sources.chunks import GetChunks
doc_id = str(ObjectId())
mock_collection = Mock()
mock_collection.find_one.return_value = {"_id": ObjectId(doc_id), "user": "u1"}
chunks = [
{"text": "a", "metadata": {"source": "inputs/dir/file.pdf"}, "doc_id": "c1"},
{"text": "b", "metadata": {"source": "inputs/other.txt"}, "doc_id": "c2"},
{"text": "c", "metadata": {"file_path": "guides/setup.md"}, "doc_id": "c3"},
]
mock_store = Mock()
mock_store.get_chunks.return_value = chunks
with patch(
"application.api.user.sources.chunks.sources_collection",
mock_collection,
), patch(
"application.api.user.sources.chunks.get_vector_store",
return_value=mock_store,
):
with app.test_request_context(
f"/api/get_chunks?id={doc_id}&path=file.pdf"
):
from flask import request
request.decoded_token = {"sub": "u1"}
response = GetChunks().get()
data = _json(response)
assert data["total"] == 1
assert data["chunks"][0]["text"] == "a"
assert data["path"] == "file.pdf"
def test_filters_chunks_by_file_path_metadata(self, app):
from application.api.user.sources.chunks import GetChunks
doc_id = str(ObjectId())
mock_collection = Mock()
mock_collection.find_one.return_value = {"_id": ObjectId(doc_id), "user": "u1"}
chunks = [
{"text": "a", "metadata": {"source": "inputs/dir/file.pdf"}, "doc_id": "c1"},
{"text": "c", "metadata": {"file_path": "guides/setup.md"}, "doc_id": "c3"},
]
mock_store = Mock()
mock_store.get_chunks.return_value = chunks
with patch(
"application.api.user.sources.chunks.sources_collection",
mock_collection,
), patch(
"application.api.user.sources.chunks.get_vector_store",
return_value=mock_store,
):
with app.test_request_context(
f"/api/get_chunks?id={doc_id}&path=setup.md"
):
from flask import request
request.decoded_token = {"sub": "u1"}
response = GetChunks().get()
data = _json(response)
assert data["total"] == 1
assert data["chunks"][0]["text"] == "c"
def test_filters_chunks_by_search_term(self, app):
from application.api.user.sources.chunks import GetChunks
doc_id = str(ObjectId())
mock_collection = Mock()
mock_collection.find_one.return_value = {"_id": ObjectId(doc_id), "user": "u1"}
chunks = [
{"text": "Python is great", "metadata": {"title": "intro"}, "doc_id": "c1"},
{"text": "Java tutorial", "metadata": {"title": "java guide"}, "doc_id": "c2"},
{"text": "Hello world", "metadata": {"title": "Python Basics"}, "doc_id": "c3"},
]
mock_store = Mock()
mock_store.get_chunks.return_value = chunks
with patch(
"application.api.user.sources.chunks.sources_collection",
mock_collection,
), patch(
"application.api.user.sources.chunks.get_vector_store",
return_value=mock_store,
):
with app.test_request_context(
f"/api/get_chunks?id={doc_id}&search=python"
):
from flask import request
request.decoded_token = {"sub": "u1"}
response = GetChunks().get()
data = _json(response)
assert data["total"] == 2
assert data["search"] == "python"
def test_combines_path_and_search_filters(self, app):
from application.api.user.sources.chunks import GetChunks
doc_id = str(ObjectId())
mock_collection = Mock()
mock_collection.find_one.return_value = {"_id": ObjectId(doc_id), "user": "u1"}
chunks = [
{"text": "Python intro", "metadata": {"source": "dir/intro.md", "title": ""}, "doc_id": "c1"},
{"text": "Python deep", "metadata": {"source": "dir/deep.md", "title": ""}, "doc_id": "c2"},
{"text": "Java intro", "metadata": {"source": "dir/intro.md", "title": ""}, "doc_id": "c3"},
]
mock_store = Mock()
mock_store.get_chunks.return_value = chunks
with patch(
"application.api.user.sources.chunks.sources_collection",
mock_collection,
), patch(
"application.api.user.sources.chunks.get_vector_store",
return_value=mock_store,
):
with app.test_request_context(
f"/api/get_chunks?id={doc_id}&path=intro.md&search=python"
):
from flask import request
request.decoded_token = {"sub": "u1"}
response = GetChunks().get()
data = _json(response)
assert data["total"] == 1
assert data["chunks"][0]["doc_id"] == "c1"
def test_returns_500_on_store_error(self, app):
from application.api.user.sources.chunks import GetChunks
doc_id = str(ObjectId())
mock_collection = Mock()
mock_collection.find_one.return_value = {"_id": ObjectId(doc_id), "user": "u1"}
mock_store = Mock()
mock_store.get_chunks.side_effect = Exception("Store error")
with patch(
"application.api.user.sources.chunks.sources_collection",
mock_collection,
), patch(
"application.api.user.sources.chunks.get_vector_store",
return_value=mock_store,
):
with app.test_request_context(f"/api/get_chunks?id={doc_id}"):
from flask import request
request.decoded_token = {"sub": "u1"}
response = GetChunks().get()
assert _status(response) == 500
def test_no_path_or_search_returns_null_fields(self, app):
from application.api.user.sources.chunks import GetChunks
doc_id = str(ObjectId())
mock_collection = Mock()
mock_collection.find_one.return_value = {"_id": ObjectId(doc_id), "user": "u1"}
mock_store = Mock()
mock_store.get_chunks.return_value = []
with patch(
"application.api.user.sources.chunks.sources_collection",
mock_collection,
), patch(
"application.api.user.sources.chunks.get_vector_store",
return_value=mock_store,
):
with app.test_request_context(f"/api/get_chunks?id={doc_id}"):
from flask import request
request.decoded_token = {"sub": "u1"}
response = GetChunks().get()
data = _json(response)
assert data["path"] is None
assert data["search"] is None
# ---------------------------------------------------------------------------
# AddChunk
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestAddChunk:
def test_returns_401_unauthenticated(self, app):
from application.api.user.sources.chunks import AddChunk
with app.test_request_context(
"/api/add_chunk", method="POST", json={"id": "abc", "text": "hi"}
):
from flask import request
request.decoded_token = None
response = AddChunk().post()
assert _status(response) == 401
def test_returns_400_missing_required_fields(self, app):
from application.api.user.sources.chunks import AddChunk
with app.test_request_context(
"/api/add_chunk", method="POST", json={"id": str(ObjectId())}
):
from flask import request
request.decoded_token = {"sub": "u1"}
response = AddChunk().post()
# check_required_fields returns a tuple (response, status)
assert response is not None
def test_returns_400_for_invalid_doc_id(self, app):
from application.api.user.sources.chunks import AddChunk
with app.test_request_context(
"/api/add_chunk", method="POST", json={"id": "bad", "text": "hi"}
):
from flask import request
request.decoded_token = {"sub": "u1"}
response = AddChunk().post()
assert _status(response) == 400
assert "Invalid doc_id" in _json(response)["error"]
def test_returns_404_when_doc_not_found(self, app):
from application.api.user.sources.chunks import AddChunk
doc_id = str(ObjectId())
mock_collection = Mock()
mock_collection.find_one.return_value = None
with patch(
"application.api.user.sources.chunks.sources_collection",
mock_collection,
):
with app.test_request_context(
"/api/add_chunk", method="POST",
json={"id": doc_id, "text": "hello"},
):
from flask import request
request.decoded_token = {"sub": "u1"}
response = AddChunk().post()
assert _status(response) == 404
def test_adds_chunk_successfully(self, app):
from application.api.user.sources.chunks import AddChunk
doc_id = str(ObjectId())
mock_collection = Mock()
mock_collection.find_one.return_value = {"_id": ObjectId(doc_id), "user": "u1"}
mock_store = Mock()
mock_store.add_chunk.return_value = "new-chunk-id"
with patch(
"application.api.user.sources.chunks.sources_collection",
mock_collection,
), patch(
"application.api.user.sources.chunks.get_vector_store",
return_value=mock_store,
), patch(
"application.api.user.sources.chunks.num_tokens_from_string",
return_value=5,
):
with app.test_request_context(
"/api/add_chunk", method="POST",
json={"id": doc_id, "text": "hello world"},
):
from flask import request
request.decoded_token = {"sub": "u1"}
response = AddChunk().post()
assert _status(response) == 201
data = _json(response)
assert data["chunk_id"] == "new-chunk-id"
assert "successfully" in data["message"]
call_args = mock_store.add_chunk.call_args
assert call_args[0][0] == "hello world"
assert call_args[0][1]["token_count"] == 5
def test_adds_chunk_with_custom_metadata(self, app):
from application.api.user.sources.chunks import AddChunk
doc_id = str(ObjectId())
mock_collection = Mock()
mock_collection.find_one.return_value = {"_id": ObjectId(doc_id), "user": "u1"}
mock_store = Mock()
mock_store.add_chunk.return_value = "cid"
with patch(
"application.api.user.sources.chunks.sources_collection",
mock_collection,
), patch(
"application.api.user.sources.chunks.get_vector_store",
return_value=mock_store,
), patch(
"application.api.user.sources.chunks.num_tokens_from_string",
return_value=3,
):
with app.test_request_context(
"/api/add_chunk", method="POST",
json={
"id": doc_id,
"text": "hi",
"metadata": {"source": "test.pdf"},
},
):
from flask import request
request.decoded_token = {"sub": "u1"}
response = AddChunk().post()
assert _status(response) == 201
meta = mock_store.add_chunk.call_args[0][1]
assert meta["source"] == "test.pdf"
assert meta["token_count"] == 3
def test_returns_500_on_store_error(self, app):
from application.api.user.sources.chunks import AddChunk
doc_id = str(ObjectId())
mock_collection = Mock()
mock_collection.find_one.return_value = {"_id": ObjectId(doc_id), "user": "u1"}
mock_store = Mock()
mock_store.add_chunk.side_effect = Exception("fail")
with patch(
"application.api.user.sources.chunks.sources_collection",
mock_collection,
), patch(
"application.api.user.sources.chunks.get_vector_store",
return_value=mock_store,
), patch(
"application.api.user.sources.chunks.num_tokens_from_string",
return_value=1,
):
with app.test_request_context(
"/api/add_chunk", method="POST",
json={"id": doc_id, "text": "hello"},
):
from flask import request
request.decoded_token = {"sub": "u1"}
response = AddChunk().post()
assert _status(response) == 500
# ---------------------------------------------------------------------------
# DeleteChunk
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestDeleteChunk:
def test_returns_401_unauthenticated(self, app):
from application.api.user.sources.chunks import DeleteChunk
with app.test_request_context("/api/delete_chunk?id=abc&chunk_id=xyz"):
from flask import request
request.decoded_token = None
response = DeleteChunk().delete()
assert _status(response) == 401
def test_returns_400_for_invalid_doc_id(self, app):
from application.api.user.sources.chunks import DeleteChunk
with app.test_request_context("/api/delete_chunk?id=bad&chunk_id=xyz"):
from flask import request
request.decoded_token = {"sub": "u1"}
response = DeleteChunk().delete()
assert _status(response) == 400
assert "Invalid doc_id" in _json(response)["error"]
def test_returns_404_when_doc_not_found(self, app):
from application.api.user.sources.chunks import DeleteChunk
doc_id = str(ObjectId())
mock_collection = Mock()
mock_collection.find_one.return_value = None
with patch(
"application.api.user.sources.chunks.sources_collection",
mock_collection,
):
with app.test_request_context(
f"/api/delete_chunk?id={doc_id}&chunk_id=cid"
):
from flask import request
request.decoded_token = {"sub": "u1"}
response = DeleteChunk().delete()
assert _status(response) == 404
def test_deletes_chunk_successfully(self, app):
from application.api.user.sources.chunks import DeleteChunk
doc_id = str(ObjectId())
mock_collection = Mock()
mock_collection.find_one.return_value = {"_id": ObjectId(doc_id), "user": "u1"}
mock_store = Mock()
mock_store.delete_chunk.return_value = True
with patch(
"application.api.user.sources.chunks.sources_collection",
mock_collection,
), patch(
"application.api.user.sources.chunks.get_vector_store",
return_value=mock_store,
):
with app.test_request_context(
f"/api/delete_chunk?id={doc_id}&chunk_id=cid"
):
from flask import request
request.decoded_token = {"sub": "u1"}
response = DeleteChunk().delete()
assert _status(response) == 200
assert "successfully" in _json(response)["message"]
mock_store.delete_chunk.assert_called_once_with("cid")
def test_returns_404_when_chunk_not_deleted(self, app):
from application.api.user.sources.chunks import DeleteChunk
doc_id = str(ObjectId())
mock_collection = Mock()
mock_collection.find_one.return_value = {"_id": ObjectId(doc_id), "user": "u1"}
mock_store = Mock()
mock_store.delete_chunk.return_value = False
with patch(
"application.api.user.sources.chunks.sources_collection",
mock_collection,
), patch(
"application.api.user.sources.chunks.get_vector_store",
return_value=mock_store,
):
with app.test_request_context(
f"/api/delete_chunk?id={doc_id}&chunk_id=missing"
):
from flask import request
request.decoded_token = {"sub": "u1"}
response = DeleteChunk().delete()
assert _status(response) == 404
assert "not found" in _json(response)["message"]
def test_returns_500_on_store_error(self, app):
from application.api.user.sources.chunks import DeleteChunk
doc_id = str(ObjectId())
mock_collection = Mock()
mock_collection.find_one.return_value = {"_id": ObjectId(doc_id), "user": "u1"}
mock_store = Mock()
mock_store.delete_chunk.side_effect = Exception("boom")
with patch(
"application.api.user.sources.chunks.sources_collection",
mock_collection,
), patch(
"application.api.user.sources.chunks.get_vector_store",
return_value=mock_store,
):
with app.test_request_context(
f"/api/delete_chunk?id={doc_id}&chunk_id=cid"
):
from flask import request
request.decoded_token = {"sub": "u1"}
response = DeleteChunk().delete()
assert _status(response) == 500
# ---------------------------------------------------------------------------
# UpdateChunk
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestUpdateChunk:
def test_returns_401_unauthenticated(self, app):
from application.api.user.sources.chunks import UpdateChunk
with app.test_request_context(
"/api/update_chunk", method="PUT",
json={"id": "abc", "chunk_id": "cid"},
):
from flask import request
request.decoded_token = None
response = UpdateChunk().put()
assert _status(response) == 401
def test_returns_400_missing_required_fields(self, app):
from application.api.user.sources.chunks import UpdateChunk
with app.test_request_context(
"/api/update_chunk", method="PUT", json={"id": str(ObjectId())}
):
from flask import request
request.decoded_token = {"sub": "u1"}
response = UpdateChunk().put()
assert response is not None
def test_returns_400_for_invalid_doc_id(self, app):
from application.api.user.sources.chunks import UpdateChunk
with app.test_request_context(
"/api/update_chunk", method="PUT",
json={"id": "bad", "chunk_id": "cid"},
):
from flask import request
request.decoded_token = {"sub": "u1"}
response = UpdateChunk().put()
assert _status(response) == 400
def test_returns_404_when_doc_not_found(self, app):
from application.api.user.sources.chunks import UpdateChunk
doc_id = str(ObjectId())
mock_collection = Mock()
mock_collection.find_one.return_value = None
with patch(
"application.api.user.sources.chunks.sources_collection",
mock_collection,
):
with app.test_request_context(
"/api/update_chunk", method="PUT",
json={"id": doc_id, "chunk_id": "cid"},
):
from flask import request
request.decoded_token = {"sub": "u1"}
response = UpdateChunk().put()
assert _status(response) == 404
def test_returns_404_when_chunk_not_found(self, app):
from application.api.user.sources.chunks import UpdateChunk
doc_id = str(ObjectId())
mock_collection = Mock()
mock_collection.find_one.return_value = {"_id": ObjectId(doc_id), "user": "u1"}
mock_store = Mock()
mock_store.get_chunks.return_value = [
{"doc_id": "other", "text": "x", "metadata": {}},
]
with patch(
"application.api.user.sources.chunks.sources_collection",
mock_collection,
), patch(
"application.api.user.sources.chunks.get_vector_store",
return_value=mock_store,
):
with app.test_request_context(
"/api/update_chunk", method="PUT",
json={"id": doc_id, "chunk_id": "missing"},
):
from flask import request
request.decoded_token = {"sub": "u1"}
response = UpdateChunk().put()
assert _status(response) == 404
assert "Chunk not found" in _json(response)["error"]
def test_updates_chunk_text_successfully(self, app):
from application.api.user.sources.chunks import UpdateChunk
doc_id = str(ObjectId())
mock_collection = Mock()
mock_collection.find_one.return_value = {"_id": ObjectId(doc_id), "user": "u1"}
mock_store = Mock()
mock_store.get_chunks.return_value = [
{"doc_id": "cid", "text": "old text", "metadata": {"source": "f.pdf"}},
]
mock_store.add_chunk.return_value = "new-cid"
mock_store.delete_chunk.return_value = True
with patch(
"application.api.user.sources.chunks.sources_collection",
mock_collection,
), patch(
"application.api.user.sources.chunks.get_vector_store",
return_value=mock_store,
), patch(
"application.api.user.sources.chunks.num_tokens_from_string",
return_value=7,
):
with app.test_request_context(
"/api/update_chunk", method="PUT",
json={"id": doc_id, "chunk_id": "cid", "text": "new text"},
):
from flask import request
request.decoded_token = {"sub": "u1"}
response = UpdateChunk().put()
assert _status(response) == 200
data = _json(response)
assert data["chunk_id"] == "new-cid"
assert data["original_chunk_id"] == "cid"
# Verify add was called with new text and merged metadata
add_call = mock_store.add_chunk.call_args
assert add_call[0][0] == "new text"
assert add_call[0][1]["source"] == "f.pdf"
assert add_call[0][1]["token_count"] == 7
mock_store.delete_chunk.assert_called_once_with("cid")
def test_updates_chunk_metadata_only(self, app):
from application.api.user.sources.chunks import UpdateChunk
doc_id = str(ObjectId())
mock_collection = Mock()
mock_collection.find_one.return_value = {"_id": ObjectId(doc_id), "user": "u1"}
mock_store = Mock()
mock_store.get_chunks.return_value = [
{"doc_id": "cid", "text": "keep me", "metadata": {"source": "f.pdf"}},
]
mock_store.add_chunk.return_value = "new-cid"
mock_store.delete_chunk.return_value = True
with patch(
"application.api.user.sources.chunks.sources_collection",
mock_collection,
), patch(
"application.api.user.sources.chunks.get_vector_store",
return_value=mock_store,
):
with app.test_request_context(
"/api/update_chunk", method="PUT",
json={
"id": doc_id,
"chunk_id": "cid",
"metadata": {"title": "new title"},
},
):
from flask import request
request.decoded_token = {"sub": "u1"}
response = UpdateChunk().put()
assert _status(response) == 200
add_call = mock_store.add_chunk.call_args
# text should be preserved
assert add_call[0][0] == "keep me"
# metadata should be merged
assert add_call[0][1]["source"] == "f.pdf"
assert add_call[0][1]["title"] == "new title"
def test_update_warns_when_old_chunk_delete_fails(self, app):
from application.api.user.sources.chunks import UpdateChunk
doc_id = str(ObjectId())
mock_collection = Mock()
mock_collection.find_one.return_value = {"_id": ObjectId(doc_id), "user": "u1"}
mock_store = Mock()
mock_store.get_chunks.return_value = [
{"doc_id": "cid", "text": "text", "metadata": {}},
]
mock_store.add_chunk.return_value = "new-cid"
mock_store.delete_chunk.return_value = False # delete fails
with patch(
"application.api.user.sources.chunks.sources_collection",
mock_collection,
), patch(
"application.api.user.sources.chunks.get_vector_store",
return_value=mock_store,
):
with app.test_request_context(
"/api/update_chunk", method="PUT",
json={"id": doc_id, "chunk_id": "cid"},
):
from flask import request
request.decoded_token = {"sub": "u1"}
response = UpdateChunk().put()
# Still returns 200 with a warning logged
assert _status(response) == 200
def test_returns_500_when_add_chunk_fails(self, app):
from application.api.user.sources.chunks import UpdateChunk
doc_id = str(ObjectId())
mock_collection = Mock()
mock_collection.find_one.return_value = {"_id": ObjectId(doc_id), "user": "u1"}
mock_store = Mock()
mock_store.get_chunks.return_value = [
{"doc_id": "cid", "text": "text", "metadata": {}},
]
mock_store.add_chunk.side_effect = Exception("add failed")
with patch(
"application.api.user.sources.chunks.sources_collection",
mock_collection,
), patch(
"application.api.user.sources.chunks.get_vector_store",
return_value=mock_store,
):
with app.test_request_context(
"/api/update_chunk", method="PUT",
json={"id": doc_id, "chunk_id": "cid", "text": "new"},
):
from flask import request
request.decoded_token = {"sub": "u1"}
response = UpdateChunk().put()
assert _status(response) == 500
assert "addition failed" in _json(response)["error"]
def test_returns_500_on_general_store_error(self, app):
from application.api.user.sources.chunks import UpdateChunk
doc_id = str(ObjectId())
mock_collection = Mock()
mock_collection.find_one.return_value = {"_id": ObjectId(doc_id), "user": "u1"}
mock_store = Mock()
mock_store.get_chunks.side_effect = Exception("connection lost")
with patch(
"application.api.user.sources.chunks.sources_collection",
mock_collection,
), patch(
"application.api.user.sources.chunks.get_vector_store",
return_value=mock_store,
):
with app.test_request_context(
"/api/update_chunk", method="PUT",
json={"id": doc_id, "chunk_id": "cid"},
):
from flask import request
request.decoded_token = {"sub": "u1"}
response = UpdateChunk().put()
assert _status(response) == 500

View File

@@ -0,0 +1,965 @@
"""Tests for source management routes (CombinedJson, PaginatedSources,
DeleteByIds, DeleteOldIndexes, ManageSync, DirectoryStructure).
Note: SyncSource and _get_provider_from_remote_data are already covered in
test_routes.py and are NOT duplicated here.
"""
import json
import pytest
from unittest.mock import MagicMock, Mock, patch
from bson import ObjectId
from flask import Flask
@pytest.fixture
def app():
app = Flask(__name__)
return app
def _status(response):
if isinstance(response, tuple):
return response[1]
return response.status_code
def _json(response):
if isinstance(response, tuple):
return response[0].json
return response.json
# ---------------------------------------------------------------------------
# CombinedJson (/api/sources)
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestCombinedJson:
def test_returns_401_unauthenticated(self, app):
from application.api.user.sources.routes import CombinedJson
with app.test_request_context("/api/sources"):
from flask import request
request.decoded_token = None
response = CombinedJson().get()
assert _status(response) == 401
def test_returns_default_source_plus_user_sources(self, app):
from application.api.user.sources.routes import CombinedJson
src_id = ObjectId()
mock_cursor = MagicMock()
mock_cursor.sort.return_value = [
{
"_id": src_id,
"name": "My Doc",
"date": "2024-01-01",
"tokens": "100",
"retriever": "classic",
"sync_frequency": "daily",
"remote_data": json.dumps({"provider": "github"}),
"directory_structure": None,
"type": "file",
}
]
mock_collection = Mock()
mock_collection.find.return_value = mock_cursor
with patch(
"application.api.user.sources.routes.sources_collection",
mock_collection,
):
with app.test_request_context("/api/sources"):
from flask import request
request.decoded_token = {"sub": "user1"}
response = CombinedJson().get()
assert _status(response) == 200
data = _json(response)
# First entry is always the Default
assert data[0]["name"] == "Default"
assert data[0]["date"] == "default"
# Second entry is user source
assert data[1]["id"] == str(src_id)
assert data[1]["name"] == "My Doc"
assert data[1]["provider"] == "github"
assert data[1]["syncFrequency"] == "daily"
assert data[1]["is_nested"] is False
def test_is_nested_true_when_directory_structure_present(self, app):
from application.api.user.sources.routes import CombinedJson
mock_cursor = MagicMock()
mock_cursor.sort.return_value = [
{
"_id": ObjectId(),
"name": "Nested",
"date": "2024-01-01",
"directory_structure": {"files": ["a.txt"]},
}
]
mock_collection = Mock()
mock_collection.find.return_value = mock_cursor
with patch(
"application.api.user.sources.routes.sources_collection",
mock_collection,
):
with app.test_request_context("/api/sources"):
from flask import request
request.decoded_token = {"sub": "u1"}
response = CombinedJson().get()
data = _json(response)
assert data[1]["is_nested"] is True
def test_returns_400_on_db_error(self, app):
from application.api.user.sources.routes import CombinedJson
mock_collection = Mock()
mock_collection.find.side_effect = Exception("db err")
with patch(
"application.api.user.sources.routes.sources_collection",
mock_collection,
):
with app.test_request_context("/api/sources"):
from flask import request
request.decoded_token = {"sub": "u1"}
response = CombinedJson().get()
assert _status(response) == 400
def test_type_defaults_to_file(self, app):
from application.api.user.sources.routes import CombinedJson
mock_cursor = MagicMock()
mock_cursor.sort.return_value = [
{"_id": ObjectId(), "name": "X", "date": "d"}
]
mock_collection = Mock()
mock_collection.find.return_value = mock_cursor
with patch(
"application.api.user.sources.routes.sources_collection",
mock_collection,
):
with app.test_request_context("/api/sources"):
from flask import request
request.decoded_token = {"sub": "u1"}
response = CombinedJson().get()
data = _json(response)
assert data[1]["type"] == "file"
# ---------------------------------------------------------------------------
# PaginatedSources (/api/sources/paginated)
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestPaginatedSources:
def test_returns_401_unauthenticated(self, app):
from application.api.user.sources.routes import PaginatedSources
with app.test_request_context("/api/sources/paginated"):
from flask import request
request.decoded_token = None
response = PaginatedSources().get()
assert _status(response) == 401
def test_returns_paginated_results(self, app):
from application.api.user.sources.routes import PaginatedSources
ids = [ObjectId() for _ in range(3)]
docs = [
{"_id": ids[i], "name": f"Doc{i}", "date": f"2024-0{i + 1}-01"}
for i in range(3)
]
mock_cursor = MagicMock()
mock_cursor.sort.return_value = mock_cursor
mock_cursor.skip.return_value = mock_cursor
mock_cursor.limit.return_value = docs
mock_collection = Mock()
mock_collection.find.return_value = mock_cursor
mock_collection.count_documents.return_value = 3
with patch(
"application.api.user.sources.routes.sources_collection",
mock_collection,
):
with app.test_request_context(
"/api/sources/paginated?page=1&rows=10&sort=date&order=desc"
):
from flask import request
request.decoded_token = {"sub": "u1"}
response = PaginatedSources().get()
assert _status(response) == 200
data = _json(response)
assert data["total"] == 3
assert data["totalPages"] == 1
assert data["currentPage"] == 1
assert len(data["paginated"]) == 3
def test_search_filter_applies_regex(self, app):
from application.api.user.sources.routes import PaginatedSources
mock_cursor = MagicMock()
mock_cursor.sort.return_value = mock_cursor
mock_cursor.skip.return_value = mock_cursor
mock_cursor.limit.return_value = []
mock_collection = Mock()
mock_collection.find.return_value = mock_cursor
mock_collection.count_documents.return_value = 0
with patch(
"application.api.user.sources.routes.sources_collection",
mock_collection,
):
with app.test_request_context(
"/api/sources/paginated?search=test%20doc"
):
from flask import request
request.decoded_token = {"sub": "u1"}
response = PaginatedSources().get()
assert _status(response) == 200
# Verify search query was passed
query_arg = mock_collection.count_documents.call_args[0][0]
assert query_arg["name"]["$regex"] == "test doc"
assert query_arg["name"]["$options"] == "i"
def test_ascending_sort_order(self, app):
from application.api.user.sources.routes import PaginatedSources
mock_cursor = MagicMock()
mock_cursor.sort.return_value = mock_cursor
mock_cursor.skip.return_value = mock_cursor
mock_cursor.limit.return_value = []
mock_collection = Mock()
mock_collection.find.return_value = mock_cursor
mock_collection.count_documents.return_value = 0
with patch(
"application.api.user.sources.routes.sources_collection",
mock_collection,
):
with app.test_request_context(
"/api/sources/paginated?order=asc&sort=name"
):
from flask import request
request.decoded_token = {"sub": "u1"}
response = PaginatedSources().get()
assert _status(response) == 200
mock_cursor.sort.assert_called_once_with("name", 1)
def test_page_clamped_to_valid_range(self, app):
from application.api.user.sources.routes import PaginatedSources
mock_cursor = MagicMock()
mock_cursor.sort.return_value = mock_cursor
mock_cursor.skip.return_value = mock_cursor
mock_cursor.limit.return_value = []
mock_collection = Mock()
mock_collection.find.return_value = mock_cursor
mock_collection.count_documents.return_value = 5 # 1 page with default 10 rows
with patch(
"application.api.user.sources.routes.sources_collection",
mock_collection,
):
with app.test_request_context(
"/api/sources/paginated?page=999"
):
from flask import request
request.decoded_token = {"sub": "u1"}
response = PaginatedSources().get()
data = _json(response)
assert data["currentPage"] == 1 # clamped
def test_returns_400_on_db_error(self, app):
from application.api.user.sources.routes import PaginatedSources
mock_collection = Mock()
mock_collection.count_documents.return_value = 0
mock_collection.find.side_effect = Exception("db error")
with patch(
"application.api.user.sources.routes.sources_collection",
mock_collection,
):
with app.test_request_context("/api/sources/paginated"):
from flask import request
request.decoded_token = {"sub": "u1"}
response = PaginatedSources().get()
assert _status(response) == 400
def test_paginated_includes_provider_and_is_nested(self, app):
from application.api.user.sources.routes import PaginatedSources
doc = {
"_id": ObjectId(),
"name": "S3 Src",
"date": "2024-01-01",
"remote_data": {"provider": "s3"},
"directory_structure": {"dirs": ["a"]},
"type": "s3",
}
mock_cursor = MagicMock()
mock_cursor.sort.return_value = mock_cursor
mock_cursor.skip.return_value = mock_cursor
mock_cursor.limit.return_value = [doc]
mock_collection = Mock()
mock_collection.find.return_value = mock_cursor
mock_collection.count_documents.return_value = 1
with patch(
"application.api.user.sources.routes.sources_collection",
mock_collection,
):
with app.test_request_context("/api/sources/paginated"):
from flask import request
request.decoded_token = {"sub": "u1"}
response = PaginatedSources().get()
data = _json(response)
entry = data["paginated"][0]
assert entry["provider"] == "s3"
assert entry["isNested"] is True
assert entry["type"] == "s3"
# ---------------------------------------------------------------------------
# DeleteByIds (/api/delete_by_ids)
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestDeleteByIds:
def test_returns_400_when_path_missing(self, app):
from application.api.user.sources.routes import DeleteByIds
with app.test_request_context("/api/delete_by_ids"):
response = DeleteByIds().get()
assert _status(response) == 400
assert "Missing" in _json(response)["message"]
def test_returns_200_on_successful_delete(self, app):
from application.api.user.sources.routes import DeleteByIds
mock_collection = Mock()
mock_collection.delete_index.return_value = True
with patch(
"application.api.user.sources.routes.sources_collection",
mock_collection,
):
with app.test_request_context("/api/delete_by_ids?path=id1,id2"):
response = DeleteByIds().get()
assert _status(response) == 200
assert _json(response)["success"] is True
mock_collection.delete_index.assert_called_once_with(ids="id1,id2")
def test_returns_400_when_delete_returns_false(self, app):
from application.api.user.sources.routes import DeleteByIds
mock_collection = Mock()
mock_collection.delete_index.return_value = False
with patch(
"application.api.user.sources.routes.sources_collection",
mock_collection,
):
with app.test_request_context("/api/delete_by_ids?path=id1"):
response = DeleteByIds().get()
assert _status(response) == 400
def test_returns_400_on_exception(self, app):
from application.api.user.sources.routes import DeleteByIds
mock_collection = Mock()
mock_collection.delete_index.side_effect = Exception("fail")
with patch(
"application.api.user.sources.routes.sources_collection",
mock_collection,
):
with app.test_request_context("/api/delete_by_ids?path=id1"):
response = DeleteByIds().get()
assert _status(response) == 400
# ---------------------------------------------------------------------------
# DeleteOldIndexes (/api/delete_old)
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestDeleteOldIndexes:
def test_returns_401_unauthenticated(self, app):
from application.api.user.sources.routes import DeleteOldIndexes
with app.test_request_context("/api/delete_old?source_id=abc"):
from flask import request
request.decoded_token = None
response = DeleteOldIndexes().get()
assert _status(response) == 401
def test_returns_400_when_source_id_missing(self, app):
from application.api.user.sources.routes import DeleteOldIndexes
with app.test_request_context("/api/delete_old"):
from flask import request
request.decoded_token = {"sub": "u1"}
response = DeleteOldIndexes().get()
assert _status(response) == 400
def test_returns_404_when_doc_not_found(self, app):
from application.api.user.sources.routes import DeleteOldIndexes
source_id = str(ObjectId())
mock_collection = Mock()
mock_collection.find_one.return_value = None
with patch(
"application.api.user.sources.routes.sources_collection",
mock_collection,
):
with app.test_request_context(f"/api/delete_old?source_id={source_id}"):
from flask import request
request.decoded_token = {"sub": "u1"}
response = DeleteOldIndexes().get()
assert _status(response) == 404
def test_deletes_faiss_index_and_file(self, app):
from application.api.user.sources.routes import DeleteOldIndexes
source_id = ObjectId()
mock_collection = Mock()
mock_collection.find_one.return_value = {
"_id": source_id,
"user": "u1",
"file_path": "uploads/u1/doc.pdf",
}
mock_storage = Mock()
mock_storage.file_exists.return_value = True
mock_storage.is_directory.return_value = False
with patch(
"application.api.user.sources.routes.sources_collection",
mock_collection,
), patch(
"application.api.user.sources.routes.StorageCreator.get_storage",
return_value=mock_storage,
), patch(
"application.api.user.sources.routes.settings"
) as mock_settings:
mock_settings.VECTOR_STORE = "faiss"
with app.test_request_context(
f"/api/delete_old?source_id={source_id}"
):
from flask import request
request.decoded_token = {"sub": "u1"}
response = DeleteOldIndexes().get()
assert _status(response) == 200
assert _json(response)["success"] is True
# Should have checked and deleted faiss files
assert mock_storage.delete_file.call_count >= 1
mock_collection.delete_one.assert_called_once()
def test_deletes_non_faiss_vector_index(self, app):
from application.api.user.sources.routes import DeleteOldIndexes
source_id = ObjectId()
mock_collection = Mock()
mock_collection.find_one.return_value = {
"_id": source_id,
"user": "u1",
}
mock_storage = Mock()
mock_vectorstore = Mock()
with patch(
"application.api.user.sources.routes.sources_collection",
mock_collection,
), patch(
"application.api.user.sources.routes.StorageCreator.get_storage",
return_value=mock_storage,
), patch(
"application.api.user.sources.routes.VectorCreator.create_vectorstore",
return_value=mock_vectorstore,
), patch(
"application.api.user.sources.routes.settings"
) as mock_settings:
mock_settings.VECTOR_STORE = "elasticsearch"
with app.test_request_context(
f"/api/delete_old?source_id={source_id}"
):
from flask import request
request.decoded_token = {"sub": "u1"}
response = DeleteOldIndexes().get()
assert _status(response) == 200
mock_vectorstore.delete_index.assert_called_once()
def test_deletes_directory_of_files(self, app):
from application.api.user.sources.routes import DeleteOldIndexes
source_id = ObjectId()
mock_collection = Mock()
mock_collection.find_one.return_value = {
"_id": source_id,
"user": "u1",
"file_path": "uploads/u1/mydir",
}
mock_storage = Mock()
mock_storage.is_directory.return_value = True
mock_storage.list_files.return_value = ["uploads/u1/mydir/a.txt", "uploads/u1/mydir/b.txt"]
with patch(
"application.api.user.sources.routes.sources_collection",
mock_collection,
), patch(
"application.api.user.sources.routes.StorageCreator.get_storage",
return_value=mock_storage,
), patch(
"application.api.user.sources.routes.settings"
) as mock_settings:
mock_settings.VECTOR_STORE = "faiss"
mock_storage.file_exists.return_value = False
with app.test_request_context(
f"/api/delete_old?source_id={source_id}"
):
from flask import request
request.decoded_token = {"sub": "u1"}
response = DeleteOldIndexes().get()
assert _status(response) == 200
# Each file in directory should be deleted
assert mock_storage.delete_file.call_count == 2
def test_handles_file_not_found_gracefully(self, app):
from application.api.user.sources.routes import DeleteOldIndexes
source_id = ObjectId()
mock_collection = Mock()
mock_collection.find_one.return_value = {
"_id": source_id,
"user": "u1",
"file_path": "uploads/missing.pdf",
}
mock_storage = Mock()
mock_storage.is_directory.side_effect = FileNotFoundError()
with patch(
"application.api.user.sources.routes.sources_collection",
mock_collection,
), patch(
"application.api.user.sources.routes.StorageCreator.get_storage",
return_value=mock_storage,
), patch(
"application.api.user.sources.routes.settings"
) as mock_settings:
mock_settings.VECTOR_STORE = "faiss"
mock_storage.file_exists.return_value = False
with app.test_request_context(
f"/api/delete_old?source_id={source_id}"
):
from flask import request
request.decoded_token = {"sub": "u1"}
response = DeleteOldIndexes().get()
assert _status(response) == 200
mock_collection.delete_one.assert_called_once()
def test_returns_400_on_general_error(self, app):
from application.api.user.sources.routes import DeleteOldIndexes
source_id = ObjectId()
mock_collection = Mock()
mock_collection.find_one.return_value = {
"_id": source_id,
"user": "u1",
}
mock_storage = Mock()
with patch(
"application.api.user.sources.routes.sources_collection",
mock_collection,
), patch(
"application.api.user.sources.routes.StorageCreator.get_storage",
return_value=mock_storage,
), patch(
"application.api.user.sources.routes.settings"
) as mock_settings:
mock_settings.VECTOR_STORE = "faiss"
mock_storage.file_exists.side_effect = RuntimeError("disk error")
with app.test_request_context(
f"/api/delete_old?source_id={source_id}"
):
from flask import request
request.decoded_token = {"sub": "u1"}
response = DeleteOldIndexes().get()
assert _status(response) == 400
# ---------------------------------------------------------------------------
# ManageSync (/api/manage_sync)
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestManageSync:
def test_returns_401_unauthenticated(self, app):
from application.api.user.sources.routes import ManageSync
with app.test_request_context(
"/api/manage_sync", method="POST",
json={"source_id": "x", "sync_frequency": "daily"},
):
from flask import request
request.decoded_token = None
response = ManageSync().post()
assert _status(response) == 401
def test_returns_400_missing_fields(self, app):
from application.api.user.sources.routes import ManageSync
with app.test_request_context(
"/api/manage_sync", method="POST",
json={"source_id": "abc"},
):
from flask import request
request.decoded_token = {"sub": "u1"}
response = ManageSync().post()
assert response is not None
def test_returns_400_for_invalid_frequency(self, app):
from application.api.user.sources.routes import ManageSync
with app.test_request_context(
"/api/manage_sync", method="POST",
json={"source_id": str(ObjectId()), "sync_frequency": "hourly"},
):
from flask import request
request.decoded_token = {"sub": "u1"}
response = ManageSync().post()
assert _status(response) == 400
assert "Invalid frequency" in _json(response)["message"]
def test_updates_sync_frequency_successfully(self, app):
from application.api.user.sources.routes import ManageSync
source_id = str(ObjectId())
mock_collection = Mock()
with patch(
"application.api.user.sources.routes.sources_collection",
mock_collection,
):
with app.test_request_context(
"/api/manage_sync", method="POST",
json={"source_id": source_id, "sync_frequency": "weekly"},
):
from flask import request
request.decoded_token = {"sub": "u1"}
response = ManageSync().post()
assert _status(response) == 200
assert _json(response)["success"] is True
call_args = mock_collection.update_one.call_args
assert call_args[0][0]["_id"] == ObjectId(source_id)
assert call_args[0][0]["user"] == "u1"
assert call_args[0][1]["$set"]["sync_frequency"] == "weekly"
def test_accepts_all_valid_frequencies(self, app):
from application.api.user.sources.routes import ManageSync
mock_collection = Mock()
for freq in ["never", "daily", "weekly", "monthly"]:
with patch(
"application.api.user.sources.routes.sources_collection",
mock_collection,
):
with app.test_request_context(
"/api/manage_sync", method="POST",
json={"source_id": str(ObjectId()), "sync_frequency": freq},
):
from flask import request
request.decoded_token = {"sub": "u1"}
response = ManageSync().post()
assert _status(response) == 200
def test_returns_400_on_db_error(self, app):
from application.api.user.sources.routes import ManageSync
mock_collection = Mock()
mock_collection.update_one.side_effect = Exception("db err")
with patch(
"application.api.user.sources.routes.sources_collection",
mock_collection,
):
with app.test_request_context(
"/api/manage_sync", method="POST",
json={"source_id": str(ObjectId()), "sync_frequency": "daily"},
):
from flask import request
request.decoded_token = {"sub": "u1"}
response = ManageSync().post()
assert _status(response) == 400
# ---------------------------------------------------------------------------
# RedirectToSources (/api/combine)
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestRedirectToSources:
def test_redirects_to_sources(self, app):
from application.api.user.sources.routes import RedirectToSources
with app.test_request_context("/api/combine"):
response = RedirectToSources().get()
assert response.status_code == 301
assert response.location == "/api/sources"
# ---------------------------------------------------------------------------
# DirectoryStructure (/api/directory_structure)
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestDirectoryStructure:
def test_returns_401_unauthenticated(self, app):
from application.api.user.sources.routes import DirectoryStructure
with app.test_request_context("/api/directory_structure?id=abc"):
from flask import request
request.decoded_token = None
response = DirectoryStructure().get()
assert _status(response) == 401
def test_returns_400_when_id_missing(self, app):
from application.api.user.sources.routes import DirectoryStructure
with app.test_request_context("/api/directory_structure"):
from flask import request
request.decoded_token = {"sub": "u1"}
response = DirectoryStructure().get()
assert _status(response) == 400
assert "required" in _json(response)["error"]
def test_returns_400_for_invalid_doc_id(self, app):
from application.api.user.sources.routes import DirectoryStructure
with app.test_request_context("/api/directory_structure?id=invalid"):
from flask import request
request.decoded_token = {"sub": "u1"}
response = DirectoryStructure().get()
assert _status(response) == 400
assert "Invalid" in _json(response)["error"]
def test_returns_404_when_doc_not_found(self, app):
from application.api.user.sources.routes import DirectoryStructure
doc_id = str(ObjectId())
mock_collection = Mock()
mock_collection.find_one.return_value = None
with patch(
"application.api.user.sources.routes.sources_collection",
mock_collection,
):
with app.test_request_context(f"/api/directory_structure?id={doc_id}"):
from flask import request
request.decoded_token = {"sub": "u1"}
response = DirectoryStructure().get()
assert _status(response) == 404
assert "not found" in _json(response)["error"]
def test_returns_directory_structure(self, app):
from application.api.user.sources.routes import DirectoryStructure
doc_id = ObjectId()
dir_struct = {"dirs": ["a", "b"], "files": ["c.txt"]}
mock_collection = Mock()
mock_collection.find_one.return_value = {
"_id": doc_id,
"user": "u1",
"directory_structure": dir_struct,
"file_path": "uploads/u1/mydir",
"remote_data": json.dumps({"provider": "github"}),
}
with patch(
"application.api.user.sources.routes.sources_collection",
mock_collection,
):
with app.test_request_context(
f"/api/directory_structure?id={doc_id}"
):
from flask import request
request.decoded_token = {"sub": "u1"}
response = DirectoryStructure().get()
assert _status(response) == 200
data = _json(response)
assert data["success"] is True
assert data["directory_structure"] == dir_struct
assert data["base_path"] == "uploads/u1/mydir"
assert data["provider"] == "github"
def test_returns_none_provider_when_no_remote_data(self, app):
from application.api.user.sources.routes import DirectoryStructure
doc_id = ObjectId()
mock_collection = Mock()
mock_collection.find_one.return_value = {
"_id": doc_id,
"user": "u1",
"directory_structure": {},
"file_path": "path",
}
with patch(
"application.api.user.sources.routes.sources_collection",
mock_collection,
):
with app.test_request_context(
f"/api/directory_structure?id={doc_id}"
):
from flask import request
request.decoded_token = {"sub": "u1"}
response = DirectoryStructure().get()
data = _json(response)
assert data["provider"] is None
def test_handles_invalid_remote_data_json(self, app):
from application.api.user.sources.routes import DirectoryStructure
doc_id = ObjectId()
mock_collection = Mock()
mock_collection.find_one.return_value = {
"_id": doc_id,
"user": "u1",
"directory_structure": {},
"file_path": "path",
"remote_data": "not-valid-json{",
}
with patch(
"application.api.user.sources.routes.sources_collection",
mock_collection,
):
with app.test_request_context(
f"/api/directory_structure?id={doc_id}"
):
from flask import request
request.decoded_token = {"sub": "u1"}
response = DirectoryStructure().get()
data = _json(response)
assert data["success"] is True
assert data["provider"] is None
def test_returns_500_on_general_error(self, app):
from application.api.user.sources.routes import DirectoryStructure
doc_id = str(ObjectId())
mock_collection = Mock()
mock_collection.find_one.side_effect = Exception("db error")
with patch(
"application.api.user.sources.routes.sources_collection",
mock_collection,
):
with app.test_request_context(
f"/api/directory_structure?id={doc_id}"
):
from flask import request
request.decoded_token = {"sub": "u1"}
response = DirectoryStructure().get()
assert _status(response) == 500

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,768 @@
"""Tests for application.api.user.agents.sharing module."""
from unittest.mock import Mock, patch
import pytest
from bson import DBRef, ObjectId
from flask import Flask
@pytest.fixture
def app():
app = Flask(__name__)
return app
# ---------------------------------------------------------------------------
# SharedAgent (GET /shared_agent)
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestSharedAgent:
def test_returns_400_missing_token(self, app):
from application.api.user.agents.sharing import SharedAgent
with app.test_request_context("/api/shared_agent"):
response = SharedAgent().get()
assert response.status_code == 400
def test_returns_404_agent_not_found(self, app):
from application.api.user.agents.sharing import SharedAgent
mock_col = Mock()
mock_col.find_one.return_value = None
with patch(
"application.api.user.agents.sharing.agents_collection", mock_col
):
with app.test_request_context("/api/shared_agent?token=abc123"):
response = SharedAgent().get()
assert response.status_code == 404
def test_returns_shared_agent_data(self, app):
from application.api.user.agents.sharing import SharedAgent
agent_id = ObjectId()
mock_agents_col = Mock()
mock_agents_col.find_one.return_value = {
"_id": agent_id,
"user": "owner1",
"name": "Shared Agent",
"description": "A shared agent",
"chunks": "5",
"retriever": "classic",
"prompt_id": "default",
"tools": [],
"agent_type": "classic",
"status": "published",
"shared_publicly": True,
"shared_token": "abc123",
}
mock_resolve = Mock(return_value=[])
mock_db = Mock()
with patch(
"application.api.user.agents.sharing.agents_collection", mock_agents_col
), patch(
"application.api.user.agents.sharing.resolve_tool_details", mock_resolve
), patch(
"application.api.user.agents.sharing.db", mock_db
):
with app.test_request_context("/api/shared_agent?token=abc123"):
from flask import request
# No decoded_token -> anonymous access
request.decoded_token = None
response = SharedAgent().get()
assert response.status_code == 200
data = response.json
assert data["id"] == str(agent_id)
assert data["name"] == "Shared Agent"
assert data["shared"] is True
def test_adds_to_shared_with_me_for_different_user(self, app):
from application.api.user.agents.sharing import SharedAgent
agent_id = ObjectId()
mock_agents_col = Mock()
mock_agents_col.find_one.return_value = {
"_id": agent_id,
"user": "owner1",
"name": "Agent",
"tools": [],
"shared_publicly": True,
"shared_token": "abc123",
}
mock_resolve = Mock(return_value=[])
mock_db = Mock()
mock_ensure = Mock(return_value={"user_id": "user2"})
mock_users_col = Mock()
with patch(
"application.api.user.agents.sharing.agents_collection", mock_agents_col
), patch(
"application.api.user.agents.sharing.resolve_tool_details", mock_resolve
), patch(
"application.api.user.agents.sharing.db", mock_db
), patch(
"application.api.user.agents.sharing.ensure_user_doc", mock_ensure
), patch(
"application.api.user.agents.sharing.users_collection", mock_users_col
):
with app.test_request_context("/api/shared_agent?token=abc123"):
from flask import request
request.decoded_token = {"sub": "user2"}
response = SharedAgent().get()
assert response.status_code == 200
mock_ensure.assert_called_once_with("user2")
mock_users_col.update_one.assert_called_once()
def test_does_not_add_to_shared_for_owner(self, app):
from application.api.user.agents.sharing import SharedAgent
agent_id = ObjectId()
mock_agents_col = Mock()
mock_agents_col.find_one.return_value = {
"_id": agent_id,
"user": "owner1",
"name": "Agent",
"tools": [],
"shared_publicly": True,
"shared_token": "abc123",
}
mock_resolve = Mock(return_value=[])
mock_db = Mock()
mock_ensure = Mock()
mock_users_col = Mock()
with patch(
"application.api.user.agents.sharing.agents_collection", mock_agents_col
), patch(
"application.api.user.agents.sharing.resolve_tool_details", mock_resolve
), patch(
"application.api.user.agents.sharing.db", mock_db
), patch(
"application.api.user.agents.sharing.ensure_user_doc", mock_ensure
), patch(
"application.api.user.agents.sharing.users_collection", mock_users_col
):
with app.test_request_context("/api/shared_agent?token=abc123"):
from flask import request
request.decoded_token = {"sub": "owner1"}
response = SharedAgent().get()
assert response.status_code == 200
mock_ensure.assert_not_called()
mock_users_col.update_one.assert_not_called()
def test_enriches_tool_names(self, app):
from application.api.user.agents.sharing import SharedAgent
agent_id = ObjectId()
tool_id = str(ObjectId())
mock_agents_col = Mock()
mock_agents_col.find_one.return_value = {
"_id": agent_id,
"user": "owner1",
"name": "Agent",
"tools": [tool_id],
"shared_publicly": True,
"shared_token": "tok",
}
mock_tools_col = Mock()
mock_tools_col.find_one.return_value = {
"_id": ObjectId(tool_id),
"name": "calculator",
}
mock_resolve = Mock(return_value=[])
mock_db = Mock()
with patch(
"application.api.user.agents.sharing.agents_collection", mock_agents_col
), patch(
"application.api.user.agents.sharing.user_tools_collection", mock_tools_col
), patch(
"application.api.user.agents.sharing.resolve_tool_details", mock_resolve
), patch(
"application.api.user.agents.sharing.db", mock_db
):
with app.test_request_context("/api/shared_agent?token=tok"):
from flask import request
request.decoded_token = None
response = SharedAgent().get()
assert response.status_code == 200
assert response.json["tools"] == ["calculator"]
def test_handles_source_dbref(self, app):
from application.api.user.agents.sharing import SharedAgent
agent_id = ObjectId()
source_id = ObjectId()
source_ref = DBRef("sources", source_id)
mock_agents_col = Mock()
mock_agents_col.find_one.return_value = {
"_id": agent_id,
"user": "owner1",
"name": "Agent",
"source": source_ref,
"tools": [],
"shared_publicly": True,
"shared_token": "tok",
}
mock_resolve = Mock(return_value=[])
mock_db = Mock()
mock_db.dereference.return_value = {"_id": source_id}
with patch(
"application.api.user.agents.sharing.agents_collection", mock_agents_col
), patch(
"application.api.user.agents.sharing.resolve_tool_details", mock_resolve
), patch(
"application.api.user.agents.sharing.db", mock_db
):
with app.test_request_context("/api/shared_agent?token=tok"):
from flask import request
request.decoded_token = None
response = SharedAgent().get()
assert response.status_code == 200
assert response.json["source"] == str(source_id)
def test_returns_400_on_exception(self, app):
from application.api.user.agents.sharing import SharedAgent
mock_col = Mock()
mock_col.find_one.side_effect = Exception("DB error")
with patch(
"application.api.user.agents.sharing.agents_collection", mock_col
):
with app.test_request_context("/api/shared_agent?token=tok"):
response = SharedAgent().get()
assert response.status_code == 400
def test_tool_enrichment_handles_missing_tool(self, app):
from application.api.user.agents.sharing import SharedAgent
agent_id = ObjectId()
tool_id = str(ObjectId())
mock_agents_col = Mock()
mock_agents_col.find_one.return_value = {
"_id": agent_id,
"user": "owner1",
"name": "Agent",
"tools": [tool_id],
"shared_publicly": True,
"shared_token": "tok",
}
mock_tools_col = Mock()
mock_tools_col.find_one.return_value = None
mock_resolve = Mock(return_value=[])
mock_db = Mock()
with patch(
"application.api.user.agents.sharing.agents_collection", mock_agents_col
), patch(
"application.api.user.agents.sharing.user_tools_collection", mock_tools_col
), patch(
"application.api.user.agents.sharing.resolve_tool_details", mock_resolve
), patch(
"application.api.user.agents.sharing.db", mock_db
):
with app.test_request_context("/api/shared_agent?token=tok"):
from flask import request
request.decoded_token = None
response = SharedAgent().get()
assert response.status_code == 200
# Missing tools are skipped
assert response.json["tools"] == []
def test_image_url_generated_when_present(self, app):
from application.api.user.agents.sharing import SharedAgent
agent_id = ObjectId()
mock_agents_col = Mock()
mock_agents_col.find_one.return_value = {
"_id": agent_id,
"user": "owner1",
"name": "Agent",
"image": "path/to/img.png",
"tools": [],
"shared_publicly": True,
"shared_token": "tok",
}
mock_resolve = Mock(return_value=[])
mock_db = Mock()
mock_generate = Mock(return_value="http://example.com/img.png")
with patch(
"application.api.user.agents.sharing.agents_collection", mock_agents_col
), patch(
"application.api.user.agents.sharing.resolve_tool_details", mock_resolve
), patch(
"application.api.user.agents.sharing.db", mock_db
), patch(
"application.api.user.agents.sharing.generate_image_url", mock_generate
):
with app.test_request_context("/api/shared_agent?token=tok"):
from flask import request
request.decoded_token = None
response = SharedAgent().get()
assert response.status_code == 200
assert response.json["image"] == "http://example.com/img.png"
mock_generate.assert_called_once_with("path/to/img.png")
# ---------------------------------------------------------------------------
# SharedAgents (GET /shared_agents)
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestSharedAgents:
def test_returns_401_unauthenticated(self, app):
from application.api.user.agents.sharing import SharedAgents
with app.test_request_context("/api/shared_agents"):
from flask import request
request.decoded_token = None
response = SharedAgents().get()
assert response.status_code == 401
def test_returns_shared_agents_list(self, app):
from application.api.user.agents.sharing import SharedAgents
agent_id = ObjectId()
mock_ensure = Mock(
return_value={
"user_id": "user1",
"agent_preferences": {
"shared_with_me": [str(agent_id)],
"pinned": [str(agent_id)],
},
}
)
mock_agents_col = Mock()
mock_agents_col.find.return_value = [
{
"_id": agent_id,
"name": "Shared Agent",
"description": "desc",
"tools": [],
"agent_type": "classic",
"status": "published",
"shared_publicly": True,
"shared_token": "tok123",
}
]
mock_resolve = Mock(return_value=[])
mock_users_col = Mock()
with patch(
"application.api.user.agents.sharing.ensure_user_doc", mock_ensure
), patch(
"application.api.user.agents.sharing.agents_collection", mock_agents_col
), patch(
"application.api.user.agents.sharing.resolve_tool_details", mock_resolve
), patch(
"application.api.user.agents.sharing.users_collection", mock_users_col
):
with app.test_request_context("/api/shared_agents"):
from flask import request
request.decoded_token = {"sub": "user1"}
response = SharedAgents().get()
assert response.status_code == 200
data = response.json
assert len(data) == 1
assert data[0]["name"] == "Shared Agent"
assert data[0]["pinned"] is True
def test_removes_stale_shared_ids(self, app):
from application.api.user.agents.sharing import SharedAgents
stale_id = str(ObjectId())
mock_ensure = Mock(
return_value={
"user_id": "user1",
"agent_preferences": {
"shared_with_me": [stale_id],
"pinned": [],
},
}
)
mock_agents_col = Mock()
mock_agents_col.find.return_value = [] # None found
mock_users_col = Mock()
with patch(
"application.api.user.agents.sharing.ensure_user_doc", mock_ensure
), patch(
"application.api.user.agents.sharing.agents_collection", mock_agents_col
), patch(
"application.api.user.agents.sharing.users_collection", mock_users_col
):
with app.test_request_context("/api/shared_agents"):
from flask import request
request.decoded_token = {"sub": "user1"}
response = SharedAgents().get()
assert response.status_code == 200
mock_users_col.update_one.assert_called_once()
call_args = mock_users_col.update_one.call_args
assert stale_id in call_args[0][1]["$pullAll"][
"agent_preferences.shared_with_me"
]
def test_returns_empty_when_no_shared_ids(self, app):
from application.api.user.agents.sharing import SharedAgents
mock_ensure = Mock(
return_value={
"user_id": "user1",
"agent_preferences": {"shared_with_me": [], "pinned": []},
}
)
mock_agents_col = Mock()
mock_agents_col.find.return_value = []
with patch(
"application.api.user.agents.sharing.ensure_user_doc", mock_ensure
), patch(
"application.api.user.agents.sharing.agents_collection", mock_agents_col
):
with app.test_request_context("/api/shared_agents"):
from flask import request
request.decoded_token = {"sub": "user1"}
response = SharedAgents().get()
assert response.status_code == 200
assert response.json == []
def test_returns_400_on_exception(self, app):
from application.api.user.agents.sharing import SharedAgents
mock_ensure = Mock(side_effect=Exception("DB error"))
with patch(
"application.api.user.agents.sharing.ensure_user_doc", mock_ensure
):
with app.test_request_context("/api/shared_agents"):
from flask import request
request.decoded_token = {"sub": "user1"}
response = SharedAgents().get()
assert response.status_code == 400
def test_image_url_generated(self, app):
from application.api.user.agents.sharing import SharedAgents
agent_id = ObjectId()
mock_ensure = Mock(
return_value={
"user_id": "user1",
"agent_preferences": {
"shared_with_me": [str(agent_id)],
"pinned": [],
},
}
)
mock_agents_col = Mock()
mock_agents_col.find.return_value = [
{
"_id": agent_id,
"name": "Agent",
"image": "path.png",
"tools": [],
"shared_publicly": True,
}
]
mock_resolve = Mock(return_value=[])
mock_generate = Mock(return_value="http://example.com/path.png")
with patch(
"application.api.user.agents.sharing.ensure_user_doc", mock_ensure
), patch(
"application.api.user.agents.sharing.agents_collection", mock_agents_col
), patch(
"application.api.user.agents.sharing.resolve_tool_details", mock_resolve
), patch(
"application.api.user.agents.sharing.generate_image_url", mock_generate
), patch(
"application.api.user.agents.sharing.users_collection", Mock()
):
with app.test_request_context("/api/shared_agents"):
from flask import request
request.decoded_token = {"sub": "user1"}
response = SharedAgents().get()
assert response.status_code == 200
assert response.json[0]["image"] == "http://example.com/path.png"
# ---------------------------------------------------------------------------
# ShareAgent (PUT /share_agent)
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestShareAgent:
def test_returns_401_unauthenticated(self, app):
from application.api.user.agents.sharing import ShareAgent
with app.test_request_context(
"/api/share_agent",
method="PUT",
json={"id": "abc", "shared": True},
):
from flask import request
request.decoded_token = None
response = ShareAgent().put()
assert response.status_code == 401
def test_returns_400_missing_json_body(self, app):
from application.api.user.agents.sharing import ShareAgent
with app.test_request_context(
"/api/share_agent",
method="PUT",
content_type="application/json",
data=b"{}",
):
from flask import request
request.decoded_token = {"sub": "user1"}
# Empty JSON object -> no id, no shared -> 400
response = ShareAgent().put()
assert response.status_code == 400
assert response.json["success"] is False
def test_returns_400_missing_id(self, app):
from application.api.user.agents.sharing import ShareAgent
with app.test_request_context(
"/api/share_agent",
method="PUT",
json={"shared": True},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = ShareAgent().put()
assert response.status_code == 400
def test_returns_400_missing_shared_param(self, app):
from application.api.user.agents.sharing import ShareAgent
with app.test_request_context(
"/api/share_agent",
method="PUT",
json={"id": str(ObjectId())},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = ShareAgent().put()
assert response.status_code == 400
def test_returns_400_invalid_agent_id(self, app):
from application.api.user.agents.sharing import ShareAgent
with app.test_request_context(
"/api/share_agent",
method="PUT",
json={"id": "invalid-oid", "shared": True},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = ShareAgent().put()
assert response.status_code == 400
def test_returns_404_agent_not_found(self, app):
from application.api.user.agents.sharing import ShareAgent
mock_col = Mock()
mock_col.find_one.return_value = None
agent_id = str(ObjectId())
with patch(
"application.api.user.agents.sharing.agents_collection", mock_col
):
with app.test_request_context(
"/api/share_agent",
method="PUT",
json={"id": agent_id, "shared": True},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = ShareAgent().put()
assert response.status_code == 404
def test_shares_agent_success(self, app):
from application.api.user.agents.sharing import ShareAgent
agent_id = ObjectId()
mock_col = Mock()
mock_col.find_one.return_value = {
"_id": agent_id,
"user": "user1",
}
mock_col.update_one.return_value = Mock()
with patch(
"application.api.user.agents.sharing.agents_collection", mock_col
):
with app.test_request_context(
"/api/share_agent",
method="PUT",
json={
"id": str(agent_id),
"shared": True,
"username": "TestUser",
},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = ShareAgent().put()
assert response.status_code == 200
data = response.json
assert data["success"] is True
assert data["shared_token"] is not None
mock_col.update_one.assert_called_once()
def test_unshares_agent_success(self, app):
from application.api.user.agents.sharing import ShareAgent
agent_id = ObjectId()
mock_col = Mock()
mock_col.find_one.return_value = {
"_id": agent_id,
"user": "user1",
}
mock_col.update_one.return_value = Mock()
with patch(
"application.api.user.agents.sharing.agents_collection", mock_col
):
with app.test_request_context(
"/api/share_agent",
method="PUT",
json={
"id": str(agent_id),
"shared": False,
},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = ShareAgent().put()
assert response.status_code == 200
data = response.json
assert data["success"] is True
assert data["shared_token"] is None
def test_returns_400_on_db_exception(self, app):
from application.api.user.agents.sharing import ShareAgent
agent_id = ObjectId()
mock_col = Mock()
mock_col.find_one.return_value = {
"_id": agent_id,
"user": "user1",
}
mock_col.update_one.side_effect = Exception("DB error")
with patch(
"application.api.user.agents.sharing.agents_collection", mock_col
):
with app.test_request_context(
"/api/share_agent",
method="PUT",
json={
"id": str(agent_id),
"shared": True,
},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = ShareAgent().put()
assert response.status_code == 400
def test_share_with_username(self, app):
from application.api.user.agents.sharing import ShareAgent
agent_id = ObjectId()
mock_col = Mock()
mock_col.find_one.return_value = {
"_id": agent_id,
"user": "user1",
}
mock_col.update_one.return_value = Mock()
with patch(
"application.api.user.agents.sharing.agents_collection", mock_col
):
with app.test_request_context(
"/api/share_agent",
method="PUT",
json={
"id": str(agent_id),
"shared": True,
"username": "SharedByUser",
},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = ShareAgent().put()
assert response.status_code == 200
# Verify the update call includes shared_metadata with username
update_call = mock_col.update_one.call_args[0][1]["$set"]
assert update_call["shared_metadata"]["shared_by"] == "SharedByUser"
assert update_call["shared_publicly"] is True
assert "shared_token" in update_call
def test_shared_false_explicitly(self, app):
from application.api.user.agents.sharing import ShareAgent
agent_id = ObjectId()
mock_col = Mock()
mock_col.find_one.return_value = {
"_id": agent_id,
"user": "user1",
}
mock_col.update_one.return_value = Mock()
with patch(
"application.api.user.agents.sharing.agents_collection", mock_col
):
with app.test_request_context(
"/api/share_agent",
method="PUT",
json={
"id": str(agent_id),
"shared": False,
},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = ShareAgent().put()
assert response.status_code == 200
update_call = mock_col.update_one.call_args[0][1]
assert update_call["$set"]["shared_publicly"] is False
assert update_call["$set"]["shared_token"] is None

View File

@@ -0,0 +1,899 @@
import datetime
from unittest.mock import Mock, patch
import pytest
from bson import ObjectId
from flask import Flask
@pytest.fixture
def app():
app = Flask(__name__)
return app
@pytest.mark.unit
class TestGetMessageAnalytics:
def test_returns_message_analytics_last_30_days(self, app):
from application.api.user.analytics.routes import GetMessageAnalytics
mock_conversations = Mock()
mock_conversations.aggregate.return_value = [
{"_id": "2024-06-01", "count": 5},
{"_id": "2024-06-02", "count": 3},
]
mock_agents = Mock()
mock_agents.find_one.return_value = None
with patch(
"application.api.user.analytics.routes.conversations_collection",
mock_conversations,
), patch(
"application.api.user.analytics.routes.agents_collection",
mock_agents,
):
with app.test_request_context(
"/api/get_message_analytics",
method="POST",
json={"filter_option": "last_30_days"},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = GetMessageAnalytics().post()
assert response.status_code == 200
assert response.json["success"] is True
assert "messages" in response.json
def test_returns_401_unauthenticated(self, app):
from application.api.user.analytics.routes import GetMessageAnalytics
with app.test_request_context(
"/api/get_message_analytics",
method="POST",
json={"filter_option": "last_30_days"},
):
from flask import request
request.decoded_token = None
response = GetMessageAnalytics().post()
assert response.status_code == 401
def test_returns_400_invalid_filter_option(self, app):
from application.api.user.analytics.routes import GetMessageAnalytics
mock_agents = Mock()
mock_agents.find_one.return_value = None
with patch(
"application.api.user.analytics.routes.agents_collection",
mock_agents,
):
with app.test_request_context(
"/api/get_message_analytics",
method="POST",
json={"filter_option": "invalid_option"},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = GetMessageAnalytics().post()
assert response.status_code == 400
def test_filters_by_api_key(self, app):
from application.api.user.analytics.routes import GetMessageAnalytics
agent_id = ObjectId()
mock_agents = Mock()
mock_agents.find_one.return_value = {
"_id": agent_id,
"key": "api_key_value",
}
mock_conversations = Mock()
mock_conversations.aggregate.return_value = []
with patch(
"application.api.user.analytics.routes.agents_collection",
mock_agents,
), patch(
"application.api.user.analytics.routes.conversations_collection",
mock_conversations,
):
with app.test_request_context(
"/api/get_message_analytics",
method="POST",
json={
"filter_option": "last_7_days",
"api_key_id": str(agent_id),
},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = GetMessageAnalytics().post()
assert response.status_code == 200
pipeline = mock_conversations.aggregate.call_args[0][0]
assert pipeline[0]["$match"].get("api_key") == "api_key_value"
def test_last_hour_filter(self, app):
from application.api.user.analytics.routes import GetMessageAnalytics
mock_conversations = Mock()
mock_conversations.aggregate.return_value = []
mock_agents = Mock()
mock_agents.find_one.return_value = None
with patch(
"application.api.user.analytics.routes.conversations_collection",
mock_conversations,
), patch(
"application.api.user.analytics.routes.agents_collection",
mock_agents,
):
with app.test_request_context(
"/api/get_message_analytics",
method="POST",
json={"filter_option": "last_hour"},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = GetMessageAnalytics().post()
assert response.status_code == 200
def test_last_24_hour_filter(self, app):
from application.api.user.analytics.routes import GetMessageAnalytics
mock_conversations = Mock()
mock_conversations.aggregate.return_value = []
mock_agents = Mock()
mock_agents.find_one.return_value = None
with patch(
"application.api.user.analytics.routes.conversations_collection",
mock_conversations,
), patch(
"application.api.user.analytics.routes.agents_collection",
mock_agents,
):
with app.test_request_context(
"/api/get_message_analytics",
method="POST",
json={"filter_option": "last_24_hour"},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = GetMessageAnalytics().post()
assert response.status_code == 200
@pytest.mark.unit
class TestGetTokenAnalytics:
def test_returns_token_analytics(self, app):
from application.api.user.analytics.routes import GetTokenAnalytics
mock_token_usage = Mock()
mock_token_usage.aggregate.return_value = [
{"_id": {"day": "2024-06-01"}, "total_tokens": 1000}
]
mock_agents = Mock()
mock_agents.find_one.return_value = None
with patch(
"application.api.user.analytics.routes.token_usage_collection",
mock_token_usage,
), patch(
"application.api.user.analytics.routes.agents_collection",
mock_agents,
):
with app.test_request_context(
"/api/get_token_analytics",
method="POST",
json={"filter_option": "last_30_days"},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = GetTokenAnalytics().post()
assert response.status_code == 200
assert response.json["success"] is True
assert "token_usage" in response.json
def test_returns_400_invalid_filter(self, app):
from application.api.user.analytics.routes import GetTokenAnalytics
mock_agents = Mock()
mock_agents.find_one.return_value = None
with patch(
"application.api.user.analytics.routes.agents_collection",
mock_agents,
):
with app.test_request_context(
"/api/get_token_analytics",
method="POST",
json={"filter_option": "invalid"},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = GetTokenAnalytics().post()
assert response.status_code == 400
@pytest.mark.unit
class TestGetFeedbackAnalytics:
def test_returns_feedback_analytics(self, app):
from application.api.user.analytics.routes import GetFeedbackAnalytics
mock_conversations = Mock()
mock_conversations.aggregate.return_value = [
{"_id": "2024-06-01", "positive": 10, "negative": 2}
]
mock_agents = Mock()
mock_agents.find_one.return_value = None
with patch(
"application.api.user.analytics.routes.conversations_collection",
mock_conversations,
), patch(
"application.api.user.analytics.routes.agents_collection",
mock_agents,
):
with app.test_request_context(
"/api/get_feedback_analytics",
method="POST",
json={"filter_option": "last_30_days"},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = GetFeedbackAnalytics().post()
assert response.status_code == 200
assert response.json["success"] is True
assert "feedback" in response.json
def test_returns_400_invalid_filter(self, app):
from application.api.user.analytics.routes import GetFeedbackAnalytics
mock_agents = Mock()
mock_agents.find_one.return_value = None
with patch(
"application.api.user.analytics.routes.agents_collection",
mock_agents,
):
with app.test_request_context(
"/api/get_feedback_analytics",
method="POST",
json={"filter_option": "bad"},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = GetFeedbackAnalytics().post()
assert response.status_code == 400
@pytest.mark.unit
class TestGetUserLogs:
def test_returns_paginated_logs(self, app):
from application.api.user.analytics.routes import GetUserLogs
log_id = ObjectId()
mock_cursor = Mock()
mock_cursor.sort.return_value.skip.return_value.limit.return_value = [
{
"_id": log_id,
"action": "query",
"level": "info",
"user": "user1",
"question": "test?",
"sources": [],
"retriever_params": {},
"timestamp": datetime.datetime(2024, 6, 1),
}
]
mock_user_logs = Mock()
mock_user_logs.find.return_value = mock_cursor
mock_agents = Mock()
mock_agents.find_one.return_value = None
with patch(
"application.api.user.analytics.routes.user_logs_collection",
mock_user_logs,
), patch(
"application.api.user.analytics.routes.agents_collection",
mock_agents,
):
with app.test_request_context(
"/api/get_user_logs",
method="POST",
json={"page": 1, "page_size": 10},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = GetUserLogs().post()
assert response.status_code == 200
assert response.json["success"] is True
assert response.json["page"] == 1
assert len(response.json["logs"]) == 1
assert response.json["has_more"] is False
def test_detects_has_more(self, app):
from application.api.user.analytics.routes import GetUserLogs
items = [
{"_id": ObjectId(), "action": f"q{i}", "level": "info"}
for i in range(3)
]
mock_cursor = Mock()
mock_cursor.sort.return_value.skip.return_value.limit.return_value = items
mock_user_logs = Mock()
mock_user_logs.find.return_value = mock_cursor
mock_agents = Mock()
mock_agents.find_one.return_value = None
with patch(
"application.api.user.analytics.routes.user_logs_collection",
mock_user_logs,
), patch(
"application.api.user.analytics.routes.agents_collection",
mock_agents,
):
with app.test_request_context(
"/api/get_user_logs",
method="POST",
json={"page": 1, "page_size": 2},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = GetUserLogs().post()
assert response.status_code == 200
assert response.json["has_more"] is True
assert len(response.json["logs"]) == 2
def test_returns_401_unauthenticated(self, app):
from application.api.user.analytics.routes import GetUserLogs
with app.test_request_context(
"/api/get_user_logs",
method="POST",
json={"page": 1},
):
from flask import request
request.decoded_token = None
response = GetUserLogs().post()
assert response.status_code == 401
@pytest.mark.unit
class TestGetTokenAnalyticsAdditional:
"""Additional tests for GetTokenAnalytics covering missing lines."""
def test_returns_401_unauthenticated(self, app):
from application.api.user.analytics.routes import GetTokenAnalytics
with app.test_request_context(
"/api/get_token_analytics",
method="POST",
json={"filter_option": "last_30_days"},
):
from flask import request
request.decoded_token = None
response = GetTokenAnalytics().post()
assert response.status_code == 401
def test_last_hour_filter(self, app):
from application.api.user.analytics.routes import GetTokenAnalytics
mock_token_usage = Mock()
mock_token_usage.aggregate.return_value = [
{"_id": {"minute": "2024-06-01 12:00:00"}, "total_tokens": 500}
]
mock_agents = Mock()
mock_agents.find_one.return_value = None
with patch(
"application.api.user.analytics.routes.token_usage_collection",
mock_token_usage,
), patch(
"application.api.user.analytics.routes.agents_collection",
mock_agents,
):
with app.test_request_context(
"/api/get_token_analytics",
method="POST",
json={"filter_option": "last_hour"},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = GetTokenAnalytics().post()
assert response.status_code == 200
assert response.json["success"] is True
assert "token_usage" in response.json
def test_last_24_hour_filter(self, app):
from application.api.user.analytics.routes import GetTokenAnalytics
mock_token_usage = Mock()
mock_token_usage.aggregate.return_value = [
{"_id": {"hour": "2024-06-01 12:00"}, "total_tokens": 800}
]
mock_agents = Mock()
mock_agents.find_one.return_value = None
with patch(
"application.api.user.analytics.routes.token_usage_collection",
mock_token_usage,
), patch(
"application.api.user.analytics.routes.agents_collection",
mock_agents,
):
with app.test_request_context(
"/api/get_token_analytics",
method="POST",
json={"filter_option": "last_24_hour"},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = GetTokenAnalytics().post()
assert response.status_code == 200
assert response.json["success"] is True
def test_filters_by_api_key(self, app):
from application.api.user.analytics.routes import GetTokenAnalytics
agent_id = ObjectId()
mock_agents = Mock()
mock_agents.find_one.return_value = {
"_id": agent_id,
"key": "token_api_key",
}
mock_token_usage = Mock()
mock_token_usage.aggregate.return_value = []
with patch(
"application.api.user.analytics.routes.agents_collection",
mock_agents,
), patch(
"application.api.user.analytics.routes.token_usage_collection",
mock_token_usage,
):
with app.test_request_context(
"/api/get_token_analytics",
method="POST",
json={
"filter_option": "last_7_days",
"api_key_id": str(agent_id),
},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = GetTokenAnalytics().post()
assert response.status_code == 200
pipeline = mock_token_usage.aggregate.call_args[0][0]
assert pipeline[0]["$match"].get("api_key") == "token_api_key"
def test_api_key_error_returns_400(self, app):
from application.api.user.analytics.routes import GetTokenAnalytics
mock_agents = Mock()
mock_agents.find_one.side_effect = Exception("db error")
with patch(
"application.api.user.analytics.routes.agents_collection",
mock_agents,
):
with app.test_request_context(
"/api/get_token_analytics",
method="POST",
json={
"filter_option": "last_30_days",
"api_key_id": str(ObjectId()),
},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = GetTokenAnalytics().post()
assert response.status_code == 400
def test_aggregate_error_returns_400(self, app):
from application.api.user.analytics.routes import GetTokenAnalytics
mock_agents = Mock()
mock_agents.find_one.return_value = None
mock_token_usage = Mock()
mock_token_usage.aggregate.side_effect = Exception("aggregate error")
with patch(
"application.api.user.analytics.routes.agents_collection",
mock_agents,
), patch(
"application.api.user.analytics.routes.token_usage_collection",
mock_token_usage,
):
with app.test_request_context(
"/api/get_token_analytics",
method="POST",
json={"filter_option": "last_30_days"},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = GetTokenAnalytics().post()
assert response.status_code == 400
def test_last_15_days_filter(self, app):
from application.api.user.analytics.routes import GetTokenAnalytics
mock_token_usage = Mock()
mock_token_usage.aggregate.return_value = []
mock_agents = Mock()
mock_agents.find_one.return_value = None
with patch(
"application.api.user.analytics.routes.token_usage_collection",
mock_token_usage,
), patch(
"application.api.user.analytics.routes.agents_collection",
mock_agents,
):
with app.test_request_context(
"/api/get_token_analytics",
method="POST",
json={"filter_option": "last_15_days"},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = GetTokenAnalytics().post()
assert response.status_code == 200
@pytest.mark.unit
class TestGetFeedbackAnalyticsAdditional:
"""Additional tests for GetFeedbackAnalytics covering missing lines."""
def test_returns_401_unauthenticated(self, app):
from application.api.user.analytics.routes import GetFeedbackAnalytics
with app.test_request_context(
"/api/get_feedback_analytics",
method="POST",
json={"filter_option": "last_30_days"},
):
from flask import request
request.decoded_token = None
response = GetFeedbackAnalytics().post()
assert response.status_code == 401
def test_last_hour_filter(self, app):
from application.api.user.analytics.routes import GetFeedbackAnalytics
mock_conversations = Mock()
mock_conversations.aggregate.return_value = []
mock_agents = Mock()
mock_agents.find_one.return_value = None
with patch(
"application.api.user.analytics.routes.conversations_collection",
mock_conversations,
), patch(
"application.api.user.analytics.routes.agents_collection",
mock_agents,
):
with app.test_request_context(
"/api/get_feedback_analytics",
method="POST",
json={"filter_option": "last_hour"},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = GetFeedbackAnalytics().post()
assert response.status_code == 200
def test_last_24_hour_filter(self, app):
from application.api.user.analytics.routes import GetFeedbackAnalytics
mock_conversations = Mock()
mock_conversations.aggregate.return_value = []
mock_agents = Mock()
mock_agents.find_one.return_value = None
with patch(
"application.api.user.analytics.routes.conversations_collection",
mock_conversations,
), patch(
"application.api.user.analytics.routes.agents_collection",
mock_agents,
):
with app.test_request_context(
"/api/get_feedback_analytics",
method="POST",
json={"filter_option": "last_24_hour"},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = GetFeedbackAnalytics().post()
assert response.status_code == 200
def test_filters_by_api_key(self, app):
from application.api.user.analytics.routes import GetFeedbackAnalytics
agent_id = ObjectId()
mock_agents = Mock()
mock_agents.find_one.return_value = {
"_id": agent_id,
"key": "fb_api_key",
}
mock_conversations = Mock()
mock_conversations.aggregate.return_value = []
with patch(
"application.api.user.analytics.routes.agents_collection",
mock_agents,
), patch(
"application.api.user.analytics.routes.conversations_collection",
mock_conversations,
):
with app.test_request_context(
"/api/get_feedback_analytics",
method="POST",
json={
"filter_option": "last_7_days",
"api_key_id": str(agent_id),
},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = GetFeedbackAnalytics().post()
assert response.status_code == 200
pipeline = mock_conversations.aggregate.call_args[0][0]
assert pipeline[0]["$match"].get("api_key") == "fb_api_key"
def test_api_key_error_returns_400(self, app):
from application.api.user.analytics.routes import GetFeedbackAnalytics
mock_agents = Mock()
mock_agents.find_one.side_effect = Exception("db error")
with patch(
"application.api.user.analytics.routes.agents_collection",
mock_agents,
):
with app.test_request_context(
"/api/get_feedback_analytics",
method="POST",
json={
"filter_option": "last_30_days",
"api_key_id": str(ObjectId()),
},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = GetFeedbackAnalytics().post()
assert response.status_code == 400
def test_aggregate_error_returns_400(self, app):
from application.api.user.analytics.routes import GetFeedbackAnalytics
mock_agents = Mock()
mock_agents.find_one.return_value = None
mock_conversations = Mock()
mock_conversations.aggregate.side_effect = Exception("aggregate error")
with patch(
"application.api.user.analytics.routes.agents_collection",
mock_agents,
), patch(
"application.api.user.analytics.routes.conversations_collection",
mock_conversations,
):
with app.test_request_context(
"/api/get_feedback_analytics",
method="POST",
json={"filter_option": "last_30_days"},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = GetFeedbackAnalytics().post()
assert response.status_code == 400
@pytest.mark.unit
class TestGetMessageAnalyticsAdditional:
"""Additional tests for GetMessageAnalytics covering error paths."""
def test_api_key_error_returns_400(self, app):
from application.api.user.analytics.routes import GetMessageAnalytics
mock_agents = Mock()
mock_agents.find_one.side_effect = Exception("db error")
with patch(
"application.api.user.analytics.routes.agents_collection",
mock_agents,
):
with app.test_request_context(
"/api/get_message_analytics",
method="POST",
json={
"filter_option": "last_30_days",
"api_key_id": str(ObjectId()),
},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = GetMessageAnalytics().post()
assert response.status_code == 400
def test_aggregate_error_returns_400(self, app):
from application.api.user.analytics.routes import GetMessageAnalytics
mock_agents = Mock()
mock_agents.find_one.return_value = None
mock_conversations = Mock()
mock_conversations.aggregate.side_effect = Exception("aggregate error")
with patch(
"application.api.user.analytics.routes.agents_collection",
mock_agents,
), patch(
"application.api.user.analytics.routes.conversations_collection",
mock_conversations,
):
with app.test_request_context(
"/api/get_message_analytics",
method="POST",
json={"filter_option": "last_30_days"},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = GetMessageAnalytics().post()
assert response.status_code == 400
def test_last_15_days_filter(self, app):
from application.api.user.analytics.routes import GetMessageAnalytics
mock_conversations = Mock()
mock_conversations.aggregate.return_value = []
mock_agents = Mock()
mock_agents.find_one.return_value = None
with patch(
"application.api.user.analytics.routes.conversations_collection",
mock_conversations,
), patch(
"application.api.user.analytics.routes.agents_collection",
mock_agents,
):
with app.test_request_context(
"/api/get_message_analytics",
method="POST",
json={"filter_option": "last_15_days"},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = GetMessageAnalytics().post()
assert response.status_code == 200
@pytest.mark.unit
class TestGetUserLogsAdditional:
"""Additional tests for GetUserLogs covering api_key filtering and errors."""
def test_filters_by_api_key(self, app):
from application.api.user.analytics.routes import GetUserLogs
agent_id = ObjectId()
mock_agents = Mock()
mock_agents.find_one.return_value = {
"_id": agent_id,
"key": "logs_api_key",
}
mock_cursor = Mock()
mock_cursor.sort.return_value.skip.return_value.limit.return_value = []
mock_user_logs = Mock()
mock_user_logs.find.return_value = mock_cursor
with patch(
"application.api.user.analytics.routes.user_logs_collection",
mock_user_logs,
), patch(
"application.api.user.analytics.routes.agents_collection",
mock_agents,
):
with app.test_request_context(
"/api/get_user_logs",
method="POST",
json={
"page": 1,
"page_size": 10,
"api_key_id": str(agent_id),
},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = GetUserLogs().post()
assert response.status_code == 200
query_arg = mock_user_logs.find.call_args[0][0]
assert query_arg == {"api_key": "logs_api_key"}
def test_api_key_error_returns_400(self, app):
from application.api.user.analytics.routes import GetUserLogs
mock_agents = Mock()
mock_agents.find_one.side_effect = Exception("db error")
with patch(
"application.api.user.analytics.routes.agents_collection",
mock_agents,
):
with app.test_request_context(
"/api/get_user_logs",
method="POST",
json={
"page": 1,
"api_key_id": str(ObjectId()),
},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = GetUserLogs().post()
assert response.status_code == 400

View File

@@ -0,0 +1,360 @@
from unittest.mock import Mock, patch
import pytest
from bson import ObjectId
from flask import Flask
@pytest.fixture
def app():
app = Flask(__name__)
return app
@pytest.mark.unit
class TestDeleteConversation:
def test_deletes_conversation(self, app):
from application.api.user.conversations.routes import DeleteConversation
conv_id = ObjectId()
mock_collection = Mock()
with patch(
"application.api.user.conversations.routes.conversations_collection",
mock_collection,
):
with app.test_request_context(f"/api/delete_conversation?id={conv_id}"):
from flask import request
request.decoded_token = {"sub": "user1"}
response = DeleteConversation().post()
assert response.status_code == 200
assert response.json["success"] is True
mock_collection.delete_one.assert_called_once_with(
{"_id": conv_id, "user": "user1"}
)
def test_returns_401_unauthenticated(self, app):
from application.api.user.conversations.routes import DeleteConversation
with app.test_request_context("/api/delete_conversation?id=abc"):
from flask import request
request.decoded_token = None
response = DeleteConversation().post()
assert response.status_code == 401
def test_returns_400_missing_id(self, app):
from application.api.user.conversations.routes import DeleteConversation
with app.test_request_context("/api/delete_conversation"):
from flask import request
request.decoded_token = {"sub": "user1"}
response = DeleteConversation().post()
assert response.status_code == 400
@pytest.mark.unit
class TestDeleteAllConversations:
def test_deletes_all_for_user(self, app):
from application.api.user.conversations.routes import DeleteAllConversations
mock_collection = Mock()
with patch(
"application.api.user.conversations.routes.conversations_collection",
mock_collection,
):
with app.test_request_context("/api/delete_all_conversations"):
from flask import request
request.decoded_token = {"sub": "user1"}
response = DeleteAllConversations().get()
assert response.status_code == 200
mock_collection.delete_many.assert_called_once_with({"user": "user1"})
def test_returns_401_unauthenticated(self, app):
from application.api.user.conversations.routes import DeleteAllConversations
with app.test_request_context("/api/delete_all_conversations"):
from flask import request
request.decoded_token = None
response = DeleteAllConversations().get()
assert response.status_code == 401
@pytest.mark.unit
class TestGetConversations:
def test_returns_conversations(self, app):
from application.api.user.conversations.routes import GetConversations
conv_id = ObjectId()
mock_cursor = Mock()
mock_cursor.sort.return_value.limit.return_value = [
{
"_id": conv_id,
"name": "Test Chat",
"agent_id": "agent1",
"is_shared_usage": False,
}
]
mock_collection = Mock()
mock_collection.find.return_value = mock_cursor
with patch(
"application.api.user.conversations.routes.conversations_collection",
mock_collection,
):
with app.test_request_context("/api/get_conversations"):
from flask import request
request.decoded_token = {"sub": "user1"}
response = GetConversations().get()
assert response.status_code == 200
data = response.json
assert len(data) == 1
assert data[0]["id"] == str(conv_id)
assert data[0]["name"] == "Test Chat"
def test_returns_401_unauthenticated(self, app):
from application.api.user.conversations.routes import GetConversations
with app.test_request_context("/api/get_conversations"):
from flask import request
request.decoded_token = None
response = GetConversations().get()
assert response.status_code == 401
@pytest.mark.unit
class TestGetSingleConversation:
def test_returns_conversation(self, app):
from application.api.user.conversations.routes import GetSingleConversation
conv_id = ObjectId()
mock_conv_collection = Mock()
mock_conv_collection.find_one.return_value = {
"_id": conv_id,
"name": "Chat",
"queries": [{"prompt": "hi", "response": "hello"}],
"agent_id": "agent1",
}
with patch(
"application.api.user.conversations.routes.conversations_collection",
mock_conv_collection,
):
with app.test_request_context(
f"/api/get_single_conversation?id={conv_id}"
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = GetSingleConversation().get()
assert response.status_code == 200
assert response.json["queries"] == [{"prompt": "hi", "response": "hello"}]
assert response.json["agent_id"] == "agent1"
def test_returns_404_not_found(self, app):
from application.api.user.conversations.routes import GetSingleConversation
mock_collection = Mock()
mock_collection.find_one.return_value = None
with patch(
"application.api.user.conversations.routes.conversations_collection",
mock_collection,
):
with app.test_request_context(
f"/api/get_single_conversation?id={ObjectId()}"
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = GetSingleConversation().get()
assert response.status_code == 404
def test_returns_400_missing_id(self, app):
from application.api.user.conversations.routes import GetSingleConversation
with app.test_request_context("/api/get_single_conversation"):
from flask import request
request.decoded_token = {"sub": "user1"}
response = GetSingleConversation().get()
assert response.status_code == 400
def test_resolves_attachments(self, app):
from application.api.user.conversations.routes import GetSingleConversation
conv_id = ObjectId()
att_id = ObjectId()
mock_conv_collection = Mock()
mock_conv_collection.find_one.return_value = {
"_id": conv_id,
"name": "Chat",
"queries": [
{"prompt": "hi", "response": "hello", "attachments": [str(att_id)]}
],
}
mock_att_collection = Mock()
mock_att_collection.find_one.return_value = {
"_id": att_id,
"filename": "doc.pdf",
}
with patch(
"application.api.user.conversations.routes.conversations_collection",
mock_conv_collection,
), patch(
"application.api.user.conversations.routes.attachments_collection",
mock_att_collection,
):
with app.test_request_context(
f"/api/get_single_conversation?id={conv_id}"
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = GetSingleConversation().get()
assert response.status_code == 200
attachments = response.json["queries"][0]["attachments"]
assert len(attachments) == 1
assert attachments[0]["fileName"] == "doc.pdf"
@pytest.mark.unit
class TestUpdateConversationName:
def test_updates_name(self, app):
from application.api.user.conversations.routes import UpdateConversationName
conv_id = ObjectId()
mock_collection = Mock()
with patch(
"application.api.user.conversations.routes.conversations_collection",
mock_collection,
):
with app.test_request_context(
"/api/update_conversation_name",
method="POST",
json={"id": str(conv_id), "name": "New Name"},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = UpdateConversationName().post()
assert response.status_code == 200
assert response.json["success"] is True
mock_collection.update_one.assert_called_once()
def test_returns_400_missing_fields(self, app):
from application.api.user.conversations.routes import UpdateConversationName
with app.test_request_context(
"/api/update_conversation_name",
method="POST",
json={"id": str(ObjectId())},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = UpdateConversationName().post()
assert response.status_code == 400
@pytest.mark.unit
class TestSubmitFeedback:
def test_submits_positive_feedback(self, app):
from application.api.user.conversations.routes import SubmitFeedback
conv_id = ObjectId()
mock_collection = Mock()
with patch(
"application.api.user.conversations.routes.conversations_collection",
mock_collection,
):
with app.test_request_context(
"/api/feedback",
method="POST",
json={
"feedback": "LIKE",
"conversation_id": str(conv_id),
"question_index": 0,
},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = SubmitFeedback().post()
assert response.status_code == 200
assert response.json["success"] is True
call_args = mock_collection.update_one.call_args
assert "$set" in call_args[0][1]
def test_removes_feedback_when_null(self, app):
from application.api.user.conversations.routes import SubmitFeedback
conv_id = ObjectId()
mock_collection = Mock()
with patch(
"application.api.user.conversations.routes.conversations_collection",
mock_collection,
):
with app.test_request_context(
"/api/feedback",
method="POST",
json={
"feedback": None,
"conversation_id": str(conv_id),
"question_index": 0,
},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = SubmitFeedback().post()
assert response.status_code == 200
call_args = mock_collection.update_one.call_args
assert "$unset" in call_args[0][1]
def test_returns_400_missing_fields(self, app):
from application.api.user.conversations.routes import SubmitFeedback
with app.test_request_context(
"/api/feedback",
method="POST",
json={"feedback": "LIKE"},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = SubmitFeedback().post()
assert response.status_code == 400

View File

@@ -0,0 +1,622 @@
import datetime
from unittest.mock import Mock, patch
import pytest
from bson import ObjectId
from flask import Flask
@pytest.fixture
def app():
app = Flask(__name__)
return app
@pytest.mark.unit
class TestAgentFoldersGet:
def test_returns_folders(self, app):
from application.api.user.agents.folders import AgentFolders
now = datetime.datetime(2024, 6, 15, tzinfo=datetime.timezone.utc)
folder_id = ObjectId()
mock_collection = Mock()
mock_collection.find.return_value = [
{
"_id": folder_id,
"name": "My Folder",
"parent_id": None,
"created_at": now,
"updated_at": now,
}
]
with patch(
"application.api.user.agents.folders.agent_folders_collection",
mock_collection,
):
with app.test_request_context("/api/agents/folders/", method="GET"):
from flask import request
request.decoded_token = {"sub": "user1"}
response = AgentFolders().get()
assert response.status_code == 200
folders = response.json["folders"]
assert len(folders) == 1
assert folders[0]["id"] == str(folder_id)
assert folders[0]["name"] == "My Folder"
def test_returns_401_unauthenticated(self, app):
from application.api.user.agents.folders import AgentFolders
with app.test_request_context("/api/agents/folders/", method="GET"):
from flask import request
request.decoded_token = None
response = AgentFolders().get()
assert response.status_code == 401
@pytest.mark.unit
class TestAgentFoldersCreate:
def test_creates_folder(self, app):
from application.api.user.agents.folders import AgentFolders
inserted_id = ObjectId()
mock_collection = Mock()
mock_collection.insert_one.return_value = Mock(inserted_id=inserted_id)
with patch(
"application.api.user.agents.folders.agent_folders_collection",
mock_collection,
):
with app.test_request_context(
"/api/agents/folders/",
method="POST",
json={"name": "New Folder"},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = AgentFolders().post()
assert response.status_code == 201
assert response.json["id"] == str(inserted_id)
assert response.json["name"] == "New Folder"
def test_returns_400_missing_name(self, app):
from application.api.user.agents.folders import AgentFolders
with app.test_request_context(
"/api/agents/folders/",
method="POST",
json={},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = AgentFolders().post()
assert response.status_code == 400
def test_validates_parent_folder_exists(self, app):
from application.api.user.agents.folders import AgentFolders
mock_collection = Mock()
mock_collection.find_one.return_value = None
with patch(
"application.api.user.agents.folders.agent_folders_collection",
mock_collection,
):
with app.test_request_context(
"/api/agents/folders/",
method="POST",
json={"name": "Sub", "parent_id": str(ObjectId())},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = AgentFolders().post()
assert response.status_code == 404
@pytest.mark.unit
class TestAgentFolderGet:
def test_returns_folder_with_agents_and_subfolders(self, app):
from application.api.user.agents.folders import AgentFolder
folder_id = ObjectId()
agent_id = ObjectId()
subfolder_id = ObjectId()
mock_folders = Mock()
mock_folders.find_one.return_value = {
"_id": folder_id,
"name": "Folder",
"parent_id": None,
}
mock_folders.find.return_value = [
{"_id": subfolder_id, "name": "Subfolder"}
]
mock_agents = Mock()
mock_agents.find.return_value = [
{"_id": agent_id, "name": "Agent 1", "description": "Desc"}
]
with patch(
"application.api.user.agents.folders.agent_folders_collection",
mock_folders,
), patch(
"application.api.user.agents.folders.agents_collection",
mock_agents,
):
with app.test_request_context(
f"/api/agents/folders/{folder_id}", method="GET"
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = AgentFolder().get(str(folder_id))
assert response.status_code == 200
assert response.json["name"] == "Folder"
assert len(response.json["agents"]) == 1
assert len(response.json["subfolders"]) == 1
def test_returns_404_not_found(self, app):
from application.api.user.agents.folders import AgentFolder
mock_collection = Mock()
mock_collection.find_one.return_value = None
with patch(
"application.api.user.agents.folders.agent_folders_collection",
mock_collection,
):
with app.test_request_context(
f"/api/agents/folders/{ObjectId()}", method="GET"
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = AgentFolder().get(str(ObjectId()))
assert response.status_code == 404
@pytest.mark.unit
class TestAgentFolderUpdate:
def test_updates_folder_name(self, app):
from application.api.user.agents.folders import AgentFolder
folder_id = ObjectId()
mock_collection = Mock()
mock_collection.update_one.return_value = Mock(matched_count=1)
with patch(
"application.api.user.agents.folders.agent_folders_collection",
mock_collection,
):
with app.test_request_context(
f"/api/agents/folders/{folder_id}",
method="PUT",
json={"name": "Renamed"},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = AgentFolder().put(str(folder_id))
assert response.status_code == 200
assert response.json["success"] is True
def test_prevents_self_parent(self, app):
from application.api.user.agents.folders import AgentFolder
folder_id = str(ObjectId())
with app.test_request_context(
f"/api/agents/folders/{folder_id}",
method="PUT",
json={"parent_id": folder_id},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = AgentFolder().put(folder_id)
assert response.status_code == 400
assert "own parent" in response.json["message"]
def test_returns_404_when_not_found(self, app):
from application.api.user.agents.folders import AgentFolder
mock_collection = Mock()
mock_collection.update_one.return_value = Mock(matched_count=0)
with patch(
"application.api.user.agents.folders.agent_folders_collection",
mock_collection,
):
with app.test_request_context(
f"/api/agents/folders/{ObjectId()}",
method="PUT",
json={"name": "X"},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = AgentFolder().put(str(ObjectId()))
assert response.status_code == 404
@pytest.mark.unit
class TestAgentFolderDelete:
def test_deletes_folder_and_unsets_references(self, app):
from application.api.user.agents.folders import AgentFolder
folder_id = str(ObjectId())
mock_folders = Mock()
mock_folders.delete_one.return_value = Mock(deleted_count=1)
mock_agents = Mock()
with patch(
"application.api.user.agents.folders.agent_folders_collection",
mock_folders,
), patch(
"application.api.user.agents.folders.agents_collection",
mock_agents,
):
with app.test_request_context(
f"/api/agents/folders/{folder_id}", method="DELETE"
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = AgentFolder().delete(folder_id)
assert response.status_code == 200
mock_agents.update_many.assert_called_once()
mock_folders.update_many.assert_called_once()
mock_folders.delete_one.assert_called_once()
def test_returns_404_not_found(self, app):
from application.api.user.agents.folders import AgentFolder
mock_folders = Mock()
mock_folders.delete_one.return_value = Mock(deleted_count=0)
mock_agents = Mock()
with patch(
"application.api.user.agents.folders.agent_folders_collection",
mock_folders,
), patch(
"application.api.user.agents.folders.agents_collection",
mock_agents,
):
with app.test_request_context(
f"/api/agents/folders/{ObjectId()}", method="DELETE"
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = AgentFolder().delete(str(ObjectId()))
assert response.status_code == 404
@pytest.mark.unit
class TestMoveAgentToFolder:
def test_moves_agent_to_folder(self, app):
from application.api.user.agents.folders import MoveAgentToFolder
agent_id = ObjectId()
folder_id = ObjectId()
mock_agents = Mock()
mock_agents.find_one.return_value = {"_id": agent_id, "user": "user1"}
mock_folders = Mock()
mock_folders.find_one.return_value = {"_id": folder_id}
with patch(
"application.api.user.agents.folders.agents_collection",
mock_agents,
), patch(
"application.api.user.agents.folders.agent_folders_collection",
mock_folders,
):
with app.test_request_context(
"/api/agents/folders/move_agent",
method="POST",
json={
"agent_id": str(agent_id),
"folder_id": str(folder_id),
},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = MoveAgentToFolder().post()
assert response.status_code == 200
mock_agents.update_one.assert_called_once()
def test_removes_agent_from_folder(self, app):
from application.api.user.agents.folders import MoveAgentToFolder
agent_id = ObjectId()
mock_agents = Mock()
mock_agents.find_one.return_value = {"_id": agent_id, "user": "user1"}
with patch(
"application.api.user.agents.folders.agents_collection",
mock_agents,
):
with app.test_request_context(
"/api/agents/folders/move_agent",
method="POST",
json={"agent_id": str(agent_id), "folder_id": None},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = MoveAgentToFolder().post()
assert response.status_code == 200
call_args = mock_agents.update_one.call_args
assert "$unset" in call_args[0][1]
def test_returns_404_agent_not_found(self, app):
from application.api.user.agents.folders import MoveAgentToFolder
mock_agents = Mock()
mock_agents.find_one.return_value = None
with patch(
"application.api.user.agents.folders.agents_collection",
mock_agents,
):
with app.test_request_context(
"/api/agents/folders/move_agent",
method="POST",
json={"agent_id": str(ObjectId())},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = MoveAgentToFolder().post()
assert response.status_code == 404
def test_returns_400_missing_agent_id(self, app):
from application.api.user.agents.folders import MoveAgentToFolder
with app.test_request_context(
"/api/agents/folders/move_agent",
method="POST",
json={},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = MoveAgentToFolder().post()
assert response.status_code == 400
@pytest.mark.unit
class TestBulkMoveAgents:
def test_bulk_moves_to_folder(self, app):
from application.api.user.agents.folders import BulkMoveAgents
folder_id = ObjectId()
agent_ids = [str(ObjectId()), str(ObjectId())]
mock_agents = Mock()
mock_folders = Mock()
mock_folders.find_one.return_value = {"_id": folder_id}
with patch(
"application.api.user.agents.folders.agents_collection",
mock_agents,
), patch(
"application.api.user.agents.folders.agent_folders_collection",
mock_folders,
):
with app.test_request_context(
"/api/agents/folders/bulk_move",
method="POST",
json={"agent_ids": agent_ids, "folder_id": str(folder_id)},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = BulkMoveAgents().post()
assert response.status_code == 200
mock_agents.update_many.assert_called_once()
def test_bulk_removes_from_folders(self, app):
from application.api.user.agents.folders import BulkMoveAgents
agent_ids = [str(ObjectId())]
mock_agents = Mock()
with patch(
"application.api.user.agents.folders.agents_collection",
mock_agents,
):
with app.test_request_context(
"/api/agents/folders/bulk_move",
method="POST",
json={"agent_ids": agent_ids},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = BulkMoveAgents().post()
assert response.status_code == 200
call_args = mock_agents.update_many.call_args
assert "$unset" in call_args[0][1]
def test_returns_400_missing_agent_ids(self, app):
from application.api.user.agents.folders import BulkMoveAgents
with app.test_request_context(
"/api/agents/folders/bulk_move",
method="POST",
json={},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = BulkMoveAgents().post()
assert response.status_code == 400
def test_returns_404_folder_not_found(self, app):
from application.api.user.agents.folders import BulkMoveAgents
mock_folders = Mock()
mock_folders.find_one.return_value = None
with patch(
"application.api.user.agents.folders.agent_folders_collection",
mock_folders,
):
with app.test_request_context(
"/api/agents/folders/bulk_move",
method="POST",
json={
"agent_ids": [str(ObjectId())],
"folder_id": str(ObjectId()),
},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = BulkMoveAgents().post()
assert response.status_code == 404
# =====================================================================
# Coverage gap tests (lines 64, 90-91, 100, 125-126, 132, 136)
# =====================================================================
@pytest.mark.unit
class TestAgentFoldersGaps:
def test_create_folder_no_auth(self, app):
"""Cover line 64: post returns 401 when no decoded_token."""
from application.api.user.agents.folders import AgentFolders
with app.test_request_context(
"/api/agents/folders/",
method="POST",
json={"name": "Test"},
):
from flask import request
request.decoded_token = None
response = AgentFolders().post()
assert response.status_code == 401
def test_create_folder_exception(self, app):
"""Cover lines 90-91: exception during insert_one returns 400."""
from application.api.user.agents.folders import AgentFolders
mock_folders = Mock()
mock_folders.find_one.return_value = None
mock_folders.insert_one.side_effect = Exception("db error")
with patch(
"application.api.user.agents.folders.agent_folders_collection",
mock_folders,
):
with app.test_request_context(
"/api/agents/folders/",
method="POST",
json={"name": "Test"},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = AgentFolders().post()
assert response.status_code == 400
def test_get_folder_no_auth(self, app):
"""Cover line 100: get specific folder returns 401 when no auth."""
from application.api.user.agents.folders import AgentFolder
with app.test_request_context(
"/api/agents/folders/abc",
method="GET",
):
from flask import request
request.decoded_token = None
response = AgentFolder().get("abc")
assert response.status_code == 401
def test_get_folder_exception(self, app):
"""Cover lines 125-126: exception during find returns 400."""
from application.api.user.agents.folders import AgentFolder
mock_folders = Mock()
mock_folders.find_one.side_effect = Exception("db error")
with patch(
"application.api.user.agents.folders.agent_folders_collection",
mock_folders,
):
with app.test_request_context(
"/api/agents/folders/" + str(ObjectId()),
method="GET",
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = AgentFolder().get(str(ObjectId()))
assert response.status_code == 400
def test_update_folder_no_auth(self, app):
"""Cover line 132: put returns 401 when no decoded_token."""
from application.api.user.agents.folders import AgentFolder
with app.test_request_context(
"/api/agents/folders/abc",
method="PUT",
json={"name": "Updated"},
):
from flask import request
request.decoded_token = None
response = AgentFolder().put("abc")
assert response.status_code == 401
def test_update_folder_no_data(self, app):
"""Cover line 136: put with no data returns 400."""
from application.api.user.agents.folders import AgentFolder
with app.test_request_context(
"/api/agents/folders/abc",
method="PUT",
content_type="application/json",
data="null",
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = AgentFolder().put("abc")
assert response.status_code == 400

View File

@@ -0,0 +1,70 @@
from unittest.mock import Mock, patch
import pytest
from flask import Flask
@pytest.fixture
def app():
app = Flask(__name__)
return app
@pytest.mark.unit
class TestModelsListResource:
def test_returns_models(self, app):
from application.api.user.models.routes import ModelsListResource
mock_model = Mock()
mock_model.to_dict.return_value = {
"id": "gpt-4",
"name": "GPT-4",
"provider": "openai",
}
mock_registry = Mock()
mock_registry.get_enabled_models.return_value = [mock_model]
mock_registry.default_model_id = "gpt-4"
with patch(
"application.api.user.models.routes.ModelRegistry.get_instance",
return_value=mock_registry,
):
with app.test_request_context("/api/models"):
response = ModelsListResource().get()
assert response.status_code == 200
assert response.json["count"] == 1
assert response.json["default_model_id"] == "gpt-4"
assert response.json["models"][0]["id"] == "gpt-4"
def test_returns_empty_models(self, app):
from application.api.user.models.routes import ModelsListResource
mock_registry = Mock()
mock_registry.get_enabled_models.return_value = []
mock_registry.default_model_id = None
with patch(
"application.api.user.models.routes.ModelRegistry.get_instance",
return_value=mock_registry,
):
with app.test_request_context("/api/models"):
response = ModelsListResource().get()
assert response.status_code == 200
assert response.json["count"] == 0
assert response.json["models"] == []
def test_returns_500_on_error(self, app):
from application.api.user.models.routes import ModelsListResource
with patch(
"application.api.user.models.routes.ModelRegistry.get_instance",
side_effect=Exception("Registry error"),
):
with app.test_request_context("/api/models"):
response = ModelsListResource().get()
assert response.status_code == 500

View File

@@ -0,0 +1,288 @@
from unittest.mock import Mock, mock_open, patch
import pytest
from bson import ObjectId
from flask import Flask
@pytest.fixture
def app():
app = Flask(__name__)
return app
@pytest.mark.unit
class TestCreatePrompt:
def test_creates_prompt(self, app):
from application.api.user.prompts.routes import CreatePrompt
mock_collection = Mock()
inserted_id = ObjectId()
mock_collection.insert_one.return_value = Mock(inserted_id=inserted_id)
with patch(
"application.api.user.prompts.routes.prompts_collection",
mock_collection,
):
with app.test_request_context(
"/api/create_prompt",
method="POST",
json={"name": "My Prompt", "content": "You are helpful."},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = CreatePrompt().post()
assert response.status_code == 200
assert response.json["id"] == str(inserted_id)
mock_collection.insert_one.assert_called_once()
doc = mock_collection.insert_one.call_args[0][0]
assert doc["name"] == "My Prompt"
assert doc["user"] == "user1"
def test_returns_401_unauthenticated(self, app):
from application.api.user.prompts.routes import CreatePrompt
with app.test_request_context(
"/api/create_prompt",
method="POST",
json={"name": "P", "content": "C"},
):
from flask import request
request.decoded_token = None
response = CreatePrompt().post()
assert response.status_code == 401
def test_returns_400_missing_fields(self, app):
from application.api.user.prompts.routes import CreatePrompt
with app.test_request_context(
"/api/create_prompt",
method="POST",
json={"name": "P"},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = CreatePrompt().post()
assert response.status_code == 400
@pytest.mark.unit
class TestGetPrompts:
def test_returns_prompts_with_defaults(self, app):
from application.api.user.prompts.routes import GetPrompts
user_prompt_id = ObjectId()
mock_collection = Mock()
mock_collection.find.return_value = [
{"_id": user_prompt_id, "name": "Custom Prompt"}
]
with patch(
"application.api.user.prompts.routes.prompts_collection",
mock_collection,
):
with app.test_request_context("/api/get_prompts"):
from flask import request
request.decoded_token = {"sub": "user1"}
response = GetPrompts().get()
assert response.status_code == 200
data = response.json
public_names = [p["name"] for p in data if p["type"] == "public"]
assert "default" in public_names
assert "creative" in public_names
assert "strict" in public_names
private = [p for p in data if p["type"] == "private"]
assert len(private) == 1
assert private[0]["name"] == "Custom Prompt"
def test_returns_401_unauthenticated(self, app):
from application.api.user.prompts.routes import GetPrompts
with app.test_request_context("/api/get_prompts"):
from flask import request
request.decoded_token = None
response = GetPrompts().get()
assert response.status_code == 401
@pytest.mark.unit
class TestGetSinglePrompt:
def test_returns_default_prompt(self, app):
from application.api.user.prompts.routes import GetSinglePrompt
with patch("builtins.open", mock_open(read_data="Default prompt content")):
with app.test_request_context("/api/get_single_prompt?id=default"):
from flask import request
request.decoded_token = {"sub": "user1"}
response = GetSinglePrompt().get()
assert response.status_code == 200
assert response.json["content"] == "Default prompt content"
def test_returns_creative_prompt(self, app):
from application.api.user.prompts.routes import GetSinglePrompt
with patch("builtins.open", mock_open(read_data="Creative content")):
with app.test_request_context("/api/get_single_prompt?id=creative"):
from flask import request
request.decoded_token = {"sub": "user1"}
response = GetSinglePrompt().get()
assert response.status_code == 200
assert response.json["content"] == "Creative content"
def test_returns_strict_prompt(self, app):
from application.api.user.prompts.routes import GetSinglePrompt
with patch("builtins.open", mock_open(read_data="Strict content")):
with app.test_request_context("/api/get_single_prompt?id=strict"):
from flask import request
request.decoded_token = {"sub": "user1"}
response = GetSinglePrompt().get()
assert response.status_code == 200
assert response.json["content"] == "Strict content"
def test_returns_custom_prompt(self, app):
from application.api.user.prompts.routes import GetSinglePrompt
prompt_id = ObjectId()
mock_collection = Mock()
mock_collection.find_one.return_value = {
"_id": prompt_id,
"content": "Custom content",
}
with patch(
"application.api.user.prompts.routes.prompts_collection",
mock_collection,
):
with app.test_request_context(
f"/api/get_single_prompt?id={prompt_id}"
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = GetSinglePrompt().get()
assert response.status_code == 200
assert response.json["content"] == "Custom content"
def test_returns_400_missing_id(self, app):
from application.api.user.prompts.routes import GetSinglePrompt
with app.test_request_context("/api/get_single_prompt"):
from flask import request
request.decoded_token = {"sub": "user1"}
response = GetSinglePrompt().get()
assert response.status_code == 400
@pytest.mark.unit
class TestDeletePrompt:
def test_deletes_prompt(self, app):
from application.api.user.prompts.routes import DeletePrompt
prompt_id = ObjectId()
mock_collection = Mock()
with patch(
"application.api.user.prompts.routes.prompts_collection",
mock_collection,
):
with app.test_request_context(
"/api/delete_prompt",
method="POST",
json={"id": str(prompt_id)},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = DeletePrompt().post()
assert response.status_code == 200
assert response.json["success"] is True
mock_collection.delete_one.assert_called_once_with(
{"_id": prompt_id, "user": "user1"}
)
def test_returns_400_missing_id(self, app):
from application.api.user.prompts.routes import DeletePrompt
with app.test_request_context(
"/api/delete_prompt",
method="POST",
json={},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = DeletePrompt().post()
assert response.status_code == 400
@pytest.mark.unit
class TestUpdatePrompt:
def test_updates_prompt(self, app):
from application.api.user.prompts.routes import UpdatePrompt
prompt_id = ObjectId()
mock_collection = Mock()
with patch(
"application.api.user.prompts.routes.prompts_collection",
mock_collection,
):
with app.test_request_context(
"/api/update_prompt",
method="POST",
json={
"id": str(prompt_id),
"name": "Updated",
"content": "New content",
},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = UpdatePrompt().post()
assert response.status_code == 200
assert response.json["success"] is True
mock_collection.update_one.assert_called_once()
def test_returns_400_missing_fields(self, app):
from application.api.user.prompts.routes import UpdatePrompt
with app.test_request_context(
"/api/update_prompt",
method="POST",
json={"id": str(ObjectId()), "name": "Updated"},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = UpdatePrompt().post()
assert response.status_code == 400

View File

@@ -0,0 +1,803 @@
import uuid
from unittest.mock import Mock, patch
import pytest
from bson import ObjectId
from bson.binary import Binary, UuidRepresentation
from flask import Flask
@pytest.fixture
def app():
app = Flask(__name__)
return app
@pytest.mark.unit
class TestShareConversation:
def test_shares_non_promptable_conversation(self, app):
from application.api.user.sharing.routes import ShareConversation
conv_id = ObjectId()
mock_conversations = Mock()
mock_conversations.find_one.return_value = {
"_id": conv_id,
"name": "Test Chat",
"queries": [{"prompt": "hi"}],
}
mock_shared = Mock()
mock_shared.find_one.return_value = None
mock_shared.insert_one.return_value = Mock()
with patch(
"application.api.user.sharing.routes.conversations_collection",
mock_conversations,
), patch(
"application.api.user.sharing.routes.shared_conversations_collections",
mock_shared,
):
with app.test_request_context(
"/api/share?isPromptable=false",
method="POST",
json={"conversation_id": str(conv_id)},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = ShareConversation().post()
assert response.status_code == 201
assert response.json["success"] is True
assert "identifier" in response.json
mock_shared.insert_one.assert_called_once()
def test_returns_existing_shared_link(self, app):
from application.api.user.sharing.routes import ShareConversation
conv_id = ObjectId()
test_uuid = uuid.uuid4()
binary_uuid = Binary.from_uuid(test_uuid, UuidRepresentation.STANDARD)
mock_conversations = Mock()
mock_conversations.find_one.return_value = {
"_id": conv_id,
"name": "Test Chat",
"queries": [{"prompt": "hi"}],
}
mock_shared = Mock()
mock_shared.find_one.return_value = {
"uuid": binary_uuid,
"conversation_id": conv_id,
}
with patch(
"application.api.user.sharing.routes.conversations_collection",
mock_conversations,
), patch(
"application.api.user.sharing.routes.shared_conversations_collections",
mock_shared,
):
with app.test_request_context(
"/api/share?isPromptable=false",
method="POST",
json={"conversation_id": str(conv_id)},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = ShareConversation().post()
assert response.status_code == 200
assert response.json["identifier"] == str(test_uuid)
def test_returns_401_unauthenticated(self, app):
from application.api.user.sharing.routes import ShareConversation
with app.test_request_context(
"/api/share?isPromptable=false",
method="POST",
json={"conversation_id": str(ObjectId())},
):
from flask import request
request.decoded_token = None
response = ShareConversation().post()
assert response.status_code == 401
def test_returns_400_missing_conversation_id(self, app):
from application.api.user.sharing.routes import ShareConversation
with app.test_request_context(
"/api/share?isPromptable=false",
method="POST",
json={},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = ShareConversation().post()
assert response.status_code == 400
def test_returns_400_missing_isPromptable(self, app):
from application.api.user.sharing.routes import ShareConversation
with app.test_request_context(
"/api/share",
method="POST",
json={"conversation_id": str(ObjectId())},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = ShareConversation().post()
assert response.status_code == 400
assert "isPromptable" in response.json["message"]
def test_returns_404_conversation_not_found(self, app):
from application.api.user.sharing.routes import ShareConversation
mock_conversations = Mock()
mock_conversations.find_one.return_value = None
with patch(
"application.api.user.sharing.routes.conversations_collection",
mock_conversations,
):
with app.test_request_context(
"/api/share?isPromptable=false",
method="POST",
json={"conversation_id": str(ObjectId())},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = ShareConversation().post()
assert response.status_code == 404
@pytest.mark.unit
class TestGetPubliclySharedConversations:
def test_returns_shared_conversation(self, app):
from application.api.user.sharing.routes import (
GetPubliclySharedConversations,
)
test_uuid = uuid.uuid4()
binary_uuid = Binary.from_uuid(test_uuid, UuidRepresentation.STANDARD)
conv_id = ObjectId()
mock_shared = Mock()
mock_shared.find_one.return_value = {
"uuid": binary_uuid,
"conversation_id": conv_id,
"first_n_queries": 2,
"isPromptable": False,
}
mock_conversations = Mock()
mock_conversations.find_one.return_value = {
"_id": conv_id,
"name": "Shared Chat",
"queries": [
{"prompt": "q1", "response": "a1"},
{"prompt": "q2", "response": "a2"},
{"prompt": "q3", "response": "a3"},
],
}
with patch(
"application.api.user.sharing.routes.shared_conversations_collections",
mock_shared,
), patch(
"application.api.user.sharing.routes.conversations_collection",
mock_conversations,
):
with app.test_request_context(
f"/api/shared_conversation/{test_uuid}"
):
response = GetPubliclySharedConversations().get(str(test_uuid))
assert response.status_code == 200
assert response.json["success"] is True
assert response.json["title"] == "Shared Chat"
assert len(response.json["queries"]) == 2
def test_returns_404_not_found(self, app):
from application.api.user.sharing.routes import (
GetPubliclySharedConversations,
)
test_uuid = uuid.uuid4()
mock_shared = Mock()
mock_shared.find_one.return_value = None
with patch(
"application.api.user.sharing.routes.shared_conversations_collections",
mock_shared,
):
with app.test_request_context(
f"/api/shared_conversation/{test_uuid}"
):
response = GetPubliclySharedConversations().get(str(test_uuid))
assert response.status_code == 404
def test_returns_404_conversation_deleted(self, app):
from application.api.user.sharing.routes import (
GetPubliclySharedConversations,
)
test_uuid = uuid.uuid4()
binary_uuid = Binary.from_uuid(test_uuid, UuidRepresentation.STANDARD)
conv_id = ObjectId()
mock_shared = Mock()
mock_shared.find_one.return_value = {
"uuid": binary_uuid,
"conversation_id": conv_id,
"first_n_queries": 1,
"isPromptable": False,
}
mock_conversations = Mock()
mock_conversations.find_one.return_value = None
with patch(
"application.api.user.sharing.routes.shared_conversations_collections",
mock_shared,
), patch(
"application.api.user.sharing.routes.conversations_collection",
mock_conversations,
):
with app.test_request_context(
f"/api/shared_conversation/{test_uuid}"
):
response = GetPubliclySharedConversations().get(str(test_uuid))
assert response.status_code == 404
def test_includes_api_key_when_promptable(self, app):
from application.api.user.sharing.routes import (
GetPubliclySharedConversations,
)
test_uuid = uuid.uuid4()
binary_uuid = Binary.from_uuid(test_uuid, UuidRepresentation.STANDARD)
conv_id = ObjectId()
mock_shared = Mock()
mock_shared.find_one.return_value = {
"uuid": binary_uuid,
"conversation_id": conv_id,
"first_n_queries": 1,
"isPromptable": True,
"api_key": "shared_api_key",
}
mock_conversations = Mock()
mock_conversations.find_one.return_value = {
"_id": conv_id,
"name": "Chat",
"queries": [{"prompt": "q1", "response": "a1"}],
}
with patch(
"application.api.user.sharing.routes.shared_conversations_collections",
mock_shared,
), patch(
"application.api.user.sharing.routes.conversations_collection",
mock_conversations,
):
with app.test_request_context(
f"/api/shared_conversation/{test_uuid}"
):
response = GetPubliclySharedConversations().get(str(test_uuid))
assert response.status_code == 200
assert response.json["api_key"] == "shared_api_key"
def test_handles_dbref_conversation_id(self, app):
from bson.dbref import DBRef
from application.api.user.sharing.routes import (
GetPubliclySharedConversations,
)
test_uuid = uuid.uuid4()
binary_uuid = Binary.from_uuid(test_uuid, UuidRepresentation.STANDARD)
conv_id = ObjectId()
mock_shared = Mock()
mock_shared.find_one.return_value = {
"uuid": binary_uuid,
"conversation_id": DBRef("conversations", conv_id),
"first_n_queries": 1,
"isPromptable": False,
}
mock_conversations = Mock()
mock_conversations.find_one.return_value = {
"_id": conv_id,
"name": "Chat",
"queries": [{"prompt": "q1", "response": "a1"}],
}
with patch(
"application.api.user.sharing.routes.shared_conversations_collections",
mock_shared,
), patch(
"application.api.user.sharing.routes.conversations_collection",
mock_conversations,
):
with app.test_request_context(
f"/api/shared_conversation/{test_uuid}"
):
response = GetPubliclySharedConversations().get(str(test_uuid))
assert response.status_code == 200
mock_conversations.find_one.assert_called_once_with({"_id": conv_id})
def test_handles_dict_oid_conversation_id(self, app):
from application.api.user.sharing.routes import (
GetPubliclySharedConversations,
)
test_uuid = uuid.uuid4()
binary_uuid = Binary.from_uuid(test_uuid, UuidRepresentation.STANDARD)
conv_id = ObjectId()
mock_shared = Mock()
mock_shared.find_one.return_value = {
"uuid": binary_uuid,
"conversation_id": {"$id": {"$oid": str(conv_id)}},
"first_n_queries": 1,
"isPromptable": False,
}
mock_conversations = Mock()
mock_conversations.find_one.return_value = {
"_id": conv_id,
"name": "Chat",
"queries": [{"prompt": "q1", "response": "a1"}],
}
with patch(
"application.api.user.sharing.routes.shared_conversations_collections",
mock_shared,
), patch(
"application.api.user.sharing.routes.conversations_collection",
mock_conversations,
):
with app.test_request_context(
f"/api/shared_conversation/{test_uuid}"
):
response = GetPubliclySharedConversations().get(str(test_uuid))
assert response.status_code == 200
def test_handles_dict_id_string_conversation_id(self, app):
from application.api.user.sharing.routes import (
GetPubliclySharedConversations,
)
test_uuid = uuid.uuid4()
binary_uuid = Binary.from_uuid(test_uuid, UuidRepresentation.STANDARD)
conv_id = ObjectId()
mock_shared = Mock()
mock_shared.find_one.return_value = {
"uuid": binary_uuid,
"conversation_id": {"$id": str(conv_id)},
"first_n_queries": 1,
"isPromptable": False,
}
mock_conversations = Mock()
mock_conversations.find_one.return_value = {
"_id": conv_id,
"name": "Chat",
"queries": [{"prompt": "q1", "response": "a1"}],
}
with patch(
"application.api.user.sharing.routes.shared_conversations_collections",
mock_shared,
), patch(
"application.api.user.sharing.routes.conversations_collection",
mock_conversations,
):
with app.test_request_context(
f"/api/shared_conversation/{test_uuid}"
):
response = GetPubliclySharedConversations().get(str(test_uuid))
assert response.status_code == 200
def test_handles_dict_underscore_id_conversation_id(self, app):
from application.api.user.sharing.routes import (
GetPubliclySharedConversations,
)
test_uuid = uuid.uuid4()
binary_uuid = Binary.from_uuid(test_uuid, UuidRepresentation.STANDARD)
conv_id = ObjectId()
mock_shared = Mock()
mock_shared.find_one.return_value = {
"uuid": binary_uuid,
"conversation_id": {"_id": str(conv_id)},
"first_n_queries": 1,
"isPromptable": False,
}
mock_conversations = Mock()
mock_conversations.find_one.return_value = {
"_id": conv_id,
"name": "Chat",
"queries": [{"prompt": "q1", "response": "a1"}],
}
with patch(
"application.api.user.sharing.routes.shared_conversations_collections",
mock_shared,
), patch(
"application.api.user.sharing.routes.conversations_collection",
mock_conversations,
):
with app.test_request_context(
f"/api/shared_conversation/{test_uuid}"
):
response = GetPubliclySharedConversations().get(str(test_uuid))
assert response.status_code == 200
def test_handles_string_conversation_id(self, app):
from application.api.user.sharing.routes import (
GetPubliclySharedConversations,
)
test_uuid = uuid.uuid4()
binary_uuid = Binary.from_uuid(test_uuid, UuidRepresentation.STANDARD)
conv_id = ObjectId()
mock_shared = Mock()
mock_shared.find_one.return_value = {
"uuid": binary_uuid,
"conversation_id": str(conv_id),
"first_n_queries": 1,
"isPromptable": False,
}
mock_conversations = Mock()
mock_conversations.find_one.return_value = {
"_id": conv_id,
"name": "Chat",
"queries": [{"prompt": "q1", "response": "a1"}],
}
with patch(
"application.api.user.sharing.routes.shared_conversations_collections",
mock_shared,
), patch(
"application.api.user.sharing.routes.conversations_collection",
mock_conversations,
):
with app.test_request_context(
f"/api/shared_conversation/{test_uuid}"
):
response = GetPubliclySharedConversations().get(str(test_uuid))
assert response.status_code == 200
def test_resolves_attachments_in_shared(self, app):
from application.api.user.sharing.routes import (
GetPubliclySharedConversations,
)
test_uuid = uuid.uuid4()
binary_uuid = Binary.from_uuid(test_uuid, UuidRepresentation.STANDARD)
conv_id = ObjectId()
att_id = ObjectId()
mock_shared = Mock()
mock_shared.find_one.return_value = {
"uuid": binary_uuid,
"conversation_id": conv_id,
"first_n_queries": 1,
"isPromptable": False,
}
mock_conversations = Mock()
mock_conversations.find_one.return_value = {
"_id": conv_id,
"name": "Chat",
"queries": [
{"prompt": "q1", "response": "a1", "attachments": [str(att_id)]}
],
}
mock_attachments = Mock()
mock_attachments.find_one.return_value = {
"_id": att_id,
"filename": "file.pdf",
}
with patch(
"application.api.user.sharing.routes.shared_conversations_collections",
mock_shared,
), patch(
"application.api.user.sharing.routes.conversations_collection",
mock_conversations,
), patch(
"application.api.user.sharing.routes.attachments_collection",
mock_attachments,
):
with app.test_request_context(
f"/api/shared_conversation/{test_uuid}"
):
response = GetPubliclySharedConversations().get(str(test_uuid))
assert response.status_code == 200
assert response.json["queries"][0]["attachments"][0]["fileName"] == "file.pdf"
def test_handles_general_exception(self, app):
from application.api.user.sharing.routes import (
GetPubliclySharedConversations,
)
mock_shared = Mock()
mock_shared.find_one.side_effect = Exception("DB error")
with patch(
"application.api.user.sharing.routes.shared_conversations_collections",
mock_shared,
):
with app.test_request_context(
f"/api/shared_conversation/{uuid.uuid4()}"
):
response = GetPubliclySharedConversations().get(str(uuid.uuid4()))
assert response.status_code == 400
@pytest.mark.unit
class TestShareConversationPromptable:
def test_promptable_with_existing_api_key_and_existing_share(self, app):
from application.api.user.sharing.routes import ShareConversation
conv_id = ObjectId()
test_uuid = uuid.uuid4()
binary_uuid = Binary.from_uuid(test_uuid, UuidRepresentation.STANDARD)
mock_conversations = Mock()
mock_conversations.find_one.return_value = {
"_id": conv_id,
"name": "Test Chat",
"queries": [{"prompt": "hi"}],
}
mock_agents = Mock()
mock_agents.find_one.return_value = {"key": "existing_api_uuid"}
mock_shared = Mock()
mock_shared.find_one.return_value = {"uuid": binary_uuid}
with patch(
"application.api.user.sharing.routes.conversations_collection",
mock_conversations,
), patch(
"application.api.user.sharing.routes.agents_collection",
mock_agents,
), patch(
"application.api.user.sharing.routes.shared_conversations_collections",
mock_shared,
):
with app.test_request_context(
"/api/share?isPromptable=true",
method="POST",
json={
"conversation_id": str(conv_id),
"prompt_id": "default",
"chunks": "3",
},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = ShareConversation().post()
assert response.status_code == 200
assert response.json["identifier"] == str(test_uuid)
def test_promptable_with_existing_api_key_new_share(self, app):
from application.api.user.sharing.routes import ShareConversation
conv_id = ObjectId()
mock_conversations = Mock()
mock_conversations.find_one.return_value = {
"_id": conv_id,
"name": "Test Chat",
"queries": [{"prompt": "hi"}],
}
mock_agents = Mock()
mock_agents.find_one.return_value = {"key": "existing_api_uuid"}
mock_shared = Mock()
mock_shared.find_one.return_value = None
with patch(
"application.api.user.sharing.routes.conversations_collection",
mock_conversations,
), patch(
"application.api.user.sharing.routes.agents_collection",
mock_agents,
), patch(
"application.api.user.sharing.routes.shared_conversations_collections",
mock_shared,
):
with app.test_request_context(
"/api/share?isPromptable=true",
method="POST",
json={
"conversation_id": str(conv_id),
"source": str(ObjectId()),
"retriever": "classic",
},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = ShareConversation().post()
assert response.status_code == 201
mock_shared.insert_one.assert_called_once()
def test_promptable_creates_new_api_key(self, app):
from application.api.user.sharing.routes import ShareConversation
conv_id = ObjectId()
mock_conversations = Mock()
mock_conversations.find_one.return_value = {
"_id": conv_id,
"name": "Test Chat",
"queries": [{"prompt": "hi"}],
}
mock_agents = Mock()
mock_agents.find_one.return_value = None
mock_shared = Mock()
with patch(
"application.api.user.sharing.routes.conversations_collection",
mock_conversations,
), patch(
"application.api.user.sharing.routes.agents_collection",
mock_agents,
), patch(
"application.api.user.sharing.routes.shared_conversations_collections",
mock_shared,
):
with app.test_request_context(
"/api/share?isPromptable=true",
method="POST",
json={
"conversation_id": str(conv_id),
"source": str(ObjectId()),
"retriever": "classic",
},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = ShareConversation().post()
assert response.status_code == 201
mock_agents.insert_one.assert_called_once()
mock_shared.insert_one.assert_called_once()
# =====================================================================
# Coverage gap tests (lines 201-205)
# =====================================================================
@pytest.mark.unit
class TestShareConversationExceptionGap:
def test_share_conversation_exception_returns_400(self):
"""Cover lines 201-205: exception during sharing returns 400."""
from application.api.user.sharing.routes import ShareConversation
from unittest.mock import Mock, patch
app = Flask(__name__)
mock_conversations = Mock()
mock_conversations.find_one.side_effect = Exception("db error")
with patch(
"application.api.user.sharing.routes.conversations_collection",
mock_conversations,
):
with app.test_request_context(
"/api/share",
method="POST",
json={
"conversation_id": str(ObjectId()),
"source": str(ObjectId()),
"retriever": "classic",
},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = ShareConversation().post()
assert response.status_code == 400
# ---------------------------------------------------------------------------
# Coverage — additional uncovered lines: 201-205
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestShareConversationErrorPath:
def test_share_conversation_exception_returns_400(self, app):
"""Cover lines 201-205: exception during sharing returns 400."""
from application.api.user.sharing.routes import ShareConversation
mock_conversations = Mock()
mock_conversations.find_one.side_effect = Exception("DB error")
with patch(
"application.api.user.sharing.routes.conversations_collection",
mock_conversations,
):
with app.test_request_context(
"/api/share",
method="POST",
json={"conversation_id": str(ObjectId())},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = ShareConversation().post()
assert response.status_code == 400
# ---------------------------------------------------------------------------
# Additional coverage for sharing/routes.py
# Lines: 201-205: exception in try block (different entry point)
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestShareConversationInsertException:
"""Cover lines 201-205: exception during insert_one."""
def test_insert_one_exception_returns_400(self, app):
from application.api.user.sharing.routes import ShareConversation
mock_conversations = Mock()
mock_conversations.find_one.return_value = {
"_id": ObjectId(),
"user": "user1",
"queries": [],
}
mock_shared = Mock()
mock_shared.find_one.return_value = None
mock_shared.insert_one.side_effect = Exception("Insert failed")
with patch(
"application.api.user.sharing.routes.conversations_collection",
mock_conversations,
), patch(
"application.api.user.sharing.routes.shared_conversations_collections",
mock_shared,
):
with app.test_request_context(
"/api/share",
method="POST",
json={"conversation_id": str(ObjectId())},
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = ShareConversation().post()
assert response.status_code == 400

View File

@@ -0,0 +1,240 @@
from datetime import timedelta
from unittest.mock import ANY, MagicMock, patch
import pytest
class TestIngestTask:
@pytest.mark.unit
@patch("application.api.user.tasks.ingest_worker")
def test_calls_ingest_worker(self, mock_worker):
from application.api.user.tasks import ingest
mock_worker.return_value = {"status": "ok"}
result = ingest("dir", ["pdf"], "job1", "user1", "/path", "file.pdf")
mock_worker.assert_called_once_with(
ANY, "dir", ["pdf"], "job1", "/path", "file.pdf", "user1",
file_name_map=None,
)
assert result == {"status": "ok"}
@pytest.mark.unit
@patch("application.api.user.tasks.ingest_worker")
def test_passes_file_name_map(self, mock_worker):
from application.api.user.tasks import ingest
mock_worker.return_value = {"status": "ok"}
name_map = {"a.pdf": "b.pdf"}
ingest("dir", ["pdf"], "job1", "user1", "/path", "file.pdf",
file_name_map=name_map)
mock_worker.assert_called_once_with(
ANY, "dir", ["pdf"], "job1", "/path", "file.pdf", "user1",
file_name_map=name_map,
)
class TestIngestRemoteTask:
@pytest.mark.unit
@patch("application.api.user.tasks.remote_worker")
def test_calls_remote_worker(self, mock_worker):
from application.api.user.tasks import ingest_remote
mock_worker.return_value = {"status": "ok"}
result = ingest_remote({"url": "http://x"}, "job1", "user1", "web")
mock_worker.assert_called_once_with(
ANY, {"url": "http://x"}, "job1", "user1", "web"
)
assert result == {"status": "ok"}
class TestReingestSourceTask:
@pytest.mark.unit
@patch("application.worker.reingest_source_worker")
def test_calls_reingest_worker(self, mock_worker):
from application.api.user.tasks import reingest_source_task
mock_worker.return_value = {"status": "ok"}
result = reingest_source_task("source123", "user1")
mock_worker.assert_called_once_with(ANY, "source123", "user1")
assert result == {"status": "ok"}
class TestScheduleSyncsTask:
@pytest.mark.unit
@patch("application.api.user.tasks.sync_worker")
def test_calls_sync_worker(self, mock_worker):
from application.api.user.tasks import schedule_syncs
mock_worker.return_value = {"status": "ok"}
result = schedule_syncs("daily")
mock_worker.assert_called_once_with(ANY, "daily")
assert result == {"status": "ok"}
class TestSyncSourceTask:
@pytest.mark.unit
@patch("application.api.user.tasks.sync")
def test_calls_sync(self, mock_sync):
from application.api.user.tasks import sync_source
mock_sync.return_value = {"status": "ok"}
result = sync_source(
{"data": 1}, "job1", "user1", "web", "daily", "classic", "doc1"
)
mock_sync.assert_called_once_with(
ANY, {"data": 1}, "job1", "user1", "web", "daily", "classic", "doc1"
)
assert result == {"status": "ok"}
class TestStoreAttachmentTask:
@pytest.mark.unit
@patch("application.api.user.tasks.attachment_worker")
def test_calls_attachment_worker(self, mock_worker):
from application.api.user.tasks import store_attachment
mock_worker.return_value = {"status": "ok"}
result = store_attachment({"file": "info"}, "user1")
mock_worker.assert_called_once_with(ANY, {"file": "info"}, "user1")
assert result == {"status": "ok"}
class TestProcessAgentWebhookTask:
@pytest.mark.unit
@patch("application.api.user.tasks.agent_webhook_worker")
def test_calls_agent_webhook_worker(self, mock_worker):
from application.api.user.tasks import process_agent_webhook
mock_worker.return_value = {"status": "ok"}
result = process_agent_webhook("agent123", {"event": "test"})
mock_worker.assert_called_once_with(ANY, "agent123", {"event": "test"})
assert result == {"status": "ok"}
class TestIngestConnectorTask:
@pytest.mark.unit
@patch("application.worker.ingest_connector")
def test_calls_ingest_connector_defaults(self, mock_worker):
from application.api.user.tasks import ingest_connector_task
mock_worker.return_value = {"status": "ok"}
result = ingest_connector_task("job1", "user1", "gdrive")
mock_worker.assert_called_once_with(
ANY,
"job1",
"user1",
"gdrive",
session_token=None,
file_ids=None,
folder_ids=None,
recursive=True,
retriever="classic",
operation_mode="upload",
doc_id=None,
sync_frequency="never",
)
assert result == {"status": "ok"}
@pytest.mark.unit
@patch("application.worker.ingest_connector")
def test_calls_ingest_connector_custom(self, mock_worker):
from application.api.user.tasks import ingest_connector_task
mock_worker.return_value = {"status": "ok"}
result = ingest_connector_task(
"job1",
"user1",
"sharepoint",
session_token="tok",
file_ids=["f1"],
folder_ids=["d1"],
recursive=False,
retriever="duckdb",
operation_mode="sync",
doc_id="doc1",
sync_frequency="daily",
)
mock_worker.assert_called_once_with(
ANY,
"job1",
"user1",
"sharepoint",
session_token="tok",
file_ids=["f1"],
folder_ids=["d1"],
recursive=False,
retriever="duckdb",
operation_mode="sync",
doc_id="doc1",
sync_frequency="daily",
)
assert result == {"status": "ok"}
class TestSetupPeriodicTasks:
@pytest.mark.unit
def test_registers_periodic_tasks(self):
from application.api.user.tasks import setup_periodic_tasks
sender = MagicMock()
setup_periodic_tasks(sender)
assert sender.add_periodic_task.call_count == 3
calls = sender.add_periodic_task.call_args_list
# daily
assert calls[0][0][0] == timedelta(days=1)
# weekly
assert calls[1][0][0] == timedelta(weeks=1)
# monthly
assert calls[2][0][0] == timedelta(days=30)
class TestMcpOauthTask:
@pytest.mark.unit
@patch("application.api.user.tasks.mcp_oauth")
def test_calls_mcp_oauth(self, mock_worker):
from application.api.user.tasks import mcp_oauth_task
mock_worker.return_value = {"url": "http://auth"}
result = mcp_oauth_task({"server": "mcp"}, "user1")
mock_worker.assert_called_once_with(ANY, {"server": "mcp"}, "user1")
assert result == {"url": "http://auth"}
class TestMcpOauthStatusTask:
@pytest.mark.unit
@patch("application.api.user.tasks.mcp_oauth_status")
def test_calls_mcp_oauth_status(self, mock_worker):
from application.api.user.tasks import mcp_oauth_status_task
mock_worker.return_value = {"status": "authorized"}
result = mcp_oauth_status_task("task123")
mock_worker.assert_called_once_with(ANY, "task123")
assert result == {"status": "authorized"}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,411 @@
from unittest.mock import Mock
import pytest
from bson import ObjectId
from flask import Flask
@pytest.fixture
def app():
app = Flask(__name__)
return app
@pytest.mark.unit
class TestGetUserId:
def test_returns_user_id_from_decoded_token(self, app):
from application.api.user.utils import get_user_id
with app.test_request_context():
from flask import request
request.decoded_token = {"sub": "user_123"}
assert get_user_id() == "user_123"
def test_returns_none_when_no_decoded_token(self, app):
from application.api.user.utils import get_user_id
with app.test_request_context():
assert get_user_id() is None
def test_returns_none_when_decoded_token_has_no_sub(self, app):
from application.api.user.utils import get_user_id
with app.test_request_context():
from flask import request
request.decoded_token = {}
assert get_user_id() is None
@pytest.mark.unit
class TestRequireAuth:
def test_allows_authenticated_request(self, app):
from application.api.user.utils import require_auth
@require_auth
def protected():
return "ok"
with app.test_request_context():
from flask import request
request.decoded_token = {"sub": "user_123"}
assert protected() == "ok"
def test_returns_401_when_unauthenticated(self, app):
from application.api.user.utils import require_auth
@require_auth
def protected():
return "ok"
with app.test_request_context():
result = protected()
assert result.status_code == 401
@pytest.mark.unit
class TestSuccessResponse:
def test_default_success_response(self, app):
from application.api.user.utils import success_response
with app.app_context():
resp = success_response()
assert resp.status_code == 200
assert resp.json["success"] is True
def test_success_response_with_data(self, app):
from application.api.user.utils import success_response
with app.app_context():
resp = success_response({"items": [1, 2], "total": 2})
assert resp.status_code == 200
assert resp.json["success"] is True
assert resp.json["items"] == [1, 2]
assert resp.json["total"] == 2
def test_success_response_custom_status(self, app):
from application.api.user.utils import success_response
with app.app_context():
resp = success_response({"id": "new"}, 201)
assert resp.status_code == 201
@pytest.mark.unit
class TestErrorResponse:
def test_default_error_response(self, app):
from application.api.user.utils import error_response
with app.app_context():
resp = error_response("Something went wrong")
assert resp.status_code == 400
assert resp.json["success"] is False
assert resp.json["message"] == "Something went wrong"
def test_error_response_custom_status(self, app):
from application.api.user.utils import error_response
with app.app_context():
resp = error_response("Not found", 404)
assert resp.status_code == 404
def test_error_response_extra_kwargs(self, app):
from application.api.user.utils import error_response
with app.app_context():
resp = error_response("Bad", 400, errors=["field1", "field2"])
assert resp.json["errors"] == ["field1", "field2"]
@pytest.mark.unit
class TestValidateObjectId:
def test_valid_object_id(self, app):
from application.api.user.utils import validate_object_id
with app.app_context():
oid = ObjectId()
result, error = validate_object_id(str(oid))
assert result == oid
assert error is None
def test_invalid_object_id(self, app):
from application.api.user.utils import validate_object_id
with app.app_context():
result, error = validate_object_id("not-a-valid-id")
assert result is None
assert error.status_code == 400
assert "Invalid" in error.json["message"]
def test_custom_resource_name(self, app):
from application.api.user.utils import validate_object_id
with app.app_context():
_, error = validate_object_id("bad", "Workflow")
assert "Workflow" in error.json["message"]
@pytest.mark.unit
class TestValidatePagination:
def test_default_pagination(self, app):
from application.api.user.utils import validate_pagination
with app.test_request_context("/?limit=10&skip=0"):
limit, skip, error = validate_pagination()
assert limit == 10
assert skip == 0
assert error is None
def test_uses_defaults_when_no_params(self, app):
from application.api.user.utils import validate_pagination
with app.test_request_context("/"):
limit, skip, error = validate_pagination()
assert limit == 20
assert skip == 0
assert error is None
def test_enforces_max_limit(self, app):
from application.api.user.utils import validate_pagination
with app.test_request_context("/?limit=500"):
limit, _, _ = validate_pagination(max_limit=100)
assert limit == 100
def test_invalid_limit(self, app):
from application.api.user.utils import validate_pagination
with app.test_request_context("/?limit=-1"):
_, _, error = validate_pagination()
assert error is not None
assert error.status_code == 400
def test_invalid_skip(self, app):
from application.api.user.utils import validate_pagination
with app.test_request_context("/?skip=-1"):
_, _, error = validate_pagination()
assert error is not None
def test_non_numeric_values(self, app):
from application.api.user.utils import validate_pagination
with app.test_request_context("/?limit=abc"):
_, _, error = validate_pagination()
assert error is not None
@pytest.mark.unit
class TestCheckResourceOwnership:
def test_returns_resource_when_owned(self, app):
from application.api.user.utils import check_resource_ownership
with app.app_context():
collection = Mock()
oid = ObjectId()
doc = {"_id": oid, "user": "user1", "name": "test"}
collection.find_one.return_value = doc
resource, error = check_resource_ownership(collection, oid, "user1")
assert resource == doc
assert error is None
def test_returns_404_when_not_found(self, app):
from application.api.user.utils import check_resource_ownership
with app.app_context():
collection = Mock()
collection.find_one.return_value = None
resource, error = check_resource_ownership(
collection, ObjectId(), "user1", "Workflow"
)
assert resource is None
assert error.status_code == 404
assert "Workflow" in error.json["message"]
@pytest.mark.unit
class TestSerializeObjectId:
def test_converts_id_to_string(self):
from application.api.user.utils import serialize_object_id
oid = ObjectId()
obj = {"_id": oid, "name": "test"}
result = serialize_object_id(obj)
assert result["id"] == str(oid)
assert "_id" not in result
def test_custom_field_names(self):
from application.api.user.utils import serialize_object_id
oid = ObjectId()
obj = {"custom_id": oid}
result = serialize_object_id(obj, id_field="custom_id", new_field="uid")
assert result["uid"] == str(oid)
assert "custom_id" not in result
def test_no_id_field_present(self):
from application.api.user.utils import serialize_object_id
obj = {"name": "test"}
result = serialize_object_id(obj)
assert "id" not in result
@pytest.mark.unit
class TestSerializeList:
def test_applies_serializer_to_all_items(self):
from application.api.user.utils import serialize_list
items = [{"_id": ObjectId()}, {"_id": ObjectId()}]
def serializer(item):
return {"id": str(item["_id"])}
result = serialize_list(items, serializer)
assert len(result) == 2
assert all("id" in r for r in result)
def test_empty_list(self):
from application.api.user.utils import serialize_list
assert serialize_list([], lambda x: x) == []
@pytest.mark.unit
class TestRequireFields:
def test_allows_valid_request(self, app):
from application.api.user.utils import require_fields
@require_fields(["name", "email"])
def handler():
return "ok"
with app.test_request_context(
"/", method="POST", json={"name": "Alice", "email": "a@b.com"}
):
assert handler() == "ok"
def test_rejects_missing_fields(self, app):
from application.api.user.utils import require_fields
@require_fields(["name", "email"])
def handler():
return "ok"
with app.test_request_context("/", method="POST", json={"name": "Alice"}):
result = handler()
assert result.status_code == 400
assert "email" in result.json["message"]
def test_rejects_empty_body(self, app):
from application.api.user.utils import require_fields
@require_fields(["name"])
def handler():
return "ok"
with app.test_request_context(
"/", method="POST", json={}
):
result = handler()
assert result.status_code == 400
@pytest.mark.unit
class TestSafeDbOperation:
def test_returns_result_on_success(self, app):
from application.api.user.utils import safe_db_operation
with app.app_context():
result, error = safe_db_operation(lambda: {"inserted": True})
assert result == {"inserted": True}
assert error is None
def test_returns_error_on_exception(self, app):
from application.api.user.utils import safe_db_operation
with app.app_context():
result, error = safe_db_operation(
lambda: (_ for _ in ()).throw(RuntimeError("db error")),
"Operation failed",
)
assert result is None
assert error.status_code == 400
assert error.json["message"] == "Operation failed"
def test_hides_exception_details(self, app):
from application.api.user.utils import safe_db_operation
with app.app_context():
_, error = safe_db_operation(
lambda: (_ for _ in ()).throw(RuntimeError("secret credentials")),
"Failed",
)
assert "credentials" not in error.json["message"]
@pytest.mark.unit
class TestValidateEnum:
def test_valid_value(self, app):
from application.api.user.utils import validate_enum
with app.app_context():
assert validate_enum("draft", ["draft", "published"], "status") is None
def test_invalid_value(self, app):
from application.api.user.utils import validate_enum
with app.app_context():
error = validate_enum("unknown", ["draft", "published"], "status")
assert error.status_code == 400
assert "status" in error.json["message"]
@pytest.mark.unit
class TestExtractSortParams:
def test_defaults(self, app):
from application.api.user.utils import extract_sort_params
with app.test_request_context("/"):
field, order = extract_sort_params()
assert field == "created_at"
assert order == -1
def test_custom_params(self, app):
from application.api.user.utils import extract_sort_params
with app.test_request_context("/?sort=name&order=asc"):
field, order = extract_sort_params()
assert field == "name"
assert order == 1
def test_enforces_allowed_fields(self, app):
from application.api.user.utils import extract_sort_params
with app.test_request_context("/?sort=forbidden_field"):
field, _ = extract_sort_params(allowed_fields=["name", "date"])
assert field == "created_at"
def test_desc_order(self, app):
from application.api.user.utils import extract_sort_params
with app.test_request_context("/?order=desc"):
_, order = extract_sort_params()
assert order == -1

View File

@@ -0,0 +1,225 @@
from unittest.mock import Mock, patch
import pytest
from bson import ObjectId
from flask import Flask
@pytest.fixture
def app():
app = Flask(__name__)
return app
@pytest.mark.unit
class TestAgentWebhook:
def test_returns_existing_webhook_url(self, app):
from application.api.user.agents.webhooks import AgentWebhook
agent_id = ObjectId()
mock_collection = Mock()
mock_collection.find_one.return_value = {
"_id": agent_id,
"user": "user1",
"incoming_webhook_token": "existing_token",
}
with patch(
"application.api.user.agents.webhooks.agents_collection",
mock_collection,
), patch(
"application.api.user.agents.webhooks.settings",
Mock(API_URL="https://api.example.com"),
):
with app.test_request_context(
f"/api/agent_webhook?id={agent_id}"
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = AgentWebhook().get()
assert response.status_code == 200
assert response.json["success"] is True
assert "existing_token" in response.json["webhook_url"]
mock_collection.update_one.assert_not_called()
def test_generates_new_webhook_token(self, app):
from application.api.user.agents.webhooks import AgentWebhook
agent_id = ObjectId()
mock_collection = Mock()
mock_collection.find_one.return_value = {
"_id": agent_id,
"user": "user1",
"incoming_webhook_token": None,
}
with patch(
"application.api.user.agents.webhooks.agents_collection",
mock_collection,
), patch(
"application.api.user.agents.webhooks.settings",
Mock(API_URL="https://api.example.com"),
), patch(
"application.api.user.agents.webhooks.secrets.token_urlsafe",
return_value="new_generated_token",
):
with app.test_request_context(
f"/api/agent_webhook?id={agent_id}"
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = AgentWebhook().get()
assert response.status_code == 200
assert "new_generated_token" in response.json["webhook_url"]
mock_collection.update_one.assert_called_once()
def test_returns_401_unauthenticated(self, app):
from application.api.user.agents.webhooks import AgentWebhook
with app.test_request_context(f"/api/agent_webhook?id={ObjectId()}"):
from flask import request
request.decoded_token = None
response = AgentWebhook().get()
assert response.status_code == 401
def test_returns_400_missing_id(self, app):
from application.api.user.agents.webhooks import AgentWebhook
with app.test_request_context("/api/agent_webhook"):
from flask import request
request.decoded_token = {"sub": "user1"}
response = AgentWebhook().get()
assert response.status_code == 400
def test_returns_404_agent_not_found(self, app):
from application.api.user.agents.webhooks import AgentWebhook
mock_collection = Mock()
mock_collection.find_one.return_value = None
with patch(
"application.api.user.agents.webhooks.agents_collection",
mock_collection,
):
with app.test_request_context(
f"/api/agent_webhook?id={ObjectId()}"
):
from flask import request
request.decoded_token = {"sub": "user1"}
response = AgentWebhook().get()
assert response.status_code == 404
@pytest.mark.unit
class TestAgentWebhookListenerPost:
def test_enqueues_task_on_valid_post(self, app):
from application.api.user.agents.webhooks import AgentWebhookListener
mock_task = Mock()
mock_task.id = "task_abc"
with patch(
"application.api.user.agents.webhooks.process_agent_webhook"
) as mock_process:
mock_process.delay.return_value = mock_task
with app.test_request_context(
"/api/webhooks/agents/tok",
method="POST",
json={"event": "new_message"},
):
listener = AgentWebhookListener()
response = listener._enqueue_webhook_task(
"agent123", {"event": "new_message"}, "POST"
)
assert response.status_code == 200
assert response.json["task_id"] == "task_abc"
mock_process.delay.assert_called_once_with(
agent_id="agent123", payload={"event": "new_message"}
)
def test_returns_400_on_missing_json(self, app):
from application.api.user.agents.webhooks import AgentWebhookListener
with app.test_request_context(
"/api/webhooks/agents/tok",
method="POST",
json=None,
content_type="application/json",
data="",
):
from flask import request as flask_request
# Force get_json to return None (simulating empty/missing body)
with patch.object(
flask_request, "get_json", return_value=None
):
listener = AgentWebhookListener()
response = listener.post(
webhook_token="tok",
agent={"_id": ObjectId()},
agent_id_str="agent123",
)
assert response.status_code == 400
def test_handles_enqueue_error(self, app):
from application.api.user.agents.webhooks import AgentWebhookListener
with patch(
"application.api.user.agents.webhooks.process_agent_webhook"
) as mock_process:
mock_process.delay.side_effect = Exception("Queue down")
with app.test_request_context(
"/api/webhooks/agents/tok",
method="POST",
json={"event": "test"},
):
listener = AgentWebhookListener()
response = listener._enqueue_webhook_task(
"agent123", {"event": "test"}, "POST"
)
assert response.status_code == 500
@pytest.mark.unit
class TestAgentWebhookListenerGet:
def test_uses_query_params_as_payload(self, app):
from application.api.user.agents.webhooks import AgentWebhookListener
mock_task = Mock()
mock_task.id = "task_xyz"
with patch(
"application.api.user.agents.webhooks.process_agent_webhook"
) as mock_process:
mock_process.delay.return_value = mock_task
with app.test_request_context(
"/api/webhooks/agents/tok?event=ping&source=test",
method="GET",
):
listener = AgentWebhookListener()
response = listener.get(
webhook_token="tok",
agent={"_id": ObjectId()},
agent_id_str="agent456",
)
assert response.status_code == 200
call_kwargs = mock_process.delay.call_args[1]
assert call_kwargs["payload"]["event"] == "ping"
assert call_kwargs["payload"]["source"] == "test"

File diff suppressed because it is too large Load Diff

0
tests/core/__init__.py Normal file
View File

View File

@@ -0,0 +1,803 @@
"""Tests for application/core/model_settings.py"""
from unittest.mock import MagicMock, patch
import pytest
from application.core.model_settings import (
AvailableModel,
ModelCapabilities,
ModelProvider,
ModelRegistry,
)
class TestModelProvider:
@pytest.mark.unit
def test_all_providers_exist(self):
assert ModelProvider.OPENAI == "openai"
assert ModelProvider.ANTHROPIC == "anthropic"
assert ModelProvider.GOOGLE == "google"
assert ModelProvider.GROQ == "groq"
assert ModelProvider.DOCSGPT == "docsgpt"
assert ModelProvider.HUGGINGFACE == "huggingface"
assert ModelProvider.NOVITA == "novita"
assert ModelProvider.OPENROUTER == "openrouter"
assert ModelProvider.SAGEMAKER == "sagemaker"
assert ModelProvider.PREMAI == "premai"
assert ModelProvider.LLAMA_CPP == "llama.cpp"
assert ModelProvider.AZURE_OPENAI == "azure_openai"
class TestModelCapabilities:
@pytest.mark.unit
def test_defaults(self):
caps = ModelCapabilities()
assert caps.supports_tools is False
assert caps.supports_structured_output is False
assert caps.supports_streaming is True
assert caps.supported_attachment_types == []
assert caps.context_window == 128000
assert caps.input_cost_per_token is None
assert caps.output_cost_per_token is None
@pytest.mark.unit
def test_custom_values(self):
caps = ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
context_window=32000,
input_cost_per_token=0.001,
)
assert caps.supports_tools is True
assert caps.context_window == 32000
class TestAvailableModel:
@pytest.mark.unit
def test_to_dict_basic(self):
model = AvailableModel(
id="gpt-4",
provider=ModelProvider.OPENAI,
display_name="GPT-4",
description="OpenAI GPT-4",
)
d = model.to_dict()
assert d["id"] == "gpt-4"
assert d["provider"] == "openai"
assert d["display_name"] == "GPT-4"
assert d["enabled"] is True
assert "base_url" not in d
@pytest.mark.unit
def test_to_dict_with_base_url(self):
model = AvailableModel(
id="local-model",
provider=ModelProvider.OPENAI,
display_name="Local",
base_url="http://localhost:11434",
)
d = model.to_dict()
assert d["base_url"] == "http://localhost:11434"
@pytest.mark.unit
def test_to_dict_includes_capabilities(self):
caps = ModelCapabilities(supports_tools=True, context_window=64000)
model = AvailableModel(
id="m1",
provider=ModelProvider.ANTHROPIC,
display_name="M1",
capabilities=caps,
)
d = model.to_dict()
assert d["supports_tools"] is True
assert d["context_window"] == 64000
class TestModelRegistry:
@pytest.fixture(autouse=True)
def _reset_singleton(self):
"""Reset singleton between tests."""
ModelRegistry._instance = None
ModelRegistry._initialized = False
yield
ModelRegistry._instance = None
ModelRegistry._initialized = False
@pytest.mark.unit
def test_singleton(self):
with patch.object(ModelRegistry, "_load_models"):
r1 = ModelRegistry()
r2 = ModelRegistry()
assert r1 is r2
@pytest.mark.unit
def test_get_instance(self):
with patch.object(ModelRegistry, "_load_models"):
r = ModelRegistry.get_instance()
assert isinstance(r, ModelRegistry)
@pytest.mark.unit
def test_get_model(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
model = AvailableModel(id="test", provider=ModelProvider.OPENAI, display_name="Test")
reg.models["test"] = model
assert reg.get_model("test") is model
assert reg.get_model("nonexistent") is None
@pytest.mark.unit
def test_get_all_models(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models["m1"] = AvailableModel(id="m1", provider=ModelProvider.OPENAI, display_name="M1")
reg.models["m2"] = AvailableModel(id="m2", provider=ModelProvider.ANTHROPIC, display_name="M2")
assert len(reg.get_all_models()) == 2
@pytest.mark.unit
def test_get_enabled_models(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models["m1"] = AvailableModel(id="m1", provider=ModelProvider.OPENAI, display_name="M1", enabled=True)
reg.models["m2"] = AvailableModel(id="m2", provider=ModelProvider.OPENAI, display_name="M2", enabled=False)
enabled = reg.get_enabled_models()
assert len(enabled) == 1
assert enabled[0].id == "m1"
@pytest.mark.unit
def test_model_exists(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models["m1"] = AvailableModel(id="m1", provider=ModelProvider.OPENAI, display_name="M1")
assert reg.model_exists("m1") is True
assert reg.model_exists("m2") is False
@pytest.mark.unit
def test_parse_model_names(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
assert reg._parse_model_names("model1,model2") == ["model1", "model2"]
assert reg._parse_model_names("model1 , model2 ") == ["model1", "model2"]
assert reg._parse_model_names("single") == ["single"]
assert reg._parse_model_names("") == []
assert reg._parse_model_names(None) == []
@pytest.mark.unit
def test_add_docsgpt_models(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
reg._add_docsgpt_models(mock_settings)
assert "docsgpt-local" in reg.models
@pytest.mark.unit
def test_add_huggingface_models(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
reg._add_huggingface_models(mock_settings)
assert "huggingface-local" in reg.models
@pytest.mark.unit
def test_load_models_with_openai_key(self):
mock_settings = MagicMock()
mock_settings.OPENAI_BASE_URL = None
mock_settings.OPENAI_API_KEY = "sk-test"
mock_settings.OPENAI_API_BASE = None
mock_settings.ANTHROPIC_API_KEY = None
mock_settings.GOOGLE_API_KEY = None
mock_settings.GROQ_API_KEY = None
mock_settings.OPEN_ROUTER_API_KEY = None
mock_settings.NOVITA_API_KEY = None
mock_settings.HUGGINGFACE_API_KEY = None
mock_settings.LLM_PROVIDER = "openai"
mock_settings.LLM_NAME = ""
mock_settings.API_KEY = None
with patch("application.core.settings.settings", mock_settings):
reg = ModelRegistry()
assert len(reg.models) > 0
@pytest.mark.unit
def test_load_models_custom_openai_base_url(self):
mock_settings = MagicMock()
mock_settings.OPENAI_BASE_URL = "http://localhost:11434/v1"
mock_settings.OPENAI_API_KEY = "sk-test"
mock_settings.OPENAI_API_BASE = None
mock_settings.ANTHROPIC_API_KEY = None
mock_settings.GOOGLE_API_KEY = None
mock_settings.GROQ_API_KEY = None
mock_settings.OPEN_ROUTER_API_KEY = None
mock_settings.NOVITA_API_KEY = None
mock_settings.HUGGINGFACE_API_KEY = None
mock_settings.LLM_PROVIDER = "openai"
mock_settings.LLM_NAME = "llama3,gemma"
mock_settings.API_KEY = None
with patch("application.core.settings.settings", mock_settings):
reg = ModelRegistry()
assert "llama3" in reg.models
assert "gemma" in reg.models
@pytest.mark.unit
def test_default_model_selection_from_llm_name(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {"gpt-4": AvailableModel(id="gpt-4", provider=ModelProvider.OPENAI, display_name="GPT-4")}
reg.default_model_id = "gpt-4"
assert reg.default_model_id == "gpt-4"
@pytest.mark.unit
def test_add_anthropic_models_with_key(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.ANTHROPIC_API_KEY = "sk-ant-test"
mock_settings.LLM_PROVIDER = ""
mock_settings.LLM_NAME = ""
reg._add_anthropic_models(mock_settings)
assert len(reg.models) > 0
@pytest.mark.unit
def test_add_google_models_with_key(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.GOOGLE_API_KEY = "google-test"
mock_settings.LLM_PROVIDER = ""
mock_settings.LLM_NAME = ""
reg._add_google_models(mock_settings)
assert len(reg.models) > 0
@pytest.mark.unit
def test_add_groq_models_with_key(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.GROQ_API_KEY = "groq-test"
mock_settings.LLM_PROVIDER = ""
mock_settings.LLM_NAME = ""
reg._add_groq_models(mock_settings)
assert len(reg.models) > 0
@pytest.mark.unit
def test_add_openrouter_models_with_key(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.OPEN_ROUTER_API_KEY = "or-test"
mock_settings.LLM_PROVIDER = ""
mock_settings.LLM_NAME = ""
reg._add_openrouter_models(mock_settings)
assert len(reg.models) > 0
@pytest.mark.unit
def test_add_novita_models_with_key(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.NOVITA_API_KEY = "novita-test"
mock_settings.LLM_PROVIDER = ""
mock_settings.LLM_NAME = ""
reg._add_novita_models(mock_settings)
assert len(reg.models) > 0
@pytest.mark.unit
def test_add_azure_openai_models_specific(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.LLM_PROVIDER = "azure_openai"
mock_settings.LLM_NAME = "nonexistent-model"
reg._add_azure_openai_models(mock_settings)
# Falls through to adding all azure models
assert len(reg.models) > 0
@pytest.mark.unit
def test_add_anthropic_models_no_key_with_provider(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.ANTHROPIC_API_KEY = None
mock_settings.LLM_PROVIDER = "anthropic"
mock_settings.LLM_NAME = "nonexistent"
reg._add_anthropic_models(mock_settings)
assert len(reg.models) > 0
@pytest.mark.unit
def test_default_model_fallback_to_first(self):
mock_settings = MagicMock()
mock_settings.OPENAI_BASE_URL = None
mock_settings.OPENAI_API_KEY = None
mock_settings.OPENAI_API_BASE = None
mock_settings.ANTHROPIC_API_KEY = None
mock_settings.GOOGLE_API_KEY = None
mock_settings.GROQ_API_KEY = None
mock_settings.OPEN_ROUTER_API_KEY = None
mock_settings.NOVITA_API_KEY = None
mock_settings.HUGGINGFACE_API_KEY = None
mock_settings.LLM_PROVIDER = ""
mock_settings.LLM_NAME = ""
mock_settings.API_KEY = None
with patch("application.core.settings.settings", mock_settings):
reg = ModelRegistry()
# Should have at least docsgpt-local
assert reg.default_model_id is not None
@pytest.mark.unit
def test_default_model_from_provider_fallback(self):
"""When LLM_NAME is not set but LLM_PROVIDER and API_KEY are,
default should be first model of that provider."""
mock_settings = MagicMock()
mock_settings.OPENAI_BASE_URL = None
mock_settings.OPENAI_API_KEY = "sk-test"
mock_settings.OPENAI_API_BASE = None
mock_settings.ANTHROPIC_API_KEY = None
mock_settings.GOOGLE_API_KEY = None
mock_settings.GROQ_API_KEY = None
mock_settings.OPEN_ROUTER_API_KEY = None
mock_settings.NOVITA_API_KEY = None
mock_settings.HUGGINGFACE_API_KEY = None
mock_settings.LLM_PROVIDER = "openai"
mock_settings.LLM_NAME = None
mock_settings.API_KEY = "sk-test"
with patch("application.core.settings.settings", mock_settings):
reg = ModelRegistry()
assert reg.default_model_id is not None
@pytest.mark.unit
def test_add_google_models_no_key_with_provider(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.GOOGLE_API_KEY = None
mock_settings.LLM_PROVIDER = "google"
mock_settings.LLM_NAME = "nonexistent"
reg._add_google_models(mock_settings)
assert len(reg.models) > 0
@pytest.mark.unit
def test_add_groq_models_no_key_with_provider(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.GROQ_API_KEY = None
mock_settings.LLM_PROVIDER = "groq"
mock_settings.LLM_NAME = "nonexistent"
reg._add_groq_models(mock_settings)
assert len(reg.models) > 0
@pytest.mark.unit
def test_add_openrouter_models_no_key_with_provider(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.OPEN_ROUTER_API_KEY = None
mock_settings.LLM_PROVIDER = "openrouter"
mock_settings.LLM_NAME = "nonexistent"
reg._add_openrouter_models(mock_settings)
assert len(reg.models) > 0
@pytest.mark.unit
def test_add_novita_models_no_key_with_provider(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.NOVITA_API_KEY = None
mock_settings.LLM_PROVIDER = "novita"
mock_settings.LLM_NAME = "nonexistent"
reg._add_novita_models(mock_settings)
assert len(reg.models) > 0
@pytest.mark.unit
def test_to_dict_disabled_model(self):
model = AvailableModel(
id="disabled",
provider=ModelProvider.OPENAI,
display_name="Disabled",
enabled=False,
)
d = model.to_dict()
assert d["enabled"] is False
@pytest.mark.unit
def test_to_dict_with_attachment_types(self):
caps = ModelCapabilities(
supported_attachment_types=["image/png", "application/pdf"],
)
model = AvailableModel(
id="vision",
provider=ModelProvider.OPENAI,
display_name="Vision",
capabilities=caps,
)
d = model.to_dict()
assert d["supported_attachment_types"] == ["image/png", "application/pdf"]
# ----------------------------------------------------------------
# Coverage for _add_* methods with matching LLM_NAME
# Lines: 100, 105, 147, 171, 179, 186, 199-201, 204, 210, 213,
# 218, 229, 233, 241, 250
# ----------------------------------------------------------------
@pytest.mark.unit
def test_add_azure_openai_models_with_matching_name(self):
"""Cover line 186: azure model matching LLM_NAME returns early."""
from application.core.model_configs import AZURE_OPENAI_MODELS
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.LLM_PROVIDER = "azure_openai"
if AZURE_OPENAI_MODELS:
mock_settings.LLM_NAME = AZURE_OPENAI_MODELS[0].id
else:
mock_settings.LLM_NAME = "nonexistent"
reg._add_azure_openai_models(mock_settings)
# Should have added at least one model
assert len(reg.models) >= 1
@pytest.mark.unit
def test_add_anthropic_no_key_no_provider_fallthrough(self):
"""Cover lines 199-204: no key, provider set but name not found -> add all."""
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.ANTHROPIC_API_KEY = None
mock_settings.LLM_PROVIDER = "anthropic"
mock_settings.LLM_NAME = "nonexistent-model"
reg._add_anthropic_models(mock_settings)
# Falls through to add all anthropic models
assert len(reg.models) > 0
@pytest.mark.unit
def test_add_google_no_key_matching_name(self):
"""Cover lines 213-218: Google fallback with matching name."""
from application.core.model_configs import GOOGLE_MODELS
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.GOOGLE_API_KEY = None
mock_settings.LLM_PROVIDER = "google"
if GOOGLE_MODELS:
mock_settings.LLM_NAME = GOOGLE_MODELS[0].id
else:
mock_settings.LLM_NAME = "nonexistent"
reg._add_google_models(mock_settings)
assert len(reg.models) >= 1
@pytest.mark.unit
def test_add_groq_no_key_matching_name(self):
"""Cover lines 229-233: Groq fallback with matching name."""
from application.core.model_configs import GROQ_MODELS
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.GROQ_API_KEY = None
mock_settings.LLM_PROVIDER = "groq"
if GROQ_MODELS:
mock_settings.LLM_NAME = GROQ_MODELS[0].id
else:
mock_settings.LLM_NAME = "nonexistent"
reg._add_groq_models(mock_settings)
assert len(reg.models) >= 1
@pytest.mark.unit
def test_add_openrouter_no_key_matching_name(self):
"""Cover lines 241-250: OpenRouter fallback with matching name."""
from application.core.model_configs import OPENROUTER_MODELS
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.OPEN_ROUTER_API_KEY = None
mock_settings.LLM_PROVIDER = "openrouter"
if OPENROUTER_MODELS:
mock_settings.LLM_NAME = OPENROUTER_MODELS[0].id
else:
mock_settings.LLM_NAME = "nonexistent"
reg._add_openrouter_models(mock_settings)
assert len(reg.models) >= 1
@pytest.mark.unit
def test_add_novita_no_key_matching_name(self):
"""Cover novita fallback with matching name."""
from application.core.model_configs import NOVITA_MODELS
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.NOVITA_API_KEY = None
mock_settings.LLM_PROVIDER = "novita"
if NOVITA_MODELS:
mock_settings.LLM_NAME = NOVITA_MODELS[0].id
else:
mock_settings.LLM_NAME = "nonexistent"
reg._add_novita_models(mock_settings)
assert len(reg.models) >= 1
@pytest.mark.unit
def test_load_models_default_from_llm_name_exact_match(self):
"""Cover line 136/147: exact LLM_NAME match for default model."""
mock_settings = MagicMock()
mock_settings.OPENAI_BASE_URL = None
mock_settings.OPENAI_API_KEY = "sk-test"
mock_settings.OPENAI_API_BASE = None
mock_settings.ANTHROPIC_API_KEY = None
mock_settings.GOOGLE_API_KEY = None
mock_settings.GROQ_API_KEY = None
mock_settings.OPEN_ROUTER_API_KEY = None
mock_settings.NOVITA_API_KEY = None
mock_settings.HUGGINGFACE_API_KEY = None
mock_settings.LLM_PROVIDER = "openai"
mock_settings.API_KEY = None
from application.core.model_configs import OPENAI_MODELS
if OPENAI_MODELS:
mock_settings.LLM_NAME = OPENAI_MODELS[0].id
else:
mock_settings.LLM_NAME = "gpt-4o"
with patch("application.core.settings.settings", mock_settings):
reg = ModelRegistry()
assert reg.default_model_id is not None
@pytest.mark.unit
def test_add_openai_models_local_endpoint_no_name(self):
"""Cover line 171: local endpoint without LLM_NAME adds nothing."""
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.OPENAI_BASE_URL = "http://localhost:11434/v1"
mock_settings.OPENAI_API_KEY = "sk-test"
mock_settings.LLM_NAME = None
reg._add_openai_models(mock_settings)
assert len(reg.models) == 0
@pytest.mark.unit
def test_add_openai_standard_no_api_key(self):
"""Cover line 179: standard OpenAI without API key adds nothing."""
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.OPENAI_BASE_URL = None
mock_settings.OPENAI_API_KEY = None
reg._add_openai_models(mock_settings)
assert len(reg.models) == 0
# ---------------------------------------------------------------------------
# Coverage — additional uncovered lines: 100, 105, 147, 171, 179, 186, 250
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestModelRegistryAdditionalCoverage:
def test_add_azure_openai_models_specific_name(self):
"""Cover line 186: azure_openai with specific LLM_NAME match."""
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.LLM_PROVIDER = "azure_openai"
mock_settings.LLM_NAME = "gpt-4o"
# Create a fake model that matches
fake_model = MagicMock()
fake_model.id = "gpt-4o"
with patch(
"application.core.model_configs.AZURE_OPENAI_MODELS",
[fake_model],
):
reg._add_azure_openai_models(mock_settings)
assert "gpt-4o" in reg.models
def test_add_anthropic_models_with_api_key(self):
"""Cover line 100: anthropic with API key."""
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.ANTHROPIC_API_KEY = "sk-test"
mock_settings.LLM_PROVIDER = "anthropic"
reg._add_anthropic_models(mock_settings)
assert len(reg.models) > 0
def test_add_google_models_with_api_key(self):
"""Cover line 105: google with API key."""
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.GOOGLE_API_KEY = "test-key"
mock_settings.LLM_PROVIDER = "google"
reg._add_google_models(mock_settings)
assert len(reg.models) > 0
def test_default_model_from_provider(self):
"""Cover line 147: default model selected from provider."""
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
reg.default_model_id = None
fake_model = MagicMock()
fake_model.provider = MagicMock()
fake_model.provider.value = "openai"
reg.models["gpt-4o"] = fake_model
mock_settings = MagicMock()
mock_settings.LLM_NAME = None
mock_settings.LLM_PROVIDER = "openai"
mock_settings.API_KEY = "key"
# Simulate the default selection logic
if not reg.default_model_id:
for model_id, model in reg.models.items():
if model.provider.value == mock_settings.LLM_PROVIDER:
reg.default_model_id = model_id
break
assert reg.default_model_id == "gpt-4o"
def test_add_openai_local_endpoint_with_llm_name(self):
"""Cover line 171: local endpoint registers custom models from LLM_NAME."""
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.OPENAI_BASE_URL = "http://localhost:11434/v1"
mock_settings.OPENAI_API_KEY = "sk-test"
mock_settings.LLM_NAME = "llama3,phi3"
reg._add_openai_models(mock_settings)
assert "llama3" in reg.models
assert "phi3" in reg.models
def test_add_openai_standard_with_api_key(self):
"""Cover line 179: standard OpenAI with API key adds models."""
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.OPENAI_BASE_URL = None
mock_settings.OPENAI_API_KEY = "sk-real-key"
reg._add_openai_models(mock_settings)
assert len(reg.models) > 0
def test_add_openrouter_models(self):
"""Cover line 250: openrouter models added."""
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.OPEN_ROUTER_API_KEY = "or-key"
mock_settings.LLM_PROVIDER = "openrouter"
reg._add_openrouter_models(mock_settings)
assert len(reg.models) > 0
# ---------------------------------------------------------------------------
# Additional coverage for model_settings.py
# Lines: 135-136 (backward compat LLM_NAME), 138-143 (provider fallback),
# 145-146 (first model as default)
# ---------------------------------------------------------------------------
# Imports already at the top of the file; no additional imports needed
@pytest.mark.unit
class TestDefaultModelSelectionBackwardCompat:
"""Cover lines 135-136: backward compat exact match on LLM_NAME."""
def test_llm_name_exact_match_as_default(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
reg.default_model_id = None
# Add a model with composite ID
model = AvailableModel(
id="my-composite-model",
provider=ModelProvider.OPENAI,
display_name="Composite",
description="test",
capabilities=ModelCapabilities(),
)
reg.models["my-composite-model"] = model
# Simulate _parse_model_names returning something different
# so that the first for-loop doesn't match
mock_settings = MagicMock()
mock_settings.LLM_NAME = "my-composite-model"
mock_settings.LLM_PROVIDER = None
mock_settings.API_KEY = None
# Call the logic directly
model_names = reg._parse_model_names(mock_settings.LLM_NAME)
for mn in model_names:
if mn in reg.models:
reg.default_model_id = mn
break
assert reg.default_model_id == "my-composite-model"
@pytest.mark.unit
class TestDefaultModelSelectionByProvider:
"""Cover lines 138-143: default model by provider when LLM_NAME doesn't match."""
def test_default_by_provider(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
reg.default_model_id = None
model = AvailableModel(
id="gpt-4",
provider=ModelProvider.OPENAI,
display_name="GPT-4",
description="test",
capabilities=ModelCapabilities(),
)
reg.models["gpt-4"] = model
# Simulate: LLM_NAME doesn't exist/match, but LLM_PROVIDER + API_KEY set
if not reg.default_model_id:
for model_id, m in reg.models.items():
if m.provider.value == "openai":
reg.default_model_id = model_id
break
assert reg.default_model_id == "gpt-4"
@pytest.mark.unit
class TestDefaultModelSelectionFirstModel:
"""Cover lines 145-146: first model as default when nothing else matches."""
def test_first_model_as_default(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
reg.default_model_id = None
model = AvailableModel(
id="fallback-model",
provider=ModelProvider.OPENAI,
display_name="Fallback",
description="test",
capabilities=ModelCapabilities(),
)
reg.models["fallback-model"] = model
if not reg.default_model_id and reg.models:
reg.default_model_id = next(iter(reg.models.keys()))
assert reg.default_model_id == "fallback-model"

View File

@@ -0,0 +1,382 @@
from unittest.mock import MagicMock, patch
import pytest
from application.core.model_settings import (
AvailableModel,
ModelCapabilities,
ModelProvider,
ModelRegistry,
)
@pytest.fixture(autouse=True)
def _reset_registry():
"""Reset ModelRegistry singleton between tests."""
ModelRegistry._instance = None
ModelRegistry._initialized = False
yield
ModelRegistry._instance = None
ModelRegistry._initialized = False
def _make_model(
model_id="test-model",
provider=ModelProvider.OPENAI,
display_name="Test Model",
context_window=128000,
supports_tools=True,
supports_structured_output=False,
supported_attachment_types=None,
enabled=True,
base_url=None,
):
return AvailableModel(
id=model_id,
provider=provider,
display_name=display_name,
capabilities=ModelCapabilities(
supports_tools=supports_tools,
supports_structured_output=supports_structured_output,
supported_attachment_types=supported_attachment_types or [],
context_window=context_window,
),
enabled=enabled,
base_url=base_url,
)
# ── get_api_key_for_provider ─────────────────────────────────────────────────
class TestGetApiKeyForProvider:
"""settings is lazily imported inside the function body, so we patch
at application.core.settings.settings (the actual module attribute)."""
@pytest.mark.unit
def test_openai_key(self):
with patch("application.core.settings.settings") as mock_settings:
mock_settings.OPENAI_API_KEY = "sk-openai"
mock_settings.API_KEY = "sk-fallback"
mock_settings.OPEN_ROUTER_API_KEY = None
mock_settings.NOVITA_API_KEY = None
mock_settings.ANTHROPIC_API_KEY = None
mock_settings.GOOGLE_API_KEY = None
mock_settings.GROQ_API_KEY = None
mock_settings.HUGGINGFACE_API_KEY = None
from application.core.model_utils import get_api_key_for_provider
assert get_api_key_for_provider("openai") == "sk-openai"
@pytest.mark.unit
def test_anthropic_key(self):
with patch("application.core.settings.settings") as mock_settings:
mock_settings.ANTHROPIC_API_KEY = "sk-anthropic"
mock_settings.API_KEY = "sk-fallback"
from application.core.model_utils import get_api_key_for_provider
assert get_api_key_for_provider("anthropic") == "sk-anthropic"
@pytest.mark.unit
def test_google_key(self):
with patch("application.core.settings.settings") as mock_settings:
mock_settings.GOOGLE_API_KEY = "sk-google"
mock_settings.API_KEY = "sk-fallback"
from application.core.model_utils import get_api_key_for_provider
assert get_api_key_for_provider("google") == "sk-google"
@pytest.mark.unit
def test_groq_key(self):
with patch("application.core.settings.settings") as mock_settings:
mock_settings.GROQ_API_KEY = "sk-groq"
mock_settings.API_KEY = "sk-fallback"
from application.core.model_utils import get_api_key_for_provider
assert get_api_key_for_provider("groq") == "sk-groq"
@pytest.mark.unit
def test_openrouter_key(self):
with patch("application.core.settings.settings") as mock_settings:
mock_settings.OPEN_ROUTER_API_KEY = "sk-or"
mock_settings.API_KEY = "sk-fallback"
from application.core.model_utils import get_api_key_for_provider
assert get_api_key_for_provider("openrouter") == "sk-or"
@pytest.mark.unit
def test_novita_key(self):
with patch("application.core.settings.settings") as mock_settings:
mock_settings.NOVITA_API_KEY = "sk-novita"
mock_settings.API_KEY = "sk-fallback"
from application.core.model_utils import get_api_key_for_provider
assert get_api_key_for_provider("novita") == "sk-novita"
@pytest.mark.unit
def test_huggingface_key(self):
with patch("application.core.settings.settings") as mock_settings:
mock_settings.HUGGINGFACE_API_KEY = "hf-key"
mock_settings.API_KEY = "sk-fallback"
from application.core.model_utils import get_api_key_for_provider
assert get_api_key_for_provider("huggingface") == "hf-key"
@pytest.mark.unit
def test_docsgpt_returns_fallback(self):
with patch("application.core.settings.settings") as mock_settings:
mock_settings.API_KEY = "sk-fallback"
from application.core.model_utils import get_api_key_for_provider
assert get_api_key_for_provider("docsgpt") == "sk-fallback"
@pytest.mark.unit
def test_llama_cpp_returns_fallback(self):
with patch("application.core.settings.settings") as mock_settings:
mock_settings.API_KEY = "sk-fallback"
from application.core.model_utils import get_api_key_for_provider
assert get_api_key_for_provider("llama.cpp") == "sk-fallback"
@pytest.mark.unit
def test_unknown_provider_returns_fallback(self):
with patch("application.core.settings.settings") as mock_settings:
mock_settings.API_KEY = "sk-fallback"
from application.core.model_utils import get_api_key_for_provider
assert get_api_key_for_provider("unknown_provider") == "sk-fallback"
@pytest.mark.unit
def test_azure_openai_key(self):
with patch("application.core.settings.settings") as mock_settings:
mock_settings.API_KEY = "sk-azure"
from application.core.model_utils import get_api_key_for_provider
assert get_api_key_for_provider("azure_openai") == "sk-azure"
# ── get_all_available_models ─────────────────────────────────────────────────
class TestGetAllAvailableModels:
@pytest.mark.unit
@patch("application.core.model_utils.ModelRegistry.get_instance")
def test_returns_enabled_models_as_dict(self, mock_get_instance):
model_a = _make_model("model-a", display_name="Model A")
model_b = _make_model("model-b", display_name="Model B")
mock_registry = MagicMock()
mock_registry.get_enabled_models.return_value = [model_a, model_b]
mock_get_instance.return_value = mock_registry
from application.core.model_utils import get_all_available_models
result = get_all_available_models()
assert "model-a" in result
assert "model-b" in result
assert result["model-a"]["display_name"] == "Model A"
assert result["model-b"]["display_name"] == "Model B"
@pytest.mark.unit
@patch("application.core.model_utils.ModelRegistry.get_instance")
def test_empty_registry(self, mock_get_instance):
mock_registry = MagicMock()
mock_registry.get_enabled_models.return_value = []
mock_get_instance.return_value = mock_registry
from application.core.model_utils import get_all_available_models
assert get_all_available_models() == {}
# ── validate_model_id ────────────────────────────────────────────────────────
class TestValidateModelId:
@pytest.mark.unit
@patch("application.core.model_utils.ModelRegistry.get_instance")
def test_exists(self, mock_get_instance):
mock_registry = MagicMock()
mock_registry.model_exists.return_value = True
mock_get_instance.return_value = mock_registry
from application.core.model_utils import validate_model_id
assert validate_model_id("gpt-4") is True
@pytest.mark.unit
@patch("application.core.model_utils.ModelRegistry.get_instance")
def test_not_exists(self, mock_get_instance):
mock_registry = MagicMock()
mock_registry.model_exists.return_value = False
mock_get_instance.return_value = mock_registry
from application.core.model_utils import validate_model_id
assert validate_model_id("nonexistent") is False
# ── get_model_capabilities ───────────────────────────────────────────────────
class TestGetModelCapabilities:
@pytest.mark.unit
@patch("application.core.model_utils.ModelRegistry.get_instance")
def test_model_found(self, mock_get_instance):
model = _make_model(
"gpt-4",
context_window=8192,
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=["image/png"],
)
mock_registry = MagicMock()
mock_registry.get_model.return_value = model
mock_get_instance.return_value = mock_registry
from application.core.model_utils import get_model_capabilities
caps = get_model_capabilities("gpt-4")
assert caps is not None
assert caps["supported_attachment_types"] == ["image/png"]
assert caps["supports_tools"] is True
assert caps["supports_structured_output"] is True
assert caps["context_window"] == 8192
@pytest.mark.unit
@patch("application.core.model_utils.ModelRegistry.get_instance")
def test_model_not_found(self, mock_get_instance):
mock_registry = MagicMock()
mock_registry.get_model.return_value = None
mock_get_instance.return_value = mock_registry
from application.core.model_utils import get_model_capabilities
assert get_model_capabilities("nonexistent") is None
# ── get_default_model_id ─────────────────────────────────────────────────────
class TestGetDefaultModelId:
@pytest.mark.unit
@patch("application.core.model_utils.ModelRegistry.get_instance")
def test_returns_default(self, mock_get_instance):
mock_registry = MagicMock()
mock_registry.default_model_id = "gpt-4"
mock_get_instance.return_value = mock_registry
from application.core.model_utils import get_default_model_id
assert get_default_model_id() == "gpt-4"
# ── get_provider_from_model_id ───────────────────────────────────────────────
class TestGetProviderFromModelId:
@pytest.mark.unit
@patch("application.core.model_utils.ModelRegistry.get_instance")
def test_model_found(self, mock_get_instance):
model = _make_model("gpt-4", provider=ModelProvider.OPENAI)
mock_registry = MagicMock()
mock_registry.get_model.return_value = model
mock_get_instance.return_value = mock_registry
from application.core.model_utils import get_provider_from_model_id
assert get_provider_from_model_id("gpt-4") == "openai"
@pytest.mark.unit
@patch("application.core.model_utils.ModelRegistry.get_instance")
def test_model_not_found(self, mock_get_instance):
mock_registry = MagicMock()
mock_registry.get_model.return_value = None
mock_get_instance.return_value = mock_registry
from application.core.model_utils import get_provider_from_model_id
assert get_provider_from_model_id("nonexistent") is None
# ── get_token_limit ──────────────────────────────────────────────────────────
class TestGetTokenLimit:
@pytest.mark.unit
@patch("application.core.model_utils.ModelRegistry.get_instance")
def test_model_found(self, mock_get_instance):
model = _make_model("gpt-4", context_window=8192)
mock_registry = MagicMock()
mock_registry.get_model.return_value = model
mock_get_instance.return_value = mock_registry
from application.core.model_utils import get_token_limit
assert get_token_limit("gpt-4") == 8192
@pytest.mark.unit
@patch("application.core.model_utils.ModelRegistry.get_instance")
def test_model_not_found_returns_default(self, mock_get_instance):
mock_registry = MagicMock()
mock_registry.get_model.return_value = None
mock_get_instance.return_value = mock_registry
with patch("application.core.settings.settings") as mock_settings:
mock_settings.DEFAULT_LLM_TOKEN_LIMIT = 128000
from application.core.model_utils import get_token_limit
assert get_token_limit("nonexistent") == 128000
# ── get_base_url_for_model ───────────────────────────────────────────────────
class TestGetBaseUrlForModel:
@pytest.mark.unit
@patch("application.core.model_utils.ModelRegistry.get_instance")
def test_model_with_base_url(self, mock_get_instance):
model = _make_model("custom-model", base_url="http://localhost:8080")
mock_registry = MagicMock()
mock_registry.get_model.return_value = model
mock_get_instance.return_value = mock_registry
from application.core.model_utils import get_base_url_for_model
assert get_base_url_for_model("custom-model") == "http://localhost:8080"
@pytest.mark.unit
@patch("application.core.model_utils.ModelRegistry.get_instance")
def test_model_without_base_url(self, mock_get_instance):
model = _make_model("gpt-4", base_url=None)
mock_registry = MagicMock()
mock_registry.get_model.return_value = model
mock_get_instance.return_value = mock_registry
from application.core.model_utils import get_base_url_for_model
assert get_base_url_for_model("gpt-4") is None
@pytest.mark.unit
@patch("application.core.model_utils.ModelRegistry.get_instance")
def test_model_not_found(self, mock_get_instance):
mock_registry = MagicMock()
mock_registry.get_model.return_value = None
mock_get_instance.return_value = mock_registry
from application.core.model_utils import get_base_url_for_model
assert get_base_url_for_model("nonexistent") is None

View File

@@ -195,3 +195,67 @@ class TestValidateUrlSafe:
is_valid, url, error = validate_url_safe("http://192.168.1.1")
assert is_valid is False
assert "private" in error.lower() or "internal" in error.lower()
def test_adds_scheme_when_missing(self):
with patch("application.core.url_validation.resolve_hostname") as mock_resolve:
mock_resolve.return_value = "93.184.216.34"
is_valid, url, error = validate_url_safe("example.com")
assert is_valid is True
assert url == "http://example.com"
class TestIsPrivateIPExtended:
"""Additional edge cases for IP classification."""
def test_multicast_ip(self):
assert is_private_ip("224.0.0.1") is True
def test_unspecified_ip(self):
assert is_private_ip("0.0.0.0") is True
def test_ipv6_loopback(self):
assert is_private_ip("::1") is True
def test_ipv6_private(self):
assert is_private_ip("fc00::1") is True
def test_ipv6_public(self):
assert is_private_ip("2607:f8b0:4004:800::200e") is False
def test_reserved_ip(self):
# 240.0.0.0/4 is reserved (future use), Python's ipaddress marks it as such
assert is_private_ip("240.0.0.1") is True
class TestValidateUrlExtended:
"""Additional URL validation tests."""
def test_blocks_metadata_hostname(self):
with pytest.raises(SSRFError):
validate_url("http://metadata")
def test_allows_localhost_with_flag(self):
with patch("application.core.url_validation.resolve_hostname") as mock_resolve:
mock_resolve.return_value = "192.168.1.1"
result = validate_url(
"http://internal.local", allow_localhost=True
)
assert result == "http://internal.local"
def test_blocks_aws_ecs_metadata_ip(self):
with pytest.raises(SSRFError, match="metadata"):
validate_url("http://169.254.170.2")
def test_blocks_aws_ipv6_metadata(self):
with pytest.raises(SSRFError, match="metadata"):
validate_url("http://[fd00:ec2::254]")
def test_blocks_hostname_resolving_to_loopback(self):
with patch("application.core.url_validation.resolve_hostname") as mock_resolve:
mock_resolve.return_value = "127.0.0.1"
with pytest.raises(SSRFError):
validate_url("http://sneaky.example.com")
def test_allows_localhost_ip_with_flag(self):
result = validate_url("http://10.0.0.1", allow_localhost=True)
assert result == "http://10.0.0.1"

View File

@@ -0,0 +1,577 @@
#!/usr/bin/env python3
"""
Integration tests for DocsGPT workflow management endpoints.
Uses Flask test client with real MongoDB (must be running).
Endpoints tested:
- /api/workflows (POST) - Create workflow
- /api/workflows/<id> (GET) - Get workflow
- /api/workflows/<id> (PUT) - Update workflow
- /api/workflows/<id> (DELETE) - Delete workflow
Run:
pytest tests/integration/test_workflows.py -v
"""
import time
import pytest
from jose import jwt
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture(scope="module")
def app():
"""Create the real Flask app (connects to real MongoDB)."""
from application.app import app as flask_app
flask_app.config["TESTING"] = True
return flask_app
@pytest.fixture(scope="module")
def client(app):
"""Flask test client.
When AUTH_TYPE is set to simple_jwt/session_jwt a Bearer token is
injected; otherwise the backend already returns {"sub": "local"}
for every request so no token is needed.
"""
from application.core.settings import settings
c = app.test_client()
if settings.AUTH_TYPE in ("simple_jwt", "session_jwt"):
secret = settings.JWT_SECRET_KEY
if not secret:
pytest.skip("JWT_SECRET_KEY not configured")
payload = {"sub": f"test_workflow_integration_{int(time.time())}"}
token = jwt.encode(payload, secret, algorithm="HS256")
c.environ_base["HTTP_AUTHORIZATION"] = f"Bearer {token}"
return c
@pytest.fixture(scope="module")
def created_ids():
"""Accumulator for workflow IDs to clean up after all tests."""
return []
@pytest.fixture(autouse=True, scope="module")
def cleanup(client, created_ids):
"""Delete all test-created workflows after the module finishes."""
yield
for wf_id in created_ids:
try:
client.delete(f"/api/workflows/{wf_id}")
except Exception:
pass
# ---------------------------------------------------------------------------
# Payload helpers
# ---------------------------------------------------------------------------
def simple_workflow(suffix=""):
"""Start -> End."""
return {
"name": f"Simple WF {int(time.time())}{suffix}",
"description": "integration test",
"nodes": [
{"id": "start_1", "type": "start", "title": "Start",
"position": {"x": 0, "y": 0}, "data": {}},
{"id": "end_1", "type": "end", "title": "End",
"position": {"x": 400, "y": 0}, "data": {}},
],
"edges": [
{"id": "edge_1", "source": "start_1", "target": "end_1"},
],
}
def linear_workflow(suffix=""):
"""Start -> Agent -> End."""
return {
"name": f"Linear WF {int(time.time())}{suffix}",
"description": "integration test",
"nodes": [
{"id": "start_1", "type": "start", "title": "Start",
"position": {"x": 0, "y": 0}, "data": {}},
{"id": "agent_1", "type": "agent", "title": "Agent",
"position": {"x": 200, "y": 0}, "data": {
"agent_type": "classic",
"system_prompt": "You are helpful.",
"prompt_template": "",
"stream_to_user": False,
}},
{"id": "end_1", "type": "end", "title": "End",
"position": {"x": 400, "y": 0}, "data": {}},
],
"edges": [
{"id": "edge_1", "source": "start_1", "target": "agent_1"},
{"id": "edge_2", "source": "agent_1", "target": "end_1"},
],
}
def multi_input_end_workflow(suffix=""):
"""Condition branches into two agents, both converging on one end node.
Graph:
start -> condition --(case_1)--> agent_a --\
--(else)----> agent_b ---+--> end
"""
return {
"name": f"Multi-Input End {int(time.time())}{suffix}",
"description": "end node with multiple inputs",
"nodes": [
{"id": "start_1", "type": "start", "title": "Start",
"position": {"x": 0, "y": 100}, "data": {}},
{"id": "cond_1", "type": "condition", "title": "Branch",
"position": {"x": 200, "y": 100}, "data": {
"mode": "simple",
"cases": [
{"name": "Case 1", "expression": "true",
"sourceHandle": "case_1"},
],
}},
{"id": "agent_a", "type": "agent", "title": "Agent A",
"position": {"x": 400, "y": 0}, "data": {
"agent_type": "classic",
"system_prompt": "Branch A",
"prompt_template": "",
"stream_to_user": False,
}},
{"id": "agent_b", "type": "agent", "title": "Agent B",
"position": {"x": 400, "y": 200}, "data": {
"agent_type": "classic",
"system_prompt": "Branch B",
"prompt_template": "",
"stream_to_user": False,
}},
{"id": "end_1", "type": "end", "title": "End",
"position": {"x": 600, "y": 100}, "data": {}},
],
"edges": [
{"id": "e1", "source": "start_1", "target": "cond_1"},
{"id": "e2", "source": "cond_1", "target": "agent_a",
"sourceHandle": "case_1"},
{"id": "e3", "source": "cond_1", "target": "agent_b",
"sourceHandle": "else"},
# Both agents feed into the SAME end node
{"id": "e4", "source": "agent_a", "target": "end_1"},
{"id": "e5", "source": "agent_b", "target": "end_1"},
],
}
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _extract_id(resp):
"""Pull workflow id from create/update response."""
body = resp.get_json()
data = body.get("data") or body
return data.get("id")
def _get_graph(client, wf_id):
"""Fetch workflow and return (nodes, edges)."""
resp = client.get(f"/api/workflows/{wf_id}")
assert resp.status_code == 200, resp.get_data(as_text=True)
body = resp.get_json()
data = body.get("data") or body
return data.get("nodes", []), data.get("edges", [])
# ===========================================================================
# CRUD tests
# ===========================================================================
class TestWorkflowCRUD:
def test_create_simple_workflow(self, client, created_ids):
resp = client.post("/api/workflows", json=simple_workflow())
assert resp.status_code in (200, 201), resp.get_data(as_text=True)
wf_id = _extract_id(resp)
assert wf_id
created_ids.append(wf_id)
def test_create_linear_workflow(self, client, created_ids):
resp = client.post("/api/workflows", json=linear_workflow())
assert resp.status_code in (200, 201), resp.get_data(as_text=True)
wf_id = _extract_id(resp)
assert wf_id
created_ids.append(wf_id)
def test_get_workflow_returns_nodes_and_edges(self, client, created_ids):
resp = client.post("/api/workflows", json=simple_workflow(" get"))
wf_id = _extract_id(resp)
created_ids.append(wf_id)
nodes, edges = _get_graph(client, wf_id)
assert len(nodes) == 2
assert len(edges) == 1
def test_update_workflow(self, client, created_ids):
resp = client.post("/api/workflows", json=simple_workflow(" upd"))
wf_id = _extract_id(resp)
created_ids.append(wf_id)
update_resp = client.put(
f"/api/workflows/{wf_id}", json=linear_workflow(" updated")
)
assert update_resp.status_code == 200, update_resp.get_data(as_text=True)
nodes, edges = _get_graph(client, wf_id)
assert len(nodes) == 3 # start, agent, end
assert len(edges) == 2
def test_delete_workflow(self, client):
resp = client.post("/api/workflows", json=simple_workflow(" del"))
wf_id = _extract_id(resp)
del_resp = client.delete(f"/api/workflows/{wf_id}")
assert del_resp.status_code == 200
get_resp = client.get(f"/api/workflows/{wf_id}")
assert get_resp.status_code in (400, 404)
def test_reject_workflow_without_end_node(self, client):
payload = {
"name": "No End",
"nodes": [
{"id": "s", "type": "start", "title": "Start",
"position": {"x": 0, "y": 0}, "data": {}},
],
"edges": [],
}
resp = client.post("/api/workflows", json=payload)
assert resp.status_code == 400, resp.get_data(as_text=True)
# ===========================================================================
# Multi-input end node tests
# ===========================================================================
class TestMultiInputEndNode:
"""Verify that an end node can receive edges from multiple source nodes."""
def test_create_multi_input_end_workflow_accepted(self, client, created_ids):
"""Backend must accept a workflow where two edges target the same end node."""
resp = client.post("/api/workflows", json=multi_input_end_workflow())
assert resp.status_code in (200, 201), resp.get_data(as_text=True)
wf_id = _extract_id(resp)
assert wf_id
created_ids.append(wf_id)
def test_multi_input_end_all_edges_persisted(self, client, created_ids):
"""After round-trip, both edges into the end node must still be present."""
resp = client.post(
"/api/workflows", json=multi_input_end_workflow(" persist")
)
assert resp.status_code in (200, 201), resp.get_data(as_text=True)
wf_id = _extract_id(resp)
created_ids.append(wf_id)
nodes, edges = _get_graph(client, wf_id)
# Locate end node
end_ids = {n["id"] for n in nodes if n["type"] == "end"}
assert end_ids, "no end node in response"
# Count edges targeting any end node
edges_to_end = [e for e in edges if e["target"] in end_ids]
assert len(edges_to_end) >= 2, (
f"Expected >=2 edges to end, got {len(edges_to_end)}: {edges_to_end}"
)
def test_multi_input_end_total_edge_count(self, client, created_ids):
"""All 5 edges of the multi-input graph must survive persistence."""
resp = client.post(
"/api/workflows", json=multi_input_end_workflow(" count")
)
wf_id = _extract_id(resp)
created_ids.append(wf_id)
_, edges = _get_graph(client, wf_id)
assert len(edges) == 5, f"Expected 5 edges, got {len(edges)}"
def test_update_to_multi_input_end_preserves_edges(self, client, created_ids):
"""Updating a simple workflow to multi-input end keeps all edges."""
# Create simple
resp = client.post("/api/workflows", json=simple_workflow(" pre"))
wf_id = _extract_id(resp)
created_ids.append(wf_id)
# Update to multi-input end
update_resp = client.put(
f"/api/workflows/{wf_id}",
json=multi_input_end_workflow(" post"),
)
assert update_resp.status_code == 200, update_resp.get_data(as_text=True)
nodes, edges = _get_graph(client, wf_id)
end_ids = {n["id"] for n in nodes if n["type"] == "end"}
edges_to_end = [e for e in edges if e["target"] in end_ids]
assert len(edges_to_end) >= 2, (
f"Expected >=2 edges to end after update, got {len(edges_to_end)}"
)
# ---------------------------------------------------------------------------
# Source-aware payload helpers
# ---------------------------------------------------------------------------
def workflow_with_sources(sources, suffix=""):
"""Start -> Agent (with sources) -> End."""
return {
"name": f"Source WF {int(time.time())}{suffix}",
"description": "integration test with sources",
"nodes": [
{"id": "start_1", "type": "start", "title": "Start",
"position": {"x": 0, "y": 0}, "data": {}},
{"id": "agent_1", "type": "agent", "title": "Agent",
"position": {"x": 200, "y": 0}, "data": {
"agent_type": "classic",
"system_prompt": "You are helpful.",
"prompt_template": "",
"stream_to_user": False,
"sources": sources,
"tools": [],
}},
{"id": "end_1", "type": "end", "title": "End",
"position": {"x": 400, "y": 0}, "data": {}},
],
"edges": [
{"id": "edge_1", "source": "start_1", "target": "agent_1"},
{"id": "edge_2", "source": "agent_1", "target": "end_1"},
],
}
def workflow_multi_agent_sources(suffix=""):
"""Start -> Agent A (sources A) -> Agent B (sources B) -> End."""
return {
"name": f"Multi-Agent Sources {int(time.time())}{suffix}",
"description": "two agents with different sources",
"nodes": [
{"id": "start_1", "type": "start", "title": "Start",
"position": {"x": 0, "y": 0}, "data": {}},
{"id": "agent_a", "type": "agent", "title": "Agent A",
"position": {"x": 200, "y": 0}, "data": {
"agent_type": "agentic",
"system_prompt": "Agent A prompt",
"prompt_template": "",
"stream_to_user": False,
"sources": ["src_alpha", "src_beta"],
"tools": [],
}},
{"id": "agent_b", "type": "agent", "title": "Agent B",
"position": {"x": 400, "y": 0}, "data": {
"agent_type": "classic",
"system_prompt": "Agent B prompt",
"prompt_template": "",
"stream_to_user": True,
"sources": ["src_gamma"],
"tools": [],
}},
{"id": "end_1", "type": "end", "title": "End",
"position": {"x": 600, "y": 0}, "data": {}},
],
"edges": [
{"id": "e1", "source": "start_1", "target": "agent_a"},
{"id": "e2", "source": "agent_a", "target": "agent_b"},
{"id": "e3", "source": "agent_b", "target": "end_1"},
],
}
def _find_agent_node(nodes, node_id):
"""Find a specific node by id."""
return next((n for n in nodes if n["id"] == node_id), None)
# ===========================================================================
# Workflow integration tests
# ===========================================================================
class TestWorkflowIntegration:
"""Verify end-to-end workflow create → get → update → get round-trips."""
def test_linear_workflow_round_trip(self, client, created_ids):
"""Create a linear workflow and verify all nodes/edges survive the round-trip."""
payload = linear_workflow(" round-trip")
resp = client.post("/api/workflows", json=payload)
assert resp.status_code in (200, 201), resp.get_data(as_text=True)
wf_id = _extract_id(resp)
assert wf_id
created_ids.append(wf_id)
nodes, edges = _get_graph(client, wf_id)
assert len(nodes) == 3
assert len(edges) == 2
# Verify node types
types = {n["id"]: n["type"] for n in nodes}
assert types["start_1"] == "start"
assert types["agent_1"] == "agent"
assert types["end_1"] == "end"
def test_agent_config_persisted(self, client, created_ids):
"""Agent node config (type, prompts, stream_to_user) round-trips correctly."""
payload = linear_workflow(" config")
resp = client.post("/api/workflows", json=payload)
wf_id = _extract_id(resp)
created_ids.append(wf_id)
nodes, _ = _get_graph(client, wf_id)
agent = _find_agent_node(nodes, "agent_1")
assert agent is not None
assert agent["data"]["agent_type"] == "classic"
assert agent["data"]["system_prompt"] == "You are helpful."
assert agent["data"]["stream_to_user"] is False
def test_update_workflow_replaces_graph(self, client, created_ids):
"""Updating a workflow fully replaces nodes and edges."""
resp = client.post("/api/workflows", json=simple_workflow(" replace"))
wf_id = _extract_id(resp)
created_ids.append(wf_id)
nodes, edges = _get_graph(client, wf_id)
assert len(nodes) == 2
# Update to linear
update_resp = client.put(
f"/api/workflows/{wf_id}", json=linear_workflow(" replaced")
)
assert update_resp.status_code == 200
nodes, edges = _get_graph(client, wf_id)
assert len(nodes) == 3
assert len(edges) == 2
# ===========================================================================
# Source-specific integration tests
# ===========================================================================
class TestWorkflowSources:
"""Verify that agent node sources are persisted and retrieved correctly."""
def test_create_workflow_with_single_source(self, client, created_ids):
"""A workflow with one source on an agent node persists it."""
payload = workflow_with_sources(["default"])
resp = client.post("/api/workflows", json=payload)
assert resp.status_code in (200, 201), resp.get_data(as_text=True)
wf_id = _extract_id(resp)
assert wf_id
created_ids.append(wf_id)
nodes, _ = _get_graph(client, wf_id)
agent = _find_agent_node(nodes, "agent_1")
assert agent is not None, "Agent node not found"
assert agent["data"].get("sources") == ["default"], (
f"Expected sources=['default'], got {agent['data'].get('sources')}"
)
def test_create_workflow_with_multiple_sources(self, client, created_ids):
"""Multiple sources on an agent node are all persisted."""
sources = ["src_1", "src_2", "src_3"]
payload = workflow_with_sources(sources)
resp = client.post("/api/workflows", json=payload)
assert resp.status_code in (200, 201), resp.get_data(as_text=True)
wf_id = _extract_id(resp)
created_ids.append(wf_id)
nodes, _ = _get_graph(client, wf_id)
agent = _find_agent_node(nodes, "agent_1")
assert agent is not None
assert agent["data"].get("sources") == sources
def test_create_workflow_with_empty_sources(self, client, created_ids):
"""An agent node with empty sources list is accepted and persisted."""
payload = workflow_with_sources([])
resp = client.post("/api/workflows", json=payload)
assert resp.status_code in (200, 201), resp.get_data(as_text=True)
wf_id = _extract_id(resp)
assert wf_id
created_ids.append(wf_id)
nodes, _ = _get_graph(client, wf_id)
agent = _find_agent_node(nodes, "agent_1")
assert agent is not None
assert agent["data"].get("sources") == []
def test_update_workflow_sources(self, client, created_ids):
"""Updating a workflow replaces agent sources."""
# Create with original sources
payload = workflow_with_sources(["old_src"])
resp = client.post("/api/workflows", json=payload)
wf_id = _extract_id(resp)
created_ids.append(wf_id)
# Update with new sources
updated_payload = workflow_with_sources(["new_src_1", "new_src_2"], " upd")
update_resp = client.put(f"/api/workflows/{wf_id}", json=updated_payload)
assert update_resp.status_code == 200, update_resp.get_data(as_text=True)
nodes, _ = _get_graph(client, wf_id)
agent = _find_agent_node(nodes, "agent_1")
assert agent is not None
assert agent["data"].get("sources") == ["new_src_1", "new_src_2"]
def test_multi_agent_independent_sources(self, client, created_ids):
"""Each agent node keeps its own distinct sources list."""
payload = workflow_multi_agent_sources()
resp = client.post("/api/workflows", json=payload)
assert resp.status_code in (200, 201), resp.get_data(as_text=True)
wf_id = _extract_id(resp)
created_ids.append(wf_id)
nodes, _ = _get_graph(client, wf_id)
agent_a = _find_agent_node(nodes, "agent_a")
agent_b = _find_agent_node(nodes, "agent_b")
assert agent_a is not None, "Agent A not found"
assert agent_b is not None, "Agent B not found"
assert agent_a["data"].get("sources") == ["src_alpha", "src_beta"]
assert agent_b["data"].get("sources") == ["src_gamma"]
def test_sources_survive_workflow_update(self, client, created_ids):
"""Sources survive when a workflow is updated without changing sources."""
payload = workflow_with_sources(["persistent_src"])
resp = client.post("/api/workflows", json=payload)
wf_id = _extract_id(resp)
created_ids.append(wf_id)
# Update keeping same sources
update_resp = client.put(f"/api/workflows/{wf_id}", json=payload)
assert update_resp.status_code == 200
nodes, _ = _get_graph(client, wf_id)
agent = _find_agent_node(nodes, "agent_1")
assert agent["data"].get("sources") == ["persistent_src"]
def test_remove_sources_on_update(self, client, created_ids):
"""Clearing sources on update results in empty list."""
payload = workflow_with_sources(["will_be_removed"])
resp = client.post("/api/workflows", json=payload)
wf_id = _extract_id(resp)
created_ids.append(wf_id)
# Update with no sources
cleared_payload = workflow_with_sources([], " cleared")
update_resp = client.put(f"/api/workflows/{wf_id}", json=cleared_payload)
assert update_resp.status_code == 200
nodes, _ = _get_graph(client, wf_id)
agent = _find_agent_node(nodes, "agent_1")
assert agent["data"].get("sources") == []

0
tests/llm/__init__.py Normal file
View File

View File

@@ -71,6 +71,7 @@ class TestLLMHandlerCreator:
expected_handlers = {
"openai": OpenAILLMHandler,
"google": GoogleLLMHandler,
"novita": OpenAILLMHandler,
"default": OpenAILLMHandler,
}

File diff suppressed because it is too large Load Diff

323
tests/llm/test_anthropic.py Normal file
View File

@@ -0,0 +1,323 @@
"""Unit tests for application/llm/anthropic.py — AnthropicLLM.
Extends coverage beyond test_anthropic_llm.py:
- Constructor: api_key priority, base_url support
- get_supported_attachment_types
- prepare_messages_with_attachments: various scenarios
- _get_base64_image: error paths
- _raw_gen_stream: close called on response
"""
import sys
import types
import pytest
# ---------------------------------------------------------------------------
# Fake anthropic module
# ---------------------------------------------------------------------------
class _FakeCompletion:
def __init__(self, text):
self.completion = text
class _FakeCompletions:
def __init__(self):
self.last_kwargs = None
self._stream_items = [_FakeCompletion("s1"), _FakeCompletion("s2")]
def create(self, **kwargs):
self.last_kwargs = kwargs
if kwargs.get("stream"):
return self._stream_items
return _FakeCompletion("final")
class _FakeAnthropic:
def __init__(self, api_key=None, base_url=None):
self.api_key = api_key
self.base_url = base_url
self.completions = _FakeCompletions()
@pytest.fixture(autouse=True)
def patch_anthropic(monkeypatch):
fake = types.ModuleType("anthropic")
fake.Anthropic = _FakeAnthropic
fake.HUMAN_PROMPT = "<HUMAN>"
fake.AI_PROMPT = "<AI>"
modules_to_remove = [key for key in sys.modules if key.startswith("anthropic")]
for key in modules_to_remove:
sys.modules.pop(key, None)
sys.modules["anthropic"] = fake
if "application.llm.anthropic" in sys.modules:
del sys.modules["application.llm.anthropic"]
yield
sys.modules.pop("anthropic", None)
if "application.llm.anthropic" in sys.modules:
del sys.modules["application.llm.anthropic"]
@pytest.fixture
def llm():
from application.llm.anthropic import AnthropicLLM
instance = AnthropicLLM(api_key="test-key")
instance.storage = types.SimpleNamespace(
get_file=lambda path: _ctx_manager(b"img_bytes"),
)
return instance
def _ctx_manager(data):
"""Create a simple context manager returning an object with .read()."""
import contextlib
@contextlib.contextmanager
def cm():
yield types.SimpleNamespace(read=lambda: data)
return cm()
# ---------------------------------------------------------------------------
# Constructor
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestAnthropicConstructor:
def test_api_key_set(self):
from application.llm.anthropic import AnthropicLLM
instance = AnthropicLLM(api_key="custom-key")
assert instance.api_key == "custom-key"
def test_base_url_passed(self):
from application.llm.anthropic import AnthropicLLM
instance = AnthropicLLM(api_key="k", base_url="https://custom.api")
assert instance.anthropic.base_url == "https://custom.api"
def test_no_base_url(self):
from application.llm.anthropic import AnthropicLLM
instance = AnthropicLLM(api_key="k")
assert instance.anthropic.base_url is None
def test_human_and_ai_prompts_set(self):
from application.llm.anthropic import AnthropicLLM
instance = AnthropicLLM(api_key="k")
assert instance.HUMAN_PROMPT == "<HUMAN>"
assert instance.AI_PROMPT == "<AI>"
# ---------------------------------------------------------------------------
# _raw_gen
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestRawGen:
def test_returns_completion(self, llm):
msgs = [{"content": "context"}, {"content": "question"}]
result = llm._raw_gen(llm, model="claude-2", messages=msgs)
assert result == "final"
def test_prompt_contains_context_and_question(self, llm):
msgs = [{"content": "my context"}, {"content": "my question"}]
llm._raw_gen(llm, model="claude-2", messages=msgs)
prompt = llm.anthropic.completions.last_kwargs["prompt"]
assert "my context" in prompt
assert "my question" in prompt
def test_max_tokens_passed(self, llm):
msgs = [{"content": "c"}, {"content": "q"}]
llm._raw_gen(llm, model="claude-2", messages=msgs, max_tokens=200)
assert llm.anthropic.completions.last_kwargs["max_tokens_to_sample"] == 200
# ---------------------------------------------------------------------------
# _raw_gen_stream
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestRawGenStream:
def test_yields_all_completions(self, llm):
msgs = [{"content": "c"}, {"content": "q"}]
chunks = list(
llm._raw_gen_stream(llm, model="claude", messages=msgs, max_tokens=10)
)
assert chunks == ["s1", "s2"]
def test_calls_close_on_response(self, llm):
closed = {"called": False}
original = llm.anthropic.completions._stream_items
class ClosableList(list):
def close(self):
closed["called"] = True
closable = ClosableList(original)
llm.anthropic.completions._stream_items = closable
llm.anthropic.completions.create = lambda **kw: closable
msgs = [{"content": "c"}, {"content": "q"}]
list(llm._raw_gen_stream(llm, model="claude", messages=msgs))
assert closed["called"]
def test_prompt_format(self, llm):
msgs = [{"content": "ctx"}, {"content": "q"}]
list(llm._raw_gen_stream(llm, model="claude", messages=msgs))
prompt = llm.anthropic.completions.last_kwargs["prompt"]
assert prompt.startswith("<HUMAN>")
assert prompt.endswith("<AI>")
# ---------------------------------------------------------------------------
# get_supported_attachment_types
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestGetSupportedAttachmentTypes:
def test_returns_image_types(self, llm):
result = llm.get_supported_attachment_types()
assert "image/png" in result
assert "image/jpeg" in result
assert "image/webp" in result
assert "image/gif" in result
def test_no_pdf_support(self, llm):
result = llm.get_supported_attachment_types()
assert "application/pdf" not in result
# ---------------------------------------------------------------------------
# prepare_messages_with_attachments
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestPrepareMessagesWithAttachments:
def test_no_attachments_returns_same(self, llm):
msgs = [{"role": "user", "content": "hi"}]
result = llm.prepare_messages_with_attachments(msgs)
assert result == msgs
def test_empty_attachments_returns_same(self, llm):
msgs = [{"role": "user", "content": "hi"}]
result = llm.prepare_messages_with_attachments(msgs, [])
assert result == msgs
def test_image_with_preconverted_data(self, llm):
msgs = [{"role": "user", "content": "look"}]
attachments = [{"mime_type": "image/png", "data": "AABBCC"}]
result = llm.prepare_messages_with_attachments(msgs, attachments)
user_msg = next(m for m in result if m["role"] == "user")
img_part = next(
p for p in user_msg["content"] if p.get("type") == "image"
)
assert img_part["source"]["data"] == "AABBCC"
assert img_part["source"]["type"] == "base64"
assert img_part["source"]["media_type"] == "image/png"
def test_image_from_storage(self, llm):
llm.storage = types.SimpleNamespace(
get_file=lambda p: _ctx_manager(b"raw_image_bytes"),
)
msgs = [{"role": "user", "content": "look"}]
attachments = [{"mime_type": "image/jpeg", "path": "/tmp/img.jpg"}]
result = llm.prepare_messages_with_attachments(msgs, attachments)
user_msg = next(m for m in result if m["role"] == "user")
img_part = next(
p for p in user_msg["content"] if p.get("type") == "image"
)
assert img_part["source"]["media_type"] == "image/jpeg"
assert len(img_part["source"]["data"]) > 0
def test_no_user_message_creates_one(self, llm):
msgs = [{"role": "system", "content": "sys"}]
attachments = [{"mime_type": "image/png", "data": "AAA"}]
result = llm.prepare_messages_with_attachments(msgs, attachments)
user_msgs = [m for m in result if m["role"] == "user"]
assert len(user_msgs) == 1
def test_image_error_adds_text_fallback(self, llm):
def bad_storage(path):
raise Exception("storage error")
llm.storage = types.SimpleNamespace(get_file=bad_storage)
msgs = [{"role": "user", "content": "look"}]
attachments = [
{"mime_type": "image/png", "path": "/bad.png", "content": "fb"},
]
result = llm.prepare_messages_with_attachments(msgs, attachments)
user_msg = next(m for m in result if m["role"] == "user")
text_parts = [
p for p in user_msg["content"]
if p.get("type") == "text" and "could not" in p.get("text", "").lower()
]
assert len(text_parts) == 1
def test_non_image_attachment_ignored(self, llm):
msgs = [{"role": "user", "content": "look"}]
attachments = [{"mime_type": "application/pdf"}]
result = llm.prepare_messages_with_attachments(msgs, attachments)
user_msg = next(m for m in result if m["role"] == "user")
# content becomes list with just original text
assert isinstance(user_msg["content"], list)
assert len(user_msg["content"]) == 1
def test_content_not_list_becomes_empty(self, llm):
msgs = [{"role": "user", "content": 999}]
attachments = [{"mime_type": "image/png", "data": "AAA"}]
result = llm.prepare_messages_with_attachments(msgs, attachments)
user_msg = next(m for m in result if m["role"] == "user")
assert isinstance(user_msg["content"], list)
# ---------------------------------------------------------------------------
# _get_base64_image
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestGetBase64Image:
def test_raises_for_no_path(self, llm):
with pytest.raises(ValueError, match="No file path"):
llm._get_base64_image({})
def test_raises_for_file_not_found(self, llm):
import contextlib
@contextlib.contextmanager
def bad_file(path):
raise FileNotFoundError("not found")
llm.storage = types.SimpleNamespace(get_file=bad_file)
with pytest.raises(FileNotFoundError):
llm._get_base64_image({"path": "/nonexistent"})
def test_returns_base64_encoded(self, llm):
import base64
llm.storage = types.SimpleNamespace(
get_file=lambda p: _ctx_manager(b"test_data"),
)
result = llm._get_base64_image({"path": "/tmp/img.png"})
decoded = base64.b64decode(result)
assert decoded == b"test_data"

Some files were not shown because too many files have changed in this diff Show More