Compare commits

...

28 Commits

Author SHA1 Message Date
Alex
126fa01b14 chore: utils tests 2026-03-29 11:49:35 +01:00
Alex
e06debad5f chore: connector tests 2026-03-29 10:32:48 +01:00
Alex
6492852f7d fix: lint on seeder test 2026-03-29 09:48:46 +01:00
Alex
00a621f33a tests: retriever and seeder 2026-03-29 09:46:07 +01:00
Alex
e92ffc6fdc Merge pull request #2337 from arc53/tests-api
chore: api and tool tests
2026-03-28 22:55:10 +00:00
Alex
fe185e5b8d chore: api and tool tests 2026-03-28 21:51:47 +00:00
Alex
9f3d9ab860 docs: agent types 2026-03-28 18:50:50 +00:00
Alex
1c0adde380 chore: docs update 2026-03-28 17:04:06 +00:00
Alex
3c56bd0d0b docs: fix conflicts 2026-03-28 16:16:15 +00:00
Alex
86664ebda2 Merge pull request #2320 from Alex-wuhu/novita-integration
feat: complete Novita AI provider integration
2026-03-28 13:22:02 +00:00
Alex
db18b743d1 fix tests 2 2026-03-28 13:10:41 +00:00
Alex
9e85cc9065 fix: test errors 2026-03-28 13:04:21 +00:00
Alex
aaaa6f002d Merge branch 'main' of https://github.com/arc53/DocsGPT 2026-03-28 12:03:27 +00:00
Alex
47dcbcb74b fix: tests and sources on workflow agent 2026-03-28 12:03:16 +00:00
Alex
ddbfd94193 Adjust demo GIF height in README
Update the height of the demo GIF in README.
2026-03-28 11:09:05 +00:00
Alex
8dec60ab8b Update demo GIF in README
Replaced demo video GIF in README with a new example.
2026-03-28 11:08:10 +00:00
Alex
84b2e4bab4 fix: end node multi input 2026-03-28 10:00:01 +00:00
Alex
2afdd7f026 fix: enable tools in workflow agents 2026-03-27 13:43:09 +00:00
Alex
f364475f64 Merge pull request #2335 from arc53/tests-worker
tests: worker coverage
2026-03-27 12:14:56 +00:00
Alex
b254de6ed6 tests: worker coverage 2026-03-27 12:03:55 +00:00
Alex
08dedcaf95 Merge pull request #2334 from arc53/tests-vectors
tests: vectors
2026-03-26 19:11:20 +00:00
Alex
c726eb8ebd fix: ruff 2026-03-26 18:48:53 +00:00
Alex
5f0d39e5f1 tests: vectors 2026-03-26 18:42:59 +00:00
Alex
8c82fc5495 Merge pull request #2333 from arc53/chore/bump-npm-v0.6.3
chore: bump npm libraries to v0.6.3
2026-03-26 14:17:09 +00:00
github-actions[bot]
6d81a15e97 chore: bump npm libraries to v0.6.3 2026-03-26 14:16:32 +00:00
Alex
5478e4234c chore: bump npm again 2026-03-26 14:06:05 +00:00
Alex
4056278fef chore: bump npm 2026-03-26 13:48:05 +00:00
Alex-wuhu
eaf39bb15b feat: add Novita AI as LLM provider
Add Novita AI (https://novita.ai) as a new LLM provider option.
Novita offers OpenAI-compatible API endpoints with competitive pricing.
2026-03-23 10:52:26 +08:00
86 changed files with 16195 additions and 132 deletions

View File

@@ -3,6 +3,14 @@ LLM_NAME=docsgpt
VITE_API_STREAMING=true
INTERNAL_KEY=<internal key for worker-to-backend authentication>
# Provider-specific API keys (optional - use these to enable multiple providers)
# OPENAI_API_KEY=<your-openai-api-key>
# ANTHROPIC_API_KEY=<your-anthropic-api-key>
# GOOGLE_API_KEY=<your-google-api-key>
# GROQ_API_KEY=<your-groq-api-key>
# NOVITA_API_KEY=<your-novita-api-key>
# OPEN_ROUTER_API_KEY=<your-openrouter-api-key>
# Remote Embeddings (Optional - for using a remote embeddings API instead of local SentenceTransformer)
# When set, the app will use the remote API and won't load SentenceTransformer (saves RAM)
EMBEDDINGS_BASE_URL=

View File

@@ -29,7 +29,7 @@
<div align="center">
<br>
<img src="https://d3dg1063dc54p9.cloudfront.net/videos/demov7.gif" alt="video-example-of-docs-gpt" width="800" height="450">
<img src="https://d3dg1063dc54p9.cloudfront.net/videos/demo-26.gif" alt="video-example-of-docs-gpt" width="800" height="480">
</div>
<h3 align="left">
<strong>Key Features:</strong>

View File

@@ -185,7 +185,10 @@ class ToolExecutor:
target_dict[param] = value
# Load tool (with caching)
tool = self._get_or_load_tool(tool_data, tool_id, action_name)
tool = self._get_or_load_tool(
tool_data, tool_id, action_name,
headers=headers, query_params=query_params,
)
resolved_arguments = (
{"query_params": query_params, "headers": headers, "body": body}
@@ -238,7 +241,10 @@ class ToolExecutor:
return result, call_id
def _get_or_load_tool(self, tool_data: Dict, tool_id: str, action_name: str):
def _get_or_load_tool(
self, tool_data: Dict, tool_id: str, action_name: str,
headers: Optional[Dict] = None, query_params: Optional[Dict] = None,
):
"""Load a tool, using cache when possible."""
cache_key = f"{tool_data['name']}:{tool_id}:{self.user or ''}"
if cache_key in self._loaded_tools:
@@ -251,8 +257,8 @@ class ToolExecutor:
tool_config = {
"url": action_config["url"],
"method": action_config["method"],
"headers": {},
"query_params": {},
"headers": headers or {},
"query_params": query_params or {},
}
if "body_content_type" in action_config:
tool_config["body_content_type"] = action_config.get(

View File

@@ -27,6 +27,8 @@ ANTHROPIC_ATTACHMENTS = IMAGE_ATTACHMENTS
OPENROUTER_ATTACHMENTS = IMAGE_ATTACHMENTS
NOVITA_ATTACHMENTS = IMAGE_ATTACHMENTS
OPENAI_MODELS = [
AvailableModel(
@@ -193,6 +195,46 @@ OPENROUTER_MODELS = [
),
]
NOVITA_MODELS = [
AvailableModel(
id="moonshotai/kimi-k2.5",
provider=ModelProvider.NOVITA,
display_name="Kimi K2.5",
description="MoE model with function calling, structured output, reasoning, and vision",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=NOVITA_ATTACHMENTS,
context_window=262144,
),
),
AvailableModel(
id="zai-org/glm-5",
provider=ModelProvider.NOVITA,
display_name="GLM-5",
description="MoE model with function calling, structured output, and reasoning",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=[],
context_window=202800,
),
),
AvailableModel(
id="minimax/minimax-m2.5",
provider=ModelProvider.NOVITA,
display_name="MiniMax M2.5",
description="MoE model with function calling, structured output, and reasoning",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=[],
context_window=204800,
),
),
]
AZURE_OPENAI_MODELS = [
AvailableModel(
id="azure-gpt-4",

View File

@@ -114,6 +114,10 @@ class ModelRegistry:
settings.LLM_PROVIDER == "openrouter" and settings.API_KEY
):
self._add_openrouter_models(settings)
if settings.NOVITA_API_KEY or (
settings.LLM_PROVIDER == "novita" and settings.API_KEY
):
self._add_novita_models(settings)
if settings.HUGGINGFACE_API_KEY or (
settings.LLM_PROVIDER == "huggingface" and settings.API_KEY
):
@@ -245,6 +249,21 @@ class ModelRegistry:
for model in OPENROUTER_MODELS:
self.models[model.id] = model
def _add_novita_models(self, settings):
from application.core.model_configs import NOVITA_MODELS
if settings.NOVITA_API_KEY:
for model in NOVITA_MODELS:
self.models[model.id] = model
return
if settings.LLM_PROVIDER == "novita" and settings.LLM_NAME:
for model in NOVITA_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in NOVITA_MODELS:
self.models[model.id] = model
def _add_docsgpt_models(self, settings):
model_id = "docsgpt-local"
model = AvailableModel(

View File

@@ -10,6 +10,7 @@ def get_api_key_for_provider(provider: str) -> Optional[str]:
provider_key_map = {
"openai": settings.OPENAI_API_KEY,
"openrouter": settings.OPEN_ROUTER_API_KEY,
"novita": settings.NOVITA_API_KEY,
"anthropic": settings.ANTHROPIC_API_KEY,
"google": settings.GOOGLE_API_KEY,
"groq": settings.GROQ_API_KEY,

View File

@@ -5,9 +5,7 @@ from typing import Optional
from pydantic import field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
class Settings(BaseSettings):
@@ -15,15 +13,11 @@ class Settings(BaseSettings):
AUTH_TYPE: Optional[str] = None # simple_jwt, session_jwt, or None
LLM_PROVIDER: str = "docsgpt"
LLM_NAME: Optional[str] = (
None # if LLM_PROVIDER is openai, LLM_NAME can be gpt-4 or gpt-3.5-turbo
)
LLM_NAME: Optional[str] = None # if LLM_PROVIDER is openai, LLM_NAME can be gpt-4 or gpt-3.5-turbo
EMBEDDINGS_NAME: str = "huggingface_sentence-transformers/all-mpnet-base-v2"
EMBEDDINGS_BASE_URL: Optional[str] = None # Remote embeddings API URL (OpenAI-compatible)
EMBEDDINGS_KEY: Optional[str] = (
None # api key for embeddings (if using openai, just copy API_KEY)
)
EMBEDDINGS_KEY: Optional[str] = None # api key for embeddings (if using openai, just copy API_KEY)
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
MONGO_URI: str = "mongodb://localhost:27017/docsgpt"
@@ -45,9 +39,7 @@ class Settings(BaseSettings):
PARSE_IMAGE_REMOTE: bool = False
DOCLING_OCR_ENABLED: bool = False # Enable OCR for docling parsers (PDF, images)
DOCLING_OCR_ATTACHMENTS_ENABLED: bool = False # Enable OCR for docling when parsing attachments
VECTOR_STORE: str = (
"faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb" or "pgvector"
)
VECTOR_STORE: str = "faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb" or "pgvector"
RETRIEVERS_ENABLED: list = ["classic_rag"]
AGENT_NAME: str = "classic"
FALLBACK_LLM_PROVIDER: Optional[str] = None # provider for fallback llm
@@ -55,12 +47,8 @@ class Settings(BaseSettings):
FALLBACK_LLM_API_KEY: Optional[str] = None # api key for fallback llm
# Google Drive integration
GOOGLE_CLIENT_ID: Optional[str] = (
None # Replace with your actual Google OAuth client ID
)
GOOGLE_CLIENT_SECRET: Optional[str] = (
None # Replace with your actual Google OAuth client secret
)
GOOGLE_CLIENT_ID: Optional[str] = None # Replace with your actual Google OAuth client ID
GOOGLE_CLIENT_SECRET: Optional[str] = None # Replace with your actual Google OAuth client secret
CONNECTOR_REDIRECT_BASE_URI: Optional[str] = (
"http://127.0.0.1:7091/api/connectors/callback" ##add redirect url as it is to your provider's console(gcp)
)
@@ -72,7 +60,7 @@ class Settings(BaseSettings):
MICROSOFT_AUTHORITY: Optional[str] = None # e.g., "https://login.microsoftonline.com/{tenant_id}"
# GitHub source
GITHUB_ACCESS_TOKEN: Optional[str] = None # PAT token with read repo access
GITHUB_ACCESS_TOKEN: Optional[str] = None # PAT token with read repo access
# LLM Cache
CACHE_REDIS_URL: str = "redis://localhost:6379/2"
@@ -90,16 +78,13 @@ class Settings(BaseSettings):
GROQ_API_KEY: Optional[str] = None
HUGGINGFACE_API_KEY: Optional[str] = None
OPEN_ROUTER_API_KEY: Optional[str] = None
NOVITA_API_KEY: Optional[str] = None
OPENAI_API_BASE: Optional[str] = None # azure openai api base url
OPENAI_API_VERSION: Optional[str] = None # azure openai api version
AZURE_DEPLOYMENT_NAME: Optional[str] = None # azure deployment name for answering
AZURE_EMBEDDINGS_DEPLOYMENT_NAME: Optional[str] = (
None # azure deployment name for embeddings
)
OPENAI_BASE_URL: Optional[str] = (
None # openai base url for open ai compatable models
)
AZURE_EMBEDDINGS_DEPLOYMENT_NAME: Optional[str] = None # azure deployment name for embeddings
OPENAI_BASE_URL: Optional[str] = None # openai base url for open ai compatable models
# elasticsearch
ELASTIC_CLOUD_ID: Optional[str] = None # cloud id for elasticsearch
@@ -141,9 +126,7 @@ class Settings(BaseSettings):
# LanceDB vectorstore config
LANCEDB_PATH: str = "./data/lancedb" # Path where LanceDB stores its local data
LANCEDB_TABLE_NAME: Optional[str] = (
"docsgpts" # Name of the table to use for storing vectors
)
LANCEDB_TABLE_NAME: Optional[str] = "docsgpts" # Name of the table to use for storing vectors
FLASK_DEBUG_MODE: bool = False
STORAGE_TYPE: str = "local" # local or s3
@@ -180,6 +163,7 @@ class Settings(BaseSettings):
"GOOGLE_API_KEY",
"GROQ_API_KEY",
"HUGGINGFACE_API_KEY",
"NOVITA_API_KEY",
"EMBEDDINGS_KEY",
"FALLBACK_LLM_API_KEY",
"QDRANT_API_KEY",

View File

@@ -7,6 +7,7 @@ class LLMHandlerCreator:
handlers = {
"openai": OpenAILLMHandler,
"google": GoogleLLMHandler,
"novita": OpenAILLMHandler, # Novita uses OpenAI-compatible API
"default": OpenAILLMHandler,
}

View File

@@ -1,13 +1,13 @@
from application.core.settings import settings
from application.llm.openai import OpenAILLM
NOVITA_BASE_URL = "https://api.novita.ai/v3/openai"
NOVITA_BASE_URL = "https://api.novita.ai/openai"
class NovitaLLM(OpenAILLM):
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
super().__init__(
api_key=api_key or settings.API_KEY,
api_key=api_key or settings.NOVITA_API_KEY or settings.API_KEY,
user_api_key=user_api_key,
base_url=base_url or NOVITA_BASE_URL,
*args,

View File

@@ -44,36 +44,40 @@ The main set of instructions or system [prompt](/Guides/Customising-prompts) tha
## Understanding Agent Types
DocsGPT allows for different "types" of agents, each with a distinct way of processing information and generating responses. The code for these agent types can be found in the `application/agents/` directory.
DocsGPT supports several agent types, each with a distinct way of processing information. The code for these can be found in the `application/agents/` directory.
### 1. Classic Agent (`classic_agent.py`)
### 1. Classic Agent
**How it works:** The Classic Agent follows a traditional Retrieval Augmented Generation (RAG) approach.
1. **Retrieve:** When a query is made, it first searches the selected Source documents for relevant information.
2. **Augment:** This retrieved data is then added to the context, along with the main Prompt and the user's query.
3. **Generate:** The LLM generates a response based on this augmented context. It can also utilize any configured tools if the LLM decides they are necessary.
The Classic Agent follows a traditional Retrieval Augmented Generation (RAG) approach: it retrieves relevant document chunks, augments the prompt context with them, and generates a response. It can also use configured tools if the LLM decides they are necessary.
**Best for:**
* Direct question-answering over a specific set of documents.
* Tasks where the primary goal is to extract and synthesize information from the provided sources.
* Simpler tool integrations where the decision to use a tool is straightforward.
**Best for:** Direct question-answering over a specific set of documents and straightforward tool use.
### 2. ReAct Agent (`react_agent.py`)
### 2. Agentic Agent
**How it works:** The ReAct Agent employs a more sophisticated "Reason and Act" framework. This involves a multi-step process:
1. **Plan (Thought):** Based on the query, its prompt, and available tools/sources, the LLM first generates a plan or a sequence of thoughts on how to approach the problem. You might see this output as a "thought" process during generation.
2. **Act:** The agent then executes actions based on this plan. This might involve querying its sources, using a tool, or performing internal reasoning.
3. **Observe:** It gathers observations from the results of its actions (e.g., data from a tool, snippets from documents).
4. **Repeat (if necessary):** Steps 2 and 3 can be repeated as the agent refines its approach or gathers more information.
5. **Conclude:** Finally, it generates the final answer based on the initial query and all accumulated observations.
Unlike Classic which pre-fetches documents into the prompt, the Agentic Agent gives the LLM an `internal_search` tool so it can decide **when, what, and whether** to search. This means the LLM controls its own retrieval — it can search multiple times, refine queries, or skip retrieval entirely if the question doesn't need it.
**Best for:**
* More complex tasks that require multi-step reasoning or problem-solving.
* Scenarios where the agent needs to dynamically decide which tools to use and in what order, based on intermediate results.
* Interactive tasks where the agent needs to "think" through a problem.
**Best for:** Tasks where the agent needs to dynamically decide how to gather information, use multiple tools in sequence, or combine retrieval with external tool calls.
### 3. Research Agent
A multi-phase agent designed for in-depth research tasks:
1. **Clarification** — Determines if the question needs clarification before proceeding.
2. **Planning** — Decomposes the question into research steps with adaptive depth based on complexity.
3. **Research** — Executes each step, calling tools and refining queries as needed.
4. **Synthesis** — Compiles findings into a final cited report.
Includes budget controls for max steps, timeout, and token limits to keep research bounded.
**Best for:** Complex questions that require multi-step investigation, gathering information from multiple sources, and producing structured reports with citations.
### 4. Workflow Agent
Executes predefined workflows composed of connected nodes (AI Agent, Set State, Condition). See the [Workflow Nodes](/Agents/nodes) page for details on building workflows.
**Best for:** Structured, multi-step processes with branching logic and shared state between steps.
<Callout type="info">
Developers looking to introduce new agent architectures can explore the `application/agents/` directory. `classic_agent.py` and `react_agent.py` serve as excellent starting points, demonstrating how to inherit from `BaseAgent` and structure agent logic.
The legacy "ReAct" agent type is still accepted for backwards compatibility but maps to the Classic Agent internally. New agents should use Classic, Agentic, or Research instead.
</Callout>
## Navigating and Managing Agents in DocsGPT

View File

@@ -70,9 +70,9 @@ Inside the DocsGPT folder create a `.env` file and copy the contents of `.env_sa
Make sure your `.env` file looks like this:
```
OPENAI_API_KEY=(Your OpenAI API key)
API_KEY=<Your LLM API key>
LLM_NAME=docsgpt
VITE_API_STREAMING=true
SELF_HOSTED_MODEL=false
```
To save the file, press CTRL+X, then Y, and then ENTER.

View File

@@ -104,7 +104,7 @@ DocsGPT can transcribe audio in two places:
- Voice input in the chat.
- Audio file ingestion. Uploaded `.wav`, `.mp3`, `.m4a`, `.ogg`, and `.webm` files are transcribed first and then passed through the normal parser, chunking, embedding, and indexing pipeline.
For an end-to-end walkthrough, see the [Speech and Audio Guide](/Guides/speech-and-audio).
The settings below control speech-to-text behaviour for both voice input and audio file ingestion.
| Setting | Purpose | Typical values |
| --- | --- | --- |
@@ -214,6 +214,31 @@ If you have configured `AUTH_TYPE=simple_jwt`, the DocsGPT frontend will prompt
}}
/>
## S3 Storage Backend
By default DocsGPT stores files locally. Set `STORAGE_TYPE=s3` to use Amazon S3 instead.
| Setting | Description | Default |
| --- | --- | --- |
| `STORAGE_TYPE` | `local` or `s3` | `local` |
| `S3_BUCKET_NAME` | S3 bucket name | `docsgpt-test-bucket` |
| `SAGEMAKER_ACCESS_KEY` | AWS access key ID | — |
| `SAGEMAKER_SECRET_KEY` | AWS secret access key | — |
| `SAGEMAKER_REGION` | AWS region | — |
| `URL_STRATEGY` | `backend` (proxy through API) or `s3` (direct S3 URLs) | `backend` |
The S3 credentials use `SAGEMAKER_*` variable names because they are shared with the SageMaker integration.
```env
STORAGE_TYPE=s3
S3_BUCKET_NAME=your-bucket-name
SAGEMAKER_ACCESS_KEY=your-aws-access-key-id
SAGEMAKER_SECRET_KEY=your-aws-secret-access-key
SAGEMAKER_REGION=us-east-1
```
Your IAM user needs these permissions on the bucket: `s3:PutObject`, `s3:GetObject`, `s3:DeleteObject`, `s3:ListBucket`, `s3:HeadObject`.
## Exploring More Settings
These are just the basic settings to get you started. The `settings.py` file contains many more advanced options that you can explore to further customize DocsGPT, such as:

View File

@@ -86,13 +86,9 @@ Make sure your `.env` file looks like this:
```
OPENAI_API_KEY=(Your OpenAI API key)
API_KEY=<Your LLM API key>
LLM_NAME=docsgpt
VITE_API_STREAMING=true
SELF_HOSTED_MODEL=false
```

View File

@@ -11,18 +11,18 @@ DocsGPT API keys are essential for developers and users who wish to integrate th
After uploading your document, you can obtain an API key either through the graphical user interface or via an API call:
- **Graphical User Interface:** Navigate to the Settings section of the DocsGPT web app, find the API Keys option, and press 'Create New' to generate your key.
- **API Call:** Alternatively, you can use the `/api/create_api_key` endpoint to create a new API key. For detailed instructions, visit [DocsGPT API Documentation](https://gptcloud.arc53.com/).
- **Graphical User Interface:** Navigate to the Settings section of the DocsGPT web app, find the Agents option, and press 'Create New' to generate a new agent (which includes an API key).
- **API Call:** Alternatively, you can use the `/api/create_agent` endpoint to create a new agent. An API key is automatically generated for each agent. For detailed instructions, visit [DocsGPT API Documentation](https://gptcloud.arc53.com/).
## Understanding Key Variables
Upon creating your API key, you will encounter several key variables. Each serves a specific purpose:
Upon creating your agent, you will encounter several key variables. Each serves a specific purpose:
- **Name:** Assign a name to your API key for easy identification.
- **Source:** Indicates the source document(s) linked to your API key, which DocsGPT will use to generate responses.
- **ID:** A unique identifier for your API key. You can view this by making a call to `/api/get_api_keys`.
- **Key:** The API key itself, which will be used in your application to authenticate API requests.
- **Name:** Assign a name to your agent for easy identification.
- **Source:** Indicates the source document(s) linked to your agent, which DocsGPT will use to generate responses.
- **ID:** A unique identifier for your agent. You can view this by making a call to `/api/get_agents`.
- **Key:** The API key for the agent, which will be used in your application to authenticate API requests.
With your API key ready, you can now integrate DocsGPT into your application, such as the DocsGPT Widget or any other software, via `/api/answer` or `/stream` endpoints. The source document is preset with the API key, allowing you to bypass fields like `selectDocs` and `active_docs` during implementation.
With your API key ready, you can now integrate DocsGPT into your application, such as the DocsGPT Widget or any other software, via `/api/answer` or `/stream` endpoints. The source document is preset with the agent, allowing you to bypass fields like `selectDocs` and `active_docs` during implementation.
Congratulations on taking the first step towards enhancing your applications with DocsGPT!

View File

@@ -64,7 +64,7 @@ flowchart LR
* **Technology:** Supports multiple vector databases.
* **Responsibility:** Vector Stores are used to store and retrieve vector embeddings of document chunks. This enables semantic search and retrieval of relevant document snippets in response to user queries.
* **Key Features:**
* Supports vector databases including FAISS, Elasticsearch, Qdrant, Milvus, and LanceDB.
* Supports vector databases including FAISS, Elasticsearch, Qdrant, Milvus, MongoDB Atlas Vector Search, and pgvector.
* Provides storage and indexing of high-dimensional vector embeddings.
* Enables editing and updating of vector indexes including specific chunks.

View File

@@ -16,7 +16,7 @@ Training on other documentation sources can greatly enhance the versatility and
Make sure you have the document on which you want to train on ready with you on the device which you are using .You can also use links to the documentation to train on.
<Callout type="warning" emoji="⚠️">
Note: The document should be either of the given file formats .pdf, .txt, .rst, .docx, .md, .zip and limited to 25mb.You can also train using the link of the documentation.
Note: Supported file formats include .pdf, .txt, .rst, .docx, .md, .mdx, .csv, .epub, .html, .json, .xlsx, .pptx, .png, .jpg, .jpeg, and audio files (.wav, .mp3, .m4a, .ogg, .webm). You can also train using the link of the documentation.
</Callout>

View File

@@ -35,8 +35,34 @@ Choose the LLM of your choice.
For open source version please edit `LLM_PROVIDER`, `LLM_NAME` and others in the .env file. Refer to [⚙️ App Configuration](/Deploying/DocsGPT-Settings) for more information.
### Step 2
Visit [☁️ Cloud Providers](/Models/cloud-providers) for the updated list of online models. Make sure you have the right API_KEY and correct LLM_PROVIDER.
For self-hosted please visit [🖥️ Local Inference](/Models/local-inference).
For self-hosted please visit [🖥️ Local Inference](/Models/local-inference).
</Steps>
## Fallback LLM
DocsGPT can automatically switch to a fallback LLM when the primary model fails, including mid-stream. This works with both streaming and non-streaming requests.
**Fallback order:**
1. Per-agent backup models (other models configured on the same agent)
2. Global fallback (`FALLBACK_LLM_*` env vars below)
3. Error returned if all fail
| Setting | Description | Default |
| --- | --- | --- |
| `FALLBACK_LLM_PROVIDER` | Provider name (e.g., `openai`, `anthropic`, `google`) | — |
| `FALLBACK_LLM_NAME` | Model name (e.g., `gpt-4o`, `claude-sonnet-4-20250514`) | — |
| `FALLBACK_LLM_API_KEY` | API key for the fallback provider | Falls back to `API_KEY` |
All three (`FALLBACK_LLM_PROVIDER`, `FALLBACK_LLM_NAME`, and an API key) must resolve for the global fallback to activate.
```env
FALLBACK_LLM_PROVIDER=anthropic
FALLBACK_LLM_NAME=claude-sonnet-4-20250514
FALLBACK_LLM_API_KEY=sk-ant-your-anthropic-key
```
<Callout type="info">
For maximum resilience, use a fallback provider from a different cloud than your primary. Each agent can also have multiple models configured — the other models are tried first before the global fallback.
</Callout>

View File

@@ -2,5 +2,13 @@ export default {
"google-drive-connector": {
"title": "🔗 Google Drive",
"href": "/Guides/Integrations/google-drive-connector"
},
"sharepoint-connector": {
"title": "🔗 SharePoint / OneDrive",
"href": "/Guides/Integrations/sharepoint-connector"
},
"mcp-tool-integration": {
"title": "🔗 MCP Tools",
"href": "/Guides/Integrations/mcp-tool-integration"
}
}

View File

@@ -0,0 +1,66 @@
---
title: MCP Tool Integration
description: Connect external tools to DocsGPT agents using the Model Context Protocol (MCP) standard.
---
import { Callout } from 'nextra/components'
import { Steps } from 'nextra/components'
# MCP Tool Integration
The [Model Context Protocol (MCP)](https://modelcontextprotocol.io/) integration lets you connect external tool servers to DocsGPT. Your agents can then discover and call tools provided by those servers during conversations — for example, querying a CRM, running code, or accessing a database.
## Setup
<Steps>
### Step 1: Configure Environment Variables (Optional)
Only needed if your MCP servers use OAuth authentication:
```env
MCP_OAUTH_REDIRECT_URI=https://yourdomain.com/api/mcp_server/callback
```
If not set, falls back to `API_URL/api/mcp_server/callback`.
### Step 2: Add an MCP Server
Go to **Settings** > **Tools** > **Add Tool** > **MCP Server**. Enter the server URL, select an auth type, and click **Test Connection** to verify, then **Save**.
### Step 3: Enable for Your Agent
In your agent configuration, enable the MCP tools you want the agent to use.
</Steps>
## Authentication Types
| Auth Type | Config Fields |
|-----------|---------------|
| **None** | — |
| **Bearer** | `bearer_token` |
| **API Key** | `api_key`, `api_key_header` (default: `X-API-Key`) |
| **Basic** | `username`, `password` |
| **OAuth** | `oauth_scopes` (optional) |
<Callout type="warning">
For OAuth in production, `MCP_OAUTH_REDIRECT_URI` must be a publicly accessible URL pointing to your DocsGPT backend.
</Callout>
## API Endpoints
| Endpoint | Method | Description |
|----------|--------|-------------|
| `/api/mcp_server/test` | POST | Test a connection without saving |
| `/api/mcp_server/save` | POST | Save or update a server configuration |
| `/api/mcp_server/callback` | GET | OAuth callback handler |
| `/api/mcp_server/oauth_status/<task_id>` | GET | Poll OAuth flow status |
| `/api/mcp_server/auth_status` | GET | Batch check auth status for all MCP tools |
## Troubleshooting
- **Connection refused** — Verify the URL and that the server is reachable from your backend.
- **403 Forbidden** — Check credentials and permissions.
- **Timed out** — Default is 30s; increase timeout in tool config (max 300s).
- **OAuth "needs_auth" persists** — Verify `MCP_OAUTH_REDIRECT_URI` is correct and Redis is running.

View File

@@ -0,0 +1,63 @@
---
title: SharePoint / OneDrive Connector
description: Connect your Microsoft SharePoint or OneDrive as an external knowledge base to upload and process files directly.
---
import { Callout } from 'nextra/components'
import { Steps } from 'nextra/components'
# SharePoint / OneDrive Connector
Connect your SharePoint or OneDrive account to upload and process files directly as an external knowledge base. Supports Office files, PDFs, text files, CSVs, images, and more. Authentication is handled via Microsoft Entra ID (Azure AD) with automatic token refresh.
## Setup
<Steps>
### Step 1: Create an App Registration in Azure
1. Go to the [Azure Portal](https://portal.azure.com/) > **Microsoft Entra ID** > **App registrations** > **New registration**
2. Set **Redirect URI** (Web) to:
- Local: `http://localhost:7091/api/connectors/callback?provider=share_point`
- Production: `https://yourdomain.com/api/connectors/callback?provider=share_point`
### Step 2: Configure API Permissions
In your App Registration, go to **API permissions** > **Add a permission** > **Microsoft Graph** > **Delegated permissions** and add: `Files.Read`, `Files.Read.All`, `Sites.Read.All`. Grant admin consent if possible.
### Step 3: Create a Client Secret
Go to **Certificates & secrets** > **New client secret**. Copy the secret value immediately (it won't be shown again).
### Step 4: Configure Environment Variables
Add to your `.env` file:
```env
MICROSOFT_CLIENT_ID=your-azure-ad-client-id
MICROSOFT_CLIENT_SECRET=your-azure-ad-client-secret
MICROSOFT_TENANT_ID=your-azure-ad-tenant-id
```
| Variable | Description | Required | Default |
|----------|-------------|----------|---------|
| `MICROSOFT_CLIENT_ID` | Application (client) ID from App Registration overview | Yes | — |
| `MICROSOFT_CLIENT_SECRET` | Client secret value | Yes | — |
| `MICROSOFT_TENANT_ID` | Directory (tenant) ID | No | `common` |
| `MICROSOFT_AUTHORITY` | Login endpoint override | No | Auto-constructed |
<Callout type="warning">
`MICROSOFT_TENANT_ID=common` (the default) allows any Microsoft account to authenticate. Set this to your specific tenant ID in production.
</Callout>
### Step 5: Restart and Use
Restart your application, then go to the upload section in DocsGPT and select **SharePoint / OneDrive** as the source. You'll be redirected to Microsoft to sign in, then can browse and select files to process.
</Steps>
## Troubleshooting
- **Option not appearing** — Verify `MICROSOFT_CLIENT_ID` and `MICROSOFT_CLIENT_SECRET` are set, then restart.
- **Authentication failed** — Check that the redirect URI matches exactly, including `?provider=share_point`.
- **Permission denied** — Ensure admin consent is granted and the user has access to the target files.

View File

@@ -7,20 +7,10 @@ description:
If your AI uses external knowledge and is not explicit enough, it is ok, because we try to make DocsGPT friendly.
But if you want to adjust it, here is a simple way:-
- Got to `application/prompts/chat_combine_prompt.txt`
- And change it to
But if you want to adjust it, prompts are now managed through the UI and API using a template-based system. See the [Customising Prompts](/Guides/Customising-prompts) guide for details.
To make the AI stricter about staying on-topic, edit your active prompt template (via **Sidebar → Settings → Active Prompt**) to include instructions like:
```
You are a DocsGPT, friendly and helpful AI assistant by Arc53 that provides help with documents. You give thorough answers with code examples, if possible.
Write an answer for the question below based on the provided context.
If the context provides insufficient information, reply "I cannot answer".
You have access to chat history and can use it to help answer the question.
----------------
{summaries}
```

View File

@@ -29,7 +29,7 @@ export default {
"title": "OCR",
"href": "/Guides/ocr"
},
"Integrations": {
"Integrations": {
"title": "🔗 Integrations"
}
}

View File

@@ -70,7 +70,7 @@ The easiest way to launch DocsGPT is using the provided `setup.sh` script. This
To stop DocsGPT, simply open a new terminal in the `DocsGPT` directory and run:
```bash
docker compose -f deployment/docker-compose.yaml down
docker compose -f deployment/docker-compose-hub.yaml down
```
(or the specific `docker compose` command shown at the end of the `setup.sh` execution, which may include optional compose files depending on your choices).

View File

@@ -1,12 +1,12 @@
{
"name": "docsgpt",
"version": "0.5.1",
"version": "0.6.3",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "docsgpt",
"version": "0.5.1",
"version": "0.6.3",
"license": "Apache-2.0",
"dependencies": {
"@babel/plugin-transform-flow-strip-types": "^7.23.3",

View File

@@ -1,6 +1,6 @@
{
"name": "docsgpt",
"version": "0.6.0",
"version": "0.6.3",
"private": false,
"description": "DocsGPT 🦖 is an innovative open-source tool designed to simplify the retrieval of information from project documentation using advanced GPT models 🤖.",
"source": "./src/index.html",

View File

@@ -1,6 +1,6 @@
aiohttp>=3,<4
certifi==2024.7.4
h11==0.16.0
h11==0.14.0
httpcore==1.0.5
httpx==0.27.0
idna==3.7

View File

@@ -40,9 +40,12 @@ import {
} from '@/components/ui/select';
import { Sheet, SheetContent } from '@/components/ui/sheet';
import { useSelector } from 'react-redux';
import modelService from '../api/services/modelService';
import userService from '../api/services/userService';
import ArrowLeft from '../assets/arrow-left.svg';
import { selectToken } from '../preferences/preferenceSlice';
import { WorkflowNode } from './types/workflow';
import {
AgentNode,
@@ -77,6 +80,7 @@ interface UserTool {
function WorkflowBuilderInner() {
const navigate = useNavigate();
const token = useSelector(selectToken);
const { agentId } = useParams<{ agentId?: string }>();
const [searchParams] = useSearchParams();
const folderId = searchParams.get('folder_id');
@@ -304,7 +308,7 @@ function WorkflowBuilderInner() {
setAvailableModels(modelService.transformModels(modelsData.models));
}
const toolsResponse = await userService.getUserTools(null);
const toolsResponse = await userService.getUserTools(token);
if (toolsResponse.ok) {
const toolsData = await toolsResponse.json();
setAvailableTools(toolsData.tools);

View File

@@ -54,7 +54,10 @@ import { FileUpload } from '../../components/FileUpload';
import AgentDetailsModal from '../../modals/AgentDetailsModal';
import ConfirmationModal from '../../modals/ConfirmationModal';
import { ActiveState } from '../../models/misc';
import { selectToken } from '../../preferences/preferenceSlice';
import {
selectSourceDocs,
selectToken,
} from '../../preferences/preferenceSlice';
import { getToolDisplayName } from '../../utils/toolUtils';
import { Agent } from '../types';
import { ConditionCase, WorkflowNode } from '../types/workflow';
@@ -300,6 +303,7 @@ function createWorkflowPayload(
function WorkflowBuilderInner() {
const navigate = useNavigate();
const token = useSelector(selectToken);
const sourceDocs = useSelector(selectSourceDocs);
const { agentId } = useParams<{ agentId?: string }>();
const [searchParams] = useSearchParams();
const folderId = searchParams.get('folder_id');
@@ -341,6 +345,14 @@ function WorkflowBuilderInner() {
const [availableModels, setAvailableModels] = useState<Model[]>([]);
const [defaultAgentModelId, setDefaultAgentModelId] = useState('');
const [availableTools, setAvailableTools] = useState<UserTool[]>([]);
const sourceOptions = useMemo(
() =>
(sourceDocs ?? []).map((doc) => ({
value: doc.id ?? 'default',
label: doc.name,
})),
[sourceDocs],
);
const [agentJsonSchemaDrafts, setAgentJsonSchemaDrafts] = useState<
Record<string, string>
>({});
@@ -387,31 +399,39 @@ function WorkflowBuilderInner() {
[],
);
const onConnect = useCallback((params: Connection) => {
setEdges((eds) => {
const exists = eds.some(
(e) =>
e.source === params.source &&
e.sourceHandle === params.sourceHandle &&
e.target === params.target &&
e.targetHandle === params.targetHandle,
);
if (exists) return eds;
const filtered = eds.filter(
(e) =>
!(
const onConnect = useCallback(
(params: Connection) => {
setEdges((eds) => {
const exists = eds.some(
(e) =>
e.source === params.source &&
e.sourceHandle === (params.sourceHandle ?? null)
) &&
!(
e.sourceHandle === params.sourceHandle &&
e.target === params.target &&
e.targetHandle === (params.targetHandle ?? null)
),
);
return addEdge(params, filtered);
});
}, []);
e.targetHandle === params.targetHandle,
);
if (exists) return eds;
const targetNode = nodes.find((n) => n.id === params.target);
const isEndNode = targetNode?.type === 'end';
const filtered = eds.filter(
(e) =>
!(
e.source === params.source &&
e.sourceHandle === (params.sourceHandle ?? null)
) &&
// End nodes accept multiple incoming edges
(isEndNode ||
!(
e.target === params.target &&
e.targetHandle === (params.targetHandle ?? null)
)),
);
return addEdge(params, filtered);
});
},
[nodes],
);
const onEdgeClick = useCallback((_event: React.MouseEvent, edge: Edge) => {
setEdges((eds) => eds.filter((e) => e.id !== edge.id));
@@ -701,7 +721,7 @@ function WorkflowBuilderInner() {
setDefaultAgentModelId(preferredDefaultModel);
}
const toolsResponse = await userService.getUserTools(null);
const toolsResponse = await userService.getUserTools(token);
if (toolsResponse.ok) {
const toolsData = await toolsResponse.json();
setAvailableTools(toolsData.tools);
@@ -1271,8 +1291,8 @@ function WorkflowBuilderInner() {
const handlePrimaryAction = useCallback(() => {
if (isPrimaryActionDisabled) return;
void persistWorkflow(!canManageAgent);
}, [isPrimaryActionDisabled, persistWorkflow, canManageAgent]);
void persistWorkflow(false);
}, [isPrimaryActionDisabled, persistWorkflow]);
const agentForDetails = useMemo<Agent>(
() => ({
@@ -1910,6 +1930,28 @@ function WorkflowBuilderInner() {
emptyText="No tools available"
/>
</div>
<div>
<label className="mb-2 block text-sm font-medium text-gray-700 dark:text-gray-300">
Sources
</label>
<MultiSelect
options={sourceOptions}
selected={
selectedNode.data.config?.sources || []
}
onChange={(newSources) =>
handleUpdateNodeData({
config: {
...(selectedNode.data.config || {}),
sources: newSources,
},
})
}
placeholder="Select sources..."
searchPlaceholder="Search sources..."
emptyText="No sources available"
/>
</div>
<div>
<label className="mb-2 block text-sm font-medium text-gray-700 dark:text-gray-300">
Structured Output (JSON Schema)

View File

@@ -70,12 +70,12 @@ export function MultiSelect({
role="combobox"
aria-expanded={open}
className={cn(
'w-full justify-between border-[#E5E5E5] bg-white hover:bg-gray-50 dark:border-[#3A3A3A] dark:bg-[#2C2C2C] dark:hover:bg-[#383838]',
'h-auto min-h-[2.5rem] w-full justify-between border-[#E5E5E5] bg-white py-1.5 hover:bg-gray-50 dark:border-[#3A3A3A] dark:bg-[#2C2C2C] dark:hover:bg-[#383838]',
!selected.length && 'text-gray-500 dark:text-gray-400',
className,
)}
>
<div className="flex flex-wrap gap-1">
<div className="flex min-w-0 flex-wrap gap-1">
{selected.length === 0 ? (
placeholder
) : (
@@ -85,9 +85,9 @@ export function MultiSelect({
return (
<span
key={option?.value || label}
className="dark:bg-purple-30/30 bg-violets-are-blue/20 inline-flex items-center gap-1 rounded-md px-2 py-0.5 text-xs font-medium text-purple-700 dark:text-purple-300"
className="dark:bg-purple-30/30 bg-violets-are-blue/20 inline-flex max-w-[calc(100%-1rem)] items-center gap-1 rounded-md px-2 py-0.5 text-xs font-medium text-purple-700 dark:text-purple-300"
>
{label}
<span className="truncate">{label}</span>
<span
role="button"
tabIndex={0}

View File

@@ -3,7 +3,7 @@ testpaths = tests
python_files = test_*.py
python_classes = Test*
python_functions = test_*
addopts =
addopts =
-v
--strict-markers
--tb=short
@@ -11,6 +11,7 @@ addopts =
--cov-report=html
--cov-report=term-missing
--cov-report=xml
--ignore=tests/integration
markers =
unit: Unit tests
integration: Integration tests

View File

@@ -977,8 +977,8 @@ function Connect-CloudAPIProvider {
}
"7" { # Novita
$script:provider_name = "Novita"
$script:llm_name = "novita"
$script:model_name = "deepseek/deepseek-r1"
$script:llm_provider = "novita"
$script:model_name = "moonshotai/kimi-k2.5"
Get-APIKey
break
}

View File

@@ -704,7 +704,7 @@ connect_cloud_api_provider() {
7) # Novita
provider_name="Novita"
llm_provider="novita"
model_name="deepseek/deepseek-r1"
model_name="moonshotai/kimi-k2.5"
get_api_key
break ;;
b|B) clear; return 1 ;; # Clear screen and Back to Main Menu

View File

@@ -0,0 +1,202 @@
"""Tests for application/agents/tools/api_body_serializer.py"""
import json
import pytest
from application.agents.tools.api_body_serializer import (
ContentType,
RequestBodySerializer,
)
@pytest.mark.unit
class TestContentTypeEnum:
def test_json_value(self):
assert ContentType.JSON == "application/json"
def test_form_urlencoded_value(self):
assert ContentType.FORM_URLENCODED == "application/x-www-form-urlencoded"
def test_multipart_value(self):
assert ContentType.MULTIPART_FORM_DATA == "multipart/form-data"
def test_text_plain_value(self):
assert ContentType.TEXT_PLAIN == "text/plain"
def test_xml_value(self):
assert ContentType.XML == "application/xml"
def test_octet_stream_value(self):
assert ContentType.OCTET_STREAM == "application/octet-stream"
@pytest.mark.unit
class TestSerializeJson:
def test_basic_json(self):
body, headers = RequestBodySerializer.serialize(
{"key": "value"}, ContentType.JSON
)
assert json.loads(body) == {"key": "value"}
assert headers["Content-Type"] == "application/json"
def test_nested_json(self):
data = {"user": {"name": "Alice", "age": 30}}
body, headers = RequestBodySerializer.serialize(data, ContentType.JSON)
assert json.loads(body) == data
def test_empty_body_returns_none(self):
body, headers = RequestBodySerializer.serialize({}, ContentType.JSON)
assert body is None
assert headers == {}
def test_none_body(self):
body, headers = RequestBodySerializer.serialize(None, ContentType.JSON)
assert body is None
def test_unknown_content_type_falls_back_to_json(self):
body, headers = RequestBodySerializer.serialize(
{"k": "v"}, "application/vnd.custom+json"
)
assert json.loads(body) == {"k": "v"}
def test_content_type_with_charset_suffix(self):
body, headers = RequestBodySerializer.serialize(
{"k": "v"}, "application/json; charset=utf-8"
)
assert json.loads(body) == {"k": "v"}
@pytest.mark.unit
class TestSerializeFormUrlencoded:
def test_basic_form(self):
body, headers = RequestBodySerializer.serialize(
{"name": "Alice", "age": "30"}, ContentType.FORM_URLENCODED
)
assert "name=Alice" in body
assert "age=30" in body
assert headers["Content-Type"] == "application/x-www-form-urlencoded"
def test_none_values_skipped(self):
body, headers = RequestBodySerializer.serialize(
{"name": "Alice", "skip": None}, ContentType.FORM_URLENCODED
)
assert "name=Alice" in body
assert "skip" not in body
def test_list_explode_true(self):
body, headers = RequestBodySerializer.serialize(
{"tags": ["a", "b"]},
ContentType.FORM_URLENCODED,
encoding_rules={"tags": {"style": "form", "explode": True}},
)
assert "tags=a" in body
assert "tags=b" in body
def test_list_explode_false(self):
body, headers = RequestBodySerializer.serialize(
{"tags": ["a", "b"]},
ContentType.FORM_URLENCODED,
encoding_rules={"tags": {"style": "form", "explode": False}},
)
# Value is percent-encoded by _serialize_form_value then urlencoded again
assert "tags=" in body
assert "a" in body and "b" in body
def test_dict_value_json_content_type(self):
body, headers = RequestBodySerializer.serialize(
{"metadata": {"key": "val"}},
ContentType.FORM_URLENCODED,
encoding_rules={"metadata": {"contentType": "application/json"}},
)
assert "metadata" in body
@pytest.mark.unit
class TestSerializeTextPlain:
def test_single_value(self):
body, headers = RequestBodySerializer.serialize(
{"message": "hello"}, ContentType.TEXT_PLAIN
)
assert body == "hello"
assert headers["Content-Type"] == "text/plain"
def test_multiple_values(self):
body, headers = RequestBodySerializer.serialize(
{"name": "Alice", "age": 30}, ContentType.TEXT_PLAIN
)
assert "name: Alice" in body
assert "age: 30" in body
@pytest.mark.unit
class TestSerializeXml:
def test_basic_xml(self):
body, headers = RequestBodySerializer.serialize(
{"name": "Alice"}, ContentType.XML
)
assert '<?xml version="1.0"' in body
assert "<name>Alice</name>" in body
assert headers["Content-Type"] == "application/xml"
def test_nested_xml(self):
body, headers = RequestBodySerializer.serialize(
{"user": {"name": "Alice"}}, ContentType.XML
)
assert "<user>" in body
assert "<name>Alice</name>" in body
def test_xml_escapes_special_chars(self):
body, headers = RequestBodySerializer.serialize(
{"data": "<script>alert('xss')</script>"}, ContentType.XML
)
assert "&lt;script&gt;" in body
@pytest.mark.unit
class TestSerializeOctetStream:
def test_dict_body(self):
body, headers = RequestBodySerializer.serialize(
{"key": "val"}, ContentType.OCTET_STREAM
)
assert isinstance(body, bytes)
assert headers["Content-Type"] == "application/octet-stream"
@pytest.mark.unit
class TestSerializeMultipartFormData:
def test_basic_multipart(self):
body, headers = RequestBodySerializer.serialize(
{"field": "value"}, ContentType.MULTIPART_FORM_DATA
)
assert isinstance(body, bytes)
assert "multipart/form-data" in headers["Content-Type"]
assert "boundary=" in headers["Content-Type"]
def test_none_values_skipped(self):
body, headers = RequestBodySerializer.serialize(
{"field": "value", "empty": None}, ContentType.MULTIPART_FORM_DATA
)
body_str = body.decode("utf-8", errors="replace")
assert "field" in body_str
assert "empty" not in body_str
@pytest.mark.unit
class TestHelpers:
def test_percent_encode(self):
assert RequestBodySerializer._percent_encode("hello world") == "hello%20world"
assert RequestBodySerializer._percent_encode("a/b") == "a%2Fb"
assert RequestBodySerializer._percent_encode("safe", safe_chars="/") == "safe"
def test_escape_xml(self):
assert "&amp;" in RequestBodySerializer._escape_xml("&")
assert "&lt;" in RequestBodySerializer._escape_xml("<")
assert "&gt;" in RequestBodySerializer._escape_xml(">")
assert "&quot;" in RequestBodySerializer._escape_xml('"')
assert "&apos;" in RequestBodySerializer._escape_xml("'")
def test_dict_to_xml_list(self):
xml = RequestBodySerializer._dict_to_xml({"items": [1, 2, 3]})
assert "<item>1</item>" in xml
assert "<item>2</item>" in xml

View File

@@ -0,0 +1,282 @@
"""Tests for application/agents/tools/api_tool.py"""
import json
from unittest.mock import MagicMock, patch
import pytest
import requests
from application.agents.tools.api_tool import APITool
@pytest.fixture
def tool():
return APITool(
config={
"url": "https://api.example.com/data",
"method": "GET",
"headers": {"Accept": "application/json"},
"query_params": {},
}
)
@pytest.fixture
def post_tool():
return APITool(
config={
"url": "https://api.example.com/items",
"method": "POST",
"headers": {},
"query_params": {},
}
)
@pytest.mark.unit
class TestAPIToolInit:
def test_default_values(self):
tool = APITool(config={})
assert tool.url == ""
assert tool.method == "GET"
assert tool.headers == {}
assert tool.query_params == {}
@pytest.mark.unit
class TestMakeApiCall:
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_successful_get(self, mock_get, mock_validate, tool):
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {"result": "ok"}
mock_resp.content = b'{"result":"ok"}'
mock_get.return_value = mock_resp
result = tool.execute_action("any_action")
assert result["status_code"] == 200
assert result["data"] == {"result": "ok"}
assert result["message"] == "API call successful."
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.post")
def test_successful_post(self, mock_post, mock_validate, post_tool):
mock_resp = MagicMock()
mock_resp.status_code = 201
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {"id": 1}
mock_resp.content = b'{"id":1}'
mock_post.return_value = mock_resp
result = post_tool.execute_action("create", name="test")
assert result["status_code"] == 201
@patch("application.agents.tools.api_tool.validate_url")
def test_ssrf_blocked(self, mock_validate, tool):
from application.core.url_validation import SSRFError
mock_validate.side_effect = SSRFError("blocked")
result = tool.execute_action("any")
assert result["status_code"] is None
assert "URL validation error" in result["message"]
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_timeout_error(self, mock_get, mock_validate, tool):
mock_get.side_effect = requests.exceptions.Timeout()
result = tool.execute_action("any")
assert result["status_code"] is None
assert "timeout" in result["message"].lower()
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_connection_error(self, mock_get, mock_validate, tool):
mock_get.side_effect = requests.exceptions.ConnectionError("refused")
result = tool.execute_action("any")
assert result["status_code"] is None
assert "Connection error" in result["message"]
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_http_error(self, mock_get, mock_validate, tool):
mock_resp = MagicMock()
mock_resp.status_code = 404
mock_resp.text = "Not Found"
mock_resp.json.side_effect = json.JSONDecodeError("", "", 0)
mock_resp.raise_for_status.side_effect = requests.exceptions.HTTPError(
response=mock_resp
)
mock_get.return_value = mock_resp
result = tool.execute_action("any")
assert result["status_code"] == 404
assert "HTTP Error" in result["message"]
@patch("application.agents.tools.api_tool.validate_url")
def test_unsupported_method(self, mock_validate):
tool = APITool(
config={"url": "https://example.com", "method": "CUSTOM"}
)
result = tool.execute_action("any")
assert result["status_code"] is None
assert "Unsupported" in result["message"]
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.put")
def test_put_method(self, mock_put, mock_validate):
tool = APITool(config={"url": "https://example.com/item/1", "method": "PUT"})
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {}
mock_resp.content = b'{}'
mock_put.return_value = mock_resp
result = tool.execute_action("update", name="new")
assert result["status_code"] == 200
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.delete")
def test_delete_method(self, mock_delete, mock_validate):
tool = APITool(config={"url": "https://example.com/item/1", "method": "DELETE"})
mock_resp = MagicMock()
mock_resp.status_code = 204
mock_resp.headers = {"Content-Type": "text/plain"}
mock_resp.content = b''
mock_delete.return_value = mock_resp
result = tool.execute_action("delete")
assert result["status_code"] == 204
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.patch")
def test_patch_method(self, mock_patch, mock_validate):
tool = APITool(config={"url": "https://example.com/item/1", "method": "PATCH"})
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {"patched": True}
mock_resp.content = b'{"patched":true}'
mock_patch.return_value = mock_resp
result = tool.execute_action("patch", field="val")
assert result["status_code"] == 200
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.head")
def test_head_method(self, mock_head, mock_validate):
tool = APITool(config={"url": "https://example.com", "method": "HEAD"})
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "text/html"}
mock_resp.content = b''
mock_head.return_value = mock_resp
result = tool.execute_action("check")
assert result["status_code"] == 200
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.options")
def test_options_method(self, mock_options, mock_validate):
tool = APITool(config={"url": "https://example.com", "method": "OPTIONS"})
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "text/plain"}
mock_resp.content = b''
mock_options.return_value = mock_resp
result = tool.execute_action("options")
assert result["status_code"] == 200
@pytest.mark.unit
class TestPathParamSubstitution:
@patch("application.agents.tools.api_tool.validate_url")
@patch("application.agents.tools.api_tool.requests.get")
def test_path_params_substituted(self, mock_get, mock_validate):
tool = APITool(
config={
"url": "https://api.example.com/users/{user_id}/posts/{post_id}",
"method": "GET",
"query_params": {"user_id": "42", "post_id": "7"},
}
)
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = []
mock_resp.content = b'[]'
mock_get.return_value = mock_resp
tool.execute_action("get")
called_url = mock_get.call_args[0][0]
assert "/users/42/posts/7" in called_url
assert "{user_id}" not in called_url
@pytest.mark.unit
class TestParseResponse:
def test_json_response(self, tool):
mock_resp = MagicMock()
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.json.return_value = {"key": "val"}
mock_resp.content = b'{"key":"val"}'
result = tool._parse_response(mock_resp)
assert result == {"key": "val"}
def test_text_response(self, tool):
mock_resp = MagicMock()
mock_resp.headers = {"Content-Type": "text/plain"}
mock_resp.text = "plain text"
mock_resp.content = b"plain text"
result = tool._parse_response(mock_resp)
assert result == "plain text"
def test_xml_response(self, tool):
mock_resp = MagicMock()
mock_resp.headers = {"Content-Type": "application/xml"}
mock_resp.text = "<root><item>1</item></root>"
mock_resp.content = b"<root><item>1</item></root>"
result = tool._parse_response(mock_resp)
assert "<root>" in result
def test_empty_content(self, tool):
mock_resp = MagicMock()
mock_resp.headers = {"Content-Type": "application/json"}
mock_resp.content = b""
result = tool._parse_response(mock_resp)
assert result is None
def test_html_response(self, tool):
mock_resp = MagicMock()
mock_resp.headers = {"Content-Type": "text/html"}
mock_resp.text = "<html><body>Hi</body></html>"
mock_resp.content = b"<html><body>Hi</body></html>"
result = tool._parse_response(mock_resp)
assert "<html>" in result
@pytest.mark.unit
class TestAPIToolMetadata:
def test_actions_metadata_empty(self, tool):
assert tool.get_actions_metadata() == []
def test_config_requirements_empty(self, tool):
assert tool.get_config_requirements() == {}

View File

@@ -0,0 +1,132 @@
"""Tests for application/agents/tools/brave.py"""
from unittest.mock import MagicMock, patch
import pytest
from application.agents.tools.brave import BraveSearchTool
@pytest.fixture
def tool():
return BraveSearchTool(config={"token": "test_api_key"})
@pytest.mark.unit
class TestBraveExecuteAction:
def test_unknown_action_raises(self, tool):
with pytest.raises(ValueError, match="Unknown action"):
tool.execute_action("invalid")
@patch("application.agents.tools.brave.requests.get")
def test_web_search_success(self, mock_get, tool):
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {"web": {"results": [{"title": "Result"}]}}
mock_get.return_value = mock_resp
result = tool.execute_action("brave_web_search", query="python")
assert result["status_code"] == 200
assert "results" in result
assert "successfully" in result["message"]
call_kwargs = mock_get.call_args
assert call_kwargs[1]["headers"]["X-Subscription-Token"] == "test_api_key"
@patch("application.agents.tools.brave.requests.get")
def test_web_search_failure(self, mock_get, tool):
mock_resp = MagicMock()
mock_resp.status_code = 429
mock_get.return_value = mock_resp
result = tool.execute_action("brave_web_search", query="test")
assert result["status_code"] == 429
assert "failed" in result["message"].lower()
@patch("application.agents.tools.brave.requests.get")
def test_image_search_success(self, mock_get, tool):
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {"results": [{"url": "https://img.com/1.jpg"}]}
mock_get.return_value = mock_resp
result = tool.execute_action("brave_image_search", query="cats")
assert result["status_code"] == 200
assert "results" in result
@patch("application.agents.tools.brave.requests.get")
def test_image_search_failure(self, mock_get, tool):
mock_resp = MagicMock()
mock_resp.status_code = 500
mock_get.return_value = mock_resp
result = tool.execute_action("brave_image_search", query="cats")
assert result["status_code"] == 500
@patch("application.agents.tools.brave.requests.get")
def test_count_capped_at_20(self, mock_get, tool):
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {}
mock_get.return_value = mock_resp
tool.execute_action("brave_web_search", query="test", count=100)
params = mock_get.call_args[1]["params"]
assert params["count"] == 20
@patch("application.agents.tools.brave.requests.get")
def test_image_count_capped_at_100(self, mock_get, tool):
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {}
mock_get.return_value = mock_resp
tool.execute_action("brave_image_search", query="test", count=500)
params = mock_get.call_args[1]["params"]
assert params["count"] == 100
@patch("application.agents.tools.brave.requests.get")
def test_freshness_param(self, mock_get, tool):
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {}
mock_get.return_value = mock_resp
tool.execute_action("brave_web_search", query="news", freshness="pd")
params = mock_get.call_args[1]["params"]
assert params["freshness"] == "pd"
@patch("application.agents.tools.brave.requests.get")
def test_offset_capped(self, mock_get, tool):
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {}
mock_get.return_value = mock_resp
tool.execute_action("brave_web_search", query="test", offset=100)
params = mock_get.call_args[1]["params"]
assert params["offset"] == 9
@pytest.mark.unit
class TestBraveMetadata:
def test_actions_metadata(self, tool):
meta = tool.get_actions_metadata()
assert len(meta) == 2
names = {a["name"] for a in meta}
assert "brave_web_search" in names
assert "brave_image_search" in names
def test_config_requirements(self, tool):
reqs = tool.get_config_requirements()
assert "token" in reqs
assert reqs["token"]["secret"] is True
assert reqs["token"]["required"] is True

View File

@@ -0,0 +1,169 @@
"""Tests for application/agents/workflows/cel_evaluator.py"""
import pytest
from application.agents.workflows.cel_evaluator import (
CelEvaluationError,
_convert_value,
build_activation,
cel_to_python,
evaluate_cel,
)
import celpy.celtypes
class TestConvertValue:
@pytest.mark.unit
def test_bool_true(self):
result = _convert_value(True)
assert isinstance(result, celpy.celtypes.BoolType)
assert bool(result) is True
@pytest.mark.unit
def test_bool_false(self):
result = _convert_value(False)
assert isinstance(result, celpy.celtypes.BoolType)
assert bool(result) is False
@pytest.mark.unit
def test_int(self):
result = _convert_value(42)
assert isinstance(result, celpy.celtypes.IntType)
assert int(result) == 42
@pytest.mark.unit
def test_float(self):
result = _convert_value(3.14)
assert isinstance(result, celpy.celtypes.DoubleType)
assert float(result) == pytest.approx(3.14)
@pytest.mark.unit
def test_string(self):
result = _convert_value("hello")
assert isinstance(result, celpy.celtypes.StringType)
assert str(result) == "hello"
@pytest.mark.unit
def test_list(self):
result = _convert_value([1, "two", 3.0])
assert isinstance(result, celpy.celtypes.ListType)
@pytest.mark.unit
def test_dict(self):
result = _convert_value({"key": "value"})
assert isinstance(result, celpy.celtypes.MapType)
@pytest.mark.unit
def test_none(self):
result = _convert_value(None)
assert isinstance(result, celpy.celtypes.BoolType)
assert bool(result) is False
@pytest.mark.unit
def test_other_type_converts_to_string(self):
result = _convert_value(object())
assert isinstance(result, celpy.celtypes.StringType)
class TestBuildActivation:
@pytest.mark.unit
def test_converts_dict_values(self):
state = {"name": "Alice", "age": 30, "active": True}
result = build_activation(state)
assert "name" in result
assert "age" in result
assert "active" in result
@pytest.mark.unit
def test_empty_state(self):
assert build_activation({}) == {}
class TestEvaluateCel:
@pytest.mark.unit
def test_simple_comparison(self):
assert evaluate_cel("x > 5", {"x": 10}) is True
assert evaluate_cel("x > 5", {"x": 3}) is False
@pytest.mark.unit
def test_string_comparison(self):
assert evaluate_cel('name == "Alice"', {"name": "Alice"}) is True
assert evaluate_cel('name == "Alice"', {"name": "Bob"}) is False
@pytest.mark.unit
def test_arithmetic(self):
assert evaluate_cel("x + y", {"x": 3, "y": 4}) == 7
@pytest.mark.unit
def test_boolean_logic(self):
assert evaluate_cel("a && b", {"a": True, "b": True}) is True
assert evaluate_cel("a && b", {"a": True, "b": False}) is False
assert evaluate_cel("a || b", {"a": False, "b": True}) is True
@pytest.mark.unit
def test_empty_expression_raises(self):
with pytest.raises(CelEvaluationError, match="Empty expression"):
evaluate_cel("", {})
@pytest.mark.unit
def test_whitespace_expression_raises(self):
with pytest.raises(CelEvaluationError, match="Empty expression"):
evaluate_cel(" ", {})
@pytest.mark.unit
def test_invalid_expression_raises(self):
with pytest.raises(CelEvaluationError):
evaluate_cel("invalid!!!", {})
@pytest.mark.unit
def test_missing_variable_raises(self):
with pytest.raises(CelEvaluationError):
evaluate_cel("undefined_var > 5", {})
class TestCelToPython:
@pytest.mark.unit
def test_bool(self):
result = cel_to_python(celpy.celtypes.BoolType(True))
assert result is True
@pytest.mark.unit
def test_int(self):
result = cel_to_python(celpy.celtypes.IntType(42))
assert result == 42
@pytest.mark.unit
def test_double(self):
result = cel_to_python(celpy.celtypes.DoubleType(3.14))
assert result == pytest.approx(3.14)
@pytest.mark.unit
def test_string(self):
result = cel_to_python(celpy.celtypes.StringType("hello"))
assert result == "hello"
@pytest.mark.unit
def test_list(self):
cel_list = celpy.celtypes.ListType([
celpy.celtypes.IntType(1),
celpy.celtypes.IntType(2),
])
result = cel_to_python(cel_list)
assert result == [1, 2]
@pytest.mark.unit
def test_map(self):
cel_map = celpy.celtypes.MapType({
celpy.celtypes.StringType("key"): celpy.celtypes.StringType("value"),
})
result = cel_to_python(cel_map)
assert result == {"key": "value"}
@pytest.mark.unit
def test_unknown_type_passthrough(self):
result = cel_to_python("raw_value")
assert result == "raw_value"

View File

@@ -0,0 +1,85 @@
"""Tests for application/agents/tools/cryptoprice.py"""
from unittest.mock import MagicMock, patch
import pytest
from application.agents.tools.cryptoprice import CryptoPriceTool
@pytest.fixture
def tool():
return CryptoPriceTool(config={})
@pytest.mark.unit
class TestCryptoPriceExecuteAction:
def test_unknown_action_raises(self, tool):
with pytest.raises(ValueError, match="Unknown action"):
tool.execute_action("invalid_action")
@patch("application.agents.tools.cryptoprice.requests.get")
def test_successful_price_fetch(self, mock_get, tool):
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {"USD": 65000}
mock_get.return_value = mock_resp
result = tool.execute_action("cryptoprice_get", symbol="BTC", currency="USD")
assert result["status_code"] == 200
assert result["price"] == 65000
assert "successfully" in result["message"]
@patch("application.agents.tools.cryptoprice.requests.get")
def test_currency_not_found(self, mock_get, tool):
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {"EUR": 60000}
mock_get.return_value = mock_resp
result = tool.execute_action("cryptoprice_get", symbol="BTC", currency="USD")
assert result["status_code"] == 200
assert "Couldn't find" in result["message"]
assert "price" not in result
@patch("application.agents.tools.cryptoprice.requests.get")
def test_api_failure(self, mock_get, tool):
mock_resp = MagicMock()
mock_resp.status_code = 500
mock_get.return_value = mock_resp
result = tool.execute_action("cryptoprice_get", symbol="BTC", currency="USD")
assert result["status_code"] == 500
assert "Failed" in result["message"]
@patch("application.agents.tools.cryptoprice.requests.get")
def test_symbol_case_insensitive(self, mock_get, tool):
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {"USD": 100}
mock_get.return_value = mock_resp
tool.execute_action("cryptoprice_get", symbol="btc", currency="usd")
called_url = mock_get.call_args[0][0]
assert "fsym=BTC" in called_url
assert "tsyms=USD" in called_url
@pytest.mark.unit
class TestCryptoPriceMetadata:
def test_actions_metadata(self, tool):
meta = tool.get_actions_metadata()
assert len(meta) == 1
assert meta[0]["name"] == "cryptoprice_get"
params = meta[0]["parameters"]
assert "symbol" in params["properties"]
assert "currency" in params["properties"]
assert "symbol" in params["required"]
assert "currency" in params["required"]
def test_config_requirements(self, tool):
assert tool.get_config_requirements() == {}

View File

@@ -0,0 +1,145 @@
"""Tests for application/agents/tools/duckduckgo.py"""
from unittest.mock import MagicMock, patch
import pytest
from application.agents.tools.duckduckgo import DuckDuckGoSearchTool
@pytest.fixture
def tool():
return DuckDuckGoSearchTool(config={})
@pytest.mark.unit
class TestDuckDuckGoExecuteAction:
def test_unknown_action_raises(self, tool):
with pytest.raises(ValueError, match="Unknown action"):
tool.execute_action("invalid")
@patch.object(DuckDuckGoSearchTool, "_get_ddgs_client")
def test_web_search_success(self, mock_client_factory, tool):
mock_client = MagicMock()
mock_client.text.return_value = [
{"title": "Result 1", "href": "https://example.com", "body": "snippet"}
]
mock_client_factory.return_value = mock_client
result = tool.execute_action("ddg_web_search", query="python")
assert result["status_code"] == 200
assert len(result["results"]) == 1
assert "successfully" in result["message"]
@patch.object(DuckDuckGoSearchTool, "_get_ddgs_client")
def test_image_search_success(self, mock_client_factory, tool):
mock_client = MagicMock()
mock_client.images.return_value = [{"image": "https://img.com/1.jpg"}]
mock_client_factory.return_value = mock_client
result = tool.execute_action("ddg_image_search", query="cats")
assert result["status_code"] == 200
assert len(result["results"]) == 1
@patch.object(DuckDuckGoSearchTool, "_get_ddgs_client")
def test_news_search_success(self, mock_client_factory, tool):
mock_client = MagicMock()
mock_client.news.return_value = [{"title": "News"}]
mock_client_factory.return_value = mock_client
result = tool.execute_action("ddg_news_search", query="tech")
assert result["status_code"] == 200
assert len(result["results"]) == 1
@patch.object(DuckDuckGoSearchTool, "_get_ddgs_client")
def test_search_error_returns_500(self, mock_client_factory, tool):
mock_client = MagicMock()
mock_client.text.side_effect = Exception("Network error")
mock_client_factory.return_value = mock_client
result = tool.execute_action("ddg_web_search", query="test")
assert result["status_code"] == 500
assert "failed" in result["message"].lower()
assert result["results"] == []
@patch.object(DuckDuckGoSearchTool, "_get_ddgs_client")
def test_max_results_capped_at_20(self, mock_client_factory, tool):
mock_client = MagicMock()
mock_client.text.return_value = []
mock_client_factory.return_value = mock_client
tool.execute_action("ddg_web_search", query="test", max_results=100)
call_kwargs = mock_client.text.call_args[1]
assert call_kwargs["max_results"] == 20
@patch.object(DuckDuckGoSearchTool, "_get_ddgs_client")
def test_image_max_results_capped_at_50(self, mock_client_factory, tool):
mock_client = MagicMock()
mock_client.images.return_value = []
mock_client_factory.return_value = mock_client
tool.execute_action("ddg_image_search", query="test", max_results=200)
call_kwargs = mock_client.images.call_args[1]
assert call_kwargs["max_results"] == 50
@patch("application.agents.tools.duckduckgo.time.sleep")
@patch.object(DuckDuckGoSearchTool, "_get_ddgs_client")
def test_rate_limit_retries(self, mock_client_factory, mock_sleep, tool):
mock_client = MagicMock()
mock_client.text.side_effect = [
Exception("RateLimit exceeded"),
[{"title": "Result"}],
]
mock_client_factory.return_value = mock_client
result = tool.execute_action("ddg_web_search", query="test")
assert result["status_code"] == 200
assert len(result["results"]) == 1
mock_sleep.assert_called_once()
@patch.object(DuckDuckGoSearchTool, "_get_ddgs_client")
def test_empty_results(self, mock_client_factory, tool):
mock_client = MagicMock()
mock_client.text.return_value = []
mock_client_factory.return_value = mock_client
result = tool.execute_action("ddg_web_search", query="obscure query")
assert result["status_code"] == 200
assert result["results"] == []
@patch.object(DuckDuckGoSearchTool, "_get_ddgs_client")
def test_none_results(self, mock_client_factory, tool):
mock_client = MagicMock()
mock_client.text.return_value = None
mock_client_factory.return_value = mock_client
result = tool.execute_action("ddg_web_search", query="test")
assert result["status_code"] == 200
assert result["results"] == []
@pytest.mark.unit
class TestDuckDuckGoMetadata:
def test_actions_metadata(self, tool):
meta = tool.get_actions_metadata()
assert len(meta) == 3
names = {a["name"] for a in meta}
assert "ddg_web_search" in names
assert "ddg_image_search" in names
assert "ddg_news_search" in names
def test_config_requirements(self, tool):
assert tool.get_config_requirements() == {}
def test_custom_timeout(self):
tool = DuckDuckGoSearchTool(config={"timeout": 30})
assert tool.timeout == 30

View File

@@ -0,0 +1,519 @@
"""Tests for application/agents/tools/mcp_tool.py"""
from unittest.mock import MagicMock, patch
import pytest
# mcp_tool has a circular import at module level (mcp_tool -> tasks -> user -> mcp.py -> mcp_tool).
# We must patch the dependencies BEFORE the module is first imported.
@pytest.fixture(autouse=True)
def _patch_mcp_globals(monkeypatch):
"""Patch module-level MongoDB and cache to avoid real connections."""
import sys
# If the module is already loaded, just patch attributes directly
if "application.agents.tools.mcp_tool" in sys.modules:
mcp_mod = sys.modules["application.agents.tools.mcp_tool"]
else:
# Break the circular import by pre-populating the tasks import
# with a mock before mcp_tool tries to import it
mock_tasks = MagicMock()
monkeypatch.setitem(sys.modules, "application.api.user.tasks", mock_tasks)
import application.agents.tools.mcp_tool as mcp_mod
mock_mongo = MagicMock()
mock_db = MagicMock()
mock_db.__getitem__ = MagicMock(return_value=MagicMock())
monkeypatch.setattr(mcp_mod, "mongo", mock_mongo)
monkeypatch.setattr(mcp_mod, "db", mock_db)
monkeypatch.setattr(mcp_mod, "_mcp_clients_cache", {})
@pytest.fixture
def mcp_config():
return {
"server_url": "https://mcp.example.com/api",
"transport_type": "http",
"auth_type": "none",
"timeout": 10,
}
@pytest.fixture
def bearer_config():
return {
"server_url": "https://mcp.example.com/api",
"transport_type": "http",
"auth_type": "bearer",
"auth_credentials": {"bearer_token": "tok_123"},
"timeout": 10,
}
@pytest.mark.unit
class TestMCPToolInit:
def test_basic_init(self, mcp_config):
from application.agents.tools.mcp_tool import MCPTool
with patch.object(MCPTool, "_setup_client"):
tool = MCPTool(mcp_config)
assert tool.server_url == "https://mcp.example.com/api"
assert tool.transport_type == "http"
assert tool.auth_type == "none"
assert tool.timeout == 10
def test_bearer_auth_credentials(self, bearer_config):
from application.agents.tools.mcp_tool import MCPTool
with patch.object(MCPTool, "_setup_client"):
tool = MCPTool(bearer_config)
assert tool.auth_credentials["bearer_token"] == "tok_123"
def test_no_server_url_skips_setup(self):
from application.agents.tools.mcp_tool import MCPTool
with patch.object(MCPTool, "_setup_client") as mock_setup:
MCPTool({"server_url": "", "auth_type": "none"})
mock_setup.assert_not_called()
def test_oauth_skips_setup(self):
from application.agents.tools.mcp_tool import MCPTool
with patch.object(MCPTool, "_setup_client") as mock_setup:
MCPTool(
{
"server_url": "https://mcp.example.com",
"auth_type": "oauth",
}
)
mock_setup.assert_not_called()
@pytest.mark.unit
class TestGenerateCacheKey:
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_none_auth(self, mock_setup, mcp_config):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(mcp_config)
assert "none" in tool._cache_key
assert "mcp.example.com" in tool._cache_key
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_bearer_auth(self, mock_setup, bearer_config):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(bearer_config)
assert "bearer:" in tool._cache_key
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_api_key_auth(self, mock_setup):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(
{
"server_url": "https://mcp.example.com",
"auth_type": "api_key",
"auth_credentials": {"api_key": "sk-test12345678"},
}
)
assert "apikey:" in tool._cache_key
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_basic_auth(self, mock_setup):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(
{
"server_url": "https://mcp.example.com",
"auth_type": "basic",
"auth_credentials": {"username": "user1", "password": "pass"},
}
)
assert "basic:user1" in tool._cache_key
@pytest.mark.unit
class TestCreateTransport:
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_http_transport(self, mock_setup, mcp_config):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(mcp_config)
transport = tool._create_transport()
assert transport is not None
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_sse_transport(self, mock_setup):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(
{
"server_url": "https://mcp.example.com/sse",
"transport_type": "sse",
"auth_type": "none",
}
)
transport = tool._create_transport()
assert transport is not None
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_auto_detects_sse(self, mock_setup):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(
{
"server_url": "https://mcp.example.com/sse",
"transport_type": "auto",
"auth_type": "none",
}
)
transport = tool._create_transport()
assert transport is not None
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_stdio_transport_disabled(self, mock_setup):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(
{
"server_url": "https://mcp.example.com",
"transport_type": "stdio",
"auth_type": "none",
}
)
with pytest.raises(ValueError, match="STDIO transport is disabled"):
tool._create_transport()
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_api_key_header_injected(self, mock_setup):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(
{
"server_url": "https://mcp.example.com",
"transport_type": "http",
"auth_type": "api_key",
"auth_credentials": {
"api_key": "sk-test",
"api_key_header": "X-Custom-Key",
},
}
)
# _create_transport will be called; verify it doesn't raise
transport = tool._create_transport()
assert transport is not None
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_basic_auth_header_injected(self, mock_setup):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(
{
"server_url": "https://mcp.example.com",
"transport_type": "http",
"auth_type": "basic",
"auth_credentials": {"username": "user", "password": "pass"},
}
)
transport = tool._create_transport()
assert transport is not None
@pytest.mark.unit
class TestFormatTools:
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_format_list_of_dicts(self, mock_setup, mcp_config):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(mcp_config)
result = tool._format_tools([{"name": "tool1", "description": "desc"}])
assert len(result) == 1
assert result[0]["name"] == "tool1"
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_format_tools_with_name_attribute(self, mock_setup, mcp_config):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(mcp_config)
mock_tool = MagicMock()
mock_tool.name = "my_tool"
mock_tool.description = "A tool"
mock_tool.inputSchema = {"type": "object", "properties": {}}
result = tool._format_tools([mock_tool])
assert len(result) == 1
assert result[0]["name"] == "my_tool"
assert result[0]["inputSchema"] == {"type": "object", "properties": {}}
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_format_tools_response_object(self, mock_setup, mcp_config):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(mcp_config)
resp = MagicMock()
resp.tools = [{"name": "t1", "description": "d1"}]
result = tool._format_tools(resp)
assert len(result) == 1
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_format_empty(self, mock_setup, mcp_config):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(mcp_config)
assert tool._format_tools([]) == []
assert tool._format_tools("unexpected") == []
@pytest.mark.unit
class TestFormatResult:
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_format_result_with_content(self, mock_setup, mcp_config):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(mcp_config)
mock_result = MagicMock()
text_item = MagicMock()
text_item.text = "Hello"
del text_item.data
mock_result.content = [text_item]
mock_result.isError = False
result = tool._format_result(mock_result)
assert result["content"][0]["type"] == "text"
assert result["content"][0]["text"] == "Hello"
assert result["isError"] is False
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_format_raw_result(self, mock_setup, mcp_config):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(mcp_config)
raw = {"key": "value"}
assert tool._format_result(raw) == raw
@pytest.mark.unit
class TestExecuteAction:
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_no_server_raises(self, mock_setup):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool({"server_url": "", "auth_type": "none"})
with pytest.raises(Exception, match="No MCP server configured"):
tool.execute_action("test_action")
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
@patch("application.agents.tools.mcp_tool.MCPTool._run_async_operation")
def test_successful_execute(self, mock_run, mock_setup, mcp_config):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(mcp_config)
tool._client = MagicMock()
mock_run.return_value = {"key": "value"}
result = tool.execute_action("test_action", param1="val1")
mock_run.assert_called_once_with("call_tool", "test_action", param1="val1")
assert result == {"key": "value"}
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
@patch("application.agents.tools.mcp_tool.MCPTool._run_async_operation")
def test_empty_kwargs_cleaned(self, mock_run, mock_setup, mcp_config):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(mcp_config)
tool._client = MagicMock()
mock_run.return_value = {}
tool.execute_action("test", param1="", param2=None, param3="real")
call_kwargs = mock_run.call_args[1]
assert "param1" not in call_kwargs
assert "param2" not in call_kwargs
assert call_kwargs["param3"] == "real"
@pytest.mark.unit
class TestTestConnection:
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_no_server_url(self, mock_setup):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool({"server_url": "", "auth_type": "none"})
result = tool.test_connection()
assert result["success"] is False
assert "No server URL" in result["message"]
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_invalid_scheme(self, mock_setup):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(
{"server_url": "ftp://bad.com", "auth_type": "none"}
)
result = tool.test_connection()
assert result["success"] is False
assert "Invalid URL scheme" in result["message"]
@pytest.mark.unit
class TestMapError:
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_timeout_error(self, mock_setup, mcp_config):
from application.agents.tools.mcp_tool import MCPTool
import concurrent.futures
tool = MCPTool(mcp_config)
err = tool._map_error("test", concurrent.futures.TimeoutError())
assert "Timed out" in str(err)
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_connection_refused(self, mock_setup, mcp_config):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(mcp_config)
err = tool._map_error("test", ConnectionRefusedError())
assert "Connection refused" in str(err)
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_403_forbidden(self, mock_setup, mcp_config):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(mcp_config)
err = tool._map_error("test", Exception("403 Forbidden"))
assert "Access denied" in str(err)
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_ssl_error(self, mock_setup, mcp_config):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(mcp_config)
err = tool._map_error("test", Exception("SSL certificate verify failed"))
assert "SSL" in str(err)
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_unknown_error_passthrough(self, mock_setup, mcp_config):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(mcp_config)
original = RuntimeError("something weird")
err = tool._map_error("test", original)
assert err is original
@pytest.mark.unit
class TestGetActionsMetadata:
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_empty_tools(self, mock_setup, mcp_config):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(mcp_config)
tool.available_tools = []
assert tool.get_actions_metadata() == []
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_tools_with_input_schema(self, mock_setup, mcp_config):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(mcp_config)
tool.available_tools = [
{
"name": "search",
"description": "Search things",
"inputSchema": {
"type": "object",
"properties": {"query": {"type": "string"}},
"required": ["query"],
},
}
]
meta = tool.get_actions_metadata()
assert len(meta) == 1
assert meta[0]["name"] == "search"
assert "query" in meta[0]["parameters"]["properties"]
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_tools_without_schema(self, mock_setup, mcp_config):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(mcp_config)
tool.available_tools = [{"name": "ping", "description": "Ping"}]
meta = tool.get_actions_metadata()
assert meta[0]["parameters"]["properties"] == {}
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
def test_config_requirements(self, mock_setup, mcp_config):
from application.agents.tools.mcp_tool import MCPTool
tool = MCPTool(mcp_config)
reqs = tool.get_config_requirements()
assert "server_url" in reqs
assert "auth_type" in reqs
assert reqs["server_url"]["required"] is True
@pytest.mark.unit
class TestMCPOAuthManager:
def test_handle_callback_success(self):
from application.agents.tools.mcp_tool import MCPOAuthManager
mock_redis = MagicMock()
manager = MCPOAuthManager(mock_redis)
result = manager.handle_oauth_callback(state="abc123", code="auth_code")
assert result is True
mock_redis.setex.assert_called()
def test_handle_callback_no_redis(self):
from application.agents.tools.mcp_tool import MCPOAuthManager
manager = MCPOAuthManager(None)
result = manager.handle_oauth_callback(state="abc", code="code")
assert result is False
def test_handle_callback_error(self):
from application.agents.tools.mcp_tool import MCPOAuthManager
mock_redis = MagicMock()
manager = MCPOAuthManager(mock_redis)
result = manager.handle_oauth_callback(
state="abc", code="", error="access_denied"
)
assert result is False
def test_get_oauth_status_no_task(self):
from application.agents.tools.mcp_tool import MCPOAuthManager
manager = MCPOAuthManager(MagicMock())
result = manager.get_oauth_status("")
assert result["status"] == "not_started"
@pytest.mark.unit
class TestDBTokenStorage:
def test_get_base_url(self):
from application.agents.tools.mcp_tool import DBTokenStorage
assert (
DBTokenStorage.get_base_url("https://mcp.example.com/api/v1")
== "https://mcp.example.com"
)
def test_get_db_key(self):
from application.agents.tools.mcp_tool import DBTokenStorage
mock_db = MagicMock()
storage = DBTokenStorage(
server_url="https://mcp.example.com/api",
user_id="user1",
db_client=mock_db,
)
key = storage.get_db_key()
assert key["server_url"] == "https://mcp.example.com"
assert key["user_id"] == "user1"

View File

@@ -0,0 +1,146 @@
"""Tests for application/agents/tools/ntfy.py"""
from unittest.mock import MagicMock, patch
import pytest
from application.agents.tools.ntfy import NtfyTool
@pytest.fixture
def tool():
return NtfyTool(config={"token": "test_token"})
@pytest.fixture
def tool_no_token():
return NtfyTool(config={})
@pytest.mark.unit
class TestNtfyExecuteAction:
def test_unknown_action_raises(self, tool):
with pytest.raises(ValueError, match="Unknown action"):
tool.execute_action("bad_action")
@patch("application.agents.tools.ntfy.requests.post")
def test_send_message_basic(self, mock_post, tool):
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_post.return_value = mock_resp
result = tool.execute_action(
"ntfy_send_message",
server_url="https://ntfy.sh",
message="Hello",
topic="test",
)
assert result["status_code"] == 200
assert result["message"] == "Message sent"
mock_post.assert_called_once()
call_args = mock_post.call_args
assert call_args[0][0] == "https://ntfy.sh/test"
assert call_args[1]["data"] == b"Hello"
@patch("application.agents.tools.ntfy.requests.post")
def test_send_with_title_and_priority(self, mock_post, tool):
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_post.return_value = mock_resp
tool.execute_action(
"ntfy_send_message",
server_url="https://ntfy.sh",
message="Alert",
topic="urgent",
title="Warning",
priority=5,
)
headers = mock_post.call_args[1]["headers"]
assert headers["X-Title"] == "Warning"
assert headers["X-Priority"] == "5"
@patch("application.agents.tools.ntfy.requests.post")
def test_auth_header_with_token(self, mock_post, tool):
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_post.return_value = mock_resp
tool.execute_action(
"ntfy_send_message",
server_url="https://ntfy.sh",
message="Hi",
topic="t",
)
headers = mock_post.call_args[1]["headers"]
assert headers["Authorization"] == "Basic test_token"
@patch("application.agents.tools.ntfy.requests.post")
def test_no_auth_without_token(self, mock_post, tool_no_token):
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_post.return_value = mock_resp
tool_no_token.execute_action(
"ntfy_send_message",
server_url="https://ntfy.sh",
message="Hi",
topic="t",
)
headers = mock_post.call_args[1]["headers"]
assert "Authorization" not in headers
def test_invalid_priority_raises(self, tool):
with pytest.raises(ValueError, match="between 1 and 5"):
tool.execute_action(
"ntfy_send_message",
server_url="https://ntfy.sh",
message="Hi",
topic="t",
priority=10,
)
def test_non_numeric_priority_raises(self, tool):
with pytest.raises(ValueError, match="convertible to an integer"):
tool.execute_action(
"ntfy_send_message",
server_url="https://ntfy.sh",
message="Hi",
topic="t",
priority="abc",
)
@patch("application.agents.tools.ntfy.requests.post")
def test_trailing_slash_stripped(self, mock_post, tool):
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_post.return_value = mock_resp
tool.execute_action(
"ntfy_send_message",
server_url="https://ntfy.sh/",
message="Hi",
topic="test",
)
assert mock_post.call_args[0][0] == "https://ntfy.sh/test"
@pytest.mark.unit
class TestNtfyMetadata:
def test_actions_metadata(self, tool):
meta = tool.get_actions_metadata()
assert len(meta) == 1
assert meta[0]["name"] == "ntfy_send_message"
assert "server_url" in meta[0]["parameters"]["properties"]
assert "message" in meta[0]["parameters"]["properties"]
assert "topic" in meta[0]["parameters"]["properties"]
def test_config_requirements(self, tool):
reqs = tool.get_config_requirements()
assert "token" in reqs
assert reqs["token"]["secret"] is True

View File

@@ -0,0 +1,146 @@
"""Tests for application/agents/tools/postgres.py"""
from unittest.mock import MagicMock, patch
import pytest
from application.agents.tools.postgres import PostgresTool
@pytest.fixture
def tool():
return PostgresTool(config={"token": "postgresql://user:pass@localhost/testdb"})
@pytest.mark.unit
class TestPostgresExecuteAction:
def test_unknown_action_raises(self, tool):
with pytest.raises(ValueError, match="Unknown action"):
tool.execute_action("invalid")
@patch("application.agents.tools.postgres.psycopg2.connect")
def test_select_query(self, mock_connect, tool):
mock_conn = MagicMock()
mock_cur = MagicMock()
mock_cur.description = [("id",), ("name",)]
mock_cur.fetchall.return_value = [(1, "Alice"), (2, "Bob")]
mock_conn.cursor.return_value = mock_cur
mock_connect.return_value = mock_conn
result = tool.execute_action(
"postgres_execute_sql", sql_query="SELECT id, name FROM users"
)
assert result["status_code"] == 200
assert result["response_data"]["column_names"] == ["id", "name"]
assert len(result["response_data"]["data"]) == 2
assert result["response_data"]["data"][0] == {"id": 1, "name": "Alice"}
mock_conn.close.assert_called_once()
@patch("application.agents.tools.postgres.psycopg2.connect")
def test_insert_query(self, mock_connect, tool):
mock_conn = MagicMock()
mock_cur = MagicMock()
mock_cur.rowcount = 1
mock_conn.cursor.return_value = mock_cur
mock_connect.return_value = mock_conn
result = tool.execute_action(
"postgres_execute_sql",
sql_query="INSERT INTO users (name) VALUES ('Alice')",
)
assert result["status_code"] == 200
assert "1 rows affected" in result["response_data"]["message"]
mock_conn.commit.assert_called_once()
mock_conn.close.assert_called_once()
@patch("application.agents.tools.postgres.psycopg2.connect")
def test_db_error(self, mock_connect, tool):
import psycopg2
mock_connect.side_effect = psycopg2.Error("connection refused")
result = tool.execute_action(
"postgres_execute_sql", sql_query="SELECT 1"
)
assert result["status_code"] == 500
assert "Database error" in result["error"]
@patch("application.agents.tools.postgres.psycopg2.connect")
def test_get_schema(self, mock_connect, tool):
mock_conn = MagicMock()
mock_cur = MagicMock()
mock_cur.fetchall.return_value = [
("users", "id", "integer", "nextval(...)", "NO"),
("users", "name", "varchar", None, "YES"),
("posts", "id", "integer", "nextval(...)", "NO"),
]
mock_conn.cursor.return_value = mock_cur
mock_connect.return_value = mock_conn
result = tool.execute_action("postgres_get_schema", db_name="testdb")
assert result["status_code"] == 200
assert "users" in result["schema"]
assert "posts" in result["schema"]
assert len(result["schema"]["users"]) == 2
assert result["schema"]["users"][0]["column_name"] == "id"
mock_conn.close.assert_called_once()
@patch("application.agents.tools.postgres.psycopg2.connect")
def test_get_schema_db_error(self, mock_connect, tool):
import psycopg2
mock_connect.side_effect = psycopg2.Error("auth failed")
result = tool.execute_action("postgres_get_schema", db_name="testdb")
assert result["status_code"] == 500
assert "Database error" in result["error"]
@patch("application.agents.tools.postgres.psycopg2.connect")
def test_connection_closed_on_error(self, mock_connect, tool):
import psycopg2
mock_conn = MagicMock()
mock_cur = MagicMock()
mock_cur.execute.side_effect = psycopg2.Error("syntax error")
mock_conn.cursor.return_value = mock_cur
mock_connect.return_value = mock_conn
tool.execute_action("postgres_execute_sql", sql_query="BAD SQL")
mock_conn.close.assert_called_once()
@patch("application.agents.tools.postgres.psycopg2.connect")
def test_select_with_no_description(self, mock_connect, tool):
mock_conn = MagicMock()
mock_cur = MagicMock()
mock_cur.description = None
mock_cur.fetchall.return_value = []
mock_conn.cursor.return_value = mock_cur
mock_connect.return_value = mock_conn
result = tool.execute_action(
"postgres_execute_sql", sql_query="SELECT 1 WHERE false"
)
assert result["status_code"] == 200
assert result["response_data"]["column_names"] == []
@pytest.mark.unit
class TestPostgresMetadata:
def test_actions_metadata(self, tool):
meta = tool.get_actions_metadata()
assert len(meta) == 2
names = {a["name"] for a in meta}
assert "postgres_execute_sql" in names
assert "postgres_get_schema" in names
def test_config_requirements(self, tool):
reqs = tool.get_config_requirements()
assert "token" in reqs
assert reqs["token"]["secret"] is True

View File

@@ -0,0 +1,86 @@
"""Tests for application/agents/tools/read_webpage.py"""
from unittest.mock import MagicMock, patch
import pytest
import requests
from application.agents.tools.read_webpage import ReadWebpageTool
@pytest.fixture
def tool():
return ReadWebpageTool()
@pytest.mark.unit
class TestReadWebpageExecuteAction:
def test_unknown_action(self, tool):
result = tool.execute_action("unknown_action")
assert "Error" in result
assert "Unknown action" in result
def test_missing_url(self, tool):
result = tool.execute_action("read_webpage")
assert "Error" in result
assert "URL parameter is missing" in result
@patch("application.agents.tools.read_webpage.validate_url")
@patch("application.agents.tools.read_webpage.requests.get")
def test_successful_fetch(self, mock_get, mock_validate, tool):
mock_validate.return_value = "https://example.com"
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.text = "<html><body><h1>Title</h1><p>Content</p></body></html>"
mock_get.return_value = mock_resp
result = tool.execute_action("read_webpage", url="https://example.com")
assert "Title" in result
assert "Content" in result
@patch("application.agents.tools.read_webpage.validate_url")
@patch("application.agents.tools.read_webpage.requests.get")
def test_request_error(self, mock_get, mock_validate, tool):
mock_validate.return_value = "https://example.com"
mock_get.side_effect = requests.exceptions.ConnectionError("refused")
result = tool.execute_action("read_webpage", url="https://example.com")
assert "Error fetching URL" in result
@patch("application.agents.tools.read_webpage.validate_url")
def test_ssrf_blocked(self, mock_validate, tool):
from application.core.url_validation import SSRFError
mock_validate.side_effect = SSRFError("blocked")
result = tool.execute_action("read_webpage", url="http://169.254.169.254/")
assert "Error" in result
assert "validation failed" in result
@patch("application.agents.tools.read_webpage.validate_url")
@patch("application.agents.tools.read_webpage.requests.get")
def test_http_error(self, mock_get, mock_validate, tool):
mock_validate.return_value = "https://example.com/404"
mock_resp = MagicMock()
mock_resp.raise_for_status.side_effect = requests.exceptions.HTTPError("404")
mock_get.return_value = mock_resp
result = tool.execute_action("read_webpage", url="https://example.com/404")
assert "Error fetching URL" in result
@pytest.mark.unit
class TestReadWebpageMetadata:
def test_actions_metadata(self, tool):
meta = tool.get_actions_metadata()
assert len(meta) == 1
assert meta[0]["name"] == "read_webpage"
assert "url" in meta[0]["parameters"]["properties"]
assert "url" in meta[0]["parameters"]["required"]
def test_config_requirements(self, tool):
assert tool.get_config_requirements() == {}

View File

@@ -0,0 +1,335 @@
"""Tests for application/agents/tools/spec_parser.py"""
import json
import pytest
from application.agents.tools.spec_parser import (
_extract_metadata,
_generate_action_name,
_get_base_url,
_load_spec,
_param_to_property,
_resolve_ref,
_validate_spec,
parse_spec,
)
MINIMAL_OPENAPI = json.dumps(
{
"openapi": "3.0.0",
"info": {"title": "Test API", "version": "1.0.0"},
"servers": [{"url": "https://api.example.com"}],
"paths": {
"/users": {
"get": {
"operationId": "listUsers",
"summary": "List all users",
"parameters": [
{
"name": "limit",
"in": "query",
"schema": {"type": "integer"},
}
],
"responses": {"200": {"description": "OK"}},
}
}
},
}
)
MINIMAL_SWAGGER = json.dumps(
{
"swagger": "2.0",
"info": {"title": "Swagger API", "version": "2.0.0"},
"host": "api.example.com",
"basePath": "/v1",
"schemes": ["https"],
"paths": {
"/pets": {
"get": {
"operationId": "listPets",
"summary": "List pets",
"parameters": [],
"responses": {"200": {"description": "OK"}},
}
}
},
}
)
@pytest.mark.unit
class TestLoadSpec:
def test_load_json(self):
spec = _load_spec('{"openapi": "3.0.0"}')
assert spec["openapi"] == "3.0.0"
def test_load_yaml(self):
spec = _load_spec("openapi: '3.0.0'\ninfo:\n title: Test")
assert spec["openapi"] == "3.0.0"
def test_empty_raises(self):
with pytest.raises(ValueError, match="Empty"):
_load_spec("")
def test_whitespace_only_raises(self):
with pytest.raises(ValueError, match="Empty"):
_load_spec(" \n ")
def test_invalid_json_raises(self):
with pytest.raises(ValueError, match="Invalid"):
_load_spec("{invalid json}")
@pytest.mark.unit
class TestValidateSpec:
def test_valid_openapi(self):
spec = json.loads(MINIMAL_OPENAPI)
_validate_spec(spec) # should not raise
def test_valid_swagger(self):
spec = json.loads(MINIMAL_SWAGGER)
_validate_spec(spec)
def test_missing_version_raises(self):
with pytest.raises(ValueError, match="Unsupported"):
_validate_spec({"paths": {"/a": {}}})
def test_no_paths_raises(self):
with pytest.raises(ValueError, match="No API paths"):
_validate_spec({"openapi": "3.0.0"})
def test_empty_paths_raises(self):
with pytest.raises(ValueError, match="No API paths"):
_validate_spec({"openapi": "3.0.0", "paths": {}})
def test_non_dict_raises(self):
with pytest.raises(ValueError, match="valid object"):
_validate_spec("not a dict")
@pytest.mark.unit
class TestExtractMetadata:
def test_openapi_metadata(self):
spec = json.loads(MINIMAL_OPENAPI)
meta = _extract_metadata(spec, is_swagger=False)
assert meta["title"] == "Test API"
assert meta["version"] == "1.0.0"
assert meta["base_url"] == "https://api.example.com"
def test_swagger_metadata(self):
spec = json.loads(MINIMAL_SWAGGER)
meta = _extract_metadata(spec, is_swagger=True)
assert meta["title"] == "Swagger API"
assert meta["base_url"] == "https://api.example.com/v1"
def test_missing_info(self):
spec = {"openapi": "3.0.0", "paths": {"/a": {}}}
meta = _extract_metadata(spec, is_swagger=False)
assert meta["title"] == "Untitled API"
def test_description_truncated(self):
spec = {
"openapi": "3.0.0",
"info": {"title": "T", "description": "x" * 1000},
"paths": {"/a": {}},
}
meta = _extract_metadata(spec, is_swagger=False)
assert len(meta["description"]) <= 500
@pytest.mark.unit
class TestGetBaseUrl:
def test_openapi_servers(self):
spec = {"servers": [{"url": "https://api.example.com/v2/"}]}
assert _get_base_url(spec, is_swagger=False) == "https://api.example.com/v2"
def test_openapi_no_servers(self):
assert _get_base_url({}, is_swagger=False) == ""
def test_swagger_with_host(self):
spec = {"host": "api.test.com", "basePath": "/v1", "schemes": ["https"]}
assert _get_base_url(spec, is_swagger=True) == "https://api.test.com/v1"
def test_swagger_no_host(self):
assert _get_base_url({}, is_swagger=True) == ""
def test_swagger_default_scheme(self):
spec = {"host": "api.test.com", "basePath": ""}
assert _get_base_url(spec, is_swagger=True) == "https://api.test.com"
@pytest.mark.unit
class TestGenerateActionName:
def test_uses_operation_id(self):
assert _generate_action_name({"operationId": "getUser"}, "get", "/users") == "getUser"
def test_fallback_to_method_path(self):
name = _generate_action_name({}, "get", "/users/{id}")
assert name.startswith("get_")
assert "users" in name
def test_truncates_long_names(self):
name = _generate_action_name({}, "get", "/" + "a" * 200)
assert len(name) <= 64
def test_sanitizes_special_chars(self):
name = _generate_action_name({"operationId": "get.user@v2"}, "get", "/")
assert "." not in name
assert "@" not in name
@pytest.mark.unit
class TestParamToProperty:
def test_string_param(self):
param = {"name": "q", "in": "query", "schema": {"type": "string"}}
prop = _param_to_property(param)
assert prop["type"] == "string"
assert prop["required"] is False
def test_integer_param(self):
param = {
"name": "limit",
"in": "query",
"required": True,
"schema": {"type": "integer"},
}
prop = _param_to_property(param)
assert prop["type"] == "integer"
assert prop["required"] is True
assert prop["filled_by_llm"] is True
def test_number_maps_to_integer(self):
param = {"name": "score", "in": "query", "schema": {"type": "number"}}
prop = _param_to_property(param)
assert prop["type"] == "integer"
@pytest.mark.unit
class TestResolveRef:
def test_no_ref(self):
obj = {"type": "string"}
assert _resolve_ref(obj, {}, {}) == obj
def test_components_ref(self):
components = {"schemas": {"User": {"type": "object", "properties": {"name": {"type": "string"}}}}}
obj = {"$ref": "#/components/schemas/User"}
result = _resolve_ref(obj, components, {})
assert result["type"] == "object"
def test_definitions_ref(self):
definitions = {"Pet": {"type": "object"}}
obj = {"$ref": "#/definitions/Pet"}
result = _resolve_ref(obj, {}, definitions)
assert result["type"] == "object"
def test_unsupported_ref(self):
obj = {"$ref": "#/external/something"}
assert _resolve_ref(obj, {}, {}) is None
def test_non_dict_returns_none(self):
assert _resolve_ref("string", {}, {}) is None
assert _resolve_ref(42, {}, {}) is None
@pytest.mark.unit
class TestParseSpec:
def test_openapi_full_parse(self):
metadata, actions = parse_spec(MINIMAL_OPENAPI)
assert metadata["title"] == "Test API"
assert len(actions) == 1
assert actions[0]["name"] == "listUsers"
assert actions[0]["method"] == "GET"
assert actions[0]["url"] == "https://api.example.com/users"
assert "limit" in actions[0]["query_params"]["properties"]
def test_swagger_full_parse(self):
metadata, actions = parse_spec(MINIMAL_SWAGGER)
assert metadata["title"] == "Swagger API"
assert len(actions) == 1
assert actions[0]["name"] == "listPets"
assert actions[0]["method"] == "GET"
def test_multiple_methods(self):
spec = json.dumps(
{
"openapi": "3.0.0",
"info": {"title": "T", "version": "1"},
"paths": {
"/items": {
"get": {"operationId": "listItems", "responses": {}},
"post": {
"operationId": "createItem",
"requestBody": {
"content": {
"application/json": {
"schema": {
"type": "object",
"properties": {
"name": {"type": "string"}
},
"required": ["name"],
}
}
}
},
"responses": {},
},
}
},
}
)
metadata, actions = parse_spec(spec)
assert len(actions) == 2
names = {a["name"] for a in actions}
assert "listItems" in names
assert "createItem" in names
create = next(a for a in actions if a["name"] == "createItem")
assert "name" in create["body"]["properties"]
def test_header_params(self):
spec = json.dumps(
{
"openapi": "3.0.0",
"info": {"title": "T", "version": "1"},
"paths": {
"/data": {
"get": {
"operationId": "getData",
"parameters": [
{"name": "X-API-Key", "in": "header", "schema": {"type": "string"}}
],
"responses": {},
}
}
},
}
)
_, actions = parse_spec(spec)
assert "X-API-Key" in actions[0]["headers"]["properties"]
def test_invalid_spec_raises(self):
with pytest.raises(ValueError):
parse_spec("")
def test_yaml_spec(self):
yaml_spec = """
openapi: "3.0.0"
info:
title: YAML API
version: "1.0"
paths:
/health:
get:
operationId: healthCheck
responses:
"200":
description: OK
"""
metadata, actions = parse_spec(yaml_spec)
assert metadata["title"] == "YAML API"
assert actions[0]["name"] == "healthCheck"

View File

@@ -0,0 +1,79 @@
"""Tests for application/agents/tools/telegram.py"""
from unittest.mock import MagicMock, patch
import pytest
from application.agents.tools.telegram import TelegramTool
@pytest.fixture
def tool():
return TelegramTool(config={"token": "bot123:ABC"})
@pytest.mark.unit
class TestTelegramExecuteAction:
def test_unknown_action_raises(self, tool):
with pytest.raises(ValueError, match="Unknown action"):
tool.execute_action("invalid")
@patch("application.agents.tools.telegram.requests.post")
def test_send_message(self, mock_post, tool):
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_post.return_value = mock_resp
result = tool.execute_action(
"telegram_send_message", text="Hello", chat_id="12345"
)
assert result["status_code"] == 200
assert result["message"] == "Message sent"
call_args = mock_post.call_args
assert "bot123:ABC/sendMessage" in call_args[0][0]
assert call_args[1]["data"]["text"] == "Hello"
assert call_args[1]["data"]["chat_id"] == "12345"
@patch("application.agents.tools.telegram.requests.post")
def test_send_image(self, mock_post, tool):
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_post.return_value = mock_resp
result = tool.execute_action(
"telegram_send_image", image_url="https://img.com/cat.jpg", chat_id="12345"
)
assert result["status_code"] == 200
assert result["message"] == "Image sent"
call_args = mock_post.call_args
assert "bot123:ABC/sendPhoto" in call_args[0][0]
assert call_args[1]["data"]["photo"] == "https://img.com/cat.jpg"
@patch("application.agents.tools.telegram.requests.post")
def test_api_error_status(self, mock_post, tool):
mock_resp = MagicMock()
mock_resp.status_code = 403
mock_post.return_value = mock_resp
result = tool.execute_action(
"telegram_send_message", text="Hi", chat_id="999"
)
assert result["status_code"] == 403
@pytest.mark.unit
class TestTelegramMetadata:
def test_actions_metadata(self, tool):
meta = tool.get_actions_metadata()
assert len(meta) == 2
names = {a["name"] for a in meta}
assert "telegram_send_message" in names
assert "telegram_send_image" in names
def test_config_requirements(self, tool):
reqs = tool.get_config_requirements()
assert "token" in reqs
assert reqs["token"]["secret"] is True

View File

@@ -0,0 +1,433 @@
"""Tests for WorkflowAgent - covering _parse_embedded_workflow, _load_from_database,
_save_workflow_run, _determine_run_status, _serialize_state, and gen flow."""
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
import pytest
from application.agents.workflows.schemas import (
ExecutionStatus,
WorkflowGraph,
)
def _make_agent(**overrides):
"""Create a WorkflowAgent with mocked base class dependencies."""
defaults = {
"endpoint": "https://api.example.com",
"llm_name": "openai",
"model_id": "gpt-4",
"api_key": "test_key",
"user_api_key": None,
"prompt": "You are helpful.",
"chat_history": [],
"decoded_token": {"sub": "user1"},
"attachments": [],
"json_schema": None,
}
defaults.update(overrides)
with patch("application.agents.workflow_agent.log_activity", lambda **kw: lambda f: f):
from application.agents.workflow_agent import WorkflowAgent
agent = WorkflowAgent(**defaults)
return agent
class TestWorkflowAgentInit:
@pytest.mark.unit
def test_sets_attributes(self):
agent = _make_agent(workflow_id="wf1", workflow_owner="owner1")
assert agent.workflow_id == "wf1"
assert agent.workflow_owner == "owner1"
assert agent._engine is None
@pytest.mark.unit
def test_embedded_workflow(self):
wf_data = {"nodes": [], "edges": [], "name": "Test"}
agent = _make_agent(workflow=wf_data)
assert agent._workflow_data == wf_data
class TestParseEmbeddedWorkflow:
@pytest.mark.unit
def test_parses_valid_workflow(self):
wf_data = {
"name": "Test Workflow",
"description": "A test",
"nodes": [
{"id": "n1", "type": "start", "title": "Start", "data": {}, "position": {"x": 0, "y": 0}},
{"id": "n2", "type": "end", "title": "End", "data": {}, "position": {"x": 100, "y": 0}},
],
"edges": [
{"id": "e1", "source": "n1", "target": "n2", "sourceHandle": "out", "targetHandle": "in"},
],
}
agent = _make_agent(workflow=wf_data, workflow_id="wf1")
graph = agent._parse_embedded_workflow()
assert graph is not None
assert len(graph.nodes) == 2
assert len(graph.edges) == 1
assert graph.workflow.name == "Test Workflow"
@pytest.mark.unit
def test_edge_source_id_alias(self):
wf_data = {
"nodes": [{"id": "n1", "type": "start", "data": {}}],
"edges": [{"id": "e1", "source_id": "n1", "target_id": "n2", "source_handle": "out", "target_handle": "in"}],
}
agent = _make_agent(workflow=wf_data)
graph = agent._parse_embedded_workflow()
assert graph is not None
assert graph.edges[0].source_id == "n1"
@pytest.mark.unit
def test_invalid_data_returns_none(self):
agent = _make_agent(workflow={"nodes": [{"bad": "data"}], "edges": []})
graph = agent._parse_embedded_workflow()
assert graph is None
class TestLoadWorkflowGraph:
@pytest.mark.unit
def test_uses_embedded_when_available(self):
agent = _make_agent(workflow={"nodes": [], "edges": [], "name": "E"})
agent._parse_embedded_workflow = MagicMock(return_value="parsed_graph")
result = agent._load_workflow_graph()
assert result == "parsed_graph"
@pytest.mark.unit
def test_uses_database_when_workflow_id(self):
agent = _make_agent(workflow_id="wf1")
agent._load_from_database = MagicMock(return_value="db_graph")
result = agent._load_workflow_graph()
assert result == "db_graph"
@pytest.mark.unit
def test_returns_none_when_nothing(self):
agent = _make_agent()
result = agent._load_workflow_graph()
assert result is None
class TestLoadFromDatabase:
@pytest.mark.unit
def test_invalid_workflow_id_returns_none(self):
agent = _make_agent(workflow_id="invalid!")
result = agent._load_from_database()
assert result is None
@pytest.mark.unit
def test_no_owner_returns_none(self):
agent = _make_agent(workflow_id="507f1f77bcf86cd799439011", decoded_token={})
agent.workflow_owner = None
result = agent._load_from_database()
assert result is None
@pytest.mark.unit
def test_uses_decoded_token_sub(self):
agent = _make_agent(
workflow_id="507f1f77bcf86cd799439011",
decoded_token={"sub": "user1"},
)
agent.workflow_owner = None
mock_collection = MagicMock()
mock_collection.find_one.return_value = None
mock_db = MagicMock()
mock_db.__getitem__ = MagicMock(return_value=mock_collection)
with patch("application.agents.workflow_agent.MongoDB") as MockMongo, \
patch("application.agents.workflow_agent.settings") as mock_settings:
mock_settings.MONGO_DB_NAME = "test_db"
MockMongo.get_client.return_value = {"test_db": mock_db}
result = agent._load_from_database()
assert result is None # workflow_doc not found
@pytest.mark.unit
def test_successful_load(self):
agent = _make_agent(
workflow_id="507f1f77bcf86cd799439011",
workflow_owner="owner1",
)
mock_wf_coll = MagicMock()
mock_wf_coll.find_one.return_value = {
"_id": "507f1f77bcf86cd799439011",
"name": "Test WF",
"user": "owner1",
"current_graph_version": 1,
}
mock_nodes_coll = MagicMock()
mock_nodes_coll.find.return_value = [
{"id": "n1", "workflow_id": "507f1f77bcf86cd799439011", "type": "start",
"title": "Start", "position": {"x": 0, "y": 0}, "config": {}},
]
mock_edges_coll = MagicMock()
mock_edges_coll.find.return_value = []
def getitem(name):
return {"workflows": mock_wf_coll, "workflow_nodes": mock_nodes_coll, "workflow_edges": mock_edges_coll}[name]
mock_db = MagicMock()
mock_db.__getitem__ = MagicMock(side_effect=getitem)
with patch("application.agents.workflow_agent.MongoDB") as MockMongo, \
patch("application.agents.workflow_agent.settings") as mock_settings:
mock_settings.MONGO_DB_NAME = "test_db"
MockMongo.get_client.return_value = {"test_db": mock_db}
result = agent._load_from_database()
assert result is not None
assert len(result.nodes) == 1
@pytest.mark.unit
def test_invalid_graph_version(self):
agent = _make_agent(
workflow_id="507f1f77bcf86cd799439011",
workflow_owner="owner1",
)
mock_wf_coll = MagicMock()
mock_wf_coll.find_one.return_value = {
"_id": "507f1f77bcf86cd799439011",
"name": "WF",
"user": "owner1",
"current_graph_version": "bad",
}
mock_nodes_coll = MagicMock()
mock_nodes_coll.find.return_value = []
mock_edges_coll = MagicMock()
mock_edges_coll.find.return_value = []
def getitem(name):
return {"workflows": mock_wf_coll, "workflow_nodes": mock_nodes_coll, "workflow_edges": mock_edges_coll}[name]
mock_db = MagicMock()
mock_db.__getitem__ = MagicMock(side_effect=getitem)
with patch("application.agents.workflow_agent.MongoDB") as MockMongo, \
patch("application.agents.workflow_agent.settings") as mock_settings:
mock_settings.MONGO_DB_NAME = "test_db"
MockMongo.get_client.return_value = {"test_db": mock_db}
result = agent._load_from_database()
assert result is not None # Defaults to version 1
@pytest.mark.unit
def test_fallback_nodes_without_version(self):
"""When graph_version=1 finds no nodes, falls back to nodes without version field."""
agent = _make_agent(
workflow_id="507f1f77bcf86cd799439011",
workflow_owner="owner1",
)
mock_wf_coll = MagicMock()
mock_wf_coll.find_one.return_value = {
"_id": "507f1f77bcf86cd799439011",
"name": "WF",
"user": "owner1",
"current_graph_version": 1,
}
call_count = [0]
def nodes_find(query):
call_count[0] += 1
if call_count[0] == 1:
return [] # No versioned nodes
return [{"id": "n1", "workflow_id": "wf", "type": "start",
"title": "S", "position": {"x": 0, "y": 0}, "config": {}}]
mock_nodes_coll = MagicMock()
mock_nodes_coll.find.side_effect = nodes_find
edge_call = [0]
def edges_find(query):
edge_call[0] += 1
if edge_call[0] == 1:
return []
return []
mock_edges_coll = MagicMock()
mock_edges_coll.find.side_effect = edges_find
def getitem(name):
return {"workflows": mock_wf_coll, "workflow_nodes": mock_nodes_coll, "workflow_edges": mock_edges_coll}[name]
mock_db = MagicMock()
mock_db.__getitem__ = MagicMock(side_effect=getitem)
with patch("application.agents.workflow_agent.MongoDB") as MockMongo, \
patch("application.agents.workflow_agent.settings") as mock_settings:
mock_settings.MONGO_DB_NAME = "test_db"
MockMongo.get_client.return_value = {"test_db": mock_db}
result = agent._load_from_database()
assert result is not None
assert len(result.nodes) == 1
@pytest.mark.unit
def test_exception_returns_none(self):
agent = _make_agent(
workflow_id="507f1f77bcf86cd799439011",
workflow_owner="owner1",
)
with patch("application.agents.workflow_agent.MongoDB") as MockMongo:
MockMongo.get_client.side_effect = Exception("db error")
result = agent._load_from_database()
assert result is None
class TestGenInner:
@pytest.mark.unit
def test_no_graph_yields_error(self):
agent = _make_agent()
agent._load_workflow_graph = MagicMock(return_value=None)
events = list(agent._gen_inner("query", None))
assert any(e.get("type") == "error" for e in events)
@pytest.mark.unit
def test_successful_execution(self):
agent = _make_agent(workflow_id="wf1")
mock_graph = MagicMock(spec=WorkflowGraph)
agent._load_workflow_graph = MagicMock(return_value=mock_graph)
agent._save_workflow_run = MagicMock()
mock_engine = MagicMock()
mock_engine.execute.return_value = iter([{"answer": "result"}])
with patch("application.agents.workflow_agent.WorkflowEngine", return_value=mock_engine):
events = list(agent._gen_inner("query", None))
assert len(events) == 1
agent._save_workflow_run.assert_called_once_with("query")
class TestSaveWorkflowRun:
@pytest.mark.unit
def test_no_engine_returns_early(self):
agent = _make_agent()
agent._engine = None
agent._save_workflow_run("query") # Should not raise
@pytest.mark.unit
def test_saves_to_mongo(self):
agent = _make_agent(workflow_id="wf1")
mock_engine = MagicMock()
mock_engine.state = {"query": "test"}
mock_engine.execution_log = []
mock_engine.get_execution_summary.return_value = []
agent._engine = mock_engine
mock_collection = MagicMock()
mock_db = MagicMock()
mock_db.__getitem__ = MagicMock(return_value=mock_collection)
with patch("application.agents.workflow_agent.MongoDB") as MockMongo, \
patch("application.agents.workflow_agent.settings") as mock_settings:
mock_settings.MONGO_DB_NAME = "test_db"
MockMongo.get_client.return_value = {"test_db": mock_db}
agent._save_workflow_run("query")
mock_collection.insert_one.assert_called_once()
@pytest.mark.unit
def test_exception_does_not_propagate(self):
agent = _make_agent(workflow_id="wf1")
mock_engine = MagicMock()
mock_engine.state = {}
mock_engine.execution_log = []
mock_engine.get_execution_summary.return_value = []
agent._engine = mock_engine
with patch("application.agents.workflow_agent.MongoDB") as MockMongo:
MockMongo.get_client.side_effect = Exception("db fail")
agent._save_workflow_run("query") # Should not raise
class TestDetermineRunStatus:
@pytest.mark.unit
def test_no_engine_returns_completed(self):
agent = _make_agent()
agent._engine = None
assert agent._determine_run_status() == ExecutionStatus.COMPLETED
@pytest.mark.unit
def test_empty_log_returns_completed(self):
agent = _make_agent()
agent._engine = MagicMock()
agent._engine.execution_log = []
assert agent._determine_run_status() == ExecutionStatus.COMPLETED
@pytest.mark.unit
def test_failed_log_returns_failed(self):
agent = _make_agent()
agent._engine = MagicMock()
agent._engine.execution_log = [
{"status": "completed"},
{"status": "failed"},
]
assert agent._determine_run_status() == ExecutionStatus.FAILED
@pytest.mark.unit
def test_all_completed_returns_completed(self):
agent = _make_agent()
agent._engine = MagicMock()
agent._engine.execution_log = [
{"status": "completed"},
{"status": "completed"},
]
assert agent._determine_run_status() == ExecutionStatus.COMPLETED
class TestSerializeState:
@pytest.mark.unit
def test_serializes_primitives(self):
agent = _make_agent()
state = {"str": "hello", "int": 42, "float": 3.14, "bool": True, "none": None}
result = agent._serialize_state(state)
assert result == state
@pytest.mark.unit
def test_serializes_nested_dict(self):
agent = _make_agent()
state = {"nested": {"key": "value"}}
result = agent._serialize_state(state)
assert result["nested"]["key"] == "value"
@pytest.mark.unit
def test_serializes_list(self):
agent = _make_agent()
state = {"items": [1, 2, "three"]}
result = agent._serialize_state(state)
assert result["items"] == [1, 2, "three"]
@pytest.mark.unit
def test_serializes_tuple(self):
agent = _make_agent()
state = {"tup": (1, 2)}
result = agent._serialize_state(state)
assert result["tup"] == [1, 2]
@pytest.mark.unit
def test_serializes_datetime(self):
agent = _make_agent()
dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
state = {"time": dt}
result = agent._serialize_state(state)
assert "2025-01-01" in result["time"]
@pytest.mark.unit
def test_serializes_unknown_to_str(self):
agent = _make_agent()
state = {"obj": object()}
result = agent._serialize_state(state)
assert isinstance(result["obj"], str)

View File

@@ -0,0 +1,573 @@
"""Tests covering gaps in WorkflowEngine: execute loop, state/condition/end nodes,
template context, source data, structured output parsing, get_execution_summary."""
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
import pytest
from application.agents.workflows.schemas import (
ExecutionStatus,
NodeType,
WorkflowEdge,
WorkflowGraph,
WorkflowNode,
Workflow,
)
from application.agents.workflows.workflow_engine import WorkflowEngine
def _make_graph(nodes, edges):
wf = Workflow(name="Test", description="test workflow")
return WorkflowGraph(workflow=wf, nodes=nodes, edges=edges)
def _make_node(id, type, title="Node", config=None, position=None):
return WorkflowNode(
id=id,
workflow_id="wf1",
type=type,
title=title,
position=position or {"x": 0, "y": 0},
config=config or {},
)
def _make_edge(id, source, target, source_handle=None, target_handle=None):
return WorkflowEdge(
id=id,
workflow_id="wf1",
source=source,
target=target,
sourceHandle=source_handle,
targetHandle=target_handle,
)
def _make_agent():
agent = MagicMock()
agent.chat_history = []
agent.endpoint = "https://api.example.com"
agent.llm_name = "openai"
agent.model_id = "gpt-4"
agent.api_key = "key"
agent.decoded_token = {"sub": "user1"}
agent.retrieved_docs = None
return agent
class TestExecuteLoop:
@pytest.mark.unit
def test_no_start_node_yields_error(self):
graph = _make_graph([], [])
engine = WorkflowEngine(graph, _make_agent())
events = list(engine.execute({}, "query"))
assert any(e.get("type") == "error" and "start node" in e.get("error", "") for e in events)
@pytest.mark.unit
def test_start_to_end(self):
nodes = [
_make_node("n1", NodeType.START, "Start"),
_make_node("n2", NodeType.END, "End", config={"config": {}}),
]
edges = [_make_edge("e1", "n1", "n2")]
graph = _make_graph(nodes, edges)
engine = WorkflowEngine(graph, _make_agent())
events = list(engine.execute({}, "hello"))
step_events = [e for e in events if e.get("type") == "workflow_step"]
assert len(step_events) >= 2 # At least start + end
@pytest.mark.unit
def test_node_not_found_yields_error(self):
nodes = [_make_node("n1", NodeType.START)]
edges = [_make_edge("e1", "n1", "nonexistent")]
graph = _make_graph(nodes, edges)
engine = WorkflowEngine(graph, _make_agent())
events = list(engine.execute({}, "q"))
assert any("not found" in e.get("error", "") for e in events)
@pytest.mark.unit
def test_node_execution_error_yields_error(self):
nodes = [
_make_node("n1", NodeType.START),
_make_node("n2", NodeType.STATE, "State", config={"config": {"operations": [{"expression": "bad!!!", "target_variable": "x"}]}}),
]
edges = [_make_edge("e1", "n1", "n2")]
graph = _make_graph(nodes, edges)
engine = WorkflowEngine(graph, _make_agent())
events = list(engine.execute({}, "q"))
failed_events = [e for e in events if e.get("status") == "failed"]
assert len(failed_events) >= 1
@pytest.mark.unit
def test_max_steps_limit(self):
# Create a cycle: start -> state -> state (loop)
nodes = [
_make_node("n1", NodeType.START),
_make_node("n2", NodeType.NOTE, "Note"),
]
edges = [_make_edge("e1", "n1", "n2"), _make_edge("e2", "n2", "n2")]
graph = _make_graph(nodes, edges)
engine = WorkflowEngine(graph, _make_agent())
engine.MAX_EXECUTION_STEPS = 5
events = list(engine.execute({}, "q"))
# Should not run forever
assert len(events) > 0
@pytest.mark.unit
def test_branch_ends_without_end_node(self):
nodes = [
_make_node("n1", NodeType.START),
_make_node("n2", NodeType.NOTE, "Note"),
]
edges = [_make_edge("e1", "n1", "n2")] # n2 has no outgoing edges
graph = _make_graph(nodes, edges)
engine = WorkflowEngine(graph, _make_agent())
events = list(engine.execute({}, "q"))
assert len(events) > 0
class TestInitializeState:
@pytest.mark.unit
def test_sets_query_and_history(self):
graph = _make_graph([], [])
agent = _make_agent()
agent.chat_history = [{"prompt": "hi", "response": "hey"}]
engine = WorkflowEngine(graph, agent)
engine._initialize_state({"custom": "value"}, "test query")
assert engine.state["query"] == "test query"
assert "custom" in engine.state
assert engine.state["chat_history"] is not None
class TestGetNextNodeId:
@pytest.mark.unit
def test_no_edges_returns_none(self):
nodes = [_make_node("n1", NodeType.START)]
graph = _make_graph(nodes, [])
engine = WorkflowEngine(graph, _make_agent())
assert engine._get_next_node_id("n1") is None
@pytest.mark.unit
def test_returns_first_edge_target(self):
nodes = [_make_node("n1", NodeType.START), _make_node("n2", NodeType.END)]
edges = [_make_edge("e1", "n1", "n2")]
graph = _make_graph(nodes, edges)
engine = WorkflowEngine(graph, _make_agent())
assert engine._get_next_node_id("n1") == "n2"
@pytest.mark.unit
def test_condition_uses_matched_handle(self):
nodes = [
_make_node("n1", NodeType.CONDITION),
_make_node("n2", NodeType.END, "Yes End"),
_make_node("n3", NodeType.END, "No End"),
]
edges = [
_make_edge("e1", "n1", "n2", source_handle="yes"),
_make_edge("e2", "n1", "n3", source_handle="no"),
]
graph = _make_graph(nodes, edges)
engine = WorkflowEngine(graph, _make_agent())
engine._condition_result = "no"
assert engine._get_next_node_id("n1") == "n3"
assert engine._condition_result is None # Cleared after use
@pytest.mark.unit
def test_condition_no_matching_handle_returns_none(self):
nodes = [_make_node("n1", NodeType.CONDITION)]
edges = [_make_edge("e1", "n1", "n2", source_handle="yes")]
graph = _make_graph(nodes, edges)
engine = WorkflowEngine(graph, _make_agent())
engine._condition_result = "nonexistent"
assert engine._get_next_node_id("n1") is None
class TestExecuteStateNode:
@pytest.mark.unit
def test_evaluates_operations(self):
node = _make_node("n1", NodeType.STATE, config={
"config": {
"operations": [
{"expression": "x + 1", "target_variable": "result"},
]
}
})
graph = _make_graph([node], [])
engine = WorkflowEngine(graph, _make_agent())
engine.state = {"x": 5}
list(engine._execute_state_node(node))
assert engine.state["result"] == 6
@pytest.mark.unit
def test_skips_empty_expression(self):
node = _make_node("n1", NodeType.STATE, config={
"config": {
"operations": [
{"expression": "", "target_variable": "result"},
]
}
})
graph = _make_graph([node], [])
engine = WorkflowEngine(graph, _make_agent())
engine.state = {}
list(engine._execute_state_node(node))
assert "result" not in engine.state
class TestExecuteConditionNode:
@pytest.mark.unit
def test_matches_first_true_case(self):
node = _make_node("n1", NodeType.CONDITION, config={
"config": {
"cases": [
{"expression": "x > 10", "source_handle": "high"},
{"expression": "x > 5", "source_handle": "medium"},
]
}
})
graph = _make_graph([node], [])
engine = WorkflowEngine(graph, _make_agent())
engine.state = {"x": 7}
list(engine._execute_condition_node(node))
assert engine._condition_result == "medium"
@pytest.mark.unit
def test_falls_through_to_else(self):
node = _make_node("n1", NodeType.CONDITION, config={
"config": {
"cases": [
{"expression": "x > 100", "source_handle": "high"},
]
}
})
graph = _make_graph([node], [])
engine = WorkflowEngine(graph, _make_agent())
engine.state = {"x": 1}
list(engine._execute_condition_node(node))
assert engine._condition_result == "else"
@pytest.mark.unit
def test_skips_empty_expression(self):
node = _make_node("n1", NodeType.CONDITION, config={
"config": {
"cases": [
{"expression": " ", "source_handle": "a"},
{"expression": "true", "source_handle": "b"},
]
}
})
graph = _make_graph([node], [])
engine = WorkflowEngine(graph, _make_agent())
engine.state = {}
list(engine._execute_condition_node(node))
assert engine._condition_result == "b"
@pytest.mark.unit
def test_cel_error_continues(self):
node = _make_node("n1", NodeType.CONDITION, config={
"config": {
"cases": [
{"expression": "bad!!!", "source_handle": "a"},
{"expression": "true", "source_handle": "b"},
]
}
})
graph = _make_graph([node], [])
engine = WorkflowEngine(graph, _make_agent())
engine.state = {}
list(engine._execute_condition_node(node))
assert engine._condition_result == "b"
class TestExecuteEndNode:
@pytest.mark.unit
def test_with_output_template(self):
node = _make_node("n1", NodeType.END, config={
"config": {"output_template": "Result: {{ query }}"}
})
graph = _make_graph([node], [])
engine = WorkflowEngine(graph, _make_agent())
engine.state = {"query": "hello"}
engine._format_template = MagicMock(return_value="Result: hello")
events = list(engine._execute_end_node(node))
assert len(events) == 1
assert events[0]["answer"] == "Result: hello"
@pytest.mark.unit
def test_without_output_template(self):
node = _make_node("n1", NodeType.END, config={"config": {}})
graph = _make_graph([node], [])
engine = WorkflowEngine(graph, _make_agent())
events = list(engine._execute_end_node(node))
assert len(events) == 0
class TestParseStructuredOutput:
@pytest.mark.unit
def test_valid_json(self):
graph = _make_graph([], [])
engine = WorkflowEngine(graph, _make_agent())
success, data = engine._parse_structured_output('{"key": "value"}')
assert success is True
assert data == {"key": "value"}
@pytest.mark.unit
def test_invalid_json(self):
graph = _make_graph([], [])
engine = WorkflowEngine(graph, _make_agent())
success, data = engine._parse_structured_output("not json")
assert success is False
assert data is None
@pytest.mark.unit
def test_empty_string(self):
graph = _make_graph([], [])
engine = WorkflowEngine(graph, _make_agent())
success, data = engine._parse_structured_output("")
assert success is False
assert data is None
@pytest.mark.unit
def test_whitespace_only(self):
graph = _make_graph([], [])
engine = WorkflowEngine(graph, _make_agent())
success, data = engine._parse_structured_output(" ")
assert success is False
assert data is None
class TestNormalizeNodeJsonSchema:
@pytest.mark.unit
def test_none_returns_none(self):
graph = _make_graph([], [])
engine = WorkflowEngine(graph, _make_agent())
assert engine._normalize_node_json_schema(None, "Node") is None
@pytest.mark.unit
def test_valid_schema(self):
graph = _make_graph([], [])
engine = WorkflowEngine(graph, _make_agent())
schema = {"type": "object", "properties": {"name": {"type": "string"}}}
result = engine._normalize_node_json_schema(schema, "Node")
assert result is not None
@pytest.mark.unit
def test_invalid_schema_raises(self):
graph = _make_graph([], [])
engine = WorkflowEngine(graph, _make_agent())
with patch("application.agents.workflows.workflow_engine.normalize_json_schema_payload") as mock_norm:
from application.core.json_schema_utils import JsonSchemaValidationError
mock_norm.side_effect = JsonSchemaValidationError("bad schema")
with pytest.raises(ValueError, match="Invalid JSON schema"):
engine._normalize_node_json_schema({"bad": True}, "TestNode")
class TestValidateStructuredOutput:
@pytest.mark.unit
def test_valid_output_passes(self):
graph = _make_graph([], [])
engine = WorkflowEngine(graph, _make_agent())
schema = {"type": "object", "properties": {"name": {"type": "string"}}}
engine._validate_structured_output(schema, {"name": "Alice"}) # Should not raise
@pytest.mark.unit
def test_invalid_output_raises(self):
graph = _make_graph([], [])
engine = WorkflowEngine(graph, _make_agent())
schema = {"type": "object", "properties": {"name": {"type": "string"}}, "required": ["name"]}
with pytest.raises(ValueError, match="did not match schema"):
engine._validate_structured_output(schema, {})
@pytest.mark.unit
def test_no_jsonschema_module(self):
graph = _make_graph([], [])
engine = WorkflowEngine(graph, _make_agent())
with patch("application.agents.workflows.workflow_engine.jsonschema", None):
engine._validate_structured_output({"type": "object"}, {}) # Should not raise
class TestFormatTemplate:
@pytest.mark.unit
def test_renders_template(self):
graph = _make_graph([], [])
engine = WorkflowEngine(graph, _make_agent())
engine.state = {"query": "hello"}
engine._build_template_context = MagicMock(return_value={"query": "hello"})
engine._template_engine = MagicMock()
engine._template_engine.render.return_value = "hello world"
result = engine._format_template("{{ query }} world")
assert result == "hello world"
@pytest.mark.unit
def test_render_error_returns_raw(self):
from application.templates.template_engine import TemplateRenderError
graph = _make_graph([], [])
engine = WorkflowEngine(graph, _make_agent())
engine._build_template_context = MagicMock(return_value={})
engine._template_engine = MagicMock()
engine._template_engine.render.side_effect = TemplateRenderError("fail")
result = engine._format_template("{{ bad }}")
assert result == "{{ bad }}"
class TestBuildTemplateContext:
@pytest.mark.unit
def test_includes_state_variables(self):
graph = _make_graph([], [])
agent = _make_agent()
agent.retrieved_docs = None
engine = WorkflowEngine(graph, agent)
engine.state = {"query": "hello", "custom_var": "value"}
context = engine._build_template_context()
assert context["agent"]["query"] == "hello"
assert "custom_var" in context
@pytest.mark.unit
def test_reserved_namespace_gets_prefixed(self):
graph = _make_graph([], [])
agent = _make_agent()
agent.retrieved_docs = None
engine = WorkflowEngine(graph, agent)
engine.state = {"source": "my_source_val"}
context = engine._build_template_context()
assert context.get("agent_source") == "my_source_val"
@pytest.mark.unit
def test_passthrough_data(self):
graph = _make_graph([], [])
agent = _make_agent()
agent.retrieved_docs = None
engine = WorkflowEngine(graph, agent)
engine.state = {"passthrough": {"key": "val"}}
context = engine._build_template_context()
assert "passthrough" in context or "agent_passthrough" in context
@pytest.mark.unit
def test_tools_data(self):
graph = _make_graph([], [])
agent = _make_agent()
agent.retrieved_docs = None
engine = WorkflowEngine(graph, agent)
engine.state = {"tools": {"tool1": "result"}}
context = engine._build_template_context()
assert "agent" in context
class TestGetSourceTemplateData:
@pytest.mark.unit
def test_no_docs_returns_none(self):
graph = _make_graph([], [])
agent = _make_agent()
agent.retrieved_docs = None
engine = WorkflowEngine(graph, agent)
docs, together = engine._get_source_template_data()
assert docs is None
assert together is None
@pytest.mark.unit
def test_empty_docs_returns_none(self):
graph = _make_graph([], [])
agent = _make_agent()
agent.retrieved_docs = []
engine = WorkflowEngine(graph, agent)
docs, together = engine._get_source_template_data()
assert docs is None
@pytest.mark.unit
def test_docs_with_filename(self):
graph = _make_graph([], [])
agent = _make_agent()
agent.retrieved_docs = [{"text": "content", "filename": "doc.txt"}]
engine = WorkflowEngine(graph, agent)
docs, together = engine._get_source_template_data()
assert docs is not None
assert "doc.txt" in together
assert "content" in together
@pytest.mark.unit
def test_docs_without_filename(self):
graph = _make_graph([], [])
agent = _make_agent()
agent.retrieved_docs = [{"text": "content only"}]
engine = WorkflowEngine(graph, agent)
docs, together = engine._get_source_template_data()
assert together == "content only"
@pytest.mark.unit
def test_skips_non_dict_docs(self):
graph = _make_graph([], [])
agent = _make_agent()
agent.retrieved_docs = ["not a dict", {"text": "ok"}]
engine = WorkflowEngine(graph, agent)
docs, together = engine._get_source_template_data()
assert together == "ok"
@pytest.mark.unit
def test_skips_non_string_text(self):
graph = _make_graph([], [])
agent = _make_agent()
agent.retrieved_docs = [{"text": 123}]
engine = WorkflowEngine(graph, agent)
docs, together = engine._get_source_template_data()
assert together is None
@pytest.mark.unit
def test_doc_with_title_fallback(self):
graph = _make_graph([], [])
agent = _make_agent()
agent.retrieved_docs = [{"text": "content", "title": "doc_title"}]
engine = WorkflowEngine(graph, agent)
docs, together = engine._get_source_template_data()
assert "doc_title" in together
@pytest.mark.unit
def test_doc_with_source_fallback(self):
graph = _make_graph([], [])
agent = _make_agent()
agent.retrieved_docs = [{"text": "content", "source": "src"}]
engine = WorkflowEngine(graph, agent)
docs, together = engine._get_source_template_data()
assert "src" in together
class TestGetExecutionSummary:
@pytest.mark.unit
def test_returns_log_entries(self):
graph = _make_graph([], [])
engine = WorkflowEngine(graph, _make_agent())
now = datetime.now(timezone.utc)
engine.execution_log = [
{
"node_id": "n1",
"node_type": "start",
"status": "completed",
"started_at": now,
"completed_at": now,
"error": None,
"state_snapshot": {},
}
]
summary = engine.get_execution_summary()
assert len(summary) == 1
assert summary[0].node_id == "n1"
assert summary[0].status == ExecutionStatus.COMPLETED
@pytest.mark.unit
def test_empty_log(self):
graph = _make_graph([], [])
engine = WorkflowEngine(graph, _make_agent())
assert engine.get_execution_summary() == []

View File

@@ -0,0 +1,235 @@
"""Tests for application/api/answer/routes/answer.py"""
import json
from unittest.mock import MagicMock, patch
import pytest
from bson import ObjectId
@pytest.fixture
def mock_stream_processor():
"""Create a mock StreamProcessor."""
with patch(
"application.api.answer.routes.answer.StreamProcessor"
) as MockProcessor:
processor = MagicMock()
processor.decoded_token = {"sub": "test_user"}
processor.conversation_id = str(ObjectId())
processor.agent_config = {}
processor.agent_id = str(ObjectId())
processor.is_shared_usage = False
processor.shared_token = None
processor.model_id = "gpt-4"
processor.build_agent.return_value = MagicMock()
MockProcessor.return_value = processor
yield processor
@pytest.fixture
def answer_client(mock_mongo_db, flask_app):
"""Create a test client with the answer route registered."""
from flask_restx import Api
from application.api.answer.routes.answer import answer_ns
api = Api(flask_app)
api.add_namespace(answer_ns)
flask_app.config["TESTING"] = True
return flask_app.test_client()
@pytest.mark.unit
class TestAnswerResourcePost:
def test_missing_question_returns_400(self, answer_client, mock_stream_processor):
resp = answer_client.post(
"/api/answer",
data=json.dumps({}),
content_type="application/json",
)
assert resp.status_code == 400
def test_successful_answer(self, answer_client, mock_stream_processor):
conv_id = str(ObjectId())
with patch.object(
mock_stream_processor.build_agent.return_value,
"gen",
return_value=iter([]),
):
with patch(
"application.api.answer.routes.answer.AnswerResource.validate_request",
return_value=None,
), patch(
"application.api.answer.routes.answer.AnswerResource.check_usage",
return_value=None,
), patch(
"application.api.answer.routes.answer.AnswerResource.complete_stream",
return_value=iter(
[
f'data: {json.dumps({"type": "answer", "answer": "Hello"})}\n\n',
f'data: {json.dumps({"type": "id", "id": conv_id})}\n\n',
f'data: {json.dumps({"type": "end"})}\n\n',
]
),
), patch(
"application.api.answer.routes.answer.AnswerResource.process_response_stream",
return_value=(conv_id, "Hello", [], [], "", None),
):
resp = answer_client.post(
"/api/answer",
data=json.dumps({"question": "What is Python?"}),
content_type="application/json",
)
assert resp.status_code == 200
data = resp.get_json()
assert data["answer"] == "Hello"
assert data["conversation_id"] == conv_id
def test_unauthorized_returns_401(self, answer_client, mock_stream_processor):
mock_stream_processor.decoded_token = None
with patch(
"application.api.answer.routes.answer.AnswerResource.validate_request",
return_value=None,
):
resp = answer_client.post(
"/api/answer",
data=json.dumps({"question": "test"}),
content_type="application/json",
)
assert resp.status_code == 401
assert resp.get_json()["error"] == "Unauthorized"
def test_usage_exceeded_returns_error(self, answer_client, mock_stream_processor):
with patch(
"application.api.answer.routes.answer.AnswerResource.validate_request",
return_value=None,
), patch(
"application.api.answer.routes.answer.AnswerResource.check_usage",
) as mock_check:
with flask_app_context(answer_client):
mock_check.return_value = ({"error": "Usage limit exceeded"}, 429)
resp = answer_client.post(
"/api/answer",
data=json.dumps({"question": "test"}),
content_type="application/json",
)
assert resp.status_code == 429
def test_stream_error_returns_400(self, answer_client, mock_stream_processor):
with patch(
"application.api.answer.routes.answer.AnswerResource.validate_request",
return_value=None,
), patch(
"application.api.answer.routes.answer.AnswerResource.check_usage",
return_value=None,
), patch(
"application.api.answer.routes.answer.AnswerResource.complete_stream",
return_value=iter([]),
), patch(
"application.api.answer.routes.answer.AnswerResource.process_response_stream",
return_value=(None, None, None, None, None, "Stream error"),
):
resp = answer_client.post(
"/api/answer",
data=json.dumps({"question": "test"}),
content_type="application/json",
)
assert resp.status_code == 400
assert resp.get_json()["error"] == "Stream error"
def test_exception_returns_500(self, answer_client, mock_stream_processor):
with patch(
"application.api.answer.routes.answer.AnswerResource.validate_request",
return_value=None,
), patch(
"application.api.answer.routes.answer.AnswerResource.check_usage",
return_value=None,
), patch(
"application.api.answer.routes.answer.AnswerResource.complete_stream",
side_effect=RuntimeError("unexpected"),
):
resp = answer_client.post(
"/api/answer",
data=json.dumps({"question": "test"}),
content_type="application/json",
)
assert resp.status_code == 500
assert "error" in resp.get_json()
def test_structured_info_merged_into_result(
self, answer_client, mock_stream_processor
):
conv_id = str(ObjectId())
with patch(
"application.api.answer.routes.answer.AnswerResource.validate_request",
return_value=None,
), patch(
"application.api.answer.routes.answer.AnswerResource.check_usage",
return_value=None,
), patch(
"application.api.answer.routes.answer.AnswerResource.complete_stream",
return_value=iter([]),
), patch(
"application.api.answer.routes.answer.AnswerResource.process_response_stream",
return_value=(
conv_id,
'{"key": "val"}',
[],
[],
"",
None,
{"structured": True, "schema": {"type": "object"}},
),
):
resp = answer_client.post(
"/api/answer",
data=json.dumps({"question": "test"}),
content_type="application/json",
)
assert resp.status_code == 200
data = resp.get_json()
assert data["structured"] is True
assert data["schema"] == {"type": "object"}
def test_result_contains_all_expected_fields(
self, answer_client, mock_stream_processor
):
conv_id = str(ObjectId())
with patch(
"application.api.answer.routes.answer.AnswerResource.validate_request",
return_value=None,
), patch(
"application.api.answer.routes.answer.AnswerResource.check_usage",
return_value=None,
), patch(
"application.api.answer.routes.answer.AnswerResource.complete_stream",
return_value=iter([]),
), patch(
"application.api.answer.routes.answer.AnswerResource.process_response_stream",
return_value=(
conv_id,
"answer text",
[{"title": "src"}],
[{"tool": "t"}],
"thinking...",
None,
),
):
resp = answer_client.post(
"/api/answer",
data=json.dumps({"question": "test"}),
content_type="application/json",
)
data = resp.get_json()
assert data["conversation_id"] == conv_id
assert data["answer"] == "answer text"
assert data["sources"] == [{"title": "src"}]
assert data["tool_calls"] == [{"tool": "t"}]
assert data["thought"] == "thinking..."
def flask_app_context(client):
"""Helper to get app context from test client."""
return client.application.app_context()

View File

@@ -0,0 +1,195 @@
"""Tests for application/api/answer/routes/stream.py"""
import json
from unittest.mock import MagicMock, patch
import pytest
from bson import ObjectId
@pytest.fixture
def mock_stream_processor():
"""Create a mock StreamProcessor for stream tests."""
with patch(
"application.api.answer.routes.stream.StreamProcessor"
) as MockProcessor:
processor = MagicMock()
processor.decoded_token = {"sub": "test_user"}
processor.conversation_id = str(ObjectId())
processor.agent_config = {}
processor.agent_id = str(ObjectId())
processor.is_shared_usage = False
processor.shared_token = None
processor.model_id = "gpt-4"
processor.build_agent.return_value = MagicMock()
MockProcessor.return_value = processor
yield processor
@pytest.fixture
def stream_client(mock_mongo_db, flask_app):
"""Create a test client with the stream route registered."""
from flask_restx import Api
from application.api.answer.routes.stream import answer_ns
api = Api(flask_app)
api.add_namespace(answer_ns)
flask_app.config["TESTING"] = True
return flask_app.test_client()
@pytest.mark.unit
class TestStreamResourcePost:
def test_missing_question_returns_400(self, stream_client, mock_stream_processor):
resp = stream_client.post(
"/stream",
data=json.dumps({}),
content_type="application/json",
)
assert resp.status_code == 400
def test_successful_stream(self, stream_client, mock_stream_processor):
def fake_stream(*args, **kwargs):
yield f'data: {json.dumps({"type": "answer", "answer": "Hi"})}\n\n'
yield f'data: {json.dumps({"type": "end"})}\n\n'
with patch(
"application.api.answer.routes.stream.StreamResource.validate_request",
return_value=None,
), patch(
"application.api.answer.routes.stream.StreamResource.check_usage",
return_value=None,
), patch(
"application.api.answer.routes.stream.StreamResource.complete_stream",
side_effect=fake_stream,
):
resp = stream_client.post(
"/stream",
data=json.dumps({"question": "What is Python?"}),
content_type="application/json",
)
assert resp.status_code == 200
assert "text/event-stream" in resp.content_type
data = resp.get_data(as_text=True)
assert '"type": "answer"' in data
assert '"answer": "Hi"' in data
def test_unauthorized_returns_401_stream(
self, stream_client, mock_stream_processor
):
mock_stream_processor.decoded_token = None
with patch(
"application.api.answer.routes.stream.StreamResource.validate_request",
return_value=None,
):
resp = stream_client.post(
"/stream",
data=json.dumps({"question": "test"}),
content_type="application/json",
)
assert resp.status_code == 401
assert "text/event-stream" in resp.content_type
data = resp.get_data(as_text=True)
assert "Unauthorized" in data
def test_usage_exceeded_returns_error(
self, stream_client, mock_stream_processor
):
with patch(
"application.api.answer.routes.stream.StreamResource.validate_request",
return_value=None,
), patch(
"application.api.answer.routes.stream.StreamResource.check_usage",
) as mock_check:
mock_check.return_value = ({"error": "Usage limit exceeded"}, 429)
resp = stream_client.post(
"/stream",
data=json.dumps({"question": "test"}),
content_type="application/json",
)
assert resp.status_code == 429
def test_value_error_returns_400_stream(
self, stream_client, mock_stream_processor
):
mock_stream_processor.build_agent.side_effect = ValueError("bad data")
with patch(
"application.api.answer.routes.stream.StreamResource.validate_request",
return_value=None,
):
resp = stream_client.post(
"/stream",
data=json.dumps({"question": "test"}),
content_type="application/json",
)
assert resp.status_code == 400
assert "text/event-stream" in resp.content_type
data = resp.get_data(as_text=True)
assert "Malformed request body" in data
def test_general_exception_returns_400_stream(
self, stream_client, mock_stream_processor
):
mock_stream_processor.build_agent.side_effect = RuntimeError("crash")
with patch(
"application.api.answer.routes.stream.StreamResource.validate_request",
return_value=None,
):
resp = stream_client.post(
"/stream",
data=json.dumps({"question": "test"}),
content_type="application/json",
)
assert resp.status_code == 400
assert "text/event-stream" in resp.content_type
data = resp.get_data(as_text=True)
assert "Unknown error occurred" in data
def test_index_in_data_requires_conversation_id(
self, stream_client, mock_stream_processor
):
"""When 'index' is present, validate_request is called with require_conversation_id=True."""
resp = stream_client.post(
"/stream",
data=json.dumps({"question": "test", "index": 0}),
content_type="application/json",
)
# Should get 400 since conversation_id is missing
assert resp.status_code == 400
def test_stream_passes_attachments_and_index(
self, stream_client, mock_stream_processor
):
"""Verify attachments and index params are forwarded to complete_stream."""
def fake_stream(*args, **kwargs):
yield f'data: {json.dumps({"type": "end"})}\n\n'
conv_id = str(ObjectId())
with patch(
"application.api.answer.routes.stream.StreamResource.validate_request",
return_value=None,
), patch(
"application.api.answer.routes.stream.StreamResource.check_usage",
return_value=None,
), patch(
"application.api.answer.routes.stream.StreamResource.complete_stream",
side_effect=fake_stream,
) as mock_complete:
resp = stream_client.post(
"/stream",
data=json.dumps(
{
"question": "test",
"conversation_id": conv_id,
"index": 3,
"attachments": ["att1", "att2"],
}
),
content_type="application/json",
)
assert resp.status_code == 200
call_kwargs = mock_complete.call_args
assert call_kwargs.kwargs.get("index") == 3
assert call_kwargs.kwargs.get("attachment_ids") == ["att1", "att2"]

View File

@@ -0,0 +1,303 @@
"""Tests for application/api/answer/services/compression/message_builder.py"""
import pytest
from application.api.answer.services.compression.message_builder import MessageBuilder
@pytest.mark.unit
class TestBuildFromCompressedContext:
def test_no_compression_returns_system_only(self):
messages = MessageBuilder.build_from_compressed_context(
system_prompt="You are helpful.",
compressed_summary=None,
recent_queries=[],
)
assert len(messages) == 1
assert messages[0]["role"] == "system"
assert messages[0]["content"] == "You are helpful."
def test_with_recent_queries_no_compression(self):
queries = [
{"prompt": "Hello", "response": "Hi there!"},
{"prompt": "How are you?", "response": "I'm fine."},
]
messages = MessageBuilder.build_from_compressed_context(
system_prompt="System prompt",
compressed_summary=None,
recent_queries=queries,
)
# system + 2 * (user + assistant) = 5
assert len(messages) == 5
assert messages[1] == {"role": "user", "content": "Hello"}
assert messages[2] == {"role": "assistant", "content": "Hi there!"}
assert messages[3] == {"role": "user", "content": "How are you?"}
assert messages[4] == {"role": "assistant", "content": "I'm fine."}
def test_with_compressed_summary_appended_to_system(self):
messages = MessageBuilder.build_from_compressed_context(
system_prompt="You are helpful.",
compressed_summary="Previous: user asked about Python.",
recent_queries=[{"prompt": "More?", "response": "Sure."}],
)
system_content = messages[0]["content"]
assert "This session is being continued" in system_content
assert "Previous: user asked about Python." in system_content
def test_mid_execution_context_type(self):
messages = MessageBuilder.build_from_compressed_context(
system_prompt="System",
compressed_summary="Summary here",
recent_queries=[{"prompt": "q", "response": "r"}],
context_type="mid_execution",
)
system_content = messages[0]["content"]
assert "Context window limit reached" in system_content
def test_include_tool_calls(self):
queries = [
{
"prompt": "Search for X",
"response": "Found X",
"tool_calls": [
{
"call_id": "call-1",
"action_name": "search",
"arguments": {"q": "X"},
"result": "X found",
}
],
}
]
messages = MessageBuilder.build_from_compressed_context(
system_prompt="System",
compressed_summary=None,
recent_queries=queries,
include_tool_calls=True,
)
# system + user + assistant + tool_call_assistant + tool_response = 5
assert len(messages) == 5
assert messages[3]["role"] == "assistant"
assert "function_call" in messages[3]["content"][0]
assert messages[4]["role"] == "tool"
assert "function_response" in messages[4]["content"][0]
def test_tool_calls_not_included_by_default(self):
queries = [
{
"prompt": "Search",
"response": "Found",
"tool_calls": [
{
"call_id": "c1",
"action_name": "search",
"arguments": {},
"result": "ok",
}
],
}
]
messages = MessageBuilder.build_from_compressed_context(
system_prompt="System",
compressed_summary=None,
recent_queries=queries,
include_tool_calls=False,
)
# system + user + assistant = 3 (no tool messages)
assert len(messages) == 3
def test_tool_call_without_call_id_generates_uuid(self):
queries = [
{
"prompt": "q",
"response": "r",
"tool_calls": [
{
"action_name": "act",
"arguments": {},
"result": "res",
}
],
}
]
messages = MessageBuilder.build_from_compressed_context(
system_prompt="S",
compressed_summary=None,
recent_queries=queries,
include_tool_calls=True,
)
tool_msg = messages[3]["content"][0]
call_id = tool_msg["function_call"]["call_id"]
assert call_id is not None
assert len(call_id) > 0
def test_continuation_message_when_no_recent_queries_but_has_summary(self):
messages = MessageBuilder.build_from_compressed_context(
system_prompt="System",
compressed_summary="Everything was compressed",
recent_queries=[],
)
# system + continuation user message = 2
assert len(messages) == 2
assert messages[1]["role"] == "user"
assert "continue" in messages[1]["content"].lower()
def test_no_continuation_when_no_summary(self):
messages = MessageBuilder.build_from_compressed_context(
system_prompt="System",
compressed_summary=None,
recent_queries=[],
)
assert len(messages) == 1
def test_queries_without_prompt_or_response_skipped(self):
queries = [
{"other_field": "value"},
{"prompt": "real", "response": "answer"},
]
messages = MessageBuilder.build_from_compressed_context(
system_prompt="S",
compressed_summary=None,
recent_queries=queries,
)
# system + 1 valid query (user + assistant) = 3
assert len(messages) == 3
@pytest.mark.unit
class TestAppendCompressionContext:
def test_pre_request_context(self):
result = MessageBuilder._append_compression_context(
"Original prompt", "Summary text", "pre_request"
)
assert "This session is being continued" in result
assert "Summary text" in result
assert result.startswith("Original prompt")
def test_mid_execution_context(self):
result = MessageBuilder._append_compression_context(
"Original prompt", "Summary text", "mid_execution"
)
assert "Context window limit reached" in result
assert "Summary text" in result
def test_removes_existing_compression_context(self):
prompt_with_existing = (
"Original prompt\n\n---\n\nThis session is being continued from old"
)
result = MessageBuilder._append_compression_context(
prompt_with_existing, "New summary", "pre_request"
)
# Should not contain old context twice
assert result.count("This session is being continued") == 1
assert "New summary" in result
def test_removes_mid_execution_context(self):
prompt_with_existing = (
"Original\n\n---\n\nContext window limit reached during execution. Old."
)
result = MessageBuilder._append_compression_context(
prompt_with_existing, "New", "mid_execution"
)
assert result.count("Context window limit reached") == 1
@pytest.mark.unit
class TestRebuildMessagesAfterCompression:
def test_basic_rebuild(self):
messages = [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "old message"},
{"role": "assistant", "content": "old reply"},
]
recent = [{"prompt": "new q", "response": "new r"}]
result = MessageBuilder.rebuild_messages_after_compression(
messages=messages,
compressed_summary="Everything was compressed.",
recent_queries=recent,
)
assert result is not None
# system + user + assistant = 3
assert len(result) == 3
assert "Context window limit reached" in result[0]["content"]
assert result[1] == {"role": "user", "content": "new q"}
assert result[2] == {"role": "assistant", "content": "new r"}
def test_returns_none_without_system_message(self):
messages = [
{"role": "user", "content": "hello"},
]
result = MessageBuilder.rebuild_messages_after_compression(
messages=messages,
compressed_summary="summary",
recent_queries=[],
)
assert result is None
def test_no_summary_keeps_system_unchanged(self):
messages = [{"role": "system", "content": "Be helpful."}]
result = MessageBuilder.rebuild_messages_after_compression(
messages=messages,
compressed_summary=None,
recent_queries=[],
)
assert result is not None
assert result[0]["content"] == "Be helpful."
def test_include_tool_calls_in_rebuild(self):
messages = [{"role": "system", "content": "S"}]
recent = [
{
"prompt": "q",
"response": "r",
"tool_calls": [
{
"call_id": "c1",
"action_name": "act",
"arguments": {"a": 1},
"result": "done",
}
],
}
]
result = MessageBuilder.rebuild_messages_after_compression(
messages=messages,
compressed_summary="s",
recent_queries=recent,
include_tool_calls=True,
)
# system + user + assistant + tool_call + tool_response = 5
assert len(result) == 5
def test_continuation_added_when_no_recent_queries(self):
messages = [{"role": "system", "content": "S"}]
result = MessageBuilder.rebuild_messages_after_compression(
messages=messages,
compressed_summary="All compressed",
recent_queries=[],
)
assert len(result) == 2
assert result[1]["role"] == "user"
assert "continue" in result[1]["content"].lower()
def test_include_current_execution_preserves_extra_messages(self):
messages = [
{"role": "system", "content": "S"},
{"role": "user", "content": "q1"},
{"role": "assistant", "content": "r1"},
{"role": "user", "content": "current execution msg"},
]
recent = [{"prompt": "q1", "response": "r1"}]
result = MessageBuilder.rebuild_messages_after_compression(
messages=messages,
compressed_summary="summary",
recent_queries=recent,
include_current_execution=True,
)
assert result is not None
# Should include the current execution message
contents = [m.get("content") for m in result]
assert "current execution msg" in contents

View File

@@ -0,0 +1,447 @@
"""Tests for application/api/answer/services/compression/orchestrator.py"""
from unittest.mock import MagicMock, patch
import pytest
from application.api.answer.services.compression.orchestrator import (
CompressionOrchestrator,
)
from application.api.answer.services.compression.types import (
CompressionMetadata,
CompressionResult,
)
@pytest.fixture
def mock_conversation_service():
svc = MagicMock()
return svc
@pytest.fixture
def mock_threshold_checker():
checker = MagicMock()
return checker
@pytest.fixture
def orchestrator(mock_conversation_service, mock_threshold_checker):
return CompressionOrchestrator(
conversation_service=mock_conversation_service,
threshold_checker=mock_threshold_checker,
)
@pytest.fixture
def sample_conversation():
return {
"queries": [
{"prompt": "q0", "response": "r0"},
{"prompt": "q1", "response": "r1"},
{"prompt": "q2", "response": "r2"},
],
"compression_metadata": {},
"agent_id": "agent-1",
}
@pytest.fixture
def decoded_token():
return {"sub": "user123"}
@pytest.mark.unit
class TestCompressIfNeeded:
def test_conversation_not_found_returns_failure(
self, orchestrator, mock_conversation_service
):
mock_conversation_service.get_conversation.return_value = None
result = orchestrator.compress_if_needed(
conversation_id="conv1",
user_id="user1",
model_id="gpt-4",
decoded_token={"sub": "user1"},
)
assert result.success is False
assert "not found" in result.error
def test_no_compression_needed(
self,
orchestrator,
mock_conversation_service,
mock_threshold_checker,
sample_conversation,
):
mock_conversation_service.get_conversation.return_value = sample_conversation
mock_threshold_checker.should_compress.return_value = False
result = orchestrator.compress_if_needed(
conversation_id="conv1",
user_id="user1",
model_id="gpt-4",
decoded_token={"sub": "user1"},
)
assert result.success is True
assert result.compression_performed is False
assert len(result.recent_queries) == 3
def test_compression_performed_successfully(
self,
orchestrator,
mock_conversation_service,
mock_threshold_checker,
sample_conversation,
decoded_token,
):
mock_conversation_service.get_conversation.return_value = sample_conversation
mock_threshold_checker.should_compress.return_value = True
mock_metadata = MagicMock(spec=CompressionMetadata)
mock_metadata.compression_ratio = 5.0
mock_metadata.original_token_count = 1000
mock_metadata.compressed_token_count = 200
mock_metadata.to_dict.return_value = {"query_index": 2}
with patch.object(
orchestrator, "_perform_compression"
) as mock_perform:
mock_perform.return_value = CompressionResult.success_with_compression(
"compressed summary",
[{"prompt": "q2", "response": "r2"}],
mock_metadata,
)
result = orchestrator.compress_if_needed(
conversation_id="conv1",
user_id="user1",
model_id="gpt-4",
decoded_token=decoded_token,
)
assert result.success is True
assert result.compression_performed is True
assert result.compressed_summary == "compressed summary"
mock_perform.assert_called_once()
def test_exception_returns_failure(
self,
orchestrator,
mock_conversation_service,
):
mock_conversation_service.get_conversation.side_effect = RuntimeError("DB down")
result = orchestrator.compress_if_needed(
conversation_id="conv1",
user_id="user1",
model_id="gpt-4",
decoded_token={"sub": "user1"},
)
assert result.success is False
assert "DB down" in result.error
def test_custom_query_tokens(
self,
orchestrator,
mock_conversation_service,
mock_threshold_checker,
sample_conversation,
):
mock_conversation_service.get_conversation.return_value = sample_conversation
mock_threshold_checker.should_compress.return_value = False
orchestrator.compress_if_needed(
conversation_id="conv1",
user_id="user1",
model_id="gpt-4",
decoded_token={"sub": "user1"},
current_query_tokens=1000,
)
mock_threshold_checker.should_compress.assert_called_once_with(
sample_conversation, "gpt-4", 1000
)
@pytest.mark.unit
class TestPerformCompression:
@patch(
"application.api.answer.services.compression.orchestrator.get_provider_from_model_id"
)
@patch(
"application.api.answer.services.compression.orchestrator.get_api_key_for_provider"
)
@patch("application.api.answer.services.compression.orchestrator.LLMCreator")
@patch("application.api.answer.services.compression.orchestrator.CompressionService")
@patch("application.api.answer.services.compression.orchestrator.settings")
def test_successful_compression(
self,
mock_settings,
MockCompressionService,
MockLLMCreator,
mock_get_api_key,
mock_get_provider,
mock_conversation_service,
mock_threshold_checker,
sample_conversation,
decoded_token,
):
mock_settings.COMPRESSION_MODEL_OVERRIDE = None
mock_get_provider.return_value = "openai"
mock_get_api_key.return_value = "sk-test"
MockLLMCreator.create_llm.return_value = MagicMock()
mock_metadata = MagicMock(spec=CompressionMetadata)
mock_metadata.compression_ratio = 5.0
mock_metadata.original_token_count = 500
mock_metadata.compressed_token_count = 100
mock_svc_instance = MagicMock()
mock_svc_instance.compress_and_save.return_value = mock_metadata
mock_svc_instance.get_compressed_context.return_value = (
"compressed text",
[{"prompt": "q2", "response": "r2"}],
)
MockCompressionService.return_value = mock_svc_instance
# After compression, reload conversation
mock_conversation_service.get_conversation.return_value = sample_conversation
orch = CompressionOrchestrator(
conversation_service=mock_conversation_service,
threshold_checker=mock_threshold_checker,
)
result = orch._perform_compression(
"conv1", sample_conversation, "gpt-4", decoded_token
)
assert result.success is True
assert result.compression_performed is True
assert result.compressed_summary == "compressed text"
mock_svc_instance.compress_and_save.assert_called_once()
@patch(
"application.api.answer.services.compression.orchestrator.get_provider_from_model_id"
)
@patch(
"application.api.answer.services.compression.orchestrator.get_api_key_for_provider"
)
@patch("application.api.answer.services.compression.orchestrator.LLMCreator")
@patch("application.api.answer.services.compression.orchestrator.settings")
def test_uses_compression_model_override(
self,
mock_settings,
MockLLMCreator,
mock_get_api_key,
mock_get_provider,
mock_conversation_service,
mock_threshold_checker,
decoded_token,
):
mock_settings.COMPRESSION_MODEL_OVERRIDE = "gpt-3.5-turbo"
mock_get_provider.return_value = "openai"
mock_get_api_key.return_value = "sk-test"
MockLLMCreator.create_llm.return_value = MagicMock()
conversation = {"queries": [{"prompt": "q", "response": "r"}], "agent_id": "a"}
with patch(
"application.api.answer.services.compression.orchestrator.CompressionService"
) as MockCS:
mock_svc = MagicMock()
mock_svc.compress_and_save.return_value = MagicMock(
compression_ratio=3.0,
original_token_count=300,
compressed_token_count=100,
)
mock_svc.get_compressed_context.return_value = ("s", [])
MockCS.return_value = mock_svc
mock_conversation_service.get_conversation.return_value = conversation
orch = CompressionOrchestrator(
conversation_service=mock_conversation_service,
threshold_checker=mock_threshold_checker,
)
orch._perform_compression("c1", conversation, "gpt-4", decoded_token)
# Verify the override model was used
mock_get_provider.assert_called_with("gpt-3.5-turbo")
@patch(
"application.api.answer.services.compression.orchestrator.get_provider_from_model_id"
)
@patch(
"application.api.answer.services.compression.orchestrator.get_api_key_for_provider"
)
@patch("application.api.answer.services.compression.orchestrator.LLMCreator")
@patch("application.api.answer.services.compression.orchestrator.CompressionService")
@patch("application.api.answer.services.compression.orchestrator.settings")
def test_no_queries_returns_no_compression(
self,
mock_settings,
MockCompressionService,
MockLLMCreator,
mock_get_api_key,
mock_get_provider,
mock_conversation_service,
mock_threshold_checker,
decoded_token,
):
mock_settings.COMPRESSION_MODEL_OVERRIDE = None
mock_get_provider.return_value = "openai"
mock_get_api_key.return_value = "sk-test"
MockLLMCreator.create_llm.return_value = MagicMock()
conversation = {"queries": [], "agent_id": "a"}
orch = CompressionOrchestrator(
conversation_service=mock_conversation_service,
threshold_checker=mock_threshold_checker,
)
result = orch._perform_compression("c1", conversation, "gpt-4", decoded_token)
assert result.success is True
assert result.compression_performed is False
def test_exception_returns_failure(
self,
mock_conversation_service,
mock_threshold_checker,
decoded_token,
):
conversation = {
"queries": [{"prompt": "q", "response": "r"}],
"agent_id": "a",
}
with patch(
"application.api.answer.services.compression.orchestrator.settings"
) as mock_settings, patch(
"application.api.answer.services.compression.orchestrator.get_provider_from_model_id",
side_effect=RuntimeError("provider error"),
):
mock_settings.COMPRESSION_MODEL_OVERRIDE = None
orch = CompressionOrchestrator(
conversation_service=mock_conversation_service,
threshold_checker=mock_threshold_checker,
)
result = orch._perform_compression(
"c1", conversation, "gpt-4", decoded_token
)
assert result.success is False
assert "provider error" in result.error
@pytest.mark.unit
class TestCompressMidExecution:
def test_with_provided_conversation(
self,
orchestrator,
sample_conversation,
decoded_token,
):
with patch.object(
orchestrator, "_perform_compression"
) as mock_perform:
mock_perform.return_value = CompressionResult.success_no_compression([])
orchestrator.compress_mid_execution(
conversation_id="conv1",
user_id="user1",
model_id="gpt-4",
decoded_token=decoded_token,
current_conversation=sample_conversation,
)
mock_perform.assert_called_once_with(
"conv1", sample_conversation, "gpt-4", decoded_token
)
def test_loads_conversation_when_not_provided(
self,
orchestrator,
mock_conversation_service,
sample_conversation,
decoded_token,
):
mock_conversation_service.get_conversation.return_value = sample_conversation
with patch.object(
orchestrator, "_perform_compression"
) as mock_perform:
mock_perform.return_value = CompressionResult.success_no_compression([])
orchestrator.compress_mid_execution(
conversation_id="conv1",
user_id="user1",
model_id="gpt-4",
decoded_token=decoded_token,
)
mock_conversation_service.get_conversation.assert_called_once_with(
"conv1", "user1"
)
mock_perform.assert_called_once()
def test_conversation_not_found_returns_failure(
self,
orchestrator,
mock_conversation_service,
decoded_token,
):
mock_conversation_service.get_conversation.return_value = None
result = orchestrator.compress_mid_execution(
conversation_id="conv1",
user_id="user1",
model_id="gpt-4",
decoded_token=decoded_token,
)
assert result.success is False
assert "not found" in result.error
def test_exception_returns_failure(
self,
orchestrator,
mock_conversation_service,
decoded_token,
):
mock_conversation_service.get_conversation.side_effect = RuntimeError("fail")
result = orchestrator.compress_mid_execution(
conversation_id="conv1",
user_id="user1",
model_id="gpt-4",
decoded_token=decoded_token,
)
assert result.success is False
assert "fail" in result.error
@pytest.mark.unit
class TestOrchestratorInit:
def test_default_threshold_checker(self, mock_conversation_service):
orch = CompressionOrchestrator(
conversation_service=mock_conversation_service
)
assert orch.threshold_checker is not None
assert orch.conversation_service is mock_conversation_service
def test_custom_threshold_checker(
self, mock_conversation_service, mock_threshold_checker
):
orch = CompressionOrchestrator(
conversation_service=mock_conversation_service,
threshold_checker=mock_threshold_checker,
)
assert orch.threshold_checker is mock_threshold_checker

View File

@@ -0,0 +1,423 @@
"""Tests for application/api/answer/services/compression/service.py"""
from unittest.mock import MagicMock, patch
import pytest
from application.api.answer.services.compression.service import CompressionService
from application.api.answer.services.compression.types import CompressionMetadata
@pytest.fixture
def mock_llm():
llm = MagicMock()
llm.gen.return_value = "<summary>Compressed summary content</summary>"
return llm
@pytest.fixture
def mock_conversation_service():
svc = MagicMock()
svc.update_compression_metadata = MagicMock()
return svc
@pytest.fixture
def sample_conversation():
return {
"queries": [
{"prompt": "What is Python?", "response": "A programming language."},
{"prompt": "Tell me more.", "response": "It's versatile and popular."},
{
"prompt": "What about tools?",
"response": "Python has many tools.",
"tool_calls": [
{
"tool_name": "search",
"action_name": "web_search",
"arguments": {"q": "python tools"},
"result": "Found 10 results",
"status": "success",
}
],
},
],
"compression_metadata": {},
}
@pytest.mark.unit
class TestCompressionServiceInit:
@patch("application.api.answer.services.compression.service.settings")
def test_default_prompt_builder(self, mock_settings, mock_llm):
mock_settings.COMPRESSION_PROMPT_VERSION = "v1.0"
with patch(
"application.api.answer.services.compression.service.CompressionPromptBuilder"
):
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
assert svc.llm is mock_llm
assert svc.model_id == "gpt-4"
def test_custom_prompt_builder(self, mock_llm):
custom_builder = MagicMock()
svc = CompressionService(
llm=mock_llm, model_id="gpt-4", prompt_builder=custom_builder
)
assert svc.prompt_builder is custom_builder
@pytest.mark.unit
class TestCompressConversation:
def test_successful_compression(self, mock_llm, sample_conversation):
mock_builder = MagicMock()
mock_builder.build_prompt.return_value = [
{"role": "system", "content": "Compress"},
{"role": "user", "content": "Conversation..."},
]
mock_builder.version = "v1.0"
svc = CompressionService(
llm=mock_llm, model_id="gpt-4", prompt_builder=mock_builder
)
with patch(
"application.api.answer.services.compression.service.TokenCounter"
) as MockTC:
MockTC.count_query_tokens.return_value = 1000
MockTC.count_message_tokens.return_value = 100
result = svc.compress_conversation(sample_conversation, 2)
assert isinstance(result, CompressionMetadata)
assert result.query_index == 2
assert result.compressed_summary == "Compressed summary content"
assert result.original_token_count == 1000
assert result.compressed_token_count == 100
assert result.compression_ratio == 10.0
assert result.model_used == "gpt-4"
assert result.compression_prompt_version == "v1.0"
def test_invalid_index_negative(self, mock_llm, sample_conversation):
mock_builder = MagicMock()
mock_builder.version = "v1.0"
svc = CompressionService(
llm=mock_llm, model_id="gpt-4", prompt_builder=mock_builder
)
with pytest.raises(ValueError, match="Invalid compress_up_to_index"):
svc.compress_conversation(sample_conversation, -1)
def test_invalid_index_too_large(self, mock_llm, sample_conversation):
mock_builder = MagicMock()
mock_builder.version = "v1.0"
svc = CompressionService(
llm=mock_llm, model_id="gpt-4", prompt_builder=mock_builder
)
with pytest.raises(ValueError, match="Invalid compress_up_to_index"):
svc.compress_conversation(sample_conversation, 10)
def test_with_existing_compressions(self, mock_llm):
conversation = {
"queries": [
{"prompt": "q1", "response": "r1"},
{"prompt": "q2", "response": "r2"},
],
"compression_metadata": {
"compression_points": [
{
"query_index": 0,
"compressed_summary": "Previous summary",
}
]
},
}
mock_builder = MagicMock()
mock_builder.build_prompt.return_value = [
{"role": "system", "content": "Compress"},
{"role": "user", "content": "..."},
]
mock_builder.version = "v1.0"
svc = CompressionService(
llm=mock_llm, model_id="gpt-4", prompt_builder=mock_builder
)
with patch(
"application.api.answer.services.compression.service.TokenCounter"
) as MockTC:
MockTC.count_query_tokens.return_value = 500
MockTC.count_message_tokens.return_value = 50
result = svc.compress_conversation(conversation, 1)
assert isinstance(result, CompressionMetadata)
# Verify existing compressions were passed to prompt builder
call_args = mock_builder.build_prompt.call_args
assert call_args[0][1] == [
{"query_index": 0, "compressed_summary": "Previous summary"}
]
def test_zero_compressed_tokens_ratio(self, mock_llm, sample_conversation):
mock_builder = MagicMock()
mock_builder.build_prompt.return_value = [
{"role": "system", "content": "C"},
{"role": "user", "content": "..."},
]
mock_builder.version = "v1.0"
svc = CompressionService(
llm=mock_llm, model_id="gpt-4", prompt_builder=mock_builder
)
with patch(
"application.api.answer.services.compression.service.TokenCounter"
) as MockTC:
MockTC.count_query_tokens.return_value = 1000
MockTC.count_message_tokens.return_value = 0
result = svc.compress_conversation(sample_conversation, 2)
assert result.compression_ratio == 0
def test_llm_error_propagates(self, sample_conversation):
llm = MagicMock()
llm.gen.side_effect = RuntimeError("LLM error")
mock_builder = MagicMock()
mock_builder.build_prompt.return_value = [
{"role": "system", "content": "C"},
{"role": "user", "content": "..."},
]
mock_builder.version = "v1.0"
svc = CompressionService(
llm=llm, model_id="gpt-4", prompt_builder=mock_builder
)
with patch(
"application.api.answer.services.compression.service.TokenCounter"
) as MockTC:
MockTC.count_query_tokens.return_value = 100
with pytest.raises(RuntimeError, match="LLM error"):
svc.compress_conversation(sample_conversation, 2)
@pytest.mark.unit
class TestCompressAndSave:
def test_saves_metadata_to_db(
self, mock_llm, mock_conversation_service, sample_conversation
):
mock_builder = MagicMock()
mock_builder.build_prompt.return_value = [
{"role": "system", "content": "C"},
{"role": "user", "content": "..."},
]
mock_builder.version = "v1.0"
svc = CompressionService(
llm=mock_llm,
model_id="gpt-4",
conversation_service=mock_conversation_service,
prompt_builder=mock_builder,
)
with patch(
"application.api.answer.services.compression.service.TokenCounter"
) as MockTC:
MockTC.count_query_tokens.return_value = 500
MockTC.count_message_tokens.return_value = 50
result = svc.compress_and_save("conv_123", sample_conversation, 2)
assert isinstance(result, CompressionMetadata)
mock_conversation_service.update_compression_metadata.assert_called_once_with(
"conv_123", result.to_dict()
)
def test_raises_without_conversation_service(self, mock_llm, sample_conversation):
mock_builder = MagicMock()
mock_builder.version = "v1.0"
svc = CompressionService(
llm=mock_llm, model_id="gpt-4", prompt_builder=mock_builder
)
with pytest.raises(ValueError, match="conversation_service required"):
svc.compress_and_save("conv_123", sample_conversation, 2)
@pytest.mark.unit
class TestGetCompressedContext:
def test_no_compression_returns_full_history(self, mock_llm):
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
conversation = {
"queries": [{"prompt": "q1", "response": "r1"}],
"compression_metadata": {},
}
summary, queries = svc.get_compressed_context(conversation)
assert summary is None
assert queries == [{"prompt": "q1", "response": "r1"}]
def test_no_compression_points_returns_full_history(self, mock_llm):
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
conversation = {
"queries": [{"prompt": "q1", "response": "r1"}],
"compression_metadata": {"is_compressed": True, "compression_points": []},
}
summary, queries = svc.get_compressed_context(conversation)
assert summary is None
assert len(queries) == 1
def test_with_compression_returns_summary_and_recent(self, mock_llm):
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
conversation = {
"queries": [
{"prompt": "q0", "response": "r0"},
{"prompt": "q1", "response": "r1"},
{"prompt": "q2", "response": "r2"},
],
"compression_metadata": {
"is_compressed": True,
"compression_points": [
{
"query_index": 1,
"compressed_summary": "Summary of q0 and q1",
"compressed_token_count": 50,
"original_token_count": 500,
}
],
},
}
summary, queries = svc.get_compressed_context(conversation)
assert summary == "Summary of q0 and q1"
assert len(queries) == 1
assert queries[0]["prompt"] == "q2"
def test_none_queries_returns_empty(self, mock_llm):
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
conversation = {
"queries": None,
"compression_metadata": {},
}
summary, queries = svc.get_compressed_context(conversation)
assert summary is None
assert queries == []
def test_exception_falls_back_to_full_history(self, mock_llm):
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
conversation = {
"queries": [{"prompt": "q", "response": "r"}],
"compression_metadata": {
"is_compressed": True,
"compression_points": "invalid", # This will cause an error
},
}
summary, queries = svc.get_compressed_context(conversation)
assert summary is None
assert queries == [{"prompt": "q", "response": "r"}]
def test_exception_with_none_queries_returns_empty(self, mock_llm):
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
# Force exception by making compression_points non-iterable
conversation = {
"queries": None,
"compression_metadata": {
"is_compressed": True,
"compression_points": "bad",
},
}
summary, queries = svc.get_compressed_context(conversation)
assert summary is None
assert queries == []
@pytest.mark.unit
class TestExtractSummary:
def test_extracts_from_summary_tags(self, mock_llm):
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
response = "<analysis>Some analysis</analysis><summary>The actual summary</summary>"
result = svc._extract_summary(response)
assert result == "The actual summary"
def test_removes_analysis_tags_when_no_summary(self, mock_llm):
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
response = "<analysis>analysis text</analysis>Raw summary text here"
result = svc._extract_summary(response)
assert result == "Raw summary text here"
def test_returns_full_response_when_no_tags(self, mock_llm):
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
response = "Just a plain text response"
result = svc._extract_summary(response)
assert result == "Just a plain text response"
def test_multiline_summary(self, mock_llm):
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
response = "<summary>Line 1\nLine 2\nLine 3</summary>"
result = svc._extract_summary(response)
assert "Line 1" in result
assert "Line 3" in result
def test_strips_whitespace(self, mock_llm):
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
response = "<summary> Trimmed </summary>"
result = svc._extract_summary(response)
assert result == "Trimmed"
@pytest.mark.unit
class TestLogToolCallStats:
def test_no_tool_calls(self, mock_llm):
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
queries = [{"prompt": "q", "response": "r"}]
# Should not raise
svc._log_tool_call_stats(queries)
def test_with_tool_calls(self, mock_llm):
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
queries = [
{
"prompt": "q",
"response": "r",
"tool_calls": [
{
"tool_name": "search",
"action_name": "web",
"result": "result text",
},
{
"tool_name": "search",
"action_name": "web",
"result": "more text",
},
],
}
]
# Should not raise - just logs
svc._log_tool_call_stats(queries)
def test_empty_queries(self, mock_llm):
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
svc._log_tool_call_stats([])
def test_tool_call_with_none_result(self, mock_llm):
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
queries = [
{
"prompt": "q",
"response": "r",
"tool_calls": [
{
"tool_name": "t",
"action_name": "a",
"result": None,
}
],
}
]
svc._log_tool_call_stats(queries)

View File

@@ -0,0 +1,131 @@
"""Tests for application/api/answer/services/compression/types.py"""
from datetime import datetime, timezone
import pytest
from application.api.answer.services.compression.types import (
CompressionMetadata,
CompressionResult,
)
@pytest.mark.unit
class TestCompressionMetadata:
def _make_metadata(self, **overrides):
defaults = dict(
timestamp=datetime(2025, 1, 1, tzinfo=timezone.utc),
query_index=5,
compressed_summary="Summary of conversation",
original_token_count=5000,
compressed_token_count=500,
compression_ratio=10.0,
model_used="gpt-4",
compression_prompt_version="v1.0",
)
defaults.update(overrides)
return CompressionMetadata(**defaults)
def test_to_dict_contains_all_fields(self):
meta = self._make_metadata()
d = meta.to_dict()
assert d["timestamp"] == datetime(2025, 1, 1, tzinfo=timezone.utc)
assert d["query_index"] == 5
assert d["compressed_summary"] == "Summary of conversation"
assert d["original_token_count"] == 5000
assert d["compressed_token_count"] == 500
assert d["compression_ratio"] == 10.0
assert d["model_used"] == "gpt-4"
assert d["compression_prompt_version"] == "v1.0"
def test_to_dict_returns_dict_type(self):
meta = self._make_metadata()
assert isinstance(meta.to_dict(), dict)
def test_to_dict_field_count(self):
meta = self._make_metadata()
d = meta.to_dict()
assert len(d) == 8
def test_attributes_accessible(self):
meta = self._make_metadata(query_index=10, compression_ratio=5.5)
assert meta.query_index == 10
assert meta.compression_ratio == 5.5
def test_zero_compressed_tokens(self):
meta = self._make_metadata(compressed_token_count=0, compression_ratio=0)
d = meta.to_dict()
assert d["compressed_token_count"] == 0
assert d["compression_ratio"] == 0
@pytest.mark.unit
class TestCompressionResult:
def test_success_with_compression(self):
meta = CompressionMetadata(
timestamp=datetime.now(timezone.utc),
query_index=3,
compressed_summary="summary",
original_token_count=1000,
compressed_token_count=100,
compression_ratio=10.0,
model_used="gpt-4",
compression_prompt_version="v1.0",
)
queries = [{"prompt": "q1", "response": "r1"}]
result = CompressionResult.success_with_compression("summary", queries, meta)
assert result.success is True
assert result.compressed_summary == "summary"
assert result.recent_queries == queries
assert result.metadata is meta
assert result.compression_performed is True
assert result.error is None
def test_success_no_compression(self):
queries = [{"prompt": "q1", "response": "r1"}]
result = CompressionResult.success_no_compression(queries)
assert result.success is True
assert result.compressed_summary is None
assert result.recent_queries == queries
assert result.metadata is None
assert result.compression_performed is False
assert result.error is None
def test_failure(self):
result = CompressionResult.failure("something went wrong")
assert result.success is False
assert result.error == "something went wrong"
assert result.compression_performed is False
assert result.compressed_summary is None
assert result.recent_queries == []
assert result.metadata is None
def test_as_history_extracts_prompt_response(self):
queries = [
{"prompt": "Hello", "response": "Hi", "extra": "ignored"},
{"prompt": "How?", "response": "Fine"},
]
result = CompressionResult.success_no_compression(queries)
history = result.as_history()
assert len(history) == 2
assert history[0] == {"prompt": "Hello", "response": "Hi"}
assert history[1] == {"prompt": "How?", "response": "Fine"}
def test_as_history_empty_queries(self):
result = CompressionResult.success_no_compression([])
assert result.as_history() == []
def test_default_recent_queries_is_empty_list(self):
result = CompressionResult(success=True)
assert result.recent_queries == []
assert result.as_history() == []
def test_success_no_compression_with_empty_list(self):
result = CompressionResult.success_no_compression([])
assert result.success is True
assert result.recent_queries == []

View File

@@ -0,0 +1,331 @@
"""Tests for application/api/answer/services/stream_processor.py — get_prompt and helpers."""
from unittest.mock import MagicMock, patch
import pytest
from application.api.answer.services.stream_processor import get_prompt
class TestGetPrompt:
@pytest.mark.unit
def test_default_preset(self):
prompt = get_prompt("default")
assert isinstance(prompt, str)
assert len(prompt) > 0
@pytest.mark.unit
def test_creative_preset(self):
prompt = get_prompt("creative")
assert isinstance(prompt, str)
@pytest.mark.unit
def test_strict_preset(self):
prompt = get_prompt("strict")
assert isinstance(prompt, str)
@pytest.mark.unit
def test_reduce_preset(self):
prompt = get_prompt("reduce")
assert isinstance(prompt, str)
@pytest.mark.unit
def test_agentic_default_preset(self):
prompt = get_prompt("agentic_default")
assert isinstance(prompt, str)
@pytest.mark.unit
def test_agentic_creative_preset(self):
prompt = get_prompt("agentic_creative")
assert isinstance(prompt, str)
@pytest.mark.unit
def test_agentic_strict_preset(self):
prompt = get_prompt("agentic_strict")
assert isinstance(prompt, str)
@pytest.mark.unit
def test_mongo_prompt_by_id(self):
mock_collection = MagicMock()
mock_collection.find_one.return_value = {"_id": "abc", "content": "Custom prompt"}
prompt = get_prompt("507f1f77bcf86cd799439011", prompts_collection=mock_collection)
assert prompt == "Custom prompt"
@pytest.mark.unit
def test_mongo_prompt_not_found_raises(self):
mock_collection = MagicMock()
mock_collection.find_one.return_value = None
with pytest.raises(ValueError, match="Invalid prompt ID"):
get_prompt("507f1f77bcf86cd799439011", prompts_collection=mock_collection)
@pytest.mark.unit
def test_invalid_id_raises(self):
mock_collection = MagicMock()
mock_collection.find_one.side_effect = Exception("bad id")
with pytest.raises(ValueError, match="Invalid prompt ID"):
get_prompt("not-an-objectid", prompts_collection=mock_collection)
@pytest.mark.unit
def test_mongo_fallback_when_no_collection(self):
"""When no collection passed, it reads from MongoDB."""
mock_collection = MagicMock()
mock_collection.find_one.return_value = {"content": "From DB"}
mock_db = MagicMock()
mock_db.__getitem__ = MagicMock(return_value=mock_collection)
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
mock_settings.MONGO_DB_NAME = "test_db"
MockMongo.get_client.return_value = {"test_db": mock_db}
prompt = get_prompt("507f1f77bcf86cd799439011")
assert prompt == "From DB"
class TestStreamProcessorInit:
@pytest.mark.unit
def test_init_sets_attributes(self):
mock_db = MagicMock()
mock_client = {"docsgpt": mock_db}
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
mock_settings.MONGO_DB_NAME = "docsgpt"
MockMongo.get_client.return_value = mock_client
from application.api.answer.services.stream_processor import StreamProcessor
sp = StreamProcessor(
request_data={"conversation_id": "conv1", "agent_id": "a1"},
decoded_token={"sub": "user1"},
)
assert sp.conversation_id == "conv1"
assert sp.initial_user_id == "user1"
assert sp.agent_id == "a1"
assert sp.history == []
assert sp.attachments == []
@pytest.mark.unit
def test_init_no_token(self):
mock_db = MagicMock()
mock_client = {"docsgpt": mock_db}
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
mock_settings.MONGO_DB_NAME = "docsgpt"
MockMongo.get_client.return_value = mock_client
from application.api.answer.services.stream_processor import StreamProcessor
sp = StreamProcessor(request_data={}, decoded_token=None)
assert sp.initial_user_id is None
class TestGetAttachmentsContent:
@pytest.mark.unit
def test_empty_ids_returns_empty(self):
mock_db = MagicMock()
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
mock_settings.MONGO_DB_NAME = "docsgpt"
MockMongo.get_client.return_value = {"docsgpt": mock_db}
from application.api.answer.services.stream_processor import StreamProcessor
sp = StreamProcessor(request_data={}, decoded_token={"sub": "u"})
result = sp._get_attachments_content([], "u")
assert result == []
@pytest.mark.unit
def test_returns_matching_attachments(self):
mock_db = MagicMock()
mock_attachments = MagicMock()
mock_attachments.find_one.return_value = {"_id": "att1", "content": "data"}
mock_db.__getitem__ = MagicMock(return_value=mock_attachments)
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
mock_settings.MONGO_DB_NAME = "docsgpt"
MockMongo.get_client.return_value = {"docsgpt": mock_db}
from application.api.answer.services.stream_processor import StreamProcessor
sp = StreamProcessor(request_data={}, decoded_token={"sub": "u"})
result = sp._get_attachments_content(["507f1f77bcf86cd799439011"], "u")
assert len(result) == 1
@pytest.mark.unit
def test_invalid_attachment_id_continues(self):
mock_db = MagicMock()
mock_attachments = MagicMock()
mock_attachments.find_one.side_effect = Exception("bad id")
mock_db.__getitem__ = MagicMock(return_value=mock_attachments)
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
mock_settings.MONGO_DB_NAME = "docsgpt"
MockMongo.get_client.return_value = {"docsgpt": mock_db}
from application.api.answer.services.stream_processor import StreamProcessor
sp = StreamProcessor(request_data={}, decoded_token={"sub": "u"})
result = sp._get_attachments_content(["bad"], "u")
assert result == []
class TestResolveAgentId:
@pytest.mark.unit
def test_from_request_data(self):
mock_db = MagicMock()
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
mock_settings.MONGO_DB_NAME = "docsgpt"
MockMongo.get_client.return_value = {"docsgpt": mock_db}
from application.api.answer.services.stream_processor import StreamProcessor
sp = StreamProcessor(
request_data={"agent_id": "agent_123"},
decoded_token={"sub": "u"},
)
assert sp._resolve_agent_id() == "agent_123"
@pytest.mark.unit
def test_no_agent_no_conversation(self):
mock_db = MagicMock()
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
mock_settings.MONGO_DB_NAME = "docsgpt"
MockMongo.get_client.return_value = {"docsgpt": mock_db}
from application.api.answer.services.stream_processor import StreamProcessor
sp = StreamProcessor(request_data={}, decoded_token={"sub": "u"})
assert sp._resolve_agent_id() is None
@pytest.mark.unit
def test_from_conversation(self):
mock_db = MagicMock()
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
mock_settings.MONGO_DB_NAME = "docsgpt"
MockMongo.get_client.return_value = {"docsgpt": mock_db}
from application.api.answer.services.stream_processor import StreamProcessor
sp = StreamProcessor(
request_data={"conversation_id": "conv1"},
decoded_token={"sub": "u"},
)
sp.conversation_service = MagicMock()
sp.conversation_service.get_conversation.return_value = {"agent_id": "from_conv"}
assert sp._resolve_agent_id() == "from_conv"
@pytest.mark.unit
def test_conversation_not_found(self):
mock_db = MagicMock()
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
mock_settings.MONGO_DB_NAME = "docsgpt"
MockMongo.get_client.return_value = {"docsgpt": mock_db}
from application.api.answer.services.stream_processor import StreamProcessor
sp = StreamProcessor(
request_data={"conversation_id": "conv1"},
decoded_token={"sub": "u"},
)
sp.conversation_service = MagicMock()
sp.conversation_service.get_conversation.return_value = None
assert sp._resolve_agent_id() is None
@pytest.mark.unit
def test_conversation_lookup_exception(self):
mock_db = MagicMock()
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
mock_settings.MONGO_DB_NAME = "docsgpt"
MockMongo.get_client.return_value = {"docsgpt": mock_db}
from application.api.answer.services.stream_processor import StreamProcessor
sp = StreamProcessor(
request_data={"conversation_id": "conv1"},
decoded_token={"sub": "u"},
)
sp.conversation_service = MagicMock()
sp.conversation_service.get_conversation.side_effect = Exception("db error")
assert sp._resolve_agent_id() is None
class TestGetPromptContent:
@pytest.mark.unit
def test_caches_result(self):
mock_db = MagicMock()
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
mock_settings.MONGO_DB_NAME = "docsgpt"
MockMongo.get_client.return_value = {"docsgpt": mock_db}
from application.api.answer.services.stream_processor import StreamProcessor
sp = StreamProcessor(request_data={}, decoded_token={"sub": "u"})
sp.agent_config = {"prompt_id": "default"}
result1 = sp._get_prompt_content()
result2 = sp._get_prompt_content()
assert result1 == result2
assert result1 is not None
@pytest.mark.unit
def test_no_prompt_id(self):
mock_db = MagicMock()
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
mock_settings.MONGO_DB_NAME = "docsgpt"
MockMongo.get_client.return_value = {"docsgpt": mock_db}
from application.api.answer.services.stream_processor import StreamProcessor
sp = StreamProcessor(request_data={}, decoded_token={"sub": "u"})
sp.agent_config = {}
assert sp._get_prompt_content() is None
@pytest.mark.unit
def test_invalid_prompt_id_returns_none(self):
mock_db = MagicMock()
mock_prompts = MagicMock()
mock_prompts.find_one.side_effect = Exception("bad")
mock_db.__getitem__ = MagicMock(return_value=mock_prompts)
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
mock_settings.MONGO_DB_NAME = "docsgpt"
MockMongo.get_client.return_value = {"docsgpt": mock_db}
from application.api.answer.services.stream_processor import StreamProcessor
sp = StreamProcessor(request_data={}, decoded_token={"sub": "u"})
sp.agent_config = {"prompt_id": "bad_id"}
assert sp._get_prompt_content() is None
class TestGetRequiredToolActions:
@pytest.mark.unit
def test_no_prompt_returns_none(self):
mock_db = MagicMock()
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
mock_settings.MONGO_DB_NAME = "docsgpt"
MockMongo.get_client.return_value = {"docsgpt": mock_db}
from application.api.answer.services.stream_processor import StreamProcessor
sp = StreamProcessor(request_data={}, decoded_token={"sub": "u"})
sp.agent_config = {}
assert sp._get_required_tool_actions() is None
@pytest.mark.unit
def test_no_template_syntax_returns_empty(self):
mock_db = MagicMock()
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
mock_settings.MONGO_DB_NAME = "docsgpt"
MockMongo.get_client.return_value = {"docsgpt": mock_db}
from application.api.answer.services.stream_processor import StreamProcessor
sp = StreamProcessor(request_data={}, decoded_token={"sub": "u"})
sp.agent_config = {"prompt_id": "default"}
sp._prompt_content = "No template syntax here"
result = sp._get_required_tool_actions()
assert result == {}

View File

@@ -0,0 +1,333 @@
"""Tests for application/api/connector/routes.py"""
import json
from unittest.mock import MagicMock, patch
import mongomock
import pytest
@pytest.fixture
def app():
with patch("application.app.handle_auth", return_value={"sub": "test_user"}):
from application.app import app as flask_app
flask_app.config["TESTING"] = True
yield flask_app
@pytest.fixture
def client(app):
return app.test_client()
@pytest.fixture
def mock_sessions(monkeypatch):
mock_client = mongomock.MongoClient()
mock_db = mock_client["docsgpt"]
sessions = mock_db["connector_sessions"]
sources = mock_db["sources"]
monkeypatch.setattr("application.api.connector.routes.sessions_collection", sessions)
monkeypatch.setattr("application.api.connector.routes.sources_collection", sources)
return {"sessions": sessions, "sources": sources}
class TestConnectorAuth:
@pytest.mark.unit
def test_missing_provider(self, client):
resp = client.get("/api/connectors/auth")
assert resp.status_code == 400
@pytest.mark.unit
def test_unsupported_provider(self, client):
resp = client.get("/api/connectors/auth?provider=dropbox")
assert resp.status_code == 400
@pytest.mark.unit
def test_unauthorized(self, client, app):
with patch("application.app.handle_auth", return_value=None):
resp = client.get("/api/connectors/auth?provider=google_drive")
data = json.loads(resp.data)
# decoded_token is None -> 401
assert resp.status_code == 401 or data.get("error") == "Unauthorized"
@pytest.mark.unit
def test_success(self, client, mock_sessions):
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
MockCC.is_supported.return_value = True
mock_auth = MagicMock()
mock_auth.get_authorization_url.return_value = "https://oauth.example.com/auth"
MockCC.create_auth.return_value = mock_auth
resp = client.get("/api/connectors/auth?provider=google_drive")
assert resp.status_code == 200
data = json.loads(resp.data)
assert data["success"] is True
assert "authorization_url" in data
@pytest.mark.unit
def test_exception_returns_500(self, client, mock_sessions):
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
MockCC.is_supported.return_value = True
MockCC.create_auth.side_effect = Exception("oauth fail")
resp = client.get("/api/connectors/auth?provider=google_drive")
assert resp.status_code == 500
class TestConnectorFiles:
@pytest.mark.unit
def test_missing_params(self, client):
resp = client.post("/api/connectors/files", json={"provider": "google_drive"})
assert resp.status_code == 400
@pytest.mark.unit
def test_invalid_session(self, client, mock_sessions):
resp = client.post("/api/connectors/files", json={
"provider": "google_drive",
"session_token": "bad_token",
})
assert resp.status_code == 401
@pytest.mark.unit
def test_success(self, client, mock_sessions):
mock_sessions["sessions"].insert_one({
"session_token": "valid_tok",
"user": "test_user",
"provider": "google_drive",
})
mock_doc = MagicMock()
mock_doc.doc_id = "f1"
mock_doc.extra_info = {
"file_name": "test.pdf",
"mime_type": "application/pdf",
"size": 1024,
"modified_time": "2025-01-01T12:00:00.000Z",
"is_folder": False,
}
mock_loader = MagicMock()
mock_loader.load_data.return_value = [mock_doc]
mock_loader.next_page_token = None
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
MockCC.create_connector.return_value = mock_loader
resp = client.post("/api/connectors/files", json={
"provider": "google_drive",
"session_token": "valid_tok",
})
assert resp.status_code == 200
data = json.loads(resp.data)
assert data["success"] is True
assert len(data["files"]) == 1
@pytest.mark.unit
def test_no_modified_time(self, client, mock_sessions):
mock_sessions["sessions"].insert_one({
"session_token": "tok2",
"user": "test_user",
"provider": "google_drive",
})
mock_doc = MagicMock()
mock_doc.doc_id = "f1"
mock_doc.extra_info = {"file_name": "test.pdf", "mime_type": "application/pdf"}
mock_loader = MagicMock()
mock_loader.load_data.return_value = [mock_doc]
mock_loader.next_page_token = None
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
MockCC.create_connector.return_value = mock_loader
resp = client.post("/api/connectors/files", json={
"provider": "google_drive", "session_token": "tok2",
})
assert resp.status_code == 200
class TestConnectorValidateSession:
@pytest.mark.unit
def test_missing_params(self, client):
resp = client.post("/api/connectors/validate-session", json={"provider": "google_drive"})
assert resp.status_code == 400
@pytest.mark.unit
def test_invalid_session(self, client, mock_sessions):
resp = client.post("/api/connectors/validate-session", json={
"provider": "google_drive", "session_token": "bad",
})
assert resp.status_code == 401
@pytest.mark.unit
def test_valid_non_expired(self, client, mock_sessions):
mock_sessions["sessions"].insert_one({
"session_token": "valid",
"user": "test_user",
"provider": "google_drive",
"token_info": {"access_token": "at", "refresh_token": "rt", "expiry": None},
"user_email": "user@example.com",
})
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
mock_auth = MagicMock()
mock_auth.is_token_expired.return_value = False
MockCC.create_auth.return_value = mock_auth
resp = client.post("/api/connectors/validate-session", json={
"provider": "google_drive", "session_token": "valid",
})
assert resp.status_code == 200
data = json.loads(resp.data)
assert data["success"] is True
assert data["expired"] is False
@pytest.mark.unit
def test_expired_with_refresh(self, client, mock_sessions):
mock_sessions["sessions"].insert_one({
"session_token": "expired_tok",
"user": "test_user",
"provider": "google_drive",
"token_info": {"access_token": "old_at", "refresh_token": "rt", "expiry": 100},
})
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
mock_auth = MagicMock()
mock_auth.is_token_expired.return_value = True
mock_auth.refresh_access_token.return_value = {"access_token": "new_at", "refresh_token": "rt"}
mock_auth.sanitize_token_info.return_value = {"access_token": "new_at", "refresh_token": "rt"}
MockCC.create_auth.return_value = mock_auth
resp = client.post("/api/connectors/validate-session", json={
"provider": "google_drive", "session_token": "expired_tok",
})
assert resp.status_code == 200
@pytest.mark.unit
def test_expired_no_refresh(self, client, mock_sessions):
mock_sessions["sessions"].insert_one({
"session_token": "exp_no_ref",
"user": "test_user",
"token_info": {"access_token": "at", "expiry": 100},
})
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
mock_auth = MagicMock()
mock_auth.is_token_expired.return_value = True
MockCC.create_auth.return_value = mock_auth
resp = client.post("/api/connectors/validate-session", json={
"provider": "google_drive", "session_token": "exp_no_ref",
})
assert resp.status_code == 401
class TestConnectorDisconnect:
@pytest.mark.unit
def test_missing_provider(self, client):
resp = client.post("/api/connectors/disconnect", json={})
assert resp.status_code == 400
@pytest.mark.unit
def test_success_with_session(self, client, mock_sessions):
mock_sessions["sessions"].insert_one({"session_token": "del_me", "provider": "google_drive"})
resp = client.post("/api/connectors/disconnect", json={
"provider": "google_drive", "session_token": "del_me",
})
assert resp.status_code == 200
data = json.loads(resp.data)
assert data["success"] is True
@pytest.mark.unit
def test_success_without_session(self, client, mock_sessions):
resp = client.post("/api/connectors/disconnect", json={"provider": "google_drive"})
assert resp.status_code == 200
class TestConnectorSync:
@pytest.mark.unit
def test_missing_params(self, client, mock_sessions):
resp = client.post("/api/connectors/sync", json={"source_id": "abc"})
assert resp.status_code == 400
@pytest.mark.unit
def test_source_not_found(self, client, mock_sessions):
from bson.objectid import ObjectId
resp = client.post("/api/connectors/sync", json={
"source_id": str(ObjectId()), "session_token": "tok",
})
assert resp.status_code == 404
@pytest.mark.unit
def test_unauthorized_source(self, client, mock_sessions):
sid = mock_sessions["sources"].insert_one({"user": "other_user", "name": "src"}).inserted_id
resp = client.post("/api/connectors/sync", json={
"source_id": str(sid), "session_token": "tok",
})
assert resp.status_code == 403
@pytest.mark.unit
def test_missing_provider_in_remote_data(self, client, mock_sessions):
sid = mock_sessions["sources"].insert_one({
"user": "test_user", "name": "src", "remote_data": json.dumps({}),
}).inserted_id
resp = client.post("/api/connectors/sync", json={
"source_id": str(sid), "session_token": "tok",
})
assert resp.status_code == 400
@pytest.mark.unit
def test_success(self, client, mock_sessions):
sid = mock_sessions["sources"].insert_one({
"user": "test_user",
"name": "src",
"remote_data": json.dumps({"provider": "google_drive", "file_ids": ["f1"]}),
}).inserted_id
mock_task = MagicMock()
mock_task.id = "task_123"
with patch("application.api.connector.routes.ingest_connector_task") as mock_ingest:
mock_ingest.delay.return_value = mock_task
resp = client.post("/api/connectors/sync", json={
"source_id": str(sid), "session_token": "tok",
})
assert resp.status_code == 200
data = json.loads(resp.data)
assert data["task_id"] == "task_123"
class TestConnectorCallbackStatus:
@pytest.mark.unit
def test_success_status(self, client):
resp = client.get("/api/connectors/callback-status?status=success&message=OK&provider=google_drive&session_token=tok&user_email=u@e.com")
assert resp.status_code == 200
assert b"success" in resp.data
@pytest.mark.unit
def test_error_status(self, client):
resp = client.get("/api/connectors/callback-status?status=error&message=Failed")
assert resp.status_code == 200
assert b"error" in resp.data
@pytest.mark.unit
def test_cancelled_status(self, client):
resp = client.get("/api/connectors/callback-status?status=cancelled&message=Cancelled&provider=google_drive")
assert resp.status_code == 200
assert b"cancelled" in resp.data
@pytest.mark.unit
def test_unknown_status_defaults_to_error(self, client):
resp = client.get("/api/connectors/callback-status?status=badvalue")
assert resp.status_code == 200
assert b"error" in resp.data
@pytest.mark.unit
def test_html_escaping(self, client):
resp = client.get('/api/connectors/callback-status?status=error&message=<script>alert(1)</script>')
assert resp.status_code == 200
# The raw <script> tag should be escaped (not executable)
assert b"<script>alert(1)</script>" not in resp.data
class TestBuildCallbackRedirect:
@pytest.mark.unit
def test_builds_url(self):
from application.api.connector.routes import build_callback_redirect
url = build_callback_redirect({"status": "success", "message": "OK"})
assert url.startswith("/api/connectors/callback-status?")
assert "status=success" in url

View File

@@ -0,0 +1,339 @@
"""Tests for application/core/model_settings.py"""
from unittest.mock import MagicMock, patch
import pytest
from application.core.model_settings import (
AvailableModel,
ModelCapabilities,
ModelProvider,
ModelRegistry,
)
class TestModelProvider:
@pytest.mark.unit
def test_all_providers_exist(self):
assert ModelProvider.OPENAI == "openai"
assert ModelProvider.ANTHROPIC == "anthropic"
assert ModelProvider.GOOGLE == "google"
assert ModelProvider.GROQ == "groq"
assert ModelProvider.DOCSGPT == "docsgpt"
assert ModelProvider.HUGGINGFACE == "huggingface"
assert ModelProvider.NOVITA == "novita"
assert ModelProvider.OPENROUTER == "openrouter"
assert ModelProvider.SAGEMAKER == "sagemaker"
assert ModelProvider.PREMAI == "premai"
assert ModelProvider.LLAMA_CPP == "llama.cpp"
assert ModelProvider.AZURE_OPENAI == "azure_openai"
class TestModelCapabilities:
@pytest.mark.unit
def test_defaults(self):
caps = ModelCapabilities()
assert caps.supports_tools is False
assert caps.supports_structured_output is False
assert caps.supports_streaming is True
assert caps.supported_attachment_types == []
assert caps.context_window == 128000
assert caps.input_cost_per_token is None
assert caps.output_cost_per_token is None
@pytest.mark.unit
def test_custom_values(self):
caps = ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
context_window=32000,
input_cost_per_token=0.001,
)
assert caps.supports_tools is True
assert caps.context_window == 32000
class TestAvailableModel:
@pytest.mark.unit
def test_to_dict_basic(self):
model = AvailableModel(
id="gpt-4",
provider=ModelProvider.OPENAI,
display_name="GPT-4",
description="OpenAI GPT-4",
)
d = model.to_dict()
assert d["id"] == "gpt-4"
assert d["provider"] == "openai"
assert d["display_name"] == "GPT-4"
assert d["enabled"] is True
assert "base_url" not in d
@pytest.mark.unit
def test_to_dict_with_base_url(self):
model = AvailableModel(
id="local-model",
provider=ModelProvider.OPENAI,
display_name="Local",
base_url="http://localhost:11434",
)
d = model.to_dict()
assert d["base_url"] == "http://localhost:11434"
@pytest.mark.unit
def test_to_dict_includes_capabilities(self):
caps = ModelCapabilities(supports_tools=True, context_window=64000)
model = AvailableModel(
id="m1",
provider=ModelProvider.ANTHROPIC,
display_name="M1",
capabilities=caps,
)
d = model.to_dict()
assert d["supports_tools"] is True
assert d["context_window"] == 64000
class TestModelRegistry:
@pytest.fixture(autouse=True)
def _reset_singleton(self):
"""Reset singleton between tests."""
ModelRegistry._instance = None
ModelRegistry._initialized = False
yield
ModelRegistry._instance = None
ModelRegistry._initialized = False
@pytest.mark.unit
def test_singleton(self):
with patch.object(ModelRegistry, "_load_models"):
r1 = ModelRegistry()
r2 = ModelRegistry()
assert r1 is r2
@pytest.mark.unit
def test_get_instance(self):
with patch.object(ModelRegistry, "_load_models"):
r = ModelRegistry.get_instance()
assert isinstance(r, ModelRegistry)
@pytest.mark.unit
def test_get_model(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
model = AvailableModel(id="test", provider=ModelProvider.OPENAI, display_name="Test")
reg.models["test"] = model
assert reg.get_model("test") is model
assert reg.get_model("nonexistent") is None
@pytest.mark.unit
def test_get_all_models(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models["m1"] = AvailableModel(id="m1", provider=ModelProvider.OPENAI, display_name="M1")
reg.models["m2"] = AvailableModel(id="m2", provider=ModelProvider.ANTHROPIC, display_name="M2")
assert len(reg.get_all_models()) == 2
@pytest.mark.unit
def test_get_enabled_models(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models["m1"] = AvailableModel(id="m1", provider=ModelProvider.OPENAI, display_name="M1", enabled=True)
reg.models["m2"] = AvailableModel(id="m2", provider=ModelProvider.OPENAI, display_name="M2", enabled=False)
enabled = reg.get_enabled_models()
assert len(enabled) == 1
assert enabled[0].id == "m1"
@pytest.mark.unit
def test_model_exists(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models["m1"] = AvailableModel(id="m1", provider=ModelProvider.OPENAI, display_name="M1")
assert reg.model_exists("m1") is True
assert reg.model_exists("m2") is False
@pytest.mark.unit
def test_parse_model_names(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
assert reg._parse_model_names("model1,model2") == ["model1", "model2"]
assert reg._parse_model_names("model1 , model2 ") == ["model1", "model2"]
assert reg._parse_model_names("single") == ["single"]
assert reg._parse_model_names("") == []
assert reg._parse_model_names(None) == []
@pytest.mark.unit
def test_add_docsgpt_models(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
reg._add_docsgpt_models(mock_settings)
assert "docsgpt-local" in reg.models
@pytest.mark.unit
def test_add_huggingface_models(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
reg._add_huggingface_models(mock_settings)
assert "huggingface-local" in reg.models
@pytest.mark.unit
def test_load_models_with_openai_key(self):
mock_settings = MagicMock()
mock_settings.OPENAI_BASE_URL = None
mock_settings.OPENAI_API_KEY = "sk-test"
mock_settings.OPENAI_API_BASE = None
mock_settings.ANTHROPIC_API_KEY = None
mock_settings.GOOGLE_API_KEY = None
mock_settings.GROQ_API_KEY = None
mock_settings.OPEN_ROUTER_API_KEY = None
mock_settings.NOVITA_API_KEY = None
mock_settings.HUGGINGFACE_API_KEY = None
mock_settings.LLM_PROVIDER = "openai"
mock_settings.LLM_NAME = ""
mock_settings.API_KEY = None
with patch("application.core.settings.settings", mock_settings):
reg = ModelRegistry()
assert len(reg.models) > 0
@pytest.mark.unit
def test_load_models_custom_openai_base_url(self):
mock_settings = MagicMock()
mock_settings.OPENAI_BASE_URL = "http://localhost:11434/v1"
mock_settings.OPENAI_API_KEY = "sk-test"
mock_settings.OPENAI_API_BASE = None
mock_settings.ANTHROPIC_API_KEY = None
mock_settings.GOOGLE_API_KEY = None
mock_settings.GROQ_API_KEY = None
mock_settings.OPEN_ROUTER_API_KEY = None
mock_settings.NOVITA_API_KEY = None
mock_settings.HUGGINGFACE_API_KEY = None
mock_settings.LLM_PROVIDER = "openai"
mock_settings.LLM_NAME = "llama3,gemma"
mock_settings.API_KEY = None
with patch("application.core.settings.settings", mock_settings):
reg = ModelRegistry()
assert "llama3" in reg.models
assert "gemma" in reg.models
@pytest.mark.unit
def test_default_model_selection_from_llm_name(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {"gpt-4": AvailableModel(id="gpt-4", provider=ModelProvider.OPENAI, display_name="GPT-4")}
reg.default_model_id = "gpt-4"
assert reg.default_model_id == "gpt-4"
@pytest.mark.unit
def test_add_anthropic_models_with_key(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.ANTHROPIC_API_KEY = "sk-ant-test"
mock_settings.LLM_PROVIDER = ""
mock_settings.LLM_NAME = ""
reg._add_anthropic_models(mock_settings)
assert len(reg.models) > 0
@pytest.mark.unit
def test_add_google_models_with_key(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.GOOGLE_API_KEY = "google-test"
mock_settings.LLM_PROVIDER = ""
mock_settings.LLM_NAME = ""
reg._add_google_models(mock_settings)
assert len(reg.models) > 0
@pytest.mark.unit
def test_add_groq_models_with_key(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.GROQ_API_KEY = "groq-test"
mock_settings.LLM_PROVIDER = ""
mock_settings.LLM_NAME = ""
reg._add_groq_models(mock_settings)
assert len(reg.models) > 0
@pytest.mark.unit
def test_add_openrouter_models_with_key(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.OPEN_ROUTER_API_KEY = "or-test"
mock_settings.LLM_PROVIDER = ""
mock_settings.LLM_NAME = ""
reg._add_openrouter_models(mock_settings)
assert len(reg.models) > 0
@pytest.mark.unit
def test_add_novita_models_with_key(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.NOVITA_API_KEY = "novita-test"
mock_settings.LLM_PROVIDER = ""
mock_settings.LLM_NAME = ""
reg._add_novita_models(mock_settings)
assert len(reg.models) > 0
@pytest.mark.unit
def test_add_azure_openai_models_specific(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.LLM_PROVIDER = "azure_openai"
mock_settings.LLM_NAME = "nonexistent-model"
reg._add_azure_openai_models(mock_settings)
# Falls through to adding all azure models
assert len(reg.models) > 0
@pytest.mark.unit
def test_add_anthropic_models_no_key_with_provider(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.ANTHROPIC_API_KEY = None
mock_settings.LLM_PROVIDER = "anthropic"
mock_settings.LLM_NAME = "nonexistent"
reg._add_anthropic_models(mock_settings)
assert len(reg.models) > 0
@pytest.mark.unit
def test_default_model_fallback_to_first(self):
mock_settings = MagicMock()
mock_settings.OPENAI_BASE_URL = None
mock_settings.OPENAI_API_KEY = None
mock_settings.OPENAI_API_BASE = None
mock_settings.ANTHROPIC_API_KEY = None
mock_settings.GOOGLE_API_KEY = None
mock_settings.GROQ_API_KEY = None
mock_settings.OPEN_ROUTER_API_KEY = None
mock_settings.NOVITA_API_KEY = None
mock_settings.HUGGINGFACE_API_KEY = None
mock_settings.LLM_PROVIDER = ""
mock_settings.LLM_NAME = ""
mock_settings.API_KEY = None
with patch("application.core.settings.settings", mock_settings):
reg = ModelRegistry()
# Should have at least docsgpt-local
assert reg.default_model_id is not None

View File

@@ -0,0 +1,577 @@
#!/usr/bin/env python3
"""
Integration tests for DocsGPT workflow management endpoints.
Uses Flask test client with real MongoDB (must be running).
Endpoints tested:
- /api/workflows (POST) - Create workflow
- /api/workflows/<id> (GET) - Get workflow
- /api/workflows/<id> (PUT) - Update workflow
- /api/workflows/<id> (DELETE) - Delete workflow
Run:
pytest tests/integration/test_workflows.py -v
"""
import time
import pytest
from jose import jwt
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture(scope="module")
def app():
"""Create the real Flask app (connects to real MongoDB)."""
from application.app import app as flask_app
flask_app.config["TESTING"] = True
return flask_app
@pytest.fixture(scope="module")
def client(app):
"""Flask test client.
When AUTH_TYPE is set to simple_jwt/session_jwt a Bearer token is
injected; otherwise the backend already returns {"sub": "local"}
for every request so no token is needed.
"""
from application.core.settings import settings
c = app.test_client()
if settings.AUTH_TYPE in ("simple_jwt", "session_jwt"):
secret = settings.JWT_SECRET_KEY
if not secret:
pytest.skip("JWT_SECRET_KEY not configured")
payload = {"sub": f"test_workflow_integration_{int(time.time())}"}
token = jwt.encode(payload, secret, algorithm="HS256")
c.environ_base["HTTP_AUTHORIZATION"] = f"Bearer {token}"
return c
@pytest.fixture(scope="module")
def created_ids():
"""Accumulator for workflow IDs to clean up after all tests."""
return []
@pytest.fixture(autouse=True, scope="module")
def cleanup(client, created_ids):
"""Delete all test-created workflows after the module finishes."""
yield
for wf_id in created_ids:
try:
client.delete(f"/api/workflows/{wf_id}")
except Exception:
pass
# ---------------------------------------------------------------------------
# Payload helpers
# ---------------------------------------------------------------------------
def simple_workflow(suffix=""):
"""Start -> End."""
return {
"name": f"Simple WF {int(time.time())}{suffix}",
"description": "integration test",
"nodes": [
{"id": "start_1", "type": "start", "title": "Start",
"position": {"x": 0, "y": 0}, "data": {}},
{"id": "end_1", "type": "end", "title": "End",
"position": {"x": 400, "y": 0}, "data": {}},
],
"edges": [
{"id": "edge_1", "source": "start_1", "target": "end_1"},
],
}
def linear_workflow(suffix=""):
"""Start -> Agent -> End."""
return {
"name": f"Linear WF {int(time.time())}{suffix}",
"description": "integration test",
"nodes": [
{"id": "start_1", "type": "start", "title": "Start",
"position": {"x": 0, "y": 0}, "data": {}},
{"id": "agent_1", "type": "agent", "title": "Agent",
"position": {"x": 200, "y": 0}, "data": {
"agent_type": "classic",
"system_prompt": "You are helpful.",
"prompt_template": "",
"stream_to_user": False,
}},
{"id": "end_1", "type": "end", "title": "End",
"position": {"x": 400, "y": 0}, "data": {}},
],
"edges": [
{"id": "edge_1", "source": "start_1", "target": "agent_1"},
{"id": "edge_2", "source": "agent_1", "target": "end_1"},
],
}
def multi_input_end_workflow(suffix=""):
"""Condition branches into two agents, both converging on one end node.
Graph:
start -> condition --(case_1)--> agent_a --\
--(else)----> agent_b ---+--> end
"""
return {
"name": f"Multi-Input End {int(time.time())}{suffix}",
"description": "end node with multiple inputs",
"nodes": [
{"id": "start_1", "type": "start", "title": "Start",
"position": {"x": 0, "y": 100}, "data": {}},
{"id": "cond_1", "type": "condition", "title": "Branch",
"position": {"x": 200, "y": 100}, "data": {
"mode": "simple",
"cases": [
{"name": "Case 1", "expression": "true",
"sourceHandle": "case_1"},
],
}},
{"id": "agent_a", "type": "agent", "title": "Agent A",
"position": {"x": 400, "y": 0}, "data": {
"agent_type": "classic",
"system_prompt": "Branch A",
"prompt_template": "",
"stream_to_user": False,
}},
{"id": "agent_b", "type": "agent", "title": "Agent B",
"position": {"x": 400, "y": 200}, "data": {
"agent_type": "classic",
"system_prompt": "Branch B",
"prompt_template": "",
"stream_to_user": False,
}},
{"id": "end_1", "type": "end", "title": "End",
"position": {"x": 600, "y": 100}, "data": {}},
],
"edges": [
{"id": "e1", "source": "start_1", "target": "cond_1"},
{"id": "e2", "source": "cond_1", "target": "agent_a",
"sourceHandle": "case_1"},
{"id": "e3", "source": "cond_1", "target": "agent_b",
"sourceHandle": "else"},
# Both agents feed into the SAME end node
{"id": "e4", "source": "agent_a", "target": "end_1"},
{"id": "e5", "source": "agent_b", "target": "end_1"},
],
}
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _extract_id(resp):
"""Pull workflow id from create/update response."""
body = resp.get_json()
data = body.get("data") or body
return data.get("id")
def _get_graph(client, wf_id):
"""Fetch workflow and return (nodes, edges)."""
resp = client.get(f"/api/workflows/{wf_id}")
assert resp.status_code == 200, resp.get_data(as_text=True)
body = resp.get_json()
data = body.get("data") or body
return data.get("nodes", []), data.get("edges", [])
# ===========================================================================
# CRUD tests
# ===========================================================================
class TestWorkflowCRUD:
def test_create_simple_workflow(self, client, created_ids):
resp = client.post("/api/workflows", json=simple_workflow())
assert resp.status_code in (200, 201), resp.get_data(as_text=True)
wf_id = _extract_id(resp)
assert wf_id
created_ids.append(wf_id)
def test_create_linear_workflow(self, client, created_ids):
resp = client.post("/api/workflows", json=linear_workflow())
assert resp.status_code in (200, 201), resp.get_data(as_text=True)
wf_id = _extract_id(resp)
assert wf_id
created_ids.append(wf_id)
def test_get_workflow_returns_nodes_and_edges(self, client, created_ids):
resp = client.post("/api/workflows", json=simple_workflow(" get"))
wf_id = _extract_id(resp)
created_ids.append(wf_id)
nodes, edges = _get_graph(client, wf_id)
assert len(nodes) == 2
assert len(edges) == 1
def test_update_workflow(self, client, created_ids):
resp = client.post("/api/workflows", json=simple_workflow(" upd"))
wf_id = _extract_id(resp)
created_ids.append(wf_id)
update_resp = client.put(
f"/api/workflows/{wf_id}", json=linear_workflow(" updated")
)
assert update_resp.status_code == 200, update_resp.get_data(as_text=True)
nodes, edges = _get_graph(client, wf_id)
assert len(nodes) == 3 # start, agent, end
assert len(edges) == 2
def test_delete_workflow(self, client):
resp = client.post("/api/workflows", json=simple_workflow(" del"))
wf_id = _extract_id(resp)
del_resp = client.delete(f"/api/workflows/{wf_id}")
assert del_resp.status_code == 200
get_resp = client.get(f"/api/workflows/{wf_id}")
assert get_resp.status_code in (400, 404)
def test_reject_workflow_without_end_node(self, client):
payload = {
"name": "No End",
"nodes": [
{"id": "s", "type": "start", "title": "Start",
"position": {"x": 0, "y": 0}, "data": {}},
],
"edges": [],
}
resp = client.post("/api/workflows", json=payload)
assert resp.status_code == 400, resp.get_data(as_text=True)
# ===========================================================================
# Multi-input end node tests
# ===========================================================================
class TestMultiInputEndNode:
"""Verify that an end node can receive edges from multiple source nodes."""
def test_create_multi_input_end_workflow_accepted(self, client, created_ids):
"""Backend must accept a workflow where two edges target the same end node."""
resp = client.post("/api/workflows", json=multi_input_end_workflow())
assert resp.status_code in (200, 201), resp.get_data(as_text=True)
wf_id = _extract_id(resp)
assert wf_id
created_ids.append(wf_id)
def test_multi_input_end_all_edges_persisted(self, client, created_ids):
"""After round-trip, both edges into the end node must still be present."""
resp = client.post(
"/api/workflows", json=multi_input_end_workflow(" persist")
)
assert resp.status_code in (200, 201), resp.get_data(as_text=True)
wf_id = _extract_id(resp)
created_ids.append(wf_id)
nodes, edges = _get_graph(client, wf_id)
# Locate end node
end_ids = {n["id"] for n in nodes if n["type"] == "end"}
assert end_ids, "no end node in response"
# Count edges targeting any end node
edges_to_end = [e for e in edges if e["target"] in end_ids]
assert len(edges_to_end) >= 2, (
f"Expected >=2 edges to end, got {len(edges_to_end)}: {edges_to_end}"
)
def test_multi_input_end_total_edge_count(self, client, created_ids):
"""All 5 edges of the multi-input graph must survive persistence."""
resp = client.post(
"/api/workflows", json=multi_input_end_workflow(" count")
)
wf_id = _extract_id(resp)
created_ids.append(wf_id)
_, edges = _get_graph(client, wf_id)
assert len(edges) == 5, f"Expected 5 edges, got {len(edges)}"
def test_update_to_multi_input_end_preserves_edges(self, client, created_ids):
"""Updating a simple workflow to multi-input end keeps all edges."""
# Create simple
resp = client.post("/api/workflows", json=simple_workflow(" pre"))
wf_id = _extract_id(resp)
created_ids.append(wf_id)
# Update to multi-input end
update_resp = client.put(
f"/api/workflows/{wf_id}",
json=multi_input_end_workflow(" post"),
)
assert update_resp.status_code == 200, update_resp.get_data(as_text=True)
nodes, edges = _get_graph(client, wf_id)
end_ids = {n["id"] for n in nodes if n["type"] == "end"}
edges_to_end = [e for e in edges if e["target"] in end_ids]
assert len(edges_to_end) >= 2, (
f"Expected >=2 edges to end after update, got {len(edges_to_end)}"
)
# ---------------------------------------------------------------------------
# Source-aware payload helpers
# ---------------------------------------------------------------------------
def workflow_with_sources(sources, suffix=""):
"""Start -> Agent (with sources) -> End."""
return {
"name": f"Source WF {int(time.time())}{suffix}",
"description": "integration test with sources",
"nodes": [
{"id": "start_1", "type": "start", "title": "Start",
"position": {"x": 0, "y": 0}, "data": {}},
{"id": "agent_1", "type": "agent", "title": "Agent",
"position": {"x": 200, "y": 0}, "data": {
"agent_type": "classic",
"system_prompt": "You are helpful.",
"prompt_template": "",
"stream_to_user": False,
"sources": sources,
"tools": [],
}},
{"id": "end_1", "type": "end", "title": "End",
"position": {"x": 400, "y": 0}, "data": {}},
],
"edges": [
{"id": "edge_1", "source": "start_1", "target": "agent_1"},
{"id": "edge_2", "source": "agent_1", "target": "end_1"},
],
}
def workflow_multi_agent_sources(suffix=""):
"""Start -> Agent A (sources A) -> Agent B (sources B) -> End."""
return {
"name": f"Multi-Agent Sources {int(time.time())}{suffix}",
"description": "two agents with different sources",
"nodes": [
{"id": "start_1", "type": "start", "title": "Start",
"position": {"x": 0, "y": 0}, "data": {}},
{"id": "agent_a", "type": "agent", "title": "Agent A",
"position": {"x": 200, "y": 0}, "data": {
"agent_type": "agentic",
"system_prompt": "Agent A prompt",
"prompt_template": "",
"stream_to_user": False,
"sources": ["src_alpha", "src_beta"],
"tools": [],
}},
{"id": "agent_b", "type": "agent", "title": "Agent B",
"position": {"x": 400, "y": 0}, "data": {
"agent_type": "classic",
"system_prompt": "Agent B prompt",
"prompt_template": "",
"stream_to_user": True,
"sources": ["src_gamma"],
"tools": [],
}},
{"id": "end_1", "type": "end", "title": "End",
"position": {"x": 600, "y": 0}, "data": {}},
],
"edges": [
{"id": "e1", "source": "start_1", "target": "agent_a"},
{"id": "e2", "source": "agent_a", "target": "agent_b"},
{"id": "e3", "source": "agent_b", "target": "end_1"},
],
}
def _find_agent_node(nodes, node_id):
"""Find a specific node by id."""
return next((n for n in nodes if n["id"] == node_id), None)
# ===========================================================================
# Workflow integration tests
# ===========================================================================
class TestWorkflowIntegration:
"""Verify end-to-end workflow create → get → update → get round-trips."""
def test_linear_workflow_round_trip(self, client, created_ids):
"""Create a linear workflow and verify all nodes/edges survive the round-trip."""
payload = linear_workflow(" round-trip")
resp = client.post("/api/workflows", json=payload)
assert resp.status_code in (200, 201), resp.get_data(as_text=True)
wf_id = _extract_id(resp)
assert wf_id
created_ids.append(wf_id)
nodes, edges = _get_graph(client, wf_id)
assert len(nodes) == 3
assert len(edges) == 2
# Verify node types
types = {n["id"]: n["type"] for n in nodes}
assert types["start_1"] == "start"
assert types["agent_1"] == "agent"
assert types["end_1"] == "end"
def test_agent_config_persisted(self, client, created_ids):
"""Agent node config (type, prompts, stream_to_user) round-trips correctly."""
payload = linear_workflow(" config")
resp = client.post("/api/workflows", json=payload)
wf_id = _extract_id(resp)
created_ids.append(wf_id)
nodes, _ = _get_graph(client, wf_id)
agent = _find_agent_node(nodes, "agent_1")
assert agent is not None
assert agent["data"]["agent_type"] == "classic"
assert agent["data"]["system_prompt"] == "You are helpful."
assert agent["data"]["stream_to_user"] is False
def test_update_workflow_replaces_graph(self, client, created_ids):
"""Updating a workflow fully replaces nodes and edges."""
resp = client.post("/api/workflows", json=simple_workflow(" replace"))
wf_id = _extract_id(resp)
created_ids.append(wf_id)
nodes, edges = _get_graph(client, wf_id)
assert len(nodes) == 2
# Update to linear
update_resp = client.put(
f"/api/workflows/{wf_id}", json=linear_workflow(" replaced")
)
assert update_resp.status_code == 200
nodes, edges = _get_graph(client, wf_id)
assert len(nodes) == 3
assert len(edges) == 2
# ===========================================================================
# Source-specific integration tests
# ===========================================================================
class TestWorkflowSources:
"""Verify that agent node sources are persisted and retrieved correctly."""
def test_create_workflow_with_single_source(self, client, created_ids):
"""A workflow with one source on an agent node persists it."""
payload = workflow_with_sources(["default"])
resp = client.post("/api/workflows", json=payload)
assert resp.status_code in (200, 201), resp.get_data(as_text=True)
wf_id = _extract_id(resp)
assert wf_id
created_ids.append(wf_id)
nodes, _ = _get_graph(client, wf_id)
agent = _find_agent_node(nodes, "agent_1")
assert agent is not None, "Agent node not found"
assert agent["data"].get("sources") == ["default"], (
f"Expected sources=['default'], got {agent['data'].get('sources')}"
)
def test_create_workflow_with_multiple_sources(self, client, created_ids):
"""Multiple sources on an agent node are all persisted."""
sources = ["src_1", "src_2", "src_3"]
payload = workflow_with_sources(sources)
resp = client.post("/api/workflows", json=payload)
assert resp.status_code in (200, 201), resp.get_data(as_text=True)
wf_id = _extract_id(resp)
created_ids.append(wf_id)
nodes, _ = _get_graph(client, wf_id)
agent = _find_agent_node(nodes, "agent_1")
assert agent is not None
assert agent["data"].get("sources") == sources
def test_create_workflow_with_empty_sources(self, client, created_ids):
"""An agent node with empty sources list is accepted and persisted."""
payload = workflow_with_sources([])
resp = client.post("/api/workflows", json=payload)
assert resp.status_code in (200, 201), resp.get_data(as_text=True)
wf_id = _extract_id(resp)
assert wf_id
created_ids.append(wf_id)
nodes, _ = _get_graph(client, wf_id)
agent = _find_agent_node(nodes, "agent_1")
assert agent is not None
assert agent["data"].get("sources") == []
def test_update_workflow_sources(self, client, created_ids):
"""Updating a workflow replaces agent sources."""
# Create with original sources
payload = workflow_with_sources(["old_src"])
resp = client.post("/api/workflows", json=payload)
wf_id = _extract_id(resp)
created_ids.append(wf_id)
# Update with new sources
updated_payload = workflow_with_sources(["new_src_1", "new_src_2"], " upd")
update_resp = client.put(f"/api/workflows/{wf_id}", json=updated_payload)
assert update_resp.status_code == 200, update_resp.get_data(as_text=True)
nodes, _ = _get_graph(client, wf_id)
agent = _find_agent_node(nodes, "agent_1")
assert agent is not None
assert agent["data"].get("sources") == ["new_src_1", "new_src_2"]
def test_multi_agent_independent_sources(self, client, created_ids):
"""Each agent node keeps its own distinct sources list."""
payload = workflow_multi_agent_sources()
resp = client.post("/api/workflows", json=payload)
assert resp.status_code in (200, 201), resp.get_data(as_text=True)
wf_id = _extract_id(resp)
created_ids.append(wf_id)
nodes, _ = _get_graph(client, wf_id)
agent_a = _find_agent_node(nodes, "agent_a")
agent_b = _find_agent_node(nodes, "agent_b")
assert agent_a is not None, "Agent A not found"
assert agent_b is not None, "Agent B not found"
assert agent_a["data"].get("sources") == ["src_alpha", "src_beta"]
assert agent_b["data"].get("sources") == ["src_gamma"]
def test_sources_survive_workflow_update(self, client, created_ids):
"""Sources survive when a workflow is updated without changing sources."""
payload = workflow_with_sources(["persistent_src"])
resp = client.post("/api/workflows", json=payload)
wf_id = _extract_id(resp)
created_ids.append(wf_id)
# Update keeping same sources
update_resp = client.put(f"/api/workflows/{wf_id}", json=payload)
assert update_resp.status_code == 200
nodes, _ = _get_graph(client, wf_id)
agent = _find_agent_node(nodes, "agent_1")
assert agent["data"].get("sources") == ["persistent_src"]
def test_remove_sources_on_update(self, client, created_ids):
"""Clearing sources on update results in empty list."""
payload = workflow_with_sources(["will_be_removed"])
resp = client.post("/api/workflows", json=payload)
wf_id = _extract_id(resp)
created_ids.append(wf_id)
# Update with no sources
cleared_payload = workflow_with_sources([], " cleared")
update_resp = client.put(f"/api/workflows/{wf_id}", json=cleared_payload)
assert update_resp.status_code == 200
nodes, _ = _get_graph(client, wf_id)
agent = _find_agent_node(nodes, "agent_1")
assert agent["data"].get("sources") == []

View File

@@ -71,6 +71,7 @@ class TestLLMHandlerCreator:
expected_handlers = {
"openai": OpenAILLMHandler,
"google": GoogleLLMHandler,
"novita": OpenAILLMHandler,
"default": OpenAILLMHandler,
}

View File

@@ -0,0 +1,165 @@
"""Tests for the Novita LLM provider.
Novita uses an OpenAI-compatible API, so NovitaLLM extends OpenAILLM.
These tests verify the Novita-specific configuration is applied correctly.
"""
import types
from unittest.mock import patch
import pytest
from application.llm.novita import NOVITA_BASE_URL, NovitaLLM
class FakeChatCompletions:
"""Fake OpenAI chat completions for testing."""
def __init__(self):
self.last_kwargs = None
class _Msg:
def __init__(self, content=None):
self.content = content
class _Delta:
def __init__(self, content=None):
self.content = content
class _Choice:
def __init__(self, content=None, delta=None):
self.message = FakeChatCompletions._Msg(content=content)
self.delta = FakeChatCompletions._Delta(content=delta)
class _StreamChunk:
def __init__(self, choice):
self.choices = [choice]
class _Response:
def __init__(self, choices=None, lines=None):
self._choices = choices or []
self._lines = lines or []
@property
def choices(self):
return self._choices
def __iter__(self):
for line in self._lines:
yield line
def create(self, **kwargs):
self.last_kwargs = kwargs
if not kwargs.get("stream"):
return FakeChatCompletions._Response(choices=[FakeChatCompletions._Choice(content="novita response")])
return FakeChatCompletions._Response(
lines=[
FakeChatCompletions._StreamChunk(FakeChatCompletions._Choice(delta="part1")),
FakeChatCompletions._StreamChunk(FakeChatCompletions._Choice(delta="part2")),
]
)
class FakeClient:
"""Fake OpenAI client for testing."""
def __init__(self):
self.chat = types.SimpleNamespace(completions=FakeChatCompletions())
@pytest.mark.unit
def test_novita_base_url_constant():
"""Verify the Novita base URL is correctly defined."""
assert NOVITA_BASE_URL == "https://api.novita.ai/openai"
@pytest.mark.unit
def test_novita_llm_uses_novita_base_url():
"""Verify NovitaLLM uses the Novita API endpoint."""
llm = NovitaLLM(api_key="test-key", user_api_key=None)
# The client should be configured with Novita's base URL
assert str(llm.client.base_url) == NOVITA_BASE_URL + "/"
@pytest.mark.unit
def test_novita_llm_uses_novita_api_key():
"""Verify NovitaLLM prioritizes NOVITA_API_KEY from settings."""
with patch("application.llm.novita.settings") as mock_settings:
mock_settings.NOVITA_API_KEY = "novita-test-key"
mock_settings.API_KEY = "fallback-key"
mock_settings.OPENAI_BASE_URL = None
llm = NovitaLLM(api_key=None, user_api_key=None)
assert llm.api_key == "novita-test-key"
@pytest.mark.unit
def test_novita_llm_falls_back_to_api_key():
"""Verify NovitaLLM falls back to API_KEY when NOVITA_API_KEY is not set."""
with patch("application.llm.novita.settings") as mock_settings:
mock_settings.NOVITA_API_KEY = None
mock_settings.API_KEY = "fallback-key"
mock_settings.OPENAI_BASE_URL = None
llm = NovitaLLM(api_key=None, user_api_key=None)
assert llm.api_key == "fallback-key"
@pytest.mark.unit
def test_novita_llm_explicit_api_key_takes_precedence():
"""Verify explicitly passed API key takes precedence over settings."""
with patch("application.llm.novita.settings") as mock_settings:
mock_settings.NOVITA_API_KEY = "settings-key"
mock_settings.API_KEY = "fallback-key"
mock_settings.OPENAI_BASE_URL = None
llm = NovitaLLM(api_key="explicit-key", user_api_key=None)
assert llm.api_key == "explicit-key"
@pytest.mark.unit
def test_novita_llm_custom_base_url():
"""Verify custom base_url can override the default Novita URL."""
custom_url = "https://custom.novita.endpoint/v1"
llm = NovitaLLM(api_key="test-key", user_api_key=None, base_url=custom_url)
assert str(llm.client.base_url) == custom_url + "/"
@pytest.mark.unit
def test_novita_llm_supports_tools():
"""Verify NovitaLLM supports function calling/tools."""
llm = NovitaLLM(api_key="test-key", user_api_key=None)
assert llm.supports_tools() is True
@pytest.mark.unit
def test_novita_llm_supports_structured_output():
"""Verify NovitaLLM supports structured output."""
llm = NovitaLLM(api_key="test-key", user_api_key=None)
assert llm.supports_structured_output() is True
@pytest.mark.unit
def test_novita_llm_gen_calls_client(monkeypatch):
"""Verify NovitaLLM.gen calls the OpenAI-compatible client correctly."""
llm = NovitaLLM(api_key="test-key", user_api_key=None)
llm.client = FakeClient()
msgs = [{"role": "user", "content": "hello"}]
result = llm._raw_gen(llm, model="moonshotai/kimi-k2.5", messages=msgs, stream=False)
assert result == "novita response"
assert llm.client.chat.completions.last_kwargs["model"] == "moonshotai/kimi-k2.5"
@pytest.mark.unit
def test_novita_llm_gen_stream_yields_chunks(monkeypatch):
"""Verify NovitaLLM streaming yields chunks correctly."""
llm = NovitaLLM(api_key="test-key", user_api_key=None)
llm.client = FakeClient()
msgs = [{"role": "user", "content": "hi"}]
gen = llm._raw_gen_stream(llm, model="moonshotai/kimi-k2.5", messages=msgs, stream=True)
chunks = list(gen)
assert "part1" in "".join(chunks)
assert "part2" in "".join(chunks)

0
tests/parser/__init__.py Normal file
View File

View File

View File

@@ -0,0 +1,110 @@
"""Tests for connector base classes."""
import pytest
from application.parser.connectors.base import BaseConnectorAuth, BaseConnectorLoader
from application.parser.schema.base import Document
class ConcreteAuth(BaseConnectorAuth):
"""Minimal concrete implementation for testing the ABC."""
def get_authorization_url(self, state=None):
return f"https://example.com/auth?state={state}"
def exchange_code_for_tokens(self, authorization_code):
return {"access_token": "tok", "code": authorization_code}
def refresh_access_token(self, refresh_token):
return {"access_token": "new_tok", "refresh_token": refresh_token}
def is_token_expired(self, token_info):
return token_info.get("expired", False)
class ConcreteLoader(BaseConnectorLoader):
"""Minimal concrete implementation for testing the ABC."""
def __init__(self, session_token):
self.session_token = session_token
def load_data(self, inputs):
return [Document(text="test", doc_id="1", extra_info={})]
def download_to_directory(self, local_dir, source_config=None):
return {"files_downloaded": 0, "directory_path": local_dir}
class TestBaseConnectorAuth:
@pytest.mark.unit
def test_sanitize_token_info_extracts_standard_fields(self):
auth = ConcreteAuth()
token_info = {
"access_token": "at",
"refresh_token": "rt",
"token_uri": "https://token.uri",
"expiry": 12345,
"extra_field": "should_not_appear",
}
result = auth.sanitize_token_info(token_info)
assert result == {
"access_token": "at",
"refresh_token": "rt",
"token_uri": "https://token.uri",
"expiry": 12345,
}
@pytest.mark.unit
def test_sanitize_token_info_with_extra_kwargs(self):
auth = ConcreteAuth()
token_info = {
"access_token": "at",
"refresh_token": "rt",
"token_uri": "https://token.uri",
"expiry": 100,
}
result = auth.sanitize_token_info(token_info, custom_field="custom_val")
assert result["custom_field"] == "custom_val"
assert result["access_token"] == "at"
@pytest.mark.unit
def test_sanitize_token_info_missing_fields_returns_none(self):
auth = ConcreteAuth()
result = auth.sanitize_token_info({})
assert result["access_token"] is None
assert result["refresh_token"] is None
assert result["token_uri"] is None
assert result["expiry"] is None
@pytest.mark.unit
def test_abstract_methods_invocable_on_concrete(self):
auth = ConcreteAuth()
assert "example.com" in auth.get_authorization_url("s1")
assert auth.exchange_code_for_tokens("code1")["access_token"] == "tok"
assert auth.refresh_access_token("rt")["access_token"] == "new_tok"
assert auth.is_token_expired({"expired": True}) is True
assert auth.is_token_expired({"expired": False}) is False
class TestBaseConnectorLoader:
@pytest.mark.unit
def test_concrete_loader_init(self):
loader = ConcreteLoader("session123")
assert loader.session_token == "session123"
@pytest.mark.unit
def test_concrete_loader_load_data(self):
loader = ConcreteLoader("s")
docs = loader.load_data({})
assert len(docs) == 1
assert docs[0].text == "test"
@pytest.mark.unit
def test_concrete_loader_download_to_directory(self):
loader = ConcreteLoader("s")
result = loader.download_to_directory("/tmp/test")
assert result["directory_path"] == "/tmp/test"
assert result["files_downloaded"] == 0

View File

@@ -0,0 +1,102 @@
"""Tests for ConnectorCreator factory class."""
from unittest.mock import patch, MagicMock
import pytest
class TestConnectorCreator:
@pytest.fixture(autouse=True)
def _patch_settings(self):
"""Patch settings so connector imports don't fail on missing credentials."""
mock_settings = MagicMock()
mock_settings.GOOGLE_CLIENT_ID = "gid"
mock_settings.GOOGLE_CLIENT_SECRET = "gsecret"
mock_settings.CONNECTOR_REDIRECT_BASE_URI = "https://redirect"
mock_settings.MICROSOFT_CLIENT_ID = "mid"
mock_settings.MICROSOFT_CLIENT_SECRET = "msecret"
mock_settings.MICROSOFT_TENANT_ID = "tid"
mock_settings.MONGO_DB_NAME = "test_db"
with patch("application.core.settings.settings", mock_settings), \
patch("application.parser.connectors.share_point.auth.settings", mock_settings), \
patch("application.parser.connectors.google_drive.auth.settings", mock_settings), \
patch("application.parser.connectors.share_point.auth.ConfidentialClientApplication"):
from application.parser.connectors.connector_creator import ConnectorCreator
self.ConnectorCreator = ConnectorCreator
yield
@pytest.mark.unit
def test_get_supported_connectors(self):
supported = self.ConnectorCreator.get_supported_connectors()
assert "google_drive" in supported
assert "share_point" in supported
@pytest.mark.unit
def test_is_supported_valid(self):
assert self.ConnectorCreator.is_supported("google_drive") is True
assert self.ConnectorCreator.is_supported("share_point") is True
@pytest.mark.unit
def test_is_supported_case_insensitive(self):
assert self.ConnectorCreator.is_supported("Google_Drive") is True
assert self.ConnectorCreator.is_supported("SHARE_POINT") is True
@pytest.mark.unit
def test_is_supported_invalid(self):
assert self.ConnectorCreator.is_supported("dropbox") is False
assert self.ConnectorCreator.is_supported("") is False
@pytest.mark.unit
def test_create_auth_google_drive(self):
auth = self.ConnectorCreator.create_auth("google_drive")
from application.parser.connectors.google_drive.auth import GoogleDriveAuth
assert isinstance(auth, GoogleDriveAuth)
@pytest.mark.unit
def test_create_auth_share_point(self):
auth = self.ConnectorCreator.create_auth("share_point")
from application.parser.connectors.share_point.auth import SharePointAuth
assert isinstance(auth, SharePointAuth)
@pytest.mark.unit
def test_create_auth_invalid_raises(self):
with pytest.raises(ValueError, match="No auth class found"):
self.ConnectorCreator.create_auth("invalid")
@pytest.mark.unit
def test_create_connector_invalid_raises(self):
with pytest.raises(ValueError, match="No connector class found"):
self.ConnectorCreator.create_connector("invalid", "session_tok")
@pytest.mark.unit
def test_create_connector_google_drive(self):
with patch("application.parser.connectors.google_drive.loader.GoogleDriveAuth") as MockAuth:
mock_auth_instance = MagicMock()
mock_auth_instance.get_token_info_from_session.return_value = {
"access_token": "at", "refresh_token": "rt"
}
mock_creds = MagicMock()
mock_creds.token = "at"
mock_creds.expired = False
mock_auth_instance.create_credentials_from_token_info.return_value = mock_creds
mock_auth_instance.build_drive_service.return_value = MagicMock()
MockAuth.return_value = mock_auth_instance
loader = self.ConnectorCreator.create_connector("google_drive", "session_tok")
from application.parser.connectors.google_drive.loader import GoogleDriveLoader
assert isinstance(loader, GoogleDriveLoader)
@pytest.mark.unit
def test_create_connector_share_point(self):
with patch("application.parser.connectors.share_point.loader.SharePointAuth") as MockAuth:
mock_auth_instance = MagicMock()
mock_auth_instance.get_token_info_from_session.return_value = {
"access_token": "at", "refresh_token": "rt"
}
MockAuth.return_value = mock_auth_instance
loader = self.ConnectorCreator.create_connector("share_point", "session_tok")
from application.parser.connectors.share_point.loader import SharePointLoader
assert isinstance(loader, SharePointLoader)

View File

@@ -0,0 +1,441 @@
"""Tests for GoogleDriveAuth."""
import datetime
from unittest.mock import MagicMock, patch
import pytest
@pytest.fixture
def mock_settings():
s = MagicMock()
s.GOOGLE_CLIENT_ID = "test-client-id"
s.GOOGLE_CLIENT_SECRET = "test-client-secret"
s.CONNECTOR_REDIRECT_BASE_URI = "https://redirect.example.com/callback"
s.MONGO_DB_NAME = "test_db"
return s
@pytest.fixture
def auth(mock_settings):
with patch("application.parser.connectors.google_drive.auth.settings", mock_settings):
from application.parser.connectors.google_drive.auth import GoogleDriveAuth
return GoogleDriveAuth()
class TestGoogleDriveAuthInit:
@pytest.mark.unit
def test_init_sets_credentials(self, auth, mock_settings):
assert auth.client_id == "test-client-id"
assert auth.client_secret == "test-client-secret"
assert auth.redirect_uri == "https://redirect.example.com/callback"
@pytest.mark.unit
def test_init_missing_client_id_raises(self, mock_settings):
mock_settings.GOOGLE_CLIENT_ID = None
with patch("application.parser.connectors.google_drive.auth.settings", mock_settings):
from application.parser.connectors.google_drive.auth import GoogleDriveAuth
with pytest.raises(ValueError, match="Google OAuth credentials not configured"):
GoogleDriveAuth()
@pytest.mark.unit
def test_init_missing_client_secret_raises(self, mock_settings):
mock_settings.GOOGLE_CLIENT_SECRET = None
with patch("application.parser.connectors.google_drive.auth.settings", mock_settings):
from application.parser.connectors.google_drive.auth import GoogleDriveAuth
with pytest.raises(ValueError, match="Google OAuth credentials not configured"):
GoogleDriveAuth()
class TestGetAuthorizationUrl:
@pytest.mark.unit
def test_returns_authorization_url(self, auth):
mock_flow = MagicMock()
mock_flow.authorization_url.return_value = ("https://accounts.google.com/auth?state=s1", "s1")
with patch("application.parser.connectors.google_drive.auth.Flow") as MockFlow:
MockFlow.from_client_config.return_value = mock_flow
url = auth.get_authorization_url(state="s1")
assert url == "https://accounts.google.com/auth?state=s1"
mock_flow.authorization_url.assert_called_once_with(
access_type='offline',
prompt='consent',
include_granted_scopes='false',
state="s1"
)
@pytest.mark.unit
def test_raises_on_flow_error(self, auth):
with patch("application.parser.connectors.google_drive.auth.Flow") as MockFlow:
MockFlow.from_client_config.side_effect = Exception("flow error")
with pytest.raises(Exception, match="flow error"):
auth.get_authorization_url()
class TestExchangeCodeForTokens:
@pytest.mark.unit
def test_successful_exchange(self, auth):
mock_creds = MagicMock()
mock_creds.token = "access_tok"
mock_creds.refresh_token = "refresh_tok"
mock_creds.token_uri = "https://oauth2.googleapis.com/token"
mock_creds.client_id = "test-client-id"
mock_creds.client_secret = "test-client-secret"
mock_creds.scopes = ["https://www.googleapis.com/auth/drive.file"]
mock_creds.expiry = datetime.datetime(2025, 1, 1, 12, 0, 0)
mock_flow = MagicMock()
mock_flow.credentials = mock_creds
with patch("application.parser.connectors.google_drive.auth.Flow") as MockFlow:
MockFlow.from_client_config.return_value = mock_flow
result = auth.exchange_code_for_tokens("auth_code_123")
assert result["access_token"] == "access_tok"
assert result["refresh_token"] == "refresh_tok"
assert result["token_uri"] == "https://oauth2.googleapis.com/token"
assert result["client_id"] == "test-client-id"
assert result["client_secret"] == "test-client-secret"
assert result["expiry"] == "2025-01-01T12:00:00"
@pytest.mark.unit
def test_empty_code_raises(self, auth):
with pytest.raises(ValueError, match="Authorization code is required"):
auth.exchange_code_for_tokens("")
@pytest.mark.unit
def test_no_access_token_raises(self, auth):
mock_creds = MagicMock()
mock_creds.token = None
mock_creds.refresh_token = "rt"
mock_flow = MagicMock()
mock_flow.credentials = mock_creds
with patch("application.parser.connectors.google_drive.auth.Flow") as MockFlow:
MockFlow.from_client_config.return_value = mock_flow
with pytest.raises(ValueError, match="did not return an access token"):
auth.exchange_code_for_tokens("code")
@pytest.mark.unit
def test_no_refresh_token_raises(self, auth):
mock_creds = MagicMock()
mock_creds.token = "at"
mock_creds.refresh_token = None
mock_flow = MagicMock()
mock_flow.credentials = mock_creds
with patch("application.parser.connectors.google_drive.auth.Flow") as MockFlow:
MockFlow.from_client_config.return_value = mock_flow
with pytest.raises(ValueError, match="No refresh token received"):
auth.exchange_code_for_tokens("code")
@pytest.mark.unit
def test_fills_in_missing_token_uri(self, auth):
mock_creds = MagicMock()
mock_creds.token = "at"
mock_creds.refresh_token = "rt"
mock_creds.token_uri = None
mock_creds.client_id = None
mock_creds.client_secret = None
mock_creds.scopes = []
mock_creds.expiry = None
mock_flow = MagicMock()
mock_flow.credentials = mock_creds
with patch("application.parser.connectors.google_drive.auth.Flow") as MockFlow:
MockFlow.from_client_config.return_value = mock_flow
result = auth.exchange_code_for_tokens("code")
assert result["token_uri"] == "https://oauth2.googleapis.com/token"
assert result["client_id"] == "test-client-id"
assert result["client_secret"] == "test-client-secret"
class TestRefreshAccessToken:
@pytest.mark.unit
def test_successful_refresh(self, auth):
mock_request_cls = MagicMock()
with patch("application.parser.connectors.google_drive.auth.Credentials") as MockCreds, \
patch("google.auth.transport.requests.Request", mock_request_cls):
mock_cred_instance = MagicMock()
mock_cred_instance.token = "new_access"
mock_cred_instance.token_uri = "https://oauth2.googleapis.com/token"
mock_cred_instance.client_id = "cid"
mock_cred_instance.client_secret = "cs"
mock_cred_instance.scopes = []
mock_cred_instance.expiry = datetime.datetime(2025, 6, 1, 0, 0, 0)
MockCreds.return_value = mock_cred_instance
result = auth.refresh_access_token("old_refresh")
assert result["access_token"] == "new_access"
assert result["refresh_token"] == "old_refresh"
mock_cred_instance.refresh.assert_called_once()
@pytest.mark.unit
def test_empty_refresh_token_raises(self, auth):
with pytest.raises(ValueError, match="Refresh token is required"):
auth.refresh_access_token("")
@pytest.mark.unit
def test_refresh_failure_raises(self, auth):
with patch("application.parser.connectors.google_drive.auth.Credentials") as MockCreds, \
patch("google.auth.transport.requests.Request"):
mock_cred_instance = MagicMock()
mock_cred_instance.refresh.side_effect = Exception("refresh failed")
MockCreds.return_value = mock_cred_instance
with pytest.raises(Exception, match="refresh failed"):
auth.refresh_access_token("rt")
class TestCreateCredentialsFromTokenInfo:
@pytest.mark.unit
def test_creates_credentials(self, auth, mock_settings):
with patch("application.parser.connectors.google_drive.auth.Credentials") as MockCreds, \
patch("application.parser.connectors.google_drive.auth.settings", mock_settings):
mock_cred = MagicMock()
mock_cred.token = "at"
MockCreds.return_value = mock_cred
creds = auth.create_credentials_from_token_info({
"access_token": "at",
"refresh_token": "rt",
"scopes": ["scope1"],
})
assert creds.token == "at"
@pytest.mark.unit
def test_missing_access_token_raises(self, auth, mock_settings):
with patch("application.parser.connectors.google_drive.auth.settings", mock_settings):
with pytest.raises(ValueError, match="No access token found"):
auth.create_credentials_from_token_info({})
@pytest.mark.unit
def test_credentials_without_valid_token_raises(self, auth, mock_settings):
with patch("application.parser.connectors.google_drive.auth.Credentials") as MockCreds, \
patch("application.parser.connectors.google_drive.auth.settings", mock_settings):
mock_cred = MagicMock()
mock_cred.token = None
MockCreds.return_value = mock_cred
with pytest.raises(ValueError, match="Credentials created without valid access token"):
auth.create_credentials_from_token_info({"access_token": "at"})
class TestBuildDriveService:
@pytest.mark.unit
def test_builds_service(self, auth):
mock_creds = MagicMock()
mock_creds.token = "at"
mock_creds.refresh_token = "rt"
mock_creds.expired = False
with patch("application.parser.connectors.google_drive.auth.build") as mock_build:
mock_build.return_value = MagicMock()
service = auth.build_drive_service(mock_creds)
mock_build.assert_called_once_with('drive', 'v3', credentials=mock_creds)
assert service is not None
@pytest.mark.unit
def test_no_credentials_raises(self, auth):
with pytest.raises(ValueError, match="No credentials provided"):
auth.build_drive_service(None)
@pytest.mark.unit
def test_no_token_no_refresh_raises(self, auth):
mock_creds = MagicMock()
mock_creds.token = None
mock_creds.refresh_token = None
with pytest.raises(ValueError, match="No access token or refresh token"):
auth.build_drive_service(mock_creds)
@pytest.mark.unit
def test_expired_token_refreshes(self, auth):
mock_creds = MagicMock()
mock_creds.token = "at"
mock_creds.refresh_token = "rt"
mock_creds.expired = True
with patch("application.parser.connectors.google_drive.auth.build") as mock_build, \
patch("google.auth.transport.requests.Request"):
mock_build.return_value = MagicMock()
auth.build_drive_service(mock_creds)
mock_creds.refresh.assert_called_once()
@pytest.mark.unit
def test_expired_no_refresh_token_raises(self, auth):
mock_creds = MagicMock()
mock_creds.token = "at"
mock_creds.refresh_token = None
mock_creds.expired = True
with pytest.raises(ValueError, match="No access token or refresh token"):
auth.build_drive_service(mock_creds)
@pytest.mark.unit
def test_refresh_failure_raises(self, auth):
mock_creds = MagicMock()
mock_creds.token = "at"
mock_creds.refresh_token = "rt"
mock_creds.expired = True
with patch("google.auth.transport.requests.Request"):
mock_creds.refresh.side_effect = Exception("Cannot refresh")
with pytest.raises(ValueError, match="Failed to refresh credentials"):
auth.build_drive_service(mock_creds)
@pytest.mark.unit
def test_http_error_raises(self, auth):
from googleapiclient.errors import HttpError
mock_creds = MagicMock()
mock_creds.token = "at"
mock_creds.refresh_token = "rt"
mock_creds.expired = False
mock_resp = MagicMock()
mock_resp.status = 500
with patch("application.parser.connectors.google_drive.auth.build") as mock_build:
mock_build.side_effect = HttpError(mock_resp, b"error")
with pytest.raises(ValueError, match="HTTP 500"):
auth.build_drive_service(mock_creds)
class TestIsTokenExpired:
@pytest.mark.unit
def test_expired_token(self, auth):
past = (datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(hours=1)).isoformat()
assert auth.is_token_expired({"expiry": past}) is True
@pytest.mark.unit
def test_valid_token(self, auth):
future = (datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=1)).isoformat()
assert auth.is_token_expired({"expiry": future}) is False
@pytest.mark.unit
def test_token_within_buffer(self, auth):
almost_expired = (datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(seconds=30)).isoformat()
assert auth.is_token_expired({"expiry": almost_expired}) is True
@pytest.mark.unit
def test_no_expiry_with_access_token(self, auth):
assert auth.is_token_expired({"access_token": "at"}) is False
@pytest.mark.unit
def test_no_expiry_no_access_token(self, auth):
assert auth.is_token_expired({}) is True
@pytest.mark.unit
def test_invalid_expiry_format_returns_true(self, auth):
assert auth.is_token_expired({"expiry": "not-a-date"}) is True
@pytest.mark.unit
def test_none_expiry_with_access_token(self, auth):
assert auth.is_token_expired({"expiry": None, "access_token": "at"}) is False
class TestGetTokenInfoFromSession:
def _mock_mongo(self, mock_settings, find_one_return):
mock_collection = MagicMock()
mock_collection.find_one.return_value = find_one_return
mock_db = MagicMock()
mock_db.__getitem__ = MagicMock(return_value=mock_collection)
return {mock_settings.MONGO_DB_NAME: mock_db}
@pytest.mark.unit
def test_valid_session(self, auth, mock_settings):
mock_client = self._mock_mongo(mock_settings, {
"session_token": "st",
"token_info": {"access_token": "at", "refresh_token": "rt"},
})
with patch("application.core.mongo_db.MongoDB.get_client", return_value=mock_client), \
patch("application.core.settings.settings", mock_settings):
result = auth.get_token_info_from_session("st")
assert result["access_token"] == "at"
assert result["token_uri"] == "https://oauth2.googleapis.com/token"
@pytest.mark.unit
def test_session_not_found_raises(self, auth, mock_settings):
mock_client = self._mock_mongo(mock_settings, None)
with patch("application.core.mongo_db.MongoDB.get_client", return_value=mock_client), \
patch("application.core.settings.settings", mock_settings):
with pytest.raises(ValueError, match="Failed to retrieve Google Drive token"):
auth.get_token_info_from_session("bad_token")
@pytest.mark.unit
def test_session_missing_token_info_raises(self, auth, mock_settings):
mock_client = self._mock_mongo(mock_settings, {"session_token": "st"})
with patch("application.core.mongo_db.MongoDB.get_client", return_value=mock_client), \
patch("application.core.settings.settings", mock_settings):
with pytest.raises(ValueError, match="Failed to retrieve Google Drive token"):
auth.get_token_info_from_session("st")
@pytest.mark.unit
def test_missing_required_fields_raises(self, auth, mock_settings):
mock_client = self._mock_mongo(mock_settings, {
"session_token": "st",
"token_info": {"access_token": "at"},
})
with patch("application.core.mongo_db.MongoDB.get_client", return_value=mock_client), \
patch("application.core.settings.settings", mock_settings):
with pytest.raises(ValueError, match="Failed to retrieve Google Drive token"):
auth.get_token_info_from_session("st")
@pytest.mark.unit
def test_empty_token_info_raises(self, auth, mock_settings):
mock_client = self._mock_mongo(mock_settings, {
"session_token": "st",
"token_info": None,
})
with patch("application.core.mongo_db.MongoDB.get_client", return_value=mock_client), \
patch("application.core.settings.settings", mock_settings):
with pytest.raises(ValueError, match="Failed to retrieve Google Drive token"):
auth.get_token_info_from_session("st")
class TestValidateCredentials:
@pytest.mark.unit
def test_valid_credentials(self, auth):
mock_creds = MagicMock()
mock_creds.token = "at"
mock_creds.refresh_token = "rt"
mock_creds.expired = False
mock_service = MagicMock()
mock_service.about.return_value.get.return_value.execute.return_value = {"user": {}}
with patch.object(auth, 'build_drive_service', return_value=mock_service):
assert auth.validate_credentials(mock_creds) is True
@pytest.mark.unit
def test_http_error_returns_false(self, auth):
from googleapiclient.errors import HttpError
mock_creds = MagicMock()
mock_resp = MagicMock()
mock_resp.status = 401
mock_service = MagicMock()
mock_service.about.return_value.get.return_value.execute.side_effect = HttpError(mock_resp, b"unauth")
with patch.object(auth, 'build_drive_service', return_value=mock_service):
assert auth.validate_credentials(mock_creds) is False
@pytest.mark.unit
def test_general_error_returns_false(self, auth):
mock_creds = MagicMock()
with patch.object(auth, 'build_drive_service', side_effect=Exception("fail")):
assert auth.validate_credentials(mock_creds) is False

View File

@@ -0,0 +1,851 @@
"""Tests for GoogleDriveLoader."""
from unittest.mock import MagicMock, patch
import pytest
from application.parser.schema.base import Document
def _make_loader(service=None):
"""Create a GoogleDriveLoader with mocked dependencies."""
with patch("application.parser.connectors.google_drive.loader.GoogleDriveAuth") as MockAuth:
mock_auth = MagicMock()
mock_auth.get_token_info_from_session.return_value = {
"access_token": "at",
"refresh_token": "rt",
}
mock_creds = MagicMock()
mock_creds.token = "at"
mock_creds.expired = False
mock_creds.refresh_token = "rt"
mock_auth.create_credentials_from_token_info.return_value = mock_creds
mock_auth.build_drive_service.return_value = service or MagicMock()
MockAuth.return_value = mock_auth
from application.parser.connectors.google_drive.loader import GoogleDriveLoader
loader = GoogleDriveLoader("session_tok")
return loader
@pytest.fixture
def mock_service():
return MagicMock()
@pytest.fixture
def loader(mock_service):
return _make_loader(mock_service)
class TestGoogleDriveLoaderInit:
@pytest.mark.unit
def test_init_sets_attributes(self, loader):
assert loader.session_token == "session_tok"
assert loader.credentials is not None
assert loader.service is not None
assert loader.next_page_token is None
@pytest.mark.unit
def test_init_service_failure_sets_none(self):
with patch("application.parser.connectors.google_drive.loader.GoogleDriveAuth") as MockAuth:
mock_auth = MagicMock()
mock_auth.get_token_info_from_session.return_value = {
"access_token": "at", "refresh_token": "rt"
}
mock_creds = MagicMock()
mock_creds.token = "at"
mock_auth.create_credentials_from_token_info.return_value = mock_creds
mock_auth.build_drive_service.side_effect = Exception("service fail")
MockAuth.return_value = mock_auth
from application.parser.connectors.google_drive.loader import GoogleDriveLoader
loader = GoogleDriveLoader("st")
assert loader.service is None
class TestProcessFile:
@pytest.mark.unit
def test_supported_mime_type_with_content(self, loader):
loader._download_file_content = MagicMock(return_value="file content")
metadata = {
"id": "f1",
"name": "test.pdf",
"mimeType": "application/pdf",
"size": 1024,
"createdTime": "2025-01-01T00:00:00Z",
"modifiedTime": "2025-01-02T00:00:00Z",
"parents": ["root"],
}
doc = loader._process_file(metadata, load_content=True)
assert doc is not None
assert doc.text == "file content"
assert doc.doc_id == "f1"
assert doc.extra_info["file_name"] == "test.pdf"
assert doc.extra_info["source"] == "google_drive"
@pytest.mark.unit
def test_supported_mime_type_no_content(self, loader):
metadata = {
"id": "f1",
"name": "test.pdf",
"mimeType": "application/pdf",
}
doc = loader._process_file(metadata, load_content=False)
assert doc is not None
assert doc.text == ""
assert doc.doc_id == "f1"
@pytest.mark.unit
def test_unsupported_mime_type_returns_none(self, loader):
metadata = {
"id": "f1",
"name": "test.zip",
"mimeType": "application/zip",
}
doc = loader._process_file(metadata)
assert doc is None
@pytest.mark.unit
def test_download_failure_returns_none(self, loader):
loader._download_file_content = MagicMock(return_value=None)
metadata = {
"id": "f1",
"name": "test.txt",
"mimeType": "text/plain",
}
doc = loader._process_file(metadata, load_content=True)
assert doc is None
@pytest.mark.unit
def test_exception_returns_none(self, loader):
loader._download_file_content = MagicMock(side_effect=Exception("fail"))
metadata = {
"id": "f1",
"name": "test.txt",
"mimeType": "text/plain",
}
doc = loader._process_file(metadata, load_content=True)
assert doc is None
class TestLoadData:
@pytest.mark.unit
def test_load_specific_files(self, loader):
doc = Document(text="content", doc_id="f1", extra_info={"file_name": "test.pdf"})
loader._load_file_by_id = MagicMock(return_value=doc)
result = loader.load_data({"file_ids": ["f1"]})
assert len(result) == 1
assert result[0].doc_id == "f1"
@pytest.mark.unit
def test_load_files_with_search_filter(self, loader):
doc = Document(text="c", doc_id="f1", extra_info={"file_name": "report.pdf"})
loader._load_file_by_id = MagicMock(return_value=doc)
result = loader.load_data({"file_ids": ["f1"], "search_query": "report"})
assert len(result) == 1
result = loader.load_data({"file_ids": ["f1"], "search_query": "other"})
assert len(result) == 0
@pytest.mark.unit
def test_load_files_error_continues(self, loader):
loader._load_file_by_id = MagicMock(side_effect=Exception("fail"))
result = loader.load_data({"file_ids": ["f1", "f2"]})
assert len(result) == 0
@pytest.mark.unit
def test_browse_mode_uses_list_items(self, loader):
docs = [Document(text="", doc_id="f1", extra_info={})]
loader._list_items_in_parent = MagicMock(return_value=docs)
result = loader.load_data({"folder_id": "folder1", "limit": 50})
loader._list_items_in_parent.assert_called_once_with(
"folder1", limit=50, load_content=True, page_token=None, search_query=None
)
assert len(result) == 1
@pytest.mark.unit
def test_browse_mode_defaults_to_root(self, loader):
loader._list_items_in_parent = MagicMock(return_value=[])
loader.load_data({})
loader._list_items_in_parent.assert_called_once_with(
"root", limit=100, load_content=True, page_token=None, search_query=None
)
@pytest.mark.unit
def test_session_token_mismatch_logs_warning(self, loader):
loader._list_items_in_parent = MagicMock(return_value=[])
loader.load_data({"session_token": "different_token"})
# Should not raise, just logs
@pytest.mark.unit
def test_load_data_with_list_only(self, loader):
loader._list_items_in_parent = MagicMock(return_value=[])
loader.load_data({"list_only": True})
loader._list_items_in_parent.assert_called_once_with(
"root", limit=100, load_content=False, page_token=None, search_query=None
)
@pytest.mark.unit
def test_load_data_with_page_token(self, loader):
loader._list_items_in_parent = MagicMock(return_value=[])
loader.load_data({"page_token": "next_page"})
loader._list_items_in_parent.assert_called_once_with(
"root", limit=100, load_content=True, page_token="next_page", search_query=None
)
@pytest.mark.unit
def test_credential_refresh_retry(self, loader):
"""When _load_file_by_id returns None and _credential_refreshed is set, retry."""
loader._credential_refreshed = True
call_count = [0]
def side_effect(fid, load_content=True):
call_count[0] += 1
if call_count[0] == 1:
return None
return Document(text="c", doc_id=fid, extra_info={"file_name": "test.pdf"})
loader._load_file_by_id = MagicMock(side_effect=side_effect)
result = loader.load_data({"file_ids": ["f1"]})
assert len(result) == 1
@pytest.mark.unit
def test_outer_exception_raises(self, loader):
"""The outer try/except in load_data re-raises unexpected errors."""
loader._list_items_in_parent = MagicMock(side_effect=RuntimeError("unexpected"))
with pytest.raises(RuntimeError, match="unexpected"):
loader.load_data({})
class TestLoadFileById:
@pytest.mark.unit
def test_loads_file_metadata_and_processes(self, loader, mock_service):
mock_service.files.return_value.get.return_value.execute.return_value = {
"id": "f1", "name": "test.txt", "mimeType": "text/plain"
}
loader._process_file = MagicMock(return_value=Document(text="t", doc_id="f1", extra_info={}))
doc = loader._load_file_by_id("f1")
assert doc is not None
@pytest.mark.unit
def test_http_401_refreshes_credentials(self, loader, mock_service):
from googleapiclient.errors import HttpError
resp = MagicMock()
resp.status = 401
mock_service.files.return_value.get.return_value.execute.side_effect = HttpError(resp, b"unauth")
with patch("google.auth.transport.requests.Request"):
result = loader._load_file_by_id("f1")
assert result is None
loader.credentials.refresh.assert_called_once()
@pytest.mark.unit
def test_http_401_no_refresh_token_raises(self, loader, mock_service):
from googleapiclient.errors import HttpError
resp = MagicMock()
resp.status = 401
mock_service.files.return_value.get.return_value.execute.side_effect = HttpError(resp, b"unauth")
loader.credentials.refresh_token = None
with pytest.raises(ValueError, match="missing refresh_token"):
loader._load_file_by_id("f1")
@pytest.mark.unit
def test_http_500_returns_none(self, loader, mock_service):
from googleapiclient.errors import HttpError
resp = MagicMock()
resp.status = 500
mock_service.files.return_value.get.return_value.execute.side_effect = HttpError(resp, b"server error")
result = loader._load_file_by_id("f1")
assert result is None
@pytest.mark.unit
def test_general_exception_returns_none(self, loader, mock_service):
mock_service.files.return_value.get.return_value.execute.side_effect = Exception("fail")
result = loader._load_file_by_id("f1")
assert result is None
@pytest.mark.unit
def test_ensure_service_called(self, loader):
loader.service = None
loader.auth.build_drive_service.return_value = MagicMock()
loader.auth.build_drive_service.return_value.files.return_value.get.return_value.execute.return_value = {
"id": "f1", "name": "t.txt", "mimeType": "text/plain"
}
loader._process_file = MagicMock(return_value=None)
loader._load_file_by_id("f1")
loader.auth.build_drive_service.assert_called()
@pytest.mark.unit
def test_http_401_refresh_failure_raises(self, loader, mock_service):
from googleapiclient.errors import HttpError
resp = MagicMock()
resp.status = 401
mock_service.files.return_value.get.return_value.execute.side_effect = HttpError(resp, b"unauth")
loader.credentials.refresh_token = "rt"
with patch("google.auth.transport.requests.Request"):
loader.credentials.refresh.side_effect = Exception("refresh broke")
with pytest.raises(ValueError, match="could not be refreshed"):
loader._load_file_by_id("f1")
class TestListItemsInParent:
@pytest.mark.unit
def test_lists_files_and_folders(self, loader, mock_service):
mock_service.files.return_value.list.return_value.execute.return_value = {
"files": [
{"id": "folder1", "name": "Docs", "mimeType": "application/vnd.google-apps.folder"},
{"id": "file1", "name": "test.txt", "mimeType": "text/plain"},
],
"nextPageToken": None,
}
loader._process_file = MagicMock(return_value=Document(text="", doc_id="file1", extra_info={}))
docs = loader._list_items_in_parent("root", limit=100)
assert len(docs) == 2
assert docs[0].extra_info.get("is_folder") is True
@pytest.mark.unit
def test_search_query_modifies_drive_query(self, loader, mock_service):
mock_service.files.return_value.list.return_value.execute.return_value = {
"files": [], "nextPageToken": None
}
loader._list_items_in_parent("root", search_query="report")
call_args = mock_service.files.return_value.list.call_args
assert "name contains 'report'" in call_args.kwargs.get("q", "")
@pytest.mark.unit
def test_limit_stops_early(self, loader, mock_service):
mock_service.files.return_value.list.return_value.execute.return_value = {
"files": [
{"id": f"f{i}", "name": f"file{i}.txt", "mimeType": "text/plain"} for i in range(10)
],
"nextPageToken": "next",
}
loader._process_file = MagicMock(side_effect=lambda m, **kw: Document(text="", doc_id=m["id"], extra_info={}))
docs = loader._list_items_in_parent("root", limit=3)
assert len(docs) == 3
@pytest.mark.unit
def test_pagination_stores_next_page_token(self, loader, mock_service):
mock_service.files.return_value.list.return_value.execute.return_value = {
"files": [{"id": "f1", "name": "f.txt", "mimeType": "text/plain"}],
"nextPageToken": "page2",
}
loader._process_file = MagicMock(return_value=Document(text="", doc_id="f1", extra_info={}))
loader._list_items_in_parent("root", limit=1)
assert loader.next_page_token == "page2"
@pytest.mark.unit
def test_limit_breaks_loop_when_remaining_zero(self, loader, mock_service):
"""When limit is exactly met after first page, the while loop breaks via remaining==0."""
call_count = [0]
def list_side_effect(**kw):
call_count[0] += 1
mock = MagicMock()
if call_count[0] == 1:
mock.execute.return_value = {
"files": [
{"id": "f1", "name": "f1.txt", "mimeType": "text/plain"},
{"id": "f2", "name": "f2.txt", "mimeType": "text/plain"},
],
"nextPageToken": "page2",
}
else:
# Should not reach here if break works correctly
mock.execute.return_value = {"files": [], "nextPageToken": None}
return mock
mock_service.files.return_value.list.side_effect = list_side_effect
loader._process_file = MagicMock(side_effect=lambda m, **kw: Document(text="", doc_id=m["id"], extra_info={}))
# limit=2, first page returns exactly 2 files with nextPageToken
# Since items don't hit the inner limit check (it checks after each item),
# it should loop back and break at remaining==0
docs = loader._list_items_in_parent("root", limit=2)
assert len(docs) == 2
@pytest.mark.unit
def test_exception_returns_partial_results(self, loader, mock_service):
mock_service.files.return_value.list.return_value.execute.side_effect = Exception("api error")
docs = loader._list_items_in_parent("root")
assert docs == []
class TestDownloadFileContent:
@pytest.mark.unit
def test_download_regular_file(self, loader, mock_service):
mock_request = MagicMock()
mock_service.files.return_value.get_media.return_value = mock_request
with patch("application.parser.connectors.google_drive.loader.MediaIoBaseDownload") as MockDownload:
mock_dl = MagicMock()
mock_dl.next_chunk.side_effect = [(None, False), (None, True)]
MockDownload.return_value = mock_dl
with patch("io.BytesIO") as MockBytesIO:
mock_bio = MagicMock()
mock_bio.getvalue.return_value = b"file content"
MockBytesIO.return_value = mock_bio
content = loader._download_file_content("f1", "text/plain")
assert content == "file content"
@pytest.mark.unit
def test_download_google_workspace_file_uses_export(self, loader, mock_service):
mock_request = MagicMock()
mock_service.files.return_value.export_media.return_value = mock_request
with patch("application.parser.connectors.google_drive.loader.MediaIoBaseDownload") as MockDownload:
mock_dl = MagicMock()
mock_dl.next_chunk.return_value = (None, True)
MockDownload.return_value = mock_dl
with patch("io.BytesIO") as MockBytesIO:
mock_bio = MagicMock()
mock_bio.getvalue.return_value = b"exported"
MockBytesIO.return_value = mock_bio
content = loader._download_file_content("f1", "application/vnd.google-apps.document")
mock_service.files.return_value.export_media.assert_called_once()
assert content == "exported"
@pytest.mark.unit
def test_unicode_decode_error_returns_none(self, loader, mock_service):
mock_request = MagicMock()
mock_service.files.return_value.get_media.return_value = mock_request
with patch("application.parser.connectors.google_drive.loader.MediaIoBaseDownload") as MockDownload:
mock_dl = MagicMock()
mock_dl.next_chunk.return_value = (None, True)
MockDownload.return_value = mock_dl
with patch("io.BytesIO") as MockBytesIO:
mock_bio = MagicMock()
mock_bio.getvalue.return_value = MagicMock()
mock_bio.getvalue.return_value.decode.side_effect = UnicodeDecodeError("utf-8", b"", 0, 1, "bad")
MockBytesIO.return_value = mock_bio
result = loader._download_file_content("f1", "application/pdf")
assert result is None
@pytest.mark.unit
def test_no_access_token_with_refresh(self, loader):
loader.credentials.token = None
loader.credentials.refresh_token = "rt"
loader.credentials.expired = False
with patch("google.auth.transport.requests.Request"):
loader.credentials.refresh = MagicMock()
# After refresh, set token
def set_token(req):
loader.credentials.token = "new_at"
loader.credentials.refresh.side_effect = set_token
loader._ensure_service = MagicMock()
loader.service.files.return_value.get_media.return_value = MagicMock()
with patch("application.parser.connectors.google_drive.loader.MediaIoBaseDownload") as MockDownload:
mock_dl = MagicMock()
mock_dl.next_chunk.return_value = (None, True)
MockDownload.return_value = mock_dl
with patch("io.BytesIO") as MockBytesIO:
mock_bio = MagicMock()
mock_bio.getvalue.return_value = b"data"
MockBytesIO.return_value = mock_bio
content = loader._download_file_content("f1", "text/plain")
assert content == "data"
@pytest.mark.unit
def test_no_access_token_no_refresh_raises(self, loader):
loader.credentials.token = None
loader.credentials.refresh_token = None
with pytest.raises(ValueError, match="missing refresh_token"):
loader._download_file_content("f1", "text/plain")
@pytest.mark.unit
def test_no_token_refresh_fails_raises(self, loader):
loader.credentials.token = None
loader.credentials.refresh_token = "rt"
with patch("google.auth.transport.requests.Request"):
loader.credentials.refresh.side_effect = Exception("fail")
with pytest.raises(ValueError, match="missing or invalid refresh_token"):
loader._download_file_content("f1", "text/plain")
@pytest.mark.unit
def test_expired_credentials_refresh(self, loader):
loader.credentials.token = "at"
loader.credentials.expired = True
loader.credentials.refresh_token = "rt"
with patch("google.auth.transport.requests.Request"):
loader.credentials.refresh = MagicMock()
def fix_expired(req):
loader.credentials.expired = False
loader.credentials.refresh.side_effect = fix_expired
loader._ensure_service = MagicMock()
loader.service.files.return_value.get_media.return_value = MagicMock()
with patch("application.parser.connectors.google_drive.loader.MediaIoBaseDownload") as MockDownload:
mock_dl = MagicMock()
mock_dl.next_chunk.return_value = (None, True)
MockDownload.return_value = mock_dl
with patch("io.BytesIO") as MockBytesIO:
mock_bio = MagicMock()
mock_bio.getvalue.return_value = b"ok"
MockBytesIO.return_value = mock_bio
content = loader._download_file_content("f1", "text/plain")
assert content == "ok"
@pytest.mark.unit
def test_expired_no_refresh_token_raises(self, loader):
loader.credentials.token = "at"
loader.credentials.expired = True
loader.credentials.refresh_token = None
with pytest.raises(ValueError, match="missing refresh_token"):
loader._download_file_content("f1", "text/plain")
@pytest.mark.unit
def test_expired_refresh_fails_raises(self, loader):
loader.credentials.token = "at"
loader.credentials.expired = True
loader.credentials.refresh_token = "rt"
with patch("google.auth.transport.requests.Request"):
loader.credentials.refresh.side_effect = Exception("fail")
with pytest.raises(ValueError, match="expired credentials"):
loader._download_file_content("f1", "text/plain")
@pytest.mark.unit
def test_http_error_during_download_chunk_returns_none(self, loader, mock_service):
from googleapiclient.errors import HttpError
mock_request = MagicMock()
mock_service.files.return_value.get_media.return_value = mock_request
resp = MagicMock()
resp.status = 500
with patch("application.parser.connectors.google_drive.loader.MediaIoBaseDownload") as MockDownload:
mock_dl = MagicMock()
mock_dl.next_chunk.side_effect = HttpError(resp, b"server error")
MockDownload.return_value = mock_dl
with patch("io.BytesIO") as MockBytesIO:
MockBytesIO.return_value = MagicMock()
result = loader._download_file_content("f1", "text/plain")
assert result is None
@pytest.mark.unit
def test_general_error_during_download_chunk_returns_none(self, loader, mock_service):
mock_request = MagicMock()
mock_service.files.return_value.get_media.return_value = mock_request
with patch("application.parser.connectors.google_drive.loader.MediaIoBaseDownload") as MockDownload:
mock_dl = MagicMock()
mock_dl.next_chunk.side_effect = RuntimeError("chunk fail")
MockDownload.return_value = mock_dl
with patch("io.BytesIO") as MockBytesIO:
MockBytesIO.return_value = MagicMock()
result = loader._download_file_content("f1", "text/plain")
assert result is None
@pytest.mark.unit
def test_http_401_during_download_refreshes_and_returns_none(self, loader, mock_service):
from googleapiclient.errors import HttpError
resp = MagicMock()
resp.status = 401
mock_service.files.return_value.get_media.side_effect = HttpError(resp, b"unauth")
loader.credentials.refresh_token = "rt"
with patch("google.auth.transport.requests.Request"):
loader.credentials.refresh = MagicMock()
loader._ensure_service = MagicMock()
result = loader._download_file_content("f1", "text/plain")
assert result is None
assert loader._credential_refreshed is True
@pytest.mark.unit
def test_http_401_no_refresh_token_raises(self, loader, mock_service):
from googleapiclient.errors import HttpError
resp = MagicMock()
resp.status = 401
mock_service.files.return_value.get_media.side_effect = HttpError(resp, b"unauth")
loader.credentials.refresh_token = None
with pytest.raises(ValueError, match="missing refresh_token"):
loader._download_file_content("f1", "text/plain")
@pytest.mark.unit
def test_http_401_refresh_fails_raises(self, loader, mock_service):
from googleapiclient.errors import HttpError
resp = MagicMock()
resp.status = 401
mock_service.files.return_value.get_media.side_effect = HttpError(resp, b"unauth")
loader.credentials.refresh_token = "rt"
with patch("google.auth.transport.requests.Request"):
loader.credentials.refresh.side_effect = Exception("refresh fail")
with pytest.raises(ValueError, match="could not be refreshed"):
loader._download_file_content("f1", "text/plain")
@pytest.mark.unit
def test_http_500_returns_none(self, loader, mock_service):
from googleapiclient.errors import HttpError
resp = MagicMock()
resp.status = 500
mock_service.files.return_value.get_media.side_effect = HttpError(resp, b"error")
result = loader._download_file_content("f1", "text/plain")
assert result is None
@pytest.mark.unit
def test_general_exception_returns_none(self, loader, mock_service):
mock_service.files.return_value.get_media.side_effect = RuntimeError("fail")
result = loader._download_file_content("f1", "text/plain")
assert result is None
class TestDownloadToDirectory:
@pytest.mark.unit
def test_download_files(self, loader, tmp_path):
loader._download_file_to_directory = MagicMock(return_value=True)
result = loader.download_to_directory(str(tmp_path), {"file_ids": ["f1", "f2"]})
assert result["files_downloaded"] == 2
assert result["source_type"] == "google_drive"
assert result["empty_result"] is False
@pytest.mark.unit
def test_download_files_string_id(self, loader, tmp_path):
loader._download_file_to_directory = MagicMock(return_value=True)
result = loader.download_to_directory(str(tmp_path), {"file_ids": "single_id"})
assert result["files_downloaded"] == 1
@pytest.mark.unit
def test_download_folders(self, loader, tmp_path, mock_service):
mock_service.files.return_value.get.return_value.execute.return_value = {"name": "MyFolder"}
loader._download_folder_recursive = MagicMock(return_value=3)
result = loader.download_to_directory(str(tmp_path), {"folder_ids": ["folder1"]})
assert result["files_downloaded"] == 3
@pytest.mark.unit
def test_download_folders_string_id(self, loader, tmp_path, mock_service):
mock_service.files.return_value.get.return_value.execute.return_value = {"name": "F"}
loader._download_folder_recursive = MagicMock(return_value=1)
result = loader.download_to_directory(str(tmp_path), {"folder_ids": "single_folder"})
assert result["files_downloaded"] == 1
@pytest.mark.unit
def test_no_ids_returns_error(self, loader, tmp_path):
result = loader.download_to_directory(str(tmp_path), {})
assert "error" in result
assert result["empty_result"] is True
@pytest.mark.unit
def test_none_config_uses_empty(self, loader, tmp_path):
result = loader.download_to_directory(str(tmp_path))
assert "error" in result
@pytest.mark.unit
def test_folder_error_continues(self, loader, tmp_path, mock_service):
mock_service.files.return_value.get.return_value.execute.side_effect = Exception("fail")
loader._download_file_to_directory = MagicMock(return_value=True)
result = loader.download_to_directory(str(tmp_path), {"file_ids": ["f1"], "folder_ids": ["bad_folder"]})
assert result["files_downloaded"] == 1
class TestDownloadSingleFile:
@pytest.mark.unit
def test_downloads_supported_file(self, loader, mock_service, tmp_path):
mock_service.files.return_value.get.return_value.execute.return_value = {
"name": "test.txt", "mimeType": "text/plain"
}
mock_request = MagicMock()
mock_service.files.return_value.get_media.return_value = mock_request
with patch("application.parser.connectors.google_drive.loader.MediaIoBaseDownload") as MockDl:
mock_dl = MagicMock()
mock_dl.next_chunk.return_value = (None, True)
MockDl.return_value = mock_dl
result = loader._download_single_file("f1", str(tmp_path))
assert result is True
@pytest.mark.unit
def test_unsupported_mime_returns_false(self, loader, mock_service, tmp_path):
mock_service.files.return_value.get.return_value.execute.return_value = {
"name": "test.zip", "mimeType": "application/zip"
}
result = loader._download_single_file("f1", str(tmp_path))
assert result is False
@pytest.mark.unit
def test_google_workspace_file_export(self, loader, mock_service, tmp_path):
mock_service.files.return_value.get.return_value.execute.return_value = {
"name": "doc", "mimeType": "application/vnd.google-apps.document"
}
mock_request = MagicMock()
mock_service.files.return_value.export_media.return_value = mock_request
with patch("application.parser.connectors.google_drive.loader.MediaIoBaseDownload") as MockDl:
mock_dl = MagicMock()
mock_dl.next_chunk.return_value = (None, True)
MockDl.return_value = mock_dl
result = loader._download_single_file("f1", str(tmp_path))
assert result is True
mock_service.files.return_value.export_media.assert_called_once()
class TestDownloadFolderRecursive:
@pytest.mark.unit
def test_downloads_files_in_folder(self, loader, mock_service, tmp_path):
mock_service.files.return_value.list.return_value.execute.return_value = {
"files": [
{"id": "f1", "name": "file1.txt", "mimeType": "text/plain"},
],
"nextPageToken": None,
}
loader._download_single_file = MagicMock(return_value=True)
count = loader._download_folder_recursive("folder1", str(tmp_path))
assert count == 1
@pytest.mark.unit
def test_recurses_into_subfolders(self, loader, mock_service, tmp_path):
# First call: folder with subfolder and file
# Second call: subfolder contents
call_count = [0]
def list_side_effect():
mock = MagicMock()
if call_count[0] == 0:
call_count[0] += 1
mock.execute.return_value = {
"files": [
{"id": "sub1", "name": "subfolder", "mimeType": "application/vnd.google-apps.folder"},
{"id": "f1", "name": "file1.txt", "mimeType": "text/plain"},
],
"nextPageToken": None,
}
else:
mock.execute.return_value = {
"files": [
{"id": "f2", "name": "file2.txt", "mimeType": "text/plain"},
],
"nextPageToken": None,
}
return mock
mock_service.files.return_value.list.side_effect = lambda **kw: list_side_effect()
loader._download_single_file = MagicMock(return_value=True)
count = loader._download_folder_recursive("folder1", str(tmp_path), recursive=True)
assert count == 2
@pytest.mark.unit
def test_non_recursive_skips_subfolders(self, loader, mock_service, tmp_path):
mock_service.files.return_value.list.return_value.execute.return_value = {
"files": [
{"id": "sub1", "name": "subfolder", "mimeType": "application/vnd.google-apps.folder"},
{"id": "f1", "name": "file1.txt", "mimeType": "text/plain"},
],
"nextPageToken": None,
}
loader._download_single_file = MagicMock(return_value=True)
count = loader._download_folder_recursive("folder1", str(tmp_path), recursive=False)
assert count == 1
@pytest.mark.unit
def test_download_failure_continues(self, loader, mock_service, tmp_path):
mock_service.files.return_value.list.return_value.execute.return_value = {
"files": [
{"id": "f1", "name": "fail.txt", "mimeType": "text/plain"},
{"id": "f2", "name": "ok.txt", "mimeType": "text/plain"},
],
"nextPageToken": None,
}
loader._download_single_file = MagicMock(side_effect=[False, True])
count = loader._download_folder_recursive("folder1", str(tmp_path))
assert count == 1
@pytest.mark.unit
def test_exception_returns_partial_count(self, loader, mock_service, tmp_path):
mock_service.files.return_value.list.return_value.execute.side_effect = Exception("fail")
count = loader._download_folder_recursive("folder1", str(tmp_path))
assert count == 0
class TestDownloadFileToDirectory:
@pytest.mark.unit
def test_delegates_to_download_single(self, loader, tmp_path):
loader._download_single_file = MagicMock(return_value=True)
assert loader._download_file_to_directory("f1", str(tmp_path)) is True
@pytest.mark.unit
def test_exception_returns_false(self, loader, tmp_path):
loader._download_single_file = MagicMock(side_effect=Exception("fail"))
assert loader._download_file_to_directory("f1", str(tmp_path)) is False
class TestDownloadFolderContents:
@pytest.mark.unit
def test_delegates_to_recursive(self, loader, tmp_path):
loader._download_folder_recursive = MagicMock(return_value=5)
count = loader._download_folder_contents("folder1", str(tmp_path))
assert count == 5
@pytest.mark.unit
def test_exception_returns_zero(self, loader, tmp_path):
loader.service = None
loader.auth.build_drive_service.side_effect = Exception("fail")
count = loader._download_folder_contents("folder1", str(tmp_path))
assert count == 0
class TestEnsureService:
@pytest.mark.unit
def test_builds_service_when_none(self, loader):
loader.service = None
mock_svc = MagicMock()
loader.auth.build_drive_service.return_value = mock_svc
loader._ensure_service()
assert loader.service == mock_svc
@pytest.mark.unit
def test_noop_when_service_exists(self, loader, mock_service):
loader._ensure_service()
assert loader.service == mock_service
@pytest.mark.unit
def test_build_failure_raises(self, loader):
loader.service = None
loader.auth.build_drive_service.side_effect = Exception("fail")
with pytest.raises(ValueError, match="Cannot access Google Drive"):
loader._ensure_service()
class TestGetExtensionForMimeType:
@pytest.mark.unit
def test_known_mime_types(self, loader):
assert loader._get_extension_for_mime_type("application/pdf") == ".pdf"
assert loader._get_extension_for_mime_type("text/plain") == ".txt"
assert loader._get_extension_for_mime_type("text/html") == ".html"
assert loader._get_extension_for_mime_type("text/markdown") == ".md"
@pytest.mark.unit
def test_unknown_returns_bin(self, loader):
assert loader._get_extension_for_mime_type("application/unknown") == ".bin"

View File

@@ -0,0 +1,326 @@
"""Tests for SharePointAuth."""
import datetime
from unittest.mock import MagicMock, patch
import pytest
@pytest.fixture
def mock_settings():
s = MagicMock()
s.MICROSOFT_CLIENT_ID = "ms-client-id"
s.MICROSOFT_CLIENT_SECRET = "ms-client-secret"
s.MICROSOFT_TENANT_ID = "tenant-id-123"
s.CONNECTOR_REDIRECT_BASE_URI = "https://redirect.example.com/callback"
s.MONGO_DB_NAME = "test_db"
# Delete MICROSOFT_AUTHORITY so getattr falls back to default
del s.MICROSOFT_AUTHORITY
return s
@pytest.fixture
def mock_msal():
with patch("application.parser.connectors.share_point.auth.ConfidentialClientApplication") as MockMSAL:
mock_app = MagicMock()
MockMSAL.return_value = mock_app
yield mock_app
@pytest.fixture
def auth(mock_settings, mock_msal):
with patch("application.parser.connectors.share_point.auth.settings", mock_settings):
from application.parser.connectors.share_point.auth import SharePointAuth
return SharePointAuth()
class TestSharePointAuthInit:
@pytest.mark.unit
def test_init_sets_attributes(self, auth, mock_settings):
assert auth.client_id == "ms-client-id"
assert auth.client_secret == "ms-client-secret"
assert auth.redirect_uri == "https://redirect.example.com/callback"
assert auth.tenant_id == "tenant-id-123"
@pytest.mark.unit
def test_missing_client_id_raises(self, mock_settings):
mock_settings.MICROSOFT_CLIENT_ID = None
with patch("application.parser.connectors.share_point.auth.settings", mock_settings), \
patch("application.parser.connectors.share_point.auth.ConfidentialClientApplication"):
from application.parser.connectors.share_point.auth import SharePointAuth
with pytest.raises(ValueError, match="MICROSOFT_CLIENT_ID"):
SharePointAuth()
@pytest.mark.unit
def test_missing_client_secret_raises(self, mock_settings):
mock_settings.MICROSOFT_CLIENT_SECRET = None
with patch("application.parser.connectors.share_point.auth.settings", mock_settings), \
patch("application.parser.connectors.share_point.auth.ConfidentialClientApplication"):
from application.parser.connectors.share_point.auth import SharePointAuth
with pytest.raises(ValueError, match="MICROSOFT_CLIENT_SECRET"):
SharePointAuth()
@pytest.mark.unit
def test_default_authority(self, auth):
assert "login.microsoftonline.com" in auth.authority
assert "tenant-id-123" in auth.authority
class TestGetAuthorizationUrl:
@pytest.mark.unit
def test_returns_url(self, auth, mock_msal):
mock_msal.get_authorization_request_url.return_value = "https://login.microsoftonline.com/auth?state=s1"
url = auth.get_authorization_url(state="s1")
assert url == "https://login.microsoftonline.com/auth?state=s1"
mock_msal.get_authorization_request_url.assert_called_once()
class TestExchangeCodeForTokens:
@pytest.mark.unit
def test_successful_exchange(self, auth, mock_msal):
mock_msal.acquire_token_by_authorization_code.return_value = {
"access_token": "at",
"refresh_token": "rt",
"scope": ["Files.Read"],
"id_token_claims": {
"iss": "https://login.microsoftonline.com/tid/v2.0",
"exp": 1700000000,
"tid": "work-tenant-id",
"name": "Test User",
"preferred_username": "test@example.com",
},
}
result = auth.exchange_code_for_tokens("auth_code")
assert result["access_token"] == "at"
assert result["refresh_token"] == "rt"
assert result["user_info"]["name"] == "Test User"
assert result["user_info"]["email"] == "test@example.com"
assert result["allows_shared_content"] is True
@pytest.mark.unit
def test_error_in_response_raises(self, auth, mock_msal):
mock_msal.acquire_token_by_authorization_code.return_value = {
"error": "invalid_grant",
"error_description": "Code expired",
}
with pytest.raises(ValueError, match="Code expired"):
auth.exchange_code_for_tokens("bad_code")
class TestRefreshAccessToken:
@pytest.mark.unit
def test_successful_refresh(self, auth, mock_msal):
mock_msal.acquire_token_by_refresh_token.return_value = {
"access_token": "new_at",
"refresh_token": "new_rt",
"scope": ["Files.Read"],
"id_token_claims": {
"iss": "https://issuer",
"exp": 1700001000,
"tid": "work-tid",
"name": "User",
"preferred_username": "u@example.com",
},
}
result = auth.refresh_access_token("old_rt")
assert result["access_token"] == "new_at"
assert result["refresh_token"] == "new_rt"
@pytest.mark.unit
def test_error_in_response_raises(self, auth, mock_msal):
mock_msal.acquire_token_by_refresh_token.return_value = {
"error": "invalid_grant",
"error_description": "Token revoked",
}
with pytest.raises(ValueError, match="Token revoked"):
auth.refresh_access_token("bad_rt")
class TestIsTokenExpired:
@pytest.mark.unit
def test_expired_token(self, auth):
past = int((datetime.datetime.now() - datetime.timedelta(hours=1)).timestamp())
assert auth.is_token_expired({"expiry": past}) is True
@pytest.mark.unit
def test_valid_token(self, auth):
future = int((datetime.datetime.now() + datetime.timedelta(hours=1)).timestamp())
assert auth.is_token_expired({"expiry": future}) is False
@pytest.mark.unit
def test_within_buffer(self, auth):
almost_expired = int((datetime.datetime.now() + datetime.timedelta(seconds=30)).timestamp())
assert auth.is_token_expired({"expiry": almost_expired}) is True
@pytest.mark.unit
def test_none_token_info(self, auth):
assert auth.is_token_expired(None) is True
@pytest.mark.unit
def test_missing_expiry(self, auth):
assert auth.is_token_expired({}) is True
@pytest.mark.unit
def test_none_expiry(self, auth):
assert auth.is_token_expired({"expiry": None}) is True
class TestSanitizeTokenInfo:
@pytest.mark.unit
def test_includes_allows_shared_content(self, auth):
token_info = {
"access_token": "at",
"refresh_token": "rt",
"token_uri": "https://uri",
"expiry": 123,
"allows_shared_content": True,
}
result = auth.sanitize_token_info(token_info)
assert result["allows_shared_content"] is True
assert result["access_token"] == "at"
@pytest.mark.unit
def test_defaults_allows_shared_content_to_false(self, auth):
token_info = {
"access_token": "at",
"refresh_token": "rt",
"token_uri": "https://uri",
"expiry": 123,
}
result = auth.sanitize_token_info(token_info)
assert result["allows_shared_content"] is False
@pytest.mark.unit
def test_with_extra_fields(self, auth):
token_info = {
"access_token": "at",
"refresh_token": "rt",
"token_uri": "https://uri",
"expiry": 123,
"allows_shared_content": True,
}
result = auth.sanitize_token_info(token_info, custom="val")
assert result["custom"] == "val"
class TestAllowsSharedContent:
@pytest.mark.unit
def test_work_account_returns_true(self, auth):
claims = {"tid": "some-work-tenant-id"}
assert auth._allows_shared_content(claims) is True
@pytest.mark.unit
def test_personal_account_returns_false(self, auth):
claims = {"tid": "9188040d-6c67-4c5b-b112-36a304b66dad"}
assert auth._allows_shared_content(claims) is False
@pytest.mark.unit
def test_empty_tid_returns_false(self, auth):
assert auth._allows_shared_content({"tid": ""}) is False
assert auth._allows_shared_content({}) is False
class TestMapTokenResponse:
@pytest.mark.unit
def test_maps_all_fields(self, auth):
result = {
"access_token": "at",
"refresh_token": "rt",
"scope": ["Files.Read"],
"id_token_claims": {
"iss": "https://issuer",
"exp": 1700000000,
"tid": "work-tid",
"name": "User Name",
"preferred_username": "user@example.com",
},
}
mapped = auth.map_token_response(result)
assert mapped["access_token"] == "at"
assert mapped["refresh_token"] == "rt"
assert mapped["token_uri"] == "https://issuer"
assert mapped["scopes"] == ["Files.Read"]
assert mapped["expiry"] == 1700000000
assert mapped["user_info"]["name"] == "User Name"
assert mapped["user_info"]["email"] == "user@example.com"
assert mapped["allows_shared_content"] is True
@pytest.mark.unit
def test_missing_claims_uses_defaults(self, auth):
result = {"access_token": "at", "refresh_token": "rt"}
mapped = auth.map_token_response(result)
assert mapped["token_uri"] is None
assert mapped["expiry"] is None
assert mapped["user_info"]["name"] is None
assert mapped["allows_shared_content"] is False
class TestGetTokenInfoFromSession:
def _mock_mongo(self, mock_settings, find_one_return):
mock_collection = MagicMock()
mock_collection.find_one.return_value = find_one_return
mock_db = MagicMock()
mock_db.__getitem__ = MagicMock(return_value=mock_collection)
mock_client = {mock_settings.MONGO_DB_NAME: mock_db}
return mock_client
@pytest.mark.unit
def test_valid_session(self, auth, mock_settings):
mock_client = self._mock_mongo(mock_settings, {
"session_token": "st",
"token_info": {"access_token": "at", "refresh_token": "rt"},
})
with patch("application.core.mongo_db.MongoDB.get_client", return_value=mock_client), \
patch("application.core.settings.settings", mock_settings):
result = auth.get_token_info_from_session("st")
assert result["access_token"] == "at"
assert "token_uri" in result
@pytest.mark.unit
def test_session_not_found_raises(self, auth, mock_settings):
mock_client = self._mock_mongo(mock_settings, None)
with patch("application.core.mongo_db.MongoDB.get_client", return_value=mock_client), \
patch("application.core.settings.settings", mock_settings):
with pytest.raises(ValueError, match="Failed to retrieve SharePoint token"):
auth.get_token_info_from_session("bad")
@pytest.mark.unit
def test_missing_token_info_raises(self, auth, mock_settings):
mock_client = self._mock_mongo(mock_settings, {"session_token": "st"})
with patch("application.core.mongo_db.MongoDB.get_client", return_value=mock_client), \
patch("application.core.settings.settings", mock_settings):
with pytest.raises(ValueError, match="Failed to retrieve SharePoint token"):
auth.get_token_info_from_session("st")
@pytest.mark.unit
def test_empty_token_info_raises(self, auth, mock_settings):
mock_client = self._mock_mongo(mock_settings, {"session_token": "st", "token_info": None})
with patch("application.core.mongo_db.MongoDB.get_client", return_value=mock_client), \
patch("application.core.settings.settings", mock_settings):
with pytest.raises(ValueError, match="Failed to retrieve SharePoint token"):
auth.get_token_info_from_session("st")
@pytest.mark.unit
def test_missing_required_fields_raises(self, auth, mock_settings):
mock_client = self._mock_mongo(mock_settings, {
"session_token": "st",
"token_info": {"access_token": "at"},
})
with patch("application.core.mongo_db.MongoDB.get_client", return_value=mock_client), \
patch("application.core.settings.settings", mock_settings):
with pytest.raises(ValueError, match="Failed to retrieve SharePoint token"):
auth.get_token_info_from_session("st")

File diff suppressed because it is too large Load Diff

115
tests/test_app_routes.py Normal file
View File

@@ -0,0 +1,115 @@
"""Tests for application/app.py route handlers."""
import json
from unittest.mock import patch
import pytest
@pytest.fixture
def app():
"""Import the Flask app with auth mocked to avoid JWT setup issues."""
with patch("application.app.handle_auth", return_value={"sub": "test_user"}):
from application.app import app as flask_app
flask_app.config["TESTING"] = True
yield flask_app
@pytest.fixture
def client(app):
return app.test_client()
class TestHomeRoute:
@pytest.mark.unit
def test_root_returns_200(self, client):
"""Root serves Swagger UI via Flask-RESTX."""
response = client.get("/")
assert response.status_code == 200
class TestHealthRoute:
@pytest.mark.unit
def test_returns_ok(self, client):
response = client.get("/api/health")
assert response.status_code == 200
data = json.loads(response.data)
assert data["status"] == "ok"
class TestConfigRoute:
@pytest.mark.unit
def test_returns_auth_config(self, client):
response = client.get("/api/config")
assert response.status_code == 200
data = json.loads(response.data)
assert "auth_type" in data
assert "requires_auth" in data
class TestGenerateTokenRoute:
@pytest.mark.unit
def test_session_jwt_generates_token(self, client, app):
with patch("application.app.settings") as mock_settings:
mock_settings.AUTH_TYPE = "session_jwt"
mock_settings.JWT_SECRET_KEY = "test_secret"
response = client.get("/api/generate_token")
assert response.status_code == 200
data = json.loads(response.data)
assert "token" in data
@pytest.mark.unit
def test_non_session_jwt_returns_error(self, client, app):
with patch("application.app.settings") as mock_settings:
mock_settings.AUTH_TYPE = "none"
response = client.get("/api/generate_token")
assert response.status_code == 400
class TestSttRequestSizeLimits:
@pytest.mark.unit
def test_non_stt_request_passes(self, client):
response = client.get("/api/health")
assert response.status_code == 200
@pytest.mark.unit
def test_oversized_stt_request_rejected(self, client):
with patch("application.app.should_reject_stt_request", return_value=True), \
patch("application.app.build_stt_file_size_limit_message", return_value="Too large"):
response = client.post("/api/stt/upload", data=b"x" * 100)
assert response.status_code == 413
class TestAuthenticateRequest:
@pytest.mark.unit
def test_options_returns_200(self, client):
response = client.options("/api/health")
assert response.status_code == 200
@pytest.mark.unit
def test_auth_error_returns_401(self, client, app):
with patch("application.app.handle_auth", return_value={"error": "Invalid token"}):
response = client.get("/api/health")
assert response.status_code == 401
@pytest.mark.unit
def test_no_token_sets_none(self, client, app):
with patch("application.app.handle_auth", return_value=None):
response = client.get("/api/health")
assert response.status_code == 200
class TestAfterRequest:
@pytest.mark.unit
def test_cors_headers(self, client):
response = client.get("/api/health")
assert response.headers.get("Access-Control-Allow-Origin") == "*"
assert "Content-Type" in response.headers.get("Access-Control-Allow-Headers", "")
assert "GET" in response.headers.get("Access-Control-Allow-Methods", "")

279
tests/test_namespaces.py Normal file
View File

@@ -0,0 +1,279 @@
from datetime import datetime, timezone
from unittest.mock import patch
import pytest
from application.templates.namespaces import (
NamespaceBuilder,
NamespaceManager,
PassthroughNamespace,
SourceNamespace,
SystemNamespace,
ToolsNamespace,
)
# ── SystemNamespace ────────────────────────────────────────────────────────────
@pytest.mark.unit
class TestSystemNamespace:
def test_namespace_name(self):
ns = SystemNamespace()
assert ns.namespace_name == "system"
def test_build_returns_expected_keys(self):
ns = SystemNamespace()
result = ns.build()
assert "date" in result
assert "time" in result
assert "timestamp" in result
assert "request_id" in result
assert "user_id" in result
def test_build_with_request_id(self):
ns = SystemNamespace()
result = ns.build(request_id="req-123")
assert result["request_id"] == "req-123"
def test_build_with_user_id(self):
ns = SystemNamespace()
result = ns.build(user_id="user-456")
assert result["user_id"] == "user-456"
def test_build_generates_uuid_when_no_request_id(self):
ns = SystemNamespace()
result = ns.build()
assert len(result["request_id"]) == 36 # UUID format
def test_user_id_defaults_to_none(self):
ns = SystemNamespace()
result = ns.build()
assert result["user_id"] is None
def test_date_format(self):
ns = SystemNamespace()
fixed = datetime(2026, 1, 15, 10, 30, 45, tzinfo=timezone.utc)
with patch("application.templates.namespaces.datetime") as mock_dt:
mock_dt.now.return_value = fixed
mock_dt.side_effect = lambda *a, **kw: datetime(*a, **kw)
result = ns.build()
assert result["date"] == "2026-01-15"
assert result["time"] == "10:30:45"
def test_extra_kwargs_ignored(self):
ns = SystemNamespace()
result = ns.build(unknown_param="whatever")
assert "date" in result
# ── PassthroughNamespace ───────────────────────────────────────────────────────
@pytest.mark.unit
class TestPassthroughNamespace:
def test_namespace_name(self):
ns = PassthroughNamespace()
assert ns.namespace_name == "passthrough"
def test_none_data_returns_empty(self):
ns = PassthroughNamespace()
assert ns.build(passthrough_data=None) == {}
def test_no_data_returns_empty(self):
ns = PassthroughNamespace()
assert ns.build() == {}
def test_safe_types_pass_through(self):
ns = PassthroughNamespace()
data = {"s": "string", "i": 42, "f": 3.14, "b": True, "n": None}
result = ns.build(passthrough_data=data)
assert result == data
def test_non_serializable_types_filtered(self):
ns = PassthroughNamespace()
data = {"good": "ok", "bad": object()}
result = ns.build(passthrough_data=data)
assert result == {"good": "ok"}
def test_list_value_filtered(self):
ns = PassthroughNamespace()
data = {"list_val": [1, 2, 3]}
result = ns.build(passthrough_data=data)
assert result == {}
def test_dict_value_filtered(self):
ns = PassthroughNamespace()
data = {"nested": {"key": "val"}}
result = ns.build(passthrough_data=data)
assert result == {}
# ── SourceNamespace ────────────────────────────────────────────────────────────
@pytest.mark.unit
class TestSourceNamespace:
def test_namespace_name(self):
ns = SourceNamespace()
assert ns.namespace_name == "source"
def test_no_docs_returns_empty(self):
ns = SourceNamespace()
assert ns.build() == {}
def test_with_docs(self):
ns = SourceNamespace()
docs = [{"text": "doc1"}, {"text": "doc2"}]
result = ns.build(docs=docs)
assert result["documents"] == docs
assert result["count"] == 2
def test_with_docs_together(self):
ns = SourceNamespace()
result = ns.build(docs_together="all content together")
assert result["content"] == "all content together"
assert result["docs_together"] == "all content together"
assert result["summaries"] == "all content together"
def test_with_both_docs_and_docs_together(self):
ns = SourceNamespace()
docs = [{"text": "doc1"}]
result = ns.build(docs=docs, docs_together="combined")
assert result["documents"] == docs
assert result["count"] == 1
assert result["content"] == "combined"
def test_empty_docs_list_returns_empty(self):
ns = SourceNamespace()
result = ns.build(docs=[])
assert "documents" not in result
# ── ToolsNamespace ─────────────────────────────────────────────────────────────
@pytest.mark.unit
class TestToolsNamespace:
def test_namespace_name(self):
ns = ToolsNamespace()
assert ns.namespace_name == "tools"
def test_none_data_returns_empty(self):
ns = ToolsNamespace()
assert ns.build(tools_data=None) == {}
def test_no_data_returns_empty(self):
ns = ToolsNamespace()
assert ns.build() == {}
def test_safe_types_pass_through(self):
ns = ToolsNamespace()
data = {
"str_tool": "result",
"dict_tool": {"key": "val"},
"list_tool": [1, 2],
"int_tool": 42,
"float_tool": 3.14,
"bool_tool": True,
"none_tool": None,
}
result = ns.build(tools_data=data)
assert result == data
def test_non_serializable_filtered(self):
ns = ToolsNamespace()
data = {"good": "ok", "bad": object()}
result = ns.build(tools_data=data)
assert result == {"good": "ok"}
# ── NamespaceBuilder ABC ──────────────────────────────────────────────────────
@pytest.mark.unit
class TestNamespaceBuilderABC:
def test_cannot_instantiate(self):
with pytest.raises(TypeError):
NamespaceBuilder()
def test_subclass_must_implement_both(self):
class Incomplete(NamespaceBuilder):
pass
with pytest.raises(TypeError):
Incomplete()
def test_concrete_subclass_works(self):
class Complete(NamespaceBuilder):
@property
def namespace_name(self):
return "test"
def build(self, **kwargs):
return {"ok": True}
inst = Complete()
assert inst.namespace_name == "test"
assert inst.build() == {"ok": True}
# ── NamespaceManager ──────────────────────────────────────────────────────────
@pytest.mark.unit
class TestNamespaceManager:
def test_build_context_contains_all_namespaces(self):
mgr = NamespaceManager()
ctx = mgr.build_context()
assert "system" in ctx
assert "passthrough" in ctx
assert "source" in ctx
assert "tools" in ctx
def test_system_namespace_populated(self):
mgr = NamespaceManager()
ctx = mgr.build_context(request_id="r1", user_id="u1")
assert ctx["system"]["request_id"] == "r1"
assert ctx["system"]["user_id"] == "u1"
def test_passthrough_namespace_populated(self):
mgr = NamespaceManager()
ctx = mgr.build_context(passthrough_data={"key": "val"})
assert ctx["passthrough"] == {"key": "val"}
def test_source_namespace_populated(self):
mgr = NamespaceManager()
docs = [{"text": "doc"}]
ctx = mgr.build_context(docs=docs)
assert ctx["source"]["count"] == 1
def test_tools_namespace_populated(self):
mgr = NamespaceManager()
ctx = mgr.build_context(tools_data={"search": "results"})
assert ctx["tools"] == {"search": "results"}
def test_empty_kwargs_all_namespaces_present(self):
mgr = NamespaceManager()
ctx = mgr.build_context()
for ns in ["system", "passthrough", "source", "tools"]:
assert ns in ctx
assert isinstance(ctx[ns], dict)
def test_builder_exception_returns_empty_namespace(self):
mgr = NamespaceManager()
with patch.object(
mgr._builders["system"], "build", side_effect=RuntimeError("boom")
):
ctx = mgr.build_context()
assert ctx["system"] == {}
assert "passthrough" in ctx
def test_get_builder_existing(self):
mgr = NamespaceManager()
builder = mgr.get_builder("system")
assert isinstance(builder, SystemNamespace)
def test_get_builder_nonexistent(self):
mgr = NamespaceManager()
assert mgr.get_builder("nonexistent") is None

353
tests/test_retriever.py Normal file
View File

@@ -0,0 +1,353 @@
from unittest.mock import MagicMock, Mock, patch
import pytest
from application.retriever.base import BaseRetriever
from application.retriever.retriever_creator import RetrieverCreator
# ── BaseRetriever ──────────────────────────────────────────────────────────────
@pytest.mark.unit
class TestBaseRetriever:
def test_cannot_instantiate_directly(self):
with pytest.raises(TypeError):
BaseRetriever()
def test_subclass_must_implement_search(self):
class Incomplete(BaseRetriever):
pass
with pytest.raises(TypeError):
Incomplete()
def test_concrete_subclass_works(self):
class Concrete(BaseRetriever):
def search(self, *args, **kwargs):
return "ok"
instance = Concrete()
assert instance.search() == "ok"
# ── RetrieverCreator ───────────────────────────────────────────────────────────
@pytest.mark.unit
class TestRetrieverCreator:
def test_create_classic(self):
mock_cls = Mock(return_value="rag_instance")
original = RetrieverCreator.retrievers.copy()
RetrieverCreator.retrievers["classic"] = mock_cls
try:
result = RetrieverCreator.create_retriever("classic", "arg1", key="val")
mock_cls.assert_called_once_with("arg1", key="val")
assert result == "rag_instance"
finally:
RetrieverCreator.retrievers.update(original)
def test_create_default(self):
mock_cls = Mock(return_value="rag_instance")
original = RetrieverCreator.retrievers.copy()
RetrieverCreator.retrievers["default"] = mock_cls
try:
result = RetrieverCreator.create_retriever("default")
mock_cls.assert_called_once_with()
assert result == "rag_instance"
finally:
RetrieverCreator.retrievers.update(original)
def test_create_none_type_uses_default(self):
mock_cls = Mock(return_value="rag_instance")
original = RetrieverCreator.retrievers.copy()
RetrieverCreator.retrievers["default"] = mock_cls
try:
result = RetrieverCreator.create_retriever(None)
mock_cls.assert_called_once()
assert result == "rag_instance"
finally:
RetrieverCreator.retrievers.update(original)
def test_case_insensitive(self):
mock_cls = Mock(return_value="rag_instance")
original = RetrieverCreator.retrievers.copy()
RetrieverCreator.retrievers["classic"] = mock_cls
try:
RetrieverCreator.create_retriever("CLASSIC")
mock_cls.assert_called_once()
finally:
RetrieverCreator.retrievers.update(original)
def test_invalid_type_raises(self):
with pytest.raises(ValueError, match="No retievers class found"):
RetrieverCreator.create_retriever("nonexistent")
# ── ClassicRAG ─────────────────────────────────────────────────────────────────
@pytest.fixture
def _patch_llm_creator(mock_llm, monkeypatch):
"""Patch LLMCreator.create_llm to return the shared mock_llm fixture."""
monkeypatch.setattr(
"application.retriever.classic_rag.LLMCreator.create_llm",
Mock(return_value=mock_llm),
)
return mock_llm
def _make_rag(source=None, _patch_llm_creator=None, **overrides):
"""Helper builds a ClassicRAG with sensible defaults."""
from application.retriever.classic_rag import ClassicRAG
defaults = dict(
source=source or {"question": "hello"},
chat_history=None,
prompt="",
chunks=2,
doc_token_limit=50000,
model_id="test-model",
user_api_key=None,
agent_id=None,
llm_name="openai",
api_key="fake",
decoded_token={"sub": "user1"},
)
defaults.update(overrides)
return ClassicRAG(**defaults)
@pytest.mark.unit
class TestClassicRAGInit:
def test_basic_init(self, _patch_llm_creator):
rag = _make_rag()
assert rag.original_question == "hello"
assert rag.chunks == 2
assert rag.vectorstores == []
def test_active_docs_as_list(self, _patch_llm_creator):
rag = _make_rag(source={"question": "q", "active_docs": ["a", "b"]})
assert rag.vectorstores == ["a", "b"]
def test_active_docs_as_string(self, _patch_llm_creator):
rag = _make_rag(source={"question": "q", "active_docs": "single"})
assert rag.vectorstores == ["single"]
def test_active_docs_none(self, _patch_llm_creator):
rag = _make_rag(source={"question": "q", "active_docs": None})
assert rag.vectorstores == []
def test_chunks_string_converted(self, _patch_llm_creator):
rag = _make_rag(chunks="5")
assert rag.chunks == 5
def test_chunks_invalid_string_defaults(self, _patch_llm_creator):
rag = _make_rag(chunks="abc")
assert rag.chunks == 2
def test_decoded_token_none(self, _patch_llm_creator):
rag = _make_rag(decoded_token=None)
assert rag.decoded_token is None
@pytest.mark.unit
class TestClassicRAGValidateVectorstore:
def test_removes_empty_ids(self, _patch_llm_creator):
rag = _make_rag(source={"question": "q", "active_docs": ["ok", "", " ", "good"]})
assert rag.vectorstores == ["ok", "good"]
def test_empty_vectorstores_no_error(self, _patch_llm_creator):
rag = _make_rag(source={"question": "q"})
assert rag.vectorstores == []
@pytest.mark.unit
class TestClassicRAGRephraseQuery:
def test_no_history_returns_original(self, _patch_llm_creator):
rag = _make_rag(
source={"question": "original", "active_docs": ["vs1"]},
chat_history=[],
)
assert rag.question == "original"
def test_no_vectorstores_returns_original(self, _patch_llm_creator):
rag = _make_rag(
source={"question": "original"},
chat_history=[{"prompt": "hi", "response": "hello"}],
)
assert rag.question == "original"
def test_chunks_zero_returns_original(self, _patch_llm_creator):
rag = _make_rag(
source={"question": "original", "active_docs": ["vs1"]},
chat_history=[{"prompt": "hi", "response": "hello"}],
chunks=0,
)
assert rag.question == "original"
def test_rephrase_called_with_history(self, _patch_llm_creator, mock_llm):
mock_llm.gen = Mock(return_value="rephrased question")
rag = _make_rag(
source={"question": "original", "active_docs": ["vs1"]},
chat_history=[{"prompt": "hi", "response": "hello"}],
)
assert rag.question == "rephrased question"
mock_llm.gen.assert_called_once()
def test_rephrase_llm_returns_empty_falls_back(self, _patch_llm_creator, mock_llm):
mock_llm.gen = Mock(return_value="")
rag = _make_rag(
source={"question": "original", "active_docs": ["vs1"]},
chat_history=[{"prompt": "hi", "response": "hello"}],
)
assert rag.question == "original"
def test_rephrase_llm_exception_falls_back(self, _patch_llm_creator, mock_llm):
mock_llm.gen = Mock(side_effect=RuntimeError("boom"))
rag = _make_rag(
source={"question": "original", "active_docs": ["vs1"]},
chat_history=[{"prompt": "hi", "response": "hello"}],
)
assert rag.question == "original"
@pytest.mark.unit
class TestClassicRAGGetData:
def test_chunks_zero_returns_empty(self, _patch_llm_creator):
rag = _make_rag(chunks=0)
assert rag._get_data() == []
def test_no_vectorstores_returns_empty(self, _patch_llm_creator):
rag = _make_rag(source={"question": "q"})
assert rag._get_data() == []
@patch("application.retriever.classic_rag.VectorCreator")
@patch("application.retriever.classic_rag.num_tokens_from_string", return_value=10)
def test_returns_docs_with_metadata(self, mock_tokens, mock_vc, _patch_llm_creator):
mock_docsearch = MagicMock()
mock_doc = MagicMock()
mock_doc.page_content = "content here"
mock_doc.metadata = {
"title": "path/to/Title",
"filename": "/docs/file.txt",
"source": "http://example.com",
}
mock_docsearch.search.return_value = [mock_doc]
mock_vc.create_vectorstore.return_value = mock_docsearch
rag = _make_rag(source={"question": "q", "active_docs": ["vs1"]})
docs = rag._get_data()
assert len(docs) == 1
assert docs[0]["text"] == "content here"
assert docs[0]["title"] == "Title"
assert docs[0]["filename"] == "file.txt"
assert docs[0]["source"] == "http://example.com"
@patch("application.retriever.classic_rag.VectorCreator")
@patch("application.retriever.classic_rag.num_tokens_from_string", return_value=10)
def test_dict_style_docs(self, mock_tokens, mock_vc, _patch_llm_creator):
mock_docsearch = MagicMock()
mock_docsearch.search.return_value = [
{"text": "dict content", "metadata": {"title": "Dict Title"}}
]
mock_vc.create_vectorstore.return_value = mock_docsearch
rag = _make_rag(source={"question": "q", "active_docs": ["vs1"]})
docs = rag._get_data()
assert len(docs) == 1
assert docs[0]["text"] == "dict content"
@patch("application.retriever.classic_rag.VectorCreator")
@patch("application.retriever.classic_rag.num_tokens_from_string", return_value=100000)
def test_token_budget_respected(self, mock_tokens, mock_vc, _patch_llm_creator):
mock_docsearch = MagicMock()
mock_doc = MagicMock()
mock_doc.page_content = "big content"
mock_doc.metadata = {"title": "t"}
mock_docsearch.search.return_value = [mock_doc, mock_doc, mock_doc]
mock_vc.create_vectorstore.return_value = mock_docsearch
rag = _make_rag(
source={"question": "q", "active_docs": ["vs1"]},
doc_token_limit=100,
)
docs = rag._get_data()
# tokens (100000) exceed budget (90), so no docs should be added
assert len(docs) == 0
@patch("application.retriever.classic_rag.VectorCreator")
def test_vectorstore_error_continues(self, mock_vc, _patch_llm_creator):
mock_vc.create_vectorstore.side_effect = RuntimeError("connection failed")
rag = _make_rag(source={"question": "q", "active_docs": ["vs1"]})
docs = rag._get_data()
assert docs == []
@patch("application.retriever.classic_rag.VectorCreator")
@patch("application.retriever.classic_rag.num_tokens_from_string", return_value=10)
def test_multiple_vectorstores(self, mock_tokens, mock_vc, _patch_llm_creator):
mock_docsearch = MagicMock()
mock_doc = MagicMock()
mock_doc.page_content = "content"
mock_doc.metadata = {"title": "t", "source": "s"}
mock_docsearch.search.return_value = [mock_doc]
mock_vc.create_vectorstore.return_value = mock_docsearch
rag = _make_rag(source={"question": "q", "active_docs": ["vs1", "vs2"]})
docs = rag._get_data()
assert len(docs) == 2
@patch("application.retriever.classic_rag.VectorCreator")
@patch("application.retriever.classic_rag.num_tokens_from_string", return_value=10)
def test_doc_missing_filename_uses_title(self, mock_tokens, mock_vc, _patch_llm_creator):
mock_docsearch = MagicMock()
mock_doc = MagicMock()
mock_doc.page_content = "content"
mock_doc.metadata = {"title": "MyTitle"}
mock_docsearch.search.return_value = [mock_doc]
mock_vc.create_vectorstore.return_value = mock_docsearch
rag = _make_rag(source={"question": "q", "active_docs": ["vs1"]})
docs = rag._get_data()
assert docs[0]["filename"] == "MyTitle"
@patch("application.retriever.classic_rag.VectorCreator")
@patch("application.retriever.classic_rag.num_tokens_from_string", return_value=10)
def test_non_string_title_converted(self, mock_tokens, mock_vc, _patch_llm_creator):
mock_docsearch = MagicMock()
mock_doc = MagicMock()
mock_doc.page_content = "content"
mock_doc.metadata = {"title": 42}
mock_docsearch.search.return_value = [mock_doc]
mock_vc.create_vectorstore.return_value = mock_docsearch
rag = _make_rag(source={"question": "q", "active_docs": ["vs1"]})
docs = rag._get_data()
assert docs[0]["title"] == "42"
@pytest.mark.unit
class TestClassicRAGSearch:
@patch("application.retriever.classic_rag.VectorCreator")
@patch("application.retriever.classic_rag.num_tokens_from_string", return_value=10)
def test_search_with_query_override(self, mock_tokens, mock_vc, _patch_llm_creator, mock_llm):
mock_docsearch = MagicMock()
mock_doc = MagicMock()
mock_doc.page_content = "result"
mock_doc.metadata = {"title": "t"}
mock_docsearch.search.return_value = [mock_doc]
mock_vc.create_vectorstore.return_value = mock_docsearch
mock_llm.gen = Mock(return_value="")
rag = _make_rag(source={"question": "original", "active_docs": ["vs1"]})
docs = rag.search(query="override query")
assert rag.original_question == "override query"
assert len(docs) == 1
def test_search_without_query_uses_default(self, _patch_llm_creator):
rag = _make_rag(source={"question": "q"})
docs = rag.search()
assert docs == []

492
tests/test_seeder.py Normal file
View File

@@ -0,0 +1,492 @@
from unittest.mock import MagicMock, patch, mock_open
import mongomock
import pytest
from bson import ObjectId
from application.seed.seeder import DatabaseSeeder
@pytest.fixture
def mock_db():
client = mongomock.MongoClient()
return client["test_docsgpt"]
@pytest.fixture
def seeder(mock_db):
return DatabaseSeeder(mock_db)
# ── __init__ ───────────────────────────────────────────────────────────────────
@pytest.mark.unit
class TestDatabaseSeederInit:
def test_collections_set(self, seeder, mock_db):
assert seeder.db is mock_db
assert seeder.tools_collection == mock_db["user_tools"]
assert seeder.sources_collection == mock_db["sources"]
assert seeder.agents_collection == mock_db["agents"]
assert seeder.prompts_collection == mock_db["prompts"]
assert seeder.system_user_id == "system"
# ── _is_already_seeded ─────────────────────────────────────────────────────────
@pytest.mark.unit
class TestIsAlreadySeeded:
def test_not_seeded(self, seeder):
assert seeder._is_already_seeded() is False
def test_already_seeded(self, seeder, mock_db):
mock_db["agents"].insert_one({"user": "system", "name": "test"})
assert seeder._is_already_seeded() is True
# ── _process_config ────────────────────────────────────────────────────────────
@pytest.mark.unit
class TestProcessConfig:
def test_env_var_substitution(self, seeder, monkeypatch):
monkeypatch.setenv("MY_SECRET", "secret_value")
result = seeder._process_config({"key": "${MY_SECRET}"})
assert result["key"] == "secret_value"
def test_missing_env_var_defaults_empty(self, seeder, monkeypatch):
monkeypatch.delenv("NONEXISTENT_VAR", raising=False)
result = seeder._process_config({"key": "${NONEXISTENT_VAR}"})
assert result["key"] == ""
def test_non_env_value_unchanged(self, seeder):
result = seeder._process_config({"key": "plain_value", "num": 42})
assert result == {"key": "plain_value", "num": 42}
def test_partial_env_syntax_unchanged(self, seeder):
result = seeder._process_config({"key": "${INCOMPLETE"})
assert result["key"] == "${INCOMPLETE"
def test_empty_config(self, seeder):
assert seeder._process_config({}) == {}
# ── _handle_prompt ─────────────────────────────────────────────────────────────
@pytest.mark.unit
class TestHandlePrompt:
def test_no_prompt_returns_none(self, seeder):
assert seeder._handle_prompt({"name": "agent1"}) is None
def test_empty_content_returns_none(self, seeder):
config = {"name": "agent1", "prompt": {"name": "p", "content": ""}}
assert seeder._handle_prompt(config) is None
def test_creates_prompt(self, seeder, mock_db):
config = {
"name": "agent1",
"prompt": {"name": "My Prompt", "content": "You are helpful."},
}
result = seeder._handle_prompt(config)
assert result is not None
doc = mock_db["prompts"].find_one({"name": "My Prompt"})
assert doc is not None
assert doc["content"] == "You are helpful."
assert doc["user"] == "system"
def test_duplicate_prompt_returns_existing(self, seeder, mock_db):
config = {
"name": "agent1",
"prompt": {"name": "Dup Prompt", "content": "content"},
}
id1 = seeder._handle_prompt(config)
id2 = seeder._handle_prompt(config)
assert id1 == id2
assert mock_db["prompts"].count_documents({"name": "Dup Prompt"}) == 1
def test_default_prompt_name(self, seeder, mock_db):
config = {"name": "agent1", "prompt": {"content": "hello"}}
seeder._handle_prompt(config)
doc = mock_db["prompts"].find_one({"name": "agent1 Prompt"})
assert doc is not None
def test_exception_returns_none(self, seeder):
with patch.object(
seeder.prompts_collection, "find_one", side_effect=RuntimeError("db error")
):
config = {
"name": "agent1",
"prompt": {"name": "p", "content": "c"},
}
assert seeder._handle_prompt(config) is None
# ── _handle_tools ──────────────────────────────────────────────────────────────
@pytest.mark.unit
class TestHandleTools:
def test_no_tools_returns_empty(self, seeder):
assert seeder._handle_tools({"name": "agent1"}) == []
@patch("application.seed.seeder.tool_manager")
def test_creates_tool(self, mock_tm, seeder, mock_db):
mock_tool = MagicMock()
mock_tool.get_actions_metadata.return_value = [{"name": "act1"}]
mock_tm.tools = {"my_tool": mock_tool}
config = {
"name": "agent1",
"tools": [{"name": "my_tool", "description": "desc"}],
}
ids = seeder._handle_tools(config)
assert len(ids) == 1
doc = mock_db["user_tools"].find_one({"name": "my_tool"})
assert doc is not None
assert doc["user"] == "system"
@patch("application.seed.seeder.tool_manager")
def test_duplicate_tool_returns_existing(self, mock_tm, seeder, mock_db):
mock_tool = MagicMock()
mock_tool.get_actions_metadata.return_value = []
mock_tm.tools = {"my_tool": mock_tool}
config = {
"name": "agent1",
"tools": [{"name": "my_tool"}],
}
ids1 = seeder._handle_tools(config)
ids2 = seeder._handle_tools(config)
assert ids1 == ids2
assert mock_db["user_tools"].count_documents({"name": "my_tool"}) == 1
def test_tool_exception_continues(self, seeder):
config = {
"name": "agent1",
"tools": [{"name": "broken_tool"}],
}
# tool_manager.tools will KeyError on "broken_tool"
ids = seeder._handle_tools(config)
assert ids == []
@patch("application.seed.seeder.tool_manager")
def test_tool_config_env_expansion(self, mock_tm, seeder, monkeypatch):
monkeypatch.setenv("TOOL_KEY", "expanded_val")
mock_tool = MagicMock()
mock_tool.get_actions_metadata.return_value = []
mock_tm.tools = {"my_tool": mock_tool}
config = {
"name": "agent1",
"tools": [{"name": "my_tool", "config": {"api_key": "${TOOL_KEY}"}}],
}
seeder._handle_tools(config)
doc = seeder.tools_collection.find_one({"name": "my_tool"})
assert doc["config"]["api_key"] == "expanded_val"
# ── _handle_source ─────────────────────────────────────────────────────────────
@pytest.mark.unit
class TestHandleSource:
def test_no_source_returns_none(self, seeder):
assert seeder._handle_source({"name": "a"}) is None
def test_existing_source_returns_id(self, seeder, mock_db):
inserted = mock_db["sources"].insert_one(
{"user": "system", "remote_data": "http://example.com"}
)
config = {
"name": "a",
"source": {"url": "http://example.com", "name": "src"},
}
result = seeder._handle_source(config)
assert result == inserted.inserted_id
@patch("application.seed.seeder.ingest_remote")
def test_new_source_ingestion(self, mock_ingest, seeder):
mock_task = MagicMock()
mock_task.get.return_value = {"id": "new_source_id"}
mock_task.successful.return_value = True
mock_ingest.delay.return_value = mock_task
config = {
"name": "a",
"source": {"url": "http://new.com", "name": "new_src", "loader": "web"},
}
result = seeder._handle_source(config)
assert result == "new_source_id"
mock_ingest.delay.assert_called_once_with(
source_data="http://new.com",
job_name="new_src",
user="system",
loader="web",
)
@patch("application.seed.seeder.ingest_remote")
def test_source_ingestion_failure_returns_false(self, mock_ingest, seeder):
mock_task = MagicMock()
mock_task.get.side_effect = RuntimeError("timeout")
mock_ingest.delay.return_value = mock_task
config = {
"name": "a",
"source": {"url": "http://fail.com", "name": "fail_src"},
}
result = seeder._handle_source(config)
assert result is False
@patch("application.seed.seeder.ingest_remote")
def test_source_missing_id_returns_false(self, mock_ingest, seeder):
mock_task = MagicMock()
mock_task.get.return_value = {"no_id_key": True}
mock_task.successful.return_value = True
mock_ingest.delay.return_value = mock_task
config = {
"name": "a",
"source": {"url": "http://bad.com", "name": "bad_src"},
}
result = seeder._handle_source(config)
assert result is False
@patch("application.seed.seeder.ingest_remote")
def test_default_loader(self, mock_ingest, seeder):
mock_task = MagicMock()
mock_task.get.return_value = {"id": "sid"}
mock_task.successful.return_value = True
mock_ingest.delay.return_value = mock_task
config = {
"name": "a",
"source": {"url": "http://x.com", "name": "s"},
}
seeder._handle_source(config)
call_kwargs = mock_ingest.delay.call_args[1]
assert call_kwargs["loader"] == "url"
# ── seed_initial_data ──────────────────────────────────────────────────────────
@pytest.mark.unit
class TestSeedInitialData:
def test_already_seeded_skips(self, seeder, mock_db):
mock_db["agents"].insert_one({"user": "system", "name": "existing"})
with patch.object(seeder, "_seed_from_config") as mock_seed:
seeder.seed_initial_data()
mock_seed.assert_not_called()
def test_force_reseeds(self, seeder, mock_db):
mock_db["agents"].insert_one({"user": "system", "name": "existing"})
yaml_content = "agents: []"
with patch("builtins.open", mock_open(read_data=yaml_content)):
with patch.object(seeder, "_seed_from_config") as mock_seed:
seeder.seed_initial_data(force=True)
mock_seed.assert_called_once()
def test_config_file_not_found_raises(self, seeder):
with pytest.raises(Exception):
seeder.seed_initial_data(config_path="/nonexistent/path.yaml")
def test_custom_config_path(self, seeder):
yaml_content = "agents:\n - name: test_agent"
with patch("builtins.open", mock_open(read_data=yaml_content)):
with patch.object(seeder, "_seed_from_config") as mock_seed:
seeder.seed_initial_data(config_path="/custom/path.yaml")
mock_seed.assert_called_once()
# ── _seed_from_config ──────────────────────────────────────────────────────────
@pytest.mark.unit
class TestSeedFromConfig:
def test_no_agents_in_config(self, seeder):
seeder._seed_from_config({})
assert seeder.agents_collection.count_documents({}) == 0
def test_empty_agents_list(self, seeder):
seeder._seed_from_config({"agents": []})
assert seeder.agents_collection.count_documents({}) == 0
@patch.object(DatabaseSeeder, "_handle_source", return_value=None)
@patch.object(DatabaseSeeder, "_handle_tools", return_value=[])
@patch.object(DatabaseSeeder, "_handle_prompt", return_value=None)
def test_creates_agent(self, mock_prompt, mock_tools, mock_source, seeder, mock_db):
config = {
"agents": [
{
"name": "TestAgent",
"description": "A test agent",
"agent_type": "classic",
}
]
}
seeder._seed_from_config(config)
agent = mock_db["agents"].find_one({"name": "TestAgent"})
assert agent is not None
assert agent["user"] == "system"
assert agent["agent_type"] == "classic"
assert agent["status"] == "template"
@patch.object(DatabaseSeeder, "_handle_source", return_value=None)
@patch.object(DatabaseSeeder, "_handle_tools", return_value=[])
@patch.object(DatabaseSeeder, "_handle_prompt", return_value=None)
def test_updates_existing_agent(self, mock_prompt, mock_tools, mock_source, seeder, mock_db):
mock_db["agents"].insert_one(
{"user": "system", "name": "TestAgent", "description": "old"}
)
config = {
"agents": [
{
"name": "TestAgent",
"description": "updated",
"agent_type": "classic",
}
]
}
seeder._seed_from_config(config)
assert mock_db["agents"].count_documents({"name": "TestAgent"}) == 1
agent = mock_db["agents"].find_one({"name": "TestAgent"})
assert agent["description"] == "updated"
@patch.object(DatabaseSeeder, "_handle_source", return_value=False)
def test_source_failure_skips_agent(self, mock_source, seeder, mock_db):
config = {
"agents": [
{
"name": "SkippedAgent",
"description": "skip",
"agent_type": "classic",
}
]
}
seeder._seed_from_config(config)
assert mock_db["agents"].count_documents({"name": "SkippedAgent"}) == 0
@patch.object(DatabaseSeeder, "_handle_source", side_effect=KeyError("name"))
def test_agent_exception_continues(self, mock_source, seeder, mock_db):
config = {
"agents": [
{"name": "Bad", "description": "x", "agent_type": "y"},
{"name": "Good", "description": "x", "agent_type": "y"},
]
}
with patch.object(seeder, "_handle_tools", return_value=[]):
with patch.object(seeder, "_handle_prompt", return_value=None):
seeder._seed_from_config(config)
# Both agents should be attempted; first errors, second might too
# Main assertion: no unhandled exception
@patch.object(DatabaseSeeder, "_handle_source", return_value=None)
@patch.object(DatabaseSeeder, "_handle_tools")
@patch.object(DatabaseSeeder, "_handle_prompt", return_value="prompt_id_123")
def test_agent_with_source_and_tools(self, mock_prompt, mock_tools, mock_source, seeder, mock_db):
tool_id = ObjectId()
mock_tools.return_value = [tool_id]
source_id = ObjectId()
mock_source.return_value = source_id
config = {
"agents": [
{
"name": "FullAgent",
"description": "full",
"agent_type": "classic",
"chunks": "5",
"retriever": "classic",
"image": "img.png",
}
]
}
seeder._seed_from_config(config)
agent = mock_db["agents"].find_one({"name": "FullAgent"})
assert agent is not None
assert agent["prompt_id"] == "prompt_id_123"
assert str(tool_id) in agent["tools"]
assert agent["chunks"] == "5"
assert agent["image"] == "img.png"
# ── initialize_from_env ────────────────────────────────────────────────────────
@pytest.mark.unit
class TestInitializeFromEnv:
@patch("application.seed.seeder.MongoClient")
def test_creates_seeder_from_env(self, mock_client_cls, monkeypatch):
monkeypatch.setenv("MONGO_URI", "mongodb://test:27017")
monkeypatch.setenv("MONGO_DB_NAME", "testdb")
mock_db = mongomock.MongoClient()["testdb"]
mock_client = MagicMock()
mock_client.__getitem__ = MagicMock(return_value=mock_db)
mock_client_cls.return_value = mock_client
seeder = DatabaseSeeder.initialize_from_env()
mock_client_cls.assert_called_once_with("mongodb://test:27017")
assert isinstance(seeder, DatabaseSeeder)
@patch("application.seed.seeder.MongoClient")
def test_default_env_values(self, mock_client_cls, monkeypatch):
monkeypatch.delenv("MONGO_URI", raising=False)
monkeypatch.delenv("MONGO_DB_NAME", raising=False)
mock_db = mongomock.MongoClient()["docsgpt"]
mock_client = MagicMock()
mock_client.__getitem__ = MagicMock(return_value=mock_db)
mock_client_cls.return_value = mock_client
DatabaseSeeder.initialize_from_env()
mock_client_cls.assert_called_once_with("mongodb://localhost:27017")
# ── seed CLI commands ──────────────────────────────────────────────────────────
@pytest.mark.unit
class TestSeedCommands:
@patch("application.seed.commands.DatabaseSeeder")
@patch("application.seed.commands.MongoDB")
@patch("application.seed.commands.settings")
def test_init_command(self, mock_settings, mock_mongodb, mock_seeder_cls):
from click.testing import CliRunner
from application.seed.commands import seed
mock_settings.MONGO_DB_NAME = "testdb"
mock_client = MagicMock()
mock_client.__getitem__ = MagicMock(return_value="mock_db")
mock_mongodb.get_client.return_value = mock_client
mock_seeder = MagicMock()
mock_seeder_cls.return_value = mock_seeder
runner = CliRunner()
result = runner.invoke(seed, ["init"])
assert result.exit_code == 0
mock_seeder.seed_initial_data.assert_called_once_with(force=False)
@patch("application.seed.commands.DatabaseSeeder")
@patch("application.seed.commands.MongoDB")
@patch("application.seed.commands.settings")
def test_init_command_with_force(self, mock_settings, mock_mongodb, mock_seeder_cls):
from click.testing import CliRunner
from application.seed.commands import seed
mock_settings.MONGO_DB_NAME = "testdb"
mock_client = MagicMock()
mock_client.__getitem__ = MagicMock(return_value="mock_db")
mock_mongodb.get_client.return_value = mock_client
mock_seeder = MagicMock()
mock_seeder_cls.return_value = mock_seeder
runner = CliRunner()
result = runner.invoke(seed, ["init", "--force"])
assert result.exit_code == 0
mock_seeder.seed_initial_data.assert_called_once_with(force=True)

View File

@@ -0,0 +1,186 @@
from unittest.mock import patch
import pytest
from application.templates.template_engine import TemplateEngine, TemplateRenderError
@pytest.fixture
def engine():
return TemplateEngine()
# ── render ─────────────────────────────────────────────────────────────────────
@pytest.mark.unit
class TestRender:
def test_simple_variable(self, engine):
result = engine.render("Hello {{ name }}", {"name": "World"})
assert result == "Hello World"
def test_empty_template_returns_empty(self, engine):
assert engine.render("", {"x": 1}) == ""
def test_none_template_returns_empty(self, engine):
assert engine.render(None, {"x": 1}) == ""
def test_no_variables(self, engine):
assert engine.render("plain text", {}) == "plain text"
def test_multiple_variables(self, engine):
tpl = "{{ a }} and {{ b }}"
assert engine.render(tpl, {"a": "X", "b": "Y"}) == "X and Y"
def test_nested_dict_access(self, engine):
tpl = "{{ data.key }}"
assert engine.render(tpl, {"data": {"key": "value"}}) == "value"
def test_loop(self, engine):
tpl = "{% for i in items %}{{ i }} {% endfor %}"
result = engine.render(tpl, {"items": ["a", "b", "c"]})
assert result.strip() == "a b c"
def test_conditional(self, engine):
tpl = "{% if show %}yes{% else %}no{% endif %}"
assert engine.render(tpl, {"show": True}) == "yes"
assert engine.render(tpl, {"show": False}) == "no"
def test_syntax_error_raises_template_render_error(self, engine):
with pytest.raises(TemplateRenderError, match="syntax error"):
engine.render("{% if %}", {})
def test_undefined_variable_chainable(self, engine):
# ChainableUndefined should NOT raise; it silently returns empty
result = engine.render("{{ missing }}", {})
assert result == ""
def test_autoescape_html(self, engine):
result = engine.render("{{ content }}", {"content": "<script>alert(1)</script>"})
assert "<script>" not in result
assert "&lt;script&gt;" in result
def test_trim_blocks(self, engine):
tpl = "{% if True %}\nyes\n{% endif %}"
result = engine.render(tpl, {})
assert result.strip() == "yes"
def test_general_exception_raises_template_render_error(self, engine):
with patch.object(engine._env, "from_string", side_effect=RuntimeError("boom")):
with pytest.raises(TemplateRenderError, match="rendering failed"):
engine.render("{{ x }}", {"x": 1})
# ── validate_template ──────────────────────────────────────────────────────────
@pytest.mark.unit
class TestValidateTemplate:
def test_valid_template(self, engine):
assert engine.validate_template("{{ name }}") is True
def test_empty_template_valid(self, engine):
assert engine.validate_template("") is True
def test_none_template_valid(self, engine):
assert engine.validate_template(None) is True
def test_invalid_syntax(self, engine):
assert engine.validate_template("{% if %}") is False
def test_plain_text_valid(self, engine):
assert engine.validate_template("just plain text") is True
def test_complex_valid_template(self, engine):
tpl = "{% for x in items %}{{ x.name }}{% endfor %}"
assert engine.validate_template(tpl) is True
def test_general_exception_returns_false(self, engine):
with patch.object(engine._env, "from_string", side_effect=RuntimeError("boom")):
assert engine.validate_template("{{ x }}") is False
# ── extract_variables ──────────────────────────────────────────────────────────
@pytest.mark.unit
class TestExtractVariables:
def test_empty_template(self, engine):
assert engine.extract_variables("") == set()
def test_none_template(self, engine):
assert engine.extract_variables(None) == set()
def test_syntax_error_returns_empty(self, engine):
assert engine.extract_variables("{% if %}") == set()
def test_general_exception_returns_empty(self, engine):
with patch.object(engine._env, "parse", side_effect=RuntimeError("boom")):
assert engine.extract_variables("{{ x }}") == set()
# ── extract_tool_usages ───────────────────────────────────────────────────────
@pytest.mark.unit
class TestExtractToolUsages:
def test_empty_template(self, engine):
assert engine.extract_tool_usages("") == {}
def test_none_template(self, engine):
assert engine.extract_tool_usages(None) == {}
def test_syntax_error_returns_empty(self, engine):
assert engine.extract_tool_usages("{% if %}") == {}
def test_getattr_single_tool(self, engine):
tpl = "{{ tools.memory.notes }}"
result = engine.extract_tool_usages(tpl)
assert "memory" in result
assert "notes" in result["memory"]
def test_getattr_tool_without_action(self, engine):
tpl = "{{ tools.search }}"
result = engine.extract_tool_usages(tpl)
assert "search" in result
assert None in result["search"]
def test_getitem_bracket_notation(self, engine):
tpl = '{{ tools["calendar"]["events"] }}'
result = engine.extract_tool_usages(tpl)
assert "calendar" in result
assert "events" in result["calendar"]
def test_multiple_tools(self, engine):
tpl = "{{ tools.memory.notes }} {{ tools.search.query }}"
result = engine.extract_tool_usages(tpl)
assert "memory" in result
assert "search" in result
def test_same_tool_multiple_actions(self, engine):
tpl = "{{ tools.memory.notes }} {{ tools.memory.tasks }}"
result = engine.extract_tool_usages(tpl)
assert "memory" in result
assert "notes" in result["memory"]
assert "tasks" in result["memory"]
def test_non_tools_getattr_ignored(self, engine):
tpl = "{{ data.something }}"
result = engine.extract_tool_usages(tpl)
assert result == {}
def test_general_parse_error_returns_empty(self, engine):
with patch.object(engine._env, "parse", side_effect=RuntimeError("boom")):
assert engine.extract_tool_usages("{{ tools.x }}") == {}
def test_tools_in_loop(self, engine):
tpl = "{% for item in tools.memory.notes %}{{ item }}{% endfor %}"
result = engine.extract_tool_usages(tpl)
assert "memory" in result
assert "notes" in result["memory"]
def test_tools_in_conditional(self, engine):
tpl = "{% if tools.search.results %}found{% endif %}"
result = engine.extract_tool_usages(tpl)
assert "search" in result
assert "results" in result["search"]

533
tests/test_utils.py Normal file
View File

@@ -0,0 +1,533 @@
"""Tests for application/utils.py"""
from unittest.mock import MagicMock, patch
import pytest
from application.utils import (
calculate_compression_threshold,
calculate_doc_token_budget,
check_required_fields,
clean_text_for_tts,
convert_pdf_to_images,
get_encoding,
get_field_validation_errors,
get_gpt_model,
get_hash,
get_missing_fields,
generate_image_url,
limit_chat_history,
num_tokens_from_object_or_list,
num_tokens_from_string,
safe_filename,
validate_function_name,
validate_required_fields,
)
class TestGetEncoding:
@pytest.mark.unit
def test_returns_encoding(self):
enc = get_encoding()
assert enc is not None
@pytest.mark.unit
def test_returns_same_instance(self):
enc1 = get_encoding()
enc2 = get_encoding()
assert enc1 is enc2
class TestGetGptModel:
@pytest.mark.unit
def test_returns_llm_name_when_set(self):
with patch("application.utils.settings") as s:
s.LLM_NAME = "my-model"
s.LLM_PROVIDER = "openai"
assert get_gpt_model() == "my-model"
@pytest.mark.unit
def test_falls_back_to_provider_map(self):
with patch("application.utils.settings") as s:
s.LLM_NAME = ""
s.LLM_PROVIDER = "openai"
assert get_gpt_model() == "gpt-4o-mini"
@pytest.mark.unit
def test_unknown_provider_returns_empty(self):
with patch("application.utils.settings") as s:
s.LLM_NAME = ""
s.LLM_PROVIDER = "unknown"
assert get_gpt_model() == ""
class TestSafeFilename:
@pytest.mark.unit
def test_normal_filename(self):
assert safe_filename("test.pdf") == "test.pdf"
@pytest.mark.unit
def test_empty_filename_returns_uuid(self):
result = safe_filename("")
assert len(result) > 10 # UUID
@pytest.mark.unit
def test_none_filename_returns_uuid(self):
result = safe_filename(None)
assert len(result) > 10
@pytest.mark.unit
def test_non_latin_filename(self):
result = safe_filename("документ.pdf")
assert result.endswith(".pdf")
class TestNumTokens:
@pytest.mark.unit
def test_string_token_count(self):
count = num_tokens_from_string("hello world")
assert count > 0
@pytest.mark.unit
def test_non_string_returns_zero(self):
assert num_tokens_from_string(123) == 0
@pytest.mark.unit
def test_empty_string(self):
assert num_tokens_from_string("") == 0
class TestNumTokensFromObjectOrList:
@pytest.mark.unit
def test_list(self):
result = num_tokens_from_object_or_list(["hello", "world"])
assert result > 0
@pytest.mark.unit
def test_dict(self):
result = num_tokens_from_object_or_list({"key": "value"})
assert result > 0
@pytest.mark.unit
def test_string(self):
result = num_tokens_from_object_or_list("hello")
assert result > 0
@pytest.mark.unit
def test_number_returns_zero(self):
assert num_tokens_from_object_or_list(42) == 0
@pytest.mark.unit
def test_nested(self):
result = num_tokens_from_object_or_list({"a": ["b", "c"]})
assert result > 0
class TestCountTokensDocs:
@pytest.mark.unit
def test_counts_doc_tokens(self):
from application.utils import count_tokens_docs
doc1 = MagicMock()
doc1.page_content = "hello world"
doc2 = MagicMock()
doc2.page_content = " foo bar"
result = count_tokens_docs([doc1, doc2])
assert result > 0
class TestCalculateDocTokenBudget:
@pytest.mark.unit
def test_returns_budget(self):
with patch("application.utils.get_token_limit", return_value=128000), \
patch("application.utils.settings") as s:
s.RESERVED_TOKENS = {"system": 500, "history": 500}
result = calculate_doc_token_budget("gpt-4o")
assert result == 127000
@pytest.mark.unit
def test_minimum_budget(self):
with patch("application.utils.get_token_limit", return_value=1000), \
patch("application.utils.settings") as s:
s.RESERVED_TOKENS = {"system": 500, "history": 500}
result = calculate_doc_token_budget("small-model")
assert result == 1000
class TestFieldValidation:
@pytest.mark.unit
def test_get_missing_fields(self):
assert get_missing_fields({"a": 1}, ["a", "b"]) == ["b"]
assert get_missing_fields({"a": 1, "b": 2}, ["a", "b"]) == []
@pytest.mark.unit
def test_check_required_fields_pass(self):
from flask import Flask
app = Flask(__name__)
with app.app_context():
result = check_required_fields({"a": 1, "b": 2}, ["a", "b"])
assert result is None
@pytest.mark.unit
def test_check_required_fields_fail(self):
from flask import Flask
app = Flask(__name__)
with app.app_context():
result = check_required_fields({"a": 1}, ["a", "b"])
assert result is not None
assert result.status_code == 400
@pytest.mark.unit
def test_get_field_validation_errors_none_when_valid(self):
assert get_field_validation_errors({"a": 1}, ["a"]) is None
@pytest.mark.unit
def test_get_field_validation_errors_missing(self):
result = get_field_validation_errors({}, ["a"])
assert result["missing_fields"] == ["a"]
@pytest.mark.unit
def test_get_field_validation_errors_empty(self):
result = get_field_validation_errors({"a": ""}, ["a"])
assert result["empty_fields"] == ["a"]
@pytest.mark.unit
def test_validate_required_fields_pass(self):
from flask import Flask
app = Flask(__name__)
with app.app_context():
result = validate_required_fields({"a": "v"}, ["a"])
assert result is None
@pytest.mark.unit
def test_validate_required_fields_missing(self):
from flask import Flask
app = Flask(__name__)
with app.app_context():
result = validate_required_fields({}, ["a"])
assert result is not None
assert result.status_code == 400
@pytest.mark.unit
def test_validate_required_fields_empty(self):
from flask import Flask
app = Flask(__name__)
with app.app_context():
result = validate_required_fields({"a": ""}, ["a"])
assert result is not None
@pytest.mark.unit
def test_validate_required_fields_both_missing_and_empty(self):
from flask import Flask
app = Flask(__name__)
with app.app_context():
result = validate_required_fields({"a": ""}, ["a", "b"])
assert result is not None
class TestGetHash:
@pytest.mark.unit
def test_returns_hex_string(self):
h = get_hash("test")
assert len(h) == 32
assert all(c in "0123456789abcdef" for c in h)
@pytest.mark.unit
def test_deterministic(self):
assert get_hash("hello") == get_hash("hello")
@pytest.mark.unit
def test_different_inputs(self):
assert get_hash("a") != get_hash("b")
class TestLimitChatHistory:
@pytest.mark.unit
def test_empty_history(self):
assert limit_chat_history([]) == []
@pytest.mark.unit
def test_none_history(self):
assert limit_chat_history(None) == []
@pytest.mark.unit
def test_keeps_recent_messages(self):
history = [
{"prompt": "q1", "response": "a1"},
{"prompt": "q2", "response": "a2"},
]
result = limit_chat_history(history, max_token_limit=10000)
assert len(result) == 2
@pytest.mark.unit
def test_trims_old_messages(self):
history = [
{"prompt": "x" * 5000, "response": "y" * 5000},
{"prompt": "q", "response": "a"},
]
result = limit_chat_history(history, max_token_limit=100)
assert len(result) <= 2
@pytest.mark.unit
def test_handles_tool_calls(self):
history = [
{
"prompt": "q",
"response": "a",
"tool_calls": [
{"tool_name": "t", "action_name": "a", "arguments": "{}", "result": "r"}
],
}
]
result = limit_chat_history(history, max_token_limit=10000)
assert len(result) == 1
class TestValidateFunctionName:
@pytest.mark.unit
def test_valid_names(self):
assert validate_function_name("hello") is True
assert validate_function_name("hello_world") is True
assert validate_function_name("hello-world") is True
assert validate_function_name("test123") is True
@pytest.mark.unit
def test_invalid_names(self):
assert validate_function_name("hello world") is False
assert validate_function_name("hello!") is False
assert validate_function_name("") is False
class TestGenerateImageUrl:
@pytest.mark.unit
def test_http_url_passthrough(self):
assert generate_image_url("https://example.com/img.png") == "https://example.com/img.png"
assert generate_image_url("http://example.com/img.png") == "http://example.com/img.png"
@pytest.mark.unit
def test_s3_strategy(self):
with patch("application.utils.settings") as s:
s.URL_STRATEGY = "s3"
s.S3_BUCKET_NAME = "my-bucket"
s.SAGEMAKER_REGION = "us-west-2"
result = generate_image_url("path/to/img.png")
assert "my-bucket.s3.us-west-2" in result
@pytest.mark.unit
def test_backend_strategy(self):
with patch("application.utils.settings") as s:
s.URL_STRATEGY = "backend"
s.API_URL = "http://localhost:7091"
result = generate_image_url("path/to/img.png")
assert result == "http://localhost:7091/api/images/path/to/img.png"
class TestCalculateCompressionThreshold:
@pytest.mark.unit
def test_default_threshold(self):
with patch("application.utils.get_token_limit", return_value=100000):
result = calculate_compression_threshold("gpt-4o")
assert result == 80000
@pytest.mark.unit
def test_custom_percentage(self):
with patch("application.utils.get_token_limit", return_value=100000):
result = calculate_compression_threshold("gpt-4o", 0.5)
assert result == 50000
class TestConvertPdfToImages:
@pytest.mark.unit
def test_missing_pdf2image_raises(self):
with patch.dict("sys.modules", {"pdf2image": None}):
# Force re-import to trigger ImportError
# The function handles the import internally
with pytest.raises(ImportError, match="pdf2image"):
convert_pdf_to_images("test.pdf")
@pytest.mark.unit
def test_converts_from_path(self):
mock_image = MagicMock()
mock_image.save = MagicMock(side_effect=lambda buf, format: buf.write(b"PNG_DATA"))
mock_module = MagicMock()
mock_module.convert_from_path.return_value = [mock_image]
mock_module.convert_from_bytes.return_value = [mock_image]
original_import = __import__
def patched_import(name, *args, **kwargs):
if name == "pdf2image":
return mock_module
return original_import(name, *args, **kwargs)
with patch("builtins.__import__", side_effect=patched_import):
result = convert_pdf_to_images("/some/file.pdf")
assert len(result) == 1
assert result[0]["mime_type"] == "image/png"
assert result[0]["page"] == 1
@pytest.mark.unit
def test_with_storage(self):
mock_image = MagicMock()
mock_image.save = MagicMock(side_effect=lambda buf, format: buf.write(b"IMG"))
mock_storage = MagicMock()
mock_file = MagicMock()
mock_file.read.return_value = b"pdf_bytes"
mock_file.__enter__ = MagicMock(return_value=mock_file)
mock_file.__exit__ = MagicMock(return_value=False)
mock_storage.get_file.return_value = mock_file
mock_module = MagicMock()
mock_module.convert_from_bytes.return_value = [mock_image]
original_import = __import__
def patched_import(name, *args, **kwargs):
if name == "pdf2image":
return mock_module
return original_import(name, *args, **kwargs)
with patch("builtins.__import__", side_effect=patched_import):
result = convert_pdf_to_images("test.pdf", storage=mock_storage)
assert len(result) == 1
mock_module.convert_from_bytes.assert_called_once()
@pytest.mark.unit
def test_file_not_found_raises(self):
mock_module = MagicMock()
mock_module.convert_from_path.side_effect = FileNotFoundError("not found")
# Patch the import inside the function
original_import = __builtins__.__import__ if hasattr(__builtins__, '__import__') else __import__
def patched_import(name, *args, **kwargs):
if name == "pdf2image":
return mock_module
return original_import(name, *args, **kwargs)
with patch("builtins.__import__", side_effect=patched_import):
with pytest.raises(FileNotFoundError):
convert_pdf_to_images("/nonexistent.pdf")
@pytest.mark.unit
def test_generic_error_raises(self):
mock_module = MagicMock()
mock_module.convert_from_path.side_effect = RuntimeError("conversion failed")
original_import = __builtins__.__import__ if hasattr(__builtins__, '__import__') else __import__
def patched_import(name, *args, **kwargs):
if name == "pdf2image":
return mock_module
return original_import(name, *args, **kwargs)
with patch("builtins.__import__", side_effect=patched_import):
with pytest.raises(RuntimeError, match="conversion failed"):
convert_pdf_to_images("/some.pdf")
class TestCleanTextForTts:
@pytest.mark.unit
def test_removes_code_blocks(self):
result = clean_text_for_tts("before ```python\ncode\n``` after")
assert "code block" in result
assert "python" not in result
@pytest.mark.unit
def test_removes_mermaid_blocks(self):
result = clean_text_for_tts("```mermaid\ngraph TD\n```")
assert "flowchart" in result
@pytest.mark.unit
def test_removes_markdown_links(self):
result = clean_text_for_tts("[click here](https://example.com)")
assert "click here" in result
assert "https" not in result
@pytest.mark.unit
def test_removes_images(self):
result = clean_text_for_tts("![alt text](image.png)")
assert "image.png" not in result
@pytest.mark.unit
def test_removes_inline_code(self):
result = clean_text_for_tts("use `foo()` here")
assert "foo()" in result
assert "`" not in result
@pytest.mark.unit
def test_removes_bold_italic(self):
result = clean_text_for_tts("**bold** and *italic*")
assert "bold" in result
assert "italic" in result
assert "*" not in result
@pytest.mark.unit
def test_removes_headers(self):
result = clean_text_for_tts("# Header\ntext")
assert "Header" in result
assert "#" not in result
@pytest.mark.unit
def test_removes_blockquotes(self):
result = clean_text_for_tts("> quoted text")
assert "quoted text" in result
assert ">" not in result
@pytest.mark.unit
def test_removes_html_tags(self):
result = clean_text_for_tts("<div>content</div>")
assert "content" in result
assert "<" not in result
@pytest.mark.unit
def test_removes_arrows(self):
result = clean_text_for_tts("a --> b <-- c => d")
assert "-->" not in result
assert "<--" not in result
assert "=>" not in result
@pytest.mark.unit
def test_removes_horizontal_rules(self):
result = clean_text_for_tts("text\n---\nmore")
assert "---" not in result
@pytest.mark.unit
def test_removes_list_markers(self):
result = clean_text_for_tts("- item1\n* item2\n1. item3")
assert "item1" in result
assert "item2" in result
assert "item3" in result
@pytest.mark.unit
def test_normalizes_whitespace(self):
result = clean_text_for_tts(" lots of spaces ")
assert " " not in result
@pytest.mark.unit
def test_removes_braces(self):
result = clean_text_for_tts("{content} and [more]")
assert "content" in result
assert "more" in result
assert "{" not in result
@pytest.mark.unit
def test_removes_double_colons(self):
result = clean_text_for_tts("module::function")
assert "::" not in result

1718
tests/test_worker.py Normal file

File diff suppressed because it is too large Load Diff

View File

View File

@@ -0,0 +1,361 @@
from unittest.mock import Mock, patch
import pytest
from application.vectorstore.base import (
BaseVectorStore,
EmbeddingsSingleton,
RemoteEmbeddings,
)
# --- RemoteEmbeddings ---
@pytest.mark.unit
class TestRemoteEmbeddings:
def test_init_sets_url_and_headers(self):
emb = RemoteEmbeddings(
api_url="http://localhost:8080/", model_name="model-v1", api_key="sk-key"
)
assert emb.api_url == "http://localhost:8080"
assert emb.model_name == "model-v1"
assert emb.headers["Authorization"] == "Bearer sk-key"
def test_init_no_api_key(self):
emb = RemoteEmbeddings(api_url="http://host", model_name="m")
assert "Authorization" not in emb.headers
@patch("application.vectorstore.base.requests.post")
def test_embed_sends_correct_payload(self, mock_post):
mock_resp = Mock()
mock_resp.json.return_value = {
"data": [{"index": 0, "embedding": [0.1, 0.2]}]
}
mock_resp.raise_for_status = Mock()
mock_post.return_value = mock_resp
emb = RemoteEmbeddings("http://host", "model-v1")
result = emb._embed("test input")
mock_post.assert_called_once()
call_kwargs = mock_post.call_args
assert call_kwargs[1]["json"]["input"] == "test input"
assert call_kwargs[1]["json"]["model"] == "model-v1"
assert result == [[0.1, 0.2]]
@patch("application.vectorstore.base.requests.post")
def test_embed_sorts_by_index(self, mock_post):
mock_resp = Mock()
mock_resp.json.return_value = {
"data": [
{"index": 1, "embedding": [0.3, 0.4]},
{"index": 0, "embedding": [0.1, 0.2]},
]
}
mock_resp.raise_for_status = Mock()
mock_post.return_value = mock_resp
emb = RemoteEmbeddings("http://host", "m")
result = emb._embed(["a", "b"])
assert result == [[0.1, 0.2], [0.3, 0.4]]
@patch("application.vectorstore.base.requests.post")
def test_embed_raises_on_error_response(self, mock_post):
mock_resp = Mock()
mock_resp.json.return_value = {"error": "rate limit exceeded"}
mock_resp.raise_for_status = Mock()
mock_post.return_value = mock_resp
emb = RemoteEmbeddings("http://host", "m")
with pytest.raises(ValueError, match="rate limit exceeded"):
emb._embed("test")
@patch("application.vectorstore.base.requests.post")
def test_embed_raises_on_unexpected_format(self, mock_post):
mock_resp = Mock()
mock_resp.json.return_value = {"unexpected": True}
mock_resp.raise_for_status = Mock()
mock_post.return_value = mock_resp
emb = RemoteEmbeddings("http://host", "m")
with pytest.raises(ValueError, match="Unexpected response format"):
emb._embed("test")
@patch("application.vectorstore.base.requests.post")
def test_embed_raises_on_non_dict_response(self, mock_post):
mock_resp = Mock()
mock_resp.json.return_value = [1, 2, 3]
mock_resp.raise_for_status = Mock()
mock_post.return_value = mock_resp
emb = RemoteEmbeddings("http://host", "m")
with pytest.raises(ValueError, match="Unexpected response format"):
emb._embed("test")
@patch("application.vectorstore.base.requests.post")
def test_embed_query(self, mock_post):
mock_resp = Mock()
mock_resp.json.return_value = {
"data": [{"index": 0, "embedding": [0.1, 0.2, 0.3]}]
}
mock_resp.raise_for_status = Mock()
mock_post.return_value = mock_resp
emb = RemoteEmbeddings("http://host", "m")
emb.dimension = None # Reset so it gets set from response
result = emb.embed_query("hello")
assert result == [0.1, 0.2, 0.3]
assert emb.dimension == 3
@patch("application.vectorstore.base.requests.post")
def test_embed_query_raises_on_bad_structure(self, mock_post):
mock_resp = Mock()
# Return multiple embeddings for a single query
mock_resp.json.return_value = {
"data": [
{"index": 0, "embedding": [0.1]},
{"index": 1, "embedding": [0.2]},
]
}
mock_resp.raise_for_status = Mock()
mock_post.return_value = mock_resp
emb = RemoteEmbeddings("http://host", "m")
with pytest.raises(ValueError, match="Unexpected result structure"):
emb.embed_query("hello")
@patch("application.vectorstore.base.requests.post")
def test_embed_documents(self, mock_post):
mock_resp = Mock()
mock_resp.json.return_value = {
"data": [
{"index": 0, "embedding": [0.1, 0.2]},
{"index": 1, "embedding": [0.3, 0.4]},
]
}
mock_resp.raise_for_status = Mock()
mock_post.return_value = mock_resp
emb = RemoteEmbeddings("http://host", "m")
emb.dimension = None # Reset so it gets set from response
result = emb.embed_documents(["doc1", "doc2"])
assert result == [[0.1, 0.2], [0.3, 0.4]]
assert emb.dimension == 2
def test_embed_documents_empty(self):
emb = RemoteEmbeddings("http://host", "m")
assert emb.embed_documents([]) == []
@patch("application.vectorstore.base.requests.post")
def test_call_with_string(self, mock_post):
mock_resp = Mock()
mock_resp.json.return_value = {
"data": [{"index": 0, "embedding": [0.5]}]
}
mock_resp.raise_for_status = Mock()
mock_post.return_value = mock_resp
emb = RemoteEmbeddings("http://host", "m")
result = emb("hello")
assert result == [0.5]
@patch("application.vectorstore.base.requests.post")
def test_call_with_list(self, mock_post):
mock_resp = Mock()
mock_resp.json.return_value = {
"data": [{"index": 0, "embedding": [0.5]}]
}
mock_resp.raise_for_status = Mock()
mock_post.return_value = mock_resp
emb = RemoteEmbeddings("http://host", "m")
result = emb(["hello"])
assert result == [[0.5]]
def test_call_with_invalid_type(self):
emb = RemoteEmbeddings("http://host", "m")
with pytest.raises(ValueError, match="Input must be a string or a list"):
emb(123)
# --- EmbeddingsSingleton ---
@pytest.mark.unit
class TestEmbeddingsSingleton:
def setup_method(self):
EmbeddingsSingleton._instances = {}
@patch("application.vectorstore.base.OpenAIEmbeddings")
def test_get_instance_openai(self, mock_openai_cls):
mock_instance = Mock()
mock_openai_cls.return_value = mock_instance
result = EmbeddingsSingleton.get_instance("openai_text-embedding-ada-002")
assert result is mock_instance
@patch("application.vectorstore.base.OpenAIEmbeddings")
def test_singleton_returns_same_instance(self, mock_openai_cls):
mock_instance = Mock()
mock_openai_cls.return_value = mock_instance
r1 = EmbeddingsSingleton.get_instance("openai_text-embedding-ada-002")
r2 = EmbeddingsSingleton.get_instance("openai_text-embedding-ada-002")
assert r1 is r2
mock_openai_cls.assert_called_once()
@patch("application.vectorstore.base._get_embeddings_wrapper")
def test_get_instance_huggingface(self, mock_get_wrapper):
mock_wrapper_cls = Mock()
mock_instance = Mock()
mock_wrapper_cls.return_value = mock_instance
mock_get_wrapper.return_value = mock_wrapper_cls
result = EmbeddingsSingleton.get_instance(
"huggingface_sentence-transformers/all-mpnet-base-v2"
)
assert result is mock_instance
@patch("application.vectorstore.base._get_embeddings_wrapper")
def test_get_instance_unknown_falls_back_to_wrapper(self, mock_get_wrapper):
mock_wrapper_cls = Mock()
mock_instance = Mock()
mock_wrapper_cls.return_value = mock_instance
mock_get_wrapper.return_value = mock_wrapper_cls
result = EmbeddingsSingleton.get_instance("custom_model_name")
mock_wrapper_cls.assert_called_once_with("custom_model_name")
assert result is mock_instance
# --- BaseVectorStore ---
class ConcreteVectorStore(BaseVectorStore):
"""Concrete implementation for testing base class methods."""
def search(self, *args, **kwargs):
return []
def add_texts(self, texts, metadatas=None, *args, **kwargs):
return []
@pytest.mark.unit
class TestBaseVectorStore:
def setup_method(self):
EmbeddingsSingleton._instances = {}
def test_default_methods_are_noop(self):
store = ConcreteVectorStore()
assert store.delete_index() is None
assert store.save_local() is None
assert store.get_chunks() is None
assert store.add_chunk("text") is None
assert store.delete_chunk("id") is None
@patch("application.vectorstore.base.settings")
def test_is_azure_configured_true(self, mock_settings):
mock_settings.OPENAI_API_BASE = "https://azure.openai.com"
mock_settings.OPENAI_API_VERSION = "2023-05-15"
mock_settings.AZURE_DEPLOYMENT_NAME = "my-deploy"
store = ConcreteVectorStore()
assert store.is_azure_configured()
@patch("application.vectorstore.base.settings")
def test_is_azure_configured_false(self, mock_settings):
mock_settings.OPENAI_API_BASE = None
mock_settings.OPENAI_API_VERSION = None
mock_settings.AZURE_DEPLOYMENT_NAME = None
store = ConcreteVectorStore()
assert not store.is_azure_configured()
@patch("application.vectorstore.base.settings")
def test_get_embeddings_remote(self, mock_settings):
mock_settings.EMBEDDINGS_BASE_URL = "http://remote:8080"
store = ConcreteVectorStore()
result = store._get_embeddings("model-name", "api-key")
assert isinstance(result, RemoteEmbeddings)
assert result.api_url == "http://remote:8080"
@patch("application.vectorstore.base.settings")
@patch("application.vectorstore.base.EmbeddingsSingleton.get_instance")
def test_get_embeddings_openai(self, mock_get_instance, mock_settings):
mock_settings.EMBEDDINGS_BASE_URL = None
mock_settings.OPENAI_API_BASE = None
mock_settings.OPENAI_API_VERSION = None
mock_settings.AZURE_DEPLOYMENT_NAME = None
mock_emb = Mock()
mock_get_instance.return_value = mock_emb
store = ConcreteVectorStore()
result = store._get_embeddings("openai_text-embedding-ada-002", "sk-key")
assert result is mock_emb
@patch("application.vectorstore.base.settings")
@patch("application.vectorstore.base.EmbeddingsSingleton.get_instance")
def test_get_embeddings_openai_azure(self, mock_get_instance, mock_settings):
mock_settings.EMBEDDINGS_BASE_URL = None
mock_settings.OPENAI_API_BASE = "https://azure.openai.com"
mock_settings.OPENAI_API_VERSION = "2023-05-15"
mock_settings.AZURE_DEPLOYMENT_NAME = "deploy"
mock_settings.AZURE_EMBEDDINGS_DEPLOYMENT_NAME = "embed-deploy"
mock_emb = Mock()
mock_get_instance.return_value = mock_emb
store = ConcreteVectorStore()
result = store._get_embeddings("openai_text-embedding-ada-002", "sk-key")
assert result is mock_emb
@patch("application.vectorstore.base.settings")
@patch("application.vectorstore.base.EmbeddingsSingleton.get_instance")
@patch("os.path.exists", return_value=False)
def test_get_embeddings_huggingface_no_local_model(
self, mock_exists, mock_get_instance, mock_settings
):
mock_settings.EMBEDDINGS_BASE_URL = None
mock_emb = Mock()
mock_get_instance.return_value = mock_emb
store = ConcreteVectorStore()
result = store._get_embeddings(
"huggingface_sentence-transformers/all-mpnet-base-v2"
)
assert result is mock_emb
@patch("application.vectorstore.base.settings")
@patch("application.vectorstore.base.EmbeddingsSingleton.get_instance")
@patch("os.path.exists")
def test_get_embeddings_huggingface_local_model(
self, mock_exists, mock_get_instance, mock_settings
):
mock_settings.EMBEDDINGS_BASE_URL = None
mock_exists.side_effect = lambda p: p == "/app/models/all-mpnet-base-v2"
mock_emb = Mock()
mock_get_instance.return_value = mock_emb
store = ConcreteVectorStore()
result = store._get_embeddings(
"huggingface_sentence-transformers/all-mpnet-base-v2"
)
assert result is mock_emb
mock_get_instance.assert_called_with("/app/models/all-mpnet-base-v2")
@patch("application.vectorstore.base.settings")
@patch("application.vectorstore.base.EmbeddingsSingleton.get_instance")
def test_get_embeddings_generic(self, mock_get_instance, mock_settings):
mock_settings.EMBEDDINGS_BASE_URL = None
mock_emb = Mock()
mock_get_instance.return_value = mock_emb
store = ConcreteVectorStore()
result = store._get_embeddings("some_custom_embedding")
assert result is mock_emb
mock_get_instance.assert_called_with("some_custom_embedding")

View File

@@ -0,0 +1,39 @@
import pytest
from application.vectorstore.document_class import Document
@pytest.mark.unit
class TestDocument:
def test_create_document(self):
doc = Document(page_content="hello world", metadata={"source": "test"})
assert doc.page_content == "hello world"
assert doc.metadata == {"source": "test"}
def test_document_is_string(self):
doc = Document(page_content="hello world", metadata={})
assert isinstance(doc, str)
assert str(doc) == "hello world"
def test_document_string_equality(self):
doc = Document(page_content="hello", metadata={"k": "v"})
assert doc == "hello"
def test_document_empty_metadata(self):
doc = Document(page_content="text", metadata={})
assert doc.metadata == {}
def test_document_empty_content(self):
doc = Document(page_content="", metadata={"a": 1})
assert doc.page_content == ""
assert doc == ""
def test_document_preserves_complex_metadata(self):
meta = {"source": "file.txt", "page": 3, "nested": {"key": "val"}}
doc = Document(page_content="content", metadata=meta)
assert doc.metadata["nested"]["key"] == "val"
def test_document_string_operations(self):
doc = Document(page_content="hello world", metadata={})
assert doc.upper() == "HELLO WORLD"
assert doc.split() == ["hello", "world"]
assert "world" in doc

View File

@@ -0,0 +1,292 @@
from unittest.mock import MagicMock, Mock, patch
import pytest
def _make_es_store(source_id="test-source"):
"""Helper to create an ElasticsearchStore with mocked deps."""
# Reset class-level connection
from application.vectorstore.elasticsearch import ElasticsearchStore
ElasticsearchStore._es_connection = None
with patch(
"application.vectorstore.elasticsearch.settings"
) as mock_settings, patch.dict(
"sys.modules", {"elasticsearch": MagicMock(), "elasticsearch.helpers": MagicMock()}
):
mock_settings.ELASTIC_URL = "http://localhost:9200"
mock_settings.ELASTIC_USERNAME = "elastic"
mock_settings.ELASTIC_PASSWORD = "password"
mock_settings.ELASTIC_CLOUD_ID = None
mock_settings.ELASTIC_INDEX = "test_index"
mock_settings.EMBEDDINGS_NAME = "test_model"
import elasticsearch
mock_es = MagicMock()
elasticsearch.Elasticsearch.return_value = mock_es
store = ElasticsearchStore(
source_id=source_id,
embeddings_key="key",
index_name="test_index",
)
return store, mock_es, mock_settings
@pytest.mark.unit
class TestElasticsearchStoreInit:
def test_source_id_cleaned(self):
store, _, _ = _make_es_store(source_id="application/indexes/abc123/")
assert store.source_id == "abc123"
def test_init_with_url(self):
store, mock_es, _ = _make_es_store()
assert store.docsearch is mock_es
assert store.index_name == "test_index"
def test_init_with_cloud_id(self):
from application.vectorstore.elasticsearch import ElasticsearchStore
ElasticsearchStore._es_connection = None
with patch(
"application.vectorstore.elasticsearch.settings"
) as mock_settings, patch.dict(
"sys.modules", {"elasticsearch": MagicMock()}
):
mock_settings.ELASTIC_URL = None
mock_settings.ELASTIC_CLOUD_ID = "my-cloud-id"
mock_settings.ELASTIC_USERNAME = "user"
mock_settings.ELASTIC_PASSWORD = "pass"
mock_settings.ELASTIC_INDEX = "idx"
mock_settings.EMBEDDINGS_NAME = "model"
store = ElasticsearchStore(
source_id="src", embeddings_key="k", index_name="idx"
)
assert store.docsearch is not None
def test_init_no_url_no_cloud_id_raises(self):
from application.vectorstore.elasticsearch import ElasticsearchStore
ElasticsearchStore._es_connection = None
with patch(
"application.vectorstore.elasticsearch.settings"
) as mock_settings, patch.dict(
"sys.modules", {"elasticsearch": MagicMock()}
):
mock_settings.ELASTIC_URL = None
mock_settings.ELASTIC_CLOUD_ID = None
mock_settings.ELASTIC_INDEX = "idx"
mock_settings.EMBEDDINGS_NAME = "model"
with pytest.raises(ValueError, match="provide either"):
ElasticsearchStore(source_id="src", embeddings_key="k")
def test_reuses_class_connection(self):
from application.vectorstore.elasticsearch import ElasticsearchStore
ElasticsearchStore._es_connection = None
with patch(
"application.vectorstore.elasticsearch.settings"
) as mock_settings, patch.dict(
"sys.modules", {"elasticsearch": MagicMock()}
):
mock_settings.ELASTIC_URL = "http://localhost:9200"
mock_settings.ELASTIC_USERNAME = "user"
mock_settings.ELASTIC_PASSWORD = "pass"
mock_settings.ELASTIC_CLOUD_ID = None
mock_settings.ELASTIC_INDEX = "idx"
mock_settings.EMBEDDINGS_NAME = "model"
import elasticsearch
mock_es = MagicMock()
elasticsearch.Elasticsearch.return_value = mock_es
store1 = ElasticsearchStore(source_id="src1", embeddings_key="k")
store2 = ElasticsearchStore(source_id="src2", embeddings_key="k")
assert store1.docsearch is store2.docsearch
elasticsearch.Elasticsearch.assert_called_once()
@pytest.mark.unit
class TestElasticsearchStoreSearch:
def test_search_builds_query(self):
store, mock_es, mock_settings = _make_es_store()
mock_emb = Mock()
mock_emb.embed_query = Mock(return_value=[0.1, 0.2, 0.3])
with patch.object(store, "_get_embeddings", return_value=mock_emb):
mock_es.search.return_value = {
"hits": {
"hits": [
{
"_source": {
"text": "doc1",
"metadata": {"source": "file.txt"},
}
},
{
"_source": {
"text": "doc2",
"metadata": {"source": "file2.txt"},
}
},
]
}
}
results = store.search("query", k=2)
assert len(results) == 2
assert results[0].page_content == "doc1"
assert results[1].metadata == {"source": "file2.txt"}
def test_search_empty_results(self):
store, mock_es, _ = _make_es_store()
mock_emb = Mock()
mock_emb.embed_query = Mock(return_value=[0.1])
with patch.object(store, "_get_embeddings", return_value=mock_emb):
mock_es.search.return_value = {"hits": {"hits": []}}
results = store.search("query")
assert results == []
@pytest.mark.unit
class TestElasticsearchStoreAddTexts:
def test_add_texts(self):
store, mock_es, mock_settings = _make_es_store()
mock_emb = Mock()
mock_emb.embed_documents = Mock(return_value=[[0.1, 0.2], [0.3, 0.4]])
mock_bulk = Mock(return_value=(2, 0))
mock_helpers = MagicMock()
mock_helpers.bulk = mock_bulk
with patch.object(
store, "_get_embeddings", return_value=mock_emb
), patch.object(
store, "_create_index_if_not_exists"
), patch.dict(
"sys.modules", {"elasticsearch.helpers": mock_helpers}
):
ids = store.add_texts(
["text1", "text2"],
metadatas=[{"a": 1}, {"b": 2}],
)
assert len(ids) == 2
def test_add_texts_empty_raises(self):
"""Empty texts causes IndexError because code accesses vectors[0] unconditionally."""
store, _, _ = _make_es_store()
mock_emb = Mock()
mock_emb.embed_documents = Mock(return_value=[])
with patch.object(store, "_get_embeddings", return_value=mock_emb):
with pytest.raises(IndexError):
store.add_texts([], metadatas=[])
@pytest.mark.unit
class TestElasticsearchStoreDeleteIndex:
def test_delete_index_calls_delete_by_query(self):
store, mock_es, _ = _make_es_store(source_id="src1")
store.delete_index()
mock_es.delete_by_query.assert_called_once_with(
index="test_index",
query={"match": {"metadata.source_id.keyword": "src1"}},
)
@pytest.mark.unit
class TestElasticsearchStoreIndex:
def test_index_returns_mapping(self):
store, _, _ = _make_es_store()
mapping = store.index(dims_length=768)
assert mapping["mappings"]["properties"]["vector"]["type"] == "dense_vector"
assert mapping["mappings"]["properties"]["vector"]["dims"] == 768
assert mapping["mappings"]["properties"]["vector"]["similarity"] == "cosine"
def test_create_index_if_not_exists_existing(self):
store, mock_es, _ = _make_es_store()
mock_es.indices.exists.return_value = True
store._create_index_if_not_exists("test_index", 768)
mock_es.indices.create.assert_not_called()
def test_create_index_if_not_exists_new(self):
store, mock_es, _ = _make_es_store()
mock_es.indices.exists.return_value = False
store._create_index_if_not_exists("test_index", 768)
mock_es.indices.create.assert_called_once()
@pytest.mark.unit
class TestElasticsearchStoreConnectToElasticsearch:
def test_connect_with_url(self):
from application.vectorstore.elasticsearch import ElasticsearchStore
with patch.dict("sys.modules", {"elasticsearch": MagicMock()}):
import elasticsearch
mock_es = MagicMock()
elasticsearch.Elasticsearch.return_value = mock_es
result = ElasticsearchStore.connect_to_elasticsearch(
es_url="http://localhost:9200",
username="user",
password="pass",
)
assert result is mock_es
def test_connect_with_both_raises(self):
from application.vectorstore.elasticsearch import ElasticsearchStore
with patch.dict("sys.modules", {"elasticsearch": MagicMock()}):
with pytest.raises(ValueError, match="Both es_url and cloud_id"):
ElasticsearchStore.connect_to_elasticsearch(
es_url="http://localhost", cloud_id="cloud-123"
)
def test_connect_with_neither_raises(self):
from application.vectorstore.elasticsearch import ElasticsearchStore
with patch.dict("sys.modules", {"elasticsearch": MagicMock()}):
with pytest.raises(ValueError, match="provide either"):
ElasticsearchStore.connect_to_elasticsearch()
def test_connect_with_api_key(self):
from application.vectorstore.elasticsearch import ElasticsearchStore
with patch.dict("sys.modules", {"elasticsearch": MagicMock()}):
import elasticsearch
mock_es = MagicMock()
elasticsearch.Elasticsearch.return_value = mock_es
result = ElasticsearchStore.connect_to_elasticsearch(
es_url="http://localhost:9200",
api_key="my-api-key",
)
assert result is mock_es

View File

@@ -0,0 +1,140 @@
from unittest.mock import MagicMock, Mock, patch
import pytest
@pytest.mark.unit
class TestEmbeddingsWrapper:
@patch("application.vectorstore.embeddings_local.SentenceTransformer")
def test_init_success(self, mock_st_cls):
mock_model = MagicMock()
mock_model._first_module.return_value = MagicMock()
mock_model.get_sentence_embedding_dimension.return_value = 768
mock_st_cls.return_value = mock_model
from application.vectorstore.embeddings_local import EmbeddingsWrapper
wrapper = EmbeddingsWrapper("test-model")
mock_st_cls.assert_called_once()
assert wrapper.dimension == 768
@patch("application.vectorstore.embeddings_local.SentenceTransformer")
def test_init_failure(self, mock_st_cls):
mock_st_cls.side_effect = Exception("model not found")
from application.vectorstore.embeddings_local import EmbeddingsWrapper
with pytest.raises(Exception, match="model not found"):
EmbeddingsWrapper("bad-model")
@patch("application.vectorstore.embeddings_local.SentenceTransformer")
def test_init_none_model(self, mock_st_cls):
mock_st_cls.return_value = None
from application.vectorstore.embeddings_local import EmbeddingsWrapper
with pytest.raises((ValueError, AttributeError)):
EmbeddingsWrapper("bad-model")
@patch("application.vectorstore.embeddings_local.SentenceTransformer")
def test_init_null_first_module(self, mock_st_cls):
mock_model = MagicMock()
mock_model._first_module.return_value = None
mock_st_cls.return_value = mock_model
from application.vectorstore.embeddings_local import EmbeddingsWrapper
with pytest.raises(ValueError, match="failed to load properly"):
EmbeddingsWrapper("bad-model")
@patch("application.vectorstore.embeddings_local.SentenceTransformer")
def test_embed_query(self, mock_st_cls):
mock_model = MagicMock()
mock_model._first_module.return_value = MagicMock()
mock_model.get_sentence_embedding_dimension.return_value = 3
mock_model.encode.return_value = MagicMock(tolist=Mock(return_value=[0.1, 0.2, 0.3]))
mock_st_cls.return_value = mock_model
from application.vectorstore.embeddings_local import EmbeddingsWrapper
wrapper = EmbeddingsWrapper("model")
result = wrapper.embed_query("hello world")
mock_model.encode.assert_called_once_with("hello world")
assert result == [0.1, 0.2, 0.3]
@patch("application.vectorstore.embeddings_local.SentenceTransformer")
def test_embed_documents(self, mock_st_cls):
mock_model = MagicMock()
mock_model._first_module.return_value = MagicMock()
mock_model.get_sentence_embedding_dimension.return_value = 3
mock_model.encode.return_value = MagicMock(
tolist=Mock(return_value=[[0.1, 0.2], [0.3, 0.4]])
)
mock_st_cls.return_value = mock_model
from application.vectorstore.embeddings_local import EmbeddingsWrapper
wrapper = EmbeddingsWrapper("model")
result = wrapper.embed_documents(["doc1", "doc2"])
mock_model.encode.assert_called_with(["doc1", "doc2"])
assert result == [[0.1, 0.2], [0.3, 0.4]]
@patch("application.vectorstore.embeddings_local.SentenceTransformer")
def test_call_with_string(self, mock_st_cls):
mock_model = MagicMock()
mock_model._first_module.return_value = MagicMock()
mock_model.get_sentence_embedding_dimension.return_value = 3
mock_model.encode.return_value = MagicMock(tolist=Mock(return_value=[0.1]))
mock_st_cls.return_value = mock_model
from application.vectorstore.embeddings_local import EmbeddingsWrapper
wrapper = EmbeddingsWrapper("model")
result = wrapper("hello")
assert result == [0.1]
@patch("application.vectorstore.embeddings_local.SentenceTransformer")
def test_call_with_list(self, mock_st_cls):
mock_model = MagicMock()
mock_model._first_module.return_value = MagicMock()
mock_model.get_sentence_embedding_dimension.return_value = 3
mock_model.encode.return_value = MagicMock(
tolist=Mock(return_value=[[0.1], [0.2]])
)
mock_st_cls.return_value = mock_model
from application.vectorstore.embeddings_local import EmbeddingsWrapper
wrapper = EmbeddingsWrapper("model")
result = wrapper(["a", "b"])
assert result == [[0.1], [0.2]]
@patch("application.vectorstore.embeddings_local.SentenceTransformer")
def test_call_with_invalid_type(self, mock_st_cls):
mock_model = MagicMock()
mock_model._first_module.return_value = MagicMock()
mock_model.get_sentence_embedding_dimension.return_value = 3
mock_st_cls.return_value = mock_model
from application.vectorstore.embeddings_local import EmbeddingsWrapper
wrapper = EmbeddingsWrapper("model")
with pytest.raises(ValueError, match="Input must be a string or a list"):
wrapper(123)
@patch("application.vectorstore.embeddings_local.SentenceTransformer")
def test_trust_remote_code_default(self, mock_st_cls):
mock_model = MagicMock()
mock_model._first_module.return_value = MagicMock()
mock_model.get_sentence_embedding_dimension.return_value = 768
mock_st_cls.return_value = mock_model
from application.vectorstore.embeddings_local import EmbeddingsWrapper
EmbeddingsWrapper("model")
call_kwargs = mock_st_cls.call_args[1]
assert call_kwargs["trust_remote_code"] is True

View File

@@ -0,0 +1,359 @@
import io
from unittest.mock import Mock, patch
import pytest
@pytest.fixture
def mock_embeddings():
emb = Mock()
emb.embed_query = Mock(return_value=[0.1, 0.2, 0.3])
emb.embed_documents = Mock(return_value=[[0.1, 0.2, 0.3]])
emb.dimension = 3
return emb
@pytest.fixture
def mock_storage():
storage = Mock()
storage.file_exists = Mock(return_value=True)
storage.get_file = Mock(return_value=io.BytesIO(b"fake data"))
storage.save_file = Mock()
return storage
@pytest.fixture
def mock_docsearch():
ds = Mock()
ds.similarity_search = Mock(return_value=[])
ds.add_texts = Mock(return_value=["id1"])
ds.add_documents = Mock(return_value=["id1"])
ds.save_local = Mock()
ds.delete = Mock()
ds.index = Mock()
ds.index.d = 3
ds.docstore = Mock()
ds.docstore._dict = {
"doc1": Mock(page_content="text1", metadata={"source": "a"}),
"doc2": Mock(page_content="text2", metadata={"source": "b"}),
}
return ds
@pytest.mark.unit
class TestFaissStoreInit:
@patch("application.vectorstore.faiss.StorageCreator")
@patch("application.vectorstore.faiss.FAISS")
@patch.object(
__import__("application.vectorstore.base", fromlist=["BaseVectorStore"]).BaseVectorStore,
"_get_embeddings",
)
@patch("application.vectorstore.faiss.settings")
def test_init_with_docs(self, mock_settings, mock_get_emb, mock_faiss, mock_storage_creator):
mock_settings.EMBEDDINGS_NAME = "test_model"
mock_emb = Mock(dimension=3)
mock_get_emb.return_value = mock_emb
mock_ds = Mock()
mock_ds.index = Mock(d=3)
mock_faiss.from_documents.return_value = mock_ds
mock_storage_creator.get_storage.return_value = Mock()
from application.vectorstore.faiss import FaissStore
store = FaissStore(source_id="test", embeddings_key="key", docs_init=[Mock()])
mock_faiss.from_documents.assert_called_once()
assert store.docsearch is mock_ds
@patch("application.vectorstore.faiss.StorageCreator")
@patch("application.vectorstore.faiss.FAISS")
@patch.object(
__import__("application.vectorstore.base", fromlist=["BaseVectorStore"]).BaseVectorStore,
"_get_embeddings",
)
@patch("application.vectorstore.faiss.settings")
def test_init_missing_index_files(
self, mock_settings, mock_get_emb, mock_faiss, mock_storage_creator
):
mock_settings.EMBEDDINGS_NAME = "test_model"
mock_emb = Mock(dimension=3)
mock_get_emb.return_value = mock_emb
mock_storage = Mock()
mock_storage.file_exists.return_value = False
mock_storage_creator.get_storage.return_value = mock_storage
from application.vectorstore.faiss import FaissStore
with pytest.raises(Exception, match="Error loading FAISS index"):
FaissStore(source_id="test", embeddings_key="key")
@pytest.mark.unit
class TestFaissStoreSearch:
@patch("application.vectorstore.faiss.StorageCreator")
@patch("application.vectorstore.faiss.FAISS")
@patch.object(
__import__("application.vectorstore.base", fromlist=["BaseVectorStore"]).BaseVectorStore,
"_get_embeddings",
)
@patch("application.vectorstore.faiss.settings")
def test_search_delegates_to_docsearch(
self, mock_settings, mock_get_emb, mock_faiss, mock_storage_creator
):
mock_settings.EMBEDDINGS_NAME = "test_model"
mock_emb = Mock(dimension=3)
mock_get_emb.return_value = mock_emb
mock_ds = Mock()
mock_ds.index = Mock(d=3)
mock_ds.similarity_search.return_value = ["doc1"]
mock_faiss.from_documents.return_value = mock_ds
mock_storage_creator.get_storage.return_value = Mock()
from application.vectorstore.faiss import FaissStore
store = FaissStore(source_id="t", embeddings_key="k", docs_init=[Mock()])
result = store.search("query", k=5)
mock_ds.similarity_search.assert_called_once_with("query", k=5)
assert result == ["doc1"]
@pytest.mark.unit
class TestFaissStoreAddTexts:
@patch("application.vectorstore.faiss.StorageCreator")
@patch("application.vectorstore.faiss.FAISS")
@patch.object(
__import__("application.vectorstore.base", fromlist=["BaseVectorStore"]).BaseVectorStore,
"_get_embeddings",
)
@patch("application.vectorstore.faiss.settings")
def test_add_texts_delegates(
self, mock_settings, mock_get_emb, mock_faiss, mock_storage_creator
):
mock_settings.EMBEDDINGS_NAME = "test_model"
mock_emb = Mock(dimension=3)
mock_get_emb.return_value = mock_emb
mock_ds = Mock()
mock_ds.index = Mock(d=3)
mock_ds.add_texts.return_value = ["id1", "id2"]
mock_faiss.from_documents.return_value = mock_ds
mock_storage_creator.get_storage.return_value = Mock()
from application.vectorstore.faiss import FaissStore
store = FaissStore(source_id="t", embeddings_key="k", docs_init=[Mock()])
result = store.add_texts(["text1", "text2"])
assert result == ["id1", "id2"]
@pytest.mark.unit
class TestFaissStoreGetChunks:
@patch("application.vectorstore.faiss.StorageCreator")
@patch("application.vectorstore.faiss.FAISS")
@patch.object(
__import__("application.vectorstore.base", fromlist=["BaseVectorStore"]).BaseVectorStore,
"_get_embeddings",
)
@patch("application.vectorstore.faiss.settings")
def test_get_chunks(self, mock_settings, mock_get_emb, mock_faiss, mock_storage_creator):
mock_settings.EMBEDDINGS_NAME = "test_model"
mock_emb = Mock(dimension=3)
mock_get_emb.return_value = mock_emb
doc1 = Mock(page_content="text1", metadata={"source": "a"})
doc2 = Mock(page_content="text2", metadata={"source": "b"})
mock_ds = Mock()
mock_ds.index = Mock(d=3)
mock_ds.docstore._dict = {"id1": doc1, "id2": doc2}
mock_faiss.from_documents.return_value = mock_ds
mock_storage_creator.get_storage.return_value = Mock()
from application.vectorstore.faiss import FaissStore
store = FaissStore(source_id="t", embeddings_key="k", docs_init=[Mock()])
chunks = store.get_chunks()
assert len(chunks) == 2
texts = {c["text"] for c in chunks}
assert texts == {"text1", "text2"}
@patch("application.vectorstore.faiss.StorageCreator")
@patch("application.vectorstore.faiss.FAISS")
@patch.object(
__import__("application.vectorstore.base", fromlist=["BaseVectorStore"]).BaseVectorStore,
"_get_embeddings",
)
@patch("application.vectorstore.faiss.settings")
def test_get_chunks_empty(self, mock_settings, mock_get_emb, mock_faiss, mock_storage_creator):
mock_settings.EMBEDDINGS_NAME = "test_model"
mock_emb = Mock(dimension=3)
mock_get_emb.return_value = mock_emb
mock_ds = Mock()
mock_ds.index = Mock(d=3)
mock_ds.docstore._dict = {}
mock_faiss.from_documents.return_value = mock_ds
mock_storage_creator.get_storage.return_value = Mock()
from application.vectorstore.faiss import FaissStore
store = FaissStore(source_id="t", embeddings_key="k", docs_init=[Mock()])
assert store.get_chunks() == []
@pytest.mark.unit
class TestFaissStoreSaveLocal:
@patch("application.vectorstore.faiss.StorageCreator")
@patch("application.vectorstore.faiss.FAISS")
@patch.object(
__import__("application.vectorstore.base", fromlist=["BaseVectorStore"]).BaseVectorStore,
"_get_embeddings",
)
@patch("application.vectorstore.faiss.settings")
def test_save_local_with_path(
self, mock_settings, mock_get_emb, mock_faiss, mock_storage_creator
):
mock_settings.EMBEDDINGS_NAME = "test_model"
mock_emb = Mock(dimension=3)
mock_get_emb.return_value = mock_emb
mock_ds = Mock()
mock_ds.index = Mock(d=3)
mock_faiss.from_documents.return_value = mock_ds
mock_storage = Mock()
mock_storage_creator.get_storage.return_value = mock_storage
from application.vectorstore.faiss import FaissStore
store = FaissStore(source_id="t", embeddings_key="k", docs_init=[Mock()])
# Mock _save_to_storage to avoid file I/O
store._save_to_storage = Mock(return_value=True)
with patch("os.makedirs"):
result = store.save_local(path="/tmp/test_save")
mock_ds.save_local.assert_called_once_with("/tmp/test_save")
store._save_to_storage.assert_called_once()
assert result is True
@pytest.mark.unit
class TestFaissStoreDeleteIndex:
@patch("application.vectorstore.faiss.StorageCreator")
@patch("application.vectorstore.faiss.FAISS")
@patch.object(
__import__("application.vectorstore.base", fromlist=["BaseVectorStore"]).BaseVectorStore,
"_get_embeddings",
)
@patch("application.vectorstore.faiss.settings")
def test_delete_index_delegates(
self, mock_settings, mock_get_emb, mock_faiss, mock_storage_creator
):
mock_settings.EMBEDDINGS_NAME = "test_model"
mock_emb = Mock(dimension=3)
mock_get_emb.return_value = mock_emb
mock_ds = Mock()
mock_ds.index = Mock(d=3)
mock_faiss.from_documents.return_value = mock_ds
mock_storage_creator.get_storage.return_value = Mock()
from application.vectorstore.faiss import FaissStore
store = FaissStore(source_id="t", embeddings_key="k", docs_init=[Mock()])
store.delete_index(["id1"])
mock_ds.delete.assert_called_once_with(["id1"])
@pytest.mark.unit
class TestFaissStoreAssertEmbeddingDimensions:
@patch("application.vectorstore.faiss.StorageCreator")
@patch("application.vectorstore.faiss.FAISS")
@patch.object(
__import__("application.vectorstore.base", fromlist=["BaseVectorStore"]).BaseVectorStore,
"_get_embeddings",
)
@patch("application.vectorstore.faiss.settings")
def test_dimension_mismatch_raises(
self, mock_settings, mock_get_emb, mock_faiss, mock_storage_creator
):
mock_settings.EMBEDDINGS_NAME = (
"huggingface_sentence-transformers/all-mpnet-base-v2"
)
mock_emb = Mock(dimension=768)
mock_get_emb.return_value = mock_emb
mock_ds = Mock()
mock_ds.index = Mock(d=512) # Mismatched dimension
mock_faiss.from_documents.return_value = mock_ds
mock_storage_creator.get_storage.return_value = Mock()
from application.vectorstore.faiss import FaissStore
with pytest.raises(ValueError, match="Embedding dimension mismatch"):
FaissStore(source_id="t", embeddings_key="k", docs_init=[Mock()])
@patch("application.vectorstore.faiss.StorageCreator")
@patch("application.vectorstore.faiss.FAISS")
@patch.object(
__import__("application.vectorstore.base", fromlist=["BaseVectorStore"]).BaseVectorStore,
"_get_embeddings",
)
@patch("application.vectorstore.faiss.settings")
def test_missing_dimension_attr_raises(
self, mock_settings, mock_get_emb, mock_faiss, mock_storage_creator
):
mock_settings.EMBEDDINGS_NAME = (
"huggingface_sentence-transformers/all-mpnet-base-v2"
)
mock_emb = Mock(spec=[]) # No dimension attribute
mock_get_emb.return_value = mock_emb
mock_ds = Mock()
mock_ds.index = Mock(d=768)
mock_faiss.from_documents.return_value = mock_ds
mock_storage_creator.get_storage.return_value = Mock()
from application.vectorstore.faiss import FaissStore
with pytest.raises(AttributeError, match="dimension"):
FaissStore(source_id="t", embeddings_key="k", docs_init=[Mock()])
@pytest.mark.unit
class TestFaissStoreDeleteChunk:
@patch("application.vectorstore.faiss.StorageCreator")
@patch("application.vectorstore.faiss.FAISS")
@patch.object(
__import__("application.vectorstore.base", fromlist=["BaseVectorStore"]).BaseVectorStore,
"_get_embeddings",
)
@patch("application.vectorstore.faiss.settings")
def test_delete_chunk(self, mock_settings, mock_get_emb, mock_faiss, mock_storage_creator):
mock_settings.EMBEDDINGS_NAME = "test_model"
mock_emb = Mock(dimension=3)
mock_get_emb.return_value = mock_emb
mock_ds = Mock()
mock_ds.index = Mock(d=3)
mock_faiss.from_documents.return_value = mock_ds
mock_storage = Mock()
mock_storage_creator.get_storage.return_value = mock_storage
from application.vectorstore.faiss import FaissStore
store = FaissStore(source_id="t", embeddings_key="k", docs_init=[Mock()])
store._save_to_storage = Mock(return_value=True)
result = store.delete_chunk("chunk_id")
mock_ds.delete.assert_called_once_with(["chunk_id"])
store._save_to_storage.assert_called_once()
assert result is True
@pytest.mark.unit
class TestGetVectorstore:
def test_with_path(self):
from application.vectorstore.faiss import get_vectorstore
assert get_vectorstore("abc123") == "indexes/abc123"
def test_without_path(self):
from application.vectorstore.faiss import get_vectorstore
assert get_vectorstore("") == "indexes"
assert get_vectorstore(None) == "indexes"

View File

@@ -0,0 +1,312 @@
from unittest.mock import MagicMock, Mock, patch
import pytest
def _make_lancedb_store(source_id="test-source"):
"""Helper to create a LanceDBVectorStore with mocked deps."""
with patch(
"application.vectorstore.lancedb.settings"
) as mock_settings:
mock_settings.LANCEDB_PATH = "/tmp/lancedb"
mock_settings.LANCEDB_TABLE_NAME = "docs"
mock_settings.EMBEDDINGS_NAME = "test_model"
from application.vectorstore.lancedb import LanceDBVectorStore
store = LanceDBVectorStore(
path="/tmp/lancedb",
table_name_prefix="docs",
source_id=source_id,
embeddings_key="key",
)
return store, mock_settings
@pytest.mark.unit
class TestLanceDBVectorStoreInit:
def test_table_name_with_source_id(self):
store, _ = _make_lancedb_store(source_id="src1")
assert store.table_name == "docs_src1"
def test_table_name_without_source_id(self):
with patch("application.vectorstore.lancedb.settings") as mock_settings:
mock_settings.LANCEDB_PATH = "/tmp"
mock_settings.LANCEDB_TABLE_NAME = "docs"
from application.vectorstore.lancedb import LanceDBVectorStore
store = LanceDBVectorStore(
path="/tmp", table_name_prefix="docs", source_id=None
)
assert store.table_name == "docs"
def test_init_defaults(self):
store, _ = _make_lancedb_store()
assert store.path == "/tmp/lancedb"
assert store._lance_db is None
assert store.docsearch is None
@pytest.mark.unit
class TestLanceDBVectorStoreLazyLoading:
def test_pa_lazy_load(self):
store, _ = _make_lancedb_store()
mock_pa = MagicMock()
with patch("importlib.import_module", return_value=mock_pa) as mock_import:
result = store.pa
mock_import.assert_called_with("pyarrow")
assert result is mock_pa
def test_pa_cached(self):
store, _ = _make_lancedb_store()
mock_pa = MagicMock()
store._pa = mock_pa
assert store.pa is mock_pa
def test_lancedb_lazy_load(self):
store, _ = _make_lancedb_store()
mock_ldb = MagicMock()
with patch("importlib.import_module", return_value=mock_ldb) as mock_import:
result = store.lancedb
mock_import.assert_called_with("lancedb")
assert result is mock_ldb
def test_lance_db_connection(self):
store, _ = _make_lancedb_store()
mock_ldb_module = MagicMock()
mock_conn = MagicMock()
mock_ldb_module.connect.return_value = mock_conn
store._lancedb_module = mock_ldb_module
result = store.lance_db
mock_ldb_module.connect.assert_called_once_with("/tmp/lancedb")
assert result is mock_conn
def test_lance_db_cached(self):
store, _ = _make_lancedb_store()
mock_conn = MagicMock()
store._lance_db = mock_conn
assert store.lance_db is mock_conn
def test_table_opens_existing(self):
store, _ = _make_lancedb_store()
mock_conn = MagicMock()
mock_table = MagicMock()
mock_conn.table_names.return_value = [store.table_name]
mock_conn.open_table.return_value = mock_table
store._lance_db = mock_conn
result = store.table
mock_conn.open_table.assert_called_once_with(store.table_name)
assert result is mock_table
def test_table_returns_none_for_missing(self):
store, _ = _make_lancedb_store()
mock_conn = MagicMock()
mock_conn.table_names.return_value = []
store._lance_db = mock_conn
result = store.table
assert result is None
@pytest.mark.unit
class TestLanceDBVectorStoreEnsureTableExists:
def test_creates_table_when_missing(self):
store, _ = _make_lancedb_store()
mock_conn = MagicMock()
mock_conn.table_names.return_value = []
store._lance_db = mock_conn
mock_emb = MagicMock()
mock_emb.dimension = 768
mock_pa = MagicMock()
store._pa = mock_pa
with patch.object(store, "_get_embeddings", return_value=mock_emb):
store.ensure_table_exists()
mock_conn.create_table.assert_called_once()
def test_noop_when_table_exists(self):
store, _ = _make_lancedb_store()
mock_table = MagicMock()
store.docsearch = mock_table
mock_conn = MagicMock()
mock_conn.table_names.return_value = [store.table_name]
mock_conn.open_table.return_value = mock_table
store._lance_db = mock_conn
store.ensure_table_exists()
mock_conn.create_table.assert_not_called()
@pytest.mark.unit
class TestLanceDBVectorStoreAddTexts:
def test_add_texts(self):
store, _ = _make_lancedb_store()
mock_table = MagicMock()
store.docsearch = mock_table
mock_emb = MagicMock()
mock_emb.embed_documents.return_value = [[0.1, 0.2], [0.3, 0.4]]
with patch.object(store, "_get_embeddings", return_value=mock_emb), patch.object(
store, "ensure_table_exists"
):
store.add_texts(
["text1", "text2"],
metadatas=[{"a": "1"}, {"b": "2"}],
source_id="src1",
)
mock_table.add.assert_called_once()
vectors = mock_table.add.call_args[0][0]
assert len(vectors) == 2
assert vectors[0]["text"] == "text1"
assert vectors[0]["vector"] == [0.1, 0.2]
def test_add_texts_with_source_id_in_metadata(self):
store, _ = _make_lancedb_store()
mock_table = MagicMock()
store.docsearch = mock_table
mock_emb = MagicMock()
mock_emb.embed_documents.return_value = [[0.1]]
with patch.object(store, "_get_embeddings", return_value=mock_emb), patch.object(
store, "ensure_table_exists"
):
store.add_texts(["text1"], metadatas=[{"k": "v"}], source_id="src1")
vectors = mock_table.add.call_args[0][0]
metadata_keys = [m["key"] for m in vectors[0]["metadata"]]
assert "source_id" in metadata_keys
def test_add_texts_default_metadata(self):
store, _ = _make_lancedb_store()
mock_table = MagicMock()
store.docsearch = mock_table
mock_emb = MagicMock()
mock_emb.embed_documents.return_value = [[0.1]]
with patch.object(store, "_get_embeddings", return_value=mock_emb), patch.object(
store, "ensure_table_exists"
):
store.add_texts(["text1"])
mock_table.add.assert_called_once()
@pytest.mark.unit
class TestLanceDBVectorStoreSearch:
def test_search(self):
store, _ = _make_lancedb_store()
mock_table = MagicMock()
store.docsearch = mock_table
mock_emb = MagicMock()
mock_emb.embed_query.return_value = [0.1, 0.2]
mock_result = MagicMock()
mock_result.limit.return_value.to_list.return_value = [
{"_distance": 0.1, "text": "result1", "metadata": {"k": "v"}},
]
mock_table.search.return_value = mock_result
with patch.object(store, "_get_embeddings", return_value=mock_emb), patch.object(
store, "ensure_table_exists"
):
results = store.search("query", k=3)
assert len(results) == 1
assert results[0][1] == "result1"
mock_result.limit.assert_called_once_with(3)
@pytest.mark.unit
class TestLanceDBVectorStoreDeleteIndex:
def test_delete_index_drops_table(self):
store, _ = _make_lancedb_store()
mock_table = MagicMock()
store.docsearch = mock_table
mock_conn = MagicMock()
mock_conn.table_names.return_value = [store.table_name]
mock_conn.open_table.return_value = mock_table
store._lance_db = mock_conn
store.delete_index()
mock_conn.drop_table.assert_called_once_with(store.table_name)
def test_delete_index_noop_when_no_table(self):
store, _ = _make_lancedb_store()
store.docsearch = None
mock_conn = MagicMock()
mock_conn.table_names.return_value = []
store._lance_db = mock_conn
store.delete_index()
mock_conn.drop_table.assert_not_called()
@pytest.mark.unit
class TestLanceDBVectorStoreAssertEmbeddingDimensions:
def test_matching_dimensions(self):
store, _ = _make_lancedb_store()
mock_table = MagicMock()
mock_table.schema = {"vector": MagicMock()}
mock_table.schema["vector"].type.value_type.__len__ = Mock(return_value=768)
store.docsearch = mock_table
mock_emb = MagicMock()
mock_emb.dimension = 768
# Should not raise
store.assert_embedding_dimensions(mock_emb)
def test_mismatched_dimensions_raises(self):
store, _ = _make_lancedb_store()
mock_table = MagicMock()
store.docsearch = mock_table
type_mock = MagicMock()
type_mock.__len__ = Mock(return_value=512)
mock_table.schema.__getitem__ = Mock(return_value=MagicMock())
mock_table.schema["vector"].type.value_type = type_mock
mock_emb = MagicMock()
mock_emb.dimension = 768
with pytest.raises(ValueError, match="Embedding dimension mismatch"):
store.assert_embedding_dimensions(mock_emb)
@pytest.mark.unit
class TestLanceDBVectorStoreFilterDocuments:
def test_filter_documents(self):
store, _ = _make_lancedb_store()
mock_table = MagicMock()
mock_table.filter.return_value.to_list.return_value = [{"text": "filtered"}]
store.docsearch = mock_table
with patch.object(store, "ensure_table_exists"):
results = store.filter_documents({"source_id": "src1"})
assert len(results) == 1
def test_filter_documents_requires_source_id(self):
store, _ = _make_lancedb_store()
mock_table = MagicMock()
store.docsearch = mock_table
with patch.object(store, "ensure_table_exists"):
with pytest.raises(ValueError, match="must contain 'source_id'"):
store.filter_documents({"other_key": "value"})

View File

@@ -0,0 +1,95 @@
from unittest.mock import MagicMock, Mock, patch
from uuid import UUID
import pytest
def _make_milvus_store(source_id="test-source"):
"""Helper to create a MilvusStore with mocked deps."""
with patch(
"application.vectorstore.base.BaseVectorStore._get_embeddings"
) as mock_get_emb, patch(
"application.vectorstore.milvus.settings"
) as mock_settings, patch.dict(
"sys.modules",
{
"langchain_milvus": MagicMock(),
},
):
mock_emb = Mock()
mock_get_emb.return_value = mock_emb
mock_settings.EMBEDDINGS_NAME = "test_model"
mock_settings.MILVUS_URI = "http://localhost:19530"
mock_settings.MILVUS_TOKEN = "token"
mock_settings.MILVUS_COLLECTION_NAME = "test_collection"
from langchain_milvus import Milvus
mock_docsearch = MagicMock()
Milvus.return_value = mock_docsearch
from application.vectorstore.milvus import MilvusStore
store = MilvusStore(source_id=source_id, embeddings_key="key")
store._docsearch = mock_docsearch
return store, mock_docsearch
@pytest.mark.unit
class TestMilvusStoreInit:
def test_source_id_stored(self):
store, _ = _make_milvus_store(source_id="src1")
assert store._source_id == "src1"
@pytest.mark.unit
class TestMilvusStoreSearch:
def test_search(self):
store, mock_ds = _make_milvus_store(source_id="src1")
mock_ds.similarity_search.return_value = ["doc1", "doc2"]
results = store.search("query", k=3)
mock_ds.similarity_search.assert_called_once()
call_kwargs = mock_ds.similarity_search.call_args
assert call_kwargs[1]["query"] == "query"
assert call_kwargs[1]["k"] == 3
assert call_kwargs[1]["expr"] == "source_id == 'src1'"
assert results == ["doc1", "doc2"]
@pytest.mark.unit
class TestMilvusStoreAddTexts:
def test_add_texts(self):
store, mock_ds = _make_milvus_store()
mock_ds.add_texts.return_value = ["id1", "id2"]
result = store.add_texts(
["text1", "text2"], metadatas=[{"a": 1}, {"b": 2}]
)
mock_ds.add_texts.assert_called_once()
call_kwargs = mock_ds.add_texts.call_args
assert call_kwargs[1]["texts"] == ["text1", "text2"]
# ids should be UUIDs
ids = call_kwargs[1]["ids"]
assert len(ids) == 2
for uid in ids:
UUID(uid) # Validates it's a valid UUID
assert result == ["id1", "id2"]
@pytest.mark.unit
class TestMilvusStoreSaveLocal:
def test_save_local_is_noop(self):
store, _ = _make_milvus_store()
assert store.save_local() is None
@pytest.mark.unit
class TestMilvusStoreDeleteIndex:
def test_delete_index_is_noop(self):
store, _ = _make_milvus_store()
assert store.delete_index() is None

View File

@@ -0,0 +1,259 @@
from unittest.mock import MagicMock, Mock, patch
import pytest
def _make_mongodb_store(source_id="test-source"):
"""Helper to create a MongoDBVectorStore with all external deps mocked."""
with patch(
"application.vectorstore.base.BaseVectorStore._get_embeddings"
) as mock_get_emb, patch(
"application.vectorstore.mongodb.settings"
) as mock_settings, patch.dict(
"sys.modules", {"pymongo": MagicMock()}
):
mock_emb = Mock()
mock_emb.embed_query = Mock(return_value=[0.1, 0.2, 0.3])
mock_emb.embed_documents = Mock(return_value=[[0.1, 0.2, 0.3]])
mock_get_emb.return_value = mock_emb
mock_settings.EMBEDDINGS_NAME = "test_model"
mock_settings.MONGO_URI = "mongodb://localhost:27017"
from application.vectorstore.mongodb import MongoDBVectorStore
store = MongoDBVectorStore(
source_id=source_id,
embeddings_key="key",
collection="test_docs",
database="test_db",
)
mock_collection = MagicMock()
store._collection = mock_collection
return store, mock_collection, mock_emb
@pytest.mark.unit
class TestMongoDBVectorStoreInit:
def test_source_id_cleaned(self):
store, _, _ = _make_mongodb_store(source_id="application/indexes/abc123/")
assert store._source_id == "abc123"
@pytest.mark.unit
class TestMongoDBVectorStoreSearch:
def test_search_builds_pipeline(self):
store, mock_collection, mock_emb = _make_mongodb_store()
doc1 = {
"_id": "id1",
"text": "hello world",
"embedding": [0.1, 0.2],
"source": "test",
}
mock_collection.aggregate.return_value = iter([doc1])
results = store.search("query", k=3)
mock_emb.embed_query.assert_called_once_with("query")
mock_collection.aggregate.assert_called_once()
pipeline = mock_collection.aggregate.call_args[0][0]
assert pipeline[0]["$vectorSearch"]["limit"] == 3
assert pipeline[0]["$vectorSearch"]["numCandidates"] == 30
assert len(results) == 1
assert str(results[0]) == "hello world"
def test_search_removes_id_text_embedding_from_metadata(self):
store, mock_collection, _ = _make_mongodb_store()
doc = {
"_id": "id1",
"text": "content",
"embedding": [0.1],
"custom_key": "custom_val",
}
mock_collection.aggregate.return_value = iter([doc])
results = store.search("q", k=1)
metadata = results[0].metadata
assert "_id" not in metadata
assert "text" not in metadata
assert "embedding" not in metadata
assert metadata["custom_key"] == "custom_val"
@pytest.mark.unit
class TestMongoDBVectorStoreAddTexts:
def test_add_texts_batches(self):
store, mock_collection, mock_emb = _make_mongodb_store()
# Generate 150 texts to trigger batching at 100
texts = [f"text_{i}" for i in range(150)]
metadatas = [{"i": i} for i in range(150)]
mock_emb.embed_documents.return_value = [[0.1]] * 100 # per batch
mock_collection.insert_many.return_value = Mock(
inserted_ids=list(range(100))
)
store.add_texts(texts, metadatas)
# Should have been called twice: batch of 100, then batch of 50
assert mock_collection.insert_many.call_count == 2
def test_add_texts_default_metadata(self):
store, mock_collection, mock_emb = _make_mongodb_store()
mock_emb.embed_documents.return_value = [[0.1]]
mock_collection.insert_many.return_value = Mock(inserted_ids=["id1"])
store.add_texts(["text1"])
mock_collection.insert_many.assert_called_once()
def test_add_texts_empty(self):
store, mock_collection, mock_emb = _make_mongodb_store()
mock_emb.embed_documents.return_value = []
mock_collection.insert_many.return_value = Mock(inserted_ids=[])
result = store.add_texts([], [])
# _insert_texts returns [] for empty input
assert result == []
@pytest.mark.unit
class TestMongoDBVectorStoreInsertTexts:
def test_insert_texts_empty_returns_empty(self):
store, _, _ = _make_mongodb_store()
result = store._insert_texts([], [])
assert result == []
def test_insert_texts_builds_correct_documents(self):
store, mock_collection, mock_emb = _make_mongodb_store()
mock_emb.embed_documents.return_value = [[0.1, 0.2]]
mock_collection.insert_many.return_value = Mock(inserted_ids=["id1"])
store._insert_texts(["hello"], [{"source": "test"}])
inserted_docs = mock_collection.insert_many.call_args[0][0]
assert len(inserted_docs) == 1
assert inserted_docs[0]["text"] == "hello"
assert inserted_docs[0]["embedding"] == [0.1, 0.2]
assert inserted_docs[0]["source"] == "test"
@pytest.mark.unit
class TestMongoDBVectorStoreDeleteIndex:
def test_delete_index_calls_delete_many(self):
store, mock_collection, _ = _make_mongodb_store(source_id="src1")
store.delete_index()
mock_collection.delete_many.assert_called_once_with({"source_id": "src1"})
@pytest.mark.unit
class TestMongoDBVectorStoreGetChunks:
def test_get_chunks(self):
store, mock_collection, _ = _make_mongodb_store()
docs = [
{
"_id": "id1",
"text": "chunk1",
"embedding": [0.1],
"source_id": "src",
"extra": "val",
},
{
"_id": "id2",
"text": "chunk2",
"embedding": [0.2],
"source_id": "src",
},
]
mock_collection.find.return_value = iter(docs)
chunks = store.get_chunks()
assert len(chunks) == 2
assert chunks[0]["doc_id"] == "id1"
assert chunks[0]["text"] == "chunk1"
assert chunks[0]["metadata"] == {"extra": "val"}
assert "embedding" not in chunks[0]["metadata"]
assert "source_id" not in chunks[0]["metadata"]
def test_get_chunks_skips_empty_text(self):
store, mock_collection, _ = _make_mongodb_store()
docs = [
{"_id": "id1", "text": None, "embedding": [0.1], "source_id": "src"},
]
mock_collection.find.return_value = iter(docs)
chunks = store.get_chunks()
assert len(chunks) == 0
def test_get_chunks_returns_empty_on_error(self):
store, mock_collection, _ = _make_mongodb_store()
mock_collection.find.side_effect = Exception("connection error")
assert store.get_chunks() == []
@pytest.mark.unit
class TestMongoDBVectorStoreAddChunk:
def test_add_chunk(self):
store, mock_collection, mock_emb = _make_mongodb_store(source_id="src1")
mock_emb.embed_documents.return_value = [[0.1, 0.2]]
mock_collection.insert_one.return_value = Mock(inserted_id="new_id")
result = store.add_chunk("hello chunk", metadata={"key": "val"})
assert result == "new_id"
inserted = mock_collection.insert_one.call_args[0][0]
assert inserted["text"] == "hello chunk"
assert inserted["source_id"] == "src1"
assert inserted["key"] == "val"
def test_add_chunk_default_metadata(self):
store, mock_collection, mock_emb = _make_mongodb_store()
mock_emb.embed_documents.return_value = [[0.1]]
mock_collection.insert_one.return_value = Mock(inserted_id="id")
store.add_chunk("text")
inserted = mock_collection.insert_one.call_args[0][0]
assert "source_id" in inserted
def test_add_chunk_raises_on_empty_embedding(self):
store, _, mock_emb = _make_mongodb_store()
mock_emb.embed_documents.return_value = []
with pytest.raises(ValueError, match="Could not generate embedding"):
store.add_chunk("text")
@pytest.mark.unit
class TestMongoDBVectorStoreDeleteChunk:
def test_delete_chunk_success(self):
store, mock_collection, _ = _make_mongodb_store()
mock_collection.delete_one.return_value = Mock(deleted_count=1)
with patch("application.vectorstore.mongodb.ObjectId", create=True):
# We need to mock bson.objectid.ObjectId
with patch.dict("sys.modules", {"bson": MagicMock(), "bson.objectid": MagicMock()}):
from unittest.mock import MagicMock as MM
mock_oid = MM()
with patch(
"bson.objectid.ObjectId", return_value=mock_oid
):
result = store.delete_chunk("507f1f77bcf86cd799439011")
assert result is True
def test_delete_chunk_returns_false_on_error(self):
store, mock_collection, _ = _make_mongodb_store()
mock_collection.delete_one.side_effect = Exception("fail")
result = store.delete_chunk("bad_id")
assert result is False

View File

@@ -0,0 +1,297 @@
from unittest.mock import MagicMock, Mock, patch
import pytest
def _make_store(
source_id="test-source",
embeddings_key="key",
connection_string="postgresql://user:pass@localhost/db",
):
"""Helper to create a PGVectorStore with all external deps mocked."""
with patch(
"application.vectorstore.base.BaseVectorStore._get_embeddings"
) as mock_get_emb, patch(
"application.vectorstore.pgvector.settings"
) as mock_settings, patch.dict(
"sys.modules",
{
"psycopg2": MagicMock(),
"psycopg2.extras": MagicMock(),
"pgvector": MagicMock(),
"pgvector.psycopg2": MagicMock(),
},
):
mock_emb = Mock()
mock_emb.embed_query = Mock(return_value=[0.1, 0.2, 0.3])
mock_emb.embed_documents = Mock(return_value=[[0.1, 0.2, 0.3]])
mock_emb.dimension = 768
mock_get_emb.return_value = mock_emb
mock_settings.EMBEDDINGS_NAME = "test_model"
mock_settings.PGVECTOR_CONNECTION_STRING = connection_string
from application.vectorstore.pgvector import PGVectorStore
# Patch _ensure_table_exists to avoid DB calls during init
with patch.object(PGVectorStore, "_ensure_table_exists"):
store = PGVectorStore(
source_id=source_id,
embeddings_key=embeddings_key,
connection_string=connection_string,
)
# Provide a mock connection
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_conn.cursor.return_value = mock_cursor
mock_conn.closed = False
store._connection = mock_conn
return store, mock_conn, mock_cursor, mock_emb
@pytest.mark.unit
class TestPGVectorStoreInit:
def test_source_id_cleaned(self):
store, _, _, _ = _make_store(source_id="application/indexes/abc123/")
assert store._source_id == "abc123"
def test_missing_connection_string_raises(self):
with patch(
"application.vectorstore.base.BaseVectorStore._get_embeddings"
) as mock_get_emb, patch(
"application.vectorstore.pgvector.settings"
) as mock_settings, patch.dict(
"sys.modules",
{
"psycopg2": MagicMock(),
"psycopg2.extras": MagicMock(),
"pgvector": MagicMock(),
"pgvector.psycopg2": MagicMock(),
},
):
mock_get_emb.return_value = Mock(dimension=768)
mock_settings.EMBEDDINGS_NAME = "test_model"
mock_settings.PGVECTOR_CONNECTION_STRING = None
from application.vectorstore.pgvector import PGVectorStore
with pytest.raises(ValueError, match="connection string is required"):
PGVectorStore(
source_id="test", embeddings_key="key", connection_string=None
)
@pytest.mark.unit
class TestPGVectorStoreSearch:
def test_search_returns_documents(self):
store, mock_conn, mock_cursor, mock_emb = _make_store()
mock_cursor.fetchall.return_value = [
("hello world", {"source": "test.txt"}, 0.1),
("foo bar", {"source": "test2.txt"}, 0.2),
]
results = store.search("query", k=2)
mock_emb.embed_query.assert_called_once_with("query")
assert len(results) == 2
assert results[0].page_content == "hello world"
assert results[0].metadata == {"source": "test.txt"}
def test_search_returns_empty_on_error(self):
store, mock_conn, mock_cursor, _ = _make_store()
mock_cursor.execute.side_effect = Exception("connection lost")
results = store.search("query")
assert results == []
def test_search_handles_null_metadata(self):
store, _, mock_cursor, _ = _make_store()
mock_cursor.fetchall.return_value = [("text", None, 0.5)]
results = store.search("query")
assert len(results) == 1
assert results[0].metadata == {}
@pytest.mark.unit
class TestPGVectorStoreAddTexts:
def test_add_texts_inserts_and_returns_ids(self):
store, mock_conn, mock_cursor, mock_emb = _make_store()
mock_emb.embed_documents.return_value = [[0.1, 0.2], [0.3, 0.4]]
mock_cursor.fetchone.side_effect = [(1,), (2,)]
ids = store.add_texts(["text1", "text2"], [{"a": 1}, {"b": 2}])
assert ids == ["1", "2"]
assert mock_cursor.execute.call_count == 2
mock_conn.commit.assert_called_once()
def test_add_texts_empty_returns_empty(self):
store, _, _, _ = _make_store()
assert store.add_texts([]) == []
def test_add_texts_default_metadatas(self):
store, mock_conn, mock_cursor, mock_emb = _make_store()
mock_emb.embed_documents.return_value = [[0.1, 0.2]]
mock_cursor.fetchone.return_value = (1,)
ids = store.add_texts(["text1"])
assert ids == ["1"]
def test_add_texts_rolls_back_on_error(self):
store, mock_conn, mock_cursor, mock_emb = _make_store()
mock_emb.embed_documents.return_value = [[0.1]]
mock_cursor.execute.side_effect = Exception("insert failed")
with pytest.raises(Exception, match="insert failed"):
store.add_texts(["text1"])
mock_conn.rollback.assert_called_once()
@pytest.mark.unit
class TestPGVectorStoreDeleteIndex:
def test_delete_index_deletes_by_source_id(self):
store, mock_conn, mock_cursor, _ = _make_store(source_id="src123")
store.delete_index()
mock_cursor.execute.assert_called_once()
sql = mock_cursor.execute.call_args[0][0]
assert "DELETE FROM" in sql
assert mock_cursor.execute.call_args[0][1] == ("src123",)
mock_conn.commit.assert_called_once()
def test_delete_index_rolls_back_on_error(self):
store, mock_conn, mock_cursor, _ = _make_store()
mock_cursor.execute.side_effect = Exception("fail")
with pytest.raises(Exception):
store.delete_index()
mock_conn.rollback.assert_called_once()
@pytest.mark.unit
class TestPGVectorStoreSaveLocal:
def test_save_local_is_noop(self):
store, _, _, _ = _make_store()
assert store.save_local() is None
@pytest.mark.unit
class TestPGVectorStoreGetChunks:
def test_get_chunks(self):
store, _, mock_cursor, _ = _make_store()
mock_cursor.fetchall.return_value = [
(1, "text1", {"key": "val"}),
(2, "text2", None),
]
chunks = store.get_chunks()
assert len(chunks) == 2
assert chunks[0] == {"doc_id": "1", "text": "text1", "metadata": {"key": "val"}}
assert chunks[1] == {"doc_id": "2", "text": "text2", "metadata": {}}
def test_get_chunks_returns_empty_on_error(self):
store, _, mock_cursor, _ = _make_store()
mock_cursor.execute.side_effect = Exception("fail")
assert store.get_chunks() == []
@pytest.mark.unit
class TestPGVectorStoreAddChunk:
def test_add_chunk(self):
store, mock_conn, mock_cursor, mock_emb = _make_store(source_id="src1")
mock_emb.embed_documents.return_value = [[0.1, 0.2]]
mock_cursor.fetchone.return_value = (42,)
chunk_id = store.add_chunk("hello", metadata={"key": "val"})
assert chunk_id == "42"
mock_conn.commit.assert_called_once()
def test_add_chunk_raises_on_empty_embedding(self):
store, _, _, mock_emb = _make_store()
mock_emb.embed_documents.return_value = []
with pytest.raises(ValueError, match="Could not generate embedding"):
store.add_chunk("text")
def test_add_chunk_includes_source_id_in_metadata(self):
store, mock_conn, mock_cursor, mock_emb = _make_store(source_id="src1")
mock_emb.embed_documents.return_value = [[0.1, 0.2]]
mock_cursor.fetchone.return_value = (1,)
store.add_chunk("hello", metadata={"key": "val"})
# Verify source_id is passed as a parameter to the INSERT
insert_call = mock_cursor.execute.call_args
params = insert_call[0][1]
# source_id is the 4th param in the insert
assert params[3] == "src1"
@pytest.mark.unit
class TestPGVectorStoreDeleteChunk:
def test_delete_chunk_success(self):
store, mock_conn, mock_cursor, _ = _make_store()
mock_cursor.rowcount = 1
result = store.delete_chunk("42")
assert result is True
mock_conn.commit.assert_called_once()
def test_delete_chunk_not_found(self):
store, mock_conn, mock_cursor, _ = _make_store()
mock_cursor.rowcount = 0
result = store.delete_chunk("999")
assert result is False
def test_delete_chunk_returns_false_on_error(self):
store, _, mock_cursor, _ = _make_store()
mock_cursor.execute.side_effect = Exception("fail")
result = store.delete_chunk("42")
assert result is False
@pytest.mark.unit
class TestPGVectorStoreConnection:
def test_get_connection_creates_new_when_closed(self):
store, mock_conn, _, _ = _make_store()
mock_conn.closed = True
mock_psycopg2 = MagicMock()
new_conn = MagicMock()
mock_psycopg2.connect.return_value = new_conn
store._psycopg2 = mock_psycopg2
conn = store._get_connection()
mock_psycopg2.connect.assert_called_once()
assert conn is new_conn
def test_get_connection_reuses_open(self):
store, mock_conn, _, _ = _make_store()
mock_conn.closed = False
conn = store._get_connection()
assert conn is mock_conn
def test_ensure_table_exists(self):
store, mock_conn, mock_cursor, _ = _make_store()
# Call _ensure_table_exists directly
store._ensure_table_exists()
# Should execute CREATE EXTENSION, CREATE TABLE, and CREATE INDEX statements
assert mock_cursor.execute.call_count >= 3
mock_conn.commit.assert_called()
def test_del_closes_connection(self):
store, mock_conn, _, _ = _make_store()
mock_conn.closed = False
store.__del__()
mock_conn.close.assert_called_once()

View File

@@ -0,0 +1,230 @@
from unittest.mock import MagicMock, Mock, patch
import pytest
def _make_qdrant_store(source_id="test-source"):
"""Helper to create a QdrantStore with all external deps mocked."""
mock_models = MagicMock()
mock_qdrant_langchain = MagicMock()
with patch(
"application.vectorstore.base.BaseVectorStore._get_embeddings"
) as mock_get_emb, patch(
"application.vectorstore.qdrant.settings"
) as mock_settings, patch.dict(
"sys.modules",
{
"qdrant_client": MagicMock(),
"qdrant_client.models": mock_models,
"langchain_community": MagicMock(),
"langchain_community.vectorstores": MagicMock(),
"langchain_community.vectorstores.qdrant": mock_qdrant_langchain,
},
):
mock_emb = Mock()
mock_emb.embed_query = Mock(return_value=[0.1, 0.2, 0.3])
mock_emb.embed_documents = Mock(return_value=[[0.1, 0.2, 0.3]])
mock_emb.client = [None, Mock(word_embedding_dimension=768)]
mock_get_emb.return_value = mock_emb
mock_settings.EMBEDDINGS_NAME = "test_model"
mock_settings.QDRANT_COLLECTION_NAME = "test_collection"
mock_settings.QDRANT_LOCATION = ":memory:"
mock_settings.QDRANT_URL = None
mock_settings.QDRANT_PORT = 6333
mock_settings.QDRANT_GRPC_PORT = 6334
mock_settings.QDRANT_HTTPS = False
mock_settings.QDRANT_PREFER_GRPC = False
mock_settings.QDRANT_API_KEY = None
mock_settings.QDRANT_PREFIX = None
mock_settings.QDRANT_TIMEOUT = None
mock_settings.QDRANT_PATH = None
mock_settings.QDRANT_DISTANCE_FUNC = "Cosine"
mock_docsearch = MagicMock()
mock_collections = MagicMock()
mock_collections.collections = [MagicMock(name="test_collection")]
mock_docsearch.client.get_collections.return_value = mock_collections
mock_qdrant_langchain.Qdrant.construct_instance.return_value = mock_docsearch
from application.vectorstore.qdrant import QdrantStore
store = QdrantStore(source_id=source_id, embeddings_key="key")
return store, mock_docsearch, mock_settings
@pytest.mark.unit
class TestQdrantStoreInit:
def test_source_id_cleaned(self):
store, _, _ = _make_qdrant_store(source_id="application/indexes/abc123/")
assert store._source_id == "abc123"
def test_filter_constructed(self):
store, _, _ = _make_qdrant_store(source_id="src1")
assert store._filter is not None
@pytest.mark.unit
class TestQdrantStoreSearch:
def test_search_delegates(self):
store, mock_ds, _ = _make_qdrant_store()
mock_ds.similarity_search.return_value = ["result1"]
results = store.search("query", k=5)
mock_ds.similarity_search.assert_called_once()
assert results == ["result1"]
@pytest.mark.unit
class TestQdrantStoreAddTexts:
def test_add_texts_delegates(self):
store, mock_ds, _ = _make_qdrant_store()
mock_ds.add_texts.return_value = ["id1"]
result = store.add_texts(["text1"], metadatas=[{"a": 1}])
mock_ds.add_texts.assert_called_once_with(["text1"], metadatas=[{"a": 1}])
assert result == ["id1"]
@pytest.mark.unit
class TestQdrantStoreSaveLocal:
def test_save_local_is_noop(self):
store, _, _ = _make_qdrant_store()
assert store.save_local() is None
@pytest.mark.unit
class TestQdrantStoreDeleteIndex:
def test_delete_index(self):
store, mock_ds, _ = _make_qdrant_store()
with patch("application.vectorstore.qdrant.settings") as ms:
ms.QDRANT_COLLECTION_NAME = "test_collection"
store.delete_index()
mock_ds.client.delete.assert_called_once_with(
collection_name="test_collection",
points_selector=store._filter,
)
@pytest.mark.unit
class TestQdrantStoreGetChunks:
def test_get_chunks(self):
store, mock_ds, _ = _make_qdrant_store()
record1 = MagicMock()
record1.id = "id1"
record1.payload = {
"page_content": "text1",
"metadata": {"source": "test"},
}
record2 = MagicMock()
record2.id = "id2"
record2.payload = {
"page_content": "text2",
"metadata": {"source": "test2"},
}
# First call returns records with offset, second returns empty with None offset
mock_ds.client.scroll.side_effect = [
([record1, record2], None),
]
chunks = store.get_chunks()
assert len(chunks) == 2
assert chunks[0] == {
"doc_id": "id1",
"text": "text1",
"metadata": {"source": "test"},
}
def test_get_chunks_pagination(self):
store, mock_ds, _ = _make_qdrant_store()
record1 = MagicMock()
record1.id = "id1"
record1.payload = {"page_content": "text1", "metadata": {}}
record2 = MagicMock()
record2.id = "id2"
record2.payload = {"page_content": "text2", "metadata": {}}
mock_ds.client.scroll.side_effect = [
([record1], "offset_token"),
([record2], None),
]
chunks = store.get_chunks()
assert len(chunks) == 2
assert mock_ds.client.scroll.call_count == 2
def test_get_chunks_returns_empty_on_error(self):
store, mock_ds, _ = _make_qdrant_store()
mock_ds.client.scroll.side_effect = Exception("fail")
assert store.get_chunks() == []
@pytest.mark.unit
class TestQdrantStoreAddChunk:
def test_add_chunk(self):
store, mock_ds, _ = _make_qdrant_store(source_id="src1")
mock_ds.add_documents.return_value = ["new-id"]
result = store.add_chunk("hello", metadata={"key": "val"})
assert result == "new-id"
mock_ds.add_documents.assert_called_once()
doc = mock_ds.add_documents.call_args[0][0][0]
assert doc.page_content == "hello"
assert doc.metadata["source_id"] == "src1"
assert doc.metadata["key"] == "val"
def test_add_chunk_default_metadata(self):
store, mock_ds, _ = _make_qdrant_store(source_id="src1")
mock_ds.add_documents.return_value = ["id"]
store.add_chunk("text")
doc = mock_ds.add_documents.call_args[0][0][0]
assert doc.metadata["source_id"] == "src1"
def test_add_chunk_fallback_id(self):
store, mock_ds, _ = _make_qdrant_store()
mock_ds.add_documents.return_value = []
result = store.add_chunk("text")
# Should return the uuid that was generated
assert result is not None
assert isinstance(result, str)
@pytest.mark.unit
class TestQdrantStoreDeleteChunk:
def test_delete_chunk_success(self):
store, mock_ds, _ = _make_qdrant_store()
with patch("application.vectorstore.qdrant.settings") as ms:
ms.QDRANT_COLLECTION_NAME = "test_collection"
result = store.delete_chunk("chunk-id")
mock_ds.client.delete.assert_called_once_with(
collection_name="test_collection",
points_selector=["chunk-id"],
)
assert result is True
def test_delete_chunk_returns_false_on_error(self):
store, mock_ds, _ = _make_qdrant_store()
with patch("application.vectorstore.qdrant.settings") as ms:
ms.QDRANT_COLLECTION_NAME = "test_collection"
mock_ds.client.delete.side_effect = Exception("fail")
result = store.delete_chunk("bad-id")
assert result is False

View File

@@ -0,0 +1,39 @@
from unittest.mock import patch
import pytest
from application.vectorstore.vector_creator import VectorCreator
@pytest.mark.unit
class TestVectorCreator:
def test_registered_vectorstores(self):
assert "faiss" in VectorCreator.vectorstores
assert "elasticsearch" in VectorCreator.vectorstores
assert "mongodb" in VectorCreator.vectorstores
assert "qdrant" in VectorCreator.vectorstores
assert "milvus" in VectorCreator.vectorstores
assert "pgvector" in VectorCreator.vectorstores
def test_create_vectorstore_invalid_type(self):
with pytest.raises(ValueError, match="No vectorstore class found for type"):
VectorCreator.create_vectorstore("nonexistent")
def test_create_vectorstore_case_insensitive(self):
with patch.object(
VectorCreator.vectorstores["faiss"], "__init__", return_value=None
) as mock_init:
mock_init.return_value = None
VectorCreator.create_vectorstore("FAISS", source_id="test", embeddings_key="key")
mock_init.assert_called_once_with(source_id="test", embeddings_key="key")
def test_create_vectorstore_passes_args(self):
with patch.object(
VectorCreator.vectorstores["mongodb"], "__init__", return_value=None
) as mock_init:
VectorCreator.create_vectorstore(
"mongodb", source_id="src1", embeddings_key="ek", database="mydb"
)
mock_init.assert_called_once_with(
source_id="src1", embeddings_key="ek", database="mydb"
)