mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-07 06:30:03 +00:00
Compare commits
28 Commits
chore/bump
...
tests-util
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
126fa01b14 | ||
|
|
e06debad5f | ||
|
|
6492852f7d | ||
|
|
00a621f33a | ||
|
|
e92ffc6fdc | ||
|
|
fe185e5b8d | ||
|
|
9f3d9ab860 | ||
|
|
1c0adde380 | ||
|
|
3c56bd0d0b | ||
|
|
86664ebda2 | ||
|
|
db18b743d1 | ||
|
|
9e85cc9065 | ||
|
|
aaaa6f002d | ||
|
|
47dcbcb74b | ||
|
|
ddbfd94193 | ||
|
|
8dec60ab8b | ||
|
|
84b2e4bab4 | ||
|
|
2afdd7f026 | ||
|
|
f364475f64 | ||
|
|
b254de6ed6 | ||
|
|
08dedcaf95 | ||
|
|
c726eb8ebd | ||
|
|
5f0d39e5f1 | ||
|
|
8c82fc5495 | ||
|
|
6d81a15e97 | ||
|
|
5478e4234c | ||
|
|
4056278fef | ||
|
|
eaf39bb15b |
@@ -3,6 +3,14 @@ LLM_NAME=docsgpt
|
||||
VITE_API_STREAMING=true
|
||||
INTERNAL_KEY=<internal key for worker-to-backend authentication>
|
||||
|
||||
# Provider-specific API keys (optional - use these to enable multiple providers)
|
||||
# OPENAI_API_KEY=<your-openai-api-key>
|
||||
# ANTHROPIC_API_KEY=<your-anthropic-api-key>
|
||||
# GOOGLE_API_KEY=<your-google-api-key>
|
||||
# GROQ_API_KEY=<your-groq-api-key>
|
||||
# NOVITA_API_KEY=<your-novita-api-key>
|
||||
# OPEN_ROUTER_API_KEY=<your-openrouter-api-key>
|
||||
|
||||
# Remote Embeddings (Optional - for using a remote embeddings API instead of local SentenceTransformer)
|
||||
# When set, the app will use the remote API and won't load SentenceTransformer (saves RAM)
|
||||
EMBEDDINGS_BASE_URL=
|
||||
|
||||
@@ -29,7 +29,7 @@
|
||||
|
||||
<div align="center">
|
||||
<br>
|
||||
<img src="https://d3dg1063dc54p9.cloudfront.net/videos/demov7.gif" alt="video-example-of-docs-gpt" width="800" height="450">
|
||||
<img src="https://d3dg1063dc54p9.cloudfront.net/videos/demo-26.gif" alt="video-example-of-docs-gpt" width="800" height="480">
|
||||
</div>
|
||||
<h3 align="left">
|
||||
<strong>Key Features:</strong>
|
||||
|
||||
@@ -185,7 +185,10 @@ class ToolExecutor:
|
||||
target_dict[param] = value
|
||||
|
||||
# Load tool (with caching)
|
||||
tool = self._get_or_load_tool(tool_data, tool_id, action_name)
|
||||
tool = self._get_or_load_tool(
|
||||
tool_data, tool_id, action_name,
|
||||
headers=headers, query_params=query_params,
|
||||
)
|
||||
|
||||
resolved_arguments = (
|
||||
{"query_params": query_params, "headers": headers, "body": body}
|
||||
@@ -238,7 +241,10 @@ class ToolExecutor:
|
||||
|
||||
return result, call_id
|
||||
|
||||
def _get_or_load_tool(self, tool_data: Dict, tool_id: str, action_name: str):
|
||||
def _get_or_load_tool(
|
||||
self, tool_data: Dict, tool_id: str, action_name: str,
|
||||
headers: Optional[Dict] = None, query_params: Optional[Dict] = None,
|
||||
):
|
||||
"""Load a tool, using cache when possible."""
|
||||
cache_key = f"{tool_data['name']}:{tool_id}:{self.user or ''}"
|
||||
if cache_key in self._loaded_tools:
|
||||
@@ -251,8 +257,8 @@ class ToolExecutor:
|
||||
tool_config = {
|
||||
"url": action_config["url"],
|
||||
"method": action_config["method"],
|
||||
"headers": {},
|
||||
"query_params": {},
|
||||
"headers": headers or {},
|
||||
"query_params": query_params or {},
|
||||
}
|
||||
if "body_content_type" in action_config:
|
||||
tool_config["body_content_type"] = action_config.get(
|
||||
|
||||
@@ -27,6 +27,8 @@ ANTHROPIC_ATTACHMENTS = IMAGE_ATTACHMENTS
|
||||
|
||||
OPENROUTER_ATTACHMENTS = IMAGE_ATTACHMENTS
|
||||
|
||||
NOVITA_ATTACHMENTS = IMAGE_ATTACHMENTS
|
||||
|
||||
|
||||
OPENAI_MODELS = [
|
||||
AvailableModel(
|
||||
@@ -193,6 +195,46 @@ OPENROUTER_MODELS = [
|
||||
),
|
||||
]
|
||||
|
||||
NOVITA_MODELS = [
|
||||
AvailableModel(
|
||||
id="moonshotai/kimi-k2.5",
|
||||
provider=ModelProvider.NOVITA,
|
||||
display_name="Kimi K2.5",
|
||||
description="MoE model with function calling, structured output, reasoning, and vision",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=NOVITA_ATTACHMENTS,
|
||||
context_window=262144,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="zai-org/glm-5",
|
||||
provider=ModelProvider.NOVITA,
|
||||
display_name="GLM-5",
|
||||
description="MoE model with function calling, structured output, and reasoning",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=[],
|
||||
context_window=202800,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="minimax/minimax-m2.5",
|
||||
provider=ModelProvider.NOVITA,
|
||||
display_name="MiniMax M2.5",
|
||||
description="MoE model with function calling, structured output, and reasoning",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=[],
|
||||
context_window=204800,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
AZURE_OPENAI_MODELS = [
|
||||
AvailableModel(
|
||||
id="azure-gpt-4",
|
||||
|
||||
@@ -114,6 +114,10 @@ class ModelRegistry:
|
||||
settings.LLM_PROVIDER == "openrouter" and settings.API_KEY
|
||||
):
|
||||
self._add_openrouter_models(settings)
|
||||
if settings.NOVITA_API_KEY or (
|
||||
settings.LLM_PROVIDER == "novita" and settings.API_KEY
|
||||
):
|
||||
self._add_novita_models(settings)
|
||||
if settings.HUGGINGFACE_API_KEY or (
|
||||
settings.LLM_PROVIDER == "huggingface" and settings.API_KEY
|
||||
):
|
||||
@@ -245,6 +249,21 @@ class ModelRegistry:
|
||||
for model in OPENROUTER_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_novita_models(self, settings):
|
||||
from application.core.model_configs import NOVITA_MODELS
|
||||
|
||||
if settings.NOVITA_API_KEY:
|
||||
for model in NOVITA_MODELS:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
if settings.LLM_PROVIDER == "novita" and settings.LLM_NAME:
|
||||
for model in NOVITA_MODELS:
|
||||
if model.id == settings.LLM_NAME:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
for model in NOVITA_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_docsgpt_models(self, settings):
|
||||
model_id = "docsgpt-local"
|
||||
model = AvailableModel(
|
||||
|
||||
@@ -10,6 +10,7 @@ def get_api_key_for_provider(provider: str) -> Optional[str]:
|
||||
provider_key_map = {
|
||||
"openai": settings.OPENAI_API_KEY,
|
||||
"openrouter": settings.OPEN_ROUTER_API_KEY,
|
||||
"novita": settings.NOVITA_API_KEY,
|
||||
"anthropic": settings.ANTHROPIC_API_KEY,
|
||||
"google": settings.GOOGLE_API_KEY,
|
||||
"groq": settings.GROQ_API_KEY,
|
||||
|
||||
@@ -5,9 +5,7 @@ from typing import Optional
|
||||
from pydantic import field_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
current_dir = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
)
|
||||
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
@@ -15,15 +13,11 @@ class Settings(BaseSettings):
|
||||
|
||||
AUTH_TYPE: Optional[str] = None # simple_jwt, session_jwt, or None
|
||||
LLM_PROVIDER: str = "docsgpt"
|
||||
LLM_NAME: Optional[str] = (
|
||||
None # if LLM_PROVIDER is openai, LLM_NAME can be gpt-4 or gpt-3.5-turbo
|
||||
)
|
||||
LLM_NAME: Optional[str] = None # if LLM_PROVIDER is openai, LLM_NAME can be gpt-4 or gpt-3.5-turbo
|
||||
EMBEDDINGS_NAME: str = "huggingface_sentence-transformers/all-mpnet-base-v2"
|
||||
EMBEDDINGS_BASE_URL: Optional[str] = None # Remote embeddings API URL (OpenAI-compatible)
|
||||
EMBEDDINGS_KEY: Optional[str] = (
|
||||
None # api key for embeddings (if using openai, just copy API_KEY)
|
||||
)
|
||||
|
||||
EMBEDDINGS_KEY: Optional[str] = None # api key for embeddings (if using openai, just copy API_KEY)
|
||||
|
||||
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
|
||||
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
|
||||
MONGO_URI: str = "mongodb://localhost:27017/docsgpt"
|
||||
@@ -45,9 +39,7 @@ class Settings(BaseSettings):
|
||||
PARSE_IMAGE_REMOTE: bool = False
|
||||
DOCLING_OCR_ENABLED: bool = False # Enable OCR for docling parsers (PDF, images)
|
||||
DOCLING_OCR_ATTACHMENTS_ENABLED: bool = False # Enable OCR for docling when parsing attachments
|
||||
VECTOR_STORE: str = (
|
||||
"faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb" or "pgvector"
|
||||
)
|
||||
VECTOR_STORE: str = "faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb" or "pgvector"
|
||||
RETRIEVERS_ENABLED: list = ["classic_rag"]
|
||||
AGENT_NAME: str = "classic"
|
||||
FALLBACK_LLM_PROVIDER: Optional[str] = None # provider for fallback llm
|
||||
@@ -55,12 +47,8 @@ class Settings(BaseSettings):
|
||||
FALLBACK_LLM_API_KEY: Optional[str] = None # api key for fallback llm
|
||||
|
||||
# Google Drive integration
|
||||
GOOGLE_CLIENT_ID: Optional[str] = (
|
||||
None # Replace with your actual Google OAuth client ID
|
||||
)
|
||||
GOOGLE_CLIENT_SECRET: Optional[str] = (
|
||||
None # Replace with your actual Google OAuth client secret
|
||||
)
|
||||
GOOGLE_CLIENT_ID: Optional[str] = None # Replace with your actual Google OAuth client ID
|
||||
GOOGLE_CLIENT_SECRET: Optional[str] = None # Replace with your actual Google OAuth client secret
|
||||
CONNECTOR_REDIRECT_BASE_URI: Optional[str] = (
|
||||
"http://127.0.0.1:7091/api/connectors/callback" ##add redirect url as it is to your provider's console(gcp)
|
||||
)
|
||||
@@ -72,7 +60,7 @@ class Settings(BaseSettings):
|
||||
MICROSOFT_AUTHORITY: Optional[str] = None # e.g., "https://login.microsoftonline.com/{tenant_id}"
|
||||
|
||||
# GitHub source
|
||||
GITHUB_ACCESS_TOKEN: Optional[str] = None # PAT token with read repo access
|
||||
GITHUB_ACCESS_TOKEN: Optional[str] = None # PAT token with read repo access
|
||||
|
||||
# LLM Cache
|
||||
CACHE_REDIS_URL: str = "redis://localhost:6379/2"
|
||||
@@ -90,16 +78,13 @@ class Settings(BaseSettings):
|
||||
GROQ_API_KEY: Optional[str] = None
|
||||
HUGGINGFACE_API_KEY: Optional[str] = None
|
||||
OPEN_ROUTER_API_KEY: Optional[str] = None
|
||||
NOVITA_API_KEY: Optional[str] = None
|
||||
|
||||
OPENAI_API_BASE: Optional[str] = None # azure openai api base url
|
||||
OPENAI_API_VERSION: Optional[str] = None # azure openai api version
|
||||
AZURE_DEPLOYMENT_NAME: Optional[str] = None # azure deployment name for answering
|
||||
AZURE_EMBEDDINGS_DEPLOYMENT_NAME: Optional[str] = (
|
||||
None # azure deployment name for embeddings
|
||||
)
|
||||
OPENAI_BASE_URL: Optional[str] = (
|
||||
None # openai base url for open ai compatable models
|
||||
)
|
||||
AZURE_EMBEDDINGS_DEPLOYMENT_NAME: Optional[str] = None # azure deployment name for embeddings
|
||||
OPENAI_BASE_URL: Optional[str] = None # openai base url for open ai compatable models
|
||||
|
||||
# elasticsearch
|
||||
ELASTIC_CLOUD_ID: Optional[str] = None # cloud id for elasticsearch
|
||||
@@ -141,9 +126,7 @@ class Settings(BaseSettings):
|
||||
|
||||
# LanceDB vectorstore config
|
||||
LANCEDB_PATH: str = "./data/lancedb" # Path where LanceDB stores its local data
|
||||
LANCEDB_TABLE_NAME: Optional[str] = (
|
||||
"docsgpts" # Name of the table to use for storing vectors
|
||||
)
|
||||
LANCEDB_TABLE_NAME: Optional[str] = "docsgpts" # Name of the table to use for storing vectors
|
||||
|
||||
FLASK_DEBUG_MODE: bool = False
|
||||
STORAGE_TYPE: str = "local" # local or s3
|
||||
@@ -180,6 +163,7 @@ class Settings(BaseSettings):
|
||||
"GOOGLE_API_KEY",
|
||||
"GROQ_API_KEY",
|
||||
"HUGGINGFACE_API_KEY",
|
||||
"NOVITA_API_KEY",
|
||||
"EMBEDDINGS_KEY",
|
||||
"FALLBACK_LLM_API_KEY",
|
||||
"QDRANT_API_KEY",
|
||||
|
||||
@@ -7,6 +7,7 @@ class LLMHandlerCreator:
|
||||
handlers = {
|
||||
"openai": OpenAILLMHandler,
|
||||
"google": GoogleLLMHandler,
|
||||
"novita": OpenAILLMHandler, # Novita uses OpenAI-compatible API
|
||||
"default": OpenAILLMHandler,
|
||||
}
|
||||
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from application.core.settings import settings
|
||||
from application.llm.openai import OpenAILLM
|
||||
|
||||
NOVITA_BASE_URL = "https://api.novita.ai/v3/openai"
|
||||
NOVITA_BASE_URL = "https://api.novita.ai/openai"
|
||||
|
||||
|
||||
class NovitaLLM(OpenAILLM):
|
||||
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
|
||||
super().__init__(
|
||||
api_key=api_key or settings.API_KEY,
|
||||
api_key=api_key or settings.NOVITA_API_KEY or settings.API_KEY,
|
||||
user_api_key=user_api_key,
|
||||
base_url=base_url or NOVITA_BASE_URL,
|
||||
*args,
|
||||
|
||||
@@ -44,36 +44,40 @@ The main set of instructions or system [prompt](/Guides/Customising-prompts) tha
|
||||
|
||||
## Understanding Agent Types
|
||||
|
||||
DocsGPT allows for different "types" of agents, each with a distinct way of processing information and generating responses. The code for these agent types can be found in the `application/agents/` directory.
|
||||
DocsGPT supports several agent types, each with a distinct way of processing information. The code for these can be found in the `application/agents/` directory.
|
||||
|
||||
### 1. Classic Agent (`classic_agent.py`)
|
||||
### 1. Classic Agent
|
||||
|
||||
**How it works:** The Classic Agent follows a traditional Retrieval Augmented Generation (RAG) approach.
|
||||
1. **Retrieve:** When a query is made, it first searches the selected Source documents for relevant information.
|
||||
2. **Augment:** This retrieved data is then added to the context, along with the main Prompt and the user's query.
|
||||
3. **Generate:** The LLM generates a response based on this augmented context. It can also utilize any configured tools if the LLM decides they are necessary.
|
||||
The Classic Agent follows a traditional Retrieval Augmented Generation (RAG) approach: it retrieves relevant document chunks, augments the prompt context with them, and generates a response. It can also use configured tools if the LLM decides they are necessary.
|
||||
|
||||
**Best for:**
|
||||
* Direct question-answering over a specific set of documents.
|
||||
* Tasks where the primary goal is to extract and synthesize information from the provided sources.
|
||||
* Simpler tool integrations where the decision to use a tool is straightforward.
|
||||
**Best for:** Direct question-answering over a specific set of documents and straightforward tool use.
|
||||
|
||||
### 2. ReAct Agent (`react_agent.py`)
|
||||
### 2. Agentic Agent
|
||||
|
||||
**How it works:** The ReAct Agent employs a more sophisticated "Reason and Act" framework. This involves a multi-step process:
|
||||
1. **Plan (Thought):** Based on the query, its prompt, and available tools/sources, the LLM first generates a plan or a sequence of thoughts on how to approach the problem. You might see this output as a "thought" process during generation.
|
||||
2. **Act:** The agent then executes actions based on this plan. This might involve querying its sources, using a tool, or performing internal reasoning.
|
||||
3. **Observe:** It gathers observations from the results of its actions (e.g., data from a tool, snippets from documents).
|
||||
4. **Repeat (if necessary):** Steps 2 and 3 can be repeated as the agent refines its approach or gathers more information.
|
||||
5. **Conclude:** Finally, it generates the final answer based on the initial query and all accumulated observations.
|
||||
Unlike Classic which pre-fetches documents into the prompt, the Agentic Agent gives the LLM an `internal_search` tool so it can decide **when, what, and whether** to search. This means the LLM controls its own retrieval — it can search multiple times, refine queries, or skip retrieval entirely if the question doesn't need it.
|
||||
|
||||
**Best for:**
|
||||
* More complex tasks that require multi-step reasoning or problem-solving.
|
||||
* Scenarios where the agent needs to dynamically decide which tools to use and in what order, based on intermediate results.
|
||||
* Interactive tasks where the agent needs to "think" through a problem.
|
||||
**Best for:** Tasks where the agent needs to dynamically decide how to gather information, use multiple tools in sequence, or combine retrieval with external tool calls.
|
||||
|
||||
### 3. Research Agent
|
||||
|
||||
A multi-phase agent designed for in-depth research tasks:
|
||||
1. **Clarification** — Determines if the question needs clarification before proceeding.
|
||||
2. **Planning** — Decomposes the question into research steps with adaptive depth based on complexity.
|
||||
3. **Research** — Executes each step, calling tools and refining queries as needed.
|
||||
4. **Synthesis** — Compiles findings into a final cited report.
|
||||
|
||||
Includes budget controls for max steps, timeout, and token limits to keep research bounded.
|
||||
|
||||
**Best for:** Complex questions that require multi-step investigation, gathering information from multiple sources, and producing structured reports with citations.
|
||||
|
||||
### 4. Workflow Agent
|
||||
|
||||
Executes predefined workflows composed of connected nodes (AI Agent, Set State, Condition). See the [Workflow Nodes](/Agents/nodes) page for details on building workflows.
|
||||
|
||||
**Best for:** Structured, multi-step processes with branching logic and shared state between steps.
|
||||
|
||||
<Callout type="info">
|
||||
Developers looking to introduce new agent architectures can explore the `application/agents/` directory. `classic_agent.py` and `react_agent.py` serve as excellent starting points, demonstrating how to inherit from `BaseAgent` and structure agent logic.
|
||||
The legacy "ReAct" agent type is still accepted for backwards compatibility but maps to the Classic Agent internally. New agents should use Classic, Agentic, or Research instead.
|
||||
</Callout>
|
||||
|
||||
## Navigating and Managing Agents in DocsGPT
|
||||
|
||||
@@ -70,9 +70,9 @@ Inside the DocsGPT folder create a `.env` file and copy the contents of `.env_sa
|
||||
Make sure your `.env` file looks like this:
|
||||
|
||||
```
|
||||
OPENAI_API_KEY=(Your OpenAI API key)
|
||||
API_KEY=<Your LLM API key>
|
||||
LLM_NAME=docsgpt
|
||||
VITE_API_STREAMING=true
|
||||
SELF_HOSTED_MODEL=false
|
||||
```
|
||||
|
||||
To save the file, press CTRL+X, then Y, and then ENTER.
|
||||
|
||||
@@ -104,7 +104,7 @@ DocsGPT can transcribe audio in two places:
|
||||
- Voice input in the chat.
|
||||
- Audio file ingestion. Uploaded `.wav`, `.mp3`, `.m4a`, `.ogg`, and `.webm` files are transcribed first and then passed through the normal parser, chunking, embedding, and indexing pipeline.
|
||||
|
||||
For an end-to-end walkthrough, see the [Speech and Audio Guide](/Guides/speech-and-audio).
|
||||
The settings below control speech-to-text behaviour for both voice input and audio file ingestion.
|
||||
|
||||
| Setting | Purpose | Typical values |
|
||||
| --- | --- | --- |
|
||||
@@ -214,6 +214,31 @@ If you have configured `AUTH_TYPE=simple_jwt`, the DocsGPT frontend will prompt
|
||||
}}
|
||||
/>
|
||||
|
||||
## S3 Storage Backend
|
||||
|
||||
By default DocsGPT stores files locally. Set `STORAGE_TYPE=s3` to use Amazon S3 instead.
|
||||
|
||||
| Setting | Description | Default |
|
||||
| --- | --- | --- |
|
||||
| `STORAGE_TYPE` | `local` or `s3` | `local` |
|
||||
| `S3_BUCKET_NAME` | S3 bucket name | `docsgpt-test-bucket` |
|
||||
| `SAGEMAKER_ACCESS_KEY` | AWS access key ID | — |
|
||||
| `SAGEMAKER_SECRET_KEY` | AWS secret access key | — |
|
||||
| `SAGEMAKER_REGION` | AWS region | — |
|
||||
| `URL_STRATEGY` | `backend` (proxy through API) or `s3` (direct S3 URLs) | `backend` |
|
||||
|
||||
The S3 credentials use `SAGEMAKER_*` variable names because they are shared with the SageMaker integration.
|
||||
|
||||
```env
|
||||
STORAGE_TYPE=s3
|
||||
S3_BUCKET_NAME=your-bucket-name
|
||||
SAGEMAKER_ACCESS_KEY=your-aws-access-key-id
|
||||
SAGEMAKER_SECRET_KEY=your-aws-secret-access-key
|
||||
SAGEMAKER_REGION=us-east-1
|
||||
```
|
||||
|
||||
Your IAM user needs these permissions on the bucket: `s3:PutObject`, `s3:GetObject`, `s3:DeleteObject`, `s3:ListBucket`, `s3:HeadObject`.
|
||||
|
||||
## Exploring More Settings
|
||||
|
||||
These are just the basic settings to get you started. The `settings.py` file contains many more advanced options that you can explore to further customize DocsGPT, such as:
|
||||
|
||||
@@ -86,13 +86,9 @@ Make sure your `.env` file looks like this:
|
||||
|
||||
|
||||
```
|
||||
|
||||
OPENAI_API_KEY=(Your OpenAI API key)
|
||||
|
||||
API_KEY=<Your LLM API key>
|
||||
LLM_NAME=docsgpt
|
||||
VITE_API_STREAMING=true
|
||||
|
||||
SELF_HOSTED_MODEL=false
|
||||
|
||||
```
|
||||
|
||||
|
||||
|
||||
@@ -11,18 +11,18 @@ DocsGPT API keys are essential for developers and users who wish to integrate th
|
||||
|
||||
After uploading your document, you can obtain an API key either through the graphical user interface or via an API call:
|
||||
|
||||
- **Graphical User Interface:** Navigate to the Settings section of the DocsGPT web app, find the API Keys option, and press 'Create New' to generate your key.
|
||||
- **API Call:** Alternatively, you can use the `/api/create_api_key` endpoint to create a new API key. For detailed instructions, visit [DocsGPT API Documentation](https://gptcloud.arc53.com/).
|
||||
- **Graphical User Interface:** Navigate to the Settings section of the DocsGPT web app, find the Agents option, and press 'Create New' to generate a new agent (which includes an API key).
|
||||
- **API Call:** Alternatively, you can use the `/api/create_agent` endpoint to create a new agent. An API key is automatically generated for each agent. For detailed instructions, visit [DocsGPT API Documentation](https://gptcloud.arc53.com/).
|
||||
|
||||
## Understanding Key Variables
|
||||
|
||||
Upon creating your API key, you will encounter several key variables. Each serves a specific purpose:
|
||||
Upon creating your agent, you will encounter several key variables. Each serves a specific purpose:
|
||||
|
||||
- **Name:** Assign a name to your API key for easy identification.
|
||||
- **Source:** Indicates the source document(s) linked to your API key, which DocsGPT will use to generate responses.
|
||||
- **ID:** A unique identifier for your API key. You can view this by making a call to `/api/get_api_keys`.
|
||||
- **Key:** The API key itself, which will be used in your application to authenticate API requests.
|
||||
- **Name:** Assign a name to your agent for easy identification.
|
||||
- **Source:** Indicates the source document(s) linked to your agent, which DocsGPT will use to generate responses.
|
||||
- **ID:** A unique identifier for your agent. You can view this by making a call to `/api/get_agents`.
|
||||
- **Key:** The API key for the agent, which will be used in your application to authenticate API requests.
|
||||
|
||||
With your API key ready, you can now integrate DocsGPT into your application, such as the DocsGPT Widget or any other software, via `/api/answer` or `/stream` endpoints. The source document is preset with the API key, allowing you to bypass fields like `selectDocs` and `active_docs` during implementation.
|
||||
With your API key ready, you can now integrate DocsGPT into your application, such as the DocsGPT Widget or any other software, via `/api/answer` or `/stream` endpoints. The source document is preset with the agent, allowing you to bypass fields like `selectDocs` and `active_docs` during implementation.
|
||||
|
||||
Congratulations on taking the first step towards enhancing your applications with DocsGPT!
|
||||
|
||||
@@ -64,7 +64,7 @@ flowchart LR
|
||||
* **Technology:** Supports multiple vector databases.
|
||||
* **Responsibility:** Vector Stores are used to store and retrieve vector embeddings of document chunks. This enables semantic search and retrieval of relevant document snippets in response to user queries.
|
||||
* **Key Features:**
|
||||
* Supports vector databases including FAISS, Elasticsearch, Qdrant, Milvus, and LanceDB.
|
||||
* Supports vector databases including FAISS, Elasticsearch, Qdrant, Milvus, MongoDB Atlas Vector Search, and pgvector.
|
||||
* Provides storage and indexing of high-dimensional vector embeddings.
|
||||
* Enables editing and updating of vector indexes including specific chunks.
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ Training on other documentation sources can greatly enhance the versatility and
|
||||
Make sure you have the document on which you want to train on ready with you on the device which you are using .You can also use links to the documentation to train on.
|
||||
|
||||
<Callout type="warning" emoji="⚠️">
|
||||
Note: The document should be either of the given file formats .pdf, .txt, .rst, .docx, .md, .zip and limited to 25mb.You can also train using the link of the documentation.
|
||||
Note: Supported file formats include .pdf, .txt, .rst, .docx, .md, .mdx, .csv, .epub, .html, .json, .xlsx, .pptx, .png, .jpg, .jpeg, and audio files (.wav, .mp3, .m4a, .ogg, .webm). You can also train using the link of the documentation.
|
||||
|
||||
</Callout>
|
||||
|
||||
|
||||
@@ -35,8 +35,34 @@ Choose the LLM of your choice.
|
||||
For open source version please edit `LLM_PROVIDER`, `LLM_NAME` and others in the .env file. Refer to [⚙️ App Configuration](/Deploying/DocsGPT-Settings) for more information.
|
||||
### Step 2
|
||||
Visit [☁️ Cloud Providers](/Models/cloud-providers) for the updated list of online models. Make sure you have the right API_KEY and correct LLM_PROVIDER.
|
||||
For self-hosted please visit [🖥️ Local Inference](/Models/local-inference).
|
||||
For self-hosted please visit [🖥️ Local Inference](/Models/local-inference).
|
||||
</Steps>
|
||||
|
||||
## Fallback LLM
|
||||
|
||||
DocsGPT can automatically switch to a fallback LLM when the primary model fails, including mid-stream. This works with both streaming and non-streaming requests.
|
||||
|
||||
**Fallback order:**
|
||||
1. Per-agent backup models (other models configured on the same agent)
|
||||
2. Global fallback (`FALLBACK_LLM_*` env vars below)
|
||||
3. Error returned if all fail
|
||||
|
||||
| Setting | Description | Default |
|
||||
| --- | --- | --- |
|
||||
| `FALLBACK_LLM_PROVIDER` | Provider name (e.g., `openai`, `anthropic`, `google`) | — |
|
||||
| `FALLBACK_LLM_NAME` | Model name (e.g., `gpt-4o`, `claude-sonnet-4-20250514`) | — |
|
||||
| `FALLBACK_LLM_API_KEY` | API key for the fallback provider | Falls back to `API_KEY` |
|
||||
|
||||
All three (`FALLBACK_LLM_PROVIDER`, `FALLBACK_LLM_NAME`, and an API key) must resolve for the global fallback to activate.
|
||||
|
||||
```env
|
||||
FALLBACK_LLM_PROVIDER=anthropic
|
||||
FALLBACK_LLM_NAME=claude-sonnet-4-20250514
|
||||
FALLBACK_LLM_API_KEY=sk-ant-your-anthropic-key
|
||||
```
|
||||
|
||||
<Callout type="info">
|
||||
For maximum resilience, use a fallback provider from a different cloud than your primary. Each agent can also have multiple models configured — the other models are tried first before the global fallback.
|
||||
</Callout>
|
||||
|
||||
|
||||
|
||||
@@ -2,5 +2,13 @@ export default {
|
||||
"google-drive-connector": {
|
||||
"title": "🔗 Google Drive",
|
||||
"href": "/Guides/Integrations/google-drive-connector"
|
||||
},
|
||||
"sharepoint-connector": {
|
||||
"title": "🔗 SharePoint / OneDrive",
|
||||
"href": "/Guides/Integrations/sharepoint-connector"
|
||||
},
|
||||
"mcp-tool-integration": {
|
||||
"title": "🔗 MCP Tools",
|
||||
"href": "/Guides/Integrations/mcp-tool-integration"
|
||||
}
|
||||
}
|
||||
|
||||
66
docs/content/Guides/Integrations/mcp-tool-integration.mdx
Normal file
66
docs/content/Guides/Integrations/mcp-tool-integration.mdx
Normal file
@@ -0,0 +1,66 @@
|
||||
---
|
||||
title: MCP Tool Integration
|
||||
description: Connect external tools to DocsGPT agents using the Model Context Protocol (MCP) standard.
|
||||
---
|
||||
|
||||
import { Callout } from 'nextra/components'
|
||||
import { Steps } from 'nextra/components'
|
||||
|
||||
# MCP Tool Integration
|
||||
|
||||
The [Model Context Protocol (MCP)](https://modelcontextprotocol.io/) integration lets you connect external tool servers to DocsGPT. Your agents can then discover and call tools provided by those servers during conversations — for example, querying a CRM, running code, or accessing a database.
|
||||
|
||||
## Setup
|
||||
|
||||
<Steps>
|
||||
|
||||
### Step 1: Configure Environment Variables (Optional)
|
||||
|
||||
Only needed if your MCP servers use OAuth authentication:
|
||||
|
||||
```env
|
||||
MCP_OAUTH_REDIRECT_URI=https://yourdomain.com/api/mcp_server/callback
|
||||
```
|
||||
|
||||
If not set, falls back to `API_URL/api/mcp_server/callback`.
|
||||
|
||||
### Step 2: Add an MCP Server
|
||||
|
||||
Go to **Settings** > **Tools** > **Add Tool** > **MCP Server**. Enter the server URL, select an auth type, and click **Test Connection** to verify, then **Save**.
|
||||
|
||||
### Step 3: Enable for Your Agent
|
||||
|
||||
In your agent configuration, enable the MCP tools you want the agent to use.
|
||||
|
||||
</Steps>
|
||||
|
||||
## Authentication Types
|
||||
|
||||
| Auth Type | Config Fields |
|
||||
|-----------|---------------|
|
||||
| **None** | — |
|
||||
| **Bearer** | `bearer_token` |
|
||||
| **API Key** | `api_key`, `api_key_header` (default: `X-API-Key`) |
|
||||
| **Basic** | `username`, `password` |
|
||||
| **OAuth** | `oauth_scopes` (optional) |
|
||||
|
||||
<Callout type="warning">
|
||||
For OAuth in production, `MCP_OAUTH_REDIRECT_URI` must be a publicly accessible URL pointing to your DocsGPT backend.
|
||||
</Callout>
|
||||
|
||||
## API Endpoints
|
||||
|
||||
| Endpoint | Method | Description |
|
||||
|----------|--------|-------------|
|
||||
| `/api/mcp_server/test` | POST | Test a connection without saving |
|
||||
| `/api/mcp_server/save` | POST | Save or update a server configuration |
|
||||
| `/api/mcp_server/callback` | GET | OAuth callback handler |
|
||||
| `/api/mcp_server/oauth_status/<task_id>` | GET | Poll OAuth flow status |
|
||||
| `/api/mcp_server/auth_status` | GET | Batch check auth status for all MCP tools |
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
- **Connection refused** — Verify the URL and that the server is reachable from your backend.
|
||||
- **403 Forbidden** — Check credentials and permissions.
|
||||
- **Timed out** — Default is 30s; increase timeout in tool config (max 300s).
|
||||
- **OAuth "needs_auth" persists** — Verify `MCP_OAUTH_REDIRECT_URI` is correct and Redis is running.
|
||||
63
docs/content/Guides/Integrations/sharepoint-connector.mdx
Normal file
63
docs/content/Guides/Integrations/sharepoint-connector.mdx
Normal file
@@ -0,0 +1,63 @@
|
||||
---
|
||||
title: SharePoint / OneDrive Connector
|
||||
description: Connect your Microsoft SharePoint or OneDrive as an external knowledge base to upload and process files directly.
|
||||
---
|
||||
|
||||
import { Callout } from 'nextra/components'
|
||||
import { Steps } from 'nextra/components'
|
||||
|
||||
# SharePoint / OneDrive Connector
|
||||
|
||||
Connect your SharePoint or OneDrive account to upload and process files directly as an external knowledge base. Supports Office files, PDFs, text files, CSVs, images, and more. Authentication is handled via Microsoft Entra ID (Azure AD) with automatic token refresh.
|
||||
|
||||
## Setup
|
||||
|
||||
<Steps>
|
||||
|
||||
### Step 1: Create an App Registration in Azure
|
||||
|
||||
1. Go to the [Azure Portal](https://portal.azure.com/) > **Microsoft Entra ID** > **App registrations** > **New registration**
|
||||
2. Set **Redirect URI** (Web) to:
|
||||
- Local: `http://localhost:7091/api/connectors/callback?provider=share_point`
|
||||
- Production: `https://yourdomain.com/api/connectors/callback?provider=share_point`
|
||||
|
||||
### Step 2: Configure API Permissions
|
||||
|
||||
In your App Registration, go to **API permissions** > **Add a permission** > **Microsoft Graph** > **Delegated permissions** and add: `Files.Read`, `Files.Read.All`, `Sites.Read.All`. Grant admin consent if possible.
|
||||
|
||||
### Step 3: Create a Client Secret
|
||||
|
||||
Go to **Certificates & secrets** > **New client secret**. Copy the secret value immediately (it won't be shown again).
|
||||
|
||||
### Step 4: Configure Environment Variables
|
||||
|
||||
Add to your `.env` file:
|
||||
|
||||
```env
|
||||
MICROSOFT_CLIENT_ID=your-azure-ad-client-id
|
||||
MICROSOFT_CLIENT_SECRET=your-azure-ad-client-secret
|
||||
MICROSOFT_TENANT_ID=your-azure-ad-tenant-id
|
||||
```
|
||||
|
||||
| Variable | Description | Required | Default |
|
||||
|----------|-------------|----------|---------|
|
||||
| `MICROSOFT_CLIENT_ID` | Application (client) ID from App Registration overview | Yes | — |
|
||||
| `MICROSOFT_CLIENT_SECRET` | Client secret value | Yes | — |
|
||||
| `MICROSOFT_TENANT_ID` | Directory (tenant) ID | No | `common` |
|
||||
| `MICROSOFT_AUTHORITY` | Login endpoint override | No | Auto-constructed |
|
||||
|
||||
<Callout type="warning">
|
||||
`MICROSOFT_TENANT_ID=common` (the default) allows any Microsoft account to authenticate. Set this to your specific tenant ID in production.
|
||||
</Callout>
|
||||
|
||||
### Step 5: Restart and Use
|
||||
|
||||
Restart your application, then go to the upload section in DocsGPT and select **SharePoint / OneDrive** as the source. You'll be redirected to Microsoft to sign in, then can browse and select files to process.
|
||||
|
||||
</Steps>
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
- **Option not appearing** — Verify `MICROSOFT_CLIENT_ID` and `MICROSOFT_CLIENT_SECRET` are set, then restart.
|
||||
- **Authentication failed** — Check that the redirect URI matches exactly, including `?provider=share_point`.
|
||||
- **Permission denied** — Ensure admin consent is granted and the user has access to the target files.
|
||||
@@ -7,20 +7,10 @@ description:
|
||||
|
||||
If your AI uses external knowledge and is not explicit enough, it is ok, because we try to make DocsGPT friendly.
|
||||
|
||||
But if you want to adjust it, here is a simple way:-
|
||||
|
||||
- Got to `application/prompts/chat_combine_prompt.txt`
|
||||
|
||||
- And change it to
|
||||
But if you want to adjust it, prompts are now managed through the UI and API using a template-based system. See the [Customising Prompts](/Guides/Customising-prompts) guide for details.
|
||||
|
||||
To make the AI stricter about staying on-topic, edit your active prompt template (via **Sidebar → Settings → Active Prompt**) to include instructions like:
|
||||
|
||||
```
|
||||
|
||||
You are a DocsGPT, friendly and helpful AI assistant by Arc53 that provides help with documents. You give thorough answers with code examples, if possible.
|
||||
Write an answer for the question below based on the provided context.
|
||||
If the context provides insufficient information, reply "I cannot answer".
|
||||
You have access to chat history and can use it to help answer the question.
|
||||
----------------
|
||||
{summaries}
|
||||
|
||||
```
|
||||
|
||||
@@ -29,7 +29,7 @@ export default {
|
||||
"title": "OCR",
|
||||
"href": "/Guides/ocr"
|
||||
},
|
||||
"Integrations": {
|
||||
"Integrations": {
|
||||
"title": "🔗 Integrations"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -70,7 +70,7 @@ The easiest way to launch DocsGPT is using the provided `setup.sh` script. This
|
||||
To stop DocsGPT, simply open a new terminal in the `DocsGPT` directory and run:
|
||||
|
||||
```bash
|
||||
docker compose -f deployment/docker-compose.yaml down
|
||||
docker compose -f deployment/docker-compose-hub.yaml down
|
||||
```
|
||||
(or the specific `docker compose` command shown at the end of the `setup.sh` execution, which may include optional compose files depending on your choices).
|
||||
|
||||
|
||||
4
extensions/react-widget/package-lock.json
generated
4
extensions/react-widget/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "docsgpt",
|
||||
"version": "0.5.1",
|
||||
"version": "0.6.3",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "docsgpt",
|
||||
"version": "0.5.1",
|
||||
"version": "0.6.3",
|
||||
"license": "Apache-2.0",
|
||||
"dependencies": {
|
||||
"@babel/plugin-transform-flow-strip-types": "^7.23.3",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "docsgpt",
|
||||
"version": "0.6.0",
|
||||
"version": "0.6.3",
|
||||
"private": false,
|
||||
"description": "DocsGPT 🦖 is an innovative open-source tool designed to simplify the retrieval of information from project documentation using advanced GPT models 🤖.",
|
||||
"source": "./src/index.html",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
aiohttp>=3,<4
|
||||
certifi==2024.7.4
|
||||
h11==0.16.0
|
||||
h11==0.14.0
|
||||
httpcore==1.0.5
|
||||
httpx==0.27.0
|
||||
idna==3.7
|
||||
|
||||
@@ -40,9 +40,12 @@ import {
|
||||
} from '@/components/ui/select';
|
||||
import { Sheet, SheetContent } from '@/components/ui/sheet';
|
||||
|
||||
import { useSelector } from 'react-redux';
|
||||
|
||||
import modelService from '../api/services/modelService';
|
||||
import userService from '../api/services/userService';
|
||||
import ArrowLeft from '../assets/arrow-left.svg';
|
||||
import { selectToken } from '../preferences/preferenceSlice';
|
||||
import { WorkflowNode } from './types/workflow';
|
||||
import {
|
||||
AgentNode,
|
||||
@@ -77,6 +80,7 @@ interface UserTool {
|
||||
|
||||
function WorkflowBuilderInner() {
|
||||
const navigate = useNavigate();
|
||||
const token = useSelector(selectToken);
|
||||
const { agentId } = useParams<{ agentId?: string }>();
|
||||
const [searchParams] = useSearchParams();
|
||||
const folderId = searchParams.get('folder_id');
|
||||
@@ -304,7 +308,7 @@ function WorkflowBuilderInner() {
|
||||
setAvailableModels(modelService.transformModels(modelsData.models));
|
||||
}
|
||||
|
||||
const toolsResponse = await userService.getUserTools(null);
|
||||
const toolsResponse = await userService.getUserTools(token);
|
||||
if (toolsResponse.ok) {
|
||||
const toolsData = await toolsResponse.json();
|
||||
setAvailableTools(toolsData.tools);
|
||||
|
||||
@@ -54,7 +54,10 @@ import { FileUpload } from '../../components/FileUpload';
|
||||
import AgentDetailsModal from '../../modals/AgentDetailsModal';
|
||||
import ConfirmationModal from '../../modals/ConfirmationModal';
|
||||
import { ActiveState } from '../../models/misc';
|
||||
import { selectToken } from '../../preferences/preferenceSlice';
|
||||
import {
|
||||
selectSourceDocs,
|
||||
selectToken,
|
||||
} from '../../preferences/preferenceSlice';
|
||||
import { getToolDisplayName } from '../../utils/toolUtils';
|
||||
import { Agent } from '../types';
|
||||
import { ConditionCase, WorkflowNode } from '../types/workflow';
|
||||
@@ -300,6 +303,7 @@ function createWorkflowPayload(
|
||||
function WorkflowBuilderInner() {
|
||||
const navigate = useNavigate();
|
||||
const token = useSelector(selectToken);
|
||||
const sourceDocs = useSelector(selectSourceDocs);
|
||||
const { agentId } = useParams<{ agentId?: string }>();
|
||||
const [searchParams] = useSearchParams();
|
||||
const folderId = searchParams.get('folder_id');
|
||||
@@ -341,6 +345,14 @@ function WorkflowBuilderInner() {
|
||||
const [availableModels, setAvailableModels] = useState<Model[]>([]);
|
||||
const [defaultAgentModelId, setDefaultAgentModelId] = useState('');
|
||||
const [availableTools, setAvailableTools] = useState<UserTool[]>([]);
|
||||
const sourceOptions = useMemo(
|
||||
() =>
|
||||
(sourceDocs ?? []).map((doc) => ({
|
||||
value: doc.id ?? 'default',
|
||||
label: doc.name,
|
||||
})),
|
||||
[sourceDocs],
|
||||
);
|
||||
const [agentJsonSchemaDrafts, setAgentJsonSchemaDrafts] = useState<
|
||||
Record<string, string>
|
||||
>({});
|
||||
@@ -387,31 +399,39 @@ function WorkflowBuilderInner() {
|
||||
[],
|
||||
);
|
||||
|
||||
const onConnect = useCallback((params: Connection) => {
|
||||
setEdges((eds) => {
|
||||
const exists = eds.some(
|
||||
(e) =>
|
||||
e.source === params.source &&
|
||||
e.sourceHandle === params.sourceHandle &&
|
||||
e.target === params.target &&
|
||||
e.targetHandle === params.targetHandle,
|
||||
);
|
||||
if (exists) return eds;
|
||||
|
||||
const filtered = eds.filter(
|
||||
(e) =>
|
||||
!(
|
||||
const onConnect = useCallback(
|
||||
(params: Connection) => {
|
||||
setEdges((eds) => {
|
||||
const exists = eds.some(
|
||||
(e) =>
|
||||
e.source === params.source &&
|
||||
e.sourceHandle === (params.sourceHandle ?? null)
|
||||
) &&
|
||||
!(
|
||||
e.sourceHandle === params.sourceHandle &&
|
||||
e.target === params.target &&
|
||||
e.targetHandle === (params.targetHandle ?? null)
|
||||
),
|
||||
);
|
||||
return addEdge(params, filtered);
|
||||
});
|
||||
}, []);
|
||||
e.targetHandle === params.targetHandle,
|
||||
);
|
||||
if (exists) return eds;
|
||||
|
||||
const targetNode = nodes.find((n) => n.id === params.target);
|
||||
const isEndNode = targetNode?.type === 'end';
|
||||
|
||||
const filtered = eds.filter(
|
||||
(e) =>
|
||||
!(
|
||||
e.source === params.source &&
|
||||
e.sourceHandle === (params.sourceHandle ?? null)
|
||||
) &&
|
||||
// End nodes accept multiple incoming edges
|
||||
(isEndNode ||
|
||||
!(
|
||||
e.target === params.target &&
|
||||
e.targetHandle === (params.targetHandle ?? null)
|
||||
)),
|
||||
);
|
||||
return addEdge(params, filtered);
|
||||
});
|
||||
},
|
||||
[nodes],
|
||||
);
|
||||
|
||||
const onEdgeClick = useCallback((_event: React.MouseEvent, edge: Edge) => {
|
||||
setEdges((eds) => eds.filter((e) => e.id !== edge.id));
|
||||
@@ -701,7 +721,7 @@ function WorkflowBuilderInner() {
|
||||
setDefaultAgentModelId(preferredDefaultModel);
|
||||
}
|
||||
|
||||
const toolsResponse = await userService.getUserTools(null);
|
||||
const toolsResponse = await userService.getUserTools(token);
|
||||
if (toolsResponse.ok) {
|
||||
const toolsData = await toolsResponse.json();
|
||||
setAvailableTools(toolsData.tools);
|
||||
@@ -1271,8 +1291,8 @@ function WorkflowBuilderInner() {
|
||||
|
||||
const handlePrimaryAction = useCallback(() => {
|
||||
if (isPrimaryActionDisabled) return;
|
||||
void persistWorkflow(!canManageAgent);
|
||||
}, [isPrimaryActionDisabled, persistWorkflow, canManageAgent]);
|
||||
void persistWorkflow(false);
|
||||
}, [isPrimaryActionDisabled, persistWorkflow]);
|
||||
|
||||
const agentForDetails = useMemo<Agent>(
|
||||
() => ({
|
||||
@@ -1910,6 +1930,28 @@ function WorkflowBuilderInner() {
|
||||
emptyText="No tools available"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label className="mb-2 block text-sm font-medium text-gray-700 dark:text-gray-300">
|
||||
Sources
|
||||
</label>
|
||||
<MultiSelect
|
||||
options={sourceOptions}
|
||||
selected={
|
||||
selectedNode.data.config?.sources || []
|
||||
}
|
||||
onChange={(newSources) =>
|
||||
handleUpdateNodeData({
|
||||
config: {
|
||||
...(selectedNode.data.config || {}),
|
||||
sources: newSources,
|
||||
},
|
||||
})
|
||||
}
|
||||
placeholder="Select sources..."
|
||||
searchPlaceholder="Search sources..."
|
||||
emptyText="No sources available"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label className="mb-2 block text-sm font-medium text-gray-700 dark:text-gray-300">
|
||||
Structured Output (JSON Schema)
|
||||
|
||||
@@ -70,12 +70,12 @@ export function MultiSelect({
|
||||
role="combobox"
|
||||
aria-expanded={open}
|
||||
className={cn(
|
||||
'w-full justify-between border-[#E5E5E5] bg-white hover:bg-gray-50 dark:border-[#3A3A3A] dark:bg-[#2C2C2C] dark:hover:bg-[#383838]',
|
||||
'h-auto min-h-[2.5rem] w-full justify-between border-[#E5E5E5] bg-white py-1.5 hover:bg-gray-50 dark:border-[#3A3A3A] dark:bg-[#2C2C2C] dark:hover:bg-[#383838]',
|
||||
!selected.length && 'text-gray-500 dark:text-gray-400',
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<div className="flex flex-wrap gap-1">
|
||||
<div className="flex min-w-0 flex-wrap gap-1">
|
||||
{selected.length === 0 ? (
|
||||
placeholder
|
||||
) : (
|
||||
@@ -85,9 +85,9 @@ export function MultiSelect({
|
||||
return (
|
||||
<span
|
||||
key={option?.value || label}
|
||||
className="dark:bg-purple-30/30 bg-violets-are-blue/20 inline-flex items-center gap-1 rounded-md px-2 py-0.5 text-xs font-medium text-purple-700 dark:text-purple-300"
|
||||
className="dark:bg-purple-30/30 bg-violets-are-blue/20 inline-flex max-w-[calc(100%-1rem)] items-center gap-1 rounded-md px-2 py-0.5 text-xs font-medium text-purple-700 dark:text-purple-300"
|
||||
>
|
||||
{label}
|
||||
<span className="truncate">{label}</span>
|
||||
<span
|
||||
role="button"
|
||||
tabIndex={0}
|
||||
|
||||
@@ -3,7 +3,7 @@ testpaths = tests
|
||||
python_files = test_*.py
|
||||
python_classes = Test*
|
||||
python_functions = test_*
|
||||
addopts =
|
||||
addopts =
|
||||
-v
|
||||
--strict-markers
|
||||
--tb=short
|
||||
@@ -11,6 +11,7 @@ addopts =
|
||||
--cov-report=html
|
||||
--cov-report=term-missing
|
||||
--cov-report=xml
|
||||
--ignore=tests/integration
|
||||
markers =
|
||||
unit: Unit tests
|
||||
integration: Integration tests
|
||||
|
||||
@@ -977,8 +977,8 @@ function Connect-CloudAPIProvider {
|
||||
}
|
||||
"7" { # Novita
|
||||
$script:provider_name = "Novita"
|
||||
$script:llm_name = "novita"
|
||||
$script:model_name = "deepseek/deepseek-r1"
|
||||
$script:llm_provider = "novita"
|
||||
$script:model_name = "moonshotai/kimi-k2.5"
|
||||
Get-APIKey
|
||||
break
|
||||
}
|
||||
|
||||
2
setup.sh
2
setup.sh
@@ -704,7 +704,7 @@ connect_cloud_api_provider() {
|
||||
7) # Novita
|
||||
provider_name="Novita"
|
||||
llm_provider="novita"
|
||||
model_name="deepseek/deepseek-r1"
|
||||
model_name="moonshotai/kimi-k2.5"
|
||||
get_api_key
|
||||
break ;;
|
||||
b|B) clear; return 1 ;; # Clear screen and Back to Main Menu
|
||||
|
||||
202
tests/agents/test_api_body_serializer.py
Normal file
202
tests/agents/test_api_body_serializer.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""Tests for application/agents/tools/api_body_serializer.py"""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from application.agents.tools.api_body_serializer import (
|
||||
ContentType,
|
||||
RequestBodySerializer,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestContentTypeEnum:
|
||||
def test_json_value(self):
|
||||
assert ContentType.JSON == "application/json"
|
||||
|
||||
def test_form_urlencoded_value(self):
|
||||
assert ContentType.FORM_URLENCODED == "application/x-www-form-urlencoded"
|
||||
|
||||
def test_multipart_value(self):
|
||||
assert ContentType.MULTIPART_FORM_DATA == "multipart/form-data"
|
||||
|
||||
def test_text_plain_value(self):
|
||||
assert ContentType.TEXT_PLAIN == "text/plain"
|
||||
|
||||
def test_xml_value(self):
|
||||
assert ContentType.XML == "application/xml"
|
||||
|
||||
def test_octet_stream_value(self):
|
||||
assert ContentType.OCTET_STREAM == "application/octet-stream"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSerializeJson:
|
||||
def test_basic_json(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"key": "value"}, ContentType.JSON
|
||||
)
|
||||
assert json.loads(body) == {"key": "value"}
|
||||
assert headers["Content-Type"] == "application/json"
|
||||
|
||||
def test_nested_json(self):
|
||||
data = {"user": {"name": "Alice", "age": 30}}
|
||||
body, headers = RequestBodySerializer.serialize(data, ContentType.JSON)
|
||||
assert json.loads(body) == data
|
||||
|
||||
def test_empty_body_returns_none(self):
|
||||
body, headers = RequestBodySerializer.serialize({}, ContentType.JSON)
|
||||
assert body is None
|
||||
assert headers == {}
|
||||
|
||||
def test_none_body(self):
|
||||
body, headers = RequestBodySerializer.serialize(None, ContentType.JSON)
|
||||
assert body is None
|
||||
|
||||
def test_unknown_content_type_falls_back_to_json(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"k": "v"}, "application/vnd.custom+json"
|
||||
)
|
||||
assert json.loads(body) == {"k": "v"}
|
||||
|
||||
def test_content_type_with_charset_suffix(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"k": "v"}, "application/json; charset=utf-8"
|
||||
)
|
||||
assert json.loads(body) == {"k": "v"}
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSerializeFormUrlencoded:
|
||||
def test_basic_form(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"name": "Alice", "age": "30"}, ContentType.FORM_URLENCODED
|
||||
)
|
||||
assert "name=Alice" in body
|
||||
assert "age=30" in body
|
||||
assert headers["Content-Type"] == "application/x-www-form-urlencoded"
|
||||
|
||||
def test_none_values_skipped(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"name": "Alice", "skip": None}, ContentType.FORM_URLENCODED
|
||||
)
|
||||
assert "name=Alice" in body
|
||||
assert "skip" not in body
|
||||
|
||||
def test_list_explode_true(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"tags": ["a", "b"]},
|
||||
ContentType.FORM_URLENCODED,
|
||||
encoding_rules={"tags": {"style": "form", "explode": True}},
|
||||
)
|
||||
assert "tags=a" in body
|
||||
assert "tags=b" in body
|
||||
|
||||
def test_list_explode_false(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"tags": ["a", "b"]},
|
||||
ContentType.FORM_URLENCODED,
|
||||
encoding_rules={"tags": {"style": "form", "explode": False}},
|
||||
)
|
||||
# Value is percent-encoded by _serialize_form_value then urlencoded again
|
||||
assert "tags=" in body
|
||||
assert "a" in body and "b" in body
|
||||
|
||||
def test_dict_value_json_content_type(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"metadata": {"key": "val"}},
|
||||
ContentType.FORM_URLENCODED,
|
||||
encoding_rules={"metadata": {"contentType": "application/json"}},
|
||||
)
|
||||
assert "metadata" in body
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSerializeTextPlain:
|
||||
def test_single_value(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"message": "hello"}, ContentType.TEXT_PLAIN
|
||||
)
|
||||
assert body == "hello"
|
||||
assert headers["Content-Type"] == "text/plain"
|
||||
|
||||
def test_multiple_values(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"name": "Alice", "age": 30}, ContentType.TEXT_PLAIN
|
||||
)
|
||||
assert "name: Alice" in body
|
||||
assert "age: 30" in body
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSerializeXml:
|
||||
def test_basic_xml(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"name": "Alice"}, ContentType.XML
|
||||
)
|
||||
assert '<?xml version="1.0"' in body
|
||||
assert "<name>Alice</name>" in body
|
||||
assert headers["Content-Type"] == "application/xml"
|
||||
|
||||
def test_nested_xml(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"user": {"name": "Alice"}}, ContentType.XML
|
||||
)
|
||||
assert "<user>" in body
|
||||
assert "<name>Alice</name>" in body
|
||||
|
||||
def test_xml_escapes_special_chars(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"data": "<script>alert('xss')</script>"}, ContentType.XML
|
||||
)
|
||||
assert "<script>" in body
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSerializeOctetStream:
|
||||
def test_dict_body(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"key": "val"}, ContentType.OCTET_STREAM
|
||||
)
|
||||
assert isinstance(body, bytes)
|
||||
assert headers["Content-Type"] == "application/octet-stream"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSerializeMultipartFormData:
|
||||
def test_basic_multipart(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"field": "value"}, ContentType.MULTIPART_FORM_DATA
|
||||
)
|
||||
assert isinstance(body, bytes)
|
||||
assert "multipart/form-data" in headers["Content-Type"]
|
||||
assert "boundary=" in headers["Content-Type"]
|
||||
|
||||
def test_none_values_skipped(self):
|
||||
body, headers = RequestBodySerializer.serialize(
|
||||
{"field": "value", "empty": None}, ContentType.MULTIPART_FORM_DATA
|
||||
)
|
||||
body_str = body.decode("utf-8", errors="replace")
|
||||
assert "field" in body_str
|
||||
assert "empty" not in body_str
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestHelpers:
|
||||
def test_percent_encode(self):
|
||||
assert RequestBodySerializer._percent_encode("hello world") == "hello%20world"
|
||||
assert RequestBodySerializer._percent_encode("a/b") == "a%2Fb"
|
||||
assert RequestBodySerializer._percent_encode("safe", safe_chars="/") == "safe"
|
||||
|
||||
def test_escape_xml(self):
|
||||
assert "&" in RequestBodySerializer._escape_xml("&")
|
||||
assert "<" in RequestBodySerializer._escape_xml("<")
|
||||
assert ">" in RequestBodySerializer._escape_xml(">")
|
||||
assert """ in RequestBodySerializer._escape_xml('"')
|
||||
assert "'" in RequestBodySerializer._escape_xml("'")
|
||||
|
||||
def test_dict_to_xml_list(self):
|
||||
xml = RequestBodySerializer._dict_to_xml({"items": [1, 2, 3]})
|
||||
assert "<item>1</item>" in xml
|
||||
assert "<item>2</item>" in xml
|
||||
282
tests/agents/test_api_tool.py
Normal file
282
tests/agents/test_api_tool.py
Normal file
@@ -0,0 +1,282 @@
|
||||
"""Tests for application/agents/tools/api_tool.py"""
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from application.agents.tools.api_tool import APITool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool():
|
||||
return APITool(
|
||||
config={
|
||||
"url": "https://api.example.com/data",
|
||||
"method": "GET",
|
||||
"headers": {"Accept": "application/json"},
|
||||
"query_params": {},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def post_tool():
|
||||
return APITool(
|
||||
config={
|
||||
"url": "https://api.example.com/items",
|
||||
"method": "POST",
|
||||
"headers": {},
|
||||
"query_params": {},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestAPIToolInit:
|
||||
def test_default_values(self):
|
||||
tool = APITool(config={})
|
||||
assert tool.url == ""
|
||||
assert tool.method == "GET"
|
||||
assert tool.headers == {}
|
||||
assert tool.query_params == {}
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestMakeApiCall:
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_successful_get(self, mock_get, mock_validate, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {"result": "ok"}
|
||||
mock_resp.content = b'{"result":"ok"}'
|
||||
mock_get.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("any_action")
|
||||
|
||||
assert result["status_code"] == 200
|
||||
assert result["data"] == {"result": "ok"}
|
||||
assert result["message"] == "API call successful."
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.post")
|
||||
def test_successful_post(self, mock_post, mock_validate, post_tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 201
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {"id": 1}
|
||||
mock_resp.content = b'{"id":1}'
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
result = post_tool.execute_action("create", name="test")
|
||||
|
||||
assert result["status_code"] == 201
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
def test_ssrf_blocked(self, mock_validate, tool):
|
||||
from application.core.url_validation import SSRFError
|
||||
|
||||
mock_validate.side_effect = SSRFError("blocked")
|
||||
|
||||
result = tool.execute_action("any")
|
||||
|
||||
assert result["status_code"] is None
|
||||
assert "URL validation error" in result["message"]
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_timeout_error(self, mock_get, mock_validate, tool):
|
||||
mock_get.side_effect = requests.exceptions.Timeout()
|
||||
|
||||
result = tool.execute_action("any")
|
||||
|
||||
assert result["status_code"] is None
|
||||
assert "timeout" in result["message"].lower()
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_connection_error(self, mock_get, mock_validate, tool):
|
||||
mock_get.side_effect = requests.exceptions.ConnectionError("refused")
|
||||
|
||||
result = tool.execute_action("any")
|
||||
|
||||
assert result["status_code"] is None
|
||||
assert "Connection error" in result["message"]
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_http_error(self, mock_get, mock_validate, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 404
|
||||
mock_resp.text = "Not Found"
|
||||
mock_resp.json.side_effect = json.JSONDecodeError("", "", 0)
|
||||
mock_resp.raise_for_status.side_effect = requests.exceptions.HTTPError(
|
||||
response=mock_resp
|
||||
)
|
||||
mock_get.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("any")
|
||||
|
||||
assert result["status_code"] == 404
|
||||
assert "HTTP Error" in result["message"]
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
def test_unsupported_method(self, mock_validate):
|
||||
tool = APITool(
|
||||
config={"url": "https://example.com", "method": "CUSTOM"}
|
||||
)
|
||||
result = tool.execute_action("any")
|
||||
assert result["status_code"] is None
|
||||
assert "Unsupported" in result["message"]
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.put")
|
||||
def test_put_method(self, mock_put, mock_validate):
|
||||
tool = APITool(config={"url": "https://example.com/item/1", "method": "PUT"})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {}
|
||||
mock_resp.content = b'{}'
|
||||
mock_put.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("update", name="new")
|
||||
assert result["status_code"] == 200
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.delete")
|
||||
def test_delete_method(self, mock_delete, mock_validate):
|
||||
tool = APITool(config={"url": "https://example.com/item/1", "method": "DELETE"})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 204
|
||||
mock_resp.headers = {"Content-Type": "text/plain"}
|
||||
mock_resp.content = b''
|
||||
mock_delete.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("delete")
|
||||
assert result["status_code"] == 204
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.patch")
|
||||
def test_patch_method(self, mock_patch, mock_validate):
|
||||
tool = APITool(config={"url": "https://example.com/item/1", "method": "PATCH"})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {"patched": True}
|
||||
mock_resp.content = b'{"patched":true}'
|
||||
mock_patch.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("patch", field="val")
|
||||
assert result["status_code"] == 200
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.head")
|
||||
def test_head_method(self, mock_head, mock_validate):
|
||||
tool = APITool(config={"url": "https://example.com", "method": "HEAD"})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "text/html"}
|
||||
mock_resp.content = b''
|
||||
mock_head.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("check")
|
||||
assert result["status_code"] == 200
|
||||
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.options")
|
||||
def test_options_method(self, mock_options, mock_validate):
|
||||
tool = APITool(config={"url": "https://example.com", "method": "OPTIONS"})
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "text/plain"}
|
||||
mock_resp.content = b''
|
||||
mock_options.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("options")
|
||||
assert result["status_code"] == 200
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestPathParamSubstitution:
|
||||
@patch("application.agents.tools.api_tool.validate_url")
|
||||
@patch("application.agents.tools.api_tool.requests.get")
|
||||
def test_path_params_substituted(self, mock_get, mock_validate):
|
||||
tool = APITool(
|
||||
config={
|
||||
"url": "https://api.example.com/users/{user_id}/posts/{post_id}",
|
||||
"method": "GET",
|
||||
"query_params": {"user_id": "42", "post_id": "7"},
|
||||
}
|
||||
)
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = []
|
||||
mock_resp.content = b'[]'
|
||||
mock_get.return_value = mock_resp
|
||||
|
||||
tool.execute_action("get")
|
||||
|
||||
called_url = mock_get.call_args[0][0]
|
||||
assert "/users/42/posts/7" in called_url
|
||||
assert "{user_id}" not in called_url
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestParseResponse:
|
||||
def test_json_response(self, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.json.return_value = {"key": "val"}
|
||||
mock_resp.content = b'{"key":"val"}'
|
||||
|
||||
result = tool._parse_response(mock_resp)
|
||||
assert result == {"key": "val"}
|
||||
|
||||
def test_text_response(self, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.headers = {"Content-Type": "text/plain"}
|
||||
mock_resp.text = "plain text"
|
||||
mock_resp.content = b"plain text"
|
||||
|
||||
result = tool._parse_response(mock_resp)
|
||||
assert result == "plain text"
|
||||
|
||||
def test_xml_response(self, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.headers = {"Content-Type": "application/xml"}
|
||||
mock_resp.text = "<root><item>1</item></root>"
|
||||
mock_resp.content = b"<root><item>1</item></root>"
|
||||
|
||||
result = tool._parse_response(mock_resp)
|
||||
assert "<root>" in result
|
||||
|
||||
def test_empty_content(self, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.headers = {"Content-Type": "application/json"}
|
||||
mock_resp.content = b""
|
||||
|
||||
result = tool._parse_response(mock_resp)
|
||||
assert result is None
|
||||
|
||||
def test_html_response(self, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.headers = {"Content-Type": "text/html"}
|
||||
mock_resp.text = "<html><body>Hi</body></html>"
|
||||
mock_resp.content = b"<html><body>Hi</body></html>"
|
||||
|
||||
result = tool._parse_response(mock_resp)
|
||||
assert "<html>" in result
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestAPIToolMetadata:
|
||||
def test_actions_metadata_empty(self, tool):
|
||||
assert tool.get_actions_metadata() == []
|
||||
|
||||
def test_config_requirements_empty(self, tool):
|
||||
assert tool.get_config_requirements() == {}
|
||||
132
tests/agents/test_brave_tool.py
Normal file
132
tests/agents/test_brave_tool.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""Tests for application/agents/tools/brave.py"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from application.agents.tools.brave import BraveSearchTool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool():
|
||||
return BraveSearchTool(config={"token": "test_api_key"})
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestBraveExecuteAction:
|
||||
def test_unknown_action_raises(self, tool):
|
||||
with pytest.raises(ValueError, match="Unknown action"):
|
||||
tool.execute_action("invalid")
|
||||
|
||||
@patch("application.agents.tools.brave.requests.get")
|
||||
def test_web_search_success(self, mock_get, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = {"web": {"results": [{"title": "Result"}]}}
|
||||
mock_get.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("brave_web_search", query="python")
|
||||
|
||||
assert result["status_code"] == 200
|
||||
assert "results" in result
|
||||
assert "successfully" in result["message"]
|
||||
|
||||
call_kwargs = mock_get.call_args
|
||||
assert call_kwargs[1]["headers"]["X-Subscription-Token"] == "test_api_key"
|
||||
|
||||
@patch("application.agents.tools.brave.requests.get")
|
||||
def test_web_search_failure(self, mock_get, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 429
|
||||
mock_get.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("brave_web_search", query="test")
|
||||
|
||||
assert result["status_code"] == 429
|
||||
assert "failed" in result["message"].lower()
|
||||
|
||||
@patch("application.agents.tools.brave.requests.get")
|
||||
def test_image_search_success(self, mock_get, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = {"results": [{"url": "https://img.com/1.jpg"}]}
|
||||
mock_get.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("brave_image_search", query="cats")
|
||||
|
||||
assert result["status_code"] == 200
|
||||
assert "results" in result
|
||||
|
||||
@patch("application.agents.tools.brave.requests.get")
|
||||
def test_image_search_failure(self, mock_get, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 500
|
||||
mock_get.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("brave_image_search", query="cats")
|
||||
|
||||
assert result["status_code"] == 500
|
||||
|
||||
@patch("application.agents.tools.brave.requests.get")
|
||||
def test_count_capped_at_20(self, mock_get, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = {}
|
||||
mock_get.return_value = mock_resp
|
||||
|
||||
tool.execute_action("brave_web_search", query="test", count=100)
|
||||
|
||||
params = mock_get.call_args[1]["params"]
|
||||
assert params["count"] == 20
|
||||
|
||||
@patch("application.agents.tools.brave.requests.get")
|
||||
def test_image_count_capped_at_100(self, mock_get, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = {}
|
||||
mock_get.return_value = mock_resp
|
||||
|
||||
tool.execute_action("brave_image_search", query="test", count=500)
|
||||
|
||||
params = mock_get.call_args[1]["params"]
|
||||
assert params["count"] == 100
|
||||
|
||||
@patch("application.agents.tools.brave.requests.get")
|
||||
def test_freshness_param(self, mock_get, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = {}
|
||||
mock_get.return_value = mock_resp
|
||||
|
||||
tool.execute_action("brave_web_search", query="news", freshness="pd")
|
||||
|
||||
params = mock_get.call_args[1]["params"]
|
||||
assert params["freshness"] == "pd"
|
||||
|
||||
@patch("application.agents.tools.brave.requests.get")
|
||||
def test_offset_capped(self, mock_get, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = {}
|
||||
mock_get.return_value = mock_resp
|
||||
|
||||
tool.execute_action("brave_web_search", query="test", offset=100)
|
||||
|
||||
params = mock_get.call_args[1]["params"]
|
||||
assert params["offset"] == 9
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestBraveMetadata:
|
||||
def test_actions_metadata(self, tool):
|
||||
meta = tool.get_actions_metadata()
|
||||
assert len(meta) == 2
|
||||
names = {a["name"] for a in meta}
|
||||
assert "brave_web_search" in names
|
||||
assert "brave_image_search" in names
|
||||
|
||||
def test_config_requirements(self, tool):
|
||||
reqs = tool.get_config_requirements()
|
||||
assert "token" in reqs
|
||||
assert reqs["token"]["secret"] is True
|
||||
assert reqs["token"]["required"] is True
|
||||
169
tests/agents/test_cel_evaluator.py
Normal file
169
tests/agents/test_cel_evaluator.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""Tests for application/agents/workflows/cel_evaluator.py"""
|
||||
|
||||
import pytest
|
||||
|
||||
from application.agents.workflows.cel_evaluator import (
|
||||
CelEvaluationError,
|
||||
_convert_value,
|
||||
build_activation,
|
||||
cel_to_python,
|
||||
evaluate_cel,
|
||||
)
|
||||
import celpy.celtypes
|
||||
|
||||
|
||||
class TestConvertValue:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_bool_true(self):
|
||||
result = _convert_value(True)
|
||||
assert isinstance(result, celpy.celtypes.BoolType)
|
||||
assert bool(result) is True
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_bool_false(self):
|
||||
result = _convert_value(False)
|
||||
assert isinstance(result, celpy.celtypes.BoolType)
|
||||
assert bool(result) is False
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_int(self):
|
||||
result = _convert_value(42)
|
||||
assert isinstance(result, celpy.celtypes.IntType)
|
||||
assert int(result) == 42
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_float(self):
|
||||
result = _convert_value(3.14)
|
||||
assert isinstance(result, celpy.celtypes.DoubleType)
|
||||
assert float(result) == pytest.approx(3.14)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_string(self):
|
||||
result = _convert_value("hello")
|
||||
assert isinstance(result, celpy.celtypes.StringType)
|
||||
assert str(result) == "hello"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_list(self):
|
||||
result = _convert_value([1, "two", 3.0])
|
||||
assert isinstance(result, celpy.celtypes.ListType)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_dict(self):
|
||||
result = _convert_value({"key": "value"})
|
||||
assert isinstance(result, celpy.celtypes.MapType)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_none(self):
|
||||
result = _convert_value(None)
|
||||
assert isinstance(result, celpy.celtypes.BoolType)
|
||||
assert bool(result) is False
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_other_type_converts_to_string(self):
|
||||
result = _convert_value(object())
|
||||
assert isinstance(result, celpy.celtypes.StringType)
|
||||
|
||||
|
||||
class TestBuildActivation:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_converts_dict_values(self):
|
||||
state = {"name": "Alice", "age": 30, "active": True}
|
||||
result = build_activation(state)
|
||||
assert "name" in result
|
||||
assert "age" in result
|
||||
assert "active" in result
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_empty_state(self):
|
||||
assert build_activation({}) == {}
|
||||
|
||||
|
||||
class TestEvaluateCel:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_simple_comparison(self):
|
||||
assert evaluate_cel("x > 5", {"x": 10}) is True
|
||||
assert evaluate_cel("x > 5", {"x": 3}) is False
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_string_comparison(self):
|
||||
assert evaluate_cel('name == "Alice"', {"name": "Alice"}) is True
|
||||
assert evaluate_cel('name == "Alice"', {"name": "Bob"}) is False
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_arithmetic(self):
|
||||
assert evaluate_cel("x + y", {"x": 3, "y": 4}) == 7
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_boolean_logic(self):
|
||||
assert evaluate_cel("a && b", {"a": True, "b": True}) is True
|
||||
assert evaluate_cel("a && b", {"a": True, "b": False}) is False
|
||||
assert evaluate_cel("a || b", {"a": False, "b": True}) is True
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_empty_expression_raises(self):
|
||||
with pytest.raises(CelEvaluationError, match="Empty expression"):
|
||||
evaluate_cel("", {})
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_whitespace_expression_raises(self):
|
||||
with pytest.raises(CelEvaluationError, match="Empty expression"):
|
||||
evaluate_cel(" ", {})
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invalid_expression_raises(self):
|
||||
with pytest.raises(CelEvaluationError):
|
||||
evaluate_cel("invalid!!!", {})
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_missing_variable_raises(self):
|
||||
with pytest.raises(CelEvaluationError):
|
||||
evaluate_cel("undefined_var > 5", {})
|
||||
|
||||
|
||||
class TestCelToPython:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_bool(self):
|
||||
result = cel_to_python(celpy.celtypes.BoolType(True))
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_int(self):
|
||||
result = cel_to_python(celpy.celtypes.IntType(42))
|
||||
assert result == 42
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_double(self):
|
||||
result = cel_to_python(celpy.celtypes.DoubleType(3.14))
|
||||
assert result == pytest.approx(3.14)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_string(self):
|
||||
result = cel_to_python(celpy.celtypes.StringType("hello"))
|
||||
assert result == "hello"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_list(self):
|
||||
cel_list = celpy.celtypes.ListType([
|
||||
celpy.celtypes.IntType(1),
|
||||
celpy.celtypes.IntType(2),
|
||||
])
|
||||
result = cel_to_python(cel_list)
|
||||
assert result == [1, 2]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_map(self):
|
||||
cel_map = celpy.celtypes.MapType({
|
||||
celpy.celtypes.StringType("key"): celpy.celtypes.StringType("value"),
|
||||
})
|
||||
result = cel_to_python(cel_map)
|
||||
assert result == {"key": "value"}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_unknown_type_passthrough(self):
|
||||
result = cel_to_python("raw_value")
|
||||
assert result == "raw_value"
|
||||
85
tests/agents/test_cryptoprice_tool.py
Normal file
85
tests/agents/test_cryptoprice_tool.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""Tests for application/agents/tools/cryptoprice.py"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from application.agents.tools.cryptoprice import CryptoPriceTool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool():
|
||||
return CryptoPriceTool(config={})
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCryptoPriceExecuteAction:
|
||||
def test_unknown_action_raises(self, tool):
|
||||
with pytest.raises(ValueError, match="Unknown action"):
|
||||
tool.execute_action("invalid_action")
|
||||
|
||||
@patch("application.agents.tools.cryptoprice.requests.get")
|
||||
def test_successful_price_fetch(self, mock_get, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = {"USD": 65000}
|
||||
mock_get.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("cryptoprice_get", symbol="BTC", currency="USD")
|
||||
|
||||
assert result["status_code"] == 200
|
||||
assert result["price"] == 65000
|
||||
assert "successfully" in result["message"]
|
||||
|
||||
@patch("application.agents.tools.cryptoprice.requests.get")
|
||||
def test_currency_not_found(self, mock_get, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = {"EUR": 60000}
|
||||
mock_get.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("cryptoprice_get", symbol="BTC", currency="USD")
|
||||
|
||||
assert result["status_code"] == 200
|
||||
assert "Couldn't find" in result["message"]
|
||||
assert "price" not in result
|
||||
|
||||
@patch("application.agents.tools.cryptoprice.requests.get")
|
||||
def test_api_failure(self, mock_get, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 500
|
||||
mock_get.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("cryptoprice_get", symbol="BTC", currency="USD")
|
||||
|
||||
assert result["status_code"] == 500
|
||||
assert "Failed" in result["message"]
|
||||
|
||||
@patch("application.agents.tools.cryptoprice.requests.get")
|
||||
def test_symbol_case_insensitive(self, mock_get, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = {"USD": 100}
|
||||
mock_get.return_value = mock_resp
|
||||
|
||||
tool.execute_action("cryptoprice_get", symbol="btc", currency="usd")
|
||||
|
||||
called_url = mock_get.call_args[0][0]
|
||||
assert "fsym=BTC" in called_url
|
||||
assert "tsyms=USD" in called_url
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCryptoPriceMetadata:
|
||||
def test_actions_metadata(self, tool):
|
||||
meta = tool.get_actions_metadata()
|
||||
assert len(meta) == 1
|
||||
assert meta[0]["name"] == "cryptoprice_get"
|
||||
params = meta[0]["parameters"]
|
||||
assert "symbol" in params["properties"]
|
||||
assert "currency" in params["properties"]
|
||||
assert "symbol" in params["required"]
|
||||
assert "currency" in params["required"]
|
||||
|
||||
def test_config_requirements(self, tool):
|
||||
assert tool.get_config_requirements() == {}
|
||||
145
tests/agents/test_duckduckgo_tool.py
Normal file
145
tests/agents/test_duckduckgo_tool.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""Tests for application/agents/tools/duckduckgo.py"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from application.agents.tools.duckduckgo import DuckDuckGoSearchTool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool():
|
||||
return DuckDuckGoSearchTool(config={})
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestDuckDuckGoExecuteAction:
|
||||
def test_unknown_action_raises(self, tool):
|
||||
with pytest.raises(ValueError, match="Unknown action"):
|
||||
tool.execute_action("invalid")
|
||||
|
||||
@patch.object(DuckDuckGoSearchTool, "_get_ddgs_client")
|
||||
def test_web_search_success(self, mock_client_factory, tool):
|
||||
mock_client = MagicMock()
|
||||
mock_client.text.return_value = [
|
||||
{"title": "Result 1", "href": "https://example.com", "body": "snippet"}
|
||||
]
|
||||
mock_client_factory.return_value = mock_client
|
||||
|
||||
result = tool.execute_action("ddg_web_search", query="python")
|
||||
|
||||
assert result["status_code"] == 200
|
||||
assert len(result["results"]) == 1
|
||||
assert "successfully" in result["message"]
|
||||
|
||||
@patch.object(DuckDuckGoSearchTool, "_get_ddgs_client")
|
||||
def test_image_search_success(self, mock_client_factory, tool):
|
||||
mock_client = MagicMock()
|
||||
mock_client.images.return_value = [{"image": "https://img.com/1.jpg"}]
|
||||
mock_client_factory.return_value = mock_client
|
||||
|
||||
result = tool.execute_action("ddg_image_search", query="cats")
|
||||
|
||||
assert result["status_code"] == 200
|
||||
assert len(result["results"]) == 1
|
||||
|
||||
@patch.object(DuckDuckGoSearchTool, "_get_ddgs_client")
|
||||
def test_news_search_success(self, mock_client_factory, tool):
|
||||
mock_client = MagicMock()
|
||||
mock_client.news.return_value = [{"title": "News"}]
|
||||
mock_client_factory.return_value = mock_client
|
||||
|
||||
result = tool.execute_action("ddg_news_search", query="tech")
|
||||
|
||||
assert result["status_code"] == 200
|
||||
assert len(result["results"]) == 1
|
||||
|
||||
@patch.object(DuckDuckGoSearchTool, "_get_ddgs_client")
|
||||
def test_search_error_returns_500(self, mock_client_factory, tool):
|
||||
mock_client = MagicMock()
|
||||
mock_client.text.side_effect = Exception("Network error")
|
||||
mock_client_factory.return_value = mock_client
|
||||
|
||||
result = tool.execute_action("ddg_web_search", query="test")
|
||||
|
||||
assert result["status_code"] == 500
|
||||
assert "failed" in result["message"].lower()
|
||||
assert result["results"] == []
|
||||
|
||||
@patch.object(DuckDuckGoSearchTool, "_get_ddgs_client")
|
||||
def test_max_results_capped_at_20(self, mock_client_factory, tool):
|
||||
mock_client = MagicMock()
|
||||
mock_client.text.return_value = []
|
||||
mock_client_factory.return_value = mock_client
|
||||
|
||||
tool.execute_action("ddg_web_search", query="test", max_results=100)
|
||||
|
||||
call_kwargs = mock_client.text.call_args[1]
|
||||
assert call_kwargs["max_results"] == 20
|
||||
|
||||
@patch.object(DuckDuckGoSearchTool, "_get_ddgs_client")
|
||||
def test_image_max_results_capped_at_50(self, mock_client_factory, tool):
|
||||
mock_client = MagicMock()
|
||||
mock_client.images.return_value = []
|
||||
mock_client_factory.return_value = mock_client
|
||||
|
||||
tool.execute_action("ddg_image_search", query="test", max_results=200)
|
||||
|
||||
call_kwargs = mock_client.images.call_args[1]
|
||||
assert call_kwargs["max_results"] == 50
|
||||
|
||||
@patch("application.agents.tools.duckduckgo.time.sleep")
|
||||
@patch.object(DuckDuckGoSearchTool, "_get_ddgs_client")
|
||||
def test_rate_limit_retries(self, mock_client_factory, mock_sleep, tool):
|
||||
mock_client = MagicMock()
|
||||
mock_client.text.side_effect = [
|
||||
Exception("RateLimit exceeded"),
|
||||
[{"title": "Result"}],
|
||||
]
|
||||
mock_client_factory.return_value = mock_client
|
||||
|
||||
result = tool.execute_action("ddg_web_search", query="test")
|
||||
|
||||
assert result["status_code"] == 200
|
||||
assert len(result["results"]) == 1
|
||||
mock_sleep.assert_called_once()
|
||||
|
||||
@patch.object(DuckDuckGoSearchTool, "_get_ddgs_client")
|
||||
def test_empty_results(self, mock_client_factory, tool):
|
||||
mock_client = MagicMock()
|
||||
mock_client.text.return_value = []
|
||||
mock_client_factory.return_value = mock_client
|
||||
|
||||
result = tool.execute_action("ddg_web_search", query="obscure query")
|
||||
|
||||
assert result["status_code"] == 200
|
||||
assert result["results"] == []
|
||||
|
||||
@patch.object(DuckDuckGoSearchTool, "_get_ddgs_client")
|
||||
def test_none_results(self, mock_client_factory, tool):
|
||||
mock_client = MagicMock()
|
||||
mock_client.text.return_value = None
|
||||
mock_client_factory.return_value = mock_client
|
||||
|
||||
result = tool.execute_action("ddg_web_search", query="test")
|
||||
|
||||
assert result["status_code"] == 200
|
||||
assert result["results"] == []
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestDuckDuckGoMetadata:
|
||||
def test_actions_metadata(self, tool):
|
||||
meta = tool.get_actions_metadata()
|
||||
assert len(meta) == 3
|
||||
names = {a["name"] for a in meta}
|
||||
assert "ddg_web_search" in names
|
||||
assert "ddg_image_search" in names
|
||||
assert "ddg_news_search" in names
|
||||
|
||||
def test_config_requirements(self, tool):
|
||||
assert tool.get_config_requirements() == {}
|
||||
|
||||
def test_custom_timeout(self):
|
||||
tool = DuckDuckGoSearchTool(config={"timeout": 30})
|
||||
assert tool.timeout == 30
|
||||
519
tests/agents/test_mcp_tool.py
Normal file
519
tests/agents/test_mcp_tool.py
Normal file
@@ -0,0 +1,519 @@
|
||||
"""Tests for application/agents/tools/mcp_tool.py"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# mcp_tool has a circular import at module level (mcp_tool -> tasks -> user -> mcp.py -> mcp_tool).
|
||||
# We must patch the dependencies BEFORE the module is first imported.
|
||||
@pytest.fixture(autouse=True)
|
||||
def _patch_mcp_globals(monkeypatch):
|
||||
"""Patch module-level MongoDB and cache to avoid real connections."""
|
||||
import sys
|
||||
|
||||
# If the module is already loaded, just patch attributes directly
|
||||
if "application.agents.tools.mcp_tool" in sys.modules:
|
||||
mcp_mod = sys.modules["application.agents.tools.mcp_tool"]
|
||||
else:
|
||||
# Break the circular import by pre-populating the tasks import
|
||||
# with a mock before mcp_tool tries to import it
|
||||
mock_tasks = MagicMock()
|
||||
monkeypatch.setitem(sys.modules, "application.api.user.tasks", mock_tasks)
|
||||
import application.agents.tools.mcp_tool as mcp_mod
|
||||
|
||||
mock_mongo = MagicMock()
|
||||
mock_db = MagicMock()
|
||||
mock_db.__getitem__ = MagicMock(return_value=MagicMock())
|
||||
monkeypatch.setattr(mcp_mod, "mongo", mock_mongo)
|
||||
monkeypatch.setattr(mcp_mod, "db", mock_db)
|
||||
monkeypatch.setattr(mcp_mod, "_mcp_clients_cache", {})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_config():
|
||||
return {
|
||||
"server_url": "https://mcp.example.com/api",
|
||||
"transport_type": "http",
|
||||
"auth_type": "none",
|
||||
"timeout": 10,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def bearer_config():
|
||||
return {
|
||||
"server_url": "https://mcp.example.com/api",
|
||||
"transport_type": "http",
|
||||
"auth_type": "bearer",
|
||||
"auth_credentials": {"bearer_token": "tok_123"},
|
||||
"timeout": 10,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestMCPToolInit:
|
||||
def test_basic_init(self, mcp_config):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
with patch.object(MCPTool, "_setup_client"):
|
||||
tool = MCPTool(mcp_config)
|
||||
|
||||
assert tool.server_url == "https://mcp.example.com/api"
|
||||
assert tool.transport_type == "http"
|
||||
assert tool.auth_type == "none"
|
||||
assert tool.timeout == 10
|
||||
|
||||
def test_bearer_auth_credentials(self, bearer_config):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
with patch.object(MCPTool, "_setup_client"):
|
||||
tool = MCPTool(bearer_config)
|
||||
assert tool.auth_credentials["bearer_token"] == "tok_123"
|
||||
|
||||
def test_no_server_url_skips_setup(self):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
with patch.object(MCPTool, "_setup_client") as mock_setup:
|
||||
MCPTool({"server_url": "", "auth_type": "none"})
|
||||
mock_setup.assert_not_called()
|
||||
|
||||
def test_oauth_skips_setup(self):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
with patch.object(MCPTool, "_setup_client") as mock_setup:
|
||||
MCPTool(
|
||||
{
|
||||
"server_url": "https://mcp.example.com",
|
||||
"auth_type": "oauth",
|
||||
}
|
||||
)
|
||||
mock_setup.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGenerateCacheKey:
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_none_auth(self, mock_setup, mcp_config):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(mcp_config)
|
||||
assert "none" in tool._cache_key
|
||||
assert "mcp.example.com" in tool._cache_key
|
||||
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_bearer_auth(self, mock_setup, bearer_config):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(bearer_config)
|
||||
assert "bearer:" in tool._cache_key
|
||||
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_api_key_auth(self, mock_setup):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(
|
||||
{
|
||||
"server_url": "https://mcp.example.com",
|
||||
"auth_type": "api_key",
|
||||
"auth_credentials": {"api_key": "sk-test12345678"},
|
||||
}
|
||||
)
|
||||
assert "apikey:" in tool._cache_key
|
||||
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_basic_auth(self, mock_setup):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(
|
||||
{
|
||||
"server_url": "https://mcp.example.com",
|
||||
"auth_type": "basic",
|
||||
"auth_credentials": {"username": "user1", "password": "pass"},
|
||||
}
|
||||
)
|
||||
assert "basic:user1" in tool._cache_key
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCreateTransport:
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_http_transport(self, mock_setup, mcp_config):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(mcp_config)
|
||||
transport = tool._create_transport()
|
||||
assert transport is not None
|
||||
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_sse_transport(self, mock_setup):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(
|
||||
{
|
||||
"server_url": "https://mcp.example.com/sse",
|
||||
"transport_type": "sse",
|
||||
"auth_type": "none",
|
||||
}
|
||||
)
|
||||
transport = tool._create_transport()
|
||||
assert transport is not None
|
||||
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_auto_detects_sse(self, mock_setup):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(
|
||||
{
|
||||
"server_url": "https://mcp.example.com/sse",
|
||||
"transport_type": "auto",
|
||||
"auth_type": "none",
|
||||
}
|
||||
)
|
||||
transport = tool._create_transport()
|
||||
assert transport is not None
|
||||
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_stdio_transport_disabled(self, mock_setup):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(
|
||||
{
|
||||
"server_url": "https://mcp.example.com",
|
||||
"transport_type": "stdio",
|
||||
"auth_type": "none",
|
||||
}
|
||||
)
|
||||
with pytest.raises(ValueError, match="STDIO transport is disabled"):
|
||||
tool._create_transport()
|
||||
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_api_key_header_injected(self, mock_setup):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(
|
||||
{
|
||||
"server_url": "https://mcp.example.com",
|
||||
"transport_type": "http",
|
||||
"auth_type": "api_key",
|
||||
"auth_credentials": {
|
||||
"api_key": "sk-test",
|
||||
"api_key_header": "X-Custom-Key",
|
||||
},
|
||||
}
|
||||
)
|
||||
# _create_transport will be called; verify it doesn't raise
|
||||
transport = tool._create_transport()
|
||||
assert transport is not None
|
||||
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_basic_auth_header_injected(self, mock_setup):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(
|
||||
{
|
||||
"server_url": "https://mcp.example.com",
|
||||
"transport_type": "http",
|
||||
"auth_type": "basic",
|
||||
"auth_credentials": {"username": "user", "password": "pass"},
|
||||
}
|
||||
)
|
||||
transport = tool._create_transport()
|
||||
assert transport is not None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestFormatTools:
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_format_list_of_dicts(self, mock_setup, mcp_config):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(mcp_config)
|
||||
result = tool._format_tools([{"name": "tool1", "description": "desc"}])
|
||||
assert len(result) == 1
|
||||
assert result[0]["name"] == "tool1"
|
||||
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_format_tools_with_name_attribute(self, mock_setup, mcp_config):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(mcp_config)
|
||||
mock_tool = MagicMock()
|
||||
mock_tool.name = "my_tool"
|
||||
mock_tool.description = "A tool"
|
||||
mock_tool.inputSchema = {"type": "object", "properties": {}}
|
||||
|
||||
result = tool._format_tools([mock_tool])
|
||||
assert len(result) == 1
|
||||
assert result[0]["name"] == "my_tool"
|
||||
assert result[0]["inputSchema"] == {"type": "object", "properties": {}}
|
||||
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_format_tools_response_object(self, mock_setup, mcp_config):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(mcp_config)
|
||||
resp = MagicMock()
|
||||
resp.tools = [{"name": "t1", "description": "d1"}]
|
||||
|
||||
result = tool._format_tools(resp)
|
||||
assert len(result) == 1
|
||||
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_format_empty(self, mock_setup, mcp_config):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(mcp_config)
|
||||
assert tool._format_tools([]) == []
|
||||
assert tool._format_tools("unexpected") == []
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestFormatResult:
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_format_result_with_content(self, mock_setup, mcp_config):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(mcp_config)
|
||||
mock_result = MagicMock()
|
||||
text_item = MagicMock()
|
||||
text_item.text = "Hello"
|
||||
del text_item.data
|
||||
mock_result.content = [text_item]
|
||||
mock_result.isError = False
|
||||
|
||||
result = tool._format_result(mock_result)
|
||||
assert result["content"][0]["type"] == "text"
|
||||
assert result["content"][0]["text"] == "Hello"
|
||||
assert result["isError"] is False
|
||||
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_format_raw_result(self, mock_setup, mcp_config):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(mcp_config)
|
||||
raw = {"key": "value"}
|
||||
assert tool._format_result(raw) == raw
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestExecuteAction:
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_no_server_raises(self, mock_setup):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool({"server_url": "", "auth_type": "none"})
|
||||
with pytest.raises(Exception, match="No MCP server configured"):
|
||||
tool.execute_action("test_action")
|
||||
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._run_async_operation")
|
||||
def test_successful_execute(self, mock_run, mock_setup, mcp_config):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(mcp_config)
|
||||
tool._client = MagicMock()
|
||||
mock_run.return_value = {"key": "value"}
|
||||
|
||||
result = tool.execute_action("test_action", param1="val1")
|
||||
|
||||
mock_run.assert_called_once_with("call_tool", "test_action", param1="val1")
|
||||
assert result == {"key": "value"}
|
||||
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._run_async_operation")
|
||||
def test_empty_kwargs_cleaned(self, mock_run, mock_setup, mcp_config):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(mcp_config)
|
||||
tool._client = MagicMock()
|
||||
mock_run.return_value = {}
|
||||
|
||||
tool.execute_action("test", param1="", param2=None, param3="real")
|
||||
|
||||
call_kwargs = mock_run.call_args[1]
|
||||
assert "param1" not in call_kwargs
|
||||
assert "param2" not in call_kwargs
|
||||
assert call_kwargs["param3"] == "real"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestTestConnection:
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_no_server_url(self, mock_setup):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool({"server_url": "", "auth_type": "none"})
|
||||
result = tool.test_connection()
|
||||
assert result["success"] is False
|
||||
assert "No server URL" in result["message"]
|
||||
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_invalid_scheme(self, mock_setup):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(
|
||||
{"server_url": "ftp://bad.com", "auth_type": "none"}
|
||||
)
|
||||
result = tool.test_connection()
|
||||
assert result["success"] is False
|
||||
assert "Invalid URL scheme" in result["message"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestMapError:
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_timeout_error(self, mock_setup, mcp_config):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
import concurrent.futures
|
||||
|
||||
tool = MCPTool(mcp_config)
|
||||
err = tool._map_error("test", concurrent.futures.TimeoutError())
|
||||
assert "Timed out" in str(err)
|
||||
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_connection_refused(self, mock_setup, mcp_config):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(mcp_config)
|
||||
err = tool._map_error("test", ConnectionRefusedError())
|
||||
assert "Connection refused" in str(err)
|
||||
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_403_forbidden(self, mock_setup, mcp_config):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(mcp_config)
|
||||
err = tool._map_error("test", Exception("403 Forbidden"))
|
||||
assert "Access denied" in str(err)
|
||||
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_ssl_error(self, mock_setup, mcp_config):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(mcp_config)
|
||||
err = tool._map_error("test", Exception("SSL certificate verify failed"))
|
||||
assert "SSL" in str(err)
|
||||
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_unknown_error_passthrough(self, mock_setup, mcp_config):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(mcp_config)
|
||||
original = RuntimeError("something weird")
|
||||
err = tool._map_error("test", original)
|
||||
assert err is original
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGetActionsMetadata:
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_empty_tools(self, mock_setup, mcp_config):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(mcp_config)
|
||||
tool.available_tools = []
|
||||
assert tool.get_actions_metadata() == []
|
||||
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_tools_with_input_schema(self, mock_setup, mcp_config):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(mcp_config)
|
||||
tool.available_tools = [
|
||||
{
|
||||
"name": "search",
|
||||
"description": "Search things",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {"query": {"type": "string"}},
|
||||
"required": ["query"],
|
||||
},
|
||||
}
|
||||
]
|
||||
meta = tool.get_actions_metadata()
|
||||
assert len(meta) == 1
|
||||
assert meta[0]["name"] == "search"
|
||||
assert "query" in meta[0]["parameters"]["properties"]
|
||||
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_tools_without_schema(self, mock_setup, mcp_config):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(mcp_config)
|
||||
tool.available_tools = [{"name": "ping", "description": "Ping"}]
|
||||
meta = tool.get_actions_metadata()
|
||||
assert meta[0]["parameters"]["properties"] == {}
|
||||
|
||||
@patch("application.agents.tools.mcp_tool.MCPTool._setup_client")
|
||||
def test_config_requirements(self, mock_setup, mcp_config):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
tool = MCPTool(mcp_config)
|
||||
reqs = tool.get_config_requirements()
|
||||
assert "server_url" in reqs
|
||||
assert "auth_type" in reqs
|
||||
assert reqs["server_url"]["required"] is True
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestMCPOAuthManager:
|
||||
def test_handle_callback_success(self):
|
||||
from application.agents.tools.mcp_tool import MCPOAuthManager
|
||||
|
||||
mock_redis = MagicMock()
|
||||
manager = MCPOAuthManager(mock_redis)
|
||||
|
||||
result = manager.handle_oauth_callback(state="abc123", code="auth_code")
|
||||
|
||||
assert result is True
|
||||
mock_redis.setex.assert_called()
|
||||
|
||||
def test_handle_callback_no_redis(self):
|
||||
from application.agents.tools.mcp_tool import MCPOAuthManager
|
||||
|
||||
manager = MCPOAuthManager(None)
|
||||
result = manager.handle_oauth_callback(state="abc", code="code")
|
||||
assert result is False
|
||||
|
||||
def test_handle_callback_error(self):
|
||||
from application.agents.tools.mcp_tool import MCPOAuthManager
|
||||
|
||||
mock_redis = MagicMock()
|
||||
manager = MCPOAuthManager(mock_redis)
|
||||
|
||||
result = manager.handle_oauth_callback(
|
||||
state="abc", code="", error="access_denied"
|
||||
)
|
||||
assert result is False
|
||||
|
||||
def test_get_oauth_status_no_task(self):
|
||||
from application.agents.tools.mcp_tool import MCPOAuthManager
|
||||
|
||||
manager = MCPOAuthManager(MagicMock())
|
||||
result = manager.get_oauth_status("")
|
||||
assert result["status"] == "not_started"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestDBTokenStorage:
|
||||
def test_get_base_url(self):
|
||||
from application.agents.tools.mcp_tool import DBTokenStorage
|
||||
|
||||
assert (
|
||||
DBTokenStorage.get_base_url("https://mcp.example.com/api/v1")
|
||||
== "https://mcp.example.com"
|
||||
)
|
||||
|
||||
def test_get_db_key(self):
|
||||
from application.agents.tools.mcp_tool import DBTokenStorage
|
||||
|
||||
mock_db = MagicMock()
|
||||
storage = DBTokenStorage(
|
||||
server_url="https://mcp.example.com/api",
|
||||
user_id="user1",
|
||||
db_client=mock_db,
|
||||
)
|
||||
key = storage.get_db_key()
|
||||
assert key["server_url"] == "https://mcp.example.com"
|
||||
assert key["user_id"] == "user1"
|
||||
146
tests/agents/test_ntfy_tool.py
Normal file
146
tests/agents/test_ntfy_tool.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""Tests for application/agents/tools/ntfy.py"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from application.agents.tools.ntfy import NtfyTool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool():
|
||||
return NtfyTool(config={"token": "test_token"})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool_no_token():
|
||||
return NtfyTool(config={})
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestNtfyExecuteAction:
|
||||
def test_unknown_action_raises(self, tool):
|
||||
with pytest.raises(ValueError, match="Unknown action"):
|
||||
tool.execute_action("bad_action")
|
||||
|
||||
@patch("application.agents.tools.ntfy.requests.post")
|
||||
def test_send_message_basic(self, mock_post, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action(
|
||||
"ntfy_send_message",
|
||||
server_url="https://ntfy.sh",
|
||||
message="Hello",
|
||||
topic="test",
|
||||
)
|
||||
|
||||
assert result["status_code"] == 200
|
||||
assert result["message"] == "Message sent"
|
||||
mock_post.assert_called_once()
|
||||
call_args = mock_post.call_args
|
||||
assert call_args[0][0] == "https://ntfy.sh/test"
|
||||
assert call_args[1]["data"] == b"Hello"
|
||||
|
||||
@patch("application.agents.tools.ntfy.requests.post")
|
||||
def test_send_with_title_and_priority(self, mock_post, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
tool.execute_action(
|
||||
"ntfy_send_message",
|
||||
server_url="https://ntfy.sh",
|
||||
message="Alert",
|
||||
topic="urgent",
|
||||
title="Warning",
|
||||
priority=5,
|
||||
)
|
||||
|
||||
headers = mock_post.call_args[1]["headers"]
|
||||
assert headers["X-Title"] == "Warning"
|
||||
assert headers["X-Priority"] == "5"
|
||||
|
||||
@patch("application.agents.tools.ntfy.requests.post")
|
||||
def test_auth_header_with_token(self, mock_post, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
tool.execute_action(
|
||||
"ntfy_send_message",
|
||||
server_url="https://ntfy.sh",
|
||||
message="Hi",
|
||||
topic="t",
|
||||
)
|
||||
|
||||
headers = mock_post.call_args[1]["headers"]
|
||||
assert headers["Authorization"] == "Basic test_token"
|
||||
|
||||
@patch("application.agents.tools.ntfy.requests.post")
|
||||
def test_no_auth_without_token(self, mock_post, tool_no_token):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
tool_no_token.execute_action(
|
||||
"ntfy_send_message",
|
||||
server_url="https://ntfy.sh",
|
||||
message="Hi",
|
||||
topic="t",
|
||||
)
|
||||
|
||||
headers = mock_post.call_args[1]["headers"]
|
||||
assert "Authorization" not in headers
|
||||
|
||||
def test_invalid_priority_raises(self, tool):
|
||||
with pytest.raises(ValueError, match="between 1 and 5"):
|
||||
tool.execute_action(
|
||||
"ntfy_send_message",
|
||||
server_url="https://ntfy.sh",
|
||||
message="Hi",
|
||||
topic="t",
|
||||
priority=10,
|
||||
)
|
||||
|
||||
def test_non_numeric_priority_raises(self, tool):
|
||||
with pytest.raises(ValueError, match="convertible to an integer"):
|
||||
tool.execute_action(
|
||||
"ntfy_send_message",
|
||||
server_url="https://ntfy.sh",
|
||||
message="Hi",
|
||||
topic="t",
|
||||
priority="abc",
|
||||
)
|
||||
|
||||
@patch("application.agents.tools.ntfy.requests.post")
|
||||
def test_trailing_slash_stripped(self, mock_post, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
tool.execute_action(
|
||||
"ntfy_send_message",
|
||||
server_url="https://ntfy.sh/",
|
||||
message="Hi",
|
||||
topic="test",
|
||||
)
|
||||
|
||||
assert mock_post.call_args[0][0] == "https://ntfy.sh/test"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestNtfyMetadata:
|
||||
def test_actions_metadata(self, tool):
|
||||
meta = tool.get_actions_metadata()
|
||||
assert len(meta) == 1
|
||||
assert meta[0]["name"] == "ntfy_send_message"
|
||||
assert "server_url" in meta[0]["parameters"]["properties"]
|
||||
assert "message" in meta[0]["parameters"]["properties"]
|
||||
assert "topic" in meta[0]["parameters"]["properties"]
|
||||
|
||||
def test_config_requirements(self, tool):
|
||||
reqs = tool.get_config_requirements()
|
||||
assert "token" in reqs
|
||||
assert reqs["token"]["secret"] is True
|
||||
146
tests/agents/test_postgres_tool.py
Normal file
146
tests/agents/test_postgres_tool.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""Tests for application/agents/tools/postgres.py"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from application.agents.tools.postgres import PostgresTool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool():
|
||||
return PostgresTool(config={"token": "postgresql://user:pass@localhost/testdb"})
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestPostgresExecuteAction:
|
||||
def test_unknown_action_raises(self, tool):
|
||||
with pytest.raises(ValueError, match="Unknown action"):
|
||||
tool.execute_action("invalid")
|
||||
|
||||
@patch("application.agents.tools.postgres.psycopg2.connect")
|
||||
def test_select_query(self, mock_connect, tool):
|
||||
mock_conn = MagicMock()
|
||||
mock_cur = MagicMock()
|
||||
mock_cur.description = [("id",), ("name",)]
|
||||
mock_cur.fetchall.return_value = [(1, "Alice"), (2, "Bob")]
|
||||
mock_conn.cursor.return_value = mock_cur
|
||||
mock_connect.return_value = mock_conn
|
||||
|
||||
result = tool.execute_action(
|
||||
"postgres_execute_sql", sql_query="SELECT id, name FROM users"
|
||||
)
|
||||
|
||||
assert result["status_code"] == 200
|
||||
assert result["response_data"]["column_names"] == ["id", "name"]
|
||||
assert len(result["response_data"]["data"]) == 2
|
||||
assert result["response_data"]["data"][0] == {"id": 1, "name": "Alice"}
|
||||
mock_conn.close.assert_called_once()
|
||||
|
||||
@patch("application.agents.tools.postgres.psycopg2.connect")
|
||||
def test_insert_query(self, mock_connect, tool):
|
||||
mock_conn = MagicMock()
|
||||
mock_cur = MagicMock()
|
||||
mock_cur.rowcount = 1
|
||||
mock_conn.cursor.return_value = mock_cur
|
||||
mock_connect.return_value = mock_conn
|
||||
|
||||
result = tool.execute_action(
|
||||
"postgres_execute_sql",
|
||||
sql_query="INSERT INTO users (name) VALUES ('Alice')",
|
||||
)
|
||||
|
||||
assert result["status_code"] == 200
|
||||
assert "1 rows affected" in result["response_data"]["message"]
|
||||
mock_conn.commit.assert_called_once()
|
||||
mock_conn.close.assert_called_once()
|
||||
|
||||
@patch("application.agents.tools.postgres.psycopg2.connect")
|
||||
def test_db_error(self, mock_connect, tool):
|
||||
import psycopg2
|
||||
|
||||
mock_connect.side_effect = psycopg2.Error("connection refused")
|
||||
|
||||
result = tool.execute_action(
|
||||
"postgres_execute_sql", sql_query="SELECT 1"
|
||||
)
|
||||
|
||||
assert result["status_code"] == 500
|
||||
assert "Database error" in result["error"]
|
||||
|
||||
@patch("application.agents.tools.postgres.psycopg2.connect")
|
||||
def test_get_schema(self, mock_connect, tool):
|
||||
mock_conn = MagicMock()
|
||||
mock_cur = MagicMock()
|
||||
mock_cur.fetchall.return_value = [
|
||||
("users", "id", "integer", "nextval(...)", "NO"),
|
||||
("users", "name", "varchar", None, "YES"),
|
||||
("posts", "id", "integer", "nextval(...)", "NO"),
|
||||
]
|
||||
mock_conn.cursor.return_value = mock_cur
|
||||
mock_connect.return_value = mock_conn
|
||||
|
||||
result = tool.execute_action("postgres_get_schema", db_name="testdb")
|
||||
|
||||
assert result["status_code"] == 200
|
||||
assert "users" in result["schema"]
|
||||
assert "posts" in result["schema"]
|
||||
assert len(result["schema"]["users"]) == 2
|
||||
assert result["schema"]["users"][0]["column_name"] == "id"
|
||||
mock_conn.close.assert_called_once()
|
||||
|
||||
@patch("application.agents.tools.postgres.psycopg2.connect")
|
||||
def test_get_schema_db_error(self, mock_connect, tool):
|
||||
import psycopg2
|
||||
|
||||
mock_connect.side_effect = psycopg2.Error("auth failed")
|
||||
|
||||
result = tool.execute_action("postgres_get_schema", db_name="testdb")
|
||||
|
||||
assert result["status_code"] == 500
|
||||
assert "Database error" in result["error"]
|
||||
|
||||
@patch("application.agents.tools.postgres.psycopg2.connect")
|
||||
def test_connection_closed_on_error(self, mock_connect, tool):
|
||||
import psycopg2
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cur = MagicMock()
|
||||
mock_cur.execute.side_effect = psycopg2.Error("syntax error")
|
||||
mock_conn.cursor.return_value = mock_cur
|
||||
mock_connect.return_value = mock_conn
|
||||
|
||||
tool.execute_action("postgres_execute_sql", sql_query="BAD SQL")
|
||||
|
||||
mock_conn.close.assert_called_once()
|
||||
|
||||
@patch("application.agents.tools.postgres.psycopg2.connect")
|
||||
def test_select_with_no_description(self, mock_connect, tool):
|
||||
mock_conn = MagicMock()
|
||||
mock_cur = MagicMock()
|
||||
mock_cur.description = None
|
||||
mock_cur.fetchall.return_value = []
|
||||
mock_conn.cursor.return_value = mock_cur
|
||||
mock_connect.return_value = mock_conn
|
||||
|
||||
result = tool.execute_action(
|
||||
"postgres_execute_sql", sql_query="SELECT 1 WHERE false"
|
||||
)
|
||||
|
||||
assert result["status_code"] == 200
|
||||
assert result["response_data"]["column_names"] == []
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestPostgresMetadata:
|
||||
def test_actions_metadata(self, tool):
|
||||
meta = tool.get_actions_metadata()
|
||||
assert len(meta) == 2
|
||||
names = {a["name"] for a in meta}
|
||||
assert "postgres_execute_sql" in names
|
||||
assert "postgres_get_schema" in names
|
||||
|
||||
def test_config_requirements(self, tool):
|
||||
reqs = tool.get_config_requirements()
|
||||
assert "token" in reqs
|
||||
assert reqs["token"]["secret"] is True
|
||||
86
tests/agents/test_read_webpage_tool.py
Normal file
86
tests/agents/test_read_webpage_tool.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""Tests for application/agents/tools/read_webpage.py"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from application.agents.tools.read_webpage import ReadWebpageTool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool():
|
||||
return ReadWebpageTool()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestReadWebpageExecuteAction:
|
||||
def test_unknown_action(self, tool):
|
||||
result = tool.execute_action("unknown_action")
|
||||
assert "Error" in result
|
||||
assert "Unknown action" in result
|
||||
|
||||
def test_missing_url(self, tool):
|
||||
result = tool.execute_action("read_webpage")
|
||||
assert "Error" in result
|
||||
assert "URL parameter is missing" in result
|
||||
|
||||
@patch("application.agents.tools.read_webpage.validate_url")
|
||||
@patch("application.agents.tools.read_webpage.requests.get")
|
||||
def test_successful_fetch(self, mock_get, mock_validate, tool):
|
||||
mock_validate.return_value = "https://example.com"
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.text = "<html><body><h1>Title</h1><p>Content</p></body></html>"
|
||||
mock_get.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("read_webpage", url="https://example.com")
|
||||
|
||||
assert "Title" in result
|
||||
assert "Content" in result
|
||||
|
||||
@patch("application.agents.tools.read_webpage.validate_url")
|
||||
@patch("application.agents.tools.read_webpage.requests.get")
|
||||
def test_request_error(self, mock_get, mock_validate, tool):
|
||||
mock_validate.return_value = "https://example.com"
|
||||
mock_get.side_effect = requests.exceptions.ConnectionError("refused")
|
||||
|
||||
result = tool.execute_action("read_webpage", url="https://example.com")
|
||||
|
||||
assert "Error fetching URL" in result
|
||||
|
||||
@patch("application.agents.tools.read_webpage.validate_url")
|
||||
def test_ssrf_blocked(self, mock_validate, tool):
|
||||
from application.core.url_validation import SSRFError
|
||||
|
||||
mock_validate.side_effect = SSRFError("blocked")
|
||||
|
||||
result = tool.execute_action("read_webpage", url="http://169.254.169.254/")
|
||||
|
||||
assert "Error" in result
|
||||
assert "validation failed" in result
|
||||
|
||||
@patch("application.agents.tools.read_webpage.validate_url")
|
||||
@patch("application.agents.tools.read_webpage.requests.get")
|
||||
def test_http_error(self, mock_get, mock_validate, tool):
|
||||
mock_validate.return_value = "https://example.com/404"
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.raise_for_status.side_effect = requests.exceptions.HTTPError("404")
|
||||
mock_get.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action("read_webpage", url="https://example.com/404")
|
||||
|
||||
assert "Error fetching URL" in result
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestReadWebpageMetadata:
|
||||
def test_actions_metadata(self, tool):
|
||||
meta = tool.get_actions_metadata()
|
||||
assert len(meta) == 1
|
||||
assert meta[0]["name"] == "read_webpage"
|
||||
assert "url" in meta[0]["parameters"]["properties"]
|
||||
assert "url" in meta[0]["parameters"]["required"]
|
||||
|
||||
def test_config_requirements(self, tool):
|
||||
assert tool.get_config_requirements() == {}
|
||||
335
tests/agents/test_spec_parser.py
Normal file
335
tests/agents/test_spec_parser.py
Normal file
@@ -0,0 +1,335 @@
|
||||
"""Tests for application/agents/tools/spec_parser.py"""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from application.agents.tools.spec_parser import (
|
||||
_extract_metadata,
|
||||
_generate_action_name,
|
||||
_get_base_url,
|
||||
_load_spec,
|
||||
_param_to_property,
|
||||
_resolve_ref,
|
||||
_validate_spec,
|
||||
parse_spec,
|
||||
)
|
||||
|
||||
|
||||
MINIMAL_OPENAPI = json.dumps(
|
||||
{
|
||||
"openapi": "3.0.0",
|
||||
"info": {"title": "Test API", "version": "1.0.0"},
|
||||
"servers": [{"url": "https://api.example.com"}],
|
||||
"paths": {
|
||||
"/users": {
|
||||
"get": {
|
||||
"operationId": "listUsers",
|
||||
"summary": "List all users",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "limit",
|
||||
"in": "query",
|
||||
"schema": {"type": "integer"},
|
||||
}
|
||||
],
|
||||
"responses": {"200": {"description": "OK"}},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
MINIMAL_SWAGGER = json.dumps(
|
||||
{
|
||||
"swagger": "2.0",
|
||||
"info": {"title": "Swagger API", "version": "2.0.0"},
|
||||
"host": "api.example.com",
|
||||
"basePath": "/v1",
|
||||
"schemes": ["https"],
|
||||
"paths": {
|
||||
"/pets": {
|
||||
"get": {
|
||||
"operationId": "listPets",
|
||||
"summary": "List pets",
|
||||
"parameters": [],
|
||||
"responses": {"200": {"description": "OK"}},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestLoadSpec:
|
||||
def test_load_json(self):
|
||||
spec = _load_spec('{"openapi": "3.0.0"}')
|
||||
assert spec["openapi"] == "3.0.0"
|
||||
|
||||
def test_load_yaml(self):
|
||||
spec = _load_spec("openapi: '3.0.0'\ninfo:\n title: Test")
|
||||
assert spec["openapi"] == "3.0.0"
|
||||
|
||||
def test_empty_raises(self):
|
||||
with pytest.raises(ValueError, match="Empty"):
|
||||
_load_spec("")
|
||||
|
||||
def test_whitespace_only_raises(self):
|
||||
with pytest.raises(ValueError, match="Empty"):
|
||||
_load_spec(" \n ")
|
||||
|
||||
def test_invalid_json_raises(self):
|
||||
with pytest.raises(ValueError, match="Invalid"):
|
||||
_load_spec("{invalid json}")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestValidateSpec:
|
||||
def test_valid_openapi(self):
|
||||
spec = json.loads(MINIMAL_OPENAPI)
|
||||
_validate_spec(spec) # should not raise
|
||||
|
||||
def test_valid_swagger(self):
|
||||
spec = json.loads(MINIMAL_SWAGGER)
|
||||
_validate_spec(spec)
|
||||
|
||||
def test_missing_version_raises(self):
|
||||
with pytest.raises(ValueError, match="Unsupported"):
|
||||
_validate_spec({"paths": {"/a": {}}})
|
||||
|
||||
def test_no_paths_raises(self):
|
||||
with pytest.raises(ValueError, match="No API paths"):
|
||||
_validate_spec({"openapi": "3.0.0"})
|
||||
|
||||
def test_empty_paths_raises(self):
|
||||
with pytest.raises(ValueError, match="No API paths"):
|
||||
_validate_spec({"openapi": "3.0.0", "paths": {}})
|
||||
|
||||
def test_non_dict_raises(self):
|
||||
with pytest.raises(ValueError, match="valid object"):
|
||||
_validate_spec("not a dict")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestExtractMetadata:
|
||||
def test_openapi_metadata(self):
|
||||
spec = json.loads(MINIMAL_OPENAPI)
|
||||
meta = _extract_metadata(spec, is_swagger=False)
|
||||
assert meta["title"] == "Test API"
|
||||
assert meta["version"] == "1.0.0"
|
||||
assert meta["base_url"] == "https://api.example.com"
|
||||
|
||||
def test_swagger_metadata(self):
|
||||
spec = json.loads(MINIMAL_SWAGGER)
|
||||
meta = _extract_metadata(spec, is_swagger=True)
|
||||
assert meta["title"] == "Swagger API"
|
||||
assert meta["base_url"] == "https://api.example.com/v1"
|
||||
|
||||
def test_missing_info(self):
|
||||
spec = {"openapi": "3.0.0", "paths": {"/a": {}}}
|
||||
meta = _extract_metadata(spec, is_swagger=False)
|
||||
assert meta["title"] == "Untitled API"
|
||||
|
||||
def test_description_truncated(self):
|
||||
spec = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {"title": "T", "description": "x" * 1000},
|
||||
"paths": {"/a": {}},
|
||||
}
|
||||
meta = _extract_metadata(spec, is_swagger=False)
|
||||
assert len(meta["description"]) <= 500
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGetBaseUrl:
|
||||
def test_openapi_servers(self):
|
||||
spec = {"servers": [{"url": "https://api.example.com/v2/"}]}
|
||||
assert _get_base_url(spec, is_swagger=False) == "https://api.example.com/v2"
|
||||
|
||||
def test_openapi_no_servers(self):
|
||||
assert _get_base_url({}, is_swagger=False) == ""
|
||||
|
||||
def test_swagger_with_host(self):
|
||||
spec = {"host": "api.test.com", "basePath": "/v1", "schemes": ["https"]}
|
||||
assert _get_base_url(spec, is_swagger=True) == "https://api.test.com/v1"
|
||||
|
||||
def test_swagger_no_host(self):
|
||||
assert _get_base_url({}, is_swagger=True) == ""
|
||||
|
||||
def test_swagger_default_scheme(self):
|
||||
spec = {"host": "api.test.com", "basePath": ""}
|
||||
assert _get_base_url(spec, is_swagger=True) == "https://api.test.com"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGenerateActionName:
|
||||
def test_uses_operation_id(self):
|
||||
assert _generate_action_name({"operationId": "getUser"}, "get", "/users") == "getUser"
|
||||
|
||||
def test_fallback_to_method_path(self):
|
||||
name = _generate_action_name({}, "get", "/users/{id}")
|
||||
assert name.startswith("get_")
|
||||
assert "users" in name
|
||||
|
||||
def test_truncates_long_names(self):
|
||||
name = _generate_action_name({}, "get", "/" + "a" * 200)
|
||||
assert len(name) <= 64
|
||||
|
||||
def test_sanitizes_special_chars(self):
|
||||
name = _generate_action_name({"operationId": "get.user@v2"}, "get", "/")
|
||||
assert "." not in name
|
||||
assert "@" not in name
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestParamToProperty:
|
||||
def test_string_param(self):
|
||||
param = {"name": "q", "in": "query", "schema": {"type": "string"}}
|
||||
prop = _param_to_property(param)
|
||||
assert prop["type"] == "string"
|
||||
assert prop["required"] is False
|
||||
|
||||
def test_integer_param(self):
|
||||
param = {
|
||||
"name": "limit",
|
||||
"in": "query",
|
||||
"required": True,
|
||||
"schema": {"type": "integer"},
|
||||
}
|
||||
prop = _param_to_property(param)
|
||||
assert prop["type"] == "integer"
|
||||
assert prop["required"] is True
|
||||
assert prop["filled_by_llm"] is True
|
||||
|
||||
def test_number_maps_to_integer(self):
|
||||
param = {"name": "score", "in": "query", "schema": {"type": "number"}}
|
||||
prop = _param_to_property(param)
|
||||
assert prop["type"] == "integer"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestResolveRef:
|
||||
def test_no_ref(self):
|
||||
obj = {"type": "string"}
|
||||
assert _resolve_ref(obj, {}, {}) == obj
|
||||
|
||||
def test_components_ref(self):
|
||||
components = {"schemas": {"User": {"type": "object", "properties": {"name": {"type": "string"}}}}}
|
||||
obj = {"$ref": "#/components/schemas/User"}
|
||||
result = _resolve_ref(obj, components, {})
|
||||
assert result["type"] == "object"
|
||||
|
||||
def test_definitions_ref(self):
|
||||
definitions = {"Pet": {"type": "object"}}
|
||||
obj = {"$ref": "#/definitions/Pet"}
|
||||
result = _resolve_ref(obj, {}, definitions)
|
||||
assert result["type"] == "object"
|
||||
|
||||
def test_unsupported_ref(self):
|
||||
obj = {"$ref": "#/external/something"}
|
||||
assert _resolve_ref(obj, {}, {}) is None
|
||||
|
||||
def test_non_dict_returns_none(self):
|
||||
assert _resolve_ref("string", {}, {}) is None
|
||||
assert _resolve_ref(42, {}, {}) is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestParseSpec:
|
||||
def test_openapi_full_parse(self):
|
||||
metadata, actions = parse_spec(MINIMAL_OPENAPI)
|
||||
assert metadata["title"] == "Test API"
|
||||
assert len(actions) == 1
|
||||
assert actions[0]["name"] == "listUsers"
|
||||
assert actions[0]["method"] == "GET"
|
||||
assert actions[0]["url"] == "https://api.example.com/users"
|
||||
assert "limit" in actions[0]["query_params"]["properties"]
|
||||
|
||||
def test_swagger_full_parse(self):
|
||||
metadata, actions = parse_spec(MINIMAL_SWAGGER)
|
||||
assert metadata["title"] == "Swagger API"
|
||||
assert len(actions) == 1
|
||||
assert actions[0]["name"] == "listPets"
|
||||
assert actions[0]["method"] == "GET"
|
||||
|
||||
def test_multiple_methods(self):
|
||||
spec = json.dumps(
|
||||
{
|
||||
"openapi": "3.0.0",
|
||||
"info": {"title": "T", "version": "1"},
|
||||
"paths": {
|
||||
"/items": {
|
||||
"get": {"operationId": "listItems", "responses": {}},
|
||||
"post": {
|
||||
"operationId": "createItem",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"}
|
||||
},
|
||||
"required": ["name"],
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
metadata, actions = parse_spec(spec)
|
||||
assert len(actions) == 2
|
||||
names = {a["name"] for a in actions}
|
||||
assert "listItems" in names
|
||||
assert "createItem" in names
|
||||
|
||||
create = next(a for a in actions if a["name"] == "createItem")
|
||||
assert "name" in create["body"]["properties"]
|
||||
|
||||
def test_header_params(self):
|
||||
spec = json.dumps(
|
||||
{
|
||||
"openapi": "3.0.0",
|
||||
"info": {"title": "T", "version": "1"},
|
||||
"paths": {
|
||||
"/data": {
|
||||
"get": {
|
||||
"operationId": "getData",
|
||||
"parameters": [
|
||||
{"name": "X-API-Key", "in": "header", "schema": {"type": "string"}}
|
||||
],
|
||||
"responses": {},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
_, actions = parse_spec(spec)
|
||||
assert "X-API-Key" in actions[0]["headers"]["properties"]
|
||||
|
||||
def test_invalid_spec_raises(self):
|
||||
with pytest.raises(ValueError):
|
||||
parse_spec("")
|
||||
|
||||
def test_yaml_spec(self):
|
||||
yaml_spec = """
|
||||
openapi: "3.0.0"
|
||||
info:
|
||||
title: YAML API
|
||||
version: "1.0"
|
||||
paths:
|
||||
/health:
|
||||
get:
|
||||
operationId: healthCheck
|
||||
responses:
|
||||
"200":
|
||||
description: OK
|
||||
"""
|
||||
metadata, actions = parse_spec(yaml_spec)
|
||||
assert metadata["title"] == "YAML API"
|
||||
assert actions[0]["name"] == "healthCheck"
|
||||
79
tests/agents/test_telegram_tool.py
Normal file
79
tests/agents/test_telegram_tool.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""Tests for application/agents/tools/telegram.py"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from application.agents.tools.telegram import TelegramTool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool():
|
||||
return TelegramTool(config={"token": "bot123:ABC"})
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestTelegramExecuteAction:
|
||||
def test_unknown_action_raises(self, tool):
|
||||
with pytest.raises(ValueError, match="Unknown action"):
|
||||
tool.execute_action("invalid")
|
||||
|
||||
@patch("application.agents.tools.telegram.requests.post")
|
||||
def test_send_message(self, mock_post, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action(
|
||||
"telegram_send_message", text="Hello", chat_id="12345"
|
||||
)
|
||||
|
||||
assert result["status_code"] == 200
|
||||
assert result["message"] == "Message sent"
|
||||
call_args = mock_post.call_args
|
||||
assert "bot123:ABC/sendMessage" in call_args[0][0]
|
||||
assert call_args[1]["data"]["text"] == "Hello"
|
||||
assert call_args[1]["data"]["chat_id"] == "12345"
|
||||
|
||||
@patch("application.agents.tools.telegram.requests.post")
|
||||
def test_send_image(self, mock_post, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action(
|
||||
"telegram_send_image", image_url="https://img.com/cat.jpg", chat_id="12345"
|
||||
)
|
||||
|
||||
assert result["status_code"] == 200
|
||||
assert result["message"] == "Image sent"
|
||||
call_args = mock_post.call_args
|
||||
assert "bot123:ABC/sendPhoto" in call_args[0][0]
|
||||
assert call_args[1]["data"]["photo"] == "https://img.com/cat.jpg"
|
||||
|
||||
@patch("application.agents.tools.telegram.requests.post")
|
||||
def test_api_error_status(self, mock_post, tool):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 403
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
result = tool.execute_action(
|
||||
"telegram_send_message", text="Hi", chat_id="999"
|
||||
)
|
||||
|
||||
assert result["status_code"] == 403
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestTelegramMetadata:
|
||||
def test_actions_metadata(self, tool):
|
||||
meta = tool.get_actions_metadata()
|
||||
assert len(meta) == 2
|
||||
names = {a["name"] for a in meta}
|
||||
assert "telegram_send_message" in names
|
||||
assert "telegram_send_image" in names
|
||||
|
||||
def test_config_requirements(self, tool):
|
||||
reqs = tool.get_config_requirements()
|
||||
assert "token" in reqs
|
||||
assert reqs["token"]["secret"] is True
|
||||
433
tests/agents/test_workflow_agent_coverage.py
Normal file
433
tests/agents/test_workflow_agent_coverage.py
Normal file
@@ -0,0 +1,433 @@
|
||||
"""Tests for WorkflowAgent - covering _parse_embedded_workflow, _load_from_database,
|
||||
_save_workflow_run, _determine_run_status, _serialize_state, and gen flow."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from application.agents.workflows.schemas import (
|
||||
ExecutionStatus,
|
||||
WorkflowGraph,
|
||||
)
|
||||
|
||||
|
||||
def _make_agent(**overrides):
|
||||
"""Create a WorkflowAgent with mocked base class dependencies."""
|
||||
defaults = {
|
||||
"endpoint": "https://api.example.com",
|
||||
"llm_name": "openai",
|
||||
"model_id": "gpt-4",
|
||||
"api_key": "test_key",
|
||||
"user_api_key": None,
|
||||
"prompt": "You are helpful.",
|
||||
"chat_history": [],
|
||||
"decoded_token": {"sub": "user1"},
|
||||
"attachments": [],
|
||||
"json_schema": None,
|
||||
}
|
||||
defaults.update(overrides)
|
||||
|
||||
with patch("application.agents.workflow_agent.log_activity", lambda **kw: lambda f: f):
|
||||
from application.agents.workflow_agent import WorkflowAgent
|
||||
agent = WorkflowAgent(**defaults)
|
||||
return agent
|
||||
|
||||
|
||||
class TestWorkflowAgentInit:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_sets_attributes(self):
|
||||
agent = _make_agent(workflow_id="wf1", workflow_owner="owner1")
|
||||
assert agent.workflow_id == "wf1"
|
||||
assert agent.workflow_owner == "owner1"
|
||||
assert agent._engine is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_embedded_workflow(self):
|
||||
wf_data = {"nodes": [], "edges": [], "name": "Test"}
|
||||
agent = _make_agent(workflow=wf_data)
|
||||
assert agent._workflow_data == wf_data
|
||||
|
||||
|
||||
class TestParseEmbeddedWorkflow:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_parses_valid_workflow(self):
|
||||
wf_data = {
|
||||
"name": "Test Workflow",
|
||||
"description": "A test",
|
||||
"nodes": [
|
||||
{"id": "n1", "type": "start", "title": "Start", "data": {}, "position": {"x": 0, "y": 0}},
|
||||
{"id": "n2", "type": "end", "title": "End", "data": {}, "position": {"x": 100, "y": 0}},
|
||||
],
|
||||
"edges": [
|
||||
{"id": "e1", "source": "n1", "target": "n2", "sourceHandle": "out", "targetHandle": "in"},
|
||||
],
|
||||
}
|
||||
agent = _make_agent(workflow=wf_data, workflow_id="wf1")
|
||||
graph = agent._parse_embedded_workflow()
|
||||
assert graph is not None
|
||||
assert len(graph.nodes) == 2
|
||||
assert len(graph.edges) == 1
|
||||
assert graph.workflow.name == "Test Workflow"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_edge_source_id_alias(self):
|
||||
wf_data = {
|
||||
"nodes": [{"id": "n1", "type": "start", "data": {}}],
|
||||
"edges": [{"id": "e1", "source_id": "n1", "target_id": "n2", "source_handle": "out", "target_handle": "in"}],
|
||||
}
|
||||
agent = _make_agent(workflow=wf_data)
|
||||
graph = agent._parse_embedded_workflow()
|
||||
assert graph is not None
|
||||
assert graph.edges[0].source_id == "n1"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invalid_data_returns_none(self):
|
||||
agent = _make_agent(workflow={"nodes": [{"bad": "data"}], "edges": []})
|
||||
graph = agent._parse_embedded_workflow()
|
||||
assert graph is None
|
||||
|
||||
|
||||
class TestLoadWorkflowGraph:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_uses_embedded_when_available(self):
|
||||
agent = _make_agent(workflow={"nodes": [], "edges": [], "name": "E"})
|
||||
agent._parse_embedded_workflow = MagicMock(return_value="parsed_graph")
|
||||
result = agent._load_workflow_graph()
|
||||
assert result == "parsed_graph"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_uses_database_when_workflow_id(self):
|
||||
agent = _make_agent(workflow_id="wf1")
|
||||
agent._load_from_database = MagicMock(return_value="db_graph")
|
||||
result = agent._load_workflow_graph()
|
||||
assert result == "db_graph"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_returns_none_when_nothing(self):
|
||||
agent = _make_agent()
|
||||
result = agent._load_workflow_graph()
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestLoadFromDatabase:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invalid_workflow_id_returns_none(self):
|
||||
agent = _make_agent(workflow_id="invalid!")
|
||||
result = agent._load_from_database()
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_owner_returns_none(self):
|
||||
agent = _make_agent(workflow_id="507f1f77bcf86cd799439011", decoded_token={})
|
||||
agent.workflow_owner = None
|
||||
result = agent._load_from_database()
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_uses_decoded_token_sub(self):
|
||||
agent = _make_agent(
|
||||
workflow_id="507f1f77bcf86cd799439011",
|
||||
decoded_token={"sub": "user1"},
|
||||
)
|
||||
agent.workflow_owner = None
|
||||
|
||||
mock_collection = MagicMock()
|
||||
mock_collection.find_one.return_value = None
|
||||
mock_db = MagicMock()
|
||||
mock_db.__getitem__ = MagicMock(return_value=mock_collection)
|
||||
|
||||
with patch("application.agents.workflow_agent.MongoDB") as MockMongo, \
|
||||
patch("application.agents.workflow_agent.settings") as mock_settings:
|
||||
mock_settings.MONGO_DB_NAME = "test_db"
|
||||
MockMongo.get_client.return_value = {"test_db": mock_db}
|
||||
result = agent._load_from_database()
|
||||
assert result is None # workflow_doc not found
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_successful_load(self):
|
||||
agent = _make_agent(
|
||||
workflow_id="507f1f77bcf86cd799439011",
|
||||
workflow_owner="owner1",
|
||||
)
|
||||
|
||||
mock_wf_coll = MagicMock()
|
||||
mock_wf_coll.find_one.return_value = {
|
||||
"_id": "507f1f77bcf86cd799439011",
|
||||
"name": "Test WF",
|
||||
"user": "owner1",
|
||||
"current_graph_version": 1,
|
||||
}
|
||||
|
||||
mock_nodes_coll = MagicMock()
|
||||
mock_nodes_coll.find.return_value = [
|
||||
{"id": "n1", "workflow_id": "507f1f77bcf86cd799439011", "type": "start",
|
||||
"title": "Start", "position": {"x": 0, "y": 0}, "config": {}},
|
||||
]
|
||||
|
||||
mock_edges_coll = MagicMock()
|
||||
mock_edges_coll.find.return_value = []
|
||||
|
||||
def getitem(name):
|
||||
return {"workflows": mock_wf_coll, "workflow_nodes": mock_nodes_coll, "workflow_edges": mock_edges_coll}[name]
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.__getitem__ = MagicMock(side_effect=getitem)
|
||||
|
||||
with patch("application.agents.workflow_agent.MongoDB") as MockMongo, \
|
||||
patch("application.agents.workflow_agent.settings") as mock_settings:
|
||||
mock_settings.MONGO_DB_NAME = "test_db"
|
||||
MockMongo.get_client.return_value = {"test_db": mock_db}
|
||||
result = agent._load_from_database()
|
||||
|
||||
assert result is not None
|
||||
assert len(result.nodes) == 1
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invalid_graph_version(self):
|
||||
agent = _make_agent(
|
||||
workflow_id="507f1f77bcf86cd799439011",
|
||||
workflow_owner="owner1",
|
||||
)
|
||||
|
||||
mock_wf_coll = MagicMock()
|
||||
mock_wf_coll.find_one.return_value = {
|
||||
"_id": "507f1f77bcf86cd799439011",
|
||||
"name": "WF",
|
||||
"user": "owner1",
|
||||
"current_graph_version": "bad",
|
||||
}
|
||||
|
||||
mock_nodes_coll = MagicMock()
|
||||
mock_nodes_coll.find.return_value = []
|
||||
mock_edges_coll = MagicMock()
|
||||
mock_edges_coll.find.return_value = []
|
||||
|
||||
def getitem(name):
|
||||
return {"workflows": mock_wf_coll, "workflow_nodes": mock_nodes_coll, "workflow_edges": mock_edges_coll}[name]
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.__getitem__ = MagicMock(side_effect=getitem)
|
||||
|
||||
with patch("application.agents.workflow_agent.MongoDB") as MockMongo, \
|
||||
patch("application.agents.workflow_agent.settings") as mock_settings:
|
||||
mock_settings.MONGO_DB_NAME = "test_db"
|
||||
MockMongo.get_client.return_value = {"test_db": mock_db}
|
||||
result = agent._load_from_database()
|
||||
assert result is not None # Defaults to version 1
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_fallback_nodes_without_version(self):
|
||||
"""When graph_version=1 finds no nodes, falls back to nodes without version field."""
|
||||
agent = _make_agent(
|
||||
workflow_id="507f1f77bcf86cd799439011",
|
||||
workflow_owner="owner1",
|
||||
)
|
||||
|
||||
mock_wf_coll = MagicMock()
|
||||
mock_wf_coll.find_one.return_value = {
|
||||
"_id": "507f1f77bcf86cd799439011",
|
||||
"name": "WF",
|
||||
"user": "owner1",
|
||||
"current_graph_version": 1,
|
||||
}
|
||||
|
||||
call_count = [0]
|
||||
def nodes_find(query):
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
return [] # No versioned nodes
|
||||
return [{"id": "n1", "workflow_id": "wf", "type": "start",
|
||||
"title": "S", "position": {"x": 0, "y": 0}, "config": {}}]
|
||||
|
||||
mock_nodes_coll = MagicMock()
|
||||
mock_nodes_coll.find.side_effect = nodes_find
|
||||
|
||||
edge_call = [0]
|
||||
def edges_find(query):
|
||||
edge_call[0] += 1
|
||||
if edge_call[0] == 1:
|
||||
return []
|
||||
return []
|
||||
|
||||
mock_edges_coll = MagicMock()
|
||||
mock_edges_coll.find.side_effect = edges_find
|
||||
|
||||
def getitem(name):
|
||||
return {"workflows": mock_wf_coll, "workflow_nodes": mock_nodes_coll, "workflow_edges": mock_edges_coll}[name]
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.__getitem__ = MagicMock(side_effect=getitem)
|
||||
|
||||
with patch("application.agents.workflow_agent.MongoDB") as MockMongo, \
|
||||
patch("application.agents.workflow_agent.settings") as mock_settings:
|
||||
mock_settings.MONGO_DB_NAME = "test_db"
|
||||
MockMongo.get_client.return_value = {"test_db": mock_db}
|
||||
result = agent._load_from_database()
|
||||
assert result is not None
|
||||
assert len(result.nodes) == 1
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_exception_returns_none(self):
|
||||
agent = _make_agent(
|
||||
workflow_id="507f1f77bcf86cd799439011",
|
||||
workflow_owner="owner1",
|
||||
)
|
||||
with patch("application.agents.workflow_agent.MongoDB") as MockMongo:
|
||||
MockMongo.get_client.side_effect = Exception("db error")
|
||||
result = agent._load_from_database()
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGenInner:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_graph_yields_error(self):
|
||||
agent = _make_agent()
|
||||
agent._load_workflow_graph = MagicMock(return_value=None)
|
||||
events = list(agent._gen_inner("query", None))
|
||||
assert any(e.get("type") == "error" for e in events)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_successful_execution(self):
|
||||
agent = _make_agent(workflow_id="wf1")
|
||||
mock_graph = MagicMock(spec=WorkflowGraph)
|
||||
agent._load_workflow_graph = MagicMock(return_value=mock_graph)
|
||||
agent._save_workflow_run = MagicMock()
|
||||
|
||||
mock_engine = MagicMock()
|
||||
mock_engine.execute.return_value = iter([{"answer": "result"}])
|
||||
|
||||
with patch("application.agents.workflow_agent.WorkflowEngine", return_value=mock_engine):
|
||||
events = list(agent._gen_inner("query", None))
|
||||
assert len(events) == 1
|
||||
agent._save_workflow_run.assert_called_once_with("query")
|
||||
|
||||
|
||||
class TestSaveWorkflowRun:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_engine_returns_early(self):
|
||||
agent = _make_agent()
|
||||
agent._engine = None
|
||||
agent._save_workflow_run("query") # Should not raise
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_saves_to_mongo(self):
|
||||
agent = _make_agent(workflow_id="wf1")
|
||||
mock_engine = MagicMock()
|
||||
mock_engine.state = {"query": "test"}
|
||||
mock_engine.execution_log = []
|
||||
mock_engine.get_execution_summary.return_value = []
|
||||
agent._engine = mock_engine
|
||||
|
||||
mock_collection = MagicMock()
|
||||
mock_db = MagicMock()
|
||||
mock_db.__getitem__ = MagicMock(return_value=mock_collection)
|
||||
|
||||
with patch("application.agents.workflow_agent.MongoDB") as MockMongo, \
|
||||
patch("application.agents.workflow_agent.settings") as mock_settings:
|
||||
mock_settings.MONGO_DB_NAME = "test_db"
|
||||
MockMongo.get_client.return_value = {"test_db": mock_db}
|
||||
agent._save_workflow_run("query")
|
||||
|
||||
mock_collection.insert_one.assert_called_once()
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_exception_does_not_propagate(self):
|
||||
agent = _make_agent(workflow_id="wf1")
|
||||
mock_engine = MagicMock()
|
||||
mock_engine.state = {}
|
||||
mock_engine.execution_log = []
|
||||
mock_engine.get_execution_summary.return_value = []
|
||||
agent._engine = mock_engine
|
||||
|
||||
with patch("application.agents.workflow_agent.MongoDB") as MockMongo:
|
||||
MockMongo.get_client.side_effect = Exception("db fail")
|
||||
agent._save_workflow_run("query") # Should not raise
|
||||
|
||||
|
||||
class TestDetermineRunStatus:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_engine_returns_completed(self):
|
||||
agent = _make_agent()
|
||||
agent._engine = None
|
||||
assert agent._determine_run_status() == ExecutionStatus.COMPLETED
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_empty_log_returns_completed(self):
|
||||
agent = _make_agent()
|
||||
agent._engine = MagicMock()
|
||||
agent._engine.execution_log = []
|
||||
assert agent._determine_run_status() == ExecutionStatus.COMPLETED
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_failed_log_returns_failed(self):
|
||||
agent = _make_agent()
|
||||
agent._engine = MagicMock()
|
||||
agent._engine.execution_log = [
|
||||
{"status": "completed"},
|
||||
{"status": "failed"},
|
||||
]
|
||||
assert agent._determine_run_status() == ExecutionStatus.FAILED
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_all_completed_returns_completed(self):
|
||||
agent = _make_agent()
|
||||
agent._engine = MagicMock()
|
||||
agent._engine.execution_log = [
|
||||
{"status": "completed"},
|
||||
{"status": "completed"},
|
||||
]
|
||||
assert agent._determine_run_status() == ExecutionStatus.COMPLETED
|
||||
|
||||
|
||||
class TestSerializeState:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_serializes_primitives(self):
|
||||
agent = _make_agent()
|
||||
state = {"str": "hello", "int": 42, "float": 3.14, "bool": True, "none": None}
|
||||
result = agent._serialize_state(state)
|
||||
assert result == state
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_serializes_nested_dict(self):
|
||||
agent = _make_agent()
|
||||
state = {"nested": {"key": "value"}}
|
||||
result = agent._serialize_state(state)
|
||||
assert result["nested"]["key"] == "value"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_serializes_list(self):
|
||||
agent = _make_agent()
|
||||
state = {"items": [1, 2, "three"]}
|
||||
result = agent._serialize_state(state)
|
||||
assert result["items"] == [1, 2, "three"]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_serializes_tuple(self):
|
||||
agent = _make_agent()
|
||||
state = {"tup": (1, 2)}
|
||||
result = agent._serialize_state(state)
|
||||
assert result["tup"] == [1, 2]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_serializes_datetime(self):
|
||||
agent = _make_agent()
|
||||
dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
|
||||
state = {"time": dt}
|
||||
result = agent._serialize_state(state)
|
||||
assert "2025-01-01" in result["time"]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_serializes_unknown_to_str(self):
|
||||
agent = _make_agent()
|
||||
state = {"obj": object()}
|
||||
result = agent._serialize_state(state)
|
||||
assert isinstance(result["obj"], str)
|
||||
573
tests/agents/test_workflow_engine_coverage.py
Normal file
573
tests/agents/test_workflow_engine_coverage.py
Normal file
@@ -0,0 +1,573 @@
|
||||
"""Tests covering gaps in WorkflowEngine: execute loop, state/condition/end nodes,
|
||||
template context, source data, structured output parsing, get_execution_summary."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from application.agents.workflows.schemas import (
|
||||
ExecutionStatus,
|
||||
NodeType,
|
||||
WorkflowEdge,
|
||||
WorkflowGraph,
|
||||
WorkflowNode,
|
||||
Workflow,
|
||||
)
|
||||
from application.agents.workflows.workflow_engine import WorkflowEngine
|
||||
|
||||
|
||||
def _make_graph(nodes, edges):
|
||||
wf = Workflow(name="Test", description="test workflow")
|
||||
return WorkflowGraph(workflow=wf, nodes=nodes, edges=edges)
|
||||
|
||||
|
||||
def _make_node(id, type, title="Node", config=None, position=None):
|
||||
return WorkflowNode(
|
||||
id=id,
|
||||
workflow_id="wf1",
|
||||
type=type,
|
||||
title=title,
|
||||
position=position or {"x": 0, "y": 0},
|
||||
config=config or {},
|
||||
)
|
||||
|
||||
|
||||
def _make_edge(id, source, target, source_handle=None, target_handle=None):
|
||||
return WorkflowEdge(
|
||||
id=id,
|
||||
workflow_id="wf1",
|
||||
source=source,
|
||||
target=target,
|
||||
sourceHandle=source_handle,
|
||||
targetHandle=target_handle,
|
||||
)
|
||||
|
||||
|
||||
def _make_agent():
|
||||
agent = MagicMock()
|
||||
agent.chat_history = []
|
||||
agent.endpoint = "https://api.example.com"
|
||||
agent.llm_name = "openai"
|
||||
agent.model_id = "gpt-4"
|
||||
agent.api_key = "key"
|
||||
agent.decoded_token = {"sub": "user1"}
|
||||
agent.retrieved_docs = None
|
||||
return agent
|
||||
|
||||
|
||||
class TestExecuteLoop:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_start_node_yields_error(self):
|
||||
graph = _make_graph([], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
events = list(engine.execute({}, "query"))
|
||||
assert any(e.get("type") == "error" and "start node" in e.get("error", "") for e in events)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_start_to_end(self):
|
||||
nodes = [
|
||||
_make_node("n1", NodeType.START, "Start"),
|
||||
_make_node("n2", NodeType.END, "End", config={"config": {}}),
|
||||
]
|
||||
edges = [_make_edge("e1", "n1", "n2")]
|
||||
graph = _make_graph(nodes, edges)
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
events = list(engine.execute({}, "hello"))
|
||||
step_events = [e for e in events if e.get("type") == "workflow_step"]
|
||||
assert len(step_events) >= 2 # At least start + end
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_node_not_found_yields_error(self):
|
||||
nodes = [_make_node("n1", NodeType.START)]
|
||||
edges = [_make_edge("e1", "n1", "nonexistent")]
|
||||
graph = _make_graph(nodes, edges)
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
events = list(engine.execute({}, "q"))
|
||||
assert any("not found" in e.get("error", "") for e in events)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_node_execution_error_yields_error(self):
|
||||
nodes = [
|
||||
_make_node("n1", NodeType.START),
|
||||
_make_node("n2", NodeType.STATE, "State", config={"config": {"operations": [{"expression": "bad!!!", "target_variable": "x"}]}}),
|
||||
]
|
||||
edges = [_make_edge("e1", "n1", "n2")]
|
||||
graph = _make_graph(nodes, edges)
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
events = list(engine.execute({}, "q"))
|
||||
failed_events = [e for e in events if e.get("status") == "failed"]
|
||||
assert len(failed_events) >= 1
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_max_steps_limit(self):
|
||||
# Create a cycle: start -> state -> state (loop)
|
||||
nodes = [
|
||||
_make_node("n1", NodeType.START),
|
||||
_make_node("n2", NodeType.NOTE, "Note"),
|
||||
]
|
||||
edges = [_make_edge("e1", "n1", "n2"), _make_edge("e2", "n2", "n2")]
|
||||
graph = _make_graph(nodes, edges)
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
engine.MAX_EXECUTION_STEPS = 5
|
||||
events = list(engine.execute({}, "q"))
|
||||
# Should not run forever
|
||||
assert len(events) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_branch_ends_without_end_node(self):
|
||||
nodes = [
|
||||
_make_node("n1", NodeType.START),
|
||||
_make_node("n2", NodeType.NOTE, "Note"),
|
||||
]
|
||||
edges = [_make_edge("e1", "n1", "n2")] # n2 has no outgoing edges
|
||||
graph = _make_graph(nodes, edges)
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
events = list(engine.execute({}, "q"))
|
||||
assert len(events) > 0
|
||||
|
||||
|
||||
class TestInitializeState:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_sets_query_and_history(self):
|
||||
graph = _make_graph([], [])
|
||||
agent = _make_agent()
|
||||
agent.chat_history = [{"prompt": "hi", "response": "hey"}]
|
||||
engine = WorkflowEngine(graph, agent)
|
||||
engine._initialize_state({"custom": "value"}, "test query")
|
||||
assert engine.state["query"] == "test query"
|
||||
assert "custom" in engine.state
|
||||
assert engine.state["chat_history"] is not None
|
||||
|
||||
|
||||
class TestGetNextNodeId:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_edges_returns_none(self):
|
||||
nodes = [_make_node("n1", NodeType.START)]
|
||||
graph = _make_graph(nodes, [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
assert engine._get_next_node_id("n1") is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_returns_first_edge_target(self):
|
||||
nodes = [_make_node("n1", NodeType.START), _make_node("n2", NodeType.END)]
|
||||
edges = [_make_edge("e1", "n1", "n2")]
|
||||
graph = _make_graph(nodes, edges)
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
assert engine._get_next_node_id("n1") == "n2"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_condition_uses_matched_handle(self):
|
||||
nodes = [
|
||||
_make_node("n1", NodeType.CONDITION),
|
||||
_make_node("n2", NodeType.END, "Yes End"),
|
||||
_make_node("n3", NodeType.END, "No End"),
|
||||
]
|
||||
edges = [
|
||||
_make_edge("e1", "n1", "n2", source_handle="yes"),
|
||||
_make_edge("e2", "n1", "n3", source_handle="no"),
|
||||
]
|
||||
graph = _make_graph(nodes, edges)
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
engine._condition_result = "no"
|
||||
assert engine._get_next_node_id("n1") == "n3"
|
||||
assert engine._condition_result is None # Cleared after use
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_condition_no_matching_handle_returns_none(self):
|
||||
nodes = [_make_node("n1", NodeType.CONDITION)]
|
||||
edges = [_make_edge("e1", "n1", "n2", source_handle="yes")]
|
||||
graph = _make_graph(nodes, edges)
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
engine._condition_result = "nonexistent"
|
||||
assert engine._get_next_node_id("n1") is None
|
||||
|
||||
|
||||
class TestExecuteStateNode:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_evaluates_operations(self):
|
||||
node = _make_node("n1", NodeType.STATE, config={
|
||||
"config": {
|
||||
"operations": [
|
||||
{"expression": "x + 1", "target_variable": "result"},
|
||||
]
|
||||
}
|
||||
})
|
||||
graph = _make_graph([node], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
engine.state = {"x": 5}
|
||||
list(engine._execute_state_node(node))
|
||||
assert engine.state["result"] == 6
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_skips_empty_expression(self):
|
||||
node = _make_node("n1", NodeType.STATE, config={
|
||||
"config": {
|
||||
"operations": [
|
||||
{"expression": "", "target_variable": "result"},
|
||||
]
|
||||
}
|
||||
})
|
||||
graph = _make_graph([node], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
engine.state = {}
|
||||
list(engine._execute_state_node(node))
|
||||
assert "result" not in engine.state
|
||||
|
||||
|
||||
class TestExecuteConditionNode:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_matches_first_true_case(self):
|
||||
node = _make_node("n1", NodeType.CONDITION, config={
|
||||
"config": {
|
||||
"cases": [
|
||||
{"expression": "x > 10", "source_handle": "high"},
|
||||
{"expression": "x > 5", "source_handle": "medium"},
|
||||
]
|
||||
}
|
||||
})
|
||||
graph = _make_graph([node], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
engine.state = {"x": 7}
|
||||
list(engine._execute_condition_node(node))
|
||||
assert engine._condition_result == "medium"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_falls_through_to_else(self):
|
||||
node = _make_node("n1", NodeType.CONDITION, config={
|
||||
"config": {
|
||||
"cases": [
|
||||
{"expression": "x > 100", "source_handle": "high"},
|
||||
]
|
||||
}
|
||||
})
|
||||
graph = _make_graph([node], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
engine.state = {"x": 1}
|
||||
list(engine._execute_condition_node(node))
|
||||
assert engine._condition_result == "else"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_skips_empty_expression(self):
|
||||
node = _make_node("n1", NodeType.CONDITION, config={
|
||||
"config": {
|
||||
"cases": [
|
||||
{"expression": " ", "source_handle": "a"},
|
||||
{"expression": "true", "source_handle": "b"},
|
||||
]
|
||||
}
|
||||
})
|
||||
graph = _make_graph([node], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
engine.state = {}
|
||||
list(engine._execute_condition_node(node))
|
||||
assert engine._condition_result == "b"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_cel_error_continues(self):
|
||||
node = _make_node("n1", NodeType.CONDITION, config={
|
||||
"config": {
|
||||
"cases": [
|
||||
{"expression": "bad!!!", "source_handle": "a"},
|
||||
{"expression": "true", "source_handle": "b"},
|
||||
]
|
||||
}
|
||||
})
|
||||
graph = _make_graph([node], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
engine.state = {}
|
||||
list(engine._execute_condition_node(node))
|
||||
assert engine._condition_result == "b"
|
||||
|
||||
|
||||
class TestExecuteEndNode:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_with_output_template(self):
|
||||
node = _make_node("n1", NodeType.END, config={
|
||||
"config": {"output_template": "Result: {{ query }}"}
|
||||
})
|
||||
graph = _make_graph([node], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
engine.state = {"query": "hello"}
|
||||
engine._format_template = MagicMock(return_value="Result: hello")
|
||||
events = list(engine._execute_end_node(node))
|
||||
assert len(events) == 1
|
||||
assert events[0]["answer"] == "Result: hello"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_without_output_template(self):
|
||||
node = _make_node("n1", NodeType.END, config={"config": {}})
|
||||
graph = _make_graph([node], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
events = list(engine._execute_end_node(node))
|
||||
assert len(events) == 0
|
||||
|
||||
|
||||
class TestParseStructuredOutput:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_valid_json(self):
|
||||
graph = _make_graph([], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
success, data = engine._parse_structured_output('{"key": "value"}')
|
||||
assert success is True
|
||||
assert data == {"key": "value"}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invalid_json(self):
|
||||
graph = _make_graph([], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
success, data = engine._parse_structured_output("not json")
|
||||
assert success is False
|
||||
assert data is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_empty_string(self):
|
||||
graph = _make_graph([], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
success, data = engine._parse_structured_output("")
|
||||
assert success is False
|
||||
assert data is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_whitespace_only(self):
|
||||
graph = _make_graph([], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
success, data = engine._parse_structured_output(" ")
|
||||
assert success is False
|
||||
assert data is None
|
||||
|
||||
|
||||
class TestNormalizeNodeJsonSchema:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_none_returns_none(self):
|
||||
graph = _make_graph([], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
assert engine._normalize_node_json_schema(None, "Node") is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_valid_schema(self):
|
||||
graph = _make_graph([], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
schema = {"type": "object", "properties": {"name": {"type": "string"}}}
|
||||
result = engine._normalize_node_json_schema(schema, "Node")
|
||||
assert result is not None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invalid_schema_raises(self):
|
||||
graph = _make_graph([], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
with patch("application.agents.workflows.workflow_engine.normalize_json_schema_payload") as mock_norm:
|
||||
from application.core.json_schema_utils import JsonSchemaValidationError
|
||||
mock_norm.side_effect = JsonSchemaValidationError("bad schema")
|
||||
with pytest.raises(ValueError, match="Invalid JSON schema"):
|
||||
engine._normalize_node_json_schema({"bad": True}, "TestNode")
|
||||
|
||||
|
||||
class TestValidateStructuredOutput:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_valid_output_passes(self):
|
||||
graph = _make_graph([], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
schema = {"type": "object", "properties": {"name": {"type": "string"}}}
|
||||
engine._validate_structured_output(schema, {"name": "Alice"}) # Should not raise
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invalid_output_raises(self):
|
||||
graph = _make_graph([], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
schema = {"type": "object", "properties": {"name": {"type": "string"}}, "required": ["name"]}
|
||||
with pytest.raises(ValueError, match="did not match schema"):
|
||||
engine._validate_structured_output(schema, {})
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_jsonschema_module(self):
|
||||
graph = _make_graph([], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
with patch("application.agents.workflows.workflow_engine.jsonschema", None):
|
||||
engine._validate_structured_output({"type": "object"}, {}) # Should not raise
|
||||
|
||||
|
||||
class TestFormatTemplate:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_renders_template(self):
|
||||
graph = _make_graph([], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
engine.state = {"query": "hello"}
|
||||
engine._build_template_context = MagicMock(return_value={"query": "hello"})
|
||||
engine._template_engine = MagicMock()
|
||||
engine._template_engine.render.return_value = "hello world"
|
||||
result = engine._format_template("{{ query }} world")
|
||||
assert result == "hello world"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_render_error_returns_raw(self):
|
||||
from application.templates.template_engine import TemplateRenderError
|
||||
graph = _make_graph([], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
engine._build_template_context = MagicMock(return_value={})
|
||||
engine._template_engine = MagicMock()
|
||||
engine._template_engine.render.side_effect = TemplateRenderError("fail")
|
||||
result = engine._format_template("{{ bad }}")
|
||||
assert result == "{{ bad }}"
|
||||
|
||||
|
||||
class TestBuildTemplateContext:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_includes_state_variables(self):
|
||||
graph = _make_graph([], [])
|
||||
agent = _make_agent()
|
||||
agent.retrieved_docs = None
|
||||
engine = WorkflowEngine(graph, agent)
|
||||
engine.state = {"query": "hello", "custom_var": "value"}
|
||||
context = engine._build_template_context()
|
||||
assert context["agent"]["query"] == "hello"
|
||||
assert "custom_var" in context
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_reserved_namespace_gets_prefixed(self):
|
||||
graph = _make_graph([], [])
|
||||
agent = _make_agent()
|
||||
agent.retrieved_docs = None
|
||||
engine = WorkflowEngine(graph, agent)
|
||||
engine.state = {"source": "my_source_val"}
|
||||
context = engine._build_template_context()
|
||||
assert context.get("agent_source") == "my_source_val"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_passthrough_data(self):
|
||||
graph = _make_graph([], [])
|
||||
agent = _make_agent()
|
||||
agent.retrieved_docs = None
|
||||
engine = WorkflowEngine(graph, agent)
|
||||
engine.state = {"passthrough": {"key": "val"}}
|
||||
context = engine._build_template_context()
|
||||
assert "passthrough" in context or "agent_passthrough" in context
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_tools_data(self):
|
||||
graph = _make_graph([], [])
|
||||
agent = _make_agent()
|
||||
agent.retrieved_docs = None
|
||||
engine = WorkflowEngine(graph, agent)
|
||||
engine.state = {"tools": {"tool1": "result"}}
|
||||
context = engine._build_template_context()
|
||||
assert "agent" in context
|
||||
|
||||
|
||||
class TestGetSourceTemplateData:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_docs_returns_none(self):
|
||||
graph = _make_graph([], [])
|
||||
agent = _make_agent()
|
||||
agent.retrieved_docs = None
|
||||
engine = WorkflowEngine(graph, agent)
|
||||
docs, together = engine._get_source_template_data()
|
||||
assert docs is None
|
||||
assert together is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_empty_docs_returns_none(self):
|
||||
graph = _make_graph([], [])
|
||||
agent = _make_agent()
|
||||
agent.retrieved_docs = []
|
||||
engine = WorkflowEngine(graph, agent)
|
||||
docs, together = engine._get_source_template_data()
|
||||
assert docs is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_docs_with_filename(self):
|
||||
graph = _make_graph([], [])
|
||||
agent = _make_agent()
|
||||
agent.retrieved_docs = [{"text": "content", "filename": "doc.txt"}]
|
||||
engine = WorkflowEngine(graph, agent)
|
||||
docs, together = engine._get_source_template_data()
|
||||
assert docs is not None
|
||||
assert "doc.txt" in together
|
||||
assert "content" in together
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_docs_without_filename(self):
|
||||
graph = _make_graph([], [])
|
||||
agent = _make_agent()
|
||||
agent.retrieved_docs = [{"text": "content only"}]
|
||||
engine = WorkflowEngine(graph, agent)
|
||||
docs, together = engine._get_source_template_data()
|
||||
assert together == "content only"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_skips_non_dict_docs(self):
|
||||
graph = _make_graph([], [])
|
||||
agent = _make_agent()
|
||||
agent.retrieved_docs = ["not a dict", {"text": "ok"}]
|
||||
engine = WorkflowEngine(graph, agent)
|
||||
docs, together = engine._get_source_template_data()
|
||||
assert together == "ok"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_skips_non_string_text(self):
|
||||
graph = _make_graph([], [])
|
||||
agent = _make_agent()
|
||||
agent.retrieved_docs = [{"text": 123}]
|
||||
engine = WorkflowEngine(graph, agent)
|
||||
docs, together = engine._get_source_template_data()
|
||||
assert together is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_doc_with_title_fallback(self):
|
||||
graph = _make_graph([], [])
|
||||
agent = _make_agent()
|
||||
agent.retrieved_docs = [{"text": "content", "title": "doc_title"}]
|
||||
engine = WorkflowEngine(graph, agent)
|
||||
docs, together = engine._get_source_template_data()
|
||||
assert "doc_title" in together
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_doc_with_source_fallback(self):
|
||||
graph = _make_graph([], [])
|
||||
agent = _make_agent()
|
||||
agent.retrieved_docs = [{"text": "content", "source": "src"}]
|
||||
engine = WorkflowEngine(graph, agent)
|
||||
docs, together = engine._get_source_template_data()
|
||||
assert "src" in together
|
||||
|
||||
|
||||
class TestGetExecutionSummary:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_returns_log_entries(self):
|
||||
graph = _make_graph([], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
now = datetime.now(timezone.utc)
|
||||
engine.execution_log = [
|
||||
{
|
||||
"node_id": "n1",
|
||||
"node_type": "start",
|
||||
"status": "completed",
|
||||
"started_at": now,
|
||||
"completed_at": now,
|
||||
"error": None,
|
||||
"state_snapshot": {},
|
||||
}
|
||||
]
|
||||
summary = engine.get_execution_summary()
|
||||
assert len(summary) == 1
|
||||
assert summary[0].node_id == "n1"
|
||||
assert summary[0].status == ExecutionStatus.COMPLETED
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_empty_log(self):
|
||||
graph = _make_graph([], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
assert engine.get_execution_summary() == []
|
||||
235
tests/api/answer/routes/test_answer.py
Normal file
235
tests/api/answer/routes/test_answer.py
Normal file
@@ -0,0 +1,235 @@
|
||||
"""Tests for application/api/answer/routes/answer.py"""
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from bson import ObjectId
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_stream_processor():
|
||||
"""Create a mock StreamProcessor."""
|
||||
with patch(
|
||||
"application.api.answer.routes.answer.StreamProcessor"
|
||||
) as MockProcessor:
|
||||
processor = MagicMock()
|
||||
processor.decoded_token = {"sub": "test_user"}
|
||||
processor.conversation_id = str(ObjectId())
|
||||
processor.agent_config = {}
|
||||
processor.agent_id = str(ObjectId())
|
||||
processor.is_shared_usage = False
|
||||
processor.shared_token = None
|
||||
processor.model_id = "gpt-4"
|
||||
processor.build_agent.return_value = MagicMock()
|
||||
MockProcessor.return_value = processor
|
||||
yield processor
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def answer_client(mock_mongo_db, flask_app):
|
||||
"""Create a test client with the answer route registered."""
|
||||
from flask_restx import Api
|
||||
|
||||
from application.api.answer.routes.answer import answer_ns
|
||||
|
||||
api = Api(flask_app)
|
||||
api.add_namespace(answer_ns)
|
||||
flask_app.config["TESTING"] = True
|
||||
return flask_app.test_client()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestAnswerResourcePost:
|
||||
def test_missing_question_returns_400(self, answer_client, mock_stream_processor):
|
||||
resp = answer_client.post(
|
||||
"/api/answer",
|
||||
data=json.dumps({}),
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_successful_answer(self, answer_client, mock_stream_processor):
|
||||
conv_id = str(ObjectId())
|
||||
with patch.object(
|
||||
mock_stream_processor.build_agent.return_value,
|
||||
"gen",
|
||||
return_value=iter([]),
|
||||
):
|
||||
with patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.validate_request",
|
||||
return_value=None,
|
||||
), patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.check_usage",
|
||||
return_value=None,
|
||||
), patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.complete_stream",
|
||||
return_value=iter(
|
||||
[
|
||||
f'data: {json.dumps({"type": "answer", "answer": "Hello"})}\n\n',
|
||||
f'data: {json.dumps({"type": "id", "id": conv_id})}\n\n',
|
||||
f'data: {json.dumps({"type": "end"})}\n\n',
|
||||
]
|
||||
),
|
||||
), patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.process_response_stream",
|
||||
return_value=(conv_id, "Hello", [], [], "", None),
|
||||
):
|
||||
resp = answer_client.post(
|
||||
"/api/answer",
|
||||
data=json.dumps({"question": "What is Python?"}),
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.get_json()
|
||||
assert data["answer"] == "Hello"
|
||||
assert data["conversation_id"] == conv_id
|
||||
|
||||
def test_unauthorized_returns_401(self, answer_client, mock_stream_processor):
|
||||
mock_stream_processor.decoded_token = None
|
||||
with patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.validate_request",
|
||||
return_value=None,
|
||||
):
|
||||
resp = answer_client.post(
|
||||
"/api/answer",
|
||||
data=json.dumps({"question": "test"}),
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
assert resp.get_json()["error"] == "Unauthorized"
|
||||
|
||||
def test_usage_exceeded_returns_error(self, answer_client, mock_stream_processor):
|
||||
|
||||
with patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.validate_request",
|
||||
return_value=None,
|
||||
), patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.check_usage",
|
||||
) as mock_check:
|
||||
with flask_app_context(answer_client):
|
||||
mock_check.return_value = ({"error": "Usage limit exceeded"}, 429)
|
||||
|
||||
resp = answer_client.post(
|
||||
"/api/answer",
|
||||
data=json.dumps({"question": "test"}),
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 429
|
||||
|
||||
def test_stream_error_returns_400(self, answer_client, mock_stream_processor):
|
||||
with patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.validate_request",
|
||||
return_value=None,
|
||||
), patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.check_usage",
|
||||
return_value=None,
|
||||
), patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.complete_stream",
|
||||
return_value=iter([]),
|
||||
), patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.process_response_stream",
|
||||
return_value=(None, None, None, None, None, "Stream error"),
|
||||
):
|
||||
resp = answer_client.post(
|
||||
"/api/answer",
|
||||
data=json.dumps({"question": "test"}),
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
assert resp.get_json()["error"] == "Stream error"
|
||||
|
||||
def test_exception_returns_500(self, answer_client, mock_stream_processor):
|
||||
with patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.validate_request",
|
||||
return_value=None,
|
||||
), patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.check_usage",
|
||||
return_value=None,
|
||||
), patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.complete_stream",
|
||||
side_effect=RuntimeError("unexpected"),
|
||||
):
|
||||
resp = answer_client.post(
|
||||
"/api/answer",
|
||||
data=json.dumps({"question": "test"}),
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 500
|
||||
assert "error" in resp.get_json()
|
||||
|
||||
def test_structured_info_merged_into_result(
|
||||
self, answer_client, mock_stream_processor
|
||||
):
|
||||
conv_id = str(ObjectId())
|
||||
with patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.validate_request",
|
||||
return_value=None,
|
||||
), patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.check_usage",
|
||||
return_value=None,
|
||||
), patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.complete_stream",
|
||||
return_value=iter([]),
|
||||
), patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.process_response_stream",
|
||||
return_value=(
|
||||
conv_id,
|
||||
'{"key": "val"}',
|
||||
[],
|
||||
[],
|
||||
"",
|
||||
None,
|
||||
{"structured": True, "schema": {"type": "object"}},
|
||||
),
|
||||
):
|
||||
resp = answer_client.post(
|
||||
"/api/answer",
|
||||
data=json.dumps({"question": "test"}),
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.get_json()
|
||||
assert data["structured"] is True
|
||||
assert data["schema"] == {"type": "object"}
|
||||
|
||||
def test_result_contains_all_expected_fields(
|
||||
self, answer_client, mock_stream_processor
|
||||
):
|
||||
conv_id = str(ObjectId())
|
||||
with patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.validate_request",
|
||||
return_value=None,
|
||||
), patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.check_usage",
|
||||
return_value=None,
|
||||
), patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.complete_stream",
|
||||
return_value=iter([]),
|
||||
), patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.process_response_stream",
|
||||
return_value=(
|
||||
conv_id,
|
||||
"answer text",
|
||||
[{"title": "src"}],
|
||||
[{"tool": "t"}],
|
||||
"thinking...",
|
||||
None,
|
||||
),
|
||||
):
|
||||
resp = answer_client.post(
|
||||
"/api/answer",
|
||||
data=json.dumps({"question": "test"}),
|
||||
content_type="application/json",
|
||||
)
|
||||
data = resp.get_json()
|
||||
assert data["conversation_id"] == conv_id
|
||||
assert data["answer"] == "answer text"
|
||||
assert data["sources"] == [{"title": "src"}]
|
||||
assert data["tool_calls"] == [{"tool": "t"}]
|
||||
assert data["thought"] == "thinking..."
|
||||
|
||||
|
||||
def flask_app_context(client):
|
||||
"""Helper to get app context from test client."""
|
||||
return client.application.app_context()
|
||||
195
tests/api/answer/routes/test_stream.py
Normal file
195
tests/api/answer/routes/test_stream.py
Normal file
@@ -0,0 +1,195 @@
|
||||
"""Tests for application/api/answer/routes/stream.py"""
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from bson import ObjectId
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_stream_processor():
|
||||
"""Create a mock StreamProcessor for stream tests."""
|
||||
with patch(
|
||||
"application.api.answer.routes.stream.StreamProcessor"
|
||||
) as MockProcessor:
|
||||
processor = MagicMock()
|
||||
processor.decoded_token = {"sub": "test_user"}
|
||||
processor.conversation_id = str(ObjectId())
|
||||
processor.agent_config = {}
|
||||
processor.agent_id = str(ObjectId())
|
||||
processor.is_shared_usage = False
|
||||
processor.shared_token = None
|
||||
processor.model_id = "gpt-4"
|
||||
processor.build_agent.return_value = MagicMock()
|
||||
MockProcessor.return_value = processor
|
||||
yield processor
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def stream_client(mock_mongo_db, flask_app):
|
||||
"""Create a test client with the stream route registered."""
|
||||
from flask_restx import Api
|
||||
|
||||
from application.api.answer.routes.stream import answer_ns
|
||||
|
||||
api = Api(flask_app)
|
||||
api.add_namespace(answer_ns)
|
||||
flask_app.config["TESTING"] = True
|
||||
return flask_app.test_client()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestStreamResourcePost:
|
||||
def test_missing_question_returns_400(self, stream_client, mock_stream_processor):
|
||||
resp = stream_client.post(
|
||||
"/stream",
|
||||
data=json.dumps({}),
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_successful_stream(self, stream_client, mock_stream_processor):
|
||||
def fake_stream(*args, **kwargs):
|
||||
yield f'data: {json.dumps({"type": "answer", "answer": "Hi"})}\n\n'
|
||||
yield f'data: {json.dumps({"type": "end"})}\n\n'
|
||||
|
||||
with patch(
|
||||
"application.api.answer.routes.stream.StreamResource.validate_request",
|
||||
return_value=None,
|
||||
), patch(
|
||||
"application.api.answer.routes.stream.StreamResource.check_usage",
|
||||
return_value=None,
|
||||
), patch(
|
||||
"application.api.answer.routes.stream.StreamResource.complete_stream",
|
||||
side_effect=fake_stream,
|
||||
):
|
||||
resp = stream_client.post(
|
||||
"/stream",
|
||||
data=json.dumps({"question": "What is Python?"}),
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert "text/event-stream" in resp.content_type
|
||||
data = resp.get_data(as_text=True)
|
||||
assert '"type": "answer"' in data
|
||||
assert '"answer": "Hi"' in data
|
||||
|
||||
def test_unauthorized_returns_401_stream(
|
||||
self, stream_client, mock_stream_processor
|
||||
):
|
||||
mock_stream_processor.decoded_token = None
|
||||
with patch(
|
||||
"application.api.answer.routes.stream.StreamResource.validate_request",
|
||||
return_value=None,
|
||||
):
|
||||
resp = stream_client.post(
|
||||
"/stream",
|
||||
data=json.dumps({"question": "test"}),
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
assert "text/event-stream" in resp.content_type
|
||||
data = resp.get_data(as_text=True)
|
||||
assert "Unauthorized" in data
|
||||
|
||||
def test_usage_exceeded_returns_error(
|
||||
self, stream_client, mock_stream_processor
|
||||
):
|
||||
with patch(
|
||||
"application.api.answer.routes.stream.StreamResource.validate_request",
|
||||
return_value=None,
|
||||
), patch(
|
||||
"application.api.answer.routes.stream.StreamResource.check_usage",
|
||||
) as mock_check:
|
||||
mock_check.return_value = ({"error": "Usage limit exceeded"}, 429)
|
||||
resp = stream_client.post(
|
||||
"/stream",
|
||||
data=json.dumps({"question": "test"}),
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 429
|
||||
|
||||
def test_value_error_returns_400_stream(
|
||||
self, stream_client, mock_stream_processor
|
||||
):
|
||||
mock_stream_processor.build_agent.side_effect = ValueError("bad data")
|
||||
with patch(
|
||||
"application.api.answer.routes.stream.StreamResource.validate_request",
|
||||
return_value=None,
|
||||
):
|
||||
resp = stream_client.post(
|
||||
"/stream",
|
||||
data=json.dumps({"question": "test"}),
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
assert "text/event-stream" in resp.content_type
|
||||
data = resp.get_data(as_text=True)
|
||||
assert "Malformed request body" in data
|
||||
|
||||
def test_general_exception_returns_400_stream(
|
||||
self, stream_client, mock_stream_processor
|
||||
):
|
||||
mock_stream_processor.build_agent.side_effect = RuntimeError("crash")
|
||||
with patch(
|
||||
"application.api.answer.routes.stream.StreamResource.validate_request",
|
||||
return_value=None,
|
||||
):
|
||||
resp = stream_client.post(
|
||||
"/stream",
|
||||
data=json.dumps({"question": "test"}),
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
assert "text/event-stream" in resp.content_type
|
||||
data = resp.get_data(as_text=True)
|
||||
assert "Unknown error occurred" in data
|
||||
|
||||
def test_index_in_data_requires_conversation_id(
|
||||
self, stream_client, mock_stream_processor
|
||||
):
|
||||
"""When 'index' is present, validate_request is called with require_conversation_id=True."""
|
||||
resp = stream_client.post(
|
||||
"/stream",
|
||||
data=json.dumps({"question": "test", "index": 0}),
|
||||
content_type="application/json",
|
||||
)
|
||||
# Should get 400 since conversation_id is missing
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_stream_passes_attachments_and_index(
|
||||
self, stream_client, mock_stream_processor
|
||||
):
|
||||
"""Verify attachments and index params are forwarded to complete_stream."""
|
||||
|
||||
def fake_stream(*args, **kwargs):
|
||||
yield f'data: {json.dumps({"type": "end"})}\n\n'
|
||||
|
||||
conv_id = str(ObjectId())
|
||||
with patch(
|
||||
"application.api.answer.routes.stream.StreamResource.validate_request",
|
||||
return_value=None,
|
||||
), patch(
|
||||
"application.api.answer.routes.stream.StreamResource.check_usage",
|
||||
return_value=None,
|
||||
), patch(
|
||||
"application.api.answer.routes.stream.StreamResource.complete_stream",
|
||||
side_effect=fake_stream,
|
||||
) as mock_complete:
|
||||
resp = stream_client.post(
|
||||
"/stream",
|
||||
data=json.dumps(
|
||||
{
|
||||
"question": "test",
|
||||
"conversation_id": conv_id,
|
||||
"index": 3,
|
||||
"attachments": ["att1", "att2"],
|
||||
}
|
||||
),
|
||||
content_type="application/json",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
call_kwargs = mock_complete.call_args
|
||||
assert call_kwargs.kwargs.get("index") == 3
|
||||
assert call_kwargs.kwargs.get("attachment_ids") == ["att1", "att2"]
|
||||
0
tests/api/answer/services/compression/__init__.py
Normal file
0
tests/api/answer/services/compression/__init__.py
Normal file
303
tests/api/answer/services/compression/test_message_builder.py
Normal file
303
tests/api/answer/services/compression/test_message_builder.py
Normal file
@@ -0,0 +1,303 @@
|
||||
"""Tests for application/api/answer/services/compression/message_builder.py"""
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
from application.api.answer.services.compression.message_builder import MessageBuilder
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestBuildFromCompressedContext:
|
||||
def test_no_compression_returns_system_only(self):
|
||||
messages = MessageBuilder.build_from_compressed_context(
|
||||
system_prompt="You are helpful.",
|
||||
compressed_summary=None,
|
||||
recent_queries=[],
|
||||
)
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["role"] == "system"
|
||||
assert messages[0]["content"] == "You are helpful."
|
||||
|
||||
def test_with_recent_queries_no_compression(self):
|
||||
queries = [
|
||||
{"prompt": "Hello", "response": "Hi there!"},
|
||||
{"prompt": "How are you?", "response": "I'm fine."},
|
||||
]
|
||||
messages = MessageBuilder.build_from_compressed_context(
|
||||
system_prompt="System prompt",
|
||||
compressed_summary=None,
|
||||
recent_queries=queries,
|
||||
)
|
||||
# system + 2 * (user + assistant) = 5
|
||||
assert len(messages) == 5
|
||||
assert messages[1] == {"role": "user", "content": "Hello"}
|
||||
assert messages[2] == {"role": "assistant", "content": "Hi there!"}
|
||||
assert messages[3] == {"role": "user", "content": "How are you?"}
|
||||
assert messages[4] == {"role": "assistant", "content": "I'm fine."}
|
||||
|
||||
def test_with_compressed_summary_appended_to_system(self):
|
||||
messages = MessageBuilder.build_from_compressed_context(
|
||||
system_prompt="You are helpful.",
|
||||
compressed_summary="Previous: user asked about Python.",
|
||||
recent_queries=[{"prompt": "More?", "response": "Sure."}],
|
||||
)
|
||||
system_content = messages[0]["content"]
|
||||
assert "This session is being continued" in system_content
|
||||
assert "Previous: user asked about Python." in system_content
|
||||
|
||||
def test_mid_execution_context_type(self):
|
||||
messages = MessageBuilder.build_from_compressed_context(
|
||||
system_prompt="System",
|
||||
compressed_summary="Summary here",
|
||||
recent_queries=[{"prompt": "q", "response": "r"}],
|
||||
context_type="mid_execution",
|
||||
)
|
||||
system_content = messages[0]["content"]
|
||||
assert "Context window limit reached" in system_content
|
||||
|
||||
def test_include_tool_calls(self):
|
||||
queries = [
|
||||
{
|
||||
"prompt": "Search for X",
|
||||
"response": "Found X",
|
||||
"tool_calls": [
|
||||
{
|
||||
"call_id": "call-1",
|
||||
"action_name": "search",
|
||||
"arguments": {"q": "X"},
|
||||
"result": "X found",
|
||||
}
|
||||
],
|
||||
}
|
||||
]
|
||||
messages = MessageBuilder.build_from_compressed_context(
|
||||
system_prompt="System",
|
||||
compressed_summary=None,
|
||||
recent_queries=queries,
|
||||
include_tool_calls=True,
|
||||
)
|
||||
# system + user + assistant + tool_call_assistant + tool_response = 5
|
||||
assert len(messages) == 5
|
||||
assert messages[3]["role"] == "assistant"
|
||||
assert "function_call" in messages[3]["content"][0]
|
||||
assert messages[4]["role"] == "tool"
|
||||
assert "function_response" in messages[4]["content"][0]
|
||||
|
||||
def test_tool_calls_not_included_by_default(self):
|
||||
queries = [
|
||||
{
|
||||
"prompt": "Search",
|
||||
"response": "Found",
|
||||
"tool_calls": [
|
||||
{
|
||||
"call_id": "c1",
|
||||
"action_name": "search",
|
||||
"arguments": {},
|
||||
"result": "ok",
|
||||
}
|
||||
],
|
||||
}
|
||||
]
|
||||
messages = MessageBuilder.build_from_compressed_context(
|
||||
system_prompt="System",
|
||||
compressed_summary=None,
|
||||
recent_queries=queries,
|
||||
include_tool_calls=False,
|
||||
)
|
||||
# system + user + assistant = 3 (no tool messages)
|
||||
assert len(messages) == 3
|
||||
|
||||
def test_tool_call_without_call_id_generates_uuid(self):
|
||||
queries = [
|
||||
{
|
||||
"prompt": "q",
|
||||
"response": "r",
|
||||
"tool_calls": [
|
||||
{
|
||||
"action_name": "act",
|
||||
"arguments": {},
|
||||
"result": "res",
|
||||
}
|
||||
],
|
||||
}
|
||||
]
|
||||
messages = MessageBuilder.build_from_compressed_context(
|
||||
system_prompt="S",
|
||||
compressed_summary=None,
|
||||
recent_queries=queries,
|
||||
include_tool_calls=True,
|
||||
)
|
||||
tool_msg = messages[3]["content"][0]
|
||||
call_id = tool_msg["function_call"]["call_id"]
|
||||
assert call_id is not None
|
||||
assert len(call_id) > 0
|
||||
|
||||
def test_continuation_message_when_no_recent_queries_but_has_summary(self):
|
||||
messages = MessageBuilder.build_from_compressed_context(
|
||||
system_prompt="System",
|
||||
compressed_summary="Everything was compressed",
|
||||
recent_queries=[],
|
||||
)
|
||||
# system + continuation user message = 2
|
||||
assert len(messages) == 2
|
||||
assert messages[1]["role"] == "user"
|
||||
assert "continue" in messages[1]["content"].lower()
|
||||
|
||||
def test_no_continuation_when_no_summary(self):
|
||||
messages = MessageBuilder.build_from_compressed_context(
|
||||
system_prompt="System",
|
||||
compressed_summary=None,
|
||||
recent_queries=[],
|
||||
)
|
||||
assert len(messages) == 1
|
||||
|
||||
def test_queries_without_prompt_or_response_skipped(self):
|
||||
queries = [
|
||||
{"other_field": "value"},
|
||||
{"prompt": "real", "response": "answer"},
|
||||
]
|
||||
messages = MessageBuilder.build_from_compressed_context(
|
||||
system_prompt="S",
|
||||
compressed_summary=None,
|
||||
recent_queries=queries,
|
||||
)
|
||||
# system + 1 valid query (user + assistant) = 3
|
||||
assert len(messages) == 3
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestAppendCompressionContext:
|
||||
def test_pre_request_context(self):
|
||||
result = MessageBuilder._append_compression_context(
|
||||
"Original prompt", "Summary text", "pre_request"
|
||||
)
|
||||
assert "This session is being continued" in result
|
||||
assert "Summary text" in result
|
||||
assert result.startswith("Original prompt")
|
||||
|
||||
def test_mid_execution_context(self):
|
||||
result = MessageBuilder._append_compression_context(
|
||||
"Original prompt", "Summary text", "mid_execution"
|
||||
)
|
||||
assert "Context window limit reached" in result
|
||||
assert "Summary text" in result
|
||||
|
||||
def test_removes_existing_compression_context(self):
|
||||
prompt_with_existing = (
|
||||
"Original prompt\n\n---\n\nThis session is being continued from old"
|
||||
)
|
||||
result = MessageBuilder._append_compression_context(
|
||||
prompt_with_existing, "New summary", "pre_request"
|
||||
)
|
||||
# Should not contain old context twice
|
||||
assert result.count("This session is being continued") == 1
|
||||
assert "New summary" in result
|
||||
|
||||
def test_removes_mid_execution_context(self):
|
||||
prompt_with_existing = (
|
||||
"Original\n\n---\n\nContext window limit reached during execution. Old."
|
||||
)
|
||||
result = MessageBuilder._append_compression_context(
|
||||
prompt_with_existing, "New", "mid_execution"
|
||||
)
|
||||
assert result.count("Context window limit reached") == 1
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestRebuildMessagesAfterCompression:
|
||||
def test_basic_rebuild(self):
|
||||
messages = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "old message"},
|
||||
{"role": "assistant", "content": "old reply"},
|
||||
]
|
||||
recent = [{"prompt": "new q", "response": "new r"}]
|
||||
|
||||
result = MessageBuilder.rebuild_messages_after_compression(
|
||||
messages=messages,
|
||||
compressed_summary="Everything was compressed.",
|
||||
recent_queries=recent,
|
||||
)
|
||||
assert result is not None
|
||||
# system + user + assistant = 3
|
||||
assert len(result) == 3
|
||||
assert "Context window limit reached" in result[0]["content"]
|
||||
assert result[1] == {"role": "user", "content": "new q"}
|
||||
assert result[2] == {"role": "assistant", "content": "new r"}
|
||||
|
||||
def test_returns_none_without_system_message(self):
|
||||
messages = [
|
||||
{"role": "user", "content": "hello"},
|
||||
]
|
||||
result = MessageBuilder.rebuild_messages_after_compression(
|
||||
messages=messages,
|
||||
compressed_summary="summary",
|
||||
recent_queries=[],
|
||||
)
|
||||
assert result is None
|
||||
|
||||
def test_no_summary_keeps_system_unchanged(self):
|
||||
messages = [{"role": "system", "content": "Be helpful."}]
|
||||
result = MessageBuilder.rebuild_messages_after_compression(
|
||||
messages=messages,
|
||||
compressed_summary=None,
|
||||
recent_queries=[],
|
||||
)
|
||||
assert result is not None
|
||||
assert result[0]["content"] == "Be helpful."
|
||||
|
||||
def test_include_tool_calls_in_rebuild(self):
|
||||
messages = [{"role": "system", "content": "S"}]
|
||||
recent = [
|
||||
{
|
||||
"prompt": "q",
|
||||
"response": "r",
|
||||
"tool_calls": [
|
||||
{
|
||||
"call_id": "c1",
|
||||
"action_name": "act",
|
||||
"arguments": {"a": 1},
|
||||
"result": "done",
|
||||
}
|
||||
],
|
||||
}
|
||||
]
|
||||
result = MessageBuilder.rebuild_messages_after_compression(
|
||||
messages=messages,
|
||||
compressed_summary="s",
|
||||
recent_queries=recent,
|
||||
include_tool_calls=True,
|
||||
)
|
||||
# system + user + assistant + tool_call + tool_response = 5
|
||||
assert len(result) == 5
|
||||
|
||||
def test_continuation_added_when_no_recent_queries(self):
|
||||
messages = [{"role": "system", "content": "S"}]
|
||||
result = MessageBuilder.rebuild_messages_after_compression(
|
||||
messages=messages,
|
||||
compressed_summary="All compressed",
|
||||
recent_queries=[],
|
||||
)
|
||||
assert len(result) == 2
|
||||
assert result[1]["role"] == "user"
|
||||
assert "continue" in result[1]["content"].lower()
|
||||
|
||||
def test_include_current_execution_preserves_extra_messages(self):
|
||||
messages = [
|
||||
{"role": "system", "content": "S"},
|
||||
{"role": "user", "content": "q1"},
|
||||
{"role": "assistant", "content": "r1"},
|
||||
{"role": "user", "content": "current execution msg"},
|
||||
]
|
||||
recent = [{"prompt": "q1", "response": "r1"}]
|
||||
|
||||
result = MessageBuilder.rebuild_messages_after_compression(
|
||||
messages=messages,
|
||||
compressed_summary="summary",
|
||||
recent_queries=recent,
|
||||
include_current_execution=True,
|
||||
)
|
||||
assert result is not None
|
||||
# Should include the current execution message
|
||||
contents = [m.get("content") for m in result]
|
||||
assert "current execution msg" in contents
|
||||
447
tests/api/answer/services/compression/test_orchestrator.py
Normal file
447
tests/api/answer/services/compression/test_orchestrator.py
Normal file
@@ -0,0 +1,447 @@
|
||||
"""Tests for application/api/answer/services/compression/orchestrator.py"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from application.api.answer.services.compression.orchestrator import (
|
||||
CompressionOrchestrator,
|
||||
)
|
||||
from application.api.answer.services.compression.types import (
|
||||
CompressionMetadata,
|
||||
CompressionResult,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_conversation_service():
|
||||
svc = MagicMock()
|
||||
return svc
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_threshold_checker():
|
||||
checker = MagicMock()
|
||||
return checker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def orchestrator(mock_conversation_service, mock_threshold_checker):
|
||||
return CompressionOrchestrator(
|
||||
conversation_service=mock_conversation_service,
|
||||
threshold_checker=mock_threshold_checker,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_conversation():
|
||||
return {
|
||||
"queries": [
|
||||
{"prompt": "q0", "response": "r0"},
|
||||
{"prompt": "q1", "response": "r1"},
|
||||
{"prompt": "q2", "response": "r2"},
|
||||
],
|
||||
"compression_metadata": {},
|
||||
"agent_id": "agent-1",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def decoded_token():
|
||||
return {"sub": "user123"}
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCompressIfNeeded:
|
||||
def test_conversation_not_found_returns_failure(
|
||||
self, orchestrator, mock_conversation_service
|
||||
):
|
||||
mock_conversation_service.get_conversation.return_value = None
|
||||
|
||||
result = orchestrator.compress_if_needed(
|
||||
conversation_id="conv1",
|
||||
user_id="user1",
|
||||
model_id="gpt-4",
|
||||
decoded_token={"sub": "user1"},
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "not found" in result.error
|
||||
|
||||
def test_no_compression_needed(
|
||||
self,
|
||||
orchestrator,
|
||||
mock_conversation_service,
|
||||
mock_threshold_checker,
|
||||
sample_conversation,
|
||||
):
|
||||
mock_conversation_service.get_conversation.return_value = sample_conversation
|
||||
mock_threshold_checker.should_compress.return_value = False
|
||||
|
||||
result = orchestrator.compress_if_needed(
|
||||
conversation_id="conv1",
|
||||
user_id="user1",
|
||||
model_id="gpt-4",
|
||||
decoded_token={"sub": "user1"},
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.compression_performed is False
|
||||
assert len(result.recent_queries) == 3
|
||||
|
||||
def test_compression_performed_successfully(
|
||||
self,
|
||||
orchestrator,
|
||||
mock_conversation_service,
|
||||
mock_threshold_checker,
|
||||
sample_conversation,
|
||||
decoded_token,
|
||||
):
|
||||
mock_conversation_service.get_conversation.return_value = sample_conversation
|
||||
mock_threshold_checker.should_compress.return_value = True
|
||||
|
||||
mock_metadata = MagicMock(spec=CompressionMetadata)
|
||||
mock_metadata.compression_ratio = 5.0
|
||||
mock_metadata.original_token_count = 1000
|
||||
mock_metadata.compressed_token_count = 200
|
||||
mock_metadata.to_dict.return_value = {"query_index": 2}
|
||||
|
||||
with patch.object(
|
||||
orchestrator, "_perform_compression"
|
||||
) as mock_perform:
|
||||
mock_perform.return_value = CompressionResult.success_with_compression(
|
||||
"compressed summary",
|
||||
[{"prompt": "q2", "response": "r2"}],
|
||||
mock_metadata,
|
||||
)
|
||||
|
||||
result = orchestrator.compress_if_needed(
|
||||
conversation_id="conv1",
|
||||
user_id="user1",
|
||||
model_id="gpt-4",
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.compression_performed is True
|
||||
assert result.compressed_summary == "compressed summary"
|
||||
mock_perform.assert_called_once()
|
||||
|
||||
def test_exception_returns_failure(
|
||||
self,
|
||||
orchestrator,
|
||||
mock_conversation_service,
|
||||
):
|
||||
mock_conversation_service.get_conversation.side_effect = RuntimeError("DB down")
|
||||
|
||||
result = orchestrator.compress_if_needed(
|
||||
conversation_id="conv1",
|
||||
user_id="user1",
|
||||
model_id="gpt-4",
|
||||
decoded_token={"sub": "user1"},
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "DB down" in result.error
|
||||
|
||||
def test_custom_query_tokens(
|
||||
self,
|
||||
orchestrator,
|
||||
mock_conversation_service,
|
||||
mock_threshold_checker,
|
||||
sample_conversation,
|
||||
):
|
||||
mock_conversation_service.get_conversation.return_value = sample_conversation
|
||||
mock_threshold_checker.should_compress.return_value = False
|
||||
|
||||
orchestrator.compress_if_needed(
|
||||
conversation_id="conv1",
|
||||
user_id="user1",
|
||||
model_id="gpt-4",
|
||||
decoded_token={"sub": "user1"},
|
||||
current_query_tokens=1000,
|
||||
)
|
||||
|
||||
mock_threshold_checker.should_compress.assert_called_once_with(
|
||||
sample_conversation, "gpt-4", 1000
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestPerformCompression:
|
||||
@patch(
|
||||
"application.api.answer.services.compression.orchestrator.get_provider_from_model_id"
|
||||
)
|
||||
@patch(
|
||||
"application.api.answer.services.compression.orchestrator.get_api_key_for_provider"
|
||||
)
|
||||
@patch("application.api.answer.services.compression.orchestrator.LLMCreator")
|
||||
@patch("application.api.answer.services.compression.orchestrator.CompressionService")
|
||||
@patch("application.api.answer.services.compression.orchestrator.settings")
|
||||
def test_successful_compression(
|
||||
self,
|
||||
mock_settings,
|
||||
MockCompressionService,
|
||||
MockLLMCreator,
|
||||
mock_get_api_key,
|
||||
mock_get_provider,
|
||||
mock_conversation_service,
|
||||
mock_threshold_checker,
|
||||
sample_conversation,
|
||||
decoded_token,
|
||||
):
|
||||
mock_settings.COMPRESSION_MODEL_OVERRIDE = None
|
||||
mock_get_provider.return_value = "openai"
|
||||
mock_get_api_key.return_value = "sk-test"
|
||||
MockLLMCreator.create_llm.return_value = MagicMock()
|
||||
|
||||
mock_metadata = MagicMock(spec=CompressionMetadata)
|
||||
mock_metadata.compression_ratio = 5.0
|
||||
mock_metadata.original_token_count = 500
|
||||
mock_metadata.compressed_token_count = 100
|
||||
|
||||
mock_svc_instance = MagicMock()
|
||||
mock_svc_instance.compress_and_save.return_value = mock_metadata
|
||||
mock_svc_instance.get_compressed_context.return_value = (
|
||||
"compressed text",
|
||||
[{"prompt": "q2", "response": "r2"}],
|
||||
)
|
||||
MockCompressionService.return_value = mock_svc_instance
|
||||
|
||||
# After compression, reload conversation
|
||||
mock_conversation_service.get_conversation.return_value = sample_conversation
|
||||
|
||||
orch = CompressionOrchestrator(
|
||||
conversation_service=mock_conversation_service,
|
||||
threshold_checker=mock_threshold_checker,
|
||||
)
|
||||
|
||||
result = orch._perform_compression(
|
||||
"conv1", sample_conversation, "gpt-4", decoded_token
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.compression_performed is True
|
||||
assert result.compressed_summary == "compressed text"
|
||||
mock_svc_instance.compress_and_save.assert_called_once()
|
||||
|
||||
@patch(
|
||||
"application.api.answer.services.compression.orchestrator.get_provider_from_model_id"
|
||||
)
|
||||
@patch(
|
||||
"application.api.answer.services.compression.orchestrator.get_api_key_for_provider"
|
||||
)
|
||||
@patch("application.api.answer.services.compression.orchestrator.LLMCreator")
|
||||
@patch("application.api.answer.services.compression.orchestrator.settings")
|
||||
def test_uses_compression_model_override(
|
||||
self,
|
||||
mock_settings,
|
||||
MockLLMCreator,
|
||||
mock_get_api_key,
|
||||
mock_get_provider,
|
||||
mock_conversation_service,
|
||||
mock_threshold_checker,
|
||||
decoded_token,
|
||||
):
|
||||
mock_settings.COMPRESSION_MODEL_OVERRIDE = "gpt-3.5-turbo"
|
||||
mock_get_provider.return_value = "openai"
|
||||
mock_get_api_key.return_value = "sk-test"
|
||||
MockLLMCreator.create_llm.return_value = MagicMock()
|
||||
|
||||
conversation = {"queries": [{"prompt": "q", "response": "r"}], "agent_id": "a"}
|
||||
|
||||
with patch(
|
||||
"application.api.answer.services.compression.orchestrator.CompressionService"
|
||||
) as MockCS:
|
||||
mock_svc = MagicMock()
|
||||
mock_svc.compress_and_save.return_value = MagicMock(
|
||||
compression_ratio=3.0,
|
||||
original_token_count=300,
|
||||
compressed_token_count=100,
|
||||
)
|
||||
mock_svc.get_compressed_context.return_value = ("s", [])
|
||||
MockCS.return_value = mock_svc
|
||||
|
||||
mock_conversation_service.get_conversation.return_value = conversation
|
||||
|
||||
orch = CompressionOrchestrator(
|
||||
conversation_service=mock_conversation_service,
|
||||
threshold_checker=mock_threshold_checker,
|
||||
)
|
||||
orch._perform_compression("c1", conversation, "gpt-4", decoded_token)
|
||||
|
||||
# Verify the override model was used
|
||||
mock_get_provider.assert_called_with("gpt-3.5-turbo")
|
||||
|
||||
@patch(
|
||||
"application.api.answer.services.compression.orchestrator.get_provider_from_model_id"
|
||||
)
|
||||
@patch(
|
||||
"application.api.answer.services.compression.orchestrator.get_api_key_for_provider"
|
||||
)
|
||||
@patch("application.api.answer.services.compression.orchestrator.LLMCreator")
|
||||
@patch("application.api.answer.services.compression.orchestrator.CompressionService")
|
||||
@patch("application.api.answer.services.compression.orchestrator.settings")
|
||||
def test_no_queries_returns_no_compression(
|
||||
self,
|
||||
mock_settings,
|
||||
MockCompressionService,
|
||||
MockLLMCreator,
|
||||
mock_get_api_key,
|
||||
mock_get_provider,
|
||||
mock_conversation_service,
|
||||
mock_threshold_checker,
|
||||
decoded_token,
|
||||
):
|
||||
mock_settings.COMPRESSION_MODEL_OVERRIDE = None
|
||||
mock_get_provider.return_value = "openai"
|
||||
mock_get_api_key.return_value = "sk-test"
|
||||
MockLLMCreator.create_llm.return_value = MagicMock()
|
||||
|
||||
conversation = {"queries": [], "agent_id": "a"}
|
||||
|
||||
orch = CompressionOrchestrator(
|
||||
conversation_service=mock_conversation_service,
|
||||
threshold_checker=mock_threshold_checker,
|
||||
)
|
||||
result = orch._perform_compression("c1", conversation, "gpt-4", decoded_token)
|
||||
|
||||
assert result.success is True
|
||||
assert result.compression_performed is False
|
||||
|
||||
def test_exception_returns_failure(
|
||||
self,
|
||||
mock_conversation_service,
|
||||
mock_threshold_checker,
|
||||
decoded_token,
|
||||
):
|
||||
conversation = {
|
||||
"queries": [{"prompt": "q", "response": "r"}],
|
||||
"agent_id": "a",
|
||||
}
|
||||
|
||||
with patch(
|
||||
"application.api.answer.services.compression.orchestrator.settings"
|
||||
) as mock_settings, patch(
|
||||
"application.api.answer.services.compression.orchestrator.get_provider_from_model_id",
|
||||
side_effect=RuntimeError("provider error"),
|
||||
):
|
||||
mock_settings.COMPRESSION_MODEL_OVERRIDE = None
|
||||
|
||||
orch = CompressionOrchestrator(
|
||||
conversation_service=mock_conversation_service,
|
||||
threshold_checker=mock_threshold_checker,
|
||||
)
|
||||
result = orch._perform_compression(
|
||||
"c1", conversation, "gpt-4", decoded_token
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "provider error" in result.error
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCompressMidExecution:
|
||||
def test_with_provided_conversation(
|
||||
self,
|
||||
orchestrator,
|
||||
sample_conversation,
|
||||
decoded_token,
|
||||
):
|
||||
with patch.object(
|
||||
orchestrator, "_perform_compression"
|
||||
) as mock_perform:
|
||||
mock_perform.return_value = CompressionResult.success_no_compression([])
|
||||
|
||||
orchestrator.compress_mid_execution(
|
||||
conversation_id="conv1",
|
||||
user_id="user1",
|
||||
model_id="gpt-4",
|
||||
decoded_token=decoded_token,
|
||||
current_conversation=sample_conversation,
|
||||
)
|
||||
|
||||
mock_perform.assert_called_once_with(
|
||||
"conv1", sample_conversation, "gpt-4", decoded_token
|
||||
)
|
||||
|
||||
def test_loads_conversation_when_not_provided(
|
||||
self,
|
||||
orchestrator,
|
||||
mock_conversation_service,
|
||||
sample_conversation,
|
||||
decoded_token,
|
||||
):
|
||||
mock_conversation_service.get_conversation.return_value = sample_conversation
|
||||
|
||||
with patch.object(
|
||||
orchestrator, "_perform_compression"
|
||||
) as mock_perform:
|
||||
mock_perform.return_value = CompressionResult.success_no_compression([])
|
||||
|
||||
orchestrator.compress_mid_execution(
|
||||
conversation_id="conv1",
|
||||
user_id="user1",
|
||||
model_id="gpt-4",
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
mock_conversation_service.get_conversation.assert_called_once_with(
|
||||
"conv1", "user1"
|
||||
)
|
||||
mock_perform.assert_called_once()
|
||||
|
||||
def test_conversation_not_found_returns_failure(
|
||||
self,
|
||||
orchestrator,
|
||||
mock_conversation_service,
|
||||
decoded_token,
|
||||
):
|
||||
mock_conversation_service.get_conversation.return_value = None
|
||||
|
||||
result = orchestrator.compress_mid_execution(
|
||||
conversation_id="conv1",
|
||||
user_id="user1",
|
||||
model_id="gpt-4",
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "not found" in result.error
|
||||
|
||||
def test_exception_returns_failure(
|
||||
self,
|
||||
orchestrator,
|
||||
mock_conversation_service,
|
||||
decoded_token,
|
||||
):
|
||||
mock_conversation_service.get_conversation.side_effect = RuntimeError("fail")
|
||||
|
||||
result = orchestrator.compress_mid_execution(
|
||||
conversation_id="conv1",
|
||||
user_id="user1",
|
||||
model_id="gpt-4",
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "fail" in result.error
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestOrchestratorInit:
|
||||
def test_default_threshold_checker(self, mock_conversation_service):
|
||||
orch = CompressionOrchestrator(
|
||||
conversation_service=mock_conversation_service
|
||||
)
|
||||
assert orch.threshold_checker is not None
|
||||
assert orch.conversation_service is mock_conversation_service
|
||||
|
||||
def test_custom_threshold_checker(
|
||||
self, mock_conversation_service, mock_threshold_checker
|
||||
):
|
||||
orch = CompressionOrchestrator(
|
||||
conversation_service=mock_conversation_service,
|
||||
threshold_checker=mock_threshold_checker,
|
||||
)
|
||||
assert orch.threshold_checker is mock_threshold_checker
|
||||
423
tests/api/answer/services/compression/test_service.py
Normal file
423
tests/api/answer/services/compression/test_service.py
Normal file
@@ -0,0 +1,423 @@
|
||||
"""Tests for application/api/answer/services/compression/service.py"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from application.api.answer.services.compression.service import CompressionService
|
||||
from application.api.answer.services.compression.types import CompressionMetadata
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm():
|
||||
llm = MagicMock()
|
||||
llm.gen.return_value = "<summary>Compressed summary content</summary>"
|
||||
return llm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_conversation_service():
|
||||
svc = MagicMock()
|
||||
svc.update_compression_metadata = MagicMock()
|
||||
return svc
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_conversation():
|
||||
return {
|
||||
"queries": [
|
||||
{"prompt": "What is Python?", "response": "A programming language."},
|
||||
{"prompt": "Tell me more.", "response": "It's versatile and popular."},
|
||||
{
|
||||
"prompt": "What about tools?",
|
||||
"response": "Python has many tools.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"tool_name": "search",
|
||||
"action_name": "web_search",
|
||||
"arguments": {"q": "python tools"},
|
||||
"result": "Found 10 results",
|
||||
"status": "success",
|
||||
}
|
||||
],
|
||||
},
|
||||
],
|
||||
"compression_metadata": {},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCompressionServiceInit:
|
||||
@patch("application.api.answer.services.compression.service.settings")
|
||||
def test_default_prompt_builder(self, mock_settings, mock_llm):
|
||||
mock_settings.COMPRESSION_PROMPT_VERSION = "v1.0"
|
||||
with patch(
|
||||
"application.api.answer.services.compression.service.CompressionPromptBuilder"
|
||||
):
|
||||
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
|
||||
assert svc.llm is mock_llm
|
||||
assert svc.model_id == "gpt-4"
|
||||
|
||||
def test_custom_prompt_builder(self, mock_llm):
|
||||
custom_builder = MagicMock()
|
||||
svc = CompressionService(
|
||||
llm=mock_llm, model_id="gpt-4", prompt_builder=custom_builder
|
||||
)
|
||||
assert svc.prompt_builder is custom_builder
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCompressConversation:
|
||||
def test_successful_compression(self, mock_llm, sample_conversation):
|
||||
mock_builder = MagicMock()
|
||||
mock_builder.build_prompt.return_value = [
|
||||
{"role": "system", "content": "Compress"},
|
||||
{"role": "user", "content": "Conversation..."},
|
||||
]
|
||||
mock_builder.version = "v1.0"
|
||||
|
||||
svc = CompressionService(
|
||||
llm=mock_llm, model_id="gpt-4", prompt_builder=mock_builder
|
||||
)
|
||||
|
||||
with patch(
|
||||
"application.api.answer.services.compression.service.TokenCounter"
|
||||
) as MockTC:
|
||||
MockTC.count_query_tokens.return_value = 1000
|
||||
MockTC.count_message_tokens.return_value = 100
|
||||
|
||||
result = svc.compress_conversation(sample_conversation, 2)
|
||||
|
||||
assert isinstance(result, CompressionMetadata)
|
||||
assert result.query_index == 2
|
||||
assert result.compressed_summary == "Compressed summary content"
|
||||
assert result.original_token_count == 1000
|
||||
assert result.compressed_token_count == 100
|
||||
assert result.compression_ratio == 10.0
|
||||
assert result.model_used == "gpt-4"
|
||||
assert result.compression_prompt_version == "v1.0"
|
||||
|
||||
def test_invalid_index_negative(self, mock_llm, sample_conversation):
|
||||
mock_builder = MagicMock()
|
||||
mock_builder.version = "v1.0"
|
||||
svc = CompressionService(
|
||||
llm=mock_llm, model_id="gpt-4", prompt_builder=mock_builder
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid compress_up_to_index"):
|
||||
svc.compress_conversation(sample_conversation, -1)
|
||||
|
||||
def test_invalid_index_too_large(self, mock_llm, sample_conversation):
|
||||
mock_builder = MagicMock()
|
||||
mock_builder.version = "v1.0"
|
||||
svc = CompressionService(
|
||||
llm=mock_llm, model_id="gpt-4", prompt_builder=mock_builder
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid compress_up_to_index"):
|
||||
svc.compress_conversation(sample_conversation, 10)
|
||||
|
||||
def test_with_existing_compressions(self, mock_llm):
|
||||
conversation = {
|
||||
"queries": [
|
||||
{"prompt": "q1", "response": "r1"},
|
||||
{"prompt": "q2", "response": "r2"},
|
||||
],
|
||||
"compression_metadata": {
|
||||
"compression_points": [
|
||||
{
|
||||
"query_index": 0,
|
||||
"compressed_summary": "Previous summary",
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
mock_builder = MagicMock()
|
||||
mock_builder.build_prompt.return_value = [
|
||||
{"role": "system", "content": "Compress"},
|
||||
{"role": "user", "content": "..."},
|
||||
]
|
||||
mock_builder.version = "v1.0"
|
||||
|
||||
svc = CompressionService(
|
||||
llm=mock_llm, model_id="gpt-4", prompt_builder=mock_builder
|
||||
)
|
||||
|
||||
with patch(
|
||||
"application.api.answer.services.compression.service.TokenCounter"
|
||||
) as MockTC:
|
||||
MockTC.count_query_tokens.return_value = 500
|
||||
MockTC.count_message_tokens.return_value = 50
|
||||
|
||||
result = svc.compress_conversation(conversation, 1)
|
||||
assert isinstance(result, CompressionMetadata)
|
||||
# Verify existing compressions were passed to prompt builder
|
||||
call_args = mock_builder.build_prompt.call_args
|
||||
assert call_args[0][1] == [
|
||||
{"query_index": 0, "compressed_summary": "Previous summary"}
|
||||
]
|
||||
|
||||
def test_zero_compressed_tokens_ratio(self, mock_llm, sample_conversation):
|
||||
mock_builder = MagicMock()
|
||||
mock_builder.build_prompt.return_value = [
|
||||
{"role": "system", "content": "C"},
|
||||
{"role": "user", "content": "..."},
|
||||
]
|
||||
mock_builder.version = "v1.0"
|
||||
|
||||
svc = CompressionService(
|
||||
llm=mock_llm, model_id="gpt-4", prompt_builder=mock_builder
|
||||
)
|
||||
|
||||
with patch(
|
||||
"application.api.answer.services.compression.service.TokenCounter"
|
||||
) as MockTC:
|
||||
MockTC.count_query_tokens.return_value = 1000
|
||||
MockTC.count_message_tokens.return_value = 0
|
||||
|
||||
result = svc.compress_conversation(sample_conversation, 2)
|
||||
assert result.compression_ratio == 0
|
||||
|
||||
def test_llm_error_propagates(self, sample_conversation):
|
||||
llm = MagicMock()
|
||||
llm.gen.side_effect = RuntimeError("LLM error")
|
||||
mock_builder = MagicMock()
|
||||
mock_builder.build_prompt.return_value = [
|
||||
{"role": "system", "content": "C"},
|
||||
{"role": "user", "content": "..."},
|
||||
]
|
||||
mock_builder.version = "v1.0"
|
||||
|
||||
svc = CompressionService(
|
||||
llm=llm, model_id="gpt-4", prompt_builder=mock_builder
|
||||
)
|
||||
|
||||
with patch(
|
||||
"application.api.answer.services.compression.service.TokenCounter"
|
||||
) as MockTC:
|
||||
MockTC.count_query_tokens.return_value = 100
|
||||
with pytest.raises(RuntimeError, match="LLM error"):
|
||||
svc.compress_conversation(sample_conversation, 2)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCompressAndSave:
|
||||
def test_saves_metadata_to_db(
|
||||
self, mock_llm, mock_conversation_service, sample_conversation
|
||||
):
|
||||
mock_builder = MagicMock()
|
||||
mock_builder.build_prompt.return_value = [
|
||||
{"role": "system", "content": "C"},
|
||||
{"role": "user", "content": "..."},
|
||||
]
|
||||
mock_builder.version = "v1.0"
|
||||
|
||||
svc = CompressionService(
|
||||
llm=mock_llm,
|
||||
model_id="gpt-4",
|
||||
conversation_service=mock_conversation_service,
|
||||
prompt_builder=mock_builder,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"application.api.answer.services.compression.service.TokenCounter"
|
||||
) as MockTC:
|
||||
MockTC.count_query_tokens.return_value = 500
|
||||
MockTC.count_message_tokens.return_value = 50
|
||||
|
||||
result = svc.compress_and_save("conv_123", sample_conversation, 2)
|
||||
|
||||
assert isinstance(result, CompressionMetadata)
|
||||
mock_conversation_service.update_compression_metadata.assert_called_once_with(
|
||||
"conv_123", result.to_dict()
|
||||
)
|
||||
|
||||
def test_raises_without_conversation_service(self, mock_llm, sample_conversation):
|
||||
mock_builder = MagicMock()
|
||||
mock_builder.version = "v1.0"
|
||||
svc = CompressionService(
|
||||
llm=mock_llm, model_id="gpt-4", prompt_builder=mock_builder
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="conversation_service required"):
|
||||
svc.compress_and_save("conv_123", sample_conversation, 2)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGetCompressedContext:
|
||||
def test_no_compression_returns_full_history(self, mock_llm):
|
||||
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
|
||||
conversation = {
|
||||
"queries": [{"prompt": "q1", "response": "r1"}],
|
||||
"compression_metadata": {},
|
||||
}
|
||||
|
||||
summary, queries = svc.get_compressed_context(conversation)
|
||||
|
||||
assert summary is None
|
||||
assert queries == [{"prompt": "q1", "response": "r1"}]
|
||||
|
||||
def test_no_compression_points_returns_full_history(self, mock_llm):
|
||||
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
|
||||
conversation = {
|
||||
"queries": [{"prompt": "q1", "response": "r1"}],
|
||||
"compression_metadata": {"is_compressed": True, "compression_points": []},
|
||||
}
|
||||
|
||||
summary, queries = svc.get_compressed_context(conversation)
|
||||
assert summary is None
|
||||
assert len(queries) == 1
|
||||
|
||||
def test_with_compression_returns_summary_and_recent(self, mock_llm):
|
||||
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
|
||||
conversation = {
|
||||
"queries": [
|
||||
{"prompt": "q0", "response": "r0"},
|
||||
{"prompt": "q1", "response": "r1"},
|
||||
{"prompt": "q2", "response": "r2"},
|
||||
],
|
||||
"compression_metadata": {
|
||||
"is_compressed": True,
|
||||
"compression_points": [
|
||||
{
|
||||
"query_index": 1,
|
||||
"compressed_summary": "Summary of q0 and q1",
|
||||
"compressed_token_count": 50,
|
||||
"original_token_count": 500,
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
summary, queries = svc.get_compressed_context(conversation)
|
||||
|
||||
assert summary == "Summary of q0 and q1"
|
||||
assert len(queries) == 1
|
||||
assert queries[0]["prompt"] == "q2"
|
||||
|
||||
def test_none_queries_returns_empty(self, mock_llm):
|
||||
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
|
||||
conversation = {
|
||||
"queries": None,
|
||||
"compression_metadata": {},
|
||||
}
|
||||
|
||||
summary, queries = svc.get_compressed_context(conversation)
|
||||
assert summary is None
|
||||
assert queries == []
|
||||
|
||||
def test_exception_falls_back_to_full_history(self, mock_llm):
|
||||
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
|
||||
conversation = {
|
||||
"queries": [{"prompt": "q", "response": "r"}],
|
||||
"compression_metadata": {
|
||||
"is_compressed": True,
|
||||
"compression_points": "invalid", # This will cause an error
|
||||
},
|
||||
}
|
||||
|
||||
summary, queries = svc.get_compressed_context(conversation)
|
||||
assert summary is None
|
||||
assert queries == [{"prompt": "q", "response": "r"}]
|
||||
|
||||
def test_exception_with_none_queries_returns_empty(self, mock_llm):
|
||||
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
|
||||
# Force exception by making compression_points non-iterable
|
||||
conversation = {
|
||||
"queries": None,
|
||||
"compression_metadata": {
|
||||
"is_compressed": True,
|
||||
"compression_points": "bad",
|
||||
},
|
||||
}
|
||||
|
||||
summary, queries = svc.get_compressed_context(conversation)
|
||||
assert summary is None
|
||||
assert queries == []
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestExtractSummary:
|
||||
def test_extracts_from_summary_tags(self, mock_llm):
|
||||
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
|
||||
response = "<analysis>Some analysis</analysis><summary>The actual summary</summary>"
|
||||
result = svc._extract_summary(response)
|
||||
assert result == "The actual summary"
|
||||
|
||||
def test_removes_analysis_tags_when_no_summary(self, mock_llm):
|
||||
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
|
||||
response = "<analysis>analysis text</analysis>Raw summary text here"
|
||||
result = svc._extract_summary(response)
|
||||
assert result == "Raw summary text here"
|
||||
|
||||
def test_returns_full_response_when_no_tags(self, mock_llm):
|
||||
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
|
||||
response = "Just a plain text response"
|
||||
result = svc._extract_summary(response)
|
||||
assert result == "Just a plain text response"
|
||||
|
||||
def test_multiline_summary(self, mock_llm):
|
||||
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
|
||||
response = "<summary>Line 1\nLine 2\nLine 3</summary>"
|
||||
result = svc._extract_summary(response)
|
||||
assert "Line 1" in result
|
||||
assert "Line 3" in result
|
||||
|
||||
def test_strips_whitespace(self, mock_llm):
|
||||
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
|
||||
response = "<summary> Trimmed </summary>"
|
||||
result = svc._extract_summary(response)
|
||||
assert result == "Trimmed"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestLogToolCallStats:
|
||||
def test_no_tool_calls(self, mock_llm):
|
||||
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
|
||||
queries = [{"prompt": "q", "response": "r"}]
|
||||
# Should not raise
|
||||
svc._log_tool_call_stats(queries)
|
||||
|
||||
def test_with_tool_calls(self, mock_llm):
|
||||
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
|
||||
queries = [
|
||||
{
|
||||
"prompt": "q",
|
||||
"response": "r",
|
||||
"tool_calls": [
|
||||
{
|
||||
"tool_name": "search",
|
||||
"action_name": "web",
|
||||
"result": "result text",
|
||||
},
|
||||
{
|
||||
"tool_name": "search",
|
||||
"action_name": "web",
|
||||
"result": "more text",
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
# Should not raise - just logs
|
||||
svc._log_tool_call_stats(queries)
|
||||
|
||||
def test_empty_queries(self, mock_llm):
|
||||
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
|
||||
svc._log_tool_call_stats([])
|
||||
|
||||
def test_tool_call_with_none_result(self, mock_llm):
|
||||
svc = CompressionService(llm=mock_llm, model_id="gpt-4")
|
||||
queries = [
|
||||
{
|
||||
"prompt": "q",
|
||||
"response": "r",
|
||||
"tool_calls": [
|
||||
{
|
||||
"tool_name": "t",
|
||||
"action_name": "a",
|
||||
"result": None,
|
||||
}
|
||||
],
|
||||
}
|
||||
]
|
||||
svc._log_tool_call_stats(queries)
|
||||
131
tests/api/answer/services/compression/test_types.py
Normal file
131
tests/api/answer/services/compression/test_types.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""Tests for application/api/answer/services/compression/types.py"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from application.api.answer.services.compression.types import (
|
||||
CompressionMetadata,
|
||||
CompressionResult,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCompressionMetadata:
|
||||
def _make_metadata(self, **overrides):
|
||||
defaults = dict(
|
||||
timestamp=datetime(2025, 1, 1, tzinfo=timezone.utc),
|
||||
query_index=5,
|
||||
compressed_summary="Summary of conversation",
|
||||
original_token_count=5000,
|
||||
compressed_token_count=500,
|
||||
compression_ratio=10.0,
|
||||
model_used="gpt-4",
|
||||
compression_prompt_version="v1.0",
|
||||
)
|
||||
defaults.update(overrides)
|
||||
return CompressionMetadata(**defaults)
|
||||
|
||||
def test_to_dict_contains_all_fields(self):
|
||||
meta = self._make_metadata()
|
||||
d = meta.to_dict()
|
||||
|
||||
assert d["timestamp"] == datetime(2025, 1, 1, tzinfo=timezone.utc)
|
||||
assert d["query_index"] == 5
|
||||
assert d["compressed_summary"] == "Summary of conversation"
|
||||
assert d["original_token_count"] == 5000
|
||||
assert d["compressed_token_count"] == 500
|
||||
assert d["compression_ratio"] == 10.0
|
||||
assert d["model_used"] == "gpt-4"
|
||||
assert d["compression_prompt_version"] == "v1.0"
|
||||
|
||||
def test_to_dict_returns_dict_type(self):
|
||||
meta = self._make_metadata()
|
||||
assert isinstance(meta.to_dict(), dict)
|
||||
|
||||
def test_to_dict_field_count(self):
|
||||
meta = self._make_metadata()
|
||||
d = meta.to_dict()
|
||||
assert len(d) == 8
|
||||
|
||||
def test_attributes_accessible(self):
|
||||
meta = self._make_metadata(query_index=10, compression_ratio=5.5)
|
||||
assert meta.query_index == 10
|
||||
assert meta.compression_ratio == 5.5
|
||||
|
||||
def test_zero_compressed_tokens(self):
|
||||
meta = self._make_metadata(compressed_token_count=0, compression_ratio=0)
|
||||
d = meta.to_dict()
|
||||
assert d["compressed_token_count"] == 0
|
||||
assert d["compression_ratio"] == 0
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCompressionResult:
|
||||
def test_success_with_compression(self):
|
||||
meta = CompressionMetadata(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
query_index=3,
|
||||
compressed_summary="summary",
|
||||
original_token_count=1000,
|
||||
compressed_token_count=100,
|
||||
compression_ratio=10.0,
|
||||
model_used="gpt-4",
|
||||
compression_prompt_version="v1.0",
|
||||
)
|
||||
queries = [{"prompt": "q1", "response": "r1"}]
|
||||
result = CompressionResult.success_with_compression("summary", queries, meta)
|
||||
|
||||
assert result.success is True
|
||||
assert result.compressed_summary == "summary"
|
||||
assert result.recent_queries == queries
|
||||
assert result.metadata is meta
|
||||
assert result.compression_performed is True
|
||||
assert result.error is None
|
||||
|
||||
def test_success_no_compression(self):
|
||||
queries = [{"prompt": "q1", "response": "r1"}]
|
||||
result = CompressionResult.success_no_compression(queries)
|
||||
|
||||
assert result.success is True
|
||||
assert result.compressed_summary is None
|
||||
assert result.recent_queries == queries
|
||||
assert result.metadata is None
|
||||
assert result.compression_performed is False
|
||||
assert result.error is None
|
||||
|
||||
def test_failure(self):
|
||||
result = CompressionResult.failure("something went wrong")
|
||||
|
||||
assert result.success is False
|
||||
assert result.error == "something went wrong"
|
||||
assert result.compression_performed is False
|
||||
assert result.compressed_summary is None
|
||||
assert result.recent_queries == []
|
||||
assert result.metadata is None
|
||||
|
||||
def test_as_history_extracts_prompt_response(self):
|
||||
queries = [
|
||||
{"prompt": "Hello", "response": "Hi", "extra": "ignored"},
|
||||
{"prompt": "How?", "response": "Fine"},
|
||||
]
|
||||
result = CompressionResult.success_no_compression(queries)
|
||||
history = result.as_history()
|
||||
|
||||
assert len(history) == 2
|
||||
assert history[0] == {"prompt": "Hello", "response": "Hi"}
|
||||
assert history[1] == {"prompt": "How?", "response": "Fine"}
|
||||
|
||||
def test_as_history_empty_queries(self):
|
||||
result = CompressionResult.success_no_compression([])
|
||||
assert result.as_history() == []
|
||||
|
||||
def test_default_recent_queries_is_empty_list(self):
|
||||
result = CompressionResult(success=True)
|
||||
assert result.recent_queries == []
|
||||
assert result.as_history() == []
|
||||
|
||||
def test_success_no_compression_with_empty_list(self):
|
||||
result = CompressionResult.success_no_compression([])
|
||||
assert result.success is True
|
||||
assert result.recent_queries == []
|
||||
331
tests/api/answer/test_stream_processor.py
Normal file
331
tests/api/answer/test_stream_processor.py
Normal file
@@ -0,0 +1,331 @@
|
||||
"""Tests for application/api/answer/services/stream_processor.py — get_prompt and helpers."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from application.api.answer.services.stream_processor import get_prompt
|
||||
|
||||
|
||||
class TestGetPrompt:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_default_preset(self):
|
||||
prompt = get_prompt("default")
|
||||
assert isinstance(prompt, str)
|
||||
assert len(prompt) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_creative_preset(self):
|
||||
prompt = get_prompt("creative")
|
||||
assert isinstance(prompt, str)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_strict_preset(self):
|
||||
prompt = get_prompt("strict")
|
||||
assert isinstance(prompt, str)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_reduce_preset(self):
|
||||
prompt = get_prompt("reduce")
|
||||
assert isinstance(prompt, str)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_agentic_default_preset(self):
|
||||
prompt = get_prompt("agentic_default")
|
||||
assert isinstance(prompt, str)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_agentic_creative_preset(self):
|
||||
prompt = get_prompt("agentic_creative")
|
||||
assert isinstance(prompt, str)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_agentic_strict_preset(self):
|
||||
prompt = get_prompt("agentic_strict")
|
||||
assert isinstance(prompt, str)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_mongo_prompt_by_id(self):
|
||||
mock_collection = MagicMock()
|
||||
mock_collection.find_one.return_value = {"_id": "abc", "content": "Custom prompt"}
|
||||
prompt = get_prompt("507f1f77bcf86cd799439011", prompts_collection=mock_collection)
|
||||
assert prompt == "Custom prompt"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_mongo_prompt_not_found_raises(self):
|
||||
mock_collection = MagicMock()
|
||||
mock_collection.find_one.return_value = None
|
||||
with pytest.raises(ValueError, match="Invalid prompt ID"):
|
||||
get_prompt("507f1f77bcf86cd799439011", prompts_collection=mock_collection)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invalid_id_raises(self):
|
||||
mock_collection = MagicMock()
|
||||
mock_collection.find_one.side_effect = Exception("bad id")
|
||||
with pytest.raises(ValueError, match="Invalid prompt ID"):
|
||||
get_prompt("not-an-objectid", prompts_collection=mock_collection)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_mongo_fallback_when_no_collection(self):
|
||||
"""When no collection passed, it reads from MongoDB."""
|
||||
mock_collection = MagicMock()
|
||||
mock_collection.find_one.return_value = {"content": "From DB"}
|
||||
mock_db = MagicMock()
|
||||
mock_db.__getitem__ = MagicMock(return_value=mock_collection)
|
||||
|
||||
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
|
||||
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
|
||||
mock_settings.MONGO_DB_NAME = "test_db"
|
||||
MockMongo.get_client.return_value = {"test_db": mock_db}
|
||||
prompt = get_prompt("507f1f77bcf86cd799439011")
|
||||
assert prompt == "From DB"
|
||||
|
||||
|
||||
class TestStreamProcessorInit:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_sets_attributes(self):
|
||||
mock_db = MagicMock()
|
||||
mock_client = {"docsgpt": mock_db}
|
||||
|
||||
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
|
||||
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
|
||||
mock_settings.MONGO_DB_NAME = "docsgpt"
|
||||
MockMongo.get_client.return_value = mock_client
|
||||
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
sp = StreamProcessor(
|
||||
request_data={"conversation_id": "conv1", "agent_id": "a1"},
|
||||
decoded_token={"sub": "user1"},
|
||||
)
|
||||
assert sp.conversation_id == "conv1"
|
||||
assert sp.initial_user_id == "user1"
|
||||
assert sp.agent_id == "a1"
|
||||
assert sp.history == []
|
||||
assert sp.attachments == []
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_no_token(self):
|
||||
mock_db = MagicMock()
|
||||
mock_client = {"docsgpt": mock_db}
|
||||
|
||||
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
|
||||
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
|
||||
mock_settings.MONGO_DB_NAME = "docsgpt"
|
||||
MockMongo.get_client.return_value = mock_client
|
||||
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
sp = StreamProcessor(request_data={}, decoded_token=None)
|
||||
assert sp.initial_user_id is None
|
||||
|
||||
|
||||
class TestGetAttachmentsContent:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_empty_ids_returns_empty(self):
|
||||
mock_db = MagicMock()
|
||||
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
|
||||
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
|
||||
mock_settings.MONGO_DB_NAME = "docsgpt"
|
||||
MockMongo.get_client.return_value = {"docsgpt": mock_db}
|
||||
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
sp = StreamProcessor(request_data={}, decoded_token={"sub": "u"})
|
||||
result = sp._get_attachments_content([], "u")
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_returns_matching_attachments(self):
|
||||
mock_db = MagicMock()
|
||||
mock_attachments = MagicMock()
|
||||
mock_attachments.find_one.return_value = {"_id": "att1", "content": "data"}
|
||||
mock_db.__getitem__ = MagicMock(return_value=mock_attachments)
|
||||
|
||||
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
|
||||
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
|
||||
mock_settings.MONGO_DB_NAME = "docsgpt"
|
||||
MockMongo.get_client.return_value = {"docsgpt": mock_db}
|
||||
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
sp = StreamProcessor(request_data={}, decoded_token={"sub": "u"})
|
||||
result = sp._get_attachments_content(["507f1f77bcf86cd799439011"], "u")
|
||||
assert len(result) == 1
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invalid_attachment_id_continues(self):
|
||||
mock_db = MagicMock()
|
||||
mock_attachments = MagicMock()
|
||||
mock_attachments.find_one.side_effect = Exception("bad id")
|
||||
mock_db.__getitem__ = MagicMock(return_value=mock_attachments)
|
||||
|
||||
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
|
||||
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
|
||||
mock_settings.MONGO_DB_NAME = "docsgpt"
|
||||
MockMongo.get_client.return_value = {"docsgpt": mock_db}
|
||||
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
sp = StreamProcessor(request_data={}, decoded_token={"sub": "u"})
|
||||
result = sp._get_attachments_content(["bad"], "u")
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestResolveAgentId:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_request_data(self):
|
||||
mock_db = MagicMock()
|
||||
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
|
||||
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
|
||||
mock_settings.MONGO_DB_NAME = "docsgpt"
|
||||
MockMongo.get_client.return_value = {"docsgpt": mock_db}
|
||||
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
sp = StreamProcessor(
|
||||
request_data={"agent_id": "agent_123"},
|
||||
decoded_token={"sub": "u"},
|
||||
)
|
||||
assert sp._resolve_agent_id() == "agent_123"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_agent_no_conversation(self):
|
||||
mock_db = MagicMock()
|
||||
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
|
||||
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
|
||||
mock_settings.MONGO_DB_NAME = "docsgpt"
|
||||
MockMongo.get_client.return_value = {"docsgpt": mock_db}
|
||||
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
sp = StreamProcessor(request_data={}, decoded_token={"sub": "u"})
|
||||
assert sp._resolve_agent_id() is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_conversation(self):
|
||||
mock_db = MagicMock()
|
||||
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
|
||||
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
|
||||
mock_settings.MONGO_DB_NAME = "docsgpt"
|
||||
MockMongo.get_client.return_value = {"docsgpt": mock_db}
|
||||
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
sp = StreamProcessor(
|
||||
request_data={"conversation_id": "conv1"},
|
||||
decoded_token={"sub": "u"},
|
||||
)
|
||||
sp.conversation_service = MagicMock()
|
||||
sp.conversation_service.get_conversation.return_value = {"agent_id": "from_conv"}
|
||||
assert sp._resolve_agent_id() == "from_conv"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_conversation_not_found(self):
|
||||
mock_db = MagicMock()
|
||||
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
|
||||
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
|
||||
mock_settings.MONGO_DB_NAME = "docsgpt"
|
||||
MockMongo.get_client.return_value = {"docsgpt": mock_db}
|
||||
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
sp = StreamProcessor(
|
||||
request_data={"conversation_id": "conv1"},
|
||||
decoded_token={"sub": "u"},
|
||||
)
|
||||
sp.conversation_service = MagicMock()
|
||||
sp.conversation_service.get_conversation.return_value = None
|
||||
assert sp._resolve_agent_id() is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_conversation_lookup_exception(self):
|
||||
mock_db = MagicMock()
|
||||
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
|
||||
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
|
||||
mock_settings.MONGO_DB_NAME = "docsgpt"
|
||||
MockMongo.get_client.return_value = {"docsgpt": mock_db}
|
||||
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
sp = StreamProcessor(
|
||||
request_data={"conversation_id": "conv1"},
|
||||
decoded_token={"sub": "u"},
|
||||
)
|
||||
sp.conversation_service = MagicMock()
|
||||
sp.conversation_service.get_conversation.side_effect = Exception("db error")
|
||||
assert sp._resolve_agent_id() is None
|
||||
|
||||
|
||||
class TestGetPromptContent:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_caches_result(self):
|
||||
mock_db = MagicMock()
|
||||
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
|
||||
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
|
||||
mock_settings.MONGO_DB_NAME = "docsgpt"
|
||||
MockMongo.get_client.return_value = {"docsgpt": mock_db}
|
||||
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
sp = StreamProcessor(request_data={}, decoded_token={"sub": "u"})
|
||||
sp.agent_config = {"prompt_id": "default"}
|
||||
result1 = sp._get_prompt_content()
|
||||
result2 = sp._get_prompt_content()
|
||||
assert result1 == result2
|
||||
assert result1 is not None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_prompt_id(self):
|
||||
mock_db = MagicMock()
|
||||
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
|
||||
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
|
||||
mock_settings.MONGO_DB_NAME = "docsgpt"
|
||||
MockMongo.get_client.return_value = {"docsgpt": mock_db}
|
||||
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
sp = StreamProcessor(request_data={}, decoded_token={"sub": "u"})
|
||||
sp.agent_config = {}
|
||||
assert sp._get_prompt_content() is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invalid_prompt_id_returns_none(self):
|
||||
mock_db = MagicMock()
|
||||
mock_prompts = MagicMock()
|
||||
mock_prompts.find_one.side_effect = Exception("bad")
|
||||
mock_db.__getitem__ = MagicMock(return_value=mock_prompts)
|
||||
|
||||
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
|
||||
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
|
||||
mock_settings.MONGO_DB_NAME = "docsgpt"
|
||||
MockMongo.get_client.return_value = {"docsgpt": mock_db}
|
||||
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
sp = StreamProcessor(request_data={}, decoded_token={"sub": "u"})
|
||||
sp.agent_config = {"prompt_id": "bad_id"}
|
||||
assert sp._get_prompt_content() is None
|
||||
|
||||
|
||||
class TestGetRequiredToolActions:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_prompt_returns_none(self):
|
||||
mock_db = MagicMock()
|
||||
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
|
||||
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
|
||||
mock_settings.MONGO_DB_NAME = "docsgpt"
|
||||
MockMongo.get_client.return_value = {"docsgpt": mock_db}
|
||||
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
sp = StreamProcessor(request_data={}, decoded_token={"sub": "u"})
|
||||
sp.agent_config = {}
|
||||
assert sp._get_required_tool_actions() is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_template_syntax_returns_empty(self):
|
||||
mock_db = MagicMock()
|
||||
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
|
||||
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
|
||||
mock_settings.MONGO_DB_NAME = "docsgpt"
|
||||
MockMongo.get_client.return_value = {"docsgpt": mock_db}
|
||||
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
sp = StreamProcessor(request_data={}, decoded_token={"sub": "u"})
|
||||
sp.agent_config = {"prompt_id": "default"}
|
||||
sp._prompt_content = "No template syntax here"
|
||||
result = sp._get_required_tool_actions()
|
||||
assert result == {}
|
||||
333
tests/api/test_connector_routes.py
Normal file
333
tests/api/test_connector_routes.py
Normal file
@@ -0,0 +1,333 @@
|
||||
"""Tests for application/api/connector/routes.py"""
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import mongomock
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
with patch("application.app.handle_auth", return_value={"sub": "test_user"}):
|
||||
from application.app import app as flask_app
|
||||
flask_app.config["TESTING"] = True
|
||||
yield flask_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
return app.test_client()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_sessions(monkeypatch):
|
||||
mock_client = mongomock.MongoClient()
|
||||
mock_db = mock_client["docsgpt"]
|
||||
sessions = mock_db["connector_sessions"]
|
||||
sources = mock_db["sources"]
|
||||
monkeypatch.setattr("application.api.connector.routes.sessions_collection", sessions)
|
||||
monkeypatch.setattr("application.api.connector.routes.sources_collection", sources)
|
||||
return {"sessions": sessions, "sources": sources}
|
||||
|
||||
|
||||
class TestConnectorAuth:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_missing_provider(self, client):
|
||||
resp = client.get("/api/connectors/auth")
|
||||
assert resp.status_code == 400
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_unsupported_provider(self, client):
|
||||
resp = client.get("/api/connectors/auth?provider=dropbox")
|
||||
assert resp.status_code == 400
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_unauthorized(self, client, app):
|
||||
with patch("application.app.handle_auth", return_value=None):
|
||||
resp = client.get("/api/connectors/auth?provider=google_drive")
|
||||
data = json.loads(resp.data)
|
||||
# decoded_token is None -> 401
|
||||
assert resp.status_code == 401 or data.get("error") == "Unauthorized"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_success(self, client, mock_sessions):
|
||||
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
|
||||
MockCC.is_supported.return_value = True
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.get_authorization_url.return_value = "https://oauth.example.com/auth"
|
||||
MockCC.create_auth.return_value = mock_auth
|
||||
|
||||
resp = client.get("/api/connectors/auth?provider=google_drive")
|
||||
assert resp.status_code == 200
|
||||
data = json.loads(resp.data)
|
||||
assert data["success"] is True
|
||||
assert "authorization_url" in data
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_exception_returns_500(self, client, mock_sessions):
|
||||
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
|
||||
MockCC.is_supported.return_value = True
|
||||
MockCC.create_auth.side_effect = Exception("oauth fail")
|
||||
resp = client.get("/api/connectors/auth?provider=google_drive")
|
||||
assert resp.status_code == 500
|
||||
|
||||
|
||||
class TestConnectorFiles:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_missing_params(self, client):
|
||||
resp = client.post("/api/connectors/files", json={"provider": "google_drive"})
|
||||
assert resp.status_code == 400
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invalid_session(self, client, mock_sessions):
|
||||
resp = client.post("/api/connectors/files", json={
|
||||
"provider": "google_drive",
|
||||
"session_token": "bad_token",
|
||||
})
|
||||
assert resp.status_code == 401
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_success(self, client, mock_sessions):
|
||||
mock_sessions["sessions"].insert_one({
|
||||
"session_token": "valid_tok",
|
||||
"user": "test_user",
|
||||
"provider": "google_drive",
|
||||
})
|
||||
|
||||
mock_doc = MagicMock()
|
||||
mock_doc.doc_id = "f1"
|
||||
mock_doc.extra_info = {
|
||||
"file_name": "test.pdf",
|
||||
"mime_type": "application/pdf",
|
||||
"size": 1024,
|
||||
"modified_time": "2025-01-01T12:00:00.000Z",
|
||||
"is_folder": False,
|
||||
}
|
||||
mock_loader = MagicMock()
|
||||
mock_loader.load_data.return_value = [mock_doc]
|
||||
mock_loader.next_page_token = None
|
||||
|
||||
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
|
||||
MockCC.create_connector.return_value = mock_loader
|
||||
resp = client.post("/api/connectors/files", json={
|
||||
"provider": "google_drive",
|
||||
"session_token": "valid_tok",
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
data = json.loads(resp.data)
|
||||
assert data["success"] is True
|
||||
assert len(data["files"]) == 1
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_modified_time(self, client, mock_sessions):
|
||||
mock_sessions["sessions"].insert_one({
|
||||
"session_token": "tok2",
|
||||
"user": "test_user",
|
||||
"provider": "google_drive",
|
||||
})
|
||||
mock_doc = MagicMock()
|
||||
mock_doc.doc_id = "f1"
|
||||
mock_doc.extra_info = {"file_name": "test.pdf", "mime_type": "application/pdf"}
|
||||
mock_loader = MagicMock()
|
||||
mock_loader.load_data.return_value = [mock_doc]
|
||||
mock_loader.next_page_token = None
|
||||
|
||||
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
|
||||
MockCC.create_connector.return_value = mock_loader
|
||||
resp = client.post("/api/connectors/files", json={
|
||||
"provider": "google_drive", "session_token": "tok2",
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
class TestConnectorValidateSession:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_missing_params(self, client):
|
||||
resp = client.post("/api/connectors/validate-session", json={"provider": "google_drive"})
|
||||
assert resp.status_code == 400
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invalid_session(self, client, mock_sessions):
|
||||
resp = client.post("/api/connectors/validate-session", json={
|
||||
"provider": "google_drive", "session_token": "bad",
|
||||
})
|
||||
assert resp.status_code == 401
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_valid_non_expired(self, client, mock_sessions):
|
||||
mock_sessions["sessions"].insert_one({
|
||||
"session_token": "valid",
|
||||
"user": "test_user",
|
||||
"provider": "google_drive",
|
||||
"token_info": {"access_token": "at", "refresh_token": "rt", "expiry": None},
|
||||
"user_email": "user@example.com",
|
||||
})
|
||||
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.is_token_expired.return_value = False
|
||||
MockCC.create_auth.return_value = mock_auth
|
||||
resp = client.post("/api/connectors/validate-session", json={
|
||||
"provider": "google_drive", "session_token": "valid",
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
data = json.loads(resp.data)
|
||||
assert data["success"] is True
|
||||
assert data["expired"] is False
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_expired_with_refresh(self, client, mock_sessions):
|
||||
mock_sessions["sessions"].insert_one({
|
||||
"session_token": "expired_tok",
|
||||
"user": "test_user",
|
||||
"provider": "google_drive",
|
||||
"token_info": {"access_token": "old_at", "refresh_token": "rt", "expiry": 100},
|
||||
})
|
||||
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.is_token_expired.return_value = True
|
||||
mock_auth.refresh_access_token.return_value = {"access_token": "new_at", "refresh_token": "rt"}
|
||||
mock_auth.sanitize_token_info.return_value = {"access_token": "new_at", "refresh_token": "rt"}
|
||||
MockCC.create_auth.return_value = mock_auth
|
||||
resp = client.post("/api/connectors/validate-session", json={
|
||||
"provider": "google_drive", "session_token": "expired_tok",
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_expired_no_refresh(self, client, mock_sessions):
|
||||
mock_sessions["sessions"].insert_one({
|
||||
"session_token": "exp_no_ref",
|
||||
"user": "test_user",
|
||||
"token_info": {"access_token": "at", "expiry": 100},
|
||||
})
|
||||
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.is_token_expired.return_value = True
|
||||
MockCC.create_auth.return_value = mock_auth
|
||||
resp = client.post("/api/connectors/validate-session", json={
|
||||
"provider": "google_drive", "session_token": "exp_no_ref",
|
||||
})
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
class TestConnectorDisconnect:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_missing_provider(self, client):
|
||||
resp = client.post("/api/connectors/disconnect", json={})
|
||||
assert resp.status_code == 400
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_success_with_session(self, client, mock_sessions):
|
||||
mock_sessions["sessions"].insert_one({"session_token": "del_me", "provider": "google_drive"})
|
||||
resp = client.post("/api/connectors/disconnect", json={
|
||||
"provider": "google_drive", "session_token": "del_me",
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
data = json.loads(resp.data)
|
||||
assert data["success"] is True
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_success_without_session(self, client, mock_sessions):
|
||||
resp = client.post("/api/connectors/disconnect", json={"provider": "google_drive"})
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
class TestConnectorSync:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_missing_params(self, client, mock_sessions):
|
||||
resp = client.post("/api/connectors/sync", json={"source_id": "abc"})
|
||||
assert resp.status_code == 400
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_source_not_found(self, client, mock_sessions):
|
||||
from bson.objectid import ObjectId
|
||||
resp = client.post("/api/connectors/sync", json={
|
||||
"source_id": str(ObjectId()), "session_token": "tok",
|
||||
})
|
||||
assert resp.status_code == 404
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_unauthorized_source(self, client, mock_sessions):
|
||||
sid = mock_sessions["sources"].insert_one({"user": "other_user", "name": "src"}).inserted_id
|
||||
resp = client.post("/api/connectors/sync", json={
|
||||
"source_id": str(sid), "session_token": "tok",
|
||||
})
|
||||
assert resp.status_code == 403
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_missing_provider_in_remote_data(self, client, mock_sessions):
|
||||
sid = mock_sessions["sources"].insert_one({
|
||||
"user": "test_user", "name": "src", "remote_data": json.dumps({}),
|
||||
}).inserted_id
|
||||
resp = client.post("/api/connectors/sync", json={
|
||||
"source_id": str(sid), "session_token": "tok",
|
||||
})
|
||||
assert resp.status_code == 400
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_success(self, client, mock_sessions):
|
||||
sid = mock_sessions["sources"].insert_one({
|
||||
"user": "test_user",
|
||||
"name": "src",
|
||||
"remote_data": json.dumps({"provider": "google_drive", "file_ids": ["f1"]}),
|
||||
}).inserted_id
|
||||
mock_task = MagicMock()
|
||||
mock_task.id = "task_123"
|
||||
with patch("application.api.connector.routes.ingest_connector_task") as mock_ingest:
|
||||
mock_ingest.delay.return_value = mock_task
|
||||
resp = client.post("/api/connectors/sync", json={
|
||||
"source_id": str(sid), "session_token": "tok",
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
data = json.loads(resp.data)
|
||||
assert data["task_id"] == "task_123"
|
||||
|
||||
|
||||
class TestConnectorCallbackStatus:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_success_status(self, client):
|
||||
resp = client.get("/api/connectors/callback-status?status=success&message=OK&provider=google_drive&session_token=tok&user_email=u@e.com")
|
||||
assert resp.status_code == 200
|
||||
assert b"success" in resp.data
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_error_status(self, client):
|
||||
resp = client.get("/api/connectors/callback-status?status=error&message=Failed")
|
||||
assert resp.status_code == 200
|
||||
assert b"error" in resp.data
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_cancelled_status(self, client):
|
||||
resp = client.get("/api/connectors/callback-status?status=cancelled&message=Cancelled&provider=google_drive")
|
||||
assert resp.status_code == 200
|
||||
assert b"cancelled" in resp.data
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_unknown_status_defaults_to_error(self, client):
|
||||
resp = client.get("/api/connectors/callback-status?status=badvalue")
|
||||
assert resp.status_code == 200
|
||||
assert b"error" in resp.data
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_html_escaping(self, client):
|
||||
resp = client.get('/api/connectors/callback-status?status=error&message=<script>alert(1)</script>')
|
||||
assert resp.status_code == 200
|
||||
# The raw <script> tag should be escaped (not executable)
|
||||
assert b"<script>alert(1)</script>" not in resp.data
|
||||
|
||||
|
||||
class TestBuildCallbackRedirect:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_builds_url(self):
|
||||
from application.api.connector.routes import build_callback_redirect
|
||||
url = build_callback_redirect({"status": "success", "message": "OK"})
|
||||
assert url.startswith("/api/connectors/callback-status?")
|
||||
assert "status=success" in url
|
||||
339
tests/core/test_model_settings.py
Normal file
339
tests/core/test_model_settings.py
Normal file
@@ -0,0 +1,339 @@
|
||||
"""Tests for application/core/model_settings.py"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from application.core.model_settings import (
|
||||
AvailableModel,
|
||||
ModelCapabilities,
|
||||
ModelProvider,
|
||||
ModelRegistry,
|
||||
)
|
||||
|
||||
|
||||
class TestModelProvider:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_all_providers_exist(self):
|
||||
assert ModelProvider.OPENAI == "openai"
|
||||
assert ModelProvider.ANTHROPIC == "anthropic"
|
||||
assert ModelProvider.GOOGLE == "google"
|
||||
assert ModelProvider.GROQ == "groq"
|
||||
assert ModelProvider.DOCSGPT == "docsgpt"
|
||||
assert ModelProvider.HUGGINGFACE == "huggingface"
|
||||
assert ModelProvider.NOVITA == "novita"
|
||||
assert ModelProvider.OPENROUTER == "openrouter"
|
||||
assert ModelProvider.SAGEMAKER == "sagemaker"
|
||||
assert ModelProvider.PREMAI == "premai"
|
||||
assert ModelProvider.LLAMA_CPP == "llama.cpp"
|
||||
assert ModelProvider.AZURE_OPENAI == "azure_openai"
|
||||
|
||||
|
||||
class TestModelCapabilities:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_defaults(self):
|
||||
caps = ModelCapabilities()
|
||||
assert caps.supports_tools is False
|
||||
assert caps.supports_structured_output is False
|
||||
assert caps.supports_streaming is True
|
||||
assert caps.supported_attachment_types == []
|
||||
assert caps.context_window == 128000
|
||||
assert caps.input_cost_per_token is None
|
||||
assert caps.output_cost_per_token is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_custom_values(self):
|
||||
caps = ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
context_window=32000,
|
||||
input_cost_per_token=0.001,
|
||||
)
|
||||
assert caps.supports_tools is True
|
||||
assert caps.context_window == 32000
|
||||
|
||||
|
||||
class TestAvailableModel:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_basic(self):
|
||||
model = AvailableModel(
|
||||
id="gpt-4",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="GPT-4",
|
||||
description="OpenAI GPT-4",
|
||||
)
|
||||
d = model.to_dict()
|
||||
assert d["id"] == "gpt-4"
|
||||
assert d["provider"] == "openai"
|
||||
assert d["display_name"] == "GPT-4"
|
||||
assert d["enabled"] is True
|
||||
assert "base_url" not in d
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_with_base_url(self):
|
||||
model = AvailableModel(
|
||||
id="local-model",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="Local",
|
||||
base_url="http://localhost:11434",
|
||||
)
|
||||
d = model.to_dict()
|
||||
assert d["base_url"] == "http://localhost:11434"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_includes_capabilities(self):
|
||||
caps = ModelCapabilities(supports_tools=True, context_window=64000)
|
||||
model = AvailableModel(
|
||||
id="m1",
|
||||
provider=ModelProvider.ANTHROPIC,
|
||||
display_name="M1",
|
||||
capabilities=caps,
|
||||
)
|
||||
d = model.to_dict()
|
||||
assert d["supports_tools"] is True
|
||||
assert d["context_window"] == 64000
|
||||
|
||||
|
||||
class TestModelRegistry:
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_singleton(self):
|
||||
"""Reset singleton between tests."""
|
||||
ModelRegistry._instance = None
|
||||
ModelRegistry._initialized = False
|
||||
yield
|
||||
ModelRegistry._instance = None
|
||||
ModelRegistry._initialized = False
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_singleton(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
r1 = ModelRegistry()
|
||||
r2 = ModelRegistry()
|
||||
assert r1 is r2
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_instance(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
r = ModelRegistry.get_instance()
|
||||
assert isinstance(r, ModelRegistry)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_model(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
model = AvailableModel(id="test", provider=ModelProvider.OPENAI, display_name="Test")
|
||||
reg.models["test"] = model
|
||||
assert reg.get_model("test") is model
|
||||
assert reg.get_model("nonexistent") is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_all_models(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models["m1"] = AvailableModel(id="m1", provider=ModelProvider.OPENAI, display_name="M1")
|
||||
reg.models["m2"] = AvailableModel(id="m2", provider=ModelProvider.ANTHROPIC, display_name="M2")
|
||||
assert len(reg.get_all_models()) == 2
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_enabled_models(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models["m1"] = AvailableModel(id="m1", provider=ModelProvider.OPENAI, display_name="M1", enabled=True)
|
||||
reg.models["m2"] = AvailableModel(id="m2", provider=ModelProvider.OPENAI, display_name="M2", enabled=False)
|
||||
enabled = reg.get_enabled_models()
|
||||
assert len(enabled) == 1
|
||||
assert enabled[0].id == "m1"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_model_exists(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models["m1"] = AvailableModel(id="m1", provider=ModelProvider.OPENAI, display_name="M1")
|
||||
assert reg.model_exists("m1") is True
|
||||
assert reg.model_exists("m2") is False
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_parse_model_names(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
assert reg._parse_model_names("model1,model2") == ["model1", "model2"]
|
||||
assert reg._parse_model_names("model1 , model2 ") == ["model1", "model2"]
|
||||
assert reg._parse_model_names("single") == ["single"]
|
||||
assert reg._parse_model_names("") == []
|
||||
assert reg._parse_model_names(None) == []
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_docsgpt_models(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
reg._add_docsgpt_models(mock_settings)
|
||||
assert "docsgpt-local" in reg.models
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_huggingface_models(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
reg._add_huggingface_models(mock_settings)
|
||||
assert "huggingface-local" in reg.models
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_load_models_with_openai_key(self):
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.OPENAI_BASE_URL = None
|
||||
mock_settings.OPENAI_API_KEY = "sk-test"
|
||||
mock_settings.OPENAI_API_BASE = None
|
||||
mock_settings.ANTHROPIC_API_KEY = None
|
||||
mock_settings.GOOGLE_API_KEY = None
|
||||
mock_settings.GROQ_API_KEY = None
|
||||
mock_settings.OPEN_ROUTER_API_KEY = None
|
||||
mock_settings.NOVITA_API_KEY = None
|
||||
mock_settings.HUGGINGFACE_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = "openai"
|
||||
mock_settings.LLM_NAME = ""
|
||||
mock_settings.API_KEY = None
|
||||
|
||||
with patch("application.core.settings.settings", mock_settings):
|
||||
reg = ModelRegistry()
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_load_models_custom_openai_base_url(self):
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.OPENAI_BASE_URL = "http://localhost:11434/v1"
|
||||
mock_settings.OPENAI_API_KEY = "sk-test"
|
||||
mock_settings.OPENAI_API_BASE = None
|
||||
mock_settings.ANTHROPIC_API_KEY = None
|
||||
mock_settings.GOOGLE_API_KEY = None
|
||||
mock_settings.GROQ_API_KEY = None
|
||||
mock_settings.OPEN_ROUTER_API_KEY = None
|
||||
mock_settings.NOVITA_API_KEY = None
|
||||
mock_settings.HUGGINGFACE_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = "openai"
|
||||
mock_settings.LLM_NAME = "llama3,gemma"
|
||||
mock_settings.API_KEY = None
|
||||
|
||||
with patch("application.core.settings.settings", mock_settings):
|
||||
reg = ModelRegistry()
|
||||
assert "llama3" in reg.models
|
||||
assert "gemma" in reg.models
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_default_model_selection_from_llm_name(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {"gpt-4": AvailableModel(id="gpt-4", provider=ModelProvider.OPENAI, display_name="GPT-4")}
|
||||
reg.default_model_id = "gpt-4"
|
||||
assert reg.default_model_id == "gpt-4"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_anthropic_models_with_key(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.ANTHROPIC_API_KEY = "sk-ant-test"
|
||||
mock_settings.LLM_PROVIDER = ""
|
||||
mock_settings.LLM_NAME = ""
|
||||
reg._add_anthropic_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_google_models_with_key(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.GOOGLE_API_KEY = "google-test"
|
||||
mock_settings.LLM_PROVIDER = ""
|
||||
mock_settings.LLM_NAME = ""
|
||||
reg._add_google_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_groq_models_with_key(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.GROQ_API_KEY = "groq-test"
|
||||
mock_settings.LLM_PROVIDER = ""
|
||||
mock_settings.LLM_NAME = ""
|
||||
reg._add_groq_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_openrouter_models_with_key(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.OPEN_ROUTER_API_KEY = "or-test"
|
||||
mock_settings.LLM_PROVIDER = ""
|
||||
mock_settings.LLM_NAME = ""
|
||||
reg._add_openrouter_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_novita_models_with_key(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.NOVITA_API_KEY = "novita-test"
|
||||
mock_settings.LLM_PROVIDER = ""
|
||||
mock_settings.LLM_NAME = ""
|
||||
reg._add_novita_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_azure_openai_models_specific(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.LLM_PROVIDER = "azure_openai"
|
||||
mock_settings.LLM_NAME = "nonexistent-model"
|
||||
reg._add_azure_openai_models(mock_settings)
|
||||
# Falls through to adding all azure models
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_anthropic_models_no_key_with_provider(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.ANTHROPIC_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = "anthropic"
|
||||
mock_settings.LLM_NAME = "nonexistent"
|
||||
reg._add_anthropic_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_default_model_fallback_to_first(self):
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.OPENAI_BASE_URL = None
|
||||
mock_settings.OPENAI_API_KEY = None
|
||||
mock_settings.OPENAI_API_BASE = None
|
||||
mock_settings.ANTHROPIC_API_KEY = None
|
||||
mock_settings.GOOGLE_API_KEY = None
|
||||
mock_settings.GROQ_API_KEY = None
|
||||
mock_settings.OPEN_ROUTER_API_KEY = None
|
||||
mock_settings.NOVITA_API_KEY = None
|
||||
mock_settings.HUGGINGFACE_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = ""
|
||||
mock_settings.LLM_NAME = ""
|
||||
mock_settings.API_KEY = None
|
||||
|
||||
with patch("application.core.settings.settings", mock_settings):
|
||||
reg = ModelRegistry()
|
||||
# Should have at least docsgpt-local
|
||||
assert reg.default_model_id is not None
|
||||
577
tests/integration/test_workflows.py
Normal file
577
tests/integration/test_workflows.py
Normal file
@@ -0,0 +1,577 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Integration tests for DocsGPT workflow management endpoints.
|
||||
|
||||
Uses Flask test client with real MongoDB (must be running).
|
||||
|
||||
Endpoints tested:
|
||||
- /api/workflows (POST) - Create workflow
|
||||
- /api/workflows/<id> (GET) - Get workflow
|
||||
- /api/workflows/<id> (PUT) - Update workflow
|
||||
- /api/workflows/<id> (DELETE) - Delete workflow
|
||||
|
||||
Run:
|
||||
pytest tests/integration/test_workflows.py -v
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from jose import jwt
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def app():
|
||||
"""Create the real Flask app (connects to real MongoDB)."""
|
||||
from application.app import app as flask_app
|
||||
flask_app.config["TESTING"] = True
|
||||
return flask_app
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client(app):
|
||||
"""Flask test client.
|
||||
|
||||
When AUTH_TYPE is set to simple_jwt/session_jwt a Bearer token is
|
||||
injected; otherwise the backend already returns {"sub": "local"}
|
||||
for every request so no token is needed.
|
||||
"""
|
||||
from application.core.settings import settings
|
||||
|
||||
c = app.test_client()
|
||||
if settings.AUTH_TYPE in ("simple_jwt", "session_jwt"):
|
||||
secret = settings.JWT_SECRET_KEY
|
||||
if not secret:
|
||||
pytest.skip("JWT_SECRET_KEY not configured")
|
||||
payload = {"sub": f"test_workflow_integration_{int(time.time())}"}
|
||||
token = jwt.encode(payload, secret, algorithm="HS256")
|
||||
c.environ_base["HTTP_AUTHORIZATION"] = f"Bearer {token}"
|
||||
return c
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def created_ids():
|
||||
"""Accumulator for workflow IDs to clean up after all tests."""
|
||||
return []
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="module")
|
||||
def cleanup(client, created_ids):
|
||||
"""Delete all test-created workflows after the module finishes."""
|
||||
yield
|
||||
for wf_id in created_ids:
|
||||
try:
|
||||
client.delete(f"/api/workflows/{wf_id}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Payload helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def simple_workflow(suffix=""):
|
||||
"""Start -> End."""
|
||||
return {
|
||||
"name": f"Simple WF {int(time.time())}{suffix}",
|
||||
"description": "integration test",
|
||||
"nodes": [
|
||||
{"id": "start_1", "type": "start", "title": "Start",
|
||||
"position": {"x": 0, "y": 0}, "data": {}},
|
||||
{"id": "end_1", "type": "end", "title": "End",
|
||||
"position": {"x": 400, "y": 0}, "data": {}},
|
||||
],
|
||||
"edges": [
|
||||
{"id": "edge_1", "source": "start_1", "target": "end_1"},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def linear_workflow(suffix=""):
|
||||
"""Start -> Agent -> End."""
|
||||
return {
|
||||
"name": f"Linear WF {int(time.time())}{suffix}",
|
||||
"description": "integration test",
|
||||
"nodes": [
|
||||
{"id": "start_1", "type": "start", "title": "Start",
|
||||
"position": {"x": 0, "y": 0}, "data": {}},
|
||||
{"id": "agent_1", "type": "agent", "title": "Agent",
|
||||
"position": {"x": 200, "y": 0}, "data": {
|
||||
"agent_type": "classic",
|
||||
"system_prompt": "You are helpful.",
|
||||
"prompt_template": "",
|
||||
"stream_to_user": False,
|
||||
}},
|
||||
{"id": "end_1", "type": "end", "title": "End",
|
||||
"position": {"x": 400, "y": 0}, "data": {}},
|
||||
],
|
||||
"edges": [
|
||||
{"id": "edge_1", "source": "start_1", "target": "agent_1"},
|
||||
{"id": "edge_2", "source": "agent_1", "target": "end_1"},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def multi_input_end_workflow(suffix=""):
|
||||
"""Condition branches into two agents, both converging on one end node.
|
||||
|
||||
Graph:
|
||||
start -> condition --(case_1)--> agent_a --\
|
||||
--(else)----> agent_b ---+--> end
|
||||
"""
|
||||
return {
|
||||
"name": f"Multi-Input End {int(time.time())}{suffix}",
|
||||
"description": "end node with multiple inputs",
|
||||
"nodes": [
|
||||
{"id": "start_1", "type": "start", "title": "Start",
|
||||
"position": {"x": 0, "y": 100}, "data": {}},
|
||||
{"id": "cond_1", "type": "condition", "title": "Branch",
|
||||
"position": {"x": 200, "y": 100}, "data": {
|
||||
"mode": "simple",
|
||||
"cases": [
|
||||
{"name": "Case 1", "expression": "true",
|
||||
"sourceHandle": "case_1"},
|
||||
],
|
||||
}},
|
||||
{"id": "agent_a", "type": "agent", "title": "Agent A",
|
||||
"position": {"x": 400, "y": 0}, "data": {
|
||||
"agent_type": "classic",
|
||||
"system_prompt": "Branch A",
|
||||
"prompt_template": "",
|
||||
"stream_to_user": False,
|
||||
}},
|
||||
{"id": "agent_b", "type": "agent", "title": "Agent B",
|
||||
"position": {"x": 400, "y": 200}, "data": {
|
||||
"agent_type": "classic",
|
||||
"system_prompt": "Branch B",
|
||||
"prompt_template": "",
|
||||
"stream_to_user": False,
|
||||
}},
|
||||
{"id": "end_1", "type": "end", "title": "End",
|
||||
"position": {"x": 600, "y": 100}, "data": {}},
|
||||
],
|
||||
"edges": [
|
||||
{"id": "e1", "source": "start_1", "target": "cond_1"},
|
||||
{"id": "e2", "source": "cond_1", "target": "agent_a",
|
||||
"sourceHandle": "case_1"},
|
||||
{"id": "e3", "source": "cond_1", "target": "agent_b",
|
||||
"sourceHandle": "else"},
|
||||
# Both agents feed into the SAME end node
|
||||
{"id": "e4", "source": "agent_a", "target": "end_1"},
|
||||
{"id": "e5", "source": "agent_b", "target": "end_1"},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _extract_id(resp):
|
||||
"""Pull workflow id from create/update response."""
|
||||
body = resp.get_json()
|
||||
data = body.get("data") or body
|
||||
return data.get("id")
|
||||
|
||||
|
||||
def _get_graph(client, wf_id):
|
||||
"""Fetch workflow and return (nodes, edges)."""
|
||||
resp = client.get(f"/api/workflows/{wf_id}")
|
||||
assert resp.status_code == 200, resp.get_data(as_text=True)
|
||||
body = resp.get_json()
|
||||
data = body.get("data") or body
|
||||
return data.get("nodes", []), data.get("edges", [])
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# CRUD tests
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestWorkflowCRUD:
|
||||
|
||||
def test_create_simple_workflow(self, client, created_ids):
|
||||
resp = client.post("/api/workflows", json=simple_workflow())
|
||||
assert resp.status_code in (200, 201), resp.get_data(as_text=True)
|
||||
wf_id = _extract_id(resp)
|
||||
assert wf_id
|
||||
created_ids.append(wf_id)
|
||||
|
||||
def test_create_linear_workflow(self, client, created_ids):
|
||||
resp = client.post("/api/workflows", json=linear_workflow())
|
||||
assert resp.status_code in (200, 201), resp.get_data(as_text=True)
|
||||
wf_id = _extract_id(resp)
|
||||
assert wf_id
|
||||
created_ids.append(wf_id)
|
||||
|
||||
def test_get_workflow_returns_nodes_and_edges(self, client, created_ids):
|
||||
resp = client.post("/api/workflows", json=simple_workflow(" get"))
|
||||
wf_id = _extract_id(resp)
|
||||
created_ids.append(wf_id)
|
||||
|
||||
nodes, edges = _get_graph(client, wf_id)
|
||||
assert len(nodes) == 2
|
||||
assert len(edges) == 1
|
||||
|
||||
def test_update_workflow(self, client, created_ids):
|
||||
resp = client.post("/api/workflows", json=simple_workflow(" upd"))
|
||||
wf_id = _extract_id(resp)
|
||||
created_ids.append(wf_id)
|
||||
|
||||
update_resp = client.put(
|
||||
f"/api/workflows/{wf_id}", json=linear_workflow(" updated")
|
||||
)
|
||||
assert update_resp.status_code == 200, update_resp.get_data(as_text=True)
|
||||
|
||||
nodes, edges = _get_graph(client, wf_id)
|
||||
assert len(nodes) == 3 # start, agent, end
|
||||
assert len(edges) == 2
|
||||
|
||||
def test_delete_workflow(self, client):
|
||||
resp = client.post("/api/workflows", json=simple_workflow(" del"))
|
||||
wf_id = _extract_id(resp)
|
||||
|
||||
del_resp = client.delete(f"/api/workflows/{wf_id}")
|
||||
assert del_resp.status_code == 200
|
||||
|
||||
get_resp = client.get(f"/api/workflows/{wf_id}")
|
||||
assert get_resp.status_code in (400, 404)
|
||||
|
||||
def test_reject_workflow_without_end_node(self, client):
|
||||
payload = {
|
||||
"name": "No End",
|
||||
"nodes": [
|
||||
{"id": "s", "type": "start", "title": "Start",
|
||||
"position": {"x": 0, "y": 0}, "data": {}},
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
resp = client.post("/api/workflows", json=payload)
|
||||
assert resp.status_code == 400, resp.get_data(as_text=True)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Multi-input end node tests
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestMultiInputEndNode:
|
||||
"""Verify that an end node can receive edges from multiple source nodes."""
|
||||
|
||||
def test_create_multi_input_end_workflow_accepted(self, client, created_ids):
|
||||
"""Backend must accept a workflow where two edges target the same end node."""
|
||||
resp = client.post("/api/workflows", json=multi_input_end_workflow())
|
||||
assert resp.status_code in (200, 201), resp.get_data(as_text=True)
|
||||
wf_id = _extract_id(resp)
|
||||
assert wf_id
|
||||
created_ids.append(wf_id)
|
||||
|
||||
def test_multi_input_end_all_edges_persisted(self, client, created_ids):
|
||||
"""After round-trip, both edges into the end node must still be present."""
|
||||
resp = client.post(
|
||||
"/api/workflows", json=multi_input_end_workflow(" persist")
|
||||
)
|
||||
assert resp.status_code in (200, 201), resp.get_data(as_text=True)
|
||||
wf_id = _extract_id(resp)
|
||||
created_ids.append(wf_id)
|
||||
|
||||
nodes, edges = _get_graph(client, wf_id)
|
||||
|
||||
# Locate end node
|
||||
end_ids = {n["id"] for n in nodes if n["type"] == "end"}
|
||||
assert end_ids, "no end node in response"
|
||||
|
||||
# Count edges targeting any end node
|
||||
edges_to_end = [e for e in edges if e["target"] in end_ids]
|
||||
assert len(edges_to_end) >= 2, (
|
||||
f"Expected >=2 edges to end, got {len(edges_to_end)}: {edges_to_end}"
|
||||
)
|
||||
|
||||
def test_multi_input_end_total_edge_count(self, client, created_ids):
|
||||
"""All 5 edges of the multi-input graph must survive persistence."""
|
||||
resp = client.post(
|
||||
"/api/workflows", json=multi_input_end_workflow(" count")
|
||||
)
|
||||
wf_id = _extract_id(resp)
|
||||
created_ids.append(wf_id)
|
||||
|
||||
_, edges = _get_graph(client, wf_id)
|
||||
assert len(edges) == 5, f"Expected 5 edges, got {len(edges)}"
|
||||
|
||||
def test_update_to_multi_input_end_preserves_edges(self, client, created_ids):
|
||||
"""Updating a simple workflow to multi-input end keeps all edges."""
|
||||
# Create simple
|
||||
resp = client.post("/api/workflows", json=simple_workflow(" pre"))
|
||||
wf_id = _extract_id(resp)
|
||||
created_ids.append(wf_id)
|
||||
|
||||
# Update to multi-input end
|
||||
update_resp = client.put(
|
||||
f"/api/workflows/{wf_id}",
|
||||
json=multi_input_end_workflow(" post"),
|
||||
)
|
||||
assert update_resp.status_code == 200, update_resp.get_data(as_text=True)
|
||||
|
||||
nodes, edges = _get_graph(client, wf_id)
|
||||
end_ids = {n["id"] for n in nodes if n["type"] == "end"}
|
||||
edges_to_end = [e for e in edges if e["target"] in end_ids]
|
||||
assert len(edges_to_end) >= 2, (
|
||||
f"Expected >=2 edges to end after update, got {len(edges_to_end)}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Source-aware payload helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def workflow_with_sources(sources, suffix=""):
|
||||
"""Start -> Agent (with sources) -> End."""
|
||||
return {
|
||||
"name": f"Source WF {int(time.time())}{suffix}",
|
||||
"description": "integration test with sources",
|
||||
"nodes": [
|
||||
{"id": "start_1", "type": "start", "title": "Start",
|
||||
"position": {"x": 0, "y": 0}, "data": {}},
|
||||
{"id": "agent_1", "type": "agent", "title": "Agent",
|
||||
"position": {"x": 200, "y": 0}, "data": {
|
||||
"agent_type": "classic",
|
||||
"system_prompt": "You are helpful.",
|
||||
"prompt_template": "",
|
||||
"stream_to_user": False,
|
||||
"sources": sources,
|
||||
"tools": [],
|
||||
}},
|
||||
{"id": "end_1", "type": "end", "title": "End",
|
||||
"position": {"x": 400, "y": 0}, "data": {}},
|
||||
],
|
||||
"edges": [
|
||||
{"id": "edge_1", "source": "start_1", "target": "agent_1"},
|
||||
{"id": "edge_2", "source": "agent_1", "target": "end_1"},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def workflow_multi_agent_sources(suffix=""):
|
||||
"""Start -> Agent A (sources A) -> Agent B (sources B) -> End."""
|
||||
return {
|
||||
"name": f"Multi-Agent Sources {int(time.time())}{suffix}",
|
||||
"description": "two agents with different sources",
|
||||
"nodes": [
|
||||
{"id": "start_1", "type": "start", "title": "Start",
|
||||
"position": {"x": 0, "y": 0}, "data": {}},
|
||||
{"id": "agent_a", "type": "agent", "title": "Agent A",
|
||||
"position": {"x": 200, "y": 0}, "data": {
|
||||
"agent_type": "agentic",
|
||||
"system_prompt": "Agent A prompt",
|
||||
"prompt_template": "",
|
||||
"stream_to_user": False,
|
||||
"sources": ["src_alpha", "src_beta"],
|
||||
"tools": [],
|
||||
}},
|
||||
{"id": "agent_b", "type": "agent", "title": "Agent B",
|
||||
"position": {"x": 400, "y": 0}, "data": {
|
||||
"agent_type": "classic",
|
||||
"system_prompt": "Agent B prompt",
|
||||
"prompt_template": "",
|
||||
"stream_to_user": True,
|
||||
"sources": ["src_gamma"],
|
||||
"tools": [],
|
||||
}},
|
||||
{"id": "end_1", "type": "end", "title": "End",
|
||||
"position": {"x": 600, "y": 0}, "data": {}},
|
||||
],
|
||||
"edges": [
|
||||
{"id": "e1", "source": "start_1", "target": "agent_a"},
|
||||
{"id": "e2", "source": "agent_a", "target": "agent_b"},
|
||||
{"id": "e3", "source": "agent_b", "target": "end_1"},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _find_agent_node(nodes, node_id):
|
||||
"""Find a specific node by id."""
|
||||
return next((n for n in nodes if n["id"] == node_id), None)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Workflow integration tests
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestWorkflowIntegration:
|
||||
"""Verify end-to-end workflow create → get → update → get round-trips."""
|
||||
|
||||
def test_linear_workflow_round_trip(self, client, created_ids):
|
||||
"""Create a linear workflow and verify all nodes/edges survive the round-trip."""
|
||||
payload = linear_workflow(" round-trip")
|
||||
resp = client.post("/api/workflows", json=payload)
|
||||
assert resp.status_code in (200, 201), resp.get_data(as_text=True)
|
||||
wf_id = _extract_id(resp)
|
||||
assert wf_id
|
||||
created_ids.append(wf_id)
|
||||
|
||||
nodes, edges = _get_graph(client, wf_id)
|
||||
assert len(nodes) == 3
|
||||
assert len(edges) == 2
|
||||
|
||||
# Verify node types
|
||||
types = {n["id"]: n["type"] for n in nodes}
|
||||
assert types["start_1"] == "start"
|
||||
assert types["agent_1"] == "agent"
|
||||
assert types["end_1"] == "end"
|
||||
|
||||
def test_agent_config_persisted(self, client, created_ids):
|
||||
"""Agent node config (type, prompts, stream_to_user) round-trips correctly."""
|
||||
payload = linear_workflow(" config")
|
||||
resp = client.post("/api/workflows", json=payload)
|
||||
wf_id = _extract_id(resp)
|
||||
created_ids.append(wf_id)
|
||||
|
||||
nodes, _ = _get_graph(client, wf_id)
|
||||
agent = _find_agent_node(nodes, "agent_1")
|
||||
assert agent is not None
|
||||
assert agent["data"]["agent_type"] == "classic"
|
||||
assert agent["data"]["system_prompt"] == "You are helpful."
|
||||
assert agent["data"]["stream_to_user"] is False
|
||||
|
||||
def test_update_workflow_replaces_graph(self, client, created_ids):
|
||||
"""Updating a workflow fully replaces nodes and edges."""
|
||||
resp = client.post("/api/workflows", json=simple_workflow(" replace"))
|
||||
wf_id = _extract_id(resp)
|
||||
created_ids.append(wf_id)
|
||||
|
||||
nodes, edges = _get_graph(client, wf_id)
|
||||
assert len(nodes) == 2
|
||||
|
||||
# Update to linear
|
||||
update_resp = client.put(
|
||||
f"/api/workflows/{wf_id}", json=linear_workflow(" replaced")
|
||||
)
|
||||
assert update_resp.status_code == 200
|
||||
|
||||
nodes, edges = _get_graph(client, wf_id)
|
||||
assert len(nodes) == 3
|
||||
assert len(edges) == 2
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Source-specific integration tests
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestWorkflowSources:
|
||||
"""Verify that agent node sources are persisted and retrieved correctly."""
|
||||
|
||||
def test_create_workflow_with_single_source(self, client, created_ids):
|
||||
"""A workflow with one source on an agent node persists it."""
|
||||
payload = workflow_with_sources(["default"])
|
||||
resp = client.post("/api/workflows", json=payload)
|
||||
assert resp.status_code in (200, 201), resp.get_data(as_text=True)
|
||||
wf_id = _extract_id(resp)
|
||||
assert wf_id
|
||||
created_ids.append(wf_id)
|
||||
|
||||
nodes, _ = _get_graph(client, wf_id)
|
||||
agent = _find_agent_node(nodes, "agent_1")
|
||||
assert agent is not None, "Agent node not found"
|
||||
assert agent["data"].get("sources") == ["default"], (
|
||||
f"Expected sources=['default'], got {agent['data'].get('sources')}"
|
||||
)
|
||||
|
||||
def test_create_workflow_with_multiple_sources(self, client, created_ids):
|
||||
"""Multiple sources on an agent node are all persisted."""
|
||||
sources = ["src_1", "src_2", "src_3"]
|
||||
payload = workflow_with_sources(sources)
|
||||
resp = client.post("/api/workflows", json=payload)
|
||||
assert resp.status_code in (200, 201), resp.get_data(as_text=True)
|
||||
wf_id = _extract_id(resp)
|
||||
created_ids.append(wf_id)
|
||||
|
||||
nodes, _ = _get_graph(client, wf_id)
|
||||
agent = _find_agent_node(nodes, "agent_1")
|
||||
assert agent is not None
|
||||
assert agent["data"].get("sources") == sources
|
||||
|
||||
def test_create_workflow_with_empty_sources(self, client, created_ids):
|
||||
"""An agent node with empty sources list is accepted and persisted."""
|
||||
payload = workflow_with_sources([])
|
||||
resp = client.post("/api/workflows", json=payload)
|
||||
assert resp.status_code in (200, 201), resp.get_data(as_text=True)
|
||||
wf_id = _extract_id(resp)
|
||||
assert wf_id
|
||||
created_ids.append(wf_id)
|
||||
|
||||
nodes, _ = _get_graph(client, wf_id)
|
||||
agent = _find_agent_node(nodes, "agent_1")
|
||||
assert agent is not None
|
||||
assert agent["data"].get("sources") == []
|
||||
|
||||
def test_update_workflow_sources(self, client, created_ids):
|
||||
"""Updating a workflow replaces agent sources."""
|
||||
# Create with original sources
|
||||
payload = workflow_with_sources(["old_src"])
|
||||
resp = client.post("/api/workflows", json=payload)
|
||||
wf_id = _extract_id(resp)
|
||||
created_ids.append(wf_id)
|
||||
|
||||
# Update with new sources
|
||||
updated_payload = workflow_with_sources(["new_src_1", "new_src_2"], " upd")
|
||||
update_resp = client.put(f"/api/workflows/{wf_id}", json=updated_payload)
|
||||
assert update_resp.status_code == 200, update_resp.get_data(as_text=True)
|
||||
|
||||
nodes, _ = _get_graph(client, wf_id)
|
||||
agent = _find_agent_node(nodes, "agent_1")
|
||||
assert agent is not None
|
||||
assert agent["data"].get("sources") == ["new_src_1", "new_src_2"]
|
||||
|
||||
def test_multi_agent_independent_sources(self, client, created_ids):
|
||||
"""Each agent node keeps its own distinct sources list."""
|
||||
payload = workflow_multi_agent_sources()
|
||||
resp = client.post("/api/workflows", json=payload)
|
||||
assert resp.status_code in (200, 201), resp.get_data(as_text=True)
|
||||
wf_id = _extract_id(resp)
|
||||
created_ids.append(wf_id)
|
||||
|
||||
nodes, _ = _get_graph(client, wf_id)
|
||||
agent_a = _find_agent_node(nodes, "agent_a")
|
||||
agent_b = _find_agent_node(nodes, "agent_b")
|
||||
|
||||
assert agent_a is not None, "Agent A not found"
|
||||
assert agent_b is not None, "Agent B not found"
|
||||
assert agent_a["data"].get("sources") == ["src_alpha", "src_beta"]
|
||||
assert agent_b["data"].get("sources") == ["src_gamma"]
|
||||
|
||||
def test_sources_survive_workflow_update(self, client, created_ids):
|
||||
"""Sources survive when a workflow is updated without changing sources."""
|
||||
payload = workflow_with_sources(["persistent_src"])
|
||||
resp = client.post("/api/workflows", json=payload)
|
||||
wf_id = _extract_id(resp)
|
||||
created_ids.append(wf_id)
|
||||
|
||||
# Update keeping same sources
|
||||
update_resp = client.put(f"/api/workflows/{wf_id}", json=payload)
|
||||
assert update_resp.status_code == 200
|
||||
|
||||
nodes, _ = _get_graph(client, wf_id)
|
||||
agent = _find_agent_node(nodes, "agent_1")
|
||||
assert agent["data"].get("sources") == ["persistent_src"]
|
||||
|
||||
def test_remove_sources_on_update(self, client, created_ids):
|
||||
"""Clearing sources on update results in empty list."""
|
||||
payload = workflow_with_sources(["will_be_removed"])
|
||||
resp = client.post("/api/workflows", json=payload)
|
||||
wf_id = _extract_id(resp)
|
||||
created_ids.append(wf_id)
|
||||
|
||||
# Update with no sources
|
||||
cleared_payload = workflow_with_sources([], " cleared")
|
||||
update_resp = client.put(f"/api/workflows/{wf_id}", json=cleared_payload)
|
||||
assert update_resp.status_code == 200
|
||||
|
||||
nodes, _ = _get_graph(client, wf_id)
|
||||
agent = _find_agent_node(nodes, "agent_1")
|
||||
assert agent["data"].get("sources") == []
|
||||
@@ -71,6 +71,7 @@ class TestLLMHandlerCreator:
|
||||
expected_handlers = {
|
||||
"openai": OpenAILLMHandler,
|
||||
"google": GoogleLLMHandler,
|
||||
"novita": OpenAILLMHandler,
|
||||
"default": OpenAILLMHandler,
|
||||
}
|
||||
|
||||
|
||||
165
tests/llm/test_novita_llm.py
Normal file
165
tests/llm/test_novita_llm.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""Tests for the Novita LLM provider.
|
||||
|
||||
Novita uses an OpenAI-compatible API, so NovitaLLM extends OpenAILLM.
|
||||
These tests verify the Novita-specific configuration is applied correctly.
|
||||
"""
|
||||
|
||||
import types
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from application.llm.novita import NOVITA_BASE_URL, NovitaLLM
|
||||
|
||||
|
||||
class FakeChatCompletions:
|
||||
"""Fake OpenAI chat completions for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self.last_kwargs = None
|
||||
|
||||
class _Msg:
|
||||
def __init__(self, content=None):
|
||||
self.content = content
|
||||
|
||||
class _Delta:
|
||||
def __init__(self, content=None):
|
||||
self.content = content
|
||||
|
||||
class _Choice:
|
||||
def __init__(self, content=None, delta=None):
|
||||
self.message = FakeChatCompletions._Msg(content=content)
|
||||
self.delta = FakeChatCompletions._Delta(content=delta)
|
||||
|
||||
class _StreamChunk:
|
||||
def __init__(self, choice):
|
||||
self.choices = [choice]
|
||||
|
||||
class _Response:
|
||||
def __init__(self, choices=None, lines=None):
|
||||
self._choices = choices or []
|
||||
self._lines = lines or []
|
||||
|
||||
@property
|
||||
def choices(self):
|
||||
return self._choices
|
||||
|
||||
def __iter__(self):
|
||||
for line in self._lines:
|
||||
yield line
|
||||
|
||||
def create(self, **kwargs):
|
||||
self.last_kwargs = kwargs
|
||||
if not kwargs.get("stream"):
|
||||
return FakeChatCompletions._Response(choices=[FakeChatCompletions._Choice(content="novita response")])
|
||||
return FakeChatCompletions._Response(
|
||||
lines=[
|
||||
FakeChatCompletions._StreamChunk(FakeChatCompletions._Choice(delta="part1")),
|
||||
FakeChatCompletions._StreamChunk(FakeChatCompletions._Choice(delta="part2")),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class FakeClient:
|
||||
"""Fake OpenAI client for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self.chat = types.SimpleNamespace(completions=FakeChatCompletions())
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_novita_base_url_constant():
|
||||
"""Verify the Novita base URL is correctly defined."""
|
||||
assert NOVITA_BASE_URL == "https://api.novita.ai/openai"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_novita_llm_uses_novita_base_url():
|
||||
"""Verify NovitaLLM uses the Novita API endpoint."""
|
||||
llm = NovitaLLM(api_key="test-key", user_api_key=None)
|
||||
# The client should be configured with Novita's base URL
|
||||
assert str(llm.client.base_url) == NOVITA_BASE_URL + "/"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_novita_llm_uses_novita_api_key():
|
||||
"""Verify NovitaLLM prioritizes NOVITA_API_KEY from settings."""
|
||||
with patch("application.llm.novita.settings") as mock_settings:
|
||||
mock_settings.NOVITA_API_KEY = "novita-test-key"
|
||||
mock_settings.API_KEY = "fallback-key"
|
||||
mock_settings.OPENAI_BASE_URL = None
|
||||
|
||||
llm = NovitaLLM(api_key=None, user_api_key=None)
|
||||
assert llm.api_key == "novita-test-key"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_novita_llm_falls_back_to_api_key():
|
||||
"""Verify NovitaLLM falls back to API_KEY when NOVITA_API_KEY is not set."""
|
||||
with patch("application.llm.novita.settings") as mock_settings:
|
||||
mock_settings.NOVITA_API_KEY = None
|
||||
mock_settings.API_KEY = "fallback-key"
|
||||
mock_settings.OPENAI_BASE_URL = None
|
||||
|
||||
llm = NovitaLLM(api_key=None, user_api_key=None)
|
||||
assert llm.api_key == "fallback-key"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_novita_llm_explicit_api_key_takes_precedence():
|
||||
"""Verify explicitly passed API key takes precedence over settings."""
|
||||
with patch("application.llm.novita.settings") as mock_settings:
|
||||
mock_settings.NOVITA_API_KEY = "settings-key"
|
||||
mock_settings.API_KEY = "fallback-key"
|
||||
mock_settings.OPENAI_BASE_URL = None
|
||||
|
||||
llm = NovitaLLM(api_key="explicit-key", user_api_key=None)
|
||||
assert llm.api_key == "explicit-key"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_novita_llm_custom_base_url():
|
||||
"""Verify custom base_url can override the default Novita URL."""
|
||||
custom_url = "https://custom.novita.endpoint/v1"
|
||||
llm = NovitaLLM(api_key="test-key", user_api_key=None, base_url=custom_url)
|
||||
assert str(llm.client.base_url) == custom_url + "/"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_novita_llm_supports_tools():
|
||||
"""Verify NovitaLLM supports function calling/tools."""
|
||||
llm = NovitaLLM(api_key="test-key", user_api_key=None)
|
||||
assert llm.supports_tools() is True
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_novita_llm_supports_structured_output():
|
||||
"""Verify NovitaLLM supports structured output."""
|
||||
llm = NovitaLLM(api_key="test-key", user_api_key=None)
|
||||
assert llm.supports_structured_output() is True
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_novita_llm_gen_calls_client(monkeypatch):
|
||||
"""Verify NovitaLLM.gen calls the OpenAI-compatible client correctly."""
|
||||
llm = NovitaLLM(api_key="test-key", user_api_key=None)
|
||||
llm.client = FakeClient()
|
||||
|
||||
msgs = [{"role": "user", "content": "hello"}]
|
||||
result = llm._raw_gen(llm, model="moonshotai/kimi-k2.5", messages=msgs, stream=False)
|
||||
|
||||
assert result == "novita response"
|
||||
assert llm.client.chat.completions.last_kwargs["model"] == "moonshotai/kimi-k2.5"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_novita_llm_gen_stream_yields_chunks(monkeypatch):
|
||||
"""Verify NovitaLLM streaming yields chunks correctly."""
|
||||
llm = NovitaLLM(api_key="test-key", user_api_key=None)
|
||||
llm.client = FakeClient()
|
||||
|
||||
msgs = [{"role": "user", "content": "hi"}]
|
||||
gen = llm._raw_gen_stream(llm, model="moonshotai/kimi-k2.5", messages=msgs, stream=True)
|
||||
chunks = list(gen)
|
||||
|
||||
assert "part1" in "".join(chunks)
|
||||
assert "part2" in "".join(chunks)
|
||||
0
tests/parser/__init__.py
Normal file
0
tests/parser/__init__.py
Normal file
0
tests/parser/connectors/__init__.py
Normal file
0
tests/parser/connectors/__init__.py
Normal file
110
tests/parser/connectors/test_base.py
Normal file
110
tests/parser/connectors/test_base.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""Tests for connector base classes."""
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
from application.parser.connectors.base import BaseConnectorAuth, BaseConnectorLoader
|
||||
from application.parser.schema.base import Document
|
||||
|
||||
|
||||
class ConcreteAuth(BaseConnectorAuth):
|
||||
"""Minimal concrete implementation for testing the ABC."""
|
||||
|
||||
def get_authorization_url(self, state=None):
|
||||
return f"https://example.com/auth?state={state}"
|
||||
|
||||
def exchange_code_for_tokens(self, authorization_code):
|
||||
return {"access_token": "tok", "code": authorization_code}
|
||||
|
||||
def refresh_access_token(self, refresh_token):
|
||||
return {"access_token": "new_tok", "refresh_token": refresh_token}
|
||||
|
||||
def is_token_expired(self, token_info):
|
||||
return token_info.get("expired", False)
|
||||
|
||||
|
||||
class ConcreteLoader(BaseConnectorLoader):
|
||||
"""Minimal concrete implementation for testing the ABC."""
|
||||
|
||||
def __init__(self, session_token):
|
||||
self.session_token = session_token
|
||||
|
||||
def load_data(self, inputs):
|
||||
return [Document(text="test", doc_id="1", extra_info={})]
|
||||
|
||||
def download_to_directory(self, local_dir, source_config=None):
|
||||
return {"files_downloaded": 0, "directory_path": local_dir}
|
||||
|
||||
|
||||
class TestBaseConnectorAuth:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_sanitize_token_info_extracts_standard_fields(self):
|
||||
auth = ConcreteAuth()
|
||||
token_info = {
|
||||
"access_token": "at",
|
||||
"refresh_token": "rt",
|
||||
"token_uri": "https://token.uri",
|
||||
"expiry": 12345,
|
||||
"extra_field": "should_not_appear",
|
||||
}
|
||||
result = auth.sanitize_token_info(token_info)
|
||||
assert result == {
|
||||
"access_token": "at",
|
||||
"refresh_token": "rt",
|
||||
"token_uri": "https://token.uri",
|
||||
"expiry": 12345,
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_sanitize_token_info_with_extra_kwargs(self):
|
||||
auth = ConcreteAuth()
|
||||
token_info = {
|
||||
"access_token": "at",
|
||||
"refresh_token": "rt",
|
||||
"token_uri": "https://token.uri",
|
||||
"expiry": 100,
|
||||
}
|
||||
result = auth.sanitize_token_info(token_info, custom_field="custom_val")
|
||||
assert result["custom_field"] == "custom_val"
|
||||
assert result["access_token"] == "at"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_sanitize_token_info_missing_fields_returns_none(self):
|
||||
auth = ConcreteAuth()
|
||||
result = auth.sanitize_token_info({})
|
||||
assert result["access_token"] is None
|
||||
assert result["refresh_token"] is None
|
||||
assert result["token_uri"] is None
|
||||
assert result["expiry"] is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_abstract_methods_invocable_on_concrete(self):
|
||||
auth = ConcreteAuth()
|
||||
assert "example.com" in auth.get_authorization_url("s1")
|
||||
assert auth.exchange_code_for_tokens("code1")["access_token"] == "tok"
|
||||
assert auth.refresh_access_token("rt")["access_token"] == "new_tok"
|
||||
assert auth.is_token_expired({"expired": True}) is True
|
||||
assert auth.is_token_expired({"expired": False}) is False
|
||||
|
||||
|
||||
class TestBaseConnectorLoader:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_concrete_loader_init(self):
|
||||
loader = ConcreteLoader("session123")
|
||||
assert loader.session_token == "session123"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_concrete_loader_load_data(self):
|
||||
loader = ConcreteLoader("s")
|
||||
docs = loader.load_data({})
|
||||
assert len(docs) == 1
|
||||
assert docs[0].text == "test"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_concrete_loader_download_to_directory(self):
|
||||
loader = ConcreteLoader("s")
|
||||
result = loader.download_to_directory("/tmp/test")
|
||||
assert result["directory_path"] == "/tmp/test"
|
||||
assert result["files_downloaded"] == 0
|
||||
102
tests/parser/connectors/test_connector_creator.py
Normal file
102
tests/parser/connectors/test_connector_creator.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""Tests for ConnectorCreator factory class."""
|
||||
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestConnectorCreator:
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _patch_settings(self):
|
||||
"""Patch settings so connector imports don't fail on missing credentials."""
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.GOOGLE_CLIENT_ID = "gid"
|
||||
mock_settings.GOOGLE_CLIENT_SECRET = "gsecret"
|
||||
mock_settings.CONNECTOR_REDIRECT_BASE_URI = "https://redirect"
|
||||
mock_settings.MICROSOFT_CLIENT_ID = "mid"
|
||||
mock_settings.MICROSOFT_CLIENT_SECRET = "msecret"
|
||||
mock_settings.MICROSOFT_TENANT_ID = "tid"
|
||||
mock_settings.MONGO_DB_NAME = "test_db"
|
||||
|
||||
with patch("application.core.settings.settings", mock_settings), \
|
||||
patch("application.parser.connectors.share_point.auth.settings", mock_settings), \
|
||||
patch("application.parser.connectors.google_drive.auth.settings", mock_settings), \
|
||||
patch("application.parser.connectors.share_point.auth.ConfidentialClientApplication"):
|
||||
from application.parser.connectors.connector_creator import ConnectorCreator
|
||||
self.ConnectorCreator = ConnectorCreator
|
||||
yield
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_supported_connectors(self):
|
||||
supported = self.ConnectorCreator.get_supported_connectors()
|
||||
assert "google_drive" in supported
|
||||
assert "share_point" in supported
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_is_supported_valid(self):
|
||||
assert self.ConnectorCreator.is_supported("google_drive") is True
|
||||
assert self.ConnectorCreator.is_supported("share_point") is True
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_is_supported_case_insensitive(self):
|
||||
assert self.ConnectorCreator.is_supported("Google_Drive") is True
|
||||
assert self.ConnectorCreator.is_supported("SHARE_POINT") is True
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_is_supported_invalid(self):
|
||||
assert self.ConnectorCreator.is_supported("dropbox") is False
|
||||
assert self.ConnectorCreator.is_supported("") is False
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_create_auth_google_drive(self):
|
||||
auth = self.ConnectorCreator.create_auth("google_drive")
|
||||
from application.parser.connectors.google_drive.auth import GoogleDriveAuth
|
||||
assert isinstance(auth, GoogleDriveAuth)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_create_auth_share_point(self):
|
||||
auth = self.ConnectorCreator.create_auth("share_point")
|
||||
from application.parser.connectors.share_point.auth import SharePointAuth
|
||||
assert isinstance(auth, SharePointAuth)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_create_auth_invalid_raises(self):
|
||||
with pytest.raises(ValueError, match="No auth class found"):
|
||||
self.ConnectorCreator.create_auth("invalid")
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_create_connector_invalid_raises(self):
|
||||
with pytest.raises(ValueError, match="No connector class found"):
|
||||
self.ConnectorCreator.create_connector("invalid", "session_tok")
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_create_connector_google_drive(self):
|
||||
with patch("application.parser.connectors.google_drive.loader.GoogleDriveAuth") as MockAuth:
|
||||
mock_auth_instance = MagicMock()
|
||||
mock_auth_instance.get_token_info_from_session.return_value = {
|
||||
"access_token": "at", "refresh_token": "rt"
|
||||
}
|
||||
mock_creds = MagicMock()
|
||||
mock_creds.token = "at"
|
||||
mock_creds.expired = False
|
||||
mock_auth_instance.create_credentials_from_token_info.return_value = mock_creds
|
||||
mock_auth_instance.build_drive_service.return_value = MagicMock()
|
||||
MockAuth.return_value = mock_auth_instance
|
||||
|
||||
loader = self.ConnectorCreator.create_connector("google_drive", "session_tok")
|
||||
from application.parser.connectors.google_drive.loader import GoogleDriveLoader
|
||||
assert isinstance(loader, GoogleDriveLoader)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_create_connector_share_point(self):
|
||||
with patch("application.parser.connectors.share_point.loader.SharePointAuth") as MockAuth:
|
||||
mock_auth_instance = MagicMock()
|
||||
mock_auth_instance.get_token_info_from_session.return_value = {
|
||||
"access_token": "at", "refresh_token": "rt"
|
||||
}
|
||||
MockAuth.return_value = mock_auth_instance
|
||||
|
||||
loader = self.ConnectorCreator.create_connector("share_point", "session_tok")
|
||||
from application.parser.connectors.share_point.loader import SharePointLoader
|
||||
assert isinstance(loader, SharePointLoader)
|
||||
441
tests/parser/connectors/test_google_drive_auth.py
Normal file
441
tests/parser/connectors/test_google_drive_auth.py
Normal file
@@ -0,0 +1,441 @@
|
||||
"""Tests for GoogleDriveAuth."""
|
||||
|
||||
import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings():
|
||||
s = MagicMock()
|
||||
s.GOOGLE_CLIENT_ID = "test-client-id"
|
||||
s.GOOGLE_CLIENT_SECRET = "test-client-secret"
|
||||
s.CONNECTOR_REDIRECT_BASE_URI = "https://redirect.example.com/callback"
|
||||
s.MONGO_DB_NAME = "test_db"
|
||||
return s
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth(mock_settings):
|
||||
with patch("application.parser.connectors.google_drive.auth.settings", mock_settings):
|
||||
from application.parser.connectors.google_drive.auth import GoogleDriveAuth
|
||||
return GoogleDriveAuth()
|
||||
|
||||
|
||||
class TestGoogleDriveAuthInit:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_sets_credentials(self, auth, mock_settings):
|
||||
assert auth.client_id == "test-client-id"
|
||||
assert auth.client_secret == "test-client-secret"
|
||||
assert auth.redirect_uri == "https://redirect.example.com/callback"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_missing_client_id_raises(self, mock_settings):
|
||||
mock_settings.GOOGLE_CLIENT_ID = None
|
||||
with patch("application.parser.connectors.google_drive.auth.settings", mock_settings):
|
||||
from application.parser.connectors.google_drive.auth import GoogleDriveAuth
|
||||
with pytest.raises(ValueError, match="Google OAuth credentials not configured"):
|
||||
GoogleDriveAuth()
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_missing_client_secret_raises(self, mock_settings):
|
||||
mock_settings.GOOGLE_CLIENT_SECRET = None
|
||||
with patch("application.parser.connectors.google_drive.auth.settings", mock_settings):
|
||||
from application.parser.connectors.google_drive.auth import GoogleDriveAuth
|
||||
with pytest.raises(ValueError, match="Google OAuth credentials not configured"):
|
||||
GoogleDriveAuth()
|
||||
|
||||
|
||||
class TestGetAuthorizationUrl:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_returns_authorization_url(self, auth):
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.authorization_url.return_value = ("https://accounts.google.com/auth?state=s1", "s1")
|
||||
|
||||
with patch("application.parser.connectors.google_drive.auth.Flow") as MockFlow:
|
||||
MockFlow.from_client_config.return_value = mock_flow
|
||||
url = auth.get_authorization_url(state="s1")
|
||||
|
||||
assert url == "https://accounts.google.com/auth?state=s1"
|
||||
mock_flow.authorization_url.assert_called_once_with(
|
||||
access_type='offline',
|
||||
prompt='consent',
|
||||
include_granted_scopes='false',
|
||||
state="s1"
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_raises_on_flow_error(self, auth):
|
||||
with patch("application.parser.connectors.google_drive.auth.Flow") as MockFlow:
|
||||
MockFlow.from_client_config.side_effect = Exception("flow error")
|
||||
with pytest.raises(Exception, match="flow error"):
|
||||
auth.get_authorization_url()
|
||||
|
||||
|
||||
class TestExchangeCodeForTokens:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_successful_exchange(self, auth):
|
||||
mock_creds = MagicMock()
|
||||
mock_creds.token = "access_tok"
|
||||
mock_creds.refresh_token = "refresh_tok"
|
||||
mock_creds.token_uri = "https://oauth2.googleapis.com/token"
|
||||
mock_creds.client_id = "test-client-id"
|
||||
mock_creds.client_secret = "test-client-secret"
|
||||
mock_creds.scopes = ["https://www.googleapis.com/auth/drive.file"]
|
||||
mock_creds.expiry = datetime.datetime(2025, 1, 1, 12, 0, 0)
|
||||
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.credentials = mock_creds
|
||||
|
||||
with patch("application.parser.connectors.google_drive.auth.Flow") as MockFlow:
|
||||
MockFlow.from_client_config.return_value = mock_flow
|
||||
result = auth.exchange_code_for_tokens("auth_code_123")
|
||||
|
||||
assert result["access_token"] == "access_tok"
|
||||
assert result["refresh_token"] == "refresh_tok"
|
||||
assert result["token_uri"] == "https://oauth2.googleapis.com/token"
|
||||
assert result["client_id"] == "test-client-id"
|
||||
assert result["client_secret"] == "test-client-secret"
|
||||
assert result["expiry"] == "2025-01-01T12:00:00"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_empty_code_raises(self, auth):
|
||||
with pytest.raises(ValueError, match="Authorization code is required"):
|
||||
auth.exchange_code_for_tokens("")
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_access_token_raises(self, auth):
|
||||
mock_creds = MagicMock()
|
||||
mock_creds.token = None
|
||||
mock_creds.refresh_token = "rt"
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.credentials = mock_creds
|
||||
|
||||
with patch("application.parser.connectors.google_drive.auth.Flow") as MockFlow:
|
||||
MockFlow.from_client_config.return_value = mock_flow
|
||||
with pytest.raises(ValueError, match="did not return an access token"):
|
||||
auth.exchange_code_for_tokens("code")
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_refresh_token_raises(self, auth):
|
||||
mock_creds = MagicMock()
|
||||
mock_creds.token = "at"
|
||||
mock_creds.refresh_token = None
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.credentials = mock_creds
|
||||
|
||||
with patch("application.parser.connectors.google_drive.auth.Flow") as MockFlow:
|
||||
MockFlow.from_client_config.return_value = mock_flow
|
||||
with pytest.raises(ValueError, match="No refresh token received"):
|
||||
auth.exchange_code_for_tokens("code")
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_fills_in_missing_token_uri(self, auth):
|
||||
mock_creds = MagicMock()
|
||||
mock_creds.token = "at"
|
||||
mock_creds.refresh_token = "rt"
|
||||
mock_creds.token_uri = None
|
||||
mock_creds.client_id = None
|
||||
mock_creds.client_secret = None
|
||||
mock_creds.scopes = []
|
||||
mock_creds.expiry = None
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.credentials = mock_creds
|
||||
|
||||
with patch("application.parser.connectors.google_drive.auth.Flow") as MockFlow:
|
||||
MockFlow.from_client_config.return_value = mock_flow
|
||||
result = auth.exchange_code_for_tokens("code")
|
||||
|
||||
assert result["token_uri"] == "https://oauth2.googleapis.com/token"
|
||||
assert result["client_id"] == "test-client-id"
|
||||
assert result["client_secret"] == "test-client-secret"
|
||||
|
||||
|
||||
class TestRefreshAccessToken:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_successful_refresh(self, auth):
|
||||
mock_request_cls = MagicMock()
|
||||
with patch("application.parser.connectors.google_drive.auth.Credentials") as MockCreds, \
|
||||
patch("google.auth.transport.requests.Request", mock_request_cls):
|
||||
mock_cred_instance = MagicMock()
|
||||
mock_cred_instance.token = "new_access"
|
||||
mock_cred_instance.token_uri = "https://oauth2.googleapis.com/token"
|
||||
mock_cred_instance.client_id = "cid"
|
||||
mock_cred_instance.client_secret = "cs"
|
||||
mock_cred_instance.scopes = []
|
||||
mock_cred_instance.expiry = datetime.datetime(2025, 6, 1, 0, 0, 0)
|
||||
MockCreds.return_value = mock_cred_instance
|
||||
|
||||
result = auth.refresh_access_token("old_refresh")
|
||||
|
||||
assert result["access_token"] == "new_access"
|
||||
assert result["refresh_token"] == "old_refresh"
|
||||
mock_cred_instance.refresh.assert_called_once()
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_empty_refresh_token_raises(self, auth):
|
||||
with pytest.raises(ValueError, match="Refresh token is required"):
|
||||
auth.refresh_access_token("")
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_refresh_failure_raises(self, auth):
|
||||
with patch("application.parser.connectors.google_drive.auth.Credentials") as MockCreds, \
|
||||
patch("google.auth.transport.requests.Request"):
|
||||
mock_cred_instance = MagicMock()
|
||||
mock_cred_instance.refresh.side_effect = Exception("refresh failed")
|
||||
MockCreds.return_value = mock_cred_instance
|
||||
|
||||
with pytest.raises(Exception, match="refresh failed"):
|
||||
auth.refresh_access_token("rt")
|
||||
|
||||
|
||||
class TestCreateCredentialsFromTokenInfo:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_creates_credentials(self, auth, mock_settings):
|
||||
with patch("application.parser.connectors.google_drive.auth.Credentials") as MockCreds, \
|
||||
patch("application.parser.connectors.google_drive.auth.settings", mock_settings):
|
||||
mock_cred = MagicMock()
|
||||
mock_cred.token = "at"
|
||||
MockCreds.return_value = mock_cred
|
||||
|
||||
creds = auth.create_credentials_from_token_info({
|
||||
"access_token": "at",
|
||||
"refresh_token": "rt",
|
||||
"scopes": ["scope1"],
|
||||
})
|
||||
assert creds.token == "at"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_missing_access_token_raises(self, auth, mock_settings):
|
||||
with patch("application.parser.connectors.google_drive.auth.settings", mock_settings):
|
||||
with pytest.raises(ValueError, match="No access token found"):
|
||||
auth.create_credentials_from_token_info({})
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_credentials_without_valid_token_raises(self, auth, mock_settings):
|
||||
with patch("application.parser.connectors.google_drive.auth.Credentials") as MockCreds, \
|
||||
patch("application.parser.connectors.google_drive.auth.settings", mock_settings):
|
||||
mock_cred = MagicMock()
|
||||
mock_cred.token = None
|
||||
MockCreds.return_value = mock_cred
|
||||
|
||||
with pytest.raises(ValueError, match="Credentials created without valid access token"):
|
||||
auth.create_credentials_from_token_info({"access_token": "at"})
|
||||
|
||||
|
||||
class TestBuildDriveService:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_builds_service(self, auth):
|
||||
mock_creds = MagicMock()
|
||||
mock_creds.token = "at"
|
||||
mock_creds.refresh_token = "rt"
|
||||
mock_creds.expired = False
|
||||
|
||||
with patch("application.parser.connectors.google_drive.auth.build") as mock_build:
|
||||
mock_build.return_value = MagicMock()
|
||||
service = auth.build_drive_service(mock_creds)
|
||||
mock_build.assert_called_once_with('drive', 'v3', credentials=mock_creds)
|
||||
assert service is not None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_credentials_raises(self, auth):
|
||||
with pytest.raises(ValueError, match="No credentials provided"):
|
||||
auth.build_drive_service(None)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_token_no_refresh_raises(self, auth):
|
||||
mock_creds = MagicMock()
|
||||
mock_creds.token = None
|
||||
mock_creds.refresh_token = None
|
||||
with pytest.raises(ValueError, match="No access token or refresh token"):
|
||||
auth.build_drive_service(mock_creds)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_expired_token_refreshes(self, auth):
|
||||
mock_creds = MagicMock()
|
||||
mock_creds.token = "at"
|
||||
mock_creds.refresh_token = "rt"
|
||||
mock_creds.expired = True
|
||||
|
||||
with patch("application.parser.connectors.google_drive.auth.build") as mock_build, \
|
||||
patch("google.auth.transport.requests.Request"):
|
||||
mock_build.return_value = MagicMock()
|
||||
auth.build_drive_service(mock_creds)
|
||||
mock_creds.refresh.assert_called_once()
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_expired_no_refresh_token_raises(self, auth):
|
||||
mock_creds = MagicMock()
|
||||
mock_creds.token = "at"
|
||||
mock_creds.refresh_token = None
|
||||
mock_creds.expired = True
|
||||
with pytest.raises(ValueError, match="No access token or refresh token"):
|
||||
auth.build_drive_service(mock_creds)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_refresh_failure_raises(self, auth):
|
||||
mock_creds = MagicMock()
|
||||
mock_creds.token = "at"
|
||||
mock_creds.refresh_token = "rt"
|
||||
mock_creds.expired = True
|
||||
|
||||
with patch("google.auth.transport.requests.Request"):
|
||||
mock_creds.refresh.side_effect = Exception("Cannot refresh")
|
||||
with pytest.raises(ValueError, match="Failed to refresh credentials"):
|
||||
auth.build_drive_service(mock_creds)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_http_error_raises(self, auth):
|
||||
from googleapiclient.errors import HttpError
|
||||
mock_creds = MagicMock()
|
||||
mock_creds.token = "at"
|
||||
mock_creds.refresh_token = "rt"
|
||||
mock_creds.expired = False
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status = 500
|
||||
|
||||
with patch("application.parser.connectors.google_drive.auth.build") as mock_build:
|
||||
mock_build.side_effect = HttpError(mock_resp, b"error")
|
||||
with pytest.raises(ValueError, match="HTTP 500"):
|
||||
auth.build_drive_service(mock_creds)
|
||||
|
||||
|
||||
class TestIsTokenExpired:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_expired_token(self, auth):
|
||||
past = (datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(hours=1)).isoformat()
|
||||
assert auth.is_token_expired({"expiry": past}) is True
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_valid_token(self, auth):
|
||||
future = (datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=1)).isoformat()
|
||||
assert auth.is_token_expired({"expiry": future}) is False
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_token_within_buffer(self, auth):
|
||||
almost_expired = (datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(seconds=30)).isoformat()
|
||||
assert auth.is_token_expired({"expiry": almost_expired}) is True
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_expiry_with_access_token(self, auth):
|
||||
assert auth.is_token_expired({"access_token": "at"}) is False
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_expiry_no_access_token(self, auth):
|
||||
assert auth.is_token_expired({}) is True
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invalid_expiry_format_returns_true(self, auth):
|
||||
assert auth.is_token_expired({"expiry": "not-a-date"}) is True
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_none_expiry_with_access_token(self, auth):
|
||||
assert auth.is_token_expired({"expiry": None, "access_token": "at"}) is False
|
||||
|
||||
|
||||
class TestGetTokenInfoFromSession:
|
||||
|
||||
def _mock_mongo(self, mock_settings, find_one_return):
|
||||
mock_collection = MagicMock()
|
||||
mock_collection.find_one.return_value = find_one_return
|
||||
mock_db = MagicMock()
|
||||
mock_db.__getitem__ = MagicMock(return_value=mock_collection)
|
||||
return {mock_settings.MONGO_DB_NAME: mock_db}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_valid_session(self, auth, mock_settings):
|
||||
mock_client = self._mock_mongo(mock_settings, {
|
||||
"session_token": "st",
|
||||
"token_info": {"access_token": "at", "refresh_token": "rt"},
|
||||
})
|
||||
|
||||
with patch("application.core.mongo_db.MongoDB.get_client", return_value=mock_client), \
|
||||
patch("application.core.settings.settings", mock_settings):
|
||||
result = auth.get_token_info_from_session("st")
|
||||
assert result["access_token"] == "at"
|
||||
assert result["token_uri"] == "https://oauth2.googleapis.com/token"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_session_not_found_raises(self, auth, mock_settings):
|
||||
mock_client = self._mock_mongo(mock_settings, None)
|
||||
|
||||
with patch("application.core.mongo_db.MongoDB.get_client", return_value=mock_client), \
|
||||
patch("application.core.settings.settings", mock_settings):
|
||||
with pytest.raises(ValueError, match="Failed to retrieve Google Drive token"):
|
||||
auth.get_token_info_from_session("bad_token")
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_session_missing_token_info_raises(self, auth, mock_settings):
|
||||
mock_client = self._mock_mongo(mock_settings, {"session_token": "st"})
|
||||
|
||||
with patch("application.core.mongo_db.MongoDB.get_client", return_value=mock_client), \
|
||||
patch("application.core.settings.settings", mock_settings):
|
||||
with pytest.raises(ValueError, match="Failed to retrieve Google Drive token"):
|
||||
auth.get_token_info_from_session("st")
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_missing_required_fields_raises(self, auth, mock_settings):
|
||||
mock_client = self._mock_mongo(mock_settings, {
|
||||
"session_token": "st",
|
||||
"token_info": {"access_token": "at"},
|
||||
})
|
||||
|
||||
with patch("application.core.mongo_db.MongoDB.get_client", return_value=mock_client), \
|
||||
patch("application.core.settings.settings", mock_settings):
|
||||
with pytest.raises(ValueError, match="Failed to retrieve Google Drive token"):
|
||||
auth.get_token_info_from_session("st")
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_empty_token_info_raises(self, auth, mock_settings):
|
||||
mock_client = self._mock_mongo(mock_settings, {
|
||||
"session_token": "st",
|
||||
"token_info": None,
|
||||
})
|
||||
|
||||
with patch("application.core.mongo_db.MongoDB.get_client", return_value=mock_client), \
|
||||
patch("application.core.settings.settings", mock_settings):
|
||||
with pytest.raises(ValueError, match="Failed to retrieve Google Drive token"):
|
||||
auth.get_token_info_from_session("st")
|
||||
|
||||
|
||||
class TestValidateCredentials:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_valid_credentials(self, auth):
|
||||
mock_creds = MagicMock()
|
||||
mock_creds.token = "at"
|
||||
mock_creds.refresh_token = "rt"
|
||||
mock_creds.expired = False
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service.about.return_value.get.return_value.execute.return_value = {"user": {}}
|
||||
|
||||
with patch.object(auth, 'build_drive_service', return_value=mock_service):
|
||||
assert auth.validate_credentials(mock_creds) is True
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_http_error_returns_false(self, auth):
|
||||
from googleapiclient.errors import HttpError
|
||||
mock_creds = MagicMock()
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status = 401
|
||||
mock_service = MagicMock()
|
||||
mock_service.about.return_value.get.return_value.execute.side_effect = HttpError(mock_resp, b"unauth")
|
||||
|
||||
with patch.object(auth, 'build_drive_service', return_value=mock_service):
|
||||
assert auth.validate_credentials(mock_creds) is False
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_general_error_returns_false(self, auth):
|
||||
mock_creds = MagicMock()
|
||||
with patch.object(auth, 'build_drive_service', side_effect=Exception("fail")):
|
||||
assert auth.validate_credentials(mock_creds) is False
|
||||
851
tests/parser/connectors/test_google_drive_loader.py
Normal file
851
tests/parser/connectors/test_google_drive_loader.py
Normal file
@@ -0,0 +1,851 @@
|
||||
"""Tests for GoogleDriveLoader."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from application.parser.schema.base import Document
|
||||
|
||||
|
||||
def _make_loader(service=None):
|
||||
"""Create a GoogleDriveLoader with mocked dependencies."""
|
||||
with patch("application.parser.connectors.google_drive.loader.GoogleDriveAuth") as MockAuth:
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.get_token_info_from_session.return_value = {
|
||||
"access_token": "at",
|
||||
"refresh_token": "rt",
|
||||
}
|
||||
mock_creds = MagicMock()
|
||||
mock_creds.token = "at"
|
||||
mock_creds.expired = False
|
||||
mock_creds.refresh_token = "rt"
|
||||
mock_auth.create_credentials_from_token_info.return_value = mock_creds
|
||||
mock_auth.build_drive_service.return_value = service or MagicMock()
|
||||
MockAuth.return_value = mock_auth
|
||||
|
||||
from application.parser.connectors.google_drive.loader import GoogleDriveLoader
|
||||
loader = GoogleDriveLoader("session_tok")
|
||||
return loader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_service():
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def loader(mock_service):
|
||||
return _make_loader(mock_service)
|
||||
|
||||
|
||||
class TestGoogleDriveLoaderInit:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_sets_attributes(self, loader):
|
||||
assert loader.session_token == "session_tok"
|
||||
assert loader.credentials is not None
|
||||
assert loader.service is not None
|
||||
assert loader.next_page_token is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_service_failure_sets_none(self):
|
||||
with patch("application.parser.connectors.google_drive.loader.GoogleDriveAuth") as MockAuth:
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.get_token_info_from_session.return_value = {
|
||||
"access_token": "at", "refresh_token": "rt"
|
||||
}
|
||||
mock_creds = MagicMock()
|
||||
mock_creds.token = "at"
|
||||
mock_auth.create_credentials_from_token_info.return_value = mock_creds
|
||||
mock_auth.build_drive_service.side_effect = Exception("service fail")
|
||||
MockAuth.return_value = mock_auth
|
||||
|
||||
from application.parser.connectors.google_drive.loader import GoogleDriveLoader
|
||||
loader = GoogleDriveLoader("st")
|
||||
assert loader.service is None
|
||||
|
||||
|
||||
class TestProcessFile:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_supported_mime_type_with_content(self, loader):
|
||||
loader._download_file_content = MagicMock(return_value="file content")
|
||||
metadata = {
|
||||
"id": "f1",
|
||||
"name": "test.pdf",
|
||||
"mimeType": "application/pdf",
|
||||
"size": 1024,
|
||||
"createdTime": "2025-01-01T00:00:00Z",
|
||||
"modifiedTime": "2025-01-02T00:00:00Z",
|
||||
"parents": ["root"],
|
||||
}
|
||||
doc = loader._process_file(metadata, load_content=True)
|
||||
assert doc is not None
|
||||
assert doc.text == "file content"
|
||||
assert doc.doc_id == "f1"
|
||||
assert doc.extra_info["file_name"] == "test.pdf"
|
||||
assert doc.extra_info["source"] == "google_drive"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_supported_mime_type_no_content(self, loader):
|
||||
metadata = {
|
||||
"id": "f1",
|
||||
"name": "test.pdf",
|
||||
"mimeType": "application/pdf",
|
||||
}
|
||||
doc = loader._process_file(metadata, load_content=False)
|
||||
assert doc is not None
|
||||
assert doc.text == ""
|
||||
assert doc.doc_id == "f1"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_unsupported_mime_type_returns_none(self, loader):
|
||||
metadata = {
|
||||
"id": "f1",
|
||||
"name": "test.zip",
|
||||
"mimeType": "application/zip",
|
||||
}
|
||||
doc = loader._process_file(metadata)
|
||||
assert doc is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_download_failure_returns_none(self, loader):
|
||||
loader._download_file_content = MagicMock(return_value=None)
|
||||
metadata = {
|
||||
"id": "f1",
|
||||
"name": "test.txt",
|
||||
"mimeType": "text/plain",
|
||||
}
|
||||
doc = loader._process_file(metadata, load_content=True)
|
||||
assert doc is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_exception_returns_none(self, loader):
|
||||
loader._download_file_content = MagicMock(side_effect=Exception("fail"))
|
||||
metadata = {
|
||||
"id": "f1",
|
||||
"name": "test.txt",
|
||||
"mimeType": "text/plain",
|
||||
}
|
||||
doc = loader._process_file(metadata, load_content=True)
|
||||
assert doc is None
|
||||
|
||||
|
||||
class TestLoadData:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_load_specific_files(self, loader):
|
||||
doc = Document(text="content", doc_id="f1", extra_info={"file_name": "test.pdf"})
|
||||
loader._load_file_by_id = MagicMock(return_value=doc)
|
||||
|
||||
result = loader.load_data({"file_ids": ["f1"]})
|
||||
assert len(result) == 1
|
||||
assert result[0].doc_id == "f1"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_load_files_with_search_filter(self, loader):
|
||||
doc = Document(text="c", doc_id="f1", extra_info={"file_name": "report.pdf"})
|
||||
loader._load_file_by_id = MagicMock(return_value=doc)
|
||||
|
||||
result = loader.load_data({"file_ids": ["f1"], "search_query": "report"})
|
||||
assert len(result) == 1
|
||||
|
||||
result = loader.load_data({"file_ids": ["f1"], "search_query": "other"})
|
||||
assert len(result) == 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_load_files_error_continues(self, loader):
|
||||
loader._load_file_by_id = MagicMock(side_effect=Exception("fail"))
|
||||
result = loader.load_data({"file_ids": ["f1", "f2"]})
|
||||
assert len(result) == 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_browse_mode_uses_list_items(self, loader):
|
||||
docs = [Document(text="", doc_id="f1", extra_info={})]
|
||||
loader._list_items_in_parent = MagicMock(return_value=docs)
|
||||
|
||||
result = loader.load_data({"folder_id": "folder1", "limit": 50})
|
||||
loader._list_items_in_parent.assert_called_once_with(
|
||||
"folder1", limit=50, load_content=True, page_token=None, search_query=None
|
||||
)
|
||||
assert len(result) == 1
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_browse_mode_defaults_to_root(self, loader):
|
||||
loader._list_items_in_parent = MagicMock(return_value=[])
|
||||
loader.load_data({})
|
||||
loader._list_items_in_parent.assert_called_once_with(
|
||||
"root", limit=100, load_content=True, page_token=None, search_query=None
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_session_token_mismatch_logs_warning(self, loader):
|
||||
loader._list_items_in_parent = MagicMock(return_value=[])
|
||||
loader.load_data({"session_token": "different_token"})
|
||||
# Should not raise, just logs
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_load_data_with_list_only(self, loader):
|
||||
loader._list_items_in_parent = MagicMock(return_value=[])
|
||||
loader.load_data({"list_only": True})
|
||||
loader._list_items_in_parent.assert_called_once_with(
|
||||
"root", limit=100, load_content=False, page_token=None, search_query=None
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_load_data_with_page_token(self, loader):
|
||||
loader._list_items_in_parent = MagicMock(return_value=[])
|
||||
loader.load_data({"page_token": "next_page"})
|
||||
loader._list_items_in_parent.assert_called_once_with(
|
||||
"root", limit=100, load_content=True, page_token="next_page", search_query=None
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_credential_refresh_retry(self, loader):
|
||||
"""When _load_file_by_id returns None and _credential_refreshed is set, retry."""
|
||||
loader._credential_refreshed = True
|
||||
call_count = [0]
|
||||
def side_effect(fid, load_content=True):
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
return None
|
||||
return Document(text="c", doc_id=fid, extra_info={"file_name": "test.pdf"})
|
||||
|
||||
loader._load_file_by_id = MagicMock(side_effect=side_effect)
|
||||
result = loader.load_data({"file_ids": ["f1"]})
|
||||
assert len(result) == 1
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_outer_exception_raises(self, loader):
|
||||
"""The outer try/except in load_data re-raises unexpected errors."""
|
||||
loader._list_items_in_parent = MagicMock(side_effect=RuntimeError("unexpected"))
|
||||
with pytest.raises(RuntimeError, match="unexpected"):
|
||||
loader.load_data({})
|
||||
|
||||
|
||||
class TestLoadFileById:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_loads_file_metadata_and_processes(self, loader, mock_service):
|
||||
mock_service.files.return_value.get.return_value.execute.return_value = {
|
||||
"id": "f1", "name": "test.txt", "mimeType": "text/plain"
|
||||
}
|
||||
loader._process_file = MagicMock(return_value=Document(text="t", doc_id="f1", extra_info={}))
|
||||
doc = loader._load_file_by_id("f1")
|
||||
assert doc is not None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_http_401_refreshes_credentials(self, loader, mock_service):
|
||||
from googleapiclient.errors import HttpError
|
||||
resp = MagicMock()
|
||||
resp.status = 401
|
||||
mock_service.files.return_value.get.return_value.execute.side_effect = HttpError(resp, b"unauth")
|
||||
|
||||
with patch("google.auth.transport.requests.Request"):
|
||||
result = loader._load_file_by_id("f1")
|
||||
assert result is None
|
||||
loader.credentials.refresh.assert_called_once()
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_http_401_no_refresh_token_raises(self, loader, mock_service):
|
||||
from googleapiclient.errors import HttpError
|
||||
resp = MagicMock()
|
||||
resp.status = 401
|
||||
mock_service.files.return_value.get.return_value.execute.side_effect = HttpError(resp, b"unauth")
|
||||
loader.credentials.refresh_token = None
|
||||
|
||||
with pytest.raises(ValueError, match="missing refresh_token"):
|
||||
loader._load_file_by_id("f1")
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_http_500_returns_none(self, loader, mock_service):
|
||||
from googleapiclient.errors import HttpError
|
||||
resp = MagicMock()
|
||||
resp.status = 500
|
||||
mock_service.files.return_value.get.return_value.execute.side_effect = HttpError(resp, b"server error")
|
||||
result = loader._load_file_by_id("f1")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_general_exception_returns_none(self, loader, mock_service):
|
||||
mock_service.files.return_value.get.return_value.execute.side_effect = Exception("fail")
|
||||
result = loader._load_file_by_id("f1")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_ensure_service_called(self, loader):
|
||||
loader.service = None
|
||||
loader.auth.build_drive_service.return_value = MagicMock()
|
||||
loader.auth.build_drive_service.return_value.files.return_value.get.return_value.execute.return_value = {
|
||||
"id": "f1", "name": "t.txt", "mimeType": "text/plain"
|
||||
}
|
||||
loader._process_file = MagicMock(return_value=None)
|
||||
loader._load_file_by_id("f1")
|
||||
loader.auth.build_drive_service.assert_called()
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_http_401_refresh_failure_raises(self, loader, mock_service):
|
||||
from googleapiclient.errors import HttpError
|
||||
resp = MagicMock()
|
||||
resp.status = 401
|
||||
mock_service.files.return_value.get.return_value.execute.side_effect = HttpError(resp, b"unauth")
|
||||
loader.credentials.refresh_token = "rt"
|
||||
|
||||
with patch("google.auth.transport.requests.Request"):
|
||||
loader.credentials.refresh.side_effect = Exception("refresh broke")
|
||||
with pytest.raises(ValueError, match="could not be refreshed"):
|
||||
loader._load_file_by_id("f1")
|
||||
|
||||
|
||||
class TestListItemsInParent:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_lists_files_and_folders(self, loader, mock_service):
|
||||
mock_service.files.return_value.list.return_value.execute.return_value = {
|
||||
"files": [
|
||||
{"id": "folder1", "name": "Docs", "mimeType": "application/vnd.google-apps.folder"},
|
||||
{"id": "file1", "name": "test.txt", "mimeType": "text/plain"},
|
||||
],
|
||||
"nextPageToken": None,
|
||||
}
|
||||
loader._process_file = MagicMock(return_value=Document(text="", doc_id="file1", extra_info={}))
|
||||
docs = loader._list_items_in_parent("root", limit=100)
|
||||
assert len(docs) == 2
|
||||
assert docs[0].extra_info.get("is_folder") is True
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_search_query_modifies_drive_query(self, loader, mock_service):
|
||||
mock_service.files.return_value.list.return_value.execute.return_value = {
|
||||
"files": [], "nextPageToken": None
|
||||
}
|
||||
loader._list_items_in_parent("root", search_query="report")
|
||||
call_args = mock_service.files.return_value.list.call_args
|
||||
assert "name contains 'report'" in call_args.kwargs.get("q", "")
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_limit_stops_early(self, loader, mock_service):
|
||||
mock_service.files.return_value.list.return_value.execute.return_value = {
|
||||
"files": [
|
||||
{"id": f"f{i}", "name": f"file{i}.txt", "mimeType": "text/plain"} for i in range(10)
|
||||
],
|
||||
"nextPageToken": "next",
|
||||
}
|
||||
loader._process_file = MagicMock(side_effect=lambda m, **kw: Document(text="", doc_id=m["id"], extra_info={}))
|
||||
docs = loader._list_items_in_parent("root", limit=3)
|
||||
assert len(docs) == 3
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_pagination_stores_next_page_token(self, loader, mock_service):
|
||||
mock_service.files.return_value.list.return_value.execute.return_value = {
|
||||
"files": [{"id": "f1", "name": "f.txt", "mimeType": "text/plain"}],
|
||||
"nextPageToken": "page2",
|
||||
}
|
||||
loader._process_file = MagicMock(return_value=Document(text="", doc_id="f1", extra_info={}))
|
||||
loader._list_items_in_parent("root", limit=1)
|
||||
assert loader.next_page_token == "page2"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_limit_breaks_loop_when_remaining_zero(self, loader, mock_service):
|
||||
"""When limit is exactly met after first page, the while loop breaks via remaining==0."""
|
||||
call_count = [0]
|
||||
def list_side_effect(**kw):
|
||||
call_count[0] += 1
|
||||
mock = MagicMock()
|
||||
if call_count[0] == 1:
|
||||
mock.execute.return_value = {
|
||||
"files": [
|
||||
{"id": "f1", "name": "f1.txt", "mimeType": "text/plain"},
|
||||
{"id": "f2", "name": "f2.txt", "mimeType": "text/plain"},
|
||||
],
|
||||
"nextPageToken": "page2",
|
||||
}
|
||||
else:
|
||||
# Should not reach here if break works correctly
|
||||
mock.execute.return_value = {"files": [], "nextPageToken": None}
|
||||
return mock
|
||||
|
||||
mock_service.files.return_value.list.side_effect = list_side_effect
|
||||
loader._process_file = MagicMock(side_effect=lambda m, **kw: Document(text="", doc_id=m["id"], extra_info={}))
|
||||
# limit=2, first page returns exactly 2 files with nextPageToken
|
||||
# Since items don't hit the inner limit check (it checks after each item),
|
||||
# it should loop back and break at remaining==0
|
||||
docs = loader._list_items_in_parent("root", limit=2)
|
||||
assert len(docs) == 2
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_exception_returns_partial_results(self, loader, mock_service):
|
||||
mock_service.files.return_value.list.return_value.execute.side_effect = Exception("api error")
|
||||
docs = loader._list_items_in_parent("root")
|
||||
assert docs == []
|
||||
|
||||
|
||||
class TestDownloadFileContent:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_download_regular_file(self, loader, mock_service):
|
||||
mock_request = MagicMock()
|
||||
mock_service.files.return_value.get_media.return_value = mock_request
|
||||
|
||||
with patch("application.parser.connectors.google_drive.loader.MediaIoBaseDownload") as MockDownload:
|
||||
mock_dl = MagicMock()
|
||||
mock_dl.next_chunk.side_effect = [(None, False), (None, True)]
|
||||
MockDownload.return_value = mock_dl
|
||||
|
||||
with patch("io.BytesIO") as MockBytesIO:
|
||||
mock_bio = MagicMock()
|
||||
mock_bio.getvalue.return_value = b"file content"
|
||||
MockBytesIO.return_value = mock_bio
|
||||
|
||||
content = loader._download_file_content("f1", "text/plain")
|
||||
|
||||
assert content == "file content"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_download_google_workspace_file_uses_export(self, loader, mock_service):
|
||||
mock_request = MagicMock()
|
||||
mock_service.files.return_value.export_media.return_value = mock_request
|
||||
|
||||
with patch("application.parser.connectors.google_drive.loader.MediaIoBaseDownload") as MockDownload:
|
||||
mock_dl = MagicMock()
|
||||
mock_dl.next_chunk.return_value = (None, True)
|
||||
MockDownload.return_value = mock_dl
|
||||
|
||||
with patch("io.BytesIO") as MockBytesIO:
|
||||
mock_bio = MagicMock()
|
||||
mock_bio.getvalue.return_value = b"exported"
|
||||
MockBytesIO.return_value = mock_bio
|
||||
|
||||
content = loader._download_file_content("f1", "application/vnd.google-apps.document")
|
||||
|
||||
mock_service.files.return_value.export_media.assert_called_once()
|
||||
assert content == "exported"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_unicode_decode_error_returns_none(self, loader, mock_service):
|
||||
mock_request = MagicMock()
|
||||
mock_service.files.return_value.get_media.return_value = mock_request
|
||||
|
||||
with patch("application.parser.connectors.google_drive.loader.MediaIoBaseDownload") as MockDownload:
|
||||
mock_dl = MagicMock()
|
||||
mock_dl.next_chunk.return_value = (None, True)
|
||||
MockDownload.return_value = mock_dl
|
||||
|
||||
with patch("io.BytesIO") as MockBytesIO:
|
||||
mock_bio = MagicMock()
|
||||
mock_bio.getvalue.return_value = MagicMock()
|
||||
mock_bio.getvalue.return_value.decode.side_effect = UnicodeDecodeError("utf-8", b"", 0, 1, "bad")
|
||||
MockBytesIO.return_value = mock_bio
|
||||
|
||||
result = loader._download_file_content("f1", "application/pdf")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_access_token_with_refresh(self, loader):
|
||||
loader.credentials.token = None
|
||||
loader.credentials.refresh_token = "rt"
|
||||
loader.credentials.expired = False
|
||||
|
||||
with patch("google.auth.transport.requests.Request"):
|
||||
loader.credentials.refresh = MagicMock()
|
||||
# After refresh, set token
|
||||
def set_token(req):
|
||||
loader.credentials.token = "new_at"
|
||||
loader.credentials.refresh.side_effect = set_token
|
||||
loader._ensure_service = MagicMock()
|
||||
loader.service.files.return_value.get_media.return_value = MagicMock()
|
||||
|
||||
with patch("application.parser.connectors.google_drive.loader.MediaIoBaseDownload") as MockDownload:
|
||||
mock_dl = MagicMock()
|
||||
mock_dl.next_chunk.return_value = (None, True)
|
||||
MockDownload.return_value = mock_dl
|
||||
with patch("io.BytesIO") as MockBytesIO:
|
||||
mock_bio = MagicMock()
|
||||
mock_bio.getvalue.return_value = b"data"
|
||||
MockBytesIO.return_value = mock_bio
|
||||
content = loader._download_file_content("f1", "text/plain")
|
||||
|
||||
assert content == "data"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_access_token_no_refresh_raises(self, loader):
|
||||
loader.credentials.token = None
|
||||
loader.credentials.refresh_token = None
|
||||
with pytest.raises(ValueError, match="missing refresh_token"):
|
||||
loader._download_file_content("f1", "text/plain")
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_token_refresh_fails_raises(self, loader):
|
||||
loader.credentials.token = None
|
||||
loader.credentials.refresh_token = "rt"
|
||||
with patch("google.auth.transport.requests.Request"):
|
||||
loader.credentials.refresh.side_effect = Exception("fail")
|
||||
with pytest.raises(ValueError, match="missing or invalid refresh_token"):
|
||||
loader._download_file_content("f1", "text/plain")
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_expired_credentials_refresh(self, loader):
|
||||
loader.credentials.token = "at"
|
||||
loader.credentials.expired = True
|
||||
loader.credentials.refresh_token = "rt"
|
||||
|
||||
with patch("google.auth.transport.requests.Request"):
|
||||
loader.credentials.refresh = MagicMock()
|
||||
def fix_expired(req):
|
||||
loader.credentials.expired = False
|
||||
loader.credentials.refresh.side_effect = fix_expired
|
||||
loader._ensure_service = MagicMock()
|
||||
loader.service.files.return_value.get_media.return_value = MagicMock()
|
||||
|
||||
with patch("application.parser.connectors.google_drive.loader.MediaIoBaseDownload") as MockDownload:
|
||||
mock_dl = MagicMock()
|
||||
mock_dl.next_chunk.return_value = (None, True)
|
||||
MockDownload.return_value = mock_dl
|
||||
with patch("io.BytesIO") as MockBytesIO:
|
||||
mock_bio = MagicMock()
|
||||
mock_bio.getvalue.return_value = b"ok"
|
||||
MockBytesIO.return_value = mock_bio
|
||||
content = loader._download_file_content("f1", "text/plain")
|
||||
|
||||
assert content == "ok"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_expired_no_refresh_token_raises(self, loader):
|
||||
loader.credentials.token = "at"
|
||||
loader.credentials.expired = True
|
||||
loader.credentials.refresh_token = None
|
||||
with pytest.raises(ValueError, match="missing refresh_token"):
|
||||
loader._download_file_content("f1", "text/plain")
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_expired_refresh_fails_raises(self, loader):
|
||||
loader.credentials.token = "at"
|
||||
loader.credentials.expired = True
|
||||
loader.credentials.refresh_token = "rt"
|
||||
with patch("google.auth.transport.requests.Request"):
|
||||
loader.credentials.refresh.side_effect = Exception("fail")
|
||||
with pytest.raises(ValueError, match="expired credentials"):
|
||||
loader._download_file_content("f1", "text/plain")
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_http_error_during_download_chunk_returns_none(self, loader, mock_service):
|
||||
from googleapiclient.errors import HttpError
|
||||
mock_request = MagicMock()
|
||||
mock_service.files.return_value.get_media.return_value = mock_request
|
||||
|
||||
resp = MagicMock()
|
||||
resp.status = 500
|
||||
|
||||
with patch("application.parser.connectors.google_drive.loader.MediaIoBaseDownload") as MockDownload:
|
||||
mock_dl = MagicMock()
|
||||
mock_dl.next_chunk.side_effect = HttpError(resp, b"server error")
|
||||
MockDownload.return_value = mock_dl
|
||||
with patch("io.BytesIO") as MockBytesIO:
|
||||
MockBytesIO.return_value = MagicMock()
|
||||
result = loader._download_file_content("f1", "text/plain")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_general_error_during_download_chunk_returns_none(self, loader, mock_service):
|
||||
mock_request = MagicMock()
|
||||
mock_service.files.return_value.get_media.return_value = mock_request
|
||||
|
||||
with patch("application.parser.connectors.google_drive.loader.MediaIoBaseDownload") as MockDownload:
|
||||
mock_dl = MagicMock()
|
||||
mock_dl.next_chunk.side_effect = RuntimeError("chunk fail")
|
||||
MockDownload.return_value = mock_dl
|
||||
with patch("io.BytesIO") as MockBytesIO:
|
||||
MockBytesIO.return_value = MagicMock()
|
||||
result = loader._download_file_content("f1", "text/plain")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_http_401_during_download_refreshes_and_returns_none(self, loader, mock_service):
|
||||
from googleapiclient.errors import HttpError
|
||||
resp = MagicMock()
|
||||
resp.status = 401
|
||||
|
||||
mock_service.files.return_value.get_media.side_effect = HttpError(resp, b"unauth")
|
||||
loader.credentials.refresh_token = "rt"
|
||||
|
||||
with patch("google.auth.transport.requests.Request"):
|
||||
loader.credentials.refresh = MagicMock()
|
||||
loader._ensure_service = MagicMock()
|
||||
result = loader._download_file_content("f1", "text/plain")
|
||||
|
||||
assert result is None
|
||||
assert loader._credential_refreshed is True
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_http_401_no_refresh_token_raises(self, loader, mock_service):
|
||||
from googleapiclient.errors import HttpError
|
||||
resp = MagicMock()
|
||||
resp.status = 401
|
||||
mock_service.files.return_value.get_media.side_effect = HttpError(resp, b"unauth")
|
||||
loader.credentials.refresh_token = None
|
||||
|
||||
with pytest.raises(ValueError, match="missing refresh_token"):
|
||||
loader._download_file_content("f1", "text/plain")
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_http_401_refresh_fails_raises(self, loader, mock_service):
|
||||
from googleapiclient.errors import HttpError
|
||||
resp = MagicMock()
|
||||
resp.status = 401
|
||||
mock_service.files.return_value.get_media.side_effect = HttpError(resp, b"unauth")
|
||||
loader.credentials.refresh_token = "rt"
|
||||
|
||||
with patch("google.auth.transport.requests.Request"):
|
||||
loader.credentials.refresh.side_effect = Exception("refresh fail")
|
||||
with pytest.raises(ValueError, match="could not be refreshed"):
|
||||
loader._download_file_content("f1", "text/plain")
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_http_500_returns_none(self, loader, mock_service):
|
||||
from googleapiclient.errors import HttpError
|
||||
resp = MagicMock()
|
||||
resp.status = 500
|
||||
mock_service.files.return_value.get_media.side_effect = HttpError(resp, b"error")
|
||||
result = loader._download_file_content("f1", "text/plain")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_general_exception_returns_none(self, loader, mock_service):
|
||||
mock_service.files.return_value.get_media.side_effect = RuntimeError("fail")
|
||||
result = loader._download_file_content("f1", "text/plain")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestDownloadToDirectory:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_download_files(self, loader, tmp_path):
|
||||
loader._download_file_to_directory = MagicMock(return_value=True)
|
||||
result = loader.download_to_directory(str(tmp_path), {"file_ids": ["f1", "f2"]})
|
||||
assert result["files_downloaded"] == 2
|
||||
assert result["source_type"] == "google_drive"
|
||||
assert result["empty_result"] is False
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_download_files_string_id(self, loader, tmp_path):
|
||||
loader._download_file_to_directory = MagicMock(return_value=True)
|
||||
result = loader.download_to_directory(str(tmp_path), {"file_ids": "single_id"})
|
||||
assert result["files_downloaded"] == 1
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_download_folders(self, loader, tmp_path, mock_service):
|
||||
mock_service.files.return_value.get.return_value.execute.return_value = {"name": "MyFolder"}
|
||||
loader._download_folder_recursive = MagicMock(return_value=3)
|
||||
result = loader.download_to_directory(str(tmp_path), {"folder_ids": ["folder1"]})
|
||||
assert result["files_downloaded"] == 3
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_download_folders_string_id(self, loader, tmp_path, mock_service):
|
||||
mock_service.files.return_value.get.return_value.execute.return_value = {"name": "F"}
|
||||
loader._download_folder_recursive = MagicMock(return_value=1)
|
||||
result = loader.download_to_directory(str(tmp_path), {"folder_ids": "single_folder"})
|
||||
assert result["files_downloaded"] == 1
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_ids_returns_error(self, loader, tmp_path):
|
||||
result = loader.download_to_directory(str(tmp_path), {})
|
||||
assert "error" in result
|
||||
assert result["empty_result"] is True
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_none_config_uses_empty(self, loader, tmp_path):
|
||||
result = loader.download_to_directory(str(tmp_path))
|
||||
assert "error" in result
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_folder_error_continues(self, loader, tmp_path, mock_service):
|
||||
mock_service.files.return_value.get.return_value.execute.side_effect = Exception("fail")
|
||||
loader._download_file_to_directory = MagicMock(return_value=True)
|
||||
result = loader.download_to_directory(str(tmp_path), {"file_ids": ["f1"], "folder_ids": ["bad_folder"]})
|
||||
assert result["files_downloaded"] == 1
|
||||
|
||||
|
||||
class TestDownloadSingleFile:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_downloads_supported_file(self, loader, mock_service, tmp_path):
|
||||
mock_service.files.return_value.get.return_value.execute.return_value = {
|
||||
"name": "test.txt", "mimeType": "text/plain"
|
||||
}
|
||||
mock_request = MagicMock()
|
||||
mock_service.files.return_value.get_media.return_value = mock_request
|
||||
|
||||
with patch("application.parser.connectors.google_drive.loader.MediaIoBaseDownload") as MockDl:
|
||||
mock_dl = MagicMock()
|
||||
mock_dl.next_chunk.return_value = (None, True)
|
||||
MockDl.return_value = mock_dl
|
||||
|
||||
result = loader._download_single_file("f1", str(tmp_path))
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_unsupported_mime_returns_false(self, loader, mock_service, tmp_path):
|
||||
mock_service.files.return_value.get.return_value.execute.return_value = {
|
||||
"name": "test.zip", "mimeType": "application/zip"
|
||||
}
|
||||
result = loader._download_single_file("f1", str(tmp_path))
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_google_workspace_file_export(self, loader, mock_service, tmp_path):
|
||||
mock_service.files.return_value.get.return_value.execute.return_value = {
|
||||
"name": "doc", "mimeType": "application/vnd.google-apps.document"
|
||||
}
|
||||
mock_request = MagicMock()
|
||||
mock_service.files.return_value.export_media.return_value = mock_request
|
||||
|
||||
with patch("application.parser.connectors.google_drive.loader.MediaIoBaseDownload") as MockDl:
|
||||
mock_dl = MagicMock()
|
||||
mock_dl.next_chunk.return_value = (None, True)
|
||||
MockDl.return_value = mock_dl
|
||||
result = loader._download_single_file("f1", str(tmp_path))
|
||||
|
||||
assert result is True
|
||||
mock_service.files.return_value.export_media.assert_called_once()
|
||||
|
||||
|
||||
class TestDownloadFolderRecursive:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_downloads_files_in_folder(self, loader, mock_service, tmp_path):
|
||||
mock_service.files.return_value.list.return_value.execute.return_value = {
|
||||
"files": [
|
||||
{"id": "f1", "name": "file1.txt", "mimeType": "text/plain"},
|
||||
],
|
||||
"nextPageToken": None,
|
||||
}
|
||||
loader._download_single_file = MagicMock(return_value=True)
|
||||
count = loader._download_folder_recursive("folder1", str(tmp_path))
|
||||
assert count == 1
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_recurses_into_subfolders(self, loader, mock_service, tmp_path):
|
||||
# First call: folder with subfolder and file
|
||||
# Second call: subfolder contents
|
||||
call_count = [0]
|
||||
def list_side_effect():
|
||||
mock = MagicMock()
|
||||
if call_count[0] == 0:
|
||||
call_count[0] += 1
|
||||
mock.execute.return_value = {
|
||||
"files": [
|
||||
{"id": "sub1", "name": "subfolder", "mimeType": "application/vnd.google-apps.folder"},
|
||||
{"id": "f1", "name": "file1.txt", "mimeType": "text/plain"},
|
||||
],
|
||||
"nextPageToken": None,
|
||||
}
|
||||
else:
|
||||
mock.execute.return_value = {
|
||||
"files": [
|
||||
{"id": "f2", "name": "file2.txt", "mimeType": "text/plain"},
|
||||
],
|
||||
"nextPageToken": None,
|
||||
}
|
||||
return mock
|
||||
|
||||
mock_service.files.return_value.list.side_effect = lambda **kw: list_side_effect()
|
||||
loader._download_single_file = MagicMock(return_value=True)
|
||||
count = loader._download_folder_recursive("folder1", str(tmp_path), recursive=True)
|
||||
assert count == 2
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_non_recursive_skips_subfolders(self, loader, mock_service, tmp_path):
|
||||
mock_service.files.return_value.list.return_value.execute.return_value = {
|
||||
"files": [
|
||||
{"id": "sub1", "name": "subfolder", "mimeType": "application/vnd.google-apps.folder"},
|
||||
{"id": "f1", "name": "file1.txt", "mimeType": "text/plain"},
|
||||
],
|
||||
"nextPageToken": None,
|
||||
}
|
||||
loader._download_single_file = MagicMock(return_value=True)
|
||||
count = loader._download_folder_recursive("folder1", str(tmp_path), recursive=False)
|
||||
assert count == 1
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_download_failure_continues(self, loader, mock_service, tmp_path):
|
||||
mock_service.files.return_value.list.return_value.execute.return_value = {
|
||||
"files": [
|
||||
{"id": "f1", "name": "fail.txt", "mimeType": "text/plain"},
|
||||
{"id": "f2", "name": "ok.txt", "mimeType": "text/plain"},
|
||||
],
|
||||
"nextPageToken": None,
|
||||
}
|
||||
loader._download_single_file = MagicMock(side_effect=[False, True])
|
||||
count = loader._download_folder_recursive("folder1", str(tmp_path))
|
||||
assert count == 1
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_exception_returns_partial_count(self, loader, mock_service, tmp_path):
|
||||
mock_service.files.return_value.list.return_value.execute.side_effect = Exception("fail")
|
||||
count = loader._download_folder_recursive("folder1", str(tmp_path))
|
||||
assert count == 0
|
||||
|
||||
|
||||
class TestDownloadFileToDirectory:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_delegates_to_download_single(self, loader, tmp_path):
|
||||
loader._download_single_file = MagicMock(return_value=True)
|
||||
assert loader._download_file_to_directory("f1", str(tmp_path)) is True
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_exception_returns_false(self, loader, tmp_path):
|
||||
loader._download_single_file = MagicMock(side_effect=Exception("fail"))
|
||||
assert loader._download_file_to_directory("f1", str(tmp_path)) is False
|
||||
|
||||
|
||||
class TestDownloadFolderContents:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_delegates_to_recursive(self, loader, tmp_path):
|
||||
loader._download_folder_recursive = MagicMock(return_value=5)
|
||||
count = loader._download_folder_contents("folder1", str(tmp_path))
|
||||
assert count == 5
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_exception_returns_zero(self, loader, tmp_path):
|
||||
loader.service = None
|
||||
loader.auth.build_drive_service.side_effect = Exception("fail")
|
||||
count = loader._download_folder_contents("folder1", str(tmp_path))
|
||||
assert count == 0
|
||||
|
||||
|
||||
class TestEnsureService:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_builds_service_when_none(self, loader):
|
||||
loader.service = None
|
||||
mock_svc = MagicMock()
|
||||
loader.auth.build_drive_service.return_value = mock_svc
|
||||
loader._ensure_service()
|
||||
assert loader.service == mock_svc
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_noop_when_service_exists(self, loader, mock_service):
|
||||
loader._ensure_service()
|
||||
assert loader.service == mock_service
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_build_failure_raises(self, loader):
|
||||
loader.service = None
|
||||
loader.auth.build_drive_service.side_effect = Exception("fail")
|
||||
with pytest.raises(ValueError, match="Cannot access Google Drive"):
|
||||
loader._ensure_service()
|
||||
|
||||
|
||||
class TestGetExtensionForMimeType:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_known_mime_types(self, loader):
|
||||
assert loader._get_extension_for_mime_type("application/pdf") == ".pdf"
|
||||
assert loader._get_extension_for_mime_type("text/plain") == ".txt"
|
||||
assert loader._get_extension_for_mime_type("text/html") == ".html"
|
||||
assert loader._get_extension_for_mime_type("text/markdown") == ".md"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_unknown_returns_bin(self, loader):
|
||||
assert loader._get_extension_for_mime_type("application/unknown") == ".bin"
|
||||
326
tests/parser/connectors/test_share_point_auth.py
Normal file
326
tests/parser/connectors/test_share_point_auth.py
Normal file
@@ -0,0 +1,326 @@
|
||||
"""Tests for SharePointAuth."""
|
||||
|
||||
import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings():
|
||||
s = MagicMock()
|
||||
s.MICROSOFT_CLIENT_ID = "ms-client-id"
|
||||
s.MICROSOFT_CLIENT_SECRET = "ms-client-secret"
|
||||
s.MICROSOFT_TENANT_ID = "tenant-id-123"
|
||||
s.CONNECTOR_REDIRECT_BASE_URI = "https://redirect.example.com/callback"
|
||||
s.MONGO_DB_NAME = "test_db"
|
||||
# Delete MICROSOFT_AUTHORITY so getattr falls back to default
|
||||
del s.MICROSOFT_AUTHORITY
|
||||
return s
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_msal():
|
||||
with patch("application.parser.connectors.share_point.auth.ConfidentialClientApplication") as MockMSAL:
|
||||
mock_app = MagicMock()
|
||||
MockMSAL.return_value = mock_app
|
||||
yield mock_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth(mock_settings, mock_msal):
|
||||
with patch("application.parser.connectors.share_point.auth.settings", mock_settings):
|
||||
from application.parser.connectors.share_point.auth import SharePointAuth
|
||||
return SharePointAuth()
|
||||
|
||||
|
||||
class TestSharePointAuthInit:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_sets_attributes(self, auth, mock_settings):
|
||||
assert auth.client_id == "ms-client-id"
|
||||
assert auth.client_secret == "ms-client-secret"
|
||||
assert auth.redirect_uri == "https://redirect.example.com/callback"
|
||||
assert auth.tenant_id == "tenant-id-123"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_missing_client_id_raises(self, mock_settings):
|
||||
mock_settings.MICROSOFT_CLIENT_ID = None
|
||||
with patch("application.parser.connectors.share_point.auth.settings", mock_settings), \
|
||||
patch("application.parser.connectors.share_point.auth.ConfidentialClientApplication"):
|
||||
from application.parser.connectors.share_point.auth import SharePointAuth
|
||||
with pytest.raises(ValueError, match="MICROSOFT_CLIENT_ID"):
|
||||
SharePointAuth()
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_missing_client_secret_raises(self, mock_settings):
|
||||
mock_settings.MICROSOFT_CLIENT_SECRET = None
|
||||
with patch("application.parser.connectors.share_point.auth.settings", mock_settings), \
|
||||
patch("application.parser.connectors.share_point.auth.ConfidentialClientApplication"):
|
||||
from application.parser.connectors.share_point.auth import SharePointAuth
|
||||
with pytest.raises(ValueError, match="MICROSOFT_CLIENT_SECRET"):
|
||||
SharePointAuth()
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_default_authority(self, auth):
|
||||
assert "login.microsoftonline.com" in auth.authority
|
||||
assert "tenant-id-123" in auth.authority
|
||||
|
||||
|
||||
class TestGetAuthorizationUrl:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_returns_url(self, auth, mock_msal):
|
||||
mock_msal.get_authorization_request_url.return_value = "https://login.microsoftonline.com/auth?state=s1"
|
||||
url = auth.get_authorization_url(state="s1")
|
||||
assert url == "https://login.microsoftonline.com/auth?state=s1"
|
||||
mock_msal.get_authorization_request_url.assert_called_once()
|
||||
|
||||
|
||||
class TestExchangeCodeForTokens:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_successful_exchange(self, auth, mock_msal):
|
||||
mock_msal.acquire_token_by_authorization_code.return_value = {
|
||||
"access_token": "at",
|
||||
"refresh_token": "rt",
|
||||
"scope": ["Files.Read"],
|
||||
"id_token_claims": {
|
||||
"iss": "https://login.microsoftonline.com/tid/v2.0",
|
||||
"exp": 1700000000,
|
||||
"tid": "work-tenant-id",
|
||||
"name": "Test User",
|
||||
"preferred_username": "test@example.com",
|
||||
},
|
||||
}
|
||||
result = auth.exchange_code_for_tokens("auth_code")
|
||||
assert result["access_token"] == "at"
|
||||
assert result["refresh_token"] == "rt"
|
||||
assert result["user_info"]["name"] == "Test User"
|
||||
assert result["user_info"]["email"] == "test@example.com"
|
||||
assert result["allows_shared_content"] is True
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_error_in_response_raises(self, auth, mock_msal):
|
||||
mock_msal.acquire_token_by_authorization_code.return_value = {
|
||||
"error": "invalid_grant",
|
||||
"error_description": "Code expired",
|
||||
}
|
||||
with pytest.raises(ValueError, match="Code expired"):
|
||||
auth.exchange_code_for_tokens("bad_code")
|
||||
|
||||
|
||||
class TestRefreshAccessToken:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_successful_refresh(self, auth, mock_msal):
|
||||
mock_msal.acquire_token_by_refresh_token.return_value = {
|
||||
"access_token": "new_at",
|
||||
"refresh_token": "new_rt",
|
||||
"scope": ["Files.Read"],
|
||||
"id_token_claims": {
|
||||
"iss": "https://issuer",
|
||||
"exp": 1700001000,
|
||||
"tid": "work-tid",
|
||||
"name": "User",
|
||||
"preferred_username": "u@example.com",
|
||||
},
|
||||
}
|
||||
result = auth.refresh_access_token("old_rt")
|
||||
assert result["access_token"] == "new_at"
|
||||
assert result["refresh_token"] == "new_rt"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_error_in_response_raises(self, auth, mock_msal):
|
||||
mock_msal.acquire_token_by_refresh_token.return_value = {
|
||||
"error": "invalid_grant",
|
||||
"error_description": "Token revoked",
|
||||
}
|
||||
with pytest.raises(ValueError, match="Token revoked"):
|
||||
auth.refresh_access_token("bad_rt")
|
||||
|
||||
|
||||
class TestIsTokenExpired:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_expired_token(self, auth):
|
||||
past = int((datetime.datetime.now() - datetime.timedelta(hours=1)).timestamp())
|
||||
assert auth.is_token_expired({"expiry": past}) is True
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_valid_token(self, auth):
|
||||
future = int((datetime.datetime.now() + datetime.timedelta(hours=1)).timestamp())
|
||||
assert auth.is_token_expired({"expiry": future}) is False
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_within_buffer(self, auth):
|
||||
almost_expired = int((datetime.datetime.now() + datetime.timedelta(seconds=30)).timestamp())
|
||||
assert auth.is_token_expired({"expiry": almost_expired}) is True
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_none_token_info(self, auth):
|
||||
assert auth.is_token_expired(None) is True
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_missing_expiry(self, auth):
|
||||
assert auth.is_token_expired({}) is True
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_none_expiry(self, auth):
|
||||
assert auth.is_token_expired({"expiry": None}) is True
|
||||
|
||||
|
||||
class TestSanitizeTokenInfo:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_includes_allows_shared_content(self, auth):
|
||||
token_info = {
|
||||
"access_token": "at",
|
||||
"refresh_token": "rt",
|
||||
"token_uri": "https://uri",
|
||||
"expiry": 123,
|
||||
"allows_shared_content": True,
|
||||
}
|
||||
result = auth.sanitize_token_info(token_info)
|
||||
assert result["allows_shared_content"] is True
|
||||
assert result["access_token"] == "at"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_defaults_allows_shared_content_to_false(self, auth):
|
||||
token_info = {
|
||||
"access_token": "at",
|
||||
"refresh_token": "rt",
|
||||
"token_uri": "https://uri",
|
||||
"expiry": 123,
|
||||
}
|
||||
result = auth.sanitize_token_info(token_info)
|
||||
assert result["allows_shared_content"] is False
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_with_extra_fields(self, auth):
|
||||
token_info = {
|
||||
"access_token": "at",
|
||||
"refresh_token": "rt",
|
||||
"token_uri": "https://uri",
|
||||
"expiry": 123,
|
||||
"allows_shared_content": True,
|
||||
}
|
||||
result = auth.sanitize_token_info(token_info, custom="val")
|
||||
assert result["custom"] == "val"
|
||||
|
||||
|
||||
class TestAllowsSharedContent:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_work_account_returns_true(self, auth):
|
||||
claims = {"tid": "some-work-tenant-id"}
|
||||
assert auth._allows_shared_content(claims) is True
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_personal_account_returns_false(self, auth):
|
||||
claims = {"tid": "9188040d-6c67-4c5b-b112-36a304b66dad"}
|
||||
assert auth._allows_shared_content(claims) is False
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_empty_tid_returns_false(self, auth):
|
||||
assert auth._allows_shared_content({"tid": ""}) is False
|
||||
assert auth._allows_shared_content({}) is False
|
||||
|
||||
|
||||
class TestMapTokenResponse:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_maps_all_fields(self, auth):
|
||||
result = {
|
||||
"access_token": "at",
|
||||
"refresh_token": "rt",
|
||||
"scope": ["Files.Read"],
|
||||
"id_token_claims": {
|
||||
"iss": "https://issuer",
|
||||
"exp": 1700000000,
|
||||
"tid": "work-tid",
|
||||
"name": "User Name",
|
||||
"preferred_username": "user@example.com",
|
||||
},
|
||||
}
|
||||
mapped = auth.map_token_response(result)
|
||||
assert mapped["access_token"] == "at"
|
||||
assert mapped["refresh_token"] == "rt"
|
||||
assert mapped["token_uri"] == "https://issuer"
|
||||
assert mapped["scopes"] == ["Files.Read"]
|
||||
assert mapped["expiry"] == 1700000000
|
||||
assert mapped["user_info"]["name"] == "User Name"
|
||||
assert mapped["user_info"]["email"] == "user@example.com"
|
||||
assert mapped["allows_shared_content"] is True
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_missing_claims_uses_defaults(self, auth):
|
||||
result = {"access_token": "at", "refresh_token": "rt"}
|
||||
mapped = auth.map_token_response(result)
|
||||
assert mapped["token_uri"] is None
|
||||
assert mapped["expiry"] is None
|
||||
assert mapped["user_info"]["name"] is None
|
||||
assert mapped["allows_shared_content"] is False
|
||||
|
||||
|
||||
class TestGetTokenInfoFromSession:
|
||||
|
||||
def _mock_mongo(self, mock_settings, find_one_return):
|
||||
mock_collection = MagicMock()
|
||||
mock_collection.find_one.return_value = find_one_return
|
||||
mock_db = MagicMock()
|
||||
mock_db.__getitem__ = MagicMock(return_value=mock_collection)
|
||||
mock_client = {mock_settings.MONGO_DB_NAME: mock_db}
|
||||
return mock_client
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_valid_session(self, auth, mock_settings):
|
||||
mock_client = self._mock_mongo(mock_settings, {
|
||||
"session_token": "st",
|
||||
"token_info": {"access_token": "at", "refresh_token": "rt"},
|
||||
})
|
||||
|
||||
with patch("application.core.mongo_db.MongoDB.get_client", return_value=mock_client), \
|
||||
patch("application.core.settings.settings", mock_settings):
|
||||
result = auth.get_token_info_from_session("st")
|
||||
assert result["access_token"] == "at"
|
||||
assert "token_uri" in result
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_session_not_found_raises(self, auth, mock_settings):
|
||||
mock_client = self._mock_mongo(mock_settings, None)
|
||||
|
||||
with patch("application.core.mongo_db.MongoDB.get_client", return_value=mock_client), \
|
||||
patch("application.core.settings.settings", mock_settings):
|
||||
with pytest.raises(ValueError, match="Failed to retrieve SharePoint token"):
|
||||
auth.get_token_info_from_session("bad")
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_missing_token_info_raises(self, auth, mock_settings):
|
||||
mock_client = self._mock_mongo(mock_settings, {"session_token": "st"})
|
||||
|
||||
with patch("application.core.mongo_db.MongoDB.get_client", return_value=mock_client), \
|
||||
patch("application.core.settings.settings", mock_settings):
|
||||
with pytest.raises(ValueError, match="Failed to retrieve SharePoint token"):
|
||||
auth.get_token_info_from_session("st")
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_empty_token_info_raises(self, auth, mock_settings):
|
||||
mock_client = self._mock_mongo(mock_settings, {"session_token": "st", "token_info": None})
|
||||
|
||||
with patch("application.core.mongo_db.MongoDB.get_client", return_value=mock_client), \
|
||||
patch("application.core.settings.settings", mock_settings):
|
||||
with pytest.raises(ValueError, match="Failed to retrieve SharePoint token"):
|
||||
auth.get_token_info_from_session("st")
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_missing_required_fields_raises(self, auth, mock_settings):
|
||||
mock_client = self._mock_mongo(mock_settings, {
|
||||
"session_token": "st",
|
||||
"token_info": {"access_token": "at"},
|
||||
})
|
||||
|
||||
with patch("application.core.mongo_db.MongoDB.get_client", return_value=mock_client), \
|
||||
patch("application.core.settings.settings", mock_settings):
|
||||
with pytest.raises(ValueError, match="Failed to retrieve SharePoint token"):
|
||||
auth.get_token_info_from_session("st")
|
||||
1036
tests/parser/connectors/test_share_point_loader.py
Normal file
1036
tests/parser/connectors/test_share_point_loader.py
Normal file
File diff suppressed because it is too large
Load Diff
115
tests/test_app_routes.py
Normal file
115
tests/test_app_routes.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""Tests for application/app.py route handlers."""
|
||||
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
"""Import the Flask app with auth mocked to avoid JWT setup issues."""
|
||||
with patch("application.app.handle_auth", return_value={"sub": "test_user"}):
|
||||
from application.app import app as flask_app
|
||||
flask_app.config["TESTING"] = True
|
||||
yield flask_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
return app.test_client()
|
||||
|
||||
|
||||
class TestHomeRoute:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_root_returns_200(self, client):
|
||||
"""Root serves Swagger UI via Flask-RESTX."""
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
class TestHealthRoute:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_returns_ok(self, client):
|
||||
response = client.get("/api/health")
|
||||
assert response.status_code == 200
|
||||
data = json.loads(response.data)
|
||||
assert data["status"] == "ok"
|
||||
|
||||
|
||||
class TestConfigRoute:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_returns_auth_config(self, client):
|
||||
response = client.get("/api/config")
|
||||
assert response.status_code == 200
|
||||
data = json.loads(response.data)
|
||||
assert "auth_type" in data
|
||||
assert "requires_auth" in data
|
||||
|
||||
|
||||
class TestGenerateTokenRoute:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_session_jwt_generates_token(self, client, app):
|
||||
with patch("application.app.settings") as mock_settings:
|
||||
mock_settings.AUTH_TYPE = "session_jwt"
|
||||
mock_settings.JWT_SECRET_KEY = "test_secret"
|
||||
response = client.get("/api/generate_token")
|
||||
assert response.status_code == 200
|
||||
data = json.loads(response.data)
|
||||
assert "token" in data
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_non_session_jwt_returns_error(self, client, app):
|
||||
with patch("application.app.settings") as mock_settings:
|
||||
mock_settings.AUTH_TYPE = "none"
|
||||
response = client.get("/api/generate_token")
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
class TestSttRequestSizeLimits:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_non_stt_request_passes(self, client):
|
||||
response = client.get("/api/health")
|
||||
assert response.status_code == 200
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_oversized_stt_request_rejected(self, client):
|
||||
with patch("application.app.should_reject_stt_request", return_value=True), \
|
||||
patch("application.app.build_stt_file_size_limit_message", return_value="Too large"):
|
||||
response = client.post("/api/stt/upload", data=b"x" * 100)
|
||||
assert response.status_code == 413
|
||||
|
||||
|
||||
class TestAuthenticateRequest:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_options_returns_200(self, client):
|
||||
response = client.options("/api/health")
|
||||
assert response.status_code == 200
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_auth_error_returns_401(self, client, app):
|
||||
with patch("application.app.handle_auth", return_value={"error": "Invalid token"}):
|
||||
response = client.get("/api/health")
|
||||
assert response.status_code == 401
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_token_sets_none(self, client, app):
|
||||
with patch("application.app.handle_auth", return_value=None):
|
||||
response = client.get("/api/health")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
class TestAfterRequest:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_cors_headers(self, client):
|
||||
response = client.get("/api/health")
|
||||
assert response.headers.get("Access-Control-Allow-Origin") == "*"
|
||||
assert "Content-Type" in response.headers.get("Access-Control-Allow-Headers", "")
|
||||
assert "GET" in response.headers.get("Access-Control-Allow-Methods", "")
|
||||
279
tests/test_namespaces.py
Normal file
279
tests/test_namespaces.py
Normal file
@@ -0,0 +1,279 @@
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from application.templates.namespaces import (
|
||||
NamespaceBuilder,
|
||||
NamespaceManager,
|
||||
PassthroughNamespace,
|
||||
SourceNamespace,
|
||||
SystemNamespace,
|
||||
ToolsNamespace,
|
||||
)
|
||||
|
||||
|
||||
# ── SystemNamespace ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSystemNamespace:
|
||||
def test_namespace_name(self):
|
||||
ns = SystemNamespace()
|
||||
assert ns.namespace_name == "system"
|
||||
|
||||
def test_build_returns_expected_keys(self):
|
||||
ns = SystemNamespace()
|
||||
result = ns.build()
|
||||
assert "date" in result
|
||||
assert "time" in result
|
||||
assert "timestamp" in result
|
||||
assert "request_id" in result
|
||||
assert "user_id" in result
|
||||
|
||||
def test_build_with_request_id(self):
|
||||
ns = SystemNamespace()
|
||||
result = ns.build(request_id="req-123")
|
||||
assert result["request_id"] == "req-123"
|
||||
|
||||
def test_build_with_user_id(self):
|
||||
ns = SystemNamespace()
|
||||
result = ns.build(user_id="user-456")
|
||||
assert result["user_id"] == "user-456"
|
||||
|
||||
def test_build_generates_uuid_when_no_request_id(self):
|
||||
ns = SystemNamespace()
|
||||
result = ns.build()
|
||||
assert len(result["request_id"]) == 36 # UUID format
|
||||
|
||||
def test_user_id_defaults_to_none(self):
|
||||
ns = SystemNamespace()
|
||||
result = ns.build()
|
||||
assert result["user_id"] is None
|
||||
|
||||
def test_date_format(self):
|
||||
ns = SystemNamespace()
|
||||
fixed = datetime(2026, 1, 15, 10, 30, 45, tzinfo=timezone.utc)
|
||||
with patch("application.templates.namespaces.datetime") as mock_dt:
|
||||
mock_dt.now.return_value = fixed
|
||||
mock_dt.side_effect = lambda *a, **kw: datetime(*a, **kw)
|
||||
result = ns.build()
|
||||
assert result["date"] == "2026-01-15"
|
||||
assert result["time"] == "10:30:45"
|
||||
|
||||
def test_extra_kwargs_ignored(self):
|
||||
ns = SystemNamespace()
|
||||
result = ns.build(unknown_param="whatever")
|
||||
assert "date" in result
|
||||
|
||||
|
||||
# ── PassthroughNamespace ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestPassthroughNamespace:
|
||||
def test_namespace_name(self):
|
||||
ns = PassthroughNamespace()
|
||||
assert ns.namespace_name == "passthrough"
|
||||
|
||||
def test_none_data_returns_empty(self):
|
||||
ns = PassthroughNamespace()
|
||||
assert ns.build(passthrough_data=None) == {}
|
||||
|
||||
def test_no_data_returns_empty(self):
|
||||
ns = PassthroughNamespace()
|
||||
assert ns.build() == {}
|
||||
|
||||
def test_safe_types_pass_through(self):
|
||||
ns = PassthroughNamespace()
|
||||
data = {"s": "string", "i": 42, "f": 3.14, "b": True, "n": None}
|
||||
result = ns.build(passthrough_data=data)
|
||||
assert result == data
|
||||
|
||||
def test_non_serializable_types_filtered(self):
|
||||
ns = PassthroughNamespace()
|
||||
data = {"good": "ok", "bad": object()}
|
||||
result = ns.build(passthrough_data=data)
|
||||
assert result == {"good": "ok"}
|
||||
|
||||
def test_list_value_filtered(self):
|
||||
ns = PassthroughNamespace()
|
||||
data = {"list_val": [1, 2, 3]}
|
||||
result = ns.build(passthrough_data=data)
|
||||
assert result == {}
|
||||
|
||||
def test_dict_value_filtered(self):
|
||||
ns = PassthroughNamespace()
|
||||
data = {"nested": {"key": "val"}}
|
||||
result = ns.build(passthrough_data=data)
|
||||
assert result == {}
|
||||
|
||||
|
||||
# ── SourceNamespace ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSourceNamespace:
|
||||
def test_namespace_name(self):
|
||||
ns = SourceNamespace()
|
||||
assert ns.namespace_name == "source"
|
||||
|
||||
def test_no_docs_returns_empty(self):
|
||||
ns = SourceNamespace()
|
||||
assert ns.build() == {}
|
||||
|
||||
def test_with_docs(self):
|
||||
ns = SourceNamespace()
|
||||
docs = [{"text": "doc1"}, {"text": "doc2"}]
|
||||
result = ns.build(docs=docs)
|
||||
assert result["documents"] == docs
|
||||
assert result["count"] == 2
|
||||
|
||||
def test_with_docs_together(self):
|
||||
ns = SourceNamespace()
|
||||
result = ns.build(docs_together="all content together")
|
||||
assert result["content"] == "all content together"
|
||||
assert result["docs_together"] == "all content together"
|
||||
assert result["summaries"] == "all content together"
|
||||
|
||||
def test_with_both_docs_and_docs_together(self):
|
||||
ns = SourceNamespace()
|
||||
docs = [{"text": "doc1"}]
|
||||
result = ns.build(docs=docs, docs_together="combined")
|
||||
assert result["documents"] == docs
|
||||
assert result["count"] == 1
|
||||
assert result["content"] == "combined"
|
||||
|
||||
def test_empty_docs_list_returns_empty(self):
|
||||
ns = SourceNamespace()
|
||||
result = ns.build(docs=[])
|
||||
assert "documents" not in result
|
||||
|
||||
|
||||
# ── ToolsNamespace ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestToolsNamespace:
|
||||
def test_namespace_name(self):
|
||||
ns = ToolsNamespace()
|
||||
assert ns.namespace_name == "tools"
|
||||
|
||||
def test_none_data_returns_empty(self):
|
||||
ns = ToolsNamespace()
|
||||
assert ns.build(tools_data=None) == {}
|
||||
|
||||
def test_no_data_returns_empty(self):
|
||||
ns = ToolsNamespace()
|
||||
assert ns.build() == {}
|
||||
|
||||
def test_safe_types_pass_through(self):
|
||||
ns = ToolsNamespace()
|
||||
data = {
|
||||
"str_tool": "result",
|
||||
"dict_tool": {"key": "val"},
|
||||
"list_tool": [1, 2],
|
||||
"int_tool": 42,
|
||||
"float_tool": 3.14,
|
||||
"bool_tool": True,
|
||||
"none_tool": None,
|
||||
}
|
||||
result = ns.build(tools_data=data)
|
||||
assert result == data
|
||||
|
||||
def test_non_serializable_filtered(self):
|
||||
ns = ToolsNamespace()
|
||||
data = {"good": "ok", "bad": object()}
|
||||
result = ns.build(tools_data=data)
|
||||
assert result == {"good": "ok"}
|
||||
|
||||
|
||||
# ── NamespaceBuilder ABC ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestNamespaceBuilderABC:
|
||||
def test_cannot_instantiate(self):
|
||||
with pytest.raises(TypeError):
|
||||
NamespaceBuilder()
|
||||
|
||||
def test_subclass_must_implement_both(self):
|
||||
class Incomplete(NamespaceBuilder):
|
||||
pass
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
Incomplete()
|
||||
|
||||
def test_concrete_subclass_works(self):
|
||||
class Complete(NamespaceBuilder):
|
||||
@property
|
||||
def namespace_name(self):
|
||||
return "test"
|
||||
|
||||
def build(self, **kwargs):
|
||||
return {"ok": True}
|
||||
|
||||
inst = Complete()
|
||||
assert inst.namespace_name == "test"
|
||||
assert inst.build() == {"ok": True}
|
||||
|
||||
|
||||
# ── NamespaceManager ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestNamespaceManager:
|
||||
def test_build_context_contains_all_namespaces(self):
|
||||
mgr = NamespaceManager()
|
||||
ctx = mgr.build_context()
|
||||
assert "system" in ctx
|
||||
assert "passthrough" in ctx
|
||||
assert "source" in ctx
|
||||
assert "tools" in ctx
|
||||
|
||||
def test_system_namespace_populated(self):
|
||||
mgr = NamespaceManager()
|
||||
ctx = mgr.build_context(request_id="r1", user_id="u1")
|
||||
assert ctx["system"]["request_id"] == "r1"
|
||||
assert ctx["system"]["user_id"] == "u1"
|
||||
|
||||
def test_passthrough_namespace_populated(self):
|
||||
mgr = NamespaceManager()
|
||||
ctx = mgr.build_context(passthrough_data={"key": "val"})
|
||||
assert ctx["passthrough"] == {"key": "val"}
|
||||
|
||||
def test_source_namespace_populated(self):
|
||||
mgr = NamespaceManager()
|
||||
docs = [{"text": "doc"}]
|
||||
ctx = mgr.build_context(docs=docs)
|
||||
assert ctx["source"]["count"] == 1
|
||||
|
||||
def test_tools_namespace_populated(self):
|
||||
mgr = NamespaceManager()
|
||||
ctx = mgr.build_context(tools_data={"search": "results"})
|
||||
assert ctx["tools"] == {"search": "results"}
|
||||
|
||||
def test_empty_kwargs_all_namespaces_present(self):
|
||||
mgr = NamespaceManager()
|
||||
ctx = mgr.build_context()
|
||||
for ns in ["system", "passthrough", "source", "tools"]:
|
||||
assert ns in ctx
|
||||
assert isinstance(ctx[ns], dict)
|
||||
|
||||
def test_builder_exception_returns_empty_namespace(self):
|
||||
mgr = NamespaceManager()
|
||||
with patch.object(
|
||||
mgr._builders["system"], "build", side_effect=RuntimeError("boom")
|
||||
):
|
||||
ctx = mgr.build_context()
|
||||
assert ctx["system"] == {}
|
||||
assert "passthrough" in ctx
|
||||
|
||||
def test_get_builder_existing(self):
|
||||
mgr = NamespaceManager()
|
||||
builder = mgr.get_builder("system")
|
||||
assert isinstance(builder, SystemNamespace)
|
||||
|
||||
def test_get_builder_nonexistent(self):
|
||||
mgr = NamespaceManager()
|
||||
assert mgr.get_builder("nonexistent") is None
|
||||
353
tests/test_retriever.py
Normal file
353
tests/test_retriever.py
Normal file
@@ -0,0 +1,353 @@
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from application.retriever.base import BaseRetriever
|
||||
from application.retriever.retriever_creator import RetrieverCreator
|
||||
|
||||
|
||||
# ── BaseRetriever ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestBaseRetriever:
|
||||
def test_cannot_instantiate_directly(self):
|
||||
with pytest.raises(TypeError):
|
||||
BaseRetriever()
|
||||
|
||||
def test_subclass_must_implement_search(self):
|
||||
class Incomplete(BaseRetriever):
|
||||
pass
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
Incomplete()
|
||||
|
||||
def test_concrete_subclass_works(self):
|
||||
class Concrete(BaseRetriever):
|
||||
def search(self, *args, **kwargs):
|
||||
return "ok"
|
||||
|
||||
instance = Concrete()
|
||||
assert instance.search() == "ok"
|
||||
|
||||
|
||||
# ── RetrieverCreator ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestRetrieverCreator:
|
||||
def test_create_classic(self):
|
||||
mock_cls = Mock(return_value="rag_instance")
|
||||
original = RetrieverCreator.retrievers.copy()
|
||||
RetrieverCreator.retrievers["classic"] = mock_cls
|
||||
try:
|
||||
result = RetrieverCreator.create_retriever("classic", "arg1", key="val")
|
||||
mock_cls.assert_called_once_with("arg1", key="val")
|
||||
assert result == "rag_instance"
|
||||
finally:
|
||||
RetrieverCreator.retrievers.update(original)
|
||||
|
||||
def test_create_default(self):
|
||||
mock_cls = Mock(return_value="rag_instance")
|
||||
original = RetrieverCreator.retrievers.copy()
|
||||
RetrieverCreator.retrievers["default"] = mock_cls
|
||||
try:
|
||||
result = RetrieverCreator.create_retriever("default")
|
||||
mock_cls.assert_called_once_with()
|
||||
assert result == "rag_instance"
|
||||
finally:
|
||||
RetrieverCreator.retrievers.update(original)
|
||||
|
||||
def test_create_none_type_uses_default(self):
|
||||
mock_cls = Mock(return_value="rag_instance")
|
||||
original = RetrieverCreator.retrievers.copy()
|
||||
RetrieverCreator.retrievers["default"] = mock_cls
|
||||
try:
|
||||
result = RetrieverCreator.create_retriever(None)
|
||||
mock_cls.assert_called_once()
|
||||
assert result == "rag_instance"
|
||||
finally:
|
||||
RetrieverCreator.retrievers.update(original)
|
||||
|
||||
def test_case_insensitive(self):
|
||||
mock_cls = Mock(return_value="rag_instance")
|
||||
original = RetrieverCreator.retrievers.copy()
|
||||
RetrieverCreator.retrievers["classic"] = mock_cls
|
||||
try:
|
||||
RetrieverCreator.create_retriever("CLASSIC")
|
||||
mock_cls.assert_called_once()
|
||||
finally:
|
||||
RetrieverCreator.retrievers.update(original)
|
||||
|
||||
def test_invalid_type_raises(self):
|
||||
with pytest.raises(ValueError, match="No retievers class found"):
|
||||
RetrieverCreator.create_retriever("nonexistent")
|
||||
|
||||
|
||||
# ── ClassicRAG ─────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _patch_llm_creator(mock_llm, monkeypatch):
|
||||
"""Patch LLMCreator.create_llm to return the shared mock_llm fixture."""
|
||||
monkeypatch.setattr(
|
||||
"application.retriever.classic_rag.LLMCreator.create_llm",
|
||||
Mock(return_value=mock_llm),
|
||||
)
|
||||
return mock_llm
|
||||
|
||||
|
||||
def _make_rag(source=None, _patch_llm_creator=None, **overrides):
|
||||
"""Helper – builds a ClassicRAG with sensible defaults."""
|
||||
from application.retriever.classic_rag import ClassicRAG
|
||||
|
||||
defaults = dict(
|
||||
source=source or {"question": "hello"},
|
||||
chat_history=None,
|
||||
prompt="",
|
||||
chunks=2,
|
||||
doc_token_limit=50000,
|
||||
model_id="test-model",
|
||||
user_api_key=None,
|
||||
agent_id=None,
|
||||
llm_name="openai",
|
||||
api_key="fake",
|
||||
decoded_token={"sub": "user1"},
|
||||
)
|
||||
defaults.update(overrides)
|
||||
return ClassicRAG(**defaults)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestClassicRAGInit:
|
||||
def test_basic_init(self, _patch_llm_creator):
|
||||
rag = _make_rag()
|
||||
assert rag.original_question == "hello"
|
||||
assert rag.chunks == 2
|
||||
assert rag.vectorstores == []
|
||||
|
||||
def test_active_docs_as_list(self, _patch_llm_creator):
|
||||
rag = _make_rag(source={"question": "q", "active_docs": ["a", "b"]})
|
||||
assert rag.vectorstores == ["a", "b"]
|
||||
|
||||
def test_active_docs_as_string(self, _patch_llm_creator):
|
||||
rag = _make_rag(source={"question": "q", "active_docs": "single"})
|
||||
assert rag.vectorstores == ["single"]
|
||||
|
||||
def test_active_docs_none(self, _patch_llm_creator):
|
||||
rag = _make_rag(source={"question": "q", "active_docs": None})
|
||||
assert rag.vectorstores == []
|
||||
|
||||
def test_chunks_string_converted(self, _patch_llm_creator):
|
||||
rag = _make_rag(chunks="5")
|
||||
assert rag.chunks == 5
|
||||
|
||||
def test_chunks_invalid_string_defaults(self, _patch_llm_creator):
|
||||
rag = _make_rag(chunks="abc")
|
||||
assert rag.chunks == 2
|
||||
|
||||
def test_decoded_token_none(self, _patch_llm_creator):
|
||||
rag = _make_rag(decoded_token=None)
|
||||
assert rag.decoded_token is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestClassicRAGValidateVectorstore:
|
||||
def test_removes_empty_ids(self, _patch_llm_creator):
|
||||
rag = _make_rag(source={"question": "q", "active_docs": ["ok", "", " ", "good"]})
|
||||
assert rag.vectorstores == ["ok", "good"]
|
||||
|
||||
def test_empty_vectorstores_no_error(self, _patch_llm_creator):
|
||||
rag = _make_rag(source={"question": "q"})
|
||||
assert rag.vectorstores == []
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestClassicRAGRephraseQuery:
|
||||
def test_no_history_returns_original(self, _patch_llm_creator):
|
||||
rag = _make_rag(
|
||||
source={"question": "original", "active_docs": ["vs1"]},
|
||||
chat_history=[],
|
||||
)
|
||||
assert rag.question == "original"
|
||||
|
||||
def test_no_vectorstores_returns_original(self, _patch_llm_creator):
|
||||
rag = _make_rag(
|
||||
source={"question": "original"},
|
||||
chat_history=[{"prompt": "hi", "response": "hello"}],
|
||||
)
|
||||
assert rag.question == "original"
|
||||
|
||||
def test_chunks_zero_returns_original(self, _patch_llm_creator):
|
||||
rag = _make_rag(
|
||||
source={"question": "original", "active_docs": ["vs1"]},
|
||||
chat_history=[{"prompt": "hi", "response": "hello"}],
|
||||
chunks=0,
|
||||
)
|
||||
assert rag.question == "original"
|
||||
|
||||
def test_rephrase_called_with_history(self, _patch_llm_creator, mock_llm):
|
||||
mock_llm.gen = Mock(return_value="rephrased question")
|
||||
rag = _make_rag(
|
||||
source={"question": "original", "active_docs": ["vs1"]},
|
||||
chat_history=[{"prompt": "hi", "response": "hello"}],
|
||||
)
|
||||
assert rag.question == "rephrased question"
|
||||
mock_llm.gen.assert_called_once()
|
||||
|
||||
def test_rephrase_llm_returns_empty_falls_back(self, _patch_llm_creator, mock_llm):
|
||||
mock_llm.gen = Mock(return_value="")
|
||||
rag = _make_rag(
|
||||
source={"question": "original", "active_docs": ["vs1"]},
|
||||
chat_history=[{"prompt": "hi", "response": "hello"}],
|
||||
)
|
||||
assert rag.question == "original"
|
||||
|
||||
def test_rephrase_llm_exception_falls_back(self, _patch_llm_creator, mock_llm):
|
||||
mock_llm.gen = Mock(side_effect=RuntimeError("boom"))
|
||||
rag = _make_rag(
|
||||
source={"question": "original", "active_docs": ["vs1"]},
|
||||
chat_history=[{"prompt": "hi", "response": "hello"}],
|
||||
)
|
||||
assert rag.question == "original"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestClassicRAGGetData:
|
||||
def test_chunks_zero_returns_empty(self, _patch_llm_creator):
|
||||
rag = _make_rag(chunks=0)
|
||||
assert rag._get_data() == []
|
||||
|
||||
def test_no_vectorstores_returns_empty(self, _patch_llm_creator):
|
||||
rag = _make_rag(source={"question": "q"})
|
||||
assert rag._get_data() == []
|
||||
|
||||
@patch("application.retriever.classic_rag.VectorCreator")
|
||||
@patch("application.retriever.classic_rag.num_tokens_from_string", return_value=10)
|
||||
def test_returns_docs_with_metadata(self, mock_tokens, mock_vc, _patch_llm_creator):
|
||||
mock_docsearch = MagicMock()
|
||||
mock_doc = MagicMock()
|
||||
mock_doc.page_content = "content here"
|
||||
mock_doc.metadata = {
|
||||
"title": "path/to/Title",
|
||||
"filename": "/docs/file.txt",
|
||||
"source": "http://example.com",
|
||||
}
|
||||
mock_docsearch.search.return_value = [mock_doc]
|
||||
mock_vc.create_vectorstore.return_value = mock_docsearch
|
||||
|
||||
rag = _make_rag(source={"question": "q", "active_docs": ["vs1"]})
|
||||
docs = rag._get_data()
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0]["text"] == "content here"
|
||||
assert docs[0]["title"] == "Title"
|
||||
assert docs[0]["filename"] == "file.txt"
|
||||
assert docs[0]["source"] == "http://example.com"
|
||||
|
||||
@patch("application.retriever.classic_rag.VectorCreator")
|
||||
@patch("application.retriever.classic_rag.num_tokens_from_string", return_value=10)
|
||||
def test_dict_style_docs(self, mock_tokens, mock_vc, _patch_llm_creator):
|
||||
mock_docsearch = MagicMock()
|
||||
mock_docsearch.search.return_value = [
|
||||
{"text": "dict content", "metadata": {"title": "Dict Title"}}
|
||||
]
|
||||
mock_vc.create_vectorstore.return_value = mock_docsearch
|
||||
|
||||
rag = _make_rag(source={"question": "q", "active_docs": ["vs1"]})
|
||||
docs = rag._get_data()
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0]["text"] == "dict content"
|
||||
|
||||
@patch("application.retriever.classic_rag.VectorCreator")
|
||||
@patch("application.retriever.classic_rag.num_tokens_from_string", return_value=100000)
|
||||
def test_token_budget_respected(self, mock_tokens, mock_vc, _patch_llm_creator):
|
||||
mock_docsearch = MagicMock()
|
||||
mock_doc = MagicMock()
|
||||
mock_doc.page_content = "big content"
|
||||
mock_doc.metadata = {"title": "t"}
|
||||
mock_docsearch.search.return_value = [mock_doc, mock_doc, mock_doc]
|
||||
mock_vc.create_vectorstore.return_value = mock_docsearch
|
||||
|
||||
rag = _make_rag(
|
||||
source={"question": "q", "active_docs": ["vs1"]},
|
||||
doc_token_limit=100,
|
||||
)
|
||||
docs = rag._get_data()
|
||||
# tokens (100000) exceed budget (90), so no docs should be added
|
||||
assert len(docs) == 0
|
||||
|
||||
@patch("application.retriever.classic_rag.VectorCreator")
|
||||
def test_vectorstore_error_continues(self, mock_vc, _patch_llm_creator):
|
||||
mock_vc.create_vectorstore.side_effect = RuntimeError("connection failed")
|
||||
|
||||
rag = _make_rag(source={"question": "q", "active_docs": ["vs1"]})
|
||||
docs = rag._get_data()
|
||||
assert docs == []
|
||||
|
||||
@patch("application.retriever.classic_rag.VectorCreator")
|
||||
@patch("application.retriever.classic_rag.num_tokens_from_string", return_value=10)
|
||||
def test_multiple_vectorstores(self, mock_tokens, mock_vc, _patch_llm_creator):
|
||||
mock_docsearch = MagicMock()
|
||||
mock_doc = MagicMock()
|
||||
mock_doc.page_content = "content"
|
||||
mock_doc.metadata = {"title": "t", "source": "s"}
|
||||
mock_docsearch.search.return_value = [mock_doc]
|
||||
mock_vc.create_vectorstore.return_value = mock_docsearch
|
||||
|
||||
rag = _make_rag(source={"question": "q", "active_docs": ["vs1", "vs2"]})
|
||||
docs = rag._get_data()
|
||||
assert len(docs) == 2
|
||||
|
||||
@patch("application.retriever.classic_rag.VectorCreator")
|
||||
@patch("application.retriever.classic_rag.num_tokens_from_string", return_value=10)
|
||||
def test_doc_missing_filename_uses_title(self, mock_tokens, mock_vc, _patch_llm_creator):
|
||||
mock_docsearch = MagicMock()
|
||||
mock_doc = MagicMock()
|
||||
mock_doc.page_content = "content"
|
||||
mock_doc.metadata = {"title": "MyTitle"}
|
||||
mock_docsearch.search.return_value = [mock_doc]
|
||||
mock_vc.create_vectorstore.return_value = mock_docsearch
|
||||
|
||||
rag = _make_rag(source={"question": "q", "active_docs": ["vs1"]})
|
||||
docs = rag._get_data()
|
||||
assert docs[0]["filename"] == "MyTitle"
|
||||
|
||||
@patch("application.retriever.classic_rag.VectorCreator")
|
||||
@patch("application.retriever.classic_rag.num_tokens_from_string", return_value=10)
|
||||
def test_non_string_title_converted(self, mock_tokens, mock_vc, _patch_llm_creator):
|
||||
mock_docsearch = MagicMock()
|
||||
mock_doc = MagicMock()
|
||||
mock_doc.page_content = "content"
|
||||
mock_doc.metadata = {"title": 42}
|
||||
mock_docsearch.search.return_value = [mock_doc]
|
||||
mock_vc.create_vectorstore.return_value = mock_docsearch
|
||||
|
||||
rag = _make_rag(source={"question": "q", "active_docs": ["vs1"]})
|
||||
docs = rag._get_data()
|
||||
assert docs[0]["title"] == "42"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestClassicRAGSearch:
|
||||
@patch("application.retriever.classic_rag.VectorCreator")
|
||||
@patch("application.retriever.classic_rag.num_tokens_from_string", return_value=10)
|
||||
def test_search_with_query_override(self, mock_tokens, mock_vc, _patch_llm_creator, mock_llm):
|
||||
mock_docsearch = MagicMock()
|
||||
mock_doc = MagicMock()
|
||||
mock_doc.page_content = "result"
|
||||
mock_doc.metadata = {"title": "t"}
|
||||
mock_docsearch.search.return_value = [mock_doc]
|
||||
mock_vc.create_vectorstore.return_value = mock_docsearch
|
||||
mock_llm.gen = Mock(return_value="")
|
||||
|
||||
rag = _make_rag(source={"question": "original", "active_docs": ["vs1"]})
|
||||
docs = rag.search(query="override query")
|
||||
assert rag.original_question == "override query"
|
||||
assert len(docs) == 1
|
||||
|
||||
def test_search_without_query_uses_default(self, _patch_llm_creator):
|
||||
rag = _make_rag(source={"question": "q"})
|
||||
docs = rag.search()
|
||||
assert docs == []
|
||||
492
tests/test_seeder.py
Normal file
492
tests/test_seeder.py
Normal file
@@ -0,0 +1,492 @@
|
||||
from unittest.mock import MagicMock, patch, mock_open
|
||||
|
||||
import mongomock
|
||||
import pytest
|
||||
from bson import ObjectId
|
||||
|
||||
from application.seed.seeder import DatabaseSeeder
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db():
|
||||
client = mongomock.MongoClient()
|
||||
return client["test_docsgpt"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def seeder(mock_db):
|
||||
return DatabaseSeeder(mock_db)
|
||||
|
||||
|
||||
# ── __init__ ───────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestDatabaseSeederInit:
|
||||
def test_collections_set(self, seeder, mock_db):
|
||||
assert seeder.db is mock_db
|
||||
assert seeder.tools_collection == mock_db["user_tools"]
|
||||
assert seeder.sources_collection == mock_db["sources"]
|
||||
assert seeder.agents_collection == mock_db["agents"]
|
||||
assert seeder.prompts_collection == mock_db["prompts"]
|
||||
assert seeder.system_user_id == "system"
|
||||
|
||||
|
||||
# ── _is_already_seeded ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestIsAlreadySeeded:
|
||||
def test_not_seeded(self, seeder):
|
||||
assert seeder._is_already_seeded() is False
|
||||
|
||||
def test_already_seeded(self, seeder, mock_db):
|
||||
mock_db["agents"].insert_one({"user": "system", "name": "test"})
|
||||
assert seeder._is_already_seeded() is True
|
||||
|
||||
|
||||
# ── _process_config ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestProcessConfig:
|
||||
def test_env_var_substitution(self, seeder, monkeypatch):
|
||||
monkeypatch.setenv("MY_SECRET", "secret_value")
|
||||
result = seeder._process_config({"key": "${MY_SECRET}"})
|
||||
assert result["key"] == "secret_value"
|
||||
|
||||
def test_missing_env_var_defaults_empty(self, seeder, monkeypatch):
|
||||
monkeypatch.delenv("NONEXISTENT_VAR", raising=False)
|
||||
result = seeder._process_config({"key": "${NONEXISTENT_VAR}"})
|
||||
assert result["key"] == ""
|
||||
|
||||
def test_non_env_value_unchanged(self, seeder):
|
||||
result = seeder._process_config({"key": "plain_value", "num": 42})
|
||||
assert result == {"key": "plain_value", "num": 42}
|
||||
|
||||
def test_partial_env_syntax_unchanged(self, seeder):
|
||||
result = seeder._process_config({"key": "${INCOMPLETE"})
|
||||
assert result["key"] == "${INCOMPLETE"
|
||||
|
||||
def test_empty_config(self, seeder):
|
||||
assert seeder._process_config({}) == {}
|
||||
|
||||
|
||||
# ── _handle_prompt ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestHandlePrompt:
|
||||
def test_no_prompt_returns_none(self, seeder):
|
||||
assert seeder._handle_prompt({"name": "agent1"}) is None
|
||||
|
||||
def test_empty_content_returns_none(self, seeder):
|
||||
config = {"name": "agent1", "prompt": {"name": "p", "content": ""}}
|
||||
assert seeder._handle_prompt(config) is None
|
||||
|
||||
def test_creates_prompt(self, seeder, mock_db):
|
||||
config = {
|
||||
"name": "agent1",
|
||||
"prompt": {"name": "My Prompt", "content": "You are helpful."},
|
||||
}
|
||||
result = seeder._handle_prompt(config)
|
||||
assert result is not None
|
||||
doc = mock_db["prompts"].find_one({"name": "My Prompt"})
|
||||
assert doc is not None
|
||||
assert doc["content"] == "You are helpful."
|
||||
assert doc["user"] == "system"
|
||||
|
||||
def test_duplicate_prompt_returns_existing(self, seeder, mock_db):
|
||||
config = {
|
||||
"name": "agent1",
|
||||
"prompt": {"name": "Dup Prompt", "content": "content"},
|
||||
}
|
||||
id1 = seeder._handle_prompt(config)
|
||||
id2 = seeder._handle_prompt(config)
|
||||
assert id1 == id2
|
||||
assert mock_db["prompts"].count_documents({"name": "Dup Prompt"}) == 1
|
||||
|
||||
def test_default_prompt_name(self, seeder, mock_db):
|
||||
config = {"name": "agent1", "prompt": {"content": "hello"}}
|
||||
seeder._handle_prompt(config)
|
||||
doc = mock_db["prompts"].find_one({"name": "agent1 Prompt"})
|
||||
assert doc is not None
|
||||
|
||||
def test_exception_returns_none(self, seeder):
|
||||
with patch.object(
|
||||
seeder.prompts_collection, "find_one", side_effect=RuntimeError("db error")
|
||||
):
|
||||
config = {
|
||||
"name": "agent1",
|
||||
"prompt": {"name": "p", "content": "c"},
|
||||
}
|
||||
assert seeder._handle_prompt(config) is None
|
||||
|
||||
|
||||
# ── _handle_tools ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestHandleTools:
|
||||
def test_no_tools_returns_empty(self, seeder):
|
||||
assert seeder._handle_tools({"name": "agent1"}) == []
|
||||
|
||||
@patch("application.seed.seeder.tool_manager")
|
||||
def test_creates_tool(self, mock_tm, seeder, mock_db):
|
||||
mock_tool = MagicMock()
|
||||
mock_tool.get_actions_metadata.return_value = [{"name": "act1"}]
|
||||
mock_tm.tools = {"my_tool": mock_tool}
|
||||
|
||||
config = {
|
||||
"name": "agent1",
|
||||
"tools": [{"name": "my_tool", "description": "desc"}],
|
||||
}
|
||||
ids = seeder._handle_tools(config)
|
||||
assert len(ids) == 1
|
||||
doc = mock_db["user_tools"].find_one({"name": "my_tool"})
|
||||
assert doc is not None
|
||||
assert doc["user"] == "system"
|
||||
|
||||
@patch("application.seed.seeder.tool_manager")
|
||||
def test_duplicate_tool_returns_existing(self, mock_tm, seeder, mock_db):
|
||||
mock_tool = MagicMock()
|
||||
mock_tool.get_actions_metadata.return_value = []
|
||||
mock_tm.tools = {"my_tool": mock_tool}
|
||||
|
||||
config = {
|
||||
"name": "agent1",
|
||||
"tools": [{"name": "my_tool"}],
|
||||
}
|
||||
ids1 = seeder._handle_tools(config)
|
||||
ids2 = seeder._handle_tools(config)
|
||||
assert ids1 == ids2
|
||||
assert mock_db["user_tools"].count_documents({"name": "my_tool"}) == 1
|
||||
|
||||
def test_tool_exception_continues(self, seeder):
|
||||
config = {
|
||||
"name": "agent1",
|
||||
"tools": [{"name": "broken_tool"}],
|
||||
}
|
||||
# tool_manager.tools will KeyError on "broken_tool"
|
||||
ids = seeder._handle_tools(config)
|
||||
assert ids == []
|
||||
|
||||
@patch("application.seed.seeder.tool_manager")
|
||||
def test_tool_config_env_expansion(self, mock_tm, seeder, monkeypatch):
|
||||
monkeypatch.setenv("TOOL_KEY", "expanded_val")
|
||||
mock_tool = MagicMock()
|
||||
mock_tool.get_actions_metadata.return_value = []
|
||||
mock_tm.tools = {"my_tool": mock_tool}
|
||||
|
||||
config = {
|
||||
"name": "agent1",
|
||||
"tools": [{"name": "my_tool", "config": {"api_key": "${TOOL_KEY}"}}],
|
||||
}
|
||||
seeder._handle_tools(config)
|
||||
doc = seeder.tools_collection.find_one({"name": "my_tool"})
|
||||
assert doc["config"]["api_key"] == "expanded_val"
|
||||
|
||||
|
||||
# ── _handle_source ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestHandleSource:
|
||||
def test_no_source_returns_none(self, seeder):
|
||||
assert seeder._handle_source({"name": "a"}) is None
|
||||
|
||||
def test_existing_source_returns_id(self, seeder, mock_db):
|
||||
inserted = mock_db["sources"].insert_one(
|
||||
{"user": "system", "remote_data": "http://example.com"}
|
||||
)
|
||||
config = {
|
||||
"name": "a",
|
||||
"source": {"url": "http://example.com", "name": "src"},
|
||||
}
|
||||
result = seeder._handle_source(config)
|
||||
assert result == inserted.inserted_id
|
||||
|
||||
@patch("application.seed.seeder.ingest_remote")
|
||||
def test_new_source_ingestion(self, mock_ingest, seeder):
|
||||
mock_task = MagicMock()
|
||||
mock_task.get.return_value = {"id": "new_source_id"}
|
||||
mock_task.successful.return_value = True
|
||||
mock_ingest.delay.return_value = mock_task
|
||||
|
||||
config = {
|
||||
"name": "a",
|
||||
"source": {"url": "http://new.com", "name": "new_src", "loader": "web"},
|
||||
}
|
||||
result = seeder._handle_source(config)
|
||||
assert result == "new_source_id"
|
||||
mock_ingest.delay.assert_called_once_with(
|
||||
source_data="http://new.com",
|
||||
job_name="new_src",
|
||||
user="system",
|
||||
loader="web",
|
||||
)
|
||||
|
||||
@patch("application.seed.seeder.ingest_remote")
|
||||
def test_source_ingestion_failure_returns_false(self, mock_ingest, seeder):
|
||||
mock_task = MagicMock()
|
||||
mock_task.get.side_effect = RuntimeError("timeout")
|
||||
mock_ingest.delay.return_value = mock_task
|
||||
|
||||
config = {
|
||||
"name": "a",
|
||||
"source": {"url": "http://fail.com", "name": "fail_src"},
|
||||
}
|
||||
result = seeder._handle_source(config)
|
||||
assert result is False
|
||||
|
||||
@patch("application.seed.seeder.ingest_remote")
|
||||
def test_source_missing_id_returns_false(self, mock_ingest, seeder):
|
||||
mock_task = MagicMock()
|
||||
mock_task.get.return_value = {"no_id_key": True}
|
||||
mock_task.successful.return_value = True
|
||||
mock_ingest.delay.return_value = mock_task
|
||||
|
||||
config = {
|
||||
"name": "a",
|
||||
"source": {"url": "http://bad.com", "name": "bad_src"},
|
||||
}
|
||||
result = seeder._handle_source(config)
|
||||
assert result is False
|
||||
|
||||
@patch("application.seed.seeder.ingest_remote")
|
||||
def test_default_loader(self, mock_ingest, seeder):
|
||||
mock_task = MagicMock()
|
||||
mock_task.get.return_value = {"id": "sid"}
|
||||
mock_task.successful.return_value = True
|
||||
mock_ingest.delay.return_value = mock_task
|
||||
|
||||
config = {
|
||||
"name": "a",
|
||||
"source": {"url": "http://x.com", "name": "s"},
|
||||
}
|
||||
seeder._handle_source(config)
|
||||
call_kwargs = mock_ingest.delay.call_args[1]
|
||||
assert call_kwargs["loader"] == "url"
|
||||
|
||||
|
||||
# ── seed_initial_data ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSeedInitialData:
|
||||
def test_already_seeded_skips(self, seeder, mock_db):
|
||||
mock_db["agents"].insert_one({"user": "system", "name": "existing"})
|
||||
with patch.object(seeder, "_seed_from_config") as mock_seed:
|
||||
seeder.seed_initial_data()
|
||||
mock_seed.assert_not_called()
|
||||
|
||||
def test_force_reseeds(self, seeder, mock_db):
|
||||
mock_db["agents"].insert_one({"user": "system", "name": "existing"})
|
||||
yaml_content = "agents: []"
|
||||
with patch("builtins.open", mock_open(read_data=yaml_content)):
|
||||
with patch.object(seeder, "_seed_from_config") as mock_seed:
|
||||
seeder.seed_initial_data(force=True)
|
||||
mock_seed.assert_called_once()
|
||||
|
||||
def test_config_file_not_found_raises(self, seeder):
|
||||
with pytest.raises(Exception):
|
||||
seeder.seed_initial_data(config_path="/nonexistent/path.yaml")
|
||||
|
||||
def test_custom_config_path(self, seeder):
|
||||
yaml_content = "agents:\n - name: test_agent"
|
||||
with patch("builtins.open", mock_open(read_data=yaml_content)):
|
||||
with patch.object(seeder, "_seed_from_config") as mock_seed:
|
||||
seeder.seed_initial_data(config_path="/custom/path.yaml")
|
||||
mock_seed.assert_called_once()
|
||||
|
||||
|
||||
# ── _seed_from_config ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSeedFromConfig:
|
||||
def test_no_agents_in_config(self, seeder):
|
||||
seeder._seed_from_config({})
|
||||
assert seeder.agents_collection.count_documents({}) == 0
|
||||
|
||||
def test_empty_agents_list(self, seeder):
|
||||
seeder._seed_from_config({"agents": []})
|
||||
assert seeder.agents_collection.count_documents({}) == 0
|
||||
|
||||
@patch.object(DatabaseSeeder, "_handle_source", return_value=None)
|
||||
@patch.object(DatabaseSeeder, "_handle_tools", return_value=[])
|
||||
@patch.object(DatabaseSeeder, "_handle_prompt", return_value=None)
|
||||
def test_creates_agent(self, mock_prompt, mock_tools, mock_source, seeder, mock_db):
|
||||
config = {
|
||||
"agents": [
|
||||
{
|
||||
"name": "TestAgent",
|
||||
"description": "A test agent",
|
||||
"agent_type": "classic",
|
||||
}
|
||||
]
|
||||
}
|
||||
seeder._seed_from_config(config)
|
||||
agent = mock_db["agents"].find_one({"name": "TestAgent"})
|
||||
assert agent is not None
|
||||
assert agent["user"] == "system"
|
||||
assert agent["agent_type"] == "classic"
|
||||
assert agent["status"] == "template"
|
||||
|
||||
@patch.object(DatabaseSeeder, "_handle_source", return_value=None)
|
||||
@patch.object(DatabaseSeeder, "_handle_tools", return_value=[])
|
||||
@patch.object(DatabaseSeeder, "_handle_prompt", return_value=None)
|
||||
def test_updates_existing_agent(self, mock_prompt, mock_tools, mock_source, seeder, mock_db):
|
||||
mock_db["agents"].insert_one(
|
||||
{"user": "system", "name": "TestAgent", "description": "old"}
|
||||
)
|
||||
config = {
|
||||
"agents": [
|
||||
{
|
||||
"name": "TestAgent",
|
||||
"description": "updated",
|
||||
"agent_type": "classic",
|
||||
}
|
||||
]
|
||||
}
|
||||
seeder._seed_from_config(config)
|
||||
assert mock_db["agents"].count_documents({"name": "TestAgent"}) == 1
|
||||
agent = mock_db["agents"].find_one({"name": "TestAgent"})
|
||||
assert agent["description"] == "updated"
|
||||
|
||||
@patch.object(DatabaseSeeder, "_handle_source", return_value=False)
|
||||
def test_source_failure_skips_agent(self, mock_source, seeder, mock_db):
|
||||
config = {
|
||||
"agents": [
|
||||
{
|
||||
"name": "SkippedAgent",
|
||||
"description": "skip",
|
||||
"agent_type": "classic",
|
||||
}
|
||||
]
|
||||
}
|
||||
seeder._seed_from_config(config)
|
||||
assert mock_db["agents"].count_documents({"name": "SkippedAgent"}) == 0
|
||||
|
||||
@patch.object(DatabaseSeeder, "_handle_source", side_effect=KeyError("name"))
|
||||
def test_agent_exception_continues(self, mock_source, seeder, mock_db):
|
||||
config = {
|
||||
"agents": [
|
||||
{"name": "Bad", "description": "x", "agent_type": "y"},
|
||||
{"name": "Good", "description": "x", "agent_type": "y"},
|
||||
]
|
||||
}
|
||||
with patch.object(seeder, "_handle_tools", return_value=[]):
|
||||
with patch.object(seeder, "_handle_prompt", return_value=None):
|
||||
seeder._seed_from_config(config)
|
||||
# Both agents should be attempted; first errors, second might too
|
||||
# Main assertion: no unhandled exception
|
||||
|
||||
@patch.object(DatabaseSeeder, "_handle_source", return_value=None)
|
||||
@patch.object(DatabaseSeeder, "_handle_tools")
|
||||
@patch.object(DatabaseSeeder, "_handle_prompt", return_value="prompt_id_123")
|
||||
def test_agent_with_source_and_tools(self, mock_prompt, mock_tools, mock_source, seeder, mock_db):
|
||||
tool_id = ObjectId()
|
||||
mock_tools.return_value = [tool_id]
|
||||
|
||||
source_id = ObjectId()
|
||||
mock_source.return_value = source_id
|
||||
|
||||
config = {
|
||||
"agents": [
|
||||
{
|
||||
"name": "FullAgent",
|
||||
"description": "full",
|
||||
"agent_type": "classic",
|
||||
"chunks": "5",
|
||||
"retriever": "classic",
|
||||
"image": "img.png",
|
||||
}
|
||||
]
|
||||
}
|
||||
seeder._seed_from_config(config)
|
||||
agent = mock_db["agents"].find_one({"name": "FullAgent"})
|
||||
assert agent is not None
|
||||
assert agent["prompt_id"] == "prompt_id_123"
|
||||
assert str(tool_id) in agent["tools"]
|
||||
assert agent["chunks"] == "5"
|
||||
assert agent["image"] == "img.png"
|
||||
|
||||
|
||||
# ── initialize_from_env ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestInitializeFromEnv:
|
||||
@patch("application.seed.seeder.MongoClient")
|
||||
def test_creates_seeder_from_env(self, mock_client_cls, monkeypatch):
|
||||
monkeypatch.setenv("MONGO_URI", "mongodb://test:27017")
|
||||
monkeypatch.setenv("MONGO_DB_NAME", "testdb")
|
||||
|
||||
mock_db = mongomock.MongoClient()["testdb"]
|
||||
mock_client = MagicMock()
|
||||
mock_client.__getitem__ = MagicMock(return_value=mock_db)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
seeder = DatabaseSeeder.initialize_from_env()
|
||||
mock_client_cls.assert_called_once_with("mongodb://test:27017")
|
||||
assert isinstance(seeder, DatabaseSeeder)
|
||||
|
||||
@patch("application.seed.seeder.MongoClient")
|
||||
def test_default_env_values(self, mock_client_cls, monkeypatch):
|
||||
monkeypatch.delenv("MONGO_URI", raising=False)
|
||||
monkeypatch.delenv("MONGO_DB_NAME", raising=False)
|
||||
|
||||
mock_db = mongomock.MongoClient()["docsgpt"]
|
||||
mock_client = MagicMock()
|
||||
mock_client.__getitem__ = MagicMock(return_value=mock_db)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
DatabaseSeeder.initialize_from_env()
|
||||
mock_client_cls.assert_called_once_with("mongodb://localhost:27017")
|
||||
|
||||
|
||||
# ── seed CLI commands ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSeedCommands:
|
||||
@patch("application.seed.commands.DatabaseSeeder")
|
||||
@patch("application.seed.commands.MongoDB")
|
||||
@patch("application.seed.commands.settings")
|
||||
def test_init_command(self, mock_settings, mock_mongodb, mock_seeder_cls):
|
||||
from click.testing import CliRunner
|
||||
from application.seed.commands import seed
|
||||
|
||||
mock_settings.MONGO_DB_NAME = "testdb"
|
||||
mock_client = MagicMock()
|
||||
mock_client.__getitem__ = MagicMock(return_value="mock_db")
|
||||
mock_mongodb.get_client.return_value = mock_client
|
||||
|
||||
mock_seeder = MagicMock()
|
||||
mock_seeder_cls.return_value = mock_seeder
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(seed, ["init"])
|
||||
assert result.exit_code == 0
|
||||
mock_seeder.seed_initial_data.assert_called_once_with(force=False)
|
||||
|
||||
@patch("application.seed.commands.DatabaseSeeder")
|
||||
@patch("application.seed.commands.MongoDB")
|
||||
@patch("application.seed.commands.settings")
|
||||
def test_init_command_with_force(self, mock_settings, mock_mongodb, mock_seeder_cls):
|
||||
from click.testing import CliRunner
|
||||
from application.seed.commands import seed
|
||||
|
||||
mock_settings.MONGO_DB_NAME = "testdb"
|
||||
mock_client = MagicMock()
|
||||
mock_client.__getitem__ = MagicMock(return_value="mock_db")
|
||||
mock_mongodb.get_client.return_value = mock_client
|
||||
|
||||
mock_seeder = MagicMock()
|
||||
mock_seeder_cls.return_value = mock_seeder
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(seed, ["init", "--force"])
|
||||
assert result.exit_code == 0
|
||||
mock_seeder.seed_initial_data.assert_called_once_with(force=True)
|
||||
186
tests/test_template_engine.py
Normal file
186
tests/test_template_engine.py
Normal file
@@ -0,0 +1,186 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from application.templates.template_engine import TemplateEngine, TemplateRenderError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def engine():
|
||||
return TemplateEngine()
|
||||
|
||||
|
||||
# ── render ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestRender:
|
||||
def test_simple_variable(self, engine):
|
||||
result = engine.render("Hello {{ name }}", {"name": "World"})
|
||||
assert result == "Hello World"
|
||||
|
||||
def test_empty_template_returns_empty(self, engine):
|
||||
assert engine.render("", {"x": 1}) == ""
|
||||
|
||||
def test_none_template_returns_empty(self, engine):
|
||||
assert engine.render(None, {"x": 1}) == ""
|
||||
|
||||
def test_no_variables(self, engine):
|
||||
assert engine.render("plain text", {}) == "plain text"
|
||||
|
||||
def test_multiple_variables(self, engine):
|
||||
tpl = "{{ a }} and {{ b }}"
|
||||
assert engine.render(tpl, {"a": "X", "b": "Y"}) == "X and Y"
|
||||
|
||||
def test_nested_dict_access(self, engine):
|
||||
tpl = "{{ data.key }}"
|
||||
assert engine.render(tpl, {"data": {"key": "value"}}) == "value"
|
||||
|
||||
def test_loop(self, engine):
|
||||
tpl = "{% for i in items %}{{ i }} {% endfor %}"
|
||||
result = engine.render(tpl, {"items": ["a", "b", "c"]})
|
||||
assert result.strip() == "a b c"
|
||||
|
||||
def test_conditional(self, engine):
|
||||
tpl = "{% if show %}yes{% else %}no{% endif %}"
|
||||
assert engine.render(tpl, {"show": True}) == "yes"
|
||||
assert engine.render(tpl, {"show": False}) == "no"
|
||||
|
||||
def test_syntax_error_raises_template_render_error(self, engine):
|
||||
with pytest.raises(TemplateRenderError, match="syntax error"):
|
||||
engine.render("{% if %}", {})
|
||||
|
||||
def test_undefined_variable_chainable(self, engine):
|
||||
# ChainableUndefined should NOT raise; it silently returns empty
|
||||
result = engine.render("{{ missing }}", {})
|
||||
assert result == ""
|
||||
|
||||
def test_autoescape_html(self, engine):
|
||||
result = engine.render("{{ content }}", {"content": "<script>alert(1)</script>"})
|
||||
assert "<script>" not in result
|
||||
assert "<script>" in result
|
||||
|
||||
def test_trim_blocks(self, engine):
|
||||
tpl = "{% if True %}\nyes\n{% endif %}"
|
||||
result = engine.render(tpl, {})
|
||||
assert result.strip() == "yes"
|
||||
|
||||
def test_general_exception_raises_template_render_error(self, engine):
|
||||
with patch.object(engine._env, "from_string", side_effect=RuntimeError("boom")):
|
||||
with pytest.raises(TemplateRenderError, match="rendering failed"):
|
||||
engine.render("{{ x }}", {"x": 1})
|
||||
|
||||
|
||||
# ── validate_template ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestValidateTemplate:
|
||||
def test_valid_template(self, engine):
|
||||
assert engine.validate_template("{{ name }}") is True
|
||||
|
||||
def test_empty_template_valid(self, engine):
|
||||
assert engine.validate_template("") is True
|
||||
|
||||
def test_none_template_valid(self, engine):
|
||||
assert engine.validate_template(None) is True
|
||||
|
||||
def test_invalid_syntax(self, engine):
|
||||
assert engine.validate_template("{% if %}") is False
|
||||
|
||||
def test_plain_text_valid(self, engine):
|
||||
assert engine.validate_template("just plain text") is True
|
||||
|
||||
def test_complex_valid_template(self, engine):
|
||||
tpl = "{% for x in items %}{{ x.name }}{% endfor %}"
|
||||
assert engine.validate_template(tpl) is True
|
||||
|
||||
def test_general_exception_returns_false(self, engine):
|
||||
with patch.object(engine._env, "from_string", side_effect=RuntimeError("boom")):
|
||||
assert engine.validate_template("{{ x }}") is False
|
||||
|
||||
|
||||
# ── extract_variables ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestExtractVariables:
|
||||
def test_empty_template(self, engine):
|
||||
assert engine.extract_variables("") == set()
|
||||
|
||||
def test_none_template(self, engine):
|
||||
assert engine.extract_variables(None) == set()
|
||||
|
||||
def test_syntax_error_returns_empty(self, engine):
|
||||
assert engine.extract_variables("{% if %}") == set()
|
||||
|
||||
def test_general_exception_returns_empty(self, engine):
|
||||
with patch.object(engine._env, "parse", side_effect=RuntimeError("boom")):
|
||||
assert engine.extract_variables("{{ x }}") == set()
|
||||
|
||||
|
||||
# ── extract_tool_usages ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestExtractToolUsages:
|
||||
def test_empty_template(self, engine):
|
||||
assert engine.extract_tool_usages("") == {}
|
||||
|
||||
def test_none_template(self, engine):
|
||||
assert engine.extract_tool_usages(None) == {}
|
||||
|
||||
def test_syntax_error_returns_empty(self, engine):
|
||||
assert engine.extract_tool_usages("{% if %}") == {}
|
||||
|
||||
def test_getattr_single_tool(self, engine):
|
||||
tpl = "{{ tools.memory.notes }}"
|
||||
result = engine.extract_tool_usages(tpl)
|
||||
assert "memory" in result
|
||||
assert "notes" in result["memory"]
|
||||
|
||||
def test_getattr_tool_without_action(self, engine):
|
||||
tpl = "{{ tools.search }}"
|
||||
result = engine.extract_tool_usages(tpl)
|
||||
assert "search" in result
|
||||
assert None in result["search"]
|
||||
|
||||
def test_getitem_bracket_notation(self, engine):
|
||||
tpl = '{{ tools["calendar"]["events"] }}'
|
||||
result = engine.extract_tool_usages(tpl)
|
||||
assert "calendar" in result
|
||||
assert "events" in result["calendar"]
|
||||
|
||||
def test_multiple_tools(self, engine):
|
||||
tpl = "{{ tools.memory.notes }} {{ tools.search.query }}"
|
||||
result = engine.extract_tool_usages(tpl)
|
||||
assert "memory" in result
|
||||
assert "search" in result
|
||||
|
||||
def test_same_tool_multiple_actions(self, engine):
|
||||
tpl = "{{ tools.memory.notes }} {{ tools.memory.tasks }}"
|
||||
result = engine.extract_tool_usages(tpl)
|
||||
assert "memory" in result
|
||||
assert "notes" in result["memory"]
|
||||
assert "tasks" in result["memory"]
|
||||
|
||||
def test_non_tools_getattr_ignored(self, engine):
|
||||
tpl = "{{ data.something }}"
|
||||
result = engine.extract_tool_usages(tpl)
|
||||
assert result == {}
|
||||
|
||||
def test_general_parse_error_returns_empty(self, engine):
|
||||
with patch.object(engine._env, "parse", side_effect=RuntimeError("boom")):
|
||||
assert engine.extract_tool_usages("{{ tools.x }}") == {}
|
||||
|
||||
def test_tools_in_loop(self, engine):
|
||||
tpl = "{% for item in tools.memory.notes %}{{ item }}{% endfor %}"
|
||||
result = engine.extract_tool_usages(tpl)
|
||||
assert "memory" in result
|
||||
assert "notes" in result["memory"]
|
||||
|
||||
def test_tools_in_conditional(self, engine):
|
||||
tpl = "{% if tools.search.results %}found{% endif %}"
|
||||
result = engine.extract_tool_usages(tpl)
|
||||
assert "search" in result
|
||||
assert "results" in result["search"]
|
||||
533
tests/test_utils.py
Normal file
533
tests/test_utils.py
Normal file
@@ -0,0 +1,533 @@
|
||||
"""Tests for application/utils.py"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from application.utils import (
|
||||
calculate_compression_threshold,
|
||||
calculate_doc_token_budget,
|
||||
check_required_fields,
|
||||
clean_text_for_tts,
|
||||
convert_pdf_to_images,
|
||||
get_encoding,
|
||||
get_field_validation_errors,
|
||||
get_gpt_model,
|
||||
get_hash,
|
||||
get_missing_fields,
|
||||
generate_image_url,
|
||||
limit_chat_history,
|
||||
num_tokens_from_object_or_list,
|
||||
num_tokens_from_string,
|
||||
safe_filename,
|
||||
validate_function_name,
|
||||
validate_required_fields,
|
||||
)
|
||||
|
||||
|
||||
class TestGetEncoding:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_returns_encoding(self):
|
||||
enc = get_encoding()
|
||||
assert enc is not None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_returns_same_instance(self):
|
||||
enc1 = get_encoding()
|
||||
enc2 = get_encoding()
|
||||
assert enc1 is enc2
|
||||
|
||||
|
||||
class TestGetGptModel:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_returns_llm_name_when_set(self):
|
||||
with patch("application.utils.settings") as s:
|
||||
s.LLM_NAME = "my-model"
|
||||
s.LLM_PROVIDER = "openai"
|
||||
assert get_gpt_model() == "my-model"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_falls_back_to_provider_map(self):
|
||||
with patch("application.utils.settings") as s:
|
||||
s.LLM_NAME = ""
|
||||
s.LLM_PROVIDER = "openai"
|
||||
assert get_gpt_model() == "gpt-4o-mini"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_unknown_provider_returns_empty(self):
|
||||
with patch("application.utils.settings") as s:
|
||||
s.LLM_NAME = ""
|
||||
s.LLM_PROVIDER = "unknown"
|
||||
assert get_gpt_model() == ""
|
||||
|
||||
|
||||
class TestSafeFilename:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_normal_filename(self):
|
||||
assert safe_filename("test.pdf") == "test.pdf"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_empty_filename_returns_uuid(self):
|
||||
result = safe_filename("")
|
||||
assert len(result) > 10 # UUID
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_none_filename_returns_uuid(self):
|
||||
result = safe_filename(None)
|
||||
assert len(result) > 10
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_non_latin_filename(self):
|
||||
result = safe_filename("документ.pdf")
|
||||
assert result.endswith(".pdf")
|
||||
|
||||
|
||||
class TestNumTokens:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_string_token_count(self):
|
||||
count = num_tokens_from_string("hello world")
|
||||
assert count > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_non_string_returns_zero(self):
|
||||
assert num_tokens_from_string(123) == 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_empty_string(self):
|
||||
assert num_tokens_from_string("") == 0
|
||||
|
||||
|
||||
class TestNumTokensFromObjectOrList:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_list(self):
|
||||
result = num_tokens_from_object_or_list(["hello", "world"])
|
||||
assert result > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_dict(self):
|
||||
result = num_tokens_from_object_or_list({"key": "value"})
|
||||
assert result > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_string(self):
|
||||
result = num_tokens_from_object_or_list("hello")
|
||||
assert result > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_number_returns_zero(self):
|
||||
assert num_tokens_from_object_or_list(42) == 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_nested(self):
|
||||
result = num_tokens_from_object_or_list({"a": ["b", "c"]})
|
||||
assert result > 0
|
||||
|
||||
|
||||
class TestCountTokensDocs:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_counts_doc_tokens(self):
|
||||
from application.utils import count_tokens_docs
|
||||
doc1 = MagicMock()
|
||||
doc1.page_content = "hello world"
|
||||
doc2 = MagicMock()
|
||||
doc2.page_content = " foo bar"
|
||||
result = count_tokens_docs([doc1, doc2])
|
||||
assert result > 0
|
||||
|
||||
|
||||
class TestCalculateDocTokenBudget:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_returns_budget(self):
|
||||
with patch("application.utils.get_token_limit", return_value=128000), \
|
||||
patch("application.utils.settings") as s:
|
||||
s.RESERVED_TOKENS = {"system": 500, "history": 500}
|
||||
result = calculate_doc_token_budget("gpt-4o")
|
||||
assert result == 127000
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_minimum_budget(self):
|
||||
with patch("application.utils.get_token_limit", return_value=1000), \
|
||||
patch("application.utils.settings") as s:
|
||||
s.RESERVED_TOKENS = {"system": 500, "history": 500}
|
||||
result = calculate_doc_token_budget("small-model")
|
||||
assert result == 1000
|
||||
|
||||
|
||||
class TestFieldValidation:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_missing_fields(self):
|
||||
assert get_missing_fields({"a": 1}, ["a", "b"]) == ["b"]
|
||||
assert get_missing_fields({"a": 1, "b": 2}, ["a", "b"]) == []
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_check_required_fields_pass(self):
|
||||
from flask import Flask
|
||||
app = Flask(__name__)
|
||||
with app.app_context():
|
||||
result = check_required_fields({"a": 1, "b": 2}, ["a", "b"])
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_check_required_fields_fail(self):
|
||||
from flask import Flask
|
||||
app = Flask(__name__)
|
||||
with app.app_context():
|
||||
result = check_required_fields({"a": 1}, ["a", "b"])
|
||||
assert result is not None
|
||||
assert result.status_code == 400
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_field_validation_errors_none_when_valid(self):
|
||||
assert get_field_validation_errors({"a": 1}, ["a"]) is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_field_validation_errors_missing(self):
|
||||
result = get_field_validation_errors({}, ["a"])
|
||||
assert result["missing_fields"] == ["a"]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_field_validation_errors_empty(self):
|
||||
result = get_field_validation_errors({"a": ""}, ["a"])
|
||||
assert result["empty_fields"] == ["a"]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_validate_required_fields_pass(self):
|
||||
from flask import Flask
|
||||
app = Flask(__name__)
|
||||
with app.app_context():
|
||||
result = validate_required_fields({"a": "v"}, ["a"])
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_validate_required_fields_missing(self):
|
||||
from flask import Flask
|
||||
app = Flask(__name__)
|
||||
with app.app_context():
|
||||
result = validate_required_fields({}, ["a"])
|
||||
assert result is not None
|
||||
assert result.status_code == 400
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_validate_required_fields_empty(self):
|
||||
from flask import Flask
|
||||
app = Flask(__name__)
|
||||
with app.app_context():
|
||||
result = validate_required_fields({"a": ""}, ["a"])
|
||||
assert result is not None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_validate_required_fields_both_missing_and_empty(self):
|
||||
from flask import Flask
|
||||
app = Flask(__name__)
|
||||
with app.app_context():
|
||||
result = validate_required_fields({"a": ""}, ["a", "b"])
|
||||
assert result is not None
|
||||
|
||||
|
||||
class TestGetHash:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_returns_hex_string(self):
|
||||
h = get_hash("test")
|
||||
assert len(h) == 32
|
||||
assert all(c in "0123456789abcdef" for c in h)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_deterministic(self):
|
||||
assert get_hash("hello") == get_hash("hello")
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_different_inputs(self):
|
||||
assert get_hash("a") != get_hash("b")
|
||||
|
||||
|
||||
class TestLimitChatHistory:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_empty_history(self):
|
||||
assert limit_chat_history([]) == []
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_none_history(self):
|
||||
assert limit_chat_history(None) == []
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_keeps_recent_messages(self):
|
||||
history = [
|
||||
{"prompt": "q1", "response": "a1"},
|
||||
{"prompt": "q2", "response": "a2"},
|
||||
]
|
||||
result = limit_chat_history(history, max_token_limit=10000)
|
||||
assert len(result) == 2
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_trims_old_messages(self):
|
||||
history = [
|
||||
{"prompt": "x" * 5000, "response": "y" * 5000},
|
||||
{"prompt": "q", "response": "a"},
|
||||
]
|
||||
result = limit_chat_history(history, max_token_limit=100)
|
||||
assert len(result) <= 2
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_handles_tool_calls(self):
|
||||
history = [
|
||||
{
|
||||
"prompt": "q",
|
||||
"response": "a",
|
||||
"tool_calls": [
|
||||
{"tool_name": "t", "action_name": "a", "arguments": "{}", "result": "r"}
|
||||
],
|
||||
}
|
||||
]
|
||||
result = limit_chat_history(history, max_token_limit=10000)
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
class TestValidateFunctionName:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_valid_names(self):
|
||||
assert validate_function_name("hello") is True
|
||||
assert validate_function_name("hello_world") is True
|
||||
assert validate_function_name("hello-world") is True
|
||||
assert validate_function_name("test123") is True
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invalid_names(self):
|
||||
assert validate_function_name("hello world") is False
|
||||
assert validate_function_name("hello!") is False
|
||||
assert validate_function_name("") is False
|
||||
|
||||
|
||||
class TestGenerateImageUrl:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_http_url_passthrough(self):
|
||||
assert generate_image_url("https://example.com/img.png") == "https://example.com/img.png"
|
||||
assert generate_image_url("http://example.com/img.png") == "http://example.com/img.png"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_s3_strategy(self):
|
||||
with patch("application.utils.settings") as s:
|
||||
s.URL_STRATEGY = "s3"
|
||||
s.S3_BUCKET_NAME = "my-bucket"
|
||||
s.SAGEMAKER_REGION = "us-west-2"
|
||||
result = generate_image_url("path/to/img.png")
|
||||
assert "my-bucket.s3.us-west-2" in result
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_backend_strategy(self):
|
||||
with patch("application.utils.settings") as s:
|
||||
s.URL_STRATEGY = "backend"
|
||||
s.API_URL = "http://localhost:7091"
|
||||
result = generate_image_url("path/to/img.png")
|
||||
assert result == "http://localhost:7091/api/images/path/to/img.png"
|
||||
|
||||
|
||||
class TestCalculateCompressionThreshold:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_default_threshold(self):
|
||||
with patch("application.utils.get_token_limit", return_value=100000):
|
||||
result = calculate_compression_threshold("gpt-4o")
|
||||
assert result == 80000
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_custom_percentage(self):
|
||||
with patch("application.utils.get_token_limit", return_value=100000):
|
||||
result = calculate_compression_threshold("gpt-4o", 0.5)
|
||||
assert result == 50000
|
||||
|
||||
|
||||
class TestConvertPdfToImages:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_missing_pdf2image_raises(self):
|
||||
with patch.dict("sys.modules", {"pdf2image": None}):
|
||||
# Force re-import to trigger ImportError
|
||||
# The function handles the import internally
|
||||
with pytest.raises(ImportError, match="pdf2image"):
|
||||
convert_pdf_to_images("test.pdf")
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_converts_from_path(self):
|
||||
mock_image = MagicMock()
|
||||
mock_image.save = MagicMock(side_effect=lambda buf, format: buf.write(b"PNG_DATA"))
|
||||
|
||||
mock_module = MagicMock()
|
||||
mock_module.convert_from_path.return_value = [mock_image]
|
||||
mock_module.convert_from_bytes.return_value = [mock_image]
|
||||
|
||||
original_import = __import__
|
||||
|
||||
def patched_import(name, *args, **kwargs):
|
||||
if name == "pdf2image":
|
||||
return mock_module
|
||||
return original_import(name, *args, **kwargs)
|
||||
|
||||
with patch("builtins.__import__", side_effect=patched_import):
|
||||
result = convert_pdf_to_images("/some/file.pdf")
|
||||
assert len(result) == 1
|
||||
assert result[0]["mime_type"] == "image/png"
|
||||
assert result[0]["page"] == 1
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_with_storage(self):
|
||||
mock_image = MagicMock()
|
||||
mock_image.save = MagicMock(side_effect=lambda buf, format: buf.write(b"IMG"))
|
||||
|
||||
mock_storage = MagicMock()
|
||||
mock_file = MagicMock()
|
||||
mock_file.read.return_value = b"pdf_bytes"
|
||||
mock_file.__enter__ = MagicMock(return_value=mock_file)
|
||||
mock_file.__exit__ = MagicMock(return_value=False)
|
||||
mock_storage.get_file.return_value = mock_file
|
||||
|
||||
mock_module = MagicMock()
|
||||
mock_module.convert_from_bytes.return_value = [mock_image]
|
||||
|
||||
original_import = __import__
|
||||
|
||||
def patched_import(name, *args, **kwargs):
|
||||
if name == "pdf2image":
|
||||
return mock_module
|
||||
return original_import(name, *args, **kwargs)
|
||||
|
||||
with patch("builtins.__import__", side_effect=patched_import):
|
||||
result = convert_pdf_to_images("test.pdf", storage=mock_storage)
|
||||
assert len(result) == 1
|
||||
mock_module.convert_from_bytes.assert_called_once()
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_file_not_found_raises(self):
|
||||
mock_module = MagicMock()
|
||||
mock_module.convert_from_path.side_effect = FileNotFoundError("not found")
|
||||
|
||||
# Patch the import inside the function
|
||||
original_import = __builtins__.__import__ if hasattr(__builtins__, '__import__') else __import__
|
||||
|
||||
def patched_import(name, *args, **kwargs):
|
||||
if name == "pdf2image":
|
||||
return mock_module
|
||||
return original_import(name, *args, **kwargs)
|
||||
|
||||
with patch("builtins.__import__", side_effect=patched_import):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
convert_pdf_to_images("/nonexistent.pdf")
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_generic_error_raises(self):
|
||||
mock_module = MagicMock()
|
||||
mock_module.convert_from_path.side_effect = RuntimeError("conversion failed")
|
||||
|
||||
original_import = __builtins__.__import__ if hasattr(__builtins__, '__import__') else __import__
|
||||
|
||||
def patched_import(name, *args, **kwargs):
|
||||
if name == "pdf2image":
|
||||
return mock_module
|
||||
return original_import(name, *args, **kwargs)
|
||||
|
||||
with patch("builtins.__import__", side_effect=patched_import):
|
||||
with pytest.raises(RuntimeError, match="conversion failed"):
|
||||
convert_pdf_to_images("/some.pdf")
|
||||
|
||||
|
||||
class TestCleanTextForTts:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_removes_code_blocks(self):
|
||||
result = clean_text_for_tts("before ```python\ncode\n``` after")
|
||||
assert "code block" in result
|
||||
assert "python" not in result
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_removes_mermaid_blocks(self):
|
||||
result = clean_text_for_tts("```mermaid\ngraph TD\n```")
|
||||
assert "flowchart" in result
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_removes_markdown_links(self):
|
||||
result = clean_text_for_tts("[click here](https://example.com)")
|
||||
assert "click here" in result
|
||||
assert "https" not in result
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_removes_images(self):
|
||||
result = clean_text_for_tts("")
|
||||
assert "image.png" not in result
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_removes_inline_code(self):
|
||||
result = clean_text_for_tts("use `foo()` here")
|
||||
assert "foo()" in result
|
||||
assert "`" not in result
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_removes_bold_italic(self):
|
||||
result = clean_text_for_tts("**bold** and *italic*")
|
||||
assert "bold" in result
|
||||
assert "italic" in result
|
||||
assert "*" not in result
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_removes_headers(self):
|
||||
result = clean_text_for_tts("# Header\ntext")
|
||||
assert "Header" in result
|
||||
assert "#" not in result
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_removes_blockquotes(self):
|
||||
result = clean_text_for_tts("> quoted text")
|
||||
assert "quoted text" in result
|
||||
assert ">" not in result
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_removes_html_tags(self):
|
||||
result = clean_text_for_tts("<div>content</div>")
|
||||
assert "content" in result
|
||||
assert "<" not in result
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_removes_arrows(self):
|
||||
result = clean_text_for_tts("a --> b <-- c => d")
|
||||
assert "-->" not in result
|
||||
assert "<--" not in result
|
||||
assert "=>" not in result
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_removes_horizontal_rules(self):
|
||||
result = clean_text_for_tts("text\n---\nmore")
|
||||
assert "---" not in result
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_removes_list_markers(self):
|
||||
result = clean_text_for_tts("- item1\n* item2\n1. item3")
|
||||
assert "item1" in result
|
||||
assert "item2" in result
|
||||
assert "item3" in result
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_normalizes_whitespace(self):
|
||||
result = clean_text_for_tts(" lots of spaces ")
|
||||
assert " " not in result
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_removes_braces(self):
|
||||
result = clean_text_for_tts("{content} and [more]")
|
||||
assert "content" in result
|
||||
assert "more" in result
|
||||
assert "{" not in result
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_removes_double_colons(self):
|
||||
result = clean_text_for_tts("module::function")
|
||||
assert "::" not in result
|
||||
1718
tests/test_worker.py
Normal file
1718
tests/test_worker.py
Normal file
File diff suppressed because it is too large
Load Diff
0
tests/vectorstore/__init__.py
Normal file
0
tests/vectorstore/__init__.py
Normal file
361
tests/vectorstore/test_base.py
Normal file
361
tests/vectorstore/test_base.py
Normal file
@@ -0,0 +1,361 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from application.vectorstore.base import (
|
||||
BaseVectorStore,
|
||||
EmbeddingsSingleton,
|
||||
RemoteEmbeddings,
|
||||
)
|
||||
|
||||
|
||||
# --- RemoteEmbeddings ---
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestRemoteEmbeddings:
|
||||
def test_init_sets_url_and_headers(self):
|
||||
emb = RemoteEmbeddings(
|
||||
api_url="http://localhost:8080/", model_name="model-v1", api_key="sk-key"
|
||||
)
|
||||
assert emb.api_url == "http://localhost:8080"
|
||||
assert emb.model_name == "model-v1"
|
||||
assert emb.headers["Authorization"] == "Bearer sk-key"
|
||||
|
||||
def test_init_no_api_key(self):
|
||||
emb = RemoteEmbeddings(api_url="http://host", model_name="m")
|
||||
assert "Authorization" not in emb.headers
|
||||
|
||||
@patch("application.vectorstore.base.requests.post")
|
||||
def test_embed_sends_correct_payload(self, mock_post):
|
||||
mock_resp = Mock()
|
||||
mock_resp.json.return_value = {
|
||||
"data": [{"index": 0, "embedding": [0.1, 0.2]}]
|
||||
}
|
||||
mock_resp.raise_for_status = Mock()
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
emb = RemoteEmbeddings("http://host", "model-v1")
|
||||
result = emb._embed("test input")
|
||||
|
||||
mock_post.assert_called_once()
|
||||
call_kwargs = mock_post.call_args
|
||||
assert call_kwargs[1]["json"]["input"] == "test input"
|
||||
assert call_kwargs[1]["json"]["model"] == "model-v1"
|
||||
assert result == [[0.1, 0.2]]
|
||||
|
||||
@patch("application.vectorstore.base.requests.post")
|
||||
def test_embed_sorts_by_index(self, mock_post):
|
||||
mock_resp = Mock()
|
||||
mock_resp.json.return_value = {
|
||||
"data": [
|
||||
{"index": 1, "embedding": [0.3, 0.4]},
|
||||
{"index": 0, "embedding": [0.1, 0.2]},
|
||||
]
|
||||
}
|
||||
mock_resp.raise_for_status = Mock()
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
emb = RemoteEmbeddings("http://host", "m")
|
||||
result = emb._embed(["a", "b"])
|
||||
assert result == [[0.1, 0.2], [0.3, 0.4]]
|
||||
|
||||
@patch("application.vectorstore.base.requests.post")
|
||||
def test_embed_raises_on_error_response(self, mock_post):
|
||||
mock_resp = Mock()
|
||||
mock_resp.json.return_value = {"error": "rate limit exceeded"}
|
||||
mock_resp.raise_for_status = Mock()
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
emb = RemoteEmbeddings("http://host", "m")
|
||||
with pytest.raises(ValueError, match="rate limit exceeded"):
|
||||
emb._embed("test")
|
||||
|
||||
@patch("application.vectorstore.base.requests.post")
|
||||
def test_embed_raises_on_unexpected_format(self, mock_post):
|
||||
mock_resp = Mock()
|
||||
mock_resp.json.return_value = {"unexpected": True}
|
||||
mock_resp.raise_for_status = Mock()
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
emb = RemoteEmbeddings("http://host", "m")
|
||||
with pytest.raises(ValueError, match="Unexpected response format"):
|
||||
emb._embed("test")
|
||||
|
||||
@patch("application.vectorstore.base.requests.post")
|
||||
def test_embed_raises_on_non_dict_response(self, mock_post):
|
||||
mock_resp = Mock()
|
||||
mock_resp.json.return_value = [1, 2, 3]
|
||||
mock_resp.raise_for_status = Mock()
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
emb = RemoteEmbeddings("http://host", "m")
|
||||
with pytest.raises(ValueError, match="Unexpected response format"):
|
||||
emb._embed("test")
|
||||
|
||||
@patch("application.vectorstore.base.requests.post")
|
||||
def test_embed_query(self, mock_post):
|
||||
mock_resp = Mock()
|
||||
mock_resp.json.return_value = {
|
||||
"data": [{"index": 0, "embedding": [0.1, 0.2, 0.3]}]
|
||||
}
|
||||
mock_resp.raise_for_status = Mock()
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
emb = RemoteEmbeddings("http://host", "m")
|
||||
emb.dimension = None # Reset so it gets set from response
|
||||
result = emb.embed_query("hello")
|
||||
assert result == [0.1, 0.2, 0.3]
|
||||
assert emb.dimension == 3
|
||||
|
||||
@patch("application.vectorstore.base.requests.post")
|
||||
def test_embed_query_raises_on_bad_structure(self, mock_post):
|
||||
mock_resp = Mock()
|
||||
# Return multiple embeddings for a single query
|
||||
mock_resp.json.return_value = {
|
||||
"data": [
|
||||
{"index": 0, "embedding": [0.1]},
|
||||
{"index": 1, "embedding": [0.2]},
|
||||
]
|
||||
}
|
||||
mock_resp.raise_for_status = Mock()
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
emb = RemoteEmbeddings("http://host", "m")
|
||||
with pytest.raises(ValueError, match="Unexpected result structure"):
|
||||
emb.embed_query("hello")
|
||||
|
||||
@patch("application.vectorstore.base.requests.post")
|
||||
def test_embed_documents(self, mock_post):
|
||||
mock_resp = Mock()
|
||||
mock_resp.json.return_value = {
|
||||
"data": [
|
||||
{"index": 0, "embedding": [0.1, 0.2]},
|
||||
{"index": 1, "embedding": [0.3, 0.4]},
|
||||
]
|
||||
}
|
||||
mock_resp.raise_for_status = Mock()
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
emb = RemoteEmbeddings("http://host", "m")
|
||||
emb.dimension = None # Reset so it gets set from response
|
||||
result = emb.embed_documents(["doc1", "doc2"])
|
||||
assert result == [[0.1, 0.2], [0.3, 0.4]]
|
||||
assert emb.dimension == 2
|
||||
|
||||
def test_embed_documents_empty(self):
|
||||
emb = RemoteEmbeddings("http://host", "m")
|
||||
assert emb.embed_documents([]) == []
|
||||
|
||||
@patch("application.vectorstore.base.requests.post")
|
||||
def test_call_with_string(self, mock_post):
|
||||
mock_resp = Mock()
|
||||
mock_resp.json.return_value = {
|
||||
"data": [{"index": 0, "embedding": [0.5]}]
|
||||
}
|
||||
mock_resp.raise_for_status = Mock()
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
emb = RemoteEmbeddings("http://host", "m")
|
||||
result = emb("hello")
|
||||
assert result == [0.5]
|
||||
|
||||
@patch("application.vectorstore.base.requests.post")
|
||||
def test_call_with_list(self, mock_post):
|
||||
mock_resp = Mock()
|
||||
mock_resp.json.return_value = {
|
||||
"data": [{"index": 0, "embedding": [0.5]}]
|
||||
}
|
||||
mock_resp.raise_for_status = Mock()
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
emb = RemoteEmbeddings("http://host", "m")
|
||||
result = emb(["hello"])
|
||||
assert result == [[0.5]]
|
||||
|
||||
def test_call_with_invalid_type(self):
|
||||
emb = RemoteEmbeddings("http://host", "m")
|
||||
with pytest.raises(ValueError, match="Input must be a string or a list"):
|
||||
emb(123)
|
||||
|
||||
|
||||
# --- EmbeddingsSingleton ---
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestEmbeddingsSingleton:
|
||||
def setup_method(self):
|
||||
EmbeddingsSingleton._instances = {}
|
||||
|
||||
@patch("application.vectorstore.base.OpenAIEmbeddings")
|
||||
def test_get_instance_openai(self, mock_openai_cls):
|
||||
mock_instance = Mock()
|
||||
mock_openai_cls.return_value = mock_instance
|
||||
|
||||
result = EmbeddingsSingleton.get_instance("openai_text-embedding-ada-002")
|
||||
assert result is mock_instance
|
||||
|
||||
@patch("application.vectorstore.base.OpenAIEmbeddings")
|
||||
def test_singleton_returns_same_instance(self, mock_openai_cls):
|
||||
mock_instance = Mock()
|
||||
mock_openai_cls.return_value = mock_instance
|
||||
|
||||
r1 = EmbeddingsSingleton.get_instance("openai_text-embedding-ada-002")
|
||||
r2 = EmbeddingsSingleton.get_instance("openai_text-embedding-ada-002")
|
||||
assert r1 is r2
|
||||
mock_openai_cls.assert_called_once()
|
||||
|
||||
@patch("application.vectorstore.base._get_embeddings_wrapper")
|
||||
def test_get_instance_huggingface(self, mock_get_wrapper):
|
||||
mock_wrapper_cls = Mock()
|
||||
mock_instance = Mock()
|
||||
mock_wrapper_cls.return_value = mock_instance
|
||||
mock_get_wrapper.return_value = mock_wrapper_cls
|
||||
|
||||
result = EmbeddingsSingleton.get_instance(
|
||||
"huggingface_sentence-transformers/all-mpnet-base-v2"
|
||||
)
|
||||
assert result is mock_instance
|
||||
|
||||
@patch("application.vectorstore.base._get_embeddings_wrapper")
|
||||
def test_get_instance_unknown_falls_back_to_wrapper(self, mock_get_wrapper):
|
||||
mock_wrapper_cls = Mock()
|
||||
mock_instance = Mock()
|
||||
mock_wrapper_cls.return_value = mock_instance
|
||||
mock_get_wrapper.return_value = mock_wrapper_cls
|
||||
|
||||
result = EmbeddingsSingleton.get_instance("custom_model_name")
|
||||
mock_wrapper_cls.assert_called_once_with("custom_model_name")
|
||||
assert result is mock_instance
|
||||
|
||||
|
||||
# --- BaseVectorStore ---
|
||||
|
||||
|
||||
class ConcreteVectorStore(BaseVectorStore):
|
||||
"""Concrete implementation for testing base class methods."""
|
||||
|
||||
def search(self, *args, **kwargs):
|
||||
return []
|
||||
|
||||
def add_texts(self, texts, metadatas=None, *args, **kwargs):
|
||||
return []
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestBaseVectorStore:
|
||||
def setup_method(self):
|
||||
EmbeddingsSingleton._instances = {}
|
||||
|
||||
def test_default_methods_are_noop(self):
|
||||
store = ConcreteVectorStore()
|
||||
assert store.delete_index() is None
|
||||
assert store.save_local() is None
|
||||
assert store.get_chunks() is None
|
||||
assert store.add_chunk("text") is None
|
||||
assert store.delete_chunk("id") is None
|
||||
|
||||
@patch("application.vectorstore.base.settings")
|
||||
def test_is_azure_configured_true(self, mock_settings):
|
||||
mock_settings.OPENAI_API_BASE = "https://azure.openai.com"
|
||||
mock_settings.OPENAI_API_VERSION = "2023-05-15"
|
||||
mock_settings.AZURE_DEPLOYMENT_NAME = "my-deploy"
|
||||
|
||||
store = ConcreteVectorStore()
|
||||
assert store.is_azure_configured()
|
||||
|
||||
@patch("application.vectorstore.base.settings")
|
||||
def test_is_azure_configured_false(self, mock_settings):
|
||||
mock_settings.OPENAI_API_BASE = None
|
||||
mock_settings.OPENAI_API_VERSION = None
|
||||
mock_settings.AZURE_DEPLOYMENT_NAME = None
|
||||
|
||||
store = ConcreteVectorStore()
|
||||
assert not store.is_azure_configured()
|
||||
|
||||
@patch("application.vectorstore.base.settings")
|
||||
def test_get_embeddings_remote(self, mock_settings):
|
||||
mock_settings.EMBEDDINGS_BASE_URL = "http://remote:8080"
|
||||
|
||||
store = ConcreteVectorStore()
|
||||
result = store._get_embeddings("model-name", "api-key")
|
||||
|
||||
assert isinstance(result, RemoteEmbeddings)
|
||||
assert result.api_url == "http://remote:8080"
|
||||
|
||||
@patch("application.vectorstore.base.settings")
|
||||
@patch("application.vectorstore.base.EmbeddingsSingleton.get_instance")
|
||||
def test_get_embeddings_openai(self, mock_get_instance, mock_settings):
|
||||
mock_settings.EMBEDDINGS_BASE_URL = None
|
||||
mock_settings.OPENAI_API_BASE = None
|
||||
mock_settings.OPENAI_API_VERSION = None
|
||||
mock_settings.AZURE_DEPLOYMENT_NAME = None
|
||||
|
||||
mock_emb = Mock()
|
||||
mock_get_instance.return_value = mock_emb
|
||||
|
||||
store = ConcreteVectorStore()
|
||||
result = store._get_embeddings("openai_text-embedding-ada-002", "sk-key")
|
||||
assert result is mock_emb
|
||||
|
||||
@patch("application.vectorstore.base.settings")
|
||||
@patch("application.vectorstore.base.EmbeddingsSingleton.get_instance")
|
||||
def test_get_embeddings_openai_azure(self, mock_get_instance, mock_settings):
|
||||
mock_settings.EMBEDDINGS_BASE_URL = None
|
||||
mock_settings.OPENAI_API_BASE = "https://azure.openai.com"
|
||||
mock_settings.OPENAI_API_VERSION = "2023-05-15"
|
||||
mock_settings.AZURE_DEPLOYMENT_NAME = "deploy"
|
||||
mock_settings.AZURE_EMBEDDINGS_DEPLOYMENT_NAME = "embed-deploy"
|
||||
|
||||
mock_emb = Mock()
|
||||
mock_get_instance.return_value = mock_emb
|
||||
|
||||
store = ConcreteVectorStore()
|
||||
result = store._get_embeddings("openai_text-embedding-ada-002", "sk-key")
|
||||
assert result is mock_emb
|
||||
|
||||
@patch("application.vectorstore.base.settings")
|
||||
@patch("application.vectorstore.base.EmbeddingsSingleton.get_instance")
|
||||
@patch("os.path.exists", return_value=False)
|
||||
def test_get_embeddings_huggingface_no_local_model(
|
||||
self, mock_exists, mock_get_instance, mock_settings
|
||||
):
|
||||
mock_settings.EMBEDDINGS_BASE_URL = None
|
||||
mock_emb = Mock()
|
||||
mock_get_instance.return_value = mock_emb
|
||||
|
||||
store = ConcreteVectorStore()
|
||||
result = store._get_embeddings(
|
||||
"huggingface_sentence-transformers/all-mpnet-base-v2"
|
||||
)
|
||||
assert result is mock_emb
|
||||
|
||||
@patch("application.vectorstore.base.settings")
|
||||
@patch("application.vectorstore.base.EmbeddingsSingleton.get_instance")
|
||||
@patch("os.path.exists")
|
||||
def test_get_embeddings_huggingface_local_model(
|
||||
self, mock_exists, mock_get_instance, mock_settings
|
||||
):
|
||||
mock_settings.EMBEDDINGS_BASE_URL = None
|
||||
mock_exists.side_effect = lambda p: p == "/app/models/all-mpnet-base-v2"
|
||||
mock_emb = Mock()
|
||||
mock_get_instance.return_value = mock_emb
|
||||
|
||||
store = ConcreteVectorStore()
|
||||
result = store._get_embeddings(
|
||||
"huggingface_sentence-transformers/all-mpnet-base-v2"
|
||||
)
|
||||
assert result is mock_emb
|
||||
mock_get_instance.assert_called_with("/app/models/all-mpnet-base-v2")
|
||||
|
||||
@patch("application.vectorstore.base.settings")
|
||||
@patch("application.vectorstore.base.EmbeddingsSingleton.get_instance")
|
||||
def test_get_embeddings_generic(self, mock_get_instance, mock_settings):
|
||||
mock_settings.EMBEDDINGS_BASE_URL = None
|
||||
mock_emb = Mock()
|
||||
mock_get_instance.return_value = mock_emb
|
||||
|
||||
store = ConcreteVectorStore()
|
||||
result = store._get_embeddings("some_custom_embedding")
|
||||
assert result is mock_emb
|
||||
mock_get_instance.assert_called_with("some_custom_embedding")
|
||||
39
tests/vectorstore/test_document_class.py
Normal file
39
tests/vectorstore/test_document_class.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import pytest
|
||||
from application.vectorstore.document_class import Document
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestDocument:
|
||||
def test_create_document(self):
|
||||
doc = Document(page_content="hello world", metadata={"source": "test"})
|
||||
assert doc.page_content == "hello world"
|
||||
assert doc.metadata == {"source": "test"}
|
||||
|
||||
def test_document_is_string(self):
|
||||
doc = Document(page_content="hello world", metadata={})
|
||||
assert isinstance(doc, str)
|
||||
assert str(doc) == "hello world"
|
||||
|
||||
def test_document_string_equality(self):
|
||||
doc = Document(page_content="hello", metadata={"k": "v"})
|
||||
assert doc == "hello"
|
||||
|
||||
def test_document_empty_metadata(self):
|
||||
doc = Document(page_content="text", metadata={})
|
||||
assert doc.metadata == {}
|
||||
|
||||
def test_document_empty_content(self):
|
||||
doc = Document(page_content="", metadata={"a": 1})
|
||||
assert doc.page_content == ""
|
||||
assert doc == ""
|
||||
|
||||
def test_document_preserves_complex_metadata(self):
|
||||
meta = {"source": "file.txt", "page": 3, "nested": {"key": "val"}}
|
||||
doc = Document(page_content="content", metadata=meta)
|
||||
assert doc.metadata["nested"]["key"] == "val"
|
||||
|
||||
def test_document_string_operations(self):
|
||||
doc = Document(page_content="hello world", metadata={})
|
||||
assert doc.upper() == "HELLO WORLD"
|
||||
assert doc.split() == ["hello", "world"]
|
||||
assert "world" in doc
|
||||
292
tests/vectorstore/test_elasticsearch.py
Normal file
292
tests/vectorstore/test_elasticsearch.py
Normal file
@@ -0,0 +1,292 @@
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _make_es_store(source_id="test-source"):
|
||||
"""Helper to create an ElasticsearchStore with mocked deps."""
|
||||
# Reset class-level connection
|
||||
from application.vectorstore.elasticsearch import ElasticsearchStore
|
||||
|
||||
ElasticsearchStore._es_connection = None
|
||||
|
||||
with patch(
|
||||
"application.vectorstore.elasticsearch.settings"
|
||||
) as mock_settings, patch.dict(
|
||||
"sys.modules", {"elasticsearch": MagicMock(), "elasticsearch.helpers": MagicMock()}
|
||||
):
|
||||
mock_settings.ELASTIC_URL = "http://localhost:9200"
|
||||
mock_settings.ELASTIC_USERNAME = "elastic"
|
||||
mock_settings.ELASTIC_PASSWORD = "password"
|
||||
mock_settings.ELASTIC_CLOUD_ID = None
|
||||
mock_settings.ELASTIC_INDEX = "test_index"
|
||||
mock_settings.EMBEDDINGS_NAME = "test_model"
|
||||
|
||||
import elasticsearch
|
||||
|
||||
mock_es = MagicMock()
|
||||
elasticsearch.Elasticsearch.return_value = mock_es
|
||||
|
||||
store = ElasticsearchStore(
|
||||
source_id=source_id,
|
||||
embeddings_key="key",
|
||||
index_name="test_index",
|
||||
)
|
||||
|
||||
return store, mock_es, mock_settings
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestElasticsearchStoreInit:
|
||||
def test_source_id_cleaned(self):
|
||||
store, _, _ = _make_es_store(source_id="application/indexes/abc123/")
|
||||
assert store.source_id == "abc123"
|
||||
|
||||
def test_init_with_url(self):
|
||||
store, mock_es, _ = _make_es_store()
|
||||
assert store.docsearch is mock_es
|
||||
assert store.index_name == "test_index"
|
||||
|
||||
def test_init_with_cloud_id(self):
|
||||
from application.vectorstore.elasticsearch import ElasticsearchStore
|
||||
|
||||
ElasticsearchStore._es_connection = None
|
||||
|
||||
with patch(
|
||||
"application.vectorstore.elasticsearch.settings"
|
||||
) as mock_settings, patch.dict(
|
||||
"sys.modules", {"elasticsearch": MagicMock()}
|
||||
):
|
||||
mock_settings.ELASTIC_URL = None
|
||||
mock_settings.ELASTIC_CLOUD_ID = "my-cloud-id"
|
||||
mock_settings.ELASTIC_USERNAME = "user"
|
||||
mock_settings.ELASTIC_PASSWORD = "pass"
|
||||
mock_settings.ELASTIC_INDEX = "idx"
|
||||
mock_settings.EMBEDDINGS_NAME = "model"
|
||||
|
||||
store = ElasticsearchStore(
|
||||
source_id="src", embeddings_key="k", index_name="idx"
|
||||
)
|
||||
assert store.docsearch is not None
|
||||
|
||||
def test_init_no_url_no_cloud_id_raises(self):
|
||||
from application.vectorstore.elasticsearch import ElasticsearchStore
|
||||
|
||||
ElasticsearchStore._es_connection = None
|
||||
|
||||
with patch(
|
||||
"application.vectorstore.elasticsearch.settings"
|
||||
) as mock_settings, patch.dict(
|
||||
"sys.modules", {"elasticsearch": MagicMock()}
|
||||
):
|
||||
mock_settings.ELASTIC_URL = None
|
||||
mock_settings.ELASTIC_CLOUD_ID = None
|
||||
mock_settings.ELASTIC_INDEX = "idx"
|
||||
mock_settings.EMBEDDINGS_NAME = "model"
|
||||
|
||||
with pytest.raises(ValueError, match="provide either"):
|
||||
ElasticsearchStore(source_id="src", embeddings_key="k")
|
||||
|
||||
def test_reuses_class_connection(self):
|
||||
from application.vectorstore.elasticsearch import ElasticsearchStore
|
||||
|
||||
ElasticsearchStore._es_connection = None
|
||||
|
||||
with patch(
|
||||
"application.vectorstore.elasticsearch.settings"
|
||||
) as mock_settings, patch.dict(
|
||||
"sys.modules", {"elasticsearch": MagicMock()}
|
||||
):
|
||||
mock_settings.ELASTIC_URL = "http://localhost:9200"
|
||||
mock_settings.ELASTIC_USERNAME = "user"
|
||||
mock_settings.ELASTIC_PASSWORD = "pass"
|
||||
mock_settings.ELASTIC_CLOUD_ID = None
|
||||
mock_settings.ELASTIC_INDEX = "idx"
|
||||
mock_settings.EMBEDDINGS_NAME = "model"
|
||||
|
||||
import elasticsearch
|
||||
|
||||
mock_es = MagicMock()
|
||||
elasticsearch.Elasticsearch.return_value = mock_es
|
||||
|
||||
store1 = ElasticsearchStore(source_id="src1", embeddings_key="k")
|
||||
store2 = ElasticsearchStore(source_id="src2", embeddings_key="k")
|
||||
|
||||
assert store1.docsearch is store2.docsearch
|
||||
elasticsearch.Elasticsearch.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestElasticsearchStoreSearch:
|
||||
def test_search_builds_query(self):
|
||||
store, mock_es, mock_settings = _make_es_store()
|
||||
|
||||
mock_emb = Mock()
|
||||
mock_emb.embed_query = Mock(return_value=[0.1, 0.2, 0.3])
|
||||
|
||||
with patch.object(store, "_get_embeddings", return_value=mock_emb):
|
||||
mock_es.search.return_value = {
|
||||
"hits": {
|
||||
"hits": [
|
||||
{
|
||||
"_source": {
|
||||
"text": "doc1",
|
||||
"metadata": {"source": "file.txt"},
|
||||
}
|
||||
},
|
||||
{
|
||||
"_source": {
|
||||
"text": "doc2",
|
||||
"metadata": {"source": "file2.txt"},
|
||||
}
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
results = store.search("query", k=2)
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0].page_content == "doc1"
|
||||
assert results[1].metadata == {"source": "file2.txt"}
|
||||
|
||||
def test_search_empty_results(self):
|
||||
store, mock_es, _ = _make_es_store()
|
||||
|
||||
mock_emb = Mock()
|
||||
mock_emb.embed_query = Mock(return_value=[0.1])
|
||||
|
||||
with patch.object(store, "_get_embeddings", return_value=mock_emb):
|
||||
mock_es.search.return_value = {"hits": {"hits": []}}
|
||||
results = store.search("query")
|
||||
|
||||
assert results == []
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestElasticsearchStoreAddTexts:
|
||||
def test_add_texts(self):
|
||||
store, mock_es, mock_settings = _make_es_store()
|
||||
|
||||
mock_emb = Mock()
|
||||
mock_emb.embed_documents = Mock(return_value=[[0.1, 0.2], [0.3, 0.4]])
|
||||
|
||||
mock_bulk = Mock(return_value=(2, 0))
|
||||
mock_helpers = MagicMock()
|
||||
mock_helpers.bulk = mock_bulk
|
||||
|
||||
with patch.object(
|
||||
store, "_get_embeddings", return_value=mock_emb
|
||||
), patch.object(
|
||||
store, "_create_index_if_not_exists"
|
||||
), patch.dict(
|
||||
"sys.modules", {"elasticsearch.helpers": mock_helpers}
|
||||
):
|
||||
ids = store.add_texts(
|
||||
["text1", "text2"],
|
||||
metadatas=[{"a": 1}, {"b": 2}],
|
||||
)
|
||||
|
||||
assert len(ids) == 2
|
||||
|
||||
def test_add_texts_empty_raises(self):
|
||||
"""Empty texts causes IndexError because code accesses vectors[0] unconditionally."""
|
||||
store, _, _ = _make_es_store()
|
||||
|
||||
mock_emb = Mock()
|
||||
mock_emb.embed_documents = Mock(return_value=[])
|
||||
|
||||
with patch.object(store, "_get_embeddings", return_value=mock_emb):
|
||||
with pytest.raises(IndexError):
|
||||
store.add_texts([], metadatas=[])
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestElasticsearchStoreDeleteIndex:
|
||||
def test_delete_index_calls_delete_by_query(self):
|
||||
store, mock_es, _ = _make_es_store(source_id="src1")
|
||||
|
||||
store.delete_index()
|
||||
|
||||
mock_es.delete_by_query.assert_called_once_with(
|
||||
index="test_index",
|
||||
query={"match": {"metadata.source_id.keyword": "src1"}},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestElasticsearchStoreIndex:
|
||||
def test_index_returns_mapping(self):
|
||||
store, _, _ = _make_es_store()
|
||||
|
||||
mapping = store.index(dims_length=768)
|
||||
|
||||
assert mapping["mappings"]["properties"]["vector"]["type"] == "dense_vector"
|
||||
assert mapping["mappings"]["properties"]["vector"]["dims"] == 768
|
||||
assert mapping["mappings"]["properties"]["vector"]["similarity"] == "cosine"
|
||||
|
||||
def test_create_index_if_not_exists_existing(self):
|
||||
store, mock_es, _ = _make_es_store()
|
||||
mock_es.indices.exists.return_value = True
|
||||
|
||||
store._create_index_if_not_exists("test_index", 768)
|
||||
|
||||
mock_es.indices.create.assert_not_called()
|
||||
|
||||
def test_create_index_if_not_exists_new(self):
|
||||
store, mock_es, _ = _make_es_store()
|
||||
mock_es.indices.exists.return_value = False
|
||||
|
||||
store._create_index_if_not_exists("test_index", 768)
|
||||
|
||||
mock_es.indices.create.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestElasticsearchStoreConnectToElasticsearch:
|
||||
def test_connect_with_url(self):
|
||||
from application.vectorstore.elasticsearch import ElasticsearchStore
|
||||
|
||||
with patch.dict("sys.modules", {"elasticsearch": MagicMock()}):
|
||||
import elasticsearch
|
||||
|
||||
mock_es = MagicMock()
|
||||
elasticsearch.Elasticsearch.return_value = mock_es
|
||||
|
||||
result = ElasticsearchStore.connect_to_elasticsearch(
|
||||
es_url="http://localhost:9200",
|
||||
username="user",
|
||||
password="pass",
|
||||
)
|
||||
assert result is mock_es
|
||||
|
||||
def test_connect_with_both_raises(self):
|
||||
from application.vectorstore.elasticsearch import ElasticsearchStore
|
||||
|
||||
with patch.dict("sys.modules", {"elasticsearch": MagicMock()}):
|
||||
with pytest.raises(ValueError, match="Both es_url and cloud_id"):
|
||||
ElasticsearchStore.connect_to_elasticsearch(
|
||||
es_url="http://localhost", cloud_id="cloud-123"
|
||||
)
|
||||
|
||||
def test_connect_with_neither_raises(self):
|
||||
from application.vectorstore.elasticsearch import ElasticsearchStore
|
||||
|
||||
with patch.dict("sys.modules", {"elasticsearch": MagicMock()}):
|
||||
with pytest.raises(ValueError, match="provide either"):
|
||||
ElasticsearchStore.connect_to_elasticsearch()
|
||||
|
||||
def test_connect_with_api_key(self):
|
||||
from application.vectorstore.elasticsearch import ElasticsearchStore
|
||||
|
||||
with patch.dict("sys.modules", {"elasticsearch": MagicMock()}):
|
||||
import elasticsearch
|
||||
|
||||
mock_es = MagicMock()
|
||||
elasticsearch.Elasticsearch.return_value = mock_es
|
||||
|
||||
result = ElasticsearchStore.connect_to_elasticsearch(
|
||||
es_url="http://localhost:9200",
|
||||
api_key="my-api-key",
|
||||
)
|
||||
assert result is mock_es
|
||||
140
tests/vectorstore/test_embeddings_local.py
Normal file
140
tests/vectorstore/test_embeddings_local.py
Normal file
@@ -0,0 +1,140 @@
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestEmbeddingsWrapper:
|
||||
@patch("application.vectorstore.embeddings_local.SentenceTransformer")
|
||||
def test_init_success(self, mock_st_cls):
|
||||
mock_model = MagicMock()
|
||||
mock_model._first_module.return_value = MagicMock()
|
||||
mock_model.get_sentence_embedding_dimension.return_value = 768
|
||||
mock_st_cls.return_value = mock_model
|
||||
|
||||
from application.vectorstore.embeddings_local import EmbeddingsWrapper
|
||||
|
||||
wrapper = EmbeddingsWrapper("test-model")
|
||||
|
||||
mock_st_cls.assert_called_once()
|
||||
assert wrapper.dimension == 768
|
||||
|
||||
@patch("application.vectorstore.embeddings_local.SentenceTransformer")
|
||||
def test_init_failure(self, mock_st_cls):
|
||||
mock_st_cls.side_effect = Exception("model not found")
|
||||
|
||||
from application.vectorstore.embeddings_local import EmbeddingsWrapper
|
||||
|
||||
with pytest.raises(Exception, match="model not found"):
|
||||
EmbeddingsWrapper("bad-model")
|
||||
|
||||
@patch("application.vectorstore.embeddings_local.SentenceTransformer")
|
||||
def test_init_none_model(self, mock_st_cls):
|
||||
mock_st_cls.return_value = None
|
||||
|
||||
from application.vectorstore.embeddings_local import EmbeddingsWrapper
|
||||
|
||||
with pytest.raises((ValueError, AttributeError)):
|
||||
EmbeddingsWrapper("bad-model")
|
||||
|
||||
@patch("application.vectorstore.embeddings_local.SentenceTransformer")
|
||||
def test_init_null_first_module(self, mock_st_cls):
|
||||
mock_model = MagicMock()
|
||||
mock_model._first_module.return_value = None
|
||||
mock_st_cls.return_value = mock_model
|
||||
|
||||
from application.vectorstore.embeddings_local import EmbeddingsWrapper
|
||||
|
||||
with pytest.raises(ValueError, match="failed to load properly"):
|
||||
EmbeddingsWrapper("bad-model")
|
||||
|
||||
@patch("application.vectorstore.embeddings_local.SentenceTransformer")
|
||||
def test_embed_query(self, mock_st_cls):
|
||||
mock_model = MagicMock()
|
||||
mock_model._first_module.return_value = MagicMock()
|
||||
mock_model.get_sentence_embedding_dimension.return_value = 3
|
||||
mock_model.encode.return_value = MagicMock(tolist=Mock(return_value=[0.1, 0.2, 0.3]))
|
||||
mock_st_cls.return_value = mock_model
|
||||
|
||||
from application.vectorstore.embeddings_local import EmbeddingsWrapper
|
||||
|
||||
wrapper = EmbeddingsWrapper("model")
|
||||
result = wrapper.embed_query("hello world")
|
||||
|
||||
mock_model.encode.assert_called_once_with("hello world")
|
||||
assert result == [0.1, 0.2, 0.3]
|
||||
|
||||
@patch("application.vectorstore.embeddings_local.SentenceTransformer")
|
||||
def test_embed_documents(self, mock_st_cls):
|
||||
mock_model = MagicMock()
|
||||
mock_model._first_module.return_value = MagicMock()
|
||||
mock_model.get_sentence_embedding_dimension.return_value = 3
|
||||
mock_model.encode.return_value = MagicMock(
|
||||
tolist=Mock(return_value=[[0.1, 0.2], [0.3, 0.4]])
|
||||
)
|
||||
mock_st_cls.return_value = mock_model
|
||||
|
||||
from application.vectorstore.embeddings_local import EmbeddingsWrapper
|
||||
|
||||
wrapper = EmbeddingsWrapper("model")
|
||||
result = wrapper.embed_documents(["doc1", "doc2"])
|
||||
|
||||
mock_model.encode.assert_called_with(["doc1", "doc2"])
|
||||
assert result == [[0.1, 0.2], [0.3, 0.4]]
|
||||
|
||||
@patch("application.vectorstore.embeddings_local.SentenceTransformer")
|
||||
def test_call_with_string(self, mock_st_cls):
|
||||
mock_model = MagicMock()
|
||||
mock_model._first_module.return_value = MagicMock()
|
||||
mock_model.get_sentence_embedding_dimension.return_value = 3
|
||||
mock_model.encode.return_value = MagicMock(tolist=Mock(return_value=[0.1]))
|
||||
mock_st_cls.return_value = mock_model
|
||||
|
||||
from application.vectorstore.embeddings_local import EmbeddingsWrapper
|
||||
|
||||
wrapper = EmbeddingsWrapper("model")
|
||||
result = wrapper("hello")
|
||||
assert result == [0.1]
|
||||
|
||||
@patch("application.vectorstore.embeddings_local.SentenceTransformer")
|
||||
def test_call_with_list(self, mock_st_cls):
|
||||
mock_model = MagicMock()
|
||||
mock_model._first_module.return_value = MagicMock()
|
||||
mock_model.get_sentence_embedding_dimension.return_value = 3
|
||||
mock_model.encode.return_value = MagicMock(
|
||||
tolist=Mock(return_value=[[0.1], [0.2]])
|
||||
)
|
||||
mock_st_cls.return_value = mock_model
|
||||
|
||||
from application.vectorstore.embeddings_local import EmbeddingsWrapper
|
||||
|
||||
wrapper = EmbeddingsWrapper("model")
|
||||
result = wrapper(["a", "b"])
|
||||
assert result == [[0.1], [0.2]]
|
||||
|
||||
@patch("application.vectorstore.embeddings_local.SentenceTransformer")
|
||||
def test_call_with_invalid_type(self, mock_st_cls):
|
||||
mock_model = MagicMock()
|
||||
mock_model._first_module.return_value = MagicMock()
|
||||
mock_model.get_sentence_embedding_dimension.return_value = 3
|
||||
mock_st_cls.return_value = mock_model
|
||||
|
||||
from application.vectorstore.embeddings_local import EmbeddingsWrapper
|
||||
|
||||
wrapper = EmbeddingsWrapper("model")
|
||||
with pytest.raises(ValueError, match="Input must be a string or a list"):
|
||||
wrapper(123)
|
||||
|
||||
@patch("application.vectorstore.embeddings_local.SentenceTransformer")
|
||||
def test_trust_remote_code_default(self, mock_st_cls):
|
||||
mock_model = MagicMock()
|
||||
mock_model._first_module.return_value = MagicMock()
|
||||
mock_model.get_sentence_embedding_dimension.return_value = 768
|
||||
mock_st_cls.return_value = mock_model
|
||||
|
||||
from application.vectorstore.embeddings_local import EmbeddingsWrapper
|
||||
|
||||
EmbeddingsWrapper("model")
|
||||
|
||||
call_kwargs = mock_st_cls.call_args[1]
|
||||
assert call_kwargs["trust_remote_code"] is True
|
||||
359
tests/vectorstore/test_faiss.py
Normal file
359
tests/vectorstore/test_faiss.py
Normal file
@@ -0,0 +1,359 @@
|
||||
import io
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embeddings():
|
||||
emb = Mock()
|
||||
emb.embed_query = Mock(return_value=[0.1, 0.2, 0.3])
|
||||
emb.embed_documents = Mock(return_value=[[0.1, 0.2, 0.3]])
|
||||
emb.dimension = 3
|
||||
return emb
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_storage():
|
||||
storage = Mock()
|
||||
storage.file_exists = Mock(return_value=True)
|
||||
storage.get_file = Mock(return_value=io.BytesIO(b"fake data"))
|
||||
storage.save_file = Mock()
|
||||
return storage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_docsearch():
|
||||
ds = Mock()
|
||||
ds.similarity_search = Mock(return_value=[])
|
||||
ds.add_texts = Mock(return_value=["id1"])
|
||||
ds.add_documents = Mock(return_value=["id1"])
|
||||
ds.save_local = Mock()
|
||||
ds.delete = Mock()
|
||||
ds.index = Mock()
|
||||
ds.index.d = 3
|
||||
ds.docstore = Mock()
|
||||
ds.docstore._dict = {
|
||||
"doc1": Mock(page_content="text1", metadata={"source": "a"}),
|
||||
"doc2": Mock(page_content="text2", metadata={"source": "b"}),
|
||||
}
|
||||
return ds
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestFaissStoreInit:
|
||||
@patch("application.vectorstore.faiss.StorageCreator")
|
||||
@patch("application.vectorstore.faiss.FAISS")
|
||||
@patch.object(
|
||||
__import__("application.vectorstore.base", fromlist=["BaseVectorStore"]).BaseVectorStore,
|
||||
"_get_embeddings",
|
||||
)
|
||||
@patch("application.vectorstore.faiss.settings")
|
||||
def test_init_with_docs(self, mock_settings, mock_get_emb, mock_faiss, mock_storage_creator):
|
||||
mock_settings.EMBEDDINGS_NAME = "test_model"
|
||||
mock_emb = Mock(dimension=3)
|
||||
mock_get_emb.return_value = mock_emb
|
||||
mock_ds = Mock()
|
||||
mock_ds.index = Mock(d=3)
|
||||
mock_faiss.from_documents.return_value = mock_ds
|
||||
mock_storage_creator.get_storage.return_value = Mock()
|
||||
|
||||
from application.vectorstore.faiss import FaissStore
|
||||
|
||||
store = FaissStore(source_id="test", embeddings_key="key", docs_init=[Mock()])
|
||||
mock_faiss.from_documents.assert_called_once()
|
||||
assert store.docsearch is mock_ds
|
||||
|
||||
@patch("application.vectorstore.faiss.StorageCreator")
|
||||
@patch("application.vectorstore.faiss.FAISS")
|
||||
@patch.object(
|
||||
__import__("application.vectorstore.base", fromlist=["BaseVectorStore"]).BaseVectorStore,
|
||||
"_get_embeddings",
|
||||
)
|
||||
@patch("application.vectorstore.faiss.settings")
|
||||
def test_init_missing_index_files(
|
||||
self, mock_settings, mock_get_emb, mock_faiss, mock_storage_creator
|
||||
):
|
||||
mock_settings.EMBEDDINGS_NAME = "test_model"
|
||||
mock_emb = Mock(dimension=3)
|
||||
mock_get_emb.return_value = mock_emb
|
||||
mock_storage = Mock()
|
||||
mock_storage.file_exists.return_value = False
|
||||
mock_storage_creator.get_storage.return_value = mock_storage
|
||||
|
||||
from application.vectorstore.faiss import FaissStore
|
||||
|
||||
with pytest.raises(Exception, match="Error loading FAISS index"):
|
||||
FaissStore(source_id="test", embeddings_key="key")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestFaissStoreSearch:
|
||||
@patch("application.vectorstore.faiss.StorageCreator")
|
||||
@patch("application.vectorstore.faiss.FAISS")
|
||||
@patch.object(
|
||||
__import__("application.vectorstore.base", fromlist=["BaseVectorStore"]).BaseVectorStore,
|
||||
"_get_embeddings",
|
||||
)
|
||||
@patch("application.vectorstore.faiss.settings")
|
||||
def test_search_delegates_to_docsearch(
|
||||
self, mock_settings, mock_get_emb, mock_faiss, mock_storage_creator
|
||||
):
|
||||
mock_settings.EMBEDDINGS_NAME = "test_model"
|
||||
mock_emb = Mock(dimension=3)
|
||||
mock_get_emb.return_value = mock_emb
|
||||
mock_ds = Mock()
|
||||
mock_ds.index = Mock(d=3)
|
||||
mock_ds.similarity_search.return_value = ["doc1"]
|
||||
mock_faiss.from_documents.return_value = mock_ds
|
||||
mock_storage_creator.get_storage.return_value = Mock()
|
||||
|
||||
from application.vectorstore.faiss import FaissStore
|
||||
|
||||
store = FaissStore(source_id="t", embeddings_key="k", docs_init=[Mock()])
|
||||
result = store.search("query", k=5)
|
||||
mock_ds.similarity_search.assert_called_once_with("query", k=5)
|
||||
assert result == ["doc1"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestFaissStoreAddTexts:
|
||||
@patch("application.vectorstore.faiss.StorageCreator")
|
||||
@patch("application.vectorstore.faiss.FAISS")
|
||||
@patch.object(
|
||||
__import__("application.vectorstore.base", fromlist=["BaseVectorStore"]).BaseVectorStore,
|
||||
"_get_embeddings",
|
||||
)
|
||||
@patch("application.vectorstore.faiss.settings")
|
||||
def test_add_texts_delegates(
|
||||
self, mock_settings, mock_get_emb, mock_faiss, mock_storage_creator
|
||||
):
|
||||
mock_settings.EMBEDDINGS_NAME = "test_model"
|
||||
mock_emb = Mock(dimension=3)
|
||||
mock_get_emb.return_value = mock_emb
|
||||
mock_ds = Mock()
|
||||
mock_ds.index = Mock(d=3)
|
||||
mock_ds.add_texts.return_value = ["id1", "id2"]
|
||||
mock_faiss.from_documents.return_value = mock_ds
|
||||
mock_storage_creator.get_storage.return_value = Mock()
|
||||
|
||||
from application.vectorstore.faiss import FaissStore
|
||||
|
||||
store = FaissStore(source_id="t", embeddings_key="k", docs_init=[Mock()])
|
||||
result = store.add_texts(["text1", "text2"])
|
||||
assert result == ["id1", "id2"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestFaissStoreGetChunks:
|
||||
@patch("application.vectorstore.faiss.StorageCreator")
|
||||
@patch("application.vectorstore.faiss.FAISS")
|
||||
@patch.object(
|
||||
__import__("application.vectorstore.base", fromlist=["BaseVectorStore"]).BaseVectorStore,
|
||||
"_get_embeddings",
|
||||
)
|
||||
@patch("application.vectorstore.faiss.settings")
|
||||
def test_get_chunks(self, mock_settings, mock_get_emb, mock_faiss, mock_storage_creator):
|
||||
mock_settings.EMBEDDINGS_NAME = "test_model"
|
||||
mock_emb = Mock(dimension=3)
|
||||
mock_get_emb.return_value = mock_emb
|
||||
|
||||
doc1 = Mock(page_content="text1", metadata={"source": "a"})
|
||||
doc2 = Mock(page_content="text2", metadata={"source": "b"})
|
||||
|
||||
mock_ds = Mock()
|
||||
mock_ds.index = Mock(d=3)
|
||||
mock_ds.docstore._dict = {"id1": doc1, "id2": doc2}
|
||||
mock_faiss.from_documents.return_value = mock_ds
|
||||
mock_storage_creator.get_storage.return_value = Mock()
|
||||
|
||||
from application.vectorstore.faiss import FaissStore
|
||||
|
||||
store = FaissStore(source_id="t", embeddings_key="k", docs_init=[Mock()])
|
||||
chunks = store.get_chunks()
|
||||
|
||||
assert len(chunks) == 2
|
||||
texts = {c["text"] for c in chunks}
|
||||
assert texts == {"text1", "text2"}
|
||||
|
||||
@patch("application.vectorstore.faiss.StorageCreator")
|
||||
@patch("application.vectorstore.faiss.FAISS")
|
||||
@patch.object(
|
||||
__import__("application.vectorstore.base", fromlist=["BaseVectorStore"]).BaseVectorStore,
|
||||
"_get_embeddings",
|
||||
)
|
||||
@patch("application.vectorstore.faiss.settings")
|
||||
def test_get_chunks_empty(self, mock_settings, mock_get_emb, mock_faiss, mock_storage_creator):
|
||||
mock_settings.EMBEDDINGS_NAME = "test_model"
|
||||
mock_emb = Mock(dimension=3)
|
||||
mock_get_emb.return_value = mock_emb
|
||||
mock_ds = Mock()
|
||||
mock_ds.index = Mock(d=3)
|
||||
mock_ds.docstore._dict = {}
|
||||
mock_faiss.from_documents.return_value = mock_ds
|
||||
mock_storage_creator.get_storage.return_value = Mock()
|
||||
|
||||
from application.vectorstore.faiss import FaissStore
|
||||
|
||||
store = FaissStore(source_id="t", embeddings_key="k", docs_init=[Mock()])
|
||||
assert store.get_chunks() == []
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestFaissStoreSaveLocal:
|
||||
@patch("application.vectorstore.faiss.StorageCreator")
|
||||
@patch("application.vectorstore.faiss.FAISS")
|
||||
@patch.object(
|
||||
__import__("application.vectorstore.base", fromlist=["BaseVectorStore"]).BaseVectorStore,
|
||||
"_get_embeddings",
|
||||
)
|
||||
@patch("application.vectorstore.faiss.settings")
|
||||
def test_save_local_with_path(
|
||||
self, mock_settings, mock_get_emb, mock_faiss, mock_storage_creator
|
||||
):
|
||||
mock_settings.EMBEDDINGS_NAME = "test_model"
|
||||
mock_emb = Mock(dimension=3)
|
||||
mock_get_emb.return_value = mock_emb
|
||||
mock_ds = Mock()
|
||||
mock_ds.index = Mock(d=3)
|
||||
mock_faiss.from_documents.return_value = mock_ds
|
||||
mock_storage = Mock()
|
||||
mock_storage_creator.get_storage.return_value = mock_storage
|
||||
|
||||
from application.vectorstore.faiss import FaissStore
|
||||
|
||||
store = FaissStore(source_id="t", embeddings_key="k", docs_init=[Mock()])
|
||||
|
||||
# Mock _save_to_storage to avoid file I/O
|
||||
store._save_to_storage = Mock(return_value=True)
|
||||
|
||||
with patch("os.makedirs"):
|
||||
result = store.save_local(path="/tmp/test_save")
|
||||
|
||||
mock_ds.save_local.assert_called_once_with("/tmp/test_save")
|
||||
store._save_to_storage.assert_called_once()
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestFaissStoreDeleteIndex:
|
||||
@patch("application.vectorstore.faiss.StorageCreator")
|
||||
@patch("application.vectorstore.faiss.FAISS")
|
||||
@patch.object(
|
||||
__import__("application.vectorstore.base", fromlist=["BaseVectorStore"]).BaseVectorStore,
|
||||
"_get_embeddings",
|
||||
)
|
||||
@patch("application.vectorstore.faiss.settings")
|
||||
def test_delete_index_delegates(
|
||||
self, mock_settings, mock_get_emb, mock_faiss, mock_storage_creator
|
||||
):
|
||||
mock_settings.EMBEDDINGS_NAME = "test_model"
|
||||
mock_emb = Mock(dimension=3)
|
||||
mock_get_emb.return_value = mock_emb
|
||||
mock_ds = Mock()
|
||||
mock_ds.index = Mock(d=3)
|
||||
mock_faiss.from_documents.return_value = mock_ds
|
||||
mock_storage_creator.get_storage.return_value = Mock()
|
||||
|
||||
from application.vectorstore.faiss import FaissStore
|
||||
|
||||
store = FaissStore(source_id="t", embeddings_key="k", docs_init=[Mock()])
|
||||
store.delete_index(["id1"])
|
||||
mock_ds.delete.assert_called_once_with(["id1"])
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestFaissStoreAssertEmbeddingDimensions:
|
||||
@patch("application.vectorstore.faiss.StorageCreator")
|
||||
@patch("application.vectorstore.faiss.FAISS")
|
||||
@patch.object(
|
||||
__import__("application.vectorstore.base", fromlist=["BaseVectorStore"]).BaseVectorStore,
|
||||
"_get_embeddings",
|
||||
)
|
||||
@patch("application.vectorstore.faiss.settings")
|
||||
def test_dimension_mismatch_raises(
|
||||
self, mock_settings, mock_get_emb, mock_faiss, mock_storage_creator
|
||||
):
|
||||
mock_settings.EMBEDDINGS_NAME = (
|
||||
"huggingface_sentence-transformers/all-mpnet-base-v2"
|
||||
)
|
||||
mock_emb = Mock(dimension=768)
|
||||
mock_get_emb.return_value = mock_emb
|
||||
mock_ds = Mock()
|
||||
mock_ds.index = Mock(d=512) # Mismatched dimension
|
||||
mock_faiss.from_documents.return_value = mock_ds
|
||||
mock_storage_creator.get_storage.return_value = Mock()
|
||||
|
||||
from application.vectorstore.faiss import FaissStore
|
||||
|
||||
with pytest.raises(ValueError, match="Embedding dimension mismatch"):
|
||||
FaissStore(source_id="t", embeddings_key="k", docs_init=[Mock()])
|
||||
|
||||
@patch("application.vectorstore.faiss.StorageCreator")
|
||||
@patch("application.vectorstore.faiss.FAISS")
|
||||
@patch.object(
|
||||
__import__("application.vectorstore.base", fromlist=["BaseVectorStore"]).BaseVectorStore,
|
||||
"_get_embeddings",
|
||||
)
|
||||
@patch("application.vectorstore.faiss.settings")
|
||||
def test_missing_dimension_attr_raises(
|
||||
self, mock_settings, mock_get_emb, mock_faiss, mock_storage_creator
|
||||
):
|
||||
mock_settings.EMBEDDINGS_NAME = (
|
||||
"huggingface_sentence-transformers/all-mpnet-base-v2"
|
||||
)
|
||||
mock_emb = Mock(spec=[]) # No dimension attribute
|
||||
mock_get_emb.return_value = mock_emb
|
||||
mock_ds = Mock()
|
||||
mock_ds.index = Mock(d=768)
|
||||
mock_faiss.from_documents.return_value = mock_ds
|
||||
mock_storage_creator.get_storage.return_value = Mock()
|
||||
|
||||
from application.vectorstore.faiss import FaissStore
|
||||
|
||||
with pytest.raises(AttributeError, match="dimension"):
|
||||
FaissStore(source_id="t", embeddings_key="k", docs_init=[Mock()])
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestFaissStoreDeleteChunk:
|
||||
@patch("application.vectorstore.faiss.StorageCreator")
|
||||
@patch("application.vectorstore.faiss.FAISS")
|
||||
@patch.object(
|
||||
__import__("application.vectorstore.base", fromlist=["BaseVectorStore"]).BaseVectorStore,
|
||||
"_get_embeddings",
|
||||
)
|
||||
@patch("application.vectorstore.faiss.settings")
|
||||
def test_delete_chunk(self, mock_settings, mock_get_emb, mock_faiss, mock_storage_creator):
|
||||
mock_settings.EMBEDDINGS_NAME = "test_model"
|
||||
mock_emb = Mock(dimension=3)
|
||||
mock_get_emb.return_value = mock_emb
|
||||
mock_ds = Mock()
|
||||
mock_ds.index = Mock(d=3)
|
||||
mock_faiss.from_documents.return_value = mock_ds
|
||||
mock_storage = Mock()
|
||||
mock_storage_creator.get_storage.return_value = mock_storage
|
||||
|
||||
from application.vectorstore.faiss import FaissStore
|
||||
|
||||
store = FaissStore(source_id="t", embeddings_key="k", docs_init=[Mock()])
|
||||
store._save_to_storage = Mock(return_value=True)
|
||||
|
||||
result = store.delete_chunk("chunk_id")
|
||||
mock_ds.delete.assert_called_once_with(["chunk_id"])
|
||||
store._save_to_storage.assert_called_once()
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGetVectorstore:
|
||||
def test_with_path(self):
|
||||
from application.vectorstore.faiss import get_vectorstore
|
||||
|
||||
assert get_vectorstore("abc123") == "indexes/abc123"
|
||||
|
||||
def test_without_path(self):
|
||||
from application.vectorstore.faiss import get_vectorstore
|
||||
|
||||
assert get_vectorstore("") == "indexes"
|
||||
assert get_vectorstore(None) == "indexes"
|
||||
312
tests/vectorstore/test_lancedb.py
Normal file
312
tests/vectorstore/test_lancedb.py
Normal file
@@ -0,0 +1,312 @@
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _make_lancedb_store(source_id="test-source"):
|
||||
"""Helper to create a LanceDBVectorStore with mocked deps."""
|
||||
with patch(
|
||||
"application.vectorstore.lancedb.settings"
|
||||
) as mock_settings:
|
||||
mock_settings.LANCEDB_PATH = "/tmp/lancedb"
|
||||
mock_settings.LANCEDB_TABLE_NAME = "docs"
|
||||
mock_settings.EMBEDDINGS_NAME = "test_model"
|
||||
|
||||
from application.vectorstore.lancedb import LanceDBVectorStore
|
||||
|
||||
store = LanceDBVectorStore(
|
||||
path="/tmp/lancedb",
|
||||
table_name_prefix="docs",
|
||||
source_id=source_id,
|
||||
embeddings_key="key",
|
||||
)
|
||||
|
||||
return store, mock_settings
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestLanceDBVectorStoreInit:
|
||||
def test_table_name_with_source_id(self):
|
||||
store, _ = _make_lancedb_store(source_id="src1")
|
||||
assert store.table_name == "docs_src1"
|
||||
|
||||
def test_table_name_without_source_id(self):
|
||||
with patch("application.vectorstore.lancedb.settings") as mock_settings:
|
||||
mock_settings.LANCEDB_PATH = "/tmp"
|
||||
mock_settings.LANCEDB_TABLE_NAME = "docs"
|
||||
|
||||
from application.vectorstore.lancedb import LanceDBVectorStore
|
||||
|
||||
store = LanceDBVectorStore(
|
||||
path="/tmp", table_name_prefix="docs", source_id=None
|
||||
)
|
||||
assert store.table_name == "docs"
|
||||
|
||||
def test_init_defaults(self):
|
||||
store, _ = _make_lancedb_store()
|
||||
assert store.path == "/tmp/lancedb"
|
||||
assert store._lance_db is None
|
||||
assert store.docsearch is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestLanceDBVectorStoreLazyLoading:
|
||||
def test_pa_lazy_load(self):
|
||||
store, _ = _make_lancedb_store()
|
||||
mock_pa = MagicMock()
|
||||
|
||||
with patch("importlib.import_module", return_value=mock_pa) as mock_import:
|
||||
result = store.pa
|
||||
mock_import.assert_called_with("pyarrow")
|
||||
assert result is mock_pa
|
||||
|
||||
def test_pa_cached(self):
|
||||
store, _ = _make_lancedb_store()
|
||||
mock_pa = MagicMock()
|
||||
store._pa = mock_pa
|
||||
|
||||
assert store.pa is mock_pa
|
||||
|
||||
def test_lancedb_lazy_load(self):
|
||||
store, _ = _make_lancedb_store()
|
||||
mock_ldb = MagicMock()
|
||||
|
||||
with patch("importlib.import_module", return_value=mock_ldb) as mock_import:
|
||||
result = store.lancedb
|
||||
mock_import.assert_called_with("lancedb")
|
||||
assert result is mock_ldb
|
||||
|
||||
def test_lance_db_connection(self):
|
||||
store, _ = _make_lancedb_store()
|
||||
mock_ldb_module = MagicMock()
|
||||
mock_conn = MagicMock()
|
||||
mock_ldb_module.connect.return_value = mock_conn
|
||||
store._lancedb_module = mock_ldb_module
|
||||
|
||||
result = store.lance_db
|
||||
mock_ldb_module.connect.assert_called_once_with("/tmp/lancedb")
|
||||
assert result is mock_conn
|
||||
|
||||
def test_lance_db_cached(self):
|
||||
store, _ = _make_lancedb_store()
|
||||
mock_conn = MagicMock()
|
||||
store._lance_db = mock_conn
|
||||
|
||||
assert store.lance_db is mock_conn
|
||||
|
||||
def test_table_opens_existing(self):
|
||||
store, _ = _make_lancedb_store()
|
||||
mock_conn = MagicMock()
|
||||
mock_table = MagicMock()
|
||||
mock_conn.table_names.return_value = [store.table_name]
|
||||
mock_conn.open_table.return_value = mock_table
|
||||
store._lance_db = mock_conn
|
||||
|
||||
result = store.table
|
||||
mock_conn.open_table.assert_called_once_with(store.table_name)
|
||||
assert result is mock_table
|
||||
|
||||
def test_table_returns_none_for_missing(self):
|
||||
store, _ = _make_lancedb_store()
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.table_names.return_value = []
|
||||
store._lance_db = mock_conn
|
||||
|
||||
result = store.table
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestLanceDBVectorStoreEnsureTableExists:
|
||||
def test_creates_table_when_missing(self):
|
||||
store, _ = _make_lancedb_store()
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.table_names.return_value = []
|
||||
store._lance_db = mock_conn
|
||||
|
||||
mock_emb = MagicMock()
|
||||
mock_emb.dimension = 768
|
||||
mock_pa = MagicMock()
|
||||
store._pa = mock_pa
|
||||
|
||||
with patch.object(store, "_get_embeddings", return_value=mock_emb):
|
||||
store.ensure_table_exists()
|
||||
|
||||
mock_conn.create_table.assert_called_once()
|
||||
|
||||
def test_noop_when_table_exists(self):
|
||||
store, _ = _make_lancedb_store()
|
||||
mock_table = MagicMock()
|
||||
store.docsearch = mock_table
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.table_names.return_value = [store.table_name]
|
||||
mock_conn.open_table.return_value = mock_table
|
||||
store._lance_db = mock_conn
|
||||
|
||||
store.ensure_table_exists()
|
||||
mock_conn.create_table.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestLanceDBVectorStoreAddTexts:
|
||||
def test_add_texts(self):
|
||||
store, _ = _make_lancedb_store()
|
||||
mock_table = MagicMock()
|
||||
store.docsearch = mock_table
|
||||
|
||||
mock_emb = MagicMock()
|
||||
mock_emb.embed_documents.return_value = [[0.1, 0.2], [0.3, 0.4]]
|
||||
|
||||
with patch.object(store, "_get_embeddings", return_value=mock_emb), patch.object(
|
||||
store, "ensure_table_exists"
|
||||
):
|
||||
store.add_texts(
|
||||
["text1", "text2"],
|
||||
metadatas=[{"a": "1"}, {"b": "2"}],
|
||||
source_id="src1",
|
||||
)
|
||||
|
||||
mock_table.add.assert_called_once()
|
||||
vectors = mock_table.add.call_args[0][0]
|
||||
assert len(vectors) == 2
|
||||
assert vectors[0]["text"] == "text1"
|
||||
assert vectors[0]["vector"] == [0.1, 0.2]
|
||||
|
||||
def test_add_texts_with_source_id_in_metadata(self):
|
||||
store, _ = _make_lancedb_store()
|
||||
mock_table = MagicMock()
|
||||
store.docsearch = mock_table
|
||||
|
||||
mock_emb = MagicMock()
|
||||
mock_emb.embed_documents.return_value = [[0.1]]
|
||||
|
||||
with patch.object(store, "_get_embeddings", return_value=mock_emb), patch.object(
|
||||
store, "ensure_table_exists"
|
||||
):
|
||||
store.add_texts(["text1"], metadatas=[{"k": "v"}], source_id="src1")
|
||||
|
||||
vectors = mock_table.add.call_args[0][0]
|
||||
metadata_keys = [m["key"] for m in vectors[0]["metadata"]]
|
||||
assert "source_id" in metadata_keys
|
||||
|
||||
def test_add_texts_default_metadata(self):
|
||||
store, _ = _make_lancedb_store()
|
||||
mock_table = MagicMock()
|
||||
store.docsearch = mock_table
|
||||
|
||||
mock_emb = MagicMock()
|
||||
mock_emb.embed_documents.return_value = [[0.1]]
|
||||
|
||||
with patch.object(store, "_get_embeddings", return_value=mock_emb), patch.object(
|
||||
store, "ensure_table_exists"
|
||||
):
|
||||
store.add_texts(["text1"])
|
||||
|
||||
mock_table.add.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestLanceDBVectorStoreSearch:
|
||||
def test_search(self):
|
||||
store, _ = _make_lancedb_store()
|
||||
mock_table = MagicMock()
|
||||
store.docsearch = mock_table
|
||||
|
||||
mock_emb = MagicMock()
|
||||
mock_emb.embed_query.return_value = [0.1, 0.2]
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.limit.return_value.to_list.return_value = [
|
||||
{"_distance": 0.1, "text": "result1", "metadata": {"k": "v"}},
|
||||
]
|
||||
mock_table.search.return_value = mock_result
|
||||
|
||||
with patch.object(store, "_get_embeddings", return_value=mock_emb), patch.object(
|
||||
store, "ensure_table_exists"
|
||||
):
|
||||
results = store.search("query", k=3)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0][1] == "result1"
|
||||
mock_result.limit.assert_called_once_with(3)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestLanceDBVectorStoreDeleteIndex:
|
||||
def test_delete_index_drops_table(self):
|
||||
store, _ = _make_lancedb_store()
|
||||
mock_table = MagicMock()
|
||||
store.docsearch = mock_table
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.table_names.return_value = [store.table_name]
|
||||
mock_conn.open_table.return_value = mock_table
|
||||
store._lance_db = mock_conn
|
||||
|
||||
store.delete_index()
|
||||
|
||||
mock_conn.drop_table.assert_called_once_with(store.table_name)
|
||||
|
||||
def test_delete_index_noop_when_no_table(self):
|
||||
store, _ = _make_lancedb_store()
|
||||
store.docsearch = None
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.table_names.return_value = []
|
||||
store._lance_db = mock_conn
|
||||
|
||||
store.delete_index()
|
||||
mock_conn.drop_table.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestLanceDBVectorStoreAssertEmbeddingDimensions:
|
||||
def test_matching_dimensions(self):
|
||||
store, _ = _make_lancedb_store()
|
||||
mock_table = MagicMock()
|
||||
mock_table.schema = {"vector": MagicMock()}
|
||||
mock_table.schema["vector"].type.value_type.__len__ = Mock(return_value=768)
|
||||
store.docsearch = mock_table
|
||||
|
||||
mock_emb = MagicMock()
|
||||
mock_emb.dimension = 768
|
||||
|
||||
# Should not raise
|
||||
store.assert_embedding_dimensions(mock_emb)
|
||||
|
||||
def test_mismatched_dimensions_raises(self):
|
||||
store, _ = _make_lancedb_store()
|
||||
mock_table = MagicMock()
|
||||
store.docsearch = mock_table
|
||||
|
||||
type_mock = MagicMock()
|
||||
type_mock.__len__ = Mock(return_value=512)
|
||||
mock_table.schema.__getitem__ = Mock(return_value=MagicMock())
|
||||
mock_table.schema["vector"].type.value_type = type_mock
|
||||
|
||||
mock_emb = MagicMock()
|
||||
mock_emb.dimension = 768
|
||||
|
||||
with pytest.raises(ValueError, match="Embedding dimension mismatch"):
|
||||
store.assert_embedding_dimensions(mock_emb)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestLanceDBVectorStoreFilterDocuments:
|
||||
def test_filter_documents(self):
|
||||
store, _ = _make_lancedb_store()
|
||||
mock_table = MagicMock()
|
||||
mock_table.filter.return_value.to_list.return_value = [{"text": "filtered"}]
|
||||
store.docsearch = mock_table
|
||||
|
||||
with patch.object(store, "ensure_table_exists"):
|
||||
results = store.filter_documents({"source_id": "src1"})
|
||||
|
||||
assert len(results) == 1
|
||||
|
||||
def test_filter_documents_requires_source_id(self):
|
||||
store, _ = _make_lancedb_store()
|
||||
mock_table = MagicMock()
|
||||
store.docsearch = mock_table
|
||||
|
||||
with patch.object(store, "ensure_table_exists"):
|
||||
with pytest.raises(ValueError, match="must contain 'source_id'"):
|
||||
store.filter_documents({"other_key": "value"})
|
||||
95
tests/vectorstore/test_milvus.py
Normal file
95
tests/vectorstore/test_milvus.py
Normal file
@@ -0,0 +1,95 @@
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _make_milvus_store(source_id="test-source"):
|
||||
"""Helper to create a MilvusStore with mocked deps."""
|
||||
with patch(
|
||||
"application.vectorstore.base.BaseVectorStore._get_embeddings"
|
||||
) as mock_get_emb, patch(
|
||||
"application.vectorstore.milvus.settings"
|
||||
) as mock_settings, patch.dict(
|
||||
"sys.modules",
|
||||
{
|
||||
"langchain_milvus": MagicMock(),
|
||||
},
|
||||
):
|
||||
mock_emb = Mock()
|
||||
mock_get_emb.return_value = mock_emb
|
||||
mock_settings.EMBEDDINGS_NAME = "test_model"
|
||||
mock_settings.MILVUS_URI = "http://localhost:19530"
|
||||
mock_settings.MILVUS_TOKEN = "token"
|
||||
mock_settings.MILVUS_COLLECTION_NAME = "test_collection"
|
||||
|
||||
from langchain_milvus import Milvus
|
||||
|
||||
mock_docsearch = MagicMock()
|
||||
Milvus.return_value = mock_docsearch
|
||||
|
||||
from application.vectorstore.milvus import MilvusStore
|
||||
|
||||
store = MilvusStore(source_id=source_id, embeddings_key="key")
|
||||
store._docsearch = mock_docsearch
|
||||
|
||||
return store, mock_docsearch
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestMilvusStoreInit:
|
||||
def test_source_id_stored(self):
|
||||
store, _ = _make_milvus_store(source_id="src1")
|
||||
assert store._source_id == "src1"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestMilvusStoreSearch:
|
||||
def test_search(self):
|
||||
store, mock_ds = _make_milvus_store(source_id="src1")
|
||||
mock_ds.similarity_search.return_value = ["doc1", "doc2"]
|
||||
|
||||
results = store.search("query", k=3)
|
||||
|
||||
mock_ds.similarity_search.assert_called_once()
|
||||
call_kwargs = mock_ds.similarity_search.call_args
|
||||
assert call_kwargs[1]["query"] == "query"
|
||||
assert call_kwargs[1]["k"] == 3
|
||||
assert call_kwargs[1]["expr"] == "source_id == 'src1'"
|
||||
assert results == ["doc1", "doc2"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestMilvusStoreAddTexts:
|
||||
def test_add_texts(self):
|
||||
store, mock_ds = _make_milvus_store()
|
||||
mock_ds.add_texts.return_value = ["id1", "id2"]
|
||||
|
||||
result = store.add_texts(
|
||||
["text1", "text2"], metadatas=[{"a": 1}, {"b": 2}]
|
||||
)
|
||||
|
||||
mock_ds.add_texts.assert_called_once()
|
||||
call_kwargs = mock_ds.add_texts.call_args
|
||||
assert call_kwargs[1]["texts"] == ["text1", "text2"]
|
||||
# ids should be UUIDs
|
||||
ids = call_kwargs[1]["ids"]
|
||||
assert len(ids) == 2
|
||||
for uid in ids:
|
||||
UUID(uid) # Validates it's a valid UUID
|
||||
|
||||
assert result == ["id1", "id2"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestMilvusStoreSaveLocal:
|
||||
def test_save_local_is_noop(self):
|
||||
store, _ = _make_milvus_store()
|
||||
assert store.save_local() is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestMilvusStoreDeleteIndex:
|
||||
def test_delete_index_is_noop(self):
|
||||
store, _ = _make_milvus_store()
|
||||
assert store.delete_index() is None
|
||||
259
tests/vectorstore/test_mongodb.py
Normal file
259
tests/vectorstore/test_mongodb.py
Normal file
@@ -0,0 +1,259 @@
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _make_mongodb_store(source_id="test-source"):
|
||||
"""Helper to create a MongoDBVectorStore with all external deps mocked."""
|
||||
with patch(
|
||||
"application.vectorstore.base.BaseVectorStore._get_embeddings"
|
||||
) as mock_get_emb, patch(
|
||||
"application.vectorstore.mongodb.settings"
|
||||
) as mock_settings, patch.dict(
|
||||
"sys.modules", {"pymongo": MagicMock()}
|
||||
):
|
||||
mock_emb = Mock()
|
||||
mock_emb.embed_query = Mock(return_value=[0.1, 0.2, 0.3])
|
||||
mock_emb.embed_documents = Mock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_get_emb.return_value = mock_emb
|
||||
mock_settings.EMBEDDINGS_NAME = "test_model"
|
||||
mock_settings.MONGO_URI = "mongodb://localhost:27017"
|
||||
|
||||
from application.vectorstore.mongodb import MongoDBVectorStore
|
||||
|
||||
store = MongoDBVectorStore(
|
||||
source_id=source_id,
|
||||
embeddings_key="key",
|
||||
collection="test_docs",
|
||||
database="test_db",
|
||||
)
|
||||
|
||||
mock_collection = MagicMock()
|
||||
store._collection = mock_collection
|
||||
|
||||
return store, mock_collection, mock_emb
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestMongoDBVectorStoreInit:
|
||||
def test_source_id_cleaned(self):
|
||||
store, _, _ = _make_mongodb_store(source_id="application/indexes/abc123/")
|
||||
assert store._source_id == "abc123"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestMongoDBVectorStoreSearch:
|
||||
def test_search_builds_pipeline(self):
|
||||
store, mock_collection, mock_emb = _make_mongodb_store()
|
||||
|
||||
doc1 = {
|
||||
"_id": "id1",
|
||||
"text": "hello world",
|
||||
"embedding": [0.1, 0.2],
|
||||
"source": "test",
|
||||
}
|
||||
mock_collection.aggregate.return_value = iter([doc1])
|
||||
|
||||
results = store.search("query", k=3)
|
||||
|
||||
mock_emb.embed_query.assert_called_once_with("query")
|
||||
mock_collection.aggregate.assert_called_once()
|
||||
pipeline = mock_collection.aggregate.call_args[0][0]
|
||||
assert pipeline[0]["$vectorSearch"]["limit"] == 3
|
||||
assert pipeline[0]["$vectorSearch"]["numCandidates"] == 30
|
||||
|
||||
assert len(results) == 1
|
||||
assert str(results[0]) == "hello world"
|
||||
|
||||
def test_search_removes_id_text_embedding_from_metadata(self):
|
||||
store, mock_collection, _ = _make_mongodb_store()
|
||||
|
||||
doc = {
|
||||
"_id": "id1",
|
||||
"text": "content",
|
||||
"embedding": [0.1],
|
||||
"custom_key": "custom_val",
|
||||
}
|
||||
mock_collection.aggregate.return_value = iter([doc])
|
||||
|
||||
results = store.search("q", k=1)
|
||||
metadata = results[0].metadata
|
||||
assert "_id" not in metadata
|
||||
assert "text" not in metadata
|
||||
assert "embedding" not in metadata
|
||||
assert metadata["custom_key"] == "custom_val"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestMongoDBVectorStoreAddTexts:
|
||||
def test_add_texts_batches(self):
|
||||
store, mock_collection, mock_emb = _make_mongodb_store()
|
||||
# Generate 150 texts to trigger batching at 100
|
||||
texts = [f"text_{i}" for i in range(150)]
|
||||
metadatas = [{"i": i} for i in range(150)]
|
||||
mock_emb.embed_documents.return_value = [[0.1]] * 100 # per batch
|
||||
|
||||
mock_collection.insert_many.return_value = Mock(
|
||||
inserted_ids=list(range(100))
|
||||
)
|
||||
|
||||
store.add_texts(texts, metadatas)
|
||||
|
||||
# Should have been called twice: batch of 100, then batch of 50
|
||||
assert mock_collection.insert_many.call_count == 2
|
||||
|
||||
def test_add_texts_default_metadata(self):
|
||||
store, mock_collection, mock_emb = _make_mongodb_store()
|
||||
mock_emb.embed_documents.return_value = [[0.1]]
|
||||
mock_collection.insert_many.return_value = Mock(inserted_ids=["id1"])
|
||||
|
||||
store.add_texts(["text1"])
|
||||
mock_collection.insert_many.assert_called_once()
|
||||
|
||||
def test_add_texts_empty(self):
|
||||
store, mock_collection, mock_emb = _make_mongodb_store()
|
||||
mock_emb.embed_documents.return_value = []
|
||||
mock_collection.insert_many.return_value = Mock(inserted_ids=[])
|
||||
|
||||
result = store.add_texts([], [])
|
||||
# _insert_texts returns [] for empty input
|
||||
assert result == []
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestMongoDBVectorStoreInsertTexts:
|
||||
def test_insert_texts_empty_returns_empty(self):
|
||||
store, _, _ = _make_mongodb_store()
|
||||
result = store._insert_texts([], [])
|
||||
assert result == []
|
||||
|
||||
def test_insert_texts_builds_correct_documents(self):
|
||||
store, mock_collection, mock_emb = _make_mongodb_store()
|
||||
mock_emb.embed_documents.return_value = [[0.1, 0.2]]
|
||||
mock_collection.insert_many.return_value = Mock(inserted_ids=["id1"])
|
||||
|
||||
store._insert_texts(["hello"], [{"source": "test"}])
|
||||
|
||||
inserted_docs = mock_collection.insert_many.call_args[0][0]
|
||||
assert len(inserted_docs) == 1
|
||||
assert inserted_docs[0]["text"] == "hello"
|
||||
assert inserted_docs[0]["embedding"] == [0.1, 0.2]
|
||||
assert inserted_docs[0]["source"] == "test"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestMongoDBVectorStoreDeleteIndex:
|
||||
def test_delete_index_calls_delete_many(self):
|
||||
store, mock_collection, _ = _make_mongodb_store(source_id="src1")
|
||||
|
||||
store.delete_index()
|
||||
|
||||
mock_collection.delete_many.assert_called_once_with({"source_id": "src1"})
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestMongoDBVectorStoreGetChunks:
|
||||
def test_get_chunks(self):
|
||||
store, mock_collection, _ = _make_mongodb_store()
|
||||
|
||||
docs = [
|
||||
{
|
||||
"_id": "id1",
|
||||
"text": "chunk1",
|
||||
"embedding": [0.1],
|
||||
"source_id": "src",
|
||||
"extra": "val",
|
||||
},
|
||||
{
|
||||
"_id": "id2",
|
||||
"text": "chunk2",
|
||||
"embedding": [0.2],
|
||||
"source_id": "src",
|
||||
},
|
||||
]
|
||||
mock_collection.find.return_value = iter(docs)
|
||||
|
||||
chunks = store.get_chunks()
|
||||
|
||||
assert len(chunks) == 2
|
||||
assert chunks[0]["doc_id"] == "id1"
|
||||
assert chunks[0]["text"] == "chunk1"
|
||||
assert chunks[0]["metadata"] == {"extra": "val"}
|
||||
assert "embedding" not in chunks[0]["metadata"]
|
||||
assert "source_id" not in chunks[0]["metadata"]
|
||||
|
||||
def test_get_chunks_skips_empty_text(self):
|
||||
store, mock_collection, _ = _make_mongodb_store()
|
||||
|
||||
docs = [
|
||||
{"_id": "id1", "text": None, "embedding": [0.1], "source_id": "src"},
|
||||
]
|
||||
mock_collection.find.return_value = iter(docs)
|
||||
|
||||
chunks = store.get_chunks()
|
||||
assert len(chunks) == 0
|
||||
|
||||
def test_get_chunks_returns_empty_on_error(self):
|
||||
store, mock_collection, _ = _make_mongodb_store()
|
||||
mock_collection.find.side_effect = Exception("connection error")
|
||||
|
||||
assert store.get_chunks() == []
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestMongoDBVectorStoreAddChunk:
|
||||
def test_add_chunk(self):
|
||||
store, mock_collection, mock_emb = _make_mongodb_store(source_id="src1")
|
||||
mock_emb.embed_documents.return_value = [[0.1, 0.2]]
|
||||
mock_collection.insert_one.return_value = Mock(inserted_id="new_id")
|
||||
|
||||
result = store.add_chunk("hello chunk", metadata={"key": "val"})
|
||||
|
||||
assert result == "new_id"
|
||||
inserted = mock_collection.insert_one.call_args[0][0]
|
||||
assert inserted["text"] == "hello chunk"
|
||||
assert inserted["source_id"] == "src1"
|
||||
assert inserted["key"] == "val"
|
||||
|
||||
def test_add_chunk_default_metadata(self):
|
||||
store, mock_collection, mock_emb = _make_mongodb_store()
|
||||
mock_emb.embed_documents.return_value = [[0.1]]
|
||||
mock_collection.insert_one.return_value = Mock(inserted_id="id")
|
||||
|
||||
store.add_chunk("text")
|
||||
|
||||
inserted = mock_collection.insert_one.call_args[0][0]
|
||||
assert "source_id" in inserted
|
||||
|
||||
def test_add_chunk_raises_on_empty_embedding(self):
|
||||
store, _, mock_emb = _make_mongodb_store()
|
||||
mock_emb.embed_documents.return_value = []
|
||||
|
||||
with pytest.raises(ValueError, match="Could not generate embedding"):
|
||||
store.add_chunk("text")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestMongoDBVectorStoreDeleteChunk:
|
||||
def test_delete_chunk_success(self):
|
||||
store, mock_collection, _ = _make_mongodb_store()
|
||||
mock_collection.delete_one.return_value = Mock(deleted_count=1)
|
||||
|
||||
with patch("application.vectorstore.mongodb.ObjectId", create=True):
|
||||
# We need to mock bson.objectid.ObjectId
|
||||
with patch.dict("sys.modules", {"bson": MagicMock(), "bson.objectid": MagicMock()}):
|
||||
from unittest.mock import MagicMock as MM
|
||||
mock_oid = MM()
|
||||
with patch(
|
||||
"bson.objectid.ObjectId", return_value=mock_oid
|
||||
):
|
||||
result = store.delete_chunk("507f1f77bcf86cd799439011")
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_delete_chunk_returns_false_on_error(self):
|
||||
store, mock_collection, _ = _make_mongodb_store()
|
||||
mock_collection.delete_one.side_effect = Exception("fail")
|
||||
|
||||
result = store.delete_chunk("bad_id")
|
||||
assert result is False
|
||||
297
tests/vectorstore/test_pgvector.py
Normal file
297
tests/vectorstore/test_pgvector.py
Normal file
@@ -0,0 +1,297 @@
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _make_store(
|
||||
source_id="test-source",
|
||||
embeddings_key="key",
|
||||
connection_string="postgresql://user:pass@localhost/db",
|
||||
):
|
||||
"""Helper to create a PGVectorStore with all external deps mocked."""
|
||||
with patch(
|
||||
"application.vectorstore.base.BaseVectorStore._get_embeddings"
|
||||
) as mock_get_emb, patch(
|
||||
"application.vectorstore.pgvector.settings"
|
||||
) as mock_settings, patch.dict(
|
||||
"sys.modules",
|
||||
{
|
||||
"psycopg2": MagicMock(),
|
||||
"psycopg2.extras": MagicMock(),
|
||||
"pgvector": MagicMock(),
|
||||
"pgvector.psycopg2": MagicMock(),
|
||||
},
|
||||
):
|
||||
mock_emb = Mock()
|
||||
mock_emb.embed_query = Mock(return_value=[0.1, 0.2, 0.3])
|
||||
mock_emb.embed_documents = Mock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_emb.dimension = 768
|
||||
mock_get_emb.return_value = mock_emb
|
||||
mock_settings.EMBEDDINGS_NAME = "test_model"
|
||||
mock_settings.PGVECTOR_CONNECTION_STRING = connection_string
|
||||
|
||||
from application.vectorstore.pgvector import PGVectorStore
|
||||
|
||||
# Patch _ensure_table_exists to avoid DB calls during init
|
||||
with patch.object(PGVectorStore, "_ensure_table_exists"):
|
||||
store = PGVectorStore(
|
||||
source_id=source_id,
|
||||
embeddings_key=embeddings_key,
|
||||
connection_string=connection_string,
|
||||
)
|
||||
# Provide a mock connection
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_conn.closed = False
|
||||
store._connection = mock_conn
|
||||
|
||||
return store, mock_conn, mock_cursor, mock_emb
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestPGVectorStoreInit:
|
||||
def test_source_id_cleaned(self):
|
||||
store, _, _, _ = _make_store(source_id="application/indexes/abc123/")
|
||||
assert store._source_id == "abc123"
|
||||
|
||||
def test_missing_connection_string_raises(self):
|
||||
with patch(
|
||||
"application.vectorstore.base.BaseVectorStore._get_embeddings"
|
||||
) as mock_get_emb, patch(
|
||||
"application.vectorstore.pgvector.settings"
|
||||
) as mock_settings, patch.dict(
|
||||
"sys.modules",
|
||||
{
|
||||
"psycopg2": MagicMock(),
|
||||
"psycopg2.extras": MagicMock(),
|
||||
"pgvector": MagicMock(),
|
||||
"pgvector.psycopg2": MagicMock(),
|
||||
},
|
||||
):
|
||||
mock_get_emb.return_value = Mock(dimension=768)
|
||||
mock_settings.EMBEDDINGS_NAME = "test_model"
|
||||
mock_settings.PGVECTOR_CONNECTION_STRING = None
|
||||
|
||||
from application.vectorstore.pgvector import PGVectorStore
|
||||
|
||||
with pytest.raises(ValueError, match="connection string is required"):
|
||||
PGVectorStore(
|
||||
source_id="test", embeddings_key="key", connection_string=None
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestPGVectorStoreSearch:
|
||||
def test_search_returns_documents(self):
|
||||
store, mock_conn, mock_cursor, mock_emb = _make_store()
|
||||
mock_cursor.fetchall.return_value = [
|
||||
("hello world", {"source": "test.txt"}, 0.1),
|
||||
("foo bar", {"source": "test2.txt"}, 0.2),
|
||||
]
|
||||
|
||||
results = store.search("query", k=2)
|
||||
|
||||
mock_emb.embed_query.assert_called_once_with("query")
|
||||
assert len(results) == 2
|
||||
assert results[0].page_content == "hello world"
|
||||
assert results[0].metadata == {"source": "test.txt"}
|
||||
|
||||
def test_search_returns_empty_on_error(self):
|
||||
store, mock_conn, mock_cursor, _ = _make_store()
|
||||
mock_cursor.execute.side_effect = Exception("connection lost")
|
||||
|
||||
results = store.search("query")
|
||||
assert results == []
|
||||
|
||||
def test_search_handles_null_metadata(self):
|
||||
store, _, mock_cursor, _ = _make_store()
|
||||
mock_cursor.fetchall.return_value = [("text", None, 0.5)]
|
||||
|
||||
results = store.search("query")
|
||||
assert len(results) == 1
|
||||
assert results[0].metadata == {}
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestPGVectorStoreAddTexts:
|
||||
def test_add_texts_inserts_and_returns_ids(self):
|
||||
store, mock_conn, mock_cursor, mock_emb = _make_store()
|
||||
mock_emb.embed_documents.return_value = [[0.1, 0.2], [0.3, 0.4]]
|
||||
mock_cursor.fetchone.side_effect = [(1,), (2,)]
|
||||
|
||||
ids = store.add_texts(["text1", "text2"], [{"a": 1}, {"b": 2}])
|
||||
|
||||
assert ids == ["1", "2"]
|
||||
assert mock_cursor.execute.call_count == 2
|
||||
mock_conn.commit.assert_called_once()
|
||||
|
||||
def test_add_texts_empty_returns_empty(self):
|
||||
store, _, _, _ = _make_store()
|
||||
assert store.add_texts([]) == []
|
||||
|
||||
def test_add_texts_default_metadatas(self):
|
||||
store, mock_conn, mock_cursor, mock_emb = _make_store()
|
||||
mock_emb.embed_documents.return_value = [[0.1, 0.2]]
|
||||
mock_cursor.fetchone.return_value = (1,)
|
||||
|
||||
ids = store.add_texts(["text1"])
|
||||
assert ids == ["1"]
|
||||
|
||||
def test_add_texts_rolls_back_on_error(self):
|
||||
store, mock_conn, mock_cursor, mock_emb = _make_store()
|
||||
mock_emb.embed_documents.return_value = [[0.1]]
|
||||
mock_cursor.execute.side_effect = Exception("insert failed")
|
||||
|
||||
with pytest.raises(Exception, match="insert failed"):
|
||||
store.add_texts(["text1"])
|
||||
|
||||
mock_conn.rollback.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestPGVectorStoreDeleteIndex:
|
||||
def test_delete_index_deletes_by_source_id(self):
|
||||
store, mock_conn, mock_cursor, _ = _make_store(source_id="src123")
|
||||
|
||||
store.delete_index()
|
||||
|
||||
mock_cursor.execute.assert_called_once()
|
||||
sql = mock_cursor.execute.call_args[0][0]
|
||||
assert "DELETE FROM" in sql
|
||||
assert mock_cursor.execute.call_args[0][1] == ("src123",)
|
||||
mock_conn.commit.assert_called_once()
|
||||
|
||||
def test_delete_index_rolls_back_on_error(self):
|
||||
store, mock_conn, mock_cursor, _ = _make_store()
|
||||
mock_cursor.execute.side_effect = Exception("fail")
|
||||
|
||||
with pytest.raises(Exception):
|
||||
store.delete_index()
|
||||
|
||||
mock_conn.rollback.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestPGVectorStoreSaveLocal:
|
||||
def test_save_local_is_noop(self):
|
||||
store, _, _, _ = _make_store()
|
||||
assert store.save_local() is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestPGVectorStoreGetChunks:
|
||||
def test_get_chunks(self):
|
||||
store, _, mock_cursor, _ = _make_store()
|
||||
mock_cursor.fetchall.return_value = [
|
||||
(1, "text1", {"key": "val"}),
|
||||
(2, "text2", None),
|
||||
]
|
||||
|
||||
chunks = store.get_chunks()
|
||||
assert len(chunks) == 2
|
||||
assert chunks[0] == {"doc_id": "1", "text": "text1", "metadata": {"key": "val"}}
|
||||
assert chunks[1] == {"doc_id": "2", "text": "text2", "metadata": {}}
|
||||
|
||||
def test_get_chunks_returns_empty_on_error(self):
|
||||
store, _, mock_cursor, _ = _make_store()
|
||||
mock_cursor.execute.side_effect = Exception("fail")
|
||||
|
||||
assert store.get_chunks() == []
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestPGVectorStoreAddChunk:
|
||||
def test_add_chunk(self):
|
||||
store, mock_conn, mock_cursor, mock_emb = _make_store(source_id="src1")
|
||||
mock_emb.embed_documents.return_value = [[0.1, 0.2]]
|
||||
mock_cursor.fetchone.return_value = (42,)
|
||||
|
||||
chunk_id = store.add_chunk("hello", metadata={"key": "val"})
|
||||
|
||||
assert chunk_id == "42"
|
||||
mock_conn.commit.assert_called_once()
|
||||
|
||||
def test_add_chunk_raises_on_empty_embedding(self):
|
||||
store, _, _, mock_emb = _make_store()
|
||||
mock_emb.embed_documents.return_value = []
|
||||
|
||||
with pytest.raises(ValueError, match="Could not generate embedding"):
|
||||
store.add_chunk("text")
|
||||
|
||||
def test_add_chunk_includes_source_id_in_metadata(self):
|
||||
store, mock_conn, mock_cursor, mock_emb = _make_store(source_id="src1")
|
||||
mock_emb.embed_documents.return_value = [[0.1, 0.2]]
|
||||
mock_cursor.fetchone.return_value = (1,)
|
||||
|
||||
store.add_chunk("hello", metadata={"key": "val"})
|
||||
|
||||
# Verify source_id is passed as a parameter to the INSERT
|
||||
insert_call = mock_cursor.execute.call_args
|
||||
params = insert_call[0][1]
|
||||
# source_id is the 4th param in the insert
|
||||
assert params[3] == "src1"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestPGVectorStoreDeleteChunk:
|
||||
def test_delete_chunk_success(self):
|
||||
store, mock_conn, mock_cursor, _ = _make_store()
|
||||
mock_cursor.rowcount = 1
|
||||
|
||||
result = store.delete_chunk("42")
|
||||
assert result is True
|
||||
mock_conn.commit.assert_called_once()
|
||||
|
||||
def test_delete_chunk_not_found(self):
|
||||
store, mock_conn, mock_cursor, _ = _make_store()
|
||||
mock_cursor.rowcount = 0
|
||||
|
||||
result = store.delete_chunk("999")
|
||||
assert result is False
|
||||
|
||||
def test_delete_chunk_returns_false_on_error(self):
|
||||
store, _, mock_cursor, _ = _make_store()
|
||||
mock_cursor.execute.side_effect = Exception("fail")
|
||||
|
||||
result = store.delete_chunk("42")
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestPGVectorStoreConnection:
|
||||
def test_get_connection_creates_new_when_closed(self):
|
||||
store, mock_conn, _, _ = _make_store()
|
||||
mock_conn.closed = True
|
||||
|
||||
mock_psycopg2 = MagicMock()
|
||||
new_conn = MagicMock()
|
||||
mock_psycopg2.connect.return_value = new_conn
|
||||
store._psycopg2 = mock_psycopg2
|
||||
|
||||
conn = store._get_connection()
|
||||
mock_psycopg2.connect.assert_called_once()
|
||||
assert conn is new_conn
|
||||
|
||||
def test_get_connection_reuses_open(self):
|
||||
store, mock_conn, _, _ = _make_store()
|
||||
mock_conn.closed = False
|
||||
|
||||
conn = store._get_connection()
|
||||
assert conn is mock_conn
|
||||
|
||||
def test_ensure_table_exists(self):
|
||||
store, mock_conn, mock_cursor, _ = _make_store()
|
||||
# Call _ensure_table_exists directly
|
||||
store._ensure_table_exists()
|
||||
|
||||
# Should execute CREATE EXTENSION, CREATE TABLE, and CREATE INDEX statements
|
||||
assert mock_cursor.execute.call_count >= 3
|
||||
mock_conn.commit.assert_called()
|
||||
|
||||
def test_del_closes_connection(self):
|
||||
store, mock_conn, _, _ = _make_store()
|
||||
mock_conn.closed = False
|
||||
|
||||
store.__del__()
|
||||
mock_conn.close.assert_called_once()
|
||||
230
tests/vectorstore/test_qdrant.py
Normal file
230
tests/vectorstore/test_qdrant.py
Normal file
@@ -0,0 +1,230 @@
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _make_qdrant_store(source_id="test-source"):
|
||||
"""Helper to create a QdrantStore with all external deps mocked."""
|
||||
mock_models = MagicMock()
|
||||
mock_qdrant_langchain = MagicMock()
|
||||
|
||||
with patch(
|
||||
"application.vectorstore.base.BaseVectorStore._get_embeddings"
|
||||
) as mock_get_emb, patch(
|
||||
"application.vectorstore.qdrant.settings"
|
||||
) as mock_settings, patch.dict(
|
||||
"sys.modules",
|
||||
{
|
||||
"qdrant_client": MagicMock(),
|
||||
"qdrant_client.models": mock_models,
|
||||
"langchain_community": MagicMock(),
|
||||
"langchain_community.vectorstores": MagicMock(),
|
||||
"langchain_community.vectorstores.qdrant": mock_qdrant_langchain,
|
||||
},
|
||||
):
|
||||
mock_emb = Mock()
|
||||
mock_emb.embed_query = Mock(return_value=[0.1, 0.2, 0.3])
|
||||
mock_emb.embed_documents = Mock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_emb.client = [None, Mock(word_embedding_dimension=768)]
|
||||
mock_get_emb.return_value = mock_emb
|
||||
|
||||
mock_settings.EMBEDDINGS_NAME = "test_model"
|
||||
mock_settings.QDRANT_COLLECTION_NAME = "test_collection"
|
||||
mock_settings.QDRANT_LOCATION = ":memory:"
|
||||
mock_settings.QDRANT_URL = None
|
||||
mock_settings.QDRANT_PORT = 6333
|
||||
mock_settings.QDRANT_GRPC_PORT = 6334
|
||||
mock_settings.QDRANT_HTTPS = False
|
||||
mock_settings.QDRANT_PREFER_GRPC = False
|
||||
mock_settings.QDRANT_API_KEY = None
|
||||
mock_settings.QDRANT_PREFIX = None
|
||||
mock_settings.QDRANT_TIMEOUT = None
|
||||
mock_settings.QDRANT_PATH = None
|
||||
mock_settings.QDRANT_DISTANCE_FUNC = "Cosine"
|
||||
|
||||
mock_docsearch = MagicMock()
|
||||
mock_collections = MagicMock()
|
||||
mock_collections.collections = [MagicMock(name="test_collection")]
|
||||
mock_docsearch.client.get_collections.return_value = mock_collections
|
||||
mock_qdrant_langchain.Qdrant.construct_instance.return_value = mock_docsearch
|
||||
|
||||
from application.vectorstore.qdrant import QdrantStore
|
||||
|
||||
store = QdrantStore(source_id=source_id, embeddings_key="key")
|
||||
|
||||
return store, mock_docsearch, mock_settings
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestQdrantStoreInit:
|
||||
def test_source_id_cleaned(self):
|
||||
store, _, _ = _make_qdrant_store(source_id="application/indexes/abc123/")
|
||||
assert store._source_id == "abc123"
|
||||
|
||||
def test_filter_constructed(self):
|
||||
store, _, _ = _make_qdrant_store(source_id="src1")
|
||||
assert store._filter is not None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestQdrantStoreSearch:
|
||||
def test_search_delegates(self):
|
||||
store, mock_ds, _ = _make_qdrant_store()
|
||||
mock_ds.similarity_search.return_value = ["result1"]
|
||||
|
||||
results = store.search("query", k=5)
|
||||
|
||||
mock_ds.similarity_search.assert_called_once()
|
||||
assert results == ["result1"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestQdrantStoreAddTexts:
|
||||
def test_add_texts_delegates(self):
|
||||
store, mock_ds, _ = _make_qdrant_store()
|
||||
mock_ds.add_texts.return_value = ["id1"]
|
||||
|
||||
result = store.add_texts(["text1"], metadatas=[{"a": 1}])
|
||||
mock_ds.add_texts.assert_called_once_with(["text1"], metadatas=[{"a": 1}])
|
||||
assert result == ["id1"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestQdrantStoreSaveLocal:
|
||||
def test_save_local_is_noop(self):
|
||||
store, _, _ = _make_qdrant_store()
|
||||
assert store.save_local() is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestQdrantStoreDeleteIndex:
|
||||
def test_delete_index(self):
|
||||
store, mock_ds, _ = _make_qdrant_store()
|
||||
|
||||
with patch("application.vectorstore.qdrant.settings") as ms:
|
||||
ms.QDRANT_COLLECTION_NAME = "test_collection"
|
||||
store.delete_index()
|
||||
|
||||
mock_ds.client.delete.assert_called_once_with(
|
||||
collection_name="test_collection",
|
||||
points_selector=store._filter,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestQdrantStoreGetChunks:
|
||||
def test_get_chunks(self):
|
||||
store, mock_ds, _ = _make_qdrant_store()
|
||||
|
||||
record1 = MagicMock()
|
||||
record1.id = "id1"
|
||||
record1.payload = {
|
||||
"page_content": "text1",
|
||||
"metadata": {"source": "test"},
|
||||
}
|
||||
record2 = MagicMock()
|
||||
record2.id = "id2"
|
||||
record2.payload = {
|
||||
"page_content": "text2",
|
||||
"metadata": {"source": "test2"},
|
||||
}
|
||||
|
||||
# First call returns records with offset, second returns empty with None offset
|
||||
mock_ds.client.scroll.side_effect = [
|
||||
([record1, record2], None),
|
||||
]
|
||||
|
||||
chunks = store.get_chunks()
|
||||
|
||||
assert len(chunks) == 2
|
||||
assert chunks[0] == {
|
||||
"doc_id": "id1",
|
||||
"text": "text1",
|
||||
"metadata": {"source": "test"},
|
||||
}
|
||||
|
||||
def test_get_chunks_pagination(self):
|
||||
store, mock_ds, _ = _make_qdrant_store()
|
||||
|
||||
record1 = MagicMock()
|
||||
record1.id = "id1"
|
||||
record1.payload = {"page_content": "text1", "metadata": {}}
|
||||
|
||||
record2 = MagicMock()
|
||||
record2.id = "id2"
|
||||
record2.payload = {"page_content": "text2", "metadata": {}}
|
||||
|
||||
mock_ds.client.scroll.side_effect = [
|
||||
([record1], "offset_token"),
|
||||
([record2], None),
|
||||
]
|
||||
|
||||
chunks = store.get_chunks()
|
||||
assert len(chunks) == 2
|
||||
assert mock_ds.client.scroll.call_count == 2
|
||||
|
||||
def test_get_chunks_returns_empty_on_error(self):
|
||||
store, mock_ds, _ = _make_qdrant_store()
|
||||
mock_ds.client.scroll.side_effect = Exception("fail")
|
||||
|
||||
assert store.get_chunks() == []
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestQdrantStoreAddChunk:
|
||||
def test_add_chunk(self):
|
||||
store, mock_ds, _ = _make_qdrant_store(source_id="src1")
|
||||
mock_ds.add_documents.return_value = ["new-id"]
|
||||
|
||||
result = store.add_chunk("hello", metadata={"key": "val"})
|
||||
|
||||
assert result == "new-id"
|
||||
mock_ds.add_documents.assert_called_once()
|
||||
doc = mock_ds.add_documents.call_args[0][0][0]
|
||||
assert doc.page_content == "hello"
|
||||
assert doc.metadata["source_id"] == "src1"
|
||||
assert doc.metadata["key"] == "val"
|
||||
|
||||
def test_add_chunk_default_metadata(self):
|
||||
store, mock_ds, _ = _make_qdrant_store(source_id="src1")
|
||||
mock_ds.add_documents.return_value = ["id"]
|
||||
|
||||
store.add_chunk("text")
|
||||
|
||||
doc = mock_ds.add_documents.call_args[0][0][0]
|
||||
assert doc.metadata["source_id"] == "src1"
|
||||
|
||||
def test_add_chunk_fallback_id(self):
|
||||
store, mock_ds, _ = _make_qdrant_store()
|
||||
mock_ds.add_documents.return_value = []
|
||||
|
||||
result = store.add_chunk("text")
|
||||
# Should return the uuid that was generated
|
||||
assert result is not None
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestQdrantStoreDeleteChunk:
|
||||
def test_delete_chunk_success(self):
|
||||
store, mock_ds, _ = _make_qdrant_store()
|
||||
|
||||
with patch("application.vectorstore.qdrant.settings") as ms:
|
||||
ms.QDRANT_COLLECTION_NAME = "test_collection"
|
||||
result = store.delete_chunk("chunk-id")
|
||||
|
||||
mock_ds.client.delete.assert_called_once_with(
|
||||
collection_name="test_collection",
|
||||
points_selector=["chunk-id"],
|
||||
)
|
||||
assert result is True
|
||||
|
||||
def test_delete_chunk_returns_false_on_error(self):
|
||||
store, mock_ds, _ = _make_qdrant_store()
|
||||
|
||||
with patch("application.vectorstore.qdrant.settings") as ms:
|
||||
ms.QDRANT_COLLECTION_NAME = "test_collection"
|
||||
mock_ds.client.delete.side_effect = Exception("fail")
|
||||
result = store.delete_chunk("bad-id")
|
||||
|
||||
assert result is False
|
||||
39
tests/vectorstore/test_vector_creator.py
Normal file
39
tests/vectorstore/test_vector_creator.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from application.vectorstore.vector_creator import VectorCreator
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestVectorCreator:
|
||||
def test_registered_vectorstores(self):
|
||||
assert "faiss" in VectorCreator.vectorstores
|
||||
assert "elasticsearch" in VectorCreator.vectorstores
|
||||
assert "mongodb" in VectorCreator.vectorstores
|
||||
assert "qdrant" in VectorCreator.vectorstores
|
||||
assert "milvus" in VectorCreator.vectorstores
|
||||
assert "pgvector" in VectorCreator.vectorstores
|
||||
|
||||
def test_create_vectorstore_invalid_type(self):
|
||||
with pytest.raises(ValueError, match="No vectorstore class found for type"):
|
||||
VectorCreator.create_vectorstore("nonexistent")
|
||||
|
||||
def test_create_vectorstore_case_insensitive(self):
|
||||
with patch.object(
|
||||
VectorCreator.vectorstores["faiss"], "__init__", return_value=None
|
||||
) as mock_init:
|
||||
mock_init.return_value = None
|
||||
VectorCreator.create_vectorstore("FAISS", source_id="test", embeddings_key="key")
|
||||
mock_init.assert_called_once_with(source_id="test", embeddings_key="key")
|
||||
|
||||
def test_create_vectorstore_passes_args(self):
|
||||
with patch.object(
|
||||
VectorCreator.vectorstores["mongodb"], "__init__", return_value=None
|
||||
) as mock_init:
|
||||
VectorCreator.create_vectorstore(
|
||||
"mongodb", source_id="src1", embeddings_key="ek", database="mydb"
|
||||
)
|
||||
mock_init.assert_called_once_with(
|
||||
source_id="src1", embeddings_key="ek", database="mydb"
|
||||
)
|
||||
Reference in New Issue
Block a user