mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-03-10 15:53:16 +00:00
Compare commits
112 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b91ee8d008 | ||
|
|
6586f08584 | ||
|
|
f49e887fe6 | ||
|
|
344066fd11 | ||
|
|
bcb8092488 | ||
|
|
1efade8bdb | ||
|
|
a5b3ff11fd | ||
|
|
084558f200 | ||
|
|
b602eae215 | ||
|
|
d02bf9c243 | ||
|
|
26a5f67df2 | ||
|
|
600fd42a83 | ||
|
|
670685139a | ||
|
|
52b6306388 | ||
|
|
f957b8948c | ||
|
|
cd0b14dd2d | ||
|
|
894703a484 | ||
|
|
521ec6f1b8 | ||
|
|
b0c5d9640a | ||
|
|
ef8e94e992 | ||
|
|
9df96a4bb4 | ||
|
|
28a428ae2f | ||
|
|
b326ec3641 | ||
|
|
fcecbc7d46 | ||
|
|
f4007f53ba | ||
|
|
d08a2453f7 | ||
|
|
3f53eea1e0 | ||
|
|
5a812a1e93 | ||
|
|
5e624cc7b1 | ||
|
|
f3d1cc8dc1 | ||
|
|
e889efeda7 | ||
|
|
0a3a95521c | ||
|
|
4ebaf6f7a9 | ||
|
|
59ac1a3f60 | ||
|
|
3af24597ee | ||
|
|
e0be6c5786 | ||
|
|
88b101ebf5 | ||
|
|
923a5d6efb | ||
|
|
734b7e42ad | ||
|
|
d9a65745df | ||
|
|
97ab623d42 | ||
|
|
14aa6cc7e8 | ||
|
|
10e77fcf24 | ||
|
|
bbb21d7c2b | ||
|
|
3bc489254b | ||
|
|
4c07ea41c3 | ||
|
|
f6720f8dfa | ||
|
|
e19ab3a066 | ||
|
|
c46099c5d7 | ||
|
|
8f1dd69e72 | ||
|
|
f26da24a2f | ||
|
|
407020de0c | ||
|
|
8e4fbcaa7d | ||
|
|
09c339953d | ||
|
|
367a05bdf6 | ||
|
|
d20b71deb9 | ||
|
|
712ce9f781 | ||
|
|
a4a3274a55 | ||
|
|
716aa71f6e | ||
|
|
e8976f9898 | ||
|
|
8496cc2444 | ||
|
|
5ef2d59e05 | ||
|
|
07bb89ae80 | ||
|
|
27a5ad8ec2 | ||
|
|
707b07c5f5 | ||
|
|
4a764afd76 | ||
|
|
ecf49d574b | ||
|
|
188de4ff2a | ||
|
|
5a75ef8ffd | ||
|
|
07279f8746 | ||
|
|
71f788b13a | ||
|
|
59c62dc580 | ||
|
|
8fb1f114bc | ||
|
|
6a4cff6699 | ||
|
|
d5310a3300 | ||
|
|
de0ea3ac49 | ||
|
|
12116b018d | ||
|
|
c3ed3b40ea | ||
|
|
b80c2aabb0 | ||
|
|
f0a3eb574e | ||
|
|
bb15855443 | ||
|
|
14ce6aebd1 | ||
|
|
2fe83723f2 | ||
|
|
e73b9e10a6 | ||
|
|
9c04c18c04 | ||
|
|
81ae09d0ec | ||
|
|
01cf221167 | ||
|
|
cd8c86c6fb | ||
|
|
52d5fd1a67 | ||
|
|
7ecc7aabda | ||
|
|
79033aee34 | ||
|
|
b6ad243e9e | ||
|
|
92ca5078c1 | ||
|
|
aca8523060 | ||
|
|
1ea0cff3a4 | ||
|
|
75793a18f0 | ||
|
|
58866b21cb | ||
|
|
660aabc437 | ||
|
|
db80b20bc2 | ||
|
|
566120e8d5 | ||
|
|
f3f0f1717d | ||
|
|
05b499fb83 | ||
|
|
7621ec609e | ||
|
|
9f511f0024 | ||
|
|
374faa2640 | ||
|
|
ba6aa5fbbe | ||
|
|
1c52a89535 | ||
|
|
e7cedbee6e | ||
|
|
15c3cc3a50 | ||
|
|
5ab3032335 | ||
|
|
1215c635a0 | ||
|
|
07d21463ca |
@@ -28,3 +28,4 @@ bin/*
|
||||
.claude/*
|
||||
.vscode/*
|
||||
.serena/*
|
||||
.bmad/*
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -32,6 +32,7 @@ GEMINI.md
|
||||
.vscode/*
|
||||
.claude/*
|
||||
.serena/*
|
||||
.bmad/*
|
||||
.mcp/cache/
|
||||
|
||||
# macOS
|
||||
|
||||
@@ -11,7 +11,7 @@ The Plus release stays in lockstep with the mainline features.
|
||||
## Differences from the Mainline
|
||||
|
||||
- Added GitHub Copilot support (OAuth login), provided by [em4go](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)
|
||||
- Added Kiro (AWS CodeWhisperer) support (OAuth login), provided by [fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration)
|
||||
- Added Kiro (AWS CodeWhisperer) support (OAuth login), provided by [fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration), [Ravens2121](https://github.com/Ravens2121/CLIProxyAPIPlus/)
|
||||
|
||||
## Contributing
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
## 与主线版本版本差异
|
||||
|
||||
- 新增 GitHub Copilot 支持(OAuth 登录),由[em4go](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)提供
|
||||
- 新增 Kiro (AWS CodeWhisperer) 支持 (OAuth 登录), 由[fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration)提供
|
||||
- 新增 Kiro (AWS CodeWhisperer) 支持 (OAuth 登录), 由[fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration)、[Ravens2121](https://github.com/Ravens2121/CLIProxyAPIPlus/)提供
|
||||
|
||||
## 贡献
|
||||
|
||||
|
||||
@@ -25,6 +25,9 @@ remote-management:
|
||||
# Disable the bundled management control panel asset download and HTTP route when true.
|
||||
disable-control-panel: false
|
||||
|
||||
# GitHub repository for the management control panel. Accepts a repository URL or releases API URL.
|
||||
panel-github-repository: "https://github.com/router-for-me/Cli-Proxy-API-Management-Center"
|
||||
|
||||
# Authentication directory (supports ~ for home directory)
|
||||
auth-dir: "~/.cli-proxy-api"
|
||||
|
||||
@@ -50,6 +53,9 @@ usage-statistics-enabled: false
|
||||
# Proxy URL. Supports socks5/http/https protocols. Example: socks5://user:pass@192.168.1.1:1080/
|
||||
proxy-url: ""
|
||||
|
||||
# When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name).
|
||||
force-model-prefix: false
|
||||
|
||||
# Number of times to retry a request. Retries will occur if the HTTP response code is 403, 408, 500, 502, 503, or 504.
|
||||
request-retry: 3
|
||||
|
||||
@@ -67,6 +73,7 @@ ws-auth: false
|
||||
# Gemini API keys
|
||||
# gemini-api-key:
|
||||
# - api-key: "AIzaSy...01"
|
||||
# prefix: "test" # optional: require calls like "test/gemini-3-pro-preview" to target this credential
|
||||
# base-url: "https://generativelanguage.googleapis.com"
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
@@ -81,6 +88,7 @@ ws-auth: false
|
||||
# Codex API keys
|
||||
# codex-api-key:
|
||||
# - api-key: "sk-atSM..."
|
||||
# prefix: "test" # optional: require calls like "test/gpt-5-codex" to target this credential
|
||||
# base-url: "https://www.example.com" # use the custom codex API endpoint
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
@@ -95,6 +103,7 @@ ws-auth: false
|
||||
# claude-api-key:
|
||||
# - api-key: "sk-atSM..." # use the official claude API key, no need to set the base url
|
||||
# - api-key: "sk-atSM..."
|
||||
# prefix: "test" # optional: require calls like "test/claude-sonnet-latest" to target this credential
|
||||
# base-url: "https://www.example.com" # use the custom claude API endpoint
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
@@ -121,6 +130,7 @@ ws-auth: false
|
||||
# OpenAI compatibility providers
|
||||
# openai-compatibility:
|
||||
# - name: "openrouter" # The name of the provider; it will be used in the user agent and other places.
|
||||
# prefix: "test" # optional: require calls like "test/kimi-k2" to target this provider's credentials
|
||||
# base-url: "https://openrouter.ai/api/v1" # The base URL of the provider.
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
@@ -135,6 +145,7 @@ ws-auth: false
|
||||
# Vertex API keys (Vertex-compatible endpoints, use API key + base URL)
|
||||
# vertex-api-key:
|
||||
# - api-key: "vk-123..." # x-goog-api-key header
|
||||
# prefix: "test" # optional: require calls like "test/vertex-pro" to target this credential
|
||||
# base-url: "https://example.com/api" # e.g. https://zenmux.ai/api
|
||||
# proxy-url: "socks5://proxy.example.com:1080" # optional per-key proxy override
|
||||
# headers:
|
||||
@@ -151,8 +162,8 @@ ws-auth: false
|
||||
# upstream-url: "https://ampcode.com"
|
||||
# # Optional: Override API key for Amp upstream (otherwise uses env or file)
|
||||
# upstream-api-key: ""
|
||||
# # Restrict Amp management routes (/api/auth, /api/user, etc.) to localhost only (recommended)
|
||||
# restrict-management-to-localhost: true
|
||||
# # Restrict Amp management routes (/api/auth, /api/user, etc.) to localhost only (default: false)
|
||||
# restrict-management-to-localhost: false
|
||||
# # Force model mappings to run before checking local API keys (default: false)
|
||||
# force-model-mappings: false
|
||||
# # Amp Model Mappings
|
||||
|
||||
@@ -1,443 +0,0 @@
|
||||
# Amp CLI Integration Guide
|
||||
|
||||
This guide explains how to use CLIProxyAPI with Amp CLI and Amp IDE extensions, enabling you to use your existing Google/ChatGPT/Claude subscriptions (via OAuth) with Amp's CLI.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Overview](#overview)
|
||||
- [Which Providers Should You Authenticate?](#which-providers-should-you-authenticate)
|
||||
- [Architecture](#architecture)
|
||||
- [Configuration](#configuration)
|
||||
- [Model Mapping Configuration](#model-mapping-configuration)
|
||||
- [Setup](#setup)
|
||||
- [Usage](#usage)
|
||||
- [Troubleshooting](#troubleshooting)
|
||||
|
||||
## Overview
|
||||
|
||||
The Amp CLI integration adds specialized routing to support Amp's API patterns while maintaining full compatibility with all existing CLIProxyAPI features. This allows you to use both traditional CLIProxyAPI features and Amp CLI with the same proxy server.
|
||||
|
||||
### Key Features
|
||||
|
||||
- **Provider route aliases**: Maps Amp's `/api/provider/{provider}/v1...` patterns to CLIProxyAPI handlers
|
||||
- **Management proxy**: Forwards OAuth and account management requests to Amp's control plane
|
||||
- **Smart fallback**: Automatically routes unconfigured models to ampcode.com
|
||||
- **Model mapping**: Route unavailable models to alternatives you have access to (e.g., `claude-opus-4.5` → `claude-sonnet-4`)
|
||||
- **Secret management**: Configurable precedence (config > env > file) with 5-minute caching
|
||||
- **Security-first**: Management routes restricted to localhost by default
|
||||
- **Automatic gzip handling**: Decompresses responses from Amp upstream
|
||||
|
||||
### What You Can Do
|
||||
|
||||
- Use Amp CLI with your Google account (Gemini 3 Pro Preview, Gemini 2.5 Pro, Gemini 2.5 Flash)
|
||||
- Use Amp CLI with your ChatGPT Plus/Pro subscription (GPT-5, GPT-5 Codex models)
|
||||
- Use Amp CLI with your Claude Pro/Max subscription (Claude Sonnet 4.5, Opus 4.1)
|
||||
- Use Amp IDE extensions (VS Code, Cursor, Windsurf, etc.) with the same proxy
|
||||
- Run multiple CLI tools (Factory + Amp) through one proxy server
|
||||
- Route unconfigured models automatically through ampcode.com
|
||||
|
||||
### Which Providers Should You Authenticate?
|
||||
|
||||
**Important**: The providers you need to authenticate depend on which models and features your installed version of Amp currently uses. Amp employs different providers for various agent modes and specialized subagents:
|
||||
|
||||
- **Smart mode**: Uses Google/Gemini models (Gemini 3 Pro)
|
||||
- **Rush mode**: Uses Anthropic/Claude models (Claude Haiku 4.5)
|
||||
- **Oracle subagent**: Uses OpenAI/GPT models (GPT-5 medium reasoning)
|
||||
- **Librarian subagent**: Uses Anthropic/Claude models (Claude Sonnet 4.5)
|
||||
- **Search subagent**: Uses Anthropic/Claude models (Claude Haiku 4.5)
|
||||
- **Review feature**: Uses Google/Gemini models (Gemini 2.5 Flash-Lite)
|
||||
|
||||
For the most current information about which models Amp uses, see the **[Amp Models Documentation](https://ampcode.com/models)**.
|
||||
|
||||
#### Fallback Behavior
|
||||
|
||||
CLIProxyAPI uses a smart fallback system:
|
||||
|
||||
1. **Provider authenticated locally** (`--login`, `--codex-login`, `--claude-login`):
|
||||
- Requests use **your OAuth subscription** (ChatGPT Plus/Pro, Claude Pro/Max, Google account)
|
||||
- You benefit from your subscription's included usage quotas
|
||||
- No Amp credits consumed
|
||||
|
||||
2. **Provider NOT authenticated locally**:
|
||||
- Requests automatically forward to **ampcode.com**
|
||||
- Uses Amp's backend provider connections
|
||||
- **Requires Amp credits** if the provider is paid (OpenAI, Anthropic paid tiers)
|
||||
- May result in errors if Amp credit balance is insufficient
|
||||
|
||||
**Recommendation**: Authenticate all providers you have subscriptions for to maximize value and minimize Amp credit usage. If you don't have subscriptions to all providers Amp uses, ensure you have sufficient Amp credits available for fallback requests.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Request Flow
|
||||
|
||||
```
|
||||
Amp CLI/IDE
|
||||
↓
|
||||
├─ Provider API requests (/api/provider/{provider}/v1/...)
|
||||
│ ↓
|
||||
│ ├─ Model configured locally?
|
||||
│ │ YES → Use local OAuth tokens (OpenAI/Claude/Gemini handlers)
|
||||
│ │ NO ↓
|
||||
│ │ ├─ Model mapping configured?
|
||||
│ │ │ YES → Rewrite model → Use local handler (free)
|
||||
│ │ │ NO → Forward to ampcode.com (uses Amp credits)
|
||||
│ ↓
|
||||
│ Response
|
||||
│
|
||||
└─ Management requests (/api/auth, /api/user, /api/threads, ...)
|
||||
↓
|
||||
├─ Localhost check (security)
|
||||
↓
|
||||
└─ Reverse proxy to ampcode.com
|
||||
↓
|
||||
Response (auto-decompressed if gzipped)
|
||||
```
|
||||
|
||||
### Components
|
||||
|
||||
The Amp integration is implemented as a modular routing module (`internal/api/modules/amp/`) with these components:
|
||||
|
||||
1. **Route Aliases** (`routes.go`): Maps Amp-style paths to standard handlers
|
||||
2. **Reverse Proxy** (`proxy.go`): Forwards management requests to ampcode.com
|
||||
3. **Fallback Handler** (`fallback_handlers.go`): Routes unconfigured models to ampcode.com
|
||||
4. **Secret Management** (`secret.go`): Multi-source API key resolution with caching
|
||||
5. **Main Module** (`amp.go`): Orchestrates registration and configuration
|
||||
|
||||
## Configuration
|
||||
|
||||
### Basic Configuration
|
||||
|
||||
Add these fields to your `config.yaml`:
|
||||
|
||||
```yaml
|
||||
# Amp upstream control plane (required for management routes)
|
||||
amp-upstream-url: "https://ampcode.com"
|
||||
|
||||
# Optional: Override API key (otherwise uses env or file)
|
||||
# amp-upstream-api-key: "your-amp-api-key"
|
||||
|
||||
# Security: restrict management routes to localhost (recommended)
|
||||
amp-restrict-management-to-localhost: true
|
||||
```
|
||||
|
||||
### Model Mapping Configuration
|
||||
|
||||
When Amp CLI requests a model that you don't have access to, you can configure mappings to route those requests to alternative models that you DO have available. This avoids consuming Amp credits for models you could handle locally.
|
||||
|
||||
```yaml
|
||||
# Route unavailable models to alternatives
|
||||
amp-model-mappings:
|
||||
# Example: Route Claude Opus 4.5 requests to Claude Sonnet 4
|
||||
- from: "claude-opus-4.5"
|
||||
to: "claude-sonnet-4"
|
||||
|
||||
# Example: Route GPT-5 requests to Gemini 2.5 Pro
|
||||
- from: "gpt-5"
|
||||
to: "gemini-2.5-pro"
|
||||
|
||||
# Example: Map older model names to newer versions
|
||||
- from: "claude-3-opus-20240229"
|
||||
to: "claude-3-5-sonnet-20241022"
|
||||
```
|
||||
|
||||
**How it works:**
|
||||
|
||||
1. Amp CLI requests a model (e.g., `claude-opus-4.5`)
|
||||
2. CLIProxyAPI checks if a local provider is available for that model
|
||||
3. If not available, it checks the model mappings
|
||||
4. If a mapping exists, the request is rewritten to use the target model
|
||||
5. The request is then handled locally (free, using your OAuth subscription)
|
||||
|
||||
**Benefits:**
|
||||
- **Save Amp credits**: Use your local subscriptions instead of forwarding to ampcode.com
|
||||
- **Hot-reload**: Mappings can be updated without restarting the proxy
|
||||
- **Structured logging**: Clear logs show when mappings are applied
|
||||
|
||||
**Routing Decision Logs:**
|
||||
|
||||
The proxy logs each routing decision with structured fields:
|
||||
|
||||
```
|
||||
[AMP] Using local provider for model: gemini-2.5-pro # Local provider (free)
|
||||
[AMP] Model mapped: claude-opus-4.5 -> claude-sonnet-4 # Mapping applied (free)
|
||||
[AMP] Forwarding to ampcode.com (uses Amp credits) - model_id: gpt-5 # Fallback (costs credits)
|
||||
```
|
||||
|
||||
### Secret Resolution Precedence
|
||||
|
||||
The Amp module resolves API keys using this precedence order:
|
||||
|
||||
| Source | Key | Priority | Cache |
|
||||
|--------|-----|----------|-------|
|
||||
| Config file | `amp-upstream-api-key` | High | No |
|
||||
| Environment | `AMP_API_KEY` | Medium | No |
|
||||
| Amp secrets file | `~/.local/share/amp/secrets.json` | Low | 5 min |
|
||||
|
||||
**Recommendation**: Use the Amp secrets file (lowest precedence) for normal usage. This file is automatically managed by `amp login`.
|
||||
|
||||
### Security Settings
|
||||
|
||||
**`amp-restrict-management-to-localhost`** (default: `true`)
|
||||
|
||||
When enabled, management routes (`/api/auth`, `/api/user`, `/api/threads`, etc.) only accept connections from localhost (127.0.0.1, ::1). This prevents:
|
||||
- Drive-by browser attacks
|
||||
- Remote access to management endpoints
|
||||
- CORS-based attacks
|
||||
- Header spoofing attacks (e.g., `X-Forwarded-For: 127.0.0.1`)
|
||||
|
||||
#### How It Works
|
||||
|
||||
This restriction uses the **actual TCP connection address** (`RemoteAddr`), not HTTP headers like `X-Forwarded-For`. This prevents header spoofing attacks but has important implications:
|
||||
|
||||
- ✅ **Works for direct connections**: Running CLIProxyAPI directly on your machine or server
|
||||
- ⚠️ **May not work behind reverse proxies**: If deploying behind nginx, Cloudflare, or other proxies, the connection will appear to come from the proxy's IP, not localhost
|
||||
|
||||
#### Reverse Proxy Deployments
|
||||
|
||||
If you need to run CLIProxyAPI behind a reverse proxy (nginx, Caddy, Cloudflare Tunnel, etc.):
|
||||
|
||||
1. **Disable the localhost restriction**:
|
||||
```yaml
|
||||
amp-restrict-management-to-localhost: false
|
||||
```
|
||||
|
||||
2. **Use alternative security measures**:
|
||||
- Firewall rules restricting access to management routes
|
||||
- Proxy-level authentication (HTTP Basic Auth, OAuth)
|
||||
- Network-level isolation (VPN, Tailscale, Cloudflare Access)
|
||||
- Bind CLIProxyAPI to `127.0.0.1` only and access via SSH tunnel
|
||||
|
||||
3. **Example nginx configuration** (blocks external access to management routes):
|
||||
```nginx
|
||||
location /api/auth { deny all; }
|
||||
location /api/user { deny all; }
|
||||
location /api/threads { deny all; }
|
||||
location /api/internal { deny all; }
|
||||
```
|
||||
|
||||
**Important**: Only disable `amp-restrict-management-to-localhost` if you understand the security implications and have other protections in place.
|
||||
|
||||
## Setup
|
||||
|
||||
### 1. Configure CLIProxyAPI
|
||||
|
||||
Create or edit `config.yaml`:
|
||||
|
||||
```yaml
|
||||
port: 8317
|
||||
auth-dir: "~/.cli-proxy-api"
|
||||
|
||||
# Amp integration
|
||||
amp-upstream-url: "https://ampcode.com"
|
||||
amp-restrict-management-to-localhost: true
|
||||
|
||||
# Other standard settings...
|
||||
debug: false
|
||||
logging-to-file: true
|
||||
```
|
||||
|
||||
### 2. Authenticate with Providers
|
||||
|
||||
Run OAuth login for the providers you want to use:
|
||||
|
||||
**Google Account (Gemini 2.5 Pro, Gemini 2.5 Flash, Gemini 3 Pro Preview):**
|
||||
```bash
|
||||
./cli-proxy-api --login
|
||||
```
|
||||
|
||||
**ChatGPT Plus/Pro (GPT-5, GPT-5 Codex):**
|
||||
```bash
|
||||
./cli-proxy-api --codex-login
|
||||
```
|
||||
|
||||
**Claude Pro/Max (Claude Sonnet 4.5, Opus 4.1):**
|
||||
```bash
|
||||
./cli-proxy-api --claude-login
|
||||
```
|
||||
|
||||
Tokens are saved to:
|
||||
- Gemini: `~/.cli-proxy-api/gemini-<email>.json`
|
||||
- OpenAI Codex: `~/.cli-proxy-api/codex-<email>.json`
|
||||
- Claude: `~/.cli-proxy-api/claude-<email>.json`
|
||||
|
||||
### 3. Start the Proxy
|
||||
|
||||
```bash
|
||||
./cli-proxy-api --config config.yaml
|
||||
```
|
||||
|
||||
Or run in background with tmux (recommended for remote servers):
|
||||
|
||||
```bash
|
||||
tmux new-session -d -s proxy "./cli-proxy-api --config config.yaml"
|
||||
```
|
||||
|
||||
### 4. Configure Amp CLI
|
||||
|
||||
#### Option A: Settings File
|
||||
|
||||
Edit `~/.config/amp/settings.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"amp.url": "http://localhost:8317"
|
||||
}
|
||||
```
|
||||
|
||||
#### Option B: Environment Variable
|
||||
|
||||
```bash
|
||||
export AMP_URL=http://localhost:8317
|
||||
```
|
||||
|
||||
### 5. Login and Use Amp
|
||||
|
||||
Login through the proxy (proxied to ampcode.com):
|
||||
|
||||
```bash
|
||||
amp login
|
||||
```
|
||||
|
||||
Use Amp as normal:
|
||||
|
||||
```bash
|
||||
amp "Write a hello world program in Python"
|
||||
```
|
||||
|
||||
### 6. (Optional) Configure Amp IDE Extension
|
||||
|
||||
The proxy also works with Amp IDE extensions for VS Code, Cursor, Windsurf, etc.
|
||||
|
||||
1. Open Amp extension settings in your IDE
|
||||
2. Set **Amp URL** to `http://localhost:8317`
|
||||
3. Login with your Amp account
|
||||
4. Start using Amp in your IDE
|
||||
|
||||
Both CLI and IDE can use the proxy simultaneously.
|
||||
|
||||
## Usage
|
||||
|
||||
### Supported Routes
|
||||
|
||||
#### Provider Aliases (Always Available)
|
||||
|
||||
These routes work even without `amp-upstream-url` configured:
|
||||
|
||||
- `/api/provider/openai/v1/chat/completions`
|
||||
- `/api/provider/openai/v1/responses`
|
||||
- `/api/provider/anthropic/v1/messages`
|
||||
- `/api/provider/google/v1beta/models/:action`
|
||||
|
||||
Amp CLI calls these routes with your OAuth-authenticated models configured in CLIProxyAPI.
|
||||
|
||||
#### Management Routes (Require `amp-upstream-url`)
|
||||
|
||||
These routes are proxied to ampcode.com:
|
||||
|
||||
- `/api/auth` - Authentication
|
||||
- `/api/user` - User profile
|
||||
- `/api/meta` - Metadata
|
||||
- `/api/threads` - Conversation threads
|
||||
- `/api/telemetry` - Usage telemetry
|
||||
- `/api/internal` - Internal APIs
|
||||
|
||||
**Security**: Restricted to localhost by default.
|
||||
|
||||
### Model Fallback Behavior
|
||||
|
||||
When Amp requests a model:
|
||||
|
||||
1. **Check local configuration**: Does CLIProxyAPI have OAuth tokens for this model's provider?
|
||||
2. **If YES**: Route to local handler (use your OAuth subscription)
|
||||
3. **If NO**: Check if a model mapping exists
|
||||
4. **If mapping exists**: Rewrite request to mapped model → Route to local handler (free)
|
||||
5. **If no mapping**: Forward to ampcode.com (uses Amp credits)
|
||||
|
||||
This enables seamless mixed usage:
|
||||
- Models you've configured (Gemini, ChatGPT, Claude) → Your OAuth subscriptions
|
||||
- Models with mappings configured → Routed to alternative local models (free)
|
||||
- Models you haven't configured and have no mapping → Amp's default providers (uses credits)
|
||||
|
||||
### Example API Calls
|
||||
|
||||
**Chat completion with local OAuth:**
|
||||
```bash
|
||||
curl http://localhost:8317/api/provider/openai/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "gpt-5",
|
||||
"messages": [{"role": "user", "content": "Hello"}]
|
||||
}'
|
||||
```
|
||||
|
||||
**Management endpoint (localhost only):**
|
||||
```bash
|
||||
curl http://localhost:8317/api/user
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
| Symptom | Likely Cause | Fix |
|
||||
|---------|--------------|-----|
|
||||
| 404 on `/api/provider/...` | Incorrect route path | Ensure exact path: `/api/provider/{provider}/v1...` |
|
||||
| 403 on `/api/user` | Non-localhost request | Run from same machine or disable `amp-restrict-management-to-localhost` (not recommended) |
|
||||
| 401/403 from provider | Missing/expired OAuth | Re-run `--codex-login` or `--claude-login` |
|
||||
| Amp gzip errors | Response decompression issue | Update to latest build; auto-decompression should handle this |
|
||||
| Models not using proxy | Wrong Amp URL | Verify `amp.url` setting or `AMP_URL` environment variable |
|
||||
| CORS errors | Protected management endpoint | Use CLI/terminal, not browser |
|
||||
|
||||
### Diagnostics
|
||||
|
||||
**Check proxy logs:**
|
||||
```bash
|
||||
# If logging-to-file: true
|
||||
tail -f logs/requests.log
|
||||
|
||||
# If running in tmux
|
||||
tmux attach-session -t proxy
|
||||
```
|
||||
|
||||
**Enable debug mode** (temporarily):
|
||||
```yaml
|
||||
debug: true
|
||||
```
|
||||
|
||||
**Test basic connectivity:**
|
||||
```bash
|
||||
# Check if proxy is running
|
||||
curl http://localhost:8317/v1/models
|
||||
|
||||
# Check Amp-specific route
|
||||
curl http://localhost:8317/api/provider/openai/v1/models
|
||||
```
|
||||
|
||||
**Verify Amp configuration:**
|
||||
```bash
|
||||
# Check if Amp is using proxy
|
||||
amp config get amp.url
|
||||
|
||||
# Or check environment
|
||||
echo $AMP_URL
|
||||
```
|
||||
|
||||
### Security Checklist
|
||||
|
||||
- ✅ Keep `amp-restrict-management-to-localhost: true` (default)
|
||||
- ✅ Don't expose proxy publicly (bind to localhost or use firewall/VPN)
|
||||
- ✅ Use the Amp secrets file (`~/.local/share/amp/secrets.json`) managed by `amp login`
|
||||
- ✅ Rotate OAuth tokens periodically by re-running login commands
|
||||
- ✅ Store config and auth-dir on encrypted disk if handling sensitive data
|
||||
- ✅ Keep proxy binary up to date for security fixes
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- [CLIProxyAPI Main Documentation](https://help.router-for.me/)
|
||||
- [Amp CLI Official Manual](https://ampcode.com/manual)
|
||||
- [Management API Reference](https://help.router-for.me/management/api)
|
||||
- [SDK Documentation](sdk-usage.md)
|
||||
|
||||
## Disclaimer
|
||||
|
||||
This integration is for personal/educational use. Using reverse proxies or alternate API bases may violate provider Terms of Service. You are solely responsible for how you use this software. Accounts may be rate-limited, locked, or banned. No warranties. Use at your own risk.
|
||||
@@ -1,392 +0,0 @@
|
||||
# Amp CLI 集成指南
|
||||
|
||||
本指南说明如何在 Amp CLI 和 Amp IDE 扩展中使用 CLIProxyAPI,通过 OAuth 让你能够把已有的 Google/ChatGPT/Claude 订阅与 Amp 的 CLI 一起使用。
|
||||
|
||||
## 目录
|
||||
|
||||
- [概述](#概述)
|
||||
- [应该认证哪些服务提供商?](#应该认证哪些服务提供商)
|
||||
- [架构](#架构)
|
||||
- [配置](#配置)
|
||||
- [设置](#设置)
|
||||
- [用法](#用法)
|
||||
- [故障排查](#故障排查)
|
||||
|
||||
## 概述
|
||||
|
||||
Amp CLI 集成为 Amp 的 API 模式添加了专用路由,同时保持与现有 CLIProxyAPI 功能的完全兼容。这样你可以在同一个代理服务器上同时使用传统 CLIProxyAPI 功能和 Amp CLI。
|
||||
|
||||
### 主要特性
|
||||
|
||||
- **提供者路由别名**:将 Amp 的 `/api/provider/{provider}/v1...` 路径映射到 CLIProxyAPI 处理器
|
||||
- **管理代理**:将 OAuth 和账号管理请求转发到 Amp 控制平面
|
||||
- **智能回退**:自动将未配置的模型路由到 ampcode.com
|
||||
- **密钥管理**:可配置优先级(配置 > 环境变量 > 文件),缓存 5 分钟
|
||||
- **安全优先**:管理路由默认限制为 localhost
|
||||
- **自动 gzip 处理**:自动解压来自 Amp 上游的响应
|
||||
|
||||
### 你可以做什么
|
||||
|
||||
- 使用 Amp CLI 搭配你的 Google 账号(Gemini 3 Pro Preview、Gemini 2.5 Pro、Gemini 2.5 Flash)
|
||||
- 使用 Amp CLI 搭配你的 ChatGPT Plus/Pro 订阅(GPT-5、GPT-5 Codex 模型)
|
||||
- 使用 Amp CLI 搭配你的 Claude Pro/Max 订阅(Claude Sonnet 4.5、Opus 4.1)
|
||||
- 将 Amp IDE 扩展(VS Code、Cursor、Windsurf 等)与同一个代理一起使用
|
||||
- 通过一个代理同时运行多个 CLI 工具(Factory + Amp)
|
||||
- 将未配置的模型自动路由到 ampcode.com
|
||||
|
||||
### 应该认证哪些服务提供商?
|
||||
|
||||
**重要**:需要认证的提供商取决于你安装的 Amp 版本当前使用的模型和功能。Amp 的不同智能模式和子代理会使用不同的提供商:
|
||||
|
||||
- **Smart 模式**:使用 Google/Gemini 模型(Gemini 3 Pro)
|
||||
- **Rush 模式**:使用 Anthropic/Claude 模型(Claude Haiku 4.5)
|
||||
- **Oracle 子代理**:使用 OpenAI/GPT 模型(GPT-5 medium reasoning)
|
||||
- **Librarian 子代理**:使用 Anthropic/Claude 模型(Claude Sonnet 4.5)
|
||||
- **Search 子代理**:使用 Anthropic/Claude 模型(Claude Haiku 4.5)
|
||||
- **Review 功能**:使用 Google/Gemini 模型(Gemini 2.5 Flash-Lite)
|
||||
|
||||
有关 Amp 当前使用哪些模型的最新信息,请参阅 **[Amp 模型文档](https://ampcode.com/models)**。
|
||||
|
||||
#### 回退行为
|
||||
|
||||
CLIProxyAPI 采用智能回退机制:
|
||||
|
||||
1. **本地已认证提供商**(`--login`、`--codex-login`、`--claude-login`):
|
||||
- 请求使用**你的 OAuth 订阅**(ChatGPT Plus/Pro、Claude Pro/Max、Google 账号)
|
||||
- 享受订阅自带的额度
|
||||
- 不消耗 Amp 额度
|
||||
|
||||
2. **本地未认证提供商**:
|
||||
- 请求自动转发到 **ampcode.com**
|
||||
- 使用 Amp 的后端提供商连接
|
||||
- 如果提供商是付费的(OpenAI、Anthropic 付费档),**需要消耗 Amp 额度**
|
||||
- 若 Amp 额度不足,可能产生错误
|
||||
|
||||
**建议**:对你有订阅的所有提供商都进行认证,以最大化价值并尽量减少 Amp 额度消耗。如果没有覆盖 Amp 使用的全部提供商,请确保为回退请求准备足够的 Amp 额度。
|
||||
|
||||
## 架构
|
||||
|
||||
### 请求流
|
||||
|
||||
```
|
||||
Amp CLI/IDE
|
||||
↓
|
||||
├─ Provider API requests (/api/provider/{provider}/v1/...)
|
||||
│ ↓
|
||||
│ ├─ Model configured locally?
|
||||
│ │ YES → Use local OAuth tokens (OpenAI/Claude/Gemini handlers)
|
||||
│ │ NO → Forward to ampcode.com (reverse proxy)
|
||||
│ ↓
|
||||
│ Response
|
||||
│
|
||||
└─ Management requests (/api/auth, /api/user, /api/threads, ...)
|
||||
↓
|
||||
├─ Localhost check (security)
|
||||
↓
|
||||
└─ Reverse proxy to ampcode.com
|
||||
↓
|
||||
Response (auto-decompressed if gzipped)
|
||||
```
|
||||
|
||||
### 组件
|
||||
|
||||
Amp 集成以模块化路由模块(`internal/api/modules/amp/`)实现,包含以下组件:
|
||||
|
||||
1. **路由别名**(`routes.go`):将 Amp 风格的路径映射到标准处理器
|
||||
2. **反向代理**(`proxy.go`):将管理请求转发到 ampcode.com
|
||||
3. **回退处理器**(`fallback_handlers.go`):将未配置的模型路由到 ampcode.com
|
||||
4. **密钥管理**(`secret.go`):多来源 API 密钥解析并带缓存
|
||||
5. **主模块**(`amp.go`):负责注册和配置
|
||||
|
||||
## 配置
|
||||
|
||||
### 基础配置
|
||||
|
||||
在 `config.yaml` 中新增以下字段:
|
||||
|
||||
```yaml
|
||||
# Amp 上游控制平面(管理路由必需)
|
||||
amp-upstream-url: "https://ampcode.com"
|
||||
|
||||
# 可选:覆盖 API key(否则使用环境变量或文件)
|
||||
# amp-upstream-api-key: "your-amp-api-key"
|
||||
|
||||
# 安全性:将管理路由限制为 localhost(推荐)
|
||||
amp-restrict-management-to-localhost: true
|
||||
```
|
||||
|
||||
### 密钥解析优先级
|
||||
|
||||
Amp 模块以如下优先级解析 API key:
|
||||
|
||||
| 来源 | 键名 | 优先级 | 缓存 |
|
||||
|------|------|--------|------|
|
||||
| 配置文件 | `amp-upstream-api-key` | 高 | 无 |
|
||||
| 环境变量 | `AMP_API_KEY` | 中 | 无 |
|
||||
| Amp 密钥文件 | `~/.local/share/amp/secrets.json` | 低 | 5 分钟 |
|
||||
|
||||
**建议**:日常使用时采用 Amp 密钥文件(最低优先级)。该文件由 `amp login` 自动管理。
|
||||
|
||||
### 安全设置
|
||||
|
||||
**`amp-restrict-management-to-localhost`**(默认:`true`)
|
||||
|
||||
启用后,管理路由(`/api/auth`、`/api/user`、`/api/threads` 等)只接受来自 localhost(127.0.0.1、::1)的连接,可防止:
|
||||
- 浏览器探测式攻击
|
||||
- 对管理端点的远程访问
|
||||
- 基于 CORS 的攻击
|
||||
- 伪造头攻击(例如 `X-Forwarded-For: 127.0.0.1`)
|
||||
|
||||
#### 工作原理
|
||||
|
||||
此限制使用**实际的 TCP 连接地址**(`RemoteAddr`),而非 `X-Forwarded-For` 等 HTTP 头,能防止头部伪造,但有重要影响:
|
||||
|
||||
- ✅ **直接连接可用**:在本机或服务器直接运行 CLIProxyAPI 时适用
|
||||
- ⚠️ **可能不适用于反向代理场景**:部署在 nginx、Cloudflare 等代理后,请求源会显示为代理 IP 而非 localhost
|
||||
|
||||
#### 反向代理部署
|
||||
|
||||
若需要在反向代理(nginx、Caddy、Cloudflare Tunnel 等)后运行 CLIProxyAPI:
|
||||
|
||||
1. **关闭 localhost 限制**:
|
||||
```yaml
|
||||
amp-restrict-management-to-localhost: false
|
||||
```
|
||||
|
||||
2. **使用替代安全措施**:
|
||||
- 防火墙规则限制管理路由访问
|
||||
- 代理层认证(HTTP Basic Auth、OAuth)
|
||||
- 网络隔离(VPN、Tailscale、Cloudflare Access)
|
||||
- 将 CLIProxyAPI 仅绑定 `127.0.0.1`,并通过 SSH 隧道访问
|
||||
|
||||
3. **nginx 示例配置**(阻止外部访问管理路由):
|
||||
```nginx
|
||||
location /api/auth { deny all; }
|
||||
location /api/user { deny all; }
|
||||
location /api/threads { deny all; }
|
||||
location /api/internal { deny all; }
|
||||
```
|
||||
|
||||
**重要**:只有在理解安全影响并已采取其他防护措施时,才关闭 `amp-restrict-management-to-localhost`。
|
||||
|
||||
## 设置
|
||||
|
||||
### 1. 配置 CLIProxyAPI
|
||||
|
||||
创建或编辑 `config.yaml`:
|
||||
|
||||
```yaml
|
||||
port: 8317
|
||||
auth-dir: "~/.cli-proxy-api"
|
||||
|
||||
# Amp 集成
|
||||
amp-upstream-url: "https://ampcode.com"
|
||||
amp-restrict-management-to-localhost: true
|
||||
|
||||
# 其他常规设置...
|
||||
debug: false
|
||||
logging-to-file: true
|
||||
```
|
||||
|
||||
### 2. 认证提供商
|
||||
|
||||
为要使用的提供商执行 OAuth 登录:
|
||||
|
||||
**Google 账号(Gemini 2.5 Pro、Gemini 2.5 Flash、Gemini 3 Pro Preview):**
|
||||
```bash
|
||||
./cli-proxy-api --login
|
||||
```
|
||||
|
||||
**ChatGPT Plus/Pro(GPT-5、GPT-5 Codex):**
|
||||
```bash
|
||||
./cli-proxy-api --codex-login
|
||||
```
|
||||
|
||||
**Claude Pro/Max(Claude Sonnet 4.5、Opus 4.1):**
|
||||
```bash
|
||||
./cli-proxy-api --claude-login
|
||||
```
|
||||
|
||||
令牌会保存到:
|
||||
- Gemini: `~/.cli-proxy-api/gemini-<email>.json`
|
||||
- OpenAI Codex: `~/.cli-proxy-api/codex-<email>.json`
|
||||
- Claude: `~/.cli-proxy-api/claude-<email>.json`
|
||||
|
||||
### 3. 启动代理
|
||||
|
||||
```bash
|
||||
./cli-proxy-api --config config.yaml
|
||||
```
|
||||
|
||||
或使用 tmux 在后台运行(推荐用于远程服务器):
|
||||
|
||||
```bash
|
||||
tmux new-session -d -s proxy "./cli-proxy-api --config config.yaml"
|
||||
```
|
||||
|
||||
### 4. 配置 Amp CLI
|
||||
|
||||
#### 方案 A:配置文件
|
||||
|
||||
编辑 `~/.config/amp/settings.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"amp.url": "http://localhost:8317"
|
||||
}
|
||||
```
|
||||
|
||||
#### 方案 B:环境变量
|
||||
|
||||
```bash
|
||||
export AMP_URL=http://localhost:8317
|
||||
```
|
||||
|
||||
### 5. 登录并使用 Amp
|
||||
|
||||
通过代理登录(请求会被代理到 ampcode.com):
|
||||
|
||||
```bash
|
||||
amp login
|
||||
```
|
||||
|
||||
像平常一样使用 Amp:
|
||||
|
||||
```bash
|
||||
amp "Write a hello world program in Python"
|
||||
```
|
||||
|
||||
### 6. (可选)配置 Amp IDE 扩展
|
||||
|
||||
该代理同样适用于 VS Code、Cursor、Windsurf 等 Amp IDE 扩展。
|
||||
|
||||
1. 在 IDE 中打开 Amp 扩展设置
|
||||
2. 将 **Amp URL** 设置为 `http://localhost:8317`
|
||||
3. 用你的 Amp 账号登录
|
||||
4. 在 IDE 中开始使用 Amp
|
||||
|
||||
CLI 和 IDE 可同时使用该代理。
|
||||
|
||||
## 用法
|
||||
|
||||
### 支持的路由
|
||||
|
||||
#### 提供商别名(始终可用)
|
||||
|
||||
这些路由即使未配置 `amp-upstream-url` 也可使用:
|
||||
|
||||
- `/api/provider/openai/v1/chat/completions`
|
||||
- `/api/provider/openai/v1/responses`
|
||||
- `/api/provider/anthropic/v1/messages`
|
||||
- `/api/provider/google/v1beta/models/:action`
|
||||
|
||||
Amp CLI 会使用你在 CLIProxyAPI 中通过 OAuth 认证的模型来调用这些路由。
|
||||
|
||||
#### 管理路由(需要 `amp-upstream-url`)
|
||||
|
||||
这些路由会被代理到 ampcode.com:
|
||||
|
||||
- `/api/auth` - 认证
|
||||
- `/api/user` - 用户资料
|
||||
- `/api/meta` - 元数据
|
||||
- `/api/threads` - 会话线程
|
||||
- `/api/telemetry` - 使用遥测
|
||||
- `/api/internal` - 内部 API
|
||||
|
||||
**安全性**:默认限制为 localhost。
|
||||
|
||||
### 模型回退行为
|
||||
|
||||
当 Amp 请求模型时:
|
||||
|
||||
1. **检查本地配置**:CLIProxyAPI 是否有该模型提供商的 OAuth 令牌?
|
||||
2. **如果有**:路由到本地处理器(使用你的 OAuth 订阅)
|
||||
3. **如果没有**:转发到 ampcode.com(使用 Amp 的默认路由)
|
||||
|
||||
这实现了无缝混用:
|
||||
- 你已配置的模型(Gemini、ChatGPT、Claude)→ 你的 OAuth 订阅
|
||||
- 未配置的模型 → Amp 的默认提供商
|
||||
|
||||
### 示例 API 调用
|
||||
|
||||
**使用本地 OAuth 的聊天补全:**
|
||||
```bash
|
||||
curl http://localhost:8317/api/provider/openai/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "gpt-5",
|
||||
"messages": [{"role": "user", "content": "Hello"}]
|
||||
}'
|
||||
```
|
||||
|
||||
**管理端点(仅限 localhost):**
|
||||
```bash
|
||||
curl http://localhost:8317/api/user
|
||||
```
|
||||
|
||||
## 故障排查
|
||||
|
||||
### 常见问题
|
||||
|
||||
| 症状 | 可能原因 | 解决方案 |
|
||||
|------|----------|----------|
|
||||
| `/api/provider/...` 返回 404 | 路径错误 | 确保路径准确:`/api/provider/{provider}/v1...` |
|
||||
| `/api/user` 返回 403 | 非 localhost 请求 | 在同一机器上访问,或关闭 `amp-restrict-management-to-localhost`(不推荐) |
|
||||
| 提供商返回 401/403 | OAuth 缺失或过期 | 重新运行 `--codex-login` 或 `--claude-login` |
|
||||
| Amp gzip 错误 | 响应解压问题 | 更新到最新构建;自动解压应能处理 |
|
||||
| 模型未走代理 | Amp URL 设置错误 | 检查 `amp.url` 设置或 `AMP_URL` 环境变量 |
|
||||
| CORS 错误 | 受保护的管理端点 | 使用 CLI/终端而非浏览器 |
|
||||
|
||||
### 诊断
|
||||
|
||||
**查看代理日志:**
|
||||
```bash
|
||||
# 若 logging-to-file: true
|
||||
tail -f logs/requests.log
|
||||
|
||||
# 若运行在 tmux 中
|
||||
tmux attach-session -t proxy
|
||||
```
|
||||
|
||||
**临时开启调试模式:**
|
||||
```yaml
|
||||
debug: true
|
||||
```
|
||||
|
||||
**测试基础连通性:**
|
||||
```bash
|
||||
# 检查代理是否运行
|
||||
curl http://localhost:8317/v1/models
|
||||
|
||||
# 检查 Amp 特定路由
|
||||
curl http://localhost:8317/api/provider/openai/v1/models
|
||||
```
|
||||
|
||||
**验证 Amp 配置:**
|
||||
```bash
|
||||
# 检查 Amp 是否使用代理
|
||||
amp config get amp.url
|
||||
|
||||
# 或检查环境变量
|
||||
echo $AMP_URL
|
||||
```
|
||||
|
||||
### 安全清单
|
||||
|
||||
- ✅ 保持 `amp-restrict-management-to-localhost: true`(默认)
|
||||
- ✅ 不要将代理暴露到公共网络(绑定到 localhost 或使用防火墙/VPN)
|
||||
- ✅ 使用 `amp login` 管理的 Amp 密钥文件(`~/.local/share/amp/secrets.json`)
|
||||
- ✅ 定期重新登录轮换 OAuth 令牌
|
||||
- ✅ 若处理敏感数据,使用加密磁盘存储配置和 auth-dir
|
||||
- ✅ 保持代理二进制为最新版本以获取安全修复
|
||||
|
||||
## 其他资源
|
||||
|
||||
- [CLIProxyAPI 主文档](https://help.router-for.me/)
|
||||
- [Amp CLI 官方手册](https://ampcode.com/manual)
|
||||
- [管理 API 参考](https://help.router-for.me/management/api)
|
||||
- [SDK 文档](sdk-usage.md)
|
||||
|
||||
## 免责声明
|
||||
|
||||
此集成仅用于个人或教育用途。使用反向代理或替代 API 基址可能违反提供商的服务条款。你需要对自己的使用方式负责。账号可能会被限速、锁定或封禁。软件不附带任何保证,使用风险自负。
|
||||
12
go.mod
12
go.mod
@@ -18,10 +18,10 @@ require (
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
github.com/tiktoken-go/tokenizer v0.7.0
|
||||
golang.org/x/crypto v0.43.0
|
||||
golang.org/x/net v0.46.0
|
||||
golang.org/x/crypto v0.45.0
|
||||
golang.org/x/net v0.47.0
|
||||
golang.org/x/oauth2 v0.30.0
|
||||
golang.org/x/term v0.36.0
|
||||
golang.org/x/term v0.37.0
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
@@ -69,9 +69,9 @@ require (
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/sync v0.17.0 // indirect
|
||||
golang.org/x/sys v0.37.0 // indirect
|
||||
golang.org/x/text v0.30.0 // indirect
|
||||
golang.org/x/sync v0.18.0 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.org/x/text v0.31.0 // indirect
|
||||
google.golang.org/protobuf v1.34.1 // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
)
|
||||
|
||||
24
go.sum
24
go.sum
@@ -160,23 +160,23 @@ github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZ
|
||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
|
||||
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
||||
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
|
||||
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
|
||||
golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4=
|
||||
golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210=
|
||||
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
|
||||
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
|
||||
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
||||
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
||||
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
|
||||
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q=
|
||||
golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss=
|
||||
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
|
||||
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
|
||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU=
|
||||
golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254=
|
||||
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
|
||||
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
|
||||
|
||||
@@ -3,6 +3,9 @@ package management
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -23,9 +26,11 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
||||
geminiAuth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini"
|
||||
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
|
||||
kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
@@ -289,6 +294,54 @@ func (h *Handler) ListAuthFiles(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"files": files})
|
||||
}
|
||||
|
||||
// GetAuthFileModels returns the models supported by a specific auth file
|
||||
func (h *Handler) GetAuthFileModels(c *gin.Context) {
|
||||
name := c.Query("name")
|
||||
if name == "" {
|
||||
c.JSON(400, gin.H{"error": "name is required"})
|
||||
return
|
||||
}
|
||||
|
||||
// Try to find auth ID via authManager
|
||||
var authID string
|
||||
if h.authManager != nil {
|
||||
auths := h.authManager.List()
|
||||
for _, auth := range auths {
|
||||
if auth.FileName == name || auth.ID == name {
|
||||
authID = auth.ID
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if authID == "" {
|
||||
authID = name // fallback to filename as ID
|
||||
}
|
||||
|
||||
// Get models from registry
|
||||
reg := registry.GetGlobalRegistry()
|
||||
models := reg.GetModelsForClient(authID)
|
||||
|
||||
result := make([]gin.H, 0, len(models))
|
||||
for _, m := range models {
|
||||
entry := gin.H{
|
||||
"id": m.ID,
|
||||
}
|
||||
if m.DisplayName != "" {
|
||||
entry["display_name"] = m.DisplayName
|
||||
}
|
||||
if m.Type != "" {
|
||||
entry["type"] = m.Type
|
||||
}
|
||||
if m.OwnedBy != "" {
|
||||
entry["owned_by"] = m.OwnedBy
|
||||
}
|
||||
result = append(result, entry)
|
||||
}
|
||||
|
||||
c.JSON(200, gin.H{"models": result})
|
||||
}
|
||||
|
||||
// List auth files from disk when the auth manager is unavailable.
|
||||
func (h *Handler) listAuthFilesFromDisk(c *gin.Context) {
|
||||
entries, err := os.ReadDir(h.cfg.AuthDir)
|
||||
@@ -1745,6 +1798,17 @@ func (h *Handler) RequestIFlowCookieToken(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Check for duplicate BXAuth before authentication
|
||||
bxAuth := iflowauth.ExtractBXAuth(cookieValue)
|
||||
if existingFile, err := iflowauth.CheckDuplicateBXAuth(h.cfg.AuthDir, bxAuth); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to check duplicate"})
|
||||
return
|
||||
} else if existingFile != "" {
|
||||
existingFileName := filepath.Base(existingFile)
|
||||
c.JSON(http.StatusConflict, gin.H{"status": "error", "error": "duplicate BXAuth found", "existing_file": existingFileName})
|
||||
return
|
||||
}
|
||||
|
||||
authSvc := iflowauth.NewIFlowAuth(h.cfg)
|
||||
tokenData, errAuth := authSvc.AuthenticateWithCookie(ctx, cookieValue)
|
||||
if errAuth != nil {
|
||||
@@ -1767,11 +1831,12 @@ func (h *Handler) RequestIFlowCookieToken(c *gin.Context) {
|
||||
}
|
||||
|
||||
tokenStorage.Email = email
|
||||
timestamp := time.Now().Unix()
|
||||
|
||||
record := &coreauth.Auth{
|
||||
ID: fmt.Sprintf("iflow-%s.json", fileName),
|
||||
ID: fmt.Sprintf("iflow-%s-%d.json", fileName, timestamp),
|
||||
Provider: "iflow",
|
||||
FileName: fmt.Sprintf("iflow-%s.json", fileName),
|
||||
FileName: fmt.Sprintf("iflow-%s-%d.json", fileName, timestamp),
|
||||
Storage: tokenStorage,
|
||||
Metadata: map[string]any{
|
||||
"email": email,
|
||||
@@ -2142,9 +2207,35 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec
|
||||
|
||||
func (h *Handler) GetAuthStatus(c *gin.Context) {
|
||||
state := c.Query("state")
|
||||
if err, ok := getOAuthStatus(state); ok {
|
||||
if err != "" {
|
||||
c.JSON(200, gin.H{"status": "error", "error": err})
|
||||
if statusValue, ok := getOAuthStatus(state); ok {
|
||||
if statusValue != "" {
|
||||
// Check for device_code prefix (Kiro AWS Builder ID flow)
|
||||
// Format: "device_code|verification_url|user_code"
|
||||
// Using "|" as separator because URLs contain ":"
|
||||
if strings.HasPrefix(statusValue, "device_code|") {
|
||||
parts := strings.SplitN(statusValue, "|", 3)
|
||||
if len(parts) == 3 {
|
||||
c.JSON(200, gin.H{
|
||||
"status": "device_code",
|
||||
"verification_url": parts[1],
|
||||
"user_code": parts[2],
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
// Check for auth_url prefix (Kiro social auth flow)
|
||||
// Format: "auth_url|url"
|
||||
// Using "|" as separator because URLs contain ":"
|
||||
if strings.HasPrefix(statusValue, "auth_url|") {
|
||||
authURL := strings.TrimPrefix(statusValue, "auth_url|")
|
||||
c.JSON(200, gin.H{
|
||||
"status": "auth_url",
|
||||
"url": authURL,
|
||||
})
|
||||
return
|
||||
}
|
||||
// Otherwise treat as error
|
||||
c.JSON(200, gin.H{"status": "error", "error": statusValue})
|
||||
} else {
|
||||
c.JSON(200, gin.H{"status": "wait"})
|
||||
return
|
||||
@@ -2154,3 +2245,295 @@ func (h *Handler) GetAuthStatus(c *gin.Context) {
|
||||
}
|
||||
deleteOAuthStatus(state)
|
||||
}
|
||||
|
||||
const kiroCallbackPort = 9876
|
||||
|
||||
func (h *Handler) RequestKiroToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Get the login method from query parameter (default: aws for device code flow)
|
||||
method := strings.ToLower(strings.TrimSpace(c.Query("method")))
|
||||
if method == "" {
|
||||
method = "aws"
|
||||
}
|
||||
|
||||
fmt.Println("Initializing Kiro authentication...")
|
||||
|
||||
state := fmt.Sprintf("kiro-%d", time.Now().UnixNano())
|
||||
|
||||
switch method {
|
||||
case "aws", "builder-id":
|
||||
// AWS Builder ID uses device code flow (no callback needed)
|
||||
go func() {
|
||||
ssoClient := kiroauth.NewSSOOIDCClient(h.cfg)
|
||||
|
||||
// Step 1: Register client
|
||||
fmt.Println("Registering client...")
|
||||
regResp, err := ssoClient.RegisterClient(ctx)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to register client: %v", err)
|
||||
setOAuthStatus(state, "Failed to register client")
|
||||
return
|
||||
}
|
||||
|
||||
// Step 2: Start device authorization
|
||||
fmt.Println("Starting device authorization...")
|
||||
authResp, err := ssoClient.StartDeviceAuthorization(ctx, regResp.ClientID, regResp.ClientSecret)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to start device auth: %v", err)
|
||||
setOAuthStatus(state, "Failed to start device authorization")
|
||||
return
|
||||
}
|
||||
|
||||
// Store the verification URL for the frontend to display
|
||||
// Using "|" as separator because URLs contain ":"
|
||||
setOAuthStatus(state, "device_code|"+authResp.VerificationURIComplete+"|"+authResp.UserCode)
|
||||
|
||||
// Step 3: Poll for token
|
||||
fmt.Println("Waiting for authorization...")
|
||||
interval := 5 * time.Second
|
||||
if authResp.Interval > 0 {
|
||||
interval = time.Duration(authResp.Interval) * time.Second
|
||||
}
|
||||
deadline := time.Now().Add(time.Duration(authResp.ExpiresIn) * time.Second)
|
||||
|
||||
for time.Now().Before(deadline) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
setOAuthStatus(state, "Authorization cancelled")
|
||||
return
|
||||
case <-time.After(interval):
|
||||
tokenResp, err := ssoClient.CreateToken(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode)
|
||||
if err != nil {
|
||||
errStr := err.Error()
|
||||
if strings.Contains(errStr, "authorization_pending") {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(errStr, "slow_down") {
|
||||
interval += 5 * time.Second
|
||||
continue
|
||||
}
|
||||
log.Errorf("Token creation failed: %v", err)
|
||||
setOAuthStatus(state, "Token creation failed")
|
||||
return
|
||||
}
|
||||
|
||||
// Success! Save the token
|
||||
expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
|
||||
email := kiroauth.ExtractEmailFromJWT(tokenResp.AccessToken)
|
||||
|
||||
idPart := kiroauth.SanitizeEmailForFilename(email)
|
||||
if idPart == "" {
|
||||
idPart = fmt.Sprintf("%d", time.Now().UnixNano()%100000)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
fileName := fmt.Sprintf("kiro-aws-%s.json", idPart)
|
||||
|
||||
record := &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: "kiro",
|
||||
FileName: fileName,
|
||||
Metadata: map[string]any{
|
||||
"type": "kiro",
|
||||
"access_token": tokenResp.AccessToken,
|
||||
"refresh_token": tokenResp.RefreshToken,
|
||||
"expires_at": expiresAt.Format(time.RFC3339),
|
||||
"auth_method": "builder-id",
|
||||
"provider": "AWS",
|
||||
"client_id": regResp.ClientID,
|
||||
"client_secret": regResp.ClientSecret,
|
||||
"email": email,
|
||||
"last_refresh": now.Format(time.RFC3339),
|
||||
},
|
||||
}
|
||||
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
if errSave != nil {
|
||||
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
||||
setOAuthStatus(state, "Failed to save authentication tokens")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
||||
if email != "" {
|
||||
fmt.Printf("Authenticated as: %s\n", email)
|
||||
}
|
||||
deleteOAuthStatus(state)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
setOAuthStatus(state, "Authorization timed out")
|
||||
}()
|
||||
|
||||
// Return immediately with the state for polling
|
||||
c.JSON(200, gin.H{"status": "ok", "state": state, "method": "device_code"})
|
||||
|
||||
case "google", "github":
|
||||
// Social auth uses protocol handler - for WEB UI we use a callback forwarder
|
||||
provider := "Google"
|
||||
if method == "github" {
|
||||
provider = "Github"
|
||||
}
|
||||
|
||||
isWebUI := isWebUIRequest(c)
|
||||
if isWebUI {
|
||||
targetURL, errTarget := h.managementCallbackURL("/kiro/callback")
|
||||
if errTarget != nil {
|
||||
log.WithError(errTarget).Error("failed to compute kiro callback target")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
|
||||
return
|
||||
}
|
||||
if _, errStart := startCallbackForwarder(kiroCallbackPort, "kiro", targetURL); errStart != nil {
|
||||
log.WithError(errStart).Error("failed to start kiro callback forwarder")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
if isWebUI {
|
||||
defer stopCallbackForwarder(kiroCallbackPort)
|
||||
}
|
||||
|
||||
socialClient := kiroauth.NewSocialAuthClient(h.cfg)
|
||||
|
||||
// Generate PKCE codes
|
||||
codeVerifier, codeChallenge, err := generateKiroPKCE()
|
||||
if err != nil {
|
||||
log.Errorf("Failed to generate PKCE: %v", err)
|
||||
setOAuthStatus(state, "Failed to generate PKCE")
|
||||
return
|
||||
}
|
||||
|
||||
// Build login URL
|
||||
authURL := fmt.Sprintf("%s/login?idp=%s&redirect_uri=%s&code_challenge=%s&code_challenge_method=S256&state=%s&prompt=select_account",
|
||||
"https://prod.us-east-1.auth.desktop.kiro.dev",
|
||||
provider,
|
||||
url.QueryEscape(kiroauth.KiroRedirectURI),
|
||||
codeChallenge,
|
||||
state,
|
||||
)
|
||||
|
||||
// Store auth URL for frontend
|
||||
// Using "|" as separator because URLs contain ":"
|
||||
setOAuthStatus(state, "auth_url|"+authURL)
|
||||
|
||||
// Wait for callback file
|
||||
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-kiro-%s.oauth", state))
|
||||
deadline := time.Now().Add(5 * time.Minute)
|
||||
|
||||
for {
|
||||
if time.Now().After(deadline) {
|
||||
log.Error("oauth flow timed out")
|
||||
setOAuthStatus(state, "OAuth flow timed out")
|
||||
return
|
||||
}
|
||||
if data, errR := os.ReadFile(waitFile); errR == nil {
|
||||
var m map[string]string
|
||||
_ = json.Unmarshal(data, &m)
|
||||
_ = os.Remove(waitFile)
|
||||
if errStr := m["error"]; errStr != "" {
|
||||
log.Errorf("Authentication failed: %s", errStr)
|
||||
setOAuthStatus(state, "Authentication failed")
|
||||
return
|
||||
}
|
||||
if m["state"] != state {
|
||||
log.Errorf("State mismatch")
|
||||
setOAuthStatus(state, "State mismatch")
|
||||
return
|
||||
}
|
||||
code := m["code"]
|
||||
if code == "" {
|
||||
log.Error("No authorization code received")
|
||||
setOAuthStatus(state, "No authorization code received")
|
||||
return
|
||||
}
|
||||
|
||||
// Exchange code for tokens
|
||||
tokenReq := &kiroauth.CreateTokenRequest{
|
||||
Code: code,
|
||||
CodeVerifier: codeVerifier,
|
||||
RedirectURI: kiroauth.KiroRedirectURI,
|
||||
}
|
||||
|
||||
tokenResp, errToken := socialClient.CreateToken(ctx, tokenReq)
|
||||
if errToken != nil {
|
||||
log.Errorf("Failed to exchange code for tokens: %v", errToken)
|
||||
setOAuthStatus(state, "Failed to exchange code for tokens")
|
||||
return
|
||||
}
|
||||
|
||||
// Save the token
|
||||
expiresIn := tokenResp.ExpiresIn
|
||||
if expiresIn <= 0 {
|
||||
expiresIn = 3600
|
||||
}
|
||||
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
|
||||
email := kiroauth.ExtractEmailFromJWT(tokenResp.AccessToken)
|
||||
|
||||
idPart := kiroauth.SanitizeEmailForFilename(email)
|
||||
if idPart == "" {
|
||||
idPart = fmt.Sprintf("%d", time.Now().UnixNano()%100000)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
fileName := fmt.Sprintf("kiro-%s-%s.json", strings.ToLower(provider), idPart)
|
||||
|
||||
record := &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: "kiro",
|
||||
FileName: fileName,
|
||||
Metadata: map[string]any{
|
||||
"type": "kiro",
|
||||
"access_token": tokenResp.AccessToken,
|
||||
"refresh_token": tokenResp.RefreshToken,
|
||||
"profile_arn": tokenResp.ProfileArn,
|
||||
"expires_at": expiresAt.Format(time.RFC3339),
|
||||
"auth_method": "social",
|
||||
"provider": provider,
|
||||
"email": email,
|
||||
"last_refresh": now.Format(time.RFC3339),
|
||||
},
|
||||
}
|
||||
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
if errSave != nil {
|
||||
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
||||
setOAuthStatus(state, "Failed to save authentication tokens")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
||||
if email != "" {
|
||||
fmt.Printf("Authenticated as: %s\n", email)
|
||||
}
|
||||
deleteOAuthStatus(state)
|
||||
return
|
||||
}
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
}()
|
||||
|
||||
setOAuthStatus(state, "")
|
||||
c.JSON(200, gin.H{"status": "ok", "state": state, "method": "social"})
|
||||
|
||||
default:
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid method, use 'aws', 'google', or 'github'"})
|
||||
}
|
||||
}
|
||||
|
||||
// generateKiroPKCE generates PKCE code verifier and challenge for Kiro OAuth.
|
||||
func generateKiroPKCE() (verifier, challenge string, err error) {
|
||||
b := make([]byte, 32)
|
||||
if _, err := io.ReadFull(rand.Reader, b); err != nil {
|
||||
return "", "", fmt.Errorf("failed to generate random bytes: %w", err)
|
||||
}
|
||||
verifier = base64.RawURLEncoding.EncodeToString(b)
|
||||
|
||||
h := sha256.Sum256([]byte(verifier))
|
||||
challenge = base64.RawURLEncoding.EncodeToString(h[:])
|
||||
|
||||
return verifier, challenge, nil
|
||||
}
|
||||
|
||||
@@ -71,22 +71,64 @@ func (w *ResponseWriterWrapper) Write(data []byte) (int, error) {
|
||||
n, err := w.ResponseWriter.Write(data)
|
||||
|
||||
// THEN: Handle logging based on response type
|
||||
if w.isStreaming {
|
||||
if w.isStreaming && w.chunkChannel != nil {
|
||||
// For streaming responses: Send to async logging channel (non-blocking)
|
||||
if w.chunkChannel != nil {
|
||||
select {
|
||||
case w.chunkChannel <- append([]byte(nil), data...): // Non-blocking send with copy
|
||||
default: // Channel full, skip logging to avoid blocking
|
||||
}
|
||||
select {
|
||||
case w.chunkChannel <- append([]byte(nil), data...): // Non-blocking send with copy
|
||||
default: // Channel full, skip logging to avoid blocking
|
||||
}
|
||||
} else {
|
||||
// For non-streaming responses: Buffer complete response
|
||||
return n, err
|
||||
}
|
||||
|
||||
if w.shouldBufferResponseBody() {
|
||||
w.body.Write(data)
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (w *ResponseWriterWrapper) shouldBufferResponseBody() bool {
|
||||
if w.logger != nil && w.logger.IsEnabled() {
|
||||
return true
|
||||
}
|
||||
if !w.logOnErrorOnly {
|
||||
return false
|
||||
}
|
||||
status := w.statusCode
|
||||
if status == 0 {
|
||||
if statusWriter, ok := w.ResponseWriter.(interface{ Status() int }); ok && statusWriter != nil {
|
||||
status = statusWriter.Status()
|
||||
} else {
|
||||
status = http.StatusOK
|
||||
}
|
||||
}
|
||||
return status >= http.StatusBadRequest
|
||||
}
|
||||
|
||||
// WriteString wraps the underlying ResponseWriter's WriteString method to capture response data.
|
||||
// Some handlers (and fmt/io helpers) write via io.StringWriter; without this override, those writes
|
||||
// bypass Write() and would be missing from request logs.
|
||||
func (w *ResponseWriterWrapper) WriteString(data string) (int, error) {
|
||||
w.ensureHeadersCaptured()
|
||||
|
||||
// CRITICAL: Write to client first (zero latency)
|
||||
n, err := w.ResponseWriter.WriteString(data)
|
||||
|
||||
// THEN: Capture for logging
|
||||
if w.isStreaming && w.chunkChannel != nil {
|
||||
select {
|
||||
case w.chunkChannel <- []byte(data):
|
||||
default:
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
if w.shouldBufferResponseBody() {
|
||||
w.body.WriteString(data)
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// WriteHeader wraps the underlying ResponseWriter's WriteHeader method.
|
||||
// It captures the status code, detects if the response is streaming based on the Content-Type header,
|
||||
// and initializes the appropriate logging mechanism (standard or streaming).
|
||||
@@ -160,12 +202,16 @@ func (w *ResponseWriterWrapper) detectStreaming(contentType string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check request body for streaming indicators
|
||||
if w.requestInfo.Body != nil {
|
||||
// If a concrete Content-Type is already set (e.g., application/json for error responses),
|
||||
// treat it as non-streaming instead of inferring from the request payload.
|
||||
if strings.TrimSpace(contentType) != "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Only fall back to request payload hints when Content-Type is not set yet.
|
||||
if w.requestInfo != nil && len(w.requestInfo.Body) > 0 {
|
||||
bodyStr := string(w.requestInfo.Body)
|
||||
if strings.Contains(bodyStr, `"stream": true`) || strings.Contains(bodyStr, `"stream":true`) {
|
||||
return true
|
||||
}
|
||||
return strings.Contains(bodyStr, `"stream": true`) || strings.Contains(bodyStr, `"stream":true`)
|
||||
}
|
||||
|
||||
return false
|
||||
@@ -221,7 +267,7 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if w.isStreaming {
|
||||
if w.isStreaming && w.streamWriter != nil {
|
||||
if w.chunkChannel != nil {
|
||||
close(w.chunkChannel)
|
||||
w.chunkChannel = nil
|
||||
@@ -233,24 +279,19 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
|
||||
}
|
||||
|
||||
// Write API Request and Response to the streaming log before closing
|
||||
if w.streamWriter != nil {
|
||||
apiRequest := w.extractAPIRequest(c)
|
||||
if len(apiRequest) > 0 {
|
||||
_ = w.streamWriter.WriteAPIRequest(apiRequest)
|
||||
}
|
||||
apiResponse := w.extractAPIResponse(c)
|
||||
if len(apiResponse) > 0 {
|
||||
_ = w.streamWriter.WriteAPIResponse(apiResponse)
|
||||
}
|
||||
if err := w.streamWriter.Close(); err != nil {
|
||||
w.streamWriter = nil
|
||||
return err
|
||||
}
|
||||
apiRequest := w.extractAPIRequest(c)
|
||||
if len(apiRequest) > 0 {
|
||||
_ = w.streamWriter.WriteAPIRequest(apiRequest)
|
||||
}
|
||||
apiResponse := w.extractAPIResponse(c)
|
||||
if len(apiResponse) > 0 {
|
||||
_ = w.streamWriter.WriteAPIResponse(apiResponse)
|
||||
}
|
||||
if err := w.streamWriter.Close(); err != nil {
|
||||
w.streamWriter = nil
|
||||
return err
|
||||
}
|
||||
if forceLog {
|
||||
return w.logRequest(finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), slicesAPIResponseError, forceLog)
|
||||
}
|
||||
w.streamWriter = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -335,26 +376,3 @@ func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]
|
||||
apiResponseErrors,
|
||||
)
|
||||
}
|
||||
|
||||
// Status returns the HTTP response status code captured by the wrapper.
|
||||
// It defaults to 200 if WriteHeader has not been called.
|
||||
func (w *ResponseWriterWrapper) Status() int {
|
||||
if w.statusCode == 0 {
|
||||
return 200 // Default status code
|
||||
}
|
||||
return w.statusCode
|
||||
}
|
||||
|
||||
// Size returns the size of the response body in bytes for non-streaming responses.
|
||||
// For streaming responses, it returns -1, as the total size is unknown.
|
||||
func (w *ResponseWriterWrapper) Size() int {
|
||||
if w.isStreaming {
|
||||
return -1 // Unknown size for streaming responses
|
||||
}
|
||||
return w.body.Len()
|
||||
}
|
||||
|
||||
// Written returns true if the response header has been written (i.e., a status code has been set).
|
||||
func (w *ResponseWriterWrapper) Written() bool {
|
||||
return w.statusCode != 0
|
||||
}
|
||||
|
||||
@@ -137,7 +137,8 @@ func (m *AmpModule) Register(ctx modules.Context) error {
|
||||
m.registerProviderAliases(ctx.Engine, ctx.BaseHandler, auth)
|
||||
|
||||
// Register management proxy routes once; middleware will gate access when upstream is unavailable.
|
||||
m.registerManagementRoutes(ctx.Engine, ctx.BaseHandler)
|
||||
// Pass auth middleware to require valid API key for all management routes.
|
||||
m.registerManagementRoutes(ctx.Engine, ctx.BaseHandler, auth)
|
||||
|
||||
// If no upstream URL, skip proxy routes but provider aliases are still available
|
||||
if upstreamURL == "" {
|
||||
@@ -187,9 +188,6 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
|
||||
|
||||
if oldSettings != nil && oldSettings.RestrictManagementToLocalhost != newSettings.RestrictManagementToLocalhost {
|
||||
m.setRestrictToLocalhost(newSettings.RestrictManagementToLocalhost)
|
||||
if !newSettings.RestrictManagementToLocalhost {
|
||||
log.Warnf("amp management routes now accessible from any IP - this is insecure!")
|
||||
}
|
||||
}
|
||||
|
||||
newUpstreamURL := strings.TrimSpace(newSettings.UpstreamURL)
|
||||
|
||||
@@ -146,6 +146,9 @@ func TestAmpModule_OnConfigUpdated_CacheInvalidation(t *testing.T) {
|
||||
m := &AmpModule{enabled: true}
|
||||
ms := NewMultiSourceSecretWithPath("", p, time.Minute)
|
||||
m.secretSource = ms
|
||||
m.lastConfig = &config.AmpCode{
|
||||
UpstreamAPIKey: "old-key",
|
||||
}
|
||||
|
||||
// Warm the cache
|
||||
if _, err := ms.Get(context.Background()); err != nil {
|
||||
@@ -157,7 +160,7 @@ func TestAmpModule_OnConfigUpdated_CacheInvalidation(t *testing.T) {
|
||||
}
|
||||
|
||||
// Update config - should invalidate cache
|
||||
if err := m.OnConfigUpdated(&config.Config{AmpCode: config.AmpCode{UpstreamURL: "http://x"}}); err != nil {
|
||||
if err := m.OnConfigUpdated(&config.Config{AmpCode: config.AmpCode{UpstreamURL: "http://x", UpstreamAPIKey: "new-key"}}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
|
||||
@@ -64,7 +64,7 @@ func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provid
|
||||
fields["cost"] = "amp_credits"
|
||||
fields["source"] = "ampcode.com"
|
||||
fields["model_id"] = requestedModel // Explicit model_id for easy config reference
|
||||
log.WithFields(fields).Warnf("forwarding to ampcode.com (uses amp credits) - model_id: %s | To use local proxy, add to config: amp-model-mappings: [{from: \"%s\", to: \"<your-local-model>\"}]", requestedModel, requestedModel)
|
||||
log.WithFields(fields).Warnf("forwarding to ampcode.com (uses amp credits) - model_id: %s | To use local provider, add to config: ampcode.model-mappings: [{from: \"%s\", to: \"<your-local-model>\"}]", requestedModel, requestedModel)
|
||||
|
||||
case RouteTypeNoProvider:
|
||||
fields["cost"] = "none"
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
@@ -43,6 +44,11 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
|
||||
originalDirector(req)
|
||||
req.Host = parsed.Host
|
||||
|
||||
// Remove client's Authorization header - it was only used for CLI Proxy API authentication
|
||||
// We will set our own Authorization using the configured upstream-api-key
|
||||
req.Header.Del("Authorization")
|
||||
req.Header.Del("X-Api-Key")
|
||||
|
||||
// Preserve correlation headers for debugging
|
||||
if req.Header.Get("X-Request-ID") == "" {
|
||||
// Could generate one here if needed
|
||||
@@ -52,7 +58,7 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
|
||||
// Users going through ampcode.com proxy are paying for the service and should get all features
|
||||
// including 1M context window (context-1m-2025-08-07)
|
||||
|
||||
// Inject API key from secret source (precedence: config > env > file)
|
||||
// Inject API key from secret source (only uses upstream-api-key from config)
|
||||
if key, err := secretSource.Get(req.Context()); err == nil && key != "" {
|
||||
req.Header.Set("X-Api-Key", key)
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key))
|
||||
@@ -64,7 +70,15 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
|
||||
// Modify incoming responses to handle gzip without Content-Encoding
|
||||
// This addresses the same issue as inline handler gzip handling, but at the proxy level
|
||||
proxy.ModifyResponse = func(resp *http.Response) error {
|
||||
// Only process successful responses
|
||||
// Log upstream error responses for diagnostics (502, 503, etc.)
|
||||
// These are NOT proxy connection errors - the upstream responded with an error status
|
||||
if resp.StatusCode >= 500 {
|
||||
log.Errorf("amp upstream responded with error [%d] for %s %s", resp.StatusCode, resp.Request.Method, resp.Request.URL.Path)
|
||||
} else if resp.StatusCode >= 400 {
|
||||
log.Warnf("amp upstream responded with client error [%d] for %s %s", resp.StatusCode, resp.Request.Method, resp.Request.URL.Path)
|
||||
}
|
||||
|
||||
// Only process successful responses for gzip decompression
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil
|
||||
}
|
||||
@@ -148,15 +162,29 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
|
||||
return nil
|
||||
}
|
||||
|
||||
// Error handler for proxy failures
|
||||
// Error handler for proxy failures with detailed error classification for diagnostics
|
||||
proxy.ErrorHandler = func(rw http.ResponseWriter, req *http.Request, err error) {
|
||||
// Check if this is a client-side cancellation (normal behavior)
|
||||
// Classify the error type for better diagnostics
|
||||
var errType string
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
errType = "timeout"
|
||||
} else if errors.Is(err, context.Canceled) {
|
||||
errType = "canceled"
|
||||
} else if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
errType = "dial_timeout"
|
||||
} else if _, ok := err.(net.Error); ok {
|
||||
errType = "network_error"
|
||||
} else {
|
||||
errType = "connection_error"
|
||||
}
|
||||
|
||||
// Don't log as error for context canceled - it's usually client closing connection
|
||||
if errors.Is(err, context.Canceled) {
|
||||
log.Debugf("amp upstream proxy: client canceled request for %s %s", req.Method, req.URL.Path)
|
||||
log.Debugf("amp upstream proxy [%s]: client canceled request for %s %s", errType, req.Method, req.URL.Path)
|
||||
} else {
|
||||
log.Errorf("amp upstream proxy error for %s %s: %v", req.Method, req.URL.Path, err)
|
||||
log.Errorf("amp upstream proxy error [%s] for %s %s: %v", errType, req.Method, req.URL.Path, err)
|
||||
}
|
||||
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
rw.WriteHeader(http.StatusBadGateway)
|
||||
_, _ = rw.Write([]byte(`{"error":"amp_upstream_proxy_error","message":"Failed to reach Amp upstream"}`))
|
||||
|
||||
@@ -29,17 +29,79 @@ func NewResponseRewriter(w gin.ResponseWriter, originalModel string) *ResponseRe
|
||||
}
|
||||
}
|
||||
|
||||
const maxBufferedResponseBytes = 2 * 1024 * 1024 // 2MB safety cap
|
||||
|
||||
func looksLikeSSEChunk(data []byte) bool {
|
||||
// Fallback detection: some upstreams may omit/lie about Content-Type, causing SSE to be buffered.
|
||||
// Heuristics are intentionally simple and cheap.
|
||||
return bytes.Contains(data, []byte("data:")) ||
|
||||
bytes.Contains(data, []byte("event:")) ||
|
||||
bytes.Contains(data, []byte("message_start")) ||
|
||||
bytes.Contains(data, []byte("message_delta")) ||
|
||||
bytes.Contains(data, []byte("content_block_start")) ||
|
||||
bytes.Contains(data, []byte("content_block_delta")) ||
|
||||
bytes.Contains(data, []byte("content_block_stop")) ||
|
||||
bytes.Contains(data, []byte("\n\n"))
|
||||
}
|
||||
|
||||
func (rw *ResponseRewriter) enableStreaming(reason string) error {
|
||||
if rw.isStreaming {
|
||||
return nil
|
||||
}
|
||||
rw.isStreaming = true
|
||||
|
||||
// Flush any previously buffered data to avoid reordering or data loss.
|
||||
if rw.body != nil && rw.body.Len() > 0 {
|
||||
buf := rw.body.Bytes()
|
||||
// Copy before Reset() to keep bytes stable.
|
||||
toFlush := make([]byte, len(buf))
|
||||
copy(toFlush, buf)
|
||||
rw.body.Reset()
|
||||
|
||||
if _, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(toFlush)); err != nil {
|
||||
return err
|
||||
}
|
||||
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("amp response rewriter: switched to streaming (%s)", reason)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write intercepts response writes and buffers them for model name replacement
|
||||
func (rw *ResponseRewriter) Write(data []byte) (int, error) {
|
||||
// Detect streaming on first write
|
||||
if rw.body.Len() == 0 && !rw.isStreaming {
|
||||
// Detect streaming on first write (header-based)
|
||||
if !rw.isStreaming && rw.body.Len() == 0 {
|
||||
contentType := rw.Header().Get("Content-Type")
|
||||
rw.isStreaming = strings.Contains(contentType, "text/event-stream") ||
|
||||
strings.Contains(contentType, "stream")
|
||||
}
|
||||
|
||||
if !rw.isStreaming {
|
||||
// Content-based fallback: detect SSE-like chunks even if Content-Type is missing/wrong.
|
||||
if looksLikeSSEChunk(data) {
|
||||
if err := rw.enableStreaming("sse heuristic"); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
} else if rw.body.Len()+len(data) > maxBufferedResponseBytes {
|
||||
// Safety cap: avoid unbounded buffering on large responses.
|
||||
log.Warnf("amp response rewriter: buffer exceeded %d bytes, switching to streaming", maxBufferedResponseBytes)
|
||||
if err := rw.enableStreaming("buffer limit"); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if rw.isStreaming {
|
||||
return rw.ResponseWriter.Write(rw.rewriteStreamChunk(data))
|
||||
n, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(data))
|
||||
if err == nil {
|
||||
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
return rw.body.Write(data)
|
||||
}
|
||||
|
||||
@@ -98,7 +98,8 @@ func (m *AmpModule) managementAvailabilityMiddleware() gin.HandlerFunc {
|
||||
// registerManagementRoutes registers Amp management proxy routes
|
||||
// These routes proxy through to the Amp control plane for OAuth, user management, etc.
|
||||
// Uses dynamic middleware and proxy getter for hot-reload support.
|
||||
func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *handlers.BaseAPIHandler) {
|
||||
// The auth middleware validates Authorization header against configured API keys.
|
||||
func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *handlers.BaseAPIHandler, auth gin.HandlerFunc) {
|
||||
ampAPI := engine.Group("/api")
|
||||
|
||||
// Always disable CORS for management routes to prevent browser-based attacks
|
||||
@@ -107,8 +108,9 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
|
||||
// Apply dynamic localhost-only restriction (hot-reloadable via m.IsRestrictedToLocalhost())
|
||||
ampAPI.Use(m.localhostOnlyMiddleware())
|
||||
|
||||
if !m.IsRestrictedToLocalhost() {
|
||||
log.Warn("amp management routes are NOT restricted to localhost - this is insecure!")
|
||||
// Apply authentication middleware - requires valid API key in Authorization header
|
||||
if auth != nil {
|
||||
ampAPI.Use(auth)
|
||||
}
|
||||
|
||||
// Dynamic proxy handler that uses m.getProxy() for hot-reload support
|
||||
@@ -154,6 +156,9 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
|
||||
// Root-level routes that AMP CLI expects without /api prefix
|
||||
// These need the same security middleware as the /api/* routes (dynamic for hot-reload)
|
||||
rootMiddleware := []gin.HandlerFunc{m.managementAvailabilityMiddleware(), noCORSMiddleware(), m.localhostOnlyMiddleware()}
|
||||
if auth != nil {
|
||||
rootMiddleware = append(rootMiddleware, auth)
|
||||
}
|
||||
engine.GET("/threads/*path", append(rootMiddleware, proxyHandler)...)
|
||||
engine.GET("/threads.rss", append(rootMiddleware, proxyHandler)...)
|
||||
engine.GET("/news.rss", append(rootMiddleware, proxyHandler)...)
|
||||
@@ -262,7 +267,7 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han
|
||||
v1betaAmp := provider.Group("/v1beta")
|
||||
{
|
||||
v1betaAmp.GET("/models", geminiHandlers.GeminiModels)
|
||||
v1betaAmp.POST("/models/:action", fallbackHandler.WrapHandler(geminiHandlers.GeminiHandler))
|
||||
v1betaAmp.GET("/models/:action", geminiHandlers.GeminiGetHandler)
|
||||
v1betaAmp.POST("/models/*action", fallbackHandler.WrapHandler(geminiHandlers.GeminiHandler))
|
||||
v1betaAmp.GET("/models/*action", geminiHandlers.GeminiGetHandler)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,7 +32,9 @@ func TestRegisterManagementRoutes(t *testing.T) {
|
||||
m.setProxy(proxy)
|
||||
|
||||
base := &handlers.BaseAPIHandler{}
|
||||
m.registerManagementRoutes(r, base)
|
||||
m.registerManagementRoutes(r, base, nil)
|
||||
srv := httptest.NewServer(r)
|
||||
defer srv.Close()
|
||||
|
||||
managementPaths := []struct {
|
||||
path string
|
||||
@@ -63,11 +65,17 @@ func TestRegisterManagementRoutes(t *testing.T) {
|
||||
for _, path := range managementPaths {
|
||||
t.Run(path.path, func(t *testing.T) {
|
||||
proxyCalled = false
|
||||
req := httptest.NewRequest(path.method, path.path, nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
req, err := http.NewRequest(path.method, srv.URL+path.path, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to build request: %v", err)
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if w.Code == http.StatusNotFound {
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
t.Fatalf("route %s not registered", path.path)
|
||||
}
|
||||
if !proxyCalled {
|
||||
|
||||
@@ -230,13 +230,9 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
||||
envManagementSecret := envAdminPasswordSet && envAdminPassword != ""
|
||||
|
||||
// Create server instance
|
||||
providerNames := make([]string, 0, len(cfg.OpenAICompatibility))
|
||||
for _, p := range cfg.OpenAICompatibility {
|
||||
providerNames = append(providerNames, p.Name)
|
||||
}
|
||||
s := &Server{
|
||||
engine: engine,
|
||||
handlers: handlers.NewBaseAPIHandlers(&cfg.SDKConfig, authManager, providerNames),
|
||||
handlers: handlers.NewBaseAPIHandlers(&cfg.SDKConfig, authManager),
|
||||
cfg: cfg,
|
||||
accessManager: accessManager,
|
||||
requestLogger: requestLogger,
|
||||
@@ -334,8 +330,8 @@ func (s *Server) setupRoutes() {
|
||||
v1beta.Use(AuthMiddleware(s.accessManager))
|
||||
{
|
||||
v1beta.GET("/models", geminiHandlers.GeminiModels)
|
||||
v1beta.POST("/models/:action", geminiHandlers.GeminiHandler)
|
||||
v1beta.GET("/models/:action", geminiHandlers.GeminiGetHandler)
|
||||
v1beta.POST("/models/*action", geminiHandlers.GeminiHandler)
|
||||
v1beta.GET("/models/*action", geminiHandlers.GeminiGetHandler)
|
||||
}
|
||||
|
||||
// Root endpoint
|
||||
@@ -421,6 +417,18 @@ func (s *Server) setupRoutes() {
|
||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||
})
|
||||
|
||||
s.engine.GET("/kiro/callback", func(c *gin.Context) {
|
||||
code := c.Query("code")
|
||||
state := c.Query("state")
|
||||
errStr := c.Query("error")
|
||||
if state != "" {
|
||||
file := fmt.Sprintf("%s/.oauth-kiro-%s.oauth", s.cfg.AuthDir, state)
|
||||
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
|
||||
}
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||
})
|
||||
|
||||
// Management routes are registered lazily by registerManagementRoutes when a secret is configured.
|
||||
}
|
||||
|
||||
@@ -574,6 +582,7 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.DELETE("/oauth-excluded-models", s.mgmt.DeleteOAuthExcludedModels)
|
||||
|
||||
mgmt.GET("/auth-files", s.mgmt.ListAuthFiles)
|
||||
mgmt.GET("/auth-files/models", s.mgmt.GetAuthFileModels)
|
||||
mgmt.GET("/auth-files/download", s.mgmt.DownloadAuthFile)
|
||||
mgmt.POST("/auth-files", s.mgmt.UploadAuthFile)
|
||||
mgmt.DELETE("/auth-files", s.mgmt.DeleteAuthFile)
|
||||
@@ -586,6 +595,7 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken)
|
||||
mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken)
|
||||
mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken)
|
||||
mgmt.GET("/kiro-auth-url", s.mgmt.RequestKiroToken)
|
||||
mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus)
|
||||
}
|
||||
}
|
||||
@@ -614,7 +624,7 @@ func (s *Server) serveManagementControlPanel(c *gin.Context) {
|
||||
|
||||
if _, err := os.Stat(filePath); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
go managementasset.EnsureLatestManagementHTML(context.Background(), managementasset.StaticDir(s.configFilePath), cfg.ProxyURL)
|
||||
go managementasset.EnsureLatestManagementHTML(context.Background(), managementasset.StaticDir(s.configFilePath), cfg.ProxyURL, cfg.RemoteManagement.PanelGitHubRepository)
|
||||
c.AbortWithStatus(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
@@ -924,17 +934,11 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
||||
// Save YAML snapshot for next comparison
|
||||
s.oldConfigYaml, _ = yaml.Marshal(cfg)
|
||||
|
||||
providerNames := make([]string, 0, len(cfg.OpenAICompatibility))
|
||||
for _, p := range cfg.OpenAICompatibility {
|
||||
providerNames = append(providerNames, p.Name)
|
||||
}
|
||||
s.handlers.SetOpenAICompatProviders(providerNames)
|
||||
|
||||
s.handlers.UpdateClients(&cfg.SDKConfig)
|
||||
|
||||
if !cfg.RemoteManagement.DisableControlPanel {
|
||||
staticDir := managementasset.StaticDir(s.configFilePath)
|
||||
go managementasset.EnsureLatestManagementHTML(context.Background(), staticDir, cfg.ProxyURL)
|
||||
go managementasset.EnsureLatestManagementHTML(context.Background(), staticDir, cfg.ProxyURL, cfg.RemoteManagement.PanelGitHubRepository)
|
||||
}
|
||||
if s.mgmt != nil {
|
||||
s.mgmt.SetConfig(cfg)
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
package iflow
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -36,3 +39,61 @@ func SanitizeIFlowFileName(raw string) string {
|
||||
}
|
||||
return strings.TrimSpace(result.String())
|
||||
}
|
||||
|
||||
// ExtractBXAuth extracts the BXAuth value from a cookie string.
|
||||
func ExtractBXAuth(cookie string) string {
|
||||
parts := strings.Split(cookie, ";")
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if strings.HasPrefix(part, "BXAuth=") {
|
||||
return strings.TrimPrefix(part, "BXAuth=")
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// CheckDuplicateBXAuth checks if the given BXAuth value already exists in any iflow auth file.
|
||||
// Returns the path of the existing file if found, empty string otherwise.
|
||||
func CheckDuplicateBXAuth(authDir, bxAuth string) (string, error) {
|
||||
if bxAuth == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(authDir)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return "", nil
|
||||
}
|
||||
return "", fmt.Errorf("read auth dir failed: %w", err)
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := entry.Name()
|
||||
if !strings.HasPrefix(name, "iflow-") || !strings.HasSuffix(name, ".json") {
|
||||
continue
|
||||
}
|
||||
|
||||
filePath := filepath.Join(authDir, name)
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var tokenData struct {
|
||||
Cookie string `json:"cookie"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &tokenData); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
existingBXAuth := ExtractBXAuth(tokenData.Cookie)
|
||||
if existingBXAuth != "" && existingBXAuth == bxAuth {
|
||||
return filePath, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
@@ -506,11 +506,18 @@ func (ia *IFlowAuth) CreateCookieTokenStorage(data *IFlowTokenData) *IFlowTokenS
|
||||
return nil
|
||||
}
|
||||
|
||||
// Only save the BXAuth field from the cookie
|
||||
bxAuth := ExtractBXAuth(data.Cookie)
|
||||
cookieToSave := ""
|
||||
if bxAuth != "" {
|
||||
cookieToSave = "BXAuth=" + bxAuth + ";"
|
||||
}
|
||||
|
||||
return &IFlowTokenStorage{
|
||||
APIKey: data.APIKey,
|
||||
Email: data.Email,
|
||||
Expire: data.Expire,
|
||||
Cookie: data.Cookie,
|
||||
Cookie: cookieToSave,
|
||||
LastRefresh: time.Now().Format(time.RFC3339),
|
||||
Type: "iflow",
|
||||
}
|
||||
|
||||
@@ -471,7 +471,7 @@ foreach ($port in $ports) {
|
||||
|
||||
// Create batch wrapper
|
||||
batchPath := filepath.Join(scriptDir, "kiro-oauth-handler.bat")
|
||||
batchContent := fmt.Sprintf("@echo off\npowershell -ExecutionPolicy Bypass -File \"%s\" \"%%1\"\n", scriptPath)
|
||||
batchContent := fmt.Sprintf("@echo off\npowershell -ExecutionPolicy Bypass -File \"%s\" %%1\n", scriptPath)
|
||||
|
||||
if err := os.WriteFile(batchPath, []byte(batchContent), 0644); err != nil {
|
||||
return fmt.Errorf("failed to write batch wrapper: %w", err)
|
||||
|
||||
@@ -126,8 +126,8 @@ func (c *SocialAuthClient) buildLoginURL(provider, redirectURI, codeChallenge, s
|
||||
)
|
||||
}
|
||||
|
||||
// createToken exchanges the authorization code for tokens.
|
||||
func (c *SocialAuthClient) createToken(ctx context.Context, req *CreateTokenRequest) (*SocialTokenResponse, error) {
|
||||
// CreateToken exchanges the authorization code for tokens.
|
||||
func (c *SocialAuthClient) CreateToken(ctx context.Context, req *CreateTokenRequest) (*SocialTokenResponse, error) {
|
||||
body, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal token request: %w", err)
|
||||
@@ -326,7 +326,7 @@ func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialP
|
||||
RedirectURI: KiroRedirectURI,
|
||||
}
|
||||
|
||||
tokenResp, err := c.createToken(ctx, tokenReq)
|
||||
tokenResp, err := c.CreateToken(ctx, tokenReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to exchange code for tokens: %w", err)
|
||||
}
|
||||
|
||||
@@ -5,7 +5,9 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
@@ -37,6 +39,16 @@ func DoIFlowCookieAuth(cfg *config.Config, options *LoginOptions) {
|
||||
return
|
||||
}
|
||||
|
||||
// Check for duplicate BXAuth before authentication
|
||||
bxAuth := iflow.ExtractBXAuth(cookie)
|
||||
if existingFile, err := iflow.CheckDuplicateBXAuth(cfg.AuthDir, bxAuth); err != nil {
|
||||
fmt.Printf("Failed to check duplicate: %v\n", err)
|
||||
return
|
||||
} else if existingFile != "" {
|
||||
fmt.Printf("Duplicate BXAuth found, authentication already exists: %s\n", filepath.Base(existingFile))
|
||||
return
|
||||
}
|
||||
|
||||
// Authenticate with cookie
|
||||
auth := iflow.NewIFlowAuth(cfg)
|
||||
ctx := context.Background()
|
||||
@@ -82,5 +94,5 @@ func promptForCookie(promptFn func(string) (string, error)) (string, error) {
|
||||
// getAuthFilePath returns the auth file path for the given provider and email
|
||||
func getAuthFilePath(cfg *config.Config, provider, email string) string {
|
||||
fileName := iflow.SanitizeIFlowFileName(email)
|
||||
return fmt.Sprintf("%s/%s-%s.json", cfg.AuthDir, provider, fileName)
|
||||
return fmt.Sprintf("%s/%s-%s-%d.json", cfg.AuthDir, provider, fileName, time.Now().Unix())
|
||||
}
|
||||
|
||||
@@ -17,6 +17,8 @@ import (
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
const DefaultPanelGitHubRepository = "https://github.com/router-for-me/Cli-Proxy-API-Management-Center"
|
||||
|
||||
// Config represents the application's configuration, loaded from a YAML file.
|
||||
type Config struct {
|
||||
config.SDKConfig `yaml:",inline"`
|
||||
@@ -116,6 +118,9 @@ type RemoteManagement struct {
|
||||
SecretKey string `yaml:"secret-key"`
|
||||
// DisableControlPanel skips serving and syncing the bundled management UI when true.
|
||||
DisableControlPanel bool `yaml:"disable-control-panel"`
|
||||
// PanelGitHubRepository overrides the GitHub repository used to fetch the management panel asset.
|
||||
// Accepts either a repository URL (https://github.com/org/repo) or an API releases endpoint.
|
||||
PanelGitHubRepository string `yaml:"panel-github-repository"`
|
||||
}
|
||||
|
||||
// QuotaExceeded defines the behavior when API quota limits are exceeded.
|
||||
@@ -151,7 +156,7 @@ type AmpCode struct {
|
||||
|
||||
// RestrictManagementToLocalhost restricts Amp management routes (/api/user, /api/threads, etc.)
|
||||
// to only accept connections from localhost (127.0.0.1, ::1). When true, prevents drive-by
|
||||
// browser attacks and remote access to management endpoints. Default: true (recommended).
|
||||
// browser attacks and remote access to management endpoints. Default: false (API key auth is sufficient).
|
||||
RestrictManagementToLocalhost bool `yaml:"restrict-management-to-localhost" json:"restrict-management-to-localhost"`
|
||||
|
||||
// ModelMappings defines model name mappings for Amp CLI requests.
|
||||
@@ -194,6 +199,9 @@ type ClaudeKey struct {
|
||||
// APIKey is the authentication key for accessing Claude API services.
|
||||
APIKey string `yaml:"api-key" json:"api-key"`
|
||||
|
||||
// Prefix optionally namespaces models for this credential (e.g., "teamA/claude-sonnet-4").
|
||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||
|
||||
// BaseURL is the base URL for the Claude API endpoint.
|
||||
// If empty, the default Claude API URL will be used.
|
||||
BaseURL string `yaml:"base-url" json:"base-url"`
|
||||
@@ -226,6 +234,9 @@ type CodexKey struct {
|
||||
// APIKey is the authentication key for accessing Codex API services.
|
||||
APIKey string `yaml:"api-key" json:"api-key"`
|
||||
|
||||
// Prefix optionally namespaces models for this credential (e.g., "teamA/gpt-5-codex").
|
||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||
|
||||
// BaseURL is the base URL for the Codex API endpoint.
|
||||
// If empty, the default Codex API URL will be used.
|
||||
BaseURL string `yaml:"base-url" json:"base-url"`
|
||||
@@ -246,6 +257,9 @@ type GeminiKey struct {
|
||||
// APIKey is the authentication key for accessing Gemini API services.
|
||||
APIKey string `yaml:"api-key" json:"api-key"`
|
||||
|
||||
// Prefix optionally namespaces models for this credential (e.g., "teamA/gemini-3-pro-preview").
|
||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||
|
||||
// BaseURL optionally overrides the Gemini API endpoint.
|
||||
BaseURL string `yaml:"base-url,omitempty" json:"base-url,omitempty"`
|
||||
|
||||
@@ -294,6 +308,9 @@ type OpenAICompatibility struct {
|
||||
// Name is the identifier for this OpenAI compatibility configuration.
|
||||
Name string `yaml:"name" json:"name"`
|
||||
|
||||
// Prefix optionally namespaces model aliases for this provider (e.g., "teamA/kimi-k2").
|
||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||
|
||||
// BaseURL is the base URL for the external OpenAI-compatible API endpoint.
|
||||
BaseURL string `yaml:"base-url" json:"base-url"`
|
||||
|
||||
@@ -368,7 +385,8 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
cfg.LoggingToFile = false
|
||||
cfg.UsageStatisticsEnabled = false
|
||||
cfg.DisableCooling = false
|
||||
cfg.AmpCode.RestrictManagementToLocalhost = true // Default to secure: only localhost access
|
||||
cfg.AmpCode.RestrictManagementToLocalhost = false // Default to false: API key auth is sufficient
|
||||
cfg.RemoteManagement.PanelGitHubRepository = DefaultPanelGitHubRepository
|
||||
cfg.IncognitoBrowser = false // Default to normal browser (AWS uses incognito by force)
|
||||
if err = yaml.Unmarshal(data, &cfg); err != nil {
|
||||
if optional {
|
||||
@@ -405,6 +423,11 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
_ = SaveConfigPreserveCommentsUpdateNestedScalar(configFile, []string{"remote-management", "secret-key"}, hashed)
|
||||
}
|
||||
|
||||
cfg.RemoteManagement.PanelGitHubRepository = strings.TrimSpace(cfg.RemoteManagement.PanelGitHubRepository)
|
||||
if cfg.RemoteManagement.PanelGitHubRepository == "" {
|
||||
cfg.RemoteManagement.PanelGitHubRepository = DefaultPanelGitHubRepository
|
||||
}
|
||||
|
||||
// Sync request authentication providers with inline API keys for backwards compatibility.
|
||||
syncInlineAccessProvider(&cfg)
|
||||
|
||||
@@ -456,6 +479,7 @@ func (cfg *Config) SanitizeOpenAICompatibility() {
|
||||
for i := range cfg.OpenAICompatibility {
|
||||
e := cfg.OpenAICompatibility[i]
|
||||
e.Name = strings.TrimSpace(e.Name)
|
||||
e.Prefix = normalizeModelPrefix(e.Prefix)
|
||||
e.BaseURL = strings.TrimSpace(e.BaseURL)
|
||||
e.Headers = NormalizeHeaders(e.Headers)
|
||||
if e.BaseURL == "" {
|
||||
@@ -476,6 +500,7 @@ func (cfg *Config) SanitizeCodexKeys() {
|
||||
out := make([]CodexKey, 0, len(cfg.CodexKey))
|
||||
for i := range cfg.CodexKey {
|
||||
e := cfg.CodexKey[i]
|
||||
e.Prefix = normalizeModelPrefix(e.Prefix)
|
||||
e.BaseURL = strings.TrimSpace(e.BaseURL)
|
||||
e.Headers = NormalizeHeaders(e.Headers)
|
||||
e.ExcludedModels = NormalizeExcludedModels(e.ExcludedModels)
|
||||
@@ -494,6 +519,7 @@ func (cfg *Config) SanitizeClaudeKeys() {
|
||||
}
|
||||
for i := range cfg.ClaudeKey {
|
||||
entry := &cfg.ClaudeKey[i]
|
||||
entry.Prefix = normalizeModelPrefix(entry.Prefix)
|
||||
entry.Headers = NormalizeHeaders(entry.Headers)
|
||||
entry.ExcludedModels = NormalizeExcludedModels(entry.ExcludedModels)
|
||||
}
|
||||
@@ -530,6 +556,7 @@ func (cfg *Config) SanitizeGeminiKeys() {
|
||||
if entry.APIKey == "" {
|
||||
continue
|
||||
}
|
||||
entry.Prefix = normalizeModelPrefix(entry.Prefix)
|
||||
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
|
||||
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
|
||||
entry.Headers = NormalizeHeaders(entry.Headers)
|
||||
@@ -543,6 +570,18 @@ func (cfg *Config) SanitizeGeminiKeys() {
|
||||
cfg.GeminiKey = out
|
||||
}
|
||||
|
||||
func normalizeModelPrefix(prefix string) string {
|
||||
trimmed := strings.TrimSpace(prefix)
|
||||
trimmed = strings.Trim(trimmed, "/")
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
if strings.Contains(trimmed, "/") {
|
||||
return ""
|
||||
}
|
||||
return trimmed
|
||||
}
|
||||
|
||||
func syncInlineAccessProvider(cfg *Config) {
|
||||
if cfg == nil {
|
||||
return
|
||||
|
||||
@@ -13,6 +13,9 @@ type VertexCompatKey struct {
|
||||
// Maps to the x-goog-api-key header.
|
||||
APIKey string `yaml:"api-key" json:"api-key"`
|
||||
|
||||
// Prefix optionally namespaces model aliases for this credential (e.g., "teamA/vertex-pro").
|
||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||
|
||||
// BaseURL is the base URL for the Vertex-compatible API endpoint.
|
||||
// The executor will append "/v1/publishers/google/models/{model}:action" to this.
|
||||
// Example: "https://zenmux.ai/api" becomes "https://zenmux.ai/api/v1/publishers/google/models/..."
|
||||
@@ -53,6 +56,7 @@ func (cfg *Config) SanitizeVertexCompatKeys() {
|
||||
if entry.APIKey == "" {
|
||||
continue
|
||||
}
|
||||
entry.Prefix = normalizeModelPrefix(entry.Prefix)
|
||||
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
|
||||
if entry.BaseURL == "" {
|
||||
// BaseURL is required for Vertex API key entries
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -23,10 +24,10 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
managementReleaseURL = "https://api.github.com/repos/router-for-me/Cli-Proxy-API-Management-Center/releases/latest"
|
||||
managementAssetName = "management.html"
|
||||
httpUserAgent = "CLIProxyAPI-management-updater"
|
||||
updateCheckInterval = 3 * time.Hour
|
||||
defaultManagementReleaseURL = "https://api.github.com/repos/router-for-me/Cli-Proxy-API-Management-Center/releases/latest"
|
||||
managementAssetName = "management.html"
|
||||
httpUserAgent = "CLIProxyAPI-management-updater"
|
||||
updateCheckInterval = 3 * time.Hour
|
||||
)
|
||||
|
||||
// ManagementFileName exposes the control panel asset filename.
|
||||
@@ -97,7 +98,7 @@ func runAutoUpdater(ctx context.Context) {
|
||||
|
||||
configPath, _ := schedulerConfigPath.Load().(string)
|
||||
staticDir := StaticDir(configPath)
|
||||
EnsureLatestManagementHTML(ctx, staticDir, cfg.ProxyURL)
|
||||
EnsureLatestManagementHTML(ctx, staticDir, cfg.ProxyURL, cfg.RemoteManagement.PanelGitHubRepository)
|
||||
}
|
||||
|
||||
runOnce()
|
||||
@@ -181,7 +182,7 @@ func FilePath(configFilePath string) string {
|
||||
// EnsureLatestManagementHTML checks the latest management.html asset and updates the local copy when needed.
|
||||
// The function is designed to run in a background goroutine and will never panic.
|
||||
// It enforces a 3-hour rate limit to avoid frequent checks on config/auth file changes.
|
||||
func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL string) {
|
||||
func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL string, panelRepository string) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
@@ -214,6 +215,7 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
|
||||
return
|
||||
}
|
||||
|
||||
releaseURL := resolveReleaseURL(panelRepository)
|
||||
client := newHTTPClient(proxyURL)
|
||||
|
||||
localPath := filepath.Join(staticDir, managementAssetName)
|
||||
@@ -225,7 +227,7 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
|
||||
localHash = ""
|
||||
}
|
||||
|
||||
asset, remoteHash, err := fetchLatestAsset(ctx, client)
|
||||
asset, remoteHash, err := fetchLatestAsset(ctx, client, releaseURL)
|
||||
if err != nil {
|
||||
log.WithError(err).Warn("failed to fetch latest management release information")
|
||||
return
|
||||
@@ -254,8 +256,44 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
|
||||
log.Infof("management asset updated successfully (hash=%s)", downloadedHash)
|
||||
}
|
||||
|
||||
func fetchLatestAsset(ctx context.Context, client *http.Client) (*releaseAsset, string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, managementReleaseURL, nil)
|
||||
func resolveReleaseURL(repo string) string {
|
||||
repo = strings.TrimSpace(repo)
|
||||
if repo == "" {
|
||||
return defaultManagementReleaseURL
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(repo)
|
||||
if err != nil || parsed.Host == "" {
|
||||
return defaultManagementReleaseURL
|
||||
}
|
||||
|
||||
host := strings.ToLower(parsed.Host)
|
||||
parsed.Path = strings.TrimSuffix(parsed.Path, "/")
|
||||
|
||||
if host == "api.github.com" {
|
||||
if !strings.HasSuffix(strings.ToLower(parsed.Path), "/releases/latest") {
|
||||
parsed.Path = parsed.Path + "/releases/latest"
|
||||
}
|
||||
return parsed.String()
|
||||
}
|
||||
|
||||
if host == "github.com" {
|
||||
parts := strings.Split(strings.Trim(parsed.Path, "/"), "/")
|
||||
if len(parts) >= 2 && parts[0] != "" && parts[1] != "" {
|
||||
repoName := strings.TrimSuffix(parts[1], ".git")
|
||||
return fmt.Sprintf("https://api.github.com/repos/%s/%s/releases/latest", parts[0], repoName)
|
||||
}
|
||||
}
|
||||
|
||||
return defaultManagementReleaseURL
|
||||
}
|
||||
|
||||
func fetchLatestAsset(ctx context.Context, client *http.Client, releaseURL string) (*releaseAsset, string, error) {
|
||||
if strings.TrimSpace(releaseURL) == "" {
|
||||
releaseURL = defaultManagementReleaseURL
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, releaseURL, nil)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("create release request: %w", err)
|
||||
}
|
||||
|
||||
@@ -630,6 +630,13 @@ func GetQwenModels() []*ModelInfo {
|
||||
}
|
||||
}
|
||||
|
||||
// iFlowThinkingSupport is a shared ThinkingSupport configuration for iFlow models
|
||||
// that support thinking mode via chat_template_kwargs.enable_thinking (boolean toggle).
|
||||
// Uses level-based configuration so standard normalization flows apply before conversion.
|
||||
var iFlowThinkingSupport = &ThinkingSupport{
|
||||
Levels: []string{"none", "auto", "minimal", "low", "medium", "high", "xhigh"},
|
||||
}
|
||||
|
||||
// GetIFlowModels returns supported models for iFlow OAuth accounts.
|
||||
func GetIFlowModels() []*ModelInfo {
|
||||
entries := []struct {
|
||||
@@ -645,19 +652,20 @@ func GetIFlowModels() []*ModelInfo {
|
||||
{ID: "qwen3-vl-plus", DisplayName: "Qwen3-VL-Plus", Description: "Qwen3 multimodal vision-language", Created: 1758672000},
|
||||
{ID: "qwen3-max-preview", DisplayName: "Qwen3-Max-Preview", Description: "Qwen3 Max preview build", Created: 1757030400},
|
||||
{ID: "kimi-k2-0905", DisplayName: "Kimi-K2-Instruct-0905", Description: "Moonshot Kimi K2 instruct 0905", Created: 1757030400},
|
||||
{ID: "glm-4.6", DisplayName: "GLM-4.6", Description: "Zhipu GLM 4.6 general model", Created: 1759190400},
|
||||
{ID: "glm-4.6", DisplayName: "GLM-4.6", Description: "Zhipu GLM 4.6 general model", Created: 1759190400, Thinking: iFlowThinkingSupport},
|
||||
{ID: "kimi-k2", DisplayName: "Kimi-K2", Description: "Moonshot Kimi K2 general model", Created: 1752192000},
|
||||
{ID: "kimi-k2-thinking", DisplayName: "Kimi-K2-Thinking", Description: "Moonshot Kimi K2 thinking model", Created: 1762387200, Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}}},
|
||||
{ID: "deepseek-v3.2-chat", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2", Created: 1764576000},
|
||||
{ID: "kimi-k2-thinking", DisplayName: "Kimi-K2-Thinking", Description: "Moonshot Kimi K2 thinking model", Created: 1762387200},
|
||||
{ID: "deepseek-v3.2-chat", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Chat", Created: 1764576000},
|
||||
{ID: "deepseek-v3.2-reasoner", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Reasoner", Created: 1764576000},
|
||||
{ID: "deepseek-v3.2", DisplayName: "DeepSeek-V3.2-Exp", Description: "DeepSeek V3.2 experimental", Created: 1759104000},
|
||||
{ID: "deepseek-v3.1", DisplayName: "DeepSeek-V3.1-Terminus", Description: "DeepSeek V3.1 Terminus", Created: 1756339200},
|
||||
{ID: "deepseek-r1", DisplayName: "DeepSeek-R1", Description: "DeepSeek reasoning model R1", Created: 1737331200, Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}}},
|
||||
{ID: "deepseek-r1", DisplayName: "DeepSeek-R1", Description: "DeepSeek reasoning model R1", Created: 1737331200},
|
||||
{ID: "deepseek-v3", DisplayName: "DeepSeek-V3-671B", Description: "DeepSeek V3 671B", Created: 1734307200},
|
||||
{ID: "qwen3-32b", DisplayName: "Qwen3-32B", Description: "Qwen3 32B", Created: 1747094400},
|
||||
{ID: "qwen3-235b-a22b-thinking-2507", DisplayName: "Qwen3-235B-A22B-Thinking", Description: "Qwen3 235B A22B Thinking (2507)", Created: 1753401600, Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}}},
|
||||
{ID: "qwen3-235b-a22b-thinking-2507", DisplayName: "Qwen3-235B-A22B-Thinking", Description: "Qwen3 235B A22B Thinking (2507)", Created: 1753401600},
|
||||
{ID: "qwen3-235b-a22b-instruct", DisplayName: "Qwen3-235B-A22B-Instruct", Description: "Qwen3 235B A22B Instruct", Created: 1753401600},
|
||||
{ID: "qwen3-235b", DisplayName: "Qwen3-235B-A22B", Description: "Qwen3 235B A22B", Created: 1753401600},
|
||||
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000, Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}}},
|
||||
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000},
|
||||
}
|
||||
models := make([]*ModelInfo, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
@@ -895,6 +903,7 @@ func GetKiroModels() []*ModelInfo {
|
||||
Description: "Claude Opus 4.5 via Kiro (2.2x credit)",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kiro-claude-sonnet-4-5",
|
||||
@@ -906,6 +915,7 @@ func GetKiroModels() []*ModelInfo {
|
||||
Description: "Claude Sonnet 4.5 via Kiro (1.3x credit)",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kiro-claude-sonnet-4",
|
||||
@@ -917,6 +927,7 @@ func GetKiroModels() []*ModelInfo {
|
||||
Description: "Claude Sonnet 4 via Kiro (1.3x credit)",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kiro-claude-haiku-4-5",
|
||||
@@ -928,6 +939,7 @@ func GetKiroModels() []*ModelInfo {
|
||||
Description: "Claude Haiku 4.5 via Kiro (0.4x credit)",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
// --- Agentic Variants (Optimized for coding agents with chunked writes) ---
|
||||
{
|
||||
@@ -940,6 +952,7 @@ func GetKiroModels() []*ModelInfo {
|
||||
Description: "Claude Opus 4.5 optimized for coding agents (chunked writes)",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kiro-claude-sonnet-4-5-agentic",
|
||||
@@ -951,6 +964,7 @@ func GetKiroModels() []*ModelInfo {
|
||||
Description: "Claude Sonnet 4.5 optimized for coding agents (chunked writes)",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kiro-claude-sonnet-4-agentic",
|
||||
@@ -962,6 +976,7 @@ func GetKiroModels() []*ModelInfo {
|
||||
Description: "Claude Sonnet 4 optimized for coding agents (chunked writes)",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kiro-claude-haiku-4-5-agentic",
|
||||
@@ -973,6 +988,7 @@ func GetKiroModels() []*ModelInfo {
|
||||
Description: "Claude Haiku 4.5 optimized for coding agents (chunked writes)",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -90,6 +90,9 @@ type ModelRegistry struct {
|
||||
models map[string]*ModelRegistration
|
||||
// clientModels maps client ID to the models it provides
|
||||
clientModels map[string][]string
|
||||
// clientModelInfos maps client ID to a map of model ID -> ModelInfo
|
||||
// This preserves the original model info provided by each client
|
||||
clientModelInfos map[string]map[string]*ModelInfo
|
||||
// clientProviders maps client ID to its provider identifier
|
||||
clientProviders map[string]string
|
||||
// mutex ensures thread-safe access to the registry
|
||||
@@ -104,10 +107,11 @@ var registryOnce sync.Once
|
||||
func GetGlobalRegistry() *ModelRegistry {
|
||||
registryOnce.Do(func() {
|
||||
globalRegistry = &ModelRegistry{
|
||||
models: make(map[string]*ModelRegistration),
|
||||
clientModels: make(map[string][]string),
|
||||
clientProviders: make(map[string]string),
|
||||
mutex: &sync.RWMutex{},
|
||||
models: make(map[string]*ModelRegistration),
|
||||
clientModels: make(map[string][]string),
|
||||
clientModelInfos: make(map[string]map[string]*ModelInfo),
|
||||
clientProviders: make(map[string]string),
|
||||
mutex: &sync.RWMutex{},
|
||||
}
|
||||
})
|
||||
return globalRegistry
|
||||
@@ -144,6 +148,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
||||
// No models supplied; unregister existing client state if present.
|
||||
r.unregisterClientInternal(clientID)
|
||||
delete(r.clientModels, clientID)
|
||||
delete(r.clientModelInfos, clientID)
|
||||
delete(r.clientProviders, clientID)
|
||||
misc.LogCredentialSeparator()
|
||||
return
|
||||
@@ -152,7 +157,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
||||
now := time.Now()
|
||||
|
||||
oldModels, hadExisting := r.clientModels[clientID]
|
||||
oldProvider, _ := r.clientProviders[clientID]
|
||||
oldProvider := r.clientProviders[clientID]
|
||||
providerChanged := oldProvider != provider
|
||||
if !hadExisting {
|
||||
// Pure addition path.
|
||||
@@ -161,6 +166,12 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
||||
r.addModelRegistration(modelID, provider, model, now)
|
||||
}
|
||||
r.clientModels[clientID] = append([]string(nil), rawModelIDs...)
|
||||
// Store client's own model infos
|
||||
clientInfos := make(map[string]*ModelInfo, len(newModels))
|
||||
for id, m := range newModels {
|
||||
clientInfos[id] = cloneModelInfo(m)
|
||||
}
|
||||
r.clientModelInfos[clientID] = clientInfos
|
||||
if provider != "" {
|
||||
r.clientProviders[clientID] = provider
|
||||
} else {
|
||||
@@ -287,6 +298,12 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
||||
if len(rawModelIDs) > 0 {
|
||||
r.clientModels[clientID] = append([]string(nil), rawModelIDs...)
|
||||
}
|
||||
// Update client's own model infos
|
||||
clientInfos := make(map[string]*ModelInfo, len(newModels))
|
||||
for id, m := range newModels {
|
||||
clientInfos[id] = cloneModelInfo(m)
|
||||
}
|
||||
r.clientModelInfos[clientID] = clientInfos
|
||||
if provider != "" {
|
||||
r.clientProviders[clientID] = provider
|
||||
} else {
|
||||
@@ -436,6 +453,7 @@ func (r *ModelRegistry) unregisterClientInternal(clientID string) {
|
||||
}
|
||||
|
||||
delete(r.clientModels, clientID)
|
||||
delete(r.clientModelInfos, clientID)
|
||||
if hasProvider {
|
||||
delete(r.clientProviders, clientID)
|
||||
}
|
||||
@@ -748,7 +766,8 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string)
|
||||
}
|
||||
return result
|
||||
|
||||
case "claude":
|
||||
case "claude", "kiro", "antigravity":
|
||||
// Claude, Kiro, and Antigravity all use Claude-compatible format for Claude Code client
|
||||
result := map[string]any{
|
||||
"id": model.ID,
|
||||
"object": "model",
|
||||
@@ -763,6 +782,19 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string)
|
||||
if model.DisplayName != "" {
|
||||
result["display_name"] = model.DisplayName
|
||||
}
|
||||
// Add thinking support for Claude Code client
|
||||
// Claude Code checks for "thinking" field (simple boolean) to enable tab toggle
|
||||
// Also add "extended_thinking" for detailed budget info
|
||||
if model.Thinking != nil {
|
||||
result["thinking"] = true
|
||||
result["extended_thinking"] = map[string]any{
|
||||
"supported": true,
|
||||
"min": model.Thinking.Min,
|
||||
"max": model.Thinking.Max,
|
||||
"zero_allowed": model.Thinking.ZeroAllowed,
|
||||
"dynamic_allowed": model.Thinking.DynamicAllowed,
|
||||
}
|
||||
}
|
||||
return result
|
||||
|
||||
case "gemini":
|
||||
@@ -871,3 +903,44 @@ func (r *ModelRegistry) GetFirstAvailableModel(handlerType string) (string, erro
|
||||
|
||||
return "", fmt.Errorf("no available clients for any model in handler type: %s", handlerType)
|
||||
}
|
||||
|
||||
// GetModelsForClient returns the models registered for a specific client.
|
||||
// Parameters:
|
||||
// - clientID: The client identifier (typically auth file name or auth ID)
|
||||
//
|
||||
// Returns:
|
||||
// - []*ModelInfo: List of models registered for this client, nil if client not found
|
||||
func (r *ModelRegistry) GetModelsForClient(clientID string) []*ModelInfo {
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
|
||||
modelIDs, exists := r.clientModels[clientID]
|
||||
if !exists || len(modelIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try to use client-specific model infos first
|
||||
clientInfos := r.clientModelInfos[clientID]
|
||||
|
||||
seen := make(map[string]struct{})
|
||||
result := make([]*ModelInfo, 0, len(modelIDs))
|
||||
for _, modelID := range modelIDs {
|
||||
if _, dup := seen[modelID]; dup {
|
||||
continue
|
||||
}
|
||||
seen[modelID] = struct{}{}
|
||||
|
||||
// Prefer client's own model info to preserve original type/owned_by
|
||||
if clientInfos != nil {
|
||||
if info, ok := clientInfos[modelID]; ok && info != nil {
|
||||
result = append(result, info)
|
||||
continue
|
||||
}
|
||||
}
|
||||
// Fallback to global registry (for backwards compatibility)
|
||||
if reg, ok := r.models[modelID]; ok && reg.Info != nil {
|
||||
result = append(result, reg.Info)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -322,7 +322,7 @@ func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts c
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream)
|
||||
payload = applyThinkingMetadata(payload, req.Metadata, req.Model)
|
||||
payload = ApplyThinkingMetadata(payload, req.Metadata, req.Model)
|
||||
payload = util.ApplyDefaultThinkingIfNeeded(req.Model, payload)
|
||||
payload = util.ConvertThinkingLevelToBudget(payload)
|
||||
payload = util.NormalizeGeminiThinkingBudget(req.Model, payload)
|
||||
@@ -384,8 +384,16 @@ func ensureColonSpacedJSON(payload []byte) []byte {
|
||||
|
||||
for i := 0; i < len(indented); i++ {
|
||||
ch := indented[i]
|
||||
if ch == '"' && (i == 0 || indented[i-1] != '\\') {
|
||||
inString = !inString
|
||||
if ch == '"' {
|
||||
// A quote is escaped only when preceded by an odd number of consecutive backslashes.
|
||||
// For example: "\\\"" keeps the quote inside the string, but "\\\\" closes the string.
|
||||
backslashes := 0
|
||||
for j := i - 1; j >= 0 && indented[j] == '\\'; j-- {
|
||||
backslashes++
|
||||
}
|
||||
if backslashes%2 == 0 {
|
||||
inString = !inString
|
||||
}
|
||||
}
|
||||
|
||||
if !inString {
|
||||
|
||||
@@ -54,9 +54,9 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("codex")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
body = applyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning.effort")
|
||||
body = normalizeThinkingConfig(body, upstreamModel)
|
||||
if errValidate := validateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning.effort", false)
|
||||
body = NormalizeThinkingConfig(body, upstreamModel, false)
|
||||
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
return resp, errValidate
|
||||
}
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
@@ -152,9 +152,9 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
to := sdktranslator.FromString("codex")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
|
||||
body = applyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning.effort")
|
||||
body = normalizeThinkingConfig(body, upstreamModel)
|
||||
if errValidate := validateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning.effort", false)
|
||||
body = NormalizeThinkingConfig(body, upstreamModel, false)
|
||||
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
return nil, errValidate
|
||||
}
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
@@ -254,7 +254,7 @@ func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth
|
||||
|
||||
modelForCounting := req.Model
|
||||
|
||||
body = applyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning.effort")
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning.effort", false)
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||
body, _ = sjson.SetBytes(body, "stream", false)
|
||||
|
||||
@@ -11,6 +11,8 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -784,20 +786,45 @@ func parseRetryDelay(errorBody []byte) (*time.Duration, error) {
|
||||
// Try to parse the retryDelay from the error response
|
||||
// Format: error.details[].retryDelay where @type == "type.googleapis.com/google.rpc.RetryInfo"
|
||||
details := gjson.GetBytes(errorBody, "error.details")
|
||||
if !details.Exists() || !details.IsArray() {
|
||||
return nil, fmt.Errorf("no error.details found")
|
||||
if details.Exists() && details.IsArray() {
|
||||
for _, detail := range details.Array() {
|
||||
typeVal := detail.Get("@type").String()
|
||||
if typeVal == "type.googleapis.com/google.rpc.RetryInfo" {
|
||||
retryDelay := detail.Get("retryDelay").String()
|
||||
if retryDelay != "" {
|
||||
// Parse duration string like "0.847655010s"
|
||||
duration, err := time.ParseDuration(retryDelay)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse duration")
|
||||
}
|
||||
return &duration, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: try ErrorInfo.metadata.quotaResetDelay (e.g., "373.801628ms")
|
||||
for _, detail := range details.Array() {
|
||||
typeVal := detail.Get("@type").String()
|
||||
if typeVal == "type.googleapis.com/google.rpc.ErrorInfo" {
|
||||
quotaResetDelay := detail.Get("metadata.quotaResetDelay").String()
|
||||
if quotaResetDelay != "" {
|
||||
duration, err := time.ParseDuration(quotaResetDelay)
|
||||
if err == nil {
|
||||
return &duration, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, detail := range details.Array() {
|
||||
typeVal := detail.Get("@type").String()
|
||||
if typeVal == "type.googleapis.com/google.rpc.RetryInfo" {
|
||||
retryDelay := detail.Get("retryDelay").String()
|
||||
if retryDelay != "" {
|
||||
// Parse duration string like "0.847655010s"
|
||||
duration, err := time.ParseDuration(retryDelay)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse duration")
|
||||
}
|
||||
// Fallback: parse from error.message "Your quota will reset after Xs."
|
||||
message := gjson.GetBytes(errorBody, "error.message").String()
|
||||
if message != "" {
|
||||
re := regexp.MustCompile(`after\s+(\d+)s\.?`)
|
||||
if matches := re.FindStringSubmatch(message); len(matches) > 1 {
|
||||
seconds, err := strconv.Atoi(matches[1])
|
||||
if err == nil {
|
||||
duration := time.Duration(seconds) * time.Second
|
||||
return &duration, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -83,7 +83,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
body = applyThinkingMetadata(body, req.Metadata, req.Model)
|
||||
body = ApplyThinkingMetadata(body, req.Metadata, req.Model)
|
||||
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
|
||||
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
|
||||
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
||||
@@ -178,7 +178,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
body = applyThinkingMetadata(body, req.Metadata, req.Model)
|
||||
body = ApplyThinkingMetadata(body, req.Metadata, req.Model)
|
||||
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
|
||||
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
|
||||
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
||||
@@ -290,7 +290,7 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
translatedReq = applyThinkingMetadata(translatedReq, req.Metadata, req.Model)
|
||||
translatedReq = ApplyThinkingMetadata(translatedReq, req.Metadata, req.Model)
|
||||
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
|
||||
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
|
||||
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
||||
|
||||
@@ -57,15 +57,16 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
body = applyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort")
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
if upstreamModel != "" {
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
}
|
||||
body = normalizeThinkingConfig(body, upstreamModel)
|
||||
if errValidate := validateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
body = NormalizeThinkingConfig(body, upstreamModel, false)
|
||||
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
return resp, errValidate
|
||||
}
|
||||
body = applyIFlowThinkingConfig(body)
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
|
||||
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
|
||||
@@ -148,15 +149,16 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
to := sdktranslator.FromString("openai")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
|
||||
body = applyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort")
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
if upstreamModel != "" {
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
}
|
||||
body = normalizeThinkingConfig(body, upstreamModel)
|
||||
if errValidate := validateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
body = NormalizeThinkingConfig(body, upstreamModel, false)
|
||||
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
return nil, errValidate
|
||||
}
|
||||
body = applyIFlowThinkingConfig(body)
|
||||
// Ensure tools array exists to avoid provider quirks similar to Qwen's behaviour.
|
||||
toolsResult := gjson.GetBytes(body, "tools")
|
||||
if toolsResult.Exists() && toolsResult.IsArray() && len(toolsResult.Array()) == 0 {
|
||||
@@ -442,3 +444,21 @@ func ensureToolsArray(body []byte) []byte {
|
||||
}
|
||||
return updated
|
||||
}
|
||||
|
||||
// applyIFlowThinkingConfig converts normalized reasoning_effort to iFlow chat_template_kwargs.enable_thinking.
|
||||
// This should be called after NormalizeThinkingConfig has processed the payload.
|
||||
// iFlow only supports boolean enable_thinking, so any non-"none" effort enables thinking.
|
||||
func applyIFlowThinkingConfig(body []byte) []byte {
|
||||
effort := gjson.GetBytes(body, "reasoning_effort")
|
||||
if !effort.Exists() {
|
||||
return body
|
||||
}
|
||||
|
||||
val := strings.ToLower(strings.TrimSpace(effort.String()))
|
||||
enableThinking := val != "none" && val != ""
|
||||
|
||||
body, _ = sjson.DeleteBytes(body, "reasoning_effort")
|
||||
body, _ = sjson.SetBytes(body, "chat_template_kwargs.enable_thinking", enableThinking)
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -54,17 +54,19 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), opts.Stream)
|
||||
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
|
||||
modelOverride := e.resolveUpstreamModel(req.Model, auth)
|
||||
if modelOverride != "" {
|
||||
translated = e.overrideModel(translated, modelOverride)
|
||||
}
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated)
|
||||
translated = applyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort")
|
||||
allowCompat := e.allowCompatReasoningEffort(req.Model, auth)
|
||||
translated = ApplyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort", allowCompat)
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
if upstreamModel != "" {
|
||||
if upstreamModel != "" && modelOverride == "" {
|
||||
translated, _ = sjson.SetBytes(translated, "model", upstreamModel)
|
||||
}
|
||||
translated = normalizeThinkingConfig(translated, upstreamModel)
|
||||
if errValidate := validateThinkingConfig(translated, upstreamModel); errValidate != nil {
|
||||
translated = NormalizeThinkingConfig(translated, upstreamModel, allowCompat)
|
||||
if errValidate := ValidateThinkingConfig(translated, upstreamModel); errValidate != nil {
|
||||
return resp, errValidate
|
||||
}
|
||||
|
||||
@@ -148,17 +150,19 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
|
||||
modelOverride := e.resolveUpstreamModel(req.Model, auth)
|
||||
if modelOverride != "" {
|
||||
translated = e.overrideModel(translated, modelOverride)
|
||||
}
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated)
|
||||
translated = applyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort")
|
||||
allowCompat := e.allowCompatReasoningEffort(req.Model, auth)
|
||||
translated = ApplyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort", allowCompat)
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
if upstreamModel != "" {
|
||||
if upstreamModel != "" && modelOverride == "" {
|
||||
translated, _ = sjson.SetBytes(translated, "model", upstreamModel)
|
||||
}
|
||||
translated = normalizeThinkingConfig(translated, upstreamModel)
|
||||
if errValidate := validateThinkingConfig(translated, upstreamModel); errValidate != nil {
|
||||
translated = NormalizeThinkingConfig(translated, upstreamModel, allowCompat)
|
||||
if errValidate := ValidateThinkingConfig(translated, upstreamModel); errValidate != nil {
|
||||
return nil, errValidate
|
||||
}
|
||||
|
||||
@@ -323,6 +327,27 @@ func (e *OpenAICompatExecutor) resolveUpstreamModel(alias string, auth *cliproxy
|
||||
return ""
|
||||
}
|
||||
|
||||
func (e *OpenAICompatExecutor) allowCompatReasoningEffort(model string, auth *cliproxyauth.Auth) bool {
|
||||
trimmed := strings.TrimSpace(model)
|
||||
if trimmed == "" || e == nil || e.cfg == nil {
|
||||
return false
|
||||
}
|
||||
compat := e.resolveCompatConfig(auth)
|
||||
if compat == nil || len(compat.Models) == 0 {
|
||||
return false
|
||||
}
|
||||
for i := range compat.Models {
|
||||
entry := compat.Models[i]
|
||||
if strings.EqualFold(strings.TrimSpace(entry.Alias), trimmed) {
|
||||
return true
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(entry.Name), trimmed) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (e *OpenAICompatExecutor) resolveCompatConfig(auth *cliproxyauth.Auth) *config.OpenAICompatibility {
|
||||
if auth == nil || e.cfg == nil {
|
||||
return nil
|
||||
|
||||
@@ -11,9 +11,9 @@ import (
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// applyThinkingMetadata applies thinking config from model suffix metadata (e.g., (high), (8192))
|
||||
// ApplyThinkingMetadata applies thinking config from model suffix metadata (e.g., (high), (8192))
|
||||
// for standard Gemini format payloads. It normalizes the budget when the model supports thinking.
|
||||
func applyThinkingMetadata(payload []byte, metadata map[string]any, model string) []byte {
|
||||
func ApplyThinkingMetadata(payload []byte, metadata map[string]any, model string) []byte {
|
||||
budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, metadata)
|
||||
if !ok || (budgetOverride == nil && includeOverride == nil) {
|
||||
return payload
|
||||
@@ -45,22 +45,38 @@ func applyThinkingMetadataCLI(payload []byte, metadata map[string]any, model str
|
||||
return util.ApplyGeminiCLIThinkingConfig(payload, budgetOverride, includeOverride)
|
||||
}
|
||||
|
||||
// applyReasoningEffortMetadata applies reasoning effort overrides from metadata to the given JSON path.
|
||||
// ApplyReasoningEffortMetadata applies reasoning effort overrides from metadata to the given JSON path.
|
||||
// Metadata values take precedence over any existing field when the model supports thinking, intentionally
|
||||
// overwriting caller-provided values to honor suffix/default metadata priority.
|
||||
func applyReasoningEffortMetadata(payload []byte, metadata map[string]any, model, field string) []byte {
|
||||
func ApplyReasoningEffortMetadata(payload []byte, metadata map[string]any, model, field string, allowCompat bool) []byte {
|
||||
if len(metadata) == 0 {
|
||||
return payload
|
||||
}
|
||||
if !util.ModelSupportsThinking(model) {
|
||||
return payload
|
||||
}
|
||||
if field == "" {
|
||||
return payload
|
||||
}
|
||||
baseModel := util.ResolveOriginalModel(model, metadata)
|
||||
if baseModel == "" {
|
||||
baseModel = model
|
||||
}
|
||||
if !util.ModelSupportsThinking(baseModel) && !allowCompat {
|
||||
return payload
|
||||
}
|
||||
if effort, ok := util.ReasoningEffortFromMetadata(metadata); ok && effort != "" {
|
||||
if updated, err := sjson.SetBytes(payload, field, effort); err == nil {
|
||||
return updated
|
||||
if util.ModelUsesThinkingLevels(baseModel) || allowCompat {
|
||||
if updated, err := sjson.SetBytes(payload, field, effort); err == nil {
|
||||
return updated
|
||||
}
|
||||
}
|
||||
}
|
||||
// Fallback: numeric thinking_budget suffix for level-based (OpenAI-style) models.
|
||||
if util.ModelUsesThinkingLevels(baseModel) || allowCompat {
|
||||
if budget, _, _, matched := util.ThinkingFromMetadata(metadata); matched && budget != nil {
|
||||
if effort, ok := util.ThinkingBudgetToEffort(baseModel, *budget); ok && effort != "" {
|
||||
if updated, err := sjson.SetBytes(payload, field, effort); err == nil {
|
||||
return updated
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return payload
|
||||
@@ -216,34 +232,43 @@ func matchModelPattern(pattern, model string) bool {
|
||||
return pi == len(pattern)
|
||||
}
|
||||
|
||||
// normalizeThinkingConfig normalizes thinking-related fields in the payload
|
||||
// NormalizeThinkingConfig normalizes thinking-related fields in the payload
|
||||
// based on model capabilities. For models without thinking support, it strips
|
||||
// reasoning fields. For models with level-based thinking, it validates and
|
||||
// normalizes the reasoning effort level.
|
||||
func normalizeThinkingConfig(payload []byte, model string) []byte {
|
||||
// normalizes the reasoning effort level. For models with numeric budget thinking,
|
||||
// it strips the effort string fields.
|
||||
func NormalizeThinkingConfig(payload []byte, model string, allowCompat bool) []byte {
|
||||
if len(payload) == 0 || model == "" {
|
||||
return payload
|
||||
}
|
||||
|
||||
if !util.ModelSupportsThinking(model) {
|
||||
return stripThinkingFields(payload)
|
||||
if allowCompat {
|
||||
return payload
|
||||
}
|
||||
return StripThinkingFields(payload, false)
|
||||
}
|
||||
|
||||
if util.ModelUsesThinkingLevels(model) {
|
||||
return normalizeReasoningEffortLevel(payload, model)
|
||||
return NormalizeReasoningEffortLevel(payload, model)
|
||||
}
|
||||
|
||||
return payload
|
||||
// Model supports thinking but uses numeric budgets, not levels.
|
||||
// Strip effort string fields since they are not applicable.
|
||||
return StripThinkingFields(payload, true)
|
||||
}
|
||||
|
||||
// stripThinkingFields removes thinking-related fields from the payload for
|
||||
// models that do not support thinking.
|
||||
func stripThinkingFields(payload []byte) []byte {
|
||||
// StripThinkingFields removes thinking-related fields from the payload for
|
||||
// models that do not support thinking. If effortOnly is true, only removes
|
||||
// effort string fields (for models using numeric budgets).
|
||||
func StripThinkingFields(payload []byte, effortOnly bool) []byte {
|
||||
fieldsToRemove := []string{
|
||||
"reasoning",
|
||||
"reasoning_effort",
|
||||
"reasoning.effort",
|
||||
}
|
||||
if !effortOnly {
|
||||
fieldsToRemove = append([]string{"reasoning", "thinking"}, fieldsToRemove...)
|
||||
}
|
||||
out := payload
|
||||
for _, field := range fieldsToRemove {
|
||||
if gjson.GetBytes(out, field).Exists() {
|
||||
@@ -253,9 +278,9 @@ func stripThinkingFields(payload []byte) []byte {
|
||||
return out
|
||||
}
|
||||
|
||||
// normalizeReasoningEffortLevel validates and normalizes the reasoning_effort
|
||||
// NormalizeReasoningEffortLevel validates and normalizes the reasoning_effort
|
||||
// or reasoning.effort field for level-based thinking models.
|
||||
func normalizeReasoningEffortLevel(payload []byte, model string) []byte {
|
||||
func NormalizeReasoningEffortLevel(payload []byte, model string) []byte {
|
||||
out := payload
|
||||
|
||||
if effort := gjson.GetBytes(out, "reasoning_effort"); effort.Exists() {
|
||||
@@ -273,10 +298,10 @@ func normalizeReasoningEffortLevel(payload []byte, model string) []byte {
|
||||
return out
|
||||
}
|
||||
|
||||
// validateThinkingConfig checks for unsupported reasoning levels on level-based models.
|
||||
// ValidateThinkingConfig checks for unsupported reasoning levels on level-based models.
|
||||
// Returns a statusErr with 400 when an unsupported level is supplied to avoid silently
|
||||
// downgrading requests.
|
||||
func validateThinkingConfig(payload []byte, model string) error {
|
||||
func ValidateThinkingConfig(payload []byte, model string) error {
|
||||
if len(payload) == 0 || model == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -51,13 +51,13 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
body = applyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort")
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
if upstreamModel != "" {
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
}
|
||||
body = normalizeThinkingConfig(body, upstreamModel)
|
||||
if errValidate := validateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
body = NormalizeThinkingConfig(body, upstreamModel, false)
|
||||
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
return resp, errValidate
|
||||
}
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
@@ -131,13 +131,13 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
to := sdktranslator.FromString("openai")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
|
||||
body = applyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort")
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
if upstreamModel != "" {
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
}
|
||||
body = normalizeThinkingConfig(body, upstreamModel)
|
||||
if errValidate := validateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
body = NormalizeThinkingConfig(body, upstreamModel, false)
|
||||
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
return nil, errValidate
|
||||
}
|
||||
toolsResult := gjson.GetBytes(body, "tools")
|
||||
|
||||
@@ -2,43 +2,107 @@ package executor
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tiktoken-go/tokenizer"
|
||||
)
|
||||
|
||||
// tokenizerCache stores tokenizer instances to avoid repeated creation
|
||||
var tokenizerCache sync.Map
|
||||
|
||||
// TokenizerWrapper wraps a tokenizer codec with an adjustment factor for models
|
||||
// where tiktoken may not accurately estimate token counts (e.g., Claude models)
|
||||
type TokenizerWrapper struct {
|
||||
Codec tokenizer.Codec
|
||||
AdjustmentFactor float64 // 1.0 means no adjustment, >1.0 means tiktoken underestimates
|
||||
}
|
||||
|
||||
// Count returns the token count with adjustment factor applied
|
||||
func (tw *TokenizerWrapper) Count(text string) (int, error) {
|
||||
count, err := tw.Codec.Count(text)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if tw.AdjustmentFactor != 1.0 && tw.AdjustmentFactor > 0 {
|
||||
return int(float64(count) * tw.AdjustmentFactor), nil
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// getTokenizer returns a cached tokenizer for the given model.
|
||||
// This improves performance by avoiding repeated tokenizer creation.
|
||||
func getTokenizer(model string) (*TokenizerWrapper, error) {
|
||||
// Check cache first
|
||||
if cached, ok := tokenizerCache.Load(model); ok {
|
||||
return cached.(*TokenizerWrapper), nil
|
||||
}
|
||||
|
||||
// Cache miss, create new tokenizer
|
||||
wrapper, err := tokenizerForModel(model)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Store in cache (use LoadOrStore to handle race conditions)
|
||||
actual, _ := tokenizerCache.LoadOrStore(model, wrapper)
|
||||
return actual.(*TokenizerWrapper), nil
|
||||
}
|
||||
|
||||
// tokenizerForModel returns a tokenizer codec suitable for an OpenAI-style model id.
|
||||
func tokenizerForModel(model string) (tokenizer.Codec, error) {
|
||||
// For Claude models, applies a 1.1 adjustment factor since tiktoken may underestimate.
|
||||
func tokenizerForModel(model string) (*TokenizerWrapper, error) {
|
||||
sanitized := strings.ToLower(strings.TrimSpace(model))
|
||||
|
||||
// Claude models use cl100k_base with 1.1 adjustment factor
|
||||
// because tiktoken may underestimate Claude's actual token count
|
||||
if strings.Contains(sanitized, "claude") || strings.HasPrefix(sanitized, "kiro-") || strings.HasPrefix(sanitized, "amazonq-") {
|
||||
enc, err := tokenizer.Get(tokenizer.Cl100kBase)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &TokenizerWrapper{Codec: enc, AdjustmentFactor: 1.1}, nil
|
||||
}
|
||||
|
||||
var enc tokenizer.Codec
|
||||
var err error
|
||||
|
||||
switch {
|
||||
case sanitized == "":
|
||||
return tokenizer.Get(tokenizer.Cl100kBase)
|
||||
enc, err = tokenizer.Get(tokenizer.Cl100kBase)
|
||||
case strings.HasPrefix(sanitized, "gpt-5"):
|
||||
return tokenizer.ForModel(tokenizer.GPT5)
|
||||
enc, err = tokenizer.ForModel(tokenizer.GPT5)
|
||||
case strings.HasPrefix(sanitized, "gpt-5.1"):
|
||||
return tokenizer.ForModel(tokenizer.GPT5)
|
||||
enc, err = tokenizer.ForModel(tokenizer.GPT5)
|
||||
case strings.HasPrefix(sanitized, "gpt-4.1"):
|
||||
return tokenizer.ForModel(tokenizer.GPT41)
|
||||
enc, err = tokenizer.ForModel(tokenizer.GPT41)
|
||||
case strings.HasPrefix(sanitized, "gpt-4o"):
|
||||
return tokenizer.ForModel(tokenizer.GPT4o)
|
||||
enc, err = tokenizer.ForModel(tokenizer.GPT4o)
|
||||
case strings.HasPrefix(sanitized, "gpt-4"):
|
||||
return tokenizer.ForModel(tokenizer.GPT4)
|
||||
enc, err = tokenizer.ForModel(tokenizer.GPT4)
|
||||
case strings.HasPrefix(sanitized, "gpt-3.5"), strings.HasPrefix(sanitized, "gpt-3"):
|
||||
return tokenizer.ForModel(tokenizer.GPT35Turbo)
|
||||
enc, err = tokenizer.ForModel(tokenizer.GPT35Turbo)
|
||||
case strings.HasPrefix(sanitized, "o1"):
|
||||
return tokenizer.ForModel(tokenizer.O1)
|
||||
enc, err = tokenizer.ForModel(tokenizer.O1)
|
||||
case strings.HasPrefix(sanitized, "o3"):
|
||||
return tokenizer.ForModel(tokenizer.O3)
|
||||
enc, err = tokenizer.ForModel(tokenizer.O3)
|
||||
case strings.HasPrefix(sanitized, "o4"):
|
||||
return tokenizer.ForModel(tokenizer.O4Mini)
|
||||
enc, err = tokenizer.ForModel(tokenizer.O4Mini)
|
||||
default:
|
||||
return tokenizer.Get(tokenizer.O200kBase)
|
||||
enc, err = tokenizer.Get(tokenizer.O200kBase)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &TokenizerWrapper{Codec: enc, AdjustmentFactor: 1.0}, nil
|
||||
}
|
||||
|
||||
// countOpenAIChatTokens approximates prompt tokens for OpenAI chat completions payloads.
|
||||
func countOpenAIChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) {
|
||||
func countOpenAIChatTokens(enc *TokenizerWrapper, payload []byte) (int64, error) {
|
||||
if enc == nil {
|
||||
return 0, fmt.Errorf("encoder is nil")
|
||||
}
|
||||
@@ -62,11 +126,206 @@ func countOpenAIChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Count text tokens
|
||||
count, err := enc.Count(joined)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int64(count), nil
|
||||
|
||||
// Extract and add image tokens from placeholders
|
||||
imageTokens := extractImageTokens(joined)
|
||||
|
||||
return int64(count) + int64(imageTokens), nil
|
||||
}
|
||||
|
||||
// countClaudeChatTokens approximates prompt tokens for Claude API chat completions payloads.
|
||||
// This handles Claude's message format with system, messages, and tools.
|
||||
// Image tokens are estimated based on image dimensions when available.
|
||||
func countClaudeChatTokens(enc *TokenizerWrapper, payload []byte) (int64, error) {
|
||||
if enc == nil {
|
||||
return 0, fmt.Errorf("encoder is nil")
|
||||
}
|
||||
if len(payload) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
root := gjson.ParseBytes(payload)
|
||||
segments := make([]string, 0, 32)
|
||||
|
||||
// Collect system prompt (can be string or array of content blocks)
|
||||
collectClaudeSystem(root.Get("system"), &segments)
|
||||
|
||||
// Collect messages
|
||||
collectClaudeMessages(root.Get("messages"), &segments)
|
||||
|
||||
// Collect tools
|
||||
collectClaudeTools(root.Get("tools"), &segments)
|
||||
|
||||
joined := strings.TrimSpace(strings.Join(segments, "\n"))
|
||||
if joined == "" {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Count text tokens
|
||||
count, err := enc.Count(joined)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Extract and add image tokens from placeholders
|
||||
imageTokens := extractImageTokens(joined)
|
||||
|
||||
return int64(count) + int64(imageTokens), nil
|
||||
}
|
||||
|
||||
// imageTokenPattern matches [IMAGE:xxx tokens] format for extracting estimated image tokens
|
||||
var imageTokenPattern = regexp.MustCompile(`\[IMAGE:(\d+) tokens\]`)
|
||||
|
||||
// extractImageTokens extracts image token estimates from placeholder text.
|
||||
// Placeholders are in the format [IMAGE:xxx tokens] where xxx is the estimated token count.
|
||||
func extractImageTokens(text string) int {
|
||||
matches := imageTokenPattern.FindAllStringSubmatch(text, -1)
|
||||
total := 0
|
||||
for _, match := range matches {
|
||||
if len(match) > 1 {
|
||||
if tokens, err := strconv.Atoi(match[1]); err == nil {
|
||||
total += tokens
|
||||
}
|
||||
}
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
// estimateImageTokens calculates estimated tokens for an image based on dimensions.
|
||||
// Based on Claude's image token calculation: tokens ≈ (width * height) / 750
|
||||
// Minimum 85 tokens, maximum 1590 tokens (for 1568x1568 images).
|
||||
func estimateImageTokens(width, height float64) int {
|
||||
if width <= 0 || height <= 0 {
|
||||
// No valid dimensions, use default estimate (medium-sized image)
|
||||
return 1000
|
||||
}
|
||||
|
||||
tokens := int(width * height / 750)
|
||||
|
||||
// Apply bounds
|
||||
if tokens < 85 {
|
||||
tokens = 85
|
||||
}
|
||||
if tokens > 1590 {
|
||||
tokens = 1590
|
||||
}
|
||||
|
||||
return tokens
|
||||
}
|
||||
|
||||
// collectClaudeSystem extracts text from Claude's system field.
|
||||
// System can be a string or an array of content blocks.
|
||||
func collectClaudeSystem(system gjson.Result, segments *[]string) {
|
||||
if !system.Exists() {
|
||||
return
|
||||
}
|
||||
if system.Type == gjson.String {
|
||||
addIfNotEmpty(segments, system.String())
|
||||
return
|
||||
}
|
||||
if system.IsArray() {
|
||||
system.ForEach(func(_, block gjson.Result) bool {
|
||||
blockType := block.Get("type").String()
|
||||
if blockType == "text" || blockType == "" {
|
||||
addIfNotEmpty(segments, block.Get("text").String())
|
||||
}
|
||||
// Also handle plain string blocks
|
||||
if block.Type == gjson.String {
|
||||
addIfNotEmpty(segments, block.String())
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// collectClaudeMessages extracts text from Claude's messages array.
|
||||
func collectClaudeMessages(messages gjson.Result, segments *[]string) {
|
||||
if !messages.Exists() || !messages.IsArray() {
|
||||
return
|
||||
}
|
||||
messages.ForEach(func(_, message gjson.Result) bool {
|
||||
addIfNotEmpty(segments, message.Get("role").String())
|
||||
collectClaudeContent(message.Get("content"), segments)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// collectClaudeContent extracts text from Claude's content field.
|
||||
// Content can be a string or an array of content blocks.
|
||||
// For images, estimates token count based on dimensions when available.
|
||||
func collectClaudeContent(content gjson.Result, segments *[]string) {
|
||||
if !content.Exists() {
|
||||
return
|
||||
}
|
||||
if content.Type == gjson.String {
|
||||
addIfNotEmpty(segments, content.String())
|
||||
return
|
||||
}
|
||||
if content.IsArray() {
|
||||
content.ForEach(func(_, part gjson.Result) bool {
|
||||
partType := part.Get("type").String()
|
||||
switch partType {
|
||||
case "text":
|
||||
addIfNotEmpty(segments, part.Get("text").String())
|
||||
case "image":
|
||||
// Estimate image tokens based on dimensions if available
|
||||
source := part.Get("source")
|
||||
if source.Exists() {
|
||||
width := source.Get("width").Float()
|
||||
height := source.Get("height").Float()
|
||||
if width > 0 && height > 0 {
|
||||
tokens := estimateImageTokens(width, height)
|
||||
addIfNotEmpty(segments, fmt.Sprintf("[IMAGE:%d tokens]", tokens))
|
||||
} else {
|
||||
// No dimensions available, use default estimate
|
||||
addIfNotEmpty(segments, "[IMAGE:1000 tokens]")
|
||||
}
|
||||
} else {
|
||||
// No source info, use default estimate
|
||||
addIfNotEmpty(segments, "[IMAGE:1000 tokens]")
|
||||
}
|
||||
case "tool_use":
|
||||
addIfNotEmpty(segments, part.Get("id").String())
|
||||
addIfNotEmpty(segments, part.Get("name").String())
|
||||
if input := part.Get("input"); input.Exists() {
|
||||
addIfNotEmpty(segments, input.Raw)
|
||||
}
|
||||
case "tool_result":
|
||||
addIfNotEmpty(segments, part.Get("tool_use_id").String())
|
||||
collectClaudeContent(part.Get("content"), segments)
|
||||
case "thinking":
|
||||
addIfNotEmpty(segments, part.Get("thinking").String())
|
||||
default:
|
||||
// For unknown types, try to extract any text content
|
||||
if part.Type == gjson.String {
|
||||
addIfNotEmpty(segments, part.String())
|
||||
} else if part.Type == gjson.JSON {
|
||||
addIfNotEmpty(segments, part.Raw)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// collectClaudeTools extracts text from Claude's tools array.
|
||||
func collectClaudeTools(tools gjson.Result, segments *[]string) {
|
||||
if !tools.Exists() || !tools.IsArray() {
|
||||
return
|
||||
}
|
||||
tools.ForEach(func(_, tool gjson.Result) bool {
|
||||
addIfNotEmpty(segments, tool.Get("name").String())
|
||||
addIfNotEmpty(segments, tool.Get("description").String())
|
||||
if inputSchema := tool.Get("input_schema"); inputSchema.Exists() {
|
||||
addIfNotEmpty(segments, inputSchema.Raw)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// buildOpenAIUsageJSON returns a minimal usage structure understood by downstream translators.
|
||||
|
||||
@@ -84,13 +84,18 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
contentResult := contentResults[j]
|
||||
contentTypeResult := contentResult.Get("type")
|
||||
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "thinking" {
|
||||
prompt := contentResult.Get("thinking").String()
|
||||
// Claude "thinking" blocks are internal-only. They also require a valid provider signature
|
||||
// when replayed as conversation history. Since we cannot mint signatures, only forward
|
||||
// thinking blocks when the client provides a non-empty signature; otherwise, drop them.
|
||||
signatureResult := contentResult.Get("signature")
|
||||
signature := geminiCLIClaudeThoughtSignature
|
||||
if signatureResult.Exists() {
|
||||
signature = signatureResult.String()
|
||||
if signatureResult.Type == gjson.String && signatureResult.String() != "" {
|
||||
prompt := contentResult.Get("thinking").String()
|
||||
clientContent.Parts = append(clientContent.Parts, client.Part{
|
||||
Text: prompt,
|
||||
Thought: true,
|
||||
ThoughtSignature: signatureResult.String(),
|
||||
})
|
||||
}
|
||||
clientContent.Parts = append(clientContent.Parts, client.Part{Text: prompt, Thought: true, ThoughtSignature: signature})
|
||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
|
||||
prompt := contentResult.Get("text").String()
|
||||
clientContent.Parts = append(clientContent.Parts, client.Part{Text: prompt})
|
||||
@@ -117,9 +122,17 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
funcName := toolCallID
|
||||
toolCallIDs := strings.Split(toolCallID, "-")
|
||||
if len(toolCallIDs) > 1 {
|
||||
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-1], "-")
|
||||
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-2], "-")
|
||||
}
|
||||
responseData := contentResult.Get("content").Raw
|
||||
functionResponseResult := contentResult.Get("content")
|
||||
|
||||
responseData := ""
|
||||
if functionResponseResult.Type == gjson.String {
|
||||
responseData = functionResponseResult.String()
|
||||
} else {
|
||||
responseData = contentResult.Get("content").Raw
|
||||
}
|
||||
|
||||
functionResponse := client.FunctionResponse{ID: toolCallID, Name: funcName, Response: map[string]interface{}{"result": responseData}}
|
||||
clientContent.Parts = append(clientContent.Parts, client.Part{FunctionResponse: &functionResponse})
|
||||
}
|
||||
@@ -134,7 +147,9 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
}
|
||||
}
|
||||
}
|
||||
contents = append(contents, clientContent)
|
||||
if len(clientContent.Parts) > 0 {
|
||||
contents = append(contents, clientContent)
|
||||
}
|
||||
} else if contentsResult.Type == gjson.String {
|
||||
prompt := contentsResult.String()
|
||||
contents = append(contents, client.Content{Role: role, Parts: []client.Part{{Text: prompt}}})
|
||||
|
||||
@@ -114,44 +114,54 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
||||
// Extract the different types of content from each part
|
||||
partTextResult := partResult.Get("text")
|
||||
functionCallResult := partResult.Get("functionCall")
|
||||
thoughtSignatureResult := partResult.Get("thoughtSignature")
|
||||
if !thoughtSignatureResult.Exists() {
|
||||
thoughtSignatureResult = partResult.Get("thought_signature")
|
||||
}
|
||||
hasThoughtSignature := thoughtSignatureResult.Exists() && thoughtSignatureResult.String() != ""
|
||||
isThought := partResult.Get("thought").Bool()
|
||||
|
||||
// Some Antigravity/Vertex Claude streams emit the thought signature as a standalone part
|
||||
// (no text payload). Claude requires this signature to be replayed verbatim on subsequent turns.
|
||||
if isThought && hasThoughtSignature && !partTextResult.Exists() && !functionCallResult.Exists() {
|
||||
if params.ResponseType == 2 {
|
||||
output = output + "event: content_block_delta\n"
|
||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":""}}`, params.ResponseIndex), "delta.signature", thoughtSignatureResult.String())
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
params.HasContent = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle text content (both regular content and thinking)
|
||||
if partTextResult.Exists() {
|
||||
// Process thinking content (internal reasoning)
|
||||
if partResult.Get("thought").Bool() {
|
||||
if thoughtSignature := partResult.Get("thoughtSignature"); thoughtSignature.Exists() && thoughtSignature.String() != "" {
|
||||
output = output + "event: content_block_delta\n"
|
||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":""}}`, params.ResponseIndex), "delta.signature", thoughtSignature.String())
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
params.HasContent = true
|
||||
} else if params.ResponseType == 2 { // Continue existing thinking block if already in thinking state
|
||||
output = output + "event: content_block_delta\n"
|
||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, params.ResponseIndex), "delta.thinking", partTextResult.String())
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
params.HasContent = true
|
||||
} else {
|
||||
// Transition from another state to thinking
|
||||
// First, close any existing content block
|
||||
if isThought {
|
||||
// Ensure we have an open thinking block to attach thinking/signature deltas to.
|
||||
if params.ResponseType != 2 {
|
||||
if params.ResponseType != 0 {
|
||||
if params.ResponseType == 2 {
|
||||
// output = output + "event: content_block_delta\n"
|
||||
// output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, params.ResponseIndex)
|
||||
// output = output + "\n\n\n"
|
||||
}
|
||||
output = output + "event: content_block_stop\n"
|
||||
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, params.ResponseIndex)
|
||||
output = output + "\n\n\n"
|
||||
params.ResponseIndex++
|
||||
}
|
||||
|
||||
// Start a new thinking content block
|
||||
output = output + "event: content_block_start\n"
|
||||
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"thinking","thinking":""}}`, params.ResponseIndex)
|
||||
output = output + "\n\n\n"
|
||||
params.ResponseType = 2
|
||||
}
|
||||
|
||||
if partTextResult.String() != "" {
|
||||
output = output + "event: content_block_delta\n"
|
||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, params.ResponseIndex), "delta.thinking", partTextResult.String())
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
params.ResponseType = 2 // Set state to thinking
|
||||
params.HasContent = true
|
||||
}
|
||||
|
||||
if hasThoughtSignature {
|
||||
output = output + "event: content_block_delta\n"
|
||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":""}}`, params.ResponseIndex), "delta.signature", thoughtSignatureResult.String())
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
params.HasContent = true
|
||||
}
|
||||
} else {
|
||||
@@ -368,6 +378,7 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
||||
var contentBlocks []interface{}
|
||||
textBuilder := strings.Builder{}
|
||||
thinkingBuilder := strings.Builder{}
|
||||
thinkingSignature := ""
|
||||
toolIDCounter := 0
|
||||
hasToolCall := false
|
||||
|
||||
@@ -386,19 +397,37 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
||||
if thinkingBuilder.Len() == 0 {
|
||||
return
|
||||
}
|
||||
contentBlocks = append(contentBlocks, map[string]interface{}{
|
||||
block := map[string]interface{}{
|
||||
"type": "thinking",
|
||||
"thinking": thinkingBuilder.String(),
|
||||
})
|
||||
}
|
||||
if thinkingSignature != "" {
|
||||
block["signature"] = thinkingSignature
|
||||
}
|
||||
contentBlocks = append(contentBlocks, block)
|
||||
thinkingBuilder.Reset()
|
||||
thinkingSignature = ""
|
||||
}
|
||||
|
||||
if parts.IsArray() {
|
||||
for _, part := range parts.Array() {
|
||||
thoughtSignatureResult := part.Get("thoughtSignature")
|
||||
if !thoughtSignatureResult.Exists() {
|
||||
thoughtSignatureResult = part.Get("thought_signature")
|
||||
}
|
||||
if part.Get("thought").Bool() && thoughtSignatureResult.Exists() && thoughtSignatureResult.String() != "" && (!part.Get("text").Exists() || part.Get("text").String() == "") {
|
||||
// Signature-only thought part (no text payload).
|
||||
thinkingSignature = thoughtSignatureResult.String()
|
||||
continue
|
||||
}
|
||||
|
||||
if text := part.Get("text"); text.Exists() && text.String() != "" {
|
||||
if part.Get("thought").Bool() {
|
||||
flushText()
|
||||
thinkingBuilder.WriteString(text.String())
|
||||
if thoughtSignatureResult.Exists() && thoughtSignatureResult.String() != "" {
|
||||
thinkingSignature = thoughtSignatureResult.String()
|
||||
}
|
||||
continue
|
||||
}
|
||||
flushThinking()
|
||||
|
||||
@@ -122,6 +122,38 @@ type FunctionCallGroup struct {
|
||||
ResponsesNeeded int
|
||||
}
|
||||
|
||||
// parseFunctionResponse attempts to unmarshal a function response part.
|
||||
// Falls back to gjson extraction if standard json.Unmarshal fails.
|
||||
func parseFunctionResponse(response gjson.Result) map[string]interface{} {
|
||||
var responseMap map[string]interface{}
|
||||
err := json.Unmarshal([]byte(response.Raw), &responseMap)
|
||||
if err == nil {
|
||||
return responseMap
|
||||
}
|
||||
|
||||
log.Debugf("unmarshal function response failed, using fallback: %v", err)
|
||||
funcResp := response.Get("functionResponse")
|
||||
if funcResp.Exists() {
|
||||
fr := map[string]interface{}{
|
||||
"name": funcResp.Get("name").String(),
|
||||
"response": map[string]interface{}{
|
||||
"result": funcResp.Get("response").String(),
|
||||
},
|
||||
}
|
||||
if id := funcResp.Get("id").String(); id != "" {
|
||||
fr["id"] = id
|
||||
}
|
||||
return map[string]interface{}{"functionResponse": fr}
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"functionResponse": map[string]interface{}{
|
||||
"name": "unknown",
|
||||
"response": map[string]interface{}{"result": response.String()},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// fixCLIToolResponse performs sophisticated tool response format conversion and grouping.
|
||||
// This function transforms the CLI tool response format by intelligently grouping function calls
|
||||
// with their corresponding responses, ensuring proper conversation flow and API compatibility.
|
||||
@@ -180,13 +212,7 @@ func fixCLIToolResponse(input string) (string, error) {
|
||||
// Create merged function response content
|
||||
var responseParts []interface{}
|
||||
for _, response := range groupResponses {
|
||||
var responseMap map[string]interface{}
|
||||
errUnmarshal := json.Unmarshal([]byte(response.Raw), &responseMap)
|
||||
if errUnmarshal != nil {
|
||||
log.Warnf("failed to unmarshal function response: %v\n", errUnmarshal)
|
||||
continue
|
||||
}
|
||||
responseParts = append(responseParts, responseMap)
|
||||
responseParts = append(responseParts, parseFunctionResponse(response))
|
||||
}
|
||||
|
||||
if len(responseParts) > 0 {
|
||||
@@ -265,13 +291,7 @@ func fixCLIToolResponse(input string) (string, error) {
|
||||
|
||||
var responseParts []interface{}
|
||||
for _, response := range groupResponses {
|
||||
var responseMap map[string]interface{}
|
||||
errUnmarshal := json.Unmarshal([]byte(response.Raw), &responseMap)
|
||||
if errUnmarshal != nil {
|
||||
log.Warnf("failed to unmarshal function response: %v\n", errUnmarshal)
|
||||
continue
|
||||
}
|
||||
responseParts = append(responseParts, responseMap)
|
||||
responseParts = append(responseParts, parseFunctionResponse(response))
|
||||
}
|
||||
|
||||
if len(responseParts) > 0 {
|
||||
|
||||
@@ -39,31 +39,13 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
// Note: OpenAI official fields take precedence over extra_body.google.thinking_config
|
||||
re := gjson.GetBytes(rawJSON, "reasoning_effort")
|
||||
hasOfficialThinking := re.Exists()
|
||||
if hasOfficialThinking && util.ModelSupportsThinking(modelName) {
|
||||
switch re.String() {
|
||||
case "none":
|
||||
out, _ = sjson.DeleteBytes(out, "request.generationConfig.thinkingConfig.include_thoughts")
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 0)
|
||||
case "auto":
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||
case "low":
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 1024)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||
case "medium":
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 8192)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||
case "high":
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 32768)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||
default:
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||
}
|
||||
if hasOfficialThinking && util.ModelSupportsThinking(modelName) && !util.ModelUsesThinkingLevels(modelName) {
|
||||
out = util.ApplyReasoningEffortToGeminiCLI(out, re.String())
|
||||
}
|
||||
|
||||
// Cherry Studio extension extra_body.google.thinking_config (effective only when official fields are absent)
|
||||
if !hasOfficialThinking && util.ModelSupportsThinking(modelName) {
|
||||
// Only apply for models that use numeric budgets, not discrete levels.
|
||||
if !hasOfficialThinking && util.ModelSupportsThinking(modelName) && !util.ModelUsesThinkingLevels(modelName) {
|
||||
if tc := gjson.GetBytes(rawJSON, "extra_body.google.thinking_config"); tc.Exists() && tc.IsObject() {
|
||||
var setBudget bool
|
||||
var budget int
|
||||
|
||||
@@ -114,14 +114,16 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
}
|
||||
}
|
||||
// Include thoughts configuration for reasoning process visibility
|
||||
if thinkingConfig := genConfig.Get("thinkingConfig"); thinkingConfig.Exists() && thinkingConfig.IsObject() {
|
||||
if includeThoughts := thinkingConfig.Get("include_thoughts"); includeThoughts.Exists() {
|
||||
if includeThoughts.Type == gjson.True {
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
if thinkingBudget := thinkingConfig.Get("thinkingBudget"); thinkingBudget.Exists() {
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", thinkingBudget.Int())
|
||||
}
|
||||
}
|
||||
// Only apply for models that support thinking and use numeric budgets, not discrete levels.
|
||||
if thinkingConfig := genConfig.Get("thinkingConfig"); thinkingConfig.Exists() && thinkingConfig.IsObject() && util.ModelSupportsThinking(modelName) && !util.ModelUsesThinkingLevels(modelName) {
|
||||
// Check for thinkingBudget first - if present, enable thinking with budget
|
||||
if thinkingBudget := thinkingConfig.Get("thinkingBudget"); thinkingBudget.Exists() && thinkingBudget.Int() > 0 {
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
normalizedBudget := util.NormalizeThinkingBudget(modelName, int(thinkingBudget.Int()))
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", normalizedBudget)
|
||||
} else if includeThoughts := thinkingConfig.Get("include_thoughts"); includeThoughts.Exists() && includeThoughts.Type == gjson.True {
|
||||
// Fallback to include_thoughts if no budget specified
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -65,18 +66,23 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
|
||||
if v := root.Get("reasoning_effort"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
|
||||
switch v.String() {
|
||||
case "none":
|
||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||
case "low":
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", 1024)
|
||||
case "medium":
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", 8192)
|
||||
case "high":
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", 24576)
|
||||
if v := root.Get("reasoning_effort"); v.Exists() && util.ModelSupportsThinking(modelName) && !util.ModelUsesThinkingLevels(modelName) {
|
||||
effort := strings.ToLower(strings.TrimSpace(v.String()))
|
||||
if effort != "" {
|
||||
budget, ok := util.ThinkingEffortToBudget(modelName, effort)
|
||||
if ok {
|
||||
switch budget {
|
||||
case 0:
|
||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||
case -1:
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
default:
|
||||
if budget > 0 {
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -52,20 +53,23 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
|
||||
if v := root.Get("reasoning.effort"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
|
||||
switch v.String() {
|
||||
case "none":
|
||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||
case "minimal":
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", 1024)
|
||||
case "low":
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", 4096)
|
||||
case "medium":
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", 8192)
|
||||
case "high":
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", 24576)
|
||||
if v := root.Get("reasoning.effort"); v.Exists() && util.ModelSupportsThinking(modelName) && !util.ModelUsesThinkingLevels(modelName) {
|
||||
effort := strings.ToLower(strings.TrimSpace(v.String()))
|
||||
if effort != "" {
|
||||
budget, ok := util.ThinkingEffortToBudget(modelName, effort)
|
||||
if ok {
|
||||
switch budget {
|
||||
case 0:
|
||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||
case -1:
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
default:
|
||||
if budget > 0 {
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -214,7 +215,27 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
|
||||
// Add additional configuration parameters for the Codex API.
|
||||
template, _ = sjson.Set(template, "parallel_tool_calls", true)
|
||||
template, _ = sjson.Set(template, "reasoning.effort", "medium")
|
||||
|
||||
// Convert thinking.budget_tokens to reasoning.effort for level-based models
|
||||
reasoningEffort := "medium" // default
|
||||
if thinking := rootResult.Get("thinking"); thinking.Exists() && thinking.IsObject() {
|
||||
switch thinking.Get("type").String() {
|
||||
case "enabled":
|
||||
if util.ModelUsesThinkingLevels(modelName) {
|
||||
if budgetTokens := thinking.Get("budget_tokens"); budgetTokens.Exists() {
|
||||
budget := int(budgetTokens.Int())
|
||||
if effort, ok := util.ThinkingBudgetToEffort(modelName, budget); ok && effort != "" {
|
||||
reasoningEffort = effort
|
||||
}
|
||||
}
|
||||
}
|
||||
case "disabled":
|
||||
if effort, ok := util.ThinkingBudgetToEffort(modelName, 0); ok && effort != "" {
|
||||
reasoningEffort = effort
|
||||
}
|
||||
}
|
||||
}
|
||||
template, _ = sjson.Set(template, "reasoning.effort", reasoningEffort)
|
||||
template, _ = sjson.Set(template, "reasoning.summary", "auto")
|
||||
template, _ = sjson.Set(template, "stream", true)
|
||||
template, _ = sjson.Set(template, "store", false)
|
||||
|
||||
@@ -245,7 +245,22 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
|
||||
// Fixed flags aligning with Codex expectations
|
||||
out, _ = sjson.Set(out, "parallel_tool_calls", true)
|
||||
out, _ = sjson.Set(out, "reasoning.effort", "medium")
|
||||
|
||||
// Convert thinkingBudget to reasoning.effort for level-based models
|
||||
reasoningEffort := "medium" // default
|
||||
if genConfig := root.Get("generationConfig"); genConfig.Exists() {
|
||||
if thinkingConfig := genConfig.Get("thinkingConfig"); thinkingConfig.Exists() && thinkingConfig.IsObject() {
|
||||
if util.ModelUsesThinkingLevels(modelName) {
|
||||
if thinkingBudget := thinkingConfig.Get("thinkingBudget"); thinkingBudget.Exists() {
|
||||
budget := int(thinkingBudget.Int())
|
||||
if effort, ok := util.ThinkingBudgetToEffort(modelName, budget); ok && effort != "" {
|
||||
reasoningEffort = effort
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
out, _ = sjson.Set(out, "reasoning.effort", reasoningEffort)
|
||||
out, _ = sjson.Set(out, "reasoning.summary", "auto")
|
||||
out, _ = sjson.Set(out, "stream", true)
|
||||
out, _ = sjson.Set(out, "store", false)
|
||||
|
||||
@@ -39,31 +39,13 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
|
||||
// Note: OpenAI official fields take precedence over extra_body.google.thinking_config
|
||||
re := gjson.GetBytes(rawJSON, "reasoning_effort")
|
||||
hasOfficialThinking := re.Exists()
|
||||
if hasOfficialThinking && util.ModelSupportsThinking(modelName) {
|
||||
switch re.String() {
|
||||
case "none":
|
||||
out, _ = sjson.DeleteBytes(out, "request.generationConfig.thinkingConfig.include_thoughts")
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 0)
|
||||
case "auto":
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||
case "low":
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 1024)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||
case "medium":
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 8192)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||
case "high":
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 32768)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||
default:
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||
}
|
||||
if hasOfficialThinking && util.ModelSupportsThinking(modelName) && !util.ModelUsesThinkingLevels(modelName) {
|
||||
out = util.ApplyReasoningEffortToGeminiCLI(out, re.String())
|
||||
}
|
||||
|
||||
// Cherry Studio extension extra_body.google.thinking_config (effective only when official fields are absent)
|
||||
if !hasOfficialThinking && util.ModelSupportsThinking(modelName) {
|
||||
// Only apply for models that use numeric budgets, not discrete levels.
|
||||
if !hasOfficialThinking && util.ModelSupportsThinking(modelName) && !util.ModelUsesThinkingLevels(modelName) {
|
||||
if tc := gjson.GetBytes(rawJSON, "extra_body.google.thinking_config"); tc.Exists() && tc.IsObject() {
|
||||
var setBudget bool
|
||||
var budget int
|
||||
|
||||
@@ -154,7 +154,8 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
}
|
||||
|
||||
// Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when enabled
|
||||
if t := gjson.GetBytes(rawJSON, "thinking"); t.Exists() && t.IsObject() && util.ModelSupportsThinking(modelName) {
|
||||
// Only apply for models that use numeric budgets, not discrete levels.
|
||||
if t := gjson.GetBytes(rawJSON, "thinking"); t.Exists() && t.IsObject() && util.ModelSupportsThinking(modelName) && !util.ModelUsesThinkingLevels(modelName) {
|
||||
if t.Get("type").String() == "enabled" {
|
||||
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
|
||||
budget := int(b.Int())
|
||||
|
||||
@@ -37,33 +37,17 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
|
||||
// Reasoning effort -> thinkingBudget/include_thoughts
|
||||
// Note: OpenAI official fields take precedence over extra_body.google.thinking_config
|
||||
// Only convert for models that use numeric budgets (not discrete levels) to avoid
|
||||
// incorrectly applying thinkingBudget for level-based models like gpt-5.
|
||||
re := gjson.GetBytes(rawJSON, "reasoning_effort")
|
||||
hasOfficialThinking := re.Exists()
|
||||
if hasOfficialThinking && util.ModelSupportsThinking(modelName) {
|
||||
switch re.String() {
|
||||
case "none":
|
||||
out, _ = sjson.DeleteBytes(out, "generationConfig.thinkingConfig.include_thoughts")
|
||||
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", 0)
|
||||
case "auto":
|
||||
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
||||
case "low":
|
||||
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", 1024)
|
||||
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
||||
case "medium":
|
||||
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", 8192)
|
||||
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
||||
case "high":
|
||||
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", 32768)
|
||||
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
||||
default:
|
||||
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
||||
}
|
||||
if hasOfficialThinking && util.ModelSupportsThinking(modelName) && !util.ModelUsesThinkingLevels(modelName) {
|
||||
out = util.ApplyReasoningEffortToGemini(out, re.String())
|
||||
}
|
||||
|
||||
// Cherry Studio extension extra_body.google.thinking_config (effective only when official fields are absent)
|
||||
if !hasOfficialThinking && util.ModelSupportsThinking(modelName) {
|
||||
// Only apply for models that use numeric budgets, not discrete levels.
|
||||
if !hasOfficialThinking && util.ModelSupportsThinking(modelName) && !util.ModelUsesThinkingLevels(modelName) {
|
||||
if tc := gjson.GetBytes(rawJSON, "extra_body.google.thinking_config"); tc.Exists() && tc.IsObject() {
|
||||
var setBudget bool
|
||||
var budget int
|
||||
|
||||
@@ -389,36 +389,16 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
||||
}
|
||||
|
||||
// OpenAI official reasoning fields take precedence
|
||||
// Only convert for models that use numeric budgets (not discrete levels).
|
||||
hasOfficialThinking := root.Get("reasoning.effort").Exists()
|
||||
if hasOfficialThinking && util.ModelSupportsThinking(modelName) {
|
||||
if hasOfficialThinking && util.ModelSupportsThinking(modelName) && !util.ModelUsesThinkingLevels(modelName) {
|
||||
reasoningEffort := root.Get("reasoning.effort")
|
||||
switch reasoningEffort.String() {
|
||||
case "none":
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", false)
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", 0)
|
||||
case "auto":
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
||||
case "minimal":
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", 1024)
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
||||
case "low":
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", 4096)
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
||||
case "medium":
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", 8192)
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
||||
case "high":
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", 32768)
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
||||
default:
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
||||
}
|
||||
out = string(util.ApplyReasoningEffortToGemini([]byte(out), reasoningEffort.String()))
|
||||
}
|
||||
|
||||
// Cherry Studio extension (applies only when official fields are missing)
|
||||
if !hasOfficialThinking && util.ModelSupportsThinking(modelName) {
|
||||
// Only apply for models that use numeric budgets, not discrete levels.
|
||||
if !hasOfficialThinking && util.ModelSupportsThinking(modelName) && !util.ModelUsesThinkingLevels(modelName) {
|
||||
if tc := root.Get("extra_body.google.thinking_config"); tc.Exists() && tc.IsObject() {
|
||||
var setBudget bool
|
||||
var budget int
|
||||
|
||||
@@ -35,5 +35,5 @@ import (
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/openai/responses"
|
||||
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/claude"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/openai/chat-completions"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/openai"
|
||||
)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
// Package claude provides translation between Kiro and Claude formats.
|
||||
package claude
|
||||
|
||||
import (
|
||||
@@ -12,8 +13,8 @@ func init() {
|
||||
Kiro,
|
||||
ConvertClaudeRequestToKiro,
|
||||
interfaces.TranslateResponse{
|
||||
Stream: ConvertKiroResponseToClaude,
|
||||
NonStream: ConvertKiroResponseToClaudeNonStream,
|
||||
Stream: ConvertKiroStreamToClaude,
|
||||
NonStream: ConvertKiroNonStreamToClaude,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1,27 +1,21 @@
|
||||
// Package claude provides translation between Kiro and Claude formats.
|
||||
// Since Kiro executor generates Claude-compatible SSE format internally (with event: prefix),
|
||||
// translations are pass-through.
|
||||
// translations are pass-through for streaming, but responses need proper formatting.
|
||||
package claude
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
)
|
||||
|
||||
// ConvertClaudeRequestToKiro converts Claude request to Kiro format.
|
||||
// Since Kiro uses Claude format internally, this is mostly a pass-through.
|
||||
func ConvertClaudeRequestToKiro(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
return bytes.Clone(inputRawJSON)
|
||||
}
|
||||
|
||||
// ConvertKiroResponseToClaude converts Kiro streaming response to Claude format.
|
||||
// ConvertKiroStreamToClaude converts Kiro streaming response to Claude format.
|
||||
// Kiro executor already generates complete SSE format with "event:" prefix,
|
||||
// so this is a simple pass-through.
|
||||
func ConvertKiroResponseToClaude(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) []string {
|
||||
func ConvertKiroStreamToClaude(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) []string {
|
||||
return []string{string(rawResponse)}
|
||||
}
|
||||
|
||||
// ConvertKiroResponseToClaudeNonStream converts Kiro non-streaming response to Claude format.
|
||||
func ConvertKiroResponseToClaudeNonStream(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) string {
|
||||
// ConvertKiroNonStreamToClaude converts Kiro non-streaming response to Claude format.
|
||||
// The response is already in Claude format, so this is a pass-through.
|
||||
func ConvertKiroNonStreamToClaude(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) string {
|
||||
return string(rawResponse)
|
||||
}
|
||||
|
||||
810
internal/translator/kiro/claude/kiro_claude_request.go
Normal file
810
internal/translator/kiro/claude/kiro_claude_request.go
Normal file
@@ -0,0 +1,810 @@
|
||||
// Package claude provides request translation functionality for Claude API to Kiro format.
|
||||
// It handles parsing and transforming Claude API requests into the Kiro/Amazon Q API format,
|
||||
// extracting model information, system instructions, message contents, and tool declarations.
|
||||
package claude
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/google/uuid"
|
||||
kirocommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/common"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
|
||||
// Kiro API request structs - field order determines JSON key order
|
||||
|
||||
// KiroPayload is the top-level request structure for Kiro API
|
||||
type KiroPayload struct {
|
||||
ConversationState KiroConversationState `json:"conversationState"`
|
||||
ProfileArn string `json:"profileArn,omitempty"`
|
||||
InferenceConfig *KiroInferenceConfig `json:"inferenceConfig,omitempty"`
|
||||
}
|
||||
|
||||
// KiroInferenceConfig contains inference parameters for the Kiro API.
|
||||
type KiroInferenceConfig struct {
|
||||
MaxTokens int `json:"maxTokens,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
}
|
||||
|
||||
|
||||
// KiroConversationState holds the conversation context
|
||||
type KiroConversationState struct {
|
||||
ChatTriggerType string `json:"chatTriggerType"` // Required: "MANUAL" - must be first field
|
||||
ConversationID string `json:"conversationId"`
|
||||
CurrentMessage KiroCurrentMessage `json:"currentMessage"`
|
||||
History []KiroHistoryMessage `json:"history,omitempty"`
|
||||
}
|
||||
|
||||
// KiroCurrentMessage wraps the current user message
|
||||
type KiroCurrentMessage struct {
|
||||
UserInputMessage KiroUserInputMessage `json:"userInputMessage"`
|
||||
}
|
||||
|
||||
// KiroHistoryMessage represents a message in the conversation history
|
||||
type KiroHistoryMessage struct {
|
||||
UserInputMessage *KiroUserInputMessage `json:"userInputMessage,omitempty"`
|
||||
AssistantResponseMessage *KiroAssistantResponseMessage `json:"assistantResponseMessage,omitempty"`
|
||||
}
|
||||
|
||||
// KiroImage represents an image in Kiro API format
|
||||
type KiroImage struct {
|
||||
Format string `json:"format"`
|
||||
Source KiroImageSource `json:"source"`
|
||||
}
|
||||
|
||||
// KiroImageSource contains the image data
|
||||
type KiroImageSource struct {
|
||||
Bytes string `json:"bytes"` // base64 encoded image data
|
||||
}
|
||||
|
||||
// KiroUserInputMessage represents a user message
|
||||
type KiroUserInputMessage struct {
|
||||
Content string `json:"content"`
|
||||
ModelID string `json:"modelId"`
|
||||
Origin string `json:"origin"`
|
||||
Images []KiroImage `json:"images,omitempty"`
|
||||
UserInputMessageContext *KiroUserInputMessageContext `json:"userInputMessageContext,omitempty"`
|
||||
}
|
||||
|
||||
// KiroUserInputMessageContext contains tool-related context
|
||||
type KiroUserInputMessageContext struct {
|
||||
ToolResults []KiroToolResult `json:"toolResults,omitempty"`
|
||||
Tools []KiroToolWrapper `json:"tools,omitempty"`
|
||||
}
|
||||
|
||||
// KiroToolResult represents a tool execution result
|
||||
type KiroToolResult struct {
|
||||
Content []KiroTextContent `json:"content"`
|
||||
Status string `json:"status"`
|
||||
ToolUseID string `json:"toolUseId"`
|
||||
}
|
||||
|
||||
// KiroTextContent represents text content
|
||||
type KiroTextContent struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
// KiroToolWrapper wraps a tool specification
|
||||
type KiroToolWrapper struct {
|
||||
ToolSpecification KiroToolSpecification `json:"toolSpecification"`
|
||||
}
|
||||
|
||||
// KiroToolSpecification defines a tool's schema
|
||||
type KiroToolSpecification struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
InputSchema KiroInputSchema `json:"inputSchema"`
|
||||
}
|
||||
|
||||
// KiroInputSchema wraps the JSON schema for tool input
|
||||
type KiroInputSchema struct {
|
||||
JSON interface{} `json:"json"`
|
||||
}
|
||||
|
||||
// KiroAssistantResponseMessage represents an assistant message
|
||||
type KiroAssistantResponseMessage struct {
|
||||
Content string `json:"content"`
|
||||
ToolUses []KiroToolUse `json:"toolUses,omitempty"`
|
||||
}
|
||||
|
||||
// KiroToolUse represents a tool invocation by the assistant
|
||||
type KiroToolUse struct {
|
||||
ToolUseID string `json:"toolUseId"`
|
||||
Name string `json:"name"`
|
||||
Input map[string]interface{} `json:"input"`
|
||||
}
|
||||
|
||||
// ConvertClaudeRequestToKiro converts a Claude API request to Kiro format.
|
||||
// This is the main entry point for request translation.
|
||||
func ConvertClaudeRequestToKiro(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
// For Kiro, we pass through the Claude format since buildKiroPayload
|
||||
// expects Claude format and does the conversion internally.
|
||||
// The actual conversion happens in the executor when building the HTTP request.
|
||||
return inputRawJSON
|
||||
}
|
||||
|
||||
// BuildKiroPayload constructs the Kiro API request payload from Claude format.
|
||||
// Supports tool calling - tools are passed via userInputMessageContext.
|
||||
// origin parameter determines which quota to use: "CLI" for Amazon Q, "AI_EDITOR" for Kiro IDE.
|
||||
// isAgentic parameter enables chunked write optimization prompt for -agentic model variants.
|
||||
// isChatOnly parameter disables tool calling for -chat model variants (pure conversation mode).
|
||||
// headers parameter allows checking Anthropic-Beta header for thinking mode detection.
|
||||
// metadata parameter is kept for API compatibility but no longer used for thinking configuration.
|
||||
// Supports thinking mode - when enabled, injects thinking tags into system prompt.
|
||||
// Returns the payload and a boolean indicating whether thinking mode was injected.
|
||||
func BuildKiroPayload(claudeBody []byte, modelID, profileArn, origin string, isAgentic, isChatOnly bool, headers http.Header, metadata map[string]any) ([]byte, bool) {
|
||||
// Extract max_tokens for potential use in inferenceConfig
|
||||
// Handle -1 as "use maximum" (Kiro max output is ~32000 tokens)
|
||||
const kiroMaxOutputTokens = 32000
|
||||
var maxTokens int64
|
||||
if mt := gjson.GetBytes(claudeBody, "max_tokens"); mt.Exists() {
|
||||
maxTokens = mt.Int()
|
||||
if maxTokens == -1 {
|
||||
maxTokens = kiroMaxOutputTokens
|
||||
log.Debugf("kiro: max_tokens=-1 converted to %d", kiroMaxOutputTokens)
|
||||
}
|
||||
}
|
||||
|
||||
// Extract temperature if specified
|
||||
var temperature float64
|
||||
var hasTemperature bool
|
||||
if temp := gjson.GetBytes(claudeBody, "temperature"); temp.Exists() {
|
||||
temperature = temp.Float()
|
||||
hasTemperature = true
|
||||
}
|
||||
|
||||
// Extract top_p if specified
|
||||
var topP float64
|
||||
var hasTopP bool
|
||||
if tp := gjson.GetBytes(claudeBody, "top_p"); tp.Exists() {
|
||||
topP = tp.Float()
|
||||
hasTopP = true
|
||||
log.Debugf("kiro: extracted top_p: %.2f", topP)
|
||||
}
|
||||
|
||||
// Normalize origin value for Kiro API compatibility
|
||||
origin = normalizeOrigin(origin)
|
||||
log.Debugf("kiro: normalized origin value: %s", origin)
|
||||
|
||||
messages := gjson.GetBytes(claudeBody, "messages")
|
||||
|
||||
// For chat-only mode, don't include tools
|
||||
var tools gjson.Result
|
||||
if !isChatOnly {
|
||||
tools = gjson.GetBytes(claudeBody, "tools")
|
||||
}
|
||||
|
||||
// Extract system prompt
|
||||
systemPrompt := extractSystemPrompt(claudeBody)
|
||||
|
||||
// Check for thinking mode using the comprehensive IsThinkingEnabledWithHeaders function
|
||||
// This supports Claude API format, OpenAI reasoning_effort, AMP/Cursor format, and Anthropic-Beta header
|
||||
thinkingEnabled := IsThinkingEnabledWithHeaders(claudeBody, headers)
|
||||
|
||||
// Inject timestamp context
|
||||
timestamp := time.Now().Format("2006-01-02 15:04:05 MST")
|
||||
timestampContext := fmt.Sprintf("[Context: Current time is %s]", timestamp)
|
||||
if systemPrompt != "" {
|
||||
systemPrompt = timestampContext + "\n\n" + systemPrompt
|
||||
} else {
|
||||
systemPrompt = timestampContext
|
||||
}
|
||||
log.Debugf("kiro: injected timestamp context: %s", timestamp)
|
||||
|
||||
// Inject agentic optimization prompt for -agentic model variants
|
||||
if isAgentic {
|
||||
if systemPrompt != "" {
|
||||
systemPrompt += "\n"
|
||||
}
|
||||
systemPrompt += kirocommon.KiroAgenticSystemPrompt
|
||||
}
|
||||
|
||||
// Handle tool_choice parameter - Kiro doesn't support it natively, so we inject system prompt hints
|
||||
// Claude tool_choice values: {"type": "auto/any/tool", "name": "..."}
|
||||
toolChoiceHint := extractClaudeToolChoiceHint(claudeBody)
|
||||
if toolChoiceHint != "" {
|
||||
if systemPrompt != "" {
|
||||
systemPrompt += "\n"
|
||||
}
|
||||
systemPrompt += toolChoiceHint
|
||||
log.Debugf("kiro: injected tool_choice hint into system prompt")
|
||||
}
|
||||
|
||||
// Convert Claude tools to Kiro format
|
||||
kiroTools := convertClaudeToolsToKiro(tools)
|
||||
|
||||
// Thinking mode implementation:
|
||||
// Kiro API doesn't accept max_tokens for thinking. Instead, thinking mode is enabled
|
||||
// by injecting <thinking_mode> and <max_thinking_length> tags into the system prompt.
|
||||
// We use a fixed max_thinking_length value since Kiro handles the actual budget internally.
|
||||
if thinkingEnabled {
|
||||
thinkingHint := `<thinking_mode>interleaved</thinking_mode>
|
||||
<max_thinking_length>200000</max_thinking_length>
|
||||
|
||||
IMPORTANT: You MUST use <thinking>...</thinking> tags to show your reasoning process before providing your final response. Think step by step inside the thinking tags.`
|
||||
if systemPrompt != "" {
|
||||
systemPrompt = thinkingHint + "\n\n" + systemPrompt
|
||||
} else {
|
||||
systemPrompt = thinkingHint
|
||||
}
|
||||
log.Infof("kiro: injected thinking prompt, has_tools: %v", len(kiroTools) > 0)
|
||||
}
|
||||
|
||||
// Process messages and build history
|
||||
history, currentUserMsg, currentToolResults := processMessages(messages, modelID, origin)
|
||||
|
||||
// Build content with system prompt
|
||||
if currentUserMsg != nil {
|
||||
currentUserMsg.Content = buildFinalContent(currentUserMsg.Content, systemPrompt, currentToolResults)
|
||||
|
||||
// Deduplicate currentToolResults
|
||||
currentToolResults = deduplicateToolResults(currentToolResults)
|
||||
|
||||
// Build userInputMessageContext with tools and tool results
|
||||
if len(kiroTools) > 0 || len(currentToolResults) > 0 {
|
||||
currentUserMsg.UserInputMessageContext = &KiroUserInputMessageContext{
|
||||
Tools: kiroTools,
|
||||
ToolResults: currentToolResults,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build payload
|
||||
var currentMessage KiroCurrentMessage
|
||||
if currentUserMsg != nil {
|
||||
currentMessage = KiroCurrentMessage{UserInputMessage: *currentUserMsg}
|
||||
} else {
|
||||
fallbackContent := ""
|
||||
if systemPrompt != "" {
|
||||
fallbackContent = "--- SYSTEM PROMPT ---\n" + systemPrompt + "\n--- END SYSTEM PROMPT ---\n"
|
||||
}
|
||||
currentMessage = KiroCurrentMessage{UserInputMessage: KiroUserInputMessage{
|
||||
Content: fallbackContent,
|
||||
ModelID: modelID,
|
||||
Origin: origin,
|
||||
}}
|
||||
}
|
||||
|
||||
// Build inferenceConfig if we have any inference parameters
|
||||
// Note: Kiro API doesn't actually use max_tokens for thinking budget
|
||||
var inferenceConfig *KiroInferenceConfig
|
||||
if maxTokens > 0 || hasTemperature || hasTopP {
|
||||
inferenceConfig = &KiroInferenceConfig{}
|
||||
if maxTokens > 0 {
|
||||
inferenceConfig.MaxTokens = int(maxTokens)
|
||||
}
|
||||
if hasTemperature {
|
||||
inferenceConfig.Temperature = temperature
|
||||
}
|
||||
if hasTopP {
|
||||
inferenceConfig.TopP = topP
|
||||
}
|
||||
}
|
||||
|
||||
payload := KiroPayload{
|
||||
ConversationState: KiroConversationState{
|
||||
ChatTriggerType: "MANUAL",
|
||||
ConversationID: uuid.New().String(),
|
||||
CurrentMessage: currentMessage,
|
||||
History: history,
|
||||
},
|
||||
ProfileArn: profileArn,
|
||||
InferenceConfig: inferenceConfig,
|
||||
}
|
||||
|
||||
result, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
log.Debugf("kiro: failed to marshal payload: %v", err)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return result, thinkingEnabled
|
||||
}
|
||||
|
||||
// normalizeOrigin normalizes origin value for Kiro API compatibility
|
||||
func normalizeOrigin(origin string) string {
|
||||
switch origin {
|
||||
case "KIRO_CLI":
|
||||
return "CLI"
|
||||
case "KIRO_AI_EDITOR":
|
||||
return "AI_EDITOR"
|
||||
case "AMAZON_Q":
|
||||
return "CLI"
|
||||
case "KIRO_IDE":
|
||||
return "AI_EDITOR"
|
||||
default:
|
||||
return origin
|
||||
}
|
||||
}
|
||||
|
||||
// extractSystemPrompt extracts system prompt from Claude request
|
||||
func extractSystemPrompt(claudeBody []byte) string {
|
||||
systemField := gjson.GetBytes(claudeBody, "system")
|
||||
if systemField.IsArray() {
|
||||
var sb strings.Builder
|
||||
for _, block := range systemField.Array() {
|
||||
if block.Get("type").String() == "text" {
|
||||
sb.WriteString(block.Get("text").String())
|
||||
} else if block.Type == gjson.String {
|
||||
sb.WriteString(block.String())
|
||||
}
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
return systemField.String()
|
||||
}
|
||||
|
||||
// checkThinkingMode checks if thinking mode is enabled in the Claude request
|
||||
func checkThinkingMode(claudeBody []byte) (bool, int64) {
|
||||
thinkingEnabled := false
|
||||
var budgetTokens int64 = 24000
|
||||
|
||||
thinkingField := gjson.GetBytes(claudeBody, "thinking")
|
||||
if thinkingField.Exists() {
|
||||
thinkingType := thinkingField.Get("type").String()
|
||||
if thinkingType == "enabled" {
|
||||
thinkingEnabled = true
|
||||
if bt := thinkingField.Get("budget_tokens"); bt.Exists() {
|
||||
budgetTokens = bt.Int()
|
||||
if budgetTokens <= 0 {
|
||||
thinkingEnabled = false
|
||||
log.Debugf("kiro: thinking mode disabled via budget_tokens <= 0")
|
||||
}
|
||||
}
|
||||
if thinkingEnabled {
|
||||
log.Debugf("kiro: thinking mode enabled via Claude API parameter, budget_tokens: %d", budgetTokens)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return thinkingEnabled, budgetTokens
|
||||
}
|
||||
|
||||
// hasThinkingTagInBody checks if the request body already contains thinking configuration tags.
|
||||
// This is used to prevent duplicate injection when client (e.g., AMP/Cursor) already includes thinking config.
|
||||
func hasThinkingTagInBody(body []byte) bool {
|
||||
bodyStr := string(body)
|
||||
return strings.Contains(bodyStr, "<thinking_mode>") || strings.Contains(bodyStr, "<max_thinking_length>")
|
||||
}
|
||||
|
||||
|
||||
// IsThinkingEnabledFromHeader checks if thinking mode is enabled via Anthropic-Beta header.
|
||||
// Claude CLI uses "Anthropic-Beta: interleaved-thinking-2025-05-14" to enable thinking.
|
||||
func IsThinkingEnabledFromHeader(headers http.Header) bool {
|
||||
if headers == nil {
|
||||
return false
|
||||
}
|
||||
betaHeader := headers.Get("Anthropic-Beta")
|
||||
if betaHeader == "" {
|
||||
return false
|
||||
}
|
||||
// Check for interleaved-thinking beta feature
|
||||
if strings.Contains(betaHeader, "interleaved-thinking") {
|
||||
log.Debugf("kiro: thinking mode enabled via Anthropic-Beta header: %s", betaHeader)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsThinkingEnabled is a public wrapper to check if thinking mode is enabled.
|
||||
// This is used by the executor to determine whether to parse <thinking> tags in responses.
|
||||
// When thinking is NOT enabled in the request, <thinking> tags in responses should be
|
||||
// treated as regular text content, not as thinking blocks.
|
||||
//
|
||||
// Supports multiple formats:
|
||||
// - Claude API format: thinking.type = "enabled"
|
||||
// - OpenAI format: reasoning_effort parameter
|
||||
// - AMP/Cursor format: <thinking_mode>interleaved</thinking_mode> in system prompt
|
||||
func IsThinkingEnabled(body []byte) bool {
|
||||
return IsThinkingEnabledWithHeaders(body, nil)
|
||||
}
|
||||
|
||||
// IsThinkingEnabledWithHeaders checks if thinking mode is enabled from body or headers.
|
||||
// This is the comprehensive check that supports all thinking detection methods:
|
||||
// - Claude API format: thinking.type = "enabled"
|
||||
// - OpenAI format: reasoning_effort parameter
|
||||
// - AMP/Cursor format: <thinking_mode>interleaved</thinking_mode> in system prompt
|
||||
// - Anthropic-Beta header: interleaved-thinking-2025-05-14
|
||||
func IsThinkingEnabledWithHeaders(body []byte, headers http.Header) bool {
|
||||
// Check Anthropic-Beta header first (Claude Code uses this)
|
||||
if IsThinkingEnabledFromHeader(headers) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check Claude API format first (thinking.type = "enabled")
|
||||
enabled, _ := checkThinkingMode(body)
|
||||
if enabled {
|
||||
log.Debugf("kiro: IsThinkingEnabled returning true (Claude API format)")
|
||||
return true
|
||||
}
|
||||
|
||||
// Check OpenAI format: reasoning_effort parameter
|
||||
// Valid values: "low", "medium", "high", "auto" (not "none")
|
||||
reasoningEffort := gjson.GetBytes(body, "reasoning_effort")
|
||||
if reasoningEffort.Exists() {
|
||||
effort := reasoningEffort.String()
|
||||
if effort != "" && effort != "none" {
|
||||
log.Debugf("kiro: thinking mode enabled via OpenAI reasoning_effort: %s", effort)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check AMP/Cursor format: <thinking_mode>interleaved</thinking_mode> in system prompt
|
||||
// This is how AMP client passes thinking configuration
|
||||
bodyStr := string(body)
|
||||
if strings.Contains(bodyStr, "<thinking_mode>") && strings.Contains(bodyStr, "</thinking_mode>") {
|
||||
// Extract thinking mode value
|
||||
startTag := "<thinking_mode>"
|
||||
endTag := "</thinking_mode>"
|
||||
startIdx := strings.Index(bodyStr, startTag)
|
||||
if startIdx >= 0 {
|
||||
startIdx += len(startTag)
|
||||
endIdx := strings.Index(bodyStr[startIdx:], endTag)
|
||||
if endIdx >= 0 {
|
||||
thinkingMode := bodyStr[startIdx : startIdx+endIdx]
|
||||
if thinkingMode == "interleaved" || thinkingMode == "enabled" {
|
||||
log.Debugf("kiro: thinking mode enabled via AMP/Cursor format: %s", thinkingMode)
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check OpenAI format: max_completion_tokens with reasoning (o1-style)
|
||||
// Some clients use this to indicate reasoning mode
|
||||
if gjson.GetBytes(body, "max_completion_tokens").Exists() {
|
||||
// If max_completion_tokens is set, check if model name suggests reasoning
|
||||
model := gjson.GetBytes(body, "model").String()
|
||||
if strings.Contains(strings.ToLower(model), "thinking") ||
|
||||
strings.Contains(strings.ToLower(model), "reason") {
|
||||
log.Debugf("kiro: thinking mode enabled via model name hint: %s", model)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("kiro: IsThinkingEnabled returning false (no thinking mode detected)")
|
||||
return false
|
||||
}
|
||||
|
||||
// shortenToolNameIfNeeded shortens tool names that exceed 64 characters.
|
||||
// MCP tools often have long names like "mcp__server-name__tool-name".
|
||||
// This preserves the "mcp__" prefix and last segment when possible.
|
||||
func shortenToolNameIfNeeded(name string) string {
|
||||
const limit = 64
|
||||
if len(name) <= limit {
|
||||
return name
|
||||
}
|
||||
// For MCP tools, try to preserve prefix and last segment
|
||||
if strings.HasPrefix(name, "mcp__") {
|
||||
idx := strings.LastIndex(name, "__")
|
||||
if idx > 0 {
|
||||
cand := "mcp__" + name[idx+2:]
|
||||
if len(cand) > limit {
|
||||
return cand[:limit]
|
||||
}
|
||||
return cand
|
||||
}
|
||||
}
|
||||
return name[:limit]
|
||||
}
|
||||
|
||||
// convertClaudeToolsToKiro converts Claude tools to Kiro format
|
||||
func convertClaudeToolsToKiro(tools gjson.Result) []KiroToolWrapper {
|
||||
var kiroTools []KiroToolWrapper
|
||||
if !tools.IsArray() {
|
||||
return kiroTools
|
||||
}
|
||||
|
||||
for _, tool := range tools.Array() {
|
||||
name := tool.Get("name").String()
|
||||
description := tool.Get("description").String()
|
||||
inputSchema := tool.Get("input_schema").Value()
|
||||
|
||||
// Shorten tool name if it exceeds 64 characters (common with MCP tools)
|
||||
originalName := name
|
||||
name = shortenToolNameIfNeeded(name)
|
||||
if name != originalName {
|
||||
log.Debugf("kiro: shortened tool name from '%s' to '%s'", originalName, name)
|
||||
}
|
||||
|
||||
// CRITICAL FIX: Kiro API requires non-empty description
|
||||
if strings.TrimSpace(description) == "" {
|
||||
description = fmt.Sprintf("Tool: %s", name)
|
||||
log.Debugf("kiro: tool '%s' has empty description, using default: %s", name, description)
|
||||
}
|
||||
|
||||
// Truncate long descriptions
|
||||
if len(description) > kirocommon.KiroMaxToolDescLen {
|
||||
truncLen := kirocommon.KiroMaxToolDescLen - 30
|
||||
for truncLen > 0 && !utf8.RuneStart(description[truncLen]) {
|
||||
truncLen--
|
||||
}
|
||||
description = description[:truncLen] + "... (description truncated)"
|
||||
}
|
||||
|
||||
kiroTools = append(kiroTools, KiroToolWrapper{
|
||||
ToolSpecification: KiroToolSpecification{
|
||||
Name: name,
|
||||
Description: description,
|
||||
InputSchema: KiroInputSchema{JSON: inputSchema},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return kiroTools
|
||||
}
|
||||
|
||||
// processMessages processes Claude messages and builds Kiro history
|
||||
func processMessages(messages gjson.Result, modelID, origin string) ([]KiroHistoryMessage, *KiroUserInputMessage, []KiroToolResult) {
|
||||
var history []KiroHistoryMessage
|
||||
var currentUserMsg *KiroUserInputMessage
|
||||
var currentToolResults []KiroToolResult
|
||||
|
||||
// Merge adjacent messages with the same role
|
||||
messagesArray := kirocommon.MergeAdjacentMessages(messages.Array())
|
||||
for i, msg := range messagesArray {
|
||||
role := msg.Get("role").String()
|
||||
isLastMessage := i == len(messagesArray)-1
|
||||
|
||||
if role == "user" {
|
||||
userMsg, toolResults := BuildUserMessageStruct(msg, modelID, origin)
|
||||
if isLastMessage {
|
||||
currentUserMsg = &userMsg
|
||||
currentToolResults = toolResults
|
||||
} else {
|
||||
// CRITICAL: Kiro API requires content to be non-empty for history messages too
|
||||
if strings.TrimSpace(userMsg.Content) == "" {
|
||||
if len(toolResults) > 0 {
|
||||
userMsg.Content = "Tool results provided."
|
||||
} else {
|
||||
userMsg.Content = "Continue"
|
||||
}
|
||||
}
|
||||
// For history messages, embed tool results in context
|
||||
if len(toolResults) > 0 {
|
||||
userMsg.UserInputMessageContext = &KiroUserInputMessageContext{
|
||||
ToolResults: toolResults,
|
||||
}
|
||||
}
|
||||
history = append(history, KiroHistoryMessage{
|
||||
UserInputMessage: &userMsg,
|
||||
})
|
||||
}
|
||||
} else if role == "assistant" {
|
||||
assistantMsg := BuildAssistantMessageStruct(msg)
|
||||
if isLastMessage {
|
||||
history = append(history, KiroHistoryMessage{
|
||||
AssistantResponseMessage: &assistantMsg,
|
||||
})
|
||||
// Create a "Continue" user message as currentMessage
|
||||
currentUserMsg = &KiroUserInputMessage{
|
||||
Content: "Continue",
|
||||
ModelID: modelID,
|
||||
Origin: origin,
|
||||
}
|
||||
} else {
|
||||
history = append(history, KiroHistoryMessage{
|
||||
AssistantResponseMessage: &assistantMsg,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return history, currentUserMsg, currentToolResults
|
||||
}
|
||||
|
||||
// buildFinalContent builds the final content with system prompt
|
||||
func buildFinalContent(content, systemPrompt string, toolResults []KiroToolResult) string {
|
||||
var contentBuilder strings.Builder
|
||||
|
||||
if systemPrompt != "" {
|
||||
contentBuilder.WriteString("--- SYSTEM PROMPT ---\n")
|
||||
contentBuilder.WriteString(systemPrompt)
|
||||
contentBuilder.WriteString("\n--- END SYSTEM PROMPT ---\n\n")
|
||||
}
|
||||
|
||||
contentBuilder.WriteString(content)
|
||||
finalContent := contentBuilder.String()
|
||||
|
||||
// CRITICAL: Kiro API requires content to be non-empty
|
||||
if strings.TrimSpace(finalContent) == "" {
|
||||
if len(toolResults) > 0 {
|
||||
finalContent = "Tool results provided."
|
||||
} else {
|
||||
finalContent = "Continue"
|
||||
}
|
||||
log.Debugf("kiro: content was empty, using default: %s", finalContent)
|
||||
}
|
||||
|
||||
return finalContent
|
||||
}
|
||||
|
||||
// deduplicateToolResults removes duplicate tool results
|
||||
func deduplicateToolResults(toolResults []KiroToolResult) []KiroToolResult {
|
||||
if len(toolResults) == 0 {
|
||||
return toolResults
|
||||
}
|
||||
|
||||
seenIDs := make(map[string]bool)
|
||||
unique := make([]KiroToolResult, 0, len(toolResults))
|
||||
for _, tr := range toolResults {
|
||||
if !seenIDs[tr.ToolUseID] {
|
||||
seenIDs[tr.ToolUseID] = true
|
||||
unique = append(unique, tr)
|
||||
} else {
|
||||
log.Debugf("kiro: skipping duplicate toolResult in currentMessage: %s", tr.ToolUseID)
|
||||
}
|
||||
}
|
||||
return unique
|
||||
}
|
||||
|
||||
// extractClaudeToolChoiceHint extracts tool_choice from Claude request and returns a system prompt hint.
|
||||
// Claude tool_choice values:
|
||||
// - {"type": "auto"}: Model decides (default, no hint needed)
|
||||
// - {"type": "any"}: Must use at least one tool
|
||||
// - {"type": "tool", "name": "..."}: Must use specific tool
|
||||
func extractClaudeToolChoiceHint(claudeBody []byte) string {
|
||||
toolChoice := gjson.GetBytes(claudeBody, "tool_choice")
|
||||
if !toolChoice.Exists() {
|
||||
return ""
|
||||
}
|
||||
|
||||
toolChoiceType := toolChoice.Get("type").String()
|
||||
switch toolChoiceType {
|
||||
case "any":
|
||||
return "[INSTRUCTION: You MUST use at least one of the available tools to respond. Do not respond with text only - always make a tool call.]"
|
||||
case "tool":
|
||||
toolName := toolChoice.Get("name").String()
|
||||
if toolName != "" {
|
||||
return fmt.Sprintf("[INSTRUCTION: You MUST use the tool named '%s' to respond. Do not use any other tool or respond with text only.]", toolName)
|
||||
}
|
||||
case "auto":
|
||||
// Default behavior, no hint needed
|
||||
return ""
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// BuildUserMessageStruct builds a user message and extracts tool results
|
||||
func BuildUserMessageStruct(msg gjson.Result, modelID, origin string) (KiroUserInputMessage, []KiroToolResult) {
|
||||
content := msg.Get("content")
|
||||
var contentBuilder strings.Builder
|
||||
var toolResults []KiroToolResult
|
||||
var images []KiroImage
|
||||
|
||||
// Track seen toolUseIds to deduplicate
|
||||
seenToolUseIDs := make(map[string]bool)
|
||||
|
||||
if content.IsArray() {
|
||||
for _, part := range content.Array() {
|
||||
partType := part.Get("type").String()
|
||||
switch partType {
|
||||
case "text":
|
||||
contentBuilder.WriteString(part.Get("text").String())
|
||||
case "image":
|
||||
mediaType := part.Get("source.media_type").String()
|
||||
data := part.Get("source.data").String()
|
||||
|
||||
format := ""
|
||||
if idx := strings.LastIndex(mediaType, "/"); idx != -1 {
|
||||
format = mediaType[idx+1:]
|
||||
}
|
||||
|
||||
if format != "" && data != "" {
|
||||
images = append(images, KiroImage{
|
||||
Format: format,
|
||||
Source: KiroImageSource{
|
||||
Bytes: data,
|
||||
},
|
||||
})
|
||||
}
|
||||
case "tool_result":
|
||||
toolUseID := part.Get("tool_use_id").String()
|
||||
|
||||
// Skip duplicate toolUseIds
|
||||
if seenToolUseIDs[toolUseID] {
|
||||
log.Debugf("kiro: skipping duplicate tool_result with toolUseId: %s", toolUseID)
|
||||
continue
|
||||
}
|
||||
seenToolUseIDs[toolUseID] = true
|
||||
|
||||
isError := part.Get("is_error").Bool()
|
||||
resultContent := part.Get("content")
|
||||
|
||||
var textContents []KiroTextContent
|
||||
if resultContent.IsArray() {
|
||||
for _, item := range resultContent.Array() {
|
||||
if item.Get("type").String() == "text" {
|
||||
textContents = append(textContents, KiroTextContent{Text: item.Get("text").String()})
|
||||
} else if item.Type == gjson.String {
|
||||
textContents = append(textContents, KiroTextContent{Text: item.String()})
|
||||
}
|
||||
}
|
||||
} else if resultContent.Type == gjson.String {
|
||||
textContents = append(textContents, KiroTextContent{Text: resultContent.String()})
|
||||
}
|
||||
|
||||
if len(textContents) == 0 {
|
||||
textContents = append(textContents, KiroTextContent{Text: "Tool use was cancelled by the user"})
|
||||
}
|
||||
|
||||
status := "success"
|
||||
if isError {
|
||||
status = "error"
|
||||
}
|
||||
|
||||
toolResults = append(toolResults, KiroToolResult{
|
||||
ToolUseID: toolUseID,
|
||||
Content: textContents,
|
||||
Status: status,
|
||||
})
|
||||
}
|
||||
}
|
||||
} else {
|
||||
contentBuilder.WriteString(content.String())
|
||||
}
|
||||
|
||||
userMsg := KiroUserInputMessage{
|
||||
Content: contentBuilder.String(),
|
||||
ModelID: modelID,
|
||||
Origin: origin,
|
||||
}
|
||||
|
||||
if len(images) > 0 {
|
||||
userMsg.Images = images
|
||||
}
|
||||
|
||||
return userMsg, toolResults
|
||||
}
|
||||
|
||||
// BuildAssistantMessageStruct builds an assistant message with tool uses
|
||||
func BuildAssistantMessageStruct(msg gjson.Result) KiroAssistantResponseMessage {
|
||||
content := msg.Get("content")
|
||||
var contentBuilder strings.Builder
|
||||
var toolUses []KiroToolUse
|
||||
|
||||
if content.IsArray() {
|
||||
for _, part := range content.Array() {
|
||||
partType := part.Get("type").String()
|
||||
switch partType {
|
||||
case "text":
|
||||
contentBuilder.WriteString(part.Get("text").String())
|
||||
case "tool_use":
|
||||
toolUseID := part.Get("id").String()
|
||||
toolName := part.Get("name").String()
|
||||
toolInput := part.Get("input")
|
||||
|
||||
var inputMap map[string]interface{}
|
||||
if toolInput.IsObject() {
|
||||
inputMap = make(map[string]interface{})
|
||||
toolInput.ForEach(func(key, value gjson.Result) bool {
|
||||
inputMap[key.String()] = value.Value()
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
toolUses = append(toolUses, KiroToolUse{
|
||||
ToolUseID: toolUseID,
|
||||
Name: toolName,
|
||||
Input: inputMap,
|
||||
})
|
||||
}
|
||||
}
|
||||
} else {
|
||||
contentBuilder.WriteString(content.String())
|
||||
}
|
||||
|
||||
return KiroAssistantResponseMessage{
|
||||
Content: contentBuilder.String(),
|
||||
ToolUses: toolUses,
|
||||
}
|
||||
}
|
||||
204
internal/translator/kiro/claude/kiro_claude_response.go
Normal file
204
internal/translator/kiro/claude/kiro_claude_response.go
Normal file
@@ -0,0 +1,204 @@
|
||||
// Package claude provides response translation functionality for Kiro API to Claude format.
|
||||
// This package handles the conversion of Kiro API responses into Claude-compatible format,
|
||||
// including support for thinking blocks and tool use.
|
||||
package claude
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
kirocommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/common"
|
||||
)
|
||||
|
||||
// generateThinkingSignature generates a signature for thinking content.
|
||||
// This is required by Claude API for thinking blocks in non-streaming responses.
|
||||
// The signature is a base64-encoded hash of the thinking content.
|
||||
func generateThinkingSignature(thinkingContent string) string {
|
||||
if thinkingContent == "" {
|
||||
return ""
|
||||
}
|
||||
// Generate a deterministic signature based on content hash
|
||||
hash := sha256.Sum256([]byte(thinkingContent))
|
||||
return base64.StdEncoding.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// Local references to kirocommon constants for thinking block parsing
|
||||
var (
|
||||
thinkingStartTag = kirocommon.ThinkingStartTag
|
||||
thinkingEndTag = kirocommon.ThinkingEndTag
|
||||
)
|
||||
|
||||
// BuildClaudeResponse constructs a Claude-compatible response.
|
||||
// Supports tool_use blocks when tools are present in the response.
|
||||
// Supports thinking blocks - parses <thinking> tags and converts to Claude thinking content blocks.
|
||||
// stopReason is passed from upstream; fallback logic applied if empty.
|
||||
func BuildClaudeResponse(content string, toolUses []KiroToolUse, model string, usageInfo usage.Detail, stopReason string) []byte {
|
||||
var contentBlocks []map[string]interface{}
|
||||
|
||||
// Extract thinking blocks and text from content
|
||||
if content != "" {
|
||||
blocks := ExtractThinkingFromContent(content)
|
||||
contentBlocks = append(contentBlocks, blocks...)
|
||||
|
||||
// Log if thinking blocks were extracted
|
||||
for _, block := range blocks {
|
||||
if block["type"] == "thinking" {
|
||||
thinkingContent := block["thinking"].(string)
|
||||
log.Infof("kiro: buildClaudeResponse extracted thinking block (len: %d)", len(thinkingContent))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add tool_use blocks
|
||||
for _, toolUse := range toolUses {
|
||||
contentBlocks = append(contentBlocks, map[string]interface{}{
|
||||
"type": "tool_use",
|
||||
"id": toolUse.ToolUseID,
|
||||
"name": toolUse.Name,
|
||||
"input": toolUse.Input,
|
||||
})
|
||||
}
|
||||
|
||||
// Ensure at least one content block (Claude API requires non-empty content)
|
||||
if len(contentBlocks) == 0 {
|
||||
contentBlocks = append(contentBlocks, map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": "",
|
||||
})
|
||||
}
|
||||
|
||||
// Use upstream stopReason; apply fallback logic if not provided
|
||||
if stopReason == "" {
|
||||
stopReason = "end_turn"
|
||||
if len(toolUses) > 0 {
|
||||
stopReason = "tool_use"
|
||||
}
|
||||
log.Debugf("kiro: buildClaudeResponse using fallback stop_reason: %s", stopReason)
|
||||
}
|
||||
|
||||
// Log warning if response was truncated due to max_tokens
|
||||
if stopReason == "max_tokens" {
|
||||
log.Warnf("kiro: response truncated due to max_tokens limit (buildClaudeResponse)")
|
||||
}
|
||||
|
||||
response := map[string]interface{}{
|
||||
"id": "msg_" + uuid.New().String()[:24],
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"content": contentBlocks,
|
||||
"stop_reason": stopReason,
|
||||
"usage": map[string]interface{}{
|
||||
"input_tokens": usageInfo.InputTokens,
|
||||
"output_tokens": usageInfo.OutputTokens,
|
||||
},
|
||||
}
|
||||
result, _ := json.Marshal(response)
|
||||
return result
|
||||
}
|
||||
|
||||
// ExtractThinkingFromContent parses content to extract thinking blocks and text.
|
||||
// Returns a list of content blocks in the order they appear in the content.
|
||||
// Handles interleaved thinking and text blocks correctly.
|
||||
func ExtractThinkingFromContent(content string) []map[string]interface{} {
|
||||
var blocks []map[string]interface{}
|
||||
|
||||
if content == "" {
|
||||
return blocks
|
||||
}
|
||||
|
||||
// Check if content contains thinking tags at all
|
||||
if !strings.Contains(content, thinkingStartTag) {
|
||||
// No thinking tags, return as plain text
|
||||
return []map[string]interface{}{
|
||||
{
|
||||
"type": "text",
|
||||
"text": content,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("kiro: extractThinkingFromContent - found thinking tags in content (len: %d)", len(content))
|
||||
|
||||
remaining := content
|
||||
|
||||
for len(remaining) > 0 {
|
||||
// Look for <thinking> tag
|
||||
startIdx := strings.Index(remaining, thinkingStartTag)
|
||||
|
||||
if startIdx == -1 {
|
||||
// No more thinking tags, add remaining as text
|
||||
if strings.TrimSpace(remaining) != "" {
|
||||
blocks = append(blocks, map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": remaining,
|
||||
})
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
// Add text before thinking tag (if any meaningful content)
|
||||
if startIdx > 0 {
|
||||
textBefore := remaining[:startIdx]
|
||||
if strings.TrimSpace(textBefore) != "" {
|
||||
blocks = append(blocks, map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": textBefore,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Move past the opening tag
|
||||
remaining = remaining[startIdx+len(thinkingStartTag):]
|
||||
|
||||
// Find closing tag
|
||||
endIdx := strings.Index(remaining, thinkingEndTag)
|
||||
|
||||
if endIdx == -1 {
|
||||
// No closing tag found, treat rest as thinking content (incomplete response)
|
||||
if strings.TrimSpace(remaining) != "" {
|
||||
// Generate signature for thinking content (required by Claude API)
|
||||
signature := generateThinkingSignature(remaining)
|
||||
blocks = append(blocks, map[string]interface{}{
|
||||
"type": "thinking",
|
||||
"thinking": remaining,
|
||||
"signature": signature,
|
||||
})
|
||||
log.Warnf("kiro: extractThinkingFromContent - missing closing </thinking> tag")
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
// Extract thinking content between tags
|
||||
thinkContent := remaining[:endIdx]
|
||||
if strings.TrimSpace(thinkContent) != "" {
|
||||
// Generate signature for thinking content (required by Claude API)
|
||||
signature := generateThinkingSignature(thinkContent)
|
||||
blocks = append(blocks, map[string]interface{}{
|
||||
"type": "thinking",
|
||||
"thinking": thinkContent,
|
||||
"signature": signature,
|
||||
})
|
||||
log.Debugf("kiro: extractThinkingFromContent - extracted thinking block (len: %d)", len(thinkContent))
|
||||
}
|
||||
|
||||
// Move past the closing tag
|
||||
remaining = remaining[endIdx+len(thinkingEndTag):]
|
||||
}
|
||||
|
||||
// If no blocks were created (all whitespace), return empty text block
|
||||
if len(blocks) == 0 {
|
||||
blocks = append(blocks, map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": "",
|
||||
})
|
||||
}
|
||||
|
||||
return blocks
|
||||
}
|
||||
186
internal/translator/kiro/claude/kiro_claude_stream.go
Normal file
186
internal/translator/kiro/claude/kiro_claude_stream.go
Normal file
@@ -0,0 +1,186 @@
|
||||
// Package claude provides streaming SSE event building for Claude format.
|
||||
// This package handles the construction of Claude-compatible Server-Sent Events (SSE)
|
||||
// for streaming responses from Kiro API.
|
||||
package claude
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
|
||||
)
|
||||
|
||||
// BuildClaudeMessageStartEvent creates the message_start SSE event
|
||||
func BuildClaudeMessageStartEvent(model string, inputTokens int64) []byte {
|
||||
event := map[string]interface{}{
|
||||
"type": "message_start",
|
||||
"message": map[string]interface{}{
|
||||
"id": "msg_" + uuid.New().String()[:24],
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": []interface{}{},
|
||||
"model": model,
|
||||
"stop_reason": nil,
|
||||
"stop_sequence": nil,
|
||||
"usage": map[string]interface{}{"input_tokens": inputTokens, "output_tokens": 0},
|
||||
},
|
||||
}
|
||||
result, _ := json.Marshal(event)
|
||||
return []byte("event: message_start\ndata: " + string(result))
|
||||
}
|
||||
|
||||
// BuildClaudeContentBlockStartEvent creates a content_block_start SSE event
|
||||
func BuildClaudeContentBlockStartEvent(index int, blockType, toolUseID, toolName string) []byte {
|
||||
var contentBlock map[string]interface{}
|
||||
switch blockType {
|
||||
case "tool_use":
|
||||
contentBlock = map[string]interface{}{
|
||||
"type": "tool_use",
|
||||
"id": toolUseID,
|
||||
"name": toolName,
|
||||
"input": map[string]interface{}{},
|
||||
}
|
||||
case "thinking":
|
||||
contentBlock = map[string]interface{}{
|
||||
"type": "thinking",
|
||||
"thinking": "",
|
||||
}
|
||||
default:
|
||||
contentBlock = map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": "",
|
||||
}
|
||||
}
|
||||
|
||||
event := map[string]interface{}{
|
||||
"type": "content_block_start",
|
||||
"index": index,
|
||||
"content_block": contentBlock,
|
||||
}
|
||||
result, _ := json.Marshal(event)
|
||||
return []byte("event: content_block_start\ndata: " + string(result))
|
||||
}
|
||||
|
||||
// BuildClaudeStreamEvent creates a text_delta content_block_delta SSE event
|
||||
func BuildClaudeStreamEvent(contentDelta string, index int) []byte {
|
||||
event := map[string]interface{}{
|
||||
"type": "content_block_delta",
|
||||
"index": index,
|
||||
"delta": map[string]interface{}{
|
||||
"type": "text_delta",
|
||||
"text": contentDelta,
|
||||
},
|
||||
}
|
||||
result, _ := json.Marshal(event)
|
||||
return []byte("event: content_block_delta\ndata: " + string(result))
|
||||
}
|
||||
|
||||
// BuildClaudeInputJsonDeltaEvent creates an input_json_delta event for tool use streaming
|
||||
func BuildClaudeInputJsonDeltaEvent(partialJSON string, index int) []byte {
|
||||
event := map[string]interface{}{
|
||||
"type": "content_block_delta",
|
||||
"index": index,
|
||||
"delta": map[string]interface{}{
|
||||
"type": "input_json_delta",
|
||||
"partial_json": partialJSON,
|
||||
},
|
||||
}
|
||||
result, _ := json.Marshal(event)
|
||||
return []byte("event: content_block_delta\ndata: " + string(result))
|
||||
}
|
||||
|
||||
// BuildClaudeContentBlockStopEvent creates a content_block_stop SSE event
|
||||
func BuildClaudeContentBlockStopEvent(index int) []byte {
|
||||
event := map[string]interface{}{
|
||||
"type": "content_block_stop",
|
||||
"index": index,
|
||||
}
|
||||
result, _ := json.Marshal(event)
|
||||
return []byte("event: content_block_stop\ndata: " + string(result))
|
||||
}
|
||||
|
||||
// BuildClaudeThinkingBlockStopEvent creates a content_block_stop SSE event for thinking blocks.
|
||||
func BuildClaudeThinkingBlockStopEvent(index int) []byte {
|
||||
event := map[string]interface{}{
|
||||
"type": "content_block_stop",
|
||||
"index": index,
|
||||
}
|
||||
result, _ := json.Marshal(event)
|
||||
return []byte("event: content_block_stop\ndata: " + string(result))
|
||||
}
|
||||
|
||||
// BuildClaudeMessageDeltaEvent creates the message_delta event with stop_reason and usage
|
||||
func BuildClaudeMessageDeltaEvent(stopReason string, usageInfo usage.Detail) []byte {
|
||||
deltaEvent := map[string]interface{}{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]interface{}{
|
||||
"stop_reason": stopReason,
|
||||
"stop_sequence": nil,
|
||||
},
|
||||
"usage": map[string]interface{}{
|
||||
"input_tokens": usageInfo.InputTokens,
|
||||
"output_tokens": usageInfo.OutputTokens,
|
||||
},
|
||||
}
|
||||
deltaResult, _ := json.Marshal(deltaEvent)
|
||||
return []byte("event: message_delta\ndata: " + string(deltaResult))
|
||||
}
|
||||
|
||||
// BuildClaudeMessageStopOnlyEvent creates only the message_stop event
|
||||
func BuildClaudeMessageStopOnlyEvent() []byte {
|
||||
stopEvent := map[string]interface{}{
|
||||
"type": "message_stop",
|
||||
}
|
||||
stopResult, _ := json.Marshal(stopEvent)
|
||||
return []byte("event: message_stop\ndata: " + string(stopResult))
|
||||
}
|
||||
|
||||
// BuildClaudePingEventWithUsage creates a ping event with embedded usage information.
|
||||
// This is used for real-time usage estimation during streaming.
|
||||
func BuildClaudePingEventWithUsage(inputTokens, outputTokens int64) []byte {
|
||||
event := map[string]interface{}{
|
||||
"type": "ping",
|
||||
"usage": map[string]interface{}{
|
||||
"input_tokens": inputTokens,
|
||||
"output_tokens": outputTokens,
|
||||
"total_tokens": inputTokens + outputTokens,
|
||||
"estimated": true,
|
||||
},
|
||||
}
|
||||
result, _ := json.Marshal(event)
|
||||
return []byte("event: ping\ndata: " + string(result))
|
||||
}
|
||||
|
||||
// BuildClaudeThinkingDeltaEvent creates a thinking_delta event for Claude API compatibility.
|
||||
// This is used when streaming thinking content wrapped in <thinking> tags.
|
||||
func BuildClaudeThinkingDeltaEvent(thinkingDelta string, index int) []byte {
|
||||
event := map[string]interface{}{
|
||||
"type": "content_block_delta",
|
||||
"index": index,
|
||||
"delta": map[string]interface{}{
|
||||
"type": "thinking_delta",
|
||||
"thinking": thinkingDelta,
|
||||
},
|
||||
}
|
||||
result, _ := json.Marshal(event)
|
||||
return []byte("event: content_block_delta\ndata: " + string(result))
|
||||
}
|
||||
|
||||
// PendingTagSuffix detects if the buffer ends with a partial prefix of the given tag.
|
||||
// Returns the length of the partial match (0 if no match).
|
||||
// Based on amq2api implementation for handling cross-chunk tag boundaries.
|
||||
func PendingTagSuffix(buffer, tag string) int {
|
||||
if buffer == "" || tag == "" {
|
||||
return 0
|
||||
}
|
||||
maxLen := len(buffer)
|
||||
if maxLen > len(tag)-1 {
|
||||
maxLen = len(tag) - 1
|
||||
}
|
||||
for length := maxLen; length > 0; length-- {
|
||||
if len(buffer) >= length && buffer[len(buffer)-length:] == tag[:length] {
|
||||
return length
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
522
internal/translator/kiro/claude/kiro_claude_tools.go
Normal file
522
internal/translator/kiro/claude/kiro_claude_tools.go
Normal file
@@ -0,0 +1,522 @@
|
||||
// Package claude provides tool calling support for Kiro to Claude translation.
|
||||
// This package handles parsing embedded tool calls, JSON repair, and deduplication.
|
||||
package claude
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
kirocommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/common"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ToolUseState tracks the state of an in-progress tool use during streaming.
|
||||
type ToolUseState struct {
|
||||
ToolUseID string
|
||||
Name string
|
||||
InputBuffer strings.Builder
|
||||
IsComplete bool
|
||||
}
|
||||
|
||||
// Pre-compiled regex patterns for performance
|
||||
var (
|
||||
// embeddedToolCallPattern matches [Called tool_name with args: {...}] format
|
||||
embeddedToolCallPattern = regexp.MustCompile(`\[Called\s+([A-Za-z0-9_.-]+)\s+with\s+args:\s*`)
|
||||
// trailingCommaPattern matches trailing commas before closing braces/brackets
|
||||
trailingCommaPattern = regexp.MustCompile(`,\s*([}\]])`)
|
||||
)
|
||||
|
||||
// ParseEmbeddedToolCalls extracts [Called tool_name with args: {...}] format from text.
|
||||
// Kiro sometimes embeds tool calls in text content instead of using toolUseEvent.
|
||||
// Returns the cleaned text (with tool calls removed) and extracted tool uses.
|
||||
func ParseEmbeddedToolCalls(text string, processedIDs map[string]bool) (string, []KiroToolUse) {
|
||||
if !strings.Contains(text, "[Called") {
|
||||
return text, nil
|
||||
}
|
||||
|
||||
var toolUses []KiroToolUse
|
||||
cleanText := text
|
||||
|
||||
// Find all [Called markers
|
||||
matches := embeddedToolCallPattern.FindAllStringSubmatchIndex(text, -1)
|
||||
if len(matches) == 0 {
|
||||
return text, nil
|
||||
}
|
||||
|
||||
// Process matches in reverse order to maintain correct indices
|
||||
for i := len(matches) - 1; i >= 0; i-- {
|
||||
matchStart := matches[i][0]
|
||||
toolNameStart := matches[i][2]
|
||||
toolNameEnd := matches[i][3]
|
||||
|
||||
if toolNameStart < 0 || toolNameEnd < 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
toolName := text[toolNameStart:toolNameEnd]
|
||||
|
||||
// Find the JSON object start (after "with args:")
|
||||
jsonStart := matches[i][1]
|
||||
if jsonStart >= len(text) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip whitespace to find the opening brace
|
||||
for jsonStart < len(text) && (text[jsonStart] == ' ' || text[jsonStart] == '\t') {
|
||||
jsonStart++
|
||||
}
|
||||
|
||||
if jsonStart >= len(text) || text[jsonStart] != '{' {
|
||||
continue
|
||||
}
|
||||
|
||||
// Find matching closing bracket
|
||||
jsonEnd := findMatchingBracket(text, jsonStart)
|
||||
if jsonEnd < 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract JSON and find the closing bracket of [Called ...]
|
||||
jsonStr := text[jsonStart : jsonEnd+1]
|
||||
|
||||
// Find the closing ] after the JSON
|
||||
closingBracket := jsonEnd + 1
|
||||
for closingBracket < len(text) && text[closingBracket] != ']' {
|
||||
closingBracket++
|
||||
}
|
||||
if closingBracket >= len(text) {
|
||||
continue
|
||||
}
|
||||
|
||||
// End index of the full tool call (closing ']' inclusive)
|
||||
matchEnd := closingBracket + 1
|
||||
|
||||
// Repair and parse JSON
|
||||
repairedJSON := RepairJSON(jsonStr)
|
||||
var inputMap map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(repairedJSON), &inputMap); err != nil {
|
||||
log.Debugf("kiro: failed to parse embedded tool call JSON: %v, raw: %s", err, jsonStr)
|
||||
continue
|
||||
}
|
||||
|
||||
// Generate unique tool ID
|
||||
toolUseID := "toolu_" + uuid.New().String()[:12]
|
||||
|
||||
// Check for duplicates using name+input as key
|
||||
dedupeKey := toolName + ":" + repairedJSON
|
||||
if processedIDs != nil {
|
||||
if processedIDs[dedupeKey] {
|
||||
log.Debugf("kiro: skipping duplicate embedded tool call: %s", toolName)
|
||||
// Still remove from text even if duplicate
|
||||
if matchStart >= 0 && matchEnd <= len(cleanText) && matchStart <= matchEnd {
|
||||
cleanText = cleanText[:matchStart] + cleanText[matchEnd:]
|
||||
}
|
||||
continue
|
||||
}
|
||||
processedIDs[dedupeKey] = true
|
||||
}
|
||||
|
||||
toolUses = append(toolUses, KiroToolUse{
|
||||
ToolUseID: toolUseID,
|
||||
Name: toolName,
|
||||
Input: inputMap,
|
||||
})
|
||||
|
||||
log.Infof("kiro: extracted embedded tool call: %s (ID: %s)", toolName, toolUseID)
|
||||
|
||||
// Remove from clean text (index-based removal to avoid deleting the wrong occurrence)
|
||||
if matchStart >= 0 && matchEnd <= len(cleanText) && matchStart <= matchEnd {
|
||||
cleanText = cleanText[:matchStart] + cleanText[matchEnd:]
|
||||
}
|
||||
}
|
||||
|
||||
return cleanText, toolUses
|
||||
}
|
||||
|
||||
// findMatchingBracket finds the index of the closing brace/bracket that matches
|
||||
// the opening one at startPos. Handles nested objects and strings correctly.
|
||||
func findMatchingBracket(text string, startPos int) int {
|
||||
if startPos >= len(text) {
|
||||
return -1
|
||||
}
|
||||
|
||||
openChar := text[startPos]
|
||||
var closeChar byte
|
||||
switch openChar {
|
||||
case '{':
|
||||
closeChar = '}'
|
||||
case '[':
|
||||
closeChar = ']'
|
||||
default:
|
||||
return -1
|
||||
}
|
||||
|
||||
depth := 1
|
||||
inString := false
|
||||
escapeNext := false
|
||||
|
||||
for i := startPos + 1; i < len(text); i++ {
|
||||
char := text[i]
|
||||
|
||||
if escapeNext {
|
||||
escapeNext = false
|
||||
continue
|
||||
}
|
||||
|
||||
if char == '\\' && inString {
|
||||
escapeNext = true
|
||||
continue
|
||||
}
|
||||
|
||||
if char == '"' {
|
||||
inString = !inString
|
||||
continue
|
||||
}
|
||||
|
||||
if !inString {
|
||||
if char == openChar {
|
||||
depth++
|
||||
} else if char == closeChar {
|
||||
depth--
|
||||
if depth == 0 {
|
||||
return i
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return -1
|
||||
}
|
||||
|
||||
// RepairJSON attempts to fix common JSON issues that may occur in tool call arguments.
|
||||
// Conservative repair strategy:
|
||||
// 1. First try to parse JSON directly - if valid, return as-is
|
||||
// 2. Only attempt repair if parsing fails
|
||||
// 3. After repair, validate the result - if still invalid, return original
|
||||
func RepairJSON(jsonString string) string {
|
||||
// Handle empty or invalid input
|
||||
if jsonString == "" {
|
||||
return "{}"
|
||||
}
|
||||
|
||||
str := strings.TrimSpace(jsonString)
|
||||
if str == "" {
|
||||
return "{}"
|
||||
}
|
||||
|
||||
// CONSERVATIVE STRATEGY: First try to parse directly
|
||||
var testParse interface{}
|
||||
if err := json.Unmarshal([]byte(str), &testParse); err == nil {
|
||||
log.Debugf("kiro: repairJSON - JSON is already valid, returning unchanged")
|
||||
return str
|
||||
}
|
||||
|
||||
log.Debugf("kiro: repairJSON - JSON parse failed, attempting repair")
|
||||
originalStr := str
|
||||
|
||||
// First, escape unescaped newlines/tabs within JSON string values
|
||||
str = escapeNewlinesInStrings(str)
|
||||
// Remove trailing commas before closing braces/brackets
|
||||
str = trailingCommaPattern.ReplaceAllString(str, "$1")
|
||||
|
||||
// Calculate bracket balance
|
||||
braceCount := 0
|
||||
bracketCount := 0
|
||||
inString := false
|
||||
escape := false
|
||||
lastValidIndex := -1
|
||||
|
||||
for i := 0; i < len(str); i++ {
|
||||
char := str[i]
|
||||
|
||||
if escape {
|
||||
escape = false
|
||||
continue
|
||||
}
|
||||
|
||||
if char == '\\' {
|
||||
escape = true
|
||||
continue
|
||||
}
|
||||
|
||||
if char == '"' {
|
||||
inString = !inString
|
||||
continue
|
||||
}
|
||||
|
||||
if inString {
|
||||
continue
|
||||
}
|
||||
|
||||
switch char {
|
||||
case '{':
|
||||
braceCount++
|
||||
case '}':
|
||||
braceCount--
|
||||
case '[':
|
||||
bracketCount++
|
||||
case ']':
|
||||
bracketCount--
|
||||
}
|
||||
|
||||
if braceCount >= 0 && bracketCount >= 0 {
|
||||
lastValidIndex = i
|
||||
}
|
||||
}
|
||||
|
||||
// If brackets are unbalanced, try to repair
|
||||
if braceCount > 0 || bracketCount > 0 {
|
||||
if lastValidIndex > 0 && lastValidIndex < len(str)-1 {
|
||||
truncated := str[:lastValidIndex+1]
|
||||
// Recount brackets after truncation
|
||||
braceCount = 0
|
||||
bracketCount = 0
|
||||
inString = false
|
||||
escape = false
|
||||
for i := 0; i < len(truncated); i++ {
|
||||
char := truncated[i]
|
||||
if escape {
|
||||
escape = false
|
||||
continue
|
||||
}
|
||||
if char == '\\' {
|
||||
escape = true
|
||||
continue
|
||||
}
|
||||
if char == '"' {
|
||||
inString = !inString
|
||||
continue
|
||||
}
|
||||
if inString {
|
||||
continue
|
||||
}
|
||||
switch char {
|
||||
case '{':
|
||||
braceCount++
|
||||
case '}':
|
||||
braceCount--
|
||||
case '[':
|
||||
bracketCount++
|
||||
case ']':
|
||||
bracketCount--
|
||||
}
|
||||
}
|
||||
str = truncated
|
||||
}
|
||||
|
||||
// Add missing closing brackets
|
||||
for braceCount > 0 {
|
||||
str += "}"
|
||||
braceCount--
|
||||
}
|
||||
for bracketCount > 0 {
|
||||
str += "]"
|
||||
bracketCount--
|
||||
}
|
||||
}
|
||||
|
||||
// Validate repaired JSON
|
||||
if err := json.Unmarshal([]byte(str), &testParse); err != nil {
|
||||
log.Warnf("kiro: repairJSON - repair failed to produce valid JSON, returning original")
|
||||
return originalStr
|
||||
}
|
||||
|
||||
log.Debugf("kiro: repairJSON - successfully repaired JSON")
|
||||
return str
|
||||
}
|
||||
|
||||
// escapeNewlinesInStrings escapes literal newlines, tabs, and other control characters
|
||||
// that appear inside JSON string values.
|
||||
func escapeNewlinesInStrings(raw string) string {
|
||||
var result strings.Builder
|
||||
result.Grow(len(raw) + 100)
|
||||
|
||||
inString := false
|
||||
escaped := false
|
||||
|
||||
for i := 0; i < len(raw); i++ {
|
||||
c := raw[i]
|
||||
|
||||
if escaped {
|
||||
result.WriteByte(c)
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
|
||||
if c == '\\' && inString {
|
||||
result.WriteByte(c)
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
|
||||
if c == '"' {
|
||||
inString = !inString
|
||||
result.WriteByte(c)
|
||||
continue
|
||||
}
|
||||
|
||||
if inString {
|
||||
switch c {
|
||||
case '\n':
|
||||
result.WriteString("\\n")
|
||||
case '\r':
|
||||
result.WriteString("\\r")
|
||||
case '\t':
|
||||
result.WriteString("\\t")
|
||||
default:
|
||||
result.WriteByte(c)
|
||||
}
|
||||
} else {
|
||||
result.WriteByte(c)
|
||||
}
|
||||
}
|
||||
|
||||
return result.String()
|
||||
}
|
||||
|
||||
// ProcessToolUseEvent handles a toolUseEvent from the Kiro stream.
|
||||
// It accumulates input fragments and emits tool_use blocks when complete.
|
||||
// Returns events to emit and updated state.
|
||||
func ProcessToolUseEvent(event map[string]interface{}, currentToolUse *ToolUseState, processedIDs map[string]bool) ([]KiroToolUse, *ToolUseState) {
|
||||
var toolUses []KiroToolUse
|
||||
|
||||
// Extract from nested toolUseEvent or direct format
|
||||
tu := event
|
||||
if nested, ok := event["toolUseEvent"].(map[string]interface{}); ok {
|
||||
tu = nested
|
||||
}
|
||||
|
||||
toolUseID := kirocommon.GetString(tu, "toolUseId")
|
||||
toolName := kirocommon.GetString(tu, "name")
|
||||
isStop := false
|
||||
if stop, ok := tu["stop"].(bool); ok {
|
||||
isStop = stop
|
||||
}
|
||||
|
||||
// Get input - can be string (fragment) or object (complete)
|
||||
var inputFragment string
|
||||
var inputMap map[string]interface{}
|
||||
|
||||
if inputRaw, ok := tu["input"]; ok {
|
||||
switch v := inputRaw.(type) {
|
||||
case string:
|
||||
inputFragment = v
|
||||
case map[string]interface{}:
|
||||
inputMap = v
|
||||
}
|
||||
}
|
||||
|
||||
// New tool use starting
|
||||
if toolUseID != "" && toolName != "" {
|
||||
if currentToolUse != nil && currentToolUse.ToolUseID != toolUseID {
|
||||
log.Warnf("kiro: interleaved tool use detected - new ID %s arrived while %s in progress, completing previous",
|
||||
toolUseID, currentToolUse.ToolUseID)
|
||||
if !processedIDs[currentToolUse.ToolUseID] {
|
||||
incomplete := KiroToolUse{
|
||||
ToolUseID: currentToolUse.ToolUseID,
|
||||
Name: currentToolUse.Name,
|
||||
}
|
||||
if currentToolUse.InputBuffer.Len() > 0 {
|
||||
raw := currentToolUse.InputBuffer.String()
|
||||
repaired := RepairJSON(raw)
|
||||
|
||||
var input map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(repaired), &input); err != nil {
|
||||
log.Warnf("kiro: failed to parse interleaved tool input: %v, raw: %s", err, raw)
|
||||
input = make(map[string]interface{})
|
||||
}
|
||||
incomplete.Input = input
|
||||
}
|
||||
toolUses = append(toolUses, incomplete)
|
||||
processedIDs[currentToolUse.ToolUseID] = true
|
||||
}
|
||||
currentToolUse = nil
|
||||
}
|
||||
|
||||
if currentToolUse == nil {
|
||||
if processedIDs != nil && processedIDs[toolUseID] {
|
||||
log.Debugf("kiro: skipping duplicate toolUseEvent: %s", toolUseID)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
currentToolUse = &ToolUseState{
|
||||
ToolUseID: toolUseID,
|
||||
Name: toolName,
|
||||
}
|
||||
log.Infof("kiro: starting new tool use: %s (ID: %s)", toolName, toolUseID)
|
||||
}
|
||||
}
|
||||
|
||||
// Accumulate input fragments
|
||||
if currentToolUse != nil && inputFragment != "" {
|
||||
currentToolUse.InputBuffer.WriteString(inputFragment)
|
||||
log.Debugf("kiro: accumulated input fragment, total length: %d", currentToolUse.InputBuffer.Len())
|
||||
}
|
||||
|
||||
// If complete input object provided directly
|
||||
if currentToolUse != nil && inputMap != nil {
|
||||
inputBytes, _ := json.Marshal(inputMap)
|
||||
currentToolUse.InputBuffer.Reset()
|
||||
currentToolUse.InputBuffer.Write(inputBytes)
|
||||
}
|
||||
|
||||
// Tool use complete
|
||||
if isStop && currentToolUse != nil {
|
||||
fullInput := currentToolUse.InputBuffer.String()
|
||||
|
||||
// Repair and parse the accumulated JSON
|
||||
repairedJSON := RepairJSON(fullInput)
|
||||
var finalInput map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(repairedJSON), &finalInput); err != nil {
|
||||
log.Warnf("kiro: failed to parse accumulated tool input: %v, raw: %s", err, fullInput)
|
||||
finalInput = make(map[string]interface{})
|
||||
}
|
||||
|
||||
toolUse := KiroToolUse{
|
||||
ToolUseID: currentToolUse.ToolUseID,
|
||||
Name: currentToolUse.Name,
|
||||
Input: finalInput,
|
||||
}
|
||||
toolUses = append(toolUses, toolUse)
|
||||
|
||||
if processedIDs != nil {
|
||||
processedIDs[currentToolUse.ToolUseID] = true
|
||||
}
|
||||
|
||||
log.Infof("kiro: completed tool use: %s (ID: %s)", currentToolUse.Name, currentToolUse.ToolUseID)
|
||||
return toolUses, nil
|
||||
}
|
||||
|
||||
return toolUses, currentToolUse
|
||||
}
|
||||
|
||||
// DeduplicateToolUses removes duplicate tool uses based on toolUseId and content.
|
||||
func DeduplicateToolUses(toolUses []KiroToolUse) []KiroToolUse {
|
||||
seenIDs := make(map[string]bool)
|
||||
seenContent := make(map[string]bool)
|
||||
var unique []KiroToolUse
|
||||
|
||||
for _, tu := range toolUses {
|
||||
if seenIDs[tu.ToolUseID] {
|
||||
log.Debugf("kiro: removing ID-duplicate tool use: %s (name: %s)", tu.ToolUseID, tu.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
inputJSON, _ := json.Marshal(tu.Input)
|
||||
contentKey := tu.Name + ":" + string(inputJSON)
|
||||
|
||||
if seenContent[contentKey] {
|
||||
log.Debugf("kiro: removing content-duplicate tool use: %s (id: %s)", tu.Name, tu.ToolUseID)
|
||||
continue
|
||||
}
|
||||
|
||||
seenIDs[tu.ToolUseID] = true
|
||||
seenContent[contentKey] = true
|
||||
unique = append(unique, tu)
|
||||
}
|
||||
|
||||
return unique
|
||||
}
|
||||
|
||||
75
internal/translator/kiro/common/constants.go
Normal file
75
internal/translator/kiro/common/constants.go
Normal file
@@ -0,0 +1,75 @@
|
||||
// Package common provides shared constants and utilities for Kiro translator.
|
||||
package common
|
||||
|
||||
const (
|
||||
// KiroMaxToolDescLen is the maximum description length for Kiro API tools.
|
||||
// Kiro API limit is 10240 bytes, leave room for "..."
|
||||
KiroMaxToolDescLen = 10237
|
||||
|
||||
// ThinkingStartTag is the start tag for thinking blocks in responses.
|
||||
ThinkingStartTag = "<thinking>"
|
||||
|
||||
// ThinkingEndTag is the end tag for thinking blocks in responses.
|
||||
ThinkingEndTag = "</thinking>"
|
||||
|
||||
// CodeFenceMarker is the markdown code fence marker.
|
||||
CodeFenceMarker = "```"
|
||||
|
||||
// AltCodeFenceMarker is the alternative markdown code fence marker.
|
||||
AltCodeFenceMarker = "~~~"
|
||||
|
||||
// InlineCodeMarker is the markdown inline code marker (backtick).
|
||||
InlineCodeMarker = "`"
|
||||
|
||||
// KiroAgenticSystemPrompt is injected only for -agentic models to prevent timeouts on large writes.
|
||||
// AWS Kiro API has a 2-3 minute timeout for large file write operations.
|
||||
KiroAgenticSystemPrompt = `
|
||||
# CRITICAL: CHUNKED WRITE PROTOCOL (MANDATORY)
|
||||
|
||||
You MUST follow these rules for ALL file operations. Violation causes server timeouts and task failure.
|
||||
|
||||
## ABSOLUTE LIMITS
|
||||
- **MAXIMUM 350 LINES** per single write/edit operation - NO EXCEPTIONS
|
||||
- **RECOMMENDED 300 LINES** or less for optimal performance
|
||||
- **NEVER** write entire files in one operation if >300 lines
|
||||
|
||||
## MANDATORY CHUNKED WRITE STRATEGY
|
||||
|
||||
### For NEW FILES (>300 lines total):
|
||||
1. FIRST: Write initial chunk (first 250-300 lines) using write_to_file/fsWrite
|
||||
2. THEN: Append remaining content in 250-300 line chunks using file append operations
|
||||
3. REPEAT: Continue appending until complete
|
||||
|
||||
### For EDITING EXISTING FILES:
|
||||
1. Use surgical edits (apply_diff/targeted edits) - change ONLY what's needed
|
||||
2. NEVER rewrite entire files - use incremental modifications
|
||||
3. Split large refactors into multiple small, focused edits
|
||||
|
||||
### For LARGE CODE GENERATION:
|
||||
1. Generate in logical sections (imports, types, functions separately)
|
||||
2. Write each section as a separate operation
|
||||
3. Use append operations for subsequent sections
|
||||
|
||||
## EXAMPLES OF CORRECT BEHAVIOR
|
||||
|
||||
✅ CORRECT: Writing a 600-line file
|
||||
- Operation 1: Write lines 1-300 (initial file creation)
|
||||
- Operation 2: Append lines 301-600
|
||||
|
||||
✅ CORRECT: Editing multiple functions
|
||||
- Operation 1: Edit function A
|
||||
- Operation 2: Edit function B
|
||||
- Operation 3: Edit function C
|
||||
|
||||
❌ WRONG: Writing 500 lines in single operation → TIMEOUT
|
||||
❌ WRONG: Rewriting entire file to change 5 lines → TIMEOUT
|
||||
❌ WRONG: Generating massive code blocks without chunking → TIMEOUT
|
||||
|
||||
## WHY THIS MATTERS
|
||||
- Server has 2-3 minute timeout for operations
|
||||
- Large writes exceed timeout and FAIL completely
|
||||
- Chunked writes are FASTER and more RELIABLE
|
||||
- Failed writes waste time and require retry
|
||||
|
||||
REMEMBER: When in doubt, write LESS per operation. Multiple small operations > one large operation.`
|
||||
)
|
||||
125
internal/translator/kiro/common/message_merge.go
Normal file
125
internal/translator/kiro/common/message_merge.go
Normal file
@@ -0,0 +1,125 @@
|
||||
// Package common provides shared utilities for Kiro translators.
|
||||
package common
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// MergeAdjacentMessages merges adjacent messages with the same role.
|
||||
// This reduces API call complexity and improves compatibility.
|
||||
// Based on AIClient-2-API implementation.
|
||||
func MergeAdjacentMessages(messages []gjson.Result) []gjson.Result {
|
||||
if len(messages) <= 1 {
|
||||
return messages
|
||||
}
|
||||
|
||||
var merged []gjson.Result
|
||||
for _, msg := range messages {
|
||||
if len(merged) == 0 {
|
||||
merged = append(merged, msg)
|
||||
continue
|
||||
}
|
||||
|
||||
lastMsg := merged[len(merged)-1]
|
||||
currentRole := msg.Get("role").String()
|
||||
lastRole := lastMsg.Get("role").String()
|
||||
|
||||
if currentRole == lastRole {
|
||||
// Merge content from current message into last message
|
||||
mergedContent := mergeMessageContent(lastMsg, msg)
|
||||
// Create a new merged message JSON
|
||||
mergedMsg := createMergedMessage(lastRole, mergedContent)
|
||||
merged[len(merged)-1] = gjson.Parse(mergedMsg)
|
||||
} else {
|
||||
merged = append(merged, msg)
|
||||
}
|
||||
}
|
||||
|
||||
return merged
|
||||
}
|
||||
|
||||
// mergeMessageContent merges the content of two messages with the same role.
|
||||
// Handles both string content and array content (with text, tool_use, tool_result blocks).
|
||||
func mergeMessageContent(msg1, msg2 gjson.Result) string {
|
||||
content1 := msg1.Get("content")
|
||||
content2 := msg2.Get("content")
|
||||
|
||||
// Extract content blocks from both messages
|
||||
var blocks1, blocks2 []map[string]interface{}
|
||||
|
||||
if content1.IsArray() {
|
||||
for _, block := range content1.Array() {
|
||||
blocks1 = append(blocks1, blockToMap(block))
|
||||
}
|
||||
} else if content1.Type == gjson.String {
|
||||
blocks1 = append(blocks1, map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": content1.String(),
|
||||
})
|
||||
}
|
||||
|
||||
if content2.IsArray() {
|
||||
for _, block := range content2.Array() {
|
||||
blocks2 = append(blocks2, blockToMap(block))
|
||||
}
|
||||
} else if content2.Type == gjson.String {
|
||||
blocks2 = append(blocks2, map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": content2.String(),
|
||||
})
|
||||
}
|
||||
|
||||
// Merge text blocks if both end/start with text
|
||||
if len(blocks1) > 0 && len(blocks2) > 0 {
|
||||
if blocks1[len(blocks1)-1]["type"] == "text" && blocks2[0]["type"] == "text" {
|
||||
// Merge the last text block of msg1 with the first text block of msg2
|
||||
text1 := blocks1[len(blocks1)-1]["text"].(string)
|
||||
text2 := blocks2[0]["text"].(string)
|
||||
blocks1[len(blocks1)-1]["text"] = text1 + "\n" + text2
|
||||
blocks2 = blocks2[1:] // Remove the merged block from blocks2
|
||||
}
|
||||
}
|
||||
|
||||
// Combine all blocks
|
||||
allBlocks := append(blocks1, blocks2...)
|
||||
|
||||
// Convert to JSON
|
||||
result, _ := json.Marshal(allBlocks)
|
||||
return string(result)
|
||||
}
|
||||
|
||||
// blockToMap converts a gjson.Result block to a map[string]interface{}
|
||||
func blockToMap(block gjson.Result) map[string]interface{} {
|
||||
result := make(map[string]interface{})
|
||||
block.ForEach(func(key, value gjson.Result) bool {
|
||||
if value.IsObject() {
|
||||
result[key.String()] = blockToMap(value)
|
||||
} else if value.IsArray() {
|
||||
var arr []interface{}
|
||||
for _, item := range value.Array() {
|
||||
if item.IsObject() {
|
||||
arr = append(arr, blockToMap(item))
|
||||
} else {
|
||||
arr = append(arr, item.Value())
|
||||
}
|
||||
}
|
||||
result[key.String()] = arr
|
||||
} else {
|
||||
result[key.String()] = value.Value()
|
||||
}
|
||||
return true
|
||||
})
|
||||
return result
|
||||
}
|
||||
|
||||
// createMergedMessage creates a JSON string for a merged message
|
||||
func createMergedMessage(role string, content string) string {
|
||||
msg := map[string]interface{}{
|
||||
"role": role,
|
||||
"content": json.RawMessage(content),
|
||||
}
|
||||
result, _ := json.Marshal(msg)
|
||||
return string(result)
|
||||
}
|
||||
16
internal/translator/kiro/common/utils.go
Normal file
16
internal/translator/kiro/common/utils.go
Normal file
@@ -0,0 +1,16 @@
|
||||
// Package common provides shared constants and utilities for Kiro translator.
|
||||
package common
|
||||
|
||||
// GetString safely extracts a string from a map.
|
||||
// Returns empty string if the key doesn't exist or the value is not a string.
|
||||
func GetString(m map[string]interface{}, key string) string {
|
||||
if v, ok := m[key].(string); ok {
|
||||
return v
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetStringValue is an alias for GetString for backward compatibility.
|
||||
func GetStringValue(m map[string]interface{}, key string) string {
|
||||
return GetString(m, key)
|
||||
}
|
||||
@@ -1,348 +0,0 @@
|
||||
// Package chat_completions provides request translation from OpenAI to Kiro format.
|
||||
package chat_completions
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// reasoningEffortToBudget maps OpenAI reasoning_effort values to Claude thinking budget_tokens.
|
||||
// OpenAI uses "low", "medium", "high" while Claude uses numeric budget_tokens.
|
||||
var reasoningEffortToBudget = map[string]int{
|
||||
"low": 4000,
|
||||
"medium": 16000,
|
||||
"high": 32000,
|
||||
}
|
||||
|
||||
// ConvertOpenAIRequestToKiro transforms an OpenAI Chat Completions API request into Kiro (Claude) format.
|
||||
// Kiro uses Claude-compatible format internally, so we primarily pass through to Claude format.
|
||||
// Supports tool calling: OpenAI tools -> Claude tools, tool_calls -> tool_use, tool messages -> tool_result.
|
||||
// Supports reasoning/thinking: OpenAI reasoning_effort -> Claude thinking parameter.
|
||||
func ConvertOpenAIRequestToKiro(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
|
||||
// Build Claude-compatible request
|
||||
out := `{"model":"","max_tokens":32000,"messages":[]}`
|
||||
|
||||
// Set model
|
||||
out, _ = sjson.Set(out, "model", modelName)
|
||||
|
||||
// Copy max_tokens if present
|
||||
if v := root.Get("max_tokens"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "max_tokens", v.Int())
|
||||
}
|
||||
|
||||
// Copy temperature if present
|
||||
if v := root.Get("temperature"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "temperature", v.Float())
|
||||
}
|
||||
|
||||
// Copy top_p if present
|
||||
if v := root.Get("top_p"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "top_p", v.Float())
|
||||
}
|
||||
|
||||
// Handle OpenAI reasoning_effort parameter -> Claude thinking parameter
|
||||
// OpenAI format: {"reasoning_effort": "low"|"medium"|"high"}
|
||||
// Claude format: {"thinking": {"type": "enabled", "budget_tokens": N}}
|
||||
if v := root.Get("reasoning_effort"); v.Exists() {
|
||||
effort := v.String()
|
||||
if budget, ok := reasoningEffortToBudget[effort]; ok {
|
||||
thinking := map[string]interface{}{
|
||||
"type": "enabled",
|
||||
"budget_tokens": budget,
|
||||
}
|
||||
out, _ = sjson.Set(out, "thinking", thinking)
|
||||
}
|
||||
}
|
||||
|
||||
// Also support direct thinking parameter passthrough (for Claude API compatibility)
|
||||
// Claude format: {"thinking": {"type": "enabled", "budget_tokens": N}}
|
||||
if v := root.Get("thinking"); v.Exists() && v.IsObject() {
|
||||
out, _ = sjson.Set(out, "thinking", v.Value())
|
||||
}
|
||||
|
||||
// Convert OpenAI tools to Claude tools format
|
||||
if tools := root.Get("tools"); tools.Exists() && tools.IsArray() {
|
||||
claudeTools := make([]interface{}, 0)
|
||||
for _, tool := range tools.Array() {
|
||||
if tool.Get("type").String() == "function" {
|
||||
fn := tool.Get("function")
|
||||
claudeTool := map[string]interface{}{
|
||||
"name": fn.Get("name").String(),
|
||||
"description": fn.Get("description").String(),
|
||||
}
|
||||
// Convert parameters to input_schema
|
||||
if params := fn.Get("parameters"); params.Exists() {
|
||||
claudeTool["input_schema"] = params.Value()
|
||||
} else {
|
||||
claudeTool["input_schema"] = map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{},
|
||||
}
|
||||
}
|
||||
claudeTools = append(claudeTools, claudeTool)
|
||||
}
|
||||
}
|
||||
if len(claudeTools) > 0 {
|
||||
out, _ = sjson.Set(out, "tools", claudeTools)
|
||||
}
|
||||
}
|
||||
|
||||
// Process messages
|
||||
messages := root.Get("messages")
|
||||
if messages.Exists() && messages.IsArray() {
|
||||
claudeMessages := make([]interface{}, 0)
|
||||
var systemPrompt string
|
||||
|
||||
// Track pending tool results to merge with next user message
|
||||
var pendingToolResults []map[string]interface{}
|
||||
|
||||
for _, msg := range messages.Array() {
|
||||
role := msg.Get("role").String()
|
||||
content := msg.Get("content")
|
||||
|
||||
if role == "system" {
|
||||
// Extract system message
|
||||
if content.IsArray() {
|
||||
for _, part := range content.Array() {
|
||||
if part.Get("type").String() == "text" {
|
||||
systemPrompt += part.Get("text").String() + "\n"
|
||||
}
|
||||
}
|
||||
} else {
|
||||
systemPrompt = content.String()
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if role == "tool" {
|
||||
// OpenAI tool message -> Claude tool_result content block
|
||||
toolCallID := msg.Get("tool_call_id").String()
|
||||
toolContent := content.String()
|
||||
|
||||
toolResult := map[string]interface{}{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": toolCallID,
|
||||
}
|
||||
|
||||
// Handle content - can be string or structured
|
||||
if content.IsArray() {
|
||||
contentParts := make([]interface{}, 0)
|
||||
for _, part := range content.Array() {
|
||||
if part.Get("type").String() == "text" {
|
||||
contentParts = append(contentParts, map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": part.Get("text").String(),
|
||||
})
|
||||
}
|
||||
}
|
||||
toolResult["content"] = contentParts
|
||||
} else {
|
||||
toolResult["content"] = toolContent
|
||||
}
|
||||
|
||||
pendingToolResults = append(pendingToolResults, toolResult)
|
||||
continue
|
||||
}
|
||||
|
||||
claudeMsg := map[string]interface{}{
|
||||
"role": role,
|
||||
}
|
||||
|
||||
// Handle assistant messages with tool_calls
|
||||
if role == "assistant" && msg.Get("tool_calls").Exists() {
|
||||
contentParts := make([]interface{}, 0)
|
||||
|
||||
// Add text content if present
|
||||
if content.Exists() && content.String() != "" {
|
||||
contentParts = append(contentParts, map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": content.String(),
|
||||
})
|
||||
}
|
||||
|
||||
// Convert tool_calls to tool_use blocks
|
||||
for _, toolCall := range msg.Get("tool_calls").Array() {
|
||||
toolUseID := toolCall.Get("id").String()
|
||||
fnName := toolCall.Get("function.name").String()
|
||||
fnArgs := toolCall.Get("function.arguments").String()
|
||||
|
||||
// Parse arguments JSON
|
||||
var argsMap map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(fnArgs), &argsMap); err != nil {
|
||||
argsMap = map[string]interface{}{"raw": fnArgs}
|
||||
}
|
||||
|
||||
contentParts = append(contentParts, map[string]interface{}{
|
||||
"type": "tool_use",
|
||||
"id": toolUseID,
|
||||
"name": fnName,
|
||||
"input": argsMap,
|
||||
})
|
||||
}
|
||||
|
||||
claudeMsg["content"] = contentParts
|
||||
claudeMessages = append(claudeMessages, claudeMsg)
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle user messages - may need to include pending tool results
|
||||
if role == "user" && len(pendingToolResults) > 0 {
|
||||
contentParts := make([]interface{}, 0)
|
||||
|
||||
// Add pending tool results first
|
||||
for _, tr := range pendingToolResults {
|
||||
contentParts = append(contentParts, tr)
|
||||
}
|
||||
pendingToolResults = nil
|
||||
|
||||
// Add user content
|
||||
if content.IsArray() {
|
||||
for _, part := range content.Array() {
|
||||
partType := part.Get("type").String()
|
||||
if partType == "text" {
|
||||
contentParts = append(contentParts, map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": part.Get("text").String(),
|
||||
})
|
||||
} else if partType == "image_url" {
|
||||
imageURL := part.Get("image_url.url").String()
|
||||
|
||||
// Check if it's base64 format (data:image/png;base64,xxxxx)
|
||||
if strings.HasPrefix(imageURL, "data:") {
|
||||
// Parse data URL format
|
||||
// Format: data:image/png;base64,xxxxx
|
||||
commaIdx := strings.Index(imageURL, ",")
|
||||
if commaIdx != -1 {
|
||||
// Extract media_type (e.g., "image/png")
|
||||
header := imageURL[5:commaIdx] // Remove "data:" prefix
|
||||
mediaType := header
|
||||
if semiIdx := strings.Index(header, ";"); semiIdx != -1 {
|
||||
mediaType = header[:semiIdx]
|
||||
}
|
||||
|
||||
// Extract base64 data
|
||||
base64Data := imageURL[commaIdx+1:]
|
||||
|
||||
contentParts = append(contentParts, map[string]interface{}{
|
||||
"type": "image",
|
||||
"source": map[string]interface{}{
|
||||
"type": "base64",
|
||||
"media_type": mediaType,
|
||||
"data": base64Data,
|
||||
},
|
||||
})
|
||||
}
|
||||
} else {
|
||||
// Regular URL format - keep original logic
|
||||
contentParts = append(contentParts, map[string]interface{}{
|
||||
"type": "image",
|
||||
"source": map[string]interface{}{
|
||||
"type": "url",
|
||||
"url": imageURL,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if content.String() != "" {
|
||||
contentParts = append(contentParts, map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": content.String(),
|
||||
})
|
||||
}
|
||||
|
||||
claudeMsg["content"] = contentParts
|
||||
claudeMessages = append(claudeMessages, claudeMsg)
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle regular content
|
||||
if content.IsArray() {
|
||||
contentParts := make([]interface{}, 0)
|
||||
for _, part := range content.Array() {
|
||||
partType := part.Get("type").String()
|
||||
if partType == "text" {
|
||||
contentParts = append(contentParts, map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": part.Get("text").String(),
|
||||
})
|
||||
} else if partType == "image_url" {
|
||||
imageURL := part.Get("image_url.url").String()
|
||||
|
||||
// Check if it's base64 format (data:image/png;base64,xxxxx)
|
||||
if strings.HasPrefix(imageURL, "data:") {
|
||||
// Parse data URL format
|
||||
// Format: data:image/png;base64,xxxxx
|
||||
commaIdx := strings.Index(imageURL, ",")
|
||||
if commaIdx != -1 {
|
||||
// Extract media_type (e.g., "image/png")
|
||||
header := imageURL[5:commaIdx] // Remove "data:" prefix
|
||||
mediaType := header
|
||||
if semiIdx := strings.Index(header, ";"); semiIdx != -1 {
|
||||
mediaType = header[:semiIdx]
|
||||
}
|
||||
|
||||
// Extract base64 data
|
||||
base64Data := imageURL[commaIdx+1:]
|
||||
|
||||
contentParts = append(contentParts, map[string]interface{}{
|
||||
"type": "image",
|
||||
"source": map[string]interface{}{
|
||||
"type": "base64",
|
||||
"media_type": mediaType,
|
||||
"data": base64Data,
|
||||
},
|
||||
})
|
||||
}
|
||||
} else {
|
||||
// Regular URL format - keep original logic
|
||||
contentParts = append(contentParts, map[string]interface{}{
|
||||
"type": "image",
|
||||
"source": map[string]interface{}{
|
||||
"type": "url",
|
||||
"url": imageURL,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
claudeMsg["content"] = contentParts
|
||||
} else {
|
||||
claudeMsg["content"] = content.String()
|
||||
}
|
||||
|
||||
claudeMessages = append(claudeMessages, claudeMsg)
|
||||
}
|
||||
|
||||
// If there are pending tool results without a following user message,
|
||||
// create a user message with just the tool results
|
||||
if len(pendingToolResults) > 0 {
|
||||
contentParts := make([]interface{}, 0)
|
||||
for _, tr := range pendingToolResults {
|
||||
contentParts = append(contentParts, tr)
|
||||
}
|
||||
claudeMessages = append(claudeMessages, map[string]interface{}{
|
||||
"role": "user",
|
||||
"content": contentParts,
|
||||
})
|
||||
}
|
||||
|
||||
out, _ = sjson.Set(out, "messages", claudeMessages)
|
||||
|
||||
if systemPrompt != "" {
|
||||
out, _ = sjson.Set(out, "system", systemPrompt)
|
||||
}
|
||||
}
|
||||
|
||||
// Set stream
|
||||
out, _ = sjson.Set(out, "stream", stream)
|
||||
|
||||
return []byte(out)
|
||||
}
|
||||
@@ -1,404 +0,0 @@
|
||||
// Package chat_completions provides response translation from Kiro to OpenAI format.
|
||||
package chat_completions
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// ConvertKiroResponseToOpenAI converts Kiro streaming response to OpenAI SSE format.
|
||||
// Handles Claude SSE events: content_block_start, content_block_delta, input_json_delta,
|
||||
// content_block_stop, message_delta, and message_stop.
|
||||
// Input may be in SSE format: "event: xxx\ndata: {...}" or raw JSON.
|
||||
func ConvertKiroResponseToOpenAI(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) []string {
|
||||
raw := string(rawResponse)
|
||||
var results []string
|
||||
|
||||
// Handle SSE format: extract JSON from "data: " lines
|
||||
// Input format: "event: message_start\ndata: {...}"
|
||||
lines := strings.Split(raw, "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if strings.HasPrefix(line, "data: ") {
|
||||
jsonPart := strings.TrimPrefix(line, "data: ")
|
||||
chunks := convertClaudeEventToOpenAI(jsonPart, model)
|
||||
results = append(results, chunks...)
|
||||
} else if strings.HasPrefix(line, "{") {
|
||||
// Raw JSON (backward compatibility)
|
||||
chunks := convertClaudeEventToOpenAI(line, model)
|
||||
results = append(results, chunks...)
|
||||
}
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
// convertClaudeEventToOpenAI converts a single Claude JSON event to OpenAI format
|
||||
func convertClaudeEventToOpenAI(jsonStr string, model string) []string {
|
||||
root := gjson.Parse(jsonStr)
|
||||
var results []string
|
||||
|
||||
eventType := root.Get("type").String()
|
||||
|
||||
switch eventType {
|
||||
case "message_start":
|
||||
// Initial message event - emit initial chunk with role
|
||||
response := map[string]interface{}{
|
||||
"id": "chatcmpl-" + uuid.New().String()[:24],
|
||||
"object": "chat.completion.chunk",
|
||||
"created": time.Now().Unix(),
|
||||
"model": model,
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
"index": 0,
|
||||
"delta": map[string]interface{}{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
},
|
||||
"finish_reason": nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
result, _ := json.Marshal(response)
|
||||
results = append(results, string(result))
|
||||
return results
|
||||
|
||||
case "content_block_start":
|
||||
// Start of a content block (text or tool_use)
|
||||
blockType := root.Get("content_block.type").String()
|
||||
index := int(root.Get("index").Int())
|
||||
|
||||
if blockType == "tool_use" {
|
||||
// Start of tool_use block
|
||||
toolUseID := root.Get("content_block.id").String()
|
||||
toolName := root.Get("content_block.name").String()
|
||||
|
||||
toolCall := map[string]interface{}{
|
||||
"index": index,
|
||||
"id": toolUseID,
|
||||
"type": "function",
|
||||
"function": map[string]interface{}{
|
||||
"name": toolName,
|
||||
"arguments": "",
|
||||
},
|
||||
}
|
||||
|
||||
response := map[string]interface{}{
|
||||
"id": "chatcmpl-" + uuid.New().String()[:24],
|
||||
"object": "chat.completion.chunk",
|
||||
"created": time.Now().Unix(),
|
||||
"model": model,
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
"index": 0,
|
||||
"delta": map[string]interface{}{
|
||||
"tool_calls": []map[string]interface{}{toolCall},
|
||||
},
|
||||
"finish_reason": nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
result, _ := json.Marshal(response)
|
||||
results = append(results, string(result))
|
||||
}
|
||||
return results
|
||||
|
||||
case "content_block_delta":
|
||||
index := int(root.Get("index").Int())
|
||||
deltaType := root.Get("delta.type").String()
|
||||
|
||||
if deltaType == "text_delta" {
|
||||
// Text content delta
|
||||
contentDelta := root.Get("delta.text").String()
|
||||
if contentDelta != "" {
|
||||
response := map[string]interface{}{
|
||||
"id": "chatcmpl-" + uuid.New().String()[:24],
|
||||
"object": "chat.completion.chunk",
|
||||
"created": time.Now().Unix(),
|
||||
"model": model,
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
"index": 0,
|
||||
"delta": map[string]interface{}{
|
||||
"content": contentDelta,
|
||||
},
|
||||
"finish_reason": nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
result, _ := json.Marshal(response)
|
||||
results = append(results, string(result))
|
||||
}
|
||||
} else if deltaType == "thinking_delta" {
|
||||
// Thinking/reasoning content delta - convert to OpenAI reasoning_content format
|
||||
thinkingDelta := root.Get("delta.thinking").String()
|
||||
if thinkingDelta != "" {
|
||||
response := map[string]interface{}{
|
||||
"id": "chatcmpl-" + uuid.New().String()[:24],
|
||||
"object": "chat.completion.chunk",
|
||||
"created": time.Now().Unix(),
|
||||
"model": model,
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
"index": 0,
|
||||
"delta": map[string]interface{}{
|
||||
"reasoning_content": thinkingDelta,
|
||||
},
|
||||
"finish_reason": nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
result, _ := json.Marshal(response)
|
||||
results = append(results, string(result))
|
||||
}
|
||||
} else if deltaType == "input_json_delta" {
|
||||
// Tool input delta (streaming arguments)
|
||||
partialJSON := root.Get("delta.partial_json").String()
|
||||
if partialJSON != "" {
|
||||
toolCall := map[string]interface{}{
|
||||
"index": index,
|
||||
"function": map[string]interface{}{
|
||||
"arguments": partialJSON,
|
||||
},
|
||||
}
|
||||
|
||||
response := map[string]interface{}{
|
||||
"id": "chatcmpl-" + uuid.New().String()[:24],
|
||||
"object": "chat.completion.chunk",
|
||||
"created": time.Now().Unix(),
|
||||
"model": model,
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
"index": 0,
|
||||
"delta": map[string]interface{}{
|
||||
"tool_calls": []map[string]interface{}{toolCall},
|
||||
},
|
||||
"finish_reason": nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
result, _ := json.Marshal(response)
|
||||
results = append(results, string(result))
|
||||
}
|
||||
}
|
||||
return results
|
||||
|
||||
case "content_block_stop":
|
||||
// End of content block - no output needed for OpenAI format
|
||||
return results
|
||||
|
||||
case "message_delta":
|
||||
// Final message delta with stop_reason and usage
|
||||
stopReason := root.Get("delta.stop_reason").String()
|
||||
if stopReason != "" {
|
||||
finishReason := "stop"
|
||||
if stopReason == "tool_use" {
|
||||
finishReason = "tool_calls"
|
||||
} else if stopReason == "end_turn" {
|
||||
finishReason = "stop"
|
||||
} else if stopReason == "max_tokens" {
|
||||
finishReason = "length"
|
||||
}
|
||||
|
||||
response := map[string]interface{}{
|
||||
"id": "chatcmpl-" + uuid.New().String()[:24],
|
||||
"object": "chat.completion.chunk",
|
||||
"created": time.Now().Unix(),
|
||||
"model": model,
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
"index": 0,
|
||||
"delta": map[string]interface{}{},
|
||||
"finish_reason": finishReason,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Extract and include usage information from message_delta event
|
||||
usage := root.Get("usage")
|
||||
if usage.Exists() {
|
||||
inputTokens := usage.Get("input_tokens").Int()
|
||||
outputTokens := usage.Get("output_tokens").Int()
|
||||
response["usage"] = map[string]interface{}{
|
||||
"prompt_tokens": inputTokens,
|
||||
"completion_tokens": outputTokens,
|
||||
"total_tokens": inputTokens + outputTokens,
|
||||
}
|
||||
}
|
||||
|
||||
result, _ := json.Marshal(response)
|
||||
results = append(results, string(result))
|
||||
}
|
||||
return results
|
||||
|
||||
case "message_stop":
|
||||
// End of message - could emit [DONE] marker
|
||||
return results
|
||||
}
|
||||
|
||||
// Fallback: handle raw content for backward compatibility
|
||||
var contentDelta string
|
||||
if delta := root.Get("delta.text"); delta.Exists() {
|
||||
contentDelta = delta.String()
|
||||
} else if content := root.Get("content"); content.Exists() && root.Get("type").String() == "" {
|
||||
contentDelta = content.String()
|
||||
}
|
||||
|
||||
if contentDelta != "" {
|
||||
response := map[string]interface{}{
|
||||
"id": "chatcmpl-" + uuid.New().String()[:24],
|
||||
"object": "chat.completion.chunk",
|
||||
"created": time.Now().Unix(),
|
||||
"model": model,
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
"index": 0,
|
||||
"delta": map[string]interface{}{
|
||||
"content": contentDelta,
|
||||
},
|
||||
"finish_reason": nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
result, _ := json.Marshal(response)
|
||||
results = append(results, string(result))
|
||||
}
|
||||
|
||||
// Handle tool_use content blocks (Claude format) - fallback
|
||||
toolUses := root.Get("delta.tool_use")
|
||||
if !toolUses.Exists() {
|
||||
toolUses = root.Get("tool_use")
|
||||
}
|
||||
if toolUses.Exists() && toolUses.IsObject() {
|
||||
inputJSON := toolUses.Get("input").String()
|
||||
if inputJSON == "" {
|
||||
if inputObj := toolUses.Get("input"); inputObj.Exists() {
|
||||
inputBytes, _ := json.Marshal(inputObj.Value())
|
||||
inputJSON = string(inputBytes)
|
||||
}
|
||||
}
|
||||
|
||||
toolCall := map[string]interface{}{
|
||||
"index": 0,
|
||||
"id": toolUses.Get("id").String(),
|
||||
"type": "function",
|
||||
"function": map[string]interface{}{
|
||||
"name": toolUses.Get("name").String(),
|
||||
"arguments": inputJSON,
|
||||
},
|
||||
}
|
||||
|
||||
response := map[string]interface{}{
|
||||
"id": "chatcmpl-" + uuid.New().String()[:24],
|
||||
"object": "chat.completion.chunk",
|
||||
"created": time.Now().Unix(),
|
||||
"model": model,
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
"index": 0,
|
||||
"delta": map[string]interface{}{
|
||||
"tool_calls": []map[string]interface{}{toolCall},
|
||||
},
|
||||
"finish_reason": nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
result, _ := json.Marshal(response)
|
||||
results = append(results, string(result))
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
// ConvertKiroResponseToOpenAINonStream converts Kiro non-streaming response to OpenAI format.
|
||||
func ConvertKiroResponseToOpenAINonStream(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) string {
|
||||
root := gjson.ParseBytes(rawResponse)
|
||||
|
||||
var content string
|
||||
var reasoningContent string
|
||||
var toolCalls []map[string]interface{}
|
||||
|
||||
contentArray := root.Get("content")
|
||||
if contentArray.IsArray() {
|
||||
for _, item := range contentArray.Array() {
|
||||
itemType := item.Get("type").String()
|
||||
if itemType == "text" {
|
||||
content += item.Get("text").String()
|
||||
} else if itemType == "thinking" {
|
||||
// Extract thinking/reasoning content
|
||||
reasoningContent += item.Get("thinking").String()
|
||||
} else if itemType == "tool_use" {
|
||||
// Convert Claude tool_use to OpenAI tool_calls format
|
||||
inputJSON := item.Get("input").String()
|
||||
if inputJSON == "" {
|
||||
// If input is an object, marshal it
|
||||
if inputObj := item.Get("input"); inputObj.Exists() {
|
||||
inputBytes, _ := json.Marshal(inputObj.Value())
|
||||
inputJSON = string(inputBytes)
|
||||
}
|
||||
}
|
||||
toolCall := map[string]interface{}{
|
||||
"id": item.Get("id").String(),
|
||||
"type": "function",
|
||||
"function": map[string]interface{}{
|
||||
"name": item.Get("name").String(),
|
||||
"arguments": inputJSON,
|
||||
},
|
||||
}
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
content = root.Get("content").String()
|
||||
}
|
||||
|
||||
inputTokens := root.Get("usage.input_tokens").Int()
|
||||
outputTokens := root.Get("usage.output_tokens").Int()
|
||||
|
||||
message := map[string]interface{}{
|
||||
"role": "assistant",
|
||||
"content": content,
|
||||
}
|
||||
|
||||
// Add reasoning_content if present (OpenAI reasoning format)
|
||||
if reasoningContent != "" {
|
||||
message["reasoning_content"] = reasoningContent
|
||||
}
|
||||
|
||||
// Add tool_calls if present
|
||||
if len(toolCalls) > 0 {
|
||||
message["tool_calls"] = toolCalls
|
||||
}
|
||||
|
||||
finishReason := "stop"
|
||||
if len(toolCalls) > 0 {
|
||||
finishReason = "tool_calls"
|
||||
}
|
||||
|
||||
response := map[string]interface{}{
|
||||
"id": "chatcmpl-" + uuid.New().String()[:24],
|
||||
"object": "chat.completion",
|
||||
"created": time.Now().Unix(),
|
||||
"model": model,
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
"index": 0,
|
||||
"message": message,
|
||||
"finish_reason": finishReason,
|
||||
},
|
||||
},
|
||||
"usage": map[string]interface{}{
|
||||
"prompt_tokens": inputTokens,
|
||||
"completion_tokens": outputTokens,
|
||||
"total_tokens": inputTokens + outputTokens,
|
||||
},
|
||||
}
|
||||
|
||||
result, _ := json.Marshal(response)
|
||||
return string(result)
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
package chat_completions
|
||||
// Package openai provides translation between OpenAI Chat Completions and Kiro formats.
|
||||
package openai
|
||||
|
||||
import (
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
|
||||
@@ -8,12 +9,12 @@ import (
|
||||
|
||||
func init() {
|
||||
translator.Register(
|
||||
OpenAI,
|
||||
Kiro,
|
||||
OpenAI, // source format
|
||||
Kiro, // target format
|
||||
ConvertOpenAIRequestToKiro,
|
||||
interfaces.TranslateResponse{
|
||||
Stream: ConvertKiroResponseToOpenAI,
|
||||
NonStream: ConvertKiroResponseToOpenAINonStream,
|
||||
Stream: ConvertKiroStreamToOpenAI,
|
||||
NonStream: ConvertKiroNonStreamToOpenAI,
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
371
internal/translator/kiro/openai/kiro_openai.go
Normal file
371
internal/translator/kiro/openai/kiro_openai.go
Normal file
@@ -0,0 +1,371 @@
|
||||
// Package openai provides translation between OpenAI Chat Completions and Kiro formats.
|
||||
// This package enables direct OpenAI → Kiro translation, bypassing the Claude intermediate layer.
|
||||
//
|
||||
// The Kiro executor generates Claude-compatible SSE format internally, so the streaming response
|
||||
// translation converts from Claude SSE format to OpenAI SSE format.
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
kirocommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/common"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// ConvertKiroStreamToOpenAI converts Kiro streaming response to OpenAI format.
|
||||
// The Kiro executor emits Claude-compatible SSE events, so this function translates
|
||||
// from Claude SSE format to OpenAI SSE format.
|
||||
//
|
||||
// Claude SSE format:
|
||||
// - event: message_start\ndata: {...}
|
||||
// - event: content_block_start\ndata: {...}
|
||||
// - event: content_block_delta\ndata: {...}
|
||||
// - event: content_block_stop\ndata: {...}
|
||||
// - event: message_delta\ndata: {...}
|
||||
// - event: message_stop\ndata: {...}
|
||||
//
|
||||
// OpenAI SSE format:
|
||||
// - data: {"id":"...","object":"chat.completion.chunk",...}
|
||||
// - data: [DONE]
|
||||
func ConvertKiroStreamToOpenAI(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) []string {
|
||||
// Initialize state if needed
|
||||
if *param == nil {
|
||||
*param = NewOpenAIStreamState(model)
|
||||
}
|
||||
state := (*param).(*OpenAIStreamState)
|
||||
|
||||
// Parse the Claude SSE event
|
||||
responseStr := string(rawResponse)
|
||||
|
||||
// Handle raw event format (event: xxx\ndata: {...})
|
||||
var eventType string
|
||||
var eventData string
|
||||
|
||||
if strings.HasPrefix(responseStr, "event:") {
|
||||
// Parse event type and data
|
||||
lines := strings.SplitN(responseStr, "\n", 2)
|
||||
if len(lines) >= 1 {
|
||||
eventType = strings.TrimSpace(strings.TrimPrefix(lines[0], "event:"))
|
||||
}
|
||||
if len(lines) >= 2 && strings.HasPrefix(lines[1], "data:") {
|
||||
eventData = strings.TrimSpace(strings.TrimPrefix(lines[1], "data:"))
|
||||
}
|
||||
} else if strings.HasPrefix(responseStr, "data:") {
|
||||
// Just data line
|
||||
eventData = strings.TrimSpace(strings.TrimPrefix(responseStr, "data:"))
|
||||
} else {
|
||||
// Try to parse as raw JSON
|
||||
eventData = strings.TrimSpace(responseStr)
|
||||
}
|
||||
|
||||
if eventData == "" {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
// Parse the event data as JSON
|
||||
eventJSON := gjson.Parse(eventData)
|
||||
if !eventJSON.Exists() {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
// Determine event type from JSON if not already set
|
||||
if eventType == "" {
|
||||
eventType = eventJSON.Get("type").String()
|
||||
}
|
||||
|
||||
var results []string
|
||||
|
||||
switch eventType {
|
||||
case "message_start":
|
||||
// Send first chunk with role
|
||||
firstChunk := BuildOpenAISSEFirstChunk(state)
|
||||
results = append(results, firstChunk)
|
||||
|
||||
case "content_block_start":
|
||||
// Check block type
|
||||
blockType := eventJSON.Get("content_block.type").String()
|
||||
switch blockType {
|
||||
case "text":
|
||||
// Text block starting - nothing to emit yet
|
||||
case "thinking":
|
||||
// Thinking block starting - nothing to emit yet for OpenAI
|
||||
case "tool_use":
|
||||
// Tool use block starting
|
||||
toolUseID := eventJSON.Get("content_block.id").String()
|
||||
toolName := eventJSON.Get("content_block.name").String()
|
||||
chunk := BuildOpenAISSEToolCallStart(state, toolUseID, toolName)
|
||||
results = append(results, chunk)
|
||||
state.ToolCallIndex++
|
||||
}
|
||||
|
||||
case "content_block_delta":
|
||||
deltaType := eventJSON.Get("delta.type").String()
|
||||
switch deltaType {
|
||||
case "text_delta":
|
||||
textDelta := eventJSON.Get("delta.text").String()
|
||||
if textDelta != "" {
|
||||
chunk := BuildOpenAISSETextDelta(state, textDelta)
|
||||
results = append(results, chunk)
|
||||
}
|
||||
case "thinking_delta":
|
||||
// Convert thinking to reasoning_content for o1-style compatibility
|
||||
thinkingDelta := eventJSON.Get("delta.thinking").String()
|
||||
if thinkingDelta != "" {
|
||||
chunk := BuildOpenAISSEReasoningDelta(state, thinkingDelta)
|
||||
results = append(results, chunk)
|
||||
}
|
||||
case "input_json_delta":
|
||||
// Tool call arguments delta
|
||||
partialJSON := eventJSON.Get("delta.partial_json").String()
|
||||
if partialJSON != "" {
|
||||
// Get the tool index from content block index
|
||||
blockIndex := int(eventJSON.Get("index").Int())
|
||||
chunk := BuildOpenAISSEToolCallArgumentsDelta(state, partialJSON, blockIndex-1) // Adjust for 0-based tool index
|
||||
results = append(results, chunk)
|
||||
}
|
||||
}
|
||||
|
||||
case "content_block_stop":
|
||||
// Content block ended - nothing to emit for OpenAI
|
||||
|
||||
case "message_delta":
|
||||
// Message delta with stop_reason
|
||||
stopReason := eventJSON.Get("delta.stop_reason").String()
|
||||
finishReason := mapKiroStopReasonToOpenAI(stopReason)
|
||||
if finishReason != "" {
|
||||
chunk := BuildOpenAISSEFinish(state, finishReason)
|
||||
results = append(results, chunk)
|
||||
}
|
||||
|
||||
// Extract usage if present
|
||||
if eventJSON.Get("usage").Exists() {
|
||||
inputTokens := eventJSON.Get("usage.input_tokens").Int()
|
||||
outputTokens := eventJSON.Get("usage.output_tokens").Int()
|
||||
usageInfo := usage.Detail{
|
||||
InputTokens: inputTokens,
|
||||
OutputTokens: outputTokens,
|
||||
TotalTokens: inputTokens + outputTokens,
|
||||
}
|
||||
chunk := BuildOpenAISSEUsage(state, usageInfo)
|
||||
results = append(results, chunk)
|
||||
}
|
||||
|
||||
case "message_stop":
|
||||
// Final event - do NOT emit [DONE] here
|
||||
// The handler layer (openai_handlers.go) will send [DONE] when the stream closes
|
||||
// Emitting [DONE] here would cause duplicate [DONE] markers
|
||||
|
||||
case "ping":
|
||||
// Ping event with usage - optionally emit usage chunk
|
||||
if eventJSON.Get("usage").Exists() {
|
||||
inputTokens := eventJSON.Get("usage.input_tokens").Int()
|
||||
outputTokens := eventJSON.Get("usage.output_tokens").Int()
|
||||
usageInfo := usage.Detail{
|
||||
InputTokens: inputTokens,
|
||||
OutputTokens: outputTokens,
|
||||
TotalTokens: inputTokens + outputTokens,
|
||||
}
|
||||
chunk := BuildOpenAISSEUsage(state, usageInfo)
|
||||
results = append(results, chunk)
|
||||
}
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
// ConvertKiroNonStreamToOpenAI converts Kiro non-streaming response to OpenAI format.
|
||||
// The Kiro executor returns Claude-compatible JSON responses, so this function translates
|
||||
// from Claude format to OpenAI format.
|
||||
func ConvertKiroNonStreamToOpenAI(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) string {
|
||||
// Parse the Claude-format response
|
||||
response := gjson.ParseBytes(rawResponse)
|
||||
|
||||
// Extract content
|
||||
var content string
|
||||
var reasoningContent string
|
||||
var toolUses []KiroToolUse
|
||||
var stopReason string
|
||||
|
||||
// Get stop_reason
|
||||
stopReason = response.Get("stop_reason").String()
|
||||
|
||||
// Process content blocks
|
||||
contentBlocks := response.Get("content")
|
||||
if contentBlocks.IsArray() {
|
||||
for _, block := range contentBlocks.Array() {
|
||||
blockType := block.Get("type").String()
|
||||
switch blockType {
|
||||
case "text":
|
||||
content += block.Get("text").String()
|
||||
case "thinking":
|
||||
// Convert thinking blocks to reasoning_content for OpenAI format
|
||||
reasoningContent += block.Get("thinking").String()
|
||||
case "tool_use":
|
||||
toolUseID := block.Get("id").String()
|
||||
toolName := block.Get("name").String()
|
||||
toolInput := block.Get("input")
|
||||
|
||||
var inputMap map[string]interface{}
|
||||
if toolInput.IsObject() {
|
||||
inputMap = make(map[string]interface{})
|
||||
toolInput.ForEach(func(key, value gjson.Result) bool {
|
||||
inputMap[key.String()] = value.Value()
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
toolUses = append(toolUses, KiroToolUse{
|
||||
ToolUseID: toolUseID,
|
||||
Name: toolName,
|
||||
Input: inputMap,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract usage
|
||||
usageInfo := usage.Detail{
|
||||
InputTokens: response.Get("usage.input_tokens").Int(),
|
||||
OutputTokens: response.Get("usage.output_tokens").Int(),
|
||||
}
|
||||
usageInfo.TotalTokens = usageInfo.InputTokens + usageInfo.OutputTokens
|
||||
|
||||
// Build OpenAI response with reasoning_content support
|
||||
openaiResponse := BuildOpenAIResponseWithReasoning(content, reasoningContent, toolUses, model, usageInfo, stopReason)
|
||||
return string(openaiResponse)
|
||||
}
|
||||
|
||||
// ParseClaudeEvent parses a Claude SSE event and returns the event type and data
|
||||
func ParseClaudeEvent(rawEvent []byte) (eventType string, eventData []byte) {
|
||||
lines := bytes.Split(rawEvent, []byte("\n"))
|
||||
for _, line := range lines {
|
||||
line = bytes.TrimSpace(line)
|
||||
if bytes.HasPrefix(line, []byte("event:")) {
|
||||
eventType = string(bytes.TrimSpace(bytes.TrimPrefix(line, []byte("event:"))))
|
||||
} else if bytes.HasPrefix(line, []byte("data:")) {
|
||||
eventData = bytes.TrimSpace(bytes.TrimPrefix(line, []byte("data:")))
|
||||
}
|
||||
}
|
||||
return eventType, eventData
|
||||
}
|
||||
|
||||
// ExtractThinkingFromContent parses content to extract thinking blocks.
|
||||
// Returns cleaned content (without thinking tags) and whether thinking was found.
|
||||
func ExtractThinkingFromContent(content string) (string, string, bool) {
|
||||
if !strings.Contains(content, kirocommon.ThinkingStartTag) {
|
||||
return content, "", false
|
||||
}
|
||||
|
||||
var cleanedContent strings.Builder
|
||||
var thinkingContent strings.Builder
|
||||
hasThinking := false
|
||||
remaining := content
|
||||
|
||||
for len(remaining) > 0 {
|
||||
startIdx := strings.Index(remaining, kirocommon.ThinkingStartTag)
|
||||
if startIdx == -1 {
|
||||
cleanedContent.WriteString(remaining)
|
||||
break
|
||||
}
|
||||
|
||||
// Add content before thinking tag
|
||||
cleanedContent.WriteString(remaining[:startIdx])
|
||||
|
||||
// Move past opening tag
|
||||
remaining = remaining[startIdx+len(kirocommon.ThinkingStartTag):]
|
||||
|
||||
// Find closing tag
|
||||
endIdx := strings.Index(remaining, kirocommon.ThinkingEndTag)
|
||||
if endIdx == -1 {
|
||||
// No closing tag - treat rest as thinking
|
||||
thinkingContent.WriteString(remaining)
|
||||
hasThinking = true
|
||||
break
|
||||
}
|
||||
|
||||
// Extract thinking content
|
||||
thinkingContent.WriteString(remaining[:endIdx])
|
||||
hasThinking = true
|
||||
remaining = remaining[endIdx+len(kirocommon.ThinkingEndTag):]
|
||||
}
|
||||
|
||||
return strings.TrimSpace(cleanedContent.String()), strings.TrimSpace(thinkingContent.String()), hasThinking
|
||||
}
|
||||
|
||||
// ConvertOpenAIToolsToKiroFormat is a helper that converts OpenAI tools format to Kiro format
|
||||
func ConvertOpenAIToolsToKiroFormat(tools []map[string]interface{}) []KiroToolWrapper {
|
||||
var kiroTools []KiroToolWrapper
|
||||
|
||||
for _, tool := range tools {
|
||||
toolType, _ := tool["type"].(string)
|
||||
if toolType != "function" {
|
||||
continue
|
||||
}
|
||||
|
||||
fn, ok := tool["function"].(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
name := kirocommon.GetString(fn, "name")
|
||||
description := kirocommon.GetString(fn, "description")
|
||||
parameters := fn["parameters"]
|
||||
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if description == "" {
|
||||
description = "Tool: " + name
|
||||
}
|
||||
|
||||
kiroTools = append(kiroTools, KiroToolWrapper{
|
||||
ToolSpecification: KiroToolSpecification{
|
||||
Name: name,
|
||||
Description: description,
|
||||
InputSchema: KiroInputSchema{JSON: parameters},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return kiroTools
|
||||
}
|
||||
|
||||
// OpenAIStreamParams holds parameters for OpenAI streaming conversion
|
||||
type OpenAIStreamParams struct {
|
||||
State *OpenAIStreamState
|
||||
ThinkingState *ThinkingTagState
|
||||
ToolCallsEmitted map[string]bool
|
||||
}
|
||||
|
||||
// NewOpenAIStreamParams creates new streaming parameters
|
||||
func NewOpenAIStreamParams(model string) *OpenAIStreamParams {
|
||||
return &OpenAIStreamParams{
|
||||
State: NewOpenAIStreamState(model),
|
||||
ThinkingState: NewThinkingTagState(),
|
||||
ToolCallsEmitted: make(map[string]bool),
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertClaudeToolUseToOpenAI converts a Claude tool_use block to OpenAI tool_calls format
|
||||
func ConvertClaudeToolUseToOpenAI(toolUseID, toolName string, input map[string]interface{}) map[string]interface{} {
|
||||
inputJSON, _ := json.Marshal(input)
|
||||
return map[string]interface{}{
|
||||
"id": toolUseID,
|
||||
"type": "function",
|
||||
"function": map[string]interface{}{
|
||||
"name": toolName,
|
||||
"arguments": string(inputJSON),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// LogStreamEvent logs a streaming event for debugging
|
||||
func LogStreamEvent(eventType, data string) {
|
||||
log.Debugf("kiro-openai: stream event type=%s, data_len=%d", eventType, len(data))
|
||||
}
|
||||
850
internal/translator/kiro/openai/kiro_openai_request.go
Normal file
850
internal/translator/kiro/openai/kiro_openai_request.go
Normal file
@@ -0,0 +1,850 @@
|
||||
// Package openai provides request translation from OpenAI Chat Completions to Kiro format.
|
||||
// It handles parsing and transforming OpenAI API requests into the Kiro/Amazon Q API format,
|
||||
// extracting model information, system instructions, message contents, and tool declarations.
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/google/uuid"
|
||||
kiroclaude "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/claude"
|
||||
kirocommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/common"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// Kiro API request structs - reuse from kiroclaude package structure
|
||||
|
||||
// KiroPayload is the top-level request structure for Kiro API
|
||||
type KiroPayload struct {
|
||||
ConversationState KiroConversationState `json:"conversationState"`
|
||||
ProfileArn string `json:"profileArn,omitempty"`
|
||||
InferenceConfig *KiroInferenceConfig `json:"inferenceConfig,omitempty"`
|
||||
}
|
||||
|
||||
// KiroInferenceConfig contains inference parameters for the Kiro API.
|
||||
type KiroInferenceConfig struct {
|
||||
MaxTokens int `json:"maxTokens,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
}
|
||||
|
||||
// KiroConversationState holds the conversation context
|
||||
type KiroConversationState struct {
|
||||
ChatTriggerType string `json:"chatTriggerType"` // Required: "MANUAL"
|
||||
ConversationID string `json:"conversationId"`
|
||||
CurrentMessage KiroCurrentMessage `json:"currentMessage"`
|
||||
History []KiroHistoryMessage `json:"history,omitempty"`
|
||||
}
|
||||
|
||||
// KiroCurrentMessage wraps the current user message
|
||||
type KiroCurrentMessage struct {
|
||||
UserInputMessage KiroUserInputMessage `json:"userInputMessage"`
|
||||
}
|
||||
|
||||
// KiroHistoryMessage represents a message in the conversation history
|
||||
type KiroHistoryMessage struct {
|
||||
UserInputMessage *KiroUserInputMessage `json:"userInputMessage,omitempty"`
|
||||
AssistantResponseMessage *KiroAssistantResponseMessage `json:"assistantResponseMessage,omitempty"`
|
||||
}
|
||||
|
||||
// KiroImage represents an image in Kiro API format
|
||||
type KiroImage struct {
|
||||
Format string `json:"format"`
|
||||
Source KiroImageSource `json:"source"`
|
||||
}
|
||||
|
||||
// KiroImageSource contains the image data
|
||||
type KiroImageSource struct {
|
||||
Bytes string `json:"bytes"` // base64 encoded image data
|
||||
}
|
||||
|
||||
// KiroUserInputMessage represents a user message
|
||||
type KiroUserInputMessage struct {
|
||||
Content string `json:"content"`
|
||||
ModelID string `json:"modelId"`
|
||||
Origin string `json:"origin"`
|
||||
Images []KiroImage `json:"images,omitempty"`
|
||||
UserInputMessageContext *KiroUserInputMessageContext `json:"userInputMessageContext,omitempty"`
|
||||
}
|
||||
|
||||
// KiroUserInputMessageContext contains tool-related context
|
||||
type KiroUserInputMessageContext struct {
|
||||
ToolResults []KiroToolResult `json:"toolResults,omitempty"`
|
||||
Tools []KiroToolWrapper `json:"tools,omitempty"`
|
||||
}
|
||||
|
||||
// KiroToolResult represents a tool execution result
|
||||
type KiroToolResult struct {
|
||||
Content []KiroTextContent `json:"content"`
|
||||
Status string `json:"status"`
|
||||
ToolUseID string `json:"toolUseId"`
|
||||
}
|
||||
|
||||
// KiroTextContent represents text content
|
||||
type KiroTextContent struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
// KiroToolWrapper wraps a tool specification
|
||||
type KiroToolWrapper struct {
|
||||
ToolSpecification KiroToolSpecification `json:"toolSpecification"`
|
||||
}
|
||||
|
||||
// KiroToolSpecification defines a tool's schema
|
||||
type KiroToolSpecification struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
InputSchema KiroInputSchema `json:"inputSchema"`
|
||||
}
|
||||
|
||||
// KiroInputSchema wraps the JSON schema for tool input
|
||||
type KiroInputSchema struct {
|
||||
JSON interface{} `json:"json"`
|
||||
}
|
||||
|
||||
// KiroAssistantResponseMessage represents an assistant message
|
||||
type KiroAssistantResponseMessage struct {
|
||||
Content string `json:"content"`
|
||||
ToolUses []KiroToolUse `json:"toolUses,omitempty"`
|
||||
}
|
||||
|
||||
// KiroToolUse represents a tool invocation by the assistant
|
||||
type KiroToolUse struct {
|
||||
ToolUseID string `json:"toolUseId"`
|
||||
Name string `json:"name"`
|
||||
Input map[string]interface{} `json:"input"`
|
||||
}
|
||||
|
||||
// ConvertOpenAIRequestToKiro converts an OpenAI Chat Completions request to Kiro format.
|
||||
// This is the main entry point for request translation.
|
||||
// Note: The actual payload building happens in the executor, this just passes through
|
||||
// the OpenAI format which will be converted by BuildKiroPayloadFromOpenAI.
|
||||
func ConvertOpenAIRequestToKiro(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
// Pass through the OpenAI format - actual conversion happens in BuildKiroPayloadFromOpenAI
|
||||
return inputRawJSON
|
||||
}
|
||||
|
||||
// BuildKiroPayloadFromOpenAI constructs the Kiro API request payload from OpenAI format.
|
||||
// Supports tool calling - tools are passed via userInputMessageContext.
|
||||
// origin parameter determines which quota to use: "CLI" for Amazon Q, "AI_EDITOR" for Kiro IDE.
|
||||
// isAgentic parameter enables chunked write optimization prompt for -agentic model variants.
|
||||
// isChatOnly parameter disables tool calling for -chat model variants (pure conversation mode).
|
||||
// headers parameter allows checking Anthropic-Beta header for thinking mode detection.
|
||||
// metadata parameter is kept for API compatibility but no longer used for thinking configuration.
|
||||
// Returns the payload and a boolean indicating whether thinking mode was injected.
|
||||
func BuildKiroPayloadFromOpenAI(openaiBody []byte, modelID, profileArn, origin string, isAgentic, isChatOnly bool, headers http.Header, metadata map[string]any) ([]byte, bool) {
|
||||
// Extract max_tokens for potential use in inferenceConfig
|
||||
// Handle -1 as "use maximum" (Kiro max output is ~32000 tokens)
|
||||
const kiroMaxOutputTokens = 32000
|
||||
var maxTokens int64
|
||||
if mt := gjson.GetBytes(openaiBody, "max_tokens"); mt.Exists() {
|
||||
maxTokens = mt.Int()
|
||||
if maxTokens == -1 {
|
||||
maxTokens = kiroMaxOutputTokens
|
||||
log.Debugf("kiro-openai: max_tokens=-1 converted to %d", kiroMaxOutputTokens)
|
||||
}
|
||||
}
|
||||
|
||||
// Extract temperature if specified
|
||||
var temperature float64
|
||||
var hasTemperature bool
|
||||
if temp := gjson.GetBytes(openaiBody, "temperature"); temp.Exists() {
|
||||
temperature = temp.Float()
|
||||
hasTemperature = true
|
||||
}
|
||||
|
||||
// Extract top_p if specified
|
||||
var topP float64
|
||||
var hasTopP bool
|
||||
if tp := gjson.GetBytes(openaiBody, "top_p"); tp.Exists() {
|
||||
topP = tp.Float()
|
||||
hasTopP = true
|
||||
log.Debugf("kiro-openai: extracted top_p: %.2f", topP)
|
||||
}
|
||||
|
||||
// Normalize origin value for Kiro API compatibility
|
||||
origin = normalizeOrigin(origin)
|
||||
log.Debugf("kiro-openai: normalized origin value: %s", origin)
|
||||
|
||||
messages := gjson.GetBytes(openaiBody, "messages")
|
||||
|
||||
// For chat-only mode, don't include tools
|
||||
var tools gjson.Result
|
||||
if !isChatOnly {
|
||||
tools = gjson.GetBytes(openaiBody, "tools")
|
||||
}
|
||||
|
||||
// Extract system prompt from messages
|
||||
systemPrompt := extractSystemPromptFromOpenAI(messages)
|
||||
|
||||
// Inject timestamp context
|
||||
timestamp := time.Now().Format("2006-01-02 15:04:05 MST")
|
||||
timestampContext := fmt.Sprintf("[Context: Current time is %s]", timestamp)
|
||||
if systemPrompt != "" {
|
||||
systemPrompt = timestampContext + "\n\n" + systemPrompt
|
||||
} else {
|
||||
systemPrompt = timestampContext
|
||||
}
|
||||
log.Debugf("kiro-openai: injected timestamp context: %s", timestamp)
|
||||
|
||||
// Inject agentic optimization prompt for -agentic model variants
|
||||
if isAgentic {
|
||||
if systemPrompt != "" {
|
||||
systemPrompt += "\n"
|
||||
}
|
||||
systemPrompt += kirocommon.KiroAgenticSystemPrompt
|
||||
}
|
||||
|
||||
// Handle tool_choice parameter - Kiro doesn't support it natively, so we inject system prompt hints
|
||||
// OpenAI tool_choice values: "none", "auto", "required", or {"type":"function","function":{"name":"..."}}
|
||||
toolChoiceHint := extractToolChoiceHint(openaiBody)
|
||||
if toolChoiceHint != "" {
|
||||
if systemPrompt != "" {
|
||||
systemPrompt += "\n"
|
||||
}
|
||||
systemPrompt += toolChoiceHint
|
||||
log.Debugf("kiro-openai: injected tool_choice hint into system prompt")
|
||||
}
|
||||
|
||||
// Handle response_format parameter - Kiro doesn't support it natively, so we inject system prompt hints
|
||||
// OpenAI response_format: {"type": "json_object"} or {"type": "json_schema", "json_schema": {...}}
|
||||
responseFormatHint := extractResponseFormatHint(openaiBody)
|
||||
if responseFormatHint != "" {
|
||||
if systemPrompt != "" {
|
||||
systemPrompt += "\n"
|
||||
}
|
||||
systemPrompt += responseFormatHint
|
||||
log.Debugf("kiro-openai: injected response_format hint into system prompt")
|
||||
}
|
||||
|
||||
// Check for thinking mode
|
||||
// Supports OpenAI reasoning_effort parameter, model name hints, and Anthropic-Beta header
|
||||
thinkingEnabled := checkThinkingModeFromOpenAIWithHeaders(openaiBody, headers)
|
||||
|
||||
// Convert OpenAI tools to Kiro format
|
||||
kiroTools := convertOpenAIToolsToKiro(tools)
|
||||
|
||||
// Thinking mode implementation:
|
||||
// Kiro API doesn't accept max_tokens for thinking. Instead, thinking mode is enabled
|
||||
// by injecting <thinking_mode> and <max_thinking_length> tags into the system prompt.
|
||||
// We use a fixed max_thinking_length value since Kiro handles the actual budget internally.
|
||||
if thinkingEnabled {
|
||||
thinkingHint := `<thinking_mode>interleaved</thinking_mode>
|
||||
<max_thinking_length>200000</max_thinking_length>
|
||||
|
||||
IMPORTANT: You MUST use <thinking>...</thinking> tags to show your reasoning process before providing your final response. Think step by step inside the thinking tags.`
|
||||
if systemPrompt != "" {
|
||||
systemPrompt = thinkingHint + "\n\n" + systemPrompt
|
||||
} else {
|
||||
systemPrompt = thinkingHint
|
||||
}
|
||||
log.Debugf("kiro-openai: injected thinking prompt")
|
||||
}
|
||||
|
||||
// Process messages and build history
|
||||
history, currentUserMsg, currentToolResults := processOpenAIMessages(messages, modelID, origin)
|
||||
|
||||
// Build content with system prompt
|
||||
if currentUserMsg != nil {
|
||||
currentUserMsg.Content = buildFinalContent(currentUserMsg.Content, systemPrompt, currentToolResults)
|
||||
|
||||
// Deduplicate currentToolResults
|
||||
currentToolResults = deduplicateToolResults(currentToolResults)
|
||||
|
||||
// Build userInputMessageContext with tools and tool results
|
||||
if len(kiroTools) > 0 || len(currentToolResults) > 0 {
|
||||
currentUserMsg.UserInputMessageContext = &KiroUserInputMessageContext{
|
||||
Tools: kiroTools,
|
||||
ToolResults: currentToolResults,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build payload
|
||||
var currentMessage KiroCurrentMessage
|
||||
if currentUserMsg != nil {
|
||||
currentMessage = KiroCurrentMessage{UserInputMessage: *currentUserMsg}
|
||||
} else {
|
||||
fallbackContent := ""
|
||||
if systemPrompt != "" {
|
||||
fallbackContent = "--- SYSTEM PROMPT ---\n" + systemPrompt + "\n--- END SYSTEM PROMPT ---\n"
|
||||
}
|
||||
currentMessage = KiroCurrentMessage{UserInputMessage: KiroUserInputMessage{
|
||||
Content: fallbackContent,
|
||||
ModelID: modelID,
|
||||
Origin: origin,
|
||||
}}
|
||||
}
|
||||
|
||||
// Build inferenceConfig if we have any inference parameters
|
||||
// Note: Kiro API doesn't actually use max_tokens for thinking budget
|
||||
var inferenceConfig *KiroInferenceConfig
|
||||
if maxTokens > 0 || hasTemperature || hasTopP {
|
||||
inferenceConfig = &KiroInferenceConfig{}
|
||||
if maxTokens > 0 {
|
||||
inferenceConfig.MaxTokens = int(maxTokens)
|
||||
}
|
||||
if hasTemperature {
|
||||
inferenceConfig.Temperature = temperature
|
||||
}
|
||||
if hasTopP {
|
||||
inferenceConfig.TopP = topP
|
||||
}
|
||||
}
|
||||
|
||||
payload := KiroPayload{
|
||||
ConversationState: KiroConversationState{
|
||||
ChatTriggerType: "MANUAL",
|
||||
ConversationID: uuid.New().String(),
|
||||
CurrentMessage: currentMessage,
|
||||
History: history,
|
||||
},
|
||||
ProfileArn: profileArn,
|
||||
InferenceConfig: inferenceConfig,
|
||||
}
|
||||
|
||||
result, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
log.Debugf("kiro-openai: failed to marshal payload: %v", err)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return result, thinkingEnabled
|
||||
}
|
||||
|
||||
// normalizeOrigin normalizes origin value for Kiro API compatibility
|
||||
func normalizeOrigin(origin string) string {
|
||||
switch origin {
|
||||
case "KIRO_CLI":
|
||||
return "CLI"
|
||||
case "KIRO_AI_EDITOR":
|
||||
return "AI_EDITOR"
|
||||
case "AMAZON_Q":
|
||||
return "CLI"
|
||||
case "KIRO_IDE":
|
||||
return "AI_EDITOR"
|
||||
default:
|
||||
return origin
|
||||
}
|
||||
}
|
||||
|
||||
// extractSystemPromptFromOpenAI extracts system prompt from OpenAI messages
|
||||
func extractSystemPromptFromOpenAI(messages gjson.Result) string {
|
||||
if !messages.IsArray() {
|
||||
return ""
|
||||
}
|
||||
|
||||
var systemParts []string
|
||||
for _, msg := range messages.Array() {
|
||||
if msg.Get("role").String() == "system" {
|
||||
content := msg.Get("content")
|
||||
if content.Type == gjson.String {
|
||||
systemParts = append(systemParts, content.String())
|
||||
} else if content.IsArray() {
|
||||
// Handle array content format
|
||||
for _, part := range content.Array() {
|
||||
if part.Get("type").String() == "text" {
|
||||
systemParts = append(systemParts, part.Get("text").String())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(systemParts, "\n")
|
||||
}
|
||||
|
||||
// shortenToolNameIfNeeded shortens tool names that exceed 64 characters.
|
||||
// MCP tools often have long names like "mcp__server-name__tool-name".
|
||||
// This preserves the "mcp__" prefix and last segment when possible.
|
||||
func shortenToolNameIfNeeded(name string) string {
|
||||
const limit = 64
|
||||
if len(name) <= limit {
|
||||
return name
|
||||
}
|
||||
// For MCP tools, try to preserve prefix and last segment
|
||||
if strings.HasPrefix(name, "mcp__") {
|
||||
idx := strings.LastIndex(name, "__")
|
||||
if idx > 0 {
|
||||
cand := "mcp__" + name[idx+2:]
|
||||
if len(cand) > limit {
|
||||
return cand[:limit]
|
||||
}
|
||||
return cand
|
||||
}
|
||||
}
|
||||
return name[:limit]
|
||||
}
|
||||
|
||||
// convertOpenAIToolsToKiro converts OpenAI tools to Kiro format
|
||||
func convertOpenAIToolsToKiro(tools gjson.Result) []KiroToolWrapper {
|
||||
var kiroTools []KiroToolWrapper
|
||||
if !tools.IsArray() {
|
||||
return kiroTools
|
||||
}
|
||||
|
||||
for _, tool := range tools.Array() {
|
||||
// OpenAI tools have type "function" with function definition inside
|
||||
if tool.Get("type").String() != "function" {
|
||||
continue
|
||||
}
|
||||
|
||||
fn := tool.Get("function")
|
||||
if !fn.Exists() {
|
||||
continue
|
||||
}
|
||||
|
||||
name := fn.Get("name").String()
|
||||
description := fn.Get("description").String()
|
||||
parameters := fn.Get("parameters").Value()
|
||||
|
||||
// Shorten tool name if it exceeds 64 characters (common with MCP tools)
|
||||
originalName := name
|
||||
name = shortenToolNameIfNeeded(name)
|
||||
if name != originalName {
|
||||
log.Debugf("kiro-openai: shortened tool name from '%s' to '%s'", originalName, name)
|
||||
}
|
||||
|
||||
// CRITICAL FIX: Kiro API requires non-empty description
|
||||
if strings.TrimSpace(description) == "" {
|
||||
description = fmt.Sprintf("Tool: %s", name)
|
||||
log.Debugf("kiro-openai: tool '%s' has empty description, using default: %s", name, description)
|
||||
}
|
||||
|
||||
// Truncate long descriptions
|
||||
if len(description) > kirocommon.KiroMaxToolDescLen {
|
||||
truncLen := kirocommon.KiroMaxToolDescLen - 30
|
||||
for truncLen > 0 && !utf8.RuneStart(description[truncLen]) {
|
||||
truncLen--
|
||||
}
|
||||
description = description[:truncLen] + "... (description truncated)"
|
||||
}
|
||||
|
||||
kiroTools = append(kiroTools, KiroToolWrapper{
|
||||
ToolSpecification: KiroToolSpecification{
|
||||
Name: name,
|
||||
Description: description,
|
||||
InputSchema: KiroInputSchema{JSON: parameters},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return kiroTools
|
||||
}
|
||||
|
||||
// processOpenAIMessages processes OpenAI messages and builds Kiro history
|
||||
func processOpenAIMessages(messages gjson.Result, modelID, origin string) ([]KiroHistoryMessage, *KiroUserInputMessage, []KiroToolResult) {
|
||||
var history []KiroHistoryMessage
|
||||
var currentUserMsg *KiroUserInputMessage
|
||||
var currentToolResults []KiroToolResult
|
||||
|
||||
if !messages.IsArray() {
|
||||
return history, currentUserMsg, currentToolResults
|
||||
}
|
||||
|
||||
// Merge adjacent messages with the same role
|
||||
messagesArray := kirocommon.MergeAdjacentMessages(messages.Array())
|
||||
|
||||
// Build tool_call_id to name mapping from assistant messages
|
||||
toolCallIDToName := make(map[string]string)
|
||||
for _, msg := range messagesArray {
|
||||
if msg.Get("role").String() == "assistant" {
|
||||
toolCalls := msg.Get("tool_calls")
|
||||
if toolCalls.IsArray() {
|
||||
for _, tc := range toolCalls.Array() {
|
||||
if tc.Get("type").String() == "function" {
|
||||
id := tc.Get("id").String()
|
||||
name := tc.Get("function.name").String()
|
||||
if id != "" && name != "" {
|
||||
toolCallIDToName[id] = name
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for i, msg := range messagesArray {
|
||||
role := msg.Get("role").String()
|
||||
isLastMessage := i == len(messagesArray)-1
|
||||
|
||||
switch role {
|
||||
case "system":
|
||||
// System messages are handled separately via extractSystemPromptFromOpenAI
|
||||
continue
|
||||
|
||||
case "user":
|
||||
userMsg, toolResults := buildUserMessageFromOpenAI(msg, modelID, origin)
|
||||
if isLastMessage {
|
||||
currentUserMsg = &userMsg
|
||||
currentToolResults = toolResults
|
||||
} else {
|
||||
// CRITICAL: Kiro API requires content to be non-empty for history messages
|
||||
if strings.TrimSpace(userMsg.Content) == "" {
|
||||
if len(toolResults) > 0 {
|
||||
userMsg.Content = "Tool results provided."
|
||||
} else {
|
||||
userMsg.Content = "Continue"
|
||||
}
|
||||
}
|
||||
// For history messages, embed tool results in context
|
||||
if len(toolResults) > 0 {
|
||||
userMsg.UserInputMessageContext = &KiroUserInputMessageContext{
|
||||
ToolResults: toolResults,
|
||||
}
|
||||
}
|
||||
history = append(history, KiroHistoryMessage{
|
||||
UserInputMessage: &userMsg,
|
||||
})
|
||||
}
|
||||
|
||||
case "assistant":
|
||||
assistantMsg := buildAssistantMessageFromOpenAI(msg)
|
||||
if isLastMessage {
|
||||
history = append(history, KiroHistoryMessage{
|
||||
AssistantResponseMessage: &assistantMsg,
|
||||
})
|
||||
// Create a "Continue" user message as currentMessage
|
||||
currentUserMsg = &KiroUserInputMessage{
|
||||
Content: "Continue",
|
||||
ModelID: modelID,
|
||||
Origin: origin,
|
||||
}
|
||||
} else {
|
||||
history = append(history, KiroHistoryMessage{
|
||||
AssistantResponseMessage: &assistantMsg,
|
||||
})
|
||||
}
|
||||
|
||||
case "tool":
|
||||
// Tool messages in OpenAI format provide results for tool_calls
|
||||
// These are typically followed by user or assistant messages
|
||||
// Process them and merge into the next user message's tool results
|
||||
toolCallID := msg.Get("tool_call_id").String()
|
||||
content := msg.Get("content").String()
|
||||
|
||||
if toolCallID != "" {
|
||||
toolResult := KiroToolResult{
|
||||
ToolUseID: toolCallID,
|
||||
Content: []KiroTextContent{{Text: content}},
|
||||
Status: "success",
|
||||
}
|
||||
// Tool results should be included in the next user message
|
||||
// For now, collect them and they'll be handled when we build the current message
|
||||
currentToolResults = append(currentToolResults, toolResult)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return history, currentUserMsg, currentToolResults
|
||||
}
|
||||
|
||||
// buildUserMessageFromOpenAI builds a user message from OpenAI format and extracts tool results
|
||||
func buildUserMessageFromOpenAI(msg gjson.Result, modelID, origin string) (KiroUserInputMessage, []KiroToolResult) {
|
||||
content := msg.Get("content")
|
||||
var contentBuilder strings.Builder
|
||||
var toolResults []KiroToolResult
|
||||
var images []KiroImage
|
||||
|
||||
// Track seen toolCallIds to deduplicate
|
||||
seenToolCallIDs := make(map[string]bool)
|
||||
|
||||
if content.IsArray() {
|
||||
for _, part := range content.Array() {
|
||||
partType := part.Get("type").String()
|
||||
switch partType {
|
||||
case "text":
|
||||
contentBuilder.WriteString(part.Get("text").String())
|
||||
case "image_url":
|
||||
imageURL := part.Get("image_url.url").String()
|
||||
if strings.HasPrefix(imageURL, "data:") {
|
||||
// Parse data URL: data:image/png;base64,xxxxx
|
||||
if idx := strings.Index(imageURL, ";base64,"); idx != -1 {
|
||||
mediaType := imageURL[5:idx] // Skip "data:"
|
||||
data := imageURL[idx+8:] // Skip ";base64,"
|
||||
|
||||
format := ""
|
||||
if lastSlash := strings.LastIndex(mediaType, "/"); lastSlash != -1 {
|
||||
format = mediaType[lastSlash+1:]
|
||||
}
|
||||
|
||||
if format != "" && data != "" {
|
||||
images = append(images, KiroImage{
|
||||
Format: format,
|
||||
Source: KiroImageSource{
|
||||
Bytes: data,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if content.Type == gjson.String {
|
||||
contentBuilder.WriteString(content.String())
|
||||
}
|
||||
|
||||
// Check for tool_calls in the message (shouldn't be in user messages, but handle edge cases)
|
||||
_ = seenToolCallIDs // Used for deduplication if needed
|
||||
|
||||
userMsg := KiroUserInputMessage{
|
||||
Content: contentBuilder.String(),
|
||||
ModelID: modelID,
|
||||
Origin: origin,
|
||||
}
|
||||
|
||||
if len(images) > 0 {
|
||||
userMsg.Images = images
|
||||
}
|
||||
|
||||
return userMsg, toolResults
|
||||
}
|
||||
|
||||
// buildAssistantMessageFromOpenAI builds an assistant message from OpenAI format
|
||||
func buildAssistantMessageFromOpenAI(msg gjson.Result) KiroAssistantResponseMessage {
|
||||
content := msg.Get("content")
|
||||
var contentBuilder strings.Builder
|
||||
var toolUses []KiroToolUse
|
||||
|
||||
// Handle content
|
||||
if content.Type == gjson.String {
|
||||
contentBuilder.WriteString(content.String())
|
||||
} else if content.IsArray() {
|
||||
for _, part := range content.Array() {
|
||||
if part.Get("type").String() == "text" {
|
||||
contentBuilder.WriteString(part.Get("text").String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle tool_calls
|
||||
toolCalls := msg.Get("tool_calls")
|
||||
if toolCalls.IsArray() {
|
||||
for _, tc := range toolCalls.Array() {
|
||||
if tc.Get("type").String() != "function" {
|
||||
continue
|
||||
}
|
||||
|
||||
toolUseID := tc.Get("id").String()
|
||||
toolName := tc.Get("function.name").String()
|
||||
toolArgs := tc.Get("function.arguments").String()
|
||||
|
||||
var inputMap map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(toolArgs), &inputMap); err != nil {
|
||||
log.Debugf("kiro-openai: failed to parse tool arguments: %v", err)
|
||||
inputMap = make(map[string]interface{})
|
||||
}
|
||||
|
||||
toolUses = append(toolUses, KiroToolUse{
|
||||
ToolUseID: toolUseID,
|
||||
Name: toolName,
|
||||
Input: inputMap,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return KiroAssistantResponseMessage{
|
||||
Content: contentBuilder.String(),
|
||||
ToolUses: toolUses,
|
||||
}
|
||||
}
|
||||
|
||||
// buildFinalContent builds the final content with system prompt
|
||||
func buildFinalContent(content, systemPrompt string, toolResults []KiroToolResult) string {
|
||||
var contentBuilder strings.Builder
|
||||
|
||||
if systemPrompt != "" {
|
||||
contentBuilder.WriteString("--- SYSTEM PROMPT ---\n")
|
||||
contentBuilder.WriteString(systemPrompt)
|
||||
contentBuilder.WriteString("\n--- END SYSTEM PROMPT ---\n\n")
|
||||
}
|
||||
|
||||
contentBuilder.WriteString(content)
|
||||
finalContent := contentBuilder.String()
|
||||
|
||||
// CRITICAL: Kiro API requires content to be non-empty
|
||||
if strings.TrimSpace(finalContent) == "" {
|
||||
if len(toolResults) > 0 {
|
||||
finalContent = "Tool results provided."
|
||||
} else {
|
||||
finalContent = "Continue"
|
||||
}
|
||||
log.Debugf("kiro-openai: content was empty, using default: %s", finalContent)
|
||||
}
|
||||
|
||||
return finalContent
|
||||
}
|
||||
|
||||
// checkThinkingModeFromOpenAI checks if thinking mode is enabled in the OpenAI request.
|
||||
// Returns thinkingEnabled.
|
||||
// Supports:
|
||||
// - reasoning_effort parameter (low/medium/high/auto)
|
||||
// - Model name containing "thinking" or "reason"
|
||||
// - <thinking_mode> tag in system prompt (AMP/Cursor format)
|
||||
func checkThinkingModeFromOpenAI(openaiBody []byte) bool {
|
||||
return checkThinkingModeFromOpenAIWithHeaders(openaiBody, nil)
|
||||
}
|
||||
|
||||
// checkThinkingModeFromOpenAIWithHeaders checks if thinking mode is enabled in the OpenAI request.
|
||||
// Returns thinkingEnabled.
|
||||
// Supports:
|
||||
// - Anthropic-Beta header with interleaved-thinking (Claude CLI)
|
||||
// - reasoning_effort parameter (low/medium/high/auto)
|
||||
// - Model name containing "thinking" or "reason"
|
||||
// - <thinking_mode> tag in system prompt (AMP/Cursor format)
|
||||
func checkThinkingModeFromOpenAIWithHeaders(openaiBody []byte, headers http.Header) bool {
|
||||
// Check Anthropic-Beta header first (Claude CLI uses this)
|
||||
if kiroclaude.IsThinkingEnabledFromHeader(headers) {
|
||||
log.Debugf("kiro-openai: thinking mode enabled via Anthropic-Beta header")
|
||||
return true
|
||||
}
|
||||
|
||||
// Check OpenAI format: reasoning_effort parameter
|
||||
// Valid values: "low", "medium", "high", "auto" (not "none")
|
||||
reasoningEffort := gjson.GetBytes(openaiBody, "reasoning_effort")
|
||||
if reasoningEffort.Exists() {
|
||||
effort := reasoningEffort.String()
|
||||
if effort != "" && effort != "none" {
|
||||
log.Debugf("kiro-openai: thinking mode enabled via reasoning_effort: %s", effort)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check AMP/Cursor format: <thinking_mode>interleaved</thinking_mode> in system prompt
|
||||
bodyStr := string(openaiBody)
|
||||
if strings.Contains(bodyStr, "<thinking_mode>") && strings.Contains(bodyStr, "</thinking_mode>") {
|
||||
startTag := "<thinking_mode>"
|
||||
endTag := "</thinking_mode>"
|
||||
startIdx := strings.Index(bodyStr, startTag)
|
||||
if startIdx >= 0 {
|
||||
startIdx += len(startTag)
|
||||
endIdx := strings.Index(bodyStr[startIdx:], endTag)
|
||||
if endIdx >= 0 {
|
||||
thinkingMode := bodyStr[startIdx : startIdx+endIdx]
|
||||
if thinkingMode == "interleaved" || thinkingMode == "enabled" {
|
||||
log.Debugf("kiro-openai: thinking mode enabled via AMP/Cursor format: %s", thinkingMode)
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check model name for thinking hints
|
||||
model := gjson.GetBytes(openaiBody, "model").String()
|
||||
modelLower := strings.ToLower(model)
|
||||
if strings.Contains(modelLower, "thinking") || strings.Contains(modelLower, "-reason") {
|
||||
log.Debugf("kiro-openai: thinking mode enabled via model name hint: %s", model)
|
||||
return true
|
||||
}
|
||||
|
||||
log.Debugf("kiro-openai: no thinking mode detected in OpenAI request")
|
||||
return false
|
||||
}
|
||||
|
||||
// hasThinkingTagInBody checks if the request body already contains thinking configuration tags.
|
||||
// This is used to prevent duplicate injection when client (e.g., AMP/Cursor) already includes thinking config.
|
||||
func hasThinkingTagInBody(body []byte) bool {
|
||||
bodyStr := string(body)
|
||||
return strings.Contains(bodyStr, "<thinking_mode>") || strings.Contains(bodyStr, "<max_thinking_length>")
|
||||
}
|
||||
|
||||
|
||||
// extractToolChoiceHint extracts tool_choice from OpenAI request and returns a system prompt hint.
|
||||
// OpenAI tool_choice values:
|
||||
// - "none": Don't use any tools
|
||||
// - "auto": Model decides (default, no hint needed)
|
||||
// - "required": Must use at least one tool
|
||||
// - {"type":"function","function":{"name":"..."}} : Must use specific tool
|
||||
func extractToolChoiceHint(openaiBody []byte) string {
|
||||
toolChoice := gjson.GetBytes(openaiBody, "tool_choice")
|
||||
if !toolChoice.Exists() {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Handle string values
|
||||
if toolChoice.Type == gjson.String {
|
||||
switch toolChoice.String() {
|
||||
case "none":
|
||||
// Note: When tool_choice is "none", we should ideally not pass tools at all
|
||||
// But since we can't modify tool passing here, we add a strong hint
|
||||
return "[INSTRUCTION: Do NOT use any tools. Respond with text only.]"
|
||||
case "required":
|
||||
return "[INSTRUCTION: You MUST use at least one of the available tools to respond. Do not respond with text only - always make a tool call.]"
|
||||
case "auto":
|
||||
// Default behavior, no hint needed
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// Handle object value: {"type":"function","function":{"name":"..."}}
|
||||
if toolChoice.IsObject() {
|
||||
if toolChoice.Get("type").String() == "function" {
|
||||
toolName := toolChoice.Get("function.name").String()
|
||||
if toolName != "" {
|
||||
return fmt.Sprintf("[INSTRUCTION: You MUST use the tool named '%s' to respond. Do not use any other tool or respond with text only.]", toolName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// extractResponseFormatHint extracts response_format from OpenAI request and returns a system prompt hint.
|
||||
// OpenAI response_format values:
|
||||
// - {"type": "text"}: Default, no hint needed
|
||||
// - {"type": "json_object"}: Must respond with valid JSON
|
||||
// - {"type": "json_schema", "json_schema": {...}}: Must respond with JSON matching schema
|
||||
func extractResponseFormatHint(openaiBody []byte) string {
|
||||
responseFormat := gjson.GetBytes(openaiBody, "response_format")
|
||||
if !responseFormat.Exists() {
|
||||
return ""
|
||||
}
|
||||
|
||||
formatType := responseFormat.Get("type").String()
|
||||
switch formatType {
|
||||
case "json_object":
|
||||
return "[INSTRUCTION: You MUST respond with valid JSON only. Do not include any text before or after the JSON. Do not wrap the JSON in markdown code blocks. Output raw JSON directly.]"
|
||||
case "json_schema":
|
||||
// Extract schema if provided
|
||||
schema := responseFormat.Get("json_schema.schema")
|
||||
if schema.Exists() {
|
||||
schemaStr := schema.Raw
|
||||
// Truncate if too long
|
||||
if len(schemaStr) > 500 {
|
||||
schemaStr = schemaStr[:500] + "..."
|
||||
}
|
||||
return fmt.Sprintf("[INSTRUCTION: You MUST respond with valid JSON that matches this schema: %s. Do not include any text before or after the JSON. Do not wrap the JSON in markdown code blocks. Output raw JSON directly.]", schemaStr)
|
||||
}
|
||||
return "[INSTRUCTION: You MUST respond with valid JSON only. Do not include any text before or after the JSON. Do not wrap the JSON in markdown code blocks. Output raw JSON directly.]"
|
||||
case "text":
|
||||
// Default behavior, no hint needed
|
||||
return ""
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// deduplicateToolResults removes duplicate tool results
|
||||
func deduplicateToolResults(toolResults []KiroToolResult) []KiroToolResult {
|
||||
if len(toolResults) == 0 {
|
||||
return toolResults
|
||||
}
|
||||
|
||||
seenIDs := make(map[string]bool)
|
||||
unique := make([]KiroToolResult, 0, len(toolResults))
|
||||
for _, tr := range toolResults {
|
||||
if !seenIDs[tr.ToolUseID] {
|
||||
seenIDs[tr.ToolUseID] = true
|
||||
unique = append(unique, tr)
|
||||
} else {
|
||||
log.Debugf("kiro-openai: skipping duplicate toolResult: %s", tr.ToolUseID)
|
||||
}
|
||||
}
|
||||
return unique
|
||||
}
|
||||
277
internal/translator/kiro/openai/kiro_openai_response.go
Normal file
277
internal/translator/kiro/openai/kiro_openai_response.go
Normal file
@@ -0,0 +1,277 @@
|
||||
// Package openai provides response translation from Kiro to OpenAI format.
|
||||
// This package handles the conversion of Kiro API responses into OpenAI Chat Completions-compatible
|
||||
// JSON format, transforming streaming events and non-streaming responses.
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// functionCallIDCounter provides a process-wide unique counter for function call identifiers.
|
||||
var functionCallIDCounter uint64
|
||||
|
||||
// BuildOpenAIResponse constructs an OpenAI Chat Completions-compatible response.
|
||||
// Supports tool_calls when tools are present in the response.
|
||||
// stopReason is passed from upstream; fallback logic applied if empty.
|
||||
func BuildOpenAIResponse(content string, toolUses []KiroToolUse, model string, usageInfo usage.Detail, stopReason string) []byte {
|
||||
return BuildOpenAIResponseWithReasoning(content, "", toolUses, model, usageInfo, stopReason)
|
||||
}
|
||||
|
||||
// BuildOpenAIResponseWithReasoning constructs an OpenAI Chat Completions-compatible response with reasoning_content support.
|
||||
// Supports tool_calls when tools are present in the response.
|
||||
// reasoningContent is included as reasoning_content field in the message when present.
|
||||
// stopReason is passed from upstream; fallback logic applied if empty.
|
||||
func BuildOpenAIResponseWithReasoning(content, reasoningContent string, toolUses []KiroToolUse, model string, usageInfo usage.Detail, stopReason string) []byte {
|
||||
// Build the message object
|
||||
message := map[string]interface{}{
|
||||
"role": "assistant",
|
||||
"content": content,
|
||||
}
|
||||
|
||||
// Add reasoning_content if present (for thinking/reasoning models)
|
||||
if reasoningContent != "" {
|
||||
message["reasoning_content"] = reasoningContent
|
||||
}
|
||||
|
||||
// Add tool_calls if present
|
||||
if len(toolUses) > 0 {
|
||||
var toolCalls []map[string]interface{}
|
||||
for i, tu := range toolUses {
|
||||
inputJSON, _ := json.Marshal(tu.Input)
|
||||
toolCalls = append(toolCalls, map[string]interface{}{
|
||||
"id": tu.ToolUseID,
|
||||
"type": "function",
|
||||
"index": i,
|
||||
"function": map[string]interface{}{
|
||||
"name": tu.Name,
|
||||
"arguments": string(inputJSON),
|
||||
},
|
||||
})
|
||||
}
|
||||
message["tool_calls"] = toolCalls
|
||||
// When tool_calls are present, content should be null according to OpenAI spec
|
||||
if content == "" {
|
||||
message["content"] = nil
|
||||
}
|
||||
}
|
||||
|
||||
// Use upstream stopReason; apply fallback logic if not provided
|
||||
finishReason := mapKiroStopReasonToOpenAI(stopReason)
|
||||
if finishReason == "" {
|
||||
finishReason = "stop"
|
||||
if len(toolUses) > 0 {
|
||||
finishReason = "tool_calls"
|
||||
}
|
||||
log.Debugf("kiro-openai: buildOpenAIResponse using fallback finish_reason: %s", finishReason)
|
||||
}
|
||||
|
||||
response := map[string]interface{}{
|
||||
"id": "chatcmpl-" + uuid.New().String()[:24],
|
||||
"object": "chat.completion",
|
||||
"created": time.Now().Unix(),
|
||||
"model": model,
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
"index": 0,
|
||||
"message": message,
|
||||
"finish_reason": finishReason,
|
||||
},
|
||||
},
|
||||
"usage": map[string]interface{}{
|
||||
"prompt_tokens": usageInfo.InputTokens,
|
||||
"completion_tokens": usageInfo.OutputTokens,
|
||||
"total_tokens": usageInfo.InputTokens + usageInfo.OutputTokens,
|
||||
},
|
||||
}
|
||||
|
||||
result, _ := json.Marshal(response)
|
||||
return result
|
||||
}
|
||||
|
||||
// mapKiroStopReasonToOpenAI converts Kiro/Claude stop_reason to OpenAI finish_reason
|
||||
func mapKiroStopReasonToOpenAI(stopReason string) string {
|
||||
switch stopReason {
|
||||
case "end_turn":
|
||||
return "stop"
|
||||
case "stop_sequence":
|
||||
return "stop"
|
||||
case "tool_use":
|
||||
return "tool_calls"
|
||||
case "max_tokens":
|
||||
return "length"
|
||||
case "content_filtered":
|
||||
return "content_filter"
|
||||
default:
|
||||
return stopReason
|
||||
}
|
||||
}
|
||||
|
||||
// BuildOpenAIStreamChunk constructs an OpenAI Chat Completions streaming chunk.
|
||||
// This is the delta format used in streaming responses.
|
||||
func BuildOpenAIStreamChunk(model string, deltaContent string, deltaToolCalls []map[string]interface{}, finishReason string, index int) []byte {
|
||||
delta := map[string]interface{}{}
|
||||
|
||||
// First chunk should include role
|
||||
if index == 0 && deltaContent == "" && len(deltaToolCalls) == 0 {
|
||||
delta["role"] = "assistant"
|
||||
delta["content"] = ""
|
||||
} else if deltaContent != "" {
|
||||
delta["content"] = deltaContent
|
||||
}
|
||||
|
||||
// Add tool_calls delta if present
|
||||
if len(deltaToolCalls) > 0 {
|
||||
delta["tool_calls"] = deltaToolCalls
|
||||
}
|
||||
|
||||
choice := map[string]interface{}{
|
||||
"index": 0,
|
||||
"delta": delta,
|
||||
}
|
||||
|
||||
if finishReason != "" {
|
||||
choice["finish_reason"] = finishReason
|
||||
} else {
|
||||
choice["finish_reason"] = nil
|
||||
}
|
||||
|
||||
chunk := map[string]interface{}{
|
||||
"id": "chatcmpl-" + uuid.New().String()[:12],
|
||||
"object": "chat.completion.chunk",
|
||||
"created": time.Now().Unix(),
|
||||
"model": model,
|
||||
"choices": []map[string]interface{}{choice},
|
||||
}
|
||||
|
||||
result, _ := json.Marshal(chunk)
|
||||
return result
|
||||
}
|
||||
|
||||
// BuildOpenAIStreamChunkWithToolCallStart creates a stream chunk for tool call start
|
||||
func BuildOpenAIStreamChunkWithToolCallStart(model string, toolUseID, toolName string, toolIndex int) []byte {
|
||||
toolCall := map[string]interface{}{
|
||||
"index": toolIndex,
|
||||
"id": toolUseID,
|
||||
"type": "function",
|
||||
"function": map[string]interface{}{
|
||||
"name": toolName,
|
||||
"arguments": "",
|
||||
},
|
||||
}
|
||||
|
||||
delta := map[string]interface{}{
|
||||
"tool_calls": []map[string]interface{}{toolCall},
|
||||
}
|
||||
|
||||
choice := map[string]interface{}{
|
||||
"index": 0,
|
||||
"delta": delta,
|
||||
"finish_reason": nil,
|
||||
}
|
||||
|
||||
chunk := map[string]interface{}{
|
||||
"id": "chatcmpl-" + uuid.New().String()[:12],
|
||||
"object": "chat.completion.chunk",
|
||||
"created": time.Now().Unix(),
|
||||
"model": model,
|
||||
"choices": []map[string]interface{}{choice},
|
||||
}
|
||||
|
||||
result, _ := json.Marshal(chunk)
|
||||
return result
|
||||
}
|
||||
|
||||
// BuildOpenAIStreamChunkWithToolCallDelta creates a stream chunk for tool call arguments delta
|
||||
func BuildOpenAIStreamChunkWithToolCallDelta(model string, argumentsDelta string, toolIndex int) []byte {
|
||||
toolCall := map[string]interface{}{
|
||||
"index": toolIndex,
|
||||
"function": map[string]interface{}{
|
||||
"arguments": argumentsDelta,
|
||||
},
|
||||
}
|
||||
|
||||
delta := map[string]interface{}{
|
||||
"tool_calls": []map[string]interface{}{toolCall},
|
||||
}
|
||||
|
||||
choice := map[string]interface{}{
|
||||
"index": 0,
|
||||
"delta": delta,
|
||||
"finish_reason": nil,
|
||||
}
|
||||
|
||||
chunk := map[string]interface{}{
|
||||
"id": "chatcmpl-" + uuid.New().String()[:12],
|
||||
"object": "chat.completion.chunk",
|
||||
"created": time.Now().Unix(),
|
||||
"model": model,
|
||||
"choices": []map[string]interface{}{choice},
|
||||
}
|
||||
|
||||
result, _ := json.Marshal(chunk)
|
||||
return result
|
||||
}
|
||||
|
||||
// BuildOpenAIStreamDoneChunk creates the final [DONE] stream event
|
||||
func BuildOpenAIStreamDoneChunk() []byte {
|
||||
return []byte("data: [DONE]")
|
||||
}
|
||||
|
||||
// BuildOpenAIStreamFinishChunk creates the final chunk with finish_reason
|
||||
func BuildOpenAIStreamFinishChunk(model string, finishReason string) []byte {
|
||||
choice := map[string]interface{}{
|
||||
"index": 0,
|
||||
"delta": map[string]interface{}{},
|
||||
"finish_reason": finishReason,
|
||||
}
|
||||
|
||||
chunk := map[string]interface{}{
|
||||
"id": "chatcmpl-" + uuid.New().String()[:12],
|
||||
"object": "chat.completion.chunk",
|
||||
"created": time.Now().Unix(),
|
||||
"model": model,
|
||||
"choices": []map[string]interface{}{choice},
|
||||
}
|
||||
|
||||
result, _ := json.Marshal(chunk)
|
||||
return result
|
||||
}
|
||||
|
||||
// BuildOpenAIStreamUsageChunk creates a chunk with usage information (optional, for stream_options.include_usage)
|
||||
func BuildOpenAIStreamUsageChunk(model string, usageInfo usage.Detail) []byte {
|
||||
chunk := map[string]interface{}{
|
||||
"id": "chatcmpl-" + uuid.New().String()[:12],
|
||||
"object": "chat.completion.chunk",
|
||||
"created": time.Now().Unix(),
|
||||
"model": model,
|
||||
"choices": []map[string]interface{}{},
|
||||
"usage": map[string]interface{}{
|
||||
"prompt_tokens": usageInfo.InputTokens,
|
||||
"completion_tokens": usageInfo.OutputTokens,
|
||||
"total_tokens": usageInfo.InputTokens + usageInfo.OutputTokens,
|
||||
},
|
||||
}
|
||||
|
||||
result, _ := json.Marshal(chunk)
|
||||
return result
|
||||
}
|
||||
|
||||
// GenerateToolCallID generates a unique tool call ID in OpenAI format
|
||||
func GenerateToolCallID(toolName string) string {
|
||||
return fmt.Sprintf("call_%s_%d_%d", toolName[:min(8, len(toolName))], time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1))
|
||||
}
|
||||
|
||||
// min returns the minimum of two integers
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
212
internal/translator/kiro/openai/kiro_openai_stream.go
Normal file
212
internal/translator/kiro/openai/kiro_openai_stream.go
Normal file
@@ -0,0 +1,212 @@
|
||||
// Package openai provides streaming SSE event building for OpenAI format.
|
||||
// This package handles the construction of OpenAI-compatible Server-Sent Events (SSE)
|
||||
// for streaming responses from Kiro API.
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
|
||||
)
|
||||
|
||||
// OpenAIStreamState tracks the state of streaming response conversion
|
||||
type OpenAIStreamState struct {
|
||||
ChunkIndex int
|
||||
ToolCallIndex int
|
||||
HasSentFirstChunk bool
|
||||
Model string
|
||||
ResponseID string
|
||||
Created int64
|
||||
}
|
||||
|
||||
// NewOpenAIStreamState creates a new stream state for tracking
|
||||
func NewOpenAIStreamState(model string) *OpenAIStreamState {
|
||||
return &OpenAIStreamState{
|
||||
ChunkIndex: 0,
|
||||
ToolCallIndex: 0,
|
||||
HasSentFirstChunk: false,
|
||||
Model: model,
|
||||
ResponseID: "chatcmpl-" + uuid.New().String()[:24],
|
||||
Created: time.Now().Unix(),
|
||||
}
|
||||
}
|
||||
|
||||
// FormatSSEEvent formats a JSON payload for SSE streaming.
|
||||
// Note: This returns raw JSON data without "data:" prefix.
|
||||
// The SSE "data:" prefix is added by the Handler layer (e.g., openai_handlers.go)
|
||||
// to maintain architectural consistency and avoid double-prefix issues.
|
||||
func FormatSSEEvent(data []byte) string {
|
||||
return string(data)
|
||||
}
|
||||
|
||||
// BuildOpenAISSETextDelta creates an SSE event for text content delta
|
||||
func BuildOpenAISSETextDelta(state *OpenAIStreamState, textDelta string) string {
|
||||
delta := map[string]interface{}{
|
||||
"content": textDelta,
|
||||
}
|
||||
|
||||
// Include role in first chunk
|
||||
if !state.HasSentFirstChunk {
|
||||
delta["role"] = "assistant"
|
||||
state.HasSentFirstChunk = true
|
||||
}
|
||||
|
||||
chunk := buildBaseChunk(state, delta, nil)
|
||||
result, _ := json.Marshal(chunk)
|
||||
state.ChunkIndex++
|
||||
return FormatSSEEvent(result)
|
||||
}
|
||||
|
||||
// BuildOpenAISSEToolCallStart creates an SSE event for tool call start
|
||||
func BuildOpenAISSEToolCallStart(state *OpenAIStreamState, toolUseID, toolName string) string {
|
||||
toolCall := map[string]interface{}{
|
||||
"index": state.ToolCallIndex,
|
||||
"id": toolUseID,
|
||||
"type": "function",
|
||||
"function": map[string]interface{}{
|
||||
"name": toolName,
|
||||
"arguments": "",
|
||||
},
|
||||
}
|
||||
|
||||
delta := map[string]interface{}{
|
||||
"tool_calls": []map[string]interface{}{toolCall},
|
||||
}
|
||||
|
||||
// Include role in first chunk if not sent yet
|
||||
if !state.HasSentFirstChunk {
|
||||
delta["role"] = "assistant"
|
||||
state.HasSentFirstChunk = true
|
||||
}
|
||||
|
||||
chunk := buildBaseChunk(state, delta, nil)
|
||||
result, _ := json.Marshal(chunk)
|
||||
state.ChunkIndex++
|
||||
return FormatSSEEvent(result)
|
||||
}
|
||||
|
||||
// BuildOpenAISSEToolCallArgumentsDelta creates an SSE event for tool call arguments delta
|
||||
func BuildOpenAISSEToolCallArgumentsDelta(state *OpenAIStreamState, argumentsDelta string, toolIndex int) string {
|
||||
toolCall := map[string]interface{}{
|
||||
"index": toolIndex,
|
||||
"function": map[string]interface{}{
|
||||
"arguments": argumentsDelta,
|
||||
},
|
||||
}
|
||||
|
||||
delta := map[string]interface{}{
|
||||
"tool_calls": []map[string]interface{}{toolCall},
|
||||
}
|
||||
|
||||
chunk := buildBaseChunk(state, delta, nil)
|
||||
result, _ := json.Marshal(chunk)
|
||||
state.ChunkIndex++
|
||||
return FormatSSEEvent(result)
|
||||
}
|
||||
|
||||
// BuildOpenAISSEFinish creates an SSE event with finish_reason
|
||||
func BuildOpenAISSEFinish(state *OpenAIStreamState, finishReason string) string {
|
||||
chunk := buildBaseChunk(state, map[string]interface{}{}, &finishReason)
|
||||
result, _ := json.Marshal(chunk)
|
||||
state.ChunkIndex++
|
||||
return FormatSSEEvent(result)
|
||||
}
|
||||
|
||||
// BuildOpenAISSEUsage creates an SSE event with usage information
|
||||
func BuildOpenAISSEUsage(state *OpenAIStreamState, usageInfo usage.Detail) string {
|
||||
chunk := map[string]interface{}{
|
||||
"id": state.ResponseID,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": state.Created,
|
||||
"model": state.Model,
|
||||
"choices": []map[string]interface{}{},
|
||||
"usage": map[string]interface{}{
|
||||
"prompt_tokens": usageInfo.InputTokens,
|
||||
"completion_tokens": usageInfo.OutputTokens,
|
||||
"total_tokens": usageInfo.InputTokens + usageInfo.OutputTokens,
|
||||
},
|
||||
}
|
||||
result, _ := json.Marshal(chunk)
|
||||
return FormatSSEEvent(result)
|
||||
}
|
||||
|
||||
// BuildOpenAISSEDone creates the final [DONE] SSE event.
|
||||
// Note: This returns raw "[DONE]" without "data:" prefix.
|
||||
// The SSE "data:" prefix is added by the Handler layer (e.g., openai_handlers.go)
|
||||
// to maintain architectural consistency and avoid double-prefix issues.
|
||||
func BuildOpenAISSEDone() string {
|
||||
return "[DONE]"
|
||||
}
|
||||
|
||||
// buildBaseChunk creates a base chunk structure for streaming
|
||||
func buildBaseChunk(state *OpenAIStreamState, delta map[string]interface{}, finishReason *string) map[string]interface{} {
|
||||
choice := map[string]interface{}{
|
||||
"index": 0,
|
||||
"delta": delta,
|
||||
}
|
||||
|
||||
if finishReason != nil {
|
||||
choice["finish_reason"] = *finishReason
|
||||
} else {
|
||||
choice["finish_reason"] = nil
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"id": state.ResponseID,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": state.Created,
|
||||
"model": state.Model,
|
||||
"choices": []map[string]interface{}{choice},
|
||||
}
|
||||
}
|
||||
|
||||
// BuildOpenAISSEReasoningDelta creates an SSE event for reasoning content delta
|
||||
// This is used for o1/o3 style models that expose reasoning tokens
|
||||
func BuildOpenAISSEReasoningDelta(state *OpenAIStreamState, reasoningDelta string) string {
|
||||
delta := map[string]interface{}{
|
||||
"reasoning_content": reasoningDelta,
|
||||
}
|
||||
|
||||
// Include role in first chunk
|
||||
if !state.HasSentFirstChunk {
|
||||
delta["role"] = "assistant"
|
||||
state.HasSentFirstChunk = true
|
||||
}
|
||||
|
||||
chunk := buildBaseChunk(state, delta, nil)
|
||||
result, _ := json.Marshal(chunk)
|
||||
state.ChunkIndex++
|
||||
return FormatSSEEvent(result)
|
||||
}
|
||||
|
||||
// BuildOpenAISSEFirstChunk creates the first chunk with role only
|
||||
func BuildOpenAISSEFirstChunk(state *OpenAIStreamState) string {
|
||||
delta := map[string]interface{}{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
}
|
||||
|
||||
state.HasSentFirstChunk = true
|
||||
chunk := buildBaseChunk(state, delta, nil)
|
||||
result, _ := json.Marshal(chunk)
|
||||
state.ChunkIndex++
|
||||
return FormatSSEEvent(result)
|
||||
}
|
||||
|
||||
// ThinkingTagState tracks state for thinking tag detection in streaming
|
||||
type ThinkingTagState struct {
|
||||
InThinkingBlock bool
|
||||
PendingStartChars int
|
||||
PendingEndChars int
|
||||
}
|
||||
|
||||
// NewThinkingTagState creates a new thinking tag state
|
||||
func NewThinkingTagState() *ThinkingTagState {
|
||||
return &ThinkingTagState{
|
||||
InThinkingBlock: false,
|
||||
PendingStartChars: 0,
|
||||
PendingEndChars: 0,
|
||||
}
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -60,6 +61,30 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream
|
||||
// Stream
|
||||
out, _ = sjson.Set(out, "stream", stream)
|
||||
|
||||
// Thinking: Convert Claude thinking.budget_tokens to OpenAI reasoning_effort
|
||||
if thinking := root.Get("thinking"); thinking.Exists() && thinking.IsObject() {
|
||||
if thinkingType := thinking.Get("type"); thinkingType.Exists() {
|
||||
switch thinkingType.String() {
|
||||
case "enabled":
|
||||
if budgetTokens := thinking.Get("budget_tokens"); budgetTokens.Exists() {
|
||||
budget := int(budgetTokens.Int())
|
||||
if effort, ok := util.ThinkingBudgetToEffort(modelName, budget); ok && effort != "" {
|
||||
out, _ = sjson.Set(out, "reasoning_effort", effort)
|
||||
}
|
||||
} else {
|
||||
// No budget_tokens specified, default to "auto" for enabled thinking
|
||||
if effort, ok := util.ThinkingBudgetToEffort(modelName, -1); ok && effort != "" {
|
||||
out, _ = sjson.Set(out, "reasoning_effort", effort)
|
||||
}
|
||||
}
|
||||
case "disabled":
|
||||
if effort, ok := util.ThinkingBudgetToEffort(modelName, 0); ok && effort != "" {
|
||||
out, _ = sjson.Set(out, "reasoning_effort", effort)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process messages and system
|
||||
var messagesJSON = "[]"
|
||||
|
||||
|
||||
@@ -128,9 +128,10 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
|
||||
param.CreatedAt = root.Get("created").Int()
|
||||
}
|
||||
|
||||
// Check if this is the first chunk (has role)
|
||||
// Emit message_start on the very first chunk, regardless of whether it has a role field.
|
||||
// Some providers (like Copilot) may send tool_calls in the first chunk without a role field.
|
||||
if delta := root.Get("choices.0.delta"); delta.Exists() {
|
||||
if role := delta.Get("role"); role.Exists() && role.String() == "assistant" && !param.MessageStarted {
|
||||
if !param.MessageStarted {
|
||||
// Send message_start event
|
||||
messageStart := map[string]interface{}{
|
||||
"type": "message_start",
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"math/big"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -76,6 +77,17 @@ func ConvertGeminiRequestToOpenAI(modelName string, inputRawJSON []byte, stream
|
||||
out, _ = sjson.Set(out, "stop", stops)
|
||||
}
|
||||
}
|
||||
|
||||
// Convert thinkingBudget to reasoning_effort
|
||||
// Always perform conversion to support allowCompat models that may not be in registry
|
||||
if thinkingConfig := genConfig.Get("thinkingConfig"); thinkingConfig.Exists() && thinkingConfig.IsObject() {
|
||||
if thinkingBudget := thinkingConfig.Get("thinkingBudget"); thinkingBudget.Exists() {
|
||||
budget := int(thinkingBudget.Int())
|
||||
if effort, ok := util.ThinkingBudgetToEffort(modelName, budget); ok && effort != "" {
|
||||
out, _ = sjson.Set(out, "reasoning_effort", effort)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stream parameter
|
||||
|
||||
@@ -2,6 +2,7 @@ package responses
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
@@ -64,7 +65,7 @@ func ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName string, inpu
|
||||
}
|
||||
|
||||
switch itemType {
|
||||
case "message":
|
||||
case "message", "":
|
||||
// Handle regular message conversion
|
||||
role := item.Get("role").String()
|
||||
message := `{"role":"","content":""}`
|
||||
@@ -106,6 +107,8 @@ func ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName string, inpu
|
||||
if len(toolCalls) > 0 {
|
||||
message, _ = sjson.Set(message, "tool_calls", toolCalls)
|
||||
}
|
||||
} else if content.Type == gjson.String {
|
||||
message, _ = sjson.Set(message, "content", content.String())
|
||||
}
|
||||
|
||||
out, _ = sjson.SetRaw(out, "messages.-1", message)
|
||||
@@ -189,23 +192,9 @@ func ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName string, inpu
|
||||
}
|
||||
|
||||
if reasoningEffort := root.Get("reasoning.effort"); reasoningEffort.Exists() {
|
||||
switch reasoningEffort.String() {
|
||||
case "none":
|
||||
out, _ = sjson.Set(out, "reasoning_effort", "none")
|
||||
case "auto":
|
||||
out, _ = sjson.Set(out, "reasoning_effort", "auto")
|
||||
case "minimal":
|
||||
out, _ = sjson.Set(out, "reasoning_effort", "low")
|
||||
case "low":
|
||||
out, _ = sjson.Set(out, "reasoning_effort", "low")
|
||||
case "medium":
|
||||
out, _ = sjson.Set(out, "reasoning_effort", "medium")
|
||||
case "high":
|
||||
out, _ = sjson.Set(out, "reasoning_effort", "high")
|
||||
case "xhigh":
|
||||
out, _ = sjson.Set(out, "reasoning_effort", "xhigh")
|
||||
default:
|
||||
out, _ = sjson.Set(out, "reasoning_effort", "auto")
|
||||
effort := strings.ToLower(strings.TrimSpace(reasoningEffort.String()))
|
||||
if effort != "" {
|
||||
out, _ = sjson.Set(out, "reasoning_effort", effort)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -28,6 +28,9 @@ func ApplyClaudeThinkingConfig(body []byte, budget *int) []byte {
|
||||
// It uses the unified ResolveThinkingConfigFromMetadata and normalizes the budget.
|
||||
// Returns the normalized budget (nil if thinking should not be enabled) and whether it matched.
|
||||
func ResolveClaudeThinkingConfig(modelName string, metadata map[string]any) (*int, bool) {
|
||||
if !ModelSupportsThinking(modelName) {
|
||||
return nil, false
|
||||
}
|
||||
budget, include, matched := ResolveThinkingConfigFromMetadata(modelName, metadata)
|
||||
if !matched {
|
||||
return nil, false
|
||||
|
||||
@@ -25,9 +25,15 @@ func ApplyGeminiThinkingConfig(body []byte, budget *int, includeThoughts *bool)
|
||||
updated = rewritten
|
||||
}
|
||||
}
|
||||
if includeThoughts != nil {
|
||||
// Default to including thoughts when a budget override is present but no explicit include flag is provided.
|
||||
incl := includeThoughts
|
||||
if incl == nil && budget != nil && *budget != 0 {
|
||||
defaultInclude := true
|
||||
incl = &defaultInclude
|
||||
}
|
||||
if incl != nil {
|
||||
valuePath := "generationConfig.thinkingConfig.include_thoughts"
|
||||
rewritten, err := sjson.SetBytes(updated, valuePath, *includeThoughts)
|
||||
rewritten, err := sjson.SetBytes(updated, valuePath, *incl)
|
||||
if err == nil {
|
||||
updated = rewritten
|
||||
}
|
||||
@@ -47,9 +53,15 @@ func ApplyGeminiCLIThinkingConfig(body []byte, budget *int, includeThoughts *boo
|
||||
updated = rewritten
|
||||
}
|
||||
}
|
||||
if includeThoughts != nil {
|
||||
// Default to including thoughts when a budget override is present but no explicit include flag is provided.
|
||||
incl := includeThoughts
|
||||
if incl == nil && budget != nil && *budget != 0 {
|
||||
defaultInclude := true
|
||||
incl = &defaultInclude
|
||||
}
|
||||
if incl != nil {
|
||||
valuePath := "request.generationConfig.thinkingConfig.include_thoughts"
|
||||
rewritten, err := sjson.SetBytes(updated, valuePath, *includeThoughts)
|
||||
rewritten, err := sjson.SetBytes(updated, valuePath, *incl)
|
||||
if err == nil {
|
||||
updated = rewritten
|
||||
}
|
||||
@@ -140,6 +152,71 @@ func NormalizeGeminiCLIThinkingBudget(model string, body []byte) []byte {
|
||||
return updated
|
||||
}
|
||||
|
||||
// ReasoningEffortBudgetMapping defines the thinkingBudget values for each reasoning effort level.
|
||||
var ReasoningEffortBudgetMapping = map[string]int{
|
||||
"none": 0,
|
||||
"auto": -1,
|
||||
"minimal": 512,
|
||||
"low": 1024,
|
||||
"medium": 8192,
|
||||
"high": 24576,
|
||||
"xhigh": 32768,
|
||||
}
|
||||
|
||||
// ApplyReasoningEffortToGemini applies OpenAI reasoning_effort to Gemini thinkingConfig
|
||||
// for standard Gemini API format (generationConfig.thinkingConfig path).
|
||||
// Returns the modified body with thinkingBudget and include_thoughts set.
|
||||
func ApplyReasoningEffortToGemini(body []byte, effort string) []byte {
|
||||
normalized := strings.ToLower(strings.TrimSpace(effort))
|
||||
if normalized == "" {
|
||||
return body
|
||||
}
|
||||
|
||||
budgetPath := "generationConfig.thinkingConfig.thinkingBudget"
|
||||
includePath := "generationConfig.thinkingConfig.include_thoughts"
|
||||
|
||||
if normalized == "none" {
|
||||
body, _ = sjson.DeleteBytes(body, "generationConfig.thinkingConfig")
|
||||
return body
|
||||
}
|
||||
|
||||
budget, ok := ReasoningEffortBudgetMapping[normalized]
|
||||
if !ok {
|
||||
return body
|
||||
}
|
||||
|
||||
body, _ = sjson.SetBytes(body, budgetPath, budget)
|
||||
body, _ = sjson.SetBytes(body, includePath, true)
|
||||
return body
|
||||
}
|
||||
|
||||
// ApplyReasoningEffortToGeminiCLI applies OpenAI reasoning_effort to Gemini CLI thinkingConfig
|
||||
// for Gemini CLI API format (request.generationConfig.thinkingConfig path).
|
||||
// Returns the modified body with thinkingBudget and include_thoughts set.
|
||||
func ApplyReasoningEffortToGeminiCLI(body []byte, effort string) []byte {
|
||||
normalized := strings.ToLower(strings.TrimSpace(effort))
|
||||
if normalized == "" {
|
||||
return body
|
||||
}
|
||||
|
||||
budgetPath := "request.generationConfig.thinkingConfig.thinkingBudget"
|
||||
includePath := "request.generationConfig.thinkingConfig.include_thoughts"
|
||||
|
||||
if normalized == "none" {
|
||||
body, _ = sjson.DeleteBytes(body, "request.generationConfig.thinkingConfig")
|
||||
return body
|
||||
}
|
||||
|
||||
budget, ok := ReasoningEffortBudgetMapping[normalized]
|
||||
if !ok {
|
||||
return body
|
||||
}
|
||||
|
||||
body, _ = sjson.SetBytes(body, budgetPath, budget)
|
||||
body, _ = sjson.SetBytes(body, includePath, true)
|
||||
return body
|
||||
}
|
||||
|
||||
// ConvertThinkingLevelToBudget checks for "generationConfig.thinkingConfig.thinkingLevel"
|
||||
// and converts it to "thinkingBudget".
|
||||
// "high" -> 32768
|
||||
|
||||
@@ -25,33 +25,33 @@ func ModelSupportsThinking(model string) bool {
|
||||
// or min (0 if zero is allowed and mid <= 0).
|
||||
func NormalizeThinkingBudget(model string, budget int) int {
|
||||
if budget == -1 { // dynamic
|
||||
if found, min, max, zeroAllowed, dynamicAllowed := thinkingRangeFromRegistry(model); found {
|
||||
if found, minBudget, maxBudget, zeroAllowed, dynamicAllowed := thinkingRangeFromRegistry(model); found {
|
||||
if dynamicAllowed {
|
||||
return -1
|
||||
}
|
||||
mid := (min + max) / 2
|
||||
mid := (minBudget + maxBudget) / 2
|
||||
if mid <= 0 && zeroAllowed {
|
||||
return 0
|
||||
}
|
||||
if mid <= 0 {
|
||||
return min
|
||||
return minBudget
|
||||
}
|
||||
return mid
|
||||
}
|
||||
return -1
|
||||
}
|
||||
if found, min, max, zeroAllowed, _ := thinkingRangeFromRegistry(model); found {
|
||||
if found, minBudget, maxBudget, zeroAllowed, _ := thinkingRangeFromRegistry(model); found {
|
||||
if budget == 0 {
|
||||
if zeroAllowed {
|
||||
return 0
|
||||
}
|
||||
return min
|
||||
return minBudget
|
||||
}
|
||||
if budget < min {
|
||||
return min
|
||||
if budget < minBudget {
|
||||
return minBudget
|
||||
}
|
||||
if budget > max {
|
||||
return max
|
||||
if budget > maxBudget {
|
||||
return maxBudget
|
||||
}
|
||||
return budget
|
||||
}
|
||||
@@ -105,3 +105,96 @@ func NormalizeReasoningEffortLevel(model, effort string) (string, bool) {
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
// IsOpenAICompatibilityModel reports whether the model is registered as an OpenAI-compatibility model.
|
||||
// These models may not advertise Thinking metadata in the registry.
|
||||
func IsOpenAICompatibilityModel(model string) bool {
|
||||
if model == "" {
|
||||
return false
|
||||
}
|
||||
info := registry.GetGlobalRegistry().GetModelInfo(model)
|
||||
if info == nil {
|
||||
return false
|
||||
}
|
||||
return strings.EqualFold(strings.TrimSpace(info.Type), "openai-compatibility")
|
||||
}
|
||||
|
||||
// ThinkingEffortToBudget maps a reasoning effort level to a numeric thinking budget (tokens),
|
||||
// clamping the result to the model's supported range.
|
||||
//
|
||||
// Mappings (values are normalized to model's supported range):
|
||||
// - "none" -> 0
|
||||
// - "auto" -> -1
|
||||
// - "minimal" -> 512
|
||||
// - "low" -> 1024
|
||||
// - "medium" -> 8192
|
||||
// - "high" -> 24576
|
||||
// - "xhigh" -> 32768
|
||||
//
|
||||
// Returns false when the effort level is empty or unsupported.
|
||||
func ThinkingEffortToBudget(model, effort string) (int, bool) {
|
||||
if effort == "" {
|
||||
return 0, false
|
||||
}
|
||||
normalized, ok := NormalizeReasoningEffortLevel(model, effort)
|
||||
if !ok {
|
||||
normalized = strings.ToLower(strings.TrimSpace(effort))
|
||||
}
|
||||
switch normalized {
|
||||
case "none":
|
||||
return 0, true
|
||||
case "auto":
|
||||
return NormalizeThinkingBudget(model, -1), true
|
||||
case "minimal":
|
||||
return NormalizeThinkingBudget(model, 512), true
|
||||
case "low":
|
||||
return NormalizeThinkingBudget(model, 1024), true
|
||||
case "medium":
|
||||
return NormalizeThinkingBudget(model, 8192), true
|
||||
case "high":
|
||||
return NormalizeThinkingBudget(model, 24576), true
|
||||
case "xhigh":
|
||||
return NormalizeThinkingBudget(model, 32768), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
// ThinkingBudgetToEffort maps a numeric thinking budget (tokens)
|
||||
// to a reasoning effort level for level-based models.
|
||||
//
|
||||
// Mappings:
|
||||
// - 0 -> "none" (or lowest supported level if model doesn't support "none")
|
||||
// - -1 -> "auto"
|
||||
// - 1..1024 -> "low"
|
||||
// - 1025..8192 -> "medium"
|
||||
// - 8193..24576 -> "high"
|
||||
// - 24577.. -> highest supported level for the model (defaults to "xhigh")
|
||||
//
|
||||
// Returns false when the budget is unsupported (negative values other than -1).
|
||||
func ThinkingBudgetToEffort(model string, budget int) (string, bool) {
|
||||
switch {
|
||||
case budget == -1:
|
||||
return "auto", true
|
||||
case budget < -1:
|
||||
return "", false
|
||||
case budget == 0:
|
||||
if levels := GetModelThinkingLevels(model); len(levels) > 0 {
|
||||
return levels[0], true
|
||||
}
|
||||
return "none", true
|
||||
case budget > 0 && budget <= 1024:
|
||||
return "low", true
|
||||
case budget <= 8192:
|
||||
return "medium", true
|
||||
case budget <= 24576:
|
||||
return "high", true
|
||||
case budget > 24576:
|
||||
if levels := GetModelThinkingLevels(model); len(levels) > 0 {
|
||||
return levels[len(levels)-1], true
|
||||
}
|
||||
return "xhigh", true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -163,6 +163,11 @@ func ResolveThinkingConfigFromMetadata(model string, metadata map[string]any) (*
|
||||
if !matched {
|
||||
return nil, nil, false
|
||||
}
|
||||
// Level-based models (OpenAI-style) do not accept numeric thinking budgets in
|
||||
// Claude/Gemini-style protocols, so we don't derive budgets for them here.
|
||||
if ModelUsesThinkingLevels(model) {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
if budget == nil && effort != nil {
|
||||
if derived, ok := ThinkingEffortToBudget(model, *effort); ok {
|
||||
@@ -196,36 +201,6 @@ func ReasoningEffortFromMetadata(metadata map[string]any) (string, bool) {
|
||||
return "", true
|
||||
}
|
||||
|
||||
// ThinkingEffortToBudget maps reasoning effort levels to approximate budgets,
|
||||
// clamping the result to the model's supported range.
|
||||
func ThinkingEffortToBudget(model, effort string) (int, bool) {
|
||||
if effort == "" {
|
||||
return 0, false
|
||||
}
|
||||
normalized, ok := NormalizeReasoningEffortLevel(model, effort)
|
||||
if !ok {
|
||||
normalized = strings.ToLower(strings.TrimSpace(effort))
|
||||
}
|
||||
switch normalized {
|
||||
case "none":
|
||||
return 0, true
|
||||
case "auto":
|
||||
return NormalizeThinkingBudget(model, -1), true
|
||||
case "minimal":
|
||||
return NormalizeThinkingBudget(model, 512), true
|
||||
case "low":
|
||||
return NormalizeThinkingBudget(model, 1024), true
|
||||
case "medium":
|
||||
return NormalizeThinkingBudget(model, 8192), true
|
||||
case "high":
|
||||
return NormalizeThinkingBudget(model, 24576), true
|
||||
case "xhigh":
|
||||
return NormalizeThinkingBudget(model, 32768), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
// ResolveOriginalModel returns the original model name stored in metadata (if present),
|
||||
// otherwise falls back to the provided model.
|
||||
func ResolveOriginalModel(model string, metadata map[string]any) string {
|
||||
|
||||
303
internal/watcher/diff/config_diff.go
Normal file
303
internal/watcher/diff/config_diff.go
Normal file
@@ -0,0 +1,303 @@
|
||||
package diff
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
// BuildConfigChangeDetails computes a redacted, human-readable list of config changes.
|
||||
// Secrets are never printed; only structural or non-sensitive fields are surfaced.
|
||||
func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
|
||||
changes := make([]string, 0, 16)
|
||||
if oldCfg == nil || newCfg == nil {
|
||||
return changes
|
||||
}
|
||||
|
||||
// Simple scalars
|
||||
if oldCfg.Port != newCfg.Port {
|
||||
changes = append(changes, fmt.Sprintf("port: %d -> %d", oldCfg.Port, newCfg.Port))
|
||||
}
|
||||
if oldCfg.AuthDir != newCfg.AuthDir {
|
||||
changes = append(changes, fmt.Sprintf("auth-dir: %s -> %s", oldCfg.AuthDir, newCfg.AuthDir))
|
||||
}
|
||||
if oldCfg.Debug != newCfg.Debug {
|
||||
changes = append(changes, fmt.Sprintf("debug: %t -> %t", oldCfg.Debug, newCfg.Debug))
|
||||
}
|
||||
if oldCfg.LoggingToFile != newCfg.LoggingToFile {
|
||||
changes = append(changes, fmt.Sprintf("logging-to-file: %t -> %t", oldCfg.LoggingToFile, newCfg.LoggingToFile))
|
||||
}
|
||||
if oldCfg.UsageStatisticsEnabled != newCfg.UsageStatisticsEnabled {
|
||||
changes = append(changes, fmt.Sprintf("usage-statistics-enabled: %t -> %t", oldCfg.UsageStatisticsEnabled, newCfg.UsageStatisticsEnabled))
|
||||
}
|
||||
if oldCfg.DisableCooling != newCfg.DisableCooling {
|
||||
changes = append(changes, fmt.Sprintf("disable-cooling: %t -> %t", oldCfg.DisableCooling, newCfg.DisableCooling))
|
||||
}
|
||||
if oldCfg.RequestLog != newCfg.RequestLog {
|
||||
changes = append(changes, fmt.Sprintf("request-log: %t -> %t", oldCfg.RequestLog, newCfg.RequestLog))
|
||||
}
|
||||
if oldCfg.RequestRetry != newCfg.RequestRetry {
|
||||
changes = append(changes, fmt.Sprintf("request-retry: %d -> %d", oldCfg.RequestRetry, newCfg.RequestRetry))
|
||||
}
|
||||
if oldCfg.MaxRetryInterval != newCfg.MaxRetryInterval {
|
||||
changes = append(changes, fmt.Sprintf("max-retry-interval: %d -> %d", oldCfg.MaxRetryInterval, newCfg.MaxRetryInterval))
|
||||
}
|
||||
if oldCfg.ProxyURL != newCfg.ProxyURL {
|
||||
changes = append(changes, fmt.Sprintf("proxy-url: %s -> %s", formatProxyURL(oldCfg.ProxyURL), formatProxyURL(newCfg.ProxyURL)))
|
||||
}
|
||||
if oldCfg.WebsocketAuth != newCfg.WebsocketAuth {
|
||||
changes = append(changes, fmt.Sprintf("ws-auth: %t -> %t", oldCfg.WebsocketAuth, newCfg.WebsocketAuth))
|
||||
}
|
||||
if oldCfg.ForceModelPrefix != newCfg.ForceModelPrefix {
|
||||
changes = append(changes, fmt.Sprintf("force-model-prefix: %t -> %t", oldCfg.ForceModelPrefix, newCfg.ForceModelPrefix))
|
||||
}
|
||||
|
||||
// Quota-exceeded behavior
|
||||
if oldCfg.QuotaExceeded.SwitchProject != newCfg.QuotaExceeded.SwitchProject {
|
||||
changes = append(changes, fmt.Sprintf("quota-exceeded.switch-project: %t -> %t", oldCfg.QuotaExceeded.SwitchProject, newCfg.QuotaExceeded.SwitchProject))
|
||||
}
|
||||
if oldCfg.QuotaExceeded.SwitchPreviewModel != newCfg.QuotaExceeded.SwitchPreviewModel {
|
||||
changes = append(changes, fmt.Sprintf("quota-exceeded.switch-preview-model: %t -> %t", oldCfg.QuotaExceeded.SwitchPreviewModel, newCfg.QuotaExceeded.SwitchPreviewModel))
|
||||
}
|
||||
|
||||
// API keys (redacted) and counts
|
||||
if len(oldCfg.APIKeys) != len(newCfg.APIKeys) {
|
||||
changes = append(changes, fmt.Sprintf("api-keys count: %d -> %d", len(oldCfg.APIKeys), len(newCfg.APIKeys)))
|
||||
} else if !reflect.DeepEqual(trimStrings(oldCfg.APIKeys), trimStrings(newCfg.APIKeys)) {
|
||||
changes = append(changes, "api-keys: values updated (count unchanged, redacted)")
|
||||
}
|
||||
if len(oldCfg.GeminiKey) != len(newCfg.GeminiKey) {
|
||||
changes = append(changes, fmt.Sprintf("gemini-api-key count: %d -> %d", len(oldCfg.GeminiKey), len(newCfg.GeminiKey)))
|
||||
} else {
|
||||
for i := range oldCfg.GeminiKey {
|
||||
o := oldCfg.GeminiKey[i]
|
||||
n := newCfg.GeminiKey[i]
|
||||
if strings.TrimSpace(o.BaseURL) != strings.TrimSpace(n.BaseURL) {
|
||||
changes = append(changes, fmt.Sprintf("gemini[%d].base-url: %s -> %s", i, strings.TrimSpace(o.BaseURL), strings.TrimSpace(n.BaseURL)))
|
||||
}
|
||||
if strings.TrimSpace(o.ProxyURL) != strings.TrimSpace(n.ProxyURL) {
|
||||
changes = append(changes, fmt.Sprintf("gemini[%d].proxy-url: %s -> %s", i, formatProxyURL(o.ProxyURL), formatProxyURL(n.ProxyURL)))
|
||||
}
|
||||
if strings.TrimSpace(o.Prefix) != strings.TrimSpace(n.Prefix) {
|
||||
changes = append(changes, fmt.Sprintf("gemini[%d].prefix: %s -> %s", i, formatProxyURL(o.Prefix), formatProxyURL(n.Prefix)))
|
||||
}
|
||||
if strings.TrimSpace(o.APIKey) != strings.TrimSpace(n.APIKey) {
|
||||
changes = append(changes, fmt.Sprintf("gemini[%d].api-key: updated", i))
|
||||
}
|
||||
if !equalStringMap(o.Headers, n.Headers) {
|
||||
changes = append(changes, fmt.Sprintf("gemini[%d].headers: updated", i))
|
||||
}
|
||||
oldExcluded := SummarizeExcludedModels(o.ExcludedModels)
|
||||
newExcluded := SummarizeExcludedModels(n.ExcludedModels)
|
||||
if oldExcluded.hash != newExcluded.hash {
|
||||
changes = append(changes, fmt.Sprintf("gemini[%d].excluded-models: updated (%d -> %d entries)", i, oldExcluded.count, newExcluded.count))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Claude keys (do not print key material)
|
||||
if len(oldCfg.ClaudeKey) != len(newCfg.ClaudeKey) {
|
||||
changes = append(changes, fmt.Sprintf("claude-api-key count: %d -> %d", len(oldCfg.ClaudeKey), len(newCfg.ClaudeKey)))
|
||||
} else {
|
||||
for i := range oldCfg.ClaudeKey {
|
||||
o := oldCfg.ClaudeKey[i]
|
||||
n := newCfg.ClaudeKey[i]
|
||||
if strings.TrimSpace(o.BaseURL) != strings.TrimSpace(n.BaseURL) {
|
||||
changes = append(changes, fmt.Sprintf("claude[%d].base-url: %s -> %s", i, strings.TrimSpace(o.BaseURL), strings.TrimSpace(n.BaseURL)))
|
||||
}
|
||||
if strings.TrimSpace(o.ProxyURL) != strings.TrimSpace(n.ProxyURL) {
|
||||
changes = append(changes, fmt.Sprintf("claude[%d].proxy-url: %s -> %s", i, formatProxyURL(o.ProxyURL), formatProxyURL(n.ProxyURL)))
|
||||
}
|
||||
if strings.TrimSpace(o.Prefix) != strings.TrimSpace(n.Prefix) {
|
||||
changes = append(changes, fmt.Sprintf("claude[%d].prefix: %s -> %s", i, formatProxyURL(o.Prefix), formatProxyURL(n.Prefix)))
|
||||
}
|
||||
if strings.TrimSpace(o.APIKey) != strings.TrimSpace(n.APIKey) {
|
||||
changes = append(changes, fmt.Sprintf("claude[%d].api-key: updated", i))
|
||||
}
|
||||
if !equalStringMap(o.Headers, n.Headers) {
|
||||
changes = append(changes, fmt.Sprintf("claude[%d].headers: updated", i))
|
||||
}
|
||||
oldExcluded := SummarizeExcludedModels(o.ExcludedModels)
|
||||
newExcluded := SummarizeExcludedModels(n.ExcludedModels)
|
||||
if oldExcluded.hash != newExcluded.hash {
|
||||
changes = append(changes, fmt.Sprintf("claude[%d].excluded-models: updated (%d -> %d entries)", i, oldExcluded.count, newExcluded.count))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Codex keys (do not print key material)
|
||||
if len(oldCfg.CodexKey) != len(newCfg.CodexKey) {
|
||||
changes = append(changes, fmt.Sprintf("codex-api-key count: %d -> %d", len(oldCfg.CodexKey), len(newCfg.CodexKey)))
|
||||
} else {
|
||||
for i := range oldCfg.CodexKey {
|
||||
o := oldCfg.CodexKey[i]
|
||||
n := newCfg.CodexKey[i]
|
||||
if strings.TrimSpace(o.BaseURL) != strings.TrimSpace(n.BaseURL) {
|
||||
changes = append(changes, fmt.Sprintf("codex[%d].base-url: %s -> %s", i, strings.TrimSpace(o.BaseURL), strings.TrimSpace(n.BaseURL)))
|
||||
}
|
||||
if strings.TrimSpace(o.ProxyURL) != strings.TrimSpace(n.ProxyURL) {
|
||||
changes = append(changes, fmt.Sprintf("codex[%d].proxy-url: %s -> %s", i, formatProxyURL(o.ProxyURL), formatProxyURL(n.ProxyURL)))
|
||||
}
|
||||
if strings.TrimSpace(o.Prefix) != strings.TrimSpace(n.Prefix) {
|
||||
changes = append(changes, fmt.Sprintf("codex[%d].prefix: %s -> %s", i, formatProxyURL(o.Prefix), formatProxyURL(n.Prefix)))
|
||||
}
|
||||
if strings.TrimSpace(o.APIKey) != strings.TrimSpace(n.APIKey) {
|
||||
changes = append(changes, fmt.Sprintf("codex[%d].api-key: updated", i))
|
||||
}
|
||||
if !equalStringMap(o.Headers, n.Headers) {
|
||||
changes = append(changes, fmt.Sprintf("codex[%d].headers: updated", i))
|
||||
}
|
||||
oldExcluded := SummarizeExcludedModels(o.ExcludedModels)
|
||||
newExcluded := SummarizeExcludedModels(n.ExcludedModels)
|
||||
if oldExcluded.hash != newExcluded.hash {
|
||||
changes = append(changes, fmt.Sprintf("codex[%d].excluded-models: updated (%d -> %d entries)", i, oldExcluded.count, newExcluded.count))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AmpCode settings (redacted where needed)
|
||||
oldAmpURL := strings.TrimSpace(oldCfg.AmpCode.UpstreamURL)
|
||||
newAmpURL := strings.TrimSpace(newCfg.AmpCode.UpstreamURL)
|
||||
if oldAmpURL != newAmpURL {
|
||||
changes = append(changes, fmt.Sprintf("ampcode.upstream-url: %s -> %s", oldAmpURL, newAmpURL))
|
||||
}
|
||||
oldAmpKey := strings.TrimSpace(oldCfg.AmpCode.UpstreamAPIKey)
|
||||
newAmpKey := strings.TrimSpace(newCfg.AmpCode.UpstreamAPIKey)
|
||||
switch {
|
||||
case oldAmpKey == "" && newAmpKey != "":
|
||||
changes = append(changes, "ampcode.upstream-api-key: added")
|
||||
case oldAmpKey != "" && newAmpKey == "":
|
||||
changes = append(changes, "ampcode.upstream-api-key: removed")
|
||||
case oldAmpKey != newAmpKey:
|
||||
changes = append(changes, "ampcode.upstream-api-key: updated")
|
||||
}
|
||||
if oldCfg.AmpCode.RestrictManagementToLocalhost != newCfg.AmpCode.RestrictManagementToLocalhost {
|
||||
changes = append(changes, fmt.Sprintf("ampcode.restrict-management-to-localhost: %t -> %t", oldCfg.AmpCode.RestrictManagementToLocalhost, newCfg.AmpCode.RestrictManagementToLocalhost))
|
||||
}
|
||||
oldMappings := SummarizeAmpModelMappings(oldCfg.AmpCode.ModelMappings)
|
||||
newMappings := SummarizeAmpModelMappings(newCfg.AmpCode.ModelMappings)
|
||||
if oldMappings.hash != newMappings.hash {
|
||||
changes = append(changes, fmt.Sprintf("ampcode.model-mappings: updated (%d -> %d entries)", oldMappings.count, newMappings.count))
|
||||
}
|
||||
if oldCfg.AmpCode.ForceModelMappings != newCfg.AmpCode.ForceModelMappings {
|
||||
changes = append(changes, fmt.Sprintf("ampcode.force-model-mappings: %t -> %t", oldCfg.AmpCode.ForceModelMappings, newCfg.AmpCode.ForceModelMappings))
|
||||
}
|
||||
|
||||
if entries, _ := DiffOAuthExcludedModelChanges(oldCfg.OAuthExcludedModels, newCfg.OAuthExcludedModels); len(entries) > 0 {
|
||||
changes = append(changes, entries...)
|
||||
}
|
||||
|
||||
// Remote management (never print the key)
|
||||
if oldCfg.RemoteManagement.AllowRemote != newCfg.RemoteManagement.AllowRemote {
|
||||
changes = append(changes, fmt.Sprintf("remote-management.allow-remote: %t -> %t", oldCfg.RemoteManagement.AllowRemote, newCfg.RemoteManagement.AllowRemote))
|
||||
}
|
||||
if oldCfg.RemoteManagement.DisableControlPanel != newCfg.RemoteManagement.DisableControlPanel {
|
||||
changes = append(changes, fmt.Sprintf("remote-management.disable-control-panel: %t -> %t", oldCfg.RemoteManagement.DisableControlPanel, newCfg.RemoteManagement.DisableControlPanel))
|
||||
}
|
||||
oldPanelRepo := strings.TrimSpace(oldCfg.RemoteManagement.PanelGitHubRepository)
|
||||
newPanelRepo := strings.TrimSpace(newCfg.RemoteManagement.PanelGitHubRepository)
|
||||
if oldPanelRepo != newPanelRepo {
|
||||
changes = append(changes, fmt.Sprintf("remote-management.panel-github-repository: %s -> %s", oldPanelRepo, newPanelRepo))
|
||||
}
|
||||
if oldCfg.RemoteManagement.SecretKey != newCfg.RemoteManagement.SecretKey {
|
||||
switch {
|
||||
case oldCfg.RemoteManagement.SecretKey == "" && newCfg.RemoteManagement.SecretKey != "":
|
||||
changes = append(changes, "remote-management.secret-key: created")
|
||||
case oldCfg.RemoteManagement.SecretKey != "" && newCfg.RemoteManagement.SecretKey == "":
|
||||
changes = append(changes, "remote-management.secret-key: deleted")
|
||||
default:
|
||||
changes = append(changes, "remote-management.secret-key: updated")
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAI compatibility providers (summarized)
|
||||
if compat := DiffOpenAICompatibility(oldCfg.OpenAICompatibility, newCfg.OpenAICompatibility); len(compat) > 0 {
|
||||
changes = append(changes, "openai-compatibility:")
|
||||
for _, c := range compat {
|
||||
changes = append(changes, " "+c)
|
||||
}
|
||||
}
|
||||
|
||||
// Vertex-compatible API keys
|
||||
if len(oldCfg.VertexCompatAPIKey) != len(newCfg.VertexCompatAPIKey) {
|
||||
changes = append(changes, fmt.Sprintf("vertex-api-key count: %d -> %d", len(oldCfg.VertexCompatAPIKey), len(newCfg.VertexCompatAPIKey)))
|
||||
} else {
|
||||
for i := range oldCfg.VertexCompatAPIKey {
|
||||
o := oldCfg.VertexCompatAPIKey[i]
|
||||
n := newCfg.VertexCompatAPIKey[i]
|
||||
if strings.TrimSpace(o.BaseURL) != strings.TrimSpace(n.BaseURL) {
|
||||
changes = append(changes, fmt.Sprintf("vertex[%d].base-url: %s -> %s", i, strings.TrimSpace(o.BaseURL), strings.TrimSpace(n.BaseURL)))
|
||||
}
|
||||
if strings.TrimSpace(o.ProxyURL) != strings.TrimSpace(n.ProxyURL) {
|
||||
changes = append(changes, fmt.Sprintf("vertex[%d].proxy-url: %s -> %s", i, formatProxyURL(o.ProxyURL), formatProxyURL(n.ProxyURL)))
|
||||
}
|
||||
if strings.TrimSpace(o.Prefix) != strings.TrimSpace(n.Prefix) {
|
||||
changes = append(changes, fmt.Sprintf("vertex[%d].prefix: %s -> %s", i, formatProxyURL(o.Prefix), formatProxyURL(n.Prefix)))
|
||||
}
|
||||
if strings.TrimSpace(o.APIKey) != strings.TrimSpace(n.APIKey) {
|
||||
changes = append(changes, fmt.Sprintf("vertex[%d].api-key: updated", i))
|
||||
}
|
||||
oldModels := SummarizeVertexModels(o.Models)
|
||||
newModels := SummarizeVertexModels(n.Models)
|
||||
if oldModels.hash != newModels.hash {
|
||||
changes = append(changes, fmt.Sprintf("vertex[%d].models: updated (%d -> %d entries)", i, oldModels.count, newModels.count))
|
||||
}
|
||||
if !equalStringMap(o.Headers, n.Headers) {
|
||||
changes = append(changes, fmt.Sprintf("vertex[%d].headers: updated", i))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return changes
|
||||
}
|
||||
|
||||
func trimStrings(in []string) []string {
|
||||
out := make([]string, len(in))
|
||||
for i := range in {
|
||||
out[i] = strings.TrimSpace(in[i])
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func equalStringMap(a, b map[string]string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for k, v := range a {
|
||||
if b[k] != v {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func formatProxyURL(raw string) string {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return "<none>"
|
||||
}
|
||||
parsed, err := url.Parse(trimmed)
|
||||
if err != nil {
|
||||
return "<redacted>"
|
||||
}
|
||||
host := strings.TrimSpace(parsed.Host)
|
||||
scheme := strings.TrimSpace(parsed.Scheme)
|
||||
if host == "" {
|
||||
// Allow host:port style without scheme.
|
||||
parsed2, err2 := url.Parse("http://" + trimmed)
|
||||
if err2 == nil {
|
||||
host = strings.TrimSpace(parsed2.Host)
|
||||
}
|
||||
scheme = ""
|
||||
}
|
||||
if host == "" {
|
||||
return "<redacted>"
|
||||
}
|
||||
if scheme == "" {
|
||||
return host
|
||||
}
|
||||
return scheme + "://" + host
|
||||
}
|
||||
526
internal/watcher/diff/config_diff_test.go
Normal file
526
internal/watcher/diff/config_diff_test.go
Normal file
@@ -0,0 +1,526 @@
|
||||
package diff
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
func TestBuildConfigChangeDetails(t *testing.T) {
|
||||
oldCfg := &config.Config{
|
||||
Port: 8080,
|
||||
AuthDir: "/tmp/auth-old",
|
||||
GeminiKey: []config.GeminiKey{
|
||||
{APIKey: "old", BaseURL: "http://old", ExcludedModels: []string{"old-model"}},
|
||||
},
|
||||
AmpCode: config.AmpCode{
|
||||
UpstreamURL: "http://old-upstream",
|
||||
ModelMappings: []config.AmpModelMapping{{From: "from-old", To: "to-old"}},
|
||||
RestrictManagementToLocalhost: false,
|
||||
},
|
||||
RemoteManagement: config.RemoteManagement{
|
||||
AllowRemote: false,
|
||||
SecretKey: "old",
|
||||
DisableControlPanel: false,
|
||||
PanelGitHubRepository: "repo-old",
|
||||
},
|
||||
OAuthExcludedModels: map[string][]string{
|
||||
"providerA": {"m1"},
|
||||
},
|
||||
OpenAICompatibility: []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "compat-a",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
|
||||
{APIKey: "k1"},
|
||||
},
|
||||
Models: []config.OpenAICompatibilityModel{{Name: "m1"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
newCfg := &config.Config{
|
||||
Port: 9090,
|
||||
AuthDir: "/tmp/auth-new",
|
||||
GeminiKey: []config.GeminiKey{
|
||||
{APIKey: "old", BaseURL: "http://old", ExcludedModels: []string{"old-model", "extra"}},
|
||||
},
|
||||
AmpCode: config.AmpCode{
|
||||
UpstreamURL: "http://new-upstream",
|
||||
RestrictManagementToLocalhost: true,
|
||||
ModelMappings: []config.AmpModelMapping{
|
||||
{From: "from-old", To: "to-old"},
|
||||
{From: "from-new", To: "to-new"},
|
||||
},
|
||||
},
|
||||
RemoteManagement: config.RemoteManagement{
|
||||
AllowRemote: true,
|
||||
SecretKey: "new",
|
||||
DisableControlPanel: true,
|
||||
PanelGitHubRepository: "repo-new",
|
||||
},
|
||||
OAuthExcludedModels: map[string][]string{
|
||||
"providerA": {"m1", "m2"},
|
||||
"providerB": {"x"},
|
||||
},
|
||||
OpenAICompatibility: []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "compat-a",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
|
||||
{APIKey: "k1"},
|
||||
},
|
||||
Models: []config.OpenAICompatibilityModel{{Name: "m1"}, {Name: "m2"}},
|
||||
},
|
||||
{
|
||||
Name: "compat-b",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
|
||||
{APIKey: "k2"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
details := BuildConfigChangeDetails(oldCfg, newCfg)
|
||||
|
||||
expectContains(t, details, "port: 8080 -> 9090")
|
||||
expectContains(t, details, "auth-dir: /tmp/auth-old -> /tmp/auth-new")
|
||||
expectContains(t, details, "gemini[0].excluded-models: updated (1 -> 2 entries)")
|
||||
expectContains(t, details, "ampcode.upstream-url: http://old-upstream -> http://new-upstream")
|
||||
expectContains(t, details, "ampcode.model-mappings: updated (1 -> 2 entries)")
|
||||
expectContains(t, details, "remote-management.allow-remote: false -> true")
|
||||
expectContains(t, details, "remote-management.secret-key: updated")
|
||||
expectContains(t, details, "oauth-excluded-models[providera]: updated (1 -> 2 entries)")
|
||||
expectContains(t, details, "oauth-excluded-models[providerb]: added (1 entries)")
|
||||
expectContains(t, details, "openai-compatibility:")
|
||||
expectContains(t, details, " provider added: compat-b (api-keys=1, models=0)")
|
||||
expectContains(t, details, " provider updated: compat-a (models 1 -> 2)")
|
||||
}
|
||||
|
||||
func TestBuildConfigChangeDetails_NoChanges(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Port: 8080,
|
||||
}
|
||||
if details := BuildConfigChangeDetails(cfg, cfg); len(details) != 0 {
|
||||
t.Fatalf("expected no change entries, got %v", details)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildConfigChangeDetails_GeminiVertexHeadersAndForceMappings(t *testing.T) {
|
||||
oldCfg := &config.Config{
|
||||
GeminiKey: []config.GeminiKey{
|
||||
{APIKey: "g1", Headers: map[string]string{"H": "1"}, ExcludedModels: []string{"a"}},
|
||||
},
|
||||
VertexCompatAPIKey: []config.VertexCompatKey{
|
||||
{APIKey: "v1", BaseURL: "http://v-old", Models: []config.VertexCompatModel{{Name: "m1"}}},
|
||||
},
|
||||
AmpCode: config.AmpCode{
|
||||
ModelMappings: []config.AmpModelMapping{{From: "a", To: "b"}},
|
||||
ForceModelMappings: false,
|
||||
},
|
||||
}
|
||||
newCfg := &config.Config{
|
||||
GeminiKey: []config.GeminiKey{
|
||||
{APIKey: "g1", Headers: map[string]string{"H": "2"}, ExcludedModels: []string{"a", "b"}},
|
||||
},
|
||||
VertexCompatAPIKey: []config.VertexCompatKey{
|
||||
{APIKey: "v1", BaseURL: "http://v-new", Models: []config.VertexCompatModel{{Name: "m1"}, {Name: "m2"}}},
|
||||
},
|
||||
AmpCode: config.AmpCode{
|
||||
ModelMappings: []config.AmpModelMapping{{From: "a", To: "c"}},
|
||||
ForceModelMappings: true,
|
||||
},
|
||||
}
|
||||
|
||||
details := BuildConfigChangeDetails(oldCfg, newCfg)
|
||||
expectContains(t, details, "gemini[0].headers: updated")
|
||||
expectContains(t, details, "gemini[0].excluded-models: updated (1 -> 2 entries)")
|
||||
expectContains(t, details, "ampcode.model-mappings: updated (1 -> 1 entries)")
|
||||
expectContains(t, details, "ampcode.force-model-mappings: false -> true")
|
||||
}
|
||||
|
||||
func TestBuildConfigChangeDetails_ModelPrefixes(t *testing.T) {
|
||||
oldCfg := &config.Config{
|
||||
GeminiKey: []config.GeminiKey{
|
||||
{APIKey: "g1", Prefix: "old-g", BaseURL: "http://g", ProxyURL: "http://gp"},
|
||||
},
|
||||
ClaudeKey: []config.ClaudeKey{
|
||||
{APIKey: "c1", Prefix: "old-c", BaseURL: "http://c", ProxyURL: "http://cp"},
|
||||
},
|
||||
CodexKey: []config.CodexKey{
|
||||
{APIKey: "x1", Prefix: "old-x", BaseURL: "http://x", ProxyURL: "http://xp"},
|
||||
},
|
||||
VertexCompatAPIKey: []config.VertexCompatKey{
|
||||
{APIKey: "v1", Prefix: "old-v", BaseURL: "http://v", ProxyURL: "http://vp"},
|
||||
},
|
||||
}
|
||||
newCfg := &config.Config{
|
||||
GeminiKey: []config.GeminiKey{
|
||||
{APIKey: "g1", Prefix: "new-g", BaseURL: "http://g", ProxyURL: "http://gp"},
|
||||
},
|
||||
ClaudeKey: []config.ClaudeKey{
|
||||
{APIKey: "c1", Prefix: "new-c", BaseURL: "http://c", ProxyURL: "http://cp"},
|
||||
},
|
||||
CodexKey: []config.CodexKey{
|
||||
{APIKey: "x1", Prefix: "new-x", BaseURL: "http://x", ProxyURL: "http://xp"},
|
||||
},
|
||||
VertexCompatAPIKey: []config.VertexCompatKey{
|
||||
{APIKey: "v1", Prefix: "new-v", BaseURL: "http://v", ProxyURL: "http://vp"},
|
||||
},
|
||||
}
|
||||
|
||||
changes := BuildConfigChangeDetails(oldCfg, newCfg)
|
||||
expectContains(t, changes, "gemini[0].prefix: old-g -> new-g")
|
||||
expectContains(t, changes, "claude[0].prefix: old-c -> new-c")
|
||||
expectContains(t, changes, "codex[0].prefix: old-x -> new-x")
|
||||
expectContains(t, changes, "vertex[0].prefix: old-v -> new-v")
|
||||
}
|
||||
|
||||
func TestBuildConfigChangeDetails_NilSafe(t *testing.T) {
|
||||
if details := BuildConfigChangeDetails(nil, &config.Config{}); len(details) != 0 {
|
||||
t.Fatalf("expected empty change list when old nil, got %v", details)
|
||||
}
|
||||
if details := BuildConfigChangeDetails(&config.Config{}, nil); len(details) != 0 {
|
||||
t.Fatalf("expected empty change list when new nil, got %v", details)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildConfigChangeDetails_SecretsAndCounts(t *testing.T) {
|
||||
oldCfg := &config.Config{
|
||||
SDKConfig: sdkconfig.SDKConfig{
|
||||
APIKeys: []string{"a"},
|
||||
},
|
||||
AmpCode: config.AmpCode{
|
||||
UpstreamAPIKey: "",
|
||||
},
|
||||
RemoteManagement: config.RemoteManagement{
|
||||
SecretKey: "",
|
||||
},
|
||||
}
|
||||
newCfg := &config.Config{
|
||||
SDKConfig: sdkconfig.SDKConfig{
|
||||
APIKeys: []string{"a", "b", "c"},
|
||||
},
|
||||
AmpCode: config.AmpCode{
|
||||
UpstreamAPIKey: "new-key",
|
||||
},
|
||||
RemoteManagement: config.RemoteManagement{
|
||||
SecretKey: "new-secret",
|
||||
},
|
||||
}
|
||||
|
||||
details := BuildConfigChangeDetails(oldCfg, newCfg)
|
||||
expectContains(t, details, "api-keys count: 1 -> 3")
|
||||
expectContains(t, details, "ampcode.upstream-api-key: added")
|
||||
expectContains(t, details, "remote-management.secret-key: created")
|
||||
}
|
||||
|
||||
func TestBuildConfigChangeDetails_FlagsAndKeys(t *testing.T) {
|
||||
oldCfg := &config.Config{
|
||||
Port: 1000,
|
||||
AuthDir: "/old",
|
||||
Debug: false,
|
||||
LoggingToFile: false,
|
||||
UsageStatisticsEnabled: false,
|
||||
DisableCooling: false,
|
||||
RequestRetry: 1,
|
||||
MaxRetryInterval: 1,
|
||||
WebsocketAuth: false,
|
||||
QuotaExceeded: config.QuotaExceeded{SwitchProject: false, SwitchPreviewModel: false},
|
||||
ClaudeKey: []config.ClaudeKey{{APIKey: "c1"}},
|
||||
CodexKey: []config.CodexKey{{APIKey: "x1"}},
|
||||
AmpCode: config.AmpCode{UpstreamAPIKey: "keep", RestrictManagementToLocalhost: false},
|
||||
RemoteManagement: config.RemoteManagement{DisableControlPanel: false, PanelGitHubRepository: "old/repo", SecretKey: "keep"},
|
||||
SDKConfig: sdkconfig.SDKConfig{
|
||||
RequestLog: false,
|
||||
ProxyURL: "http://old-proxy",
|
||||
APIKeys: []string{"key-1"},
|
||||
},
|
||||
}
|
||||
newCfg := &config.Config{
|
||||
Port: 2000,
|
||||
AuthDir: "/new",
|
||||
Debug: true,
|
||||
LoggingToFile: true,
|
||||
UsageStatisticsEnabled: true,
|
||||
DisableCooling: true,
|
||||
RequestRetry: 2,
|
||||
MaxRetryInterval: 3,
|
||||
WebsocketAuth: true,
|
||||
QuotaExceeded: config.QuotaExceeded{SwitchProject: true, SwitchPreviewModel: true},
|
||||
ClaudeKey: []config.ClaudeKey{
|
||||
{APIKey: "c1", BaseURL: "http://new", ProxyURL: "http://p", Headers: map[string]string{"H": "1"}, ExcludedModels: []string{"a"}},
|
||||
{APIKey: "c2"},
|
||||
},
|
||||
CodexKey: []config.CodexKey{
|
||||
{APIKey: "x1", BaseURL: "http://x", ProxyURL: "http://px", Headers: map[string]string{"H": "2"}, ExcludedModels: []string{"b"}},
|
||||
{APIKey: "x2"},
|
||||
},
|
||||
AmpCode: config.AmpCode{
|
||||
UpstreamAPIKey: "",
|
||||
RestrictManagementToLocalhost: true,
|
||||
ModelMappings: []config.AmpModelMapping{{From: "a", To: "b"}},
|
||||
},
|
||||
RemoteManagement: config.RemoteManagement{
|
||||
DisableControlPanel: true,
|
||||
PanelGitHubRepository: "new/repo",
|
||||
SecretKey: "",
|
||||
},
|
||||
SDKConfig: sdkconfig.SDKConfig{
|
||||
RequestLog: true,
|
||||
ProxyURL: "http://new-proxy",
|
||||
APIKeys: []string{" key-1 ", "key-2"},
|
||||
},
|
||||
}
|
||||
|
||||
details := BuildConfigChangeDetails(oldCfg, newCfg)
|
||||
expectContains(t, details, "debug: false -> true")
|
||||
expectContains(t, details, "logging-to-file: false -> true")
|
||||
expectContains(t, details, "usage-statistics-enabled: false -> true")
|
||||
expectContains(t, details, "disable-cooling: false -> true")
|
||||
expectContains(t, details, "request-log: false -> true")
|
||||
expectContains(t, details, "request-retry: 1 -> 2")
|
||||
expectContains(t, details, "max-retry-interval: 1 -> 3")
|
||||
expectContains(t, details, "proxy-url: http://old-proxy -> http://new-proxy")
|
||||
expectContains(t, details, "ws-auth: false -> true")
|
||||
expectContains(t, details, "quota-exceeded.switch-project: false -> true")
|
||||
expectContains(t, details, "quota-exceeded.switch-preview-model: false -> true")
|
||||
expectContains(t, details, "api-keys count: 1 -> 2")
|
||||
expectContains(t, details, "claude-api-key count: 1 -> 2")
|
||||
expectContains(t, details, "codex-api-key count: 1 -> 2")
|
||||
expectContains(t, details, "ampcode.restrict-management-to-localhost: false -> true")
|
||||
expectContains(t, details, "ampcode.upstream-api-key: removed")
|
||||
expectContains(t, details, "remote-management.disable-control-panel: false -> true")
|
||||
expectContains(t, details, "remote-management.panel-github-repository: old/repo -> new/repo")
|
||||
expectContains(t, details, "remote-management.secret-key: deleted")
|
||||
}
|
||||
|
||||
func TestBuildConfigChangeDetails_AllBranches(t *testing.T) {
|
||||
oldCfg := &config.Config{
|
||||
Port: 1,
|
||||
AuthDir: "/a",
|
||||
Debug: false,
|
||||
LoggingToFile: false,
|
||||
UsageStatisticsEnabled: false,
|
||||
DisableCooling: false,
|
||||
RequestRetry: 1,
|
||||
MaxRetryInterval: 1,
|
||||
WebsocketAuth: false,
|
||||
QuotaExceeded: config.QuotaExceeded{SwitchProject: false, SwitchPreviewModel: false},
|
||||
GeminiKey: []config.GeminiKey{
|
||||
{APIKey: "g-old", BaseURL: "http://g-old", ProxyURL: "http://gp-old", Headers: map[string]string{"A": "1"}},
|
||||
},
|
||||
ClaudeKey: []config.ClaudeKey{
|
||||
{APIKey: "c-old", BaseURL: "http://c-old", ProxyURL: "http://cp-old", Headers: map[string]string{"H": "1"}, ExcludedModels: []string{"x"}},
|
||||
},
|
||||
CodexKey: []config.CodexKey{
|
||||
{APIKey: "x-old", BaseURL: "http://x-old", ProxyURL: "http://xp-old", Headers: map[string]string{"H": "1"}, ExcludedModels: []string{"x"}},
|
||||
},
|
||||
VertexCompatAPIKey: []config.VertexCompatKey{
|
||||
{APIKey: "v-old", BaseURL: "http://v-old", ProxyURL: "http://vp-old", Headers: map[string]string{"H": "1"}, Models: []config.VertexCompatModel{{Name: "m1"}}},
|
||||
},
|
||||
AmpCode: config.AmpCode{
|
||||
UpstreamURL: "http://amp-old",
|
||||
UpstreamAPIKey: "old-key",
|
||||
RestrictManagementToLocalhost: false,
|
||||
ModelMappings: []config.AmpModelMapping{{From: "a", To: "b"}},
|
||||
ForceModelMappings: false,
|
||||
},
|
||||
RemoteManagement: config.RemoteManagement{
|
||||
AllowRemote: false,
|
||||
DisableControlPanel: false,
|
||||
PanelGitHubRepository: "old/repo",
|
||||
SecretKey: "old",
|
||||
},
|
||||
SDKConfig: sdkconfig.SDKConfig{
|
||||
RequestLog: false,
|
||||
ProxyURL: "http://old-proxy",
|
||||
APIKeys: []string{" keyA "},
|
||||
},
|
||||
OAuthExcludedModels: map[string][]string{"p1": {"a"}},
|
||||
OpenAICompatibility: []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "prov-old",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
|
||||
{APIKey: "k1"},
|
||||
},
|
||||
Models: []config.OpenAICompatibilityModel{{Name: "m1"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
newCfg := &config.Config{
|
||||
Port: 2,
|
||||
AuthDir: "/b",
|
||||
Debug: true,
|
||||
LoggingToFile: true,
|
||||
UsageStatisticsEnabled: true,
|
||||
DisableCooling: true,
|
||||
RequestRetry: 2,
|
||||
MaxRetryInterval: 3,
|
||||
WebsocketAuth: true,
|
||||
QuotaExceeded: config.QuotaExceeded{SwitchProject: true, SwitchPreviewModel: true},
|
||||
GeminiKey: []config.GeminiKey{
|
||||
{APIKey: "g-new", BaseURL: "http://g-new", ProxyURL: "http://gp-new", Headers: map[string]string{"A": "2"}, ExcludedModels: []string{"x", "y"}},
|
||||
},
|
||||
ClaudeKey: []config.ClaudeKey{
|
||||
{APIKey: "c-new", BaseURL: "http://c-new", ProxyURL: "http://cp-new", Headers: map[string]string{"H": "2"}, ExcludedModels: []string{"x", "y"}},
|
||||
},
|
||||
CodexKey: []config.CodexKey{
|
||||
{APIKey: "x-new", BaseURL: "http://x-new", ProxyURL: "http://xp-new", Headers: map[string]string{"H": "2"}, ExcludedModels: []string{"x", "y"}},
|
||||
},
|
||||
VertexCompatAPIKey: []config.VertexCompatKey{
|
||||
{APIKey: "v-new", BaseURL: "http://v-new", ProxyURL: "http://vp-new", Headers: map[string]string{"H": "2"}, Models: []config.VertexCompatModel{{Name: "m1"}, {Name: "m2"}}},
|
||||
},
|
||||
AmpCode: config.AmpCode{
|
||||
UpstreamURL: "http://amp-new",
|
||||
UpstreamAPIKey: "",
|
||||
RestrictManagementToLocalhost: true,
|
||||
ModelMappings: []config.AmpModelMapping{{From: "a", To: "c"}},
|
||||
ForceModelMappings: true,
|
||||
},
|
||||
RemoteManagement: config.RemoteManagement{
|
||||
AllowRemote: true,
|
||||
DisableControlPanel: true,
|
||||
PanelGitHubRepository: "new/repo",
|
||||
SecretKey: "",
|
||||
},
|
||||
SDKConfig: sdkconfig.SDKConfig{
|
||||
RequestLog: true,
|
||||
ProxyURL: "http://new-proxy",
|
||||
APIKeys: []string{"keyB"},
|
||||
},
|
||||
OAuthExcludedModels: map[string][]string{"p1": {"b", "c"}, "p2": {"d"}},
|
||||
OpenAICompatibility: []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "prov-old",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
|
||||
{APIKey: "k1"},
|
||||
{APIKey: "k2"},
|
||||
},
|
||||
Models: []config.OpenAICompatibilityModel{{Name: "m1"}, {Name: "m2"}},
|
||||
},
|
||||
{
|
||||
Name: "prov-new",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{{APIKey: "k3"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
changes := BuildConfigChangeDetails(oldCfg, newCfg)
|
||||
expectContains(t, changes, "port: 1 -> 2")
|
||||
expectContains(t, changes, "auth-dir: /a -> /b")
|
||||
expectContains(t, changes, "debug: false -> true")
|
||||
expectContains(t, changes, "logging-to-file: false -> true")
|
||||
expectContains(t, changes, "usage-statistics-enabled: false -> true")
|
||||
expectContains(t, changes, "disable-cooling: false -> true")
|
||||
expectContains(t, changes, "request-retry: 1 -> 2")
|
||||
expectContains(t, changes, "max-retry-interval: 1 -> 3")
|
||||
expectContains(t, changes, "proxy-url: http://old-proxy -> http://new-proxy")
|
||||
expectContains(t, changes, "ws-auth: false -> true")
|
||||
expectContains(t, changes, "quota-exceeded.switch-project: false -> true")
|
||||
expectContains(t, changes, "quota-exceeded.switch-preview-model: false -> true")
|
||||
expectContains(t, changes, "api-keys: values updated (count unchanged, redacted)")
|
||||
expectContains(t, changes, "gemini[0].base-url: http://g-old -> http://g-new")
|
||||
expectContains(t, changes, "gemini[0].proxy-url: http://gp-old -> http://gp-new")
|
||||
expectContains(t, changes, "gemini[0].api-key: updated")
|
||||
expectContains(t, changes, "gemini[0].headers: updated")
|
||||
expectContains(t, changes, "gemini[0].excluded-models: updated (0 -> 2 entries)")
|
||||
expectContains(t, changes, "claude[0].base-url: http://c-old -> http://c-new")
|
||||
expectContains(t, changes, "claude[0].proxy-url: http://cp-old -> http://cp-new")
|
||||
expectContains(t, changes, "claude[0].api-key: updated")
|
||||
expectContains(t, changes, "claude[0].headers: updated")
|
||||
expectContains(t, changes, "claude[0].excluded-models: updated (1 -> 2 entries)")
|
||||
expectContains(t, changes, "codex[0].base-url: http://x-old -> http://x-new")
|
||||
expectContains(t, changes, "codex[0].proxy-url: http://xp-old -> http://xp-new")
|
||||
expectContains(t, changes, "codex[0].api-key: updated")
|
||||
expectContains(t, changes, "codex[0].headers: updated")
|
||||
expectContains(t, changes, "codex[0].excluded-models: updated (1 -> 2 entries)")
|
||||
expectContains(t, changes, "vertex[0].base-url: http://v-old -> http://v-new")
|
||||
expectContains(t, changes, "vertex[0].proxy-url: http://vp-old -> http://vp-new")
|
||||
expectContains(t, changes, "vertex[0].api-key: updated")
|
||||
expectContains(t, changes, "vertex[0].models: updated (1 -> 2 entries)")
|
||||
expectContains(t, changes, "vertex[0].headers: updated")
|
||||
expectContains(t, changes, "ampcode.upstream-url: http://amp-old -> http://amp-new")
|
||||
expectContains(t, changes, "ampcode.upstream-api-key: removed")
|
||||
expectContains(t, changes, "ampcode.restrict-management-to-localhost: false -> true")
|
||||
expectContains(t, changes, "ampcode.model-mappings: updated (1 -> 1 entries)")
|
||||
expectContains(t, changes, "ampcode.force-model-mappings: false -> true")
|
||||
expectContains(t, changes, "oauth-excluded-models[p1]: updated (1 -> 2 entries)")
|
||||
expectContains(t, changes, "oauth-excluded-models[p2]: added (1 entries)")
|
||||
expectContains(t, changes, "remote-management.allow-remote: false -> true")
|
||||
expectContains(t, changes, "remote-management.disable-control-panel: false -> true")
|
||||
expectContains(t, changes, "remote-management.panel-github-repository: old/repo -> new/repo")
|
||||
expectContains(t, changes, "remote-management.secret-key: deleted")
|
||||
expectContains(t, changes, "openai-compatibility:")
|
||||
}
|
||||
|
||||
func TestFormatProxyURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{name: "empty", in: "", want: "<none>"},
|
||||
{name: "invalid", in: "http://[::1", want: "<redacted>"},
|
||||
{name: "fullURLRedactsUserinfoAndPath", in: "http://user:pass@example.com:8080/path?x=1#frag", want: "http://example.com:8080"},
|
||||
{name: "socks5RedactsUserinfoAndPath", in: "socks5://user:pass@192.168.1.1:1080/path?x=1", want: "socks5://192.168.1.1:1080"},
|
||||
{name: "socks5HostPort", in: "socks5://proxy.example.com:1080/", want: "socks5://proxy.example.com:1080"},
|
||||
{name: "hostPortNoScheme", in: "example.com:1234/path?x=1", want: "example.com:1234"},
|
||||
{name: "relativePathRedacted", in: "/just/path", want: "<redacted>"},
|
||||
{name: "schemeAndHost", in: "https://example.com", want: "https://example.com"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := formatProxyURL(tt.in); got != tt.want {
|
||||
t.Fatalf("expected %q, got %q", tt.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildConfigChangeDetails_SecretAndUpstreamUpdates(t *testing.T) {
|
||||
oldCfg := &config.Config{
|
||||
AmpCode: config.AmpCode{
|
||||
UpstreamAPIKey: "old",
|
||||
},
|
||||
RemoteManagement: config.RemoteManagement{
|
||||
SecretKey: "old",
|
||||
},
|
||||
}
|
||||
newCfg := &config.Config{
|
||||
AmpCode: config.AmpCode{
|
||||
UpstreamAPIKey: "new",
|
||||
},
|
||||
RemoteManagement: config.RemoteManagement{
|
||||
SecretKey: "new",
|
||||
},
|
||||
}
|
||||
|
||||
changes := BuildConfigChangeDetails(oldCfg, newCfg)
|
||||
expectContains(t, changes, "ampcode.upstream-api-key: updated")
|
||||
expectContains(t, changes, "remote-management.secret-key: updated")
|
||||
}
|
||||
|
||||
func TestBuildConfigChangeDetails_CountBranches(t *testing.T) {
|
||||
oldCfg := &config.Config{}
|
||||
newCfg := &config.Config{
|
||||
GeminiKey: []config.GeminiKey{{APIKey: "g"}},
|
||||
ClaudeKey: []config.ClaudeKey{{APIKey: "c"}},
|
||||
CodexKey: []config.CodexKey{{APIKey: "x"}},
|
||||
VertexCompatAPIKey: []config.VertexCompatKey{
|
||||
{APIKey: "v", BaseURL: "http://v"},
|
||||
},
|
||||
}
|
||||
|
||||
changes := BuildConfigChangeDetails(oldCfg, newCfg)
|
||||
expectContains(t, changes, "gemini-api-key count: 0 -> 1")
|
||||
expectContains(t, changes, "claude-api-key count: 0 -> 1")
|
||||
expectContains(t, changes, "codex-api-key count: 0 -> 1")
|
||||
expectContains(t, changes, "vertex-api-key count: 0 -> 1")
|
||||
}
|
||||
|
||||
func TestTrimStrings(t *testing.T) {
|
||||
out := trimStrings([]string{" a ", "b", " c"})
|
||||
if len(out) != 3 || out[0] != "a" || out[1] != "b" || out[2] != "c" {
|
||||
t.Fatalf("unexpected trimmed strings: %v", out)
|
||||
}
|
||||
}
|
||||
102
internal/watcher/diff/model_hash.go
Normal file
102
internal/watcher/diff/model_hash.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package diff
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
// ComputeOpenAICompatModelsHash returns a stable hash for OpenAI-compat models.
|
||||
// Used to detect model list changes during hot reload.
|
||||
func ComputeOpenAICompatModelsHash(models []config.OpenAICompatibilityModel) string {
|
||||
keys := normalizeModelPairs(func(out func(key string)) {
|
||||
for _, model := range models {
|
||||
name := strings.TrimSpace(model.Name)
|
||||
alias := strings.TrimSpace(model.Alias)
|
||||
if name == "" && alias == "" {
|
||||
continue
|
||||
}
|
||||
out(strings.ToLower(name) + "|" + strings.ToLower(alias))
|
||||
}
|
||||
})
|
||||
return hashJoined(keys)
|
||||
}
|
||||
|
||||
// ComputeVertexCompatModelsHash returns a stable hash for Vertex-compatible models.
|
||||
func ComputeVertexCompatModelsHash(models []config.VertexCompatModel) string {
|
||||
keys := normalizeModelPairs(func(out func(key string)) {
|
||||
for _, model := range models {
|
||||
name := strings.TrimSpace(model.Name)
|
||||
alias := strings.TrimSpace(model.Alias)
|
||||
if name == "" && alias == "" {
|
||||
continue
|
||||
}
|
||||
out(strings.ToLower(name) + "|" + strings.ToLower(alias))
|
||||
}
|
||||
})
|
||||
return hashJoined(keys)
|
||||
}
|
||||
|
||||
// ComputeClaudeModelsHash returns a stable hash for Claude model aliases.
|
||||
func ComputeClaudeModelsHash(models []config.ClaudeModel) string {
|
||||
keys := normalizeModelPairs(func(out func(key string)) {
|
||||
for _, model := range models {
|
||||
name := strings.TrimSpace(model.Name)
|
||||
alias := strings.TrimSpace(model.Alias)
|
||||
if name == "" && alias == "" {
|
||||
continue
|
||||
}
|
||||
out(strings.ToLower(name) + "|" + strings.ToLower(alias))
|
||||
}
|
||||
})
|
||||
return hashJoined(keys)
|
||||
}
|
||||
|
||||
// ComputeExcludedModelsHash returns a normalized hash for excluded model lists.
|
||||
func ComputeExcludedModelsHash(excluded []string) string {
|
||||
if len(excluded) == 0 {
|
||||
return ""
|
||||
}
|
||||
normalized := make([]string, 0, len(excluded))
|
||||
for _, entry := range excluded {
|
||||
if trimmed := strings.TrimSpace(entry); trimmed != "" {
|
||||
normalized = append(normalized, strings.ToLower(trimmed))
|
||||
}
|
||||
}
|
||||
if len(normalized) == 0 {
|
||||
return ""
|
||||
}
|
||||
sort.Strings(normalized)
|
||||
data, _ := json.Marshal(normalized)
|
||||
sum := sha256.Sum256(data)
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func normalizeModelPairs(collect func(out func(key string))) []string {
|
||||
seen := make(map[string]struct{})
|
||||
keys := make([]string, 0)
|
||||
collect(func(key string) {
|
||||
if _, exists := seen[key]; exists {
|
||||
return
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
keys = append(keys, key)
|
||||
})
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
sort.Strings(keys)
|
||||
return keys
|
||||
}
|
||||
|
||||
func hashJoined(keys []string) string {
|
||||
if len(keys) == 0 {
|
||||
return ""
|
||||
}
|
||||
sum := sha256.Sum256([]byte(strings.Join(keys, "\n")))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
159
internal/watcher/diff/model_hash_test.go
Normal file
159
internal/watcher/diff/model_hash_test.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package diff
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
func TestComputeOpenAICompatModelsHash_Deterministic(t *testing.T) {
|
||||
models := []config.OpenAICompatibilityModel{
|
||||
{Name: "gpt-4", Alias: "gpt4"},
|
||||
{Name: "gpt-3.5-turbo"},
|
||||
}
|
||||
hash1 := ComputeOpenAICompatModelsHash(models)
|
||||
hash2 := ComputeOpenAICompatModelsHash(models)
|
||||
if hash1 == "" {
|
||||
t.Fatal("hash should not be empty")
|
||||
}
|
||||
if hash1 != hash2 {
|
||||
t.Fatalf("hash should be deterministic, got %s vs %s", hash1, hash2)
|
||||
}
|
||||
changed := ComputeOpenAICompatModelsHash([]config.OpenAICompatibilityModel{{Name: "gpt-4"}, {Name: "gpt-4.1"}})
|
||||
if hash1 == changed {
|
||||
t.Fatal("hash should change when model list changes")
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeOpenAICompatModelsHash_NormalizesAndDedups(t *testing.T) {
|
||||
a := []config.OpenAICompatibilityModel{
|
||||
{Name: "gpt-4", Alias: "gpt4"},
|
||||
{Name: " "},
|
||||
{Name: "GPT-4", Alias: "GPT4"},
|
||||
{Alias: "a1"},
|
||||
}
|
||||
b := []config.OpenAICompatibilityModel{
|
||||
{Alias: "A1"},
|
||||
{Name: "gpt-4", Alias: "gpt4"},
|
||||
}
|
||||
h1 := ComputeOpenAICompatModelsHash(a)
|
||||
h2 := ComputeOpenAICompatModelsHash(b)
|
||||
if h1 == "" || h2 == "" {
|
||||
t.Fatal("expected non-empty hashes for non-empty model sets")
|
||||
}
|
||||
if h1 != h2 {
|
||||
t.Fatalf("expected normalized hashes to match, got %s / %s", h1, h2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeVertexCompatModelsHash_DifferentInputs(t *testing.T) {
|
||||
models := []config.VertexCompatModel{{Name: "gemini-pro", Alias: "pro"}}
|
||||
hash1 := ComputeVertexCompatModelsHash(models)
|
||||
hash2 := ComputeVertexCompatModelsHash([]config.VertexCompatModel{{Name: "gemini-1.5-pro", Alias: "pro"}})
|
||||
if hash1 == "" || hash2 == "" {
|
||||
t.Fatal("hashes should not be empty for non-empty models")
|
||||
}
|
||||
if hash1 == hash2 {
|
||||
t.Fatal("hash should differ when model content differs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeVertexCompatModelsHash_IgnoresBlankAndOrder(t *testing.T) {
|
||||
a := []config.VertexCompatModel{
|
||||
{Name: "m1", Alias: "a1"},
|
||||
{Name: " "},
|
||||
{Name: "M1", Alias: "A1"},
|
||||
}
|
||||
b := []config.VertexCompatModel{
|
||||
{Name: "m1", Alias: "a1"},
|
||||
}
|
||||
if h1, h2 := ComputeVertexCompatModelsHash(a), ComputeVertexCompatModelsHash(b); h1 == "" || h1 != h2 {
|
||||
t.Fatalf("expected same hash ignoring blanks/dupes, got %q / %q", h1, h2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeClaudeModelsHash_Empty(t *testing.T) {
|
||||
if got := ComputeClaudeModelsHash(nil); got != "" {
|
||||
t.Fatalf("expected empty hash for nil models, got %q", got)
|
||||
}
|
||||
if got := ComputeClaudeModelsHash([]config.ClaudeModel{}); got != "" {
|
||||
t.Fatalf("expected empty hash for empty slice, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeClaudeModelsHash_IgnoresBlankAndDedup(t *testing.T) {
|
||||
a := []config.ClaudeModel{
|
||||
{Name: "m1", Alias: "a1"},
|
||||
{Name: " "},
|
||||
{Name: "M1", Alias: "A1"},
|
||||
}
|
||||
b := []config.ClaudeModel{
|
||||
{Name: "m1", Alias: "a1"},
|
||||
}
|
||||
if h1, h2 := ComputeClaudeModelsHash(a), ComputeClaudeModelsHash(b); h1 == "" || h1 != h2 {
|
||||
t.Fatalf("expected same hash ignoring blanks/dupes, got %q / %q", h1, h2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeExcludedModelsHash_Normalizes(t *testing.T) {
|
||||
hash1 := ComputeExcludedModelsHash([]string{" A ", "b", "a"})
|
||||
hash2 := ComputeExcludedModelsHash([]string{"a", " b", "A"})
|
||||
if hash1 == "" || hash2 == "" {
|
||||
t.Fatal("hash should not be empty for non-empty input")
|
||||
}
|
||||
if hash1 != hash2 {
|
||||
t.Fatalf("hash should be order/space insensitive for same multiset, got %s vs %s", hash1, hash2)
|
||||
}
|
||||
hash3 := ComputeExcludedModelsHash([]string{"c"})
|
||||
if hash1 == hash3 {
|
||||
t.Fatal("hash should differ for different normalized sets")
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeOpenAICompatModelsHash_Empty(t *testing.T) {
|
||||
if got := ComputeOpenAICompatModelsHash(nil); got != "" {
|
||||
t.Fatalf("expected empty hash for nil input, got %q", got)
|
||||
}
|
||||
if got := ComputeOpenAICompatModelsHash([]config.OpenAICompatibilityModel{}); got != "" {
|
||||
t.Fatalf("expected empty hash for empty slice, got %q", got)
|
||||
}
|
||||
if got := ComputeOpenAICompatModelsHash([]config.OpenAICompatibilityModel{{Name: " "}, {Alias: ""}}); got != "" {
|
||||
t.Fatalf("expected empty hash for blank models, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeVertexCompatModelsHash_Empty(t *testing.T) {
|
||||
if got := ComputeVertexCompatModelsHash(nil); got != "" {
|
||||
t.Fatalf("expected empty hash for nil input, got %q", got)
|
||||
}
|
||||
if got := ComputeVertexCompatModelsHash([]config.VertexCompatModel{}); got != "" {
|
||||
t.Fatalf("expected empty hash for empty slice, got %q", got)
|
||||
}
|
||||
if got := ComputeVertexCompatModelsHash([]config.VertexCompatModel{{Name: " "}}); got != "" {
|
||||
t.Fatalf("expected empty hash for blank models, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeExcludedModelsHash_Empty(t *testing.T) {
|
||||
if got := ComputeExcludedModelsHash(nil); got != "" {
|
||||
t.Fatalf("expected empty hash for nil input, got %q", got)
|
||||
}
|
||||
if got := ComputeExcludedModelsHash([]string{}); got != "" {
|
||||
t.Fatalf("expected empty hash for empty slice, got %q", got)
|
||||
}
|
||||
if got := ComputeExcludedModelsHash([]string{" ", ""}); got != "" {
|
||||
t.Fatalf("expected empty hash for whitespace-only entries, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeClaudeModelsHash_Deterministic(t *testing.T) {
|
||||
models := []config.ClaudeModel{{Name: "a", Alias: "A"}, {Name: "b"}}
|
||||
h1 := ComputeClaudeModelsHash(models)
|
||||
h2 := ComputeClaudeModelsHash(models)
|
||||
if h1 == "" || h1 != h2 {
|
||||
t.Fatalf("expected deterministic hash, got %s / %s", h1, h2)
|
||||
}
|
||||
if h3 := ComputeClaudeModelsHash([]config.ClaudeModel{{Name: "a"}}); h3 == h1 {
|
||||
t.Fatalf("expected different hash when models change, got %s", h3)
|
||||
}
|
||||
}
|
||||
151
internal/watcher/diff/oauth_excluded.go
Normal file
151
internal/watcher/diff/oauth_excluded.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package diff
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
type ExcludedModelsSummary struct {
|
||||
hash string
|
||||
count int
|
||||
}
|
||||
|
||||
// SummarizeExcludedModels normalizes and hashes an excluded-model list.
|
||||
func SummarizeExcludedModels(list []string) ExcludedModelsSummary {
|
||||
if len(list) == 0 {
|
||||
return ExcludedModelsSummary{}
|
||||
}
|
||||
seen := make(map[string]struct{}, len(list))
|
||||
normalized := make([]string, 0, len(list))
|
||||
for _, entry := range list {
|
||||
if trimmed := strings.ToLower(strings.TrimSpace(entry)); trimmed != "" {
|
||||
if _, exists := seen[trimmed]; exists {
|
||||
continue
|
||||
}
|
||||
seen[trimmed] = struct{}{}
|
||||
normalized = append(normalized, trimmed)
|
||||
}
|
||||
}
|
||||
sort.Strings(normalized)
|
||||
return ExcludedModelsSummary{
|
||||
hash: ComputeExcludedModelsHash(normalized),
|
||||
count: len(normalized),
|
||||
}
|
||||
}
|
||||
|
||||
// SummarizeOAuthExcludedModels summarizes OAuth excluded models per provider.
|
||||
func SummarizeOAuthExcludedModels(entries map[string][]string) map[string]ExcludedModelsSummary {
|
||||
if len(entries) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]ExcludedModelsSummary, len(entries))
|
||||
for k, v := range entries {
|
||||
key := strings.ToLower(strings.TrimSpace(k))
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
out[key] = SummarizeExcludedModels(v)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// DiffOAuthExcludedModelChanges compares OAuth excluded models maps.
|
||||
func DiffOAuthExcludedModelChanges(oldMap, newMap map[string][]string) ([]string, []string) {
|
||||
oldSummary := SummarizeOAuthExcludedModels(oldMap)
|
||||
newSummary := SummarizeOAuthExcludedModels(newMap)
|
||||
keys := make(map[string]struct{}, len(oldSummary)+len(newSummary))
|
||||
for k := range oldSummary {
|
||||
keys[k] = struct{}{}
|
||||
}
|
||||
for k := range newSummary {
|
||||
keys[k] = struct{}{}
|
||||
}
|
||||
changes := make([]string, 0, len(keys))
|
||||
affected := make([]string, 0, len(keys))
|
||||
for key := range keys {
|
||||
oldInfo, okOld := oldSummary[key]
|
||||
newInfo, okNew := newSummary[key]
|
||||
switch {
|
||||
case okOld && !okNew:
|
||||
changes = append(changes, fmt.Sprintf("oauth-excluded-models[%s]: removed", key))
|
||||
affected = append(affected, key)
|
||||
case !okOld && okNew:
|
||||
changes = append(changes, fmt.Sprintf("oauth-excluded-models[%s]: added (%d entries)", key, newInfo.count))
|
||||
affected = append(affected, key)
|
||||
case okOld && okNew && oldInfo.hash != newInfo.hash:
|
||||
changes = append(changes, fmt.Sprintf("oauth-excluded-models[%s]: updated (%d -> %d entries)", key, oldInfo.count, newInfo.count))
|
||||
affected = append(affected, key)
|
||||
}
|
||||
}
|
||||
sort.Strings(changes)
|
||||
sort.Strings(affected)
|
||||
return changes, affected
|
||||
}
|
||||
|
||||
type AmpModelMappingsSummary struct {
|
||||
hash string
|
||||
count int
|
||||
}
|
||||
|
||||
// SummarizeAmpModelMappings hashes Amp model mappings for change detection.
|
||||
func SummarizeAmpModelMappings(mappings []config.AmpModelMapping) AmpModelMappingsSummary {
|
||||
if len(mappings) == 0 {
|
||||
return AmpModelMappingsSummary{}
|
||||
}
|
||||
entries := make([]string, 0, len(mappings))
|
||||
for _, mapping := range mappings {
|
||||
from := strings.TrimSpace(mapping.From)
|
||||
to := strings.TrimSpace(mapping.To)
|
||||
if from == "" && to == "" {
|
||||
continue
|
||||
}
|
||||
entries = append(entries, from+"->"+to)
|
||||
}
|
||||
if len(entries) == 0 {
|
||||
return AmpModelMappingsSummary{}
|
||||
}
|
||||
sort.Strings(entries)
|
||||
sum := sha256.Sum256([]byte(strings.Join(entries, "|")))
|
||||
return AmpModelMappingsSummary{
|
||||
hash: hex.EncodeToString(sum[:]),
|
||||
count: len(entries),
|
||||
}
|
||||
}
|
||||
|
||||
type VertexModelsSummary struct {
|
||||
hash string
|
||||
count int
|
||||
}
|
||||
|
||||
// SummarizeVertexModels hashes vertex-compatible models for change detection.
|
||||
func SummarizeVertexModels(models []config.VertexCompatModel) VertexModelsSummary {
|
||||
if len(models) == 0 {
|
||||
return VertexModelsSummary{}
|
||||
}
|
||||
names := make([]string, 0, len(models))
|
||||
for _, m := range models {
|
||||
name := strings.TrimSpace(m.Name)
|
||||
alias := strings.TrimSpace(m.Alias)
|
||||
if name == "" && alias == "" {
|
||||
continue
|
||||
}
|
||||
if alias != "" {
|
||||
name = alias
|
||||
}
|
||||
names = append(names, name)
|
||||
}
|
||||
if len(names) == 0 {
|
||||
return VertexModelsSummary{}
|
||||
}
|
||||
sort.Strings(names)
|
||||
sum := sha256.Sum256([]byte(strings.Join(names, "|")))
|
||||
return VertexModelsSummary{
|
||||
hash: hex.EncodeToString(sum[:]),
|
||||
count: len(names),
|
||||
}
|
||||
}
|
||||
109
internal/watcher/diff/oauth_excluded_test.go
Normal file
109
internal/watcher/diff/oauth_excluded_test.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package diff
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
func TestSummarizeExcludedModels_NormalizesAndDedupes(t *testing.T) {
|
||||
summary := SummarizeExcludedModels([]string{"A", " a ", "B", "b"})
|
||||
if summary.count != 2 {
|
||||
t.Fatalf("expected 2 unique entries, got %d", summary.count)
|
||||
}
|
||||
if summary.hash == "" {
|
||||
t.Fatal("expected non-empty hash")
|
||||
}
|
||||
if empty := SummarizeExcludedModels(nil); empty.count != 0 || empty.hash != "" {
|
||||
t.Fatalf("expected empty summary for nil input, got %+v", empty)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDiffOAuthExcludedModelChanges(t *testing.T) {
|
||||
oldMap := map[string][]string{
|
||||
"ProviderA": {"model-1", "model-2"},
|
||||
"providerB": {"x"},
|
||||
}
|
||||
newMap := map[string][]string{
|
||||
"providerA": {"model-1", "model-3"},
|
||||
"providerC": {"y"},
|
||||
}
|
||||
|
||||
changes, affected := DiffOAuthExcludedModelChanges(oldMap, newMap)
|
||||
expectContains(t, changes, "oauth-excluded-models[providera]: updated (2 -> 2 entries)")
|
||||
expectContains(t, changes, "oauth-excluded-models[providerb]: removed")
|
||||
expectContains(t, changes, "oauth-excluded-models[providerc]: added (1 entries)")
|
||||
|
||||
if len(affected) != 3 {
|
||||
t.Fatalf("expected 3 affected providers, got %d", len(affected))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSummarizeAmpModelMappings(t *testing.T) {
|
||||
summary := SummarizeAmpModelMappings([]config.AmpModelMapping{
|
||||
{From: "a", To: "A"},
|
||||
{From: "b", To: "B"},
|
||||
{From: " ", To: " "}, // ignored
|
||||
})
|
||||
if summary.count != 2 {
|
||||
t.Fatalf("expected 2 entries, got %d", summary.count)
|
||||
}
|
||||
if summary.hash == "" {
|
||||
t.Fatal("expected non-empty hash")
|
||||
}
|
||||
if empty := SummarizeAmpModelMappings(nil); empty.count != 0 || empty.hash != "" {
|
||||
t.Fatalf("expected empty summary for nil input, got %+v", empty)
|
||||
}
|
||||
if blank := SummarizeAmpModelMappings([]config.AmpModelMapping{{From: " ", To: " "}}); blank.count != 0 || blank.hash != "" {
|
||||
t.Fatalf("expected blank mappings ignored, got %+v", blank)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSummarizeOAuthExcludedModels_NormalizesKeys(t *testing.T) {
|
||||
out := SummarizeOAuthExcludedModels(map[string][]string{
|
||||
"ProvA": {"X"},
|
||||
"": {"ignored"},
|
||||
})
|
||||
if len(out) != 1 {
|
||||
t.Fatalf("expected only non-empty key summary, got %d", len(out))
|
||||
}
|
||||
if _, ok := out["prova"]; !ok {
|
||||
t.Fatalf("expected normalized key 'prova', got keys %v", out)
|
||||
}
|
||||
if out["prova"].count != 1 || out["prova"].hash == "" {
|
||||
t.Fatalf("unexpected summary %+v", out["prova"])
|
||||
}
|
||||
if outEmpty := SummarizeOAuthExcludedModels(nil); outEmpty != nil {
|
||||
t.Fatalf("expected nil map for nil input, got %v", outEmpty)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSummarizeVertexModels(t *testing.T) {
|
||||
summary := SummarizeVertexModels([]config.VertexCompatModel{
|
||||
{Name: "m1"},
|
||||
{Name: " ", Alias: "alias"},
|
||||
{}, // ignored
|
||||
})
|
||||
if summary.count != 2 {
|
||||
t.Fatalf("expected 2 vertex models, got %d", summary.count)
|
||||
}
|
||||
if summary.hash == "" {
|
||||
t.Fatal("expected non-empty hash")
|
||||
}
|
||||
if empty := SummarizeVertexModels(nil); empty.count != 0 || empty.hash != "" {
|
||||
t.Fatalf("expected empty summary for nil input, got %+v", empty)
|
||||
}
|
||||
if blank := SummarizeVertexModels([]config.VertexCompatModel{{Name: " "}}); blank.count != 0 || blank.hash != "" {
|
||||
t.Fatalf("expected blank model ignored, got %+v", blank)
|
||||
}
|
||||
}
|
||||
|
||||
func expectContains(t *testing.T, list []string, target string) {
|
||||
t.Helper()
|
||||
for _, entry := range list {
|
||||
if entry == target {
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Fatalf("expected list to contain %q, got %#v", target, list)
|
||||
}
|
||||
183
internal/watcher/diff/openai_compat.go
Normal file
183
internal/watcher/diff/openai_compat.go
Normal file
@@ -0,0 +1,183 @@
|
||||
package diff
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
// DiffOpenAICompatibility produces human-readable change descriptions.
|
||||
func DiffOpenAICompatibility(oldList, newList []config.OpenAICompatibility) []string {
|
||||
changes := make([]string, 0)
|
||||
oldMap := make(map[string]config.OpenAICompatibility, len(oldList))
|
||||
oldLabels := make(map[string]string, len(oldList))
|
||||
for idx, entry := range oldList {
|
||||
key, label := openAICompatKey(entry, idx)
|
||||
oldMap[key] = entry
|
||||
oldLabels[key] = label
|
||||
}
|
||||
newMap := make(map[string]config.OpenAICompatibility, len(newList))
|
||||
newLabels := make(map[string]string, len(newList))
|
||||
for idx, entry := range newList {
|
||||
key, label := openAICompatKey(entry, idx)
|
||||
newMap[key] = entry
|
||||
newLabels[key] = label
|
||||
}
|
||||
keySet := make(map[string]struct{}, len(oldMap)+len(newMap))
|
||||
for key := range oldMap {
|
||||
keySet[key] = struct{}{}
|
||||
}
|
||||
for key := range newMap {
|
||||
keySet[key] = struct{}{}
|
||||
}
|
||||
orderedKeys := make([]string, 0, len(keySet))
|
||||
for key := range keySet {
|
||||
orderedKeys = append(orderedKeys, key)
|
||||
}
|
||||
sort.Strings(orderedKeys)
|
||||
for _, key := range orderedKeys {
|
||||
oldEntry, oldOk := oldMap[key]
|
||||
newEntry, newOk := newMap[key]
|
||||
label := oldLabels[key]
|
||||
if label == "" {
|
||||
label = newLabels[key]
|
||||
}
|
||||
switch {
|
||||
case !oldOk:
|
||||
changes = append(changes, fmt.Sprintf("provider added: %s (api-keys=%d, models=%d)", label, countAPIKeys(newEntry), countOpenAIModels(newEntry.Models)))
|
||||
case !newOk:
|
||||
changes = append(changes, fmt.Sprintf("provider removed: %s (api-keys=%d, models=%d)", label, countAPIKeys(oldEntry), countOpenAIModels(oldEntry.Models)))
|
||||
default:
|
||||
if detail := describeOpenAICompatibilityUpdate(oldEntry, newEntry); detail != "" {
|
||||
changes = append(changes, fmt.Sprintf("provider updated: %s %s", label, detail))
|
||||
}
|
||||
}
|
||||
}
|
||||
return changes
|
||||
}
|
||||
|
||||
func describeOpenAICompatibilityUpdate(oldEntry, newEntry config.OpenAICompatibility) string {
|
||||
oldKeyCount := countAPIKeys(oldEntry)
|
||||
newKeyCount := countAPIKeys(newEntry)
|
||||
oldModelCount := countOpenAIModels(oldEntry.Models)
|
||||
newModelCount := countOpenAIModels(newEntry.Models)
|
||||
details := make([]string, 0, 3)
|
||||
if oldKeyCount != newKeyCount {
|
||||
details = append(details, fmt.Sprintf("api-keys %d -> %d", oldKeyCount, newKeyCount))
|
||||
}
|
||||
if oldModelCount != newModelCount {
|
||||
details = append(details, fmt.Sprintf("models %d -> %d", oldModelCount, newModelCount))
|
||||
}
|
||||
if !equalStringMap(oldEntry.Headers, newEntry.Headers) {
|
||||
details = append(details, "headers updated")
|
||||
}
|
||||
if len(details) == 0 {
|
||||
return ""
|
||||
}
|
||||
return "(" + strings.Join(details, ", ") + ")"
|
||||
}
|
||||
|
||||
func countAPIKeys(entry config.OpenAICompatibility) int {
|
||||
count := 0
|
||||
for _, keyEntry := range entry.APIKeyEntries {
|
||||
if strings.TrimSpace(keyEntry.APIKey) != "" {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func countOpenAIModels(models []config.OpenAICompatibilityModel) int {
|
||||
count := 0
|
||||
for _, model := range models {
|
||||
name := strings.TrimSpace(model.Name)
|
||||
alias := strings.TrimSpace(model.Alias)
|
||||
if name == "" && alias == "" {
|
||||
continue
|
||||
}
|
||||
count++
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func openAICompatKey(entry config.OpenAICompatibility, index int) (string, string) {
|
||||
name := strings.TrimSpace(entry.Name)
|
||||
if name != "" {
|
||||
return "name:" + name, name
|
||||
}
|
||||
base := strings.TrimSpace(entry.BaseURL)
|
||||
if base != "" {
|
||||
return "base:" + base, base
|
||||
}
|
||||
for _, model := range entry.Models {
|
||||
alias := strings.TrimSpace(model.Alias)
|
||||
if alias == "" {
|
||||
alias = strings.TrimSpace(model.Name)
|
||||
}
|
||||
if alias != "" {
|
||||
return "alias:" + alias, alias
|
||||
}
|
||||
}
|
||||
sig := openAICompatSignature(entry)
|
||||
if sig == "" {
|
||||
return fmt.Sprintf("index:%d", index), fmt.Sprintf("entry-%d", index+1)
|
||||
}
|
||||
short := sig
|
||||
if len(short) > 8 {
|
||||
short = short[:8]
|
||||
}
|
||||
return "sig:" + sig, "compat-" + short
|
||||
}
|
||||
|
||||
func openAICompatSignature(entry config.OpenAICompatibility) string {
|
||||
var parts []string
|
||||
|
||||
if v := strings.TrimSpace(entry.Name); v != "" {
|
||||
parts = append(parts, "name="+strings.ToLower(v))
|
||||
}
|
||||
if v := strings.TrimSpace(entry.BaseURL); v != "" {
|
||||
parts = append(parts, "base="+v)
|
||||
}
|
||||
|
||||
models := make([]string, 0, len(entry.Models))
|
||||
for _, model := range entry.Models {
|
||||
name := strings.TrimSpace(model.Name)
|
||||
alias := strings.TrimSpace(model.Alias)
|
||||
if name == "" && alias == "" {
|
||||
continue
|
||||
}
|
||||
models = append(models, strings.ToLower(name)+"|"+strings.ToLower(alias))
|
||||
}
|
||||
if len(models) > 0 {
|
||||
sort.Strings(models)
|
||||
parts = append(parts, "models="+strings.Join(models, ","))
|
||||
}
|
||||
|
||||
if len(entry.Headers) > 0 {
|
||||
keys := make([]string, 0, len(entry.Headers))
|
||||
for k := range entry.Headers {
|
||||
if trimmed := strings.TrimSpace(k); trimmed != "" {
|
||||
keys = append(keys, strings.ToLower(trimmed))
|
||||
}
|
||||
}
|
||||
if len(keys) > 0 {
|
||||
sort.Strings(keys)
|
||||
parts = append(parts, "headers="+strings.Join(keys, ","))
|
||||
}
|
||||
}
|
||||
|
||||
// Intentionally exclude API key material; only count non-empty entries.
|
||||
if count := countAPIKeys(entry); count > 0 {
|
||||
parts = append(parts, fmt.Sprintf("api_keys=%d", count))
|
||||
}
|
||||
|
||||
if len(parts) == 0 {
|
||||
return ""
|
||||
}
|
||||
sum := sha256.Sum256([]byte(strings.Join(parts, "|")))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
187
internal/watcher/diff/openai_compat_test.go
Normal file
187
internal/watcher/diff/openai_compat_test.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package diff
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
func TestDiffOpenAICompatibility(t *testing.T) {
|
||||
oldList := []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "provider-a",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
|
||||
{APIKey: "key-a"},
|
||||
},
|
||||
Models: []config.OpenAICompatibilityModel{
|
||||
{Name: "m1"},
|
||||
},
|
||||
},
|
||||
}
|
||||
newList := []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "provider-a",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
|
||||
{APIKey: "key-a"},
|
||||
{APIKey: "key-b"},
|
||||
},
|
||||
Models: []config.OpenAICompatibilityModel{
|
||||
{Name: "m1"},
|
||||
{Name: "m2"},
|
||||
},
|
||||
Headers: map[string]string{"X-Test": "1"},
|
||||
},
|
||||
{
|
||||
Name: "provider-b",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{{APIKey: "key-b"}},
|
||||
},
|
||||
}
|
||||
|
||||
changes := DiffOpenAICompatibility(oldList, newList)
|
||||
expectContains(t, changes, "provider added: provider-b (api-keys=1, models=0)")
|
||||
expectContains(t, changes, "provider updated: provider-a (api-keys 1 -> 2, models 1 -> 2, headers updated)")
|
||||
}
|
||||
|
||||
func TestDiffOpenAICompatibility_RemovedAndUnchanged(t *testing.T) {
|
||||
oldList := []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "provider-a",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{{APIKey: "key-a"}},
|
||||
Models: []config.OpenAICompatibilityModel{{Name: "m1"}},
|
||||
},
|
||||
}
|
||||
newList := []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "provider-a",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{{APIKey: "key-a"}},
|
||||
Models: []config.OpenAICompatibilityModel{{Name: "m1"}},
|
||||
},
|
||||
}
|
||||
if changes := DiffOpenAICompatibility(oldList, newList); len(changes) != 0 {
|
||||
t.Fatalf("expected no changes, got %v", changes)
|
||||
}
|
||||
|
||||
newList = nil
|
||||
changes := DiffOpenAICompatibility(oldList, newList)
|
||||
expectContains(t, changes, "provider removed: provider-a (api-keys=1, models=1)")
|
||||
}
|
||||
|
||||
func TestOpenAICompatKeyFallbacks(t *testing.T) {
|
||||
entry := config.OpenAICompatibility{
|
||||
BaseURL: "http://base",
|
||||
Models: []config.OpenAICompatibilityModel{{Alias: "alias-only"}},
|
||||
}
|
||||
key, label := openAICompatKey(entry, 0)
|
||||
if key != "base:http://base" || label != "http://base" {
|
||||
t.Fatalf("expected base key, got %s/%s", key, label)
|
||||
}
|
||||
|
||||
entry.BaseURL = ""
|
||||
key, label = openAICompatKey(entry, 1)
|
||||
if key != "alias:alias-only" || label != "alias-only" {
|
||||
t.Fatalf("expected alias fallback, got %s/%s", key, label)
|
||||
}
|
||||
|
||||
entry.Models = nil
|
||||
key, label = openAICompatKey(entry, 2)
|
||||
if key != "index:2" || label != "entry-3" {
|
||||
t.Fatalf("expected index fallback, got %s/%s", key, label)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAICompatKey_UsesName(t *testing.T) {
|
||||
entry := config.OpenAICompatibility{Name: "My-Provider"}
|
||||
key, label := openAICompatKey(entry, 0)
|
||||
if key != "name:My-Provider" || label != "My-Provider" {
|
||||
t.Fatalf("expected name key, got %s/%s", key, label)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAICompatKey_SignatureFallbackWhenOnlyAPIKeys(t *testing.T) {
|
||||
entry := config.OpenAICompatibility{
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{{APIKey: "k1"}, {APIKey: "k2"}},
|
||||
}
|
||||
key, label := openAICompatKey(entry, 0)
|
||||
if !strings.HasPrefix(key, "sig:") || !strings.HasPrefix(label, "compat-") {
|
||||
t.Fatalf("expected signature key, got %s/%s", key, label)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAICompatSignature_EmptyReturnsEmpty(t *testing.T) {
|
||||
if got := openAICompatSignature(config.OpenAICompatibility{}); got != "" {
|
||||
t.Fatalf("expected empty signature, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAICompatSignature_StableAndNormalized(t *testing.T) {
|
||||
a := config.OpenAICompatibility{
|
||||
Name: " Provider ",
|
||||
BaseURL: "http://base",
|
||||
Models: []config.OpenAICompatibilityModel{
|
||||
{Name: "m1"},
|
||||
{Name: " "},
|
||||
{Alias: "A1"},
|
||||
},
|
||||
Headers: map[string]string{
|
||||
"X-Test": "1",
|
||||
" ": "ignored",
|
||||
},
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
|
||||
{APIKey: "k1"},
|
||||
{APIKey: " "},
|
||||
},
|
||||
}
|
||||
b := config.OpenAICompatibility{
|
||||
Name: "provider",
|
||||
BaseURL: "http://base",
|
||||
Models: []config.OpenAICompatibilityModel{
|
||||
{Alias: "a1"},
|
||||
{Name: "m1"},
|
||||
},
|
||||
Headers: map[string]string{
|
||||
"x-test": "2",
|
||||
},
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
|
||||
{APIKey: "k2"},
|
||||
},
|
||||
}
|
||||
|
||||
sigA := openAICompatSignature(a)
|
||||
sigB := openAICompatSignature(b)
|
||||
if sigA == "" || sigB == "" {
|
||||
t.Fatalf("expected non-empty signatures, got %q / %q", sigA, sigB)
|
||||
}
|
||||
if sigA != sigB {
|
||||
t.Fatalf("expected normalized signatures to match, got %s / %s", sigA, sigB)
|
||||
}
|
||||
|
||||
c := b
|
||||
c.Models = append(c.Models, config.OpenAICompatibilityModel{Name: "m2"})
|
||||
if sigC := openAICompatSignature(c); sigC == sigB {
|
||||
t.Fatalf("expected signature to change when models change, got %s", sigC)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCountOpenAIModelsSkipsBlanks(t *testing.T) {
|
||||
models := []config.OpenAICompatibilityModel{
|
||||
{Name: "m1"},
|
||||
{Name: ""},
|
||||
{Alias: ""},
|
||||
{Name: " "},
|
||||
{Alias: "a1"},
|
||||
}
|
||||
if got := countOpenAIModels(models); got != 2 {
|
||||
t.Fatalf("expected 2 counted models, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAICompatKeyUsesModelNameWhenAliasEmpty(t *testing.T) {
|
||||
entry := config.OpenAICompatibility{
|
||||
Models: []config.OpenAICompatibilityModel{{Name: "model-name"}},
|
||||
}
|
||||
key, label := openAICompatKey(entry, 5)
|
||||
if key != "alias:model-name" || label != "model-name" {
|
||||
t.Fatalf("expected model-name fallback, got %s/%s", key, label)
|
||||
}
|
||||
}
|
||||
@@ -21,9 +21,10 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff"
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
@@ -187,7 +188,7 @@ func (w *Watcher) Start(ctx context.Context) error {
|
||||
go w.processEvents(ctx)
|
||||
|
||||
// Perform an initial full reload based on current config and auth dir
|
||||
w.reloadClients(true, nil)
|
||||
w.reloadClients(true, nil, false)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -202,7 +203,7 @@ func (w *Watcher) watchKiroIDETokenFile() {
|
||||
|
||||
// Kiro IDE stores tokens in ~/.aws/sso/cache/
|
||||
kiroTokenDir := filepath.Join(homeDir, ".aws", "sso", "cache")
|
||||
|
||||
|
||||
// Check if directory exists
|
||||
if _, statErr := os.Stat(kiroTokenDir); os.IsNotExist(statErr) {
|
||||
log.Debugf("Kiro IDE token directory does not exist: %s", kiroTokenDir)
|
||||
@@ -305,7 +306,7 @@ func (w *Watcher) DispatchRuntimeAuthUpdate(update AuthUpdate) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (w *Watcher) refreshAuthState() {
|
||||
func (w *Watcher) refreshAuthState(force bool) {
|
||||
auths := w.SnapshotCoreAuths()
|
||||
w.clientsMutex.Lock()
|
||||
if len(w.runtimeAuths) > 0 {
|
||||
@@ -315,12 +316,12 @@ func (w *Watcher) refreshAuthState() {
|
||||
}
|
||||
}
|
||||
}
|
||||
updates := w.prepareAuthUpdatesLocked(auths)
|
||||
updates := w.prepareAuthUpdatesLocked(auths, force)
|
||||
w.clientsMutex.Unlock()
|
||||
w.dispatchAuthUpdates(updates)
|
||||
}
|
||||
|
||||
func (w *Watcher) prepareAuthUpdatesLocked(auths []*coreauth.Auth) []AuthUpdate {
|
||||
func (w *Watcher) prepareAuthUpdatesLocked(auths []*coreauth.Auth, force bool) []AuthUpdate {
|
||||
newState := make(map[string]*coreauth.Auth, len(auths))
|
||||
for _, auth := range auths {
|
||||
if auth == nil || auth.ID == "" {
|
||||
@@ -347,7 +348,7 @@ func (w *Watcher) prepareAuthUpdatesLocked(auths []*coreauth.Auth) []AuthUpdate
|
||||
for id, auth := range newState {
|
||||
if existing, ok := w.currentAuths[id]; !ok {
|
||||
updates = append(updates, AuthUpdate{Action: AuthUpdateActionAdd, ID: id, Auth: auth.Clone()})
|
||||
} else if !authEqual(existing, auth) {
|
||||
} else if force || !authEqual(existing, auth) {
|
||||
updates = append(updates, AuthUpdate{Action: AuthUpdateActionModify, ID: id, Auth: auth.Clone()})
|
||||
}
|
||||
}
|
||||
@@ -514,170 +515,6 @@ func normalizeAuth(a *coreauth.Auth) *coreauth.Auth {
|
||||
return clone
|
||||
}
|
||||
|
||||
// computeOpenAICompatModelsHash returns a stable hash for the compatibility models so that
|
||||
// changes to the model list trigger auth updates during hot reload.
|
||||
func computeOpenAICompatModelsHash(models []config.OpenAICompatibilityModel) string {
|
||||
if len(models) == 0 {
|
||||
return ""
|
||||
}
|
||||
data, err := json.Marshal(models)
|
||||
if err != nil || len(data) == 0 {
|
||||
return ""
|
||||
}
|
||||
sum := sha256.Sum256(data)
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func computeVertexCompatModelsHash(models []config.VertexCompatModel) string {
|
||||
if len(models) == 0 {
|
||||
return ""
|
||||
}
|
||||
data, err := json.Marshal(models)
|
||||
if err != nil || len(data) == 0 {
|
||||
return ""
|
||||
}
|
||||
sum := sha256.Sum256(data)
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
// computeClaudeModelsHash returns a stable hash for Claude model aliases.
|
||||
func computeClaudeModelsHash(models []config.ClaudeModel) string {
|
||||
if len(models) == 0 {
|
||||
return ""
|
||||
}
|
||||
data, err := json.Marshal(models)
|
||||
if err != nil || len(data) == 0 {
|
||||
return ""
|
||||
}
|
||||
sum := sha256.Sum256(data)
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func computeExcludedModelsHash(excluded []string) string {
|
||||
if len(excluded) == 0 {
|
||||
return ""
|
||||
}
|
||||
normalized := make([]string, 0, len(excluded))
|
||||
for _, entry := range excluded {
|
||||
if trimmed := strings.TrimSpace(entry); trimmed != "" {
|
||||
normalized = append(normalized, strings.ToLower(trimmed))
|
||||
}
|
||||
}
|
||||
if len(normalized) == 0 {
|
||||
return ""
|
||||
}
|
||||
sort.Strings(normalized)
|
||||
data, err := json.Marshal(normalized)
|
||||
if err != nil || len(data) == 0 {
|
||||
return ""
|
||||
}
|
||||
sum := sha256.Sum256(data)
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
type excludedModelsSummary struct {
|
||||
hash string
|
||||
count int
|
||||
}
|
||||
|
||||
func summarizeExcludedModels(list []string) excludedModelsSummary {
|
||||
if len(list) == 0 {
|
||||
return excludedModelsSummary{}
|
||||
}
|
||||
seen := make(map[string]struct{}, len(list))
|
||||
normalized := make([]string, 0, len(list))
|
||||
for _, entry := range list {
|
||||
if trimmed := strings.ToLower(strings.TrimSpace(entry)); trimmed != "" {
|
||||
if _, exists := seen[trimmed]; exists {
|
||||
continue
|
||||
}
|
||||
seen[trimmed] = struct{}{}
|
||||
normalized = append(normalized, trimmed)
|
||||
}
|
||||
}
|
||||
sort.Strings(normalized)
|
||||
return excludedModelsSummary{
|
||||
hash: computeExcludedModelsHash(normalized),
|
||||
count: len(normalized),
|
||||
}
|
||||
}
|
||||
|
||||
type ampModelMappingsSummary struct {
|
||||
hash string
|
||||
count int
|
||||
}
|
||||
|
||||
func summarizeAmpModelMappings(mappings []config.AmpModelMapping) ampModelMappingsSummary {
|
||||
if len(mappings) == 0 {
|
||||
return ampModelMappingsSummary{}
|
||||
}
|
||||
entries := make([]string, 0, len(mappings))
|
||||
for _, mapping := range mappings {
|
||||
from := strings.TrimSpace(mapping.From)
|
||||
to := strings.TrimSpace(mapping.To)
|
||||
if from == "" && to == "" {
|
||||
continue
|
||||
}
|
||||
entries = append(entries, from+"->"+to)
|
||||
}
|
||||
if len(entries) == 0 {
|
||||
return ampModelMappingsSummary{}
|
||||
}
|
||||
sort.Strings(entries)
|
||||
sum := sha256.Sum256([]byte(strings.Join(entries, "|")))
|
||||
return ampModelMappingsSummary{
|
||||
hash: hex.EncodeToString(sum[:]),
|
||||
count: len(entries),
|
||||
}
|
||||
}
|
||||
|
||||
func summarizeOAuthExcludedModels(entries map[string][]string) map[string]excludedModelsSummary {
|
||||
if len(entries) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]excludedModelsSummary, len(entries))
|
||||
for k, v := range entries {
|
||||
key := strings.ToLower(strings.TrimSpace(k))
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
out[key] = summarizeExcludedModels(v)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func diffOAuthExcludedModelChanges(oldMap, newMap map[string][]string) ([]string, []string) {
|
||||
oldSummary := summarizeOAuthExcludedModels(oldMap)
|
||||
newSummary := summarizeOAuthExcludedModels(newMap)
|
||||
keys := make(map[string]struct{}, len(oldSummary)+len(newSummary))
|
||||
for k := range oldSummary {
|
||||
keys[k] = struct{}{}
|
||||
}
|
||||
for k := range newSummary {
|
||||
keys[k] = struct{}{}
|
||||
}
|
||||
changes := make([]string, 0, len(keys))
|
||||
affected := make([]string, 0, len(keys))
|
||||
for key := range keys {
|
||||
oldInfo, okOld := oldSummary[key]
|
||||
newInfo, okNew := newSummary[key]
|
||||
switch {
|
||||
case okOld && !okNew:
|
||||
changes = append(changes, fmt.Sprintf("oauth-excluded-models[%s]: removed", key))
|
||||
affected = append(affected, key)
|
||||
case !okOld && okNew:
|
||||
changes = append(changes, fmt.Sprintf("oauth-excluded-models[%s]: added (%d entries)", key, newInfo.count))
|
||||
affected = append(affected, key)
|
||||
case okOld && okNew && oldInfo.hash != newInfo.hash:
|
||||
changes = append(changes, fmt.Sprintf("oauth-excluded-models[%s]: updated (%d -> %d entries)", key, oldInfo.count, newInfo.count))
|
||||
affected = append(affected, key)
|
||||
}
|
||||
}
|
||||
sort.Strings(changes)
|
||||
sort.Strings(affected)
|
||||
return changes, affected
|
||||
}
|
||||
|
||||
func applyAuthExcludedModelsMeta(auth *coreauth.Auth, cfg *config.Config, perKey []string, authKind string) {
|
||||
if auth == nil || cfg == nil {
|
||||
return
|
||||
@@ -706,7 +543,7 @@ func applyAuthExcludedModelsMeta(auth *coreauth.Auth, cfg *config.Config, perKey
|
||||
combined = append(combined, k)
|
||||
}
|
||||
sort.Strings(combined)
|
||||
hash := computeExcludedModelsHash(combined)
|
||||
hash := diff.ComputeExcludedModelsHash(combined)
|
||||
if auth.Attributes == nil {
|
||||
auth.Attributes = make(map[string]string)
|
||||
}
|
||||
@@ -820,16 +657,16 @@ func (w *Watcher) handleEvent(event fsnotify.Event) {
|
||||
normalizedAuthDir := w.normalizeAuthPath(w.authDir)
|
||||
isConfigEvent := normalizedName == normalizedConfigPath && event.Op&configOps != 0
|
||||
authOps := fsnotify.Create | fsnotify.Write | fsnotify.Remove | fsnotify.Rename
|
||||
isAuthJSON := strings.HasPrefix(normalizedName, normalizedAuthDir) && strings.HasSuffix(normalizedName, ".json") && event.Op&authOps != 0
|
||||
|
||||
isAuthJSON := strings.HasPrefix(normalizedName, normalizedAuthDir) && strings.HasSuffix(normalizedName, ".json") && event.Op&authOps != 0
|
||||
|
||||
// Check for Kiro IDE token file changes
|
||||
isKiroIDEToken := w.isKiroIDETokenFile(event.Name) && event.Op&authOps != 0
|
||||
|
||||
|
||||
if !isConfigEvent && !isAuthJSON && !isKiroIDEToken {
|
||||
// Ignore unrelated files (e.g., cookie snapshots *.cookie) and other noise.
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// Handle Kiro IDE token file changes
|
||||
if isKiroIDEToken {
|
||||
w.handleKiroIDETokenChange(event)
|
||||
@@ -860,7 +697,7 @@ func (w *Watcher) handleEvent(event fsnotify.Event) {
|
||||
log.Debugf("auth file unchanged (hash match), skipping reload: %s", filepath.Base(event.Name))
|
||||
return
|
||||
}
|
||||
fmt.Printf("auth file changed (%s): %s, processing incrementally\n", event.Op.String(), filepath.Base(event.Name))
|
||||
log.Infof("auth file changed (%s): %s, processing incrementally", event.Op.String(), filepath.Base(event.Name))
|
||||
w.addOrUpdateClient(event.Name)
|
||||
return
|
||||
}
|
||||
@@ -868,7 +705,7 @@ func (w *Watcher) handleEvent(event fsnotify.Event) {
|
||||
log.Debugf("ignoring remove for unknown auth file: %s", filepath.Base(event.Name))
|
||||
return
|
||||
}
|
||||
fmt.Printf("auth file changed (%s): %s, processing incrementally\n", event.Op.String(), filepath.Base(event.Name))
|
||||
log.Infof("auth file changed (%s): %s, processing incrementally", event.Op.String(), filepath.Base(event.Name))
|
||||
w.removeClient(event.Name)
|
||||
return
|
||||
}
|
||||
@@ -877,7 +714,7 @@ func (w *Watcher) handleEvent(event fsnotify.Event) {
|
||||
log.Debugf("auth file unchanged (hash match), skipping reload: %s", filepath.Base(event.Name))
|
||||
return
|
||||
}
|
||||
fmt.Printf("auth file changed (%s): %s, processing incrementally\n", event.Op.String(), filepath.Base(event.Name))
|
||||
log.Infof("auth file changed (%s): %s, processing incrementally", event.Op.String(), filepath.Base(event.Name))
|
||||
w.addOrUpdateClient(event.Name)
|
||||
}
|
||||
}
|
||||
@@ -928,7 +765,7 @@ func (w *Watcher) handleKiroIDETokenChange(event fsnotify.Event) {
|
||||
log.Infof("Kiro IDE token file updated, access token refreshed (provider: %s)", tokenData.Provider)
|
||||
|
||||
// Trigger auth state refresh to pick up the new token
|
||||
w.refreshAuthState()
|
||||
w.refreshAuthState(true)
|
||||
|
||||
// Notify callback if set
|
||||
w.clientsMutex.RLock()
|
||||
@@ -962,7 +799,7 @@ func (w *Watcher) reloadConfigIfChanged() {
|
||||
log.Debugf("config file content unchanged (hash match), skipping reload")
|
||||
return
|
||||
}
|
||||
fmt.Printf("config file changed, reloading: %s\n", w.configPath)
|
||||
log.Infof("config file changed, reloading: %s", w.configPath)
|
||||
if w.reloadConfig() {
|
||||
finalHash := newHash
|
||||
if updatedData, errRead := os.ReadFile(w.configPath); errRead == nil && len(updatedData) > 0 {
|
||||
@@ -1008,7 +845,7 @@ func (w *Watcher) reloadConfig() bool {
|
||||
|
||||
var affectedOAuthProviders []string
|
||||
if oldConfig != nil {
|
||||
_, affectedOAuthProviders = diffOAuthExcludedModelChanges(oldConfig.OAuthExcludedModels, newConfig.OAuthExcludedModels)
|
||||
_, affectedOAuthProviders = diff.DiffOAuthExcludedModelChanges(oldConfig.OAuthExcludedModels, newConfig.OAuthExcludedModels)
|
||||
}
|
||||
|
||||
// Always apply the current log level based on the latest config.
|
||||
@@ -1021,7 +858,7 @@ func (w *Watcher) reloadConfig() bool {
|
||||
|
||||
// Log configuration changes in debug mode, only when there are material diffs
|
||||
if oldConfig != nil {
|
||||
details := buildConfigChangeDetails(oldConfig, newConfig)
|
||||
details := diff.BuildConfigChangeDetails(oldConfig, newConfig)
|
||||
if len(details) > 0 {
|
||||
log.Debugf("config changes detected:")
|
||||
for _, d := range details {
|
||||
@@ -1033,15 +870,16 @@ func (w *Watcher) reloadConfig() bool {
|
||||
}
|
||||
|
||||
authDirChanged := oldConfig == nil || oldConfig.AuthDir != newConfig.AuthDir
|
||||
forceAuthRefresh := oldConfig != nil && oldConfig.ForceModelPrefix != newConfig.ForceModelPrefix
|
||||
|
||||
log.Infof("config successfully reloaded, triggering client reload")
|
||||
// Reload clients with new config
|
||||
w.reloadClients(authDirChanged, affectedOAuthProviders)
|
||||
w.reloadClients(authDirChanged, affectedOAuthProviders, forceAuthRefresh)
|
||||
return true
|
||||
}
|
||||
|
||||
// reloadClients performs a full scan and reload of all clients.
|
||||
func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string) {
|
||||
func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string, forceAuthRefresh bool) {
|
||||
log.Debugf("starting full client load process")
|
||||
|
||||
w.clientsMutex.RLock()
|
||||
@@ -1132,7 +970,7 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string
|
||||
w.reloadCallback(cfg)
|
||||
}
|
||||
|
||||
w.refreshAuthState()
|
||||
w.refreshAuthState(forceAuthRefresh)
|
||||
|
||||
log.Infof("full client load complete - %d clients (%d auth files + %d Gemini API keys + %d Vertex API keys + %d Claude API keys + %d Codex keys + %d OpenAI-compat)",
|
||||
totalNewClients,
|
||||
@@ -1183,7 +1021,7 @@ func (w *Watcher) addOrUpdateClient(path string) {
|
||||
|
||||
w.clientsMutex.Unlock() // Unlock before the callback
|
||||
|
||||
w.refreshAuthState()
|
||||
w.refreshAuthState(false)
|
||||
|
||||
if w.reloadCallback != nil {
|
||||
log.Debugf("triggering server update callback after add/update")
|
||||
@@ -1202,7 +1040,7 @@ func (w *Watcher) removeClient(path string) {
|
||||
|
||||
w.clientsMutex.Unlock() // Release the lock before the callback
|
||||
|
||||
w.refreshAuthState()
|
||||
w.refreshAuthState(false)
|
||||
|
||||
if w.reloadCallback != nil {
|
||||
log.Debugf("triggering server update callback after removal")
|
||||
@@ -1231,6 +1069,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
prefix := strings.TrimSpace(entry.Prefix)
|
||||
base := strings.TrimSpace(entry.BaseURL)
|
||||
proxyURL := strings.TrimSpace(entry.ProxyURL)
|
||||
id, token := idGen.next("gemini:apikey", key, base)
|
||||
@@ -1246,6 +1085,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
||||
ID: id,
|
||||
Provider: "gemini",
|
||||
Label: "gemini-apikey",
|
||||
Prefix: prefix,
|
||||
Status: coreauth.StatusActive,
|
||||
ProxyURL: proxyURL,
|
||||
Attributes: attrs,
|
||||
@@ -1263,6 +1103,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
prefix := strings.TrimSpace(ck.Prefix)
|
||||
base := strings.TrimSpace(ck.BaseURL)
|
||||
id, token := idGen.next("claude:apikey", key, base)
|
||||
attrs := map[string]string{
|
||||
@@ -1272,7 +1113,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
||||
if base != "" {
|
||||
attrs["base_url"] = base
|
||||
}
|
||||
if hash := computeClaudeModelsHash(ck.Models); hash != "" {
|
||||
if hash := diff.ComputeClaudeModelsHash(ck.Models); hash != "" {
|
||||
attrs["models_hash"] = hash
|
||||
}
|
||||
addConfigHeadersToAttrs(ck.Headers, attrs)
|
||||
@@ -1281,6 +1122,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
||||
ID: id,
|
||||
Provider: "claude",
|
||||
Label: "claude-apikey",
|
||||
Prefix: prefix,
|
||||
Status: coreauth.StatusActive,
|
||||
ProxyURL: proxyURL,
|
||||
Attributes: attrs,
|
||||
@@ -1297,6 +1139,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
prefix := strings.TrimSpace(ck.Prefix)
|
||||
id, token := idGen.next("codex:apikey", key, ck.BaseURL)
|
||||
attrs := map[string]string{
|
||||
"source": fmt.Sprintf("config:codex[%s]", token),
|
||||
@@ -1311,6 +1154,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
||||
ID: id,
|
||||
Provider: "codex",
|
||||
Label: "codex-apikey",
|
||||
Prefix: prefix,
|
||||
Status: coreauth.StatusActive,
|
||||
ProxyURL: proxyURL,
|
||||
Attributes: attrs,
|
||||
@@ -1404,6 +1248,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
||||
}
|
||||
for i := range cfg.OpenAICompatibility {
|
||||
compat := &cfg.OpenAICompatibility[i]
|
||||
prefix := strings.TrimSpace(compat.Prefix)
|
||||
providerName := strings.ToLower(strings.TrimSpace(compat.Name))
|
||||
if providerName == "" {
|
||||
providerName = "openai-compatibility"
|
||||
@@ -1427,7 +1272,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
||||
if key != "" {
|
||||
attrs["api_key"] = key
|
||||
}
|
||||
if hash := computeOpenAICompatModelsHash(compat.Models); hash != "" {
|
||||
if hash := diff.ComputeOpenAICompatModelsHash(compat.Models); hash != "" {
|
||||
attrs["models_hash"] = hash
|
||||
}
|
||||
addConfigHeadersToAttrs(compat.Headers, attrs)
|
||||
@@ -1435,6 +1280,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
||||
ID: id,
|
||||
Provider: providerName,
|
||||
Label: compat.Name,
|
||||
Prefix: prefix,
|
||||
Status: coreauth.StatusActive,
|
||||
ProxyURL: proxyURL,
|
||||
Attributes: attrs,
|
||||
@@ -1453,7 +1299,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
||||
"compat_name": compat.Name,
|
||||
"provider_key": providerName,
|
||||
}
|
||||
if hash := computeOpenAICompatModelsHash(compat.Models); hash != "" {
|
||||
if hash := diff.ComputeOpenAICompatModelsHash(compat.Models); hash != "" {
|
||||
attrs["models_hash"] = hash
|
||||
}
|
||||
addConfigHeadersToAttrs(compat.Headers, attrs)
|
||||
@@ -1461,6 +1307,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
||||
ID: id,
|
||||
Provider: providerName,
|
||||
Label: compat.Name,
|
||||
Prefix: prefix,
|
||||
Status: coreauth.StatusActive,
|
||||
Attributes: attrs,
|
||||
CreatedAt: now,
|
||||
@@ -1478,8 +1325,9 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
||||
base := strings.TrimSpace(compat.BaseURL)
|
||||
|
||||
key := strings.TrimSpace(compat.APIKey)
|
||||
prefix := strings.TrimSpace(compat.Prefix)
|
||||
proxyURL := strings.TrimSpace(compat.ProxyURL)
|
||||
idKind := fmt.Sprintf("vertex:apikey:%s", base)
|
||||
idKind := "vertex:apikey"
|
||||
id, token := idGen.next(idKind, key, base, proxyURL)
|
||||
attrs := map[string]string{
|
||||
"source": fmt.Sprintf("config:vertex-apikey[%s]", token),
|
||||
@@ -1489,7 +1337,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
||||
if key != "" {
|
||||
attrs["api_key"] = key
|
||||
}
|
||||
if hash := computeVertexCompatModelsHash(compat.Models); hash != "" {
|
||||
if hash := diff.ComputeVertexCompatModelsHash(compat.Models); hash != "" {
|
||||
attrs["models_hash"] = hash
|
||||
}
|
||||
addConfigHeadersToAttrs(compat.Headers, attrs)
|
||||
@@ -1497,6 +1345,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
||||
ID: id,
|
||||
Provider: providerName,
|
||||
Label: "vertex-apikey",
|
||||
Prefix: prefix,
|
||||
Status: coreauth.StatusActive,
|
||||
ProxyURL: proxyURL,
|
||||
Attributes: attrs,
|
||||
@@ -1532,7 +1381,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
||||
continue
|
||||
}
|
||||
t, _ := metadata["type"].(string)
|
||||
|
||||
|
||||
// Detect Kiro auth files by auth_method field (they don't have "type" field)
|
||||
if t == "" {
|
||||
if authMethod, _ := metadata["auth_method"].(string); authMethod == "builder-id" || authMethod == "social" {
|
||||
@@ -1540,7 +1389,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
||||
log.Debugf("SnapshotCoreAuths: detected Kiro auth by auth_method: %s", name)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if t == "" {
|
||||
log.Debugf("SnapshotCoreAuths: skipping file without type: %s", name)
|
||||
continue
|
||||
@@ -1571,10 +1420,20 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
||||
proxyURL = p
|
||||
}
|
||||
|
||||
prefix := ""
|
||||
if rawPrefix, ok := metadata["prefix"].(string); ok {
|
||||
trimmed := strings.TrimSpace(rawPrefix)
|
||||
trimmed = strings.Trim(trimmed, "/")
|
||||
if trimmed != "" && !strings.Contains(trimmed, "/") {
|
||||
prefix = trimmed
|
||||
}
|
||||
}
|
||||
|
||||
a := &coreauth.Auth{
|
||||
ID: id,
|
||||
Provider: provider,
|
||||
Label: label,
|
||||
Prefix: prefix,
|
||||
Status: coreauth.StatusActive,
|
||||
Attributes: map[string]string{
|
||||
"source": full,
|
||||
@@ -1593,7 +1452,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
||||
a.NextRefreshAfter = expiresAt.Add(-30 * time.Minute)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Apply global preferred endpoint setting if not present in metadata
|
||||
if cfg.KiroPreferredEndpoint != "" {
|
||||
// Check if already set in metadata (which takes precedence in executor)
|
||||
@@ -1682,6 +1541,7 @@ func synthesizeGeminiVirtualAuths(primary *coreauth.Auth, metadata map[string]an
|
||||
Attributes: attrs,
|
||||
Metadata: metadataCopy,
|
||||
ProxyURL: primary.ProxyURL,
|
||||
Prefix: primary.Prefix,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
Runtime: geminicli.NewVirtualCredential(projectID, shared),
|
||||
@@ -1795,324 +1655,6 @@ func BuildAPIKeyClients(cfg *config.Config) (int, int, int, int, int) {
|
||||
return geminiAPIKeyCount, vertexCompatAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount
|
||||
}
|
||||
|
||||
func diffOpenAICompatibility(oldList, newList []config.OpenAICompatibility) []string {
|
||||
changes := make([]string, 0)
|
||||
oldMap := make(map[string]config.OpenAICompatibility, len(oldList))
|
||||
oldLabels := make(map[string]string, len(oldList))
|
||||
for idx, entry := range oldList {
|
||||
key, label := openAICompatKey(entry, idx)
|
||||
oldMap[key] = entry
|
||||
oldLabels[key] = label
|
||||
}
|
||||
newMap := make(map[string]config.OpenAICompatibility, len(newList))
|
||||
newLabels := make(map[string]string, len(newList))
|
||||
for idx, entry := range newList {
|
||||
key, label := openAICompatKey(entry, idx)
|
||||
newMap[key] = entry
|
||||
newLabels[key] = label
|
||||
}
|
||||
keySet := make(map[string]struct{}, len(oldMap)+len(newMap))
|
||||
for key := range oldMap {
|
||||
keySet[key] = struct{}{}
|
||||
}
|
||||
for key := range newMap {
|
||||
keySet[key] = struct{}{}
|
||||
}
|
||||
orderedKeys := make([]string, 0, len(keySet))
|
||||
for key := range keySet {
|
||||
orderedKeys = append(orderedKeys, key)
|
||||
}
|
||||
sort.Strings(orderedKeys)
|
||||
for _, key := range orderedKeys {
|
||||
oldEntry, oldOk := oldMap[key]
|
||||
newEntry, newOk := newMap[key]
|
||||
label := oldLabels[key]
|
||||
if label == "" {
|
||||
label = newLabels[key]
|
||||
}
|
||||
switch {
|
||||
case !oldOk:
|
||||
changes = append(changes, fmt.Sprintf("provider added: %s (api-keys=%d, models=%d)", label, countAPIKeys(newEntry), countOpenAIModels(newEntry.Models)))
|
||||
case !newOk:
|
||||
changes = append(changes, fmt.Sprintf("provider removed: %s (api-keys=%d, models=%d)", label, countAPIKeys(oldEntry), countOpenAIModels(oldEntry.Models)))
|
||||
default:
|
||||
if detail := describeOpenAICompatibilityUpdate(oldEntry, newEntry); detail != "" {
|
||||
changes = append(changes, fmt.Sprintf("provider updated: %s %s", label, detail))
|
||||
}
|
||||
}
|
||||
}
|
||||
return changes
|
||||
}
|
||||
|
||||
func describeOpenAICompatibilityUpdate(oldEntry, newEntry config.OpenAICompatibility) string {
|
||||
oldKeyCount := countAPIKeys(oldEntry)
|
||||
newKeyCount := countAPIKeys(newEntry)
|
||||
oldModelCount := countOpenAIModels(oldEntry.Models)
|
||||
newModelCount := countOpenAIModels(newEntry.Models)
|
||||
details := make([]string, 0, 3)
|
||||
if oldKeyCount != newKeyCount {
|
||||
details = append(details, fmt.Sprintf("api-keys %d -> %d", oldKeyCount, newKeyCount))
|
||||
}
|
||||
if oldModelCount != newModelCount {
|
||||
details = append(details, fmt.Sprintf("models %d -> %d", oldModelCount, newModelCount))
|
||||
}
|
||||
if !equalStringMap(oldEntry.Headers, newEntry.Headers) {
|
||||
details = append(details, "headers updated")
|
||||
}
|
||||
if len(details) == 0 {
|
||||
return ""
|
||||
}
|
||||
return "(" + strings.Join(details, ", ") + ")"
|
||||
}
|
||||
|
||||
func countAPIKeys(entry config.OpenAICompatibility) int {
|
||||
count := 0
|
||||
for _, keyEntry := range entry.APIKeyEntries {
|
||||
if strings.TrimSpace(keyEntry.APIKey) != "" {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func countOpenAIModels(models []config.OpenAICompatibilityModel) int {
|
||||
count := 0
|
||||
for _, model := range models {
|
||||
name := strings.TrimSpace(model.Name)
|
||||
alias := strings.TrimSpace(model.Alias)
|
||||
if name == "" && alias == "" {
|
||||
continue
|
||||
}
|
||||
count++
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func openAICompatKey(entry config.OpenAICompatibility, index int) (string, string) {
|
||||
name := strings.TrimSpace(entry.Name)
|
||||
if name != "" {
|
||||
return "name:" + name, name
|
||||
}
|
||||
base := strings.TrimSpace(entry.BaseURL)
|
||||
if base != "" {
|
||||
return "base:" + base, base
|
||||
}
|
||||
for _, model := range entry.Models {
|
||||
alias := strings.TrimSpace(model.Alias)
|
||||
if alias == "" {
|
||||
alias = strings.TrimSpace(model.Name)
|
||||
}
|
||||
if alias != "" {
|
||||
return "alias:" + alias, alias
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("index:%d", index), fmt.Sprintf("entry-%d", index+1)
|
||||
}
|
||||
|
||||
// buildConfigChangeDetails computes a redacted, human-readable list of config changes.
|
||||
// It avoids printing secrets (like API keys) and focuses on structural or non-sensitive fields.
|
||||
func buildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
|
||||
changes := make([]string, 0, 16)
|
||||
if oldCfg == nil || newCfg == nil {
|
||||
return changes
|
||||
}
|
||||
|
||||
// Simple scalars
|
||||
if oldCfg.Port != newCfg.Port {
|
||||
changes = append(changes, fmt.Sprintf("port: %d -> %d", oldCfg.Port, newCfg.Port))
|
||||
}
|
||||
if oldCfg.AuthDir != newCfg.AuthDir {
|
||||
changes = append(changes, fmt.Sprintf("auth-dir: %s -> %s", oldCfg.AuthDir, newCfg.AuthDir))
|
||||
}
|
||||
if oldCfg.Debug != newCfg.Debug {
|
||||
changes = append(changes, fmt.Sprintf("debug: %t -> %t", oldCfg.Debug, newCfg.Debug))
|
||||
}
|
||||
if oldCfg.LoggingToFile != newCfg.LoggingToFile {
|
||||
changes = append(changes, fmt.Sprintf("logging-to-file: %t -> %t", oldCfg.LoggingToFile, newCfg.LoggingToFile))
|
||||
}
|
||||
if oldCfg.UsageStatisticsEnabled != newCfg.UsageStatisticsEnabled {
|
||||
changes = append(changes, fmt.Sprintf("usage-statistics-enabled: %t -> %t", oldCfg.UsageStatisticsEnabled, newCfg.UsageStatisticsEnabled))
|
||||
}
|
||||
if oldCfg.DisableCooling != newCfg.DisableCooling {
|
||||
changes = append(changes, fmt.Sprintf("disable-cooling: %t -> %t", oldCfg.DisableCooling, newCfg.DisableCooling))
|
||||
}
|
||||
if oldCfg.RequestLog != newCfg.RequestLog {
|
||||
changes = append(changes, fmt.Sprintf("request-log: %t -> %t", oldCfg.RequestLog, newCfg.RequestLog))
|
||||
}
|
||||
if oldCfg.RequestRetry != newCfg.RequestRetry {
|
||||
changes = append(changes, fmt.Sprintf("request-retry: %d -> %d", oldCfg.RequestRetry, newCfg.RequestRetry))
|
||||
}
|
||||
if oldCfg.MaxRetryInterval != newCfg.MaxRetryInterval {
|
||||
changes = append(changes, fmt.Sprintf("max-retry-interval: %d -> %d", oldCfg.MaxRetryInterval, newCfg.MaxRetryInterval))
|
||||
}
|
||||
if oldCfg.ProxyURL != newCfg.ProxyURL {
|
||||
changes = append(changes, fmt.Sprintf("proxy-url: %s -> %s", oldCfg.ProxyURL, newCfg.ProxyURL))
|
||||
}
|
||||
if oldCfg.WebsocketAuth != newCfg.WebsocketAuth {
|
||||
changes = append(changes, fmt.Sprintf("ws-auth: %t -> %t", oldCfg.WebsocketAuth, newCfg.WebsocketAuth))
|
||||
}
|
||||
|
||||
// Quota-exceeded behavior
|
||||
if oldCfg.QuotaExceeded.SwitchProject != newCfg.QuotaExceeded.SwitchProject {
|
||||
changes = append(changes, fmt.Sprintf("quota-exceeded.switch-project: %t -> %t", oldCfg.QuotaExceeded.SwitchProject, newCfg.QuotaExceeded.SwitchProject))
|
||||
}
|
||||
if oldCfg.QuotaExceeded.SwitchPreviewModel != newCfg.QuotaExceeded.SwitchPreviewModel {
|
||||
changes = append(changes, fmt.Sprintf("quota-exceeded.switch-preview-model: %t -> %t", oldCfg.QuotaExceeded.SwitchPreviewModel, newCfg.QuotaExceeded.SwitchPreviewModel))
|
||||
}
|
||||
|
||||
// API keys (redacted) and counts
|
||||
if len(oldCfg.APIKeys) != len(newCfg.APIKeys) {
|
||||
changes = append(changes, fmt.Sprintf("api-keys count: %d -> %d", len(oldCfg.APIKeys), len(newCfg.APIKeys)))
|
||||
} else if !reflect.DeepEqual(trimStrings(oldCfg.APIKeys), trimStrings(newCfg.APIKeys)) {
|
||||
changes = append(changes, "api-keys: values updated (count unchanged, redacted)")
|
||||
}
|
||||
if len(oldCfg.GeminiKey) != len(newCfg.GeminiKey) {
|
||||
changes = append(changes, fmt.Sprintf("gemini-api-key count: %d -> %d", len(oldCfg.GeminiKey), len(newCfg.GeminiKey)))
|
||||
} else {
|
||||
for i := range oldCfg.GeminiKey {
|
||||
if i >= len(newCfg.GeminiKey) {
|
||||
break
|
||||
}
|
||||
o := oldCfg.GeminiKey[i]
|
||||
n := newCfg.GeminiKey[i]
|
||||
if strings.TrimSpace(o.BaseURL) != strings.TrimSpace(n.BaseURL) {
|
||||
changes = append(changes, fmt.Sprintf("gemini[%d].base-url: %s -> %s", i, strings.TrimSpace(o.BaseURL), strings.TrimSpace(n.BaseURL)))
|
||||
}
|
||||
if strings.TrimSpace(o.ProxyURL) != strings.TrimSpace(n.ProxyURL) {
|
||||
changes = append(changes, fmt.Sprintf("gemini[%d].proxy-url: %s -> %s", i, strings.TrimSpace(o.ProxyURL), strings.TrimSpace(n.ProxyURL)))
|
||||
}
|
||||
if strings.TrimSpace(o.APIKey) != strings.TrimSpace(n.APIKey) {
|
||||
changes = append(changes, fmt.Sprintf("gemini[%d].api-key: updated", i))
|
||||
}
|
||||
if !equalStringMap(o.Headers, n.Headers) {
|
||||
changes = append(changes, fmt.Sprintf("gemini[%d].headers: updated", i))
|
||||
}
|
||||
oldExcluded := summarizeExcludedModels(o.ExcludedModels)
|
||||
newExcluded := summarizeExcludedModels(n.ExcludedModels)
|
||||
if oldExcluded.hash != newExcluded.hash {
|
||||
changes = append(changes, fmt.Sprintf("gemini[%d].excluded-models: updated (%d -> %d entries)", i, oldExcluded.count, newExcluded.count))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Claude keys (do not print key material)
|
||||
if len(oldCfg.ClaudeKey) != len(newCfg.ClaudeKey) {
|
||||
changes = append(changes, fmt.Sprintf("claude-api-key count: %d -> %d", len(oldCfg.ClaudeKey), len(newCfg.ClaudeKey)))
|
||||
} else {
|
||||
for i := range oldCfg.ClaudeKey {
|
||||
if i >= len(newCfg.ClaudeKey) {
|
||||
break
|
||||
}
|
||||
o := oldCfg.ClaudeKey[i]
|
||||
n := newCfg.ClaudeKey[i]
|
||||
if strings.TrimSpace(o.BaseURL) != strings.TrimSpace(n.BaseURL) {
|
||||
changes = append(changes, fmt.Sprintf("claude[%d].base-url: %s -> %s", i, strings.TrimSpace(o.BaseURL), strings.TrimSpace(n.BaseURL)))
|
||||
}
|
||||
if strings.TrimSpace(o.ProxyURL) != strings.TrimSpace(n.ProxyURL) {
|
||||
changes = append(changes, fmt.Sprintf("claude[%d].proxy-url: %s -> %s", i, strings.TrimSpace(o.ProxyURL), strings.TrimSpace(n.ProxyURL)))
|
||||
}
|
||||
if strings.TrimSpace(o.APIKey) != strings.TrimSpace(n.APIKey) {
|
||||
changes = append(changes, fmt.Sprintf("claude[%d].api-key: updated", i))
|
||||
}
|
||||
if !equalStringMap(o.Headers, n.Headers) {
|
||||
changes = append(changes, fmt.Sprintf("claude[%d].headers: updated", i))
|
||||
}
|
||||
oldExcluded := summarizeExcludedModels(o.ExcludedModels)
|
||||
newExcluded := summarizeExcludedModels(n.ExcludedModels)
|
||||
if oldExcluded.hash != newExcluded.hash {
|
||||
changes = append(changes, fmt.Sprintf("claude[%d].excluded-models: updated (%d -> %d entries)", i, oldExcluded.count, newExcluded.count))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Codex keys (do not print key material)
|
||||
if len(oldCfg.CodexKey) != len(newCfg.CodexKey) {
|
||||
changes = append(changes, fmt.Sprintf("codex-api-key count: %d -> %d", len(oldCfg.CodexKey), len(newCfg.CodexKey)))
|
||||
} else {
|
||||
for i := range oldCfg.CodexKey {
|
||||
if i >= len(newCfg.CodexKey) {
|
||||
break
|
||||
}
|
||||
o := oldCfg.CodexKey[i]
|
||||
n := newCfg.CodexKey[i]
|
||||
if strings.TrimSpace(o.BaseURL) != strings.TrimSpace(n.BaseURL) {
|
||||
changes = append(changes, fmt.Sprintf("codex[%d].base-url: %s -> %s", i, strings.TrimSpace(o.BaseURL), strings.TrimSpace(n.BaseURL)))
|
||||
}
|
||||
if strings.TrimSpace(o.ProxyURL) != strings.TrimSpace(n.ProxyURL) {
|
||||
changes = append(changes, fmt.Sprintf("codex[%d].proxy-url: %s -> %s", i, strings.TrimSpace(o.ProxyURL), strings.TrimSpace(n.ProxyURL)))
|
||||
}
|
||||
if strings.TrimSpace(o.APIKey) != strings.TrimSpace(n.APIKey) {
|
||||
changes = append(changes, fmt.Sprintf("codex[%d].api-key: updated", i))
|
||||
}
|
||||
if !equalStringMap(o.Headers, n.Headers) {
|
||||
changes = append(changes, fmt.Sprintf("codex[%d].headers: updated", i))
|
||||
}
|
||||
oldExcluded := summarizeExcludedModels(o.ExcludedModels)
|
||||
newExcluded := summarizeExcludedModels(n.ExcludedModels)
|
||||
if oldExcluded.hash != newExcluded.hash {
|
||||
changes = append(changes, fmt.Sprintf("codex[%d].excluded-models: updated (%d -> %d entries)", i, oldExcluded.count, newExcluded.count))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AmpCode settings (redacted where needed)
|
||||
oldAmpURL := strings.TrimSpace(oldCfg.AmpCode.UpstreamURL)
|
||||
newAmpURL := strings.TrimSpace(newCfg.AmpCode.UpstreamURL)
|
||||
if oldAmpURL != newAmpURL {
|
||||
changes = append(changes, fmt.Sprintf("ampcode.upstream-url: %s -> %s", oldAmpURL, newAmpURL))
|
||||
}
|
||||
oldAmpKey := strings.TrimSpace(oldCfg.AmpCode.UpstreamAPIKey)
|
||||
newAmpKey := strings.TrimSpace(newCfg.AmpCode.UpstreamAPIKey)
|
||||
switch {
|
||||
case oldAmpKey == "" && newAmpKey != "":
|
||||
changes = append(changes, "ampcode.upstream-api-key: added")
|
||||
case oldAmpKey != "" && newAmpKey == "":
|
||||
changes = append(changes, "ampcode.upstream-api-key: removed")
|
||||
case oldAmpKey != newAmpKey:
|
||||
changes = append(changes, "ampcode.upstream-api-key: updated")
|
||||
}
|
||||
if oldCfg.AmpCode.RestrictManagementToLocalhost != newCfg.AmpCode.RestrictManagementToLocalhost {
|
||||
changes = append(changes, fmt.Sprintf("ampcode.restrict-management-to-localhost: %t -> %t", oldCfg.AmpCode.RestrictManagementToLocalhost, newCfg.AmpCode.RestrictManagementToLocalhost))
|
||||
}
|
||||
oldMappings := summarizeAmpModelMappings(oldCfg.AmpCode.ModelMappings)
|
||||
newMappings := summarizeAmpModelMappings(newCfg.AmpCode.ModelMappings)
|
||||
if oldMappings.hash != newMappings.hash {
|
||||
changes = append(changes, fmt.Sprintf("ampcode.model-mappings: updated (%d -> %d entries)", oldMappings.count, newMappings.count))
|
||||
}
|
||||
|
||||
if entries, _ := diffOAuthExcludedModelChanges(oldCfg.OAuthExcludedModels, newCfg.OAuthExcludedModels); len(entries) > 0 {
|
||||
changes = append(changes, entries...)
|
||||
}
|
||||
|
||||
// Remote management (never print the key)
|
||||
if oldCfg.RemoteManagement.AllowRemote != newCfg.RemoteManagement.AllowRemote {
|
||||
changes = append(changes, fmt.Sprintf("remote-management.allow-remote: %t -> %t", oldCfg.RemoteManagement.AllowRemote, newCfg.RemoteManagement.AllowRemote))
|
||||
}
|
||||
if oldCfg.RemoteManagement.DisableControlPanel != newCfg.RemoteManagement.DisableControlPanel {
|
||||
changes = append(changes, fmt.Sprintf("remote-management.disable-control-panel: %t -> %t", oldCfg.RemoteManagement.DisableControlPanel, newCfg.RemoteManagement.DisableControlPanel))
|
||||
}
|
||||
if oldCfg.RemoteManagement.SecretKey != newCfg.RemoteManagement.SecretKey {
|
||||
switch {
|
||||
case oldCfg.RemoteManagement.SecretKey == "" && newCfg.RemoteManagement.SecretKey != "":
|
||||
changes = append(changes, "remote-management.secret-key: created")
|
||||
case oldCfg.RemoteManagement.SecretKey != "" && newCfg.RemoteManagement.SecretKey == "":
|
||||
changes = append(changes, "remote-management.secret-key: deleted")
|
||||
default:
|
||||
changes = append(changes, "remote-management.secret-key: updated")
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAI compatibility providers (summarized)
|
||||
if compat := diffOpenAICompatibility(oldCfg.OpenAICompatibility, newCfg.OpenAICompatibility); len(compat) > 0 {
|
||||
changes = append(changes, "openai-compatibility:")
|
||||
for _, c := range compat {
|
||||
changes = append(changes, " "+c)
|
||||
}
|
||||
}
|
||||
|
||||
return changes
|
||||
}
|
||||
|
||||
func addConfigHeadersToAttrs(headers map[string]string, attrs map[string]string) {
|
||||
if len(headers) == 0 || attrs == nil {
|
||||
return
|
||||
@@ -2126,23 +1668,3 @@ func addConfigHeadersToAttrs(headers map[string]string, attrs map[string]string)
|
||||
attrs["header:"+key] = val
|
||||
}
|
||||
}
|
||||
|
||||
func trimStrings(in []string) []string {
|
||||
out := make([]string, len(in))
|
||||
for i := range in {
|
||||
out[i] = strings.TrimSpace(in[i])
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func equalStringMap(a, b map[string]string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for k, v := range a {
|
||||
if b[k] != v {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
609
internal/watcher/watcher_test.go
Normal file
609
internal/watcher/watcher_test.go
Normal file
@@ -0,0 +1,609 @@
|
||||
package watcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func TestApplyAuthExcludedModelsMeta_APIKey(t *testing.T) {
|
||||
auth := &coreauth.Auth{Attributes: map[string]string{}}
|
||||
cfg := &config.Config{}
|
||||
perKey := []string{" Model-1 ", "model-2"}
|
||||
|
||||
applyAuthExcludedModelsMeta(auth, cfg, perKey, "apikey")
|
||||
|
||||
expected := diff.ComputeExcludedModelsHash([]string{"model-1", "model-2"})
|
||||
if got := auth.Attributes["excluded_models_hash"]; got != expected {
|
||||
t.Fatalf("expected hash %s, got %s", expected, got)
|
||||
}
|
||||
if got := auth.Attributes["auth_kind"]; got != "apikey" {
|
||||
t.Fatalf("expected auth_kind=apikey, got %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyAuthExcludedModelsMeta_OAuthProvider(t *testing.T) {
|
||||
auth := &coreauth.Auth{
|
||||
Provider: "TestProv",
|
||||
Attributes: map[string]string{},
|
||||
}
|
||||
cfg := &config.Config{
|
||||
OAuthExcludedModels: map[string][]string{
|
||||
"testprov": {"A", "b"},
|
||||
},
|
||||
}
|
||||
|
||||
applyAuthExcludedModelsMeta(auth, cfg, nil, "oauth")
|
||||
|
||||
expected := diff.ComputeExcludedModelsHash([]string{"a", "b"})
|
||||
if got := auth.Attributes["excluded_models_hash"]; got != expected {
|
||||
t.Fatalf("expected hash %s, got %s", expected, got)
|
||||
}
|
||||
if got := auth.Attributes["auth_kind"]; got != "oauth" {
|
||||
t.Fatalf("expected auth_kind=oauth, got %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAPIKeyClientsCounts(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
GeminiKey: []config.GeminiKey{{APIKey: "g1"}, {APIKey: "g2"}},
|
||||
VertexCompatAPIKey: []config.VertexCompatKey{
|
||||
{APIKey: "v1"},
|
||||
},
|
||||
ClaudeKey: []config.ClaudeKey{{APIKey: "c1"}},
|
||||
CodexKey: []config.CodexKey{{APIKey: "x1"}, {APIKey: "x2"}},
|
||||
OpenAICompatibility: []config.OpenAICompatibility{
|
||||
{APIKeyEntries: []config.OpenAICompatibilityAPIKey{{APIKey: "o1"}, {APIKey: "o2"}}},
|
||||
},
|
||||
}
|
||||
|
||||
gemini, vertex, claude, codex, compat := BuildAPIKeyClients(cfg)
|
||||
if gemini != 2 || vertex != 1 || claude != 1 || codex != 2 || compat != 2 {
|
||||
t.Fatalf("unexpected counts: %d %d %d %d %d", gemini, vertex, claude, codex, compat)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeAuthStripsTemporalFields(t *testing.T) {
|
||||
now := time.Now()
|
||||
auth := &coreauth.Auth{
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
LastRefreshedAt: now,
|
||||
NextRefreshAfter: now,
|
||||
Quota: coreauth.QuotaState{
|
||||
NextRecoverAt: now,
|
||||
},
|
||||
Runtime: map[string]any{"k": "v"},
|
||||
}
|
||||
|
||||
normalized := normalizeAuth(auth)
|
||||
if !normalized.CreatedAt.IsZero() || !normalized.UpdatedAt.IsZero() || !normalized.LastRefreshedAt.IsZero() || !normalized.NextRefreshAfter.IsZero() {
|
||||
t.Fatal("expected time fields to be zeroed")
|
||||
}
|
||||
if normalized.Runtime != nil {
|
||||
t.Fatal("expected runtime to be nil")
|
||||
}
|
||||
if !normalized.Quota.NextRecoverAt.IsZero() {
|
||||
t.Fatal("expected quota.NextRecoverAt to be zeroed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchProvider(t *testing.T) {
|
||||
if _, ok := matchProvider("OpenAI", []string{"openai", "claude"}); !ok {
|
||||
t.Fatal("expected match to succeed ignoring case")
|
||||
}
|
||||
if _, ok := matchProvider("missing", []string{"openai"}); ok {
|
||||
t.Fatal("expected match to fail for unknown provider")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSnapshotCoreAuths_ConfigAndAuthFiles(t *testing.T) {
|
||||
authDir := t.TempDir()
|
||||
metadata := map[string]any{
|
||||
"type": "gemini",
|
||||
"email": "user@example.com",
|
||||
"project_id": "proj-a, proj-b",
|
||||
"proxy_url": "https://proxy",
|
||||
}
|
||||
authFile := filepath.Join(authDir, "gemini.json")
|
||||
data, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal metadata: %v", err)
|
||||
}
|
||||
if err = os.WriteFile(authFile, data, 0o644); err != nil {
|
||||
t.Fatalf("failed to write auth file: %v", err)
|
||||
}
|
||||
|
||||
cfg := &config.Config{
|
||||
AuthDir: authDir,
|
||||
GeminiKey: []config.GeminiKey{
|
||||
{
|
||||
APIKey: "g-key",
|
||||
BaseURL: "https://gemini",
|
||||
ExcludedModels: []string{"Model-A", "model-b"},
|
||||
Headers: map[string]string{"X-Req": "1"},
|
||||
},
|
||||
},
|
||||
OAuthExcludedModels: map[string][]string{
|
||||
"gemini-cli": {"Foo", "bar"},
|
||||
},
|
||||
}
|
||||
|
||||
w := &Watcher{authDir: authDir}
|
||||
w.SetConfig(cfg)
|
||||
|
||||
auths := w.SnapshotCoreAuths()
|
||||
if len(auths) != 4 {
|
||||
t.Fatalf("expected 4 auth entries (1 config + 1 primary + 2 virtual), got %d", len(auths))
|
||||
}
|
||||
|
||||
var geminiAPIKeyAuth *coreauth.Auth
|
||||
var geminiPrimary *coreauth.Auth
|
||||
virtuals := make([]*coreauth.Auth, 0)
|
||||
for _, a := range auths {
|
||||
switch {
|
||||
case a.Provider == "gemini" && a.Attributes["api_key"] == "g-key":
|
||||
geminiAPIKeyAuth = a
|
||||
case a.Attributes["gemini_virtual_primary"] == "true":
|
||||
geminiPrimary = a
|
||||
case strings.TrimSpace(a.Attributes["gemini_virtual_parent"]) != "":
|
||||
virtuals = append(virtuals, a)
|
||||
}
|
||||
}
|
||||
if geminiAPIKeyAuth == nil {
|
||||
t.Fatal("expected synthesized Gemini API key auth")
|
||||
}
|
||||
expectedAPIKeyHash := diff.ComputeExcludedModelsHash([]string{"Model-A", "model-b"})
|
||||
if geminiAPIKeyAuth.Attributes["excluded_models_hash"] != expectedAPIKeyHash {
|
||||
t.Fatalf("expected API key excluded hash %s, got %s", expectedAPIKeyHash, geminiAPIKeyAuth.Attributes["excluded_models_hash"])
|
||||
}
|
||||
if geminiAPIKeyAuth.Attributes["auth_kind"] != "apikey" {
|
||||
t.Fatalf("expected auth_kind=apikey, got %s", geminiAPIKeyAuth.Attributes["auth_kind"])
|
||||
}
|
||||
|
||||
if geminiPrimary == nil {
|
||||
t.Fatal("expected primary gemini-cli auth from file")
|
||||
}
|
||||
if !geminiPrimary.Disabled || geminiPrimary.Status != coreauth.StatusDisabled {
|
||||
t.Fatal("expected primary gemini-cli auth to be disabled when virtual auths are synthesized")
|
||||
}
|
||||
expectedOAuthHash := diff.ComputeExcludedModelsHash([]string{"Foo", "bar"})
|
||||
if geminiPrimary.Attributes["excluded_models_hash"] != expectedOAuthHash {
|
||||
t.Fatalf("expected OAuth excluded hash %s, got %s", expectedOAuthHash, geminiPrimary.Attributes["excluded_models_hash"])
|
||||
}
|
||||
if geminiPrimary.Attributes["auth_kind"] != "oauth" {
|
||||
t.Fatalf("expected auth_kind=oauth, got %s", geminiPrimary.Attributes["auth_kind"])
|
||||
}
|
||||
|
||||
if len(virtuals) != 2 {
|
||||
t.Fatalf("expected 2 virtual auths, got %d", len(virtuals))
|
||||
}
|
||||
for _, v := range virtuals {
|
||||
if v.Attributes["gemini_virtual_parent"] != geminiPrimary.ID {
|
||||
t.Fatalf("virtual auth missing parent link to %s", geminiPrimary.ID)
|
||||
}
|
||||
if v.Attributes["excluded_models_hash"] != expectedOAuthHash {
|
||||
t.Fatalf("expected virtual excluded hash %s, got %s", expectedOAuthHash, v.Attributes["excluded_models_hash"])
|
||||
}
|
||||
if v.Status != coreauth.StatusActive {
|
||||
t.Fatalf("expected virtual auth to be active, got %s", v.Status)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestReloadConfigIfChanged_TriggersOnChangeAndSkipsUnchanged(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
authDir := filepath.Join(tmpDir, "auth")
|
||||
if err := os.MkdirAll(authDir, 0o755); err != nil {
|
||||
t.Fatalf("failed to create auth dir: %v", err)
|
||||
}
|
||||
|
||||
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||
writeConfig := func(port int, allowRemote bool) {
|
||||
cfg := &config.Config{
|
||||
Port: port,
|
||||
AuthDir: authDir,
|
||||
RemoteManagement: config.RemoteManagement{
|
||||
AllowRemote: allowRemote,
|
||||
},
|
||||
}
|
||||
data, err := yaml.Marshal(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal config: %v", err)
|
||||
}
|
||||
if err = os.WriteFile(configPath, data, 0o644); err != nil {
|
||||
t.Fatalf("failed to write config: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
writeConfig(8080, false)
|
||||
|
||||
reloads := 0
|
||||
w := &Watcher{
|
||||
configPath: configPath,
|
||||
authDir: authDir,
|
||||
reloadCallback: func(*config.Config) { reloads++ },
|
||||
}
|
||||
|
||||
w.reloadConfigIfChanged()
|
||||
if reloads != 1 {
|
||||
t.Fatalf("expected first reload to trigger callback once, got %d", reloads)
|
||||
}
|
||||
|
||||
// Same content should be skipped by hash check.
|
||||
w.reloadConfigIfChanged()
|
||||
if reloads != 1 {
|
||||
t.Fatalf("expected unchanged config to be skipped, callback count %d", reloads)
|
||||
}
|
||||
|
||||
writeConfig(9090, true)
|
||||
w.reloadConfigIfChanged()
|
||||
if reloads != 2 {
|
||||
t.Fatalf("expected changed config to trigger reload, callback count %d", reloads)
|
||||
}
|
||||
w.clientsMutex.RLock()
|
||||
defer w.clientsMutex.RUnlock()
|
||||
if w.config == nil || w.config.Port != 9090 || !w.config.RemoteManagement.AllowRemote {
|
||||
t.Fatalf("expected config to be updated after reload, got %+v", w.config)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartAndStopSuccess(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
authDir := filepath.Join(tmpDir, "auth")
|
||||
if err := os.MkdirAll(authDir, 0o755); err != nil {
|
||||
t.Fatalf("failed to create auth dir: %v", err)
|
||||
}
|
||||
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||
if err := os.WriteFile(configPath, []byte("auth_dir: "+authDir), 0o644); err != nil {
|
||||
t.Fatalf("failed to create config file: %v", err)
|
||||
}
|
||||
|
||||
var reloads int32
|
||||
w, err := NewWatcher(configPath, authDir, func(*config.Config) {
|
||||
atomic.AddInt32(&reloads, 1)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create watcher: %v", err)
|
||||
}
|
||||
w.SetConfig(&config.Config{AuthDir: authDir})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
if err := w.Start(ctx); err != nil {
|
||||
t.Fatalf("expected Start to succeed: %v", err)
|
||||
}
|
||||
cancel()
|
||||
if err := w.Stop(); err != nil {
|
||||
t.Fatalf("expected Stop to succeed: %v", err)
|
||||
}
|
||||
if got := atomic.LoadInt32(&reloads); got != 1 {
|
||||
t.Fatalf("expected one reload callback, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartFailsWhenConfigMissing(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
authDir := filepath.Join(tmpDir, "auth")
|
||||
if err := os.MkdirAll(authDir, 0o755); err != nil {
|
||||
t.Fatalf("failed to create auth dir: %v", err)
|
||||
}
|
||||
configPath := filepath.Join(tmpDir, "missing-config.yaml")
|
||||
|
||||
w, err := NewWatcher(configPath, authDir, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create watcher: %v", err)
|
||||
}
|
||||
defer w.Stop()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
if err := w.Start(ctx); err == nil {
|
||||
t.Fatal("expected Start to fail for missing config file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDispatchRuntimeAuthUpdateEnqueuesAndUpdatesState(t *testing.T) {
|
||||
queue := make(chan AuthUpdate, 4)
|
||||
w := &Watcher{}
|
||||
w.SetAuthUpdateQueue(queue)
|
||||
defer w.stopDispatch()
|
||||
|
||||
auth := &coreauth.Auth{ID: "auth-1", Provider: "test"}
|
||||
if ok := w.DispatchRuntimeAuthUpdate(AuthUpdate{Action: AuthUpdateActionAdd, Auth: auth}); !ok {
|
||||
t.Fatal("expected DispatchRuntimeAuthUpdate to enqueue")
|
||||
}
|
||||
|
||||
select {
|
||||
case update := <-queue:
|
||||
if update.Action != AuthUpdateActionAdd || update.Auth.ID != "auth-1" {
|
||||
t.Fatalf("unexpected update: %+v", update)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for auth update")
|
||||
}
|
||||
|
||||
if ok := w.DispatchRuntimeAuthUpdate(AuthUpdate{Action: AuthUpdateActionDelete, ID: "auth-1"}); !ok {
|
||||
t.Fatal("expected delete update to enqueue")
|
||||
}
|
||||
select {
|
||||
case update := <-queue:
|
||||
if update.Action != AuthUpdateActionDelete || update.ID != "auth-1" {
|
||||
t.Fatalf("unexpected delete update: %+v", update)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for delete update")
|
||||
}
|
||||
w.clientsMutex.RLock()
|
||||
if _, exists := w.runtimeAuths["auth-1"]; exists {
|
||||
w.clientsMutex.RUnlock()
|
||||
t.Fatal("expected runtime auth to be cleared after delete")
|
||||
}
|
||||
w.clientsMutex.RUnlock()
|
||||
}
|
||||
|
||||
func TestAddOrUpdateClientSkipsUnchanged(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
authFile := filepath.Join(tmpDir, "sample.json")
|
||||
if err := os.WriteFile(authFile, []byte(`{"type":"demo"}`), 0o644); err != nil {
|
||||
t.Fatalf("failed to create auth file: %v", err)
|
||||
}
|
||||
data, _ := os.ReadFile(authFile)
|
||||
sum := sha256.Sum256(data)
|
||||
|
||||
var reloads int32
|
||||
w := &Watcher{
|
||||
authDir: tmpDir,
|
||||
lastAuthHashes: map[string]string{
|
||||
filepath.Clean(authFile): hexString(sum[:]),
|
||||
},
|
||||
reloadCallback: func(*config.Config) {
|
||||
atomic.AddInt32(&reloads, 1)
|
||||
},
|
||||
}
|
||||
w.SetConfig(&config.Config{AuthDir: tmpDir})
|
||||
|
||||
w.addOrUpdateClient(authFile)
|
||||
if got := atomic.LoadInt32(&reloads); got != 0 {
|
||||
t.Fatalf("expected no reload for unchanged file, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddOrUpdateClientTriggersReloadAndHash(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
authFile := filepath.Join(tmpDir, "sample.json")
|
||||
if err := os.WriteFile(authFile, []byte(`{"type":"demo","api_key":"k"}`), 0o644); err != nil {
|
||||
t.Fatalf("failed to create auth file: %v", err)
|
||||
}
|
||||
|
||||
var reloads int32
|
||||
w := &Watcher{
|
||||
authDir: tmpDir,
|
||||
lastAuthHashes: make(map[string]string),
|
||||
reloadCallback: func(*config.Config) {
|
||||
atomic.AddInt32(&reloads, 1)
|
||||
},
|
||||
}
|
||||
w.SetConfig(&config.Config{AuthDir: tmpDir})
|
||||
|
||||
w.addOrUpdateClient(authFile)
|
||||
|
||||
if got := atomic.LoadInt32(&reloads); got != 1 {
|
||||
t.Fatalf("expected reload callback once, got %d", got)
|
||||
}
|
||||
normalized := filepath.Clean(authFile)
|
||||
if _, ok := w.lastAuthHashes[normalized]; !ok {
|
||||
t.Fatalf("expected hash to be stored for %s", normalized)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveClientRemovesHash(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
authFile := filepath.Join(tmpDir, "sample.json")
|
||||
var reloads int32
|
||||
|
||||
w := &Watcher{
|
||||
authDir: tmpDir,
|
||||
lastAuthHashes: map[string]string{
|
||||
filepath.Clean(authFile): "hash",
|
||||
},
|
||||
reloadCallback: func(*config.Config) {
|
||||
atomic.AddInt32(&reloads, 1)
|
||||
},
|
||||
}
|
||||
w.SetConfig(&config.Config{AuthDir: tmpDir})
|
||||
|
||||
w.removeClient(authFile)
|
||||
if _, ok := w.lastAuthHashes[filepath.Clean(authFile)]; ok {
|
||||
t.Fatal("expected hash to be removed after deletion")
|
||||
}
|
||||
if got := atomic.LoadInt32(&reloads); got != 1 {
|
||||
t.Fatalf("expected reload callback once, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldDebounceRemove(t *testing.T) {
|
||||
w := &Watcher{}
|
||||
path := filepath.Clean("test.json")
|
||||
|
||||
if w.shouldDebounceRemove(path, time.Now()) {
|
||||
t.Fatal("first call should not debounce")
|
||||
}
|
||||
if !w.shouldDebounceRemove(path, time.Now()) {
|
||||
t.Fatal("second call within window should debounce")
|
||||
}
|
||||
|
||||
w.clientsMutex.Lock()
|
||||
w.lastRemoveTimes = map[string]time.Time{path: time.Now().Add(-2 * authRemoveDebounceWindow)}
|
||||
w.clientsMutex.Unlock()
|
||||
|
||||
if w.shouldDebounceRemove(path, time.Now()) {
|
||||
t.Fatal("call after window should not debounce")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthFileUnchangedUsesHash(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
authFile := filepath.Join(tmpDir, "sample.json")
|
||||
content := []byte(`{"type":"demo"}`)
|
||||
if err := os.WriteFile(authFile, content, 0o644); err != nil {
|
||||
t.Fatalf("failed to write auth file: %v", err)
|
||||
}
|
||||
|
||||
w := &Watcher{lastAuthHashes: make(map[string]string)}
|
||||
unchanged, err := w.authFileUnchanged(authFile)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if unchanged {
|
||||
t.Fatal("expected first check to report changed")
|
||||
}
|
||||
|
||||
sum := sha256.Sum256(content)
|
||||
w.lastAuthHashes[filepath.Clean(authFile)] = hexString(sum[:])
|
||||
|
||||
unchanged, err = w.authFileUnchanged(authFile)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !unchanged {
|
||||
t.Fatal("expected hash match to report unchanged")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReloadClientsCachesAuthHashes(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
authFile := filepath.Join(tmpDir, "one.json")
|
||||
if err := os.WriteFile(authFile, []byte(`{"type":"demo"}`), 0o644); err != nil {
|
||||
t.Fatalf("failed to write auth file: %v", err)
|
||||
}
|
||||
w := &Watcher{
|
||||
authDir: tmpDir,
|
||||
config: &config.Config{AuthDir: tmpDir},
|
||||
}
|
||||
|
||||
w.reloadClients(true, nil, false)
|
||||
|
||||
w.clientsMutex.RLock()
|
||||
defer w.clientsMutex.RUnlock()
|
||||
if len(w.lastAuthHashes) != 1 {
|
||||
t.Fatalf("expected hash cache for one auth file, got %d", len(w.lastAuthHashes))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReloadClientsLogsConfigDiffs(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
oldCfg := &config.Config{AuthDir: tmpDir, Port: 1, Debug: false}
|
||||
newCfg := &config.Config{AuthDir: tmpDir, Port: 2, Debug: true}
|
||||
|
||||
w := &Watcher{
|
||||
authDir: tmpDir,
|
||||
config: oldCfg,
|
||||
}
|
||||
w.SetConfig(oldCfg)
|
||||
w.oldConfigYaml, _ = yaml.Marshal(oldCfg)
|
||||
|
||||
w.clientsMutex.Lock()
|
||||
w.config = newCfg
|
||||
w.clientsMutex.Unlock()
|
||||
|
||||
w.reloadClients(false, nil, false)
|
||||
}
|
||||
|
||||
func TestSetAuthUpdateQueueNilResetsDispatch(t *testing.T) {
|
||||
w := &Watcher{}
|
||||
queue := make(chan AuthUpdate, 1)
|
||||
w.SetAuthUpdateQueue(queue)
|
||||
if w.dispatchCond == nil || w.dispatchCancel == nil {
|
||||
t.Fatal("expected dispatch to be initialized")
|
||||
}
|
||||
w.SetAuthUpdateQueue(nil)
|
||||
if w.dispatchCancel != nil {
|
||||
t.Fatal("expected dispatch cancel to be cleared when queue nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStopConfigReloadTimerSafeWhenNil(t *testing.T) {
|
||||
w := &Watcher{}
|
||||
w.stopConfigReloadTimer()
|
||||
w.configReloadMu.Lock()
|
||||
w.configReloadTimer = time.AfterFunc(10*time.Millisecond, func() {})
|
||||
w.configReloadMu.Unlock()
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
w.stopConfigReloadTimer()
|
||||
}
|
||||
|
||||
func TestHandleEventRemovesAuthFile(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
authFile := filepath.Join(tmpDir, "remove.json")
|
||||
if err := os.WriteFile(authFile, []byte(`{"type":"demo"}`), 0o644); err != nil {
|
||||
t.Fatalf("failed to write auth file: %v", err)
|
||||
}
|
||||
if err := os.Remove(authFile); err != nil {
|
||||
t.Fatalf("failed to remove auth file pre-check: %v", err)
|
||||
}
|
||||
|
||||
var reloads int32
|
||||
w := &Watcher{
|
||||
authDir: tmpDir,
|
||||
config: &config.Config{AuthDir: tmpDir},
|
||||
lastAuthHashes: map[string]string{
|
||||
filepath.Clean(authFile): "hash",
|
||||
},
|
||||
reloadCallback: func(*config.Config) {
|
||||
atomic.AddInt32(&reloads, 1)
|
||||
},
|
||||
}
|
||||
w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Remove})
|
||||
|
||||
if atomic.LoadInt32(&reloads) != 1 {
|
||||
t.Fatalf("expected reload callback once, got %d", reloads)
|
||||
}
|
||||
if _, ok := w.lastAuthHashes[filepath.Clean(authFile)]; ok {
|
||||
t.Fatal("expected hash entry to be removed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDispatchAuthUpdatesFlushesQueue(t *testing.T) {
|
||||
queue := make(chan AuthUpdate, 4)
|
||||
w := &Watcher{}
|
||||
w.SetAuthUpdateQueue(queue)
|
||||
defer w.stopDispatch()
|
||||
|
||||
w.dispatchAuthUpdates([]AuthUpdate{
|
||||
{Action: AuthUpdateActionAdd, ID: "a"},
|
||||
{Action: AuthUpdateActionModify, ID: "b"},
|
||||
})
|
||||
|
||||
got := make([]AuthUpdate, 0, 2)
|
||||
for i := 0; i < 2; i++ {
|
||||
select {
|
||||
case u := <-queue:
|
||||
got = append(got, u)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatalf("timed out waiting for update %d", i)
|
||||
}
|
||||
}
|
||||
if len(got) != 2 || got[0].ID != "a" || got[1].ID != "b" {
|
||||
t.Fatalf("unexpected updates order/content: %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func hexString(data []byte) string {
|
||||
return strings.ToLower(fmt.Sprintf("%x", data))
|
||||
}
|
||||
@@ -7,7 +7,6 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
@@ -219,52 +218,24 @@ func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON [
|
||||
}
|
||||
|
||||
func (h *ClaudeCodeAPIHandler) forwardClaudeStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
|
||||
// v6.1: Intelligent Buffered Streamer strategy
|
||||
// Enhanced buffering with larger buffer size (16KB) and longer flush interval (120ms).
|
||||
// Smart flush only when buffer is sufficiently filled (≥50%), dramatically reducing
|
||||
// flush frequency from ~12.5Hz to ~5-8Hz while maintaining low latency.
|
||||
writer := bufio.NewWriterSize(c.Writer, 16*1024) // 4KB → 16KB
|
||||
ticker := time.NewTicker(120 * time.Millisecond) // 80ms → 120ms
|
||||
defer ticker.Stop()
|
||||
|
||||
var chunkIdx int
|
||||
|
||||
// OpenAI-style stream forwarding: write each SSE chunk and flush immediately.
|
||||
// This guarantees clients see incremental output even for small responses.
|
||||
for {
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
// Context cancelled, flush any remaining data before exit
|
||||
_ = writer.Flush()
|
||||
cancel(c.Request.Context().Err())
|
||||
return
|
||||
|
||||
case <-ticker.C:
|
||||
// Smart flush: only flush when buffer has sufficient data (≥50% full)
|
||||
// This reduces flush frequency while ensuring data flows naturally
|
||||
buffered := writer.Buffered()
|
||||
if buffered >= 8*1024 { // At least 8KB (50% of 16KB buffer)
|
||||
if err := writer.Flush(); err != nil {
|
||||
// Error flushing, cancel and return
|
||||
cancel(err)
|
||||
return
|
||||
}
|
||||
flusher.Flush() // Also flush the underlying http.ResponseWriter
|
||||
}
|
||||
|
||||
case chunk, ok := <-data:
|
||||
if !ok {
|
||||
// Stream ended, flush remaining data
|
||||
_ = writer.Flush()
|
||||
flusher.Flush()
|
||||
cancel(nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Forward the complete SSE event block directly (already formatted by the translator).
|
||||
// The translator returns a complete SSE-compliant event block, including event:, data:, and separators.
|
||||
// The handler just needs to forward it without reassembly.
|
||||
if len(chunk) > 0 {
|
||||
_, _ = writer.Write(chunk)
|
||||
_, _ = c.Writer.Write(chunk)
|
||||
flusher.Flush()
|
||||
}
|
||||
chunkIdx++
|
||||
|
||||
case errMsg, ok := <-errs:
|
||||
if !ok {
|
||||
@@ -276,21 +247,20 @@ func (h *ClaudeCodeAPIHandler) forwardClaudeStream(c *gin.Context, flusher http.
|
||||
status = errMsg.StatusCode
|
||||
}
|
||||
c.Status(status)
|
||||
|
||||
// An error occurred: emit as a proper SSE error event
|
||||
errorBytes, _ := json.Marshal(h.toClaudeError(errMsg))
|
||||
_, _ = writer.WriteString("event: error\n")
|
||||
_, _ = writer.WriteString("data: ")
|
||||
_, _ = writer.Write(errorBytes)
|
||||
_, _ = writer.WriteString("\n\n")
|
||||
_ = writer.Flush()
|
||||
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", errorBytes)
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
var execErr error
|
||||
if errMsg != nil {
|
||||
execErr = errMsg.Error
|
||||
}
|
||||
cancel(execErr)
|
||||
return
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -84,7 +84,8 @@ func (h *GeminiAPIHandler) GeminiGetHandler(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
switch request.Action {
|
||||
action := strings.TrimPrefix(request.Action, "/")
|
||||
switch action {
|
||||
case "gemini-3-pro-preview":
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"name": "models/gemini-3-pro-preview",
|
||||
@@ -189,7 +190,7 @@ func (h *GeminiAPIHandler) GeminiHandler(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
action := strings.Split(request.Action, ":")
|
||||
action := strings.Split(strings.TrimPrefix(request.Action, "/"), ":")
|
||||
if len(action) != 2 {
|
||||
c.JSON(http.StatusNotFound, handlers.ErrorResponse{
|
||||
Error: handlers.ErrorDetail{
|
||||
|
||||
@@ -5,10 +5,10 @@ package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
@@ -49,27 +49,6 @@ type BaseAPIHandler struct {
|
||||
|
||||
// Cfg holds the current application configuration.
|
||||
Cfg *config.SDKConfig
|
||||
|
||||
// OpenAICompatProviders is a list of provider names for OpenAI compatibility.
|
||||
openAICompatProviders []string
|
||||
openAICompatMutex sync.RWMutex
|
||||
}
|
||||
|
||||
// GetOpenAICompatProviders safely returns a copy of the provider names
|
||||
func (h *BaseAPIHandler) GetOpenAICompatProviders() []string {
|
||||
h.openAICompatMutex.RLock()
|
||||
defer h.openAICompatMutex.RUnlock()
|
||||
result := make([]string, len(h.openAICompatProviders))
|
||||
copy(result, h.openAICompatProviders)
|
||||
return result
|
||||
}
|
||||
|
||||
// SetOpenAICompatProviders safely sets the provider names
|
||||
func (h *BaseAPIHandler) SetOpenAICompatProviders(providers []string) {
|
||||
h.openAICompatMutex.Lock()
|
||||
defer h.openAICompatMutex.Unlock()
|
||||
h.openAICompatProviders = make([]string, len(providers))
|
||||
copy(h.openAICompatProviders, providers)
|
||||
}
|
||||
|
||||
// NewBaseAPIHandlers creates a new API handlers instance.
|
||||
@@ -81,12 +60,11 @@ func (h *BaseAPIHandler) SetOpenAICompatProviders(providers []string) {
|
||||
//
|
||||
// Returns:
|
||||
// - *BaseAPIHandler: A new API handlers instance
|
||||
func NewBaseAPIHandlers(cfg *config.SDKConfig, authManager *coreauth.Manager, openAICompatProviders []string) *BaseAPIHandler {
|
||||
func NewBaseAPIHandlers(cfg *config.SDKConfig, authManager *coreauth.Manager) *BaseAPIHandler {
|
||||
h := &BaseAPIHandler{
|
||||
Cfg: cfg,
|
||||
AuthManager: authManager,
|
||||
}
|
||||
h.SetOpenAICompatProviders(openAICompatProviders)
|
||||
return h
|
||||
}
|
||||
|
||||
@@ -137,6 +115,16 @@ func (h *BaseAPIHandler) GetContextWithCancel(handler interfaces.APIHandler, c *
|
||||
newCtx = context.WithValue(newCtx, "handler", handler)
|
||||
return newCtx, func(params ...interface{}) {
|
||||
if h.Cfg.RequestLog && len(params) == 1 {
|
||||
if existing, exists := c.Get("API_RESPONSE"); exists {
|
||||
if existingBytes, ok := existing.([]byte); ok && len(bytes.TrimSpace(existingBytes)) > 0 {
|
||||
switch params[0].(type) {
|
||||
case error, string:
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var payload []byte
|
||||
switch data := params[0].(type) {
|
||||
case []byte:
|
||||
@@ -351,30 +339,19 @@ func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string
|
||||
// Resolve "auto" model to an actual available model first
|
||||
resolvedModelName := util.ResolveAutoModel(modelName)
|
||||
|
||||
providerName, extractedModelName, isDynamic := h.parseDynamicModel(resolvedModelName)
|
||||
|
||||
targetModelName := resolvedModelName
|
||||
if isDynamic {
|
||||
targetModelName = extractedModelName
|
||||
}
|
||||
|
||||
// Normalize the model name to handle dynamic thinking suffixes before determining the provider.
|
||||
normalizedModel, metadata = normalizeModelMetadata(targetModelName)
|
||||
normalizedModel, metadata = normalizeModelMetadata(resolvedModelName)
|
||||
|
||||
if isDynamic {
|
||||
providers = []string{providerName}
|
||||
} else {
|
||||
// For non-dynamic models, use the normalizedModel to get the provider name.
|
||||
providers = util.GetProviderName(normalizedModel)
|
||||
if len(providers) == 0 && metadata != nil {
|
||||
if originalRaw, ok := metadata[util.ThinkingOriginalModelMetadataKey]; ok {
|
||||
if originalModel, okStr := originalRaw.(string); okStr {
|
||||
originalModel = strings.TrimSpace(originalModel)
|
||||
if originalModel != "" && !strings.EqualFold(originalModel, normalizedModel) {
|
||||
if altProviders := util.GetProviderName(originalModel); len(altProviders) > 0 {
|
||||
providers = altProviders
|
||||
normalizedModel = originalModel
|
||||
}
|
||||
// Use the normalizedModel to get the provider name.
|
||||
providers = util.GetProviderName(normalizedModel)
|
||||
if len(providers) == 0 && metadata != nil {
|
||||
if originalRaw, ok := metadata[util.ThinkingOriginalModelMetadataKey]; ok {
|
||||
if originalModel, okStr := originalRaw.(string); okStr {
|
||||
originalModel = strings.TrimSpace(originalModel)
|
||||
if originalModel != "" && !strings.EqualFold(originalModel, normalizedModel) {
|
||||
if altProviders := util.GetProviderName(originalModel); len(altProviders) > 0 {
|
||||
providers = altProviders
|
||||
normalizedModel = originalModel
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -392,30 +369,6 @@ func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string
|
||||
return providers, normalizedModel, metadata, nil
|
||||
}
|
||||
|
||||
func (h *BaseAPIHandler) parseDynamicModel(modelName string) (providerName, model string, isDynamic bool) {
|
||||
var providerPart, modelPart string
|
||||
for _, sep := range []string{"://"} {
|
||||
if parts := strings.SplitN(modelName, sep, 2); len(parts) == 2 {
|
||||
providerPart = parts[0]
|
||||
modelPart = parts[1]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if providerPart == "" {
|
||||
return "", modelName, false
|
||||
}
|
||||
|
||||
// Check if the provider is a configured openai-compatibility provider
|
||||
for _, pName := range h.GetOpenAICompatProviders() {
|
||||
if pName == providerPart {
|
||||
return providerPart, modelPart, true
|
||||
}
|
||||
}
|
||||
|
||||
return "", modelName, false
|
||||
}
|
||||
|
||||
func cloneBytes(src []byte) []byte {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
@@ -457,12 +410,53 @@ func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.Erro
|
||||
}
|
||||
}
|
||||
}
|
||||
c.Status(status)
|
||||
|
||||
errText := http.StatusText(status)
|
||||
if msg != nil && msg.Error != nil {
|
||||
_, _ = c.Writer.Write([]byte(msg.Error.Error()))
|
||||
} else {
|
||||
_, _ = c.Writer.Write([]byte(http.StatusText(status)))
|
||||
if v := strings.TrimSpace(msg.Error.Error()); v != "" {
|
||||
errText = v
|
||||
}
|
||||
}
|
||||
|
||||
// Prefer preserving upstream JSON error bodies when possible.
|
||||
buildJSONBody := func() []byte {
|
||||
trimmed := strings.TrimSpace(errText)
|
||||
if trimmed != "" && json.Valid([]byte(trimmed)) {
|
||||
return []byte(trimmed)
|
||||
}
|
||||
errType := "invalid_request_error"
|
||||
switch status {
|
||||
case http.StatusUnauthorized:
|
||||
errType = "authentication_error"
|
||||
case http.StatusForbidden:
|
||||
errType = "permission_error"
|
||||
case http.StatusTooManyRequests:
|
||||
errType = "rate_limit_error"
|
||||
default:
|
||||
if status >= http.StatusInternalServerError {
|
||||
errType = "server_error"
|
||||
}
|
||||
}
|
||||
payload, err := json.Marshal(ErrorResponse{
|
||||
Error: ErrorDetail{
|
||||
Message: errText,
|
||||
Type: errType,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return []byte(fmt.Sprintf(`{"error":{"message":%q,"type":"server_error"}}`, errText))
|
||||
}
|
||||
return payload
|
||||
}
|
||||
|
||||
body := buildJSONBody()
|
||||
c.Set("API_RESPONSE", bytes.Clone(body))
|
||||
|
||||
if !c.Writer.Written() {
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
}
|
||||
c.Status(status)
|
||||
_, _ = c.Writer.Write(body)
|
||||
}
|
||||
|
||||
func (h *BaseAPIHandler) LoggingAPIResponseError(ctx context.Context, err *interfaces.ErrorMessage) {
|
||||
|
||||
@@ -107,7 +107,7 @@ func (a *IFlowAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
|
||||
return nil, fmt.Errorf("iflow authentication failed: missing account identifier")
|
||||
}
|
||||
|
||||
fileName := fmt.Sprintf("iflow-%s.json", email)
|
||||
fileName := fmt.Sprintf("iflow-%s-%d.json", email, time.Now().Unix())
|
||||
metadata := map[string]any{
|
||||
"email": email,
|
||||
"api_key": tokenStorage.APIKey,
|
||||
|
||||
@@ -47,8 +47,9 @@ func (a *KiroAuthenticator) Provider() string {
|
||||
}
|
||||
|
||||
// RefreshLead indicates how soon before expiry a refresh should be attempted.
|
||||
// Set to 5 minutes to match Antigravity and avoid frequent refresh checks while still ensuring timely token refresh.
|
||||
func (a *KiroAuthenticator) RefreshLead() *time.Duration {
|
||||
d := 30 * time.Minute
|
||||
d := 5 * time.Minute
|
||||
return &d
|
||||
}
|
||||
|
||||
@@ -103,7 +104,8 @@ func (a *KiroAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
|
||||
"source": "aws-builder-id",
|
||||
"email": tokenData.Email,
|
||||
},
|
||||
NextRefreshAfter: expiresAt.Add(-30 * time.Minute),
|
||||
// NextRefreshAfter is aligned with RefreshLead (5min)
|
||||
NextRefreshAfter: expiresAt.Add(-5 * time.Minute),
|
||||
}
|
||||
|
||||
if tokenData.Email != "" {
|
||||
@@ -165,7 +167,8 @@ func (a *KiroAuthenticator) LoginWithGoogle(ctx context.Context, cfg *config.Con
|
||||
"source": "google-oauth",
|
||||
"email": tokenData.Email,
|
||||
},
|
||||
NextRefreshAfter: expiresAt.Add(-30 * time.Minute),
|
||||
// NextRefreshAfter is aligned with RefreshLead (5min)
|
||||
NextRefreshAfter: expiresAt.Add(-5 * time.Minute),
|
||||
}
|
||||
|
||||
if tokenData.Email != "" {
|
||||
@@ -227,7 +230,8 @@ func (a *KiroAuthenticator) LoginWithGitHub(ctx context.Context, cfg *config.Con
|
||||
"source": "github-oauth",
|
||||
"email": tokenData.Email,
|
||||
},
|
||||
NextRefreshAfter: expiresAt.Add(-30 * time.Minute),
|
||||
// NextRefreshAfter is aligned with RefreshLead (5min)
|
||||
NextRefreshAfter: expiresAt.Add(-5 * time.Minute),
|
||||
}
|
||||
|
||||
if tokenData.Email != "" {
|
||||
@@ -291,7 +295,8 @@ func (a *KiroAuthenticator) ImportFromKiroIDE(ctx context.Context, cfg *config.C
|
||||
"source": "kiro-ide-import",
|
||||
"email": tokenData.Email,
|
||||
},
|
||||
NextRefreshAfter: expiresAt.Add(-30 * time.Minute),
|
||||
// NextRefreshAfter is aligned with RefreshLead (5min)
|
||||
NextRefreshAfter: expiresAt.Add(-5 * time.Minute),
|
||||
}
|
||||
|
||||
// Display the email if extracted
|
||||
@@ -351,7 +356,8 @@ func (a *KiroAuthenticator) Refresh(ctx context.Context, cfg *config.Config, aut
|
||||
updated.Metadata["refresh_token"] = tokenData.RefreshToken
|
||||
updated.Metadata["expires_at"] = tokenData.ExpiresAt
|
||||
updated.Metadata["last_refresh"] = now.Format(time.RFC3339) // For double-check optimization
|
||||
updated.NextRefreshAfter = expiresAt.Add(-30 * time.Minute)
|
||||
// NextRefreshAfter is aligned with RefreshLead (5min)
|
||||
updated.NextRefreshAfter = expiresAt.Add(-5 * time.Minute)
|
||||
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
@@ -40,7 +40,7 @@ type RefreshEvaluator interface {
|
||||
const (
|
||||
refreshCheckInterval = 5 * time.Second
|
||||
refreshPendingBackoff = time.Minute
|
||||
refreshFailureBackoff = 5 * time.Minute
|
||||
refreshFailureBackoff = 1 * time.Minute
|
||||
quotaBackoffBase = time.Second
|
||||
quotaBackoffMax = 30 * time.Minute
|
||||
)
|
||||
@@ -363,10 +363,11 @@ func (m *Manager) executeWithProvider(ctx context.Context, provider string, req
|
||||
if provider == "" {
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "provider identifier is empty"}
|
||||
}
|
||||
routeModel := req.Model
|
||||
tried := make(map[string]struct{})
|
||||
var lastErr error
|
||||
for {
|
||||
auth, executor, errPick := m.pickNext(ctx, provider, req.Model, opts, tried)
|
||||
auth, executor, errPick := m.pickNext(ctx, provider, routeModel, opts, tried)
|
||||
if errPick != nil {
|
||||
if lastErr != nil {
|
||||
return cliproxyexecutor.Response{}, lastErr
|
||||
@@ -375,10 +376,19 @@ func (m *Manager) executeWithProvider(ctx context.Context, provider string, req
|
||||
}
|
||||
|
||||
accountType, accountInfo := auth.AccountInfo()
|
||||
proxyInfo := auth.ProxyInfo()
|
||||
if accountType == "api_key" {
|
||||
log.Debugf("Use API key %s for model %s", util.HideAPIKey(accountInfo), req.Model)
|
||||
if proxyInfo != "" {
|
||||
log.Debugf("Use API key %s for model %s %s", util.HideAPIKey(accountInfo), req.Model, proxyInfo)
|
||||
} else {
|
||||
log.Debugf("Use API key %s for model %s", util.HideAPIKey(accountInfo), req.Model)
|
||||
}
|
||||
} else if accountType == "oauth" {
|
||||
log.Debugf("Use OAuth %s for model %s", accountInfo, req.Model)
|
||||
if proxyInfo != "" {
|
||||
log.Debugf("Use OAuth %s for model %s %s", accountInfo, req.Model, proxyInfo)
|
||||
} else {
|
||||
log.Debugf("Use OAuth %s for model %s", accountInfo, req.Model)
|
||||
}
|
||||
}
|
||||
|
||||
tried[auth.ID] = struct{}{}
|
||||
@@ -387,8 +397,10 @@ func (m *Manager) executeWithProvider(ctx context.Context, provider string, req
|
||||
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
resp, errExec := executor.Execute(execCtx, auth, req, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: req.Model, Success: errExec == nil}
|
||||
execReq := req
|
||||
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
||||
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||
if errExec != nil {
|
||||
result.Error = &Error{Message: errExec.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
@@ -411,10 +423,11 @@ func (m *Manager) executeCountWithProvider(ctx context.Context, provider string,
|
||||
if provider == "" {
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "provider identifier is empty"}
|
||||
}
|
||||
routeModel := req.Model
|
||||
tried := make(map[string]struct{})
|
||||
var lastErr error
|
||||
for {
|
||||
auth, executor, errPick := m.pickNext(ctx, provider, req.Model, opts, tried)
|
||||
auth, executor, errPick := m.pickNext(ctx, provider, routeModel, opts, tried)
|
||||
if errPick != nil {
|
||||
if lastErr != nil {
|
||||
return cliproxyexecutor.Response{}, lastErr
|
||||
@@ -423,10 +436,19 @@ func (m *Manager) executeCountWithProvider(ctx context.Context, provider string,
|
||||
}
|
||||
|
||||
accountType, accountInfo := auth.AccountInfo()
|
||||
proxyInfo := auth.ProxyInfo()
|
||||
if accountType == "api_key" {
|
||||
log.Debugf("Use API key %s for model %s", util.HideAPIKey(accountInfo), req.Model)
|
||||
if proxyInfo != "" {
|
||||
log.Debugf("Use API key %s for model %s %s", util.HideAPIKey(accountInfo), req.Model, proxyInfo)
|
||||
} else {
|
||||
log.Debugf("Use API key %s for model %s", util.HideAPIKey(accountInfo), req.Model)
|
||||
}
|
||||
} else if accountType == "oauth" {
|
||||
log.Debugf("Use OAuth %s for model %s", accountInfo, req.Model)
|
||||
if proxyInfo != "" {
|
||||
log.Debugf("Use OAuth %s for model %s %s", accountInfo, req.Model, proxyInfo)
|
||||
} else {
|
||||
log.Debugf("Use OAuth %s for model %s", accountInfo, req.Model)
|
||||
}
|
||||
}
|
||||
|
||||
tried[auth.ID] = struct{}{}
|
||||
@@ -435,8 +457,10 @@ func (m *Manager) executeCountWithProvider(ctx context.Context, provider string,
|
||||
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
resp, errExec := executor.CountTokens(execCtx, auth, req, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: req.Model, Success: errExec == nil}
|
||||
execReq := req
|
||||
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
||||
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||
if errExec != nil {
|
||||
result.Error = &Error{Message: errExec.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
@@ -459,10 +483,11 @@ func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string
|
||||
if provider == "" {
|
||||
return nil, &Error{Code: "provider_not_found", Message: "provider identifier is empty"}
|
||||
}
|
||||
routeModel := req.Model
|
||||
tried := make(map[string]struct{})
|
||||
var lastErr error
|
||||
for {
|
||||
auth, executor, errPick := m.pickNext(ctx, provider, req.Model, opts, tried)
|
||||
auth, executor, errPick := m.pickNext(ctx, provider, routeModel, opts, tried)
|
||||
if errPick != nil {
|
||||
if lastErr != nil {
|
||||
return nil, lastErr
|
||||
@@ -471,10 +496,19 @@ func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string
|
||||
}
|
||||
|
||||
accountType, accountInfo := auth.AccountInfo()
|
||||
proxyInfo := auth.ProxyInfo()
|
||||
if accountType == "api_key" {
|
||||
log.Debugf("Use API key %s for model %s", util.HideAPIKey(accountInfo), req.Model)
|
||||
if proxyInfo != "" {
|
||||
log.Debugf("Use API key %s for model %s %s", util.HideAPIKey(accountInfo), req.Model, proxyInfo)
|
||||
} else {
|
||||
log.Debugf("Use API key %s for model %s", util.HideAPIKey(accountInfo), req.Model)
|
||||
}
|
||||
} else if accountType == "oauth" {
|
||||
log.Debugf("Use OAuth %s for model %s", accountInfo, req.Model)
|
||||
if proxyInfo != "" {
|
||||
log.Debugf("Use OAuth %s for model %s %s", accountInfo, req.Model, proxyInfo)
|
||||
} else {
|
||||
log.Debugf("Use OAuth %s for model %s", accountInfo, req.Model)
|
||||
}
|
||||
}
|
||||
|
||||
tried[auth.ID] = struct{}{}
|
||||
@@ -483,14 +517,16 @@ func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string
|
||||
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
chunks, errStream := executor.ExecuteStream(execCtx, auth, req, opts)
|
||||
execReq := req
|
||||
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
||||
chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
|
||||
if errStream != nil {
|
||||
rerr := &Error{Message: errStream.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(errStream, &se) && se != nil {
|
||||
rerr.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: req.Model, Success: false, Error: rerr}
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
|
||||
result.RetryAfter = retryAfterFromError(errStream)
|
||||
m.MarkResult(execCtx, result)
|
||||
lastErr = errStream
|
||||
@@ -508,18 +544,66 @@ func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string
|
||||
if errors.As(chunk.Err, &se) && se != nil {
|
||||
rerr.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: req.Model, Success: false, Error: rerr})
|
||||
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: false, Error: rerr})
|
||||
}
|
||||
out <- chunk
|
||||
}
|
||||
if !failed {
|
||||
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: req.Model, Success: true})
|
||||
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true})
|
||||
}
|
||||
}(execCtx, auth.Clone(), provider, chunks)
|
||||
return out, nil
|
||||
}
|
||||
}
|
||||
|
||||
func rewriteModelForAuth(model string, metadata map[string]any, auth *Auth) (string, map[string]any) {
|
||||
if auth == nil || model == "" {
|
||||
return model, metadata
|
||||
}
|
||||
prefix := strings.TrimSpace(auth.Prefix)
|
||||
if prefix == "" {
|
||||
return model, metadata
|
||||
}
|
||||
needle := prefix + "/"
|
||||
if !strings.HasPrefix(model, needle) {
|
||||
return model, metadata
|
||||
}
|
||||
rewritten := strings.TrimPrefix(model, needle)
|
||||
return rewritten, stripPrefixFromMetadata(metadata, needle)
|
||||
}
|
||||
|
||||
func stripPrefixFromMetadata(metadata map[string]any, needle string) map[string]any {
|
||||
if len(metadata) == 0 || needle == "" {
|
||||
return metadata
|
||||
}
|
||||
keys := []string{
|
||||
util.ThinkingOriginalModelMetadataKey,
|
||||
util.GeminiOriginalModelMetadataKey,
|
||||
}
|
||||
var out map[string]any
|
||||
for _, key := range keys {
|
||||
raw, ok := metadata[key]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
value, okStr := raw.(string)
|
||||
if !okStr || !strings.HasPrefix(value, needle) {
|
||||
continue
|
||||
}
|
||||
if out == nil {
|
||||
out = make(map[string]any, len(metadata))
|
||||
for k, v := range metadata {
|
||||
out[k] = v
|
||||
}
|
||||
}
|
||||
out[key] = strings.TrimPrefix(value, needle)
|
||||
}
|
||||
if out == nil {
|
||||
return metadata
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (m *Manager) normalizeProviders(providers []string) []string {
|
||||
if len(providers) == 0 {
|
||||
return nil
|
||||
@@ -1471,7 +1555,9 @@ func (m *Manager) refreshAuth(ctx context.Context, id string) {
|
||||
updated.Runtime = auth.Runtime
|
||||
}
|
||||
updated.LastRefreshedAt = now
|
||||
updated.NextRefreshAfter = time.Time{}
|
||||
// Preserve NextRefreshAfter set by the Authenticator
|
||||
// If the Authenticator set a reasonable refresh time, it should not be overwritten
|
||||
// If the Authenticator did not set it (zero value), shouldRefresh will use default logic
|
||||
updated.LastError = nil
|
||||
updated.UpdatedAt = now
|
||||
_, _ = m.Update(ctx, updated)
|
||||
|
||||
@@ -19,6 +19,8 @@ type Auth struct {
|
||||
Index uint64 `json:"-"`
|
||||
// Provider is the upstream provider key (e.g. "gemini", "claude").
|
||||
Provider string `json:"provider"`
|
||||
// Prefix optionally namespaces models for routing (e.g., "teamA/gemini-3-pro-preview").
|
||||
Prefix string `json:"prefix,omitempty"`
|
||||
// FileName stores the relative or absolute path of the backing auth file.
|
||||
FileName string `json:"-"`
|
||||
// Storage holds the token persistence implementation used during login flows.
|
||||
@@ -157,6 +159,20 @@ func (m *ModelState) Clone() *ModelState {
|
||||
return ©State
|
||||
}
|
||||
|
||||
func (a *Auth) ProxyInfo() string {
|
||||
if a == nil {
|
||||
return ""
|
||||
}
|
||||
proxyStr := strings.TrimSpace(a.ProxyURL)
|
||||
if proxyStr == "" {
|
||||
return ""
|
||||
}
|
||||
if idx := strings.Index(proxyStr, "://"); idx > 0 {
|
||||
return "via " + proxyStr[:idx] + " proxy"
|
||||
}
|
||||
return "via proxy"
|
||||
}
|
||||
|
||||
func (a *Auth) AccountInfo() (string, string) {
|
||||
if a == nil {
|
||||
return "", ""
|
||||
|
||||
@@ -796,7 +796,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
if providerKey == "" {
|
||||
providerKey = "openai-compatibility"
|
||||
}
|
||||
GlobalModelRegistry().RegisterClient(a.ID, providerKey, ms)
|
||||
GlobalModelRegistry().RegisterClient(a.ID, providerKey, applyModelPrefixes(ms, a.Prefix, s.cfg.ForceModelPrefix))
|
||||
} else {
|
||||
// Ensure stale registrations are cleared when model list becomes empty.
|
||||
GlobalModelRegistry().UnregisterClient(a.ID)
|
||||
@@ -816,7 +816,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
if key == "" {
|
||||
key = strings.ToLower(strings.TrimSpace(a.Provider))
|
||||
}
|
||||
GlobalModelRegistry().RegisterClient(a.ID, key, models)
|
||||
GlobalModelRegistry().RegisterClient(a.ID, key, applyModelPrefixes(models, a.Prefix, s.cfg != nil && s.cfg.ForceModelPrefix))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -996,6 +996,48 @@ func applyExcludedModels(models []*ModelInfo, excluded []string) []*ModelInfo {
|
||||
return filtered
|
||||
}
|
||||
|
||||
func applyModelPrefixes(models []*ModelInfo, prefix string, forceModelPrefix bool) []*ModelInfo {
|
||||
trimmedPrefix := strings.TrimSpace(prefix)
|
||||
if trimmedPrefix == "" || len(models) == 0 {
|
||||
return models
|
||||
}
|
||||
|
||||
out := make([]*ModelInfo, 0, len(models)*2)
|
||||
seen := make(map[string]struct{}, len(models)*2)
|
||||
|
||||
addModel := func(model *ModelInfo) {
|
||||
if model == nil {
|
||||
return
|
||||
}
|
||||
id := strings.TrimSpace(model.ID)
|
||||
if id == "" {
|
||||
return
|
||||
}
|
||||
if _, exists := seen[id]; exists {
|
||||
return
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
out = append(out, model)
|
||||
}
|
||||
|
||||
for _, model := range models {
|
||||
if model == nil {
|
||||
continue
|
||||
}
|
||||
baseID := strings.TrimSpace(model.ID)
|
||||
if baseID == "" {
|
||||
continue
|
||||
}
|
||||
if !forceModelPrefix || trimmedPrefix == baseID {
|
||||
addModel(model)
|
||||
}
|
||||
clone := *model
|
||||
clone.ID = trimmedPrefix + "/" + baseID
|
||||
addModel(&clone)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// matchWildcard performs case-insensitive wildcard matching where '*' matches any substring.
|
||||
func matchWildcard(pattern, value string) bool {
|
||||
if pattern == "" {
|
||||
|
||||
@@ -9,6 +9,11 @@ type SDKConfig struct {
|
||||
// ProxyURL is the URL of an optional proxy server to use for outbound requests.
|
||||
ProxyURL string `yaml:"proxy-url" json:"proxy-url"`
|
||||
|
||||
// ForceModelPrefix requires explicit model prefixes (e.g., "teamA/gemini-3-pro-preview")
|
||||
// to target prefixed credentials. When false, unprefixed model requests may use prefixed
|
||||
// credentials as well.
|
||||
ForceModelPrefix bool `yaml:"force-model-prefix" json:"force-model-prefix"`
|
||||
|
||||
// RequestLog enables or disables detailed request logging functionality.
|
||||
RequestLog bool `yaml:"request-log" json:"request-log"`
|
||||
|
||||
|
||||
109
test/antigravity_claude_signature_test.go
Normal file
109
test/antigravity_claude_signature_test.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
agclaude "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/claude"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestAntigravityClaudeRequest_DropsUnsignedThinkingBlocks(t *testing.T) {
|
||||
model := "gemini-claude-sonnet-4-5-thinking"
|
||||
input := []byte(`{
|
||||
"model":"` + model + `",
|
||||
"messages":[
|
||||
{"role":"assistant","content":[{"type":"thinking","thinking":"secret without signature"}]},
|
||||
{"role":"user","content":[{"type":"text","text":"hi"}]}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := agclaude.ConvertClaudeRequestToAntigravity(model, input, false)
|
||||
contents := gjson.GetBytes(out, "request.contents")
|
||||
if !contents.Exists() || !contents.IsArray() {
|
||||
t.Fatalf("expected request.contents array, got: %s", string(out))
|
||||
}
|
||||
if got := len(contents.Array()); got != 1 {
|
||||
t.Fatalf("expected 1 content message after dropping unsigned thinking-only assistant message, got %d: %s", got, contents.Raw)
|
||||
}
|
||||
if role := contents.Array()[0].Get("role").String(); role != "user" {
|
||||
t.Fatalf("expected remaining message role=user, got %q", role)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAntigravityClaudeStreamResponse_EmitsSignatureDeltaForStandaloneSignaturePart(t *testing.T) {
|
||||
raw := []byte(`{
|
||||
"response":{
|
||||
"responseId":"resp_1",
|
||||
"modelVersion":"claude-sonnet-4-5-thinking",
|
||||
"candidates":[{
|
||||
"content":{"parts":[
|
||||
{"text":"THOUGHT","thought":true},
|
||||
{"thought":true,"thoughtSignature":"sig123"},
|
||||
{"text":"ANSWER","thought":false}
|
||||
]},
|
||||
"finishReason":"STOP"
|
||||
}],
|
||||
"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"thoughtsTokenCount":1,"totalTokenCount":3}
|
||||
}
|
||||
}`)
|
||||
|
||||
var param any
|
||||
chunks := agclaude.ConvertAntigravityResponseToClaude(context.Background(), "", nil, nil, raw, ¶m)
|
||||
joined := strings.Join(chunks, "")
|
||||
if !strings.Contains(joined, `"type":"signature_delta"`) {
|
||||
t.Fatalf("expected signature_delta in stream output, got: %s", joined)
|
||||
}
|
||||
if !strings.Contains(joined, `"signature":"sig123"`) {
|
||||
t.Fatalf("expected signature sig123 in stream output, got: %s", joined)
|
||||
}
|
||||
// Signature delta must be attached to the thinking content block (index 0 in this minimal stream).
|
||||
if !strings.Contains(joined, `{"type":"content_block_delta","index":0,"delta":{"type":"signature_delta","signature":"sig123"}}`) {
|
||||
t.Fatalf("expected signature_delta to target thinking block index 0, got: %s", joined)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAntigravityClaudeNonStreamResponse_IncludesThinkingSignature(t *testing.T) {
|
||||
raw := []byte(`{
|
||||
"response":{
|
||||
"responseId":"resp_1",
|
||||
"modelVersion":"claude-sonnet-4-5-thinking",
|
||||
"candidates":[{
|
||||
"content":{"parts":[
|
||||
{"text":"THOUGHT","thought":true},
|
||||
{"thought":true,"thoughtSignature":"sig123"},
|
||||
{"text":"ANSWER","thought":false}
|
||||
]},
|
||||
"finishReason":"STOP"
|
||||
}],
|
||||
"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"thoughtsTokenCount":1,"totalTokenCount":3}
|
||||
}
|
||||
}`)
|
||||
|
||||
out := agclaude.ConvertAntigravityResponseToClaudeNonStream(context.Background(), "", nil, nil, raw, nil)
|
||||
if !gjson.Valid(out) {
|
||||
t.Fatalf("expected valid JSON output, got: %s", out)
|
||||
}
|
||||
content := gjson.Get(out, "content")
|
||||
if !content.Exists() || !content.IsArray() {
|
||||
t.Fatalf("expected content array in output, got: %s", out)
|
||||
}
|
||||
|
||||
found := false
|
||||
for _, block := range content.Array() {
|
||||
if block.Get("type").String() != "thinking" {
|
||||
continue
|
||||
}
|
||||
found = true
|
||||
if got := block.Get("signature").String(); got != "sig123" {
|
||||
t.Fatalf("expected thinking.signature=sig123, got %q (block=%s)", got, block.Raw)
|
||||
}
|
||||
if got := block.Get("thinking").String(); got != "THOUGHT" {
|
||||
t.Fatalf("expected thinking.thinking=THOUGHT, got %q (block=%s)", got, block.Raw)
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatalf("expected a thinking block in output, got: %s", out)
|
||||
}
|
||||
}
|
||||
798
test/thinking_conversion_test.go
Normal file
798
test/thinking_conversion_test.go
Normal file
@@ -0,0 +1,798 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// isOpenAICompatModel returns true if the model is configured as an OpenAI-compatible
|
||||
// model that should have reasoning effort passed through even if not in registry.
|
||||
// This simulates the allowCompat behavior from OpenAICompatExecutor.
|
||||
func isOpenAICompatModel(model string) bool {
|
||||
return model == "openai-compat"
|
||||
}
|
||||
|
||||
// registerCoreModels loads representative models across providers into the registry
|
||||
// so NormalizeThinkingBudget and level validation use real ranges.
|
||||
func registerCoreModels(t *testing.T) func() {
|
||||
t.Helper()
|
||||
reg := registry.GetGlobalRegistry()
|
||||
uid := fmt.Sprintf("thinking-core-%d", time.Now().UnixNano())
|
||||
reg.RegisterClient(uid+"-gemini", "gemini", registry.GetGeminiModels())
|
||||
reg.RegisterClient(uid+"-claude", "claude", registry.GetClaudeModels())
|
||||
reg.RegisterClient(uid+"-openai", "codex", registry.GetOpenAIModels())
|
||||
reg.RegisterClient(uid+"-qwen", "qwen", registry.GetQwenModels())
|
||||
// Custom openai-compatible model with forced thinking suffix passthrough.
|
||||
// No Thinking field - simulates an external model added via openai-compat
|
||||
// where the registry has no knowledge of its thinking capabilities.
|
||||
// The allowCompat flag should preserve reasoning effort for such models.
|
||||
customOpenAIModels := []*registry.ModelInfo{
|
||||
{
|
||||
ID: "openai-compat",
|
||||
Object: "model",
|
||||
Created: 1700000000,
|
||||
OwnedBy: "custom-provider",
|
||||
Type: "openai",
|
||||
DisplayName: "OpenAI Compatible Model",
|
||||
Description: "OpenAI-compatible model with forced thinking suffix support",
|
||||
},
|
||||
}
|
||||
reg.RegisterClient(uid+"-custom-openai", "codex", customOpenAIModels)
|
||||
return func() {
|
||||
reg.UnregisterClient(uid + "-gemini")
|
||||
reg.UnregisterClient(uid + "-claude")
|
||||
reg.UnregisterClient(uid + "-openai")
|
||||
reg.UnregisterClient(uid + "-qwen")
|
||||
reg.UnregisterClient(uid + "-custom-openai")
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
thinkingTestModels = []string{
|
||||
"gpt-5", // level-based thinking model
|
||||
"gemini-2.5-pro", // numeric-budget thinking model
|
||||
"qwen3-code-plus", // no thinking support
|
||||
"openai-compat", // allowCompat=true (OpenAI-compatible channel)
|
||||
}
|
||||
thinkingTestFromProtocols = []string{"openai", "claude", "gemini", "openai-response"}
|
||||
thinkingTestToProtocols = []string{"gemini", "claude", "openai", "codex"}
|
||||
|
||||
// Numeric budgets and their level equivalents:
|
||||
// -1 -> auto
|
||||
// 0 -> none
|
||||
// 1..1024 -> low
|
||||
// 1025..8192 -> medium
|
||||
// 8193..24576 -> high
|
||||
// >24576 -> model highest level (right-most in Levels)
|
||||
thinkingNumericSamples = []int{-1, 0, 1023, 1025, 8193, 64000}
|
||||
|
||||
// Levels and their numeric equivalents:
|
||||
// auto -> -1
|
||||
// none -> 0
|
||||
// minimal -> 512
|
||||
// low -> 1024
|
||||
// medium -> 8192
|
||||
// high -> 24576
|
||||
// xhigh -> 32768
|
||||
// invalid -> invalid (no mapping)
|
||||
thinkingLevelSamples = []string{"auto", "none", "minimal", "low", "medium", "high", "xhigh", "invalid"}
|
||||
)
|
||||
|
||||
func buildRawPayload(fromProtocol, modelWithSuffix string) []byte {
|
||||
switch fromProtocol {
|
||||
case "gemini":
|
||||
return []byte(fmt.Sprintf(`{"model":"%s","contents":[{"role":"user","parts":[{"text":"hi"}]}]}`, modelWithSuffix))
|
||||
case "openai-response":
|
||||
return []byte(fmt.Sprintf(`{"model":"%s","input":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`, modelWithSuffix))
|
||||
default: // openai / claude and other chat-style payloads
|
||||
return []byte(fmt.Sprintf(`{"model":"%s","messages":[{"role":"user","content":"hi"}]}`, modelWithSuffix))
|
||||
}
|
||||
}
|
||||
|
||||
// normalizeCodexPayload mirrors codex_executor's reasoning + streaming tweaks.
|
||||
func normalizeCodexPayload(body []byte, upstreamModel string, allowCompat bool) ([]byte, error) {
|
||||
body = executor.NormalizeThinkingConfig(body, upstreamModel, allowCompat)
|
||||
if err := executor.ValidateThinkingConfig(body, upstreamModel); err != nil {
|
||||
return body, err
|
||||
}
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
body, _ = sjson.SetBytes(body, "stream", true)
|
||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// buildBodyForProtocol runs a minimal request through the same translation and
|
||||
// thinking pipeline used in executors for the given target protocol.
|
||||
func buildBodyForProtocol(t *testing.T, fromProtocol, toProtocol, modelWithSuffix string) ([]byte, error) {
|
||||
t.Helper()
|
||||
normalizedModel, metadata := util.NormalizeThinkingModel(modelWithSuffix)
|
||||
upstreamModel := util.ResolveOriginalModel(normalizedModel, metadata)
|
||||
raw := buildRawPayload(fromProtocol, modelWithSuffix)
|
||||
stream := fromProtocol != toProtocol
|
||||
|
||||
body := sdktranslator.TranslateRequest(
|
||||
sdktranslator.FromString(fromProtocol),
|
||||
sdktranslator.FromString(toProtocol),
|
||||
normalizedModel,
|
||||
raw,
|
||||
stream,
|
||||
)
|
||||
|
||||
var err error
|
||||
allowCompat := isOpenAICompatModel(normalizedModel)
|
||||
switch toProtocol {
|
||||
case "gemini":
|
||||
body = executor.ApplyThinkingMetadata(body, metadata, normalizedModel)
|
||||
body = util.ApplyDefaultThinkingIfNeeded(normalizedModel, body)
|
||||
body = util.NormalizeGeminiThinkingBudget(normalizedModel, body)
|
||||
body = util.StripThinkingConfigIfUnsupported(normalizedModel, body)
|
||||
case "claude":
|
||||
if budget, ok := util.ResolveClaudeThinkingConfig(normalizedModel, metadata); ok {
|
||||
body = util.ApplyClaudeThinkingConfig(body, budget)
|
||||
}
|
||||
case "openai":
|
||||
body = executor.ApplyReasoningEffortMetadata(body, metadata, normalizedModel, "reasoning_effort", allowCompat)
|
||||
body = executor.NormalizeThinkingConfig(body, upstreamModel, allowCompat)
|
||||
err = executor.ValidateThinkingConfig(body, upstreamModel)
|
||||
case "codex": // OpenAI responses / codex
|
||||
// Codex does not support allowCompat; always use false.
|
||||
body = executor.ApplyReasoningEffortMetadata(body, metadata, normalizedModel, "reasoning.effort", false)
|
||||
// Mirror CodexExecutor final normalization and model override so tests log the final body.
|
||||
body, err = normalizeCodexPayload(body, upstreamModel, false)
|
||||
default:
|
||||
}
|
||||
|
||||
// Mirror executor behavior: final payload uses the upstream (base) model name.
|
||||
if upstreamModel != "" {
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
}
|
||||
|
||||
// For tests we only keep model + thinking-related fields to avoid noise.
|
||||
body = filterThinkingBody(toProtocol, body, upstreamModel, normalizedModel)
|
||||
return body, err
|
||||
}
|
||||
|
||||
// filterThinkingBody projects the translated payload down to only model and
|
||||
// thinking-related fields for the given target protocol.
|
||||
func filterThinkingBody(toProtocol string, body []byte, upstreamModel, normalizedModel string) []byte {
|
||||
if len(body) == 0 {
|
||||
return body
|
||||
}
|
||||
out := []byte(`{}`)
|
||||
|
||||
// Preserve model if present, otherwise fall back to upstream/normalized model.
|
||||
if m := gjson.GetBytes(body, "model"); m.Exists() {
|
||||
out, _ = sjson.SetBytes(out, "model", m.Value())
|
||||
} else if upstreamModel != "" {
|
||||
out, _ = sjson.SetBytes(out, "model", upstreamModel)
|
||||
} else if normalizedModel != "" {
|
||||
out, _ = sjson.SetBytes(out, "model", normalizedModel)
|
||||
}
|
||||
|
||||
switch toProtocol {
|
||||
case "gemini":
|
||||
if tc := gjson.GetBytes(body, "generationConfig.thinkingConfig"); tc.Exists() {
|
||||
out, _ = sjson.SetRawBytes(out, "generationConfig.thinkingConfig", []byte(tc.Raw))
|
||||
}
|
||||
case "claude":
|
||||
if tcfg := gjson.GetBytes(body, "thinking"); tcfg.Exists() {
|
||||
out, _ = sjson.SetRawBytes(out, "thinking", []byte(tcfg.Raw))
|
||||
}
|
||||
case "openai":
|
||||
if re := gjson.GetBytes(body, "reasoning_effort"); re.Exists() {
|
||||
out, _ = sjson.SetBytes(out, "reasoning_effort", re.Value())
|
||||
}
|
||||
case "codex":
|
||||
if re := gjson.GetBytes(body, "reasoning.effort"); re.Exists() {
|
||||
out, _ = sjson.SetBytes(out, "reasoning.effort", re.Value())
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func TestThinkingConversionsAcrossProtocolsAndModels(t *testing.T) {
|
||||
cleanup := registerCoreModels(t)
|
||||
defer cleanup()
|
||||
|
||||
type scenario struct {
|
||||
name string
|
||||
modelSuffix string
|
||||
}
|
||||
|
||||
numericName := func(budget int) string {
|
||||
if budget < 0 {
|
||||
return "numeric-neg1"
|
||||
}
|
||||
return fmt.Sprintf("numeric-%d", budget)
|
||||
}
|
||||
|
||||
for _, model := range thinkingTestModels {
|
||||
_ = registry.GetGlobalRegistry().GetModelInfo(model)
|
||||
|
||||
for _, from := range thinkingTestFromProtocols {
|
||||
// Scenario selection follows protocol semantics:
|
||||
// - OpenAI-style protocols (openai/openai-response) express thinking as levels.
|
||||
// - Claude/Gemini-style protocols express thinking as numeric budgets.
|
||||
cases := []scenario{
|
||||
{name: "no-suffix", modelSuffix: model},
|
||||
}
|
||||
if from == "openai" || from == "openai-response" {
|
||||
for _, lvl := range thinkingLevelSamples {
|
||||
cases = append(cases, scenario{
|
||||
name: "level-" + lvl,
|
||||
modelSuffix: fmt.Sprintf("%s(%s)", model, lvl),
|
||||
})
|
||||
}
|
||||
} else { // claude or gemini
|
||||
for _, budget := range thinkingNumericSamples {
|
||||
budget := budget
|
||||
cases = append(cases, scenario{
|
||||
name: numericName(budget),
|
||||
modelSuffix: fmt.Sprintf("%s(%d)", model, budget),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
for _, to := range thinkingTestToProtocols {
|
||||
if from == to {
|
||||
continue
|
||||
}
|
||||
t.Logf("─────────────────────────────────────────────────────────────────────────────────")
|
||||
t.Logf(" %s -> %s | model: %s", from, to, model)
|
||||
t.Logf("─────────────────────────────────────────────────────────────────────────────────")
|
||||
for _, cs := range cases {
|
||||
from := from
|
||||
to := to
|
||||
cs := cs
|
||||
testName := fmt.Sprintf("%s->%s/%s/%s", from, to, model, cs.name)
|
||||
t.Run(testName, func(t *testing.T) {
|
||||
normalizedModel, metadata := util.NormalizeThinkingModel(cs.modelSuffix)
|
||||
expectPresent, expectValue, expectErr := func() (bool, string, bool) {
|
||||
switch to {
|
||||
case "gemini":
|
||||
budget, include, ok := util.ResolveThinkingConfigFromMetadata(normalizedModel, metadata)
|
||||
if !ok || !util.ModelSupportsThinking(normalizedModel) {
|
||||
return false, "", false
|
||||
}
|
||||
if include != nil && !*include {
|
||||
return false, "", false
|
||||
}
|
||||
if budget == nil {
|
||||
return false, "", false
|
||||
}
|
||||
norm := util.NormalizeThinkingBudget(normalizedModel, *budget)
|
||||
return true, fmt.Sprintf("%d", norm), false
|
||||
case "claude":
|
||||
if !util.ModelSupportsThinking(normalizedModel) {
|
||||
return false, "", false
|
||||
}
|
||||
budget, ok := util.ResolveClaudeThinkingConfig(normalizedModel, metadata)
|
||||
if !ok || budget == nil {
|
||||
return false, "", false
|
||||
}
|
||||
return true, fmt.Sprintf("%d", *budget), false
|
||||
case "openai":
|
||||
allowCompat := isOpenAICompatModel(normalizedModel)
|
||||
if !util.ModelSupportsThinking(normalizedModel) && !allowCompat {
|
||||
return false, "", false
|
||||
}
|
||||
// For allowCompat models, pass through effort directly without validation
|
||||
if allowCompat {
|
||||
effort, ok := util.ReasoningEffortFromMetadata(metadata)
|
||||
if ok && strings.TrimSpace(effort) != "" {
|
||||
return true, strings.ToLower(strings.TrimSpace(effort)), false
|
||||
}
|
||||
// Check numeric budget fallback for allowCompat
|
||||
if budget, _, _, matched := util.ThinkingFromMetadata(metadata); matched && budget != nil {
|
||||
if mapped, okMap := util.ThinkingBudgetToEffort(normalizedModel, *budget); okMap && mapped != "" {
|
||||
return true, mapped, false
|
||||
}
|
||||
}
|
||||
return false, "", false
|
||||
}
|
||||
if !util.ModelUsesThinkingLevels(normalizedModel) {
|
||||
// Non-levels models don't support effort strings in openai
|
||||
return false, "", false
|
||||
}
|
||||
effort, ok := util.ReasoningEffortFromMetadata(metadata)
|
||||
if !ok || strings.TrimSpace(effort) == "" {
|
||||
if budget, _, _, matched := util.ThinkingFromMetadata(metadata); matched && budget != nil {
|
||||
if mapped, okMap := util.ThinkingBudgetToEffort(normalizedModel, *budget); okMap {
|
||||
effort = mapped
|
||||
ok = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if !ok || strings.TrimSpace(effort) == "" {
|
||||
return false, "", false
|
||||
}
|
||||
effort = strings.ToLower(strings.TrimSpace(effort))
|
||||
if normalized, okLevel := util.NormalizeReasoningEffortLevel(normalizedModel, effort); okLevel {
|
||||
return true, normalized, false
|
||||
}
|
||||
return false, "", true // validation would fail
|
||||
case "codex":
|
||||
// Codex does not support allowCompat; require thinking-capable level models.
|
||||
if !util.ModelSupportsThinking(normalizedModel) || !util.ModelUsesThinkingLevels(normalizedModel) {
|
||||
return false, "", false
|
||||
}
|
||||
effort, ok := util.ReasoningEffortFromMetadata(metadata)
|
||||
if ok && strings.TrimSpace(effort) != "" {
|
||||
effort = strings.ToLower(strings.TrimSpace(effort))
|
||||
if normalized, okLevel := util.NormalizeReasoningEffortLevel(normalizedModel, effort); okLevel {
|
||||
return true, normalized, false
|
||||
}
|
||||
return false, "", true
|
||||
}
|
||||
if budget, _, _, matched := util.ThinkingFromMetadata(metadata); matched && budget != nil {
|
||||
if mapped, okMap := util.ThinkingBudgetToEffort(normalizedModel, *budget); okMap && mapped != "" {
|
||||
mapped = strings.ToLower(strings.TrimSpace(mapped))
|
||||
if normalized, okLevel := util.NormalizeReasoningEffortLevel(normalizedModel, mapped); okLevel {
|
||||
return true, normalized, false
|
||||
}
|
||||
return false, "", true
|
||||
}
|
||||
}
|
||||
if from != "openai-response" {
|
||||
// Codex translators default reasoning.effort to "medium" when
|
||||
// no explicit thinking suffix/metadata is provided.
|
||||
return true, "medium", false
|
||||
}
|
||||
return false, "", false
|
||||
default:
|
||||
return false, "", false
|
||||
}
|
||||
}()
|
||||
|
||||
body, err := buildBodyForProtocol(t, from, to, cs.modelSuffix)
|
||||
actualPresent, actualValue := func() (bool, string) {
|
||||
path := ""
|
||||
switch to {
|
||||
case "gemini":
|
||||
path = "generationConfig.thinkingConfig.thinkingBudget"
|
||||
case "claude":
|
||||
path = "thinking.budget_tokens"
|
||||
case "openai":
|
||||
path = "reasoning_effort"
|
||||
case "codex":
|
||||
path = "reasoning.effort"
|
||||
}
|
||||
if path == "" {
|
||||
return false, ""
|
||||
}
|
||||
val := gjson.GetBytes(body, path)
|
||||
if to == "codex" && !val.Exists() {
|
||||
reasoning := gjson.GetBytes(body, "reasoning")
|
||||
if reasoning.Exists() {
|
||||
val = reasoning.Get("effort")
|
||||
}
|
||||
}
|
||||
if !val.Exists() {
|
||||
return false, ""
|
||||
}
|
||||
if val.Type == gjson.Number {
|
||||
return true, fmt.Sprintf("%d", val.Int())
|
||||
}
|
||||
return true, val.String()
|
||||
}()
|
||||
|
||||
t.Logf("from=%s to=%s model=%s suffix=%s present(expect=%v got=%v) value(expect=%s got=%s) err(expect=%v got=%v) body=%s",
|
||||
from, to, model, cs.modelSuffix, expectPresent, actualPresent, expectValue, actualValue, expectErr, err != nil, string(body))
|
||||
|
||||
if expectErr {
|
||||
if err == nil {
|
||||
t.Fatalf("expected validation error but got none, body=%s", string(body))
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v body=%s", err, string(body))
|
||||
}
|
||||
|
||||
if expectPresent != actualPresent {
|
||||
t.Fatalf("presence mismatch: expect %v got %v body=%s", expectPresent, actualPresent, string(body))
|
||||
}
|
||||
if expectPresent && expectValue != actualValue {
|
||||
t.Fatalf("value mismatch: expect %s got %s body=%s", expectValue, actualValue, string(body))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// buildRawPayloadWithThinking creates a payload with thinking parameters already in the body.
|
||||
// This tests the path where thinking comes from the raw payload, not model suffix.
|
||||
func buildRawPayloadWithThinking(fromProtocol, model string, thinkingParam any) []byte {
|
||||
switch fromProtocol {
|
||||
case "gemini":
|
||||
base := fmt.Sprintf(`{"model":"%s","contents":[{"role":"user","parts":[{"text":"hi"}]}]}`, model)
|
||||
if budget, ok := thinkingParam.(int); ok {
|
||||
base, _ = sjson.Set(base, "generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||
}
|
||||
return []byte(base)
|
||||
case "openai-response":
|
||||
base := fmt.Sprintf(`{"model":"%s","input":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`, model)
|
||||
if effort, ok := thinkingParam.(string); ok && effort != "" {
|
||||
base, _ = sjson.Set(base, "reasoning.effort", effort)
|
||||
}
|
||||
return []byte(base)
|
||||
case "openai":
|
||||
base := fmt.Sprintf(`{"model":"%s","messages":[{"role":"user","content":"hi"}]}`, model)
|
||||
if effort, ok := thinkingParam.(string); ok && effort != "" {
|
||||
base, _ = sjson.Set(base, "reasoning_effort", effort)
|
||||
}
|
||||
return []byte(base)
|
||||
case "claude":
|
||||
base := fmt.Sprintf(`{"model":"%s","messages":[{"role":"user","content":"hi"}]}`, model)
|
||||
if budget, ok := thinkingParam.(int); ok {
|
||||
base, _ = sjson.Set(base, "thinking.type", "enabled")
|
||||
base, _ = sjson.Set(base, "thinking.budget_tokens", budget)
|
||||
}
|
||||
return []byte(base)
|
||||
default:
|
||||
return []byte(fmt.Sprintf(`{"model":"%s","messages":[{"role":"user","content":"hi"}]}`, model))
|
||||
}
|
||||
}
|
||||
|
||||
// buildBodyForProtocolWithRawThinking translates payload with raw thinking params.
|
||||
func buildBodyForProtocolWithRawThinking(t *testing.T, fromProtocol, toProtocol, model string, thinkingParam any) ([]byte, error) {
|
||||
t.Helper()
|
||||
raw := buildRawPayloadWithThinking(fromProtocol, model, thinkingParam)
|
||||
stream := fromProtocol != toProtocol
|
||||
|
||||
body := sdktranslator.TranslateRequest(
|
||||
sdktranslator.FromString(fromProtocol),
|
||||
sdktranslator.FromString(toProtocol),
|
||||
model,
|
||||
raw,
|
||||
stream,
|
||||
)
|
||||
|
||||
var err error
|
||||
allowCompat := isOpenAICompatModel(model)
|
||||
switch toProtocol {
|
||||
case "gemini":
|
||||
body = util.ApplyDefaultThinkingIfNeeded(model, body)
|
||||
body = util.NormalizeGeminiThinkingBudget(model, body)
|
||||
body = util.StripThinkingConfigIfUnsupported(model, body)
|
||||
case "claude":
|
||||
// For raw payload, Claude thinking is passed through by translator
|
||||
// No additional processing needed as thinking is already in body
|
||||
case "openai":
|
||||
body = executor.NormalizeThinkingConfig(body, model, allowCompat)
|
||||
err = executor.ValidateThinkingConfig(body, model)
|
||||
case "codex":
|
||||
// Codex does not support allowCompat; always use false.
|
||||
body, err = normalizeCodexPayload(body, model, false)
|
||||
}
|
||||
|
||||
body, _ = sjson.SetBytes(body, "model", model)
|
||||
body = filterThinkingBody(toProtocol, body, model, model)
|
||||
return body, err
|
||||
}
|
||||
|
||||
func TestRawPayloadThinkingConversions(t *testing.T) {
|
||||
cleanup := registerCoreModels(t)
|
||||
defer cleanup()
|
||||
|
||||
type scenario struct {
|
||||
name string
|
||||
thinkingParam any // int for budget, string for effort level
|
||||
}
|
||||
|
||||
numericName := func(budget int) string {
|
||||
if budget < 0 {
|
||||
return "budget-neg1"
|
||||
}
|
||||
return fmt.Sprintf("budget-%d", budget)
|
||||
}
|
||||
|
||||
for _, model := range thinkingTestModels {
|
||||
supportsThinking := util.ModelSupportsThinking(model)
|
||||
usesLevels := util.ModelUsesThinkingLevels(model)
|
||||
allowCompat := isOpenAICompatModel(model)
|
||||
|
||||
for _, from := range thinkingTestFromProtocols {
|
||||
var cases []scenario
|
||||
switch from {
|
||||
case "openai", "openai-response":
|
||||
cases = []scenario{
|
||||
{name: "no-thinking", thinkingParam: nil},
|
||||
}
|
||||
for _, lvl := range thinkingLevelSamples {
|
||||
cases = append(cases, scenario{
|
||||
name: "effort-" + lvl,
|
||||
thinkingParam: lvl,
|
||||
})
|
||||
}
|
||||
case "gemini", "claude":
|
||||
cases = []scenario{
|
||||
{name: "no-thinking", thinkingParam: nil},
|
||||
}
|
||||
for _, budget := range thinkingNumericSamples {
|
||||
budget := budget
|
||||
cases = append(cases, scenario{
|
||||
name: numericName(budget),
|
||||
thinkingParam: budget,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
for _, to := range thinkingTestToProtocols {
|
||||
if from == to {
|
||||
continue
|
||||
}
|
||||
t.Logf("═══════════════════════════════════════════════════════════════════════════════")
|
||||
t.Logf(" RAW PAYLOAD: %s -> %s | model: %s", from, to, model)
|
||||
t.Logf("═══════════════════════════════════════════════════════════════════════════════")
|
||||
|
||||
for _, cs := range cases {
|
||||
from := from
|
||||
to := to
|
||||
cs := cs
|
||||
testName := fmt.Sprintf("raw/%s->%s/%s/%s", from, to, model, cs.name)
|
||||
t.Run(testName, func(t *testing.T) {
|
||||
expectPresent, expectValue, expectErr := func() (bool, string, bool) {
|
||||
if cs.thinkingParam == nil {
|
||||
if to == "codex" && from != "openai-response" && supportsThinking && usesLevels {
|
||||
// Codex translators default reasoning.effort to "medium" for thinking-capable level models
|
||||
return true, "medium", false
|
||||
}
|
||||
return false, "", false
|
||||
}
|
||||
|
||||
switch to {
|
||||
case "gemini":
|
||||
if !supportsThinking || usesLevels {
|
||||
return false, "", false
|
||||
}
|
||||
// Gemini expects numeric budget (only for non-level models)
|
||||
if budget, ok := cs.thinkingParam.(int); ok {
|
||||
norm := util.NormalizeThinkingBudget(model, budget)
|
||||
return true, fmt.Sprintf("%d", norm), false
|
||||
}
|
||||
// Convert effort level to budget for non-level models only
|
||||
if effort, ok := cs.thinkingParam.(string); ok && effort != "" {
|
||||
// "none" disables thinking - no thinkingBudget in output
|
||||
if strings.ToLower(effort) == "none" {
|
||||
return false, "", false
|
||||
}
|
||||
if budget, okB := util.ThinkingEffortToBudget(model, effort); okB {
|
||||
// ThinkingEffortToBudget already returns normalized budget
|
||||
return true, fmt.Sprintf("%d", budget), false
|
||||
}
|
||||
// Invalid effort does not map to a budget
|
||||
return false, "", false
|
||||
}
|
||||
return false, "", false
|
||||
case "claude":
|
||||
if !supportsThinking || usesLevels {
|
||||
return false, "", false
|
||||
}
|
||||
// Claude expects numeric budget (only for non-level models)
|
||||
if budget, ok := cs.thinkingParam.(int); ok && budget > 0 {
|
||||
norm := util.NormalizeThinkingBudget(model, budget)
|
||||
return true, fmt.Sprintf("%d", norm), false
|
||||
}
|
||||
// Convert effort level to budget for non-level models only
|
||||
if effort, ok := cs.thinkingParam.(string); ok && effort != "" {
|
||||
// "none" and "auto" don't produce budget_tokens
|
||||
lower := strings.ToLower(effort)
|
||||
if lower == "none" || lower == "auto" {
|
||||
return false, "", false
|
||||
}
|
||||
if budget, okB := util.ThinkingEffortToBudget(model, effort); okB {
|
||||
// ThinkingEffortToBudget already returns normalized budget
|
||||
return true, fmt.Sprintf("%d", budget), false
|
||||
}
|
||||
// Invalid effort - claude sets thinking.type:enabled but no budget_tokens
|
||||
return false, "", false
|
||||
}
|
||||
return false, "", false
|
||||
case "openai":
|
||||
if allowCompat {
|
||||
if effort, ok := cs.thinkingParam.(string); ok && strings.TrimSpace(effort) != "" {
|
||||
normalized := strings.ToLower(strings.TrimSpace(effort))
|
||||
return true, normalized, false
|
||||
}
|
||||
if budget, ok := cs.thinkingParam.(int); ok {
|
||||
if mapped, okM := util.ThinkingBudgetToEffort(model, budget); okM && mapped != "" {
|
||||
return true, mapped, false
|
||||
}
|
||||
}
|
||||
return false, "", false
|
||||
}
|
||||
if !supportsThinking || !usesLevels {
|
||||
return false, "", false
|
||||
}
|
||||
if effort, ok := cs.thinkingParam.(string); ok && effort != "" {
|
||||
if normalized, okN := util.NormalizeReasoningEffortLevel(model, effort); okN {
|
||||
return true, normalized, false
|
||||
}
|
||||
return false, "", true // invalid level
|
||||
}
|
||||
if budget, ok := cs.thinkingParam.(int); ok {
|
||||
if mapped, okM := util.ThinkingBudgetToEffort(model, budget); okM && mapped != "" {
|
||||
// Check if the mapped effort is valid for this model
|
||||
if _, validLevel := util.NormalizeReasoningEffortLevel(model, mapped); !validLevel {
|
||||
return true, mapped, true // expect validation error
|
||||
}
|
||||
return true, mapped, false
|
||||
}
|
||||
}
|
||||
return false, "", false
|
||||
case "codex":
|
||||
// Codex does not support allowCompat; require thinking-capable level models.
|
||||
if !supportsThinking || !usesLevels {
|
||||
return false, "", false
|
||||
}
|
||||
if effort, ok := cs.thinkingParam.(string); ok && effort != "" {
|
||||
if normalized, okN := util.NormalizeReasoningEffortLevel(model, effort); okN {
|
||||
return true, normalized, false
|
||||
}
|
||||
return false, "", true
|
||||
}
|
||||
if budget, ok := cs.thinkingParam.(int); ok {
|
||||
if mapped, okM := util.ThinkingBudgetToEffort(model, budget); okM && mapped != "" {
|
||||
// Check if the mapped effort is valid for this model
|
||||
if _, validLevel := util.NormalizeReasoningEffortLevel(model, mapped); !validLevel {
|
||||
return true, mapped, true // expect validation error
|
||||
}
|
||||
return true, mapped, false
|
||||
}
|
||||
}
|
||||
if from != "openai-response" {
|
||||
// Codex translators default reasoning.effort to "medium" for thinking-capable models
|
||||
return true, "medium", false
|
||||
}
|
||||
return false, "", false
|
||||
}
|
||||
return false, "", false
|
||||
}()
|
||||
|
||||
body, err := buildBodyForProtocolWithRawThinking(t, from, to, model, cs.thinkingParam)
|
||||
actualPresent, actualValue := func() (bool, string) {
|
||||
path := ""
|
||||
switch to {
|
||||
case "gemini":
|
||||
path = "generationConfig.thinkingConfig.thinkingBudget"
|
||||
case "claude":
|
||||
path = "thinking.budget_tokens"
|
||||
case "openai":
|
||||
path = "reasoning_effort"
|
||||
case "codex":
|
||||
path = "reasoning.effort"
|
||||
}
|
||||
if path == "" {
|
||||
return false, ""
|
||||
}
|
||||
val := gjson.GetBytes(body, path)
|
||||
if to == "codex" && !val.Exists() {
|
||||
reasoning := gjson.GetBytes(body, "reasoning")
|
||||
if reasoning.Exists() {
|
||||
val = reasoning.Get("effort")
|
||||
}
|
||||
}
|
||||
if !val.Exists() {
|
||||
return false, ""
|
||||
}
|
||||
if val.Type == gjson.Number {
|
||||
return true, fmt.Sprintf("%d", val.Int())
|
||||
}
|
||||
return true, val.String()
|
||||
}()
|
||||
|
||||
t.Logf("from=%s to=%s model=%s param=%v present(expect=%v got=%v) value(expect=%s got=%s) err(expect=%v got=%v) body=%s",
|
||||
from, to, model, cs.thinkingParam, expectPresent, actualPresent, expectValue, actualValue, expectErr, err != nil, string(body))
|
||||
|
||||
if expectErr {
|
||||
if err == nil {
|
||||
t.Fatalf("expected validation error but got none, body=%s", string(body))
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v body=%s", err, string(body))
|
||||
}
|
||||
|
||||
if expectPresent != actualPresent {
|
||||
t.Fatalf("presence mismatch: expect %v got %v body=%s", expectPresent, actualPresent, string(body))
|
||||
}
|
||||
if expectPresent && expectValue != actualValue {
|
||||
t.Fatalf("value mismatch: expect %s got %s body=%s", expectValue, actualValue, string(body))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestThinkingBudgetToEffort(t *testing.T) {
|
||||
cleanup := registerCoreModels(t)
|
||||
defer cleanup()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
model string
|
||||
budget int
|
||||
want string
|
||||
ok bool
|
||||
}{
|
||||
{name: "dynamic-auto", model: "gpt-5", budget: -1, want: "auto", ok: true},
|
||||
{name: "zero-none", model: "gpt-5", budget: 0, want: "minimal", ok: true},
|
||||
{name: "low-min", model: "gpt-5", budget: 1, want: "low", ok: true},
|
||||
{name: "low-max", model: "gpt-5", budget: 1024, want: "low", ok: true},
|
||||
{name: "medium-min", model: "gpt-5", budget: 1025, want: "medium", ok: true},
|
||||
{name: "medium-max", model: "gpt-5", budget: 8192, want: "medium", ok: true},
|
||||
{name: "high-min", model: "gpt-5", budget: 8193, want: "high", ok: true},
|
||||
{name: "high-max", model: "gpt-5", budget: 24576, want: "high", ok: true},
|
||||
{name: "over-max-clamps-to-highest", model: "gpt-5", budget: 64000, want: "high", ok: true},
|
||||
{name: "over-max-xhigh-model", model: "gpt-5.2", budget: 64000, want: "xhigh", ok: true},
|
||||
{name: "negative-unsupported", model: "gpt-5", budget: -5, want: "", ok: false},
|
||||
}
|
||||
|
||||
for _, cs := range cases {
|
||||
cs := cs
|
||||
t.Run(cs.name, func(t *testing.T) {
|
||||
got, ok := util.ThinkingBudgetToEffort(cs.model, cs.budget)
|
||||
if ok != cs.ok {
|
||||
t.Fatalf("ok mismatch for model=%s budget=%d: expect %v got %v", cs.model, cs.budget, cs.ok, ok)
|
||||
}
|
||||
if got != cs.want {
|
||||
t.Fatalf("value mismatch for model=%s budget=%d: expect %q got %q", cs.model, cs.budget, cs.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestThinkingEffortToBudget(t *testing.T) {
|
||||
cleanup := registerCoreModels(t)
|
||||
defer cleanup()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
model string
|
||||
effort string
|
||||
want int
|
||||
ok bool
|
||||
}{
|
||||
{name: "none", model: "gemini-2.5-pro", effort: "none", want: 0, ok: true},
|
||||
{name: "auto", model: "gemini-2.5-pro", effort: "auto", want: -1, ok: true},
|
||||
{name: "minimal", model: "gemini-2.5-pro", effort: "minimal", want: 512, ok: true},
|
||||
{name: "low", model: "gemini-2.5-pro", effort: "low", want: 1024, ok: true},
|
||||
{name: "medium", model: "gemini-2.5-pro", effort: "medium", want: 8192, ok: true},
|
||||
{name: "high", model: "gemini-2.5-pro", effort: "high", want: 24576, ok: true},
|
||||
{name: "xhigh", model: "gemini-2.5-pro", effort: "xhigh", want: 32768, ok: true},
|
||||
{name: "empty-unsupported", model: "gemini-2.5-pro", effort: "", want: 0, ok: false},
|
||||
{name: "invalid-unsupported", model: "gemini-2.5-pro", effort: "ultra", want: 0, ok: false},
|
||||
{name: "case-insensitive", model: "gemini-2.5-pro", effort: "LOW", want: 1024, ok: true},
|
||||
{name: "case-insensitive-medium", model: "gemini-2.5-pro", effort: "MEDIUM", want: 8192, ok: true},
|
||||
}
|
||||
|
||||
for _, cs := range cases {
|
||||
cs := cs
|
||||
t.Run(cs.name, func(t *testing.T) {
|
||||
got, ok := util.ThinkingEffortToBudget(cs.model, cs.effort)
|
||||
if ok != cs.ok {
|
||||
t.Fatalf("ok mismatch for model=%s effort=%s: expect %v got %v", cs.model, cs.effort, cs.ok, ok)
|
||||
}
|
||||
if got != cs.want {
|
||||
t.Fatalf("value mismatch for model=%s effort=%s: expect %d got %d", cs.model, cs.effort, cs.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user